浏览代码

修改tag的实现机制

yjp 4 月之前
父节点
当前提交
286b71bbdc
共有 1 个文件被更改,包括 42 次插入23 次删除
  1. 42 23
      framework/core/tag/check/validate.go

+ 42 - 23
framework/core/tag/check/validate.go

@@ -2,6 +2,7 @@ package check
 
 import (
 	"fmt"
+	"git.sxidc.com/go-framework/baize/framework/core/infrastructure/logger"
 	"git.sxidc.com/go-tools/utils/reflectutils"
 	"git.sxidc.com/go-tools/utils/strutils"
 	"github.com/go-playground/locales/zh"
@@ -9,8 +10,8 @@ import (
 	"github.com/go-playground/validator/v10"
 	zhTranslations "github.com/go-playground/validator/v10/translations/zh"
 	"github.com/pkg/errors"
+	"reflect"
 	"strings"
-	"sync"
 	"time"
 )
 
@@ -22,9 +23,6 @@ const (
 var validate = validator.New(validator.WithRequiredStructEnabled())
 var translator ut.Translator
 
-var whenMapMutex = sync.RWMutex{}
-var whenMap = make(map[string][]string)
-
 func init() {
 	validate.SetTagName("check")
 
@@ -68,19 +66,6 @@ func registerCustomTags() {
 	}
 
 	if err := validate.RegisterValidation(whenTag, func(fl validator.FieldLevel) bool {
-		whenMapMutex.Lock()
-		defer whenMapMutex.Unlock()
-
-		param := strings.Trim(fl.Param(), "/")
-
-		if strutils.IsStringEmpty(param) {
-			return true
-		}
-
-		structName := reflectutils.PointerValueElem(fl.Top()).Type().Name()
-		structFieldName := structName + "." + fl.StructFieldName()
-		whenMap[structFieldName] = strings.Split(param, "/")
-
 		return true
 	}); err != nil {
 		fmt.Println(err)
@@ -89,20 +74,55 @@ func registerCustomTags() {
 }
 
 func Struct(obj any, fieldNameMap map[string]string) Result {
-	return newResult(validate.Struct(obj), nil, fieldNameMap)
+	objectType := reflectutils.PointerTypeElem(reflect.TypeOf(obj))
+	whenMap := parseWhenMap(objectType)
+	return newResult(validate.Struct(obj), nil, fieldNameMap, whenMap)
+}
+
+func parseWhenMap(objectType reflect.Type) map[string][]string {
+	whenMap := make(map[string][]string)
+
+	for i := range objectType.NumField() {
+		field := objectType.Field(i)
+
+		if (field.Type.Kind() == reflect.Pointer && field.Type.Elem().Kind() == reflect.Struct) ||
+			field.Type.Kind() == reflect.Struct {
+			fieldWhenMap := parseWhenMap(reflectutils.PointerTypeElem(field.Type))
+			for k, v := range fieldWhenMap {
+				whenMap[objectType.Name()+"."+k] = v
+			}
+		}
+
+		fieldCheckTag := field.Tag.Get("check")
+		if strutils.IsStringEmpty(fieldCheckTag) {
+			continue
+		}
+
+		checkTags := strings.Split(fieldCheckTag, ",")
+		for _, checkTag := range checkTags {
+			if strings.HasPrefix(checkTag, "when=") {
+				whens := strings.Split(strings.Split(checkTag, "=")[1], "/")
+				whenMap[objectType.Name()+"."+field.Name] = whens
+			}
+		}
+	}
+
+	return whenMap
 }
 
 type Result struct {
 	err            error
 	translatedErrs map[string]string
 	fieldNameMap   map[string]string
+	whenMap        map[string][]string
 }
 
-func newResult(err error, translatedErrs map[string]string, fieldNameMap map[string]string) Result {
+func newResult(err error, translatedErrs map[string]string, fieldNameMap map[string]string, whenMap map[string][]string) Result {
 	return Result{
 		err:            err,
 		translatedErrs: translatedErrs,
 		fieldNameMap:   fieldNameMap,
+		whenMap:        whenMap,
 	}
 }
 
@@ -126,9 +146,7 @@ func (result Result) CheckFieldWhen(when string, check func(fieldName string) bo
 	}
 
 	for errStructNamespace, translatedErr := range translatedResult.translatedErrs {
-		whenMapMutex.RLock()
-		whens := whenMap[errStructNamespace]
-		whenMapMutex.RUnlock()
+		whens := result.whenMap[errStructNamespace]
 
 		if strutils.IsStringNotEmpty(when) {
 			find := false
@@ -140,6 +158,7 @@ func (result Result) CheckFieldWhen(when string, check func(fieldName string) bo
 			}
 
 			if !find {
+				logger.GetInstance().Error(errors.New("使用的when不存在"))
 				continue
 			}
 		}
@@ -183,5 +202,5 @@ func (result Result) translation() Result {
 		}
 	}
 
-	return newResult(result.err, translatedErrors, result.fieldNameMap)
+	return newResult(result.err, translatedErrors, result.fieldNameMap, result.whenMap)
 }