yjp 11 mēneši atpakaļ
vecāks
revīzija
ffabc4505f
8 mainītis faili ar 163 papildinājumiem un 138 dzēšanām
  1. 24 7
      demo/entity.go
  2. 15 15
      demo/sql_mapping.go
  3. 40 0
      demo/sql_result.go
  4. 10 9
      sql/parse_result.go
  5. 27 30
      sql/sql.go
  6. 22 39
      sql/sql_mapping.go
  7. 9 22
      sql/sql_result.go
  8. 16 16
      test/sdk_test.go

+ 24 - 7
demo/entity.go

@@ -11,17 +11,34 @@ type TimeFields struct {
 	LastUpdatedTime time.Time
 }
 
-type GraduatedTimeTestField struct {
-	GraduatedTimeTest *string `sqlmapping:"-" sqlresult:"column:graduated_time;parseTime:2006-01-02 15:04:05"`
+type IgnoreStruct struct {
+	IgnoreField *string `sqlmapping:"-" sqlresult:"-"` // 这里如果结构字段上忽略了,结构中的字段可以不加忽略
 }
 
 type Class struct {
 	IDField
-	Name          string `sqlmapping:"updateClear;aes:@MKU^AHYCN$:j76J<TAHCVD#$XZSWQ@L;" sqlresult:"aes:@MKU^AHYCN$:j76J<TAHCVD#$XZSWQ@L;"`
-	StudentNum    int    `sqlmapping:"column:student_num;notUpdate;updateClear;" sqlresult:"column:student_num_alias"`
+	Name          string `sqlmapping:"updateClear;aes:@MKU^AHYCN$:j76J<TAHCVD#$XZSWQ@L;"`
+	StudentNum    int    `sqlmapping:"column:student_num;notUpdate;updateClear;"`
 	GraduatedTime *time.Time
-	StudentIDs    []string `sqlmapping:"column:student_ids;joinWith:'\n'" sqlresult:"column:student_ids;splitWith:'\n'"`
+	StudentIDs    []string `sqlmapping:"column:student_ids;joinWith:'\n'"`
 	TimeFields
-	Ignored string `sqlmapping:"-" sqlresult:"-"`
-	*GraduatedTimeTestField
+	Ignored       string           `sqlmapping:"-"`
+	*IgnoreStruct `sqlmapping:"-"` // 会忽略结构下的所有字段
+}
+
+type GraduatedTimeInfoStruct struct {
+	GraduatedTime *string `sqlresult:"column:graduated_time;parseTime:2006-01-02 15:04:05"`
+}
+
+type ClassInfo struct {
+	IDField
+	Name          string `sqlresult:"aes:@MKU^AHYCN$:j76J<TAHCVD#$XZSWQ@L;"`
+	StudentNum    int    `sqlresult:"column:student_num_alias"`
+	GraduatedTime *time.Time
+	StudentIDs    []string `sqlresult:"column:student_ids;splitWith:'\n'"`
+	TimeFields
+	Ignored           string          `sqlresult:"-"`
+	*IgnoreStruct     `sqlresult:"-"` // 会忽略结构下的所有字段
+	GraduatedTimeTest *string         `sqlresult:"column:graduated_time;parseTime:2006-01-02 15:04:05"`
+	*GraduatedTimeInfoStruct
 }

+ 15 - 15
demo/sql_mapping.go

@@ -6,31 +6,31 @@ import (
 )
 
 func main() {
-	printSqlMapping(&Class{})
-}
-
-func printSqlMapping(entity any) {
-	sqlMapping, err := sql.ParseSqlMappingTag(entity)
+	sqlMapping, err := sql.ParseSqlMappingTag(&Class{})
 	if err != nil {
 		panic(err)
 	}
 
-	for fieldName, element := range sqlMapping.MappingElement {
+	printSqlMapping(sqlMapping)
+}
+
+func printSqlMapping(sqlMapping *sql.Mapping) {
+	for fieldName, mappingElement := range sqlMapping.MappingElement {
 		fmt.Println("---------------------------------------")
 		fmt.Println("Field Name: " + fieldName)
 
-		switch sqlColumn := element.(type) {
-		case *sql.MappingStruct:
+		switch element := mappingElement.(type) {
+		case *sql.Mapping:
 			fmt.Println("Type: Struct")
-			printSqlMapping(sqlColumn.FieldValueElem.Interface())
+			printSqlMapping(element)
 		case *sql.MappingColumn:
 			fmt.Println("Type: Field")
-			fmt.Printf("Name: %s\n", sqlColumn.Name)
-			fmt.Printf("IsKey: %v\n", sqlColumn.IsKey)
-			fmt.Printf("CanUpdate: %v\n", sqlColumn.CanUpdate)
-			fmt.Printf("CanUpdateClear: %v\n", sqlColumn.CanUpdateClear)
-			fmt.Printf("AESKey: %v\n", sqlColumn.AESKey)
-			fmt.Printf("JoinWith: %v\n", sqlColumn.JoinWith)
+			fmt.Printf("Name: \"%s\"\n", element.Name)
+			fmt.Printf("IsKey: %v\n", element.IsKey)
+			fmt.Printf("CanUpdate: %v\n", element.CanUpdate)
+			fmt.Printf("CanUpdateClear: %v\n", element.CanUpdateClear)
+			fmt.Printf("AESKey: \"%s\"\n", element.AESKey)
+			fmt.Printf("JoinWith: \"%s\"\n", element.JoinWith)
 		default:
 			fmt.Println("类型错误")
 		}

+ 40 - 0
demo/sql_result.go

@@ -0,0 +1,40 @@
+package main
+
+import (
+	"fmt"
+	"git.sxidc.com/service-supports/ds-sdk/sql"
+)
+
+func main() {
+	// 必须使用结构指针或者[]指针,要接收返回值
+	sqlResult, err := sql.ParseSqlResultTag(&ClassInfo{})
+	if err != nil {
+		panic(err)
+	}
+
+	printSqlResult(sqlResult)
+}
+
+func printSqlResult(sqlResult *sql.Result) {
+
+	for fieldName, resultElement := range sqlResult.ResultElement {
+		fmt.Println("---------------------------------------")
+		fmt.Println("Field Name: " + fieldName)
+
+		switch element := resultElement.(type) {
+		case *sql.Result:
+			fmt.Println("Type: Struct")
+			printSqlResult(element)
+		case *sql.ResultColumn:
+			fmt.Println("Type: Field")
+			fmt.Printf("Name: \"%s\"\n", element.Name)
+			fmt.Printf("ParseTime: %v\n", element.ParseTime)
+			fmt.Printf("AESKey: \"%s\"\n", element.AESKey)
+			fmt.Printf("SplitWith: \"%s\"\n", element.SplitWith)
+		default:
+			fmt.Println("类型错误")
+		}
+
+		fmt.Println("---------------------------------------")
+	}
+}

+ 10 - 9
sql/parse_result.go

@@ -63,7 +63,12 @@ func ParseSqlResult(input any, output any) error {
 		outputEntityValue := reflect.New(outputElemType).Elem().Addr()
 		outputEntity := outputEntityValue.Interface()
 
-		err := formOutputEntity(tableRow, outputEntity)
+		sqlResult, err := ParseSqlResultTag(outputEntity)
+		if err != nil {
+			return err
+		}
+
+		err = formOutputEntity(tableRow, sqlResult)
 		if err != nil {
 			return err
 		}
@@ -88,16 +93,12 @@ func ParseSqlResult(input any, output any) error {
 	return nil
 }
 
-func formOutputEntity(tableRow sdk.SqlResult, outputEntity any) error {
-	sqlResult, err := ParseSqlResultTag(outputEntity)
-	if err != nil {
-		return err
-	}
+func formOutputEntity(tableRow sdk.SqlResult, sqlResult *Result) error {
 
 	for fieldName, resultElement := range sqlResult.ResultElement {
 		switch element := resultElement.(type) {
-		case *ResultStruct:
-			err := formOutputEntity(tableRow, element.FieldValueElem.Addr().Interface())
+		case *Result:
+			err := formOutputEntity(tableRow, element)
 			if err != nil {
 				return err
 			}
@@ -143,7 +144,7 @@ func formOutputEntity(tableRow sdk.SqlResult, outputEntity any) error {
 					}
 				}
 
-				err = reflectutils.AssignStringValue(strValue, fieldValueElem)
+				err := reflectutils.AssignStringValue(strValue, fieldValueElem)
 				if err != nil {
 					return err
 				}

+ 27 - 30
sql/sql.go

@@ -33,9 +33,13 @@ func InsertEntity[T any](executor Executor, tableName string, e T) error {
 		return errors.New("没有传递实体")
 	}
 
-	tableRows := sql_tpl.NewTableRows()
+	sqlMapping, err := ParseSqlMappingTag(e)
+	if err != nil {
+		return err
+	}
 
-	err := formInsertTableRow(e, tableRows)
+	tableRows := sql_tpl.NewTableRows()
+	err = formInsertTableRow(sqlMapping, tableRows)
 	if err != nil {
 		return err
 	}
@@ -60,18 +64,13 @@ func InsertEntity[T any](executor Executor, tableName string, e T) error {
 	return nil
 }
 
-func formInsertTableRow(e any, tableRows *sql_tpl.TableRows) error {
-	sqlMapping, err := ParseSqlMappingTag(e)
-	if err != nil {
-		return err
-	}
-
+func formInsertTableRow(sqlMapping *Mapping, tableRows *sql_tpl.TableRows) error {
 	now := time.Now()
 
 	for fieldName, mappingElement := range sqlMapping.MappingElement {
 		switch element := mappingElement.(type) {
-		case *MappingStruct:
-			err := formInsertTableRow(element.FieldValueElem.Interface(), tableRows)
+		case *Mapping:
+			err := formInsertTableRow(element, tableRows)
 			if err != nil {
 				return err
 			}
@@ -139,9 +138,13 @@ func DeleteEntity[T any](executor Executor, tableName string, e T) error {
 		return errors.New("没有传递实体")
 	}
 
-	conditions := sql_tpl.NewConditions()
+	sqlMapping, err := ParseSqlMappingTag(e)
+	if err != nil {
+		return err
+	}
 
-	err := formDeleteConditions(e, conditions)
+	conditions := sql_tpl.NewConditions()
+	err = formDeleteConditions(sqlMapping, conditions)
 	if err != nil {
 		return err
 	}
@@ -162,16 +165,11 @@ func DeleteEntity[T any](executor Executor, tableName string, e T) error {
 	return nil
 }
 
-func formDeleteConditions(e any, conditions *sql_tpl.Conditions) error {
-	sqlMapping, err := ParseSqlMappingTag(e)
-	if err != nil {
-		return err
-	}
-
+func formDeleteConditions(sqlMapping *Mapping, conditions *sql_tpl.Conditions) error {
 	for _, mappingElement := range sqlMapping.MappingElement {
 		switch element := mappingElement.(type) {
-		case *MappingStruct:
-			err := formDeleteConditions(element.FieldValueElem.Interface(), conditions)
+		case *Mapping:
+			err := formDeleteConditions(element, conditions)
 			if err != nil {
 				return err
 			}
@@ -213,10 +211,14 @@ func UpdateEntity[T any](executor Executor, tableName string, e T) error {
 		return errors.New("没有传递实体")
 	}
 
+	sqlMapping, err := ParseSqlMappingTag(e)
+	if err != nil {
+		return err
+	}
+
 	tableRows := sql_tpl.NewTableRows()
 	conditions := sql_tpl.NewConditions()
-
-	err := formUpdateTableRowsAndConditions(e, tableRows, conditions)
+	err = formUpdateTableRowsAndConditions(sqlMapping, tableRows, conditions)
 	if err != nil {
 		return err
 	}
@@ -238,18 +240,13 @@ func UpdateEntity[T any](executor Executor, tableName string, e T) error {
 	return nil
 }
 
-func formUpdateTableRowsAndConditions(e any, tableRows *sql_tpl.TableRows, conditions *sql_tpl.Conditions) error {
-	sqlMapping, err := ParseSqlMappingTag(e)
-	if err != nil {
-		return err
-	}
-
+func formUpdateTableRowsAndConditions(sqlMapping *Mapping, tableRows *sql_tpl.TableRows, conditions *sql_tpl.Conditions) error {
 	now := time.Now()
 
 	for fieldName, mappingElement := range sqlMapping.MappingElement {
 		switch element := mappingElement.(type) {
-		case *MappingStruct:
-			err := formUpdateTableRowsAndConditions(element.FieldValueElem.Interface(), tableRows, conditions)
+		case *Mapping:
+			err := formUpdateTableRowsAndConditions(element, tableRows, conditions)
 			if err != nil {
 				return err
 			}

+ 22 - 39
sql/sql_mapping.go

@@ -31,6 +31,23 @@ type Mapping struct {
 	MappingElement map[string]any
 }
 
+type MappingColumn struct {
+	Name           string
+	IsKey          bool
+	CanUpdate      bool
+	CanUpdateClear bool
+	AESKey         string
+	JoinWith       string
+
+	// 原字段的反射结构
+	OriginFieldType  reflect.Type
+	OriginFieldValue reflect.Value
+
+	// 值类型的反射结构
+	FieldTypeElem  reflect.Type
+	FieldValueElem reflect.Value
+}
+
 func ParseSqlMappingTag(e any) (*Mapping, error) {
 	if e == nil {
 		return nil, errors.New("没有传递实体")
@@ -73,31 +90,6 @@ func ParseSqlMappingTag(e any) (*Mapping, error) {
 	return sqlMapping, nil
 }
 
-type MappingStruct struct {
-	MappingTypesAndValues
-}
-
-type MappingColumn struct {
-	Name           string
-	IsKey          bool
-	CanUpdate      bool
-	CanUpdateClear bool
-	AESKey         string
-	JoinWith       string
-
-	MappingTypesAndValues
-}
-
-type MappingTypesAndValues struct {
-	// 原字段的反射结构
-	OriginFieldType  reflect.Type
-	OriginFieldValue reflect.Value
-
-	// 值类型的反射结构
-	FieldTypeElem  reflect.Type
-	FieldValueElem reflect.Value
-}
-
 func parseSqlMappingElement(field reflect.StructField, fieldValue reflect.Value) (any, error) {
 	sqlMappingTag := field.Tag.Get(sqlMappingTagKey)
 	if sqlMappingTag == sqlMappingIgnore {
@@ -119,14 +111,7 @@ func parseSqlMappingElement(field reflect.StructField, fieldValue reflect.Value)
 	}
 
 	if fieldValueTypeElem.Kind() == reflect.Struct && fieldValueTypeElem != reflect.TypeOf(time.Time{}) {
-		return &MappingStruct{
-			MappingTypesAndValues: MappingTypesAndValues{
-				OriginFieldType:  field.Type,
-				OriginFieldValue: fieldValue,
-				FieldTypeElem:    fieldValueTypeElem,
-				FieldValueElem:   fieldValueElem,
-			},
-		}, nil
+		return ParseSqlMappingTag(fieldValueElem.Interface())
 	}
 
 	sqlColumn := &MappingColumn{
@@ -137,12 +122,10 @@ func parseSqlMappingElement(field reflect.StructField, fieldValue reflect.Value)
 		AESKey:         "",
 		JoinWith:       sqlMappingDefaultJoinWith,
 
-		MappingTypesAndValues: MappingTypesAndValues{
-			OriginFieldType:  field.Type,
-			OriginFieldValue: fieldValue,
-			FieldTypeElem:    fieldValueTypeElem,
-			FieldValueElem:   fieldValueElem,
-		},
+		OriginFieldType:  field.Type,
+		OriginFieldValue: fieldValue,
+		FieldTypeElem:    fieldValueTypeElem,
+		FieldValueElem:   fieldValueElem,
 	}
 
 	if sqlColumn.Name == sqlMappingDefaultKeyColumnName {

+ 9 - 22
sql/sql_result.go

@@ -70,20 +70,12 @@ func ParseSqlResultTag(e any) (*Result, error) {
 	return sqlResult, nil
 }
 
-type ResultStruct struct {
-	ResultTypesAndValues
-}
-
 type ResultColumn struct {
 	Name      string
 	ParseTime string
 	AESKey    string
 	SplitWith string
 
-	ResultTypesAndValues
-}
-
-type ResultTypesAndValues struct {
 	// 原字段的反射结构
 	OriginFieldType  reflect.Type
 	OriginFieldValue reflect.Value
@@ -118,14 +110,11 @@ func parseSqlResultElement(field reflect.StructField, fieldValue reflect.Value)
 	}
 
 	if fieldValueTypeElem.Kind() == reflect.Struct && fieldValueTypeElem != reflect.TypeOf(time.Time{}) {
-		return &ResultStruct{
-			ResultTypesAndValues: ResultTypesAndValues{
-				OriginFieldType:  field.Type,
-				OriginFieldValue: fieldValue,
-				FieldTypeElem:    fieldValueTypeElem,
-				FieldValueElem:   fieldValueElem,
-			},
-		}, nil
+		if !fieldValueElem.CanAddr() {
+			return nil, errors.New("请使用指针作为变量")
+		}
+
+		return ParseSqlResultTag(fieldValueElem.Addr().Interface())
 	}
 
 	sqlColumn := &ResultColumn{
@@ -134,12 +123,10 @@ func parseSqlResultElement(field reflect.StructField, fieldValue reflect.Value)
 		AESKey:    "",
 		SplitWith: sqlResultDefaultSplitWith,
 
-		ResultTypesAndValues: ResultTypesAndValues{
-			OriginFieldType:  field.Type,
-			OriginFieldValue: fieldValue,
-			FieldTypeElem:    fieldValueTypeElem,
-			FieldValueElem:   fieldValueElem,
-		},
+		OriginFieldType:  field.Type,
+		OriginFieldValue: fieldValue,
+		FieldTypeElem:    fieldValueTypeElem,
+		FieldValueElem:   fieldValueElem,
 	}
 
 	if strutils.IsStringEmpty(sqlResultTag) {

+ 16 - 16
test/sdk_test.go

@@ -359,19 +359,19 @@ func TestRawSqlTemplate(t *testing.T) {
 }
 
 func TestSqlMapping(t *testing.T) {
-	checkSqlMapping(t, &Class{})
-}
-
-func checkSqlMapping(t *testing.T, e any) {
-	sqlMapping, err := sql.ParseSqlMappingTag(e)
+	sqlMapping, err := sql.ParseSqlMappingTag(&Class{})
 	if err != nil {
 		t.Fatal(err)
 	}
 
+	checkSqlMapping(t, sqlMapping)
+}
+
+func checkSqlMapping(t *testing.T, sqlMapping *sql.Mapping) {
 	for fieldName, mappingElement := range sqlMapping.MappingElement {
 		switch element := mappingElement.(type) {
-		case *sql.MappingStruct:
-			checkSqlMapping(t, element.FieldValueElem.Interface())
+		case *sql.Mapping:
+			checkSqlMapping(t, element)
 		case *sql.MappingColumn:
 			if fieldName != "ID" && fieldName != "Name" &&
 				fieldName != "StudentNum" && fieldName != "StudentIDs" &&
@@ -424,19 +424,19 @@ func checkSqlMapping(t *testing.T, e any) {
 }
 
 func TestSqlResult(t *testing.T) {
-	checkSqlResult(t, &Class{})
-}
-
-func checkSqlResult(t *testing.T, e any) {
-	sqlResult, err := sql.ParseSqlResultTag(e)
+	sqlResult, err := sql.ParseSqlResultTag(&Class{})
 	if err != nil {
 		t.Fatal(err)
 	}
 
-	for fieldName, mappingElement := range sqlResult.ResultElement {
-		switch element := mappingElement.(type) {
-		case *sql.ResultStruct:
-			checkSqlResult(t, element.FieldValueElem.Addr().Interface())
+	checkSqlResult(t, sqlResult)
+}
+
+func checkSqlResult(t *testing.T, sqlResult *sql.Result) {
+	for fieldName, resultElement := range sqlResult.ResultElement {
+		switch element := resultElement.(type) {
+		case *sql.Result:
+			checkSqlResult(t, element)
 		case *sql.ResultColumn:
 			if fieldName != "ID" && fieldName != "Name" &&
 				fieldName != "StudentNum" && fieldName != "StudentIDs" &&