retry.go

  1package fantasy
  2
  3import (
  4	"context"
  5	"errors"
  6	"fmt"
  7	"strconv"
  8	"time"
  9)
 10
 11// RetryFn is a function that returns a value and an error.
 12type RetryFn[T any] func() (T, error)
 13
 14// RetryFunction is a function that retries another function.
 15type RetryFunction[T any] func(ctx context.Context, fn RetryFn[T]) (T, error)
 16
 17// RetryReason represents the reason why a retry operation failed.
 18type RetryReason string
 19
 20const (
 21	// RetryReasonMaxRetriesExceeded indicates the maximum number of retries was exceeded.
 22	RetryReasonMaxRetriesExceeded RetryReason = "maxRetriesExceeded"
 23	// RetryReasonErrorNotRetryable indicates the error is not retryable.
 24	RetryReasonErrorNotRetryable RetryReason = "errorNotRetryable"
 25)
 26
 27// RetryError represents an error that occurred during retry operations.
 28type RetryError struct {
 29	*AIError
 30	Reason RetryReason
 31	Errors []error
 32}
 33
 34// NewRetryError creates a new retry error.
 35func NewRetryError(message string, reason RetryReason, errors []error) *RetryError {
 36	return &RetryError{
 37		AIError: NewAIError("AI_RetryError", message, nil),
 38		Reason:  reason,
 39		Errors:  errors,
 40	}
 41}
 42
 43// getRetryDelayInMs calculates the retry delay based on error headers and exponential backoff.
 44func getRetryDelayInMs(err error, exponentialBackoffDelay time.Duration) time.Duration {
 45	var apiErr *APICallError
 46	if !errors.As(err, &apiErr) || apiErr.ResponseHeaders == nil {
 47		return exponentialBackoffDelay
 48	}
 49
 50	headers := apiErr.ResponseHeaders
 51	var ms time.Duration
 52
 53	// retry-ms is more precise than retry-after and used by e.g. OpenAI
 54	if retryAfterMs, exists := headers["retry-after-ms"]; exists {
 55		if timeoutMs, err := strconv.ParseFloat(retryAfterMs, 64); err == nil {
 56			ms = time.Duration(timeoutMs) * time.Millisecond
 57		}
 58	}
 59
 60	// About the Retry-After header: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Retry-After
 61	if retryAfter, exists := headers["retry-after"]; exists && ms == 0 {
 62		if timeoutSeconds, err := strconv.ParseFloat(retryAfter, 64); err == nil {
 63			ms = time.Duration(timeoutSeconds) * time.Second
 64		} else {
 65			// Try parsing as HTTP date
 66			if t, err := time.Parse(time.RFC1123, retryAfter); err == nil {
 67				ms = time.Until(t)
 68			}
 69		}
 70	}
 71
 72	// Check that the delay is reasonable:
 73	// 0 <= ms < 60 seconds or ms < exponentialBackoffDelay
 74	if ms > 0 && (ms < 60*time.Second || ms < exponentialBackoffDelay) {
 75		return ms
 76	}
 77
 78	return exponentialBackoffDelay
 79}
 80
 81// isAbortError checks if the error is a context cancellation error.
 82func isAbortError(err error) bool {
 83	return errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded)
 84}
 85
 86// RetryWithExponentialBackoffRespectingRetryHeaders creates a retry function that retries
 87// a failed operation with exponential backoff, while respecting rate limit headers
 88// (retry-after-ms and retry-after) if they are provided and reasonable (0-60 seconds).
 89func RetryWithExponentialBackoffRespectingRetryHeaders[T any](options RetryOptions) RetryFunction[T] {
 90	return func(ctx context.Context, fn RetryFn[T]) (T, error) {
 91		return retryWithExponentialBackoff(ctx, fn, options, nil)
 92	}
 93}
 94
 95// RetryOptions configures the retry behavior.
 96type RetryOptions struct {
 97	MaxRetries     int
 98	InitialDelayIn time.Duration
 99	BackoffFactor  float64
100	OnRetry        OnRetryCallback
101}
102
103// OnRetryCallback defines a function that is called when a retry occurs.
104type OnRetryCallback = func(err *APICallError, delay time.Duration)
105
106// DefaultRetryOptions returns the default retry options.
107// DefaultRetryOptions returns the default retry options.
108func DefaultRetryOptions() RetryOptions {
109	return RetryOptions{
110		MaxRetries:     2,
111		InitialDelayIn: 2000 * time.Millisecond,
112		BackoffFactor:  2.0,
113	}
114}
115
116// retryWithExponentialBackoff implements the retry logic with exponential backoff.
117func retryWithExponentialBackoff[T any](ctx context.Context, fn RetryFn[T], options RetryOptions, allErrors []error) (T, error) {
118	var zero T
119	result, err := fn()
120	if err == nil {
121		return result, nil
122	}
123
124	if isAbortError(err) {
125		return zero, err // don't retry when the request was aborted
126	}
127
128	if options.MaxRetries == 0 {
129		return zero, err // don't wrap the error when retries are disabled
130	}
131
132	errorMessage := GetErrorMessage(err)
133	newErrors := append(allErrors, err)
134	tryNumber := len(newErrors)
135
136	if tryNumber > options.MaxRetries {
137		return zero, NewRetryError(
138			fmt.Sprintf("Failed after %d attempts. Last error: %s", tryNumber, errorMessage),
139			RetryReasonMaxRetriesExceeded,
140			newErrors,
141		)
142	}
143
144	var apiErr *APICallError
145	if errors.As(err, &apiErr) && apiErr.IsRetryable && tryNumber <= options.MaxRetries {
146		delay := getRetryDelayInMs(err, options.InitialDelayIn)
147		if options.OnRetry != nil {
148			options.OnRetry(apiErr, delay)
149		}
150
151		select {
152		case <-time.After(delay):
153			// Continue with retry
154		case <-ctx.Done():
155			return zero, ctx.Err()
156		}
157
158		newOptions := options
159		newOptions.InitialDelayIn = time.Duration(float64(options.InitialDelayIn) * options.BackoffFactor)
160
161		return retryWithExponentialBackoff(ctx, fn, newOptions, newErrors)
162	}
163
164	if tryNumber == 1 {
165		return zero, err // don't wrap the error when a non-retryable error occurs on the first try
166	}
167
168	return zero, NewRetryError(
169		fmt.Sprintf("Failed after %d attempts with non-retryable error: '%s'", tryNumber, errorMessage),
170		RetryReasonErrorNotRetryable,
171		newErrors,
172	)
173}