parallel.go 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245
  1. package workflow
  2. import (
  3. "context"
  4. "fmt"
  5. "runtime"
  6. "strings"
  7. "sync"
  8. "sync/atomic"
  9. )
  10. // ParallelErrorStrategy defines how to handle errors in parallel execution
  11. type ParallelErrorStrategy string
  12. const (
  13. // ParallelErrorStrategyFailFast stops all parallel work on first error
  14. ParallelErrorStrategyFailFast ParallelErrorStrategy = "failFast"
  15. // ParallelErrorStrategyCollectAll continues all parallel work, collects all errors
  16. ParallelErrorStrategyCollectAll ParallelErrorStrategy = "collectAll"
  17. // ParallelErrorStrategyPartialSuccess continues all work, succeeds if any succeed
  18. ParallelErrorStrategyPartialSuccess ParallelErrorStrategy = "partialSuccess"
  19. )
  20. // BranchError represents an error from a single parallel branch
  21. type BranchError struct {
  22. BranchID string // Child step ID or iteration index
  23. BranchIndex int // Numeric index for ordering
  24. Error error // The actual error
  25. }
  26. // ParallelError aggregates errors from parallel executions
  27. type ParallelError struct {
  28. Errors []BranchError
  29. Strategy ParallelErrorStrategy
  30. TotalBranches int
  31. SuccessCount int
  32. }
  33. // Error implements the error interface
  34. func (e *ParallelError) Error() string {
  35. failureCount := len(e.Errors)
  36. if failureCount == 0 {
  37. return "no errors"
  38. }
  39. if failureCount == 1 {
  40. return fmt.Sprintf("parallel execution failed: %s: %v",
  41. e.Errors[0].BranchID, e.Errors[0].Error)
  42. }
  43. var sb strings.Builder
  44. sb.WriteString(fmt.Sprintf("%d/%d parallel branches failed:\n",
  45. failureCount, e.TotalBranches))
  46. for i, branchErr := range e.Errors {
  47. if i >= 5 {
  48. sb.WriteString(fmt.Sprintf(" ... and %d more errors\n", failureCount-5))
  49. break
  50. }
  51. sb.WriteString(fmt.Sprintf(" - %s: %v\n", branchErr.BranchID, branchErr.Error))
  52. }
  53. return sb.String()
  54. }
  55. // ParallelBranch represents a single parallel execution branch
  56. type ParallelBranch struct {
  57. ID string // Branch identifier
  58. Fn func(context.Context) error // Function to execute
  59. }
  60. // ParallelCoordinator manages parallel execution with cancellation and error handling
  61. type ParallelCoordinator struct {
  62. ctx context.Context
  63. cancel context.CancelFunc
  64. wg sync.WaitGroup
  65. errorStrategy ParallelErrorStrategy
  66. errorMutex sync.Mutex
  67. errors []BranchError
  68. totalBranches int
  69. successCount int32 // atomic
  70. }
  71. // NewParallelCoordinator creates a new parallel execution coordinator
  72. func NewParallelCoordinator(
  73. parentCtx context.Context,
  74. strategy ParallelErrorStrategy,
  75. totalBranches int,
  76. ) *ParallelCoordinator {
  77. ctx, cancel := context.WithCancel(parentCtx)
  78. return &ParallelCoordinator{
  79. ctx: ctx,
  80. cancel: cancel,
  81. errorStrategy: strategy,
  82. errors: make([]BranchError, 0),
  83. totalBranches: totalBranches,
  84. successCount: 0,
  85. }
  86. }
  87. // ExecuteBranch executes a single branch in a goroutine
  88. func (pc *ParallelCoordinator) ExecuteBranch(
  89. branchID string,
  90. branchIndex int,
  91. fn func(ctx context.Context) error,
  92. ) {
  93. pc.wg.Add(1)
  94. go func() {
  95. defer pc.wg.Done()
  96. // Check if already cancelled before starting
  97. select {
  98. case <-pc.ctx.Done():
  99. return
  100. default:
  101. }
  102. // Execute branch function
  103. err := fn(pc.ctx)
  104. if err != nil {
  105. pc.recordError(BranchError{
  106. BranchID: branchID,
  107. BranchIndex: branchIndex,
  108. Error: err,
  109. })
  110. // Cancel other branches if fail-fast
  111. if pc.errorStrategy == ParallelErrorStrategyFailFast {
  112. pc.cancel()
  113. }
  114. } else {
  115. // Increment success count
  116. atomic.AddInt32(&pc.successCount, 1)
  117. }
  118. }()
  119. }
  120. // Wait waits for all branches to complete and returns aggregated error
  121. func (pc *ParallelCoordinator) Wait() error {
  122. pc.wg.Wait()
  123. pc.cancel() // Clean up context
  124. successCount := int(atomic.LoadInt32(&pc.successCount))
  125. // Handle different error strategies
  126. if len(pc.errors) == 0 {
  127. return nil
  128. }
  129. // For partial success strategy, succeed if at least one branch succeeded
  130. if pc.errorStrategy == ParallelErrorStrategyPartialSuccess && successCount > 0 {
  131. return nil
  132. }
  133. // Return aggregated error
  134. return &ParallelError{
  135. Errors: pc.errors,
  136. Strategy: pc.errorStrategy,
  137. TotalBranches: pc.totalBranches,
  138. SuccessCount: successCount,
  139. }
  140. }
  141. // recordError records an error from a branch (thread-safe)
  142. func (pc *ParallelCoordinator) recordError(err BranchError) {
  143. pc.errorMutex.Lock()
  144. defer pc.errorMutex.Unlock()
  145. pc.errors = append(pc.errors, err)
  146. }
  147. // ParallelExecutor manages parallel execution with resource limits
  148. type ParallelExecutor struct {
  149. maxConcurrency int
  150. semaphore chan struct{}
  151. }
  152. // NewParallelExecutor creates a new parallel executor with concurrency limit
  153. func NewParallelExecutor(maxConcurrency int) *ParallelExecutor {
  154. if maxConcurrency <= 0 {
  155. maxConcurrency = runtime.NumCPU() * 2 // Sensible default
  156. }
  157. return &ParallelExecutor{
  158. maxConcurrency: maxConcurrency,
  159. semaphore: make(chan struct{}, maxConcurrency),
  160. }
  161. }
  162. // Execute executes branches in parallel with concurrency limit
  163. func (pe *ParallelExecutor) Execute(
  164. ctx context.Context,
  165. branches []ParallelBranch,
  166. strategy ParallelErrorStrategy,
  167. ) error {
  168. if len(branches) == 0 {
  169. return nil
  170. }
  171. coordinator := NewParallelCoordinator(ctx, strategy, len(branches))
  172. for i, branch := range branches {
  173. branchID := branch.ID
  174. branchIndex := i
  175. branchFn := branch.Fn
  176. // Acquire semaphore (blocks if at max concurrency)
  177. select {
  178. case pe.semaphore <- struct{}{}:
  179. // Got token, proceed
  180. case <-ctx.Done():
  181. // Context cancelled while waiting
  182. coordinator.cancel()
  183. return ctx.Err()
  184. }
  185. // Execute branch with semaphore release
  186. coordinator.ExecuteBranch(branchID, branchIndex, func(ctx context.Context) error {
  187. defer func() {
  188. <-pe.semaphore // Release semaphore
  189. }()
  190. return branchFn(ctx)
  191. })
  192. }
  193. return coordinator.Wait()
  194. }
  195. // EngineOptions configures the workflow engine
  196. type EngineOptions struct {
  197. MaxConcurrency int // Max concurrent goroutines (default: runtime.NumCPU() * 2)
  198. ParallelErrorStrategy ParallelErrorStrategy // Error handling strategy (default: fail-fast)
  199. EventBufferSize int // Event stream buffer size (default: 1000)
  200. }
  201. // DefaultEngineOptions provides sensible defaults
  202. var DefaultEngineOptions = EngineOptions{
  203. MaxConcurrency: runtime.NumCPU() * 2,
  204. ParallelErrorStrategy: ParallelErrorStrategyFailFast,
  205. EventBufferSize: 1000,
  206. }