package sql import ( "errors" "git.sxidc.com/go-tools/utils/strutils" "github.com/iancoleman/strcase" "reflect" "strings" "time" ) const ( sqlMappingDefaultKeyColumnName = "id" sqlMappingDefaultJoinWith = "::" sqlMappingTagPartSeparator = ";" sqlMappingTagPartKeyValueSeparator = ":" ) const ( sqlMappingTagKey = "sqlmapping" sqlMappingIgnore = "-" sqlMappingColumn = "column" sqlMappingKey = "key" sqlMappingNotUpdate = "notUpdate" sqlMappingUpdateClear = "updateClear" sqlMappingAes = "aes" sqlMappingJoinWith = "joinWith" ) type Mapping struct { MappingElement map[string]any } type MappingColumn struct { Name string IsKey bool CanUpdate bool CanUpdateClear bool AESKey string JoinWith string // 原字段的反射结构 OriginFieldType reflect.Type OriginFieldValue reflect.Value // 值类型的反射结构 FieldTypeElem reflect.Type FieldValueElem reflect.Value } func ParseSqlMappingTag(e any) (*Mapping, error) { if e == nil { return nil, errors.New("没有传递实体") } entityType := reflect.TypeOf(e) if entityType.Kind() == reflect.Ptr { entityType = entityType.Elem() } if entityType.Kind() != reflect.Struct { return nil, errors.New("传递的实体不是结构类型") } entityValue := reflect.ValueOf(e) if entityValue.Kind() == reflect.Ptr { entityValue = entityValue.Elem() } sqlMapping := new(Mapping) sqlMapping.MappingElement = make(map[string]any) fieldNum := entityType.NumField() for i := 0; i < fieldNum; i++ { field := entityType.Field(i) fieldValue := entityValue.Field(i) element, err := parseSqlMappingElement(field, fieldValue) if err != nil { return nil, err } if element == nil { continue } sqlMapping.MappingElement[field.Name] = element } return sqlMapping, nil } func parseSqlMappingElement(field reflect.StructField, fieldValue reflect.Value) (any, error) { sqlMappingTag := field.Tag.Get(sqlMappingTagKey) if sqlMappingTag == sqlMappingIgnore { return nil, nil } fieldValueTypeElem := field.Type if field.Type.Kind() == reflect.Ptr { fieldValueTypeElem = field.Type.Elem() } fieldValueElem := fieldValue if fieldValue.Kind() == reflect.Ptr { if !fieldValue.IsValid() || fieldValue.IsNil() { fieldValueElem = reflect.New(fieldValueTypeElem).Elem() } else { fieldValueElem = fieldValue.Elem() } } if fieldValueTypeElem.Kind() == reflect.Struct && fieldValueTypeElem != reflect.TypeOf(time.Time{}) { return ParseSqlMappingTag(fieldValueElem.Interface()) } sqlColumn := &MappingColumn{ Name: strcase.ToSnake(field.Name), IsKey: false, CanUpdate: true, CanUpdateClear: false, AESKey: "", JoinWith: sqlMappingDefaultJoinWith, OriginFieldType: field.Type, OriginFieldValue: fieldValue, FieldTypeElem: fieldValueTypeElem, FieldValueElem: fieldValueElem, } if sqlColumn.Name == sqlMappingDefaultKeyColumnName { sqlColumn.IsKey = true sqlColumn.CanUpdate = false } if strutils.IsStringEmpty(sqlMappingTag) { return sqlColumn, nil } sqlMappingParts := strings.Split(sqlMappingTag, sqlMappingTagPartSeparator) if sqlMappingParts != nil || len(sqlMappingParts) != 0 { for _, sqlMappingPart := range sqlMappingParts { sqlPartKeyValue := strings.SplitN(strings.TrimSpace(sqlMappingPart), sqlMappingTagPartKeyValueSeparator, 2) if sqlPartKeyValue != nil && len(sqlPartKeyValue) == 2 && strutils.IsStringNotEmpty(sqlPartKeyValue[1]) { sqlPartKeyValue[1] = strings.Trim(sqlPartKeyValue[1], "'") } switch sqlPartKeyValue[0] { case sqlMappingColumn: if strutils.IsStringEmpty(sqlPartKeyValue[1]) { return nil, errors.New("column没有赋值列名") } sqlColumn.Name = sqlPartKeyValue[1] case sqlMappingKey: sqlColumn.IsKey = true sqlColumn.CanUpdate = false case sqlMappingNotUpdate: sqlColumn.CanUpdate = false sqlColumn.CanUpdateClear = false case sqlMappingUpdateClear: if !sqlColumn.CanUpdate { sqlColumn.CanUpdateClear = false } else { sqlColumn.CanUpdateClear = true } case sqlMappingAes: if len(sqlPartKeyValue[1]) != 32 { return nil, errors.New("AES密钥长度应该为32个字节") } sqlColumn.AESKey = sqlPartKeyValue[1] case sqlMappingJoinWith: if strutils.IsStringEmpty(sqlPartKeyValue[1]) { return nil, errors.New(sqlMappingJoinWith + "没有赋值分隔符") } if fieldValueTypeElem.Kind() != reflect.Slice || fieldValueTypeElem.Elem().Kind() != reflect.String { return nil, errors.New(sqlMappingJoinWith + "应该添加在[]string字段上") } sqlColumn.JoinWith = sqlPartKeyValue[1] default: continue } } } return sqlColumn, nil }