Explorar el Código

完成saga调整

yjp hace 3 años
padre
commit
2139282a0b
Se han modificado 6 ficheros con 529 adiciones y 9 borrados
  1. 137 0
      saga/orchestrator.go
  2. 76 0
      saga/saga.go
  3. 194 0
      saga/state_store.go
  4. 6 0
      test/const.go
  5. 112 0
      test/saga_test.go
  6. 4 9
      test/state_test.go

+ 137 - 0
saga/orchestrator.go

@@ -0,0 +1,137 @@
+package saga
+
+import (
+	"fmt"
+	"git.sxidc.com/service-supports/dapr_api/state"
+	"time"
+)
+
+type StepFunc func() error
+type StepRollbackFunc func() error
+type OutputFunc func() (interface{}, error)
+
+type Step struct {
+	StepFunc            StepFunc
+	StepRollbackFunc    StepRollbackFunc
+	RollbackContextData string
+	rollbackDone        chan interface{}
+}
+
+type Orchestrator struct {
+	stateAPI               *state.API
+	stateStoreName         string
+	sagaName               string
+	name                   string
+	rollbackRetryPeriodSec time.Duration
+	steps                  []*Step
+	OutputFunc             OutputFunc
+}
+
+func (orchestrator *Orchestrator) AddStep(step *Step) *Orchestrator {
+	orchestrator.steps = append(orchestrator.steps, step)
+
+	return orchestrator
+}
+
+func (orchestrator *Orchestrator) Output(outputFunc OutputFunc) *Orchestrator {
+	orchestrator.OutputFunc = outputFunc
+
+	return orchestrator
+}
+
+func (orchestrator *Orchestrator) Run() (interface{}, error) {
+	for index, step := range orchestrator.steps {
+		if step.StepFunc == nil {
+			continue
+		}
+
+		err := step.StepFunc()
+		if err == nil {
+			continue
+		}
+
+		if err != nil {
+			err = saveOrchestratorState(orchestrator.stateAPI, orchestrator.stateStoreName, orchestrator.sagaName,
+				orchestrator.name, step.RollbackContextData, index)
+			if err != nil {
+				return nil, err
+			}
+
+			go orchestrator.Rollback(index)
+			return nil, err
+		}
+	}
+
+	if orchestrator.OutputFunc != nil {
+		return orchestrator.OutputFunc()
+	}
+
+	return nil, nil
+}
+
+func (orchestrator *Orchestrator) Rollback(startIndex int) {
+	for i := startIndex; i >= 0; i-- {
+		rollbackStep := orchestrator.steps[i]
+		err := orchestrator.rollbackStep(rollbackStep)
+		if err != nil {
+			fmt.Println("Rollback", "orchestrator.rollbackStep", err)
+		}
+
+		err = saveOrchestratorState(orchestrator.stateAPI, orchestrator.stateStoreName,
+			orchestrator.sagaName, orchestrator.name,
+			rollbackStep.RollbackContextData, i-1)
+		if err != nil {
+			fmt.Println("Rollback", "saveOrchestratorState", err)
+			return
+		}
+	}
+
+	err := deleteOrchestratorState(orchestrator.stateAPI, orchestrator.stateStoreName,
+		orchestrator.sagaName, orchestrator.name)
+	if err != nil {
+		fmt.Println("Rollback", "deleteOrchestratorState", err)
+		return
+	}
+}
+
+func (orchestrator *Orchestrator) rollbackStep(rollbackStep *Step) error {
+	if rollbackStep.StepRollbackFunc == nil {
+		return nil
+	}
+
+	err := rollbackStep.StepRollbackFunc()
+	if err == nil {
+		return nil
+	}
+
+	orchestrator.rollbackRetry(rollbackStep)
+
+	return nil
+}
+
+func (orchestrator *Orchestrator) rollbackRetry(step *Step) {
+	ticker := time.NewTicker(orchestrator.rollbackRetryPeriodSec)
+
+	for {
+		select {
+		case <-step.rollbackDone:
+			ticker.Stop()
+			ticker = nil
+		case <-ticker.C:
+			err := step.StepRollbackFunc()
+			if err == nil {
+				return
+			}
+		}
+	}
+}
+
+func (orchestrator *Orchestrator) stop() {
+	for _, step := range orchestrator.steps {
+		if step.rollbackDone != nil {
+			step.rollbackDone <- true
+			close(step.rollbackDone)
+			step.rollbackDone = nil
+		}
+	}
+}

