1package ai
2
3import (
4 "context"
5 "encoding/json"
6 "testing"
7
8 "github.com/charmbracelet/crush/internal/llm/tools"
9 "github.com/stretchr/testify/require"
10)
11
12// Mock tool for testing
13type mockTool struct {
14 name string
15 description string
16 parameters map[string]any
17 required []string
18 executeFunc func(ctx context.Context, call tools.ToolCall) (tools.ToolResponse, error)
19}
20
21func (m *mockTool) Info() tools.ToolInfo {
22 return tools.ToolInfo{
23 Name: m.name,
24 Description: m.description,
25 Parameters: m.parameters,
26 Required: m.required,
27 }
28}
29
30func (m *mockTool) Run(ctx context.Context, call tools.ToolCall) (tools.ToolResponse, error) {
31 if m.executeFunc != nil {
32 return m.executeFunc(ctx, call)
33 }
34 return tools.ToolResponse{Content: "mock result", IsError: false}, nil
35}
36
37// Mock language model for testing
38type mockLanguageModel struct {
39 generateFunc func(ctx context.Context, call Call) (*Response, error)
40}
41
42func (m *mockLanguageModel) Generate(ctx context.Context, call Call) (*Response, error) {
43 if m.generateFunc != nil {
44 return m.generateFunc(ctx, call)
45 }
46 return &Response{
47 Content: []Content{
48 TextContent{Text: "Hello, world!"},
49 },
50 Usage: Usage{
51 InputTokens: 3,
52 OutputTokens: 10,
53 TotalTokens: 13,
54 },
55 FinishReason: FinishReasonStop,
56 }, nil
57}
58
59func (m *mockLanguageModel) Stream(ctx context.Context, call Call) (StreamResponse, error) {
60 panic("not implemented")
61}
62
63func (m *mockLanguageModel) Provider() string {
64 return "mock-provider"
65}
66
67func (m *mockLanguageModel) Model() string {
68 return "mock-model"
69}
70
71// Test result.content - comprehensive content types (matches TS test)
72func TestAgent_Generate_ResultContent_AllTypes(t *testing.T) {
73 t.Parallel()
74
75 tool1 := &mockTool{
76 name: "tool1",
77 description: "Test tool",
78 parameters: map[string]any{
79 "value": map[string]any{"type": "string"},
80 },
81 required: []string{"value"},
82 executeFunc: func(ctx context.Context, call tools.ToolCall) (tools.ToolResponse, error) {
83 var input map[string]any
84 err := json.Unmarshal([]byte(call.Input), &input)
85 require.NoError(t, err)
86 require.Equal(t, "value", input["value"])
87 return tools.ToolResponse{Content: "result1", IsError: false}, nil
88 },
89 }
90
91 model := &mockLanguageModel{
92 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
93 return &Response{
94 Content: []Content{
95 TextContent{Text: "Hello, world!"},
96 SourceContent{
97 ID: "123",
98 URL: "https://example.com",
99 Title: "Example",
100 SourceType: SourceTypeURL,
101 ProviderMetadata: ProviderMetadata{
102 "provider": map[string]any{"custom": "value"},
103 },
104 },
105 FileContent{
106 Data: []byte{1, 2, 3},
107 MediaType: "image/png",
108 },
109 ReasoningContent{
110 Text: "I will open the conversation with witty banter.",
111 },
112 ToolCallContent{
113 ToolCallID: "call-1",
114 ToolName: "tool1",
115 Input: `{"value":"value"}`,
116 },
117 TextContent{Text: "More text"},
118 },
119 Usage: Usage{
120 InputTokens: 3,
121 OutputTokens: 10,
122 TotalTokens: 13,
123 },
124 FinishReason: FinishReasonStop, // Note: FinishReasonStop, not ToolCalls
125 }, nil
126 },
127 }
128
129 agent := NewAgent(model, WithTools(tool1))
130 result, err := agent.Generate(context.Background(), AgentCall{
131 Prompt: "prompt",
132 })
133
134 require.NoError(t, err)
135 require.NotNil(t, result)
136 require.Len(t, result.Steps, 1) // Single step like TypeScript
137
138 // Check final response content includes tool result
139 require.Len(t, result.Response.Content, 7) // original 6 + 1 tool result
140
141 // Verify each content type in order
142 textContent, ok := AsContentType[TextContent](result.Response.Content[0])
143 require.True(t, ok)
144 require.Equal(t, "Hello, world!", textContent.Text)
145
146 sourceContent, ok := AsContentType[SourceContent](result.Response.Content[1])
147 require.True(t, ok)
148 require.Equal(t, "123", sourceContent.ID)
149
150 fileContent, ok := AsContentType[FileContent](result.Response.Content[2])
151 require.True(t, ok)
152 require.Equal(t, []byte{1, 2, 3}, fileContent.Data)
153
154 reasoningContent, ok := AsContentType[ReasoningContent](result.Response.Content[3])
155 require.True(t, ok)
156 require.Equal(t, "I will open the conversation with witty banter.", reasoningContent.Text)
157
158 toolCallContent, ok := AsContentType[ToolCallContent](result.Response.Content[4])
159 require.True(t, ok)
160 require.Equal(t, "call-1", toolCallContent.ToolCallID)
161
162 moreTextContent, ok := AsContentType[TextContent](result.Response.Content[5])
163 require.True(t, ok)
164 require.Equal(t, "More text", moreTextContent.Text)
165
166 // Tool result should be appended
167 toolResultContent, ok := AsContentType[ToolResultContent](result.Response.Content[6])
168 require.True(t, ok)
169 require.Equal(t, "call-1", toolResultContent.ToolCallID)
170 require.Equal(t, "tool1", toolResultContent.ToolName)
171}
172
173// Test result.text extraction
174func TestAgent_Generate_ResultText(t *testing.T) {
175 t.Parallel()
176
177 model := &mockLanguageModel{
178 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
179 return &Response{
180 Content: []Content{
181 TextContent{Text: "Hello, world!"},
182 },
183 Usage: Usage{
184 InputTokens: 3,
185 OutputTokens: 10,
186 TotalTokens: 13,
187 },
188 FinishReason: FinishReasonStop,
189 }, nil
190 },
191 }
192
193 agent := NewAgent(model)
194 result, err := agent.Generate(context.Background(), AgentCall{
195 Prompt: "prompt",
196 })
197
198 require.NoError(t, err)
199 require.NotNil(t, result)
200
201 // Test text extraction from content
202 text := result.Response.Content.Text()
203 require.Equal(t, "Hello, world!", text)
204}
205
206// Test result.toolCalls extraction (matches TS test exactly)
207func TestAgent_Generate_ResultToolCalls(t *testing.T) {
208 t.Parallel()
209
210 tool1 := &mockTool{
211 name: "tool1",
212 description: "Test tool 1",
213 parameters: map[string]any{
214 "value": map[string]any{"type": "string"},
215 },
216 required: []string{"value"},
217 }
218
219 tool2 := &mockTool{
220 name: "tool2",
221 description: "Test tool 2",
222 parameters: map[string]any{
223 "somethingElse": map[string]any{"type": "string"},
224 },
225 required: []string{"somethingElse"},
226 }
227
228 model := &mockLanguageModel{
229 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
230 // Verify tools are passed correctly
231 require.Len(t, call.Tools, 2)
232 require.Equal(t, ToolChoiceAuto, *call.ToolChoice) // Should be auto, not required
233
234 // Verify prompt structure
235 require.Len(t, call.Prompt, 1)
236 require.Equal(t, MessageRoleUser, call.Prompt[0].Role)
237
238 return &Response{
239 Content: []Content{
240 ToolCallContent{
241 ToolCallID: "call-1",
242 ToolName: "tool1",
243 Input: `{"value":"value"}`,
244 },
245 },
246 Usage: Usage{
247 InputTokens: 3,
248 OutputTokens: 10,
249 TotalTokens: 13,
250 },
251 FinishReason: FinishReasonStop, // Note: Stop, not ToolCalls
252 }, nil
253 },
254 }
255
256 agent := NewAgent(model, WithTools(tool1, tool2))
257 result, err := agent.Generate(context.Background(), AgentCall{
258 Prompt: "test-input",
259 })
260
261 require.NoError(t, err)
262 require.NotNil(t, result)
263 require.Len(t, result.Steps, 1) // Single step
264
265 // Extract tool calls from final response (should be empty since tools don't execute)
266 var toolCalls []ToolCallContent
267 for _, content := range result.Response.Content {
268 if toolCall, ok := AsContentType[ToolCallContent](content); ok {
269 toolCalls = append(toolCalls, toolCall)
270 }
271 }
272
273 require.Len(t, toolCalls, 1)
274 require.Equal(t, "call-1", toolCalls[0].ToolCallID)
275 require.Equal(t, "tool1", toolCalls[0].ToolName)
276
277 // Parse and verify input
278 var input map[string]any
279 err = json.Unmarshal([]byte(toolCalls[0].Input), &input)
280 require.NoError(t, err)
281 require.Equal(t, "value", input["value"])
282}
283
284// Test result.toolResults extraction (matches TS test exactly)
285func TestAgent_Generate_ResultToolResults(t *testing.T) {
286 t.Parallel()
287
288 tool1 := &mockTool{
289 name: "tool1",
290 description: "Test tool",
291 parameters: map[string]any{
292 "value": map[string]any{"type": "string"},
293 },
294 required: []string{"value"},
295 executeFunc: func(ctx context.Context, call tools.ToolCall) (tools.ToolResponse, error) {
296 var input map[string]any
297 err := json.Unmarshal([]byte(call.Input), &input)
298 require.NoError(t, err)
299 require.Equal(t, "value", input["value"])
300 return tools.ToolResponse{Content: "result1", IsError: false}, nil
301 },
302 }
303
304 model := &mockLanguageModel{
305 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
306 // Verify tools and tool choice
307 require.Len(t, call.Tools, 1)
308 require.Equal(t, ToolChoiceAuto, *call.ToolChoice)
309
310 // Verify prompt
311 require.Len(t, call.Prompt, 1)
312 require.Equal(t, MessageRoleUser, call.Prompt[0].Role)
313
314 return &Response{
315 Content: []Content{
316 ToolCallContent{
317 ToolCallID: "call-1",
318 ToolName: "tool1",
319 Input: `{"value":"value"}`,
320 },
321 },
322 Usage: Usage{
323 InputTokens: 3,
324 OutputTokens: 10,
325 TotalTokens: 13,
326 },
327 FinishReason: FinishReasonStop, // Note: Stop, not ToolCalls
328 }, nil
329 },
330 }
331
332 agent := NewAgent(model, WithTools(tool1))
333 result, err := agent.Generate(context.Background(), AgentCall{
334 Prompt: "test-input",
335 })
336
337 require.NoError(t, err)
338 require.NotNil(t, result)
339 require.Len(t, result.Steps, 1) // Single step
340
341 // Extract tool results from final response
342 var toolResults []ToolResultContent
343 for _, content := range result.Response.Content {
344 if toolResult, ok := AsContentType[ToolResultContent](content); ok {
345 toolResults = append(toolResults, toolResult)
346 }
347 }
348
349 require.Len(t, toolResults, 1)
350 require.Equal(t, "call-1", toolResults[0].ToolCallID)
351 require.Equal(t, "tool1", toolResults[0].ToolName)
352
353 // Verify result content
354 textResult, ok := toolResults[0].Result.(ToolResultOutputContentText)
355 require.True(t, ok)
356 require.Equal(t, "result1", textResult.Text)
357}
358
359// Test multi-step scenario (matches TS "2 steps: initial, tool-result" test)
360func TestAgent_Generate_MultipleSteps(t *testing.T) {
361 t.Parallel()
362
363 tool1 := &mockTool{
364 name: "tool1",
365 description: "Test tool",
366 parameters: map[string]any{
367 "value": map[string]any{"type": "string"},
368 },
369 required: []string{"value"},
370 executeFunc: func(ctx context.Context, call tools.ToolCall) (tools.ToolResponse, error) {
371 var input map[string]any
372 err := json.Unmarshal([]byte(call.Input), &input)
373 require.NoError(t, err)
374 require.Equal(t, "value", input["value"])
375 return tools.ToolResponse{Content: "result1", IsError: false}, nil
376 },
377 }
378
379 callCount := 0
380 model := &mockLanguageModel{
381 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
382 callCount++
383 switch callCount {
384 case 1:
385 // First call - return tool call with FinishReasonToolCalls
386 return &Response{
387 Content: []Content{
388 ToolCallContent{
389 ToolCallID: "call-1",
390 ToolName: "tool1",
391 Input: `{"value":"value"}`,
392 },
393 },
394 Usage: Usage{
395 InputTokens: 10,
396 OutputTokens: 5,
397 TotalTokens: 15,
398 },
399 FinishReason: FinishReasonToolCalls, // This triggers multi-step
400 }, nil
401 case 2:
402 // Second call - return final text
403 return &Response{
404 Content: []Content{
405 TextContent{Text: "Hello, world!"},
406 },
407 Usage: Usage{
408 InputTokens: 3,
409 OutputTokens: 10,
410 TotalTokens: 13,
411 },
412 FinishReason: FinishReasonStop,
413 }, nil
414 default:
415 t.Fatalf("Unexpected call count: %d", callCount)
416 return nil, nil
417 }
418 },
419 }
420
421 agent := NewAgent(model, WithTools(tool1))
422 result, err := agent.Generate(context.Background(), AgentCall{
423 Prompt: "test-input",
424 })
425
426 require.NoError(t, err)
427 require.NotNil(t, result)
428 require.Len(t, result.Steps, 2)
429
430 // Check total usage sums both steps
431 require.Equal(t, int64(13), result.TotalUsage.InputTokens) // 10 + 3
432 require.Equal(t, int64(15), result.TotalUsage.OutputTokens) // 5 + 10
433 require.Equal(t, int64(28), result.TotalUsage.TotalTokens) // 15 + 13
434
435 // Final response should be from last step
436 require.Len(t, result.Response.Content, 1)
437 textContent, ok := AsContentType[TextContent](result.Response.Content[0])
438 require.True(t, ok)
439 require.Equal(t, "Hello, world!", textContent.Text)
440
441 // result.toolCalls should be empty (from last step)
442 var toolCalls []ToolCallContent
443 for _, content := range result.Response.Content {
444 if _, ok := AsContentType[ToolCallContent](content); ok {
445 toolCalls = append(toolCalls, content.(ToolCallContent))
446 }
447 }
448 require.Len(t, toolCalls, 0)
449
450 // result.toolResults should be empty (from last step)
451 var toolResults []ToolResultContent
452 for _, content := range result.Response.Content {
453 if _, ok := AsContentType[ToolResultContent](content); ok {
454 toolResults = append(toolResults, content.(ToolResultContent))
455 }
456 }
457 require.Len(t, toolResults, 0)
458}
459
460// Test basic text generation
461func TestAgent_Generate_BasicText(t *testing.T) {
462 t.Parallel()
463
464 model := &mockLanguageModel{
465 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
466 return &Response{
467 Content: []Content{
468 TextContent{Text: "Hello, world!"},
469 },
470 Usage: Usage{
471 InputTokens: 3,
472 OutputTokens: 10,
473 TotalTokens: 13,
474 },
475 FinishReason: FinishReasonStop,
476 }, nil
477 },
478 }
479
480 agent := NewAgent(model)
481 result, err := agent.Generate(context.Background(), AgentCall{
482 Prompt: "test prompt",
483 })
484
485 require.NoError(t, err)
486 require.NotNil(t, result)
487 require.Len(t, result.Steps, 1)
488
489 // Check final response
490 require.Len(t, result.Response.Content, 1)
491 textContent, ok := AsContentType[TextContent](result.Response.Content[0])
492 require.True(t, ok)
493 require.Equal(t, "Hello, world!", textContent.Text)
494
495 // Check usage
496 require.Equal(t, int64(3), result.Response.Usage.InputTokens)
497 require.Equal(t, int64(10), result.Response.Usage.OutputTokens)
498 require.Equal(t, int64(13), result.Response.Usage.TotalTokens)
499
500 // Check total usage
501 require.Equal(t, int64(3), result.TotalUsage.InputTokens)
502 require.Equal(t, int64(10), result.TotalUsage.OutputTokens)
503 require.Equal(t, int64(13), result.TotalUsage.TotalTokens)
504}
505
506// Test empty prompt error
507func TestAgent_Generate_EmptyPrompt(t *testing.T) {
508 t.Parallel()
509
510 model := &mockLanguageModel{}
511 agent := NewAgent(model)
512
513 result, err := agent.Generate(context.Background(), AgentCall{
514 Prompt: "", // Empty prompt should cause error
515 })
516
517 require.Error(t, err)
518 require.Nil(t, result)
519 require.Contains(t, err.Error(), "Prompt can't be empty")
520}
521
522// Test with system prompt
523func TestAgent_Generate_WithSystemPrompt(t *testing.T) {
524 t.Parallel()
525
526 model := &mockLanguageModel{
527 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
528 // Verify system message is included
529 require.Len(t, call.Prompt, 2) // system + user
530 require.Equal(t, MessageRoleSystem, call.Prompt[0].Role)
531 require.Equal(t, MessageRoleUser, call.Prompt[1].Role)
532
533 systemPart, ok := call.Prompt[0].Content[0].(TextPart)
534 require.True(t, ok)
535 require.Equal(t, "You are a helpful assistant", systemPart.Text)
536
537 return &Response{
538 Content: []Content{
539 TextContent{Text: "Hello, world!"},
540 },
541 Usage: Usage{
542 InputTokens: 3,
543 OutputTokens: 10,
544 TotalTokens: 13,
545 },
546 FinishReason: FinishReasonStop,
547 }, nil
548 },
549 }
550
551 agent := NewAgent(model, WithSystemPrompt("You are a helpful assistant"))
552 result, err := agent.Generate(context.Background(), AgentCall{
553 Prompt: "test prompt",
554 })
555
556 require.NoError(t, err)
557 require.NotNil(t, result)
558}
559
560// Test options.headers
561func TestAgent_Generate_OptionsHeaders(t *testing.T) {
562 t.Parallel()
563
564 model := &mockLanguageModel{
565 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
566 // Verify headers are passed
567 require.Equal(t, map[string]string{
568 "custom-request-header": "request-header-value",
569 }, call.Headers)
570
571 return &Response{
572 Content: []Content{
573 TextContent{Text: "Hello, world!"},
574 },
575 Usage: Usage{
576 InputTokens: 3,
577 OutputTokens: 10,
578 TotalTokens: 13,
579 },
580 FinishReason: FinishReasonStop,
581 }, nil
582 },
583 }
584
585 agent := NewAgent(model)
586 result, err := agent.Generate(context.Background(), AgentCall{
587 Prompt: "test-input",
588 Headers: map[string]string{"custom-request-header": "request-header-value"},
589 })
590
591 require.NoError(t, err)
592 require.NotNil(t, result)
593 require.Equal(t, "Hello, world!", result.Response.Content.Text())
594}
595
596// Test options.activeTools filtering
597func TestAgent_Generate_OptionsActiveTools(t *testing.T) {
598 t.Parallel()
599
600 tool1 := &mockTool{
601 name: "tool1",
602 description: "Test tool 1",
603 parameters: map[string]any{
604 "value": map[string]any{"type": "string"},
605 },
606 required: []string{"value"},
607 }
608
609 tool2 := &mockTool{
610 name: "tool2",
611 description: "Test tool 2",
612 parameters: map[string]any{
613 "value": map[string]any{"type": "string"},
614 },
615 required: []string{"value"},
616 }
617
618 model := &mockLanguageModel{
619 generateFunc: func(ctx context.Context, call Call) (*Response, error) {
620 // Verify only tool1 is available
621 require.Len(t, call.Tools, 1)
622 functionTool, ok := call.Tools[0].(FunctionTool)
623 require.True(t, ok)
624 require.Equal(t, "tool1", functionTool.Name)
625
626 return &Response{
627 Content: []Content{
628 TextContent{Text: "Hello, world!"},
629 },
630 Usage: Usage{
631 InputTokens: 3,
632 OutputTokens: 10,
633 TotalTokens: 13,
634 },
635 FinishReason: FinishReasonStop,
636 }, nil
637 },
638 }
639
640 agent := NewAgent(model, WithTools(tool1, tool2))
641 result, err := agent.Generate(context.Background(), AgentCall{
642 Prompt: "test-input",
643 ActiveTools: []string{"tool1"}, // Only tool1 should be active
644 })
645
646 require.NoError(t, err)
647 require.NotNil(t, result)
648}