@@ -16,6 +16,12 @@ type StepResult struct {
Messages []Message
}
+// stepExecutionResult encapsulates the result of executing a step with stream processing.
+type stepExecutionResult struct {
+ StepResult StepResult
+ ShouldContinue bool
+}
+
// StopCondition defines a function that determines when an agent should stop executing.
type StopCondition = func(steps []StepResult) bool
@@ -736,6 +742,7 @@ func (a *agent) Stream(ctx context.Context, opts AgentStreamCall) (*AgentResult,
ActiveTools: opts.ActiveTools,
ProviderOptions: opts.ProviderOptions,
MaxRetries: opts.MaxRetries,
+ OnRetry: opts.OnRetry,
StopWhen: opts.StopWhen,
PrepareStep: opts.PrepareStep,
RepairToolCall: opts.RepairToolCall,
@@ -829,26 +836,29 @@ func (a *agent) Stream(ctx context.Context, opts AgentStreamCall) (*AgentResult,
ProviderOptions: call.ProviderOptions,
}
- // Get streaming response with retry logic
+ // Execute step with retry logic wrapping both stream creation and processing
retryOptions := DefaultRetryOptions()
if call.MaxRetries != nil {
retryOptions.MaxRetries = *call.MaxRetries
}
retryOptions.OnRetry = call.OnRetry
- retry := RetryWithExponentialBackoffRespectingRetryHeaders[StreamResponse](retryOptions)
+ retry := RetryWithExponentialBackoffRespectingRetryHeaders[stepExecutionResult](retryOptions)
- stream, err := retry(ctx, func() (StreamResponse, error) {
- return stepModel.Stream(ctx, streamCall)
- })
- if err != nil {
- if opts.OnError != nil {
- opts.OnError(err)
+ result, err := retry(ctx, func() (stepExecutionResult, error) {
+ // Create the stream
+ stream, err := stepModel.Stream(ctx, streamCall)
+ if err != nil {
+ return stepExecutionResult{}, err
+ }
+
+ // Process the stream
+ result, err := a.processStepStream(ctx, stream, opts, steps)
+ if err != nil {
+ return stepExecutionResult{}, err
}
- return nil, err
- }
- // Process stream with tool execution
- stepResult, shouldContinue, err := a.processStepStream(ctx, stream, opts, steps)
+ return result, nil
+ })
if err != nil {
if opts.OnError != nil {
opts.OnError(err)
@@ -856,21 +866,21 @@ func (a *agent) Stream(ctx context.Context, opts AgentStreamCall) (*AgentResult,
return nil, err
}
- steps = append(steps, stepResult)
- totalUsage = addUsage(totalUsage, stepResult.Usage)
+ steps = append(steps, result.StepResult)
+ totalUsage = addUsage(totalUsage, result.StepResult.Usage)
// Call step finished callback
if opts.OnStepFinish != nil {
- _ = opts.OnStepFinish(stepResult)
+ _ = opts.OnStepFinish(result.StepResult)
}
// Add step messages to response messages
- stepMessages := toResponseMessages(stepResult.Content)
+ stepMessages := toResponseMessages(result.StepResult.Content)
responseMessages = append(responseMessages, stepMessages...)
// Check stop conditions
shouldStop := isStopConditionMet(call.StopWhen, steps)
- if shouldStop || !shouldContinue {
+ if shouldStop || !result.ShouldContinue {
break
}
}
@@ -1088,7 +1098,7 @@ func WithOnRetry(callback OnRetryCallback) AgentOption {
}
// processStepStream processes a single step's stream and returns the step result.
-func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, opts AgentStreamCall, _ []StepResult) (StepResult, bool, error) {
+func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, opts AgentStreamCall, _ []StepResult) (stepExecutionResult, error) {
var stepContent []Content
var stepToolCalls []ToolCallContent
var stepUsage Usage
@@ -1110,7 +1120,7 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op
if opts.OnChunk != nil {
err := opts.OnChunk(part)
if err != nil {
- return StepResult{}, false, err
+ return stepExecutionResult{}, err
}
}
@@ -1120,7 +1130,7 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op
if opts.OnWarnings != nil {
err := opts.OnWarnings(part.Warnings)
if err != nil {
- return StepResult{}, false, err
+ return stepExecutionResult{}, err
}
}
@@ -1129,7 +1139,7 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op
if opts.OnTextStart != nil {
err := opts.OnTextStart(part.ID)
if err != nil {
- return StepResult{}, false, err
+ return stepExecutionResult{}, err
}
}
@@ -1140,7 +1150,7 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op
if opts.OnTextDelta != nil {
err := opts.OnTextDelta(part.ID, part.Delta)
if err != nil {
- return StepResult{}, false, err
+ return stepExecutionResult{}, err
}
}
@@ -1155,7 +1165,7 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op
if opts.OnTextEnd != nil {
err := opts.OnTextEnd(part.ID)
if err != nil {
- return StepResult{}, false, err
+ return stepExecutionResult{}, err
}
}
@@ -1168,7 +1178,7 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op
}
err := opts.OnReasoningStart(part.ID, content)
if err != nil {
- return StepResult{}, false, err
+ return stepExecutionResult{}, err
}
}
@@ -1181,7 +1191,7 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op
if opts.OnReasoningDelta != nil {
err := opts.OnReasoningDelta(part.ID, part.Delta)
if err != nil {
- return StepResult{}, false, err
+ return stepExecutionResult{}, err
}
}
@@ -1198,7 +1208,7 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op
if opts.OnReasoningEnd != nil {
err := opts.OnReasoningEnd(part.ID, content)
if err != nil {
- return StepResult{}, false, err
+ return stepExecutionResult{}, err
}
}
delete(activeReasoningContent, part.ID)
@@ -1214,7 +1224,7 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op
if opts.OnToolInputStart != nil {
err := opts.OnToolInputStart(part.ID, part.ToolCallName)
if err != nil {
- return StepResult{}, false, err
+ return stepExecutionResult{}, err
}
}
@@ -1225,7 +1235,7 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op
if opts.OnToolInputDelta != nil {
err := opts.OnToolInputDelta(part.ID, part.Delta)
if err != nil {
- return StepResult{}, false, err
+ return stepExecutionResult{}, err
}
}
@@ -1233,7 +1243,7 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op
if opts.OnToolInputEnd != nil {
err := opts.OnToolInputEnd(part.ID)
if err != nil {
- return StepResult{}, false, err
+ return stepExecutionResult{}, err
}
}
@@ -1254,7 +1264,7 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op
if opts.OnToolCall != nil {
err := opts.OnToolCall(validatedToolCall)
if err != nil {
- return StepResult{}, false, err
+ return stepExecutionResult{}, err
}
}
@@ -1273,7 +1283,7 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op
if opts.OnSource != nil {
err := opts.OnSource(sourceContent)
if err != nil {
- return StepResult{}, false, err
+ return stepExecutionResult{}, err
}
}
@@ -1284,15 +1294,12 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op
if opts.OnStreamFinish != nil {
err := opts.OnStreamFinish(part.Usage, part.FinishReason, part.ProviderMetadata)
if err != nil {
- return StepResult{}, false, err
+ return stepExecutionResult{}, err
}
}
case StreamPartTypeError:
- if opts.OnError != nil {
- opts.OnError(part.Error)
- }
- return StepResult{}, false, part.Error
+ return stepExecutionResult{}, part.Error
}
}
@@ -1302,7 +1309,7 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op
var err error
toolResults, err = a.executeTools(ctx, a.settings.tools, stepToolCalls, opts.OnToolResult)
if err != nil {
- return StepResult{}, false, err
+ return stepExecutionResult{}, err
}
// Add tool results to content
for _, result := range toolResults {
@@ -1324,7 +1331,10 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op
// Determine if we should continue (has tool calls and not stopped)
shouldContinue := len(stepToolCalls) > 0 && stepFinishReason == FinishReasonToolCalls
- return stepResult, shouldContinue, nil
+ return stepExecutionResult{
+ StepResult: stepResult,
+ ShouldContinue: shouldContinue,
+ }, nil
}
func addUsage(a, b Usage) Usage {