From ddf3ee09696961bc26a5a49ac6412b970e1d3051 Mon Sep 17 00:00:00 2001 From: Cristian Date: Fri, 21 Nov 2025 05:49:00 -0800 Subject: [PATCH] fix: retry logic for stream processing errors in agent `Stream` function (#75) --- agent.go | 86 +++++++++++++++++++++++++++++++------------------------- 1 file changed, 48 insertions(+), 38 deletions(-) diff --git a/agent.go b/agent.go index 1f62626725d1ce81b7b6a099344634501ccf21f1..0a81daa3e9b81454ec7297de2b63d3c54ac4b77a 100644 --- a/agent.go +++ b/agent.go @@ -16,6 +16,12 @@ type StepResult struct { Messages []Message } +// stepExecutionResult encapsulates the result of executing a step with stream processing. +type stepExecutionResult struct { + StepResult StepResult + ShouldContinue bool +} + // StopCondition defines a function that determines when an agent should stop executing. type StopCondition = func(steps []StepResult) bool @@ -736,6 +742,7 @@ func (a *agent) Stream(ctx context.Context, opts AgentStreamCall) (*AgentResult, ActiveTools: opts.ActiveTools, ProviderOptions: opts.ProviderOptions, MaxRetries: opts.MaxRetries, + OnRetry: opts.OnRetry, StopWhen: opts.StopWhen, PrepareStep: opts.PrepareStep, RepairToolCall: opts.RepairToolCall, @@ -829,26 +836,29 @@ func (a *agent) Stream(ctx context.Context, opts AgentStreamCall) (*AgentResult, ProviderOptions: call.ProviderOptions, } - // Get streaming response with retry logic + // Execute step with retry logic wrapping both stream creation and processing retryOptions := DefaultRetryOptions() if call.MaxRetries != nil { retryOptions.MaxRetries = *call.MaxRetries } retryOptions.OnRetry = call.OnRetry - retry := RetryWithExponentialBackoffRespectingRetryHeaders[StreamResponse](retryOptions) + retry := RetryWithExponentialBackoffRespectingRetryHeaders[stepExecutionResult](retryOptions) - stream, err := retry(ctx, func() (StreamResponse, error) { - return stepModel.Stream(ctx, streamCall) - }) - if err != nil { - if opts.OnError != nil { - opts.OnError(err) + result, err := retry(ctx, func() (stepExecutionResult, error) { + // Create the stream + stream, err := stepModel.Stream(ctx, streamCall) + if err != nil { + return stepExecutionResult{}, err + } + + // Process the stream + result, err := a.processStepStream(ctx, stream, opts, steps) + if err != nil { + return stepExecutionResult{}, err } - return nil, err - } - // Process stream with tool execution - stepResult, shouldContinue, err := a.processStepStream(ctx, stream, opts, steps) + return result, nil + }) if err != nil { if opts.OnError != nil { opts.OnError(err) @@ -856,21 +866,21 @@ func (a *agent) Stream(ctx context.Context, opts AgentStreamCall) (*AgentResult, return nil, err } - steps = append(steps, stepResult) - totalUsage = addUsage(totalUsage, stepResult.Usage) + steps = append(steps, result.StepResult) + totalUsage = addUsage(totalUsage, result.StepResult.Usage) // Call step finished callback if opts.OnStepFinish != nil { - _ = opts.OnStepFinish(stepResult) + _ = opts.OnStepFinish(result.StepResult) } // Add step messages to response messages - stepMessages := toResponseMessages(stepResult.Content) + stepMessages := toResponseMessages(result.StepResult.Content) responseMessages = append(responseMessages, stepMessages...) // Check stop conditions shouldStop := isStopConditionMet(call.StopWhen, steps) - if shouldStop || !shouldContinue { + if shouldStop || !result.ShouldContinue { break } } @@ -1088,7 +1098,7 @@ func WithOnRetry(callback OnRetryCallback) AgentOption { } // processStepStream processes a single step's stream and returns the step result. -func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, opts AgentStreamCall, _ []StepResult) (StepResult, bool, error) { +func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, opts AgentStreamCall, _ []StepResult) (stepExecutionResult, error) { var stepContent []Content var stepToolCalls []ToolCallContent var stepUsage Usage @@ -1110,7 +1120,7 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op if opts.OnChunk != nil { err := opts.OnChunk(part) if err != nil { - return StepResult{}, false, err + return stepExecutionResult{}, err } } @@ -1120,7 +1130,7 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op if opts.OnWarnings != nil { err := opts.OnWarnings(part.Warnings) if err != nil { - return StepResult{}, false, err + return stepExecutionResult{}, err } } @@ -1129,7 +1139,7 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op if opts.OnTextStart != nil { err := opts.OnTextStart(part.ID) if err != nil { - return StepResult{}, false, err + return stepExecutionResult{}, err } } @@ -1140,7 +1150,7 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op if opts.OnTextDelta != nil { err := opts.OnTextDelta(part.ID, part.Delta) if err != nil { - return StepResult{}, false, err + return stepExecutionResult{}, err } } @@ -1155,7 +1165,7 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op if opts.OnTextEnd != nil { err := opts.OnTextEnd(part.ID) if err != nil { - return StepResult{}, false, err + return stepExecutionResult{}, err } } @@ -1168,7 +1178,7 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op } err := opts.OnReasoningStart(part.ID, content) if err != nil { - return StepResult{}, false, err + return stepExecutionResult{}, err } } @@ -1181,7 +1191,7 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op if opts.OnReasoningDelta != nil { err := opts.OnReasoningDelta(part.ID, part.Delta) if err != nil { - return StepResult{}, false, err + return stepExecutionResult{}, err } } @@ -1198,7 +1208,7 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op if opts.OnReasoningEnd != nil { err := opts.OnReasoningEnd(part.ID, content) if err != nil { - return StepResult{}, false, err + return stepExecutionResult{}, err } } delete(activeReasoningContent, part.ID) @@ -1214,7 +1224,7 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op if opts.OnToolInputStart != nil { err := opts.OnToolInputStart(part.ID, part.ToolCallName) if err != nil { - return StepResult{}, false, err + return stepExecutionResult{}, err } } @@ -1225,7 +1235,7 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op if opts.OnToolInputDelta != nil { err := opts.OnToolInputDelta(part.ID, part.Delta) if err != nil { - return StepResult{}, false, err + return stepExecutionResult{}, err } } @@ -1233,7 +1243,7 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op if opts.OnToolInputEnd != nil { err := opts.OnToolInputEnd(part.ID) if err != nil { - return StepResult{}, false, err + return stepExecutionResult{}, err } } @@ -1254,7 +1264,7 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op if opts.OnToolCall != nil { err := opts.OnToolCall(validatedToolCall) if err != nil { - return StepResult{}, false, err + return stepExecutionResult{}, err } } @@ -1273,7 +1283,7 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op if opts.OnSource != nil { err := opts.OnSource(sourceContent) if err != nil { - return StepResult{}, false, err + return stepExecutionResult{}, err } } @@ -1284,15 +1294,12 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op if opts.OnStreamFinish != nil { err := opts.OnStreamFinish(part.Usage, part.FinishReason, part.ProviderMetadata) if err != nil { - return StepResult{}, false, err + return stepExecutionResult{}, err } } case StreamPartTypeError: - if opts.OnError != nil { - opts.OnError(part.Error) - } - return StepResult{}, false, part.Error + return stepExecutionResult{}, part.Error } } @@ -1302,7 +1309,7 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op var err error toolResults, err = a.executeTools(ctx, a.settings.tools, stepToolCalls, opts.OnToolResult) if err != nil { - return StepResult{}, false, err + return stepExecutionResult{}, err } // Add tool results to content for _, result := range toolResults { @@ -1324,7 +1331,10 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op // Determine if we should continue (has tool calls and not stopped) shouldContinue := len(stepToolCalls) > 0 && stepFinishReason == FinishReasonToolCalls - return stepResult, shouldContinue, nil + return stepExecutionResult{ + StepResult: stepResult, + ShouldContinue: shouldContinue, + }, nil } func addUsage(a, b Usage) Usage {