retry.go

  1package fantasy
  2
  3import (
  4	"context"
  5	"errors"
  6	"strconv"
  7	"time"
  8)
  9
 10// RetryFn is a function that returns a value and an error.
 11type RetryFn[T any] func() (T, error)
 12
 13// RetryFunction is a function that retries another function.
 14type RetryFunction[T any] func(ctx context.Context, fn RetryFn[T]) (T, error)
 15
 16// getRetryDelayInMs calculates the retry delay based on error headers and exponential backoff.
 17func getRetryDelayInMs(err error, exponentialBackoffDelay time.Duration) time.Duration {
 18	var providerErr *ProviderError
 19	if !errors.As(err, &providerErr) || providerErr.ResponseHeaders == nil {
 20		return exponentialBackoffDelay
 21	}
 22
 23	headers := providerErr.ResponseHeaders
 24	var ms time.Duration
 25
 26	// retry-ms is more precise than retry-after and used by e.g. OpenAI
 27	if retryAfterMs, exists := headers["retry-after-ms"]; exists {
 28		if timeoutMs, err := strconv.ParseFloat(retryAfterMs, 64); err == nil {
 29			ms = time.Duration(timeoutMs) * time.Millisecond
 30		}
 31	}
 32
 33	// About the Retry-After header: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Retry-After
 34	if retryAfter, exists := headers["retry-after"]; exists && ms == 0 {
 35		if timeoutSeconds, err := strconv.ParseFloat(retryAfter, 64); err == nil {
 36			ms = time.Duration(timeoutSeconds) * time.Second
 37		} else {
 38			// Try parsing as HTTP date
 39			if t, err := time.Parse(time.RFC1123, retryAfter); err == nil {
 40				ms = time.Until(t)
 41			}
 42		}
 43	}
 44
 45	// Check that the delay is reasonable:
 46	// 0 <= ms < 60 seconds or ms < exponentialBackoffDelay
 47	if ms > 0 && (ms < 60*time.Second || ms < exponentialBackoffDelay) {
 48		return ms
 49	}
 50
 51	return exponentialBackoffDelay
 52}
 53
 54// isAbortError checks if the error is a context cancellation error.
 55func isAbortError(err error) bool {
 56	return errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded)
 57}
 58
 59// RetryWithExponentialBackoffRespectingRetryHeaders creates a retry function that retries
 60// a failed operation with exponential backoff, while respecting rate limit headers
 61// (retry-after-ms and retry-after) if they are provided and reasonable (0-60 seconds).
 62func RetryWithExponentialBackoffRespectingRetryHeaders[T any](options RetryOptions) RetryFunction[T] {
 63	return func(ctx context.Context, fn RetryFn[T]) (T, error) {
 64		return retryWithExponentialBackoff(ctx, fn, options, nil)
 65	}
 66}
 67
 68// RetryOptions configures the retry behavior.
 69type RetryOptions struct {
 70	MaxRetries     int
 71	InitialDelayIn time.Duration
 72	BackoffFactor  float64
 73	OnRetry        OnRetryCallback
 74}
 75
 76// OnRetryCallback defines a function that is called when a retry occurs.
 77type OnRetryCallback = func(err *ProviderError, delay time.Duration)
 78
 79// DefaultRetryOptions returns the default retry options.
 80// DefaultRetryOptions returns the default retry options.
 81func DefaultRetryOptions() RetryOptions {
 82	return RetryOptions{
 83		MaxRetries:     2,
 84		InitialDelayIn: 2000 * time.Millisecond,
 85		BackoffFactor:  2.0,
 86	}
 87}
 88
 89// retryWithExponentialBackoff implements the retry logic with exponential backoff.
 90func retryWithExponentialBackoff[T any](ctx context.Context, fn RetryFn[T], options RetryOptions, allErrors []error) (T, error) {
 91	var zero T
 92	result, err := fn()
 93	if err == nil {
 94		return result, nil
 95	}
 96
 97	if isAbortError(err) {
 98		return zero, err // don't retry when the request was aborted
 99	}
100
101	if options.MaxRetries == 0 {
102		return zero, err // don't wrap the error when retries are disabled
103	}
104
105	newErrors := append(allErrors, err)
106	tryNumber := len(newErrors)
107
108	if tryNumber > options.MaxRetries {
109		return zero, &RetryError{newErrors}
110	}
111
112	var providerErr *ProviderError
113	if errors.As(err, &providerErr) && providerErr.IsRetryable() && tryNumber <= options.MaxRetries {
114		delay := getRetryDelayInMs(err, options.InitialDelayIn)
115		if options.OnRetry != nil {
116			options.OnRetry(providerErr, delay)
117		}
118
119		select {
120		case <-time.After(delay):
121			// Continue with retry
122		case <-ctx.Done():
123			return zero, ctx.Err()
124		}
125
126		newOptions := options
127		newOptions.InitialDelayIn = time.Duration(float64(options.InitialDelayIn) * options.BackoffFactor)
128
129		return retryWithExponentialBackoff(ctx, fn, newOptions, newErrors)
130	}
131
132	if tryNumber == 1 {
133		return zero, err // don't wrap the error when a non-retryable error occurs on the first try
134	}
135
136	return zero, &RetryError{newErrors}
137}