1package agent
2
3import (
4 "os"
5 "path/filepath"
6 "runtime"
7 "strings"
8 "testing"
9
10 "charm.land/fantasy"
11 "charm.land/x/vcr"
12 "github.com/charmbracelet/crush/internal/agent/tools"
13 "github.com/charmbracelet/crush/internal/message"
14 "github.com/stretchr/testify/assert"
15 "github.com/stretchr/testify/require"
16
17 _ "github.com/joho/godotenv/autoload"
18)
19
20var modelPairs = []modelPair{
21 {"anthropic-sonnet", anthropicBuilder("claude-sonnet-4-5-20250929"), anthropicBuilder("claude-3-5-haiku-20241022")},
22 {"openai-gpt-5", openaiBuilder("gpt-5"), openaiBuilder("gpt-4o")},
23 {"openrouter-kimi-k2", openRouterBuilder("moonshotai/kimi-k2-0905"), openRouterBuilder("qwen/qwen3-next-80b-a3b-instruct")},
24 {"zai-glm4.6", zAIBuilder("glm-4.6"), zAIBuilder("glm-4.5-air")},
25}
26
27func getModels(t *testing.T, r *vcr.Recorder, pair modelPair) (fantasy.LanguageModel, fantasy.LanguageModel) {
28 large, err := pair.largeModel(t, r)
29 require.NoError(t, err)
30 small, err := pair.smallModel(t, r)
31 require.NoError(t, err)
32 return large, small
33}
34
35func setupAgent(t *testing.T, pair modelPair) (SessionAgent, fakeEnv) {
36 r := vcr.NewRecorder(t)
37 large, small := getModels(t, r, pair)
38 env := testEnv(t)
39
40 createSimpleGoProject(t, env.workingDir)
41 agent, err := coderAgent(r, env, large, small)
42 require.NoError(t, err)
43 return agent, env
44}
45
46func TestCoderAgent(t *testing.T) {
47 if runtime.GOOS == "windows" {
48 t.Skip("skipping on windows for now")
49 }
50
51 for _, pair := range modelPairs {
52 t.Run(pair.name, func(t *testing.T) {
53 t.Run("simple test", func(t *testing.T) {
54 agent, env := setupAgent(t, pair)
55
56 session, err := env.sessions.Create(t.Context(), "New Session")
57 require.NoError(t, err)
58
59 res, err := agent.Run(t.Context(), SessionAgentCall{
60 Prompt: "Hello",
61 SessionID: session.ID,
62 MaxOutputTokens: 10000,
63 })
64 require.NoError(t, err)
65 assert.NotNil(t, res)
66
67 msgs, err := env.messages.List(t.Context(), session.ID)
68 require.NoError(t, err)
69 // Should have the agent and user message
70 assert.Equal(t, len(msgs), 2)
71 })
72 t.Run("read a file", func(t *testing.T) {
73 agent, env := setupAgent(t, pair)
74
75 session, err := env.sessions.Create(t.Context(), "New Session")
76 require.NoError(t, err)
77 res, err := agent.Run(t.Context(), SessionAgentCall{
78 Prompt: "Read the go mod",
79 SessionID: session.ID,
80 MaxOutputTokens: 10000,
81 })
82
83 require.NoError(t, err)
84 assert.NotNil(t, res)
85
86 msgs, err := env.messages.List(t.Context(), session.ID)
87 require.NoError(t, err)
88 foundFile := false
89 var tcID string
90 out:
91 for _, msg := range msgs {
92 if msg.Role == message.Assistant {
93 for _, tc := range msg.ToolCalls() {
94 if tc.Name == tools.ViewToolName {
95 tcID = tc.ID
96 }
97 }
98 }
99 if msg.Role == message.Tool {
100 for _, tr := range msg.ToolResults() {
101 if tr.ToolCallID == tcID {
102 if strings.Contains(tr.Content, "module example.com/testproject") {
103 foundFile = true
104 break out
105 }
106 }
107 }
108 }
109 }
110 require.True(t, foundFile)
111 })
112 t.Run("update a file", func(t *testing.T) {
113 agent, env := setupAgent(t, pair)
114
115 session, err := env.sessions.Create(t.Context(), "New Session")
116 require.NoError(t, err)
117
118 res, err := agent.Run(t.Context(), SessionAgentCall{
119 Prompt: "update the main.go file by changing the print to say hello from crush",
120 SessionID: session.ID,
121 MaxOutputTokens: 10000,
122 })
123 require.NoError(t, err)
124 assert.NotNil(t, res)
125
126 msgs, err := env.messages.List(t.Context(), session.ID)
127 require.NoError(t, err)
128
129 foundRead := false
130 foundWrite := false
131 var readTCID, writeTCID string
132
133 for _, msg := range msgs {
134 if msg.Role == message.Assistant {
135 for _, tc := range msg.ToolCalls() {
136 if tc.Name == tools.ViewToolName {
137 readTCID = tc.ID
138 }
139 if tc.Name == tools.EditToolName || tc.Name == tools.WriteToolName {
140 writeTCID = tc.ID
141 }
142 }
143 }
144 if msg.Role == message.Tool {
145 for _, tr := range msg.ToolResults() {
146 if tr.ToolCallID == readTCID {
147 foundRead = true
148 }
149 if tr.ToolCallID == writeTCID {
150 foundWrite = true
151 }
152 }
153 }
154 }
155
156 require.True(t, foundRead, "Expected to find a read operation")
157 require.True(t, foundWrite, "Expected to find a write operation")
158
159 mainGoPath := filepath.Join(env.workingDir, "main.go")
160 content, err := os.ReadFile(mainGoPath)
161 require.NoError(t, err)
162 require.Contains(t, strings.ToLower(string(content)), "hello from crush")
163 })
164 t.Run("bash tool", func(t *testing.T) {
165 agent, env := setupAgent(t, pair)
166
167 session, err := env.sessions.Create(t.Context(), "New Session")
168 require.NoError(t, err)
169
170 res, err := agent.Run(t.Context(), SessionAgentCall{
171 Prompt: "use bash to create a file named test.txt with content 'hello bash'. do not print its timestamp",
172 SessionID: session.ID,
173 MaxOutputTokens: 10000,
174 })
175 require.NoError(t, err)
176 assert.NotNil(t, res)
177
178 msgs, err := env.messages.List(t.Context(), session.ID)
179 require.NoError(t, err)
180
181 foundBash := false
182 var bashTCID string
183
184 for _, msg := range msgs {
185 if msg.Role == message.Assistant {
186 for _, tc := range msg.ToolCalls() {
187 if tc.Name == tools.BashToolName {
188 bashTCID = tc.ID
189 }
190 }
191 }
192 if msg.Role == message.Tool {
193 for _, tr := range msg.ToolResults() {
194 if tr.ToolCallID == bashTCID {
195 foundBash = true
196 }
197 }
198 }
199 }
200
201 require.True(t, foundBash, "Expected to find a bash operation")
202
203 testFilePath := filepath.Join(env.workingDir, "test.txt")
204 content, err := os.ReadFile(testFilePath)
205 require.NoError(t, err)
206 require.Contains(t, string(content), "hello bash")
207 })
208 t.Run("download tool", func(t *testing.T) {
209 agent, env := setupAgent(t, pair)
210
211 session, err := env.sessions.Create(t.Context(), "New Session")
212 require.NoError(t, err)
213
214 res, err := agent.Run(t.Context(), SessionAgentCall{
215 Prompt: "download the file from https://example-files.online-convert.com/document/txt/example.txt and save it as example.txt",
216 SessionID: session.ID,
217 MaxOutputTokens: 10000,
218 })
219 require.NoError(t, err)
220 assert.NotNil(t, res)
221
222 msgs, err := env.messages.List(t.Context(), session.ID)
223 require.NoError(t, err)
224
225 foundDownload := false
226 var downloadTCID string
227
228 for _, msg := range msgs {
229 if msg.Role == message.Assistant {
230 for _, tc := range msg.ToolCalls() {
231 if tc.Name == tools.DownloadToolName {
232 downloadTCID = tc.ID
233 }
234 }
235 }
236 if msg.Role == message.Tool {
237 for _, tr := range msg.ToolResults() {
238 if tr.ToolCallID == downloadTCID {
239 foundDownload = true
240 }
241 }
242 }
243 }
244
245 require.True(t, foundDownload, "Expected to find a download operation")
246
247 examplePath := filepath.Join(env.workingDir, "example.txt")
248 _, err = os.Stat(examplePath)
249 require.NoError(t, err, "Expected example.txt file to exist")
250 })
251 t.Run("fetch tool", func(t *testing.T) {
252 agent, env := setupAgent(t, pair)
253
254 session, err := env.sessions.Create(t.Context(), "New Session")
255 require.NoError(t, err)
256
257 res, err := agent.Run(t.Context(), SessionAgentCall{
258 Prompt: "fetch the content from https://example-files.online-convert.com/website/html/example.html and tell me if it contains the word 'John Doe'",
259 SessionID: session.ID,
260 MaxOutputTokens: 10000,
261 })
262 require.NoError(t, err)
263 assert.NotNil(t, res)
264
265 msgs, err := env.messages.List(t.Context(), session.ID)
266 require.NoError(t, err)
267
268 foundFetch := false
269 var fetchTCID string
270
271 for _, msg := range msgs {
272 if msg.Role == message.Assistant {
273 for _, tc := range msg.ToolCalls() {
274 if tc.Name == tools.FetchToolName {
275 fetchTCID = tc.ID
276 }
277 }
278 }
279 if msg.Role == message.Tool {
280 for _, tr := range msg.ToolResults() {
281 if tr.ToolCallID == fetchTCID {
282 foundFetch = true
283 }
284 }
285 }
286 }
287
288 require.True(t, foundFetch, "Expected to find a fetch operation")
289 })
290 t.Run("glob tool", func(t *testing.T) {
291 agent, env := setupAgent(t, pair)
292
293 session, err := env.sessions.Create(t.Context(), "New Session")
294 require.NoError(t, err)
295
296 res, err := agent.Run(t.Context(), SessionAgentCall{
297 Prompt: "use glob to find all .go files in the current directory",
298 SessionID: session.ID,
299 MaxOutputTokens: 10000,
300 })
301 require.NoError(t, err)
302 assert.NotNil(t, res)
303
304 msgs, err := env.messages.List(t.Context(), session.ID)
305 require.NoError(t, err)
306
307 foundGlob := false
308 var globTCID string
309
310 for _, msg := range msgs {
311 if msg.Role == message.Assistant {
312 for _, tc := range msg.ToolCalls() {
313 if tc.Name == tools.GlobToolName {
314 globTCID = tc.ID
315 }
316 }
317 }
318 if msg.Role == message.Tool {
319 for _, tr := range msg.ToolResults() {
320 if tr.ToolCallID == globTCID {
321 foundGlob = true
322 require.Contains(t, tr.Content, "main.go", "Expected glob to find main.go")
323 }
324 }
325 }
326 }
327
328 require.True(t, foundGlob, "Expected to find a glob operation")
329 })
330 t.Run("grep tool", func(t *testing.T) {
331 agent, env := setupAgent(t, pair)
332
333 session, err := env.sessions.Create(t.Context(), "New Session")
334 require.NoError(t, err)
335
336 res, err := agent.Run(t.Context(), SessionAgentCall{
337 Prompt: "use grep to search for the word 'package' in go files",
338 SessionID: session.ID,
339 MaxOutputTokens: 10000,
340 })
341 require.NoError(t, err)
342 assert.NotNil(t, res)
343
344 msgs, err := env.messages.List(t.Context(), session.ID)
345 require.NoError(t, err)
346
347 foundGrep := false
348 var grepTCID string
349
350 for _, msg := range msgs {
351 if msg.Role == message.Assistant {
352 for _, tc := range msg.ToolCalls() {
353 if tc.Name == tools.GrepToolName {
354 grepTCID = tc.ID
355 }
356 }
357 }
358 if msg.Role == message.Tool {
359 for _, tr := range msg.ToolResults() {
360 if tr.ToolCallID == grepTCID {
361 foundGrep = true
362 require.Contains(t, tr.Content, "main.go", "Expected grep to find main.go")
363 }
364 }
365 }
366 }
367
368 require.True(t, foundGrep, "Expected to find a grep operation")
369 })
370 t.Run("ls tool", func(t *testing.T) {
371 agent, env := setupAgent(t, pair)
372
373 session, err := env.sessions.Create(t.Context(), "New Session")
374 require.NoError(t, err)
375
376 res, err := agent.Run(t.Context(), SessionAgentCall{
377 Prompt: "use ls to list the files in the current directory",
378 SessionID: session.ID,
379 MaxOutputTokens: 10000,
380 })
381 require.NoError(t, err)
382 assert.NotNil(t, res)
383
384 msgs, err := env.messages.List(t.Context(), session.ID)
385 require.NoError(t, err)
386
387 foundLS := false
388 var lsTCID string
389
390 for _, msg := range msgs {
391 if msg.Role == message.Assistant {
392 for _, tc := range msg.ToolCalls() {
393 if tc.Name == tools.LSToolName {
394 lsTCID = tc.ID
395 }
396 }
397 }
398 if msg.Role == message.Tool {
399 for _, tr := range msg.ToolResults() {
400 if tr.ToolCallID == lsTCID {
401 foundLS = true
402 require.Contains(t, tr.Content, "main.go", "Expected ls to list main.go")
403 require.Contains(t, tr.Content, "go.mod", "Expected ls to list go.mod")
404 }
405 }
406 }
407 }
408
409 require.True(t, foundLS, "Expected to find an ls operation")
410 })
411 t.Run("multiedit tool", func(t *testing.T) {
412 agent, env := setupAgent(t, pair)
413
414 session, err := env.sessions.Create(t.Context(), "New Session")
415 require.NoError(t, err)
416
417 res, err := agent.Run(t.Context(), SessionAgentCall{
418 Prompt: "use multiedit to change 'Hello, World!' to 'Hello, Crush!' and add a comment '// Greeting' above the fmt.Println line in main.go",
419 SessionID: session.ID,
420 MaxOutputTokens: 10000,
421 })
422 require.NoError(t, err)
423 assert.NotNil(t, res)
424
425 msgs, err := env.messages.List(t.Context(), session.ID)
426 require.NoError(t, err)
427
428 foundMultiEdit := false
429 var multiEditTCID string
430
431 for _, msg := range msgs {
432 if msg.Role == message.Assistant {
433 for _, tc := range msg.ToolCalls() {
434 if tc.Name == tools.MultiEditToolName {
435 multiEditTCID = tc.ID
436 }
437 }
438 }
439 if msg.Role == message.Tool {
440 for _, tr := range msg.ToolResults() {
441 if tr.ToolCallID == multiEditTCID {
442 foundMultiEdit = true
443 }
444 }
445 }
446 }
447
448 require.True(t, foundMultiEdit, "Expected to find a multiedit operation")
449
450 mainGoPath := filepath.Join(env.workingDir, "main.go")
451 content, err := os.ReadFile(mainGoPath)
452 require.NoError(t, err)
453 require.Contains(t, string(content), "Hello, Crush!", "Expected file to contain 'Hello, Crush!'")
454 })
455 t.Run("sourcegraph tool", func(t *testing.T) {
456 if runtime.GOOS == "darwin" {
457 t.Skip("skipping flacky test on macos for now")
458 }
459
460 agent, env := setupAgent(t, pair)
461
462 session, err := env.sessions.Create(t.Context(), "New Session")
463 require.NoError(t, err)
464
465 res, err := agent.Run(t.Context(), SessionAgentCall{
466 Prompt: "use sourcegraph to search for 'func main' in Go repositories",
467 SessionID: session.ID,
468 MaxOutputTokens: 10000,
469 })
470 require.NoError(t, err)
471 assert.NotNil(t, res)
472
473 msgs, err := env.messages.List(t.Context(), session.ID)
474 require.NoError(t, err)
475
476 foundSourcegraph := false
477 var sourcegraphTCID string
478
479 for _, msg := range msgs {
480 if msg.Role == message.Assistant {
481 for _, tc := range msg.ToolCalls() {
482 if tc.Name == tools.SourcegraphToolName {
483 sourcegraphTCID = tc.ID
484 }
485 }
486 }
487 if msg.Role == message.Tool {
488 for _, tr := range msg.ToolResults() {
489 if tr.ToolCallID == sourcegraphTCID {
490 foundSourcegraph = true
491 }
492 }
493 }
494 }
495
496 require.True(t, foundSourcegraph, "Expected to find a sourcegraph operation")
497 })
498 t.Run("write tool", func(t *testing.T) {
499 agent, env := setupAgent(t, pair)
500
501 session, err := env.sessions.Create(t.Context(), "New Session")
502 require.NoError(t, err)
503
504 res, err := agent.Run(t.Context(), SessionAgentCall{
505 Prompt: "use write to create a new file called config.json with content '{\"name\": \"test\", \"version\": \"1.0.0\"}'",
506 SessionID: session.ID,
507 MaxOutputTokens: 10000,
508 })
509 require.NoError(t, err)
510 assert.NotNil(t, res)
511
512 msgs, err := env.messages.List(t.Context(), session.ID)
513 require.NoError(t, err)
514
515 foundWrite := false
516 var writeTCID string
517
518 for _, msg := range msgs {
519 if msg.Role == message.Assistant {
520 for _, tc := range msg.ToolCalls() {
521 if tc.Name == tools.WriteToolName {
522 writeTCID = tc.ID
523 }
524 }
525 }
526 if msg.Role == message.Tool {
527 for _, tr := range msg.ToolResults() {
528 if tr.ToolCallID == writeTCID {
529 foundWrite = true
530 }
531 }
532 }
533 }
534
535 require.True(t, foundWrite, "Expected to find a write operation")
536
537 configPath := filepath.Join(env.workingDir, "config.json")
538 content, err := os.ReadFile(configPath)
539 require.NoError(t, err)
540 require.Contains(t, string(content), "test", "Expected config.json to contain 'test'")
541 require.Contains(t, string(content), "1.0.0", "Expected config.json to contain '1.0.0'")
542 })
543 t.Run("parallel tool calls", func(t *testing.T) {
544 agent, env := setupAgent(t, pair)
545
546 session, err := env.sessions.Create(t.Context(), "New Session")
547 require.NoError(t, err)
548
549 res, err := agent.Run(t.Context(), SessionAgentCall{
550 Prompt: "use glob to find all .go files and use ls to list the current directory, it is very important that you run both tool calls in parallel",
551 SessionID: session.ID,
552 MaxOutputTokens: 10000,
553 })
554 require.NoError(t, err)
555 assert.NotNil(t, res)
556
557 msgs, err := env.messages.List(t.Context(), session.ID)
558 require.NoError(t, err)
559
560 var assistantMsg *message.Message
561 var toolMsgs []message.Message
562
563 for _, msg := range msgs {
564 if msg.Role == message.Assistant && len(msg.ToolCalls()) > 0 {
565 assistantMsg = &msg
566 }
567 if msg.Role == message.Tool {
568 toolMsgs = append(toolMsgs, msg)
569 }
570 }
571
572 require.NotNil(t, assistantMsg, "Expected to find an assistant message with tool calls")
573 require.NotNil(t, toolMsgs, "Expected to find a tool message")
574
575 toolCalls := assistantMsg.ToolCalls()
576 require.GreaterOrEqual(t, len(toolCalls), 2, "Expected at least 2 tool calls in parallel")
577
578 foundGlob := false
579 foundLS := false
580 var globTCID, lsTCID string
581
582 for _, tc := range toolCalls {
583 if tc.Name == tools.GlobToolName {
584 foundGlob = true
585 globTCID = tc.ID
586 }
587 if tc.Name == tools.LSToolName {
588 foundLS = true
589 lsTCID = tc.ID
590 }
591 }
592
593 require.True(t, foundGlob, "Expected to find a glob tool call")
594 require.True(t, foundLS, "Expected to find an ls tool call")
595
596 require.GreaterOrEqual(t, len(toolMsgs), 2, "Expected at least 2 tool results in the same message")
597
598 foundGlobResult := false
599 foundLSResult := false
600
601 for _, msg := range toolMsgs {
602 for _, tr := range msg.ToolResults() {
603 if tr.ToolCallID == globTCID {
604 foundGlobResult = true
605 require.Contains(t, tr.Content, "main.go", "Expected glob result to contain main.go")
606 require.False(t, tr.IsError, "Expected glob result to not be an error")
607 }
608 if tr.ToolCallID == lsTCID {
609 foundLSResult = true
610 require.Contains(t, tr.Content, "main.go", "Expected ls result to contain main.go")
611 require.False(t, tr.IsError, "Expected ls result to not be an error")
612 }
613 }
614 }
615
616 require.True(t, foundGlobResult, "Expected to find glob tool result")
617 require.True(t, foundLSResult, "Expected to find ls tool result")
618 })
619 })
620 }
621}