sql_mapping.go 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. package sql
  2. import (
  3. "errors"
  4. "github.com/iancoleman/strcase"
  5. "reflect"
  6. "strings"
  7. )
  8. const (
  9. defaultKeyColumnName = "id"
  10. sqlMappingTagPartSeparator = ";"
  11. sqlMappingTagPartKeyValueSeparator = ":"
  12. )
  13. const (
  14. sqlMappingTagKey = "sqlmapping"
  15. sqlMappingIgnore = "-"
  16. sqlMappingColumn = "column"
  17. sqlMappingKey = "key"
  18. sqlMappingNotUpdate = "notUpdate"
  19. sqlMappingUpdateClear = "updateClear"
  20. sqlMappingAes = "aes"
  21. )
  22. type Mapping struct {
  23. ColumnMap map[string]MappingColumn
  24. }
  25. func ParseSqlMapping(e any) (*Mapping, error) {
  26. if e == nil {
  27. return nil, errors.New("没有传递实体")
  28. }
  29. entityType := reflect.TypeOf(e)
  30. if entityType.Kind() == reflect.Ptr {
  31. entityType = entityType.Elem()
  32. }
  33. if entityType.Kind() != reflect.Struct {
  34. return nil, errors.New("传递的实体不是结构类型")
  35. }
  36. entityValue := reflect.ValueOf(e)
  37. if entityValue.Kind() == reflect.Ptr {
  38. entityValue = entityValue.Elem()
  39. }
  40. sqlMapping := new(Mapping)
  41. sqlMapping.ColumnMap = make(map[string]MappingColumn)
  42. fieldNum := entityType.NumField()
  43. for i := 0; i < fieldNum; i++ {
  44. field := entityType.Field(i)
  45. fieldValue := entityValue.Field(i)
  46. column, err := parseSqlMappingColumn(field, fieldValue)
  47. if err != nil {
  48. return nil, err
  49. }
  50. if column == nil {
  51. continue
  52. }
  53. sqlMapping.ColumnMap[field.Name] = *column
  54. }
  55. return sqlMapping, nil
  56. }
  57. type MappingColumn struct {
  58. Name string
  59. IsKey bool
  60. CanUpdate bool
  61. CanUpdateClear bool
  62. AESKey string
  63. // 原字段的反射结构
  64. OriginFieldType reflect.Type
  65. OriginFieldValue reflect.Value
  66. // 值类型的反射结构
  67. FieldTypeElem reflect.Type
  68. FieldValueElem reflect.Value
  69. }
  70. func parseSqlMappingColumn(field reflect.StructField, fieldValue reflect.Value) (*MappingColumn, error) {
  71. valueFieldType := field.Type
  72. valueFieldValue := fieldValue
  73. if valueFieldType.Kind() == reflect.Ptr {
  74. valueFieldType = valueFieldType.Elem()
  75. if !valueFieldValue.IsValid() || valueFieldValue.IsNil() || valueFieldValue.IsZero() {
  76. valueFieldValue = reflect.Zero(valueFieldType)
  77. } else {
  78. valueFieldValue = fieldValue.Elem()
  79. }
  80. }
  81. sqlColumn := &MappingColumn{
  82. Name: strcase.ToSnake(field.Name),
  83. IsKey: false,
  84. CanUpdate: true,
  85. CanUpdateClear: false,
  86. AESKey: "",
  87. OriginFieldType: field.Type,
  88. OriginFieldValue: fieldValue,
  89. FieldTypeElem: valueFieldType,
  90. FieldValueElem: valueFieldValue,
  91. }
  92. if sqlColumn.Name == defaultKeyColumnName {
  93. sqlColumn.IsKey = true
  94. sqlColumn.CanUpdate = false
  95. }
  96. sqlMappingTag, ok := field.Tag.Lookup(sqlMappingTagKey)
  97. if !ok {
  98. return sqlColumn, nil
  99. }
  100. if sqlMappingTag == sqlMappingIgnore {
  101. return nil, nil
  102. }
  103. sqlMappingParts := strings.Split(sqlMappingTag, sqlMappingTagPartSeparator)
  104. if sqlMappingParts != nil || len(sqlMappingParts) != 0 {
  105. for _, sqlMappingPart := range sqlMappingParts {
  106. sqlPartKeyValue := strings.SplitN(strings.TrimSpace(sqlMappingPart), sqlMappingTagPartKeyValueSeparator, 2)
  107. switch sqlPartKeyValue[0] {
  108. case sqlMappingColumn:
  109. sqlColumn.Name = strings.TrimSpace(sqlPartKeyValue[1])
  110. case sqlMappingKey:
  111. sqlColumn.IsKey = true
  112. sqlColumn.CanUpdate = false
  113. case sqlMappingNotUpdate:
  114. sqlColumn.CanUpdate = false
  115. case sqlMappingUpdateClear:
  116. sqlColumn.CanUpdateClear = true
  117. case sqlMappingAes:
  118. if len(strings.TrimSpace(sqlPartKeyValue[1])) != 32 {
  119. return nil, errors.New("AES密钥长度应该为32个字节")
  120. }
  121. sqlColumn.AESKey = strings.TrimSpace(sqlPartKeyValue[1])
  122. default:
  123. continue
  124. }
  125. }
  126. }
  127. return sqlColumn, nil
  128. }