|
|
@@ -1,14 +1,15 @@
|
|
|
package sql
|
|
|
|
|
|
import (
|
|
|
+ "reflect"
|
|
|
+ "strings"
|
|
|
+ "time"
|
|
|
+
|
|
|
"git.sxidc.com/go-framework/baize/framework/core/infrastructure/logger"
|
|
|
"git.sxidc.com/go-framework/baize/framework/core/tag/sql/sql_result"
|
|
|
"git.sxidc.com/go-tools/utils/reflectutils"
|
|
|
"git.sxidc.com/go-tools/utils/strutils"
|
|
|
"github.com/pkg/errors"
|
|
|
- "reflect"
|
|
|
- "strings"
|
|
|
- "time"
|
|
|
)
|
|
|
|
|
|
const (
|
|
|
@@ -272,6 +273,8 @@ func (result Result) ColumnValueFloat64(columnName string) float64 {
|
|
|
return value
|
|
|
}
|
|
|
|
|
|
+type PostDealFunc func(output any) error
|
|
|
+
|
|
|
// ParseSqlResult 解析查询结果
|
|
|
// 参数:
|
|
|
// - input: sql.Result或者[]sql.Result类型的查询结果
|
|
|
@@ -279,7 +282,17 @@ func (result Result) ColumnValueFloat64(columnName string) float64 {
|
|
|
// 返回值:
|
|
|
// - 错误
|
|
|
func ParseSqlResult(input any, output any) error {
|
|
|
- return ParseSqlResultWithColumn(input, output, "")
|
|
|
+ return ParseSqlResultFunc(input, output, nil)
|
|
|
+}
|
|
|
+
|
|
|
+// ParseSqlResultFunc 解析查询结果,可传递后处理函数
|
|
|
+// 参数:
|
|
|
+// - input: sql.Result或者[]sql.Result类型的查询结果
|
|
|
+// - output: 接收查询结果的指针,如果是结构,需要使用sqlresult tag标注字段
|
|
|
+// 返回值:
|
|
|
+// - 错误
|
|
|
+func ParseSqlResultFunc(input any, output any, postDealFunc PostDealFunc) error {
|
|
|
+ return ParseSqlResultWithColumnFunc(input, output, "", postDealFunc)
|
|
|
}
|
|
|
|
|
|
// ParseSqlResultWithColumn 取查询结果列解析结果
|
|
|
@@ -290,6 +303,17 @@ func ParseSqlResult(input any, output any) error {
|
|
|
// 返回值:
|
|
|
// - 错误
|
|
|
func ParseSqlResultWithColumn(input any, output any, columnName string) error {
|
|
|
+ return ParseSqlResultWithColumnFunc(input, output, columnName, nil)
|
|
|
+}
|
|
|
+
|
|
|
+// ParseSqlResultWithColumnFunc 取查询结果列解析结果,可传递后处理函数
|
|
|
+// 参数:
|
|
|
+// - input: sql.Result或者[]sql.Result类型的查询结果
|
|
|
+// - output: 接收查询结果的指针,如果是结构,需要使用sqlresult tag标注字段
|
|
|
+// - columnName: 获取的列名
|
|
|
+// 返回值:
|
|
|
+// - 错误
|
|
|
+func ParseSqlResultWithColumnFunc(input any, output any, columnName string, postDealFunc PostDealFunc) error {
|
|
|
if input == nil {
|
|
|
return nil
|
|
|
}
|
|
|
@@ -314,7 +338,7 @@ func ParseSqlResultWithColumn(input any, output any, columnName string) error {
|
|
|
return errors.New("输出不是slice,输入需要是sql.Result")
|
|
|
}
|
|
|
|
|
|
- return parseSqlSingle(result, output, columnName)
|
|
|
+ return parseSqlSingleFunc(result, output, columnName, postDealFunc)
|
|
|
}
|
|
|
|
|
|
// 输出是slice,需要遍历处理
|
|
|
@@ -331,13 +355,13 @@ func ParseSqlResultWithColumn(input any, output any, columnName string) error {
|
|
|
// slice子类型判断
|
|
|
if outputElemType.Elem().Kind() == reflect.Pointer {
|
|
|
outputEntityValue = reflect.New(outputElemType.Elem().Elem())
|
|
|
- err := parseSqlSingle(result, outputEntityValue.Interface(), columnName)
|
|
|
+ err := parseSqlSingleFunc(result, outputEntityValue.Interface(), columnName, postDealFunc)
|
|
|
if err != nil {
|
|
|
return err
|
|
|
}
|
|
|
} else {
|
|
|
outputEntityValue = reflect.New(outputElemType.Elem()).Elem()
|
|
|
- err := parseSqlSingle(result, outputEntityValue.Addr().Interface(), columnName)
|
|
|
+ err := parseSqlSingleFunc(result, outputEntityValue.Addr().Interface(), columnName, postDealFunc)
|
|
|
if err != nil {
|
|
|
return err
|
|
|
}
|
|
|
@@ -352,7 +376,7 @@ func ParseSqlResultWithColumn(input any, output any, columnName string) error {
|
|
|
return nil
|
|
|
}
|
|
|
|
|
|
-func parseSqlSingle(result Result, output any, columnName string) error {
|
|
|
+func parseSqlSingleFunc(result Result, output any, columnName string, postDealFunc PostDealFunc) error {
|
|
|
outputValue := reflectutils.PointerValueElem(reflect.ValueOf(output))
|
|
|
outputKind := reflectutils.GroupValueKind(outputValue)
|
|
|
|
|
|
@@ -373,22 +397,49 @@ func parseSqlSingle(result Result, output any, columnName string) error {
|
|
|
|
|
|
switch outputKind {
|
|
|
case reflect.String:
|
|
|
- return reflectutils.AssignStringValue(oneResultValue, outputValue)
|
|
|
+ err := reflectutils.AssignStringValue(oneResultValue, outputValue)
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
case reflect.Bool:
|
|
|
- return reflectutils.AssignBoolValue(oneResultValue, outputValue)
|
|
|
+ err := reflectutils.AssignBoolValue(oneResultValue, outputValue)
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
case reflect.Int64:
|
|
|
- return reflectutils.AssignInt64Value(oneResultValue, outputValue)
|
|
|
+ err := reflectutils.AssignInt64Value(oneResultValue, outputValue)
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
case reflect.Uint64:
|
|
|
- return reflectutils.AssignUint64Value(oneResultValue, outputValue)
|
|
|
+ err := reflectutils.AssignUint64Value(oneResultValue, outputValue)
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
case reflect.Float64:
|
|
|
- return reflectutils.AssignFloat64Value(oneResultValue, outputValue)
|
|
|
+ err := reflectutils.AssignFloat64Value(oneResultValue, outputValue)
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
case reflect.Struct:
|
|
|
if outputValue.Type().Name() == "time.Time" {
|
|
|
return errors.New("不支持转换到time.Time,请使用Result.ColumnValueStringAsTime")
|
|
|
}
|
|
|
|
|
|
- return sql_result.DefaultUsage(result, output, columnName)
|
|
|
+ err := sql_result.DefaultUsage(result, output, columnName)
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
default:
|
|
|
return errors.New("不支持的类型")
|
|
|
}
|
|
|
+
|
|
|
+ if postDealFunc != nil {
|
|
|
+ err := postDealFunc(output)
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ return nil
|
|
|
}
|