errgroup.go

  1// Copyright 2016 The Go Authors. All rights reserved.
  2// Use of this source code is governed by a BSD-style
  3// license that can be found in the LICENSE file.
  4
  5// Package errgroup provides synchronization, error propagation, and Context
  6// cancelation for groups of goroutines working on subtasks of a common task.
  7//
  8// [errgroup.Group] is related to [sync.WaitGroup] but adds handling of tasks
  9// returning errors.
 10package errgroup
 11
 12import (
 13	"context"
 14	"fmt"
 15	"runtime"
 16	"runtime/debug"
 17	"sync"
 18)
 19
 20type token struct{}
 21
 22// A Group is a collection of goroutines working on subtasks that are part of
 23// the same overall task. A Group should not be reused for different tasks.
 24//
 25// A zero Group is valid, has no limit on the number of active goroutines,
 26// and does not cancel on error.
 27type Group struct {
 28	cancel func(error)
 29
 30	wg sync.WaitGroup
 31
 32	sem chan token
 33
 34	errOnce sync.Once
 35	err     error
 36
 37	mu         sync.Mutex
 38	panicValue any  // = PanicError | PanicValue; non-nil if some Group.Go goroutine panicked.
 39	abnormal   bool // some Group.Go goroutine terminated abnormally (panic or goexit).
 40}
 41
 42func (g *Group) done() {
 43	if g.sem != nil {
 44		<-g.sem
 45	}
 46	g.wg.Done()
 47}
 48
 49// WithContext returns a new Group and an associated Context derived from ctx.
 50//
 51// The derived Context is canceled the first time a function passed to Go
 52// returns a non-nil error or the first time Wait returns, whichever occurs
 53// first.
 54func WithContext(ctx context.Context) (*Group, context.Context) {
 55	ctx, cancel := context.WithCancelCause(ctx)
 56	return &Group{cancel: cancel}, ctx
 57}
 58
 59// Wait blocks until all function calls from the Go method have returned
 60// normally, then returns the first non-nil error (if any) from them.
 61//
 62// If any of the calls panics, Wait panics with a [PanicValue];
 63// and if any of them calls [runtime.Goexit], Wait calls runtime.Goexit.
 64func (g *Group) Wait() error {
 65	g.wg.Wait()
 66	if g.cancel != nil {
 67		g.cancel(g.err)
 68	}
 69	if g.panicValue != nil {
 70		panic(g.panicValue)
 71	}
 72	if g.abnormal {
 73		runtime.Goexit()
 74	}
 75	return g.err
 76}
 77
 78// Go calls the given function in a new goroutine.
 79//
 80// The first call to Go must happen before a Wait.
 81// It blocks until the new goroutine can be added without the number of
 82// goroutines in the group exceeding the configured limit.
 83//
 84// The first goroutine in the group that returns a non-nil error, panics, or
 85// invokes [runtime.Goexit] will cancel the associated Context, if any.
 86func (g *Group) Go(f func() error) {
 87	if g.sem != nil {
 88		g.sem <- token{}
 89	}
 90
 91	g.add(f)
 92}
 93
 94func (g *Group) add(f func() error) {
 95	g.wg.Add(1)
 96	go func() {
 97		defer g.done()
 98		normalReturn := false
 99		defer func() {
100			if normalReturn {
101				return
102			}
103			v := recover()
104			g.mu.Lock()
105			defer g.mu.Unlock()
106			if !g.abnormal {
107				if g.cancel != nil {
108					g.cancel(g.err)
109				}
110				g.abnormal = true
111			}
112			if v != nil && g.panicValue == nil {
113				switch v := v.(type) {
114				case error:
115					g.panicValue = PanicError{
116						Recovered: v,
117						Stack:     debug.Stack(),
118					}
119				default:
120					g.panicValue = PanicValue{
121						Recovered: v,
122						Stack:     debug.Stack(),
123					}
124				}
125			}
126		}()
127
128		err := f()
129		normalReturn = true
130		if err != nil {
131			g.errOnce.Do(func() {
132				g.err = err
133				if g.cancel != nil {
134					g.cancel(g.err)
135				}
136			})
137		}
138	}()
139}
140
141// TryGo calls the given function in a new goroutine only if the number of
142// active goroutines in the group is currently below the configured limit.
143//
144// The return value reports whether the goroutine was started.
145func (g *Group) TryGo(f func() error) bool {
146	if g.sem != nil {
147		select {
148		case g.sem <- token{}:
149			// Note: this allows barging iff channels in general allow barging.
150		default:
151			return false
152		}
153	}
154
155	g.add(f)
156	return true
157}
158
159// SetLimit limits the number of active goroutines in this group to at most n.
160// A negative value indicates no limit.
161// A limit of zero will prevent any new goroutines from being added.
162//
163// Any subsequent call to the Go method will block until it can add an active
164// goroutine without exceeding the configured limit.
165//
166// The limit must not be modified while any goroutines in the group are active.
167func (g *Group) SetLimit(n int) {
168	if n < 0 {
169		g.sem = nil
170		return
171	}
172	if len(g.sem) != 0 {
173		panic(fmt.Errorf("errgroup: modify limit while %v goroutines in the group are still active", len(g.sem)))
174	}
175	g.sem = make(chan token, n)
176}
177
178// PanicError wraps an error recovered from an unhandled panic
179// when calling a function passed to Go or TryGo.
180type PanicError struct {
181	Recovered error
182	Stack     []byte // result of call to [debug.Stack]
183}
184
185func (p PanicError) Error() string {
186	if len(p.Stack) > 0 {
187		return fmt.Sprintf("recovered from errgroup.Group: %v\n%s", p.Recovered, p.Stack)
188	}
189	return fmt.Sprintf("recovered from errgroup.Group: %v", p.Recovered)
190}
191
192func (p PanicError) Unwrap() error { return p.Recovered }
193
194// PanicValue wraps a value that does not implement the error interface,
195// recovered from an unhandled panic when calling a function passed to Go or
196// TryGo.
197type PanicValue struct {
198	Recovered any
199	Stack     []byte // result of call to [debug.Stack]
200}
201
202func (p PanicValue) String() string {
203	if len(p.Stack) > 0 {
204		return fmt.Sprintf("recovered from errgroup.Group: %v\n%s", p.Recovered, p.Stack)
205	}
206	return fmt.Sprintf("recovered from errgroup.Group: %v", p.Recovered)
207}