Browse Source

封装函数

yjp 1 year ago
parent
commit
06dfaab4cb
1 changed files with 43 additions and 14 deletions
  1. 43 14
      sql_parser.go

+ 43 - 14
sql_parser.go

@@ -116,8 +116,13 @@ func insertWalk(sql string) (*insertClause, error) {
 			realNode := node.(*tree.Insert)
 
 			// 获取table
-			tableName := realNode.Table.(*tree.TableName)
-			clause.table = tableName.Table()
+			tableName, err := parseTableExpr(realNode.Table)
+			if err != nil {
+				walkFuncErr = err
+				return true
+			}
+
+			clause.table = tableName
 
 			// 获取table row
 			clause.tableRow = make(map[string]clauseTableRowValue)
@@ -125,7 +130,7 @@ func insertWalk(sql string) (*insertClause, error) {
 			valuesClause := realNode.Rows.Select.(*tree.ValuesClause)
 			for _, row := range valuesClause.Rows {
 				for i, column := range row {
-					columnValue, err := parseValueExpr(column)
+					columnValue, err := parseExpr(column)
 					if err != nil {
 						walkFuncErr = err
 						return true
@@ -166,11 +171,16 @@ func deleteWalk(sql string) (*deleteClause, error) {
 			realNode := node.(*tree.Delete)
 
 			// 获取table
-			tableName := realNode.Table.(*tree.AliasedTableExpr)
-			clause.table = tableName.String()
+			tableName, err := parseTableExpr(realNode.Table)
+			if err != nil {
+				walkFuncErr = err
+				return true
+			}
+
+			clause.table = tableName
 
 			// 获取where
-			clause.where = realNode.Where.Expr.String()
+			clause.where = parseWhere(realNode.Where)
 
 			return false
 		},
@@ -203,17 +213,21 @@ func updateWalk(sql string) (*updateClause, error) {
 			realNode := node.(*tree.Update)
 
 			// 获取table
-			tableName := realNode.Table.(*tree.AliasedTableExpr)
-			clause.table = tableName.String()
+			tableName, err := parseTableExpr(realNode.Table)
+			if err != nil {
+				walkFuncErr = err
+				return true
+			}
+
+			clause.table = tableName
 
 			// 获取where
-			clause.where = realNode.Where.Expr.String()
+			clause.where = parseWhere(realNode.Where)
 
 			// 获取table row
 			clause.newTableRow = make(map[string]clauseTableRowValue)
-
 			for _, expr := range realNode.Exprs {
-				value, err := parseValueExpr(expr.Expr)
+				value, err := parseExpr(expr.Expr)
 				if err != nil {
 					walkFuncErr = err
 					return true
@@ -242,7 +256,22 @@ func selectWalk(sql string) (*insertClause, error) {
 	return nil, nil
 }
 
-func parseValueExpr(valueExpr tree.Expr) (*clauseTableRowValue, error) {
+func parseTableExpr(tableExpr tree.TableExpr) (string, error) {
+	switch table := tableExpr.(type) {
+	case *tree.TableName:
+		return table.Table(), nil
+	case *tree.AliasedTableExpr:
+		return table.String(), nil
+	default:
+		return "", errors.New("不支持的TableExpr")
+	}
+}
+
+func parseWhere(where *tree.Where) string {
+	return where.Expr.String()
+}
+
+func parseExpr(valueExpr tree.Expr) (*clauseTableRowValue, error) {
 	switch realColumn := valueExpr.(type) {
 	case *tree.FuncExpr:
 		// 函数类型
@@ -312,7 +341,7 @@ func evaluateFuncExpr(expr *tree.FuncExpr) (any, int, error) {
 			return nil, 0, errors.New("parse_time(time_str, time_format)")
 		}
 
-		timeStrValue, err := parseValueExpr(expr.Exprs[0])
+		timeStrValue, err := parseExpr(expr.Exprs[0])
 		if err != nil {
 			return nil, 0, err
 		}
@@ -321,7 +350,7 @@ func evaluateFuncExpr(expr *tree.FuncExpr) (any, int, error) {
 			return nil, 0, errors.New("时间字符串不是字符串类型")
 		}
 
-		timeFormatValue, err := parseValueExpr(expr.Exprs[1])
+		timeFormatValue, err := parseExpr(expr.Exprs[1])
 		if err != nil {
 			return nil, 0, err
 		}