orchestrator.go 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137
  1. package saga
  2. import (
  3. "fmt"
  4. "git.sxidc.com/service-supports/dapr_api/state"
  5. "time"
  6. )
  7. type StepFunc func() error
  8. type StepRollbackFunc func() error
  9. type OutputFunc func() (interface{}, error)
  10. type Step struct {
  11. StepFunc StepFunc
  12. StepRollbackFunc StepRollbackFunc
  13. RollbackContextData string
  14. rollbackDone chan interface{}
  15. }
  16. type Orchestrator struct {
  17. stateAPI *state.API
  18. stateStoreName string
  19. sagaName string
  20. name string
  21. rollbackRetryPeriodSec time.Duration
  22. steps []*Step
  23. OutputFunc OutputFunc
  24. }
  25. func (orchestrator *Orchestrator) AddStep(step *Step) *Orchestrator {
  26. orchestrator.steps = append(orchestrator.steps, step)
  27. return orchestrator
  28. }
  29. func (orchestrator *Orchestrator) Output(outputFunc OutputFunc) *Orchestrator {
  30. orchestrator.OutputFunc = outputFunc
  31. return orchestrator
  32. }
  33. func (orchestrator *Orchestrator) Run() (interface{}, error) {
  34. for index, step := range orchestrator.steps {
  35. if step.StepFunc == nil {
  36. continue
  37. }
  38. err := step.StepFunc()
  39. if err == nil {
  40. continue
  41. }
  42. if err != nil {
  43. err = saveOrchestratorState(orchestrator.stateAPI, orchestrator.stateStoreName, orchestrator.sagaName,
  44. orchestrator.name, step.RollbackContextData, index)
  45. if err != nil {
  46. return nil, err
  47. }
  48. go orchestrator.Rollback(index)
  49. return nil, err
  50. }
  51. }
  52. if orchestrator.OutputFunc != nil {
  53. return orchestrator.OutputFunc()
  54. }
  55. return nil, nil
  56. }
  57. func (orchestrator *Orchestrator) Rollback(startIndex int) {
  58. for i := startIndex; i >= 0; i-- {
  59. rollbackStep := orchestrator.steps[i]
  60. err := orchestrator.rollbackStep(rollbackStep)
  61. if err != nil {
  62. fmt.Println("Rollback", "orchestrator.rollbackStep", err)
  63. }
  64. err = saveOrchestratorState(orchestrator.stateAPI, orchestrator.stateStoreName,
  65. orchestrator.sagaName, orchestrator.name,
  66. rollbackStep.RollbackContextData, i-1)
  67. if err != nil {
  68. fmt.Println("Rollback", "saveOrchestratorState", err)
  69. return
  70. }
  71. }
  72. err := deleteOrchestratorState(orchestrator.stateAPI, orchestrator.stateStoreName,
  73. orchestrator.sagaName, orchestrator.name)
  74. if err != nil {
  75. fmt.Println("Rollback", "deleteOrchestratorState", err)
  76. return
  77. }
  78. }
  79. func (orchestrator *Orchestrator) rollbackStep(rollbackStep *Step) error {
  80. if rollbackStep.StepRollbackFunc == nil {
  81. return nil
  82. }
  83. err := rollbackStep.StepRollbackFunc()
  84. if err == nil {
  85. return nil
  86. }
  87. orchestrator.rollbackRetry(rollbackStep)
  88. return nil
  89. }
  90. func (orchestrator *Orchestrator) rollbackRetry(step *Step) {
  91. ticker := time.NewTicker(orchestrator.rollbackRetryPeriodSec)
  92. for {
  93. select {
  94. case <-step.rollbackDone:
  95. ticker.Stop()
  96. ticker = nil
  97. case <-ticker.C:
  98. err := step.StepRollbackFunc()
  99. if err == nil {
  100. return
  101. }
  102. }
  103. }
  104. }
  105. func (orchestrator *Orchestrator) stop() {
  106. for _, step := range orchestrator.steps {
  107. if step.rollbackDone != nil {
  108. step.rollbackDone <- true
  109. close(step.rollbackDone)
  110. step.rollbackDone = nil
  111. }
  112. }
  113. }