1package ai
2
3import (
4 "context"
5 "encoding/json"
6 "errors"
7 "fmt"
8 "testing"
9
10 "github.com/charmbracelet/crush/internal/llm/tools"
11 "github.com/stretchr/testify/require"
12)
13
14// Mock tool for testing
15type mockTool struct {
16 name string
17 description string
18 parameters map[string]any
19 required []string
20 executeFunc func(ctx context.Context, call tools.ToolCall) (tools.ToolResponse, error)
21}
22
23func (m *mockTool) Info() tools.ToolInfo {
24 return tools.ToolInfo{
25 Name: m.name,
26 Description: m.description,
27 Parameters: m.parameters,
28 Required: m.required,
29 }
30}
31
32func (m *mockTool) Run(ctx context.Context, call tools.ToolCall) (tools.ToolResponse, error) {
33 if m.executeFunc != nil {
34 return m.executeFunc(ctx, call)
35 }
36 return tools.ToolResponse{Content: "mock result", IsError: false}, nil
37}
38
39// Mock language model for testing
40type mockLanguageModel struct {
41 generateFunc func(ctx context.Context, call Call) (*Response, error)
42 streamFunc func(ctx context.Context, call Call) (StreamResponse, error)
43}
44
45func (m *mockLanguageModel) Generate(ctx context.Context, call Call) (*Response, error) {
46 if m.generateFunc != nil {
47 return m.generateFunc(ctx, call)
48 }
49 return &Response{
50 Content: []Content{
51 TextContent{Text: "Hello, world!"},
52 },
53 Usage: Usage{
54 InputTokens: 3,
55 OutputTokens: 10,
56 TotalTokens: 13,
57 },
58 FinishReason: FinishReasonStop,
59 }, nil
60}
61
62func (m *mockLanguageModel) Stream(ctx context.Context, call Call) (StreamResponse, error) {
63 if m.streamFunc != nil {
64 return m.streamFunc(ctx, call)
65 }
66 return nil, fmt.Errorf("mock stream not implemented")
67}
68
69func (m *mockLanguageModel) Provider() string {
70 return "mock-provider"
71}
72
73func (m *mockLanguageModel) Model() string {
74 return "mock-model"
75}
76
77// Test result.content - comprehensive content types (matches TS test)
78func TestAgent_Generate_ResultContent_AllTypes(t *testing.T) {
79 t.Parallel()
80
81 tool1 := &mockTool{
82 name: "tool1",
83 description: "Test tool",
84 parameters: map[string]any{
85 "value": map[string]any{"type": "string"},
86 },
87 required: []string{"value"},
88 executeFunc: func(ctx context.Context, call tools.ToolCall) (tools.ToolResponse, error) {
89 var input map[string]any
90 err := json.Unmarshal([]byte(call.Input), &input)
91 require.NoError(t, err)
92 require.Equal(t, "value", input["value"])
93 return tools.ToolResponse{Content: "result1", IsError: false}, nil
94 },
95 }
96
97 model := &mockLanguageModel{
98 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
99 return &Response{
100 Content: []Content{
101 TextContent{Text: "Hello, world!"},
102 SourceContent{
103 ID: "123",
104 URL: "https://example.com",
105 Title: "Example",
106 SourceType: SourceTypeURL,
107 ProviderMetadata: ProviderMetadata{
108 "provider": map[string]any{"custom": "value"},
109 },
110 },
111 FileContent{
112 Data: []byte{1, 2, 3},
113 MediaType: "image/png",
114 },
115 ReasoningContent{
116 Text: "I will open the conversation with witty banter.",
117 },
118 ToolCallContent{
119 ToolCallID: "call-1",
120 ToolName: "tool1",
121 Input: `{"value":"value"}`,
122 },
123 TextContent{Text: "More text"},
124 },
125 Usage: Usage{
126 InputTokens: 3,
127 OutputTokens: 10,
128 TotalTokens: 13,
129 },
130 FinishReason: FinishReasonStop, // Note: FinishReasonStop, not ToolCalls
131 }, nil
132 },
133 }
134
135 agent := NewAgent(model, WithTools(tool1))
136 result, err := agent.Generate(context.Background(), AgentCall{
137 Prompt: "prompt",
138 })
139
140 require.NoError(t, err)
141 require.NotNil(t, result)
142 require.Len(t, result.Steps, 1) // Single step like TypeScript
143
144 // Check final response content includes tool result
145 require.Len(t, result.Response.Content, 7) // original 6 + 1 tool result
146
147 // Verify each content type in order
148 textContent, ok := AsContentType[TextContent](result.Response.Content[0])
149 require.True(t, ok)
150 require.Equal(t, "Hello, world!", textContent.Text)
151
152 sourceContent, ok := AsContentType[SourceContent](result.Response.Content[1])
153 require.True(t, ok)
154 require.Equal(t, "123", sourceContent.ID)
155
156 fileContent, ok := AsContentType[FileContent](result.Response.Content[2])
157 require.True(t, ok)
158 require.Equal(t, []byte{1, 2, 3}, fileContent.Data)
159
160 reasoningContent, ok := AsContentType[ReasoningContent](result.Response.Content[3])
161 require.True(t, ok)
162 require.Equal(t, "I will open the conversation with witty banter.", reasoningContent.Text)
163
164 toolCallContent, ok := AsContentType[ToolCallContent](result.Response.Content[4])
165 require.True(t, ok)
166 require.Equal(t, "call-1", toolCallContent.ToolCallID)
167
168 moreTextContent, ok := AsContentType[TextContent](result.Response.Content[5])
169 require.True(t, ok)
170 require.Equal(t, "More text", moreTextContent.Text)
171
172 // Tool result should be appended
173 toolResultContent, ok := AsContentType[ToolResultContent](result.Response.Content[6])
174 require.True(t, ok)
175 require.Equal(t, "call-1", toolResultContent.ToolCallID)
176 require.Equal(t, "tool1", toolResultContent.ToolName)
177}
178
179// Test result.text extraction
180func TestAgent_Generate_ResultText(t *testing.T) {
181 t.Parallel()
182
183 model := &mockLanguageModel{
184 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
185 return &Response{
186 Content: []Content{
187 TextContent{Text: "Hello, world!"},
188 },
189 Usage: Usage{
190 InputTokens: 3,
191 OutputTokens: 10,
192 TotalTokens: 13,
193 },
194 FinishReason: FinishReasonStop,
195 }, nil
196 },
197 }
198
199 agent := NewAgent(model)
200 result, err := agent.Generate(context.Background(), AgentCall{
201 Prompt: "prompt",
202 })
203
204 require.NoError(t, err)
205 require.NotNil(t, result)
206
207 // Test text extraction from content
208 text := result.Response.Content.Text()
209 require.Equal(t, "Hello, world!", text)
210}
211
212// Test result.toolCalls extraction (matches TS test exactly)
213func TestAgent_Generate_ResultToolCalls(t *testing.T) {
214 t.Parallel()
215
216 tool1 := &mockTool{
217 name: "tool1",
218 description: "Test tool 1",
219 parameters: map[string]any{
220 "value": map[string]any{"type": "string"},
221 },
222 required: []string{"value"},
223 }
224
225 tool2 := &mockTool{
226 name: "tool2",
227 description: "Test tool 2",
228 parameters: map[string]any{
229 "somethingElse": map[string]any{"type": "string"},
230 },
231 required: []string{"somethingElse"},
232 }
233
234 model := &mockLanguageModel{
235 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
236 // Verify tools are passed correctly
237 require.Len(t, call.Tools, 2)
238 require.Equal(t, ToolChoiceAuto, *call.ToolChoice) // Should be auto, not required
239
240 // Verify prompt structure
241 require.Len(t, call.Prompt, 1)
242 require.Equal(t, MessageRoleUser, call.Prompt[0].Role)
243
244 return &Response{
245 Content: []Content{
246 ToolCallContent{
247 ToolCallID: "call-1",
248 ToolName: "tool1",
249 Input: `{"value":"value"}`,
250 },
251 },
252 Usage: Usage{
253 InputTokens: 3,
254 OutputTokens: 10,
255 TotalTokens: 13,
256 },
257 FinishReason: FinishReasonStop, // Note: Stop, not ToolCalls
258 }, nil
259 },
260 }
261
262 agent := NewAgent(model, WithTools(tool1, tool2))
263 result, err := agent.Generate(context.Background(), AgentCall{
264 Prompt: "test-input",
265 })
266
267 require.NoError(t, err)
268 require.NotNil(t, result)
269 require.Len(t, result.Steps, 1) // Single step
270
271 // Extract tool calls from final response (should be empty since tools don't execute)
272 var toolCalls []ToolCallContent
273 for _, content := range result.Response.Content {
274 if toolCall, ok := AsContentType[ToolCallContent](content); ok {
275 toolCalls = append(toolCalls, toolCall)
276 }
277 }
278
279 require.Len(t, toolCalls, 1)
280 require.Equal(t, "call-1", toolCalls[0].ToolCallID)
281 require.Equal(t, "tool1", toolCalls[0].ToolName)
282
283 // Parse and verify input
284 var input map[string]any
285 err = json.Unmarshal([]byte(toolCalls[0].Input), &input)
286 require.NoError(t, err)
287 require.Equal(t, "value", input["value"])
288}
289
290// Test result.toolResults extraction (matches TS test exactly)
291func TestAgent_Generate_ResultToolResults(t *testing.T) {
292 t.Parallel()
293
294 tool1 := &mockTool{
295 name: "tool1",
296 description: "Test tool",
297 parameters: map[string]any{
298 "value": map[string]any{"type": "string"},
299 },
300 required: []string{"value"},
301 executeFunc: func(ctx context.Context, call tools.ToolCall) (tools.ToolResponse, error) {
302 var input map[string]any
303 err := json.Unmarshal([]byte(call.Input), &input)
304 require.NoError(t, err)
305 require.Equal(t, "value", input["value"])
306 return tools.ToolResponse{Content: "result1", IsError: false}, nil
307 },
308 }
309
310 model := &mockLanguageModel{
311 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
312 // Verify tools and tool choice
313 require.Len(t, call.Tools, 1)
314 require.Equal(t, ToolChoiceAuto, *call.ToolChoice)
315
316 // Verify prompt
317 require.Len(t, call.Prompt, 1)
318 require.Equal(t, MessageRoleUser, call.Prompt[0].Role)
319
320 return &Response{
321 Content: []Content{
322 ToolCallContent{
323 ToolCallID: "call-1",
324 ToolName: "tool1",
325 Input: `{"value":"value"}`,
326 },
327 },
328 Usage: Usage{
329 InputTokens: 3,
330 OutputTokens: 10,
331 TotalTokens: 13,
332 },
333 FinishReason: FinishReasonStop, // Note: Stop, not ToolCalls
334 }, nil
335 },
336 }
337
338 agent := NewAgent(model, WithTools(tool1))
339 result, err := agent.Generate(context.Background(), AgentCall{
340 Prompt: "test-input",
341 })
342
343 require.NoError(t, err)
344 require.NotNil(t, result)
345 require.Len(t, result.Steps, 1) // Single step
346
347 // Extract tool results from final response
348 var toolResults []ToolResultContent
349 for _, content := range result.Response.Content {
350 if toolResult, ok := AsContentType[ToolResultContent](content); ok {
351 toolResults = append(toolResults, toolResult)
352 }
353 }
354
355 require.Len(t, toolResults, 1)
356 require.Equal(t, "call-1", toolResults[0].ToolCallID)
357 require.Equal(t, "tool1", toolResults[0].ToolName)
358
359 // Verify result content
360 textResult, ok := toolResults[0].Result.(ToolResultOutputContentText)
361 require.True(t, ok)
362 require.Equal(t, "result1", textResult.Text)
363}
364
365// Test multi-step scenario (matches TS "2 steps: initial, tool-result" test)
366func TestAgent_Generate_MultipleSteps(t *testing.T) {
367 t.Parallel()
368
369 tool1 := &mockTool{
370 name: "tool1",
371 description: "Test tool",
372 parameters: map[string]any{
373 "value": map[string]any{"type": "string"},
374 },
375 required: []string{"value"},
376 executeFunc: func(ctx context.Context, call tools.ToolCall) (tools.ToolResponse, error) {
377 var input map[string]any
378 err := json.Unmarshal([]byte(call.Input), &input)
379 require.NoError(t, err)
380 require.Equal(t, "value", input["value"])
381 return tools.ToolResponse{Content: "result1", IsError: false}, nil
382 },
383 }
384
385 callCount := 0
386 model := &mockLanguageModel{
387 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
388 callCount++
389 switch callCount {
390 case 1:
391 // First call - return tool call with FinishReasonToolCalls
392 return &Response{
393 Content: []Content{
394 ToolCallContent{
395 ToolCallID: "call-1",
396 ToolName: "tool1",
397 Input: `{"value":"value"}`,
398 },
399 },
400 Usage: Usage{
401 InputTokens: 10,
402 OutputTokens: 5,
403 TotalTokens: 15,
404 },
405 FinishReason: FinishReasonToolCalls, // This triggers multi-step
406 }, nil
407 case 2:
408 // Second call - return final text
409 return &Response{
410 Content: []Content{
411 TextContent{Text: "Hello, world!"},
412 },
413 Usage: Usage{
414 InputTokens: 3,
415 OutputTokens: 10,
416 TotalTokens: 13,
417 },
418 FinishReason: FinishReasonStop,
419 }, nil
420 default:
421 t.Fatalf("Unexpected call count: %d", callCount)
422 return nil, nil
423 }
424 },
425 }
426
427 agent := NewAgent(model, WithTools(tool1))
428 result, err := agent.Generate(context.Background(), AgentCall{
429 Prompt: "test-input",
430 })
431
432 require.NoError(t, err)
433 require.NotNil(t, result)
434 require.Len(t, result.Steps, 2)
435
436 // Check total usage sums both steps
437 require.Equal(t, int64(13), result.TotalUsage.InputTokens) // 10 + 3
438 require.Equal(t, int64(15), result.TotalUsage.OutputTokens) // 5 + 10
439 require.Equal(t, int64(28), result.TotalUsage.TotalTokens) // 15 + 13
440
441 // Final response should be from last step
442 require.Len(t, result.Response.Content, 1)
443 textContent, ok := AsContentType[TextContent](result.Response.Content[0])
444 require.True(t, ok)
445 require.Equal(t, "Hello, world!", textContent.Text)
446
447 // result.toolCalls should be empty (from last step)
448 var toolCalls []ToolCallContent
449 for _, content := range result.Response.Content {
450 if _, ok := AsContentType[ToolCallContent](content); ok {
451 toolCalls = append(toolCalls, content.(ToolCallContent))
452 }
453 }
454 require.Len(t, toolCalls, 0)
455
456 // result.toolResults should be empty (from last step)
457 var toolResults []ToolResultContent
458 for _, content := range result.Response.Content {
459 if _, ok := AsContentType[ToolResultContent](content); ok {
460 toolResults = append(toolResults, content.(ToolResultContent))
461 }
462 }
463 require.Len(t, toolResults, 0)
464}
465
466// Test basic text generation
467func TestAgent_Generate_BasicText(t *testing.T) {
468 t.Parallel()
469
470 model := &mockLanguageModel{
471 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
472 return &Response{
473 Content: []Content{
474 TextContent{Text: "Hello, world!"},
475 },
476 Usage: Usage{
477 InputTokens: 3,
478 OutputTokens: 10,
479 TotalTokens: 13,
480 },
481 FinishReason: FinishReasonStop,
482 }, nil
483 },
484 }
485
486 agent := NewAgent(model)
487 result, err := agent.Generate(context.Background(), AgentCall{
488 Prompt: "test prompt",
489 })
490
491 require.NoError(t, err)
492 require.NotNil(t, result)
493 require.Len(t, result.Steps, 1)
494
495 // Check final response
496 require.Len(t, result.Response.Content, 1)
497 textContent, ok := AsContentType[TextContent](result.Response.Content[0])
498 require.True(t, ok)
499 require.Equal(t, "Hello, world!", textContent.Text)
500
501 // Check usage
502 require.Equal(t, int64(3), result.Response.Usage.InputTokens)
503 require.Equal(t, int64(10), result.Response.Usage.OutputTokens)
504 require.Equal(t, int64(13), result.Response.Usage.TotalTokens)
505
506 // Check total usage
507 require.Equal(t, int64(3), result.TotalUsage.InputTokens)
508 require.Equal(t, int64(10), result.TotalUsage.OutputTokens)
509 require.Equal(t, int64(13), result.TotalUsage.TotalTokens)
510}
511
512// Test empty prompt error
513func TestAgent_Generate_EmptyPrompt(t *testing.T) {
514 t.Parallel()
515
516 model := &mockLanguageModel{}
517 agent := NewAgent(model)
518
519 result, err := agent.Generate(context.Background(), AgentCall{
520 Prompt: "", // Empty prompt should cause error
521 })
522
523 require.Error(t, err)
524 require.Nil(t, result)
525 require.Contains(t, err.Error(), "Prompt can't be empty")
526}
527
528// Test with system prompt
529func TestAgent_Generate_WithSystemPrompt(t *testing.T) {
530 t.Parallel()
531
532 model := &mockLanguageModel{
533 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
534 // Verify system message is included
535 require.Len(t, call.Prompt, 2) // system + user
536 require.Equal(t, MessageRoleSystem, call.Prompt[0].Role)
537 require.Equal(t, MessageRoleUser, call.Prompt[1].Role)
538
539 systemPart, ok := call.Prompt[0].Content[0].(TextPart)
540 require.True(t, ok)
541 require.Equal(t, "You are a helpful assistant", systemPart.Text)
542
543 return &Response{
544 Content: []Content{
545 TextContent{Text: "Hello, world!"},
546 },
547 Usage: Usage{
548 InputTokens: 3,
549 OutputTokens: 10,
550 TotalTokens: 13,
551 },
552 FinishReason: FinishReasonStop,
553 }, nil
554 },
555 }
556
557 agent := NewAgent(model, WithSystemPrompt("You are a helpful assistant"))
558 result, err := agent.Generate(context.Background(), AgentCall{
559 Prompt: "test prompt",
560 })
561
562 require.NoError(t, err)
563 require.NotNil(t, result)
564}
565
566// Test options.headers
567func TestAgent_Generate_OptionsHeaders(t *testing.T) {
568 t.Parallel()
569
570 model := &mockLanguageModel{
571 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
572 // Verify headers are passed
573 require.Equal(t, map[string]string{
574 "custom-request-header": "request-header-value",
575 }, call.Headers)
576
577 return &Response{
578 Content: []Content{
579 TextContent{Text: "Hello, world!"},
580 },
581 Usage: Usage{
582 InputTokens: 3,
583 OutputTokens: 10,
584 TotalTokens: 13,
585 },
586 FinishReason: FinishReasonStop,
587 }, nil
588 },
589 }
590
591 agent := NewAgent(model)
592 result, err := agent.Generate(context.Background(), AgentCall{
593 Prompt: "test-input",
594 Headers: map[string]string{"custom-request-header": "request-header-value"},
595 })
596
597 require.NoError(t, err)
598 require.NotNil(t, result)
599 require.Equal(t, "Hello, world!", result.Response.Content.Text())
600}
601
602// Test options.activeTools filtering
603func TestAgent_Generate_OptionsActiveTools(t *testing.T) {
604 t.Parallel()
605
606 tool1 := &mockTool{
607 name: "tool1",
608 description: "Test tool 1",
609 parameters: map[string]any{
610 "value": map[string]any{"type": "string"},
611 },
612 required: []string{"value"},
613 }
614
615 tool2 := &mockTool{
616 name: "tool2",
617 description: "Test tool 2",
618 parameters: map[string]any{
619 "value": map[string]any{"type": "string"},
620 },
621 required: []string{"value"},
622 }
623
624 model := &mockLanguageModel{
625 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
626 // Verify only tool1 is available
627 require.Len(t, call.Tools, 1)
628 functionTool, ok := call.Tools[0].(FunctionTool)
629 require.True(t, ok)
630 require.Equal(t, "tool1", functionTool.Name)
631
632 return &Response{
633 Content: []Content{
634 TextContent{Text: "Hello, world!"},
635 },
636 Usage: Usage{
637 InputTokens: 3,
638 OutputTokens: 10,
639 TotalTokens: 13,
640 },
641 FinishReason: FinishReasonStop,
642 }, nil
643 },
644 }
645
646 agent := NewAgent(model, WithTools(tool1, tool2))
647 result, err := agent.Generate(context.Background(), AgentCall{
648 Prompt: "test-input",
649 ActiveTools: []string{"tool1"}, // Only tool1 should be active
650 })
651
652 require.NoError(t, err)
653 require.NotNil(t, result)
654}
655
656func TestResponseContent_Getters(t *testing.T) {
657 t.Parallel()
658
659 // Create test content with all types
660 content := ResponseContent{
661 TextContent{Text: "Hello world"},
662 ReasoningContent{Text: "Let me think..."},
663 FileContent{Data: []byte("file data"), MediaType: "text/plain"},
664 SourceContent{SourceType: SourceTypeURL, URL: "https://example.com", Title: "Example"},
665 ToolCallContent{ToolCallID: "call1", ToolName: "test_tool", Input: `{"arg": "value"}`},
666 ToolResultContent{ToolCallID: "call1", ToolName: "test_tool", Result: ToolResultOutputContentText{Text: "result"}},
667 }
668
669 // Test Text()
670 require.Equal(t, "Hello world", content.Text())
671
672 // Test Reasoning()
673 reasoning := content.Reasoning()
674 require.Len(t, reasoning, 1)
675 require.Equal(t, "Let me think...", reasoning[0].Text)
676
677 // Test ReasoningText()
678 require.Equal(t, "Let me think...", content.ReasoningText())
679
680 // Test Files()
681 files := content.Files()
682 require.Len(t, files, 1)
683 require.Equal(t, "text/plain", files[0].MediaType)
684 require.Equal(t, []byte("file data"), files[0].Data)
685
686 // Test Sources()
687 sources := content.Sources()
688 require.Len(t, sources, 1)
689 require.Equal(t, SourceTypeURL, sources[0].SourceType)
690 require.Equal(t, "https://example.com", sources[0].URL)
691 require.Equal(t, "Example", sources[0].Title)
692
693 // Test ToolCalls()
694 toolCalls := content.ToolCalls()
695 require.Len(t, toolCalls, 1)
696 require.Equal(t, "call1", toolCalls[0].ToolCallID)
697 require.Equal(t, "test_tool", toolCalls[0].ToolName)
698 require.Equal(t, `{"arg": "value"}`, toolCalls[0].Input)
699
700 // Test ToolResults()
701 toolResults := content.ToolResults()
702 require.Len(t, toolResults, 1)
703 require.Equal(t, "call1", toolResults[0].ToolCallID)
704 require.Equal(t, "test_tool", toolResults[0].ToolName)
705 result, ok := AsToolResultOutputType[ToolResultOutputContentText](toolResults[0].Result)
706 require.True(t, ok)
707 require.Equal(t, "result", result.Text)
708}
709
710func TestResponseContent_Getters_Empty(t *testing.T) {
711 t.Parallel()
712
713 // Test with empty content
714 content := ResponseContent{}
715
716 require.Equal(t, "", content.Text())
717 require.Equal(t, "", content.ReasoningText())
718 require.Empty(t, content.Reasoning())
719 require.Empty(t, content.Files())
720 require.Empty(t, content.Sources())
721 require.Empty(t, content.ToolCalls())
722 require.Empty(t, content.ToolResults())
723}
724
725func TestResponseContent_Getters_MultipleItems(t *testing.T) {
726 t.Parallel()
727
728 // Test with multiple items of same type
729 content := ResponseContent{
730 ReasoningContent{Text: "First thought"},
731 ReasoningContent{Text: "Second thought"},
732 FileContent{Data: []byte("file1"), MediaType: "text/plain"},
733 FileContent{Data: []byte("file2"), MediaType: "image/png"},
734 }
735
736 // Test multiple reasoning
737 reasoning := content.Reasoning()
738 require.Len(t, reasoning, 2)
739 require.Equal(t, "First thought", reasoning[0].Text)
740 require.Equal(t, "Second thought", reasoning[1].Text)
741
742 // Test concatenated reasoning text
743 require.Equal(t, "First thoughtSecond thought", content.ReasoningText())
744
745 // Test multiple files
746 files := content.Files()
747 require.Len(t, files, 2)
748 require.Equal(t, "text/plain", files[0].MediaType)
749 require.Equal(t, "image/png", files[1].MediaType)
750}
751
752func TestStopConditions(t *testing.T) {
753 t.Parallel()
754
755 // Create test steps
756 step1 := StepResult{
757 Response: Response{
758 Content: ResponseContent{
759 TextContent{Text: "Hello"},
760 },
761 FinishReason: FinishReasonToolCalls,
762 Usage: Usage{TotalTokens: 10},
763 },
764 }
765
766 step2 := StepResult{
767 Response: Response{
768 Content: ResponseContent{
769 TextContent{Text: "World"},
770 ToolCallContent{ToolCallID: "call1", ToolName: "search", Input: `{"query": "test"}`},
771 },
772 FinishReason: FinishReasonStop,
773 Usage: Usage{TotalTokens: 15},
774 },
775 }
776
777 step3 := StepResult{
778 Response: Response{
779 Content: ResponseContent{
780 ReasoningContent{Text: "Let me think..."},
781 FileContent{Data: []byte("data"), MediaType: "text/plain"},
782 },
783 FinishReason: FinishReasonLength,
784 Usage: Usage{TotalTokens: 20},
785 },
786 }
787
788 t.Run("StepCountIs", func(t *testing.T) {
789 condition := StepCountIs(2)
790
791 // Should not stop with 1 step
792 require.False(t, condition([]StepResult{step1}))
793
794 // Should stop with 2 steps
795 require.True(t, condition([]StepResult{step1, step2}))
796
797 // Should stop with more than 2 steps
798 require.True(t, condition([]StepResult{step1, step2, step3}))
799
800 // Should not stop with empty steps
801 require.False(t, condition([]StepResult{}))
802 })
803
804 t.Run("HasToolCall", func(t *testing.T) {
805 condition := HasToolCall("search")
806
807 // Should not stop when tool not called
808 require.False(t, condition([]StepResult{step1}))
809
810 // Should stop when tool is called in last step
811 require.True(t, condition([]StepResult{step1, step2}))
812
813 // Should not stop when tool called in earlier step but not last
814 require.False(t, condition([]StepResult{step1, step2, step3}))
815
816 // Should not stop with empty steps
817 require.False(t, condition([]StepResult{}))
818
819 // Should not stop when different tool is called
820 differentToolCondition := HasToolCall("different_tool")
821 require.False(t, differentToolCondition([]StepResult{step1, step2}))
822 })
823
824 t.Run("HasContent", func(t *testing.T) {
825 reasoningCondition := HasContent(ContentTypeReasoning)
826 fileCondition := HasContent(ContentTypeFile)
827
828 // Should not stop when content type not present
829 require.False(t, reasoningCondition([]StepResult{step1, step2}))
830
831 // Should stop when content type is present in last step
832 require.True(t, reasoningCondition([]StepResult{step1, step2, step3}))
833 require.True(t, fileCondition([]StepResult{step1, step2, step3}))
834
835 // Should not stop with empty steps
836 require.False(t, reasoningCondition([]StepResult{}))
837 })
838
839 t.Run("FinishReasonIs", func(t *testing.T) {
840 stopCondition := FinishReasonIs(FinishReasonStop)
841 lengthCondition := FinishReasonIs(FinishReasonLength)
842
843 // Should not stop when finish reason doesn't match
844 require.False(t, stopCondition([]StepResult{step1}))
845
846 // Should stop when finish reason matches in last step
847 require.True(t, stopCondition([]StepResult{step1, step2}))
848 require.True(t, lengthCondition([]StepResult{step1, step2, step3}))
849
850 // Should not stop with empty steps
851 require.False(t, stopCondition([]StepResult{}))
852 })
853
854 t.Run("MaxTokensUsed", func(t *testing.T) {
855 condition := MaxTokensUsed(30)
856
857 // Should not stop when under limit
858 require.False(t, condition([]StepResult{step1})) // 10 tokens
859 require.False(t, condition([]StepResult{step1, step2})) // 25 tokens
860
861 // Should stop when at or over limit
862 require.True(t, condition([]StepResult{step1, step2, step3})) // 45 tokens
863
864 // Should not stop with empty steps
865 require.False(t, condition([]StepResult{}))
866 })
867}
868
869func TestStopConditions_Integration(t *testing.T) {
870 t.Parallel()
871
872 t.Run("StepCountIs integration", func(t *testing.T) {
873 model := &mockLanguageModel{
874 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
875 return &Response{
876 Content: ResponseContent{
877 TextContent{Text: "Mock response"},
878 },
879 Usage: Usage{
880 InputTokens: 3,
881 OutputTokens: 10,
882 TotalTokens: 13,
883 },
884 FinishReason: FinishReasonStop,
885 }, nil
886 },
887 }
888
889 agent := NewAgent(model, WithStopConditions(StepCountIs(1)))
890
891 result, err := agent.Generate(context.Background(), AgentCall{
892 Prompt: "test prompt",
893 })
894
895 require.NoError(t, err)
896 require.NotNil(t, result)
897 require.Len(t, result.Steps, 1) // Should stop after 1 step
898 })
899
900 t.Run("Multiple stop conditions", func(t *testing.T) {
901 model := &mockLanguageModel{
902 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
903 return &Response{
904 Content: ResponseContent{
905 TextContent{Text: "Mock response"},
906 },
907 Usage: Usage{
908 InputTokens: 3,
909 OutputTokens: 10,
910 TotalTokens: 13,
911 },
912 FinishReason: FinishReasonStop,
913 }, nil
914 },
915 }
916
917 agent := NewAgent(model, WithStopConditions(
918 StepCountIs(5), // Stop after 5 steps
919 FinishReasonIs(FinishReasonStop), // Or stop on finish reason
920 ))
921
922 result, err := agent.Generate(context.Background(), AgentCall{
923 Prompt: "test prompt",
924 })
925
926 require.NoError(t, err)
927 require.NotNil(t, result)
928 // Should stop on first condition met (finish reason stop)
929 require.Equal(t, FinishReasonStop, result.Response.FinishReason)
930 })
931}
932
933func TestPrepareStep(t *testing.T) {
934 t.Parallel()
935
936 t.Run("System prompt modification", func(t *testing.T) {
937 var capturedSystemPrompt string
938 model := &mockLanguageModel{
939 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
940 // Capture the system message to verify it was modified
941 if len(call.Prompt) > 0 && call.Prompt[0].Role == MessageRoleSystem {
942 if len(call.Prompt[0].Content) > 0 {
943 if textPart, ok := AsContentType[TextPart](call.Prompt[0].Content[0]); ok {
944 capturedSystemPrompt = textPart.Text
945 }
946 }
947 }
948 return &Response{
949 Content: ResponseContent{
950 TextContent{Text: "Response"},
951 },
952 Usage: Usage{TotalTokens: 10},
953 FinishReason: FinishReasonStop,
954 }, nil
955 },
956 }
957
958 prepareStepFunc := func(options PrepareStepFunctionOptions) PrepareStepResult {
959 newSystem := "Modified system prompt for step " + fmt.Sprintf("%d", options.StepNumber)
960 return PrepareStepResult{
961 Model: options.Model,
962 Messages: options.Messages,
963 System: &newSystem,
964 }
965 }
966
967 agent := NewAgent(model, WithSystemPrompt("Original system prompt"))
968
969 result, err := agent.Generate(context.Background(), AgentCall{
970 Prompt: "test prompt",
971 PrepareStep: prepareStepFunc,
972 })
973
974 require.NoError(t, err)
975 require.NotNil(t, result)
976 require.Equal(t, "Modified system prompt for step 0", capturedSystemPrompt)
977 })
978
979 t.Run("Tool choice modification", func(t *testing.T) {
980 var capturedToolChoice *ToolChoice
981 model := &mockLanguageModel{
982 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
983 capturedToolChoice = call.ToolChoice
984 return &Response{
985 Content: ResponseContent{
986 TextContent{Text: "Response"},
987 },
988 Usage: Usage{TotalTokens: 10},
989 FinishReason: FinishReasonStop,
990 }, nil
991 },
992 }
993
994 prepareStepFunc := func(options PrepareStepFunctionOptions) PrepareStepResult {
995 toolChoice := ToolChoiceNone
996 return PrepareStepResult{
997 Model: options.Model,
998 Messages: options.Messages,
999 ToolChoice: &toolChoice,
1000 }
1001 }
1002
1003 agent := NewAgent(model)
1004
1005 result, err := agent.Generate(context.Background(), AgentCall{
1006 Prompt: "test prompt",
1007 PrepareStep: prepareStepFunc,
1008 })
1009
1010 require.NoError(t, err)
1011 require.NotNil(t, result)
1012 require.NotNil(t, capturedToolChoice)
1013 require.Equal(t, ToolChoiceNone, *capturedToolChoice)
1014 })
1015
1016 t.Run("Active tools modification", func(t *testing.T) {
1017 var capturedToolNames []string
1018 model := &mockLanguageModel{
1019 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1020 // Capture tool names to verify active tools were modified
1021 for _, tool := range call.Tools {
1022 capturedToolNames = append(capturedToolNames, tool.GetName())
1023 }
1024 return &Response{
1025 Content: ResponseContent{
1026 TextContent{Text: "Response"},
1027 },
1028 Usage: Usage{TotalTokens: 10},
1029 FinishReason: FinishReasonStop,
1030 }, nil
1031 },
1032 }
1033
1034 tool1 := &mockTool{name: "tool1", description: "Tool 1"}
1035 tool2 := &mockTool{name: "tool2", description: "Tool 2"}
1036 tool3 := &mockTool{name: "tool3", description: "Tool 3"}
1037
1038 prepareStepFunc := func(options PrepareStepFunctionOptions) PrepareStepResult {
1039 activeTools := []string{"tool2"} // Only tool2 should be active
1040 return PrepareStepResult{
1041 Model: options.Model,
1042 Messages: options.Messages,
1043 ActiveTools: activeTools,
1044 }
1045 }
1046
1047 agent := NewAgent(model, WithTools(tool1, tool2, tool3))
1048
1049 result, err := agent.Generate(context.Background(), AgentCall{
1050 Prompt: "test prompt",
1051 PrepareStep: prepareStepFunc,
1052 })
1053
1054 require.NoError(t, err)
1055 require.NotNil(t, result)
1056 require.Len(t, capturedToolNames, 1)
1057 require.Equal(t, "tool2", capturedToolNames[0])
1058 })
1059
1060 t.Run("No tools when DisableAllTools is true", func(t *testing.T) {
1061 var capturedToolCount int
1062 model := &mockLanguageModel{
1063 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1064 capturedToolCount = len(call.Tools)
1065 return &Response{
1066 Content: ResponseContent{
1067 TextContent{Text: "Response"},
1068 },
1069 Usage: Usage{TotalTokens: 10},
1070 FinishReason: FinishReasonStop,
1071 }, nil
1072 },
1073 }
1074
1075 tool1 := &mockTool{name: "tool1", description: "Tool 1"}
1076
1077 prepareStepFunc := func(options PrepareStepFunctionOptions) PrepareStepResult {
1078 return PrepareStepResult{
1079 Model: options.Model,
1080 Messages: options.Messages,
1081 DisableAllTools: true, // Disable all tools for this step
1082 }
1083 }
1084
1085 agent := NewAgent(model, WithTools(tool1))
1086
1087 result, err := agent.Generate(context.Background(), AgentCall{
1088 Prompt: "test prompt",
1089 PrepareStep: prepareStepFunc,
1090 })
1091
1092 require.NoError(t, err)
1093 require.NotNil(t, result)
1094 require.Equal(t, 0, capturedToolCount) // No tools should be passed
1095 })
1096
1097 t.Run("All fields modified together", func(t *testing.T) {
1098 var capturedSystemPrompt string
1099 var capturedToolChoice *ToolChoice
1100 var capturedToolNames []string
1101
1102 model := &mockLanguageModel{
1103 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1104 // Capture system prompt
1105 if len(call.Prompt) > 0 && call.Prompt[0].Role == MessageRoleSystem {
1106 if len(call.Prompt[0].Content) > 0 {
1107 if textPart, ok := AsContentType[TextPart](call.Prompt[0].Content[0]); ok {
1108 capturedSystemPrompt = textPart.Text
1109 }
1110 }
1111 }
1112 // Capture tool choice
1113 capturedToolChoice = call.ToolChoice
1114 // Capture tool names
1115 for _, tool := range call.Tools {
1116 capturedToolNames = append(capturedToolNames, tool.GetName())
1117 }
1118 return &Response{
1119 Content: ResponseContent{
1120 TextContent{Text: "Response"},
1121 },
1122 Usage: Usage{TotalTokens: 10},
1123 FinishReason: FinishReasonStop,
1124 }, nil
1125 },
1126 }
1127
1128 tool1 := &mockTool{name: "tool1", description: "Tool 1"}
1129 tool2 := &mockTool{name: "tool2", description: "Tool 2"}
1130
1131 prepareStepFunc := func(options PrepareStepFunctionOptions) PrepareStepResult {
1132 newSystem := "Step-specific system"
1133 toolChoice := SpecificToolChoice("tool1")
1134 activeTools := []string{"tool1"}
1135 return PrepareStepResult{
1136 Model: options.Model,
1137 Messages: options.Messages,
1138 System: &newSystem,
1139 ToolChoice: &toolChoice,
1140 ActiveTools: activeTools,
1141 }
1142 }
1143
1144 agent := NewAgent(model, WithSystemPrompt("Original system"), WithTools(tool1, tool2))
1145
1146 result, err := agent.Generate(context.Background(), AgentCall{
1147 Prompt: "test prompt",
1148 PrepareStep: prepareStepFunc,
1149 })
1150
1151 require.NoError(t, err)
1152 require.NotNil(t, result)
1153 require.Equal(t, "Step-specific system", capturedSystemPrompt)
1154 require.NotNil(t, capturedToolChoice)
1155 require.Equal(t, SpecificToolChoice("tool1"), *capturedToolChoice)
1156 require.Len(t, capturedToolNames, 1)
1157 require.Equal(t, "tool1", capturedToolNames[0])
1158 })
1159
1160 t.Run("Nil fields use parent values", func(t *testing.T) {
1161 var capturedSystemPrompt string
1162 var capturedToolChoice *ToolChoice
1163 var capturedToolNames []string
1164
1165 model := &mockLanguageModel{
1166 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1167 // Capture system prompt
1168 if len(call.Prompt) > 0 && call.Prompt[0].Role == MessageRoleSystem {
1169 if len(call.Prompt[0].Content) > 0 {
1170 if textPart, ok := AsContentType[TextPart](call.Prompt[0].Content[0]); ok {
1171 capturedSystemPrompt = textPart.Text
1172 }
1173 }
1174 }
1175 // Capture tool choice
1176 capturedToolChoice = call.ToolChoice
1177 // Capture tool names
1178 for _, tool := range call.Tools {
1179 capturedToolNames = append(capturedToolNames, tool.GetName())
1180 }
1181 return &Response{
1182 Content: ResponseContent{
1183 TextContent{Text: "Response"},
1184 },
1185 Usage: Usage{TotalTokens: 10},
1186 FinishReason: FinishReasonStop,
1187 }, nil
1188 },
1189 }
1190
1191 tool1 := &mockTool{name: "tool1", description: "Tool 1"}
1192
1193 prepareStepFunc := func(options PrepareStepFunctionOptions) PrepareStepResult {
1194 // All optional fields are nil, should use parent values
1195 return PrepareStepResult{
1196 Model: options.Model,
1197 Messages: options.Messages,
1198 System: nil, // Use parent
1199 ToolChoice: nil, // Use parent (auto)
1200 ActiveTools: nil, // Use parent (all tools)
1201 }
1202 }
1203
1204 agent := NewAgent(model, WithSystemPrompt("Parent system"), WithTools(tool1))
1205
1206 result, err := agent.Generate(context.Background(), AgentCall{
1207 Prompt: "test prompt",
1208 PrepareStep: prepareStepFunc,
1209 })
1210
1211 require.NoError(t, err)
1212 require.NotNil(t, result)
1213 require.Equal(t, "Parent system", capturedSystemPrompt)
1214 require.NotNil(t, capturedToolChoice)
1215 require.Equal(t, ToolChoiceAuto, *capturedToolChoice) // Default
1216 require.Len(t, capturedToolNames, 1)
1217 require.Equal(t, "tool1", capturedToolNames[0])
1218 })
1219
1220 t.Run("Empty ActiveTools means all tools", func(t *testing.T) {
1221 var capturedToolNames []string
1222 model := &mockLanguageModel{
1223 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1224 // Capture tool names to verify all tools are included
1225 for _, tool := range call.Tools {
1226 capturedToolNames = append(capturedToolNames, tool.GetName())
1227 }
1228 return &Response{
1229 Content: ResponseContent{
1230 TextContent{Text: "Response"},
1231 },
1232 Usage: Usage{TotalTokens: 10},
1233 FinishReason: FinishReasonStop,
1234 }, nil
1235 },
1236 }
1237
1238 tool1 := &mockTool{name: "tool1", description: "Tool 1"}
1239 tool2 := &mockTool{name: "tool2", description: "Tool 2"}
1240
1241 prepareStepFunc := func(options PrepareStepFunctionOptions) PrepareStepResult {
1242 return PrepareStepResult{
1243 Model: options.Model,
1244 Messages: options.Messages,
1245 ActiveTools: []string{}, // Empty slice means all tools
1246 }
1247 }
1248
1249 agent := NewAgent(model, WithTools(tool1, tool2))
1250
1251 result, err := agent.Generate(context.Background(), AgentCall{
1252 Prompt: "test prompt",
1253 PrepareStep: prepareStepFunc,
1254 })
1255
1256 require.NoError(t, err)
1257 require.NotNil(t, result)
1258 require.Len(t, capturedToolNames, 2) // All tools should be included
1259 require.Contains(t, capturedToolNames, "tool1")
1260 require.Contains(t, capturedToolNames, "tool2")
1261 })
1262}
1263
1264func TestToolCallRepair(t *testing.T) {
1265 t.Parallel()
1266
1267 t.Run("Valid tool call passes validation", func(t *testing.T) {
1268 model := &mockLanguageModel{
1269 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1270 return &Response{
1271 Content: ResponseContent{
1272 TextContent{Text: "Response"},
1273 ToolCallContent{
1274 ToolCallID: "call1",
1275 ToolName: "test_tool",
1276 Input: `{"value": "test"}`, // Valid JSON with required field
1277 },
1278 },
1279 Usage: Usage{TotalTokens: 10},
1280 FinishReason: FinishReasonStop, // Changed to stop to avoid infinite loop
1281 }, nil
1282 },
1283 }
1284
1285 tool := &mockTool{
1286 name: "test_tool",
1287 description: "Test tool",
1288 parameters: map[string]any{
1289 "value": map[string]any{"type": "string"},
1290 },
1291 required: []string{"value"},
1292 executeFunc: func(ctx context.Context, call tools.ToolCall) (tools.ToolResponse, error) {
1293 return tools.ToolResponse{Content: "success", IsError: false}, nil
1294 },
1295 }
1296
1297 agent := NewAgent(model, WithTools(tool), WithStopConditions(StepCountIs(2))) // Limit steps
1298
1299 result, err := agent.Generate(context.Background(), AgentCall{
1300 Prompt: "test prompt",
1301 })
1302
1303 require.NoError(t, err)
1304 require.NotNil(t, result)
1305 require.Len(t, result.Steps, 1) // Only one step since FinishReason is stop
1306
1307 // Check that tool call was executed successfully
1308 toolCalls := result.Steps[0].Response.Content.ToolCalls()
1309 require.Len(t, toolCalls, 1)
1310 require.False(t, toolCalls[0].Invalid) // Should be valid
1311 })
1312
1313 t.Run("Invalid tool call without repair function", func(t *testing.T) {
1314 model := &mockLanguageModel{
1315 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1316 return &Response{
1317 Content: ResponseContent{
1318 TextContent{Text: "Response"},
1319 ToolCallContent{
1320 ToolCallID: "call1",
1321 ToolName: "test_tool",
1322 Input: `{"wrong_field": "test"}`, // Missing required field
1323 },
1324 },
1325 Usage: Usage{TotalTokens: 10},
1326 FinishReason: FinishReasonStop, // Changed to stop to avoid infinite loop
1327 }, nil
1328 },
1329 }
1330
1331 tool := &mockTool{
1332 name: "test_tool",
1333 description: "Test tool",
1334 parameters: map[string]any{
1335 "value": map[string]any{"type": "string"},
1336 },
1337 required: []string{"value"},
1338 }
1339
1340 agent := NewAgent(model, WithTools(tool), WithStopConditions(StepCountIs(2))) // Limit steps
1341
1342 result, err := agent.Generate(context.Background(), AgentCall{
1343 Prompt: "test prompt",
1344 })
1345
1346 require.NoError(t, err)
1347 require.NotNil(t, result)
1348 require.Len(t, result.Steps, 1) // Only one step
1349
1350 // Check that tool call was marked as invalid
1351 toolCalls := result.Steps[0].Response.Content.ToolCalls()
1352 require.Len(t, toolCalls, 1)
1353 require.True(t, toolCalls[0].Invalid) // Should be invalid
1354 require.Contains(t, toolCalls[0].ValidationError.Error(), "missing required parameter: value")
1355 })
1356
1357 t.Run("Invalid tool call with successful repair", func(t *testing.T) {
1358 model := &mockLanguageModel{
1359 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1360 return &Response{
1361 Content: ResponseContent{
1362 TextContent{Text: "Response"},
1363 ToolCallContent{
1364 ToolCallID: "call1",
1365 ToolName: "test_tool",
1366 Input: `{"wrong_field": "test"}`, // Missing required field
1367 },
1368 },
1369 Usage: Usage{TotalTokens: 10},
1370 FinishReason: FinishReasonStop, // Changed to stop
1371 }, nil
1372 },
1373 }
1374
1375 tool := &mockTool{
1376 name: "test_tool",
1377 description: "Test tool",
1378 parameters: map[string]any{
1379 "value": map[string]any{"type": "string"},
1380 },
1381 required: []string{"value"},
1382 executeFunc: func(ctx context.Context, call tools.ToolCall) (tools.ToolResponse, error) {
1383 return tools.ToolResponse{Content: "repaired_success", IsError: false}, nil
1384 },
1385 }
1386
1387 repairFunc := func(ctx context.Context, options ToolCallRepairOptions) (*ToolCallContent, error) {
1388 // Simple repair: add the missing required field
1389 repairedToolCall := options.OriginalToolCall
1390 repairedToolCall.Input = `{"value": "repaired"}`
1391 return &repairedToolCall, nil
1392 }
1393
1394 agent := NewAgent(model, WithTools(tool), WithRepairToolCall(repairFunc), WithStopConditions(StepCountIs(2)))
1395
1396 result, err := agent.Generate(context.Background(), AgentCall{
1397 Prompt: "test prompt",
1398 })
1399
1400 require.NoError(t, err)
1401 require.NotNil(t, result)
1402 require.Len(t, result.Steps, 1) // Only one step
1403
1404 // Check that tool call was repaired and is now valid
1405 toolCalls := result.Steps[0].Response.Content.ToolCalls()
1406 require.Len(t, toolCalls, 1)
1407 require.False(t, toolCalls[0].Invalid) // Should be valid after repair
1408 require.Equal(t, `{"value": "repaired"}`, toolCalls[0].Input) // Should have repaired input
1409 })
1410
1411 t.Run("Invalid tool call with failed repair", func(t *testing.T) {
1412 model := &mockLanguageModel{
1413 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1414 return &Response{
1415 Content: ResponseContent{
1416 TextContent{Text: "Response"},
1417 ToolCallContent{
1418 ToolCallID: "call1",
1419 ToolName: "test_tool",
1420 Input: `{"wrong_field": "test"}`, // Missing required field
1421 },
1422 },
1423 Usage: Usage{TotalTokens: 10},
1424 FinishReason: FinishReasonStop, // Changed to stop
1425 }, nil
1426 },
1427 }
1428
1429 tool := &mockTool{
1430 name: "test_tool",
1431 description: "Test tool",
1432 parameters: map[string]any{
1433 "value": map[string]any{"type": "string"},
1434 },
1435 required: []string{"value"},
1436 }
1437
1438 repairFunc := func(ctx context.Context, options ToolCallRepairOptions) (*ToolCallContent, error) {
1439 // Repair function fails
1440 return nil, errors.New("repair failed")
1441 }
1442
1443 agent := NewAgent(model, WithTools(tool), WithRepairToolCall(repairFunc), WithStopConditions(StepCountIs(2)))
1444
1445 result, err := agent.Generate(context.Background(), AgentCall{
1446 Prompt: "test prompt",
1447 })
1448
1449 require.NoError(t, err)
1450 require.NotNil(t, result)
1451 require.Len(t, result.Steps, 1) // Only one step
1452
1453 // Check that tool call was marked as invalid since repair failed
1454 toolCalls := result.Steps[0].Response.Content.ToolCalls()
1455 require.Len(t, toolCalls, 1)
1456 require.True(t, toolCalls[0].Invalid) // Should be invalid
1457 require.Contains(t, toolCalls[0].ValidationError.Error(), "missing required parameter: value")
1458 })
1459
1460 t.Run("Nonexistent tool call", func(t *testing.T) {
1461 model := &mockLanguageModel{
1462 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1463 return &Response{
1464 Content: ResponseContent{
1465 TextContent{Text: "Response"},
1466 ToolCallContent{
1467 ToolCallID: "call1",
1468 ToolName: "nonexistent_tool",
1469 Input: `{"value": "test"}`,
1470 },
1471 },
1472 Usage: Usage{TotalTokens: 10},
1473 FinishReason: FinishReasonStop, // Changed to stop
1474 }, nil
1475 },
1476 }
1477
1478 tool := &mockTool{name: "test_tool", description: "Test tool"}
1479
1480 agent := NewAgent(model, WithTools(tool), WithStopConditions(StepCountIs(2)))
1481
1482 result, err := agent.Generate(context.Background(), AgentCall{
1483 Prompt: "test prompt",
1484 })
1485
1486 require.NoError(t, err)
1487 require.NotNil(t, result)
1488 require.Len(t, result.Steps, 1) // Only one step
1489
1490 // Check that tool call was marked as invalid due to nonexistent tool
1491 toolCalls := result.Steps[0].Response.Content.ToolCalls()
1492 require.Len(t, toolCalls, 1)
1493 require.True(t, toolCalls[0].Invalid) // Should be invalid
1494 require.Contains(t, toolCalls[0].ValidationError.Error(), "tool not found: nonexistent_tool")
1495 })
1496
1497 t.Run("Invalid JSON in tool call", func(t *testing.T) {
1498 model := &mockLanguageModel{
1499 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1500 return &Response{
1501 Content: ResponseContent{
1502 TextContent{Text: "Response"},
1503 ToolCallContent{
1504 ToolCallID: "call1",
1505 ToolName: "test_tool",
1506 Input: `{invalid json}`, // Invalid JSON
1507 },
1508 },
1509 Usage: Usage{TotalTokens: 10},
1510 FinishReason: FinishReasonStop, // Changed to stop
1511 }, nil
1512 },
1513 }
1514
1515 tool := &mockTool{
1516 name: "test_tool",
1517 description: "Test tool",
1518 parameters: map[string]any{
1519 "value": map[string]any{"type": "string"},
1520 },
1521 required: []string{"value"},
1522 }
1523
1524 agent := NewAgent(model, WithTools(tool), WithStopConditions(StepCountIs(2)))
1525
1526 result, err := agent.Generate(context.Background(), AgentCall{
1527 Prompt: "test prompt",
1528 })
1529
1530 require.NoError(t, err)
1531 require.NotNil(t, result)
1532 require.Len(t, result.Steps, 1) // Only one step
1533
1534 // Check that tool call was marked as invalid due to invalid JSON
1535 toolCalls := result.Steps[0].Response.Content.ToolCalls()
1536 require.Len(t, toolCalls, 1)
1537 require.True(t, toolCalls[0].Invalid) // Should be invalid
1538 require.Contains(t, toolCalls[0].ValidationError.Error(), "invalid JSON input")
1539 })
1540}