From 72026be82cde1f768dd30202ae5a47f196abfb1e Mon Sep 17 00:00:00 2001 From: Cristian Date: Tue, 11 Nov 2025 01:48:26 -0800 Subject: [PATCH] fix: honor MaxRetries option and add retry logic to Stream method (#68) --- agent.go | 30 ++++++++++++++++++++++++++++-- 1 file changed, 28 insertions(+), 2 deletions(-) diff --git a/agent.go b/agent.go index 52015b5fd90ca988cb8e7af763e38693bfd4c9e1..1f62626725d1ce81b7b6a099344634501ccf21f1 100644 --- a/agent.go +++ b/agent.go @@ -418,6 +418,9 @@ func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, err preparedTools := a.prepareTools(a.settings.tools, stepActiveTools, disableAllTools) retryOptions := DefaultRetryOptions() + if opts.MaxRetries != nil { + retryOptions.MaxRetries = *opts.MaxRetries + } retryOptions.OnRetry = opts.OnRetry retry := RetryWithExponentialBackoffRespectingRetryHeaders[*Response](retryOptions) @@ -826,8 +829,17 @@ func (a *agent) Stream(ctx context.Context, opts AgentStreamCall) (*AgentResult, ProviderOptions: call.ProviderOptions, } - // Get streaming response - stream, err := stepModel.Stream(ctx, streamCall) + // Get streaming response with retry logic + retryOptions := DefaultRetryOptions() + if call.MaxRetries != nil { + retryOptions.MaxRetries = *call.MaxRetries + } + retryOptions.OnRetry = call.OnRetry + retry := RetryWithExponentialBackoffRespectingRetryHeaders[StreamResponse](retryOptions) + + stream, err := retry(ctx, func() (StreamResponse, error) { + return stepModel.Stream(ctx, streamCall) + }) if err != nil { if opts.OnError != nil { opts.OnError(err) @@ -1061,6 +1073,20 @@ func WithRepairToolCall(fn RepairToolCallFunction) AgentOption { } } +// WithMaxRetries sets the maximum number of retries for the agent. +func WithMaxRetries(maxRetries int) AgentOption { + return func(s *agentSettings) { + s.maxRetries = &maxRetries + } +} + +// WithOnRetry sets the retry callback for the agent. +func WithOnRetry(callback OnRetryCallback) AgentOption { + return func(s *agentSettings) { + s.onRetry = callback + } +} + // processStepStream processes a single step's stream and returns the step result. func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, opts AgentStreamCall, _ []StepResult) (StepResult, bool, error) { var stepContent []Content