+ 76 - 0
saga/saga.go

@@ -0,0 +1,76 @@
+package saga
+
+import (
+	"git.sxidc.com/service-supports/dapr_api/state"
+	"sync"
+	"time"
+)
+
+var sagaInstance *Saga
+var sagaInstanceMutex sync.Mutex
+
+type Saga struct {
+	stateAPI       *state.API
+	stateStoreName string
+	name           string
+	orchestrators  []*Orchestrator
+}
+
+func Init(stateAPI *state.API, stateStoreName string, sagaName string) {
+	sagaInstanceMutex.Lock()
+	defer sagaInstanceMutex.Unlock()
+
+	if sagaInstance != nil {
+		return
+	}
+
+	sagaInstance = new(Saga)
+	sagaInstance.stateAPI = stateAPI
+	sagaInstance.stateStoreName = stateStoreName
+	sagaInstance.name = sagaName
+	sagaInstance.orchestrators = make([]*Orchestrator, 0)
+}
+
+func Destroy() {
+	sagaInstanceMutex.Lock()
+	defer sagaInstanceMutex.Unlock()
+
+	if sagaInstance == nil {
+		return
+	}
+
+	for _, orchestrator := range sagaInstance.orchestrators {
+		orchestrator.stop()
+	}
+
+	sagaInstance.orchestrators = nil
+	sagaInstance.name = ""
+	sagaInstance.stateStoreName = ""
+	sagaInstance.stateAPI = nil
+	sagaInstance = nil
+}
+
+func GetInstance() *Saga {
+	return sagaInstance
+}
+
+func (s *Saga) BuildOrchestrator(orchestratorName string, rollbackRetryPeriodSec time.Duration) *Orchestrator {
+	sagaInstanceMutex.Lock()
+	defer sagaInstanceMutex.Unlock()
+
+	orchestrator := &Orchestrator{
+		stateAPI:               s.stateAPI,
+		stateStoreName:         s.stateStoreName,
+		sagaName:               s.name,
+		name:                   orchestratorName,
+		rollbackRetryPeriodSec: rollbackRetryPeriodSec,
+	}
+
+	sagaInstance.orchestrators = append(sagaInstance.orchestrators, orchestrator)
+
+	return orchestrator
+}
+
+func (s *Saga) GetOrchestratorNeedRollback() ([]OrchestratorState, error) {
+	return getSagaOrchestratorStates(s.stateAPI, s.stateStoreName, s.name)
+}

+ 194 - 0
saga/state_store.go

