Browse Source

修改key columns由界面上传递

yjp 1 year ago
parent
commit
d2f018db94
4 changed files with 24 additions and 48 deletions
  1. 4 3
      request.go
  2. 2 26
      sql_parser.go
  3. 14 10
      v1.go
  4. 4 9
      v1_test.go

+ 4 - 3
request.go

@@ -9,7 +9,8 @@ type OperateParseRequest struct {
 }
 
 type OperateRequest struct {
-	DatabaseID string `json:"databaseId" binding:"required"`
-	Version    string `json:"version"`
-	SQL        string `json:"sql" binding:"required"`
+	DatabaseID string   `json:"databaseId" binding:"required"`
+	Version    string   `json:"version"`
+	KeyColumns []string `json:"keyColumns"`
+	SQL        string   `json:"sql" binding:"required"`
 }

+ 2 - 26
sql_parser.go

@@ -25,9 +25,8 @@ type clauseTableRowValue struct {
 }
 
 type insertClause struct {
-	table      string
-	keyColumns []string
-	tableRows  map[string]clauseTableRowValue
+	table     string
+	tableRows map[string]clauseTableRowValue
 }
 
 type deleteClause struct {
@@ -218,34 +217,11 @@ func insertWalk(sql string) (*insertClause, error) {
 			}
 
 			// 组装columnValues
-			hasIDColumn := false
-			keyColumns := make([]string, 0)
-			allColumns := make([]string, 0)
 			tableRows := make(map[string]clauseTableRowValue)
-
 			for i, column := range realNode.Columns.ToStrings() {
-				if column == "id" {
-					hasIDColumn = true
-				}
-
-				if strings.HasPrefix(column, "**") {
-					column = strings.TrimPrefix(column, "**")
-					keyColumns = append(keyColumns, column)
-				}
-
-				allColumns = append(allColumns, column)
-
 				tableRows[column] = values[i]
 			}
 
-			if keyColumns != nil && len(keyColumns) != 0 {
-				clause.keyColumns = keyColumns
-			} else if hasIDColumn {
-				clause.keyColumns = []string{"id"}
-			} else {
-				clause.keyColumns = allColumns
-			}
-
 			clause.tableRows = tableRows
 
 			return false

+ 14 - 10
v1.go

@@ -70,6 +70,11 @@ func ApiV1(binding *http_binding.Binding, dpsAddress string, operatorIDFunc Oper
 				version = inputModel.Version
 			}
 
+			keyColumns := []string{"id"}
+			if inputModel.KeyColumns != nil && len(inputModel.KeyColumns) != 0 {
+				keyColumns = inputModel.KeyColumns
+			}
+
 			operatorID, err := operatorIDFunc(c)
 			if err != nil {
 				return nil, err
@@ -89,11 +94,11 @@ func ApiV1(binding *http_binding.Binding, dpsAddress string, operatorIDFunc Oper
 				for _, parsedClause := range parsedClauses {
 					switch clause := parsedClause.(type) {
 					case *insertClause:
-						return doInsert(tx, version, clause, operatorID)
+						return doInsert(tx, version, keyColumns, clause, operatorID)
 					case *deleteClause:
-						return doDelete(tx, version, clause, operatorID)
+						return doDelete(tx, version, keyColumns, clause, operatorID)
 					case *updateClause:
-						return doUpdate(tx, version, clause, operatorID)
+						return doUpdate(tx, version, keyColumns, clause, operatorID)
 					case *selectClause:
 						return doSelect(dpsClient, version, clause)
 					default:
@@ -119,9 +124,8 @@ func insertMap(clause *insertClause) map[string]any {
 	}
 
 	return map[string]any{
-		"table":       clause.table,
-		"key_columns": clause.keyColumns,
-		"table_rows":  tableRows,
+		"table":      clause.table,
+		"table_rows": tableRows,
 	}
 }
 
@@ -140,7 +144,7 @@ func selectMap(clause *selectClause) map[string]any {
 	return map[string]any{}
 }
 
-func doInsert(tx client.Transaction, version string, clause *insertClause, operatorID string) error {
+func doInsert(tx client.Transaction, version string, keyColumns []string, clause *insertClause, operatorID string) error {
 	tableRow := client.NewTableRow()
 	for columnName, value := range clause.tableRows {
 		switch value.kind {
@@ -162,7 +166,7 @@ func doInsert(tx client.Transaction, version string, clause *insertClause, opera
 	statement, err := tx.InsertTx(&client.InsertRequest{
 		TablePrefixWithSchema: clause.table,
 		Version:               version,
-		KeyColumns:            clause.keyColumns,
+		KeyColumns:            keyColumns,
 		TableRow:              tableRow,
 		UserID:                operatorID,
 	})
@@ -174,11 +178,11 @@ func doInsert(tx client.Transaction, version string, clause *insertClause, opera
 	return nil
 }
 
-func doDelete(tx client.Transaction, version string, clause *deleteClause, operatorID string) error {
+func doDelete(tx client.Transaction, version string, keyColumns []string, clause *deleteClause, operatorID string) error {
 	return nil
 }
 
-func doUpdate(tx client.Transaction, version string, clause *updateClause, operatorID string) error {
+func doUpdate(tx client.Transaction, version string, keyColumns []string, clause *updateClause, operatorID string) error {
 	return nil
 }
 

+ 4 - 9
v1_test.go

@@ -19,7 +19,6 @@ func TestApiV1OperateParse(t *testing.T) {
 
 	operatorID := simpleUUID()
 	tableName := "students"
-	exceptedKeyColumns := []string{"id", "name", "age", "rate", "time", "is_right"}
 	exceptedTableRows := map[string]any{
 		"id":       "aaa",
 		"name":     "yjp",
@@ -40,12 +39,6 @@ func TestApiV1OperateParse(t *testing.T) {
 		t.Fatal("表名不正确")
 	}
 
-	for i, keyColumn := range parsed["key_columns"].([]any) {
-		if exceptedKeyColumns[i] != keyColumn {
-			t.Fatal("没有关键列数值或顺序不正确")
-		}
-	}
-
 	for columnName, value := range parsed["table_rows"].(map[string]any) {
 		if exceptedTableRows[columnName] != value {
 			t.Fatal("行数据不正确")
@@ -63,6 +56,7 @@ func TestApiV1Operate(t *testing.T) {
 	name := simpleUUID()
 	now := time.Now().Local()
 	tableNum := rand.Intn(10) + 1
+	keyColumns := []string{"id"}
 
 	autoMigrate(t, []client.AutoMigrateItem{
 		{
@@ -77,7 +71,7 @@ func TestApiV1Operate(t *testing.T) {
 		return operatorID, nil
 	})
 
-	operate(t, fmt.Sprintf(sqlInsertFormat, tablePrefix, id, name, now.Format(time.DateTime), tableNum))
+	operate(t, fmt.Sprintf(sqlInsertFormat, tablePrefix, id, name, now.Format(time.DateTime), tableNum), keyColumns)
 }
 
 func autoMigrate(t *testing.T, items []client.AutoMigrateItem) {
@@ -119,13 +113,14 @@ func operateParse(t *testing.T, sql string) map[string]any {
 	return result.Parsed
 }
 
-func operate(t *testing.T, sql string) {
+func operate(t *testing.T, sql string, keyColumns []string) {
 	result := new(response.MsgResponse)
 
 	resp, err := resty.New().R().
 		SetBody(&OperateRequest{
 			DatabaseID: testDatabaseID,
 			Version:    "v1",
+			KeyColumns: keyColumns,
 			SQL:        sql,
 		}).
 		SetResult(result).