package tag import ( "errors" "git.sxidc.com/go-tools/utils/reflectutils" "git.sxidc.com/go-tools/utils/strutils" "git.sxidc.com/service-supports/fserr" "github.com/iancoleman/strcase" "reflect" "strings" ) type OnSqlResultParsedFieldTagFunc func(fieldName string, entityFieldElemValue reflect.Value, sqlResult *SqlResultTag) error func UseSqlResultTag(e any, onParsedFieldTagFunc OnSqlResultParsedFieldTagFunc) error { if e == nil { return nil } entityValue := reflect.ValueOf(e) // 类型校验 if !reflectutils.IsValueStructPointer(entityValue) { return fserr.New("参数不是结构指针") } entityElemValue := reflectutils.PointerValueElem(entityValue) err := parseEntitySqlResultTag(entityElemValue, onParsedFieldTagFunc) if err != nil { return err } return nil } func parseEntitySqlResultTag(entityElemValue reflect.Value, onParsedFieldTagFunc OnSqlResultParsedFieldTagFunc) error { for i := 0; i < entityElemValue.NumField(); i++ { entityField := entityElemValue.Type().Field(i) entityFieldValue := entityElemValue.Field(i) // 无效值,不进行映射 if !entityFieldValue.IsValid() { continue } if entityFieldValue.Kind() == reflect.Pointer && entityFieldValue.IsNil() { entityFieldValue.Set(reflect.New(entityField.Type.Elem())) } entityFieldElemValue := reflectutils.PointerValueElem(entityFieldValue) tagStr := entityField.Tag.Get(sqlResultTagKey) tag, err := parseSqlResultTag(entityField, tagStr) if err != nil { return err } if tag == nil { continue } // 结构类型的字段,解析结构内部 if entityFieldElemValue.Kind() == reflect.Struct && !reflectutils.IsValueTime(entityFieldElemValue) { err := parseEntitySqlResultTag(entityFieldElemValue, onParsedFieldTagFunc) if err != nil { return err } continue } err = onParsedFieldTagFunc(entityField.Name, entityFieldElemValue, tag) if err != nil { return err } } return nil } const ( sqlResultDefaultSplitWith = "::" sqlResultTagPartSeparator = ";" sqlResultTagPartKeyValueSeparator = ":" ) const ( sqlResultTagKey = "sqlresult" sqlResultIgnore = "-" sqlResultColumn = "column" sqlResultParseTime = "parseTime" sqlResultAes = "aes" sqlResultSplitWith = "splitWith" ) type SqlResultTag struct { Name string ParseTime string AESKey string SplitWith string } func parseSqlResultTag(field reflect.StructField, tagStr string) (*SqlResultTag, error) { if tagStr == sqlResultIgnore { return nil, nil } sqlResultTag := &SqlResultTag{ Name: strcase.ToSnake(field.Name), ParseTime: "", AESKey: "", SplitWith: sqlResultDefaultSplitWith, } if strutils.IsStringEmpty(tagStr) { return sqlResultTag, nil } sqlResultParts := strings.Split(tagStr, sqlResultTagPartSeparator) if sqlResultParts != nil || len(sqlResultParts) != 0 { for _, sqlResultPart := range sqlResultParts { sqlPartKeyValue := strings.SplitN(strings.TrimSpace(sqlResultPart), sqlResultTagPartKeyValueSeparator, 2) if sqlPartKeyValue != nil && len(sqlPartKeyValue) == 2 && strutils.IsStringNotEmpty(sqlPartKeyValue[1]) { sqlPartKeyValue[1] = strings.Trim(sqlPartKeyValue[1], "'") } switch sqlPartKeyValue[0] { case sqlResultColumn: sqlResultTag.Name = sqlPartKeyValue[1] case sqlResultParseTime: sqlResultTag.ParseTime = sqlPartKeyValue[1] case sqlResultAes: if len(sqlPartKeyValue[1]) != 32 { return nil, errors.New("AES密钥长度应该为32个字节") } sqlResultTag.AESKey = sqlPartKeyValue[1] case sqlResultSplitWith: if strutils.IsStringEmpty(sqlPartKeyValue[1]) { return nil, errors.New(sqlResultDefaultSplitWith + "没有赋值分隔符") } sqlResultTag.SplitWith = sqlPartKeyValue[1] default: continue } } } return sqlResultTag, nil }