package sql import ( "errors" "git.sxidc.com/go-tools/utils/strutils" "git.sxidc.com/service-supports/ds-sdk/sdk" "git.sxidc.com/service-supports/ds-sdk/sql/sql_tpl" "github.com/mitchellh/mapstructure" "reflect" "strings" "time" ) // TODO Key字段校验零值 type Executor interface { ExecuteRawSql(sql string, executeParams map[string]any) ([]map[string]any, error) ExecuteSql(name string, executeParams map[string]any) ([]map[string]any, error) } const ( createdTimeFieldName = "CreatedTime" lastUpdatedTimeFieldName = "LastUpdatedTime" ) func InsertEntity[T any](executor Executor, tableName string, e T) error { if executor == nil { return errors.New("没有传递执行器") } if strutils.IsStringEmpty(tableName) { return errors.New("没有传递表名") } if reflect.TypeOf(e) == nil { return errors.New("没有传递实体") } sqlMapping, err := ParseSqlMapping(e) if err != nil { return err } tableRow := sql_tpl.NewTableRows() now := time.Now() for fieldName, sqlColumn := range sqlMapping.ColumnMap { fieldType := sqlColumn.ValueFieldType value := reflect.Zero(fieldType).Interface() if !sqlColumn.ValueFieldValue.IsZero() { value = sqlColumn.ValueFieldValue.Interface() } if (fieldName == createdTimeFieldName || fieldName == lastUpdatedTimeFieldName) && fieldType.String() == "time.Time" && value.(time.Time).IsZero() { value = now } tableRow.Add(sqlColumn.Name, value) } executeParamsMap, err := sql_tpl.InsertExecuteParams{ TableName: tableName, TableRows: tableRow, }.Map() if err != nil { return err } _, err = executor.ExecuteRawSql(sql_tpl.InsertTpl, executeParamsMap) if err != nil { if strings.Contains(err.Error(), "SQLSTATE 23505") { return sdk.ErrDBRecordHasExist } return err } return nil } func DeleteEntity[T any](executor Executor, tableName string, e T) error { if executor == nil { return errors.New("没有传递执行器") } if strutils.IsStringEmpty(tableName) { return errors.New("没有传递表名") } if reflect.TypeOf(e) == nil { return errors.New("没有传递实体") } sqlMapping, err := ParseSqlMapping(e) if err != nil { return err } conditions := sql_tpl.NewConditions() for _, sqlColumn := range sqlMapping.ColumnMap { if !sqlColumn.IsKey { continue } fieldType := sqlColumn.ValueFieldType value := reflect.Zero(fieldType).Interface() if !sqlColumn.ValueFieldValue.IsZero() { value = sqlColumn.ValueFieldValue.Interface() } conditions.Equal(sqlColumn.Name, value) } executeParamsMap, err := sql_tpl.DeleteExecuteParams{ TableName: tableName, Conditions: conditions, }.Map() if err != nil { return err } _, err = executor.ExecuteRawSql(sql_tpl.DeleteTpl, executeParamsMap) if err != nil { return err } return nil } func UpdateEntity[T any](executor Executor, tableName string, e T) error { if executor == nil { return errors.New("没有传递执行器") } if strutils.IsStringEmpty(tableName) { return errors.New("没有传递表名") } if reflect.TypeOf(e) == nil { return errors.New("没有传递实体") } sqlMapping, err := ParseSqlMapping(e) if err != nil { return err } tableRows := sql_tpl.NewTableRows() conditions := sql_tpl.NewConditions() now := time.Now() for fieldName, sqlColumn := range sqlMapping.ColumnMap { if !sqlColumn.IsKey && !sqlColumn.CanUpdate { continue } fieldType := sqlColumn.ValueFieldType value := reflect.Zero(fieldType).Interface() if !sqlColumn.ValueFieldValue.IsZero() { value = sqlColumn.ValueFieldValue.Interface() } if fieldName == lastUpdatedTimeFieldName && fieldType.String() == "time.Time" && value.(time.Time).IsZero() { value = now } // 字段为空且不能清空,不更新 if reflect.ValueOf(value).IsZero() && !sqlColumn.CanUpdateClear { continue } if !sqlColumn.IsKey { tableRows.Add(sqlColumn.Name, value) } if sqlColumn.IsKey { conditions.Equal(sqlColumn.Name, value) } } executeParamsMap, err := sql_tpl.UpdateExecuteParams{ TableName: tableName, TableRows: tableRows, Conditions: conditions, }.Map() if err != nil { return err } _, err = executor.ExecuteRawSql(sql_tpl.UpdateTpl, executeParamsMap) if err != nil { return err } return nil } func Insert(executor Executor, executeParams *sql_tpl.InsertExecuteParams) error { if executor == nil { return errors.New("没有传递执行器") } if executeParams == nil { return errors.New("没有传递执行参数") } executeParamsMap, err := executeParams.Map() if err != nil { return err } _, err = executor.ExecuteRawSql(sql_tpl.InsertTpl, executeParamsMap) if err != nil { return err } return nil } func Delete(executor Executor, executeParams *sql_tpl.DeleteExecuteParams) error { if executor == nil { return errors.New("没有传递执行器") } if executeParams == nil { return errors.New("没有传递执行参数") } executeParamsMap, err := executeParams.Map() if err != nil { return err } _, err = executor.ExecuteRawSql(sql_tpl.DeleteTpl, executeParamsMap) if err != nil { return err } return nil } func Update(executor Executor, executeParams *sql_tpl.UpdateExecuteParams) error { if executor == nil { return errors.New("没有传递执行器") } if executeParams == nil { return errors.New("没有传递执行参数") } executeParamsMap, err := executeParams.Map() if err != nil { return err } _, err = executor.ExecuteRawSql(sql_tpl.UpdateTpl, executeParamsMap) if err != nil { return err } return nil } func Query(executor Executor, executeParams *sql_tpl.QueryExecuteParams) ([]map[string]any, int64, error) { if executor == nil { return nil, 0, errors.New("没有传递执行器") } if executeParams == nil { return nil, 0, errors.New("没有传递执行参数") } queryExecuteParamsMap, err := executeParams.Map() if err != nil { return nil, 0, err } countExecuteParamsMap, err := sql_tpl.CountExecuteParams{ TableName: executeParams.TableName, Conditions: executeParams.Conditions, }.Map() if err != nil { return nil, 0, err } tableRows, err := executor.ExecuteRawSql(sql_tpl.QueryTpl, queryExecuteParamsMap) if err != nil { return nil, 0, err } countTableRow, err := executor.ExecuteRawSql(sql_tpl.CountTpl, countExecuteParamsMap) if err != nil { return nil, 0, err } return tableRows, int64(countTableRow[0]["count"].(float64)), nil } func QueryOne(executor Executor, executeParams *sql_tpl.QueryOneExecuteParams) (map[string]any, error) { if executor == nil { return nil, errors.New("没有传递执行器") } if executeParams == nil { return nil, errors.New("没有传递执行参数") } executeParamsMap, err := executeParams.Map() if err != nil { return nil, err } tableRows, err := executor.ExecuteRawSql(sql_tpl.QueryTpl, executeParamsMap) if err != nil { return nil, err } if tableRows == nil || len(tableRows) == 0 { return nil, sdk.ErrDBRecordNotExist } return tableRows[0], nil } func Count(executor Executor, executeParams *sql_tpl.CountExecuteParams) (int64, error) { if executor == nil { return 0, errors.New("没有传递执行器") } if executeParams == nil { return 0, errors.New("没有传递执行参数") } executeParamsMap, err := executeParams.Map() if err != nil { return 0, err } tableRows, err := executor.ExecuteRawSql(sql_tpl.CountTpl, executeParamsMap) if err != nil { return 0, err } return int64(tableRows[0]["count"].(float64)), nil } func CheckExist(executor Executor, executeParams *sql_tpl.CheckExistExecuteParams) (bool, error) { if executor == nil { return false, errors.New("没有传递执行器") } if executeParams == nil { return false, errors.New("没有传递执行参数") } executeParamsMap, err := executeParams.Map() if err != nil { return false, err } tableRows, err := executor.ExecuteRawSql(sql_tpl.CountTpl, executeParamsMap) if err != nil { return false, err } return int64(tableRows[0]["count"].(float64)) > 0, nil } func CheckHasOnlyOne(executor Executor, executeParams *sql_tpl.CheckHasOnlyOneExecuteParams) (bool, error) { if executor == nil { return false, errors.New("没有传递执行器") } if executeParams == nil { return false, errors.New("没有传递执行参数") } executeParamsMap, err := executeParams.Map() if err != nil { return false, err } tableRows, err := executor.ExecuteRawSql(sql_tpl.CountTpl, executeParamsMap) if err != nil { return false, err } return int64(tableRows[0]["count"].(float64)) == 1, nil } func ExecuteRawSql(executor Executor, sql string, executeParams map[string]any) ([]map[string]any, error) { if executor == nil { return nil, errors.New("没有传递执行器") } if strutils.IsStringEmpty(sql) { return nil, errors.New("没有sql") } tableRows, err := executor.ExecuteRawSql(sql, executeParams) if err != nil { return nil, err } return tableRows, nil } func ExecuteSql(executor Executor, name string, executeParams map[string]any) ([]map[string]any, error) { if executor == nil { return nil, errors.New("没有传递执行器") } if strutils.IsStringEmpty(name) { return nil, errors.New("没有sql资源名称") } tableRows, err := executor.ExecuteSql(name, executeParams) if err != nil { return nil, err } return tableRows, nil } const ( sqlResultTimeMicroFormat = "2006-01-02T15:04:05.000000+08:00" sqlResultTimeMilliFormat = "2006-01-02T15:04:05.000+08:00" sqlResultTimeSecFormat = "2006-01-02T15:04:05+08:00" ) // TODO 添加DecodeHook func ParseSqlResults(results any, e any) error { decoder, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{ DecodeHook: func(f reflect.Type, t reflect.Type, data interface{}) (interface{}, error) { if f.Kind() != reflect.String { return data, nil } if t != reflect.TypeOf(time.Time{}) { return data, nil } var layout string timeStr := data.(string) if strings.HasSuffix(timeStr, ".000000+08:00") { layout = sqlResultTimeMicroFormat } else if strings.HasSuffix(timeStr, ".000+08:00") { layout = sqlResultTimeMilliFormat } else { layout = sqlResultTimeSecFormat } return time.ParseInLocation(layout, data.(string), time.Local) }, Result: e, }) if err != nil { return err } err = decoder.Decode(results) if err != nil { return err } return nil }