yjp hace 1 año
padre
commit
2d15797d33

+ 8 - 7
convenient/domain/auth/auth.go

@@ -115,8 +115,8 @@ func (simple *Simple) init(urlPrefix string, i *infrastructure.Infrastructure) {
 	adminUserID := strutils.SimpleUUID()
 	adminRoleID := strutils.SimpleUUID()
 
-	permissionGroupEntities := make([]any, 0)
-	permissionEntities := make([]any, 0)
+	permissionGroupEntities := make([]permission_group.Entity, 0)
+	permissionEntities := make([]permission.Entity, 0)
 	permissionIDs := make([]string, 0)
 	permissionInGroup := make(map[string][]string)
 
@@ -179,13 +179,13 @@ func (simple *Simple) init(urlPrefix string, i *infrastructure.Infrastructure) {
 
 	err = database.Transaction(dbExecutor, func(tx database.Executor) error {
 		// 创建权限
-		err := database.InsertEntityBatch(tx, domain.TableName(simple.Schema, &permission.Entity{}), permissionEntities)
+		err := database.InsertEntity(tx, domain.TableName(simple.Schema, &permission.Entity{}), permissionEntities)
 		if err != nil {
 			return err
 		}
 
 		// 创建权限组
-		err = database.InsertEntityBatch(tx, domain.TableName(simple.Schema, &permission_group.Entity{}), permissionGroupEntities)
+		err = database.InsertEntity(tx, domain.TableName(simple.Schema, &permission_group.Entity{}), permissionGroupEntities)
 		if err != nil {
 			return err
 		}
@@ -388,9 +388,10 @@ func (simple *Simple) bind(binder *binding.Binder) {
 				}, nil
 			}
 
-			roleIDs := make([]string, len(roleIDResults))
-			for index, roleIDResult := range roleIDResults {
-				roleIDs[index] = roleIDResult.ColumnValueString(domain.RelationColumnName(&role.Entity{}))
+			roleIDs := make([]string, 0)
+			err = sql.ParseSqlResult(roleIDResults, &roleIDs)
+			if err != nil {
+				return errInfo, errors.New(err.Error())
 			}
 
 			roleResults, totalCount, err := database.Query(dbExecutor, &sql.QueryExecuteParams{

+ 5 - 2
convenient/domain/auth/middlewares/middlewares.go

@@ -92,8 +92,11 @@ func Authentication(dbSchema string, jwtSecretKey string) binding.Middleware {
 		}
 
 		roleIDs := make([]string, 0)
-		for _, roleIDResult := range roleIDResults {
-			roleIDs = append(roleIDs, roleIDResult.ColumnValueString(domain.RelationColumnName(&role.Entity{})))
+		err = sql.ParseSqlResult(roleIDResults, &roleIDs)
+		if err != nil {
+			respFunc(c, http.StatusUnauthorized, nil, errors.New(err.Error()))
+			c.Abort()
+			return
 		}
 
 		// 查找权限

+ 5 - 2
convenient/domain/configuration/api.go

@@ -83,8 +83,11 @@ func (simple *Simple) bind(binder *binding.Binder) {
 				}
 
 				values := make([]string, 0)
-				for _, result := range results {
-					values = append(values, result.ColumnValueString(ColumnValue))
+				err = sql.ParseSqlResult(results, &values)
+				if err != nil {
+					return map[string]any{
+						"values": make([]string, 0),
+					}, err
 				}
 
 				return map[string]any{

+ 3 - 2
convenient/relation/many2many/service.go

@@ -185,8 +185,9 @@ func Query[TI any](middleTableName string,
 		}
 
 		toIDs := make([]string, 0)
-		for _, toIDResult := range toIDResults {
-			toIDs = append(toIDs, toIDResult.ColumnValueString(toRelationColumnName))
+		err = sql.ParseSqlResult(toIDResults, &toIDs)
+		if err != nil {
+			return errResponse, err
 		}
 
 		toResults, _, err := database.Query(dbExecutor, &sql.QueryExecuteParams{

+ 6 - 4
convenient/relation/remote/service.go

@@ -191,8 +191,9 @@ func QueryToExist[TI any](middleTableName string,
 		}
 
 		toIDs := make([]string, 0)
-		for _, toIDResult := range toIDResults {
-			toIDs = append(toIDs, toIDResult.ColumnValueString(toRelationColumnName))
+		err = sql.ParseSqlResult(toIDResults, &toIDs)
+		if err != nil {
+			return errResponse, err
 		}
 
 		toResults, _, err := database.Query(dbExecutor, &sql.QueryExecuteParams{
@@ -281,8 +282,9 @@ func QueryToRemote(middleTableName string, fromRemote bool, fromTableName string
 		}
 
 		toIDs := make([]string, 0)
-		for _, toIDResult := range toIDResults {
-			toIDs = append(toIDs, toIDResult.ColumnValueString(toRelationColumnName))
+		err = sql.ParseSqlResult(toIDResults, &toIDs)
+		if err != nil {
+			return errResponse, err
 		}
 
 		return response.InfosData[string]{

+ 79 - 45
framework/core/infrastructure/database/sql/result.go

@@ -4,6 +4,7 @@ import (
 	"fmt"
 	"git.sxidc.com/go-framework/baize/framework/core/tag/sql/sql_result"
 	"git.sxidc.com/go-tools/utils/reflectutils"
+	"git.sxidc.com/go-tools/utils/strutils"
 	"github.com/pkg/errors"
 	"reflect"
 	"strings"
@@ -16,79 +17,112 @@ const (
 	resultTimeSecFormat   = "2006-01-02T15:04:05+08:00"
 )
 
-func ParseSqlResult(input any, e any) error {
+func ParseSqlResult(input any, output any) error {
+	return ParseSqlResultWithColumn(input, output, "")
+}
+
+func ParseSqlResultWithColumn(input any, output any, columnName string) error {
 	if input == nil {
 		return nil
 	}
 
-	if e == nil {
+	if output == nil {
 		return nil
 	}
 
-	results, ok := input.([]Result)
-	if !ok {
-		tableRow, ok := input.(Result)
-		if !ok {
-			return errors.New("输入数据应该为sdk.SqlResult或[]sdk.SqlResult")
-		}
+	typeCheckErr := errors.New("可以接受的输出类型为指针或者slice的指针")
+	outputType := reflect.TypeOf(output)
 
-		results = []Result{tableRow}
+	if outputType.Kind() != reflect.Pointer {
+		return typeCheckErr
 	}
 
-	typeCheckErr := errors.New("可以接受的输出类型为结构指针或者结构slice的指针")
-	outputValue := reflect.ValueOf(e)
-
-	if outputValue.Kind() != reflect.Pointer {
-		return typeCheckErr
-	} else {
-		outputElemValue := reflectutils.PointerValueElem(outputValue)
+	outputElemType := reflectutils.PointerTypeElem(outputType)
 
-		if outputElemValue.Kind() != reflect.Struct && outputElemValue.Kind() != reflect.Slice {
-			return typeCheckErr
+	// 输出不是slice,直接用result赋值即可
+	if outputElemType.Kind() != reflect.Slice {
+		result, ok := input.(Result)
+		if !ok {
+			return errors.New("输出不是slice,输入需要是sql.Result")
 		}
 
-		if outputElemValue.Kind() == reflect.Slice && !reflectutils.IsSliceValueOf(outputElemValue, reflect.Struct) {
-			return typeCheckErr
-		}
+		return parseSqlSingle(result, output, columnName)
 	}
 
-	outputElemValue := reflectutils.PointerValueElem(outputValue)
-
-	// 构造输出实体slice
-	var outputEntities reflect.Value
-	if outputElemValue.Kind() == reflect.Struct {
-		outputEntities = reflect.MakeSlice(reflect.SliceOf(outputElemValue.Type()), 0, 0)
-	} else {
-		outputEntities = reflect.MakeSlice(outputElemValue.Type(), 0, 0)
+	// 输出是slice,需要遍历处理
+	results, ok := input.([]Result)
+	if !ok {
+		return errors.New("输出是slice,输入需要是[]sql.Result")
 	}
 
+	outputEntities := reflect.MakeSlice(outputElemType, 0, 0)
+
 	for _, result := range results {
 		var outputEntityValue reflect.Value
-		if outputElemValue.Kind() == reflect.Struct {
-			outputEntityValue = reflect.New(outputElemValue.Type()).Elem()
-		} else {
-			outputEntityValue = reflect.New(outputElemValue.Type().Elem()).Elem()
-		}
-
-		outputEntity := outputEntityValue.Addr().Interface()
 
-		err := sql_result.DefaultUsage(result, outputEntity)
-		if err != nil {
-			return err
+		// slice子类型判断
+		if outputElemType.Elem().Kind() == reflect.Pointer {
+			outputEntityValue = reflect.New(outputElemType.Elem().Elem())
+			err := parseSqlSingle(result, outputEntityValue.Interface(), columnName)
+			if err != nil {
+				return err
+			}
+		} else {
+			outputEntityValue = reflect.New(outputElemType.Elem()).Elem()
+			err := parseSqlSingle(result, outputEntityValue.Addr().Interface(), columnName)
+			if err != nil {
+				return err
+			}
 		}
 
-		// 保存输出实体
 		outputEntities = reflect.Append(outputEntities, outputEntityValue)
 	}
 
-	// 将输出实体赋值给输出指针变量
-	if outputElemValue.Kind() == reflect.Slice {
-		outputElemValue.Set(outputEntities)
+	outputElemValue := reflectutils.PointerValueElem(reflect.ValueOf(output))
+	outputElemValue.Set(outputEntities)
+
+	return nil
+}
+
+func parseSqlSingle(result Result, output any, columnName string) error {
+	outputValue := reflectutils.PointerValueElem(reflect.ValueOf(output))
+	outputKind := reflectutils.GroupValueKind(outputValue)
+
+	var oneResultValue any
+	if strutils.IsStringEmpty(columnName) {
+		for _, value := range result {
+			oneResultValue = value
+			break
+		}
 	} else {
-		outputElemValue.Set(outputEntities.Index(0))
+		value, ok := result[columnName]
+		if !ok {
+			return errors.New("列不存在")
+		}
+
+		oneResultValue = value
 	}
 
-	return nil
+	switch outputKind {
+	case reflect.String:
+		return reflectutils.AssignStringValue(oneResultValue, outputValue)
+	case reflect.Bool:
+		return reflectutils.AssignBoolValue(oneResultValue, outputValue)
+	case reflect.Int64:
+		return reflectutils.AssignInt64Value(oneResultValue, outputValue)
+	case reflect.Uint64:
+		return reflectutils.AssignUint64Value(oneResultValue, outputValue)
+	case reflect.Float64:
+		return reflectutils.AssignFloat64Value(oneResultValue, outputValue)
+	case reflect.Struct:
+		if outputValue.Type().Name() == "time.Time" {
+			return errors.New("不支持转换到time.Time,请使用Result.ColumnValueStringAsTime")
+		}
+
+		return sql_result.DefaultUsage(result, output, columnName)
+	default:
+		return errors.New("不支持的类型")
+	}
 }
 
 type Result map[string]any

+ 7 - 3
framework/core/tag/sql/sql_result/usage.go

@@ -17,12 +17,12 @@ const (
 	timeSecFormat   = "2006-01-02T15:04:05"
 )
 
-func DefaultUsage(result map[string]any, e any) error {
+func DefaultUsage(result map[string]any, e any, columnName string) error {
 	if result == nil || len(result) == 0 {
 		return nil
 	}
 
-	err := UseTag(e, defaultCallback(result))
+	err := UseTag(e, defaultCallback(result, columnName))
 	if err != nil {
 		return err
 	}
@@ -30,8 +30,12 @@ func DefaultUsage(result map[string]any, e any) error {
 	return nil
 }
 
-func defaultCallback(result map[string]any) OnParsedFieldTagFunc {
+func defaultCallback(result map[string]any, columnName string) OnParsedFieldTagFunc {
 	return func(fieldName string, entityFieldElemValue reflect.Value, tag *Tag) error {
+		if strutils.IsStringNotEmpty(columnName) && columnName != tag.Name {
+			return nil
+		}
+
 		resultValue, ok := result[tag.Name]
 		if !ok {
 			return nil