123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505 |
- package api
- import (
- "errors"
- "fmt"
- "github.com/auxten/postgresql-parser/pkg/sql/parser"
- "github.com/auxten/postgresql-parser/pkg/sql/sem/tree"
- "github.com/auxten/postgresql-parser/pkg/walk"
- "go/constant"
- "strconv"
- "strings"
- "time"
- )
- const (
- clauseTableRowValueKindTime int = iota + 1
- clauseTableRowValueKindBool
- clauseTableRowValueKindString
- clauseTableRowValueKindUint64
- clauseTableRowValueKindFloat64
- )
- type clauseTableRowValue struct {
- kind int
- value any
- }
- type insertClause struct {
- table string
- tableRow map[string]clauseTableRowValue
- }
- type deleteClause struct {
- table string
- where string
- }
- type updateClause struct {
- table string
- where string
- newTableRow map[string]clauseTableRowValue
- }
- type selectClause struct {
- table string
- fromSubQuery string
- selectClause string
- where string
- orderBy string
- groupBy string
- having string
- pageNo int
- pageSize int
- }
- // 调试很重要的函数,可以看到一个节点的实际类型以及包含字段的类型
- func printNode(node any) {
- fmt.Printf("%+#v\n", node)
- }
- func parseSql(sqlStr string) ([]any, error) {
- sqls := strings.Split(sqlStr, ";")
- sqlClauses := make([]any, 0)
- for _, sql := range sqls {
- trimSQL := strings.TrimSpace(sql)
- upperTrimSQL := strings.ToUpper(trimSQL)
- var clause any
- if strings.HasPrefix(upperTrimSQL, "INSERT") {
- innerClause, err := insertWalk(sql)
- if err != nil {
- return nil, err
- }
- clause = innerClause
- } else if strings.HasPrefix(upperTrimSQL, "DELETE") {
- innerClause, err := deleteWalk(sql)
- if err != nil {
- return nil, err
- }
- clause = innerClause
- } else if strings.HasPrefix(upperTrimSQL, "UPDATE") {
- innerClause, err := updateWalk(sql)
- if err != nil {
- return nil, err
- }
- clause = innerClause
- } else if strings.HasPrefix(upperTrimSQL, "SELECT") {
- innerClause, err := selectWalk(sql)
- if err != nil {
- return nil, err
- }
- clause = innerClause
- }
- sqlClauses = append(sqlClauses, clause)
- }
- return sqlClauses, nil
- }
- func insertWalk(sql string) (*insertClause, error) {
- clause := new(insertClause)
- stmts, err := parser.Parse(sql)
- if err != nil {
- return nil, err
- }
- var walkFuncErr error
- w := &walk.AstWalker{
- Fn: func(ctx interface{}, node interface{}) (stop bool) {
- realNode := node.(*tree.Insert)
- // 获取table
- tableName, err := parseTableExpr(realNode.Table)
- if err != nil {
- walkFuncErr = err
- return true
- }
- clause.table = tableName
- // 获取table row
- clause.tableRow = make(map[string]clauseTableRowValue)
- columns := realNode.Columns.ToStrings()
- valuesClause := realNode.Rows.Select.(*tree.ValuesClause)
- for _, row := range valuesClause.Rows {
- for i, column := range row {
- columnValue, err := parseExpr(column)
- if err != nil {
- walkFuncErr = err
- return true
- }
- clause.tableRow[columns[i]] = *columnValue
- }
- }
- return true
- },
- }
- _, err = w.Walk(stmts, nil)
- if err != nil {
- return nil, err
- }
- if walkFuncErr != nil {
- return nil, walkFuncErr
- }
- return clause, nil
- }
- func deleteWalk(sql string) (*deleteClause, error) {
- clause := new(deleteClause)
- stmts, err := parser.Parse(sql)
- if err != nil {
- return nil, err
- }
- var walkFuncErr error
- w := &walk.AstWalker{
- Fn: func(ctx interface{}, node interface{}) (stop bool) {
- realNode := node.(*tree.Delete)
- // 获取table
- tableName, err := parseTableExpr(realNode.Table)
- if err != nil {
- walkFuncErr = err
- return true
- }
- clause.table = tableName
- // 获取where
- clause.where = parseWhere(realNode.Where)
- return true
- },
- }
- _, err = w.Walk(stmts, nil)
- if err != nil {
- return nil, err
- }
- if walkFuncErr != nil {
- return nil, walkFuncErr
- }
- return clause, nil
- }
- func updateWalk(sql string) (*updateClause, error) {
- clause := new(updateClause)
- stmts, err := parser.Parse(sql)
- if err != nil {
- return nil, err
- }
- var walkFuncErr error
- w := &walk.AstWalker{
- Fn: func(ctx interface{}, node interface{}) (stop bool) {
- realNode := node.(*tree.Update)
- // 获取table
- tableName, err := parseTableExpr(realNode.Table)
- if err != nil {
- walkFuncErr = err
- return true
- }
- clause.table = tableName
- // 获取where
- clause.where = parseWhere(realNode.Where)
- // 获取table row
- clause.newTableRow = make(map[string]clauseTableRowValue)
- for _, expr := range realNode.Exprs {
- value, err := parseExpr(expr.Expr)
- if err != nil {
- walkFuncErr = err
- return true
- }
- clause.newTableRow[fmt.Sprint(expr.Names[0])] = *value
- }
- return true
- },
- }
- _, err = w.Walk(stmts, nil)
- if err != nil {
- return nil, err
- }
- if walkFuncErr != nil {
- return nil, walkFuncErr
- }
- return clause, nil
- }
- func selectWalk(sql string) (*selectClause, error) {
- clause := new(selectClause)
- stmts, err := parser.Parse(sql)
- if err != nil {
- return nil, err
- }
- var walkFuncErr error
- w := &walk.AstWalker{
- Fn: func(ctx interface{}, node interface{}) (stop bool) {
- realNode := node.(*tree.Select)
- nodeSelectClause := realNode.Select.(*tree.SelectClause)
- // select
- clause.selectClause = parseSelect(nodeSelectClause.Exprs)
- if clause.selectClause == "*" {
- clause.selectClause = ""
- }
- // from
- asFromSubQuery, from, err := parseFrom(&nodeSelectClause.From)
- if err != nil {
- walkFuncErr = err
- return true
- }
- if asFromSubQuery {
- clause.fromSubQuery = from
- } else {
- clause.table = from
- }
- // where
- if nodeSelectClause.Where != nil {
- clause.where = parseWhere(nodeSelectClause.Where)
- }
- // order by
- if realNode.OrderBy != nil {
- clause.orderBy = parseOrderBy(realNode.OrderBy)
- }
- // limit
- if realNode.Limit != nil {
- pageNo, pageSize, err := parseLimit(realNode.Limit)
- if err != nil {
- walkFuncErr = err
- return true
- }
- clause.pageNo = pageNo
- clause.pageSize = pageSize
- }
- // group by
- if nodeSelectClause.GroupBy != nil {
- clause.groupBy = parseGroupBy(nodeSelectClause.GroupBy)
- }
- // having
- if nodeSelectClause.Having != nil {
- clause.having = parseWhere(nodeSelectClause.Having)
- }
- return true
- },
- }
- _, err = w.Walk(stmts, nil)
- if err != nil {
- return nil, err
- }
- if walkFuncErr != nil {
- return nil, walkFuncErr
- }
- return clause, nil
- }
- func parseTableExpr(tableExpr tree.TableExpr) (string, error) {
- switch table := tableExpr.(type) {
- case *tree.TableName:
- return table.String(), nil
- case *tree.AliasedTableExpr:
- return table.String(), nil
- default:
- return "", errors.New("不支持的TableExpr")
- }
- }
- func parseSelect(selectExprs tree.SelectExprs) string {
- selectFmtCtx := tree.NewFmtCtx(tree.FmtBareStrings)
- selectExprs.Format(selectFmtCtx)
- return selectFmtCtx.String()
- }
- func parseFrom(from *tree.From) (bool, string, error) {
- switch fromTable := from.Tables[0].(type) {
- case *tree.JoinTableExpr:
- return true, fromTable.String(), nil
- case *tree.AliasedTableExpr:
- _, ok := fromTable.Expr.(*tree.Subquery)
- if ok {
- return true, fromTable.String(), nil
- } else {
- return false, fromTable.String(), nil
- }
- default:
- return false, "", errors.New("不支持的From类型")
- }
- }
- func parseWhere(where *tree.Where) string {
- return where.Expr.String()
- }
- func parseOrderBy(orderBy tree.OrderBy) string {
- orderByFmtCtx := tree.NewFmtCtx(tree.FmtBareStrings)
- orderBy[0].Format(orderByFmtCtx)
- return orderByFmtCtx.String()
- }
- func parseLimit(limit *tree.Limit) (int, int, error) {
- pageNo, err := strconv.Atoi(limit.Offset.String())
- if err != nil {
- return 0, 0, err
- }
- pageSize, err := strconv.Atoi(limit.Count.String())
- if err != nil {
- return 0, 0, err
- }
- return pageNo, pageSize, nil
- }
- func parseGroupBy(groupBy tree.GroupBy) string {
- groupBySlice := make([]string, 0)
- for _, groupExpr := range groupBy {
- groupBySlice = append(groupBySlice, groupExpr.String())
- }
- return strings.Join(groupBySlice, ",")
- }
- func parseExpr(valueExpr tree.Expr) (*clauseTableRowValue, error) {
- switch realColumn := valueExpr.(type) {
- case *tree.FuncExpr:
- // 函数类型
- value, kind, err := evaluateFuncExpr(realColumn)
- if err != nil {
- return nil, err
- }
- return &clauseTableRowValue{
- kind: kind,
- value: value,
- }, nil
- case *tree.DBool:
- // 布尔类型
- var boolValue bool
- stringValue := realColumn.String()
- if stringValue == "true" {
- boolValue = true
- }
- return &clauseTableRowValue{
- kind: clauseTableRowValueKindBool,
- value: boolValue,
- }, nil
- case *tree.StrVal:
- // 字符串类型或者函数参数是字符串的类型,这里通过比较字符串value,排除了函数参数类型
- return &clauseTableRowValue{
- kind: clauseTableRowValueKindString,
- value: realColumn.RawString(),
- }, nil
- case *tree.NumVal:
- // 数值类型,可以是整形或浮点型
- numKind := realColumn.Kind()
- if numKind == constant.Int {
- valueUint64, err := strconv.ParseUint(realColumn.String(), 10, 64)
- if err != nil {
- return nil, err
- }
- return &clauseTableRowValue{
- kind: clauseTableRowValueKindUint64,
- value: valueUint64,
- }, nil
- } else if numKind == constant.Float {
- valueFloat64, err := strconv.ParseFloat(realColumn.String(), 64)
- if err != nil {
- return nil, err
- }
- return &clauseTableRowValue{
- kind: clauseTableRowValueKindFloat64,
- value: valueFloat64,
- }, nil
- } else {
- return nil, errors.New("不支持的数值类型")
- }
- case *tree.UnresolvedName:
- return nil, errors.New("存在无法解析的名字,是否应该使用单引号")
- default:
- return nil, errors.New("未支持的数据类型")
- }
- }
- func evaluateFuncExpr(expr *tree.FuncExpr) (any, int, error) {
- if strings.HasPrefix(expr.String(), "parse_time") {
- if expr.Exprs == nil || len(expr.Exprs) != 2 {
- return nil, 0, errors.New("parse_time(time_str, time_format)")
- }
- timeStrValue, err := parseExpr(expr.Exprs[0])
- if err != nil {
- return nil, 0, err
- }
- if timeStrValue.kind != clauseTableRowValueKindString {
- return nil, 0, errors.New("时间字符串不是字符串类型")
- }
- timeFormatValue, err := parseExpr(expr.Exprs[1])
- if err != nil {
- return nil, 0, err
- }
- if timeFormatValue.kind != clauseTableRowValueKindString {
- return nil, 0, errors.New("时间格式不是字符串类型")
- }
- parsedTime, err := time.ParseInLocation(timeFormatValue.value.(string), timeStrValue.value.(string), time.Local)
- if err != nil {
- return nil, 0, err
- }
- return parsedTime, clauseTableRowValueKindTime, nil
- } else {
- return nil, 0, errors.New("不支持的函数")
- }
- }
|