1package fantasy
2
3import (
4 "context"
5 "encoding/json"
6 "errors"
7 "fmt"
8 "testing"
9
10 "github.com/stretchr/testify/require"
11)
12
13// Mock tool for testing
14type mockTool struct {
15 name string
16 providerOptions ProviderOptions
17 description string
18 parameters map[string]any
19 required []string
20 executeFunc func(ctx context.Context, call ToolCall) (ToolResponse, error)
21}
22
23func (m *mockTool) SetProviderOptions(opts ProviderOptions) {
24 m.providerOptions = opts
25}
26
27func (m *mockTool) ProviderOptions() ProviderOptions {
28 return m.providerOptions
29}
30
31func (m *mockTool) Info() ToolInfo {
32 return ToolInfo{
33 Name: m.name,
34 Description: m.description,
35 Parameters: m.parameters,
36 Required: m.required,
37 }
38}
39
40func (m *mockTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
41 if m.executeFunc != nil {
42 return m.executeFunc(ctx, call)
43 }
44 return ToolResponse{Content: "mock result", IsError: false}, nil
45}
46
47// Mock language model for testing
48type mockLanguageModel struct {
49 generateFunc func(ctx context.Context, call Call) (*Response, error)
50 streamFunc func(ctx context.Context, call Call) (StreamResponse, error)
51}
52
53func (m *mockLanguageModel) Generate(ctx context.Context, call Call) (*Response, error) {
54 if m.generateFunc != nil {
55 return m.generateFunc(ctx, call)
56 }
57 return &Response{
58 Content: []Content{
59 TextContent{Text: "Hello, world!"},
60 },
61 Usage: Usage{
62 InputTokens: 3,
63 OutputTokens: 10,
64 TotalTokens: 13,
65 },
66 FinishReason: FinishReasonStop,
67 }, nil
68}
69
70func (m *mockLanguageModel) Stream(ctx context.Context, call Call) (StreamResponse, error) {
71 if m.streamFunc != nil {
72 return m.streamFunc(ctx, call)
73 }
74 return nil, fmt.Errorf("mock stream not implemented")
75}
76
77func (m *mockLanguageModel) Provider() string {
78 return "mock-provider"
79}
80
81func (m *mockLanguageModel) Model() string {
82 return "mock-model"
83}
84
85func (m *mockLanguageModel) GenerateObject(ctx context.Context, call ObjectCall) (*ObjectResponse, error) {
86 return nil, fmt.Errorf("mock GenerateObject not implemented")
87}
88
89func (m *mockLanguageModel) StreamObject(ctx context.Context, call ObjectCall) (ObjectStreamResponse, error) {
90 return nil, fmt.Errorf("mock StreamObject not implemented")
91}
92
93// Test result.content - comprehensive content types (matches TS test)
94func TestAgent_Generate_ResultContent_AllTypes(t *testing.T) {
95 t.Parallel()
96
97 // Create a type-safe tool using the new API
98 type TestInput struct {
99 Value string `json:"value" description:"Test value"`
100 }
101
102 tool1 := NewAgentTool(
103 "tool1",
104 "Test tool",
105 func(ctx context.Context, input TestInput, _ ToolCall) (ToolResponse, error) {
106 require.Equal(t, "value", input.Value)
107 return ToolResponse{Content: "result1", IsError: false}, nil
108 },
109 )
110
111 model := &mockLanguageModel{
112 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
113 return &Response{
114 Content: []Content{
115 TextContent{Text: "Hello, world!"},
116 SourceContent{
117 ID: "123",
118 URL: "https://example.com",
119 Title: "Example",
120 SourceType: SourceTypeURL,
121 },
122 FileContent{
123 Data: []byte{1, 2, 3},
124 MediaType: "image/png",
125 },
126 ReasoningContent{
127 Text: "I will open the conversation with witty banter.",
128 },
129 ToolCallContent{
130 ToolCallID: "call-1",
131 ToolName: "tool1",
132 Input: `{"value":"value"}`,
133 },
134 TextContent{Text: "More text"},
135 },
136 Usage: Usage{
137 InputTokens: 3,
138 OutputTokens: 10,
139 TotalTokens: 13,
140 },
141 FinishReason: FinishReasonStop, // Note: FinishReasonStop, not ToolCalls
142 }, nil
143 },
144 }
145
146 agent := NewAgent(model, WithTools(tool1))
147 result, err := agent.Generate(context.Background(), AgentCall{
148 Prompt: "prompt",
149 })
150
151 require.NoError(t, err)
152 require.NotNil(t, result)
153 require.Len(t, result.Steps, 1) // Single step like TypeScript
154
155 // Check final response content includes tool result
156 require.Len(t, result.Response.Content, 7) // original 6 + 1 tool result
157
158 // Verify each content type in order
159 textContent, ok := AsContentType[TextContent](result.Response.Content[0])
160 require.True(t, ok)
161 require.Equal(t, "Hello, world!", textContent.Text)
162
163 sourceContent, ok := AsContentType[SourceContent](result.Response.Content[1])
164 require.True(t, ok)
165 require.Equal(t, "123", sourceContent.ID)
166
167 fileContent, ok := AsContentType[FileContent](result.Response.Content[2])
168 require.True(t, ok)
169 require.Equal(t, []byte{1, 2, 3}, fileContent.Data)
170
171 reasoningContent, ok := AsContentType[ReasoningContent](result.Response.Content[3])
172 require.True(t, ok)
173 require.Equal(t, "I will open the conversation with witty banter.", reasoningContent.Text)
174
175 toolCallContent, ok := AsContentType[ToolCallContent](result.Response.Content[4])
176 require.True(t, ok)
177 require.Equal(t, "call-1", toolCallContent.ToolCallID)
178
179 moreTextContent, ok := AsContentType[TextContent](result.Response.Content[5])
180 require.True(t, ok)
181 require.Equal(t, "More text", moreTextContent.Text)
182
183 // Tool result should be appended
184 toolResultContent, ok := AsContentType[ToolResultContent](result.Response.Content[6])
185 require.True(t, ok)
186 require.Equal(t, "call-1", toolResultContent.ToolCallID)
187 require.Equal(t, "tool1", toolResultContent.ToolName)
188}
189
190// Test result.text extraction
191func TestAgent_Generate_ResultText(t *testing.T) {
192 t.Parallel()
193
194 model := &mockLanguageModel{
195 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
196 return &Response{
197 Content: []Content{
198 TextContent{Text: "Hello, world!"},
199 },
200 Usage: Usage{
201 InputTokens: 3,
202 OutputTokens: 10,
203 TotalTokens: 13,
204 },
205 FinishReason: FinishReasonStop,
206 }, nil
207 },
208 }
209
210 agent := NewAgent(model)
211 result, err := agent.Generate(context.Background(), AgentCall{
212 Prompt: "prompt",
213 })
214
215 require.NoError(t, err)
216 require.NotNil(t, result)
217
218 // Test text extraction from content
219 text := result.Response.Content.Text()
220 require.Equal(t, "Hello, world!", text)
221}
222
223// Test result.toolCalls extraction (matches TS test exactly)
224func TestAgent_Generate_ResultToolCalls(t *testing.T) {
225 t.Parallel()
226
227 // Create type-safe tools using the new API
228 type Tool1Input struct {
229 Value string `json:"value" description:"Test value"`
230 }
231
232 type Tool2Input struct {
233 SomethingElse string `json:"somethingElse" description:"Another test value"`
234 }
235
236 tool1 := NewAgentTool(
237 "tool1",
238 "Test tool 1",
239 func(ctx context.Context, input Tool1Input, _ ToolCall) (ToolResponse, error) {
240 return ToolResponse{Content: "result1", IsError: false}, nil
241 },
242 )
243
244 tool2 := NewAgentTool(
245 "tool2",
246 "Test tool 2",
247 func(ctx context.Context, input Tool2Input, _ ToolCall) (ToolResponse, error) {
248 return ToolResponse{Content: "result2", IsError: false}, nil
249 },
250 )
251
252 model := &mockLanguageModel{
253 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
254 // Verify tools are passed correctly
255 require.Len(t, call.Tools, 2)
256 require.Equal(t, ToolChoiceAuto, *call.ToolChoice) // Should be auto, not required
257
258 // Verify prompt structure
259 require.Len(t, call.Prompt, 1)
260 require.Equal(t, MessageRoleUser, call.Prompt[0].Role)
261
262 return &Response{
263 Content: []Content{
264 ToolCallContent{
265 ToolCallID: "call-1",
266 ToolName: "tool1",
267 Input: `{"value":"value"}`,
268 },
269 },
270 Usage: Usage{
271 InputTokens: 3,
272 OutputTokens: 10,
273 TotalTokens: 13,
274 },
275 FinishReason: FinishReasonStop, // Note: Stop, not ToolCalls
276 }, nil
277 },
278 }
279
280 agent := NewAgent(model, WithTools(tool1, tool2))
281 result, err := agent.Generate(context.Background(), AgentCall{
282 Prompt: "test-input",
283 })
284
285 require.NoError(t, err)
286 require.NotNil(t, result)
287 require.Len(t, result.Steps, 1) // Single step
288
289 // Extract tool calls from final response (should be empty since tools don't execute)
290 var toolCalls []ToolCallContent
291 for _, content := range result.Response.Content {
292 if toolCall, ok := AsContentType[ToolCallContent](content); ok {
293 toolCalls = append(toolCalls, toolCall)
294 }
295 }
296
297 require.Len(t, toolCalls, 1)
298 require.Equal(t, "call-1", toolCalls[0].ToolCallID)
299 require.Equal(t, "tool1", toolCalls[0].ToolName)
300
301 // Parse and verify input
302 var input map[string]any
303 err = json.Unmarshal([]byte(toolCalls[0].Input), &input)
304 require.NoError(t, err)
305 require.Equal(t, "value", input["value"])
306}
307
308// Test result.toolResults extraction (matches TS test exactly)
309func TestAgent_Generate_ResultToolResults(t *testing.T) {
310 t.Parallel()
311
312 // Create type-safe tool using the new API
313 type TestInput struct {
314 Value string `json:"value" description:"Test value"`
315 }
316
317 tool1 := NewAgentTool(
318 "tool1",
319 "Test tool",
320 func(ctx context.Context, input TestInput, _ ToolCall) (ToolResponse, error) {
321 require.Equal(t, "value", input.Value)
322 return ToolResponse{Content: "result1", IsError: false}, nil
323 },
324 )
325
326 model := &mockLanguageModel{
327 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
328 // Verify tools and tool choice
329 require.Len(t, call.Tools, 1)
330 require.Equal(t, ToolChoiceAuto, *call.ToolChoice)
331
332 // Verify prompt
333 require.Len(t, call.Prompt, 1)
334 require.Equal(t, MessageRoleUser, call.Prompt[0].Role)
335
336 return &Response{
337 Content: []Content{
338 ToolCallContent{
339 ToolCallID: "call-1",
340 ToolName: "tool1",
341 Input: `{"value":"value"}`,
342 },
343 },
344 Usage: Usage{
345 InputTokens: 3,
346 OutputTokens: 10,
347 TotalTokens: 13,
348 },
349 FinishReason: FinishReasonStop, // Note: Stop, not ToolCalls
350 }, nil
351 },
352 }
353
354 agent := NewAgent(model, WithTools(tool1))
355 result, err := agent.Generate(context.Background(), AgentCall{
356 Prompt: "test-input",
357 })
358
359 require.NoError(t, err)
360 require.NotNil(t, result)
361 require.Len(t, result.Steps, 1) // Single step
362
363 // Extract tool results from final response
364 var toolResults []ToolResultContent
365 for _, content := range result.Response.Content {
366 if toolResult, ok := AsContentType[ToolResultContent](content); ok {
367 toolResults = append(toolResults, toolResult)
368 }
369 }
370
371 require.Len(t, toolResults, 1)
372 require.Equal(t, "call-1", toolResults[0].ToolCallID)
373 require.Equal(t, "tool1", toolResults[0].ToolName)
374
375 // Verify result content
376 textResult, ok := toolResults[0].Result.(ToolResultOutputContentText)
377 require.True(t, ok)
378 require.Equal(t, "result1", textResult.Text)
379}
380
381// Test multi-step scenario (matches TS "2 steps: initial, tool-result" test)
382func TestAgent_Generate_MultipleSteps(t *testing.T) {
383 t.Parallel()
384
385 // Create type-safe tool using the new API
386 type TestInput struct {
387 Value string `json:"value" description:"Test value"`
388 }
389
390 tool1 := NewAgentTool(
391 "tool1",
392 "Test tool",
393 func(ctx context.Context, input TestInput, _ ToolCall) (ToolResponse, error) {
394 require.Equal(t, "value", input.Value)
395 return ToolResponse{Content: "result1", IsError: false}, nil
396 },
397 )
398
399 callCount := 0
400 model := &mockLanguageModel{
401 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
402 callCount++
403 switch callCount {
404 case 1:
405 // First call - return tool call with FinishReasonToolCalls
406 return &Response{
407 Content: []Content{
408 ToolCallContent{
409 ToolCallID: "call-1",
410 ToolName: "tool1",
411 Input: `{"value":"value"}`,
412 },
413 },
414 Usage: Usage{
415 InputTokens: 10,
416 OutputTokens: 5,
417 TotalTokens: 15,
418 },
419 FinishReason: FinishReasonToolCalls, // This triggers multi-step
420 }, nil
421 case 2:
422 // Second call - return final text
423 return &Response{
424 Content: []Content{
425 TextContent{Text: "Hello, world!"},
426 },
427 Usage: Usage{
428 InputTokens: 3,
429 OutputTokens: 10,
430 TotalTokens: 13,
431 },
432 FinishReason: FinishReasonStop,
433 }, nil
434 default:
435 t.Fatalf("Unexpected call count: %d", callCount)
436 return nil, nil
437 }
438 },
439 }
440
441 agent := NewAgent(model, WithTools(tool1))
442 result, err := agent.Generate(context.Background(), AgentCall{
443 Prompt: "test-input",
444 })
445
446 require.NoError(t, err)
447 require.NotNil(t, result)
448 require.Len(t, result.Steps, 2)
449
450 // Check total usage sums both steps
451 require.Equal(t, int64(13), result.TotalUsage.InputTokens) // 10 + 3
452 require.Equal(t, int64(15), result.TotalUsage.OutputTokens) // 5 + 10
453 require.Equal(t, int64(28), result.TotalUsage.TotalTokens) // 15 + 13
454
455 // Final response should be from last step
456 require.Len(t, result.Response.Content, 1)
457 textContent, ok := AsContentType[TextContent](result.Response.Content[0])
458 require.True(t, ok)
459 require.Equal(t, "Hello, world!", textContent.Text)
460
461 // result.toolCalls should be empty (from last step)
462 var toolCalls []ToolCallContent
463 for _, content := range result.Response.Content {
464 if _, ok := AsContentType[ToolCallContent](content); ok {
465 toolCalls = append(toolCalls, content.(ToolCallContent))
466 }
467 }
468 require.Len(t, toolCalls, 0)
469
470 // result.toolResults should be empty (from last step)
471 var toolResults []ToolResultContent
472 for _, content := range result.Response.Content {
473 if _, ok := AsContentType[ToolResultContent](content); ok {
474 toolResults = append(toolResults, content.(ToolResultContent))
475 }
476 }
477 require.Len(t, toolResults, 0)
478}
479
480// Test basic text generation
481func TestAgent_Generate_BasicText(t *testing.T) {
482 t.Parallel()
483
484 model := &mockLanguageModel{
485 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
486 return &Response{
487 Content: []Content{
488 TextContent{Text: "Hello, world!"},
489 },
490 Usage: Usage{
491 InputTokens: 3,
492 OutputTokens: 10,
493 TotalTokens: 13,
494 },
495 FinishReason: FinishReasonStop,
496 }, nil
497 },
498 }
499
500 agent := NewAgent(model)
501 result, err := agent.Generate(context.Background(), AgentCall{
502 Prompt: "test prompt",
503 })
504
505 require.NoError(t, err)
506 require.NotNil(t, result)
507 require.Len(t, result.Steps, 1)
508
509 // Check final response
510 require.Len(t, result.Response.Content, 1)
511 textContent, ok := AsContentType[TextContent](result.Response.Content[0])
512 require.True(t, ok)
513 require.Equal(t, "Hello, world!", textContent.Text)
514
515 // Check usage
516 require.Equal(t, int64(3), result.Response.Usage.InputTokens)
517 require.Equal(t, int64(10), result.Response.Usage.OutputTokens)
518 require.Equal(t, int64(13), result.Response.Usage.TotalTokens)
519
520 // Check total usage
521 require.Equal(t, int64(3), result.TotalUsage.InputTokens)
522 require.Equal(t, int64(10), result.TotalUsage.OutputTokens)
523 require.Equal(t, int64(13), result.TotalUsage.TotalTokens)
524}
525
526// Test empty prompt validation
527func TestAgent_Generate_EmptyPrompt(t *testing.T) {
528 t.Parallel()
529
530 model := &mockLanguageModel{}
531 agent := NewAgent(model)
532
533 t.Run("fails without messages", func(t *testing.T) {
534 result, err := agent.Generate(context.Background(), AgentCall{
535 Prompt: "",
536 })
537 require.Error(t, err)
538 require.Nil(t, result)
539 require.Contains(t, err.Error(), "prompt can't be empty when there are no messages")
540 })
541
542 t.Run("fails with files even if messages exist", func(t *testing.T) {
543 result, err := agent.Generate(context.Background(), AgentCall{
544 Prompt: "",
545 Messages: []Message{
546 {Role: MessageRoleUser, Content: []MessagePart{TextPart{Text: "hello"}}},
547 },
548 Files: []FilePart{{Filename: "test.txt", Data: []byte("test"), MediaType: "text/plain"}},
549 })
550 require.Error(t, err)
551 require.Nil(t, result)
552 require.Contains(t, err.Error(), "prompt can't be empty when there are files")
553 })
554
555 t.Run("fails when last message is assistant", func(t *testing.T) {
556 result, err := agent.Generate(context.Background(), AgentCall{
557 Prompt: "",
558 Messages: []Message{
559 {Role: MessageRoleUser, Content: []MessagePart{TextPart{Text: "hello"}}},
560 {Role: MessageRoleAssistant, Content: []MessagePart{TextPart{Text: "hi there"}}},
561 },
562 })
563 require.Error(t, err)
564 require.Nil(t, result)
565 require.Contains(t, err.Error(), "prompt can't be empty when the last message is not a user or tool message")
566 })
567
568 t.Run("succeeds when last message is user", func(t *testing.T) {
569 model := &mockLanguageModel{
570 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
571 return &Response{
572 Content: []Content{TextContent{Text: "response"}},
573 FinishReason: FinishReasonStop,
574 }, nil
575 },
576 }
577 agent := NewAgent(model)
578
579 result, err := agent.Generate(context.Background(), AgentCall{
580 Prompt: "",
581 Messages: []Message{
582 {Role: MessageRoleUser, Content: []MessagePart{TextPart{Text: "hello"}}},
583 },
584 })
585 require.NoError(t, err)
586 require.NotNil(t, result)
587 })
588
589 t.Run("succeeds when last message is tool", func(t *testing.T) {
590 model := &mockLanguageModel{
591 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
592 return &Response{
593 Content: []Content{TextContent{Text: "response"}},
594 FinishReason: FinishReasonStop,
595 }, nil
596 },
597 }
598 agent := NewAgent(model)
599
600 result, err := agent.Generate(context.Background(), AgentCall{
601 Prompt: "",
602 Messages: []Message{
603 {Role: MessageRoleUser, Content: []MessagePart{TextPart{Text: "hello"}}},
604 {Role: MessageRoleAssistant, Content: []MessagePart{ToolCallPart{ToolCallID: "call_1", ToolName: "test"}}},
605 {Role: MessageRoleTool, Content: []MessagePart{ToolResultPart{ToolCallID: "call_1", Output: ToolResultOutputContentText{Text: "result"}}}},
606 },
607 })
608 require.NoError(t, err)
609 require.NotNil(t, result)
610 })
611}
612
613// Test with system prompt
614func TestAgent_Generate_WithSystemPrompt(t *testing.T) {
615 t.Parallel()
616
617 model := &mockLanguageModel{
618 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
619 // Verify system message is included
620 require.Len(t, call.Prompt, 2) // system + user
621 require.Equal(t, MessageRoleSystem, call.Prompt[0].Role)
622 require.Equal(t, MessageRoleUser, call.Prompt[1].Role)
623
624 systemPart, ok := call.Prompt[0].Content[0].(TextPart)
625 require.True(t, ok)
626 require.Equal(t, "You are a helpful assistant", systemPart.Text)
627
628 return &Response{
629 Content: []Content{
630 TextContent{Text: "Hello, world!"},
631 },
632 Usage: Usage{
633 InputTokens: 3,
634 OutputTokens: 10,
635 TotalTokens: 13,
636 },
637 FinishReason: FinishReasonStop,
638 }, nil
639 },
640 }
641
642 agent := NewAgent(model, WithSystemPrompt("You are a helpful assistant"))
643 result, err := agent.Generate(context.Background(), AgentCall{
644 Prompt: "test prompt",
645 })
646
647 require.NoError(t, err)
648 require.NotNil(t, result)
649}
650
651// Test options.activeTools filtering
652func TestAgent_Generate_OptionsActiveTools(t *testing.T) {
653 t.Parallel()
654
655 tool1 := &mockTool{
656 name: "tool1",
657 description: "Test tool 1",
658 parameters: map[string]any{
659 "value": map[string]any{"type": "string"},
660 },
661 required: []string{"value"},
662 }
663
664 tool2 := &mockTool{
665 name: "tool2",
666 description: "Test tool 2",
667 parameters: map[string]any{
668 "value": map[string]any{"type": "string"},
669 },
670 required: []string{"value"},
671 }
672
673 model := &mockLanguageModel{
674 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
675 // Verify only tool1 is available
676 require.Len(t, call.Tools, 1)
677 functionTool, ok := call.Tools[0].(FunctionTool)
678 require.True(t, ok)
679 require.Equal(t, "tool1", functionTool.Name)
680
681 return &Response{
682 Content: []Content{
683 TextContent{Text: "Hello, world!"},
684 },
685 Usage: Usage{
686 InputTokens: 3,
687 OutputTokens: 10,
688 TotalTokens: 13,
689 },
690 FinishReason: FinishReasonStop,
691 }, nil
692 },
693 }
694
695 agent := NewAgent(model, WithTools(tool1, tool2))
696 result, err := agent.Generate(context.Background(), AgentCall{
697 Prompt: "test-input",
698 ActiveTools: []string{"tool1"}, // Only tool1 should be active
699 })
700
701 require.NoError(t, err)
702 require.NotNil(t, result)
703}
704
705func TestAgent_Generate_OptionsActiveTools_WithProviderDefinedTools(t *testing.T) {
706 t.Parallel()
707
708 tool1 := &mockTool{
709 name: "tool1",
710 description: "Test tool 1",
711 parameters: map[string]any{
712 "value": map[string]any{"type": "string"},
713 },
714 required: []string{"value"},
715 }
716
717 providerTool1 := ProviderDefinedTool{ID: "provider.web_search", Name: "web_search"}
718 providerTool2 := ProviderDefinedTool{ID: "provider.code_execution", Name: "code_execution"}
719
720 model := &mockLanguageModel{
721 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
722 require.Len(t, call.Tools, 2)
723
724 functionTool, ok := call.Tools[0].(FunctionTool)
725 require.True(t, ok)
726 require.Equal(t, "tool1", functionTool.Name)
727
728 providerTool, ok := call.Tools[1].(ProviderDefinedTool)
729 require.True(t, ok)
730 require.Equal(t, "web_search", providerTool.Name)
731
732 return &Response{
733 Content: []Content{
734 TextContent{Text: "Hello, world!"},
735 },
736 Usage: Usage{
737 InputTokens: 3,
738 OutputTokens: 10,
739 TotalTokens: 13,
740 },
741 FinishReason: FinishReasonStop,
742 }, nil
743 },
744 }
745
746 agent := NewAgent(model, WithTools(tool1), WithProviderDefinedTools(providerTool1, providerTool2))
747 result, err := agent.Generate(context.Background(), AgentCall{
748 Prompt: "test-input",
749 ActiveTools: []string{"tool1", "web_search"}, // Only tool1 and web_search should be active
750 })
751
752 require.NoError(t, err)
753 require.NotNil(t, result)
754}
755
756func TestResponseContent_Getters(t *testing.T) {
757 t.Parallel()
758
759 // Create test content with all types
760 content := ResponseContent{
761 TextContent{Text: "Hello world"},
762 ReasoningContent{Text: "Let me think..."},
763 FileContent{Data: []byte("file data"), MediaType: "text/plain"},
764 SourceContent{SourceType: SourceTypeURL, URL: "https://example.com", Title: "Example"},
765 ToolCallContent{ToolCallID: "call1", ToolName: "test_tool", Input: `{"arg": "value"}`},
766 ToolResultContent{ToolCallID: "call1", ToolName: "test_tool", Result: ToolResultOutputContentText{Text: "result"}},
767 }
768
769 // Test Text()
770 require.Equal(t, "Hello world", content.Text())
771
772 // Test Reasoning()
773 reasoning := content.Reasoning()
774 require.Len(t, reasoning, 1)
775 require.Equal(t, "Let me think...", reasoning[0].Text)
776
777 // Test ReasoningText()
778 require.Equal(t, "Let me think...", content.ReasoningText())
779
780 // Test Files()
781 files := content.Files()
782 require.Len(t, files, 1)
783 require.Equal(t, "text/plain", files[0].MediaType)
784 require.Equal(t, []byte("file data"), files[0].Data)
785
786 // Test Sources()
787 sources := content.Sources()
788 require.Len(t, sources, 1)
789 require.Equal(t, SourceTypeURL, sources[0].SourceType)
790 require.Equal(t, "https://example.com", sources[0].URL)
791 require.Equal(t, "Example", sources[0].Title)
792
793 // Test ToolCalls()
794 toolCalls := content.ToolCalls()
795 require.Len(t, toolCalls, 1)
796 require.Equal(t, "call1", toolCalls[0].ToolCallID)
797 require.Equal(t, "test_tool", toolCalls[0].ToolName)
798 require.Equal(t, `{"arg": "value"}`, toolCalls[0].Input)
799
800 // Test ToolResults()
801 toolResults := content.ToolResults()
802 require.Len(t, toolResults, 1)
803 require.Equal(t, "call1", toolResults[0].ToolCallID)
804 require.Equal(t, "test_tool", toolResults[0].ToolName)
805 result, ok := AsToolResultOutputType[ToolResultOutputContentText](toolResults[0].Result)
806 require.True(t, ok)
807 require.Equal(t, "result", result.Text)
808}
809
810func TestResponseContent_Getters_Empty(t *testing.T) {
811 t.Parallel()
812
813 // Test with empty content
814 content := ResponseContent{}
815
816 require.Equal(t, "", content.Text())
817 require.Equal(t, "", content.ReasoningText())
818 require.Empty(t, content.Reasoning())
819 require.Empty(t, content.Files())
820 require.Empty(t, content.Sources())
821 require.Empty(t, content.ToolCalls())
822 require.Empty(t, content.ToolResults())
823}
824
825func TestResponseContent_Getters_MultipleItems(t *testing.T) {
826 t.Parallel()
827
828 // Test with multiple items of same type
829 content := ResponseContent{
830 ReasoningContent{Text: "First thought"},
831 ReasoningContent{Text: "Second thought"},
832 FileContent{Data: []byte("file1"), MediaType: "text/plain"},
833 FileContent{Data: []byte("file2"), MediaType: "image/png"},
834 }
835
836 // Test multiple reasoning
837 reasoning := content.Reasoning()
838 require.Len(t, reasoning, 2)
839 require.Equal(t, "First thought", reasoning[0].Text)
840 require.Equal(t, "Second thought", reasoning[1].Text)
841
842 // Test concatenated reasoning text
843 require.Equal(t, "First thoughtSecond thought", content.ReasoningText())
844
845 // Test multiple files
846 files := content.Files()
847 require.Len(t, files, 2)
848 require.Equal(t, "text/plain", files[0].MediaType)
849 require.Equal(t, "image/png", files[1].MediaType)
850}
851
852func TestStopConditions(t *testing.T) {
853 t.Parallel()
854
855 // Create test steps
856 step1 := StepResult{
857 Response: Response{
858 Content: ResponseContent{
859 TextContent{Text: "Hello"},
860 },
861 FinishReason: FinishReasonToolCalls,
862 Usage: Usage{TotalTokens: 10},
863 },
864 }
865
866 step2 := StepResult{
867 Response: Response{
868 Content: ResponseContent{
869 TextContent{Text: "World"},
870 ToolCallContent{ToolCallID: "call1", ToolName: "search", Input: `{"query": "test"}`},
871 },
872 FinishReason: FinishReasonStop,
873 Usage: Usage{TotalTokens: 15},
874 },
875 }
876
877 step3 := StepResult{
878 Response: Response{
879 Content: ResponseContent{
880 ReasoningContent{Text: "Let me think..."},
881 FileContent{Data: []byte("data"), MediaType: "text/plain"},
882 },
883 FinishReason: FinishReasonLength,
884 Usage: Usage{TotalTokens: 20},
885 },
886 }
887
888 t.Run("StepCountIs", func(t *testing.T) {
889 t.Parallel()
890 condition := StepCountIs(2)
891
892 // Should not stop with 1 step
893 require.False(t, condition([]StepResult{step1}))
894
895 // Should stop with 2 steps
896 require.True(t, condition([]StepResult{step1, step2}))
897
898 // Should stop with more than 2 steps
899 require.True(t, condition([]StepResult{step1, step2, step3}))
900
901 // Should not stop with empty steps
902 require.False(t, condition([]StepResult{}))
903 })
904
905 t.Run("HasToolCall", func(t *testing.T) {
906 t.Parallel()
907 condition := HasToolCall("search")
908
909 // Should not stop when tool not called
910 require.False(t, condition([]StepResult{step1}))
911
912 // Should stop when tool is called in last step
913 require.True(t, condition([]StepResult{step1, step2}))
914
915 // Should not stop when tool called in earlier step but not last
916 require.False(t, condition([]StepResult{step1, step2, step3}))
917
918 // Should not stop with empty steps
919 require.False(t, condition([]StepResult{}))
920
921 // Should not stop when different tool is called
922 differentToolCondition := HasToolCall("different_tool")
923 require.False(t, differentToolCondition([]StepResult{step1, step2}))
924 })
925
926 t.Run("HasContent", func(t *testing.T) {
927 t.Parallel()
928 reasoningCondition := HasContent(ContentTypeReasoning)
929 fileCondition := HasContent(ContentTypeFile)
930
931 // Should not stop when content type not present
932 require.False(t, reasoningCondition([]StepResult{step1, step2}))
933
934 // Should stop when content type is present in last step
935 require.True(t, reasoningCondition([]StepResult{step1, step2, step3}))
936 require.True(t, fileCondition([]StepResult{step1, step2, step3}))
937
938 // Should not stop with empty steps
939 require.False(t, reasoningCondition([]StepResult{}))
940 })
941
942 t.Run("FinishReasonIs", func(t *testing.T) {
943 t.Parallel()
944 stopCondition := FinishReasonIs(FinishReasonStop)
945 lengthCondition := FinishReasonIs(FinishReasonLength)
946
947 // Should not stop when finish reason doesn't match
948 require.False(t, stopCondition([]StepResult{step1}))
949
950 // Should stop when finish reason matches in last step
951 require.True(t, stopCondition([]StepResult{step1, step2}))
952 require.True(t, lengthCondition([]StepResult{step1, step2, step3}))
953
954 // Should not stop with empty steps
955 require.False(t, stopCondition([]StepResult{}))
956 })
957
958 t.Run("MaxTokensUsed", func(t *testing.T) {
959 condition := MaxTokensUsed(30)
960
961 // Should not stop when under limit
962 require.False(t, condition([]StepResult{step1})) // 10 tokens
963 require.False(t, condition([]StepResult{step1, step2})) // 25 tokens
964
965 // Should stop when at or over limit
966 require.True(t, condition([]StepResult{step1, step2, step3})) // 45 tokens
967
968 // Should not stop with empty steps
969 require.False(t, condition([]StepResult{}))
970 })
971}
972
973func TestStopConditions_Integration(t *testing.T) {
974 t.Parallel()
975
976 t.Run("StepCountIs integration", func(t *testing.T) {
977 t.Parallel()
978 model := &mockLanguageModel{
979 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
980 return &Response{
981 Content: ResponseContent{
982 TextContent{Text: "Mock response"},
983 },
984 Usage: Usage{
985 InputTokens: 3,
986 OutputTokens: 10,
987 TotalTokens: 13,
988 },
989 FinishReason: FinishReasonStop,
990 }, nil
991 },
992 }
993
994 agent := NewAgent(model, WithStopConditions(StepCountIs(1)))
995
996 result, err := agent.Generate(context.Background(), AgentCall{
997 Prompt: "test prompt",
998 })
999
1000 require.NoError(t, err)
1001 require.NotNil(t, result)
1002 require.Len(t, result.Steps, 1) // Should stop after 1 step
1003 })
1004
1005 t.Run("Multiple stop conditions", func(t *testing.T) {
1006 t.Parallel()
1007 model := &mockLanguageModel{
1008 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1009 return &Response{
1010 Content: ResponseContent{
1011 TextContent{Text: "Mock response"},
1012 },
1013 Usage: Usage{
1014 InputTokens: 3,
1015 OutputTokens: 10,
1016 TotalTokens: 13,
1017 },
1018 FinishReason: FinishReasonStop,
1019 }, nil
1020 },
1021 }
1022
1023 agent := NewAgent(model, WithStopConditions(
1024 StepCountIs(5), // Stop after 5 steps
1025 FinishReasonIs(FinishReasonStop), // Or stop on finish reason
1026 ))
1027
1028 result, err := agent.Generate(context.Background(), AgentCall{
1029 Prompt: "test prompt",
1030 })
1031
1032 require.NoError(t, err)
1033 require.NotNil(t, result)
1034 // Should stop on first condition met (finish reason stop)
1035 require.Equal(t, FinishReasonStop, result.Response.FinishReason)
1036 })
1037}
1038
1039func TestPrepareStep(t *testing.T) {
1040 t.Parallel()
1041
1042 t.Run("System prompt modification", func(t *testing.T) {
1043 t.Parallel()
1044 var capturedSystemPrompt string
1045 model := &mockLanguageModel{
1046 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1047 // Capture the system message to verify it was modified
1048 if len(call.Prompt) > 0 && call.Prompt[0].Role == MessageRoleSystem {
1049 if len(call.Prompt[0].Content) > 0 {
1050 if textPart, ok := AsContentType[TextPart](call.Prompt[0].Content[0]); ok {
1051 capturedSystemPrompt = textPart.Text
1052 }
1053 }
1054 }
1055 return &Response{
1056 Content: ResponseContent{
1057 TextContent{Text: "Response"},
1058 },
1059 Usage: Usage{TotalTokens: 10},
1060 FinishReason: FinishReasonStop,
1061 }, nil
1062 },
1063 }
1064
1065 prepareStepFunc := func(ctx context.Context, options PrepareStepFunctionOptions) (context.Context, PrepareStepResult, error) {
1066 newSystem := "Modified system prompt for step " + fmt.Sprintf("%d", options.StepNumber)
1067 return ctx, PrepareStepResult{
1068 Model: options.Model,
1069 Messages: options.Messages,
1070 System: &newSystem,
1071 }, nil
1072 }
1073
1074 agent := NewAgent(model, WithSystemPrompt("Original system prompt"))
1075
1076 result, err := agent.Generate(context.Background(), AgentCall{
1077 Prompt: "test prompt",
1078 PrepareStep: prepareStepFunc,
1079 })
1080
1081 require.NoError(t, err)
1082 require.NotNil(t, result)
1083 require.Equal(t, "Modified system prompt for step 0", capturedSystemPrompt)
1084 })
1085
1086 t.Run("Tool choice modification", func(t *testing.T) {
1087 t.Parallel()
1088 var capturedToolChoice *ToolChoice
1089 model := &mockLanguageModel{
1090 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1091 capturedToolChoice = call.ToolChoice
1092 return &Response{
1093 Content: ResponseContent{
1094 TextContent{Text: "Response"},
1095 },
1096 Usage: Usage{TotalTokens: 10},
1097 FinishReason: FinishReasonStop,
1098 }, nil
1099 },
1100 }
1101
1102 prepareStepFunc := func(ctx context.Context, options PrepareStepFunctionOptions) (context.Context, PrepareStepResult, error) {
1103 toolChoice := ToolChoiceNone
1104 return ctx, PrepareStepResult{
1105 Model: options.Model,
1106 Messages: options.Messages,
1107 ToolChoice: &toolChoice,
1108 }, nil
1109 }
1110
1111 agent := NewAgent(model)
1112
1113 result, err := agent.Generate(context.Background(), AgentCall{
1114 Prompt: "test prompt",
1115 PrepareStep: prepareStepFunc,
1116 })
1117
1118 require.NoError(t, err)
1119 require.NotNil(t, result)
1120 require.NotNil(t, capturedToolChoice)
1121 require.Equal(t, ToolChoiceNone, *capturedToolChoice)
1122 })
1123
1124 t.Run("Active tools modification", func(t *testing.T) {
1125 t.Parallel()
1126 var capturedToolNames []string
1127 model := &mockLanguageModel{
1128 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1129 // Capture tool names to verify active tools were modified
1130 for _, tool := range call.Tools {
1131 capturedToolNames = append(capturedToolNames, tool.GetName())
1132 }
1133 return &Response{
1134 Content: ResponseContent{
1135 TextContent{Text: "Response"},
1136 },
1137 Usage: Usage{TotalTokens: 10},
1138 FinishReason: FinishReasonStop,
1139 }, nil
1140 },
1141 }
1142
1143 tool1 := &mockTool{name: "tool1", description: "Tool 1"}
1144 tool2 := &mockTool{name: "tool2", description: "Tool 2"}
1145 tool3 := &mockTool{name: "tool3", description: "Tool 3"}
1146
1147 prepareStepFunc := func(ctx context.Context, options PrepareStepFunctionOptions) (context.Context, PrepareStepResult, error) {
1148 activeTools := []string{"tool2"} // Only tool2 should be active
1149 return ctx, PrepareStepResult{
1150 Model: options.Model,
1151 Messages: options.Messages,
1152 ActiveTools: activeTools,
1153 }, nil
1154 }
1155
1156 agent := NewAgent(model, WithTools(tool1, tool2, tool3))
1157
1158 result, err := agent.Generate(context.Background(), AgentCall{
1159 Prompt: "test prompt",
1160 PrepareStep: prepareStepFunc,
1161 })
1162
1163 require.NoError(t, err)
1164 require.NotNil(t, result)
1165 require.Len(t, capturedToolNames, 1)
1166 require.Equal(t, "tool2", capturedToolNames[0])
1167 })
1168
1169 t.Run("No tools when DisableAllTools is true", func(t *testing.T) {
1170 t.Parallel()
1171 var capturedToolCount int
1172 model := &mockLanguageModel{
1173 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1174 capturedToolCount = len(call.Tools)
1175 return &Response{
1176 Content: ResponseContent{
1177 TextContent{Text: "Response"},
1178 },
1179 Usage: Usage{TotalTokens: 10},
1180 FinishReason: FinishReasonStop,
1181 }, nil
1182 },
1183 }
1184
1185 tool1 := &mockTool{name: "tool1", description: "Tool 1"}
1186
1187 prepareStepFunc := func(ctx context.Context, options PrepareStepFunctionOptions) (context.Context, PrepareStepResult, error) {
1188 return ctx, PrepareStepResult{
1189 Model: options.Model,
1190 Messages: options.Messages,
1191 DisableAllTools: true, // Disable all tools for this step
1192 }, nil
1193 }
1194
1195 agent := NewAgent(model, WithTools(tool1))
1196
1197 result, err := agent.Generate(context.Background(), AgentCall{
1198 Prompt: "test prompt",
1199 PrepareStep: prepareStepFunc,
1200 })
1201
1202 require.NoError(t, err)
1203 require.NotNil(t, result)
1204 require.Equal(t, 0, capturedToolCount) // No tools should be passed
1205 })
1206
1207 t.Run("All fields modified together", func(t *testing.T) {
1208 t.Parallel()
1209 var capturedSystemPrompt string
1210 var capturedToolChoice *ToolChoice
1211 var capturedToolNames []string
1212
1213 model := &mockLanguageModel{
1214 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1215 // Capture system prompt
1216 if len(call.Prompt) > 0 && call.Prompt[0].Role == MessageRoleSystem {
1217 if len(call.Prompt[0].Content) > 0 {
1218 if textPart, ok := AsContentType[TextPart](call.Prompt[0].Content[0]); ok {
1219 capturedSystemPrompt = textPart.Text
1220 }
1221 }
1222 }
1223 // Capture tool choice
1224 capturedToolChoice = call.ToolChoice
1225 // Capture tool names
1226 for _, tool := range call.Tools {
1227 capturedToolNames = append(capturedToolNames, tool.GetName())
1228 }
1229 return &Response{
1230 Content: ResponseContent{
1231 TextContent{Text: "Response"},
1232 },
1233 Usage: Usage{TotalTokens: 10},
1234 FinishReason: FinishReasonStop,
1235 }, nil
1236 },
1237 }
1238
1239 tool1 := &mockTool{name: "tool1", description: "Tool 1"}
1240 tool2 := &mockTool{name: "tool2", description: "Tool 2"}
1241
1242 prepareStepFunc := func(ctx context.Context, options PrepareStepFunctionOptions) (context.Context, PrepareStepResult, error) {
1243 newSystem := "Step-specific system"
1244 toolChoice := SpecificToolChoice("tool1")
1245 activeTools := []string{"tool1"}
1246 return ctx, PrepareStepResult{
1247 Model: options.Model,
1248 Messages: options.Messages,
1249 System: &newSystem,
1250 ToolChoice: &toolChoice,
1251 ActiveTools: activeTools,
1252 }, nil
1253 }
1254
1255 agent := NewAgent(model, WithSystemPrompt("Original system"), WithTools(tool1, tool2))
1256
1257 result, err := agent.Generate(context.Background(), AgentCall{
1258 Prompt: "test prompt",
1259 PrepareStep: prepareStepFunc,
1260 })
1261
1262 require.NoError(t, err)
1263 require.NotNil(t, result)
1264 require.Equal(t, "Step-specific system", capturedSystemPrompt)
1265 require.NotNil(t, capturedToolChoice)
1266 require.Equal(t, SpecificToolChoice("tool1"), *capturedToolChoice)
1267 require.Len(t, capturedToolNames, 1)
1268 require.Equal(t, "tool1", capturedToolNames[0])
1269 })
1270
1271 t.Run("Nil fields use parent values", func(t *testing.T) {
1272 t.Parallel()
1273 var capturedSystemPrompt string
1274 var capturedToolChoice *ToolChoice
1275 var capturedToolNames []string
1276
1277 model := &mockLanguageModel{
1278 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1279 // Capture system prompt
1280 if len(call.Prompt) > 0 && call.Prompt[0].Role == MessageRoleSystem {
1281 if len(call.Prompt[0].Content) > 0 {
1282 if textPart, ok := AsContentType[TextPart](call.Prompt[0].Content[0]); ok {
1283 capturedSystemPrompt = textPart.Text
1284 }
1285 }
1286 }
1287 // Capture tool choice
1288 capturedToolChoice = call.ToolChoice
1289 // Capture tool names
1290 for _, tool := range call.Tools {
1291 capturedToolNames = append(capturedToolNames, tool.GetName())
1292 }
1293 return &Response{
1294 Content: ResponseContent{
1295 TextContent{Text: "Response"},
1296 },
1297 Usage: Usage{TotalTokens: 10},
1298 FinishReason: FinishReasonStop,
1299 }, nil
1300 },
1301 }
1302
1303 tool1 := &mockTool{name: "tool1", description: "Tool 1"}
1304
1305 prepareStepFunc := func(ctx context.Context, options PrepareStepFunctionOptions) (context.Context, PrepareStepResult, error) {
1306 // All optional fields are nil, should use parent values
1307 return ctx, PrepareStepResult{
1308 Model: options.Model,
1309 Messages: options.Messages,
1310 System: nil, // Use parent
1311 ToolChoice: nil, // Use parent (auto)
1312 ActiveTools: nil, // Use parent (all tools)
1313 }, nil
1314 }
1315
1316 agent := NewAgent(model, WithSystemPrompt("Parent system"), WithTools(tool1))
1317
1318 result, err := agent.Generate(context.Background(), AgentCall{
1319 Prompt: "test prompt",
1320 PrepareStep: prepareStepFunc,
1321 })
1322
1323 require.NoError(t, err)
1324 require.NotNil(t, result)
1325 require.Equal(t, "Parent system", capturedSystemPrompt)
1326 require.NotNil(t, capturedToolChoice)
1327 require.Equal(t, ToolChoiceAuto, *capturedToolChoice) // Default
1328 require.Len(t, capturedToolNames, 1)
1329 require.Equal(t, "tool1", capturedToolNames[0])
1330 })
1331
1332 t.Run("Empty ActiveTools means all tools", func(t *testing.T) {
1333 t.Parallel()
1334 var capturedToolNames []string
1335 model := &mockLanguageModel{
1336 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1337 // Capture tool names to verify all tools are included
1338 for _, tool := range call.Tools {
1339 capturedToolNames = append(capturedToolNames, tool.GetName())
1340 }
1341 return &Response{
1342 Content: ResponseContent{
1343 TextContent{Text: "Response"},
1344 },
1345 Usage: Usage{TotalTokens: 10},
1346 FinishReason: FinishReasonStop,
1347 }, nil
1348 },
1349 }
1350
1351 tool1 := &mockTool{name: "tool1", description: "Tool 1"}
1352 tool2 := &mockTool{name: "tool2", description: "Tool 2"}
1353
1354 prepareStepFunc := func(ctx context.Context, options PrepareStepFunctionOptions) (context.Context, PrepareStepResult, error) {
1355 return ctx, PrepareStepResult{
1356 Model: options.Model,
1357 Messages: options.Messages,
1358 ActiveTools: []string{}, // Empty slice means all tools
1359 }, nil
1360 }
1361
1362 agent := NewAgent(model, WithTools(tool1, tool2))
1363
1364 result, err := agent.Generate(context.Background(), AgentCall{
1365 Prompt: "test prompt",
1366 PrepareStep: prepareStepFunc,
1367 })
1368
1369 require.NoError(t, err)
1370 require.NotNil(t, result)
1371 require.Len(t, capturedToolNames, 2) // All tools should be included
1372 require.Contains(t, capturedToolNames, "tool1")
1373 require.Contains(t, capturedToolNames, "tool2")
1374 })
1375}
1376
1377func TestToolCallRepair(t *testing.T) {
1378 t.Parallel()
1379
1380 t.Run("Valid tool call passes validation", func(t *testing.T) {
1381 t.Parallel()
1382 model := &mockLanguageModel{
1383 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1384 return &Response{
1385 Content: ResponseContent{
1386 TextContent{Text: "Response"},
1387 ToolCallContent{
1388 ToolCallID: "call1",
1389 ToolName: "test_tool",
1390 Input: `{"value": "test"}`, // Valid JSON with required field
1391 },
1392 },
1393 Usage: Usage{TotalTokens: 10},
1394 FinishReason: FinishReasonStop, // Changed to stop to avoid infinite loop
1395 }, nil
1396 },
1397 }
1398
1399 tool := &mockTool{
1400 name: "test_tool",
1401 description: "Test tool",
1402 parameters: map[string]any{
1403 "value": map[string]any{"type": "string"},
1404 },
1405 required: []string{"value"},
1406 executeFunc: func(ctx context.Context, call ToolCall) (ToolResponse, error) {
1407 return ToolResponse{Content: "success", IsError: false}, nil
1408 },
1409 }
1410
1411 agent := NewAgent(model, WithTools(tool), WithStopConditions(StepCountIs(2))) // Limit steps
1412
1413 result, err := agent.Generate(context.Background(), AgentCall{
1414 Prompt: "test prompt",
1415 })
1416
1417 require.NoError(t, err)
1418 require.NotNil(t, result)
1419 require.Len(t, result.Steps, 1) // Only one step since FinishReason is stop
1420
1421 // Check that tool call was executed successfully
1422 toolCalls := result.Steps[0].Content.ToolCalls()
1423 require.Len(t, toolCalls, 1)
1424 require.False(t, toolCalls[0].Invalid) // Should be valid
1425 })
1426
1427 t.Run("Invalid tool call without repair function", func(t *testing.T) {
1428 t.Parallel()
1429 model := &mockLanguageModel{
1430 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1431 return &Response{
1432 Content: ResponseContent{
1433 TextContent{Text: "Response"},
1434 ToolCallContent{
1435 ToolCallID: "call1",
1436 ToolName: "test_tool",
1437 Input: `{"wrong_field": "test"}`, // Missing required field
1438 },
1439 },
1440 Usage: Usage{TotalTokens: 10},
1441 FinishReason: FinishReasonStop, // Changed to stop to avoid infinite loop
1442 }, nil
1443 },
1444 }
1445
1446 tool := &mockTool{
1447 name: "test_tool",
1448 description: "Test tool",
1449 parameters: map[string]any{
1450 "value": map[string]any{"type": "string"},
1451 },
1452 required: []string{"value"},
1453 }
1454
1455 agent := NewAgent(model, WithTools(tool), WithStopConditions(StepCountIs(2))) // Limit steps
1456
1457 result, err := agent.Generate(context.Background(), AgentCall{
1458 Prompt: "test prompt",
1459 })
1460
1461 require.NoError(t, err)
1462 require.NotNil(t, result)
1463 require.Len(t, result.Steps, 1) // Only one step
1464
1465 // Check that tool call was marked as invalid
1466 toolCalls := result.Steps[0].Content.ToolCalls()
1467 require.Len(t, toolCalls, 1)
1468 require.True(t, toolCalls[0].Invalid) // Should be invalid
1469 require.Contains(t, toolCalls[0].ValidationError.Error(), "missing required parameter: value")
1470 })
1471
1472 t.Run("Invalid tool call with successful repair", func(t *testing.T) {
1473 t.Parallel()
1474 model := &mockLanguageModel{
1475 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1476 return &Response{
1477 Content: ResponseContent{
1478 TextContent{Text: "Response"},
1479 ToolCallContent{
1480 ToolCallID: "call1",
1481 ToolName: "test_tool",
1482 Input: `{"wrong_field": "test"}`, // Missing required field
1483 },
1484 },
1485 Usage: Usage{TotalTokens: 10},
1486 FinishReason: FinishReasonStop, // Changed to stop
1487 }, nil
1488 },
1489 }
1490
1491 tool := &mockTool{
1492 name: "test_tool",
1493 description: "Test tool",
1494 parameters: map[string]any{
1495 "value": map[string]any{"type": "string"},
1496 },
1497 required: []string{"value"},
1498 executeFunc: func(ctx context.Context, call ToolCall) (ToolResponse, error) {
1499 return ToolResponse{Content: "repaired_success", IsError: false}, nil
1500 },
1501 }
1502
1503 repairFunc := func(ctx context.Context, options ToolCallRepairOptions) (*ToolCallContent, error) {
1504 // Simple repair: add the missing required field
1505 repairedToolCall := options.OriginalToolCall
1506 repairedToolCall.Input = `{"value": "repaired"}`
1507 return &repairedToolCall, nil
1508 }
1509
1510 agent := NewAgent(model, WithTools(tool), WithRepairToolCall(repairFunc), WithStopConditions(StepCountIs(2)))
1511
1512 result, err := agent.Generate(context.Background(), AgentCall{
1513 Prompt: "test prompt",
1514 })
1515
1516 require.NoError(t, err)
1517 require.NotNil(t, result)
1518 require.Len(t, result.Steps, 1) // Only one step
1519
1520 // Check that tool call was repaired and is now valid
1521 toolCalls := result.Steps[0].Content.ToolCalls()
1522 require.Len(t, toolCalls, 1)
1523 require.False(t, toolCalls[0].Invalid) // Should be valid after repair
1524 require.Equal(t, `{"value": "repaired"}`, toolCalls[0].Input) // Should have repaired input
1525 })
1526
1527 t.Run("Invalid tool call with failed repair", func(t *testing.T) {
1528 t.Parallel()
1529 model := &mockLanguageModel{
1530 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1531 return &Response{
1532 Content: ResponseContent{
1533 TextContent{Text: "Response"},
1534 ToolCallContent{
1535 ToolCallID: "call1",
1536 ToolName: "test_tool",
1537 Input: `{"wrong_field": "test"}`, // Missing required field
1538 },
1539 },
1540 Usage: Usage{TotalTokens: 10},
1541 FinishReason: FinishReasonStop, // Changed to stop
1542 }, nil
1543 },
1544 }
1545
1546 tool := &mockTool{
1547 name: "test_tool",
1548 description: "Test tool",
1549 parameters: map[string]any{
1550 "value": map[string]any{"type": "string"},
1551 },
1552 required: []string{"value"},
1553 }
1554
1555 repairFunc := func(ctx context.Context, options ToolCallRepairOptions) (*ToolCallContent, error) {
1556 // Repair function fails
1557 return nil, errors.New("repair failed")
1558 }
1559
1560 agent := NewAgent(model, WithTools(tool), WithRepairToolCall(repairFunc), WithStopConditions(StepCountIs(2)))
1561
1562 result, err := agent.Generate(context.Background(), AgentCall{
1563 Prompt: "test prompt",
1564 })
1565
1566 require.NoError(t, err)
1567 require.NotNil(t, result)
1568 require.Len(t, result.Steps, 1) // Only one step
1569
1570 // Check that tool call was marked as invalid since repair failed
1571 toolCalls := result.Steps[0].Content.ToolCalls()
1572 require.Len(t, toolCalls, 1)
1573 require.True(t, toolCalls[0].Invalid) // Should be invalid
1574 require.Contains(t, toolCalls[0].ValidationError.Error(), "missing required parameter: value")
1575 })
1576
1577 t.Run("Nonexistent tool call", func(t *testing.T) {
1578 t.Parallel()
1579 model := &mockLanguageModel{
1580 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1581 return &Response{
1582 Content: ResponseContent{
1583 TextContent{Text: "Response"},
1584 ToolCallContent{
1585 ToolCallID: "call1",
1586 ToolName: "nonexistent_tool",
1587 Input: `{"value": "test"}`,
1588 },
1589 },
1590 Usage: Usage{TotalTokens: 10},
1591 FinishReason: FinishReasonStop, // Changed to stop
1592 }, nil
1593 },
1594 }
1595
1596 tool := &mockTool{name: "test_tool", description: "Test tool"}
1597
1598 agent := NewAgent(model, WithTools(tool), WithStopConditions(StepCountIs(2)))
1599
1600 result, err := agent.Generate(context.Background(), AgentCall{
1601 Prompt: "test prompt",
1602 })
1603
1604 require.NoError(t, err)
1605 require.NotNil(t, result)
1606 require.Len(t, result.Steps, 1) // Only one step
1607
1608 // Check that tool call was marked as invalid due to nonexistent tool
1609 toolCalls := result.Steps[0].Content.ToolCalls()
1610 require.Len(t, toolCalls, 1)
1611 require.True(t, toolCalls[0].Invalid) // Should be invalid
1612 require.Contains(t, toolCalls[0].ValidationError.Error(), "tool not found: nonexistent_tool")
1613 })
1614
1615 t.Run("Invalid JSON in tool call", func(t *testing.T) {
1616 t.Parallel()
1617 model := &mockLanguageModel{
1618 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1619 return &Response{
1620 Content: ResponseContent{
1621 TextContent{Text: "Response"},
1622 ToolCallContent{
1623 ToolCallID: "call1",
1624 ToolName: "test_tool",
1625 Input: `{invalid json}`, // Invalid JSON
1626 },
1627 },
1628 Usage: Usage{TotalTokens: 10},
1629 FinishReason: FinishReasonStop, // Changed to stop
1630 }, nil
1631 },
1632 }
1633
1634 tool := &mockTool{
1635 name: "test_tool",
1636 description: "Test tool",
1637 parameters: map[string]any{
1638 "value": map[string]any{"type": "string"},
1639 },
1640 required: []string{"value"},
1641 }
1642
1643 agent := NewAgent(model, WithTools(tool), WithStopConditions(StepCountIs(2)))
1644
1645 result, err := agent.Generate(context.Background(), AgentCall{
1646 Prompt: "test prompt",
1647 })
1648
1649 require.NoError(t, err)
1650 require.NotNil(t, result)
1651 require.Len(t, result.Steps, 1) // Only one step
1652
1653 // Check that tool call was marked as invalid due to invalid JSON
1654 toolCalls := result.Steps[0].Content.ToolCalls()
1655 require.Len(t, toolCalls, 1)
1656 require.True(t, toolCalls[0].Invalid) // Should be invalid
1657 require.Contains(t, toolCalls[0].ValidationError.Error(), "invalid JSON input")
1658 })
1659}
1660
1661// Test media and image tool responses
1662func TestAgent_MediaToolResponses(t *testing.T) {
1663 t.Parallel()
1664
1665 imageData := []byte{0x89, 0x50, 0x4E, 0x47} // PNG header bytes
1666 audioData := []byte{0x52, 0x49, 0x46, 0x46} // RIFF header bytes
1667
1668 t.Run("Image tool response", func(t *testing.T) {
1669 t.Parallel()
1670
1671 imageTool := &mockTool{
1672 name: "generate_image",
1673 description: "Generates an image",
1674 executeFunc: func(ctx context.Context, call ToolCall) (ToolResponse, error) {
1675 return NewImageResponse(imageData, "image/png"), nil
1676 },
1677 }
1678
1679 model := &mockLanguageModel{
1680 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1681 if len(call.Prompt) == 1 {
1682 // First call - request image tool
1683 return &Response{
1684 Content: []Content{
1685 ToolCallContent{
1686 ToolCallID: "img-1",
1687 ToolName: "generate_image",
1688 Input: `{}`,
1689 },
1690 },
1691 Usage: Usage{TotalTokens: 10},
1692 FinishReason: FinishReasonToolCalls,
1693 }, nil
1694 }
1695 // Second call - after tool execution
1696 return &Response{
1697 Content: []Content{TextContent{Text: "Image generated"}},
1698 Usage: Usage{TotalTokens: 20},
1699 FinishReason: FinishReasonStop,
1700 }, nil
1701 },
1702 }
1703
1704 agent := NewAgent(model, WithTools(imageTool), WithStopConditions(StepCountIs(3)))
1705
1706 result, err := agent.Generate(context.Background(), AgentCall{
1707 Prompt: "Generate an image",
1708 })
1709
1710 require.NoError(t, err)
1711 require.NotNil(t, result)
1712 require.Len(t, result.Steps, 2) // Tool call step + final response
1713
1714 // Check tool results in first step
1715 toolResults := result.Steps[0].Content.ToolResults()
1716 require.Len(t, toolResults, 1)
1717
1718 mediaResult, ok := toolResults[0].Result.(ToolResultOutputContentMedia)
1719 require.True(t, ok, "Expected media result")
1720 require.Equal(t, string(imageData), mediaResult.Data)
1721 require.Equal(t, "image/png", mediaResult.MediaType)
1722 })
1723
1724 t.Run("Media tool response (audio)", func(t *testing.T) {
1725 t.Parallel()
1726
1727 audioTool := &mockTool{
1728 name: "generate_audio",
1729 description: "Generates audio",
1730 executeFunc: func(ctx context.Context, call ToolCall) (ToolResponse, error) {
1731 return NewMediaResponse(audioData, "audio/wav"), nil
1732 },
1733 }
1734
1735 model := &mockLanguageModel{
1736 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1737 if len(call.Prompt) == 1 {
1738 return &Response{
1739 Content: []Content{
1740 ToolCallContent{
1741 ToolCallID: "audio-1",
1742 ToolName: "generate_audio",
1743 Input: `{}`,
1744 },
1745 },
1746 Usage: Usage{TotalTokens: 10},
1747 FinishReason: FinishReasonToolCalls,
1748 }, nil
1749 }
1750 return &Response{
1751 Content: []Content{TextContent{Text: "Audio generated"}},
1752 Usage: Usage{TotalTokens: 20},
1753 FinishReason: FinishReasonStop,
1754 }, nil
1755 },
1756 }
1757
1758 agent := NewAgent(model, WithTools(audioTool), WithStopConditions(StepCountIs(3)))
1759
1760 result, err := agent.Generate(context.Background(), AgentCall{
1761 Prompt: "Generate audio",
1762 })
1763
1764 require.NoError(t, err)
1765 require.NotNil(t, result)
1766
1767 toolResults := result.Steps[0].Content.ToolResults()
1768 require.Len(t, toolResults, 1)
1769
1770 mediaResult, ok := toolResults[0].Result.(ToolResultOutputContentMedia)
1771 require.True(t, ok, "Expected media result")
1772 require.Equal(t, string(audioData), mediaResult.Data)
1773 require.Equal(t, "audio/wav", mediaResult.MediaType)
1774 })
1775
1776 t.Run("Media response with text", func(t *testing.T) {
1777 t.Parallel()
1778
1779 imageTool := &mockTool{
1780 name: "screenshot",
1781 description: "Takes a screenshot",
1782 executeFunc: func(ctx context.Context, call ToolCall) (ToolResponse, error) {
1783 resp := NewImageResponse(imageData, "image/png")
1784 resp.Content = "Screenshot captured successfully"
1785 return resp, nil
1786 },
1787 }
1788
1789 model := &mockLanguageModel{
1790 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1791 if len(call.Prompt) == 1 {
1792 return &Response{
1793 Content: []Content{
1794 ToolCallContent{
1795 ToolCallID: "screen-1",
1796 ToolName: "screenshot",
1797 Input: `{}`,
1798 },
1799 },
1800 Usage: Usage{TotalTokens: 10},
1801 FinishReason: FinishReasonToolCalls,
1802 }, nil
1803 }
1804 return &Response{
1805 Content: []Content{TextContent{Text: "Done"}},
1806 Usage: Usage{TotalTokens: 20},
1807 FinishReason: FinishReasonStop,
1808 }, nil
1809 },
1810 }
1811
1812 agent := NewAgent(model, WithTools(imageTool), WithStopConditions(StepCountIs(3)))
1813
1814 result, err := agent.Generate(context.Background(), AgentCall{
1815 Prompt: "Take a screenshot",
1816 })
1817
1818 require.NoError(t, err)
1819 require.NotNil(t, result)
1820
1821 toolResults := result.Steps[0].Content.ToolResults()
1822 require.Len(t, toolResults, 1)
1823
1824 mediaResult, ok := toolResults[0].Result.(ToolResultOutputContentMedia)
1825 require.True(t, ok, "Expected media result")
1826 require.Equal(t, string(imageData), mediaResult.Data)
1827 require.Equal(t, "image/png", mediaResult.MediaType)
1828 require.Equal(t, "Screenshot captured successfully", mediaResult.Text)
1829 })
1830
1831 t.Run("Media response preserves metadata", func(t *testing.T) {
1832 t.Parallel()
1833
1834 type ImageMetadata struct {
1835 Width int `json:"width"`
1836 Height int `json:"height"`
1837 }
1838
1839 imageTool := &mockTool{
1840 name: "generate_image",
1841 description: "Generates an image",
1842 executeFunc: func(ctx context.Context, call ToolCall) (ToolResponse, error) {
1843 resp := NewImageResponse(imageData, "image/png")
1844 return WithResponseMetadata(resp, ImageMetadata{Width: 800, Height: 600}), nil
1845 },
1846 }
1847
1848 model := &mockLanguageModel{
1849 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1850 if len(call.Prompt) == 1 {
1851 return &Response{
1852 Content: []Content{
1853 ToolCallContent{
1854 ToolCallID: "img-1",
1855 ToolName: "generate_image",
1856 Input: `{}`,
1857 },
1858 },
1859 Usage: Usage{TotalTokens: 10},
1860 FinishReason: FinishReasonToolCalls,
1861 }, nil
1862 }
1863 return &Response{
1864 Content: []Content{TextContent{Text: "Done"}},
1865 Usage: Usage{TotalTokens: 20},
1866 FinishReason: FinishReasonStop,
1867 }, nil
1868 },
1869 }
1870
1871 agent := NewAgent(model, WithTools(imageTool), WithStopConditions(StepCountIs(3)))
1872
1873 result, err := agent.Generate(context.Background(), AgentCall{
1874 Prompt: "Generate image",
1875 })
1876
1877 require.NoError(t, err)
1878 require.NotNil(t, result)
1879
1880 toolResults := result.Steps[0].Content.ToolResults()
1881 require.Len(t, toolResults, 1)
1882
1883 // Check metadata was preserved
1884 require.NotEmpty(t, toolResults[0].ClientMetadata)
1885
1886 var metadata ImageMetadata
1887 err = json.Unmarshal([]byte(toolResults[0].ClientMetadata), &metadata)
1888 require.NoError(t, err)
1889 require.Equal(t, 800, metadata.Width)
1890 require.Equal(t, 600, metadata.Height)
1891 })
1892}
1893
1894func TestToResponseMessages_ProviderExecutedRouting(t *testing.T) {
1895 t.Parallel()
1896
1897 // Build step content that mixes a provider-executed tool call/result
1898 // (e.g. web search) with a regular local tool call/result.
1899 content := []Content{
1900 // Provider-executed tool call.
1901 &ToolCallContent{
1902 ToolCallID: "srvtoolu_01",
1903 ToolName: "web_search",
1904 Input: `{"query":"test"}`,
1905 ProviderExecuted: true,
1906 },
1907 // Provider-executed tool result.
1908 &ToolResultContent{
1909 ToolCallID: "srvtoolu_01",
1910 ProviderExecuted: true,
1911 },
1912 // Regular (locally-executed) tool call.
1913 &ToolCallContent{
1914 ToolCallID: "toolu_02",
1915 ToolName: "calculator",
1916 Input: `{"expr":"1+1"}`,
1917 },
1918 // Regular tool result.
1919 &ToolResultContent{
1920 ToolCallID: "toolu_02",
1921 Result: ToolResultOutputContentText{Text: "2"},
1922 },
1923 // Some trailing text.
1924 &TextContent{Text: "Done."},
1925 }
1926
1927 msgs := toResponseMessages(content)
1928
1929 // Expect two messages: assistant + tool.
1930 require.Len(t, msgs, 2)
1931
1932 // Assistant message should contain:
1933 // 1. provider-executed ToolCallPart
1934 // 2. provider-executed ToolResultPart
1935 // 3. regular ToolCallPart
1936 // 4. TextPart
1937 assistant := msgs[0]
1938 require.Equal(t, MessageRoleAssistant, assistant.Role)
1939 require.Len(t, assistant.Content, 4)
1940
1941 // Verify provider-executed tool call is in assistant.
1942 tc1, ok := AsMessagePart[ToolCallPart](assistant.Content[0])
1943 require.True(t, ok)
1944 require.Equal(t, "srvtoolu_01", tc1.ToolCallID)
1945 require.True(t, tc1.ProviderExecuted)
1946
1947 // Verify provider-executed tool result is in assistant.
1948 tr1, ok := AsMessagePart[ToolResultPart](assistant.Content[1])
1949 require.True(t, ok)
1950 require.Equal(t, "srvtoolu_01", tr1.ToolCallID)
1951 require.True(t, tr1.ProviderExecuted)
1952
1953 // Verify regular tool call is in assistant.
1954 tc2, ok := AsMessagePart[ToolCallPart](assistant.Content[2])
1955 require.True(t, ok)
1956 require.Equal(t, "toolu_02", tc2.ToolCallID)
1957 require.False(t, tc2.ProviderExecuted)
1958
1959 // Verify text part is in assistant.
1960 text, ok := AsMessagePart[TextPart](assistant.Content[3])
1961 require.True(t, ok)
1962 require.Equal(t, "Done.", text.Text)
1963
1964 // Tool message should contain only the regular tool result.
1965 toolMsg := msgs[1]
1966 require.Equal(t, MessageRoleTool, toolMsg.Role)
1967 require.Len(t, toolMsg.Content, 1)
1968
1969 tr2, ok := AsMessagePart[ToolResultPart](toolMsg.Content[0])
1970 require.True(t, ok)
1971 require.Equal(t, "toolu_02", tr2.ToolCallID)
1972 require.False(t, tr2.ProviderExecuted)
1973}