1package fantasy
2
3import (
4 "context"
5 "encoding/base64"
6 "encoding/json"
7 "errors"
8 "fmt"
9 "testing"
10
11 "github.com/stretchr/testify/require"
12)
13
14// Mock tool for testing
15type mockTool struct {
16 name string
17 providerOptions ProviderOptions
18 description string
19 parameters map[string]any
20 required []string
21 executeFunc func(ctx context.Context, call ToolCall) (ToolResponse, error)
22}
23
24func (m *mockTool) SetProviderOptions(opts ProviderOptions) {
25 m.providerOptions = opts
26}
27
28func (m *mockTool) ProviderOptions() ProviderOptions {
29 return m.providerOptions
30}
31
32func (m *mockTool) Info() ToolInfo {
33 return ToolInfo{
34 Name: m.name,
35 Description: m.description,
36 Parameters: m.parameters,
37 Required: m.required,
38 }
39}
40
41func (m *mockTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
42 if m.executeFunc != nil {
43 return m.executeFunc(ctx, call)
44 }
45 return ToolResponse{Content: "mock result", IsError: false}, nil
46}
47
48// Mock language model for testing
49type mockLanguageModel struct {
50 generateFunc func(ctx context.Context, call Call) (*Response, error)
51 streamFunc func(ctx context.Context, call Call) (StreamResponse, error)
52}
53
54func (m *mockLanguageModel) Generate(ctx context.Context, call Call) (*Response, error) {
55 if m.generateFunc != nil {
56 return m.generateFunc(ctx, call)
57 }
58 return &Response{
59 Content: []Content{
60 TextContent{Text: "Hello, world!"},
61 },
62 Usage: Usage{
63 InputTokens: 3,
64 OutputTokens: 10,
65 TotalTokens: 13,
66 },
67 FinishReason: FinishReasonStop,
68 }, nil
69}
70
71func (m *mockLanguageModel) Stream(ctx context.Context, call Call) (StreamResponse, error) {
72 if m.streamFunc != nil {
73 return m.streamFunc(ctx, call)
74 }
75 return nil, fmt.Errorf("mock stream not implemented")
76}
77
78func (m *mockLanguageModel) Provider() string {
79 return "mock-provider"
80}
81
82func (m *mockLanguageModel) Model() string {
83 return "mock-model"
84}
85
86func (m *mockLanguageModel) GenerateObject(ctx context.Context, call ObjectCall) (*ObjectResponse, error) {
87 return nil, fmt.Errorf("mock GenerateObject not implemented")
88}
89
90func (m *mockLanguageModel) StreamObject(ctx context.Context, call ObjectCall) (ObjectStreamResponse, error) {
91 return nil, fmt.Errorf("mock StreamObject not implemented")
92}
93
94// Test result.content - comprehensive content types (matches TS test)
95func TestAgent_Generate_ResultContent_AllTypes(t *testing.T) {
96 t.Parallel()
97
98 // Create a type-safe tool using the new API
99 type TestInput struct {
100 Value string `json:"value" description:"Test value"`
101 }
102
103 tool1 := NewAgentTool(
104 "tool1",
105 "Test tool",
106 func(ctx context.Context, input TestInput, _ ToolCall) (ToolResponse, error) {
107 require.Equal(t, "value", input.Value)
108 return ToolResponse{Content: "result1", IsError: false}, nil
109 },
110 )
111
112 model := &mockLanguageModel{
113 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
114 return &Response{
115 Content: []Content{
116 TextContent{Text: "Hello, world!"},
117 SourceContent{
118 ID: "123",
119 URL: "https://example.com",
120 Title: "Example",
121 SourceType: SourceTypeURL,
122 },
123 FileContent{
124 Data: []byte{1, 2, 3},
125 MediaType: "image/png",
126 },
127 ReasoningContent{
128 Text: "I will open the conversation with witty banter.",
129 },
130 ToolCallContent{
131 ToolCallID: "call-1",
132 ToolName: "tool1",
133 Input: `{"value":"value"}`,
134 },
135 TextContent{Text: "More text"},
136 },
137 Usage: Usage{
138 InputTokens: 3,
139 OutputTokens: 10,
140 TotalTokens: 13,
141 },
142 FinishReason: FinishReasonStop, // Note: FinishReasonStop, not ToolCalls
143 }, nil
144 },
145 }
146
147 agent := NewAgent(model, WithTools(tool1))
148 result, err := agent.Generate(context.Background(), AgentCall{
149 Prompt: "prompt",
150 })
151
152 require.NoError(t, err)
153 require.NotNil(t, result)
154 require.Len(t, result.Steps, 1) // Single step like TypeScript
155
156 // Check final response content includes tool result
157 require.Len(t, result.Response.Content, 7) // original 6 + 1 tool result
158
159 // Verify each content type in order
160 textContent, ok := AsContentType[TextContent](result.Response.Content[0])
161 require.True(t, ok)
162 require.Equal(t, "Hello, world!", textContent.Text)
163
164 sourceContent, ok := AsContentType[SourceContent](result.Response.Content[1])
165 require.True(t, ok)
166 require.Equal(t, "123", sourceContent.ID)
167
168 fileContent, ok := AsContentType[FileContent](result.Response.Content[2])
169 require.True(t, ok)
170 require.Equal(t, []byte{1, 2, 3}, fileContent.Data)
171
172 reasoningContent, ok := AsContentType[ReasoningContent](result.Response.Content[3])
173 require.True(t, ok)
174 require.Equal(t, "I will open the conversation with witty banter.", reasoningContent.Text)
175
176 toolCallContent, ok := AsContentType[ToolCallContent](result.Response.Content[4])
177 require.True(t, ok)
178 require.Equal(t, "call-1", toolCallContent.ToolCallID)
179
180 moreTextContent, ok := AsContentType[TextContent](result.Response.Content[5])
181 require.True(t, ok)
182 require.Equal(t, "More text", moreTextContent.Text)
183
184 // Tool result should be appended
185 toolResultContent, ok := AsContentType[ToolResultContent](result.Response.Content[6])
186 require.True(t, ok)
187 require.Equal(t, "call-1", toolResultContent.ToolCallID)
188 require.Equal(t, "tool1", toolResultContent.ToolName)
189}
190
191// Test result.text extraction
192func TestAgent_Generate_ResultText(t *testing.T) {
193 t.Parallel()
194
195 model := &mockLanguageModel{
196 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
197 return &Response{
198 Content: []Content{
199 TextContent{Text: "Hello, world!"},
200 },
201 Usage: Usage{
202 InputTokens: 3,
203 OutputTokens: 10,
204 TotalTokens: 13,
205 },
206 FinishReason: FinishReasonStop,
207 }, nil
208 },
209 }
210
211 agent := NewAgent(model)
212 result, err := agent.Generate(context.Background(), AgentCall{
213 Prompt: "prompt",
214 })
215
216 require.NoError(t, err)
217 require.NotNil(t, result)
218
219 // Test text extraction from content
220 text := result.Response.Content.Text()
221 require.Equal(t, "Hello, world!", text)
222}
223
224// Test result.toolCalls extraction (matches TS test exactly)
225func TestAgent_Generate_ResultToolCalls(t *testing.T) {
226 t.Parallel()
227
228 // Create type-safe tools using the new API
229 type Tool1Input struct {
230 Value string `json:"value" description:"Test value"`
231 }
232
233 type Tool2Input struct {
234 SomethingElse string `json:"somethingElse" description:"Another test value"`
235 }
236
237 tool1 := NewAgentTool(
238 "tool1",
239 "Test tool 1",
240 func(ctx context.Context, input Tool1Input, _ ToolCall) (ToolResponse, error) {
241 return ToolResponse{Content: "result1", IsError: false}, nil
242 },
243 )
244
245 tool2 := NewAgentTool(
246 "tool2",
247 "Test tool 2",
248 func(ctx context.Context, input Tool2Input, _ ToolCall) (ToolResponse, error) {
249 return ToolResponse{Content: "result2", IsError: false}, nil
250 },
251 )
252
253 model := &mockLanguageModel{
254 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
255 // Verify tools are passed correctly
256 require.Len(t, call.Tools, 2)
257 require.Equal(t, ToolChoiceAuto, *call.ToolChoice) // Should be auto, not required
258
259 // Verify prompt structure
260 require.Len(t, call.Prompt, 1)
261 require.Equal(t, MessageRoleUser, call.Prompt[0].Role)
262
263 return &Response{
264 Content: []Content{
265 ToolCallContent{
266 ToolCallID: "call-1",
267 ToolName: "tool1",
268 Input: `{"value":"value"}`,
269 },
270 },
271 Usage: Usage{
272 InputTokens: 3,
273 OutputTokens: 10,
274 TotalTokens: 13,
275 },
276 FinishReason: FinishReasonStop, // Note: Stop, not ToolCalls
277 }, nil
278 },
279 }
280
281 agent := NewAgent(model, WithTools(tool1, tool2))
282 result, err := agent.Generate(context.Background(), AgentCall{
283 Prompt: "test-input",
284 })
285
286 require.NoError(t, err)
287 require.NotNil(t, result)
288 require.Len(t, result.Steps, 1) // Single step
289
290 // Extract tool calls from final response (should be empty since tools don't execute)
291 var toolCalls []ToolCallContent
292 for _, content := range result.Response.Content {
293 if toolCall, ok := AsContentType[ToolCallContent](content); ok {
294 toolCalls = append(toolCalls, toolCall)
295 }
296 }
297
298 require.Len(t, toolCalls, 1)
299 require.Equal(t, "call-1", toolCalls[0].ToolCallID)
300 require.Equal(t, "tool1", toolCalls[0].ToolName)
301
302 // Parse and verify input
303 var input map[string]any
304 err = json.Unmarshal([]byte(toolCalls[0].Input), &input)
305 require.NoError(t, err)
306 require.Equal(t, "value", input["value"])
307}
308
309// Test result.toolResults extraction (matches TS test exactly)
310func TestAgent_Generate_ResultToolResults(t *testing.T) {
311 t.Parallel()
312
313 // Create type-safe tool using the new API
314 type TestInput struct {
315 Value string `json:"value" description:"Test value"`
316 }
317
318 tool1 := NewAgentTool(
319 "tool1",
320 "Test tool",
321 func(ctx context.Context, input TestInput, _ ToolCall) (ToolResponse, error) {
322 require.Equal(t, "value", input.Value)
323 return ToolResponse{Content: "result1", IsError: false}, nil
324 },
325 )
326
327 model := &mockLanguageModel{
328 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
329 // Verify tools and tool choice
330 require.Len(t, call.Tools, 1)
331 require.Equal(t, ToolChoiceAuto, *call.ToolChoice)
332
333 // Verify prompt
334 require.Len(t, call.Prompt, 1)
335 require.Equal(t, MessageRoleUser, call.Prompt[0].Role)
336
337 return &Response{
338 Content: []Content{
339 ToolCallContent{
340 ToolCallID: "call-1",
341 ToolName: "tool1",
342 Input: `{"value":"value"}`,
343 },
344 },
345 Usage: Usage{
346 InputTokens: 3,
347 OutputTokens: 10,
348 TotalTokens: 13,
349 },
350 FinishReason: FinishReasonStop, // Note: Stop, not ToolCalls
351 }, nil
352 },
353 }
354
355 agent := NewAgent(model, WithTools(tool1))
356 result, err := agent.Generate(context.Background(), AgentCall{
357 Prompt: "test-input",
358 })
359
360 require.NoError(t, err)
361 require.NotNil(t, result)
362 require.Len(t, result.Steps, 1) // Single step
363
364 // Extract tool results from final response
365 var toolResults []ToolResultContent
366 for _, content := range result.Response.Content {
367 if toolResult, ok := AsContentType[ToolResultContent](content); ok {
368 toolResults = append(toolResults, toolResult)
369 }
370 }
371
372 require.Len(t, toolResults, 1)
373 require.Equal(t, "call-1", toolResults[0].ToolCallID)
374 require.Equal(t, "tool1", toolResults[0].ToolName)
375
376 // Verify result content
377 textResult, ok := toolResults[0].Result.(ToolResultOutputContentText)
378 require.True(t, ok)
379 require.Equal(t, "result1", textResult.Text)
380}
381
382// Test multi-step scenario (matches TS "2 steps: initial, tool-result" test)
383func TestAgent_Generate_MultipleSteps(t *testing.T) {
384 t.Parallel()
385
386 // Create type-safe tool using the new API
387 type TestInput struct {
388 Value string `json:"value" description:"Test value"`
389 }
390
391 tool1 := NewAgentTool(
392 "tool1",
393 "Test tool",
394 func(ctx context.Context, input TestInput, _ ToolCall) (ToolResponse, error) {
395 require.Equal(t, "value", input.Value)
396 return ToolResponse{Content: "result1", IsError: false}, nil
397 },
398 )
399
400 callCount := 0
401 model := &mockLanguageModel{
402 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
403 callCount++
404 switch callCount {
405 case 1:
406 // First call - return tool call with FinishReasonToolCalls
407 return &Response{
408 Content: []Content{
409 ToolCallContent{
410 ToolCallID: "call-1",
411 ToolName: "tool1",
412 Input: `{"value":"value"}`,
413 },
414 },
415 Usage: Usage{
416 InputTokens: 10,
417 OutputTokens: 5,
418 TotalTokens: 15,
419 },
420 FinishReason: FinishReasonToolCalls, // This triggers multi-step
421 }, nil
422 case 2:
423 // Second call - return final text
424 return &Response{
425 Content: []Content{
426 TextContent{Text: "Hello, world!"},
427 },
428 Usage: Usage{
429 InputTokens: 3,
430 OutputTokens: 10,
431 TotalTokens: 13,
432 },
433 FinishReason: FinishReasonStop,
434 }, nil
435 default:
436 t.Fatalf("Unexpected call count: %d", callCount)
437 return nil, nil
438 }
439 },
440 }
441
442 agent := NewAgent(model, WithTools(tool1))
443 result, err := agent.Generate(context.Background(), AgentCall{
444 Prompt: "test-input",
445 })
446
447 require.NoError(t, err)
448 require.NotNil(t, result)
449 require.Len(t, result.Steps, 2)
450
451 // Check total usage sums both steps
452 require.Equal(t, int64(13), result.TotalUsage.InputTokens) // 10 + 3
453 require.Equal(t, int64(15), result.TotalUsage.OutputTokens) // 5 + 10
454 require.Equal(t, int64(28), result.TotalUsage.TotalTokens) // 15 + 13
455
456 // Final response should be from last step
457 require.Len(t, result.Response.Content, 1)
458 textContent, ok := AsContentType[TextContent](result.Response.Content[0])
459 require.True(t, ok)
460 require.Equal(t, "Hello, world!", textContent.Text)
461
462 // result.toolCalls should be empty (from last step)
463 var toolCalls []ToolCallContent
464 for _, content := range result.Response.Content {
465 if _, ok := AsContentType[ToolCallContent](content); ok {
466 toolCalls = append(toolCalls, content.(ToolCallContent))
467 }
468 }
469 require.Len(t, toolCalls, 0)
470
471 // result.toolResults should be empty (from last step)
472 var toolResults []ToolResultContent
473 for _, content := range result.Response.Content {
474 if _, ok := AsContentType[ToolResultContent](content); ok {
475 toolResults = append(toolResults, content.(ToolResultContent))
476 }
477 }
478 require.Len(t, toolResults, 0)
479}
480
481// Test basic text generation
482func TestAgent_Generate_BasicText(t *testing.T) {
483 t.Parallel()
484
485 model := &mockLanguageModel{
486 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
487 return &Response{
488 Content: []Content{
489 TextContent{Text: "Hello, world!"},
490 },
491 Usage: Usage{
492 InputTokens: 3,
493 OutputTokens: 10,
494 TotalTokens: 13,
495 },
496 FinishReason: FinishReasonStop,
497 }, nil
498 },
499 }
500
501 agent := NewAgent(model)
502 result, err := agent.Generate(context.Background(), AgentCall{
503 Prompt: "test prompt",
504 })
505
506 require.NoError(t, err)
507 require.NotNil(t, result)
508 require.Len(t, result.Steps, 1)
509
510 // Check final response
511 require.Len(t, result.Response.Content, 1)
512 textContent, ok := AsContentType[TextContent](result.Response.Content[0])
513 require.True(t, ok)
514 require.Equal(t, "Hello, world!", textContent.Text)
515
516 // Check usage
517 require.Equal(t, int64(3), result.Response.Usage.InputTokens)
518 require.Equal(t, int64(10), result.Response.Usage.OutputTokens)
519 require.Equal(t, int64(13), result.Response.Usage.TotalTokens)
520
521 // Check total usage
522 require.Equal(t, int64(3), result.TotalUsage.InputTokens)
523 require.Equal(t, int64(10), result.TotalUsage.OutputTokens)
524 require.Equal(t, int64(13), result.TotalUsage.TotalTokens)
525}
526
527// Test empty prompt validation
528func TestAgent_Generate_EmptyPrompt(t *testing.T) {
529 t.Parallel()
530
531 model := &mockLanguageModel{}
532 agent := NewAgent(model)
533
534 t.Run("fails without messages", func(t *testing.T) {
535 result, err := agent.Generate(context.Background(), AgentCall{
536 Prompt: "",
537 })
538 require.Error(t, err)
539 require.Nil(t, result)
540 require.Contains(t, err.Error(), "prompt can't be empty when there are no messages")
541 })
542
543 t.Run("fails with files even if messages exist", func(t *testing.T) {
544 result, err := agent.Generate(context.Background(), AgentCall{
545 Prompt: "",
546 Messages: []Message{
547 {Role: MessageRoleUser, Content: []MessagePart{TextPart{Text: "hello"}}},
548 },
549 Files: []FilePart{{Filename: "test.txt", Data: []byte("test"), MediaType: "text/plain"}},
550 })
551 require.Error(t, err)
552 require.Nil(t, result)
553 require.Contains(t, err.Error(), "prompt can't be empty when there are files")
554 })
555
556 t.Run("fails when last message is assistant", func(t *testing.T) {
557 result, err := agent.Generate(context.Background(), AgentCall{
558 Prompt: "",
559 Messages: []Message{
560 {Role: MessageRoleUser, Content: []MessagePart{TextPart{Text: "hello"}}},
561 {Role: MessageRoleAssistant, Content: []MessagePart{TextPart{Text: "hi there"}}},
562 },
563 })
564 require.Error(t, err)
565 require.Nil(t, result)
566 require.Contains(t, err.Error(), "prompt can't be empty when the last message is not a user or tool message")
567 })
568
569 t.Run("succeeds when last message is user", func(t *testing.T) {
570 model := &mockLanguageModel{
571 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
572 return &Response{
573 Content: []Content{TextContent{Text: "response"}},
574 FinishReason: FinishReasonStop,
575 }, nil
576 },
577 }
578 agent := NewAgent(model)
579
580 result, err := agent.Generate(context.Background(), AgentCall{
581 Prompt: "",
582 Messages: []Message{
583 {Role: MessageRoleUser, Content: []MessagePart{TextPart{Text: "hello"}}},
584 },
585 })
586 require.NoError(t, err)
587 require.NotNil(t, result)
588 })
589
590 t.Run("succeeds when last message is tool", func(t *testing.T) {
591 model := &mockLanguageModel{
592 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
593 return &Response{
594 Content: []Content{TextContent{Text: "response"}},
595 FinishReason: FinishReasonStop,
596 }, nil
597 },
598 }
599 agent := NewAgent(model)
600
601 result, err := agent.Generate(context.Background(), AgentCall{
602 Prompt: "",
603 Messages: []Message{
604 {Role: MessageRoleUser, Content: []MessagePart{TextPart{Text: "hello"}}},
605 {Role: MessageRoleAssistant, Content: []MessagePart{ToolCallPart{ToolCallID: "call_1", ToolName: "test"}}},
606 {Role: MessageRoleTool, Content: []MessagePart{ToolResultPart{ToolCallID: "call_1", Output: ToolResultOutputContentText{Text: "result"}}}},
607 },
608 })
609 require.NoError(t, err)
610 require.NotNil(t, result)
611 })
612}
613
614// Test with system prompt
615func TestAgent_Generate_WithSystemPrompt(t *testing.T) {
616 t.Parallel()
617
618 model := &mockLanguageModel{
619 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
620 // Verify system message is included
621 require.Len(t, call.Prompt, 2) // system + user
622 require.Equal(t, MessageRoleSystem, call.Prompt[0].Role)
623 require.Equal(t, MessageRoleUser, call.Prompt[1].Role)
624
625 systemPart, ok := call.Prompt[0].Content[0].(TextPart)
626 require.True(t, ok)
627 require.Equal(t, "You are a helpful assistant", systemPart.Text)
628
629 return &Response{
630 Content: []Content{
631 TextContent{Text: "Hello, world!"},
632 },
633 Usage: Usage{
634 InputTokens: 3,
635 OutputTokens: 10,
636 TotalTokens: 13,
637 },
638 FinishReason: FinishReasonStop,
639 }, nil
640 },
641 }
642
643 agent := NewAgent(model, WithSystemPrompt("You are a helpful assistant"))
644 result, err := agent.Generate(context.Background(), AgentCall{
645 Prompt: "test prompt",
646 })
647
648 require.NoError(t, err)
649 require.NotNil(t, result)
650}
651
652// Test options.activeTools filtering
653func TestAgent_Generate_OptionsActiveTools(t *testing.T) {
654 t.Parallel()
655
656 tool1 := &mockTool{
657 name: "tool1",
658 description: "Test tool 1",
659 parameters: map[string]any{
660 "value": map[string]any{"type": "string"},
661 },
662 required: []string{"value"},
663 }
664
665 tool2 := &mockTool{
666 name: "tool2",
667 description: "Test tool 2",
668 parameters: map[string]any{
669 "value": map[string]any{"type": "string"},
670 },
671 required: []string{"value"},
672 }
673
674 model := &mockLanguageModel{
675 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
676 // Verify only tool1 is available
677 require.Len(t, call.Tools, 1)
678 functionTool, ok := call.Tools[0].(FunctionTool)
679 require.True(t, ok)
680 require.Equal(t, "tool1", functionTool.Name)
681
682 return &Response{
683 Content: []Content{
684 TextContent{Text: "Hello, world!"},
685 },
686 Usage: Usage{
687 InputTokens: 3,
688 OutputTokens: 10,
689 TotalTokens: 13,
690 },
691 FinishReason: FinishReasonStop,
692 }, nil
693 },
694 }
695
696 agent := NewAgent(model, WithTools(tool1, tool2))
697 result, err := agent.Generate(context.Background(), AgentCall{
698 Prompt: "test-input",
699 ActiveTools: []string{"tool1"}, // Only tool1 should be active
700 })
701
702 require.NoError(t, err)
703 require.NotNil(t, result)
704}
705
706func TestAgent_Generate_OptionsActiveTools_WithProviderDefinedTools(t *testing.T) {
707 t.Parallel()
708
709 tool1 := &mockTool{
710 name: "tool1",
711 description: "Test tool 1",
712 parameters: map[string]any{
713 "value": map[string]any{"type": "string"},
714 },
715 required: []string{"value"},
716 }
717
718 providerTool1 := ProviderDefinedTool{ID: "provider.web_search", Name: "web_search"}
719 providerTool2 := ProviderDefinedTool{ID: "provider.code_execution", Name: "code_execution"}
720
721 model := &mockLanguageModel{
722 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
723 require.Len(t, call.Tools, 2)
724
725 functionTool, ok := call.Tools[0].(FunctionTool)
726 require.True(t, ok)
727 require.Equal(t, "tool1", functionTool.Name)
728
729 providerTool, ok := call.Tools[1].(ProviderDefinedTool)
730 require.True(t, ok)
731 require.Equal(t, "web_search", providerTool.Name)
732
733 return &Response{
734 Content: []Content{
735 TextContent{Text: "Hello, world!"},
736 },
737 Usage: Usage{
738 InputTokens: 3,
739 OutputTokens: 10,
740 TotalTokens: 13,
741 },
742 FinishReason: FinishReasonStop,
743 }, nil
744 },
745 }
746
747 agent := NewAgent(model, WithTools(tool1), WithProviderDefinedTools(providerTool1, providerTool2))
748 result, err := agent.Generate(context.Background(), AgentCall{
749 Prompt: "test-input",
750 ActiveTools: []string{"tool1", "web_search"}, // Only tool1 and web_search should be active
751 })
752
753 require.NoError(t, err)
754 require.NotNil(t, result)
755}
756
757func TestResponseContent_Getters(t *testing.T) {
758 t.Parallel()
759
760 // Create test content with all types
761 content := ResponseContent{
762 TextContent{Text: "Hello world"},
763 ReasoningContent{Text: "Let me think..."},
764 FileContent{Data: []byte("file data"), MediaType: "text/plain"},
765 SourceContent{SourceType: SourceTypeURL, URL: "https://example.com", Title: "Example"},
766 ToolCallContent{ToolCallID: "call1", ToolName: "test_tool", Input: `{"arg": "value"}`},
767 ToolResultContent{ToolCallID: "call1", ToolName: "test_tool", Result: ToolResultOutputContentText{Text: "result"}},
768 }
769
770 // Test Text()
771 require.Equal(t, "Hello world", content.Text())
772
773 // Test Reasoning()
774 reasoning := content.Reasoning()
775 require.Len(t, reasoning, 1)
776 require.Equal(t, "Let me think...", reasoning[0].Text)
777
778 // Test ReasoningText()
779 require.Equal(t, "Let me think...", content.ReasoningText())
780
781 // Test Files()
782 files := content.Files()
783 require.Len(t, files, 1)
784 require.Equal(t, "text/plain", files[0].MediaType)
785 require.Equal(t, []byte("file data"), files[0].Data)
786
787 // Test Sources()
788 sources := content.Sources()
789 require.Len(t, sources, 1)
790 require.Equal(t, SourceTypeURL, sources[0].SourceType)
791 require.Equal(t, "https://example.com", sources[0].URL)
792 require.Equal(t, "Example", sources[0].Title)
793
794 // Test ToolCalls()
795 toolCalls := content.ToolCalls()
796 require.Len(t, toolCalls, 1)
797 require.Equal(t, "call1", toolCalls[0].ToolCallID)
798 require.Equal(t, "test_tool", toolCalls[0].ToolName)
799 require.Equal(t, `{"arg": "value"}`, toolCalls[0].Input)
800
801 // Test ToolResults()
802 toolResults := content.ToolResults()
803 require.Len(t, toolResults, 1)
804 require.Equal(t, "call1", toolResults[0].ToolCallID)
805 require.Equal(t, "test_tool", toolResults[0].ToolName)
806 result, ok := AsToolResultOutputType[ToolResultOutputContentText](toolResults[0].Result)
807 require.True(t, ok)
808 require.Equal(t, "result", result.Text)
809}
810
811func TestResponseContent_Getters_Empty(t *testing.T) {
812 t.Parallel()
813
814 // Test with empty content
815 content := ResponseContent{}
816
817 require.Equal(t, "", content.Text())
818 require.Equal(t, "", content.ReasoningText())
819 require.Empty(t, content.Reasoning())
820 require.Empty(t, content.Files())
821 require.Empty(t, content.Sources())
822 require.Empty(t, content.ToolCalls())
823 require.Empty(t, content.ToolResults())
824}
825
826func TestResponseContent_Getters_MultipleItems(t *testing.T) {
827 t.Parallel()
828
829 // Test with multiple items of same type
830 content := ResponseContent{
831 ReasoningContent{Text: "First thought"},
832 ReasoningContent{Text: "Second thought"},
833 FileContent{Data: []byte("file1"), MediaType: "text/plain"},
834 FileContent{Data: []byte("file2"), MediaType: "image/png"},
835 }
836
837 // Test multiple reasoning
838 reasoning := content.Reasoning()
839 require.Len(t, reasoning, 2)
840 require.Equal(t, "First thought", reasoning[0].Text)
841 require.Equal(t, "Second thought", reasoning[1].Text)
842
843 // Test concatenated reasoning text
844 require.Equal(t, "First thoughtSecond thought", content.ReasoningText())
845
846 // Test multiple files
847 files := content.Files()
848 require.Len(t, files, 2)
849 require.Equal(t, "text/plain", files[0].MediaType)
850 require.Equal(t, "image/png", files[1].MediaType)
851}
852
853func TestStopConditions(t *testing.T) {
854 t.Parallel()
855
856 // Create test steps
857 step1 := StepResult{
858 Response: Response{
859 Content: ResponseContent{
860 TextContent{Text: "Hello"},
861 },
862 FinishReason: FinishReasonToolCalls,
863 Usage: Usage{TotalTokens: 10},
864 },
865 }
866
867 step2 := StepResult{
868 Response: Response{
869 Content: ResponseContent{
870 TextContent{Text: "World"},
871 ToolCallContent{ToolCallID: "call1", ToolName: "search", Input: `{"query": "test"}`},
872 },
873 FinishReason: FinishReasonStop,
874 Usage: Usage{TotalTokens: 15},
875 },
876 }
877
878 step3 := StepResult{
879 Response: Response{
880 Content: ResponseContent{
881 ReasoningContent{Text: "Let me think..."},
882 FileContent{Data: []byte("data"), MediaType: "text/plain"},
883 },
884 FinishReason: FinishReasonLength,
885 Usage: Usage{TotalTokens: 20},
886 },
887 }
888
889 t.Run("StepCountIs", func(t *testing.T) {
890 t.Parallel()
891 condition := StepCountIs(2)
892
893 // Should not stop with 1 step
894 require.False(t, condition([]StepResult{step1}))
895
896 // Should stop with 2 steps
897 require.True(t, condition([]StepResult{step1, step2}))
898
899 // Should stop with more than 2 steps
900 require.True(t, condition([]StepResult{step1, step2, step3}))
901
902 // Should not stop with empty steps
903 require.False(t, condition([]StepResult{}))
904 })
905
906 t.Run("HasToolCall", func(t *testing.T) {
907 t.Parallel()
908 condition := HasToolCall("search")
909
910 // Should not stop when tool not called
911 require.False(t, condition([]StepResult{step1}))
912
913 // Should stop when tool is called in last step
914 require.True(t, condition([]StepResult{step1, step2}))
915
916 // Should not stop when tool called in earlier step but not last
917 require.False(t, condition([]StepResult{step1, step2, step3}))
918
919 // Should not stop with empty steps
920 require.False(t, condition([]StepResult{}))
921
922 // Should not stop when different tool is called
923 differentToolCondition := HasToolCall("different_tool")
924 require.False(t, differentToolCondition([]StepResult{step1, step2}))
925 })
926
927 t.Run("HasContent", func(t *testing.T) {
928 t.Parallel()
929 reasoningCondition := HasContent(ContentTypeReasoning)
930 fileCondition := HasContent(ContentTypeFile)
931
932 // Should not stop when content type not present
933 require.False(t, reasoningCondition([]StepResult{step1, step2}))
934
935 // Should stop when content type is present in last step
936 require.True(t, reasoningCondition([]StepResult{step1, step2, step3}))
937 require.True(t, fileCondition([]StepResult{step1, step2, step3}))
938
939 // Should not stop with empty steps
940 require.False(t, reasoningCondition([]StepResult{}))
941 })
942
943 t.Run("FinishReasonIs", func(t *testing.T) {
944 t.Parallel()
945 stopCondition := FinishReasonIs(FinishReasonStop)
946 lengthCondition := FinishReasonIs(FinishReasonLength)
947
948 // Should not stop when finish reason doesn't match
949 require.False(t, stopCondition([]StepResult{step1}))
950
951 // Should stop when finish reason matches in last step
952 require.True(t, stopCondition([]StepResult{step1, step2}))
953 require.True(t, lengthCondition([]StepResult{step1, step2, step3}))
954
955 // Should not stop with empty steps
956 require.False(t, stopCondition([]StepResult{}))
957 })
958
959 t.Run("MaxTokensUsed", func(t *testing.T) {
960 condition := MaxTokensUsed(30)
961
962 // Should not stop when under limit
963 require.False(t, condition([]StepResult{step1})) // 10 tokens
964 require.False(t, condition([]StepResult{step1, step2})) // 25 tokens
965
966 // Should stop when at or over limit
967 require.True(t, condition([]StepResult{step1, step2, step3})) // 45 tokens
968
969 // Should not stop with empty steps
970 require.False(t, condition([]StepResult{}))
971 })
972}
973
974func TestStopConditions_Integration(t *testing.T) {
975 t.Parallel()
976
977 t.Run("StepCountIs integration", func(t *testing.T) {
978 t.Parallel()
979 model := &mockLanguageModel{
980 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
981 return &Response{
982 Content: ResponseContent{
983 TextContent{Text: "Mock response"},
984 },
985 Usage: Usage{
986 InputTokens: 3,
987 OutputTokens: 10,
988 TotalTokens: 13,
989 },
990 FinishReason: FinishReasonStop,
991 }, nil
992 },
993 }
994
995 agent := NewAgent(model, WithStopConditions(StepCountIs(1)))
996
997 result, err := agent.Generate(context.Background(), AgentCall{
998 Prompt: "test prompt",
999 })
1000
1001 require.NoError(t, err)
1002 require.NotNil(t, result)
1003 require.Len(t, result.Steps, 1) // Should stop after 1 step
1004 })
1005
1006 t.Run("Multiple stop conditions", func(t *testing.T) {
1007 t.Parallel()
1008 model := &mockLanguageModel{
1009 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1010 return &Response{
1011 Content: ResponseContent{
1012 TextContent{Text: "Mock response"},
1013 },
1014 Usage: Usage{
1015 InputTokens: 3,
1016 OutputTokens: 10,
1017 TotalTokens: 13,
1018 },
1019 FinishReason: FinishReasonStop,
1020 }, nil
1021 },
1022 }
1023
1024 agent := NewAgent(model, WithStopConditions(
1025 StepCountIs(5), // Stop after 5 steps
1026 FinishReasonIs(FinishReasonStop), // Or stop on finish reason
1027 ))
1028
1029 result, err := agent.Generate(context.Background(), AgentCall{
1030 Prompt: "test prompt",
1031 })
1032
1033 require.NoError(t, err)
1034 require.NotNil(t, result)
1035 // Should stop on first condition met (finish reason stop)
1036 require.Equal(t, FinishReasonStop, result.Response.FinishReason)
1037 })
1038}
1039
1040func TestPrepareStep(t *testing.T) {
1041 t.Parallel()
1042
1043 t.Run("System prompt modification", func(t *testing.T) {
1044 t.Parallel()
1045 var capturedSystemPrompt string
1046 model := &mockLanguageModel{
1047 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1048 // Capture the system message to verify it was modified
1049 if len(call.Prompt) > 0 && call.Prompt[0].Role == MessageRoleSystem {
1050 if len(call.Prompt[0].Content) > 0 {
1051 if textPart, ok := AsContentType[TextPart](call.Prompt[0].Content[0]); ok {
1052 capturedSystemPrompt = textPart.Text
1053 }
1054 }
1055 }
1056 return &Response{
1057 Content: ResponseContent{
1058 TextContent{Text: "Response"},
1059 },
1060 Usage: Usage{TotalTokens: 10},
1061 FinishReason: FinishReasonStop,
1062 }, nil
1063 },
1064 }
1065
1066 prepareStepFunc := func(ctx context.Context, options PrepareStepFunctionOptions) (context.Context, PrepareStepResult, error) {
1067 newSystem := "Modified system prompt for step " + fmt.Sprintf("%d", options.StepNumber)
1068 return ctx, PrepareStepResult{
1069 Model: options.Model,
1070 Messages: options.Messages,
1071 System: &newSystem,
1072 }, nil
1073 }
1074
1075 agent := NewAgent(model, WithSystemPrompt("Original system prompt"))
1076
1077 result, err := agent.Generate(context.Background(), AgentCall{
1078 Prompt: "test prompt",
1079 PrepareStep: prepareStepFunc,
1080 })
1081
1082 require.NoError(t, err)
1083 require.NotNil(t, result)
1084 require.Equal(t, "Modified system prompt for step 0", capturedSystemPrompt)
1085 })
1086
1087 t.Run("Tool choice modification", func(t *testing.T) {
1088 t.Parallel()
1089 var capturedToolChoice *ToolChoice
1090 model := &mockLanguageModel{
1091 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1092 capturedToolChoice = call.ToolChoice
1093 return &Response{
1094 Content: ResponseContent{
1095 TextContent{Text: "Response"},
1096 },
1097 Usage: Usage{TotalTokens: 10},
1098 FinishReason: FinishReasonStop,
1099 }, nil
1100 },
1101 }
1102
1103 prepareStepFunc := func(ctx context.Context, options PrepareStepFunctionOptions) (context.Context, PrepareStepResult, error) {
1104 toolChoice := ToolChoiceNone
1105 return ctx, PrepareStepResult{
1106 Model: options.Model,
1107 Messages: options.Messages,
1108 ToolChoice: &toolChoice,
1109 }, nil
1110 }
1111
1112 agent := NewAgent(model)
1113
1114 result, err := agent.Generate(context.Background(), AgentCall{
1115 Prompt: "test prompt",
1116 PrepareStep: prepareStepFunc,
1117 })
1118
1119 require.NoError(t, err)
1120 require.NotNil(t, result)
1121 require.NotNil(t, capturedToolChoice)
1122 require.Equal(t, ToolChoiceNone, *capturedToolChoice)
1123 })
1124
1125 t.Run("Active tools modification", func(t *testing.T) {
1126 t.Parallel()
1127 var capturedToolNames []string
1128 model := &mockLanguageModel{
1129 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1130 // Capture tool names to verify active tools were modified
1131 for _, tool := range call.Tools {
1132 capturedToolNames = append(capturedToolNames, tool.GetName())
1133 }
1134 return &Response{
1135 Content: ResponseContent{
1136 TextContent{Text: "Response"},
1137 },
1138 Usage: Usage{TotalTokens: 10},
1139 FinishReason: FinishReasonStop,
1140 }, nil
1141 },
1142 }
1143
1144 tool1 := &mockTool{name: "tool1", description: "Tool 1"}
1145 tool2 := &mockTool{name: "tool2", description: "Tool 2"}
1146 tool3 := &mockTool{name: "tool3", description: "Tool 3"}
1147
1148 prepareStepFunc := func(ctx context.Context, options PrepareStepFunctionOptions) (context.Context, PrepareStepResult, error) {
1149 activeTools := []string{"tool2"} // Only tool2 should be active
1150 return ctx, PrepareStepResult{
1151 Model: options.Model,
1152 Messages: options.Messages,
1153 ActiveTools: activeTools,
1154 }, nil
1155 }
1156
1157 agent := NewAgent(model, WithTools(tool1, tool2, tool3))
1158
1159 result, err := agent.Generate(context.Background(), AgentCall{
1160 Prompt: "test prompt",
1161 PrepareStep: prepareStepFunc,
1162 })
1163
1164 require.NoError(t, err)
1165 require.NotNil(t, result)
1166 require.Len(t, capturedToolNames, 1)
1167 require.Equal(t, "tool2", capturedToolNames[0])
1168 })
1169
1170 t.Run("No tools when DisableAllTools is true", func(t *testing.T) {
1171 t.Parallel()
1172 var capturedToolCount int
1173 model := &mockLanguageModel{
1174 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1175 capturedToolCount = len(call.Tools)
1176 return &Response{
1177 Content: ResponseContent{
1178 TextContent{Text: "Response"},
1179 },
1180 Usage: Usage{TotalTokens: 10},
1181 FinishReason: FinishReasonStop,
1182 }, nil
1183 },
1184 }
1185
1186 tool1 := &mockTool{name: "tool1", description: "Tool 1"}
1187
1188 prepareStepFunc := func(ctx context.Context, options PrepareStepFunctionOptions) (context.Context, PrepareStepResult, error) {
1189 return ctx, PrepareStepResult{
1190 Model: options.Model,
1191 Messages: options.Messages,
1192 DisableAllTools: true, // Disable all tools for this step
1193 }, nil
1194 }
1195
1196 agent := NewAgent(model, WithTools(tool1))
1197
1198 result, err := agent.Generate(context.Background(), AgentCall{
1199 Prompt: "test prompt",
1200 PrepareStep: prepareStepFunc,
1201 })
1202
1203 require.NoError(t, err)
1204 require.NotNil(t, result)
1205 require.Equal(t, 0, capturedToolCount) // No tools should be passed
1206 })
1207
1208 t.Run("All fields modified together", func(t *testing.T) {
1209 t.Parallel()
1210 var capturedSystemPrompt string
1211 var capturedToolChoice *ToolChoice
1212 var capturedToolNames []string
1213
1214 model := &mockLanguageModel{
1215 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1216 // Capture system prompt
1217 if len(call.Prompt) > 0 && call.Prompt[0].Role == MessageRoleSystem {
1218 if len(call.Prompt[0].Content) > 0 {
1219 if textPart, ok := AsContentType[TextPart](call.Prompt[0].Content[0]); ok {
1220 capturedSystemPrompt = textPart.Text
1221 }
1222 }
1223 }
1224 // Capture tool choice
1225 capturedToolChoice = call.ToolChoice
1226 // Capture tool names
1227 for _, tool := range call.Tools {
1228 capturedToolNames = append(capturedToolNames, tool.GetName())
1229 }
1230 return &Response{
1231 Content: ResponseContent{
1232 TextContent{Text: "Response"},
1233 },
1234 Usage: Usage{TotalTokens: 10},
1235 FinishReason: FinishReasonStop,
1236 }, nil
1237 },
1238 }
1239
1240 tool1 := &mockTool{name: "tool1", description: "Tool 1"}
1241 tool2 := &mockTool{name: "tool2", description: "Tool 2"}
1242
1243 prepareStepFunc := func(ctx context.Context, options PrepareStepFunctionOptions) (context.Context, PrepareStepResult, error) {
1244 newSystem := "Step-specific system"
1245 toolChoice := SpecificToolChoice("tool1")
1246 activeTools := []string{"tool1"}
1247 return ctx, PrepareStepResult{
1248 Model: options.Model,
1249 Messages: options.Messages,
1250 System: &newSystem,
1251 ToolChoice: &toolChoice,
1252 ActiveTools: activeTools,
1253 }, nil
1254 }
1255
1256 agent := NewAgent(model, WithSystemPrompt("Original system"), WithTools(tool1, tool2))
1257
1258 result, err := agent.Generate(context.Background(), AgentCall{
1259 Prompt: "test prompt",
1260 PrepareStep: prepareStepFunc,
1261 })
1262
1263 require.NoError(t, err)
1264 require.NotNil(t, result)
1265 require.Equal(t, "Step-specific system", capturedSystemPrompt)
1266 require.NotNil(t, capturedToolChoice)
1267 require.Equal(t, SpecificToolChoice("tool1"), *capturedToolChoice)
1268 require.Len(t, capturedToolNames, 1)
1269 require.Equal(t, "tool1", capturedToolNames[0])
1270 })
1271
1272 t.Run("Nil fields use parent values", func(t *testing.T) {
1273 t.Parallel()
1274 var capturedSystemPrompt string
1275 var capturedToolChoice *ToolChoice
1276 var capturedToolNames []string
1277
1278 model := &mockLanguageModel{
1279 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1280 // Capture system prompt
1281 if len(call.Prompt) > 0 && call.Prompt[0].Role == MessageRoleSystem {
1282 if len(call.Prompt[0].Content) > 0 {
1283 if textPart, ok := AsContentType[TextPart](call.Prompt[0].Content[0]); ok {
1284 capturedSystemPrompt = textPart.Text
1285 }
1286 }
1287 }
1288 // Capture tool choice
1289 capturedToolChoice = call.ToolChoice
1290 // Capture tool names
1291 for _, tool := range call.Tools {
1292 capturedToolNames = append(capturedToolNames, tool.GetName())
1293 }
1294 return &Response{
1295 Content: ResponseContent{
1296 TextContent{Text: "Response"},
1297 },
1298 Usage: Usage{TotalTokens: 10},
1299 FinishReason: FinishReasonStop,
1300 }, nil
1301 },
1302 }
1303
1304 tool1 := &mockTool{name: "tool1", description: "Tool 1"}
1305
1306 prepareStepFunc := func(ctx context.Context, options PrepareStepFunctionOptions) (context.Context, PrepareStepResult, error) {
1307 // All optional fields are nil, should use parent values
1308 return ctx, PrepareStepResult{
1309 Model: options.Model,
1310 Messages: options.Messages,
1311 System: nil, // Use parent
1312 ToolChoice: nil, // Use parent (auto)
1313 ActiveTools: nil, // Use parent (all tools)
1314 }, nil
1315 }
1316
1317 agent := NewAgent(model, WithSystemPrompt("Parent system"), WithTools(tool1))
1318
1319 result, err := agent.Generate(context.Background(), AgentCall{
1320 Prompt: "test prompt",
1321 PrepareStep: prepareStepFunc,
1322 })
1323
1324 require.NoError(t, err)
1325 require.NotNil(t, result)
1326 require.Equal(t, "Parent system", capturedSystemPrompt)
1327 require.NotNil(t, capturedToolChoice)
1328 require.Equal(t, ToolChoiceAuto, *capturedToolChoice) // Default
1329 require.Len(t, capturedToolNames, 1)
1330 require.Equal(t, "tool1", capturedToolNames[0])
1331 })
1332
1333 t.Run("Empty ActiveTools means all tools", func(t *testing.T) {
1334 t.Parallel()
1335 var capturedToolNames []string
1336 model := &mockLanguageModel{
1337 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1338 // Capture tool names to verify all tools are included
1339 for _, tool := range call.Tools {
1340 capturedToolNames = append(capturedToolNames, tool.GetName())
1341 }
1342 return &Response{
1343 Content: ResponseContent{
1344 TextContent{Text: "Response"},
1345 },
1346 Usage: Usage{TotalTokens: 10},
1347 FinishReason: FinishReasonStop,
1348 }, nil
1349 },
1350 }
1351
1352 tool1 := &mockTool{name: "tool1", description: "Tool 1"}
1353 tool2 := &mockTool{name: "tool2", description: "Tool 2"}
1354
1355 prepareStepFunc := func(ctx context.Context, options PrepareStepFunctionOptions) (context.Context, PrepareStepResult, error) {
1356 return ctx, PrepareStepResult{
1357 Model: options.Model,
1358 Messages: options.Messages,
1359 ActiveTools: []string{}, // Empty slice means all tools
1360 }, nil
1361 }
1362
1363 agent := NewAgent(model, WithTools(tool1, tool2))
1364
1365 result, err := agent.Generate(context.Background(), AgentCall{
1366 Prompt: "test prompt",
1367 PrepareStep: prepareStepFunc,
1368 })
1369
1370 require.NoError(t, err)
1371 require.NotNil(t, result)
1372 require.Len(t, capturedToolNames, 2) // All tools should be included
1373 require.Contains(t, capturedToolNames, "tool1")
1374 require.Contains(t, capturedToolNames, "tool2")
1375 })
1376}
1377
1378func TestToolCallRepair(t *testing.T) {
1379 t.Parallel()
1380
1381 t.Run("Valid tool call passes validation", func(t *testing.T) {
1382 t.Parallel()
1383 model := &mockLanguageModel{
1384 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1385 return &Response{
1386 Content: ResponseContent{
1387 TextContent{Text: "Response"},
1388 ToolCallContent{
1389 ToolCallID: "call1",
1390 ToolName: "test_tool",
1391 Input: `{"value": "test"}`, // Valid JSON with required field
1392 },
1393 },
1394 Usage: Usage{TotalTokens: 10},
1395 FinishReason: FinishReasonStop, // Changed to stop to avoid infinite loop
1396 }, nil
1397 },
1398 }
1399
1400 tool := &mockTool{
1401 name: "test_tool",
1402 description: "Test tool",
1403 parameters: map[string]any{
1404 "value": map[string]any{"type": "string"},
1405 },
1406 required: []string{"value"},
1407 executeFunc: func(ctx context.Context, call ToolCall) (ToolResponse, error) {
1408 return ToolResponse{Content: "success", IsError: false}, nil
1409 },
1410 }
1411
1412 agent := NewAgent(model, WithTools(tool), WithStopConditions(StepCountIs(2))) // Limit steps
1413
1414 result, err := agent.Generate(context.Background(), AgentCall{
1415 Prompt: "test prompt",
1416 })
1417
1418 require.NoError(t, err)
1419 require.NotNil(t, result)
1420 require.Len(t, result.Steps, 1) // Only one step since FinishReason is stop
1421
1422 // Check that tool call was executed successfully
1423 toolCalls := result.Steps[0].Content.ToolCalls()
1424 require.Len(t, toolCalls, 1)
1425 require.False(t, toolCalls[0].Invalid) // Should be valid
1426 })
1427
1428 t.Run("Invalid tool call without repair function", func(t *testing.T) {
1429 t.Parallel()
1430 model := &mockLanguageModel{
1431 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1432 return &Response{
1433 Content: ResponseContent{
1434 TextContent{Text: "Response"},
1435 ToolCallContent{
1436 ToolCallID: "call1",
1437 ToolName: "test_tool",
1438 Input: `{"wrong_field": "test"}`, // Missing required field
1439 },
1440 },
1441 Usage: Usage{TotalTokens: 10},
1442 FinishReason: FinishReasonStop, // Changed to stop to avoid infinite loop
1443 }, nil
1444 },
1445 }
1446
1447 tool := &mockTool{
1448 name: "test_tool",
1449 description: "Test tool",
1450 parameters: map[string]any{
1451 "value": map[string]any{"type": "string"},
1452 },
1453 required: []string{"value"},
1454 }
1455
1456 agent := NewAgent(model, WithTools(tool), WithStopConditions(StepCountIs(2))) // Limit steps
1457
1458 result, err := agent.Generate(context.Background(), AgentCall{
1459 Prompt: "test prompt",
1460 })
1461
1462 require.NoError(t, err)
1463 require.NotNil(t, result)
1464 require.Len(t, result.Steps, 1) // Only one step
1465
1466 // Check that tool call was marked as invalid
1467 toolCalls := result.Steps[0].Content.ToolCalls()
1468 require.Len(t, toolCalls, 1)
1469 require.True(t, toolCalls[0].Invalid) // Should be invalid
1470 require.Contains(t, toolCalls[0].ValidationError.Error(), "missing required parameter: value")
1471 })
1472
1473 t.Run("Invalid tool call with successful repair", func(t *testing.T) {
1474 t.Parallel()
1475 model := &mockLanguageModel{
1476 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1477 return &Response{
1478 Content: ResponseContent{
1479 TextContent{Text: "Response"},
1480 ToolCallContent{
1481 ToolCallID: "call1",
1482 ToolName: "test_tool",
1483 Input: `{"wrong_field": "test"}`, // Missing required field
1484 },
1485 },
1486 Usage: Usage{TotalTokens: 10},
1487 FinishReason: FinishReasonStop, // Changed to stop
1488 }, nil
1489 },
1490 }
1491
1492 tool := &mockTool{
1493 name: "test_tool",
1494 description: "Test tool",
1495 parameters: map[string]any{
1496 "value": map[string]any{"type": "string"},
1497 },
1498 required: []string{"value"},
1499 executeFunc: func(ctx context.Context, call ToolCall) (ToolResponse, error) {
1500 return ToolResponse{Content: "repaired_success", IsError: false}, nil
1501 },
1502 }
1503
1504 repairFunc := func(ctx context.Context, options ToolCallRepairOptions) (*ToolCallContent, error) {
1505 // Simple repair: add the missing required field
1506 repairedToolCall := options.OriginalToolCall
1507 repairedToolCall.Input = `{"value": "repaired"}`
1508 return &repairedToolCall, nil
1509 }
1510
1511 agent := NewAgent(model, WithTools(tool), WithRepairToolCall(repairFunc), WithStopConditions(StepCountIs(2)))
1512
1513 result, err := agent.Generate(context.Background(), AgentCall{
1514 Prompt: "test prompt",
1515 })
1516
1517 require.NoError(t, err)
1518 require.NotNil(t, result)
1519 require.Len(t, result.Steps, 1) // Only one step
1520
1521 // Check that tool call was repaired and is now valid
1522 toolCalls := result.Steps[0].Content.ToolCalls()
1523 require.Len(t, toolCalls, 1)
1524 require.False(t, toolCalls[0].Invalid) // Should be valid after repair
1525 require.Equal(t, `{"value": "repaired"}`, toolCalls[0].Input) // Should have repaired input
1526 })
1527
1528 t.Run("Invalid tool call with failed repair", func(t *testing.T) {
1529 t.Parallel()
1530 model := &mockLanguageModel{
1531 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1532 return &Response{
1533 Content: ResponseContent{
1534 TextContent{Text: "Response"},
1535 ToolCallContent{
1536 ToolCallID: "call1",
1537 ToolName: "test_tool",
1538 Input: `{"wrong_field": "test"}`, // Missing required field
1539 },
1540 },
1541 Usage: Usage{TotalTokens: 10},
1542 FinishReason: FinishReasonStop, // Changed to stop
1543 }, nil
1544 },
1545 }
1546
1547 tool := &mockTool{
1548 name: "test_tool",
1549 description: "Test tool",
1550 parameters: map[string]any{
1551 "value": map[string]any{"type": "string"},
1552 },
1553 required: []string{"value"},
1554 }
1555
1556 repairFunc := func(ctx context.Context, options ToolCallRepairOptions) (*ToolCallContent, error) {
1557 // Repair function fails
1558 return nil, errors.New("repair failed")
1559 }
1560
1561 agent := NewAgent(model, WithTools(tool), WithRepairToolCall(repairFunc), WithStopConditions(StepCountIs(2)))
1562
1563 result, err := agent.Generate(context.Background(), AgentCall{
1564 Prompt: "test prompt",
1565 })
1566
1567 require.NoError(t, err)
1568 require.NotNil(t, result)
1569 require.Len(t, result.Steps, 1) // Only one step
1570
1571 // Check that tool call was marked as invalid since repair failed
1572 toolCalls := result.Steps[0].Content.ToolCalls()
1573 require.Len(t, toolCalls, 1)
1574 require.True(t, toolCalls[0].Invalid) // Should be invalid
1575 require.Contains(t, toolCalls[0].ValidationError.Error(), "missing required parameter: value")
1576 })
1577
1578 t.Run("Nonexistent tool call", func(t *testing.T) {
1579 t.Parallel()
1580 model := &mockLanguageModel{
1581 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1582 return &Response{
1583 Content: ResponseContent{
1584 TextContent{Text: "Response"},
1585 ToolCallContent{
1586 ToolCallID: "call1",
1587 ToolName: "nonexistent_tool",
1588 Input: `{"value": "test"}`,
1589 },
1590 },
1591 Usage: Usage{TotalTokens: 10},
1592 FinishReason: FinishReasonStop, // Changed to stop
1593 }, nil
1594 },
1595 }
1596
1597 tool := &mockTool{name: "test_tool", description: "Test tool"}
1598
1599 agent := NewAgent(model, WithTools(tool), WithStopConditions(StepCountIs(2)))
1600
1601 result, err := agent.Generate(context.Background(), AgentCall{
1602 Prompt: "test prompt",
1603 })
1604
1605 require.NoError(t, err)
1606 require.NotNil(t, result)
1607 require.Len(t, result.Steps, 1) // Only one step
1608
1609 // Check that tool call was marked as invalid due to nonexistent tool
1610 toolCalls := result.Steps[0].Content.ToolCalls()
1611 require.Len(t, toolCalls, 1)
1612 require.True(t, toolCalls[0].Invalid) // Should be invalid
1613 require.Contains(t, toolCalls[0].ValidationError.Error(), "tool not found: nonexistent_tool")
1614 })
1615
1616 t.Run("Invalid JSON in tool call", func(t *testing.T) {
1617 t.Parallel()
1618 model := &mockLanguageModel{
1619 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1620 return &Response{
1621 Content: ResponseContent{
1622 TextContent{Text: "Response"},
1623 ToolCallContent{
1624 ToolCallID: "call1",
1625 ToolName: "test_tool",
1626 Input: `{invalid json}`, // Invalid JSON
1627 },
1628 },
1629 Usage: Usage{TotalTokens: 10},
1630 FinishReason: FinishReasonStop, // Changed to stop
1631 }, nil
1632 },
1633 }
1634
1635 tool := &mockTool{
1636 name: "test_tool",
1637 description: "Test tool",
1638 parameters: map[string]any{
1639 "value": map[string]any{"type": "string"},
1640 },
1641 required: []string{"value"},
1642 }
1643
1644 agent := NewAgent(model, WithTools(tool), WithStopConditions(StepCountIs(2)))
1645
1646 result, err := agent.Generate(context.Background(), AgentCall{
1647 Prompt: "test prompt",
1648 })
1649
1650 require.NoError(t, err)
1651 require.NotNil(t, result)
1652 require.Len(t, result.Steps, 1) // Only one step
1653
1654 // Check that tool call was marked as invalid due to invalid JSON
1655 toolCalls := result.Steps[0].Content.ToolCalls()
1656 require.Len(t, toolCalls, 1)
1657 require.True(t, toolCalls[0].Invalid) // Should be invalid
1658 require.Contains(t, toolCalls[0].ValidationError.Error(), "invalid JSON input")
1659 })
1660}
1661
1662// Test media and image tool responses
1663func TestAgent_MediaToolResponses(t *testing.T) {
1664 t.Parallel()
1665
1666 imageData := []byte{0x89, 0x50, 0x4E, 0x47} // PNG header bytes
1667 audioData := []byte{0x52, 0x49, 0x46, 0x46} // RIFF header bytes
1668
1669 t.Run("Image tool response", func(t *testing.T) {
1670 t.Parallel()
1671
1672 imageTool := &mockTool{
1673 name: "generate_image",
1674 description: "Generates an image",
1675 executeFunc: func(ctx context.Context, call ToolCall) (ToolResponse, error) {
1676 return NewImageResponse(imageData, "image/png"), nil
1677 },
1678 }
1679
1680 model := &mockLanguageModel{
1681 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1682 if len(call.Prompt) == 1 {
1683 // First call - request image tool
1684 return &Response{
1685 Content: []Content{
1686 ToolCallContent{
1687 ToolCallID: "img-1",
1688 ToolName: "generate_image",
1689 Input: `{}`,
1690 },
1691 },
1692 Usage: Usage{TotalTokens: 10},
1693 FinishReason: FinishReasonToolCalls,
1694 }, nil
1695 }
1696 // Second call - after tool execution
1697 return &Response{
1698 Content: []Content{TextContent{Text: "Image generated"}},
1699 Usage: Usage{TotalTokens: 20},
1700 FinishReason: FinishReasonStop,
1701 }, nil
1702 },
1703 }
1704
1705 agent := NewAgent(model, WithTools(imageTool), WithStopConditions(StepCountIs(3)))
1706
1707 result, err := agent.Generate(context.Background(), AgentCall{
1708 Prompt: "Generate an image",
1709 })
1710
1711 require.NoError(t, err)
1712 require.NotNil(t, result)
1713 require.Len(t, result.Steps, 2) // Tool call step + final response
1714
1715 // Check tool results in first step
1716 toolResults := result.Steps[0].Content.ToolResults()
1717 require.Len(t, toolResults, 1)
1718
1719 mediaResult, ok := toolResults[0].Result.(ToolResultOutputContentMedia)
1720 require.True(t, ok, "Expected media result")
1721 require.Equal(t, base64.StdEncoding.EncodeToString(imageData), mediaResult.Data)
1722 require.Equal(t, "image/png", mediaResult.MediaType)
1723 })
1724
1725 t.Run("Media tool response (audio)", func(t *testing.T) {
1726 t.Parallel()
1727
1728 audioTool := &mockTool{
1729 name: "generate_audio",
1730 description: "Generates audio",
1731 executeFunc: func(ctx context.Context, call ToolCall) (ToolResponse, error) {
1732 return NewMediaResponse(audioData, "audio/wav"), nil
1733 },
1734 }
1735
1736 model := &mockLanguageModel{
1737 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1738 if len(call.Prompt) == 1 {
1739 return &Response{
1740 Content: []Content{
1741 ToolCallContent{
1742 ToolCallID: "audio-1",
1743 ToolName: "generate_audio",
1744 Input: `{}`,
1745 },
1746 },
1747 Usage: Usage{TotalTokens: 10},
1748 FinishReason: FinishReasonToolCalls,
1749 }, nil
1750 }
1751 return &Response{
1752 Content: []Content{TextContent{Text: "Audio generated"}},
1753 Usage: Usage{TotalTokens: 20},
1754 FinishReason: FinishReasonStop,
1755 }, nil
1756 },
1757 }
1758
1759 agent := NewAgent(model, WithTools(audioTool), WithStopConditions(StepCountIs(3)))
1760
1761 result, err := agent.Generate(context.Background(), AgentCall{
1762 Prompt: "Generate audio",
1763 })
1764
1765 require.NoError(t, err)
1766 require.NotNil(t, result)
1767
1768 toolResults := result.Steps[0].Content.ToolResults()
1769 require.Len(t, toolResults, 1)
1770
1771 mediaResult, ok := toolResults[0].Result.(ToolResultOutputContentMedia)
1772 require.True(t, ok, "Expected media result")
1773 require.Equal(t, base64.StdEncoding.EncodeToString(audioData), mediaResult.Data)
1774 require.Equal(t, "audio/wav", mediaResult.MediaType)
1775 })
1776
1777 t.Run("Media response with text", func(t *testing.T) {
1778 t.Parallel()
1779
1780 imageTool := &mockTool{
1781 name: "screenshot",
1782 description: "Takes a screenshot",
1783 executeFunc: func(ctx context.Context, call ToolCall) (ToolResponse, error) {
1784 resp := NewImageResponse(imageData, "image/png")
1785 resp.Content = "Screenshot captured successfully"
1786 return resp, nil
1787 },
1788 }
1789
1790 model := &mockLanguageModel{
1791 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1792 if len(call.Prompt) == 1 {
1793 return &Response{
1794 Content: []Content{
1795 ToolCallContent{
1796 ToolCallID: "screen-1",
1797 ToolName: "screenshot",
1798 Input: `{}`,
1799 },
1800 },
1801 Usage: Usage{TotalTokens: 10},
1802 FinishReason: FinishReasonToolCalls,
1803 }, nil
1804 }
1805 return &Response{
1806 Content: []Content{TextContent{Text: "Done"}},
1807 Usage: Usage{TotalTokens: 20},
1808 FinishReason: FinishReasonStop,
1809 }, nil
1810 },
1811 }
1812
1813 agent := NewAgent(model, WithTools(imageTool), WithStopConditions(StepCountIs(3)))
1814
1815 result, err := agent.Generate(context.Background(), AgentCall{
1816 Prompt: "Take a screenshot",
1817 })
1818
1819 require.NoError(t, err)
1820 require.NotNil(t, result)
1821
1822 toolResults := result.Steps[0].Content.ToolResults()
1823 require.Len(t, toolResults, 1)
1824
1825 mediaResult, ok := toolResults[0].Result.(ToolResultOutputContentMedia)
1826 require.True(t, ok, "Expected media result")
1827 require.Equal(t, base64.StdEncoding.EncodeToString(imageData), mediaResult.Data)
1828 require.Equal(t, "image/png", mediaResult.MediaType)
1829 require.Equal(t, "Screenshot captured successfully", mediaResult.Text)
1830 })
1831
1832 t.Run("Media response preserves metadata", func(t *testing.T) {
1833 t.Parallel()
1834
1835 type ImageMetadata struct {
1836 Width int `json:"width"`
1837 Height int `json:"height"`
1838 }
1839
1840 imageTool := &mockTool{
1841 name: "generate_image",
1842 description: "Generates an image",
1843 executeFunc: func(ctx context.Context, call ToolCall) (ToolResponse, error) {
1844 resp := NewImageResponse(imageData, "image/png")
1845 return WithResponseMetadata(resp, ImageMetadata{Width: 800, Height: 600}), nil
1846 },
1847 }
1848
1849 model := &mockLanguageModel{
1850 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1851 if len(call.Prompt) == 1 {
1852 return &Response{
1853 Content: []Content{
1854 ToolCallContent{
1855 ToolCallID: "img-1",
1856 ToolName: "generate_image",
1857 Input: `{}`,
1858 },
1859 },
1860 Usage: Usage{TotalTokens: 10},
1861 FinishReason: FinishReasonToolCalls,
1862 }, nil
1863 }
1864 return &Response{
1865 Content: []Content{TextContent{Text: "Done"}},
1866 Usage: Usage{TotalTokens: 20},
1867 FinishReason: FinishReasonStop,
1868 }, nil
1869 },
1870 }
1871
1872 agent := NewAgent(model, WithTools(imageTool), WithStopConditions(StepCountIs(3)))
1873
1874 result, err := agent.Generate(context.Background(), AgentCall{
1875 Prompt: "Generate image",
1876 })
1877
1878 require.NoError(t, err)
1879 require.NotNil(t, result)
1880
1881 toolResults := result.Steps[0].Content.ToolResults()
1882 require.Len(t, toolResults, 1)
1883
1884 // Check metadata was preserved
1885 require.NotEmpty(t, toolResults[0].ClientMetadata)
1886
1887 var metadata ImageMetadata
1888 err = json.Unmarshal([]byte(toolResults[0].ClientMetadata), &metadata)
1889 require.NoError(t, err)
1890 require.Equal(t, 800, metadata.Width)
1891 require.Equal(t, 600, metadata.Height)
1892 })
1893}
1894
1895func TestToResponseMessages_ProviderExecutedRouting(t *testing.T) {
1896 t.Parallel()
1897
1898 // Build step content that mixes a provider-executed tool call/result
1899 // (e.g. web search) with a regular local tool call/result.
1900 content := []Content{
1901 // Provider-executed tool call.
1902 &ToolCallContent{
1903 ToolCallID: "srvtoolu_01",
1904 ToolName: "web_search",
1905 Input: `{"query":"test"}`,
1906 ProviderExecuted: true,
1907 },
1908 // Provider-executed tool result.
1909 &ToolResultContent{
1910 ToolCallID: "srvtoolu_01",
1911 ProviderExecuted: true,
1912 },
1913 // Regular (locally-executed) tool call.
1914 &ToolCallContent{
1915 ToolCallID: "toolu_02",
1916 ToolName: "calculator",
1917 Input: `{"expr":"1+1"}`,
1918 },
1919 // Regular tool result.
1920 &ToolResultContent{
1921 ToolCallID: "toolu_02",
1922 Result: ToolResultOutputContentText{Text: "2"},
1923 },
1924 // Some trailing text.
1925 &TextContent{Text: "Done."},
1926 }
1927
1928 msgs := toResponseMessages(content)
1929
1930 // Expect two messages: assistant + tool.
1931 require.Len(t, msgs, 2)
1932
1933 // Assistant message should contain:
1934 // 1. provider-executed ToolCallPart
1935 // 2. provider-executed ToolResultPart
1936 // 3. regular ToolCallPart
1937 // 4. TextPart
1938 assistant := msgs[0]
1939 require.Equal(t, MessageRoleAssistant, assistant.Role)
1940 require.Len(t, assistant.Content, 4)
1941
1942 // Verify provider-executed tool call is in assistant.
1943 tc1, ok := AsMessagePart[ToolCallPart](assistant.Content[0])
1944 require.True(t, ok)
1945 require.Equal(t, "srvtoolu_01", tc1.ToolCallID)
1946 require.True(t, tc1.ProviderExecuted)
1947
1948 // Verify provider-executed tool result is in assistant.
1949 tr1, ok := AsMessagePart[ToolResultPart](assistant.Content[1])
1950 require.True(t, ok)
1951 require.Equal(t, "srvtoolu_01", tr1.ToolCallID)
1952 require.True(t, tr1.ProviderExecuted)
1953
1954 // Verify regular tool call is in assistant.
1955 tc2, ok := AsMessagePart[ToolCallPart](assistant.Content[2])
1956 require.True(t, ok)
1957 require.Equal(t, "toolu_02", tc2.ToolCallID)
1958 require.False(t, tc2.ProviderExecuted)
1959
1960 // Verify text part is in assistant.
1961 text, ok := AsMessagePart[TextPart](assistant.Content[3])
1962 require.True(t, ok)
1963 require.Equal(t, "Done.", text.Text)
1964
1965 // Tool message should contain only the regular tool result.
1966 toolMsg := msgs[1]
1967 require.Equal(t, MessageRoleTool, toolMsg.Role)
1968 require.Len(t, toolMsg.Content, 1)
1969
1970 tr2, ok := AsMessagePart[ToolResultPart](toolMsg.Content[0])
1971 require.True(t, ok)
1972 require.Equal(t, "toolu_02", tr2.ToolCallID)
1973 require.False(t, tr2.ProviderExecuted)
1974}
1975
1976// TestAgent_Generate_ExecutableProviderTool verifies that an
1977// ExecutableProviderTool registered via WithProviderDefinedTools is
1978// executed by the agent when the model returns a matching tool call.
1979func TestAgent_Generate_ExecutableProviderTool(t *testing.T) {
1980 t.Parallel()
1981
1982 runCalled := false
1983 execTool := NewExecutableProviderTool(
1984 ProviderDefinedTool{
1985 ID: "test.computer",
1986 Name: "computer",
1987 Args: map[string]any{"display_width_px": 1920},
1988 },
1989 func(ctx context.Context, call ToolCall) (ToolResponse, error) {
1990 runCalled = true
1991 return NewTextResponse("screenshot taken"), nil
1992 },
1993 )
1994
1995 model := &mockLanguageModel{
1996 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1997 return &Response{
1998 Content: []Content{
1999 ToolCallContent{
2000 ToolCallID: "call-1",
2001 ToolName: "computer",
2002 Input: `{"action":"screenshot"}`,
2003 },
2004 },
2005 Usage: Usage{TotalTokens: 10},
2006 FinishReason: FinishReasonStop,
2007 }, nil
2008 },
2009 }
2010
2011 agent := NewAgent(model, WithProviderDefinedTools(execTool))
2012 result, err := agent.Generate(context.Background(), AgentCall{
2013 Prompt: "take a screenshot",
2014 })
2015
2016 require.NoError(t, err)
2017 require.NotNil(t, result)
2018 require.True(t, runCalled, "expected Run func to be called")
2019 require.Len(t, result.Steps, 1)
2020
2021 // Verify tool result is in the response.
2022 var toolResults []ToolResultContent
2023 for _, c := range result.Response.Content {
2024 if tr, ok := AsContentType[ToolResultContent](c); ok {
2025 toolResults = append(toolResults, tr)
2026 }
2027 }
2028 require.Len(t, toolResults, 1)
2029 require.Equal(t, "call-1", toolResults[0].ToolCallID)
2030 require.Equal(t, "computer", toolResults[0].ToolName)
2031
2032 textResult, ok := toolResults[0].Result.(ToolResultOutputContentText)
2033 require.True(t, ok)
2034 require.Equal(t, "screenshot taken", textResult.Text)
2035}
2036
2037// TestAgent_Generate_ExecutableProviderTool_ActiveTools verifies that
2038// active tool filtering works for ExecutableProviderTool.
2039func TestAgent_Generate_ExecutableProviderTool_ActiveTools(t *testing.T) {
2040 t.Parallel()
2041
2042 execTool := NewExecutableProviderTool(
2043 ProviderDefinedTool{
2044 ID: "test.computer",
2045 Name: "computer",
2046 Args: map[string]any{"display_width_px": 1920},
2047 },
2048 func(ctx context.Context, call ToolCall) (ToolResponse, error) {
2049 return NewTextResponse("ok"), nil
2050 },
2051 )
2052
2053 model := &mockLanguageModel{
2054 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
2055 // With ActiveTools=["other"], computer should be filtered out.
2056 require.Empty(t, call.Tools)
2057
2058 return &Response{
2059 Content: []Content{TextContent{Text: "no tools"}},
2060 Usage: Usage{TotalTokens: 5},
2061 FinishReason: FinishReasonStop,
2062 }, nil
2063 },
2064 }
2065
2066 agent := NewAgent(model, WithProviderDefinedTools(execTool))
2067 result, err := agent.Generate(context.Background(), AgentCall{
2068 Prompt: "test",
2069 ActiveTools: []string{"other"},
2070 })
2071
2072 require.NoError(t, err)
2073 require.NotNil(t, result)
2074}
2075
2076// TestAgent_Generate_ExecutableProviderTool_ActiveTools_Rejected
2077// verifies that a hallucinated tool call for an EPT excluded by
2078// activeTools is rejected at validation and execution time.
2079func TestAgent_Generate_ExecutableProviderTool_ActiveTools_Rejected(t *testing.T) {
2080 t.Parallel()
2081
2082 runCalled := false
2083 execTool := NewExecutableProviderTool(
2084 ProviderDefinedTool{
2085 ID: "test.computer",
2086 Name: "computer",
2087 Args: map[string]any{"display_width_px": 1920},
2088 },
2089 func(ctx context.Context, call ToolCall) (ToolResponse, error) {
2090 runCalled = true
2091 return NewTextResponse("ok"), nil
2092 },
2093 )
2094
2095 callCount := 0
2096 model := &mockLanguageModel{
2097 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
2098 callCount++
2099 if callCount == 1 {
2100 // Model hallucinates a call to the excluded tool.
2101 return &Response{
2102 Content: []Content{ToolCallContent{
2103 ToolCallID: "call-1",
2104 ToolName: "computer",
2105 Input: `{"action":"screenshot"}`,
2106 }},
2107 Usage: Usage{TotalTokens: 5},
2108 FinishReason: FinishReasonToolCalls,
2109 }, nil
2110 }
2111 // Second call: model stops.
2112 return &Response{
2113 Content: []Content{TextContent{Text: "done"}},
2114 Usage: Usage{TotalTokens: 3},
2115 FinishReason: FinishReasonStop,
2116 }, nil
2117 },
2118 }
2119
2120 agent := NewAgent(model, WithProviderDefinedTools(execTool))
2121 result, err := agent.Generate(context.Background(), AgentCall{
2122 Prompt: "test",
2123 ActiveTools: []string{"other"},
2124 })
2125
2126 require.NoError(t, err)
2127 require.NotNil(t, result)
2128 require.False(t, runCalled, "excluded EPT should not have been executed")
2129
2130 // The tool call should have been marked invalid.
2131 var foundInvalidToolResult bool
2132 for _, step := range result.Steps {
2133 for _, content := range step.Content {
2134 if tr, ok := AsContentType[ToolResultContent](content); ok {
2135 if errResult, ok := tr.Result.(ToolResultOutputContentError); ok {
2136 require.Contains(t, errResult.Error.Error(), "tool not found")
2137 foundInvalidToolResult = true
2138 }
2139 }
2140 }
2141 }
2142 require.True(t, foundInvalidToolResult, "expected an error result for the excluded tool call")
2143}
2144
2145// TestAgent_Stream_ExecutableProviderTool verifies that an
2146// ExecutableProviderTool works through the Stream path.
2147func TestAgent_Stream_ExecutableProviderTool(t *testing.T) {
2148 t.Parallel()
2149
2150 runCalled := false
2151 execTool := NewExecutableProviderTool(
2152 ProviderDefinedTool{
2153 ID: "test.computer",
2154 Name: "computer",
2155 Args: map[string]any{"display_width_px": 1920},
2156 },
2157 func(ctx context.Context, call ToolCall) (ToolResponse, error) {
2158 runCalled = true
2159 return NewTextResponse("screenshot taken"), nil
2160 },
2161 )
2162
2163 model := &mockLanguageModel{
2164 streamFunc: func(ctx context.Context, call Call) (StreamResponse, error) {
2165 return func(yield func(StreamPart) bool) {
2166 if !yield(StreamPart{
2167 Type: StreamPartTypeToolCall,
2168 ID: "call-1",
2169 ToolCallName: "computer",
2170 ToolCallInput: `{"action":"screenshot"}`,
2171 }) {
2172 return
2173 }
2174 yield(StreamPart{
2175 Type: StreamPartTypeFinish,
2176 FinishReason: FinishReasonStop,
2177 Usage: Usage{TotalTokens: 10},
2178 })
2179 }, nil
2180 },
2181 }
2182
2183 agent := NewAgent(model, WithProviderDefinedTools(execTool))
2184 result, err := agent.Stream(context.Background(), AgentStreamCall{
2185 Prompt: "take a screenshot",
2186 })
2187
2188 require.NoError(t, err)
2189 require.NotNil(t, result)
2190 require.True(t, runCalled, "expected Run func to be called")
2191 require.Len(t, result.Steps, 1)
2192
2193 // Verify tool result is in the step content.
2194 var toolResults []ToolResultContent
2195 for _, c := range result.Steps[0].Content {
2196 if tr, ok := AsContentType[ToolResultContent](c); ok {
2197 toolResults = append(toolResults, tr)
2198 }
2199 }
2200 require.Len(t, toolResults, 1)
2201 require.Equal(t, "call-1", toolResults[0].ToolCallID)
2202}
2203
2204// TestAgent_PrepareTools_ExecutableProviderTool verifies that
2205// prepareTools emits a ProviderDefinedTool (not a FunctionTool) when
2206// an ExecutableProviderTool is registered via WithProviderDefinedTools.
2207func TestAgent_PrepareTools_ExecutableProviderTool(t *testing.T) {
2208 t.Parallel()
2209
2210 execTool := NewExecutableProviderTool(
2211 ProviderDefinedTool{
2212 ID: "test.computer",
2213 Name: "computer",
2214 Args: map[string]any{"display_width_px": 1920},
2215 },
2216 func(ctx context.Context, call ToolCall) (ToolResponse, error) {
2217 return NewTextResponse("ok"), nil
2218 },
2219 )
2220
2221 model := &mockLanguageModel{
2222 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
2223 // Verify the tool is emitted as a ProviderDefinedTool.
2224 require.Len(t, call.Tools, 1)
2225 pdt, ok := call.Tools[0].(ProviderDefinedTool)
2226 require.True(t, ok, "expected ProviderDefinedTool, got %T", call.Tools[0])
2227 require.Equal(t, "computer", pdt.Name)
2228 require.Equal(t, "test.computer", pdt.ID)
2229
2230 return &Response{
2231 Content: []Content{TextContent{Text: "done"}},
2232 Usage: Usage{TotalTokens: 5},
2233 FinishReason: FinishReasonStop,
2234 }, nil
2235 },
2236 }
2237
2238 agent := NewAgent(model, WithProviderDefinedTools(execTool))
2239 _, err := agent.Generate(context.Background(), AgentCall{
2240 Prompt: "test",
2241 })
2242 require.NoError(t, err)
2243}
2244
2245// TestAgent_ValidateToolCall_ExecutableProviderTool verifies that
2246// schema validation is skipped for executable provider tools, but
2247// JSON parsing is still checked.
2248func TestAgent_ValidateToolCall_ExecutableProviderTool(t *testing.T) {
2249 t.Parallel()
2250
2251 execTool := NewExecutableProviderTool(
2252 ProviderDefinedTool{
2253 ID: "test.computer",
2254 Name: "computer",
2255 },
2256 func(ctx context.Context, call ToolCall) (ToolResponse, error) {
2257 return NewTextResponse("ok"), nil
2258 },
2259 )
2260
2261 a := &agent{
2262 settings: agentSettings{
2263 executableProviderTools: []ExecutableProviderTool{execTool},
2264 },
2265 }
2266
2267 // Valid JSON should pass even without required fields.
2268 err := a.validateToolCall(ToolCallContent{
2269 ToolName: "computer",
2270 Input: `{"action":"screenshot"}`,
2271 }, []AgentTool{}, []ExecutableProviderTool{execTool})
2272 require.NoError(t, err)
2273
2274 // Invalid JSON should still fail.
2275 err = a.validateToolCall(ToolCallContent{
2276 ToolName: "computer",
2277 Input: `not-json`,
2278 }, []AgentTool{}, []ExecutableProviderTool{execTool})
2279 require.Error(t, err)
2280 require.Contains(t, err.Error(), "invalid JSON")
2281}
2282
2283// TestAgent_WithProviderDefinedTools_BackwardCompat verifies that
2284// passing a plain ProviderDefinedTool to WithProviderDefinedTools
2285// still works (web search path).
2286func TestAgent_WithProviderDefinedTools_BackwardCompat(t *testing.T) {
2287 t.Parallel()
2288
2289 webSearch := ProviderDefinedTool{
2290 ID: "anthropic.web_search",
2291 Name: "web_search",
2292 Args: map[string]any{"max_results": 5},
2293 }
2294
2295 model := &mockLanguageModel{
2296 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
2297 require.Len(t, call.Tools, 1)
2298 pdt, ok := call.Tools[0].(ProviderDefinedTool)
2299 require.True(t, ok, "expected ProviderDefinedTool, got %T", call.Tools[0])
2300 require.Equal(t, "web_search", pdt.Name)
2301 require.Equal(t, "anthropic.web_search", pdt.ID)
2302
2303 return &Response{
2304 Content: []Content{TextContent{Text: "search results"}},
2305 Usage: Usage{TotalTokens: 5},
2306 FinishReason: FinishReasonStop,
2307 }, nil
2308 },
2309 }
2310
2311 agent := NewAgent(model, WithProviderDefinedTools(webSearch))
2312 result, err := agent.Generate(context.Background(), AgentCall{
2313 Prompt: "search for something",
2314 })
2315
2316 require.NoError(t, err)
2317 require.NotNil(t, result)
2318 require.Equal(t, "search results", result.Response.Content.Text())
2319}
2320
2321// TestAgent_Generate_ExecutableProviderTool_ImageBase64 verifies that
2322// image data returned by an ExecutableProviderTool's run function is
2323// base64-encoded when stored in ToolResultOutputContentMedia.Data.
2324func TestAgent_Generate_ExecutableProviderTool_ImageBase64(t *testing.T) {
2325 t.Parallel()
2326
2327 rawPNG := []byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A}
2328
2329 execTool := NewExecutableProviderTool(
2330 ProviderDefinedTool{
2331 ID: "test.computer",
2332 Name: "computer",
2333 Args: map[string]any{"display_width_px": 1920},
2334 },
2335 func(ctx context.Context, call ToolCall) (ToolResponse, error) {
2336 return NewImageResponse(rawPNG, "image/png"), nil
2337 },
2338 )
2339
2340 callCount := 0
2341 model := &mockLanguageModel{
2342 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
2343 callCount++
2344 if callCount == 1 {
2345 return &Response{
2346 Content: []Content{
2347 ToolCallContent{
2348 ToolCallID: "call-1",
2349 ToolName: "computer",
2350 Input: `{"action":"screenshot"}`,
2351 },
2352 },
2353 Usage: Usage{TotalTokens: 10},
2354 FinishReason: FinishReasonToolCalls,
2355 }, nil
2356 }
2357 return &Response{
2358 Content: []Content{TextContent{Text: "done"}},
2359 Usage: Usage{TotalTokens: 5},
2360 FinishReason: FinishReasonStop,
2361 }, nil
2362 },
2363 }
2364
2365 agent := NewAgent(model, WithProviderDefinedTools(execTool))
2366 result, err := agent.Generate(context.Background(), AgentCall{
2367 Prompt: "take a screenshot",
2368 })
2369
2370 require.NoError(t, err)
2371 require.NotNil(t, result)
2372 require.Len(t, result.Steps, 2)
2373
2374 // The tool result in the first step must have base64-encoded data.
2375 toolResults := result.Steps[0].Content.ToolResults()
2376 require.Len(t, toolResults, 1)
2377
2378 mediaResult, ok := toolResults[0].Result.(ToolResultOutputContentMedia)
2379 require.True(t, ok, "expected media result")
2380 require.Equal(t, base64.StdEncoding.EncodeToString(rawPNG), mediaResult.Data)
2381 require.Equal(t, "image/png", mediaResult.MediaType)
2382}
2383
2384// TestAgent_Generate_ExecutableProviderTool_CriticalError verifies
2385// that a Go error returned from an ExecutableProviderTool's run
2386// function is treated as a critical error, stopping the agent loop.
2387func TestAgent_Generate_ExecutableProviderTool_CriticalError(t *testing.T) {
2388 t.Parallel()
2389
2390 execTool := NewExecutableProviderTool(
2391 ProviderDefinedTool{
2392 ID: "test.computer",
2393 Name: "computer",
2394 Args: map[string]any{"display_width_px": 1920},
2395 },
2396 func(ctx context.Context, call ToolCall) (ToolResponse, error) {
2397 return ToolResponse{}, fmt.Errorf("vnc connection lost")
2398 },
2399 )
2400
2401 callCount := 0
2402 model := &mockLanguageModel{
2403 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
2404 callCount++
2405 return &Response{
2406 Content: []Content{
2407 ToolCallContent{
2408 ToolCallID: "call-1",
2409 ToolName: "computer",
2410 Input: `{"action":"screenshot"}`,
2411 },
2412 },
2413 Usage: Usage{TotalTokens: 10},
2414 FinishReason: FinishReasonToolCalls,
2415 }, nil
2416 },
2417 }
2418
2419 agent := NewAgent(model, WithProviderDefinedTools(execTool), WithStopConditions(StepCountIs(5)))
2420 result, err := agent.Generate(context.Background(), AgentCall{
2421 Prompt: "take a screenshot",
2422 })
2423
2424 require.NoError(t, err)
2425 require.NotNil(t, result)
2426 // The model should only be called once — the critical error stops
2427 // the loop before a second model call.
2428 require.Equal(t, 1, callCount)
2429 require.Len(t, result.Steps, 1)
2430}