|
@@ -9,7 +9,6 @@ import (
|
|
|
"github.com/go-playground/validator/v10"
|
|
|
zhTranslations "github.com/go-playground/validator/v10/translations/zh"
|
|
|
"github.com/pkg/errors"
|
|
|
- "reflect"
|
|
|
"strings"
|
|
|
"sync"
|
|
|
"time"
|
|
@@ -24,7 +23,7 @@ var validate = validator.New(validator.WithRequiredStructEnabled())
|
|
|
var translator ut.Translator
|
|
|
|
|
|
var whenMapMutex = sync.RWMutex{}
|
|
|
-var whenMap = make(map[string]map[string][]string)
|
|
|
+var whenMap = make(map[string][]string)
|
|
|
|
|
|
func init() {
|
|
|
validate.SetTagName("check")
|
|
@@ -72,20 +71,15 @@ func registerCustomTags() {
|
|
|
whenMapMutex.Lock()
|
|
|
defer whenMapMutex.Unlock()
|
|
|
|
|
|
- topStructName := reflectutils.PointerValueElem(fl.Top()).Type().String()
|
|
|
- topStructMap, ok := whenMap[topStructName]
|
|
|
- if !ok {
|
|
|
- topStructMap = make(map[string][]string)
|
|
|
- whenMap[topStructName] = topStructMap
|
|
|
- }
|
|
|
-
|
|
|
param := strings.Trim(fl.Param(), "/")
|
|
|
|
|
|
if strutils.IsStringEmpty(param) {
|
|
|
return true
|
|
|
}
|
|
|
|
|
|
- topStructMap[fl.FieldName()] = strings.Split(param, "/")
|
|
|
+ structName := reflectutils.PointerValueElem(fl.Top()).Type().Name()
|
|
|
+ structFieldName := structName + "." + fl.StructFieldName()
|
|
|
+ whenMap[structFieldName] = strings.Split(param, "/")
|
|
|
|
|
|
return true
|
|
|
}); err != nil {
|
|
@@ -95,29 +89,20 @@ func registerCustomTags() {
|
|
|
}
|
|
|
|
|
|
func Struct(obj any, fieldNameMap map[string]string) Result {
|
|
|
- validateErr := validate.Struct(obj)
|
|
|
-
|
|
|
- whenMapMutex.RLock()
|
|
|
- typeName := reflectutils.PointerTypeElem(reflect.TypeOf(obj)).String()
|
|
|
- whenFieldMap := whenMap[typeName]
|
|
|
- whenMapMutex.RUnlock()
|
|
|
-
|
|
|
- return newResult(validateErr, nil, fieldNameMap, whenFieldMap)
|
|
|
+ return newResult(validate.Struct(obj), nil, fieldNameMap)
|
|
|
}
|
|
|
|
|
|
type Result struct {
|
|
|
err error
|
|
|
translatedErrs map[string]string
|
|
|
fieldNameMap map[string]string
|
|
|
- whenFieldMap map[string][]string
|
|
|
}
|
|
|
|
|
|
-func newResult(err error, translatedErrs map[string]string, fieldNameMap map[string]string, whenFieldMap map[string][]string) Result {
|
|
|
+func newResult(err error, translatedErrs map[string]string, fieldNameMap map[string]string) Result {
|
|
|
return Result{
|
|
|
err: err,
|
|
|
translatedErrs: translatedErrs,
|
|
|
fieldNameMap: fieldNameMap,
|
|
|
- whenFieldMap: whenFieldMap,
|
|
|
}
|
|
|
}
|
|
|
|
|
@@ -125,31 +110,11 @@ func (result Result) IsError() bool {
|
|
|
return result.err != nil
|
|
|
}
|
|
|
|
|
|
-func (result Result) CheckFields(fieldNames ...string) error {
|
|
|
- if len(fieldNames) == 0 {
|
|
|
- return nil
|
|
|
- }
|
|
|
-
|
|
|
- return result.checkFields("", fieldNames...)
|
|
|
-}
|
|
|
-
|
|
|
-func (result Result) CheckStruct() error {
|
|
|
- return result.checkStruct("")
|
|
|
+func (result Result) CheckWhen(when string) error {
|
|
|
+ return result.CheckFieldWhen(when, nil)
|
|
|
}
|
|
|
|
|
|
-func (result Result) CheckFieldsWhen(when string, fieldNames ...string) error {
|
|
|
- if len(fieldNames) == 0 {
|
|
|
- return nil
|
|
|
- }
|
|
|
-
|
|
|
- return result.checkFields(when, fieldNames...)
|
|
|
-}
|
|
|
-
|
|
|
-func (result Result) CheckStructWhen(when string) error {
|
|
|
- return result.checkStruct(when)
|
|
|
-}
|
|
|
-
|
|
|
-func (result Result) checkFields(when string, fieldNames ...string) error {
|
|
|
+func (result Result) CheckFieldWhen(when string, check func(fieldName string) bool) error {
|
|
|
if result.err == nil {
|
|
|
return nil
|
|
|
}
|
|
@@ -160,74 +125,37 @@ func (result Result) checkFields(when string, fieldNames ...string) error {
|
|
|
return translatedResult.err
|
|
|
}
|
|
|
|
|
|
- errMsg := strings.Builder{}
|
|
|
-
|
|
|
- for _, fieldName := range fieldNames {
|
|
|
- for errStructFieldName, translatedErr := range translatedResult.translatedErrs {
|
|
|
- errStructFieldNameParts := strings.Split(errStructFieldName, ".")
|
|
|
- errFieldName := errStructFieldNameParts[len(errStructFieldNameParts)-1]
|
|
|
- if fieldName == errFieldName {
|
|
|
- if strutils.IsStringNotEmpty(when) && result.whenFieldMap != nil && len(result.whenFieldMap) != 0 {
|
|
|
- find := false
|
|
|
- for _, fieldWhen := range result.whenFieldMap[fieldName] {
|
|
|
- if fieldWhen == when {
|
|
|
- find = true
|
|
|
- break
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- if !find {
|
|
|
- continue
|
|
|
- }
|
|
|
- }
|
|
|
+ for errStructNamespace, translatedErr := range translatedResult.translatedErrs {
|
|
|
+ whenMapMutex.RLock()
|
|
|
+ whens := whenMap[errStructNamespace]
|
|
|
+ whenMapMutex.RUnlock()
|
|
|
|
|
|
- if translatedResult.fieldNameMap != nil {
|
|
|
- fieldCNName, ok := translatedResult.fieldNameMap[fieldName]
|
|
|
- if ok {
|
|
|
- translatedErr = strings.ReplaceAll(translatedErr, fieldName, fieldCNName)
|
|
|
- }
|
|
|
+ if strutils.IsStringNotEmpty(when) {
|
|
|
+ find := false
|
|
|
+ for _, fieldWhen := range whens {
|
|
|
+ if fieldWhen == when {
|
|
|
+ find = true
|
|
|
+ break
|
|
|
}
|
|
|
+ }
|
|
|
|
|
|
- errMsg.WriteString(translatedErr + "\n")
|
|
|
+ if !find {
|
|
|
+ continue
|
|
|
}
|
|
|
}
|
|
|
- }
|
|
|
|
|
|
- if errMsg.Len() > 0 {
|
|
|
- return errors.New(errMsg.String())
|
|
|
- }
|
|
|
-
|
|
|
- return nil
|
|
|
-}
|
|
|
-
|
|
|
-func (result Result) checkStruct(when string) error {
|
|
|
- if result.err == nil {
|
|
|
- return nil
|
|
|
- }
|
|
|
-
|
|
|
- translatedResult := result.translation()
|
|
|
-
|
|
|
- if translatedResult.translatedErrs == nil || len(translatedResult.translatedErrs) == 0 {
|
|
|
- return translatedResult.err
|
|
|
- }
|
|
|
-
|
|
|
- errMsg := strings.Builder{}
|
|
|
+ if check != nil && !check(strings.Split(errStructNamespace, ".")[1]) {
|
|
|
+ continue
|
|
|
+ }
|
|
|
|
|
|
- for errStructFieldName, translatedErr := range translatedResult.translatedErrs {
|
|
|
if translatedResult.fieldNameMap != nil {
|
|
|
- errStructFieldNameParts := strings.Split(errStructFieldName, ".")
|
|
|
- errFieldName := errStructFieldNameParts[len(errStructFieldNameParts)-1]
|
|
|
- fieldCNName, ok := translatedResult.fieldNameMap[errFieldName]
|
|
|
+ fieldCNName, ok := translatedResult.fieldNameMap[errStructNamespace]
|
|
|
if ok {
|
|
|
- translatedErr = strings.ReplaceAll(translatedErr, errFieldName, fieldCNName)
|
|
|
+ translatedErr = strings.ReplaceAll(translatedErr, errStructNamespace, fieldCNName)
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- errMsg.WriteString(translatedErr + "\n")
|
|
|
- }
|
|
|
-
|
|
|
- if errMsg.Len() > 0 {
|
|
|
- return errors.New(errMsg.String())
|
|
|
+ return errors.New(translatedErr)
|
|
|
}
|
|
|
|
|
|
return nil
|
|
@@ -251,9 +179,9 @@ func (result Result) translation() Result {
|
|
|
|
|
|
for _, validationError := range validationErrors {
|
|
|
if validationError.Tag() == timeNotZeroTag {
|
|
|
- translatedErrors[validationError.Field()] = validationError.Field() + "使用了时间零值"
|
|
|
+ translatedErrors[validationError.StructNamespace()] = validationError.Field() + "使用了时间零值"
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- return newResult(result.err, translatedErrors, result.fieldNameMap, result.whenFieldMap)
|
|
|
+ return newResult(result.err, translatedErrors, result.fieldNameMap)
|
|
|
}
|