1package multierr
2
3import (
4 "context"
5 "fmt"
6 "sync"
7)
8
9type token struct{}
10
11// A ErrWaitGroup is a collection of goroutines working on subtasks that are part of
12// the same overall task.
13//
14// A zero ErrWaitGroup is valid, has no limit on the number of active goroutines,
15// and does not cancel on error.
16type ErrWaitGroup struct {
17 cancel func()
18
19 wg sync.WaitGroup
20
21 sem chan token
22
23 mu sync.Mutex
24 err error
25}
26
27func (g *ErrWaitGroup) done() {
28 if g.sem != nil {
29 <-g.sem
30 }
31 g.wg.Done()
32}
33
34// WithContext returns a new ErrWaitGroup and an associated Context derived from ctx.
35//
36// The derived Context is canceled the first time Wait returns.
37func WithContext(ctx context.Context) (*ErrWaitGroup, context.Context) {
38 ctx, cancel := context.WithCancel(ctx)
39 return &ErrWaitGroup{cancel: cancel}, ctx
40}
41
42// Wait blocks until all function calls from the Go method have returned, then
43// returns the combined non-nil errors (if any) from them.
44func (g *ErrWaitGroup) Wait() error {
45 g.wg.Wait()
46 if g.cancel != nil {
47 g.cancel()
48 }
49 return g.err
50}
51
52// Go calls the given function in a new goroutine.
53// It blocks until the new goroutine can be added without the number of
54// active goroutines in the group exceeding the configured limit.
55func (g *ErrWaitGroup) Go(f func() error) {
56 if g.sem != nil {
57 g.sem <- token{}
58 }
59
60 g.wg.Add(1)
61 go func() {
62 defer g.done()
63
64 if err := f(); err != nil {
65 g.mu.Lock()
66 err = Join(g.err, err)
67 g.mu.Unlock()
68 }
69 }()
70}
71
72// TryGo calls the given function in a new goroutine only if the number of
73// active goroutines in the group is currently below the configured limit.
74//
75// The return value reports whether the goroutine was started.
76func (g *ErrWaitGroup) TryGo(f func() error) bool {
77 if g.sem != nil {
78 select {
79 case g.sem <- token{}:
80 // Note: this allows barging iff channels in general allow barging.
81 default:
82 return false
83 }
84 }
85
86 g.wg.Add(1)
87 go func() {
88 defer g.done()
89
90 if err := f(); err != nil {
91 g.mu.Lock()
92 err = Join(g.err, err)
93 g.mu.Unlock()
94 }
95 }()
96 return true
97}
98
99// SetLimit limits the number of active goroutines in this group to at most n.
100// A negative value indicates no limit.
101//
102// Any subsequent call to the Go method will block until it can add an active
103// goroutine without exceeding the configured limit.
104//
105// The limit must not be modified while any goroutines in the group are active.
106func (g *ErrWaitGroup) SetLimit(n int) {
107 if n < 0 {
108 g.sem = nil
109 return
110 }
111 if len(g.sem) != 0 {
112 panic(fmt.Errorf("errwaitgroup: modify limit while %v goroutines in the group are still active", len(g.sem)))
113 }
114 g.sem = make(chan token, n)
115}