Browse Source

修改bug

yjp 1 year ago
parent
commit
2e964bd291
7 changed files with 194 additions and 34 deletions
  1. 1 0
      go.mod
  2. 2 0
      go.sum
  3. 4 0
      request.go
  4. 55 16
      sql_parser.go
  5. 27 2
      test_sql.go
  6. 62 11
      v1.go
  7. 43 5
      v1_test.go

+ 1 - 0
go.mod

@@ -7,6 +7,7 @@ require (
 	git.sxidc.com/service-supports/dps-sdk v1.10.1
 	github.com/auxten/postgresql-parser v1.0.1
 	github.com/go-resty/resty/v2 v2.11.0
+	github.com/satori/go.uuid v1.2.0
 )
 
 require (

+ 2 - 0
go.sum

@@ -253,6 +253,8 @@ github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZV
 github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs=
 github.com/russross/blackfriday v1.5.2/go.mod h1:JO/DiYxRf+HjHt06OyowR9PTA263kcR/rfWxYHBV53g=
 github.com/ryanuber/columnize v2.1.0+incompatible/go.mod h1:sm1tb6uqfes/u+d4ooFouqFdy9/2g9QGwK3SQygK0Ts=
+github.com/satori/go.uuid v1.2.0 h1:0uYX9dsZ2yD7q2RtLRtPSdGDWzjeM3TbMJP9utgA0ww=
+github.com/satori/go.uuid v1.2.0/go.mod h1:dA0hQrYB0VpLJoorglMZABFdXlWrHn1NEOzdhQKdks0=
 github.com/sclevine/agouti v3.0.0+incompatible/go.mod h1:b4WX9W9L1sfQKXeJf1mUTLZKJ48R1S7H23Ji7oFO5Bw=
 github.com/sergi/go-diff v1.1.0/go.mod h1:STckp+ISIX8hZLjrqAeVduY0gWCT9IjLuqbuNXdaHfM=
 github.com/shurcooL/sanitized_anchor_name v1.0.0/go.mod h1:1NzhyTcUVG4SuEtjjoZeVRXNmyL/1OwPU0+IJeTBvfc=

+ 4 - 0
request.go

@@ -1,5 +1,9 @@
 package dpsapi
 
+import "git.sxidc.com/go-tools/api_binding/http_binding/binding_context"
+
+type OperatorIDFunc func(c *binding_context.Context) (string, error)
+
 type OperateParseRequest struct {
 	SQL string `json:"sql" binding:"required"`
 }

+ 55 - 16
sql_parser.go

@@ -11,10 +11,23 @@ import (
 	"time"
 )
 
+const (
+	clauseTableRowValueKindTime int = iota + 1
+	clauseTableRowValueKindBool
+	clauseTableRowValueKindString
+	clauseTableRowValueKindUint64
+	clauseTableRowValueKindFloat64
+)
+
+type clauseTableRowValue struct {
+	kind  int
+	value any
+}
+
 type insertClause struct {
 	table      string
 	keyColumns []string
-	tableRows  map[string]any
+	tableRows  map[string]clauseTableRowValue
 }
 
 type deleteClause struct {
@@ -98,7 +111,7 @@ func insertWalk(sql string) (*insertClause, error) {
 			clause.table = tableFmtCtx.String()
 
 			// 解析rows
-			values := make([]any, 0)
+			values := make([]clauseTableRowValue, 0)
 
 			rowsStatement, err := parser.Parse(realNode.Rows.String())
 			if err != nil {
@@ -113,31 +126,43 @@ func insertWalk(sql string) (*insertClause, error) {
 				switch rowsNode := node.(type) {
 				case *tree.FuncExpr:
 					// 函数类型
-					value, err := evaluateFuncExpr(rowsNode, nil, &funcStringParams)
+					value, kind, err := evaluateFuncExpr(rowsNode, nil, &funcStringParams)
 					if err != nil {
 						rowsWalkFuncErr = err
 						return true
 					}
 
-					values = append(values, value)
+					values = append(values, clauseTableRowValue{
+						kind:  kind,
+						value: value,
+					})
 				case *tree.DBool:
 					// 布尔类型
 					stringValue := rowsNode.String()
 					if !canSkipValueInThisContext(&funcStringParams, stringValue) {
 						if stringValue == "false" {
-							values = append(values, false)
+							values = append(values, clauseTableRowValue{
+								kind:  clauseTableRowValueKindBool,
+								value: false,
+							})
 						} else if stringValue == "true" {
-							values = append(values, true)
+							values = append(values, clauseTableRowValue{
+								kind:  clauseTableRowValueKindBool,
+								value: true,
+							})
 						} else {
 							rowsWalkFuncErr = errors.New("不支持的bool值")
 							return true
 						}
 					}
-				case *tree.UnresolvedName:
-					// 字符串类型或者函数参数类型,这里通过比较字符串value,排除了函数参数类型
+				case *tree.StrVal:
+					// 字符串类型或者函数参数是字符串的类型,这里通过比较字符串value,排除了函数参数类型
 					stringValue := rowsNode.String()
 					if !canSkipValueInThisContext(&funcStringParams, stringValue) {
-						values = append(values, stringValue)
+						values = append(values, clauseTableRowValue{
+							kind:  clauseTableRowValueKindString,
+							value: rowsNode.RawString(),
+						})
 					}
 				case *tree.NumVal:
 					// 数值类型,可以是整形或浮点型
@@ -151,7 +176,10 @@ func insertWalk(sql string) (*insertClause, error) {
 								return true
 							}
 
-							values = append(values, valueUint64)
+							values = append(values, clauseTableRowValue{
+								kind:  clauseTableRowValueKindUint64,
+								value: valueUint64,
+							})
 						} else if numKind == constant.Float {
 							valueFloat64, err := strconv.ParseFloat(rowsNode.String(), 64)
 							if err != nil {
@@ -159,12 +187,18 @@ func insertWalk(sql string) (*insertClause, error) {
 								return true
 							}
 
-							values = append(values, valueFloat64)
+							values = append(values, clauseTableRowValue{
+								kind:  clauseTableRowValueKindFloat64,
+								value: valueFloat64,
+							})
 						} else {
 							rowsWalkFuncErr = errors.New("不支持的数值类型")
 							return true
 						}
 					}
+				case *tree.UnresolvedName:
+					rowsWalkFuncErr = errors.New("存在无法解析的名字,是否应该使用单引号")
+					return true
 				}
 
 				return false
@@ -185,7 +219,7 @@ func insertWalk(sql string) (*insertClause, error) {
 			hasIDColumn := false
 			keyColumns := make([]string, 0)
 			allColumns := make([]string, 0)
-			tableRows := make(map[string]any)
+			tableRows := make(map[string]clauseTableRowValue)
 
 			for i, column := range realNode.Columns.ToStrings() {
 				if column == "id" {
@@ -240,10 +274,10 @@ func selectWalk(sql string) (*insertClause, error) {
 	return nil, nil
 }
 
-func evaluateFuncExpr(expr *tree.FuncExpr, funcName *string, stringParams *[]string) (any, error) {
+func evaluateFuncExpr(expr *tree.FuncExpr, funcName *string, stringParams *[]string) (any, int, error) {
 	if strings.HasPrefix(expr.String(), "parse_time") {
 		if expr.Exprs == nil || len(expr.Exprs) != 2 {
-			return nil, errors.New("parse_time(time_str, time_format)")
+			return nil, 0, errors.New("parse_time(time_str, time_format)")
 		}
 
 		timeStr := expr.Exprs[0].String()
@@ -257,9 +291,14 @@ func evaluateFuncExpr(expr *tree.FuncExpr, funcName *string, stringParams *[]str
 			*stringParams = append(*stringParams, timeStr, timeFormat)
 		}
 
-		return time.ParseInLocation(timeFormat, timeStr, time.Local)
+		parsedTime, err := time.ParseInLocation(timeFormat, timeStr, time.Local)
+		if err != nil {
+			return nil, 0, err
+		}
+
+		return parsedTime, clauseTableRowValueKindTime, nil
 	} else {
-		return nil, errors.New("不支持的函数")
+		return nil, 0, errors.New("不支持的函数")
 	}
 }
 

+ 27 - 2
test_sql.go

@@ -1,9 +1,34 @@
 package dpsapi
 
+import (
+	uuid "github.com/satori/go.uuid"
+	"strings"
+)
+
+const (
+	parseSqlInsert = `insert into students (id, name, age, rate, time, is_right) values ('aaa', 'yjp', 5, 92.5, parse_time('2024-01-01 00:00:00', '2006-01-02 15:04:05'), false)`
+)
+
 const (
-	parseSqlSelect = `insert into students (id, name, age, rate, time, is_right) values ("aaa", "yjp", 5, 92.5, parse_time("2024-01-01 00:00:00", "2006-01-02 15:04:05"), false)`
+	sqlInsertFormat = `insert into %s (id, name, time, table_num) values ('%s', '%s', parse_time('%s', '2006-01-02 15:04:05'), %d)`
 )
 
 const (
-	sqlSelect = `select * from t where id = 'aaa'`
+	dpsAddress     = "localhost:30170"
+	testDatabaseID = "ee2d7dabe56646ce835d80873348ee0e"
 )
+
+var tableModelDescribe = map[string]string{
+	"ID":       "gorm:\"primary_key;type:varchar(32);comment:id;\"",
+	"Name":     "gorm:\"not null;type:varchar(128);comment:数据库名称;\"",
+	"Time":     "gorm:\"not null;type:timestamp with time zone;comment:数据库时间;\"",
+	"TableNum": "gorm:\"not null;type:integer;comment:数据库表数量;\"",
+}
+
+func getUUID() string {
+	return uuid.NewV4().String()
+}
+
+func simpleUUID() string {
+	return strings.ReplaceAll(getUUID(), "-", "")
+}

+ 62 - 11
v1.go

@@ -9,9 +9,22 @@ import (
 	"git.sxidc.com/go-tools/api_binding/utils"
 	"git.sxidc.com/service-supports/dps-sdk"
 	"git.sxidc.com/service-supports/dps-sdk/client"
+	"time"
 )
 
-func ApiV1(binding *http_binding.Binding, dpsAddress string) {
+func ApiV1(binding *http_binding.Binding, dpsAddress string, operatorIDFunc OperatorIDFunc) {
+	if binding == nil {
+		panic("没有传递http_binding")
+	}
+
+	if utils.IsStringEmpty(dpsAddress) {
+		panic("没有指定dps地址")
+	}
+
+	if operatorIDFunc == nil {
+		panic("没有传递获取operatorID的回调函数")
+	}
+
 	http_binding.PostBind(binding, &http_binding.SimpleBindItem[OperateParseRequest, map[string]any]{
 		Path:         "/dpsv1/database/operate/parse",
 		ResponseFunc: response.SendMapResponse,
@@ -52,6 +65,11 @@ func ApiV1(binding *http_binding.Binding, dpsAddress string) {
 		Path:         "/dpsv1/database/operate",
 		ResponseFunc: response.SendMsgResponse,
 		BusinessFunc: func(c *binding_context.Context, inputModel OperateRequest) (any, error) {
+			operatorID, err := operatorIDFunc(c)
+			if err != nil {
+				return nil, err
+			}
+
 			parsedClauses, err := parseSql(inputModel.SQL)
 			if err != nil {
 				return nil, err
@@ -66,13 +84,13 @@ func ApiV1(binding *http_binding.Binding, dpsAddress string) {
 				for _, parsedClause := range parsedClauses {
 					switch clause := parsedClause.(type) {
 					case *insertClause:
-						return doInsert(inputModel, clause)
+						return doInsert(tx, inputModel, clause, operatorID)
 					case *deleteClause:
-						return doDelete(inputModel, clause)
+						return doDelete(tx, inputModel, clause, operatorID)
 					case *updateClause:
-						return doUpdate(inputModel, clause)
+						return doUpdate(tx, inputModel, clause, operatorID)
 					case *selectClause:
-						return doSelect(inputModel, clause)
+						return doSelect(dpsClient, inputModel, clause)
 					default:
 						return errors.New("不支持的SQL语句")
 					}
@@ -90,10 +108,15 @@ func ApiV1(binding *http_binding.Binding, dpsAddress string) {
 }
 
 func insertMap(clause *insertClause) map[string]any {
+	tableRows := make(map[string]any)
+	for columnName, value := range clause.tableRows {
+		tableRows[columnName] = value.value
+	}
+
 	return map[string]any{
 		"table":       clause.table,
 		"key_columns": clause.keyColumns,
-		"table_rows":  clause.tableRows,
+		"table_rows":  tableRows,
 	}
 }
 
@@ -109,18 +132,46 @@ func selectMap(clause *selectClause) map[string]any {
 	return map[string]any{}
 }
 
-func doInsert(inputModel OperateRequest, clause *insertClause) error {
+func doInsert(tx client.Transaction, inputModel OperateRequest, clause *insertClause, operatorID string) error {
 	version := inputModel.Version
 	if utils.IsStringEmpty(version) {
 		version = "v1"
 	}
 
-	fmt.Printf("%+#v\n", clause)
+	tableRow := client.NewTableRow()
+	for columnName, value := range clause.tableRows {
+		switch value.kind {
+		case clauseTableRowValueKindTime:
+			tableRow.AddColumnValueTime(columnName, value.value.(time.Time))
+		case clauseTableRowValueKindBool:
+			tableRow.AddColumnValueBool(columnName, value.value.(bool))
+		case clauseTableRowValueKindString:
+			tableRow.AddColumnValueString(columnName, value.value.(string))
+		case clauseTableRowValueKindUint64:
+			tableRow.AddColumnValueUint64(columnName, value.value.(uint64))
+		case clauseTableRowValueKindFloat64:
+			tableRow.AddColumnValueFloat64(columnName, value.value.(float64))
+		default:
+			return errors.New("不支持的值类型")
+		}
+	}
+
+	statement, err := tx.InsertTx(&client.InsertRequest{
+		TablePrefixWithSchema: clause.table,
+		Version:               inputModel.Version,
+		KeyColumns:            clause.keyColumns,
+		TableRow:              tableRow,
+		UserID:                operatorID,
+	})
+	if err != nil {
+		fmt.Println(statement)
+		return err
+	}
 
 	return nil
 }
 
-func doDelete(inputModel OperateRequest, clause *deleteClause) error {
+func doDelete(tx client.Transaction, inputModel OperateRequest, clause *deleteClause, operatorID string) error {
 	version := inputModel.Version
 	if utils.IsStringEmpty(version) {
 		version = "v1"
@@ -129,7 +180,7 @@ func doDelete(inputModel OperateRequest, clause *deleteClause) error {
 	return nil
 }
 
-func doUpdate(inputModel OperateRequest, clause *updateClause) error {
+func doUpdate(tx client.Transaction, inputModel OperateRequest, clause *updateClause, operatorID string) error {
 	version := inputModel.Version
 	if utils.IsStringEmpty(version) {
 		version = "v1"
@@ -138,7 +189,7 @@ func doUpdate(inputModel OperateRequest, clause *updateClause) error {
 	return nil
 }
 
-func doSelect(inputModel OperateRequest, clause *selectClause) error {
+func doSelect(dpsClient client.Client, inputModel OperateRequest, clause *selectClause) error {
 	version := inputModel.Version
 	if utils.IsStringEmpty(version) {
 		version = "v1"

+ 43 - 5
v1_test.go

@@ -1,16 +1,23 @@
 package dpsapi
 
 import (
+	"fmt"
 	"git.sxidc.com/go-tools/api_binding/http_binding"
+	"git.sxidc.com/go-tools/api_binding/http_binding/binding_context"
 	"git.sxidc.com/go-tools/api_binding/http_binding/response"
+	"git.sxidc.com/service-supports/dps-sdk"
+	"git.sxidc.com/service-supports/dps-sdk/client"
 	"github.com/go-resty/resty/v2"
+	"math/rand"
 	"testing"
+	"time"
 )
 
 func TestApiV1OperateParse(t *testing.T) {
 	http_binding.Init("test", "10086")
 	defer http_binding.Destroy()
 
+	operatorID := simpleUUID()
 	tableName := "students"
 	exceptedKeyColumns := []string{"id", "name", "age", "rate", "time", "is_right"}
 	exceptedTableRows := map[string]any{
@@ -23,9 +30,11 @@ func TestApiV1OperateParse(t *testing.T) {
 	}
 
 	binding := http_binding.NewBinding("v1")
-	ApiV1(binding, "localhost:30170")
+	ApiV1(binding, dpsAddress, func(c *binding_context.Context) (string, error) {
+		return operatorID, nil
+	})
 
-	parsed := operateParse(t, parseSqlSelect)
+	parsed := operateParse(t, parseSqlInsert)
 
 	if parsed["table"].(string) != tableName {
 		t.Fatal("表名不正确")
@@ -48,10 +57,39 @@ func TestApiV1Operate(t *testing.T) {
 	http_binding.Init("test", "10086")
 	defer http_binding.Destroy()
 
+	operatorID := simpleUUID()
+	tablePrefix := "test.a" + simpleUUID()[0:7]
+	id := simpleUUID()
+	name := simpleUUID()
+	now := time.Now().Local()
+	tableNum := rand.Intn(10) + 1
+
+	autoMigrate(t, []client.AutoMigrateItem{
+		{
+			TablePrefixWithSchema: tablePrefix,
+			Version:               "v1",
+			TableModelDescribe:    tableModelDescribe,
+		},
+	})
+
 	binding := http_binding.NewBinding("v1")
-	ApiV1(binding, "localhost:30170")
+	ApiV1(binding, dpsAddress, func(c *binding_context.Context) (string, error) {
+		return operatorID, nil
+	})
 
-	operate(t, sqlSelect)
+	operate(t, fmt.Sprintf(sqlInsertFormat, tablePrefix, id, name, now.Format(time.DateTime), tableNum))
+}
+
+func autoMigrate(t *testing.T, items []client.AutoMigrateItem) {
+	dpsClient, err := dps.NewClient(dpsAddress, "v1", testDatabaseID)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	err = dpsClient.AutoMigrate(&client.AutoMigrateRequest{Items: items})
+	if err != nil {
+		t.Fatal(err)
+	}
 }
 
 func operateParse(t *testing.T, sql string) map[string]any {
@@ -86,7 +124,7 @@ func operate(t *testing.T, sql string) {
 
 	resp, err := resty.New().R().
 		SetBody(&OperateRequest{
-			DatabaseID: "ee2d7dabe56646ce835d80873348ee0e",
+			DatabaseID: testDatabaseID,
 			Version:    "v1",
 			SQL:        sql,
 		}).