1package agent
2
3import (
4 "os"
5 "path/filepath"
6 "strings"
7 "testing"
8
9 "charm.land/fantasy"
10 "github.com/charmbracelet/crush/internal/agent/tools"
11 "github.com/charmbracelet/crush/internal/message"
12 "github.com/charmbracelet/crush/internal/shell"
13 "github.com/stretchr/testify/assert"
14 "github.com/stretchr/testify/require"
15 "gopkg.in/dnaeon/go-vcr.v4/pkg/recorder"
16
17 _ "github.com/joho/godotenv/autoload"
18)
19
20var modelPairs = []modelPair{
21 {"anthropic-sonnet", anthropicBuilder("claude-sonnet-4-5-20250929"), anthropicBuilder("claude-3-5-haiku-20241022")},
22 {"openai-gpt-5", openaiBuilder("gpt-5"), openaiBuilder("gpt-4o")},
23 {"openrouter-kimi-k2", openRouterBuilder("moonshotai/kimi-k2-0905"), openRouterBuilder("qwen/qwen3-next-80b-a3b-instruct")},
24 {"zai-glm4.6", zAIBuilder("glm-4.6"), zAIBuilder("glm-4.5-air")},
25}
26
27func getModels(t *testing.T, r *recorder.Recorder, pair modelPair) (fantasy.LanguageModel, fantasy.LanguageModel) {
28 large, err := pair.largeModel(t, r)
29 require.NoError(t, err)
30 small, err := pair.smallModel(t, r)
31 require.NoError(t, err)
32 return large, small
33}
34
35func setupAgent(t *testing.T, pair modelPair) (SessionAgent, env) {
36 r := newRecorder(t)
37 large, small := getModels(t, r, pair)
38 env := testEnv(t)
39
40 createSimpleGoProject(t, env.workingDir)
41 agent, err := coderAgent(r, env, large, small)
42 shell.Reset(env.workingDir)
43 require.NoError(t, err)
44 return agent, env
45}
46
47func TestCoderAgent(t *testing.T) {
48 for _, pair := range modelPairs {
49 t.Run(pair.name, func(t *testing.T) {
50 t.Run("simple test", func(t *testing.T) {
51 agent, env := setupAgent(t, pair)
52
53 session, err := env.sessions.Create(t.Context(), "New Session")
54 require.NoError(t, err)
55
56 res, err := agent.Run(t.Context(), SessionAgentCall{
57 Prompt: "Hello",
58 SessionID: session.ID,
59 MaxOutputTokens: 10000,
60 })
61 require.NoError(t, err)
62 assert.NotNil(t, res)
63
64 msgs, err := env.messages.List(t.Context(), session.ID)
65 require.NoError(t, err)
66 // Should have the agent and user message
67 assert.Equal(t, len(msgs), 2)
68 })
69 t.Run("read a file", func(t *testing.T) {
70 agent, env := setupAgent(t, pair)
71
72 session, err := env.sessions.Create(t.Context(), "New Session")
73 require.NoError(t, err)
74 res, err := agent.Run(t.Context(), SessionAgentCall{
75 Prompt: "Read the go mod",
76 SessionID: session.ID,
77 MaxOutputTokens: 10000,
78 })
79
80 require.NoError(t, err)
81 assert.NotNil(t, res)
82
83 msgs, err := env.messages.List(t.Context(), session.ID)
84 require.NoError(t, err)
85 foundFile := false
86 var tcID string
87 out:
88 for _, msg := range msgs {
89 if msg.Role == message.Assistant {
90 for _, tc := range msg.ToolCalls() {
91 if tc.Name == tools.ViewToolName {
92 tcID = tc.ID
93 }
94 }
95 }
96 if msg.Role == message.Tool {
97 for _, tr := range msg.ToolResults() {
98 if tr.ToolCallID == tcID {
99 if strings.Contains(tr.Content, "module example.com/testproject") {
100 foundFile = true
101 break out
102 }
103 }
104 }
105 }
106 }
107 require.True(t, foundFile)
108 })
109 t.Run("update a file", func(t *testing.T) {
110 agent, env := setupAgent(t, pair)
111
112 session, err := env.sessions.Create(t.Context(), "New Session")
113 require.NoError(t, err)
114
115 res, err := agent.Run(t.Context(), SessionAgentCall{
116 Prompt: "update the main.go file by changing the print to say hello from crush",
117 SessionID: session.ID,
118 MaxOutputTokens: 10000,
119 })
120 require.NoError(t, err)
121 assert.NotNil(t, res)
122
123 msgs, err := env.messages.List(t.Context(), session.ID)
124 require.NoError(t, err)
125
126 foundRead := false
127 foundWrite := false
128 var readTCID, writeTCID string
129
130 for _, msg := range msgs {
131 if msg.Role == message.Assistant {
132 for _, tc := range msg.ToolCalls() {
133 if tc.Name == tools.ViewToolName {
134 readTCID = tc.ID
135 }
136 if tc.Name == tools.EditToolName || tc.Name == tools.WriteToolName {
137 writeTCID = tc.ID
138 }
139 }
140 }
141 if msg.Role == message.Tool {
142 for _, tr := range msg.ToolResults() {
143 if tr.ToolCallID == readTCID {
144 foundRead = true
145 }
146 if tr.ToolCallID == writeTCID {
147 foundWrite = true
148 }
149 }
150 }
151 }
152
153 require.True(t, foundRead, "Expected to find a read operation")
154 require.True(t, foundWrite, "Expected to find a write operation")
155
156 mainGoPath := filepath.Join(env.workingDir, "main.go")
157 content, err := os.ReadFile(mainGoPath)
158 require.NoError(t, err)
159 require.Contains(t, strings.ToLower(string(content)), "hello from crush")
160 })
161 t.Run("bash tool", func(t *testing.T) {
162 agent, env := setupAgent(t, pair)
163
164 session, err := env.sessions.Create(t.Context(), "New Session")
165 require.NoError(t, err)
166
167 res, err := agent.Run(t.Context(), SessionAgentCall{
168 Prompt: "use bash to create a file named test.txt with content 'hello bash'. do not print its timestamp",
169 SessionID: session.ID,
170 MaxOutputTokens: 10000,
171 })
172 require.NoError(t, err)
173 assert.NotNil(t, res)
174
175 msgs, err := env.messages.List(t.Context(), session.ID)
176 require.NoError(t, err)
177
178 foundBash := false
179 var bashTCID string
180
181 for _, msg := range msgs {
182 if msg.Role == message.Assistant {
183 for _, tc := range msg.ToolCalls() {
184 if tc.Name == tools.BashToolName {
185 bashTCID = tc.ID
186 }
187 }
188 }
189 if msg.Role == message.Tool {
190 for _, tr := range msg.ToolResults() {
191 if tr.ToolCallID == bashTCID {
192 foundBash = true
193 }
194 }
195 }
196 }
197
198 require.True(t, foundBash, "Expected to find a bash operation")
199
200 testFilePath := filepath.Join(env.workingDir, "test.txt")
201 content, err := os.ReadFile(testFilePath)
202 require.NoError(t, err)
203 require.Contains(t, string(content), "hello bash")
204 })
205 t.Run("download tool", func(t *testing.T) {
206 agent, env := setupAgent(t, pair)
207
208 session, err := env.sessions.Create(t.Context(), "New Session")
209 require.NoError(t, err)
210
211 res, err := agent.Run(t.Context(), SessionAgentCall{
212 Prompt: "download the file from https://example-files.online-convert.com/document/txt/example.txt and save it as example.txt",
213 SessionID: session.ID,
214 MaxOutputTokens: 10000,
215 })
216 require.NoError(t, err)
217 assert.NotNil(t, res)
218
219 msgs, err := env.messages.List(t.Context(), session.ID)
220 require.NoError(t, err)
221
222 foundDownload := false
223 var downloadTCID string
224
225 for _, msg := range msgs {
226 if msg.Role == message.Assistant {
227 for _, tc := range msg.ToolCalls() {
228 if tc.Name == tools.DownloadToolName {
229 downloadTCID = tc.ID
230 }
231 }
232 }
233 if msg.Role == message.Tool {
234 for _, tr := range msg.ToolResults() {
235 if tr.ToolCallID == downloadTCID {
236 foundDownload = true
237 }
238 }
239 }
240 }
241
242 require.True(t, foundDownload, "Expected to find a download operation")
243
244 examplePath := filepath.Join(env.workingDir, "example.txt")
245 _, err = os.Stat(examplePath)
246 require.NoError(t, err, "Expected example.txt file to exist")
247 })
248 t.Run("fetch tool", func(t *testing.T) {
249 agent, env := setupAgent(t, pair)
250
251 session, err := env.sessions.Create(t.Context(), "New Session")
252 require.NoError(t, err)
253
254 res, err := agent.Run(t.Context(), SessionAgentCall{
255 Prompt: "fetch the content from https://example-files.online-convert.com/website/html/example.html and tell me if it contains the word 'John Doe'",
256 SessionID: session.ID,
257 MaxOutputTokens: 10000,
258 })
259 require.NoError(t, err)
260 assert.NotNil(t, res)
261
262 msgs, err := env.messages.List(t.Context(), session.ID)
263 require.NoError(t, err)
264
265 foundFetch := false
266 var fetchTCID string
267
268 for _, msg := range msgs {
269 if msg.Role == message.Assistant {
270 for _, tc := range msg.ToolCalls() {
271 if tc.Name == tools.FetchToolName {
272 fetchTCID = tc.ID
273 }
274 }
275 }
276 if msg.Role == message.Tool {
277 for _, tr := range msg.ToolResults() {
278 if tr.ToolCallID == fetchTCID {
279 foundFetch = true
280 }
281 }
282 }
283 }
284
285 require.True(t, foundFetch, "Expected to find a fetch operation")
286 })
287 t.Run("glob tool", func(t *testing.T) {
288 agent, env := setupAgent(t, pair)
289
290 session, err := env.sessions.Create(t.Context(), "New Session")
291 require.NoError(t, err)
292
293 res, err := agent.Run(t.Context(), SessionAgentCall{
294 Prompt: "use glob to find all .go files in the current directory",
295 SessionID: session.ID,
296 MaxOutputTokens: 10000,
297 })
298 require.NoError(t, err)
299 assert.NotNil(t, res)
300
301 msgs, err := env.messages.List(t.Context(), session.ID)
302 require.NoError(t, err)
303
304 foundGlob := false
305 var globTCID string
306
307 for _, msg := range msgs {
308 if msg.Role == message.Assistant {
309 for _, tc := range msg.ToolCalls() {
310 if tc.Name == tools.GlobToolName {
311 globTCID = tc.ID
312 }
313 }
314 }
315 if msg.Role == message.Tool {
316 for _, tr := range msg.ToolResults() {
317 if tr.ToolCallID == globTCID {
318 foundGlob = true
319 require.Contains(t, tr.Content, "main.go", "Expected glob to find main.go")
320 }
321 }
322 }
323 }
324
325 require.True(t, foundGlob, "Expected to find a glob operation")
326 })
327 t.Run("grep tool", func(t *testing.T) {
328 agent, env := setupAgent(t, pair)
329
330 session, err := env.sessions.Create(t.Context(), "New Session")
331 require.NoError(t, err)
332
333 res, err := agent.Run(t.Context(), SessionAgentCall{
334 Prompt: "use grep to search for the word 'package' in go files",
335 SessionID: session.ID,
336 MaxOutputTokens: 10000,
337 })
338 require.NoError(t, err)
339 assert.NotNil(t, res)
340
341 msgs, err := env.messages.List(t.Context(), session.ID)
342 require.NoError(t, err)
343
344 foundGrep := false
345 var grepTCID string
346
347 for _, msg := range msgs {
348 if msg.Role == message.Assistant {
349 for _, tc := range msg.ToolCalls() {
350 if tc.Name == tools.GrepToolName {
351 grepTCID = tc.ID
352 }
353 }
354 }
355 if msg.Role == message.Tool {
356 for _, tr := range msg.ToolResults() {
357 if tr.ToolCallID == grepTCID {
358 foundGrep = true
359 require.Contains(t, tr.Content, "main.go", "Expected grep to find main.go")
360 }
361 }
362 }
363 }
364
365 require.True(t, foundGrep, "Expected to find a grep operation")
366 })
367 t.Run("ls tool", func(t *testing.T) {
368 agent, env := setupAgent(t, pair)
369
370 session, err := env.sessions.Create(t.Context(), "New Session")
371 require.NoError(t, err)
372
373 res, err := agent.Run(t.Context(), SessionAgentCall{
374 Prompt: "use ls to list the files in the current directory",
375 SessionID: session.ID,
376 MaxOutputTokens: 10000,
377 })
378 require.NoError(t, err)
379 assert.NotNil(t, res)
380
381 msgs, err := env.messages.List(t.Context(), session.ID)
382 require.NoError(t, err)
383
384 foundLS := false
385 var lsTCID string
386
387 for _, msg := range msgs {
388 if msg.Role == message.Assistant {
389 for _, tc := range msg.ToolCalls() {
390 if tc.Name == tools.LSToolName {
391 lsTCID = tc.ID
392 }
393 }
394 }
395 if msg.Role == message.Tool {
396 for _, tr := range msg.ToolResults() {
397 if tr.ToolCallID == lsTCID {
398 foundLS = true
399 require.Contains(t, tr.Content, "main.go", "Expected ls to list main.go")
400 require.Contains(t, tr.Content, "go.mod", "Expected ls to list go.mod")
401 }
402 }
403 }
404 }
405
406 require.True(t, foundLS, "Expected to find an ls operation")
407 })
408 t.Run("multiedit tool", func(t *testing.T) {
409 agent, env := setupAgent(t, pair)
410
411 session, err := env.sessions.Create(t.Context(), "New Session")
412 require.NoError(t, err)
413
414 res, err := agent.Run(t.Context(), SessionAgentCall{
415 Prompt: "use multiedit to change 'Hello, World!' to 'Hello, Crush!' and add a comment '// Greeting' above the fmt.Println line in main.go",
416 SessionID: session.ID,
417 MaxOutputTokens: 10000,
418 })
419 require.NoError(t, err)
420 assert.NotNil(t, res)
421
422 msgs, err := env.messages.List(t.Context(), session.ID)
423 require.NoError(t, err)
424
425 foundMultiEdit := false
426 var multiEditTCID string
427
428 for _, msg := range msgs {
429 if msg.Role == message.Assistant {
430 for _, tc := range msg.ToolCalls() {
431 if tc.Name == tools.MultiEditToolName {
432 multiEditTCID = tc.ID
433 }
434 }
435 }
436 if msg.Role == message.Tool {
437 for _, tr := range msg.ToolResults() {
438 if tr.ToolCallID == multiEditTCID {
439 foundMultiEdit = true
440 }
441 }
442 }
443 }
444
445 require.True(t, foundMultiEdit, "Expected to find a multiedit operation")
446
447 mainGoPath := filepath.Join(env.workingDir, "main.go")
448 content, err := os.ReadFile(mainGoPath)
449 require.NoError(t, err)
450 require.Contains(t, string(content), "Hello, Crush!", "Expected file to contain 'Hello, Crush!'")
451 })
452 t.Run("sourcegraph tool", func(t *testing.T) {
453 agent, env := setupAgent(t, pair)
454
455 session, err := env.sessions.Create(t.Context(), "New Session")
456 require.NoError(t, err)
457
458 res, err := agent.Run(t.Context(), SessionAgentCall{
459 Prompt: "use sourcegraph to search for 'func main' in Go repositories",
460 SessionID: session.ID,
461 MaxOutputTokens: 10000,
462 })
463 require.NoError(t, err)
464 assert.NotNil(t, res)
465
466 msgs, err := env.messages.List(t.Context(), session.ID)
467 require.NoError(t, err)
468
469 foundSourcegraph := false
470 var sourcegraphTCID string
471
472 for _, msg := range msgs {
473 if msg.Role == message.Assistant {
474 for _, tc := range msg.ToolCalls() {
475 if tc.Name == tools.SourcegraphToolName {
476 sourcegraphTCID = tc.ID
477 }
478 }
479 }
480 if msg.Role == message.Tool {
481 for _, tr := range msg.ToolResults() {
482 if tr.ToolCallID == sourcegraphTCID {
483 foundSourcegraph = true
484 }
485 }
486 }
487 }
488
489 require.True(t, foundSourcegraph, "Expected to find a sourcegraph operation")
490 })
491 t.Run("write tool", func(t *testing.T) {
492 agent, env := setupAgent(t, pair)
493
494 session, err := env.sessions.Create(t.Context(), "New Session")
495 require.NoError(t, err)
496
497 res, err := agent.Run(t.Context(), SessionAgentCall{
498 Prompt: "use write to create a new file called config.json with content '{\"name\": \"test\", \"version\": \"1.0.0\"}'",
499 SessionID: session.ID,
500 MaxOutputTokens: 10000,
501 })
502 require.NoError(t, err)
503 assert.NotNil(t, res)
504
505 msgs, err := env.messages.List(t.Context(), session.ID)
506 require.NoError(t, err)
507
508 foundWrite := false
509 var writeTCID string
510
511 for _, msg := range msgs {
512 if msg.Role == message.Assistant {
513 for _, tc := range msg.ToolCalls() {
514 if tc.Name == tools.WriteToolName {
515 writeTCID = tc.ID
516 }
517 }
518 }
519 if msg.Role == message.Tool {
520 for _, tr := range msg.ToolResults() {
521 if tr.ToolCallID == writeTCID {
522 foundWrite = true
523 }
524 }
525 }
526 }
527
528 require.True(t, foundWrite, "Expected to find a write operation")
529
530 configPath := filepath.Join(env.workingDir, "config.json")
531 content, err := os.ReadFile(configPath)
532 require.NoError(t, err)
533 require.Contains(t, string(content), "test", "Expected config.json to contain 'test'")
534 require.Contains(t, string(content), "1.0.0", "Expected config.json to contain '1.0.0'")
535 })
536 t.Run("parallel tool calls", func(t *testing.T) {
537 agent, env := setupAgent(t, pair)
538
539 session, err := env.sessions.Create(t.Context(), "New Session")
540 require.NoError(t, err)
541
542 res, err := agent.Run(t.Context(), SessionAgentCall{
543 Prompt: "use glob to find all .go files and use ls to list the current directory, it is very important that you run both tool calls in parallel",
544 SessionID: session.ID,
545 MaxOutputTokens: 10000,
546 })
547 require.NoError(t, err)
548 assert.NotNil(t, res)
549
550 msgs, err := env.messages.List(t.Context(), session.ID)
551 require.NoError(t, err)
552
553 var assistantMsg *message.Message
554 var toolMsgs []message.Message
555
556 for _, msg := range msgs {
557 if msg.Role == message.Assistant && len(msg.ToolCalls()) > 0 {
558 assistantMsg = &msg
559 }
560 if msg.Role == message.Tool {
561 toolMsgs = append(toolMsgs, msg)
562 }
563 }
564
565 require.NotNil(t, assistantMsg, "Expected to find an assistant message with tool calls")
566 require.NotNil(t, toolMsgs, "Expected to find a tool message")
567
568 toolCalls := assistantMsg.ToolCalls()
569 require.GreaterOrEqual(t, len(toolCalls), 2, "Expected at least 2 tool calls in parallel")
570
571 foundGlob := false
572 foundLS := false
573 var globTCID, lsTCID string
574
575 for _, tc := range toolCalls {
576 if tc.Name == tools.GlobToolName {
577 foundGlob = true
578 globTCID = tc.ID
579 }
580 if tc.Name == tools.LSToolName {
581 foundLS = true
582 lsTCID = tc.ID
583 }
584 }
585
586 require.True(t, foundGlob, "Expected to find a glob tool call")
587 require.True(t, foundLS, "Expected to find an ls tool call")
588
589 require.GreaterOrEqual(t, len(toolMsgs), 2, "Expected at least 2 tool results in the same message")
590
591 foundGlobResult := false
592 foundLSResult := false
593
594 for _, msg := range toolMsgs {
595 for _, tr := range msg.ToolResults() {
596 if tr.ToolCallID == globTCID {
597 foundGlobResult = true
598 require.Contains(t, tr.Content, "main.go", "Expected glob result to contain main.go")
599 require.False(t, tr.IsError, "Expected glob result to not be an error")
600 }
601 if tr.ToolCallID == lsTCID {
602 foundLSResult = true
603 require.Contains(t, tr.Content, "main.go", "Expected ls result to contain main.go")
604 require.False(t, tr.IsError, "Expected ls result to not be an error")
605 }
606 }
607 }
608
609 require.True(t, foundGlobResult, "Expected to find glob tool result")
610 require.True(t, foundLSResult, "Expected to find ls tool result")
611 })
612 })
613 }
614}