1package agent
2
3import (
4 "os"
5 "path/filepath"
6 "runtime"
7 "strings"
8 "testing"
9
10 "charm.land/fantasy"
11 "github.com/charmbracelet/crush/internal/agent/tools"
12 "github.com/charmbracelet/crush/internal/message"
13 "github.com/stretchr/testify/assert"
14 "github.com/stretchr/testify/require"
15 "gopkg.in/dnaeon/go-vcr.v4/pkg/recorder"
16
17 _ "github.com/joho/godotenv/autoload"
18)
19
20var modelPairs = []modelPair{
21 {"anthropic-sonnet", anthropicBuilder("claude-sonnet-4-5-20250929"), anthropicBuilder("claude-3-5-haiku-20241022")},
22 {"openai-gpt-5", openaiBuilder("gpt-5"), openaiBuilder("gpt-4o")},
23 {"openrouter-kimi-k2", openRouterBuilder("moonshotai/kimi-k2-0905"), openRouterBuilder("qwen/qwen3-next-80b-a3b-instruct")},
24 {"zai-glm4.6", zAIBuilder("glm-4.6"), zAIBuilder("glm-4.5-air")},
25}
26
27func getModels(t *testing.T, r *recorder.Recorder, pair modelPair) (fantasy.LanguageModel, fantasy.LanguageModel) {
28 large, err := pair.largeModel(t, r)
29 require.NoError(t, err)
30 small, err := pair.smallModel(t, r)
31 require.NoError(t, err)
32 return large, small
33}
34
35func setupAgent(t *testing.T, pair modelPair) (SessionAgent, env) {
36 r := newRecorder(t)
37 large, small := getModels(t, r, pair)
38 env := testEnv(t)
39
40 createSimpleGoProject(t, env.workingDir)
41 agent, err := coderAgent(r, env, large, small)
42 require.NoError(t, err)
43 return agent, env
44}
45
46func TestCoderAgent(t *testing.T) {
47 if runtime.GOOS == "windows" {
48 t.Skip("skipping on windows for now")
49 }
50
51 for _, pair := range modelPairs {
52 t.Run(pair.name, func(t *testing.T) {
53 t.Run("simple test", func(t *testing.T) {
54 agent, env := setupAgent(t, pair)
55
56 session, err := env.sessions.Create(t.Context(), "New Session")
57 require.NoError(t, err)
58
59 res, err := agent.Run(t.Context(), SessionAgentCall{
60 Prompt: "Hello",
61 SessionID: session.ID,
62 MaxOutputTokens: 10000,
63 })
64 require.NoError(t, err)
65 assert.NotNil(t, res)
66
67 msgs, err := env.messages.List(t.Context(), session.ID)
68 require.NoError(t, err)
69 // Should have the agent and user message
70 assert.Equal(t, len(msgs), 2)
71 })
72 t.Run("read a file", func(t *testing.T) {
73 agent, env := setupAgent(t, pair)
74
75 session, err := env.sessions.Create(t.Context(), "New Session")
76 require.NoError(t, err)
77 res, err := agent.Run(t.Context(), SessionAgentCall{
78 Prompt: "Read the go mod",
79 SessionID: session.ID,
80 MaxOutputTokens: 10000,
81 })
82
83 require.NoError(t, err)
84 assert.NotNil(t, res)
85
86 msgs, err := env.messages.List(t.Context(), session.ID)
87 require.NoError(t, err)
88 foundFile := false
89 var tcID string
90 out:
91 for _, msg := range msgs {
92 if msg.Role == message.Assistant {
93 for _, tc := range msg.ToolCalls() {
94 if tc.Name == tools.ViewToolName {
95 tcID = tc.ID
96 }
97 }
98 }
99 if msg.Role == message.Tool {
100 for _, tr := range msg.ToolResults() {
101 if tr.ToolCallID == tcID {
102 if strings.Contains(tr.Content, "module example.com/testproject") {
103 foundFile = true
104 break out
105 }
106 }
107 }
108 }
109 }
110 require.True(t, foundFile)
111 })
112 t.Run("update a file", func(t *testing.T) {
113 agent, env := setupAgent(t, pair)
114
115 session, err := env.sessions.Create(t.Context(), "New Session")
116 require.NoError(t, err)
117
118 res, err := agent.Run(t.Context(), SessionAgentCall{
119 Prompt: "update the main.go file by changing the print to say hello from crush",
120 SessionID: session.ID,
121 MaxOutputTokens: 10000,
122 })
123 require.NoError(t, err)
124 assert.NotNil(t, res)
125
126 msgs, err := env.messages.List(t.Context(), session.ID)
127 require.NoError(t, err)
128
129 foundRead := false
130 foundWrite := false
131 var readTCID, writeTCID string
132
133 for _, msg := range msgs {
134 if msg.Role == message.Assistant {
135 for _, tc := range msg.ToolCalls() {
136 if tc.Name == tools.ViewToolName {
137 readTCID = tc.ID
138 }
139 if tc.Name == tools.EditToolName || tc.Name == tools.WriteToolName {
140 writeTCID = tc.ID
141 }
142 }
143 }
144 if msg.Role == message.Tool {
145 for _, tr := range msg.ToolResults() {
146 if tr.ToolCallID == readTCID {
147 foundRead = true
148 }
149 if tr.ToolCallID == writeTCID {
150 foundWrite = true
151 }
152 }
153 }
154 }
155
156 require.True(t, foundRead, "Expected to find a read operation")
157 require.True(t, foundWrite, "Expected to find a write operation")
158
159 mainGoPath := filepath.Join(env.workingDir, "main.go")
160 content, err := os.ReadFile(mainGoPath)
161 require.NoError(t, err)
162 require.Contains(t, strings.ToLower(string(content)), "hello from crush")
163 })
164 t.Run("bash tool", func(t *testing.T) {
165 agent, env := setupAgent(t, pair)
166
167 session, err := env.sessions.Create(t.Context(), "New Session")
168 require.NoError(t, err)
169
170 res, err := agent.Run(t.Context(), SessionAgentCall{
171 Prompt: "use bash to create a file named test.txt with content 'hello bash'. do not print its timestamp",
172 SessionID: session.ID,
173 MaxOutputTokens: 10000,
174 })
175 require.NoError(t, err)
176 assert.NotNil(t, res)
177
178 msgs, err := env.messages.List(t.Context(), session.ID)
179 require.NoError(t, err)
180
181 foundBash := false
182 var bashTCID string
183
184 for _, msg := range msgs {
185 if msg.Role == message.Assistant {
186 for _, tc := range msg.ToolCalls() {
187 if tc.Name == tools.BashToolName {
188 bashTCID = tc.ID
189 }
190 }
191 }
192 if msg.Role == message.Tool {
193 for _, tr := range msg.ToolResults() {
194 if tr.ToolCallID == bashTCID {
195 foundBash = true
196 }
197 }
198 }
199 }
200
201 require.True(t, foundBash, "Expected to find a bash operation")
202
203 testFilePath := filepath.Join(env.workingDir, "test.txt")
204 content, err := os.ReadFile(testFilePath)
205 require.NoError(t, err)
206 require.Contains(t, string(content), "hello bash")
207 })
208 t.Run("download tool", func(t *testing.T) {
209 agent, env := setupAgent(t, pair)
210
211 session, err := env.sessions.Create(t.Context(), "New Session")
212 require.NoError(t, err)
213
214 res, err := agent.Run(t.Context(), SessionAgentCall{
215 Prompt: "download the file from https://example-files.online-convert.com/document/txt/example.txt and save it as example.txt",
216 SessionID: session.ID,
217 MaxOutputTokens: 10000,
218 })
219 require.NoError(t, err)
220 assert.NotNil(t, res)
221
222 msgs, err := env.messages.List(t.Context(), session.ID)
223 require.NoError(t, err)
224
225 foundDownload := false
226 var downloadTCID string
227
228 for _, msg := range msgs {
229 if msg.Role == message.Assistant {
230 for _, tc := range msg.ToolCalls() {
231 if tc.Name == tools.DownloadToolName {
232 downloadTCID = tc.ID
233 }
234 }
235 }
236 if msg.Role == message.Tool {
237 for _, tr := range msg.ToolResults() {
238 if tr.ToolCallID == downloadTCID {
239 foundDownload = true
240 }
241 }
242 }
243 }
244
245 require.True(t, foundDownload, "Expected to find a download operation")
246
247 examplePath := filepath.Join(env.workingDir, "example.txt")
248 _, err = os.Stat(examplePath)
249 require.NoError(t, err, "Expected example.txt file to exist")
250 })
251 t.Run("fetch tool", func(t *testing.T) {
252 agent, env := setupAgent(t, pair)
253
254 session, err := env.sessions.Create(t.Context(), "New Session")
255 require.NoError(t, err)
256
257 res, err := agent.Run(t.Context(), SessionAgentCall{
258 Prompt: "fetch the content from https://example-files.online-convert.com/website/html/example.html and tell me if it contains the word 'John Doe'",
259 SessionID: session.ID,
260 MaxOutputTokens: 10000,
261 })
262 require.NoError(t, err)
263 assert.NotNil(t, res)
264
265 msgs, err := env.messages.List(t.Context(), session.ID)
266 require.NoError(t, err)
267
268 foundFetch := false
269 var fetchTCID string
270
271 for _, msg := range msgs {
272 if msg.Role == message.Assistant {
273 for _, tc := range msg.ToolCalls() {
274 if tc.Name == tools.FetchToolName {
275 fetchTCID = tc.ID
276 }
277 }
278 }
279 if msg.Role == message.Tool {
280 for _, tr := range msg.ToolResults() {
281 if tr.ToolCallID == fetchTCID {
282 foundFetch = true
283 }
284 }
285 }
286 }
287
288 require.True(t, foundFetch, "Expected to find a fetch operation")
289 })
290 t.Run("glob tool", func(t *testing.T) {
291 agent, env := setupAgent(t, pair)
292
293 session, err := env.sessions.Create(t.Context(), "New Session")
294 require.NoError(t, err)
295
296 res, err := agent.Run(t.Context(), SessionAgentCall{
297 Prompt: "use glob to find all .go files in the current directory",
298 SessionID: session.ID,
299 MaxOutputTokens: 10000,
300 })
301 require.NoError(t, err)
302 assert.NotNil(t, res)
303
304 msgs, err := env.messages.List(t.Context(), session.ID)
305 require.NoError(t, err)
306
307 foundGlob := false
308 var globTCID string
309
310 for _, msg := range msgs {
311 if msg.Role == message.Assistant {
312 for _, tc := range msg.ToolCalls() {
313 if tc.Name == tools.GlobToolName {
314 globTCID = tc.ID
315 }
316 }
317 }
318 if msg.Role == message.Tool {
319 for _, tr := range msg.ToolResults() {
320 if tr.ToolCallID == globTCID {
321 foundGlob = true
322 require.Contains(t, tr.Content, "main.go", "Expected glob to find main.go")
323 }
324 }
325 }
326 }
327
328 require.True(t, foundGlob, "Expected to find a glob operation")
329 })
330 t.Run("grep tool", func(t *testing.T) {
331 agent, env := setupAgent(t, pair)
332
333 session, err := env.sessions.Create(t.Context(), "New Session")
334 require.NoError(t, err)
335
336 res, err := agent.Run(t.Context(), SessionAgentCall{
337 Prompt: "use grep to search for the word 'package' in go files",
338 SessionID: session.ID,
339 MaxOutputTokens: 10000,
340 })
341 require.NoError(t, err)
342 assert.NotNil(t, res)
343
344 msgs, err := env.messages.List(t.Context(), session.ID)
345 require.NoError(t, err)
346
347 foundGrep := false
348 var grepTCID string
349
350 for _, msg := range msgs {
351 if msg.Role == message.Assistant {
352 for _, tc := range msg.ToolCalls() {
353 if tc.Name == tools.GrepToolName {
354 grepTCID = tc.ID
355 }
356 }
357 }
358 if msg.Role == message.Tool {
359 for _, tr := range msg.ToolResults() {
360 if tr.ToolCallID == grepTCID {
361 foundGrep = true
362 require.Contains(t, tr.Content, "main.go", "Expected grep to find main.go")
363 }
364 }
365 }
366 }
367
368 require.True(t, foundGrep, "Expected to find a grep operation")
369 })
370 t.Run("ls tool", func(t *testing.T) {
371 agent, env := setupAgent(t, pair)
372
373 session, err := env.sessions.Create(t.Context(), "New Session")
374 require.NoError(t, err)
375
376 res, err := agent.Run(t.Context(), SessionAgentCall{
377 Prompt: "use ls to list the files in the current directory",
378 SessionID: session.ID,
379 MaxOutputTokens: 10000,
380 })
381 require.NoError(t, err)
382 assert.NotNil(t, res)
383
384 msgs, err := env.messages.List(t.Context(), session.ID)
385 require.NoError(t, err)
386
387 foundLS := false
388 var lsTCID string
389
390 for _, msg := range msgs {
391 if msg.Role == message.Assistant {
392 for _, tc := range msg.ToolCalls() {
393 if tc.Name == tools.LSToolName {
394 lsTCID = tc.ID
395 }
396 }
397 }
398 if msg.Role == message.Tool {
399 for _, tr := range msg.ToolResults() {
400 if tr.ToolCallID == lsTCID {
401 foundLS = true
402 require.Contains(t, tr.Content, "main.go", "Expected ls to list main.go")
403 require.Contains(t, tr.Content, "go.mod", "Expected ls to list go.mod")
404 }
405 }
406 }
407 }
408
409 require.True(t, foundLS, "Expected to find an ls operation")
410 })
411 t.Run("multiedit tool", func(t *testing.T) {
412 agent, env := setupAgent(t, pair)
413
414 session, err := env.sessions.Create(t.Context(), "New Session")
415 require.NoError(t, err)
416
417 res, err := agent.Run(t.Context(), SessionAgentCall{
418 Prompt: "use multiedit to change 'Hello, World!' to 'Hello, Crush!' and add a comment '// Greeting' above the fmt.Println line in main.go",
419 SessionID: session.ID,
420 MaxOutputTokens: 10000,
421 })
422 require.NoError(t, err)
423 assert.NotNil(t, res)
424
425 msgs, err := env.messages.List(t.Context(), session.ID)
426 require.NoError(t, err)
427
428 foundMultiEdit := false
429 var multiEditTCID string
430
431 for _, msg := range msgs {
432 if msg.Role == message.Assistant {
433 for _, tc := range msg.ToolCalls() {
434 if tc.Name == tools.MultiEditToolName {
435 multiEditTCID = tc.ID
436 }
437 }
438 }
439 if msg.Role == message.Tool {
440 for _, tr := range msg.ToolResults() {
441 if tr.ToolCallID == multiEditTCID {
442 foundMultiEdit = true
443 }
444 }
445 }
446 }
447
448 require.True(t, foundMultiEdit, "Expected to find a multiedit operation")
449
450 mainGoPath := filepath.Join(env.workingDir, "main.go")
451 content, err := os.ReadFile(mainGoPath)
452 require.NoError(t, err)
453 require.Contains(t, string(content), "Hello, Crush!", "Expected file to contain 'Hello, Crush!'")
454 })
455 t.Run("sourcegraph tool", func(t *testing.T) {
456 agent, env := setupAgent(t, pair)
457
458 session, err := env.sessions.Create(t.Context(), "New Session")
459 require.NoError(t, err)
460
461 res, err := agent.Run(t.Context(), SessionAgentCall{
462 Prompt: "use sourcegraph to search for 'func main' in Go repositories",
463 SessionID: session.ID,
464 MaxOutputTokens: 10000,
465 })
466 require.NoError(t, err)
467 assert.NotNil(t, res)
468
469 msgs, err := env.messages.List(t.Context(), session.ID)
470 require.NoError(t, err)
471
472 foundSourcegraph := false
473 var sourcegraphTCID string
474
475 for _, msg := range msgs {
476 if msg.Role == message.Assistant {
477 for _, tc := range msg.ToolCalls() {
478 if tc.Name == tools.SourcegraphToolName {
479 sourcegraphTCID = tc.ID
480 }
481 }
482 }
483 if msg.Role == message.Tool {
484 for _, tr := range msg.ToolResults() {
485 if tr.ToolCallID == sourcegraphTCID {
486 foundSourcegraph = true
487 }
488 }
489 }
490 }
491
492 require.True(t, foundSourcegraph, "Expected to find a sourcegraph operation")
493 })
494 t.Run("write tool", func(t *testing.T) {
495 agent, env := setupAgent(t, pair)
496
497 session, err := env.sessions.Create(t.Context(), "New Session")
498 require.NoError(t, err)
499
500 res, err := agent.Run(t.Context(), SessionAgentCall{
501 Prompt: "use write to create a new file called config.json with content '{\"name\": \"test\", \"version\": \"1.0.0\"}'",
502 SessionID: session.ID,
503 MaxOutputTokens: 10000,
504 })
505 require.NoError(t, err)
506 assert.NotNil(t, res)
507
508 msgs, err := env.messages.List(t.Context(), session.ID)
509 require.NoError(t, err)
510
511 foundWrite := false
512 var writeTCID string
513
514 for _, msg := range msgs {
515 if msg.Role == message.Assistant {
516 for _, tc := range msg.ToolCalls() {
517 if tc.Name == tools.WriteToolName {
518 writeTCID = tc.ID
519 }
520 }
521 }
522 if msg.Role == message.Tool {
523 for _, tr := range msg.ToolResults() {
524 if tr.ToolCallID == writeTCID {
525 foundWrite = true
526 }
527 }
528 }
529 }
530
531 require.True(t, foundWrite, "Expected to find a write operation")
532
533 configPath := filepath.Join(env.workingDir, "config.json")
534 content, err := os.ReadFile(configPath)
535 require.NoError(t, err)
536 require.Contains(t, string(content), "test", "Expected config.json to contain 'test'")
537 require.Contains(t, string(content), "1.0.0", "Expected config.json to contain '1.0.0'")
538 })
539 t.Run("parallel tool calls", func(t *testing.T) {
540 agent, env := setupAgent(t, pair)
541
542 session, err := env.sessions.Create(t.Context(), "New Session")
543 require.NoError(t, err)
544
545 res, err := agent.Run(t.Context(), SessionAgentCall{
546 Prompt: "use glob to find all .go files and use ls to list the current directory, it is very important that you run both tool calls in parallel",
547 SessionID: session.ID,
548 MaxOutputTokens: 10000,
549 })
550 require.NoError(t, err)
551 assert.NotNil(t, res)
552
553 msgs, err := env.messages.List(t.Context(), session.ID)
554 require.NoError(t, err)
555
556 var assistantMsg *message.Message
557 var toolMsgs []message.Message
558
559 for _, msg := range msgs {
560 if msg.Role == message.Assistant && len(msg.ToolCalls()) > 0 {
561 assistantMsg = &msg
562 }
563 if msg.Role == message.Tool {
564 toolMsgs = append(toolMsgs, msg)
565 }
566 }
567
568 require.NotNil(t, assistantMsg, "Expected to find an assistant message with tool calls")
569 require.NotNil(t, toolMsgs, "Expected to find a tool message")
570
571 toolCalls := assistantMsg.ToolCalls()
572 require.GreaterOrEqual(t, len(toolCalls), 2, "Expected at least 2 tool calls in parallel")
573
574 foundGlob := false
575 foundLS := false
576 var globTCID, lsTCID string
577
578 for _, tc := range toolCalls {
579 if tc.Name == tools.GlobToolName {
580 foundGlob = true
581 globTCID = tc.ID
582 }
583 if tc.Name == tools.LSToolName {
584 foundLS = true
585 lsTCID = tc.ID
586 }
587 }
588
589 require.True(t, foundGlob, "Expected to find a glob tool call")
590 require.True(t, foundLS, "Expected to find an ls tool call")
591
592 require.GreaterOrEqual(t, len(toolMsgs), 2, "Expected at least 2 tool results in the same message")
593
594 foundGlobResult := false
595 foundLSResult := false
596
597 for _, msg := range toolMsgs {
598 for _, tr := range msg.ToolResults() {
599 if tr.ToolCallID == globTCID {
600 foundGlobResult = true
601 require.Contains(t, tr.Content, "main.go", "Expected glob result to contain main.go")
602 require.False(t, tr.IsError, "Expected glob result to not be an error")
603 }
604 if tr.ToolCallID == lsTCID {
605 foundLSResult = true
606 require.Contains(t, tr.Content, "main.go", "Expected ls result to contain main.go")
607 require.False(t, tr.IsError, "Expected ls result to not be an error")
608 }
609 }
610 }
611
612 require.True(t, foundGlobResult, "Expected to find glob tool result")
613 require.True(t, foundLSResult, "Expected to find ls tool result")
614 })
615 })
616 }
617}