1package ai
2
3import (
4 "context"
5 "encoding/json"
6 "errors"
7 "fmt"
8 "testing"
9
10 "github.com/stretchr/testify/require"
11)
12
13// Mock tool for testing
14type mockTool struct {
15 name string
16 description string
17 parameters map[string]any
18 required []string
19 executeFunc func(ctx context.Context, call ToolCall) (ToolResponse, error)
20}
21
22func (m *mockTool) Info() ToolInfo {
23 return ToolInfo{
24 Name: m.name,
25 Description: m.description,
26 Parameters: m.parameters,
27 Required: m.required,
28 }
29}
30
31func (m *mockTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
32 if m.executeFunc != nil {
33 return m.executeFunc(ctx, call)
34 }
35 return ToolResponse{Content: "mock result", IsError: false}, nil
36}
37
38// Mock language model for testing
39type mockLanguageModel struct {
40 generateFunc func(ctx context.Context, call Call) (*Response, error)
41 streamFunc func(ctx context.Context, call Call) (StreamResponse, error)
42}
43
44func (m *mockLanguageModel) Generate(ctx context.Context, call Call) (*Response, error) {
45 if m.generateFunc != nil {
46 return m.generateFunc(ctx, call)
47 }
48 return &Response{
49 Content: []Content{
50 TextContent{Text: "Hello, world!"},
51 },
52 Usage: Usage{
53 InputTokens: 3,
54 OutputTokens: 10,
55 TotalTokens: 13,
56 },
57 FinishReason: FinishReasonStop,
58 }, nil
59}
60
61func (m *mockLanguageModel) Stream(ctx context.Context, call Call) (StreamResponse, error) {
62 if m.streamFunc != nil {
63 return m.streamFunc(ctx, call)
64 }
65 return nil, fmt.Errorf("mock stream not implemented")
66}
67
68func (m *mockLanguageModel) Provider() string {
69 return "mock-provider"
70}
71
72func (m *mockLanguageModel) Model() string {
73 return "mock-model"
74}
75
76// Test result.content - comprehensive content types (matches TS test)
77func TestAgent_Generate_ResultContent_AllTypes(t *testing.T) {
78 t.Parallel()
79
80 // Create a type-safe tool using the new API
81 type TestInput struct {
82 Value string `json:"value" description:"Test value"`
83 }
84
85 tool1 := NewTypedToolFunc(
86 "tool1",
87 "Test tool",
88 func(ctx context.Context, input TestInput, _ ToolCall) (ToolResponse, error) {
89 require.Equal(t, "value", input.Value)
90 return ToolResponse{Content: "result1", IsError: false}, nil
91 },
92 )
93
94 model := &mockLanguageModel{
95 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
96 return &Response{
97 Content: []Content{
98 TextContent{Text: "Hello, world!"},
99 SourceContent{
100 ID: "123",
101 URL: "https://example.com",
102 Title: "Example",
103 SourceType: SourceTypeURL,
104 ProviderMetadata: ProviderMetadata{
105 "provider": map[string]any{"custom": "value"},
106 },
107 },
108 FileContent{
109 Data: []byte{1, 2, 3},
110 MediaType: "image/png",
111 },
112 ReasoningContent{
113 Text: "I will open the conversation with witty banter.",
114 },
115 ToolCallContent{
116 ToolCallID: "call-1",
117 ToolName: "tool1",
118 Input: `{"value":"value"}`,
119 },
120 TextContent{Text: "More text"},
121 },
122 Usage: Usage{
123 InputTokens: 3,
124 OutputTokens: 10,
125 TotalTokens: 13,
126 },
127 FinishReason: FinishReasonStop, // Note: FinishReasonStop, not ToolCalls
128 }, nil
129 },
130 }
131
132 agent := NewAgent(model, WithTools(tool1))
133 result, err := agent.Generate(context.Background(), AgentCall{
134 Prompt: "prompt",
135 })
136
137 require.NoError(t, err)
138 require.NotNil(t, result)
139 require.Len(t, result.Steps, 1) // Single step like TypeScript
140
141 // Check final response content includes tool result
142 require.Len(t, result.Response.Content, 7) // original 6 + 1 tool result
143
144 // Verify each content type in order
145 textContent, ok := AsContentType[TextContent](result.Response.Content[0])
146 require.True(t, ok)
147 require.Equal(t, "Hello, world!", textContent.Text)
148
149 sourceContent, ok := AsContentType[SourceContent](result.Response.Content[1])
150 require.True(t, ok)
151 require.Equal(t, "123", sourceContent.ID)
152
153 fileContent, ok := AsContentType[FileContent](result.Response.Content[2])
154 require.True(t, ok)
155 require.Equal(t, []byte{1, 2, 3}, fileContent.Data)
156
157 reasoningContent, ok := AsContentType[ReasoningContent](result.Response.Content[3])
158 require.True(t, ok)
159 require.Equal(t, "I will open the conversation with witty banter.", reasoningContent.Text)
160
161 toolCallContent, ok := AsContentType[ToolCallContent](result.Response.Content[4])
162 require.True(t, ok)
163 require.Equal(t, "call-1", toolCallContent.ToolCallID)
164
165 moreTextContent, ok := AsContentType[TextContent](result.Response.Content[5])
166 require.True(t, ok)
167 require.Equal(t, "More text", moreTextContent.Text)
168
169 // Tool result should be appended
170 toolResultContent, ok := AsContentType[ToolResultContent](result.Response.Content[6])
171 require.True(t, ok)
172 require.Equal(t, "call-1", toolResultContent.ToolCallID)
173 require.Equal(t, "tool1", toolResultContent.ToolName)
174}
175
176// Test result.text extraction
177func TestAgent_Generate_ResultText(t *testing.T) {
178 t.Parallel()
179
180 model := &mockLanguageModel{
181 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
182 return &Response{
183 Content: []Content{
184 TextContent{Text: "Hello, world!"},
185 },
186 Usage: Usage{
187 InputTokens: 3,
188 OutputTokens: 10,
189 TotalTokens: 13,
190 },
191 FinishReason: FinishReasonStop,
192 }, nil
193 },
194 }
195
196 agent := NewAgent(model)
197 result, err := agent.Generate(context.Background(), AgentCall{
198 Prompt: "prompt",
199 })
200
201 require.NoError(t, err)
202 require.NotNil(t, result)
203
204 // Test text extraction from content
205 text := result.Response.Content.Text()
206 require.Equal(t, "Hello, world!", text)
207}
208
209// Test result.toolCalls extraction (matches TS test exactly)
210func TestAgent_Generate_ResultToolCalls(t *testing.T) {
211 t.Parallel()
212
213 // Create type-safe tools using the new API
214 type Tool1Input struct {
215 Value string `json:"value" description:"Test value"`
216 }
217
218 type Tool2Input struct {
219 SomethingElse string `json:"somethingElse" description:"Another test value"`
220 }
221
222 tool1 := NewTypedToolFunc(
223 "tool1",
224 "Test tool 1",
225 func(ctx context.Context, input Tool1Input, _ ToolCall) (ToolResponse, error) {
226 return ToolResponse{Content: "result1", IsError: false}, nil
227 },
228 )
229
230 tool2 := NewTypedToolFunc(
231 "tool2",
232 "Test tool 2",
233 func(ctx context.Context, input Tool2Input, _ ToolCall) (ToolResponse, error) {
234 return ToolResponse{Content: "result2", IsError: false}, nil
235 },
236 )
237
238 model := &mockLanguageModel{
239 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
240 // Verify tools are passed correctly
241 require.Len(t, call.Tools, 2)
242 require.Equal(t, ToolChoiceAuto, *call.ToolChoice) // Should be auto, not required
243
244 // Verify prompt structure
245 require.Len(t, call.Prompt, 1)
246 require.Equal(t, MessageRoleUser, call.Prompt[0].Role)
247
248 return &Response{
249 Content: []Content{
250 ToolCallContent{
251 ToolCallID: "call-1",
252 ToolName: "tool1",
253 Input: `{"value":"value"}`,
254 },
255 },
256 Usage: Usage{
257 InputTokens: 3,
258 OutputTokens: 10,
259 TotalTokens: 13,
260 },
261 FinishReason: FinishReasonStop, // Note: Stop, not ToolCalls
262 }, nil
263 },
264 }
265
266 agent := NewAgent(model, WithTools(tool1, tool2))
267 result, err := agent.Generate(context.Background(), AgentCall{
268 Prompt: "test-input",
269 })
270
271 require.NoError(t, err)
272 require.NotNil(t, result)
273 require.Len(t, result.Steps, 1) // Single step
274
275 // Extract tool calls from final response (should be empty since tools don't execute)
276 var toolCalls []ToolCallContent
277 for _, content := range result.Response.Content {
278 if toolCall, ok := AsContentType[ToolCallContent](content); ok {
279 toolCalls = append(toolCalls, toolCall)
280 }
281 }
282
283 require.Len(t, toolCalls, 1)
284 require.Equal(t, "call-1", toolCalls[0].ToolCallID)
285 require.Equal(t, "tool1", toolCalls[0].ToolName)
286
287 // Parse and verify input
288 var input map[string]any
289 err = json.Unmarshal([]byte(toolCalls[0].Input), &input)
290 require.NoError(t, err)
291 require.Equal(t, "value", input["value"])
292}
293
294// Test result.toolResults extraction (matches TS test exactly)
295func TestAgent_Generate_ResultToolResults(t *testing.T) {
296 t.Parallel()
297
298 // Create type-safe tool using the new API
299 type TestInput struct {
300 Value string `json:"value" description:"Test value"`
301 }
302
303 tool1 := NewTypedToolFunc(
304 "tool1",
305 "Test tool",
306 func(ctx context.Context, input TestInput, _ ToolCall) (ToolResponse, error) {
307 require.Equal(t, "value", input.Value)
308 return ToolResponse{Content: "result1", IsError: false}, nil
309 },
310 )
311
312 model := &mockLanguageModel{
313 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
314 // Verify tools and tool choice
315 require.Len(t, call.Tools, 1)
316 require.Equal(t, ToolChoiceAuto, *call.ToolChoice)
317
318 // Verify prompt
319 require.Len(t, call.Prompt, 1)
320 require.Equal(t, MessageRoleUser, call.Prompt[0].Role)
321
322 return &Response{
323 Content: []Content{
324 ToolCallContent{
325 ToolCallID: "call-1",
326 ToolName: "tool1",
327 Input: `{"value":"value"}`,
328 },
329 },
330 Usage: Usage{
331 InputTokens: 3,
332 OutputTokens: 10,
333 TotalTokens: 13,
334 },
335 FinishReason: FinishReasonStop, // Note: Stop, not ToolCalls
336 }, nil
337 },
338 }
339
340 agent := NewAgent(model, WithTools(tool1))
341 result, err := agent.Generate(context.Background(), AgentCall{
342 Prompt: "test-input",
343 })
344
345 require.NoError(t, err)
346 require.NotNil(t, result)
347 require.Len(t, result.Steps, 1) // Single step
348
349 // Extract tool results from final response
350 var toolResults []ToolResultContent
351 for _, content := range result.Response.Content {
352 if toolResult, ok := AsContentType[ToolResultContent](content); ok {
353 toolResults = append(toolResults, toolResult)
354 }
355 }
356
357 require.Len(t, toolResults, 1)
358 require.Equal(t, "call-1", toolResults[0].ToolCallID)
359 require.Equal(t, "tool1", toolResults[0].ToolName)
360
361 // Verify result content
362 textResult, ok := toolResults[0].Result.(ToolResultOutputContentText)
363 require.True(t, ok)
364 require.Equal(t, "result1", textResult.Text)
365}
366
367// Test multi-step scenario (matches TS "2 steps: initial, tool-result" test)
368func TestAgent_Generate_MultipleSteps(t *testing.T) {
369 t.Parallel()
370
371 // Create type-safe tool using the new API
372 type TestInput struct {
373 Value string `json:"value" description:"Test value"`
374 }
375
376 tool1 := NewTypedToolFunc(
377 "tool1",
378 "Test tool",
379 func(ctx context.Context, input TestInput, _ ToolCall) (ToolResponse, error) {
380 require.Equal(t, "value", input.Value)
381 return ToolResponse{Content: "result1", IsError: false}, nil
382 },
383 )
384
385 callCount := 0
386 model := &mockLanguageModel{
387 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
388 callCount++
389 switch callCount {
390 case 1:
391 // First call - return tool call with FinishReasonToolCalls
392 return &Response{
393 Content: []Content{
394 ToolCallContent{
395 ToolCallID: "call-1",
396 ToolName: "tool1",
397 Input: `{"value":"value"}`,
398 },
399 },
400 Usage: Usage{
401 InputTokens: 10,
402 OutputTokens: 5,
403 TotalTokens: 15,
404 },
405 FinishReason: FinishReasonToolCalls, // This triggers multi-step
406 }, nil
407 case 2:
408 // Second call - return final text
409 return &Response{
410 Content: []Content{
411 TextContent{Text: "Hello, world!"},
412 },
413 Usage: Usage{
414 InputTokens: 3,
415 OutputTokens: 10,
416 TotalTokens: 13,
417 },
418 FinishReason: FinishReasonStop,
419 }, nil
420 default:
421 t.Fatalf("Unexpected call count: %d", callCount)
422 return nil, nil
423 }
424 },
425 }
426
427 agent := NewAgent(model, WithTools(tool1))
428 result, err := agent.Generate(context.Background(), AgentCall{
429 Prompt: "test-input",
430 })
431
432 require.NoError(t, err)
433 require.NotNil(t, result)
434 require.Len(t, result.Steps, 2)
435
436 // Check total usage sums both steps
437 require.Equal(t, int64(13), result.TotalUsage.InputTokens) // 10 + 3
438 require.Equal(t, int64(15), result.TotalUsage.OutputTokens) // 5 + 10
439 require.Equal(t, int64(28), result.TotalUsage.TotalTokens) // 15 + 13
440
441 // Final response should be from last step
442 require.Len(t, result.Response.Content, 1)
443 textContent, ok := AsContentType[TextContent](result.Response.Content[0])
444 require.True(t, ok)
445 require.Equal(t, "Hello, world!", textContent.Text)
446
447 // result.toolCalls should be empty (from last step)
448 var toolCalls []ToolCallContent
449 for _, content := range result.Response.Content {
450 if _, ok := AsContentType[ToolCallContent](content); ok {
451 toolCalls = append(toolCalls, content.(ToolCallContent))
452 }
453 }
454 require.Len(t, toolCalls, 0)
455
456 // result.toolResults should be empty (from last step)
457 var toolResults []ToolResultContent
458 for _, content := range result.Response.Content {
459 if _, ok := AsContentType[ToolResultContent](content); ok {
460 toolResults = append(toolResults, content.(ToolResultContent))
461 }
462 }
463 require.Len(t, toolResults, 0)
464}
465
466// Test basic text generation
467func TestAgent_Generate_BasicText(t *testing.T) {
468 t.Parallel()
469
470 model := &mockLanguageModel{
471 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
472 return &Response{
473 Content: []Content{
474 TextContent{Text: "Hello, world!"},
475 },
476 Usage: Usage{
477 InputTokens: 3,
478 OutputTokens: 10,
479 TotalTokens: 13,
480 },
481 FinishReason: FinishReasonStop,
482 }, nil
483 },
484 }
485
486 agent := NewAgent(model)
487 result, err := agent.Generate(context.Background(), AgentCall{
488 Prompt: "test prompt",
489 })
490
491 require.NoError(t, err)
492 require.NotNil(t, result)
493 require.Len(t, result.Steps, 1)
494
495 // Check final response
496 require.Len(t, result.Response.Content, 1)
497 textContent, ok := AsContentType[TextContent](result.Response.Content[0])
498 require.True(t, ok)
499 require.Equal(t, "Hello, world!", textContent.Text)
500
501 // Check usage
502 require.Equal(t, int64(3), result.Response.Usage.InputTokens)
503 require.Equal(t, int64(10), result.Response.Usage.OutputTokens)
504 require.Equal(t, int64(13), result.Response.Usage.TotalTokens)
505
506 // Check total usage
507 require.Equal(t, int64(3), result.TotalUsage.InputTokens)
508 require.Equal(t, int64(10), result.TotalUsage.OutputTokens)
509 require.Equal(t, int64(13), result.TotalUsage.TotalTokens)
510}
511
512// Test empty prompt error
513func TestAgent_Generate_EmptyPrompt(t *testing.T) {
514 t.Parallel()
515
516 model := &mockLanguageModel{}
517 agent := NewAgent(model)
518
519 result, err := agent.Generate(context.Background(), AgentCall{
520 Prompt: "", // Empty prompt should cause error
521 })
522
523 require.Error(t, err)
524 require.Nil(t, result)
525 require.Contains(t, err.Error(), "Prompt can't be empty")
526}
527
528// Test with system prompt
529func TestAgent_Generate_WithSystemPrompt(t *testing.T) {
530 t.Parallel()
531
532 model := &mockLanguageModel{
533 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
534 // Verify system message is included
535 require.Len(t, call.Prompt, 2) // system + user
536 require.Equal(t, MessageRoleSystem, call.Prompt[0].Role)
537 require.Equal(t, MessageRoleUser, call.Prompt[1].Role)
538
539 systemPart, ok := call.Prompt[0].Content[0].(TextPart)
540 require.True(t, ok)
541 require.Equal(t, "You are a helpful assistant", systemPart.Text)
542
543 return &Response{
544 Content: []Content{
545 TextContent{Text: "Hello, world!"},
546 },
547 Usage: Usage{
548 InputTokens: 3,
549 OutputTokens: 10,
550 TotalTokens: 13,
551 },
552 FinishReason: FinishReasonStop,
553 }, nil
554 },
555 }
556
557 agent := NewAgent(model, WithSystemPrompt("You are a helpful assistant"))
558 result, err := agent.Generate(context.Background(), AgentCall{
559 Prompt: "test prompt",
560 })
561
562 require.NoError(t, err)
563 require.NotNil(t, result)
564}
565
566// Test options.headers
567func TestAgent_Generate_OptionsHeaders(t *testing.T) {
568 t.Parallel()
569
570 model := &mockLanguageModel{
571 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
572 // Verify headers are passed
573 require.Equal(t, map[string]string{
574 "custom-request-header": "request-header-value",
575 }, call.Headers)
576
577 return &Response{
578 Content: []Content{
579 TextContent{Text: "Hello, world!"},
580 },
581 Usage: Usage{
582 InputTokens: 3,
583 OutputTokens: 10,
584 TotalTokens: 13,
585 },
586 FinishReason: FinishReasonStop,
587 }, nil
588 },
589 }
590
591 agent := NewAgent(model)
592 result, err := agent.Generate(context.Background(), AgentCall{
593 Prompt: "test-input",
594 Headers: map[string]string{"custom-request-header": "request-header-value"},
595 })
596
597 require.NoError(t, err)
598 require.NotNil(t, result)
599 require.Equal(t, "Hello, world!", result.Response.Content.Text())
600}
601
602// Test options.activeTools filtering
603func TestAgent_Generate_OptionsActiveTools(t *testing.T) {
604 t.Parallel()
605
606 tool1 := &mockTool{
607 name: "tool1",
608 description: "Test tool 1",
609 parameters: map[string]any{
610 "value": map[string]any{"type": "string"},
611 },
612 required: []string{"value"},
613 }
614
615 tool2 := &mockTool{
616 name: "tool2",
617 description: "Test tool 2",
618 parameters: map[string]any{
619 "value": map[string]any{"type": "string"},
620 },
621 required: []string{"value"},
622 }
623
624 model := &mockLanguageModel{
625 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
626 // Verify only tool1 is available
627 require.Len(t, call.Tools, 1)
628 functionTool, ok := call.Tools[0].(FunctionTool)
629 require.True(t, ok)
630 require.Equal(t, "tool1", functionTool.Name)
631
632 return &Response{
633 Content: []Content{
634 TextContent{Text: "Hello, world!"},
635 },
636 Usage: Usage{
637 InputTokens: 3,
638 OutputTokens: 10,
639 TotalTokens: 13,
640 },
641 FinishReason: FinishReasonStop,
642 }, nil
643 },
644 }
645
646 agent := NewAgent(model, WithTools(tool1, tool2))
647 result, err := agent.Generate(context.Background(), AgentCall{
648 Prompt: "test-input",
649 ActiveTools: []string{"tool1"}, // Only tool1 should be active
650 })
651
652 require.NoError(t, err)
653 require.NotNil(t, result)
654}
655
656func TestResponseContent_Getters(t *testing.T) {
657 t.Parallel()
658
659 // Create test content with all types
660 content := ResponseContent{
661 TextContent{Text: "Hello world"},
662 ReasoningContent{Text: "Let me think..."},
663 FileContent{Data: []byte("file data"), MediaType: "text/plain"},
664 SourceContent{SourceType: SourceTypeURL, URL: "https://example.com", Title: "Example"},
665 ToolCallContent{ToolCallID: "call1", ToolName: "test_tool", Input: `{"arg": "value"}`},
666 ToolResultContent{ToolCallID: "call1", ToolName: "test_tool", Result: ToolResultOutputContentText{Text: "result"}},
667 }
668
669 // Test Text()
670 require.Equal(t, "Hello world", content.Text())
671
672 // Test Reasoning()
673 reasoning := content.Reasoning()
674 require.Len(t, reasoning, 1)
675 require.Equal(t, "Let me think...", reasoning[0].Text)
676
677 // Test ReasoningText()
678 require.Equal(t, "Let me think...", content.ReasoningText())
679
680 // Test Files()
681 files := content.Files()
682 require.Len(t, files, 1)
683 require.Equal(t, "text/plain", files[0].MediaType)
684 require.Equal(t, []byte("file data"), files[0].Data)
685
686 // Test Sources()
687 sources := content.Sources()
688 require.Len(t, sources, 1)
689 require.Equal(t, SourceTypeURL, sources[0].SourceType)
690 require.Equal(t, "https://example.com", sources[0].URL)
691 require.Equal(t, "Example", sources[0].Title)
692
693 // Test ToolCalls()
694 toolCalls := content.ToolCalls()
695 require.Len(t, toolCalls, 1)
696 require.Equal(t, "call1", toolCalls[0].ToolCallID)
697 require.Equal(t, "test_tool", toolCalls[0].ToolName)
698 require.Equal(t, `{"arg": "value"}`, toolCalls[0].Input)
699
700 // Test ToolResults()
701 toolResults := content.ToolResults()
702 require.Len(t, toolResults, 1)
703 require.Equal(t, "call1", toolResults[0].ToolCallID)
704 require.Equal(t, "test_tool", toolResults[0].ToolName)
705 result, ok := AsToolResultOutputType[ToolResultOutputContentText](toolResults[0].Result)
706 require.True(t, ok)
707 require.Equal(t, "result", result.Text)
708}
709
710func TestResponseContent_Getters_Empty(t *testing.T) {
711 t.Parallel()
712
713 // Test with empty content
714 content := ResponseContent{}
715
716 require.Equal(t, "", content.Text())
717 require.Equal(t, "", content.ReasoningText())
718 require.Empty(t, content.Reasoning())
719 require.Empty(t, content.Files())
720 require.Empty(t, content.Sources())
721 require.Empty(t, content.ToolCalls())
722 require.Empty(t, content.ToolResults())
723}
724
725func TestResponseContent_Getters_MultipleItems(t *testing.T) {
726 t.Parallel()
727
728 // Test with multiple items of same type
729 content := ResponseContent{
730 ReasoningContent{Text: "First thought"},
731 ReasoningContent{Text: "Second thought"},
732 FileContent{Data: []byte("file1"), MediaType: "text/plain"},
733 FileContent{Data: []byte("file2"), MediaType: "image/png"},
734 }
735
736 // Test multiple reasoning
737 reasoning := content.Reasoning()
738 require.Len(t, reasoning, 2)
739 require.Equal(t, "First thought", reasoning[0].Text)
740 require.Equal(t, "Second thought", reasoning[1].Text)
741
742 // Test concatenated reasoning text
743 require.Equal(t, "First thoughtSecond thought", content.ReasoningText())
744
745 // Test multiple files
746 files := content.Files()
747 require.Len(t, files, 2)
748 require.Equal(t, "text/plain", files[0].MediaType)
749 require.Equal(t, "image/png", files[1].MediaType)
750}
751
752func TestStopConditions(t *testing.T) {
753 t.Parallel()
754
755 // Create test steps
756 step1 := StepResult{
757 Response: Response{
758 Content: ResponseContent{
759 TextContent{Text: "Hello"},
760 },
761 FinishReason: FinishReasonToolCalls,
762 Usage: Usage{TotalTokens: 10},
763 },
764 }
765
766 step2 := StepResult{
767 Response: Response{
768 Content: ResponseContent{
769 TextContent{Text: "World"},
770 ToolCallContent{ToolCallID: "call1", ToolName: "search", Input: `{"query": "test"}`},
771 },
772 FinishReason: FinishReasonStop,
773 Usage: Usage{TotalTokens: 15},
774 },
775 }
776
777 step3 := StepResult{
778 Response: Response{
779 Content: ResponseContent{
780 ReasoningContent{Text: "Let me think..."},
781 FileContent{Data: []byte("data"), MediaType: "text/plain"},
782 },
783 FinishReason: FinishReasonLength,
784 Usage: Usage{TotalTokens: 20},
785 },
786 }
787
788 t.Run("StepCountIs", func(t *testing.T) {
789 condition := StepCountIs(2)
790
791 // Should not stop with 1 step
792 require.False(t, condition([]StepResult{step1}))
793
794 // Should stop with 2 steps
795 require.True(t, condition([]StepResult{step1, step2}))
796
797 // Should stop with more than 2 steps
798 require.True(t, condition([]StepResult{step1, step2, step3}))
799
800 // Should not stop with empty steps
801 require.False(t, condition([]StepResult{}))
802 })
803
804 t.Run("HasToolCall", func(t *testing.T) {
805 condition := HasToolCall("search")
806
807 // Should not stop when tool not called
808 require.False(t, condition([]StepResult{step1}))
809
810 // Should stop when tool is called in last step
811 require.True(t, condition([]StepResult{step1, step2}))
812
813 // Should not stop when tool called in earlier step but not last
814 require.False(t, condition([]StepResult{step1, step2, step3}))
815
816 // Should not stop with empty steps
817 require.False(t, condition([]StepResult{}))
818
819 // Should not stop when different tool is called
820 differentToolCondition := HasToolCall("different_tool")
821 require.False(t, differentToolCondition([]StepResult{step1, step2}))
822 })
823
824 t.Run("HasContent", func(t *testing.T) {
825 reasoningCondition := HasContent(ContentTypeReasoning)
826 fileCondition := HasContent(ContentTypeFile)
827
828 // Should not stop when content type not present
829 require.False(t, reasoningCondition([]StepResult{step1, step2}))
830
831 // Should stop when content type is present in last step
832 require.True(t, reasoningCondition([]StepResult{step1, step2, step3}))
833 require.True(t, fileCondition([]StepResult{step1, step2, step3}))
834
835 // Should not stop with empty steps
836 require.False(t, reasoningCondition([]StepResult{}))
837 })
838
839 t.Run("FinishReasonIs", func(t *testing.T) {
840 stopCondition := FinishReasonIs(FinishReasonStop)
841 lengthCondition := FinishReasonIs(FinishReasonLength)
842
843 // Should not stop when finish reason doesn't match
844 require.False(t, stopCondition([]StepResult{step1}))
845
846 // Should stop when finish reason matches in last step
847 require.True(t, stopCondition([]StepResult{step1, step2}))
848 require.True(t, lengthCondition([]StepResult{step1, step2, step3}))
849
850 // Should not stop with empty steps
851 require.False(t, stopCondition([]StepResult{}))
852 })
853
854 t.Run("MaxTokensUsed", func(t *testing.T) {
855 condition := MaxTokensUsed(30)
856
857 // Should not stop when under limit
858 require.False(t, condition([]StepResult{step1})) // 10 tokens
859 require.False(t, condition([]StepResult{step1, step2})) // 25 tokens
860
861 // Should stop when at or over limit
862 require.True(t, condition([]StepResult{step1, step2, step3})) // 45 tokens
863
864 // Should not stop with empty steps
865 require.False(t, condition([]StepResult{}))
866 })
867}
868
869func TestStopConditions_Integration(t *testing.T) {
870 t.Parallel()
871
872 t.Run("StepCountIs integration", func(t *testing.T) {
873 model := &mockLanguageModel{
874 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
875 return &Response{
876 Content: ResponseContent{
877 TextContent{Text: "Mock response"},
878 },
879 Usage: Usage{
880 InputTokens: 3,
881 OutputTokens: 10,
882 TotalTokens: 13,
883 },
884 FinishReason: FinishReasonStop,
885 }, nil
886 },
887 }
888
889 agent := NewAgent(model, WithStopConditions(StepCountIs(1)))
890
891 result, err := agent.Generate(context.Background(), AgentCall{
892 Prompt: "test prompt",
893 })
894
895 require.NoError(t, err)
896 require.NotNil(t, result)
897 require.Len(t, result.Steps, 1) // Should stop after 1 step
898 })
899
900 t.Run("Multiple stop conditions", func(t *testing.T) {
901 model := &mockLanguageModel{
902 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
903 return &Response{
904 Content: ResponseContent{
905 TextContent{Text: "Mock response"},
906 },
907 Usage: Usage{
908 InputTokens: 3,
909 OutputTokens: 10,
910 TotalTokens: 13,
911 },
912 FinishReason: FinishReasonStop,
913 }, nil
914 },
915 }
916
917 agent := NewAgent(model, WithStopConditions(
918 StepCountIs(5), // Stop after 5 steps
919 FinishReasonIs(FinishReasonStop), // Or stop on finish reason
920 ))
921
922 result, err := agent.Generate(context.Background(), AgentCall{
923 Prompt: "test prompt",
924 })
925
926 require.NoError(t, err)
927 require.NotNil(t, result)
928 // Should stop on first condition met (finish reason stop)
929 require.Equal(t, FinishReasonStop, result.Response.FinishReason)
930 })
931}
932
933func TestPrepareStep(t *testing.T) {
934 t.Parallel()
935
936 t.Run("System prompt modification", func(t *testing.T) {
937 var capturedSystemPrompt string
938 model := &mockLanguageModel{
939 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
940 // Capture the system message to verify it was modified
941 if len(call.Prompt) > 0 && call.Prompt[0].Role == MessageRoleSystem {
942 if len(call.Prompt[0].Content) > 0 {
943 if textPart, ok := AsContentType[TextPart](call.Prompt[0].Content[0]); ok {
944 capturedSystemPrompt = textPart.Text
945 }
946 }
947 }
948 return &Response{
949 Content: ResponseContent{
950 TextContent{Text: "Response"},
951 },
952 Usage: Usage{TotalTokens: 10},
953 FinishReason: FinishReasonStop,
954 }, nil
955 },
956 }
957
958 prepareStepFunc := func(options PrepareStepFunctionOptions) PrepareStepResult {
959 newSystem := "Modified system prompt for step " + fmt.Sprintf("%d", options.StepNumber)
960 return PrepareStepResult{
961 Model: options.Model,
962 Messages: options.Messages,
963 System: &newSystem,
964 }
965 }
966
967 agent := NewAgent(model, WithSystemPrompt("Original system prompt"))
968
969 result, err := agent.Generate(context.Background(), AgentCall{
970 Prompt: "test prompt",
971 PrepareStep: prepareStepFunc,
972 })
973
974 require.NoError(t, err)
975 require.NotNil(t, result)
976 require.Equal(t, "Modified system prompt for step 0", capturedSystemPrompt)
977 })
978
979 t.Run("Tool choice modification", func(t *testing.T) {
980 var capturedToolChoice *ToolChoice
981 model := &mockLanguageModel{
982 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
983 capturedToolChoice = call.ToolChoice
984 return &Response{
985 Content: ResponseContent{
986 TextContent{Text: "Response"},
987 },
988 Usage: Usage{TotalTokens: 10},
989 FinishReason: FinishReasonStop,
990 }, nil
991 },
992 }
993
994 prepareStepFunc := func(options PrepareStepFunctionOptions) PrepareStepResult {
995 toolChoice := ToolChoiceNone
996 return PrepareStepResult{
997 Model: options.Model,
998 Messages: options.Messages,
999 ToolChoice: &toolChoice,
1000 }
1001 }
1002
1003 agent := NewAgent(model)
1004
1005 result, err := agent.Generate(context.Background(), AgentCall{
1006 Prompt: "test prompt",
1007 PrepareStep: prepareStepFunc,
1008 })
1009
1010 require.NoError(t, err)
1011 require.NotNil(t, result)
1012 require.NotNil(t, capturedToolChoice)
1013 require.Equal(t, ToolChoiceNone, *capturedToolChoice)
1014 })
1015
1016 t.Run("Active tools modification", func(t *testing.T) {
1017 var capturedToolNames []string
1018 model := &mockLanguageModel{
1019 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1020 // Capture tool names to verify active tools were modified
1021 for _, tool := range call.Tools {
1022 capturedToolNames = append(capturedToolNames, tool.GetName())
1023 }
1024 return &Response{
1025 Content: ResponseContent{
1026 TextContent{Text: "Response"},
1027 },
1028 Usage: Usage{TotalTokens: 10},
1029 FinishReason: FinishReasonStop,
1030 }, nil
1031 },
1032 }
1033
1034 tool1 := &mockTool{name: "tool1", description: "Tool 1"}
1035 tool2 := &mockTool{name: "tool2", description: "Tool 2"}
1036 tool3 := &mockTool{name: "tool3", description: "Tool 3"}
1037
1038 prepareStepFunc := func(options PrepareStepFunctionOptions) PrepareStepResult {
1039 activeTools := []string{"tool2"} // Only tool2 should be active
1040 return PrepareStepResult{
1041 Model: options.Model,
1042 Messages: options.Messages,
1043 ActiveTools: activeTools,
1044 }
1045 }
1046
1047 agent := NewAgent(model, WithTools(tool1, tool2, tool3))
1048
1049 result, err := agent.Generate(context.Background(), AgentCall{
1050 Prompt: "test prompt",
1051 PrepareStep: prepareStepFunc,
1052 })
1053
1054 require.NoError(t, err)
1055 require.NotNil(t, result)
1056 require.Len(t, capturedToolNames, 1)
1057 require.Equal(t, "tool2", capturedToolNames[0])
1058 })
1059
1060 t.Run("No tools when DisableAllTools is true", func(t *testing.T) {
1061 var capturedToolCount int
1062 model := &mockLanguageModel{
1063 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1064 capturedToolCount = len(call.Tools)
1065 return &Response{
1066 Content: ResponseContent{
1067 TextContent{Text: "Response"},
1068 },
1069 Usage: Usage{TotalTokens: 10},
1070 FinishReason: FinishReasonStop,
1071 }, nil
1072 },
1073 }
1074
1075 tool1 := &mockTool{name: "tool1", description: "Tool 1"}
1076
1077 prepareStepFunc := func(options PrepareStepFunctionOptions) PrepareStepResult {
1078 return PrepareStepResult{
1079 Model: options.Model,
1080 Messages: options.Messages,
1081 DisableAllTools: true, // Disable all tools for this step
1082 }
1083 }
1084
1085 agent := NewAgent(model, WithTools(tool1))
1086
1087 result, err := agent.Generate(context.Background(), AgentCall{
1088 Prompt: "test prompt",
1089 PrepareStep: prepareStepFunc,
1090 })
1091
1092 require.NoError(t, err)
1093 require.NotNil(t, result)
1094 require.Equal(t, 0, capturedToolCount) // No tools should be passed
1095 })
1096
1097 t.Run("All fields modified together", func(t *testing.T) {
1098 var capturedSystemPrompt string
1099 var capturedToolChoice *ToolChoice
1100 var capturedToolNames []string
1101
1102 model := &mockLanguageModel{
1103 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1104 // Capture system prompt
1105 if len(call.Prompt) > 0 && call.Prompt[0].Role == MessageRoleSystem {
1106 if len(call.Prompt[0].Content) > 0 {
1107 if textPart, ok := AsContentType[TextPart](call.Prompt[0].Content[0]); ok {
1108 capturedSystemPrompt = textPart.Text
1109 }
1110 }
1111 }
1112 // Capture tool choice
1113 capturedToolChoice = call.ToolChoice
1114 // Capture tool names
1115 for _, tool := range call.Tools {
1116 capturedToolNames = append(capturedToolNames, tool.GetName())
1117 }
1118 return &Response{
1119 Content: ResponseContent{
1120 TextContent{Text: "Response"},
1121 },
1122 Usage: Usage{TotalTokens: 10},
1123 FinishReason: FinishReasonStop,
1124 }, nil
1125 },
1126 }
1127
1128 tool1 := &mockTool{name: "tool1", description: "Tool 1"}
1129 tool2 := &mockTool{name: "tool2", description: "Tool 2"}
1130
1131 prepareStepFunc := func(options PrepareStepFunctionOptions) PrepareStepResult {
1132 newSystem := "Step-specific system"
1133 toolChoice := SpecificToolChoice("tool1")
1134 activeTools := []string{"tool1"}
1135 return PrepareStepResult{
1136 Model: options.Model,
1137 Messages: options.Messages,
1138 System: &newSystem,
1139 ToolChoice: &toolChoice,
1140 ActiveTools: activeTools,
1141 }
1142 }
1143
1144 agent := NewAgent(model, WithSystemPrompt("Original system"), WithTools(tool1, tool2))
1145
1146 result, err := agent.Generate(context.Background(), AgentCall{
1147 Prompt: "test prompt",
1148 PrepareStep: prepareStepFunc,
1149 })
1150
1151 require.NoError(t, err)
1152 require.NotNil(t, result)
1153 require.Equal(t, "Step-specific system", capturedSystemPrompt)
1154 require.NotNil(t, capturedToolChoice)
1155 require.Equal(t, SpecificToolChoice("tool1"), *capturedToolChoice)
1156 require.Len(t, capturedToolNames, 1)
1157 require.Equal(t, "tool1", capturedToolNames[0])
1158 })
1159
1160 t.Run("Nil fields use parent values", func(t *testing.T) {
1161 var capturedSystemPrompt string
1162 var capturedToolChoice *ToolChoice
1163 var capturedToolNames []string
1164
1165 model := &mockLanguageModel{
1166 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1167 // Capture system prompt
1168 if len(call.Prompt) > 0 && call.Prompt[0].Role == MessageRoleSystem {
1169 if len(call.Prompt[0].Content) > 0 {
1170 if textPart, ok := AsContentType[TextPart](call.Prompt[0].Content[0]); ok {
1171 capturedSystemPrompt = textPart.Text
1172 }
1173 }
1174 }
1175 // Capture tool choice
1176 capturedToolChoice = call.ToolChoice
1177 // Capture tool names
1178 for _, tool := range call.Tools {
1179 capturedToolNames = append(capturedToolNames, tool.GetName())
1180 }
1181 return &Response{
1182 Content: ResponseContent{
1183 TextContent{Text: "Response"},
1184 },
1185 Usage: Usage{TotalTokens: 10},
1186 FinishReason: FinishReasonStop,
1187 }, nil
1188 },
1189 }
1190
1191 tool1 := &mockTool{name: "tool1", description: "Tool 1"}
1192
1193 prepareStepFunc := func(options PrepareStepFunctionOptions) PrepareStepResult {
1194 // All optional fields are nil, should use parent values
1195 return PrepareStepResult{
1196 Model: options.Model,
1197 Messages: options.Messages,
1198 System: nil, // Use parent
1199 ToolChoice: nil, // Use parent (auto)
1200 ActiveTools: nil, // Use parent (all tools)
1201 }
1202 }
1203
1204 agent := NewAgent(model, WithSystemPrompt("Parent system"), WithTools(tool1))
1205
1206 result, err := agent.Generate(context.Background(), AgentCall{
1207 Prompt: "test prompt",
1208 PrepareStep: prepareStepFunc,
1209 })
1210
1211 require.NoError(t, err)
1212 require.NotNil(t, result)
1213 require.Equal(t, "Parent system", capturedSystemPrompt)
1214 require.NotNil(t, capturedToolChoice)
1215 require.Equal(t, ToolChoiceAuto, *capturedToolChoice) // Default
1216 require.Len(t, capturedToolNames, 1)
1217 require.Equal(t, "tool1", capturedToolNames[0])
1218 })
1219
1220 t.Run("Empty ActiveTools means all tools", func(t *testing.T) {
1221 var capturedToolNames []string
1222 model := &mockLanguageModel{
1223 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1224 // Capture tool names to verify all tools are included
1225 for _, tool := range call.Tools {
1226 capturedToolNames = append(capturedToolNames, tool.GetName())
1227 }
1228 return &Response{
1229 Content: ResponseContent{
1230 TextContent{Text: "Response"},
1231 },
1232 Usage: Usage{TotalTokens: 10},
1233 FinishReason: FinishReasonStop,
1234 }, nil
1235 },
1236 }
1237
1238 tool1 := &mockTool{name: "tool1", description: "Tool 1"}
1239 tool2 := &mockTool{name: "tool2", description: "Tool 2"}
1240
1241 prepareStepFunc := func(options PrepareStepFunctionOptions) PrepareStepResult {
1242 return PrepareStepResult{
1243 Model: options.Model,
1244 Messages: options.Messages,
1245 ActiveTools: []string{}, // Empty slice means all tools
1246 }
1247 }
1248
1249 agent := NewAgent(model, WithTools(tool1, tool2))
1250
1251 result, err := agent.Generate(context.Background(), AgentCall{
1252 Prompt: "test prompt",
1253 PrepareStep: prepareStepFunc,
1254 })
1255
1256 require.NoError(t, err)
1257 require.NotNil(t, result)
1258 require.Len(t, capturedToolNames, 2) // All tools should be included
1259 require.Contains(t, capturedToolNames, "tool1")
1260 require.Contains(t, capturedToolNames, "tool2")
1261 })
1262}
1263
1264func TestToolCallRepair(t *testing.T) {
1265 t.Parallel()
1266
1267 t.Run("Valid tool call passes validation", func(t *testing.T) {
1268 model := &mockLanguageModel{
1269 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1270 return &Response{
1271 Content: ResponseContent{
1272 TextContent{Text: "Response"},
1273 ToolCallContent{
1274 ToolCallID: "call1",
1275 ToolName: "test_tool",
1276 Input: `{"value": "test"}`, // Valid JSON with required field
1277 },
1278 },
1279 Usage: Usage{TotalTokens: 10},
1280 FinishReason: FinishReasonStop, // Changed to stop to avoid infinite loop
1281 }, nil
1282 },
1283 }
1284
1285 tool := &mockTool{
1286 name: "test_tool",
1287 description: "Test tool",
1288 parameters: map[string]any{
1289 "value": map[string]any{"type": "string"},
1290 },
1291 required: []string{"value"},
1292 executeFunc: func(ctx context.Context, call ToolCall) (ToolResponse, error) {
1293 return ToolResponse{Content: "success", IsError: false}, nil
1294 },
1295 }
1296
1297 agent := NewAgent(model, WithTools(tool), WithStopConditions(StepCountIs(2))) // Limit steps
1298
1299 result, err := agent.Generate(context.Background(), AgentCall{
1300 Prompt: "test prompt",
1301 })
1302
1303 require.NoError(t, err)
1304 require.NotNil(t, result)
1305 require.Len(t, result.Steps, 1) // Only one step since FinishReason is stop
1306
1307 // Check that tool call was executed successfully
1308 toolCalls := result.Steps[0].Response.Content.ToolCalls()
1309 require.Len(t, toolCalls, 1)
1310 require.False(t, toolCalls[0].Invalid) // Should be valid
1311 })
1312
1313 t.Run("Invalid tool call without repair function", func(t *testing.T) {
1314 model := &mockLanguageModel{
1315 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1316 return &Response{
1317 Content: ResponseContent{
1318 TextContent{Text: "Response"},
1319 ToolCallContent{
1320 ToolCallID: "call1",
1321 ToolName: "test_tool",
1322 Input: `{"wrong_field": "test"}`, // Missing required field
1323 },
1324 },
1325 Usage: Usage{TotalTokens: 10},
1326 FinishReason: FinishReasonStop, // Changed to stop to avoid infinite loop
1327 }, nil
1328 },
1329 }
1330
1331 tool := &mockTool{
1332 name: "test_tool",
1333 description: "Test tool",
1334 parameters: map[string]any{
1335 "value": map[string]any{"type": "string"},
1336 },
1337 required: []string{"value"},
1338 }
1339
1340 agent := NewAgent(model, WithTools(tool), WithStopConditions(StepCountIs(2))) // Limit steps
1341
1342 result, err := agent.Generate(context.Background(), AgentCall{
1343 Prompt: "test prompt",
1344 })
1345
1346 require.NoError(t, err)
1347 require.NotNil(t, result)
1348 require.Len(t, result.Steps, 1) // Only one step
1349
1350 // Check that tool call was marked as invalid
1351 toolCalls := result.Steps[0].Response.Content.ToolCalls()
1352 require.Len(t, toolCalls, 1)
1353 require.True(t, toolCalls[0].Invalid) // Should be invalid
1354 require.Contains(t, toolCalls[0].ValidationError.Error(), "missing required parameter: value")
1355 })
1356
1357 t.Run("Invalid tool call with successful repair", func(t *testing.T) {
1358 model := &mockLanguageModel{
1359 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1360 return &Response{
1361 Content: ResponseContent{
1362 TextContent{Text: "Response"},
1363 ToolCallContent{
1364 ToolCallID: "call1",
1365 ToolName: "test_tool",
1366 Input: `{"wrong_field": "test"}`, // Missing required field
1367 },
1368 },
1369 Usage: Usage{TotalTokens: 10},
1370 FinishReason: FinishReasonStop, // Changed to stop
1371 }, nil
1372 },
1373 }
1374
1375 tool := &mockTool{
1376 name: "test_tool",
1377 description: "Test tool",
1378 parameters: map[string]any{
1379 "value": map[string]any{"type": "string"},
1380 },
1381 required: []string{"value"},
1382 executeFunc: func(ctx context.Context, call ToolCall) (ToolResponse, error) {
1383 return ToolResponse{Content: "repaired_success", IsError: false}, nil
1384 },
1385 }
1386
1387 repairFunc := func(ctx context.Context, options ToolCallRepairOptions) (*ToolCallContent, error) {
1388 // Simple repair: add the missing required field
1389 repairedToolCall := options.OriginalToolCall
1390 repairedToolCall.Input = `{"value": "repaired"}`
1391 return &repairedToolCall, nil
1392 }
1393
1394 agent := NewAgent(model, WithTools(tool), WithRepairToolCall(repairFunc), WithStopConditions(StepCountIs(2)))
1395
1396 result, err := agent.Generate(context.Background(), AgentCall{
1397 Prompt: "test prompt",
1398 })
1399
1400 require.NoError(t, err)
1401 require.NotNil(t, result)
1402 require.Len(t, result.Steps, 1) // Only one step
1403
1404 // Check that tool call was repaired and is now valid
1405 toolCalls := result.Steps[0].Response.Content.ToolCalls()
1406 require.Len(t, toolCalls, 1)
1407 require.False(t, toolCalls[0].Invalid) // Should be valid after repair
1408 require.Equal(t, `{"value": "repaired"}`, toolCalls[0].Input) // Should have repaired input
1409 })
1410
1411 t.Run("Invalid tool call with failed repair", func(t *testing.T) {
1412 model := &mockLanguageModel{
1413 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1414 return &Response{
1415 Content: ResponseContent{
1416 TextContent{Text: "Response"},
1417 ToolCallContent{
1418 ToolCallID: "call1",
1419 ToolName: "test_tool",
1420 Input: `{"wrong_field": "test"}`, // Missing required field
1421 },
1422 },
1423 Usage: Usage{TotalTokens: 10},
1424 FinishReason: FinishReasonStop, // Changed to stop
1425 }, nil
1426 },
1427 }
1428
1429 tool := &mockTool{
1430 name: "test_tool",
1431 description: "Test tool",
1432 parameters: map[string]any{
1433 "value": map[string]any{"type": "string"},
1434 },
1435 required: []string{"value"},
1436 }
1437
1438 repairFunc := func(ctx context.Context, options ToolCallRepairOptions) (*ToolCallContent, error) {
1439 // Repair function fails
1440 return nil, errors.New("repair failed")
1441 }
1442
1443 agent := NewAgent(model, WithTools(tool), WithRepairToolCall(repairFunc), WithStopConditions(StepCountIs(2)))
1444
1445 result, err := agent.Generate(context.Background(), AgentCall{
1446 Prompt: "test prompt",
1447 })
1448
1449 require.NoError(t, err)
1450 require.NotNil(t, result)
1451 require.Len(t, result.Steps, 1) // Only one step
1452
1453 // Check that tool call was marked as invalid since repair failed
1454 toolCalls := result.Steps[0].Response.Content.ToolCalls()
1455 require.Len(t, toolCalls, 1)
1456 require.True(t, toolCalls[0].Invalid) // Should be invalid
1457 require.Contains(t, toolCalls[0].ValidationError.Error(), "missing required parameter: value")
1458 })
1459
1460 t.Run("Nonexistent tool call", func(t *testing.T) {
1461 model := &mockLanguageModel{
1462 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1463 return &Response{
1464 Content: ResponseContent{
1465 TextContent{Text: "Response"},
1466 ToolCallContent{
1467 ToolCallID: "call1",
1468 ToolName: "nonexistent_tool",
1469 Input: `{"value": "test"}`,
1470 },
1471 },
1472 Usage: Usage{TotalTokens: 10},
1473 FinishReason: FinishReasonStop, // Changed to stop
1474 }, nil
1475 },
1476 }
1477
1478 tool := &mockTool{name: "test_tool", description: "Test tool"}
1479
1480 agent := NewAgent(model, WithTools(tool), WithStopConditions(StepCountIs(2)))
1481
1482 result, err := agent.Generate(context.Background(), AgentCall{
1483 Prompt: "test prompt",
1484 })
1485
1486 require.NoError(t, err)
1487 require.NotNil(t, result)
1488 require.Len(t, result.Steps, 1) // Only one step
1489
1490 // Check that tool call was marked as invalid due to nonexistent tool
1491 toolCalls := result.Steps[0].Response.Content.ToolCalls()
1492 require.Len(t, toolCalls, 1)
1493 require.True(t, toolCalls[0].Invalid) // Should be invalid
1494 require.Contains(t, toolCalls[0].ValidationError.Error(), "tool not found: nonexistent_tool")
1495 })
1496
1497 t.Run("Invalid JSON in tool call", func(t *testing.T) {
1498 model := &mockLanguageModel{
1499 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1500 return &Response{
1501 Content: ResponseContent{
1502 TextContent{Text: "Response"},
1503 ToolCallContent{
1504 ToolCallID: "call1",
1505 ToolName: "test_tool",
1506 Input: `{invalid json}`, // Invalid JSON
1507 },
1508 },
1509 Usage: Usage{TotalTokens: 10},
1510 FinishReason: FinishReasonStop, // Changed to stop
1511 }, nil
1512 },
1513 }
1514
1515 tool := &mockTool{
1516 name: "test_tool",
1517 description: "Test tool",
1518 parameters: map[string]any{
1519 "value": map[string]any{"type": "string"},
1520 },
1521 required: []string{"value"},
1522 }
1523
1524 agent := NewAgent(model, WithTools(tool), WithStopConditions(StepCountIs(2)))
1525
1526 result, err := agent.Generate(context.Background(), AgentCall{
1527 Prompt: "test prompt",
1528 })
1529
1530 require.NoError(t, err)
1531 require.NotNil(t, result)
1532 require.Len(t, result.Steps, 1) // Only one step
1533
1534 // Check that tool call was marked as invalid due to invalid JSON
1535 toolCalls := result.Steps[0].Response.Content.ToolCalls()
1536 require.Len(t, toolCalls, 1)
1537 require.True(t, toolCalls[0].Invalid) // Should be invalid
1538 require.Contains(t, toolCalls[0].ValidationError.Error(), "invalid JSON input")
1539 })
1540}