1package test
 2
 3import (
 4	"errors"
 5	"math/rand"
 6	"slices"
 7	"testing"
 8	"time"
 9)
10
11type flaky struct {
12	t testing.TB
13	o *FlakyOptions
14}
15
16type FlakyOptions struct {
17	InitialBackoff time.Duration
18	MaxAttempts    int
19	Jitter         float64
20}
21
22func NewFlaky(t testing.TB, o *FlakyOptions) *flaky {
23	if o.InitialBackoff <= 0 {
24		o.InitialBackoff = 500 * time.Millisecond
25	}
26
27	if o.MaxAttempts <= 0 {
28		o.MaxAttempts = 3
29	}
30
31	if o.Jitter < 0 {
32		o.Jitter = 0
33	}
34
35	return &flaky{t: t, o: o}
36}
37
38func (f *flaky) Run(fn func(t testing.TB)) {
39	var last error
40
41	for attempt := 1; attempt <= f.o.MaxAttempts; attempt++ {
42		r := &recorder{
43			TB:    f.t,
44			fail:  func(s string) { last = errors.New(s) },
45			fatal: func(s string) { last = errors.New(s) },
46		}
47
48		func() {
49			failCodes := []int{RecorderFailNow, RecorderFatalf, RecorderFatal}
50			defer func() {
51				if rec := recover(); rec != nil {
52					if code, ok := rec.(int); ok && !slices.Contains(failCodes, code) {
53						panic(rec)
54					}
55				}
56			}()
57
58			fn(r)
59		}()
60
61		if !r.Failed() {
62			return
63		}
64
65		if attempt < f.o.MaxAttempts {
66			backoff := f.o.InitialBackoff * time.Duration(1<<uint(attempt-1))
67			time.Sleep(applyJitter(backoff, f.o.Jitter))
68		}
69	}
70
71	f.t.Fatalf("[%s] test failed after %d attempts: %v", f.t.Name(), f.o.MaxAttempts, last)
72}
73
74func applyJitter(d time.Duration, jitter float64) time.Duration {
75	if jitter == 0 {
76		return d
77	}
78	maxJitter := float64(d) * jitter
79	delta := maxJitter * (rand.Float64()*2 - 1)
80	return time.Duration(float64(d) + delta)
81}