1package ai
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"fmt"
  7	"testing"
  8
  9	"github.com/stretchr/testify/require"
 10)
 11
 12// EchoTool is a simple tool that echoes back the input message
 13type EchoTool struct{}
 14
 15// Info returns the tool information
 16func (e *EchoTool) Info() ToolInfo {
 17	return ToolInfo{
 18		Name:        "echo",
 19		Description: "Echo back the provided message",
 20		Parameters: map[string]any{
 21			"message": map[string]any{
 22				"type":        "string",
 23				"description": "The message to echo back",
 24			},
 25		},
 26		Required: []string{"message"},
 27	}
 28}
 29
 30// Run executes the echo tool
 31func (e *EchoTool) Run(ctx context.Context, params ToolCall) (ToolResponse, error) {
 32	var input struct {
 33		Message string `json:"message"`
 34	}
 35
 36	if err := json.Unmarshal([]byte(params.Input), &input); err != nil {
 37		return NewTextErrorResponse("Invalid input: " + err.Error()), nil
 38	}
 39
 40	if input.Message == "" {
 41		return NewTextErrorResponse("Message cannot be empty"), nil
 42	}
 43
 44	return NewTextResponse("Echo: " + input.Message), nil
 45}
 46
 47// TestStreamingAgentCallbacks tests that all streaming callbacks are called correctly
 48func TestStreamingAgentCallbacks(t *testing.T) {
 49	t.Parallel()
 50
 51	// Track which callbacks were called
 52	callbacks := make(map[string]bool)
 53
 54	// Create a mock language model that returns various stream parts
 55	mockModel := &mockLanguageModel{
 56		streamFunc: func(ctx context.Context, call Call) (StreamResponse, error) {
 57			return func(yield func(StreamPart) bool) {
 58				// Test all stream part types
 59				if !yield(StreamPart{Type: StreamPartTypeWarnings, Warnings: []CallWarning{{Type: CallWarningTypeOther, Message: "test warning"}}}) {
 60					return
 61				}
 62				if !yield(StreamPart{Type: StreamPartTypeTextStart, ID: "text-1"}) {
 63					return
 64				}
 65				if !yield(StreamPart{Type: StreamPartTypeTextDelta, ID: "text-1", Delta: "Hello"}) {
 66					return
 67				}
 68				if !yield(StreamPart{Type: StreamPartTypeTextEnd, ID: "text-1"}) {
 69					return
 70				}
 71				if !yield(StreamPart{Type: StreamPartTypeReasoningStart, ID: "reasoning-1"}) {
 72					return
 73				}
 74				if !yield(StreamPart{Type: StreamPartTypeReasoningDelta, ID: "reasoning-1", Delta: "thinking..."}) {
 75					return
 76				}
 77				if !yield(StreamPart{Type: StreamPartTypeReasoningEnd, ID: "reasoning-1"}) {
 78					return
 79				}
 80				if !yield(StreamPart{Type: StreamPartTypeToolInputStart, ID: "tool-1", ToolCallName: "test_tool"}) {
 81					return
 82				}
 83				if !yield(StreamPart{Type: StreamPartTypeToolInputDelta, ID: "tool-1", Delta: `{"param"`}) {
 84					return
 85				}
 86				if !yield(StreamPart{Type: StreamPartTypeToolInputEnd, ID: "tool-1"}) {
 87					return
 88				}
 89				if !yield(StreamPart{Type: StreamPartTypeSource, ID: "source-1", SourceType: SourceTypeURL, URL: "https://example.com", Title: "Example"}) {
 90					return
 91				}
 92				yield(StreamPart{
 93					Type:         StreamPartTypeFinish,
 94					Usage:        Usage{InputTokens: 5, OutputTokens: 2, TotalTokens: 7},
 95					FinishReason: FinishReasonStop,
 96				})
 97			}, nil
 98		},
 99	}
100
101	// Create agent
102	agent := NewAgent(mockModel)
103
104	ctx := context.Background()
105
106	// Create streaming call with all callbacks
107	streamCall := AgentStreamCall{
108		Prompt: "Test all callbacks",
109		OnAgentStart: func() {
110			callbacks["OnAgentStart"] = true
111		},
112		OnAgentFinish: func(result *AgentResult) error {
113			callbacks["OnAgentFinish"] = true
114			return nil
115		},
116		OnStepStart: func(stepNumber int) error {
117			callbacks["OnStepStart"] = true
118			return nil
119		},
120		OnStepFinish: func(stepResult StepResult) error {
121			callbacks["OnStepFinish"] = true
122			return nil
123		},
124		OnFinish: func(result *AgentResult) {
125			callbacks["OnFinish"] = true
126		},
127		OnError: func(err error) {
128			callbacks["OnError"] = true
129		},
130		OnChunk: func(part StreamPart) error {
131			callbacks["OnChunk"] = true
132			return nil
133		},
134		OnWarnings: func(warnings []CallWarning) error {
135			callbacks["OnWarnings"] = true
136			return nil
137		},
138		OnTextStart: func(id string) error {
139			callbacks["OnTextStart"] = true
140			return nil
141		},
142		OnTextDelta: func(id, text string) error {
143			callbacks["OnTextDelta"] = true
144			return nil
145		},
146		OnTextEnd: func(id string) error {
147			callbacks["OnTextEnd"] = true
148			return nil
149		},
150		OnReasoningStart: func(id string) error {
151			callbacks["OnReasoningStart"] = true
152			return nil
153		},
154		OnReasoningDelta: func(id, text string) error {
155			callbacks["OnReasoningDelta"] = true
156			return nil
157		},
158		OnReasoningEnd: func(id string, content ReasoningContent) error {
159			callbacks["OnReasoningEnd"] = true
160			return nil
161		},
162		OnToolInputStart: func(id, toolName string) error {
163			callbacks["OnToolInputStart"] = true
164			return nil
165		},
166		OnToolInputDelta: func(id, delta string) error {
167			callbacks["OnToolInputDelta"] = true
168			return nil
169		},
170		OnToolInputEnd: func(id string) error {
171			callbacks["OnToolInputEnd"] = true
172			return nil
173		},
174		OnToolCall: func(toolCall ToolCallContent) error {
175			callbacks["OnToolCall"] = true
176			return nil
177		},
178		OnToolResult: func(result ToolResultContent) error {
179			callbacks["OnToolResult"] = true
180			return nil
181		},
182		OnSource: func(source SourceContent) error {
183			callbacks["OnSource"] = true
184			return nil
185		},
186		OnStreamFinish: func(usage Usage, finishReason FinishReason, providerMetadata ProviderMetadata) error {
187			callbacks["OnStreamFinish"] = true
188			return nil
189		},
190	}
191
192	// Execute streaming agent
193	result, err := agent.Stream(ctx, streamCall)
194	require.NoError(t, err)
195	require.NotNil(t, result)
196
197	// Verify that expected callbacks were called
198	expectedCallbacks := []string{
199		"OnAgentStart",
200		"OnAgentFinish",
201		"OnStepStart",
202		"OnStepFinish",
203		"OnFinish",
204		"OnChunk",
205		"OnWarnings",
206		"OnTextStart",
207		"OnTextDelta",
208		"OnTextEnd",
209		"OnReasoningStart",
210		"OnReasoningDelta",
211		"OnReasoningEnd",
212		"OnToolInputStart",
213		"OnToolInputDelta",
214		"OnToolInputEnd",
215		"OnSource",
216		"OnStreamFinish",
217	}
218
219	for _, callback := range expectedCallbacks {
220		require.True(t, callbacks[callback], "Expected callback %s to be called", callback)
221	}
222
223	// Verify that error callbacks were not called
224	require.False(t, callbacks["OnError"], "OnError should not be called in successful case")
225	require.False(t, callbacks["OnToolCall"], "OnToolCall should not be called without actual tool calls")
226	require.False(t, callbacks["OnToolResult"], "OnToolResult should not be called without actual tool results")
227}
228
229// TestStreamingAgentWithTools tests streaming agent with tool calls (mirrors TS test patterns)
230func TestStreamingAgentWithTools(t *testing.T) {
231	t.Parallel()
232
233	stepCount := 0
234	// Create a mock language model that makes a tool call then finishes
235	mockModel := &mockLanguageModel{
236		streamFunc: func(ctx context.Context, call Call) (StreamResponse, error) {
237			stepCount++
238			return func(yield func(StreamPart) bool) {
239				if stepCount == 1 {
240					// First step: make tool call
241					if !yield(StreamPart{Type: StreamPartTypeToolInputStart, ID: "tool-1", ToolCallName: "echo"}) {
242						return
243					}
244					if !yield(StreamPart{Type: StreamPartTypeToolInputDelta, ID: "tool-1", Delta: `{"message"`}) {
245						return
246					}
247					if !yield(StreamPart{Type: StreamPartTypeToolInputDelta, ID: "tool-1", Delta: `: "test"}`}) {
248						return
249					}
250					if !yield(StreamPart{Type: StreamPartTypeToolInputEnd, ID: "tool-1"}) {
251						return
252					}
253					if !yield(StreamPart{
254						Type:          StreamPartTypeToolCall,
255						ID:            "tool-1",
256						ToolCallName:  "echo",
257						ToolCallInput: `{"message": "test"}`,
258					}) {
259						return
260					}
261					yield(StreamPart{
262						Type:         StreamPartTypeFinish,
263						Usage:        Usage{InputTokens: 10, OutputTokens: 5, TotalTokens: 15},
264						FinishReason: FinishReasonToolCalls,
265					})
266				} else {
267					// Second step: finish after tool execution
268					if !yield(StreamPart{Type: StreamPartTypeTextStart, ID: "text-1"}) {
269						return
270					}
271					if !yield(StreamPart{Type: StreamPartTypeTextDelta, ID: "text-1", Delta: "Tool executed successfully"}) {
272						return
273					}
274					if !yield(StreamPart{Type: StreamPartTypeTextEnd, ID: "text-1"}) {
275						return
276					}
277					yield(StreamPart{
278						Type:         StreamPartTypeFinish,
279						Usage:        Usage{InputTokens: 5, OutputTokens: 3, TotalTokens: 8},
280						FinishReason: FinishReasonStop,
281					})
282				}
283			}, nil
284		},
285	}
286
287	// Create agent with echo tool
288	agent := NewAgent(
289		mockModel,
290		WithSystemPrompt("You are a helpful assistant."),
291		WithTools(&EchoTool{}),
292	)
293
294	ctx := context.Background()
295
296	// Track callback invocations
297	var toolInputStartCalled bool
298	var toolInputDeltaCalled bool
299	var toolInputEndCalled bool
300	var toolCallCalled bool
301	var toolResultCalled bool
302
303	// Create streaming call with callbacks
304	streamCall := AgentStreamCall{
305		Prompt: "Echo 'test'",
306		OnToolInputStart: func(id, toolName string) error {
307			toolInputStartCalled = true
308			require.Equal(t, "tool-1", id)
309			require.Equal(t, "echo", toolName)
310			return nil
311		},
312		OnToolInputDelta: func(id, delta string) error {
313			toolInputDeltaCalled = true
314			require.Equal(t, "tool-1", id)
315			require.Contains(t, []string{`{"message"`, `: "test"}`}, delta)
316			return nil
317		},
318		OnToolInputEnd: func(id string) error {
319			toolInputEndCalled = true
320			require.Equal(t, "tool-1", id)
321			return nil
322		},
323		OnToolCall: func(toolCall ToolCallContent) error {
324			toolCallCalled = true
325			require.Equal(t, "echo", toolCall.ToolName)
326			require.Equal(t, `{"message": "test"}`, toolCall.Input)
327			return nil
328		},
329		OnToolResult: func(result ToolResultContent) error {
330			toolResultCalled = true
331			require.Equal(t, "echo", result.ToolName)
332			return nil
333		},
334	}
335
336	// Execute streaming agent
337	result, err := agent.Stream(ctx, streamCall)
338	require.NoError(t, err)
339
340	// Verify results
341	require.True(t, toolInputStartCalled, "OnToolInputStart should have been called")
342	require.True(t, toolInputDeltaCalled, "OnToolInputDelta should have been called")
343	require.True(t, toolInputEndCalled, "OnToolInputEnd should have been called")
344	require.True(t, toolCallCalled, "OnToolCall should have been called")
345	require.True(t, toolResultCalled, "OnToolResult should have been called")
346	require.Equal(t, 2, len(result.Steps)) // Two steps: tool call + final response
347
348	// Check that tool was executed in first step
349	firstStep := result.Steps[0]
350	toolCalls := firstStep.Content.ToolCalls()
351	require.Equal(t, 1, len(toolCalls))
352	require.Equal(t, "echo", toolCalls[0].ToolName)
353
354	toolResults := firstStep.Content.ToolResults()
355	require.Equal(t, 1, len(toolResults))
356	require.Equal(t, "echo", toolResults[0].ToolName)
357}
358
359// TestStreamingAgentTextDeltas tests text streaming (mirrors TS textStream tests)
360func TestStreamingAgentTextDeltas(t *testing.T) {
361	t.Parallel()
362
363	// Create a mock language model that returns text deltas
364	mockModel := &mockLanguageModel{
365		streamFunc: func(ctx context.Context, call Call) (StreamResponse, error) {
366			return func(yield func(StreamPart) bool) {
367				if !yield(StreamPart{Type: StreamPartTypeTextStart, ID: "text-1"}) {
368					return
369				}
370				if !yield(StreamPart{Type: StreamPartTypeTextDelta, ID: "text-1", Delta: "Hello"}) {
371					return
372				}
373				if !yield(StreamPart{Type: StreamPartTypeTextDelta, ID: "text-1", Delta: ", "}) {
374					return
375				}
376				if !yield(StreamPart{Type: StreamPartTypeTextDelta, ID: "text-1", Delta: "world!"}) {
377					return
378				}
379				if !yield(StreamPart{Type: StreamPartTypeTextEnd, ID: "text-1"}) {
380					return
381				}
382				yield(StreamPart{
383					Type:         StreamPartTypeFinish,
384					Usage:        Usage{InputTokens: 3, OutputTokens: 10, TotalTokens: 13},
385					FinishReason: FinishReasonStop,
386				})
387			}, nil
388		},
389	}
390
391	agent := NewAgent(mockModel)
392	ctx := context.Background()
393
394	// Track text deltas
395	var textDeltas []string
396
397	streamCall := AgentStreamCall{
398		Prompt: "Say hello",
399		OnTextDelta: func(id, text string) error {
400			if text != "" {
401				textDeltas = append(textDeltas, text)
402			}
403			return nil
404		},
405	}
406
407	result, err := agent.Stream(ctx, streamCall)
408	require.NoError(t, err)
409
410	// Verify text deltas match expected pattern
411	require.Equal(t, []string{"Hello", ", ", "world!"}, textDeltas)
412	require.Equal(t, "Hello, world!", result.Response.Content.Text())
413	require.Equal(t, int64(13), result.TotalUsage.TotalTokens)
414}
415
416// TestStreamingAgentReasoning tests reasoning content (mirrors TS reasoning tests)
417func TestStreamingAgentReasoning(t *testing.T) {
418	t.Parallel()
419
420	mockModel := &mockLanguageModel{
421		streamFunc: func(ctx context.Context, call Call) (StreamResponse, error) {
422			return func(yield func(StreamPart) bool) {
423				if !yield(StreamPart{Type: StreamPartTypeReasoningStart, ID: "reasoning-1"}) {
424					return
425				}
426				if !yield(StreamPart{Type: StreamPartTypeReasoningDelta, ID: "reasoning-1", Delta: "I will open the conversation"}) {
427					return
428				}
429				if !yield(StreamPart{Type: StreamPartTypeReasoningDelta, ID: "reasoning-1", Delta: " with witty banter."}) {
430					return
431				}
432				if !yield(StreamPart{Type: StreamPartTypeReasoningEnd, ID: "reasoning-1"}) {
433					return
434				}
435				if !yield(StreamPart{Type: StreamPartTypeTextStart, ID: "text-1"}) {
436					return
437				}
438				if !yield(StreamPart{Type: StreamPartTypeTextDelta, ID: "text-1", Delta: "Hi there!"}) {
439					return
440				}
441				if !yield(StreamPart{Type: StreamPartTypeTextEnd, ID: "text-1"}) {
442					return
443				}
444				yield(StreamPart{
445					Type:         StreamPartTypeFinish,
446					Usage:        Usage{InputTokens: 5, OutputTokens: 15, TotalTokens: 20},
447					FinishReason: FinishReasonStop,
448				})
449			}, nil
450		},
451	}
452
453	agent := NewAgent(mockModel)
454	ctx := context.Background()
455
456	var reasoningDeltas []string
457	var textDeltas []string
458
459	streamCall := AgentStreamCall{
460		Prompt: "Think and respond",
461		OnReasoningDelta: func(id, text string) error {
462			reasoningDeltas = append(reasoningDeltas, text)
463			return nil
464		},
465		OnTextDelta: func(id, text string) error {
466			textDeltas = append(textDeltas, text)
467			return nil
468		},
469	}
470
471	result, err := agent.Stream(ctx, streamCall)
472	require.NoError(t, err)
473
474	// Verify reasoning and text are separate
475	require.Equal(t, []string{"I will open the conversation", " with witty banter."}, reasoningDeltas)
476	require.Equal(t, []string{"Hi there!"}, textDeltas)
477	require.Equal(t, "Hi there!", result.Response.Content.Text())
478	require.Equal(t, "I will open the conversation with witty banter.", result.Response.Content.ReasoningText())
479}
480
481// TestStreamingAgentError tests error handling (mirrors TS error tests)
482func TestStreamingAgentError(t *testing.T) {
483	t.Parallel()
484
485	// Create a mock language model that returns an error
486	mockModel := &mockLanguageModel{
487		streamFunc: func(ctx context.Context, call Call) (StreamResponse, error) {
488			return func(yield func(StreamPart) bool) {
489				yield(StreamPart{Type: StreamPartTypeError, Error: fmt.Errorf("mock stream error")})
490			}, nil
491		},
492	}
493
494	agent := NewAgent(mockModel)
495	ctx := context.Background()
496
497	// Track error callbacks
498	var errorOccurred bool
499	var errorMessage string
500
501	streamCall := AgentStreamCall{
502		Prompt: "This will fail",
503
504		OnError: func(err error) {
505			errorOccurred = true
506			errorMessage = err.Error()
507		},
508	}
509
510	// Execute streaming agent
511	result, err := agent.Stream(ctx, streamCall)
512	require.Error(t, err)
513	require.Nil(t, result)
514	require.True(t, errorOccurred, "OnError should have been called")
515	require.Contains(t, errorMessage, "mock stream error")
516}
517
518// TestStreamingAgentSources tests source handling (mirrors TS source tests)
519func TestStreamingAgentSources(t *testing.T) {
520	t.Parallel()
521
522	mockModel := &mockLanguageModel{
523		streamFunc: func(ctx context.Context, call Call) (StreamResponse, error) {
524			return func(yield func(StreamPart) bool) {
525				if !yield(StreamPart{
526					Type:       StreamPartTypeSource,
527					ID:         "source-1",
528					SourceType: SourceTypeURL,
529					URL:        "https://example.com",
530					Title:      "Example",
531				}) {
532					return
533				}
534				if !yield(StreamPart{Type: StreamPartTypeTextStart, ID: "text-1"}) {
535					return
536				}
537				if !yield(StreamPart{Type: StreamPartTypeTextDelta, ID: "text-1", Delta: "Hello!"}) {
538					return
539				}
540				if !yield(StreamPart{Type: StreamPartTypeTextEnd, ID: "text-1"}) {
541					return
542				}
543				if !yield(StreamPart{
544					Type:       StreamPartTypeSource,
545					ID:         "source-2",
546					SourceType: SourceTypeDocument,
547					Title:      "Document Example",
548				}) {
549					return
550				}
551				yield(StreamPart{
552					Type:         StreamPartTypeFinish,
553					Usage:        Usage{InputTokens: 3, OutputTokens: 5, TotalTokens: 8},
554					FinishReason: FinishReasonStop,
555				})
556			}, nil
557		},
558	}
559
560	agent := NewAgent(mockModel)
561	ctx := context.Background()
562
563	var sources []SourceContent
564
565	streamCall := AgentStreamCall{
566		Prompt: "Search and respond",
567		OnSource: func(source SourceContent) error {
568			sources = append(sources, source)
569			return nil
570		},
571	}
572
573	result, err := agent.Stream(ctx, streamCall)
574	require.NoError(t, err)
575
576	// Verify sources were captured
577	require.Equal(t, 2, len(sources))
578	require.Equal(t, SourceTypeURL, sources[0].SourceType)
579	require.Equal(t, "https://example.com", sources[0].URL)
580	require.Equal(t, "Example", sources[0].Title)
581	require.Equal(t, SourceTypeDocument, sources[1].SourceType)
582	require.Equal(t, "Document Example", sources[1].Title)
583
584	// Verify sources are in final result
585	resultSources := result.Response.Content.Sources()
586	require.Equal(t, 2, len(resultSources))
587}