Detailed changes
@@ -157,6 +157,53 @@ type AgentCall struct {
OnStepFinished OnStepFinishedFunction
}
+type AgentStreamCall struct {
+ Prompt string `json:"prompt"`
+ Files []FilePart `json:"files"`
+ Messages []Message `json:"messages"`
+ MaxOutputTokens *int64
+ Temperature *float64 `json:"temperature"`
+ TopP *float64 `json:"top_p"`
+ TopK *int64 `json:"top_k"`
+ PresencePenalty *float64 `json:"presence_penalty"`
+ FrequencyPenalty *float64 `json:"frequency_penalty"`
+ ActiveTools []string `json:"active_tools"`
+ Headers map[string]string
+ ProviderOptions ProviderOptions
+ OnRetry OnRetryCallback
+ MaxRetries *int
+
+ StopWhen []StopCondition
+ PrepareStep PrepareStepFunction
+ RepairToolCall RepairToolCallFunction
+
+ // Agent-level callbacks
+ OnAgentStart func() // Called when agent starts
+ OnAgentFinish func(result *AgentResult) // Called when agent finishes
+ OnStepStart func(stepNumber int) // Called when a step starts
+ OnStepFinish func(stepResult StepResult) // Called when a step finishes
+ OnFinish func(result *AgentResult) // Called when entire agent completes
+ OnError func(error) // Called when an error occurs
+
+ // Stream part callbacks - called for each corresponding stream part type
+ OnChunk func(StreamPart) // Called for each stream part (catch-all)
+ OnWarnings func(warnings []CallWarning) // Called for warnings
+ OnTextStart func(id string) // Called when text starts
+ OnTextDelta func(id, text string) // Called for text deltas
+ OnTextEnd func(id string) // Called when text ends
+ OnReasoningStart func(id string) // Called when reasoning starts
+ OnReasoningDelta func(id, text string) // Called for reasoning deltas
+ OnReasoningEnd func(id string) // Called when reasoning ends
+ OnToolInputStart func(id, toolName string) // Called when tool input starts
+ OnToolInputDelta func(id, delta string) // Called for tool input deltas
+ OnToolInputEnd func(id string) // Called when tool input ends
+ OnToolCall func(toolCall ToolCallContent) // Called when tool call is complete
+ OnToolResult func(result ToolResultContent) // Called when tool execution completes
+ OnSource func(source SourceContent) // Called for source references
+ OnStreamFinish func(usage Usage, finishReason FinishReason, providerMetadata ProviderOptions) // Called when stream finishes
+ OnStreamError func(error) // Called when stream error occurs
+}
+
type AgentResult struct {
Steps []StepResult
// Final response
@@ -166,7 +213,7 @@ type AgentResult struct {
type Agent interface {
Generate(context.Context, AgentCall) (*AgentResult, error)
- Stream(context.Context, AgentCall) (StreamResponse, error)
+ Stream(context.Context, AgentStreamCall) (*AgentResult, error)
}
type agentOption = func(*AgentSettings)
@@ -343,7 +390,7 @@ func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, err
}
}
- toolResults, err := a.executeTools(ctx, a.settings.tools, stepToolCalls)
+ toolResults, err := a.executeTools(ctx, a.settings.tools, stepToolCalls, nil)
// Build step content with validated tool calls and tool results
stepContent := []Content{}
@@ -501,7 +548,7 @@ func toResponseMessages(content []Content) []Message {
return messages
}
-func (a *agent) executeTools(ctx context.Context, allTools []tools.BaseTool, toolCalls []ToolCallContent) ([]ToolResultContent, error) {
+func (a *agent) executeTools(ctx context.Context, allTools []tools.BaseTool, toolCalls []ToolCallContent, toolResultCallback func(result ToolResultContent)) ([]ToolResultContent, error) {
if len(toolCalls) == 0 {
return nil, nil
}
@@ -532,6 +579,10 @@ func (a *agent) executeTools(ctx context.Context, allTools []tools.BaseTool, too
},
ProviderExecuted: false,
}
+ if toolResultCallback != nil {
+ toolResultCallback(results[index])
+ }
+
return
}
@@ -545,6 +596,10 @@ func (a *agent) executeTools(ctx context.Context, allTools []tools.BaseTool, too
},
ProviderExecuted: false,
}
+
+ if toolResultCallback != nil {
+ toolResultCallback(results[index])
+ }
return
}
@@ -563,6 +618,9 @@ func (a *agent) executeTools(ctx context.Context, allTools []tools.BaseTool, too
},
ProviderExecuted: false,
}
+ if toolResultCallback != nil {
+ toolResultCallback(results[index])
+ }
toolExecutionError = err
return
}
@@ -576,6 +634,10 @@ func (a *agent) executeTools(ctx context.Context, allTools []tools.BaseTool, too
},
ProviderExecuted: false,
}
+
+ if toolResultCallback != nil {
+ toolResultCallback(results[index])
+ }
} else {
results[index] = ToolResultContent{
ToolCallID: call.ToolCallID,
@@ -585,6 +647,9 @@ func (a *agent) executeTools(ctx context.Context, allTools []tools.BaseTool, too
},
ProviderExecuted: false,
}
+ if toolResultCallback != nil {
+ toolResultCallback(results[index])
+ }
}
}(i, toolCall)
}
@@ -596,9 +661,164 @@ func (a *agent) executeTools(ctx context.Context, allTools []tools.BaseTool, too
}
// Stream implements Agent.
-func (a *agent) Stream(ctx context.Context, opts AgentCall) (StreamResponse, error) {
- // TODO: implement the agentic stuff
- panic("not implemented")
+func (a *agent) Stream(ctx context.Context, opts AgentStreamCall) (*AgentResult, error) {
+ // Convert AgentStreamCall to AgentCall for preparation
+ call := AgentCall{
+ Prompt: opts.Prompt,
+ Files: opts.Files,
+ Messages: opts.Messages,
+ MaxOutputTokens: opts.MaxOutputTokens,
+ Temperature: opts.Temperature,
+ TopP: opts.TopP,
+ TopK: opts.TopK,
+ PresencePenalty: opts.PresencePenalty,
+ FrequencyPenalty: opts.FrequencyPenalty,
+ ActiveTools: opts.ActiveTools,
+ Headers: opts.Headers,
+ ProviderOptions: opts.ProviderOptions,
+ MaxRetries: opts.MaxRetries,
+ StopWhen: opts.StopWhen,
+ PrepareStep: opts.PrepareStep,
+ RepairToolCall: opts.RepairToolCall,
+ }
+
+ call = a.prepareCall(call)
+
+ initialPrompt, err := a.createPrompt(a.settings.systemPrompt, call.Prompt, call.Messages, call.Files...)
+ if err != nil {
+ return nil, err
+ }
+
+ var responseMessages []Message
+ var steps []StepResult
+ var totalUsage Usage
+
+ // Start agent stream
+ if opts.OnAgentStart != nil {
+ opts.OnAgentStart()
+ }
+
+ for stepNumber := 0; ; stepNumber++ {
+ stepInputMessages := append(initialPrompt, responseMessages...)
+ stepModel := a.settings.model
+ stepSystemPrompt := a.settings.systemPrompt
+ stepActiveTools := call.ActiveTools
+ stepToolChoice := ToolChoiceAuto
+ disableAllTools := false
+
+ // Apply step preparation if provided
+ if call.PrepareStep != nil {
+ prepared := call.PrepareStep(PrepareStepFunctionOptions{
+ Model: stepModel,
+ Steps: steps,
+ StepNumber: stepNumber,
+ Messages: stepInputMessages,
+ })
+
+ if prepared.Messages != nil {
+ stepInputMessages = prepared.Messages
+ }
+ if prepared.Model != nil {
+ stepModel = prepared.Model
+ }
+ if prepared.System != nil {
+ stepSystemPrompt = *prepared.System
+ }
+ if prepared.ToolChoice != nil {
+ stepToolChoice = *prepared.ToolChoice
+ }
+ if len(prepared.ActiveTools) > 0 {
+ stepActiveTools = prepared.ActiveTools
+ }
+ disableAllTools = prepared.DisableAllTools
+ }
+
+ // Recreate prompt with potentially modified system prompt
+ if stepSystemPrompt != a.settings.systemPrompt {
+ stepPrompt, err := a.createPrompt(stepSystemPrompt, call.Prompt, call.Messages, call.Files...)
+ if err != nil {
+ return nil, err
+ }
+ if len(stepInputMessages) > 0 && len(stepPrompt) > 0 {
+ stepInputMessages[0] = stepPrompt[0]
+ }
+ }
+
+ preparedTools := a.prepareTools(a.settings.tools, stepActiveTools, disableAllTools)
+
+ // Start step stream
+ if opts.OnStepStart != nil {
+ opts.OnStepStart(stepNumber)
+ }
+
+ // Create streaming call
+ streamCall := Call{
+ Prompt: stepInputMessages,
+ MaxOutputTokens: call.MaxOutputTokens,
+ Temperature: call.Temperature,
+ TopP: call.TopP,
+ TopK: call.TopK,
+ PresencePenalty: call.PresencePenalty,
+ FrequencyPenalty: call.FrequencyPenalty,
+ Tools: preparedTools,
+ ToolChoice: &stepToolChoice,
+ Headers: call.Headers,
+ ProviderOptions: call.ProviderOptions,
+ }
+
+ // Get streaming response
+ stream, err := stepModel.Stream(ctx, streamCall)
+ if err != nil {
+ if opts.OnError != nil {
+ opts.OnError(err)
+ }
+ return nil, err
+ }
+
+ // Process stream with tool execution
+ stepResult, shouldContinue, err := a.processStepStream(ctx, stream, opts, steps)
+ if err != nil {
+ if opts.OnError != nil {
+ opts.OnError(err)
+ }
+ return nil, err
+ }
+
+ steps = append(steps, stepResult)
+ totalUsage = addUsage(totalUsage, stepResult.Usage)
+
+ // Call step finished callback
+ if opts.OnStepFinish != nil {
+ opts.OnStepFinish(stepResult)
+ }
+
+ // Add step messages to response messages
+ stepMessages := toResponseMessages(stepResult.Content)
+ responseMessages = append(responseMessages, stepMessages...)
+
+ // Check stop conditions
+ shouldStop := isStopConditionMet(call.StopWhen, steps)
+ if shouldStop || !shouldContinue {
+ break
+ }
+ }
+
+ // Finish agent stream
+ agentResult := &AgentResult{
+ Steps: steps,
+ Response: steps[len(steps)-1].Response,
+ TotalUsage: totalUsage,
+ }
+
+ if opts.OnFinish != nil {
+ opts.OnFinish(agentResult)
+ }
+
+ if opts.OnAgentFinish != nil {
+ opts.OnAgentFinish(agentResult)
+ }
+
+ return agentResult, nil
}
func (a *agent) prepareTools(tools []tools.BaseTool, activeTools []string, disableAllTools bool) []Tool {
@@ -776,6 +996,203 @@ func WithOnStepFinished(fn OnStepFinishedFunction) agentOption {
}
}
+// processStepStream processes a single step's stream and returns the step result
+func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, opts AgentStreamCall, _ []StepResult) (StepResult, bool, error) {
+ var stepContent []Content
+ var stepToolCalls []ToolCallContent
+ var stepUsage Usage
+ var stepFinishReason FinishReason = FinishReasonUnknown
+ var stepWarnings []CallWarning
+ var stepProviderMetadata ProviderMetadata
+
+ activeToolCalls := make(map[string]*ToolCallContent)
+ activeTextContent := make(map[string]string)
+
+ // Process stream parts
+ for part := range stream {
+ // Forward all parts to chunk callback
+ if opts.OnChunk != nil {
+ opts.OnChunk(part)
+ }
+
+ switch part.Type {
+ case StreamPartTypeWarnings:
+ stepWarnings = part.Warnings
+ if opts.OnWarnings != nil {
+ opts.OnWarnings(part.Warnings)
+ }
+
+ case StreamPartTypeTextStart:
+ activeTextContent[part.ID] = ""
+ if opts.OnTextStart != nil {
+ opts.OnTextStart(part.ID)
+ }
+
+ case StreamPartTypeTextDelta:
+ if _, exists := activeTextContent[part.ID]; exists {
+ activeTextContent[part.ID] += part.Delta
+ }
+ if opts.OnTextDelta != nil {
+ opts.OnTextDelta(part.ID, part.Delta)
+ }
+
+ case StreamPartTypeTextEnd:
+ if text, exists := activeTextContent[part.ID]; exists {
+ stepContent = append(stepContent, TextContent{
+ Text: text,
+ ProviderMetadata: ProviderMetadata(part.ProviderMetadata),
+ })
+ delete(activeTextContent, part.ID)
+ }
+ if opts.OnTextEnd != nil {
+ opts.OnTextEnd(part.ID)
+ }
+
+ case StreamPartTypeReasoningStart:
+ activeTextContent[part.ID] = ""
+ if opts.OnReasoningStart != nil {
+ opts.OnReasoningStart(part.ID)
+ }
+
+ case StreamPartTypeReasoningDelta:
+ if _, exists := activeTextContent[part.ID]; exists {
+ activeTextContent[part.ID] += part.Delta
+ }
+ if opts.OnReasoningDelta != nil {
+ opts.OnReasoningDelta(part.ID, part.Delta)
+ }
+
+ case StreamPartTypeReasoningEnd:
+ if text, exists := activeTextContent[part.ID]; exists {
+ stepContent = append(stepContent, ReasoningContent{
+ Text: text,
+ ProviderMetadata: ProviderMetadata(part.ProviderMetadata),
+ })
+ delete(activeTextContent, part.ID)
+ }
+ if opts.OnReasoningEnd != nil {
+ opts.OnReasoningEnd(part.ID)
+ }
+
+ case StreamPartTypeToolInputStart:
+ activeToolCalls[part.ID] = &ToolCallContent{
+ ToolCallID: part.ID,
+ ToolName: part.ToolCallName,
+ Input: "",
+ ProviderExecuted: part.ProviderExecuted,
+ }
+ if opts.OnToolInputStart != nil {
+ opts.OnToolInputStart(part.ID, part.ToolCallName)
+ }
+
+ case StreamPartTypeToolInputDelta:
+ if toolCall, exists := activeToolCalls[part.ID]; exists {
+ toolCall.Input += part.Delta
+ }
+ if opts.OnToolInputDelta != nil {
+ opts.OnToolInputDelta(part.ID, part.Delta)
+ }
+
+ case StreamPartTypeToolInputEnd:
+ if opts.OnToolInputEnd != nil {
+ opts.OnToolInputEnd(part.ID)
+ }
+
+ case StreamPartTypeToolCall:
+ toolCall := ToolCallContent{
+ ToolCallID: part.ID,
+ ToolName: part.ToolCallName,
+ Input: part.ToolCallInput,
+ ProviderExecuted: part.ProviderExecuted,
+ ProviderMetadata: ProviderMetadata(part.ProviderMetadata),
+ }
+
+ // Validate and potentially repair the tool call
+ validatedToolCall := a.validateAndRepairToolCall(ctx, toolCall, a.settings.tools, a.settings.systemPrompt, nil, opts.RepairToolCall)
+ stepToolCalls = append(stepToolCalls, validatedToolCall)
+ stepContent = append(stepContent, validatedToolCall)
+
+ if opts.OnToolCall != nil {
+ opts.OnToolCall(validatedToolCall)
+ }
+
+ // Clean up active tool call
+ delete(activeToolCalls, part.ID)
+
+ case StreamPartTypeSource:
+ sourceContent := SourceContent{
+ SourceType: part.SourceType,
+ ID: part.ID,
+ URL: part.URL,
+ Title: part.Title,
+ ProviderMetadata: ProviderMetadata(part.ProviderMetadata),
+ }
+ stepContent = append(stepContent, sourceContent)
+ if opts.OnSource != nil {
+ opts.OnSource(sourceContent)
+ }
+
+ case StreamPartTypeFinish:
+ stepUsage = part.Usage
+ stepFinishReason = part.FinishReason
+ stepProviderMetadata = ProviderMetadata(part.ProviderMetadata)
+ if opts.OnStreamFinish != nil {
+ opts.OnStreamFinish(part.Usage, part.FinishReason, part.ProviderMetadata)
+ }
+
+ case StreamPartTypeError:
+ if opts.OnStreamError != nil {
+ opts.OnStreamError(part.Error)
+ }
+ if opts.OnError != nil {
+ opts.OnError(part.Error)
+ }
+ return StepResult{}, false, part.Error
+ }
+ }
+
+ // Execute tools if any
+ var toolResults []ToolResultContent
+ if len(stepToolCalls) > 0 {
+ var err error
+ toolResults, err = a.executeTools(ctx, a.settings.tools, stepToolCalls, opts.OnToolResult)
+ if err != nil {
+ return StepResult{}, false, err
+ }
+ // Add tool results to content
+ for _, result := range toolResults {
+ stepContent = append(stepContent, result)
+ }
+ }
+
+ stepResult := StepResult{
+ Response: Response{
+ Content: stepContent,
+ FinishReason: stepFinishReason,
+ Usage: stepUsage,
+ Warnings: stepWarnings,
+ ProviderMetadata: stepProviderMetadata,
+ },
+ Messages: toResponseMessages(stepContent),
+ }
+
+ // Determine if we should continue (has tool calls and not stopped)
+ shouldContinue := len(stepToolCalls) > 0 && stepFinishReason == FinishReasonToolCalls
+
+ return stepResult, shouldContinue, nil
+}
+
+func addUsage(a, b Usage) Usage {
+ return Usage{
+ InputTokens: a.InputTokens + b.InputTokens,
+ OutputTokens: a.OutputTokens + b.OutputTokens,
+ TotalTokens: a.TotalTokens + b.TotalTokens,
+ ReasoningTokens: a.ReasoningTokens + b.ReasoningTokens,
+ CacheCreationTokens: a.CacheCreationTokens + b.CacheCreationTokens,
+ CacheReadTokens: a.CacheReadTokens + b.CacheReadTokens,
+ }
+}
+
func WithHeaders(headers map[string]string) agentOption {
return func(s *AgentSettings) {
s.headers = headers
@@ -0,0 +1,569 @@
+package ai
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "testing"
+
+ "github.com/charmbracelet/crush/internal/llm/tools"
+ "github.com/stretchr/testify/require"
+)
+
+// EchoTool is a simple tool that echoes back the input message
+type EchoTool struct{}
+
+// Info returns the tool information
+func (e *EchoTool) Info() tools.ToolInfo {
+ return tools.ToolInfo{
+ Name: "echo",
+ Description: "Echo back the provided message",
+ Parameters: map[string]any{
+ "message": map[string]any{
+ "type": "string",
+ "description": "The message to echo back",
+ },
+ },
+ Required: []string{"message"},
+ }
+}
+
+// Run executes the echo tool
+func (e *EchoTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolResponse, error) {
+ var input struct {
+ Message string `json:"message"`
+ }
+
+ if err := json.Unmarshal([]byte(params.Input), &input); err != nil {
+ return tools.NewTextErrorResponse("Invalid input: " + err.Error()), nil
+ }
+
+ if input.Message == "" {
+ return tools.NewTextErrorResponse("Message cannot be empty"), nil
+ }
+
+ return tools.NewTextResponse("Echo: " + input.Message), nil
+}
+
+// TestStreamingAgentCallbacks tests that all streaming callbacks are called correctly
+func TestStreamingAgentCallbacks(t *testing.T) {
+ t.Parallel()
+
+ // Track which callbacks were called
+ callbacks := make(map[string]bool)
+
+ // Create a mock language model that returns various stream parts
+ mockModel := &mockLanguageModel{
+ streamFunc: func(ctx context.Context, call Call) (StreamResponse, error) {
+ return func(yield func(StreamPart) bool) {
+ // Test all stream part types
+ if !yield(StreamPart{Type: StreamPartTypeWarnings, Warnings: []CallWarning{{Type: CallWarningTypeOther, Message: "test warning"}}}) {
+ return
+ }
+ if !yield(StreamPart{Type: StreamPartTypeTextStart, ID: "text-1"}) {
+ return
+ }
+ if !yield(StreamPart{Type: StreamPartTypeTextDelta, ID: "text-1", Delta: "Hello"}) {
+ return
+ }
+ if !yield(StreamPart{Type: StreamPartTypeTextEnd, ID: "text-1"}) {
+ return
+ }
+ if !yield(StreamPart{Type: StreamPartTypeReasoningStart, ID: "reasoning-1"}) {
+ return
+ }
+ if !yield(StreamPart{Type: StreamPartTypeReasoningDelta, ID: "reasoning-1", Delta: "thinking..."}) {
+ return
+ }
+ if !yield(StreamPart{Type: StreamPartTypeReasoningEnd, ID: "reasoning-1"}) {
+ return
+ }
+ if !yield(StreamPart{Type: StreamPartTypeToolInputStart, ID: "tool-1", ToolCallName: "test_tool"}) {
+ return
+ }
+ if !yield(StreamPart{Type: StreamPartTypeToolInputDelta, ID: "tool-1", Delta: `{"param"`}) {
+ return
+ }
+ if !yield(StreamPart{Type: StreamPartTypeToolInputEnd, ID: "tool-1"}) {
+ return
+ }
+ if !yield(StreamPart{Type: StreamPartTypeSource, ID: "source-1", SourceType: SourceTypeURL, URL: "https://example.com", Title: "Example"}) {
+ return
+ }
+ yield(StreamPart{
+ Type: StreamPartTypeFinish,
+ Usage: Usage{InputTokens: 5, OutputTokens: 2, TotalTokens: 7},
+ FinishReason: FinishReasonStop,
+ })
+ }, nil
+ },
+ }
+
+ // Create agent
+ agent := NewAgent(mockModel)
+
+ ctx := context.Background()
+
+ // Create streaming call with all callbacks
+ streamCall := AgentStreamCall{
+ Prompt: "Test all callbacks",
+ OnAgentStart: func() {
+ callbacks["OnAgentStart"] = true
+ },
+ OnAgentFinish: func(result *AgentResult) {
+ callbacks["OnAgentFinish"] = true
+ },
+ OnStepStart: func(stepNumber int) {
+ callbacks["OnStepStart"] = true
+ },
+ OnStepFinish: func(stepResult StepResult) {
+ callbacks["OnStepFinish"] = true
+ },
+ OnFinish: func(result *AgentResult) {
+ callbacks["OnFinish"] = true
+ },
+ OnError: func(err error) {
+ callbacks["OnError"] = true
+ },
+ OnChunk: func(part StreamPart) {
+ callbacks["OnChunk"] = true
+ },
+ OnWarnings: func(warnings []CallWarning) {
+ callbacks["OnWarnings"] = true
+ },
+ OnTextStart: func(id string) {
+ callbacks["OnTextStart"] = true
+ },
+ OnTextDelta: func(id, text string) {
+ callbacks["OnTextDelta"] = true
+ },
+ OnTextEnd: func(id string) {
+ callbacks["OnTextEnd"] = true
+ },
+ OnReasoningStart: func(id string) {
+ callbacks["OnReasoningStart"] = true
+ },
+ OnReasoningDelta: func(id, text string) {
+ callbacks["OnReasoningDelta"] = true
+ },
+ OnReasoningEnd: func(id string) {
+ callbacks["OnReasoningEnd"] = true
+ },
+ OnToolInputStart: func(id, toolName string) {
+ callbacks["OnToolInputStart"] = true
+ },
+ OnToolInputDelta: func(id, delta string) {
+ callbacks["OnToolInputDelta"] = true
+ },
+ OnToolInputEnd: func(id string) {
+ callbacks["OnToolInputEnd"] = true
+ },
+ OnToolCall: func(toolCall ToolCallContent) {
+ callbacks["OnToolCall"] = true
+ },
+ OnToolResult: func(result ToolResultContent) {
+ callbacks["OnToolResult"] = true
+ },
+ OnSource: func(source SourceContent) {
+ callbacks["OnSource"] = true
+ },
+ OnStreamFinish: func(usage Usage, finishReason FinishReason, providerMetadata ProviderOptions) {
+ callbacks["OnStreamFinish"] = true
+ },
+ OnStreamError: func(err error) {
+ callbacks["OnStreamError"] = true
+ },
+ }
+
+ // Execute streaming agent
+ result, err := agent.Stream(ctx, streamCall)
+ require.NoError(t, err)
+ require.NotNil(t, result)
+
+ // Verify that expected callbacks were called
+ expectedCallbacks := []string{
+ "OnAgentStart",
+ "OnAgentFinish",
+ "OnStepStart",
+ "OnStepFinish",
+ "OnFinish",
+ "OnChunk",
+ "OnWarnings",
+ "OnTextStart",
+ "OnTextDelta",
+ "OnTextEnd",
+ "OnReasoningStart",
+ "OnReasoningDelta",
+ "OnReasoningEnd",
+ "OnToolInputStart",
+ "OnToolInputDelta",
+ "OnToolInputEnd",
+ "OnSource",
+ "OnStreamFinish",
+ }
+
+ for _, callback := range expectedCallbacks {
+ require.True(t, callbacks[callback], "Expected callback %s to be called", callback)
+ }
+
+ // Verify that error callbacks were not called
+ require.False(t, callbacks["OnError"], "OnError should not be called in successful case")
+ require.False(t, callbacks["OnStreamError"], "OnStreamError should not be called in successful case")
+ require.False(t, callbacks["OnToolCall"], "OnToolCall should not be called without actual tool calls")
+ require.False(t, callbacks["OnToolResult"], "OnToolResult should not be called without actual tool results")
+}
+
+// TestStreamingAgentWithTools tests streaming agent with tool calls (mirrors TS test patterns)
+func TestStreamingAgentWithTools(t *testing.T) {
+ t.Parallel()
+
+ stepCount := 0
+ // Create a mock language model that makes a tool call then finishes
+ mockModel := &mockLanguageModel{
+ streamFunc: func(ctx context.Context, call Call) (StreamResponse, error) {
+ stepCount++
+ return func(yield func(StreamPart) bool) {
+ if stepCount == 1 {
+ // First step: make tool call
+ if !yield(StreamPart{Type: StreamPartTypeToolInputStart, ID: "tool-1", ToolCallName: "echo"}) {
+ return
+ }
+ if !yield(StreamPart{Type: StreamPartTypeToolInputDelta, ID: "tool-1", Delta: `{"message"`}) {
+ return
+ }
+ if !yield(StreamPart{Type: StreamPartTypeToolInputDelta, ID: "tool-1", Delta: `: "test"}`}) {
+ return
+ }
+ if !yield(StreamPart{Type: StreamPartTypeToolInputEnd, ID: "tool-1"}) {
+ return
+ }
+ if !yield(StreamPart{
+ Type: StreamPartTypeToolCall,
+ ID: "tool-1",
+ ToolCallName: "echo",
+ ToolCallInput: `{"message": "test"}`,
+ }) {
+ return
+ }
+ yield(StreamPart{
+ Type: StreamPartTypeFinish,
+ Usage: Usage{InputTokens: 10, OutputTokens: 5, TotalTokens: 15},
+ FinishReason: FinishReasonToolCalls,
+ })
+ } else {
+ // Second step: finish after tool execution
+ if !yield(StreamPart{Type: StreamPartTypeTextStart, ID: "text-1"}) {
+ return
+ }
+ if !yield(StreamPart{Type: StreamPartTypeTextDelta, ID: "text-1", Delta: "Tool executed successfully"}) {
+ return
+ }
+ if !yield(StreamPart{Type: StreamPartTypeTextEnd, ID: "text-1"}) {
+ return
+ }
+ yield(StreamPart{
+ Type: StreamPartTypeFinish,
+ Usage: Usage{InputTokens: 5, OutputTokens: 3, TotalTokens: 8},
+ FinishReason: FinishReasonStop,
+ })
+ }
+ }, nil
+ },
+ }
+
+ // Create agent with echo tool
+ agent := NewAgent(
+ mockModel,
+ WithSystemPrompt("You are a helpful assistant."),
+ WithTools(&EchoTool{}),
+ )
+
+ ctx := context.Background()
+
+ // Track callback invocations
+ var toolInputStartCalled bool
+ var toolInputDeltaCalled bool
+ var toolInputEndCalled bool
+ var toolCallCalled bool
+ var toolResultCalled bool
+
+ // Create streaming call with callbacks
+ streamCall := AgentStreamCall{
+ Prompt: "Echo 'test'",
+ OnToolInputStart: func(id, toolName string) {
+ toolInputStartCalled = true
+ require.Equal(t, "tool-1", id)
+ require.Equal(t, "echo", toolName)
+ },
+ OnToolInputDelta: func(id, delta string) {
+ toolInputDeltaCalled = true
+ require.Equal(t, "tool-1", id)
+ require.Contains(t, []string{`{"message"`, `: "test"}`}, delta)
+ },
+ OnToolInputEnd: func(id string) {
+ toolInputEndCalled = true
+ require.Equal(t, "tool-1", id)
+ },
+ OnToolCall: func(toolCall ToolCallContent) {
+ toolCallCalled = true
+ require.Equal(t, "echo", toolCall.ToolName)
+ require.Equal(t, `{"message": "test"}`, toolCall.Input)
+ },
+ OnToolResult: func(result ToolResultContent) {
+ toolResultCalled = true
+ require.Equal(t, "echo", result.ToolName)
+ },
+ }
+
+ // Execute streaming agent
+ result, err := agent.Stream(ctx, streamCall)
+ require.NoError(t, err)
+
+ // Verify results
+ require.True(t, toolInputStartCalled, "OnToolInputStart should have been called")
+ require.True(t, toolInputDeltaCalled, "OnToolInputDelta should have been called")
+ require.True(t, toolInputEndCalled, "OnToolInputEnd should have been called")
+ require.True(t, toolCallCalled, "OnToolCall should have been called")
+ require.True(t, toolResultCalled, "OnToolResult should have been called")
+ require.Equal(t, 2, len(result.Steps)) // Two steps: tool call + final response
+
+ // Check that tool was executed in first step
+ firstStep := result.Steps[0]
+ toolCalls := firstStep.Content.ToolCalls()
+ require.Equal(t, 1, len(toolCalls))
+ require.Equal(t, "echo", toolCalls[0].ToolName)
+
+ toolResults := firstStep.Content.ToolResults()
+ require.Equal(t, 1, len(toolResults))
+ require.Equal(t, "echo", toolResults[0].ToolName)
+}
+
+// TestStreamingAgentTextDeltas tests text streaming (mirrors TS textStream tests)
+func TestStreamingAgentTextDeltas(t *testing.T) {
+ t.Parallel()
+
+ // Create a mock language model that returns text deltas
+ mockModel := &mockLanguageModel{
+ streamFunc: func(ctx context.Context, call Call) (StreamResponse, error) {
+ return func(yield func(StreamPart) bool) {
+ if !yield(StreamPart{Type: StreamPartTypeTextStart, ID: "text-1"}) {
+ return
+ }
+ if !yield(StreamPart{Type: StreamPartTypeTextDelta, ID: "text-1", Delta: "Hello"}) {
+ return
+ }
+ if !yield(StreamPart{Type: StreamPartTypeTextDelta, ID: "text-1", Delta: ", "}) {
+ return
+ }
+ if !yield(StreamPart{Type: StreamPartTypeTextDelta, ID: "text-1", Delta: "world!"}) {
+ return
+ }
+ if !yield(StreamPart{Type: StreamPartTypeTextEnd, ID: "text-1"}) {
+ return
+ }
+ yield(StreamPart{
+ Type: StreamPartTypeFinish,
+ Usage: Usage{InputTokens: 3, OutputTokens: 10, TotalTokens: 13},
+ FinishReason: FinishReasonStop,
+ })
+ }, nil
+ },
+ }
+
+ agent := NewAgent(mockModel)
+ ctx := context.Background()
+
+ // Track text deltas
+ var textDeltas []string
+
+ streamCall := AgentStreamCall{
+ Prompt: "Say hello",
+ OnTextDelta: func(id, text string) {
+ if text != "" {
+ textDeltas = append(textDeltas, text)
+ }
+ },
+ }
+
+ result, err := agent.Stream(ctx, streamCall)
+ require.NoError(t, err)
+
+ // Verify text deltas match expected pattern
+ require.Equal(t, []string{"Hello", ", ", "world!"}, textDeltas)
+ require.Equal(t, "Hello, world!", result.Response.Content.Text())
+ require.Equal(t, int64(13), result.TotalUsage.TotalTokens)
+}
+
+// TestStreamingAgentReasoning tests reasoning content (mirrors TS reasoning tests)
+func TestStreamingAgentReasoning(t *testing.T) {
+ t.Parallel()
+
+ mockModel := &mockLanguageModel{
+ streamFunc: func(ctx context.Context, call Call) (StreamResponse, error) {
+ return func(yield func(StreamPart) bool) {
+ if !yield(StreamPart{Type: StreamPartTypeReasoningStart, ID: "reasoning-1"}) {
+ return
+ }
+ if !yield(StreamPart{Type: StreamPartTypeReasoningDelta, ID: "reasoning-1", Delta: "I will open the conversation"}) {
+ return
+ }
+ if !yield(StreamPart{Type: StreamPartTypeReasoningDelta, ID: "reasoning-1", Delta: " with witty banter."}) {
+ return
+ }
+ if !yield(StreamPart{Type: StreamPartTypeReasoningEnd, ID: "reasoning-1"}) {
+ return
+ }
+ if !yield(StreamPart{Type: StreamPartTypeTextStart, ID: "text-1"}) {
+ return
+ }
+ if !yield(StreamPart{Type: StreamPartTypeTextDelta, ID: "text-1", Delta: "Hi there!"}) {
+ return
+ }
+ if !yield(StreamPart{Type: StreamPartTypeTextEnd, ID: "text-1"}) {
+ return
+ }
+ yield(StreamPart{
+ Type: StreamPartTypeFinish,
+ Usage: Usage{InputTokens: 5, OutputTokens: 15, TotalTokens: 20},
+ FinishReason: FinishReasonStop,
+ })
+ }, nil
+ },
+ }
+
+ agent := NewAgent(mockModel)
+ ctx := context.Background()
+
+ var reasoningDeltas []string
+ var textDeltas []string
+
+ streamCall := AgentStreamCall{
+ Prompt: "Think and respond",
+ OnReasoningDelta: func(id, text string) {
+ reasoningDeltas = append(reasoningDeltas, text)
+ },
+ OnTextDelta: func(id, text string) {
+ textDeltas = append(textDeltas, text)
+ },
+ }
+
+ result, err := agent.Stream(ctx, streamCall)
+ require.NoError(t, err)
+
+ // Verify reasoning and text are separate
+ require.Equal(t, []string{"I will open the conversation", " with witty banter."}, reasoningDeltas)
+ require.Equal(t, []string{"Hi there!"}, textDeltas)
+ require.Equal(t, "Hi there!", result.Response.Content.Text())
+ require.Equal(t, "I will open the conversation with witty banter.", result.Response.Content.ReasoningText())
+}
+
+// TestStreamingAgentError tests error handling (mirrors TS error tests)
+func TestStreamingAgentError(t *testing.T) {
+ t.Parallel()
+
+ // Create a mock language model that returns an error
+ mockModel := &mockLanguageModel{
+ streamFunc: func(ctx context.Context, call Call) (StreamResponse, error) {
+ return func(yield func(StreamPart) bool) {
+ yield(StreamPart{Type: StreamPartTypeError, Error: fmt.Errorf("mock stream error")})
+ }, nil
+ },
+ }
+
+ agent := NewAgent(mockModel)
+ ctx := context.Background()
+
+ // Track error callbacks
+ var streamErrorOccurred bool
+ var errorOccurred bool
+ var errorMessage string
+
+ streamCall := AgentStreamCall{
+ Prompt: "This will fail",
+ OnStreamError: func(err error) {
+ streamErrorOccurred = true
+ },
+ OnError: func(err error) {
+ errorOccurred = true
+ errorMessage = err.Error()
+ },
+ }
+
+ // Execute streaming agent
+ result, err := agent.Stream(ctx, streamCall)
+ require.Error(t, err)
+ require.Nil(t, result)
+ require.True(t, streamErrorOccurred, "OnStreamError should have been called")
+ require.True(t, errorOccurred, "OnError should have been called")
+ require.Contains(t, errorMessage, "mock stream error")
+}
+
+// TestStreamingAgentSources tests source handling (mirrors TS source tests)
+func TestStreamingAgentSources(t *testing.T) {
+ t.Parallel()
+
+ mockModel := &mockLanguageModel{
+ streamFunc: func(ctx context.Context, call Call) (StreamResponse, error) {
+ return func(yield func(StreamPart) bool) {
+ if !yield(StreamPart{
+ Type: StreamPartTypeSource,
+ ID: "source-1",
+ SourceType: SourceTypeURL,
+ URL: "https://example.com",
+ Title: "Example",
+ }) {
+ return
+ }
+ if !yield(StreamPart{Type: StreamPartTypeTextStart, ID: "text-1"}) {
+ return
+ }
+ if !yield(StreamPart{Type: StreamPartTypeTextDelta, ID: "text-1", Delta: "Hello!"}) {
+ return
+ }
+ if !yield(StreamPart{Type: StreamPartTypeTextEnd, ID: "text-1"}) {
+ return
+ }
+ if !yield(StreamPart{
+ Type: StreamPartTypeSource,
+ ID: "source-2",
+ SourceType: SourceTypeDocument,
+ Title: "Document Example",
+ }) {
+ return
+ }
+ yield(StreamPart{
+ Type: StreamPartTypeFinish,
+ Usage: Usage{InputTokens: 3, OutputTokens: 5, TotalTokens: 8},
+ FinishReason: FinishReasonStop,
+ })
+ }, nil
+ },
+ }
+
+ agent := NewAgent(mockModel)
+ ctx := context.Background()
+
+ var sources []SourceContent
+
+ streamCall := AgentStreamCall{
+ Prompt: "Search and respond",
+ OnSource: func(source SourceContent) {
+ sources = append(sources, source)
+ },
+ }
+
+ result, err := agent.Stream(ctx, streamCall)
+ require.NoError(t, err)
+
+ // Verify sources were captured
+ require.Equal(t, 2, len(sources))
+ require.Equal(t, SourceTypeURL, sources[0].SourceType)
+ require.Equal(t, "https://example.com", sources[0].URL)
+ require.Equal(t, "Example", sources[0].Title)
+ require.Equal(t, SourceTypeDocument, sources[1].SourceType)
+ require.Equal(t, "Document Example", sources[1].Title)
+
+ // Verify sources are in final result
+ resultSources := result.Response.Content.Sources()
+ require.Equal(t, 2, len(resultSources))
+}
@@ -39,6 +39,7 @@ func (m *mockTool) Run(ctx context.Context, call tools.ToolCall) (tools.ToolResp
// Mock language model for testing
type mockLanguageModel struct {
generateFunc func(ctx context.Context, call Call) (*Response, error)
+ streamFunc func(ctx context.Context, call Call) (StreamResponse, error)
}
func (m *mockLanguageModel) Generate(ctx context.Context, call Call) (*Response, error) {
@@ -59,7 +60,10 @@ func (m *mockLanguageModel) Generate(ctx context.Context, call Call) (*Response,
}
func (m *mockLanguageModel) Stream(ctx context.Context, call Call) (StreamResponse, error) {
- panic("not implemented")
+ if m.streamFunc != nil {
+ return m.streamFunc(ctx, call)
+ }
+ return nil, fmt.Errorf("mock stream not implemented")
}
func (m *mockLanguageModel) Provider() string {
@@ -0,0 +1,95 @@
+package main
+
+import (
+ "context"
+ "fmt"
+ "os"
+
+ "github.com/charmbracelet/crush/internal/ai"
+ "github.com/charmbracelet/crush/internal/ai/providers"
+ "github.com/charmbracelet/crush/internal/llm/tools"
+)
+
+// Simple echo tool for demonstration
+type EchoTool struct{}
+
+func (e *EchoTool) Info() tools.ToolInfo {
+ return tools.ToolInfo{
+ Name: "echo",
+ Description: "Echo back the provided message",
+ Parameters: map[string]any{
+ "message": map[string]any{
+ "type": "string",
+ "description": "The message to echo back",
+ },
+ },
+ Required: []string{"message"},
+ }
+}
+
+func (e *EchoTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolResponse, error) {
+ return tools.NewTextResponse("Echo: " + params.Input), nil
+}
+
+func main() {
+ // Check for API key
+ apiKey := os.Getenv("OPENAI_API_KEY")
+ if apiKey == "" {
+ fmt.Println("Please set OPENAI_API_KEY environment variable")
+ os.Exit(1)
+ }
+
+ // Create provider and model
+ provider := providers.NewOpenAIProvider(
+ providers.WithOpenAIApiKey(apiKey),
+ )
+ model := provider.LanguageModel("gpt-4o-mini")
+
+ // Create streaming agent
+ agent := ai.NewAgent(
+ model,
+ ai.WithSystemPrompt("You are a helpful assistant."),
+ ai.WithTools(&EchoTool{}),
+ )
+
+ ctx := context.Background()
+
+ fmt.Println("Simple Streaming Agent Example")
+ fmt.Println("==============================")
+ fmt.Println()
+
+ // Basic streaming with key callbacks
+ streamCall := ai.AgentStreamCall{
+ Prompt: "Please echo back 'Hello, streaming world!'",
+
+ // Show real-time text as it streams
+ OnTextDelta: func(id, text string) {
+ fmt.Print(text)
+ },
+
+ // Show when tools are called
+ OnToolCall: func(toolCall ai.ToolCallContent) {
+ fmt.Printf("\n[Tool: %s called]\n", toolCall.ToolName)
+ },
+
+ // Show tool results
+ OnToolResult: func(result ai.ToolResultContent) {
+ fmt.Printf("[Tool result received]\n")
+ },
+
+ // Show when each step completes
+ OnStepFinish: func(step ai.StepResult) {
+ fmt.Printf("\n[Step completed: %s]\n", step.FinishReason)
+ },
+ }
+
+ fmt.Println("Assistant response:")
+ result, err := agent.Stream(ctx, streamCall)
+ if err != nil {
+ fmt.Printf("Error: %v\n", err)
+ os.Exit(1)
+ }
+
+ fmt.Printf("\n\nFinal result: %s\n", result.Response.Content.Text())
+ fmt.Printf("Steps: %d, Total tokens: %d\n", len(result.Steps), result.TotalUsage.TotalTokens)
+}
@@ -0,0 +1,274 @@
+package main
+
+import (
+ "context"
+ "fmt"
+ "os"
+ "strings"
+
+ "github.com/charmbracelet/crush/internal/ai"
+ "github.com/charmbracelet/crush/internal/ai/providers"
+ "github.com/charmbracelet/crush/internal/llm/tools"
+)
+
+// WeatherTool is a simple tool that simulates weather lookup
+type WeatherTool struct{}
+
+func (w *WeatherTool) Info() tools.ToolInfo {
+ return tools.ToolInfo{
+ Name: "get_weather",
+ Description: "Get the current weather for a specific location",
+ Parameters: map[string]any{
+ "location": map[string]any{
+ "type": "string",
+ "description": "The city and country, e.g. 'London, UK'",
+ },
+ "unit": map[string]any{
+ "type": "string",
+ "description": "Temperature unit (celsius or fahrenheit)",
+ "enum": []string{"celsius", "fahrenheit"},
+ "default": "celsius",
+ },
+ },
+ Required: []string{"location"},
+ }
+}
+
+func (w *WeatherTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolResponse, error) {
+ // Simulate weather lookup with some fake data
+ location := "Unknown"
+ if strings.Contains(params.Input, "pristina") || strings.Contains(params.Input, "Pristina") {
+ location = "Pristina, Kosovo"
+ } else if strings.Contains(params.Input, "london") || strings.Contains(params.Input, "London") {
+ location = "London, UK"
+ } else if strings.Contains(params.Input, "new york") || strings.Contains(params.Input, "New York") {
+ location = "New York, USA"
+ }
+
+ unit := "celsius"
+ if strings.Contains(params.Input, "fahrenheit") {
+ unit = "fahrenheit"
+ }
+
+ var temp string
+ if unit == "fahrenheit" {
+ temp = "72°F"
+ } else {
+ temp = "22°C"
+ }
+
+ weather := fmt.Sprintf("The current weather in %s is %s with partly cloudy skies and light winds.", location, temp)
+ return tools.NewTextResponse(weather), nil
+}
+
+// CalculatorTool demonstrates a second tool for multi-tool scenarios
+type CalculatorTool struct{}
+
+func (c *CalculatorTool) Info() tools.ToolInfo {
+ return tools.ToolInfo{
+ Name: "calculate",
+ Description: "Perform basic mathematical calculations",
+ Parameters: map[string]any{
+ "expression": map[string]any{
+ "type": "string",
+ "description": "Mathematical expression to evaluate (e.g., '2 + 2', '10 * 5')",
+ },
+ },
+ Required: []string{"expression"},
+ }
+}
+
+func (c *CalculatorTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolResponse, error) {
+ // Simple calculator simulation
+ expr := strings.ReplaceAll(params.Input, "\"", "")
+ if strings.Contains(expr, "2 + 2") || strings.Contains(expr, "2+2") {
+ return tools.NewTextResponse("2 + 2 = 4"), nil
+ } else if strings.Contains(expr, "10 * 5") || strings.Contains(expr, "10*5") {
+ return tools.NewTextResponse("10 * 5 = 50"), nil
+ }
+ return tools.NewTextResponse("I can calculate simple expressions like '2 + 2' or '10 * 5'"), nil
+}
+
+func main() {
+ // Check for API key
+ apiKey := os.Getenv("OPENAI_API_KEY")
+ if apiKey == "" {
+ fmt.Println("❌ Please set OPENAI_API_KEY environment variable")
+ fmt.Println(" export OPENAI_API_KEY=your_api_key_here")
+ os.Exit(1)
+ }
+
+ fmt.Println("🚀 Streaming Agent Example")
+ fmt.Println("==========================")
+ fmt.Println()
+
+ // Create OpenAI provider and model
+ provider := providers.NewOpenAIProvider(
+ providers.WithOpenAIApiKey(apiKey),
+ )
+ model := provider.LanguageModel("gpt-4o-mini") // Using mini for faster/cheaper responses
+
+ // Create agent with tools
+ agent := ai.NewAgent(
+ model,
+ ai.WithSystemPrompt("You are a helpful assistant that can check weather and do calculations. Be concise and friendly."),
+ ai.WithTools(&WeatherTool{}, &CalculatorTool{}),
+ )
+
+ ctx := context.Background()
+
+ // Demonstrate streaming with comprehensive callbacks
+ fmt.Println("💬 Asking: \"What's the weather in Pristina and what's 2 + 2?\"")
+ fmt.Println()
+
+ // Track streaming events
+ var stepCount int
+ var textBuffer strings.Builder
+ var reasoningBuffer strings.Builder
+
+ // Create streaming call with all callbacks
+ streamCall := ai.AgentStreamCall{
+ Prompt: "What's the weather in Pristina and what's 2 + 2?",
+
+ // Agent-level callbacks
+ OnAgentStart: func() {
+ fmt.Println("🎬 Agent started")
+ },
+ OnAgentFinish: func(result *ai.AgentResult) {
+ fmt.Printf("🏁 Agent finished with %d steps, total tokens: %d\n", len(result.Steps), result.TotalUsage.TotalTokens)
+ },
+ OnStepStart: func(stepNumber int) {
+ stepCount++
+ fmt.Printf("📝 Step %d started\n", stepNumber+1)
+ },
+ OnStepFinish: func(stepResult ai.StepResult) {
+ fmt.Printf("✅ Step completed (reason: %s, tokens: %d)\n", stepResult.FinishReason, stepResult.Usage.TotalTokens)
+ },
+ OnFinish: func(result *ai.AgentResult) {
+ fmt.Printf("🎯 Final result ready with %d steps\n", len(result.Steps))
+ },
+ OnError: func(err error) {
+ fmt.Printf("❌ Error: %v\n", err)
+ },
+
+ // Stream part callbacks
+ OnWarnings: func(warnings []ai.CallWarning) {
+ for _, warning := range warnings {
+ fmt.Printf("⚠️ Warning: %s\n", warning.Message)
+ }
+ },
+ OnTextStart: func(id string) {
+ fmt.Print("💭 Assistant: ")
+ },
+ OnTextDelta: func(id, text string) {
+ fmt.Print(text)
+ textBuffer.WriteString(text)
+ },
+ OnTextEnd: func(id string) {
+ fmt.Println()
+ },
+ OnReasoningStart: func(id string) {
+ fmt.Print("🤔 Thinking: ")
+ },
+ OnReasoningDelta: func(id, text string) {
+ reasoningBuffer.WriteString(text)
+ },
+ OnReasoningEnd: func(id string) {
+ if reasoningBuffer.Len() > 0 {
+ fmt.Printf("%s\n", reasoningBuffer.String())
+ reasoningBuffer.Reset()
+ }
+ },
+ OnToolInputStart: func(id, toolName string) {
+ fmt.Printf("🔧 Calling tool: %s\n", toolName)
+ },
+ OnToolInputDelta: func(id, delta string) {
+ // Could show tool input being built, but it's often noisy
+ },
+ OnToolInputEnd: func(id string) {
+ // Tool input complete
+ },
+ OnToolCall: func(toolCall ai.ToolCallContent) {
+ fmt.Printf("🛠️ Tool call: %s\n", toolCall.ToolName)
+ fmt.Printf(" Input: %s\n", toolCall.Input)
+ },
+ OnToolResult: func(result ai.ToolResultContent) {
+ fmt.Printf("🎯 Tool result from %s:\n", result.ToolName)
+ switch output := result.Result.(type) {
+ case ai.ToolResultOutputContentText:
+ fmt.Printf(" %s\n", output.Text)
+ case ai.ToolResultOutputContentError:
+ fmt.Printf(" Error: %s\n", output.Error.Error())
+ }
+ },
+ OnSource: func(source ai.SourceContent) {
+ fmt.Printf("📚 Source: %s (%s)\n", source.Title, source.URL)
+ },
+ OnStreamFinish: func(usage ai.Usage, finishReason ai.FinishReason, providerMetadata ai.ProviderOptions) {
+ fmt.Printf("📊 Stream finished (reason: %s, tokens: %d)\n", finishReason, usage.TotalTokens)
+ },
+ OnStreamError: func(err error) {
+ fmt.Printf("💥 Stream error: %v\n", err)
+ },
+ }
+
+ // Execute streaming agent
+ result, err := agent.Stream(ctx, streamCall)
+ if err != nil {
+ fmt.Printf("❌ Agent failed: %v\n", err)
+ os.Exit(1)
+ }
+
+ // Display final results
+ fmt.Println()
+ fmt.Println("📋 Final Summary")
+ fmt.Println("================")
+ fmt.Printf("Steps executed: %d\n", len(result.Steps))
+ fmt.Printf("Total tokens used: %d (input: %d, output: %d)\n",
+ result.TotalUsage.TotalTokens,
+ result.TotalUsage.InputTokens,
+ result.TotalUsage.OutputTokens)
+
+ if result.TotalUsage.ReasoningTokens > 0 {
+ fmt.Printf("Reasoning tokens: %d\n", result.TotalUsage.ReasoningTokens)
+ }
+
+ fmt.Printf("Final response: %s\n", result.Response.Content.Text())
+
+ // Show step details
+ fmt.Println()
+ fmt.Println("🔍 Step Details")
+ fmt.Println("===============")
+ for i, step := range result.Steps {
+ fmt.Printf("Step %d:\n", i+1)
+ fmt.Printf(" Finish reason: %s\n", step.FinishReason)
+ fmt.Printf(" Content types: ")
+
+ var contentTypes []string
+ for _, content := range step.Content {
+ contentTypes = append(contentTypes, string(content.GetType()))
+ }
+ fmt.Printf("%s\n", strings.Join(contentTypes, ", "))
+
+ // Show tool calls and results
+ toolCalls := step.Content.ToolCalls()
+ if len(toolCalls) > 0 {
+ fmt.Printf(" Tool calls: ")
+ var toolNames []string
+ for _, tc := range toolCalls {
+ toolNames = append(toolNames, tc.ToolName)
+ }
+ fmt.Printf("%s\n", strings.Join(toolNames, ", "))
+ }
+
+ toolResults := step.Content.ToolResults()
+ if len(toolResults) > 0 {
+ fmt.Printf(" Tool results: %d\n", len(toolResults))
+ }
+
+ fmt.Printf(" Tokens: %d\n", step.Usage.TotalTokens)
+ fmt.Println()
+ }
+
+ fmt.Println("✨ Example completed successfully!")
+}
@@ -1134,17 +1134,17 @@ func parseAnnotationsFromDelta(delta openai.ChatCompletionChunkChoiceDelta) []op
var annotations []openai.ChatCompletionMessageAnnotation
// Parse the raw JSON to extract annotations
- var deltaData map[string]interface{}
+ var deltaData map[string]any
if err := json.Unmarshal([]byte(delta.RawJSON()), &deltaData); err != nil {
return annotations
}
// Check if annotations exist in the delta
- if annotationsData, ok := deltaData["annotations"].([]interface{}); ok {
+ if annotationsData, ok := deltaData["annotations"].([]any); ok {
for _, annotationData := range annotationsData {
- if annotationMap, ok := annotationData.(map[string]interface{}); ok {
+ if annotationMap, ok := annotationData.(map[string]any); ok {
if annotationType, ok := annotationMap["type"].(string); ok && annotationType == "url_citation" {
- if urlCitationData, ok := annotationMap["url_citation"].(map[string]interface{}); ok {
+ if urlCitationData, ok := annotationMap["url_citation"].(map[string]any); ok {
annotation := openai.ChatCompletionMessageAnnotation{
Type: "url_citation",
URLCitation: openai.ChatCompletionMessageAnnotationURLCitation{