agent_stream_test.go

  1package ai
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"fmt"
  7	"testing"
  8
  9	"github.com/charmbracelet/crush/internal/llm/tools"
 10	"github.com/stretchr/testify/require"
 11)
 12
 13// EchoTool is a simple tool that echoes back the input message
 14type EchoTool struct{}
 15
 16// Info returns the tool information
 17func (e *EchoTool) Info() tools.ToolInfo {
 18	return tools.ToolInfo{
 19		Name:        "echo",
 20		Description: "Echo back the provided message",
 21		Parameters: map[string]any{
 22			"message": map[string]any{
 23				"type":        "string",
 24				"description": "The message to echo back",
 25			},
 26		},
 27		Required: []string{"message"},
 28	}
 29}
 30
 31// Run executes the echo tool
 32func (e *EchoTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolResponse, error) {
 33	var input struct {
 34		Message string `json:"message"`
 35	}
 36
 37	if err := json.Unmarshal([]byte(params.Input), &input); err != nil {
 38		return tools.NewTextErrorResponse("Invalid input: " + err.Error()), nil
 39	}
 40
 41	if input.Message == "" {
 42		return tools.NewTextErrorResponse("Message cannot be empty"), nil
 43	}
 44
 45	return tools.NewTextResponse("Echo: " + input.Message), nil
 46}
 47
 48// TestStreamingAgentCallbacks tests that all streaming callbacks are called correctly
 49func TestStreamingAgentCallbacks(t *testing.T) {
 50	t.Parallel()
 51
 52	// Track which callbacks were called
 53	callbacks := make(map[string]bool)
 54
 55	// Create a mock language model that returns various stream parts
 56	mockModel := &mockLanguageModel{
 57		streamFunc: func(ctx context.Context, call Call) (StreamResponse, error) {
 58			return func(yield func(StreamPart) bool) {
 59				// Test all stream part types
 60				if !yield(StreamPart{Type: StreamPartTypeWarnings, Warnings: []CallWarning{{Type: CallWarningTypeOther, Message: "test warning"}}}) {
 61					return
 62				}
 63				if !yield(StreamPart{Type: StreamPartTypeTextStart, ID: "text-1"}) {
 64					return
 65				}
 66				if !yield(StreamPart{Type: StreamPartTypeTextDelta, ID: "text-1", Delta: "Hello"}) {
 67					return
 68				}
 69				if !yield(StreamPart{Type: StreamPartTypeTextEnd, ID: "text-1"}) {
 70					return
 71				}
 72				if !yield(StreamPart{Type: StreamPartTypeReasoningStart, ID: "reasoning-1"}) {
 73					return
 74				}
 75				if !yield(StreamPart{Type: StreamPartTypeReasoningDelta, ID: "reasoning-1", Delta: "thinking..."}) {
 76					return
 77				}
 78				if !yield(StreamPart{Type: StreamPartTypeReasoningEnd, ID: "reasoning-1"}) {
 79					return
 80				}
 81				if !yield(StreamPart{Type: StreamPartTypeToolInputStart, ID: "tool-1", ToolCallName: "test_tool"}) {
 82					return
 83				}
 84				if !yield(StreamPart{Type: StreamPartTypeToolInputDelta, ID: "tool-1", Delta: `{"param"`}) {
 85					return
 86				}
 87				if !yield(StreamPart{Type: StreamPartTypeToolInputEnd, ID: "tool-1"}) {
 88					return
 89				}
 90				if !yield(StreamPart{Type: StreamPartTypeSource, ID: "source-1", SourceType: SourceTypeURL, URL: "https://example.com", Title: "Example"}) {
 91					return
 92				}
 93				yield(StreamPart{
 94					Type:         StreamPartTypeFinish,
 95					Usage:        Usage{InputTokens: 5, OutputTokens: 2, TotalTokens: 7},
 96					FinishReason: FinishReasonStop,
 97				})
 98			}, nil
 99		},
100	}
101
102	// Create agent
103	agent := NewAgent(mockModel)
104
105	ctx := context.Background()
106
107	// Create streaming call with all callbacks
108	streamCall := AgentStreamCall{
109		Prompt: "Test all callbacks",
110		OnAgentStart: func() {
111			callbacks["OnAgentStart"] = true
112		},
113		OnAgentFinish: func(result *AgentResult) {
114			callbacks["OnAgentFinish"] = true
115		},
116		OnStepStart: func(stepNumber int) {
117			callbacks["OnStepStart"] = true
118		},
119		OnStepFinish: func(stepResult StepResult) {
120			callbacks["OnStepFinish"] = true
121		},
122		OnFinish: func(result *AgentResult) {
123			callbacks["OnFinish"] = true
124		},
125		OnError: func(err error) {
126			callbacks["OnError"] = true
127		},
128		OnChunk: func(part StreamPart) {
129			callbacks["OnChunk"] = true
130		},
131		OnWarnings: func(warnings []CallWarning) {
132			callbacks["OnWarnings"] = true
133		},
134		OnTextStart: func(id string) {
135			callbacks["OnTextStart"] = true
136		},
137		OnTextDelta: func(id, text string) {
138			callbacks["OnTextDelta"] = true
139		},
140		OnTextEnd: func(id string) {
141			callbacks["OnTextEnd"] = true
142		},
143		OnReasoningStart: func(id string) {
144			callbacks["OnReasoningStart"] = true
145		},
146		OnReasoningDelta: func(id, text string) {
147			callbacks["OnReasoningDelta"] = true
148		},
149		OnReasoningEnd: func(id string) {
150			callbacks["OnReasoningEnd"] = true
151		},
152		OnToolInputStart: func(id, toolName string) {
153			callbacks["OnToolInputStart"] = true
154		},
155		OnToolInputDelta: func(id, delta string) {
156			callbacks["OnToolInputDelta"] = true
157		},
158		OnToolInputEnd: func(id string) {
159			callbacks["OnToolInputEnd"] = true
160		},
161		OnToolCall: func(toolCall ToolCallContent) {
162			callbacks["OnToolCall"] = true
163		},
164		OnToolResult: func(result ToolResultContent) {
165			callbacks["OnToolResult"] = true
166		},
167		OnSource: func(source SourceContent) {
168			callbacks["OnSource"] = true
169		},
170		OnStreamFinish: func(usage Usage, finishReason FinishReason, providerMetadata ProviderOptions) {
171			callbacks["OnStreamFinish"] = true
172		},
173		OnStreamError: func(err error) {
174			callbacks["OnStreamError"] = true
175		},
176	}
177
178	// Execute streaming agent
179	result, err := agent.Stream(ctx, streamCall)
180	require.NoError(t, err)
181	require.NotNil(t, result)
182
183	// Verify that expected callbacks were called
184	expectedCallbacks := []string{
185		"OnAgentStart",
186		"OnAgentFinish",
187		"OnStepStart",
188		"OnStepFinish",
189		"OnFinish",
190		"OnChunk",
191		"OnWarnings",
192		"OnTextStart",
193		"OnTextDelta",
194		"OnTextEnd",
195		"OnReasoningStart",
196		"OnReasoningDelta",
197		"OnReasoningEnd",
198		"OnToolInputStart",
199		"OnToolInputDelta",
200		"OnToolInputEnd",
201		"OnSource",
202		"OnStreamFinish",
203	}
204
205	for _, callback := range expectedCallbacks {
206		require.True(t, callbacks[callback], "Expected callback %s to be called", callback)
207	}
208
209	// Verify that error callbacks were not called
210	require.False(t, callbacks["OnError"], "OnError should not be called in successful case")
211	require.False(t, callbacks["OnStreamError"], "OnStreamError should not be called in successful case")
212	require.False(t, callbacks["OnToolCall"], "OnToolCall should not be called without actual tool calls")
213	require.False(t, callbacks["OnToolResult"], "OnToolResult should not be called without actual tool results")
214}
215
216// TestStreamingAgentWithTools tests streaming agent with tool calls (mirrors TS test patterns)
217func TestStreamingAgentWithTools(t *testing.T) {
218	t.Parallel()
219
220	stepCount := 0
221	// Create a mock language model that makes a tool call then finishes
222	mockModel := &mockLanguageModel{
223		streamFunc: func(ctx context.Context, call Call) (StreamResponse, error) {
224			stepCount++
225			return func(yield func(StreamPart) bool) {
226				if stepCount == 1 {
227					// First step: make tool call
228					if !yield(StreamPart{Type: StreamPartTypeToolInputStart, ID: "tool-1", ToolCallName: "echo"}) {
229						return
230					}
231					if !yield(StreamPart{Type: StreamPartTypeToolInputDelta, ID: "tool-1", Delta: `{"message"`}) {
232						return
233					}
234					if !yield(StreamPart{Type: StreamPartTypeToolInputDelta, ID: "tool-1", Delta: `: "test"}`}) {
235						return
236					}
237					if !yield(StreamPart{Type: StreamPartTypeToolInputEnd, ID: "tool-1"}) {
238						return
239					}
240					if !yield(StreamPart{
241						Type:          StreamPartTypeToolCall,
242						ID:            "tool-1",
243						ToolCallName:  "echo",
244						ToolCallInput: `{"message": "test"}`,
245					}) {
246						return
247					}
248					yield(StreamPart{
249						Type:         StreamPartTypeFinish,
250						Usage:        Usage{InputTokens: 10, OutputTokens: 5, TotalTokens: 15},
251						FinishReason: FinishReasonToolCalls,
252					})
253				} else {
254					// Second step: finish after tool execution
255					if !yield(StreamPart{Type: StreamPartTypeTextStart, ID: "text-1"}) {
256						return
257					}
258					if !yield(StreamPart{Type: StreamPartTypeTextDelta, ID: "text-1", Delta: "Tool executed successfully"}) {
259						return
260					}
261					if !yield(StreamPart{Type: StreamPartTypeTextEnd, ID: "text-1"}) {
262						return
263					}
264					yield(StreamPart{
265						Type:         StreamPartTypeFinish,
266						Usage:        Usage{InputTokens: 5, OutputTokens: 3, TotalTokens: 8},
267						FinishReason: FinishReasonStop,
268					})
269				}
270			}, nil
271		},
272	}
273
274	// Create agent with echo tool
275	agent := NewAgent(
276		mockModel,
277		WithSystemPrompt("You are a helpful assistant."),
278		WithTools(&EchoTool{}),
279	)
280
281	ctx := context.Background()
282
283	// Track callback invocations
284	var toolInputStartCalled bool
285	var toolInputDeltaCalled bool
286	var toolInputEndCalled bool
287	var toolCallCalled bool
288	var toolResultCalled bool
289
290	// Create streaming call with callbacks
291	streamCall := AgentStreamCall{
292		Prompt: "Echo 'test'",
293		OnToolInputStart: func(id, toolName string) {
294			toolInputStartCalled = true
295			require.Equal(t, "tool-1", id)
296			require.Equal(t, "echo", toolName)
297		},
298		OnToolInputDelta: func(id, delta string) {
299			toolInputDeltaCalled = true
300			require.Equal(t, "tool-1", id)
301			require.Contains(t, []string{`{"message"`, `: "test"}`}, delta)
302		},
303		OnToolInputEnd: func(id string) {
304			toolInputEndCalled = true
305			require.Equal(t, "tool-1", id)
306		},
307		OnToolCall: func(toolCall ToolCallContent) {
308			toolCallCalled = true
309			require.Equal(t, "echo", toolCall.ToolName)
310			require.Equal(t, `{"message": "test"}`, toolCall.Input)
311		},
312		OnToolResult: func(result ToolResultContent) {
313			toolResultCalled = true
314			require.Equal(t, "echo", result.ToolName)
315		},
316	}
317
318	// Execute streaming agent
319	result, err := agent.Stream(ctx, streamCall)
320	require.NoError(t, err)
321
322	// Verify results
323	require.True(t, toolInputStartCalled, "OnToolInputStart should have been called")
324	require.True(t, toolInputDeltaCalled, "OnToolInputDelta should have been called")
325	require.True(t, toolInputEndCalled, "OnToolInputEnd should have been called")
326	require.True(t, toolCallCalled, "OnToolCall should have been called")
327	require.True(t, toolResultCalled, "OnToolResult should have been called")
328	require.Equal(t, 2, len(result.Steps)) // Two steps: tool call + final response
329
330	// Check that tool was executed in first step
331	firstStep := result.Steps[0]
332	toolCalls := firstStep.Content.ToolCalls()
333	require.Equal(t, 1, len(toolCalls))
334	require.Equal(t, "echo", toolCalls[0].ToolName)
335
336	toolResults := firstStep.Content.ToolResults()
337	require.Equal(t, 1, len(toolResults))
338	require.Equal(t, "echo", toolResults[0].ToolName)
339}
340
341// TestStreamingAgentTextDeltas tests text streaming (mirrors TS textStream tests)
342func TestStreamingAgentTextDeltas(t *testing.T) {
343	t.Parallel()
344
345	// Create a mock language model that returns text deltas
346	mockModel := &mockLanguageModel{
347		streamFunc: func(ctx context.Context, call Call) (StreamResponse, error) {
348			return func(yield func(StreamPart) bool) {
349				if !yield(StreamPart{Type: StreamPartTypeTextStart, ID: "text-1"}) {
350					return
351				}
352				if !yield(StreamPart{Type: StreamPartTypeTextDelta, ID: "text-1", Delta: "Hello"}) {
353					return
354				}
355				if !yield(StreamPart{Type: StreamPartTypeTextDelta, ID: "text-1", Delta: ", "}) {
356					return
357				}
358				if !yield(StreamPart{Type: StreamPartTypeTextDelta, ID: "text-1", Delta: "world!"}) {
359					return
360				}
361				if !yield(StreamPart{Type: StreamPartTypeTextEnd, ID: "text-1"}) {
362					return
363				}
364				yield(StreamPart{
365					Type:         StreamPartTypeFinish,
366					Usage:        Usage{InputTokens: 3, OutputTokens: 10, TotalTokens: 13},
367					FinishReason: FinishReasonStop,
368				})
369			}, nil
370		},
371	}
372
373	agent := NewAgent(mockModel)
374	ctx := context.Background()
375
376	// Track text deltas
377	var textDeltas []string
378
379	streamCall := AgentStreamCall{
380		Prompt: "Say hello",
381		OnTextDelta: func(id, text string) {
382			if text != "" {
383				textDeltas = append(textDeltas, text)
384			}
385		},
386	}
387
388	result, err := agent.Stream(ctx, streamCall)
389	require.NoError(t, err)
390
391	// Verify text deltas match expected pattern
392	require.Equal(t, []string{"Hello", ", ", "world!"}, textDeltas)
393	require.Equal(t, "Hello, world!", result.Response.Content.Text())
394	require.Equal(t, int64(13), result.TotalUsage.TotalTokens)
395}
396
397// TestStreamingAgentReasoning tests reasoning content (mirrors TS reasoning tests)
398func TestStreamingAgentReasoning(t *testing.T) {
399	t.Parallel()
400
401	mockModel := &mockLanguageModel{
402		streamFunc: func(ctx context.Context, call Call) (StreamResponse, error) {
403			return func(yield func(StreamPart) bool) {
404				if !yield(StreamPart{Type: StreamPartTypeReasoningStart, ID: "reasoning-1"}) {
405					return
406				}
407				if !yield(StreamPart{Type: StreamPartTypeReasoningDelta, ID: "reasoning-1", Delta: "I will open the conversation"}) {
408					return
409				}
410				if !yield(StreamPart{Type: StreamPartTypeReasoningDelta, ID: "reasoning-1", Delta: " with witty banter."}) {
411					return
412				}
413				if !yield(StreamPart{Type: StreamPartTypeReasoningEnd, ID: "reasoning-1"}) {
414					return
415				}
416				if !yield(StreamPart{Type: StreamPartTypeTextStart, ID: "text-1"}) {
417					return
418				}
419				if !yield(StreamPart{Type: StreamPartTypeTextDelta, ID: "text-1", Delta: "Hi there!"}) {
420					return
421				}
422				if !yield(StreamPart{Type: StreamPartTypeTextEnd, ID: "text-1"}) {
423					return
424				}
425				yield(StreamPart{
426					Type:         StreamPartTypeFinish,
427					Usage:        Usage{InputTokens: 5, OutputTokens: 15, TotalTokens: 20},
428					FinishReason: FinishReasonStop,
429				})
430			}, nil
431		},
432	}
433
434	agent := NewAgent(mockModel)
435	ctx := context.Background()
436
437	var reasoningDeltas []string
438	var textDeltas []string
439
440	streamCall := AgentStreamCall{
441		Prompt: "Think and respond",
442		OnReasoningDelta: func(id, text string) {
443			reasoningDeltas = append(reasoningDeltas, text)
444		},
445		OnTextDelta: func(id, text string) {
446			textDeltas = append(textDeltas, text)
447		},
448	}
449
450	result, err := agent.Stream(ctx, streamCall)
451	require.NoError(t, err)
452
453	// Verify reasoning and text are separate
454	require.Equal(t, []string{"I will open the conversation", " with witty banter."}, reasoningDeltas)
455	require.Equal(t, []string{"Hi there!"}, textDeltas)
456	require.Equal(t, "Hi there!", result.Response.Content.Text())
457	require.Equal(t, "I will open the conversation with witty banter.", result.Response.Content.ReasoningText())
458}
459
460// TestStreamingAgentError tests error handling (mirrors TS error tests)
461func TestStreamingAgentError(t *testing.T) {
462	t.Parallel()
463
464	// Create a mock language model that returns an error
465	mockModel := &mockLanguageModel{
466		streamFunc: func(ctx context.Context, call Call) (StreamResponse, error) {
467			return func(yield func(StreamPart) bool) {
468				yield(StreamPart{Type: StreamPartTypeError, Error: fmt.Errorf("mock stream error")})
469			}, nil
470		},
471	}
472
473	agent := NewAgent(mockModel)
474	ctx := context.Background()
475
476	// Track error callbacks
477	var streamErrorOccurred bool
478	var errorOccurred bool
479	var errorMessage string
480
481	streamCall := AgentStreamCall{
482		Prompt: "This will fail",
483		OnStreamError: func(err error) {
484			streamErrorOccurred = true
485		},
486		OnError: func(err error) {
487			errorOccurred = true
488			errorMessage = err.Error()
489		},
490	}
491
492	// Execute streaming agent
493	result, err := agent.Stream(ctx, streamCall)
494	require.Error(t, err)
495	require.Nil(t, result)
496	require.True(t, streamErrorOccurred, "OnStreamError should have been called")
497	require.True(t, errorOccurred, "OnError should have been called")
498	require.Contains(t, errorMessage, "mock stream error")
499}
500
501// TestStreamingAgentSources tests source handling (mirrors TS source tests)
502func TestStreamingAgentSources(t *testing.T) {
503	t.Parallel()
504
505	mockModel := &mockLanguageModel{
506		streamFunc: func(ctx context.Context, call Call) (StreamResponse, error) {
507			return func(yield func(StreamPart) bool) {
508				if !yield(StreamPart{
509					Type:       StreamPartTypeSource,
510					ID:         "source-1",
511					SourceType: SourceTypeURL,
512					URL:        "https://example.com",
513					Title:      "Example",
514				}) {
515					return
516				}
517				if !yield(StreamPart{Type: StreamPartTypeTextStart, ID: "text-1"}) {
518					return
519				}
520				if !yield(StreamPart{Type: StreamPartTypeTextDelta, ID: "text-1", Delta: "Hello!"}) {
521					return
522				}
523				if !yield(StreamPart{Type: StreamPartTypeTextEnd, ID: "text-1"}) {
524					return
525				}
526				if !yield(StreamPart{
527					Type:       StreamPartTypeSource,
528					ID:         "source-2",
529					SourceType: SourceTypeDocument,
530					Title:      "Document Example",
531				}) {
532					return
533				}
534				yield(StreamPart{
535					Type:         StreamPartTypeFinish,
536					Usage:        Usage{InputTokens: 3, OutputTokens: 5, TotalTokens: 8},
537					FinishReason: FinishReasonStop,
538				})
539			}, nil
540		},
541	}
542
543	agent := NewAgent(mockModel)
544	ctx := context.Background()
545
546	var sources []SourceContent
547
548	streamCall := AgentStreamCall{
549		Prompt: "Search and respond",
550		OnSource: func(source SourceContent) {
551			sources = append(sources, source)
552		},
553	}
554
555	result, err := agent.Stream(ctx, streamCall)
556	require.NoError(t, err)
557
558	// Verify sources were captured
559	require.Equal(t, 2, len(sources))
560	require.Equal(t, SourceTypeURL, sources[0].SourceType)
561	require.Equal(t, "https://example.com", sources[0].URL)
562	require.Equal(t, "Example", sources[0].Title)
563	require.Equal(t, SourceTypeDocument, sources[1].SourceType)
564	require.Equal(t, "Document Example", sources[1].Title)
565
566	// Verify sources are in final result
567	resultSources := result.Response.Content.Sources()
568	require.Equal(t, 2, len(resultSources))
569}