Pārlūkot izejas kodu

修改bug,添加批量插入接口

yjp 11 mēneši atpakaļ
vecāks
revīzija
61a414ba43
5 mainītis faili ar 340 papildinājumiem un 66 dzēšanām
  1. 4 2
      sql/parse_result.go
  2. 96 20
      sql/sql.go
  3. 69 17
      sql/sql_tpl/sql_tpl.go
  4. 18 18
      sql/sql_tpl/table_row.go
  5. 153 9
      test/sdk_test.go

+ 4 - 2
sql/parse_result.go

@@ -34,8 +34,6 @@ func ParseSqlResult[T any](input any) (T, error) {
 	typeCheckErr := errors.New("可以接受的类型为struct, *struct, []struct, []*struct")
 	outputType := reflect.TypeOf(zero)
 
-	fmt.Println("Output Type:", outputType.String())
-
 	if outputType.Kind() != reflect.Struct && outputType.Kind() != reflect.Ptr && outputType.Kind() != reflect.Slice {
 		return zero, typeCheckErr
 	} else if outputType.Kind() == reflect.Ptr && outputType.Elem().Kind() != reflect.Struct {
@@ -115,6 +113,10 @@ func formOutputEntity(tableRow sdk.SqlResult, sqlResult *Result) error {
 				continue
 			}
 
+			if tableRowValue == nil {
+				continue
+			}
+
 			// 构造结构字段,如果结构字段是指针且为nil,需要构造元素
 			fieldTypeElem := element.FieldTypeElem
 			fieldValueElem := element.FieldValueElem

+ 96 - 20
sql/sql.go

@@ -29,30 +29,84 @@ func InsertEntity[T any](executor Executor, tableName string, e T) error {
 		return errors.New("没有传递表名")
 	}
 
-	if reflect.TypeOf(e) == nil {
+	entityType := reflect.TypeOf(e)
+
+	if entityType == nil {
 		return errors.New("没有传递实体")
 	}
 
-	sqlMapping, err := ParseSqlMappingTag(e)
-	if err != nil {
-		return err
-	}
+	typeCheckErr := errors.New("可以接受的类型为struct, *struct, []struct, []*struct")
 
-	tableRows := sql_tpl.NewTableRows()
-	err = formInsertTableRow(sqlMapping, tableRows)
-	if err != nil {
-		return err
+	if entityType.Kind() != reflect.Struct && entityType.Kind() != reflect.Ptr && entityType.Kind() != reflect.Slice {
+		return typeCheckErr
+	} else if entityType.Kind() == reflect.Ptr && entityType.Elem().Kind() != reflect.Struct {
+		return typeCheckErr
+	} else if entityType.Kind() == reflect.Slice &&
+		(entityType.Elem().Kind() != reflect.Struct && entityType.Elem().Kind() != reflect.Ptr) {
+		return typeCheckErr
+	} else if entityType.Kind() == reflect.Slice &&
+		entityType.Elem().Kind() == reflect.Ptr && entityType.Elem().Elem().Kind() != reflect.Struct {
+		return typeCheckErr
 	}
 
-	executeParamsMap, err := sql_tpl.InsertExecuteParams{
-		TableName: tableName,
-		TableRows: tableRows,
-	}.Map()
-	if err != nil {
-		return err
+	var executeParamsMap map[string]any
+
+	if entityType.Kind() == reflect.Struct || entityType.Kind() == reflect.Ptr {
+		sqlMapping, err := ParseSqlMappingTag(reflect.ValueOf(e).Interface())
+		if err != nil {
+			return err
+		}
+
+		tableRows := sql_tpl.NewTableRow()
+		err = formInsertTableRow(sqlMapping, tableRows)
+		if err != nil {
+			return err
+		}
+
+		innerExecuteParamsMap, err := sql_tpl.InsertExecuteParams{
+			TableName: tableName,
+			TableRow:  tableRows,
+		}.Map()
+		if err != nil {
+			return err
+		}
+
+		executeParamsMap = innerExecuteParamsMap
+	} else {
+		entitySliceValue := reflect.ValueOf(e)
+		if entitySliceValue.Len() == 0 {
+			return nil
+		}
+
+		tableRowsBatch := make([]sql_tpl.TableRow, 0)
+
+		for i := 0; i < entitySliceValue.Len(); i++ {
+			sqlMapping, err := ParseSqlMappingTag(entitySliceValue.Index(i).Interface())
+			if err != nil {
+				return err
+			}
+
+			tableRows := sql_tpl.NewTableRow()
+			err = formInsertTableRow(sqlMapping, tableRows)
+			if err != nil {
+				return err
+			}
+
+			tableRowsBatch = append(tableRowsBatch, *tableRows)
+		}
+
+		innerExecuteParamsMap, err := sql_tpl.InsertBatchExecuteParams{
+			TableName:      tableName,
+			TableRowsBatch: tableRowsBatch,
+		}.Map()
+		if err != nil {
+			return err
+		}
+
+		executeParamsMap = innerExecuteParamsMap
 	}
 
-	_, err = executor.ExecuteRawSql(sql_tpl.InsertTpl, executeParamsMap)
+	_, err := executor.ExecuteRawSql(sql_tpl.InsertTpl, executeParamsMap)
 	if err != nil {
 		if strings.Contains(err.Error(), "SQLSTATE 23505") {
 			return sdk.ErrDBRecordHasExist
@@ -64,7 +118,7 @@ func InsertEntity[T any](executor Executor, tableName string, e T) error {
 	return nil
 }
 
-func formInsertTableRow(sqlMapping *Mapping, tableRows *sql_tpl.TableRows) error {
+func formInsertTableRow(sqlMapping *Mapping, tableRows *sql_tpl.TableRow) error {
 	now := time.Now()
 
 	for fieldName, mappingElement := range sqlMapping.MappingElement {
@@ -216,7 +270,7 @@ func UpdateEntity[T any](executor Executor, tableName string, e T) error {
 		return err
 	}
 
-	tableRows := sql_tpl.NewTableRows()
+	tableRows := sql_tpl.NewTableRow()
 	conditions := sql_tpl.NewConditions()
 	err = formUpdateTableRowsAndConditions(sqlMapping, tableRows, conditions)
 	if err != nil {
@@ -225,7 +279,7 @@ func UpdateEntity[T any](executor Executor, tableName string, e T) error {
 
 	executeParamsMap, err := sql_tpl.UpdateExecuteParams{
 		TableName:  tableName,
-		TableRows:  tableRows,
+		TableRow:   tableRows,
 		Conditions: conditions,
 	}.Map()
 	if err != nil {
@@ -240,7 +294,7 @@ func UpdateEntity[T any](executor Executor, tableName string, e T) error {
 	return nil
 }
 
-func formUpdateTableRowsAndConditions(sqlMapping *Mapping, tableRows *sql_tpl.TableRows, conditions *sql_tpl.Conditions) error {
+func formUpdateTableRowsAndConditions(sqlMapping *Mapping, tableRows *sql_tpl.TableRow, conditions *sql_tpl.Conditions) error {
 	now := time.Now()
 
 	for fieldName, mappingElement := range sqlMapping.MappingElement {
@@ -342,6 +396,28 @@ func Insert(executor Executor, executeParams *sql_tpl.InsertExecuteParams) error
 	return nil
 }
 
+func InsertBatch(executor Executor, executeParams *sql_tpl.InsertBatchExecuteParams) error {
+	if executor == nil {
+		return errors.New("没有传递执行器")
+	}
+
+	if executeParams == nil {
+		return errors.New("没有传递执行参数")
+	}
+
+	executeParamsMap, err := executeParams.Map()
+	if err != nil {
+		return err
+	}
+
+	_, err = executor.ExecuteRawSql(sql_tpl.InsertTpl, executeParamsMap)
+	if err != nil {
+		return err
+	}
+
+	return nil
+}
+
 func Delete(executor Executor, executeParams *sql_tpl.DeleteExecuteParams) error {
 	if executor == nil {
 		return errors.New("没有传递执行器")

+ 69 - 17
sql/sql_tpl/sql_tpl.go

@@ -6,35 +6,87 @@ const InsertTpl = `
 INSERT INTO
     {{ .table_name }} ({{ .columns | join "," }})
 VALUES
-    ({{ .values | join "," }})
+{{- $valuesClauses := list }}
+{{- range .values_list }}
+{{- $valuesClause := printf "(%s)" ( . | join "," ) }}
+{{- $valuesClauses = append $valuesClauses $valuesClause }}
+{{- end }}
+    {{ $valuesClauses | join "," }}
 `
 
 type InsertExecuteParams struct {
 	TableName string
-	*TableRows
+	*TableRow
 }
 
 func (params InsertExecuteParams) Map() (map[string]any, error) {
-	if params.TableRows == nil {
+	if params.TableRow == nil {
 		return nil, nil
 	}
 
-	if params.TableRows.err != nil {
-		return nil, params.TableRows.err
+	if params.TableRow.err != nil {
+		return nil, params.TableRow.err
 	}
 
 	columns := make([]string, 0)
 	values := make([]any, 0)
 
-	for _, row := range params.TableRows.Rows {
-		columns = append(columns, row.Column)
-		values = append(values, row.Value)
+	for _, cv := range params.TableRow.columnValues {
+		columns = append(columns, cv.column)
+		values = append(values, cv.value)
 	}
 
 	return map[string]any{
-		"table_name": params.TableName,
-		"columns":    columns,
-		"values":     values,
+		"table_name":  params.TableName,
+		"columns":     columns,
+		"values_list": []any{values},
+	}, nil
+}
+
+type InsertBatchExecuteParams struct {
+	TableName      string
+	TableRowsBatch []TableRow
+}
+
+func (params InsertBatchExecuteParams) Map() (map[string]any, error) {
+	if params.TableRowsBatch == nil || len(params.TableRowsBatch) == 0 {
+		return nil, nil
+	}
+
+	columns := make([]string, 0)
+	for _, cv := range params.TableRowsBatch[0].columnValues {
+		columns = append(columns, cv.column)
+	}
+
+	valuesList := make([]any, 0)
+
+	for _, tableRows := range params.TableRowsBatch {
+		if tableRows.err != nil {
+			return nil, tableRows.err
+		}
+
+		if len(columns) != len(tableRows.columnValues) {
+			return nil, errors.New("列数不匹配,保证每个TableRow的Add数量一致")
+		}
+
+		columnAndValueMap := make(map[string]any, 0)
+
+		for _, cv := range tableRows.columnValues {
+			columnAndValueMap[cv.column] = cv.value
+		}
+
+		values := make([]any, len(columnAndValueMap))
+		for _, column := range columns {
+			values = append(values, columnAndValueMap[column])
+		}
+
+		valuesList = append(valuesList, values)
+	}
+
+	return map[string]any{
+		"table_name":  params.TableName,
+		"columns":     columns,
+		"values_list": valuesList,
 	}, nil
 }
 
@@ -76,22 +128,22 @@ WHERE
 
 type UpdateExecuteParams struct {
 	TableName string
-	*TableRows
+	*TableRow
 	*Conditions
 }
 
 func (params UpdateExecuteParams) Map() (map[string]any, error) {
-	if params.TableRows == nil {
+	if params.TableRow == nil {
 		return nil, nil
 	}
 
-	if params.TableRows.err != nil {
-		return nil, params.TableRows.err
+	if params.TableRow.err != nil {
+		return nil, params.TableRow.err
 	}
 
 	setList := make([]string, 0)
-	for _, row := range params.TableRows.Rows {
-		setList = append(setList, row.Column+" = "+row.Value)
+	for _, cv := range params.TableRow.columnValues {
+		setList = append(setList, cv.column+" = "+cv.value)
 	}
 
 	conditions := make([]string, 0)

+ 18 - 18
sql/sql_tpl/table_row.go

@@ -1,36 +1,36 @@
 package sql_tpl
 
-type TableRows struct {
-	Rows []TableRow
-	err  error
+type TableRow struct {
+	columnValues []columnValue
+	err          error
 }
 
-type TableRow struct {
-	Column string
-	Value  string
+type columnValue struct {
+	column string
+	value  string
 }
 
-func NewTableRows() *TableRows {
-	return &TableRows{
-		Rows: make([]TableRow, 0),
+func NewTableRow() *TableRow {
+	return &TableRow{
+		columnValues: make([]columnValue, 0),
 	}
 }
 
-func (tableRows *TableRows) Add(column string, value any, opts ...AfterParsedStrValueOption) *TableRows {
-	if tableRows.err != nil {
-		return tableRows
+func (tableRow *TableRow) Add(column string, value any, opts ...AfterParsedStrValueOption) *TableRow {
+	if tableRow.err != nil {
+		return tableRow
 	}
 
 	parsedValue, err := parseValue(value, opts...)
 	if err != nil {
-		tableRows.err = err
-		return tableRows
+		tableRow.err = err
+		return tableRow
 	}
 
-	tableRows.Rows = append(tableRows.Rows, TableRow{
-		Column: column,
-		Value:  parsedValue,
+	tableRow.columnValues = append(tableRow.columnValues, columnValue{
+		column: column,
+		value:  parsedValue,
 	})
 
-	return tableRows
+	return tableRow
 }

+ 153 - 9
test/sdk_test.go

@@ -81,7 +81,7 @@ func TestBasic(t *testing.T) {
 
 	insertExecuteParams, err := sql_tpl.InsertExecuteParams{
 		TableName: tableName,
-		TableRows: sql_tpl.NewTableRows().Add("id", classID).
+		TableRow: sql_tpl.NewTableRow().Add("id", classID).
 			Add("name", className, sql_tpl.WithAESKey("@MKU^AHYCN$:j76J<TAHCVD#$XZSWQ@L")).
 			Add("student_num", studentNum).
 			Add("student_ids", strings.Join(studentIDs, "\n")).
@@ -180,7 +180,7 @@ func TestRawSqlTemplate(t *testing.T) {
 
 	insertExecuteParams, err := sql_tpl.InsertExecuteParams{
 		TableName: tableName,
-		TableRows: sql_tpl.NewTableRows().Add("id", classID).
+		TableRow: sql_tpl.NewTableRow().Add("id", classID).
 			Add("name", className, sql_tpl.WithAESKey("@MKU^AHYCN$:j76J<TAHCVD#$XZSWQ@L")).
 			Add("student_num", studentNum).
 			Add("student_ids", strings.Join(studentIDs, "\n")).
@@ -202,7 +202,7 @@ func TestRawSqlTemplate(t *testing.T) {
 
 	updateExecuteParams, err := sql_tpl.UpdateExecuteParams{
 		TableName: tableName,
-		TableRows: sql_tpl.NewTableRows().
+		TableRow: sql_tpl.NewTableRow().
 			Add("name", newClassName, sql_tpl.WithAESKey("@MKU^AHYCN$:j76J<TAHCVD#$XZSWQ@L")).
 			Add("student_num", newStudentNum).
 			Add("student_ids", strings.Join(newStudentIDs, "\n")),
@@ -508,7 +508,7 @@ func TestSql(t *testing.T) {
 
 	insertExecuteParams, err := sql_tpl.InsertExecuteParams{
 		TableName: tableName,
-		TableRows: sql_tpl.NewTableRows().Add("id", classID).
+		TableRow: sql_tpl.NewTableRow().Add("id", classID).
 			Add("name", className, sql_tpl.WithAESKey("@MKU^AHYCN$:j76J<TAHCVD#$XZSWQ@L")).
 			Add("student_num", studentNum).
 			Add("student_ids", strings.Join(studentIDs, "\n")).
@@ -530,7 +530,7 @@ func TestSql(t *testing.T) {
 		Name:          className,
 		StudentNum:    int(studentNum),
 		StudentIDs:    studentIDs,
-		GraduatedTime: &newNow,
+		GraduatedTime: &now,
 		Ignored:       "",
 	}
 
@@ -584,7 +584,7 @@ func TestSql(t *testing.T) {
 
 	err = sql.Insert(sdk.GetInstance(), &sql_tpl.InsertExecuteParams{
 		TableName: tableName,
-		TableRows: sql_tpl.NewTableRows().Add("id", classID).
+		TableRow: sql_tpl.NewTableRow().Add("id", classID).
 			Add("name", className, sql_tpl.WithAESKey("@MKU^AHYCN$:j76J<TAHCVD#$XZSWQ@L")).
 			Add("student_num", studentNum).
 			Add("student_ids", strings.Join(studentIDs, "\n")).
@@ -598,7 +598,7 @@ func TestSql(t *testing.T) {
 
 	err = sql.Update(sdk.GetInstance(), &sql_tpl.UpdateExecuteParams{
 		TableName: tableName,
-		TableRows: sql_tpl.NewTableRows().Add("id", classID).
+		TableRow: sql_tpl.NewTableRow().Add("id", classID).
 			Add("name", newClassName, sql_tpl.WithAESKey("@MKU^AHYCN$:j76J<TAHCVD#$XZSWQ@L")).
 			Add("student_ids", strings.Join(newStudentIDs, "\n")).
 			Add("student_num", newStudentNum),
@@ -646,7 +646,7 @@ func TestSql(t *testing.T) {
 
 		err = sql.Insert(tx, &sql_tpl.InsertExecuteParams{
 			TableName: tableName,
-			TableRows: sql_tpl.NewTableRows().Add("id", classID).
+			TableRow: sql_tpl.NewTableRow().Add("id", classID).
 				Add("name", className, sql_tpl.WithAESKey("@MKU^AHYCN$:j76J<TAHCVD#$XZSWQ@L")).
 				Add("student_num", studentNum).
 				Add("student_ids", strings.Join(studentIDs, "\n")).
@@ -660,7 +660,7 @@ func TestSql(t *testing.T) {
 
 		err = sql.Update(tx, &sql_tpl.UpdateExecuteParams{
 			TableName: tableName,
-			TableRows: sql_tpl.NewTableRows().Add("id", classID).
+			TableRow: sql_tpl.NewTableRow().Add("id", classID).
 				Add("name", newClassName, sql_tpl.WithAESKey("@MKU^AHYCN$:j76J<TAHCVD#$XZSWQ@L")).
 				Add("student_num", newStudentNum).
 				Add("student_ids", strings.Join(newStudentIDs, "\n")),
@@ -811,3 +811,147 @@ func TestSql(t *testing.T) {
 		t.Fatal(err)
 	}
 }
+
+func TestInsertBatch(t *testing.T) {
+	classID1 := strutils.SimpleUUID()
+	className1 := strutils.SimpleUUID()
+	studentNum1 := rand.Int31n(100)
+	classID2 := strutils.SimpleUUID()
+	className2 := strutils.SimpleUUID()
+	studentNum2 := rand.Int31n(100)
+	now := time.Now()
+
+	insertBatchExecuteParams := &sql_tpl.InsertBatchExecuteParams{
+		TableName: tableName,
+		TableRowsBatch: []sql_tpl.TableRow{
+			*(sql_tpl.NewTableRow().Add("id", classID1).
+				Add("name", className1, sql_tpl.WithAESKey("@MKU^AHYCN$:j76J<TAHCVD#$XZSWQ@L")).
+				Add("student_num", studentNum1).
+				Add("student_ids", "").
+				Add("created_time", now).
+				Add("last_updated_time", now)),
+			*(sql_tpl.NewTableRow().Add("id", classID2).
+				Add("name", className2, sql_tpl.WithAESKey("@MKU^AHYCN$:j76J<TAHCVD#$XZSWQ@L")).
+				Add("student_num", studentNum2).
+				Add("student_ids", "").
+				Add("created_time", now).
+				Add("last_updated_time", now)),
+		},
+	}
+
+	class1 := &Class{
+		IDField:    IDField{ID: classID1},
+		Name:       className1,
+		StudentNum: int(studentNum1),
+		StudentIDs: make([]string, 0),
+	}
+
+	class2 := &Class{
+		IDField:    IDField{ID: classID2},
+		Name:       className2,
+		StudentNum: int(studentNum2),
+		StudentIDs: make([]string, 0),
+	}
+
+	err := sdk.InitInstance(token, address, httpPort, grpcPort, namespace, dataSource)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	defer func() {
+		err := sdk.DestroyInstance()
+		if err != nil {
+			t.Fatal(err)
+		}
+	}()
+
+	err = sql.InsertBatch(sdk.GetInstance(), insertBatchExecuteParams)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	results, _, err := sql.Query(sdk.GetInstance(), &sql_tpl.QueryExecuteParams{
+		TableName:     tableName,
+		SelectColumns: []string{"id", "name", "student_num as student_num_alias"},
+		Conditions:    sql_tpl.NewConditions().In("id", []string{classID1, classID2}),
+		PageNo:        0,
+		PageSize:      0,
+	})
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	classInfos, err := sql.ParseSqlResult[[]Class](results)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	for _, classInfo := range classInfos {
+		if classInfo.ID != classID1 && classInfo.ID != classID2 {
+			t.Fatal("id不正确")
+		}
+
+		if classInfo.ID == classID1 &&
+			(classInfo.Name != className1 || classInfo.StudentNum != int(studentNum1)) {
+			t.Fatal("数据不正确")
+		}
+
+		if classInfo.ID == classID2 &&
+			(classInfo.Name != className2 || classInfo.StudentNum != int(studentNum2)) {
+			t.Fatal("数据不正确")
+		}
+	}
+
+	err = sql.Delete(sdk.GetInstance(), &sql_tpl.DeleteExecuteParams{
+		TableName:  tableName,
+		Conditions: sql_tpl.NewConditions().In("id", []string{classID1, classID2}),
+	})
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	err = sql.InsertEntity(sdk.GetInstance(), tableName, []*Class{class1, class2})
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	results, _, err = sql.Query(sdk.GetInstance(), &sql_tpl.QueryExecuteParams{
+		TableName:     tableName,
+		SelectColumns: []string{"id", "name", "student_num as student_num_alias"},
+		Conditions:    sql_tpl.NewConditions().In("id", []string{classID1, classID2}),
+		PageNo:        0,
+		PageSize:      0,
+	})
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	classInfos, err = sql.ParseSqlResult[[]Class](results)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	for _, classInfo := range classInfos {
+		if classInfo.ID != classID1 && classInfo.ID != classID2 {
+			t.Fatal("id不正确")
+		}
+
+		if classInfo.ID == classID1 &&
+			(classInfo.Name != className1 || classInfo.StudentNum != int(studentNum1)) {
+			t.Fatal("数据不正确")
+		}
+
+		if classInfo.ID == classID2 &&
+			(classInfo.Name != className2 || classInfo.StudentNum != int(studentNum2)) {
+			t.Fatal("数据不正确")
+		}
+	}
+
+	err = sql.Delete(sdk.GetInstance(), &sql_tpl.DeleteExecuteParams{
+		TableName:  tableName,
+		Conditions: sql_tpl.NewConditions().In("id", []string{classID1, classID2}),
+	})
+	if err != nil {
+		t.Fatal(err)
+	}
+}