sql_mapping.go 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202
  1. package sql
  2. import (
  3. "errors"
  4. "git.sxidc.com/go-tools/utils/strutils"
  5. "github.com/iancoleman/strcase"
  6. "reflect"
  7. "strings"
  8. "time"
  9. )
  10. const (
  11. sqlMappingDefaultKeyColumnName = "id"
  12. sqlMappingDefaultJoinWith = "::"
  13. sqlMappingTagPartSeparator = ";"
  14. sqlMappingTagPartKeyValueSeparator = ":"
  15. )
  16. const (
  17. sqlMappingTagKey = "sqlmapping"
  18. sqlMappingIgnore = "-"
  19. sqlMappingColumn = "column"
  20. sqlMappingKey = "key"
  21. sqlMappingNotUpdate = "notUpdate"
  22. sqlMappingUpdateClear = "updateClear"
  23. sqlMappingAes = "aes"
  24. sqlMappingJoinWith = "joinWith"
  25. )
  26. type Mapping struct {
  27. MappingElement map[string]any
  28. }
  29. func ParseSqlMapping(e any) (*Mapping, error) {
  30. if e == nil {
  31. return nil, errors.New("没有传递实体")
  32. }
  33. entityType := reflect.TypeOf(e)
  34. if entityType.Kind() == reflect.Ptr {
  35. entityType = entityType.Elem()
  36. }
  37. if entityType.Kind() != reflect.Struct {
  38. return nil, errors.New("传递的实体不是结构类型")
  39. }
  40. entityValue := reflect.ValueOf(e)
  41. if entityValue.Kind() == reflect.Ptr {
  42. entityValue = entityValue.Elem()
  43. }
  44. sqlMapping := new(Mapping)
  45. sqlMapping.MappingElement = make(map[string]any)
  46. fieldNum := entityType.NumField()
  47. for i := 0; i < fieldNum; i++ {
  48. field := entityType.Field(i)
  49. fieldValue := entityValue.Field(i)
  50. element, err := parseSqlMappingElement(field, fieldValue)
  51. if err != nil {
  52. return nil, err
  53. }
  54. if element == nil {
  55. continue
  56. }
  57. sqlMapping.MappingElement[field.Name] = element
  58. }
  59. return sqlMapping, nil
  60. }
  61. type MappingStruct struct {
  62. MappingTypesAndValues
  63. }
  64. type MappingColumn struct {
  65. Name string
  66. IsKey bool
  67. CanUpdate bool
  68. CanUpdateClear bool
  69. AESKey string
  70. JoinWith string
  71. MappingTypesAndValues
  72. }
  73. type MappingTypesAndValues struct {
  74. // 原字段的反射结构
  75. OriginFieldType reflect.Type
  76. OriginFieldValue reflect.Value
  77. // 值类型的反射结构
  78. FieldTypeElem reflect.Type
  79. FieldValueElem reflect.Value
  80. }
  81. func parseSqlMappingElement(field reflect.StructField, fieldValue reflect.Value) (any, error) {
  82. fieldValueTypeElem := field.Type
  83. if field.Type.Kind() == reflect.Ptr {
  84. fieldValueTypeElem = field.Type.Elem()
  85. }
  86. fieldValueElem := fieldValue
  87. if fieldValue.Kind() == reflect.Ptr {
  88. if !fieldValue.IsValid() || fieldValue.IsNil() {
  89. fieldValue.Set(reflect.New(fieldValueTypeElem).Elem().Addr())
  90. }
  91. fieldValueElem = fieldValue.Elem()
  92. }
  93. if fieldValueTypeElem.Kind() == reflect.Struct && fieldValueTypeElem != reflect.TypeOf(time.Time{}) {
  94. return &MappingStruct{
  95. MappingTypesAndValues: MappingTypesAndValues{
  96. OriginFieldType: field.Type,
  97. OriginFieldValue: fieldValue,
  98. FieldTypeElem: fieldValueTypeElem,
  99. FieldValueElem: fieldValueElem,
  100. },
  101. }, nil
  102. }
  103. sqlColumn := &MappingColumn{
  104. Name: strcase.ToSnake(field.Name),
  105. IsKey: false,
  106. CanUpdate: true,
  107. CanUpdateClear: false,
  108. AESKey: "",
  109. JoinWith: sqlMappingDefaultJoinWith,
  110. MappingTypesAndValues: MappingTypesAndValues{
  111. OriginFieldType: field.Type,
  112. OriginFieldValue: fieldValue,
  113. FieldTypeElem: fieldValueTypeElem,
  114. FieldValueElem: fieldValueElem,
  115. },
  116. }
  117. if sqlColumn.Name == sqlMappingDefaultKeyColumnName {
  118. sqlColumn.IsKey = true
  119. sqlColumn.CanUpdate = false
  120. }
  121. sqlMappingTag, ok := field.Tag.Lookup(sqlMappingTagKey)
  122. if !ok {
  123. return sqlColumn, nil
  124. }
  125. if sqlMappingTag == sqlMappingIgnore {
  126. return nil, nil
  127. }
  128. sqlMappingParts := strings.Split(sqlMappingTag, sqlMappingTagPartSeparator)
  129. if sqlMappingParts != nil || len(sqlMappingParts) != 0 {
  130. for _, sqlMappingPart := range sqlMappingParts {
  131. sqlPartKeyValue := strings.SplitN(strings.TrimSpace(sqlMappingPart), sqlMappingTagPartKeyValueSeparator, 2)
  132. if sqlPartKeyValue != nil && len(sqlPartKeyValue) == 2 && strutils.IsStringNotEmpty(sqlPartKeyValue[1]) {
  133. sqlPartKeyValue[1] = strings.Trim(sqlPartKeyValue[1], "'")
  134. }
  135. switch sqlPartKeyValue[0] {
  136. case sqlMappingColumn:
  137. if strutils.IsStringEmpty(sqlPartKeyValue[1]) {
  138. return nil, errors.New("column没有赋值列名")
  139. }
  140. sqlColumn.Name = sqlPartKeyValue[1]
  141. case sqlMappingKey:
  142. sqlColumn.IsKey = true
  143. sqlColumn.CanUpdate = false
  144. case sqlMappingNotUpdate:
  145. sqlColumn.CanUpdate = false
  146. case sqlMappingUpdateClear:
  147. sqlColumn.CanUpdateClear = true
  148. case sqlMappingAes:
  149. if len(sqlPartKeyValue[1]) != 32 {
  150. return nil, errors.New("AES密钥长度应该为32个字节")
  151. }
  152. sqlColumn.AESKey = sqlPartKeyValue[1]
  153. case sqlMappingJoinWith:
  154. if strutils.IsStringEmpty(sqlPartKeyValue[1]) {
  155. return nil, errors.New(sqlMappingJoinWith + "没有赋值分隔符")
  156. }
  157. if fieldValueTypeElem.Kind() != reflect.Slice || fieldValueTypeElem.Elem().Kind() != reflect.String {
  158. return nil, errors.New(sqlMappingJoinWith + "应该添加在[]string字段上")
  159. }
  160. sqlColumn.JoinWith = sqlPartKeyValue[1]
  161. default:
  162. continue
  163. }
  164. }
  165. }
  166. return sqlColumn, nil
  167. }