@@ -0,0 +1,194 @@
+package saga
+
+import (
+	"encoding/json"
+	"git.sxidc.com/service-supports/dapr_api/state"
+	"git.sxidc.com/service-supports/dapr_api/utils"
+)
+
+const (
+	sagaIndexes = "indexes"
+)
+
+type OrchestratorState struct {
+	Name                string
+	RollbackContextData string
+	RollbackStartIndex  int
+}
+
+func deleteOrchestratorState(stateAPI *state.API, stateStoreName string, sagaName string, orchestratorName string) error {
+	orchestratorNames, err := getSagaIndexState(stateAPI, stateStoreName, sagaName)
+	if err != nil {
+		return err
+	}
+
+	for i, savedOrchestratorName := range orchestratorNames {
+		if savedOrchestratorName == orchestratorName {
+			if i == len(orchestratorNames)-1 {
+				orchestratorNames = orchestratorNames[:i]
+			} else {
+				orchestratorNames = append(orchestratorNames[0:i], orchestratorNames[i+1:]...)
+			}
+		}
+	}
+
+	if len(orchestratorNames) == 0 {
+		return stateAPI.Transaction(stateStoreName, state.TransactionRequest{
+			Operations: []state.TransactionOperation{
+				{
+					Operation: state.TransactionDelete,
+					Request: state.TransactionOperationRequest{
+						Key: sagaName,
+					},
+				},
+				{
+					Operation: state.TransactionDelete,
+					Request: state.TransactionOperationRequest{
+						Key: orchestratorName,
+					},
+				},
+			},
+		})
+	}
+
+	orchestratorNames = append(orchestratorNames, orchestratorName)
+	sagaIndexStateJsonData, err := json.Marshal(&map[string]interface{}{sagaIndexes: orchestratorNames})
+	if err != nil {
+		return err
+	}
+
+	return stateAPI.Transaction(stateStoreName, state.TransactionRequest{
+		Operations: []state.TransactionOperation{
+			{
+				Operation: state.TransactionUpsert,
+				Request: state.TransactionOperationRequest{
+					Key:   sagaName,
+					Value: string(sagaIndexStateJsonData),
+				},
+			},
+			{
+				Operation: state.TransactionDelete,
+				Request: state.TransactionOperationRequest{
+					Key: orchestratorName,
+				},
+			},
+		},
+	})
+}
+
+func getSagaIndexState(stateAPI *state.API, stateStoreName string, sagaName string) ([]string, error) {
+	data, _, err := stateAPI.GetState(stateStoreName, sagaName, nil)
+	if err != nil {
+		return nil, err
+	}
+
+	if utils.HasBlank(data) {
+		return make([]string, 0), nil
+	}
+
+	orchestratorNames := make([]string, 0)
+	err = json.Unmarshal([]byte(data), &map[string]interface{}{sagaIndexes: orchestratorNames})
+	if err != nil {
+		return nil, err
+	}
+
+	return orchestratorNames, nil
+}
+
+func saveOrchestratorState(stateAPI *state.API, stateStoreName string, sagaName string, orchestratorName string,
+	rollbackContextData string, rollbackStartIndex int) error {
+	orchestratorStateJsonData, err := json.Marshal(&OrchestratorState{
+		Name:                orchestratorName,
+		RollbackContextData: rollbackContextData,
+		RollbackStartIndex:  rollbackStartIndex,
+	})
+	if err != nil {
+		return err
+	}
+
+	orchestratorNames, err := getSagaIndexState(stateAPI, stateStoreName, sagaName)
+	if err != nil {
+		return err
+	}
+
+	find := false
+	for _, savedOrchestratorName := range orchestratorNames {
+		if savedOrchestratorName == orchestratorName {
+			find = true
+		}
+	}
+
+	if find {
+		err := stateAPI.SaveState(stateStoreName, []state.SaveStateRequest{
+			{
+				Key:   orchestratorName,
+				Value: string(orchestratorStateJsonData),
+			},
+		})
+		if err != nil {
+			return err
+		}
+
+		return nil
+	}
+
+	orchestratorNames = append(orchestratorNames, orchestratorName)
+	sagaIndexStateJsonData, err := json.Marshal(&map[string]interface{}{sagaIndexes: orchestratorNames})
+	if err != nil {
+		return err
+	}
+
+	return stateAPI.Transaction(stateStoreName, state.TransactionRequest{
+		Operations: []state.TransactionOperation{
+			{
+				Operation: state.TransactionUpsert,
+				Request: state.TransactionOperationRequest{
+					Key:   sagaName,
+					Value: string(sagaIndexStateJsonData),
+				},
+			},
+			{
+				Operation: state.TransactionUpsert,
+				Request: state.TransactionOperationRequest{
+					Key:   orchestratorName,
+					Value: string(orchestratorStateJsonData),
+				},
+			},
+		},
+	})
+}
+
+func getSagaOrchestratorStates(stateAPI *state.API, stateStoreName string, sagaName string) ([]OrchestratorState, error) {
+	data, _, err := stateAPI.GetState(stateStoreName, sagaName, nil)
+	if err != nil {
+		return nil, err
+	}
+
+	if utils.HasBlank(data) {
+		return make([]OrchestratorState, 0), nil
+	}
+
+	indexes := make(map[string][]string)
+	err = json.Unmarshal([]byte(data), &indexes)
+	if err != nil {
+		return nil, err
+	}
+
+	orchestratorStates := make([]OrchestratorState, 0)
+	for _, orchestratorName := range indexes[sagaIndexes] {
+		data, _, err := stateAPI.GetState(stateStoreName, orchestratorName, nil)
+		if err != nil {
+			return nil, err
+		}
+
+		orchestratorState := new(OrchestratorState)
+		err = json.Unmarshal([]byte(data), orchestratorState)
+		if err != nil {
+			return nil, err
+		}
+
+		orchestratorStates = append(orchestratorStates, *orchestratorState)
+	}
+
+	return orchestratorStates, nil
+}

+ 6 - 0
test/const.go

@@ -0,0 +1,6 @@
+package test
+
+const (
+	daprHttpPort   = 10080
+	stateStoreName = "dapr_api"
+)

+ 112 - 0
test/saga_test.go

