Browse Source

完成select解析

yjp 1 year ago
parent
commit
048b237d0d
7 changed files with 365 additions and 48 deletions
  1. 1 1
      go.mod
  2. 2 2
      go.sum
  3. 5 1
      request.go
  4. 217 35
      sql_parser.go
  5. 4 0
      test_sql.go
  6. 65 6
      v1.go
  7. 71 3
      v1_test.go

+ 1 - 1
go.mod

@@ -4,7 +4,7 @@ go 1.21.3
 
 require (
 	git.sxidc.com/go-tools/api_binding v1.3.22
-	git.sxidc.com/service-supports/dps-sdk v1.10.0
+	git.sxidc.com/service-supports/dps-sdk v1.10.1
 	github.com/auxten/postgresql-parser v1.0.1
 	github.com/go-resty/resty/v2 v2.11.0
 )

+ 2 - 2
go.sum

@@ -2,8 +2,8 @@ cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMT
 cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
 git.sxidc.com/go-tools/api_binding v1.3.22 h1:4LPdcClfqM2bmCSrCe6HceSmAwr3rGhfxjVWn/4rliY=
 git.sxidc.com/go-tools/api_binding v1.3.22/go.mod h1:JoPU2jtPwbsAEjAuiSedKxuwu3bK4rrkZxyQ3mkU0XI=
-git.sxidc.com/service-supports/dps-sdk v1.10.0 h1:kNmeGD54NiTfxRFLdOaaAMhCH+a3NY3CMq1+FkmBP9E=
-git.sxidc.com/service-supports/dps-sdk v1.10.0/go.mod h1:bR7PtL4x4QKc8ZRbszn8hLBaK6G/uZl4ZbU7/TZcJ94=
+git.sxidc.com/service-supports/dps-sdk v1.10.1 h1:Wk6Ruatn5GHgBYvaNzr27GQaLuAgHB2TD9tFTgz0lpE=
+git.sxidc.com/service-supports/dps-sdk v1.10.1/go.mod h1:bR7PtL4x4QKc8ZRbszn8hLBaK6G/uZl4ZbU7/TZcJ94=
 git.sxidc.com/service-supports/fserr v0.3.2 h1:5/FCr8o2jd1kNsp5tH/ADjB9fr091JZXMMZ15ZvNZzs=
 git.sxidc.com/service-supports/fserr v0.3.2/go.mod h1:W54RoA71mfex+zARuH/iMnQPMnBXQ23qXXOkwUh2sVQ=
 git.sxidc.com/service-supports/fslog v0.5.9 h1:q2XIK2o/fk/qmByy4x5kKLC+k7kolT5LrXHcWRSffXQ=

+ 5 - 1
request.go

@@ -1,6 +1,10 @@
 package dpsapi
 
