1package fantasy
2
3import (
4 "context"
5 "encoding/json"
6 "fmt"
7 "strings"
8 "sync"
9 "testing"
10
11 "github.com/stretchr/testify/require"
12)
13
14// EchoTool is a simple tool that echoes back the input message
15type EchoTool struct {
16 providerOptions ProviderOptions
17}
18
19func (e *EchoTool) SetProviderOptions(opts ProviderOptions) {
20 e.providerOptions = opts
21}
22
23func (e *EchoTool) ProviderOptions() ProviderOptions {
24 return e.providerOptions
25}
26
27// Info returns the tool information
28func (e *EchoTool) Info() ToolInfo {
29 return ToolInfo{
30 Name: "echo",
31 Description: "Echo back the provided message",
32 Parameters: map[string]any{
33 "message": map[string]any{
34 "type": "string",
35 "description": "The message to echo back",
36 },
37 },
38 Required: []string{"message"},
39 }
40}
41
42// Run executes the echo tool
43func (e *EchoTool) Run(ctx context.Context, params ToolCall) (ToolResponse, error) {
44 var input struct {
45 Message string `json:"message"`
46 }
47
48 if err := json.Unmarshal([]byte(params.Input), &input); err != nil {
49 return NewTextErrorResponse("Invalid input: " + err.Error()), nil
50 }
51
52 if input.Message == "" {
53 return NewTextErrorResponse("Message cannot be empty"), nil
54 }
55
56 return NewTextResponse("Echo: " + input.Message), nil
57}
58
59// TestStreamingAgentCallbacks tests that all streaming callbacks are called correctly
60func TestStreamingAgentCallbacks(t *testing.T) {
61 t.Parallel()
62
63 // Track which callbacks were called
64 callbacks := make(map[string]bool)
65
66 // Create a mock language model that returns various stream parts
67 mockModel := &mockLanguageModel{
68 streamFunc: func(ctx context.Context, call Call) (StreamResponse, error) {
69 return func(yield func(StreamPart) bool) {
70 // Test all stream part types
71 if !yield(StreamPart{Type: StreamPartTypeWarnings, Warnings: []CallWarning{{Type: CallWarningTypeOther, Message: "test warning"}}}) {
72 return
73 }
74 if !yield(StreamPart{Type: StreamPartTypeTextStart, ID: "text-1"}) {
75 return
76 }
77 if !yield(StreamPart{Type: StreamPartTypeTextDelta, ID: "text-1", Delta: "Hello"}) {
78 return
79 }
80 if !yield(StreamPart{Type: StreamPartTypeTextEnd, ID: "text-1"}) {
81 return
82 }
83 if !yield(StreamPart{Type: StreamPartTypeReasoningStart, ID: "reasoning-1"}) {
84 return
85 }
86 if !yield(StreamPart{Type: StreamPartTypeReasoningDelta, ID: "reasoning-1", Delta: "thinking..."}) {
87 return
88 }
89 if !yield(StreamPart{Type: StreamPartTypeReasoningEnd, ID: "reasoning-1"}) {
90 return
91 }
92 if !yield(StreamPart{Type: StreamPartTypeToolInputStart, ID: "tool-1", ToolCallName: "test_tool"}) {
93 return
94 }
95 if !yield(StreamPart{Type: StreamPartTypeToolInputDelta, ID: "tool-1", Delta: `{"param"`}) {
96 return
97 }
98 if !yield(StreamPart{Type: StreamPartTypeToolInputEnd, ID: "tool-1"}) {
99 return
100 }
101 if !yield(StreamPart{Type: StreamPartTypeSource, ID: "source-1", SourceType: SourceTypeURL, URL: "https://example.com", Title: "Example"}) {
102 return
103 }
104 yield(StreamPart{
105 Type: StreamPartTypeFinish,
106 Usage: Usage{InputTokens: 5, OutputTokens: 2, TotalTokens: 7},
107 FinishReason: FinishReasonStop,
108 })
109 }, nil
110 },
111 }
112
113 // Create agent
114 agent := NewAgent(mockModel)
115
116 ctx := context.Background()
117
118 // Create streaming call with all callbacks
119 streamCall := AgentStreamCall{
120 Prompt: "Test all callbacks",
121 OnAgentStart: func() {
122 callbacks["OnAgentStart"] = true
123 },
124 OnAgentFinish: func(result *AgentResult) error {
125 callbacks["OnAgentFinish"] = true
126 return nil
127 },
128 OnStepStart: func(stepNumber int) error {
129 callbacks["OnStepStart"] = true
130 return nil
131 },
132 OnStepFinish: func(stepResult StepResult) error {
133 callbacks["OnStepFinish"] = true
134 return nil
135 },
136 OnFinish: func(result *AgentResult) {
137 callbacks["OnFinish"] = true
138 },
139 OnError: func(err error) {
140 callbacks["OnError"] = true
141 },
142 OnChunk: func(part StreamPart) error {
143 callbacks["OnChunk"] = true
144 return nil
145 },
146 OnWarnings: func(warnings []CallWarning) error {
147 callbacks["OnWarnings"] = true
148 return nil
149 },
150 OnTextStart: func(id string) error {
151 callbacks["OnTextStart"] = true
152 return nil
153 },
154 OnTextDelta: func(id, text string) error {
155 callbacks["OnTextDelta"] = true
156 return nil
157 },
158 OnTextEnd: func(id string) error {
159 callbacks["OnTextEnd"] = true
160 return nil
161 },
162 OnReasoningStart: func(id string, _ ReasoningContent) error {
163 callbacks["OnReasoningStart"] = true
164 return nil
165 },
166 OnReasoningDelta: func(id, text string) error {
167 callbacks["OnReasoningDelta"] = true
168 return nil
169 },
170 OnReasoningEnd: func(id string, content ReasoningContent) error {
171 callbacks["OnReasoningEnd"] = true
172 return nil
173 },
174 OnToolInputStart: func(id, toolName string) error {
175 callbacks["OnToolInputStart"] = true
176 return nil
177 },
178 OnToolInputDelta: func(id, delta string) error {
179 callbacks["OnToolInputDelta"] = true
180 return nil
181 },
182 OnToolInputEnd: func(id string) error {
183 callbacks["OnToolInputEnd"] = true
184 return nil
185 },
186 OnToolCall: func(toolCall ToolCallContent) error {
187 callbacks["OnToolCall"] = true
188 return nil
189 },
190 OnToolResult: func(result ToolResultContent) error {
191 callbacks["OnToolResult"] = true
192 return nil
193 },
194 OnSource: func(source SourceContent) error {
195 callbacks["OnSource"] = true
196 return nil
197 },
198 OnStreamFinish: func(usage Usage, finishReason FinishReason, providerMetadata ProviderMetadata) error {
199 callbacks["OnStreamFinish"] = true
200 return nil
201 },
202 }
203
204 // Execute streaming agent
205 result, err := agent.Stream(ctx, streamCall)
206 require.NoError(t, err)
207 require.NotNil(t, result)
208
209 // Verify that expected callbacks were called
210 expectedCallbacks := []string{
211 "OnAgentStart",
212 "OnAgentFinish",
213 "OnStepStart",
214 "OnStepFinish",
215 "OnFinish",
216 "OnChunk",
217 "OnWarnings",
218 "OnTextStart",
219 "OnTextDelta",
220 "OnTextEnd",
221 "OnReasoningStart",
222 "OnReasoningDelta",
223 "OnReasoningEnd",
224 "OnToolInputStart",
225 "OnToolInputDelta",
226 "OnToolInputEnd",
227 "OnSource",
228 "OnStreamFinish",
229 }
230
231 for _, callback := range expectedCallbacks {
232 require.True(t, callbacks[callback], "Expected callback %s to be called", callback)
233 }
234
235 // Verify that error callbacks were not called
236 require.False(t, callbacks["OnError"], "OnError should not be called in successful case")
237 require.False(t, callbacks["OnToolCall"], "OnToolCall should not be called without actual tool calls")
238 require.False(t, callbacks["OnToolResult"], "OnToolResult should not be called without actual tool results")
239}
240
241// TestStreamingAgentWithTools tests streaming agent with tool calls (mirrors TS test patterns)
242func TestStreamingAgentWithTools(t *testing.T) {
243 t.Parallel()
244
245 stepCount := 0
246 // Create a mock language model that makes a tool call then finishes
247 mockModel := &mockLanguageModel{
248 streamFunc: func(ctx context.Context, call Call) (StreamResponse, error) {
249 stepCount++
250 return func(yield func(StreamPart) bool) {
251 if stepCount == 1 {
252 // First step: make tool call
253 if !yield(StreamPart{Type: StreamPartTypeToolInputStart, ID: "tool-1", ToolCallName: "echo"}) {
254 return
255 }
256 if !yield(StreamPart{Type: StreamPartTypeToolInputDelta, ID: "tool-1", Delta: `{"message"`}) {
257 return
258 }
259 if !yield(StreamPart{Type: StreamPartTypeToolInputDelta, ID: "tool-1", Delta: `: "test"}`}) {
260 return
261 }
262 if !yield(StreamPart{Type: StreamPartTypeToolInputEnd, ID: "tool-1"}) {
263 return
264 }
265 if !yield(StreamPart{
266 Type: StreamPartTypeToolCall,
267 ID: "tool-1",
268 ToolCallName: "echo",
269 ToolCallInput: `{"message": "test"}`,
270 }) {
271 return
272 }
273 yield(StreamPart{
274 Type: StreamPartTypeFinish,
275 Usage: Usage{InputTokens: 10, OutputTokens: 5, TotalTokens: 15},
276 FinishReason: FinishReasonToolCalls,
277 })
278 } else {
279 // Second step: finish after tool execution
280 if !yield(StreamPart{Type: StreamPartTypeTextStart, ID: "text-1"}) {
281 return
282 }
283 if !yield(StreamPart{Type: StreamPartTypeTextDelta, ID: "text-1", Delta: "Tool executed successfully"}) {
284 return
285 }
286 if !yield(StreamPart{Type: StreamPartTypeTextEnd, ID: "text-1"}) {
287 return
288 }
289 yield(StreamPart{
290 Type: StreamPartTypeFinish,
291 Usage: Usage{InputTokens: 5, OutputTokens: 3, TotalTokens: 8},
292 FinishReason: FinishReasonStop,
293 })
294 }
295 }, nil
296 },
297 }
298
299 // Create agent with echo tool
300 agent := NewAgent(
301 mockModel,
302 WithSystemPrompt("You are a helpful assistant."),
303 WithTools(&EchoTool{}),
304 )
305
306 ctx := context.Background()
307
308 // Track callback invocations
309 var toolInputStartCalled bool
310 var toolInputDeltaCalled bool
311 var toolInputEndCalled bool
312 var toolCallCalled bool
313 var toolResultCalled bool
314
315 // Create streaming call with callbacks
316 streamCall := AgentStreamCall{
317 Prompt: "Echo 'test'",
318 OnToolInputStart: func(id, toolName string) error {
319 toolInputStartCalled = true
320 require.Equal(t, "tool-1", id)
321 require.Equal(t, "echo", toolName)
322 return nil
323 },
324 OnToolInputDelta: func(id, delta string) error {
325 toolInputDeltaCalled = true
326 require.Equal(t, "tool-1", id)
327 require.Contains(t, []string{`{"message"`, `: "test"}`}, delta)
328 return nil
329 },
330 OnToolInputEnd: func(id string) error {
331 toolInputEndCalled = true
332 require.Equal(t, "tool-1", id)
333 return nil
334 },
335 OnToolCall: func(toolCall ToolCallContent) error {
336 toolCallCalled = true
337 require.Equal(t, "echo", toolCall.ToolName)
338 require.Equal(t, `{"message": "test"}`, toolCall.Input)
339 return nil
340 },
341 OnToolResult: func(result ToolResultContent) error {
342 toolResultCalled = true
343 require.Equal(t, "echo", result.ToolName)
344 return nil
345 },
346 }
347
348 // Execute streaming agent
349 result, err := agent.Stream(ctx, streamCall)
350 require.NoError(t, err)
351
352 // Verify results
353 require.True(t, toolInputStartCalled, "OnToolInputStart should have been called")
354 require.True(t, toolInputDeltaCalled, "OnToolInputDelta should have been called")
355 require.True(t, toolInputEndCalled, "OnToolInputEnd should have been called")
356 require.True(t, toolCallCalled, "OnToolCall should have been called")
357 require.True(t, toolResultCalled, "OnToolResult should have been called")
358 require.Equal(t, 2, len(result.Steps)) // Two steps: tool call + final response
359
360 // Check that tool was executed in first step
361 firstStep := result.Steps[0]
362 toolCalls := firstStep.Content.ToolCalls()
363 require.Equal(t, 1, len(toolCalls))
364 require.Equal(t, "echo", toolCalls[0].ToolName)
365
366 toolResults := firstStep.Content.ToolResults()
367 require.Equal(t, 1, len(toolResults))
368 require.Equal(t, "echo", toolResults[0].ToolName)
369}
370
371// TestStreamingAgentToolCallBeforeResult verifies that all OnToolCall callbacks
372// complete before any OnToolResult fires. This is the ordering guarantee
373// provided by buffering dispatches until the stream is fully consumed.
374func TestStreamingAgentToolCallBeforeResult(t *testing.T) {
375 t.Parallel()
376
377 stepCount := 0
378 mockModel := &mockLanguageModel{
379 streamFunc: func(ctx context.Context, call Call) (StreamResponse, error) {
380 stepCount++
381 return func(yield func(StreamPart) bool) {
382 if stepCount == 1 {
383 // Emit two tool calls in the same step.
384 for _, id := range []string{"tool-1", "tool-2"} {
385 if !yield(StreamPart{Type: StreamPartTypeToolInputStart, ID: id, ToolCallName: "echo"}) {
386 return
387 }
388 if !yield(StreamPart{Type: StreamPartTypeToolInputDelta, ID: id, Delta: `{"message": "` + id + `"}`}) {
389 return
390 }
391 if !yield(StreamPart{Type: StreamPartTypeToolInputEnd, ID: id}) {
392 return
393 }
394 if !yield(StreamPart{
395 Type: StreamPartTypeToolCall,
396 ID: id,
397 ToolCallName: "echo",
398 ToolCallInput: `{"message": "` + id + `"}`,
399 }) {
400 return
401 }
402 }
403 yield(StreamPart{
404 Type: StreamPartTypeFinish,
405 FinishReason: FinishReasonToolCalls,
406 })
407 } else {
408 yield(StreamPart{
409 Type: StreamPartTypeFinish,
410 FinishReason: FinishReasonStop,
411 })
412 }
413 }, nil
414 },
415 }
416
417 agent := NewAgent(mockModel, WithTools(&EchoTool{}))
418
419 var mu sync.Mutex
420 var events []string
421
422 _, err := agent.Stream(context.Background(), AgentStreamCall{
423 Prompt: "echo twice",
424 OnToolCall: func(tc ToolCallContent) error {
425 mu.Lock()
426 events = append(events, "call:"+tc.ToolCallID)
427 mu.Unlock()
428 return nil
429 },
430 OnToolResult: func(tr ToolResultContent) error {
431 mu.Lock()
432 events = append(events, "result:"+tr.ToolCallID)
433 mu.Unlock()
434 return nil
435 },
436 })
437 require.NoError(t, err)
438
439 // Both OnToolCall events must appear before any OnToolResult event.
440 lastCallIdx := -1
441 firstResultIdx := len(events)
442 for i, e := range events {
443 if strings.HasPrefix(e, "call:") {
444 lastCallIdx = i
445 }
446 if strings.HasPrefix(e, "result:") && i < firstResultIdx {
447 firstResultIdx = i
448 }
449 }
450 require.Equal(t, 2, stepCount)
451 require.Less(t, lastCallIdx, firstResultIdx,
452 "all OnToolCall events must complete before the first OnToolResult; got %v", events)
453}
454
455// TestStreamingAgentTextDeltas tests text streaming (mirrors TS textStream tests)
456func TestStreamingAgentTextDeltas(t *testing.T) {
457 t.Parallel()
458
459 // Create a mock language model that returns text deltas
460 mockModel := &mockLanguageModel{
461 streamFunc: func(ctx context.Context, call Call) (StreamResponse, error) {
462 return func(yield func(StreamPart) bool) {
463 if !yield(StreamPart{Type: StreamPartTypeTextStart, ID: "text-1"}) {
464 return
465 }
466 if !yield(StreamPart{Type: StreamPartTypeTextDelta, ID: "text-1", Delta: "Hello"}) {
467 return
468 }
469 if !yield(StreamPart{Type: StreamPartTypeTextDelta, ID: "text-1", Delta: ", "}) {
470 return
471 }
472 if !yield(StreamPart{Type: StreamPartTypeTextDelta, ID: "text-1", Delta: "world!"}) {
473 return
474 }
475 if !yield(StreamPart{Type: StreamPartTypeTextEnd, ID: "text-1"}) {
476 return
477 }
478 yield(StreamPart{
479 Type: StreamPartTypeFinish,
480 Usage: Usage{InputTokens: 3, OutputTokens: 10, TotalTokens: 13},
481 FinishReason: FinishReasonStop,
482 })
483 }, nil
484 },
485 }
486
487 agent := NewAgent(mockModel)
488 ctx := context.Background()
489
490 // Track text deltas
491 var textDeltas []string
492
493 streamCall := AgentStreamCall{
494 Prompt: "Say hello",
495 OnTextDelta: func(id, text string) error {
496 if text != "" {
497 textDeltas = append(textDeltas, text)
498 }
499 return nil
500 },
501 }
502
503 result, err := agent.Stream(ctx, streamCall)
504 require.NoError(t, err)
505
506 // Verify text deltas match expected pattern
507 require.Equal(t, []string{"Hello", ", ", "world!"}, textDeltas)
508 require.Equal(t, "Hello, world!", result.Response.Content.Text())
509 require.Equal(t, int64(13), result.TotalUsage.TotalTokens)
510}
511
512// TestStreamingAgentReasoning tests reasoning content (mirrors TS reasoning tests)
513func TestStreamingAgentReasoning(t *testing.T) {
514 t.Parallel()
515
516 mockModel := &mockLanguageModel{
517 streamFunc: func(ctx context.Context, call Call) (StreamResponse, error) {
518 return func(yield func(StreamPart) bool) {
519 if !yield(StreamPart{Type: StreamPartTypeReasoningStart, ID: "reasoning-1"}) {
520 return
521 }
522 if !yield(StreamPart{Type: StreamPartTypeReasoningDelta, ID: "reasoning-1", Delta: "I will open the conversation"}) {
523 return
524 }
525 if !yield(StreamPart{Type: StreamPartTypeReasoningDelta, ID: "reasoning-1", Delta: " with witty banter."}) {
526 return
527 }
528 if !yield(StreamPart{Type: StreamPartTypeReasoningEnd, ID: "reasoning-1"}) {
529 return
530 }
531 if !yield(StreamPart{Type: StreamPartTypeTextStart, ID: "text-1"}) {
532 return
533 }
534 if !yield(StreamPart{Type: StreamPartTypeTextDelta, ID: "text-1", Delta: "Hi there!"}) {
535 return
536 }
537 if !yield(StreamPart{Type: StreamPartTypeTextEnd, ID: "text-1"}) {
538 return
539 }
540 yield(StreamPart{
541 Type: StreamPartTypeFinish,
542 Usage: Usage{InputTokens: 5, OutputTokens: 15, TotalTokens: 20},
543 FinishReason: FinishReasonStop,
544 })
545 }, nil
546 },
547 }
548
549 agent := NewAgent(mockModel)
550 ctx := context.Background()
551
552 var reasoningDeltas []string
553 var textDeltas []string
554
555 streamCall := AgentStreamCall{
556 Prompt: "Think and respond",
557 OnReasoningDelta: func(id, text string) error {
558 reasoningDeltas = append(reasoningDeltas, text)
559 return nil
560 },
561 OnTextDelta: func(id, text string) error {
562 textDeltas = append(textDeltas, text)
563 return nil
564 },
565 }
566
567 result, err := agent.Stream(ctx, streamCall)
568 require.NoError(t, err)
569
570 // Verify reasoning and text are separate
571 require.Equal(t, []string{"I will open the conversation", " with witty banter."}, reasoningDeltas)
572 require.Equal(t, []string{"Hi there!"}, textDeltas)
573 require.Equal(t, "Hi there!", result.Response.Content.Text())
574 require.Equal(t, "I will open the conversation with witty banter.", result.Response.Content.ReasoningText())
575}
576
577// TestStreamingAgentError tests error handling (mirrors TS error tests)
578func TestStreamingAgentError(t *testing.T) {
579 t.Parallel()
580
581 // Create a mock language model that returns an error
582 mockModel := &mockLanguageModel{
583 streamFunc: func(ctx context.Context, call Call) (StreamResponse, error) {
584 return func(yield func(StreamPart) bool) {
585 yield(StreamPart{Type: StreamPartTypeError, Error: fmt.Errorf("mock stream error")})
586 }, nil
587 },
588 }
589
590 agent := NewAgent(mockModel)
591 ctx := context.Background()
592
593 // Track error callbacks
594 var errorOccurred bool
595 var errorMessage string
596
597 streamCall := AgentStreamCall{
598 Prompt: "This will fail",
599
600 OnError: func(err error) {
601 errorOccurred = true
602 errorMessage = err.Error()
603 },
604 }
605
606 // Execute streaming agent
607 result, err := agent.Stream(ctx, streamCall)
608 require.Error(t, err)
609 require.Nil(t, result)
610 require.True(t, errorOccurred, "OnError should have been called")
611 require.Contains(t, errorMessage, "mock stream error")
612}
613
614// TestStreamingAgentSources tests source handling (mirrors TS source tests)
615func TestStreamingAgentSources(t *testing.T) {
616 t.Parallel()
617
618 mockModel := &mockLanguageModel{
619 streamFunc: func(ctx context.Context, call Call) (StreamResponse, error) {
620 return func(yield func(StreamPart) bool) {
621 if !yield(StreamPart{
622 Type: StreamPartTypeSource,
623 ID: "source-1",
624 SourceType: SourceTypeURL,
625 URL: "https://example.com",
626 Title: "Example",
627 }) {
628 return
629 }
630 if !yield(StreamPart{Type: StreamPartTypeTextStart, ID: "text-1"}) {
631 return
632 }
633 if !yield(StreamPart{Type: StreamPartTypeTextDelta, ID: "text-1", Delta: "Hello!"}) {
634 return
635 }
636 if !yield(StreamPart{Type: StreamPartTypeTextEnd, ID: "text-1"}) {
637 return
638 }
639 if !yield(StreamPart{
640 Type: StreamPartTypeSource,
641 ID: "source-2",
642 SourceType: SourceTypeDocument,
643 Title: "Document Example",
644 }) {
645 return
646 }
647 yield(StreamPart{
648 Type: StreamPartTypeFinish,
649 Usage: Usage{InputTokens: 3, OutputTokens: 5, TotalTokens: 8},
650 FinishReason: FinishReasonStop,
651 })
652 }, nil
653 },
654 }
655
656 agent := NewAgent(mockModel)
657 ctx := context.Background()
658
659 var sources []SourceContent
660
661 streamCall := AgentStreamCall{
662 Prompt: "Search and respond",
663 OnSource: func(source SourceContent) error {
664 sources = append(sources, source)
665 return nil
666 },
667 }
668
669 result, err := agent.Stream(ctx, streamCall)
670 require.NoError(t, err)
671
672 // Verify sources were captured
673 require.Equal(t, 2, len(sources))
674 require.Equal(t, SourceTypeURL, sources[0].SourceType)
675 require.Equal(t, "https://example.com", sources[0].URL)
676 require.Equal(t, "Example", sources[0].Title)
677 require.Equal(t, SourceTypeDocument, sources[1].SourceType)
678 require.Equal(t, "Document Example", sources[1].Title)
679
680 // Verify sources are in final result
681 resultSources := result.Response.Content.Sources()
682 require.Equal(t, 2, len(resultSources))
683}