1package ai
2
3import (
4 "context"
5 "encoding/json"
6 "fmt"
7 "testing"
8
9 "github.com/stretchr/testify/require"
10)
11
12// EchoTool is a simple tool that echoes back the input message
13type EchoTool struct {
14 providerOptions ProviderOptions
15}
16
17func (e *EchoTool) SetProviderOptions(opts ProviderOptions) {
18 e.providerOptions = opts
19}
20
21func (e *EchoTool) ProviderOptions() ProviderOptions {
22 return e.providerOptions
23}
24
25// Info returns the tool information
26func (e *EchoTool) Info() ToolInfo {
27 return ToolInfo{
28 Name: "echo",
29 Description: "Echo back the provided message",
30 Parameters: map[string]any{
31 "message": map[string]any{
32 "type": "string",
33 "description": "The message to echo back",
34 },
35 },
36 Required: []string{"message"},
37 }
38}
39
40// Run executes the echo tool
41func (e *EchoTool) Run(ctx context.Context, params ToolCall) (ToolResponse, error) {
42 var input struct {
43 Message string `json:"message"`
44 }
45
46 if err := json.Unmarshal([]byte(params.Input), &input); err != nil {
47 return NewTextErrorResponse("Invalid input: " + err.Error()), nil
48 }
49
50 if input.Message == "" {
51 return NewTextErrorResponse("Message cannot be empty"), nil
52 }
53
54 return NewTextResponse("Echo: " + input.Message), nil
55}
56
57// TestStreamingAgentCallbacks tests that all streaming callbacks are called correctly
58func TestStreamingAgentCallbacks(t *testing.T) {
59 t.Parallel()
60
61 // Track which callbacks were called
62 callbacks := make(map[string]bool)
63
64 // Create a mock language model that returns various stream parts
65 mockModel := &mockLanguageModel{
66 streamFunc: func(ctx context.Context, call Call) (StreamResponse, error) {
67 return func(yield func(StreamPart) bool) {
68 // Test all stream part types
69 if !yield(StreamPart{Type: StreamPartTypeWarnings, Warnings: []CallWarning{{Type: CallWarningTypeOther, Message: "test warning"}}}) {
70 return
71 }
72 if !yield(StreamPart{Type: StreamPartTypeTextStart, ID: "text-1"}) {
73 return
74 }
75 if !yield(StreamPart{Type: StreamPartTypeTextDelta, ID: "text-1", Delta: "Hello"}) {
76 return
77 }
78 if !yield(StreamPart{Type: StreamPartTypeTextEnd, ID: "text-1"}) {
79 return
80 }
81 if !yield(StreamPart{Type: StreamPartTypeReasoningStart, ID: "reasoning-1"}) {
82 return
83 }
84 if !yield(StreamPart{Type: StreamPartTypeReasoningDelta, ID: "reasoning-1", Delta: "thinking..."}) {
85 return
86 }
87 if !yield(StreamPart{Type: StreamPartTypeReasoningEnd, ID: "reasoning-1"}) {
88 return
89 }
90 if !yield(StreamPart{Type: StreamPartTypeToolInputStart, ID: "tool-1", ToolCallName: "test_tool"}) {
91 return
92 }
93 if !yield(StreamPart{Type: StreamPartTypeToolInputDelta, ID: "tool-1", Delta: `{"param"`}) {
94 return
95 }
96 if !yield(StreamPart{Type: StreamPartTypeToolInputEnd, ID: "tool-1"}) {
97 return
98 }
99 if !yield(StreamPart{Type: StreamPartTypeSource, ID: "source-1", SourceType: SourceTypeURL, URL: "https://example.com", Title: "Example"}) {
100 return
101 }
102 yield(StreamPart{
103 Type: StreamPartTypeFinish,
104 Usage: Usage{InputTokens: 5, OutputTokens: 2, TotalTokens: 7},
105 FinishReason: FinishReasonStop,
106 })
107 }, nil
108 },
109 }
110
111 // Create agent
112 agent := NewAgent(mockModel)
113
114 ctx := context.Background()
115
116 // Create streaming call with all callbacks
117 streamCall := AgentStreamCall{
118 Prompt: "Test all callbacks",
119 OnAgentStart: func() {
120 callbacks["OnAgentStart"] = true
121 },
122 OnAgentFinish: func(result *AgentResult) error {
123 callbacks["OnAgentFinish"] = true
124 return nil
125 },
126 OnStepStart: func(stepNumber int) error {
127 callbacks["OnStepStart"] = true
128 return nil
129 },
130 OnStepFinish: func(stepResult StepResult) error {
131 callbacks["OnStepFinish"] = true
132 return nil
133 },
134 OnFinish: func(result *AgentResult) {
135 callbacks["OnFinish"] = true
136 },
137 OnError: func(err error) {
138 callbacks["OnError"] = true
139 },
140 OnChunk: func(part StreamPart) error {
141 callbacks["OnChunk"] = true
142 return nil
143 },
144 OnWarnings: func(warnings []CallWarning) error {
145 callbacks["OnWarnings"] = true
146 return nil
147 },
148 OnTextStart: func(id string) error {
149 callbacks["OnTextStart"] = true
150 return nil
151 },
152 OnTextDelta: func(id, text string) error {
153 callbacks["OnTextDelta"] = true
154 return nil
155 },
156 OnTextEnd: func(id string) error {
157 callbacks["OnTextEnd"] = true
158 return nil
159 },
160 OnReasoningStart: func(id string) error {
161 callbacks["OnReasoningStart"] = true
162 return nil
163 },
164 OnReasoningDelta: func(id, text string) error {
165 callbacks["OnReasoningDelta"] = true
166 return nil
167 },
168 OnReasoningEnd: func(id string, content ReasoningContent) error {
169 callbacks["OnReasoningEnd"] = true
170 return nil
171 },
172 OnToolInputStart: func(id, toolName string) error {
173 callbacks["OnToolInputStart"] = true
174 return nil
175 },
176 OnToolInputDelta: func(id, delta string) error {
177 callbacks["OnToolInputDelta"] = true
178 return nil
179 },
180 OnToolInputEnd: func(id string) error {
181 callbacks["OnToolInputEnd"] = true
182 return nil
183 },
184 OnToolCall: func(toolCall ToolCallContent) error {
185 callbacks["OnToolCall"] = true
186 return nil
187 },
188 OnToolResult: func(result ToolResultContent) error {
189 callbacks["OnToolResult"] = true
190 return nil
191 },
192 OnSource: func(source SourceContent) error {
193 callbacks["OnSource"] = true
194 return nil
195 },
196 OnStreamFinish: func(usage Usage, finishReason FinishReason, providerMetadata ProviderMetadata) error {
197 callbacks["OnStreamFinish"] = true
198 return nil
199 },
200 }
201
202 // Execute streaming agent
203 result, err := agent.Stream(ctx, streamCall)
204 require.NoError(t, err)
205 require.NotNil(t, result)
206
207 // Verify that expected callbacks were called
208 expectedCallbacks := []string{
209 "OnAgentStart",
210 "OnAgentFinish",
211 "OnStepStart",
212 "OnStepFinish",
213 "OnFinish",
214 "OnChunk",
215 "OnWarnings",
216 "OnTextStart",
217 "OnTextDelta",
218 "OnTextEnd",
219 "OnReasoningStart",
220 "OnReasoningDelta",
221 "OnReasoningEnd",
222 "OnToolInputStart",
223 "OnToolInputDelta",
224 "OnToolInputEnd",
225 "OnSource",
226 "OnStreamFinish",
227 }
228
229 for _, callback := range expectedCallbacks {
230 require.True(t, callbacks[callback], "Expected callback %s to be called", callback)
231 }
232
233 // Verify that error callbacks were not called
234 require.False(t, callbacks["OnError"], "OnError should not be called in successful case")
235 require.False(t, callbacks["OnToolCall"], "OnToolCall should not be called without actual tool calls")
236 require.False(t, callbacks["OnToolResult"], "OnToolResult should not be called without actual tool results")
237}
238
239// TestStreamingAgentWithTools tests streaming agent with tool calls (mirrors TS test patterns)
240func TestStreamingAgentWithTools(t *testing.T) {
241 t.Parallel()
242
243 stepCount := 0
244 // Create a mock language model that makes a tool call then finishes
245 mockModel := &mockLanguageModel{
246 streamFunc: func(ctx context.Context, call Call) (StreamResponse, error) {
247 stepCount++
248 return func(yield func(StreamPart) bool) {
249 if stepCount == 1 {
250 // First step: make tool call
251 if !yield(StreamPart{Type: StreamPartTypeToolInputStart, ID: "tool-1", ToolCallName: "echo"}) {
252 return
253 }
254 if !yield(StreamPart{Type: StreamPartTypeToolInputDelta, ID: "tool-1", Delta: `{"message"`}) {
255 return
256 }
257 if !yield(StreamPart{Type: StreamPartTypeToolInputDelta, ID: "tool-1", Delta: `: "test"}`}) {
258 return
259 }
260 if !yield(StreamPart{Type: StreamPartTypeToolInputEnd, ID: "tool-1"}) {
261 return
262 }
263 if !yield(StreamPart{
264 Type: StreamPartTypeToolCall,
265 ID: "tool-1",
266 ToolCallName: "echo",
267 ToolCallInput: `{"message": "test"}`,
268 }) {
269 return
270 }
271 yield(StreamPart{
272 Type: StreamPartTypeFinish,
273 Usage: Usage{InputTokens: 10, OutputTokens: 5, TotalTokens: 15},
274 FinishReason: FinishReasonToolCalls,
275 })
276 } else {
277 // Second step: finish after tool execution
278 if !yield(StreamPart{Type: StreamPartTypeTextStart, ID: "text-1"}) {
279 return
280 }
281 if !yield(StreamPart{Type: StreamPartTypeTextDelta, ID: "text-1", Delta: "Tool executed successfully"}) {
282 return
283 }
284 if !yield(StreamPart{Type: StreamPartTypeTextEnd, ID: "text-1"}) {
285 return
286 }
287 yield(StreamPart{
288 Type: StreamPartTypeFinish,
289 Usage: Usage{InputTokens: 5, OutputTokens: 3, TotalTokens: 8},
290 FinishReason: FinishReasonStop,
291 })
292 }
293 }, nil
294 },
295 }
296
297 // Create agent with echo tool
298 agent := NewAgent(
299 mockModel,
300 WithSystemPrompt("You are a helpful assistant."),
301 WithTools(&EchoTool{}),
302 )
303
304 ctx := context.Background()
305
306 // Track callback invocations
307 var toolInputStartCalled bool
308 var toolInputDeltaCalled bool
309 var toolInputEndCalled bool
310 var toolCallCalled bool
311 var toolResultCalled bool
312
313 // Create streaming call with callbacks
314 streamCall := AgentStreamCall{
315 Prompt: "Echo 'test'",
316 OnToolInputStart: func(id, toolName string) error {
317 toolInputStartCalled = true
318 require.Equal(t, "tool-1", id)
319 require.Equal(t, "echo", toolName)
320 return nil
321 },
322 OnToolInputDelta: func(id, delta string) error {
323 toolInputDeltaCalled = true
324 require.Equal(t, "tool-1", id)
325 require.Contains(t, []string{`{"message"`, `: "test"}`}, delta)
326 return nil
327 },
328 OnToolInputEnd: func(id string) error {
329 toolInputEndCalled = true
330 require.Equal(t, "tool-1", id)
331 return nil
332 },
333 OnToolCall: func(toolCall ToolCallContent) error {
334 toolCallCalled = true
335 require.Equal(t, "echo", toolCall.ToolName)
336 require.Equal(t, `{"message": "test"}`, toolCall.Input)
337 return nil
338 },
339 OnToolResult: func(result ToolResultContent) error {
340 toolResultCalled = true
341 require.Equal(t, "echo", result.ToolName)
342 return nil
343 },
344 }
345
346 // Execute streaming agent
347 result, err := agent.Stream(ctx, streamCall)
348 require.NoError(t, err)
349
350 // Verify results
351 require.True(t, toolInputStartCalled, "OnToolInputStart should have been called")
352 require.True(t, toolInputDeltaCalled, "OnToolInputDelta should have been called")
353 require.True(t, toolInputEndCalled, "OnToolInputEnd should have been called")
354 require.True(t, toolCallCalled, "OnToolCall should have been called")
355 require.True(t, toolResultCalled, "OnToolResult should have been called")
356 require.Equal(t, 2, len(result.Steps)) // Two steps: tool call + final response
357
358 // Check that tool was executed in first step
359 firstStep := result.Steps[0]
360 toolCalls := firstStep.Content.ToolCalls()
361 require.Equal(t, 1, len(toolCalls))
362 require.Equal(t, "echo", toolCalls[0].ToolName)
363
364 toolResults := firstStep.Content.ToolResults()
365 require.Equal(t, 1, len(toolResults))
366 require.Equal(t, "echo", toolResults[0].ToolName)
367}
368
369// TestStreamingAgentTextDeltas tests text streaming (mirrors TS textStream tests)
370func TestStreamingAgentTextDeltas(t *testing.T) {
371 t.Parallel()
372
373 // Create a mock language model that returns text deltas
374 mockModel := &mockLanguageModel{
375 streamFunc: func(ctx context.Context, call Call) (StreamResponse, error) {
376 return func(yield func(StreamPart) bool) {
377 if !yield(StreamPart{Type: StreamPartTypeTextStart, ID: "text-1"}) {
378 return
379 }
380 if !yield(StreamPart{Type: StreamPartTypeTextDelta, ID: "text-1", Delta: "Hello"}) {
381 return
382 }
383 if !yield(StreamPart{Type: StreamPartTypeTextDelta, ID: "text-1", Delta: ", "}) {
384 return
385 }
386 if !yield(StreamPart{Type: StreamPartTypeTextDelta, ID: "text-1", Delta: "world!"}) {
387 return
388 }
389 if !yield(StreamPart{Type: StreamPartTypeTextEnd, ID: "text-1"}) {
390 return
391 }
392 yield(StreamPart{
393 Type: StreamPartTypeFinish,
394 Usage: Usage{InputTokens: 3, OutputTokens: 10, TotalTokens: 13},
395 FinishReason: FinishReasonStop,
396 })
397 }, nil
398 },
399 }
400
401 agent := NewAgent(mockModel)
402 ctx := context.Background()
403
404 // Track text deltas
405 var textDeltas []string
406
407 streamCall := AgentStreamCall{
408 Prompt: "Say hello",
409 OnTextDelta: func(id, text string) error {
410 if text != "" {
411 textDeltas = append(textDeltas, text)
412 }
413 return nil
414 },
415 }
416
417 result, err := agent.Stream(ctx, streamCall)
418 require.NoError(t, err)
419
420 // Verify text deltas match expected pattern
421 require.Equal(t, []string{"Hello", ", ", "world!"}, textDeltas)
422 require.Equal(t, "Hello, world!", result.Response.Content.Text())
423 require.Equal(t, int64(13), result.TotalUsage.TotalTokens)
424}
425
426// TestStreamingAgentReasoning tests reasoning content (mirrors TS reasoning tests)
427func TestStreamingAgentReasoning(t *testing.T) {
428 t.Parallel()
429
430 mockModel := &mockLanguageModel{
431 streamFunc: func(ctx context.Context, call Call) (StreamResponse, error) {
432 return func(yield func(StreamPart) bool) {
433 if !yield(StreamPart{Type: StreamPartTypeReasoningStart, ID: "reasoning-1"}) {
434 return
435 }
436 if !yield(StreamPart{Type: StreamPartTypeReasoningDelta, ID: "reasoning-1", Delta: "I will open the conversation"}) {
437 return
438 }
439 if !yield(StreamPart{Type: StreamPartTypeReasoningDelta, ID: "reasoning-1", Delta: " with witty banter."}) {
440 return
441 }
442 if !yield(StreamPart{Type: StreamPartTypeReasoningEnd, ID: "reasoning-1"}) {
443 return
444 }
445 if !yield(StreamPart{Type: StreamPartTypeTextStart, ID: "text-1"}) {
446 return
447 }
448 if !yield(StreamPart{Type: StreamPartTypeTextDelta, ID: "text-1", Delta: "Hi there!"}) {
449 return
450 }
451 if !yield(StreamPart{Type: StreamPartTypeTextEnd, ID: "text-1"}) {
452 return
453 }
454 yield(StreamPart{
455 Type: StreamPartTypeFinish,
456 Usage: Usage{InputTokens: 5, OutputTokens: 15, TotalTokens: 20},
457 FinishReason: FinishReasonStop,
458 })
459 }, nil
460 },
461 }
462
463 agent := NewAgent(mockModel)
464 ctx := context.Background()
465
466 var reasoningDeltas []string
467 var textDeltas []string
468
469 streamCall := AgentStreamCall{
470 Prompt: "Think and respond",
471 OnReasoningDelta: func(id, text string) error {
472 reasoningDeltas = append(reasoningDeltas, text)
473 return nil
474 },
475 OnTextDelta: func(id, text string) error {
476 textDeltas = append(textDeltas, text)
477 return nil
478 },
479 }
480
481 result, err := agent.Stream(ctx, streamCall)
482 require.NoError(t, err)
483
484 // Verify reasoning and text are separate
485 require.Equal(t, []string{"I will open the conversation", " with witty banter."}, reasoningDeltas)
486 require.Equal(t, []string{"Hi there!"}, textDeltas)
487 require.Equal(t, "Hi there!", result.Response.Content.Text())
488 require.Equal(t, "I will open the conversation with witty banter.", result.Response.Content.ReasoningText())
489}
490
491// TestStreamingAgentError tests error handling (mirrors TS error tests)
492func TestStreamingAgentError(t *testing.T) {
493 t.Parallel()
494
495 // Create a mock language model that returns an error
496 mockModel := &mockLanguageModel{
497 streamFunc: func(ctx context.Context, call Call) (StreamResponse, error) {
498 return func(yield func(StreamPart) bool) {
499 yield(StreamPart{Type: StreamPartTypeError, Error: fmt.Errorf("mock stream error")})
500 }, nil
501 },
502 }
503
504 agent := NewAgent(mockModel)
505 ctx := context.Background()
506
507 // Track error callbacks
508 var errorOccurred bool
509 var errorMessage string
510
511 streamCall := AgentStreamCall{
512 Prompt: "This will fail",
513
514 OnError: func(err error) {
515 errorOccurred = true
516 errorMessage = err.Error()
517 },
518 }
519
520 // Execute streaming agent
521 result, err := agent.Stream(ctx, streamCall)
522 require.Error(t, err)
523 require.Nil(t, result)
524 require.True(t, errorOccurred, "OnError should have been called")
525 require.Contains(t, errorMessage, "mock stream error")
526}
527
528// TestStreamingAgentSources tests source handling (mirrors TS source tests)
529func TestStreamingAgentSources(t *testing.T) {
530 t.Parallel()
531
532 mockModel := &mockLanguageModel{
533 streamFunc: func(ctx context.Context, call Call) (StreamResponse, error) {
534 return func(yield func(StreamPart) bool) {
535 if !yield(StreamPart{
536 Type: StreamPartTypeSource,
537 ID: "source-1",
538 SourceType: SourceTypeURL,
539 URL: "https://example.com",
540 Title: "Example",
541 }) {
542 return
543 }
544 if !yield(StreamPart{Type: StreamPartTypeTextStart, ID: "text-1"}) {
545 return
546 }
547 if !yield(StreamPart{Type: StreamPartTypeTextDelta, ID: "text-1", Delta: "Hello!"}) {
548 return
549 }
550 if !yield(StreamPart{Type: StreamPartTypeTextEnd, ID: "text-1"}) {
551 return
552 }
553 if !yield(StreamPart{
554 Type: StreamPartTypeSource,
555 ID: "source-2",
556 SourceType: SourceTypeDocument,
557 Title: "Document Example",
558 }) {
559 return
560 }
561 yield(StreamPart{
562 Type: StreamPartTypeFinish,
563 Usage: Usage{InputTokens: 3, OutputTokens: 5, TotalTokens: 8},
564 FinishReason: FinishReasonStop,
565 })
566 }, nil
567 },
568 }
569
570 agent := NewAgent(mockModel)
571 ctx := context.Background()
572
573 var sources []SourceContent
574
575 streamCall := AgentStreamCall{
576 Prompt: "Search and respond",
577 OnSource: func(source SourceContent) error {
578 sources = append(sources, source)
579 return nil
580 },
581 }
582
583 result, err := agent.Stream(ctx, streamCall)
584 require.NoError(t, err)
585
586 // Verify sources were captured
587 require.Equal(t, 2, len(sources))
588 require.Equal(t, SourceTypeURL, sources[0].SourceType)
589 require.Equal(t, "https://example.com", sources[0].URL)
590 require.Equal(t, "Example", sources[0].Title)
591 require.Equal(t, SourceTypeDocument, sources[1].SourceType)
592 require.Equal(t, "Document Example", sources[1].Title)
593
594 // Verify sources are in final result
595 resultSources := result.Response.Content.Sources()
596 require.Equal(t, 2, len(resultSources))
597}