chore: change callback signature

Kujtim Hoxha created

Change summary

agent.go                                | 144 +++++++++++++++++++-------
agent_stream_test.go                    |  82 +++++++++------
examples/streaming-agent-simple/main.go |   9 +
examples/streaming-agent/main.go        |  45 +++++---
4 files changed, 185 insertions(+), 95 deletions(-)

Detailed changes

agent.go 🔗

@@ -182,22 +182,21 @@ type AgentStreamCall struct {
 	OnError       func(error)                 // Called when an error occurs
 
 	// Stream part callbacks - called for each corresponding stream part type
-	OnChunk          func(StreamPart)                                                                // Called for each stream part (catch-all)
-	OnWarnings       func(warnings []CallWarning)                                                    // Called for warnings
-	OnTextStart      func(id string)                                                                 // Called when text starts
-	OnTextDelta      func(id, text string)                                                           // Called for text deltas
-	OnTextEnd        func(id string)                                                                 // Called when text ends
-	OnReasoningStart func(id string)                                                                 // Called when reasoning starts
-	OnReasoningDelta func(id, text string)                                                           // Called for reasoning deltas
-	OnReasoningEnd   func(id string, reasoning ReasoningContent)                                     // Called when reasoning ends
-	OnToolInputStart func(id, toolName string)                                                       // Called when tool input starts
-	OnToolInputDelta func(id, delta string)                                                          // Called for tool input deltas
-	OnToolInputEnd   func(id string)                                                                 // Called when tool input ends
-	OnToolCall       func(toolCall ToolCallContent)                                                  // Called when tool call is complete
-	OnToolResult     func(result ToolResultContent)                                                  // Called when tool execution completes
-	OnSource         func(source SourceContent)                                                      // Called for source references
-	OnStreamFinish   func(usage Usage, finishReason FinishReason, providerMetadata ProviderMetadata) // Called when stream finishes
-	OnStreamError    func(error)                                                                     // Called when stream error occurs
+	OnChunk          func(StreamPart) error                                                                // Called for each stream part (catch-all)
+	OnWarnings       func(warnings []CallWarning) error                                                    // Called for warnings
+	OnTextStart      func(id string) error                                                                 // Called when text starts
+	OnTextDelta      func(id, text string) error                                                           // Called for text deltas
+	OnTextEnd        func(id string) error                                                                 // Called when text ends
+	OnReasoningStart func(id string) error                                                                 // Called when reasoning starts
+	OnReasoningDelta func(id, text string) error                                                           // Called for reasoning deltas
+	OnReasoningEnd   func(id string, reasoning ReasoningContent) error                                     // Called when reasoning ends
+	OnToolInputStart func(id, toolName string) error                                                       // Called when tool input starts
+	OnToolInputDelta func(id, delta string) error                                                          // Called for tool input deltas
+	OnToolInputEnd   func(id string) error                                                                 // Called when tool input ends
+	OnToolCall       func(toolCall ToolCallContent) error                                                  // Called when tool call is complete
+	OnToolResult     func(result ToolResultContent) error                                                  // Called when tool execution completes
+	OnSource         func(source SourceContent) error                                                      // Called for source references
+	OnStreamFinish   func(usage Usage, finishReason FinishReason, providerMetadata ProviderMetadata) error // Called when stream finishes
 }
 
 type AgentResult struct {
@@ -540,7 +539,7 @@ func toResponseMessages(content []Content) []Message {
 	return messages
 }
 
-func (a *agent) executeTools(ctx context.Context, allTools []AgentTool, toolCalls []ToolCallContent, toolResultCallback func(result ToolResultContent)) ([]ToolResultContent, error) {
+func (a *agent) executeTools(ctx context.Context, allTools []AgentTool, toolCalls []ToolCallContent, toolResultCallback func(result ToolResultContent) error) ([]ToolResultContent, error) {
 	if len(toolCalls) == 0 {
 		return nil, nil
 	}
@@ -553,7 +552,8 @@ func (a *agent) executeTools(ctx context.Context, allTools []AgentTool, toolCall
 
 	// Execute all tool calls in parallel
 	results := make([]ToolResultContent, len(toolCalls))
-	var toolExecutionError error
+	executeErrors := make([]error, len(toolCalls))
+
 	var wg sync.WaitGroup
 
 	for i, toolCall := range toolCalls {
@@ -572,7 +572,10 @@ func (a *agent) executeTools(ctx context.Context, allTools []AgentTool, toolCall
 					ProviderExecuted: false,
 				}
 				if toolResultCallback != nil {
-					toolResultCallback(results[index])
+					err := toolResultCallback(results[index])
+					if err != nil {
+						executeErrors[index] = err
+					}
 				}
 
 				return
@@ -590,7 +593,10 @@ func (a *agent) executeTools(ctx context.Context, allTools []AgentTool, toolCall
 				}
 
 				if toolResultCallback != nil {
-					toolResultCallback(results[index])
+					err := toolResultCallback(results[index])
+					if err != nil {
+						executeErrors[index] = err
+					}
 				}
 				return
 			}
@@ -612,9 +618,12 @@ func (a *agent) executeTools(ctx context.Context, allTools []AgentTool, toolCall
 					ProviderExecuted: false,
 				}
 				if toolResultCallback != nil {
-					toolResultCallback(results[index])
+					cbErr := toolResultCallback(results[index])
+					if cbErr != nil {
+						executeErrors[index] = cbErr
+					}
 				}
-				toolExecutionError = err
+				executeErrors[index] = err
 				return
 			}
 
@@ -630,7 +639,10 @@ func (a *agent) executeTools(ctx context.Context, allTools []AgentTool, toolCall
 				}
 
 				if toolResultCallback != nil {
-					toolResultCallback(results[index])
+					err := toolResultCallback(results[index])
+					if err != nil {
+						executeErrors[index] = err
+					}
 				}
 			} else {
 				results[index] = ToolResultContent{
@@ -643,7 +655,10 @@ func (a *agent) executeTools(ctx context.Context, allTools []AgentTool, toolCall
 					ProviderExecuted: false,
 				}
 				if toolResultCallback != nil {
-					toolResultCallback(results[index])
+					err := toolResultCallback(results[index])
+					if err != nil {
+						executeErrors[index] = err
+					}
 				}
 			}
 		}(i, toolCall)
@@ -652,7 +667,13 @@ func (a *agent) executeTools(ctx context.Context, allTools []AgentTool, toolCall
 	// Wait for all tool executions to complete
 	wg.Wait()
 
-	return results, toolExecutionError
+	for _, err := range executeErrors {
+		if err != nil {
+			return nil, err
+		}
+	}
+
+	return results, nil
 }
 
 // Stream implements Agent.
@@ -1004,20 +1025,29 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op
 	for part := range stream {
 		// Forward all parts to chunk callback
 		if opts.OnChunk != nil {
-			opts.OnChunk(part)
+			err := opts.OnChunk(part)
+			if err != nil {
+				return StepResult{}, false, err
+			}
 		}
 
 		switch part.Type {
 		case StreamPartTypeWarnings:
 			stepWarnings = part.Warnings
 			if opts.OnWarnings != nil {
-				opts.OnWarnings(part.Warnings)
+				err := opts.OnWarnings(part.Warnings)
+				if err != nil {
+					return StepResult{}, false, err
+				}
 			}
 
 		case StreamPartTypeTextStart:
 			activeTextContent[part.ID] = ""
 			if opts.OnTextStart != nil {
-				opts.OnTextStart(part.ID)
+				err := opts.OnTextStart(part.ID)
+				if err != nil {
+					return StepResult{}, false, err
+				}
 			}
 
 		case StreamPartTypeTextDelta:
@@ -1025,7 +1055,10 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op
 				activeTextContent[part.ID] += part.Delta
 			}
 			if opts.OnTextDelta != nil {
-				opts.OnTextDelta(part.ID, part.Delta)
+				err := opts.OnTextDelta(part.ID, part.Delta)
+				if err != nil {
+					return StepResult{}, false, err
+				}
 			}
 
 		case StreamPartTypeTextEnd:
@@ -1037,13 +1070,19 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op
 				delete(activeTextContent, part.ID)
 			}
 			if opts.OnTextEnd != nil {
-				opts.OnTextEnd(part.ID)
+				err := opts.OnTextEnd(part.ID)
+				if err != nil {
+					return StepResult{}, false, err
+				}
 			}
 
 		case StreamPartTypeReasoningStart:
 			activeTextContent[part.ID] = ""
 			if opts.OnReasoningStart != nil {
-				opts.OnReasoningStart(part.ID)
+				err := opts.OnReasoningStart(part.ID)
+				if err != nil {
+					return StepResult{}, false, err
+				}
 			}
 
 		case StreamPartTypeReasoningDelta:
@@ -1051,7 +1090,10 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op
 				activeTextContent[part.ID] += part.Delta
 			}
 			if opts.OnReasoningDelta != nil {
-				opts.OnReasoningDelta(part.ID, part.Delta)
+				err := opts.OnReasoningDelta(part.ID, part.Delta)
+				if err != nil {
+					return StepResult{}, false, err
+				}
 			}
 
 		case StreamPartTypeReasoningEnd:
@@ -1061,10 +1103,13 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op
 					ProviderMetadata: ProviderMetadata(part.ProviderMetadata),
 				})
 				if opts.OnReasoningEnd != nil {
-					opts.OnReasoningEnd(part.ID, ReasoningContent{
+					err := opts.OnReasoningEnd(part.ID, ReasoningContent{
 						Text:             text,
 						ProviderMetadata: ProviderMetadata(part.ProviderMetadata),
 					})
+					if err != nil {
+						return StepResult{}, false, err
+					}
 				}
 				delete(activeTextContent, part.ID)
 			}
@@ -1077,7 +1122,10 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op
 				ProviderExecuted: part.ProviderExecuted,
 			}
 			if opts.OnToolInputStart != nil {
-				opts.OnToolInputStart(part.ID, part.ToolCallName)
+				err := opts.OnToolInputStart(part.ID, part.ToolCallName)
+				if err != nil {
+					return StepResult{}, false, err
+				}
 			}
 
 		case StreamPartTypeToolInputDelta:
