Browse Source

完成tag添加

yjp 11 months ago
parent
commit
80ff3bc74f
5 changed files with 113 additions and 35 deletions
  1. 16 14
      README.md
  2. 28 5
      sql/sql_mapping.go
  3. 23 4
      sql/sql_result.go
  4. 4 0
      test/resources.yaml
  5. 42 12
      test/sdk_test.go

+ 16 - 14
README.md

@@ -257,14 +257,15 @@ column和notUpdate两个sqlmapping Tag,column指定了该字段对应的数据
 
 下面是sqlmapping Tag支持的所有Tag:
 
-| Tag         | 说明                                                                            |
-|-------------|-------------------------------------------------------------------------------|
-| -           | 忽略该字段,不进行持久化(不对应任何数据库表列)                                                      |
-| column      | 显式指定该字段对应的数据库表列,如column:foo                                                   |
-| key         | 该列是否作为逻辑键(实际到底哪个字段为键,是由DataContainer定义确定的)使用,如果一个结构的多个字段使用了key,这几个字段将被作为联合键使用 |
-| notUpdate   | 不对该列进行更新操作                                                                    |
-| updateClear | 允许将该列清空为零值                                                                    |
-| aes         | 进行aes加密并传递aes的密钥,密钥长度为32字节,不能包含';'                                            |
+| Tag         | 说明                                                                                   |
+|-------------|--------------------------------------------------------------------------------------|
+| -           | 忽略该字段,不进行持久化(不对应任何数据库表列)                                                             |
+| column      | 显式指定该字段对应的数据库表列,如column:foo                                                          |
+| key         | 该列是否作为逻辑键(实际到底哪个字段为键,是由DataContainer定义确定的)使用,如果一个结构的多个字段使用了key,这几个字段将被作为联合键使用        |
+| notUpdate   | 不对该列进行更新操作                                                                           |
+| updateClear | 允许将该列清空为零值                                                                           |
+| aes         | 进行aes加密并传递aes的密钥,密钥长度为32字节,不能包含';',也不能以'作为开始和结尾字符                                    |
+| joinWith    | 字段如果是[]string,可以指定join使用的分隔符,默认是'::',如果使用特殊字符,如'\n','\t'等,需要使用''包含分隔符,也就是说,分隔符不能使用'' |
 
 ### sqlresult Tag
 
@@ -290,12 +291,13 @@ GraduatedTimeTest字段通过添加'column'和'parseTime'两个Tag,将graduate
 
 下面是sqlresult Tag支持的所有Tag:
 
-| Tag         | 说明                                 |
-|-------------|------------------------------------|
-| -           | 忽略该字段,不进行持久化(不对应任何数据库表列)           |
-| column      | 显式指定该字段对应的数据库表列,如column:foo        |
-| parseTime   | 按照给定的时间格式化字符串格式化时间                 |
-| aes         | 进行aes加密并传递aes的密钥,密钥长度为32字节,不能包含';' |
+| Tag       | 说明                                                                                    |
+|-----------|---------------------------------------------------------------------------------------|
+| -         | 忽略该字段,不进行持久化(不对应任何数据库表列)                                                              |
+| column    | 显式指定该字段对应的数据库表列,如column:foo                                                           |
+| parseTime | 按照给定的时间格式化字符串格式化时间                                                                    |
+| aes       | 进行aes加密并传递aes的密钥,密钥长度为32字节,不能包含';',也不能以'作为开始和结尾字符                                     |
+| splitWith | 字段如果是[]string,可以指定split使用的分隔符,默认是'::',如果使用特殊字符,如'\n','\t'等,需要使用''包含分隔符,也就是说,分隔符不能使用'' |
 
 ### SQL语句执行
 

+ 28 - 5
sql/sql_mapping.go

