agent_stream_test.go

  1package ai
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"fmt"
  7	"testing"
  8
  9	"github.com/stretchr/testify/require"
 10)
 11
 12// EchoTool is a simple tool that echoes back the input message
 13type EchoTool struct {
 14	providerOptions ProviderOptions
 15}
 16
 17func (e *EchoTool) SetProviderOptions(opts ProviderOptions) {
 18	e.providerOptions = opts
 19}
 20
 21func (e *EchoTool) ProviderOptions() ProviderOptions {
 22	return e.providerOptions
 23}
 24
 25// Info returns the tool information
 26func (e *EchoTool) Info() ToolInfo {
 27	return ToolInfo{
 28		Name:        "echo",
 29		Description: "Echo back the provided message",
 30		Parameters: map[string]any{
 31			"message": map[string]any{
 32				"type":        "string",
 33				"description": "The message to echo back",
 34			},
 35		},
 36		Required: []string{"message"},
 37	}
 38}
 39
 40// Run executes the echo tool
 41func (e *EchoTool) Run(ctx context.Context, params ToolCall) (ToolResponse, error) {
 42	var input struct {
 43		Message string `json:"message"`
 44	}
 45
 46	if err := json.Unmarshal([]byte(params.Input), &input); err != nil {
 47		return NewTextErrorResponse("Invalid input: " + err.Error()), nil
 48	}
 49
 50	if input.Message == "" {
 51		return NewTextErrorResponse("Message cannot be empty"), nil
 52	}
 53
 54	return NewTextResponse("Echo: " + input.Message), nil
 55}
 56
 57// TestStreamingAgentCallbacks tests that all streaming callbacks are called correctly
 58func TestStreamingAgentCallbacks(t *testing.T) {
 59	t.Parallel()
 60
 61	// Track which callbacks were called
 62	callbacks := make(map[string]bool)
 63
 64	// Create a mock language model that returns various stream parts
 65	mockModel := &mockLanguageModel{
 66		streamFunc: func(ctx context.Context, call Call) (StreamResponse, error) {
 67			return func(yield func(StreamPart) bool) {
 68				// Test all stream part types
 69				if !yield(StreamPart{Type: StreamPartTypeWarnings, Warnings: []CallWarning{{Type: CallWarningTypeOther, Message: "test warning"}}}) {
 70					return
 71				}
 72				if !yield(StreamPart{Type: StreamPartTypeTextStart, ID: "text-1"}) {
 73					return
 74				}
 75				if !yield(StreamPart{Type: StreamPartTypeTextDelta, ID: "text-1", Delta: "Hello"}) {
 76					return
 77				}
 78				if !yield(StreamPart{Type: StreamPartTypeTextEnd, ID: "text-1"}) {
 79					return
 80				}
 81				if !yield(StreamPart{Type: StreamPartTypeReasoningStart, ID: "reasoning-1"}) {
 82					return
 83				}
 84				if !yield(StreamPart{Type: StreamPartTypeReasoningDelta, ID: "reasoning-1", Delta: "thinking..."}) {
 85					return
 86				}
 87				if !yield(StreamPart{Type: StreamPartTypeReasoningEnd, ID: "reasoning-1"}) {
 88					return
 89				}
 90				if !yield(StreamPart{Type: StreamPartTypeToolInputStart, ID: "tool-1", ToolCallName: "test_tool"}) {
 91					return
 92				}
 93				if !yield(StreamPart{Type: StreamPartTypeToolInputDelta, ID: "tool-1", Delta: `{"param"`}) {
 94					return
 95				}
 96				if !yield(StreamPart{Type: StreamPartTypeToolInputEnd, ID: "tool-1"}) {
 97					return
 98				}
 99				if !yield(StreamPart{Type: StreamPartTypeSource, ID: "source-1", SourceType: SourceTypeURL, URL: "https://example.com", Title: "Example"}) {
100					return
101				}
102				yield(StreamPart{
103					Type:         StreamPartTypeFinish,
104					Usage:        Usage{InputTokens: 5, OutputTokens: 2, TotalTokens: 7},
105					FinishReason: FinishReasonStop,
106				})
107			}, nil
108		},
109	}
110
111	// Create agent
112	agent := NewAgent(mockModel)
113
114	ctx := context.Background()
115
116	// Create streaming call with all callbacks
117	streamCall := AgentStreamCall{
118		Prompt: "Test all callbacks",
119		OnAgentStart: func() {
120			callbacks["OnAgentStart"] = true
121		},
122		OnAgentFinish: func(result *AgentResult) error {
123			callbacks["OnAgentFinish"] = true
124			return nil
125		},
126		OnStepStart: func(stepNumber int) error {
127			callbacks["OnStepStart"] = true
128			return nil
129		},
130		OnStepFinish: func(stepResult StepResult) error {
131			callbacks["OnStepFinish"] = true
132			return nil
133		},
134		OnFinish: func(result *AgentResult) {
135			callbacks["OnFinish"] = true
136		},
137		OnError: func(err error) {
138			callbacks["OnError"] = true
139		},
140		OnChunk: func(part StreamPart) error {
141			callbacks["OnChunk"] = true
142			return nil
143		},
144		OnWarnings: func(warnings []CallWarning) error {
145			callbacks["OnWarnings"] = true
146			return nil
147		},
148		OnTextStart: func(id string) error {
149			callbacks["OnTextStart"] = true
150			return nil
151		},
152		OnTextDelta: func(id, text string) error {
153			callbacks["OnTextDelta"] = true
154			return nil
155		},
156		OnTextEnd: func(id string) error {
157			callbacks["OnTextEnd"] = true
158			return nil
159		},
160		OnReasoningStart: func(id string, _ ReasoningContent) error {
161			callbacks["OnReasoningStart"] = true
162			return nil
163		},
164		OnReasoningDelta: func(id, text string) error {
165			callbacks["OnReasoningDelta"] = true
166			return nil
167		},
168		OnReasoningEnd: func(id string, content ReasoningContent) error {
169			callbacks["OnReasoningEnd"] = true
170			return nil
171		},
172		OnToolInputStart: func(id, toolName string) error {
173			callbacks["OnToolInputStart"] = true
174			return nil
175		},
176		OnToolInputDelta: func(id, delta string) error {
177			callbacks["OnToolInputDelta"] = true
178			return nil
179		},
180		OnToolInputEnd: func(id string) error {
181			callbacks["OnToolInputEnd"] = true
182			return nil
183		},
184		OnToolCall: func(toolCall ToolCallContent) error {
185			callbacks["OnToolCall"] = true
186			return nil
187		},
188		OnToolResult: func(result ToolResultContent) error {
189			callbacks["OnToolResult"] = true
190			return nil
191		},
192		OnSource: func(source SourceContent) error {
193			callbacks["OnSource"] = true
194			return nil
195		},
196		OnStreamFinish: func(usage Usage, finishReason FinishReason, providerMetadata ProviderMetadata) error {
197			callbacks["OnStreamFinish"] = true
198			return nil
199		},
200	}
201
202	// Execute streaming agent
203	result, err := agent.Stream(ctx, streamCall)
204	require.NoError(t, err)
205	require.NotNil(t, result)
206
207	// Verify that expected callbacks were called
208	expectedCallbacks := []string{
209		"OnAgentStart",
210		"OnAgentFinish",
211		"OnStepStart",
212		"OnStepFinish",
213		"OnFinish",
214		"OnChunk",
215		"OnWarnings",
216		"OnTextStart",
217		"OnTextDelta",
218		"OnTextEnd",
219		"OnReasoningStart",
220		"OnReasoningDelta",
221		"OnReasoningEnd",
222		"OnToolInputStart",
223		"OnToolInputDelta",
224		"OnToolInputEnd",
225		"OnSource",
226		"OnStreamFinish",
227	}
228
229	for _, callback := range expectedCallbacks {
230		require.True(t, callbacks[callback], "Expected callback %s to be called", callback)
231	}
232
233	// Verify that error callbacks were not called
234	require.False(t, callbacks["OnError"], "OnError should not be called in successful case")
235	require.False(t, callbacks["OnToolCall"], "OnToolCall should not be called without actual tool calls")
236	require.False(t, callbacks["OnToolResult"], "OnToolResult should not be called without actual tool results")
237}
238
239// TestStreamingAgentWithTools tests streaming agent with tool calls (mirrors TS test patterns)
240func TestStreamingAgentWithTools(t *testing.T) {
241	t.Parallel()
242
243	stepCount := 0
244	// Create a mock language model that makes a tool call then finishes
245	mockModel := &mockLanguageModel{
246		streamFunc: func(ctx context.Context, call Call) (StreamResponse, error) {
247			stepCount++
248			return func(yield func(StreamPart) bool) {
249				if stepCount == 1 {
250					// First step: make tool call
251					if !yield(StreamPart{Type: StreamPartTypeToolInputStart, ID: "tool-1", ToolCallName: "echo"}) {
252						return
253					}
254					if !yield(StreamPart{Type: StreamPartTypeToolInputDelta, ID: "tool-1", Delta: `{"message"`}) {
255						return
256					}
257					if !yield(StreamPart{Type: StreamPartTypeToolInputDelta, ID: "tool-1", Delta: `: "test"}`}) {
258						return
259					}
260					if !yield(StreamPart{Type: StreamPartTypeToolInputEnd, ID: "tool-1"}) {
261						return
262					}
263					if !yield(StreamPart{
264						Type:          StreamPartTypeToolCall,
265						ID:            "tool-1",
266						ToolCallName:  "echo",
267						ToolCallInput: `{"message": "test"}`,
268					}) {
269						return
270					}
271					yield(StreamPart{
272						Type:         StreamPartTypeFinish,
273						Usage:        Usage{InputTokens: 10, OutputTokens: 5, TotalTokens: 15},
274						FinishReason: FinishReasonToolCalls,
275					})
276				} else {
277					// Second step: finish after tool execution
278					if !yield(StreamPart{Type: StreamPartTypeTextStart, ID: "text-1"}) {
279						return
280					}
281					if !yield(StreamPart{Type: StreamPartTypeTextDelta, ID: "text-1", Delta: "Tool executed successfully"}) {
282						return
283					}
284					if !yield(StreamPart{Type: StreamPartTypeTextEnd, ID: "text-1"}) {
285						return
286					}
287					yield(StreamPart{
288						Type:         StreamPartTypeFinish,
289						Usage:        Usage{InputTokens: 5, OutputTokens: 3, TotalTokens: 8},
290						FinishReason: FinishReasonStop,
291					})
292				}
293			}, nil
294		},
295	}
296
297	// Create agent with echo tool
298	agent := NewAgent(
299		mockModel,
300		WithSystemPrompt("You are a helpful assistant."),
301		WithTools(&EchoTool{}),
302	)
303
304	ctx := context.Background()
305
306	// Track callback invocations
307	var toolInputStartCalled bool
308	var toolInputDeltaCalled bool
309	var toolInputEndCalled bool
310	var toolCallCalled bool
311	var toolResultCalled bool
312
313	// Create streaming call with callbacks
314	streamCall := AgentStreamCall{
315		Prompt: "Echo 'test'",
316		OnToolInputStart: func(id, toolName string) error {
317			toolInputStartCalled = true
318			require.Equal(t, "tool-1", id)
319			require.Equal(t, "echo", toolName)
320			return nil
321		},
322		OnToolInputDelta: func(id, delta string) error {
323			toolInputDeltaCalled = true
324			require.Equal(t, "tool-1", id)
325			require.Contains(t, []string{`{"message"`, `: "test"}`}, delta)
326			return nil
327		},
328		OnToolInputEnd: func(id string) error {
329			toolInputEndCalled = true
330			require.Equal(t, "tool-1", id)
331			return nil
332		},
333		OnToolCall: func(toolCall ToolCallContent) error {
334			toolCallCalled = true
335			require.Equal(t, "echo", toolCall.ToolName)
336			require.Equal(t, `{"message": "test"}`, toolCall.Input)
337			return nil
338		},
339		OnToolResult: func(result ToolResultContent) error {
340			toolResultCalled = true
341			require.Equal(t, "echo", result.ToolName)
342			return nil
343		},
344	}
345
346	// Execute streaming agent
347	result, err := agent.Stream(ctx, streamCall)
348	require.NoError(t, err)
349
350	// Verify results
351	require.True(t, toolInputStartCalled, "OnToolInputStart should have been called")
352	require.True(t, toolInputDeltaCalled, "OnToolInputDelta should have been called")
353	require.True(t, toolInputEndCalled, "OnToolInputEnd should have been called")
354	require.True(t, toolCallCalled, "OnToolCall should have been called")
355	require.True(t, toolResultCalled, "OnToolResult should have been called")
356	require.Equal(t, 2, len(result.Steps)) // Two steps: tool call + final response
357
358	// Check that tool was executed in first step
359	firstStep := result.Steps[0]
360	toolCalls := firstStep.Content.ToolCalls()
361	require.Equal(t, 1, len(toolCalls))
362	require.Equal(t, "echo", toolCalls[0].ToolName)
363
364	toolResults := firstStep.Content.ToolResults()
365	require.Equal(t, 1, len(toolResults))
366	require.Equal(t, "echo", toolResults[0].ToolName)
367}
368
369// TestStreamingAgentTextDeltas tests text streaming (mirrors TS textStream tests)
370func TestStreamingAgentTextDeltas(t *testing.T) {
371	t.Parallel()
372
373	// Create a mock language model that returns text deltas
374	mockModel := &mockLanguageModel{
375		streamFunc: func(ctx context.Context, call Call) (StreamResponse, error) {
376			return func(yield func(StreamPart) bool) {
377				if !yield(StreamPart{Type: StreamPartTypeTextStart, ID: "text-1"}) {
378					return
379				}
380				if !yield(StreamPart{Type: StreamPartTypeTextDelta, ID: "text-1", Delta: "Hello"}) {
381					return
382				}
383				if !yield(StreamPart{Type: StreamPartTypeTextDelta, ID: "text-1", Delta: ", "}) {
384					return
385				}
386				if !yield(StreamPart{Type: StreamPartTypeTextDelta, ID: "text-1", Delta: "world!"}) {
387					return
388				}
389				if !yield(StreamPart{Type: StreamPartTypeTextEnd, ID: "text-1"}) {
390					return
391				}
392				yield(StreamPart{
393					Type:         StreamPartTypeFinish,
394					Usage:        Usage{InputTokens: 3, OutputTokens: 10, TotalTokens: 13},
395					FinishReason: FinishReasonStop,
396				})
397			}, nil
398		},
399	}
400
401	agent := NewAgent(mockModel)
402	ctx := context.Background()
403
404	// Track text deltas
405	var textDeltas []string
406
407	streamCall := AgentStreamCall{
408		Prompt: "Say hello",
409		OnTextDelta: func(id, text string) error {
410			if text != "" {
411				textDeltas = append(textDeltas, text)
412			}
413			return nil
414		},
415	}
416
417	result, err := agent.Stream(ctx, streamCall)
418	require.NoError(t, err)
419
420	// Verify text deltas match expected pattern
421	require.Equal(t, []string{"Hello", ", ", "world!"}, textDeltas)
422	require.Equal(t, "Hello, world!", result.Response.Content.Text())
423	require.Equal(t, int64(13), result.TotalUsage.TotalTokens)
424}
425
426// TestStreamingAgentReasoning tests reasoning content (mirrors TS reasoning tests)
427func TestStreamingAgentReasoning(t *testing.T) {
428	t.Parallel()
429
430	mockModel := &mockLanguageModel{
431		streamFunc: func(ctx context.Context, call Call) (StreamResponse, error) {
432			return func(yield func(StreamPart) bool) {
433				if !yield(StreamPart{Type: StreamPartTypeReasoningStart, ID: "reasoning-1"}) {
434					return
435				}
436				if !yield(StreamPart{Type: StreamPartTypeReasoningDelta, ID: "reasoning-1", Delta: "I will open the conversation"}) {
437					return
438				}
439				if !yield(StreamPart{Type: StreamPartTypeReasoningDelta, ID: "reasoning-1", Delta: " with witty banter."}) {
440					return
441				}
442				if !yield(StreamPart{Type: StreamPartTypeReasoningEnd, ID: "reasoning-1"}) {
443					return
444				}
445				if !yield(StreamPart{Type: StreamPartTypeTextStart, ID: "text-1"}) {
446					return
447				}
448				if !yield(StreamPart{Type: StreamPartTypeTextDelta, ID: "text-1", Delta: "Hi there!"}) {
449					return
450				}
451				if !yield(StreamPart{Type: StreamPartTypeTextEnd, ID: "text-1"}) {
452					return
453				}
454				yield(StreamPart{
455					Type:         StreamPartTypeFinish,
456					Usage:        Usage{InputTokens: 5, OutputTokens: 15, TotalTokens: 20},
457					FinishReason: FinishReasonStop,
458				})
459			}, nil
460		},
461	}
462
463	agent := NewAgent(mockModel)
464	ctx := context.Background()
465
466	var reasoningDeltas []string
467	var textDeltas []string
468
469	streamCall := AgentStreamCall{
470		Prompt: "Think and respond",
471		OnReasoningDelta: func(id, text string) error {
472			reasoningDeltas = append(reasoningDeltas, text)
473			return nil
474		},
475		OnTextDelta: func(id, text string) error {
476			textDeltas = append(textDeltas, text)
477			return nil
478		},
479	}
480
481	result, err := agent.Stream(ctx, streamCall)
482	require.NoError(t, err)
483
484	// Verify reasoning and text are separate
485	require.Equal(t, []string{"I will open the conversation", " with witty banter."}, reasoningDeltas)
486	require.Equal(t, []string{"Hi there!"}, textDeltas)
487	require.Equal(t, "Hi there!", result.Response.Content.Text())
488	require.Equal(t, "I will open the conversation with witty banter.", result.Response.Content.ReasoningText())
489}
490
491// TestStreamingAgentError tests error handling (mirrors TS error tests)
492func TestStreamingAgentError(t *testing.T) {
493	t.Parallel()
494
495	// Create a mock language model that returns an error
496	mockModel := &mockLanguageModel{
497		streamFunc: func(ctx context.Context, call Call) (StreamResponse, error) {
498			return func(yield func(StreamPart) bool) {
499				yield(StreamPart{Type: StreamPartTypeError, Error: fmt.Errorf("mock stream error")})
500			}, nil
501		},
502	}
503
504	agent := NewAgent(mockModel)
505	ctx := context.Background()
506
507	// Track error callbacks
508	var errorOccurred bool
509	var errorMessage string
510
511	streamCall := AgentStreamCall{
512		Prompt: "This will fail",
513
514		OnError: func(err error) {
515			errorOccurred = true
516			errorMessage = err.Error()
517		},
518	}
519
520	// Execute streaming agent
521	result, err := agent.Stream(ctx, streamCall)
522	require.Error(t, err)
523	require.Nil(t, result)
524	require.True(t, errorOccurred, "OnError should have been called")
525	require.Contains(t, errorMessage, "mock stream error")
526}
527
528// TestStreamingAgentSources tests source handling (mirrors TS source tests)
529func TestStreamingAgentSources(t *testing.T) {
530	t.Parallel()
531
532	mockModel := &mockLanguageModel{
533		streamFunc: func(ctx context.Context, call Call) (StreamResponse, error) {
534			return func(yield func(StreamPart) bool) {
535				if !yield(StreamPart{
536					Type:       StreamPartTypeSource,
537					ID:         "source-1",
538					SourceType: SourceTypeURL,
539					URL:        "https://example.com",
540					Title:      "Example",
541				}) {
542					return
543				}
544				if !yield(StreamPart{Type: StreamPartTypeTextStart, ID: "text-1"}) {
545					return
546				}
547				if !yield(StreamPart{Type: StreamPartTypeTextDelta, ID: "text-1", Delta: "Hello!"}) {
548					return
549				}
550				if !yield(StreamPart{Type: StreamPartTypeTextEnd, ID: "text-1"}) {
551					return
552				}
553				if !yield(StreamPart{
554					Type:       StreamPartTypeSource,
555					ID:         "source-2",
556					SourceType: SourceTypeDocument,
557					Title:      "Document Example",
558				}) {
559					return
560				}
561				yield(StreamPart{
562					Type:         StreamPartTypeFinish,
563					Usage:        Usage{InputTokens: 3, OutputTokens: 5, TotalTokens: 8},
564					FinishReason: FinishReasonStop,
565				})
566			}, nil
567		},
568	}
569
570	agent := NewAgent(mockModel)
571	ctx := context.Background()
572
573	var sources []SourceContent
574
575	streamCall := AgentStreamCall{
576		Prompt: "Search and respond",
577		OnSource: func(source SourceContent) error {
578			sources = append(sources, source)
579			return nil
580		},
581	}
582
583	result, err := agent.Stream(ctx, streamCall)
584	require.NoError(t, err)
585
586	// Verify sources were captured
587	require.Equal(t, 2, len(sources))
588	require.Equal(t, SourceTypeURL, sources[0].SourceType)
589	require.Equal(t, "https://example.com", sources[0].URL)
590	require.Equal(t, "Example", sources[0].Title)
591	require.Equal(t, SourceTypeDocument, sources[1].SourceType)
592	require.Equal(t, "Document Example", sources[1].Title)
593
594	// Verify sources are in final result
595	resultSources := result.Response.Content.Sources()
596	require.Equal(t, 2, len(resultSources))
597}