1package agent
2
3import (
4 "encoding/json"
5 "fmt"
6 "strings"
7 "testing"
8
9 "github.com/charmbracelet/crush/internal/agent/tools"
10 "github.com/charmbracelet/crush/internal/message"
11 "github.com/charmbracelet/fantasy/ai"
12 "github.com/stretchr/testify/assert"
13 "github.com/stretchr/testify/require"
14
15 _ "github.com/joho/godotenv/autoload"
16)
17
18var modelPairs = []modelPair{
19 {"anthropic-sonnet", anthropicBuilder("claude-sonnet-4-5-20250929"), anthropicBuilder("claude-3-5-haiku-20241022")},
20 {"openai-gpt-5", openaiBuilder("gpt-5"), openaiBuilder("gpt-4o")},
21 {"openrouter-kimi-k2", openRouterBuilder("moonshotai/kimi-k2-0905"), openRouterBuilder("qwen/qwen3-next-80b-a3b-instruct")},
22}
23
24func getModels(t *testing.T, pair modelPair) (ai.LanguageModel, ai.LanguageModel) {
25 r := newRecorder(t)
26 large, err := pair.largeModel(t, r)
27 require.NoError(t, err)
28 small, err := pair.smallModel(t, r)
29 require.NoError(t, err)
30 return large, small
31}
32
33func setupAgent(t *testing.T, pair modelPair) (SessionAgent, env) {
34 large, small := getModels(t, pair)
35 env := testEnv(t)
36
37 createSimpleGoProject(t, env.workingDir)
38 agent, err := coderAgent(env, large, small)
39 require.NoError(t, err)
40 return agent, env
41}
42
43func TestCoderAgent(t *testing.T) {
44 for _, pair := range modelPairs {
45 t.Run(pair.name, func(t *testing.T) {
46 t.Run("simple test", func(t *testing.T) {
47 agent, env := setupAgent(t, pair)
48
49 session, err := env.sessions.Create(t.Context(), "New Session")
50 require.NoError(t, err)
51
52 res, err := agent.Run(t.Context(), SessionAgentCall{
53 Prompt: "Hello",
54 SessionID: session.ID,
55 MaxOutputTokens: 10000,
56 })
57 require.NoError(t, err)
58 assert.NotNil(t, res)
59
60 msgs, err := env.messages.List(t.Context(), session.ID)
61 require.NoError(t, err)
62 // Should have the agent and user message
63 assert.Equal(t, len(msgs), 2)
64 })
65 t.Run("read a file", func(t *testing.T) {
66 agent, env := setupAgent(t, pair)
67
68 session, err := env.sessions.Create(t.Context(), "New Session")
69 require.NoError(t, err)
70 res, err := agent.Run(t.Context(), SessionAgentCall{
71 Prompt: "Read the go mod",
72 SessionID: session.ID,
73 MaxOutputTokens: 10000,
74 })
75
76 require.NoError(t, err)
77 assert.NotNil(t, res)
78
79 msgs, err := env.messages.List(t.Context(), session.ID)
80 require.NoError(t, err)
81 foundFile := false
82 var tcID string
83 out:
84 for _, msg := range msgs {
85 data, _ := json.Marshal(msg)
86 fmt.Println(string(data))
87 if msg.Role == message.Assistant {
88 for _, tc := range msg.ToolCalls() {
89 if tc.Name == tools.ViewToolName {
90 tcID = tc.ID
91 }
92 }
93 }
94 if msg.Role == message.Tool {
95 for _, tr := range msg.ToolResults() {
96 if tr.ToolCallID == tcID {
97 if strings.Contains(tr.Content, "module example.com/testproject") {
98 foundFile = true
99 break out
100 }
101 }
102 }
103 }
104 }
105 require.True(t, foundFile)
106 })
107 })
108 }
109}