agent_stream_test.go

  1package fantasy
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"fmt"
  7	"strings"
  8	"sync"
  9	"testing"
 10
 11	"github.com/stretchr/testify/require"
 12)
 13
 14// EchoTool is a simple tool that echoes back the input message
 15type EchoTool struct {
 16	providerOptions ProviderOptions
 17}
 18
 19func (e *EchoTool) SetProviderOptions(opts ProviderOptions) {
 20	e.providerOptions = opts
 21}
 22
 23func (e *EchoTool) ProviderOptions() ProviderOptions {
 24	return e.providerOptions
 25}
 26
 27// Info returns the tool information
 28func (e *EchoTool) Info() ToolInfo {
 29	return ToolInfo{
 30		Name:        "echo",
 31		Description: "Echo back the provided message",
 32		Parameters: map[string]any{
 33			"message": map[string]any{
 34				"type":        "string",
 35				"description": "The message to echo back",
 36			},
 37		},
 38		Required: []string{"message"},
 39	}
 40}
 41
 42// Run executes the echo tool
 43func (e *EchoTool) Run(ctx context.Context, params ToolCall) (ToolResponse, error) {
 44	var input struct {
 45		Message string `json:"message"`
 46	}
 47
 48	if err := json.Unmarshal([]byte(params.Input), &input); err != nil {
 49		return NewTextErrorResponse("Invalid input: " + err.Error()), nil
 50	}
 51
 52	if input.Message == "" {
 53		return NewTextErrorResponse("Message cannot be empty"), nil
 54	}
 55
 56	return NewTextResponse("Echo: " + input.Message), nil
 57}
 58
 59// TestStreamingAgentCallbacks tests that all streaming callbacks are called correctly
 60func TestStreamingAgentCallbacks(t *testing.T) {
 61	t.Parallel()
 62
 63	// Track which callbacks were called
 64	callbacks := make(map[string]bool)
 65
 66	// Create a mock language model that returns various stream parts
 67	mockModel := &mockLanguageModel{
 68		streamFunc: func(ctx context.Context, call Call) (StreamResponse, error) {
 69			return func(yield func(StreamPart) bool) {
 70				// Test all stream part types
 71				if !yield(StreamPart{Type: StreamPartTypeWarnings, Warnings: []CallWarning{{Type: CallWarningTypeOther, Message: "test warning"}}}) {
 72					return
 73				}
 74				if !yield(StreamPart{Type: StreamPartTypeTextStart, ID: "text-1"}) {
 75					return
 76				}
 77				if !yield(StreamPart{Type: StreamPartTypeTextDelta, ID: "text-1", Delta: "Hello"}) {
 78					return
 79				}
 80				if !yield(StreamPart{Type: StreamPartTypeTextEnd, ID: "text-1"}) {
 81					return
 82				}
 83				if !yield(StreamPart{Type: StreamPartTypeReasoningStart, ID: "reasoning-1"}) {
 84					return
 85				}
 86				if !yield(StreamPart{Type: StreamPartTypeReasoningDelta, ID: "reasoning-1", Delta: "thinking..."}) {
 87					return
 88				}
 89				if !yield(StreamPart{Type: StreamPartTypeReasoningEnd, ID: "reasoning-1"}) {
 90					return
 91				}
 92				if !yield(StreamPart{Type: StreamPartTypeToolInputStart, ID: "tool-1", ToolCallName: "test_tool"}) {
 93					return
 94				}
 95				if !yield(StreamPart{Type: StreamPartTypeToolInputDelta, ID: "tool-1", Delta: `{"param"`}) {
 96					return
 97				}
 98				if !yield(StreamPart{Type: StreamPartTypeToolInputEnd, ID: "tool-1"}) {
 99					return
100				}
101				if !yield(StreamPart{Type: StreamPartTypeSource, ID: "source-1", SourceType: SourceTypeURL, URL: "https://example.com", Title: "Example"}) {
102					return
103				}
104				yield(StreamPart{
105					Type:         StreamPartTypeFinish,
106					Usage:        Usage{InputTokens: 5, OutputTokens: 2, TotalTokens: 7},
107					FinishReason: FinishReasonStop,
108				})
109			}, nil
110		},
111	}
112
113	// Create agent
114	agent := NewAgent(mockModel)
115
116	ctx := context.Background()
117
118	// Create streaming call with all callbacks
119	streamCall := AgentStreamCall{
120		Prompt: "Test all callbacks",
121		OnAgentStart: func() {
122			callbacks["OnAgentStart"] = true
123		},
124		OnAgentFinish: func(result *AgentResult) error {
125			callbacks["OnAgentFinish"] = true
126			return nil
127		},
128		OnStepStart: func(stepNumber int) error {
129			callbacks["OnStepStart"] = true
130			return nil
131		},
132		OnStepFinish: func(stepResult StepResult) error {
133			callbacks["OnStepFinish"] = true
134			return nil
135		},
136		OnFinish: func(result *AgentResult) {
137			callbacks["OnFinish"] = true
138		},
139		OnError: func(err error) {
140			callbacks["OnError"] = true
141		},
142		OnChunk: func(part StreamPart) error {
143			callbacks["OnChunk"] = true
144			return nil
145		},
146		OnWarnings: func(warnings []CallWarning) error {
147			callbacks["OnWarnings"] = true
148			return nil
149		},
150		OnTextStart: func(id string) error {
151			callbacks["OnTextStart"] = true
152			return nil
153		},
154		OnTextDelta: func(id, text string) error {
155			callbacks["OnTextDelta"] = true
156			return nil
157		},
158		OnTextEnd: func(id string) error {
159			callbacks["OnTextEnd"] = true
160			return nil
161		},
162		OnReasoningStart: func(id string, _ ReasoningContent) error {
163			callbacks["OnReasoningStart"] = true
164			return nil
165		},
166		OnReasoningDelta: func(id, text string) error {
167			callbacks["OnReasoningDelta"] = true
168			return nil
169		},
170		OnReasoningEnd: func(id string, content ReasoningContent) error {
171			callbacks["OnReasoningEnd"] = true
172			return nil
173		},
174		OnToolInputStart: func(id, toolName string) error {
175			callbacks["OnToolInputStart"] = true
176			return nil
177		},
178		OnToolInputDelta: func(id, delta string) error {
179			callbacks["OnToolInputDelta"] = true
180			return nil
181		},
182		OnToolInputEnd: func(id string) error {
183			callbacks["OnToolInputEnd"] = true
184			return nil
185		},
186		OnToolCall: func(toolCall ToolCallContent) error {
187			callbacks["OnToolCall"] = true
188			return nil
189		},
190		OnToolResult: func(result ToolResultContent) error {
191			callbacks["OnToolResult"] = true
192			return nil
193		},
194		OnSource: func(source SourceContent) error {
195			callbacks["OnSource"] = true
196			return nil
197		},
198		OnStreamFinish: func(usage Usage, finishReason FinishReason, providerMetadata ProviderMetadata) error {
199			callbacks["OnStreamFinish"] = true
200			return nil
201		},
202	}
203
204	// Execute streaming agent
205	result, err := agent.Stream(ctx, streamCall)
206	require.NoError(t, err)
207	require.NotNil(t, result)
208
209	// Verify that expected callbacks were called
210	expectedCallbacks := []string{
211		"OnAgentStart",
212		"OnAgentFinish",
213		"OnStepStart",
214		"OnStepFinish",
215		"OnFinish",
216		"OnChunk",
217		"OnWarnings",
218		"OnTextStart",
219		"OnTextDelta",
220		"OnTextEnd",
221		"OnReasoningStart",
222		"OnReasoningDelta",
223		"OnReasoningEnd",
224		"OnToolInputStart",
225		"OnToolInputDelta",
226		"OnToolInputEnd",
227		"OnSource",
228		"OnStreamFinish",
229	}
230
231	for _, callback := range expectedCallbacks {
232		require.True(t, callbacks[callback], "Expected callback %s to be called", callback)
233	}
234
235	// Verify that error callbacks were not called
236	require.False(t, callbacks["OnError"], "OnError should not be called in successful case")
237	require.False(t, callbacks["OnToolCall"], "OnToolCall should not be called without actual tool calls")
238	require.False(t, callbacks["OnToolResult"], "OnToolResult should not be called without actual tool results")
239}
240
241// TestStreamingAgentWithTools tests streaming agent with tool calls (mirrors TS test patterns)
242func TestStreamingAgentWithTools(t *testing.T) {
243	t.Parallel()
244
245	stepCount := 0
246	// Create a mock language model that makes a tool call then finishes
247	mockModel := &mockLanguageModel{
248		streamFunc: func(ctx context.Context, call Call) (StreamResponse, error) {
249			stepCount++
250			return func(yield func(StreamPart) bool) {
251				if stepCount == 1 {
252					// First step: make tool call
253					if !yield(StreamPart{Type: StreamPartTypeToolInputStart, ID: "tool-1", ToolCallName: "echo"}) {
254						return
255					}
256					if !yield(StreamPart{Type: StreamPartTypeToolInputDelta, ID: "tool-1", Delta: `{"message"`}) {
257						return
258					}
259					if !yield(StreamPart{Type: StreamPartTypeToolInputDelta, ID: "tool-1", Delta: `: "test"}`}) {
260						return
261					}
262					if !yield(StreamPart{Type: StreamPartTypeToolInputEnd, ID: "tool-1"}) {
263						return
264					}
265					if !yield(StreamPart{
266						Type:          StreamPartTypeToolCall,
267						ID:            "tool-1",
268						ToolCallName:  "echo",
269						ToolCallInput: `{"message": "test"}`,
270					}) {
271						return
272					}
273					yield(StreamPart{
274						Type:         StreamPartTypeFinish,
275						Usage:        Usage{InputTokens: 10, OutputTokens: 5, TotalTokens: 15},
276						FinishReason: FinishReasonToolCalls,
277					})
278				} else {
279					// Second step: finish after tool execution
280					if !yield(StreamPart{Type: StreamPartTypeTextStart, ID: "text-1"}) {
281						return
282					}
283					if !yield(StreamPart{Type: StreamPartTypeTextDelta, ID: "text-1", Delta: "Tool executed successfully"}) {
284						return
285					}
286					if !yield(StreamPart{Type: StreamPartTypeTextEnd, ID: "text-1"}) {
287						return
288					}
289					yield(StreamPart{
290						Type:         StreamPartTypeFinish,
291						Usage:        Usage{InputTokens: 5, OutputTokens: 3, TotalTokens: 8},
292						FinishReason: FinishReasonStop,
293					})
294				}
295			}, nil
296		},
297	}
298
299	// Create agent with echo tool
300	agent := NewAgent(
301		mockModel,
302		WithSystemPrompt("You are a helpful assistant."),
303		WithTools(&EchoTool{}),
304	)
305
306	ctx := context.Background()
307
308	// Track callback invocations
309	var toolInputStartCalled bool
310	var toolInputDeltaCalled bool
311	var toolInputEndCalled bool
312	var toolCallCalled bool
313	var toolResultCalled bool
314
315	// Create streaming call with callbacks
316	streamCall := AgentStreamCall{
317		Prompt: "Echo 'test'",
318		OnToolInputStart: func(id, toolName string) error {
319			toolInputStartCalled = true
320			require.Equal(t, "tool-1", id)
321			require.Equal(t, "echo", toolName)
322			return nil
323		},
324		OnToolInputDelta: func(id, delta string) error {
325			toolInputDeltaCalled = true
326			require.Equal(t, "tool-1", id)
327			require.Contains(t, []string{`{"message"`, `: "test"}`}, delta)
328			return nil
329		},
330		OnToolInputEnd: func(id string) error {
331			toolInputEndCalled = true
332			require.Equal(t, "tool-1", id)
333			return nil
334		},
335		OnToolCall: func(toolCall ToolCallContent) error {
336			toolCallCalled = true
337			require.Equal(t, "echo", toolCall.ToolName)
338			require.Equal(t, `{"message": "test"}`, toolCall.Input)
339			return nil
340		},
341		OnToolResult: func(result ToolResultContent) error {
342			toolResultCalled = true
343			require.Equal(t, "echo", result.ToolName)
344			return nil
345		},
346	}
347
348	// Execute streaming agent
349	result, err := agent.Stream(ctx, streamCall)
350	require.NoError(t, err)
351
352	// Verify results
353	require.True(t, toolInputStartCalled, "OnToolInputStart should have been called")
354	require.True(t, toolInputDeltaCalled, "OnToolInputDelta should have been called")
355	require.True(t, toolInputEndCalled, "OnToolInputEnd should have been called")
356	require.True(t, toolCallCalled, "OnToolCall should have been called")
357	require.True(t, toolResultCalled, "OnToolResult should have been called")
358	require.Equal(t, 2, len(result.Steps)) // Two steps: tool call + final response
359
360	// Check that tool was executed in first step
361	firstStep := result.Steps[0]
362	toolCalls := firstStep.Content.ToolCalls()
363	require.Equal(t, 1, len(toolCalls))
364	require.Equal(t, "echo", toolCalls[0].ToolName)
365
366	toolResults := firstStep.Content.ToolResults()
367	require.Equal(t, 1, len(toolResults))
368	require.Equal(t, "echo", toolResults[0].ToolName)
369}
370
371// TestStreamingAgentToolCallBeforeResult verifies that all OnToolCall callbacks
372// complete before any OnToolResult fires. This is the ordering guarantee
373// provided by buffering dispatches until the stream is fully consumed.
374func TestStreamingAgentToolCallBeforeResult(t *testing.T) {
375	t.Parallel()
376
377	stepCount := 0
378	mockModel := &mockLanguageModel{
379		streamFunc: func(ctx context.Context, call Call) (StreamResponse, error) {
380			stepCount++
381			return func(yield func(StreamPart) bool) {
382				if stepCount == 1 {
383					// Emit two tool calls in the same step.
384					for _, id := range []string{"tool-1", "tool-2"} {
385						if !yield(StreamPart{Type: StreamPartTypeToolInputStart, ID: id, ToolCallName: "echo"}) {
386							return
387						}
388						if !yield(StreamPart{Type: StreamPartTypeToolInputDelta, ID: id, Delta: `{"message": "` + id + `"}`}) {
389							return
390						}
391						if !yield(StreamPart{Type: StreamPartTypeToolInputEnd, ID: id}) {
392							return
393						}
394						if !yield(StreamPart{
395							Type:          StreamPartTypeToolCall,
396							ID:            id,
397							ToolCallName:  "echo",
398							ToolCallInput: `{"message": "` + id + `"}`,
399						}) {
400							return
401						}
402					}
403					yield(StreamPart{
404						Type:         StreamPartTypeFinish,
405						FinishReason: FinishReasonToolCalls,
406					})
407				} else {
408					yield(StreamPart{
409						Type:         StreamPartTypeFinish,
410						FinishReason: FinishReasonStop,
411					})
412				}
413			}, nil
414		},
415	}
416
417	agent := NewAgent(mockModel, WithTools(&EchoTool{}))
418
419	var mu sync.Mutex
420	var events []string
421
422	_, err := agent.Stream(context.Background(), AgentStreamCall{
423		Prompt: "echo twice",
424		OnToolCall: func(tc ToolCallContent) error {
425			mu.Lock()
426			events = append(events, "call:"+tc.ToolCallID)
427			mu.Unlock()
428			return nil
429		},
430		OnToolResult: func(tr ToolResultContent) error {
431			mu.Lock()
432			events = append(events, "result:"+tr.ToolCallID)
433			mu.Unlock()
434			return nil
435		},
436	})
437	require.NoError(t, err)
438
439	// Both OnToolCall events must appear before any OnToolResult event.
440	lastCallIdx := -1
441	firstResultIdx := len(events)
442	for i, e := range events {
443		if strings.HasPrefix(e, "call:") {
444			lastCallIdx = i
445		}
446		if strings.HasPrefix(e, "result:") && i < firstResultIdx {
447			firstResultIdx = i
448		}
449	}
450	require.Equal(t, 2, stepCount)
451	require.Less(t, lastCallIdx, firstResultIdx,
452		"all OnToolCall events must complete before the first OnToolResult; got %v", events)
453}
454
455// TestStreamingAgentTextDeltas tests text streaming (mirrors TS textStream tests)
456func TestStreamingAgentTextDeltas(t *testing.T) {
457	t.Parallel()
458
459	// Create a mock language model that returns text deltas
460	mockModel := &mockLanguageModel{
461		streamFunc: func(ctx context.Context, call Call) (StreamResponse, error) {
462			return func(yield func(StreamPart) bool) {
463				if !yield(StreamPart{Type: StreamPartTypeTextStart, ID: "text-1"}) {
464					return
465				}
466				if !yield(StreamPart{Type: StreamPartTypeTextDelta, ID: "text-1", Delta: "Hello"}) {
467					return
468				}
469				if !yield(StreamPart{Type: StreamPartTypeTextDelta, ID: "text-1", Delta: ", "}) {
470					return
471				}
472				if !yield(StreamPart{Type: StreamPartTypeTextDelta, ID: "text-1", Delta: "world!"}) {
473					return
474				}
475				if !yield(StreamPart{Type: StreamPartTypeTextEnd, ID: "text-1"}) {
476					return
477				}
478				yield(StreamPart{
479					Type:         StreamPartTypeFinish,
480					Usage:        Usage{InputTokens: 3, OutputTokens: 10, TotalTokens: 13},
481					FinishReason: FinishReasonStop,
482				})
483			}, nil
484		},
485	}
486
487	agent := NewAgent(mockModel)
488	ctx := context.Background()
489
490	// Track text deltas
491	var textDeltas []string
492
493	streamCall := AgentStreamCall{
494		Prompt: "Say hello",
495		OnTextDelta: func(id, text string) error {
496			if text != "" {
497				textDeltas = append(textDeltas, text)
498			}
499			return nil
500		},
501	}
502
503	result, err := agent.Stream(ctx, streamCall)
504	require.NoError(t, err)
505
506	// Verify text deltas match expected pattern
507	require.Equal(t, []string{"Hello", ", ", "world!"}, textDeltas)
508	require.Equal(t, "Hello, world!", result.Response.Content.Text())
509	require.Equal(t, int64(13), result.TotalUsage.TotalTokens)
510}
511
512// TestStreamingAgentReasoning tests reasoning content (mirrors TS reasoning tests)
513func TestStreamingAgentReasoning(t *testing.T) {
514	t.Parallel()
515
516	mockModel := &mockLanguageModel{
517		streamFunc: func(ctx context.Context, call Call) (StreamResponse, error) {
518			return func(yield func(StreamPart) bool) {
519				if !yield(StreamPart{Type: StreamPartTypeReasoningStart, ID: "reasoning-1"}) {
520					return
521				}
522				if !yield(StreamPart{Type: StreamPartTypeReasoningDelta, ID: "reasoning-1", Delta: "I will open the conversation"}) {
523					return
524				}
525				if !yield(StreamPart{Type: StreamPartTypeReasoningDelta, ID: "reasoning-1", Delta: " with witty banter."}) {
526					return
527				}
528				if !yield(StreamPart{Type: StreamPartTypeReasoningEnd, ID: "reasoning-1"}) {
529					return
530				}
531				if !yield(StreamPart{Type: StreamPartTypeTextStart, ID: "text-1"}) {
532					return
533				}
534				if !yield(StreamPart{Type: StreamPartTypeTextDelta, ID: "text-1", Delta: "Hi there!"}) {
535					return
536				}
537				if !yield(StreamPart{Type: StreamPartTypeTextEnd, ID: "text-1"}) {
538					return
539				}
540				yield(StreamPart{
541					Type:         StreamPartTypeFinish,
542					Usage:        Usage{InputTokens: 5, OutputTokens: 15, TotalTokens: 20},
543					FinishReason: FinishReasonStop,
544				})
545			}, nil
546		},
547	}
548
549	agent := NewAgent(mockModel)
550	ctx := context.Background()
551
552	var reasoningDeltas []string
553	var textDeltas []string
554
555	streamCall := AgentStreamCall{
556		Prompt: "Think and respond",
557		OnReasoningDelta: func(id, text string) error {
558			reasoningDeltas = append(reasoningDeltas, text)
559			return nil
560		},
561		OnTextDelta: func(id, text string) error {
562			textDeltas = append(textDeltas, text)
563			return nil
564		},
565	}
566
567	result, err := agent.Stream(ctx, streamCall)
568	require.NoError(t, err)
569
570	// Verify reasoning and text are separate
571	require.Equal(t, []string{"I will open the conversation", " with witty banter."}, reasoningDeltas)
572	require.Equal(t, []string{"Hi there!"}, textDeltas)
573	require.Equal(t, "Hi there!", result.Response.Content.Text())
574	require.Equal(t, "I will open the conversation with witty banter.", result.Response.Content.ReasoningText())
575}
576
577// TestStreamingAgentError tests error handling (mirrors TS error tests)
578func TestStreamingAgentError(t *testing.T) {
579	t.Parallel()
580
581	// Create a mock language model that returns an error
582	mockModel := &mockLanguageModel{
583		streamFunc: func(ctx context.Context, call Call) (StreamResponse, error) {
584			return func(yield func(StreamPart) bool) {
585				yield(StreamPart{Type: StreamPartTypeError, Error: fmt.Errorf("mock stream error")})
586			}, nil
587		},
588	}
589
590	agent := NewAgent(mockModel)
591	ctx := context.Background()
592
593	// Track error callbacks
594	var errorOccurred bool
595	var errorMessage string
596
597	streamCall := AgentStreamCall{
598		Prompt: "This will fail",
599
600		OnError: func(err error) {
601			errorOccurred = true
602			errorMessage = err.Error()
603		},
604	}
605
606	// Execute streaming agent
607	result, err := agent.Stream(ctx, streamCall)
608	require.Error(t, err)
609	require.Nil(t, result)
610	require.True(t, errorOccurred, "OnError should have been called")
611	require.Contains(t, errorMessage, "mock stream error")
612}
613
614// TestStreamingAgentSources tests source handling (mirrors TS source tests)
615func TestStreamingAgentSources(t *testing.T) {
616	t.Parallel()
617
618	mockModel := &mockLanguageModel{
619		streamFunc: func(ctx context.Context, call Call) (StreamResponse, error) {
620			return func(yield func(StreamPart) bool) {
621				if !yield(StreamPart{
622					Type:       StreamPartTypeSource,
623					ID:         "source-1",
624					SourceType: SourceTypeURL,
625					URL:        "https://example.com",
626					Title:      "Example",
627				}) {
628					return
629				}
630				if !yield(StreamPart{Type: StreamPartTypeTextStart, ID: "text-1"}) {
631					return
632				}
633				if !yield(StreamPart{Type: StreamPartTypeTextDelta, ID: "text-1", Delta: "Hello!"}) {
634					return
635				}
636				if !yield(StreamPart{Type: StreamPartTypeTextEnd, ID: "text-1"}) {
637					return
638				}
639				if !yield(StreamPart{
640					Type:       StreamPartTypeSource,
641					ID:         "source-2",
642					SourceType: SourceTypeDocument,
643					Title:      "Document Example",
644				}) {
645					return
646				}
647				yield(StreamPart{
648					Type:         StreamPartTypeFinish,
649					Usage:        Usage{InputTokens: 3, OutputTokens: 5, TotalTokens: 8},
650					FinishReason: FinishReasonStop,
651				})
652			}, nil
653		},
654	}
655
656	agent := NewAgent(mockModel)
657	ctx := context.Background()
658
659	var sources []SourceContent
660
661	streamCall := AgentStreamCall{
662		Prompt: "Search and respond",
663		OnSource: func(source SourceContent) error {
664			sources = append(sources, source)
665			return nil
666		},
667	}
668
669	result, err := agent.Stream(ctx, streamCall)
670	require.NoError(t, err)
671
672	// Verify sources were captured
673	require.Equal(t, 2, len(sources))
674	require.Equal(t, SourceTypeURL, sources[0].SourceType)
675	require.Equal(t, "https://example.com", sources[0].URL)
676	require.Equal(t, "Example", sources[0].Title)
677	require.Equal(t, SourceTypeDocument, sources[1].SourceType)
678	require.Equal(t, "Document Example", sources[1].Title)
679
680	// Verify sources are in final result
681	resultSources := result.Response.Content.Sources()
682	require.Equal(t, 2, len(resultSources))
683}