|
@@ -1,13 +1,20 @@
|
|
|
package dpsapi
|
|
|
|
|
|
import (
|
|
|
+ "errors"
|
|
|
"github.com/auxten/postgresql-parser/pkg/sql/parser"
|
|
|
+ "github.com/auxten/postgresql-parser/pkg/sql/sem/tree"
|
|
|
"github.com/auxten/postgresql-parser/pkg/walk"
|
|
|
+ "go/constant"
|
|
|
+ "strconv"
|
|
|
"strings"
|
|
|
+ "time"
|
|
|
)
|
|
|
|
|
|
type insertClause struct {
|
|
|
- into string
|
|
|
+ table string
|
|
|
+ keyColumns []string
|
|
|
+ tableRows map[string]any
|
|
|
}
|
|
|
|
|
|
type deleteClause struct {
|
|
@@ -30,38 +37,39 @@ func parseSql(sqlStr string) ([]any, error) {
|
|
|
sqlClauses := make([]any, 0)
|
|
|
|
|
|
for _, sql := range sqls {
|
|
|
- stmts, err := parser.Parse(sql)
|
|
|
- if err != nil {
|
|
|
- return nil, err
|
|
|
- }
|
|
|
-
|
|
|
trimSQL := strings.TrimSpace(sql)
|
|
|
upperTrimSQL := strings.ToUpper(trimSQL)
|
|
|
|
|
|
var clause any
|
|
|
- w := new(walk.AstWalker)
|
|
|
|
|
|
if strings.HasPrefix(upperTrimSQL, "INSERT") {
|
|
|
- c := new(insertClause)
|
|
|
- clause = c
|
|
|
- w.Fn = insertWalkFunc(c)
|
|
|
+ innerClause, err := insertWalk(sql)
|
|
|
+ if err != nil {
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
+
|
|
|
+ clause = innerClause
|
|
|
} else if strings.HasPrefix(upperTrimSQL, "DELETE") {
|
|
|
- c := new(deleteClause)
|
|
|
- clause = c
|
|
|
- w.Fn = deleteWalkFunc(c)
|
|
|
+ innerClause, err := deleteWalk(sql)
|
|
|
+ if err != nil {
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
+
|
|
|
+ clause = innerClause
|
|
|
} else if strings.HasPrefix(upperTrimSQL, "UPDATE") {
|
|
|
- c := new(updateClause)
|
|
|
- clause = c
|
|
|
- w.Fn = updateWalkFunc(c)
|
|
|
+ innerClause, err := updateWalk(sql)
|
|
|
+ if err != nil {
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
+
|
|
|
+ clause = innerClause
|
|
|
} else if strings.HasPrefix(upperTrimSQL, "SELECT") {
|
|
|
- c := new(selectClause)
|
|
|
- clause = c
|
|
|
- w.Fn = selectWalkFunc(c)
|
|
|
- }
|
|
|
+ innerClause, err := selectWalk(sql)
|
|
|
+ if err != nil {
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
|
|
|
- _, err = w.Walk(stmts, nil)
|
|
|
- if err != nil {
|
|
|
- return nil, err
|
|
|
+ clause = innerClause
|
|
|
}
|
|
|
|
|
|
sqlClauses = append(sqlClauses, clause)
|
|
@@ -70,26 +78,200 @@ func parseSql(sqlStr string) ([]any, error) {
|
|
|
return sqlClauses, nil
|
|
|
}
|
|
|
|
|
|
-func insertWalkFunc(clause *insertClause) func(ctx interface{}, node interface{}) (stop bool) {
|
|
|
- return func(ctx interface{}, node interface{}) (stop bool) {
|
|
|
- return false
|
|
|
+func insertWalk(sql string) (*insertClause, error) {
|
|
|
+ clause := new(insertClause)
|
|
|
+
|
|
|
+ stmts, err := parser.Parse(sql)
|
|
|
+ if err != nil {
|
|
|
+ return nil, err
|
|
|
}
|
|
|
-}
|
|
|
|
|
|
-func deleteWalkFunc(clause *deleteClause) func(ctx interface{}, node interface{}) (stop bool) {
|
|
|
- return func(ctx interface{}, node interface{}) (stop bool) {
|
|
|
- return false
|
|
|
+ var walkFuncErr error
|
|
|
+
|
|
|
+ w := &walk.AstWalker{
|
|
|
+ Fn: func(ctx interface{}, node interface{}) (stop bool) {
|
|
|
+ realNode := node.(*tree.Insert)
|
|
|
+
|
|
|
+
|
|
|
+ tableFmtCtx := tree.NewFmtCtx(tree.FmtSimple)
|
|
|
+ realNode.Table.Format(tableFmtCtx)
|
|
|
+ clause.table = tableFmtCtx.String()
|
|
|
+
|
|
|
+
|
|
|
+ values := make([]any, 0)
|
|
|
+
|
|
|
+ rowsStatement, err := parser.Parse(realNode.Rows.String())
|
|
|
+ if err != nil {
|
|
|
+ walkFuncErr = err
|
|
|
+ return true
|
|
|
+ }
|
|
|
+
|
|
|
+ var rowsWalkFuncErr error
|
|
|
+ funcStringParams := make([]string, 0)
|
|
|
+
|
|
|
+ rowsWalker := &walk.AstWalker{Fn: func(ctx interface{}, node interface{}) (stop bool) {
|
|
|
+ switch rowsNode := node.(type) {
|
|
|
+ case *tree.FuncExpr:
|
|
|
+
|
|
|
+ value, err := evaluateFuncExpr(rowsNode, nil, &funcStringParams)
|
|
|
+ if err != nil {
|
|
|
+ rowsWalkFuncErr = err
|
|
|
+ return true
|
|
|
+ }
|
|
|
+
|
|
|
+ values = append(values, value)
|
|
|
+ case *tree.DBool:
|
|
|
+
|
|
|
+ stringValue := rowsNode.String()
|
|
|
+ if !canSkipValueInThisContext(&funcStringParams, stringValue) {
|
|
|
+ if stringValue == "false" {
|
|
|
+ values = append(values, false)
|
|
|
+ } else if stringValue == "true" {
|
|
|
+ values = append(values, true)
|
|
|
+ } else {
|
|
|
+ rowsWalkFuncErr = errors.New("不支持的bool值")
|
|
|
+ return true
|
|
|
+ }
|
|
|
+ }
|
|
|
+ case *tree.UnresolvedName:
|
|
|
+
|
|
|
+ stringValue := rowsNode.String()
|
|
|
+ if !canSkipValueInThisContext(&funcStringParams, stringValue) {
|
|
|
+ values = append(values, stringValue)
|
|
|
+ }
|
|
|
+ case *tree.NumVal:
|
|
|
+
|
|
|
+ stringValue := rowsNode.String()
|
|
|
+ if !canSkipValueInThisContext(&funcStringParams, stringValue) {
|
|
|
+ numKind := rowsNode.Kind()
|
|
|
+ if numKind == constant.Int {
|
|
|
+ valueUint64, err := strconv.ParseUint(rowsNode.String(), 10, 64)
|
|
|
+ if err != nil {
|
|
|
+ rowsWalkFuncErr = err
|
|
|
+ return true
|
|
|
+ }
|
|
|
+
|
|
|
+ values = append(values, valueUint64)
|
|
|
+ } else if numKind == constant.Float {
|
|
|
+ valueFloat64, err := strconv.ParseFloat(rowsNode.String(), 64)
|
|
|
+ if err != nil {
|
|
|
+ rowsWalkFuncErr = err
|
|
|
+ return true
|
|
|
+ }
|
|
|
+
|
|
|
+ values = append(values, valueFloat64)
|
|
|
+ } else {
|
|
|
+ rowsWalkFuncErr = errors.New("不支持的数值类型")
|
|
|
+ return true
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ return false
|
|
|
+ }}
|
|
|
+
|
|
|
+ _, err = rowsWalker.Walk(rowsStatement, nil)
|
|
|
+ if err != nil {
|
|
|
+ walkFuncErr = err
|
|
|
+ return true
|
|
|
+ }
|
|
|
+
|
|
|
+ if rowsWalkFuncErr != nil {
|
|
|
+ walkFuncErr = rowsWalkFuncErr
|
|
|
+ return true
|
|
|
+ }
|
|
|
+
|
|
|
+
|
|
|
+ hasIDColumn := false
|
|
|
+ keyColumns := make([]string, 0)
|
|
|
+ allColumns := make([]string, 0)
|
|
|
+ tableRows := make(map[string]any)
|
|
|
+
|
|
|
+ for i, column := range realNode.Columns.ToStrings() {
|
|
|
+ if column == "id" {
|
|
|
+ hasIDColumn = true
|
|
|
+ }
|
|
|
+
|
|
|
+ if strings.HasPrefix(column, "**") {
|
|
|
+ column = strings.TrimPrefix(column, "**")
|
|
|
+ keyColumns = append(keyColumns, column)
|
|
|
+ }
|
|
|
+
|
|
|
+ allColumns = append(allColumns, column)
|
|
|
+
|
|
|
+ tableRows[column] = values[i]
|
|
|
+ }
|
|
|
+
|
|
|
+ if keyColumns != nil && len(keyColumns) != 0 {
|
|
|
+ clause.keyColumns = keyColumns
|
|
|
+ } else if hasIDColumn {
|
|
|
+ clause.keyColumns = []string{"id"}
|
|
|
+ } else {
|
|
|
+ clause.keyColumns = allColumns
|
|
|
+ }
|
|
|
+
|
|
|
+ clause.tableRows = tableRows
|
|
|
+
|
|
|
+ return false
|
|
|
+ },
|
|
|
}
|
|
|
+
|
|
|
+ _, err = w.Walk(stmts, nil)
|
|
|
+ if err != nil {
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
+
|
|
|
+ if walkFuncErr != nil {
|
|
|
+ return nil, walkFuncErr
|
|
|
+ }
|
|
|
+
|
|
|
+ return clause, nil
|
|
|
}
|
|
|
|
|
|
-func updateWalkFunc(clause *updateClause) func(ctx interface{}, node interface{}) (stop bool) {
|
|
|
- return func(ctx interface{}, node interface{}) (stop bool) {
|
|
|
- return false
|
|
|
+func deleteWalk(sql string) (*insertClause, error) {
|
|
|
+ return nil, nil
|
|
|
+}
|
|
|
+
|
|
|
+func updateWalk(sql string) (*insertClause, error) {
|
|
|
+ return nil, nil
|
|
|
+}
|
|
|
+
|
|
|
+func selectWalk(sql string) (*insertClause, error) {
|
|
|
+ return nil, nil
|
|
|
+}
|
|
|
+
|
|
|
+func evaluateFuncExpr(expr *tree.FuncExpr, funcName *string, stringParams *[]string) (any, error) {
|
|
|
+ if strings.HasPrefix(expr.String(), "parse_time") {
|
|
|
+ if expr.Exprs == nil || len(expr.Exprs) != 2 {
|
|
|
+ return nil, errors.New("parse_time(time_str, time_format)")
|
|
|
+ }
|
|
|
+
|
|
|
+ timeStr := expr.Exprs[0].String()
|
|
|
+ timeFormat := expr.Exprs[1].String()
|
|
|
+
|
|
|
+ if funcName != nil {
|
|
|
+ *funcName = "parse_time"
|
|
|
+ }
|
|
|
+
|
|
|
+ if stringParams != nil {
|
|
|
+ *stringParams = append(*stringParams, timeStr, timeFormat)
|
|
|
+ }
|
|
|
+
|
|
|
+ return time.ParseInLocation(timeFormat, timeStr, time.Local)
|
|
|
+ } else {
|
|
|
+ return nil, errors.New("不支持的函数")
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-func selectWalkFunc(clause *selectClause) func(ctx interface{}, node interface{}) (stop bool) {
|
|
|
- return func(ctx interface{}, node interface{}) (stop bool) {
|
|
|
+func canSkipValueInThisContext(stringParams *[]string, stringValue string) bool {
|
|
|
+ if stringParams == nil || *stringParams == nil || len(*stringParams) == 0 {
|
|
|
return false
|
|
|
}
|
|
|
+
|
|
|
+ if (*stringParams)[0] == stringValue {
|
|
|
+ *stringParams = (*stringParams)[1:]
|
|
|
+ return true
|
|
|
+ }
|
|
|
+
|
|
|
+ return false
|
|
|
}
|