yjp hace 11 meses
padre
commit
14654a506d
Se han modificado 3 ficheros con 55 adiciones y 78 borrados
  1. 3 6
      demo/sql_entity.go
  2. 48 63
      sql/parse_result.go
  3. 4 9
      test/sdk_test.go

+ 3 - 6
demo/sql_entity.go

@@ -43,9 +43,6 @@ func main() {
 		StudentIDs:    studentIDs,
 	}
 
-	classInfo := new(ClassInfo)
-	classInfos := make([]*ClassInfo, 0)
-
 	err = sql.InsertEntity(sdk.GetInstance(), tableName, class)
 	if err != nil {
 		panic(err)
@@ -59,7 +56,7 @@ func main() {
 		panic(err)
 	}
 
-	err = sql.ParseSqlResult(result, classInfo)
+	classInfo, err := sql.ParseSqlResult[Class](result)
 	if err != nil {
 		panic(err)
 	}
@@ -67,7 +64,7 @@ func main() {
 	fmt.Println("Class Info:")
 	fmt.Printf("%#+v\n", classInfo)
 
-	err = sql.ParseSqlResult(result, &classInfos)
+	classInfos, err := sql.ParseSqlResult[[]Class](result)
 	if err != nil {
 		panic(err)
 	}
@@ -88,7 +85,7 @@ func main() {
 		panic(err)
 	}
 
-	err = sql.ParseSqlResult(result, classInfo)
+	classInfo, err = sql.ParseSqlResult[Class](result)
 	if err != nil {
 		panic(err)
 	}

+ 48 - 63
sql/parse_result.go

@@ -12,42 +12,11 @@ import (
 	"time"
 )
 
-func ParseSqlResult(input any, output any) error {
-	if input == nil || output == nil {
-		return nil
-	}
-
-	// 输出的Type,可以是slice的指针或者是结构的指针
-	outputType := reflect.TypeOf(output)
-	if outputType.Kind() != reflect.Ptr {
-		return errors.New("输出实体应该为结构的slice或者是结构的指针")
-	}
-
-	// 取元素类型
-	if outputType.Kind() == reflect.Ptr {
-		outputType = outputType.Elem()
-	}
-
-	// 检查元素类型是否为slice或者结构
-	if outputType.Kind() != reflect.Slice && outputType.Kind() != reflect.Struct {
-		return errors.New("输出实体应该为结构的slice或者是结构的指针")
-	}
+func ParseSqlResult[T any](input any) (T, error) {
+	var zero T
 
-	// 如果输出类型为slice,则取slice元素类型
-	outputElemType := outputType
-	if outputElemType.Kind() == reflect.Slice {
-		outputElemType = outputElemType.Elem()
-	}
-
-	// 校验元素类型是否为结构类型
-	if outputElemType.Kind() == reflect.Ptr {
-		if outputElemType.Elem().Kind() != reflect.Struct {
-			return errors.New("输出实体slice元素应该为结构或者结构指针")
-		}
-	} else {
-		if outputElemType.Kind() != reflect.Struct {
-			return errors.New("输出实体slice应该为结构或者结构指针")
-		}
+	if input == nil {
+		return zero, nil
 	}
 
 	// 构造需要遍历的tableRows
@@ -55,59 +24,75 @@ func ParseSqlResult(input any, output any) error {
 	if !ok {
 		tableRow, ok := input.(sdk.SqlResult)
 		if !ok {
-			return errors.New("输入数据应该为[]sdk.SqlResult或[]sdk.SqlResult")
+			return zero, errors.New("输入数据应该为sdk.SqlResult或[]sdk.SqlResult")
 		}
 
 		tableRows = []sdk.SqlResult{tableRow}
 	}
 
-	// 构造输出实体slice
-	outputEntities := reflect.MakeSlice(reflect.SliceOf(outputElemType), 0, 0)
+	// 构造outputValue
+	typeCheckErr := errors.New("可以接受的类型为struct, *struct, []struct, []*struct")
+	outputType := reflect.TypeOf(zero)
+
+	fmt.Println("Output Type:", outputType.String())
+
+	if outputType.Kind() != reflect.Struct && outputType.Kind() != reflect.Ptr && outputType.Kind() != reflect.Slice {
+		return zero, typeCheckErr
+	} else if outputType.Kind() == reflect.Ptr && outputType.Elem().Kind() != reflect.Struct {
+		return zero, typeCheckErr
+	} else if outputType.Kind() == reflect.Slice &&
+		(outputType.Elem().Kind() != reflect.Struct && outputType.Elem().Kind() != reflect.Ptr) {
+		return zero, typeCheckErr
+	} else if outputType.Kind() == reflect.Slice &&
+		outputType.Elem().Kind() == reflect.Ptr && outputType.Elem().Elem().Kind() != reflect.Struct {
+		return zero, typeCheckErr
+	}
+
+	var outputValue reflect.Value
+	if outputType.Kind() == reflect.Struct || outputType.Kind() == reflect.Ptr {
+		outputValue = reflect.New(outputType).Elem()
+	} else {
+		outputValue = reflect.MakeSlice(outputType, 0, 0)
+	}
 
 	for _, tableRow := range tableRows {
 		// 构造输出实体
-		outputEntityValue := reflect.New(outputElemType).Elem()
-
 		var outputEntity any
-		if outputElemType.Kind() == reflect.Ptr {
-			outputEntityValue.Set(reflect.New(outputElemType.Elem()).Elem().Addr())
-			outputEntity = outputEntityValue.Interface()
+		if outputType.Kind() == reflect.Struct {
+			outputEntity = outputValue.Addr().Interface()
+		} else if outputType.Kind() == reflect.Ptr {
+			outputEntity = outputValue.Interface()
 		} else {
-			outputEntity = outputEntityValue.Addr().Interface()
+			if outputType.Elem().Kind() == reflect.Struct {
+				outputEntity = reflect.New(outputType.Elem()).Interface()
+			} else {
+				outputValueElemPtr := reflect.New(outputType.Elem()).Elem()
+				outputValueElemPtr.Set(reflect.New(outputType.Elem().Elem()))
+				outputEntity = outputValueElemPtr.Interface()
+			}
 		}
 
 		sqlResult, err := ParseSqlResultTag(outputEntity)
 		if err != nil {
-			return err
+			return zero, err
 		}
 
 		err = formOutputEntity(tableRow, sqlResult)
 		if err != nil {
-			return err
+			return zero, err
 		}
 
-		// 保存输出实体
-		if outputElemType.Kind() == reflect.Ptr {
-			outputEntities = reflect.Append(outputEntities, outputEntityValue)
-		} else {
-			outputEntities = reflect.Append(outputEntities, outputEntityValue)
+		if outputType.Kind() == reflect.Slice {
+			outputValue = reflect.Append(outputValue, reflect.ValueOf(outputEntity).Elem())
 		}
 	}
 
-	// 将输出实体赋值给输出指针变量
-	outputValue := reflect.Indirect(reflect.ValueOf(output))
-
-	if !outputValue.CanSet() {
-		return nil
-	}
-
-	if outputType.Kind() == reflect.Slice {
-		outputValue.Set(outputEntities)
-	} else {
-		outputValue.Set(outputEntities.Index(0))
+	output, ok := outputValue.Interface().(T)
+	if !ok {
+		return zero, errors.New("输出类型不匹配")
 	}
 
-	return nil
+	return output, nil
 }
 
 func formOutputEntity(tableRow sdk.SqlResult, sqlResult *Result) error {

+ 4 - 9
test/sdk_test.go

@@ -289,8 +289,7 @@ func TestRawSqlTemplate(t *testing.T) {
 		t.Fatal(err)
 	}
 
-	classes := make([]Class, 0)
-	err = sql.ParseSqlResult(queryResults, &classes)
+	classes, err := sql.ParseSqlResult[[]Class](queryResults)
 	if err != nil {
 		t.Fatal(err)
 	}
@@ -328,8 +327,7 @@ func TestRawSqlTemplate(t *testing.T) {
 		t.Fatal(err)
 	}
 
-	classes = make([]Class, 0)
-	err = sql.ParseSqlResult(queryResults, &classes)
+	classes, err = sql.ParseSqlResult[[]Class](queryResults)
 	if err != nil {
 		t.Fatal(err)
 	}
@@ -545,9 +543,6 @@ func TestSql(t *testing.T) {
 		Ignored:       "",
 	}
 
-	queryClasses := make([]Class, 0)
-	queryClass := new(Class)
-
 	err = sdk.InitInstance(token, address, httpPort, grpcPort, namespace, dataSource)
 	if err != nil {
 		t.Fatal(err)
@@ -724,7 +719,7 @@ func TestSql(t *testing.T) {
 		t.Fatal("总数不正确")
 	}
 
-	err = sql.ParseSqlResult(tableRows, &queryClasses)
+	queryClasses, err := sql.ParseSqlResult[[]Class](tableRows)
 	if err != nil {
 		t.Fatal(err)
 	}
@@ -751,7 +746,7 @@ func TestSql(t *testing.T) {
 		t.Fatal(err)
 	}
 
-	err = sql.ParseSqlResult(tableRow, queryClass)
+	queryClass, err := sql.ParseSqlResult[Class](tableRow)
 	if err != nil {
 		t.Fatal(err)
 	}