1package test
2
3import (
4 "errors"
5 "math/rand"
6 "slices"
7 "testing"
8 "time"
9)
10
11type flaky struct {
12 t testing.TB
13 o *FlakyOptions
14}
15
16type FlakyOptions struct {
17 InitialBackoff time.Duration
18 MaxAttempts int
19 Jitter float64
20}
21
22func NewFlaky(t testing.TB, o *FlakyOptions) *flaky {
23 if o.InitialBackoff <= 0 {
24 o.InitialBackoff = 500 * time.Millisecond
25 }
26
27 if o.MaxAttempts <= 0 {
28 o.MaxAttempts = 3
29 }
30
31 if o.Jitter < 0 {
32 o.Jitter = 0
33 }
34
35 return &flaky{t: t, o: o}
36}
37
38func (f *flaky) Run(fn func(t testing.TB)) {
39 var last error
40
41 for attempt := 1; attempt <= f.o.MaxAttempts; attempt++ {
42 r := &recorder{
43 TB: f.t,
44 fail: func(s string) { last = errors.New(s) },
45 fatal: func(s string) { last = errors.New(s) },
46 }
47
48 func() {
49 failCodes := []int{RecorderFailNow, RecorderFatalf, RecorderFatal}
50 defer func() {
51 if rec := recover(); rec != nil {
52 if code, ok := rec.(int); ok && !slices.Contains(failCodes, code) {
53 panic(rec)
54 }
55 }
56 }()
57
58 fn(r)
59 }()
60
61 if !r.Failed() {
62 return
63 }
64
65 if attempt < f.o.MaxAttempts {
66 backoff := f.o.InitialBackoff * time.Duration(1<<uint(attempt-1))
67 time.Sleep(applyJitter(backoff, f.o.Jitter))
68 }
69 }
70
71 f.t.Fatalf("[%s] test failed after %d attempts: %v", f.t.Name(), f.o.MaxAttempts, last)
72}
73
74func applyJitter(d time.Duration, jitter float64) time.Duration {
75 if jitter == 0 {
76 return d
77 }
78 maxJitter := float64(d) * jitter
79 delta := maxJitter * (rand.Float64()*2 - 1)
80 return time.Duration(float64(d) + delta)
81}