sql_mapping.go 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179
  1. package sql
  2. import (
  3. "errors"
  4. "github.com/iancoleman/strcase"
  5. "reflect"
  6. "strings"
  7. "time"
  8. )
  9. const (
  10. defaultKeyColumnName = "id"
  11. sqlMappingTagPartSeparator = ";"
  12. sqlMappingTagPartKeyValueSeparator = ":"
  13. )
  14. const (
  15. sqlMappingTagKey = "sqlmapping"
  16. sqlMappingIgnore = "-"
  17. sqlMappingColumn = "column"
  18. sqlMappingKey = "key"
  19. sqlMappingNotUpdate = "notUpdate"
  20. sqlMappingUpdateClear = "updateClear"
  21. sqlMappingAes = "aes"
  22. )
  23. type Mapping struct {
  24. MappingElement map[string]any
  25. }
  26. func ParseSqlMapping(e any) (*Mapping, error) {
  27. if e == nil {
  28. return nil, errors.New("没有传递实体")
  29. }
  30. entityType := reflect.TypeOf(e)
  31. if entityType.Kind() == reflect.Ptr {
  32. entityType = entityType.Elem()
  33. }
  34. if entityType.Kind() != reflect.Struct {
  35. return nil, errors.New("传递的实体不是结构类型")
  36. }
  37. entityValue := reflect.ValueOf(e)
  38. if entityValue.Kind() == reflect.Ptr {
  39. entityValue = entityValue.Elem()
  40. }
  41. sqlMapping := new(Mapping)
  42. sqlMapping.MappingElement = make(map[string]any)
  43. fieldNum := entityType.NumField()
  44. for i := 0; i < fieldNum; i++ {
  45. field := entityType.Field(i)
  46. fieldValue := entityValue.Field(i)
  47. element, err := parseSqlMappingElement(field, fieldValue)
  48. if err != nil {
  49. return nil, err
  50. }
  51. if element == nil {
  52. continue
  53. }
  54. sqlMapping.MappingElement[field.Name] = element
  55. }
  56. return sqlMapping, nil
  57. }
  58. type MappingStruct struct {
  59. MappingTypesAndValues
  60. }
  61. type MappingColumn struct {
  62. Name string
  63. IsKey bool
  64. CanUpdate bool
  65. CanUpdateClear bool
  66. AESKey string
  67. MappingTypesAndValues
  68. }
  69. type MappingTypesAndValues struct {
  70. // 原字段的反射结构
  71. OriginFieldType reflect.Type
  72. OriginFieldValue reflect.Value
  73. // 值类型的反射结构
  74. FieldTypeElem reflect.Type
  75. FieldValueElem reflect.Value
  76. }
  77. func parseSqlMappingElement(field reflect.StructField, fieldValue reflect.Value) (any, error) {
  78. fieldValueTypeElem := field.Type
  79. if field.Type.Kind() == reflect.Ptr {
  80. fieldValueTypeElem = field.Type.Elem()
  81. }
  82. fieldValueElem := fieldValue
  83. if fieldValue.Kind() == reflect.Ptr {
  84. if !fieldValue.IsValid() || fieldValue.IsNil() {
  85. fieldValue.Set(reflect.New(fieldValueTypeElem).Elem().Addr())
  86. }
  87. fieldValueElem = fieldValue.Elem()
  88. }
  89. if fieldValueTypeElem.Kind() == reflect.Struct && fieldValueTypeElem != reflect.TypeOf(time.Time{}) {
  90. return &MappingStruct{
  91. MappingTypesAndValues: MappingTypesAndValues{
  92. OriginFieldType: field.Type,
  93. OriginFieldValue: fieldValue,
  94. FieldTypeElem: fieldValueTypeElem,
  95. FieldValueElem: fieldValueElem,
  96. },
  97. }, nil
  98. }
  99. sqlColumn := &MappingColumn{
  100. Name: strcase.ToSnake(field.Name),
  101. IsKey: false,
  102. CanUpdate: true,
  103. CanUpdateClear: false,
  104. AESKey: "",
  105. MappingTypesAndValues: MappingTypesAndValues{
  106. OriginFieldType: field.Type,
  107. OriginFieldValue: fieldValue,
  108. FieldTypeElem: fieldValueTypeElem,
  109. FieldValueElem: fieldValueElem,
  110. },
  111. }
  112. if sqlColumn.Name == defaultKeyColumnName {
  113. sqlColumn.IsKey = true
  114. sqlColumn.CanUpdate = false
  115. }
  116. sqlMappingTag, ok := field.Tag.Lookup(sqlMappingTagKey)
  117. if !ok {
  118. return sqlColumn, nil
  119. }
  120. if sqlMappingTag == sqlMappingIgnore {
  121. return nil, nil
  122. }
  123. sqlMappingParts := strings.Split(sqlMappingTag, sqlMappingTagPartSeparator)
  124. if sqlMappingParts != nil || len(sqlMappingParts) != 0 {
  125. for _, sqlMappingPart := range sqlMappingParts {
  126. sqlPartKeyValue := strings.SplitN(strings.TrimSpace(sqlMappingPart), sqlMappingTagPartKeyValueSeparator, 2)
  127. switch sqlPartKeyValue[0] {
  128. case sqlMappingColumn:
  129. sqlColumn.Name = strings.TrimSpace(sqlPartKeyValue[1])
  130. case sqlMappingKey:
  131. sqlColumn.IsKey = true
  132. sqlColumn.CanUpdate = false
  133. case sqlMappingNotUpdate:
  134. sqlColumn.CanUpdate = false
  135. case sqlMappingUpdateClear:
  136. sqlColumn.CanUpdateClear = true
  137. case sqlMappingAes:
  138. if len(strings.TrimSpace(sqlPartKeyValue[1])) != 32 {
  139. return nil, errors.New("AES密钥长度应该为32个字节")
  140. }
  141. sqlColumn.AESKey = strings.TrimSpace(sqlPartKeyValue[1])
  142. default:
  143. continue
  144. }
  145. }
  146. }
  147. return sqlColumn, nil
  148. }