errwaitgroup.go

  1package multierr
  2
  3import (
  4	"context"
  5	"fmt"
  6	"sync"
  7)
  8
  9type token struct{}
 10
 11// A ErrWaitGroup is a collection of goroutines working on subtasks that are part of
 12// the same overall task.
 13//
 14// A zero ErrWaitGroup is valid, has no limit on the number of active goroutines,
 15// and does not cancel on error.
 16type ErrWaitGroup struct {
 17	cancel func()
 18
 19	wg sync.WaitGroup
 20
 21	sem chan token
 22
 23	mu  sync.Mutex
 24	err error
 25}
 26
 27func (g *ErrWaitGroup) done() {
 28	if g.sem != nil {
 29		<-g.sem
 30	}
 31	g.wg.Done()
 32}
 33
 34// WithContext returns a new ErrWaitGroup and an associated Context derived from ctx.
 35//
 36// The derived Context is canceled the first time Wait returns.
 37func WithContext(ctx context.Context) (*ErrWaitGroup, context.Context) {
 38	ctx, cancel := context.WithCancel(ctx)
 39	return &ErrWaitGroup{cancel: cancel}, ctx
 40}
 41
 42// Wait blocks until all function calls from the Go method have returned, then
 43// returns the combined non-nil errors (if any) from them.
 44func (g *ErrWaitGroup) Wait() error {
 45	g.wg.Wait()
 46	if g.cancel != nil {
 47		g.cancel()
 48	}
 49	return g.err
 50}
 51
 52// Go calls the given function in a new goroutine.
 53// It blocks until the new goroutine can be added without the number of
 54// active goroutines in the group exceeding the configured limit.
 55func (g *ErrWaitGroup) Go(f func() error) {
 56	if g.sem != nil {
 57		g.sem <- token{}
 58	}
 59
 60	g.wg.Add(1)
 61	go func() {
 62		defer g.done()
 63
 64		if err := f(); err != nil {
 65			g.mu.Lock()
 66			g.err = Join(g.err, err)
 67			g.mu.Unlock()
 68		}
 69	}()
 70}
 71
 72// TryGo calls the given function in a new goroutine only if the number of
 73// active goroutines in the group is currently below the configured limit.
 74//
 75// The return value reports whether the goroutine was started.
 76func (g *ErrWaitGroup) TryGo(f func() error) bool {
 77	if g.sem != nil {
 78		select {
 79		case g.sem <- token{}:
 80			// Note: this allows barging iff channels in general allow barging.
 81		default:
 82			return false
 83		}
 84	}
 85
 86	g.wg.Add(1)
 87	go func() {
 88		defer g.done()
 89
 90		if err := f(); err != nil {
 91			g.mu.Lock()
 92			err = Join(g.err, err)
 93			g.mu.Unlock()
 94		}
 95	}()
 96	return true
 97}
 98
 99// SetLimit limits the number of active goroutines in this group to at most n.
100// A negative value indicates no limit.
101//
102// Any subsequent call to the Go method will block until it can add an active
103// goroutine without exceeding the configured limit.
104//
105// The limit must not be modified while any goroutines in the group are active.
106func (g *ErrWaitGroup) SetLimit(n int) {
107	if n < 0 {
108		g.sem = nil
109		return
110	}
111	if len(g.sem) != 0 {
112		panic(fmt.Errorf("errwaitgroup: modify limit while %v goroutines in the group are still active", len(g.sem)))
113	}
114	g.sem = make(chan token, n)
115}