sql_parser.go 7.3 KB


  1. package dpsapi
  2. import (
  3. "errors"
  4. "github.com/auxten/postgresql-parser/pkg/sql/parser"
  5. "github.com/auxten/postgresql-parser/pkg/sql/sem/tree"
  6. "github.com/auxten/postgresql-parser/pkg/walk"
  7. "go/constant"
  8. "strconv"
  9. "strings"
  10. "time"
  11. )
  12. const (
  13. clauseTableRowValueKindTime int = iota + 1
  14. clauseTableRowValueKindBool
  15. clauseTableRowValueKindString
  16. clauseTableRowValueKindUint64
  17. clauseTableRowValueKindFloat64
  18. )
  19. type clauseTableRowValue struct {
  20. kind int
  21. value any
  22. }
  23. type insertClause struct {
  24. table string
  25. tableRow map[string]clauseTableRowValue
  26. }
  27. type deleteClause struct {
  28. table string
  29. where string
  30. }
  31. type updateClause struct {
  32. table string
  33. where string
  34. newTableRow map[string]clauseTableRowValue
  35. }
  36. type selectClause struct {
  37. selectExpr []string
  38. from string
  39. where []string
  40. limit int
  41. offset int
  42. }
  43. func parseSql(sqlStr string) ([]any, error) {
  44. sqls := strings.Split(sqlStr, ";")
  45. sqlClauses := make([]any, 0)
  46. for _, sql := range sqls {
  47. trimSQL := strings.TrimSpace(sql)
  48. upperTrimSQL := strings.ToUpper(trimSQL)
  49. var clause any
  50. if strings.HasPrefix(upperTrimSQL, "INSERT") {
  51. innerClause, err := insertWalk(sql)
  52. if err != nil {
  53. return nil, err
  54. }
  55. clause = innerClause
  56. } else if strings.HasPrefix(upperTrimSQL, "DELETE") {
  57. innerClause, err := deleteWalk(sql)
  58. if err != nil {
  59. return nil, err
  60. }
  61. clause = innerClause
  62. } else if strings.HasPrefix(upperTrimSQL, "UPDATE") {
  63. innerClause, err := updateWalk(sql)
  64. if err != nil {
  65. return nil, err
  66. }
  67. clause = innerClause
  68. } else if strings.HasPrefix(upperTrimSQL, "SELECT") {
  69. innerClause, err := selectWalk(sql)
  70. if err != nil {
  71. return nil, err
  72. }
  73. clause = innerClause
  74. }
  75. sqlClauses = append(sqlClauses, clause)
  76. }
  77. return sqlClauses, nil
  78. }
  79. func insertWalk(sql string) (*insertClause, error) {
  80. clause := new(insertClause)
  81. stmts, err := parser.Parse(sql)
  82. if err != nil {
  83. return nil, err
  84. }
  85. var walkFuncErr error
  86. w := &walk.AstWalker{
  87. Fn: func(ctx interface{}, node interface{}) (stop bool) {
  88. realNode := node.(*tree.Insert)
  89. // 获取table
  90. tableFmtCtx := tree.NewFmtCtx(tree.FmtSimple)
  91. realNode.Table.Format(tableFmtCtx)
  92. clause.table = tableFmtCtx.String()
  93. // 解析rows
  94. rowsStatement, err := parser.Parse(realNode.Rows.String())
  95. if err != nil {
  96. walkFuncErr = err
  97. return true
  98. }
  99. var rowsWalkFuncErr error
  100. values := make([]clauseTableRowValue, 0)
  101. rowsWalker := &walk.AstWalker{Fn: func(ctx interface{}, node interface{}) (stop bool) {
  102. switch rowsNode := node.(type) {
  103. case *tree.ValuesClause:
  104. for _, row := range rowsNode.Rows {
  105. for _, column := range row {
  106. columnValue, err := parseValueExpr(column)
  107. if err != nil {
  108. rowsWalkFuncErr = err
  109. return true
  110. }
  111. values = append(values, *columnValue)
  112. }
  113. }
  114. }
  115. return false
  116. }}
  117. _, err = rowsWalker.Walk(rowsStatement, nil)
  118. if err != nil {
  119. walkFuncErr = err
  120. return true
  121. }
  122. if rowsWalkFuncErr != nil {
  123. walkFuncErr = rowsWalkFuncErr
  124. return true
  125. }
  126. // 组装columnValues
  127. tableRows := make(map[string]clauseTableRowValue)
  128. for i, column := range realNode.Columns.ToStrings() {
  129. tableRows[column] = values[i]
  130. }
  131. clause.tableRow = tableRows
  132. return false
  133. },
  134. }
  135. _, err = w.Walk(stmts, nil)
  136. if err != nil {
  137. return nil, err
  138. }
  139. if walkFuncErr != nil {
  140. return nil, walkFuncErr
  141. }
  142. return clause, nil
  143. }
  144. func deleteWalk(sql string) (*deleteClause, error) {
  145. clause := new(deleteClause)
  146. stmts, err := parser.Parse(sql)
  147. if err != nil {
  148. return nil, err
  149. }
  150. var walkFuncErr error
  151. w := &walk.AstWalker{
  152. Fn: func(ctx interface{}, node interface{}) (stop bool) {
  153. realNode := node.(*tree.Delete)
  154. // 获取table
  155. tableFmtCtx := tree.NewFmtCtx(tree.FmtSimple)
  156. realNode.Table.Format(tableFmtCtx)
  157. clause.table = tableFmtCtx.String()
  158. // 获取where
  159. clause.where = realNode.Where.Expr.String()
  160. return false
  161. },
  162. }
  163. _, err = w.Walk(stmts, nil)
  164. if err != nil {
  165. return nil, err
  166. }
  167. if walkFuncErr != nil {
  168. return nil, walkFuncErr
  169. }
  170. return clause, nil
  171. }
  172. func updateWalk(sql string) (*updateClause, error) {
  173. clause := new(updateClause)
  174. stmts, err := parser.Parse(sql)
  175. if err != nil {
  176. return nil, err
  177. }
  178. var walkFuncErr error
  179. w := &walk.AstWalker{
  180. Fn: func(ctx interface{}, node interface{}) (stop bool) {
  181. realNode := node.(*tree.Update)
  182. // 获取table
  183. tableFmtCtx := tree.NewFmtCtx(tree.FmtSimple)
  184. realNode.Table.Format(tableFmtCtx)
  185. clause.table = tableFmtCtx.String()
  186. // 获取where
  187. clause.where = realNode.Where.Expr.String()
  188. // 获取table row
  189. newTableRow := make(map[string]clauseTableRowValue)
  190. for _, expr := range realNode.Exprs {
  191. value, err := parseValueExpr(expr.Expr)
  192. if err != nil {
  193. walkFuncErr = err
  194. return true
  195. }
  196. newTableRow[expr.Names[0].String()] = *value
  197. }
  198. return false
  199. },
  200. }
  201. _, err = w.Walk(stmts, nil)
  202. if err != nil {
  203. return nil, err
  204. }
  205. if walkFuncErr != nil {
  206. return nil, walkFuncErr
  207. }
  208. return clause, nil
  209. }
  210. func selectWalk(sql string) (*insertClause, error) {
  211. return nil, nil
  212. }
  213. func parseValueExpr(valueExpr tree.Expr) (*clauseTableRowValue, error) {
  214. switch realColumn := valueExpr.(type) {
  215. case *tree.FuncExpr:
  216. // 函数类型
  217. value, kind, err := evaluateFuncExpr(realColumn)
  218. if err != nil {
  219. return nil, err
  220. }
  221. return &clauseTableRowValue{
  222. kind: kind,
  223. value: value,
  224. }, nil
  225. case *tree.DBool:
  226. // 布尔类型
  227. var boolValue bool
  228. stringValue := realColumn.String()
  229. if stringValue == "true" {
  230. boolValue = true
  231. }
  232. return &clauseTableRowValue{
  233. kind: clauseTableRowValueKindBool,
  234. value: boolValue,
  235. }, nil
  236. case *tree.StrVal:
  237. // 字符串类型或者函数参数是字符串的类型,这里通过比较字符串value,排除了函数参数类型
  238. return &clauseTableRowValue{
  239. kind: clauseTableRowValueKindString,
  240. value: realColumn.RawString(),
  241. }, nil
  242. case *tree.NumVal:
  243. // 数值类型,可以是整形或浮点型
  244. numKind := realColumn.Kind()
  245. if numKind == constant.Int {
  246. valueUint64, err := strconv.ParseUint(realColumn.String(), 10, 64)
  247. if err != nil {
  248. return nil, err
  249. }
  250. return &clauseTableRowValue{
  251. kind: clauseTableRowValueKindUint64,
  252. value: valueUint64,
  253. }, nil
  254. } else if numKind == constant.Float {
  255. valueFloat64, err := strconv.ParseFloat(realColumn.String(), 64)
  256. if err != nil {
  257. return nil, err
  258. }
  259. return &clauseTableRowValue{
  260. kind: clauseTableRowValueKindFloat64,
  261. value: valueFloat64,
  262. }, nil
  263. } else {
  264. return nil, errors.New("不支持的数值类型")
  265. }
  266. case *tree.UnresolvedName:
  267. return nil, errors.New("存在无法解析的名字,是否应该使用单引号")
  268. default:
  269. return nil, errors.New("未支持的数据类型")
  270. }
  271. }
  272. func evaluateFuncExpr(expr *tree.FuncExpr) (any, int, error) {
  273. if strings.HasPrefix(expr.String(), "parse_time") {
  274. if expr.Exprs == nil || len(expr.Exprs) != 2 {
  275. return nil, 0, errors.New("parse_time(time_str, time_format)")
  276. }
  277. timeStr := expr.Exprs[0].String()
  278. timeFormat := expr.Exprs[1].String()
  279. parsedTime, err := time.ParseInLocation(timeFormat, timeStr, time.Local)
  280. if err != nil {
  281. return nil, 0, err
  282. }
  283. return parsedTime, clauseTableRowValueKindTime, nil
  284. } else {
  285. return nil, 0, errors.New("不支持的函数")
  286. }
  287. }