yjp 1 жил өмнө
parent
commit
186a43e49f
1 өөрчлөгдсөн 119 нэмэгдсэн , 102 устгасан
  1. 119 102
      sql_parser.go

+ 119 - 102
sql_parser.go

@@ -115,8 +115,6 @@ func insertWalk(sql string) (*insertClause, error) {
 			clause.table = tableFmtCtx.String()
 
 			// 解析rows
-			values := make([]clauseTableRowValue, 0)
-
 			rowsStatement, err := parser.Parse(realNode.Rows.String())
 			if err != nil {
 				walkFuncErr = err
@@ -124,90 +122,20 @@ func insertWalk(sql string) (*insertClause, error) {
 			}
 
 			var rowsWalkFuncErr error
-			funcStringParams := make([]string, 0)
+			values := make([]clauseTableRowValue, 0)
 
 			rowsWalker := &walk.AstWalker{Fn: func(ctx interface{}, node interface{}) (stop bool) {
 				switch rowsNode := node.(type) {
 				case *tree.ValuesClause:
 					for _, row := range rowsNode.Rows {
 						for _, column := range row {
-							switch realColumn := column.(type) {
-							case *tree.FuncExpr:
-								// 函数类型
-								value, kind, err := evaluateFuncExpr(realColumn, nil, &funcStringParams)
-								if err != nil {
-									rowsWalkFuncErr = err
-									return true
-								}
-
-								values = append(values, clauseTableRowValue{
-									kind:  kind,
-									value: value,
-								})
-							case *tree.DBool:
-								// 布尔类型
-								stringValue := realColumn.String()
-								if !canSkipValueInThisContext(&funcStringParams, stringValue) {
-									if stringValue == "false" {
-										values = append(values, clauseTableRowValue{
-											kind:  clauseTableRowValueKindBool,
-											value: false,
-										})
-									} else if stringValue == "true" {
-										values = append(values, clauseTableRowValue{
-											kind:  clauseTableRowValueKindBool,
-											value: true,
-										})
-									} else {
-										rowsWalkFuncErr = errors.New("不支持的bool值")
-										return true
-									}
-								}
-							case *tree.StrVal:
-								// 字符串类型或者函数参数是字符串的类型,这里通过比较字符串value,排除了函数参数类型
-								stringValue := realColumn.String()
-								if !canSkipValueInThisContext(&funcStringParams, stringValue) {
-									values = append(values, clauseTableRowValue{
-										kind:  clauseTableRowValueKindString,
-										value: realColumn.RawString(),
-									})
-								}
-							case *tree.NumVal:
-								// 数值类型,可以是整形或浮点型
-								stringValue := realColumn.String()
-								if !canSkipValueInThisContext(&funcStringParams, stringValue) {
-									numKind := realColumn.Kind()
-									if numKind == constant.Int {
-										valueUint64, err := strconv.ParseUint(realColumn.String(), 10, 64)
-										if err != nil {
-											rowsWalkFuncErr = err
-											return true
-										}
-
-										values = append(values, clauseTableRowValue{
-											kind:  clauseTableRowValueKindUint64,
-											value: valueUint64,
-										})
-									} else if numKind == constant.Float {
-										valueFloat64, err := strconv.ParseFloat(realColumn.String(), 64)
-										if err != nil {
-											rowsWalkFuncErr = err
-											return true
-										}
-
-										values = append(values, clauseTableRowValue{
-											kind:  clauseTableRowValueKindFloat64,
-											value: valueFloat64,
-										})
-									} else {
-										rowsWalkFuncErr = errors.New("不支持的数值类型")
-										return true
-									}
-								}
-							case *tree.UnresolvedName:
-								rowsWalkFuncErr = errors.New("存在无法解析的名字,是否应该使用单引号")
+							columnValue, err := parseValueExpr(column)
+							if err != nil {
+								rowsWalkFuncErr = err
 								return true
 							}
+
+							values = append(values, *columnValue)
 						}
 					}
 				}
@@ -288,15 +216,125 @@ func deleteWalk(sql string) (*deleteClause, error) {
 	return clause, nil
 }
 
-func updateWalk(sql string) (*insertClause, error) {
-	return nil, nil
+func updateWalk(sql string) (*updateClause, error) {
+	clause := new(updateClause)
+
+	stmts, err := parser.Parse(sql)
+	if err != nil {
+		return nil, err
+	}
+
+	var walkFuncErr error
+
+	w := &walk.AstWalker{
+		Fn: func(ctx interface{}, node interface{}) (stop bool) {
+			realNode := node.(*tree.Update)
+
+			// 获取table
+			tableFmtCtx := tree.NewFmtCtx(tree.FmtSimple)
+			realNode.Table.Format(tableFmtCtx)
+			clause.table = tableFmtCtx.String()
+
+			// 获取where
+			clause.where = realNode.Where.Expr.String()
+
+			// 获取table row
+			newTableRow := make(map[string]clauseTableRowValue)
+			for _, expr := range realNode.Exprs {
+				value, err := parseValueExpr(expr.Expr)
+				if err != nil {
+					walkFuncErr = err
+					return true
+				}
+
+				newTableRow[expr.Names[0].String()] = *value
+			}
+
+			return false
+		},
+	}
+
+	_, err = w.Walk(stmts, nil)
+	if err != nil {
+		return nil, err
+	}
+
+	if walkFuncErr != nil {
+		return nil, walkFuncErr
+	}
+
+	return clause, nil
 }
 
 func selectWalk(sql string) (*insertClause, error) {
 	return nil, nil
 }
 
-func evaluateFuncExpr(expr *tree.FuncExpr, funcName *string, stringParams *[]string) (any, int, error) {
+func parseValueExpr(valueExpr tree.Expr) (*clauseTableRowValue, error) {
+	switch realColumn := valueExpr.(type) {
+	case *tree.FuncExpr:
+		// 函数类型
+		value, kind, err := evaluateFuncExpr(realColumn)
+		if err != nil {
+			return nil, err
+		}
+
+		return &clauseTableRowValue{
+			kind:  kind,
+			value: value,
+		}, nil
+	case *tree.DBool:
+		// 布尔类型
+		var boolValue bool
+		stringValue := realColumn.String()
+		if stringValue == "true" {
+			boolValue = true
+		}
+
+		return &clauseTableRowValue{
+			kind:  clauseTableRowValueKindBool,
+			value: boolValue,
+		}, nil
+	case *tree.StrVal:
+		// 字符串类型或者函数参数是字符串的类型,这里通过比较字符串value,排除了函数参数类型
+		return &clauseTableRowValue{
+			kind:  clauseTableRowValueKindString,
+			value: realColumn.RawString(),
+		}, nil
+	case *tree.NumVal:
+		// 数值类型,可以是整形或浮点型
+		numKind := realColumn.Kind()
+		if numKind == constant.Int {
+			valueUint64, err := strconv.ParseUint(realColumn.String(), 10, 64)
+			if err != nil {
+				return nil, err
+			}
+
+			return &clauseTableRowValue{
+				kind:  clauseTableRowValueKindUint64,
+				value: valueUint64,
+			}, nil
+		} else if numKind == constant.Float {
+			valueFloat64, err := strconv.ParseFloat(realColumn.String(), 64)
+			if err != nil {
+				return nil, err
+			}
+
+			return &clauseTableRowValue{
+				kind:  clauseTableRowValueKindFloat64,
+				value: valueFloat64,
+			}, nil
+		} else {
+			return nil, errors.New("不支持的数值类型")
+		}
+	case *tree.UnresolvedName:
+		return nil, errors.New("存在无法解析的名字,是否应该使用单引号")
+	default:
+		return nil, errors.New("未支持的数据类型")
+	}
+}
+
+func evaluateFuncExpr(expr *tree.FuncExpr) (any, int, error) {
 	if strings.HasPrefix(expr.String(), "parse_time") {
 		if expr.Exprs == nil || len(expr.Exprs) != 2 {
 			return nil, 0, errors.New("parse_time(time_str, time_format)")
@@ -305,14 +343,6 @@ func evaluateFuncExpr(expr *tree.FuncExpr, funcName *string, stringParams *[]str
 		timeStr := expr.Exprs[0].String()
 		timeFormat := expr.Exprs[1].String()
 
-		if funcName != nil {
-			*funcName = "parse_time"
-		}
-
-		if stringParams != nil {
-			*stringParams = append(*stringParams, timeStr, timeFormat)
-		}
-
 		parsedTime, err := time.ParseInLocation(timeFormat, timeStr, time.Local)
 		if err != nil {
 			return nil, 0, err
@@ -323,16 +353,3 @@ func evaluateFuncExpr(expr *tree.FuncExpr, funcName *string, stringParams *[]str
 		return nil, 0, errors.New("不支持的函数")
 	}
 }
-
-func canSkipValueInThisContext(stringParams *[]string, stringValue string) bool {
-	if stringParams == nil || *stringParams == nil || len(*stringParams) == 0 {
-		return false
-	}
-
-	if (*stringParams)[0] == stringValue {
-		*stringParams = (*stringParams)[1:]
-		return true
-	}
-
-	return false
-}