package api import ( "errors" "fmt" "github.com/auxten/postgresql-parser/pkg/sql/parser" "github.com/auxten/postgresql-parser/pkg/sql/sem/tree" "github.com/auxten/postgresql-parser/pkg/walk" "go/constant" "strconv" "strings" "time" ) const ( clauseTableRowValueKindTime int = iota + 1 clauseTableRowValueKindBool clauseTableRowValueKindString clauseTableRowValueKindUint64 clauseTableRowValueKindFloat64 ) type clauseTableRowValue struct { kind int value any } type insertClause struct { table string tableRow map[string]clauseTableRowValue } type deleteClause struct { table string where string } type updateClause struct { table string where string newTableRow map[string]clauseTableRowValue } type selectClause struct { table string fromSubQuery string selectClause string where string orderBy string groupBy string having string pageNo int pageSize int } // 调试很重要的函数,可以看到一个节点的实际类型以及包含字段的类型 func printNode(node any) { fmt.Printf("%+#v\n", node) } func parseSql(sqlStr string) ([]any, error) { sqls := strings.Split(sqlStr, ";") sqlClauses := make([]any, 0) for _, sql := range sqls { trimSQL := strings.TrimSpace(sql) upperTrimSQL := strings.ToUpper(trimSQL) var clause any if strings.HasPrefix(upperTrimSQL, "INSERT") { innerClause, err := insertWalk(sql) if err != nil { return nil, err } clause = innerClause } else if strings.HasPrefix(upperTrimSQL, "DELETE") { innerClause, err := deleteWalk(sql) if err != nil { return nil, err } clause = innerClause } else if strings.HasPrefix(upperTrimSQL, "UPDATE") { innerClause, err := updateWalk(sql) if err != nil { return nil, err } clause = innerClause } else if strings.HasPrefix(upperTrimSQL, "SELECT") { innerClause, err := selectWalk(sql) if err != nil { return nil, err } clause = innerClause } sqlClauses = append(sqlClauses, clause) } return sqlClauses, nil } func insertWalk(sql string) (*insertClause, error) { clause := new(insertClause) stmts, err := parser.Parse(sql) if err != nil { return nil, err } var walkFuncErr error w := &walk.AstWalker{ Fn: func(ctx interface{}, node interface{}) (stop bool) { realNode := node.(*tree.Insert) // 获取table tableName, err := parseTableExpr(realNode.Table) if err != nil { walkFuncErr = err return true } clause.table = tableName // 获取table row clause.tableRow = make(map[string]clauseTableRowValue) columns := realNode.Columns.ToStrings() valuesClause := realNode.Rows.Select.(*tree.ValuesClause) for _, row := range valuesClause.Rows { for i, column := range row { columnValue, err := parseExpr(column) if err != nil { walkFuncErr = err return true } clause.tableRow[columns[i]] = *columnValue } } return true }, } _, err = w.Walk(stmts, nil) if err != nil { return nil, err } if walkFuncErr != nil { return nil, walkFuncErr } return clause, nil } func deleteWalk(sql string) (*deleteClause, error) { clause := new(deleteClause) stmts, err := parser.Parse(sql) if err != nil { return nil, err } var walkFuncErr error w := &walk.AstWalker{ Fn: func(ctx interface{}, node interface{}) (stop bool) { realNode := node.(*tree.Delete) // 获取table tableName, err := parseTableExpr(realNode.Table) if err != nil { walkFuncErr = err return true } clause.table = tableName // 获取where clause.where = parseWhere(realNode.Where) return true }, } _, err = w.Walk(stmts, nil) if err != nil { return nil, err } if walkFuncErr != nil { return nil, walkFuncErr } return clause, nil } func updateWalk(sql string) (*updateClause, error) { clause := new(updateClause) stmts, err := parser.Parse(sql) if err != nil { return nil, err } var walkFuncErr error w := &walk.AstWalker{ Fn: func(ctx interface{}, node interface{}) (stop bool) { realNode := node.(*tree.Update) // 获取table tableName, err := parseTableExpr(realNode.Table) if err != nil { walkFuncErr = err return true } clause.table = tableName // 获取where clause.where = parseWhere(realNode.Where) // 获取table row clause.newTableRow = make(map[string]clauseTableRowValue) for _, expr := range realNode.Exprs { value, err := parseExpr(expr.Expr) if err != nil { walkFuncErr = err return true } clause.newTableRow[fmt.Sprint(expr.Names[0])] = *value } return true }, } _, err = w.Walk(stmts, nil) if err != nil { return nil, err } if walkFuncErr != nil { return nil, walkFuncErr } return clause, nil } func selectWalk(sql string) (*selectClause, error) { clause := new(selectClause) stmts, err := parser.Parse(sql) if err != nil { return nil, err } var walkFuncErr error w := &walk.AstWalker{ Fn: func(ctx interface{}, node interface{}) (stop bool) { realNode := node.(*tree.Select) nodeSelectClause := realNode.Select.(*tree.SelectClause) // select clause.selectClause = parseSelect(nodeSelectClause.Exprs) if clause.selectClause == "*" { clause.selectClause = "" } // from asFromSubQuery, from, err := parseFrom(&nodeSelectClause.From) if err != nil { walkFuncErr = err return true } if asFromSubQuery { clause.fromSubQuery = from } else { clause.table = from } // where if nodeSelectClause.Where != nil { clause.where = parseWhere(nodeSelectClause.Where) } // order by if realNode.OrderBy != nil { clause.orderBy = parseOrderBy(realNode.OrderBy) } // limit if realNode.Limit != nil { pageNo, pageSize, err := parseLimit(realNode.Limit) if err != nil { walkFuncErr = err return true } clause.pageNo = pageNo clause.pageSize = pageSize } // group by if nodeSelectClause.GroupBy != nil { clause.groupBy = parseGroupBy(nodeSelectClause.GroupBy) } // having if nodeSelectClause.Having != nil { clause.having = parseWhere(nodeSelectClause.Having) } return true }, } _, err = w.Walk(stmts, nil) if err != nil { return nil, err } if walkFuncErr != nil { return nil, walkFuncErr } return clause, nil } func parseTableExpr(tableExpr tree.TableExpr) (string, error) { switch table := tableExpr.(type) { case *tree.TableName: return table.String(), nil case *tree.AliasedTableExpr: return table.String(), nil default: return "", errors.New("不支持的TableExpr") } } func parseSelect(selectExprs tree.SelectExprs) string { selectFmtCtx := tree.NewFmtCtx(tree.FmtBareStrings) selectExprs.Format(selectFmtCtx) return selectFmtCtx.String() } func parseFrom(from *tree.From) (bool, string, error) { switch fromTable := from.Tables[0].(type) { case *tree.JoinTableExpr: return true, fromTable.String(), nil case *tree.AliasedTableExpr: _, ok := fromTable.Expr.(*tree.Subquery) if ok { return true, fromTable.String(), nil } else { return false, fromTable.String(), nil } default: return false, "", errors.New("不支持的From类型") } } func parseWhere(where *tree.Where) string { return where.Expr.String() } func parseOrderBy(orderBy tree.OrderBy) string { orderByFmtCtx := tree.NewFmtCtx(tree.FmtBareStrings) orderBy[0].Format(orderByFmtCtx) return orderByFmtCtx.String() } func parseLimit(limit *tree.Limit) (int, int, error) { pageNo, err := strconv.Atoi(limit.Offset.String()) if err != nil { return 0, 0, err } pageSize, err := strconv.Atoi(limit.Count.String()) if err != nil { return 0, 0, err } return pageNo, pageSize, nil } func parseGroupBy(groupBy tree.GroupBy) string { groupBySlice := make([]string, 0) for _, groupExpr := range groupBy { groupBySlice = append(groupBySlice, groupExpr.String()) } return strings.Join(groupBySlice, ",") } func parseExpr(valueExpr tree.Expr) (*clauseTableRowValue, error) { switch realColumn := valueExpr.(type) { case *tree.FuncExpr: // 函数类型 value, kind, err := evaluateFuncExpr(realColumn) if err != nil { return nil, err } return &clauseTableRowValue{ kind: kind, value: value, }, nil case *tree.DBool: // 布尔类型 var boolValue bool stringValue := realColumn.String() if stringValue == "true" { boolValue = true } return &clauseTableRowValue{ kind: clauseTableRowValueKindBool, value: boolValue, }, nil case *tree.StrVal: // 字符串类型或者函数参数是字符串的类型,这里通过比较字符串value,排除了函数参数类型 return &clauseTableRowValue{ kind: clauseTableRowValueKindString, value: realColumn.RawString(), }, nil case *tree.NumVal: // 数值类型,可以是整形或浮点型 numKind := realColumn.Kind() if numKind == constant.Int { valueUint64, err := strconv.ParseUint(realColumn.String(), 10, 64) if err != nil { return nil, err } return &clauseTableRowValue{ kind: clauseTableRowValueKindUint64, value: valueUint64, }, nil } else if numKind == constant.Float { valueFloat64, err := strconv.ParseFloat(realColumn.String(), 64) if err != nil { return nil, err } return &clauseTableRowValue{ kind: clauseTableRowValueKindFloat64, value: valueFloat64, }, nil } else { return nil, errors.New("不支持的数值类型") } case *tree.UnresolvedName: return nil, errors.New("存在无法解析的名字,是否应该使用单引号") default: return nil, errors.New("未支持的数据类型") } } func evaluateFuncExpr(expr *tree.FuncExpr) (any, int, error) { if strings.HasPrefix(expr.String(), "parse_time") { if expr.Exprs == nil || len(expr.Exprs) != 2 { return nil, 0, errors.New("parse_time(time_str, time_format)") } timeStrValue, err := parseExpr(expr.Exprs[0]) if err != nil { return nil, 0, err } if timeStrValue.kind != clauseTableRowValueKindString { return nil, 0, errors.New("时间字符串不是字符串类型") } timeFormatValue, err := parseExpr(expr.Exprs[1]) if err != nil { return nil, 0, err } if timeFormatValue.kind != clauseTableRowValueKindString { return nil, 0, errors.New("时间格式不是字符串类型") } parsedTime, err := time.ParseInLocation(timeFormatValue.value.(string), timeStrValue.value.(string), time.Local) if err != nil { return nil, 0, err } return parsedTime, clauseTableRowValueKindTime, nil } else { return nil, 0, errors.New("不支持的函数") } }