package dpsapi import ( "errors" "github.com/auxten/postgresql-parser/pkg/sql/parser" "github.com/auxten/postgresql-parser/pkg/sql/sem/tree" "github.com/auxten/postgresql-parser/pkg/walk" "go/constant" "strconv" "strings" "time" ) const ( clauseTableRowValueKindTime int = iota + 1 clauseTableRowValueKindBool clauseTableRowValueKindString clauseTableRowValueKindUint64 clauseTableRowValueKindFloat64 ) type clauseTableRowValue struct { kind int value any } type insertClause struct { table string tableRow map[string]clauseTableRowValue } type deleteClause struct { table string where string } type updateClause struct { table string where string newTableRow map[string]clauseTableRowValue } type selectClause struct { selectExpr []string from string where []string limit int offset int } func parseSql(sqlStr string) ([]any, error) { sqls := strings.Split(sqlStr, ";") sqlClauses := make([]any, 0) for _, sql := range sqls { trimSQL := strings.TrimSpace(sql) upperTrimSQL := strings.ToUpper(trimSQL) var clause any if strings.HasPrefix(upperTrimSQL, "INSERT") { innerClause, err := insertWalk(sql) if err != nil { return nil, err } clause = innerClause } else if strings.HasPrefix(upperTrimSQL, "DELETE") { innerClause, err := deleteWalk(sql) if err != nil { return nil, err } clause = innerClause } else if strings.HasPrefix(upperTrimSQL, "UPDATE") { innerClause, err := updateWalk(sql) if err != nil { return nil, err } clause = innerClause } else if strings.HasPrefix(upperTrimSQL, "SELECT") { innerClause, err := selectWalk(sql) if err != nil { return nil, err } clause = innerClause } sqlClauses = append(sqlClauses, clause) } return sqlClauses, nil } func insertWalk(sql string) (*insertClause, error) { clause := new(insertClause) stmts, err := parser.Parse(sql) if err != nil { return nil, err } var walkFuncErr error w := &walk.AstWalker{ Fn: func(ctx interface{}, node interface{}) (stop bool) { realNode := node.(*tree.Insert) // 获取table tableFmtCtx := tree.NewFmtCtx(tree.FmtSimple) realNode.Table.Format(tableFmtCtx) clause.table = tableFmtCtx.String() // 解析rows rowsStatement, err := parser.Parse(realNode.Rows.String()) if err != nil { walkFuncErr = err return true } var rowsWalkFuncErr error values := make([]clauseTableRowValue, 0) rowsWalker := &walk.AstWalker{Fn: func(ctx interface{}, node interface{}) (stop bool) { switch rowsNode := node.(type) { case *tree.ValuesClause: for _, row := range rowsNode.Rows { for _, column := range row { columnValue, err := parseValueExpr(column) if err != nil { rowsWalkFuncErr = err return true } values = append(values, *columnValue) } } } return false }} _, err = rowsWalker.Walk(rowsStatement, nil) if err != nil { walkFuncErr = err return true } if rowsWalkFuncErr != nil { walkFuncErr = rowsWalkFuncErr return true } // 组装columnValues tableRows := make(map[string]clauseTableRowValue) for i, column := range realNode.Columns.ToStrings() { tableRows[column] = values[i] } clause.tableRow = tableRows return false }, } _, err = w.Walk(stmts, nil) if err != nil { return nil, err } if walkFuncErr != nil { return nil, walkFuncErr } return clause, nil } func deleteWalk(sql string) (*deleteClause, error) { clause := new(deleteClause) stmts, err := parser.Parse(sql) if err != nil { return nil, err } var walkFuncErr error w := &walk.AstWalker{ Fn: func(ctx interface{}, node interface{}) (stop bool) { realNode := node.(*tree.Delete) // 获取table tableFmtCtx := tree.NewFmtCtx(tree.FmtSimple) realNode.Table.Format(tableFmtCtx) clause.table = tableFmtCtx.String() // 获取where clause.where = realNode.Where.Expr.String() return false }, } _, err = w.Walk(stmts, nil) if err != nil { return nil, err } if walkFuncErr != nil { return nil, walkFuncErr } return clause, nil } func updateWalk(sql string) (*updateClause, error) { clause := new(updateClause) stmts, err := parser.Parse(sql) if err != nil { return nil, err } var walkFuncErr error w := &walk.AstWalker{ Fn: func(ctx interface{}, node interface{}) (stop bool) { realNode := node.(*tree.Update) // 获取table tableFmtCtx := tree.NewFmtCtx(tree.FmtSimple) realNode.Table.Format(tableFmtCtx) clause.table = tableFmtCtx.String() // 获取where clause.where = realNode.Where.Expr.String() // 获取table row newTableRow := make(map[string]clauseTableRowValue) for _, expr := range realNode.Exprs { value, err := parseValueExpr(expr.Expr) if err != nil { walkFuncErr = err return true } newTableRow[expr.Names[0].String()] = *value } return false }, } _, err = w.Walk(stmts, nil) if err != nil { return nil, err } if walkFuncErr != nil { return nil, walkFuncErr } return clause, nil } func selectWalk(sql string) (*insertClause, error) { return nil, nil } func parseValueExpr(valueExpr tree.Expr) (*clauseTableRowValue, error) { switch realColumn := valueExpr.(type) { case *tree.FuncExpr: // 函数类型 value, kind, err := evaluateFuncExpr(realColumn) if err != nil { return nil, err } return &clauseTableRowValue{ kind: kind, value: value, }, nil case *tree.DBool: // 布尔类型 var boolValue bool stringValue := realColumn.String() if stringValue == "true" { boolValue = true } return &clauseTableRowValue{ kind: clauseTableRowValueKindBool, value: boolValue, }, nil case *tree.StrVal: // 字符串类型或者函数参数是字符串的类型,这里通过比较字符串value,排除了函数参数类型 return &clauseTableRowValue{ kind: clauseTableRowValueKindString, value: realColumn.RawString(), }, nil case *tree.NumVal: // 数值类型,可以是整形或浮点型 numKind := realColumn.Kind() if numKind == constant.Int { valueUint64, err := strconv.ParseUint(realColumn.String(), 10, 64) if err != nil { return nil, err } return &clauseTableRowValue{ kind: clauseTableRowValueKindUint64, value: valueUint64, }, nil } else if numKind == constant.Float { valueFloat64, err := strconv.ParseFloat(realColumn.String(), 64) if err != nil { return nil, err } return &clauseTableRowValue{ kind: clauseTableRowValueKindFloat64, value: valueFloat64, }, nil } else { return nil, errors.New("不支持的数值类型") } case *tree.UnresolvedName: return nil, errors.New("存在无法解析的名字,是否应该使用单引号") default: return nil, errors.New("未支持的数据类型") } } func evaluateFuncExpr(expr *tree.FuncExpr) (any, int, error) { if strings.HasPrefix(expr.String(), "parse_time") { if expr.Exprs == nil || len(expr.Exprs) != 2 { return nil, 0, errors.New("parse_time(time_str, time_format)") } timeStr := expr.Exprs[0].String() timeFormat := expr.Exprs[1].String() parsedTime, err := time.ParseInLocation(timeFormat, timeStr, time.Local) if err != nil { return nil, 0, err } return parsedTime, clauseTableRowValueKindTime, nil } else { return nil, 0, errors.New("不支持的函数") } }