1package ai
2
3import (
4 "context"
5 "encoding/json"
6 "errors"
7 "fmt"
8 "testing"
9
10 "github.com/stretchr/testify/require"
11)
12
13// Mock tool for testing
14type mockTool struct {
15 name string
16 description string
17 parameters map[string]any
18 required []string
19 executeFunc func(ctx context.Context, call ToolCall) (ToolResponse, error)
20}
21
22func (m *mockTool) Info() ToolInfo {
23 return ToolInfo{
24 Name: m.name,
25 Description: m.description,
26 Parameters: m.parameters,
27 Required: m.required,
28 }
29}
30
31func (m *mockTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
32 if m.executeFunc != nil {
33 return m.executeFunc(ctx, call)
34 }
35 return ToolResponse{Content: "mock result", IsError: false}, nil
36}
37
38// Mock language model for testing
39type mockLanguageModel struct {
40 generateFunc func(ctx context.Context, call Call) (*Response, error)
41 streamFunc func(ctx context.Context, call Call) (StreamResponse, error)
42}
43
44func (m *mockLanguageModel) Generate(ctx context.Context, call Call) (*Response, error) {
45 if m.generateFunc != nil {
46 return m.generateFunc(ctx, call)
47 }
48 return &Response{
49 Content: []Content{
50 TextContent{Text: "Hello, world!"},
51 },
52 Usage: Usage{
53 InputTokens: 3,
54 OutputTokens: 10,
55 TotalTokens: 13,
56 },
57 FinishReason: FinishReasonStop,
58 }, nil
59}
60
61func (m *mockLanguageModel) Stream(ctx context.Context, call Call) (StreamResponse, error) {
62 if m.streamFunc != nil {
63 return m.streamFunc(ctx, call)
64 }
65 return nil, fmt.Errorf("mock stream not implemented")
66}
67
68func (m *mockLanguageModel) Provider() string {
69 return "mock-provider"
70}
71
72func (m *mockLanguageModel) Model() string {
73 return "mock-model"
74}
75
76// Test result.content - comprehensive content types (matches TS test)
77func TestAgent_Generate_ResultContent_AllTypes(t *testing.T) {
78 t.Parallel()
79
80 // Create a type-safe tool using the new API
81 type TestInput struct {
82 Value string `json:"value" description:"Test value"`
83 }
84
85 tool1 := NewAgentTool(
86 "tool1",
87 "Test tool",
88 func(ctx context.Context, input TestInput, _ ToolCall) (ToolResponse, error) {
89 require.Equal(t, "value", input.Value)
90 return ToolResponse{Content: "result1", IsError: false}, nil
91 },
92 )
93
94 model := &mockLanguageModel{
95 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
96 return &Response{
97 Content: []Content{
98 TextContent{Text: "Hello, world!"},
99 SourceContent{
100 ID: "123",
101 URL: "https://example.com",
102 Title: "Example",
103 SourceType: SourceTypeURL,
104 },
105 FileContent{
106 Data: []byte{1, 2, 3},
107 MediaType: "image/png",
108 },
109 ReasoningContent{
110 Text: "I will open the conversation with witty banter.",
111 },
112 ToolCallContent{
113 ToolCallID: "call-1",
114 ToolName: "tool1",
115 Input: `{"value":"value"}`,
116 },
117 TextContent{Text: "More text"},
118 },
119 Usage: Usage{
120 InputTokens: 3,
121 OutputTokens: 10,
122 TotalTokens: 13,
123 },
124 FinishReason: FinishReasonStop, // Note: FinishReasonStop, not ToolCalls
125 }, nil
126 },
127 }
128
129 agent := NewAgent(model, WithTools(tool1))
130 result, err := agent.Generate(context.Background(), AgentCall{
131 Prompt: "prompt",
132 })
133
134 require.NoError(t, err)
135 require.NotNil(t, result)
136 require.Len(t, result.Steps, 1) // Single step like TypeScript
137
138 // Check final response content includes tool result
139 require.Len(t, result.Response.Content, 7) // original 6 + 1 tool result
140
141 // Verify each content type in order
142 textContent, ok := AsContentType[TextContent](result.Response.Content[0])
143 require.True(t, ok)
144 require.Equal(t, "Hello, world!", textContent.Text)
145
146 sourceContent, ok := AsContentType[SourceContent](result.Response.Content[1])
147 require.True(t, ok)
148 require.Equal(t, "123", sourceContent.ID)
149
150 fileContent, ok := AsContentType[FileContent](result.Response.Content[2])
151 require.True(t, ok)
152 require.Equal(t, []byte{1, 2, 3}, fileContent.Data)
153
154 reasoningContent, ok := AsContentType[ReasoningContent](result.Response.Content[3])
155 require.True(t, ok)
156 require.Equal(t, "I will open the conversation with witty banter.", reasoningContent.Text)
157
158 toolCallContent, ok := AsContentType[ToolCallContent](result.Response.Content[4])
159 require.True(t, ok)
160 require.Equal(t, "call-1", toolCallContent.ToolCallID)
161
162 moreTextContent, ok := AsContentType[TextContent](result.Response.Content[5])
163 require.True(t, ok)
164 require.Equal(t, "More text", moreTextContent.Text)
165
166 // Tool result should be appended
167 toolResultContent, ok := AsContentType[ToolResultContent](result.Response.Content[6])
168 require.True(t, ok)
169 require.Equal(t, "call-1", toolResultContent.ToolCallID)
170 require.Equal(t, "tool1", toolResultContent.ToolName)
171}
172
173// Test result.text extraction
174func TestAgent_Generate_ResultText(t *testing.T) {
175 t.Parallel()
176
177 model := &mockLanguageModel{
178 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
179 return &Response{
180 Content: []Content{
181 TextContent{Text: "Hello, world!"},
182 },
183 Usage: Usage{
184 InputTokens: 3,
185 OutputTokens: 10,
186 TotalTokens: 13,
187 },
188 FinishReason: FinishReasonStop,
189 }, nil
190 },
191 }
192
193 agent := NewAgent(model)
194 result, err := agent.Generate(context.Background(), AgentCall{
195 Prompt: "prompt",
196 })
197
198 require.NoError(t, err)
199 require.NotNil(t, result)
200
201 // Test text extraction from content
202 text := result.Response.Content.Text()
203 require.Equal(t, "Hello, world!", text)
204}
205
206// Test result.toolCalls extraction (matches TS test exactly)
207func TestAgent_Generate_ResultToolCalls(t *testing.T) {
208 t.Parallel()
209
210 // Create type-safe tools using the new API
211 type Tool1Input struct {
212 Value string `json:"value" description:"Test value"`
213 }
214
215 type Tool2Input struct {
216 SomethingElse string `json:"somethingElse" description:"Another test value"`
217 }
218
219 tool1 := NewAgentTool(
220 "tool1",
221 "Test tool 1",
222 func(ctx context.Context, input Tool1Input, _ ToolCall) (ToolResponse, error) {
223 return ToolResponse{Content: "result1", IsError: false}, nil
224 },
225 )
226
227 tool2 := NewAgentTool(
228 "tool2",
229 "Test tool 2",
230 func(ctx context.Context, input Tool2Input, _ ToolCall) (ToolResponse, error) {
231 return ToolResponse{Content: "result2", IsError: false}, nil
232 },
233 )
234
235 model := &mockLanguageModel{
236 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
237 // Verify tools are passed correctly
238 require.Len(t, call.Tools, 2)
239 require.Equal(t, ToolChoiceAuto, *call.ToolChoice) // Should be auto, not required
240
241 // Verify prompt structure
242 require.Len(t, call.Prompt, 1)
243 require.Equal(t, MessageRoleUser, call.Prompt[0].Role)
244
245 return &Response{
246 Content: []Content{
247 ToolCallContent{
248 ToolCallID: "call-1",
249 ToolName: "tool1",
250 Input: `{"value":"value"}`,
251 },
252 },
253 Usage: Usage{
254 InputTokens: 3,
255 OutputTokens: 10,
256 TotalTokens: 13,
257 },
258 FinishReason: FinishReasonStop, // Note: Stop, not ToolCalls
259 }, nil
260 },
261 }
262
263 agent := NewAgent(model, WithTools(tool1, tool2))
264 result, err := agent.Generate(context.Background(), AgentCall{
265 Prompt: "test-input",
266 })
267
268 require.NoError(t, err)
269 require.NotNil(t, result)
270 require.Len(t, result.Steps, 1) // Single step
271
272 // Extract tool calls from final response (should be empty since tools don't execute)
273 var toolCalls []ToolCallContent
274 for _, content := range result.Response.Content {
275 if toolCall, ok := AsContentType[ToolCallContent](content); ok {
276 toolCalls = append(toolCalls, toolCall)
277 }
278 }
279
280 require.Len(t, toolCalls, 1)
281 require.Equal(t, "call-1", toolCalls[0].ToolCallID)
282 require.Equal(t, "tool1", toolCalls[0].ToolName)
283
284 // Parse and verify input
285 var input map[string]any
286 err = json.Unmarshal([]byte(toolCalls[0].Input), &input)
287 require.NoError(t, err)
288 require.Equal(t, "value", input["value"])
289}
290
291// Test result.toolResults extraction (matches TS test exactly)
292func TestAgent_Generate_ResultToolResults(t *testing.T) {
293 t.Parallel()
294
295 // Create type-safe tool using the new API
296 type TestInput struct {
297 Value string `json:"value" description:"Test value"`
298 }
299
300 tool1 := NewAgentTool(
301 "tool1",
302 "Test tool",
303 func(ctx context.Context, input TestInput, _ ToolCall) (ToolResponse, error) {
304 require.Equal(t, "value", input.Value)
305 return ToolResponse{Content: "result1", IsError: false}, nil
306 },
307 )
308
309 model := &mockLanguageModel{
310 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
311 // Verify tools and tool choice
312 require.Len(t, call.Tools, 1)
313 require.Equal(t, ToolChoiceAuto, *call.ToolChoice)
314
315 // Verify prompt
316 require.Len(t, call.Prompt, 1)
317 require.Equal(t, MessageRoleUser, call.Prompt[0].Role)
318
319 return &Response{
320 Content: []Content{
321 ToolCallContent{
322 ToolCallID: "call-1",
323 ToolName: "tool1",
324 Input: `{"value":"value"}`,
325 },
326 },
327 Usage: Usage{
328 InputTokens: 3,
329 OutputTokens: 10,
330 TotalTokens: 13,
331 },
332 FinishReason: FinishReasonStop, // Note: Stop, not ToolCalls
333 }, nil
334 },
335 }
336
337 agent := NewAgent(model, WithTools(tool1))
338 result, err := agent.Generate(context.Background(), AgentCall{
339 Prompt: "test-input",
340 })
341
342 require.NoError(t, err)
343 require.NotNil(t, result)
344 require.Len(t, result.Steps, 1) // Single step
345
346 // Extract tool results from final response
347 var toolResults []ToolResultContent
348 for _, content := range result.Response.Content {
349 if toolResult, ok := AsContentType[ToolResultContent](content); ok {
350 toolResults = append(toolResults, toolResult)
351 }
352 }
353
354 require.Len(t, toolResults, 1)
355 require.Equal(t, "call-1", toolResults[0].ToolCallID)
356 require.Equal(t, "tool1", toolResults[0].ToolName)
357
358 // Verify result content
359 textResult, ok := toolResults[0].Result.(ToolResultOutputContentText)
360 require.True(t, ok)
361 require.Equal(t, "result1", textResult.Text)
362}
363
364// Test multi-step scenario (matches TS "2 steps: initial, tool-result" test)
365func TestAgent_Generate_MultipleSteps(t *testing.T) {
366 t.Parallel()
367
368 // Create type-safe tool using the new API
369 type TestInput struct {
370 Value string `json:"value" description:"Test value"`
371 }
372
373 tool1 := NewAgentTool(
374 "tool1",
375 "Test tool",
376 func(ctx context.Context, input TestInput, _ ToolCall) (ToolResponse, error) {
377 require.Equal(t, "value", input.Value)
378 return ToolResponse{Content: "result1", IsError: false}, nil
379 },
380 )
381
382 callCount := 0
383 model := &mockLanguageModel{
384 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
385 callCount++
386 switch callCount {
387 case 1:
388 // First call - return tool call with FinishReasonToolCalls
389 return &Response{
390 Content: []Content{
391 ToolCallContent{
392 ToolCallID: "call-1",
393 ToolName: "tool1",
394 Input: `{"value":"value"}`,
395 },
396 },
397 Usage: Usage{
398 InputTokens: 10,
399 OutputTokens: 5,
400 TotalTokens: 15,
401 },
402 FinishReason: FinishReasonToolCalls, // This triggers multi-step
403 }, nil
404 case 2:
405 // Second call - return final text
406 return &Response{
407 Content: []Content{
408 TextContent{Text: "Hello, world!"},
409 },
410 Usage: Usage{
411 InputTokens: 3,
412 OutputTokens: 10,
413 TotalTokens: 13,
414 },
415 FinishReason: FinishReasonStop,
416 }, nil
417 default:
418 t.Fatalf("Unexpected call count: %d", callCount)
419 return nil, nil
420 }
421 },
422 }
423
424 agent := NewAgent(model, WithTools(tool1))
425 result, err := agent.Generate(context.Background(), AgentCall{
426 Prompt: "test-input",
427 })
428
429 require.NoError(t, err)
430 require.NotNil(t, result)
431 require.Len(t, result.Steps, 2)
432
433 // Check total usage sums both steps
434 require.Equal(t, int64(13), result.TotalUsage.InputTokens) // 10 + 3
435 require.Equal(t, int64(15), result.TotalUsage.OutputTokens) // 5 + 10
436 require.Equal(t, int64(28), result.TotalUsage.TotalTokens) // 15 + 13
437
438 // Final response should be from last step
439 require.Len(t, result.Response.Content, 1)
440 textContent, ok := AsContentType[TextContent](result.Response.Content[0])
441 require.True(t, ok)
442 require.Equal(t, "Hello, world!", textContent.Text)
443
444 // result.toolCalls should be empty (from last step)
445 var toolCalls []ToolCallContent
446 for _, content := range result.Response.Content {
447 if _, ok := AsContentType[ToolCallContent](content); ok {
448 toolCalls = append(toolCalls, content.(ToolCallContent))
449 }
450 }
451 require.Len(t, toolCalls, 0)
452
453 // result.toolResults should be empty (from last step)
454 var toolResults []ToolResultContent
455 for _, content := range result.Response.Content {
456 if _, ok := AsContentType[ToolResultContent](content); ok {
457 toolResults = append(toolResults, content.(ToolResultContent))
458 }
459 }
460 require.Len(t, toolResults, 0)
461}
462
463// Test basic text generation
464func TestAgent_Generate_BasicText(t *testing.T) {
465 t.Parallel()
466
467 model := &mockLanguageModel{
468 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
469 return &Response{
470 Content: []Content{
471 TextContent{Text: "Hello, world!"},
472 },
473 Usage: Usage{
474 InputTokens: 3,
475 OutputTokens: 10,
476 TotalTokens: 13,
477 },
478 FinishReason: FinishReasonStop,
479 }, nil
480 },
481 }
482
483 agent := NewAgent(model)
484 result, err := agent.Generate(context.Background(), AgentCall{
485 Prompt: "test prompt",
486 })
487
488 require.NoError(t, err)
489 require.NotNil(t, result)
490 require.Len(t, result.Steps, 1)
491
492 // Check final response
493 require.Len(t, result.Response.Content, 1)
494 textContent, ok := AsContentType[TextContent](result.Response.Content[0])
495 require.True(t, ok)
496 require.Equal(t, "Hello, world!", textContent.Text)
497
498 // Check usage
499 require.Equal(t, int64(3), result.Response.Usage.InputTokens)
500 require.Equal(t, int64(10), result.Response.Usage.OutputTokens)
501 require.Equal(t, int64(13), result.Response.Usage.TotalTokens)
502
503 // Check total usage
504 require.Equal(t, int64(3), result.TotalUsage.InputTokens)
505 require.Equal(t, int64(10), result.TotalUsage.OutputTokens)
506 require.Equal(t, int64(13), result.TotalUsage.TotalTokens)
507}
508
509// Test empty prompt error
510func TestAgent_Generate_EmptyPrompt(t *testing.T) {
511 t.Parallel()
512
513 model := &mockLanguageModel{}
514 agent := NewAgent(model)
515
516 result, err := agent.Generate(context.Background(), AgentCall{
517 Prompt: "", // Empty prompt should cause error
518 })
519
520 require.Error(t, err)
521 require.Nil(t, result)
522 require.Contains(t, err.Error(), "Prompt can't be empty")
523}
524
525// Test with system prompt
526func TestAgent_Generate_WithSystemPrompt(t *testing.T) {
527 t.Parallel()
528
529 model := &mockLanguageModel{
530 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
531 // Verify system message is included
532 require.Len(t, call.Prompt, 2) // system + user
533 require.Equal(t, MessageRoleSystem, call.Prompt[0].Role)
534 require.Equal(t, MessageRoleUser, call.Prompt[1].Role)
535
536 systemPart, ok := call.Prompt[0].Content[0].(TextPart)
537 require.True(t, ok)
538 require.Equal(t, "You are a helpful assistant", systemPart.Text)
539
540 return &Response{
541 Content: []Content{
542 TextContent{Text: "Hello, world!"},
543 },
544 Usage: Usage{
545 InputTokens: 3,
546 OutputTokens: 10,
547 TotalTokens: 13,
548 },
549 FinishReason: FinishReasonStop,
550 }, nil
551 },
552 }
553
554 agent := NewAgent(model, WithSystemPrompt("You are a helpful assistant"))
555 result, err := agent.Generate(context.Background(), AgentCall{
556 Prompt: "test prompt",
557 })
558
559 require.NoError(t, err)
560 require.NotNil(t, result)
561}
562
563// Test options.activeTools filtering
564func TestAgent_Generate_OptionsActiveTools(t *testing.T) {
565 t.Parallel()
566
567 tool1 := &mockTool{
568 name: "tool1",
569 description: "Test tool 1",
570 parameters: map[string]any{
571 "value": map[string]any{"type": "string"},
572 },
573 required: []string{"value"},
574 }
575
576 tool2 := &mockTool{
577 name: "tool2",
578 description: "Test tool 2",
579 parameters: map[string]any{
580 "value": map[string]any{"type": "string"},
581 },
582 required: []string{"value"},
583 }
584
585 model := &mockLanguageModel{
586 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
587 // Verify only tool1 is available
588 require.Len(t, call.Tools, 1)
589 functionTool, ok := call.Tools[0].(FunctionTool)
590 require.True(t, ok)
591 require.Equal(t, "tool1", functionTool.Name)
592
593 return &Response{
594 Content: []Content{
595 TextContent{Text: "Hello, world!"},
596 },
597 Usage: Usage{
598 InputTokens: 3,
599 OutputTokens: 10,
600 TotalTokens: 13,
601 },
602 FinishReason: FinishReasonStop,
603 }, nil
604 },
605 }
606
607 agent := NewAgent(model, WithTools(tool1, tool2))
608 result, err := agent.Generate(context.Background(), AgentCall{
609 Prompt: "test-input",
610 ActiveTools: []string{"tool1"}, // Only tool1 should be active
611 })
612
613 require.NoError(t, err)
614 require.NotNil(t, result)
615}
616
617func TestResponseContent_Getters(t *testing.T) {
618 t.Parallel()
619
620 // Create test content with all types
621 content := ResponseContent{
622 TextContent{Text: "Hello world"},
623 ReasoningContent{Text: "Let me think..."},
624 FileContent{Data: []byte("file data"), MediaType: "text/plain"},
625 SourceContent{SourceType: SourceTypeURL, URL: "https://example.com", Title: "Example"},
626 ToolCallContent{ToolCallID: "call1", ToolName: "test_tool", Input: `{"arg": "value"}`},
627 ToolResultContent{ToolCallID: "call1", ToolName: "test_tool", Result: ToolResultOutputContentText{Text: "result"}},
628 }
629
630 // Test Text()
631 require.Equal(t, "Hello world", content.Text())
632
633 // Test Reasoning()
634 reasoning := content.Reasoning()
635 require.Len(t, reasoning, 1)
636 require.Equal(t, "Let me think...", reasoning[0].Text)
637
638 // Test ReasoningText()
639 require.Equal(t, "Let me think...", content.ReasoningText())
640
641 // Test Files()
642 files := content.Files()
643 require.Len(t, files, 1)
644 require.Equal(t, "text/plain", files[0].MediaType)
645 require.Equal(t, []byte("file data"), files[0].Data)
646
647 // Test Sources()
648 sources := content.Sources()
649 require.Len(t, sources, 1)
650 require.Equal(t, SourceTypeURL, sources[0].SourceType)
651 require.Equal(t, "https://example.com", sources[0].URL)
652 require.Equal(t, "Example", sources[0].Title)
653
654 // Test ToolCalls()
655 toolCalls := content.ToolCalls()
656 require.Len(t, toolCalls, 1)
657 require.Equal(t, "call1", toolCalls[0].ToolCallID)
658 require.Equal(t, "test_tool", toolCalls[0].ToolName)
659 require.Equal(t, `{"arg": "value"}`, toolCalls[0].Input)
660
661 // Test ToolResults()
662 toolResults := content.ToolResults()
663 require.Len(t, toolResults, 1)
664 require.Equal(t, "call1", toolResults[0].ToolCallID)
665 require.Equal(t, "test_tool", toolResults[0].ToolName)
666 result, ok := AsToolResultOutputType[ToolResultOutputContentText](toolResults[0].Result)
667 require.True(t, ok)
668 require.Equal(t, "result", result.Text)
669}
670
671func TestResponseContent_Getters_Empty(t *testing.T) {
672 t.Parallel()
673
674 // Test with empty content
675 content := ResponseContent{}
676
677 require.Equal(t, "", content.Text())
678 require.Equal(t, "", content.ReasoningText())
679 require.Empty(t, content.Reasoning())
680 require.Empty(t, content.Files())
681 require.Empty(t, content.Sources())
682 require.Empty(t, content.ToolCalls())
683 require.Empty(t, content.ToolResults())
684}
685
686func TestResponseContent_Getters_MultipleItems(t *testing.T) {
687 t.Parallel()
688
689 // Test with multiple items of same type
690 content := ResponseContent{
691 ReasoningContent{Text: "First thought"},
692 ReasoningContent{Text: "Second thought"},
693 FileContent{Data: []byte("file1"), MediaType: "text/plain"},
694 FileContent{Data: []byte("file2"), MediaType: "image/png"},
695 }
696
697 // Test multiple reasoning
698 reasoning := content.Reasoning()
699 require.Len(t, reasoning, 2)
700 require.Equal(t, "First thought", reasoning[0].Text)
701 require.Equal(t, "Second thought", reasoning[1].Text)
702
703 // Test concatenated reasoning text
704 require.Equal(t, "First thoughtSecond thought", content.ReasoningText())
705
706 // Test multiple files
707 files := content.Files()
708 require.Len(t, files, 2)
709 require.Equal(t, "text/plain", files[0].MediaType)
710 require.Equal(t, "image/png", files[1].MediaType)
711}
712
713func TestStopConditions(t *testing.T) {
714 t.Parallel()
715
716 // Create test steps
717 step1 := StepResult{
718 Response: Response{
719 Content: ResponseContent{
720 TextContent{Text: "Hello"},
721 },
722 FinishReason: FinishReasonToolCalls,
723 Usage: Usage{TotalTokens: 10},
724 },
725 }
726
727 step2 := StepResult{
728 Response: Response{
729 Content: ResponseContent{
730 TextContent{Text: "World"},
731 ToolCallContent{ToolCallID: "call1", ToolName: "search", Input: `{"query": "test"}`},
732 },
733 FinishReason: FinishReasonStop,
734 Usage: Usage{TotalTokens: 15},
735 },
736 }
737
738 step3 := StepResult{
739 Response: Response{
740 Content: ResponseContent{
741 ReasoningContent{Text: "Let me think..."},
742 FileContent{Data: []byte("data"), MediaType: "text/plain"},
743 },
744 FinishReason: FinishReasonLength,
745 Usage: Usage{TotalTokens: 20},
746 },
747 }
748
749 t.Run("StepCountIs", func(t *testing.T) {
750 t.Parallel()
751 condition := StepCountIs(2)
752
753 // Should not stop with 1 step
754 require.False(t, condition([]StepResult{step1}))
755
756 // Should stop with 2 steps
757 require.True(t, condition([]StepResult{step1, step2}))
758
759 // Should stop with more than 2 steps
760 require.True(t, condition([]StepResult{step1, step2, step3}))
761
762 // Should not stop with empty steps
763 require.False(t, condition([]StepResult{}))
764 })
765
766 t.Run("HasToolCall", func(t *testing.T) {
767 t.Parallel()
768 condition := HasToolCall("search")
769
770 // Should not stop when tool not called
771 require.False(t, condition([]StepResult{step1}))
772
773 // Should stop when tool is called in last step
774 require.True(t, condition([]StepResult{step1, step2}))
775
776 // Should not stop when tool called in earlier step but not last
777 require.False(t, condition([]StepResult{step1, step2, step3}))
778
779 // Should not stop with empty steps
780 require.False(t, condition([]StepResult{}))
781
782 // Should not stop when different tool is called
783 differentToolCondition := HasToolCall("different_tool")
784 require.False(t, differentToolCondition([]StepResult{step1, step2}))
785 })
786
787 t.Run("HasContent", func(t *testing.T) {
788 t.Parallel()
789 reasoningCondition := HasContent(ContentTypeReasoning)
790 fileCondition := HasContent(ContentTypeFile)
791
792 // Should not stop when content type not present
793 require.False(t, reasoningCondition([]StepResult{step1, step2}))
794
795 // Should stop when content type is present in last step
796 require.True(t, reasoningCondition([]StepResult{step1, step2, step3}))
797 require.True(t, fileCondition([]StepResult{step1, step2, step3}))
798
799 // Should not stop with empty steps
800 require.False(t, reasoningCondition([]StepResult{}))
801 })
802
803 t.Run("FinishReasonIs", func(t *testing.T) {
804 t.Parallel()
805 stopCondition := FinishReasonIs(FinishReasonStop)
806 lengthCondition := FinishReasonIs(FinishReasonLength)
807
808 // Should not stop when finish reason doesn't match
809 require.False(t, stopCondition([]StepResult{step1}))
810
811 // Should stop when finish reason matches in last step
812 require.True(t, stopCondition([]StepResult{step1, step2}))
813 require.True(t, lengthCondition([]StepResult{step1, step2, step3}))
814
815 // Should not stop with empty steps
816 require.False(t, stopCondition([]StepResult{}))
817 })
818
819 t.Run("MaxTokensUsed", func(t *testing.T) {
820 condition := MaxTokensUsed(30)
821
822 // Should not stop when under limit
823 require.False(t, condition([]StepResult{step1})) // 10 tokens
824 require.False(t, condition([]StepResult{step1, step2})) // 25 tokens
825
826 // Should stop when at or over limit
827 require.True(t, condition([]StepResult{step1, step2, step3})) // 45 tokens
828
829 // Should not stop with empty steps
830 require.False(t, condition([]StepResult{}))
831 })
832}
833
834func TestStopConditions_Integration(t *testing.T) {
835 t.Parallel()
836
837 t.Run("StepCountIs integration", func(t *testing.T) {
838 t.Parallel()
839 model := &mockLanguageModel{
840 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
841 return &Response{
842 Content: ResponseContent{
843 TextContent{Text: "Mock response"},
844 },
845 Usage: Usage{
846 InputTokens: 3,
847 OutputTokens: 10,
848 TotalTokens: 13,
849 },
850 FinishReason: FinishReasonStop,
851 }, nil
852 },
853 }
854
855 agent := NewAgent(model, WithStopConditions(StepCountIs(1)))
856
857 result, err := agent.Generate(context.Background(), AgentCall{
858 Prompt: "test prompt",
859 })
860
861 require.NoError(t, err)
862 require.NotNil(t, result)
863 require.Len(t, result.Steps, 1) // Should stop after 1 step
864 })
865
866 t.Run("Multiple stop conditions", func(t *testing.T) {
867 t.Parallel()
868 model := &mockLanguageModel{
869 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
870 return &Response{
871 Content: ResponseContent{
872 TextContent{Text: "Mock response"},
873 },
874 Usage: Usage{
875 InputTokens: 3,
876 OutputTokens: 10,
877 TotalTokens: 13,
878 },
879 FinishReason: FinishReasonStop,
880 }, nil
881 },
882 }
883
884 agent := NewAgent(model, WithStopConditions(
885 StepCountIs(5), // Stop after 5 steps
886 FinishReasonIs(FinishReasonStop), // Or stop on finish reason
887 ))
888
889 result, err := agent.Generate(context.Background(), AgentCall{
890 Prompt: "test prompt",
891 })
892
893 require.NoError(t, err)
894 require.NotNil(t, result)
895 // Should stop on first condition met (finish reason stop)
896 require.Equal(t, FinishReasonStop, result.Response.FinishReason)
897 })
898}
899
900func TestPrepareStep(t *testing.T) {
901 t.Parallel()
902
903 t.Run("System prompt modification", func(t *testing.T) {
904 t.Parallel()
905 var capturedSystemPrompt string
906 model := &mockLanguageModel{
907 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
908 // Capture the system message to verify it was modified
909 if len(call.Prompt) > 0 && call.Prompt[0].Role == MessageRoleSystem {
910 if len(call.Prompt[0].Content) > 0 {
911 if textPart, ok := AsContentType[TextPart](call.Prompt[0].Content[0]); ok {
912 capturedSystemPrompt = textPart.Text
913 }
914 }
915 }
916 return &Response{
917 Content: ResponseContent{
918 TextContent{Text: "Response"},
919 },
920 Usage: Usage{TotalTokens: 10},
921 FinishReason: FinishReasonStop,
922 }, nil
923 },
924 }
925
926 prepareStepFunc := func(options PrepareStepFunctionOptions) (PrepareStepResult, error) {
927 newSystem := "Modified system prompt for step " + fmt.Sprintf("%d", options.StepNumber)
928 return PrepareStepResult{
929 Model: options.Model,
930 Messages: options.Messages,
931 System: &newSystem,
932 }, nil
933 }
934
935 agent := NewAgent(model, WithSystemPrompt("Original system prompt"))
936
937 result, err := agent.Generate(context.Background(), AgentCall{
938 Prompt: "test prompt",
939 PrepareStep: prepareStepFunc,
940 })
941
942 require.NoError(t, err)
943 require.NotNil(t, result)
944 require.Equal(t, "Modified system prompt for step 0", capturedSystemPrompt)
945 })
946
947 t.Run("Tool choice modification", func(t *testing.T) {
948 t.Parallel()
949 var capturedToolChoice *ToolChoice
950 model := &mockLanguageModel{
951 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
952 capturedToolChoice = call.ToolChoice
953 return &Response{
954 Content: ResponseContent{
955 TextContent{Text: "Response"},
956 },
957 Usage: Usage{TotalTokens: 10},
958 FinishReason: FinishReasonStop,
959 }, nil
960 },
961 }
962
963 prepareStepFunc := func(options PrepareStepFunctionOptions) (PrepareStepResult, error) {
964 toolChoice := ToolChoiceNone
965 return PrepareStepResult{
966 Model: options.Model,
967 Messages: options.Messages,
968 ToolChoice: &toolChoice,
969 }, nil
970 }
971
972 agent := NewAgent(model)
973
974 result, err := agent.Generate(context.Background(), AgentCall{
975 Prompt: "test prompt",
976 PrepareStep: prepareStepFunc,
977 })
978
979 require.NoError(t, err)
980 require.NotNil(t, result)
981 require.NotNil(t, capturedToolChoice)
982 require.Equal(t, ToolChoiceNone, *capturedToolChoice)
983 })
984
985 t.Run("Active tools modification", func(t *testing.T) {
986 t.Parallel()
987 var capturedToolNames []string
988 model := &mockLanguageModel{
989 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
990 // Capture tool names to verify active tools were modified
991 for _, tool := range call.Tools {
992 capturedToolNames = append(capturedToolNames, tool.GetName())
993 }
994 return &Response{
995 Content: ResponseContent{
996 TextContent{Text: "Response"},
997 },
998 Usage: Usage{TotalTokens: 10},
999 FinishReason: FinishReasonStop,
1000 }, nil
1001 },
1002 }
1003
1004 tool1 := &mockTool{name: "tool1", description: "Tool 1"}
1005 tool2 := &mockTool{name: "tool2", description: "Tool 2"}
1006 tool3 := &mockTool{name: "tool3", description: "Tool 3"}
1007
1008 prepareStepFunc := func(options PrepareStepFunctionOptions) (PrepareStepResult, error) {
1009 activeTools := []string{"tool2"} // Only tool2 should be active
1010 return PrepareStepResult{
1011 Model: options.Model,
1012 Messages: options.Messages,
1013 ActiveTools: activeTools,
1014 }, nil
1015 }
1016
1017 agent := NewAgent(model, WithTools(tool1, tool2, tool3))
1018
1019 result, err := agent.Generate(context.Background(), AgentCall{
1020 Prompt: "test prompt",
1021 PrepareStep: prepareStepFunc,
1022 })
1023
1024 require.NoError(t, err)
1025 require.NotNil(t, result)
1026 require.Len(t, capturedToolNames, 1)
1027 require.Equal(t, "tool2", capturedToolNames[0])
1028 })
1029
1030 t.Run("No tools when DisableAllTools is true", func(t *testing.T) {
1031 t.Parallel()
1032 var capturedToolCount int
1033 model := &mockLanguageModel{
1034 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1035 capturedToolCount = len(call.Tools)
1036 return &Response{
1037 Content: ResponseContent{
1038 TextContent{Text: "Response"},
1039 },
1040 Usage: Usage{TotalTokens: 10},
1041 FinishReason: FinishReasonStop,
1042 }, nil
1043 },
1044 }
1045
1046 tool1 := &mockTool{name: "tool1", description: "Tool 1"}
1047
1048 prepareStepFunc := func(options PrepareStepFunctionOptions) (PrepareStepResult, error) {
1049 return PrepareStepResult{
1050 Model: options.Model,
1051 Messages: options.Messages,
1052 DisableAllTools: true, // Disable all tools for this step
1053 }, nil
1054 }
1055
1056 agent := NewAgent(model, WithTools(tool1))
1057
1058 result, err := agent.Generate(context.Background(), AgentCall{
1059 Prompt: "test prompt",
1060 PrepareStep: prepareStepFunc,
1061 })
1062
1063 require.NoError(t, err)
1064 require.NotNil(t, result)
1065 require.Equal(t, 0, capturedToolCount) // No tools should be passed
1066 })
1067
1068 t.Run("All fields modified together", func(t *testing.T) {
1069 t.Parallel()
1070 var capturedSystemPrompt string
1071 var capturedToolChoice *ToolChoice
1072 var capturedToolNames []string
1073
1074 model := &mockLanguageModel{
1075 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1076 // Capture system prompt
1077 if len(call.Prompt) > 0 && call.Prompt[0].Role == MessageRoleSystem {
1078 if len(call.Prompt[0].Content) > 0 {
1079 if textPart, ok := AsContentType[TextPart](call.Prompt[0].Content[0]); ok {
1080 capturedSystemPrompt = textPart.Text
1081 }
1082 }
1083 }
1084 // Capture tool choice
1085 capturedToolChoice = call.ToolChoice
1086 // Capture tool names
1087 for _, tool := range call.Tools {
1088 capturedToolNames = append(capturedToolNames, tool.GetName())
1089 }
1090 return &Response{
1091 Content: ResponseContent{
1092 TextContent{Text: "Response"},
1093 },
1094 Usage: Usage{TotalTokens: 10},
1095 FinishReason: FinishReasonStop,
1096 }, nil
1097 },
1098 }
1099
1100 tool1 := &mockTool{name: "tool1", description: "Tool 1"}
1101 tool2 := &mockTool{name: "tool2", description: "Tool 2"}
1102
1103 prepareStepFunc := func(options PrepareStepFunctionOptions) (PrepareStepResult, error) {
1104 newSystem := "Step-specific system"
1105 toolChoice := SpecificToolChoice("tool1")
1106 activeTools := []string{"tool1"}
1107 return PrepareStepResult{
1108 Model: options.Model,
1109 Messages: options.Messages,
1110 System: &newSystem,
1111 ToolChoice: &toolChoice,
1112 ActiveTools: activeTools,
1113 }, nil
1114 }
1115
1116 agent := NewAgent(model, WithSystemPrompt("Original system"), WithTools(tool1, tool2))
1117
1118 result, err := agent.Generate(context.Background(), AgentCall{
1119 Prompt: "test prompt",
1120 PrepareStep: prepareStepFunc,
1121 })
1122
1123 require.NoError(t, err)
1124 require.NotNil(t, result)
1125 require.Equal(t, "Step-specific system", capturedSystemPrompt)
1126 require.NotNil(t, capturedToolChoice)
1127 require.Equal(t, SpecificToolChoice("tool1"), *capturedToolChoice)
1128 require.Len(t, capturedToolNames, 1)
1129 require.Equal(t, "tool1", capturedToolNames[0])
1130 })
1131
1132 t.Run("Nil fields use parent values", func(t *testing.T) {
1133 t.Parallel()
1134 var capturedSystemPrompt string
1135 var capturedToolChoice *ToolChoice
1136 var capturedToolNames []string
1137
1138 model := &mockLanguageModel{
1139 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1140 // Capture system prompt
1141 if len(call.Prompt) > 0 && call.Prompt[0].Role == MessageRoleSystem {
1142 if len(call.Prompt[0].Content) > 0 {
1143 if textPart, ok := AsContentType[TextPart](call.Prompt[0].Content[0]); ok {
1144 capturedSystemPrompt = textPart.Text
1145 }
1146 }
1147 }
1148 // Capture tool choice
1149 capturedToolChoice = call.ToolChoice
1150 // Capture tool names
1151 for _, tool := range call.Tools {
1152 capturedToolNames = append(capturedToolNames, tool.GetName())
1153 }
1154 return &Response{
1155 Content: ResponseContent{
1156 TextContent{Text: "Response"},
1157 },
1158 Usage: Usage{TotalTokens: 10},
1159 FinishReason: FinishReasonStop,
1160 }, nil
1161 },
1162 }
1163
1164 tool1 := &mockTool{name: "tool1", description: "Tool 1"}
1165
1166 prepareStepFunc := func(options PrepareStepFunctionOptions) (PrepareStepResult, error) {
1167 // All optional fields are nil, should use parent values
1168 return PrepareStepResult{
1169 Model: options.Model,
1170 Messages: options.Messages,
1171 System: nil, // Use parent
1172 ToolChoice: nil, // Use parent (auto)
1173 ActiveTools: nil, // Use parent (all tools)
1174 }, nil
1175 }
1176
1177 agent := NewAgent(model, WithSystemPrompt("Parent system"), WithTools(tool1))
1178
1179 result, err := agent.Generate(context.Background(), AgentCall{
1180 Prompt: "test prompt",
1181 PrepareStep: prepareStepFunc,
1182 })
1183
1184 require.NoError(t, err)
1185 require.NotNil(t, result)
1186 require.Equal(t, "Parent system", capturedSystemPrompt)
1187 require.NotNil(t, capturedToolChoice)
1188 require.Equal(t, ToolChoiceAuto, *capturedToolChoice) // Default
1189 require.Len(t, capturedToolNames, 1)
1190 require.Equal(t, "tool1", capturedToolNames[0])
1191 })
1192
1193 t.Run("Empty ActiveTools means all tools", func(t *testing.T) {
1194 t.Parallel()
1195 var capturedToolNames []string
1196 model := &mockLanguageModel{
1197 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1198 // Capture tool names to verify all tools are included
1199 for _, tool := range call.Tools {
1200 capturedToolNames = append(capturedToolNames, tool.GetName())
1201 }
1202 return &Response{
1203 Content: ResponseContent{
1204 TextContent{Text: "Response"},
1205 },
1206 Usage: Usage{TotalTokens: 10},
1207 FinishReason: FinishReasonStop,
1208 }, nil
1209 },
1210 }
1211
1212 tool1 := &mockTool{name: "tool1", description: "Tool 1"}
1213 tool2 := &mockTool{name: "tool2", description: "Tool 2"}
1214
1215 prepareStepFunc := func(options PrepareStepFunctionOptions) (PrepareStepResult, error) {
1216 return PrepareStepResult{
1217 Model: options.Model,
1218 Messages: options.Messages,
1219 ActiveTools: []string{}, // Empty slice means all tools
1220 }, nil
1221 }
1222
1223 agent := NewAgent(model, WithTools(tool1, tool2))
1224
1225 result, err := agent.Generate(context.Background(), AgentCall{
1226 Prompt: "test prompt",
1227 PrepareStep: prepareStepFunc,
1228 })
1229
1230 require.NoError(t, err)
1231 require.NotNil(t, result)
1232 require.Len(t, capturedToolNames, 2) // All tools should be included
1233 require.Contains(t, capturedToolNames, "tool1")
1234 require.Contains(t, capturedToolNames, "tool2")
1235 })
1236}
1237
1238func TestToolCallRepair(t *testing.T) {
1239 t.Parallel()
1240
1241 t.Run("Valid tool call passes validation", func(t *testing.T) {
1242 t.Parallel()
1243 model := &mockLanguageModel{
1244 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1245 return &Response{
1246 Content: ResponseContent{
1247 TextContent{Text: "Response"},
1248 ToolCallContent{
1249 ToolCallID: "call1",
1250 ToolName: "test_tool",
1251 Input: `{"value": "test"}`, // Valid JSON with required field
1252 },
1253 },
1254 Usage: Usage{TotalTokens: 10},
1255 FinishReason: FinishReasonStop, // Changed to stop to avoid infinite loop
1256 }, nil
1257 },
1258 }
1259
1260 tool := &mockTool{
1261 name: "test_tool",
1262 description: "Test tool",
1263 parameters: map[string]any{
1264 "value": map[string]any{"type": "string"},
1265 },
1266 required: []string{"value"},
1267 executeFunc: func(ctx context.Context, call ToolCall) (ToolResponse, error) {
1268 return ToolResponse{Content: "success", IsError: false}, nil
1269 },
1270 }
1271
1272 agent := NewAgent(model, WithTools(tool), WithStopConditions(StepCountIs(2))) // Limit steps
1273
1274 result, err := agent.Generate(context.Background(), AgentCall{
1275 Prompt: "test prompt",
1276 })
1277
1278 require.NoError(t, err)
1279 require.NotNil(t, result)
1280 require.Len(t, result.Steps, 1) // Only one step since FinishReason is stop
1281
1282 // Check that tool call was executed successfully
1283 toolCalls := result.Steps[0].Content.ToolCalls()
1284 require.Len(t, toolCalls, 1)
1285 require.False(t, toolCalls[0].Invalid) // Should be valid
1286 })
1287
1288 t.Run("Invalid tool call without repair function", func(t *testing.T) {
1289 t.Parallel()
1290 model := &mockLanguageModel{
1291 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1292 return &Response{
1293 Content: ResponseContent{
1294 TextContent{Text: "Response"},
1295 ToolCallContent{
1296 ToolCallID: "call1",
1297 ToolName: "test_tool",
1298 Input: `{"wrong_field": "test"}`, // Missing required field
1299 },
1300 },
1301 Usage: Usage{TotalTokens: 10},
1302 FinishReason: FinishReasonStop, // Changed to stop to avoid infinite loop
1303 }, nil
1304 },
1305 }
1306
1307 tool := &mockTool{
1308 name: "test_tool",
1309 description: "Test tool",
1310 parameters: map[string]any{
1311 "value": map[string]any{"type": "string"},
1312 },
1313 required: []string{"value"},
1314 }
1315
1316 agent := NewAgent(model, WithTools(tool), WithStopConditions(StepCountIs(2))) // Limit steps
1317
1318 result, err := agent.Generate(context.Background(), AgentCall{
1319 Prompt: "test prompt",
1320 })
1321
1322 require.NoError(t, err)
1323 require.NotNil(t, result)
1324 require.Len(t, result.Steps, 1) // Only one step
1325
1326 // Check that tool call was marked as invalid
1327 toolCalls := result.Steps[0].Content.ToolCalls()
1328 require.Len(t, toolCalls, 1)
1329 require.True(t, toolCalls[0].Invalid) // Should be invalid
1330 require.Contains(t, toolCalls[0].ValidationError.Error(), "missing required parameter: value")
1331 })
1332
1333 t.Run("Invalid tool call with successful repair", func(t *testing.T) {
1334 t.Parallel()
1335 model := &mockLanguageModel{
1336 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1337 return &Response{
1338 Content: ResponseContent{
1339 TextContent{Text: "Response"},
1340 ToolCallContent{
1341 ToolCallID: "call1",
1342 ToolName: "test_tool",
1343 Input: `{"wrong_field": "test"}`, // Missing required field
1344 },
1345 },
1346 Usage: Usage{TotalTokens: 10},
1347 FinishReason: FinishReasonStop, // Changed to stop
1348 }, nil
1349 },
1350 }
1351
1352 tool := &mockTool{
1353 name: "test_tool",
1354 description: "Test tool",
1355 parameters: map[string]any{
1356 "value": map[string]any{"type": "string"},
1357 },
1358 required: []string{"value"},
1359 executeFunc: func(ctx context.Context, call ToolCall) (ToolResponse, error) {
1360 return ToolResponse{Content: "repaired_success", IsError: false}, nil
1361 },
1362 }
1363
1364 repairFunc := func(ctx context.Context, options ToolCallRepairOptions) (*ToolCallContent, error) {
1365 // Simple repair: add the missing required field
1366 repairedToolCall := options.OriginalToolCall
1367 repairedToolCall.Input = `{"value": "repaired"}`
1368 return &repairedToolCall, nil
1369 }
1370
1371 agent := NewAgent(model, WithTools(tool), WithRepairToolCall(repairFunc), WithStopConditions(StepCountIs(2)))
1372
1373 result, err := agent.Generate(context.Background(), AgentCall{
1374 Prompt: "test prompt",
1375 })
1376
1377 require.NoError(t, err)
1378 require.NotNil(t, result)
1379 require.Len(t, result.Steps, 1) // Only one step
1380
1381 // Check that tool call was repaired and is now valid
1382 toolCalls := result.Steps[0].Content.ToolCalls()
1383 require.Len(t, toolCalls, 1)
1384 require.False(t, toolCalls[0].Invalid) // Should be valid after repair
1385 require.Equal(t, `{"value": "repaired"}`, toolCalls[0].Input) // Should have repaired input
1386 })
1387
1388 t.Run("Invalid tool call with failed repair", func(t *testing.T) {
1389 t.Parallel()
1390 model := &mockLanguageModel{
1391 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1392 return &Response{
1393 Content: ResponseContent{
1394 TextContent{Text: "Response"},
1395 ToolCallContent{
1396 ToolCallID: "call1",
1397 ToolName: "test_tool",
1398 Input: `{"wrong_field": "test"}`, // Missing required field
1399 },
1400 },
1401 Usage: Usage{TotalTokens: 10},
1402 FinishReason: FinishReasonStop, // Changed to stop
1403 }, nil
1404 },
1405 }
1406
1407 tool := &mockTool{
1408 name: "test_tool",
1409 description: "Test tool",
1410 parameters: map[string]any{
1411 "value": map[string]any{"type": "string"},
1412 },
1413 required: []string{"value"},
1414 }
1415
1416 repairFunc := func(ctx context.Context, options ToolCallRepairOptions) (*ToolCallContent, error) {
1417 // Repair function fails
1418 return nil, errors.New("repair failed")
1419 }
1420
1421 agent := NewAgent(model, WithTools(tool), WithRepairToolCall(repairFunc), WithStopConditions(StepCountIs(2)))
1422
1423 result, err := agent.Generate(context.Background(), AgentCall{
1424 Prompt: "test prompt",
1425 })
1426
1427 require.NoError(t, err)
1428 require.NotNil(t, result)
1429 require.Len(t, result.Steps, 1) // Only one step
1430
1431 // Check that tool call was marked as invalid since repair failed
1432 toolCalls := result.Steps[0].Content.ToolCalls()
1433 require.Len(t, toolCalls, 1)
1434 require.True(t, toolCalls[0].Invalid) // Should be invalid
1435 require.Contains(t, toolCalls[0].ValidationError.Error(), "missing required parameter: value")
1436 })
1437
1438 t.Run("Nonexistent tool call", func(t *testing.T) {
1439 t.Parallel()
1440 model := &mockLanguageModel{
1441 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1442 return &Response{
1443 Content: ResponseContent{
1444 TextContent{Text: "Response"},
1445 ToolCallContent{
1446 ToolCallID: "call1",
1447 ToolName: "nonexistent_tool",
1448 Input: `{"value": "test"}`,
1449 },
1450 },
1451 Usage: Usage{TotalTokens: 10},
1452 FinishReason: FinishReasonStop, // Changed to stop
1453 }, nil
1454 },
1455 }
1456
1457 tool := &mockTool{name: "test_tool", description: "Test tool"}
1458
1459 agent := NewAgent(model, WithTools(tool), WithStopConditions(StepCountIs(2)))
1460
1461 result, err := agent.Generate(context.Background(), AgentCall{
1462 Prompt: "test prompt",
1463 })
1464
1465 require.NoError(t, err)
1466 require.NotNil(t, result)
1467 require.Len(t, result.Steps, 1) // Only one step
1468
1469 // Check that tool call was marked as invalid due to nonexistent tool
1470 toolCalls := result.Steps[0].Content.ToolCalls()
1471 require.Len(t, toolCalls, 1)
1472 require.True(t, toolCalls[0].Invalid) // Should be invalid
1473 require.Contains(t, toolCalls[0].ValidationError.Error(), "tool not found: nonexistent_tool")
1474 })
1475
1476 t.Run("Invalid JSON in tool call", func(t *testing.T) {
1477 t.Parallel()
1478 model := &mockLanguageModel{
1479 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1480 return &Response{
1481 Content: ResponseContent{
1482 TextContent{Text: "Response"},
1483 ToolCallContent{
1484 ToolCallID: "call1",
1485 ToolName: "test_tool",
1486 Input: `{invalid json}`, // Invalid JSON
1487 },
1488 },
1489 Usage: Usage{TotalTokens: 10},
1490 FinishReason: FinishReasonStop, // Changed to stop
1491 }, nil
1492 },
1493 }
1494
1495 tool := &mockTool{
1496 name: "test_tool",
1497 description: "Test tool",
1498 parameters: map[string]any{
1499 "value": map[string]any{"type": "string"},
1500 },
1501 required: []string{"value"},
1502 }
1503
1504 agent := NewAgent(model, WithTools(tool), WithStopConditions(StepCountIs(2)))
1505
1506 result, err := agent.Generate(context.Background(), AgentCall{
1507 Prompt: "test prompt",
1508 })
1509
1510 require.NoError(t, err)
1511 require.NotNil(t, result)
1512 require.Len(t, result.Steps, 1) // Only one step
1513
1514 // Check that tool call was marked as invalid due to invalid JSON
1515 toolCalls := result.Steps[0].Content.ToolCalls()
1516 require.Len(t, toolCalls, 1)
1517 require.True(t, toolCalls[0].Invalid) // Should be invalid
1518 require.Contains(t, toolCalls[0].ValidationError.Error(), "invalid JSON input")
1519 })
1520}