package sdk import ( "errors" "git.sxidc.com/go-tools/utils/strutils" "git.sxidc.com/service-supports/ds-sdk/sdk/raw_sql_tpl" "git.sxidc.com/service-supports/ds-sdk/sdk/tag" "reflect" "strconv" "strings" "time" ) type SqlExecutor interface { ExecuteRawSql(sql string, executeParams map[string]any) ([]map[string]any, error) ExecuteSql(name string, executeParams map[string]any) ([]map[string]any, error) } const ( timeWriteFormat = time.DateTime + ".000000 +08:00" createdTimeFieldName = "CreatedTime" lastUpdatedTimeFieldName = "LastUpdatedTime" ) type ValueCallback[T any] func(e T, fieldName string, value any) (retValue any, err error) type ConditionCallback[T any] func(e T, fieldName string, columnName string, value any) (retConditionOp string, retConditionValue any, err error) func Insert[T any](executor SqlExecutor, tableName string, e T, callback ValueCallback[T]) error { if executor == nil { return errors.New("没有传递执行器") } if strutils.IsStringEmpty(tableName) { return errors.New("没有传递表名") } if reflect.TypeOf(e) == nil { return errors.New("没有传递实体") } sqlMapping, err := tag.ParseSqlMapping(e) if err != nil { return err } executeParams := raw_sql_tpl.InsertExecuteParams{ TableName: tableName, } now := time.Now() for fieldName, sqlMappingColumn := range sqlMapping.ColumnMap { fieldType := sqlMappingColumn.ValueFieldType value := reflect.Zero(fieldType).Interface() if !sqlMappingColumn.ValueFieldValue.IsZero() { value = sqlMappingColumn.ValueFieldValue.Interface() } if sqlMappingColumn.InsertCallback { if callback == nil { return errors.New("需要使用回调函数但是没有传递回调函数") } retValue, err := callback(e, fieldName, value) if err != nil { return err } retValueType := reflect.TypeOf(retValue) if retValueType == nil || retValueType.Kind() == reflect.Ptr { return errors.New("返回应当为值类型") } value = retValue } if (fieldName == createdTimeFieldName || fieldName == lastUpdatedTimeFieldName) && fieldType.String() == "time.Time" && value.(time.Time).IsZero() { value = now } tableRowValue, err := parseValue(value) if err != nil { return err } executeParams.TableRows = append(executeParams.TableRows, raw_sql_tpl.TableRow{ Column: sqlMappingColumn.Name, Value: tableRowValue, }) } executeParamsMap, err := executeParams.Map() if err != nil { return err } _, err = executor.ExecuteRawSql(raw_sql_tpl.InsertTpl, executeParamsMap) if err != nil { if strings.Contains(err.Error(), "SQLSTATE 23505") { return ErrDBRecordHasExist } return err } return nil } func Delete[T any](executor SqlExecutor, tableName string, e T) error { if executor == nil { return errors.New("没有传递执行器") } if strutils.IsStringEmpty(tableName) { return errors.New("没有传递表名") } if reflect.TypeOf(e) == nil { return errors.New("没有传递实体") } sqlMapping, err := tag.ParseSqlMapping(e) if err != nil { return err } executeParams := raw_sql_tpl.DeleteExecuteParams{ TableName: tableName, } for _, sqlMappingColumn := range sqlMapping.ColumnMap { if !sqlMappingColumn.IsKey { continue } fieldType := sqlMappingColumn.ValueFieldType value := reflect.Zero(fieldType).Interface() if !sqlMappingColumn.ValueFieldValue.IsZero() { value = sqlMappingColumn.ValueFieldValue.Interface() } tableRowValue, err := parseValue(value) if err != nil { return err } executeParams.Conditions = append(executeParams.Conditions, raw_sql_tpl.Condition{ Column: sqlMappingColumn.Name, Operator: "=", Value: tableRowValue, }) } executeParamsMap, err := executeParams.Map() if err != nil { return err } _, err = executor.ExecuteRawSql(raw_sql_tpl.DeleteTpl, executeParamsMap) if err != nil { return err } return nil } func Update[T any](executor SqlExecutor, tableName string, e T, callback ValueCallback[T]) error { if executor == nil { return errors.New("没有传递执行器") } if strutils.IsStringEmpty(tableName) { return errors.New("没有传递表名") } if reflect.TypeOf(e) == nil { return errors.New("没有传递实体") } sqlMapping, err := tag.ParseSqlMapping(e) if err != nil { return err } executeParams := raw_sql_tpl.UpdateExecuteParams{ TableName: tableName, } now := time.Now() for fieldName, sqlMappingColumn := range sqlMapping.ColumnMap { if !sqlMappingColumn.CanUpdate { continue } fieldType := sqlMappingColumn.ValueFieldType value := reflect.Zero(fieldType).Interface() if !sqlMappingColumn.ValueFieldValue.IsZero() { value = sqlMappingColumn.ValueFieldValue.Interface() } if sqlMappingColumn.UpdateCallback { if callback == nil { return errors.New("需要使用回调函数但是没有传递回调函数") } retValue, err := callback(e, fieldName, value) if err != nil { return err } retValueType := reflect.TypeOf(retValue) if retValueType == nil || retValueType.Kind() == reflect.Ptr { return errors.New("返回应当为值类型") } value = retValue } if fieldName == lastUpdatedTimeFieldName && fieldType.String() == "time.Time" && value.(time.Time).IsZero() { value = now } // 字段为空不更新 if reflect.ValueOf(value).IsZero() && !sqlMappingColumn.CanUpdateClear { continue } tableRowValue, err := parseValue(value) if err != nil { return err } executeParams.TableRows = append(executeParams.TableRows, raw_sql_tpl.TableRow{ Column: sqlMappingColumn.Name, Value: tableRowValue, }) if sqlMappingColumn.IsKey { executeParams.Conditions = append(executeParams.Conditions, raw_sql_tpl.Condition{ Column: sqlMappingColumn.Name, Operator: "=", Value: tableRowValue, }) } } executeParamsMap, err := executeParams.Map() if err != nil { return err } _, err = executor.ExecuteRawSql(raw_sql_tpl.UpdateTpl, executeParamsMap) if err != nil { return err } return nil } func Query[T any](executor SqlExecutor, tableName string, e T, pageNo int, pageSize int, callback ConditionCallback[T]) ([]map[string]any, int64, error) { if executor == nil { return nil, 0, errors.New("没有传递执行器") } if strutils.IsStringEmpty(tableName) { return nil, 0, errors.New("没有传递表名") } if reflect.TypeOf(e) == nil { return nil, 0, errors.New("没有传递实体") } sqlMapping, err := tag.ParseSqlMapping(e) if err != nil { return nil, 0, err } var offset int var limit int if pageNo != 0 && pageSize != 0 { offset = (pageNo - 1) * pageSize limit = pageSize } executeParams := raw_sql_tpl.QueryExecuteParams{ TableName: tableName, Limit: limit, Offset: offset, } countParams := raw_sql_tpl.CountExecuteParams{ TableName: tableName, } for fieldName, sqlMappingColumn := range sqlMapping.ColumnMap { if !sqlMappingColumn.CanQuery { continue } fieldType := sqlMappingColumn.ValueFieldType conditionValue := reflect.Zero(fieldType).Interface() if !sqlMappingColumn.ValueFieldValue.IsZero() { conditionValue = sqlMappingColumn.ValueFieldValue.Interface() } conditionOp := "=" if sqlMappingColumn.QueryCallback { if callback == nil { return nil, 0, errors.New("需要使用回调函数但是没有传递回调函数") } retConditionOp, retConditionValue, err := callback(e, fieldName, sqlMappingColumn.Name, conditionValue) if err != nil { return nil, 0, err } retValueType := reflect.TypeOf(retConditionValue) if retValueType == nil || retValueType.Kind() == reflect.Ptr { return nil, 0, errors.New("返回应当为值类型") } conditionValue = retConditionValue conditionOp = retConditionOp } tableRowValue, err := parseValue(conditionValue) if err != nil { return nil, 0, err } executeParams.Conditions = append(executeParams.Conditions, raw_sql_tpl.Condition{ Column: sqlMappingColumn.Name, Operator: conditionOp, Value: tableRowValue, }) countParams.Conditions = append(countParams.Conditions, raw_sql_tpl.Condition{ Column: sqlMappingColumn.Name, Operator: conditionOp, Value: tableRowValue, }) } executeParamsMap, err := executeParams.Map() if err != nil { return nil, 0, err } countParamsMap, err := countParams.Map() if err != nil { return nil, 0, err } tableRows, err := executor.ExecuteRawSql(raw_sql_tpl.QueryTpl, executeParamsMap) if err != nil { return nil, 0, err } countTableRow, err := executor.ExecuteRawSql(raw_sql_tpl.CountTpl, countParamsMap) if err != nil { return nil, 0, err } return tableRows, int64(countTableRow[0]["count"].(float64)), nil } func QueryByKeys[T any](executor SqlExecutor, tableName string, e T) (map[string]any, error) { if executor == nil { return nil, errors.New("没有传递执行器") } if strutils.IsStringEmpty(tableName) { return nil, errors.New("没有传递表名") } if reflect.TypeOf(e) == nil { return nil, errors.New("没有传递实体") } sqlMapping, err := tag.ParseSqlMapping(e) if err != nil { return nil, err } executeParams := raw_sql_tpl.QueryExecuteParams{ TableName: tableName, Limit: 0, Offset: 0, } for _, sqlMappingColumn := range sqlMapping.ColumnMap { if !sqlMappingColumn.IsKey { continue } fieldType := sqlMappingColumn.ValueFieldType conditionValue := reflect.Zero(fieldType).Interface() if !sqlMappingColumn.ValueFieldValue.IsZero() { conditionValue = sqlMappingColumn.ValueFieldValue.Interface() } tableRowValue, err := parseValue(conditionValue) if err != nil { return nil, err } executeParams.Conditions = append(executeParams.Conditions, raw_sql_tpl.Condition{ Column: sqlMappingColumn.Name, Operator: "=", Value: tableRowValue, }) } executeParamsMap, err := executeParams.Map() if err != nil { return nil, err } tableRows, err := executor.ExecuteRawSql(raw_sql_tpl.QueryTpl, executeParamsMap) if err != nil { return nil, err } if tableRows == nil || len(tableRows) == 0 { return nil, ErrDBRecordNotExist } return tableRows[0], nil } func Count[T any](executor SqlExecutor, tableName string, e T, callback ConditionCallback[T]) (int64, error) { if executor == nil { return 0, errors.New("没有传递执行器") } if strutils.IsStringEmpty(tableName) { return 0, errors.New("没有传递表名") } if reflect.TypeOf(e) == nil { return 0, errors.New("没有传递实体") } sqlMapping, err := tag.ParseSqlMapping(e) if err != nil { return 0, err } executeParams := raw_sql_tpl.CountExecuteParams{ TableName: tableName, } for fieldName, sqlMappingColumn := range sqlMapping.ColumnMap { fieldType := sqlMappingColumn.ValueFieldType conditionValue := reflect.Zero(fieldType).Interface() if !sqlMappingColumn.ValueFieldValue.IsZero() { conditionValue = sqlMappingColumn.ValueFieldValue.Interface() } conditionOp := "=" if sqlMappingColumn.CountCallback { if callback == nil { return 0, errors.New("需要使用回调函数但是没有传递回调函数") } retConditionOp, retConditionValue, err := callback(e, fieldName, sqlMappingColumn.Name, conditionValue) if err != nil { return 0, err } retValueType := reflect.TypeOf(retConditionValue) if retValueType == nil || retValueType.Kind() == reflect.Ptr { return 0, errors.New("返回应当为值类型") } conditionValue = retConditionValue conditionOp = retConditionOp } tableRowValue, err := parseValue(conditionValue) if err != nil { return 0, err } executeParams.Conditions = append(executeParams.Conditions, raw_sql_tpl.Condition{ Column: sqlMappingColumn.Name, Operator: conditionOp, Value: tableRowValue, }) } executeParamsMap, err := executeParams.Map() if err != nil { return 0, err } tableRows, err := executor.ExecuteRawSql(raw_sql_tpl.CountTpl, executeParamsMap) if err != nil { return 0, err } return int64(tableRows[0]["count"].(float64)), nil } func CheckExist[T any](executor SqlExecutor, tableName string, e T, callback ConditionCallback[T]) (bool, error) { if executor == nil { return false, errors.New("没有传递执行器") } if strutils.IsStringEmpty(tableName) { return false, errors.New("没有传递表名") } if reflect.TypeOf(e) == nil { return false, errors.New("没有传递实体") } sqlMapping, err := tag.ParseSqlMapping(e) if err != nil { return false, err } executeParams := raw_sql_tpl.CountExecuteParams{ TableName: tableName, } for fieldName, sqlMappingColumn := range sqlMapping.ColumnMap { fieldType := sqlMappingColumn.ValueFieldType conditionValue := reflect.Zero(fieldType).Interface() if !sqlMappingColumn.ValueFieldValue.IsZero() { conditionValue = sqlMappingColumn.ValueFieldValue.Interface() } conditionOp := "=" if sqlMappingColumn.CheckExistCallback { if callback == nil { return false, errors.New("需要使用回调函数但是没有传递回调函数") } retConditionOp, retConditionValue, err := callback(e, fieldName, sqlMappingColumn.Name, conditionValue) if err != nil { return false, err } retValueType := reflect.TypeOf(retConditionValue) if retValueType == nil || retValueType.Kind() == reflect.Ptr { return false, errors.New("返回应当为值类型") } conditionValue = retConditionValue conditionOp = retConditionOp } tableRowValue, err := parseValue(conditionValue) if err != nil { return false, err } executeParams.Conditions = append(executeParams.Conditions, raw_sql_tpl.Condition{ Column: sqlMappingColumn.Name, Operator: conditionOp, Value: tableRowValue, }) } executeParamsMap, err := executeParams.Map() if err != nil { return false, err } tableRows, err := executor.ExecuteRawSql(raw_sql_tpl.CountTpl, executeParamsMap) if err != nil { return false, err } return int64(tableRows[0]["count"].(float64)) > 0, nil } func CheckExistByKey[T any](executor SqlExecutor, tableName string, e T) (bool, error) { if executor == nil { return false, errors.New("没有传递执行器") } if strutils.IsStringEmpty(tableName) { return false, errors.New("没有传递表名") } if reflect.TypeOf(e) == nil { return false, errors.New("没有传递实体") } sqlMapping, err := tag.ParseSqlMapping(e) if err != nil { return false, err } executeParams := raw_sql_tpl.CountExecuteParams{ TableName: tableName, } for _, sqlMappingColumn := range sqlMapping.ColumnMap { if !sqlMappingColumn.IsKey { continue } fieldType := sqlMappingColumn.ValueFieldType conditionValue := reflect.Zero(fieldType).Interface() if !sqlMappingColumn.ValueFieldValue.IsZero() { conditionValue = sqlMappingColumn.ValueFieldValue.Interface() } tableRowValue, err := parseValue(conditionValue) if err != nil { return false, err } executeParams.Conditions = append(executeParams.Conditions, raw_sql_tpl.Condition{ Column: sqlMappingColumn.Name, Operator: "=", Value: tableRowValue, }) } executeParamsMap, err := executeParams.Map() if err != nil { return false, err } tableRows, err := executor.ExecuteRawSql(raw_sql_tpl.CountTpl, executeParamsMap) if err != nil { return false, err } return int64(tableRows[0]["count"].(float64)) > 0, nil } func CheckHasOnlyOne[T any](executor SqlExecutor, tableName string, e T, callback ConditionCallback[T]) (bool, error) { if executor == nil { return false, errors.New("没有传递执行器") } if strutils.IsStringEmpty(tableName) { return false, errors.New("没有传递表名") } if reflect.TypeOf(e) == nil { return false, errors.New("没有传递实体") } sqlMapping, err := tag.ParseSqlMapping(e) if err != nil { return false, err } executeParams := raw_sql_tpl.CountExecuteParams{ TableName: tableName, } for fieldName, sqlMappingColumn := range sqlMapping.ColumnMap { fieldType := sqlMappingColumn.ValueFieldType conditionValue := reflect.Zero(fieldType).Interface() if !sqlMappingColumn.ValueFieldValue.IsZero() { conditionValue = sqlMappingColumn.ValueFieldValue.Interface() } conditionOp := "=" if sqlMappingColumn.QueryCallback { if callback == nil { return false, errors.New("需要使用回调函数但是没有传递回调函数") } retConditionOp, retConditionValue, err := callback(e, fieldName, sqlMappingColumn.Name, conditionValue) if err != nil { return false, err } retValueType := reflect.TypeOf(retConditionValue) if retValueType == nil || retValueType.Kind() == reflect.Ptr { return false, errors.New("返回应当为值类型") } conditionValue = retConditionValue conditionOp = retConditionOp } tableRowValue, err := parseValue(conditionValue) if err != nil { return false, err } executeParams.Conditions = append(executeParams.Conditions, raw_sql_tpl.Condition{ Column: sqlMappingColumn.Name, Operator: conditionOp, Value: tableRowValue, }) } executeParamsMap, err := executeParams.Map() if err != nil { return false, err } tableRows, err := executor.ExecuteRawSql(raw_sql_tpl.CountTpl, executeParamsMap) if err != nil { return false, err } return int64(tableRows[0]["count"].(float64)) == 1, nil } func ExecuteRawSql(executor SqlExecutor, sql string, executeParams map[string]any) ([]map[string]any, error) { if executor == nil { return nil, errors.New("没有传递执行器") } if strutils.IsStringEmpty(sql) { return nil, errors.New("没有sql") } tableRows, err := executor.ExecuteRawSql(sql, executeParams) if err != nil { return nil, err } return tableRows, nil } func ExecuteSql(executor SqlExecutor, name string, executeParams map[string]any) ([]map[string]any, error) { if executor == nil { return nil, errors.New("没有传递执行器") } if strutils.IsStringEmpty(name) { return nil, errors.New("没有sql资源名称") } tableRows, err := executor.ExecuteSql(name, executeParams) if err != nil { return nil, err } return tableRows, nil } func parseValue(value any) (string, error) { switch v := value.(type) { case string: return "'" + v + "'", nil case bool: return strconv.FormatBool(v), nil case time.Time: return "'" + v.Format(timeWriteFormat) + "'", nil case int: return strconv.Itoa(v), nil case int8: return strconv.FormatInt(int64(v), 10), nil case int16: return strconv.FormatInt(int64(v), 10), nil case int32: return strconv.FormatInt(int64(v), 10), nil case int64: return strconv.FormatInt(v, 10), nil case uint: return strconv.FormatUint(uint64(v), 10), nil case uint8: return strconv.FormatUint(uint64(v), 10), nil case uint16: return strconv.FormatUint(uint64(v), 10), nil case uint32: return strconv.FormatUint(uint64(v), 10), nil case uint64: return strconv.FormatUint(v, 10), nil default: return "", errors.New("不支持的类型") } }