|
@@ -2,6 +2,7 @@ package check
|
|
|
|
|
|
import (
|
|
|
"fmt"
|
|
|
+ "git.sxidc.com/go-framework/baize/framework/core/infrastructure/logger"
|
|
|
"git.sxidc.com/go-tools/utils/reflectutils"
|
|
|
"git.sxidc.com/go-tools/utils/strutils"
|
|
|
"github.com/go-playground/locales/zh"
|
|
@@ -9,8 +10,8 @@ import (
|
|
|
"github.com/go-playground/validator/v10"
|
|
|
zhTranslations "github.com/go-playground/validator/v10/translations/zh"
|
|
|
"github.com/pkg/errors"
|
|
|
+ "reflect"
|
|
|
"strings"
|
|
|
- "sync"
|
|
|
"time"
|
|
|
)
|
|
|
|
|
@@ -22,9 +23,6 @@ const (
|
|
|
var validate = validator.New(validator.WithRequiredStructEnabled())
|
|
|
var translator ut.Translator
|
|
|
|
|
|
-var whenMapMutex = sync.RWMutex{}
|
|
|
-var whenMap = make(map[string][]string)
|
|
|
-
|
|
|
func init() {
|
|
|
validate.SetTagName("check")
|
|
|
|
|
@@ -68,19 +66,6 @@ func registerCustomTags() {
|
|
|
}
|
|
|
|
|
|
if err := validate.RegisterValidation(whenTag, func(fl validator.FieldLevel) bool {
|
|
|
- whenMapMutex.Lock()
|
|
|
- defer whenMapMutex.Unlock()
|
|
|
-
|
|
|
- param := strings.Trim(fl.Param(), "/")
|
|
|
-
|
|
|
- if strutils.IsStringEmpty(param) {
|
|
|
- return true
|
|
|
- }
|
|
|
-
|
|
|
- structName := reflectutils.PointerValueElem(fl.Top()).Type().Name()
|
|
|
- structFieldName := structName + "." + fl.StructFieldName()
|
|
|
- whenMap[structFieldName] = strings.Split(param, "/")
|
|
|
-
|
|
|
return true
|
|
|
}); err != nil {
|
|
|
fmt.Println(err)
|
|
@@ -89,20 +74,55 @@ func registerCustomTags() {
|
|
|
}
|
|
|
|
|
|
func Struct(obj any, fieldNameMap map[string]string) Result {
|
|
|
- return newResult(validate.Struct(obj), nil, fieldNameMap)
|
|
|
+ objectType := reflectutils.PointerTypeElem(reflect.TypeOf(obj))
|
|
|
+ whenMap := parseWhenMap(objectType)
|
|
|
+ return newResult(validate.Struct(obj), nil, fieldNameMap, whenMap)
|
|
|
+}
|
|
|
+
|
|
|
+func parseWhenMap(objectType reflect.Type) map[string][]string {
|
|
|
+ whenMap := make(map[string][]string)
|
|
|
+
|
|
|
+ for i := range objectType.NumField() {
|
|
|
+ field := objectType.Field(i)
|
|
|
+
|
|
|
+ if (field.Type.Kind() == reflect.Pointer && field.Type.Elem().Kind() == reflect.Struct) ||
|
|
|
+ field.Type.Kind() == reflect.Struct {
|
|
|
+ fieldWhenMap := parseWhenMap(reflectutils.PointerTypeElem(field.Type))
|
|
|
+ for k, v := range fieldWhenMap {
|
|
|
+ whenMap[objectType.Name()+"."+k] = v
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ fieldCheckTag := field.Tag.Get("check")
|
|
|
+ if strutils.IsStringEmpty(fieldCheckTag) {
|
|
|
+ continue
|
|
|
+ }
|
|
|
+
|
|
|
+ checkTags := strings.Split(fieldCheckTag, ",")
|
|
|
+ for _, checkTag := range checkTags {
|
|
|
+ if strings.HasPrefix(checkTag, "when=") {
|
|
|
+ whens := strings.Split(strings.Split(checkTag, "=")[1], "/")
|
|
|
+ whenMap[objectType.Name()+"."+field.Name] = whens
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ return whenMap
|
|
|
}
|
|
|
|
|
|
type Result struct {
|
|
|
err error
|
|
|
translatedErrs map[string]string
|
|
|
fieldNameMap map[string]string
|
|
|
+ whenMap map[string][]string
|
|
|
}
|
|
|
|
|
|
-func newResult(err error, translatedErrs map[string]string, fieldNameMap map[string]string) Result {
|
|
|
+func newResult(err error, translatedErrs map[string]string, fieldNameMap map[string]string, whenMap map[string][]string) Result {
|
|
|
return Result{
|
|
|
err: err,
|
|
|
translatedErrs: translatedErrs,
|
|
|
fieldNameMap: fieldNameMap,
|
|
|
+ whenMap: whenMap,
|
|
|
}
|
|
|
}
|
|
|
|
|
@@ -126,9 +146,7 @@ func (result Result) CheckFieldWhen(when string, check func(fieldName string) bo
|
|
|
}
|
|
|
|
|
|
for errStructNamespace, translatedErr := range translatedResult.translatedErrs {
|
|
|
- whenMapMutex.RLock()
|
|
|
- whens := whenMap[errStructNamespace]
|
|
|
- whenMapMutex.RUnlock()
|
|
|
+ whens := result.whenMap[errStructNamespace]
|
|
|
|
|
|
if strutils.IsStringNotEmpty(when) {
|
|
|
find := false
|
|
@@ -140,6 +158,7 @@ func (result Result) CheckFieldWhen(when string, check func(fieldName string) bo
|
|
|
}
|
|
|
|
|
|
if !find {
|
|
|
+ logger.GetInstance().Error(errors.New("使用的when不存在"))
|
|
|
continue
|
|
|
}
|
|
|
}
|
|
@@ -183,5 +202,5 @@ func (result Result) translation() Result {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- return newResult(result.err, translatedErrors, result.fieldNameMap)
|
|
|
+ return newResult(result.err, translatedErrors, result.fieldNameMap, result.whenMap)
|
|
|
}
|