validate.go 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206
  1. package check
  2. import (
  3. "fmt"
  4. "git.sxidc.com/go-framework/baize/framework/core/infrastructure/logger"
  5. "git.sxidc.com/go-tools/utils/reflectutils"
  6. "git.sxidc.com/go-tools/utils/strutils"
  7. "github.com/go-playground/locales/zh"
  8. ut "github.com/go-playground/universal-translator"
  9. "github.com/go-playground/validator/v10"
  10. zhTranslations "github.com/go-playground/validator/v10/translations/zh"
  11. "github.com/pkg/errors"
  12. "reflect"
  13. "strings"
  14. "time"
  15. )
  16. const (
  17. timeNotZeroTag = "timenotzero"
  18. whenTag = "when"
  19. )
  20. var validate = validator.New(validator.WithRequiredStructEnabled())
  21. var translator ut.Translator
  22. func init() {
  23. validate.SetTagName("check")
  24. zhLocale := zh.New()
  25. zhTranslator := ut.New(zhLocale)
  26. trans, _ := zhTranslator.GetTranslator("zh")
  27. err := zhTranslations.RegisterDefaultTranslations(validate, trans)
  28. if err != nil {
  29. panic(err)
  30. }
  31. translator = trans
  32. registerCustomTags()
  33. }
  34. func registerCustomTags() {
  35. if err := validate.RegisterValidation(timeNotZeroTag, func(fl validator.FieldLevel) bool {
  36. fieldValue := fl.Field()
  37. if !fieldValue.IsValid() {
  38. return true
  39. }
  40. switch value := fieldValue.Interface().(type) {
  41. case time.Time:
  42. return !value.IsZero()
  43. case *time.Time:
  44. if value != nil {
  45. return true
  46. }
  47. return !value.IsZero()
  48. default:
  49. return false
  50. }
  51. }); err != nil {
  52. fmt.Println(err)
  53. return
  54. }
  55. if err := validate.RegisterValidation(whenTag, func(fl validator.FieldLevel) bool {
  56. return true
  57. }); err != nil {
  58. fmt.Println(err)
  59. return
  60. }
  61. }
  62. func Struct(obj any, fieldNameMap map[string]string) Result {
  63. objectType := reflectutils.PointerTypeElem(reflect.TypeOf(obj))
  64. whenMap := parseWhenMap(objectType)
  65. return newResult(validate.Struct(obj), nil, fieldNameMap, whenMap)
  66. }
  67. func parseWhenMap(objectType reflect.Type) map[string][]string {
  68. whenMap := make(map[string][]string)
  69. for i := range objectType.NumField() {
  70. field := objectType.Field(i)
  71. if (field.Type.Kind() == reflect.Pointer && field.Type.Elem().Kind() == reflect.Struct) ||
  72. field.Type.Kind() == reflect.Struct {
  73. fieldWhenMap := parseWhenMap(reflectutils.PointerTypeElem(field.Type))
  74. for k, v := range fieldWhenMap {
  75. whenMap[objectType.Name()+"."+k] = v
  76. }
  77. }
  78. fieldCheckTag := field.Tag.Get("check")
  79. if strutils.IsStringEmpty(fieldCheckTag) {
  80. continue
  81. }
  82. checkTags := strings.Split(fieldCheckTag, ",")
  83. for _, checkTag := range checkTags {
  84. if strings.HasPrefix(checkTag, "when=") {
  85. whens := strings.Split(strings.Split(checkTag, "=")[1], "/")
  86. whenMap[objectType.Name()+"."+field.Name] = whens
  87. }
  88. }
  89. }
  90. return whenMap
  91. }
  92. type Result struct {
  93. err error
  94. translatedErrs map[string]string
  95. fieldNameMap map[string]string
  96. whenMap map[string][]string
  97. }
  98. func newResult(err error, translatedErrs map[string]string, fieldNameMap map[string]string, whenMap map[string][]string) Result {
  99. return Result{
  100. err: err,
  101. translatedErrs: translatedErrs,
  102. fieldNameMap: fieldNameMap,
  103. whenMap: whenMap,
  104. }
  105. }
  106. func (result Result) IsError() bool {
  107. return result.err != nil
  108. }
  109. func (result Result) CheckWhen(when string) error {
  110. return result.CheckFieldWhen(when, nil)
  111. }
  112. func (result Result) CheckFieldWhen(when string, check func(fieldName string) bool) error {
  113. if result.err == nil {
  114. return nil
  115. }
  116. translatedResult := result.translation()
  117. if translatedResult.translatedErrs == nil || len(translatedResult.translatedErrs) == 0 {
  118. return translatedResult.err
  119. }
  120. for errStructNamespace, translatedErr := range translatedResult.translatedErrs {
  121. whens := result.whenMap[errStructNamespace]
  122. if strutils.IsStringNotEmpty(when) {
  123. find := false
  124. for _, fieldWhen := range whens {
  125. if fieldWhen == when {
  126. find = true
  127. break
  128. }
  129. }
  130. if !find {
  131. logger.GetInstance().Info("使用的when不存在: when %v, whens %v\n", when, whens)
  132. continue
  133. }
  134. }
  135. if check != nil && !check(strings.Split(errStructNamespace, ".")[1]) {
  136. continue
  137. }
  138. if translatedResult.fieldNameMap != nil {
  139. fieldCNName, ok := translatedResult.fieldNameMap[errStructNamespace]
  140. if ok {
  141. translatedErr = strings.ReplaceAll(translatedErr, errStructNamespace, fieldCNName)
  142. }
  143. }
  144. return errors.New(translatedErr)
  145. }
  146. return nil
  147. }
  148. func (result Result) translation() Result {
  149. if result.err == nil {
  150. return result
  151. }
  152. if result.translatedErrs != nil {
  153. return result
  154. }
  155. var validationErrors validator.ValidationErrors
  156. if !errors.As(result.err, &validationErrors) {
  157. return result
  158. }
  159. translatedErrors := validationErrors.Translate(translator)
  160. for _, validationError := range validationErrors {
  161. if validationError.Tag() == timeNotZeroTag {
  162. translatedErrors[validationError.StructNamespace()] = validationError.Field() + "使用了时间零值"
  163. }
  164. }
  165. return newResult(result.err, translatedErrors, result.fieldNameMap, result.whenMap)
  166. }