From 13feb855bca8b993176f55c4ef9b17be5c6c6b44 Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Fri, 22 Aug 2025 12:44:34 +0200 Subject: [PATCH] feat: streaming agent --- agent.go | 429 ++++++++++++- agent_stream_test.go | 569 ++++++++++++++++++ agent_test.go | 6 +- .../examples/streaming-agent-simple/main.go | 95 +++ providers/examples/streaming-agent/main.go | 274 +++++++++ providers/openai.go | 8 +- 6 files changed, 1370 insertions(+), 11 deletions(-) create mode 100644 agent_stream_test.go create mode 100644 providers/examples/streaming-agent-simple/main.go create mode 100644 providers/examples/streaming-agent/main.go diff --git a/agent.go b/agent.go index f26ab18859746d35babc31d05c22630709f2fa23..fb84e79b9752f3d5eafecd267a0e7a85e8da9650 100644 --- a/agent.go +++ b/agent.go @@ -157,6 +157,53 @@ type AgentCall struct { OnStepFinished OnStepFinishedFunction } +type AgentStreamCall struct { + Prompt string `json:"prompt"` + Files []FilePart `json:"files"` + Messages []Message `json:"messages"` + MaxOutputTokens *int64 + Temperature *float64 `json:"temperature"` + TopP *float64 `json:"top_p"` + TopK *int64 `json:"top_k"` + PresencePenalty *float64 `json:"presence_penalty"` + FrequencyPenalty *float64 `json:"frequency_penalty"` + ActiveTools []string `json:"active_tools"` + Headers map[string]string + ProviderOptions ProviderOptions + OnRetry OnRetryCallback + MaxRetries *int + + StopWhen []StopCondition + PrepareStep PrepareStepFunction + RepairToolCall RepairToolCallFunction + + // Agent-level callbacks + OnAgentStart func() // Called when agent starts + OnAgentFinish func(result *AgentResult) // Called when agent finishes + OnStepStart func(stepNumber int) // Called when a step starts + OnStepFinish func(stepResult StepResult) // Called when a step finishes + OnFinish func(result *AgentResult) // Called when entire agent completes + OnError func(error) // Called when an error occurs + + // Stream part callbacks - called for each corresponding stream part type + OnChunk func(StreamPart) // Called for each stream part (catch-all) + OnWarnings func(warnings []CallWarning) // Called for warnings + OnTextStart func(id string) // Called when text starts + OnTextDelta func(id, text string) // Called for text deltas + OnTextEnd func(id string) // Called when text ends + OnReasoningStart func(id string) // Called when reasoning starts + OnReasoningDelta func(id, text string) // Called for reasoning deltas + OnReasoningEnd func(id string) // Called when reasoning ends + OnToolInputStart func(id, toolName string) // Called when tool input starts + OnToolInputDelta func(id, delta string) // Called for tool input deltas + OnToolInputEnd func(id string) // Called when tool input ends + OnToolCall func(toolCall ToolCallContent) // Called when tool call is complete + OnToolResult func(result ToolResultContent) // Called when tool execution completes + OnSource func(source SourceContent) // Called for source references + OnStreamFinish func(usage Usage, finishReason FinishReason, providerMetadata ProviderOptions) // Called when stream finishes + OnStreamError func(error) // Called when stream error occurs +} + type AgentResult struct { Steps []StepResult // Final response @@ -166,7 +213,7 @@ type AgentResult struct { type Agent interface { Generate(context.Context, AgentCall) (*AgentResult, error) - Stream(context.Context, AgentCall) (StreamResponse, error) + Stream(context.Context, AgentStreamCall) (*AgentResult, error) } type agentOption = func(*AgentSettings) @@ -343,7 +390,7 @@ func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, err } } - toolResults, err := a.executeTools(ctx, a.settings.tools, stepToolCalls) + toolResults, err := a.executeTools(ctx, a.settings.tools, stepToolCalls, nil) // Build step content with validated tool calls and tool results stepContent := []Content{} @@ -501,7 +548,7 @@ func toResponseMessages(content []Content) []Message { return messages } -func (a *agent) executeTools(ctx context.Context, allTools []tools.BaseTool, toolCalls []ToolCallContent) ([]ToolResultContent, error) { +func (a *agent) executeTools(ctx context.Context, allTools []tools.BaseTool, toolCalls []ToolCallContent, toolResultCallback func(result ToolResultContent)) ([]ToolResultContent, error) { if len(toolCalls) == 0 { return nil, nil } @@ -532,6 +579,10 @@ func (a *agent) executeTools(ctx context.Context, allTools []tools.BaseTool, too }, ProviderExecuted: false, } + if toolResultCallback != nil { + toolResultCallback(results[index]) + } + return } @@ -545,6 +596,10 @@ func (a *agent) executeTools(ctx context.Context, allTools []tools.BaseTool, too }, ProviderExecuted: false, } + + if toolResultCallback != nil { + toolResultCallback(results[index]) + } return } @@ -563,6 +618,9 @@ func (a *agent) executeTools(ctx context.Context, allTools []tools.BaseTool, too }, ProviderExecuted: false, } + if toolResultCallback != nil { + toolResultCallback(results[index]) + } toolExecutionError = err return } @@ -576,6 +634,10 @@ func (a *agent) executeTools(ctx context.Context, allTools []tools.BaseTool, too }, ProviderExecuted: false, } + + if toolResultCallback != nil { + toolResultCallback(results[index]) + } } else { results[index] = ToolResultContent{ ToolCallID: call.ToolCallID, @@ -585,6 +647,9 @@ func (a *agent) executeTools(ctx context.Context, allTools []tools.BaseTool, too }, ProviderExecuted: false, } + if toolResultCallback != nil { + toolResultCallback(results[index]) + } } }(i, toolCall) } @@ -596,9 +661,164 @@ func (a *agent) executeTools(ctx context.Context, allTools []tools.BaseTool, too } // Stream implements Agent. -func (a *agent) Stream(ctx context.Context, opts AgentCall) (StreamResponse, error) { - // TODO: implement the agentic stuff - panic("not implemented") +func (a *agent) Stream(ctx context.Context, opts AgentStreamCall) (*AgentResult, error) { + // Convert AgentStreamCall to AgentCall for preparation + call := AgentCall{ + Prompt: opts.Prompt, + Files: opts.Files, + Messages: opts.Messages, + MaxOutputTokens: opts.MaxOutputTokens, + Temperature: opts.Temperature, + TopP: opts.TopP, + TopK: opts.TopK, + PresencePenalty: opts.PresencePenalty, + FrequencyPenalty: opts.FrequencyPenalty, + ActiveTools: opts.ActiveTools, + Headers: opts.Headers, + ProviderOptions: opts.ProviderOptions, + MaxRetries: opts.MaxRetries, + StopWhen: opts.StopWhen, + PrepareStep: opts.PrepareStep, + RepairToolCall: opts.RepairToolCall, + } + + call = a.prepareCall(call) + + initialPrompt, err := a.createPrompt(a.settings.systemPrompt, call.Prompt, call.Messages, call.Files...) + if err != nil { + return nil, err + } + + var responseMessages []Message + var steps []StepResult + var totalUsage Usage + + // Start agent stream + if opts.OnAgentStart != nil { + opts.OnAgentStart() + } + + for stepNumber := 0; ; stepNumber++ { + stepInputMessages := append(initialPrompt, responseMessages...) + stepModel := a.settings.model + stepSystemPrompt := a.settings.systemPrompt + stepActiveTools := call.ActiveTools + stepToolChoice := ToolChoiceAuto + disableAllTools := false + + // Apply step preparation if provided + if call.PrepareStep != nil { + prepared := call.PrepareStep(PrepareStepFunctionOptions{ + Model: stepModel, + Steps: steps, + StepNumber: stepNumber, + Messages: stepInputMessages, + }) + + if prepared.Messages != nil { + stepInputMessages = prepared.Messages + } + if prepared.Model != nil { + stepModel = prepared.Model + } + if prepared.System != nil { + stepSystemPrompt = *prepared.System + } + if prepared.ToolChoice != nil { + stepToolChoice = *prepared.ToolChoice + } + if len(prepared.ActiveTools) > 0 { + stepActiveTools = prepared.ActiveTools + } + disableAllTools = prepared.DisableAllTools + } + + // Recreate prompt with potentially modified system prompt + if stepSystemPrompt != a.settings.systemPrompt { + stepPrompt, err := a.createPrompt(stepSystemPrompt, call.Prompt, call.Messages, call.Files...) + if err != nil { + return nil, err + } + if len(stepInputMessages) > 0 && len(stepPrompt) > 0 { + stepInputMessages[0] = stepPrompt[0] + } + } + + preparedTools := a.prepareTools(a.settings.tools, stepActiveTools, disableAllTools) + + // Start step stream + if opts.OnStepStart != nil { + opts.OnStepStart(stepNumber) + } + + // Create streaming call + streamCall := Call{ + Prompt: stepInputMessages, + MaxOutputTokens: call.MaxOutputTokens, + Temperature: call.Temperature, + TopP: call.TopP, + TopK: call.TopK, + PresencePenalty: call.PresencePenalty, + FrequencyPenalty: call.FrequencyPenalty, + Tools: preparedTools, + ToolChoice: &stepToolChoice, + Headers: call.Headers, + ProviderOptions: call.ProviderOptions, + } + + // Get streaming response + stream, err := stepModel.Stream(ctx, streamCall) + if err != nil { + if opts.OnError != nil { + opts.OnError(err) + } + return nil, err + } + + // Process stream with tool execution + stepResult, shouldContinue, err := a.processStepStream(ctx, stream, opts, steps) + if err != nil { + if opts.OnError != nil { + opts.OnError(err) + } + return nil, err + } + + steps = append(steps, stepResult) + totalUsage = addUsage(totalUsage, stepResult.Usage) + + // Call step finished callback + if opts.OnStepFinish != nil { + opts.OnStepFinish(stepResult) + } + + // Add step messages to response messages + stepMessages := toResponseMessages(stepResult.Content) + responseMessages = append(responseMessages, stepMessages...) + + // Check stop conditions + shouldStop := isStopConditionMet(call.StopWhen, steps) + if shouldStop || !shouldContinue { + break + } + } + + // Finish agent stream + agentResult := &AgentResult{ + Steps: steps, + Response: steps[len(steps)-1].Response, + TotalUsage: totalUsage, + } + + if opts.OnFinish != nil { + opts.OnFinish(agentResult) + } + + if opts.OnAgentFinish != nil { + opts.OnAgentFinish(agentResult) + } + + return agentResult, nil } func (a *agent) prepareTools(tools []tools.BaseTool, activeTools []string, disableAllTools bool) []Tool { @@ -776,6 +996,203 @@ func WithOnStepFinished(fn OnStepFinishedFunction) agentOption { } } +// processStepStream processes a single step's stream and returns the step result +func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, opts AgentStreamCall, _ []StepResult) (StepResult, bool, error) { + var stepContent []Content + var stepToolCalls []ToolCallContent + var stepUsage Usage + var stepFinishReason FinishReason = FinishReasonUnknown + var stepWarnings []CallWarning + var stepProviderMetadata ProviderMetadata + + activeToolCalls := make(map[string]*ToolCallContent) + activeTextContent := make(map[string]string) + + // Process stream parts + for part := range stream { + // Forward all parts to chunk callback + if opts.OnChunk != nil { + opts.OnChunk(part) + } + + switch part.Type { + case StreamPartTypeWarnings: + stepWarnings = part.Warnings + if opts.OnWarnings != nil { + opts.OnWarnings(part.Warnings) + } + + case StreamPartTypeTextStart: + activeTextContent[part.ID] = "" + if opts.OnTextStart != nil { + opts.OnTextStart(part.ID) + } + + case StreamPartTypeTextDelta: + if _, exists := activeTextContent[part.ID]; exists { + activeTextContent[part.ID] += part.Delta + } + if opts.OnTextDelta != nil { + opts.OnTextDelta(part.ID, part.Delta) + } + + case StreamPartTypeTextEnd: + if text, exists := activeTextContent[part.ID]; exists { + stepContent = append(stepContent, TextContent{ + Text: text, + ProviderMetadata: ProviderMetadata(part.ProviderMetadata), + }) + delete(activeTextContent, part.ID) + } + if opts.OnTextEnd != nil { + opts.OnTextEnd(part.ID) + } + + case StreamPartTypeReasoningStart: + activeTextContent[part.ID] = "" + if opts.OnReasoningStart != nil { + opts.OnReasoningStart(part.ID) + } + + case StreamPartTypeReasoningDelta: + if _, exists := activeTextContent[part.ID]; exists { + activeTextContent[part.ID] += part.Delta + } + if opts.OnReasoningDelta != nil { + opts.OnReasoningDelta(part.ID, part.Delta) + } + + case StreamPartTypeReasoningEnd: + if text, exists := activeTextContent[part.ID]; exists { + stepContent = append(stepContent, ReasoningContent{ + Text: text, + ProviderMetadata: ProviderMetadata(part.ProviderMetadata), + }) + delete(activeTextContent, part.ID) + } + if opts.OnReasoningEnd != nil { + opts.OnReasoningEnd(part.ID) + } + + case StreamPartTypeToolInputStart: + activeToolCalls[part.ID] = &ToolCallContent{ + ToolCallID: part.ID, + ToolName: part.ToolCallName, + Input: "", + ProviderExecuted: part.ProviderExecuted, + } + if opts.OnToolInputStart != nil { + opts.OnToolInputStart(part.ID, part.ToolCallName) + } + + case StreamPartTypeToolInputDelta: + if toolCall, exists := activeToolCalls[part.ID]; exists { + toolCall.Input += part.Delta + } + if opts.OnToolInputDelta != nil { + opts.OnToolInputDelta(part.ID, part.Delta) + } + + case StreamPartTypeToolInputEnd: + if opts.OnToolInputEnd != nil { + opts.OnToolInputEnd(part.ID) + } + + case StreamPartTypeToolCall: + toolCall := ToolCallContent{ + ToolCallID: part.ID, + ToolName: part.ToolCallName, + Input: part.ToolCallInput, + ProviderExecuted: part.ProviderExecuted, + ProviderMetadata: ProviderMetadata(part.ProviderMetadata), + } + + // Validate and potentially repair the tool call + validatedToolCall := a.validateAndRepairToolCall(ctx, toolCall, a.settings.tools, a.settings.systemPrompt, nil, opts.RepairToolCall) + stepToolCalls = append(stepToolCalls, validatedToolCall) + stepContent = append(stepContent, validatedToolCall) + + if opts.OnToolCall != nil { + opts.OnToolCall(validatedToolCall) + } + + // Clean up active tool call + delete(activeToolCalls, part.ID) + + case StreamPartTypeSource: + sourceContent := SourceContent{ + SourceType: part.SourceType, + ID: part.ID, + URL: part.URL, + Title: part.Title, + ProviderMetadata: ProviderMetadata(part.ProviderMetadata), + } + stepContent = append(stepContent, sourceContent) + if opts.OnSource != nil { + opts.OnSource(sourceContent) + } + + case StreamPartTypeFinish: + stepUsage = part.Usage + stepFinishReason = part.FinishReason + stepProviderMetadata = ProviderMetadata(part.ProviderMetadata) + if opts.OnStreamFinish != nil { + opts.OnStreamFinish(part.Usage, part.FinishReason, part.ProviderMetadata) + } + + case StreamPartTypeError: + if opts.OnStreamError != nil { + opts.OnStreamError(part.Error) + } + if opts.OnError != nil { + opts.OnError(part.Error) + } + return StepResult{}, false, part.Error + } + } + + // Execute tools if any + var toolResults []ToolResultContent + if len(stepToolCalls) > 0 { + var err error + toolResults, err = a.executeTools(ctx, a.settings.tools, stepToolCalls, opts.OnToolResult) + if err != nil { + return StepResult{}, false, err + } + // Add tool results to content + for _, result := range toolResults { + stepContent = append(stepContent, result) + } + } + + stepResult := StepResult{ + Response: Response{ + Content: stepContent, + FinishReason: stepFinishReason, + Usage: stepUsage, + Warnings: stepWarnings, + ProviderMetadata: stepProviderMetadata, + }, + Messages: toResponseMessages(stepContent), + } + + // Determine if we should continue (has tool calls and not stopped) + shouldContinue := len(stepToolCalls) > 0 && stepFinishReason == FinishReasonToolCalls + + return stepResult, shouldContinue, nil +} + +func addUsage(a, b Usage) Usage { + return Usage{ + InputTokens: a.InputTokens + b.InputTokens, + OutputTokens: a.OutputTokens + b.OutputTokens, + TotalTokens: a.TotalTokens + b.TotalTokens, + ReasoningTokens: a.ReasoningTokens + b.ReasoningTokens, + CacheCreationTokens: a.CacheCreationTokens + b.CacheCreationTokens, + CacheReadTokens: a.CacheReadTokens + b.CacheReadTokens, + } +} + func WithHeaders(headers map[string]string) agentOption { return func(s *AgentSettings) { s.headers = headers diff --git a/agent_stream_test.go b/agent_stream_test.go new file mode 100644 index 0000000000000000000000000000000000000000..3ddae7431a26f6c721fdd11308ecfb31c83c0b68 --- /dev/null +++ b/agent_stream_test.go @@ -0,0 +1,569 @@ +package ai + +import ( + "context" + "encoding/json" + "fmt" + "testing" + + "github.com/charmbracelet/crush/internal/llm/tools" + "github.com/stretchr/testify/require" +) + +// EchoTool is a simple tool that echoes back the input message +type EchoTool struct{} + +// Info returns the tool information +func (e *EchoTool) Info() tools.ToolInfo { + return tools.ToolInfo{ + Name: "echo", + Description: "Echo back the provided message", + Parameters: map[string]any{ + "message": map[string]any{ + "type": "string", + "description": "The message to echo back", + }, + }, + Required: []string{"message"}, + } +} + +// Run executes the echo tool +func (e *EchoTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolResponse, error) { + var input struct { + Message string `json:"message"` + } + + if err := json.Unmarshal([]byte(params.Input), &input); err != nil { + return tools.NewTextErrorResponse("Invalid input: " + err.Error()), nil + } + + if input.Message == "" { + return tools.NewTextErrorResponse("Message cannot be empty"), nil + } + + return tools.NewTextResponse("Echo: " + input.Message), nil +} + +// TestStreamingAgentCallbacks tests that all streaming callbacks are called correctly +func TestStreamingAgentCallbacks(t *testing.T) { + t.Parallel() + + // Track which callbacks were called + callbacks := make(map[string]bool) + + // Create a mock language model that returns various stream parts + mockModel := &mockLanguageModel{ + streamFunc: func(ctx context.Context, call Call) (StreamResponse, error) { + return func(yield func(StreamPart) bool) { + // Test all stream part types + if !yield(StreamPart{Type: StreamPartTypeWarnings, Warnings: []CallWarning{{Type: CallWarningTypeOther, Message: "test warning"}}}) { + return + } + if !yield(StreamPart{Type: StreamPartTypeTextStart, ID: "text-1"}) { + return + } + if !yield(StreamPart{Type: StreamPartTypeTextDelta, ID: "text-1", Delta: "Hello"}) { + return + } + if !yield(StreamPart{Type: StreamPartTypeTextEnd, ID: "text-1"}) { + return + } + if !yield(StreamPart{Type: StreamPartTypeReasoningStart, ID: "reasoning-1"}) { + return + } + if !yield(StreamPart{Type: StreamPartTypeReasoningDelta, ID: "reasoning-1", Delta: "thinking..."}) { + return + } + if !yield(StreamPart{Type: StreamPartTypeReasoningEnd, ID: "reasoning-1"}) { + return + } + if !yield(StreamPart{Type: StreamPartTypeToolInputStart, ID: "tool-1", ToolCallName: "test_tool"}) { + return + } + if !yield(StreamPart{Type: StreamPartTypeToolInputDelta, ID: "tool-1", Delta: `{"param"`}) { + return + } + if !yield(StreamPart{Type: StreamPartTypeToolInputEnd, ID: "tool-1"}) { + return + } + if !yield(StreamPart{Type: StreamPartTypeSource, ID: "source-1", SourceType: SourceTypeURL, URL: "https://example.com", Title: "Example"}) { + return + } + yield(StreamPart{ + Type: StreamPartTypeFinish, + Usage: Usage{InputTokens: 5, OutputTokens: 2, TotalTokens: 7}, + FinishReason: FinishReasonStop, + }) + }, nil + }, + } + + // Create agent + agent := NewAgent(mockModel) + + ctx := context.Background() + + // Create streaming call with all callbacks + streamCall := AgentStreamCall{ + Prompt: "Test all callbacks", + OnAgentStart: func() { + callbacks["OnAgentStart"] = true + }, + OnAgentFinish: func(result *AgentResult) { + callbacks["OnAgentFinish"] = true + }, + OnStepStart: func(stepNumber int) { + callbacks["OnStepStart"] = true + }, + OnStepFinish: func(stepResult StepResult) { + callbacks["OnStepFinish"] = true + }, + OnFinish: func(result *AgentResult) { + callbacks["OnFinish"] = true + }, + OnError: func(err error) { + callbacks["OnError"] = true + }, + OnChunk: func(part StreamPart) { + callbacks["OnChunk"] = true + }, + OnWarnings: func(warnings []CallWarning) { + callbacks["OnWarnings"] = true + }, + OnTextStart: func(id string) { + callbacks["OnTextStart"] = true + }, + OnTextDelta: func(id, text string) { + callbacks["OnTextDelta"] = true + }, + OnTextEnd: func(id string) { + callbacks["OnTextEnd"] = true + }, + OnReasoningStart: func(id string) { + callbacks["OnReasoningStart"] = true + }, + OnReasoningDelta: func(id, text string) { + callbacks["OnReasoningDelta"] = true + }, + OnReasoningEnd: func(id string) { + callbacks["OnReasoningEnd"] = true + }, + OnToolInputStart: func(id, toolName string) { + callbacks["OnToolInputStart"] = true + }, + OnToolInputDelta: func(id, delta string) { + callbacks["OnToolInputDelta"] = true + }, + OnToolInputEnd: func(id string) { + callbacks["OnToolInputEnd"] = true + }, + OnToolCall: func(toolCall ToolCallContent) { + callbacks["OnToolCall"] = true + }, + OnToolResult: func(result ToolResultContent) { + callbacks["OnToolResult"] = true + }, + OnSource: func(source SourceContent) { + callbacks["OnSource"] = true + }, + OnStreamFinish: func(usage Usage, finishReason FinishReason, providerMetadata ProviderOptions) { + callbacks["OnStreamFinish"] = true + }, + OnStreamError: func(err error) { + callbacks["OnStreamError"] = true + }, + } + + // Execute streaming agent + result, err := agent.Stream(ctx, streamCall) + require.NoError(t, err) + require.NotNil(t, result) + + // Verify that expected callbacks were called + expectedCallbacks := []string{ + "OnAgentStart", + "OnAgentFinish", + "OnStepStart", + "OnStepFinish", + "OnFinish", + "OnChunk", + "OnWarnings", + "OnTextStart", + "OnTextDelta", + "OnTextEnd", + "OnReasoningStart", + "OnReasoningDelta", + "OnReasoningEnd", + "OnToolInputStart", + "OnToolInputDelta", + "OnToolInputEnd", + "OnSource", + "OnStreamFinish", + } + + for _, callback := range expectedCallbacks { + require.True(t, callbacks[callback], "Expected callback %s to be called", callback) + } + + // Verify that error callbacks were not called + require.False(t, callbacks["OnError"], "OnError should not be called in successful case") + require.False(t, callbacks["OnStreamError"], "OnStreamError should not be called in successful case") + require.False(t, callbacks["OnToolCall"], "OnToolCall should not be called without actual tool calls") + require.False(t, callbacks["OnToolResult"], "OnToolResult should not be called without actual tool results") +} + +// TestStreamingAgentWithTools tests streaming agent with tool calls (mirrors TS test patterns) +func TestStreamingAgentWithTools(t *testing.T) { + t.Parallel() + + stepCount := 0 + // Create a mock language model that makes a tool call then finishes + mockModel := &mockLanguageModel{ + streamFunc: func(ctx context.Context, call Call) (StreamResponse, error) { + stepCount++ + return func(yield func(StreamPart) bool) { + if stepCount == 1 { + // First step: make tool call + if !yield(StreamPart{Type: StreamPartTypeToolInputStart, ID: "tool-1", ToolCallName: "echo"}) { + return + } + if !yield(StreamPart{Type: StreamPartTypeToolInputDelta, ID: "tool-1", Delta: `{"message"`}) { + return + } + if !yield(StreamPart{Type: StreamPartTypeToolInputDelta, ID: "tool-1", Delta: `: "test"}`}) { + return + } + if !yield(StreamPart{Type: StreamPartTypeToolInputEnd, ID: "tool-1"}) { + return + } + if !yield(StreamPart{ + Type: StreamPartTypeToolCall, + ID: "tool-1", + ToolCallName: "echo", + ToolCallInput: `{"message": "test"}`, + }) { + return + } + yield(StreamPart{ + Type: StreamPartTypeFinish, + Usage: Usage{InputTokens: 10, OutputTokens: 5, TotalTokens: 15}, + FinishReason: FinishReasonToolCalls, + }) + } else { + // Second step: finish after tool execution + if !yield(StreamPart{Type: StreamPartTypeTextStart, ID: "text-1"}) { + return + } + if !yield(StreamPart{Type: StreamPartTypeTextDelta, ID: "text-1", Delta: "Tool executed successfully"}) { + return + } + if !yield(StreamPart{Type: StreamPartTypeTextEnd, ID: "text-1"}) { + return + } + yield(StreamPart{ + Type: StreamPartTypeFinish, + Usage: Usage{InputTokens: 5, OutputTokens: 3, TotalTokens: 8}, + FinishReason: FinishReasonStop, + }) + } + }, nil + }, + } + + // Create agent with echo tool + agent := NewAgent( + mockModel, + WithSystemPrompt("You are a helpful assistant."), + WithTools(&EchoTool{}), + ) + + ctx := context.Background() + + // Track callback invocations + var toolInputStartCalled bool + var toolInputDeltaCalled bool + var toolInputEndCalled bool + var toolCallCalled bool + var toolResultCalled bool + + // Create streaming call with callbacks + streamCall := AgentStreamCall{ + Prompt: "Echo 'test'", + OnToolInputStart: func(id, toolName string) { + toolInputStartCalled = true + require.Equal(t, "tool-1", id) + require.Equal(t, "echo", toolName) + }, + OnToolInputDelta: func(id, delta string) { + toolInputDeltaCalled = true + require.Equal(t, "tool-1", id) + require.Contains(t, []string{`{"message"`, `: "test"}`}, delta) + }, + OnToolInputEnd: func(id string) { + toolInputEndCalled = true + require.Equal(t, "tool-1", id) + }, + OnToolCall: func(toolCall ToolCallContent) { + toolCallCalled = true + require.Equal(t, "echo", toolCall.ToolName) + require.Equal(t, `{"message": "test"}`, toolCall.Input) + }, + OnToolResult: func(result ToolResultContent) { + toolResultCalled = true + require.Equal(t, "echo", result.ToolName) + }, + } + + // Execute streaming agent + result, err := agent.Stream(ctx, streamCall) + require.NoError(t, err) + + // Verify results + require.True(t, toolInputStartCalled, "OnToolInputStart should have been called") + require.True(t, toolInputDeltaCalled, "OnToolInputDelta should have been called") + require.True(t, toolInputEndCalled, "OnToolInputEnd should have been called") + require.True(t, toolCallCalled, "OnToolCall should have been called") + require.True(t, toolResultCalled, "OnToolResult should have been called") + require.Equal(t, 2, len(result.Steps)) // Two steps: tool call + final response + + // Check that tool was executed in first step + firstStep := result.Steps[0] + toolCalls := firstStep.Content.ToolCalls() + require.Equal(t, 1, len(toolCalls)) + require.Equal(t, "echo", toolCalls[0].ToolName) + + toolResults := firstStep.Content.ToolResults() + require.Equal(t, 1, len(toolResults)) + require.Equal(t, "echo", toolResults[0].ToolName) +} + +// TestStreamingAgentTextDeltas tests text streaming (mirrors TS textStream tests) +func TestStreamingAgentTextDeltas(t *testing.T) { + t.Parallel() + + // Create a mock language model that returns text deltas + mockModel := &mockLanguageModel{ + streamFunc: func(ctx context.Context, call Call) (StreamResponse, error) { + return func(yield func(StreamPart) bool) { + if !yield(StreamPart{Type: StreamPartTypeTextStart, ID: "text-1"}) { + return + } + if !yield(StreamPart{Type: StreamPartTypeTextDelta, ID: "text-1", Delta: "Hello"}) { + return + } + if !yield(StreamPart{Type: StreamPartTypeTextDelta, ID: "text-1", Delta: ", "}) { + return + } + if !yield(StreamPart{Type: StreamPartTypeTextDelta, ID: "text-1", Delta: "world!"}) { + return + } + if !yield(StreamPart{Type: StreamPartTypeTextEnd, ID: "text-1"}) { + return + } + yield(StreamPart{ + Type: StreamPartTypeFinish, + Usage: Usage{InputTokens: 3, OutputTokens: 10, TotalTokens: 13}, + FinishReason: FinishReasonStop, + }) + }, nil + }, + } + + agent := NewAgent(mockModel) + ctx := context.Background() + + // Track text deltas + var textDeltas []string + + streamCall := AgentStreamCall{ + Prompt: "Say hello", + OnTextDelta: func(id, text string) { + if text != "" { + textDeltas = append(textDeltas, text) + } + }, + } + + result, err := agent.Stream(ctx, streamCall) + require.NoError(t, err) + + // Verify text deltas match expected pattern + require.Equal(t, []string{"Hello", ", ", "world!"}, textDeltas) + require.Equal(t, "Hello, world!", result.Response.Content.Text()) + require.Equal(t, int64(13), result.TotalUsage.TotalTokens) +} + +// TestStreamingAgentReasoning tests reasoning content (mirrors TS reasoning tests) +func TestStreamingAgentReasoning(t *testing.T) { + t.Parallel() + + mockModel := &mockLanguageModel{ + streamFunc: func(ctx context.Context, call Call) (StreamResponse, error) { + return func(yield func(StreamPart) bool) { + if !yield(StreamPart{Type: StreamPartTypeReasoningStart, ID: "reasoning-1"}) { + return + } + if !yield(StreamPart{Type: StreamPartTypeReasoningDelta, ID: "reasoning-1", Delta: "I will open the conversation"}) { + return + } + if !yield(StreamPart{Type: StreamPartTypeReasoningDelta, ID: "reasoning-1", Delta: " with witty banter."}) { + return + } + if !yield(StreamPart{Type: StreamPartTypeReasoningEnd, ID: "reasoning-1"}) { + return + } + if !yield(StreamPart{Type: StreamPartTypeTextStart, ID: "text-1"}) { + return + } + if !yield(StreamPart{Type: StreamPartTypeTextDelta, ID: "text-1", Delta: "Hi there!"}) { + return + } + if !yield(StreamPart{Type: StreamPartTypeTextEnd, ID: "text-1"}) { + return + } + yield(StreamPart{ + Type: StreamPartTypeFinish, + Usage: Usage{InputTokens: 5, OutputTokens: 15, TotalTokens: 20}, + FinishReason: FinishReasonStop, + }) + }, nil + }, + } + + agent := NewAgent(mockModel) + ctx := context.Background() + + var reasoningDeltas []string + var textDeltas []string + + streamCall := AgentStreamCall{ + Prompt: "Think and respond", + OnReasoningDelta: func(id, text string) { + reasoningDeltas = append(reasoningDeltas, text) + }, + OnTextDelta: func(id, text string) { + textDeltas = append(textDeltas, text) + }, + } + + result, err := agent.Stream(ctx, streamCall) + require.NoError(t, err) + + // Verify reasoning and text are separate + require.Equal(t, []string{"I will open the conversation", " with witty banter."}, reasoningDeltas) + require.Equal(t, []string{"Hi there!"}, textDeltas) + require.Equal(t, "Hi there!", result.Response.Content.Text()) + require.Equal(t, "I will open the conversation with witty banter.", result.Response.Content.ReasoningText()) +} + +// TestStreamingAgentError tests error handling (mirrors TS error tests) +func TestStreamingAgentError(t *testing.T) { + t.Parallel() + + // Create a mock language model that returns an error + mockModel := &mockLanguageModel{ + streamFunc: func(ctx context.Context, call Call) (StreamResponse, error) { + return func(yield func(StreamPart) bool) { + yield(StreamPart{Type: StreamPartTypeError, Error: fmt.Errorf("mock stream error")}) + }, nil + }, + } + + agent := NewAgent(mockModel) + ctx := context.Background() + + // Track error callbacks + var streamErrorOccurred bool + var errorOccurred bool + var errorMessage string + + streamCall := AgentStreamCall{ + Prompt: "This will fail", + OnStreamError: func(err error) { + streamErrorOccurred = true + }, + OnError: func(err error) { + errorOccurred = true + errorMessage = err.Error() + }, + } + + // Execute streaming agent + result, err := agent.Stream(ctx, streamCall) + require.Error(t, err) + require.Nil(t, result) + require.True(t, streamErrorOccurred, "OnStreamError should have been called") + require.True(t, errorOccurred, "OnError should have been called") + require.Contains(t, errorMessage, "mock stream error") +} + +// TestStreamingAgentSources tests source handling (mirrors TS source tests) +func TestStreamingAgentSources(t *testing.T) { + t.Parallel() + + mockModel := &mockLanguageModel{ + streamFunc: func(ctx context.Context, call Call) (StreamResponse, error) { + return func(yield func(StreamPart) bool) { + if !yield(StreamPart{ + Type: StreamPartTypeSource, + ID: "source-1", + SourceType: SourceTypeURL, + URL: "https://example.com", + Title: "Example", + }) { + return + } + if !yield(StreamPart{Type: StreamPartTypeTextStart, ID: "text-1"}) { + return + } + if !yield(StreamPart{Type: StreamPartTypeTextDelta, ID: "text-1", Delta: "Hello!"}) { + return + } + if !yield(StreamPart{Type: StreamPartTypeTextEnd, ID: "text-1"}) { + return + } + if !yield(StreamPart{ + Type: StreamPartTypeSource, + ID: "source-2", + SourceType: SourceTypeDocument, + Title: "Document Example", + }) { + return + } + yield(StreamPart{ + Type: StreamPartTypeFinish, + Usage: Usage{InputTokens: 3, OutputTokens: 5, TotalTokens: 8}, + FinishReason: FinishReasonStop, + }) + }, nil + }, + } + + agent := NewAgent(mockModel) + ctx := context.Background() + + var sources []SourceContent + + streamCall := AgentStreamCall{ + Prompt: "Search and respond", + OnSource: func(source SourceContent) { + sources = append(sources, source) + }, + } + + result, err := agent.Stream(ctx, streamCall) + require.NoError(t, err) + + // Verify sources were captured + require.Equal(t, 2, len(sources)) + require.Equal(t, SourceTypeURL, sources[0].SourceType) + require.Equal(t, "https://example.com", sources[0].URL) + require.Equal(t, "Example", sources[0].Title) + require.Equal(t, SourceTypeDocument, sources[1].SourceType) + require.Equal(t, "Document Example", sources[1].Title) + + // Verify sources are in final result + resultSources := result.Response.Content.Sources() + require.Equal(t, 2, len(resultSources)) +} \ No newline at end of file diff --git a/agent_test.go b/agent_test.go index bfffc1ec175893030aeff1dd45f42480d1e0ff67..75b4f8948acadc694bc35a142cfa08994a0bf35e 100644 --- a/agent_test.go +++ b/agent_test.go @@ -39,6 +39,7 @@ func (m *mockTool) Run(ctx context.Context, call tools.ToolCall) (tools.ToolResp // Mock language model for testing type mockLanguageModel struct { generateFunc func(ctx context.Context, call Call) (*Response, error) + streamFunc func(ctx context.Context, call Call) (StreamResponse, error) } func (m *mockLanguageModel) Generate(ctx context.Context, call Call) (*Response, error) { @@ -59,7 +60,10 @@ func (m *mockLanguageModel) Generate(ctx context.Context, call Call) (*Response, } func (m *mockLanguageModel) Stream(ctx context.Context, call Call) (StreamResponse, error) { - panic("not implemented") + if m.streamFunc != nil { + return m.streamFunc(ctx, call) + } + return nil, fmt.Errorf("mock stream not implemented") } func (m *mockLanguageModel) Provider() string { diff --git a/providers/examples/streaming-agent-simple/main.go b/providers/examples/streaming-agent-simple/main.go new file mode 100644 index 0000000000000000000000000000000000000000..2f8392e8640f3bff6bffa9338076d407dbaab864 --- /dev/null +++ b/providers/examples/streaming-agent-simple/main.go @@ -0,0 +1,95 @@ +package main + +import ( + "context" + "fmt" + "os" + + "github.com/charmbracelet/crush/internal/ai" + "github.com/charmbracelet/crush/internal/ai/providers" + "github.com/charmbracelet/crush/internal/llm/tools" +) + +// Simple echo tool for demonstration +type EchoTool struct{} + +func (e *EchoTool) Info() tools.ToolInfo { + return tools.ToolInfo{ + Name: "echo", + Description: "Echo back the provided message", + Parameters: map[string]any{ + "message": map[string]any{ + "type": "string", + "description": "The message to echo back", + }, + }, + Required: []string{"message"}, + } +} + +func (e *EchoTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolResponse, error) { + return tools.NewTextResponse("Echo: " + params.Input), nil +} + +func main() { + // Check for API key + apiKey := os.Getenv("OPENAI_API_KEY") + if apiKey == "" { + fmt.Println("Please set OPENAI_API_KEY environment variable") + os.Exit(1) + } + + // Create provider and model + provider := providers.NewOpenAIProvider( + providers.WithOpenAIApiKey(apiKey), + ) + model := provider.LanguageModel("gpt-4o-mini") + + // Create streaming agent + agent := ai.NewAgent( + model, + ai.WithSystemPrompt("You are a helpful assistant."), + ai.WithTools(&EchoTool{}), + ) + + ctx := context.Background() + + fmt.Println("Simple Streaming Agent Example") + fmt.Println("==============================") + fmt.Println() + + // Basic streaming with key callbacks + streamCall := ai.AgentStreamCall{ + Prompt: "Please echo back 'Hello, streaming world!'", + + // Show real-time text as it streams + OnTextDelta: func(id, text string) { + fmt.Print(text) + }, + + // Show when tools are called + OnToolCall: func(toolCall ai.ToolCallContent) { + fmt.Printf("\n[Tool: %s called]\n", toolCall.ToolName) + }, + + // Show tool results + OnToolResult: func(result ai.ToolResultContent) { + fmt.Printf("[Tool result received]\n") + }, + + // Show when each step completes + OnStepFinish: func(step ai.StepResult) { + fmt.Printf("\n[Step completed: %s]\n", step.FinishReason) + }, + } + + fmt.Println("Assistant response:") + result, err := agent.Stream(ctx, streamCall) + if err != nil { + fmt.Printf("Error: %v\n", err) + os.Exit(1) + } + + fmt.Printf("\n\nFinal result: %s\n", result.Response.Content.Text()) + fmt.Printf("Steps: %d, Total tokens: %d\n", len(result.Steps), result.TotalUsage.TotalTokens) +} \ No newline at end of file diff --git a/providers/examples/streaming-agent/main.go b/providers/examples/streaming-agent/main.go new file mode 100644 index 0000000000000000000000000000000000000000..722a81dab6050649c9c6f9f8482037e4ddfef454 --- /dev/null +++ b/providers/examples/streaming-agent/main.go @@ -0,0 +1,274 @@ +package main + +import ( + "context" + "fmt" + "os" + "strings" + + "github.com/charmbracelet/crush/internal/ai" + "github.com/charmbracelet/crush/internal/ai/providers" + "github.com/charmbracelet/crush/internal/llm/tools" +) + +// WeatherTool is a simple tool that simulates weather lookup +type WeatherTool struct{} + +func (w *WeatherTool) Info() tools.ToolInfo { + return tools.ToolInfo{ + Name: "get_weather", + Description: "Get the current weather for a specific location", + Parameters: map[string]any{ + "location": map[string]any{ + "type": "string", + "description": "The city and country, e.g. 'London, UK'", + }, + "unit": map[string]any{ + "type": "string", + "description": "Temperature unit (celsius or fahrenheit)", + "enum": []string{"celsius", "fahrenheit"}, + "default": "celsius", + }, + }, + Required: []string{"location"}, + } +} + +func (w *WeatherTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolResponse, error) { + // Simulate weather lookup with some fake data + location := "Unknown" + if strings.Contains(params.Input, "pristina") || strings.Contains(params.Input, "Pristina") { + location = "Pristina, Kosovo" + } else if strings.Contains(params.Input, "london") || strings.Contains(params.Input, "London") { + location = "London, UK" + } else if strings.Contains(params.Input, "new york") || strings.Contains(params.Input, "New York") { + location = "New York, USA" + } + + unit := "celsius" + if strings.Contains(params.Input, "fahrenheit") { + unit = "fahrenheit" + } + + var temp string + if unit == "fahrenheit" { + temp = "72°F" + } else { + temp = "22°C" + } + + weather := fmt.Sprintf("The current weather in %s is %s with partly cloudy skies and light winds.", location, temp) + return tools.NewTextResponse(weather), nil +} + +// CalculatorTool demonstrates a second tool for multi-tool scenarios +type CalculatorTool struct{} + +func (c *CalculatorTool) Info() tools.ToolInfo { + return tools.ToolInfo{ + Name: "calculate", + Description: "Perform basic mathematical calculations", + Parameters: map[string]any{ + "expression": map[string]any{ + "type": "string", + "description": "Mathematical expression to evaluate (e.g., '2 + 2', '10 * 5')", + }, + }, + Required: []string{"expression"}, + } +} + +func (c *CalculatorTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolResponse, error) { + // Simple calculator simulation + expr := strings.ReplaceAll(params.Input, "\"", "") + if strings.Contains(expr, "2 + 2") || strings.Contains(expr, "2+2") { + return tools.NewTextResponse("2 + 2 = 4"), nil + } else if strings.Contains(expr, "10 * 5") || strings.Contains(expr, "10*5") { + return tools.NewTextResponse("10 * 5 = 50"), nil + } + return tools.NewTextResponse("I can calculate simple expressions like '2 + 2' or '10 * 5'"), nil +} + +func main() { + // Check for API key + apiKey := os.Getenv("OPENAI_API_KEY") + if apiKey == "" { + fmt.Println("❌ Please set OPENAI_API_KEY environment variable") + fmt.Println(" export OPENAI_API_KEY=your_api_key_here") + os.Exit(1) + } + + fmt.Println("🚀 Streaming Agent Example") + fmt.Println("==========================") + fmt.Println() + + // Create OpenAI provider and model + provider := providers.NewOpenAIProvider( + providers.WithOpenAIApiKey(apiKey), + ) + model := provider.LanguageModel("gpt-4o-mini") // Using mini for faster/cheaper responses + + // Create agent with tools + agent := ai.NewAgent( + model, + ai.WithSystemPrompt("You are a helpful assistant that can check weather and do calculations. Be concise and friendly."), + ai.WithTools(&WeatherTool{}, &CalculatorTool{}), + ) + + ctx := context.Background() + + // Demonstrate streaming with comprehensive callbacks + fmt.Println("💬 Asking: \"What's the weather in Pristina and what's 2 + 2?\"") + fmt.Println() + + // Track streaming events + var stepCount int + var textBuffer strings.Builder + var reasoningBuffer strings.Builder + + // Create streaming call with all callbacks + streamCall := ai.AgentStreamCall{ + Prompt: "What's the weather in Pristina and what's 2 + 2?", + + // Agent-level callbacks + OnAgentStart: func() { + fmt.Println("🎬 Agent started") + }, + OnAgentFinish: func(result *ai.AgentResult) { + fmt.Printf("🏁 Agent finished with %d steps, total tokens: %d\n", len(result.Steps), result.TotalUsage.TotalTokens) + }, + OnStepStart: func(stepNumber int) { + stepCount++ + fmt.Printf("📝 Step %d started\n", stepNumber+1) + }, + OnStepFinish: func(stepResult ai.StepResult) { + fmt.Printf("✅ Step completed (reason: %s, tokens: %d)\n", stepResult.FinishReason, stepResult.Usage.TotalTokens) + }, + OnFinish: func(result *ai.AgentResult) { + fmt.Printf("🎯 Final result ready with %d steps\n", len(result.Steps)) + }, + OnError: func(err error) { + fmt.Printf("❌ Error: %v\n", err) + }, + + // Stream part callbacks + OnWarnings: func(warnings []ai.CallWarning) { + for _, warning := range warnings { + fmt.Printf("⚠️ Warning: %s\n", warning.Message) + } + }, + OnTextStart: func(id string) { + fmt.Print("💭 Assistant: ") + }, + OnTextDelta: func(id, text string) { + fmt.Print(text) + textBuffer.WriteString(text) + }, + OnTextEnd: func(id string) { + fmt.Println() + }, + OnReasoningStart: func(id string) { + fmt.Print("🤔 Thinking: ") + }, + OnReasoningDelta: func(id, text string) { + reasoningBuffer.WriteString(text) + }, + OnReasoningEnd: func(id string) { + if reasoningBuffer.Len() > 0 { + fmt.Printf("%s\n", reasoningBuffer.String()) + reasoningBuffer.Reset() + } + }, + OnToolInputStart: func(id, toolName string) { + fmt.Printf("🔧 Calling tool: %s\n", toolName) + }, + OnToolInputDelta: func(id, delta string) { + // Could show tool input being built, but it's often noisy + }, + OnToolInputEnd: func(id string) { + // Tool input complete + }, + OnToolCall: func(toolCall ai.ToolCallContent) { + fmt.Printf("🛠️ Tool call: %s\n", toolCall.ToolName) + fmt.Printf(" Input: %s\n", toolCall.Input) + }, + OnToolResult: func(result ai.ToolResultContent) { + fmt.Printf("🎯 Tool result from %s:\n", result.ToolName) + switch output := result.Result.(type) { + case ai.ToolResultOutputContentText: + fmt.Printf(" %s\n", output.Text) + case ai.ToolResultOutputContentError: + fmt.Printf(" Error: %s\n", output.Error.Error()) + } + }, + OnSource: func(source ai.SourceContent) { + fmt.Printf("📚 Source: %s (%s)\n", source.Title, source.URL) + }, + OnStreamFinish: func(usage ai.Usage, finishReason ai.FinishReason, providerMetadata ai.ProviderOptions) { + fmt.Printf("📊 Stream finished (reason: %s, tokens: %d)\n", finishReason, usage.TotalTokens) + }, + OnStreamError: func(err error) { + fmt.Printf("💥 Stream error: %v\n", err) + }, + } + + // Execute streaming agent + result, err := agent.Stream(ctx, streamCall) + if err != nil { + fmt.Printf("❌ Agent failed: %v\n", err) + os.Exit(1) + } + + // Display final results + fmt.Println() + fmt.Println("📋 Final Summary") + fmt.Println("================") + fmt.Printf("Steps executed: %d\n", len(result.Steps)) + fmt.Printf("Total tokens used: %d (input: %d, output: %d)\n", + result.TotalUsage.TotalTokens, + result.TotalUsage.InputTokens, + result.TotalUsage.OutputTokens) + + if result.TotalUsage.ReasoningTokens > 0 { + fmt.Printf("Reasoning tokens: %d\n", result.TotalUsage.ReasoningTokens) + } + + fmt.Printf("Final response: %s\n", result.Response.Content.Text()) + + // Show step details + fmt.Println() + fmt.Println("🔍 Step Details") + fmt.Println("===============") + for i, step := range result.Steps { + fmt.Printf("Step %d:\n", i+1) + fmt.Printf(" Finish reason: %s\n", step.FinishReason) + fmt.Printf(" Content types: ") + + var contentTypes []string + for _, content := range step.Content { + contentTypes = append(contentTypes, string(content.GetType())) + } + fmt.Printf("%s\n", strings.Join(contentTypes, ", ")) + + // Show tool calls and results + toolCalls := step.Content.ToolCalls() + if len(toolCalls) > 0 { + fmt.Printf(" Tool calls: ") + var toolNames []string + for _, tc := range toolCalls { + toolNames = append(toolNames, tc.ToolName) + } + fmt.Printf("%s\n", strings.Join(toolNames, ", ")) + } + + toolResults := step.Content.ToolResults() + if len(toolResults) > 0 { + fmt.Printf(" Tool results: %d\n", len(toolResults)) + } + + fmt.Printf(" Tokens: %d\n", step.Usage.TotalTokens) + fmt.Println() + } + + fmt.Println("✨ Example completed successfully!") +} \ No newline at end of file diff --git a/providers/openai.go b/providers/openai.go index d6ff2ac9fa78d6c21244f151df199c352cc6e18c..300570680f981f66ea58060847d70be19fa9427b 100644 --- a/providers/openai.go +++ b/providers/openai.go @@ -1134,17 +1134,17 @@ func parseAnnotationsFromDelta(delta openai.ChatCompletionChunkChoiceDelta) []op var annotations []openai.ChatCompletionMessageAnnotation // Parse the raw JSON to extract annotations - var deltaData map[string]interface{} + var deltaData map[string]any if err := json.Unmarshal([]byte(delta.RawJSON()), &deltaData); err != nil { return annotations } // Check if annotations exist in the delta - if annotationsData, ok := deltaData["annotations"].([]interface{}); ok { + if annotationsData, ok := deltaData["annotations"].([]any); ok { for _, annotationData := range annotationsData { - if annotationMap, ok := annotationData.(map[string]interface{}); ok { + if annotationMap, ok := annotationData.(map[string]any); ok { if annotationType, ok := annotationMap["type"].(string); ok && annotationType == "url_citation" { - if urlCitationData, ok := annotationMap["url_citation"].(map[string]interface{}); ok { + if urlCitationData, ok := annotationMap["url_citation"].(map[string]any); ok { annotation := openai.ChatCompletionMessageAnnotation{ Type: "url_citation", URLCitation: openai.ChatCompletionMessageAnnotationURLCitation{