1package ai
2
3import (
4 "context"
5 "encoding/json"
6 "errors"
7 "fmt"
8 "testing"
9
10 "github.com/stretchr/testify/require"
11)
12
13// Mock tool for testing
14type mockTool struct {
15 name string
16 description string
17 parameters map[string]any
18 required []string
19 executeFunc func(ctx context.Context, call ToolCall) (ToolResponse, error)
20}
21
22func (m *mockTool) Info() ToolInfo {
23 return ToolInfo{
24 Name: m.name,
25 Description: m.description,
26 Parameters: m.parameters,
27 Required: m.required,
28 }
29}
30
31func (m *mockTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
32 if m.executeFunc != nil {
33 return m.executeFunc(ctx, call)
34 }
35 return ToolResponse{Content: "mock result", IsError: false}, nil
36}
37
38// Mock language model for testing
39type mockLanguageModel struct {
40 generateFunc func(ctx context.Context, call Call) (*Response, error)
41 streamFunc func(ctx context.Context, call Call) (StreamResponse, error)
42}
43
44func (m *mockLanguageModel) Generate(ctx context.Context, call Call) (*Response, error) {
45 if m.generateFunc != nil {
46 return m.generateFunc(ctx, call)
47 }
48 return &Response{
49 Content: []Content{
50 TextContent{Text: "Hello, world!"},
51 },
52 Usage: Usage{
53 InputTokens: 3,
54 OutputTokens: 10,
55 TotalTokens: 13,
56 },
57 FinishReason: FinishReasonStop,
58 }, nil
59}
60
61func (m *mockLanguageModel) Stream(ctx context.Context, call Call) (StreamResponse, error) {
62 if m.streamFunc != nil {
63 return m.streamFunc(ctx, call)
64 }
65 return nil, fmt.Errorf("mock stream not implemented")
66}
67
68func (m *mockLanguageModel) Provider() string {
69 return "mock-provider"
70}
71
72func (m *mockLanguageModel) Model() string {
73 return "mock-model"
74}
75
76// Test result.content - comprehensive content types (matches TS test)
77func TestAgent_Generate_ResultContent_AllTypes(t *testing.T) {
78 t.Parallel()
79
80 // Create a type-safe tool using the new API
81 type TestInput struct {
82 Value string `json:"value" description:"Test value"`
83 }
84
85 tool1 := NewAgentTool(
86 "tool1",
87 "Test tool",
88 func(ctx context.Context, input TestInput, _ ToolCall) (ToolResponse, error) {
89 require.Equal(t, "value", input.Value)
90 return ToolResponse{Content: "result1", IsError: false}, nil
91 },
92 )
93
94 model := &mockLanguageModel{
95 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
96 return &Response{
97 Content: []Content{
98 TextContent{Text: "Hello, world!"},
99 SourceContent{
100 ID: "123",
101 URL: "https://example.com",
102 Title: "Example",
103 SourceType: SourceTypeURL,
104 ProviderMetadata: ProviderMetadata{
105 "provider": map[string]any{"custom": "value"},
106 },
107 },
108 FileContent{
109 Data: []byte{1, 2, 3},
110 MediaType: "image/png",
111 },
112 ReasoningContent{
113 Text: "I will open the conversation with witty banter.",
114 },
115 ToolCallContent{
116 ToolCallID: "call-1",
117 ToolName: "tool1",
118 Input: `{"value":"value"}`,
119 },
120 TextContent{Text: "More text"},
121 },
122 Usage: Usage{
123 InputTokens: 3,
124 OutputTokens: 10,
125 TotalTokens: 13,
126 },
127 FinishReason: FinishReasonStop, // Note: FinishReasonStop, not ToolCalls
128 }, nil
129 },
130 }
131
132 agent := NewAgent(model, WithTools(tool1))
133 result, err := agent.Generate(context.Background(), AgentCall{
134 Prompt: "prompt",
135 })
136
137 require.NoError(t, err)
138 require.NotNil(t, result)
139 require.Len(t, result.Steps, 1) // Single step like TypeScript
140
141 // Check final response content includes tool result
142 require.Len(t, result.Response.Content, 7) // original 6 + 1 tool result
143
144 // Verify each content type in order
145 textContent, ok := AsContentType[TextContent](result.Response.Content[0])
146 require.True(t, ok)
147 require.Equal(t, "Hello, world!", textContent.Text)
148
149 sourceContent, ok := AsContentType[SourceContent](result.Response.Content[1])
150 require.True(t, ok)
151 require.Equal(t, "123", sourceContent.ID)
152
153 fileContent, ok := AsContentType[FileContent](result.Response.Content[2])
154 require.True(t, ok)
155 require.Equal(t, []byte{1, 2, 3}, fileContent.Data)
156
157 reasoningContent, ok := AsContentType[ReasoningContent](result.Response.Content[3])
158 require.True(t, ok)
159 require.Equal(t, "I will open the conversation with witty banter.", reasoningContent.Text)
160
161 toolCallContent, ok := AsContentType[ToolCallContent](result.Response.Content[4])
162 require.True(t, ok)
163 require.Equal(t, "call-1", toolCallContent.ToolCallID)
164
165 moreTextContent, ok := AsContentType[TextContent](result.Response.Content[5])
166 require.True(t, ok)
167 require.Equal(t, "More text", moreTextContent.Text)
168
169 // Tool result should be appended
170 toolResultContent, ok := AsContentType[ToolResultContent](result.Response.Content[6])
171 require.True(t, ok)
172 require.Equal(t, "call-1", toolResultContent.ToolCallID)
173 require.Equal(t, "tool1", toolResultContent.ToolName)
174}
175
176// Test result.text extraction
177func TestAgent_Generate_ResultText(t *testing.T) {
178 t.Parallel()
179
180 model := &mockLanguageModel{
181 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
182 return &Response{
183 Content: []Content{
184 TextContent{Text: "Hello, world!"},
185 },
186 Usage: Usage{
187 InputTokens: 3,
188 OutputTokens: 10,
189 TotalTokens: 13,
190 },
191 FinishReason: FinishReasonStop,
192 }, nil
193 },
194 }
195
196 agent := NewAgent(model)
197 result, err := agent.Generate(context.Background(), AgentCall{
198 Prompt: "prompt",
199 })
200
201 require.NoError(t, err)
202 require.NotNil(t, result)
203
204 // Test text extraction from content
205 text := result.Response.Content.Text()
206 require.Equal(t, "Hello, world!", text)
207}
208
209// Test result.toolCalls extraction (matches TS test exactly)
210func TestAgent_Generate_ResultToolCalls(t *testing.T) {
211 t.Parallel()
212
213 // Create type-safe tools using the new API
214 type Tool1Input struct {
215 Value string `json:"value" description:"Test value"`
216 }
217
218 type Tool2Input struct {
219 SomethingElse string `json:"somethingElse" description:"Another test value"`
220 }
221
222 tool1 := NewAgentTool(
223 "tool1",
224 "Test tool 1",
225 func(ctx context.Context, input Tool1Input, _ ToolCall) (ToolResponse, error) {
226 return ToolResponse{Content: "result1", IsError: false}, nil
227 },
228 )
229
230 tool2 := NewAgentTool(
231 "tool2",
232 "Test tool 2",
233 func(ctx context.Context, input Tool2Input, _ ToolCall) (ToolResponse, error) {
234 return ToolResponse{Content: "result2", IsError: false}, nil
235 },
236 )
237
238 model := &mockLanguageModel{
239 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
240 // Verify tools are passed correctly
241 require.Len(t, call.Tools, 2)
242 require.Equal(t, ToolChoiceAuto, *call.ToolChoice) // Should be auto, not required
243
244 // Verify prompt structure
245 require.Len(t, call.Prompt, 1)
246 require.Equal(t, MessageRoleUser, call.Prompt[0].Role)
247
248 return &Response{
249 Content: []Content{
250 ToolCallContent{
251 ToolCallID: "call-1",
252 ToolName: "tool1",
253 Input: `{"value":"value"}`,
254 },
255 },
256 Usage: Usage{
257 InputTokens: 3,
258 OutputTokens: 10,
259 TotalTokens: 13,
260 },
261 FinishReason: FinishReasonStop, // Note: Stop, not ToolCalls
262 }, nil
263 },
264 }
265
266 agent := NewAgent(model, WithTools(tool1, tool2))
267 result, err := agent.Generate(context.Background(), AgentCall{
268 Prompt: "test-input",
269 })
270
271 require.NoError(t, err)
272 require.NotNil(t, result)
273 require.Len(t, result.Steps, 1) // Single step
274
275 // Extract tool calls from final response (should be empty since tools don't execute)
276 var toolCalls []ToolCallContent
277 for _, content := range result.Response.Content {
278 if toolCall, ok := AsContentType[ToolCallContent](content); ok {
279 toolCalls = append(toolCalls, toolCall)
280 }
281 }
282
283 require.Len(t, toolCalls, 1)
284 require.Equal(t, "call-1", toolCalls[0].ToolCallID)
285 require.Equal(t, "tool1", toolCalls[0].ToolName)
286
287 // Parse and verify input
288 var input map[string]any
289 err = json.Unmarshal([]byte(toolCalls[0].Input), &input)
290 require.NoError(t, err)
291 require.Equal(t, "value", input["value"])
292}
293
294// Test result.toolResults extraction (matches TS test exactly)
295func TestAgent_Generate_ResultToolResults(t *testing.T) {
296 t.Parallel()
297
298 // Create type-safe tool using the new API
299 type TestInput struct {
300 Value string `json:"value" description:"Test value"`
301 }
302
303 tool1 := NewAgentTool(
304 "tool1",
305 "Test tool",
306 func(ctx context.Context, input TestInput, _ ToolCall) (ToolResponse, error) {
307 require.Equal(t, "value", input.Value)
308 return ToolResponse{Content: "result1", IsError: false}, nil
309 },
310 )
311
312 model := &mockLanguageModel{
313 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
314 // Verify tools and tool choice
315 require.Len(t, call.Tools, 1)
316 require.Equal(t, ToolChoiceAuto, *call.ToolChoice)
317
318 // Verify prompt
319 require.Len(t, call.Prompt, 1)
320 require.Equal(t, MessageRoleUser, call.Prompt[0].Role)
321
322 return &Response{
323 Content: []Content{
324 ToolCallContent{
325 ToolCallID: "call-1",
326 ToolName: "tool1",
327 Input: `{"value":"value"}`,
328 },
329 },
330 Usage: Usage{
331 InputTokens: 3,
332 OutputTokens: 10,
333 TotalTokens: 13,
334 },
335 FinishReason: FinishReasonStop, // Note: Stop, not ToolCalls
336 }, nil
337 },
338 }
339
340 agent := NewAgent(model, WithTools(tool1))
341 result, err := agent.Generate(context.Background(), AgentCall{
342 Prompt: "test-input",
343 })
344
345 require.NoError(t, err)
346 require.NotNil(t, result)
347 require.Len(t, result.Steps, 1) // Single step
348
349 // Extract tool results from final response
350 var toolResults []ToolResultContent
351 for _, content := range result.Response.Content {
352 if toolResult, ok := AsContentType[ToolResultContent](content); ok {
353 toolResults = append(toolResults, toolResult)
354 }
355 }
356
357 require.Len(t, toolResults, 1)
358 require.Equal(t, "call-1", toolResults[0].ToolCallID)
359 require.Equal(t, "tool1", toolResults[0].ToolName)
360
361 // Verify result content
362 textResult, ok := toolResults[0].Result.(ToolResultOutputContentText)
363 require.True(t, ok)
364 require.Equal(t, "result1", textResult.Text)
365}
366
367// Test multi-step scenario (matches TS "2 steps: initial, tool-result" test)
368func TestAgent_Generate_MultipleSteps(t *testing.T) {
369 t.Parallel()
370
371 // Create type-safe tool using the new API
372 type TestInput struct {
373 Value string `json:"value" description:"Test value"`
374 }
375
376 tool1 := NewAgentTool(
377 "tool1",
378 "Test tool",
379 func(ctx context.Context, input TestInput, _ ToolCall) (ToolResponse, error) {
380 require.Equal(t, "value", input.Value)
381 return ToolResponse{Content: "result1", IsError: false}, nil
382 },
383 )
384
385 callCount := 0
386 model := &mockLanguageModel{
387 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
388 callCount++
389 switch callCount {
390 case 1:
391 // First call - return tool call with FinishReasonToolCalls
392 return &Response{
393 Content: []Content{
394 ToolCallContent{
395 ToolCallID: "call-1",
396 ToolName: "tool1",
397 Input: `{"value":"value"}`,
398 },
399 },
400 Usage: Usage{
401 InputTokens: 10,
402 OutputTokens: 5,
403 TotalTokens: 15,
404 },
405 FinishReason: FinishReasonToolCalls, // This triggers multi-step
406 }, nil
407 case 2:
408 // Second call - return final text
409 return &Response{
410 Content: []Content{
411 TextContent{Text: "Hello, world!"},
412 },
413 Usage: Usage{
414 InputTokens: 3,
415 OutputTokens: 10,
416 TotalTokens: 13,
417 },
418 FinishReason: FinishReasonStop,
419 }, nil
420 default:
421 t.Fatalf("Unexpected call count: %d", callCount)
422 return nil, nil
423 }
424 },
425 }
426
427 agent := NewAgent(model, WithTools(tool1))
428 result, err := agent.Generate(context.Background(), AgentCall{
429 Prompt: "test-input",
430 })
431
432 require.NoError(t, err)
433 require.NotNil(t, result)
434 require.Len(t, result.Steps, 2)
435
436 // Check total usage sums both steps
437 require.Equal(t, int64(13), result.TotalUsage.InputTokens) // 10 + 3
438 require.Equal(t, int64(15), result.TotalUsage.OutputTokens) // 5 + 10
439 require.Equal(t, int64(28), result.TotalUsage.TotalTokens) // 15 + 13
440
441 // Final response should be from last step
442 require.Len(t, result.Response.Content, 1)
443 textContent, ok := AsContentType[TextContent](result.Response.Content[0])
444 require.True(t, ok)
445 require.Equal(t, "Hello, world!", textContent.Text)
446
447 // result.toolCalls should be empty (from last step)
448 var toolCalls []ToolCallContent
449 for _, content := range result.Response.Content {
450 if _, ok := AsContentType[ToolCallContent](content); ok {
451 toolCalls = append(toolCalls, content.(ToolCallContent))
452 }
453 }
454 require.Len(t, toolCalls, 0)
455
456 // result.toolResults should be empty (from last step)
457 var toolResults []ToolResultContent
458 for _, content := range result.Response.Content {
459 if _, ok := AsContentType[ToolResultContent](content); ok {
460 toolResults = append(toolResults, content.(ToolResultContent))
461 }
462 }
463 require.Len(t, toolResults, 0)
464}
465
466// Test basic text generation
467func TestAgent_Generate_BasicText(t *testing.T) {
468 t.Parallel()
469
470 model := &mockLanguageModel{
471 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
472 return &Response{
473 Content: []Content{
474 TextContent{Text: "Hello, world!"},
475 },
476 Usage: Usage{
477 InputTokens: 3,
478 OutputTokens: 10,
479 TotalTokens: 13,
480 },
481 FinishReason: FinishReasonStop,
482 }, nil
483 },
484 }
485
486 agent := NewAgent(model)
487 result, err := agent.Generate(context.Background(), AgentCall{
488 Prompt: "test prompt",
489 })
490
491 require.NoError(t, err)
492 require.NotNil(t, result)
493 require.Len(t, result.Steps, 1)
494
495 // Check final response
496 require.Len(t, result.Response.Content, 1)
497 textContent, ok := AsContentType[TextContent](result.Response.Content[0])
498 require.True(t, ok)
499 require.Equal(t, "Hello, world!", textContent.Text)
500
501 // Check usage
502 require.Equal(t, int64(3), result.Response.Usage.InputTokens)
503 require.Equal(t, int64(10), result.Response.Usage.OutputTokens)
504 require.Equal(t, int64(13), result.Response.Usage.TotalTokens)
505
506 // Check total usage
507 require.Equal(t, int64(3), result.TotalUsage.InputTokens)
508 require.Equal(t, int64(10), result.TotalUsage.OutputTokens)
509 require.Equal(t, int64(13), result.TotalUsage.TotalTokens)
510}
511
512// Test empty prompt error
513func TestAgent_Generate_EmptyPrompt(t *testing.T) {
514 t.Parallel()
515
516 model := &mockLanguageModel{}
517 agent := NewAgent(model)
518
519 result, err := agent.Generate(context.Background(), AgentCall{
520 Prompt: "", // Empty prompt should cause error
521 })
522
523 require.Error(t, err)
524 require.Nil(t, result)
525 require.Contains(t, err.Error(), "Prompt can't be empty")
526}
527
528// Test with system prompt
529func TestAgent_Generate_WithSystemPrompt(t *testing.T) {
530 t.Parallel()
531
532 model := &mockLanguageModel{
533 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
534 // Verify system message is included
535 require.Len(t, call.Prompt, 2) // system + user
536 require.Equal(t, MessageRoleSystem, call.Prompt[0].Role)
537 require.Equal(t, MessageRoleUser, call.Prompt[1].Role)
538
539 systemPart, ok := call.Prompt[0].Content[0].(TextPart)
540 require.True(t, ok)
541 require.Equal(t, "You are a helpful assistant", systemPart.Text)
542
543 return &Response{
544 Content: []Content{
545 TextContent{Text: "Hello, world!"},
546 },
547 Usage: Usage{
548 InputTokens: 3,
549 OutputTokens: 10,
550 TotalTokens: 13,
551 },
552 FinishReason: FinishReasonStop,
553 }, nil
554 },
555 }
556
557 agent := NewAgent(model, WithSystemPrompt("You are a helpful assistant"))
558 result, err := agent.Generate(context.Background(), AgentCall{
559 Prompt: "test prompt",
560 })
561
562 require.NoError(t, err)
563 require.NotNil(t, result)
564}
565
566// Test options.headers
567func TestAgent_Generate_OptionsHeaders(t *testing.T) {
568 t.Parallel()
569
570 model := &mockLanguageModel{
571 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
572 // Verify headers are passed
573 require.Equal(t, map[string]string{
574 "custom-request-header": "request-header-value",
575 }, call.Headers)
576
577 return &Response{
578 Content: []Content{
579 TextContent{Text: "Hello, world!"},
580 },
581 Usage: Usage{
582 InputTokens: 3,
583 OutputTokens: 10,
584 TotalTokens: 13,
585 },
586 FinishReason: FinishReasonStop,
587 }, nil
588 },
589 }
590
591 agent := NewAgent(model)
592 result, err := agent.Generate(context.Background(), AgentCall{
593 Prompt: "test-input",
594 Headers: map[string]string{"custom-request-header": "request-header-value"},
595 })
596
597 require.NoError(t, err)
598 require.NotNil(t, result)
599 require.Equal(t, "Hello, world!", result.Response.Content.Text())
600}
601
602// Test options.activeTools filtering
603func TestAgent_Generate_OptionsActiveTools(t *testing.T) {
604 t.Parallel()
605
606 tool1 := &mockTool{
607 name: "tool1",
608 description: "Test tool 1",
609 parameters: map[string]any{
610 "value": map[string]any{"type": "string"},
611 },
612 required: []string{"value"},
613 }
614
615 tool2 := &mockTool{
616 name: "tool2",
617 description: "Test tool 2",
618 parameters: map[string]any{
619 "value": map[string]any{"type": "string"},
620 },
621 required: []string{"value"},
622 }
623
624 model := &mockLanguageModel{
625 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
626 // Verify only tool1 is available
627 require.Len(t, call.Tools, 1)
628 functionTool, ok := call.Tools[0].(FunctionTool)
629 require.True(t, ok)
630 require.Equal(t, "tool1", functionTool.Name)
631
632 return &Response{
633 Content: []Content{
634 TextContent{Text: "Hello, world!"},
635 },
636 Usage: Usage{
637 InputTokens: 3,
638 OutputTokens: 10,
639 TotalTokens: 13,
640 },
641 FinishReason: FinishReasonStop,
642 }, nil
643 },
644 }
645
646 agent := NewAgent(model, WithTools(tool1, tool2))
647 result, err := agent.Generate(context.Background(), AgentCall{
648 Prompt: "test-input",
649 ActiveTools: []string{"tool1"}, // Only tool1 should be active
650 })
651
652 require.NoError(t, err)
653 require.NotNil(t, result)
654}
655
656func TestResponseContent_Getters(t *testing.T) {
657 t.Parallel()
658
659 // Create test content with all types
660 content := ResponseContent{
661 TextContent{Text: "Hello world"},
662 ReasoningContent{Text: "Let me think..."},
663 FileContent{Data: []byte("file data"), MediaType: "text/plain"},
664 SourceContent{SourceType: SourceTypeURL, URL: "https://example.com", Title: "Example"},
665 ToolCallContent{ToolCallID: "call1", ToolName: "test_tool", Input: `{"arg": "value"}`},
666 ToolResultContent{ToolCallID: "call1", ToolName: "test_tool", Result: ToolResultOutputContentText{Text: "result"}},
667 }
668
669 // Test Text()
670 require.Equal(t, "Hello world", content.Text())
671
672 // Test Reasoning()
673 reasoning := content.Reasoning()
674 require.Len(t, reasoning, 1)
675 require.Equal(t, "Let me think...", reasoning[0].Text)
676
677 // Test ReasoningText()
678 require.Equal(t, "Let me think...", content.ReasoningText())
679
680 // Test Files()
681 files := content.Files()
682 require.Len(t, files, 1)
683 require.Equal(t, "text/plain", files[0].MediaType)
684 require.Equal(t, []byte("file data"), files[0].Data)
685
686 // Test Sources()
687 sources := content.Sources()
688 require.Len(t, sources, 1)
689 require.Equal(t, SourceTypeURL, sources[0].SourceType)
690 require.Equal(t, "https://example.com", sources[0].URL)
691 require.Equal(t, "Example", sources[0].Title)
692
693 // Test ToolCalls()
694 toolCalls := content.ToolCalls()
695 require.Len(t, toolCalls, 1)
696 require.Equal(t, "call1", toolCalls[0].ToolCallID)
697 require.Equal(t, "test_tool", toolCalls[0].ToolName)
698 require.Equal(t, `{"arg": "value"}`, toolCalls[0].Input)
699
700 // Test ToolResults()
701 toolResults := content.ToolResults()
702 require.Len(t, toolResults, 1)
703 require.Equal(t, "call1", toolResults[0].ToolCallID)
704 require.Equal(t, "test_tool", toolResults[0].ToolName)
705 result, ok := AsToolResultOutputType[ToolResultOutputContentText](toolResults[0].Result)
706 require.True(t, ok)
707 require.Equal(t, "result", result.Text)
708}
709
710func TestResponseContent_Getters_Empty(t *testing.T) {
711 t.Parallel()
712
713 // Test with empty content
714 content := ResponseContent{}
715
716 require.Equal(t, "", content.Text())
717 require.Equal(t, "", content.ReasoningText())
718 require.Empty(t, content.Reasoning())
719 require.Empty(t, content.Files())
720 require.Empty(t, content.Sources())
721 require.Empty(t, content.ToolCalls())
722 require.Empty(t, content.ToolResults())
723}
724
725func TestResponseContent_Getters_MultipleItems(t *testing.T) {
726 t.Parallel()
727
728 // Test with multiple items of same type
729 content := ResponseContent{
730 ReasoningContent{Text: "First thought"},
731 ReasoningContent{Text: "Second thought"},
732 FileContent{Data: []byte("file1"), MediaType: "text/plain"},
733 FileContent{Data: []byte("file2"), MediaType: "image/png"},
734 }
735
736 // Test multiple reasoning
737 reasoning := content.Reasoning()
738 require.Len(t, reasoning, 2)
739 require.Equal(t, "First thought", reasoning[0].Text)
740 require.Equal(t, "Second thought", reasoning[1].Text)
741
742 // Test concatenated reasoning text
743 require.Equal(t, "First thoughtSecond thought", content.ReasoningText())
744
745 // Test multiple files
746 files := content.Files()
747 require.Len(t, files, 2)
748 require.Equal(t, "text/plain", files[0].MediaType)
749 require.Equal(t, "image/png", files[1].MediaType)
750}
751
752func TestStopConditions(t *testing.T) {
753 t.Parallel()
754
755 // Create test steps
756 step1 := StepResult{
757 Response: Response{
758 Content: ResponseContent{
759 TextContent{Text: "Hello"},
760 },
761 FinishReason: FinishReasonToolCalls,
762 Usage: Usage{TotalTokens: 10},
763 },
764 }
765
766 step2 := StepResult{
767 Response: Response{
768 Content: ResponseContent{
769 TextContent{Text: "World"},
770 ToolCallContent{ToolCallID: "call1", ToolName: "search", Input: `{"query": "test"}`},
771 },
772 FinishReason: FinishReasonStop,
773 Usage: Usage{TotalTokens: 15},
774 },
775 }
776
777 step3 := StepResult{
778 Response: Response{
779 Content: ResponseContent{
780 ReasoningContent{Text: "Let me think..."},
781 FileContent{Data: []byte("data"), MediaType: "text/plain"},
782 },
783 FinishReason: FinishReasonLength,
784 Usage: Usage{TotalTokens: 20},
785 },
786 }
787
788 t.Run("StepCountIs", func(t *testing.T) {
789 t.Parallel()
790 condition := StepCountIs(2)
791
792 // Should not stop with 1 step
793 require.False(t, condition([]StepResult{step1}))
794
795 // Should stop with 2 steps
796 require.True(t, condition([]StepResult{step1, step2}))
797
798 // Should stop with more than 2 steps
799 require.True(t, condition([]StepResult{step1, step2, step3}))
800
801 // Should not stop with empty steps
802 require.False(t, condition([]StepResult{}))
803 })
804
805 t.Run("HasToolCall", func(t *testing.T) {
806 t.Parallel()
807 condition := HasToolCall("search")
808
809 // Should not stop when tool not called
810 require.False(t, condition([]StepResult{step1}))
811
812 // Should stop when tool is called in last step
813 require.True(t, condition([]StepResult{step1, step2}))
814
815 // Should not stop when tool called in earlier step but not last
816 require.False(t, condition([]StepResult{step1, step2, step3}))
817
818 // Should not stop with empty steps
819 require.False(t, condition([]StepResult{}))
820
821 // Should not stop when different tool is called
822 differentToolCondition := HasToolCall("different_tool")
823 require.False(t, differentToolCondition([]StepResult{step1, step2}))
824 })
825
826 t.Run("HasContent", func(t *testing.T) {
827 t.Parallel()
828 reasoningCondition := HasContent(ContentTypeReasoning)
829 fileCondition := HasContent(ContentTypeFile)
830
831 // Should not stop when content type not present
832 require.False(t, reasoningCondition([]StepResult{step1, step2}))
833
834 // Should stop when content type is present in last step
835 require.True(t, reasoningCondition([]StepResult{step1, step2, step3}))
836 require.True(t, fileCondition([]StepResult{step1, step2, step3}))
837
838 // Should not stop with empty steps
839 require.False(t, reasoningCondition([]StepResult{}))
840 })
841
842 t.Run("FinishReasonIs", func(t *testing.T) {
843 t.Parallel()
844 stopCondition := FinishReasonIs(FinishReasonStop)
845 lengthCondition := FinishReasonIs(FinishReasonLength)
846
847 // Should not stop when finish reason doesn't match
848 require.False(t, stopCondition([]StepResult{step1}))
849
850 // Should stop when finish reason matches in last step
851 require.True(t, stopCondition([]StepResult{step1, step2}))
852 require.True(t, lengthCondition([]StepResult{step1, step2, step3}))
853
854 // Should not stop with empty steps
855 require.False(t, stopCondition([]StepResult{}))
856 })
857
858 t.Run("MaxTokensUsed", func(t *testing.T) {
859 condition := MaxTokensUsed(30)
860
861 // Should not stop when under limit
862 require.False(t, condition([]StepResult{step1})) // 10 tokens
863 require.False(t, condition([]StepResult{step1, step2})) // 25 tokens
864
865 // Should stop when at or over limit
866 require.True(t, condition([]StepResult{step1, step2, step3})) // 45 tokens
867
868 // Should not stop with empty steps
869 require.False(t, condition([]StepResult{}))
870 })
871}
872
873func TestStopConditions_Integration(t *testing.T) {
874 t.Parallel()
875
876 t.Run("StepCountIs integration", func(t *testing.T) {
877 t.Parallel()
878 model := &mockLanguageModel{
879 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
880 return &Response{
881 Content: ResponseContent{
882 TextContent{Text: "Mock response"},
883 },
884 Usage: Usage{
885 InputTokens: 3,
886 OutputTokens: 10,
887 TotalTokens: 13,
888 },
889 FinishReason: FinishReasonStop,
890 }, nil
891 },
892 }
893
894 agent := NewAgent(model, WithStopConditions(StepCountIs(1)))
895
896 result, err := agent.Generate(context.Background(), AgentCall{
897 Prompt: "test prompt",
898 })
899
900 require.NoError(t, err)
901 require.NotNil(t, result)
902 require.Len(t, result.Steps, 1) // Should stop after 1 step
903 })
904
905 t.Run("Multiple stop conditions", func(t *testing.T) {
906 t.Parallel()
907 model := &mockLanguageModel{
908 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
909 return &Response{
910 Content: ResponseContent{
911 TextContent{Text: "Mock response"},
912 },
913 Usage: Usage{
914 InputTokens: 3,
915 OutputTokens: 10,
916 TotalTokens: 13,
917 },
918 FinishReason: FinishReasonStop,
919 }, nil
920 },
921 }
922
923 agent := NewAgent(model, WithStopConditions(
924 StepCountIs(5), // Stop after 5 steps
925 FinishReasonIs(FinishReasonStop), // Or stop on finish reason
926 ))
927
928 result, err := agent.Generate(context.Background(), AgentCall{
929 Prompt: "test prompt",
930 })
931
932 require.NoError(t, err)
933 require.NotNil(t, result)
934 // Should stop on first condition met (finish reason stop)
935 require.Equal(t, FinishReasonStop, result.Response.FinishReason)
936 })
937}
938
939func TestPrepareStep(t *testing.T) {
940 t.Parallel()
941
942 t.Run("System prompt modification", func(t *testing.T) {
943 t.Parallel()
944 var capturedSystemPrompt string
945 model := &mockLanguageModel{
946 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
947 // Capture the system message to verify it was modified
948 if len(call.Prompt) > 0 && call.Prompt[0].Role == MessageRoleSystem {
949 if len(call.Prompt[0].Content) > 0 {
950 if textPart, ok := AsContentType[TextPart](call.Prompt[0].Content[0]); ok {
951 capturedSystemPrompt = textPart.Text
952 }
953 }
954 }
955 return &Response{
956 Content: ResponseContent{
957 TextContent{Text: "Response"},
958 },
959 Usage: Usage{TotalTokens: 10},
960 FinishReason: FinishReasonStop,
961 }, nil
962 },
963 }
964
965 prepareStepFunc := func(options PrepareStepFunctionOptions) PrepareStepResult {
966 newSystem := "Modified system prompt for step " + fmt.Sprintf("%d", options.StepNumber)
967 return PrepareStepResult{
968 Model: options.Model,
969 Messages: options.Messages,
970 System: &newSystem,
971 }
972 }
973
974 agent := NewAgent(model, WithSystemPrompt("Original system prompt"))
975
976 result, err := agent.Generate(context.Background(), AgentCall{
977 Prompt: "test prompt",
978 PrepareStep: prepareStepFunc,
979 })
980
981 require.NoError(t, err)
982 require.NotNil(t, result)
983 require.Equal(t, "Modified system prompt for step 0", capturedSystemPrompt)
984 })
985
986 t.Run("Tool choice modification", func(t *testing.T) {
987 t.Parallel()
988 var capturedToolChoice *ToolChoice
989 model := &mockLanguageModel{
990 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
991 capturedToolChoice = call.ToolChoice
992 return &Response{
993 Content: ResponseContent{
994 TextContent{Text: "Response"},
995 },
996 Usage: Usage{TotalTokens: 10},
997 FinishReason: FinishReasonStop,
998 }, nil
999 },
1000 }
1001
1002 prepareStepFunc := func(options PrepareStepFunctionOptions) PrepareStepResult {
1003 toolChoice := ToolChoiceNone
1004 return PrepareStepResult{
1005 Model: options.Model,
1006 Messages: options.Messages,
1007 ToolChoice: &toolChoice,
1008 }
1009 }
1010
1011 agent := NewAgent(model)
1012
1013 result, err := agent.Generate(context.Background(), AgentCall{
1014 Prompt: "test prompt",
1015 PrepareStep: prepareStepFunc,
1016 })
1017
1018 require.NoError(t, err)
1019 require.NotNil(t, result)
1020 require.NotNil(t, capturedToolChoice)
1021 require.Equal(t, ToolChoiceNone, *capturedToolChoice)
1022 })
1023
1024 t.Run("Active tools modification", func(t *testing.T) {
1025 t.Parallel()
1026 var capturedToolNames []string
1027 model := &mockLanguageModel{
1028 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1029 // Capture tool names to verify active tools were modified
1030 for _, tool := range call.Tools {
1031 capturedToolNames = append(capturedToolNames, tool.GetName())
1032 }
1033 return &Response{
1034 Content: ResponseContent{
1035 TextContent{Text: "Response"},
1036 },
1037 Usage: Usage{TotalTokens: 10},
1038 FinishReason: FinishReasonStop,
1039 }, nil
1040 },
1041 }
1042
1043 tool1 := &mockTool{name: "tool1", description: "Tool 1"}
1044 tool2 := &mockTool{name: "tool2", description: "Tool 2"}
1045 tool3 := &mockTool{name: "tool3", description: "Tool 3"}
1046
1047 prepareStepFunc := func(options PrepareStepFunctionOptions) PrepareStepResult {
1048 activeTools := []string{"tool2"} // Only tool2 should be active
1049 return PrepareStepResult{
1050 Model: options.Model,
1051 Messages: options.Messages,
1052 ActiveTools: activeTools,
1053 }
1054 }
1055
1056 agent := NewAgent(model, WithTools(tool1, tool2, tool3))
1057
1058 result, err := agent.Generate(context.Background(), AgentCall{
1059 Prompt: "test prompt",
1060 PrepareStep: prepareStepFunc,
1061 })
1062
1063 require.NoError(t, err)
1064 require.NotNil(t, result)
1065 require.Len(t, capturedToolNames, 1)
1066 require.Equal(t, "tool2", capturedToolNames[0])
1067 })
1068
1069 t.Run("No tools when DisableAllTools is true", func(t *testing.T) {
1070 t.Parallel()
1071 var capturedToolCount int
1072 model := &mockLanguageModel{
1073 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1074 capturedToolCount = len(call.Tools)
1075 return &Response{
1076 Content: ResponseContent{
1077 TextContent{Text: "Response"},
1078 },
1079 Usage: Usage{TotalTokens: 10},
1080 FinishReason: FinishReasonStop,
1081 }, nil
1082 },
1083 }
1084
1085 tool1 := &mockTool{name: "tool1", description: "Tool 1"}
1086
1087 prepareStepFunc := func(options PrepareStepFunctionOptions) PrepareStepResult {
1088 return PrepareStepResult{
1089 Model: options.Model,
1090 Messages: options.Messages,
1091 DisableAllTools: true, // Disable all tools for this step
1092 }
1093 }
1094
1095 agent := NewAgent(model, WithTools(tool1))
1096
1097 result, err := agent.Generate(context.Background(), AgentCall{
1098 Prompt: "test prompt",
1099 PrepareStep: prepareStepFunc,
1100 })
1101
1102 require.NoError(t, err)
1103 require.NotNil(t, result)
1104 require.Equal(t, 0, capturedToolCount) // No tools should be passed
1105 })
1106
1107 t.Run("All fields modified together", func(t *testing.T) {
1108 t.Parallel()
1109 var capturedSystemPrompt string
1110 var capturedToolChoice *ToolChoice
1111 var capturedToolNames []string
1112
1113 model := &mockLanguageModel{
1114 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1115 // Capture system prompt
1116 if len(call.Prompt) > 0 && call.Prompt[0].Role == MessageRoleSystem {
1117 if len(call.Prompt[0].Content) > 0 {
1118 if textPart, ok := AsContentType[TextPart](call.Prompt[0].Content[0]); ok {
1119 capturedSystemPrompt = textPart.Text
1120 }
1121 }
1122 }
1123 // Capture tool choice
1124 capturedToolChoice = call.ToolChoice
1125 // Capture tool names
1126 for _, tool := range call.Tools {
1127 capturedToolNames = append(capturedToolNames, tool.GetName())
1128 }
1129 return &Response{
1130 Content: ResponseContent{
1131 TextContent{Text: "Response"},
1132 },
1133 Usage: Usage{TotalTokens: 10},
1134 FinishReason: FinishReasonStop,
1135 }, nil
1136 },
1137 }
1138
1139 tool1 := &mockTool{name: "tool1", description: "Tool 1"}
1140 tool2 := &mockTool{name: "tool2", description: "Tool 2"}
1141
1142 prepareStepFunc := func(options PrepareStepFunctionOptions) PrepareStepResult {
1143 newSystem := "Step-specific system"
1144 toolChoice := SpecificToolChoice("tool1")
1145 activeTools := []string{"tool1"}
1146 return PrepareStepResult{
1147 Model: options.Model,
1148 Messages: options.Messages,
1149 System: &newSystem,
1150 ToolChoice: &toolChoice,
1151 ActiveTools: activeTools,
1152 }
1153 }
1154
1155 agent := NewAgent(model, WithSystemPrompt("Original system"), WithTools(tool1, tool2))
1156
1157 result, err := agent.Generate(context.Background(), AgentCall{
1158 Prompt: "test prompt",
1159 PrepareStep: prepareStepFunc,
1160 })
1161
1162 require.NoError(t, err)
1163 require.NotNil(t, result)
1164 require.Equal(t, "Step-specific system", capturedSystemPrompt)
1165 require.NotNil(t, capturedToolChoice)
1166 require.Equal(t, SpecificToolChoice("tool1"), *capturedToolChoice)
1167 require.Len(t, capturedToolNames, 1)
1168 require.Equal(t, "tool1", capturedToolNames[0])
1169 })
1170
1171 t.Run("Nil fields use parent values", func(t *testing.T) {
1172 t.Parallel()
1173 var capturedSystemPrompt string
1174 var capturedToolChoice *ToolChoice
1175 var capturedToolNames []string
1176
1177 model := &mockLanguageModel{
1178 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1179 // Capture system prompt
1180 if len(call.Prompt) > 0 && call.Prompt[0].Role == MessageRoleSystem {
1181 if len(call.Prompt[0].Content) > 0 {
1182 if textPart, ok := AsContentType[TextPart](call.Prompt[0].Content[0]); ok {
1183 capturedSystemPrompt = textPart.Text
1184 }
1185 }
1186 }
1187 // Capture tool choice
1188 capturedToolChoice = call.ToolChoice
1189 // Capture tool names
1190 for _, tool := range call.Tools {
1191 capturedToolNames = append(capturedToolNames, tool.GetName())
1192 }
1193 return &Response{
1194 Content: ResponseContent{
1195 TextContent{Text: "Response"},
1196 },
1197 Usage: Usage{TotalTokens: 10},
1198 FinishReason: FinishReasonStop,
1199 }, nil
1200 },
1201 }
1202
1203 tool1 := &mockTool{name: "tool1", description: "Tool 1"}
1204
1205 prepareStepFunc := func(options PrepareStepFunctionOptions) PrepareStepResult {
1206 // All optional fields are nil, should use parent values
1207 return PrepareStepResult{
1208 Model: options.Model,
1209 Messages: options.Messages,
1210 System: nil, // Use parent
1211 ToolChoice: nil, // Use parent (auto)
1212 ActiveTools: nil, // Use parent (all tools)
1213 }
1214 }
1215
1216 agent := NewAgent(model, WithSystemPrompt("Parent system"), WithTools(tool1))
1217
1218 result, err := agent.Generate(context.Background(), AgentCall{
1219 Prompt: "test prompt",
1220 PrepareStep: prepareStepFunc,
1221 })
1222
1223 require.NoError(t, err)
1224 require.NotNil(t, result)
1225 require.Equal(t, "Parent system", capturedSystemPrompt)
1226 require.NotNil(t, capturedToolChoice)
1227 require.Equal(t, ToolChoiceAuto, *capturedToolChoice) // Default
1228 require.Len(t, capturedToolNames, 1)
1229 require.Equal(t, "tool1", capturedToolNames[0])
1230 })
1231
1232 t.Run("Empty ActiveTools means all tools", func(t *testing.T) {
1233 t.Parallel()
1234 var capturedToolNames []string
1235 model := &mockLanguageModel{
1236 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1237 // Capture tool names to verify all tools are included
1238 for _, tool := range call.Tools {
1239 capturedToolNames = append(capturedToolNames, tool.GetName())
1240 }
1241 return &Response{
1242 Content: ResponseContent{
1243 TextContent{Text: "Response"},
1244 },
1245 Usage: Usage{TotalTokens: 10},
1246 FinishReason: FinishReasonStop,
1247 }, nil
1248 },
1249 }
1250
1251 tool1 := &mockTool{name: "tool1", description: "Tool 1"}
1252 tool2 := &mockTool{name: "tool2", description: "Tool 2"}
1253
1254 prepareStepFunc := func(options PrepareStepFunctionOptions) PrepareStepResult {
1255 return PrepareStepResult{
1256 Model: options.Model,
1257 Messages: options.Messages,
1258 ActiveTools: []string{}, // Empty slice means all tools
1259 }
1260 }
1261
1262 agent := NewAgent(model, WithTools(tool1, tool2))
1263
1264 result, err := agent.Generate(context.Background(), AgentCall{
1265 Prompt: "test prompt",
1266 PrepareStep: prepareStepFunc,
1267 })
1268
1269 require.NoError(t, err)
1270 require.NotNil(t, result)
1271 require.Len(t, capturedToolNames, 2) // All tools should be included
1272 require.Contains(t, capturedToolNames, "tool1")
1273 require.Contains(t, capturedToolNames, "tool2")
1274 })
1275}
1276
1277func TestToolCallRepair(t *testing.T) {
1278 t.Parallel()
1279
1280 t.Run("Valid tool call passes validation", func(t *testing.T) {
1281 t.Parallel()
1282 model := &mockLanguageModel{
1283 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1284 return &Response{
1285 Content: ResponseContent{
1286 TextContent{Text: "Response"},
1287 ToolCallContent{
1288 ToolCallID: "call1",
1289 ToolName: "test_tool",
1290 Input: `{"value": "test"}`, // Valid JSON with required field
1291 },
1292 },
1293 Usage: Usage{TotalTokens: 10},
1294 FinishReason: FinishReasonStop, // Changed to stop to avoid infinite loop
1295 }, nil
1296 },
1297 }
1298
1299 tool := &mockTool{
1300 name: "test_tool",
1301 description: "Test tool",
1302 parameters: map[string]any{
1303 "value": map[string]any{"type": "string"},
1304 },
1305 required: []string{"value"},
1306 executeFunc: func(ctx context.Context, call ToolCall) (ToolResponse, error) {
1307 return ToolResponse{Content: "success", IsError: false}, nil
1308 },
1309 }
1310
1311 agent := NewAgent(model, WithTools(tool), WithStopConditions(StepCountIs(2))) // Limit steps
1312
1313 result, err := agent.Generate(context.Background(), AgentCall{
1314 Prompt: "test prompt",
1315 })
1316
1317 require.NoError(t, err)
1318 require.NotNil(t, result)
1319 require.Len(t, result.Steps, 1) // Only one step since FinishReason is stop
1320
1321 // Check that tool call was executed successfully
1322 toolCalls := result.Steps[0].Content.ToolCalls()
1323 require.Len(t, toolCalls, 1)
1324 require.False(t, toolCalls[0].Invalid) // Should be valid
1325 })
1326
1327 t.Run("Invalid tool call without repair function", func(t *testing.T) {
1328 t.Parallel()
1329 model := &mockLanguageModel{
1330 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1331 return &Response{
1332 Content: ResponseContent{
1333 TextContent{Text: "Response"},
1334 ToolCallContent{
1335 ToolCallID: "call1",
1336 ToolName: "test_tool",
1337 Input: `{"wrong_field": "test"}`, // Missing required field
1338 },
1339 },
1340 Usage: Usage{TotalTokens: 10},
1341 FinishReason: FinishReasonStop, // Changed to stop to avoid infinite loop
1342 }, nil
1343 },
1344 }
1345
1346 tool := &mockTool{
1347 name: "test_tool",
1348 description: "Test tool",
1349 parameters: map[string]any{
1350 "value": map[string]any{"type": "string"},
1351 },
1352 required: []string{"value"},
1353 }
1354
1355 agent := NewAgent(model, WithTools(tool), WithStopConditions(StepCountIs(2))) // Limit steps
1356
1357 result, err := agent.Generate(context.Background(), AgentCall{
1358 Prompt: "test prompt",
1359 })
1360
1361 require.NoError(t, err)
1362 require.NotNil(t, result)
1363 require.Len(t, result.Steps, 1) // Only one step
1364
1365 // Check that tool call was marked as invalid
1366 toolCalls := result.Steps[0].Content.ToolCalls()
1367 require.Len(t, toolCalls, 1)
1368 require.True(t, toolCalls[0].Invalid) // Should be invalid
1369 require.Contains(t, toolCalls[0].ValidationError.Error(), "missing required parameter: value")
1370 })
1371
1372 t.Run("Invalid tool call with successful repair", func(t *testing.T) {
1373 t.Parallel()
1374 model := &mockLanguageModel{
1375 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1376 return &Response{
1377 Content: ResponseContent{
1378 TextContent{Text: "Response"},
1379 ToolCallContent{
1380 ToolCallID: "call1",
1381 ToolName: "test_tool",
1382 Input: `{"wrong_field": "test"}`, // Missing required field
1383 },
1384 },
1385 Usage: Usage{TotalTokens: 10},
1386 FinishReason: FinishReasonStop, // Changed to stop
1387 }, nil
1388 },
1389 }
1390
1391 tool := &mockTool{
1392 name: "test_tool",
1393 description: "Test tool",
1394 parameters: map[string]any{
1395 "value": map[string]any{"type": "string"},
1396 },
1397 required: []string{"value"},
1398 executeFunc: func(ctx context.Context, call ToolCall) (ToolResponse, error) {
1399 return ToolResponse{Content: "repaired_success", IsError: false}, nil
1400 },
1401 }
1402
1403 repairFunc := func(ctx context.Context, options ToolCallRepairOptions) (*ToolCallContent, error) {
1404 // Simple repair: add the missing required field
1405 repairedToolCall := options.OriginalToolCall
1406 repairedToolCall.Input = `{"value": "repaired"}`
1407 return &repairedToolCall, nil
1408 }
1409
1410 agent := NewAgent(model, WithTools(tool), WithRepairToolCall(repairFunc), WithStopConditions(StepCountIs(2)))
1411
1412 result, err := agent.Generate(context.Background(), AgentCall{
1413 Prompt: "test prompt",
1414 })
1415
1416 require.NoError(t, err)
1417 require.NotNil(t, result)
1418 require.Len(t, result.Steps, 1) // Only one step
1419
1420 // Check that tool call was repaired and is now valid
1421 toolCalls := result.Steps[0].Content.ToolCalls()
1422 require.Len(t, toolCalls, 1)
1423 require.False(t, toolCalls[0].Invalid) // Should be valid after repair
1424 require.Equal(t, `{"value": "repaired"}`, toolCalls[0].Input) // Should have repaired input
1425 })
1426
1427 t.Run("Invalid tool call with failed repair", func(t *testing.T) {
1428 t.Parallel()
1429 model := &mockLanguageModel{
1430 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1431 return &Response{
1432 Content: ResponseContent{
1433 TextContent{Text: "Response"},
1434 ToolCallContent{
1435 ToolCallID: "call1",
1436 ToolName: "test_tool",
1437 Input: `{"wrong_field": "test"}`, // Missing required field
1438 },
1439 },
1440 Usage: Usage{TotalTokens: 10},
1441 FinishReason: FinishReasonStop, // Changed to stop
1442 }, nil
1443 },
1444 }
1445
1446 tool := &mockTool{
1447 name: "test_tool",
1448 description: "Test tool",
1449 parameters: map[string]any{
1450 "value": map[string]any{"type": "string"},
1451 },
1452 required: []string{"value"},
1453 }
1454
1455 repairFunc := func(ctx context.Context, options ToolCallRepairOptions) (*ToolCallContent, error) {
1456 // Repair function fails
1457 return nil, errors.New("repair failed")
1458 }
1459
1460 agent := NewAgent(model, WithTools(tool), WithRepairToolCall(repairFunc), WithStopConditions(StepCountIs(2)))
1461
1462 result, err := agent.Generate(context.Background(), AgentCall{
1463 Prompt: "test prompt",
1464 })
1465
1466 require.NoError(t, err)
1467 require.NotNil(t, result)
1468 require.Len(t, result.Steps, 1) // Only one step
1469
1470 // Check that tool call was marked as invalid since repair failed
1471 toolCalls := result.Steps[0].Content.ToolCalls()
1472 require.Len(t, toolCalls, 1)
1473 require.True(t, toolCalls[0].Invalid) // Should be invalid
1474 require.Contains(t, toolCalls[0].ValidationError.Error(), "missing required parameter: value")
1475 })
1476
1477 t.Run("Nonexistent tool call", func(t *testing.T) {
1478 t.Parallel()
1479 model := &mockLanguageModel{
1480 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1481 return &Response{
1482 Content: ResponseContent{
1483 TextContent{Text: "Response"},
1484 ToolCallContent{
1485 ToolCallID: "call1",
1486 ToolName: "nonexistent_tool",
1487 Input: `{"value": "test"}`,
1488 },
1489 },
1490 Usage: Usage{TotalTokens: 10},
1491 FinishReason: FinishReasonStop, // Changed to stop
1492 }, nil
1493 },
1494 }
1495
1496 tool := &mockTool{name: "test_tool", description: "Test tool"}
1497
1498 agent := NewAgent(model, WithTools(tool), WithStopConditions(StepCountIs(2)))
1499
1500 result, err := agent.Generate(context.Background(), AgentCall{
1501 Prompt: "test prompt",
1502 })
1503
1504 require.NoError(t, err)
1505 require.NotNil(t, result)
1506 require.Len(t, result.Steps, 1) // Only one step
1507
1508 // Check that tool call was marked as invalid due to nonexistent tool
1509 toolCalls := result.Steps[0].Content.ToolCalls()
1510 require.Len(t, toolCalls, 1)
1511 require.True(t, toolCalls[0].Invalid) // Should be invalid
1512 require.Contains(t, toolCalls[0].ValidationError.Error(), "tool not found: nonexistent_tool")
1513 })
1514
1515 t.Run("Invalid JSON in tool call", func(t *testing.T) {
1516 t.Parallel()
1517 model := &mockLanguageModel{
1518 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1519 return &Response{
1520 Content: ResponseContent{
1521 TextContent{Text: "Response"},
1522 ToolCallContent{
1523 ToolCallID: "call1",
1524 ToolName: "test_tool",
1525 Input: `{invalid json}`, // Invalid JSON
1526 },
1527 },
1528 Usage: Usage{TotalTokens: 10},
1529 FinishReason: FinishReasonStop, // Changed to stop
1530 }, nil
1531 },
1532 }
1533
1534 tool := &mockTool{
1535 name: "test_tool",
1536 description: "Test tool",
1537 parameters: map[string]any{
1538 "value": map[string]any{"type": "string"},
1539 },
1540 required: []string{"value"},
1541 }
1542
1543 agent := NewAgent(model, WithTools(tool), WithStopConditions(StepCountIs(2)))
1544
1545 result, err := agent.Generate(context.Background(), AgentCall{
1546 Prompt: "test prompt",
1547 })
1548
1549 require.NoError(t, err)
1550 require.NotNil(t, result)
1551 require.Len(t, result.Steps, 1) // Only one step
1552
1553 // Check that tool call was marked as invalid due to invalid JSON
1554 toolCalls := result.Steps[0].Content.ToolCalls()
1555 require.Len(t, toolCalls, 1)
1556 require.True(t, toolCalls[0].Invalid) // Should be invalid
1557 require.Contains(t, toolCalls[0].ValidationError.Error(), "invalid JSON input")
1558 })
1559}