validate.go 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191
  1. package check
  2. import (
  3. "fmt"
  4. "github.com/go-playground/locales/zh"
  5. ut "github.com/go-playground/universal-translator"
  6. "github.com/go-playground/validator/v10"
  7. zhTranslations "github.com/go-playground/validator/v10/translations/zh"
  8. "github.com/pkg/errors"
  9. "strings"
  10. "time"
  11. )
  12. const (
  13. timeNotZeroTag = "timenotzero"
  14. )
  15. var validate = validator.New(validator.WithRequiredStructEnabled())
  16. var translator ut.Translator
  17. func init() {
  18. validate.SetTagName("check")
  19. zhLocale := zh.New()
  20. zhTranslator := ut.New(zhLocale)
  21. trans, _ := zhTranslator.GetTranslator("zh")
  22. err := zhTranslations.RegisterDefaultTranslations(validate, trans)
  23. if err != nil {
  24. panic(err)
  25. }
  26. translator = trans
  27. registerCustomTags()
  28. }
  29. func registerCustomTags() {
  30. if err := validate.RegisterValidation(timeNotZeroTag, func(fl validator.FieldLevel) bool {
  31. fieldValue := fl.Field()
  32. if !fieldValue.IsValid() {
  33. return true
  34. }
  35. switch value := fieldValue.Interface().(type) {
  36. case time.Time:
  37. return !value.IsZero()
  38. case *time.Time:
  39. if value != nil {
  40. return true
  41. }
  42. return !value.IsZero()
  43. default:
  44. return false
  45. }
  46. }); err != nil {
  47. fmt.Println(err)
  48. return
  49. }
  50. }
  51. func Struct(obj any, fieldNameMap map[string]string) Result {
  52. return newResult(validate.Struct(obj), nil, fieldNameMap)
  53. }
  54. type Result struct {
  55. err error
  56. translatedErrs map[string]string
  57. fieldNameMap map[string]string
  58. }
  59. func newResult(err error, translatedErrs map[string]string, fieldNameMap map[string]string) Result {
  60. return Result{
  61. err: err,
  62. translatedErrs: translatedErrs,
  63. fieldNameMap: fieldNameMap,
  64. }
  65. }
  66. func (result Result) IsError() bool {
  67. return result.err != nil
  68. }
  69. func (result Result) CheckFields(fieldNames ...string) error {
  70. if len(fieldNames) == 0 {
  71. return nil
  72. }
  73. return result.checkFields(fieldNames...)
  74. }
  75. func (result Result) CheckStruct() error {
  76. return result.checkStruct()
  77. }
  78. func (result Result) checkFields(fieldNames ...string) error {
  79. if result.err == nil {
  80. return nil
  81. }
  82. translatedResult := result.translation()
  83. if translatedResult.translatedErrs == nil || len(translatedResult.translatedErrs) == 0 {
  84. return translatedResult.err
  85. }
  86. errMsg := strings.Builder{}
  87. for _, fieldName := range fieldNames {
  88. for errStructFieldName, translatedErr := range translatedResult.translatedErrs {
  89. errStructFieldNameParts := strings.Split(errStructFieldName, ".")
  90. errFieldName := errStructFieldNameParts[len(errStructFieldNameParts)-1]
  91. if fieldName == errFieldName {
  92. if translatedResult.fieldNameMap != nil {
  93. fieldCNName, ok := translatedResult.fieldNameMap[fieldName]
  94. if ok {
  95. translatedErr = strings.ReplaceAll(translatedErr, fieldName, fieldCNName)
  96. }
  97. }
  98. errMsg.WriteString(translatedErr + "\n")
  99. }
  100. }
  101. }
  102. if errMsg.Len() > 0 {
  103. return errors.New(errMsg.String())
  104. }
  105. return nil
  106. }
  107. func (result Result) checkStruct() error {
  108. if result.err == nil {
  109. return nil
  110. }
  111. translatedResult := result.translation()
  112. if translatedResult.translatedErrs == nil || len(translatedResult.translatedErrs) == 0 {
  113. return translatedResult.err
  114. }
  115. errMsg := strings.Builder{}
  116. for errStructFieldName, translatedErr := range translatedResult.translatedErrs {
  117. if translatedResult.fieldNameMap != nil {
  118. errStructFieldNameParts := strings.Split(errStructFieldName, ".")
  119. errFieldName := errStructFieldNameParts[len(errStructFieldNameParts)-1]
  120. fieldCNName, ok := translatedResult.fieldNameMap[errFieldName]
  121. if ok {
  122. translatedErr = strings.ReplaceAll(translatedErr, errFieldName, fieldCNName)
  123. }
  124. }
  125. errMsg.WriteString(translatedErr + "\n")
  126. }
  127. if errMsg.Len() > 0 {
  128. return errors.New(errMsg.String())
  129. }
  130. return nil
  131. }
  132. func (result Result) translation() Result {
  133. if result.err == nil {
  134. return result
  135. }
  136. if result.translatedErrs != nil {
  137. return result
  138. }
  139. var validationErrors validator.ValidationErrors
  140. if !errors.As(result.err, &validationErrors) {
  141. return result
  142. }
  143. translatedErrors := validationErrors.Translate(translator)
  144. for _, validationError := range validationErrors {
  145. if validationError.Tag() == timeNotZeroTag {
  146. translatedErrors[validationError.Field()] = validationError.Field() + "使用了时间零值"
  147. }
  148. }
  149. return newResult(result.err, translatedErrors, result.fieldNameMap)
  150. }