package dpsapi import ( "errors" "fmt" "github.com/auxten/postgresql-parser/pkg/sql/parser" "github.com/auxten/postgresql-parser/pkg/sql/sem/tree" "github.com/auxten/postgresql-parser/pkg/walk" "go/constant" "strconv" "strings" "time" ) const ( clauseTableRowValueKindTime int = iota + 1 clauseTableRowValueKindBool clauseTableRowValueKindString clauseTableRowValueKindUint64 clauseTableRowValueKindFloat64 ) type clauseTableRowValue struct { kind int value any } type insertClause struct { table string tableRow map[string]clauseTableRowValue } type deleteClause struct { table string where string } type updateClause struct { table string where string newTableRow map[string]clauseTableRowValue } type selectClause struct { selectExpr []string from string where []string limit int offset int } // 调试很重要的函数,可以看到一个节点的实际类型以及包含字段的类型 func printNode(node any) { fmt.Printf("%+#v\n", node) } func parseSql(sqlStr string) ([]any, error) { sqls := strings.Split(sqlStr, ";") sqlClauses := make([]any, 0) for _, sql := range sqls { trimSQL := strings.TrimSpace(sql) upperTrimSQL := strings.ToUpper(trimSQL) var clause any if strings.HasPrefix(upperTrimSQL, "INSERT") { innerClause, err := insertWalk(sql) if err != nil { return nil, err } clause = innerClause } else if strings.HasPrefix(upperTrimSQL, "DELETE") { innerClause, err := deleteWalk(sql) if err != nil { return nil, err } clause = innerClause } else if strings.HasPrefix(upperTrimSQL, "UPDATE") { innerClause, err := updateWalk(sql) if err != nil { return nil, err } clause = innerClause } else if strings.HasPrefix(upperTrimSQL, "SELECT") { innerClause, err := selectWalk(sql) if err != nil { return nil, err } clause = innerClause } sqlClauses = append(sqlClauses, clause) } return sqlClauses, nil } func insertWalk(sql string) (*insertClause, error) { clause := new(insertClause) stmts, err := parser.Parse(sql) if err != nil { return nil, err } var walkFuncErr error w := &walk.AstWalker{ Fn: func(ctx interface{}, node interface{}) (stop bool) { realNode := node.(*tree.Insert) // 获取table tableName := realNode.Table.(*tree.TableName) clause.table = tableName.Table() // 获取table row clause.tableRow = make(map[string]clauseTableRowValue) columns := realNode.Columns.ToStrings() valuesClause := realNode.Rows.Select.(*tree.ValuesClause) for _, row := range valuesClause.Rows { for i, column := range row { columnValue, err := parseValueExpr(column) if err != nil { walkFuncErr = err return true } clause.tableRow[columns[i]] = *columnValue } } return false }, } _, err = w.Walk(stmts, nil) if err != nil { return nil, err } if walkFuncErr != nil { return nil, walkFuncErr } return clause, nil } func deleteWalk(sql string) (*deleteClause, error) { clause := new(deleteClause) stmts, err := parser.Parse(sql) if err != nil { return nil, err } var walkFuncErr error w := &walk.AstWalker{ Fn: func(ctx interface{}, node interface{}) (stop bool) { realNode := node.(*tree.Delete) // 获取table tableName := realNode.Table.(*tree.AliasedTableExpr) clause.table = tableName.String() // 获取where clause.where = realNode.Where.Expr.String() return false }, } _, err = w.Walk(stmts, nil) if err != nil { return nil, err } if walkFuncErr != nil { return nil, walkFuncErr } return clause, nil } func updateWalk(sql string) (*updateClause, error) { clause := new(updateClause) stmts, err := parser.Parse(sql) if err != nil { return nil, err } var walkFuncErr error w := &walk.AstWalker{ Fn: func(ctx interface{}, node interface{}) (stop bool) { realNode := node.(*tree.Update) // 获取table tableName := realNode.Table.(*tree.AliasedTableExpr) clause.table = tableName.String() // 获取where clause.where = realNode.Where.Expr.String() // 获取table row clause.newTableRow = make(map[string]clauseTableRowValue) for _, expr := range realNode.Exprs { value, err := parseValueExpr(expr.Expr) if err != nil { walkFuncErr = err return true } clause.newTableRow[fmt.Sprint(expr.Names[0])] = *value } return false }, } _, err = w.Walk(stmts, nil) if err != nil { return nil, err } if walkFuncErr != nil { return nil, walkFuncErr } return clause, nil } func selectWalk(sql string) (*insertClause, error) { return nil, nil } func parseValueExpr(valueExpr tree.Expr) (*clauseTableRowValue, error) { switch realColumn := valueExpr.(type) { case *tree.FuncExpr: // 函数类型 value, kind, err := evaluateFuncExpr(realColumn) if err != nil { return nil, err } return &clauseTableRowValue{ kind: kind, value: value, }, nil case *tree.DBool: // 布尔类型 var boolValue bool stringValue := realColumn.String() if stringValue == "true" { boolValue = true } return &clauseTableRowValue{ kind: clauseTableRowValueKindBool, value: boolValue, }, nil case *tree.StrVal: // 字符串类型或者函数参数是字符串的类型,这里通过比较字符串value,排除了函数参数类型 return &clauseTableRowValue{ kind: clauseTableRowValueKindString, value: realColumn.RawString(), }, nil case *tree.NumVal: // 数值类型,可以是整形或浮点型 numKind := realColumn.Kind() if numKind == constant.Int { valueUint64, err := strconv.ParseUint(realColumn.String(), 10, 64) if err != nil { return nil, err } return &clauseTableRowValue{ kind: clauseTableRowValueKindUint64, value: valueUint64, }, nil } else if numKind == constant.Float { valueFloat64, err := strconv.ParseFloat(realColumn.String(), 64) if err != nil { return nil, err } return &clauseTableRowValue{ kind: clauseTableRowValueKindFloat64, value: valueFloat64, }, nil } else { return nil, errors.New("不支持的数值类型") } case *tree.UnresolvedName: return nil, errors.New("存在无法解析的名字,是否应该使用单引号") default: return nil, errors.New("未支持的数据类型") } } func evaluateFuncExpr(expr *tree.FuncExpr) (any, int, error) { if strings.HasPrefix(expr.String(), "parse_time") { if expr.Exprs == nil || len(expr.Exprs) != 2 { return nil, 0, errors.New("parse_time(time_str, time_format)") } timeStrValue, err := parseValueExpr(expr.Exprs[0]) if err != nil { return nil, 0, err } if timeStrValue.kind != clauseTableRowValueKindString { return nil, 0, errors.New("时间字符串不是字符串类型") } timeFormatValue, err := parseValueExpr(expr.Exprs[1]) if err != nil { return nil, 0, err } if timeFormatValue.kind != clauseTableRowValueKindString { return nil, 0, errors.New("时间格式不是字符串类型") } parsedTime, err := time.ParseInLocation(timeFormatValue.value.(string), timeStrValue.value.(string), time.Local) if err != nil { return nil, 0, err } return parsedTime, clauseTableRowValueKindTime, nil } else { return nil, 0, errors.New("不支持的函数") } }