Browse Source

修改bug

yjp 1 year ago
parent
commit
38aad92675
3 changed files with 168 additions and 31 deletions
  1. 136 8
      sql_parser.go
  2. 30 21
      v1.go
  3. 2 2
      v1_test.go

+ 136 - 8
sql_parser.go

@@ -42,11 +42,15 @@ type updateClause struct {
 }
 
 type selectClause struct {
-	selectExpr []string
-	from       string
-	where      []string
-	limit      int
-	offset     int
+	table        string
+	fromSubQuery string
+	selectClause string
+	where        string
+	orderBy      []string
+	groupBy      []string
+	having       string
+	pageNo       int
+	pageSize     int
 }
 
 // 调试很重要的函数,可以看到一个节点的实际类型以及包含字段的类型
@@ -252,14 +256,84 @@ func updateWalk(sql string) (*updateClause, error) {
 	return clause, nil
 }
 
-func selectWalk(sql string) (*insertClause, error) {
-	return nil, nil
+func selectWalk(sql string) (*selectClause, error) {
+	clause := new(selectClause)
+
+	stmts, err := parser.Parse(sql)
+	if err != nil {
+		return nil, err
+	}
+
+	var walkFuncErr error
+
+	w := &walk.AstWalker{
+		Fn: func(ctx interface{}, node interface{}) (stop bool) {
+			realNode := node.(*tree.Select)
+			nodeSelectClause := realNode.Select.(*tree.SelectClause)
+
+			// select
+			clause.selectClause = parseSelect(nodeSelectClause.Exprs)
+
+			// from
+			asFromSubQuery, from := parseFrom(&nodeSelectClause.From)
+			if asFromSubQuery {
+				clause.fromSubQuery = from
+			} else {
+				clause.table = from
+			}
+
+			// where
+			if nodeSelectClause.Where != nil {
+				clause.where = parseWhere(nodeSelectClause.Where)
+			}
+
+			// order by
+			if realNode.OrderBy != nil {
+				clause.orderBy = parseOrderBy(realNode.OrderBy)
+			}
+
+			// limit
+			if realNode.Limit != nil {
+				pageNo, pageSize, err := parseLimit(realNode.Limit)
+				if err != nil {
+					walkFuncErr = err
+					return true
+				}
+
+				clause.pageNo = pageNo
+				clause.pageSize = pageSize
+			}
+
+			// group by
+			if nodeSelectClause.GroupBy != nil {
+				clause.groupBy = parseGroupBy(nodeSelectClause.GroupBy)
+			}
+
+			// having
+			if nodeSelectClause.Having != nil {
+				clause.having = parseWhere(nodeSelectClause.Having)
+			}
+
+			return false
+		},
+	}
+
+	_, err = w.Walk(stmts, nil)
+	if err != nil {
+		return nil, err
+	}
+
+	if walkFuncErr != nil {
+		return nil, walkFuncErr
+	}
+
+	return clause, nil
 }
 
 func parseTableExpr(tableExpr tree.TableExpr) (string, error) {
 	switch table := tableExpr.(type) {
 	case *tree.TableName:
-		return table.Table(), nil
+		return table.String(), nil
 	case *tree.AliasedTableExpr:
 		return table.String(), nil
 	default:
@@ -267,10 +341,64 @@ func parseTableExpr(tableExpr tree.TableExpr) (string, error) {
 	}
 }
 
+func parseSelect(selectExprs tree.SelectExprs) string {
+	selectFmtCtx := tree.NewFmtCtx(tree.FmtBareStrings)
+	selectExprs.Format(selectFmtCtx)
+	return selectFmtCtx.String()
+}
+
+func parseFrom(from *tree.From) (bool, string) {
+	asFromSubQuery := false
+
+	switch fromTable := from.Tables[0].(type) {
+	case *tree.JoinTableExpr:
+		asFromSubQuery = true
+	case *tree.AliasedTableExpr:
+		_, ok := fromTable.Expr.(*tree.Subquery)
+		asFromSubQuery = ok
+	}
+
+	return asFromSubQuery, fmt.Sprint(from.Tables)
+}
+
 func parseWhere(where *tree.Where) string {
 	return where.Expr.String()
 }
 
+func parseOrderBy(orderBy tree.OrderBy) []string {
+	orderBySlice := make([]string, 0)
+
+	for _, o := range orderBy {
+		orderBySlice = append(orderBySlice, o.Expr.String())
+	}
+
+	return orderBySlice
+}
+
+func parseLimit(limit *tree.Limit) (int, int, error) {
+	pageNo, err := strconv.Atoi(limit.Offset.String())
+	if err != nil {
+		return 0, 0, err
+	}
+
+	pageSize, err := strconv.Atoi(limit.Count.String())
+	if err != nil {
+		return 0, 0, err
+	}
+
+	return pageNo, pageSize, nil
+}
+
+func parseGroupBy(groupBy tree.GroupBy) []string {
+	groupBySlice := make([]string, 0)
+
+	for _, groupExpr := range groupBy {
+		groupBySlice = append(groupBySlice, groupExpr.String())
+	}
+
+	return groupBySlice
+}
+
 func parseExpr(valueExpr tree.Expr) (*clauseTableRowValue, error) {
 	switch realColumn := valueExpr.(type) {
 	case *tree.FuncExpr:

+ 30 - 21
v1.go

@@ -77,17 +77,17 @@ func ApiV1(binding *http_binding.Binding, dpsAddress string, operatorIDFunc Oper
 
 			operatorID, err := operatorIDFunc(c)
 			if err != nil {
-				return map[string]any{"table_rows": make([]map[string]any, 0)}, err
+				return map[string]any{"tableRows": make([]map[string]any, 0)}, err
 			}
 
 			parsedClauses, err := parseSql(inputModel.SQL)
 			if err != nil {
-				return map[string]any{"table_rows": make([]map[string]any, 0)}, err
+				return map[string]any{"tableRows": make([]map[string]any, 0)}, err
 			}
 
 			dpsClient, err := dps.NewClient(dpsAddress, "v1", inputModel.DatabaseID)
 			if err != nil {
-				return map[string]any{"table_rows": make([]map[string]any, 0)}, err
+				return map[string]any{"tableRows": make([]map[string]any, 0)}, err
 			}
 
 			result := make([]map[string]any, 0)
@@ -117,10 +117,10 @@ func ApiV1(binding *http_binding.Binding, dpsAddress string, operatorIDFunc Oper
 				return nil
 			})
 			if err != nil {
-				return map[string]any{"table_rows": make([]map[string]any, 0)}, err
+				return map[string]any{"tableRows": make([]map[string]any, 0)}, err
 			}
 
-			return map[string]any{"table_rows": result}, nil
+			return map[string]any{"tableRows": result}, nil
 		},
 	})
 }
