package saga import ( "fmt" "sync" "time" ) var runningCountMutex sync.Mutex var runningCount uint64 type StepFunc func() error type StepRollbackFunc func() error type OutputFunc func() (interface{}, error) type Step struct { StepFunc StepFunc StepRollbackFunc StepRollbackFunc RollbackContextData string rollbackDone chan interface{} } type Orchestrator struct { stateStoreName string sagaName string name string rollbackRetryPeriodSec time.Duration steps []*Step OutputFunc OutputFunc } func (orchestrator *Orchestrator) AddStep(step *Step) *Orchestrator { orchestrator.steps = append(orchestrator.steps, step) return orchestrator } func (orchestrator *Orchestrator) Output(outputFunc OutputFunc) *Orchestrator { orchestrator.OutputFunc = outputFunc return orchestrator } func (orchestrator *Orchestrator) Run() (interface{}, error) { runningCountMutex.Lock() runningCount++ runningCountMutex.Unlock() for index, step := range orchestrator.steps { if step.StepFunc == nil { continue } err := step.StepFunc() if err == nil { continue } if err != nil { err = saveOrchestratorState(orchestrator.stateStoreName, orchestrator.sagaName, orchestrator.name, step.RollbackContextData, index) if err != nil { return nil, err } go orchestrator.Rollback(index) runningCountMutex.Lock() runningCount-- runningCountMutex.Unlock() return nil, err } } if orchestrator.OutputFunc != nil { result, err := orchestrator.OutputFunc() runningCountMutex.Lock() runningCount-- runningCountMutex.Unlock() return result, err } runningCountMutex.Lock() runningCount-- runningCountMutex.Unlock() return nil, nil } func (orchestrator *Orchestrator) Rollback(startIndex int) { for i := startIndex; i >= 0; i-- { rollbackStep := orchestrator.steps[i] err := orchestrator.rollbackStep(rollbackStep) if err != nil { fmt.Println("Rollback", "orchestrator.rollbackStep", err) } err = saveOrchestratorState(orchestrator.stateStoreName, orchestrator.sagaName, orchestrator.name, rollbackStep.RollbackContextData, i-1) if err != nil { fmt.Println("Rollback", "saveOrchestratorState", err) return } } err := deleteOrchestratorState(orchestrator.stateStoreName, orchestrator.sagaName, orchestrator.name) if err != nil { fmt.Println("Rollback", "deleteOrchestratorState", err) return } } func (orchestrator *Orchestrator) rollbackStep(rollbackStep *Step) error { if rollbackStep.StepRollbackFunc == nil { return nil } err := rollbackStep.StepRollbackFunc() if err == nil { return nil } orchestrator.rollbackRetry(rollbackStep) return nil } func (orchestrator *Orchestrator) rollbackRetry(step *Step) { ticker := time.NewTicker(orchestrator.rollbackRetryPeriodSec) for { select { case <-step.rollbackDone: ticker.Stop() ticker = nil case <-ticker.C: err := step.StepRollbackFunc() if err == nil { return } } } } func (orchestrator *Orchestrator) stop() { runningCountMutex.Lock() for runningCount != 0 { runningCountMutex.Unlock() time.Sleep(500 * time.Millisecond) runningCountMutex.Lock() continue } runningCountMutex.Unlock() for _, step := range orchestrator.steps { if step.rollbackDone != nil { step.rollbackDone <- true close(step.rollbackDone) step.rollbackDone = nil } } }