Browse Source

修改tag

yjp 1 year ago
parent
commit
3dd7833f56
4 changed files with 39 additions and 70 deletions
  1. 2 1
      examples/assign_tag/main.go
  2. 2 3
      go.mod
  3. 2 4
      go.sum
  4. 33 62
      tag/assign_struct.go

+ 2 - 1
examples/assign_tag/main.go

@@ -37,7 +37,8 @@ func main() {
 		CreateTime: &now,
 	}
 
-	class, err := tag.AssignTo[*ClassDomain](jsonBody)
+	class := new(ClassDomain)
+	err := tag.AssignTo(jsonBody, class)
 	if err != nil {
 		panic(err)
 	}

+ 2 - 3
go.mod

@@ -3,11 +3,11 @@ module git.sxidc.com/go-framework/baize
 go 1.22.3
 
 require (
-	git.sxidc.com/go-tools/utils v1.5.7
+	git.sxidc.com/go-tools/utils v1.5.8
 	git.sxidc.com/service-supports/fserr v0.3.5
+	git.sxidc.com/service-supports/fslog v0.5.9
 	git.sxidc.com/service-supports/websocket v1.3.1
 	github.com/gin-gonic/gin v1.10.0
-	github.com/iancoleman/strcase v0.3.0
 	github.com/vrecan/death v3.0.1+incompatible
 	go.uber.org/zap v1.27.0
 	gopkg.in/natefinch/lumberjack.v2 v2.2.1
@@ -15,7 +15,6 @@ require (
 )
 
 require (
-	git.sxidc.com/service-supports/fslog v0.5.9 // indirect
 	github.com/bytedance/sonic v1.11.6 // indirect
 	github.com/bytedance/sonic/loader v0.1.1 // indirect
 	github.com/cihub/seelog v0.0.0-20170130134532-f561c5e57575 // indirect

+ 2 - 4
go.sum

@@ -1,5 +1,5 @@
-git.sxidc.com/go-tools/utils v1.5.7 h1:vVtOfbyZPdmogyWGti2hqctv3uy+2R6Rqg88Ge3CCUg=
-git.sxidc.com/go-tools/utils v1.5.7/go.mod h1:fkobAXFpOMTvkZ82TQXWcpsayePcyk/MS5TN6GTlRDg=
+git.sxidc.com/go-tools/utils v1.5.8 h1:hjKQ8kcUy2XJ2alkwmJvRE+QqsrCAf8biBIPFkVmjnc=
+git.sxidc.com/go-tools/utils v1.5.8/go.mod h1:fkobAXFpOMTvkZ82TQXWcpsayePcyk/MS5TN6GTlRDg=
 git.sxidc.com/service-supports/fserr v0.3.5 h1:1SDC60r3FIDd2iRq/oHRLK4OMa1gf67h9B7kierKTUE=
 git.sxidc.com/service-supports/fserr v0.3.5/go.mod h1:8U+W/ulZIGVPFojV6cE18shkGXqvaICuzaxIJpOcBqI=
 git.sxidc.com/service-supports/fslog v0.5.9 h1:q2XIK2o/fk/qmByy4x5kKLC+k7kolT5LrXHcWRSffXQ=
@@ -42,8 +42,6 @@ github.com/gopherjs/gopherjs v1.17.2 h1:fQnZVsXk8uxXIStYb0N4bGk7jeyTalG/wsZjQ25d
 github.com/gopherjs/gopherjs v1.17.2/go.mod h1:pRRIvn/QzFLrKfvEz3qUuEhtE/zLCWfreZ6J5gM2i+k=
 github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc=
 github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
-github.com/iancoleman/strcase v0.3.0 h1:nTXanmYxhfFAMjZL34Ov6gkzEsSJZ5DbhxWjvSASxEI=
-github.com/iancoleman/strcase v0.3.0/go.mod h1:iwCmte+B7n89clKwxIoIXy/HfoL7AsD47ZCWhYzw7ho=
 github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
 github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
 github.com/jtolds/gls v4.20.0+incompatible h1:xdiiI2gbIgH/gLH7ADydsJ1uDOEzR8yvV7C0MuV77Wo=

+ 33 - 62
tag/assign_struct.go

@@ -10,76 +10,55 @@ import (
 	"time"
 )
 
-func AssignTo[T any](from any) (T, error) {
-	var zero T
-
-	if from == nil {
-		return zero, nil
+func AssignTo(from any, to any) error {
+	if from == nil || to == nil {
+		return nil
 	}
 
 	fromValue := reflect.ValueOf(from)
-	retType := reflect.TypeOf(zero)
-	retValue := reflect.New(retType).Elem()
-	if retValue.Kind() == reflect.Ptr {
-		retValue.Set(reflect.New(retType.Elem()))
-	}
+	retValue := reflect.ValueOf(to)
 
 	// 类型校验
-	if fromValue.Kind() != reflect.Ptr && fromValue.Kind() != reflect.Struct {
-		return zero, fserr.New("参数不是结构或结构指针")
-	}
-
-	if fromValue.Kind() == reflect.Ptr && fromValue.Elem().Kind() != reflect.Struct {
-		return zero, fserr.New("参数不是结构或结构指针")
-	}
-
-	if retValue.Kind() != reflect.Ptr && retValue.Kind() != reflect.Struct {
-		return zero, fserr.New("返回类型不是结构或结构指针")
+	if !reflectutils.IsValueStructOrStructPointer(fromValue) {
+		return fserr.New("参数不是结构或结构指针")
 	}
 
-	if retValue.Kind() == reflect.Ptr && retValue.Elem().Kind() != reflect.Struct {
-		return zero, fserr.New("返回类型不是结构或结构指针")
+	if !reflectutils.IsValueStructPointer(retValue) {
+		return fserr.New("返回类型不是结构指针")
 	}
 
-	fromElemValue := fromValue
-	if fromValue.Kind() == reflect.Ptr {
-		fromElemValue = fromValue.Elem()
-	}
-
-	retElemValue := retValue
-	if retValue.Kind() == reflect.Ptr {
-		retElemValue = retValue.Elem()
-	}
+	fromElemValue := reflectutils.PointerValueElem(fromValue)
+	retElemValue := reflectutils.PointerValueElem(retValue)
 
 	err := assignTo(fromElemValue, &retElemValue)
 	if err != nil {
-		return zero, err
+		return err
 	}
 
-	return retValue.Interface().(T), nil
+	return nil
 }
 
 func assignTo(fromElemValue reflect.Value, retElemValue *reflect.Value) error {
 	for i := 0; i < fromElemValue.NumField(); i++ {
 		fromField := fromElemValue.Type().Field(i)
+		fromFieldValue := fromElemValue.Field(i)
+
+		// 无效零值,不进行赋值
+		if !fromFieldValue.IsValid() || fromFieldValue.IsZero() {
+			continue
+		}
+
 		tagStr := fromField.Tag.Get(assignTagKey)
 
 		// 结构上没有添加Tag, 先尝试直接按照字段赋值结构,如果失败,进一步进入内部尝试
 		if strutils.IsStringEmpty(tagStr) &&
-			((fromField.Type.Kind() == reflect.Struct && fromField.Type.String() != "time.Time") ||
-				(fromField.Type.Kind() == reflect.Ptr &&
-					fromField.Type.Elem().Kind() == reflect.Struct &&
-					fromField.Type.Elem().String() != "time.Time")) {
-			fromStructElemValue := fromElemValue.Field(i)
-			if fromField.Type.Kind() == reflect.Ptr {
-				if !fromStructElemValue.IsValid() || fromStructElemValue.IsZero() {
-					continue
-				}
-
-				fromStructElemValue = fromStructElemValue.Elem()
+			!reflectutils.IsValueTime(fromFieldValue) &&
+			!reflectutils.IsValueTimePointer(fromFieldValue) {
+			if !fromFieldValue.IsValid() || fromFieldValue.IsZero() {
+				continue
 			}
 
-			err := assignTo(fromStructElemValue, retElemValue)
+			err := assignTo(reflectutils.PointerValueElem(fromFieldValue), retElemValue)
 			if err != nil {
 				return err
 			}
@@ -94,18 +73,6 @@ func assignTo(fromElemValue reflect.Value, retElemValue *reflect.Value) error {
 			continue
 		}
 
-		fromFieldValue := fromElemValue.Field(i)
-
-		// 无效零值,不进行赋值
-		if !fromFieldValue.IsValid() || fromFieldValue.IsZero() {
-			continue
-		}
-
-		fromFieldElemValue := fromFieldValue
-		if fromFieldValue.Kind() == reflect.Ptr {
-			fromFieldElemValue = fromFieldValue.Elem()
-		}
-
 		retFieldValue := retElemValue.FieldByName(tag.ToField)
 
 		// 不存在对应的字段
@@ -131,6 +98,7 @@ func assignTo(fromElemValue reflect.Value, retElemValue *reflect.Value) error {
 			retFieldElemValue = retFieldValue.Elem()
 		}
 
+		fromFieldElemValue := reflectutils.PointerValueElem(fromFieldValue)
 		err = assignField(fromFieldElemValue, retFieldElemValue, tag)
 		if err != nil {
 			return err
@@ -147,19 +115,22 @@ func assignField(fromFieldElemValue reflect.Value, retFieldElemValue reflect.Val
 	var fromAny any
 	switch fromKind {
 	case reflect.Struct:
-		if fromFieldElemValue.Type().String() == "time.Time" && retKind == reflect.String {
+		// time.Time类型的结构,接收字段是string类型,使用FormatTime的格式转换
+		if reflectutils.IsValueTime(fromFieldElemValue) && retKind == reflect.String {
 			fromString := fromFieldElemValue.Interface().(time.Time).Format(tag.FormatTime)
 			fromAny = trimFromString(fromString, tag)
 			break
 		}
 
-		if fromFieldElemValue.Type().String() != "time.Time" && retKind == reflect.Struct {
+		// 不是time.Time类型的结构,接收字段是结构,执行结构到结构字段的赋值
+		if !reflectutils.IsValueTime(fromFieldElemValue) && retKind == reflect.Struct {
 			return assignTo(fromFieldElemValue, &retFieldElemValue)
 		}
 
+		// 直接将整个结构进行字段赋值
 		fromAny = fromFieldElemValue.Interface()
 	case reflect.Slice:
-		if fromFieldElemValue.Elem().Kind() == reflect.String && retKind == reflect.String {
+		if reflectutils.IsSliceValueOf(fromFieldElemValue, reflect.String) && retKind == reflect.String {
 			fromString := strings.Join(fromFieldElemValue.Interface().([]string), tag.JoinWith)
 			fromAny = trimFromString(fromString, tag)
 			break
@@ -169,7 +140,7 @@ func assignField(fromFieldElemValue reflect.Value, retFieldElemValue reflect.Val
 	case reflect.String:
 		fromString := fromFieldElemValue.String()
 
-		if retKind == reflect.Struct && retFieldElemValue.Type().String() == "time.Time" {
+		if reflectutils.IsValueTime(retFieldElemValue) {
 			retTimeField, err := time.ParseInLocation(tag.ParseTime, fromString, time.Local)
 			if err != nil {
 				return err
@@ -179,7 +150,7 @@ func assignField(fromFieldElemValue reflect.Value, retFieldElemValue reflect.Val
 			break
 		}
 
-		if retFieldElemValue.Kind() == reflect.Slice && retFieldElemValue.Elem().Kind() == reflect.String {
+		if reflectutils.IsSliceValueOf(retFieldElemValue, reflect.String) {
 			fromAny = strings.Split(fromString, tag.SplitWith)
 			break
 		}