1package ai
2
3import (
4 "context"
5 "encoding/json"
6 "errors"
7 "fmt"
8 "testing"
9
10 "github.com/stretchr/testify/require"
11)
12
13// Mock tool for testing
14type mockTool struct {
15 name string
16 description string
17 parameters map[string]any
18 required []string
19 executeFunc func(ctx context.Context, call ToolCall) (ToolResponse, error)
20}
21
22func (m *mockTool) Info() ToolInfo {
23 return ToolInfo{
24 Name: m.name,
25 Description: m.description,
26 Parameters: m.parameters,
27 Required: m.required,
28 }
29}
30
31func (m *mockTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
32 if m.executeFunc != nil {
33 return m.executeFunc(ctx, call)
34 }
35 return ToolResponse{Content: "mock result", IsError: false}, nil
36}
37
38// Mock language model for testing
39type mockLanguageModel struct {
40 generateFunc func(ctx context.Context, call Call) (*Response, error)
41 streamFunc func(ctx context.Context, call Call) (StreamResponse, error)
42}
43
44func (m *mockLanguageModel) Generate(ctx context.Context, call Call) (*Response, error) {
45 if m.generateFunc != nil {
46 return m.generateFunc(ctx, call)
47 }
48 return &Response{
49 Content: []Content{
50 TextContent{Text: "Hello, world!"},
51 },
52 Usage: Usage{
53 InputTokens: 3,
54 OutputTokens: 10,
55 TotalTokens: 13,
56 },
57 FinishReason: FinishReasonStop,
58 }, nil
59}
60
61func (m *mockLanguageModel) Stream(ctx context.Context, call Call) (StreamResponse, error) {
62 if m.streamFunc != nil {
63 return m.streamFunc(ctx, call)
64 }
65 return nil, fmt.Errorf("mock stream not implemented")
66}
67
68func (m *mockLanguageModel) Provider() string {
69 return "mock-provider"
70}
71
72func (m *mockLanguageModel) Model() string {
73 return "mock-model"
74}
75
76// Test result.content - comprehensive content types (matches TS test)
77func TestAgent_Generate_ResultContent_AllTypes(t *testing.T) {
78 t.Parallel()
79
80 // Create a type-safe tool using the new API
81 type TestInput struct {
82 Value string `json:"value" description:"Test value"`
83 }
84
85 tool1 := NewAgentTool(
86 "tool1",
87 "Test tool",
88 func(ctx context.Context, input TestInput, _ ToolCall) (ToolResponse, error) {
89 require.Equal(t, "value", input.Value)
90 return ToolResponse{Content: "result1", IsError: false}, nil
91 },
92 )
93
94 model := &mockLanguageModel{
95 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
96 return &Response{
97 Content: []Content{
98 TextContent{Text: "Hello, world!"},
99 SourceContent{
100 ID: "123",
101 URL: "https://example.com",
102 Title: "Example",
103 SourceType: SourceTypeURL,
104 ProviderMetadata: ProviderMetadata{
105 "provider": map[string]any{"custom": "value"},
106 },
107 },
108 FileContent{
109 Data: []byte{1, 2, 3},
110 MediaType: "image/png",
111 },
112 ReasoningContent{
113 Text: "I will open the conversation with witty banter.",
114 },
115 ToolCallContent{
116 ToolCallID: "call-1",
117 ToolName: "tool1",
118 Input: `{"value":"value"}`,
119 },
120 TextContent{Text: "More text"},
121 },
122 Usage: Usage{
123 InputTokens: 3,
124 OutputTokens: 10,
125 TotalTokens: 13,
126 },
127 FinishReason: FinishReasonStop, // Note: FinishReasonStop, not ToolCalls
128 }, nil
129 },
130 }
131
132 agent := NewAgent(model, WithTools(tool1))
133 result, err := agent.Generate(context.Background(), AgentCall{
134 Prompt: "prompt",
135 })
136
137 require.NoError(t, err)
138 require.NotNil(t, result)
139 require.Len(t, result.Steps, 1) // Single step like TypeScript
140
141 // Check final response content includes tool result
142 require.Len(t, result.Response.Content, 7) // original 6 + 1 tool result
143
144 // Verify each content type in order
145 textContent, ok := AsContentType[TextContent](result.Response.Content[0])
146 require.True(t, ok)
147 require.Equal(t, "Hello, world!", textContent.Text)
148
149 sourceContent, ok := AsContentType[SourceContent](result.Response.Content[1])
150 require.True(t, ok)
151 require.Equal(t, "123", sourceContent.ID)
152
153 fileContent, ok := AsContentType[FileContent](result.Response.Content[2])
154 require.True(t, ok)
155 require.Equal(t, []byte{1, 2, 3}, fileContent.Data)
156
157 reasoningContent, ok := AsContentType[ReasoningContent](result.Response.Content[3])
158 require.True(t, ok)
159 require.Equal(t, "I will open the conversation with witty banter.", reasoningContent.Text)
160
161 toolCallContent, ok := AsContentType[ToolCallContent](result.Response.Content[4])
162 require.True(t, ok)
163 require.Equal(t, "call-1", toolCallContent.ToolCallID)
164
165 moreTextContent, ok := AsContentType[TextContent](result.Response.Content[5])
166 require.True(t, ok)
167 require.Equal(t, "More text", moreTextContent.Text)
168
169 // Tool result should be appended
170 toolResultContent, ok := AsContentType[ToolResultContent](result.Response.Content[6])
171 require.True(t, ok)
172 require.Equal(t, "call-1", toolResultContent.ToolCallID)
173 require.Equal(t, "tool1", toolResultContent.ToolName)
174}
175
176// Test result.text extraction
177func TestAgent_Generate_ResultText(t *testing.T) {
178 t.Parallel()
179
180 model := &mockLanguageModel{
181 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
182 return &Response{
183 Content: []Content{
184 TextContent{Text: "Hello, world!"},
185 },
186 Usage: Usage{
187 InputTokens: 3,
188 OutputTokens: 10,
189 TotalTokens: 13,
190 },
191 FinishReason: FinishReasonStop,
192 }, nil
193 },
194 }
195
196 agent := NewAgent(model)
197 result, err := agent.Generate(context.Background(), AgentCall{
198 Prompt: "prompt",
199 })
200
201 require.NoError(t, err)
202 require.NotNil(t, result)
203
204 // Test text extraction from content
205 text := result.Response.Content.Text()
206 require.Equal(t, "Hello, world!", text)
207}
208
209// Test result.toolCalls extraction (matches TS test exactly)
210func TestAgent_Generate_ResultToolCalls(t *testing.T) {
211 t.Parallel()
212
213 // Create type-safe tools using the new API
214 type Tool1Input struct {
215 Value string `json:"value" description:"Test value"`
216 }
217
218 type Tool2Input struct {
219 SomethingElse string `json:"somethingElse" description:"Another test value"`
220 }
221
222 tool1 := NewAgentTool(
223 "tool1",
224 "Test tool 1",
225 func(ctx context.Context, input Tool1Input, _ ToolCall) (ToolResponse, error) {
226 return ToolResponse{Content: "result1", IsError: false}, nil
227 },
228 )
229
230 tool2 := NewAgentTool(
231 "tool2",
232 "Test tool 2",
233 func(ctx context.Context, input Tool2Input, _ ToolCall) (ToolResponse, error) {
234 return ToolResponse{Content: "result2", IsError: false}, nil
235 },
236 )
237
238 model := &mockLanguageModel{
239 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
240 // Verify tools are passed correctly
241 require.Len(t, call.Tools, 2)
242 require.Equal(t, ToolChoiceAuto, *call.ToolChoice) // Should be auto, not required
243
244 // Verify prompt structure
245 require.Len(t, call.Prompt, 1)
246 require.Equal(t, MessageRoleUser, call.Prompt[0].Role)
247
248 return &Response{
249 Content: []Content{
250 ToolCallContent{
251 ToolCallID: "call-1",
252 ToolName: "tool1",
253 Input: `{"value":"value"}`,
254 },
255 },
256 Usage: Usage{
257 InputTokens: 3,
258 OutputTokens: 10,
259 TotalTokens: 13,
260 },
261 FinishReason: FinishReasonStop, // Note: Stop, not ToolCalls
262 }, nil
263 },
264 }
265
266 agent := NewAgent(model, WithTools(tool1, tool2))
267 result, err := agent.Generate(context.Background(), AgentCall{
268 Prompt: "test-input",
269 })
270
271 require.NoError(t, err)
272 require.NotNil(t, result)
273 require.Len(t, result.Steps, 1) // Single step
274
275 // Extract tool calls from final response (should be empty since tools don't execute)
276 var toolCalls []ToolCallContent
277 for _, content := range result.Response.Content {
278 if toolCall, ok := AsContentType[ToolCallContent](content); ok {
279 toolCalls = append(toolCalls, toolCall)
280 }
281 }
282
283 require.Len(t, toolCalls, 1)
284 require.Equal(t, "call-1", toolCalls[0].ToolCallID)
285 require.Equal(t, "tool1", toolCalls[0].ToolName)
286
287 // Parse and verify input
288 var input map[string]any
289 err = json.Unmarshal([]byte(toolCalls[0].Input), &input)
290 require.NoError(t, err)
291 require.Equal(t, "value", input["value"])
292}
293
294// Test result.toolResults extraction (matches TS test exactly)
295func TestAgent_Generate_ResultToolResults(t *testing.T) {
296 t.Parallel()
297
298 // Create type-safe tool using the new API
299 type TestInput struct {
300 Value string `json:"value" description:"Test value"`
301 }
302
303 tool1 := NewAgentTool(
304 "tool1",
305 "Test tool",
306 func(ctx context.Context, input TestInput, _ ToolCall) (ToolResponse, error) {
307 require.Equal(t, "value", input.Value)
308 return ToolResponse{Content: "result1", IsError: false}, nil
309 },
310 )
311
312 model := &mockLanguageModel{
313 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
314 // Verify tools and tool choice
315 require.Len(t, call.Tools, 1)
316 require.Equal(t, ToolChoiceAuto, *call.ToolChoice)
317
318 // Verify prompt
319 require.Len(t, call.Prompt, 1)
320 require.Equal(t, MessageRoleUser, call.Prompt[0].Role)
321
322 return &Response{
323 Content: []Content{
324 ToolCallContent{
325 ToolCallID: "call-1",
326 ToolName: "tool1",
327 Input: `{"value":"value"}`,
328 },
329 },
330 Usage: Usage{
331 InputTokens: 3,
332 OutputTokens: 10,
333 TotalTokens: 13,
334 },
335 FinishReason: FinishReasonStop, // Note: Stop, not ToolCalls
336 }, nil
337 },
338 }
339
340 agent := NewAgent(model, WithTools(tool1))
341 result, err := agent.Generate(context.Background(), AgentCall{
342 Prompt: "test-input",
343 })
344
345 require.NoError(t, err)
346 require.NotNil(t, result)
347 require.Len(t, result.Steps, 1) // Single step
348
349 // Extract tool results from final response
350 var toolResults []ToolResultContent
351 for _, content := range result.Response.Content {
352 if toolResult, ok := AsContentType[ToolResultContent](content); ok {
353 toolResults = append(toolResults, toolResult)
354 }
355 }
356
357 require.Len(t, toolResults, 1)
358 require.Equal(t, "call-1", toolResults[0].ToolCallID)
359 require.Equal(t, "tool1", toolResults[0].ToolName)
360
361 // Verify result content
362 textResult, ok := toolResults[0].Result.(ToolResultOutputContentText)
363 require.True(t, ok)
364 require.Equal(t, "result1", textResult.Text)
365}
366
367// Test multi-step scenario (matches TS "2 steps: initial, tool-result" test)
368func TestAgent_Generate_MultipleSteps(t *testing.T) {
369 t.Parallel()
370
371 // Create type-safe tool using the new API
372 type TestInput struct {
373 Value string `json:"value" description:"Test value"`
374 }
375
376 tool1 := NewAgentTool(
377 "tool1",
378 "Test tool",
379 func(ctx context.Context, input TestInput, _ ToolCall) (ToolResponse, error) {
380 require.Equal(t, "value", input.Value)
381 return ToolResponse{Content: "result1", IsError: false}, nil
382 },
383 )
384
385 callCount := 0
386 model := &mockLanguageModel{
387 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
388 callCount++
389 switch callCount {
390 case 1:
391 // First call - return tool call with FinishReasonToolCalls
392 return &Response{
393 Content: []Content{
394 ToolCallContent{
395 ToolCallID: "call-1",
396 ToolName: "tool1",
397 Input: `{"value":"value"}`,
398 },
399 },
400 Usage: Usage{
401 InputTokens: 10,
402 OutputTokens: 5,
403 TotalTokens: 15,
404 },
405 FinishReason: FinishReasonToolCalls, // This triggers multi-step
406 }, nil
407 case 2:
408 // Second call - return final text
409 return &Response{
410 Content: []Content{
411 TextContent{Text: "Hello, world!"},
412 },
413 Usage: Usage{
414 InputTokens: 3,
415 OutputTokens: 10,
416 TotalTokens: 13,
417 },
418 FinishReason: FinishReasonStop,
419 }, nil
420 default:
421 t.Fatalf("Unexpected call count: %d", callCount)
422 return nil, nil
423 }
424 },
425 }
426
427 agent := NewAgent(model, WithTools(tool1))
428 result, err := agent.Generate(context.Background(), AgentCall{
429 Prompt: "test-input",
430 })
431
432 require.NoError(t, err)
433 require.NotNil(t, result)
434 require.Len(t, result.Steps, 2)
435
436 // Check total usage sums both steps
437 require.Equal(t, int64(13), result.TotalUsage.InputTokens) // 10 + 3
438 require.Equal(t, int64(15), result.TotalUsage.OutputTokens) // 5 + 10
439 require.Equal(t, int64(28), result.TotalUsage.TotalTokens) // 15 + 13
440
441 // Final response should be from last step
442 require.Len(t, result.Response.Content, 1)
443 textContent, ok := AsContentType[TextContent](result.Response.Content[0])
444 require.True(t, ok)
445 require.Equal(t, "Hello, world!", textContent.Text)
446
447 // result.toolCalls should be empty (from last step)
448 var toolCalls []ToolCallContent
449 for _, content := range result.Response.Content {
450 if _, ok := AsContentType[ToolCallContent](content); ok {
451 toolCalls = append(toolCalls, content.(ToolCallContent))
452 }
453 }
454 require.Len(t, toolCalls, 0)
455
456 // result.toolResults should be empty (from last step)
457 var toolResults []ToolResultContent
458 for _, content := range result.Response.Content {
459 if _, ok := AsContentType[ToolResultContent](content); ok {
460 toolResults = append(toolResults, content.(ToolResultContent))
461 }
462 }
463 require.Len(t, toolResults, 0)
464}
465
466// Test basic text generation
467func TestAgent_Generate_BasicText(t *testing.T) {
468 t.Parallel()
469
470 model := &mockLanguageModel{
471 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
472 return &Response{
473 Content: []Content{
474 TextContent{Text: "Hello, world!"},
475 },
476 Usage: Usage{
477 InputTokens: 3,
478 OutputTokens: 10,
479 TotalTokens: 13,
480 },
481 FinishReason: FinishReasonStop,
482 }, nil
483 },
484 }
485
486 agent := NewAgent(model)
487 result, err := agent.Generate(context.Background(), AgentCall{
488 Prompt: "test prompt",
489 })
490
491 require.NoError(t, err)
492 require.NotNil(t, result)
493 require.Len(t, result.Steps, 1)
494
495 // Check final response
496 require.Len(t, result.Response.Content, 1)
497 textContent, ok := AsContentType[TextContent](result.Response.Content[0])
498 require.True(t, ok)
499 require.Equal(t, "Hello, world!", textContent.Text)
500
501 // Check usage
502 require.Equal(t, int64(3), result.Response.Usage.InputTokens)
503 require.Equal(t, int64(10), result.Response.Usage.OutputTokens)
504 require.Equal(t, int64(13), result.Response.Usage.TotalTokens)
505
506 // Check total usage
507 require.Equal(t, int64(3), result.TotalUsage.InputTokens)
508 require.Equal(t, int64(10), result.TotalUsage.OutputTokens)
509 require.Equal(t, int64(13), result.TotalUsage.TotalTokens)
510}
511
512// Test empty prompt error
513func TestAgent_Generate_EmptyPrompt(t *testing.T) {
514 t.Parallel()
515
516 model := &mockLanguageModel{}
517 agent := NewAgent(model)
518
519 result, err := agent.Generate(context.Background(), AgentCall{
520 Prompt: "", // Empty prompt should cause error
521 })
522
523 require.Error(t, err)
524 require.Nil(t, result)
525 require.Contains(t, err.Error(), "Prompt can't be empty")
526}
527
528// Test with system prompt
529func TestAgent_Generate_WithSystemPrompt(t *testing.T) {
530 t.Parallel()
531
532 model := &mockLanguageModel{
533 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
534 // Verify system message is included
535 require.Len(t, call.Prompt, 2) // system + user
536 require.Equal(t, MessageRoleSystem, call.Prompt[0].Role)
537 require.Equal(t, MessageRoleUser, call.Prompt[1].Role)
538
539 systemPart, ok := call.Prompt[0].Content[0].(TextPart)
540 require.True(t, ok)
541 require.Equal(t, "You are a helpful assistant", systemPart.Text)
542
543 return &Response{
544 Content: []Content{
545 TextContent{Text: "Hello, world!"},
546 },
547 Usage: Usage{
548 InputTokens: 3,
549 OutputTokens: 10,
550 TotalTokens: 13,
551 },
552 FinishReason: FinishReasonStop,
553 }, nil
554 },
555 }
556
557 agent := NewAgent(model, WithSystemPrompt("You are a helpful assistant"))
558 result, err := agent.Generate(context.Background(), AgentCall{
559 Prompt: "test prompt",
560 })
561
562 require.NoError(t, err)
563 require.NotNil(t, result)
564}
565
566// Test options.activeTools filtering
567func TestAgent_Generate_OptionsActiveTools(t *testing.T) {
568 t.Parallel()
569
570 tool1 := &mockTool{
571 name: "tool1",
572 description: "Test tool 1",
573 parameters: map[string]any{
574 "value": map[string]any{"type": "string"},
575 },
576 required: []string{"value"},
577 }
578
579 tool2 := &mockTool{
580 name: "tool2",
581 description: "Test tool 2",
582 parameters: map[string]any{
583 "value": map[string]any{"type": "string"},
584 },
585 required: []string{"value"},
586 }
587
588 model := &mockLanguageModel{
589 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
590 // Verify only tool1 is available
591 require.Len(t, call.Tools, 1)
592 functionTool, ok := call.Tools[0].(FunctionTool)
593 require.True(t, ok)
594 require.Equal(t, "tool1", functionTool.Name)
595
596 return &Response{
597 Content: []Content{
598 TextContent{Text: "Hello, world!"},
599 },
600 Usage: Usage{
601 InputTokens: 3,
602 OutputTokens: 10,
603 TotalTokens: 13,
604 },
605 FinishReason: FinishReasonStop,
606 }, nil
607 },
608 }
609
610 agent := NewAgent(model, WithTools(tool1, tool2))
611 result, err := agent.Generate(context.Background(), AgentCall{
612 Prompt: "test-input",
613 ActiveTools: []string{"tool1"}, // Only tool1 should be active
614 })
615
616 require.NoError(t, err)
617 require.NotNil(t, result)
618}
619
620func TestResponseContent_Getters(t *testing.T) {
621 t.Parallel()
622
623 // Create test content with all types
624 content := ResponseContent{
625 TextContent{Text: "Hello world"},
626 ReasoningContent{Text: "Let me think..."},
627 FileContent{Data: []byte("file data"), MediaType: "text/plain"},
628 SourceContent{SourceType: SourceTypeURL, URL: "https://example.com", Title: "Example"},
629 ToolCallContent{ToolCallID: "call1", ToolName: "test_tool", Input: `{"arg": "value"}`},
630 ToolResultContent{ToolCallID: "call1", ToolName: "test_tool", Result: ToolResultOutputContentText{Text: "result"}},
631 }
632
633 // Test Text()
634 require.Equal(t, "Hello world", content.Text())
635
636 // Test Reasoning()
637 reasoning := content.Reasoning()
638 require.Len(t, reasoning, 1)
639 require.Equal(t, "Let me think...", reasoning[0].Text)
640
641 // Test ReasoningText()
642 require.Equal(t, "Let me think...", content.ReasoningText())
643
644 // Test Files()
645 files := content.Files()
646 require.Len(t, files, 1)
647 require.Equal(t, "text/plain", files[0].MediaType)
648 require.Equal(t, []byte("file data"), files[0].Data)
649
650 // Test Sources()
651 sources := content.Sources()
652 require.Len(t, sources, 1)
653 require.Equal(t, SourceTypeURL, sources[0].SourceType)
654 require.Equal(t, "https://example.com", sources[0].URL)
655 require.Equal(t, "Example", sources[0].Title)
656
657 // Test ToolCalls()
658 toolCalls := content.ToolCalls()
659 require.Len(t, toolCalls, 1)
660 require.Equal(t, "call1", toolCalls[0].ToolCallID)
661 require.Equal(t, "test_tool", toolCalls[0].ToolName)
662 require.Equal(t, `{"arg": "value"}`, toolCalls[0].Input)
663
664 // Test ToolResults()
665 toolResults := content.ToolResults()
666 require.Len(t, toolResults, 1)
667 require.Equal(t, "call1", toolResults[0].ToolCallID)
668 require.Equal(t, "test_tool", toolResults[0].ToolName)
669 result, ok := AsToolResultOutputType[ToolResultOutputContentText](toolResults[0].Result)
670 require.True(t, ok)
671 require.Equal(t, "result", result.Text)
672}
673
674func TestResponseContent_Getters_Empty(t *testing.T) {
675 t.Parallel()
676
677 // Test with empty content
678 content := ResponseContent{}
679
680 require.Equal(t, "", content.Text())
681 require.Equal(t, "", content.ReasoningText())
682 require.Empty(t, content.Reasoning())
683 require.Empty(t, content.Files())
684 require.Empty(t, content.Sources())
685 require.Empty(t, content.ToolCalls())
686 require.Empty(t, content.ToolResults())
687}
688
689func TestResponseContent_Getters_MultipleItems(t *testing.T) {
690 t.Parallel()
691
692 // Test with multiple items of same type
693 content := ResponseContent{
694 ReasoningContent{Text: "First thought"},
695 ReasoningContent{Text: "Second thought"},
696 FileContent{Data: []byte("file1"), MediaType: "text/plain"},
697 FileContent{Data: []byte("file2"), MediaType: "image/png"},
698 }
699
700 // Test multiple reasoning
701 reasoning := content.Reasoning()
702 require.Len(t, reasoning, 2)
703 require.Equal(t, "First thought", reasoning[0].Text)
704 require.Equal(t, "Second thought", reasoning[1].Text)
705
706 // Test concatenated reasoning text
707 require.Equal(t, "First thoughtSecond thought", content.ReasoningText())
708
709 // Test multiple files
710 files := content.Files()
711 require.Len(t, files, 2)
712 require.Equal(t, "text/plain", files[0].MediaType)
713 require.Equal(t, "image/png", files[1].MediaType)
714}
715
716func TestStopConditions(t *testing.T) {
717 t.Parallel()
718
719 // Create test steps
720 step1 := StepResult{
721 Response: Response{
722 Content: ResponseContent{
723 TextContent{Text: "Hello"},
724 },
725 FinishReason: FinishReasonToolCalls,
726 Usage: Usage{TotalTokens: 10},
727 },
728 }
729
730 step2 := StepResult{
731 Response: Response{
732 Content: ResponseContent{
733 TextContent{Text: "World"},
734 ToolCallContent{ToolCallID: "call1", ToolName: "search", Input: `{"query": "test"}`},
735 },
736 FinishReason: FinishReasonStop,
737 Usage: Usage{TotalTokens: 15},
738 },
739 }
740
741 step3 := StepResult{
742 Response: Response{
743 Content: ResponseContent{
744 ReasoningContent{Text: "Let me think..."},
745 FileContent{Data: []byte("data"), MediaType: "text/plain"},
746 },
747 FinishReason: FinishReasonLength,
748 Usage: Usage{TotalTokens: 20},
749 },
750 }
751
752 t.Run("StepCountIs", func(t *testing.T) {
753 t.Parallel()
754 condition := StepCountIs(2)
755
756 // Should not stop with 1 step
757 require.False(t, condition([]StepResult{step1}))
758
759 // Should stop with 2 steps
760 require.True(t, condition([]StepResult{step1, step2}))
761
762 // Should stop with more than 2 steps
763 require.True(t, condition([]StepResult{step1, step2, step3}))
764
765 // Should not stop with empty steps
766 require.False(t, condition([]StepResult{}))
767 })
768
769 t.Run("HasToolCall", func(t *testing.T) {
770 t.Parallel()
771 condition := HasToolCall("search")
772
773 // Should not stop when tool not called
774 require.False(t, condition([]StepResult{step1}))
775
776 // Should stop when tool is called in last step
777 require.True(t, condition([]StepResult{step1, step2}))
778
779 // Should not stop when tool called in earlier step but not last
780 require.False(t, condition([]StepResult{step1, step2, step3}))
781
782 // Should not stop with empty steps
783 require.False(t, condition([]StepResult{}))
784
785 // Should not stop when different tool is called
786 differentToolCondition := HasToolCall("different_tool")
787 require.False(t, differentToolCondition([]StepResult{step1, step2}))
788 })
789
790 t.Run("HasContent", func(t *testing.T) {
791 t.Parallel()
792 reasoningCondition := HasContent(ContentTypeReasoning)
793 fileCondition := HasContent(ContentTypeFile)
794
795 // Should not stop when content type not present
796 require.False(t, reasoningCondition([]StepResult{step1, step2}))
797
798 // Should stop when content type is present in last step
799 require.True(t, reasoningCondition([]StepResult{step1, step2, step3}))
800 require.True(t, fileCondition([]StepResult{step1, step2, step3}))
801
802 // Should not stop with empty steps
803 require.False(t, reasoningCondition([]StepResult{}))
804 })
805
806 t.Run("FinishReasonIs", func(t *testing.T) {
807 t.Parallel()
808 stopCondition := FinishReasonIs(FinishReasonStop)
809 lengthCondition := FinishReasonIs(FinishReasonLength)
810
811 // Should not stop when finish reason doesn't match
812 require.False(t, stopCondition([]StepResult{step1}))
813
814 // Should stop when finish reason matches in last step
815 require.True(t, stopCondition([]StepResult{step1, step2}))
816 require.True(t, lengthCondition([]StepResult{step1, step2, step3}))
817
818 // Should not stop with empty steps
819 require.False(t, stopCondition([]StepResult{}))
820 })
821
822 t.Run("MaxTokensUsed", func(t *testing.T) {
823 condition := MaxTokensUsed(30)
824
825 // Should not stop when under limit
826 require.False(t, condition([]StepResult{step1})) // 10 tokens
827 require.False(t, condition([]StepResult{step1, step2})) // 25 tokens
828
829 // Should stop when at or over limit
830 require.True(t, condition([]StepResult{step1, step2, step3})) // 45 tokens
831
832 // Should not stop with empty steps
833 require.False(t, condition([]StepResult{}))
834 })
835}
836
837func TestStopConditions_Integration(t *testing.T) {
838 t.Parallel()
839
840 t.Run("StepCountIs integration", func(t *testing.T) {
841 t.Parallel()
842 model := &mockLanguageModel{
843 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
844 return &Response{
845 Content: ResponseContent{
846 TextContent{Text: "Mock response"},
847 },
848 Usage: Usage{
849 InputTokens: 3,
850 OutputTokens: 10,
851 TotalTokens: 13,
852 },
853 FinishReason: FinishReasonStop,
854 }, nil
855 },
856 }
857
858 agent := NewAgent(model, WithStopConditions(StepCountIs(1)))
859
860 result, err := agent.Generate(context.Background(), AgentCall{
861 Prompt: "test prompt",
862 })
863
864 require.NoError(t, err)
865 require.NotNil(t, result)
866 require.Len(t, result.Steps, 1) // Should stop after 1 step
867 })
868
869 t.Run("Multiple stop conditions", func(t *testing.T) {
870 t.Parallel()
871 model := &mockLanguageModel{
872 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
873 return &Response{
874 Content: ResponseContent{
875 TextContent{Text: "Mock response"},
876 },
877 Usage: Usage{
878 InputTokens: 3,
879 OutputTokens: 10,
880 TotalTokens: 13,
881 },
882 FinishReason: FinishReasonStop,
883 }, nil
884 },
885 }
886
887 agent := NewAgent(model, WithStopConditions(
888 StepCountIs(5), // Stop after 5 steps
889 FinishReasonIs(FinishReasonStop), // Or stop on finish reason
890 ))
891
892 result, err := agent.Generate(context.Background(), AgentCall{
893 Prompt: "test prompt",
894 })
895
896 require.NoError(t, err)
897 require.NotNil(t, result)
898 // Should stop on first condition met (finish reason stop)
899 require.Equal(t, FinishReasonStop, result.Response.FinishReason)
900 })
901}
902
903func TestPrepareStep(t *testing.T) {
904 t.Parallel()
905
906 t.Run("System prompt modification", func(t *testing.T) {
907 t.Parallel()
908 var capturedSystemPrompt string
909 model := &mockLanguageModel{
910 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
911 // Capture the system message to verify it was modified
912 if len(call.Prompt) > 0 && call.Prompt[0].Role == MessageRoleSystem {
913 if len(call.Prompt[0].Content) > 0 {
914 if textPart, ok := AsContentType[TextPart](call.Prompt[0].Content[0]); ok {
915 capturedSystemPrompt = textPart.Text
916 }
917 }
918 }
919 return &Response{
920 Content: ResponseContent{
921 TextContent{Text: "Response"},
922 },
923 Usage: Usage{TotalTokens: 10},
924 FinishReason: FinishReasonStop,
925 }, nil
926 },
927 }
928
929 prepareStepFunc := func(options PrepareStepFunctionOptions) (PrepareStepResult, error) {
930 newSystem := "Modified system prompt for step " + fmt.Sprintf("%d", options.StepNumber)
931 return PrepareStepResult{
932 Model: options.Model,
933 Messages: options.Messages,
934 System: &newSystem,
935 }, nil
936 }
937
938 agent := NewAgent(model, WithSystemPrompt("Original system prompt"))
939
940 result, err := agent.Generate(context.Background(), AgentCall{
941 Prompt: "test prompt",
942 PrepareStep: prepareStepFunc,
943 })
944
945 require.NoError(t, err)
946 require.NotNil(t, result)
947 require.Equal(t, "Modified system prompt for step 0", capturedSystemPrompt)
948 })
949
950 t.Run("Tool choice modification", func(t *testing.T) {
951 t.Parallel()
952 var capturedToolChoice *ToolChoice
953 model := &mockLanguageModel{
954 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
955 capturedToolChoice = call.ToolChoice
956 return &Response{
957 Content: ResponseContent{
958 TextContent{Text: "Response"},
959 },
960 Usage: Usage{TotalTokens: 10},
961 FinishReason: FinishReasonStop,
962 }, nil
963 },
964 }
965
966 prepareStepFunc := func(options PrepareStepFunctionOptions) (PrepareStepResult, error) {
967 toolChoice := ToolChoiceNone
968 return PrepareStepResult{
969 Model: options.Model,
970 Messages: options.Messages,
971 ToolChoice: &toolChoice,
972 }, nil
973 }
974
975 agent := NewAgent(model)
976
977 result, err := agent.Generate(context.Background(), AgentCall{
978 Prompt: "test prompt",
979 PrepareStep: prepareStepFunc,
980 })
981
982 require.NoError(t, err)
983 require.NotNil(t, result)
984 require.NotNil(t, capturedToolChoice)
985 require.Equal(t, ToolChoiceNone, *capturedToolChoice)
986 })
987
988 t.Run("Active tools modification", func(t *testing.T) {
989 t.Parallel()
990 var capturedToolNames []string
991 model := &mockLanguageModel{
992 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
993 // Capture tool names to verify active tools were modified
994 for _, tool := range call.Tools {
995 capturedToolNames = append(capturedToolNames, tool.GetName())
996 }
997 return &Response{
998 Content: ResponseContent{
999 TextContent{Text: "Response"},
1000 },
1001 Usage: Usage{TotalTokens: 10},
1002 FinishReason: FinishReasonStop,
1003 }, nil
1004 },
1005 }
1006
1007 tool1 := &mockTool{name: "tool1", description: "Tool 1"}
1008 tool2 := &mockTool{name: "tool2", description: "Tool 2"}
1009 tool3 := &mockTool{name: "tool3", description: "Tool 3"}
1010
1011 prepareStepFunc := func(options PrepareStepFunctionOptions) (PrepareStepResult, error) {
1012 activeTools := []string{"tool2"} // Only tool2 should be active
1013 return PrepareStepResult{
1014 Model: options.Model,
1015 Messages: options.Messages,
1016 ActiveTools: activeTools,
1017 }, nil
1018 }
1019
1020 agent := NewAgent(model, WithTools(tool1, tool2, tool3))
1021
1022 result, err := agent.Generate(context.Background(), AgentCall{
1023 Prompt: "test prompt",
1024 PrepareStep: prepareStepFunc,
1025 })
1026
1027 require.NoError(t, err)
1028 require.NotNil(t, result)
1029 require.Len(t, capturedToolNames, 1)
1030 require.Equal(t, "tool2", capturedToolNames[0])
1031 })
1032
1033 t.Run("No tools when DisableAllTools is true", func(t *testing.T) {
1034 t.Parallel()
1035 var capturedToolCount int
1036 model := &mockLanguageModel{
1037 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1038 capturedToolCount = len(call.Tools)
1039 return &Response{
1040 Content: ResponseContent{
1041 TextContent{Text: "Response"},
1042 },
1043 Usage: Usage{TotalTokens: 10},
1044 FinishReason: FinishReasonStop,
1045 }, nil
1046 },
1047 }
1048
1049 tool1 := &mockTool{name: "tool1", description: "Tool 1"}
1050
1051 prepareStepFunc := func(options PrepareStepFunctionOptions) (PrepareStepResult, error) {
1052 return PrepareStepResult{
1053 Model: options.Model,
1054 Messages: options.Messages,
1055 DisableAllTools: true, // Disable all tools for this step
1056 }, nil
1057 }
1058
1059 agent := NewAgent(model, WithTools(tool1))
1060
1061 result, err := agent.Generate(context.Background(), AgentCall{
1062 Prompt: "test prompt",
1063 PrepareStep: prepareStepFunc,
1064 })
1065
1066 require.NoError(t, err)
1067 require.NotNil(t, result)
1068 require.Equal(t, 0, capturedToolCount) // No tools should be passed
1069 })
1070
1071 t.Run("All fields modified together", func(t *testing.T) {
1072 t.Parallel()
1073 var capturedSystemPrompt string
1074 var capturedToolChoice *ToolChoice
1075 var capturedToolNames []string
1076
1077 model := &mockLanguageModel{
1078 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1079 // Capture system prompt
1080 if len(call.Prompt) > 0 && call.Prompt[0].Role == MessageRoleSystem {
1081 if len(call.Prompt[0].Content) > 0 {
1082 if textPart, ok := AsContentType[TextPart](call.Prompt[0].Content[0]); ok {
1083 capturedSystemPrompt = textPart.Text
1084 }
1085 }
1086 }
1087 // Capture tool choice
1088 capturedToolChoice = call.ToolChoice
1089 // Capture tool names
1090 for _, tool := range call.Tools {
1091 capturedToolNames = append(capturedToolNames, tool.GetName())
1092 }
1093 return &Response{
1094 Content: ResponseContent{
1095 TextContent{Text: "Response"},
1096 },
1097 Usage: Usage{TotalTokens: 10},
1098 FinishReason: FinishReasonStop,
1099 }, nil
1100 },
1101 }
1102
1103 tool1 := &mockTool{name: "tool1", description: "Tool 1"}
1104 tool2 := &mockTool{name: "tool2", description: "Tool 2"}
1105
1106 prepareStepFunc := func(options PrepareStepFunctionOptions) (PrepareStepResult, error) {
1107 newSystem := "Step-specific system"
1108 toolChoice := SpecificToolChoice("tool1")
1109 activeTools := []string{"tool1"}
1110 return PrepareStepResult{
1111 Model: options.Model,
1112 Messages: options.Messages,
1113 System: &newSystem,
1114 ToolChoice: &toolChoice,
1115 ActiveTools: activeTools,
1116 }, nil
1117 }
1118
1119 agent := NewAgent(model, WithSystemPrompt("Original system"), WithTools(tool1, tool2))
1120
1121 result, err := agent.Generate(context.Background(), AgentCall{
1122 Prompt: "test prompt",
1123 PrepareStep: prepareStepFunc,
1124 })
1125
1126 require.NoError(t, err)
1127 require.NotNil(t, result)
1128 require.Equal(t, "Step-specific system", capturedSystemPrompt)
1129 require.NotNil(t, capturedToolChoice)
1130 require.Equal(t, SpecificToolChoice("tool1"), *capturedToolChoice)
1131 require.Len(t, capturedToolNames, 1)
1132 require.Equal(t, "tool1", capturedToolNames[0])
1133 })
1134
1135 t.Run("Nil fields use parent values", func(t *testing.T) {
1136 t.Parallel()
1137 var capturedSystemPrompt string
1138 var capturedToolChoice *ToolChoice
1139 var capturedToolNames []string
1140
1141 model := &mockLanguageModel{
1142 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1143 // Capture system prompt
1144 if len(call.Prompt) > 0 && call.Prompt[0].Role == MessageRoleSystem {
1145 if len(call.Prompt[0].Content) > 0 {
1146 if textPart, ok := AsContentType[TextPart](call.Prompt[0].Content[0]); ok {
1147 capturedSystemPrompt = textPart.Text
1148 }
1149 }
1150 }
1151 // Capture tool choice
1152 capturedToolChoice = call.ToolChoice
1153 // Capture tool names
1154 for _, tool := range call.Tools {
1155 capturedToolNames = append(capturedToolNames, tool.GetName())
1156 }
1157 return &Response{
1158 Content: ResponseContent{
1159 TextContent{Text: "Response"},
1160 },
1161 Usage: Usage{TotalTokens: 10},
1162 FinishReason: FinishReasonStop,
1163 }, nil
1164 },
1165 }
1166
1167 tool1 := &mockTool{name: "tool1", description: "Tool 1"}
1168
1169 prepareStepFunc := func(options PrepareStepFunctionOptions) (PrepareStepResult, error) {
1170 // All optional fields are nil, should use parent values
1171 return PrepareStepResult{
1172 Model: options.Model,
1173 Messages: options.Messages,
1174 System: nil, // Use parent
1175 ToolChoice: nil, // Use parent (auto)
1176 ActiveTools: nil, // Use parent (all tools)
1177 }, nil
1178 }
1179
1180 agent := NewAgent(model, WithSystemPrompt("Parent system"), WithTools(tool1))
1181
1182 result, err := agent.Generate(context.Background(), AgentCall{
1183 Prompt: "test prompt",
1184 PrepareStep: prepareStepFunc,
1185 })
1186
1187 require.NoError(t, err)
1188 require.NotNil(t, result)
1189 require.Equal(t, "Parent system", capturedSystemPrompt)
1190 require.NotNil(t, capturedToolChoice)
1191 require.Equal(t, ToolChoiceAuto, *capturedToolChoice) // Default
1192 require.Len(t, capturedToolNames, 1)
1193 require.Equal(t, "tool1", capturedToolNames[0])
1194 })
1195
1196 t.Run("Empty ActiveTools means all tools", func(t *testing.T) {
1197 t.Parallel()
1198 var capturedToolNames []string
1199 model := &mockLanguageModel{
1200 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1201 // Capture tool names to verify all tools are included
1202 for _, tool := range call.Tools {
1203 capturedToolNames = append(capturedToolNames, tool.GetName())
1204 }
1205 return &Response{
1206 Content: ResponseContent{
1207 TextContent{Text: "Response"},
1208 },
1209 Usage: Usage{TotalTokens: 10},
1210 FinishReason: FinishReasonStop,
1211 }, nil
1212 },
1213 }
1214
1215 tool1 := &mockTool{name: "tool1", description: "Tool 1"}
1216 tool2 := &mockTool{name: "tool2", description: "Tool 2"}
1217
1218 prepareStepFunc := func(options PrepareStepFunctionOptions) (PrepareStepResult, error) {
1219 return PrepareStepResult{
1220 Model: options.Model,
1221 Messages: options.Messages,
1222 ActiveTools: []string{}, // Empty slice means all tools
1223 }, nil
1224 }
1225
1226 agent := NewAgent(model, WithTools(tool1, tool2))
1227
1228 result, err := agent.Generate(context.Background(), AgentCall{
1229 Prompt: "test prompt",
1230 PrepareStep: prepareStepFunc,
1231 })
1232
1233 require.NoError(t, err)
1234 require.NotNil(t, result)
1235 require.Len(t, capturedToolNames, 2) // All tools should be included
1236 require.Contains(t, capturedToolNames, "tool1")
1237 require.Contains(t, capturedToolNames, "tool2")
1238 })
1239}
1240
1241func TestToolCallRepair(t *testing.T) {
1242 t.Parallel()
1243
1244 t.Run("Valid tool call passes validation", func(t *testing.T) {
1245 t.Parallel()
1246 model := &mockLanguageModel{
1247 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1248 return &Response{
1249 Content: ResponseContent{
1250 TextContent{Text: "Response"},
1251 ToolCallContent{
1252 ToolCallID: "call1",
1253 ToolName: "test_tool",
1254 Input: `{"value": "test"}`, // Valid JSON with required field
1255 },
1256 },
1257 Usage: Usage{TotalTokens: 10},
1258 FinishReason: FinishReasonStop, // Changed to stop to avoid infinite loop
1259 }, nil
1260 },
1261 }
1262
1263 tool := &mockTool{
1264 name: "test_tool",
1265 description: "Test tool",
1266 parameters: map[string]any{
1267 "value": map[string]any{"type": "string"},
1268 },
1269 required: []string{"value"},
1270 executeFunc: func(ctx context.Context, call ToolCall) (ToolResponse, error) {
1271 return ToolResponse{Content: "success", IsError: false}, nil
1272 },
1273 }
1274
1275 agent := NewAgent(model, WithTools(tool), WithStopConditions(StepCountIs(2))) // Limit steps
1276
1277 result, err := agent.Generate(context.Background(), AgentCall{
1278 Prompt: "test prompt",
1279 })
1280
1281 require.NoError(t, err)
1282 require.NotNil(t, result)
1283 require.Len(t, result.Steps, 1) // Only one step since FinishReason is stop
1284
1285 // Check that tool call was executed successfully
1286 toolCalls := result.Steps[0].Content.ToolCalls()
1287 require.Len(t, toolCalls, 1)
1288 require.False(t, toolCalls[0].Invalid) // Should be valid
1289 })
1290
1291 t.Run("Invalid tool call without repair function", func(t *testing.T) {
1292 t.Parallel()
1293 model := &mockLanguageModel{
1294 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1295 return &Response{
1296 Content: ResponseContent{
1297 TextContent{Text: "Response"},
1298 ToolCallContent{
1299 ToolCallID: "call1",
1300 ToolName: "test_tool",
1301 Input: `{"wrong_field": "test"}`, // Missing required field
1302 },
1303 },
1304 Usage: Usage{TotalTokens: 10},
1305 FinishReason: FinishReasonStop, // Changed to stop to avoid infinite loop
1306 }, nil
1307 },
1308 }
1309
1310 tool := &mockTool{
1311 name: "test_tool",
1312 description: "Test tool",
1313 parameters: map[string]any{
1314 "value": map[string]any{"type": "string"},
1315 },
1316 required: []string{"value"},
1317 }
1318
1319 agent := NewAgent(model, WithTools(tool), WithStopConditions(StepCountIs(2))) // Limit steps
1320
1321 result, err := agent.Generate(context.Background(), AgentCall{
1322 Prompt: "test prompt",
1323 })
1324
1325 require.NoError(t, err)
1326 require.NotNil(t, result)
1327 require.Len(t, result.Steps, 1) // Only one step
1328
1329 // Check that tool call was marked as invalid
1330 toolCalls := result.Steps[0].Content.ToolCalls()
1331 require.Len(t, toolCalls, 1)
1332 require.True(t, toolCalls[0].Invalid) // Should be invalid
1333 require.Contains(t, toolCalls[0].ValidationError.Error(), "missing required parameter: value")
1334 })
1335
1336 t.Run("Invalid tool call with successful repair", func(t *testing.T) {
1337 t.Parallel()
1338 model := &mockLanguageModel{
1339 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1340 return &Response{
1341 Content: ResponseContent{
1342 TextContent{Text: "Response"},
1343 ToolCallContent{
1344 ToolCallID: "call1",
1345 ToolName: "test_tool",
1346 Input: `{"wrong_field": "test"}`, // Missing required field
1347 },
1348 },
1349 Usage: Usage{TotalTokens: 10},
1350 FinishReason: FinishReasonStop, // Changed to stop
1351 }, nil
1352 },
1353 }
1354
1355 tool := &mockTool{
1356 name: "test_tool",
1357 description: "Test tool",
1358 parameters: map[string]any{
1359 "value": map[string]any{"type": "string"},
1360 },
1361 required: []string{"value"},
1362 executeFunc: func(ctx context.Context, call ToolCall) (ToolResponse, error) {
1363 return ToolResponse{Content: "repaired_success", IsError: false}, nil
1364 },
1365 }
1366
1367 repairFunc := func(ctx context.Context, options ToolCallRepairOptions) (*ToolCallContent, error) {
1368 // Simple repair: add the missing required field
1369 repairedToolCall := options.OriginalToolCall
1370 repairedToolCall.Input = `{"value": "repaired"}`
1371 return &repairedToolCall, nil
1372 }
1373
1374 agent := NewAgent(model, WithTools(tool), WithRepairToolCall(repairFunc), WithStopConditions(StepCountIs(2)))
1375
1376 result, err := agent.Generate(context.Background(), AgentCall{
1377 Prompt: "test prompt",
1378 })
1379
1380 require.NoError(t, err)
1381 require.NotNil(t, result)
1382 require.Len(t, result.Steps, 1) // Only one step
1383
1384 // Check that tool call was repaired and is now valid
1385 toolCalls := result.Steps[0].Content.ToolCalls()
1386 require.Len(t, toolCalls, 1)
1387 require.False(t, toolCalls[0].Invalid) // Should be valid after repair
1388 require.Equal(t, `{"value": "repaired"}`, toolCalls[0].Input) // Should have repaired input
1389 })
1390
1391 t.Run("Invalid tool call with failed repair", func(t *testing.T) {
1392 t.Parallel()
1393 model := &mockLanguageModel{
1394 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1395 return &Response{
1396 Content: ResponseContent{
1397 TextContent{Text: "Response"},
1398 ToolCallContent{
1399 ToolCallID: "call1",
1400 ToolName: "test_tool",
1401 Input: `{"wrong_field": "test"}`, // Missing required field
1402 },
1403 },
1404 Usage: Usage{TotalTokens: 10},
1405 FinishReason: FinishReasonStop, // Changed to stop
1406 }, nil
1407 },
1408 }
1409
1410 tool := &mockTool{
1411 name: "test_tool",
1412 description: "Test tool",
1413 parameters: map[string]any{
1414 "value": map[string]any{"type": "string"},
1415 },
1416 required: []string{"value"},
1417 }
1418
1419 repairFunc := func(ctx context.Context, options ToolCallRepairOptions) (*ToolCallContent, error) {
1420 // Repair function fails
1421 return nil, errors.New("repair failed")
1422 }
1423
1424 agent := NewAgent(model, WithTools(tool), WithRepairToolCall(repairFunc), WithStopConditions(StepCountIs(2)))
1425
1426 result, err := agent.Generate(context.Background(), AgentCall{
1427 Prompt: "test prompt",
1428 })
1429
1430 require.NoError(t, err)
1431 require.NotNil(t, result)
1432 require.Len(t, result.Steps, 1) // Only one step
1433
1434 // Check that tool call was marked as invalid since repair failed
1435 toolCalls := result.Steps[0].Content.ToolCalls()
1436 require.Len(t, toolCalls, 1)
1437 require.True(t, toolCalls[0].Invalid) // Should be invalid
1438 require.Contains(t, toolCalls[0].ValidationError.Error(), "missing required parameter: value")
1439 })
1440
1441 t.Run("Nonexistent tool call", func(t *testing.T) {
1442 t.Parallel()
1443 model := &mockLanguageModel{
1444 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1445 return &Response{
1446 Content: ResponseContent{
1447 TextContent{Text: "Response"},
1448 ToolCallContent{
1449 ToolCallID: "call1",
1450 ToolName: "nonexistent_tool",
1451 Input: `{"value": "test"}`,
1452 },
1453 },
1454 Usage: Usage{TotalTokens: 10},
1455 FinishReason: FinishReasonStop, // Changed to stop
1456 }, nil
1457 },
1458 }
1459
1460 tool := &mockTool{name: "test_tool", description: "Test tool"}
1461
1462 agent := NewAgent(model, WithTools(tool), WithStopConditions(StepCountIs(2)))
1463
1464 result, err := agent.Generate(context.Background(), AgentCall{
1465 Prompt: "test prompt",
1466 })
1467
1468 require.NoError(t, err)
1469 require.NotNil(t, result)
1470 require.Len(t, result.Steps, 1) // Only one step
1471
1472 // Check that tool call was marked as invalid due to nonexistent tool
1473 toolCalls := result.Steps[0].Content.ToolCalls()
1474 require.Len(t, toolCalls, 1)
1475 require.True(t, toolCalls[0].Invalid) // Should be invalid
1476 require.Contains(t, toolCalls[0].ValidationError.Error(), "tool not found: nonexistent_tool")
1477 })
1478
1479 t.Run("Invalid JSON in tool call", func(t *testing.T) {
1480 t.Parallel()
1481 model := &mockLanguageModel{
1482 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1483 return &Response{
1484 Content: ResponseContent{
1485 TextContent{Text: "Response"},
1486 ToolCallContent{
1487 ToolCallID: "call1",
1488 ToolName: "test_tool",
1489 Input: `{invalid json}`, // Invalid JSON
1490 },
1491 },
1492 Usage: Usage{TotalTokens: 10},
1493 FinishReason: FinishReasonStop, // Changed to stop
1494 }, nil
1495 },
1496 }
1497
1498 tool := &mockTool{
1499 name: "test_tool",
1500 description: "Test tool",
1501 parameters: map[string]any{
1502 "value": map[string]any{"type": "string"},
1503 },
1504 required: []string{"value"},
1505 }
1506
1507 agent := NewAgent(model, WithTools(tool), WithStopConditions(StepCountIs(2)))
1508
1509 result, err := agent.Generate(context.Background(), AgentCall{
1510 Prompt: "test prompt",
1511 })
1512
1513 require.NoError(t, err)
1514 require.NotNil(t, result)
1515 require.Len(t, result.Steps, 1) // Only one step
1516
1517 // Check that tool call was marked as invalid due to invalid JSON
1518 toolCalls := result.Steps[0].Content.ToolCalls()
1519 require.Len(t, toolCalls, 1)
1520 require.True(t, toolCalls[0].Invalid) // Should be invalid
1521 require.Contains(t, toolCalls[0].ValidationError.Error(), "invalid JSON input")
1522 })
1523}