1package ai
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "strconv"
8 "time"
9)
10
11// RetryFn is a function that returns a value and an error.
12type RetryFn[T any] func() (T, error)
13
14// RetryFunction is a function that retries another function.
15type RetryFunction[T any] func(ctx context.Context, fn RetryFn[T]) (T, error)
16
17// RetryReason represents the reason why a retry operation failed.
18type RetryReason string
19
20const (
21 RetryReasonMaxRetriesExceeded RetryReason = "maxRetriesExceeded"
22 RetryReasonErrorNotRetryable RetryReason = "errorNotRetryable"
23)
24
25// RetryError represents an error that occurred during retry operations.
26type RetryError struct {
27 *AIError
28 Reason RetryReason
29 Errors []error
30}
31
32// NewRetryError creates a new retry error.
33func NewRetryError(message string, reason RetryReason, errors []error) *RetryError {
34 return &RetryError{
35 AIError: NewAIError("AI_RetryError", message, nil),
36 Reason: reason,
37 Errors: errors,
38 }
39}
40
41// getRetryDelayInMs calculates the retry delay based on error headers and exponential backoff.
42func getRetryDelayInMs(err error, exponentialBackoffDelay time.Duration) time.Duration {
43 var apiErr *APICallError
44 if !errors.As(err, &apiErr) || apiErr.ResponseHeaders == nil {
45 return exponentialBackoffDelay
46 }
47
48 headers := apiErr.ResponseHeaders
49 var ms time.Duration
50
51 // retry-ms is more precise than retry-after and used by e.g. OpenAI
52 if retryAfterMs, exists := headers["retry-after-ms"]; exists {
53 if timeoutMs, err := strconv.ParseFloat(retryAfterMs, 64); err == nil {
54 ms = time.Duration(timeoutMs) * time.Millisecond
55 }
56 }
57
58 // About the Retry-After header: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Retry-After
59 if retryAfter, exists := headers["retry-after"]; exists && ms == 0 {
60 if timeoutSeconds, err := strconv.ParseFloat(retryAfter, 64); err == nil {
61 ms = time.Duration(timeoutSeconds) * time.Second
62 } else {
63 // Try parsing as HTTP date
64 if t, err := time.Parse(time.RFC1123, retryAfter); err == nil {
65 ms = time.Until(t)
66 }
67 }
68 }
69
70 // Check that the delay is reasonable:
71 // 0 <= ms < 60 seconds or ms < exponentialBackoffDelay
72 if ms > 0 && (ms < 60*time.Second || ms < exponentialBackoffDelay) {
73 return ms
74 }
75
76 return exponentialBackoffDelay
77}
78
79// isAbortError checks if the error is a context cancellation error.
80func isAbortError(err error) bool {
81 return errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded)
82}
83
84// RetryWithExponentialBackoffRespectingRetryHeaders creates a retry function that retries
85// a failed operation with exponential backoff, while respecting rate limit headers
86// (retry-after-ms and retry-after) if they are provided and reasonable (0-60 seconds).
87func RetryWithExponentialBackoffRespectingRetryHeaders[T any](options RetryOptions) RetryFunction[T] {
88 return func(ctx context.Context, fn RetryFn[T]) (T, error) {
89 return retryWithExponentialBackoff(ctx, fn, options, nil)
90 }
91}
92
93// RetryOptions configures the retry behavior.
94type RetryOptions struct {
95 MaxRetries int
96 InitialDelayIn time.Duration
97 BackoffFactor float64
98 OnRetry OnRetryCallback
99}
100
101type OnRetryCallback = func(err *APICallError, delay time.Duration)
102
103// DefaultRetryOptions returns the default retry options.
104func DefaultRetryOptions() RetryOptions {
105 return RetryOptions{
106 MaxRetries: 2,
107 InitialDelayIn: 2000 * time.Millisecond,
108 BackoffFactor: 2.0,
109 }
110}
111
112// retryWithExponentialBackoff implements the retry logic with exponential backoff.
113func retryWithExponentialBackoff[T any](ctx context.Context, fn RetryFn[T], options RetryOptions, allErrors []error) (T, error) {
114 var zero T
115 result, err := fn()
116 if err == nil {
117 return result, nil
118 }
119
120 if isAbortError(err) {
121 return zero, err // don't retry when the request was aborted
122 }
123
124 if options.MaxRetries == 0 {
125 return zero, err // don't wrap the error when retries are disabled
126 }
127
128 errorMessage := GetErrorMessage(err)
129 newErrors := append(allErrors, err)
130 tryNumber := len(newErrors)
131
132 if tryNumber > options.MaxRetries {
133 return zero, NewRetryError(
134 fmt.Sprintf("Failed after %d attempts. Last error: %s", tryNumber, errorMessage),
135 RetryReasonMaxRetriesExceeded,
136 newErrors,
137 )
138 }
139
140 var apiErr *APICallError
141 if errors.As(err, &apiErr) && apiErr.IsRetryable && tryNumber <= options.MaxRetries {
142 delay := getRetryDelayInMs(err, options.InitialDelayIn)
143 if options.OnRetry != nil {
144 options.OnRetry(apiErr, delay)
145 }
146
147 select {
148 case <-time.After(delay):
149 // Continue with retry
150 case <-ctx.Done():
151 return zero, ctx.Err()
152 }
153
154 newOptions := options
155 newOptions.InitialDelayIn = time.Duration(float64(options.InitialDelayIn) * options.BackoffFactor)
156
157 return retryWithExponentialBackoff(ctx, fn, newOptions, newErrors)
158 }
159
160 if tryNumber == 1 {
161 return zero, err // don't wrap the error when a non-retryable error occurs on the first try
162 }
163
164 return zero, NewRetryError(
165 fmt.Sprintf("Failed after %d attempts with non-retryable error: '%s'", tryNumber, errorMessage),
166 RetryReasonErrorNotRetryable,
167 newErrors,
168 )
169}
170