瀏覽代碼

完成更新

yjp 1 年之前
父節點
當前提交
d57afb0f59
共有 2 個文件被更改,包括 111 次插入0 次删除
  1. 94 0
      sdk/sql.go
  2. 17 0
      test/sdk_test.go

+ 94 - 0
sdk/sql.go

@@ -160,6 +160,100 @@ func Delete[T any](executor RawSqlExecutor, tableName string, e T) error {
 	return nil
 }
 
+type UpdateCallback[T any] func(e T, fieldName string, value any) (retValue any, err error)
+
+func Update[T any](executor RawSqlExecutor, tableName string, e T, callback UpdateCallback[T]) error {
+	if executor == nil {
+		return errors.New("没有传递执行器")
+	}
+
+	if strutils.IsStringEmpty(tableName) {
+		return errors.New("没有传递表名")
+	}
+
+	if reflect.TypeOf(e) == nil {
+		return errors.New("没有传递实体")
+	}
+
+	sqlMapping, err := tag.ParseSqlMapping(e)
+	if err != nil {
+		return err
+	}
+
+	executeParams := raw_sql_tpl.UpdateExecuteParams{
+		TableName: tableName,
+	}
+
+	now := time.Now()
+
+	for fieldName, sqlMappingColumn := range sqlMapping.ColumnMap {
+		fieldType := sqlMappingColumn.ValueFieldType
+
+		value := reflect.Zero(fieldType).Interface()
+		if !sqlMappingColumn.ValueFieldValue.IsZero() {
+			value = sqlMappingColumn.ValueFieldValue.Interface()
+		}
+
+		if sqlMappingColumn.InsertCallback {
+			if callback == nil {
+				return errors.New("需要使用回调函数但是没有传递回调函数")
+			}
+
+			retValue, err := callback(e, fieldName, value)
+			if err != nil {
+				return err
+			}
+
+			retValueType := reflect.TypeOf(retValue)
+			if retValueType == nil || retValueType.Kind() == reflect.Ptr {
+				return errors.New("返回应当为值类型")
+			}
+
+			value = retValue
+		}
+
+		if fieldName == lastUpdatedTimeFieldName &&
+			fieldType.String() == "time.Time" && value.(time.Time).IsZero() {
+			value = now
+		}
+
+		// 字段为空不更新
+		if reflect.ValueOf(value).IsZero() && !sqlMappingColumn.CanUpdateClear {
+			continue
+		}
+
+		tableRowValue, err := parseValue(value)
+		if err != nil {
+			return err
+		}
+
+		executeParams.TableRows = append(executeParams.TableRows, raw_sql_tpl.TableRow{
+			Column: sqlMappingColumn.Name,
+			Value:  tableRowValue,
+		})
+
+		if sqlMappingColumn.IsKey {
+			executeParams.Conditions = append(executeParams.Conditions, raw_sql_tpl.Condition{
+				Column:   sqlMappingColumn.Name,
+				Operator: "=",
+				Value:    tableRowValue,
+			})
+		}
+	}
+
+	executeParamsMap, err := executeParams.Map()
+	if err != nil {
+		return err
+	}
+
+	_, err = executor.ExecuteRawSql(raw_sql_tpl.UpdateTpl, executeParamsMap)
+	if err != nil {
+		return err
+	}
+
+	return nil
+}
+
 func parseValue(value any) (string, error) {
 	switch v := value.(type) {
 	case string:

+ 17 - 0
test/sdk_test.go

@@ -497,6 +497,8 @@ func TestSql(t *testing.T) {
 	classID := strutils.SimpleUUID()
 	className := strutils.SimpleUUID()
 	studentNum := rand.Int31n(100)
+	newClassName := strutils.SimpleUUID()
+	newStudentNum := rand.Int31n(100)
 
 	class := &Class{
 		ID:            classID,
@@ -506,6 +508,14 @@ func TestSql(t *testing.T) {
 		Ignored:       "",
 	}
 
+	newClass := &Class{
+		ID:            classID,
+		Name:          newClassName,
+		StudentNum:    int(newStudentNum),
+		GraduatedTime: time.Now(),
+		Ignored:       "",
+	}
+
 	err = sdk.Insert(sdk.GetInstance(), tableName, class, func(e *Class, fieldName string, value any) (retValue any, err error) {
 		return value, nil
 	})
@@ -513,6 +523,13 @@ func TestSql(t *testing.T) {
 		t.Fatal(err)
 	}
 
+	err = sdk.Update(sdk.GetInstance(), tableName, newClass, func(e *Class, fieldName string, value any) (retValue any, err error) {
+		return value, nil
+	})
+	if err != nil {
+		t.Fatal(err)
+	}
+
 	err = sdk.Delete(sdk.GetInstance(), tableName, class)
 	if err != nil {
 		t.Fatal(err)