Browse Source

添加sql处理函数

yjp 1 ngày trước cách đây
mục cha
commit
c8a25fd90f
1 tập tin đã thay đổi với 65 bổ sung14 xóa
  1. 65 14
      framework/core/infrastructure/database/sql/result.go

+ 65 - 14
framework/core/infrastructure/database/sql/result.go

@@ -1,14 +1,15 @@
 package sql
 
 import (
+	"reflect"
+	"strings"
+	"time"
+
 	"git.sxidc.com/go-framework/baize/framework/core/infrastructure/logger"
 	"git.sxidc.com/go-framework/baize/framework/core/tag/sql/sql_result"
 	"git.sxidc.com/go-tools/utils/reflectutils"
 	"git.sxidc.com/go-tools/utils/strutils"
 	"github.com/pkg/errors"
-	"reflect"
-	"strings"
-	"time"
 )
 
 const (
@@ -272,6 +273,8 @@ func (result Result) ColumnValueFloat64(columnName string) float64 {
 	return value
 }
 
+type PostDealFunc func(output any) error
+
 // ParseSqlResult 解析查询结果
 // 参数:
 // - input: sql.Result或者[]sql.Result类型的查询结果
@@ -279,7 +282,17 @@ func (result Result) ColumnValueFloat64(columnName string) float64 {
 // 返回值:
 // - 错误
 func ParseSqlResult(input any, output any) error {
-	return ParseSqlResultWithColumn(input, output, "")
+	return ParseSqlResultFunc(input, output, nil)
+}
+
+// ParseSqlResultFunc 解析查询结果,可传递后处理函数
+// 参数:
+// - input: sql.Result或者[]sql.Result类型的查询结果
+// - output: 接收查询结果的指针,如果是结构,需要使用sqlresult tag标注字段
+// 返回值:
+// - 错误
+func ParseSqlResultFunc(input any, output any, postDealFunc PostDealFunc) error {
+	return ParseSqlResultWithColumnFunc(input, output, "", postDealFunc)
 }
 
 // ParseSqlResultWithColumn 取查询结果列解析结果
@@ -290,6 +303,17 @@ func ParseSqlResult(input any, output any) error {
 // 返回值:
 // - 错误
 func ParseSqlResultWithColumn(input any, output any, columnName string) error {
+	return ParseSqlResultWithColumnFunc(input, output, columnName, nil)
+}
+
+// ParseSqlResultWithColumnFunc 取查询结果列解析结果,可传递后处理函数
+// 参数:
+// - input: sql.Result或者[]sql.Result类型的查询结果
+// - output: 接收查询结果的指针,如果是结构,需要使用sqlresult tag标注字段
+// - columnName: 获取的列名
+// 返回值:
+// - 错误
+func ParseSqlResultWithColumnFunc(input any, output any, columnName string, postDealFunc PostDealFunc) error {
 	if input == nil {
 		return nil
 	}
@@ -314,7 +338,7 @@ func ParseSqlResultWithColumn(input any, output any, columnName string) error {
 			return errors.New("输出不是slice,输入需要是sql.Result")
 		}
 
-		return parseSqlSingle(result, output, columnName)
+		return parseSqlSingleFunc(result, output, columnName, postDealFunc)
 	}
 
 	// 输出是slice,需要遍历处理
@@ -331,13 +355,13 @@ func ParseSqlResultWithColumn(input any, output any, columnName string) error {
 		// slice子类型判断
 		if outputElemType.Elem().Kind() == reflect.Pointer {
 			outputEntityValue = reflect.New(outputElemType.Elem().Elem())
-			err := parseSqlSingle(result, outputEntityValue.Interface(), columnName)
+			err := parseSqlSingleFunc(result, outputEntityValue.Interface(), columnName, postDealFunc)
 			if err != nil {
 				return err
 			}
 		} else {
 			outputEntityValue = reflect.New(outputElemType.Elem()).Elem()
-			err := parseSqlSingle(result, outputEntityValue.Addr().Interface(), columnName)
+			err := parseSqlSingleFunc(result, outputEntityValue.Addr().Interface(), columnName, postDealFunc)
 			if err != nil {
 				return err
 			}
@@ -352,7 +376,7 @@ func ParseSqlResultWithColumn(input any, output any, columnName string) error {
 	return nil
 }
 
-func parseSqlSingle(result Result, output any, columnName string) error {
+func parseSqlSingleFunc(result Result, output any, columnName string, postDealFunc PostDealFunc) error {
 	outputValue := reflectutils.PointerValueElem(reflect.ValueOf(output))
 	outputKind := reflectutils.GroupValueKind(outputValue)
 
@@ -373,22 +397,49 @@ func parseSqlSingle(result Result, output any, columnName string) error {
 
 	switch outputKind {
 	case reflect.String:
-		return reflectutils.AssignStringValue(oneResultValue, outputValue)
+		err := reflectutils.AssignStringValue(oneResultValue, outputValue)
+		if err != nil {
+			return err
+		}
 	case reflect.Bool:
-		return reflectutils.AssignBoolValue(oneResultValue, outputValue)
+		err := reflectutils.AssignBoolValue(oneResultValue, outputValue)
+		if err != nil {
+			return err
+		}
 	case reflect.Int64:
-		return reflectutils.AssignInt64Value(oneResultValue, outputValue)
+		err := reflectutils.AssignInt64Value(oneResultValue, outputValue)
+		if err != nil {
+			return err
+		}
 	case reflect.Uint64:
-		return reflectutils.AssignUint64Value(oneResultValue, outputValue)
+		err := reflectutils.AssignUint64Value(oneResultValue, outputValue)
+		if err != nil {
+			return err
+		}
 	case reflect.Float64:
-		return reflectutils.AssignFloat64Value(oneResultValue, outputValue)
+		err := reflectutils.AssignFloat64Value(oneResultValue, outputValue)
+		if err != nil {
+			return err
+		}
 	case reflect.Struct:
 		if outputValue.Type().Name() == "time.Time" {
 			return errors.New("不支持转换到time.Time,请使用Result.ColumnValueStringAsTime")
 		}
 
-		return sql_result.DefaultUsage(result, output, columnName)
+		err := sql_result.DefaultUsage(result, output, columnName)
+		if err != nil {
+			return err
+		}
 	default:
 		return errors.New("不支持的类型")
 	}
+
+	if postDealFunc != nil {
+		err := postDealFunc(output)
+		if err != nil {
+			return err
+		}
+	}
+
+	return nil
 }