1package fantasy
2
3import (
4 "context"
5 "errors"
6 "strconv"
7 "time"
8)
9
10// RetryFn is a function that returns a value and an error.
11type RetryFn[T any] func() (T, error)
12
13// RetryFunction is a function that retries another function.
14type RetryFunction[T any] func(ctx context.Context, fn RetryFn[T]) (T, error)
15
16// getRetryDelayInMs calculates the retry delay based on error headers and exponential backoff.
17func getRetryDelayInMs(err error, exponentialBackoffDelay time.Duration) time.Duration {
18 var providerErr *ProviderError
19 if !errors.As(err, &providerErr) || providerErr.ResponseHeaders == nil {
20 return exponentialBackoffDelay
21 }
22
23 headers := providerErr.ResponseHeaders
24 var ms time.Duration
25
26 // retry-ms is more precise than retry-after and used by e.g. OpenAI
27 if retryAfterMs, exists := headers["retry-after-ms"]; exists {
28 if timeoutMs, err := strconv.ParseFloat(retryAfterMs, 64); err == nil {
29 ms = time.Duration(timeoutMs) * time.Millisecond
30 }
31 }
32
33 // About the Retry-After header: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Retry-After
34 if retryAfter, exists := headers["retry-after"]; exists && ms == 0 {
35 if timeoutSeconds, err := strconv.ParseFloat(retryAfter, 64); err == nil {
36 ms = time.Duration(timeoutSeconds) * time.Second
37 } else {
38 // Try parsing as HTTP date
39 if t, err := time.Parse(time.RFC1123, retryAfter); err == nil {
40 ms = time.Until(t)
41 }
42 }
43 }
44
45 // Check that the delay is reasonable:
46 // 0 <= ms < 60 seconds or ms < exponentialBackoffDelay
47 if ms > 0 && (ms < 60*time.Second || ms < exponentialBackoffDelay) {
48 return ms
49 }
50
51 return exponentialBackoffDelay
52}
53
54// isAbortError checks if the error is a context cancellation error.
55func isAbortError(err error) bool {
56 return errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded)
57}
58
59// RetryWithExponentialBackoffRespectingRetryHeaders creates a retry function that retries
60// a failed operation with exponential backoff, while respecting rate limit headers
61// (retry-after-ms and retry-after) if they are provided and reasonable (0-60 seconds).
62func RetryWithExponentialBackoffRespectingRetryHeaders[T any](options RetryOptions) RetryFunction[T] {
63 return func(ctx context.Context, fn RetryFn[T]) (T, error) {
64 return retryWithExponentialBackoff(ctx, fn, options, nil)
65 }
66}
67
68// RetryOptions configures the retry behavior.
69type RetryOptions struct {
70 MaxRetries int
71 InitialDelayIn time.Duration
72 BackoffFactor float64
73 OnRetry OnRetryCallback
74}
75
76// OnRetryCallback defines a function that is called when a retry occurs.
77type OnRetryCallback = func(err *ProviderError, delay time.Duration)
78
79// DefaultRetryOptions returns the default retry options.
80// DefaultRetryOptions returns the default retry options.
81func DefaultRetryOptions() RetryOptions {
82 return RetryOptions{
83 MaxRetries: 2,
84 InitialDelayIn: 2000 * time.Millisecond,
85 BackoffFactor: 2.0,
86 }
87}
88
89// retryWithExponentialBackoff implements the retry logic with exponential backoff.
90func retryWithExponentialBackoff[T any](ctx context.Context, fn RetryFn[T], options RetryOptions, allErrors []error) (T, error) {
91 var zero T
92 result, err := fn()
93 if err == nil {
94 return result, nil
95 }
96
97 if isAbortError(err) {
98 return zero, err // don't retry when the request was aborted
99 }
100
101 if options.MaxRetries == 0 {
102 return zero, err // don't wrap the error when retries are disabled
103 }
104
105 newErrors := append(allErrors, err)
106 tryNumber := len(newErrors)
107
108 if tryNumber > options.MaxRetries {
109 return zero, &RetryError{newErrors}
110 }
111
112 var providerErr *ProviderError
113 if errors.As(err, &providerErr) && providerErr.IsRetryable() && tryNumber <= options.MaxRetries {
114 delay := getRetryDelayInMs(err, options.InitialDelayIn)
115 if options.OnRetry != nil {
116 options.OnRetry(providerErr, delay)
117 }
118
119 select {
120 case <-time.After(delay):
121 // Continue with retry
122 case <-ctx.Done():
123 return zero, ctx.Err()
124 }
125
126 newOptions := options
127 newOptions.InitialDelayIn = time.Duration(float64(options.InitialDelayIn) * options.BackoffFactor)
128
129 return retryWithExponentialBackoff(ctx, fn, newOptions, newErrors)
130 }
131
132 if tryNumber == 1 {
133 return zero, err // don't wrap the error when a non-retryable error occurs on the first try
134 }
135
136 return zero, &RetryError{newErrors}
137}