1package ai
2
3import (
4 "context"
5 "encoding/json"
6 "fmt"
7 "testing"
8
9 "github.com/stretchr/testify/require"
10)
11
12// EchoTool is a simple tool that echoes back the input message
13type EchoTool struct{}
14
15// Info returns the tool information
16func (e *EchoTool) Info() ToolInfo {
17 return ToolInfo{
18 Name: "echo",
19 Description: "Echo back the provided message",
20 Parameters: map[string]any{
21 "message": map[string]any{
22 "type": "string",
23 "description": "The message to echo back",
24 },
25 },
26 Required: []string{"message"},
27 }
28}
29
30// Run executes the echo tool
31func (e *EchoTool) Run(ctx context.Context, params ToolCall) (ToolResponse, error) {
32 var input struct {
33 Message string `json:"message"`
34 }
35
36 if err := json.Unmarshal([]byte(params.Input), &input); err != nil {
37 return NewTextErrorResponse("Invalid input: " + err.Error()), nil
38 }
39
40 if input.Message == "" {
41 return NewTextErrorResponse("Message cannot be empty"), nil
42 }
43
44 return NewTextResponse("Echo: " + input.Message), nil
45}
46
47// TestStreamingAgentCallbacks tests that all streaming callbacks are called correctly
48func TestStreamingAgentCallbacks(t *testing.T) {
49 t.Parallel()
50
51 // Track which callbacks were called
52 callbacks := make(map[string]bool)
53
54 // Create a mock language model that returns various stream parts
55 mockModel := &mockLanguageModel{
56 streamFunc: func(ctx context.Context, call Call) (StreamResponse, error) {
57 return func(yield func(StreamPart) bool) {
58 // Test all stream part types
59 if !yield(StreamPart{Type: StreamPartTypeWarnings, Warnings: []CallWarning{{Type: CallWarningTypeOther, Message: "test warning"}}}) {
60 return
61 }
62 if !yield(StreamPart{Type: StreamPartTypeTextStart, ID: "text-1"}) {
63 return
64 }
65 if !yield(StreamPart{Type: StreamPartTypeTextDelta, ID: "text-1", Delta: "Hello"}) {
66 return
67 }
68 if !yield(StreamPart{Type: StreamPartTypeTextEnd, ID: "text-1"}) {
69 return
70 }
71 if !yield(StreamPart{Type: StreamPartTypeReasoningStart, ID: "reasoning-1"}) {
72 return
73 }
74 if !yield(StreamPart{Type: StreamPartTypeReasoningDelta, ID: "reasoning-1", Delta: "thinking..."}) {
75 return
76 }
77 if !yield(StreamPart{Type: StreamPartTypeReasoningEnd, ID: "reasoning-1"}) {
78 return
79 }
80 if !yield(StreamPart{Type: StreamPartTypeToolInputStart, ID: "tool-1", ToolCallName: "test_tool"}) {
81 return
82 }
83 if !yield(StreamPart{Type: StreamPartTypeToolInputDelta, ID: "tool-1", Delta: `{"param"`}) {
84 return
85 }
86 if !yield(StreamPart{Type: StreamPartTypeToolInputEnd, ID: "tool-1"}) {
87 return
88 }
89 if !yield(StreamPart{Type: StreamPartTypeSource, ID: "source-1", SourceType: SourceTypeURL, URL: "https://example.com", Title: "Example"}) {
90 return
91 }
92 yield(StreamPart{
93 Type: StreamPartTypeFinish,
94 Usage: Usage{InputTokens: 5, OutputTokens: 2, TotalTokens: 7},
95 FinishReason: FinishReasonStop,
96 })
97 }, nil
98 },
99 }
100
101 // Create agent
102 agent := NewAgent(mockModel)
103
104 ctx := context.Background()
105
106 // Create streaming call with all callbacks
107 streamCall := AgentStreamCall{
108 Prompt: "Test all callbacks",
109 OnAgentStart: func() {
110 callbacks["OnAgentStart"] = true
111 },
112 OnAgentFinish: func(result *AgentResult) {
113 callbacks["OnAgentFinish"] = true
114 },
115 OnStepStart: func(stepNumber int) {
116 callbacks["OnStepStart"] = true
117 },
118 OnStepFinish: func(stepResult StepResult) {
119 callbacks["OnStepFinish"] = true
120 },
121 OnFinish: func(result *AgentResult) {
122 callbacks["OnFinish"] = true
123 },
124 OnError: func(err error) {
125 callbacks["OnError"] = true
126 },
127 OnChunk: func(part StreamPart) error {
128 callbacks["OnChunk"] = true
129 return nil
130 },
131 OnWarnings: func(warnings []CallWarning) error {
132 callbacks["OnWarnings"] = true
133 return nil
134 },
135 OnTextStart: func(id string) error {
136 callbacks["OnTextStart"] = true
137 return nil
138 },
139 OnTextDelta: func(id, text string) error {
140 callbacks["OnTextDelta"] = true
141 return nil
142 },
143 OnTextEnd: func(id string) error {
144 callbacks["OnTextEnd"] = true
145 return nil
146 },
147 OnReasoningStart: func(id string) error {
148 callbacks["OnReasoningStart"] = true
149 return nil
150 },
151 OnReasoningDelta: func(id, text string) error {
152 callbacks["OnReasoningDelta"] = true
153 return nil
154 },
155 OnReasoningEnd: func(id string, content ReasoningContent) error {
156 callbacks["OnReasoningEnd"] = true
157 return nil
158 },
159 OnToolInputStart: func(id, toolName string) error {
160 callbacks["OnToolInputStart"] = true
161 return nil
162 },
163 OnToolInputDelta: func(id, delta string) error {
164 callbacks["OnToolInputDelta"] = true
165 return nil
166 },
167 OnToolInputEnd: func(id string) error {
168 callbacks["OnToolInputEnd"] = true
169 return nil
170 },
171 OnToolCall: func(toolCall ToolCallContent) error {
172 callbacks["OnToolCall"] = true
173 return nil
174 },
175 OnToolResult: func(result ToolResultContent) error {
176 callbacks["OnToolResult"] = true
177 return nil
178 },
179 OnSource: func(source SourceContent) error {
180 callbacks["OnSource"] = true
181 return nil
182 },
183 OnStreamFinish: func(usage Usage, finishReason FinishReason, providerMetadata ProviderMetadata) error {
184 callbacks["OnStreamFinish"] = true
185 return nil
186 },
187 }
188
189 // Execute streaming agent
190 result, err := agent.Stream(ctx, streamCall)
191 require.NoError(t, err)
192 require.NotNil(t, result)
193
194 // Verify that expected callbacks were called
195 expectedCallbacks := []string{
196 "OnAgentStart",
197 "OnAgentFinish",
198 "OnStepStart",
199 "OnStepFinish",
200 "OnFinish",
201 "OnChunk",
202 "OnWarnings",
203 "OnTextStart",
204 "OnTextDelta",
205 "OnTextEnd",
206 "OnReasoningStart",
207 "OnReasoningDelta",
208 "OnReasoningEnd",
209 "OnToolInputStart",
210 "OnToolInputDelta",
211 "OnToolInputEnd",
212 "OnSource",
213 "OnStreamFinish",
214 }
215
216 for _, callback := range expectedCallbacks {
217 require.True(t, callbacks[callback], "Expected callback %s to be called", callback)
218 }
219
220 // Verify that error callbacks were not called
221 require.False(t, callbacks["OnError"], "OnError should not be called in successful case")
222 require.False(t, callbacks["OnToolCall"], "OnToolCall should not be called without actual tool calls")
223 require.False(t, callbacks["OnToolResult"], "OnToolResult should not be called without actual tool results")
224}
225
226// TestStreamingAgentWithTools tests streaming agent with tool calls (mirrors TS test patterns)
227func TestStreamingAgentWithTools(t *testing.T) {
228 t.Parallel()
229
230 stepCount := 0
231 // Create a mock language model that makes a tool call then finishes
232 mockModel := &mockLanguageModel{
233 streamFunc: func(ctx context.Context, call Call) (StreamResponse, error) {
234 stepCount++
235 return func(yield func(StreamPart) bool) {
236 if stepCount == 1 {
237 // First step: make tool call
238 if !yield(StreamPart{Type: StreamPartTypeToolInputStart, ID: "tool-1", ToolCallName: "echo"}) {
239 return
240 }
241 if !yield(StreamPart{Type: StreamPartTypeToolInputDelta, ID: "tool-1", Delta: `{"message"`}) {
242 return
243 }
244 if !yield(StreamPart{Type: StreamPartTypeToolInputDelta, ID: "tool-1", Delta: `: "test"}`}) {
245 return
246 }
247 if !yield(StreamPart{Type: StreamPartTypeToolInputEnd, ID: "tool-1"}) {
248 return
249 }
250 if !yield(StreamPart{
251 Type: StreamPartTypeToolCall,
252 ID: "tool-1",
253 ToolCallName: "echo",
254 ToolCallInput: `{"message": "test"}`,
255 }) {
256 return
257 }
258 yield(StreamPart{
259 Type: StreamPartTypeFinish,
260 Usage: Usage{InputTokens: 10, OutputTokens: 5, TotalTokens: 15},
261 FinishReason: FinishReasonToolCalls,
262 })
263 } else {
264 // Second step: finish after tool execution
265 if !yield(StreamPart{Type: StreamPartTypeTextStart, ID: "text-1"}) {
266 return
267 }
268 if !yield(StreamPart{Type: StreamPartTypeTextDelta, ID: "text-1", Delta: "Tool executed successfully"}) {
269 return
270 }
271 if !yield(StreamPart{Type: StreamPartTypeTextEnd, ID: "text-1"}) {
272 return
273 }
274 yield(StreamPart{
275 Type: StreamPartTypeFinish,
276 Usage: Usage{InputTokens: 5, OutputTokens: 3, TotalTokens: 8},
277 FinishReason: FinishReasonStop,
278 })
279 }
280 }, nil
281 },
282 }
283
284 // Create agent with echo tool
285 agent := NewAgent(
286 mockModel,
287 WithSystemPrompt("You are a helpful assistant."),
288 WithTools(&EchoTool{}),
289 )
290
291 ctx := context.Background()
292
293 // Track callback invocations
294 var toolInputStartCalled bool
295 var toolInputDeltaCalled bool
296 var toolInputEndCalled bool
297 var toolCallCalled bool
298 var toolResultCalled bool
299
300 // Create streaming call with callbacks
301 streamCall := AgentStreamCall{
302 Prompt: "Echo 'test'",
303 OnToolInputStart: func(id, toolName string) error {
304 toolInputStartCalled = true
305 require.Equal(t, "tool-1", id)
306 require.Equal(t, "echo", toolName)
307 return nil
308 },
309 OnToolInputDelta: func(id, delta string) error {
310 toolInputDeltaCalled = true
311 require.Equal(t, "tool-1", id)
312 require.Contains(t, []string{`{"message"`, `: "test"}`}, delta)
313 return nil
314 },
315 OnToolInputEnd: func(id string) error {
316 toolInputEndCalled = true
317 require.Equal(t, "tool-1", id)
318 return nil
319 },
320 OnToolCall: func(toolCall ToolCallContent) error {
321 toolCallCalled = true
322 require.Equal(t, "echo", toolCall.ToolName)
323 require.Equal(t, `{"message": "test"}`, toolCall.Input)
324 return nil
325 },
326 OnToolResult: func(result ToolResultContent) error {
327 toolResultCalled = true
328 require.Equal(t, "echo", result.ToolName)
329 return nil
330 },
331 }
332
333 // Execute streaming agent
334 result, err := agent.Stream(ctx, streamCall)
335 require.NoError(t, err)
336
337 // Verify results
338 require.True(t, toolInputStartCalled, "OnToolInputStart should have been called")
339 require.True(t, toolInputDeltaCalled, "OnToolInputDelta should have been called")
340 require.True(t, toolInputEndCalled, "OnToolInputEnd should have been called")
341 require.True(t, toolCallCalled, "OnToolCall should have been called")
342 require.True(t, toolResultCalled, "OnToolResult should have been called")
343 require.Equal(t, 2, len(result.Steps)) // Two steps: tool call + final response
344
345 // Check that tool was executed in first step
346 firstStep := result.Steps[0]
347 toolCalls := firstStep.Content.ToolCalls()
348 require.Equal(t, 1, len(toolCalls))
349 require.Equal(t, "echo", toolCalls[0].ToolName)
350
351 toolResults := firstStep.Content.ToolResults()
352 require.Equal(t, 1, len(toolResults))
353 require.Equal(t, "echo", toolResults[0].ToolName)
354}
355
356// TestStreamingAgentTextDeltas tests text streaming (mirrors TS textStream tests)
357func TestStreamingAgentTextDeltas(t *testing.T) {
358 t.Parallel()
359
360 // Create a mock language model that returns text deltas
361 mockModel := &mockLanguageModel{
362 streamFunc: func(ctx context.Context, call Call) (StreamResponse, error) {
363 return func(yield func(StreamPart) bool) {
364 if !yield(StreamPart{Type: StreamPartTypeTextStart, ID: "text-1"}) {
365 return
366 }
367 if !yield(StreamPart{Type: StreamPartTypeTextDelta, ID: "text-1", Delta: "Hello"}) {
368 return
369 }
370 if !yield(StreamPart{Type: StreamPartTypeTextDelta, ID: "text-1", Delta: ", "}) {
371 return
372 }
373 if !yield(StreamPart{Type: StreamPartTypeTextDelta, ID: "text-1", Delta: "world!"}) {
374 return
375 }
376 if !yield(StreamPart{Type: StreamPartTypeTextEnd, ID: "text-1"}) {
377 return
378 }
379 yield(StreamPart{
380 Type: StreamPartTypeFinish,
381 Usage: Usage{InputTokens: 3, OutputTokens: 10, TotalTokens: 13},
382 FinishReason: FinishReasonStop,
383 })
384 }, nil
385 },
386 }
387
388 agent := NewAgent(mockModel)
389 ctx := context.Background()
390
391 // Track text deltas
392 var textDeltas []string
393
394 streamCall := AgentStreamCall{
395 Prompt: "Say hello",
396 OnTextDelta: func(id, text string) error {
397 if text != "" {
398 textDeltas = append(textDeltas, text)
399 }
400 return nil
401 },
402 }
403
404 result, err := agent.Stream(ctx, streamCall)
405 require.NoError(t, err)
406
407 // Verify text deltas match expected pattern
408 require.Equal(t, []string{"Hello", ", ", "world!"}, textDeltas)
409 require.Equal(t, "Hello, world!", result.Response.Content.Text())
410 require.Equal(t, int64(13), result.TotalUsage.TotalTokens)
411}
412
413// TestStreamingAgentReasoning tests reasoning content (mirrors TS reasoning tests)
414func TestStreamingAgentReasoning(t *testing.T) {
415 t.Parallel()
416
417 mockModel := &mockLanguageModel{
418 streamFunc: func(ctx context.Context, call Call) (StreamResponse, error) {
419 return func(yield func(StreamPart) bool) {
420 if !yield(StreamPart{Type: StreamPartTypeReasoningStart, ID: "reasoning-1"}) {
421 return
422 }
423 if !yield(StreamPart{Type: StreamPartTypeReasoningDelta, ID: "reasoning-1", Delta: "I will open the conversation"}) {
424 return
425 }
426 if !yield(StreamPart{Type: StreamPartTypeReasoningDelta, ID: "reasoning-1", Delta: " with witty banter."}) {
427 return
428 }
429 if !yield(StreamPart{Type: StreamPartTypeReasoningEnd, ID: "reasoning-1"}) {
430 return
431 }
432 if !yield(StreamPart{Type: StreamPartTypeTextStart, ID: "text-1"}) {
433 return
434 }
435 if !yield(StreamPart{Type: StreamPartTypeTextDelta, ID: "text-1", Delta: "Hi there!"}) {
436 return
437 }
438 if !yield(StreamPart{Type: StreamPartTypeTextEnd, ID: "text-1"}) {
439 return
440 }
441 yield(StreamPart{
442 Type: StreamPartTypeFinish,
443 Usage: Usage{InputTokens: 5, OutputTokens: 15, TotalTokens: 20},
444 FinishReason: FinishReasonStop,
445 })
446 }, nil
447 },
448 }
449
450 agent := NewAgent(mockModel)
451 ctx := context.Background()
452
453 var reasoningDeltas []string
454 var textDeltas []string
455
456 streamCall := AgentStreamCall{
457 Prompt: "Think and respond",
458 OnReasoningDelta: func(id, text string) error {
459 reasoningDeltas = append(reasoningDeltas, text)
460 return nil
461 },
462 OnTextDelta: func(id, text string) error {
463 textDeltas = append(textDeltas, text)
464 return nil
465 },
466 }
467
468 result, err := agent.Stream(ctx, streamCall)
469 require.NoError(t, err)
470
471 // Verify reasoning and text are separate
472 require.Equal(t, []string{"I will open the conversation", " with witty banter."}, reasoningDeltas)
473 require.Equal(t, []string{"Hi there!"}, textDeltas)
474 require.Equal(t, "Hi there!", result.Response.Content.Text())
475 require.Equal(t, "I will open the conversation with witty banter.", result.Response.Content.ReasoningText())
476}
477
478// TestStreamingAgentError tests error handling (mirrors TS error tests)
479func TestStreamingAgentError(t *testing.T) {
480 t.Parallel()
481
482 // Create a mock language model that returns an error
483 mockModel := &mockLanguageModel{
484 streamFunc: func(ctx context.Context, call Call) (StreamResponse, error) {
485 return func(yield func(StreamPart) bool) {
486 yield(StreamPart{Type: StreamPartTypeError, Error: fmt.Errorf("mock stream error")})
487 }, nil
488 },
489 }
490
491 agent := NewAgent(mockModel)
492 ctx := context.Background()
493
494 // Track error callbacks
495 var errorOccurred bool
496 var errorMessage string
497
498 streamCall := AgentStreamCall{
499 Prompt: "This will fail",
500
501 OnError: func(err error) {
502 errorOccurred = true
503 errorMessage = err.Error()
504 },
505 }
506
507 // Execute streaming agent
508 result, err := agent.Stream(ctx, streamCall)
509 require.Error(t, err)
510 require.Nil(t, result)
511 require.True(t, errorOccurred, "OnError should have been called")
512 require.Contains(t, errorMessage, "mock stream error")
513}
514
515// TestStreamingAgentSources tests source handling (mirrors TS source tests)
516func TestStreamingAgentSources(t *testing.T) {
517 t.Parallel()
518
519 mockModel := &mockLanguageModel{
520 streamFunc: func(ctx context.Context, call Call) (StreamResponse, error) {
521 return func(yield func(StreamPart) bool) {
522 if !yield(StreamPart{
523 Type: StreamPartTypeSource,
524 ID: "source-1",
525 SourceType: SourceTypeURL,
526 URL: "https://example.com",
527 Title: "Example",
528 }) {
529 return
530 }
531 if !yield(StreamPart{Type: StreamPartTypeTextStart, ID: "text-1"}) {
532 return
533 }
534 if !yield(StreamPart{Type: StreamPartTypeTextDelta, ID: "text-1", Delta: "Hello!"}) {
535 return
536 }
537 if !yield(StreamPart{Type: StreamPartTypeTextEnd, ID: "text-1"}) {
538 return
539 }
540 if !yield(StreamPart{
541 Type: StreamPartTypeSource,
542 ID: "source-2",
543 SourceType: SourceTypeDocument,
544 Title: "Document Example",
545 }) {
546 return
547 }
548 yield(StreamPart{
549 Type: StreamPartTypeFinish,
550 Usage: Usage{InputTokens: 3, OutputTokens: 5, TotalTokens: 8},
551 FinishReason: FinishReasonStop,
552 })
553 }, nil
554 },
555 }
556
557 agent := NewAgent(mockModel)
558 ctx := context.Background()
559
560 var sources []SourceContent
561
562 streamCall := AgentStreamCall{
563 Prompt: "Search and respond",
564 OnSource: func(source SourceContent) error {
565 sources = append(sources, source)
566 return nil
567 },
568 }
569
570 result, err := agent.Stream(ctx, streamCall)
571 require.NoError(t, err)
572
573 // Verify sources were captured
574 require.Equal(t, 2, len(sources))
575 require.Equal(t, SourceTypeURL, sources[0].SourceType)
576 require.Equal(t, "https://example.com", sources[0].URL)
577 require.Equal(t, "Example", sources[0].Title)
578 require.Equal(t, SourceTypeDocument, sources[1].SourceType)
579 require.Equal(t, "Document Example", sources[1].Title)
580
581 // Verify sources are in final result
582 resultSources := result.Response.Content.Sources()
583 require.Equal(t, 2, len(resultSources))
584}