123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206 |
- package check
- import (
- "fmt"
- "git.sxidc.com/go-framework/baize/framework/core/infrastructure/logger"
- "git.sxidc.com/go-tools/utils/reflectutils"
- "git.sxidc.com/go-tools/utils/strutils"
- "github.com/go-playground/locales/zh"
- ut "github.com/go-playground/universal-translator"
- "github.com/go-playground/validator/v10"
- zhTranslations "github.com/go-playground/validator/v10/translations/zh"
- "github.com/pkg/errors"
- "reflect"
- "strings"
- "time"
- )
- const (
- timeNotZeroTag = "timenotzero"
- whenTag = "when"
- )
- var validate = validator.New(validator.WithRequiredStructEnabled())
- var translator ut.Translator
- func init() {
- validate.SetTagName("check")
- zhLocale := zh.New()
- zhTranslator := ut.New(zhLocale)
- trans, _ := zhTranslator.GetTranslator("zh")
- err := zhTranslations.RegisterDefaultTranslations(validate, trans)
- if err != nil {
- panic(err)
- }
- translator = trans
- registerCustomTags()
- }
- func registerCustomTags() {
- if err := validate.RegisterValidation(timeNotZeroTag, func(fl validator.FieldLevel) bool {
- fieldValue := fl.Field()
- if !fieldValue.IsValid() {
- return true
- }
- switch value := fieldValue.Interface().(type) {
- case time.Time:
- return !value.IsZero()
- case *time.Time:
- if value != nil {
- return true
- }
- return !value.IsZero()
- default:
- return false
- }
- }); err != nil {
- fmt.Println(err)
- return
- }
- if err := validate.RegisterValidation(whenTag, func(fl validator.FieldLevel) bool {
- return true
- }); err != nil {
- fmt.Println(err)
- return
- }
- }
- func Struct(obj any, fieldNameMap map[string]string) Result {
- objectType := reflectutils.PointerTypeElem(reflect.TypeOf(obj))
- whenMap := parseWhenMap(objectType)
- return newResult(validate.Struct(obj), nil, fieldNameMap, whenMap)
- }
- func parseWhenMap(objectType reflect.Type) map[string][]string {
- whenMap := make(map[string][]string)
- for i := range objectType.NumField() {
- field := objectType.Field(i)
- if (field.Type.Kind() == reflect.Pointer && field.Type.Elem().Kind() == reflect.Struct) ||
- field.Type.Kind() == reflect.Struct {
- fieldWhenMap := parseWhenMap(reflectutils.PointerTypeElem(field.Type))
- for k, v := range fieldWhenMap {
- whenMap[objectType.Name()+"."+k] = v
- }
- }
- fieldCheckTag := field.Tag.Get("check")
- if strutils.IsStringEmpty(fieldCheckTag) {
- continue
- }
- checkTags := strings.Split(fieldCheckTag, ",")
- for _, checkTag := range checkTags {
- if strings.HasPrefix(checkTag, "when=") {
- whens := strings.Split(strings.Split(checkTag, "=")[1], "/")
- whenMap[objectType.Name()+"."+field.Name] = whens
- }
- }
- }
- return whenMap
- }
- type Result struct {
- err error
- translatedErrs map[string]string
- fieldNameMap map[string]string
- whenMap map[string][]string
- }
- func newResult(err error, translatedErrs map[string]string, fieldNameMap map[string]string, whenMap map[string][]string) Result {
- return Result{
- err: err,
- translatedErrs: translatedErrs,
- fieldNameMap: fieldNameMap,
- whenMap: whenMap,
- }
- }
- func (result Result) IsError() bool {
- return result.err != nil
- }
- func (result Result) CheckWhen(when string) error {
- return result.CheckFieldWhen(when, nil)
- }
- func (result Result) CheckFieldWhen(when string, check func(fieldName string) bool) error {
- if result.err == nil {
- return nil
- }
- translatedResult := result.translation()
- if translatedResult.translatedErrs == nil || len(translatedResult.translatedErrs) == 0 {
- return translatedResult.err
- }
- for errStructNamespace, translatedErr := range translatedResult.translatedErrs {
- whens := result.whenMap[errStructNamespace]
- if strutils.IsStringNotEmpty(when) {
- find := false
- for _, fieldWhen := range whens {
- if fieldWhen == when {
- find = true
- break
- }
- }
- if !find {
- logger.GetInstance().Info("使用的when不存在: when %v, whens %v\n", when, whens)
- continue
- }
- }
- if check != nil && !check(strings.Split(errStructNamespace, ".")[1]) {
- continue
- }
- if translatedResult.fieldNameMap != nil {
- fieldCNName, ok := translatedResult.fieldNameMap[errStructNamespace]
- if ok {
- translatedErr = strings.ReplaceAll(translatedErr, errStructNamespace, fieldCNName)
- }
- }
- return errors.New(translatedErr)
- }
- return nil
- }
- func (result Result) translation() Result {
- if result.err == nil {
- return result
- }
- if result.translatedErrs != nil {
- return result
- }
- var validationErrors validator.ValidationErrors
- if !errors.As(result.err, &validationErrors) {
- return result
- }
- translatedErrors := validationErrors.Translate(translator)
- for _, validationError := range validationErrors {
- if validationError.Tag() == timeNotZeroTag {
- translatedErrors[validationError.StructNamespace()] = validationError.Field() + "使用了时间零值"
- }
- }
- return newResult(result.err, translatedErrors, result.fieldNameMap, result.whenMap)
- }
|