Browse Source

添加更新返回影响行数的接口

yjp 3 months ago
parent
commit
e733e457e3
1 changed files with 90 additions and 0 deletions
  1. 90 0
      framework/core/infrastructure/database/database.go

+ 90 - 0
framework/core/infrastructure/database/database.go

@@ -373,6 +373,96 @@ func UpdateEntity(executor Executor, tableName string, e any) error {
 	return nil
 }
 
+// UpdateEntityWithRowsAffected 通过结构更新数据,返回影响行数
+// 参数:
+// - executor: 数据库基础设施接口
+// - tableName: 表名
+// - e: 结构,结构字段需要使用sqlmapping标注
+// 返回值:
+// - 影响行数
+// - 错误
+func UpdateEntityWithRowsAffected(executor Executor, tableName string, e any) (int64, error) {
+	if executor == nil {
+		return 0, errors.New("没有传递执行器")
+	}
+
+	if strutils.IsStringEmpty(tableName) {
+		return 0, errors.New("没有传递表名")
+	}
+
+	if e == nil {
+		return 0, nil
+	}
+
+	entityType := reflect.TypeOf(e)
+	if !reflectutils.IsTypeStructOrStructPointer(entityType) {
+		return 0, errors.New("实体参数不是结构或结构指针")
+	}
+
+	fields, err := sql_mapping.DefaultUsage(e)
+	if err != nil {
+		return 0, err
+	}
+
+	now := time.Now().Local()
+	tableRow := sql.NewTableRow()
+	conditions := sql.NewConditions()
+
+	for _, field := range fields {
+		// 不是键字段
+		// 不是更新时间字段
+		// 不更新的字段或者字段为零值且不能清空,跳过
+		if !field.IsKey && field.FieldName != lastUpdatedTimeFieldName &&
+			(!field.CanUpdate || (reflect.ValueOf(field.Value).IsZero() && !field.CanUpdateClear)) {
+			continue
+		}
+
+		fieldValue := reflect.ValueOf(field.Value)
+
+		if field.FieldName == lastUpdatedTimeFieldName &&
+			reflectutils.IsValueTime(fieldValue) && fieldValue.IsZero() {
+			field.Value = now
+		}
+
+		if field.FieldName != lastUpdatedTimeFieldName &&
+			reflectutils.IsValueTime(fieldValue) && fieldValue.IsZero() {
+			field.Value = nil
+		}
+
+		if field.IsKey {
+			conditions.Equal(field.ColumnName, field.Value)
+		} else {
+			if (field.Value == nil || reflect.ValueOf(field.Value).IsZero()) && !field.CanUpdateClear {
+				continue
+			}
+
+			tableRow.Add(field.ColumnName, field.Value)
+		}
+	}
+
+	executeParams := sql.UpdateExecuteParams{
+		TableName:  tableName,
+		TableRow:   tableRow,
+		Conditions: conditions,
+	}
+
+	executeParamsMap, err := executeParams.Map()
+	if err != nil {
+		return 0, err
+	}
+
+	args := make([]any, 0)
+	args = append(args, executeParams.TableRow.Values()...)
+	args = append(args, executeParams.Conditions.Args()...)
+
+	_, rowsAffected, err := executor.ExecuteRawSqlTemplateWithRowsAffected(sql.UpdateTpl, executeParamsMap, args...)
+	if err != nil {
+		return 0, err
+	}
+
+	return rowsAffected, nil
+}
+
 // Insert 插入数据
 // 参数:
 // - executor: 数据库基础设施接口