test.go

 1package test
 2
 3import (
 4	"errors"
 5	"math/rand"
 6	"testing"
 7	"time"
 8)
 9
10type flaky struct {
11	t testing.TB
12	o *FlakyOptions
13}
14
15type FlakyOptions struct {
16	InitialBackoff time.Duration
17	MaxAttempts    int
18	Jitter         float64
19}
20
21func NewFlaky(t testing.TB, o *FlakyOptions) *flaky {
22	if o.InitialBackoff <= 0 {
23		o.InitialBackoff = 500 * time.Millisecond
24	}
25
26	if o.MaxAttempts <= 0 {
27		o.MaxAttempts = 3
28	}
29
30	if o.Jitter < 0 {
31		o.Jitter = 0
32	}
33
34	return &flaky{t: t, o: o}
35}
36
37func (f *flaky) Run(fn func(t testing.TB)) {
38	var last error
39
40	for attempt := 1; attempt <= f.o.MaxAttempts; attempt++ {
41		f.t.Logf("attempt %d of %d", attempt, f.o.MaxAttempts)
42
43		r := &recorder{
44			TB:    f.t,
45			fail:  func(s string) { last = errors.New(s) },
46			fatal: func(s string) { last = errors.New(s) },
47		}
48
49		func() {
50			defer func() {
51				if v := recover(); v != nil {
52					if code, ok := v.(int); ok && code != RecorderFailNow {
53						panic(v)
54					}
55				}
56			}()
57			fn(r)
58		}()
59
60		if !r.Failed() {
61			return
62		}
63
64		if attempt < f.o.MaxAttempts {
65			backoff := f.o.InitialBackoff * time.Duration(1<<uint(attempt-1))
66			time.Sleep(applyJitter(backoff, f.o.Jitter))
67		}
68	}
69
70	f.t.Fatalf("[%s] test failed after %d attempts: %v", f.t.Name(), f.o.MaxAttempts, last)
71}
72
73func applyJitter(d time.Duration, jitter float64) time.Duration {
74	if jitter == 0 {
75		return d
76	}
77	maxJitter := float64(d) * jitter
78	delta := maxJitter * (rand.Float64()*2 - 1)
79	return time.Duration(float64(d) + delta)
80}