1package agent
2
3import (
4 "os"
5 "path/filepath"
6 "runtime"
7 "strings"
8 "testing"
9
10 "charm.land/fantasy"
11 "github.com/charmbracelet/crush/internal/agent/tools"
12 "github.com/charmbracelet/crush/internal/message"
13 "github.com/charmbracelet/crush/internal/shell"
14 "github.com/stretchr/testify/assert"
15 "github.com/stretchr/testify/require"
16 "gopkg.in/dnaeon/go-vcr.v4/pkg/recorder"
17
18 _ "github.com/joho/godotenv/autoload"
19)
20
21var modelPairs = []modelPair{
22 {"anthropic-sonnet", anthropicBuilder("claude-sonnet-4-5-20250929"), anthropicBuilder("claude-3-5-haiku-20241022")},
23 {"openai-gpt-5", openaiBuilder("gpt-5"), openaiBuilder("gpt-4o")},
24 {"openrouter-kimi-k2", openRouterBuilder("moonshotai/kimi-k2-0905"), openRouterBuilder("qwen/qwen3-next-80b-a3b-instruct")},
25 {"zai-glm4.6", zAIBuilder("glm-4.6"), zAIBuilder("glm-4.5-air")},
26}
27
28func getModels(t *testing.T, r *recorder.Recorder, pair modelPair) (fantasy.LanguageModel, fantasy.LanguageModel) {
29 large, err := pair.largeModel(t, r)
30 require.NoError(t, err)
31 small, err := pair.smallModel(t, r)
32 require.NoError(t, err)
33 return large, small
34}
35
36func setupAgent(t *testing.T, pair modelPair) (SessionAgent, env) {
37 r := newRecorder(t)
38 large, small := getModels(t, r, pair)
39 env := testEnv(t)
40
41 createSimpleGoProject(t, env.workingDir)
42 agent, err := coderAgent(r, env, large, small)
43 shell.Reset(env.workingDir)
44 require.NoError(t, err)
45 return agent, env
46}
47
48func TestCoderAgent(t *testing.T) {
49 if runtime.GOOS == "windows" {
50 t.Skip("skipping on windows for now")
51 }
52
53 for _, pair := range modelPairs {
54 t.Run(pair.name, func(t *testing.T) {
55 t.Run("simple test", func(t *testing.T) {
56 agent, env := setupAgent(t, pair)
57
58 session, err := env.sessions.Create(t.Context(), "New Session")
59 require.NoError(t, err)
60
61 res, err := agent.Run(t.Context(), SessionAgentCall{
62 Prompt: "Hello",
63 SessionID: session.ID,
64 MaxOutputTokens: 10000,
65 })
66 require.NoError(t, err)
67 assert.NotNil(t, res)
68
69 msgs, err := env.messages.List(t.Context(), session.ID)
70 require.NoError(t, err)
71 // Should have the agent and user message
72 assert.Equal(t, len(msgs), 2)
73 })
74 t.Run("read a file", func(t *testing.T) {
75 agent, env := setupAgent(t, pair)
76
77 session, err := env.sessions.Create(t.Context(), "New Session")
78 require.NoError(t, err)
79 res, err := agent.Run(t.Context(), SessionAgentCall{
80 Prompt: "Read the go mod",
81 SessionID: session.ID,
82 MaxOutputTokens: 10000,
83 })
84
85 require.NoError(t, err)
86 assert.NotNil(t, res)
87
88 msgs, err := env.messages.List(t.Context(), session.ID)
89 require.NoError(t, err)
90 foundFile := false
91 var tcID string
92 out:
93 for _, msg := range msgs {
94 if msg.Role == message.Assistant {
95 for _, tc := range msg.ToolCalls() {
96 if tc.Name == tools.ViewToolName {
97 tcID = tc.ID
98 }
99 }
100 }
101 if msg.Role == message.Tool {
102 for _, tr := range msg.ToolResults() {
103 if tr.ToolCallID == tcID {
104 if strings.Contains(tr.Content, "module example.com/testproject") {
105 foundFile = true
106 break out
107 }
108 }
109 }
110 }
111 }
112 require.True(t, foundFile)
113 })
114 t.Run("update a file", func(t *testing.T) {
115 agent, env := setupAgent(t, pair)
116
117 session, err := env.sessions.Create(t.Context(), "New Session")
118 require.NoError(t, err)
119
120 res, err := agent.Run(t.Context(), SessionAgentCall{
121 Prompt: "update the main.go file by changing the print to say hello from crush",
122 SessionID: session.ID,
123 MaxOutputTokens: 10000,
124 })
125 require.NoError(t, err)
126 assert.NotNil(t, res)
127
128 msgs, err := env.messages.List(t.Context(), session.ID)
129 require.NoError(t, err)
130
131 foundRead := false
132 foundWrite := false
133 var readTCID, writeTCID string
134
135 for _, msg := range msgs {
136 if msg.Role == message.Assistant {
137 for _, tc := range msg.ToolCalls() {
138 if tc.Name == tools.ViewToolName {
139 readTCID = tc.ID
140 }
141 if tc.Name == tools.EditToolName || tc.Name == tools.WriteToolName {
142 writeTCID = tc.ID
143 }
144 }
145 }
146 if msg.Role == message.Tool {
147 for _, tr := range msg.ToolResults() {
148 if tr.ToolCallID == readTCID {
149 foundRead = true
150 }
151 if tr.ToolCallID == writeTCID {
152 foundWrite = true
153 }
154 }
155 }
156 }
157
158 require.True(t, foundRead, "Expected to find a read operation")
159 require.True(t, foundWrite, "Expected to find a write operation")
160
161 mainGoPath := filepath.Join(env.workingDir, "main.go")
162 content, err := os.ReadFile(mainGoPath)
163 require.NoError(t, err)
164 require.Contains(t, strings.ToLower(string(content)), "hello from crush")
165 })
166 t.Run("bash tool", func(t *testing.T) {
167 agent, env := setupAgent(t, pair)
168
169 session, err := env.sessions.Create(t.Context(), "New Session")
170 require.NoError(t, err)
171
172 res, err := agent.Run(t.Context(), SessionAgentCall{
173 Prompt: "use bash to create a file named test.txt with content 'hello bash'. do not print its timestamp",
174 SessionID: session.ID,
175 MaxOutputTokens: 10000,
176 })
177 require.NoError(t, err)
178 assert.NotNil(t, res)
179
180 msgs, err := env.messages.List(t.Context(), session.ID)
181 require.NoError(t, err)
182
183 foundBash := false
184 var bashTCID string
185
186 for _, msg := range msgs {
187 if msg.Role == message.Assistant {
188 for _, tc := range msg.ToolCalls() {
189 if tc.Name == tools.BashToolName {
190 bashTCID = tc.ID
191 }
192 }
193 }
194 if msg.Role == message.Tool {
195 for _, tr := range msg.ToolResults() {
196 if tr.ToolCallID == bashTCID {
197 foundBash = true
198 }
199 }
200 }
201 }
202
203 require.True(t, foundBash, "Expected to find a bash operation")
204
205 testFilePath := filepath.Join(env.workingDir, "test.txt")
206 content, err := os.ReadFile(testFilePath)
207 require.NoError(t, err)
208 require.Contains(t, string(content), "hello bash")
209 })
210 t.Run("download tool", func(t *testing.T) {
211 agent, env := setupAgent(t, pair)
212
213 session, err := env.sessions.Create(t.Context(), "New Session")
214 require.NoError(t, err)
215
216 res, err := agent.Run(t.Context(), SessionAgentCall{
217 Prompt: "download the file from https://example-files.online-convert.com/document/txt/example.txt and save it as example.txt",
218 SessionID: session.ID,
219 MaxOutputTokens: 10000,
220 })
221 require.NoError(t, err)
222 assert.NotNil(t, res)
223
224 msgs, err := env.messages.List(t.Context(), session.ID)
225 require.NoError(t, err)
226
227 foundDownload := false
228 var downloadTCID string
229
230 for _, msg := range msgs {
231 if msg.Role == message.Assistant {
232 for _, tc := range msg.ToolCalls() {
233 if tc.Name == tools.DownloadToolName {
234 downloadTCID = tc.ID
235 }
236 }
237 }
238 if msg.Role == message.Tool {
239 for _, tr := range msg.ToolResults() {
240 if tr.ToolCallID == downloadTCID {
241 foundDownload = true
242 }
243 }
244 }
245 }
246
247 require.True(t, foundDownload, "Expected to find a download operation")
248
249 examplePath := filepath.Join(env.workingDir, "example.txt")
250 _, err = os.Stat(examplePath)
251 require.NoError(t, err, "Expected example.txt file to exist")
252 })
253 t.Run("glob tool", func(t *testing.T) {
254 agent, env := setupAgent(t, pair)
255
256 session, err := env.sessions.Create(t.Context(), "New Session")
257 require.NoError(t, err)
258
259 res, err := agent.Run(t.Context(), SessionAgentCall{
260 Prompt: "use glob to find all .go files in the current directory",
261 SessionID: session.ID,
262 MaxOutputTokens: 10000,
263 })
264 require.NoError(t, err)
265 assert.NotNil(t, res)
266
267 msgs, err := env.messages.List(t.Context(), session.ID)
268 require.NoError(t, err)
269
270 foundGlob := false
271 var globTCID string
272
273 for _, msg := range msgs {
274 if msg.Role == message.Assistant {
275 for _, tc := range msg.ToolCalls() {
276 if tc.Name == tools.GlobToolName {
277 globTCID = tc.ID
278 }
279 }
280 }
281 if msg.Role == message.Tool {
282 for _, tr := range msg.ToolResults() {
283 if tr.ToolCallID == globTCID {
284 foundGlob = true
285 require.Contains(t, tr.Content, "main.go", "Expected glob to find main.go")
286 }
287 }
288 }
289 }
290
291 require.True(t, foundGlob, "Expected to find a glob operation")
292 })
293 t.Run("grep tool", func(t *testing.T) {
294 agent, env := setupAgent(t, pair)
295
296 session, err := env.sessions.Create(t.Context(), "New Session")
297 require.NoError(t, err)
298
299 res, err := agent.Run(t.Context(), SessionAgentCall{
300 Prompt: "use grep to search for the word 'package' in go files",
301 SessionID: session.ID,
302 MaxOutputTokens: 10000,
303 })
304 require.NoError(t, err)
305 assert.NotNil(t, res)
306
307 msgs, err := env.messages.List(t.Context(), session.ID)
308 require.NoError(t, err)
309
310 foundGrep := false
311 var grepTCID string
312
313 for _, msg := range msgs {
314 if msg.Role == message.Assistant {
315 for _, tc := range msg.ToolCalls() {
316 if tc.Name == tools.GrepToolName {
317 grepTCID = tc.ID
318 }
319 }
320 }
321 if msg.Role == message.Tool {
322 for _, tr := range msg.ToolResults() {
323 if tr.ToolCallID == grepTCID {
324 foundGrep = true
325 require.Contains(t, tr.Content, "main.go", "Expected grep to find main.go")
326 }
327 }
328 }
329 }
330
331 require.True(t, foundGrep, "Expected to find a grep operation")
332 })
333 t.Run("ls tool", func(t *testing.T) {
334 agent, env := setupAgent(t, pair)
335
336 session, err := env.sessions.Create(t.Context(), "New Session")
337 require.NoError(t, err)
338
339 res, err := agent.Run(t.Context(), SessionAgentCall{
340 Prompt: "use ls to list the files in the current directory",
341 SessionID: session.ID,
342 MaxOutputTokens: 10000,
343 })
344 require.NoError(t, err)
345 assert.NotNil(t, res)
346
347 msgs, err := env.messages.List(t.Context(), session.ID)
348 require.NoError(t, err)
349
350 foundLS := false
351 var lsTCID string
352
353 for _, msg := range msgs {
354 if msg.Role == message.Assistant {
355 for _, tc := range msg.ToolCalls() {
356 if tc.Name == tools.LSToolName {
357 lsTCID = tc.ID
358 }
359 }
360 }
361 if msg.Role == message.Tool {
362 for _, tr := range msg.ToolResults() {
363 if tr.ToolCallID == lsTCID {
364 foundLS = true
365 require.Contains(t, tr.Content, "main.go", "Expected ls to list main.go")
366 require.Contains(t, tr.Content, "go.mod", "Expected ls to list go.mod")
367 }
368 }
369 }
370 }
371
372 require.True(t, foundLS, "Expected to find an ls operation")
373 })
374 t.Run("multiedit tool", func(t *testing.T) {
375 agent, env := setupAgent(t, pair)
376
377 session, err := env.sessions.Create(t.Context(), "New Session")
378 require.NoError(t, err)
379
380 res, err := agent.Run(t.Context(), SessionAgentCall{
381 Prompt: "use multiedit to change 'Hello, World!' to 'Hello, Crush!' and add a comment '// Greeting' above the fmt.Println line in main.go",
382 SessionID: session.ID,
383 MaxOutputTokens: 10000,
384 })
385 require.NoError(t, err)
386 assert.NotNil(t, res)
387
388 msgs, err := env.messages.List(t.Context(), session.ID)
389 require.NoError(t, err)
390
391 foundMultiEdit := false
392 var multiEditTCID string
393
394 for _, msg := range msgs {
395 if msg.Role == message.Assistant {
396 for _, tc := range msg.ToolCalls() {
397 if tc.Name == tools.MultiEditToolName {
398 multiEditTCID = tc.ID
399 }
400 }
401 }
402 if msg.Role == message.Tool {
403 for _, tr := range msg.ToolResults() {
404 if tr.ToolCallID == multiEditTCID {
405 foundMultiEdit = true
406 }
407 }
408 }
409 }
410
411 require.True(t, foundMultiEdit, "Expected to find a multiedit operation")
412
413 mainGoPath := filepath.Join(env.workingDir, "main.go")
414 content, err := os.ReadFile(mainGoPath)
415 require.NoError(t, err)
416 require.Contains(t, string(content), "Hello, Crush!", "Expected file to contain 'Hello, Crush!'")
417 })
418 t.Run("sourcegraph tool", func(t *testing.T) {
419 agent, env := setupAgent(t, pair)
420
421 session, err := env.sessions.Create(t.Context(), "New Session")
422 require.NoError(t, err)
423
424 res, err := agent.Run(t.Context(), SessionAgentCall{
425 Prompt: "use sourcegraph to search for 'func main' in Go repositories",
426 SessionID: session.ID,
427 MaxOutputTokens: 10000,
428 })
429 require.NoError(t, err)
430 assert.NotNil(t, res)
431
432 msgs, err := env.messages.List(t.Context(), session.ID)
433 require.NoError(t, err)
434
435 foundSourcegraph := false
436 var sourcegraphTCID string
437
438 for _, msg := range msgs {
439 if msg.Role == message.Assistant {
440 for _, tc := range msg.ToolCalls() {
441 if tc.Name == tools.SourcegraphToolName {
442 sourcegraphTCID = tc.ID
443 }
444 }
445 }
446 if msg.Role == message.Tool {
447 for _, tr := range msg.ToolResults() {
448 if tr.ToolCallID == sourcegraphTCID {
449 foundSourcegraph = true
450 }
451 }
452 }
453 }
454
455 require.True(t, foundSourcegraph, "Expected to find a sourcegraph operation")
456 })
457 t.Run("write tool", func(t *testing.T) {
458 agent, env := setupAgent(t, pair)
459
460 session, err := env.sessions.Create(t.Context(), "New Session")
461 require.NoError(t, err)
462
463 res, err := agent.Run(t.Context(), SessionAgentCall{
464 Prompt: "use write to create a new file called config.json with content '{\"name\": \"test\", \"version\": \"1.0.0\"}'",
465 SessionID: session.ID,
466 MaxOutputTokens: 10000,
467 })
468 require.NoError(t, err)
469 assert.NotNil(t, res)
470
471 msgs, err := env.messages.List(t.Context(), session.ID)
472 require.NoError(t, err)
473
474 foundWrite := false
475 var writeTCID string
476
477 for _, msg := range msgs {
478 if msg.Role == message.Assistant {
479 for _, tc := range msg.ToolCalls() {
480 if tc.Name == tools.WriteToolName {
481 writeTCID = tc.ID
482 }
483 }
484 }
485 if msg.Role == message.Tool {
486 for _, tr := range msg.ToolResults() {
487 if tr.ToolCallID == writeTCID {
488 foundWrite = true
489 }
490 }
491 }
492 }
493
494 require.True(t, foundWrite, "Expected to find a write operation")
495
496 configPath := filepath.Join(env.workingDir, "config.json")
497 content, err := os.ReadFile(configPath)
498 require.NoError(t, err)
499 require.Contains(t, string(content), "test", "Expected config.json to contain 'test'")
500 require.Contains(t, string(content), "1.0.0", "Expected config.json to contain '1.0.0'")
501 })
502 t.Run("parallel tool calls", func(t *testing.T) {
503 agent, env := setupAgent(t, pair)
504
505 session, err := env.sessions.Create(t.Context(), "New Session")
506 require.NoError(t, err)
507
508 res, err := agent.Run(t.Context(), SessionAgentCall{
509 Prompt: "use glob to find all .go files and use ls to list the current directory, it is very important that you run both tool calls in parallel",
510 SessionID: session.ID,
511 MaxOutputTokens: 10000,
512 })
513 require.NoError(t, err)
514 assert.NotNil(t, res)
515
516 msgs, err := env.messages.List(t.Context(), session.ID)
517 require.NoError(t, err)
518
519 var assistantMsg *message.Message
520 var toolMsgs []message.Message
521
522 for _, msg := range msgs {
523 if msg.Role == message.Assistant && len(msg.ToolCalls()) > 0 {
524 assistantMsg = &msg
525 }
526 if msg.Role == message.Tool {
527 toolMsgs = append(toolMsgs, msg)
528 }
529 }
530
531 require.NotNil(t, assistantMsg, "Expected to find an assistant message with tool calls")
532 require.NotNil(t, toolMsgs, "Expected to find a tool message")
533
534 toolCalls := assistantMsg.ToolCalls()
535 require.GreaterOrEqual(t, len(toolCalls), 2, "Expected at least 2 tool calls in parallel")
536
537 foundGlob := false
538 foundLS := false
539 var globTCID, lsTCID string
540
541 for _, tc := range toolCalls {
542 if tc.Name == tools.GlobToolName {
543 foundGlob = true
544 globTCID = tc.ID
545 }
546 if tc.Name == tools.LSToolName {
547 foundLS = true
548 lsTCID = tc.ID
549 }
550 }
551
552 require.True(t, foundGlob, "Expected to find a glob tool call")
553 require.True(t, foundLS, "Expected to find an ls tool call")
554
555 require.GreaterOrEqual(t, len(toolMsgs), 2, "Expected at least 2 tool results in the same message")
556
557 foundGlobResult := false
558 foundLSResult := false
559
560 for _, msg := range toolMsgs {
561 for _, tr := range msg.ToolResults() {
562 if tr.ToolCallID == globTCID {
563 foundGlobResult = true
564 require.Contains(t, tr.Content, "main.go", "Expected glob result to contain main.go")
565 require.False(t, tr.IsError, "Expected glob result to not be an error")
566 }
567 if tr.ToolCallID == lsTCID {
568 foundLSResult = true
569 require.Contains(t, tr.Content, "main.go", "Expected ls result to contain main.go")
570 require.False(t, tr.IsError, "Expected ls result to not be an error")
571 }
572 }
573 }
574
575 require.True(t, foundGlobResult, "Expected to find glob tool result")
576 require.True(t, foundLSResult, "Expected to find ls tool result")
577 })
578 })
579 }
580}