1package ai
2
3import (
4 "context"
5 "encoding/json"
6 "fmt"
7 "testing"
8
9 "github.com/charmbracelet/crush/internal/llm/tools"
10 "github.com/stretchr/testify/require"
11)
12
13// EchoTool is a simple tool that echoes back the input message
14type EchoTool struct{}
15
16// Info returns the tool information
17func (e *EchoTool) Info() tools.ToolInfo {
18 return tools.ToolInfo{
19 Name: "echo",
20 Description: "Echo back the provided message",
21 Parameters: map[string]any{
22 "message": map[string]any{
23 "type": "string",
24 "description": "The message to echo back",
25 },
26 },
27 Required: []string{"message"},
28 }
29}
30
31// Run executes the echo tool
32func (e *EchoTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolResponse, error) {
33 var input struct {
34 Message string `json:"message"`
35 }
36
37 if err := json.Unmarshal([]byte(params.Input), &input); err != nil {
38 return tools.NewTextErrorResponse("Invalid input: " + err.Error()), nil
39 }
40
41 if input.Message == "" {
42 return tools.NewTextErrorResponse("Message cannot be empty"), nil
43 }
44
45 return tools.NewTextResponse("Echo: " + input.Message), nil
46}
47
48// TestStreamingAgentCallbacks tests that all streaming callbacks are called correctly
49func TestStreamingAgentCallbacks(t *testing.T) {
50 t.Parallel()
51
52 // Track which callbacks were called
53 callbacks := make(map[string]bool)
54
55 // Create a mock language model that returns various stream parts
56 mockModel := &mockLanguageModel{
57 streamFunc: func(ctx context.Context, call Call) (StreamResponse, error) {
58 return func(yield func(StreamPart) bool) {
59 // Test all stream part types
60 if !yield(StreamPart{Type: StreamPartTypeWarnings, Warnings: []CallWarning{{Type: CallWarningTypeOther, Message: "test warning"}}}) {
61 return
62 }
63 if !yield(StreamPart{Type: StreamPartTypeTextStart, ID: "text-1"}) {
64 return
65 }
66 if !yield(StreamPart{Type: StreamPartTypeTextDelta, ID: "text-1", Delta: "Hello"}) {
67 return
68 }
69 if !yield(StreamPart{Type: StreamPartTypeTextEnd, ID: "text-1"}) {
70 return
71 }
72 if !yield(StreamPart{Type: StreamPartTypeReasoningStart, ID: "reasoning-1"}) {
73 return
74 }
75 if !yield(StreamPart{Type: StreamPartTypeReasoningDelta, ID: "reasoning-1", Delta: "thinking..."}) {
76 return
77 }
78 if !yield(StreamPart{Type: StreamPartTypeReasoningEnd, ID: "reasoning-1"}) {
79 return
80 }
81 if !yield(StreamPart{Type: StreamPartTypeToolInputStart, ID: "tool-1", ToolCallName: "test_tool"}) {
82 return
83 }
84 if !yield(StreamPart{Type: StreamPartTypeToolInputDelta, ID: "tool-1", Delta: `{"param"`}) {
85 return
86 }
87 if !yield(StreamPart{Type: StreamPartTypeToolInputEnd, ID: "tool-1"}) {
88 return
89 }
90 if !yield(StreamPart{Type: StreamPartTypeSource, ID: "source-1", SourceType: SourceTypeURL, URL: "https://example.com", Title: "Example"}) {
91 return
92 }
93 yield(StreamPart{
94 Type: StreamPartTypeFinish,
95 Usage: Usage{InputTokens: 5, OutputTokens: 2, TotalTokens: 7},
96 FinishReason: FinishReasonStop,
97 })
98 }, nil
99 },
100 }
101
102 // Create agent
103 agent := NewAgent(mockModel)
104
105 ctx := context.Background()
106
107 // Create streaming call with all callbacks
108 streamCall := AgentStreamCall{
109 Prompt: "Test all callbacks",
110 OnAgentStart: func() {
111 callbacks["OnAgentStart"] = true
112 },
113 OnAgentFinish: func(result *AgentResult) {
114 callbacks["OnAgentFinish"] = true
115 },
116 OnStepStart: func(stepNumber int) {
117 callbacks["OnStepStart"] = true
118 },
119 OnStepFinish: func(stepResult StepResult) {
120 callbacks["OnStepFinish"] = true
121 },
122 OnFinish: func(result *AgentResult) {
123 callbacks["OnFinish"] = true
124 },
125 OnError: func(err error) {
126 callbacks["OnError"] = true
127 },
128 OnChunk: func(part StreamPart) {
129 callbacks["OnChunk"] = true
130 },
131 OnWarnings: func(warnings []CallWarning) {
132 callbacks["OnWarnings"] = true
133 },
134 OnTextStart: func(id string) {
135 callbacks["OnTextStart"] = true
136 },
137 OnTextDelta: func(id, text string) {
138 callbacks["OnTextDelta"] = true
139 },
140 OnTextEnd: func(id string) {
141 callbacks["OnTextEnd"] = true
142 },
143 OnReasoningStart: func(id string) {
144 callbacks["OnReasoningStart"] = true
145 },
146 OnReasoningDelta: func(id, text string) {
147 callbacks["OnReasoningDelta"] = true
148 },
149 OnReasoningEnd: func(id string) {
150 callbacks["OnReasoningEnd"] = true
151 },
152 OnToolInputStart: func(id, toolName string) {
153 callbacks["OnToolInputStart"] = true
154 },
155 OnToolInputDelta: func(id, delta string) {
156 callbacks["OnToolInputDelta"] = true
157 },
158 OnToolInputEnd: func(id string) {
159 callbacks["OnToolInputEnd"] = true
160 },
161 OnToolCall: func(toolCall ToolCallContent) {
162 callbacks["OnToolCall"] = true
163 },
164 OnToolResult: func(result ToolResultContent) {
165 callbacks["OnToolResult"] = true
166 },
167 OnSource: func(source SourceContent) {
168 callbacks["OnSource"] = true
169 },
170 OnStreamFinish: func(usage Usage, finishReason FinishReason, providerMetadata ProviderOptions) {
171 callbacks["OnStreamFinish"] = true
172 },
173 OnStreamError: func(err error) {
174 callbacks["OnStreamError"] = true
175 },
176 }
177
178 // Execute streaming agent
179 result, err := agent.Stream(ctx, streamCall)
180 require.NoError(t, err)
181 require.NotNil(t, result)
182
183 // Verify that expected callbacks were called
184 expectedCallbacks := []string{
185 "OnAgentStart",
186 "OnAgentFinish",
187 "OnStepStart",
188 "OnStepFinish",
189 "OnFinish",
190 "OnChunk",
191 "OnWarnings",
192 "OnTextStart",
193 "OnTextDelta",
194 "OnTextEnd",
195 "OnReasoningStart",
196 "OnReasoningDelta",
197 "OnReasoningEnd",
198 "OnToolInputStart",
199 "OnToolInputDelta",
200 "OnToolInputEnd",
201 "OnSource",
202 "OnStreamFinish",
203 }
204
205 for _, callback := range expectedCallbacks {
206 require.True(t, callbacks[callback], "Expected callback %s to be called", callback)
207 }
208
209 // Verify that error callbacks were not called
210 require.False(t, callbacks["OnError"], "OnError should not be called in successful case")
211 require.False(t, callbacks["OnStreamError"], "OnStreamError should not be called in successful case")
212 require.False(t, callbacks["OnToolCall"], "OnToolCall should not be called without actual tool calls")
213 require.False(t, callbacks["OnToolResult"], "OnToolResult should not be called without actual tool results")
214}
215
216// TestStreamingAgentWithTools tests streaming agent with tool calls (mirrors TS test patterns)
217func TestStreamingAgentWithTools(t *testing.T) {
218 t.Parallel()
219
220 stepCount := 0
221 // Create a mock language model that makes a tool call then finishes
222 mockModel := &mockLanguageModel{
223 streamFunc: func(ctx context.Context, call Call) (StreamResponse, error) {
224 stepCount++
225 return func(yield func(StreamPart) bool) {
226 if stepCount == 1 {
227 // First step: make tool call
228 if !yield(StreamPart{Type: StreamPartTypeToolInputStart, ID: "tool-1", ToolCallName: "echo"}) {
229 return
230 }
231 if !yield(StreamPart{Type: StreamPartTypeToolInputDelta, ID: "tool-1", Delta: `{"message"`}) {
232 return
233 }
234 if !yield(StreamPart{Type: StreamPartTypeToolInputDelta, ID: "tool-1", Delta: `: "test"}`}) {
235 return
236 }
237 if !yield(StreamPart{Type: StreamPartTypeToolInputEnd, ID: "tool-1"}) {
238 return
239 }
240 if !yield(StreamPart{
241 Type: StreamPartTypeToolCall,
242 ID: "tool-1",
243 ToolCallName: "echo",
244 ToolCallInput: `{"message": "test"}`,
245 }) {
246 return
247 }
248 yield(StreamPart{
249 Type: StreamPartTypeFinish,
250 Usage: Usage{InputTokens: 10, OutputTokens: 5, TotalTokens: 15},
251 FinishReason: FinishReasonToolCalls,
252 })
253 } else {
254 // Second step: finish after tool execution
255 if !yield(StreamPart{Type: StreamPartTypeTextStart, ID: "text-1"}) {
256 return
257 }
258 if !yield(StreamPart{Type: StreamPartTypeTextDelta, ID: "text-1", Delta: "Tool executed successfully"}) {
259 return
260 }
261 if !yield(StreamPart{Type: StreamPartTypeTextEnd, ID: "text-1"}) {
262 return
263 }
264 yield(StreamPart{
265 Type: StreamPartTypeFinish,
266 Usage: Usage{InputTokens: 5, OutputTokens: 3, TotalTokens: 8},
267 FinishReason: FinishReasonStop,
268 })
269 }
270 }, nil
271 },
272 }
273
274 // Create agent with echo tool
275 agent := NewAgent(
276 mockModel,
277 WithSystemPrompt("You are a helpful assistant."),
278 WithTools(&EchoTool{}),
279 )
280
281 ctx := context.Background()
282
283 // Track callback invocations
284 var toolInputStartCalled bool
285 var toolInputDeltaCalled bool
286 var toolInputEndCalled bool
287 var toolCallCalled bool
288 var toolResultCalled bool
289
290 // Create streaming call with callbacks
291 streamCall := AgentStreamCall{
292 Prompt: "Echo 'test'",
293 OnToolInputStart: func(id, toolName string) {
294 toolInputStartCalled = true
295 require.Equal(t, "tool-1", id)
296 require.Equal(t, "echo", toolName)
297 },
298 OnToolInputDelta: func(id, delta string) {
299 toolInputDeltaCalled = true
300 require.Equal(t, "tool-1", id)
301 require.Contains(t, []string{`{"message"`, `: "test"}`}, delta)
302 },
303 OnToolInputEnd: func(id string) {
304 toolInputEndCalled = true
305 require.Equal(t, "tool-1", id)
306 },
307 OnToolCall: func(toolCall ToolCallContent) {
308 toolCallCalled = true
309 require.Equal(t, "echo", toolCall.ToolName)
310 require.Equal(t, `{"message": "test"}`, toolCall.Input)
311 },
312 OnToolResult: func(result ToolResultContent) {
313 toolResultCalled = true
314 require.Equal(t, "echo", result.ToolName)
315 },
316 }
317
318 // Execute streaming agent
319 result, err := agent.Stream(ctx, streamCall)
320 require.NoError(t, err)
321
322 // Verify results
323 require.True(t, toolInputStartCalled, "OnToolInputStart should have been called")
324 require.True(t, toolInputDeltaCalled, "OnToolInputDelta should have been called")
325 require.True(t, toolInputEndCalled, "OnToolInputEnd should have been called")
326 require.True(t, toolCallCalled, "OnToolCall should have been called")
327 require.True(t, toolResultCalled, "OnToolResult should have been called")
328 require.Equal(t, 2, len(result.Steps)) // Two steps: tool call + final response
329
330 // Check that tool was executed in first step
331 firstStep := result.Steps[0]
332 toolCalls := firstStep.Content.ToolCalls()
333 require.Equal(t, 1, len(toolCalls))
334 require.Equal(t, "echo", toolCalls[0].ToolName)
335
336 toolResults := firstStep.Content.ToolResults()
337 require.Equal(t, 1, len(toolResults))
338 require.Equal(t, "echo", toolResults[0].ToolName)
339}
340
341// TestStreamingAgentTextDeltas tests text streaming (mirrors TS textStream tests)
342func TestStreamingAgentTextDeltas(t *testing.T) {
343 t.Parallel()
344
345 // Create a mock language model that returns text deltas
346 mockModel := &mockLanguageModel{
347 streamFunc: func(ctx context.Context, call Call) (StreamResponse, error) {
348 return func(yield func(StreamPart) bool) {
349 if !yield(StreamPart{Type: StreamPartTypeTextStart, ID: "text-1"}) {
350 return
351 }
352 if !yield(StreamPart{Type: StreamPartTypeTextDelta, ID: "text-1", Delta: "Hello"}) {
353 return
354 }
355 if !yield(StreamPart{Type: StreamPartTypeTextDelta, ID: "text-1", Delta: ", "}) {
356 return
357 }
358 if !yield(StreamPart{Type: StreamPartTypeTextDelta, ID: "text-1", Delta: "world!"}) {
359 return
360 }
361 if !yield(StreamPart{Type: StreamPartTypeTextEnd, ID: "text-1"}) {
362 return
363 }
364 yield(StreamPart{
365 Type: StreamPartTypeFinish,
366 Usage: Usage{InputTokens: 3, OutputTokens: 10, TotalTokens: 13},
367 FinishReason: FinishReasonStop,
368 })
369 }, nil
370 },
371 }
372
373 agent := NewAgent(mockModel)
374 ctx := context.Background()
375
376 // Track text deltas
377 var textDeltas []string
378
379 streamCall := AgentStreamCall{
380 Prompt: "Say hello",
381 OnTextDelta: func(id, text string) {
382 if text != "" {
383 textDeltas = append(textDeltas, text)
384 }
385 },
386 }
387
388 result, err := agent.Stream(ctx, streamCall)
389 require.NoError(t, err)
390
391 // Verify text deltas match expected pattern
392 require.Equal(t, []string{"Hello", ", ", "world!"}, textDeltas)
393 require.Equal(t, "Hello, world!", result.Response.Content.Text())
394 require.Equal(t, int64(13), result.TotalUsage.TotalTokens)
395}
396
397// TestStreamingAgentReasoning tests reasoning content (mirrors TS reasoning tests)
398func TestStreamingAgentReasoning(t *testing.T) {
399 t.Parallel()
400
401 mockModel := &mockLanguageModel{
402 streamFunc: func(ctx context.Context, call Call) (StreamResponse, error) {
403 return func(yield func(StreamPart) bool) {
404 if !yield(StreamPart{Type: StreamPartTypeReasoningStart, ID: "reasoning-1"}) {
405 return
406 }
407 if !yield(StreamPart{Type: StreamPartTypeReasoningDelta, ID: "reasoning-1", Delta: "I will open the conversation"}) {
408 return
409 }
410 if !yield(StreamPart{Type: StreamPartTypeReasoningDelta, ID: "reasoning-1", Delta: " with witty banter."}) {
411 return
412 }
413 if !yield(StreamPart{Type: StreamPartTypeReasoningEnd, ID: "reasoning-1"}) {
414 return
415 }
416 if !yield(StreamPart{Type: StreamPartTypeTextStart, ID: "text-1"}) {
417 return
418 }
419 if !yield(StreamPart{Type: StreamPartTypeTextDelta, ID: "text-1", Delta: "Hi there!"}) {
420 return
421 }
422 if !yield(StreamPart{Type: StreamPartTypeTextEnd, ID: "text-1"}) {
423 return
424 }
425 yield(StreamPart{
426 Type: StreamPartTypeFinish,
427 Usage: Usage{InputTokens: 5, OutputTokens: 15, TotalTokens: 20},
428 FinishReason: FinishReasonStop,
429 })
430 }, nil
431 },
432 }
433
434 agent := NewAgent(mockModel)
435 ctx := context.Background()
436
437 var reasoningDeltas []string
438 var textDeltas []string
439
440 streamCall := AgentStreamCall{
441 Prompt: "Think and respond",
442 OnReasoningDelta: func(id, text string) {
443 reasoningDeltas = append(reasoningDeltas, text)
444 },
445 OnTextDelta: func(id, text string) {
446 textDeltas = append(textDeltas, text)
447 },
448 }
449
450 result, err := agent.Stream(ctx, streamCall)
451 require.NoError(t, err)
452
453 // Verify reasoning and text are separate
454 require.Equal(t, []string{"I will open the conversation", " with witty banter."}, reasoningDeltas)
455 require.Equal(t, []string{"Hi there!"}, textDeltas)
456 require.Equal(t, "Hi there!", result.Response.Content.Text())
457 require.Equal(t, "I will open the conversation with witty banter.", result.Response.Content.ReasoningText())
458}
459
460// TestStreamingAgentError tests error handling (mirrors TS error tests)
461func TestStreamingAgentError(t *testing.T) {
462 t.Parallel()
463
464 // Create a mock language model that returns an error
465 mockModel := &mockLanguageModel{
466 streamFunc: func(ctx context.Context, call Call) (StreamResponse, error) {
467 return func(yield func(StreamPart) bool) {
468 yield(StreamPart{Type: StreamPartTypeError, Error: fmt.Errorf("mock stream error")})
469 }, nil
470 },
471 }
472
473 agent := NewAgent(mockModel)
474 ctx := context.Background()
475
476 // Track error callbacks
477 var streamErrorOccurred bool
478 var errorOccurred bool
479 var errorMessage string
480
481 streamCall := AgentStreamCall{
482 Prompt: "This will fail",
483 OnStreamError: func(err error) {
484 streamErrorOccurred = true
485 },
486 OnError: func(err error) {
487 errorOccurred = true
488 errorMessage = err.Error()
489 },
490 }
491
492 // Execute streaming agent
493 result, err := agent.Stream(ctx, streamCall)
494 require.Error(t, err)
495 require.Nil(t, result)
496 require.True(t, streamErrorOccurred, "OnStreamError should have been called")
497 require.True(t, errorOccurred, "OnError should have been called")
498 require.Contains(t, errorMessage, "mock stream error")
499}
500
501// TestStreamingAgentSources tests source handling (mirrors TS source tests)
502func TestStreamingAgentSources(t *testing.T) {
503 t.Parallel()
504
505 mockModel := &mockLanguageModel{
506 streamFunc: func(ctx context.Context, call Call) (StreamResponse, error) {
507 return func(yield func(StreamPart) bool) {
508 if !yield(StreamPart{
509 Type: StreamPartTypeSource,
510 ID: "source-1",
511 SourceType: SourceTypeURL,
512 URL: "https://example.com",
513 Title: "Example",
514 }) {
515 return
516 }
517 if !yield(StreamPart{Type: StreamPartTypeTextStart, ID: "text-1"}) {
518 return
519 }
520 if !yield(StreamPart{Type: StreamPartTypeTextDelta, ID: "text-1", Delta: "Hello!"}) {
521 return
522 }
523 if !yield(StreamPart{Type: StreamPartTypeTextEnd, ID: "text-1"}) {
524 return
525 }
526 if !yield(StreamPart{
527 Type: StreamPartTypeSource,
528 ID: "source-2",
529 SourceType: SourceTypeDocument,
530 Title: "Document Example",
531 }) {
532 return
533 }
534 yield(StreamPart{
535 Type: StreamPartTypeFinish,
536 Usage: Usage{InputTokens: 3, OutputTokens: 5, TotalTokens: 8},
537 FinishReason: FinishReasonStop,
538 })
539 }, nil
540 },
541 }
542
543 agent := NewAgent(mockModel)
544 ctx := context.Background()
545
546 var sources []SourceContent
547
548 streamCall := AgentStreamCall{
549 Prompt: "Search and respond",
550 OnSource: func(source SourceContent) {
551 sources = append(sources, source)
552 },
553 }
554
555 result, err := agent.Stream(ctx, streamCall)
556 require.NoError(t, err)
557
558 // Verify sources were captured
559 require.Equal(t, 2, len(sources))
560 require.Equal(t, SourceTypeURL, sources[0].SourceType)
561 require.Equal(t, "https://example.com", sources[0].URL)
562 require.Equal(t, "Example", sources[0].Title)
563 require.Equal(t, SourceTypeDocument, sources[1].SourceType)
564 require.Equal(t, "Document Example", sources[1].Title)
565
566 // Verify sources are in final result
567 resultSources := result.Response.Content.Sources()
568 require.Equal(t, 2, len(resultSources))
569}