@@ -132,8 +132,8 @@ func insertMap(clause *insertClause) map[string]any {
 	}
 
 	return map[string]any{
-		"table":     clause.table,
-		"table_row": tableRows,
+		"table":    clause.table,
+		"tableRow": tableRows,
 	}
 }
 
@@ -151,14 +151,24 @@ func updateMap(clause *updateClause) map[string]any {
 	}
 
 	return map[string]any{
-		"table":         clause.table,
-		"where":         clause.where,
-		"new_table_row": newTableRows,
+		"table":       clause.table,
+		"where":       clause.where,
+		"newTableRow": newTableRows,
 	}
 }
 
 func selectMap(clause *selectClause) map[string]any {
-	return map[string]any{}
+	return map[string]any{
+		"table":        clause.table,
+		"fromSubQuery": clause.fromSubQuery,
+		"selectClause": clause.selectClause,
+		"where":        clause.where,
+		"orderBy":      clause.orderBy,
+		"groupBy":      clause.groupBy,
+		"having":       clause.having,
+		"pageNo":       clause.pageNo,
+		"pageSize":     clause.pageSize,
+	}
 }
 
 func doInsert(tx client.Transaction, version string, keyColumns []string, clause *insertClause, operatorID string) error {
@@ -222,17 +232,16 @@ func doUpdate(tx client.Transaction, version string, keyColumns []string, clause
 
 func doSelect(dpsClient client.Client, version string, clause *selectClause) ([]map[string]any, error) {
 	statement, tableRows, err := dpsClient.CommonQueryOnly(&client.CommonQueryRequest{
-		TablePrefixWithSchema: "",
-		Table:                 "",
+		TablePrefixWithSchema: clause.table,
+		Table:                 clause.fromSubQuery,
 		Version:               version,
-		Select:                nil,
-		Where:                 nil,
-		OrderBy:               nil,
-		GroupBy:               nil,
-		Joins:                 nil,
-		Having:                nil,
-		PageNo:                0,
-		PageSize:              0,
+		Select:                client.NewClause().Common(clause.selectClause),
+		Where:                 client.NewClause().Common(clause.where),
+		OrderBy:               clause.orderBy,
+		GroupBy:               clause.groupBy,
+		Having:                client.NewClause().Common(clause.having),
+		PageNo:                clause.pageNo,
+		PageSize:              clause.pageSize,
 	})
 	if err != nil {
 		fmt.Println(statement)

+ 2 - 2
v1_test.go

@@ -39,7 +39,7 @@ func TestApiV1OperateParse(t *testing.T) {
 		t.Fatal("表名不正确")
 	}
 
-	for columnName, value := range parsed["table_row"].(map[string]any) {
+	for columnName, value := range parsed["tableRow"].(map[string]any) {
 		if exceptedTableRows[columnName] != value {
 			t.Fatal(columnName + "行数据不正确")
 		}
@@ -55,7 +55,7 @@ func TestApiV1OperateParse(t *testing.T) {
 		t.Fatal("where不正确")
 	}
 
-	for columnName, value := range parsed["new_table_row"].(map[string]any) {
+	for columnName, value := range parsed["newTableRow"].(map[string]any) {
 		if exceptedTableRows[columnName] != value {
 			t.Fatal(columnName + "行数据不正确")
 		}