fix: retry logic for stream processing errors in agent `Stream` function (#75)

Cristian created

Change summary

agent.go | 86 ++++++++++++++++++++++++++++++++-------------------------
1 file changed, 48 insertions(+), 38 deletions(-)

Detailed changes

agent.go 🔗

@@ -16,6 +16,12 @@ type StepResult struct {
 	Messages []Message
 }
 
+// stepExecutionResult encapsulates the result of executing a step with stream processing.
+type stepExecutionResult struct {
+	StepResult     StepResult
+	ShouldContinue bool
+}
+
 // StopCondition defines a function that determines when an agent should stop executing.
 type StopCondition = func(steps []StepResult) bool
 
@@ -736,6 +742,7 @@ func (a *agent) Stream(ctx context.Context, opts AgentStreamCall) (*AgentResult,
 		ActiveTools:      opts.ActiveTools,
 		ProviderOptions:  opts.ProviderOptions,
 		MaxRetries:       opts.MaxRetries,
+		OnRetry:          opts.OnRetry,
 		StopWhen:         opts.StopWhen,
 		PrepareStep:      opts.PrepareStep,
 		RepairToolCall:   opts.RepairToolCall,
@@ -829,26 +836,29 @@ func (a *agent) Stream(ctx context.Context, opts AgentStreamCall) (*AgentResult,
 			ProviderOptions:  call.ProviderOptions,
 		}
 
-		// Get streaming response with retry logic
+		// Execute step with retry logic wrapping both stream creation and processing
 		retryOptions := DefaultRetryOptions()
 		if call.MaxRetries != nil {
 			retryOptions.MaxRetries = *call.MaxRetries
 		}
 		retryOptions.OnRetry = call.OnRetry
-		retry := RetryWithExponentialBackoffRespectingRetryHeaders[StreamResponse](retryOptions)
+		retry := RetryWithExponentialBackoffRespectingRetryHeaders[stepExecutionResult](retryOptions)
 
-		stream, err := retry(ctx, func() (StreamResponse, error) {
-			return stepModel.Stream(ctx, streamCall)
-		})
-		if err != nil {
-			if opts.OnError != nil {
-				opts.OnError(err)
+		result, err := retry(ctx, func() (stepExecutionResult, error) {
+			// Create the stream
+			stream, err := stepModel.Stream(ctx, streamCall)
+			if err != nil {
+				return stepExecutionResult{}, err
+			}
+
+			// Process the stream
+			result, err := a.processStepStream(ctx, stream, opts, steps)
+			if err != nil {
+				return stepExecutionResult{}, err
 			}
-			return nil, err
-		}
 
-		// Process stream with tool execution
-		stepResult, shouldContinue, err := a.processStepStream(ctx, stream, opts, steps)
+			return result, nil
+		})
 		if err != nil {
 			if opts.OnError != nil {
 				opts.OnError(err)
@@ -856,21 +866,21 @@ func (a *agent) Stream(ctx context.Context, opts AgentStreamCall) (*AgentResult,
 			return nil, err
 		}
 
-		steps = append(steps, stepResult)
-		totalUsage = addUsage(totalUsage, stepResult.Usage)
+		steps = append(steps, result.StepResult)
+		totalUsage = addUsage(totalUsage, result.StepResult.Usage)
 
 		// Call step finished callback
 		if opts.OnStepFinish != nil {
-			_ = opts.OnStepFinish(stepResult)
+			_ = opts.OnStepFinish(result.StepResult)
 		}
 
 		// Add step messages to response messages
-		stepMessages := toResponseMessages(stepResult.Content)
+		stepMessages := toResponseMessages(result.StepResult.Content)
 		responseMessages = append(responseMessages, stepMessages...)
 
 		// Check stop conditions
 		shouldStop := isStopConditionMet(call.StopWhen, steps)
-		if shouldStop || !shouldContinue {
+		if shouldStop || !result.ShouldContinue {
 			break
 		}
 	}
@@ -1088,7 +1098,7 @@ func WithOnRetry(callback OnRetryCallback) AgentOption {
 }
 
 // processStepStream processes a single step's stream and returns the step result.
-func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, opts AgentStreamCall, _ []StepResult) (StepResult, bool, error) {
+func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, opts AgentStreamCall, _ []StepResult) (stepExecutionResult, error) {
 	var stepContent []Content
 	var stepToolCalls []ToolCallContent
 	var stepUsage Usage
@@ -1110,7 +1120,7 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op
 		if opts.OnChunk != nil {
 			err := opts.OnChunk(part)
 			if err != nil {
-				return StepResult{}, false, err
+				return stepExecutionResult{}, err
 			}
 		}
 
@@ -1120,7 +1130,7 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op
 			if opts.OnWarnings != nil {
 				err := opts.OnWarnings(part.Warnings)
 				if err != nil {
-					return StepResult{}, false, err
+					return stepExecutionResult{}, err
 				}
 			}
 
@@ -1129,7 +1139,7 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op
 			if opts.OnTextStart != nil {
 				err := opts.OnTextStart(part.ID)
 				if err != nil {
-					return StepResult{}, false, err
+					return stepExecutionResult{}, err
 				}
 			}
 
@@ -1140,7 +1150,7 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op
 			if opts.OnTextDelta != nil {
 				err := opts.OnTextDelta(part.ID, part.Delta)
 				if err != nil {
-					return StepResult{}, false, err
+					return stepExecutionResult{}, err
 				}
 			}
 
@@ -1155,7 +1165,7 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op
 			if opts.OnTextEnd != nil {
 				err := opts.OnTextEnd(part.ID)
 				if err != nil {
-					return StepResult{}, false, err
+					return stepExecutionResult{}, err
 				}
 			}
 
@@ -1168,7 +1178,7 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op
 				}
 				err := opts.OnReasoningStart(part.ID, content)
 				if err != nil {
-					return StepResult{}, false, err
+					return stepExecutionResult{}, err
 				}
 			}
 
@@ -1181,7 +1191,7 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op
 			if opts.OnReasoningDelta != nil {
 				err := opts.OnReasoningDelta(part.ID, part.Delta)
 				if err != nil {
-					return StepResult{}, false, err
+					return stepExecutionResult{}, err
 				}
 			}
 
@@ -1198,7 +1208,7 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op
 				if opts.OnReasoningEnd != nil {
 					err := opts.OnReasoningEnd(part.ID, content)
 					if err != nil {
-						return StepResult{}, false, err
+						return stepExecutionResult{}, err
 					}
 				}
 				delete(activeReasoningContent, part.ID)
@@ -1214,7 +1224,7 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op
 			if opts.OnToolInputStart != nil {
 				err := opts.OnToolInputStart(part.ID, part.ToolCallName)
 				if err != nil {
-					return StepResult{}, false, err
+					return stepExecutionResult{}, err
 				}
 			}
 
@@ -1225,7 +1235,7 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op
 			if opts.OnToolInputDelta != nil {
 				err := opts.OnToolInputDelta(part.ID, part.Delta)
 				if err != nil {
-					return StepResult{}, false, err
+					return stepExecutionResult{}, err
 				}
 			}
 
@@ -1233,7 +1243,7 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op
 			if opts.OnToolInputEnd != nil {
 				err := opts.OnToolInputEnd(part.ID)
 				if err != nil {
-					return StepResult{}, false, err
+					return stepExecutionResult{}, err
 				}
 			}
 
@@ -1254,7 +1264,7 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op
 			if opts.OnToolCall != nil {
 				err := opts.OnToolCall(validatedToolCall)
 				if err != nil {
-					return StepResult{}, false, err
+					return stepExecutionResult{}, err
 				}
 			}
 
@@ -1273,7 +1283,7 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op
 			if opts.OnSource != nil {
 				err := opts.OnSource(sourceContent)
 				if err != nil {
-					return StepResult{}, false, err
+					return stepExecutionResult{}, err
 				}
 			}
 
@@ -1284,15 +1294,12 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op
 			if opts.OnStreamFinish != nil {
 				err := opts.OnStreamFinish(part.Usage, part.FinishReason, part.ProviderMetadata)
 				if err != nil {
-					return StepResult{}, false, err
+					return stepExecutionResult{}, err
 				}
 			}
 
 		case StreamPartTypeError:
-			if opts.OnError != nil {
-				opts.OnError(part.Error)
-			}
-			return StepResult{}, false, part.Error
+			return stepExecutionResult{}, part.Error
 		}
 	}
 
@@ -1302,7 +1309,7 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op
 		var err error
 		toolResults, err = a.executeTools(ctx, a.settings.tools, stepToolCalls, opts.OnToolResult)
 		if err != nil {
-			return StepResult{}, false, err
+			return stepExecutionResult{}, err
 		}
 		// Add tool results to content
 		for _, result := range toolResults {
@@ -1324,7 +1331,10 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op
 	// Determine if we should continue (has tool calls and not stopped)
 	shouldContinue := len(stepToolCalls) > 0 && stepFinishReason == FinishReasonToolCalls
 
-	return stepResult, shouldContinue, nil
+	return stepExecutionResult{
+		StepResult:     stepResult,
+		ShouldContinue: shouldContinue,
+	}, nil
 }
 
 func addUsage(a, b Usage) Usage {