sql_result.go 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162
  1. package sql
  2. import (
  3. "errors"
  4. "github.com/iancoleman/strcase"
  5. "reflect"
  6. "strings"
  7. "time"
  8. )
  9. const (
  10. sqlResultTagPartSeparator = ";"
  11. sqlResultTagPartKeyValueSeparator = ":"
  12. )
  13. const (
  14. sqlResultTagKey = "sqlresult"
  15. sqlResultIgnore = "-"
  16. sqlResultColumn = "column"
  17. sqlResultParseTime = "parseTime"
  18. sqlResultAes = "aes"
  19. )
  20. type Result struct {
  21. ResultElement map[string]any
  22. }
  23. func ParseSqlResult(e any) (*Result, error) {
  24. if e == nil {
  25. return nil, errors.New("没有传递实体")
  26. }
  27. entityType := reflect.TypeOf(e)
  28. if entityType.Kind() == reflect.Ptr {
  29. entityType = entityType.Elem()
  30. }
  31. if entityType.Kind() != reflect.Struct {
  32. return nil, errors.New("传递的实体不是结构类型")
  33. }
  34. entityValue := reflect.ValueOf(e)
  35. if entityValue.Kind() == reflect.Ptr {
  36. entityValue = entityValue.Elem()
  37. }
  38. sqlResult := new(Result)
  39. sqlResult.ResultElement = make(map[string]any)
  40. fieldNum := entityType.NumField()
  41. for i := 0; i < fieldNum; i++ {
  42. field := entityType.Field(i)
  43. fieldValue := entityValue.Field(i)
  44. element, err := parseSqlResultElement(field, fieldValue)
  45. if err != nil {
  46. return nil, err
  47. }
  48. if element == nil {
  49. continue
  50. }
  51. sqlResult.ResultElement[field.Name] = element
  52. }
  53. return sqlResult, nil
  54. }
  55. type ResultStruct struct {
  56. ResultTypesAndValues
  57. }
  58. type ResultColumn struct {
  59. Name string
  60. ParseTime string
  61. AESKey string
  62. ResultTypesAndValues
  63. }
  64. type ResultTypesAndValues struct {
  65. // 原字段的反射结构
  66. OriginFieldType reflect.Type
  67. OriginFieldValue reflect.Value
  68. // 值类型的反射结构
  69. FieldTypeElem reflect.Type
  70. FieldValueElem reflect.Value
  71. }
  72. func parseSqlResultElement(field reflect.StructField, fieldValue reflect.Value) (any, error) {
  73. fieldValueTypeElem := field.Type
  74. if field.Type.Kind() == reflect.Ptr {
  75. fieldValueTypeElem = field.Type.Elem()
  76. }
  77. fieldValueElem := fieldValue
  78. if fieldValue.Kind() == reflect.Ptr {
  79. if !fieldValue.IsValid() || fieldValue.IsNil() {
  80. fieldValue.Set(reflect.New(fieldValueTypeElem).Elem().Addr())
  81. }
  82. fieldValueElem = fieldValue.Elem()
  83. }
  84. if fieldValueTypeElem.Kind() == reflect.Struct && fieldValueTypeElem != reflect.TypeOf(time.Time{}) {
  85. return &ResultStruct{
  86. ResultTypesAndValues: ResultTypesAndValues{
  87. OriginFieldType: field.Type,
  88. OriginFieldValue: fieldValue,
  89. FieldTypeElem: fieldValueTypeElem,
  90. FieldValueElem: fieldValueElem,
  91. },
  92. }, nil
  93. }
  94. sqlColumn := &ResultColumn{
  95. Name: strcase.ToSnake(field.Name),
  96. ParseTime: "",
  97. AESKey: "",
  98. ResultTypesAndValues: ResultTypesAndValues{
  99. OriginFieldType: field.Type,
  100. OriginFieldValue: fieldValue,
  101. FieldTypeElem: fieldValueTypeElem,
  102. FieldValueElem: fieldValueElem,
  103. },
  104. }
  105. sqlResultTag, ok := field.Tag.Lookup(sqlResultTagKey)
  106. if !ok {
  107. return sqlColumn, nil
  108. }
  109. if sqlResultTag == sqlResultIgnore {
  110. return nil, nil
  111. }
  112. sqlResultParts := strings.Split(sqlResultTag, sqlResultTagPartSeparator)
  113. if sqlResultParts != nil || len(sqlResultParts) != 0 {
  114. for _, sqlResultPart := range sqlResultParts {
  115. sqlPartKeyValue := strings.SplitN(strings.TrimSpace(sqlResultPart), sqlResultTagPartKeyValueSeparator, 2)
  116. switch sqlPartKeyValue[0] {
  117. case sqlResultColumn:
  118. sqlColumn.Name = strings.TrimSpace(sqlPartKeyValue[1])
  119. case sqlResultParseTime:
  120. sqlColumn.ParseTime = strings.TrimSpace(sqlPartKeyValue[1])
  121. case sqlResultAes:
  122. if len(strings.TrimSpace(sqlPartKeyValue[1])) != 32 {
  123. return nil, errors.New("AES密钥长度应该为32个字节")
  124. }
  125. sqlColumn.AESKey = strings.TrimSpace(sqlPartKeyValue[1])
  126. default:
  127. continue
  128. }
  129. }
  130. }
  131. return sqlColumn, nil
  132. }