From eeb40f6ba40afbad21f622bda7694c7e864d8af3 Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Tue, 26 Aug 2025 12:33:16 -0400 Subject: [PATCH] chore: change callback signature --- internal/ai/agent.go | 144 +++++++++++++----- internal/ai/agent_stream_test.go | 82 ++++++---- .../examples/streaming-agent-simple/main.go | 9 +- internal/ai/examples/streaming-agent/main.go | 45 +++--- 4 files changed, 185 insertions(+), 95 deletions(-) diff --git a/internal/ai/agent.go b/internal/ai/agent.go index d5a8b5a6444e50f5c56da8ed80c3977f849d7aae..07b73b6c72a1a21e40bcb4da4b5cbf3521600c47 100644 --- a/internal/ai/agent.go +++ b/internal/ai/agent.go @@ -182,22 +182,21 @@ type AgentStreamCall struct { OnError func(error) // Called when an error occurs // Stream part callbacks - called for each corresponding stream part type - OnChunk func(StreamPart) // Called for each stream part (catch-all) - OnWarnings func(warnings []CallWarning) // Called for warnings - OnTextStart func(id string) // Called when text starts - OnTextDelta func(id, text string) // Called for text deltas - OnTextEnd func(id string) // Called when text ends - OnReasoningStart func(id string) // Called when reasoning starts - OnReasoningDelta func(id, text string) // Called for reasoning deltas - OnReasoningEnd func(id string, reasoning ReasoningContent) // Called when reasoning ends - OnToolInputStart func(id, toolName string) // Called when tool input starts - OnToolInputDelta func(id, delta string) // Called for tool input deltas - OnToolInputEnd func(id string) // Called when tool input ends - OnToolCall func(toolCall ToolCallContent) // Called when tool call is complete - OnToolResult func(result ToolResultContent) // Called when tool execution completes - OnSource func(source SourceContent) // Called for source references - OnStreamFinish func(usage Usage, finishReason FinishReason, providerMetadata ProviderMetadata) // Called when stream finishes - OnStreamError func(error) // Called when stream error occurs + OnChunk func(StreamPart) error // Called for each stream part (catch-all) + OnWarnings func(warnings []CallWarning) error // Called for warnings + OnTextStart func(id string) error // Called when text starts + OnTextDelta func(id, text string) error // Called for text deltas + OnTextEnd func(id string) error // Called when text ends + OnReasoningStart func(id string) error // Called when reasoning starts + OnReasoningDelta func(id, text string) error // Called for reasoning deltas + OnReasoningEnd func(id string, reasoning ReasoningContent) error // Called when reasoning ends + OnToolInputStart func(id, toolName string) error // Called when tool input starts + OnToolInputDelta func(id, delta string) error // Called for tool input deltas + OnToolInputEnd func(id string) error // Called when tool input ends + OnToolCall func(toolCall ToolCallContent) error // Called when tool call is complete + OnToolResult func(result ToolResultContent) error // Called when tool execution completes + OnSource func(source SourceContent) error // Called for source references + OnStreamFinish func(usage Usage, finishReason FinishReason, providerMetadata ProviderMetadata) error // Called when stream finishes } type AgentResult struct { @@ -540,7 +539,7 @@ func toResponseMessages(content []Content) []Message { return messages } -func (a *agent) executeTools(ctx context.Context, allTools []AgentTool, toolCalls []ToolCallContent, toolResultCallback func(result ToolResultContent)) ([]ToolResultContent, error) { +func (a *agent) executeTools(ctx context.Context, allTools []AgentTool, toolCalls []ToolCallContent, toolResultCallback func(result ToolResultContent) error) ([]ToolResultContent, error) { if len(toolCalls) == 0 { return nil, nil } @@ -553,7 +552,8 @@ func (a *agent) executeTools(ctx context.Context, allTools []AgentTool, toolCall // Execute all tool calls in parallel results := make([]ToolResultContent, len(toolCalls)) - var toolExecutionError error + executeErrors := make([]error, len(toolCalls)) + var wg sync.WaitGroup for i, toolCall := range toolCalls { @@ -572,7 +572,10 @@ func (a *agent) executeTools(ctx context.Context, allTools []AgentTool, toolCall ProviderExecuted: false, } if toolResultCallback != nil { - toolResultCallback(results[index]) + err := toolResultCallback(results[index]) + if err != nil { + executeErrors[index] = err + } } return @@ -590,7 +593,10 @@ func (a *agent) executeTools(ctx context.Context, allTools []AgentTool, toolCall } if toolResultCallback != nil { - toolResultCallback(results[index]) + err := toolResultCallback(results[index]) + if err != nil { + executeErrors[index] = err + } } return } @@ -612,9 +618,12 @@ func (a *agent) executeTools(ctx context.Context, allTools []AgentTool, toolCall ProviderExecuted: false, } if toolResultCallback != nil { - toolResultCallback(results[index]) + cbErr := toolResultCallback(results[index]) + if cbErr != nil { + executeErrors[index] = cbErr + } } - toolExecutionError = err + executeErrors[index] = err return } @@ -630,7 +639,10 @@ func (a *agent) executeTools(ctx context.Context, allTools []AgentTool, toolCall } if toolResultCallback != nil { - toolResultCallback(results[index]) + err := toolResultCallback(results[index]) + if err != nil { + executeErrors[index] = err + } } } else { results[index] = ToolResultContent{ @@ -643,7 +655,10 @@ func (a *agent) executeTools(ctx context.Context, allTools []AgentTool, toolCall ProviderExecuted: false, } if toolResultCallback != nil { - toolResultCallback(results[index]) + err := toolResultCallback(results[index]) + if err != nil { + executeErrors[index] = err + } } } }(i, toolCall) @@ -652,7 +667,13 @@ func (a *agent) executeTools(ctx context.Context, allTools []AgentTool, toolCall // Wait for all tool executions to complete wg.Wait() - return results, toolExecutionError + for _, err := range executeErrors { + if err != nil { + return nil, err + } + } + + return results, nil } // Stream implements Agent. @@ -1004,20 +1025,29 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op for part := range stream { // Forward all parts to chunk callback if opts.OnChunk != nil { - opts.OnChunk(part) + err := opts.OnChunk(part) + if err != nil { + return StepResult{}, false, err + } } switch part.Type { case StreamPartTypeWarnings: stepWarnings = part.Warnings if opts.OnWarnings != nil { - opts.OnWarnings(part.Warnings) + err := opts.OnWarnings(part.Warnings) + if err != nil { + return StepResult{}, false, err + } } case StreamPartTypeTextStart: activeTextContent[part.ID] = "" if opts.OnTextStart != nil { - opts.OnTextStart(part.ID) + err := opts.OnTextStart(part.ID) + if err != nil { + return StepResult{}, false, err + } } case StreamPartTypeTextDelta: @@ -1025,7 +1055,10 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op activeTextContent[part.ID] += part.Delta } if opts.OnTextDelta != nil { - opts.OnTextDelta(part.ID, part.Delta) + err := opts.OnTextDelta(part.ID, part.Delta) + if err != nil { + return StepResult{}, false, err + } } case StreamPartTypeTextEnd: @@ -1037,13 +1070,19 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op delete(activeTextContent, part.ID) } if opts.OnTextEnd != nil { - opts.OnTextEnd(part.ID) + err := opts.OnTextEnd(part.ID) + if err != nil { + return StepResult{}, false, err + } } case StreamPartTypeReasoningStart: activeTextContent[part.ID] = "" if opts.OnReasoningStart != nil { - opts.OnReasoningStart(part.ID) + err := opts.OnReasoningStart(part.ID) + if err != nil { + return StepResult{}, false, err + } } case StreamPartTypeReasoningDelta: @@ -1051,7 +1090,10 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op activeTextContent[part.ID] += part.Delta } if opts.OnReasoningDelta != nil { - opts.OnReasoningDelta(part.ID, part.Delta) + err := opts.OnReasoningDelta(part.ID, part.Delta) + if err != nil { + return StepResult{}, false, err + } } case StreamPartTypeReasoningEnd: @@ -1061,10 +1103,13 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op ProviderMetadata: ProviderMetadata(part.ProviderMetadata), }) if opts.OnReasoningEnd != nil { - opts.OnReasoningEnd(part.ID, ReasoningContent{ + err := opts.OnReasoningEnd(part.ID, ReasoningContent{ Text: text, ProviderMetadata: ProviderMetadata(part.ProviderMetadata), }) + if err != nil { + return StepResult{}, false, err + } } delete(activeTextContent, part.ID) } @@ -1077,7 +1122,10 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op ProviderExecuted: part.ProviderExecuted, } if opts.OnToolInputStart != nil { - opts.OnToolInputStart(part.ID, part.ToolCallName) + err := opts.OnToolInputStart(part.ID, part.ToolCallName) + if err != nil { + return StepResult{}, false, err + } } case StreamPartTypeToolInputDelta: @@ -1085,12 +1133,18 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op toolCall.Input += part.Delta } if opts.OnToolInputDelta != nil { - opts.OnToolInputDelta(part.ID, part.Delta) + err := opts.OnToolInputDelta(part.ID, part.Delta) + if err != nil { + return StepResult{}, false, err + } } case StreamPartTypeToolInputEnd: if opts.OnToolInputEnd != nil { - opts.OnToolInputEnd(part.ID) + err := opts.OnToolInputEnd(part.ID) + if err != nil { + return StepResult{}, false, err + } } case StreamPartTypeToolCall: @@ -1108,7 +1162,10 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op stepContent = append(stepContent, validatedToolCall) if opts.OnToolCall != nil { - opts.OnToolCall(validatedToolCall) + err := opts.OnToolCall(validatedToolCall) + if err != nil { + return StepResult{}, false, err + } } // Clean up active tool call @@ -1124,7 +1181,10 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op } stepContent = append(stepContent, sourceContent) if opts.OnSource != nil { - opts.OnSource(sourceContent) + err := opts.OnSource(sourceContent) + if err != nil { + return StepResult{}, false, err + } } case StreamPartTypeFinish: @@ -1132,13 +1192,13 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op stepFinishReason = part.FinishReason stepProviderMetadata = ProviderMetadata(part.ProviderMetadata) if opts.OnStreamFinish != nil { - opts.OnStreamFinish(part.Usage, part.FinishReason, part.ProviderMetadata) + err := opts.OnStreamFinish(part.Usage, part.FinishReason, part.ProviderMetadata) + if err != nil { + return StepResult{}, false, err + } } case StreamPartTypeError: - if opts.OnStreamError != nil { - opts.OnStreamError(part.Error) - } if opts.OnError != nil { opts.OnError(part.Error) } diff --git a/internal/ai/agent_stream_test.go b/internal/ai/agent_stream_test.go index 9bd1477a04d97e0fb5d49c6c2deebe1f2952a969..d3c0846dcd2a8d18d948f1cf65770f66941da9d6 100644 --- a/internal/ai/agent_stream_test.go +++ b/internal/ai/agent_stream_test.go @@ -124,53 +124,65 @@ func TestStreamingAgentCallbacks(t *testing.T) { OnError: func(err error) { callbacks["OnError"] = true }, - OnChunk: func(part StreamPart) { + OnChunk: func(part StreamPart) error { callbacks["OnChunk"] = true + return nil }, - OnWarnings: func(warnings []CallWarning) { + OnWarnings: func(warnings []CallWarning) error { callbacks["OnWarnings"] = true + return nil }, - OnTextStart: func(id string) { + OnTextStart: func(id string) error { callbacks["OnTextStart"] = true + return nil }, - OnTextDelta: func(id, text string) { + OnTextDelta: func(id, text string) error { callbacks["OnTextDelta"] = true + return nil }, - OnTextEnd: func(id string) { + OnTextEnd: func(id string) error { callbacks["OnTextEnd"] = true + return nil }, - OnReasoningStart: func(id string) { + OnReasoningStart: func(id string) error { callbacks["OnReasoningStart"] = true + return nil }, - OnReasoningDelta: func(id, text string) { + OnReasoningDelta: func(id, text string) error { callbacks["OnReasoningDelta"] = true + return nil }, - OnReasoningEnd: func(id string, content ReasoningContent) { + OnReasoningEnd: func(id string, content ReasoningContent) error { callbacks["OnReasoningEnd"] = true + return nil }, - OnToolInputStart: func(id, toolName string) { + OnToolInputStart: func(id, toolName string) error { callbacks["OnToolInputStart"] = true + return nil }, - OnToolInputDelta: func(id, delta string) { + OnToolInputDelta: func(id, delta string) error { callbacks["OnToolInputDelta"] = true + return nil }, - OnToolInputEnd: func(id string) { + OnToolInputEnd: func(id string) error { callbacks["OnToolInputEnd"] = true + return nil }, - OnToolCall: func(toolCall ToolCallContent) { + OnToolCall: func(toolCall ToolCallContent) error { callbacks["OnToolCall"] = true + return nil }, - OnToolResult: func(result ToolResultContent) { + OnToolResult: func(result ToolResultContent) error { callbacks["OnToolResult"] = true + return nil }, - OnSource: func(source SourceContent) { + OnSource: func(source SourceContent) error { callbacks["OnSource"] = true + return nil }, - OnStreamFinish: func(usage Usage, finishReason FinishReason, providerMetadata ProviderMetadata) { + OnStreamFinish: func(usage Usage, finishReason FinishReason, providerMetadata ProviderMetadata) error { callbacks["OnStreamFinish"] = true - }, - OnStreamError: func(err error) { - callbacks["OnStreamError"] = true + return nil }, } @@ -207,7 +219,6 @@ func TestStreamingAgentCallbacks(t *testing.T) { // Verify that error callbacks were not called require.False(t, callbacks["OnError"], "OnError should not be called in successful case") - require.False(t, callbacks["OnStreamError"], "OnStreamError should not be called in successful case") require.False(t, callbacks["OnToolCall"], "OnToolCall should not be called without actual tool calls") require.False(t, callbacks["OnToolResult"], "OnToolResult should not be called without actual tool results") } @@ -289,28 +300,33 @@ func TestStreamingAgentWithTools(t *testing.T) { // Create streaming call with callbacks streamCall := AgentStreamCall{ Prompt: "Echo 'test'", - OnToolInputStart: func(id, toolName string) { + OnToolInputStart: func(id, toolName string) error { toolInputStartCalled = true require.Equal(t, "tool-1", id) require.Equal(t, "echo", toolName) + return nil }, - OnToolInputDelta: func(id, delta string) { + OnToolInputDelta: func(id, delta string) error { toolInputDeltaCalled = true require.Equal(t, "tool-1", id) require.Contains(t, []string{`{"message"`, `: "test"}`}, delta) + return nil }, - OnToolInputEnd: func(id string) { + OnToolInputEnd: func(id string) error { toolInputEndCalled = true require.Equal(t, "tool-1", id) + return nil }, - OnToolCall: func(toolCall ToolCallContent) { + OnToolCall: func(toolCall ToolCallContent) error { toolCallCalled = true require.Equal(t, "echo", toolCall.ToolName) require.Equal(t, `{"message": "test"}`, toolCall.Input) + return nil }, - OnToolResult: func(result ToolResultContent) { + OnToolResult: func(result ToolResultContent) error { toolResultCalled = true require.Equal(t, "echo", result.ToolName) + return nil }, } @@ -377,10 +393,11 @@ func TestStreamingAgentTextDeltas(t *testing.T) { streamCall := AgentStreamCall{ Prompt: "Say hello", - OnTextDelta: func(id, text string) { + OnTextDelta: func(id, text string) error { if text != "" { textDeltas = append(textDeltas, text) } + return nil }, } @@ -438,11 +455,13 @@ func TestStreamingAgentReasoning(t *testing.T) { streamCall := AgentStreamCall{ Prompt: "Think and respond", - OnReasoningDelta: func(id, text string) { + OnReasoningDelta: func(id, text string) error { reasoningDeltas = append(reasoningDeltas, text) + return nil }, - OnTextDelta: func(id, text string) { + OnTextDelta: func(id, text string) error { textDeltas = append(textDeltas, text) + return nil }, } @@ -473,15 +492,12 @@ func TestStreamingAgentError(t *testing.T) { ctx := context.Background() // Track error callbacks - var streamErrorOccurred bool var errorOccurred bool var errorMessage string streamCall := AgentStreamCall{ Prompt: "This will fail", - OnStreamError: func(err error) { - streamErrorOccurred = true - }, + OnError: func(err error) { errorOccurred = true errorMessage = err.Error() @@ -492,7 +508,6 @@ func TestStreamingAgentError(t *testing.T) { result, err := agent.Stream(ctx, streamCall) require.Error(t, err) require.Nil(t, result) - require.True(t, streamErrorOccurred, "OnStreamError should have been called") require.True(t, errorOccurred, "OnError should have been called") require.Contains(t, errorMessage, "mock stream error") } @@ -546,8 +561,9 @@ func TestStreamingAgentSources(t *testing.T) { streamCall := AgentStreamCall{ Prompt: "Search and respond", - OnSource: func(source SourceContent) { + OnSource: func(source SourceContent) error { sources = append(sources, source) + return nil }, } diff --git a/internal/ai/examples/streaming-agent-simple/main.go b/internal/ai/examples/streaming-agent-simple/main.go index 72a8ac2d0aae94f330fd1ed71c9ceddaa41627dc..bd0e1bc971868baa69e17535bea7f5007ccd5bf9 100644 --- a/internal/ai/examples/streaming-agent-simple/main.go +++ b/internal/ai/examples/streaming-agent-simple/main.go @@ -58,18 +58,21 @@ func main() { Prompt: "Please echo back 'Hello, streaming world!'", // Show real-time text as it streams - OnTextDelta: func(id, text string) { + OnTextDelta: func(id, text string) error { fmt.Print(text) + return nil }, // Show when tools are called - OnToolCall: func(toolCall ai.ToolCallContent) { + OnToolCall: func(toolCall ai.ToolCallContent) error { fmt.Printf("\n[Tool: %s called]\n", toolCall.ToolName) + return nil }, // Show tool results - OnToolResult: func(result ai.ToolResultContent) { + OnToolResult: func(result ai.ToolResultContent) error { fmt.Printf("[Tool result received]\n") + return nil }, // Show when each step completes diff --git a/internal/ai/examples/streaming-agent/main.go b/internal/ai/examples/streaming-agent/main.go index efb80d57035c54e886cf35f1dcac8785670d3419..0aa2846c78985749122e2b9aa4167808d8792220 100644 --- a/internal/ai/examples/streaming-agent/main.go +++ b/internal/ai/examples/streaming-agent/main.go @@ -144,47 +144,58 @@ func main() { }, // Stream part callbacks - OnWarnings: func(warnings []ai.CallWarning) { + OnWarnings: func(warnings []ai.CallWarning) error { for _, warning := range warnings { fmt.Printf("⚠️ Warning: %s\n", warning.Message) } + return nil }, - OnTextStart: func(id string) { + OnTextStart: func(id string) error { fmt.Print("💭 Assistant: ") + return nil }, - OnTextDelta: func(id, text string) { + OnTextDelta: func(id, text string) error { fmt.Print(text) textBuffer.WriteString(text) + return nil }, - OnTextEnd: func(id string) { + OnTextEnd: func(id string) error { fmt.Println() + return nil }, - OnReasoningStart: func(id string) { + OnReasoningStart: func(id string) error { fmt.Print("🤔 Thinking: ") + return nil }, - OnReasoningDelta: func(id, text string) { + OnReasoningDelta: func(id, text string) error { reasoningBuffer.WriteString(text) + return nil }, - OnReasoningEnd: func(id string, content ai.ReasoningContent) { + OnReasoningEnd: func(id string, content ai.ReasoningContent) error { if reasoningBuffer.Len() > 0 { fmt.Printf("%s\n", reasoningBuffer.String()) reasoningBuffer.Reset() } + return nil }, - OnToolInputStart: func(id, toolName string) { + OnToolInputStart: func(id, toolName string) error { fmt.Printf("🔧 Calling tool: %s\n", toolName) + return nil }, - OnToolInputDelta: func(id, delta string) { + OnToolInputDelta: func(id, delta string) error { // Could show tool input being built, but it's often noisy + return nil }, - OnToolInputEnd: func(id string) { + OnToolInputEnd: func(id string) error { // Tool input complete + return nil }, - OnToolCall: func(toolCall ai.ToolCallContent) { + OnToolCall: func(toolCall ai.ToolCallContent) error { fmt.Printf("🛠️ Tool call: %s\n", toolCall.ToolName) fmt.Printf(" Input: %s\n", toolCall.Input) + return nil }, - OnToolResult: func(result ai.ToolResultContent) { + OnToolResult: func(result ai.ToolResultContent) error { fmt.Printf("🎯 Tool result from %s:\n", result.ToolName) switch output := result.Result.(type) { case ai.ToolResultOutputContentText: @@ -192,15 +203,15 @@ func main() { case ai.ToolResultOutputContentError: fmt.Printf(" Error: %s\n", output.Error.Error()) } + return nil }, - OnSource: func(source ai.SourceContent) { + OnSource: func(source ai.SourceContent) error { fmt.Printf("📚 Source: %s (%s)\n", source.Title, source.URL) + return nil }, - OnStreamFinish: func(usage ai.Usage, finishReason ai.FinishReason, providerMetadata ai.ProviderMetadata) { + OnStreamFinish: func(usage ai.Usage, finishReason ai.FinishReason, providerMetadata ai.ProviderMetadata) error { fmt.Printf("📊 Stream finished (reason: %s, tokens: %d)\n", finishReason, usage.TotalTokens) - }, - OnStreamError: func(err error) { - fmt.Printf("💥 Stream error: %v\n", err) + return nil }, }