1package agent
2
3import (
4 "fmt"
5 "os"
6 "path/filepath"
7 "runtime"
8 "strings"
9 "testing"
10
11 "charm.land/fantasy"
12 "charm.land/x/vcr"
13 "github.com/charmbracelet/crush/internal/agent/tools"
14 "github.com/charmbracelet/crush/internal/message"
15 "github.com/charmbracelet/crush/internal/session"
16 "github.com/stretchr/testify/assert"
17 "github.com/stretchr/testify/require"
18
19 _ "github.com/joho/godotenv/autoload"
20)
21
22var modelPairs = []modelPair{
23 {"anthropic-sonnet", anthropicBuilder("claude-sonnet-4-5-20250929"), anthropicBuilder("claude-3-5-haiku-20241022")},
24 {"openai-gpt-5", openaiBuilder("gpt-5"), openaiBuilder("gpt-4o")},
25 {"openrouter-kimi-k2", openRouterBuilder("moonshotai/kimi-k2-0905"), openRouterBuilder("qwen/qwen3-next-80b-a3b-instruct")},
26 {"zai-glm4.6", zAIBuilder("glm-4.6"), zAIBuilder("glm-4.5-air")},
27}
28
29func getModels(t *testing.T, r *vcr.Recorder, pair modelPair) (fantasy.LanguageModel, fantasy.LanguageModel) {
30 large, err := pair.largeModel(t, r)
31 require.NoError(t, err)
32 small, err := pair.smallModel(t, r)
33 require.NoError(t, err)
34 return large, small
35}
36
37func setupAgent(t *testing.T, pair modelPair) (SessionAgent, fakeEnv) {
38 r := vcr.NewRecorder(t)
39 large, small := getModels(t, r, pair)
40 env := testEnv(t)
41
42 createSimpleGoProject(t, env.workingDir)
43 agent, err := coderAgent(r, env, large, small)
44 require.NoError(t, err)
45 return agent, env
46}
47
48func TestCoderAgent(t *testing.T) {
49 if runtime.GOOS == "windows" {
50 t.Skip("skipping on windows for now")
51 }
52
53 for _, pair := range modelPairs {
54 t.Run(pair.name, func(t *testing.T) {
55 t.Run("simple test", func(t *testing.T) {
56 agent, env := setupAgent(t, pair)
57
58 session, err := env.sessions.Create(t.Context(), "New Session")
59 require.NoError(t, err)
60
61 res, err := agent.Run(t.Context(), SessionAgentCall{
62 Prompt: "Hello",
63 SessionID: session.ID,
64 MaxOutputTokens: 10000,
65 })
66 require.NoError(t, err)
67 assert.NotNil(t, res)
68
69 msgs, err := env.messages.List(t.Context(), session.ID)
70 require.NoError(t, err)
71 // Should have the agent and user message
72 assert.Equal(t, len(msgs), 2)
73 })
74 t.Run("read a file", func(t *testing.T) {
75 agent, env := setupAgent(t, pair)
76
77 session, err := env.sessions.Create(t.Context(), "New Session")
78 require.NoError(t, err)
79 res, err := agent.Run(t.Context(), SessionAgentCall{
80 Prompt: "Read the go mod",
81 SessionID: session.ID,
82 MaxOutputTokens: 10000,
83 })
84
85 require.NoError(t, err)
86 assert.NotNil(t, res)
87
88 msgs, err := env.messages.List(t.Context(), session.ID)
89 require.NoError(t, err)
90 foundFile := false
91 var tcID string
92 out:
93 for _, msg := range msgs {
94 if msg.Role == message.Assistant {
95 for _, tc := range msg.ToolCalls() {
96 if tc.Name == tools.ViewToolName {
97 tcID = tc.ID
98 }
99 }
100 }
101 if msg.Role == message.Tool {
102 for _, tr := range msg.ToolResults() {
103 if tr.ToolCallID == tcID {
104 if strings.Contains(tr.Content, "module example.com/testproject") {
105 foundFile = true
106 break out
107 }
108 }
109 }
110 }
111 }
112 require.True(t, foundFile)
113 })
114 t.Run("update a file", func(t *testing.T) {
115 agent, env := setupAgent(t, pair)
116
117 session, err := env.sessions.Create(t.Context(), "New Session")
118 require.NoError(t, err)
119
120 res, err := agent.Run(t.Context(), SessionAgentCall{
121 Prompt: "update the main.go file by changing the print to say hello from crush",
122 SessionID: session.ID,
123 MaxOutputTokens: 10000,
124 })
125 require.NoError(t, err)
126 assert.NotNil(t, res)
127
128 msgs, err := env.messages.List(t.Context(), session.ID)
129 require.NoError(t, err)
130
131 foundRead := false
132 foundWrite := false
133 var readTCID, writeTCID string
134
135 for _, msg := range msgs {
136 if msg.Role == message.Assistant {
137 for _, tc := range msg.ToolCalls() {
138 if tc.Name == tools.ViewToolName {
139 readTCID = tc.ID
140 }
141 if tc.Name == tools.EditToolName || tc.Name == tools.WriteToolName {
142 writeTCID = tc.ID
143 }
144 }
145 }
146 if msg.Role == message.Tool {
147 for _, tr := range msg.ToolResults() {
148 if tr.ToolCallID == readTCID {
149 foundRead = true
150 }
151 if tr.ToolCallID == writeTCID {
152 foundWrite = true
153 }
154 }
155 }
156 }
157
158 require.True(t, foundRead, "Expected to find a read operation")
159 require.True(t, foundWrite, "Expected to find a write operation")
160
161 mainGoPath := filepath.Join(env.workingDir, "main.go")
162 content, err := os.ReadFile(mainGoPath)
163 require.NoError(t, err)
164 require.Contains(t, strings.ToLower(string(content)), "hello from crush")
165 })
166 t.Run("bash tool", func(t *testing.T) {
167 agent, env := setupAgent(t, pair)
168
169 session, err := env.sessions.Create(t.Context(), "New Session")
170 require.NoError(t, err)
171
172 res, err := agent.Run(t.Context(), SessionAgentCall{
173 Prompt: "use bash to create a file named test.txt with content 'hello bash'. do not print its timestamp",
174 SessionID: session.ID,
175 MaxOutputTokens: 10000,
176 })
177 require.NoError(t, err)
178 assert.NotNil(t, res)
179
180 msgs, err := env.messages.List(t.Context(), session.ID)
181 require.NoError(t, err)
182
183 foundBash := false
184 var bashTCID string
185
186 for _, msg := range msgs {
187 if msg.Role == message.Assistant {
188 for _, tc := range msg.ToolCalls() {
189 if tc.Name == tools.BashToolName {
190 bashTCID = tc.ID
191 }
192 }
193 }
194 if msg.Role == message.Tool {
195 for _, tr := range msg.ToolResults() {
196 if tr.ToolCallID == bashTCID {
197 foundBash = true
198 }
199 }
200 }
201 }
202
203 require.True(t, foundBash, "Expected to find a bash operation")
204
205 testFilePath := filepath.Join(env.workingDir, "test.txt")
206 content, err := os.ReadFile(testFilePath)
207 require.NoError(t, err)
208 require.Contains(t, string(content), "hello bash")
209 })
210 t.Run("download tool", func(t *testing.T) {
211 agent, env := setupAgent(t, pair)
212
213 session, err := env.sessions.Create(t.Context(), "New Session")
214 require.NoError(t, err)
215
216 res, err := agent.Run(t.Context(), SessionAgentCall{
217 Prompt: "download the file from https://example-files.online-convert.com/document/txt/example.txt and save it as example.txt",
218 SessionID: session.ID,
219 MaxOutputTokens: 10000,
220 })
221 require.NoError(t, err)
222 assert.NotNil(t, res)
223
224 msgs, err := env.messages.List(t.Context(), session.ID)
225 require.NoError(t, err)
226
227 foundDownload := false
228 var downloadTCID string
229
230 for _, msg := range msgs {
231 if msg.Role == message.Assistant {
232 for _, tc := range msg.ToolCalls() {
233 if tc.Name == tools.DownloadToolName {
234 downloadTCID = tc.ID
235 }
236 }
237 }
238 if msg.Role == message.Tool {
239 for _, tr := range msg.ToolResults() {
240 if tr.ToolCallID == downloadTCID {
241 foundDownload = true
242 }
243 }
244 }
245 }
246
247 require.True(t, foundDownload, "Expected to find a download operation")
248
249 examplePath := filepath.Join(env.workingDir, "example.txt")
250 _, err = os.Stat(examplePath)
251 require.NoError(t, err, "Expected example.txt file to exist")
252 })
253 t.Run("fetch tool", func(t *testing.T) {
254 agent, env := setupAgent(t, pair)
255
256 session, err := env.sessions.Create(t.Context(), "New Session")
257 require.NoError(t, err)
258
259 res, err := agent.Run(t.Context(), SessionAgentCall{
260 Prompt: "fetch the content from https://example-files.online-convert.com/website/html/example.html and tell me if it contains the word 'John Doe'",
261 SessionID: session.ID,
262 MaxOutputTokens: 10000,
263 })
264 require.NoError(t, err)
265 assert.NotNil(t, res)
266
267 msgs, err := env.messages.List(t.Context(), session.ID)
268 require.NoError(t, err)
269
270 foundFetch := false
271 var fetchTCID string
272
273 for _, msg := range msgs {
274 if msg.Role == message.Assistant {
275 for _, tc := range msg.ToolCalls() {
276 if tc.Name == tools.FetchToolName {
277 fetchTCID = tc.ID
278 }
279 }
280 }
281 if msg.Role == message.Tool {
282 for _, tr := range msg.ToolResults() {
283 if tr.ToolCallID == fetchTCID {
284 foundFetch = true
285 }
286 }
287 }
288 }
289
290 require.True(t, foundFetch, "Expected to find a fetch operation")
291 })
292 t.Run("glob tool", func(t *testing.T) {
293 agent, env := setupAgent(t, pair)
294
295 session, err := env.sessions.Create(t.Context(), "New Session")
296 require.NoError(t, err)
297
298 res, err := agent.Run(t.Context(), SessionAgentCall{
299 Prompt: "use glob to find all .go files in the current directory",
300 SessionID: session.ID,
301 MaxOutputTokens: 10000,
302 })
303 require.NoError(t, err)
304 assert.NotNil(t, res)
305
306 msgs, err := env.messages.List(t.Context(), session.ID)
307 require.NoError(t, err)
308
309 foundGlob := false
310 var globTCID string
311
312 for _, msg := range msgs {
313 if msg.Role == message.Assistant {
314 for _, tc := range msg.ToolCalls() {
315 if tc.Name == tools.GlobToolName {
316 globTCID = tc.ID
317 }
318 }
319 }
320 if msg.Role == message.Tool {
321 for _, tr := range msg.ToolResults() {
322 if tr.ToolCallID == globTCID {
323 foundGlob = true
324 require.Contains(t, tr.Content, "main.go", "Expected glob to find main.go")
325 }
326 }
327 }
328 }
329
330 require.True(t, foundGlob, "Expected to find a glob operation")
331 })
332 t.Run("grep tool", func(t *testing.T) {
333 agent, env := setupAgent(t, pair)
334
335 session, err := env.sessions.Create(t.Context(), "New Session")
336 require.NoError(t, err)
337
338 res, err := agent.Run(t.Context(), SessionAgentCall{
339 Prompt: "use grep to search for the word 'package' in go files",
340 SessionID: session.ID,
341 MaxOutputTokens: 10000,
342 })
343 require.NoError(t, err)
344 assert.NotNil(t, res)
345
346 msgs, err := env.messages.List(t.Context(), session.ID)
347 require.NoError(t, err)
348
349 foundGrep := false
350 var grepTCID string
351
352 for _, msg := range msgs {
353 if msg.Role == message.Assistant {
354 for _, tc := range msg.ToolCalls() {
355 if tc.Name == tools.GrepToolName {
356 grepTCID = tc.ID
357 }
358 }
359 }
360 if msg.Role == message.Tool {
361 for _, tr := range msg.ToolResults() {
362 if tr.ToolCallID == grepTCID {
363 foundGrep = true
364 require.Contains(t, tr.Content, "main.go", "Expected grep to find main.go")
365 }
366 }
367 }
368 }
369
370 require.True(t, foundGrep, "Expected to find a grep operation")
371 })
372 t.Run("ls tool", func(t *testing.T) {
373 agent, env := setupAgent(t, pair)
374
375 session, err := env.sessions.Create(t.Context(), "New Session")
376 require.NoError(t, err)
377
378 res, err := agent.Run(t.Context(), SessionAgentCall{
379 Prompt: "use ls to list the files in the current directory",
380 SessionID: session.ID,
381 MaxOutputTokens: 10000,
382 })
383 require.NoError(t, err)
384 assert.NotNil(t, res)
385
386 msgs, err := env.messages.List(t.Context(), session.ID)
387 require.NoError(t, err)
388
389 foundLS := false
390 var lsTCID string
391
392 for _, msg := range msgs {
393 if msg.Role == message.Assistant {
394 for _, tc := range msg.ToolCalls() {
395 if tc.Name == tools.LSToolName {
396 lsTCID = tc.ID
397 }
398 }
399 }
400 if msg.Role == message.Tool {
401 for _, tr := range msg.ToolResults() {
402 if tr.ToolCallID == lsTCID {
403 foundLS = true
404 require.Contains(t, tr.Content, "main.go", "Expected ls to list main.go")
405 require.Contains(t, tr.Content, "go.mod", "Expected ls to list go.mod")
406 }
407 }
408 }
409 }
410
411 require.True(t, foundLS, "Expected to find an ls operation")
412 })
413 t.Run("multiedit tool", func(t *testing.T) {
414 agent, env := setupAgent(t, pair)
415
416 session, err := env.sessions.Create(t.Context(), "New Session")
417 require.NoError(t, err)
418
419 res, err := agent.Run(t.Context(), SessionAgentCall{
420 Prompt: "use multiedit to change 'Hello, World!' to 'Hello, Crush!' and add a comment '// Greeting' above the fmt.Println line in main.go",
421 SessionID: session.ID,
422 MaxOutputTokens: 10000,
423 })
424 require.NoError(t, err)
425 assert.NotNil(t, res)
426
427 msgs, err := env.messages.List(t.Context(), session.ID)
428 require.NoError(t, err)
429
430 foundMultiEdit := false
431 var multiEditTCID string
432
433 for _, msg := range msgs {
434 if msg.Role == message.Assistant {
435 for _, tc := range msg.ToolCalls() {
436 if tc.Name == tools.MultiEditToolName {
437 multiEditTCID = tc.ID
438 }
439 }
440 }
441 if msg.Role == message.Tool {
442 for _, tr := range msg.ToolResults() {
443 if tr.ToolCallID == multiEditTCID {
444 foundMultiEdit = true
445 }
446 }
447 }
448 }
449
450 require.True(t, foundMultiEdit, "Expected to find a multiedit operation")
451
452 mainGoPath := filepath.Join(env.workingDir, "main.go")
453 content, err := os.ReadFile(mainGoPath)
454 require.NoError(t, err)
455 require.Contains(t, string(content), "Hello, Crush!", "Expected file to contain 'Hello, Crush!'")
456 })
457 t.Run("sourcegraph tool", func(t *testing.T) {
458 if runtime.GOOS == "darwin" {
459 t.Skip("skipping flacky test on macos for now")
460 }
461
462 agent, env := setupAgent(t, pair)
463
464 session, err := env.sessions.Create(t.Context(), "New Session")
465 require.NoError(t, err)
466
467 res, err := agent.Run(t.Context(), SessionAgentCall{
468 Prompt: "use sourcegraph to search for 'func main' in Go repositories",
469 SessionID: session.ID,
470 MaxOutputTokens: 10000,
471 })
472 require.NoError(t, err)
473 assert.NotNil(t, res)
474
475 msgs, err := env.messages.List(t.Context(), session.ID)
476 require.NoError(t, err)
477
478 foundSourcegraph := false
479 var sourcegraphTCID string
480
481 for _, msg := range msgs {
482 if msg.Role == message.Assistant {
483 for _, tc := range msg.ToolCalls() {
484 if tc.Name == tools.SourcegraphToolName {
485 sourcegraphTCID = tc.ID
486 }
487 }
488 }
489 if msg.Role == message.Tool {
490 for _, tr := range msg.ToolResults() {
491 if tr.ToolCallID == sourcegraphTCID {
492 foundSourcegraph = true
493 }
494 }
495 }
496 }
497
498 require.True(t, foundSourcegraph, "Expected to find a sourcegraph operation")
499 })
500 t.Run("write tool", func(t *testing.T) {
501 agent, env := setupAgent(t, pair)
502
503 session, err := env.sessions.Create(t.Context(), "New Session")
504 require.NoError(t, err)
505
506 res, err := agent.Run(t.Context(), SessionAgentCall{
507 Prompt: "use write to create a new file called config.json with content '{\"name\": \"test\", \"version\": \"1.0.0\"}'",
508 SessionID: session.ID,
509 MaxOutputTokens: 10000,
510 })
511 require.NoError(t, err)
512 assert.NotNil(t, res)
513
514 msgs, err := env.messages.List(t.Context(), session.ID)
515 require.NoError(t, err)
516
517 foundWrite := false
518 var writeTCID string
519
520 for _, msg := range msgs {
521 if msg.Role == message.Assistant {
522 for _, tc := range msg.ToolCalls() {
523 if tc.Name == tools.WriteToolName {
524 writeTCID = tc.ID
525 }
526 }
527 }
528 if msg.Role == message.Tool {
529 for _, tr := range msg.ToolResults() {
530 if tr.ToolCallID == writeTCID {
531 foundWrite = true
532 }
533 }
534 }
535 }
536
537 require.True(t, foundWrite, "Expected to find a write operation")
538
539 configPath := filepath.Join(env.workingDir, "config.json")
540 content, err := os.ReadFile(configPath)
541 require.NoError(t, err)
542 require.Contains(t, string(content), "test", "Expected config.json to contain 'test'")
543 require.Contains(t, string(content), "1.0.0", "Expected config.json to contain '1.0.0'")
544 })
545 t.Run("parallel tool calls", func(t *testing.T) {
546 agent, env := setupAgent(t, pair)
547
548 session, err := env.sessions.Create(t.Context(), "New Session")
549 require.NoError(t, err)
550
551 res, err := agent.Run(t.Context(), SessionAgentCall{
552 Prompt: "use glob to find all .go files and use ls to list the current directory, it is very important that you run both tool calls in parallel",
553 SessionID: session.ID,
554 MaxOutputTokens: 10000,
555 })
556 require.NoError(t, err)
557 assert.NotNil(t, res)
558
559 msgs, err := env.messages.List(t.Context(), session.ID)
560 require.NoError(t, err)
561
562 var assistantMsg *message.Message
563 var toolMsgs []message.Message
564
565 for _, msg := range msgs {
566 if msg.Role == message.Assistant && len(msg.ToolCalls()) > 0 {
567 assistantMsg = &msg
568 }
569 if msg.Role == message.Tool {
570 toolMsgs = append(toolMsgs, msg)
571 }
572 }
573
574 require.NotNil(t, assistantMsg, "Expected to find an assistant message with tool calls")
575 require.NotNil(t, toolMsgs, "Expected to find a tool message")
576
577 toolCalls := assistantMsg.ToolCalls()
578 require.GreaterOrEqual(t, len(toolCalls), 2, "Expected at least 2 tool calls in parallel")
579
580 foundGlob := false
581 foundLS := false
582 var globTCID, lsTCID string
583
584 for _, tc := range toolCalls {
585 if tc.Name == tools.GlobToolName {
586 foundGlob = true
587 globTCID = tc.ID
588 }
589 if tc.Name == tools.LSToolName {
590 foundLS = true
591 lsTCID = tc.ID
592 }
593 }
594
595 require.True(t, foundGlob, "Expected to find a glob tool call")
596 require.True(t, foundLS, "Expected to find an ls tool call")
597
598 require.GreaterOrEqual(t, len(toolMsgs), 2, "Expected at least 2 tool results in the same message")
599
600 foundGlobResult := false
601 foundLSResult := false
602
603 for _, msg := range toolMsgs {
604 for _, tr := range msg.ToolResults() {
605 if tr.ToolCallID == globTCID {
606 foundGlobResult = true
607 require.Contains(t, tr.Content, "main.go", "Expected glob result to contain main.go")
608 require.False(t, tr.IsError, "Expected glob result to not be an error")
609 }
610 if tr.ToolCallID == lsTCID {
611 foundLSResult = true
612 require.Contains(t, tr.Content, "main.go", "Expected ls result to contain main.go")
613 require.False(t, tr.IsError, "Expected ls result to not be an error")
614 }
615 }
616 }
617
618 require.True(t, foundGlobResult, "Expected to find glob tool result")
619 require.True(t, foundLSResult, "Expected to find ls tool result")
620 })
621 })
622 }
623}
624
625func makeTestTodos(n int) []session.Todo {
626 todos := make([]session.Todo, n)
627 for i := range n {
628 todos[i] = session.Todo{
629 Status: session.TodoStatusPending,
630 Content: fmt.Sprintf("Task %d: Implement feature with some description that makes it realistic", i),
631 }
632 }
633 return todos
634}
635
636func BenchmarkBuildSummaryPrompt(b *testing.B) {
637 cases := []struct {
638 name string
639 numTodos int
640 }{
641 {"0todos", 0},
642 {"5todos", 5},
643 {"10todos", 10},
644 {"50todos", 50},
645 }
646
647 for _, tc := range cases {
648 todos := makeTestTodos(tc.numTodos)
649
650 b.Run(tc.name, func(b *testing.B) {
651 b.ReportAllocs()
652 for range b.N {
653 _ = buildSummaryPrompt(todos)
654 }
655 })
656 }
657}