package check import ( "fmt" "git.sxidc.com/go-framework/baize/framework/core/infrastructure/logger" "git.sxidc.com/go-tools/utils/reflectutils" "git.sxidc.com/go-tools/utils/strutils" "github.com/go-playground/locales/zh" ut "github.com/go-playground/universal-translator" "github.com/go-playground/validator/v10" zhTranslations "github.com/go-playground/validator/v10/translations/zh" "github.com/pkg/errors" "reflect" "strings" "time" ) const ( timeNotZeroTag = "timenotzero" whenTag = "when" ) var validate = validator.New(validator.WithRequiredStructEnabled()) var translator ut.Translator func init() { validate.SetTagName("check") zhLocale := zh.New() zhTranslator := ut.New(zhLocale) trans, _ := zhTranslator.GetTranslator("zh") err := zhTranslations.RegisterDefaultTranslations(validate, trans) if err != nil { panic(err) } translator = trans registerCustomTags() } func registerCustomTags() { if err := validate.RegisterValidation(timeNotZeroTag, func(fl validator.FieldLevel) bool { fieldValue := fl.Field() if !fieldValue.IsValid() { return true } switch value := fieldValue.Interface().(type) { case time.Time: return !value.IsZero() case *time.Time: if value != nil { return true } return !value.IsZero() default: return false } }); err != nil { fmt.Println(err) return } if err := validate.RegisterValidation(whenTag, func(fl validator.FieldLevel) bool { return true }); err != nil { fmt.Println(err) return } } func Struct(obj any, fieldNameMap map[string]string) Result { objectType := reflectutils.PointerTypeElem(reflect.TypeOf(obj)) whenMap := parseWhenMap(objectType) return newResult(validate.Struct(obj), nil, fieldNameMap, whenMap) } func parseWhenMap(objectType reflect.Type) map[string][]string { whenMap := make(map[string][]string) for i := range objectType.NumField() { field := objectType.Field(i) if (field.Type.Kind() == reflect.Pointer && field.Type.Elem().Kind() == reflect.Struct) || field.Type.Kind() == reflect.Struct { fieldWhenMap := parseWhenMap(reflectutils.PointerTypeElem(field.Type)) for k, v := range fieldWhenMap { whenMap[objectType.Name()+"."+k] = v } } fieldCheckTag := field.Tag.Get("check") if strutils.IsStringEmpty(fieldCheckTag) { continue } checkTags := strings.Split(fieldCheckTag, ",") for _, checkTag := range checkTags { if strings.HasPrefix(checkTag, "when=") { whens := strings.Split(strings.Split(checkTag, "=")[1], "/") whenMap[objectType.Name()+"."+field.Name] = whens } } } return whenMap } type Result struct { err error translatedErrs map[string]string fieldNameMap map[string]string whenMap map[string][]string } func newResult(err error, translatedErrs map[string]string, fieldNameMap map[string]string, whenMap map[string][]string) Result { return Result{ err: err, translatedErrs: translatedErrs, fieldNameMap: fieldNameMap, whenMap: whenMap, } } func (result Result) IsError() bool { return result.err != nil } func (result Result) CheckWhen(when string) error { return result.CheckFieldWhen(when, nil) } func (result Result) CheckFieldWhen(when string, check func(fieldName string) bool) error { if result.err == nil { return nil } translatedResult := result.translation() if translatedResult.translatedErrs == nil || len(translatedResult.translatedErrs) == 0 { return translatedResult.err } for errStructNamespace, translatedErr := range translatedResult.translatedErrs { whens := result.whenMap[errStructNamespace] if strutils.IsStringNotEmpty(when) { find := false for _, fieldWhen := range whens { if fieldWhen == when { find = true break } } if !find { logger.GetInstance().Info("使用的when不存在: when %v, whens %v\n", when, whens) continue } } if check != nil && !check(strings.Split(errStructNamespace, ".")[1]) { continue } if translatedResult.fieldNameMap != nil { fieldCNName, ok := translatedResult.fieldNameMap[errStructNamespace] if ok { translatedErr = strings.ReplaceAll(translatedErr, errStructNamespace, fieldCNName) } } return errors.New(translatedErr) } return nil } func (result Result) translation() Result { if result.err == nil { return result } if result.translatedErrs != nil { return result } var validationErrors validator.ValidationErrors if !errors.As(result.err, &validationErrors) { return result } translatedErrors := validationErrors.Translate(translator) for _, validationError := range validationErrors { if validationError.Tag() == timeNotZeroTag { translatedErrors[validationError.StructNamespace()] = validationError.Field() + "使用了时间零值" } } return newResult(result.err, translatedErrors, result.fieldNameMap, result.whenMap) }