Browse Source

添加tag

yjp 11 months ago
parent
commit
ac37d6e5a4
3 changed files with 94 additions and 24 deletions
  1. 34 11
      sql/parse_table_row.go
  2. 51 9
      sql/sql.go
  3. 9 4
      test/sdk_test.go

+ 34 - 11
sql/parse_table_row.go

@@ -122,12 +122,13 @@ func formOutputEntity(tableRow map[string]any, outputEntity any) error {
 			}
 
 			// 构造结构字段,如果结构字段是指针且为nil,需要构造元素
-			fieldValue := element.FieldValueElem
-			outputKind := reflectutils.GroupValueKind(fieldValue)
+			fieldTypeElem := element.FieldTypeElem
+			fieldValueElem := element.FieldValueElem
+			outputKind := reflectutils.GroupValueKind(fieldValueElem)
 
 			switch outputKind {
 			case reflect.Bool:
-				err := reflectutils.AssignBoolValue(tableRowValue, fieldValue)
+				err := reflectutils.AssignBoolValue(tableRowValue, fieldValueElem)
 				if err != nil {
 					return err
 				}
@@ -152,41 +153,63 @@ func formOutputEntity(tableRow map[string]any, outputEntity any) error {
 					}
 				}
 
-				err = reflectutils.AssignStringValue(strValue, fieldValue)
+				err = reflectutils.AssignStringValue(strValue, fieldValueElem)
 				if err != nil {
 					return err
 				}
 			case reflect.Int64:
-				err := reflectutils.AssignIntValue(tableRowValue, fieldValue)
+				err := reflectutils.AssignIntValue(tableRowValue, fieldValueElem)
 				if err != nil {
 					return err
 				}
 			case reflect.Uint64:
-				err := reflectutils.AssignUintValue(tableRowValue, fieldValue)
+				err := reflectutils.AssignUintValue(tableRowValue, fieldValueElem)
 				if err != nil {
 					return err
 				}
 			case reflect.Float64:
-				err := reflectutils.AssignFloatValue(tableRowValue, fieldValue)
+				err := reflectutils.AssignFloatValue(tableRowValue, fieldValueElem)
 				if err != nil {
 					return err
 				}
 			case reflect.Struct:
-				if fieldValue.Type() == reflect.TypeOf(time.Time{}) {
+				if fieldValueElem.Type() == reflect.TypeOf(time.Time{}) {
 					parsedTime, err := parseSqlTableRowTimeStr(tableRowValue.(string))
 					if err != nil {
 						return err
 					}
 
-					fieldValue.Set(reflect.ValueOf(parsedTime))
+					fieldValueElem.Set(reflect.ValueOf(parsedTime))
 					continue
 				}
 
 				return fmt.Errorf("字段: %s 列: %s 不支持的类型: %s",
-					fieldName, element.Name, reflect.TypeOf(tableRowValue).String())
+					fieldName, element.Name, fieldTypeElem.String())
+			case reflect.Slice:
+				if fieldTypeElem.Elem().Kind() != reflect.String {
+					return errors.New("slice仅支持[]string")
+				}
+
+				strValue := tableRowValue.(string)
+
+				strParts := strings.Split(strValue, element.SplitWith)
+				if strParts == nil || len(strParts) == 0 {
+					return nil
+				}
+
+				valSlice := fieldValueElem
+				if valSlice.IsNil() {
+					valSlice = reflect.MakeSlice(fieldTypeElem, 0, 0)
+				}
+
+				for _, strPart := range strParts {
+					valSlice = reflect.Append(valSlice, reflect.ValueOf(strPart))
+				}
+
+				fieldValueElem.Set(valSlice)
 			default:
 				return fmt.Errorf("字段: %s 列: %s 不支持的类型: %s",
-					fieldName, element.Name, reflect.TypeOf(tableRowValue).String())
+					fieldName, element.Name, fieldTypeElem.String())
 			}
 		default:
 			return errors.New("不支持的元素类型")