@@ -1085,12 +1133,18 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op
 				toolCall.Input += part.Delta
 			}
 			if opts.OnToolInputDelta != nil {
-				opts.OnToolInputDelta(part.ID, part.Delta)
+				err := opts.OnToolInputDelta(part.ID, part.Delta)
+				if err != nil {
+					return StepResult{}, false, err
+				}
 			}
 
 		case StreamPartTypeToolInputEnd:
 			if opts.OnToolInputEnd != nil {
-				opts.OnToolInputEnd(part.ID)
+				err := opts.OnToolInputEnd(part.ID)
+				if err != nil {
+					return StepResult{}, false, err
+				}
 			}
 
 		case StreamPartTypeToolCall:
@@ -1108,7 +1162,10 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op
 			stepContent = append(stepContent, validatedToolCall)
 
 			if opts.OnToolCall != nil {
-				opts.OnToolCall(validatedToolCall)
+				err := opts.OnToolCall(validatedToolCall)
+				if err != nil {
+					return StepResult{}, false, err
+				}
 			}
 
 			// Clean up active tool call
@@ -1124,7 +1181,10 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op
 			}
 			stepContent = append(stepContent, sourceContent)
 			if opts.OnSource != nil {
-				opts.OnSource(sourceContent)
+				err := opts.OnSource(sourceContent)
+				if err != nil {
+					return StepResult{}, false, err
+				}
 			}
 
 		case StreamPartTypeFinish:
