sql_parser.go 6.2 KB


  1. package dpsapi
  2. import (
  3. "errors"
  4. "github.com/auxten/postgresql-parser/pkg/sql/parser"
  5. "github.com/auxten/postgresql-parser/pkg/sql/sem/tree"
  6. "github.com/auxten/postgresql-parser/pkg/walk"
  7. "go/constant"
  8. "strconv"
  9. "strings"
  10. "time"
  11. )
  12. type insertClause struct {
  13. table string
  14. keyColumns []string
  15. tableRows map[string]any
  16. }
  17. type deleteClause struct {
  18. }
  19. type updateClause struct {
  20. }
  21. type selectClause struct {
  22. selectExpr []string
  23. from string
  24. where []string
  25. limit int
  26. offset int
  27. }
  28. func parseSql(sqlStr string) ([]any, error) {
  29. sqls := strings.Split(sqlStr, ";")
  30. sqlClauses := make([]any, 0)
  31. for _, sql := range sqls {
  32. trimSQL := strings.TrimSpace(sql)
  33. upperTrimSQL := strings.ToUpper(trimSQL)
  34. var clause any
  35. if strings.HasPrefix(upperTrimSQL, "INSERT") {
  36. innerClause, err := insertWalk(sql)
  37. if err != nil {
  38. return nil, err
  39. }
  40. clause = innerClause
  41. } else if strings.HasPrefix(upperTrimSQL, "DELETE") {
  42. innerClause, err := deleteWalk(sql)
  43. if err != nil {
  44. return nil, err
  45. }
  46. clause = innerClause
  47. } else if strings.HasPrefix(upperTrimSQL, "UPDATE") {
  48. innerClause, err := updateWalk(sql)
  49. if err != nil {
  50. return nil, err
  51. }
  52. clause = innerClause
  53. } else if strings.HasPrefix(upperTrimSQL, "SELECT") {
  54. innerClause, err := selectWalk(sql)
  55. if err != nil {
  56. return nil, err
  57. }
  58. clause = innerClause
  59. }
  60. sqlClauses = append(sqlClauses, clause)
  61. }
  62. return sqlClauses, nil
  63. }
  64. func insertWalk(sql string) (*insertClause, error) {
  65. clause := new(insertClause)
  66. stmts, err := parser.Parse(sql)
  67. if err != nil {
  68. return nil, err
  69. }
  70. var walkFuncErr error
  71. w := &walk.AstWalker{
  72. Fn: func(ctx interface{}, node interface{}) (stop bool) {
  73. realNode := node.(*tree.Insert)
  74. // 获取table
  75. tableFmtCtx := tree.NewFmtCtx(tree.FmtSimple)
  76. realNode.Table.Format(tableFmtCtx)
  77. clause.table = tableFmtCtx.String()
  78. // 解析rows
  79. values := make([]any, 0)
  80. rowsStatement, err := parser.Parse(realNode.Rows.String())
  81. if err != nil {
  82. walkFuncErr = err
  83. return true
  84. }
  85. var rowsWalkFuncErr error
  86. funcStringParams := make([]string, 0)
  87. rowsWalker := &walk.AstWalker{Fn: func(ctx interface{}, node interface{}) (stop bool) {
  88. switch rowsNode := node.(type) {
  89. case *tree.FuncExpr:
  90. // 函数类型
  91. value, err := evaluateFuncExpr(rowsNode, nil, &funcStringParams)
  92. if err != nil {
  93. rowsWalkFuncErr = err
  94. return true
  95. }
  96. values = append(values, value)
  97. case *tree.DBool:
  98. // 布尔类型
  99. stringValue := rowsNode.String()
  100. if !canSkipValueInThisContext(&funcStringParams, stringValue) {
  101. if stringValue == "false" {
  102. values = append(values, false)
  103. } else if stringValue == "true" {
  104. values = append(values, true)
  105. } else {
  106. rowsWalkFuncErr = errors.New("不支持的bool值")
  107. return true
  108. }
  109. }
  110. case *tree.UnresolvedName:
  111. // 字符串类型或者函数参数类型,这里通过比较字符串value,排除了函数参数类型
  112. stringValue := rowsNode.String()
  113. if !canSkipValueInThisContext(&funcStringParams, stringValue) {
  114. values = append(values, stringValue)
  115. }
  116. case *tree.NumVal:
  117. // 数值类型,可以是整形或浮点型
  118. stringValue := rowsNode.String()
  119. if !canSkipValueInThisContext(&funcStringParams, stringValue) {
  120. numKind := rowsNode.Kind()
  121. if numKind == constant.Int {
  122. valueUint64, err := strconv.ParseUint(rowsNode.String(), 10, 64)
  123. if err != nil {
  124. rowsWalkFuncErr = err
  125. return true
  126. }
  127. values = append(values, valueUint64)
  128. } else if numKind == constant.Float {
  129. valueFloat64, err := strconv.ParseFloat(rowsNode.String(), 64)
  130. if err != nil {
  131. rowsWalkFuncErr = err
  132. return true
  133. }
  134. values = append(values, valueFloat64)
  135. } else {
  136. rowsWalkFuncErr = errors.New("不支持的数值类型")
  137. return true
  138. }
  139. }
  140. }
  141. return false
  142. }}
  143. _, err = rowsWalker.Walk(rowsStatement, nil)
  144. if err != nil {
  145. walkFuncErr = err
  146. return true
  147. }
  148. if rowsWalkFuncErr != nil {
  149. walkFuncErr = rowsWalkFuncErr
  150. return true
  151. }
  152. // 组装columnValues
  153. hasIDColumn := false
  154. keyColumns := make([]string, 0)
  155. allColumns := make([]string, 0)
  156. tableRows := make(map[string]any)
  157. for i, column := range realNode.Columns.ToStrings() {
  158. if column == "id" {
  159. hasIDColumn = true
  160. }
  161. if strings.HasPrefix(column, "**") {
  162. column = strings.TrimPrefix(column, "**")
  163. keyColumns = append(keyColumns, column)
  164. }
  165. allColumns = append(allColumns, column)
  166. tableRows[column] = values[i]
  167. }
  168. if keyColumns != nil && len(keyColumns) != 0 {
  169. clause.keyColumns = keyColumns
  170. } else if hasIDColumn {
  171. clause.keyColumns = []string{"id"}
  172. } else {
  173. clause.keyColumns = allColumns
  174. }
  175. clause.tableRows = tableRows
  176. return false
  177. },
  178. }
  179. _, err = w.Walk(stmts, nil)
  180. if err != nil {
  181. return nil, err
  182. }
  183. if walkFuncErr != nil {
  184. return nil, walkFuncErr
  185. }
  186. return clause, nil
  187. }
  188. func deleteWalk(sql string) (*insertClause, error) {
  189. return nil, nil
  190. }
  191. func updateWalk(sql string) (*insertClause, error) {
  192. return nil, nil
  193. }
  194. func selectWalk(sql string) (*insertClause, error) {
  195. return nil, nil
  196. }
  197. func evaluateFuncExpr(expr *tree.FuncExpr, funcName *string, stringParams *[]string) (any, error) {
  198. if strings.HasPrefix(expr.String(), "parse_time") {
  199. if expr.Exprs == nil || len(expr.Exprs) != 2 {
  200. return nil, errors.New("parse_time(time_str, time_format)")
  201. }
  202. timeStr := expr.Exprs[0].String()
  203. timeFormat := expr.Exprs[1].String()
  204. if funcName != nil {
  205. *funcName = "parse_time"
  206. }
  207. if stringParams != nil {
  208. *stringParams = append(*stringParams, timeStr, timeFormat)
  209. }
  210. return time.ParseInLocation(timeFormat, timeStr, time.Local)
  211. } else {
  212. return nil, errors.New("不支持的函数")
  213. }
  214. }
  215. func canSkipValueInThisContext(stringParams *[]string, stringValue string) bool {
  216. if stringParams == nil || *stringParams == nil || len(*stringParams) == 0 {
  217. return false
  218. }
  219. if (*stringParams)[0] == stringValue {
  220. *stringParams = (*stringParams)[1:]
  221. return true
  222. }
  223. return false
  224. }