1package ai
2
3import (
4 "context"
5 "encoding/json"
6 "errors"
7 "fmt"
8 "testing"
9
10 "github.com/stretchr/testify/require"
11)
12
13// Mock tool for testing
14type mockTool struct {
15 name string
16 providerOptions ProviderOptions
17 description string
18 parameters map[string]any
19 required []string
20 executeFunc func(ctx context.Context, call ToolCall) (ToolResponse, error)
21}
22
23func (m *mockTool) SetProviderOptions(opts ProviderOptions) {
24 m.providerOptions = opts
25}
26
27func (m *mockTool) ProviderOptions() ProviderOptions {
28 return m.providerOptions
29}
30
31func (m *mockTool) Info() ToolInfo {
32 return ToolInfo{
33 Name: m.name,
34 Description: m.description,
35 Parameters: m.parameters,
36 Required: m.required,
37 }
38}
39
40func (m *mockTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
41 if m.executeFunc != nil {
42 return m.executeFunc(ctx, call)
43 }
44 return ToolResponse{Content: "mock result", IsError: false}, nil
45}
46
47// Mock language model for testing
48type mockLanguageModel struct {
49 generateFunc func(ctx context.Context, call Call) (*Response, error)
50 streamFunc func(ctx context.Context, call Call) (StreamResponse, error)
51}
52
53func (m *mockLanguageModel) Generate(ctx context.Context, call Call) (*Response, error) {
54 if m.generateFunc != nil {
55 return m.generateFunc(ctx, call)
56 }
57 return &Response{
58 Content: []Content{
59 TextContent{Text: "Hello, world!"},
60 },
61 Usage: Usage{
62 InputTokens: 3,
63 OutputTokens: 10,
64 TotalTokens: 13,
65 },
66 FinishReason: FinishReasonStop,
67 }, nil
68}
69
70func (m *mockLanguageModel) Stream(ctx context.Context, call Call) (StreamResponse, error) {
71 if m.streamFunc != nil {
72 return m.streamFunc(ctx, call)
73 }
74 return nil, fmt.Errorf("mock stream not implemented")
75}
76
77func (m *mockLanguageModel) Provider() string {
78 return "mock-provider"
79}
80
81func (m *mockLanguageModel) Model() string {
82 return "mock-model"
83}
84
85// Test result.content - comprehensive content types (matches TS test)
86func TestAgent_Generate_ResultContent_AllTypes(t *testing.T) {
87 t.Parallel()
88
89 // Create a type-safe tool using the new API
90 type TestInput struct {
91 Value string `json:"value" description:"Test value"`
92 }
93
94 tool1 := NewAgentTool(
95 "tool1",
96 "Test tool",
97 func(ctx context.Context, input TestInput, _ ToolCall) (ToolResponse, error) {
98 require.Equal(t, "value", input.Value)
99 return ToolResponse{Content: "result1", IsError: false}, nil
100 },
101 )
102
103 model := &mockLanguageModel{
104 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
105 return &Response{
106 Content: []Content{
107 TextContent{Text: "Hello, world!"},
108 SourceContent{
109 ID: "123",
110 URL: "https://example.com",
111 Title: "Example",
112 SourceType: SourceTypeURL,
113 },
114 FileContent{
115 Data: []byte{1, 2, 3},
116 MediaType: "image/png",
117 },
118 ReasoningContent{
119 Text: "I will open the conversation with witty banter.",
120 },
121 ToolCallContent{
122 ToolCallID: "call-1",
123 ToolName: "tool1",
124 Input: `{"value":"value"}`,
125 },
126 TextContent{Text: "More text"},
127 },
128 Usage: Usage{
129 InputTokens: 3,
130 OutputTokens: 10,
131 TotalTokens: 13,
132 },
133 FinishReason: FinishReasonStop, // Note: FinishReasonStop, not ToolCalls
134 }, nil
135 },
136 }
137
138 agent := NewAgent(model, WithTools(tool1))
139 result, err := agent.Generate(context.Background(), AgentCall{
140 Prompt: "prompt",
141 })
142
143 require.NoError(t, err)
144 require.NotNil(t, result)
145 require.Len(t, result.Steps, 1) // Single step like TypeScript
146
147 // Check final response content includes tool result
148 require.Len(t, result.Response.Content, 7) // original 6 + 1 tool result
149
150 // Verify each content type in order
151 textContent, ok := AsContentType[TextContent](result.Response.Content[0])
152 require.True(t, ok)
153 require.Equal(t, "Hello, world!", textContent.Text)
154
155 sourceContent, ok := AsContentType[SourceContent](result.Response.Content[1])
156 require.True(t, ok)
157 require.Equal(t, "123", sourceContent.ID)
158
159 fileContent, ok := AsContentType[FileContent](result.Response.Content[2])
160 require.True(t, ok)
161 require.Equal(t, []byte{1, 2, 3}, fileContent.Data)
162
163 reasoningContent, ok := AsContentType[ReasoningContent](result.Response.Content[3])
164 require.True(t, ok)
165 require.Equal(t, "I will open the conversation with witty banter.", reasoningContent.Text)
166
167 toolCallContent, ok := AsContentType[ToolCallContent](result.Response.Content[4])
168 require.True(t, ok)
169 require.Equal(t, "call-1", toolCallContent.ToolCallID)
170
171 moreTextContent, ok := AsContentType[TextContent](result.Response.Content[5])
172 require.True(t, ok)
173 require.Equal(t, "More text", moreTextContent.Text)
174
175 // Tool result should be appended
176 toolResultContent, ok := AsContentType[ToolResultContent](result.Response.Content[6])
177 require.True(t, ok)
178 require.Equal(t, "call-1", toolResultContent.ToolCallID)
179 require.Equal(t, "tool1", toolResultContent.ToolName)
180}
181
182// Test result.text extraction
183func TestAgent_Generate_ResultText(t *testing.T) {
184 t.Parallel()
185
186 model := &mockLanguageModel{
187 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
188 return &Response{
189 Content: []Content{
190 TextContent{Text: "Hello, world!"},
191 },
192 Usage: Usage{
193 InputTokens: 3,
194 OutputTokens: 10,
195 TotalTokens: 13,
196 },
197 FinishReason: FinishReasonStop,
198 }, nil
199 },
200 }
201
202 agent := NewAgent(model)
203 result, err := agent.Generate(context.Background(), AgentCall{
204 Prompt: "prompt",
205 })
206
207 require.NoError(t, err)
208 require.NotNil(t, result)
209
210 // Test text extraction from content
211 text := result.Response.Content.Text()
212 require.Equal(t, "Hello, world!", text)
213}
214
215// Test result.toolCalls extraction (matches TS test exactly)
216func TestAgent_Generate_ResultToolCalls(t *testing.T) {
217 t.Parallel()
218
219 // Create type-safe tools using the new API
220 type Tool1Input struct {
221 Value string `json:"value" description:"Test value"`
222 }
223
224 type Tool2Input struct {
225 SomethingElse string `json:"somethingElse" description:"Another test value"`
226 }
227
228 tool1 := NewAgentTool(
229 "tool1",
230 "Test tool 1",
231 func(ctx context.Context, input Tool1Input, _ ToolCall) (ToolResponse, error) {
232 return ToolResponse{Content: "result1", IsError: false}, nil
233 },
234 )
235
236 tool2 := NewAgentTool(
237 "tool2",
238 "Test tool 2",
239 func(ctx context.Context, input Tool2Input, _ ToolCall) (ToolResponse, error) {
240 return ToolResponse{Content: "result2", IsError: false}, nil
241 },
242 )
243
244 model := &mockLanguageModel{
245 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
246 // Verify tools are passed correctly
247 require.Len(t, call.Tools, 2)
248 require.Equal(t, ToolChoiceAuto, *call.ToolChoice) // Should be auto, not required
249
250 // Verify prompt structure
251 require.Len(t, call.Prompt, 1)
252 require.Equal(t, MessageRoleUser, call.Prompt[0].Role)
253
254 return &Response{
255 Content: []Content{
256 ToolCallContent{
257 ToolCallID: "call-1",
258 ToolName: "tool1",
259 Input: `{"value":"value"}`,
260 },
261 },
262 Usage: Usage{
263 InputTokens: 3,
264 OutputTokens: 10,
265 TotalTokens: 13,
266 },
267 FinishReason: FinishReasonStop, // Note: Stop, not ToolCalls
268 }, nil
269 },
270 }
271
272 agent := NewAgent(model, WithTools(tool1, tool2))
273 result, err := agent.Generate(context.Background(), AgentCall{
274 Prompt: "test-input",
275 })
276
277 require.NoError(t, err)
278 require.NotNil(t, result)
279 require.Len(t, result.Steps, 1) // Single step
280
281 // Extract tool calls from final response (should be empty since tools don't execute)
282 var toolCalls []ToolCallContent
283 for _, content := range result.Response.Content {
284 if toolCall, ok := AsContentType[ToolCallContent](content); ok {
285 toolCalls = append(toolCalls, toolCall)
286 }
287 }
288
289 require.Len(t, toolCalls, 1)
290 require.Equal(t, "call-1", toolCalls[0].ToolCallID)
291 require.Equal(t, "tool1", toolCalls[0].ToolName)
292
293 // Parse and verify input
294 var input map[string]any
295 err = json.Unmarshal([]byte(toolCalls[0].Input), &input)
296 require.NoError(t, err)
297 require.Equal(t, "value", input["value"])
298}
299
300// Test result.toolResults extraction (matches TS test exactly)
301func TestAgent_Generate_ResultToolResults(t *testing.T) {
302 t.Parallel()
303
304 // Create type-safe tool using the new API
305 type TestInput struct {
306 Value string `json:"value" description:"Test value"`
307 }
308
309 tool1 := NewAgentTool(
310 "tool1",
311 "Test tool",
312 func(ctx context.Context, input TestInput, _ ToolCall) (ToolResponse, error) {
313 require.Equal(t, "value", input.Value)
314 return ToolResponse{Content: "result1", IsError: false}, nil
315 },
316 )
317
318 model := &mockLanguageModel{
319 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
320 // Verify tools and tool choice
321 require.Len(t, call.Tools, 1)
322 require.Equal(t, ToolChoiceAuto, *call.ToolChoice)
323
324 // Verify prompt
325 require.Len(t, call.Prompt, 1)
326 require.Equal(t, MessageRoleUser, call.Prompt[0].Role)
327
328 return &Response{
329 Content: []Content{
330 ToolCallContent{
331 ToolCallID: "call-1",
332 ToolName: "tool1",
333 Input: `{"value":"value"}`,
334 },
335 },
336 Usage: Usage{
337 InputTokens: 3,
338 OutputTokens: 10,
339 TotalTokens: 13,
340 },
341 FinishReason: FinishReasonStop, // Note: Stop, not ToolCalls
342 }, nil
343 },
344 }
345
346 agent := NewAgent(model, WithTools(tool1))
347 result, err := agent.Generate(context.Background(), AgentCall{
348 Prompt: "test-input",
349 })
350
351 require.NoError(t, err)
352 require.NotNil(t, result)
353 require.Len(t, result.Steps, 1) // Single step
354
355 // Extract tool results from final response
356 var toolResults []ToolResultContent
357 for _, content := range result.Response.Content {
358 if toolResult, ok := AsContentType[ToolResultContent](content); ok {
359 toolResults = append(toolResults, toolResult)
360 }
361 }
362
363 require.Len(t, toolResults, 1)
364 require.Equal(t, "call-1", toolResults[0].ToolCallID)
365 require.Equal(t, "tool1", toolResults[0].ToolName)
366
367 // Verify result content
368 textResult, ok := toolResults[0].Result.(ToolResultOutputContentText)
369 require.True(t, ok)
370 require.Equal(t, "result1", textResult.Text)
371}
372
373// Test multi-step scenario (matches TS "2 steps: initial, tool-result" test)
374func TestAgent_Generate_MultipleSteps(t *testing.T) {
375 t.Parallel()
376
377 // Create type-safe tool using the new API
378 type TestInput struct {
379 Value string `json:"value" description:"Test value"`
380 }
381
382 tool1 := NewAgentTool(
383 "tool1",
384 "Test tool",
385 func(ctx context.Context, input TestInput, _ ToolCall) (ToolResponse, error) {
386 require.Equal(t, "value", input.Value)
387 return ToolResponse{Content: "result1", IsError: false}, nil
388 },
389 )
390
391 callCount := 0
392 model := &mockLanguageModel{
393 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
394 callCount++
395 switch callCount {
396 case 1:
397 // First call - return tool call with FinishReasonToolCalls
398 return &Response{
399 Content: []Content{
400 ToolCallContent{
401 ToolCallID: "call-1",
402 ToolName: "tool1",
403 Input: `{"value":"value"}`,
404 },
405 },
406 Usage: Usage{
407 InputTokens: 10,
408 OutputTokens: 5,
409 TotalTokens: 15,
410 },
411 FinishReason: FinishReasonToolCalls, // This triggers multi-step
412 }, nil
413 case 2:
414 // Second call - return final text
415 return &Response{
416 Content: []Content{
417 TextContent{Text: "Hello, world!"},
418 },
419 Usage: Usage{
420 InputTokens: 3,
421 OutputTokens: 10,
422 TotalTokens: 13,
423 },
424 FinishReason: FinishReasonStop,
425 }, nil
426 default:
427 t.Fatalf("Unexpected call count: %d", callCount)
428 return nil, nil
429 }
430 },
431 }
432
433 agent := NewAgent(model, WithTools(tool1))
434 result, err := agent.Generate(context.Background(), AgentCall{
435 Prompt: "test-input",
436 })
437
438 require.NoError(t, err)
439 require.NotNil(t, result)
440 require.Len(t, result.Steps, 2)
441
442 // Check total usage sums both steps
443 require.Equal(t, int64(13), result.TotalUsage.InputTokens) // 10 + 3
444 require.Equal(t, int64(15), result.TotalUsage.OutputTokens) // 5 + 10
445 require.Equal(t, int64(28), result.TotalUsage.TotalTokens) // 15 + 13
446
447 // Final response should be from last step
448 require.Len(t, result.Response.Content, 1)
449 textContent, ok := AsContentType[TextContent](result.Response.Content[0])
450 require.True(t, ok)
451 require.Equal(t, "Hello, world!", textContent.Text)
452
453 // result.toolCalls should be empty (from last step)
454 var toolCalls []ToolCallContent
455 for _, content := range result.Response.Content {
456 if _, ok := AsContentType[ToolCallContent](content); ok {
457 toolCalls = append(toolCalls, content.(ToolCallContent))
458 }
459 }
460 require.Len(t, toolCalls, 0)
461
462 // result.toolResults should be empty (from last step)
463 var toolResults []ToolResultContent
464 for _, content := range result.Response.Content {
465 if _, ok := AsContentType[ToolResultContent](content); ok {
466 toolResults = append(toolResults, content.(ToolResultContent))
467 }
468 }
469 require.Len(t, toolResults, 0)
470}
471
472// Test basic text generation
473func TestAgent_Generate_BasicText(t *testing.T) {
474 t.Parallel()
475
476 model := &mockLanguageModel{
477 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
478 return &Response{
479 Content: []Content{
480 TextContent{Text: "Hello, world!"},
481 },
482 Usage: Usage{
483 InputTokens: 3,
484 OutputTokens: 10,
485 TotalTokens: 13,
486 },
487 FinishReason: FinishReasonStop,
488 }, nil
489 },
490 }
491
492 agent := NewAgent(model)
493 result, err := agent.Generate(context.Background(), AgentCall{
494 Prompt: "test prompt",
495 })
496
497 require.NoError(t, err)
498 require.NotNil(t, result)
499 require.Len(t, result.Steps, 1)
500
501 // Check final response
502 require.Len(t, result.Response.Content, 1)
503 textContent, ok := AsContentType[TextContent](result.Response.Content[0])
504 require.True(t, ok)
505 require.Equal(t, "Hello, world!", textContent.Text)
506
507 // Check usage
508 require.Equal(t, int64(3), result.Response.Usage.InputTokens)
509 require.Equal(t, int64(10), result.Response.Usage.OutputTokens)
510 require.Equal(t, int64(13), result.Response.Usage.TotalTokens)
511
512 // Check total usage
513 require.Equal(t, int64(3), result.TotalUsage.InputTokens)
514 require.Equal(t, int64(10), result.TotalUsage.OutputTokens)
515 require.Equal(t, int64(13), result.TotalUsage.TotalTokens)
516}
517
518// Test empty prompt error
519func TestAgent_Generate_EmptyPrompt(t *testing.T) {
520 t.Parallel()
521
522 model := &mockLanguageModel{}
523 agent := NewAgent(model)
524
525 result, err := agent.Generate(context.Background(), AgentCall{
526 Prompt: "", // Empty prompt should cause error
527 })
528
529 require.Error(t, err)
530 require.Nil(t, result)
531 require.Contains(t, err.Error(), "Prompt can't be empty")
532}
533
534// Test with system prompt
535func TestAgent_Generate_WithSystemPrompt(t *testing.T) {
536 t.Parallel()
537
538 model := &mockLanguageModel{
539 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
540 // Verify system message is included
541 require.Len(t, call.Prompt, 2) // system + user
542 require.Equal(t, MessageRoleSystem, call.Prompt[0].Role)
543 require.Equal(t, MessageRoleUser, call.Prompt[1].Role)
544
545 systemPart, ok := call.Prompt[0].Content[0].(TextPart)
546 require.True(t, ok)
547 require.Equal(t, "You are a helpful assistant", systemPart.Text)
548
549 return &Response{
550 Content: []Content{
551 TextContent{Text: "Hello, world!"},
552 },
553 Usage: Usage{
554 InputTokens: 3,
555 OutputTokens: 10,
556 TotalTokens: 13,
557 },
558 FinishReason: FinishReasonStop,
559 }, nil
560 },
561 }
562
563 agent := NewAgent(model, WithSystemPrompt("You are a helpful assistant"))
564 result, err := agent.Generate(context.Background(), AgentCall{
565 Prompt: "test prompt",
566 })
567
568 require.NoError(t, err)
569 require.NotNil(t, result)
570}
571
572// Test options.activeTools filtering
573func TestAgent_Generate_OptionsActiveTools(t *testing.T) {
574 t.Parallel()
575
576 tool1 := &mockTool{
577 name: "tool1",
578 description: "Test tool 1",
579 parameters: map[string]any{
580 "value": map[string]any{"type": "string"},
581 },
582 required: []string{"value"},
583 }
584
585 tool2 := &mockTool{
586 name: "tool2",
587 description: "Test tool 2",
588 parameters: map[string]any{
589 "value": map[string]any{"type": "string"},
590 },
591 required: []string{"value"},
592 }
593
594 model := &mockLanguageModel{
595 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
596 // Verify only tool1 is available
597 require.Len(t, call.Tools, 1)
598 functionTool, ok := call.Tools[0].(FunctionTool)
599 require.True(t, ok)
600 require.Equal(t, "tool1", functionTool.Name)
601
602 return &Response{
603 Content: []Content{
604 TextContent{Text: "Hello, world!"},
605 },
606 Usage: Usage{
607 InputTokens: 3,
608 OutputTokens: 10,
609 TotalTokens: 13,
610 },
611 FinishReason: FinishReasonStop,
612 }, nil
613 },
614 }
615
616 agent := NewAgent(model, WithTools(tool1, tool2))
617 result, err := agent.Generate(context.Background(), AgentCall{
618 Prompt: "test-input",
619 ActiveTools: []string{"tool1"}, // Only tool1 should be active
620 })
621
622 require.NoError(t, err)
623 require.NotNil(t, result)
624}
625
626func TestResponseContent_Getters(t *testing.T) {
627 t.Parallel()
628
629 // Create test content with all types
630 content := ResponseContent{
631 TextContent{Text: "Hello world"},
632 ReasoningContent{Text: "Let me think..."},
633 FileContent{Data: []byte("file data"), MediaType: "text/plain"},
634 SourceContent{SourceType: SourceTypeURL, URL: "https://example.com", Title: "Example"},
635 ToolCallContent{ToolCallID: "call1", ToolName: "test_tool", Input: `{"arg": "value"}`},
636 ToolResultContent{ToolCallID: "call1", ToolName: "test_tool", Result: ToolResultOutputContentText{Text: "result"}},
637 }
638
639 // Test Text()
640 require.Equal(t, "Hello world", content.Text())
641
642 // Test Reasoning()
643 reasoning := content.Reasoning()
644 require.Len(t, reasoning, 1)
645 require.Equal(t, "Let me think...", reasoning[0].Text)
646
647 // Test ReasoningText()
648 require.Equal(t, "Let me think...", content.ReasoningText())
649
650 // Test Files()
651 files := content.Files()
652 require.Len(t, files, 1)
653 require.Equal(t, "text/plain", files[0].MediaType)
654 require.Equal(t, []byte("file data"), files[0].Data)
655
656 // Test Sources()
657 sources := content.Sources()
658 require.Len(t, sources, 1)
659 require.Equal(t, SourceTypeURL, sources[0].SourceType)
660 require.Equal(t, "https://example.com", sources[0].URL)
661 require.Equal(t, "Example", sources[0].Title)
662
663 // Test ToolCalls()
664 toolCalls := content.ToolCalls()
665 require.Len(t, toolCalls, 1)
666 require.Equal(t, "call1", toolCalls[0].ToolCallID)
667 require.Equal(t, "test_tool", toolCalls[0].ToolName)
668 require.Equal(t, `{"arg": "value"}`, toolCalls[0].Input)
669
670 // Test ToolResults()
671 toolResults := content.ToolResults()
672 require.Len(t, toolResults, 1)
673 require.Equal(t, "call1", toolResults[0].ToolCallID)
674 require.Equal(t, "test_tool", toolResults[0].ToolName)
675 result, ok := AsToolResultOutputType[ToolResultOutputContentText](toolResults[0].Result)
676 require.True(t, ok)
677 require.Equal(t, "result", result.Text)
678}
679
680func TestResponseContent_Getters_Empty(t *testing.T) {
681 t.Parallel()
682
683 // Test with empty content
684 content := ResponseContent{}
685
686 require.Equal(t, "", content.Text())
687 require.Equal(t, "", content.ReasoningText())
688 require.Empty(t, content.Reasoning())
689 require.Empty(t, content.Files())
690 require.Empty(t, content.Sources())
691 require.Empty(t, content.ToolCalls())
692 require.Empty(t, content.ToolResults())
693}
694
695func TestResponseContent_Getters_MultipleItems(t *testing.T) {
696 t.Parallel()
697
698 // Test with multiple items of same type
699 content := ResponseContent{
700 ReasoningContent{Text: "First thought"},
701 ReasoningContent{Text: "Second thought"},
702 FileContent{Data: []byte("file1"), MediaType: "text/plain"},
703 FileContent{Data: []byte("file2"), MediaType: "image/png"},
704 }
705
706 // Test multiple reasoning
707 reasoning := content.Reasoning()
708 require.Len(t, reasoning, 2)
709 require.Equal(t, "First thought", reasoning[0].Text)
710 require.Equal(t, "Second thought", reasoning[1].Text)
711
712 // Test concatenated reasoning text
713 require.Equal(t, "First thoughtSecond thought", content.ReasoningText())
714
715 // Test multiple files
716 files := content.Files()
717 require.Len(t, files, 2)
718 require.Equal(t, "text/plain", files[0].MediaType)
719 require.Equal(t, "image/png", files[1].MediaType)
720}
721
722func TestStopConditions(t *testing.T) {
723 t.Parallel()
724
725 // Create test steps
726 step1 := StepResult{
727 Response: Response{
728 Content: ResponseContent{
729 TextContent{Text: "Hello"},
730 },
731 FinishReason: FinishReasonToolCalls,
732 Usage: Usage{TotalTokens: 10},
733 },
734 }
735
736 step2 := StepResult{
737 Response: Response{
738 Content: ResponseContent{
739 TextContent{Text: "World"},
740 ToolCallContent{ToolCallID: "call1", ToolName: "search", Input: `{"query": "test"}`},
741 },
742 FinishReason: FinishReasonStop,
743 Usage: Usage{TotalTokens: 15},
744 },
745 }
746
747 step3 := StepResult{
748 Response: Response{
749 Content: ResponseContent{
750 ReasoningContent{Text: "Let me think..."},
751 FileContent{Data: []byte("data"), MediaType: "text/plain"},
752 },
753 FinishReason: FinishReasonLength,
754 Usage: Usage{TotalTokens: 20},
755 },
756 }
757
758 t.Run("StepCountIs", func(t *testing.T) {
759 t.Parallel()
760 condition := StepCountIs(2)
761
762 // Should not stop with 1 step
763 require.False(t, condition([]StepResult{step1}))
764
765 // Should stop with 2 steps
766 require.True(t, condition([]StepResult{step1, step2}))
767
768 // Should stop with more than 2 steps
769 require.True(t, condition([]StepResult{step1, step2, step3}))
770
771 // Should not stop with empty steps
772 require.False(t, condition([]StepResult{}))
773 })
774
775 t.Run("HasToolCall", func(t *testing.T) {
776 t.Parallel()
777 condition := HasToolCall("search")
778
779 // Should not stop when tool not called
780 require.False(t, condition([]StepResult{step1}))
781
782 // Should stop when tool is called in last step
783 require.True(t, condition([]StepResult{step1, step2}))
784
785 // Should not stop when tool called in earlier step but not last
786 require.False(t, condition([]StepResult{step1, step2, step3}))
787
788 // Should not stop with empty steps
789 require.False(t, condition([]StepResult{}))
790
791 // Should not stop when different tool is called
792 differentToolCondition := HasToolCall("different_tool")
793 require.False(t, differentToolCondition([]StepResult{step1, step2}))
794 })
795
796 t.Run("HasContent", func(t *testing.T) {
797 t.Parallel()
798 reasoningCondition := HasContent(ContentTypeReasoning)
799 fileCondition := HasContent(ContentTypeFile)
800
801 // Should not stop when content type not present
802 require.False(t, reasoningCondition([]StepResult{step1, step2}))
803
804 // Should stop when content type is present in last step
805 require.True(t, reasoningCondition([]StepResult{step1, step2, step3}))
806 require.True(t, fileCondition([]StepResult{step1, step2, step3}))
807
808 // Should not stop with empty steps
809 require.False(t, reasoningCondition([]StepResult{}))
810 })
811
812 t.Run("FinishReasonIs", func(t *testing.T) {
813 t.Parallel()
814 stopCondition := FinishReasonIs(FinishReasonStop)
815 lengthCondition := FinishReasonIs(FinishReasonLength)
816
817 // Should not stop when finish reason doesn't match
818 require.False(t, stopCondition([]StepResult{step1}))
819
820 // Should stop when finish reason matches in last step
821 require.True(t, stopCondition([]StepResult{step1, step2}))
822 require.True(t, lengthCondition([]StepResult{step1, step2, step3}))
823
824 // Should not stop with empty steps
825 require.False(t, stopCondition([]StepResult{}))
826 })
827
828 t.Run("MaxTokensUsed", func(t *testing.T) {
829 condition := MaxTokensUsed(30)
830
831 // Should not stop when under limit
832 require.False(t, condition([]StepResult{step1})) // 10 tokens
833 require.False(t, condition([]StepResult{step1, step2})) // 25 tokens
834
835 // Should stop when at or over limit
836 require.True(t, condition([]StepResult{step1, step2, step3})) // 45 tokens
837
838 // Should not stop with empty steps
839 require.False(t, condition([]StepResult{}))
840 })
841}
842
843func TestStopConditions_Integration(t *testing.T) {
844 t.Parallel()
845
846 t.Run("StepCountIs integration", func(t *testing.T) {
847 t.Parallel()
848 model := &mockLanguageModel{
849 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
850 return &Response{
851 Content: ResponseContent{
852 TextContent{Text: "Mock response"},
853 },
854 Usage: Usage{
855 InputTokens: 3,
856 OutputTokens: 10,
857 TotalTokens: 13,
858 },
859 FinishReason: FinishReasonStop,
860 }, nil
861 },
862 }
863
864 agent := NewAgent(model, WithStopConditions(StepCountIs(1)))
865
866 result, err := agent.Generate(context.Background(), AgentCall{
867 Prompt: "test prompt",
868 })
869
870 require.NoError(t, err)
871 require.NotNil(t, result)
872 require.Len(t, result.Steps, 1) // Should stop after 1 step
873 })
874
875 t.Run("Multiple stop conditions", func(t *testing.T) {
876 t.Parallel()
877 model := &mockLanguageModel{
878 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
879 return &Response{
880 Content: ResponseContent{
881 TextContent{Text: "Mock response"},
882 },
883 Usage: Usage{
884 InputTokens: 3,
885 OutputTokens: 10,
886 TotalTokens: 13,
887 },
888 FinishReason: FinishReasonStop,
889 }, nil
890 },
891 }
892
893 agent := NewAgent(model, WithStopConditions(
894 StepCountIs(5), // Stop after 5 steps
895 FinishReasonIs(FinishReasonStop), // Or stop on finish reason
896 ))
897
898 result, err := agent.Generate(context.Background(), AgentCall{
899 Prompt: "test prompt",
900 })
901
902 require.NoError(t, err)
903 require.NotNil(t, result)
904 // Should stop on first condition met (finish reason stop)
905 require.Equal(t, FinishReasonStop, result.Response.FinishReason)
906 })
907}
908
909func TestPrepareStep(t *testing.T) {
910 t.Parallel()
911
912 t.Run("System prompt modification", func(t *testing.T) {
913 t.Parallel()
914 var capturedSystemPrompt string
915 model := &mockLanguageModel{
916 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
917 // Capture the system message to verify it was modified
918 if len(call.Prompt) > 0 && call.Prompt[0].Role == MessageRoleSystem {
919 if len(call.Prompt[0].Content) > 0 {
920 if textPart, ok := AsContentType[TextPart](call.Prompt[0].Content[0]); ok {
921 capturedSystemPrompt = textPart.Text
922 }
923 }
924 }
925 return &Response{
926 Content: ResponseContent{
927 TextContent{Text: "Response"},
928 },
929 Usage: Usage{TotalTokens: 10},
930 FinishReason: FinishReasonStop,
931 }, nil
932 },
933 }
934
935 prepareStepFunc := func(ctx context.Context, options PrepareStepFunctionOptions) (context.Context, PrepareStepResult, error) {
936 newSystem := "Modified system prompt for step " + fmt.Sprintf("%d", options.StepNumber)
937 return ctx, PrepareStepResult{
938 Model: options.Model,
939 Messages: options.Messages,
940 System: &newSystem,
941 }, nil
942 }
943
944 agent := NewAgent(model, WithSystemPrompt("Original system prompt"))
945
946 result, err := agent.Generate(context.Background(), AgentCall{
947 Prompt: "test prompt",
948 PrepareStep: prepareStepFunc,
949 })
950
951 require.NoError(t, err)
952 require.NotNil(t, result)
953 require.Equal(t, "Modified system prompt for step 0", capturedSystemPrompt)
954 })
955
956 t.Run("Tool choice modification", func(t *testing.T) {
957 t.Parallel()
958 var capturedToolChoice *ToolChoice
959 model := &mockLanguageModel{
960 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
961 capturedToolChoice = call.ToolChoice
962 return &Response{
963 Content: ResponseContent{
964 TextContent{Text: "Response"},
965 },
966 Usage: Usage{TotalTokens: 10},
967 FinishReason: FinishReasonStop,
968 }, nil
969 },
970 }
971
972 prepareStepFunc := func(ctx context.Context, options PrepareStepFunctionOptions) (context.Context, PrepareStepResult, error) {
973 toolChoice := ToolChoiceNone
974 return ctx, PrepareStepResult{
975 Model: options.Model,
976 Messages: options.Messages,
977 ToolChoice: &toolChoice,
978 }, nil
979 }
980
981 agent := NewAgent(model)
982
983 result, err := agent.Generate(context.Background(), AgentCall{
984 Prompt: "test prompt",
985 PrepareStep: prepareStepFunc,
986 })
987
988 require.NoError(t, err)
989 require.NotNil(t, result)
990 require.NotNil(t, capturedToolChoice)
991 require.Equal(t, ToolChoiceNone, *capturedToolChoice)
992 })
993
994 t.Run("Active tools modification", func(t *testing.T) {
995 t.Parallel()
996 var capturedToolNames []string
997 model := &mockLanguageModel{
998 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
999 // Capture tool names to verify active tools were modified
1000 for _, tool := range call.Tools {
1001 capturedToolNames = append(capturedToolNames, tool.GetName())
1002 }
1003 return &Response{
1004 Content: ResponseContent{
1005 TextContent{Text: "Response"},
1006 },
1007 Usage: Usage{TotalTokens: 10},
1008 FinishReason: FinishReasonStop,
1009 }, nil
1010 },
1011 }
1012
1013 tool1 := &mockTool{name: "tool1", description: "Tool 1"}
1014 tool2 := &mockTool{name: "tool2", description: "Tool 2"}
1015 tool3 := &mockTool{name: "tool3", description: "Tool 3"}
1016
1017 prepareStepFunc := func(ctx context.Context, options PrepareStepFunctionOptions) (context.Context, PrepareStepResult, error) {
1018 activeTools := []string{"tool2"} // Only tool2 should be active
1019 return ctx, PrepareStepResult{
1020 Model: options.Model,
1021 Messages: options.Messages,
1022 ActiveTools: activeTools,
1023 }, nil
1024 }
1025
1026 agent := NewAgent(model, WithTools(tool1, tool2, tool3))
1027
1028 result, err := agent.Generate(context.Background(), AgentCall{
1029 Prompt: "test prompt",
1030 PrepareStep: prepareStepFunc,
1031 })
1032
1033 require.NoError(t, err)
1034 require.NotNil(t, result)
1035 require.Len(t, capturedToolNames, 1)
1036 require.Equal(t, "tool2", capturedToolNames[0])
1037 })
1038
1039 t.Run("No tools when DisableAllTools is true", func(t *testing.T) {
1040 t.Parallel()
1041 var capturedToolCount int
1042 model := &mockLanguageModel{
1043 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1044 capturedToolCount = len(call.Tools)
1045 return &Response{
1046 Content: ResponseContent{
1047 TextContent{Text: "Response"},
1048 },
1049 Usage: Usage{TotalTokens: 10},
1050 FinishReason: FinishReasonStop,
1051 }, nil
1052 },
1053 }
1054
1055 tool1 := &mockTool{name: "tool1", description: "Tool 1"}
1056
1057 prepareStepFunc := func(ctx context.Context, options PrepareStepFunctionOptions) (context.Context, PrepareStepResult, error) {
1058 return ctx, PrepareStepResult{
1059 Model: options.Model,
1060 Messages: options.Messages,
1061 DisableAllTools: true, // Disable all tools for this step
1062 }, nil
1063 }
1064
1065 agent := NewAgent(model, WithTools(tool1))
1066
1067 result, err := agent.Generate(context.Background(), AgentCall{
1068 Prompt: "test prompt",
1069 PrepareStep: prepareStepFunc,
1070 })
1071
1072 require.NoError(t, err)
1073 require.NotNil(t, result)
1074 require.Equal(t, 0, capturedToolCount) // No tools should be passed
1075 })
1076
1077 t.Run("All fields modified together", func(t *testing.T) {
1078 t.Parallel()
1079 var capturedSystemPrompt string
1080 var capturedToolChoice *ToolChoice
1081 var capturedToolNames []string
1082
1083 model := &mockLanguageModel{
1084 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1085 // Capture system prompt
1086 if len(call.Prompt) > 0 && call.Prompt[0].Role == MessageRoleSystem {
1087 if len(call.Prompt[0].Content) > 0 {
1088 if textPart, ok := AsContentType[TextPart](call.Prompt[0].Content[0]); ok {
1089 capturedSystemPrompt = textPart.Text
1090 }
1091 }
1092 }
1093 // Capture tool choice
1094 capturedToolChoice = call.ToolChoice
1095 // Capture tool names
1096 for _, tool := range call.Tools {
1097 capturedToolNames = append(capturedToolNames, tool.GetName())
1098 }
1099 return &Response{
1100 Content: ResponseContent{
1101 TextContent{Text: "Response"},
1102 },
1103 Usage: Usage{TotalTokens: 10},
1104 FinishReason: FinishReasonStop,
1105 }, nil
1106 },
1107 }
1108
1109 tool1 := &mockTool{name: "tool1", description: "Tool 1"}
1110 tool2 := &mockTool{name: "tool2", description: "Tool 2"}
1111
1112 prepareStepFunc := func(ctx context.Context, options PrepareStepFunctionOptions) (context.Context, PrepareStepResult, error) {
1113 newSystem := "Step-specific system"
1114 toolChoice := SpecificToolChoice("tool1")
1115 activeTools := []string{"tool1"}
1116 return ctx, PrepareStepResult{
1117 Model: options.Model,
1118 Messages: options.Messages,
1119 System: &newSystem,
1120 ToolChoice: &toolChoice,
1121 ActiveTools: activeTools,
1122 }, nil
1123 }
1124
1125 agent := NewAgent(model, WithSystemPrompt("Original system"), WithTools(tool1, tool2))
1126
1127 result, err := agent.Generate(context.Background(), AgentCall{
1128 Prompt: "test prompt",
1129 PrepareStep: prepareStepFunc,
1130 })
1131
1132 require.NoError(t, err)
1133 require.NotNil(t, result)
1134 require.Equal(t, "Step-specific system", capturedSystemPrompt)
1135 require.NotNil(t, capturedToolChoice)
1136 require.Equal(t, SpecificToolChoice("tool1"), *capturedToolChoice)
1137 require.Len(t, capturedToolNames, 1)
1138 require.Equal(t, "tool1", capturedToolNames[0])
1139 })
1140
1141 t.Run("Nil fields use parent values", func(t *testing.T) {
1142 t.Parallel()
1143 var capturedSystemPrompt string
1144 var capturedToolChoice *ToolChoice
1145 var capturedToolNames []string
1146
1147 model := &mockLanguageModel{
1148 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1149 // Capture system prompt
1150 if len(call.Prompt) > 0 && call.Prompt[0].Role == MessageRoleSystem {
1151 if len(call.Prompt[0].Content) > 0 {
1152 if textPart, ok := AsContentType[TextPart](call.Prompt[0].Content[0]); ok {
1153 capturedSystemPrompt = textPart.Text
1154 }
1155 }
1156 }
1157 // Capture tool choice
1158 capturedToolChoice = call.ToolChoice
1159 // Capture tool names
1160 for _, tool := range call.Tools {
1161 capturedToolNames = append(capturedToolNames, tool.GetName())
1162 }
1163 return &Response{
1164 Content: ResponseContent{
1165 TextContent{Text: "Response"},
1166 },
1167 Usage: Usage{TotalTokens: 10},
1168 FinishReason: FinishReasonStop,
1169 }, nil
1170 },
1171 }
1172
1173 tool1 := &mockTool{name: "tool1", description: "Tool 1"}
1174
1175 prepareStepFunc := func(ctx context.Context, options PrepareStepFunctionOptions) (context.Context, PrepareStepResult, error) {
1176 // All optional fields are nil, should use parent values
1177 return ctx, PrepareStepResult{
1178 Model: options.Model,
1179 Messages: options.Messages,
1180 System: nil, // Use parent
1181 ToolChoice: nil, // Use parent (auto)
1182 ActiveTools: nil, // Use parent (all tools)
1183 }, nil
1184 }
1185
1186 agent := NewAgent(model, WithSystemPrompt("Parent system"), WithTools(tool1))
1187
1188 result, err := agent.Generate(context.Background(), AgentCall{
1189 Prompt: "test prompt",
1190 PrepareStep: prepareStepFunc,
1191 })
1192
1193 require.NoError(t, err)
1194 require.NotNil(t, result)
1195 require.Equal(t, "Parent system", capturedSystemPrompt)
1196 require.NotNil(t, capturedToolChoice)
1197 require.Equal(t, ToolChoiceAuto, *capturedToolChoice) // Default
1198 require.Len(t, capturedToolNames, 1)
1199 require.Equal(t, "tool1", capturedToolNames[0])
1200 })
1201
1202 t.Run("Empty ActiveTools means all tools", func(t *testing.T) {
1203 t.Parallel()
1204 var capturedToolNames []string
1205 model := &mockLanguageModel{
1206 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1207 // Capture tool names to verify all tools are included
1208 for _, tool := range call.Tools {
1209 capturedToolNames = append(capturedToolNames, tool.GetName())
1210 }
1211 return &Response{
1212 Content: ResponseContent{
1213 TextContent{Text: "Response"},
1214 },
1215 Usage: Usage{TotalTokens: 10},
1216 FinishReason: FinishReasonStop,
1217 }, nil
1218 },
1219 }
1220
1221 tool1 := &mockTool{name: "tool1", description: "Tool 1"}
1222 tool2 := &mockTool{name: "tool2", description: "Tool 2"}
1223
1224 prepareStepFunc := func(ctx context.Context, options PrepareStepFunctionOptions) (context.Context, PrepareStepResult, error) {
1225 return ctx, PrepareStepResult{
1226 Model: options.Model,
1227 Messages: options.Messages,
1228 ActiveTools: []string{}, // Empty slice means all tools
1229 }, nil
1230 }
1231
1232 agent := NewAgent(model, WithTools(tool1, tool2))
1233
1234 result, err := agent.Generate(context.Background(), AgentCall{
1235 Prompt: "test prompt",
1236 PrepareStep: prepareStepFunc,
1237 })
1238
1239 require.NoError(t, err)
1240 require.NotNil(t, result)
1241 require.Len(t, capturedToolNames, 2) // All tools should be included
1242 require.Contains(t, capturedToolNames, "tool1")
1243 require.Contains(t, capturedToolNames, "tool2")
1244 })
1245}
1246
1247func TestToolCallRepair(t *testing.T) {
1248 t.Parallel()
1249
1250 t.Run("Valid tool call passes validation", func(t *testing.T) {
1251 t.Parallel()
1252 model := &mockLanguageModel{
1253 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1254 return &Response{
1255 Content: ResponseContent{
1256 TextContent{Text: "Response"},
1257 ToolCallContent{
1258 ToolCallID: "call1",
1259 ToolName: "test_tool",
1260 Input: `{"value": "test"}`, // Valid JSON with required field
1261 },
1262 },
1263 Usage: Usage{TotalTokens: 10},
1264 FinishReason: FinishReasonStop, // Changed to stop to avoid infinite loop
1265 }, nil
1266 },
1267 }
1268
1269 tool := &mockTool{
1270 name: "test_tool",
1271 description: "Test tool",
1272 parameters: map[string]any{
1273 "value": map[string]any{"type": "string"},
1274 },
1275 required: []string{"value"},
1276 executeFunc: func(ctx context.Context, call ToolCall) (ToolResponse, error) {
1277 return ToolResponse{Content: "success", IsError: false}, nil
1278 },
1279 }
1280
1281 agent := NewAgent(model, WithTools(tool), WithStopConditions(StepCountIs(2))) // Limit steps
1282
1283 result, err := agent.Generate(context.Background(), AgentCall{
1284 Prompt: "test prompt",
1285 })
1286
1287 require.NoError(t, err)
1288 require.NotNil(t, result)
1289 require.Len(t, result.Steps, 1) // Only one step since FinishReason is stop
1290
1291 // Check that tool call was executed successfully
1292 toolCalls := result.Steps[0].Content.ToolCalls()
1293 require.Len(t, toolCalls, 1)
1294 require.False(t, toolCalls[0].Invalid) // Should be valid
1295 })
1296
1297 t.Run("Invalid tool call without repair function", func(t *testing.T) {
1298 t.Parallel()
1299 model := &mockLanguageModel{
1300 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1301 return &Response{
1302 Content: ResponseContent{
1303 TextContent{Text: "Response"},
1304 ToolCallContent{
1305 ToolCallID: "call1",
1306 ToolName: "test_tool",
1307 Input: `{"wrong_field": "test"}`, // Missing required field
1308 },
1309 },
1310 Usage: Usage{TotalTokens: 10},
1311 FinishReason: FinishReasonStop, // Changed to stop to avoid infinite loop
1312 }, nil
1313 },
1314 }
1315
1316 tool := &mockTool{
1317 name: "test_tool",
1318 description: "Test tool",
1319 parameters: map[string]any{
1320 "value": map[string]any{"type": "string"},
1321 },
1322 required: []string{"value"},
1323 }
1324
1325 agent := NewAgent(model, WithTools(tool), WithStopConditions(StepCountIs(2))) // Limit steps
1326
1327 result, err := agent.Generate(context.Background(), AgentCall{
1328 Prompt: "test prompt",
1329 })
1330
1331 require.NoError(t, err)
1332 require.NotNil(t, result)
1333 require.Len(t, result.Steps, 1) // Only one step
1334
1335 // Check that tool call was marked as invalid
1336 toolCalls := result.Steps[0].Content.ToolCalls()
1337 require.Len(t, toolCalls, 1)
1338 require.True(t, toolCalls[0].Invalid) // Should be invalid
1339 require.Contains(t, toolCalls[0].ValidationError.Error(), "missing required parameter: value")
1340 })
1341
1342 t.Run("Invalid tool call with successful repair", func(t *testing.T) {
1343 t.Parallel()
1344 model := &mockLanguageModel{
1345 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1346 return &Response{
1347 Content: ResponseContent{
1348 TextContent{Text: "Response"},
1349 ToolCallContent{
1350 ToolCallID: "call1",
1351 ToolName: "test_tool",
1352 Input: `{"wrong_field": "test"}`, // Missing required field
1353 },
1354 },
1355 Usage: Usage{TotalTokens: 10},
1356 FinishReason: FinishReasonStop, // Changed to stop
1357 }, nil
1358 },
1359 }
1360
1361 tool := &mockTool{
1362 name: "test_tool",
1363 description: "Test tool",
1364 parameters: map[string]any{
1365 "value": map[string]any{"type": "string"},
1366 },
1367 required: []string{"value"},
1368 executeFunc: func(ctx context.Context, call ToolCall) (ToolResponse, error) {
1369 return ToolResponse{Content: "repaired_success", IsError: false}, nil
1370 },
1371 }
1372
1373 repairFunc := func(ctx context.Context, options ToolCallRepairOptions) (*ToolCallContent, error) {
1374 // Simple repair: add the missing required field
1375 repairedToolCall := options.OriginalToolCall
1376 repairedToolCall.Input = `{"value": "repaired"}`
1377 return &repairedToolCall, nil
1378 }
1379
1380 agent := NewAgent(model, WithTools(tool), WithRepairToolCall(repairFunc), WithStopConditions(StepCountIs(2)))
1381
1382 result, err := agent.Generate(context.Background(), AgentCall{
1383 Prompt: "test prompt",
1384 })
1385
1386 require.NoError(t, err)
1387 require.NotNil(t, result)
1388 require.Len(t, result.Steps, 1) // Only one step
1389
1390 // Check that tool call was repaired and is now valid
1391 toolCalls := result.Steps[0].Content.ToolCalls()
1392 require.Len(t, toolCalls, 1)
1393 require.False(t, toolCalls[0].Invalid) // Should be valid after repair
1394 require.Equal(t, `{"value": "repaired"}`, toolCalls[0].Input) // Should have repaired input
1395 })
1396
1397 t.Run("Invalid tool call with failed repair", func(t *testing.T) {
1398 t.Parallel()
1399 model := &mockLanguageModel{
1400 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1401 return &Response{
1402 Content: ResponseContent{
1403 TextContent{Text: "Response"},
1404 ToolCallContent{
1405 ToolCallID: "call1",
1406 ToolName: "test_tool",
1407 Input: `{"wrong_field": "test"}`, // Missing required field
1408 },
1409 },
1410 Usage: Usage{TotalTokens: 10},
1411 FinishReason: FinishReasonStop, // Changed to stop
1412 }, nil
1413 },
1414 }
1415
1416 tool := &mockTool{
1417 name: "test_tool",
1418 description: "Test tool",
1419 parameters: map[string]any{
1420 "value": map[string]any{"type": "string"},
1421 },
1422 required: []string{"value"},
1423 }
1424
1425 repairFunc := func(ctx context.Context, options ToolCallRepairOptions) (*ToolCallContent, error) {
1426 // Repair function fails
1427 return nil, errors.New("repair failed")
1428 }
1429
1430 agent := NewAgent(model, WithTools(tool), WithRepairToolCall(repairFunc), WithStopConditions(StepCountIs(2)))
1431
1432 result, err := agent.Generate(context.Background(), AgentCall{
1433 Prompt: "test prompt",
1434 })
1435
1436 require.NoError(t, err)
1437 require.NotNil(t, result)
1438 require.Len(t, result.Steps, 1) // Only one step
1439
1440 // Check that tool call was marked as invalid since repair failed
1441 toolCalls := result.Steps[0].Content.ToolCalls()
1442 require.Len(t, toolCalls, 1)
1443 require.True(t, toolCalls[0].Invalid) // Should be invalid
1444 require.Contains(t, toolCalls[0].ValidationError.Error(), "missing required parameter: value")
1445 })
1446
1447 t.Run("Nonexistent tool call", func(t *testing.T) {
1448 t.Parallel()
1449 model := &mockLanguageModel{
1450 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1451 return &Response{
1452 Content: ResponseContent{
1453 TextContent{Text: "Response"},
1454 ToolCallContent{
1455 ToolCallID: "call1",
1456 ToolName: "nonexistent_tool",
1457 Input: `{"value": "test"}`,
1458 },
1459 },
1460 Usage: Usage{TotalTokens: 10},
1461 FinishReason: FinishReasonStop, // Changed to stop
1462 }, nil
1463 },
1464 }
1465
1466 tool := &mockTool{name: "test_tool", description: "Test tool"}
1467
1468 agent := NewAgent(model, WithTools(tool), WithStopConditions(StepCountIs(2)))
1469
1470 result, err := agent.Generate(context.Background(), AgentCall{
1471 Prompt: "test prompt",
1472 })
1473
1474 require.NoError(t, err)
1475 require.NotNil(t, result)
1476 require.Len(t, result.Steps, 1) // Only one step
1477
1478 // Check that tool call was marked as invalid due to nonexistent tool
1479 toolCalls := result.Steps[0].Content.ToolCalls()
1480 require.Len(t, toolCalls, 1)
1481 require.True(t, toolCalls[0].Invalid) // Should be invalid
1482 require.Contains(t, toolCalls[0].ValidationError.Error(), "tool not found: nonexistent_tool")
1483 })
1484
1485 t.Run("Invalid JSON in tool call", func(t *testing.T) {
1486 t.Parallel()
1487 model := &mockLanguageModel{
1488 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1489 return &Response{
1490 Content: ResponseContent{
1491 TextContent{Text: "Response"},
1492 ToolCallContent{
1493 ToolCallID: "call1",
1494 ToolName: "test_tool",
1495 Input: `{invalid json}`, // Invalid JSON
1496 },
1497 },
1498 Usage: Usage{TotalTokens: 10},
1499 FinishReason: FinishReasonStop, // Changed to stop
1500 }, nil
1501 },
1502 }
1503
1504 tool := &mockTool{
1505 name: "test_tool",
1506 description: "Test tool",
1507 parameters: map[string]any{
1508 "value": map[string]any{"type": "string"},
1509 },
1510 required: []string{"value"},
1511 }
1512
1513 agent := NewAgent(model, WithTools(tool), WithStopConditions(StepCountIs(2)))
1514
1515 result, err := agent.Generate(context.Background(), AgentCall{
1516 Prompt: "test prompt",
1517 })
1518
1519 require.NoError(t, err)
1520 require.NotNil(t, result)
1521 require.Len(t, result.Steps, 1) // Only one step
1522
1523 // Check that tool call was marked as invalid due to invalid JSON
1524 toolCalls := result.Steps[0].Content.ToolCalls()
1525 require.Len(t, toolCalls, 1)
1526 require.True(t, toolCalls[0].Invalid) // Should be invalid
1527 require.Contains(t, toolCalls[0].ValidationError.Error(), "invalid JSON input")
1528 })
1529}