1package ai
2
3import (
4 "context"
5 "encoding/json"
6 "errors"
7 "fmt"
8 "testing"
9
10 "github.com/charmbracelet/crush/internal/llm/tools"
11 "github.com/stretchr/testify/require"
12)
13
14// Mock tool for testing
15type mockTool struct {
16 name string
17 description string
18 parameters map[string]any
19 required []string
20 executeFunc func(ctx context.Context, call tools.ToolCall) (tools.ToolResponse, error)
21}
22
23func (m *mockTool) Info() tools.ToolInfo {
24 return tools.ToolInfo{
25 Name: m.name,
26 Description: m.description,
27 Parameters: m.parameters,
28 Required: m.required,
29 }
30}
31
32func (m *mockTool) Run(ctx context.Context, call tools.ToolCall) (tools.ToolResponse, error) {
33 if m.executeFunc != nil {
34 return m.executeFunc(ctx, call)
35 }
36 return tools.ToolResponse{Content: "mock result", IsError: false}, nil
37}
38
39// Mock language model for testing
40type mockLanguageModel struct {
41 generateFunc func(ctx context.Context, call Call) (*Response, error)
42}
43
44func (m *mockLanguageModel) Generate(ctx context.Context, call Call) (*Response, error) {
45 if m.generateFunc != nil {
46 return m.generateFunc(ctx, call)
47 }
48 return &Response{
49 Content: []Content{
50 TextContent{Text: "Hello, world!"},
51 },
52 Usage: Usage{
53 InputTokens: 3,
54 OutputTokens: 10,
55 TotalTokens: 13,
56 },
57 FinishReason: FinishReasonStop,
58 }, nil
59}
60
61func (m *mockLanguageModel) Stream(ctx context.Context, call Call) (StreamResponse, error) {
62 panic("not implemented")
63}
64
65func (m *mockLanguageModel) Provider() string {
66 return "mock-provider"
67}
68
69func (m *mockLanguageModel) Model() string {
70 return "mock-model"
71}
72
73// Test result.content - comprehensive content types (matches TS test)
74func TestAgent_Generate_ResultContent_AllTypes(t *testing.T) {
75 t.Parallel()
76
77 tool1 := &mockTool{
78 name: "tool1",
79 description: "Test tool",
80 parameters: map[string]any{
81 "value": map[string]any{"type": "string"},
82 },
83 required: []string{"value"},
84 executeFunc: func(ctx context.Context, call tools.ToolCall) (tools.ToolResponse, error) {
85 var input map[string]any
86 err := json.Unmarshal([]byte(call.Input), &input)
87 require.NoError(t, err)
88 require.Equal(t, "value", input["value"])
89 return tools.ToolResponse{Content: "result1", IsError: false}, nil
90 },
91 }
92
93 model := &mockLanguageModel{
94 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
95 return &Response{
96 Content: []Content{
97 TextContent{Text: "Hello, world!"},
98 SourceContent{
99 ID: "123",
100 URL: "https://example.com",
101 Title: "Example",
102 SourceType: SourceTypeURL,
103 ProviderMetadata: ProviderMetadata{
104 "provider": map[string]any{"custom": "value"},
105 },
106 },
107 FileContent{
108 Data: []byte{1, 2, 3},
109 MediaType: "image/png",
110 },
111 ReasoningContent{
112 Text: "I will open the conversation with witty banter.",
113 },
114 ToolCallContent{
115 ToolCallID: "call-1",
116 ToolName: "tool1",
117 Input: `{"value":"value"}`,
118 },
119 TextContent{Text: "More text"},
120 },
121 Usage: Usage{
122 InputTokens: 3,
123 OutputTokens: 10,
124 TotalTokens: 13,
125 },
126 FinishReason: FinishReasonStop, // Note: FinishReasonStop, not ToolCalls
127 }, nil
128 },
129 }
130
131 agent := NewAgent(model, WithTools(tool1))
132 result, err := agent.Generate(context.Background(), AgentCall{
133 Prompt: "prompt",
134 })
135
136 require.NoError(t, err)
137 require.NotNil(t, result)
138 require.Len(t, result.Steps, 1) // Single step like TypeScript
139
140 // Check final response content includes tool result
141 require.Len(t, result.Response.Content, 7) // original 6 + 1 tool result
142
143 // Verify each content type in order
144 textContent, ok := AsContentType[TextContent](result.Response.Content[0])
145 require.True(t, ok)
146 require.Equal(t, "Hello, world!", textContent.Text)
147
148 sourceContent, ok := AsContentType[SourceContent](result.Response.Content[1])
149 require.True(t, ok)
150 require.Equal(t, "123", sourceContent.ID)
151
152 fileContent, ok := AsContentType[FileContent](result.Response.Content[2])
153 require.True(t, ok)
154 require.Equal(t, []byte{1, 2, 3}, fileContent.Data)
155
156 reasoningContent, ok := AsContentType[ReasoningContent](result.Response.Content[3])
157 require.True(t, ok)
158 require.Equal(t, "I will open the conversation with witty banter.", reasoningContent.Text)
159
160 toolCallContent, ok := AsContentType[ToolCallContent](result.Response.Content[4])
161 require.True(t, ok)
162 require.Equal(t, "call-1", toolCallContent.ToolCallID)
163
164 moreTextContent, ok := AsContentType[TextContent](result.Response.Content[5])
165 require.True(t, ok)
166 require.Equal(t, "More text", moreTextContent.Text)
167
168 // Tool result should be appended
169 toolResultContent, ok := AsContentType[ToolResultContent](result.Response.Content[6])
170 require.True(t, ok)
171 require.Equal(t, "call-1", toolResultContent.ToolCallID)
172 require.Equal(t, "tool1", toolResultContent.ToolName)
173}
174
175// Test result.text extraction
176func TestAgent_Generate_ResultText(t *testing.T) {
177 t.Parallel()
178
179 model := &mockLanguageModel{
180 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
181 return &Response{
182 Content: []Content{
183 TextContent{Text: "Hello, world!"},
184 },
185 Usage: Usage{
186 InputTokens: 3,
187 OutputTokens: 10,
188 TotalTokens: 13,
189 },
190 FinishReason: FinishReasonStop,
191 }, nil
192 },
193 }
194
195 agent := NewAgent(model)
196 result, err := agent.Generate(context.Background(), AgentCall{
197 Prompt: "prompt",
198 })
199
200 require.NoError(t, err)
201 require.NotNil(t, result)
202
203 // Test text extraction from content
204 text := result.Response.Content.Text()
205 require.Equal(t, "Hello, world!", text)
206}
207
208// Test result.toolCalls extraction (matches TS test exactly)
209func TestAgent_Generate_ResultToolCalls(t *testing.T) {
210 t.Parallel()
211
212 tool1 := &mockTool{
213 name: "tool1",
214 description: "Test tool 1",
215 parameters: map[string]any{
216 "value": map[string]any{"type": "string"},
217 },
218 required: []string{"value"},
219 }
220
221 tool2 := &mockTool{
222 name: "tool2",
223 description: "Test tool 2",
224 parameters: map[string]any{
225 "somethingElse": map[string]any{"type": "string"},
226 },
227 required: []string{"somethingElse"},
228 }
229
230 model := &mockLanguageModel{
231 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
232 // Verify tools are passed correctly
233 require.Len(t, call.Tools, 2)
234 require.Equal(t, ToolChoiceAuto, *call.ToolChoice) // Should be auto, not required
235
236 // Verify prompt structure
237 require.Len(t, call.Prompt, 1)
238 require.Equal(t, MessageRoleUser, call.Prompt[0].Role)
239
240 return &Response{
241 Content: []Content{
242 ToolCallContent{
243 ToolCallID: "call-1",
244 ToolName: "tool1",
245 Input: `{"value":"value"}`,
246 },
247 },
248 Usage: Usage{
249 InputTokens: 3,
250 OutputTokens: 10,
251 TotalTokens: 13,
252 },
253 FinishReason: FinishReasonStop, // Note: Stop, not ToolCalls
254 }, nil
255 },
256 }
257
258 agent := NewAgent(model, WithTools(tool1, tool2))
259 result, err := agent.Generate(context.Background(), AgentCall{
260 Prompt: "test-input",
261 })
262
263 require.NoError(t, err)
264 require.NotNil(t, result)
265 require.Len(t, result.Steps, 1) // Single step
266
267 // Extract tool calls from final response (should be empty since tools don't execute)
268 var toolCalls []ToolCallContent
269 for _, content := range result.Response.Content {
270 if toolCall, ok := AsContentType[ToolCallContent](content); ok {
271 toolCalls = append(toolCalls, toolCall)
272 }
273 }
274
275 require.Len(t, toolCalls, 1)
276 require.Equal(t, "call-1", toolCalls[0].ToolCallID)
277 require.Equal(t, "tool1", toolCalls[0].ToolName)
278
279 // Parse and verify input
280 var input map[string]any
281 err = json.Unmarshal([]byte(toolCalls[0].Input), &input)
282 require.NoError(t, err)
283 require.Equal(t, "value", input["value"])
284}
285
286// Test result.toolResults extraction (matches TS test exactly)
287func TestAgent_Generate_ResultToolResults(t *testing.T) {
288 t.Parallel()
289
290 tool1 := &mockTool{
291 name: "tool1",
292 description: "Test tool",
293 parameters: map[string]any{
294 "value": map[string]any{"type": "string"},
295 },
296 required: []string{"value"},
297 executeFunc: func(ctx context.Context, call tools.ToolCall) (tools.ToolResponse, error) {
298 var input map[string]any
299 err := json.Unmarshal([]byte(call.Input), &input)
300 require.NoError(t, err)
301 require.Equal(t, "value", input["value"])
302 return tools.ToolResponse{Content: "result1", IsError: false}, nil
303 },
304 }
305
306 model := &mockLanguageModel{
307 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
308 // Verify tools and tool choice
309 require.Len(t, call.Tools, 1)
310 require.Equal(t, ToolChoiceAuto, *call.ToolChoice)
311
312 // Verify prompt
313 require.Len(t, call.Prompt, 1)
314 require.Equal(t, MessageRoleUser, call.Prompt[0].Role)
315
316 return &Response{
317 Content: []Content{
318 ToolCallContent{
319 ToolCallID: "call-1",
320 ToolName: "tool1",
321 Input: `{"value":"value"}`,
322 },
323 },
324 Usage: Usage{
325 InputTokens: 3,
326 OutputTokens: 10,
327 TotalTokens: 13,
328 },
329 FinishReason: FinishReasonStop, // Note: Stop, not ToolCalls
330 }, nil
331 },
332 }
333
334 agent := NewAgent(model, WithTools(tool1))
335 result, err := agent.Generate(context.Background(), AgentCall{
336 Prompt: "test-input",
337 })
338
339 require.NoError(t, err)
340 require.NotNil(t, result)
341 require.Len(t, result.Steps, 1) // Single step
342
343 // Extract tool results from final response
344 var toolResults []ToolResultContent
345 for _, content := range result.Response.Content {
346 if toolResult, ok := AsContentType[ToolResultContent](content); ok {
347 toolResults = append(toolResults, toolResult)
348 }
349 }
350
351 require.Len(t, toolResults, 1)
352 require.Equal(t, "call-1", toolResults[0].ToolCallID)
353 require.Equal(t, "tool1", toolResults[0].ToolName)
354
355 // Verify result content
356 textResult, ok := toolResults[0].Result.(ToolResultOutputContentText)
357 require.True(t, ok)
358 require.Equal(t, "result1", textResult.Text)
359}
360
361// Test multi-step scenario (matches TS "2 steps: initial, tool-result" test)
362func TestAgent_Generate_MultipleSteps(t *testing.T) {
363 t.Parallel()
364
365 tool1 := &mockTool{
366 name: "tool1",
367 description: "Test tool",
368 parameters: map[string]any{
369 "value": map[string]any{"type": "string"},
370 },
371 required: []string{"value"},
372 executeFunc: func(ctx context.Context, call tools.ToolCall) (tools.ToolResponse, error) {
373 var input map[string]any
374 err := json.Unmarshal([]byte(call.Input), &input)
375 require.NoError(t, err)
376 require.Equal(t, "value", input["value"])
377 return tools.ToolResponse{Content: "result1", IsError: false}, nil
378 },
379 }
380
381 callCount := 0
382 model := &mockLanguageModel{
383 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
384 callCount++
385 switch callCount {
386 case 1:
387 // First call - return tool call with FinishReasonToolCalls
388 return &Response{
389 Content: []Content{
390 ToolCallContent{
391 ToolCallID: "call-1",
392 ToolName: "tool1",
393 Input: `{"value":"value"}`,
394 },
395 },
396 Usage: Usage{
397 InputTokens: 10,
398 OutputTokens: 5,
399 TotalTokens: 15,
400 },
401 FinishReason: FinishReasonToolCalls, // This triggers multi-step
402 }, nil
403 case 2:
404 // Second call - return final text
405 return &Response{
406 Content: []Content{
407 TextContent{Text: "Hello, world!"},
408 },
409 Usage: Usage{
410 InputTokens: 3,
411 OutputTokens: 10,
412 TotalTokens: 13,
413 },
414 FinishReason: FinishReasonStop,
415 }, nil
416 default:
417 t.Fatalf("Unexpected call count: %d", callCount)
418 return nil, nil
419 }
420 },
421 }
422
423 agent := NewAgent(model, WithTools(tool1))
424 result, err := agent.Generate(context.Background(), AgentCall{
425 Prompt: "test-input",
426 })
427
428 require.NoError(t, err)
429 require.NotNil(t, result)
430 require.Len(t, result.Steps, 2)
431
432 // Check total usage sums both steps
433 require.Equal(t, int64(13), result.TotalUsage.InputTokens) // 10 + 3
434 require.Equal(t, int64(15), result.TotalUsage.OutputTokens) // 5 + 10
435 require.Equal(t, int64(28), result.TotalUsage.TotalTokens) // 15 + 13
436
437 // Final response should be from last step
438 require.Len(t, result.Response.Content, 1)
439 textContent, ok := AsContentType[TextContent](result.Response.Content[0])
440 require.True(t, ok)
441 require.Equal(t, "Hello, world!", textContent.Text)
442
443 // result.toolCalls should be empty (from last step)
444 var toolCalls []ToolCallContent
445 for _, content := range result.Response.Content {
446 if _, ok := AsContentType[ToolCallContent](content); ok {
447 toolCalls = append(toolCalls, content.(ToolCallContent))
448 }
449 }
450 require.Len(t, toolCalls, 0)
451
452 // result.toolResults should be empty (from last step)
453 var toolResults []ToolResultContent
454 for _, content := range result.Response.Content {
455 if _, ok := AsContentType[ToolResultContent](content); ok {
456 toolResults = append(toolResults, content.(ToolResultContent))
457 }
458 }
459 require.Len(t, toolResults, 0)
460}
461
462// Test basic text generation
463func TestAgent_Generate_BasicText(t *testing.T) {
464 t.Parallel()
465
466 model := &mockLanguageModel{
467 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
468 return &Response{
469 Content: []Content{
470 TextContent{Text: "Hello, world!"},
471 },
472 Usage: Usage{
473 InputTokens: 3,
474 OutputTokens: 10,
475 TotalTokens: 13,
476 },
477 FinishReason: FinishReasonStop,
478 }, nil
479 },
480 }
481
482 agent := NewAgent(model)
483 result, err := agent.Generate(context.Background(), AgentCall{
484 Prompt: "test prompt",
485 })
486
487 require.NoError(t, err)
488 require.NotNil(t, result)
489 require.Len(t, result.Steps, 1)
490
491 // Check final response
492 require.Len(t, result.Response.Content, 1)
493 textContent, ok := AsContentType[TextContent](result.Response.Content[0])
494 require.True(t, ok)
495 require.Equal(t, "Hello, world!", textContent.Text)
496
497 // Check usage
498 require.Equal(t, int64(3), result.Response.Usage.InputTokens)
499 require.Equal(t, int64(10), result.Response.Usage.OutputTokens)
500 require.Equal(t, int64(13), result.Response.Usage.TotalTokens)
501
502 // Check total usage
503 require.Equal(t, int64(3), result.TotalUsage.InputTokens)
504 require.Equal(t, int64(10), result.TotalUsage.OutputTokens)
505 require.Equal(t, int64(13), result.TotalUsage.TotalTokens)
506}
507
508// Test empty prompt error
509func TestAgent_Generate_EmptyPrompt(t *testing.T) {
510 t.Parallel()
511
512 model := &mockLanguageModel{}
513 agent := NewAgent(model)
514
515 result, err := agent.Generate(context.Background(), AgentCall{
516 Prompt: "", // Empty prompt should cause error
517 })
518
519 require.Error(t, err)
520 require.Nil(t, result)
521 require.Contains(t, err.Error(), "Prompt can't be empty")
522}
523
524// Test with system prompt
525func TestAgent_Generate_WithSystemPrompt(t *testing.T) {
526 t.Parallel()
527
528 model := &mockLanguageModel{
529 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
530 // Verify system message is included
531 require.Len(t, call.Prompt, 2) // system + user
532 require.Equal(t, MessageRoleSystem, call.Prompt[0].Role)
533 require.Equal(t, MessageRoleUser, call.Prompt[1].Role)
534
535 systemPart, ok := call.Prompt[0].Content[0].(TextPart)
536 require.True(t, ok)
537 require.Equal(t, "You are a helpful assistant", systemPart.Text)
538
539 return &Response{
540 Content: []Content{
541 TextContent{Text: "Hello, world!"},
542 },
543 Usage: Usage{
544 InputTokens: 3,
545 OutputTokens: 10,
546 TotalTokens: 13,
547 },
548 FinishReason: FinishReasonStop,
549 }, nil
550 },
551 }
552
553 agent := NewAgent(model, WithSystemPrompt("You are a helpful assistant"))
554 result, err := agent.Generate(context.Background(), AgentCall{
555 Prompt: "test prompt",
556 })
557
558 require.NoError(t, err)
559 require.NotNil(t, result)
560}
561
562// Test options.headers
563func TestAgent_Generate_OptionsHeaders(t *testing.T) {
564 t.Parallel()
565
566 model := &mockLanguageModel{
567 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
568 // Verify headers are passed
569 require.Equal(t, map[string]string{
570 "custom-request-header": "request-header-value",
571 }, call.Headers)
572
573 return &Response{
574 Content: []Content{
575 TextContent{Text: "Hello, world!"},
576 },
577 Usage: Usage{
578 InputTokens: 3,
579 OutputTokens: 10,
580 TotalTokens: 13,
581 },
582 FinishReason: FinishReasonStop,
583 }, nil
584 },
585 }
586
587 agent := NewAgent(model)
588 result, err := agent.Generate(context.Background(), AgentCall{
589 Prompt: "test-input",
590 Headers: map[string]string{"custom-request-header": "request-header-value"},
591 })
592
593 require.NoError(t, err)
594 require.NotNil(t, result)
595 require.Equal(t, "Hello, world!", result.Response.Content.Text())
596}
597
598// Test options.activeTools filtering
599func TestAgent_Generate_OptionsActiveTools(t *testing.T) {
600 t.Parallel()
601
602 tool1 := &mockTool{
603 name: "tool1",
604 description: "Test tool 1",
605 parameters: map[string]any{
606 "value": map[string]any{"type": "string"},
607 },
608 required: []string{"value"},
609 }
610
611 tool2 := &mockTool{
612 name: "tool2",
613 description: "Test tool 2",
614 parameters: map[string]any{
615 "value": map[string]any{"type": "string"},
616 },
617 required: []string{"value"},
618 }
619
620 model := &mockLanguageModel{
621 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
622 // Verify only tool1 is available
623 require.Len(t, call.Tools, 1)
624 functionTool, ok := call.Tools[0].(FunctionTool)
625 require.True(t, ok)
626 require.Equal(t, "tool1", functionTool.Name)
627
628 return &Response{
629 Content: []Content{
630 TextContent{Text: "Hello, world!"},
631 },
632 Usage: Usage{
633 InputTokens: 3,
634 OutputTokens: 10,
635 TotalTokens: 13,
636 },
637 FinishReason: FinishReasonStop,
638 }, nil
639 },
640 }
641
642 agent := NewAgent(model, WithTools(tool1, tool2))
643 result, err := agent.Generate(context.Background(), AgentCall{
644 Prompt: "test-input",
645 ActiveTools: []string{"tool1"}, // Only tool1 should be active
646 })
647
648 require.NoError(t, err)
649 require.NotNil(t, result)
650}
651
652func TestResponseContent_Getters(t *testing.T) {
653 t.Parallel()
654
655 // Create test content with all types
656 content := ResponseContent{
657 TextContent{Text: "Hello world"},
658 ReasoningContent{Text: "Let me think..."},
659 FileContent{Data: []byte("file data"), MediaType: "text/plain"},
660 SourceContent{SourceType: SourceTypeURL, URL: "https://example.com", Title: "Example"},
661 ToolCallContent{ToolCallID: "call1", ToolName: "test_tool", Input: `{"arg": "value"}`},
662 ToolResultContent{ToolCallID: "call1", ToolName: "test_tool", Result: ToolResultOutputContentText{Text: "result"}},
663 }
664
665 // Test Text()
666 require.Equal(t, "Hello world", content.Text())
667
668 // Test Reasoning()
669 reasoning := content.Reasoning()
670 require.Len(t, reasoning, 1)
671 require.Equal(t, "Let me think...", reasoning[0].Text)
672
673 // Test ReasoningText()
674 require.Equal(t, "Let me think...", content.ReasoningText())
675
676 // Test Files()
677 files := content.Files()
678 require.Len(t, files, 1)
679 require.Equal(t, "text/plain", files[0].MediaType)
680 require.Equal(t, []byte("file data"), files[0].Data)
681
682 // Test Sources()
683 sources := content.Sources()
684 require.Len(t, sources, 1)
685 require.Equal(t, SourceTypeURL, sources[0].SourceType)
686 require.Equal(t, "https://example.com", sources[0].URL)
687 require.Equal(t, "Example", sources[0].Title)
688
689 // Test ToolCalls()
690 toolCalls := content.ToolCalls()
691 require.Len(t, toolCalls, 1)
692 require.Equal(t, "call1", toolCalls[0].ToolCallID)
693 require.Equal(t, "test_tool", toolCalls[0].ToolName)
694 require.Equal(t, `{"arg": "value"}`, toolCalls[0].Input)
695
696 // Test ToolResults()
697 toolResults := content.ToolResults()
698 require.Len(t, toolResults, 1)
699 require.Equal(t, "call1", toolResults[0].ToolCallID)
700 require.Equal(t, "test_tool", toolResults[0].ToolName)
701 result, ok := AsToolResultOutputType[ToolResultOutputContentText](toolResults[0].Result)
702 require.True(t, ok)
703 require.Equal(t, "result", result.Text)
704}
705
706func TestResponseContent_Getters_Empty(t *testing.T) {
707 t.Parallel()
708
709 // Test with empty content
710 content := ResponseContent{}
711
712 require.Equal(t, "", content.Text())
713 require.Equal(t, "", content.ReasoningText())
714 require.Empty(t, content.Reasoning())
715 require.Empty(t, content.Files())
716 require.Empty(t, content.Sources())
717 require.Empty(t, content.ToolCalls())
718 require.Empty(t, content.ToolResults())
719}
720
721func TestResponseContent_Getters_MultipleItems(t *testing.T) {
722 t.Parallel()
723
724 // Test with multiple items of same type
725 content := ResponseContent{
726 ReasoningContent{Text: "First thought"},
727 ReasoningContent{Text: "Second thought"},
728 FileContent{Data: []byte("file1"), MediaType: "text/plain"},
729 FileContent{Data: []byte("file2"), MediaType: "image/png"},
730 }
731
732 // Test multiple reasoning
733 reasoning := content.Reasoning()
734 require.Len(t, reasoning, 2)
735 require.Equal(t, "First thought", reasoning[0].Text)
736 require.Equal(t, "Second thought", reasoning[1].Text)
737
738 // Test concatenated reasoning text
739 require.Equal(t, "First thoughtSecond thought", content.ReasoningText())
740
741 // Test multiple files
742 files := content.Files()
743 require.Len(t, files, 2)
744 require.Equal(t, "text/plain", files[0].MediaType)
745 require.Equal(t, "image/png", files[1].MediaType)
746}
747
748func TestStopConditions(t *testing.T) {
749 t.Parallel()
750
751 // Create test steps
752 step1 := StepResult{
753 Response: Response{
754 Content: ResponseContent{
755 TextContent{Text: "Hello"},
756 },
757 FinishReason: FinishReasonToolCalls,
758 Usage: Usage{TotalTokens: 10},
759 },
760 }
761
762 step2 := StepResult{
763 Response: Response{
764 Content: ResponseContent{
765 TextContent{Text: "World"},
766 ToolCallContent{ToolCallID: "call1", ToolName: "search", Input: `{"query": "test"}`},
767 },
768 FinishReason: FinishReasonStop,
769 Usage: Usage{TotalTokens: 15},
770 },
771 }
772
773 step3 := StepResult{
774 Response: Response{
775 Content: ResponseContent{
776 ReasoningContent{Text: "Let me think..."},
777 FileContent{Data: []byte("data"), MediaType: "text/plain"},
778 },
779 FinishReason: FinishReasonLength,
780 Usage: Usage{TotalTokens: 20},
781 },
782 }
783
784 t.Run("StepCountIs", func(t *testing.T) {
785 condition := StepCountIs(2)
786
787 // Should not stop with 1 step
788 require.False(t, condition([]StepResult{step1}))
789
790 // Should stop with 2 steps
791 require.True(t, condition([]StepResult{step1, step2}))
792
793 // Should stop with more than 2 steps
794 require.True(t, condition([]StepResult{step1, step2, step3}))
795
796 // Should not stop with empty steps
797 require.False(t, condition([]StepResult{}))
798 })
799
800 t.Run("HasToolCall", func(t *testing.T) {
801 condition := HasToolCall("search")
802
803 // Should not stop when tool not called
804 require.False(t, condition([]StepResult{step1}))
805
806 // Should stop when tool is called in last step
807 require.True(t, condition([]StepResult{step1, step2}))
808
809 // Should not stop when tool called in earlier step but not last
810 require.False(t, condition([]StepResult{step1, step2, step3}))
811
812 // Should not stop with empty steps
813 require.False(t, condition([]StepResult{}))
814
815 // Should not stop when different tool is called
816 differentToolCondition := HasToolCall("different_tool")
817 require.False(t, differentToolCondition([]StepResult{step1, step2}))
818 })
819
820 t.Run("HasContent", func(t *testing.T) {
821 reasoningCondition := HasContent(ContentTypeReasoning)
822 fileCondition := HasContent(ContentTypeFile)
823
824 // Should not stop when content type not present
825 require.False(t, reasoningCondition([]StepResult{step1, step2}))
826
827 // Should stop when content type is present in last step
828 require.True(t, reasoningCondition([]StepResult{step1, step2, step3}))
829 require.True(t, fileCondition([]StepResult{step1, step2, step3}))
830
831 // Should not stop with empty steps
832 require.False(t, reasoningCondition([]StepResult{}))
833 })
834
835 t.Run("FinishReasonIs", func(t *testing.T) {
836 stopCondition := FinishReasonIs(FinishReasonStop)
837 lengthCondition := FinishReasonIs(FinishReasonLength)
838
839 // Should not stop when finish reason doesn't match
840 require.False(t, stopCondition([]StepResult{step1}))
841
842 // Should stop when finish reason matches in last step
843 require.True(t, stopCondition([]StepResult{step1, step2}))
844 require.True(t, lengthCondition([]StepResult{step1, step2, step3}))
845
846 // Should not stop with empty steps
847 require.False(t, stopCondition([]StepResult{}))
848 })
849
850 t.Run("MaxTokensUsed", func(t *testing.T) {
851 condition := MaxTokensUsed(30)
852
853 // Should not stop when under limit
854 require.False(t, condition([]StepResult{step1})) // 10 tokens
855 require.False(t, condition([]StepResult{step1, step2})) // 25 tokens
856
857 // Should stop when at or over limit
858 require.True(t, condition([]StepResult{step1, step2, step3})) // 45 tokens
859
860 // Should not stop with empty steps
861 require.False(t, condition([]StepResult{}))
862 })
863}
864
865func TestStopConditions_Integration(t *testing.T) {
866 t.Parallel()
867
868 t.Run("StepCountIs integration", func(t *testing.T) {
869 model := &mockLanguageModel{
870 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
871 return &Response{
872 Content: ResponseContent{
873 TextContent{Text: "Mock response"},
874 },
875 Usage: Usage{
876 InputTokens: 3,
877 OutputTokens: 10,
878 TotalTokens: 13,
879 },
880 FinishReason: FinishReasonStop,
881 }, nil
882 },
883 }
884
885 agent := NewAgent(model, WithStopConditions(StepCountIs(1)))
886
887 result, err := agent.Generate(context.Background(), AgentCall{
888 Prompt: "test prompt",
889 })
890
891 require.NoError(t, err)
892 require.NotNil(t, result)
893 require.Len(t, result.Steps, 1) // Should stop after 1 step
894 })
895
896 t.Run("Multiple stop conditions", func(t *testing.T) {
897 model := &mockLanguageModel{
898 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
899 return &Response{
900 Content: ResponseContent{
901 TextContent{Text: "Mock response"},
902 },
903 Usage: Usage{
904 InputTokens: 3,
905 OutputTokens: 10,
906 TotalTokens: 13,
907 },
908 FinishReason: FinishReasonStop,
909 }, nil
910 },
911 }
912
913 agent := NewAgent(model, WithStopConditions(
914 StepCountIs(5), // Stop after 5 steps
915 FinishReasonIs(FinishReasonStop), // Or stop on finish reason
916 ))
917
918 result, err := agent.Generate(context.Background(), AgentCall{
919 Prompt: "test prompt",
920 })
921
922 require.NoError(t, err)
923 require.NotNil(t, result)
924 // Should stop on first condition met (finish reason stop)
925 require.Equal(t, FinishReasonStop, result.Response.FinishReason)
926 })
927}
928
929func TestPrepareStep(t *testing.T) {
930 t.Parallel()
931
932 t.Run("System prompt modification", func(t *testing.T) {
933 var capturedSystemPrompt string
934 model := &mockLanguageModel{
935 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
936 // Capture the system message to verify it was modified
937 if len(call.Prompt) > 0 && call.Prompt[0].Role == MessageRoleSystem {
938 if len(call.Prompt[0].Content) > 0 {
939 if textPart, ok := AsContentType[TextPart](call.Prompt[0].Content[0]); ok {
940 capturedSystemPrompt = textPart.Text
941 }
942 }
943 }
944 return &Response{
945 Content: ResponseContent{
946 TextContent{Text: "Response"},
947 },
948 Usage: Usage{TotalTokens: 10},
949 FinishReason: FinishReasonStop,
950 }, nil
951 },
952 }
953
954 prepareStepFunc := func(options PrepareStepFunctionOptions) PrepareStepResult {
955 newSystem := "Modified system prompt for step " + fmt.Sprintf("%d", options.StepNumber)
956 return PrepareStepResult{
957 Model: options.Model,
958 Messages: options.Messages,
959 System: &newSystem,
960 }
961 }
962
963 agent := NewAgent(model, WithSystemPrompt("Original system prompt"))
964
965 result, err := agent.Generate(context.Background(), AgentCall{
966 Prompt: "test prompt",
967 PrepareStep: prepareStepFunc,
968 })
969
970 require.NoError(t, err)
971 require.NotNil(t, result)
972 require.Equal(t, "Modified system prompt for step 0", capturedSystemPrompt)
973 })
974
975 t.Run("Tool choice modification", func(t *testing.T) {
976 var capturedToolChoice *ToolChoice
977 model := &mockLanguageModel{
978 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
979 capturedToolChoice = call.ToolChoice
980 return &Response{
981 Content: ResponseContent{
982 TextContent{Text: "Response"},
983 },
984 Usage: Usage{TotalTokens: 10},
985 FinishReason: FinishReasonStop,
986 }, nil
987 },
988 }
989
990 prepareStepFunc := func(options PrepareStepFunctionOptions) PrepareStepResult {
991 toolChoice := ToolChoiceNone
992 return PrepareStepResult{
993 Model: options.Model,
994 Messages: options.Messages,
995 ToolChoice: &toolChoice,
996 }
997 }
998
999 agent := NewAgent(model)
1000
1001 result, err := agent.Generate(context.Background(), AgentCall{
1002 Prompt: "test prompt",
1003 PrepareStep: prepareStepFunc,
1004 })
1005
1006 require.NoError(t, err)
1007 require.NotNil(t, result)
1008 require.NotNil(t, capturedToolChoice)
1009 require.Equal(t, ToolChoiceNone, *capturedToolChoice)
1010 })
1011
1012 t.Run("Active tools modification", func(t *testing.T) {
1013 var capturedToolNames []string
1014 model := &mockLanguageModel{
1015 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1016 // Capture tool names to verify active tools were modified
1017 for _, tool := range call.Tools {
1018 capturedToolNames = append(capturedToolNames, tool.GetName())
1019 }
1020 return &Response{
1021 Content: ResponseContent{
1022 TextContent{Text: "Response"},
1023 },
1024 Usage: Usage{TotalTokens: 10},
1025 FinishReason: FinishReasonStop,
1026 }, nil
1027 },
1028 }
1029
1030 tool1 := &mockTool{name: "tool1", description: "Tool 1"}
1031 tool2 := &mockTool{name: "tool2", description: "Tool 2"}
1032 tool3 := &mockTool{name: "tool3", description: "Tool 3"}
1033
1034 prepareStepFunc := func(options PrepareStepFunctionOptions) PrepareStepResult {
1035 activeTools := []string{"tool2"} // Only tool2 should be active
1036 return PrepareStepResult{
1037 Model: options.Model,
1038 Messages: options.Messages,
1039 ActiveTools: activeTools,
1040 }
1041 }
1042
1043 agent := NewAgent(model, WithTools(tool1, tool2, tool3))
1044
1045 result, err := agent.Generate(context.Background(), AgentCall{
1046 Prompt: "test prompt",
1047 PrepareStep: prepareStepFunc,
1048 })
1049
1050 require.NoError(t, err)
1051 require.NotNil(t, result)
1052 require.Len(t, capturedToolNames, 1)
1053 require.Equal(t, "tool2", capturedToolNames[0])
1054 })
1055
1056 t.Run("No tools when DisableAllTools is true", func(t *testing.T) {
1057 var capturedToolCount int
1058 model := &mockLanguageModel{
1059 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1060 capturedToolCount = len(call.Tools)
1061 return &Response{
1062 Content: ResponseContent{
1063 TextContent{Text: "Response"},
1064 },
1065 Usage: Usage{TotalTokens: 10},
1066 FinishReason: FinishReasonStop,
1067 }, nil
1068 },
1069 }
1070
1071 tool1 := &mockTool{name: "tool1", description: "Tool 1"}
1072
1073 prepareStepFunc := func(options PrepareStepFunctionOptions) PrepareStepResult {
1074 return PrepareStepResult{
1075 Model: options.Model,
1076 Messages: options.Messages,
1077 DisableAllTools: true, // Disable all tools for this step
1078 }
1079 }
1080
1081 agent := NewAgent(model, WithTools(tool1))
1082
1083 result, err := agent.Generate(context.Background(), AgentCall{
1084 Prompt: "test prompt",
1085 PrepareStep: prepareStepFunc,
1086 })
1087
1088 require.NoError(t, err)
1089 require.NotNil(t, result)
1090 require.Equal(t, 0, capturedToolCount) // No tools should be passed
1091 })
1092
1093 t.Run("All fields modified together", func(t *testing.T) {
1094 var capturedSystemPrompt string
1095 var capturedToolChoice *ToolChoice
1096 var capturedToolNames []string
1097
1098 model := &mockLanguageModel{
1099 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1100 // Capture system prompt
1101 if len(call.Prompt) > 0 && call.Prompt[0].Role == MessageRoleSystem {
1102 if len(call.Prompt[0].Content) > 0 {
1103 if textPart, ok := AsContentType[TextPart](call.Prompt[0].Content[0]); ok {
1104 capturedSystemPrompt = textPart.Text
1105 }
1106 }
1107 }
1108 // Capture tool choice
1109 capturedToolChoice = call.ToolChoice
1110 // Capture tool names
1111 for _, tool := range call.Tools {
1112 capturedToolNames = append(capturedToolNames, tool.GetName())
1113 }
1114 return &Response{
1115 Content: ResponseContent{
1116 TextContent{Text: "Response"},
1117 },
1118 Usage: Usage{TotalTokens: 10},
1119 FinishReason: FinishReasonStop,
1120 }, nil
1121 },
1122 }
1123
1124 tool1 := &mockTool{name: "tool1", description: "Tool 1"}
1125 tool2 := &mockTool{name: "tool2", description: "Tool 2"}
1126
1127 prepareStepFunc := func(options PrepareStepFunctionOptions) PrepareStepResult {
1128 newSystem := "Step-specific system"
1129 toolChoice := SpecificToolChoice("tool1")
1130 activeTools := []string{"tool1"}
1131 return PrepareStepResult{
1132 Model: options.Model,
1133 Messages: options.Messages,
1134 System: &newSystem,
1135 ToolChoice: &toolChoice,
1136 ActiveTools: activeTools,
1137 }
1138 }
1139
1140 agent := NewAgent(model, WithSystemPrompt("Original system"), WithTools(tool1, tool2))
1141
1142 result, err := agent.Generate(context.Background(), AgentCall{
1143 Prompt: "test prompt",
1144 PrepareStep: prepareStepFunc,
1145 })
1146
1147 require.NoError(t, err)
1148 require.NotNil(t, result)
1149 require.Equal(t, "Step-specific system", capturedSystemPrompt)
1150 require.NotNil(t, capturedToolChoice)
1151 require.Equal(t, SpecificToolChoice("tool1"), *capturedToolChoice)
1152 require.Len(t, capturedToolNames, 1)
1153 require.Equal(t, "tool1", capturedToolNames[0])
1154 })
1155
1156 t.Run("Nil fields use parent values", func(t *testing.T) {
1157 var capturedSystemPrompt string
1158 var capturedToolChoice *ToolChoice
1159 var capturedToolNames []string
1160
1161 model := &mockLanguageModel{
1162 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1163 // Capture system prompt
1164 if len(call.Prompt) > 0 && call.Prompt[0].Role == MessageRoleSystem {
1165 if len(call.Prompt[0].Content) > 0 {
1166 if textPart, ok := AsContentType[TextPart](call.Prompt[0].Content[0]); ok {
1167 capturedSystemPrompt = textPart.Text
1168 }
1169 }
1170 }
1171 // Capture tool choice
1172 capturedToolChoice = call.ToolChoice
1173 // Capture tool names
1174 for _, tool := range call.Tools {
1175 capturedToolNames = append(capturedToolNames, tool.GetName())
1176 }
1177 return &Response{
1178 Content: ResponseContent{
1179 TextContent{Text: "Response"},
1180 },
1181 Usage: Usage{TotalTokens: 10},
1182 FinishReason: FinishReasonStop,
1183 }, nil
1184 },
1185 }
1186
1187 tool1 := &mockTool{name: "tool1", description: "Tool 1"}
1188
1189 prepareStepFunc := func(options PrepareStepFunctionOptions) PrepareStepResult {
1190 // All optional fields are nil, should use parent values
1191 return PrepareStepResult{
1192 Model: options.Model,
1193 Messages: options.Messages,
1194 System: nil, // Use parent
1195 ToolChoice: nil, // Use parent (auto)
1196 ActiveTools: nil, // Use parent (all tools)
1197 }
1198 }
1199
1200 agent := NewAgent(model, WithSystemPrompt("Parent system"), WithTools(tool1))
1201
1202 result, err := agent.Generate(context.Background(), AgentCall{
1203 Prompt: "test prompt",
1204 PrepareStep: prepareStepFunc,
1205 })
1206
1207 require.NoError(t, err)
1208 require.NotNil(t, result)
1209 require.Equal(t, "Parent system", capturedSystemPrompt)
1210 require.NotNil(t, capturedToolChoice)
1211 require.Equal(t, ToolChoiceAuto, *capturedToolChoice) // Default
1212 require.Len(t, capturedToolNames, 1)
1213 require.Equal(t, "tool1", capturedToolNames[0])
1214 })
1215
1216 t.Run("Empty ActiveTools means all tools", func(t *testing.T) {
1217 var capturedToolNames []string
1218 model := &mockLanguageModel{
1219 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1220 // Capture tool names to verify all tools are included
1221 for _, tool := range call.Tools {
1222 capturedToolNames = append(capturedToolNames, tool.GetName())
1223 }
1224 return &Response{
1225 Content: ResponseContent{
1226 TextContent{Text: "Response"},
1227 },
1228 Usage: Usage{TotalTokens: 10},
1229 FinishReason: FinishReasonStop,
1230 }, nil
1231 },
1232 }
1233
1234 tool1 := &mockTool{name: "tool1", description: "Tool 1"}
1235 tool2 := &mockTool{name: "tool2", description: "Tool 2"}
1236
1237 prepareStepFunc := func(options PrepareStepFunctionOptions) PrepareStepResult {
1238 return PrepareStepResult{
1239 Model: options.Model,
1240 Messages: options.Messages,
1241 ActiveTools: []string{}, // Empty slice means all tools
1242 }
1243 }
1244
1245 agent := NewAgent(model, WithTools(tool1, tool2))
1246
1247 result, err := agent.Generate(context.Background(), AgentCall{
1248 Prompt: "test prompt",
1249 PrepareStep: prepareStepFunc,
1250 })
1251
1252 require.NoError(t, err)
1253 require.NotNil(t, result)
1254 require.Len(t, capturedToolNames, 2) // All tools should be included
1255 require.Contains(t, capturedToolNames, "tool1")
1256 require.Contains(t, capturedToolNames, "tool2")
1257 })
1258}
1259
1260func TestToolCallRepair(t *testing.T) {
1261 t.Parallel()
1262
1263 t.Run("Valid tool call passes validation", func(t *testing.T) {
1264 model := &mockLanguageModel{
1265 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1266 return &Response{
1267 Content: ResponseContent{
1268 TextContent{Text: "Response"},
1269 ToolCallContent{
1270 ToolCallID: "call1",
1271 ToolName: "test_tool",
1272 Input: `{"value": "test"}`, // Valid JSON with required field
1273 },
1274 },
1275 Usage: Usage{TotalTokens: 10},
1276 FinishReason: FinishReasonStop, // Changed to stop to avoid infinite loop
1277 }, nil
1278 },
1279 }
1280
1281 tool := &mockTool{
1282 name: "test_tool",
1283 description: "Test tool",
1284 parameters: map[string]any{
1285 "value": map[string]any{"type": "string"},
1286 },
1287 required: []string{"value"},
1288 executeFunc: func(ctx context.Context, call tools.ToolCall) (tools.ToolResponse, error) {
1289 return tools.ToolResponse{Content: "success", IsError: false}, nil
1290 },
1291 }
1292
1293 agent := NewAgent(model, WithTools(tool), WithStopConditions(StepCountIs(2))) // Limit steps
1294
1295 result, err := agent.Generate(context.Background(), AgentCall{
1296 Prompt: "test prompt",
1297 })
1298
1299 require.NoError(t, err)
1300 require.NotNil(t, result)
1301 require.Len(t, result.Steps, 1) // Only one step since FinishReason is stop
1302
1303 // Check that tool call was executed successfully
1304 toolCalls := result.Steps[0].Response.Content.ToolCalls()
1305 require.Len(t, toolCalls, 1)
1306 require.False(t, toolCalls[0].Invalid) // Should be valid
1307 })
1308
1309 t.Run("Invalid tool call without repair function", func(t *testing.T) {
1310 model := &mockLanguageModel{
1311 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1312 return &Response{
1313 Content: ResponseContent{
1314 TextContent{Text: "Response"},
1315 ToolCallContent{
1316 ToolCallID: "call1",
1317 ToolName: "test_tool",
1318 Input: `{"wrong_field": "test"}`, // Missing required field
1319 },
1320 },
1321 Usage: Usage{TotalTokens: 10},
1322 FinishReason: FinishReasonStop, // Changed to stop to avoid infinite loop
1323 }, nil
1324 },
1325 }
1326
1327 tool := &mockTool{
1328 name: "test_tool",
1329 description: "Test tool",
1330 parameters: map[string]any{
1331 "value": map[string]any{"type": "string"},
1332 },
1333 required: []string{"value"},
1334 }
1335
1336 agent := NewAgent(model, WithTools(tool), WithStopConditions(StepCountIs(2))) // Limit steps
1337
1338 result, err := agent.Generate(context.Background(), AgentCall{
1339 Prompt: "test prompt",
1340 })
1341
1342 require.NoError(t, err)
1343 require.NotNil(t, result)
1344 require.Len(t, result.Steps, 1) // Only one step
1345
1346 // Check that tool call was marked as invalid
1347 toolCalls := result.Steps[0].Response.Content.ToolCalls()
1348 require.Len(t, toolCalls, 1)
1349 require.True(t, toolCalls[0].Invalid) // Should be invalid
1350 require.Contains(t, toolCalls[0].ValidationError.Error(), "missing required parameter: value")
1351 })
1352
1353 t.Run("Invalid tool call with successful repair", func(t *testing.T) {
1354 model := &mockLanguageModel{
1355 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1356 return &Response{
1357 Content: ResponseContent{
1358 TextContent{Text: "Response"},
1359 ToolCallContent{
1360 ToolCallID: "call1",
1361 ToolName: "test_tool",
1362 Input: `{"wrong_field": "test"}`, // Missing required field
1363 },
1364 },
1365 Usage: Usage{TotalTokens: 10},
1366 FinishReason: FinishReasonStop, // Changed to stop
1367 }, nil
1368 },
1369 }
1370
1371 tool := &mockTool{
1372 name: "test_tool",
1373 description: "Test tool",
1374 parameters: map[string]any{
1375 "value": map[string]any{"type": "string"},
1376 },
1377 required: []string{"value"},
1378 executeFunc: func(ctx context.Context, call tools.ToolCall) (tools.ToolResponse, error) {
1379 return tools.ToolResponse{Content: "repaired_success", IsError: false}, nil
1380 },
1381 }
1382
1383 repairFunc := func(ctx context.Context, options ToolCallRepairOptions) (*ToolCallContent, error) {
1384 // Simple repair: add the missing required field
1385 repairedToolCall := options.OriginalToolCall
1386 repairedToolCall.Input = `{"value": "repaired"}`
1387 return &repairedToolCall, nil
1388 }
1389
1390 agent := NewAgent(model, WithTools(tool), WithRepairToolCall(repairFunc), WithStopConditions(StepCountIs(2)))
1391
1392 result, err := agent.Generate(context.Background(), AgentCall{
1393 Prompt: "test prompt",
1394 })
1395
1396 require.NoError(t, err)
1397 require.NotNil(t, result)
1398 require.Len(t, result.Steps, 1) // Only one step
1399
1400 // Check that tool call was repaired and is now valid
1401 toolCalls := result.Steps[0].Response.Content.ToolCalls()
1402 require.Len(t, toolCalls, 1)
1403 require.False(t, toolCalls[0].Invalid) // Should be valid after repair
1404 require.Equal(t, `{"value": "repaired"}`, toolCalls[0].Input) // Should have repaired input
1405 })
1406
1407 t.Run("Invalid tool call with failed repair", func(t *testing.T) {
1408 model := &mockLanguageModel{
1409 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1410 return &Response{
1411 Content: ResponseContent{
1412 TextContent{Text: "Response"},
1413 ToolCallContent{
1414 ToolCallID: "call1",
1415 ToolName: "test_tool",
1416 Input: `{"wrong_field": "test"}`, // Missing required field
1417 },
1418 },
1419 Usage: Usage{TotalTokens: 10},
1420 FinishReason: FinishReasonStop, // Changed to stop
1421 }, nil
1422 },
1423 }
1424
1425 tool := &mockTool{
1426 name: "test_tool",
1427 description: "Test tool",
1428 parameters: map[string]any{
1429 "value": map[string]any{"type": "string"},
1430 },
1431 required: []string{"value"},
1432 }
1433
1434 repairFunc := func(ctx context.Context, options ToolCallRepairOptions) (*ToolCallContent, error) {
1435 // Repair function fails
1436 return nil, errors.New("repair failed")
1437 }
1438
1439 agent := NewAgent(model, WithTools(tool), WithRepairToolCall(repairFunc), WithStopConditions(StepCountIs(2)))
1440
1441 result, err := agent.Generate(context.Background(), AgentCall{
1442 Prompt: "test prompt",
1443 })
1444
1445 require.NoError(t, err)
1446 require.NotNil(t, result)
1447 require.Len(t, result.Steps, 1) // Only one step
1448
1449 // Check that tool call was marked as invalid since repair failed
1450 toolCalls := result.Steps[0].Response.Content.ToolCalls()
1451 require.Len(t, toolCalls, 1)
1452 require.True(t, toolCalls[0].Invalid) // Should be invalid
1453 require.Contains(t, toolCalls[0].ValidationError.Error(), "missing required parameter: value")
1454 })
1455
1456 t.Run("Nonexistent tool call", func(t *testing.T) {
1457 model := &mockLanguageModel{
1458 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1459 return &Response{
1460 Content: ResponseContent{
1461 TextContent{Text: "Response"},
1462 ToolCallContent{
1463 ToolCallID: "call1",
1464 ToolName: "nonexistent_tool",
1465 Input: `{"value": "test"}`,
1466 },
1467 },
1468 Usage: Usage{TotalTokens: 10},
1469 FinishReason: FinishReasonStop, // Changed to stop
1470 }, nil
1471 },
1472 }
1473
1474 tool := &mockTool{name: "test_tool", description: "Test tool"}
1475
1476 agent := NewAgent(model, WithTools(tool), WithStopConditions(StepCountIs(2)))
1477
1478 result, err := agent.Generate(context.Background(), AgentCall{
1479 Prompt: "test prompt",
1480 })
1481
1482 require.NoError(t, err)
1483 require.NotNil(t, result)
1484 require.Len(t, result.Steps, 1) // Only one step
1485
1486 // Check that tool call was marked as invalid due to nonexistent tool
1487 toolCalls := result.Steps[0].Response.Content.ToolCalls()
1488 require.Len(t, toolCalls, 1)
1489 require.True(t, toolCalls[0].Invalid) // Should be invalid
1490 require.Contains(t, toolCalls[0].ValidationError.Error(), "tool not found: nonexistent_tool")
1491 })
1492
1493 t.Run("Invalid JSON in tool call", func(t *testing.T) {
1494 model := &mockLanguageModel{
1495 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
1496 return &Response{
1497 Content: ResponseContent{
1498 TextContent{Text: "Response"},
1499 ToolCallContent{
1500 ToolCallID: "call1",
1501 ToolName: "test_tool",
1502 Input: `{invalid json}`, // Invalid JSON
1503 },
1504 },
1505 Usage: Usage{TotalTokens: 10},
1506 FinishReason: FinishReasonStop, // Changed to stop
1507 }, nil
1508 },
1509 }
1510
1511 tool := &mockTool{
1512 name: "test_tool",
1513 description: "Test tool",
1514 parameters: map[string]any{
1515 "value": map[string]any{"type": "string"},
1516 },
1517 required: []string{"value"},
1518 }
1519
1520 agent := NewAgent(model, WithTools(tool), WithStopConditions(StepCountIs(2)))
1521
1522 result, err := agent.Generate(context.Background(), AgentCall{
1523 Prompt: "test prompt",
1524 })
1525
1526 require.NoError(t, err)
1527 require.NotNil(t, result)
1528 require.Len(t, result.Steps, 1) // Only one step
1529
1530 // Check that tool call was marked as invalid due to invalid JSON
1531 toolCalls := result.Steps[0].Response.Content.ToolCalls()
1532 require.Len(t, toolCalls, 1)
1533 require.True(t, toolCalls[0].Invalid) // Should be invalid
1534 require.Contains(t, toolCalls[0].ValidationError.Error(), "invalid JSON input")
1535 })
1536}