yjp 11 месяцев назад
Родитель
Сommit
520e8d7c81
3 измененных файлов с 33 добавлено и 9 удалено
  1. 11 2
      demo/sql_entity.go
  2. 0 1
      demo/sql_result.go
  3. 22 6
      sql/parse_result.go

+ 11 - 2
demo/sql_entity.go

@@ -44,6 +44,7 @@ func main() {
 	}
 
 	classInfo := new(ClassInfo)
+	classInfos := make([]*ClassInfo, 0)
 
 	err = sql.InsertEntity(sdk.GetInstance(), tableName, class)
 	if err != nil {
@@ -64,7 +65,15 @@ func main() {
 	}
 
 	fmt.Println("Class Info:")
-	fmt.Printf("%#+v", classInfo)
+	fmt.Printf("%#+v\n", classInfo)
+
+	err = sql.ParseSqlResult(result, &classInfos)
+	if err != nil {
+		panic(err)
+	}
+
+	fmt.Println("Class Info:")
+	fmt.Printf("%#+v\n", classInfos[0])
 
 	err = sql.UpdateEntity(sdk.GetInstance(), tableName, newClass)
 	if err != nil {
@@ -85,7 +94,7 @@ func main() {
 	}
 
 	fmt.Println("Class Info:")
-	fmt.Printf("%#+v", classInfo)
+	fmt.Printf("%#+v\n", classInfo)
 
 	err = sql.DeleteEntity(sdk.GetInstance(), tableName, &Class{IDField: IDField{ID: classID}})
 	if err != nil {

+ 0 - 1
demo/sql_result.go

@@ -16,7 +16,6 @@ func main() {
 }
 
 func printSqlResult(sqlResult *sql.Result) {
-
 	for fieldName, resultElement := range sqlResult.ResultElement {
 		fmt.Println("---------------------------------------")
 		fmt.Println("Field Name: " + fieldName)

+ 22 - 6
sql/parse_result.go

@@ -40,8 +40,14 @@ func ParseSqlResult(input any, output any) error {
 	}
 
 	// 校验元素类型是否为结构类型
-	if outputElemType.Kind() != reflect.Struct {
-		return errors.New("输出实体slice应该为结构的slice指针")
+	if outputElemType.Kind() == reflect.Ptr {
+		if outputElemType.Elem().Kind() != reflect.Struct {
+			return errors.New("输出实体slice元素应该为结构或者结构指针")
+		}
+	} else {
+		if outputElemType.Kind() != reflect.Struct {
+			return errors.New("输出实体slice应该为结构或者结构指针")
+		}
 	}
 
 	// 构造需要遍历的tableRows
@@ -60,8 +66,15 @@ func ParseSqlResult(input any, output any) error {
 
 	for _, tableRow := range tableRows {
 		// 构造输出实体
-		outputEntityValue := reflect.New(outputElemType).Elem().Addr()
-		outputEntity := outputEntityValue.Interface()
+		outputEntityValue := reflect.New(outputElemType).Elem()
+
+		var outputEntity any
+		if outputElemType.Kind() == reflect.Ptr {
+			outputEntityValue.Set(reflect.New(outputElemType.Elem()).Elem().Addr())
+			outputEntity = outputEntityValue.Interface()
+		} else {
+			outputEntity = outputEntityValue.Addr().Interface()
+		}
 
 		sqlResult, err := ParseSqlResultTag(outputEntity)
 		if err != nil {
@@ -74,7 +87,11 @@ func ParseSqlResult(input any, output any) error {
 		}
 
 		// 保存输出实体
-		outputEntities = reflect.Append(outputEntities, outputEntityValue.Elem())
+		if outputElemType.Kind() == reflect.Ptr {
+			outputEntities = reflect.Append(outputEntities, outputEntityValue)
+		} else {
+			outputEntities = reflect.Append(outputEntities, outputEntityValue)
+		}
 	}
 
 	// 将输出实体赋值给输出指针变量
@@ -94,7 +111,6 @@ func ParseSqlResult(input any, output any) error {
 }
 
 func formOutputEntity(tableRow sdk.SqlResult, sqlResult *Result) error {
-
 	for fieldName, resultElement := range sqlResult.ResultElement {
 		switch element := resultElement.(type) {
 		case *Result: