瀏覽代碼

修改逻辑

yjp 1 年之前
父節點
當前提交
6a329f3fb5
共有 3 個文件被更改,包括 67 次插入8 次删除
  1. 51 0
      examples/sql_mapping_tag/main.go
  2. 3 2
      tag/assign_struct.go
  3. 13 6
      tag/sql_mapping.go

+ 51 - 0
examples/sql_mapping_tag/main.go

@@ -0,0 +1,51 @@
+package main
+
+import (
+	"fmt"
+	"git.sxidc.com/go-framework/baize/tag"
+	"reflect"
+	"time"
+)
+
+type IDField struct {
+	ID string
+}
+
+type TimeFields struct {
+	CreatedTime     *time.Time
+	LastUpdatedTime time.Time
+}
+
+type GraduatedTimeTestStruct struct {
+	Field *string `sqlmapping:"-"`
+}
+
+type Class struct {
+	IDField
+	Name          string `sqlmapping:"updateClear;aes:@MKU^AHYCN$:j76J<TAHCVD#$XZSWQ@L;"`
+	StudentNum    int    `sqlmapping:"column:student_num;notUpdate;"`
+	GraduatedTime *time.Time
+	StudentIDs    []string `sqlmapping:"column:student_ids;joinWith:'\n'"`
+	TimeFields
+	Ignored string `sqlmapping:"-"`
+	*GraduatedTimeTestStruct
+}
+
+func main() {
+	err := tag.BuildExecuteParams(&Class{}, func(fieldName string, entityFieldElemValue reflect.Value, tag *tag.SqlMappingTag) error {
+		fmt.Println("Field Name:", fieldName)
+		fmt.Println("Type:", entityFieldElemValue.Type().String())
+		if entityFieldElemValue.Kind() == reflect.String {
+			fmt.Printf("\"%+v\"\n", entityFieldElemValue.Interface())
+		} else {
+			fmt.Printf("%+v\n", entityFieldElemValue.Interface())
+		}
+		fmt.Printf("%+v\n", tag)
+		fmt.Println()
+
+		return nil
+	})
+	if err != nil {
+		panic(err)
+	}
+}

+ 3 - 2
tag/assign_struct.go

@@ -71,6 +71,8 @@ func assignTo(fromElemValue reflect.Value, retElemValue *reflect.Value) error {
 			if err != nil {
 				return err
 			}
+
+			continue
 		}
 
 		tag, err := parseAssignTag(fromField, tagStr)
@@ -256,8 +258,7 @@ func parseAssignTag(field reflect.StructField, tagStr string) (*assignTag, error
 		for _, assignPart := range assignParts {
 			assignPartKeyValue := strings.SplitN(strings.TrimSpace(assignPart), assignTagPartKeyValueSeparator, 2)
 			if assignPartKeyValue != nil && len(assignPartKeyValue) == 2 && strutils.IsStringNotEmpty(assignPartKeyValue[1]) {
-				// 可以支持' ' ',获取到'字符
-				assignPartKeyValue[1] = strings.TrimSpace(strings.Trim(assignPartKeyValue[1], "'"))
+				assignPartKeyValue[1] = strings.Trim(assignPartKeyValue[1], "'")
 			}
 
 			switch assignPartKeyValue[0] {

+ 13 - 6
tag/sql_mapping.go

@@ -10,7 +10,7 @@ import (
 	"strings"
 )
 
-type OnParsedFieldTagFunc func(entityFieldElemValue reflect.Value, tag *SqlMappingTag) error
+type OnParsedFieldTagFunc func(fieldName string, entityFieldElemValue reflect.Value, tag *SqlMappingTag) error
 
 func BuildExecuteParams(e any, onParsedFieldTagFunc OnParsedFieldTagFunc) error {
 	if e == nil {
@@ -24,7 +24,9 @@ func BuildExecuteParams(e any, onParsedFieldTagFunc OnParsedFieldTagFunc) error
 		return fserr.New("参数不是结构或结构指针")
 	}
 
-	err := parseEntitySqlMappingTag(reflectutils.PointerValueElem(entityValue), onParsedFieldTagFunc)
+	entityElemValue := reflectutils.PointerValueElem(entityValue)
+
+	err := parseEntitySqlMappingTag(entityElemValue, onParsedFieldTagFunc)
 	if err != nil {
 		return err
 	}
@@ -38,10 +40,14 @@ func parseEntitySqlMappingTag(entityElemValue reflect.Value, onParsedFieldTagFun
 		entityFieldValue := entityElemValue.Field(i)
 
 		// 无效零值,不进行映射
-		if !entityFieldValue.IsValid() || entityFieldValue.IsZero() {
+		if !entityFieldValue.IsValid() {
 			continue
 		}
 
+		if entityFieldValue.Kind() == reflect.Pointer && entityFieldValue.IsNil() {
+			entityFieldValue.Set(reflect.New(entityField.Type.Elem()))
+		}
+
 		entityFieldElemValue := reflectutils.PointerValueElem(entityFieldValue)
 
 		tagStr := entityField.Tag.Get(sqlMappingTagKey)
@@ -62,9 +68,11 @@ func parseEntitySqlMappingTag(entityElemValue reflect.Value, onParsedFieldTagFun
 			if err != nil {
 				return err
 			}
+
+			continue
 		}
 
-		err = onParsedFieldTagFunc(entityFieldElemValue, tag)
+		err = onParsedFieldTagFunc(entityField.Name, entityFieldElemValue, tag)
 		if err != nil {
 			return err
 		}
@@ -137,8 +145,7 @@ func parseSqlMappingTag(field reflect.StructField, tagStr string) (*SqlMappingTa
 		for _, sqlMappingPart := range sqlMappingParts {
 			sqlMappingPartKeyValue := strings.SplitN(strings.TrimSpace(sqlMappingPart), sqlMappingTagPartKeyValueSeparator, 2)
 			if sqlMappingPartKeyValue != nil && len(sqlMappingPartKeyValue) == 2 && strutils.IsStringNotEmpty(sqlMappingPartKeyValue[1]) {
-				// 可以支持' ' ',获取到'字符
-				sqlMappingPartKeyValue[1] = strings.TrimSpace(strings.Trim(sqlMappingPartKeyValue[1], "'"))
+				sqlMappingPartKeyValue[1] = strings.Trim(sqlMappingPartKeyValue[1], "'")
 			}
 
 			switch sqlMappingPartKeyValue[0] {