feat: streaming agent

Kujtim Hoxha created

Change summary

agent.go                                          | 429 ++++++++++++
agent_stream_test.go                              | 569 +++++++++++++++++
agent_test.go                                     |   6 
providers/examples/streaming-agent-simple/main.go |  95 ++
providers/examples/streaming-agent/main.go        | 274 ++++++++
providers/openai.go                               |   8 
6 files changed, 1,370 insertions(+), 11 deletions(-)

Detailed changes

agent.go 🔗

@@ -157,6 +157,53 @@ type AgentCall struct {
 	OnStepFinished OnStepFinishedFunction
 }
 
+type AgentStreamCall struct {
+	Prompt           string     `json:"prompt"`
+	Files            []FilePart `json:"files"`
+	Messages         []Message  `json:"messages"`
+	MaxOutputTokens  *int64
+	Temperature      *float64 `json:"temperature"`
+	TopP             *float64 `json:"top_p"`
+	TopK             *int64   `json:"top_k"`
+	PresencePenalty  *float64 `json:"presence_penalty"`
+	FrequencyPenalty *float64 `json:"frequency_penalty"`
+	ActiveTools      []string `json:"active_tools"`
+	Headers          map[string]string
+	ProviderOptions  ProviderOptions
+	OnRetry          OnRetryCallback
+	MaxRetries       *int
+
+	StopWhen       []StopCondition
+	PrepareStep    PrepareStepFunction
+	RepairToolCall RepairToolCallFunction
+
+	// Agent-level callbacks
+	OnAgentStart  func()                      // Called when agent starts
+	OnAgentFinish func(result *AgentResult)   // Called when agent finishes
+	OnStepStart   func(stepNumber int)        // Called when a step starts
+	OnStepFinish  func(stepResult StepResult) // Called when a step finishes
+	OnFinish      func(result *AgentResult)   // Called when entire agent completes
+	OnError       func(error)                 // Called when an error occurs
+
+	// Stream part callbacks - called for each corresponding stream part type
+	OnChunk          func(StreamPart)                                                               // Called for each stream part (catch-all)
+	OnWarnings       func(warnings []CallWarning)                                                   // Called for warnings
+	OnTextStart      func(id string)                                                                // Called when text starts
+	OnTextDelta      func(id, text string)                                                          // Called for text deltas
+	OnTextEnd        func(id string)                                                                // Called when text ends
+	OnReasoningStart func(id string)                                                                // Called when reasoning starts
+	OnReasoningDelta func(id, text string)                                                          // Called for reasoning deltas
+	OnReasoningEnd   func(id string)                                                                // Called when reasoning ends
+	OnToolInputStart func(id, toolName string)                                                      // Called when tool input starts
+	OnToolInputDelta func(id, delta string)                                                         // Called for tool input deltas
+	OnToolInputEnd   func(id string)                                                                // Called when tool input ends
+	OnToolCall       func(toolCall ToolCallContent)                                                 // Called when tool call is complete
+	OnToolResult     func(result ToolResultContent)                                                 // Called when tool execution completes
+	OnSource         func(source SourceContent)                                                     // Called for source references
+	OnStreamFinish   func(usage Usage, finishReason FinishReason, providerMetadata ProviderOptions) // Called when stream finishes
+	OnStreamError    func(error)                                                                    // Called when stream error occurs
+}
+
 type AgentResult struct {
 	Steps []StepResult
 	// Final response
@@ -166,7 +213,7 @@ type AgentResult struct {
 
 type Agent interface {
 	Generate(context.Context, AgentCall) (*AgentResult, error)
-	Stream(context.Context, AgentCall) (StreamResponse, error)
+	Stream(context.Context, AgentStreamCall) (*AgentResult, error)
 }
 
 type agentOption = func(*AgentSettings)
@@ -343,7 +390,7 @@ func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, err
 			}
 		}
 
-		toolResults, err := a.executeTools(ctx, a.settings.tools, stepToolCalls)
+		toolResults, err := a.executeTools(ctx, a.settings.tools, stepToolCalls, nil)
 
 		// Build step content with validated tool calls and tool results
 		stepContent := []Content{}
@@ -501,7 +548,7 @@ func toResponseMessages(content []Content) []Message {
 	return messages
 }
 
-func (a *agent) executeTools(ctx context.Context, allTools []tools.BaseTool, toolCalls []ToolCallContent) ([]ToolResultContent, error) {
+func (a *agent) executeTools(ctx context.Context, allTools []tools.BaseTool, toolCalls []ToolCallContent, toolResultCallback func(result ToolResultContent)) ([]ToolResultContent, error) {
 	if len(toolCalls) == 0 {
 		return nil, nil
 	}
@@ -532,6 +579,10 @@ func (a *agent) executeTools(ctx context.Context, allTools []tools.BaseTool, too
 					},
 					ProviderExecuted: false,
 				}
+				if toolResultCallback != nil {
+					toolResultCallback(results[index])
+				}
+
 				return
 			}
 
@@ -545,6 +596,10 @@ func (a *agent) executeTools(ctx context.Context, allTools []tools.BaseTool, too
 					},
 					ProviderExecuted: false,
 				}
+
+				if toolResultCallback != nil {
+					toolResultCallback(results[index])
+				}
 				return
 			}
 
@@ -563,6 +618,9 @@ func (a *agent) executeTools(ctx context.Context, allTools []tools.BaseTool, too
 					},
 					ProviderExecuted: false,
 				}
+				if toolResultCallback != nil {
+					toolResultCallback(results[index])
+				}
 				toolExecutionError = err
 				return
 			}
@@ -576,6 +634,10 @@ func (a *agent) executeTools(ctx context.Context, allTools []tools.BaseTool, too
 					},
 					ProviderExecuted: false,
 				}
+
+				if toolResultCallback != nil {
+					toolResultCallback(results[index])
+				}
 			} else {
 				results[index] = ToolResultContent{
 					ToolCallID: call.ToolCallID,
@@ -585,6 +647,9 @@ func (a *agent) executeTools(ctx context.Context, allTools []tools.BaseTool, too
 					},
 					ProviderExecuted: false,
 				}
+				if toolResultCallback != nil {
+					toolResultCallback(results[index])
+				}
 			}
 		}(i, toolCall)
 	}
@@ -596,9 +661,164 @@ func (a *agent) executeTools(ctx context.Context, allTools []tools.BaseTool, too
 }
 
 // Stream implements Agent.
