1package test
2
3import (
4 "errors"
5 "math/rand"
6 "testing"
7 "time"
8)
9
10type flaky struct {
11 t testing.TB
12 o *FlakyOptions
13}
14
15type FlakyOptions struct {
16 InitialBackoff time.Duration
17 MaxAttempts int
18 Jitter float64
19}
20
21func NewFlaky(t testing.TB, o *FlakyOptions) *flaky {
22 if o.InitialBackoff <= 0 {
23 o.InitialBackoff = 500 * time.Millisecond
24 }
25
26 if o.MaxAttempts <= 0 {
27 o.MaxAttempts = 3
28 }
29
30 if o.Jitter < 0 {
31 o.Jitter = 0
32 }
33
34 return &flaky{t: t, o: o}
35}
36
37func (f *flaky) Run(fn func(t testing.TB)) {
38 var last error
39
40 for attempt := 1; attempt <= f.o.MaxAttempts; attempt++ {
41 f.t.Logf("attempt %d of %d", attempt, f.o.MaxAttempts)
42
43 r := &recorder{
44 TB: f.t,
45 fail: func(s string) { last = errors.New(s) },
46 fatal: func(s string) { last = errors.New(s) },
47 }
48
49 func() {
50 defer func() {
51 if v := recover(); v != nil {
52 if code, ok := v.(int); ok && code != RecorderFailNow {
53 panic(v)
54 }
55 }
56 }()
57 fn(r)
58 }()
59
60 if !r.Failed() {
61 return
62 }
63
64 if attempt < f.o.MaxAttempts {
65 backoff := f.o.InitialBackoff * time.Duration(1<<uint(attempt-1))
66 time.Sleep(applyJitter(backoff, f.o.Jitter))
67 }
68 }
69
70 f.t.Fatalf("[%s] test failed after %d attempts: %v", f.t.Name(), f.o.MaxAttempts, last)
71}
72
73func applyJitter(d time.Duration, jitter float64) time.Duration {
74 if jitter == 0 {
75 return d
76 }
77 maxJitter := float64(d) * jitter
78 delta := maxJitter * (rand.Float64()*2 - 1)
79 return time.Duration(float64(d) + delta)
80}