1package agent
2
3import (
4 "os"
5 "path/filepath"
6 "strings"
7 "testing"
8
9 "charm.land/fantasy"
10 "github.com/charmbracelet/crush/internal/agent/tools"
11 "github.com/charmbracelet/crush/internal/message"
12 "github.com/charmbracelet/crush/internal/shell"
13 "github.com/stretchr/testify/assert"
14 "github.com/stretchr/testify/require"
15 "gopkg.in/dnaeon/go-vcr.v4/pkg/recorder"
16
17 _ "github.com/joho/godotenv/autoload"
18)
19
20var modelPairs = []modelPair{
21 {"anthropic-sonnet", anthropicBuilder("claude-sonnet-4-5-20250929"), anthropicBuilder("claude-3-5-haiku-20241022")},
22 {"openai-gpt-5", openaiBuilder("gpt-5"), openaiBuilder("gpt-4o")},
23 {"openrouter-kimi-k2", openRouterBuilder("moonshotai/kimi-k2-0905"), openRouterBuilder("qwen/qwen3-next-80b-a3b-instruct")},
24 {"zai-glm4.6", zAIBuilder("glm-4.6"), zAIBuilder("glm-4.5-air")},
25}
26
27func getModels(t *testing.T, r *recorder.Recorder, pair modelPair) (fantasy.LanguageModel, fantasy.LanguageModel) {
28 large, err := pair.largeModel(t, r)
29 require.NoError(t, err)
30 small, err := pair.smallModel(t, r)
31 require.NoError(t, err)
32 return large, small
33}
34
35func setupAgent(t *testing.T, pair modelPair) (SessionAgent, env) {
36 r := newRecorder(t)
37 large, small := getModels(t, r, pair)
38 env := testEnv(t)
39
40 createSimpleGoProject(t, env.workingDir)
41 agent, err := coderAgent(r, env, large, small)
42 shell.Reset(env.workingDir)
43 require.NoError(t, err)
44 return agent, env
45}
46
47func TestCoderAgent(t *testing.T) {
48 for _, pair := range modelPairs {
49 t.Run(pair.name, func(t *testing.T) {
50 t.Run("simple test", func(t *testing.T) {
51 agent, env := setupAgent(t, pair)
52
53 session, err := env.sessions.Create(t.Context(), "New Session")
54 require.NoError(t, err)
55
56 res, err := agent.Run(t.Context(), SessionAgentCall{
57 Prompt: "Hello",
58 SessionID: session.ID,
59 MaxOutputTokens: 10000,
60 })
61 require.NoError(t, err)
62 assert.NotNil(t, res)
63
64 msgs, err := env.messages.List(t.Context(), session.ID)
65 require.NoError(t, err)
66 // Should have the agent and user message
67 assert.Equal(t, len(msgs), 2)
68 })
69 t.Run("read a file", func(t *testing.T) {
70 agent, env := setupAgent(t, pair)
71
72 session, err := env.sessions.Create(t.Context(), "New Session")
73 require.NoError(t, err)
74 res, err := agent.Run(t.Context(), SessionAgentCall{
75 Prompt: "Read the go mod",
76 SessionID: session.ID,
77 MaxOutputTokens: 10000,
78 })
79
80 require.NoError(t, err)
81 assert.NotNil(t, res)
82
83 msgs, err := env.messages.List(t.Context(), session.ID)
84 require.NoError(t, err)
85 foundFile := false
86 var tcID string
87 out:
88 for _, msg := range msgs {
89 if msg.Role == message.Assistant {
90 for _, tc := range msg.ToolCalls() {
91 if tc.Name == tools.ViewToolName {
92 tcID = tc.ID
93 }
94 }
95 }
96 if msg.Role == message.Tool {
97 for _, tr := range msg.ToolResults() {
98 if tr.ToolCallID == tcID {
99 if strings.Contains(tr.Content, "module example.com/testproject") {
100 foundFile = true
101 break out
102 }
103 }
104 }
105 }
106 }
107 require.True(t, foundFile)
108 })
109 t.Run("update a file", func(t *testing.T) {
110 agent, env := setupAgent(t, pair)
111
112 session, err := env.sessions.Create(t.Context(), "New Session")
113 require.NoError(t, err)
114
115 res, err := agent.Run(t.Context(), SessionAgentCall{
116 Prompt: "update the main.go file by changing the print to say hello from crush",
117 SessionID: session.ID,
118 MaxOutputTokens: 10000,
119 })
120 require.NoError(t, err)
121 assert.NotNil(t, res)
122
123 msgs, err := env.messages.List(t.Context(), session.ID)
124 require.NoError(t, err)
125
126 foundRead := false
127 foundWrite := false
128 var readTCID, writeTCID string
129
130 for _, msg := range msgs {
131 if msg.Role == message.Assistant {
132 for _, tc := range msg.ToolCalls() {
133 if tc.Name == tools.ViewToolName {
134 readTCID = tc.ID
135 }
136 if tc.Name == tools.EditToolName || tc.Name == tools.WriteToolName {
137 writeTCID = tc.ID
138 }
139 }
140 }
141 if msg.Role == message.Tool {
142 for _, tr := range msg.ToolResults() {
143 if tr.ToolCallID == readTCID {
144 foundRead = true
145 }
146 if tr.ToolCallID == writeTCID {
147 foundWrite = true
148 }
149 }
150 }
151 }
152
153 require.True(t, foundRead, "Expected to find a read operation")
154 require.True(t, foundWrite, "Expected to find a write operation")
155
156 mainGoPath := filepath.Join(env.workingDir, "main.go")
157 content, err := os.ReadFile(mainGoPath)
158 require.NoError(t, err)
159 require.Contains(t, strings.ToLower(string(content)), "hello from crush")
160 })
161 t.Run("bash tool", func(t *testing.T) {
162 agent, env := setupAgent(t, pair)
163
164 session, err := env.sessions.Create(t.Context(), "New Session")
165 require.NoError(t, err)
166
167 res, err := agent.Run(t.Context(), SessionAgentCall{
168 Prompt: "use bash to create a file named test.txt with content 'hello bash'",
169 SessionID: session.ID,
170 MaxOutputTokens: 10000,
171 })
172 require.NoError(t, err)
173 assert.NotNil(t, res)
174
175 msgs, err := env.messages.List(t.Context(), session.ID)
176 require.NoError(t, err)
177
178 foundBash := false
179 var bashTCID string
180
181 for _, msg := range msgs {
182 if msg.Role == message.Assistant {
183 for _, tc := range msg.ToolCalls() {
184 if tc.Name == tools.BashToolName {
185 bashTCID = tc.ID
186 }
187 }
188 }
189 if msg.Role == message.Tool {
190 for _, tr := range msg.ToolResults() {
191 if tr.ToolCallID == bashTCID {
192 foundBash = true
193 }
194 }
195 }
196 }
197
198 require.True(t, foundBash, "Expected to find a bash operation")
199
200 testFilePath := filepath.Join(env.workingDir, "test.txt")
201 content, err := os.ReadFile(testFilePath)
202 require.NoError(t, err)
203 require.Contains(t, string(content), "hello bash")
204 })
205 t.Run("download tool", func(t *testing.T) {
206 agent, env := setupAgent(t, pair)
207
208 session, err := env.sessions.Create(t.Context(), "New Session")
209 require.NoError(t, err)
210
211 res, err := agent.Run(t.Context(), SessionAgentCall{
212 Prompt: "download the file from https://example-files.online-convert.com/document/txt/example.txt and save it as example.txt",
213 SessionID: session.ID,
214 MaxOutputTokens: 10000,
215 })
216 require.NoError(t, err)
217 assert.NotNil(t, res)
218
219 msgs, err := env.messages.List(t.Context(), session.ID)
220 require.NoError(t, err)
221
222 foundDownload := false
223 var downloadTCID string
224
225 for _, msg := range msgs {
226 if msg.Role == message.Assistant {
227 for _, tc := range msg.ToolCalls() {
228 if tc.Name == tools.DownloadToolName {
229 downloadTCID = tc.ID
230 }
231 }
232 }
233 if msg.Role == message.Tool {
234 for _, tr := range msg.ToolResults() {
235 if tr.ToolCallID == downloadTCID {
236 foundDownload = true
237 }
238 }
239 }
240 }
241
242 require.True(t, foundDownload, "Expected to find a download operation")
243
244 examplePath := filepath.Join(env.workingDir, "example.txt")
245 _, err = os.Stat(examplePath)
246 require.NoError(t, err, "Expected example.txt file to exist")
247 })
248 t.Run("glob tool", func(t *testing.T) {
249 agent, env := setupAgent(t, pair)
250
251 session, err := env.sessions.Create(t.Context(), "New Session")
252 require.NoError(t, err)
253
254 res, err := agent.Run(t.Context(), SessionAgentCall{
255 Prompt: "use glob to find all .go files in the current directory",
256 SessionID: session.ID,
257 MaxOutputTokens: 10000,
258 })
259 require.NoError(t, err)
260 assert.NotNil(t, res)
261
262 msgs, err := env.messages.List(t.Context(), session.ID)
263 require.NoError(t, err)
264
265 foundGlob := false
266 var globTCID string
267
268 for _, msg := range msgs {
269 if msg.Role == message.Assistant {
270 for _, tc := range msg.ToolCalls() {
271 if tc.Name == tools.GlobToolName {
272 globTCID = tc.ID
273 }
274 }
275 }
276 if msg.Role == message.Tool {
277 for _, tr := range msg.ToolResults() {
278 if tr.ToolCallID == globTCID {
279 foundGlob = true
280 require.Contains(t, tr.Content, "main.go", "Expected glob to find main.go")
281 }
282 }
283 }
284 }
285
286 require.True(t, foundGlob, "Expected to find a glob operation")
287 })
288 t.Run("grep tool", func(t *testing.T) {
289 agent, env := setupAgent(t, pair)
290
291 session, err := env.sessions.Create(t.Context(), "New Session")
292 require.NoError(t, err)
293
294 res, err := agent.Run(t.Context(), SessionAgentCall{
295 Prompt: "use grep to search for the word 'package' in go files",
296 SessionID: session.ID,
297 MaxOutputTokens: 10000,
298 })
299 require.NoError(t, err)
300 assert.NotNil(t, res)
301
302 msgs, err := env.messages.List(t.Context(), session.ID)
303 require.NoError(t, err)
304
305 foundGrep := false
306 var grepTCID string
307
308 for _, msg := range msgs {
309 if msg.Role == message.Assistant {
310 for _, tc := range msg.ToolCalls() {
311 if tc.Name == tools.GrepToolName {
312 grepTCID = tc.ID
313 }
314 }
315 }
316 if msg.Role == message.Tool {
317 for _, tr := range msg.ToolResults() {
318 if tr.ToolCallID == grepTCID {
319 foundGrep = true
320 require.Contains(t, tr.Content, "main.go", "Expected grep to find main.go")
321 }
322 }
323 }
324 }
325
326 require.True(t, foundGrep, "Expected to find a grep operation")
327 })
328 t.Run("ls tool", func(t *testing.T) {
329 agent, env := setupAgent(t, pair)
330
331 session, err := env.sessions.Create(t.Context(), "New Session")
332 require.NoError(t, err)
333
334 res, err := agent.Run(t.Context(), SessionAgentCall{
335 Prompt: "use ls to list the files in the current directory",
336 SessionID: session.ID,
337 MaxOutputTokens: 10000,
338 })
339 require.NoError(t, err)
340 assert.NotNil(t, res)
341
342 msgs, err := env.messages.List(t.Context(), session.ID)
343 require.NoError(t, err)
344
345 foundLS := false
346 var lsTCID string
347
348 for _, msg := range msgs {
349 if msg.Role == message.Assistant {
350 for _, tc := range msg.ToolCalls() {
351 if tc.Name == tools.LSToolName {
352 lsTCID = tc.ID
353 }
354 }
355 }
356 if msg.Role == message.Tool {
357 for _, tr := range msg.ToolResults() {
358 if tr.ToolCallID == lsTCID {
359 foundLS = true
360 require.Contains(t, tr.Content, "main.go", "Expected ls to list main.go")
361 require.Contains(t, tr.Content, "go.mod", "Expected ls to list go.mod")
362 }
363 }
364 }
365 }
366
367 require.True(t, foundLS, "Expected to find an ls operation")
368 })
369 t.Run("multiedit tool", func(t *testing.T) {
370 agent, env := setupAgent(t, pair)
371
372 session, err := env.sessions.Create(t.Context(), "New Session")
373 require.NoError(t, err)
374
375 res, err := agent.Run(t.Context(), SessionAgentCall{
376 Prompt: "use multiedit to change 'Hello, World!' to 'Hello, Crush!' and add a comment '// Greeting' above the fmt.Println line in main.go",
377 SessionID: session.ID,
378 MaxOutputTokens: 10000,
379 })
380 require.NoError(t, err)
381 assert.NotNil(t, res)
382
383 msgs, err := env.messages.List(t.Context(), session.ID)
384 require.NoError(t, err)
385
386 foundMultiEdit := false
387 var multiEditTCID string
388
389 for _, msg := range msgs {
390 if msg.Role == message.Assistant {
391 for _, tc := range msg.ToolCalls() {
392 if tc.Name == tools.MultiEditToolName {
393 multiEditTCID = tc.ID
394 }
395 }
396 }
397 if msg.Role == message.Tool {
398 for _, tr := range msg.ToolResults() {
399 if tr.ToolCallID == multiEditTCID {
400 foundMultiEdit = true
401 }
402 }
403 }
404 }
405
406 require.True(t, foundMultiEdit, "Expected to find a multiedit operation")
407
408 mainGoPath := filepath.Join(env.workingDir, "main.go")
409 content, err := os.ReadFile(mainGoPath)
410 require.NoError(t, err)
411 require.Contains(t, string(content), "Hello, Crush!", "Expected file to contain 'Hello, Crush!'")
412 })
413 t.Run("sourcegraph tool", func(t *testing.T) {
414 agent, env := setupAgent(t, pair)
415
416 session, err := env.sessions.Create(t.Context(), "New Session")
417 require.NoError(t, err)
418
419 res, err := agent.Run(t.Context(), SessionAgentCall{
420 Prompt: "use sourcegraph to search for 'func main' in Go repositories",
421 SessionID: session.ID,
422 MaxOutputTokens: 10000,
423 })
424 require.NoError(t, err)
425 assert.NotNil(t, res)
426
427 msgs, err := env.messages.List(t.Context(), session.ID)
428 require.NoError(t, err)
429
430 foundSourcegraph := false
431 var sourcegraphTCID string
432
433 for _, msg := range msgs {
434 if msg.Role == message.Assistant {
435 for _, tc := range msg.ToolCalls() {
436 if tc.Name == tools.SourcegraphToolName {
437 sourcegraphTCID = tc.ID
438 }
439 }
440 }
441 if msg.Role == message.Tool {
442 for _, tr := range msg.ToolResults() {
443 if tr.ToolCallID == sourcegraphTCID {
444 foundSourcegraph = true
445 }
446 }
447 }
448 }
449
450 require.True(t, foundSourcegraph, "Expected to find a sourcegraph operation")
451 })
452 t.Run("write tool", func(t *testing.T) {
453 agent, env := setupAgent(t, pair)
454
455 session, err := env.sessions.Create(t.Context(), "New Session")
456 require.NoError(t, err)
457
458 res, err := agent.Run(t.Context(), SessionAgentCall{
459 Prompt: "use write to create a new file called config.json with content '{\"name\": \"test\", \"version\": \"1.0.0\"}'",
460 SessionID: session.ID,
461 MaxOutputTokens: 10000,
462 })
463 require.NoError(t, err)
464 assert.NotNil(t, res)
465
466 msgs, err := env.messages.List(t.Context(), session.ID)
467 require.NoError(t, err)
468
469 foundWrite := false
470 var writeTCID string
471
472 for _, msg := range msgs {
473 if msg.Role == message.Assistant {
474 for _, tc := range msg.ToolCalls() {
475 if tc.Name == tools.WriteToolName {
476 writeTCID = tc.ID
477 }
478 }
479 }
480 if msg.Role == message.Tool {
481 for _, tr := range msg.ToolResults() {
482 if tr.ToolCallID == writeTCID {
483 foundWrite = true
484 }
485 }
486 }
487 }
488
489 require.True(t, foundWrite, "Expected to find a write operation")
490
491 configPath := filepath.Join(env.workingDir, "config.json")
492 content, err := os.ReadFile(configPath)
493 require.NoError(t, err)
494 require.Contains(t, string(content), "test", "Expected config.json to contain 'test'")
495 require.Contains(t, string(content), "1.0.0", "Expected config.json to contain '1.0.0'")
496 })
497 t.Run("parallel tool calls", func(t *testing.T) {
498 agent, env := setupAgent(t, pair)
499
500 session, err := env.sessions.Create(t.Context(), "New Session")
501 require.NoError(t, err)
502
503 res, err := agent.Run(t.Context(), SessionAgentCall{
504 Prompt: "use glob to find all .go files and use ls to list the current directory, it is very important that you run both tool calls in parallel",
505 SessionID: session.ID,
506 MaxOutputTokens: 10000,
507 })
508 require.NoError(t, err)
509 assert.NotNil(t, res)
510
511 msgs, err := env.messages.List(t.Context(), session.ID)
512 require.NoError(t, err)
513
514 var assistantMsg *message.Message
515 var toolMsgs []message.Message
516
517 for _, msg := range msgs {
518 if msg.Role == message.Assistant && len(msg.ToolCalls()) > 0 {
519 assistantMsg = &msg
520 }
521 if msg.Role == message.Tool {
522 toolMsgs = append(toolMsgs, msg)
523 }
524 }
525
526 require.NotNil(t, assistantMsg, "Expected to find an assistant message with tool calls")
527 require.NotNil(t, toolMsgs, "Expected to find a tool message")
528
529 toolCalls := assistantMsg.ToolCalls()
530 require.GreaterOrEqual(t, len(toolCalls), 2, "Expected at least 2 tool calls in parallel")
531
532 foundGlob := false
533 foundLS := false
534 var globTCID, lsTCID string
535
536 for _, tc := range toolCalls {
537 if tc.Name == tools.GlobToolName {
538 foundGlob = true
539 globTCID = tc.ID
540 }
541 if tc.Name == tools.LSToolName {
542 foundLS = true
543 lsTCID = tc.ID
544 }
545 }
546
547 require.True(t, foundGlob, "Expected to find a glob tool call")
548 require.True(t, foundLS, "Expected to find an ls tool call")
549
550 require.GreaterOrEqual(t, len(toolMsgs), 2, "Expected at least 2 tool results in the same message")
551
552 foundGlobResult := false
553 foundLSResult := false
554
555 for _, msg := range toolMsgs {
556 for _, tr := range msg.ToolResults() {
557 if tr.ToolCallID == globTCID {
558 foundGlobResult = true
559 require.Contains(t, tr.Content, "main.go", "Expected glob result to contain main.go")
560 require.False(t, tr.IsError, "Expected glob result to not be an error")
561 }
562 if tr.ToolCallID == lsTCID {
563 foundLSResult = true
564 require.Contains(t, tr.Content, "main.go", "Expected ls result to contain main.go")
565 require.False(t, tr.IsError, "Expected ls result to not be an error")
566 }
567 }
568 }
569
570 require.True(t, foundGlobResult, "Expected to find glob tool result")
571 require.True(t, foundLSResult, "Expected to find ls tool result")
572 })
573 })
574 }
575}