-func (a *agent) Stream(ctx context.Context, opts AgentCall) (StreamResponse, error) {
-	// TODO: implement the agentic stuff
-	panic("not implemented")
+func (a *agent) Stream(ctx context.Context, opts AgentStreamCall) (*AgentResult, error) {
+	// Convert AgentStreamCall to AgentCall for preparation
+	call := AgentCall{
+		Prompt:           opts.Prompt,
+		Files:            opts.Files,
+		Messages:         opts.Messages,
+		MaxOutputTokens:  opts.MaxOutputTokens,
+		Temperature:      opts.Temperature,
+		TopP:             opts.TopP,
+		TopK:             opts.TopK,
+		PresencePenalty:  opts.PresencePenalty,
+		FrequencyPenalty: opts.FrequencyPenalty,
+		ActiveTools:      opts.ActiveTools,
+		Headers:          opts.Headers,
+		ProviderOptions:  opts.ProviderOptions,
+		MaxRetries:       opts.MaxRetries,
+		StopWhen:         opts.StopWhen,
+		PrepareStep:      opts.PrepareStep,
+		RepairToolCall:   opts.RepairToolCall,
+	}
+
+	call = a.prepareCall(call)
+
+	initialPrompt, err := a.createPrompt(a.settings.systemPrompt, call.Prompt, call.Messages, call.Files...)
+	if err != nil {
+		return nil, err
+	}
+
+	var responseMessages []Message
+	var steps []StepResult
+	var totalUsage Usage
+
+	// Start agent stream
+	if opts.OnAgentStart != nil {
+		opts.OnAgentStart()
+	}
+
+	for stepNumber := 0; ; stepNumber++ {
+		stepInputMessages := append(initialPrompt, responseMessages...)
+		stepModel := a.settings.model
+		stepSystemPrompt := a.settings.systemPrompt
+		stepActiveTools := call.ActiveTools
+		stepToolChoice := ToolChoiceAuto
+		disableAllTools := false
+
+		// Apply step preparation if provided
+		if call.PrepareStep != nil {
+			prepared := call.PrepareStep(PrepareStepFunctionOptions{
+				Model:      stepModel,
+				Steps:      steps,
+				StepNumber: stepNumber,
+				Messages:   stepInputMessages,
+			})
+
+			if prepared.Messages != nil {
+				stepInputMessages = prepared.Messages
+			}
+			if prepared.Model != nil {
+				stepModel = prepared.Model
+			}
+			if prepared.System != nil {
+				stepSystemPrompt = *prepared.System
+			}
+			if prepared.ToolChoice != nil {
+				stepToolChoice = *prepared.ToolChoice
+			}
+			if len(prepared.ActiveTools) > 0 {
+				stepActiveTools = prepared.ActiveTools
+			}
+			disableAllTools = prepared.DisableAllTools
+		}
+
+		// Recreate prompt with potentially modified system prompt
+		if stepSystemPrompt != a.settings.systemPrompt {
+			stepPrompt, err := a.createPrompt(stepSystemPrompt, call.Prompt, call.Messages, call.Files...)
+			if err != nil {
+				return nil, err
+			}
+			if len(stepInputMessages) > 0 && len(stepPrompt) > 0 {
+				stepInputMessages[0] = stepPrompt[0]
+			}
+		}
+
+		preparedTools := a.prepareTools(a.settings.tools, stepActiveTools, disableAllTools)
+
+		// Start step stream
+		if opts.OnStepStart != nil {
+			opts.OnStepStart(stepNumber)
+		}
+
+		// Create streaming call
+		streamCall := Call{
+			Prompt:           stepInputMessages,
+			MaxOutputTokens:  call.MaxOutputTokens,
+			Temperature:      call.Temperature,
+			TopP:             call.TopP,
+			TopK:             call.TopK,
+			PresencePenalty:  call.PresencePenalty,
+			FrequencyPenalty: call.FrequencyPenalty,
+			Tools:            preparedTools,
+			ToolChoice:       &stepToolChoice,
+			Headers:          call.Headers,
+			ProviderOptions:  call.ProviderOptions,
+		}
+
+		// Get streaming response
+		stream, err := stepModel.Stream(ctx, streamCall)
+		if err != nil {
+			if opts.OnError != nil {
+				opts.OnError(err)
+			}
+			return nil, err
+		}
+
+		// Process stream with tool execution
+		stepResult, shouldContinue, err := a.processStepStream(ctx, stream, opts, steps)
+		if err != nil {
+			if opts.OnError != nil {
+				opts.OnError(err)
+			}
+			return nil, err
+		}
+
+		steps = append(steps, stepResult)
+		totalUsage = addUsage(totalUsage, stepResult.Usage)
+
+		// Call step finished callback
+		if opts.OnStepFinish != nil {
+			opts.OnStepFinish(stepResult)
+		}
+
+		// Add step messages to response messages
+		stepMessages := toResponseMessages(stepResult.Content)
+		responseMessages = append(responseMessages, stepMessages...)
+
+		// Check stop conditions
+		shouldStop := isStopConditionMet(call.StopWhen, steps)
+		if shouldStop || !shouldContinue {
+			break
+		}
+	}
+
+	// Finish agent stream
+	agentResult := &AgentResult{
+		Steps:      steps,
+		Response:   steps[len(steps)-1].Response,
+		TotalUsage: totalUsage,
+	}
+
+	if opts.OnFinish != nil {
+		opts.OnFinish(agentResult)
+	}
+
+	if opts.OnAgentFinish != nil {
+		opts.OnAgentFinish(agentResult)
+	}
+
+	return agentResult, nil
 }
 
 func (a *agent) prepareTools(tools []tools.BaseTool, activeTools []string, disableAllTools bool) []Tool {
@@ -776,6 +996,203 @@ func WithOnStepFinished(fn OnStepFinishedFunction) agentOption {
 	}
 }
 
+// processStepStream processes a single step's stream and returns the step result
+func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, opts AgentStreamCall, _ []StepResult) (StepResult, bool, error) {
+	var stepContent []Content
+	var stepToolCalls []ToolCallContent
+	var stepUsage Usage
+	var stepFinishReason FinishReason = FinishReasonUnknown
+	var stepWarnings []CallWarning
+	var stepProviderMetadata ProviderMetadata
+
+	activeToolCalls := make(map[string]*ToolCallContent)
+	activeTextContent := make(map[string]string)
+
+	// Process stream parts
+	for part := range stream {
+		// Forward all parts to chunk callback
+		if opts.OnChunk != nil {
+			opts.OnChunk(part)
+		}
+
+		switch part.Type {
+		case StreamPartTypeWarnings:
+			stepWarnings = part.Warnings
+			if opts.OnWarnings != nil {
+				opts.OnWarnings(part.Warnings)
+			}
+
+		case StreamPartTypeTextStart:
+			activeTextContent[part.ID] = ""
+			if opts.OnTextStart != nil {
+				opts.OnTextStart(part.ID)
+			}
+
+		case StreamPartTypeTextDelta:
+			if _, exists := activeTextContent[part.ID]; exists {
+				activeTextContent[part.ID] += part.Delta
+			}
+			if opts.OnTextDelta != nil {
+				opts.OnTextDelta(part.ID, part.Delta)
+			}
+
+		case StreamPartTypeTextEnd:
+			if text, exists := activeTextContent[part.ID]; exists {
+				stepContent = append(stepContent, TextContent{
+					Text:             text,
+					ProviderMetadata: ProviderMetadata(part.ProviderMetadata),
+				})
+				delete(activeTextContent, part.ID)
+			}
+			if opts.OnTextEnd != nil {
+				opts.OnTextEnd(part.ID)
+			}
+
+		case StreamPartTypeReasoningStart:
+			activeTextContent[part.ID] = ""
+			if opts.OnReasoningStart != nil {
+				opts.OnReasoningStart(part.ID)
+			}
+
+		case StreamPartTypeReasoningDelta:
+			if _, exists := activeTextContent[part.ID]; exists {
+				activeTextContent[part.ID] += part.Delta
+			}
+			if opts.OnReasoningDelta != nil {
+				opts.OnReasoningDelta(part.ID, part.Delta)
+			}
+
+		case StreamPartTypeReasoningEnd:
+			if text, exists := activeTextContent[part.ID]; exists {
+				stepContent = append(stepContent, ReasoningContent{
+					Text:             text,
+					ProviderMetadata: ProviderMetadata(part.ProviderMetadata),
+				})
+				delete(activeTextContent, part.ID)
+			}
+			if opts.OnReasoningEnd != nil {
+				opts.OnReasoningEnd(part.ID)
+			}
+
+		case StreamPartTypeToolInputStart:
+			activeToolCalls[part.ID] = &ToolCallContent{
+				ToolCallID:       part.ID,
+				ToolName:         part.ToolCallName,
+				Input:            "",
+				ProviderExecuted: part.ProviderExecuted,
+			}
+			if opts.OnToolInputStart != nil {
+				opts.OnToolInputStart(part.ID, part.ToolCallName)
+			}
+
+		case StreamPartTypeToolInputDelta:
+			if toolCall, exists := activeToolCalls[part.ID]; exists {
+				toolCall.Input += part.Delta
+			}
+			if opts.OnToolInputDelta != nil {
+				opts.OnToolInputDelta(part.ID, part.Delta)
+			}
+
+		case StreamPartTypeToolInputEnd:
+			if opts.OnToolInputEnd != nil {
+				opts.OnToolInputEnd(part.ID)
+			}
+
+		case StreamPartTypeToolCall:
+			toolCall := ToolCallContent{
+				ToolCallID:       part.ID,
+				ToolName:         part.ToolCallName,
+				Input:            part.ToolCallInput,
+				ProviderExecuted: part.ProviderExecuted,
+				ProviderMetadata: ProviderMetadata(part.ProviderMetadata),
+			}
+
+			// Validate and potentially repair the tool call
+			validatedToolCall := a.validateAndRepairToolCall(ctx, toolCall, a.settings.tools, a.settings.systemPrompt, nil, opts.RepairToolCall)
+			stepToolCalls = append(stepToolCalls, validatedToolCall)
+			stepContent = append(stepContent, validatedToolCall)
+
+			if opts.OnToolCall != nil {
+				opts.OnToolCall(validatedToolCall)
+			}
+
+			// Clean up active tool call
+			delete(activeToolCalls, part.ID)
+
+		case StreamPartTypeSource:
+			sourceContent := SourceContent{
+				SourceType:       part.SourceType,
+				ID:               part.ID,
+				URL:              part.URL,
+				Title:            part.Title,
+				ProviderMetadata: ProviderMetadata(part.ProviderMetadata),
+			}
+			stepContent = append(stepContent, sourceContent)
+			if opts.OnSource != nil {
+				opts.OnSource(sourceContent)
+			}
+
+		case StreamPartTypeFinish:
+			stepUsage = part.Usage
+			stepFinishReason = part.FinishReason
+			stepProviderMetadata = ProviderMetadata(part.ProviderMetadata)
+			if opts.OnStreamFinish != nil {
+				opts.OnStreamFinish(part.Usage, part.FinishReason, part.ProviderMetadata)
+			}
+
+		case StreamPartTypeError:
+			if opts.OnStreamError != nil {
+				opts.OnStreamError(part.Error)
+			}
+			if opts.OnError != nil {
+				opts.OnError(part.Error)
+			}
+			return StepResult{}, false, part.Error
+		}
+	}
+
+	// Execute tools if any
+	var toolResults []ToolResultContent
+	if len(stepToolCalls) > 0 {
+		var err error
+		toolResults, err = a.executeTools(ctx, a.settings.tools, stepToolCalls, opts.OnToolResult)
+		if err != nil {
+			return StepResult{}, false, err
+		}
+		// Add tool results to content
+		for _, result := range toolResults {
+			stepContent = append(stepContent, result)
+		}
+	}
+
+	stepResult := StepResult{
+		Response: Response{
+			Content:          stepContent,
+			FinishReason:     stepFinishReason,
+			Usage:            stepUsage,
+			Warnings:         stepWarnings,
+			ProviderMetadata: stepProviderMetadata,
+		},
+		Messages: toResponseMessages(stepContent),
+	}
+
+	// Determine if we should continue (has tool calls and not stopped)
+	shouldContinue := len(stepToolCalls) > 0 && stepFinishReason == FinishReasonToolCalls
+
+	return stepResult, shouldContinue, nil
+}
+
+func addUsage(a, b Usage) Usage {
+	return Usage{
+		InputTokens:         a.InputTokens + b.InputTokens,
+		OutputTokens:        a.OutputTokens + b.OutputTokens,
+		TotalTokens:         a.TotalTokens + b.TotalTokens,
+		ReasoningTokens:     a.ReasoningTokens + b.ReasoningTokens,
+		CacheCreationTokens: a.CacheCreationTokens + b.CacheCreationTokens,
+		CacheReadTokens:     a.CacheReadTokens + b.CacheReadTokens,
+	}
+}
+
 func WithHeaders(headers map[string]string) agentOption {
 	return func(s *AgentSettings) {
 		s.headers = headers

agent_stream_test.go 🔗

@@ -0,0 +1,569 @@
+package ai
+
+import (
+	"context"
+	"encoding/json"
+	"fmt"
+	"testing"
+
+	"github.com/charmbracelet/crush/internal/llm/tools"
+	"github.com/stretchr/testify/require"
+)
+
+// EchoTool is a simple tool that echoes back the input message
+type EchoTool struct{}
+
+// Info returns the tool information
+func (e *EchoTool) Info() tools.ToolInfo {
+	return tools.ToolInfo{
+		Name:        "echo",
+		Description: "Echo back the provided message",
+		Parameters: map[string]any{
+			"message": map[string]any{
+				"type":        "string",
+				"description": "The message to echo back",
+			},
+		},
+		Required: []string{"message"},
+	}
+}
+
+// Run executes the echo tool
+func (e *EchoTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolResponse, error) {
+	var input struct {
+		Message string `json:"message"`
+	}
+
+	if err := json.Unmarshal([]byte(params.Input), &input); err != nil {
+		return tools.NewTextErrorResponse("Invalid input: " + err.Error()), nil
+	}
+
+	if input.Message == "" {
+		return tools.NewTextErrorResponse("Message cannot be empty"), nil
+	}
+
+	return tools.NewTextResponse("Echo: " + input.Message), nil
+}
+
+// TestStreamingAgentCallbacks tests that all streaming callbacks are called correctly
+func TestStreamingAgentCallbacks(t *testing.T) {
+	t.Parallel()
+
+	// Track which callbacks were called
+	callbacks := make(map[string]bool)
+
+	// Create a mock language model that returns various stream parts
+	mockModel := &mockLanguageModel{
+		streamFunc: func(ctx context.Context, call Call) (StreamResponse, error) {
+			return func(yield func(StreamPart) bool) {
+				// Test all stream part types
+				if !yield(StreamPart{Type: StreamPartTypeWarnings, Warnings: []CallWarning{{Type: CallWarningTypeOther, Message: "test warning"}}}) {
+					return
+				}
+				if !yield(StreamPart{Type: StreamPartTypeTextStart, ID: "text-1"}) {
+					return
+				}
+				if !yield(StreamPart{Type: StreamPartTypeTextDelta, ID: "text-1", Delta: "Hello"}) {
+					return
+				}
+				if !yield(StreamPart{Type: StreamPartTypeTextEnd, ID: "text-1"}) {
+					return
+				}
+				if !yield(StreamPart{Type: StreamPartTypeReasoningStart, ID: "reasoning-1"}) {
+					return
+				}
+				if !yield(StreamPart{Type: StreamPartTypeReasoningDelta, ID: "reasoning-1", Delta: "thinking..."}) {
+					return
+				}
+				if !yield(StreamPart{Type: StreamPartTypeReasoningEnd, ID: "reasoning-1"}) {
+					return
+				}
+				if !yield(StreamPart{Type: StreamPartTypeToolInputStart, ID: "tool-1", ToolCallName: "test_tool"}) {
+					return
+				}
+				if !yield(StreamPart{Type: StreamPartTypeToolInputDelta, ID: "tool-1", Delta: `{"param"`}) {
+					return
+				}
+				if !yield(StreamPart{Type: StreamPartTypeToolInputEnd, ID: "tool-1"}) {
+					return
+				}
+				if !yield(StreamPart{Type: StreamPartTypeSource, ID: "source-1", SourceType: SourceTypeURL, URL: "https://example.com", Title: "Example"}) {
+					return
+				}
+				yield(StreamPart{
+					Type:         StreamPartTypeFinish,
+					Usage:        Usage{InputTokens: 5, OutputTokens: 2, TotalTokens: 7},
+					FinishReason: FinishReasonStop,
+				})
+			}, nil
+		},
+	}
+
+	// Create agent
+	agent := NewAgent(mockModel)
+
+	ctx := context.Background()
+
+	// Create streaming call with all callbacks
+	streamCall := AgentStreamCall{
+		Prompt: "Test all callbacks",
+		OnAgentStart: func() {
+			callbacks["OnAgentStart"] = true
+		},
+		OnAgentFinish: func(result *AgentResult) {
+			callbacks["OnAgentFinish"] = true
+		},
+		OnStepStart: func(stepNumber int) {
+			callbacks["OnStepStart"] = true
+		},
+		OnStepFinish: func(stepResult StepResult) {
+			callbacks["OnStepFinish"] = true
+		},
+		OnFinish: func(result *AgentResult) {
+			callbacks["OnFinish"] = true
+		},
+		OnError: func(err error) {
+			callbacks["OnError"] = true
+		},
+		OnChunk: func(part StreamPart) {
+			callbacks["OnChunk"] = true
+		},
+		OnWarnings: func(warnings []CallWarning) {
+			callbacks["OnWarnings"] = true
+		},
+		OnTextStart: func(id string) {
+			callbacks["OnTextStart"] = true
+		},
+		OnTextDelta: func(id, text string) {
+			callbacks["OnTextDelta"] = true
+		},
+		OnTextEnd: func(id string) {
+			callbacks["OnTextEnd"] = true
+		},
+		OnReasoningStart: func(id string) {
+			callbacks["OnReasoningStart"] = true
+		},
+		OnReasoningDelta: func(id, text string) {
+			callbacks["OnReasoningDelta"] = true
+		},
+		OnReasoningEnd: func(id string) {
+			callbacks["OnReasoningEnd"] = true
+		},
+		OnToolInputStart: func(id, toolName string) {
+			callbacks["OnToolInputStart"] = true
+		},
+		OnToolInputDelta: func(id, delta string) {
+			callbacks["OnToolInputDelta"] = true
+		},
+		OnToolInputEnd: func(id string) {
+			callbacks["OnToolInputEnd"] = true
+		},
+		OnToolCall: func(toolCall ToolCallContent) {
+			callbacks["OnToolCall"] = true
+		},
+		OnToolResult: func(result ToolResultContent) {
+			callbacks["OnToolResult"] = true
+		},
+		OnSource: func(source SourceContent) {
+			callbacks["OnSource"] = true
+		},
+		OnStreamFinish: func(usage Usage, finishReason FinishReason, providerMetadata ProviderOptions) {
+			callbacks["OnStreamFinish"] = true
+		},
+		OnStreamError: func(err error) {
+			callbacks["OnStreamError"] = true
+		},
+	}
+
+	// Execute streaming agent
+	result, err := agent.Stream(ctx, streamCall)
+	require.NoError(t, err)
+	require.NotNil(t, result)
+
+	// Verify that expected callbacks were called
+	expectedCallbacks := []string{
+		"OnAgentStart",
+		"OnAgentFinish",
+		"OnStepStart",
+		"OnStepFinish",
+		"OnFinish",
+		"OnChunk",
+		"OnWarnings",
+		"OnTextStart",
+		"OnTextDelta",
+		"OnTextEnd",
+		"OnReasoningStart",
+		"OnReasoningDelta",
+		"OnReasoningEnd",
+		"OnToolInputStart",
+		"OnToolInputDelta",
+		"OnToolInputEnd",
+		"OnSource",
+		"OnStreamFinish",
+	}
+
+	for _, callback := range expectedCallbacks {
+		require.True(t, callbacks[callback], "Expected callback %s to be called", callback)
+	}
+
+	// Verify that error callbacks were not called
+	require.False(t, callbacks["OnError"], "OnError should not be called in successful case")
+	require.False(t, callbacks["OnStreamError"], "OnStreamError should not be called in successful case")
+	require.False(t, callbacks["OnToolCall"], "OnToolCall should not be called without actual tool calls")
+	require.False(t, callbacks["OnToolResult"], "OnToolResult should not be called without actual tool results")
+}
+
+// TestStreamingAgentWithTools tests streaming agent with tool calls (mirrors TS test patterns)
+func TestStreamingAgentWithTools(t *testing.T) {
+	t.Parallel()
+
+	stepCount := 0
+	// Create a mock language model that makes a tool call then finishes
+	mockModel := &mockLanguageModel{
+		streamFunc: func(ctx context.Context, call Call) (StreamResponse, error) {
+			stepCount++
+			return func(yield func(StreamPart) bool) {
+				if stepCount == 1 {
+					// First step: make tool call
+					if !yield(StreamPart{Type: StreamPartTypeToolInputStart, ID: "tool-1", ToolCallName: "echo"}) {
+						return
+					}
+					if !yield(StreamPart{Type: StreamPartTypeToolInputDelta, ID: "tool-1", Delta: `{"message"`}) {
+						return
+					}
+					if !yield(StreamPart{Type: StreamPartTypeToolInputDelta, ID: "tool-1", Delta: `: "test"}`}) {
+						return
+					}
+					if !yield(StreamPart{Type: StreamPartTypeToolInputEnd, ID: "tool-1"}) {
+						return
+					}
+					if !yield(StreamPart{
+						Type:          StreamPartTypeToolCall,
+						ID:            "tool-1",
+						ToolCallName:  "echo",
+						ToolCallInput: `{"message": "test"}`,
+					}) {
+						return
+					}
+					yield(StreamPart{
+						Type:         StreamPartTypeFinish,
+						Usage:        Usage{InputTokens: 10, OutputTokens: 5, TotalTokens: 15},
+						FinishReason: FinishReasonToolCalls,
+					})
+				} else {
+					// Second step: finish after tool execution
+					if !yield(StreamPart{Type: StreamPartTypeTextStart, ID: "text-1"}) {
+						return
+					}
+					if !yield(StreamPart{Type: StreamPartTypeTextDelta, ID: "text-1", Delta: "Tool executed successfully"}) {
+						return
+					}
+					if !yield(StreamPart{Type: StreamPartTypeTextEnd, ID: "text-1"}) {
+						return
+					}
+					yield(StreamPart{
+						Type:         StreamPartTypeFinish,
+						Usage:        Usage{InputTokens: 5, OutputTokens: 3, TotalTokens: 8},
+						FinishReason: FinishReasonStop,
+					})
+				}
+			}, nil
+		},
+	}
+
+	// Create agent with echo tool
+	agent := NewAgent(
+		mockModel,
+		WithSystemPrompt("You are a helpful assistant."),
+		WithTools(&EchoTool{}),
+	)
+
+	ctx := context.Background()
+
+	// Track callback invocations
+	var toolInputStartCalled bool
+	var toolInputDeltaCalled bool
+	var toolInputEndCalled bool
+	var toolCallCalled bool
+	var toolResultCalled bool
+
+	// Create streaming call with callbacks
+	streamCall := AgentStreamCall{
+		Prompt: "Echo 'test'",
+		OnToolInputStart: func(id, toolName string) {
+			toolInputStartCalled = true
+			require.Equal(t, "tool-1", id)
+			require.Equal(t, "echo", toolName)
+		},
+		OnToolInputDelta: func(id, delta string) {
+			toolInputDeltaCalled = true
+			require.Equal(t, "tool-1", id)
+			require.Contains(t, []string{`{"message"`, `: "test"}`}, delta)
+		},
+		OnToolInputEnd: func(id string) {
+			toolInputEndCalled = true
+			require.Equal(t, "tool-1", id)
+		},
+		OnToolCall: func(toolCall ToolCallContent) {
+			toolCallCalled = true
+			require.Equal(t, "echo", toolCall.ToolName)
+			require.Equal(t, `{"message": "test"}`, toolCall.Input)
+		},
+		OnToolResult: func(result ToolResultContent) {
+			toolResultCalled = true
+			require.Equal(t, "echo", result.ToolName)
+		},
+	}
+
+	// Execute streaming agent
+	result, err := agent.Stream(ctx, streamCall)
+	require.NoError(t, err)
+
+	// Verify results
+	require.True(t, toolInputStartCalled, "OnToolInputStart should have been called")
+	require.True(t, toolInputDeltaCalled, "OnToolInputDelta should have been called")
+	require.True(t, toolInputEndCalled, "OnToolInputEnd should have been called")
+	require.True(t, toolCallCalled, "OnToolCall should have been called")
+	require.True(t, toolResultCalled, "OnToolResult should have been called")
+	require.Equal(t, 2, len(result.Steps)) // Two steps: tool call + final response
+
+	// Check that tool was executed in first step
+	firstStep := result.Steps[0]
+	toolCalls := firstStep.Content.ToolCalls()
+	require.Equal(t, 1, len(toolCalls))
+	require.Equal(t, "echo", toolCalls[0].ToolName)
+
+	toolResults := firstStep.Content.ToolResults()
+	require.Equal(t, 1, len(toolResults))
+	require.Equal(t, "echo", toolResults[0].ToolName)
+}
+
+// TestStreamingAgentTextDeltas tests text streaming (mirrors TS textStream tests)
+func TestStreamingAgentTextDeltas(t *testing.T) {
+	t.Parallel()
+
+	// Create a mock language model that returns text deltas
+	mockModel := &mockLanguageModel{
+		streamFunc: func(ctx context.Context, call Call) (StreamResponse, error) {
+			return func(yield func(StreamPart) bool) {
+				if !yield(StreamPart{Type: StreamPartTypeTextStart, ID: "text-1"}) {
+					return
+				}
+				if !yield(StreamPart{Type: StreamPartTypeTextDelta, ID: "text-1", Delta: "Hello"}) {
+					return
+				}
+				if !yield(StreamPart{Type: StreamPartTypeTextDelta, ID: "text-1", Delta: ", "}) {
+					return
+				}
+				if !yield(StreamPart{Type: StreamPartTypeTextDelta, ID: "text-1", Delta: "world!"}) {
+					return
+				}
+				if !yield(StreamPart{Type: StreamPartTypeTextEnd, ID: "text-1"}) {
+					return
+				}
+				yield(StreamPart{
+					Type:         StreamPartTypeFinish,
+					Usage:        Usage{InputTokens: 3, OutputTokens: 10, TotalTokens: 13},
+					FinishReason: FinishReasonStop,
+				})
+			}, nil
+		},
+	}
+
+	agent := NewAgent(mockModel)
+	ctx := context.Background()
+
+	// Track text deltas
+	var textDeltas []string
+
+	streamCall := AgentStreamCall{
+		Prompt: "Say hello",
+		OnTextDelta: func(id, text string) {
+			if text != "" {
+				textDeltas = append(textDeltas, text)
+			}
+		},
+	}
+
+	result, err := agent.Stream(ctx, streamCall)
+	require.NoError(t, err)
+
+	// Verify text deltas match expected pattern
+	require.Equal(t, []string{"Hello", ", ", "world!"}, textDeltas)
+	require.Equal(t, "Hello, world!", result.Response.Content.Text())
+	require.Equal(t, int64(13), result.TotalUsage.TotalTokens)
+}
+
+// TestStreamingAgentReasoning tests reasoning content (mirrors TS reasoning tests)
+func TestStreamingAgentReasoning(t *testing.T) {
+	t.Parallel()
+
+	mockModel := &mockLanguageModel{
+		streamFunc: func(ctx context.Context, call Call) (StreamResponse, error) {
+			return func(yield func(StreamPart) bool) {
+				if !yield(StreamPart{Type: StreamPartTypeReasoningStart, ID: "reasoning-1"}) {
+					return
+				}
+				if !yield(StreamPart{Type: StreamPartTypeReasoningDelta, ID: "reasoning-1", Delta: "I will open the conversation"}) {
+					return
+				}
+				if !yield(StreamPart{Type: StreamPartTypeReasoningDelta, ID: "reasoning-1", Delta: " with witty banter."}) {
+					return
+				}
+				if !yield(StreamPart{Type: StreamPartTypeReasoningEnd, ID: "reasoning-1"}) {
+					return
+				}
+				if !yield(StreamPart{Type: StreamPartTypeTextStart, ID: "text-1"}) {
+					return
+				}
+				if !yield(StreamPart{Type: StreamPartTypeTextDelta, ID: "text-1", Delta: "Hi there!"}) {
+					return
+				}
+				if !yield(StreamPart{Type: StreamPartTypeTextEnd, ID: "text-1"}) {
+					return
+				}
+				yield(StreamPart{
+					Type:         StreamPartTypeFinish,
+					Usage:        Usage{InputTokens: 5, OutputTokens: 15, TotalTokens: 20},
+					FinishReason: FinishReasonStop,
+				})
+			}, nil
+		},
+	}
+
+	agent := NewAgent(mockModel)
+	ctx := context.Background()
+
+	var reasoningDeltas []string
+	var textDeltas []string
+
+	streamCall := AgentStreamCall{
+		Prompt: "Think and respond",
+		OnReasoningDelta: func(id, text string) {
+			reasoningDeltas = append(reasoningDeltas, text)
+		},
+		OnTextDelta: func(id, text string) {
+			textDeltas = append(textDeltas, text)
+		},
+	}
+
+	result, err := agent.Stream(ctx, streamCall)
+	require.NoError(t, err)
+
+	// Verify reasoning and text are separate
+	require.Equal(t, []string{"I will open the conversation", " with witty banter."}, reasoningDeltas)
+	require.Equal(t, []string{"Hi there!"}, textDeltas)
+	require.Equal(t, "Hi there!", result.Response.Content.Text())
+	require.Equal(t, "I will open the conversation with witty banter.", result.Response.Content.ReasoningText())
+}
+
+// TestStreamingAgentError tests error handling (mirrors TS error tests)
+func TestStreamingAgentError(t *testing.T) {
+	t.Parallel()
+
+	// Create a mock language model that returns an error
+	mockModel := &mockLanguageModel{
+		streamFunc: func(ctx context.Context, call Call) (StreamResponse, error) {
+			return func(yield func(StreamPart) bool) {
+				yield(StreamPart{Type: StreamPartTypeError, Error: fmt.Errorf("mock stream error")})
+			}, nil
+		},
+	}
+
+	agent := NewAgent(mockModel)
+	ctx := context.Background()
+
+	// Track error callbacks
+	var streamErrorOccurred bool
+	var errorOccurred bool
+	var errorMessage string
+
+	streamCall := AgentStreamCall{
+		Prompt: "This will fail",
+		OnStreamError: func(err error) {
+			streamErrorOccurred = true
+		},
+		OnError: func(err error) {
+			errorOccurred = true
+			errorMessage = err.Error()
+		},
+	}
+
+	// Execute streaming agent
+	result, err := agent.Stream(ctx, streamCall)
+	require.Error(t, err)
+	require.Nil(t, result)
+	require.True(t, streamErrorOccurred, "OnStreamError should have been called")
+	require.True(t, errorOccurred, "OnError should have been called")
+	require.Contains(t, errorMessage, "mock stream error")
+}
+
+// TestStreamingAgentSources tests source handling (mirrors TS source tests)
+func TestStreamingAgentSources(t *testing.T) {
+	t.Parallel()
+
+	mockModel := &mockLanguageModel{
+		streamFunc: func(ctx context.Context, call Call) (StreamResponse, error) {
+			return func(yield func(StreamPart) bool) {
+				if !yield(StreamPart{
+					Type:       StreamPartTypeSource,
+					ID:         "source-1",
+					SourceType: SourceTypeURL,
+					URL:        "https://example.com",
+					Title:      "Example",
+				}) {
+					return
+				}
+				if !yield(StreamPart{Type: StreamPartTypeTextStart, ID: "text-1"}) {
+					return
+				}
+				if !yield(StreamPart{Type: StreamPartTypeTextDelta, ID: "text-1", Delta: "Hello!"}) {
+					return
+				}
+				if !yield(StreamPart{Type: StreamPartTypeTextEnd, ID: "text-1"}) {
+					return
+				}
+				if !yield(StreamPart{
+					Type:       StreamPartTypeSource,
+					ID:         "source-2",
+					SourceType: SourceTypeDocument,
+					Title:      "Document Example",
+				}) {
+					return
+				}
+				yield(StreamPart{
+					Type:         StreamPartTypeFinish,
+					Usage:        Usage{InputTokens: 3, OutputTokens: 5, TotalTokens: 8},
+					FinishReason: FinishReasonStop,
+				})
+			}, nil
+		},
+	}
+
+	agent := NewAgent(mockModel)
+	ctx := context.Background()
+
+	var sources []SourceContent
+
+	streamCall := AgentStreamCall{
+		Prompt: "Search and respond",
+		OnSource: func(source SourceContent) {
+			sources = append(sources, source)
+		},
+	}
+
+	result, err := agent.Stream(ctx, streamCall)
+	require.NoError(t, err)
+
+	// Verify sources were captured
+	require.Equal(t, 2, len(sources))
+	require.Equal(t, SourceTypeURL, sources[0].SourceType)
+	require.Equal(t, "https://example.com", sources[0].URL)
+	require.Equal(t, "Example", sources[0].Title)
+	require.Equal(t, SourceTypeDocument, sources[1].SourceType)
+	require.Equal(t, "Document Example", sources[1].Title)
+
+	// Verify sources are in final result
+	resultSources := result.Response.Content.Sources()
+	require.Equal(t, 2, len(resultSources))
+}

agent_test.go 🔗

@@ -39,6 +39,7 @@ func (m *mockTool) Run(ctx context.Context, call tools.ToolCall) (tools.ToolResp
 // Mock language model for testing
 type mockLanguageModel struct {
 	generateFunc func(ctx context.Context, call Call) (*Response, error)
+	streamFunc   func(ctx context.Context, call Call) (StreamResponse, error)
 }
 
 func (m *mockLanguageModel) Generate(ctx context.Context, call Call) (*Response, error) {
@@ -59,7 +60,10 @@ func (m *mockLanguageModel) Generate(ctx context.Context, call Call) (*Response,
 }
 
 func (m *mockLanguageModel) Stream(ctx context.Context, call Call) (StreamResponse, error) {
-	panic("not implemented")
+	if m.streamFunc != nil {
+		return m.streamFunc(ctx, call)
+	}
+	return nil, fmt.Errorf("mock stream not implemented")
 }
 
 func (m *mockLanguageModel) Provider() string {

providers/examples/streaming-agent-simple/main.go 🔗

@@ -0,0 +1,95 @@
+package main
+
+import (
+	"context"
+	"fmt"
+	"os"
+
+	"github.com/charmbracelet/crush/internal/ai"
+	"github.com/charmbracelet/crush/internal/ai/providers"
+	"github.com/charmbracelet/crush/internal/llm/tools"
+)
+
+// Simple echo tool for demonstration
+type EchoTool struct{}
+
+func (e *EchoTool) Info() tools.ToolInfo {
+	return tools.ToolInfo{
+		Name:        "echo",
+		Description: "Echo back the provided message",
+		Parameters: map[string]any{
+			"message": map[string]any{
+				"type":        "string",
+				"description": "The message to echo back",
+			},
+		},
+		Required: []string{"message"},
+	}
+}
+
+func (e *EchoTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolResponse, error) {
+	return tools.NewTextResponse("Echo: " + params.Input), nil
+}
+
+func main() {
+	// Check for API key
+	apiKey := os.Getenv("OPENAI_API_KEY")
+	if apiKey == "" {
+		fmt.Println("Please set OPENAI_API_KEY environment variable")
+		os.Exit(1)
+	}
+
+	// Create provider and model
+	provider := providers.NewOpenAIProvider(
+		providers.WithOpenAIApiKey(apiKey),
+	)
+	model := provider.LanguageModel("gpt-4o-mini")
+
+	// Create streaming agent
+	agent := ai.NewAgent(
+		model,
+		ai.WithSystemPrompt("You are a helpful assistant."),
+		ai.WithTools(&EchoTool{}),
+	)
+
+	ctx := context.Background()
+
+	fmt.Println("Simple Streaming Agent Example")
+	fmt.Println("==============================")
+	fmt.Println()
+
+	// Basic streaming with key callbacks
+	streamCall := ai.AgentStreamCall{
+		Prompt: "Please echo back 'Hello, streaming world!'",
+		
+		// Show real-time text as it streams
+		OnTextDelta: func(id, text string) {
+			fmt.Print(text)
+		},
+		
+		// Show when tools are called
+		OnToolCall: func(toolCall ai.ToolCallContent) {
+			fmt.Printf("\n[Tool: %s called]\n", toolCall.ToolName)
+		},
+		
+		// Show tool results
+		OnToolResult: func(result ai.ToolResultContent) {
+			fmt.Printf("[Tool result received]\n")
+		},
+		
+		// Show when each step completes
+		OnStepFinish: func(step ai.StepResult) {
+			fmt.Printf("\n[Step completed: %s]\n", step.FinishReason)
+		},
+	}
+
+	fmt.Println("Assistant response:")
+	result, err := agent.Stream(ctx, streamCall)
+	if err != nil {
+		fmt.Printf("Error: %v\n", err)
+		os.Exit(1)
+	}
+
+	fmt.Printf("\n\nFinal result: %s\n", result.Response.Content.Text())
+	fmt.Printf("Steps: %d, Total tokens: %d\n", len(result.Steps), result.TotalUsage.TotalTokens)
+}

providers/examples/streaming-agent/main.go 🔗

@@ -0,0 +1,274 @@
+package main
+
+import (
+	"context"
+	"fmt"
+	"os"
+	"strings"
+
+	"github.com/charmbracelet/crush/internal/ai"
+	"github.com/charmbracelet/crush/internal/ai/providers"
+	"github.com/charmbracelet/crush/internal/llm/tools"
+)
+
+// WeatherTool is a simple tool that simulates weather lookup
+type WeatherTool struct{}
+
+func (w *WeatherTool) Info() tools.ToolInfo {
+	return tools.ToolInfo{
+		Name:        "get_weather",
+		Description: "Get the current weather for a specific location",
+		Parameters: map[string]any{
+			"location": map[string]any{
+				"type":        "string",
+				"description": "The city and country, e.g. 'London, UK'",
+			},
+			"unit": map[string]any{
+				"type":        "string",
+				"description": "Temperature unit (celsius or fahrenheit)",
+				"enum":        []string{"celsius", "fahrenheit"},
+				"default":     "celsius",
+			},
+		},
+		Required: []string{"location"},
+	}
+}
+
+func (w *WeatherTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolResponse, error) {
+	// Simulate weather lookup with some fake data
+	location := "Unknown"
+	if strings.Contains(params.Input, "pristina") || strings.Contains(params.Input, "Pristina") {
+		location = "Pristina, Kosovo"
+	} else if strings.Contains(params.Input, "london") || strings.Contains(params.Input, "London") {
+		location = "London, UK"
+	} else if strings.Contains(params.Input, "new york") || strings.Contains(params.Input, "New York") {
+		location = "New York, USA"
+	}
+
+	unit := "celsius"
+	if strings.Contains(params.Input, "fahrenheit") {
+		unit = "fahrenheit"
+	}
+
+	var temp string
+	if unit == "fahrenheit" {
+		temp = "72°F"
+	} else {
+		temp = "22°C"
+	}
+
+	weather := fmt.Sprintf("The current weather in %s is %s with partly cloudy skies and light winds.", location, temp)
+	return tools.NewTextResponse(weather), nil
+}
+
+// CalculatorTool demonstrates a second tool for multi-tool scenarios
+type CalculatorTool struct{}
+
+func (c *CalculatorTool) Info() tools.ToolInfo {
+	return tools.ToolInfo{
+		Name:        "calculate",
+		Description: "Perform basic mathematical calculations",
+		Parameters: map[string]any{
+			"expression": map[string]any{
+				"type":        "string",
+				"description": "Mathematical expression to evaluate (e.g., '2 + 2', '10 * 5')",
+			},
+		},
+		Required: []string{"expression"},
+	}
+}
+
+func (c *CalculatorTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolResponse, error) {
+	// Simple calculator simulation
+	expr := strings.ReplaceAll(params.Input, "\"", "")
+	if strings.Contains(expr, "2 + 2") || strings.Contains(expr, "2+2") {
+		return tools.NewTextResponse("2 + 2 = 4"), nil
+	} else if strings.Contains(expr, "10 * 5") || strings.Contains(expr, "10*5") {
+		return tools.NewTextResponse("10 * 5 = 50"), nil
+	}
+	return tools.NewTextResponse("I can calculate simple expressions like '2 + 2' or '10 * 5'"), nil
+}
+
+func main() {
+	// Check for API key
+	apiKey := os.Getenv("OPENAI_API_KEY")
+	if apiKey == "" {
+		fmt.Println("❌ Please set OPENAI_API_KEY environment variable")
+		fmt.Println("   export OPENAI_API_KEY=your_api_key_here")
+		os.Exit(1)
+	}
+
+	fmt.Println("🚀 Streaming Agent Example")
+	fmt.Println("==========================")
+	fmt.Println()
+
+	// Create OpenAI provider and model
+	provider := providers.NewOpenAIProvider(
+		providers.WithOpenAIApiKey(apiKey),
+	)
+	model := provider.LanguageModel("gpt-4o-mini") // Using mini for faster/cheaper responses
+
+	// Create agent with tools
+	agent := ai.NewAgent(
+		model,
+		ai.WithSystemPrompt("You are a helpful assistant that can check weather and do calculations. Be concise and friendly."),
+		ai.WithTools(&WeatherTool{}, &CalculatorTool{}),
+	)
+
+	ctx := context.Background()
+
+	// Demonstrate streaming with comprehensive callbacks
+	fmt.Println("💬 Asking: \"What's the weather in Pristina and what's 2 + 2?\"")
+	fmt.Println()
+
+	// Track streaming events
+	var stepCount int
+	var textBuffer strings.Builder
+	var reasoningBuffer strings.Builder
+
+	// Create streaming call with all callbacks
+	streamCall := ai.AgentStreamCall{
+		Prompt: "What's the weather in Pristina and what's 2 + 2?",
+
+		// Agent-level callbacks
+		OnAgentStart: func() {
+			fmt.Println("🎬 Agent started")
+		},
+		OnAgentFinish: func(result *ai.AgentResult) {
+			fmt.Printf("🏁 Agent finished with %d steps, total tokens: %d\n", len(result.Steps), result.TotalUsage.TotalTokens)
+		},
+		OnStepStart: func(stepNumber int) {
+			stepCount++
+			fmt.Printf("📝 Step %d started\n", stepNumber+1)
+		},
+		OnStepFinish: func(stepResult ai.StepResult) {
+			fmt.Printf("✅ Step completed (reason: %s, tokens: %d)\n", stepResult.FinishReason, stepResult.Usage.TotalTokens)
+		},
+		OnFinish: func(result *ai.AgentResult) {
+			fmt.Printf("🎯 Final result ready with %d steps\n", len(result.Steps))
+		},
+		OnError: func(err error) {
+			fmt.Printf("❌ Error: %v\n", err)
+		},
+
+		// Stream part callbacks
+		OnWarnings: func(warnings []ai.CallWarning) {
+			for _, warning := range warnings {
+				fmt.Printf("⚠️  Warning: %s\n", warning.Message)
+			}
+		},
+		OnTextStart: func(id string) {
+			fmt.Print("💭 Assistant: ")
+		},
+		OnTextDelta: func(id, text string) {
+			fmt.Print(text)
+			textBuffer.WriteString(text)
+		},
+		OnTextEnd: func(id string) {
+			fmt.Println()
+		},
+		OnReasoningStart: func(id string) {
+			fmt.Print("🤔 Thinking: ")
+		},
+		OnReasoningDelta: func(id, text string) {
+			reasoningBuffer.WriteString(text)
+		},
+		OnReasoningEnd: func(id string) {
+			if reasoningBuffer.Len() > 0 {
+				fmt.Printf("%s\n", reasoningBuffer.String())
+				reasoningBuffer.Reset()
+			}
+		},
+		OnToolInputStart: func(id, toolName string) {
+			fmt.Printf("🔧 Calling tool: %s\n", toolName)
+		},
+		OnToolInputDelta: func(id, delta string) {
+			// Could show tool input being built, but it's often noisy
+		},
+		OnToolInputEnd: func(id string) {
+			// Tool input complete
+		},
+		OnToolCall: func(toolCall ai.ToolCallContent) {
+			fmt.Printf("🛠️  Tool call: %s\n", toolCall.ToolName)
+			fmt.Printf("   Input: %s\n", toolCall.Input)
+		},
+		OnToolResult: func(result ai.ToolResultContent) {
+			fmt.Printf("🎯 Tool result from %s:\n", result.ToolName)
+			switch output := result.Result.(type) {
+			case ai.ToolResultOutputContentText:
+				fmt.Printf("   %s\n", output.Text)
+			case ai.ToolResultOutputContentError:
+				fmt.Printf("   Error: %s\n", output.Error.Error())
+			}
+		},
+		OnSource: func(source ai.SourceContent) {
+			fmt.Printf("📚 Source: %s (%s)\n", source.Title, source.URL)
+		},
+		OnStreamFinish: func(usage ai.Usage, finishReason ai.FinishReason, providerMetadata ai.ProviderOptions) {
+			fmt.Printf("📊 Stream finished (reason: %s, tokens: %d)\n", finishReason, usage.TotalTokens)
+		},
+		OnStreamError: func(err error) {
+			fmt.Printf("💥 Stream error: %v\n", err)
+		},
+	}
+
+	// Execute streaming agent
+	result, err := agent.Stream(ctx, streamCall)
+	if err != nil {
+		fmt.Printf("❌ Agent failed: %v\n", err)
+		os.Exit(1)
+	}
+
+	// Display final results
+	fmt.Println()
+	fmt.Println("📋 Final Summary")
+	fmt.Println("================")
+	fmt.Printf("Steps executed: %d\n", len(result.Steps))
+	fmt.Printf("Total tokens used: %d (input: %d, output: %d)\n", 
+		result.TotalUsage.TotalTokens, 
+		result.TotalUsage.InputTokens, 
+		result.TotalUsage.OutputTokens)
+	
+	if result.TotalUsage.ReasoningTokens > 0 {
+		fmt.Printf("Reasoning tokens: %d\n", result.TotalUsage.ReasoningTokens)
+	}
+
+	fmt.Printf("Final response: %s\n", result.Response.Content.Text())
+
+	// Show step details
+	fmt.Println()
+	fmt.Println("🔍 Step Details")
+	fmt.Println("===============")
+	for i, step := range result.Steps {
+		fmt.Printf("Step %d:\n", i+1)
+		fmt.Printf("  Finish reason: %s\n", step.FinishReason)
+		fmt.Printf("  Content types: ")
+		
+		var contentTypes []string
+		for _, content := range step.Content {
+			contentTypes = append(contentTypes, string(content.GetType()))
+		}
+		fmt.Printf("%s\n", strings.Join(contentTypes, ", "))
+		
+		// Show tool calls and results
+		toolCalls := step.Content.ToolCalls()
+		if len(toolCalls) > 0 {
+			fmt.Printf("  Tool calls: ")
+			var toolNames []string
+			for _, tc := range toolCalls {
+				toolNames = append(toolNames, tc.ToolName)
+			}
+			fmt.Printf("%s\n", strings.Join(toolNames, ", "))
+		}
+		
+		toolResults := step.Content.ToolResults()
+		if len(toolResults) > 0 {
+			fmt.Printf("  Tool results: %d\n", len(toolResults))
+		}
+		
+		fmt.Printf("  Tokens: %d\n", step.Usage.TotalTokens)
+		fmt.Println()
+	}
+
+	fmt.Println("✨ Example completed successfully!")
+}

providers/openai.go 🔗

@@ -1134,17 +1134,17 @@ func parseAnnotationsFromDelta(delta openai.ChatCompletionChunkChoiceDelta) []op
 	var annotations []openai.ChatCompletionMessageAnnotation
 
 	// Parse the raw JSON to extract annotations
-	var deltaData map[string]interface{}
+	var deltaData map[string]any
 	if err := json.Unmarshal([]byte(delta.RawJSON()), &deltaData); err != nil {
 		return annotations
 	}
 
 	// Check if annotations exist in the delta
-	if annotationsData, ok := deltaData["annotations"].([]interface{}); ok {
+	if annotationsData, ok := deltaData["annotations"].([]any); ok {
 		for _, annotationData := range annotationsData {
-			if annotationMap, ok := annotationData.(map[string]interface{}); ok {
+			if annotationMap, ok := annotationData.(map[string]any); ok {
 				if annotationType, ok := annotationMap["type"].(string); ok && annotationType == "url_citation" {
-					if urlCitationData, ok := annotationMap["url_citation"].(map[string]interface{}); ok {
+					if urlCitationData, ok := annotationMap["url_citation"].(map[string]any); ok {
 						annotation := openai.ChatCompletionMessageAnnotation{
 							Type: "url_citation",
 							URLCitation: openai.ChatCompletionMessageAnnotationURLCitation{