agent_stream_test.go

  1package ai
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"fmt"
  7	"testing"
  8
  9	"github.com/stretchr/testify/require"
 10)
 11
 12// EchoTool is a simple tool that echoes back the input message
 13type EchoTool struct{}
 14
 15// Info returns the tool information
 16func (e *EchoTool) Info() ToolInfo {
 17	return ToolInfo{
 18		Name:        "echo",
 19		Description: "Echo back the provided message",
 20		Parameters: map[string]any{
 21			"message": map[string]any{
 22				"type":        "string",
 23				"description": "The message to echo back",
 24			},
 25		},
 26		Required: []string{"message"},
 27	}
 28}
 29
 30// Run executes the echo tool
 31func (e *EchoTool) Run(ctx context.Context, params ToolCall) (ToolResponse, error) {
 32	var input struct {
 33		Message string `json:"message"`
 34	}
 35
 36	if err := json.Unmarshal([]byte(params.Input), &input); err != nil {
 37		return NewTextErrorResponse("Invalid input: " + err.Error()), nil
 38	}
 39
 40	if input.Message == "" {
 41		return NewTextErrorResponse("Message cannot be empty"), nil
 42	}
 43
 44	return NewTextResponse("Echo: " + input.Message), nil
 45}
 46
 47// TestStreamingAgentCallbacks tests that all streaming callbacks are called correctly
 48func TestStreamingAgentCallbacks(t *testing.T) {
 49	t.Parallel()
 50
 51	// Track which callbacks were called
 52	callbacks := make(map[string]bool)
 53
 54	// Create a mock language model that returns various stream parts
 55	mockModel := &mockLanguageModel{
 56		streamFunc: func(ctx context.Context, call Call) (StreamResponse, error) {
 57			return func(yield func(StreamPart) bool) {
 58				// Test all stream part types
 59				if !yield(StreamPart{Type: StreamPartTypeWarnings, Warnings: []CallWarning{{Type: CallWarningTypeOther, Message: "test warning"}}}) {
 60					return
 61				}
 62				if !yield(StreamPart{Type: StreamPartTypeTextStart, ID: "text-1"}) {
 63					return
 64				}
 65				if !yield(StreamPart{Type: StreamPartTypeTextDelta, ID: "text-1", Delta: "Hello"}) {
 66					return
 67				}
 68				if !yield(StreamPart{Type: StreamPartTypeTextEnd, ID: "text-1"}) {
 69					return
 70				}
 71				if !yield(StreamPart{Type: StreamPartTypeReasoningStart, ID: "reasoning-1"}) {
 72					return
 73				}
 74				if !yield(StreamPart{Type: StreamPartTypeReasoningDelta, ID: "reasoning-1", Delta: "thinking..."}) {
 75					return
 76				}
 77				if !yield(StreamPart{Type: StreamPartTypeReasoningEnd, ID: "reasoning-1"}) {
 78					return
 79				}
 80				if !yield(StreamPart{Type: StreamPartTypeToolInputStart, ID: "tool-1", ToolCallName: "test_tool"}) {
 81					return
 82				}
 83				if !yield(StreamPart{Type: StreamPartTypeToolInputDelta, ID: "tool-1", Delta: `{"param"`}) {
 84					return
 85				}
 86				if !yield(StreamPart{Type: StreamPartTypeToolInputEnd, ID: "tool-1"}) {
 87					return
 88				}
 89				if !yield(StreamPart{Type: StreamPartTypeSource, ID: "source-1", SourceType: SourceTypeURL, URL: "https://example.com", Title: "Example"}) {
 90					return
 91				}
 92				yield(StreamPart{
 93					Type:         StreamPartTypeFinish,
 94					Usage:        Usage{InputTokens: 5, OutputTokens: 2, TotalTokens: 7},
 95					FinishReason: FinishReasonStop,
 96				})
 97			}, nil
 98		},
 99	}
100
101	// Create agent
102	agent := NewAgent(mockModel)
103
104	ctx := context.Background()
105
106	// Create streaming call with all callbacks
107	streamCall := AgentStreamCall{
108		Prompt: "Test all callbacks",
109		OnAgentStart: func() {
110			callbacks["OnAgentStart"] = true
111		},
112		OnAgentFinish: func(result *AgentResult) {
113			callbacks["OnAgentFinish"] = true
114		},
115		OnStepStart: func(stepNumber int) {
116			callbacks["OnStepStart"] = true
117		},
118		OnStepFinish: func(stepResult StepResult) {
119			callbacks["OnStepFinish"] = true
120		},
121		OnFinish: func(result *AgentResult) {
122			callbacks["OnFinish"] = true
123		},
124		OnError: func(err error) {
125			callbacks["OnError"] = true
126		},
127		OnChunk: func(part StreamPart) {
128			callbacks["OnChunk"] = true
129		},
130		OnWarnings: func(warnings []CallWarning) {
131			callbacks["OnWarnings"] = true
132		},
133		OnTextStart: func(id string) {
134			callbacks["OnTextStart"] = true
135		},
136		OnTextDelta: func(id, text string) {
137			callbacks["OnTextDelta"] = true
138		},
139		OnTextEnd: func(id string) {
140			callbacks["OnTextEnd"] = true
141		},
142		OnReasoningStart: func(id string) {
143			callbacks["OnReasoningStart"] = true
144		},
145		OnReasoningDelta: func(id, text string) {
146			callbacks["OnReasoningDelta"] = true
147		},
148		OnReasoningEnd: func(id string) {
149			callbacks["OnReasoningEnd"] = true
150		},
151		OnToolInputStart: func(id, toolName string) {
152			callbacks["OnToolInputStart"] = true
153		},
154		OnToolInputDelta: func(id, delta string) {
155			callbacks["OnToolInputDelta"] = true
156		},
157		OnToolInputEnd: func(id string) {
158			callbacks["OnToolInputEnd"] = true
159		},
160		OnToolCall: func(toolCall ToolCallContent) {
161			callbacks["OnToolCall"] = true
162		},
163		OnToolResult: func(result ToolResultContent) {
164			callbacks["OnToolResult"] = true
165		},
166		OnSource: func(source SourceContent) {
167			callbacks["OnSource"] = true
168		},
169		OnStreamFinish: func(usage Usage, finishReason FinishReason, providerMetadata ProviderOptions) {
170			callbacks["OnStreamFinish"] = true
171		},
172		OnStreamError: func(err error) {
173			callbacks["OnStreamError"] = true
174		},
175	}
176
177	// Execute streaming agent
178	result, err := agent.Stream(ctx, streamCall)
179	require.NoError(t, err)
180	require.NotNil(t, result)
181
182	// Verify that expected callbacks were called
183	expectedCallbacks := []string{
184		"OnAgentStart",
185		"OnAgentFinish",
186		"OnStepStart",
187		"OnStepFinish",
188		"OnFinish",
189		"OnChunk",
190		"OnWarnings",
191		"OnTextStart",
192		"OnTextDelta",
193		"OnTextEnd",
194		"OnReasoningStart",
195		"OnReasoningDelta",
196		"OnReasoningEnd",
197		"OnToolInputStart",
198		"OnToolInputDelta",
199		"OnToolInputEnd",
200		"OnSource",
201		"OnStreamFinish",
202	}
203
204	for _, callback := range expectedCallbacks {
205		require.True(t, callbacks[callback], "Expected callback %s to be called", callback)
206	}
207
208	// Verify that error callbacks were not called
209	require.False(t, callbacks["OnError"], "OnError should not be called in successful case")
210	require.False(t, callbacks["OnStreamError"], "OnStreamError should not be called in successful case")
211	require.False(t, callbacks["OnToolCall"], "OnToolCall should not be called without actual tool calls")
212	require.False(t, callbacks["OnToolResult"], "OnToolResult should not be called without actual tool results")
213}
214
215// TestStreamingAgentWithTools tests streaming agent with tool calls (mirrors TS test patterns)
216func TestStreamingAgentWithTools(t *testing.T) {
217	t.Parallel()
218
219	stepCount := 0
220	// Create a mock language model that makes a tool call then finishes
221	mockModel := &mockLanguageModel{
222		streamFunc: func(ctx context.Context, call Call) (StreamResponse, error) {
223			stepCount++
224			return func(yield func(StreamPart) bool) {
225				if stepCount == 1 {
226					// First step: make tool call
227					if !yield(StreamPart{Type: StreamPartTypeToolInputStart, ID: "tool-1", ToolCallName: "echo"}) {
228						return
229					}
230					if !yield(StreamPart{Type: StreamPartTypeToolInputDelta, ID: "tool-1", Delta: `{"message"`}) {
231						return
232					}
233					if !yield(StreamPart{Type: StreamPartTypeToolInputDelta, ID: "tool-1", Delta: `: "test"}`}) {
234						return
235					}
236					if !yield(StreamPart{Type: StreamPartTypeToolInputEnd, ID: "tool-1"}) {
237						return
238					}
239					if !yield(StreamPart{
240						Type:          StreamPartTypeToolCall,
241						ID:            "tool-1",
242						ToolCallName:  "echo",
243						ToolCallInput: `{"message": "test"}`,
244					}) {
245						return
246					}
247					yield(StreamPart{
248						Type:         StreamPartTypeFinish,
249						Usage:        Usage{InputTokens: 10, OutputTokens: 5, TotalTokens: 15},
250						FinishReason: FinishReasonToolCalls,
251					})
252				} else {
253					// Second step: finish after tool execution
254					if !yield(StreamPart{Type: StreamPartTypeTextStart, ID: "text-1"}) {
255						return
256					}
257					if !yield(StreamPart{Type: StreamPartTypeTextDelta, ID: "text-1", Delta: "Tool executed successfully"}) {
258						return
259					}
260					if !yield(StreamPart{Type: StreamPartTypeTextEnd, ID: "text-1"}) {
261						return
262					}
263					yield(StreamPart{
264						Type:         StreamPartTypeFinish,
265						Usage:        Usage{InputTokens: 5, OutputTokens: 3, TotalTokens: 8},
266						FinishReason: FinishReasonStop,
267					})
268				}
269			}, nil
270		},
271	}
272
273	// Create agent with echo tool
274	agent := NewAgent(
275		mockModel,
276		WithSystemPrompt("You are a helpful assistant."),
277		WithTools(&EchoTool{}),
278	)
279
280	ctx := context.Background()
281
282	// Track callback invocations
283	var toolInputStartCalled bool
284	var toolInputDeltaCalled bool
285	var toolInputEndCalled bool
286	var toolCallCalled bool
287	var toolResultCalled bool
288
289	// Create streaming call with callbacks
290	streamCall := AgentStreamCall{
291		Prompt: "Echo 'test'",
292		OnToolInputStart: func(id, toolName string) {
293			toolInputStartCalled = true
294			require.Equal(t, "tool-1", id)
295			require.Equal(t, "echo", toolName)
296		},
297		OnToolInputDelta: func(id, delta string) {
298			toolInputDeltaCalled = true
299			require.Equal(t, "tool-1", id)
300			require.Contains(t, []string{`{"message"`, `: "test"}`}, delta)
301		},
302		OnToolInputEnd: func(id string) {
303			toolInputEndCalled = true
304			require.Equal(t, "tool-1", id)
305		},
306		OnToolCall: func(toolCall ToolCallContent) {
307			toolCallCalled = true
308			require.Equal(t, "echo", toolCall.ToolName)
309			require.Equal(t, `{"message": "test"}`, toolCall.Input)
310		},
311		OnToolResult: func(result ToolResultContent) {
312			toolResultCalled = true
313			require.Equal(t, "echo", result.ToolName)
314		},
315	}
316
317	// Execute streaming agent
318	result, err := agent.Stream(ctx, streamCall)
319	require.NoError(t, err)
320
321	// Verify results
322	require.True(t, toolInputStartCalled, "OnToolInputStart should have been called")
323	require.True(t, toolInputDeltaCalled, "OnToolInputDelta should have been called")
324	require.True(t, toolInputEndCalled, "OnToolInputEnd should have been called")
325	require.True(t, toolCallCalled, "OnToolCall should have been called")
326	require.True(t, toolResultCalled, "OnToolResult should have been called")
327	require.Equal(t, 2, len(result.Steps)) // Two steps: tool call + final response
328
329	// Check that tool was executed in first step
330	firstStep := result.Steps[0]
331	toolCalls := firstStep.Content.ToolCalls()
332	require.Equal(t, 1, len(toolCalls))
333	require.Equal(t, "echo", toolCalls[0].ToolName)
334
335	toolResults := firstStep.Content.ToolResults()
336	require.Equal(t, 1, len(toolResults))
337	require.Equal(t, "echo", toolResults[0].ToolName)
338}
339
340// TestStreamingAgentTextDeltas tests text streaming (mirrors TS textStream tests)
341func TestStreamingAgentTextDeltas(t *testing.T) {
342	t.Parallel()
343
344	// Create a mock language model that returns text deltas
345	mockModel := &mockLanguageModel{
346		streamFunc: func(ctx context.Context, call Call) (StreamResponse, error) {
347			return func(yield func(StreamPart) bool) {
348				if !yield(StreamPart{Type: StreamPartTypeTextStart, ID: "text-1"}) {
349					return
350				}
351				if !yield(StreamPart{Type: StreamPartTypeTextDelta, ID: "text-1", Delta: "Hello"}) {
352					return
353				}
354				if !yield(StreamPart{Type: StreamPartTypeTextDelta, ID: "text-1", Delta: ", "}) {
355					return
356				}
357				if !yield(StreamPart{Type: StreamPartTypeTextDelta, ID: "text-1", Delta: "world!"}) {
358					return
359				}
360				if !yield(StreamPart{Type: StreamPartTypeTextEnd, ID: "text-1"}) {
361					return
362				}
363				yield(StreamPart{
364					Type:         StreamPartTypeFinish,
365					Usage:        Usage{InputTokens: 3, OutputTokens: 10, TotalTokens: 13},
366					FinishReason: FinishReasonStop,
367				})
368			}, nil
369		},
370	}
371
372	agent := NewAgent(mockModel)
373	ctx := context.Background()
374
375	// Track text deltas
376	var textDeltas []string
377
378	streamCall := AgentStreamCall{
379		Prompt: "Say hello",
380		OnTextDelta: func(id, text string) {
381			if text != "" {
382				textDeltas = append(textDeltas, text)
383			}
384		},
385	}
386
387	result, err := agent.Stream(ctx, streamCall)
388	require.NoError(t, err)
389
390	// Verify text deltas match expected pattern
391	require.Equal(t, []string{"Hello", ", ", "world!"}, textDeltas)
392	require.Equal(t, "Hello, world!", result.Response.Content.Text())
393	require.Equal(t, int64(13), result.TotalUsage.TotalTokens)
394}
395
396// TestStreamingAgentReasoning tests reasoning content (mirrors TS reasoning tests)
397func TestStreamingAgentReasoning(t *testing.T) {
398	t.Parallel()
399
400	mockModel := &mockLanguageModel{
401		streamFunc: func(ctx context.Context, call Call) (StreamResponse, error) {
402			return func(yield func(StreamPart) bool) {
403				if !yield(StreamPart{Type: StreamPartTypeReasoningStart, ID: "reasoning-1"}) {
404					return
405				}
406				if !yield(StreamPart{Type: StreamPartTypeReasoningDelta, ID: "reasoning-1", Delta: "I will open the conversation"}) {
407					return
408				}
409				if !yield(StreamPart{Type: StreamPartTypeReasoningDelta, ID: "reasoning-1", Delta: " with witty banter."}) {
410					return
411				}
412				if !yield(StreamPart{Type: StreamPartTypeReasoningEnd, ID: "reasoning-1"}) {
413					return
414				}
415				if !yield(StreamPart{Type: StreamPartTypeTextStart, ID: "text-1"}) {
416					return
417				}
418				if !yield(StreamPart{Type: StreamPartTypeTextDelta, ID: "text-1", Delta: "Hi there!"}) {
419					return
420				}
421				if !yield(StreamPart{Type: StreamPartTypeTextEnd, ID: "text-1"}) {
422					return
423				}
424				yield(StreamPart{
425					Type:         StreamPartTypeFinish,
426					Usage:        Usage{InputTokens: 5, OutputTokens: 15, TotalTokens: 20},
427					FinishReason: FinishReasonStop,
428				})
429			}, nil
430		},
431	}
432
433	agent := NewAgent(mockModel)
434	ctx := context.Background()
435
436	var reasoningDeltas []string
437	var textDeltas []string
438
439	streamCall := AgentStreamCall{
440		Prompt: "Think and respond",
441		OnReasoningDelta: func(id, text string) {
442			reasoningDeltas = append(reasoningDeltas, text)
443		},
444		OnTextDelta: func(id, text string) {
445			textDeltas = append(textDeltas, text)
446		},
447	}
448
449	result, err := agent.Stream(ctx, streamCall)
450	require.NoError(t, err)
451
452	// Verify reasoning and text are separate
453	require.Equal(t, []string{"I will open the conversation", " with witty banter."}, reasoningDeltas)
454	require.Equal(t, []string{"Hi there!"}, textDeltas)
455	require.Equal(t, "Hi there!", result.Response.Content.Text())
456	require.Equal(t, "I will open the conversation with witty banter.", result.Response.Content.ReasoningText())
457}
458
459// TestStreamingAgentError tests error handling (mirrors TS error tests)
460func TestStreamingAgentError(t *testing.T) {
461	t.Parallel()
462
463	// Create a mock language model that returns an error
464	mockModel := &mockLanguageModel{
465		streamFunc: func(ctx context.Context, call Call) (StreamResponse, error) {
466			return func(yield func(StreamPart) bool) {
467				yield(StreamPart{Type: StreamPartTypeError, Error: fmt.Errorf("mock stream error")})
468			}, nil
469		},
470	}
471
472	agent := NewAgent(mockModel)
473	ctx := context.Background()
474
475	// Track error callbacks
476	var streamErrorOccurred bool
477	var errorOccurred bool
478	var errorMessage string
479
480	streamCall := AgentStreamCall{
481		Prompt: "This will fail",
482		OnStreamError: func(err error) {
483			streamErrorOccurred = true
484		},
485		OnError: func(err error) {
486			errorOccurred = true
487			errorMessage = err.Error()
488		},
489	}
490
491	// Execute streaming agent
492	result, err := agent.Stream(ctx, streamCall)
493	require.Error(t, err)
494	require.Nil(t, result)
495	require.True(t, streamErrorOccurred, "OnStreamError should have been called")
496	require.True(t, errorOccurred, "OnError should have been called")
497	require.Contains(t, errorMessage, "mock stream error")
498}
499
500// TestStreamingAgentSources tests source handling (mirrors TS source tests)
501func TestStreamingAgentSources(t *testing.T) {
502	t.Parallel()
503
504	mockModel := &mockLanguageModel{
505		streamFunc: func(ctx context.Context, call Call) (StreamResponse, error) {
506			return func(yield func(StreamPart) bool) {
507				if !yield(StreamPart{
508					Type:       StreamPartTypeSource,
509					ID:         "source-1",
510					SourceType: SourceTypeURL,
511					URL:        "https://example.com",
512					Title:      "Example",
513				}) {
514					return
515				}
516				if !yield(StreamPart{Type: StreamPartTypeTextStart, ID: "text-1"}) {
517					return
518				}
519				if !yield(StreamPart{Type: StreamPartTypeTextDelta, ID: "text-1", Delta: "Hello!"}) {
520					return
521				}
522				if !yield(StreamPart{Type: StreamPartTypeTextEnd, ID: "text-1"}) {
523					return
524				}
525				if !yield(StreamPart{
526					Type:       StreamPartTypeSource,
527					ID:         "source-2",
528					SourceType: SourceTypeDocument,
529					Title:      "Document Example",
530				}) {
531					return
532				}
533				yield(StreamPart{
534					Type:         StreamPartTypeFinish,
535					Usage:        Usage{InputTokens: 3, OutputTokens: 5, TotalTokens: 8},
536					FinishReason: FinishReasonStop,
537				})
538			}, nil
539		},
540	}
541
542	agent := NewAgent(mockModel)
543	ctx := context.Background()
544
545	var sources []SourceContent
546
547	streamCall := AgentStreamCall{
548		Prompt: "Search and respond",
549		OnSource: func(source SourceContent) {
550			sources = append(sources, source)
551		},
552	}
553
554	result, err := agent.Stream(ctx, streamCall)
555	require.NoError(t, err)
556
557	// Verify sources were captured
558	require.Equal(t, 2, len(sources))
559	require.Equal(t, SourceTypeURL, sources[0].SourceType)
560	require.Equal(t, "https://example.com", sources[0].URL)
561	require.Equal(t, "Example", sources[0].Title)
562	require.Equal(t, SourceTypeDocument, sources[1].SourceType)
563	require.Equal(t, "Document Example", sources[1].Title)
564
565	// Verify sources are in final result
566	resultSources := result.Response.Content.Sources()
567	require.Equal(t, 2, len(resultSources))
568}