1package agent
2
3import (
4 "os"
5 "path/filepath"
6 "strings"
7 "testing"
8
9 "github.com/charmbracelet/crush/internal/agent/tools"
10 "github.com/charmbracelet/crush/internal/message"
11 "github.com/charmbracelet/crush/internal/shell"
12 "github.com/charmbracelet/fantasy/ai"
13 "github.com/stretchr/testify/assert"
14 "github.com/stretchr/testify/require"
15 "gopkg.in/dnaeon/go-vcr.v4/pkg/recorder"
16
17 _ "github.com/joho/godotenv/autoload"
18)
19
20var modelPairs = []modelPair{
21 {"anthropic-sonnet", anthropicBuilder("claude-sonnet-4-5-20250929"), anthropicBuilder("claude-3-5-haiku-20241022")},
22 {"openai-gpt-5", openaiBuilder("gpt-5"), openaiBuilder("gpt-4o")},
23 {"openrouter-kimi-k2", openRouterBuilder("moonshotai/kimi-k2-0905"), openRouterBuilder("qwen/qwen3-next-80b-a3b-instruct")},
24}
25
26func getModels(t *testing.T, r *recorder.Recorder, pair modelPair) (ai.LanguageModel, ai.LanguageModel) {
27 large, err := pair.largeModel(t, r)
28 require.NoError(t, err)
29 small, err := pair.smallModel(t, r)
30 require.NoError(t, err)
31 return large, small
32}
33
34func setupAgent(t *testing.T, pair modelPair) (SessionAgent, env) {
35 r := newRecorder(t)
36 large, small := getModels(t, r, pair)
37 env := testEnv(t)
38
39 createSimpleGoProject(t, env.workingDir)
40 agent, err := coderAgent(r, env, large, small)
41 shell.Reset(env.workingDir)
42 require.NoError(t, err)
43 return agent, env
44}
45
46func TestCoderAgent(t *testing.T) {
47 for _, pair := range modelPairs {
48 t.Run(pair.name, func(t *testing.T) {
49 t.Run("simple test", func(t *testing.T) {
50 agent, env := setupAgent(t, pair)
51
52 session, err := env.sessions.Create(t.Context(), "New Session")
53 require.NoError(t, err)
54
55 res, err := agent.Run(t.Context(), SessionAgentCall{
56 Prompt: "Hello",
57 SessionID: session.ID,
58 MaxOutputTokens: 10000,
59 })
60 require.NoError(t, err)
61 assert.NotNil(t, res)
62
63 msgs, err := env.messages.List(t.Context(), session.ID)
64 require.NoError(t, err)
65 // Should have the agent and user message
66 assert.Equal(t, len(msgs), 2)
67 })
68 t.Run("read a file", func(t *testing.T) {
69 agent, env := setupAgent(t, pair)
70
71 session, err := env.sessions.Create(t.Context(), "New Session")
72 require.NoError(t, err)
73 res, err := agent.Run(t.Context(), SessionAgentCall{
74 Prompt: "Read the go mod",
75 SessionID: session.ID,
76 MaxOutputTokens: 10000,
77 })
78
79 require.NoError(t, err)
80 assert.NotNil(t, res)
81
82 msgs, err := env.messages.List(t.Context(), session.ID)
83 require.NoError(t, err)
84 foundFile := false
85 var tcID string
86 out:
87 for _, msg := range msgs {
88 if msg.Role == message.Assistant {
89 for _, tc := range msg.ToolCalls() {
90 if tc.Name == tools.ViewToolName {
91 tcID = tc.ID
92 }
93 }
94 }
95 if msg.Role == message.Tool {
96 for _, tr := range msg.ToolResults() {
97 if tr.ToolCallID == tcID {
98 if strings.Contains(tr.Content, "module example.com/testproject") {
99 foundFile = true
100 break out
101 }
102 }
103 }
104 }
105 }
106 require.True(t, foundFile)
107 })
108 t.Run("update a file", func(t *testing.T) {
109 agent, env := setupAgent(t, pair)
110
111 session, err := env.sessions.Create(t.Context(), "New Session")
112 require.NoError(t, err)
113
114 res, err := agent.Run(t.Context(), SessionAgentCall{
115 Prompt: "update the main.go file by changing the print to say hello from crush",
116 SessionID: session.ID,
117 MaxOutputTokens: 10000,
118 })
119 require.NoError(t, err)
120 assert.NotNil(t, res)
121
122 msgs, err := env.messages.List(t.Context(), session.ID)
123 require.NoError(t, err)
124
125 foundRead := false
126 foundWrite := false
127 var readTCID, writeTCID string
128
129 for _, msg := range msgs {
130 if msg.Role == message.Assistant {
131 for _, tc := range msg.ToolCalls() {
132 if tc.Name == tools.ViewToolName {
133 readTCID = tc.ID
134 }
135 if tc.Name == tools.EditToolName || tc.Name == tools.WriteToolName {
136 writeTCID = tc.ID
137 }
138 }
139 }
140 if msg.Role == message.Tool {
141 for _, tr := range msg.ToolResults() {
142 if tr.ToolCallID == readTCID {
143 foundRead = true
144 }
145 if tr.ToolCallID == writeTCID {
146 foundWrite = true
147 }
148 }
149 }
150 }
151
152 require.True(t, foundRead, "Expected to find a read operation")
153 require.True(t, foundWrite, "Expected to find a write operation")
154
155 mainGoPath := filepath.Join(env.workingDir, "main.go")
156 content, err := os.ReadFile(mainGoPath)
157 require.NoError(t, err)
158 require.Contains(t, strings.ToLower(string(content)), "hello from crush")
159 })
160 t.Run("bash tool", func(t *testing.T) {
161 agent, env := setupAgent(t, pair)
162
163 session, err := env.sessions.Create(t.Context(), "New Session")
164 require.NoError(t, err)
165
166 res, err := agent.Run(t.Context(), SessionAgentCall{
167 Prompt: "use bash to create a file named test.txt with content 'hello bash'",
168 SessionID: session.ID,
169 MaxOutputTokens: 10000,
170 })
171 require.NoError(t, err)
172 assert.NotNil(t, res)
173
174 msgs, err := env.messages.List(t.Context(), session.ID)
175 require.NoError(t, err)
176
177 foundBash := false
178 var bashTCID string
179
180 for _, msg := range msgs {
181 if msg.Role == message.Assistant {
182 for _, tc := range msg.ToolCalls() {
183 if tc.Name == tools.BashToolName {
184 bashTCID = tc.ID
185 }
186 }
187 }
188 if msg.Role == message.Tool {
189 for _, tr := range msg.ToolResults() {
190 if tr.ToolCallID == bashTCID {
191 foundBash = true
192 }
193 }
194 }
195 }
196
197 require.True(t, foundBash, "Expected to find a bash operation")
198
199 testFilePath := filepath.Join(env.workingDir, "test.txt")
200 content, err := os.ReadFile(testFilePath)
201 require.NoError(t, err)
202 require.Contains(t, string(content), "hello bash")
203 })
204 t.Run("download tool", func(t *testing.T) {
205 agent, env := setupAgent(t, pair)
206
207 session, err := env.sessions.Create(t.Context(), "New Session")
208 require.NoError(t, err)
209
210 res, err := agent.Run(t.Context(), SessionAgentCall{
211 Prompt: "download the file from https://httpbin.org/robots.txt and save it as robots.txt",
212 SessionID: session.ID,
213 MaxOutputTokens: 10000,
214 })
215 require.NoError(t, err)
216 assert.NotNil(t, res)
217
218 msgs, err := env.messages.List(t.Context(), session.ID)
219 require.NoError(t, err)
220
221 foundDownload := false
222 var downloadTCID string
223
224 for _, msg := range msgs {
225 if msg.Role == message.Assistant {
226 for _, tc := range msg.ToolCalls() {
227 if tc.Name == tools.DownloadToolName {
228 downloadTCID = tc.ID
229 }
230 }
231 }
232 if msg.Role == message.Tool {
233 for _, tr := range msg.ToolResults() {
234 if tr.ToolCallID == downloadTCID {
235 foundDownload = true
236 }
237 }
238 }
239 }
240
241 require.True(t, foundDownload, "Expected to find a download operation")
242
243 robotsPath := filepath.Join(env.workingDir, "robots.txt")
244 _, err = os.Stat(robotsPath)
245 require.NoError(t, err, "Expected robots.txt file to exist")
246 })
247 t.Run("fetch tool", func(t *testing.T) {
248 agent, env := setupAgent(t, pair)
249
250 session, err := env.sessions.Create(t.Context(), "New Session")
251 require.NoError(t, err)
252
253 res, err := agent.Run(t.Context(), SessionAgentCall{
254 Prompt: "fetch the content from https://httpbin.org/html and tell me if it contains the word 'Herman'",
255 SessionID: session.ID,
256 MaxOutputTokens: 10000,
257 })
258 require.NoError(t, err)
259 assert.NotNil(t, res)
260
261 msgs, err := env.messages.List(t.Context(), session.ID)
262 require.NoError(t, err)
263
264 foundFetch := false
265 var fetchTCID string
266
267 for _, msg := range msgs {
268 if msg.Role == message.Assistant {
269 for _, tc := range msg.ToolCalls() {
270 if tc.Name == tools.FetchToolName {
271 fetchTCID = tc.ID
272 }
273 }
274 }
275 if msg.Role == message.Tool {
276 for _, tr := range msg.ToolResults() {
277 if tr.ToolCallID == fetchTCID {
278 foundFetch = true
279 }
280 }
281 }
282 }
283
284 require.True(t, foundFetch, "Expected to find a fetch operation")
285 })
286 t.Run("glob tool", func(t *testing.T) {
287 agent, env := setupAgent(t, pair)
288
289 session, err := env.sessions.Create(t.Context(), "New Session")
290 require.NoError(t, err)
291
292 res, err := agent.Run(t.Context(), SessionAgentCall{
293 Prompt: "use glob to find all .go files in the current directory",
294 SessionID: session.ID,
295 MaxOutputTokens: 10000,
296 })
297 require.NoError(t, err)
298 assert.NotNil(t, res)
299
300 msgs, err := env.messages.List(t.Context(), session.ID)
301 require.NoError(t, err)
302
303 foundGlob := false
304 var globTCID string
305
306 for _, msg := range msgs {
307 if msg.Role == message.Assistant {
308 for _, tc := range msg.ToolCalls() {
309 if tc.Name == tools.GlobToolName {
310 globTCID = tc.ID
311 }
312 }
313 }
314 if msg.Role == message.Tool {
315 for _, tr := range msg.ToolResults() {
316 if tr.ToolCallID == globTCID {
317 foundGlob = true
318 require.Contains(t, tr.Content, "main.go", "Expected glob to find main.go")
319 }
320 }
321 }
322 }
323
324 require.True(t, foundGlob, "Expected to find a glob operation")
325 })
326 t.Run("grep tool", func(t *testing.T) {
327 agent, env := setupAgent(t, pair)
328
329 session, err := env.sessions.Create(t.Context(), "New Session")
330 require.NoError(t, err)
331
332 res, err := agent.Run(t.Context(), SessionAgentCall{
333 Prompt: "use grep to search for the word 'package' in go files",
334 SessionID: session.ID,
335 MaxOutputTokens: 10000,
336 })
337 require.NoError(t, err)
338 assert.NotNil(t, res)
339
340 msgs, err := env.messages.List(t.Context(), session.ID)
341 require.NoError(t, err)
342
343 foundGrep := false
344 var grepTCID string
345
346 for _, msg := range msgs {
347 if msg.Role == message.Assistant {
348 for _, tc := range msg.ToolCalls() {
349 if tc.Name == tools.GrepToolName {
350 grepTCID = tc.ID
351 }
352 }
353 }
354 if msg.Role == message.Tool {
355 for _, tr := range msg.ToolResults() {
356 if tr.ToolCallID == grepTCID {
357 foundGrep = true
358 require.Contains(t, tr.Content, "main.go", "Expected grep to find main.go")
359 }
360 }
361 }
362 }
363
364 require.True(t, foundGrep, "Expected to find a grep operation")
365 })
366 t.Run("ls tool", func(t *testing.T) {
367 agent, env := setupAgent(t, pair)
368
369 session, err := env.sessions.Create(t.Context(), "New Session")
370 require.NoError(t, err)
371
372 res, err := agent.Run(t.Context(), SessionAgentCall{
373 Prompt: "use ls to list the files in the current directory",
374 SessionID: session.ID,
375 MaxOutputTokens: 10000,
376 })
377 require.NoError(t, err)
378 assert.NotNil(t, res)
379
380 msgs, err := env.messages.List(t.Context(), session.ID)
381 require.NoError(t, err)
382
383 foundLS := false
384 var lsTCID string
385
386 for _, msg := range msgs {
387 if msg.Role == message.Assistant {
388 for _, tc := range msg.ToolCalls() {
389 if tc.Name == tools.LSToolName {
390 lsTCID = tc.ID
391 }
392 }
393 }
394 if msg.Role == message.Tool {
395 for _, tr := range msg.ToolResults() {
396 if tr.ToolCallID == lsTCID {
397 foundLS = true
398 require.Contains(t, tr.Content, "main.go", "Expected ls to list main.go")
399 require.Contains(t, tr.Content, "go.mod", "Expected ls to list go.mod")
400 }
401 }
402 }
403 }
404
405 require.True(t, foundLS, "Expected to find an ls operation")
406 })
407 t.Run("multiedit tool", func(t *testing.T) {
408 agent, env := setupAgent(t, pair)
409
410 session, err := env.sessions.Create(t.Context(), "New Session")
411 require.NoError(t, err)
412
413 res, err := agent.Run(t.Context(), SessionAgentCall{
414 Prompt: "use multiedit to change 'Hello, World!' to 'Hello, Crush!' and add a comment '// Greeting' above the fmt.Println line in main.go",
415 SessionID: session.ID,
416 MaxOutputTokens: 10000,
417 })
418 require.NoError(t, err)
419 assert.NotNil(t, res)
420
421 msgs, err := env.messages.List(t.Context(), session.ID)
422 require.NoError(t, err)
423
424 foundMultiEdit := false
425 var multiEditTCID string
426
427 for _, msg := range msgs {
428 if msg.Role == message.Assistant {
429 for _, tc := range msg.ToolCalls() {
430 if tc.Name == tools.MultiEditToolName {
431 multiEditTCID = tc.ID
432 }
433 }
434 }
435 if msg.Role == message.Tool {
436 for _, tr := range msg.ToolResults() {
437 if tr.ToolCallID == multiEditTCID {
438 foundMultiEdit = true
439 }
440 }
441 }
442 }
443
444 require.True(t, foundMultiEdit, "Expected to find a multiedit operation")
445
446 mainGoPath := filepath.Join(env.workingDir, "main.go")
447 content, err := os.ReadFile(mainGoPath)
448 require.NoError(t, err)
449 require.Contains(t, string(content), "Hello, Crush!", "Expected file to contain 'Hello, Crush!'")
450 })
451 t.Run("sourcegraph tool", func(t *testing.T) {
452 agent, env := setupAgent(t, pair)
453
454 session, err := env.sessions.Create(t.Context(), "New Session")
455 require.NoError(t, err)
456
457 res, err := agent.Run(t.Context(), SessionAgentCall{
458 Prompt: "use sourcegraph to search for 'func main' in Go repositories",
459 SessionID: session.ID,
460 MaxOutputTokens: 10000,
461 })
462 require.NoError(t, err)
463 assert.NotNil(t, res)
464
465 msgs, err := env.messages.List(t.Context(), session.ID)
466 require.NoError(t, err)
467
468 foundSourcegraph := false
469 var sourcegraphTCID string
470
471 for _, msg := range msgs {
472 if msg.Role == message.Assistant {
473 for _, tc := range msg.ToolCalls() {
474 if tc.Name == tools.SourcegraphToolName {
475 sourcegraphTCID = tc.ID
476 }
477 }
478 }
479 if msg.Role == message.Tool {
480 for _, tr := range msg.ToolResults() {
481 if tr.ToolCallID == sourcegraphTCID {
482 foundSourcegraph = true
483 }
484 }
485 }
486 }
487
488 require.True(t, foundSourcegraph, "Expected to find a sourcegraph operation")
489 })
490 t.Run("write tool", func(t *testing.T) {
491 agent, env := setupAgent(t, pair)
492
493 session, err := env.sessions.Create(t.Context(), "New Session")
494 require.NoError(t, err)
495
496 res, err := agent.Run(t.Context(), SessionAgentCall{
497 Prompt: "use write to create a new file called config.json with content '{\"name\": \"test\", \"version\": \"1.0.0\"}'",
498 SessionID: session.ID,
499 MaxOutputTokens: 10000,
500 })
501 require.NoError(t, err)
502 assert.NotNil(t, res)
503
504 msgs, err := env.messages.List(t.Context(), session.ID)
505 require.NoError(t, err)
506
507 foundWrite := false
508 var writeTCID string
509
510 for _, msg := range msgs {
511 if msg.Role == message.Assistant {
512 for _, tc := range msg.ToolCalls() {
513 if tc.Name == tools.WriteToolName {
514 writeTCID = tc.ID
515 }
516 }
517 }
518 if msg.Role == message.Tool {
519 for _, tr := range msg.ToolResults() {
520 if tr.ToolCallID == writeTCID {
521 foundWrite = true
522 }
523 }
524 }
525 }
526
527 require.True(t, foundWrite, "Expected to find a write operation")
528
529 configPath := filepath.Join(env.workingDir, "config.json")
530 content, err := os.ReadFile(configPath)
531 require.NoError(t, err)
532 require.Contains(t, string(content), "test", "Expected config.json to contain 'test'")
533 require.Contains(t, string(content), "1.0.0", "Expected config.json to contain '1.0.0'")
534 })
535 t.Run("parallel tool calls", func(t *testing.T) {
536 agent, env := setupAgent(t, pair)
537
538 session, err := env.sessions.Create(t.Context(), "New Session")
539 require.NoError(t, err)
540
541 res, err := agent.Run(t.Context(), SessionAgentCall{
542 Prompt: "use glob to find all .go files and use ls to list the current directory, it is very important that you run both tool calls in parallel",
543 SessionID: session.ID,
544 MaxOutputTokens: 10000,
545 })
546 require.NoError(t, err)
547 assert.NotNil(t, res)
548
549 msgs, err := env.messages.List(t.Context(), session.ID)
550 require.NoError(t, err)
551
552 var assistantMsg *message.Message
553 var toolMsgs []message.Message
554
555 for _, msg := range msgs {
556 if msg.Role == message.Assistant && len(msg.ToolCalls()) > 0 {
557 assistantMsg = &msg
558 }
559 if msg.Role == message.Tool {
560 toolMsgs = append(toolMsgs, msg)
561 }
562 }
563
564 require.NotNil(t, assistantMsg, "Expected to find an assistant message with tool calls")
565 require.NotNil(t, toolMsgs, "Expected to find a tool message")
566
567 toolCalls := assistantMsg.ToolCalls()
568 require.GreaterOrEqual(t, len(toolCalls), 2, "Expected at least 2 tool calls in parallel")
569
570 foundGlob := false
571 foundLS := false
572 var globTCID, lsTCID string
573
574 for _, tc := range toolCalls {
575 if tc.Name == tools.GlobToolName {
576 foundGlob = true
577 globTCID = tc.ID
578 }
579 if tc.Name == tools.LSToolName {
580 foundLS = true
581 lsTCID = tc.ID
582 }
583 }
584
585 require.True(t, foundGlob, "Expected to find a glob tool call")
586 require.True(t, foundLS, "Expected to find an ls tool call")
587
588 require.GreaterOrEqual(t, len(toolMsgs), 2, "Expected at least 2 tool results in the same message")
589
590 foundGlobResult := false
591 foundLSResult := false
592
593 for _, msg := range toolMsgs {
594 for _, tr := range msg.ToolResults() {
595 if tr.ToolCallID == globTCID {
596 foundGlobResult = true
597 require.Contains(t, tr.Content, "main.go", "Expected glob result to contain main.go")
598 require.False(t, tr.IsError, "Expected glob result to not be an error")
599 }
600 if tr.ToolCallID == lsTCID {
601 foundLSResult = true
602 require.Contains(t, tr.Content, "main.go", "Expected ls result to contain main.go")
603 require.False(t, tr.IsError, "Expected ls result to not be an error")
604 }
605 }
606 }
607
608 require.True(t, foundGlobResult, "Expected to find glob tool result")
609 require.True(t, foundLSResult, "Expected to find ls tool result")
610 })
611 })
612 }
613}