sql_mapping.go 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. package sql
  2. import (
  3. "errors"
  4. "github.com/iancoleman/strcase"
  5. "reflect"
  6. "strings"
  7. )
  8. const (
  9. defaultKeyColumnName = "id"
  10. sqlMappingTagPartSeparator = ";"
  11. sqlMappingTagPartKeyValueSeparator = ":"
  12. )
  13. const (
  14. sqlMappingTagKey = "sqlmapping"
  15. sqlMappingIgnore = "-"
  16. sqlMappingColumn = "column"
  17. sqlMappingKey = "key"
  18. sqlMappingNotUpdate = "notUpdate"
  19. sqlMappingUpdateClear = "updateClear"
  20. )
  21. type Mapping struct {
  22. ColumnMap map[string]MappingColumn
  23. }
  24. func ParseSqlMapping(e any) (*Mapping, error) {
  25. if e == nil {
  26. return nil, errors.New("没有传递实体")
  27. }
  28. entityType := reflect.TypeOf(e)
  29. if entityType.Kind() == reflect.Ptr {
  30. entityType = entityType.Elem()
  31. }
  32. if entityType.Kind() != reflect.Struct {
  33. return nil, errors.New("传递的不是实体结构")
  34. }
  35. entityValue := reflect.ValueOf(e)
  36. if entityValue.Kind() == reflect.Ptr {
  37. entityValue = entityValue.Elem()
  38. }
  39. sqlMapping := new(Mapping)
  40. sqlMapping.ColumnMap = make(map[string]MappingColumn)
  41. fieldNum := entityType.NumField()
  42. for i := 0; i < fieldNum; i++ {
  43. field := entityType.Field(i)
  44. fieldValue := entityValue.Field(i)
  45. column, err := parseSqlMappingColumn(field, fieldValue)
  46. if err != nil {
  47. return nil, err
  48. }
  49. if column == nil {
  50. continue
  51. }
  52. sqlMapping.ColumnMap[field.Name] = *column
  53. }
  54. return sqlMapping, nil
  55. }
  56. type MappingColumn struct {
  57. Name string
  58. IsKey bool
  59. CanUpdate bool
  60. CanUpdateClear bool
  61. // 原字段的反射结构
  62. OriginFieldType reflect.Type
  63. OriginFieldValue reflect.Value
  64. // 值类型的反射结构
  65. ValueFieldType reflect.Type
  66. ValueFieldValue reflect.Value
  67. }
  68. func parseSqlMappingColumn(field reflect.StructField, fieldValue reflect.Value) (*MappingColumn, error) {
  69. valueFieldType := field.Type
  70. valueFieldValue := fieldValue
  71. if valueFieldType.Kind() == reflect.Ptr {
  72. valueFieldType = valueFieldType.Elem()
  73. if valueFieldValue.IsZero() {
  74. valueFieldValue = reflect.Zero(valueFieldType)
  75. } else {
  76. valueFieldValue = fieldValue.Elem()
  77. }
  78. }
  79. sqlColumn := &MappingColumn{
  80. Name: strcase.ToSnake(field.Name),
  81. IsKey: false,
  82. CanUpdate: true,
  83. CanUpdateClear: false,
  84. OriginFieldType: field.Type,
  85. OriginFieldValue: fieldValue,
  86. ValueFieldType: valueFieldType,
  87. ValueFieldValue: valueFieldValue,
  88. }
  89. if sqlColumn.Name == defaultKeyColumnName {
  90. sqlColumn.IsKey = true
  91. }
  92. sqlMappingTag, ok := field.Tag.Lookup(sqlMappingTagKey)
  93. if !ok {
  94. return sqlColumn, nil
  95. }
  96. if sqlMappingTag == sqlMappingIgnore {
  97. return nil, nil
  98. }
  99. sqlMappingParts := strings.Split(sqlMappingTag, sqlMappingTagPartSeparator)
  100. if sqlMappingParts != nil || len(sqlMappingParts) != 0 {
  101. for _, sqlMappingPart := range sqlMappingParts {
  102. sqlPartKeyValue := strings.Split(strings.TrimSpace(sqlMappingPart), sqlMappingTagPartKeyValueSeparator)
  103. switch sqlPartKeyValue[0] {
  104. case sqlMappingColumn:
  105. sqlColumn.Name = strings.TrimSpace(sqlPartKeyValue[1])
  106. case sqlMappingKey:
  107. sqlColumn.IsKey = true
  108. sqlColumn.CanUpdate = false
  109. case sqlMappingNotUpdate:
  110. sqlColumn.CanUpdate = false
  111. case sqlMappingUpdateClear:
  112. sqlColumn.CanUpdateClear = true
  113. default:
  114. continue
  115. }
  116. }
  117. }
  118. return sqlColumn, nil
  119. }