Browse Source

完成update

yjp 1 year ago
parent
commit
e2f9bbbdd0
4 changed files with 51 additions and 9 deletions
  1. 22 5
      sql_parser.go
  2. 2 0
      test_sql.go
  3. 2 2
      v1.go
  4. 25 2
      v1_test.go

+ 22 - 5
sql_parser.go

@@ -2,6 +2,7 @@ package dpsapi
 
 import (
 	"errors"
+	"fmt"
 	"github.com/auxten/postgresql-parser/pkg/sql/parser"
 	"github.com/auxten/postgresql-parser/pkg/sql/sem/tree"
 	"github.com/auxten/postgresql-parser/pkg/walk"
@@ -239,7 +240,8 @@ func updateWalk(sql string) (*updateClause, error) {
 			clause.where = realNode.Where.Expr.String()
 
 			// 获取table row
-			newTableRow := make(map[string]clauseTableRowValue)
+			clause.newTableRow = make(map[string]clauseTableRowValue)
+
 			for _, expr := range realNode.Exprs {
 				value, err := parseValueExpr(expr.Expr)
 				if err != nil {
@@ -247,7 +249,7 @@ func updateWalk(sql string) (*updateClause, error) {
 					return true
 				}
 
-				newTableRow[expr.Names[0].String()] = *value
+				clause.newTableRow[fmt.Sprint(expr.Names[0])] = *value
 			}
 
 			return false
@@ -340,10 +342,25 @@ func evaluateFuncExpr(expr *tree.FuncExpr) (any, int, error) {
 			return nil, 0, errors.New("parse_time(time_str, time_format)")
 		}
 
-		timeStr := expr.Exprs[0].String()
-		timeFormat := expr.Exprs[1].String()
+		timeStrValue, err := parseValueExpr(expr.Exprs[0])
+		if err != nil {
+			return nil, 0, err
+		}
+
+		if timeStrValue.kind != clauseTableRowValueKindString {
+			return nil, 0, errors.New("时间字符串不是字符串类型")
+		}
+
+		timeFormatValue, err := parseValueExpr(expr.Exprs[1])
+		if err != nil {
+			return nil, 0, err
+		}
+
+		if timeFormatValue.kind != clauseTableRowValueKindString {
+			return nil, 0, errors.New("时间格式不是字符串类型")
+		}
 
-		parsedTime, err := time.ParseInLocation(timeFormat, timeStr, time.Local)
+		parsedTime, err := time.ParseInLocation(timeFormatValue.value.(string), timeStrValue.value.(string), time.Local)
 		if err != nil {
 			return nil, 0, err
 		}

+ 2 - 0
test_sql.go

@@ -8,11 +8,13 @@ import (
 const (
 	parseSqlInsert = `insert into students (id, name, age, rate, time, is_right) values ('aaa', 'yjp', 5, 92.5, parse_time('2024-01-01 00:00:00', '2006-01-02 15:04:05'), false)`
 	parseSqlDelete = `delete from students where id = 'aaa' AND name = 'yjp' AND age < 100 AND describe IN ('yjp')`
+	parseSqlUpdate = `update students set name = 'yjp', age = 5, age = 5, rate = 92.5, time = parse_time('2024-01-01 00:00:00', '2006-01-02 15:04:05'), is_right = false where id = 'aaa' AND name = 'yjp' AND age < 100 AND describe IN ('yjp')`
 )
 
 const (
 	sqlInsertFormat = `insert into %s (id, name, time, table_num) values ('%s', '%s', parse_time('%s', '2006-01-02 15:04:05'), %d)`
 	sqlDeleteFormat = `delete from %s where id = '%s'`
+	sqlUpdateFormat = `update %s set name = '%s', time = parse_time('%s', '2006-01-02 15:04:05'), table_num = %d where id = '%s'`
 )
 
 const (

+ 2 - 2
v1.go

@@ -132,8 +132,8 @@ func insertMap(clause *insertClause) map[string]any {
 	}
 
 	return map[string]any{
-		"table":      clause.table,
-		"table_rows": tableRows,
+		"table":     clause.table,
+		"table_row": tableRows,
 	}
 }
 

+ 25 - 2
v1_test.go

@@ -39,9 +39,25 @@ func TestApiV1OperateParse(t *testing.T) {
 		t.Fatal("表名不正确")
 	}
 
-	for columnName, value := range parsed["table_rows"].(map[string]any) {
+	for columnName, value := range parsed["table_row"].(map[string]any) {
 		if exceptedTableRows[columnName] != value {
-			t.Fatal("行数据不正确")
+			t.Fatal(columnName + "行数据不正确")
+		}
+	}
+
+	parsed = operateParse(t, parseSqlUpdate)
+
+	if parsed["table"].(string) != tableName {
+		t.Fatal("表名不正确")
+	}
+
+	if parsed["where"] != `(((id = 'aaa') AND (name = 'yjp')) AND (age < 100)) AND (describe IN ('yjp',))` {
+		t.Fatal("where不正确")
+	}
+
+	for columnName, value := range parsed["new_table_row"].(map[string]any) {
+		if exceptedTableRows[columnName] != value {
+			t.Fatal(columnName + "行数据不正确")
 		}
 	}
 
@@ -62,10 +78,16 @@ func TestApiV1Operate(t *testing.T) {
 
 	operatorID := simpleUUID()
 	tablePrefix := "test.a" + simpleUUID()[0:7]
+
 	id := simpleUUID()
 	name := simpleUUID()
 	now := time.Now().Local()
 	tableNum := rand.Intn(10) + 1
+
+	newName := simpleUUID()
+	newNow := time.Now().Local()
+	newTableNum := rand.Intn(10) + 1
+
 	keyColumns := []string{"id"}
 
 	autoMigrate(t, []client.AutoMigrateItem{
@@ -82,6 +104,7 @@ func TestApiV1Operate(t *testing.T) {
 	})
 
 	operate(t, fmt.Sprintf(sqlInsertFormat, tablePrefix, id, name, now.Format(time.DateTime), tableNum), keyColumns)
+	operate(t, fmt.Sprintf(sqlUpdateFormat, tablePrefix, newName, newNow.Format(time.DateTime), newTableNum, id), keyColumns)
 	operate(t, fmt.Sprintf(sqlDeleteFormat, tablePrefix, id), keyColumns)
 }