123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165 |
- package saga
- import (
- "fmt"
- "sync"
- "time"
- )
- var runningCountMutex sync.Mutex
- var runningCount uint64
- type StepFunc func() error
- type StepRollbackFunc func() error
- type OutputFunc func() (interface{}, error)
- type Step struct {
- StepFunc StepFunc
- StepRollbackFunc StepRollbackFunc
- RollbackContextData string
- rollbackDone chan interface{}
- }
- type Orchestrator struct {
- stateStoreName string
- sagaName string
- name string
- rollbackRetryPeriodSec time.Duration
- steps []*Step
- OutputFunc OutputFunc
- }
- func (orchestrator *Orchestrator) AddStep(step *Step) *Orchestrator {
- orchestrator.steps = append(orchestrator.steps, step)
- return orchestrator
- }
- func (orchestrator *Orchestrator) Output(outputFunc OutputFunc) *Orchestrator {
- orchestrator.OutputFunc = outputFunc
- return orchestrator
- }
- func (orchestrator *Orchestrator) Run() (interface{}, error) {
- runningCountMutex.Lock()
- runningCount++
- runningCountMutex.Unlock()
- for index, step := range orchestrator.steps {
- if step.StepFunc == nil {
- continue
- }
- err := step.StepFunc()
- if err == nil {
- continue
- }
- if err != nil {
- err = saveOrchestratorState(orchestrator.stateStoreName, orchestrator.sagaName,
- orchestrator.name, step.RollbackContextData, index)
- if err != nil {
- return nil, err
- }
- go orchestrator.Rollback(index)
- runningCountMutex.Lock()
- runningCount--
- runningCountMutex.Unlock()
- return nil, err
- }
- }
- if orchestrator.OutputFunc != nil {
- result, err := orchestrator.OutputFunc()
- runningCountMutex.Lock()
- runningCount--
- runningCountMutex.Unlock()
- return result, err
- }
- runningCountMutex.Lock()
- runningCount--
- runningCountMutex.Unlock()
- return nil, nil
- }
- func (orchestrator *Orchestrator) Rollback(startIndex int) {
- for i := startIndex; i >= 0; i-- {
- rollbackStep := orchestrator.steps[i]
- err := orchestrator.rollbackStep(rollbackStep)
- if err != nil {
- fmt.Println("Rollback", "orchestrator.rollbackStep", err)
- }
- err = saveOrchestratorState(orchestrator.stateStoreName, orchestrator.sagaName, orchestrator.name,
- rollbackStep.RollbackContextData, i-1)
- if err != nil {
- fmt.Println("Rollback", "saveOrchestratorState", err)
- return
- }
- }
- err := deleteOrchestratorState(orchestrator.stateStoreName, orchestrator.sagaName, orchestrator.name)
- if err != nil {
- fmt.Println("Rollback", "deleteOrchestratorState", err)
- return
- }
- }
- func (orchestrator *Orchestrator) rollbackStep(rollbackStep *Step) error {
- if rollbackStep.StepRollbackFunc == nil {
- return nil
- }
- err := rollbackStep.StepRollbackFunc()
- if err == nil {
- return nil
- }
- orchestrator.rollbackRetry(rollbackStep)
- return nil
- }
- func (orchestrator *Orchestrator) rollbackRetry(step *Step) {
- ticker := time.NewTicker(orchestrator.rollbackRetryPeriodSec)
- for {
- select {
- case <-step.rollbackDone:
- ticker.Stop()
- ticker = nil
- case <-ticker.C:
- err := step.StepRollbackFunc()
- if err == nil {
- return
- }
- }
- }
- }
- func (orchestrator *Orchestrator) stop() {
- runningCountMutex.Lock()
- for runningCount != 0 {
- runningCountMutex.Unlock()
- time.Sleep(500 * time.Millisecond)
- runningCountMutex.Lock()
- continue
- }
- runningCountMutex.Unlock()
- for _, step := range orchestrator.steps {
- if step.rollbackDone != nil {
- step.rollbackDone <- true
- close(step.rollbackDone)
- step.rollbackDone = nil
- }
- }
- }
|