1package fantasy
2
3import (
4 "context"
5 "encoding/json"
6 "errors"
7 "fmt"
8 "testing"
9
10 "github.com/stretchr/testify/require"
11)
12
13// Mock tool for testing
14type mockTool struct {
15 name string
16 providerOptions ProviderOptions
17 description string
18 parameters map[string]any
19 required []string
20 executeFunc func(ctx context.Context, call ToolCall) (ToolResponse, error)
21}
22
23func (m *mockTool) SetProviderOptions(opts ProviderOptions) {
24 m.providerOptions = opts
25}
26
27func (m *mockTool) ProviderOptions() ProviderOptions {
28 return m.providerOptions
29}
30
31func (m *mockTool) Info() ToolInfo {
32 return ToolInfo{
33 Name: m.name,
34 Description: m.description,
35 Parameters: m.parameters,
36 Required: m.required,
37 }
38}
39
40func (m *mockTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
41 if m.executeFunc != nil {
42 return m.executeFunc(ctx, call)
43 }
44 return ToolResponse{Content: "mock result", IsError: false}, nil
45}
46
47// Mock language model for testing
48type mockLanguageModel struct {
49 generateFunc func(ctx context.Context, call Call) (*Response, error)
50 streamFunc func(ctx context.Context, call Call) (StreamResponse, error)
51}
52
53func (m *mockLanguageModel) Generate(ctx context.Context, call Call) (*Response, error) {
54 if m.generateFunc != nil {
55 return m.generateFunc(ctx, call)
56 }
57 return &Response{
58 Content: []Content{
59 TextContent{Text: "Hello, world!"},
60 },
61 Usage: Usage{
62 InputTokens: 3,
63 OutputTokens: 10,
64 TotalTokens: 13,
65 },
66 FinishReason: FinishReasonStop,
67 }, nil
68}
69
70func (m *mockLanguageModel) Stream(ctx context.Context, call Call) (StreamResponse, error) {
71 if m.streamFunc != nil {
72 return m.streamFunc(ctx, call)
73 }
74 return nil, fmt.Errorf("mock stream not implemented")
75}
76
77func (m *mockLanguageModel) Provider() string {
78 return "mock-provider"
79}
80
81func (m *mockLanguageModel) Model() string {
82 return "mock-model"
83}
84
85func (m *mockLanguageModel) GenerateObject(ctx context.Context, call ObjectCall) (*ObjectResponse, error) {
86 return nil, fmt.Errorf("mock GenerateObject not implemented")
87}
88
89func (m *mockLanguageModel) StreamObject(ctx context.Context, call ObjectCall) (ObjectStreamResponse, error) {
90 return nil, fmt.Errorf("mock StreamObject not implemented")
91}
92
93// Test result.content - comprehensive content types (matches TS test)
94func TestAgent_Generate_ResultContent_AllTypes(t *testing.T) {
95 t.Parallel()
96
97 // Create a type-safe tool using the new API
98 type TestInput struct {
99 Value string `json:"value" description:"Test value"`
100 }
101
102 tool1 := NewAgentTool(
103 "tool1",
104 "Test tool",
105 func(ctx context.Context, input TestInput, _ ToolCall) (ToolResponse, error) {
106 require.Equal(t, "value", input.Value)
107 return ToolResponse{Content: "result1", IsError: false}, nil
108 },
109 )
110
111 model := &mockLanguageModel{
112 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
113 return &Response{
114 Content: []Content{
115 TextContent{Text: "Hello, world!"},
116 SourceContent{
117 ID: "123",
118 URL: "https://example.com",
119 Title: "Example",
120 SourceType: SourceTypeURL,
121 },
122 FileContent{
123 Data: []byte{1, 2, 3},
124 MediaType: "image/png",
125 },
126 ReasoningContent{
127 Text: "I will open the conversation with witty banter.",
128 },
129 ToolCallContent{
130 ToolCallID: "call-1",
131 ToolName: "tool1",
132 Input: `{"value":"value"}`,
133 },
134 TextContent{Text: "More text"},
135 },
136 Usage: Usage{
137 InputTokens: 3,
138 OutputTokens: 10,
139 TotalTokens: 13,
140 },
141 FinishReason: FinishReasonStop, // Note: FinishReasonStop, not ToolCalls
142 }, nil
143 },
144 }
145
146 agent := NewAgent(model, WithTools(tool1))
147 result, err := agent.Generate(context.Background(), AgentCall{
148 Prompt: "prompt",
149 })
150
151 require.NoError(t, err)
152 require.NotNil(t, result)
153 require.Len(t, result.Steps, 1) // Single step like TypeScript
154
155 // Check final response content includes tool result
156 require.Len(t, result.Response.Content, 7) // original 6 + 1 tool result
157
158 // Verify each content type in order
159 textContent, ok := AsContentType[TextContent](result.Response.Content[0])
160 require.True(t, ok)
161 require.Equal(t, "Hello, world!", textContent.Text)
162
163 sourceContent, ok := AsContentType[SourceContent](result.Response.Content[1])
164 require.True(t, ok)
165 require.Equal(t, "123", sourceContent.ID)
166
167 fileContent, ok := AsContentType[FileContent](result.Response.Content[2])
168 require.True(t, ok)
169 require.Equal(t, []byte{1, 2, 3}, fileContent.Data)
170
171 reasoningContent, ok := AsContentType[ReasoningContent](result.Response.Content[3])
172 require.True(t, ok)
173 require.Equal(t, "I will open the conversation with witty banter.", reasoningContent.Text)
174
175 toolCallContent, ok := AsContentType[ToolCallContent](result.Response.Content[4])
176 require.True(t, ok)
177 require.Equal(t, "call-1", toolCallContent.ToolCallID)
178
179 moreTextContent, ok := AsContentType[TextContent](result.Response.Content[5])
180 require.True(t, ok)
181 require.Equal(t, "More text", moreTextContent.Text)
182
183 // Tool result should be appended
184 toolResultContent, ok := AsContentType[ToolResultContent](result.Response.Content[6])
185 require.True(t, ok)
186 require.Equal(t, "call-1", toolResultContent.ToolCallID)
187 require.Equal(t, "tool1", toolResultContent.ToolName)
188}
189
190// Test result.text extraction
191func TestAgent_Generate_ResultText(t *testing.T) {
192 t.Parallel()
193
194 model := &mockLanguageModel{
195 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
196 return &Response{
197 Content: []Content{
198 TextContent{Text: "Hello, world!"},
199 },
200 Usage: Usage{
201 InputTokens: 3,
202 OutputTokens: 10,
203 TotalTokens: 13,
204 },
205 FinishReason: FinishReasonStop,
206 }, nil
207 },
208 }
209
210 agent := NewAgent(model)
211 result, err := agent.Generate(context.Background(), AgentCall{
212 Prompt: "prompt",
213 })
214
215 require.NoError(t, err)
216 require.NotNil(t, result)
217
218 // Test text extraction from content
219 text := result.Response.Content.Text()
220 require.Equal(t, "Hello, world!", text)
221}
222
223// Test result.toolCalls extraction (matches TS test exactly)
224func TestAgent_Generate_ResultToolCalls(t *testing.T) {
225 t.Parallel()
226
227 // Create type-safe tools using the new API
228 type Tool1Input struct {
229 Value string `json:"value" description:"Test value"`
230 }
231
232 type Tool2Input struct {
233 SomethingElse string `json:"somethingElse" description:"Another test value"`
234 }
235
236 tool1 := NewAgentTool(
237 "tool1",
238 "Test tool 1",
239 func(ctx context.Context, input Tool1Input, _ ToolCall) (ToolResponse, error) {
240 return ToolResponse{Content: "result1", IsError: false}, nil
241 },
242 )
243
244 tool2 := NewAgentTool(
245 "tool2",
246 "Test tool 2",
247 func(ctx context.Context, input Tool2Input, _ ToolCall) (ToolResponse, error) {
248 return ToolResponse{Content: "result2", IsError: false}, nil
249 },
250 )
251
252 model := &mockLanguageModel{
253 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
254 // Verify tools are passed correctly
255 require.Len(t, call.Tools, 2)
256 require.Equal(t, ToolChoiceAuto, *call.ToolChoice) // Should be auto, not required
257
258 // Verify prompt structure
259 require.Len(t, call.Prompt, 1)
260 require.Equal(t, MessageRoleUser, call.Prompt[0].Role)
261
262 return &Response{
263 Content: []Content{
264 ToolCallContent{
265 ToolCallID: "call-1",
266 ToolName: "tool1",
267 Input: `{"value":"value"}`,
268 },
269 },
270 Usage: Usage{
271 InputTokens: 3,
272 OutputTokens: 10,
273 TotalTokens: 13,
274 },
275 FinishReason: FinishReasonStop, // Note: Stop, not ToolCalls
276 }, nil
277 },
278 }
279
280 agent := NewAgent(model, WithTools(tool1, tool2))
281 result, err := agent.Generate(context.Background(), AgentCall{
282 Prompt: "test-input",
283 })
284
285 require.NoError(t, err)
286 require.NotNil(t, result)
287 require.Len(t, result.Steps, 1) // Single step
288
289 // Extract tool calls from final response (should be empty since tools don't execute)
290 var toolCalls []ToolCallContent
291 for _, content := range result.Response.Content {
292 if toolCall, ok := AsContentType[ToolCallContent](content); ok {
293 toolCalls = append(toolCalls, toolCall)
294 }
295 }
296
297 require.Len(t, toolCalls, 1)
298 require.Equal(t, "call-1", toolCalls[0].ToolCallID)
299 require.Equal(t, "tool1", toolCalls[0].ToolName)
300
301 // Parse and verify input
302 var input map[string]any
303 err = json.Unmarshal([]byte(toolCalls[0].Input), &input)
304 require.NoError(t, err)
305 require.Equal(t, "value", input["value"])
306}
307
308// Test result.toolResults extraction (matches TS test exactly)
309func TestAgent_Generate_ResultToolResults(t *testing.T) {
310 t.Parallel()
311
312 // Create type-safe tool using the new API
313 type TestInput struct {
314 Value string `json:"value" description:"Test value"`
315 }
316
317 tool1 := NewAgentTool(
318 "tool1",
319 "Test tool",
320 func(ctx context.Context, input TestInput, _ ToolCall) (ToolResponse, error) {
321 require.Equal(t, "value", input.Value)
322 return ToolResponse{Content: "result1", IsError: false}, nil
323 },
324 )
325
326 model := &mockLanguageModel{
327 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
328 // Verify tools and tool choice
329 require.Len(t, call.Tools, 1)
330 require.Equal(t, ToolChoiceAuto, *call.ToolChoice)
331
332 // Verify prompt
333 require.Len(t, call.Prompt, 1)
334 require.Equal(t, MessageRoleUser, call.Prompt[0].Role)
335
336 return &Response{
337 Content: []Content{
338 ToolCallContent{
339 ToolCallID: "call-1",
340 ToolName: "tool1",
341 Input: `{"value":"value"}`,
342 },
343 },
344 Usage: Usage{
345 InputTokens: 3,
346 OutputTokens: 10,
347 TotalTokens: 13,
348 },
349 FinishReason: FinishReasonStop, // Note: Stop, not ToolCalls
350 }, nil
351 },
352 }
353
354 agent := NewAgent(model, WithTools(tool1))
355 result, err := agent.Generate(context.Background(), AgentCall{
356 Prompt: "test-input",
357 })
358
359 require.NoError(t, err)
360 require.NotNil(t, result)
361 require.Len(t, result.Steps, 1) // Single step
362
363 // Extract tool results from final response
364 var toolResults []ToolResultContent
365 for _, content := range result.Response.Content {
366 if toolResult, ok := AsContentType[ToolResultContent](content); ok {
367 toolResults = append(toolResults, toolResult)
368 }
369 }
370
371 require.Len(t, toolResults, 1)
372 require.Equal(t, "call-1", toolResults[0].ToolCallID)
373 require.Equal(t, "tool1", toolResults[0].ToolName)
374
375 // Verify result content
376 textResult, ok := toolResults[0].Result.(ToolResultOutputContentText)
377 require.True(t, ok)
378 require.Equal(t, "result1", textResult.Text)
379}
380
381// Test multi-step scenario (matches TS "2 steps: initial, tool-result" test)
382func TestAgent_Generate_MultipleSteps(t *testing.T) {
383 t.Parallel()
384
385 // Create type-safe tool using the new API
386 type TestInput struct {
387 Value string `json:"value" description:"Test value"`
388 }
389
390 tool1 := NewAgentTool(
391 "tool1",
392 "Test tool",
393 func(ctx context.Context, input TestInput, _ ToolCall) (ToolResponse, error) {
394 require.Equal(t, "value", input.Value)
395 return ToolResponse{Content: "result1", IsError: false}, nil
396 },
397 )
398
399 callCount := 0
400 model := &mockLanguageModel{
401 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
402 callCount++
403 switch callCount {
404 case 1:
405 // First call - return tool call with FinishReasonToolCalls
406 return &Response{
407 Content: []Content{
408 ToolCallContent{
409 ToolCallID: "call-1",
410 ToolName: "tool1",
411 Input: `{"value":"value"}`,
412 },
413 },
414 Usage: Usage{
415 InputTokens: 10,
416 OutputTokens: 5,
417 TotalTokens: 15,
418 },
419 FinishReason: FinishReasonToolCalls, // This triggers multi-step
420 }, nil
421 case 2:
422 // Second call - return final text
423 return &Response{
424 Content: []Content{
425 TextContent{Text: "Hello, world!"},
426 },
427 Usage: Usage{
428 InputTokens: 3,
429 OutputTokens: 10,
430 TotalTokens: 13,
431 },
432 FinishReason: FinishReasonStop,
433 }, nil
434 default:
435 t.Fatalf("Unexpected call count: %d", callCount)
436 return nil, nil
437 }
438 },
439 }
440
441 agent := NewAgent(model, WithTools(tool1))
442 result, err := agent.Generate(context.Background(), AgentCall{
443 Prompt: "test-input",
444 })
445
446 require.NoError(t, err)
447 require.NotNil(t, result)
448 require.Len(t, result.Steps, 2)
449
450 // Check total usage sums both steps
451 require.Equal(t, int64(13), result.TotalUsage.InputTokens) // 10 + 3
452 require.Equal(t, int64(15), result.TotalUsage.OutputTokens) // 5 + 10
453 require.Equal(t, int64(28), result.TotalUsage.TotalTokens) // 15 + 13
454
455 // Final response should be from last step
456 require.Len(t, result.Response.Content, 1)
457 textContent, ok := AsContentType[TextContent](result.Response.Content[0])
458 require.True(t, ok)
459 require.Equal(t, "Hello, world!", textContent.Text)
460
461 // result.toolCalls should be empty (from last step)
462 var toolCalls []ToolCallContent
463 for _, content := range result.Response.Content {
464 if _, ok := AsContentType[ToolCallContent](content); ok {
465 toolCalls = append(toolCalls, content.(ToolCallContent))
466 }
467 }
468 require.Len(t, toolCalls, 0)
469
470 // result.toolResults should be empty (from last step)
471 var toolResults []ToolResultContent
472 for _, content := range result.Response.Content {
473 if _, ok := AsContentType[ToolResultContent](content); ok {
474 toolResults = append(toolResults, content.(ToolResultContent))
475 }
476 }
477 require.Len(t, toolResults, 0)
478}
479
480// Test basic text generation
481func TestAgent_Generate_BasicText(t *testing.T) {
482 t.Parallel()
483
484 model := &mockLanguageModel{
485 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
486 return &Response{
487 Content: []Content{
488 TextContent{Text: "Hello, world!"},
489 },
490 Usage: Usage{
491 InputTokens: 3,
492 OutputTokens: 10,
493 TotalTokens: 13,
494 },
495 FinishReason: FinishReasonStop,
496 }, nil
497 },
498 }
499
500 agent := NewAgent(model)
501 result, err := agent.Generate(context.Background(), AgentCall{
502 Prompt: "test prompt",
503 })
504
505 require.NoError(t, err)
506 require.NotNil(t, result)
507 require.Len(t, result.Steps, 1)
508
509 // Check final response
510 require.Len(t, result.Response.Content, 1)
511 textContent, ok := AsContentType[TextContent](result.Response.Content[0])
512 require.True(t, ok)
513 require.Equal(t, "Hello, world!", textContent.Text)
514
515 // Check usage
516 require.Equal(t, int64(3), result.Response.Usage.InputTokens)
517 require.Equal(t, int64(10), result.Response.Usage.OutputTokens)
518 require.Equal(t, int64(13), result.Response.Usage.TotalTokens)
519
520 // Check total usage
521 require.Equal(t, int64(3), result.TotalUsage.InputTokens)
522 require.Equal(t, int64(10), result.TotalUsage.OutputTokens)
523 require.Equal(t, int64(13), result.TotalUsage.TotalTokens)
524}
525
526// Test empty prompt error
527func TestAgent_Generate_EmptyPrompt(t *testing.T) {
528 t.Parallel()
529
530 model := &mockLanguageModel{}
531 agent := NewAgent(model)
532
533 result, err := agent.Generate(context.Background(), AgentCall{
534 Prompt: "", // Empty prompt should cause error
535 })
536
537 require.Error(t, err)
538 require.Nil(t, result)
539 require.Contains(t, err.Error(), "invalid argument: prompt can't be empty")
540}
541
542// Test with system prompt
543func TestAgent_Generate_WithSystemPrompt(t *testing.T) {
544 t.Parallel()
545
546 model := &mockLanguageModel{
547 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
548 // Verify system message is included
549 require.Len(t, call.Prompt, 2) // system + user
550 require.Equal(t, MessageRoleSystem, call.Prompt[0].Role)
551 require.Equal(t, MessageRoleUser, call.Prompt[1].Role)
552
553 systemPart, ok := call.Prompt[0].Content[0].(TextPart)
554 require.True(t, ok)
555 require.Equal(t, "You are a helpful assistant", systemPart.Text)
556
557 return &Response{
558 Content: []Content{
559 TextContent{Text: "Hello, world!"},
560 },
561 Usage: Usage{
562 InputTokens: 3,
563 OutputTokens: 10,
564 TotalTokens: 13,
565 },
566 FinishReason: FinishReasonStop,
567 }, nil
568 },
569 }
570
571 agent := NewAgent(model, WithSystemPrompt("You are a helpful assistant"))
572 result, err := agent.Generate(context.Background(), AgentCall{
573 Prompt: "test prompt",
574 })
575
576 require.NoError(t, err)
577 require.NotNil(t, result)
578}
579
580// Test options.activeTools filtering
581func TestAgent_Generate_OptionsActiveTools(t *testing.T) {
582 t.Parallel()
583
584 tool1 := &mockTool{
585 name: "tool1",
586 description: "Test tool 1",
587 parameters: map[string]any{
588 "value": map[string]any{"type": "string"},
589 },
590 required: []string{"value"},
591 }
592
593 tool2 := &mockTool{
594 name: "tool2",
595 description: "Test tool 2",
596 parameters: map[string]any{
597 "value": map[string]any{"type": "string"},
598 },
599 required: []string{"value"},
600 }
601
602 model := &mockLanguageModel{
603 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
604 // Verify only tool1 is available
605 require.Len(t, call.Tools, 1)
606 functionTool, ok := call.Tools[0].(FunctionTool)
607 require.True(t, ok)
608 require.Equal(t, "tool1", functionTool.Name)
609
610 return &Response{
611 Content: []Content{
612 TextContent{Text: "Hello, world!"},
613 },
614 Usage: Usage{
615 InputTokens: 3,
616 OutputTokens: 10,
617 TotalTokens: 13,
618 },
619 FinishReason: FinishReasonStop,
620 }, nil
621 },
622 }
623
624 agent := NewAgent(model, WithTools(tool1, tool2))
625 result, err := agent.Generate(context.Background(), AgentCall{
626 Prompt: "test-input",
627 ActiveTools: []string{"tool1"}, // Only tool1 should be active
628 })
629
630 require.NoError(t, err)
631 require.NotNil(t, result)
632}
633
634func TestAgent_Generate_OptionsActiveTools_WithProviderDefinedTools(t *testing.T) {
635 t.Parallel()
636
637 tool1 := &mockTool{
638 name: "tool1",
639 description: "Test tool 1",
640 parameters: map[string]any{
641 "value": map[string]any{"type": "string"},
642 },
643 required: []string{"value"},
644 }
645
646 providerTool1 := ProviderDefinedTool{ID: "provider.web_search", Name: "web_search"}
647 providerTool2 := ProviderDefinedTool{ID: "provider.code_execution", Name: "code_execution"}
648
649 model := &mockLanguageModel{
650 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
651 require.Len(t, call.Tools, 2)
652
653 functionTool, ok := call.Tools[0].(FunctionTool)
654 require.True(t, ok)
655 require.Equal(t, "tool1", functionTool.Name)
656
657 providerTool, ok := call.Tools[1].(ProviderDefinedTool)
658 require.True(t, ok)
659 require.Equal(t, "web_search", providerTool.Name)
660
661 return &Response{
662 Content: []Content{
663 TextContent{Text: "Hello, world!"},
664 },
665 Usage: Usage{
666 InputTokens: 3,
667 OutputTokens: 10,
668 TotalTokens: 13,
669 },
670 FinishReason: FinishReasonStop,
671 }, nil
672 },
673 }
674
675 agent := NewAgent(model, WithTools(tool1), WithProviderDefinedTools(providerTool1, providerTool2))
676 result, err := agent.Generate(context.Background(), AgentCall{
677 Prompt: "test-input",
678 ActiveTools: []string{"tool1", "web_search"}, // Only tool1 and web_search should be active
679 })
680
681 require.NoError(t, err)
682 require.NotNil(t, result)
683}
684
685func TestResponseContent_Getters(t *testing.T) {
686 t.Parallel()
687
688 // Create test content with all types
689 content := ResponseContent{
690 TextContent{Text: "Hello world"},
691 ReasoningContent{Text: "Let me think..."},
692 FileContent{Data: []byte("file data"), MediaType: "text/plain"},
693 SourceContent{SourceType: SourceTypeURL, URL: "https://example.com", Title: "Example"},
694 ToolCallContent{ToolCallID: "call1", ToolName: "test_tool", Input: `{"arg": "value"}`},
695 ToolResultContent{ToolCallID: "call1", ToolName: "test_tool", Result: ToolResultOutputContentText{Text: "result"}},
696 }
697
698 // Test Text()
699 require.Equal(t, "Hello world", content.Text())
700
701 // Test Reasoning()
702 reasoning := content.Reasoning()
703 require.Len(t, reasoning, 1)
704 require.Equal(t, "Let me think...", reasoning[0].Text)
705
706 // Test ReasoningText()
707 require.Equal(t, "Let me think...", content.ReasoningText())
708
709 // Test Files()
710 files := content.Files()
711 require.Len(t, files, 1)
712 require.Equal(t, "text/plain", files[0].MediaType)
713 require.Equal(t, []byte("file data"), files[0].Data)
714
715 // Test Sources()
716 sources := content.Sources()
717 require.Len(t, sources, 1)
718 require.Equal(t, SourceTypeURL, sources[0].SourceType)
719 require.Equal(t, "https://example.com", sources[0].URL)
720 require.Equal(t, "Example", sources[0].Title)
721
722 // Test ToolCalls()
723 toolCalls := content.ToolCalls()
724 require.Len(t, toolCalls, 1)
725 require.Equal(t, "call1", toolCalls[0].ToolCallID)
726 require.Equal(t, "test_tool", toolCalls[0].ToolName)
727 require.Equal(t, `{"arg": "value"}`, toolCalls[0].Input)
728
729 // Test ToolResults()
730 toolResults := content.ToolResults()
731 require.Len(t, toolResults, 1)
732 require.Equal(t, "call1", toolResults[0].ToolCallID)
733 require.Equal(t, "test_tool", toolResults[0].ToolName)
734 result, ok := AsToolResultOutputType[ToolResultOutputContentText](toolResults[0].Result)
735 require.True(t, ok)
736 require.Equal(t, "result", result.Text)
737}
738
739func TestResponseContent_Getters_Empty(t *testing.T) {
740 t.Parallel()
741
742 // Test with empty content
743 content := ResponseContent{}
744
745 require.Equal(t, "", content.Text())
746 require.Equal(t, "", content.ReasoningText())
747 require.Empty(t, content.Reasoning())
748 require.Empty(t, content.Files())
749 require.Empty(t, content.Sources())
750 require.Empty(t, content.ToolCalls())
751 require.Empty(t, content.ToolResults())
752}
753
754func TestResponseContent_Getters_MultipleItems(t *testing.T) {
755 t.Parallel()
756
757 // Test with multiple items of same type
758 content := ResponseContent{
759 ReasoningContent{Text: "First thought"},
760 ReasoningContent{Text: "Second thought"},
761 FileContent{Data: []byte("file1"), MediaType: "text/plain"},
762 FileContent{Data: []byte("file2"), MediaType: "image/png"},
763 }
764
765 // Test multiple reasoning
766 reasoning := content.Reasoning()
767 require.Len(t, reasoning, 2)
768 require.Equal(t, "First thought", reasoning[0].Text)
769 require.Equal(t, "Second thought", reasoning[1].Text)
770
771 // Test concatenated reasoning text
772 require.Equal(t, "First thoughtSecond thought", content.ReasoningText())
773
774 // Test multiple files
775 files := content.Files()
776 require.Len(t, files, 2)
777 require.Equal(t, "text/plain", files[0].MediaType)
778 require.Equal(t, "image/png", files[1].MediaType)
779}
780
781func TestStopConditions(t *testing.T) {
782 t.Parallel()
783
784 // Create test steps
785 step1 := StepResult{
786 Response: Response{
787 Content: ResponseContent{
788 TextContent{Text: "Hello"},
789 },
790 FinishReason: FinishReasonToolCalls,
791 Usage: Usage{TotalTokens: 10},
792 },
793 }
794
795 step2 := StepResult{
796 Response: Response{
797 Content: ResponseContent{
798 TextContent{Text: "World"},
799 ToolCallContent{ToolCallID: "call1", ToolName: "search", Input: `{"query": "test"}`},
800 },
801 FinishReason: FinishReasonStop,
802 Usage: Usage{TotalTokens: 15},
803 },
804 }
805
806 step3 := StepResult{
807 Response: Response{
808 Content: ResponseContent{
809 ReasoningContent{Text: "Let me think..."},
810 FileContent{Data: []byte("data"), MediaType: "text/plain"},
811 },
812 FinishReason: FinishReasonLength,
813 Usage: Usage{TotalTokens: 20},
814 },
815 }
816
817 t.Run("StepCountIs", func(t *testing.T) {
818 t.Parallel()
819 condition := StepCountIs(2)
820
821 // Should not stop with 1 step
822 require.False(t, condition([]StepResult{step1}))
823
824 // Should stop with 2 steps
825 require.True(t, condition([]StepResult{step1, step2}))
826
827 // Should stop with more than 2 steps
828 require.True(t, condition([]StepResult{step1, step2, step3}))
829
830 // Should not stop with empty steps
831 require.False(t, condition([]StepResult{}))
832 })
833
834 t.Run("HasToolCall", func(t *testing.T) {
835 t.Parallel()
836 condition := HasToolCall("search")
837
838 // Should not stop when tool not called
839 require.False(t, condition([]StepResult{step1}))
840
841 // Should stop when tool is called in last step
842 require.True(t, condition([]StepResult{step1, step2}))
843
844 // Should not stop when tool called in earlier step but not last
845 require.False(t, condition([]StepResult{step1, step2, step3}))
846
847 // Should not stop with empty steps
848 require.False(t, condition([]StepResult{}))
849
850 // Should not stop when different tool is called
851 differentToolCondition := HasToolCall("different_tool")
852 require.False(t, differentToolCondition([]StepResult{step1, step2}))
853 })
854
855 t.Run("HasContent", func(t *testing.T) {
856 t.Parallel()
857 reasoningCondition := HasContent(ContentTypeReasoning)
858 fileCondition := HasContent(ContentTypeFile)
859
860 // Should not stop when content type not present
861 require.False(t, reasoningCondition([]StepResult{step1, step2}))
862
863 // Should stop when content type is present in last step
864 require.True(t, reasoningCondition([]StepResult{step1, step2, step3}))
865 require.True(t, fileCondition([]StepResult{step1, step2, step3}))
866
867 // Should not stop with empty steps
868 require.False(t, reasoningCondition([]StepResult{}))
869 })
870
871 t.Run("FinishReasonIs", func(t *testing.T) {
872 t.Parallel()
873 stopCondition := FinishReasonIs(FinishReasonStop)
874 lengthCondition := FinishReasonIs(FinishReasonLength)
875
876 // Should not stop when finish reason doesn't match
877 require.False(t, stopCondition([]StepResult{step1}))
878
879 // Should stop when finish reason matches in last step
880 require.True(t, stopCondition([]StepResult{step1, step2}))
881 require.True(t, lengthCondition([]StepResult{step1, step2, step3}))
882
883 // Should not stop with empty steps
884 require.False(t, stopCondition([]StepResult{}))
885 })
886
887 t.Run("MaxTokensUsed", func(t *testing.T) {
888 condition := MaxTokensUsed(30)
889
890 // Should not stop when under limit
891 require.False(t, condition([]StepResult{step1})) // 10 tokens
892 require.False(t, condition([]StepResult{step1, step2})) // 25 tokens
893
894 // Should stop when at or over limit
895 require.True(t, condition([]StepResult{step1, step2, step3})) // 45 tokens
896
897 // Should not stop with empty steps
898 require.False(t, condition([]StepResult{}))
899 })
900}
901
902func TestStopConditions_Integration(t *testing.T) {
903 t.Parallel()
904
905 t.Run("StepCountIs integration", func(t *testing.T) {
906 t.Parallel()
907 model := &mockLanguageModel{
908 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
909 return &Response{
910 Content: ResponseContent{
911 TextContent{Text: "Mock response"},
912 },
913 Usage: Usage{
914 InputTokens: 3,
915 OutputTokens: 10,
916 TotalTokens: 13,
917 },
918 FinishReason: FinishReasonStop,
919 }, nil
920 },
921 }
922
923 agent := NewAgent(model, WithStopConditions(StepCountIs(1)))
924
925 result, err := agent.Generate(context.Background(), AgentCall{
926 Prompt: "test prompt",
927 })
928
929 require.NoError(t, err)
930 require.NotNil(t, result)
931 require.Len(t, result.Steps, 1) // Should stop after 1 step
932 })
933
934 t.Run("Multiple stop conditions", func(t *testing.T) {
935 t.Parallel()
936 model := &mockLanguageModel{
937 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
938 return &Response{
939 Content: ResponseContent{
940 TextContent{Text: "Mock response"},
941 },
942 Usage: Usage{
943 InputTokens: 3,
944 OutputTokens: 10,
945 TotalTokens: 13,
946 },
947 FinishReason: FinishReasonStop,
948 }, nil
949 },
950 }
951
952 agent := NewAgent(model, WithStopConditions(
953 StepCountIs(5), // Stop after 5 steps
954 FinishReasonIs(FinishReasonStop), // Or stop on finish reason
955 ))
956
957 result, err := agent.Generate(context.Background(), AgentCall{
958 Prompt: "test prompt",
959 })
960
961 require.NoError(t, err)
962 require.NotNil(t, result)
963 // Should stop on first condition met (finish reason stop)
964 require.Equal(t, FinishReasonStop, result.Response.FinishReason)
965 })
966}
967
968func TestPrepareStep(t *testing.T) {
969 t.Parallel()
970
971 t.Run("System prompt modification", func(t *testing.T) {
972 t.Parallel()
973 var capturedSystemPrompt string
974 model := &mockLanguageModel{
975 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
976 // Capture the system message to verify it was modified
977 if len(call.Prompt) > 0 && call.Prompt[0].Role == MessageRoleSystem {
978 if len(call.Prompt[0].Content) > 0 {
979 if textPart, ok := AsContentType[TextPart](call.Prompt[0].Content[0]); ok {
980 capturedSystemPrompt = textPart.Text
981 }
982 }
983 }
984 return &Response{
985 Content: ResponseContent{
986 TextContent{Text: "Response"},
987 },
988 Usage: Usage{TotalTokens: 10},
989 FinishReason: FinishReasonStop,
990 }, nil
991 },
992 }
993
994 prepareStepFunc := func(ctx context.Context, options PrepareStepFunctionOptions) (context.Context, PrepareStepResult, error) {
995 newSystem := "Modified system prompt for step " + fmt.Sprintf("%d", options.StepNumber)
996 return ctx, PrepareStepResult{
997 Model: options.Model,
998 Messages: options.Messages,
999 System: &newSystem,
1000 }, nil
1001 }
1002
1003 agent := NewAgent(model, WithSystemPrompt("Original system prompt"))
1004
1005 result, err := agent.Generate(context.Background(), AgentCall{
1006 Prompt: "test prompt",
1007 PrepareStep: prepareStepFunc,
1008 })
1009
1010 require.NoError(t, err)
1011 require.NotNil(t, result)
1012 require.Equal(t, "Modified system prompt for step 0", capturedSystemPrompt)
1013 })
1014
1015 t.Run("Tool choice modification", func(t *testing.T) {
1016 t.Parallel()
1017 var capturedToolChoice *ToolChoice
1018 model := &mockLanguageModel{
1019 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1020 capturedToolChoice = call.ToolChoice
1021 return &Response{
1022 Content: ResponseContent{
1023 TextContent{Text: "Response"},
1024 },
1025 Usage: Usage{TotalTokens: 10},
1026 FinishReason: FinishReasonStop,
1027 }, nil
1028 },
1029 }
1030
1031 prepareStepFunc := func(ctx context.Context, options PrepareStepFunctionOptions) (context.Context, PrepareStepResult, error) {
1032 toolChoice := ToolChoiceNone
1033 return ctx, PrepareStepResult{
1034 Model: options.Model,
1035 Messages: options.Messages,
1036 ToolChoice: &toolChoice,
1037 }, nil
1038 }
1039
1040 agent := NewAgent(model)
1041
1042 result, err := agent.Generate(context.Background(), AgentCall{
1043 Prompt: "test prompt",
1044 PrepareStep: prepareStepFunc,
1045 })
1046
1047 require.NoError(t, err)
1048 require.NotNil(t, result)
1049 require.NotNil(t, capturedToolChoice)
1050 require.Equal(t, ToolChoiceNone, *capturedToolChoice)
1051 })
1052
1053 t.Run("Active tools modification", func(t *testing.T) {
1054 t.Parallel()
1055 var capturedToolNames []string
1056 model := &mockLanguageModel{
1057 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1058 // Capture tool names to verify active tools were modified
1059 for _, tool := range call.Tools {
1060 capturedToolNames = append(capturedToolNames, tool.GetName())
1061 }
1062 return &Response{
1063 Content: ResponseContent{
1064 TextContent{Text: "Response"},
1065 },
1066 Usage: Usage{TotalTokens: 10},
1067 FinishReason: FinishReasonStop,
1068 }, nil
1069 },
1070 }
1071
1072 tool1 := &mockTool{name: "tool1", description: "Tool 1"}
1073 tool2 := &mockTool{name: "tool2", description: "Tool 2"}
1074 tool3 := &mockTool{name: "tool3", description: "Tool 3"}
1075
1076 prepareStepFunc := func(ctx context.Context, options PrepareStepFunctionOptions) (context.Context, PrepareStepResult, error) {
1077 activeTools := []string{"tool2"} // Only tool2 should be active
1078 return ctx, PrepareStepResult{
1079 Model: options.Model,
1080 Messages: options.Messages,
1081 ActiveTools: activeTools,
1082 }, nil
1083 }
1084
1085 agent := NewAgent(model, WithTools(tool1, tool2, tool3))
1086
1087 result, err := agent.Generate(context.Background(), AgentCall{
1088 Prompt: "test prompt",
1089 PrepareStep: prepareStepFunc,
1090 })
1091
1092 require.NoError(t, err)
1093 require.NotNil(t, result)
1094 require.Len(t, capturedToolNames, 1)
1095 require.Equal(t, "tool2", capturedToolNames[0])
1096 })
1097
1098 t.Run("No tools when DisableAllTools is true", func(t *testing.T) {
1099 t.Parallel()
1100 var capturedToolCount int
1101 model := &mockLanguageModel{
1102 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1103 capturedToolCount = len(call.Tools)
1104 return &Response{
1105 Content: ResponseContent{
1106 TextContent{Text: "Response"},
1107 },
1108 Usage: Usage{TotalTokens: 10},
1109 FinishReason: FinishReasonStop,
1110 }, nil
1111 },
1112 }
1113
1114 tool1 := &mockTool{name: "tool1", description: "Tool 1"}
1115
1116 prepareStepFunc := func(ctx context.Context, options PrepareStepFunctionOptions) (context.Context, PrepareStepResult, error) {
1117 return ctx, PrepareStepResult{
1118 Model: options.Model,
1119 Messages: options.Messages,
1120 DisableAllTools: true, // Disable all tools for this step
1121 }, nil
1122 }
1123
1124 agent := NewAgent(model, WithTools(tool1))
1125
1126 result, err := agent.Generate(context.Background(), AgentCall{
1127 Prompt: "test prompt",
1128 PrepareStep: prepareStepFunc,
1129 })
1130
1131 require.NoError(t, err)
1132 require.NotNil(t, result)
1133 require.Equal(t, 0, capturedToolCount) // No tools should be passed
1134 })
1135
1136 t.Run("All fields modified together", func(t *testing.T) {
1137 t.Parallel()
1138 var capturedSystemPrompt string
1139 var capturedToolChoice *ToolChoice
1140 var capturedToolNames []string
1141
1142 model := &mockLanguageModel{
1143 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1144 // Capture system prompt
1145 if len(call.Prompt) > 0 && call.Prompt[0].Role == MessageRoleSystem {
1146 if len(call.Prompt[0].Content) > 0 {
1147 if textPart, ok := AsContentType[TextPart](call.Prompt[0].Content[0]); ok {
1148 capturedSystemPrompt = textPart.Text
1149 }
1150 }
1151 }
1152 // Capture tool choice
1153 capturedToolChoice = call.ToolChoice
1154 // Capture tool names
1155 for _, tool := range call.Tools {
1156 capturedToolNames = append(capturedToolNames, tool.GetName())
1157 }
1158 return &Response{
1159 Content: ResponseContent{
1160 TextContent{Text: "Response"},
1161 },
1162 Usage: Usage{TotalTokens: 10},
1163 FinishReason: FinishReasonStop,
1164 }, nil
1165 },
1166 }
1167
1168 tool1 := &mockTool{name: "tool1", description: "Tool 1"}
1169 tool2 := &mockTool{name: "tool2", description: "Tool 2"}
1170
1171 prepareStepFunc := func(ctx context.Context, options PrepareStepFunctionOptions) (context.Context, PrepareStepResult, error) {
1172 newSystem := "Step-specific system"
1173 toolChoice := SpecificToolChoice("tool1")
1174 activeTools := []string{"tool1"}
1175 return ctx, PrepareStepResult{
1176 Model: options.Model,
1177 Messages: options.Messages,
1178 System: &newSystem,
1179 ToolChoice: &toolChoice,
1180 ActiveTools: activeTools,
1181 }, nil
1182 }
1183
1184 agent := NewAgent(model, WithSystemPrompt("Original system"), WithTools(tool1, tool2))
1185
1186 result, err := agent.Generate(context.Background(), AgentCall{
1187 Prompt: "test prompt",
1188 PrepareStep: prepareStepFunc,
1189 })
1190
1191 require.NoError(t, err)
1192 require.NotNil(t, result)
1193 require.Equal(t, "Step-specific system", capturedSystemPrompt)
1194 require.NotNil(t, capturedToolChoice)
1195 require.Equal(t, SpecificToolChoice("tool1"), *capturedToolChoice)
1196 require.Len(t, capturedToolNames, 1)
1197 require.Equal(t, "tool1", capturedToolNames[0])
1198 })
1199
1200 t.Run("Nil fields use parent values", func(t *testing.T) {
1201 t.Parallel()
1202 var capturedSystemPrompt string
1203 var capturedToolChoice *ToolChoice
1204 var capturedToolNames []string
1205
1206 model := &mockLanguageModel{
1207 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1208 // Capture system prompt
1209 if len(call.Prompt) > 0 && call.Prompt[0].Role == MessageRoleSystem {
1210 if len(call.Prompt[0].Content) > 0 {
1211 if textPart, ok := AsContentType[TextPart](call.Prompt[0].Content[0]); ok {
1212 capturedSystemPrompt = textPart.Text
1213 }
1214 }
1215 }
1216 // Capture tool choice
1217 capturedToolChoice = call.ToolChoice
1218 // Capture tool names
1219 for _, tool := range call.Tools {
1220 capturedToolNames = append(capturedToolNames, tool.GetName())
1221 }
1222 return &Response{
1223 Content: ResponseContent{
1224 TextContent{Text: "Response"},
1225 },
1226 Usage: Usage{TotalTokens: 10},
1227 FinishReason: FinishReasonStop,
1228 }, nil
1229 },
1230 }
1231
1232 tool1 := &mockTool{name: "tool1", description: "Tool 1"}
1233
1234 prepareStepFunc := func(ctx context.Context, options PrepareStepFunctionOptions) (context.Context, PrepareStepResult, error) {
1235 // All optional fields are nil, should use parent values
1236 return ctx, PrepareStepResult{
1237 Model: options.Model,
1238 Messages: options.Messages,
1239 System: nil, // Use parent
1240 ToolChoice: nil, // Use parent (auto)
1241 ActiveTools: nil, // Use parent (all tools)
1242 }, nil
1243 }
1244
1245 agent := NewAgent(model, WithSystemPrompt("Parent system"), WithTools(tool1))
1246
1247 result, err := agent.Generate(context.Background(), AgentCall{
1248 Prompt: "test prompt",
1249 PrepareStep: prepareStepFunc,
1250 })
1251
1252 require.NoError(t, err)
1253 require.NotNil(t, result)
1254 require.Equal(t, "Parent system", capturedSystemPrompt)
1255 require.NotNil(t, capturedToolChoice)
1256 require.Equal(t, ToolChoiceAuto, *capturedToolChoice) // Default
1257 require.Len(t, capturedToolNames, 1)
1258 require.Equal(t, "tool1", capturedToolNames[0])
1259 })
1260
1261 t.Run("Empty ActiveTools means all tools", func(t *testing.T) {
1262 t.Parallel()
1263 var capturedToolNames []string
1264 model := &mockLanguageModel{
1265 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1266 // Capture tool names to verify all tools are included
1267 for _, tool := range call.Tools {
1268 capturedToolNames = append(capturedToolNames, tool.GetName())
1269 }
1270 return &Response{
1271 Content: ResponseContent{
1272 TextContent{Text: "Response"},
1273 },
1274 Usage: Usage{TotalTokens: 10},
1275 FinishReason: FinishReasonStop,
1276 }, nil
1277 },
1278 }
1279
1280 tool1 := &mockTool{name: "tool1", description: "Tool 1"}
1281 tool2 := &mockTool{name: "tool2", description: "Tool 2"}
1282
1283 prepareStepFunc := func(ctx context.Context, options PrepareStepFunctionOptions) (context.Context, PrepareStepResult, error) {
1284 return ctx, PrepareStepResult{
1285 Model: options.Model,
1286 Messages: options.Messages,
1287 ActiveTools: []string{}, // Empty slice means all tools
1288 }, nil
1289 }
1290
1291 agent := NewAgent(model, WithTools(tool1, tool2))
1292
1293 result, err := agent.Generate(context.Background(), AgentCall{
1294 Prompt: "test prompt",
1295 PrepareStep: prepareStepFunc,
1296 })
1297
1298 require.NoError(t, err)
1299 require.NotNil(t, result)
1300 require.Len(t, capturedToolNames, 2) // All tools should be included
1301 require.Contains(t, capturedToolNames, "tool1")
1302 require.Contains(t, capturedToolNames, "tool2")
1303 })
1304}
1305
1306func TestToolCallRepair(t *testing.T) {
1307 t.Parallel()
1308
1309 t.Run("Valid tool call passes validation", func(t *testing.T) {
1310 t.Parallel()
1311 model := &mockLanguageModel{
1312 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1313 return &Response{
1314 Content: ResponseContent{
1315 TextContent{Text: "Response"},
1316 ToolCallContent{
1317 ToolCallID: "call1",
1318 ToolName: "test_tool",
1319 Input: `{"value": "test"}`, // Valid JSON with required field
1320 },
1321 },
1322 Usage: Usage{TotalTokens: 10},
1323 FinishReason: FinishReasonStop, // Changed to stop to avoid infinite loop
1324 }, nil
1325 },
1326 }
1327
1328 tool := &mockTool{
1329 name: "test_tool",
1330 description: "Test tool",
1331 parameters: map[string]any{
1332 "value": map[string]any{"type": "string"},
1333 },
1334 required: []string{"value"},
1335 executeFunc: func(ctx context.Context, call ToolCall) (ToolResponse, error) {
1336 return ToolResponse{Content: "success", IsError: false}, nil
1337 },
1338 }
1339
1340 agent := NewAgent(model, WithTools(tool), WithStopConditions(StepCountIs(2))) // Limit steps
1341
1342 result, err := agent.Generate(context.Background(), AgentCall{
1343 Prompt: "test prompt",
1344 })
1345
1346 require.NoError(t, err)
1347 require.NotNil(t, result)
1348 require.Len(t, result.Steps, 1) // Only one step since FinishReason is stop
1349
1350 // Check that tool call was executed successfully
1351 toolCalls := result.Steps[0].Content.ToolCalls()
1352 require.Len(t, toolCalls, 1)
1353 require.False(t, toolCalls[0].Invalid) // Should be valid
1354 })
1355
1356 t.Run("Invalid tool call without repair function", func(t *testing.T) {
1357 t.Parallel()
1358 model := &mockLanguageModel{
1359 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1360 return &Response{
1361 Content: ResponseContent{
1362 TextContent{Text: "Response"},
1363 ToolCallContent{
1364 ToolCallID: "call1",
1365 ToolName: "test_tool",
1366 Input: `{"wrong_field": "test"}`, // Missing required field
1367 },
1368 },
1369 Usage: Usage{TotalTokens: 10},
1370 FinishReason: FinishReasonStop, // Changed to stop to avoid infinite loop
1371 }, nil
1372 },
1373 }
1374
1375 tool := &mockTool{
1376 name: "test_tool",
1377 description: "Test tool",
1378 parameters: map[string]any{
1379 "value": map[string]any{"type": "string"},
1380 },
1381 required: []string{"value"},
1382 }
1383
1384 agent := NewAgent(model, WithTools(tool), WithStopConditions(StepCountIs(2))) // Limit steps
1385
1386 result, err := agent.Generate(context.Background(), AgentCall{
1387 Prompt: "test prompt",
1388 })
1389
1390 require.NoError(t, err)
1391 require.NotNil(t, result)
1392 require.Len(t, result.Steps, 1) // Only one step
1393
1394 // Check that tool call was marked as invalid
1395 toolCalls := result.Steps[0].Content.ToolCalls()
1396 require.Len(t, toolCalls, 1)
1397 require.True(t, toolCalls[0].Invalid) // Should be invalid
1398 require.Contains(t, toolCalls[0].ValidationError.Error(), "missing required parameter: value")
1399 })
1400
1401 t.Run("Invalid tool call with successful repair", func(t *testing.T) {
1402 t.Parallel()
1403 model := &mockLanguageModel{
1404 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1405 return &Response{
1406 Content: ResponseContent{
1407 TextContent{Text: "Response"},
1408 ToolCallContent{
1409 ToolCallID: "call1",
1410 ToolName: "test_tool",
1411 Input: `{"wrong_field": "test"}`, // Missing required field
1412 },
1413 },
1414 Usage: Usage{TotalTokens: 10},
1415 FinishReason: FinishReasonStop, // Changed to stop
1416 }, nil
1417 },
1418 }
1419
1420 tool := &mockTool{
1421 name: "test_tool",
1422 description: "Test tool",
1423 parameters: map[string]any{
1424 "value": map[string]any{"type": "string"},
1425 },
1426 required: []string{"value"},
1427 executeFunc: func(ctx context.Context, call ToolCall) (ToolResponse, error) {
1428 return ToolResponse{Content: "repaired_success", IsError: false}, nil
1429 },
1430 }
1431
1432 repairFunc := func(ctx context.Context, options ToolCallRepairOptions) (*ToolCallContent, error) {
1433 // Simple repair: add the missing required field
1434 repairedToolCall := options.OriginalToolCall
1435 repairedToolCall.Input = `{"value": "repaired"}`
1436 return &repairedToolCall, nil
1437 }
1438
1439 agent := NewAgent(model, WithTools(tool), WithRepairToolCall(repairFunc), WithStopConditions(StepCountIs(2)))
1440
1441 result, err := agent.Generate(context.Background(), AgentCall{
1442 Prompt: "test prompt",
1443 })
1444
1445 require.NoError(t, err)
1446 require.NotNil(t, result)
1447 require.Len(t, result.Steps, 1) // Only one step
1448
1449 // Check that tool call was repaired and is now valid
1450 toolCalls := result.Steps[0].Content.ToolCalls()
1451 require.Len(t, toolCalls, 1)
1452 require.False(t, toolCalls[0].Invalid) // Should be valid after repair
1453 require.Equal(t, `{"value": "repaired"}`, toolCalls[0].Input) // Should have repaired input
1454 })
1455
1456 t.Run("Invalid tool call with failed repair", func(t *testing.T) {
1457 t.Parallel()
1458 model := &mockLanguageModel{
1459 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1460 return &Response{
1461 Content: ResponseContent{
1462 TextContent{Text: "Response"},
1463 ToolCallContent{
1464 ToolCallID: "call1",
1465 ToolName: "test_tool",
1466 Input: `{"wrong_field": "test"}`, // Missing required field
1467 },
1468 },
1469 Usage: Usage{TotalTokens: 10},
1470 FinishReason: FinishReasonStop, // Changed to stop
1471 }, nil
1472 },
1473 }
1474
1475 tool := &mockTool{
1476 name: "test_tool",
1477 description: "Test tool",
1478 parameters: map[string]any{
1479 "value": map[string]any{"type": "string"},
1480 },
1481 required: []string{"value"},
1482 }
1483
1484 repairFunc := func(ctx context.Context, options ToolCallRepairOptions) (*ToolCallContent, error) {
1485 // Repair function fails
1486 return nil, errors.New("repair failed")
1487 }
1488
1489 agent := NewAgent(model, WithTools(tool), WithRepairToolCall(repairFunc), WithStopConditions(StepCountIs(2)))
1490
1491 result, err := agent.Generate(context.Background(), AgentCall{
1492 Prompt: "test prompt",
1493 })
1494
1495 require.NoError(t, err)
1496 require.NotNil(t, result)
1497 require.Len(t, result.Steps, 1) // Only one step
1498
1499 // Check that tool call was marked as invalid since repair failed
1500 toolCalls := result.Steps[0].Content.ToolCalls()
1501 require.Len(t, toolCalls, 1)
1502 require.True(t, toolCalls[0].Invalid) // Should be invalid
1503 require.Contains(t, toolCalls[0].ValidationError.Error(), "missing required parameter: value")
1504 })
1505
1506 t.Run("Nonexistent tool call", func(t *testing.T) {
1507 t.Parallel()
1508 model := &mockLanguageModel{
1509 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1510 return &Response{
1511 Content: ResponseContent{
1512 TextContent{Text: "Response"},
1513 ToolCallContent{
1514 ToolCallID: "call1",
1515 ToolName: "nonexistent_tool",
1516 Input: `{"value": "test"}`,
1517 },
1518 },
1519 Usage: Usage{TotalTokens: 10},
1520 FinishReason: FinishReasonStop, // Changed to stop
1521 }, nil
1522 },
1523 }
1524
1525 tool := &mockTool{name: "test_tool", description: "Test tool"}
1526
1527 agent := NewAgent(model, WithTools(tool), WithStopConditions(StepCountIs(2)))
1528
1529 result, err := agent.Generate(context.Background(), AgentCall{
1530 Prompt: "test prompt",
1531 })
1532
1533 require.NoError(t, err)
1534 require.NotNil(t, result)
1535 require.Len(t, result.Steps, 1) // Only one step
1536
1537 // Check that tool call was marked as invalid due to nonexistent tool
1538 toolCalls := result.Steps[0].Content.ToolCalls()
1539 require.Len(t, toolCalls, 1)
1540 require.True(t, toolCalls[0].Invalid) // Should be invalid
1541 require.Contains(t, toolCalls[0].ValidationError.Error(), "tool not found: nonexistent_tool")
1542 })
1543
1544 t.Run("Invalid JSON in tool call", func(t *testing.T) {
1545 t.Parallel()
1546 model := &mockLanguageModel{
1547 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1548 return &Response{
1549 Content: ResponseContent{
1550 TextContent{Text: "Response"},
1551 ToolCallContent{
1552 ToolCallID: "call1",
1553 ToolName: "test_tool",
1554 Input: `{invalid json}`, // Invalid JSON
1555 },
1556 },
1557 Usage: Usage{TotalTokens: 10},
1558 FinishReason: FinishReasonStop, // Changed to stop
1559 }, nil
1560 },
1561 }
1562
1563 tool := &mockTool{
1564 name: "test_tool",
1565 description: "Test tool",
1566 parameters: map[string]any{
1567 "value": map[string]any{"type": "string"},
1568 },
1569 required: []string{"value"},
1570 }
1571
1572 agent := NewAgent(model, WithTools(tool), WithStopConditions(StepCountIs(2)))
1573
1574 result, err := agent.Generate(context.Background(), AgentCall{
1575 Prompt: "test prompt",
1576 })
1577
1578 require.NoError(t, err)
1579 require.NotNil(t, result)
1580 require.Len(t, result.Steps, 1) // Only one step
1581
1582 // Check that tool call was marked as invalid due to invalid JSON
1583 toolCalls := result.Steps[0].Content.ToolCalls()
1584 require.Len(t, toolCalls, 1)
1585 require.True(t, toolCalls[0].Invalid) // Should be invalid
1586 require.Contains(t, toolCalls[0].ValidationError.Error(), "invalid JSON input")
1587 })
1588}
1589
1590// Test media and image tool responses
1591func TestAgent_MediaToolResponses(t *testing.T) {
1592 t.Parallel()
1593
1594 imageData := []byte{0x89, 0x50, 0x4E, 0x47} // PNG header bytes
1595 audioData := []byte{0x52, 0x49, 0x46, 0x46} // RIFF header bytes
1596
1597 t.Run("Image tool response", func(t *testing.T) {
1598 t.Parallel()
1599
1600 imageTool := &mockTool{
1601 name: "generate_image",
1602 description: "Generates an image",
1603 executeFunc: func(ctx context.Context, call ToolCall) (ToolResponse, error) {
1604 return NewImageResponse(imageData, "image/png"), nil
1605 },
1606 }
1607
1608 model := &mockLanguageModel{
1609 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1610 if len(call.Prompt) == 1 {
1611 // First call - request image tool
1612 return &Response{
1613 Content: []Content{
1614 ToolCallContent{
1615 ToolCallID: "img-1",
1616 ToolName: "generate_image",
1617 Input: `{}`,
1618 },
1619 },
1620 Usage: Usage{TotalTokens: 10},
1621 FinishReason: FinishReasonToolCalls,
1622 }, nil
1623 }
1624 // Second call - after tool execution
1625 return &Response{
1626 Content: []Content{TextContent{Text: "Image generated"}},
1627 Usage: Usage{TotalTokens: 20},
1628 FinishReason: FinishReasonStop,
1629 }, nil
1630 },
1631 }
1632
1633 agent := NewAgent(model, WithTools(imageTool), WithStopConditions(StepCountIs(3)))
1634
1635 result, err := agent.Generate(context.Background(), AgentCall{
1636 Prompt: "Generate an image",
1637 })
1638
1639 require.NoError(t, err)
1640 require.NotNil(t, result)
1641 require.Len(t, result.Steps, 2) // Tool call step + final response
1642
1643 // Check tool results in first step
1644 toolResults := result.Steps[0].Content.ToolResults()
1645 require.Len(t, toolResults, 1)
1646
1647 mediaResult, ok := toolResults[0].Result.(ToolResultOutputContentMedia)
1648 require.True(t, ok, "Expected media result")
1649 require.Equal(t, string(imageData), mediaResult.Data)
1650 require.Equal(t, "image/png", mediaResult.MediaType)
1651 })
1652
1653 t.Run("Media tool response (audio)", func(t *testing.T) {
1654 t.Parallel()
1655
1656 audioTool := &mockTool{
1657 name: "generate_audio",
1658 description: "Generates audio",
1659 executeFunc: func(ctx context.Context, call ToolCall) (ToolResponse, error) {
1660 return NewMediaResponse(audioData, "audio/wav"), nil
1661 },
1662 }
1663
1664 model := &mockLanguageModel{
1665 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1666 if len(call.Prompt) == 1 {
1667 return &Response{
1668 Content: []Content{
1669 ToolCallContent{
1670 ToolCallID: "audio-1",
1671 ToolName: "generate_audio",
1672 Input: `{}`,
1673 },
1674 },
1675 Usage: Usage{TotalTokens: 10},
1676 FinishReason: FinishReasonToolCalls,
1677 }, nil
1678 }
1679 return &Response{
1680 Content: []Content{TextContent{Text: "Audio generated"}},
1681 Usage: Usage{TotalTokens: 20},
1682 FinishReason: FinishReasonStop,
1683 }, nil
1684 },
1685 }
1686
1687 agent := NewAgent(model, WithTools(audioTool), WithStopConditions(StepCountIs(3)))
1688
1689 result, err := agent.Generate(context.Background(), AgentCall{
1690 Prompt: "Generate audio",
1691 })
1692
1693 require.NoError(t, err)
1694 require.NotNil(t, result)
1695
1696 toolResults := result.Steps[0].Content.ToolResults()
1697 require.Len(t, toolResults, 1)
1698
1699 mediaResult, ok := toolResults[0].Result.(ToolResultOutputContentMedia)
1700 require.True(t, ok, "Expected media result")
1701 require.Equal(t, string(audioData), mediaResult.Data)
1702 require.Equal(t, "audio/wav", mediaResult.MediaType)
1703 })
1704
1705 t.Run("Media response with text", func(t *testing.T) {
1706 t.Parallel()
1707
1708 imageTool := &mockTool{
1709 name: "screenshot",
1710 description: "Takes a screenshot",
1711 executeFunc: func(ctx context.Context, call ToolCall) (ToolResponse, error) {
1712 resp := NewImageResponse(imageData, "image/png")
1713 resp.Content = "Screenshot captured successfully"
1714 return resp, nil
1715 },
1716 }
1717
1718 model := &mockLanguageModel{
1719 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1720 if len(call.Prompt) == 1 {
1721 return &Response{
1722 Content: []Content{
1723 ToolCallContent{
1724 ToolCallID: "screen-1",
1725 ToolName: "screenshot",
1726 Input: `{}`,
1727 },
1728 },
1729 Usage: Usage{TotalTokens: 10},
1730 FinishReason: FinishReasonToolCalls,
1731 }, nil
1732 }
1733 return &Response{
1734 Content: []Content{TextContent{Text: "Done"}},
1735 Usage: Usage{TotalTokens: 20},
1736 FinishReason: FinishReasonStop,
1737 }, nil
1738 },
1739 }
1740
1741 agent := NewAgent(model, WithTools(imageTool), WithStopConditions(StepCountIs(3)))
1742
1743 result, err := agent.Generate(context.Background(), AgentCall{
1744 Prompt: "Take a screenshot",
1745 })
1746
1747 require.NoError(t, err)
1748 require.NotNil(t, result)
1749
1750 toolResults := result.Steps[0].Content.ToolResults()
1751 require.Len(t, toolResults, 1)
1752
1753 mediaResult, ok := toolResults[0].Result.(ToolResultOutputContentMedia)
1754 require.True(t, ok, "Expected media result")
1755 require.Equal(t, string(imageData), mediaResult.Data)
1756 require.Equal(t, "image/png", mediaResult.MediaType)
1757 require.Equal(t, "Screenshot captured successfully", mediaResult.Text)
1758 })
1759
1760 t.Run("Media response preserves metadata", func(t *testing.T) {
1761 t.Parallel()
1762
1763 type ImageMetadata struct {
1764 Width int `json:"width"`
1765 Height int `json:"height"`
1766 }
1767
1768 imageTool := &mockTool{
1769 name: "generate_image",
1770 description: "Generates an image",
1771 executeFunc: func(ctx context.Context, call ToolCall) (ToolResponse, error) {
1772 resp := NewImageResponse(imageData, "image/png")
1773 return WithResponseMetadata(resp, ImageMetadata{Width: 800, Height: 600}), nil
1774 },
1775 }
1776
1777 model := &mockLanguageModel{
1778 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1779 if len(call.Prompt) == 1 {
1780 return &Response{
1781 Content: []Content{
1782 ToolCallContent{
1783 ToolCallID: "img-1",
1784 ToolName: "generate_image",
1785 Input: `{}`,
1786 },
1787 },
1788 Usage: Usage{TotalTokens: 10},
1789 FinishReason: FinishReasonToolCalls,
1790 }, nil
1791 }
1792 return &Response{
1793 Content: []Content{TextContent{Text: "Done"}},
1794 Usage: Usage{TotalTokens: 20},
1795 FinishReason: FinishReasonStop,
1796 }, nil
1797 },
1798 }
1799
1800 agent := NewAgent(model, WithTools(imageTool), WithStopConditions(StepCountIs(3)))
1801
1802 result, err := agent.Generate(context.Background(), AgentCall{
1803 Prompt: "Generate image",
1804 })
1805
1806 require.NoError(t, err)
1807 require.NotNil(t, result)
1808
1809 toolResults := result.Steps[0].Content.ToolResults()
1810 require.Len(t, toolResults, 1)
1811
1812 // Check metadata was preserved
1813 require.NotEmpty(t, toolResults[0].ClientMetadata)
1814
1815 var metadata ImageMetadata
1816 err = json.Unmarshal([]byte(toolResults[0].ClientMetadata), &metadata)
1817 require.NoError(t, err)
1818 require.Equal(t, 800, metadata.Width)
1819 require.Equal(t, 600, metadata.Height)
1820 })
1821}
1822
1823func TestToResponseMessages_ProviderExecutedRouting(t *testing.T) {
1824 t.Parallel()
1825
1826 // Build step content that mixes a provider-executed tool call/result
1827 // (e.g. web search) with a regular local tool call/result.
1828 content := []Content{
1829 // Provider-executed tool call.
1830 &ToolCallContent{
1831 ToolCallID: "srvtoolu_01",
1832 ToolName: "web_search",
1833 Input: `{"query":"test"}`,
1834 ProviderExecuted: true,
1835 },
1836 // Provider-executed tool result.
1837 &ToolResultContent{
1838 ToolCallID: "srvtoolu_01",
1839 ProviderExecuted: true,
1840 },
1841 // Regular (locally-executed) tool call.
1842 &ToolCallContent{
1843 ToolCallID: "toolu_02",
1844 ToolName: "calculator",
1845 Input: `{"expr":"1+1"}`,
1846 },
1847 // Regular tool result.
1848 &ToolResultContent{
1849 ToolCallID: "toolu_02",
1850 Result: ToolResultOutputContentText{Text: "2"},
1851 },
1852 // Some trailing text.
1853 &TextContent{Text: "Done."},
1854 }
1855
1856 msgs := toResponseMessages(content)
1857
1858 // Expect two messages: assistant + tool.
1859 require.Len(t, msgs, 2)
1860
1861 // Assistant message should contain:
1862 // 1. provider-executed ToolCallPart
1863 // 2. provider-executed ToolResultPart
1864 // 3. regular ToolCallPart
1865 // 4. TextPart
1866 assistant := msgs[0]
1867 require.Equal(t, MessageRoleAssistant, assistant.Role)
1868 require.Len(t, assistant.Content, 4)
1869
1870 // Verify provider-executed tool call is in assistant.
1871 tc1, ok := AsMessagePart[ToolCallPart](assistant.Content[0])
1872 require.True(t, ok)
1873 require.Equal(t, "srvtoolu_01", tc1.ToolCallID)
1874 require.True(t, tc1.ProviderExecuted)
1875
1876 // Verify provider-executed tool result is in assistant.
1877 tr1, ok := AsMessagePart[ToolResultPart](assistant.Content[1])
1878 require.True(t, ok)
1879 require.Equal(t, "srvtoolu_01", tr1.ToolCallID)
1880 require.True(t, tr1.ProviderExecuted)
1881
1882 // Verify regular tool call is in assistant.
1883 tc2, ok := AsMessagePart[ToolCallPart](assistant.Content[2])
1884 require.True(t, ok)
1885 require.Equal(t, "toolu_02", tc2.ToolCallID)
1886 require.False(t, tc2.ProviderExecuted)
1887
1888 // Verify text part is in assistant.
1889 text, ok := AsMessagePart[TextPart](assistant.Content[3])
1890 require.True(t, ok)
1891 require.Equal(t, "Done.", text.Text)
1892
1893 // Tool message should contain only the regular tool result.
1894 toolMsg := msgs[1]
1895 require.Equal(t, MessageRoleTool, toolMsg.Role)
1896 require.Len(t, toolMsg.Content, 1)
1897
1898 tr2, ok := AsMessagePart[ToolResultPart](toolMsg.Content[0])
1899 require.True(t, ok)
1900 require.Equal(t, "toolu_02", tr2.ToolCallID)
1901 require.False(t, tr2.ProviderExecuted)
1902}