Explorar el Código

添加sqlresult注解

yjp hace 1 año
padre
commit
a307b78cf7
Se han modificado 3 ficheros con 143 adiciones y 15 borrados
  1. 8 8
      sdk/tag/sql_mapping.go
  2. 94 1
      sdk/tag/sql_result.go
  3. 41 6
      test/sdk_test.go

+ 8 - 8
sdk/tag/sql_mapping.go

@@ -26,7 +26,7 @@ const (
 )
 
 type SqlMapping struct {
-	ColumnMap map[string]SqlColumn
+	ColumnMap map[string]SqlMappingColumn
 }
 
 func ParseSqlMapping(e any) (*SqlMapping, error) {
@@ -44,26 +44,26 @@ func ParseSqlMapping(e any) (*SqlMapping, error) {
 	}
 
 	sqlMapping := new(SqlMapping)
-	sqlMapping.ColumnMap = make(map[string]SqlColumn)
+	sqlMapping.ColumnMap = make(map[string]SqlMappingColumn)
 
 	fieldNum := entityType.NumField()
 	for i := 0; i < fieldNum; i++ {
-		sqlColumn, err := parseSqlColumn(entityType.Field(i))
+		column, err := parseSqlMappingColumn(entityType.Field(i))
 		if err != nil {
 			return nil, err
 		}
 
-		if sqlColumn == nil {
+		if column == nil {
 			continue
 		}
 
-		sqlMapping.ColumnMap[sqlColumn.Name] = *sqlColumn
+		sqlMapping.ColumnMap[column.Name] = *column
 	}
 
 	return sqlMapping, nil
 }
 
-type SqlColumn struct {
+type SqlMappingColumn struct {
 	Name           string
 	IsKey          bool
 	CanUpdate      bool
@@ -74,8 +74,8 @@ type SqlColumn struct {
 	QueryCallback  bool
 }
 
-func parseSqlColumn(field reflect.StructField) (*SqlColumn, error) {
-	sqlColumn := &SqlColumn{
+func parseSqlMappingColumn(field reflect.StructField) (*SqlMappingColumn, error) {
+	sqlColumn := &SqlMappingColumn{
 		Name:           strcase.ToSnake(field.Name),
 		IsKey:          false,
 		CanUpdate:      true,

+ 94 - 1
sdk/tag/sql_result.go

@@ -1,5 +1,98 @@
 package tag
 
+import (
+	"errors"
+	"github.com/iancoleman/strcase"
+	"reflect"
+	"strings"
+)
+
+const (
+	sqlResultTagPartSeparator         = ";"
+	sqlResultTagPartKeyValueSeparator = ":"
+)
+
+const (
+	sqlResultTagKey   = "sqlresult"
+	sqlResultColumn   = "column"
+	sqlResultIgnore   = "-"
+	sqlResultCallback = "callback"
+)
+
 type SqlResult struct {
-	ColumnMap map[string]SqlColumn
+	ColumnMap map[string]SqlResultColumn
+}
+
+func ParseSqlResult(e any) (*SqlResult, error) {
+	if e == nil {
+		return nil, errors.New("没有传递实体")
+	}
+
+	entityType := reflect.TypeOf(e)
+	if entityType.Kind() == reflect.Ptr {
+		entityType = entityType.Elem()
+	}
+
+	if entityType.Kind() != reflect.Struct {
+		return nil, errors.New("传递的不是实体结构")
+	}
+
+	sqlResult := new(SqlResult)
+	sqlResult.ColumnMap = make(map[string]SqlResultColumn)
+
+	fieldNum := entityType.NumField()
+	for i := 0; i < fieldNum; i++ {
+		column, err := parseSqlResultColumn(entityType.Field(i))
+		if err != nil {
+			return nil, err
+		}
+
+		if column == nil {
+			continue
+		}
+
+		sqlResult.ColumnMap[column.Name] = *column
+	}
+
+	return sqlResult, nil
+}
+
+type SqlResultColumn struct {
+	Name             string
+	ResultColumnName string
+	Callback         bool
+}
+
+func parseSqlResultColumn(field reflect.StructField) (*SqlResultColumn, error) {
+	sqlColumn := &SqlResultColumn{
+		Name:             strcase.ToSnake(field.Name),
+		ResultColumnName: strcase.ToSnake(field.Name),
+		Callback:         false,
+	}
+
+	sqlResultTag, ok := field.Tag.Lookup(sqlResultTagKey)
+	if !ok {
+		return sqlColumn, nil
+	}
+
+	if sqlResultTag == sqlResultIgnore {
+		return nil, nil
+	}
+
+	sqlResultParts := strings.Split(sqlResultTag, sqlResultTagPartSeparator)
+	if sqlResultParts != nil || len(sqlResultParts) != 0 {
+		for _, sqlResultPart := range sqlResultParts {
+			sqlPartKeyValue := strings.Split(strings.TrimSpace(sqlResultPart), sqlResultTagPartKeyValueSeparator)
+			switch sqlPartKeyValue[0] {
+			case sqlResultColumn:
+				sqlColumn.ResultColumnName = strings.TrimSpace(sqlPartKeyValue[1])
+			case sqlResultCallback:
+				sqlColumn.Callback = true
+			default:
+				continue
+			}
+		}
+	}
+
+	return sqlColumn, nil
 }

+ 41 - 6
test/sdk_test.go

@@ -14,13 +14,13 @@ import (
 )
 
 type Class struct {
-	ID              string `sqlmapping:"key;"`
-	Name            string `sqlmapping:"update:canClear;notQuery;insertCallback;updateCallback;"`
-	StudentNum      int    `sqlmapping:"column:student_num;notUpdate;queryCallback;"`
-	GraduatedTime   time.Time
+	ID              string    `sqlmapping:"key;"`
+	Name            string    `sqlmapping:"update:canClear;notQuery;insertCallback;updateCallback;"`
+	StudentNum      int       `sqlmapping:"column:student_num;notUpdate;queryCallback;" sqlresult:"column:student_num_alias;"`
+	GraduatedTime   time.Time `sqlresult:"callback"`
 	CreatedTime     *time.Time
 	LastUpdatedTime time.Time
-	Ignored         string `sqlmapping:"-"`
+	Ignored         string `sqlmapping:"-" sqlresult:"-"`
 }
 
 const (
@@ -392,7 +392,7 @@ func TestRawSqlTemplate(t *testing.T) {
 	}
 }
 
-func TestDataMapping(t *testing.T) {
+func TestSqlMapping(t *testing.T) {
 	sqlMapping, err := tag.ParseSqlMapping(&Class{})
 	if err != nil {
 		t.Fatal(err)
@@ -444,3 +444,38 @@ func TestDataMapping(t *testing.T) {
 		}
 	}
 }
+
+func TestSqlResult(t *testing.T) {
+	sqlResult, err := tag.ParseSqlResult(&Class{})
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	for columnName, sqlColumn := range sqlResult.ColumnMap {
+		if columnName != "id" && columnName != "name" &&
+			columnName != "student_num" && columnName != "graduated_time" &&
+			columnName != "created_time" && columnName != "last_updated_time" {
+			t.Fatal("列名不正确")
+		}
+
+		if sqlColumn.Name != "id" && sqlColumn.Name != "name" &&
+			sqlColumn.Name != "student_num" && columnName != "graduated_time" &&
+			columnName != "created_time" && columnName != "last_updated_time" {
+			t.Fatal("列名不正确")
+		}
+
+		if sqlColumn.Name != columnName {
+			t.Fatal("列名不正确")
+		}
+
+		if columnName == "student_num" {
+			if sqlColumn.ResultColumnName != "student_num_alias" {
+				t.Fatal("结果列名不正确")
+			}
+		}
+
+		if sqlColumn.Callback && columnName != "graduated_time" {
+			t.Fatal("回调不正确")
+		}
+	}
+}