orchestrator.go 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169
  1. package saga
  2. import (
  3. "fmt"
  4. "git.sxidc.com/service-supports/dapr_api/state"
  5. "sync"
  6. "time"
  7. )
  8. var runningCountMutex sync.Mutex
  9. var runningCount uint64
  10. type StepFunc func() error
  11. type StepRollbackFunc func() error
  12. type OutputFunc func() (interface{}, error)
  13. type Step struct {
  14. StepFunc StepFunc
  15. StepRollbackFunc StepRollbackFunc
  16. RollbackContextData string
  17. rollbackDone chan interface{}
  18. }
  19. type Orchestrator struct {
  20. stateAPI *state.API
  21. stateStoreName string
  22. sagaName string
  23. name string
  24. rollbackRetryPeriodSec time.Duration
  25. steps []*Step
  26. OutputFunc OutputFunc
  27. }
  28. func (orchestrator *Orchestrator) AddStep(step *Step) *Orchestrator {
  29. orchestrator.steps = append(orchestrator.steps, step)
  30. return orchestrator
  31. }
  32. func (orchestrator *Orchestrator) Output(outputFunc OutputFunc) *Orchestrator {
  33. orchestrator.OutputFunc = outputFunc
  34. return orchestrator
  35. }
  36. func (orchestrator *Orchestrator) Run() (interface{}, error) {
  37. runningCountMutex.Lock()
  38. runningCount++
  39. runningCountMutex.Unlock()
  40. for index, step := range orchestrator.steps {
  41. if step.StepFunc == nil {
  42. continue
  43. }
  44. err := step.StepFunc()
  45. if err == nil {
  46. continue
  47. }
  48. if err != nil {
  49. err = saveOrchestratorState(orchestrator.stateAPI, orchestrator.stateStoreName, orchestrator.sagaName,
  50. orchestrator.name, step.RollbackContextData, index)
  51. if err != nil {
  52. return nil, err
  53. }
  54. go orchestrator.Rollback(index)
  55. runningCountMutex.Lock()
  56. runningCount--
  57. runningCountMutex.Unlock()
  58. return nil, err
  59. }
  60. }
  61. if orchestrator.OutputFunc != nil {
  62. result, err := orchestrator.OutputFunc()
  63. runningCountMutex.Lock()
  64. runningCount--
  65. runningCountMutex.Unlock()
  66. return result, err
  67. }
  68. runningCountMutex.Lock()
  69. runningCount--
  70. runningCountMutex.Unlock()
  71. return nil, nil
  72. }
  73. func (orchestrator *Orchestrator) Rollback(startIndex int) {
  74. for i := startIndex; i >= 0; i-- {
  75. rollbackStep := orchestrator.steps[i]
  76. err := orchestrator.rollbackStep(rollbackStep)
  77. if err != nil {
  78. fmt.Println("Rollback", "orchestrator.rollbackStep", err)
  79. }
  80. err = saveOrchestratorState(orchestrator.stateAPI, orchestrator.stateStoreName,
  81. orchestrator.sagaName, orchestrator.name,
  82. rollbackStep.RollbackContextData, i-1)
  83. if err != nil {
  84. fmt.Println("Rollback", "saveOrchestratorState", err)
  85. return
  86. }
  87. }
  88. err := deleteOrchestratorState(orchestrator.stateAPI, orchestrator.stateStoreName,
  89. orchestrator.sagaName, orchestrator.name)
  90. if err != nil {
  91. fmt.Println("Rollback", "deleteOrchestratorState", err)
  92. return
  93. }
  94. }
  95. func (orchestrator *Orchestrator) rollbackStep(rollbackStep *Step) error {
  96. if rollbackStep.StepRollbackFunc == nil {
  97. return nil
  98. }
  99. err := rollbackStep.StepRollbackFunc()
  100. if err == nil {
  101. return nil
  102. }
  103. orchestrator.rollbackRetry(rollbackStep)
  104. return nil
  105. }
  106. func (orchestrator *Orchestrator) rollbackRetry(step *Step) {
  107. ticker := time.NewTicker(orchestrator.rollbackRetryPeriodSec)
  108. for {
  109. select {
  110. case <-step.rollbackDone:
  111. ticker.Stop()
  112. ticker = nil
  113. case <-ticker.C:
  114. err := step.StepRollbackFunc()
  115. if err == nil {
  116. return
  117. }
  118. }
  119. }
  120. }
  121. func (orchestrator *Orchestrator) stop() {
  122. runningCountMutex.Lock()
  123. for runningCount != 0 {
  124. runningCountMutex.Unlock()
  125. time.Sleep(500 * time.Millisecond)
  126. runningCountMutex.Lock()
  127. continue
  128. }
  129. runningCountMutex.Unlock()
  130. for _, step := range orchestrator.steps {
  131. if step.rollbackDone != nil {
  132. step.rollbackDone <- true
  133. close(step.rollbackDone)
  134. step.rollbackDone = nil
  135. }
  136. }
  137. }