1package ai
2
3import (
4 "context"
5 "encoding/json"
6 "fmt"
7 "testing"
8
9 "github.com/stretchr/testify/require"
10)
11
12// EchoTool is a simple tool that echoes back the input message
13type EchoTool struct{}
14
15// Info returns the tool information
16func (e *EchoTool) Info() ToolInfo {
17 return ToolInfo{
18 Name: "echo",
19 Description: "Echo back the provided message",
20 Parameters: map[string]any{
21 "message": map[string]any{
22 "type": "string",
23 "description": "The message to echo back",
24 },
25 },
26 Required: []string{"message"},
27 }
28}
29
30// Run executes the echo tool
31func (e *EchoTool) Run(ctx context.Context, params ToolCall) (ToolResponse, error) {
32 var input struct {
33 Message string `json:"message"`
34 }
35
36 if err := json.Unmarshal([]byte(params.Input), &input); err != nil {
37 return NewTextErrorResponse("Invalid input: " + err.Error()), nil
38 }
39
40 if input.Message == "" {
41 return NewTextErrorResponse("Message cannot be empty"), nil
42 }
43
44 return NewTextResponse("Echo: " + input.Message), nil
45}
46
47// TestStreamingAgentCallbacks tests that all streaming callbacks are called correctly
48func TestStreamingAgentCallbacks(t *testing.T) {
49 t.Parallel()
50
51 // Track which callbacks were called
52 callbacks := make(map[string]bool)
53
54 // Create a mock language model that returns various stream parts
55 mockModel := &mockLanguageModel{
56 streamFunc: func(ctx context.Context, call Call) (StreamResponse, error) {
57 return func(yield func(StreamPart) bool) {
58 // Test all stream part types
59 if !yield(StreamPart{Type: StreamPartTypeWarnings, Warnings: []CallWarning{{Type: CallWarningTypeOther, Message: "test warning"}}}) {
60 return
61 }
62 if !yield(StreamPart{Type: StreamPartTypeTextStart, ID: "text-1"}) {
63 return
64 }
65 if !yield(StreamPart{Type: StreamPartTypeTextDelta, ID: "text-1", Delta: "Hello"}) {
66 return
67 }
68 if !yield(StreamPart{Type: StreamPartTypeTextEnd, ID: "text-1"}) {
69 return
70 }
71 if !yield(StreamPart{Type: StreamPartTypeReasoningStart, ID: "reasoning-1"}) {
72 return
73 }
74 if !yield(StreamPart{Type: StreamPartTypeReasoningDelta, ID: "reasoning-1", Delta: "thinking..."}) {
75 return
76 }
77 if !yield(StreamPart{Type: StreamPartTypeReasoningEnd, ID: "reasoning-1"}) {
78 return
79 }
80 if !yield(StreamPart{Type: StreamPartTypeToolInputStart, ID: "tool-1", ToolCallName: "test_tool"}) {
81 return
82 }
83 if !yield(StreamPart{Type: StreamPartTypeToolInputDelta, ID: "tool-1", Delta: `{"param"`}) {
84 return
85 }
86 if !yield(StreamPart{Type: StreamPartTypeToolInputEnd, ID: "tool-1"}) {
87 return
88 }
89 if !yield(StreamPart{Type: StreamPartTypeSource, ID: "source-1", SourceType: SourceTypeURL, URL: "https://example.com", Title: "Example"}) {
90 return
91 }
92 yield(StreamPart{
93 Type: StreamPartTypeFinish,
94 Usage: Usage{InputTokens: 5, OutputTokens: 2, TotalTokens: 7},
95 FinishReason: FinishReasonStop,
96 })
97 }, nil
98 },
99 }
100
101 // Create agent
102 agent := NewAgent(mockModel)
103
104 ctx := context.Background()
105
106 // Create streaming call with all callbacks
107 streamCall := AgentStreamCall{
108 Prompt: "Test all callbacks",
109 OnAgentStart: func() {
110 callbacks["OnAgentStart"] = true
111 },
112 OnAgentFinish: func(result *AgentResult) error {
113 callbacks["OnAgentFinish"] = true
114 return nil
115 },
116 OnStepStart: func(stepNumber int) error {
117 callbacks["OnStepStart"] = true
118 return nil
119 },
120 OnStepFinish: func(stepResult StepResult) error {
121 callbacks["OnStepFinish"] = true
122 return nil
123 },
124 OnFinish: func(result *AgentResult) {
125 callbacks["OnFinish"] = true
126 },
127 OnError: func(err error) {
128 callbacks["OnError"] = true
129 },
130 OnChunk: func(part StreamPart) error {
131 callbacks["OnChunk"] = true
132 return nil
133 },
134 OnWarnings: func(warnings []CallWarning) error {
135 callbacks["OnWarnings"] = true
136 return nil
137 },
138 OnTextStart: func(id string) error {
139 callbacks["OnTextStart"] = true
140 return nil
141 },
142 OnTextDelta: func(id, text string) error {
143 callbacks["OnTextDelta"] = true
144 return nil
145 },
146 OnTextEnd: func(id string) error {
147 callbacks["OnTextEnd"] = true
148 return nil
149 },
150 OnReasoningStart: func(id string) error {
151 callbacks["OnReasoningStart"] = true
152 return nil
153 },
154 OnReasoningDelta: func(id, text string) error {
155 callbacks["OnReasoningDelta"] = true
156 return nil
157 },
158 OnReasoningEnd: func(id string, content ReasoningContent) error {
159 callbacks["OnReasoningEnd"] = true
160 return nil
161 },
162 OnToolInputStart: func(id, toolName string) error {
163 callbacks["OnToolInputStart"] = true
164 return nil
165 },
166 OnToolInputDelta: func(id, delta string) error {
167 callbacks["OnToolInputDelta"] = true
168 return nil
169 },
170 OnToolInputEnd: func(id string) error {
171 callbacks["OnToolInputEnd"] = true
172 return nil
173 },
174 OnToolCall: func(toolCall ToolCallContent) error {
175 callbacks["OnToolCall"] = true
176 return nil
177 },
178 OnToolResult: func(result ToolResultContent) error {
179 callbacks["OnToolResult"] = true
180 return nil
181 },
182 OnSource: func(source SourceContent) error {
183 callbacks["OnSource"] = true
184 return nil
185 },
186 OnStreamFinish: func(usage Usage, finishReason FinishReason, providerMetadata ProviderMetadata) error {
187 callbacks["OnStreamFinish"] = true
188 return nil
189 },
190 }
191
192 // Execute streaming agent
193 result, err := agent.Stream(ctx, streamCall)
194 require.NoError(t, err)
195 require.NotNil(t, result)
196
197 // Verify that expected callbacks were called
198 expectedCallbacks := []string{
199 "OnAgentStart",
200 "OnAgentFinish",
201 "OnStepStart",
202 "OnStepFinish",
203 "OnFinish",
204 "OnChunk",
205 "OnWarnings",
206 "OnTextStart",
207 "OnTextDelta",
208 "OnTextEnd",
209 "OnReasoningStart",
210 "OnReasoningDelta",
211 "OnReasoningEnd",
212 "OnToolInputStart",
213 "OnToolInputDelta",
214 "OnToolInputEnd",
215 "OnSource",
216 "OnStreamFinish",
217 }
218
219 for _, callback := range expectedCallbacks {
220 require.True(t, callbacks[callback], "Expected callback %s to be called", callback)
221 }
222
223 // Verify that error callbacks were not called
224 require.False(t, callbacks["OnError"], "OnError should not be called in successful case")
225 require.False(t, callbacks["OnToolCall"], "OnToolCall should not be called without actual tool calls")
226 require.False(t, callbacks["OnToolResult"], "OnToolResult should not be called without actual tool results")
227}
228
229// TestStreamingAgentWithTools tests streaming agent with tool calls (mirrors TS test patterns)
230func TestStreamingAgentWithTools(t *testing.T) {
231 t.Parallel()
232
233 stepCount := 0
234 // Create a mock language model that makes a tool call then finishes
235 mockModel := &mockLanguageModel{
236 streamFunc: func(ctx context.Context, call Call) (StreamResponse, error) {
237 stepCount++
238 return func(yield func(StreamPart) bool) {
239 if stepCount == 1 {
240 // First step: make tool call
241 if !yield(StreamPart{Type: StreamPartTypeToolInputStart, ID: "tool-1", ToolCallName: "echo"}) {
242 return
243 }
244 if !yield(StreamPart{Type: StreamPartTypeToolInputDelta, ID: "tool-1", Delta: `{"message"`}) {
245 return
246 }
247 if !yield(StreamPart{Type: StreamPartTypeToolInputDelta, ID: "tool-1", Delta: `: "test"}`}) {
248 return
249 }
250 if !yield(StreamPart{Type: StreamPartTypeToolInputEnd, ID: "tool-1"}) {
251 return
252 }
253 if !yield(StreamPart{
254 Type: StreamPartTypeToolCall,
255 ID: "tool-1",
256 ToolCallName: "echo",
257 ToolCallInput: `{"message": "test"}`,
258 }) {
259 return
260 }
261 yield(StreamPart{
262 Type: StreamPartTypeFinish,
263 Usage: Usage{InputTokens: 10, OutputTokens: 5, TotalTokens: 15},
264 FinishReason: FinishReasonToolCalls,
265 })
266 } else {
267 // Second step: finish after tool execution
268 if !yield(StreamPart{Type: StreamPartTypeTextStart, ID: "text-1"}) {
269 return
270 }
271 if !yield(StreamPart{Type: StreamPartTypeTextDelta, ID: "text-1", Delta: "Tool executed successfully"}) {
272 return
273 }
274 if !yield(StreamPart{Type: StreamPartTypeTextEnd, ID: "text-1"}) {
275 return
276 }
277 yield(StreamPart{
278 Type: StreamPartTypeFinish,
279 Usage: Usage{InputTokens: 5, OutputTokens: 3, TotalTokens: 8},
280 FinishReason: FinishReasonStop,
281 })
282 }
283 }, nil
284 },
285 }
286
287 // Create agent with echo tool
288 agent := NewAgent(
289 mockModel,
290 WithSystemPrompt("You are a helpful assistant."),
291 WithTools(&EchoTool{}),
292 )
293
294 ctx := context.Background()
295
296 // Track callback invocations
297 var toolInputStartCalled bool
298 var toolInputDeltaCalled bool
299 var toolInputEndCalled bool
300 var toolCallCalled bool
301 var toolResultCalled bool
302
303 // Create streaming call with callbacks
304 streamCall := AgentStreamCall{
305 Prompt: "Echo 'test'",
306 OnToolInputStart: func(id, toolName string) error {
307 toolInputStartCalled = true
308 require.Equal(t, "tool-1", id)
309 require.Equal(t, "echo", toolName)
310 return nil
311 },
312 OnToolInputDelta: func(id, delta string) error {
313 toolInputDeltaCalled = true
314 require.Equal(t, "tool-1", id)
315 require.Contains(t, []string{`{"message"`, `: "test"}`}, delta)
316 return nil
317 },
318 OnToolInputEnd: func(id string) error {
319 toolInputEndCalled = true
320 require.Equal(t, "tool-1", id)
321 return nil
322 },
323 OnToolCall: func(toolCall ToolCallContent) error {
324 toolCallCalled = true
325 require.Equal(t, "echo", toolCall.ToolName)
326 require.Equal(t, `{"message": "test"}`, toolCall.Input)
327 return nil
328 },
329 OnToolResult: func(result ToolResultContent) error {
330 toolResultCalled = true
331 require.Equal(t, "echo", result.ToolName)
332 return nil
333 },
334 }
335
336 // Execute streaming agent
337 result, err := agent.Stream(ctx, streamCall)
338 require.NoError(t, err)
339
340 // Verify results
341 require.True(t, toolInputStartCalled, "OnToolInputStart should have been called")
342 require.True(t, toolInputDeltaCalled, "OnToolInputDelta should have been called")
343 require.True(t, toolInputEndCalled, "OnToolInputEnd should have been called")
344 require.True(t, toolCallCalled, "OnToolCall should have been called")
345 require.True(t, toolResultCalled, "OnToolResult should have been called")
346 require.Equal(t, 2, len(result.Steps)) // Two steps: tool call + final response
347
348 // Check that tool was executed in first step
349 firstStep := result.Steps[0]
350 toolCalls := firstStep.Content.ToolCalls()
351 require.Equal(t, 1, len(toolCalls))
352 require.Equal(t, "echo", toolCalls[0].ToolName)
353
354 toolResults := firstStep.Content.ToolResults()
355 require.Equal(t, 1, len(toolResults))
356 require.Equal(t, "echo", toolResults[0].ToolName)
357}
358
359// TestStreamingAgentTextDeltas tests text streaming (mirrors TS textStream tests)
360func TestStreamingAgentTextDeltas(t *testing.T) {
361 t.Parallel()
362
363 // Create a mock language model that returns text deltas
364 mockModel := &mockLanguageModel{
365 streamFunc: func(ctx context.Context, call Call) (StreamResponse, error) {
366 return func(yield func(StreamPart) bool) {
367 if !yield(StreamPart{Type: StreamPartTypeTextStart, ID: "text-1"}) {
368 return
369 }
370 if !yield(StreamPart{Type: StreamPartTypeTextDelta, ID: "text-1", Delta: "Hello"}) {
371 return
372 }
373 if !yield(StreamPart{Type: StreamPartTypeTextDelta, ID: "text-1", Delta: ", "}) {
374 return
375 }
376 if !yield(StreamPart{Type: StreamPartTypeTextDelta, ID: "text-1", Delta: "world!"}) {
377 return
378 }
379 if !yield(StreamPart{Type: StreamPartTypeTextEnd, ID: "text-1"}) {
380 return
381 }
382 yield(StreamPart{
383 Type: StreamPartTypeFinish,
384 Usage: Usage{InputTokens: 3, OutputTokens: 10, TotalTokens: 13},
385 FinishReason: FinishReasonStop,
386 })
387 }, nil
388 },
389 }
390
391 agent := NewAgent(mockModel)
392 ctx := context.Background()
393
394 // Track text deltas
395 var textDeltas []string
396
397 streamCall := AgentStreamCall{
398 Prompt: "Say hello",
399 OnTextDelta: func(id, text string) error {
400 if text != "" {
401 textDeltas = append(textDeltas, text)
402 }
403 return nil
404 },
405 }
406
407 result, err := agent.Stream(ctx, streamCall)
408 require.NoError(t, err)
409
410 // Verify text deltas match expected pattern
411 require.Equal(t, []string{"Hello", ", ", "world!"}, textDeltas)
412 require.Equal(t, "Hello, world!", result.Response.Content.Text())
413 require.Equal(t, int64(13), result.TotalUsage.TotalTokens)
414}
415
416// TestStreamingAgentReasoning tests reasoning content (mirrors TS reasoning tests)
417func TestStreamingAgentReasoning(t *testing.T) {
418 t.Parallel()
419
420 mockModel := &mockLanguageModel{
421 streamFunc: func(ctx context.Context, call Call) (StreamResponse, error) {
422 return func(yield func(StreamPart) bool) {
423 if !yield(StreamPart{Type: StreamPartTypeReasoningStart, ID: "reasoning-1"}) {
424 return
425 }
426 if !yield(StreamPart{Type: StreamPartTypeReasoningDelta, ID: "reasoning-1", Delta: "I will open the conversation"}) {
427 return
428 }
429 if !yield(StreamPart{Type: StreamPartTypeReasoningDelta, ID: "reasoning-1", Delta: " with witty banter."}) {
430 return
431 }
432 if !yield(StreamPart{Type: StreamPartTypeReasoningEnd, ID: "reasoning-1"}) {
433 return
434 }
435 if !yield(StreamPart{Type: StreamPartTypeTextStart, ID: "text-1"}) {
436 return
437 }
438 if !yield(StreamPart{Type: StreamPartTypeTextDelta, ID: "text-1", Delta: "Hi there!"}) {
439 return
440 }
441 if !yield(StreamPart{Type: StreamPartTypeTextEnd, ID: "text-1"}) {
442 return
443 }
444 yield(StreamPart{
445 Type: StreamPartTypeFinish,
446 Usage: Usage{InputTokens: 5, OutputTokens: 15, TotalTokens: 20},
447 FinishReason: FinishReasonStop,
448 })
449 }, nil
450 },
451 }
452
453 agent := NewAgent(mockModel)
454 ctx := context.Background()
455
456 var reasoningDeltas []string
457 var textDeltas []string
458
459 streamCall := AgentStreamCall{
460 Prompt: "Think and respond",
461 OnReasoningDelta: func(id, text string) error {
462 reasoningDeltas = append(reasoningDeltas, text)
463 return nil
464 },
465 OnTextDelta: func(id, text string) error {
466 textDeltas = append(textDeltas, text)
467 return nil
468 },
469 }
470
471 result, err := agent.Stream(ctx, streamCall)
472 require.NoError(t, err)
473
474 // Verify reasoning and text are separate
475 require.Equal(t, []string{"I will open the conversation", " with witty banter."}, reasoningDeltas)
476 require.Equal(t, []string{"Hi there!"}, textDeltas)
477 require.Equal(t, "Hi there!", result.Response.Content.Text())
478 require.Equal(t, "I will open the conversation with witty banter.", result.Response.Content.ReasoningText())
479}
480
481// TestStreamingAgentError tests error handling (mirrors TS error tests)
482func TestStreamingAgentError(t *testing.T) {
483 t.Parallel()
484
485 // Create a mock language model that returns an error
486 mockModel := &mockLanguageModel{
487 streamFunc: func(ctx context.Context, call Call) (StreamResponse, error) {
488 return func(yield func(StreamPart) bool) {
489 yield(StreamPart{Type: StreamPartTypeError, Error: fmt.Errorf("mock stream error")})
490 }, nil
491 },
492 }
493
494 agent := NewAgent(mockModel)
495 ctx := context.Background()
496
497 // Track error callbacks
498 var errorOccurred bool
499 var errorMessage string
500
501 streamCall := AgentStreamCall{
502 Prompt: "This will fail",
503
504 OnError: func(err error) {
505 errorOccurred = true
506 errorMessage = err.Error()
507 },
508 }
509
510 // Execute streaming agent
511 result, err := agent.Stream(ctx, streamCall)
512 require.Error(t, err)
513 require.Nil(t, result)
514 require.True(t, errorOccurred, "OnError should have been called")
515 require.Contains(t, errorMessage, "mock stream error")
516}
517
518// TestStreamingAgentSources tests source handling (mirrors TS source tests)
519func TestStreamingAgentSources(t *testing.T) {
520 t.Parallel()
521
522 mockModel := &mockLanguageModel{
523 streamFunc: func(ctx context.Context, call Call) (StreamResponse, error) {
524 return func(yield func(StreamPart) bool) {
525 if !yield(StreamPart{
526 Type: StreamPartTypeSource,
527 ID: "source-1",
528 SourceType: SourceTypeURL,
529 URL: "https://example.com",
530 Title: "Example",
531 }) {
532 return
533 }
534 if !yield(StreamPart{Type: StreamPartTypeTextStart, ID: "text-1"}) {
535 return
536 }
537 if !yield(StreamPart{Type: StreamPartTypeTextDelta, ID: "text-1", Delta: "Hello!"}) {
538 return
539 }
540 if !yield(StreamPart{Type: StreamPartTypeTextEnd, ID: "text-1"}) {
541 return
542 }
543 if !yield(StreamPart{
544 Type: StreamPartTypeSource,
545 ID: "source-2",
546 SourceType: SourceTypeDocument,
547 Title: "Document Example",
548 }) {
549 return
550 }
551 yield(StreamPart{
552 Type: StreamPartTypeFinish,
553 Usage: Usage{InputTokens: 3, OutputTokens: 5, TotalTokens: 8},
554 FinishReason: FinishReasonStop,
555 })
556 }, nil
557 },
558 }
559
560 agent := NewAgent(mockModel)
561 ctx := context.Background()
562
563 var sources []SourceContent
564
565 streamCall := AgentStreamCall{
566 Prompt: "Search and respond",
567 OnSource: func(source SourceContent) error {
568 sources = append(sources, source)
569 return nil
570 },
571 }
572
573 result, err := agent.Stream(ctx, streamCall)
574 require.NoError(t, err)
575
576 // Verify sources were captured
577 require.Equal(t, 2, len(sources))
578 require.Equal(t, SourceTypeURL, sources[0].SourceType)
579 require.Equal(t, "https://example.com", sources[0].URL)
580 require.Equal(t, "Example", sources[0].Title)
581 require.Equal(t, SourceTypeDocument, sources[1].SourceType)
582 require.Equal(t, "Document Example", sources[1].Title)
583
584 // Verify sources are in final result
585 resultSources := result.Response.Content.Sources()
586 require.Equal(t, 2, len(resultSources))
587}