package dpsapi import ( "errors" "github.com/auxten/postgresql-parser/pkg/sql/parser" "github.com/auxten/postgresql-parser/pkg/sql/sem/tree" "github.com/auxten/postgresql-parser/pkg/walk" "go/constant" "strconv" "strings" "time" ) const ( clauseTableRowValueKindTime int = iota + 1 clauseTableRowValueKindBool clauseTableRowValueKindString clauseTableRowValueKindUint64 clauseTableRowValueKindFloat64 ) type clauseTableRowValue struct { kind int value any } type insertClause struct { table string tableRow map[string]clauseTableRowValue } type deleteClause struct { table string where string } type updateClause struct { table string where string newTableRow map[string]clauseTableRowValue } type selectClause struct { selectExpr []string from string where []string limit int offset int } func parseSql(sqlStr string) ([]any, error) { sqls := strings.Split(sqlStr, ";") sqlClauses := make([]any, 0) for _, sql := range sqls { trimSQL := strings.TrimSpace(sql) upperTrimSQL := strings.ToUpper(trimSQL) var clause any if strings.HasPrefix(upperTrimSQL, "INSERT") { innerClause, err := insertWalk(sql) if err != nil { return nil, err } clause = innerClause } else if strings.HasPrefix(upperTrimSQL, "DELETE") { innerClause, err := deleteWalk(sql) if err != nil { return nil, err } clause = innerClause } else if strings.HasPrefix(upperTrimSQL, "UPDATE") { innerClause, err := updateWalk(sql) if err != nil { return nil, err } clause = innerClause } else if strings.HasPrefix(upperTrimSQL, "SELECT") { innerClause, err := selectWalk(sql) if err != nil { return nil, err } clause = innerClause } sqlClauses = append(sqlClauses, clause) } return sqlClauses, nil } func insertWalk(sql string) (*insertClause, error) { clause := new(insertClause) stmts, err := parser.Parse(sql) if err != nil { return nil, err } var walkFuncErr error w := &walk.AstWalker{ Fn: func(ctx interface{}, node interface{}) (stop bool) { realNode := node.(*tree.Insert) // 获取table tableFmtCtx := tree.NewFmtCtx(tree.FmtSimple) realNode.Table.Format(tableFmtCtx) clause.table = tableFmtCtx.String() // 解析rows values := make([]clauseTableRowValue, 0) rowsStatement, err := parser.Parse(realNode.Rows.String()) if err != nil { walkFuncErr = err return true } var rowsWalkFuncErr error funcStringParams := make([]string, 0) rowsWalker := &walk.AstWalker{Fn: func(ctx interface{}, node interface{}) (stop bool) { switch rowsNode := node.(type) { case *tree.FuncExpr: // 函数类型 value, kind, err := evaluateFuncExpr(rowsNode, nil, &funcStringParams) if err != nil { rowsWalkFuncErr = err return true } values = append(values, clauseTableRowValue{ kind: kind, value: value, }) case *tree.DBool: // 布尔类型 stringValue := rowsNode.String() if !canSkipValueInThisContext(&funcStringParams, stringValue) { if stringValue == "false" { values = append(values, clauseTableRowValue{ kind: clauseTableRowValueKindBool, value: false, }) } else if stringValue == "true" { values = append(values, clauseTableRowValue{ kind: clauseTableRowValueKindBool, value: true, }) } else { rowsWalkFuncErr = errors.New("不支持的bool值") return true } } case *tree.StrVal: // 字符串类型或者函数参数是字符串的类型,这里通过比较字符串value,排除了函数参数类型 stringValue := rowsNode.String() if !canSkipValueInThisContext(&funcStringParams, stringValue) { values = append(values, clauseTableRowValue{ kind: clauseTableRowValueKindString, value: rowsNode.RawString(), }) } case *tree.NumVal: // 数值类型,可以是整形或浮点型 stringValue := rowsNode.String() if !canSkipValueInThisContext(&funcStringParams, stringValue) { numKind := rowsNode.Kind() if numKind == constant.Int { valueUint64, err := strconv.ParseUint(rowsNode.String(), 10, 64) if err != nil { rowsWalkFuncErr = err return true } values = append(values, clauseTableRowValue{ kind: clauseTableRowValueKindUint64, value: valueUint64, }) } else if numKind == constant.Float { valueFloat64, err := strconv.ParseFloat(rowsNode.String(), 64) if err != nil { rowsWalkFuncErr = err return true } values = append(values, clauseTableRowValue{ kind: clauseTableRowValueKindFloat64, value: valueFloat64, }) } else { rowsWalkFuncErr = errors.New("不支持的数值类型") return true } } case *tree.UnresolvedName: rowsWalkFuncErr = errors.New("存在无法解析的名字,是否应该使用单引号") return true } return false }} _, err = rowsWalker.Walk(rowsStatement, nil) if err != nil { walkFuncErr = err return true } if rowsWalkFuncErr != nil { walkFuncErr = rowsWalkFuncErr return true } // 组装columnValues tableRows := make(map[string]clauseTableRowValue) for i, column := range realNode.Columns.ToStrings() { tableRows[column] = values[i] } clause.tableRow = tableRows return false }, } _, err = w.Walk(stmts, nil) if err != nil { return nil, err } if walkFuncErr != nil { return nil, walkFuncErr } return clause, nil } func deleteWalk(sql string) (*deleteClause, error) { clause := new(deleteClause) stmts, err := parser.Parse(sql) if err != nil { return nil, err } var walkFuncErr error w := &walk.AstWalker{ Fn: func(ctx interface{}, node interface{}) (stop bool) { realNode := node.(*tree.Delete) // 获取table tableFmtCtx := tree.NewFmtCtx(tree.FmtSimple) realNode.Table.Format(tableFmtCtx) clause.table = tableFmtCtx.String() // 获取where clause.where = realNode.Where.Expr.String() return false }, } _, err = w.Walk(stmts, nil) if err != nil { return nil, err } if walkFuncErr != nil { return nil, walkFuncErr } return clause, nil } func updateWalk(sql string) (*insertClause, error) { return nil, nil } func selectWalk(sql string) (*insertClause, error) { return nil, nil } func evaluateFuncExpr(expr *tree.FuncExpr, funcName *string, stringParams *[]string) (any, int, error) { if strings.HasPrefix(expr.String(), "parse_time") { if expr.Exprs == nil || len(expr.Exprs) != 2 { return nil, 0, errors.New("parse_time(time_str, time_format)") } timeStr := expr.Exprs[0].String() timeFormat := expr.Exprs[1].String() if funcName != nil { *funcName = "parse_time" } if stringParams != nil { *stringParams = append(*stringParams, timeStr, timeFormat) } parsedTime, err := time.ParseInLocation(timeFormat, timeStr, time.Local) if err != nil { return nil, 0, err } return parsedTime, clauseTableRowValueKindTime, nil } else { return nil, 0, errors.New("不支持的函数") } } func canSkipValueInThisContext(stringParams *[]string, stringValue string) bool { if stringParams == nil || *stringParams == nil || len(*stringParams) == 0 { return false } if (*stringParams)[0] == stringValue { *stringParams = (*stringParams)[1:] return true } return false }