agent_stream_test.go

  1package ai
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"fmt"
  7	"testing"
  8
  9	"github.com/stretchr/testify/require"
 10)
 11
 12// EchoTool is a simple tool that echoes back the input message
 13type EchoTool struct{}
 14
 15// Info returns the tool information
 16func (e *EchoTool) Info() ToolInfo {
 17	return ToolInfo{
 18		Name:        "echo",
 19		Description: "Echo back the provided message",
 20		Parameters: map[string]any{
 21			"message": map[string]any{
 22				"type":        "string",
 23				"description": "The message to echo back",
 24			},
 25		},
 26		Required: []string{"message"},
 27	}
 28}
 29
 30// Run executes the echo tool
 31func (e *EchoTool) Run(ctx context.Context, params ToolCall) (ToolResponse, error) {
 32	var input struct {
 33		Message string `json:"message"`
 34	}
 35
 36	if err := json.Unmarshal([]byte(params.Input), &input); err != nil {
 37		return NewTextErrorResponse("Invalid input: " + err.Error()), nil
 38	}
 39
 40	if input.Message == "" {
 41		return NewTextErrorResponse("Message cannot be empty"), nil
 42	}
 43
 44	return NewTextResponse("Echo: " + input.Message), nil
 45}
 46
 47// TestStreamingAgentCallbacks tests that all streaming callbacks are called correctly
 48func TestStreamingAgentCallbacks(t *testing.T) {
 49	t.Parallel()
 50
 51	// Track which callbacks were called
 52	callbacks := make(map[string]bool)
 53
 54	// Create a mock language model that returns various stream parts
 55	mockModel := &mockLanguageModel{
 56		streamFunc: func(ctx context.Context, call Call) (StreamResponse, error) {
 57			return func(yield func(StreamPart) bool) {
 58				// Test all stream part types
 59				if !yield(StreamPart{Type: StreamPartTypeWarnings, Warnings: []CallWarning{{Type: CallWarningTypeOther, Message: "test warning"}}}) {
 60					return
 61				}
 62				if !yield(StreamPart{Type: StreamPartTypeTextStart, ID: "text-1"}) {
 63					return
 64				}
 65				if !yield(StreamPart{Type: StreamPartTypeTextDelta, ID: "text-1", Delta: "Hello"}) {
 66					return
 67				}
 68				if !yield(StreamPart{Type: StreamPartTypeTextEnd, ID: "text-1"}) {
 69					return
 70				}
 71				if !yield(StreamPart{Type: StreamPartTypeReasoningStart, ID: "reasoning-1"}) {
 72					return
 73				}
 74				if !yield(StreamPart{Type: StreamPartTypeReasoningDelta, ID: "reasoning-1", Delta: "thinking..."}) {
 75					return
 76				}
 77				if !yield(StreamPart{Type: StreamPartTypeReasoningEnd, ID: "reasoning-1"}) {
 78					return
 79				}
 80				if !yield(StreamPart{Type: StreamPartTypeToolInputStart, ID: "tool-1", ToolCallName: "test_tool"}) {
 81					return
 82				}
 83				if !yield(StreamPart{Type: StreamPartTypeToolInputDelta, ID: "tool-1", Delta: `{"param"`}) {
 84					return
 85				}
 86				if !yield(StreamPart{Type: StreamPartTypeToolInputEnd, ID: "tool-1"}) {
 87					return
 88				}
 89				if !yield(StreamPart{Type: StreamPartTypeSource, ID: "source-1", SourceType: SourceTypeURL, URL: "https://example.com", Title: "Example"}) {
 90					return
 91				}
 92				yield(StreamPart{
 93					Type:         StreamPartTypeFinish,
 94					Usage:        Usage{InputTokens: 5, OutputTokens: 2, TotalTokens: 7},
 95					FinishReason: FinishReasonStop,
 96				})
 97			}, nil
 98		},
 99	}
100
101	// Create agent
102	agent := NewAgent(mockModel)
103
104	ctx := context.Background()
105
106	// Create streaming call with all callbacks
107	streamCall := AgentStreamCall{
108		Prompt: "Test all callbacks",
109		OnAgentStart: func() {
110			callbacks["OnAgentStart"] = true
111		},
112		OnAgentFinish: func(result *AgentResult) {
113			callbacks["OnAgentFinish"] = true
114		},
115		OnStepStart: func(stepNumber int) {
116			callbacks["OnStepStart"] = true
117		},
118		OnStepFinish: func(stepResult StepResult) {
119			callbacks["OnStepFinish"] = true
120		},
121		OnFinish: func(result *AgentResult) {
122			callbacks["OnFinish"] = true
123		},
124		OnError: func(err error) {
125			callbacks["OnError"] = true
126		},
127		OnChunk: func(part StreamPart) error {
128			callbacks["OnChunk"] = true
129			return nil
130		},
131		OnWarnings: func(warnings []CallWarning) error {
132			callbacks["OnWarnings"] = true
133			return nil
134		},
135		OnTextStart: func(id string) error {
136			callbacks["OnTextStart"] = true
137			return nil
138		},
139		OnTextDelta: func(id, text string) error {
140			callbacks["OnTextDelta"] = true
141			return nil
142		},
143		OnTextEnd: func(id string) error {
144			callbacks["OnTextEnd"] = true
145			return nil
146		},
147		OnReasoningStart: func(id string) error {
148			callbacks["OnReasoningStart"] = true
149			return nil
150		},
151		OnReasoningDelta: func(id, text string) error {
152			callbacks["OnReasoningDelta"] = true
153			return nil
154		},
155		OnReasoningEnd: func(id string, content ReasoningContent) error {
156			callbacks["OnReasoningEnd"] = true
157			return nil
158		},
159		OnToolInputStart: func(id, toolName string) error {
160			callbacks["OnToolInputStart"] = true
161			return nil
162		},
163		OnToolInputDelta: func(id, delta string) error {
164			callbacks["OnToolInputDelta"] = true
165			return nil
166		},
167		OnToolInputEnd: func(id string) error {
168			callbacks["OnToolInputEnd"] = true
169			return nil
170		},
171		OnToolCall: func(toolCall ToolCallContent) error {
172			callbacks["OnToolCall"] = true
173			return nil
174		},
175		OnToolResult: func(result ToolResultContent) error {
176			callbacks["OnToolResult"] = true
177			return nil
178		},
179		OnSource: func(source SourceContent) error {
180			callbacks["OnSource"] = true
181			return nil
182		},
183		OnStreamFinish: func(usage Usage, finishReason FinishReason, providerMetadata ProviderMetadata) error {
184			callbacks["OnStreamFinish"] = true
185			return nil
186		},
187	}
188
189	// Execute streaming agent
190	result, err := agent.Stream(ctx, streamCall)
191	require.NoError(t, err)
192	require.NotNil(t, result)
193
194	// Verify that expected callbacks were called
195	expectedCallbacks := []string{
196		"OnAgentStart",
197		"OnAgentFinish",
198		"OnStepStart",
199		"OnStepFinish",
200		"OnFinish",
201		"OnChunk",
202		"OnWarnings",
203		"OnTextStart",
204		"OnTextDelta",
205		"OnTextEnd",
206		"OnReasoningStart",
207		"OnReasoningDelta",
208		"OnReasoningEnd",
209		"OnToolInputStart",
210		"OnToolInputDelta",
211		"OnToolInputEnd",
212		"OnSource",
213		"OnStreamFinish",
214	}
215
216	for _, callback := range expectedCallbacks {
217		require.True(t, callbacks[callback], "Expected callback %s to be called", callback)
218	}
219
220	// Verify that error callbacks were not called
221	require.False(t, callbacks["OnError"], "OnError should not be called in successful case")
222	require.False(t, callbacks["OnToolCall"], "OnToolCall should not be called without actual tool calls")
223	require.False(t, callbacks["OnToolResult"], "OnToolResult should not be called without actual tool results")
224}
225
226// TestStreamingAgentWithTools tests streaming agent with tool calls (mirrors TS test patterns)
227func TestStreamingAgentWithTools(t *testing.T) {
228	t.Parallel()
229
230	stepCount := 0
231	// Create a mock language model that makes a tool call then finishes
232	mockModel := &mockLanguageModel{
233		streamFunc: func(ctx context.Context, call Call) (StreamResponse, error) {
234			stepCount++
235			return func(yield func(StreamPart) bool) {
236				if stepCount == 1 {
237					// First step: make tool call
238					if !yield(StreamPart{Type: StreamPartTypeToolInputStart, ID: "tool-1", ToolCallName: "echo"}) {
239						return
240					}
241					if !yield(StreamPart{Type: StreamPartTypeToolInputDelta, ID: "tool-1", Delta: `{"message"`}) {
242						return
243					}
244					if !yield(StreamPart{Type: StreamPartTypeToolInputDelta, ID: "tool-1", Delta: `: "test"}`}) {
245						return
246					}
247					if !yield(StreamPart{Type: StreamPartTypeToolInputEnd, ID: "tool-1"}) {
248						return
249					}
250					if !yield(StreamPart{
251						Type:          StreamPartTypeToolCall,
252						ID:            "tool-1",
253						ToolCallName:  "echo",
254						ToolCallInput: `{"message": "test"}`,
255					}) {
256						return
257					}
258					yield(StreamPart{
259						Type:         StreamPartTypeFinish,
260						Usage:        Usage{InputTokens: 10, OutputTokens: 5, TotalTokens: 15},
261						FinishReason: FinishReasonToolCalls,
262					})
263				} else {
264					// Second step: finish after tool execution
265					if !yield(StreamPart{Type: StreamPartTypeTextStart, ID: "text-1"}) {
266						return
267					}
268					if !yield(StreamPart{Type: StreamPartTypeTextDelta, ID: "text-1", Delta: "Tool executed successfully"}) {
269						return
270					}
271					if !yield(StreamPart{Type: StreamPartTypeTextEnd, ID: "text-1"}) {
272						return
273					}
274					yield(StreamPart{
275						Type:         StreamPartTypeFinish,
276						Usage:        Usage{InputTokens: 5, OutputTokens: 3, TotalTokens: 8},
277						FinishReason: FinishReasonStop,
278					})
279				}
280			}, nil
281		},
282	}
283
284	// Create agent with echo tool
285	agent := NewAgent(
286		mockModel,
287		WithSystemPrompt("You are a helpful assistant."),
288		WithTools(&EchoTool{}),
289	)
290
291	ctx := context.Background()
292
293	// Track callback invocations
294	var toolInputStartCalled bool
295	var toolInputDeltaCalled bool
296	var toolInputEndCalled bool
297	var toolCallCalled bool
298	var toolResultCalled bool
299
300	// Create streaming call with callbacks
301	streamCall := AgentStreamCall{
302		Prompt: "Echo 'test'",
303		OnToolInputStart: func(id, toolName string) error {
304			toolInputStartCalled = true
305			require.Equal(t, "tool-1", id)
306			require.Equal(t, "echo", toolName)
307			return nil
308		},
309		OnToolInputDelta: func(id, delta string) error {
310			toolInputDeltaCalled = true
311			require.Equal(t, "tool-1", id)
312			require.Contains(t, []string{`{"message"`, `: "test"}`}, delta)
313			return nil
314		},
315		OnToolInputEnd: func(id string) error {
316			toolInputEndCalled = true
317			require.Equal(t, "tool-1", id)
318			return nil
319		},
320		OnToolCall: func(toolCall ToolCallContent) error {
321			toolCallCalled = true
322			require.Equal(t, "echo", toolCall.ToolName)
323			require.Equal(t, `{"message": "test"}`, toolCall.Input)
324			return nil
325		},
326		OnToolResult: func(result ToolResultContent) error {
327			toolResultCalled = true
328			require.Equal(t, "echo", result.ToolName)
329			return nil
330		},
331	}
332
333	// Execute streaming agent
334	result, err := agent.Stream(ctx, streamCall)
335	require.NoError(t, err)
336
337	// Verify results
338	require.True(t, toolInputStartCalled, "OnToolInputStart should have been called")
339	require.True(t, toolInputDeltaCalled, "OnToolInputDelta should have been called")
340	require.True(t, toolInputEndCalled, "OnToolInputEnd should have been called")
341	require.True(t, toolCallCalled, "OnToolCall should have been called")
342	require.True(t, toolResultCalled, "OnToolResult should have been called")
343	require.Equal(t, 2, len(result.Steps)) // Two steps: tool call + final response
344
345	// Check that tool was executed in first step
346	firstStep := result.Steps[0]
347	toolCalls := firstStep.Content.ToolCalls()
348	require.Equal(t, 1, len(toolCalls))
349	require.Equal(t, "echo", toolCalls[0].ToolName)
350
351	toolResults := firstStep.Content.ToolResults()
352	require.Equal(t, 1, len(toolResults))
353	require.Equal(t, "echo", toolResults[0].ToolName)
354}
355
356// TestStreamingAgentTextDeltas tests text streaming (mirrors TS textStream tests)
357func TestStreamingAgentTextDeltas(t *testing.T) {
358	t.Parallel()
359
360	// Create a mock language model that returns text deltas
361	mockModel := &mockLanguageModel{
362		streamFunc: func(ctx context.Context, call Call) (StreamResponse, error) {
363			return func(yield func(StreamPart) bool) {
364				if !yield(StreamPart{Type: StreamPartTypeTextStart, ID: "text-1"}) {
365					return
366				}
367				if !yield(StreamPart{Type: StreamPartTypeTextDelta, ID: "text-1", Delta: "Hello"}) {
368					return
369				}
370				if !yield(StreamPart{Type: StreamPartTypeTextDelta, ID: "text-1", Delta: ", "}) {
371					return
372				}
373				if !yield(StreamPart{Type: StreamPartTypeTextDelta, ID: "text-1", Delta: "world!"}) {
374					return
375				}
376				if !yield(StreamPart{Type: StreamPartTypeTextEnd, ID: "text-1"}) {
377					return
378				}
379				yield(StreamPart{
380					Type:         StreamPartTypeFinish,
381					Usage:        Usage{InputTokens: 3, OutputTokens: 10, TotalTokens: 13},
382					FinishReason: FinishReasonStop,
383				})
384			}, nil
385		},
386	}
387
388	agent := NewAgent(mockModel)
389	ctx := context.Background()
390
391	// Track text deltas
392	var textDeltas []string
393
394	streamCall := AgentStreamCall{
395		Prompt: "Say hello",
396		OnTextDelta: func(id, text string) error {
397			if text != "" {
398				textDeltas = append(textDeltas, text)
399			}
400			return nil
401		},
402	}
403
404	result, err := agent.Stream(ctx, streamCall)
405	require.NoError(t, err)
406
407	// Verify text deltas match expected pattern
408	require.Equal(t, []string{"Hello", ", ", "world!"}, textDeltas)
409	require.Equal(t, "Hello, world!", result.Response.Content.Text())
410	require.Equal(t, int64(13), result.TotalUsage.TotalTokens)
411}
412
413// TestStreamingAgentReasoning tests reasoning content (mirrors TS reasoning tests)
414func TestStreamingAgentReasoning(t *testing.T) {
415	t.Parallel()
416
417	mockModel := &mockLanguageModel{
418		streamFunc: func(ctx context.Context, call Call) (StreamResponse, error) {
419			return func(yield func(StreamPart) bool) {
420				if !yield(StreamPart{Type: StreamPartTypeReasoningStart, ID: "reasoning-1"}) {
421					return
422				}
423				if !yield(StreamPart{Type: StreamPartTypeReasoningDelta, ID: "reasoning-1", Delta: "I will open the conversation"}) {
424					return
425				}
426				if !yield(StreamPart{Type: StreamPartTypeReasoningDelta, ID: "reasoning-1", Delta: " with witty banter."}) {
427					return
428				}
429				if !yield(StreamPart{Type: StreamPartTypeReasoningEnd, ID: "reasoning-1"}) {
430					return
431				}
432				if !yield(StreamPart{Type: StreamPartTypeTextStart, ID: "text-1"}) {
433					return
434				}
435				if !yield(StreamPart{Type: StreamPartTypeTextDelta, ID: "text-1", Delta: "Hi there!"}) {
436					return
437				}
438				if !yield(StreamPart{Type: StreamPartTypeTextEnd, ID: "text-1"}) {
439					return
440				}
441				yield(StreamPart{
442					Type:         StreamPartTypeFinish,
443					Usage:        Usage{InputTokens: 5, OutputTokens: 15, TotalTokens: 20},
444					FinishReason: FinishReasonStop,
445				})
446			}, nil
447		},
448	}
449
450	agent := NewAgent(mockModel)
451	ctx := context.Background()
452
453	var reasoningDeltas []string
454	var textDeltas []string
455
456	streamCall := AgentStreamCall{
457		Prompt: "Think and respond",
458		OnReasoningDelta: func(id, text string) error {
459			reasoningDeltas = append(reasoningDeltas, text)
460			return nil
461		},
462		OnTextDelta: func(id, text string) error {
463			textDeltas = append(textDeltas, text)
464			return nil
465		},
466	}
467
468	result, err := agent.Stream(ctx, streamCall)
469	require.NoError(t, err)
470
471	// Verify reasoning and text are separate
472	require.Equal(t, []string{"I will open the conversation", " with witty banter."}, reasoningDeltas)
473	require.Equal(t, []string{"Hi there!"}, textDeltas)
474	require.Equal(t, "Hi there!", result.Response.Content.Text())
475	require.Equal(t, "I will open the conversation with witty banter.", result.Response.Content.ReasoningText())
476}
477
478// TestStreamingAgentError tests error handling (mirrors TS error tests)
479func TestStreamingAgentError(t *testing.T) {
480	t.Parallel()
481
482	// Create a mock language model that returns an error
483	mockModel := &mockLanguageModel{
484		streamFunc: func(ctx context.Context, call Call) (StreamResponse, error) {
485			return func(yield func(StreamPart) bool) {
486				yield(StreamPart{Type: StreamPartTypeError, Error: fmt.Errorf("mock stream error")})
487			}, nil
488		},
489	}
490
491	agent := NewAgent(mockModel)
492	ctx := context.Background()
493
494	// Track error callbacks
495	var errorOccurred bool
496	var errorMessage string
497
498	streamCall := AgentStreamCall{
499		Prompt: "This will fail",
500
501		OnError: func(err error) {
502			errorOccurred = true
503			errorMessage = err.Error()
504		},
505	}
506
507	// Execute streaming agent
508	result, err := agent.Stream(ctx, streamCall)
509	require.Error(t, err)
510	require.Nil(t, result)
511	require.True(t, errorOccurred, "OnError should have been called")
512	require.Contains(t, errorMessage, "mock stream error")
513}
514
515// TestStreamingAgentSources tests source handling (mirrors TS source tests)
516func TestStreamingAgentSources(t *testing.T) {
517	t.Parallel()
518
519	mockModel := &mockLanguageModel{
520		streamFunc: func(ctx context.Context, call Call) (StreamResponse, error) {
521			return func(yield func(StreamPart) bool) {
522				if !yield(StreamPart{
523					Type:       StreamPartTypeSource,
524					ID:         "source-1",
525					SourceType: SourceTypeURL,
526					URL:        "https://example.com",
527					Title:      "Example",
528				}) {
529					return
530				}
531				if !yield(StreamPart{Type: StreamPartTypeTextStart, ID: "text-1"}) {
532					return
533				}
534				if !yield(StreamPart{Type: StreamPartTypeTextDelta, ID: "text-1", Delta: "Hello!"}) {
535					return
536				}
537				if !yield(StreamPart{Type: StreamPartTypeTextEnd, ID: "text-1"}) {
538					return
539				}
540				if !yield(StreamPart{
541					Type:       StreamPartTypeSource,
542					ID:         "source-2",
543					SourceType: SourceTypeDocument,
544					Title:      "Document Example",
545				}) {
546					return
547				}
548				yield(StreamPart{
549					Type:         StreamPartTypeFinish,
550					Usage:        Usage{InputTokens: 3, OutputTokens: 5, TotalTokens: 8},
551					FinishReason: FinishReasonStop,
552				})
553			}, nil
554		},
555	}
556
557	agent := NewAgent(mockModel)
558	ctx := context.Background()
559
560	var sources []SourceContent
561
562	streamCall := AgentStreamCall{
563		Prompt: "Search and respond",
564		OnSource: func(source SourceContent) error {
565			sources = append(sources, source)
566			return nil
567		},
568	}
569
570	result, err := agent.Stream(ctx, streamCall)
571	require.NoError(t, err)
572
573	// Verify sources were captured
574	require.Equal(t, 2, len(sources))
575	require.Equal(t, SourceTypeURL, sources[0].SourceType)
576	require.Equal(t, "https://example.com", sources[0].URL)
577	require.Equal(t, "Example", sources[0].Title)
578	require.Equal(t, SourceTypeDocument, sources[1].SourceType)
579	require.Equal(t, "Document Example", sources[1].Title)
580
581	// Verify sources are in final result
582	resultSources := result.Response.Content.Sources()
583	require.Equal(t, 2, len(resultSources))
584}