@@ -0,0 +1,112 @@
+package test
+
+import (
+	"errors"
+	"fmt"
+	"git.sxidc.com/service-supports/dapr_api/saga"
+	"git.sxidc.com/service-supports/dapr_api/state"
+	"testing"
+	"time"
+)
+
+func TestSagaOK(t *testing.T) {
+	stateAPI := state.NewAPI(daprHttpPort, 10*time.Second)
+	defer state.DestroyAPI(stateAPI)
+
+	saga.Init(stateAPI, stateStoreName, "sage_test")
+	defer saga.Destroy()
+
+	orchestrator := saga.GetInstance().BuildOrchestrator("test-ok", 5*time.Second).
+		AddStep(&saga.Step{
+			StepFunc: func() error {
+				fmt.Println("one")
+				return nil
+			},
+			StepRollbackFunc: func() error {
+				fmt.Println("one rollback")
+				return nil
+			},
+		}).
+		AddStep(&saga.Step{
+			StepFunc: func() error {
+				fmt.Println("two")
+				return nil
+			},
+			StepRollbackFunc: func() error {
+				fmt.Println("two rollback")
+				return nil
+			},
+		}).
+		AddStep(&saga.Step{
+			StepFunc: func() error {
+				fmt.Println("three")
+				return nil
+			},
+			StepRollbackFunc: nil,
+		}).
+		Output(func() (interface{}, error) {
+			return "ok result", nil
+		})
+
+	result, err := orchestrator.Run()
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	fmt.Println("Result:", result)
+}
+
+func TestSagaRollback(t *testing.T) {
+	stateAPI := state.NewAPI(daprHttpPort, 10*time.Second)
+	defer state.DestroyAPI(stateAPI)
+
+	saga.Init(stateAPI, stateStoreName, "sage_test")
+	defer saga.Destroy()
+
+	orchestrator := saga.GetInstance().BuildOrchestrator("test-rollback", 5*time.Second).
+		AddStep(&saga.Step{
+			StepFunc: func() error {
+				fmt.Println("one")
+				return nil
+			},
+			StepRollbackFunc: func() error {
+				fmt.Println("one rollback")
+				return nil
+			},
+		}).
+		AddStep(&saga.Step{
+			StepFunc: func() error {
+				fmt.Println("two")
+				return nil
+			},
+			StepRollbackFunc: func() error {
+				fmt.Println("two rollback")
+				return nil
+			},
+		}).
+		AddStep(&saga.Step{
+			StepFunc: func() error {
+				fmt.Println("three")
+				return nil
+			},
+			StepRollbackFunc: nil,
+		}).
+		AddStep(&saga.Step{
+			StepFunc: func() error {
+				return errors.New("four error")
+			},
+			StepRollbackFunc: nil,
+		}).
+		Output(func() (interface{}, error) {
+			return "should not come here", nil
+		})
+
+	result, err := orchestrator.Run()
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	fmt.Println("Result:", result)
+
+	time.Sleep(1 * time.Second)
+}

+ 4 - 9
test/state_test.go

@@ -9,13 +9,8 @@ import (
 	"time"
 )
 
-const (
-	httpPort       = 10080
-	stateStoreName = "dapr_api"
-)
-
 func TestStateSaveAndGet(t *testing.T) {
-	api := state.NewAPI(httpPort, 10*time.Second)
+	api := state.NewAPI(daprHttpPort, 10*time.Second)
 	defer state.DestroyAPI(api)
 
 	key := utils.SimpleUUID()
@@ -54,7 +49,7 @@ func TestStateSaveAndGet(t *testing.T) {
 }
 
 func TestStateSaveAndGetJson(t *testing.T) {
-	api := state.NewAPI(httpPort, 10*time.Second)
+	api := state.NewAPI(daprHttpPort, 10*time.Second)
 	defer state.DestroyAPI(api)
 
 	key := utils.SimpleUUID()
@@ -96,7 +91,7 @@ func TestStateSaveAndGetJson(t *testing.T) {
 }
 
 func TestStateGetBulk(t *testing.T) {
-	api := state.NewAPI(httpPort, 10*time.Second)
+	api := state.NewAPI(daprHttpPort, 10*time.Second)
 	defer state.DestroyAPI(api)
 
 	key1 := utils.SimpleUUID()
@@ -169,7 +164,7 @@ func TestStateGetBulk(t *testing.T) {
 }
 
 func TestTransaction(t *testing.T) {
-	api := state.NewAPI(httpPort, 10*time.Second)
+	api := state.NewAPI(daprHttpPort, 10*time.Second)
 	defer state.DestroyAPI(api)
 
 	key1 := utils.SimpleUUID()