package sql import ( "errors" "git.sxidc.com/go-tools/utils/strutils" "git.sxidc.com/service-supports/ds-sdk/sdk" "git.sxidc.com/service-supports/ds-sdk/sql/sql_tpl" "reflect" "strings" "time" ) type Executor interface { ExecuteRawSql(sql string, executeParams map[string]any) ([]sdk.SqlResult, error) ExecuteSql(name string, executeParams map[string]any) ([]sdk.SqlResult, error) } const ( createdTimeFieldName = "CreatedTime" lastUpdatedTimeFieldName = "LastUpdatedTime" ) func InsertEntity[T any](executor Executor, tableName string, e T) error { if executor == nil { return errors.New("没有传递执行器") } if strutils.IsStringEmpty(tableName) { return errors.New("没有传递表名") } entityType := reflect.TypeOf(e) if entityType == nil { return errors.New("没有传递实体") } typeCheckErr := errors.New("可以接受的类型为struct, *struct, []struct, []*struct") if entityType.Kind() != reflect.Struct && entityType.Kind() != reflect.Ptr && entityType.Kind() != reflect.Slice { return typeCheckErr } else if entityType.Kind() == reflect.Ptr && entityType.Elem().Kind() != reflect.Struct { return typeCheckErr } else if entityType.Kind() == reflect.Slice && (entityType.Elem().Kind() != reflect.Struct && entityType.Elem().Kind() != reflect.Ptr) { return typeCheckErr } else if entityType.Kind() == reflect.Slice && entityType.Elem().Kind() == reflect.Ptr && entityType.Elem().Elem().Kind() != reflect.Struct { return typeCheckErr } var executeParamsMap map[string]any if entityType.Kind() == reflect.Struct || entityType.Kind() == reflect.Ptr { sqlMapping, err := ParseSqlMappingTag(reflect.ValueOf(e).Interface()) if err != nil { return err } tableRow := sql_tpl.NewTableRow() err = formInsertTableRow(sqlMapping, tableRow) if err != nil { return err } innerExecuteParamsMap, err := sql_tpl.InsertExecuteParams{ TableName: tableName, TableRow: tableRow, }.Map() if err != nil { return err } executeParamsMap = innerExecuteParamsMap } else { entitySliceValue := reflect.ValueOf(e) if entitySliceValue.Len() == 0 { return nil } tableRowBatch := make([]sql_tpl.TableRow, 0) for i := 0; i < entitySliceValue.Len(); i++ { sqlMapping, err := ParseSqlMappingTag(entitySliceValue.Index(i).Interface()) if err != nil { return err } tableRow := sql_tpl.NewTableRow() err = formInsertTableRow(sqlMapping, tableRow) if err != nil { return err } tableRowBatch = append(tableRowBatch, *tableRow) } innerExecuteParamsMap, err := sql_tpl.InsertBatchExecuteParams{ TableName: tableName, TableRowBatch: tableRowBatch, }.Map() if err != nil { return err } executeParamsMap = innerExecuteParamsMap } _, err := executor.ExecuteRawSql(sql_tpl.InsertTpl, executeParamsMap) if err != nil { if strings.Contains(err.Error(), "SQLSTATE 23505") { return sdk.ErrDBRecordHasExist } return err } return nil } func formInsertTableRow(sqlMapping *Mapping, tableRow *sql_tpl.TableRow) error { now := time.Now() for fieldName, mappingElement := range sqlMapping.MappingElement { switch element := mappingElement.(type) { case *Mapping: err := formInsertTableRow(element, tableRow) if err != nil { return err } case *MappingColumn: if element.IsKey && element.FieldValueElem.IsZero() { return errors.New("键字段没有传值") } fieldType := element.FieldTypeElem // 有值取值,没有值构造零值 value := reflect.Zero(fieldType).Interface() if fieldType.Kind() != reflect.Slice { if element.FieldValueElem.IsValid() && !element.FieldValueElem.IsZero() { value = element.FieldValueElem.Interface() } // 自动添加创建时间和更新时间 if (fieldName == createdTimeFieldName || fieldName == lastUpdatedTimeFieldName) && fieldType.String() == "time.Time" && value.(time.Time).IsZero() { value = now } } else { sliceElementType := fieldType.Elem() if sliceElementType.Kind() != reflect.String { return errors.New("slice仅支持[]string") } if element.FieldValueElem.Len() == 0 { value = "" } else { strValues := make([]string, 0, 0) for i := 0; i < element.FieldValueElem.Len(); i++ { strValues = append(strValues, element.FieldValueElem.Index(i).String()) } value = strings.Join(strValues, element.JoinWith) } } var opts []sql_tpl.AfterParsedStrValueOption if strutils.IsStringNotEmpty(element.AESKey) { opts = append(opts, sql_tpl.WithAESKey(element.AESKey)) } tableRow.Add(element.Name, value, opts...) default: return errors.New("不支持的元素类型") } } return nil } func DeleteEntity[T any](executor Executor, tableName string, e T) error { if executor == nil { return errors.New("没有传递执行器") } if strutils.IsStringEmpty(tableName) { return errors.New("没有传递表名") } if reflect.TypeOf(e) == nil { return errors.New("没有传递实体") } sqlMapping, err := ParseSqlMappingTag(e) if err != nil { return err } conditions := sql_tpl.NewConditions() err = formDeleteConditions(sqlMapping, conditions) if err != nil { return err } executeParamsMap, err := sql_tpl.DeleteExecuteParams{ TableName: tableName, Conditions: conditions, }.Map() if err != nil { return err } _, err = executor.ExecuteRawSql(sql_tpl.DeleteTpl, executeParamsMap) if err != nil { return err } return nil } func formDeleteConditions(sqlMapping *Mapping, conditions *sql_tpl.Conditions) error { for _, mappingElement := range sqlMapping.MappingElement { switch element := mappingElement.(type) { case *Mapping: err := formDeleteConditions(element, conditions) if err != nil { return err } case *MappingColumn: // 不是键,字段跳过 if !element.IsKey { continue } // 键字段没有赋值 if !element.FieldValueElem.IsValid() || element.FieldValueElem.IsZero() { return errors.New("键字段没有传值") } var opts []sql_tpl.AfterParsedStrValueOption if strutils.IsStringNotEmpty(element.AESKey) { opts = append(opts, sql_tpl.WithAESKey(element.AESKey)) } conditions.Equal(element.Name, element.FieldValueElem.Interface(), opts...) default: return errors.New("不支持的元素类型") } } return nil } func UpdateEntity[T any](executor Executor, tableName string, e T) error { if executor == nil { return errors.New("没有传递执行器") } if strutils.IsStringEmpty(tableName) { return errors.New("没有传递表名") } if reflect.TypeOf(e) == nil { return errors.New("没有传递实体") } sqlMapping, err := ParseSqlMappingTag(e) if err != nil { return err } tableRow := sql_tpl.NewTableRow() conditions := sql_tpl.NewConditions() err = formUpdateTableRowAndConditions(sqlMapping, tableRow, conditions) if err != nil { return err } executeParamsMap, err := sql_tpl.UpdateExecuteParams{ TableName: tableName, TableRow: tableRow, Conditions: conditions, }.Map() if err != nil { return err } _, err = executor.ExecuteRawSql(sql_tpl.UpdateTpl, executeParamsMap) if err != nil { return err } return nil } func formUpdateTableRowAndConditions(sqlMapping *Mapping, tableRow *sql_tpl.TableRow, conditions *sql_tpl.Conditions) error { now := time.Now() for fieldName, mappingElement := range sqlMapping.MappingElement { switch element := mappingElement.(type) { case *Mapping: err := formUpdateTableRowAndConditions(element, tableRow, conditions) if err != nil { return err } case *MappingColumn: if element.IsKey { // 键字段但是没有赋值 if element.FieldValueElem.IsZero() { return errors.New("键字段没有传值") } } else { // 不是更新时间字段 // 不是键字段 // 不更新的字段或者字段为空且不能清空,跳过 if fieldName != lastUpdatedTimeFieldName && (!element.CanUpdate || (element.FieldValueElem.IsZero() && !element.CanUpdateClear)) { continue } } fieldType := element.FieldTypeElem value := reflect.Zero(fieldType).Interface() if fieldType.Kind() != reflect.Slice { if element.FieldValueElem.IsValid() && !element.FieldValueElem.IsZero() { value = element.FieldValueElem.Interface() } } else { sliceElementType := fieldType.Elem() if sliceElementType.Kind() != reflect.String { return errors.New("slice仅支持[]string") } if element.FieldValueElem.Len() == 0 { value = "" } else { strValues := make([]string, 0, 0) for i := 0; i < element.FieldValueElem.Len(); i++ { strValues = append(strValues, element.FieldValueElem.Index(i).String()) } value = strings.Join(strValues, element.JoinWith) } } if fieldName == lastUpdatedTimeFieldName && fieldType.String() == "time.Time" && value.(time.Time).IsZero() { value = now } var opts []sql_tpl.AfterParsedStrValueOption if strutils.IsStringNotEmpty(element.AESKey) { opts = append(opts, sql_tpl.WithAESKey(element.AESKey)) } if element.IsKey { conditions.Equal(element.Name, value, opts...) } else { tableRow.Add(element.Name, value, opts...) } default: return errors.New("不支持的元素类型") } } return nil } func Insert(executor Executor, executeParams *sql_tpl.InsertExecuteParams) error { if executor == nil { return errors.New("没有传递执行器") } if executeParams == nil { return errors.New("没有传递执行参数") } executeParamsMap, err := executeParams.Map() if err != nil { return err } _, err = executor.ExecuteRawSql(sql_tpl.InsertTpl, executeParamsMap) if err != nil { return err } return nil } func InsertBatch(executor Executor, executeParams *sql_tpl.InsertBatchExecuteParams) error { if executor == nil { return errors.New("没有传递执行器") } if executeParams == nil { return errors.New("没有传递执行参数") } executeParamsMap, err := executeParams.Map() if err != nil { return err } _, err = executor.ExecuteRawSql(sql_tpl.InsertTpl, executeParamsMap) if err != nil { return err } return nil } func Delete(executor Executor, executeParams *sql_tpl.DeleteExecuteParams) error { if executor == nil { return errors.New("没有传递执行器") } if executeParams == nil { return errors.New("没有传递执行参数") } executeParamsMap, err := executeParams.Map() if err != nil { return err } _, err = executor.ExecuteRawSql(sql_tpl.DeleteTpl, executeParamsMap) if err != nil { return err } return nil } func Update(executor Executor, executeParams *sql_tpl.UpdateExecuteParams) error { if executor == nil { return errors.New("没有传递执行器") } if executeParams == nil { return errors.New("没有传递执行参数") } executeParamsMap, err := executeParams.Map() if err != nil { return err } _, err = executor.ExecuteRawSql(sql_tpl.UpdateTpl, executeParamsMap) if err != nil { return err } return nil } func Query(executor Executor, executeParams *sql_tpl.QueryExecuteParams) ([]sdk.SqlResult, int64, error) { if executor == nil { return nil, 0, errors.New("没有传递执行器") } if executeParams == nil { return nil, 0, errors.New("没有传递执行参数") } queryExecuteParamsMap, err := executeParams.Map() if err != nil { return nil, 0, err } countExecuteParamsMap, err := sql_tpl.CountExecuteParams{ TableName: executeParams.TableName, Conditions: executeParams.Conditions, }.Map() if err != nil { return nil, 0, err } tableRows, err := executor.ExecuteRawSql(sql_tpl.QueryTpl, queryExecuteParamsMap) if err != nil { return nil, 0, err } countTableRow, err := executor.ExecuteRawSql(sql_tpl.CountTpl, countExecuteParamsMap) if err != nil { return nil, 0, err } results := make([]sdk.SqlResult, len(tableRows)) for i, row := range tableRows { results[i] = row } return results, int64(countTableRow[0]["count"].(float64)), nil } func QueryOne(executor Executor, executeParams *sql_tpl.QueryOneExecuteParams) (sdk.SqlResult, error) { if executor == nil { return nil, errors.New("没有传递执行器") } if executeParams == nil { return nil, errors.New("没有传递执行参数") } executeParamsMap, err := executeParams.Map() if err != nil { return nil, err } tableRows, err := executor.ExecuteRawSql(sql_tpl.QueryTpl, executeParamsMap) if err != nil { return nil, err } if tableRows == nil || len(tableRows) == 0 { return nil, sdk.ErrDBRecordNotExist } return tableRows[0], nil } func Count(executor Executor, executeParams *sql_tpl.CountExecuteParams) (int64, error) { if executor == nil { return 0, errors.New("没有传递执行器") } if executeParams == nil { return 0, errors.New("没有传递执行参数") } executeParamsMap, err := executeParams.Map() if err != nil { return 0, err } tableRows, err := executor.ExecuteRawSql(sql_tpl.CountTpl, executeParamsMap) if err != nil { return 0, err } return int64(tableRows[0]["count"].(float64)), nil } func CheckExist(executor Executor, executeParams *sql_tpl.CheckExistExecuteParams) (bool, error) { if executor == nil { return false, errors.New("没有传递执行器") } if executeParams == nil { return false, errors.New("没有传递执行参数") } executeParamsMap, err := executeParams.Map() if err != nil { return false, err } tableRows, err := executor.ExecuteRawSql(sql_tpl.CountTpl, executeParamsMap) if err != nil { return false, err } return int64(tableRows[0]["count"].(float64)) > 0, nil } func CheckHasOnlyOne(executor Executor, executeParams *sql_tpl.CheckHasOnlyOneExecuteParams) (bool, error) { if executor == nil { return false, errors.New("没有传递执行器") } if executeParams == nil { return false, errors.New("没有传递执行参数") } executeParamsMap, err := executeParams.Map() if err != nil { return false, err } tableRows, err := executor.ExecuteRawSql(sql_tpl.CountTpl, executeParamsMap) if err != nil { return false, err } return int64(tableRows[0]["count"].(float64)) == 1, nil } func ExecuteRawSql(executor Executor, sql string, executeParams map[string]any) ([]sdk.SqlResult, error) { if executor == nil { return nil, errors.New("没有传递执行器") } if strutils.IsStringEmpty(sql) { return nil, errors.New("没有sql") } tableRows, err := executor.ExecuteRawSql(sql, executeParams) if err != nil { return nil, err } return tableRows, nil } func ExecuteSql(executor Executor, name string, executeParams map[string]any) ([]sdk.SqlResult, error) { if executor == nil { return nil, errors.New("没有传递执行器") } if strutils.IsStringEmpty(name) { return nil, errors.New("没有sql资源名称") } tableRows, err := executor.ExecuteSql(name, executeParams) if err != nil { return nil, err } return tableRows, nil }