1package ai
2
3import (
4 "context"
5 "encoding/json"
6 "fmt"
7 "testing"
8
9 "github.com/stretchr/testify/require"
10)
11
12// EchoTool is a simple tool that echoes back the input message
13type EchoTool struct{}
14
15// Info returns the tool information
16func (e *EchoTool) Info() ToolInfo {
17 return ToolInfo{
18 Name: "echo",
19 Description: "Echo back the provided message",
20 Parameters: map[string]any{
21 "message": map[string]any{
22 "type": "string",
23 "description": "The message to echo back",
24 },
25 },
26 Required: []string{"message"},
27 }
28}
29
30// Run executes the echo tool
31func (e *EchoTool) Run(ctx context.Context, params ToolCall) (ToolResponse, error) {
32 var input struct {
33 Message string `json:"message"`
34 }
35
36 if err := json.Unmarshal([]byte(params.Input), &input); err != nil {
37 return NewTextErrorResponse("Invalid input: " + err.Error()), nil
38 }
39
40 if input.Message == "" {
41 return NewTextErrorResponse("Message cannot be empty"), nil
42 }
43
44 return NewTextResponse("Echo: " + input.Message), nil
45}
46
47// TestStreamingAgentCallbacks tests that all streaming callbacks are called correctly
48func TestStreamingAgentCallbacks(t *testing.T) {
49 t.Parallel()
50
51 // Track which callbacks were called
52 callbacks := make(map[string]bool)
53
54 // Create a mock language model that returns various stream parts
55 mockModel := &mockLanguageModel{
56 streamFunc: func(ctx context.Context, call Call) (StreamResponse, error) {
57 return func(yield func(StreamPart) bool) {
58 // Test all stream part types
59 if !yield(StreamPart{Type: StreamPartTypeWarnings, Warnings: []CallWarning{{Type: CallWarningTypeOther, Message: "test warning"}}}) {
60 return
61 }
62 if !yield(StreamPart{Type: StreamPartTypeTextStart, ID: "text-1"}) {
63 return
64 }
65 if !yield(StreamPart{Type: StreamPartTypeTextDelta, ID: "text-1", Delta: "Hello"}) {
66 return
67 }
68 if !yield(StreamPart{Type: StreamPartTypeTextEnd, ID: "text-1"}) {
69 return
70 }
71 if !yield(StreamPart{Type: StreamPartTypeReasoningStart, ID: "reasoning-1"}) {
72 return
73 }
74 if !yield(StreamPart{Type: StreamPartTypeReasoningDelta, ID: "reasoning-1", Delta: "thinking..."}) {
75 return
76 }
77 if !yield(StreamPart{Type: StreamPartTypeReasoningEnd, ID: "reasoning-1"}) {
78 return
79 }
80 if !yield(StreamPart{Type: StreamPartTypeToolInputStart, ID: "tool-1", ToolCallName: "test_tool"}) {
81 return
82 }
83 if !yield(StreamPart{Type: StreamPartTypeToolInputDelta, ID: "tool-1", Delta: `{"param"`}) {
84 return
85 }
86 if !yield(StreamPart{Type: StreamPartTypeToolInputEnd, ID: "tool-1"}) {
87 return
88 }
89 if !yield(StreamPart{Type: StreamPartTypeSource, ID: "source-1", SourceType: SourceTypeURL, URL: "https://example.com", Title: "Example"}) {
90 return
91 }
92 yield(StreamPart{
93 Type: StreamPartTypeFinish,
94 Usage: Usage{InputTokens: 5, OutputTokens: 2, TotalTokens: 7},
95 FinishReason: FinishReasonStop,
96 })
97 }, nil
98 },
99 }
100
101 // Create agent
102 agent := NewAgent(mockModel)
103
104 ctx := context.Background()
105
106 // Create streaming call with all callbacks
107 streamCall := AgentStreamCall{
108 Prompt: "Test all callbacks",
109 OnAgentStart: func() {
110 callbacks["OnAgentStart"] = true
111 },
112 OnAgentFinish: func(result *AgentResult) {
113 callbacks["OnAgentFinish"] = true
114 },
115 OnStepStart: func(stepNumber int) {
116 callbacks["OnStepStart"] = true
117 },
118 OnStepFinish: func(stepResult StepResult) {
119 callbacks["OnStepFinish"] = true
120 },
121 OnFinish: func(result *AgentResult) {
122 callbacks["OnFinish"] = true
123 },
124 OnError: func(err error) {
125 callbacks["OnError"] = true
126 },
127 OnChunk: func(part StreamPart) {
128 callbacks["OnChunk"] = true
129 },
130 OnWarnings: func(warnings []CallWarning) {
131 callbacks["OnWarnings"] = true
132 },
133 OnTextStart: func(id string) {
134 callbacks["OnTextStart"] = true
135 },
136 OnTextDelta: func(id, text string) {
137 callbacks["OnTextDelta"] = true
138 },
139 OnTextEnd: func(id string) {
140 callbacks["OnTextEnd"] = true
141 },
142 OnReasoningStart: func(id string) {
143 callbacks["OnReasoningStart"] = true
144 },
145 OnReasoningDelta: func(id, text string) {
146 callbacks["OnReasoningDelta"] = true
147 },
148 OnReasoningEnd: func(id string) {
149 callbacks["OnReasoningEnd"] = true
150 },
151 OnToolInputStart: func(id, toolName string) {
152 callbacks["OnToolInputStart"] = true
153 },
154 OnToolInputDelta: func(id, delta string) {
155 callbacks["OnToolInputDelta"] = true
156 },
157 OnToolInputEnd: func(id string) {
158 callbacks["OnToolInputEnd"] = true
159 },
160 OnToolCall: func(toolCall ToolCallContent) {
161 callbacks["OnToolCall"] = true
162 },
163 OnToolResult: func(result ToolResultContent) {
164 callbacks["OnToolResult"] = true
165 },
166 OnSource: func(source SourceContent) {
167 callbacks["OnSource"] = true
168 },
169 OnStreamFinish: func(usage Usage, finishReason FinishReason, providerMetadata ProviderOptions) {
170 callbacks["OnStreamFinish"] = true
171 },
172 OnStreamError: func(err error) {
173 callbacks["OnStreamError"] = true
174 },
175 }
176
177 // Execute streaming agent
178 result, err := agent.Stream(ctx, streamCall)
179 require.NoError(t, err)
180 require.NotNil(t, result)
181
182 // Verify that expected callbacks were called
183 expectedCallbacks := []string{
184 "OnAgentStart",
185 "OnAgentFinish",
186 "OnStepStart",
187 "OnStepFinish",
188 "OnFinish",
189 "OnChunk",
190 "OnWarnings",
191 "OnTextStart",
192 "OnTextDelta",
193 "OnTextEnd",
194 "OnReasoningStart",
195 "OnReasoningDelta",
196 "OnReasoningEnd",
197 "OnToolInputStart",
198 "OnToolInputDelta",
199 "OnToolInputEnd",
200 "OnSource",
201 "OnStreamFinish",
202 }
203
204 for _, callback := range expectedCallbacks {
205 require.True(t, callbacks[callback], "Expected callback %s to be called", callback)
206 }
207
208 // Verify that error callbacks were not called
209 require.False(t, callbacks["OnError"], "OnError should not be called in successful case")
210 require.False(t, callbacks["OnStreamError"], "OnStreamError should not be called in successful case")
211 require.False(t, callbacks["OnToolCall"], "OnToolCall should not be called without actual tool calls")
212 require.False(t, callbacks["OnToolResult"], "OnToolResult should not be called without actual tool results")
213}
214
215// TestStreamingAgentWithTools tests streaming agent with tool calls (mirrors TS test patterns)
216func TestStreamingAgentWithTools(t *testing.T) {
217 t.Parallel()
218
219 stepCount := 0
220 // Create a mock language model that makes a tool call then finishes
221 mockModel := &mockLanguageModel{
222 streamFunc: func(ctx context.Context, call Call) (StreamResponse, error) {
223 stepCount++
224 return func(yield func(StreamPart) bool) {
225 if stepCount == 1 {
226 // First step: make tool call
227 if !yield(StreamPart{Type: StreamPartTypeToolInputStart, ID: "tool-1", ToolCallName: "echo"}) {
228 return
229 }
230 if !yield(StreamPart{Type: StreamPartTypeToolInputDelta, ID: "tool-1", Delta: `{"message"`}) {
231 return
232 }
233 if !yield(StreamPart{Type: StreamPartTypeToolInputDelta, ID: "tool-1", Delta: `: "test"}`}) {
234 return
235 }
236 if !yield(StreamPart{Type: StreamPartTypeToolInputEnd, ID: "tool-1"}) {
237 return
238 }
239 if !yield(StreamPart{
240 Type: StreamPartTypeToolCall,
241 ID: "tool-1",
242 ToolCallName: "echo",
243 ToolCallInput: `{"message": "test"}`,
244 }) {
245 return
246 }
247 yield(StreamPart{
248 Type: StreamPartTypeFinish,
249 Usage: Usage{InputTokens: 10, OutputTokens: 5, TotalTokens: 15},
250 FinishReason: FinishReasonToolCalls,
251 })
252 } else {
253 // Second step: finish after tool execution
254 if !yield(StreamPart{Type: StreamPartTypeTextStart, ID: "text-1"}) {
255 return
256 }
257 if !yield(StreamPart{Type: StreamPartTypeTextDelta, ID: "text-1", Delta: "Tool executed successfully"}) {
258 return
259 }
260 if !yield(StreamPart{Type: StreamPartTypeTextEnd, ID: "text-1"}) {
261 return
262 }
263 yield(StreamPart{
264 Type: StreamPartTypeFinish,
265 Usage: Usage{InputTokens: 5, OutputTokens: 3, TotalTokens: 8},
266 FinishReason: FinishReasonStop,
267 })
268 }
269 }, nil
270 },
271 }
272
273 // Create agent with echo tool
274 agent := NewAgent(
275 mockModel,
276 WithSystemPrompt("You are a helpful assistant."),
277 WithTools(&EchoTool{}),
278 )
279
280 ctx := context.Background()
281
282 // Track callback invocations
283 var toolInputStartCalled bool
284 var toolInputDeltaCalled bool
285 var toolInputEndCalled bool
286 var toolCallCalled bool
287 var toolResultCalled bool
288
289 // Create streaming call with callbacks
290 streamCall := AgentStreamCall{
291 Prompt: "Echo 'test'",
292 OnToolInputStart: func(id, toolName string) {
293 toolInputStartCalled = true
294 require.Equal(t, "tool-1", id)
295 require.Equal(t, "echo", toolName)
296 },
297 OnToolInputDelta: func(id, delta string) {
298 toolInputDeltaCalled = true
299 require.Equal(t, "tool-1", id)
300 require.Contains(t, []string{`{"message"`, `: "test"}`}, delta)
301 },
302 OnToolInputEnd: func(id string) {
303 toolInputEndCalled = true
304 require.Equal(t, "tool-1", id)
305 },
306 OnToolCall: func(toolCall ToolCallContent) {
307 toolCallCalled = true
308 require.Equal(t, "echo", toolCall.ToolName)
309 require.Equal(t, `{"message": "test"}`, toolCall.Input)
310 },
311 OnToolResult: func(result ToolResultContent) {
312 toolResultCalled = true
313 require.Equal(t, "echo", result.ToolName)
314 },
315 }
316
317 // Execute streaming agent
318 result, err := agent.Stream(ctx, streamCall)
319 require.NoError(t, err)
320
321 // Verify results
322 require.True(t, toolInputStartCalled, "OnToolInputStart should have been called")
323 require.True(t, toolInputDeltaCalled, "OnToolInputDelta should have been called")
324 require.True(t, toolInputEndCalled, "OnToolInputEnd should have been called")
325 require.True(t, toolCallCalled, "OnToolCall should have been called")
326 require.True(t, toolResultCalled, "OnToolResult should have been called")
327 require.Equal(t, 2, len(result.Steps)) // Two steps: tool call + final response
328
329 // Check that tool was executed in first step
330 firstStep := result.Steps[0]
331 toolCalls := firstStep.Content.ToolCalls()
332 require.Equal(t, 1, len(toolCalls))
333 require.Equal(t, "echo", toolCalls[0].ToolName)
334
335 toolResults := firstStep.Content.ToolResults()
336 require.Equal(t, 1, len(toolResults))
337 require.Equal(t, "echo", toolResults[0].ToolName)
338}
339
340// TestStreamingAgentTextDeltas tests text streaming (mirrors TS textStream tests)
341func TestStreamingAgentTextDeltas(t *testing.T) {
342 t.Parallel()
343
344 // Create a mock language model that returns text deltas
345 mockModel := &mockLanguageModel{
346 streamFunc: func(ctx context.Context, call Call) (StreamResponse, error) {
347 return func(yield func(StreamPart) bool) {
348 if !yield(StreamPart{Type: StreamPartTypeTextStart, ID: "text-1"}) {
349 return
350 }
351 if !yield(StreamPart{Type: StreamPartTypeTextDelta, ID: "text-1", Delta: "Hello"}) {
352 return
353 }
354 if !yield(StreamPart{Type: StreamPartTypeTextDelta, ID: "text-1", Delta: ", "}) {
355 return
356 }
357 if !yield(StreamPart{Type: StreamPartTypeTextDelta, ID: "text-1", Delta: "world!"}) {
358 return
359 }
360 if !yield(StreamPart{Type: StreamPartTypeTextEnd, ID: "text-1"}) {
361 return
362 }
363 yield(StreamPart{
364 Type: StreamPartTypeFinish,
365 Usage: Usage{InputTokens: 3, OutputTokens: 10, TotalTokens: 13},
366 FinishReason: FinishReasonStop,
367 })
368 }, nil
369 },
370 }
371
372 agent := NewAgent(mockModel)
373 ctx := context.Background()
374
375 // Track text deltas
376 var textDeltas []string
377
378 streamCall := AgentStreamCall{
379 Prompt: "Say hello",
380 OnTextDelta: func(id, text string) {
381 if text != "" {
382 textDeltas = append(textDeltas, text)
383 }
384 },
385 }
386
387 result, err := agent.Stream(ctx, streamCall)
388 require.NoError(t, err)
389
390 // Verify text deltas match expected pattern
391 require.Equal(t, []string{"Hello", ", ", "world!"}, textDeltas)
392 require.Equal(t, "Hello, world!", result.Response.Content.Text())
393 require.Equal(t, int64(13), result.TotalUsage.TotalTokens)
394}
395
396// TestStreamingAgentReasoning tests reasoning content (mirrors TS reasoning tests)
397func TestStreamingAgentReasoning(t *testing.T) {
398 t.Parallel()
399
400 mockModel := &mockLanguageModel{
401 streamFunc: func(ctx context.Context, call Call) (StreamResponse, error) {
402 return func(yield func(StreamPart) bool) {
403 if !yield(StreamPart{Type: StreamPartTypeReasoningStart, ID: "reasoning-1"}) {
404 return
405 }
406 if !yield(StreamPart{Type: StreamPartTypeReasoningDelta, ID: "reasoning-1", Delta: "I will open the conversation"}) {
407 return
408 }
409 if !yield(StreamPart{Type: StreamPartTypeReasoningDelta, ID: "reasoning-1", Delta: " with witty banter."}) {
410 return
411 }
412 if !yield(StreamPart{Type: StreamPartTypeReasoningEnd, ID: "reasoning-1"}) {
413 return
414 }
415 if !yield(StreamPart{Type: StreamPartTypeTextStart, ID: "text-1"}) {
416 return
417 }
418 if !yield(StreamPart{Type: StreamPartTypeTextDelta, ID: "text-1", Delta: "Hi there!"}) {
419 return
420 }
421 if !yield(StreamPart{Type: StreamPartTypeTextEnd, ID: "text-1"}) {
422 return
423 }
424 yield(StreamPart{
425 Type: StreamPartTypeFinish,
426 Usage: Usage{InputTokens: 5, OutputTokens: 15, TotalTokens: 20},
427 FinishReason: FinishReasonStop,
428 })
429 }, nil
430 },
431 }
432
433 agent := NewAgent(mockModel)
434 ctx := context.Background()
435
436 var reasoningDeltas []string
437 var textDeltas []string
438
439 streamCall := AgentStreamCall{
440 Prompt: "Think and respond",
441 OnReasoningDelta: func(id, text string) {
442 reasoningDeltas = append(reasoningDeltas, text)
443 },
444 OnTextDelta: func(id, text string) {
445 textDeltas = append(textDeltas, text)
446 },
447 }
448
449 result, err := agent.Stream(ctx, streamCall)
450 require.NoError(t, err)
451
452 // Verify reasoning and text are separate
453 require.Equal(t, []string{"I will open the conversation", " with witty banter."}, reasoningDeltas)
454 require.Equal(t, []string{"Hi there!"}, textDeltas)
455 require.Equal(t, "Hi there!", result.Response.Content.Text())
456 require.Equal(t, "I will open the conversation with witty banter.", result.Response.Content.ReasoningText())
457}
458
459// TestStreamingAgentError tests error handling (mirrors TS error tests)
460func TestStreamingAgentError(t *testing.T) {
461 t.Parallel()
462
463 // Create a mock language model that returns an error
464 mockModel := &mockLanguageModel{
465 streamFunc: func(ctx context.Context, call Call) (StreamResponse, error) {
466 return func(yield func(StreamPart) bool) {
467 yield(StreamPart{Type: StreamPartTypeError, Error: fmt.Errorf("mock stream error")})
468 }, nil
469 },
470 }
471
472 agent := NewAgent(mockModel)
473 ctx := context.Background()
474
475 // Track error callbacks
476 var streamErrorOccurred bool
477 var errorOccurred bool
478 var errorMessage string
479
480 streamCall := AgentStreamCall{
481 Prompt: "This will fail",
482 OnStreamError: func(err error) {
483 streamErrorOccurred = true
484 },
485 OnError: func(err error) {
486 errorOccurred = true
487 errorMessage = err.Error()
488 },
489 }
490
491 // Execute streaming agent
492 result, err := agent.Stream(ctx, streamCall)
493 require.Error(t, err)
494 require.Nil(t, result)
495 require.True(t, streamErrorOccurred, "OnStreamError should have been called")
496 require.True(t, errorOccurred, "OnError should have been called")
497 require.Contains(t, errorMessage, "mock stream error")
498}
499
500// TestStreamingAgentSources tests source handling (mirrors TS source tests)
501func TestStreamingAgentSources(t *testing.T) {
502 t.Parallel()
503
504 mockModel := &mockLanguageModel{
505 streamFunc: func(ctx context.Context, call Call) (StreamResponse, error) {
506 return func(yield func(StreamPart) bool) {
507 if !yield(StreamPart{
508 Type: StreamPartTypeSource,
509 ID: "source-1",
510 SourceType: SourceTypeURL,
511 URL: "https://example.com",
512 Title: "Example",
513 }) {
514 return
515 }
516 if !yield(StreamPart{Type: StreamPartTypeTextStart, ID: "text-1"}) {
517 return
518 }
519 if !yield(StreamPart{Type: StreamPartTypeTextDelta, ID: "text-1", Delta: "Hello!"}) {
520 return
521 }
522 if !yield(StreamPart{Type: StreamPartTypeTextEnd, ID: "text-1"}) {
523 return
524 }
525 if !yield(StreamPart{
526 Type: StreamPartTypeSource,
527 ID: "source-2",
528 SourceType: SourceTypeDocument,
529 Title: "Document Example",
530 }) {
531 return
532 }
533 yield(StreamPart{
534 Type: StreamPartTypeFinish,
535 Usage: Usage{InputTokens: 3, OutputTokens: 5, TotalTokens: 8},
536 FinishReason: FinishReasonStop,
537 })
538 }, nil
539 },
540 }
541
542 agent := NewAgent(mockModel)
543 ctx := context.Background()
544
545 var sources []SourceContent
546
547 streamCall := AgentStreamCall{
548 Prompt: "Search and respond",
549 OnSource: func(source SourceContent) {
550 sources = append(sources, source)
551 },
552 }
553
554 result, err := agent.Stream(ctx, streamCall)
555 require.NoError(t, err)
556
557 // Verify sources were captured
558 require.Equal(t, 2, len(sources))
559 require.Equal(t, SourceTypeURL, sources[0].SourceType)
560 require.Equal(t, "https://example.com", sources[0].URL)
561 require.Equal(t, "Example", sources[0].Title)
562 require.Equal(t, SourceTypeDocument, sources[1].SourceType)
563 require.Equal(t, "Document Example", sources[1].Title)
564
565 // Verify sources are in final result
566 resultSources := result.Response.Content.Sources()
567 require.Equal(t, 2, len(resultSources))
568}