agent_test.go

  1package ai
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"testing"
  7
  8	"github.com/charmbracelet/crush/internal/llm/tools"
  9	"github.com/stretchr/testify/require"
 10)
 11
 12// Mock tool for testing
 13type mockTool struct {
 14	name        string
 15	description string
 16	parameters  map[string]any
 17	required    []string
 18	executeFunc func(ctx context.Context, call tools.ToolCall) (tools.ToolResponse, error)
 19}
 20
 21func (m *mockTool) Info() tools.ToolInfo {
 22	return tools.ToolInfo{
 23		Name:        m.name,
 24		Description: m.description,
 25		Parameters:  m.parameters,
 26		Required:    m.required,
 27	}
 28}
 29
 30func (m *mockTool) Run(ctx context.Context, call tools.ToolCall) (tools.ToolResponse, error) {
 31	if m.executeFunc != nil {
 32		return m.executeFunc(ctx, call)
 33	}
 34	return tools.ToolResponse{Content: "mock result", IsError: false}, nil
 35}
 36
 37// Mock language model for testing
 38type mockLanguageModel struct {
 39	generateFunc func(ctx context.Context, call Call) (*Response, error)
 40}
 41
 42func (m *mockLanguageModel) Generate(ctx context.Context, call Call) (*Response, error) {
 43	if m.generateFunc != nil {
 44		return m.generateFunc(ctx, call)
 45	}
 46	return &Response{
 47		Content: []Content{
 48			TextContent{Text: "Hello, world!"},
 49		},
 50		Usage: Usage{
 51			InputTokens:  3,
 52			OutputTokens: 10,
 53			TotalTokens:  13,
 54		},
 55		FinishReason: FinishReasonStop,
 56	}, nil
 57}
 58
 59func (m *mockLanguageModel) Stream(ctx context.Context, call Call) (StreamResponse, error) {
 60	panic("not implemented")
 61}
 62
 63func (m *mockLanguageModel) Provider() string {
 64	return "mock-provider"
 65}
 66
 67func (m *mockLanguageModel) Model() string {
 68	return "mock-model"
 69}
 70
 71// Test result.content - comprehensive content types (matches TS test)
 72func TestAgent_Generate_ResultContent_AllTypes(t *testing.T) {
 73	t.Parallel()
 74
 75	tool1 := &mockTool{
 76		name:        "tool1",
 77		description: "Test tool",
 78		parameters: map[string]any{
 79			"value": map[string]any{"type": "string"},
 80		},
 81		required: []string{"value"},
 82		executeFunc: func(ctx context.Context, call tools.ToolCall) (tools.ToolResponse, error) {
 83			var input map[string]any
 84			err := json.Unmarshal([]byte(call.Input), &input)
 85			require.NoError(t, err)
 86			require.Equal(t, "value", input["value"])
 87			return tools.ToolResponse{Content: "result1", IsError: false}, nil
 88		},
 89	}
 90
 91	model := &mockLanguageModel{
 92		generateFunc: func(ctx context.Context, call Call) (*Response, error) {
 93			return &Response{
 94				Content: []Content{
 95					TextContent{Text: "Hello, world!"},
 96					SourceContent{
 97						ID:         "123",
 98						URL:        "https://example.com",
 99						Title:      "Example",
100						SourceType: SourceTypeURL,
101						ProviderMetadata: ProviderMetadata{
102							"provider": map[string]any{"custom": "value"},
103						},
104					},
105					FileContent{
106						Data:      []byte{1, 2, 3},
107						MediaType: "image/png",
108					},
109					ReasoningContent{
110						Text: "I will open the conversation with witty banter.",
111					},
112					ToolCallContent{
113						ToolCallID: "call-1",
114						ToolName:   "tool1",
115						Input:      `{"value":"value"}`,
116					},
117					TextContent{Text: "More text"},
118				},
119				Usage: Usage{
120					InputTokens:  3,
121					OutputTokens: 10,
122					TotalTokens:  13,
123				},
124				FinishReason: FinishReasonStop, // Note: FinishReasonStop, not ToolCalls
125			}, nil
126		},
127	}
128
129	agent := NewAgent(model, WithTools(tool1))
130	result, err := agent.Generate(context.Background(), AgentCall{
131		Prompt: "prompt",
132	})
133
134	require.NoError(t, err)
135	require.NotNil(t, result)
136	require.Len(t, result.Steps, 1) // Single step like TypeScript
137
138	// Check final response content includes tool result
139	require.Len(t, result.Response.Content, 7) // original 6 + 1 tool result
140
141	// Verify each content type in order
142	textContent, ok := AsContentType[TextContent](result.Response.Content[0])
143	require.True(t, ok)
144	require.Equal(t, "Hello, world!", textContent.Text)
145
146	sourceContent, ok := AsContentType[SourceContent](result.Response.Content[1])
147	require.True(t, ok)
148	require.Equal(t, "123", sourceContent.ID)
149
150	fileContent, ok := AsContentType[FileContent](result.Response.Content[2])
151	require.True(t, ok)
152	require.Equal(t, []byte{1, 2, 3}, fileContent.Data)
153
154	reasoningContent, ok := AsContentType[ReasoningContent](result.Response.Content[3])
155	require.True(t, ok)
156	require.Equal(t, "I will open the conversation with witty banter.", reasoningContent.Text)
157
158	toolCallContent, ok := AsContentType[ToolCallContent](result.Response.Content[4])
159	require.True(t, ok)
160	require.Equal(t, "call-1", toolCallContent.ToolCallID)
161
162	moreTextContent, ok := AsContentType[TextContent](result.Response.Content[5])
163	require.True(t, ok)
164	require.Equal(t, "More text", moreTextContent.Text)
165
166	// Tool result should be appended
167	toolResultContent, ok := AsContentType[ToolResultContent](result.Response.Content[6])
168	require.True(t, ok)
169	require.Equal(t, "call-1", toolResultContent.ToolCallID)
170	require.Equal(t, "tool1", toolResultContent.ToolName)
171}
172
173// Test result.text extraction
174func TestAgent_Generate_ResultText(t *testing.T) {
175	t.Parallel()
176
177	model := &mockLanguageModel{
178		generateFunc: func(ctx context.Context, call Call) (*Response, error) {
179			return &Response{
180				Content: []Content{
181					TextContent{Text: "Hello, world!"},
182				},
183				Usage: Usage{
184					InputTokens:  3,
185					OutputTokens: 10,
186					TotalTokens:  13,
187				},
188				FinishReason: FinishReasonStop,
189			}, nil
190		},
191	}
192
193	agent := NewAgent(model)
194	result, err := agent.Generate(context.Background(), AgentCall{
195		Prompt: "prompt",
196	})
197
198	require.NoError(t, err)
199	require.NotNil(t, result)
200
201	// Test text extraction from content
202	text := result.Response.Content.Text()
203	require.Equal(t, "Hello, world!", text)
204}
205
206// Test result.toolCalls extraction (matches TS test exactly)
207func TestAgent_Generate_ResultToolCalls(t *testing.T) {
208	t.Parallel()
209
210	tool1 := &mockTool{
211		name:        "tool1",
212		description: "Test tool 1",
213		parameters: map[string]any{
214			"value": map[string]any{"type": "string"},
215		},
216		required: []string{"value"},
217	}
218
219	tool2 := &mockTool{
220		name:        "tool2",
221		description: "Test tool 2",
222		parameters: map[string]any{
223			"somethingElse": map[string]any{"type": "string"},
224		},
225		required: []string{"somethingElse"},
226	}
227
228	model := &mockLanguageModel{
229		generateFunc: func(ctx context.Context, call Call) (*Response, error) {
230			// Verify tools are passed correctly
231			require.Len(t, call.Tools, 2)
232			require.Equal(t, ToolChoiceAuto, *call.ToolChoice) // Should be auto, not required
233
234			// Verify prompt structure
235			require.Len(t, call.Prompt, 1)
236			require.Equal(t, MessageRoleUser, call.Prompt[0].Role)
237
238			return &Response{
239				Content: []Content{
240					ToolCallContent{
241						ToolCallID: "call-1",
242						ToolName:   "tool1",
243						Input:      `{"value":"value"}`,
244					},
245				},
246				Usage: Usage{
247					InputTokens:  3,
248					OutputTokens: 10,
249					TotalTokens:  13,
250				},
251				FinishReason: FinishReasonStop, // Note: Stop, not ToolCalls
252			}, nil
253		},
254	}
255
256	agent := NewAgent(model, WithTools(tool1, tool2))
257	result, err := agent.Generate(context.Background(), AgentCall{
258		Prompt: "test-input",
259	})
260
261	require.NoError(t, err)
262	require.NotNil(t, result)
263	require.Len(t, result.Steps, 1) // Single step
264
265	// Extract tool calls from final response (should be empty since tools don't execute)
266	var toolCalls []ToolCallContent
267	for _, content := range result.Response.Content {
268		if toolCall, ok := AsContentType[ToolCallContent](content); ok {
269			toolCalls = append(toolCalls, toolCall)
270		}
271	}
272
273	require.Len(t, toolCalls, 1)
274	require.Equal(t, "call-1", toolCalls[0].ToolCallID)
275	require.Equal(t, "tool1", toolCalls[0].ToolName)
276
277	// Parse and verify input
278	var input map[string]any
279	err = json.Unmarshal([]byte(toolCalls[0].Input), &input)
280	require.NoError(t, err)
281	require.Equal(t, "value", input["value"])
282}
283
284// Test result.toolResults extraction (matches TS test exactly)
285func TestAgent_Generate_ResultToolResults(t *testing.T) {
286	t.Parallel()
287
288	tool1 := &mockTool{
289		name:        "tool1",
290		description: "Test tool",
291		parameters: map[string]any{
292			"value": map[string]any{"type": "string"},
293		},
294		required: []string{"value"},
295		executeFunc: func(ctx context.Context, call tools.ToolCall) (tools.ToolResponse, error) {
296			var input map[string]any
297			err := json.Unmarshal([]byte(call.Input), &input)
298			require.NoError(t, err)
299			require.Equal(t, "value", input["value"])
300			return tools.ToolResponse{Content: "result1", IsError: false}, nil
301		},
302	}
303
304	model := &mockLanguageModel{
305		generateFunc: func(ctx context.Context, call Call) (*Response, error) {
306			// Verify tools and tool choice
307			require.Len(t, call.Tools, 1)
308			require.Equal(t, ToolChoiceAuto, *call.ToolChoice)
309
310			// Verify prompt
311			require.Len(t, call.Prompt, 1)
312			require.Equal(t, MessageRoleUser, call.Prompt[0].Role)
313
314			return &Response{
315				Content: []Content{
316					ToolCallContent{
317						ToolCallID: "call-1",
318						ToolName:   "tool1",
319						Input:      `{"value":"value"}`,
320					},
321				},
322				Usage: Usage{
323					InputTokens:  3,
324					OutputTokens: 10,
325					TotalTokens:  13,
326				},
327				FinishReason: FinishReasonStop, // Note: Stop, not ToolCalls
328			}, nil
329		},
330	}
331
332	agent := NewAgent(model, WithTools(tool1))
333	result, err := agent.Generate(context.Background(), AgentCall{
334		Prompt: "test-input",
335	})
336
337	require.NoError(t, err)
338	require.NotNil(t, result)
339	require.Len(t, result.Steps, 1) // Single step
340
341	// Extract tool results from final response
342	var toolResults []ToolResultContent
343	for _, content := range result.Response.Content {
344		if toolResult, ok := AsContentType[ToolResultContent](content); ok {
345			toolResults = append(toolResults, toolResult)
346		}
347	}
348
349	require.Len(t, toolResults, 1)
350	require.Equal(t, "call-1", toolResults[0].ToolCallID)
351	require.Equal(t, "tool1", toolResults[0].ToolName)
352
353	// Verify result content
354	textResult, ok := toolResults[0].Result.(ToolResultOutputContentText)
355	require.True(t, ok)
356	require.Equal(t, "result1", textResult.Text)
357}
358
359// Test multi-step scenario (matches TS "2 steps: initial, tool-result" test)
360func TestAgent_Generate_MultipleSteps(t *testing.T) {
361	t.Parallel()
362
363	tool1 := &mockTool{
364		name:        "tool1",
365		description: "Test tool",
366		parameters: map[string]any{
367			"value": map[string]any{"type": "string"},
368		},
369		required: []string{"value"},
370		executeFunc: func(ctx context.Context, call tools.ToolCall) (tools.ToolResponse, error) {
371			var input map[string]any
372			err := json.Unmarshal([]byte(call.Input), &input)
373			require.NoError(t, err)
374			require.Equal(t, "value", input["value"])
375			return tools.ToolResponse{Content: "result1", IsError: false}, nil
376		},
377	}
378
379	callCount := 0
380	model := &mockLanguageModel{
381		generateFunc: func(ctx context.Context, call Call) (*Response, error) {
382			callCount++
383			switch callCount {
384			case 1:
385				// First call - return tool call with FinishReasonToolCalls
386				return &Response{
387					Content: []Content{
388						ToolCallContent{
389							ToolCallID: "call-1",
390							ToolName:   "tool1",
391							Input:      `{"value":"value"}`,
392						},
393					},
394					Usage: Usage{
395						InputTokens:  10,
396						OutputTokens: 5,
397						TotalTokens:  15,
398					},
399					FinishReason: FinishReasonToolCalls, // This triggers multi-step
400				}, nil
401			case 2:
402				// Second call - return final text
403				return &Response{
404					Content: []Content{
405						TextContent{Text: "Hello, world!"},
406					},
407					Usage: Usage{
408						InputTokens:  3,
409						OutputTokens: 10,
410						TotalTokens:  13,
411					},
412					FinishReason: FinishReasonStop,
413				}, nil
414			default:
415				t.Fatalf("Unexpected call count: %d", callCount)
416				return nil, nil
417			}
418		},
419	}
420
421	agent := NewAgent(model, WithTools(tool1))
422	result, err := agent.Generate(context.Background(), AgentCall{
423		Prompt: "test-input",
424	})
425
426	require.NoError(t, err)
427	require.NotNil(t, result)
428	require.Len(t, result.Steps, 2)
429
430	// Check total usage sums both steps
431	require.Equal(t, int64(13), result.TotalUsage.InputTokens)  // 10 + 3
432	require.Equal(t, int64(15), result.TotalUsage.OutputTokens) // 5 + 10
433	require.Equal(t, int64(28), result.TotalUsage.TotalTokens)  // 15 + 13
434
435	// Final response should be from last step
436	require.Len(t, result.Response.Content, 1)
437	textContent, ok := AsContentType[TextContent](result.Response.Content[0])
438	require.True(t, ok)
439	require.Equal(t, "Hello, world!", textContent.Text)
440
441	// result.toolCalls should be empty (from last step)
442	var toolCalls []ToolCallContent
443	for _, content := range result.Response.Content {
444		if _, ok := AsContentType[ToolCallContent](content); ok {
445			toolCalls = append(toolCalls, content.(ToolCallContent))
446		}
447	}
448	require.Len(t, toolCalls, 0)
449
450	// result.toolResults should be empty (from last step)
451	var toolResults []ToolResultContent
452	for _, content := range result.Response.Content {
453		if _, ok := AsContentType[ToolResultContent](content); ok {
454			toolResults = append(toolResults, content.(ToolResultContent))
455		}
456	}
457	require.Len(t, toolResults, 0)
458}
459
460// Test basic text generation
461func TestAgent_Generate_BasicText(t *testing.T) {
462	t.Parallel()
463
464	model := &mockLanguageModel{
465		generateFunc: func(ctx context.Context, call Call) (*Response, error) {
466			return &Response{
467				Content: []Content{
468					TextContent{Text: "Hello, world!"},
469				},
470				Usage: Usage{
471					InputTokens:  3,
472					OutputTokens: 10,
473					TotalTokens:  13,
474				},
475				FinishReason: FinishReasonStop,
476			}, nil
477		},
478	}
479
480	agent := NewAgent(model)
481	result, err := agent.Generate(context.Background(), AgentCall{
482		Prompt: "test prompt",
483	})
484
485	require.NoError(t, err)
486	require.NotNil(t, result)
487	require.Len(t, result.Steps, 1)
488
489	// Check final response
490	require.Len(t, result.Response.Content, 1)
491	textContent, ok := AsContentType[TextContent](result.Response.Content[0])
492	require.True(t, ok)
493	require.Equal(t, "Hello, world!", textContent.Text)
494
495	// Check usage
496	require.Equal(t, int64(3), result.Response.Usage.InputTokens)
497	require.Equal(t, int64(10), result.Response.Usage.OutputTokens)
498	require.Equal(t, int64(13), result.Response.Usage.TotalTokens)
499
500	// Check total usage
501	require.Equal(t, int64(3), result.TotalUsage.InputTokens)
502	require.Equal(t, int64(10), result.TotalUsage.OutputTokens)
503	require.Equal(t, int64(13), result.TotalUsage.TotalTokens)
504}
505
506// Test empty prompt error
507func TestAgent_Generate_EmptyPrompt(t *testing.T) {
508	t.Parallel()
509
510	model := &mockLanguageModel{}
511	agent := NewAgent(model)
512
513	result, err := agent.Generate(context.Background(), AgentCall{
514		Prompt: "", // Empty prompt should cause error
515	})
516
517	require.Error(t, err)
518	require.Nil(t, result)
519	require.Contains(t, err.Error(), "Prompt can't be empty")
520}
521
522// Test with system prompt
523func TestAgent_Generate_WithSystemPrompt(t *testing.T) {
524	t.Parallel()
525
526	model := &mockLanguageModel{
527		generateFunc: func(ctx context.Context, call Call) (*Response, error) {
528			// Verify system message is included
529			require.Len(t, call.Prompt, 2) // system + user
530			require.Equal(t, MessageRoleSystem, call.Prompt[0].Role)
531			require.Equal(t, MessageRoleUser, call.Prompt[1].Role)
532
533			systemPart, ok := call.Prompt[0].Content[0].(TextPart)
534			require.True(t, ok)
535			require.Equal(t, "You are a helpful assistant", systemPart.Text)
536
537			return &Response{
538				Content: []Content{
539					TextContent{Text: "Hello, world!"},
540				},
541				Usage: Usage{
542					InputTokens:  3,
543					OutputTokens: 10,
544					TotalTokens:  13,
545				},
546				FinishReason: FinishReasonStop,
547			}, nil
548		},
549	}
550
551	agent := NewAgent(model, WithSystemPrompt("You are a helpful assistant"))
552	result, err := agent.Generate(context.Background(), AgentCall{
553		Prompt: "test prompt",
554	})
555
556	require.NoError(t, err)
557	require.NotNil(t, result)
558}
559
560// Test options.headers
561func TestAgent_Generate_OptionsHeaders(t *testing.T) {
562	t.Parallel()
563
564	model := &mockLanguageModel{
565		generateFunc: func(ctx context.Context, call Call) (*Response, error) {
566			// Verify headers are passed
567			require.Equal(t, map[string]string{
568				"custom-request-header": "request-header-value",
569			}, call.Headers)
570
571			return &Response{
572				Content: []Content{
573					TextContent{Text: "Hello, world!"},
574				},
575				Usage: Usage{
576					InputTokens:  3,
577					OutputTokens: 10,
578					TotalTokens:  13,
579				},
580				FinishReason: FinishReasonStop,
581			}, nil
582		},
583	}
584
585	agent := NewAgent(model)
586	result, err := agent.Generate(context.Background(), AgentCall{
587		Prompt:  "test-input",
588		Headers: map[string]string{"custom-request-header": "request-header-value"},
589	})
590
591	require.NoError(t, err)
592	require.NotNil(t, result)
593	require.Equal(t, "Hello, world!", result.Response.Content.Text())
594}
595
596// Test options.activeTools filtering
597func TestAgent_Generate_OptionsActiveTools(t *testing.T) {
598	t.Parallel()
599
600	tool1 := &mockTool{
601		name:        "tool1",
602		description: "Test tool 1",
603		parameters: map[string]any{
604			"value": map[string]any{"type": "string"},
605		},
606		required: []string{"value"},
607	}
608
609	tool2 := &mockTool{
610		name:        "tool2",
611		description: "Test tool 2",
612		parameters: map[string]any{
613			"value": map[string]any{"type": "string"},
614		},
615		required: []string{"value"},
616	}
617
618	model := &mockLanguageModel{
619		generateFunc: func(ctx context.Context, call Call) (*Response, error) {
620			// Verify only tool1 is available
621			require.Len(t, call.Tools, 1)
622			functionTool, ok := call.Tools[0].(FunctionTool)
623			require.True(t, ok)
624			require.Equal(t, "tool1", functionTool.Name)
625
626			return &Response{
627				Content: []Content{
628					TextContent{Text: "Hello, world!"},
629				},
630				Usage: Usage{
631					InputTokens:  3,
632					OutputTokens: 10,
633					TotalTokens:  13,
634				},
635				FinishReason: FinishReasonStop,
636			}, nil
637		},
638	}
639
640	agent := NewAgent(model, WithTools(tool1, tool2))
641	result, err := agent.Generate(context.Background(), AgentCall{
642		Prompt:      "test-input",
643		ActiveTools: []string{"tool1"}, // Only tool1 should be active
644	})
645
646	require.NoError(t, err)
647	require.NotNil(t, result)
648}