|
|
@@ -160,6 +160,100 @@ func Delete[T any](executor RawSqlExecutor, tableName string, e T) error {
|
|
|
return nil
|
|
|
}
|
|
|
|
|
|
+type UpdateCallback[T any] func(e T, fieldName string, value any) (retValue any, err error)
|
|
|
+
|
|
|
+func Update[T any](executor RawSqlExecutor, tableName string, e T, callback UpdateCallback[T]) error {
|
|
|
+ if executor == nil {
|
|
|
+ return errors.New("没有传递执行器")
|
|
|
+ }
|
|
|
+
|
|
|
+ if strutils.IsStringEmpty(tableName) {
|
|
|
+ return errors.New("没有传递表名")
|
|
|
+ }
|
|
|
+
|
|
|
+ if reflect.TypeOf(e) == nil {
|
|
|
+ return errors.New("没有传递实体")
|
|
|
+ }
|
|
|
+
|
|
|
+ sqlMapping, err := tag.ParseSqlMapping(e)
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+
|
|
|
+ executeParams := raw_sql_tpl.UpdateExecuteParams{
|
|
|
+ TableName: tableName,
|
|
|
+ }
|
|
|
+
|
|
|
+ now := time.Now()
|
|
|
+
|
|
|
+ for fieldName, sqlMappingColumn := range sqlMapping.ColumnMap {
|
|
|
+ fieldType := sqlMappingColumn.ValueFieldType
|
|
|
+
|
|
|
+ value := reflect.Zero(fieldType).Interface()
|
|
|
+ if !sqlMappingColumn.ValueFieldValue.IsZero() {
|
|
|
+ value = sqlMappingColumn.ValueFieldValue.Interface()
|
|
|
+ }
|
|
|
+
|
|
|
+ if sqlMappingColumn.InsertCallback {
|
|
|
+ if callback == nil {
|
|
|
+ return errors.New("需要使用回调函数但是没有传递回调函数")
|
|
|
+ }
|
|
|
+
|
|
|
+ retValue, err := callback(e, fieldName, value)
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+
|
|
|
+ retValueType := reflect.TypeOf(retValue)
|
|
|
+ if retValueType == nil || retValueType.Kind() == reflect.Ptr {
|
|
|
+ return errors.New("返回应当为值类型")
|
|
|
+ }
|
|
|
+
|
|
|
+ value = retValue
|
|
|
+ }
|
|
|
+
|
|
|
+ if fieldName == lastUpdatedTimeFieldName &&
|
|
|
+ fieldType.String() == "time.Time" && value.(time.Time).IsZero() {
|
|
|
+ value = now
|
|
|
+ }
|
|
|
+
|
|
|
+ // 字段为空不更新
|
|
|
+ if reflect.ValueOf(value).IsZero() && !sqlMappingColumn.CanUpdateClear {
|
|
|
+ continue
|
|
|
+ }
|
|
|
+
|
|
|
+ tableRowValue, err := parseValue(value)
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+
|
|
|
+ executeParams.TableRows = append(executeParams.TableRows, raw_sql_tpl.TableRow{
|
|
|
+ Column: sqlMappingColumn.Name,
|
|
|
+ Value: tableRowValue,
|
|
|
+ })
|
|
|
+
|
|
|
+ if sqlMappingColumn.IsKey {
|
|
|
+ executeParams.Conditions = append(executeParams.Conditions, raw_sql_tpl.Condition{
|
|
|
+ Column: sqlMappingColumn.Name,
|
|
|
+ Operator: "=",
|
|
|
+ Value: tableRowValue,
|
|
|
+ })
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ executeParamsMap, err := executeParams.Map()
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+
|
|
|
+ _, err = executor.ExecuteRawSql(raw_sql_tpl.UpdateTpl, executeParamsMap)
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+
|
|
|
+ return nil
|
|
|
+}
|
|
|
+
|
|
|
func parseValue(value any) (string, error) {
|
|
|
switch v := value.(type) {
|
|
|
case string:
|