+ 51 - 9
sql/sql.go

@@ -84,14 +84,32 @@ func formInsertTableRow(e any, tableRows *sql_tpl.TableRows) error {
 
 			// 有值取值,没有值构造零值
 			value := reflect.Zero(fieldType).Interface()
-			if !element.FieldValueElem.IsZero() {
-				value = element.FieldValueElem.Interface()
-			}
+			if fieldType.Kind() != reflect.Slice {
+				if !element.FieldValueElem.IsZero() {
+					value = element.FieldValueElem.Interface()
+				}
 
-			// 自动添加创建时间和更新时间
-			if (fieldName == createdTimeFieldName || fieldName == lastUpdatedTimeFieldName) &&
-				fieldType.String() == "time.Time" && value.(time.Time).IsZero() {
-				value = now
+				// 自动添加创建时间和更新时间
+				if (fieldName == createdTimeFieldName || fieldName == lastUpdatedTimeFieldName) &&
+					fieldType.String() == "time.Time" && value.(time.Time).IsZero() {
+					value = now
+				}
+			} else {
+				sliceElementType := fieldType.Elem()
+				if sliceElementType.Kind() != reflect.String {
+					return errors.New("slice仅支持[]string")
+				}
+
+				if element.FieldValueElem.Len() == 0 {
+					continue
+				}
+
+				strValues := make([]string, 0, 0)
+				for i := 0; i < element.FieldValueElem.Len(); i++ {
+					strValues = append(strValues, element.FieldValueElem.Index(i).String())
+				}
+
+				value = strings.Join(strValues, element.JoinWith)
 			}
 
 			var opts []sql_tpl.AfterParsedStrValueOption
@@ -254,8 +272,32 @@ func formUpdateTableRowsAndConditions(e any, tableRows *sql_tpl.TableRows, condi
 			fieldType := element.FieldTypeElem
 
 			value := reflect.Zero(fieldType).Interface()
-			if !element.FieldValueElem.IsZero() {
-				value = element.FieldValueElem.Interface()
+			if fieldType.Kind() != reflect.Slice {
+				if !element.FieldValueElem.IsZero() {
+					value = element.FieldValueElem.Interface()
+				}
+
+				// 自动添加创建时间和更新时间
+				if (fieldName == createdTimeFieldName || fieldName == lastUpdatedTimeFieldName) &&
+					fieldType.String() == "time.Time" && value.(time.Time).IsZero() {
+					value = now
+				}
+			} else {
+				sliceElementType := fieldType.Elem()
+				if sliceElementType.Kind() != reflect.String {
+					return errors.New("slice仅支持[]string")
+				}
+
+				if element.FieldValueElem.Len() == 0 {
+					continue
+				}
+
+				strValues := make([]string, 0, 0)
+				for i := 0; i < element.FieldValueElem.Len(); i++ {
+					strValues = append(strValues, element.FieldValueElem.Index(i).String())
+				}
+
+				value = strings.Join(strValues, element.JoinWith)
 			}
 
 			if fieldName == lastUpdatedTimeFieldName &&

+ 9 - 4
test/sdk_test.go

@@ -214,7 +214,7 @@ func TestRawSqlTemplate(t *testing.T) {
 
 	queryExecuteParams, err := sql_tpl.QueryExecuteParams{
 		TableName:     tableName,
-		SelectColumns: []string{"id", "name", "student_num as student_num_alias", "graduated_time", "created_time", "last_updated_time"},
+		SelectColumns: []string{"id", "name", "student_num as student_num_alias", "student_ids", "graduated_time", "created_time", "last_updated_time"},
 		Conditions: sql_tpl.NewConditions().
 			Equal("id", classID).
 			Equal("name", className, sql_tpl.WithAESKey("@MKU^AHYCN$:j76J<TAHCVD#$XZSWQ@L")).
@@ -228,7 +228,7 @@ func TestRawSqlTemplate(t *testing.T) {
 
 	newQueryExecuteParams, err := sql_tpl.QueryExecuteParams{
 		TableName:     tableName,
-		SelectColumns: []string{"id", "name", "student_num as student_num_alias", "graduated_time", "created_time", "last_updated_time"},
+		SelectColumns: []string{"id", "name", "student_num as student_num_alias", "student_ids", "graduated_time", "created_time", "last_updated_time"},
 		Conditions: sql_tpl.NewConditions().
 			Equal("id", classID).
 			Equal("name", newClassName, sql_tpl.WithAESKey("@MKU^AHYCN$:j76J<TAHCVD#$XZSWQ@L")).
@@ -513,6 +513,7 @@ func TestSql(t *testing.T) {
 		TableRows: sql_tpl.NewTableRows().Add("id", classID).
 			Add("name", className, sql_tpl.WithAESKey("@MKU^AHYCN$:j76J<TAHCVD#$XZSWQ@L")).
 			Add("student_num", studentNum).
+			Add("student_ids", strings.Join(studentIDs, "\n")).
 			Add("graduated_time", now).
 			Add("created_time", now).
 			Add("last_updated_time", now),
@@ -591,6 +592,7 @@ func TestSql(t *testing.T) {
 		TableRows: sql_tpl.NewTableRows().Add("id", classID).
 			Add("name", className, sql_tpl.WithAESKey("@MKU^AHYCN$:j76J<TAHCVD#$XZSWQ@L")).
 			Add("student_num", studentNum).
+			Add("student_ids", strings.Join(studentIDs, "\n")).
 			Add("graduated_time", now).
 			Add("created_time", now).
 			Add("last_updated_time", now),
@@ -603,6 +605,7 @@ func TestSql(t *testing.T) {
 		TableName: tableName,
 		TableRows: sql_tpl.NewTableRows().Add("id", classID).
 			Add("name", newClassName, sql_tpl.WithAESKey("@MKU^AHYCN$:j76J<TAHCVD#$XZSWQ@L")).
+			Add("student_ids", strings.Join(newStudentIDs, "\n")).
 			Add("student_num", newStudentNum),
 		Conditions: sql_tpl.NewConditions().
 			Equal("id", classID),
@@ -651,6 +654,7 @@ func TestSql(t *testing.T) {
 			TableRows: sql_tpl.NewTableRows().Add("id", classID).
 				Add("name", className, sql_tpl.WithAESKey("@MKU^AHYCN$:j76J<TAHCVD#$XZSWQ@L")).
 				Add("student_num", studentNum).
+				Add("student_ids", strings.Join(studentIDs, "\n")).
 				Add("graduated_time", now).
 				Add("created_time", now).
 				Add("last_updated_time", now),
@@ -663,7 +667,8 @@ func TestSql(t *testing.T) {
 			TableName: tableName,
 			TableRows: sql_tpl.NewTableRows().Add("id", classID).
 				Add("name", newClassName, sql_tpl.WithAESKey("@MKU^AHYCN$:j76J<TAHCVD#$XZSWQ@L")).
-				Add("student_num", newStudentNum),
+				Add("student_num", newStudentNum).
+				Add("student_ids", strings.Join(newStudentIDs, "\n")),
 			Conditions: sql_tpl.NewConditions().
 				Equal("id", classID),
 		})
@@ -703,7 +708,7 @@ func TestSql(t *testing.T) {
 
 	tableRows, totalCount, err := sql.Query(sdk.GetInstance(), &sql_tpl.QueryExecuteParams{
 		TableName:     tableName,
-		SelectColumns: []string{"id", "name"},
+		SelectColumns: []string{"id", "name", "student_ids"},
 		Conditions: sql_tpl.NewConditions().
 			Equal("id", classID).
 			Equal("name", className, sql_tpl.WithAESKey("@MKU^AHYCN$:j76J<TAHCVD#$XZSWQ@L")).