Browse Source

完成check tag扩展

yjp 4 months ago
parent
commit
01c67a7fd0
2 changed files with 127 additions and 11 deletions
  1. 75 7
      framework/core/tag/check/validate.go
  2. 52 4
      test/check_tag_test.go

+ 75 - 7
framework/core/tag/check/validate.go

@@ -2,22 +2,30 @@ package check
 
 import (
 	"fmt"
+	"git.sxidc.com/go-tools/utils/reflectutils"
+	"git.sxidc.com/go-tools/utils/strutils"
 	"github.com/go-playground/locales/zh"
 	ut "github.com/go-playground/universal-translator"
 	"github.com/go-playground/validator/v10"
 	zhTranslations "github.com/go-playground/validator/v10/translations/zh"
 	"github.com/pkg/errors"
+	"reflect"
 	"strings"
+	"sync"
 	"time"
 )
 
 const (
 	timeNotZeroTag = "timenotzero"
+	whenTag        = "when"
 )
 
 var validate = validator.New(validator.WithRequiredStructEnabled())
 var translator ut.Translator
 
+var whenMapMutex = sync.RWMutex{}
+var whenMap = make(map[string]map[string][]string)
+
 func init() {
 	validate.SetTagName("check")
 
@@ -59,23 +67,57 @@ func registerCustomTags() {
 		fmt.Println(err)
 		return
 	}
+
+	if err := validate.RegisterValidation(whenTag, func(fl validator.FieldLevel) bool {
+		whenMapMutex.Lock()
+		defer whenMapMutex.Unlock()
+
+		topStructName := reflectutils.PointerValueElem(fl.Top()).Type().String()
+		topStructMap, ok := whenMap[topStructName]
+		if !ok {
+			topStructMap = make(map[string][]string)
+			whenMap[topStructName] = topStructMap
+		}
+
+		param := strings.Trim(fl.Param(), "/")
+
+		if strutils.IsStringEmpty(param) {
+			return true
+		}
+
+		topStructMap[fl.FieldName()] = strings.Split(param, "/")
+
+		return true
+	}); err != nil {
+		fmt.Println(err)
+		return
+	}
 }
 
 func Struct(obj any, fieldNameMap map[string]string) Result {
-	return newResult(validate.Struct(obj), nil, fieldNameMap)
+	validateErr := validate.Struct(obj)
+
+	whenMapMutex.RLock()
+	typeName := reflectutils.PointerTypeElem(reflect.TypeOf(obj)).String()
+	whenFieldMap := whenMap[typeName]
+	whenMapMutex.RUnlock()
+
+	return newResult(validateErr, nil, fieldNameMap, whenFieldMap)
 }
 
 type Result struct {
 	err            error
 	translatedErrs map[string]string
 	fieldNameMap   map[string]string
+	whenFieldMap   map[string][]string
 }
 
-func newResult(err error, translatedErrs map[string]string, fieldNameMap map[string]string) Result {
+func newResult(err error, translatedErrs map[string]string, fieldNameMap map[string]string, whenFieldMap map[string][]string) Result {
 	return Result{
 		err:            err,
 		translatedErrs: translatedErrs,
 		fieldNameMap:   fieldNameMap,
+		whenFieldMap:   whenFieldMap,
 	}
 }
 
@@ -88,14 +130,26 @@ func (result Result) CheckFields(fieldNames ...string) error {
 		return nil
 	}
 
-	return result.checkFields(fieldNames...)
+	return result.checkFields("", fieldNames...)
 }
 
 func (result Result) CheckStruct() error {
-	return result.checkStruct()
+	return result.checkStruct("")
+}
+
+func (result Result) CheckFieldsWhen(when string, fieldNames ...string) error {
+	if len(fieldNames) == 0 {
+		return nil
+	}
+
+	return result.checkFields(when, fieldNames...)
 }
 