@@ -1132,13 +1192,13 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op
 			stepFinishReason = part.FinishReason
 			stepProviderMetadata = ProviderMetadata(part.ProviderMetadata)
 			if opts.OnStreamFinish != nil {
-				opts.OnStreamFinish(part.Usage, part.FinishReason, part.ProviderMetadata)
+				err := opts.OnStreamFinish(part.Usage, part.FinishReason, part.ProviderMetadata)
+				if err != nil {
+					return StepResult{}, false, err
+				}
 			}
 
 		case StreamPartTypeError:
-			if opts.OnStreamError != nil {
-				opts.OnStreamError(part.Error)
-			}
 			if opts.OnError != nil {
 				opts.OnError(part.Error)
 			}

agent_stream_test.go 🔗

@@ -124,53 +124,65 @@ func TestStreamingAgentCallbacks(t *testing.T) {
 		OnError: func(err error) {
 			callbacks["OnError"] = true
 		},
-		OnChunk: func(part StreamPart) {
+		OnChunk: func(part StreamPart) error {
 			callbacks["OnChunk"] = true
+			return nil
 		},
-		OnWarnings: func(warnings []CallWarning) {
+		OnWarnings: func(warnings []CallWarning) error {
 			callbacks["OnWarnings"] = true
+			return nil
 		},
-		OnTextStart: func(id string) {
+		OnTextStart: func(id string) error {
 			callbacks["OnTextStart"] = true
+			return nil
 		},
-		OnTextDelta: func(id, text string) {
+		OnTextDelta: func(id, text string) error {
 			callbacks["OnTextDelta"] = true
+			return nil
 		},
-		OnTextEnd: func(id string) {
+		OnTextEnd: func(id string) error {
 			callbacks["OnTextEnd"] = true
+			return nil
 		},
-		OnReasoningStart: func(id string) {
+		OnReasoningStart: func(id string) error {
 			callbacks["OnReasoningStart"] = true
+			return nil
 		},
-		OnReasoningDelta: func(id, text string) {
+		OnReasoningDelta: func(id, text string) error {
 			callbacks["OnReasoningDelta"] = true
+			return nil
 		},
-		OnReasoningEnd: func(id string, content ReasoningContent) {
+		OnReasoningEnd: func(id string, content ReasoningContent) error {
 			callbacks["OnReasoningEnd"] = true
+			return nil
 		},
-		OnToolInputStart: func(id, toolName string) {
+		OnToolInputStart: func(id, toolName string) error {
 			callbacks["OnToolInputStart"] = true
+			return nil
 		},
-		OnToolInputDelta: func(id, delta string) {
+		OnToolInputDelta: func(id, delta string) error {
 			callbacks["OnToolInputDelta"] = true
+			return nil
 		},
-		OnToolInputEnd: func(id string) {
+		OnToolInputEnd: func(id string) error {
 			callbacks["OnToolInputEnd"] = true
+			return nil
 		},
-		OnToolCall: func(toolCall ToolCallContent) {
+		OnToolCall: func(toolCall ToolCallContent) error {
 			callbacks["OnToolCall"] = true
+			return nil
 		},
-		OnToolResult: func(result ToolResultContent) {
+		OnToolResult: func(result ToolResultContent) error {
 			callbacks["OnToolResult"] = true
+			return nil
 		},
-		OnSource: func(source SourceContent) {
+		OnSource: func(source SourceContent) error {
 			callbacks["OnSource"] = true
+			return nil
 		},
-		OnStreamFinish: func(usage Usage, finishReason FinishReason, providerMetadata ProviderMetadata) {
+		OnStreamFinish: func(usage Usage, finishReason FinishReason, providerMetadata ProviderMetadata) error {
 			callbacks["OnStreamFinish"] = true
-		},
-		OnStreamError: func(err error) {
-			callbacks["OnStreamError"] = true
+			return nil
 		},
 	}
 
@@ -207,7 +219,6 @@ func TestStreamingAgentCallbacks(t *testing.T) {
 
 	// Verify that error callbacks were not called
 	require.False(t, callbacks["OnError"], "OnError should not be called in successful case")
-	require.False(t, callbacks["OnStreamError"], "OnStreamError should not be called in successful case")
 	require.False(t, callbacks["OnToolCall"], "OnToolCall should not be called without actual tool calls")
 	require.False(t, callbacks["OnToolResult"], "OnToolResult should not be called without actual tool results")
 }
@@ -289,28 +300,33 @@ func TestStreamingAgentWithTools(t *testing.T) {
 	// Create streaming call with callbacks
 	streamCall := AgentStreamCall{
 		Prompt: "Echo 'test'",
-		OnToolInputStart: func(id, toolName string) {
+		OnToolInputStart: func(id, toolName string) error {
 			toolInputStartCalled = true
 			require.Equal(t, "tool-1", id)
 			require.Equal(t, "echo", toolName)
+			return nil
 		},
-		OnToolInputDelta: func(id, delta string) {
+		OnToolInputDelta: func(id, delta string) error {
 			toolInputDeltaCalled = true
 			require.Equal(t, "tool-1", id)
 			require.Contains(t, []string{`{"message"`, `: "test"}`}, delta)
+			return nil
 		},
-		OnToolInputEnd: func(id string) {
+		OnToolInputEnd: func(id string) error {
 			toolInputEndCalled = true
 			require.Equal(t, "tool-1", id)
+			return nil
 		},
-		OnToolCall: func(toolCall ToolCallContent) {
+		OnToolCall: func(toolCall ToolCallContent) error {
 			toolCallCalled = true
 			require.Equal(t, "echo", toolCall.ToolName)
 			require.Equal(t, `{"message": "test"}`, toolCall.Input)
+			return nil
 		},
-		OnToolResult: func(result ToolResultContent) {
+		OnToolResult: func(result ToolResultContent) error {
 			toolResultCalled = true
 			require.Equal(t, "echo", result.ToolName)
+			return nil
 		},
 	}
 
@@ -377,10 +393,11 @@ func TestStreamingAgentTextDeltas(t *testing.T) {
 
 	streamCall := AgentStreamCall{
 		Prompt: "Say hello",
-		OnTextDelta: func(id, text string) {
+		OnTextDelta: func(id, text string) error {
 			if text != "" {
 				textDeltas = append(textDeltas, text)
 			}
+			return nil
 		},
 	}
 
@@ -438,11 +455,13 @@ func TestStreamingAgentReasoning(t *testing.T) {
 
 	streamCall := AgentStreamCall{
 		Prompt: "Think and respond",
-		OnReasoningDelta: func(id, text string) {
+		OnReasoningDelta: func(id, text string) error {
 			reasoningDeltas = append(reasoningDeltas, text)
+			return nil
 		},
-		OnTextDelta: func(id, text string) {
+		OnTextDelta: func(id, text string) error {
 			textDeltas = append(textDeltas, text)
+			return nil
 		},
 	}
 
@@ -473,15 +492,12 @@ func TestStreamingAgentError(t *testing.T) {
 	ctx := context.Background()
 
 	// Track error callbacks
-	var streamErrorOccurred bool
 	var errorOccurred bool
 	var errorMessage string
 
 	streamCall := AgentStreamCall{
 		Prompt: "This will fail",
-		OnStreamError: func(err error) {
-			streamErrorOccurred = true
-		},
+
 		OnError: func(err error) {
 			errorOccurred = true
 			errorMessage = err.Error()
@@ -492,7 +508,6 @@ func TestStreamingAgentError(t *testing.T) {
 	result, err := agent.Stream(ctx, streamCall)
 	require.Error(t, err)
 	require.Nil(t, result)
-	require.True(t, streamErrorOccurred, "OnStreamError should have been called")
 	require.True(t, errorOccurred, "OnError should have been called")
 	require.Contains(t, errorMessage, "mock stream error")
 }
@@ -546,8 +561,9 @@ func TestStreamingAgentSources(t *testing.T) {
 
 	streamCall := AgentStreamCall{
 		Prompt: "Search and respond",
-		OnSource: func(source SourceContent) {
+		OnSource: func(source SourceContent) error {
 			sources = append(sources, source)
+			return nil
 		},
 	}
 

examples/streaming-agent-simple/main.go 🔗

@@ -58,18 +58,21 @@ func main() {
 		Prompt: "Please echo back 'Hello, streaming world!'",
 
 		// Show real-time text as it streams
-		OnTextDelta: func(id, text string) {
+		OnTextDelta: func(id, text string) error {
 			fmt.Print(text)
+			return nil
 		},
 
 		// Show when tools are called
-		OnToolCall: func(toolCall ai.ToolCallContent) {
+		OnToolCall: func(toolCall ai.ToolCallContent) error {
 			fmt.Printf("\n[Tool: %s called]\n", toolCall.ToolName)
+			return nil
 		},
 
 		// Show tool results
-		OnToolResult: func(result ai.ToolResultContent) {
+		OnToolResult: func(result ai.ToolResultContent) error {
 			fmt.Printf("[Tool result received]\n")
+			return nil
 		},
 
 		// Show when each step completes

examples/streaming-agent/main.go 🔗

@@ -144,47 +144,58 @@ func main() {
 		},
 
 		// Stream part callbacks
-		OnWarnings: func(warnings []ai.CallWarning) {
+		OnWarnings: func(warnings []ai.CallWarning) error {
 			for _, warning := range warnings {
 				fmt.Printf("⚠️  Warning: %s\n", warning.Message)
 			}
+			return nil
 		},
-		OnTextStart: func(id string) {
+		OnTextStart: func(id string) error {
 			fmt.Print("💭 Assistant: ")
+			return nil
 		},
-		OnTextDelta: func(id, text string) {
+		OnTextDelta: func(id, text string) error {
 			fmt.Print(text)
 			textBuffer.WriteString(text)
+			return nil
 		},
-		OnTextEnd: func(id string) {
+		OnTextEnd: func(id string) error {
 			fmt.Println()
+			return nil
 		},
-		OnReasoningStart: func(id string) {
+		OnReasoningStart: func(id string) error {
 			fmt.Print("🤔 Thinking: ")
+			return nil
 		},
-		OnReasoningDelta: func(id, text string) {
+		OnReasoningDelta: func(id, text string) error {
 			reasoningBuffer.WriteString(text)
+			return nil
 		},
-		OnReasoningEnd: func(id string, content ai.ReasoningContent) {
+		OnReasoningEnd: func(id string, content ai.ReasoningContent) error {
 			if reasoningBuffer.Len() > 0 {
 				fmt.Printf("%s\n", reasoningBuffer.String())
 				reasoningBuffer.Reset()
 			}
+			return nil
 		},
-		OnToolInputStart: func(id, toolName string) {
+		OnToolInputStart: func(id, toolName string) error {
 			fmt.Printf("🔧 Calling tool: %s\n", toolName)
+			return nil
 		},
-		OnToolInputDelta: func(id, delta string) {
+		OnToolInputDelta: func(id, delta string) error {
 			// Could show tool input being built, but it's often noisy
+			return nil
 		},
-		OnToolInputEnd: func(id string) {
+		OnToolInputEnd: func(id string) error {
 			// Tool input complete
+			return nil
 		},
-		OnToolCall: func(toolCall ai.ToolCallContent) {
+		OnToolCall: func(toolCall ai.ToolCallContent) error {
 			fmt.Printf("🛠️  Tool call: %s\n", toolCall.ToolName)
 			fmt.Printf("   Input: %s\n", toolCall.Input)
+			return nil
 		},
-		OnToolResult: func(result ai.ToolResultContent) {
+		OnToolResult: func(result ai.ToolResultContent) error {
 			fmt.Printf("🎯 Tool result from %s:\n", result.ToolName)
 			switch output := result.Result.(type) {
 			case ai.ToolResultOutputContentText:
@@ -192,15 +203,15 @@ func main() {
 			case ai.ToolResultOutputContentError:
 				fmt.Printf("   Error: %s\n", output.Error.Error())
 			}
+			return nil
 		},
-		OnSource: func(source ai.SourceContent) {
+		OnSource: func(source ai.SourceContent) error {
 			fmt.Printf("📚 Source: %s (%s)\n", source.Title, source.URL)
+			return nil
 		},
-		OnStreamFinish: func(usage ai.Usage, finishReason ai.FinishReason, providerMetadata ai.ProviderMetadata) {
+		OnStreamFinish: func(usage ai.Usage, finishReason ai.FinishReason, providerMetadata ai.ProviderMetadata) error {
 			fmt.Printf("📊 Stream finished (reason: %s, tokens: %d)\n", finishReason, usage.TotalTokens)
-		},
-		OnStreamError: func(err error) {
-			fmt.Printf("💥 Stream error: %v\n", err)
+			return nil
 		},
 	}