Jelajahi Sumber

完成事务

yjp 1 tahun lalu
induk
melakukan
35e4f06f6c
5 mengubah file dengan 167 tambahan dan 3 penghapusan
  1. 1 3
      dpsv1/client.go
  2. 1 0
      ports/client.go
  3. 2 0
      ports/transaction.go
  4. 9 0
      test/v1/sdk.go
  5. 154 0
      test/v1/v1_test.go

+ 1 - 3
dpsv1/client.go

@@ -71,9 +71,7 @@ func (c *Client) AutoMigrate(req *ports.AutoMigrateRequest) error {
 	return nil
 }
 
-type TransactionFunc func(tx ports.Transaction) error
-
-func (c *Client) Transaction(databaseID string, txFunc TransactionFunc) error {
+func (c *Client) Transaction(databaseID string, txFunc ports.TransactionFunc) error {
 	stream, err := c.commandServiceClient.Transaction(context.Background())
 	if err != nil {
 		return err

+ 1 - 0
ports/client.go

@@ -14,6 +14,7 @@ const (
 
 type Client interface {
 	AutoMigrate(request *AutoMigrateRequest) error
+	Transaction(databaseID string, txFunc TransactionFunc) error
 	Insert(request *InsertRequest) (string, error)
 	InsertBatch(request *InsertBatchRequest) (string, error)
 	Delete(request *DeleteRequest) (string, error)

+ 2 - 0
ports/transaction.go

@@ -1,5 +1,7 @@
 package ports
 
+type TransactionFunc func(tx Transaction) error
+
 type Transaction interface {
 	InsertTx(request *InsertRequest) (string, error)
 	InsertBatchTx(request *InsertBatchRequest) (string, error)

+ 9 - 0
test/v1/sdk.go

@@ -45,6 +45,15 @@ func (toolKit *ToolKit) autoMigrate(req *ports.AutoMigrateRequest) *ToolKit {
 	return toolKit
 }
 
+func (toolKit *ToolKit) transaction(databaseID string, txFunc ports.TransactionFunc) *ToolKit {
+	err := clientInstance.Transaction(databaseID, txFunc)
+	if err != nil {
+		toolKit.t.Fatal(err)
+	}
+
+	return toolKit
+}
+
 func (toolKit *ToolKit) insert(req *ports.InsertRequest) *ToolKit {
 	statement, err := clientInstance.Insert(req)
 	if err != nil {

+ 154 - 0
test/v1/v1_test.go

@@ -1,6 +1,7 @@
 package v1
 
 import (
+	"fmt"
 	"git.sxidc.com/service-supports/dps-sdk/ports"
 	"math/rand"
 	"testing"
@@ -28,6 +29,159 @@ func TestAutoMigrate(t *testing.T) {
 	})
 }
 
+func TestTransaction(t *testing.T) {
+	initClient(t, "localhost:30170")
+	defer destroyClient(t)
+
+	tablePrefix := "test." + simpleUUID()[0:8]
+
+	id := simpleUUID()
+	name := simpleUUID()
+	now := time.Now().Local()
+	tableNum := rand.New(rand.NewSource(now.Unix())).Intn(10)
+	newName := simpleUUID()
+	newNow := time.Now().Local()
+	newTableNum := rand.New(rand.NewSource(now.Unix())).Intn(10)
+
+	var count int64
+	resultMap := make(map[string]any)
+
+	newToolKit(t).
+		autoMigrate(&ports.AutoMigrateRequest{
+			DatabaseID:            "2b78141779ee432295ca371b91c5cac7",
+			TablePrefixWithSchema: tablePrefix,
+			Version:               "v1",
+			TableModelDescribe:    tableModelDescribe,
+		}).
+		transaction("2b78141779ee432295ca371b91c5cac7", func(tx ports.Transaction) error {
+			statement, err := tx.InsertTx(&ports.InsertRequest{
+				DatabaseID:            "2b78141779ee432295ca371b91c5cac7",
+				TablePrefixWithSchema: tablePrefix,
+				Version:               "v1",
+				KeyColumns:            []string{"id"},
+				TableRow: map[string]any{
+					"id":        id,
+					"name":      name,
+					"time":      now,
+					"table_num": tableNum,
+				},
+				UserID: "test",
+			})
+			if err != nil {
+				return err
+			}
+
+			fmt.Println(statement)
+
+			err = tx.End()
+			if err != nil {
+				return err
+			}
+
+			return nil
+		}).
+		queryByKeys(&ports.QueryByKeysRequest{
+			DatabaseID:            "2b78141779ee432295ca371b91c5cac7",
+			TablePrefixWithSchema: tablePrefix,
+			Version:               "v1",
+			KeyValues:             map[string]string{"id": id},
+		}, &resultMap).
+		assertEqual(id, resultMap["id"], "ID不一致").
+		assertEqual(name, resultMap["name"], "名称不一致").
+		assertEqual(now.UnixMilli(), resultMap["time"].(time.Time).UnixMilli(), "时间不一致").
+		assertEqual(tableNum, resultMap["table_num"], "表数量不一致").
+		transaction("2b78141779ee432295ca371b91c5cac7", func(tx ports.Transaction) error {
+			statement, err := tx.UpdateTx(&ports.UpdateRequest{
+				DatabaseID:            "2b78141779ee432295ca371b91c5cac7",
+				TablePrefixWithSchema: tablePrefix,
+				Version:               "v1",
+				KeyValues:             map[string]string{"id": id},
+				NewTableRow: map[string]any{
+					"id":        id,
+					"name":      newName,
+					"time":      newNow,
+					"table_num": newTableNum,
+				},
+				UserID: "test",
+			})
+			if err != nil {
+				return err
+			}
+
+			fmt.Println(statement)
+
+			err = tx.End()
+			if err != nil {
+				return err
+			}
+
+			return nil
+		}).
+		queryByKeys(&ports.QueryByKeysRequest{
+			DatabaseID:            "2b78141779ee432295ca371b91c5cac7",
+			TablePrefixWithSchema: tablePrefix,
+			Version:               "v1",
+			KeyValues:             map[string]string{"id": id},
+		}, &resultMap).
+		assertEqual(id, resultMap["id"], "ID不一致").
+		assertEqual(newName, resultMap["name"], "名称不一致").
+		assertEqual(newNow.UnixMilli(), resultMap["time"].(time.Time).UnixMilli(), "时间不一致").
+		assertEqual(newTableNum, resultMap["table_num"], "表数量不一致").
+		transaction("2b78141779ee432295ca371b91c5cac7", func(tx ports.Transaction) error {
+			statement, err := tx.UpdateTx(&ports.UpdateRequest{
+				DatabaseID:            "2b78141779ee432295ca371b91c5cac7",
+				TablePrefixWithSchema: tablePrefix,
+				Version:               "v1",
+				KeyValues:             map[string]string{"id": id},
+				NewTableRow: map[string]any{
+					"id":        id,
+					"name":      name,
+					"time":      now,
+					"table_num": tableNum,
+				},
+				UserID: "test",
+			})
+			if err != nil {
+				return err
+			}
+
+			fmt.Println(statement)
+
+			statement, err = tx.DeleteTx(&ports.DeleteRequest{
+				DatabaseID:            "2b78141779ee432295ca371b91c5cac7",
+				TablePrefixWithSchema: tablePrefix,
+				Version:               "v1",
+				KeyValues:             map[string]string{"id": id},
+				UserID:                "test",
+			})
+			if err != nil {
+				return err
+			}
+
+			fmt.Println(statement)
+
+			err = tx.End()
+			if err != nil {
+				return err
+			}
+
+			return nil
+		}).
+		countWhere(&ports.CountWhereRequest{
+			DatabaseID:            "2b78141779ee432295ca371b91c5cac7",
+			TablePrefixWithSchema: tablePrefix,
+			Version:               "v1",
+			Where: []ports.ColumnCompare{
+				{
+					Column:  "id",
+					Value:   id,
+					Compare: ports.CompareEqual,
+				},
+			},
+		}, &count).
+		assertEqual(int64(0), count, "数量不一致")
+}
+
 func TestInsert(t *testing.T) {
 	initClient(t, "localhost:30170")
 	defer destroyClient(t)