-type OperateFromRequest struct {
+type OperateParseRequest struct {
+	SQL string `json:"sql" binding:"required"`
+}
+
+type OperateRequest struct {
 	DatabaseID string `json:"databaseId" binding:"required"`
 	Version    string `json:"version"`
 	SQL        string `json:"sql" binding:"required"`

+ 217 - 35
sql_parser.go

@@ -1,13 +1,20 @@
 package dpsapi
 
 import (
+	"errors"
 	"github.com/auxten/postgresql-parser/pkg/sql/parser"
+	"github.com/auxten/postgresql-parser/pkg/sql/sem/tree"
 	"github.com/auxten/postgresql-parser/pkg/walk"
+	"go/constant"
+	"strconv"
 	"strings"
+	"time"
 )
 
 type insertClause struct {
-	into string
+	table      string
+	keyColumns []string
+	tableRows  map[string]any
 }
 
 type deleteClause struct {
@@ -30,38 +37,39 @@ func parseSql(sqlStr string) ([]any, error) {
 	sqlClauses := make([]any, 0)
 
 	for _, sql := range sqls {
-		stmts, err := parser.Parse(sql)
-		if err != nil {
-			return nil, err
-		}
-
 		trimSQL := strings.TrimSpace(sql)
 		upperTrimSQL := strings.ToUpper(trimSQL)
 
 		var clause any
-		w := new(walk.AstWalker)
 
 		if strings.HasPrefix(upperTrimSQL, "INSERT") {
-			c := new(insertClause)
-			clause = c
-			w.Fn = insertWalkFunc(c)
+			innerClause, err := insertWalk(sql)
+			if err != nil {
+				return nil, err
+			}
+
+			clause = innerClause
 		} else if strings.HasPrefix(upperTrimSQL, "DELETE") {
-			c := new(deleteClause)
-			clause = c
-			w.Fn = deleteWalkFunc(c)
+			innerClause, err := deleteWalk(sql)
+			if err != nil {
+				return nil, err
+			}
+
+			clause = innerClause
 		} else if strings.HasPrefix(upperTrimSQL, "UPDATE") {
-			c := new(updateClause)
-			clause = c
-			w.Fn = updateWalkFunc(c)
+			innerClause, err := updateWalk(sql)
+			if err != nil {
+				return nil, err
+			}
+
+			clause = innerClause
 		} else if strings.HasPrefix(upperTrimSQL, "SELECT") {
-			c := new(selectClause)
-			clause = c
-			w.Fn = selectWalkFunc(c)
-		}
+			innerClause, err := selectWalk(sql)
+			if err != nil {
+				return nil, err
+			}
 
-		_, err = w.Walk(stmts, nil)
-		if err != nil {
-			return nil, err
+			clause = innerClause
 		}
 
 		sqlClauses = append(sqlClauses, clause)
@@ -70,26 +78,200 @@ func parseSql(sqlStr string) ([]any, error) {
 	return sqlClauses, nil
 }
 
-func insertWalkFunc(clause *insertClause) func(ctx interface{}, node interface{}) (stop bool) {
-	return func(ctx interface{}, node interface{}) (stop bool) {
-		return false
+func insertWalk(sql string) (*insertClause, error) {
+	clause := new(insertClause)
+
+	stmts, err := parser.Parse(sql)
+	if err != nil {
+		return nil, err
 	}
-}
 
-func deleteWalkFunc(clause *deleteClause) func(ctx interface{}, node interface{}) (stop bool) {
-	return func(ctx interface{}, node interface{}) (stop bool) {
-		return false
+	var walkFuncErr error
+
+	w := &walk.AstWalker{
+		Fn: func(ctx interface{}, node interface{}) (stop bool) {
+			realNode := node.(*tree.Insert)
+
+			// 获取table
+			tableFmtCtx := tree.NewFmtCtx(tree.FmtSimple)
+			realNode.Table.Format(tableFmtCtx)
+			clause.table = tableFmtCtx.String()
+
+			// 解析rows
+			values := make([]any, 0)
+
+			rowsStatement, err := parser.Parse(realNode.Rows.String())
+			if err != nil {
+				walkFuncErr = err
+				return true
+			}
+
+			var rowsWalkFuncErr error
+			funcStringParams := make([]string, 0)
+
+			rowsWalker := &walk.AstWalker{Fn: func(ctx interface{}, node interface{}) (stop bool) {
+				switch rowsNode := node.(type) {
+				case *tree.FuncExpr:
+					// 函数类型
+					value, err := evaluateFuncExpr(rowsNode, nil, &funcStringParams)
+					if err != nil {
+						rowsWalkFuncErr = err
+						return true
+					}
+
+					values = append(values, value)
+				case *tree.DBool:
+					// 布尔类型
+					stringValue := rowsNode.String()
+					if !canSkipValueInThisContext(&funcStringParams, stringValue) {
+						if stringValue == "false" {
+							values = append(values, false)
+						} else if stringValue == "true" {
+							values = append(values, true)
+						} else {
+							rowsWalkFuncErr = errors.New("不支持的bool值")
+							return true
+						}
+					}
+				case *tree.UnresolvedName:
+					// 字符串类型或者函数参数类型,这里通过比较字符串value,排除了函数参数类型
+					stringValue := rowsNode.String()
+					if !canSkipValueInThisContext(&funcStringParams, stringValue) {
+						values = append(values, stringValue)
+					}
+				case *tree.NumVal:
+					// 数值类型,可以是整形或浮点型
+					stringValue := rowsNode.String()
+					if !canSkipValueInThisContext(&funcStringParams, stringValue) {
+						numKind := rowsNode.Kind()
+						if numKind == constant.Int {
+							valueUint64, err := strconv.ParseUint(rowsNode.String(), 10, 64)
+							if err != nil {
+								rowsWalkFuncErr = err
+								return true
+							}
+
+							values = append(values, valueUint64)
+						} else if numKind == constant.Float {
+							valueFloat64, err := strconv.ParseFloat(rowsNode.String(), 64)
+							if err != nil {
+								rowsWalkFuncErr = err
+								return true
+							}
+
+							values = append(values, valueFloat64)
+						} else {
+							rowsWalkFuncErr = errors.New("不支持的数值类型")
+							return true
+						}
+					}
+				}
+
+				return false
+			}}
+
+			_, err = rowsWalker.Walk(rowsStatement, nil)
+			if err != nil {
+				walkFuncErr = err
+				return true
+			}
+
+			if rowsWalkFuncErr != nil {
+				walkFuncErr = rowsWalkFuncErr
+				return true
+			}
+
+			// 组装columnValues
+			hasIDColumn := false
+			keyColumns := make([]string, 0)
+			allColumns := make([]string, 0)
+			tableRows := make(map[string]any)
+
+			for i, column := range realNode.Columns.ToStrings() {
+				if column == "id" {
+					hasIDColumn = true
+				}
+
+				if strings.HasPrefix(column, "**") {
+					column = strings.TrimPrefix(column, "**")
+					keyColumns = append(keyColumns, column)
+				}
+
+				allColumns = append(allColumns, column)
+
+				tableRows[column] = values[i]
+			}
+
+			if keyColumns != nil && len(keyColumns) != 0 {
+				clause.keyColumns = keyColumns
+			} else if hasIDColumn {
+				clause.keyColumns = []string{"id"}
+			} else {
+				clause.keyColumns = allColumns
+			}
+
+			clause.tableRows = tableRows
+
+			return false
+		},
 	}
+
+	_, err = w.Walk(stmts, nil)
+	if err != nil {
+		return nil, err
+	}
+
+	if walkFuncErr != nil {
+		return nil, walkFuncErr
+	}
+
+	return clause, nil
 }
 
-func updateWalkFunc(clause *updateClause) func(ctx interface{}, node interface{}) (stop bool) {
-	return func(ctx interface{}, node interface{}) (stop bool) {
-		return false
+func deleteWalk(sql string) (*insertClause, error) {
+	return nil, nil
+}
+
+func updateWalk(sql string) (*insertClause, error) {
+	return nil, nil
+}
+
+func selectWalk(sql string) (*insertClause, error) {
+	return nil, nil
+}
+
+func evaluateFuncExpr(expr *tree.FuncExpr, funcName *string, stringParams *[]string) (any, error) {
+	if strings.HasPrefix(expr.String(), "parse_time") {
+		if expr.Exprs == nil || len(expr.Exprs) != 2 {
+			return nil, errors.New("parse_time(time_str, time_format)")
+		}
+
+		timeStr := expr.Exprs[0].String()
+		timeFormat := expr.Exprs[1].String()
+
+		if funcName != nil {
+			*funcName = "parse_time"
+		}
+
+		if stringParams != nil {
+			*stringParams = append(*stringParams, timeStr, timeFormat)
+		}
+
+		return time.ParseInLocation(timeFormat, timeStr, time.Local)
+	} else {
+		return nil, errors.New("不支持的函数")
 	}
 }
 
-func selectWalkFunc(clause *selectClause) func(ctx interface{}, node interface{}) (stop bool) {
-	return func(ctx interface{}, node interface{}) (stop bool) {
+func canSkipValueInThisContext(stringParams *[]string, stringValue string) bool {
+	if stringParams == nil || *stringParams == nil || len(*stringParams) == 0 {
 		return false
 	}
+
+	if (*stringParams)[0] == stringValue {
+		*stringParams = (*stringParams)[1:]
+		return true
+	}
+
+	return false
 }

+ 4 - 0
test_sql.go

@@ -1,5 +1,9 @@
 package dpsapi
 
+const (
+	parseSqlSelect = `insert into students (id, name, age, rate, time, is_right) values ("aaa", "yjp", 5, 92.5, parse_time("2024-01-01 00:00:00", "2006-01-02 15:04:05"), false)`
+)
+
 const (
 	sqlSelect = `select * from t where id = 'aaa'`
 )

+ 65 - 6
v1.go

@@ -2,6 +2,7 @@ package dpsapi
 
 import (
 	"errors"
+	"fmt"
 	"git.sxidc.com/go-tools/api_binding/http_binding"
 	"git.sxidc.com/go-tools/api_binding/http_binding/binding_context"
 	"git.sxidc.com/go-tools/api_binding/http_binding/response"
@@ -11,10 +12,46 @@ import (
 )
 
 func ApiV1(binding *http_binding.Binding, dpsAddress string) {
-	http_binding.PostBind(binding, &http_binding.SimpleBindItem[OperateFromRequest, any]{
+	http_binding.PostBind(binding, &http_binding.SimpleBindItem[OperateParseRequest, map[string]any]{
+		Path:         "/dpsv1/database/operate/parse",
+		ResponseFunc: response.SendMapResponse,
+		BusinessFunc: func(c *binding_context.Context, inputModel OperateParseRequest) (map[string]any, error) {
+			parsedClauses, err := parseSql(inputModel.SQL)
+			if err != nil {
+				return nil, err
+			}
+
+			for _, parsedClause := range parsedClauses {
+				switch clause := parsedClause.(type) {
+				case *insertClause:
+					return map[string]any{
+						"parsed": insertMap(clause),
+					}, nil
+				case *deleteClause:
+					return map[string]any{
+						"parsed": deleteMap(clause),
+					}, nil
+				case *updateClause:
+					return map[string]any{
+						"parsed": updateMap(clause),
+					}, nil
+				case *selectClause:
+					return map[string]any{
+						"parsed": selectMap(clause),
+					}, nil
+				default:
+					return nil, errors.New("不支持的SQL语句")
+				}
+			}
+
+			return nil, nil
+		},
+	})
+
+	http_binding.PostBind(binding, &http_binding.SimpleBindItem[OperateRequest, any]{
 		Path:         "/dpsv1/database/operate",
 		ResponseFunc: response.SendMsgResponse,
-		BusinessFunc: func(c *binding_context.Context, inputModel OperateFromRequest) (any, error) {
+		BusinessFunc: func(c *binding_context.Context, inputModel OperateRequest) (any, error) {
 			parsedClauses, err := parseSql(inputModel.SQL)
 			if err != nil {
 				return nil, err
@@ -52,16 +89,38 @@ func ApiV1(binding *http_binding.Binding, dpsAddress string) {
 	})
 }
 
-func doInsert(inputModel OperateFromRequest, clause *insertClause) error {
+func insertMap(clause *insertClause) map[string]any {
+	return map[string]any{
+		"table":       clause.table,
+		"key_columns": clause.keyColumns,
+		"table_rows":  clause.tableRows,
+	}
+}
+
+func deleteMap(clause *deleteClause) map[string]any {
+	return map[string]any{}
+}
+
+func updateMap(clause *updateClause) map[string]any {
+	return map[string]any{}
+}
+
+func selectMap(clause *selectClause) map[string]any {
+	return map[string]any{}
+}
+
+func doInsert(inputModel OperateRequest, clause *insertClause) error {
 	version := inputModel.Version
 	if utils.IsStringEmpty(version) {
 		version = "v1"
 	}
 
+	fmt.Printf("%+#v\n", clause)
+
 	return nil
 }
 
-func doDelete(inputModel OperateFromRequest, clause *deleteClause) error {
+func doDelete(inputModel OperateRequest, clause *deleteClause) error {
 	version := inputModel.Version
 	if utils.IsStringEmpty(version) {
 		version = "v1"
@@ -70,7 +129,7 @@ func doDelete(inputModel OperateFromRequest, clause *deleteClause) error {
 	return nil
 }
 
-func doUpdate(inputModel OperateFromRequest, clause *updateClause) error {
+func doUpdate(inputModel OperateRequest, clause *updateClause) error {
 	version := inputModel.Version
 	if utils.IsStringEmpty(version) {
 		version = "v1"
@@ -79,7 +138,7 @@ func doUpdate(inputModel OperateFromRequest, clause *updateClause) error {
 	return nil
 }
 
-func doSelect(inputModel OperateFromRequest, clause *selectClause) error {
+func doSelect(inputModel OperateRequest, clause *selectClause) error {
 	version := inputModel.Version
 	if utils.IsStringEmpty(version) {
 		version = "v1"

+ 71 - 3
v1_test.go

@@ -7,20 +7,88 @@ import (
 	"testing"
 )
 
-func TestApiV1(t *testing.T) {
+func TestApiV1OperateParse(t *testing.T) {
+	http_binding.Init("test", "10086")
+	defer http_binding.Destroy()
+
+	tableName := "students"
+	exceptedKeyColumns := []string{"id", "name", "age", "rate", "time", "is_right"}
+	exceptedTableRows := map[string]any{
+		"id":       "aaa",
+		"name":     "yjp",
+		"age":      float64(5),
+		"rate":     92.5,
+		"time":     "2024-01-01T00:00:00+08:00",
+		"is_right": false,
+	}
+
+	binding := http_binding.NewBinding("v1")
+	ApiV1(binding, "localhost:30170")
+
+	parsed := operateParse(t, parseSqlSelect)
+
+	if parsed["table"].(string) != tableName {
+		t.Fatal("表名不正确")
+	}
+
+	for i, keyColumn := range parsed["key_columns"].([]any) {
+		if exceptedKeyColumns[i] != keyColumn {
+			t.Fatal("没有关键列数值或顺序不正确")
+		}
+	}
+
+	for columnName, value := range parsed["table_rows"].(map[string]any) {
+		if exceptedTableRows[columnName] != value {
+			t.Fatal("行数据不正确")
+		}
+	}
+}
+
+func TestApiV1Operate(t *testing.T) {
 	http_binding.Init("test", "10086")
 	defer http_binding.Destroy()
 
 	binding := http_binding.NewBinding("v1")
 	ApiV1(binding, "localhost:30170")
 
+	operate(t, sqlSelect)
+}
+
+func operateParse(t *testing.T, sql string) map[string]any {
+	result := new(struct {
+		response.MsgResponse
+		Parsed map[string]any `json:"parsed"`
+	})
+
+	resp, err := resty.New().R().
+		SetBody(&OperateParseRequest{
+			SQL: sql,
+		}).
+		SetResult(result).
+		Post("http://localhost:10086/test/api/v1/dpsv1/database/operate/parse")
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	if resp.IsError() {
+		t.Fatal(resp.Status())
+	}
+
+	if !result.Success {
+		t.Fatal(result.Msg)
+	}
+
+	return result.Parsed
+}
+
+func operate(t *testing.T, sql string) {
 	result := new(response.MsgResponse)
 
 	resp, err := resty.New().R().
-		SetBody(&OperateFromRequest{
+		SetBody(&OperateRequest{
 			DatabaseID: "ee2d7dabe56646ce835d80873348ee0e",
 			Version:    "v1",
-			SQL:        sqlSelect,
+			SQL:        sql,
 		}).
 		SetResult(result).
 		Post("http://localhost:10086/test/api/v1/dpsv1/database/operate")