Browse Source

修改bug

yjp 11 months ago
parent
commit
30e494532f
1 changed files with 23 additions and 80 deletions
  1. 23 80
      framework/core/infrastructure/database/database.go

+ 23 - 80
framework/core/infrastructure/database/database.go

@@ -292,80 +292,7 @@ func DeleteEntity(executor Executor, tableName string, e any) error {
 // 返回值:
 // - 错误
 func UpdateEntity(executor Executor, tableName string, e any) error {
-	if executor == nil {
-		return errors.New("没有传递执行器")
-	}
-
-	if strutils.IsStringEmpty(tableName) {
-		return errors.New("没有传递表名")
-	}
-
-	if e == nil {
-		return nil
-	}
-
-	entityType := reflect.TypeOf(e)
-	if !reflectutils.IsTypeStructOrStructPointer(entityType) {
-		return errors.New("实体参数不是结构或结构指针")
-	}
-
-	fields, err := sql_mapping.DefaultUsage(e)
-	if err != nil {
-		return err
-	}
-
-	now := time.Now().Local()
-	tableRow := sql.NewTableRow()
-	conditions := sql.NewConditions()
-
-	for _, field := range fields {
-		// 不是键字段
-		// 不是更新时间字段
-		// 不更新的字段或者字段为零值且不能清空,跳过
-		if !field.IsKey && field.FieldName != lastUpdatedTimeFieldName &&
-			(!field.CanUpdate || (reflect.ValueOf(field.Value).IsZero() && !field.CanUpdateClear)) {
-			continue
-		}
-
-		fieldValue := reflect.ValueOf(field.Value)
-
-		if field.FieldName == lastUpdatedTimeFieldName &&
-			reflectutils.IsValueTime(fieldValue) && fieldValue.IsZero() {
-			field.Value = now
-		}
-
-		if field.FieldName != lastUpdatedTimeFieldName &&
-			reflectutils.IsValueTime(fieldValue) && fieldValue.IsZero() {
-			field.Value = nil
-		}
-
-		if field.IsKey {
-			conditions.Equal(field.ColumnName, field.Value)
-		} else {
-			if (field.Value == nil || reflect.ValueOf(field.Value).IsZero()) && !field.CanUpdateClear {
-				continue
-			}
-
-			tableRow.Add(field.ColumnName, field.Value)
-		}
-	}
-
-	executeParams := sql.UpdateExecuteParams{
-		TableName:  tableName,
-		TableRow:   tableRow,
-		Conditions: conditions,
-	}
-
-	executeParamsMap, err := executeParams.Map()
-	if err != nil {
-		return err
-	}
-
-	args := make([]any, 0)
-	args = append(args, executeParams.TableRow.Values()...)
-	args = append(args, executeParams.Conditions.Args()...)
-
-	_, err = executor.ExecuteRawSqlTemplate(sql.UpdateTpl, executeParamsMap, args...)
+	_, err := UpdateEntityWithRowsAffected(executor, tableName, e)
 	if err != nil {
 		return err
 	}
@@ -559,12 +486,28 @@ func Delete(executor Executor, executeParams *sql.DeleteExecuteParams) error {
 // 返回值:
 // - 错误
 func Update(executor Executor, executeParams *sql.UpdateExecuteParams) error {
+	_, err := UpdateWithRowsAffected(executor, executeParams)
+	if err != nil {
+		return err
+	}
+
+	return nil
+}
+
+// UpdateWithRowsAffected 更新数据,返回影响行数
+// 参数:
+// - executor: 数据库基础设施接口
+// - executeParams: 更新数据参数
+// 返回值:
+// - 影响行数
+// - 错误
+func UpdateWithRowsAffected(executor Executor, executeParams *sql.UpdateExecuteParams) (int64, error) {
 	if executor == nil {
-		return errors.New("没有传递执行器")
+		return 0, errors.New("没有传递执行器")
 	}
 
 	if executeParams == nil {
-		return errors.New("没有传递执行参数")
+		return 0, errors.New("没有传递执行参数")
 	}
 
 	if executeParams.Conditions == nil {
@@ -573,19 +516,19 @@ func Update(executor Executor, executeParams *sql.UpdateExecuteParams) error {
 
 	executeParamsMap, err := executeParams.Map()
 	if err != nil {
-		return err
+		return 0, err
 	}
 
 	args := make([]any, 0)
 	args = append(args, executeParams.TableRow.Values()...)
 	args = append(args, executeParams.Conditions.Args()...)
 
-	_, err = executor.ExecuteRawSqlTemplate(sql.UpdateTpl, executeParamsMap, args...)
+	_, rowsAffected, err := executor.ExecuteRawSqlTemplateWithRowsAffected(sql.UpdateTpl, executeParamsMap, args...)
 	if err != nil {
-		return err
+		return 0, err
 	}
 
-	return nil
+	return rowsAffected, nil
 }
 
 // Query 查询数据