package tag import ( "errors" "git.sxidc.com/go-tools/utils/reflectutils" "git.sxidc.com/go-tools/utils/strutils" "git.sxidc.com/service-supports/fserr" "github.com/iancoleman/strcase" "reflect" "strings" ) type OnSqlMappingParsedFieldTagFunc func(fieldName string, entityFieldElemValue reflect.Value, tag *SqlMappingTag) error func BuildExecuteParams(e any, onParsedFieldTagFunc OnSqlMappingParsedFieldTagFunc) error { if e == nil { return nil } entityValue := reflect.ValueOf(e) // 类型校验 if !reflectutils.IsValueStructOrStructPointer(entityValue) { return fserr.New("参数不是结构或结构指针") } entityElemValue := reflectutils.PointerValueElem(entityValue) err := parseEntitySqlMappingTag(entityElemValue, onParsedFieldTagFunc) if err != nil { return err } return nil } func parseEntitySqlMappingTag(entityElemValue reflect.Value, onParsedFieldTagFunc OnSqlMappingParsedFieldTagFunc) error { for i := 0; i < entityElemValue.NumField(); i++ { entityField := entityElemValue.Type().Field(i) entityFieldValue := entityElemValue.Field(i) // 无效零值,不进行映射 if !entityFieldValue.IsValid() { continue } if entityFieldValue.Kind() == reflect.Pointer && entityFieldValue.IsNil() { entityFieldValue.Set(reflect.New(entityField.Type.Elem())) } entityFieldElemValue := reflectutils.PointerValueElem(entityFieldValue) tagStr := entityField.Tag.Get(sqlMappingTagKey) tag, err := parseSqlMappingTag(entityField, tagStr) if err != nil { return err } if tag == nil { continue } // 结构类型的字段,解析结构内部 if entityFieldElemValue.Kind() == reflect.Struct && !reflectutils.IsValueTime(entityFieldElemValue) { err := parseEntitySqlMappingTag(entityFieldElemValue, onParsedFieldTagFunc) if err != nil { return err } continue } err = onParsedFieldTagFunc(entityField.Name, entityFieldElemValue, tag) if err != nil { return err } } return nil } const ( sqlMappingDefaultKeyColumnName = "id" sqlMappingDefaultJoinWith = "::" sqlMappingTagPartSeparator = ";" sqlMappingTagPartKeyValueSeparator = ":" ) const ( sqlMappingTagKey = "sqlmapping" sqlMappingIgnore = "-" sqlMappingColumn = "column" sqlMappingKey = "key" sqlMappingNotUpdate = "notUpdate" sqlMappingUpdateClear = "updateClear" sqlMappingAes = "aes" sqlMappingJoinWith = "joinWith" sqlMappingTrim = "trim" sqlMappingTrimPrefix = "trimPrefix" sqlMappingTrimSuffix = "trimSuffix" ) type SqlMappingTag struct { Name string IsKey bool CanUpdate bool CanUpdateClear bool AESKey string JoinWith string Trim string TrimPrefix string TrimSuffix string } func parseSqlMappingTag(field reflect.StructField, tagStr string) (*SqlMappingTag, error) { if tagStr == sqlMappingIgnore { return nil, nil } sqlColumn := &SqlMappingTag{ Name: strcase.ToSnake(field.Name), IsKey: false, CanUpdate: true, CanUpdateClear: false, AESKey: "", JoinWith: sqlMappingDefaultJoinWith, Trim: "", TrimPrefix: "", TrimSuffix: "", } if sqlColumn.Name == sqlMappingDefaultKeyColumnName { sqlColumn.IsKey = true sqlColumn.CanUpdate = false } if strutils.IsStringEmpty(tagStr) { return sqlColumn, nil } sqlMappingParts := strings.Split(tagStr, sqlMappingTagPartSeparator) if sqlMappingParts != nil || len(sqlMappingParts) != 0 { for _, sqlMappingPart := range sqlMappingParts { sqlMappingPartKeyValue := strings.SplitN(strings.TrimSpace(sqlMappingPart), sqlMappingTagPartKeyValueSeparator, 2) if sqlMappingPartKeyValue != nil && len(sqlMappingPartKeyValue) == 2 && strutils.IsStringNotEmpty(sqlMappingPartKeyValue[1]) { sqlMappingPartKeyValue[1] = strings.Trim(sqlMappingPartKeyValue[1], "'") } switch sqlMappingPartKeyValue[0] { case sqlMappingColumn: if strutils.IsStringEmpty(sqlMappingPartKeyValue[1]) { return nil, errors.New("column没有赋值列名") } sqlColumn.Name = sqlMappingPartKeyValue[1] case sqlMappingKey: sqlColumn.IsKey = true sqlColumn.CanUpdate = false case sqlMappingNotUpdate: sqlColumn.CanUpdate = false sqlColumn.CanUpdateClear = false case sqlMappingUpdateClear: if !sqlColumn.CanUpdate { sqlColumn.CanUpdateClear = false } else { sqlColumn.CanUpdateClear = true } case sqlMappingAes: if len(sqlMappingPartKeyValue[1]) != 32 { return nil, errors.New("AES密钥长度应该为32个字节") } sqlColumn.AESKey = sqlMappingPartKeyValue[1] case sqlMappingJoinWith: if strutils.IsStringEmpty(sqlMappingPartKeyValue[1]) { return nil, errors.New(sqlMappingJoinWith + "没有赋值分隔符") } sqlColumn.JoinWith = sqlMappingPartKeyValue[1] case sqlMappingTrim: sqlColumn.Trim = sqlMappingPartKeyValue[1] case sqlMappingTrimPrefix: sqlColumn.TrimPrefix = sqlMappingPartKeyValue[1] case sqlMappingTrimSuffix: sqlColumn.TrimSuffix = sqlMappingPartKeyValue[1] default: continue } } } return sqlColumn, nil }