1package fantasy
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "strconv"
8 "time"
9)
10
11// RetryFn is a function that returns a value and an error.
12type RetryFn[T any] func() (T, error)
13
14// RetryFunction is a function that retries another function.
15type RetryFunction[T any] func(ctx context.Context, fn RetryFn[T]) (T, error)
16
17// RetryReason represents the reason why a retry operation failed.
18type RetryReason string
19
20const (
21 // RetryReasonMaxRetriesExceeded indicates the maximum number of retries was exceeded.
22 RetryReasonMaxRetriesExceeded RetryReason = "maxRetriesExceeded"
23 // RetryReasonErrorNotRetryable indicates the error is not retryable.
24 RetryReasonErrorNotRetryable RetryReason = "errorNotRetryable"
25)
26
27// RetryError represents an error that occurred during retry operations.
28type RetryError struct {
29 *AIError
30 Reason RetryReason
31 Errors []error
32}
33
34// NewRetryError creates a new retry error.
35func NewRetryError(message string, reason RetryReason, errors []error) *RetryError {
36 return &RetryError{
37 AIError: NewAIError("AI_RetryError", message, nil),
38 Reason: reason,
39 Errors: errors,
40 }
41}
42
43// getRetryDelayInMs calculates the retry delay based on error headers and exponential backoff.
44func getRetryDelayInMs(err error, exponentialBackoffDelay time.Duration) time.Duration {
45 var apiErr *APICallError
46 if !errors.As(err, &apiErr) || apiErr.ResponseHeaders == nil {
47 return exponentialBackoffDelay
48 }
49
50 headers := apiErr.ResponseHeaders
51 var ms time.Duration
52
53 // retry-ms is more precise than retry-after and used by e.g. OpenAI
54 if retryAfterMs, exists := headers["retry-after-ms"]; exists {
55 if timeoutMs, err := strconv.ParseFloat(retryAfterMs, 64); err == nil {
56 ms = time.Duration(timeoutMs) * time.Millisecond
57 }
58 }
59
60 // About the Retry-After header: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Retry-After
61 if retryAfter, exists := headers["retry-after"]; exists && ms == 0 {
62 if timeoutSeconds, err := strconv.ParseFloat(retryAfter, 64); err == nil {
63 ms = time.Duration(timeoutSeconds) * time.Second
64 } else {
65 // Try parsing as HTTP date
66 if t, err := time.Parse(time.RFC1123, retryAfter); err == nil {
67 ms = time.Until(t)
68 }
69 }
70 }
71
72 // Check that the delay is reasonable:
73 // 0 <= ms < 60 seconds or ms < exponentialBackoffDelay
74 if ms > 0 && (ms < 60*time.Second || ms < exponentialBackoffDelay) {
75 return ms
76 }
77
78 return exponentialBackoffDelay
79}
80
81// isAbortError checks if the error is a context cancellation error.
82func isAbortError(err error) bool {
83 return errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded)
84}
85
86// RetryWithExponentialBackoffRespectingRetryHeaders creates a retry function that retries
87// a failed operation with exponential backoff, while respecting rate limit headers
88// (retry-after-ms and retry-after) if they are provided and reasonable (0-60 seconds).
89func RetryWithExponentialBackoffRespectingRetryHeaders[T any](options RetryOptions) RetryFunction[T] {
90 return func(ctx context.Context, fn RetryFn[T]) (T, error) {
91 return retryWithExponentialBackoff(ctx, fn, options, nil)
92 }
93}
94
95// RetryOptions configures the retry behavior.
96type RetryOptions struct {
97 MaxRetries int
98 InitialDelayIn time.Duration
99 BackoffFactor float64
100 OnRetry OnRetryCallback
101}
102
103// OnRetryCallback defines a function that is called when a retry occurs.
104type OnRetryCallback = func(err *APICallError, delay time.Duration)
105
106// DefaultRetryOptions returns the default retry options.
107// DefaultRetryOptions returns the default retry options.
108func DefaultRetryOptions() RetryOptions {
109 return RetryOptions{
110 MaxRetries: 2,
111 InitialDelayIn: 2000 * time.Millisecond,
112 BackoffFactor: 2.0,
113 }
114}
115
116// retryWithExponentialBackoff implements the retry logic with exponential backoff.
117func retryWithExponentialBackoff[T any](ctx context.Context, fn RetryFn[T], options RetryOptions, allErrors []error) (T, error) {
118 var zero T
119 result, err := fn()
120 if err == nil {
121 return result, nil
122 }
123
124 if isAbortError(err) {
125 return zero, err // don't retry when the request was aborted
126 }
127
128 if options.MaxRetries == 0 {
129 return zero, err // don't wrap the error when retries are disabled
130 }
131
132 errorMessage := GetErrorMessage(err)
133 newErrors := append(allErrors, err)
134 tryNumber := len(newErrors)
135
136 if tryNumber > options.MaxRetries {
137 return zero, NewRetryError(
138 fmt.Sprintf("Failed after %d attempts. Last error: %s", tryNumber, errorMessage),
139 RetryReasonMaxRetriesExceeded,
140 newErrors,
141 )
142 }
143
144 var apiErr *APICallError
145 if errors.As(err, &apiErr) && apiErr.IsRetryable && tryNumber <= options.MaxRetries {
146 delay := getRetryDelayInMs(err, options.InitialDelayIn)
147 if options.OnRetry != nil {
148 options.OnRetry(apiErr, delay)
149 }
150
151 select {
152 case <-time.After(delay):
153 // Continue with retry
154 case <-ctx.Done():
155 return zero, ctx.Err()
156 }
157
158 newOptions := options
159 newOptions.InitialDelayIn = time.Duration(float64(options.InitialDelayIn) * options.BackoffFactor)
160
161 return retryWithExponentialBackoff(ctx, fn, newOptions, newErrors)
162 }
163
164 if tryNumber == 1 {
165 return zero, err // don't wrap the error when a non-retryable error occurs on the first try
166 }
167
168 return zero, NewRetryError(
169 fmt.Sprintf("Failed after %d attempts with non-retryable error: '%s'", tryNumber, errorMessage),
170 RetryReasonErrorNotRetryable,
171 newErrors,
172 )
173}