@@ -2,6 +2,7 @@ package sql
 
 import (
 	"errors"
+	"git.sxidc.com/go-tools/utils/strutils"
 	"github.com/iancoleman/strcase"
 	"reflect"
 	"strings"
@@ -9,7 +10,8 @@ import (
 )
 
 const (
-	defaultKeyColumnName               = "id"
+	sqlMappingDefaultKeyColumnName     = "id"
+	sqlMappingDefaultJoinWith          = "::"
 	sqlMappingTagPartSeparator         = ";"
 	sqlMappingTagPartKeyValueSeparator = ":"
 )
@@ -22,6 +24,7 @@ const (
 	sqlMappingNotUpdate   = "notUpdate"
 	sqlMappingUpdateClear = "updateClear"
 	sqlMappingAes         = "aes"
+	sqlMappingJoinWith    = "joinWith"
 )
 
 type Mapping struct {
@@ -80,6 +83,7 @@ type MappingColumn struct {
 	CanUpdate      bool
 	CanUpdateClear bool
 	AESKey         string
+	JoinWith       string
 
 	MappingTypesAndValues
 }
@@ -126,6 +130,7 @@ func parseSqlMappingElement(field reflect.StructField, fieldValue reflect.Value)
 		CanUpdate:      true,
 		CanUpdateClear: false,
 		AESKey:         "",
+		JoinWith:       sqlMappingDefaultJoinWith,
 
 		MappingTypesAndValues: MappingTypesAndValues{
 			OriginFieldType:  field.Type,
@@ -135,7 +140,7 @@ func parseSqlMappingElement(field reflect.StructField, fieldValue reflect.Value)
 		},
 	}
 
-	if sqlColumn.Name == defaultKeyColumnName {
+	if sqlColumn.Name == sqlMappingDefaultKeyColumnName {
 		sqlColumn.IsKey = true
 		sqlColumn.CanUpdate = false
 	}
@@ -153,9 +158,17 @@ func parseSqlMappingElement(field reflect.StructField, fieldValue reflect.Value)
 	if sqlMappingParts != nil || len(sqlMappingParts) != 0 {
 		for _, sqlMappingPart := range sqlMappingParts {
 			sqlPartKeyValue := strings.SplitN(strings.TrimSpace(sqlMappingPart), sqlMappingTagPartKeyValueSeparator, 2)
+			if sqlPartKeyValue != nil && len(sqlPartKeyValue) == 2 && strutils.IsStringNotEmpty(sqlPartKeyValue[1]) {
+				sqlPartKeyValue[1] = strings.Trim(sqlPartKeyValue[1], "'")
+			}
+
 			switch sqlPartKeyValue[0] {
 			case sqlMappingColumn:
-				sqlColumn.Name = strings.TrimSpace(sqlPartKeyValue[1])
+				if strutils.IsStringEmpty(sqlPartKeyValue[1]) {
+					return nil, errors.New("column没有赋值列名")
+				}
+
+				sqlColumn.Name = sqlPartKeyValue[1]
 			case sqlMappingKey:
 				sqlColumn.IsKey = true
 				sqlColumn.CanUpdate = false
@@ -164,11 +177,21 @@ func parseSqlMappingElement(field reflect.StructField, fieldValue reflect.Value)
 			case sqlMappingUpdateClear:
 				sqlColumn.CanUpdateClear = true
 			case sqlMappingAes:
-				if len(strings.TrimSpace(sqlPartKeyValue[1])) != 32 {
+				if len(sqlPartKeyValue[1]) != 32 {
 					return nil, errors.New("AES密钥长度应该为32个字节")
 				}
 
-				sqlColumn.AESKey = strings.TrimSpace(sqlPartKeyValue[1])
+				sqlColumn.AESKey = sqlPartKeyValue[1]
+			case sqlMappingJoinWith:
+				if strutils.IsStringEmpty(sqlPartKeyValue[1]) {
+					return nil, errors.New(sqlMappingJoinWith + "没有赋值分隔符")
+				}
+
+				if fieldValueTypeElem.Kind() != reflect.Slice || fieldValueTypeElem.Elem().Kind() != reflect.String {
+					return nil, errors.New(sqlMappingJoinWith + "应该添加在[]string字段上")
+				}
+
+				sqlColumn.JoinWith = sqlPartKeyValue[1]
 			default:
 				continue
 			}

+ 23 - 4
sql/sql_result.go

@@ -2,6 +2,7 @@ package sql
 
 import (
 	"errors"
+	"git.sxidc.com/go-tools/utils/strutils"
 	"github.com/iancoleman/strcase"
 	"reflect"
 	"strings"
@@ -9,6 +10,7 @@ import (
 )
 
 const (
+	sqlResultDefaultSplitWith         = "::"
 	sqlResultTagPartSeparator         = ";"
 	sqlResultTagPartKeyValueSeparator = ":"
 )
@@ -19,6 +21,7 @@ const (
 	sqlResultColumn    = "column"
 	sqlResultParseTime = "parseTime"
 	sqlResultAes       = "aes"
+	sqlResultSplitWith = "splitWith"
 )
 
 type Result struct {
@@ -75,6 +78,7 @@ type ResultColumn struct {
 	Name      string
 	ParseTime string
 	AESKey    string
+	SplitWith string
 
 	ResultTypesAndValues
 }
@@ -119,6 +123,7 @@ func parseSqlResultElement(field reflect.StructField, fieldValue reflect.Value)
 		Name:      strcase.ToSnake(field.Name),
 		ParseTime: "",
 		AESKey:    "",
+		SplitWith: sqlResultDefaultSplitWith,
 
 		ResultTypesAndValues: ResultTypesAndValues{
 			OriginFieldType:  field.Type,
@@ -141,17 +146,31 @@ func parseSqlResultElement(field reflect.StructField, fieldValue reflect.Value)
 	if sqlResultParts != nil || len(sqlResultParts) != 0 {
 		for _, sqlResultPart := range sqlResultParts {
 			sqlPartKeyValue := strings.SplitN(strings.TrimSpace(sqlResultPart), sqlResultTagPartKeyValueSeparator, 2)
+			if sqlPartKeyValue != nil && len(sqlPartKeyValue) == 2 && strutils.IsStringNotEmpty(sqlPartKeyValue[1]) {
+				sqlPartKeyValue[1] = strings.Trim(sqlPartKeyValue[1], "'")
+			}
+
 			switch sqlPartKeyValue[0] {
 			case sqlResultColumn:
-				sqlColumn.Name = strings.TrimSpace(sqlPartKeyValue[1])
+				sqlColumn.Name = sqlPartKeyValue[1]
 			case sqlResultParseTime:
-				sqlColumn.ParseTime = strings.TrimSpace(sqlPartKeyValue[1])
+				sqlColumn.ParseTime = sqlPartKeyValue[1]
 			case sqlResultAes:
-				if len(strings.TrimSpace(sqlPartKeyValue[1])) != 32 {
+				if len(sqlPartKeyValue[1]) != 32 {
 					return nil, errors.New("AES密钥长度应该为32个字节")
 				}
 
-				sqlColumn.AESKey = strings.TrimSpace(sqlPartKeyValue[1])
+				sqlColumn.AESKey = sqlPartKeyValue[1]
+			case sqlResultSplitWith:
+				if strutils.IsStringEmpty(sqlPartKeyValue[1]) {
+					return nil, errors.New(sqlResultDefaultSplitWith + "没有赋值分隔符")
+				}
+
+				if fieldValueTypeElem.Kind() != reflect.Slice || fieldValueTypeElem.Elem().Kind() != reflect.String {
+					return nil, errors.New(sqlResultDefaultSplitWith + "应该添加在[]string字段上")
+				}
+
+				sqlColumn.SplitWith = sqlPartKeyValue[1]
 			default:
 				continue
 			}

+ 4 - 0
test/resources.yaml

@@ -41,6 +41,10 @@ spec:
         type: integer
         comment: 学生数量
         default: 60
+      - name: student_ids
+        type: text
+        comment: 学生ID
+        not_null: true
       - name: graduated_time
         type: "timestamp with time zone"
         comment: 毕业时间

+ 42 - 12
test/sdk_test.go

@@ -32,6 +32,7 @@ type Class struct {
 	Name          string `sqlmapping:"updateClear;aes:@MKU^AHYCN$:j76J<TAHCVD#$XZSWQ@L;" sqlresult:"aes:@MKU^AHYCN$:j76J<TAHCVD#$XZSWQ@L;"`
 	StudentNum    int    `sqlmapping:"column:student_num;notUpdate;" sqlresult:"column:student_num_alias"`
 	GraduatedTime *time.Time
+	StudentIDs    []string `sqlmapping:"column:student_ids;joinWith:'\n'" sqlresult:"column:student_ids;splitWith:'\n'"`
 	TimeFields
 	Ignored string `sqlmapping:"-" sqlresult:"-"`
 	*GraduatedTimeTestField
@@ -75,6 +76,7 @@ func TestBasic(t *testing.T) {
 	classID := strutils.SimpleUUID()
 	className := strutils.SimpleUUID()
 	studentNum := rand.Int31n(100)
+	studentIDs := []string{strutils.SimpleUUID(), strutils.SimpleUUID()}
 	now := time.Now()
 
 	insertExecuteParams, err := sql_tpl.InsertExecuteParams{
@@ -82,6 +84,7 @@ func TestBasic(t *testing.T) {
 		TableRows: sql_tpl.NewTableRows().Add("id", classID).
 			Add("name", className, sql_tpl.WithAESKey("@MKU^AHYCN$:j76J<TAHCVD#$XZSWQ@L")).
 			Add("student_num", studentNum).
+			Add("student_ids", strings.Join(studentIDs, "\n")).
 			Add("graduated_time", now).
 			Add("created_time", now).
 			Add("last_updated_time", now),
@@ -168,8 +171,10 @@ func TestRawSqlTemplate(t *testing.T) {
 	classID := strutils.SimpleUUID()
 	className := strutils.SimpleUUID()
 	studentNum := rand.Int31n(100)
+	studentIDs := []string{strutils.SimpleUUID(), strutils.SimpleUUID()}
 	newClassName := strutils.SimpleUUID()
 	newStudentNum := rand.Int31n(100)
+	newStudentIDs := []string{strutils.SimpleUUID(), strutils.SimpleUUID()}
 
 	now := time.Now()
 
@@ -178,6 +183,7 @@ func TestRawSqlTemplate(t *testing.T) {
 		TableRows: sql_tpl.NewTableRows().Add("id", classID).
 			Add("name", className, sql_tpl.WithAESKey("@MKU^AHYCN$:j76J<TAHCVD#$XZSWQ@L")).
 			Add("student_num", studentNum).
+			Add("student_ids", strings.Join(studentIDs, "\n")).
 			Add("graduated_time", now).
 			Add("created_time", now).
 			Add("last_updated_time", now),
@@ -198,7 +204,8 @@ func TestRawSqlTemplate(t *testing.T) {
 		TableName: tableName,
 		TableRows: sql_tpl.NewTableRows().
 			Add("name", newClassName, sql_tpl.WithAESKey("@MKU^AHYCN$:j76J<TAHCVD#$XZSWQ@L")).
-			Add("student_num", newStudentNum),
+			Add("student_num", newStudentNum).
+			Add("student_ids", strings.Join(newStudentIDs, "\n")),
 		Conditions: sql_tpl.NewConditions().Equal("id", classID),
 	}.Map()
 	if err != nil {
@@ -299,6 +306,7 @@ func TestRawSqlTemplate(t *testing.T) {
 	if classes[0].ID != classID ||
 		classes[0].Name != className ||
 		classes[0].StudentNum != int(studentNum) ||
+		strings.Join(classes[0].StudentIDs, "\n") != strings.Join(studentIDs, "\n") ||
 		classes[0].GraduatedTime.Format(graduatedTimeLayout) != now.Format(graduatedTimeLayout) ||
 		classes[0].CreatedTime.Format(createdTimeLayout) != now.Format(createdTimeLayout) ||
 		classes[0].LastUpdatedTime.Format(lastUpdatedTimeLayout) != now.Format(lastUpdatedTimeLayout) {
@@ -337,6 +345,7 @@ func TestRawSqlTemplate(t *testing.T) {
 	if classes[0].ID != classID ||
 		classes[0].Name != newClassName ||
 		classes[0].StudentNum != int(newStudentNum) ||
+		strings.Join(classes[0].StudentIDs, "\n") != strings.Join(newStudentIDs, "\n") ||
 		classes[0].GraduatedTime.Format(graduatedTimeLayout) != now.Format(graduatedTimeLayout) ||
 		classes[0].CreatedTime.Format(createdTimeLayout) != now.Format(createdTimeLayout) ||
 		classes[0].LastUpdatedTime.Format(lastUpdatedTimeLayout) != now.Format(lastUpdatedTimeLayout) {
@@ -365,18 +374,20 @@ func checkSqlMapping(t *testing.T, e any) {
 			checkSqlMapping(t, element.FieldValueElem.Addr().Interface())
 		case *sql.MappingColumn:
 			if fieldName != "ID" && fieldName != "Name" &&
-				fieldName != "StudentNum" && fieldName != "GraduatedTime" &&
-				fieldName != "CreatedTime" && fieldName != "LastUpdatedTime" {
+				fieldName != "StudentNum" && fieldName != "StudentIDs" &&
+				fieldName != "GraduatedTime" && fieldName != "CreatedTime" &&
+				fieldName != "LastUpdatedTime" {
 				t.Fatal("字段名不正确")
 			}
 
 			if element.Name != "id" && element.Name != "name" &&
-				element.Name != "student_num" && element.Name != "graduated_time" &&
-				element.Name != "created_time" && element.Name != "last_updated_time" {
+				element.Name != "student_num" && element.Name != "student_ids" &&
+				element.Name != "graduated_time" && element.Name != "created_time" &&
+				element.Name != "last_updated_time" {
 				t.Fatal("列名不正确")
 			}
 
-			if element.Name != strcase.ToSnake(fieldName) {
+			if element.Name != "student_ids" && element.Name != strcase.ToSnake(fieldName) {
 				t.Fatal("列名不正确")
 			}
 
@@ -400,6 +411,12 @@ func checkSqlMapping(t *testing.T, e any) {
 					t.Fatal("student_num字段Tag不正确")
 				}
 			}
+
+			if element.Name == "student_ids" {
+				if element.JoinWith != "\n" {
+					t.Fatal("student_ids字段Tag不正确")
+				}
+			}
 		default:
 			t.Fatal("不支持的元素类型")
 		}
@@ -422,21 +439,22 @@ func checkSqlResult(t *testing.T, e any) {
 			checkSqlResult(t, element.FieldValueElem.Addr().Interface())
 		case *sql.ResultColumn:
 			if fieldName != "ID" && fieldName != "Name" &&
-				fieldName != "StudentNum" && fieldName != "GraduatedTime" &&
-				fieldName != "CreatedTime" && fieldName != "LastUpdatedTime" &&
-				fieldName != "GraduatedTimeTest" {
+				fieldName != "StudentNum" && fieldName != "StudentIDs" &&
+				fieldName != "GraduatedTime" && fieldName != "CreatedTime" &&
+				fieldName != "LastUpdatedTime" && fieldName != "GraduatedTimeTest" {
 				t.Fatal("字段名不正确")
 			}
 
 			if element.Name != "id" && element.Name != "name" &&
-				element.Name != "student_num_alias" && element.Name != "graduated_time" &&
-				element.Name != "created_time" && element.Name != "last_updated_time" &&
-				element.Name != "graduated_time_test" {
+				element.Name != "student_num_alias" && element.Name != "student_ids" &&
+				element.Name != "graduated_time" && element.Name != "created_time" &&
+				element.Name != "last_updated_time" && element.Name != "graduated_time_test" {
 				t.Fatal("列名不正确")
 			}
 
 			if element.Name != "student_num_alias" &&
 				element.Name != "graduated_time" &&
+				element.Name != "student_ids" &&
 				element.Name != strcase.ToSnake(fieldName) {
 				t.Fatal("列名不正确")
 			}
@@ -463,6 +481,12 @@ func checkSqlResult(t *testing.T, e any) {
 				}
 			}
 
+			if element.Name == "student_ids" {
+				if element.SplitWith != "\n" {
+					t.Fatal("student_ids字段Tag不正确")
+				}
+			}
+
 			if element.Name == "graduate_time" {
 				if strutils.IsStringEmpty(element.ParseTime) ||
 					strutils.IsStringNotEmpty(element.AESKey) {
@@ -477,8 +501,10 @@ func TestSql(t *testing.T) {
 	classID := strutils.SimpleUUID()
 	className := strutils.SimpleUUID()
 	studentNum := rand.Int31n(100)
+	studentIDs := []string{strutils.SimpleUUID(), strutils.SimpleUUID()}
 	newClassName := strutils.SimpleUUID()
 	newStudentNum := rand.Int31n(100)
+	newStudentIDs := []string{strutils.SimpleUUID(), strutils.SimpleUUID()}
 	now := time.Now()
 	newNow := time.Now()
 
@@ -504,6 +530,7 @@ func TestSql(t *testing.T) {
 		IDField:       IDField{ID: classID},
 		Name:          className,
 		StudentNum:    int(studentNum),
+		StudentIDs:    studentIDs,
 		GraduatedTime: &newNow,
 		Ignored:       "",
 	}
@@ -512,6 +539,7 @@ func TestSql(t *testing.T) {
 		IDField:       IDField{ID: classID},
 		Name:          newClassName,
 		StudentNum:    int(newStudentNum),
+		StudentIDs:    newStudentIDs,
 		GraduatedTime: &newNow,
 		Ignored:       "",
 	}
@@ -699,6 +727,7 @@ func TestSql(t *testing.T) {
 	if queryClasses[0].ID != classID ||
 		queryClasses[0].Name != className ||
 		queryClasses[0].StudentNum != 0 ||
+		strings.Join(queryClasses[0].StudentIDs, "\n") != strings.Join(studentIDs, "\n") ||
 		!queryClasses[0].GraduatedTime.IsZero() ||
 		(queryClasses[0].CreatedTime != nil && !queryClasses[0].CreatedTime.IsZero()) ||
 		!queryClasses[0].LastUpdatedTime.IsZero() {
@@ -725,6 +754,7 @@ func TestSql(t *testing.T) {
 	if queryClass.ID != classID ||
 		queryClass.Name != className ||
 		queryClass.StudentNum != 0 ||
+		strings.Join(queryClasses[0].StudentIDs, "\n") != strings.Join(studentIDs, "\n") ||
 		!queryClass.GraduatedTime.IsZero() ||
 		(queryClass.CreatedTime != nil && !queryClass.CreatedTime.IsZero()) ||
 		!queryClass.LastUpdatedTime.IsZero() {