1package agent
2
3import (
4 "fmt"
5 "os"
6 "path/filepath"
7 "runtime"
8 "strings"
9 "testing"
10 "time"
11
12 "charm.land/fantasy"
13 "charm.land/x/vcr"
14 "github.com/charmbracelet/crush/internal/agent/tools"
15 "github.com/charmbracelet/crush/internal/message"
16 "github.com/charmbracelet/crush/internal/session"
17 "github.com/stretchr/testify/assert"
18 "github.com/stretchr/testify/require"
19
20 _ "github.com/joho/godotenv/autoload"
21)
22
23var modelPairs = []modelPair{
24 {"anthropic-sonnet", anthropicBuilder("claude-sonnet-4-5-20250929"), anthropicBuilder("claude-3-5-haiku-20241022")},
25 {"openai-gpt-5", openaiBuilder("gpt-5"), openaiBuilder("gpt-4o")},
26 {"openrouter-kimi-k2", openRouterBuilder("moonshotai/kimi-k2-0905"), openRouterBuilder("qwen/qwen3-next-80b-a3b-instruct")},
27 {"zai-glm4.6", zAIBuilder("glm-4.6"), zAIBuilder("glm-4.5-air")},
28}
29
30func getModels(t *testing.T, r *vcr.Recorder, pair modelPair) (fantasy.LanguageModel, fantasy.LanguageModel) {
31 large, err := pair.largeModel(t, r)
32 require.NoError(t, err)
33 small, err := pair.smallModel(t, r)
34 require.NoError(t, err)
35 return large, small
36}
37
38func setupAgent(t *testing.T, pair modelPair) (SessionAgent, fakeEnv) {
39 r := vcr.NewRecorder(t)
40 large, small := getModels(t, r, pair)
41 env := testEnv(t)
42
43 createSimpleGoProject(t, env.workingDir)
44 agent, err := coderAgent(r, env, large, small)
45 require.NoError(t, err)
46 return agent, env
47}
48
49func TestCoderAgent(t *testing.T) {
50 if runtime.GOOS == "windows" {
51 t.Skip("skipping on windows for now")
52 }
53
54 for _, pair := range modelPairs {
55 t.Run(pair.name, func(t *testing.T) {
56 t.Run("simple test", func(t *testing.T) {
57 agent, env := setupAgent(t, pair)
58
59 session, err := env.sessions.Create(t.Context(), "New Session")
60 require.NoError(t, err)
61
62 res, err := agent.Run(t.Context(), SessionAgentCall{
63 Prompt: "Hello",
64 SessionID: session.ID,
65 MaxOutputTokens: 10000,
66 })
67 require.NoError(t, err)
68 assert.NotNil(t, res)
69
70 msgs, err := env.messages.List(t.Context(), session.ID)
71 require.NoError(t, err)
72 // Should have the agent and user message
73 assert.Equal(t, len(msgs), 2)
74 })
75 t.Run("read a file", func(t *testing.T) {
76 agent, env := setupAgent(t, pair)
77
78 session, err := env.sessions.Create(t.Context(), "New Session")
79 require.NoError(t, err)
80 res, err := agent.Run(t.Context(), SessionAgentCall{
81 Prompt: "Read the go mod",
82 SessionID: session.ID,
83 MaxOutputTokens: 10000,
84 })
85
86 require.NoError(t, err)
87 assert.NotNil(t, res)
88
89 msgs, err := env.messages.List(t.Context(), session.ID)
90 require.NoError(t, err)
91 foundFile := false
92 var tcID string
93 out:
94 for _, msg := range msgs {
95 if msg.Role == message.Assistant {
96 for _, tc := range msg.ToolCalls() {
97 if tc.Name == tools.ViewToolName {
98 tcID = tc.ID
99 }
100 }
101 }
102 if msg.Role == message.Tool {
103 for _, tr := range msg.ToolResults() {
104 if tr.ToolCallID == tcID {
105 if strings.Contains(tr.Content, "module example.com/testproject") {
106 foundFile = true
107 break out
108 }
109 }
110 }
111 }
112 }
113 require.True(t, foundFile)
114 })
115 t.Run("update a file", func(t *testing.T) {
116 agent, env := setupAgent(t, pair)
117
118 session, err := env.sessions.Create(t.Context(), "New Session")
119 require.NoError(t, err)
120
121 res, err := agent.Run(t.Context(), SessionAgentCall{
122 Prompt: "update the main.go file by changing the print to say hello from crush",
123 SessionID: session.ID,
124 MaxOutputTokens: 10000,
125 })
126 require.NoError(t, err)
127 assert.NotNil(t, res)
128
129 msgs, err := env.messages.List(t.Context(), session.ID)
130 require.NoError(t, err)
131
132 foundRead := false
133 foundWrite := false
134 var readTCID, writeTCID string
135
136 for _, msg := range msgs {
137 if msg.Role == message.Assistant {
138 for _, tc := range msg.ToolCalls() {
139 if tc.Name == tools.ViewToolName {
140 readTCID = tc.ID
141 }
142 if tc.Name == tools.EditToolName || tc.Name == tools.WriteToolName {
143 writeTCID = tc.ID
144 }
145 }
146 }
147 if msg.Role == message.Tool {
148 for _, tr := range msg.ToolResults() {
149 if tr.ToolCallID == readTCID {
150 foundRead = true
151 }
152 if tr.ToolCallID == writeTCID {
153 foundWrite = true
154 }
155 }
156 }
157 }
158
159 require.True(t, foundRead, "Expected to find a read operation")
160 require.True(t, foundWrite, "Expected to find a write operation")
161
162 mainGoPath := filepath.Join(env.workingDir, "main.go")
163 content, err := os.ReadFile(mainGoPath)
164 require.NoError(t, err)
165 require.Contains(t, strings.ToLower(string(content)), "hello from crush")
166 })
167 t.Run("bash tool", func(t *testing.T) {
168 agent, env := setupAgent(t, pair)
169
170 session, err := env.sessions.Create(t.Context(), "New Session")
171 require.NoError(t, err)
172
173 res, err := agent.Run(t.Context(), SessionAgentCall{
174 Prompt: "use bash to create a file named test.txt with content 'hello bash'. do not print its timestamp",
175 SessionID: session.ID,
176 MaxOutputTokens: 10000,
177 })
178 require.NoError(t, err)
179 assert.NotNil(t, res)
180
181 msgs, err := env.messages.List(t.Context(), session.ID)
182 require.NoError(t, err)
183
184 foundBash := false
185 var bashTCID string
186
187 for _, msg := range msgs {
188 if msg.Role == message.Assistant {
189 for _, tc := range msg.ToolCalls() {
190 if tc.Name == tools.BashToolName {
191 bashTCID = tc.ID
192 }
193 }
194 }
195 if msg.Role == message.Tool {
196 for _, tr := range msg.ToolResults() {
197 if tr.ToolCallID == bashTCID {
198 foundBash = true
199 }
200 }
201 }
202 }
203
204 require.True(t, foundBash, "Expected to find a bash operation")
205
206 testFilePath := filepath.Join(env.workingDir, "test.txt")
207 content, err := os.ReadFile(testFilePath)
208 require.NoError(t, err)
209 require.Contains(t, string(content), "hello bash")
210 })
211 t.Run("download tool", func(t *testing.T) {
212 agent, env := setupAgent(t, pair)
213
214 session, err := env.sessions.Create(t.Context(), "New Session")
215 require.NoError(t, err)
216
217 res, err := agent.Run(t.Context(), SessionAgentCall{
218 Prompt: "download the file from https://example-files.online-convert.com/document/txt/example.txt and save it as example.txt",
219 SessionID: session.ID,
220 MaxOutputTokens: 10000,
221 })
222 require.NoError(t, err)
223 assert.NotNil(t, res)
224
225 msgs, err := env.messages.List(t.Context(), session.ID)
226 require.NoError(t, err)
227
228 foundDownload := false
229 var downloadTCID string
230
231 for _, msg := range msgs {
232 if msg.Role == message.Assistant {
233 for _, tc := range msg.ToolCalls() {
234 if tc.Name == tools.DownloadToolName {
235 downloadTCID = tc.ID
236 }
237 }
238 }
239 if msg.Role == message.Tool {
240 for _, tr := range msg.ToolResults() {
241 if tr.ToolCallID == downloadTCID {
242 foundDownload = true
243 }
244 }
245 }
246 }
247
248 require.True(t, foundDownload, "Expected to find a download operation")
249
250 examplePath := filepath.Join(env.workingDir, "example.txt")
251 _, err = os.Stat(examplePath)
252 require.NoError(t, err, "Expected example.txt file to exist")
253 })
254 t.Run("fetch tool", func(t *testing.T) {
255 agent, env := setupAgent(t, pair)
256
257 session, err := env.sessions.Create(t.Context(), "New Session")
258 require.NoError(t, err)
259
260 res, err := agent.Run(t.Context(), SessionAgentCall{
261 Prompt: "fetch the content from https://example-files.online-convert.com/website/html/example.html and tell me if it contains the word 'John Doe'",
262 SessionID: session.ID,
263 MaxOutputTokens: 10000,
264 })
265 require.NoError(t, err)
266 assert.NotNil(t, res)
267
268 msgs, err := env.messages.List(t.Context(), session.ID)
269 require.NoError(t, err)
270
271 foundFetch := false
272 var fetchTCID string
273
274 for _, msg := range msgs {
275 if msg.Role == message.Assistant {
276 for _, tc := range msg.ToolCalls() {
277 if tc.Name == tools.FetchToolName {
278 fetchTCID = tc.ID
279 }
280 }
281 }
282 if msg.Role == message.Tool {
283 for _, tr := range msg.ToolResults() {
284 if tr.ToolCallID == fetchTCID {
285 foundFetch = true
286 }
287 }
288 }
289 }
290
291 require.True(t, foundFetch, "Expected to find a fetch operation")
292 })
293 t.Run("glob tool", func(t *testing.T) {
294 agent, env := setupAgent(t, pair)
295
296 session, err := env.sessions.Create(t.Context(), "New Session")
297 require.NoError(t, err)
298
299 res, err := agent.Run(t.Context(), SessionAgentCall{
300 Prompt: "use glob to find all .go files in the current directory",
301 SessionID: session.ID,
302 MaxOutputTokens: 10000,
303 })
304 require.NoError(t, err)
305 assert.NotNil(t, res)
306
307 msgs, err := env.messages.List(t.Context(), session.ID)
308 require.NoError(t, err)
309
310 foundGlob := false
311 var globTCID string
312
313 for _, msg := range msgs {
314 if msg.Role == message.Assistant {
315 for _, tc := range msg.ToolCalls() {
316 if tc.Name == tools.GlobToolName {
317 globTCID = tc.ID
318 }
319 }
320 }
321 if msg.Role == message.Tool {
322 for _, tr := range msg.ToolResults() {
323 if tr.ToolCallID == globTCID {
324 foundGlob = true
325 require.Contains(t, tr.Content, "main.go", "Expected glob to find main.go")
326 }
327 }
328 }
329 }
330
331 require.True(t, foundGlob, "Expected to find a glob operation")
332 })
333 t.Run("grep tool", func(t *testing.T) {
334 agent, env := setupAgent(t, pair)
335
336 session, err := env.sessions.Create(t.Context(), "New Session")
337 require.NoError(t, err)
338
339 res, err := agent.Run(t.Context(), SessionAgentCall{
340 Prompt: "use grep to search for the word 'package' in go files",
341 SessionID: session.ID,
342 MaxOutputTokens: 10000,
343 })
344 require.NoError(t, err)
345 assert.NotNil(t, res)
346
347 msgs, err := env.messages.List(t.Context(), session.ID)
348 require.NoError(t, err)
349
350 foundGrep := false
351 var grepTCID string
352
353 for _, msg := range msgs {
354 if msg.Role == message.Assistant {
355 for _, tc := range msg.ToolCalls() {
356 if tc.Name == tools.GrepToolName {
357 grepTCID = tc.ID
358 }
359 }
360 }
361 if msg.Role == message.Tool {
362 for _, tr := range msg.ToolResults() {
363 if tr.ToolCallID == grepTCID {
364 foundGrep = true
365 require.Contains(t, tr.Content, "main.go", "Expected grep to find main.go")
366 }
367 }
368 }
369 }
370
371 require.True(t, foundGrep, "Expected to find a grep operation")
372 })
373 t.Run("ls tool", func(t *testing.T) {
374 agent, env := setupAgent(t, pair)
375
376 session, err := env.sessions.Create(t.Context(), "New Session")
377 require.NoError(t, err)
378
379 res, err := agent.Run(t.Context(), SessionAgentCall{
380 Prompt: "use ls to list the files in the current directory",
381 SessionID: session.ID,
382 MaxOutputTokens: 10000,
383 })
384 require.NoError(t, err)
385 assert.NotNil(t, res)
386
387 msgs, err := env.messages.List(t.Context(), session.ID)
388 require.NoError(t, err)
389
390 foundLS := false
391 var lsTCID string
392
393 for _, msg := range msgs {
394 if msg.Role == message.Assistant {
395 for _, tc := range msg.ToolCalls() {
396 if tc.Name == tools.LSToolName {
397 lsTCID = tc.ID
398 }
399 }
400 }
401 if msg.Role == message.Tool {
402 for _, tr := range msg.ToolResults() {
403 if tr.ToolCallID == lsTCID {
404 foundLS = true
405 require.Contains(t, tr.Content, "main.go", "Expected ls to list main.go")
406 require.Contains(t, tr.Content, "go.mod", "Expected ls to list go.mod")
407 }
408 }
409 }
410 }
411
412 require.True(t, foundLS, "Expected to find an ls operation")
413 })
414 t.Run("multiedit tool", func(t *testing.T) {
415 agent, env := setupAgent(t, pair)
416
417 session, err := env.sessions.Create(t.Context(), "New Session")
418 require.NoError(t, err)
419
420 res, err := agent.Run(t.Context(), SessionAgentCall{
421 Prompt: "use multiedit to change 'Hello, World!' to 'Hello, Crush!' and add a comment '// Greeting' above the fmt.Println line in main.go",
422 SessionID: session.ID,
423 MaxOutputTokens: 10000,
424 })
425 require.NoError(t, err)
426 assert.NotNil(t, res)
427
428 msgs, err := env.messages.List(t.Context(), session.ID)
429 require.NoError(t, err)
430
431 foundMultiEdit := false
432 var multiEditTCID string
433
434 for _, msg := range msgs {
435 if msg.Role == message.Assistant {
436 for _, tc := range msg.ToolCalls() {
437 if tc.Name == tools.MultiEditToolName {
438 multiEditTCID = tc.ID
439 }
440 }
441 }
442 if msg.Role == message.Tool {
443 for _, tr := range msg.ToolResults() {
444 if tr.ToolCallID == multiEditTCID {
445 foundMultiEdit = true
446 }
447 }
448 }
449 }
450
451 require.True(t, foundMultiEdit, "Expected to find a multiedit operation")
452
453 mainGoPath := filepath.Join(env.workingDir, "main.go")
454 content, err := os.ReadFile(mainGoPath)
455 require.NoError(t, err)
456 require.Contains(t, string(content), "Hello, Crush!", "Expected file to contain 'Hello, Crush!'")
457 })
458 t.Run("sourcegraph tool", func(t *testing.T) {
459 agent, env := setupAgent(t, pair)
460
461 session, err := env.sessions.Create(t.Context(), "New Session")
462 require.NoError(t, err)
463
464 res, err := agent.Run(t.Context(), SessionAgentCall{
465 Prompt: "use sourcegraph to search for 'func main' in Go repositories",
466 SessionID: session.ID,
467 MaxOutputTokens: 10000,
468 })
469 require.NoError(t, err)
470 assert.NotNil(t, res)
471
472 msgs, err := env.messages.List(t.Context(), session.ID)
473 require.NoError(t, err)
474
475 foundSourcegraph := false
476 var sourcegraphTCID string
477
478 for _, msg := range msgs {
479 if msg.Role == message.Assistant {
480 for _, tc := range msg.ToolCalls() {
481 if tc.Name == tools.SourcegraphToolName {
482 sourcegraphTCID = tc.ID
483 }
484 }
485 }
486 if msg.Role == message.Tool {
487 for _, tr := range msg.ToolResults() {
488 if tr.ToolCallID == sourcegraphTCID {
489 foundSourcegraph = true
490 }
491 }
492 }
493 }
494
495 require.True(t, foundSourcegraph, "Expected to find a sourcegraph operation")
496 })
497 t.Run("write tool", func(t *testing.T) {
498 agent, env := setupAgent(t, pair)
499
500 session, err := env.sessions.Create(t.Context(), "New Session")
501 require.NoError(t, err)
502
503 res, err := agent.Run(t.Context(), SessionAgentCall{
504 Prompt: "use write to create a new file called config.json with content '{\"name\": \"test\", \"version\": \"1.0.0\"}'",
505 SessionID: session.ID,
506 MaxOutputTokens: 10000,
507 })
508 require.NoError(t, err)
509 assert.NotNil(t, res)
510
511 msgs, err := env.messages.List(t.Context(), session.ID)
512 require.NoError(t, err)
513
514 foundWrite := false
515 var writeTCID string
516
517 for _, msg := range msgs {
518 if msg.Role == message.Assistant {
519 for _, tc := range msg.ToolCalls() {
520 if tc.Name == tools.WriteToolName {
521 writeTCID = tc.ID
522 }
523 }
524 }
525 if msg.Role == message.Tool {
526 for _, tr := range msg.ToolResults() {
527 if tr.ToolCallID == writeTCID {
528 foundWrite = true
529 }
530 }
531 }
532 }
533
534 require.True(t, foundWrite, "Expected to find a write operation")
535
536 configPath := filepath.Join(env.workingDir, "config.json")
537 content, err := os.ReadFile(configPath)
538 require.NoError(t, err)
539 require.Contains(t, string(content), "test", "Expected config.json to contain 'test'")
540 require.Contains(t, string(content), "1.0.0", "Expected config.json to contain '1.0.0'")
541 })
542 t.Run("parallel tool calls", func(t *testing.T) {
543 agent, env := setupAgent(t, pair)
544
545 session, err := env.sessions.Create(t.Context(), "New Session")
546 require.NoError(t, err)
547
548 res, err := agent.Run(t.Context(), SessionAgentCall{
549 Prompt: "use glob to find all .go files and use ls to list the current directory, it is very important that you run both tool calls in parallel",
550 SessionID: session.ID,
551 MaxOutputTokens: 10000,
552 })
553 require.NoError(t, err)
554 assert.NotNil(t, res)
555
556 msgs, err := env.messages.List(t.Context(), session.ID)
557 require.NoError(t, err)
558
559 var assistantMsg *message.Message
560 var toolMsgs []message.Message
561
562 for _, msg := range msgs {
563 if msg.Role == message.Assistant && len(msg.ToolCalls()) > 0 {
564 assistantMsg = &msg
565 }
566 if msg.Role == message.Tool {
567 toolMsgs = append(toolMsgs, msg)
568 }
569 }
570
571 require.NotNil(t, assistantMsg, "Expected to find an assistant message with tool calls")
572 require.NotNil(t, toolMsgs, "Expected to find a tool message")
573
574 toolCalls := assistantMsg.ToolCalls()
575 require.GreaterOrEqual(t, len(toolCalls), 2, "Expected at least 2 tool calls in parallel")
576
577 foundGlob := false
578 foundLS := false
579 var globTCID, lsTCID string
580
581 for _, tc := range toolCalls {
582 if tc.Name == tools.GlobToolName {
583 foundGlob = true
584 globTCID = tc.ID
585 }
586 if tc.Name == tools.LSToolName {
587 foundLS = true
588 lsTCID = tc.ID
589 }
590 }
591
592 require.True(t, foundGlob, "Expected to find a glob tool call")
593 require.True(t, foundLS, "Expected to find an ls tool call")
594
595 require.GreaterOrEqual(t, len(toolMsgs), 2, "Expected at least 2 tool results in the same message")
596
597 foundGlobResult := false
598 foundLSResult := false
599
600 for _, msg := range toolMsgs {
601 for _, tr := range msg.ToolResults() {
602 if tr.ToolCallID == globTCID {
603 foundGlobResult = true
604 require.Contains(t, tr.Content, "main.go", "Expected glob result to contain main.go")
605 require.False(t, tr.IsError, "Expected glob result to not be an error")
606 }
607 if tr.ToolCallID == lsTCID {
608 foundLSResult = true
609 require.Contains(t, tr.Content, "main.go", "Expected ls result to contain main.go")
610 require.False(t, tr.IsError, "Expected ls result to not be an error")
611 }
612 }
613 }
614
615 require.True(t, foundGlobResult, "Expected to find glob tool result")
616 require.True(t, foundLSResult, "Expected to find ls tool result")
617 })
618 })
619 }
620}
621
622func makeTestTodos(n int) []session.Todo {
623 todos := make([]session.Todo, n)
624 for i := range n {
625 todos[i] = session.Todo{
626 Status: session.TodoStatusPending,
627 Content: fmt.Sprintf("Task %d: Implement feature with some description that makes it realistic", i),
628 }
629 }
630 return todos
631}
632
633func BenchmarkBuildSummaryPrompt(b *testing.B) {
634 cases := []struct {
635 name string
636 numTodos int
637 }{
638 {"0todos", 0},
639 {"5todos", 5},
640 {"10todos", 10},
641 {"50todos", 50},
642 }
643
644 for _, tc := range cases {
645 todos := makeTestTodos(tc.numTodos)
646
647 b.Run(tc.name, func(b *testing.B) {
648 b.ReportAllocs()
649 for range b.N {
650 _ = buildSummaryPrompt("test-session-id", todos)
651 }
652 })
653 }
654}
655
656func TestSerializeTranscript(t *testing.T) {
657 now := time.Now().Unix()
658
659 msgs := []message.Message{
660 {
661 ID: "msg1",
662 Role: message.User,
663 SessionID: "sess1",
664 CreatedAt: now,
665 Parts: []message.ContentPart{
666 message.TextContent{Text: "Hello, can you help me?"},
667 },
668 },
669 {
670 ID: "msg2",
671 Role: message.Assistant,
672 SessionID: "sess1",
673 Model: "claude-sonnet-4-20250514",
674 Provider: "anthropic",
675 CreatedAt: now + 1,
676 Parts: []message.ContentPart{
677 message.TextContent{Text: "Of course! What do you need help with?"},
678 message.ToolCall{
679 ID: "tc1",
680 Name: "view",
681 Input: `{"file_path": "/test/file.go"}`,
682 },
683 },
684 },
685 {
686 ID: "msg3",
687 Role: message.Tool,
688 SessionID: "sess1",
689 CreatedAt: now + 2,
690 Parts: []message.ContentPart{
691 message.ToolResult{
692 ToolCallID: "tc1",
693 Name: "view",
694 Content: "package main\n\nfunc main() {}",
695 IsError: false,
696 },
697 },
698 },
699 }
700
701 transcript := serializeTranscript(msgs)
702
703 // Verify structure.
704 require.Contains(t, transcript, "# Session Transcript")
705 require.Contains(t, transcript, "## User")
706 require.Contains(t, transcript, "## Assistant")
707 require.Contains(t, transcript, "## Tool Results")
708
709 // Verify user message.
710 require.Contains(t, transcript, "Hello, can you help me?")
711
712 // Verify assistant message.
713 require.Contains(t, transcript, "claude-sonnet-4-20250514")
714 require.Contains(t, transcript, "Of course! What do you need help with?")
715 require.Contains(t, transcript, "**Tool:** `view`")
716 require.Contains(t, transcript, `"file_path": "/test/file.go"`)
717
718 // Verify tool result.
719 require.Contains(t, transcript, "**Status:** Success")
720 require.Contains(t, transcript, "package main")
721}
722
723func TestSerializeTranscript_TruncatesLongToolResults(t *testing.T) {
724 now := time.Now().Unix()
725
726 // Create a tool result with content larger than the truncation threshold.
727 longContent := strings.Repeat("x", 15000)
728
729 msgs := []message.Message{
730 {
731 ID: "msg1",
732 Role: message.Tool,
733 SessionID: "sess1",
734 CreatedAt: now,
735 Parts: []message.ContentPart{
736 message.ToolResult{
737 ToolCallID: "tc1",
738 Name: "bash",
739 Content: longContent,
740 IsError: false,
741 },
742 },
743 },
744 }
745
746 transcript := serializeTranscript(msgs)
747
748 // Verify truncation happened.
749 require.Contains(t, transcript, "... (truncated)")
750 require.Less(t, len(transcript), 15000, "Transcript should be smaller than original content")
751}