-func (result Result) checkFields(fieldNames ...string) error {
+func (result Result) CheckStructWhen(when string) error {
+	return result.checkStruct(when)
+}
+
+func (result Result) checkFields(when string, fieldNames ...string) error {
 	if result.err == nil {
 		return nil
 	}
@@ -113,6 +167,20 @@ func (result Result) checkFields(fieldNames ...string) error {
 			errStructFieldNameParts := strings.Split(errStructFieldName, ".")
 			errFieldName := errStructFieldNameParts[len(errStructFieldNameParts)-1]
 			if fieldName == errFieldName {
+				if strutils.IsStringNotEmpty(when) && result.whenFieldMap != nil && len(result.whenFieldMap) != 0 {
+					find := false
+					for _, fieldWhen := range result.whenFieldMap[fieldName] {
+						if fieldWhen == when {
+							find = true
+							break
+						}
+					}
+
+					if !find {
+						continue
+					}
+				}
+
 				if translatedResult.fieldNameMap != nil {
 					fieldCNName, ok := translatedResult.fieldNameMap[fieldName]
 					if ok {
@@ -132,7 +200,7 @@ func (result Result) checkFields(fieldNames ...string) error {
 	return nil
 }
 
-func (result Result) checkStruct() error {
+func (result Result) checkStruct(when string) error {
 	if result.err == nil {
 		return nil
 	}
@@ -187,5 +255,5 @@ func (result Result) translation() Result {
 		}
 	}
 
-	return newResult(result.err, translatedErrors, result.fieldNameMap)
+	return newResult(result.err, translatedErrors, result.fieldNameMap, result.whenFieldMap)
 }

+ 52 - 4
test/check_tag_test.go

@@ -9,11 +9,11 @@ import (
 )
 
 type CustomCheckTagStruct struct {
-	Time time.Time `check:"timenotzero"`
+	Time time.Time `check:"timenotzero,when=create/delete/update/foo"`
 }
 
 type CustomCheckTagPointerStruct struct {
-	Time *time.Time `check:"timenotzero"`
+	Time *time.Time `check:"timenotzero,when=create/delete/update/foo"`
 }
 
 var fieldMap = map[string]string{
@@ -73,12 +73,60 @@ func TestCustomCheckTag(t *testing.T) {
 	err = check.Struct(&customCheckTagStructZero, fieldMap).
 		CheckFields("Time")
 	if err == nil || !strings.Contains(err.Error(), "使用了时间零值") {
-		t.Fatalf("%+v\n", errors.Errorf(err.Error()))
+		t.Fatalf("%+v\n", errors.New("没有检测出使用了时间零值"))
 	}
 
 	err = check.Struct(&customCheckTagPointerStructZero, fieldMap).
 		CheckFields("Time")
 	if err == nil || !strings.Contains(err.Error(), "使用了时间零值") {
-		t.Fatalf("%+v\n", errors.Errorf(err.Error()))
+		t.Fatalf("%+v\n", errors.New("没有检测出使用了时间零值"))
+	}
+
+	err = check.Struct(customCheckTagStruct, fieldMap).
+		CheckFieldsWhen("create", "Time")
+	if err != nil {
+		t.Fatalf("%+v\n", err)
+	}
+
+	err = check.Struct(customCheckTagPointerStruct, fieldMap).
+		CheckFieldsWhen("create", "Time")
+	if err != nil {
+		t.Fatalf("%+v\n", err)
+	}
+
+	err = check.Struct(&customCheckTagStruct, fieldMap).
+		CheckFieldsWhen("delete", "Time")
+	if err != nil {
+		t.Fatalf("%+v\n", err)
+	}
+
+	err = check.Struct(&customCheckTagPointerStruct, fieldMap).
+		CheckFieldsWhen("delete", "Time")
+	if err != nil {
+		t.Fatalf("%+v\n", err)
+	}
+
+	err = check.Struct(customCheckTagStructZero, fieldMap).
+		CheckFieldsWhen("update", "Time")
+	if err == nil || !strings.Contains(err.Error(), "使用了时间零值") {
+		t.Fatalf("%+v\n", errors.New("没有检测出使用了时间零值"))
+	}
+
+	err = check.Struct(customCheckTagPointerStructZero, fieldMap).
+		CheckFieldsWhen("update", "Time")
+	if err == nil || !strings.Contains(err.Error(), "使用了时间零值") {
+		t.Fatalf("%+v\n", errors.New("没有检测出使用了时间零值"))
+	}
+
+	err = check.Struct(&customCheckTagStructZero, fieldMap).
+		CheckFieldsWhen("foo", "Time")
+	if err == nil || !strings.Contains(err.Error(), "使用了时间零值") {
+		t.Fatalf("%+v\n", errors.New("没有检测出使用了时间零值"))
+	}
+
+	err = check.Struct(&customCheckTagPointerStructZero, fieldMap).
+		CheckFieldsWhen("foo", "Time")
+	if err == nil || !strings.Contains(err.Error(), "使用了时间零值") {
+		t.Fatalf("%+v\n", errors.New("没有检测出使用了时间零值"))
 	}
 }