1package fantasy
2
3import (
4 "context"
5 "encoding/json"
6 "errors"
7 "fmt"
8 "testing"
9
10 "github.com/stretchr/testify/require"
11)
12
13// Mock tool for testing
14type mockTool struct {
15 name string
16 providerOptions ProviderOptions
17 description string
18 parameters map[string]any
19 required []string
20 executeFunc func(ctx context.Context, call ToolCall) (ToolResponse, error)
21}
22
23func (m *mockTool) SetProviderOptions(opts ProviderOptions) {
24 m.providerOptions = opts
25}
26
27func (m *mockTool) ProviderOptions() ProviderOptions {
28 return m.providerOptions
29}
30
31func (m *mockTool) Info() ToolInfo {
32 return ToolInfo{
33 Name: m.name,
34 Description: m.description,
35 Parameters: m.parameters,
36 Required: m.required,
37 }
38}
39
40func (m *mockTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
41 if m.executeFunc != nil {
42 return m.executeFunc(ctx, call)
43 }
44 return ToolResponse{Content: "mock result", IsError: false}, nil
45}
46
47// Mock language model for testing
48type mockLanguageModel struct {
49 generateFunc func(ctx context.Context, call Call) (*Response, error)
50 streamFunc func(ctx context.Context, call Call) (StreamResponse, error)
51}
52
53func (m *mockLanguageModel) Generate(ctx context.Context, call Call) (*Response, error) {
54 if m.generateFunc != nil {
55 return m.generateFunc(ctx, call)
56 }
57 return &Response{
58 Content: []Content{
59 TextContent{Text: "Hello, world!"},
60 },
61 Usage: Usage{
62 InputTokens: 3,
63 OutputTokens: 10,
64 TotalTokens: 13,
65 },
66 FinishReason: FinishReasonStop,
67 }, nil
68}
69
70func (m *mockLanguageModel) Stream(ctx context.Context, call Call) (StreamResponse, error) {
71 if m.streamFunc != nil {
72 return m.streamFunc(ctx, call)
73 }
74 return nil, fmt.Errorf("mock stream not implemented")
75}
76
77func (m *mockLanguageModel) Provider() string {
78 return "mock-provider"
79}
80
81func (m *mockLanguageModel) Model() string {
82 return "mock-model"
83}
84
85func (m *mockLanguageModel) GenerateObject(ctx context.Context, call ObjectCall) (*ObjectResponse, error) {
86 return nil, fmt.Errorf("mock GenerateObject not implemented")
87}
88
89func (m *mockLanguageModel) StreamObject(ctx context.Context, call ObjectCall) (ObjectStreamResponse, error) {
90 return nil, fmt.Errorf("mock StreamObject not implemented")
91}
92
93// Test result.content - comprehensive content types (matches TS test)
94func TestAgent_Generate_ResultContent_AllTypes(t *testing.T) {
95 t.Parallel()
96
97 // Create a type-safe tool using the new API
98 type TestInput struct {
99 Value string `json:"value" description:"Test value"`
100 }
101
102 tool1 := NewAgentTool(
103 "tool1",
104 "Test tool",
105 func(ctx context.Context, input TestInput, _ ToolCall) (ToolResponse, error) {
106 require.Equal(t, "value", input.Value)
107 return ToolResponse{Content: "result1", IsError: false}, nil
108 },
109 )
110
111 model := &mockLanguageModel{
112 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
113 return &Response{
114 Content: []Content{
115 TextContent{Text: "Hello, world!"},
116 SourceContent{
117 ID: "123",
118 URL: "https://example.com",
119 Title: "Example",
120 SourceType: SourceTypeURL,
121 },
122 FileContent{
123 Data: []byte{1, 2, 3},
124 MediaType: "image/png",
125 },
126 ReasoningContent{
127 Text: "I will open the conversation with witty banter.",
128 },
129 ToolCallContent{
130 ToolCallID: "call-1",
131 ToolName: "tool1",
132 Input: `{"value":"value"}`,
133 },
134 TextContent{Text: "More text"},
135 },
136 Usage: Usage{
137 InputTokens: 3,
138 OutputTokens: 10,
139 TotalTokens: 13,
140 },
141 FinishReason: FinishReasonStop, // Note: FinishReasonStop, not ToolCalls
142 }, nil
143 },
144 }
145
146 agent := NewAgent(model, WithTools(tool1))
147 result, err := agent.Generate(context.Background(), AgentCall{
148 Prompt: "prompt",
149 })
150
151 require.NoError(t, err)
152 require.NotNil(t, result)
153 require.Len(t, result.Steps, 1) // Single step like TypeScript
154
155 // Check final response content includes tool result
156 require.Len(t, result.Response.Content, 7) // original 6 + 1 tool result
157
158 // Verify each content type in order
159 textContent, ok := AsContentType[TextContent](result.Response.Content[0])
160 require.True(t, ok)
161 require.Equal(t, "Hello, world!", textContent.Text)
162
163 sourceContent, ok := AsContentType[SourceContent](result.Response.Content[1])
164 require.True(t, ok)
165 require.Equal(t, "123", sourceContent.ID)
166
167 fileContent, ok := AsContentType[FileContent](result.Response.Content[2])
168 require.True(t, ok)
169 require.Equal(t, []byte{1, 2, 3}, fileContent.Data)
170
171 reasoningContent, ok := AsContentType[ReasoningContent](result.Response.Content[3])
172 require.True(t, ok)
173 require.Equal(t, "I will open the conversation with witty banter.", reasoningContent.Text)
174
175 toolCallContent, ok := AsContentType[ToolCallContent](result.Response.Content[4])
176 require.True(t, ok)
177 require.Equal(t, "call-1", toolCallContent.ToolCallID)
178
179 moreTextContent, ok := AsContentType[TextContent](result.Response.Content[5])
180 require.True(t, ok)
181 require.Equal(t, "More text", moreTextContent.Text)
182
183 // Tool result should be appended
184 toolResultContent, ok := AsContentType[ToolResultContent](result.Response.Content[6])
185 require.True(t, ok)
186 require.Equal(t, "call-1", toolResultContent.ToolCallID)
187 require.Equal(t, "tool1", toolResultContent.ToolName)
188}
189
190// Test result.text extraction
191func TestAgent_Generate_ResultText(t *testing.T) {
192 t.Parallel()
193
194 model := &mockLanguageModel{
195 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
196 return &Response{
197 Content: []Content{
198 TextContent{Text: "Hello, world!"},
199 },
200 Usage: Usage{
201 InputTokens: 3,
202 OutputTokens: 10,
203 TotalTokens: 13,
204 },
205 FinishReason: FinishReasonStop,
206 }, nil
207 },
208 }
209
210 agent := NewAgent(model)
211 result, err := agent.Generate(context.Background(), AgentCall{
212 Prompt: "prompt",
213 })
214
215 require.NoError(t, err)
216 require.NotNil(t, result)
217
218 // Test text extraction from content
219 text := result.Response.Content.Text()
220 require.Equal(t, "Hello, world!", text)
221}
222
223// Test result.toolCalls extraction (matches TS test exactly)
224func TestAgent_Generate_ResultToolCalls(t *testing.T) {
225 t.Parallel()
226
227 // Create type-safe tools using the new API
228 type Tool1Input struct {
229 Value string `json:"value" description:"Test value"`
230 }
231
232 type Tool2Input struct {
233 SomethingElse string `json:"somethingElse" description:"Another test value"`
234 }
235
236 tool1 := NewAgentTool(
237 "tool1",
238 "Test tool 1",
239 func(ctx context.Context, input Tool1Input, _ ToolCall) (ToolResponse, error) {
240 return ToolResponse{Content: "result1", IsError: false}, nil
241 },
242 )
243
244 tool2 := NewAgentTool(
245 "tool2",
246 "Test tool 2",
247 func(ctx context.Context, input Tool2Input, _ ToolCall) (ToolResponse, error) {
248 return ToolResponse{Content: "result2", IsError: false}, nil
249 },
250 )
251
252 model := &mockLanguageModel{
253 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
254 // Verify tools are passed correctly
255 require.Len(t, call.Tools, 2)
256 require.Equal(t, ToolChoiceAuto, *call.ToolChoice) // Should be auto, not required
257
258 // Verify prompt structure
259 require.Len(t, call.Prompt, 1)
260 require.Equal(t, MessageRoleUser, call.Prompt[0].Role)
261
262 return &Response{
263 Content: []Content{
264 ToolCallContent{
265 ToolCallID: "call-1",
266 ToolName: "tool1",
267 Input: `{"value":"value"}`,
268 },
269 },
270 Usage: Usage{
271 InputTokens: 3,
272 OutputTokens: 10,
273 TotalTokens: 13,
274 },
275 FinishReason: FinishReasonStop, // Note: Stop, not ToolCalls
276 }, nil
277 },
278 }
279
280 agent := NewAgent(model, WithTools(tool1, tool2))
281 result, err := agent.Generate(context.Background(), AgentCall{
282 Prompt: "test-input",
283 })
284
285 require.NoError(t, err)
286 require.NotNil(t, result)
287 require.Len(t, result.Steps, 1) // Single step
288
289 // Extract tool calls from final response (should be empty since tools don't execute)
290 var toolCalls []ToolCallContent
291 for _, content := range result.Response.Content {
292 if toolCall, ok := AsContentType[ToolCallContent](content); ok {
293 toolCalls = append(toolCalls, toolCall)
294 }
295 }
296
297 require.Len(t, toolCalls, 1)
298 require.Equal(t, "call-1", toolCalls[0].ToolCallID)
299 require.Equal(t, "tool1", toolCalls[0].ToolName)
300
301 // Parse and verify input
302 var input map[string]any
303 err = json.Unmarshal([]byte(toolCalls[0].Input), &input)
304 require.NoError(t, err)
305 require.Equal(t, "value", input["value"])
306}
307
308// Test result.toolResults extraction (matches TS test exactly)
309func TestAgent_Generate_ResultToolResults(t *testing.T) {
310 t.Parallel()
311
312 // Create type-safe tool using the new API
313 type TestInput struct {
314 Value string `json:"value" description:"Test value"`
315 }
316
317 tool1 := NewAgentTool(
318 "tool1",
319 "Test tool",
320 func(ctx context.Context, input TestInput, _ ToolCall) (ToolResponse, error) {
321 require.Equal(t, "value", input.Value)
322 return ToolResponse{Content: "result1", IsError: false}, nil
323 },
324 )
325
326 model := &mockLanguageModel{
327 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
328 // Verify tools and tool choice
329 require.Len(t, call.Tools, 1)
330 require.Equal(t, ToolChoiceAuto, *call.ToolChoice)
331
332 // Verify prompt
333 require.Len(t, call.Prompt, 1)
334 require.Equal(t, MessageRoleUser, call.Prompt[0].Role)
335
336 return &Response{
337 Content: []Content{
338 ToolCallContent{
339 ToolCallID: "call-1",
340 ToolName: "tool1",
341 Input: `{"value":"value"}`,
342 },
343 },
344 Usage: Usage{
345 InputTokens: 3,
346 OutputTokens: 10,
347 TotalTokens: 13,
348 },
349 FinishReason: FinishReasonStop, // Note: Stop, not ToolCalls
350 }, nil
351 },
352 }
353
354 agent := NewAgent(model, WithTools(tool1))
355 result, err := agent.Generate(context.Background(), AgentCall{
356 Prompt: "test-input",
357 })
358
359 require.NoError(t, err)
360 require.NotNil(t, result)
361 require.Len(t, result.Steps, 1) // Single step
362
363 // Extract tool results from final response
364 var toolResults []ToolResultContent
365 for _, content := range result.Response.Content {
366 if toolResult, ok := AsContentType[ToolResultContent](content); ok {
367 toolResults = append(toolResults, toolResult)
368 }
369 }
370
371 require.Len(t, toolResults, 1)
372 require.Equal(t, "call-1", toolResults[0].ToolCallID)
373 require.Equal(t, "tool1", toolResults[0].ToolName)
374
375 // Verify result content
376 textResult, ok := toolResults[0].Result.(ToolResultOutputContentText)
377 require.True(t, ok)
378 require.Equal(t, "result1", textResult.Text)
379}
380
381// Test multi-step scenario (matches TS "2 steps: initial, tool-result" test)
382func TestAgent_Generate_MultipleSteps(t *testing.T) {
383 t.Parallel()
384
385 // Create type-safe tool using the new API
386 type TestInput struct {
387 Value string `json:"value" description:"Test value"`
388 }
389
390 tool1 := NewAgentTool(
391 "tool1",
392 "Test tool",
393 func(ctx context.Context, input TestInput, _ ToolCall) (ToolResponse, error) {
394 require.Equal(t, "value", input.Value)
395 return ToolResponse{Content: "result1", IsError: false}, nil
396 },
397 )
398
399 callCount := 0
400 model := &mockLanguageModel{
401 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
402 callCount++
403 switch callCount {
404 case 1:
405 // First call - return tool call with FinishReasonToolCalls
406 return &Response{
407 Content: []Content{
408 ToolCallContent{
409 ToolCallID: "call-1",
410 ToolName: "tool1",
411 Input: `{"value":"value"}`,
412 },
413 },
414 Usage: Usage{
415 InputTokens: 10,
416 OutputTokens: 5,
417 TotalTokens: 15,
418 },
419 FinishReason: FinishReasonToolCalls, // This triggers multi-step
420 }, nil
421 case 2:
422 // Second call - return final text
423 return &Response{
424 Content: []Content{
425 TextContent{Text: "Hello, world!"},
426 },
427 Usage: Usage{
428 InputTokens: 3,
429 OutputTokens: 10,
430 TotalTokens: 13,
431 },
432 FinishReason: FinishReasonStop,
433 }, nil
434 default:
435 t.Fatalf("Unexpected call count: %d", callCount)
436 return nil, nil
437 }
438 },
439 }
440
441 agent := NewAgent(model, WithTools(tool1))
442 result, err := agent.Generate(context.Background(), AgentCall{
443 Prompt: "test-input",
444 })
445
446 require.NoError(t, err)
447 require.NotNil(t, result)
448 require.Len(t, result.Steps, 2)
449
450 // Check total usage sums both steps
451 require.Equal(t, int64(13), result.TotalUsage.InputTokens) // 10 + 3
452 require.Equal(t, int64(15), result.TotalUsage.OutputTokens) // 5 + 10
453 require.Equal(t, int64(28), result.TotalUsage.TotalTokens) // 15 + 13
454
455 // Final response should be from last step
456 require.Len(t, result.Response.Content, 1)
457 textContent, ok := AsContentType[TextContent](result.Response.Content[0])
458 require.True(t, ok)
459 require.Equal(t, "Hello, world!", textContent.Text)
460
461 // result.toolCalls should be empty (from last step)
462 var toolCalls []ToolCallContent
463 for _, content := range result.Response.Content {
464 if _, ok := AsContentType[ToolCallContent](content); ok {
465 toolCalls = append(toolCalls, content.(ToolCallContent))
466 }
467 }
468 require.Len(t, toolCalls, 0)
469
470 // result.toolResults should be empty (from last step)
471 var toolResults []ToolResultContent
472 for _, content := range result.Response.Content {
473 if _, ok := AsContentType[ToolResultContent](content); ok {
474 toolResults = append(toolResults, content.(ToolResultContent))
475 }
476 }
477 require.Len(t, toolResults, 0)
478}
479
480// Test basic text generation
481func TestAgent_Generate_BasicText(t *testing.T) {
482 t.Parallel()
483
484 model := &mockLanguageModel{
485 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
486 return &Response{
487 Content: []Content{
488 TextContent{Text: "Hello, world!"},
489 },
490 Usage: Usage{
491 InputTokens: 3,
492 OutputTokens: 10,
493 TotalTokens: 13,
494 },
495 FinishReason: FinishReasonStop,
496 }, nil
497 },
498 }
499
500 agent := NewAgent(model)
501 result, err := agent.Generate(context.Background(), AgentCall{
502 Prompt: "test prompt",
503 })
504
505 require.NoError(t, err)
506 require.NotNil(t, result)
507 require.Len(t, result.Steps, 1)
508
509 // Check final response
510 require.Len(t, result.Response.Content, 1)
511 textContent, ok := AsContentType[TextContent](result.Response.Content[0])
512 require.True(t, ok)
513 require.Equal(t, "Hello, world!", textContent.Text)
514
515 // Check usage
516 require.Equal(t, int64(3), result.Response.Usage.InputTokens)
517 require.Equal(t, int64(10), result.Response.Usage.OutputTokens)
518 require.Equal(t, int64(13), result.Response.Usage.TotalTokens)
519
520 // Check total usage
521 require.Equal(t, int64(3), result.TotalUsage.InputTokens)
522 require.Equal(t, int64(10), result.TotalUsage.OutputTokens)
523 require.Equal(t, int64(13), result.TotalUsage.TotalTokens)
524}
525
526// Test empty prompt error
527func TestAgent_Generate_EmptyPrompt(t *testing.T) {
528 t.Parallel()
529
530 model := &mockLanguageModel{}
531 agent := NewAgent(model)
532
533 result, err := agent.Generate(context.Background(), AgentCall{
534 Prompt: "", // Empty prompt should cause error
535 })
536
537 require.Error(t, err)
538 require.Nil(t, result)
539 require.Contains(t, err.Error(), "invalid argument: prompt can't be empty")
540}
541
542// Test with system prompt
543func TestAgent_Generate_WithSystemPrompt(t *testing.T) {
544 t.Parallel()
545
546 model := &mockLanguageModel{
547 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
548 // Verify system message is included
549 require.Len(t, call.Prompt, 2) // system + user
550 require.Equal(t, MessageRoleSystem, call.Prompt[0].Role)
551 require.Equal(t, MessageRoleUser, call.Prompt[1].Role)
552
553 systemPart, ok := call.Prompt[0].Content[0].(TextPart)
554 require.True(t, ok)
555 require.Equal(t, "You are a helpful assistant", systemPart.Text)
556
557 return &Response{
558 Content: []Content{
559 TextContent{Text: "Hello, world!"},
560 },
561 Usage: Usage{
562 InputTokens: 3,
563 OutputTokens: 10,
564 TotalTokens: 13,
565 },
566 FinishReason: FinishReasonStop,
567 }, nil
568 },
569 }
570
571 agent := NewAgent(model, WithSystemPrompt("You are a helpful assistant"))
572 result, err := agent.Generate(context.Background(), AgentCall{
573 Prompt: "test prompt",
574 })
575
576 require.NoError(t, err)
577 require.NotNil(t, result)
578}
579
580// Test options.activeTools filtering
581func TestAgent_Generate_OptionsActiveTools(t *testing.T) {
582 t.Parallel()
583
584 tool1 := &mockTool{
585 name: "tool1",
586 description: "Test tool 1",
587 parameters: map[string]any{
588 "value": map[string]any{"type": "string"},
589 },
590 required: []string{"value"},
591 }
592
593 tool2 := &mockTool{
594 name: "tool2",
595 description: "Test tool 2",
596 parameters: map[string]any{
597 "value": map[string]any{"type": "string"},
598 },
599 required: []string{"value"},
600 }
601
602 model := &mockLanguageModel{
603 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
604 // Verify only tool1 is available
605 require.Len(t, call.Tools, 1)
606 functionTool, ok := call.Tools[0].(FunctionTool)
607 require.True(t, ok)
608 require.Equal(t, "tool1", functionTool.Name)
609
610 return &Response{
611 Content: []Content{
612 TextContent{Text: "Hello, world!"},
613 },
614 Usage: Usage{
615 InputTokens: 3,
616 OutputTokens: 10,
617 TotalTokens: 13,
618 },
619 FinishReason: FinishReasonStop,
620 }, nil
621 },
622 }
623
624 agent := NewAgent(model, WithTools(tool1, tool2))
625 result, err := agent.Generate(context.Background(), AgentCall{
626 Prompt: "test-input",
627 ActiveTools: []string{"tool1"}, // Only tool1 should be active
628 })
629
630 require.NoError(t, err)
631 require.NotNil(t, result)
632}
633
634func TestResponseContent_Getters(t *testing.T) {
635 t.Parallel()
636
637 // Create test content with all types
638 content := ResponseContent{
639 TextContent{Text: "Hello world"},
640 ReasoningContent{Text: "Let me think..."},
641 FileContent{Data: []byte("file data"), MediaType: "text/plain"},
642 SourceContent{SourceType: SourceTypeURL, URL: "https://example.com", Title: "Example"},
643 ToolCallContent{ToolCallID: "call1", ToolName: "test_tool", Input: `{"arg": "value"}`},
644 ToolResultContent{ToolCallID: "call1", ToolName: "test_tool", Result: ToolResultOutputContentText{Text: "result"}},
645 }
646
647 // Test Text()
648 require.Equal(t, "Hello world", content.Text())
649
650 // Test Reasoning()
651 reasoning := content.Reasoning()
652 require.Len(t, reasoning, 1)
653 require.Equal(t, "Let me think...", reasoning[0].Text)
654
655 // Test ReasoningText()
656 require.Equal(t, "Let me think...", content.ReasoningText())
657
658 // Test Files()
659 files := content.Files()
660 require.Len(t, files, 1)
661 require.Equal(t, "text/plain", files[0].MediaType)
662 require.Equal(t, []byte("file data"), files[0].Data)
663
664 // Test Sources()
665 sources := content.Sources()
666 require.Len(t, sources, 1)
667 require.Equal(t, SourceTypeURL, sources[0].SourceType)
668 require.Equal(t, "https://example.com", sources[0].URL)
669 require.Equal(t, "Example", sources[0].Title)
670
671 // Test ToolCalls()
672 toolCalls := content.ToolCalls()
673 require.Len(t, toolCalls, 1)
674 require.Equal(t, "call1", toolCalls[0].ToolCallID)
675 require.Equal(t, "test_tool", toolCalls[0].ToolName)
676 require.Equal(t, `{"arg": "value"}`, toolCalls[0].Input)
677
678 // Test ToolResults()
679 toolResults := content.ToolResults()
680 require.Len(t, toolResults, 1)
681 require.Equal(t, "call1", toolResults[0].ToolCallID)
682 require.Equal(t, "test_tool", toolResults[0].ToolName)
683 result, ok := AsToolResultOutputType[ToolResultOutputContentText](toolResults[0].Result)
684 require.True(t, ok)
685 require.Equal(t, "result", result.Text)
686}
687
688func TestResponseContent_Getters_Empty(t *testing.T) {
689 t.Parallel()
690
691 // Test with empty content
692 content := ResponseContent{}
693
694 require.Equal(t, "", content.Text())
695 require.Equal(t, "", content.ReasoningText())
696 require.Empty(t, content.Reasoning())
697 require.Empty(t, content.Files())
698 require.Empty(t, content.Sources())
699 require.Empty(t, content.ToolCalls())
700 require.Empty(t, content.ToolResults())
701}
702
703func TestResponseContent_Getters_MultipleItems(t *testing.T) {
704 t.Parallel()
705
706 // Test with multiple items of same type
707 content := ResponseContent{
708 ReasoningContent{Text: "First thought"},
709 ReasoningContent{Text: "Second thought"},
710 FileContent{Data: []byte("file1"), MediaType: "text/plain"},
711 FileContent{Data: []byte("file2"), MediaType: "image/png"},
712 }
713
714 // Test multiple reasoning
715 reasoning := content.Reasoning()
716 require.Len(t, reasoning, 2)
717 require.Equal(t, "First thought", reasoning[0].Text)
718 require.Equal(t, "Second thought", reasoning[1].Text)
719
720 // Test concatenated reasoning text
721 require.Equal(t, "First thoughtSecond thought", content.ReasoningText())
722
723 // Test multiple files
724 files := content.Files()
725 require.Len(t, files, 2)
726 require.Equal(t, "text/plain", files[0].MediaType)
727 require.Equal(t, "image/png", files[1].MediaType)
728}
729
730func TestStopConditions(t *testing.T) {
731 t.Parallel()
732
733 // Create test steps
734 step1 := StepResult{
735 Response: Response{
736 Content: ResponseContent{
737 TextContent{Text: "Hello"},
738 },
739 FinishReason: FinishReasonToolCalls,
740 Usage: Usage{TotalTokens: 10},
741 },
742 }
743
744 step2 := StepResult{
745 Response: Response{
746 Content: ResponseContent{
747 TextContent{Text: "World"},
748 ToolCallContent{ToolCallID: "call1", ToolName: "search", Input: `{"query": "test"}`},
749 },
750 FinishReason: FinishReasonStop,
751 Usage: Usage{TotalTokens: 15},
752 },
753 }
754
755 step3 := StepResult{
756 Response: Response{
757 Content: ResponseContent{
758 ReasoningContent{Text: "Let me think..."},
759 FileContent{Data: []byte("data"), MediaType: "text/plain"},
760 },
761 FinishReason: FinishReasonLength,
762 Usage: Usage{TotalTokens: 20},
763 },
764 }
765
766 t.Run("StepCountIs", func(t *testing.T) {
767 t.Parallel()
768 condition := StepCountIs(2)
769
770 // Should not stop with 1 step
771 require.False(t, condition([]StepResult{step1}))
772
773 // Should stop with 2 steps
774 require.True(t, condition([]StepResult{step1, step2}))
775
776 // Should stop with more than 2 steps
777 require.True(t, condition([]StepResult{step1, step2, step3}))
778
779 // Should not stop with empty steps
780 require.False(t, condition([]StepResult{}))
781 })
782
783 t.Run("HasToolCall", func(t *testing.T) {
784 t.Parallel()
785 condition := HasToolCall("search")
786
787 // Should not stop when tool not called
788 require.False(t, condition([]StepResult{step1}))
789
790 // Should stop when tool is called in last step
791 require.True(t, condition([]StepResult{step1, step2}))
792
793 // Should not stop when tool called in earlier step but not last
794 require.False(t, condition([]StepResult{step1, step2, step3}))
795
796 // Should not stop with empty steps
797 require.False(t, condition([]StepResult{}))
798
799 // Should not stop when different tool is called
800 differentToolCondition := HasToolCall("different_tool")
801 require.False(t, differentToolCondition([]StepResult{step1, step2}))
802 })
803
804 t.Run("HasContent", func(t *testing.T) {
805 t.Parallel()
806 reasoningCondition := HasContent(ContentTypeReasoning)
807 fileCondition := HasContent(ContentTypeFile)
808
809 // Should not stop when content type not present
810 require.False(t, reasoningCondition([]StepResult{step1, step2}))
811
812 // Should stop when content type is present in last step
813 require.True(t, reasoningCondition([]StepResult{step1, step2, step3}))
814 require.True(t, fileCondition([]StepResult{step1, step2, step3}))
815
816 // Should not stop with empty steps
817 require.False(t, reasoningCondition([]StepResult{}))
818 })
819
820 t.Run("FinishReasonIs", func(t *testing.T) {
821 t.Parallel()
822 stopCondition := FinishReasonIs(FinishReasonStop)
823 lengthCondition := FinishReasonIs(FinishReasonLength)
824
825 // Should not stop when finish reason doesn't match
826 require.False(t, stopCondition([]StepResult{step1}))
827
828 // Should stop when finish reason matches in last step
829 require.True(t, stopCondition([]StepResult{step1, step2}))
830 require.True(t, lengthCondition([]StepResult{step1, step2, step3}))
831
832 // Should not stop with empty steps
833 require.False(t, stopCondition([]StepResult{}))
834 })
835
836 t.Run("MaxTokensUsed", func(t *testing.T) {
837 condition := MaxTokensUsed(30)
838
839 // Should not stop when under limit
840 require.False(t, condition([]StepResult{step1})) // 10 tokens
841 require.False(t, condition([]StepResult{step1, step2})) // 25 tokens
842
843 // Should stop when at or over limit
844 require.True(t, condition([]StepResult{step1, step2, step3})) // 45 tokens
845
846 // Should not stop with empty steps
847 require.False(t, condition([]StepResult{}))
848 })
849}
850
851func TestStopConditions_Integration(t *testing.T) {
852 t.Parallel()
853
854 t.Run("StepCountIs integration", func(t *testing.T) {
855 t.Parallel()
856 model := &mockLanguageModel{
857 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
858 return &Response{
859 Content: ResponseContent{
860 TextContent{Text: "Mock response"},
861 },
862 Usage: Usage{
863 InputTokens: 3,
864 OutputTokens: 10,
865 TotalTokens: 13,
866 },
867 FinishReason: FinishReasonStop,
868 }, nil
869 },
870 }
871
872 agent := NewAgent(model, WithStopConditions(StepCountIs(1)))
873
874 result, err := agent.Generate(context.Background(), AgentCall{
875 Prompt: "test prompt",
876 })
877
878 require.NoError(t, err)
879 require.NotNil(t, result)
880 require.Len(t, result.Steps, 1) // Should stop after 1 step
881 })
882
883 t.Run("Multiple stop conditions", func(t *testing.T) {
884 t.Parallel()
885 model := &mockLanguageModel{
886 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
887 return &Response{
888 Content: ResponseContent{
889 TextContent{Text: "Mock response"},
890 },
891 Usage: Usage{
892 InputTokens: 3,
893 OutputTokens: 10,
894 TotalTokens: 13,
895 },
896 FinishReason: FinishReasonStop,
897 }, nil
898 },
899 }
900
901 agent := NewAgent(model, WithStopConditions(
902 StepCountIs(5), // Stop after 5 steps
903 FinishReasonIs(FinishReasonStop), // Or stop on finish reason
904 ))
905
906 result, err := agent.Generate(context.Background(), AgentCall{
907 Prompt: "test prompt",
908 })
909
910 require.NoError(t, err)
911 require.NotNil(t, result)
912 // Should stop on first condition met (finish reason stop)
913 require.Equal(t, FinishReasonStop, result.Response.FinishReason)
914 })
915}
916
917func TestPrepareStep(t *testing.T) {
918 t.Parallel()
919
920 t.Run("System prompt modification", func(t *testing.T) {
921 t.Parallel()
922 var capturedSystemPrompt string
923 model := &mockLanguageModel{
924 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
925 // Capture the system message to verify it was modified
926 if len(call.Prompt) > 0 && call.Prompt[0].Role == MessageRoleSystem {
927 if len(call.Prompt[0].Content) > 0 {
928 if textPart, ok := AsContentType[TextPart](call.Prompt[0].Content[0]); ok {
929 capturedSystemPrompt = textPart.Text
930 }
931 }
932 }
933 return &Response{
934 Content: ResponseContent{
935 TextContent{Text: "Response"},
936 },
937 Usage: Usage{TotalTokens: 10},
938 FinishReason: FinishReasonStop,
939 }, nil
940 },
941 }
942
943 prepareStepFunc := func(ctx context.Context, options PrepareStepFunctionOptions) (context.Context, PrepareStepResult, error) {
944 newSystem := "Modified system prompt for step " + fmt.Sprintf("%d", options.StepNumber)
945 return ctx, PrepareStepResult{
946 Model: options.Model,
947 Messages: options.Messages,
948 System: &newSystem,
949 }, nil
950 }
951
952 agent := NewAgent(model, WithSystemPrompt("Original system prompt"))
953
954 result, err := agent.Generate(context.Background(), AgentCall{
955 Prompt: "test prompt",
956 PrepareStep: prepareStepFunc,
957 })
958
959 require.NoError(t, err)
960 require.NotNil(t, result)
961 require.Equal(t, "Modified system prompt for step 0", capturedSystemPrompt)
962 })
963
964 t.Run("Tool choice modification", func(t *testing.T) {
965 t.Parallel()
966 var capturedToolChoice *ToolChoice
967 model := &mockLanguageModel{
968 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
969 capturedToolChoice = call.ToolChoice
970 return &Response{
971 Content: ResponseContent{
972 TextContent{Text: "Response"},
973 },
974 Usage: Usage{TotalTokens: 10},
975 FinishReason: FinishReasonStop,
976 }, nil
977 },
978 }
979
980 prepareStepFunc := func(ctx context.Context, options PrepareStepFunctionOptions) (context.Context, PrepareStepResult, error) {
981 toolChoice := ToolChoiceNone
982 return ctx, PrepareStepResult{
983 Model: options.Model,
984 Messages: options.Messages,
985 ToolChoice: &toolChoice,
986 }, nil
987 }
988
989 agent := NewAgent(model)
990
991 result, err := agent.Generate(context.Background(), AgentCall{
992 Prompt: "test prompt",
993 PrepareStep: prepareStepFunc,
994 })
995
996 require.NoError(t, err)
997 require.NotNil(t, result)
998 require.NotNil(t, capturedToolChoice)
999 require.Equal(t, ToolChoiceNone, *capturedToolChoice)
1000 })
1001
1002 t.Run("Active tools modification", func(t *testing.T) {
1003 t.Parallel()
1004 var capturedToolNames []string
1005 model := &mockLanguageModel{
1006 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1007 // Capture tool names to verify active tools were modified
1008 for _, tool := range call.Tools {
1009 capturedToolNames = append(capturedToolNames, tool.GetName())
1010 }
1011 return &Response{
1012 Content: ResponseContent{
1013 TextContent{Text: "Response"},
1014 },
1015 Usage: Usage{TotalTokens: 10},
1016 FinishReason: FinishReasonStop,
1017 }, nil
1018 },
1019 }
1020
1021 tool1 := &mockTool{name: "tool1", description: "Tool 1"}
1022 tool2 := &mockTool{name: "tool2", description: "Tool 2"}
1023 tool3 := &mockTool{name: "tool3", description: "Tool 3"}
1024
1025 prepareStepFunc := func(ctx context.Context, options PrepareStepFunctionOptions) (context.Context, PrepareStepResult, error) {
1026 activeTools := []string{"tool2"} // Only tool2 should be active
1027 return ctx, PrepareStepResult{
1028 Model: options.Model,
1029 Messages: options.Messages,
1030 ActiveTools: activeTools,
1031 }, nil
1032 }
1033
1034 agent := NewAgent(model, WithTools(tool1, tool2, tool3))
1035
1036 result, err := agent.Generate(context.Background(), AgentCall{
1037 Prompt: "test prompt",
1038 PrepareStep: prepareStepFunc,
1039 })
1040
1041 require.NoError(t, err)
1042 require.NotNil(t, result)
1043 require.Len(t, capturedToolNames, 1)
1044 require.Equal(t, "tool2", capturedToolNames[0])
1045 })
1046
1047 t.Run("No tools when DisableAllTools is true", func(t *testing.T) {
1048 t.Parallel()
1049 var capturedToolCount int
1050 model := &mockLanguageModel{
1051 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1052 capturedToolCount = len(call.Tools)
1053 return &Response{
1054 Content: ResponseContent{
1055 TextContent{Text: "Response"},
1056 },
1057 Usage: Usage{TotalTokens: 10},
1058 FinishReason: FinishReasonStop,
1059 }, nil
1060 },
1061 }
1062
1063 tool1 := &mockTool{name: "tool1", description: "Tool 1"}
1064
1065 prepareStepFunc := func(ctx context.Context, options PrepareStepFunctionOptions) (context.Context, PrepareStepResult, error) {
1066 return ctx, PrepareStepResult{
1067 Model: options.Model,
1068 Messages: options.Messages,
1069 DisableAllTools: true, // Disable all tools for this step
1070 }, nil
1071 }
1072
1073 agent := NewAgent(model, WithTools(tool1))
1074
1075 result, err := agent.Generate(context.Background(), AgentCall{
1076 Prompt: "test prompt",
1077 PrepareStep: prepareStepFunc,
1078 })
1079
1080 require.NoError(t, err)
1081 require.NotNil(t, result)
1082 require.Equal(t, 0, capturedToolCount) // No tools should be passed
1083 })
1084
1085 t.Run("All fields modified together", func(t *testing.T) {
1086 t.Parallel()
1087 var capturedSystemPrompt string
1088 var capturedToolChoice *ToolChoice
1089 var capturedToolNames []string
1090
1091 model := &mockLanguageModel{
1092 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1093 // Capture system prompt
1094 if len(call.Prompt) > 0 && call.Prompt[0].Role == MessageRoleSystem {
1095 if len(call.Prompt[0].Content) > 0 {
1096 if textPart, ok := AsContentType[TextPart](call.Prompt[0].Content[0]); ok {
1097 capturedSystemPrompt = textPart.Text
1098 }
1099 }
1100 }
1101 // Capture tool choice
1102 capturedToolChoice = call.ToolChoice
1103 // Capture tool names
1104 for _, tool := range call.Tools {
1105 capturedToolNames = append(capturedToolNames, tool.GetName())
1106 }
1107 return &Response{
1108 Content: ResponseContent{
1109 TextContent{Text: "Response"},
1110 },
1111 Usage: Usage{TotalTokens: 10},
1112 FinishReason: FinishReasonStop,
1113 }, nil
1114 },
1115 }
1116
1117 tool1 := &mockTool{name: "tool1", description: "Tool 1"}
1118 tool2 := &mockTool{name: "tool2", description: "Tool 2"}
1119
1120 prepareStepFunc := func(ctx context.Context, options PrepareStepFunctionOptions) (context.Context, PrepareStepResult, error) {
1121 newSystem := "Step-specific system"
1122 toolChoice := SpecificToolChoice("tool1")
1123 activeTools := []string{"tool1"}
1124 return ctx, PrepareStepResult{
1125 Model: options.Model,
1126 Messages: options.Messages,
1127 System: &newSystem,
1128 ToolChoice: &toolChoice,
1129 ActiveTools: activeTools,
1130 }, nil
1131 }
1132
1133 agent := NewAgent(model, WithSystemPrompt("Original system"), WithTools(tool1, tool2))
1134
1135 result, err := agent.Generate(context.Background(), AgentCall{
1136 Prompt: "test prompt",
1137 PrepareStep: prepareStepFunc,
1138 })
1139
1140 require.NoError(t, err)
1141 require.NotNil(t, result)
1142 require.Equal(t, "Step-specific system", capturedSystemPrompt)
1143 require.NotNil(t, capturedToolChoice)
1144 require.Equal(t, SpecificToolChoice("tool1"), *capturedToolChoice)
1145 require.Len(t, capturedToolNames, 1)
1146 require.Equal(t, "tool1", capturedToolNames[0])
1147 })
1148
1149 t.Run("Nil fields use parent values", func(t *testing.T) {
1150 t.Parallel()
1151 var capturedSystemPrompt string
1152 var capturedToolChoice *ToolChoice
1153 var capturedToolNames []string
1154
1155 model := &mockLanguageModel{
1156 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1157 // Capture system prompt
1158 if len(call.Prompt) > 0 && call.Prompt[0].Role == MessageRoleSystem {
1159 if len(call.Prompt[0].Content) > 0 {
1160 if textPart, ok := AsContentType[TextPart](call.Prompt[0].Content[0]); ok {
1161 capturedSystemPrompt = textPart.Text
1162 }
1163 }
1164 }
1165 // Capture tool choice
1166 capturedToolChoice = call.ToolChoice
1167 // Capture tool names
1168 for _, tool := range call.Tools {
1169 capturedToolNames = append(capturedToolNames, tool.GetName())
1170 }
1171 return &Response{
1172 Content: ResponseContent{
1173 TextContent{Text: "Response"},
1174 },
1175 Usage: Usage{TotalTokens: 10},
1176 FinishReason: FinishReasonStop,
1177 }, nil
1178 },
1179 }
1180
1181 tool1 := &mockTool{name: "tool1", description: "Tool 1"}
1182
1183 prepareStepFunc := func(ctx context.Context, options PrepareStepFunctionOptions) (context.Context, PrepareStepResult, error) {
1184 // All optional fields are nil, should use parent values
1185 return ctx, PrepareStepResult{
1186 Model: options.Model,
1187 Messages: options.Messages,
1188 System: nil, // Use parent
1189 ToolChoice: nil, // Use parent (auto)
1190 ActiveTools: nil, // Use parent (all tools)
1191 }, nil
1192 }
1193
1194 agent := NewAgent(model, WithSystemPrompt("Parent system"), WithTools(tool1))
1195
1196 result, err := agent.Generate(context.Background(), AgentCall{
1197 Prompt: "test prompt",
1198 PrepareStep: prepareStepFunc,
1199 })
1200
1201 require.NoError(t, err)
1202 require.NotNil(t, result)
1203 require.Equal(t, "Parent system", capturedSystemPrompt)
1204 require.NotNil(t, capturedToolChoice)
1205 require.Equal(t, ToolChoiceAuto, *capturedToolChoice) // Default
1206 require.Len(t, capturedToolNames, 1)
1207 require.Equal(t, "tool1", capturedToolNames[0])
1208 })
1209
1210 t.Run("Empty ActiveTools means all tools", func(t *testing.T) {
1211 t.Parallel()
1212 var capturedToolNames []string
1213 model := &mockLanguageModel{
1214 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1215 // Capture tool names to verify all tools are included
1216 for _, tool := range call.Tools {
1217 capturedToolNames = append(capturedToolNames, tool.GetName())
1218 }
1219 return &Response{
1220 Content: ResponseContent{
1221 TextContent{Text: "Response"},
1222 },
1223 Usage: Usage{TotalTokens: 10},
1224 FinishReason: FinishReasonStop,
1225 }, nil
1226 },
1227 }
1228
1229 tool1 := &mockTool{name: "tool1", description: "Tool 1"}
1230 tool2 := &mockTool{name: "tool2", description: "Tool 2"}
1231
1232 prepareStepFunc := func(ctx context.Context, options PrepareStepFunctionOptions) (context.Context, PrepareStepResult, error) {
1233 return ctx, PrepareStepResult{
1234 Model: options.Model,
1235 Messages: options.Messages,
1236 ActiveTools: []string{}, // Empty slice means all tools
1237 }, nil
1238 }
1239
1240 agent := NewAgent(model, WithTools(tool1, tool2))
1241
1242 result, err := agent.Generate(context.Background(), AgentCall{
1243 Prompt: "test prompt",
1244 PrepareStep: prepareStepFunc,
1245 })
1246
1247 require.NoError(t, err)
1248 require.NotNil(t, result)
1249 require.Len(t, capturedToolNames, 2) // All tools should be included
1250 require.Contains(t, capturedToolNames, "tool1")
1251 require.Contains(t, capturedToolNames, "tool2")
1252 })
1253}
1254
1255func TestToolCallRepair(t *testing.T) {
1256 t.Parallel()
1257
1258 t.Run("Valid tool call passes validation", func(t *testing.T) {
1259 t.Parallel()
1260 model := &mockLanguageModel{
1261 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1262 return &Response{
1263 Content: ResponseContent{
1264 TextContent{Text: "Response"},
1265 ToolCallContent{
1266 ToolCallID: "call1",
1267 ToolName: "test_tool",
1268 Input: `{"value": "test"}`, // Valid JSON with required field
1269 },
1270 },
1271 Usage: Usage{TotalTokens: 10},
1272 FinishReason: FinishReasonStop, // Changed to stop to avoid infinite loop
1273 }, nil
1274 },
1275 }
1276
1277 tool := &mockTool{
1278 name: "test_tool",
1279 description: "Test tool",
1280 parameters: map[string]any{
1281 "value": map[string]any{"type": "string"},
1282 },
1283 required: []string{"value"},
1284 executeFunc: func(ctx context.Context, call ToolCall) (ToolResponse, error) {
1285 return ToolResponse{Content: "success", IsError: false}, nil
1286 },
1287 }
1288
1289 agent := NewAgent(model, WithTools(tool), WithStopConditions(StepCountIs(2))) // Limit steps
1290
1291 result, err := agent.Generate(context.Background(), AgentCall{
1292 Prompt: "test prompt",
1293 })
1294
1295 require.NoError(t, err)
1296 require.NotNil(t, result)
1297 require.Len(t, result.Steps, 1) // Only one step since FinishReason is stop
1298
1299 // Check that tool call was executed successfully
1300 toolCalls := result.Steps[0].Content.ToolCalls()
1301 require.Len(t, toolCalls, 1)
1302 require.False(t, toolCalls[0].Invalid) // Should be valid
1303 })
1304
1305 t.Run("Invalid tool call without repair function", func(t *testing.T) {
1306 t.Parallel()
1307 model := &mockLanguageModel{
1308 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1309 return &Response{
1310 Content: ResponseContent{
1311 TextContent{Text: "Response"},
1312 ToolCallContent{
1313 ToolCallID: "call1",
1314 ToolName: "test_tool",
1315 Input: `{"wrong_field": "test"}`, // Missing required field
1316 },
1317 },
1318 Usage: Usage{TotalTokens: 10},
1319 FinishReason: FinishReasonStop, // Changed to stop to avoid infinite loop
1320 }, nil
1321 },
1322 }
1323
1324 tool := &mockTool{
1325 name: "test_tool",
1326 description: "Test tool",
1327 parameters: map[string]any{
1328 "value": map[string]any{"type": "string"},
1329 },
1330 required: []string{"value"},
1331 }
1332
1333 agent := NewAgent(model, WithTools(tool), WithStopConditions(StepCountIs(2))) // Limit steps
1334
1335 result, err := agent.Generate(context.Background(), AgentCall{
1336 Prompt: "test prompt",
1337 })
1338
1339 require.NoError(t, err)
1340 require.NotNil(t, result)
1341 require.Len(t, result.Steps, 1) // Only one step
1342
1343 // Check that tool call was marked as invalid
1344 toolCalls := result.Steps[0].Content.ToolCalls()
1345 require.Len(t, toolCalls, 1)
1346 require.True(t, toolCalls[0].Invalid) // Should be invalid
1347 require.Contains(t, toolCalls[0].ValidationError.Error(), "missing required parameter: value")
1348 })
1349
1350 t.Run("Invalid tool call with successful repair", func(t *testing.T) {
1351 t.Parallel()
1352 model := &mockLanguageModel{
1353 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1354 return &Response{
1355 Content: ResponseContent{
1356 TextContent{Text: "Response"},
1357 ToolCallContent{
1358 ToolCallID: "call1",
1359 ToolName: "test_tool",
1360 Input: `{"wrong_field": "test"}`, // Missing required field
1361 },
1362 },
1363 Usage: Usage{TotalTokens: 10},
1364 FinishReason: FinishReasonStop, // Changed to stop
1365 }, nil
1366 },
1367 }
1368
1369 tool := &mockTool{
1370 name: "test_tool",
1371 description: "Test tool",
1372 parameters: map[string]any{
1373 "value": map[string]any{"type": "string"},
1374 },
1375 required: []string{"value"},
1376 executeFunc: func(ctx context.Context, call ToolCall) (ToolResponse, error) {
1377 return ToolResponse{Content: "repaired_success", IsError: false}, nil
1378 },
1379 }
1380
1381 repairFunc := func(ctx context.Context, options ToolCallRepairOptions) (*ToolCallContent, error) {
1382 // Simple repair: add the missing required field
1383 repairedToolCall := options.OriginalToolCall
1384 repairedToolCall.Input = `{"value": "repaired"}`
1385 return &repairedToolCall, nil
1386 }
1387
1388 agent := NewAgent(model, WithTools(tool), WithRepairToolCall(repairFunc), WithStopConditions(StepCountIs(2)))
1389
1390 result, err := agent.Generate(context.Background(), AgentCall{
1391 Prompt: "test prompt",
1392 })
1393
1394 require.NoError(t, err)
1395 require.NotNil(t, result)
1396 require.Len(t, result.Steps, 1) // Only one step
1397
1398 // Check that tool call was repaired and is now valid
1399 toolCalls := result.Steps[0].Content.ToolCalls()
1400 require.Len(t, toolCalls, 1)
1401 require.False(t, toolCalls[0].Invalid) // Should be valid after repair
1402 require.Equal(t, `{"value": "repaired"}`, toolCalls[0].Input) // Should have repaired input
1403 })
1404
1405 t.Run("Invalid tool call with failed repair", func(t *testing.T) {
1406 t.Parallel()
1407 model := &mockLanguageModel{
1408 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1409 return &Response{
1410 Content: ResponseContent{
1411 TextContent{Text: "Response"},
1412 ToolCallContent{
1413 ToolCallID: "call1",
1414 ToolName: "test_tool",
1415 Input: `{"wrong_field": "test"}`, // Missing required field
1416 },
1417 },
1418 Usage: Usage{TotalTokens: 10},
1419 FinishReason: FinishReasonStop, // Changed to stop
1420 }, nil
1421 },
1422 }
1423
1424 tool := &mockTool{
1425 name: "test_tool",
1426 description: "Test tool",
1427 parameters: map[string]any{
1428 "value": map[string]any{"type": "string"},
1429 },
1430 required: []string{"value"},
1431 }
1432
1433 repairFunc := func(ctx context.Context, options ToolCallRepairOptions) (*ToolCallContent, error) {
1434 // Repair function fails
1435 return nil, errors.New("repair failed")
1436 }
1437
1438 agent := NewAgent(model, WithTools(tool), WithRepairToolCall(repairFunc), WithStopConditions(StepCountIs(2)))
1439
1440 result, err := agent.Generate(context.Background(), AgentCall{
1441 Prompt: "test prompt",
1442 })
1443
1444 require.NoError(t, err)
1445 require.NotNil(t, result)
1446 require.Len(t, result.Steps, 1) // Only one step
1447
1448 // Check that tool call was marked as invalid since repair failed
1449 toolCalls := result.Steps[0].Content.ToolCalls()
1450 require.Len(t, toolCalls, 1)
1451 require.True(t, toolCalls[0].Invalid) // Should be invalid
1452 require.Contains(t, toolCalls[0].ValidationError.Error(), "missing required parameter: value")
1453 })
1454
1455 t.Run("Nonexistent tool call", func(t *testing.T) {
1456 t.Parallel()
1457 model := &mockLanguageModel{
1458 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1459 return &Response{
1460 Content: ResponseContent{
1461 TextContent{Text: "Response"},
1462 ToolCallContent{
1463 ToolCallID: "call1",
1464 ToolName: "nonexistent_tool",
1465 Input: `{"value": "test"}`,
1466 },
1467 },
1468 Usage: Usage{TotalTokens: 10},
1469 FinishReason: FinishReasonStop, // Changed to stop
1470 }, nil
1471 },
1472 }
1473
1474 tool := &mockTool{name: "test_tool", description: "Test tool"}
1475
1476 agent := NewAgent(model, WithTools(tool), WithStopConditions(StepCountIs(2)))
1477
1478 result, err := agent.Generate(context.Background(), AgentCall{
1479 Prompt: "test prompt",
1480 })
1481
1482 require.NoError(t, err)
1483 require.NotNil(t, result)
1484 require.Len(t, result.Steps, 1) // Only one step
1485
1486 // Check that tool call was marked as invalid due to nonexistent tool
1487 toolCalls := result.Steps[0].Content.ToolCalls()
1488 require.Len(t, toolCalls, 1)
1489 require.True(t, toolCalls[0].Invalid) // Should be invalid
1490 require.Contains(t, toolCalls[0].ValidationError.Error(), "tool not found: nonexistent_tool")
1491 })
1492
1493 t.Run("Invalid JSON in tool call", func(t *testing.T) {
1494 t.Parallel()
1495 model := &mockLanguageModel{
1496 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1497 return &Response{
1498 Content: ResponseContent{
1499 TextContent{Text: "Response"},
1500 ToolCallContent{
1501 ToolCallID: "call1",
1502 ToolName: "test_tool",
1503 Input: `{invalid json}`, // Invalid JSON
1504 },
1505 },
1506 Usage: Usage{TotalTokens: 10},
1507 FinishReason: FinishReasonStop, // Changed to stop
1508 }, nil
1509 },
1510 }
1511
1512 tool := &mockTool{
1513 name: "test_tool",
1514 description: "Test tool",
1515 parameters: map[string]any{
1516 "value": map[string]any{"type": "string"},
1517 },
1518 required: []string{"value"},
1519 }
1520
1521 agent := NewAgent(model, WithTools(tool), WithStopConditions(StepCountIs(2)))
1522
1523 result, err := agent.Generate(context.Background(), AgentCall{
1524 Prompt: "test prompt",
1525 })
1526
1527 require.NoError(t, err)
1528 require.NotNil(t, result)
1529 require.Len(t, result.Steps, 1) // Only one step
1530
1531 // Check that tool call was marked as invalid due to invalid JSON
1532 toolCalls := result.Steps[0].Content.ToolCalls()
1533 require.Len(t, toolCalls, 1)
1534 require.True(t, toolCalls[0].Invalid) // Should be invalid
1535 require.Contains(t, toolCalls[0].ValidationError.Error(), "invalid JSON input")
1536 })
1537}
1538
1539// Test media and image tool responses
1540func TestAgent_MediaToolResponses(t *testing.T) {
1541 t.Parallel()
1542
1543 imageData := []byte{0x89, 0x50, 0x4E, 0x47} // PNG header bytes
1544 audioData := []byte{0x52, 0x49, 0x46, 0x46} // RIFF header bytes
1545
1546 t.Run("Image tool response", func(t *testing.T) {
1547 t.Parallel()
1548
1549 imageTool := &mockTool{
1550 name: "generate_image",
1551 description: "Generates an image",
1552 executeFunc: func(ctx context.Context, call ToolCall) (ToolResponse, error) {
1553 return NewImageResponse(imageData, "image/png"), nil
1554 },
1555 }
1556
1557 model := &mockLanguageModel{
1558 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1559 if len(call.Prompt) == 1 {
1560 // First call - request image tool
1561 return &Response{
1562 Content: []Content{
1563 ToolCallContent{
1564 ToolCallID: "img-1",
1565 ToolName: "generate_image",
1566 Input: `{}`,
1567 },
1568 },
1569 Usage: Usage{TotalTokens: 10},
1570 FinishReason: FinishReasonToolCalls,
1571 }, nil
1572 }
1573 // Second call - after tool execution
1574 return &Response{
1575 Content: []Content{TextContent{Text: "Image generated"}},
1576 Usage: Usage{TotalTokens: 20},
1577 FinishReason: FinishReasonStop,
1578 }, nil
1579 },
1580 }
1581
1582 agent := NewAgent(model, WithTools(imageTool), WithStopConditions(StepCountIs(3)))
1583
1584 result, err := agent.Generate(context.Background(), AgentCall{
1585 Prompt: "Generate an image",
1586 })
1587
1588 require.NoError(t, err)
1589 require.NotNil(t, result)
1590 require.Len(t, result.Steps, 2) // Tool call step + final response
1591
1592 // Check tool results in first step
1593 toolResults := result.Steps[0].Content.ToolResults()
1594 require.Len(t, toolResults, 1)
1595
1596 mediaResult, ok := toolResults[0].Result.(ToolResultOutputContentMedia)
1597 require.True(t, ok, "Expected media result")
1598 require.Equal(t, string(imageData), mediaResult.Data)
1599 require.Equal(t, "image/png", mediaResult.MediaType)
1600 })
1601
1602 t.Run("Media tool response (audio)", func(t *testing.T) {
1603 t.Parallel()
1604
1605 audioTool := &mockTool{
1606 name: "generate_audio",
1607 description: "Generates audio",
1608 executeFunc: func(ctx context.Context, call ToolCall) (ToolResponse, error) {
1609 return NewMediaResponse(audioData, "audio/wav"), nil
1610 },
1611 }
1612
1613 model := &mockLanguageModel{
1614 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1615 if len(call.Prompt) == 1 {
1616 return &Response{
1617 Content: []Content{
1618 ToolCallContent{
1619 ToolCallID: "audio-1",
1620 ToolName: "generate_audio",
1621 Input: `{}`,
1622 },
1623 },
1624 Usage: Usage{TotalTokens: 10},
1625 FinishReason: FinishReasonToolCalls,
1626 }, nil
1627 }
1628 return &Response{
1629 Content: []Content{TextContent{Text: "Audio generated"}},
1630 Usage: Usage{TotalTokens: 20},
1631 FinishReason: FinishReasonStop,
1632 }, nil
1633 },
1634 }
1635
1636 agent := NewAgent(model, WithTools(audioTool), WithStopConditions(StepCountIs(3)))
1637
1638 result, err := agent.Generate(context.Background(), AgentCall{
1639 Prompt: "Generate audio",
1640 })
1641
1642 require.NoError(t, err)
1643 require.NotNil(t, result)
1644
1645 toolResults := result.Steps[0].Content.ToolResults()
1646 require.Len(t, toolResults, 1)
1647
1648 mediaResult, ok := toolResults[0].Result.(ToolResultOutputContentMedia)
1649 require.True(t, ok, "Expected media result")
1650 require.Equal(t, string(audioData), mediaResult.Data)
1651 require.Equal(t, "audio/wav", mediaResult.MediaType)
1652 })
1653
1654 t.Run("Media response with text", func(t *testing.T) {
1655 t.Parallel()
1656
1657 imageTool := &mockTool{
1658 name: "screenshot",
1659 description: "Takes a screenshot",
1660 executeFunc: func(ctx context.Context, call ToolCall) (ToolResponse, error) {
1661 resp := NewImageResponse(imageData, "image/png")
1662 resp.Content = "Screenshot captured successfully"
1663 return resp, nil
1664 },
1665 }
1666
1667 model := &mockLanguageModel{
1668 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1669 if len(call.Prompt) == 1 {
1670 return &Response{
1671 Content: []Content{
1672 ToolCallContent{
1673 ToolCallID: "screen-1",
1674 ToolName: "screenshot",
1675 Input: `{}`,
1676 },
1677 },
1678 Usage: Usage{TotalTokens: 10},
1679 FinishReason: FinishReasonToolCalls,
1680 }, nil
1681 }
1682 return &Response{
1683 Content: []Content{TextContent{Text: "Done"}},
1684 Usage: Usage{TotalTokens: 20},
1685 FinishReason: FinishReasonStop,
1686 }, nil
1687 },
1688 }
1689
1690 agent := NewAgent(model, WithTools(imageTool), WithStopConditions(StepCountIs(3)))
1691
1692 result, err := agent.Generate(context.Background(), AgentCall{
1693 Prompt: "Take a screenshot",
1694 })
1695
1696 require.NoError(t, err)
1697 require.NotNil(t, result)
1698
1699 toolResults := result.Steps[0].Content.ToolResults()
1700 require.Len(t, toolResults, 1)
1701
1702 mediaResult, ok := toolResults[0].Result.(ToolResultOutputContentMedia)
1703 require.True(t, ok, "Expected media result")
1704 require.Equal(t, string(imageData), mediaResult.Data)
1705 require.Equal(t, "image/png", mediaResult.MediaType)
1706 require.Equal(t, "Screenshot captured successfully", mediaResult.Text)
1707 })
1708
1709 t.Run("Media response preserves metadata", func(t *testing.T) {
1710 t.Parallel()
1711
1712 type ImageMetadata struct {
1713 Width int `json:"width"`
1714 Height int `json:"height"`
1715 }
1716
1717 imageTool := &mockTool{
1718 name: "generate_image",
1719 description: "Generates an image",
1720 executeFunc: func(ctx context.Context, call ToolCall) (ToolResponse, error) {
1721 resp := NewImageResponse(imageData, "image/png")
1722 return WithResponseMetadata(resp, ImageMetadata{Width: 800, Height: 600}), nil
1723 },
1724 }
1725
1726 model := &mockLanguageModel{
1727 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1728 if len(call.Prompt) == 1 {
1729 return &Response{
1730 Content: []Content{
1731 ToolCallContent{
1732 ToolCallID: "img-1",
1733 ToolName: "generate_image",
1734 Input: `{}`,
1735 },
1736 },
1737 Usage: Usage{TotalTokens: 10},
1738 FinishReason: FinishReasonToolCalls,
1739 }, nil
1740 }
1741 return &Response{
1742 Content: []Content{TextContent{Text: "Done"}},
1743 Usage: Usage{TotalTokens: 20},
1744 FinishReason: FinishReasonStop,
1745 }, nil
1746 },
1747 }
1748
1749 agent := NewAgent(model, WithTools(imageTool), WithStopConditions(StepCountIs(3)))
1750
1751 result, err := agent.Generate(context.Background(), AgentCall{
1752 Prompt: "Generate image",
1753 })
1754
1755 require.NoError(t, err)
1756 require.NotNil(t, result)
1757
1758 toolResults := result.Steps[0].Content.ToolResults()
1759 require.Len(t, toolResults, 1)
1760
1761 // Check metadata was preserved
1762 require.NotEmpty(t, toolResults[0].ClientMetadata)
1763
1764 var metadata ImageMetadata
1765 err = json.Unmarshal([]byte(toolResults[0].ClientMetadata), &metadata)
1766 require.NoError(t, err)
1767 require.Equal(t, 800, metadata.Width)
1768 require.Equal(t, 600, metadata.Height)
1769 })
1770}