sql_result.go 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154
  1. package tag
  2. import (
  3. "errors"
  4. "git.sxidc.com/go-tools/utils/reflectutils"
  5. "git.sxidc.com/go-tools/utils/strutils"
  6. "git.sxidc.com/service-supports/fserr"
  7. "github.com/iancoleman/strcase"
  8. "reflect"
  9. "strings"
  10. )
  11. type OnSqlResultParsedFieldTagFunc func(fieldName string, entityFieldElemValue reflect.Value, sqlResult *SqlResultTag) error
  12. func UseSqlResultTag(e any, onParsedFieldTagFunc OnSqlResultParsedFieldTagFunc) error {
  13. if e == nil {
  14. return nil
  15. }
  16. entityValue := reflect.ValueOf(e)
  17. // 类型校验
  18. if !reflectutils.IsValueStructPointer(entityValue) {
  19. return fserr.New("参数不是结构指针")
  20. }
  21. entityElemValue := reflectutils.PointerValueElem(entityValue)
  22. err := parseEntitySqlResultTag(entityElemValue, onParsedFieldTagFunc)
  23. if err != nil {
  24. return err
  25. }
  26. return nil
  27. }
  28. func parseEntitySqlResultTag(entityElemValue reflect.Value, onParsedFieldTagFunc OnSqlResultParsedFieldTagFunc) error {
  29. for i := 0; i < entityElemValue.NumField(); i++ {
  30. entityField := entityElemValue.Type().Field(i)
  31. entityFieldValue := entityElemValue.Field(i)
  32. // 无效值,不进行映射
  33. if !entityFieldValue.IsValid() {
  34. continue
  35. }
  36. if entityFieldValue.Kind() == reflect.Pointer && entityFieldValue.IsNil() {
  37. entityFieldValue.Set(reflect.New(entityField.Type.Elem()))
  38. }
  39. entityFieldElemValue := reflectutils.PointerValueElem(entityFieldValue)
  40. tagStr := entityField.Tag.Get(sqlResultTagKey)
  41. tag, err := parseSqlResultTag(entityField, tagStr)
  42. if err != nil {
  43. return err
  44. }
  45. if tag == nil {
  46. continue
  47. }
  48. // 结构类型的字段,解析结构内部
  49. if entityFieldElemValue.Kind() == reflect.Struct &&
  50. !reflectutils.IsValueTime(entityFieldElemValue) {
  51. err := parseEntitySqlResultTag(entityFieldElemValue, onParsedFieldTagFunc)
  52. if err != nil {
  53. return err
  54. }
  55. continue
  56. }
  57. err = onParsedFieldTagFunc(entityField.Name, entityFieldElemValue, tag)
  58. if err != nil {
  59. return err
  60. }
  61. }
  62. return nil
  63. }
  64. const (
  65. sqlResultDefaultSplitWith = "::"
  66. sqlResultTagPartSeparator = ";"
  67. sqlResultTagPartKeyValueSeparator = ":"
  68. )
  69. const (
  70. sqlResultTagKey = "sqlresult"
  71. sqlResultIgnore = "-"
  72. sqlResultColumn = "column"
  73. sqlResultParseTime = "parseTime"
  74. sqlResultAes = "aes"
  75. sqlResultSplitWith = "splitWith"
  76. )
  77. type SqlResultTag struct {
  78. Name string
  79. ParseTime string
  80. AESKey string
  81. SplitWith string
  82. }
  83. func parseSqlResultTag(field reflect.StructField, tagStr string) (*SqlResultTag, error) {
  84. if tagStr == sqlResultIgnore {
  85. return nil, nil
  86. }
  87. sqlResultTag := &SqlResultTag{
  88. Name: strcase.ToSnake(field.Name),
  89. ParseTime: "",
  90. AESKey: "",
  91. SplitWith: sqlResultDefaultSplitWith,
  92. }
  93. if strutils.IsStringEmpty(tagStr) {
  94. return sqlResultTag, nil
  95. }
  96. sqlResultParts := strings.Split(tagStr, sqlResultTagPartSeparator)
  97. if sqlResultParts != nil || len(sqlResultParts) != 0 {
  98. for _, sqlResultPart := range sqlResultParts {
  99. sqlPartKeyValue := strings.SplitN(strings.TrimSpace(sqlResultPart), sqlResultTagPartKeyValueSeparator, 2)
  100. if sqlPartKeyValue != nil && len(sqlPartKeyValue) == 2 && strutils.IsStringNotEmpty(sqlPartKeyValue[1]) {
  101. sqlPartKeyValue[1] = strings.Trim(sqlPartKeyValue[1], "'")
  102. }
  103. switch sqlPartKeyValue[0] {
  104. case sqlResultColumn:
  105. sqlResultTag.Name = sqlPartKeyValue[1]
  106. case sqlResultParseTime:
  107. sqlResultTag.ParseTime = sqlPartKeyValue[1]
  108. case sqlResultAes:
  109. if len(sqlPartKeyValue[1]) != 32 {
  110. return nil, errors.New("AES密钥长度应该为32个字节")
  111. }
  112. sqlResultTag.AESKey = sqlPartKeyValue[1]
  113. case sqlResultSplitWith:
  114. if strutils.IsStringEmpty(sqlPartKeyValue[1]) {
  115. return nil, errors.New(sqlResultDefaultSplitWith + "没有赋值分隔符")
  116. }
  117. sqlResultTag.SplitWith = sqlPartKeyValue[1]
  118. default:
  119. continue
  120. }
  121. }
  122. }
  123. return sqlResultTag, nil
  124. }