Detailed changes
@@ -182,22 +182,21 @@ type AgentStreamCall struct {
OnError func(error) // Called when an error occurs
// Stream part callbacks - called for each corresponding stream part type
- OnChunk func(StreamPart) // Called for each stream part (catch-all)
- OnWarnings func(warnings []CallWarning) // Called for warnings
- OnTextStart func(id string) // Called when text starts
- OnTextDelta func(id, text string) // Called for text deltas
- OnTextEnd func(id string) // Called when text ends
- OnReasoningStart func(id string) // Called when reasoning starts
- OnReasoningDelta func(id, text string) // Called for reasoning deltas
- OnReasoningEnd func(id string, reasoning ReasoningContent) // Called when reasoning ends
- OnToolInputStart func(id, toolName string) // Called when tool input starts
- OnToolInputDelta func(id, delta string) // Called for tool input deltas
- OnToolInputEnd func(id string) // Called when tool input ends
- OnToolCall func(toolCall ToolCallContent) // Called when tool call is complete
- OnToolResult func(result ToolResultContent) // Called when tool execution completes
- OnSource func(source SourceContent) // Called for source references
- OnStreamFinish func(usage Usage, finishReason FinishReason, providerMetadata ProviderMetadata) // Called when stream finishes
- OnStreamError func(error) // Called when stream error occurs
+ OnChunk func(StreamPart) error // Called for each stream part (catch-all)
+ OnWarnings func(warnings []CallWarning) error // Called for warnings
+ OnTextStart func(id string) error // Called when text starts
+ OnTextDelta func(id, text string) error // Called for text deltas
+ OnTextEnd func(id string) error // Called when text ends
+ OnReasoningStart func(id string) error // Called when reasoning starts
+ OnReasoningDelta func(id, text string) error // Called for reasoning deltas
+ OnReasoningEnd func(id string, reasoning ReasoningContent) error // Called when reasoning ends
+ OnToolInputStart func(id, toolName string) error // Called when tool input starts
+ OnToolInputDelta func(id, delta string) error // Called for tool input deltas
+ OnToolInputEnd func(id string) error // Called when tool input ends
+ OnToolCall func(toolCall ToolCallContent) error // Called when tool call is complete
+ OnToolResult func(result ToolResultContent) error // Called when tool execution completes
+ OnSource func(source SourceContent) error // Called for source references
+ OnStreamFinish func(usage Usage, finishReason FinishReason, providerMetadata ProviderMetadata) error // Called when stream finishes
}
type AgentResult struct {
@@ -540,7 +539,7 @@ func toResponseMessages(content []Content) []Message {
return messages
}
-func (a *agent) executeTools(ctx context.Context, allTools []AgentTool, toolCalls []ToolCallContent, toolResultCallback func(result ToolResultContent)) ([]ToolResultContent, error) {
+func (a *agent) executeTools(ctx context.Context, allTools []AgentTool, toolCalls []ToolCallContent, toolResultCallback func(result ToolResultContent) error) ([]ToolResultContent, error) {
if len(toolCalls) == 0 {
return nil, nil
}
@@ -553,7 +552,8 @@ func (a *agent) executeTools(ctx context.Context, allTools []AgentTool, toolCall
// Execute all tool calls in parallel
results := make([]ToolResultContent, len(toolCalls))
- var toolExecutionError error
+ executeErrors := make([]error, len(toolCalls))
+
var wg sync.WaitGroup
for i, toolCall := range toolCalls {
@@ -572,7 +572,10 @@ func (a *agent) executeTools(ctx context.Context, allTools []AgentTool, toolCall
ProviderExecuted: false,
}
if toolResultCallback != nil {
- toolResultCallback(results[index])
+ err := toolResultCallback(results[index])
+ if err != nil {
+ executeErrors[index] = err
+ }
}
return
@@ -590,7 +593,10 @@ func (a *agent) executeTools(ctx context.Context, allTools []AgentTool, toolCall
}
if toolResultCallback != nil {
- toolResultCallback(results[index])
+ err := toolResultCallback(results[index])
+ if err != nil {
+ executeErrors[index] = err
+ }
}
return
}
@@ -612,9 +618,12 @@ func (a *agent) executeTools(ctx context.Context, allTools []AgentTool, toolCall
ProviderExecuted: false,
}
if toolResultCallback != nil {
- toolResultCallback(results[index])
+ cbErr := toolResultCallback(results[index])
+ if cbErr != nil {
+ executeErrors[index] = cbErr
+ }
}
- toolExecutionError = err
+ executeErrors[index] = err
return
}
@@ -630,7 +639,10 @@ func (a *agent) executeTools(ctx context.Context, allTools []AgentTool, toolCall
}
if toolResultCallback != nil {
- toolResultCallback(results[index])
+ err := toolResultCallback(results[index])
+ if err != nil {
+ executeErrors[index] = err
+ }
}
} else {
results[index] = ToolResultContent{
@@ -643,7 +655,10 @@ func (a *agent) executeTools(ctx context.Context, allTools []AgentTool, toolCall
ProviderExecuted: false,
}
if toolResultCallback != nil {
- toolResultCallback(results[index])
+ err := toolResultCallback(results[index])
+ if err != nil {
+ executeErrors[index] = err
+ }
}
}
}(i, toolCall)
@@ -652,7 +667,13 @@ func (a *agent) executeTools(ctx context.Context, allTools []AgentTool, toolCall
// Wait for all tool executions to complete
wg.Wait()
- return results, toolExecutionError
+ for _, err := range executeErrors {
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ return results, nil
}
// Stream implements Agent.
@@ -1004,20 +1025,29 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op
for part := range stream {
// Forward all parts to chunk callback
if opts.OnChunk != nil {
- opts.OnChunk(part)
+ err := opts.OnChunk(part)
+ if err != nil {
+ return StepResult{}, false, err
+ }
}
switch part.Type {
case StreamPartTypeWarnings:
stepWarnings = part.Warnings
if opts.OnWarnings != nil {
- opts.OnWarnings(part.Warnings)
+ err := opts.OnWarnings(part.Warnings)
+ if err != nil {
+ return StepResult{}, false, err
+ }
}
case StreamPartTypeTextStart:
activeTextContent[part.ID] = ""
if opts.OnTextStart != nil {
- opts.OnTextStart(part.ID)
+ err := opts.OnTextStart(part.ID)
+ if err != nil {
+ return StepResult{}, false, err
+ }
}
case StreamPartTypeTextDelta:
@@ -1025,7 +1055,10 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op
activeTextContent[part.ID] += part.Delta
}
if opts.OnTextDelta != nil {
- opts.OnTextDelta(part.ID, part.Delta)
+ err := opts.OnTextDelta(part.ID, part.Delta)
+ if err != nil {
+ return StepResult{}, false, err
+ }
}
case StreamPartTypeTextEnd:
@@ -1037,13 +1070,19 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op
delete(activeTextContent, part.ID)
}
if opts.OnTextEnd != nil {
- opts.OnTextEnd(part.ID)
+ err := opts.OnTextEnd(part.ID)
+ if err != nil {
+ return StepResult{}, false, err
+ }
}
case StreamPartTypeReasoningStart:
activeTextContent[part.ID] = ""
if opts.OnReasoningStart != nil {
- opts.OnReasoningStart(part.ID)
+ err := opts.OnReasoningStart(part.ID)
+ if err != nil {
+ return StepResult{}, false, err
+ }
}
case StreamPartTypeReasoningDelta:
@@ -1051,7 +1090,10 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op
activeTextContent[part.ID] += part.Delta
}
if opts.OnReasoningDelta != nil {
- opts.OnReasoningDelta(part.ID, part.Delta)
+ err := opts.OnReasoningDelta(part.ID, part.Delta)
+ if err != nil {
+ return StepResult{}, false, err
+ }
}
case StreamPartTypeReasoningEnd:
@@ -1061,10 +1103,13 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op
ProviderMetadata: ProviderMetadata(part.ProviderMetadata),
})
if opts.OnReasoningEnd != nil {
- opts.OnReasoningEnd(part.ID, ReasoningContent{
+ err := opts.OnReasoningEnd(part.ID, ReasoningContent{
Text: text,
ProviderMetadata: ProviderMetadata(part.ProviderMetadata),
})
+ if err != nil {
+ return StepResult{}, false, err
+ }
}
delete(activeTextContent, part.ID)
}
@@ -1077,7 +1122,10 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op
ProviderExecuted: part.ProviderExecuted,
}
if opts.OnToolInputStart != nil {
- opts.OnToolInputStart(part.ID, part.ToolCallName)
+ err := opts.OnToolInputStart(part.ID, part.ToolCallName)
+ if err != nil {
+ return StepResult{}, false, err
+ }
}
case StreamPartTypeToolInputDelta:
@@ -1085,12 +1133,18 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op
toolCall.Input += part.Delta
}
if opts.OnToolInputDelta != nil {
- opts.OnToolInputDelta(part.ID, part.Delta)
+ err := opts.OnToolInputDelta(part.ID, part.Delta)
+ if err != nil {
+ return StepResult{}, false, err
+ }
}
case StreamPartTypeToolInputEnd:
if opts.OnToolInputEnd != nil {
- opts.OnToolInputEnd(part.ID)
+ err := opts.OnToolInputEnd(part.ID)
+ if err != nil {
+ return StepResult{}, false, err
+ }
}
case StreamPartTypeToolCall:
@@ -1108,7 +1162,10 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op
stepContent = append(stepContent, validatedToolCall)
if opts.OnToolCall != nil {
- opts.OnToolCall(validatedToolCall)
+ err := opts.OnToolCall(validatedToolCall)
+ if err != nil {
+ return StepResult{}, false, err
+ }
}
// Clean up active tool call
@@ -1124,7 +1181,10 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op
}
stepContent = append(stepContent, sourceContent)
if opts.OnSource != nil {
- opts.OnSource(sourceContent)
+ err := opts.OnSource(sourceContent)
+ if err != nil {
+ return StepResult{}, false, err
+ }
}
case StreamPartTypeFinish:
@@ -1132,13 +1192,13 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op
stepFinishReason = part.FinishReason
stepProviderMetadata = ProviderMetadata(part.ProviderMetadata)
if opts.OnStreamFinish != nil {
- opts.OnStreamFinish(part.Usage, part.FinishReason, part.ProviderMetadata)
+ err := opts.OnStreamFinish(part.Usage, part.FinishReason, part.ProviderMetadata)
+ if err != nil {
+ return StepResult{}, false, err
+ }
}
case StreamPartTypeError:
- if opts.OnStreamError != nil {
- opts.OnStreamError(part.Error)
- }
if opts.OnError != nil {
opts.OnError(part.Error)
}
@@ -124,53 +124,65 @@ func TestStreamingAgentCallbacks(t *testing.T) {
OnError: func(err error) {
callbacks["OnError"] = true
},
- OnChunk: func(part StreamPart) {
+ OnChunk: func(part StreamPart) error {
callbacks["OnChunk"] = true
+ return nil
},
- OnWarnings: func(warnings []CallWarning) {
+ OnWarnings: func(warnings []CallWarning) error {
callbacks["OnWarnings"] = true
+ return nil
},
- OnTextStart: func(id string) {
+ OnTextStart: func(id string) error {
callbacks["OnTextStart"] = true
+ return nil
},
- OnTextDelta: func(id, text string) {
+ OnTextDelta: func(id, text string) error {
callbacks["OnTextDelta"] = true
+ return nil
},
- OnTextEnd: func(id string) {
+ OnTextEnd: func(id string) error {
callbacks["OnTextEnd"] = true
+ return nil
},
- OnReasoningStart: func(id string) {
+ OnReasoningStart: func(id string) error {
callbacks["OnReasoningStart"] = true
+ return nil
},
- OnReasoningDelta: func(id, text string) {
+ OnReasoningDelta: func(id, text string) error {
callbacks["OnReasoningDelta"] = true
+ return nil
},
- OnReasoningEnd: func(id string, content ReasoningContent) {
+ OnReasoningEnd: func(id string, content ReasoningContent) error {
callbacks["OnReasoningEnd"] = true
+ return nil
},
- OnToolInputStart: func(id, toolName string) {
+ OnToolInputStart: func(id, toolName string) error {
callbacks["OnToolInputStart"] = true
+ return nil
},
- OnToolInputDelta: func(id, delta string) {
+ OnToolInputDelta: func(id, delta string) error {
callbacks["OnToolInputDelta"] = true
+ return nil
},
- OnToolInputEnd: func(id string) {
+ OnToolInputEnd: func(id string) error {
callbacks["OnToolInputEnd"] = true
+ return nil
},
- OnToolCall: func(toolCall ToolCallContent) {
+ OnToolCall: func(toolCall ToolCallContent) error {
callbacks["OnToolCall"] = true
+ return nil
},
- OnToolResult: func(result ToolResultContent) {
+ OnToolResult: func(result ToolResultContent) error {
callbacks["OnToolResult"] = true
+ return nil
},
- OnSource: func(source SourceContent) {
+ OnSource: func(source SourceContent) error {
callbacks["OnSource"] = true
+ return nil
},
- OnStreamFinish: func(usage Usage, finishReason FinishReason, providerMetadata ProviderMetadata) {
+ OnStreamFinish: func(usage Usage, finishReason FinishReason, providerMetadata ProviderMetadata) error {
callbacks["OnStreamFinish"] = true
- },
- OnStreamError: func(err error) {
- callbacks["OnStreamError"] = true
+ return nil
},
}
@@ -207,7 +219,6 @@ func TestStreamingAgentCallbacks(t *testing.T) {
// Verify that error callbacks were not called
require.False(t, callbacks["OnError"], "OnError should not be called in successful case")
- require.False(t, callbacks["OnStreamError"], "OnStreamError should not be called in successful case")
require.False(t, callbacks["OnToolCall"], "OnToolCall should not be called without actual tool calls")
require.False(t, callbacks["OnToolResult"], "OnToolResult should not be called without actual tool results")
}
@@ -289,28 +300,33 @@ func TestStreamingAgentWithTools(t *testing.T) {
// Create streaming call with callbacks
streamCall := AgentStreamCall{
Prompt: "Echo 'test'",
- OnToolInputStart: func(id, toolName string) {
+ OnToolInputStart: func(id, toolName string) error {
toolInputStartCalled = true
require.Equal(t, "tool-1", id)
require.Equal(t, "echo", toolName)
+ return nil
},
- OnToolInputDelta: func(id, delta string) {
+ OnToolInputDelta: func(id, delta string) error {
toolInputDeltaCalled = true
require.Equal(t, "tool-1", id)
require.Contains(t, []string{`{"message"`, `: "test"}`}, delta)
+ return nil
},
- OnToolInputEnd: func(id string) {
+ OnToolInputEnd: func(id string) error {
toolInputEndCalled = true
require.Equal(t, "tool-1", id)
+ return nil
},
- OnToolCall: func(toolCall ToolCallContent) {
+ OnToolCall: func(toolCall ToolCallContent) error {
toolCallCalled = true
require.Equal(t, "echo", toolCall.ToolName)
require.Equal(t, `{"message": "test"}`, toolCall.Input)
+ return nil
},
- OnToolResult: func(result ToolResultContent) {
+ OnToolResult: func(result ToolResultContent) error {
toolResultCalled = true
require.Equal(t, "echo", result.ToolName)
+ return nil
},
}
@@ -377,10 +393,11 @@ func TestStreamingAgentTextDeltas(t *testing.T) {
streamCall := AgentStreamCall{
Prompt: "Say hello",
- OnTextDelta: func(id, text string) {
+ OnTextDelta: func(id, text string) error {
if text != "" {
textDeltas = append(textDeltas, text)
}
+ return nil
},
}
@@ -438,11 +455,13 @@ func TestStreamingAgentReasoning(t *testing.T) {
streamCall := AgentStreamCall{
Prompt: "Think and respond",
- OnReasoningDelta: func(id, text string) {
+ OnReasoningDelta: func(id, text string) error {
reasoningDeltas = append(reasoningDeltas, text)
+ return nil
},
- OnTextDelta: func(id, text string) {
+ OnTextDelta: func(id, text string) error {
textDeltas = append(textDeltas, text)
+ return nil
},
}
@@ -473,15 +492,12 @@ func TestStreamingAgentError(t *testing.T) {
ctx := context.Background()
// Track error callbacks
- var streamErrorOccurred bool
var errorOccurred bool
var errorMessage string
streamCall := AgentStreamCall{
Prompt: "This will fail",
- OnStreamError: func(err error) {
- streamErrorOccurred = true
- },
+
OnError: func(err error) {
errorOccurred = true
errorMessage = err.Error()
@@ -492,7 +508,6 @@ func TestStreamingAgentError(t *testing.T) {
result, err := agent.Stream(ctx, streamCall)
require.Error(t, err)
require.Nil(t, result)
- require.True(t, streamErrorOccurred, "OnStreamError should have been called")
require.True(t, errorOccurred, "OnError should have been called")
require.Contains(t, errorMessage, "mock stream error")
}
@@ -546,8 +561,9 @@ func TestStreamingAgentSources(t *testing.T) {
streamCall := AgentStreamCall{
Prompt: "Search and respond",
- OnSource: func(source SourceContent) {
+ OnSource: func(source SourceContent) error {
sources = append(sources, source)
+ return nil
},
}
@@ -58,18 +58,21 @@ func main() {
Prompt: "Please echo back 'Hello, streaming world!'",
// Show real-time text as it streams
- OnTextDelta: func(id, text string) {
+ OnTextDelta: func(id, text string) error {
fmt.Print(text)
+ return nil
},
// Show when tools are called
- OnToolCall: func(toolCall ai.ToolCallContent) {
+ OnToolCall: func(toolCall ai.ToolCallContent) error {
fmt.Printf("\n[Tool: %s called]\n", toolCall.ToolName)
+ return nil
},
// Show tool results
- OnToolResult: func(result ai.ToolResultContent) {
+ OnToolResult: func(result ai.ToolResultContent) error {
fmt.Printf("[Tool result received]\n")
+ return nil
},
// Show when each step completes
@@ -144,47 +144,58 @@ func main() {
},
// Stream part callbacks
- OnWarnings: func(warnings []ai.CallWarning) {
+ OnWarnings: func(warnings []ai.CallWarning) error {
for _, warning := range warnings {
fmt.Printf("⚠️ Warning: %s\n", warning.Message)
}
+ return nil
},
- OnTextStart: func(id string) {
+ OnTextStart: func(id string) error {
fmt.Print("💭 Assistant: ")
+ return nil
},
- OnTextDelta: func(id, text string) {
+ OnTextDelta: func(id, text string) error {
fmt.Print(text)
textBuffer.WriteString(text)
+ return nil
},
- OnTextEnd: func(id string) {
+ OnTextEnd: func(id string) error {
fmt.Println()
+ return nil
},
- OnReasoningStart: func(id string) {
+ OnReasoningStart: func(id string) error {
fmt.Print("🤔 Thinking: ")
+ return nil
},
- OnReasoningDelta: func(id, text string) {
+ OnReasoningDelta: func(id, text string) error {
reasoningBuffer.WriteString(text)
+ return nil
},
- OnReasoningEnd: func(id string, content ai.ReasoningContent) {
+ OnReasoningEnd: func(id string, content ai.ReasoningContent) error {
if reasoningBuffer.Len() > 0 {
fmt.Printf("%s\n", reasoningBuffer.String())
reasoningBuffer.Reset()
}
+ return nil
},
- OnToolInputStart: func(id, toolName string) {
+ OnToolInputStart: func(id, toolName string) error {
fmt.Printf("🔧 Calling tool: %s\n", toolName)
+ return nil
},
- OnToolInputDelta: func(id, delta string) {
+ OnToolInputDelta: func(id, delta string) error {
// Could show tool input being built, but it's often noisy
+ return nil
},
- OnToolInputEnd: func(id string) {
+ OnToolInputEnd: func(id string) error {
// Tool input complete
+ return nil
},
- OnToolCall: func(toolCall ai.ToolCallContent) {
+ OnToolCall: func(toolCall ai.ToolCallContent) error {
fmt.Printf("🛠️ Tool call: %s\n", toolCall.ToolName)
fmt.Printf(" Input: %s\n", toolCall.Input)
+ return nil
},
- OnToolResult: func(result ai.ToolResultContent) {
+ OnToolResult: func(result ai.ToolResultContent) error {
fmt.Printf("🎯 Tool result from %s:\n", result.ToolName)
switch output := result.Result.(type) {
case ai.ToolResultOutputContentText:
@@ -192,15 +203,15 @@ func main() {
case ai.ToolResultOutputContentError:
fmt.Printf(" Error: %s\n", output.Error.Error())
}
+ return nil
},
- OnSource: func(source ai.SourceContent) {
+ OnSource: func(source ai.SourceContent) error {
fmt.Printf("📚 Source: %s (%s)\n", source.Title, source.URL)
+ return nil
},
- OnStreamFinish: func(usage ai.Usage, finishReason ai.FinishReason, providerMetadata ai.ProviderMetadata) {
+ OnStreamFinish: func(usage ai.Usage, finishReason ai.FinishReason, providerMetadata ai.ProviderMetadata) error {
fmt.Printf("📊 Stream finished (reason: %s, tokens: %d)\n", finishReason, usage.TotalTokens)
- },
- OnStreamError: func(err error) {
- fmt.Printf("💥 Stream error: %v\n", err)
+ return nil
},
}