1package agent
2
3import (
4 "os"
5 "path/filepath"
6 "runtime"
7 "strings"
8 "testing"
9
10 "charm.land/fantasy"
11 "github.com/charmbracelet/crush/internal/agent/tools"
12 "github.com/charmbracelet/crush/internal/message"
13 "github.com/charmbracelet/crush/internal/shell"
14 "github.com/stretchr/testify/assert"
15 "github.com/stretchr/testify/require"
16 "gopkg.in/dnaeon/go-vcr.v4/pkg/recorder"
17
18 _ "github.com/joho/godotenv/autoload"
19)
20
21var modelPairs = []modelPair{
22 {"anthropic-sonnet", anthropicBuilder("claude-sonnet-4-5-20250929"), anthropicBuilder("claude-3-5-haiku-20241022")},
23 {"openai-gpt-5", openaiBuilder("gpt-5"), openaiBuilder("gpt-4o")},
24 {"openrouter-kimi-k2", openRouterBuilder("moonshotai/kimi-k2-0905"), openRouterBuilder("qwen/qwen3-next-80b-a3b-instruct")},
25 {"zai-glm4.6", zAIBuilder("glm-4.6"), zAIBuilder("glm-4.5-air")},
26}
27
28func getModels(t *testing.T, r *recorder.Recorder, pair modelPair) (fantasy.LanguageModel, fantasy.LanguageModel) {
29 large, err := pair.largeModel(t, r)
30 require.NoError(t, err)
31 small, err := pair.smallModel(t, r)
32 require.NoError(t, err)
33 return large, small
34}
35
36func setupAgent(t *testing.T, pair modelPair) (SessionAgent, fakeEnv) {
37 r := newRecorder(t)
38 large, small := getModels(t, r, pair)
39 env := testEnv(t)
40
41 createSimpleGoProject(t, env.workingDir)
42 agent, err := coderAgent(r, env, large, small)
43 shell.Reset(env.workingDir)
44 require.NoError(t, err)
45 return agent, env
46}
47
48func TestCoderAgent(t *testing.T) {
49 if runtime.GOOS == "windows" {
50 t.Skip("skipping on windows for now")
51 }
52
53 for _, pair := range modelPairs {
54 t.Run(pair.name, func(t *testing.T) {
55 t.Run("simple test", func(t *testing.T) {
56 agent, env := setupAgent(t, pair)
57
58 session, err := env.sessions.Create(t.Context(), "New Session")
59 require.NoError(t, err)
60
61 res, err := agent.Run(t.Context(), SessionAgentCall{
62 Prompt: "Hello",
63 SessionID: session.ID,
64 MaxOutputTokens: 10000,
65 })
66 require.NoError(t, err)
67 assert.NotNil(t, res)
68
69 msgs, err := env.messages.List(t.Context(), session.ID)
70 require.NoError(t, err)
71 // Should have the agent and user message
72 assert.Equal(t, len(msgs), 2)
73 })
74 t.Run("read a file", func(t *testing.T) {
75 agent, env := setupAgent(t, pair)
76
77 session, err := env.sessions.Create(t.Context(), "New Session")
78 require.NoError(t, err)
79 res, err := agent.Run(t.Context(), SessionAgentCall{
80 Prompt: "Read the go mod",
81 SessionID: session.ID,
82 MaxOutputTokens: 10000,
83 })
84
85 require.NoError(t, err)
86 assert.NotNil(t, res)
87
88 msgs, err := env.messages.List(t.Context(), session.ID)
89 require.NoError(t, err)
90 foundFile := false
91 var tcID string
92 out:
93 for _, msg := range msgs {
94 if msg.Role == message.Assistant {
95 for _, tc := range msg.ToolCalls() {
96 if tc.Name == tools.ViewToolName {
97 tcID = tc.ID
98 }
99 }
100 }
101 if msg.Role == message.Tool {
102 for _, tr := range msg.ToolResults() {
103 if tr.ToolCallID == tcID {
104 if strings.Contains(tr.Content, "module example.com/testproject") {
105 foundFile = true
106 break out
107 }
108 }
109 }
110 }
111 }
112 require.True(t, foundFile)
113 })
114 t.Run("update a file", func(t *testing.T) {
115 agent, env := setupAgent(t, pair)
116
117 session, err := env.sessions.Create(t.Context(), "New Session")
118 require.NoError(t, err)
119
120 res, err := agent.Run(t.Context(), SessionAgentCall{
121 Prompt: "update the main.go file by changing the print to say hello from crush",
122 SessionID: session.ID,
123 MaxOutputTokens: 10000,
124 })
125 require.NoError(t, err)
126 assert.NotNil(t, res)
127
128 msgs, err := env.messages.List(t.Context(), session.ID)
129 require.NoError(t, err)
130
131 foundRead := false
132 foundWrite := false
133 var readTCID, writeTCID string
134
135 for _, msg := range msgs {
136 if msg.Role == message.Assistant {
137 for _, tc := range msg.ToolCalls() {
138 if tc.Name == tools.ViewToolName {
139 readTCID = tc.ID
140 }
141 if tc.Name == tools.EditToolName || tc.Name == tools.WriteToolName {
142 writeTCID = tc.ID
143 }
144 }
145 }
146 if msg.Role == message.Tool {
147 for _, tr := range msg.ToolResults() {
148 if tr.ToolCallID == readTCID {
149 foundRead = true
150 }
151 if tr.ToolCallID == writeTCID {
152 foundWrite = true
153 }
154 }
155 }
156 }
157
158 require.True(t, foundRead, "Expected to find a read operation")
159 require.True(t, foundWrite, "Expected to find a write operation")
160
161 mainGoPath := filepath.Join(env.workingDir, "main.go")
162 content, err := os.ReadFile(mainGoPath)
163 require.NoError(t, err)
164 require.Contains(t, strings.ToLower(string(content)), "hello from crush")
165 })
166 t.Run("bash tool", func(t *testing.T) {
167 agent, env := setupAgent(t, pair)
168
169 session, err := env.sessions.Create(t.Context(), "New Session")
170 require.NoError(t, err)
171
172 res, err := agent.Run(t.Context(), SessionAgentCall{
173 Prompt: "use bash to create a file named test.txt with content 'hello bash'. do not print its timestamp",
174 SessionID: session.ID,
175 MaxOutputTokens: 10000,
176 })
177 require.NoError(t, err)
178 assert.NotNil(t, res)
179
180 msgs, err := env.messages.List(t.Context(), session.ID)
181 require.NoError(t, err)
182
183 foundBash := false
184 var bashTCID string
185
186 for _, msg := range msgs {
187 if msg.Role == message.Assistant {
188 for _, tc := range msg.ToolCalls() {
189 if tc.Name == tools.BashToolName {
190 bashTCID = tc.ID
191 }
192 }
193 }
194 if msg.Role == message.Tool {
195 for _, tr := range msg.ToolResults() {
196 if tr.ToolCallID == bashTCID {
197 foundBash = true
198 }
199 }
200 }
201 }
202
203 require.True(t, foundBash, "Expected to find a bash operation")
204
205 testFilePath := filepath.Join(env.workingDir, "test.txt")
206 content, err := os.ReadFile(testFilePath)
207 require.NoError(t, err)
208 require.Contains(t, string(content), "hello bash")
209 })
210 t.Run("download tool", func(t *testing.T) {
211 agent, env := setupAgent(t, pair)
212
213 session, err := env.sessions.Create(t.Context(), "New Session")
214 require.NoError(t, err)
215
216 res, err := agent.Run(t.Context(), SessionAgentCall{
217 Prompt: "download the file from https://example-files.online-convert.com/document/txt/example.txt and save it as example.txt",
218 SessionID: session.ID,
219 MaxOutputTokens: 10000,
220 })
221 require.NoError(t, err)
222 assert.NotNil(t, res)
223
224 msgs, err := env.messages.List(t.Context(), session.ID)
225 require.NoError(t, err)
226
227 foundDownload := false
228 var downloadTCID string
229
230 for _, msg := range msgs {
231 if msg.Role == message.Assistant {
232 for _, tc := range msg.ToolCalls() {
233 if tc.Name == tools.DownloadToolName {
234 downloadTCID = tc.ID
235 }
236 }
237 }
238 if msg.Role == message.Tool {
239 for _, tr := range msg.ToolResults() {
240 if tr.ToolCallID == downloadTCID {
241 foundDownload = true
242 }
243 }
244 }
245 }
246
247 require.True(t, foundDownload, "Expected to find a download operation")
248
249 examplePath := filepath.Join(env.workingDir, "example.txt")
250 _, err = os.Stat(examplePath)
251 require.NoError(t, err, "Expected example.txt file to exist")
252 })
253 t.Run("fetch tool", func(t *testing.T) {
254 agent, env := setupAgent(t, pair)
255
256 session, err := env.sessions.Create(t.Context(), "New Session")
257 require.NoError(t, err)
258
259 res, err := agent.Run(t.Context(), SessionAgentCall{
260 Prompt: "fetch the content from https://example-files.online-convert.com/website/html/example.html and tell me if it contains the word 'John Doe'",
261 SessionID: session.ID,
262 MaxOutputTokens: 10000,
263 })
264 require.NoError(t, err)
265 assert.NotNil(t, res)
266
267 msgs, err := env.messages.List(t.Context(), session.ID)
268 require.NoError(t, err)
269
270 foundFetch := false
271 var fetchTCID string
272
273 for _, msg := range msgs {
274 if msg.Role == message.Assistant {
275 for _, tc := range msg.ToolCalls() {
276 if tc.Name == tools.FetchToolName {
277 fetchTCID = tc.ID
278 }
279 }
280 }
281 if msg.Role == message.Tool {
282 for _, tr := range msg.ToolResults() {
283 if tr.ToolCallID == fetchTCID {
284 foundFetch = true
285 }
286 }
287 }
288 }
289
290 require.True(t, foundFetch, "Expected to find a fetch operation")
291 })
292 t.Run("glob tool", func(t *testing.T) {
293 agent, env := setupAgent(t, pair)
294
295 session, err := env.sessions.Create(t.Context(), "New Session")
296 require.NoError(t, err)
297
298 res, err := agent.Run(t.Context(), SessionAgentCall{
299 Prompt: "use glob to find all .go files in the current directory",
300 SessionID: session.ID,
301 MaxOutputTokens: 10000,
302 })
303 require.NoError(t, err)
304 assert.NotNil(t, res)
305
306 msgs, err := env.messages.List(t.Context(), session.ID)
307 require.NoError(t, err)
308
309 foundGlob := false
310 var globTCID string
311
312 for _, msg := range msgs {
313 if msg.Role == message.Assistant {
314 for _, tc := range msg.ToolCalls() {
315 if tc.Name == tools.GlobToolName {
316 globTCID = tc.ID
317 }
318 }
319 }
320 if msg.Role == message.Tool {
321 for _, tr := range msg.ToolResults() {
322 if tr.ToolCallID == globTCID {
323 foundGlob = true
324 require.Contains(t, tr.Content, "main.go", "Expected glob to find main.go")
325 }
326 }
327 }
328 }
329
330 require.True(t, foundGlob, "Expected to find a glob operation")
331 })
332 t.Run("grep tool", func(t *testing.T) {
333 agent, env := setupAgent(t, pair)
334
335 session, err := env.sessions.Create(t.Context(), "New Session")
336 require.NoError(t, err)
337
338 res, err := agent.Run(t.Context(), SessionAgentCall{
339 Prompt: "use grep to search for the word 'package' in go files",
340 SessionID: session.ID,
341 MaxOutputTokens: 10000,
342 })
343 require.NoError(t, err)
344 assert.NotNil(t, res)
345
346 msgs, err := env.messages.List(t.Context(), session.ID)
347 require.NoError(t, err)
348
349 foundGrep := false
350 var grepTCID string
351
352 for _, msg := range msgs {
353 if msg.Role == message.Assistant {
354 for _, tc := range msg.ToolCalls() {
355 if tc.Name == tools.GrepToolName {
356 grepTCID = tc.ID
357 }
358 }
359 }
360 if msg.Role == message.Tool {
361 for _, tr := range msg.ToolResults() {
362 if tr.ToolCallID == grepTCID {
363 foundGrep = true
364 require.Contains(t, tr.Content, "main.go", "Expected grep to find main.go")
365 }
366 }
367 }
368 }
369
370 require.True(t, foundGrep, "Expected to find a grep operation")
371 })
372 t.Run("ls tool", func(t *testing.T) {
373 agent, env := setupAgent(t, pair)
374
375 session, err := env.sessions.Create(t.Context(), "New Session")
376 require.NoError(t, err)
377
378 res, err := agent.Run(t.Context(), SessionAgentCall{
379 Prompt: "use ls to list the files in the current directory",
380 SessionID: session.ID,
381 MaxOutputTokens: 10000,
382 })
383 require.NoError(t, err)
384 assert.NotNil(t, res)
385
386 msgs, err := env.messages.List(t.Context(), session.ID)
387 require.NoError(t, err)
388
389 foundLS := false
390 var lsTCID string
391
392 for _, msg := range msgs {
393 if msg.Role == message.Assistant {
394 for _, tc := range msg.ToolCalls() {
395 if tc.Name == tools.LSToolName {
396 lsTCID = tc.ID
397 }
398 }
399 }
400 if msg.Role == message.Tool {
401 for _, tr := range msg.ToolResults() {
402 if tr.ToolCallID == lsTCID {
403 foundLS = true
404 require.Contains(t, tr.Content, "main.go", "Expected ls to list main.go")
405 require.Contains(t, tr.Content, "go.mod", "Expected ls to list go.mod")
406 }
407 }
408 }
409 }
410
411 require.True(t, foundLS, "Expected to find an ls operation")
412 })
413 t.Run("multiedit tool", func(t *testing.T) {
414 agent, env := setupAgent(t, pair)
415
416 session, err := env.sessions.Create(t.Context(), "New Session")
417 require.NoError(t, err)
418
419 res, err := agent.Run(t.Context(), SessionAgentCall{
420 Prompt: "use multiedit to change 'Hello, World!' to 'Hello, Crush!' and add a comment '// Greeting' above the fmt.Println line in main.go",
421 SessionID: session.ID,
422 MaxOutputTokens: 10000,
423 })
424 require.NoError(t, err)
425 assert.NotNil(t, res)
426
427 msgs, err := env.messages.List(t.Context(), session.ID)
428 require.NoError(t, err)
429
430 foundMultiEdit := false
431 var multiEditTCID string
432
433 for _, msg := range msgs {
434 if msg.Role == message.Assistant {
435 for _, tc := range msg.ToolCalls() {
436 if tc.Name == tools.MultiEditToolName {
437 multiEditTCID = tc.ID
438 }
439 }
440 }
441 if msg.Role == message.Tool {
442 for _, tr := range msg.ToolResults() {
443 if tr.ToolCallID == multiEditTCID {
444 foundMultiEdit = true
445 }
446 }
447 }
448 }
449
450 require.True(t, foundMultiEdit, "Expected to find a multiedit operation")
451
452 mainGoPath := filepath.Join(env.workingDir, "main.go")
453 content, err := os.ReadFile(mainGoPath)
454 require.NoError(t, err)
455 require.Contains(t, string(content), "Hello, Crush!", "Expected file to contain 'Hello, Crush!'")
456 })
457 t.Run("sourcegraph tool", func(t *testing.T) {
458 agent, env := setupAgent(t, pair)
459
460 session, err := env.sessions.Create(t.Context(), "New Session")
461 require.NoError(t, err)
462
463 res, err := agent.Run(t.Context(), SessionAgentCall{
464 Prompt: "use sourcegraph to search for 'func main' in Go repositories",
465 SessionID: session.ID,
466 MaxOutputTokens: 10000,
467 })
468 require.NoError(t, err)
469 assert.NotNil(t, res)
470
471 msgs, err := env.messages.List(t.Context(), session.ID)
472 require.NoError(t, err)
473
474 foundSourcegraph := false
475 var sourcegraphTCID string
476
477 for _, msg := range msgs {
478 if msg.Role == message.Assistant {
479 for _, tc := range msg.ToolCalls() {
480 if tc.Name == tools.SourcegraphToolName {
481 sourcegraphTCID = tc.ID
482 }
483 }
484 }
485 if msg.Role == message.Tool {
486 for _, tr := range msg.ToolResults() {
487 if tr.ToolCallID == sourcegraphTCID {
488 foundSourcegraph = true
489 }
490 }
491 }
492 }
493
494 require.True(t, foundSourcegraph, "Expected to find a sourcegraph operation")
495 })
496 t.Run("write tool", func(t *testing.T) {
497 agent, env := setupAgent(t, pair)
498
499 session, err := env.sessions.Create(t.Context(), "New Session")
500 require.NoError(t, err)
501
502 res, err := agent.Run(t.Context(), SessionAgentCall{
503 Prompt: "use write to create a new file called config.json with content '{\"name\": \"test\", \"version\": \"1.0.0\"}'",
504 SessionID: session.ID,
505 MaxOutputTokens: 10000,
506 })
507 require.NoError(t, err)
508 assert.NotNil(t, res)
509
510 msgs, err := env.messages.List(t.Context(), session.ID)
511 require.NoError(t, err)
512
513 foundWrite := false
514 var writeTCID string
515
516 for _, msg := range msgs {
517 if msg.Role == message.Assistant {
518 for _, tc := range msg.ToolCalls() {
519 if tc.Name == tools.WriteToolName {
520 writeTCID = tc.ID
521 }
522 }
523 }
524 if msg.Role == message.Tool {
525 for _, tr := range msg.ToolResults() {
526 if tr.ToolCallID == writeTCID {
527 foundWrite = true
528 }
529 }
530 }
531 }
532
533 require.True(t, foundWrite, "Expected to find a write operation")
534
535 configPath := filepath.Join(env.workingDir, "config.json")
536 content, err := os.ReadFile(configPath)
537 require.NoError(t, err)
538 require.Contains(t, string(content), "test", "Expected config.json to contain 'test'")
539 require.Contains(t, string(content), "1.0.0", "Expected config.json to contain '1.0.0'")
540 })
541 t.Run("parallel tool calls", func(t *testing.T) {
542 agent, env := setupAgent(t, pair)
543
544 session, err := env.sessions.Create(t.Context(), "New Session")
545 require.NoError(t, err)
546
547 res, err := agent.Run(t.Context(), SessionAgentCall{
548 Prompt: "use glob to find all .go files and use ls to list the current directory, it is very important that you run both tool calls in parallel",
549 SessionID: session.ID,
550 MaxOutputTokens: 10000,
551 })
552 require.NoError(t, err)
553 assert.NotNil(t, res)
554
555 msgs, err := env.messages.List(t.Context(), session.ID)
556 require.NoError(t, err)
557
558 var assistantMsg *message.Message
559 var toolMsgs []message.Message
560
561 for _, msg := range msgs {
562 if msg.Role == message.Assistant && len(msg.ToolCalls()) > 0 {
563 assistantMsg = &msg
564 }
565 if msg.Role == message.Tool {
566 toolMsgs = append(toolMsgs, msg)
567 }
568 }
569
570 require.NotNil(t, assistantMsg, "Expected to find an assistant message with tool calls")
571 require.NotNil(t, toolMsgs, "Expected to find a tool message")
572
573 toolCalls := assistantMsg.ToolCalls()
574 require.GreaterOrEqual(t, len(toolCalls), 2, "Expected at least 2 tool calls in parallel")
575
576 foundGlob := false
577 foundLS := false
578 var globTCID, lsTCID string
579
580 for _, tc := range toolCalls {
581 if tc.Name == tools.GlobToolName {
582 foundGlob = true
583 globTCID = tc.ID
584 }
585 if tc.Name == tools.LSToolName {
586 foundLS = true
587 lsTCID = tc.ID
588 }
589 }
590
591 require.True(t, foundGlob, "Expected to find a glob tool call")
592 require.True(t, foundLS, "Expected to find an ls tool call")
593
594 require.GreaterOrEqual(t, len(toolMsgs), 2, "Expected at least 2 tool results in the same message")
595
596 foundGlobResult := false
597 foundLSResult := false
598
599 for _, msg := range toolMsgs {
600 for _, tr := range msg.ToolResults() {
601 if tr.ToolCallID == globTCID {
602 foundGlobResult = true
603 require.Contains(t, tr.Content, "main.go", "Expected glob result to contain main.go")
604 require.False(t, tr.IsError, "Expected glob result to not be an error")
605 }
606 if tr.ToolCallID == lsTCID {
607 foundLSResult = true
608 require.Contains(t, tr.Content, "main.go", "Expected ls result to contain main.go")
609 require.False(t, tr.IsError, "Expected ls result to not be an error")
610 }
611 }
612 }
613
614 require.True(t, foundGlobResult, "Expected to find glob tool result")
615 require.True(t, foundLSResult, "Expected to find ls tool result")
616 })
617 })
618 }
619}