retry.go

  1package ai
  2
  3import (
  4	"context"
  5	"errors"
  6	"fmt"
  7	"strconv"
  8	"time"
  9)
 10
 11// RetryFn is a function that returns a value and an error.
 12type RetryFn[T any] func() (T, error)
 13
 14// RetryFunction is a function that retries another function.
 15type RetryFunction[T any] func(ctx context.Context, fn RetryFn[T]) (T, error)
 16
 17// RetryReason represents the reason why a retry operation failed.
 18type RetryReason string
 19
 20const (
 21	RetryReasonMaxRetriesExceeded RetryReason = "maxRetriesExceeded"
 22	RetryReasonErrorNotRetryable  RetryReason = "errorNotRetryable"
 23)
 24
 25// RetryError represents an error that occurred during retry operations.
 26type RetryError struct {
 27	*AIError
 28	Reason RetryReason
 29	Errors []error
 30}
 31
 32// NewRetryError creates a new retry error.
 33func NewRetryError(message string, reason RetryReason, errors []error) *RetryError {
 34	return &RetryError{
 35		AIError: NewAIError("AI_RetryError", message, nil),
 36		Reason:  reason,
 37		Errors:  errors,
 38	}
 39}
 40
 41// getRetryDelayInMs calculates the retry delay based on error headers and exponential backoff.
 42func getRetryDelayInMs(err error, exponentialBackoffDelay time.Duration) time.Duration {
 43	var apiErr *APICallError
 44	if !errors.As(err, &apiErr) || apiErr.ResponseHeaders == nil {
 45		return exponentialBackoffDelay
 46	}
 47
 48	headers := apiErr.ResponseHeaders
 49	var ms time.Duration
 50
 51	// retry-ms is more precise than retry-after and used by e.g. OpenAI
 52	if retryAfterMs, exists := headers["retry-after-ms"]; exists {
 53		if timeoutMs, err := strconv.ParseFloat(retryAfterMs, 64); err == nil {
 54			ms = time.Duration(timeoutMs) * time.Millisecond
 55		}
 56	}
 57
 58	// About the Retry-After header: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Retry-After
 59	if retryAfter, exists := headers["retry-after"]; exists && ms == 0 {
 60		if timeoutSeconds, err := strconv.ParseFloat(retryAfter, 64); err == nil {
 61			ms = time.Duration(timeoutSeconds) * time.Second
 62		} else {
 63			// Try parsing as HTTP date
 64			if t, err := time.Parse(time.RFC1123, retryAfter); err == nil {
 65				ms = time.Until(t)
 66			}
 67		}
 68	}
 69
 70	// Check that the delay is reasonable:
 71	// 0 <= ms < 60 seconds or ms < exponentialBackoffDelay
 72	if ms > 0 && (ms < 60*time.Second || ms < exponentialBackoffDelay) {
 73		return ms
 74	}
 75
 76	return exponentialBackoffDelay
 77}
 78
 79// isAbortError checks if the error is a context cancellation error.
 80func isAbortError(err error) bool {
 81	return errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded)
 82}
 83
 84// RetryWithExponentialBackoffRespectingRetryHeaders creates a retry function that retries
 85// a failed operation with exponential backoff, while respecting rate limit headers
 86// (retry-after-ms and retry-after) if they are provided and reasonable (0-60 seconds).
 87func RetryWithExponentialBackoffRespectingRetryHeaders[T any](options RetryOptions) RetryFunction[T] {
 88	return func(ctx context.Context, fn RetryFn[T]) (T, error) {
 89		return retryWithExponentialBackoff(ctx, fn, options, nil)
 90	}
 91}
 92
 93// RetryOptions configures the retry behavior.
 94type RetryOptions struct {
 95	MaxRetries     int
 96	InitialDelayIn time.Duration
 97	BackoffFactor  float64
 98	OnRetry        OnRetryCallback
 99}
100
101type OnRetryCallback = func(err *APICallError, delay time.Duration)
102
103// DefaultRetryOptions returns the default retry options.
104func DefaultRetryOptions() RetryOptions {
105	return RetryOptions{
106		MaxRetries:     2,
107		InitialDelayIn: 2000 * time.Millisecond,
108		BackoffFactor:  2.0,
109	}
110}
111
112// retryWithExponentialBackoff implements the retry logic with exponential backoff.
113func retryWithExponentialBackoff[T any](ctx context.Context, fn RetryFn[T], options RetryOptions, allErrors []error) (T, error) {
114	var zero T
115	result, err := fn()
116	if err == nil {
117		return result, nil
118	}
119
120	if isAbortError(err) {
121		return zero, err // don't retry when the request was aborted
122	}
123
124	if options.MaxRetries == 0 {
125		return zero, err // don't wrap the error when retries are disabled
126	}
127
128	errorMessage := GetErrorMessage(err)
129	newErrors := append(allErrors, err)
130	tryNumber := len(newErrors)
131
132	if tryNumber > options.MaxRetries {
133		return zero, NewRetryError(
134			fmt.Sprintf("Failed after %d attempts. Last error: %s", tryNumber, errorMessage),
135			RetryReasonMaxRetriesExceeded,
136			newErrors,
137		)
138	}
139
140	var apiErr *APICallError
141	if errors.As(err, &apiErr) && apiErr.IsRetryable && tryNumber <= options.MaxRetries {
142		delay := getRetryDelayInMs(err, options.InitialDelayIn)
143		if options.OnRetry != nil {
144			options.OnRetry(apiErr, delay)
145		}
146
147		select {
148		case <-time.After(delay):
149			// Continue with retry
150		case <-ctx.Done():
151			return zero, ctx.Err()
152		}
153
154		newOptions := options
155		newOptions.InitialDelayIn = time.Duration(float64(options.InitialDelayIn) * options.BackoffFactor)
156
157		return retryWithExponentialBackoff(ctx, fn, newOptions, newErrors)
158	}
159
160	if tryNumber == 1 {
161		return zero, err // don't wrap the error when a non-retryable error occurs on the first try
162	}
163
164	return zero, NewRetryError(
165		fmt.Sprintf("Failed after %d attempts with non-retryable error: '%s'", tryNumber, errorMessage),
166		RetryReasonErrorNotRetryable,
167		newErrors,
168	)
169}
170