agent_test.go

  1package agent
  2
  3import (
  4	"fmt"
  5	"os"
  6	"path/filepath"
  7	"runtime"
  8	"strings"
  9	"testing"
 10
 11	"charm.land/fantasy"
 12	"charm.land/x/vcr"
 13	"github.com/charmbracelet/crush/internal/agent/tools"
 14	"github.com/charmbracelet/crush/internal/message"
 15	"github.com/charmbracelet/crush/internal/session"
 16	"github.com/stretchr/testify/assert"
 17	"github.com/stretchr/testify/require"
 18
 19	_ "github.com/joho/godotenv/autoload"
 20)
 21
 22var modelPairs = []modelPair{
 23	{"anthropic-sonnet", anthropicBuilder("claude-sonnet-4-5-20250929"), anthropicBuilder("claude-3-5-haiku-20241022")},
 24	{"openai-gpt-5", openaiBuilder("gpt-5"), openaiBuilder("gpt-4o")},
 25	{"openrouter-kimi-k2", openRouterBuilder("moonshotai/kimi-k2-0905"), openRouterBuilder("qwen/qwen3-next-80b-a3b-instruct")},
 26	{"zai-glm4.6", zAIBuilder("glm-4.6"), zAIBuilder("glm-4.5-air")},
 27}
 28
 29func getModels(t *testing.T, r *vcr.Recorder, pair modelPair) (fantasy.LanguageModel, fantasy.LanguageModel) {
 30	large, err := pair.largeModel(t, r)
 31	require.NoError(t, err)
 32	small, err := pair.smallModel(t, r)
 33	require.NoError(t, err)
 34	return large, small
 35}
 36
 37func setupAgent(t *testing.T, pair modelPair) (SessionAgent, fakeEnv) {
 38	r := vcr.NewRecorder(t)
 39	large, small := getModels(t, r, pair)
 40	env := testEnv(t)
 41
 42	createSimpleGoProject(t, env.workingDir)
 43	agent, err := coderAgent(r, env, large, small)
 44	require.NoError(t, err)
 45	return agent, env
 46}
 47
 48func TestCoderAgent(t *testing.T) {
 49	if runtime.GOOS == "windows" {
 50		t.Skip("skipping on windows for now")
 51	}
 52
 53	for _, pair := range modelPairs {
 54		t.Run(pair.name, func(t *testing.T) {
 55			t.Run("simple test", func(t *testing.T) {
 56				agent, env := setupAgent(t, pair)
 57
 58				session, err := env.sessions.Create(t.Context(), "New Session")
 59				require.NoError(t, err)
 60
 61				res, err := agent.Run(t.Context(), SessionAgentCall{
 62					Prompt:          "Hello",
 63					SessionID:       session.ID,
 64					MaxOutputTokens: 10000,
 65				})
 66				require.NoError(t, err)
 67				assert.NotNil(t, res)
 68
 69				msgs, err := env.messages.List(t.Context(), session.ID)
 70				require.NoError(t, err)
 71				// Should have the agent and user message
 72				assert.Equal(t, len(msgs), 2)
 73			})
 74			t.Run("read a file", func(t *testing.T) {
 75				agent, env := setupAgent(t, pair)
 76
 77				session, err := env.sessions.Create(t.Context(), "New Session")
 78				require.NoError(t, err)
 79				res, err := agent.Run(t.Context(), SessionAgentCall{
 80					Prompt:          "Read the go mod",
 81					SessionID:       session.ID,
 82					MaxOutputTokens: 10000,
 83				})
 84
 85				require.NoError(t, err)
 86				assert.NotNil(t, res)
 87
 88				msgs, err := env.messages.List(t.Context(), session.ID)
 89				require.NoError(t, err)
 90				foundFile := false
 91				var tcID string
 92			out:
 93				for _, msg := range msgs {
 94					if msg.Role == message.Assistant {
 95						for _, tc := range msg.ToolCalls() {
 96							if tc.Name == tools.ViewToolName {
 97								tcID = tc.ID
 98							}
 99						}
100					}
101					if msg.Role == message.Tool {
102						for _, tr := range msg.ToolResults() {
103							if tr.ToolCallID == tcID {
104								if strings.Contains(tr.Content, "module example.com/testproject") {
105									foundFile = true
106									break out
107								}
108							}
109						}
110					}
111				}
112				require.True(t, foundFile)
113			})
114			t.Run("update a file", func(t *testing.T) {
115				agent, env := setupAgent(t, pair)
116
117				session, err := env.sessions.Create(t.Context(), "New Session")
118				require.NoError(t, err)
119
120				res, err := agent.Run(t.Context(), SessionAgentCall{
121					Prompt:          "update the main.go file by changing the print to say hello from crush",
122					SessionID:       session.ID,
123					MaxOutputTokens: 10000,
124				})
125				require.NoError(t, err)
126				assert.NotNil(t, res)
127
128				msgs, err := env.messages.List(t.Context(), session.ID)
129				require.NoError(t, err)
130
131				foundRead := false
132				foundWrite := false
133				var readTCID, writeTCID string
134
135				for _, msg := range msgs {
136					if msg.Role == message.Assistant {
137						for _, tc := range msg.ToolCalls() {
138							if tc.Name == tools.ViewToolName {
139								readTCID = tc.ID
140							}
141							if tc.Name == tools.EditToolName || tc.Name == tools.WriteToolName {
142								writeTCID = tc.ID
143							}
144						}
145					}
146					if msg.Role == message.Tool {
147						for _, tr := range msg.ToolResults() {
148							if tr.ToolCallID == readTCID {
149								foundRead = true
150							}
151							if tr.ToolCallID == writeTCID {
152								foundWrite = true
153							}
154						}
155					}
156				}
157
158				require.True(t, foundRead, "Expected to find a read operation")
159				require.True(t, foundWrite, "Expected to find a write operation")
160
161				mainGoPath := filepath.Join(env.workingDir, "main.go")
162				content, err := os.ReadFile(mainGoPath)
163				require.NoError(t, err)
164				require.Contains(t, strings.ToLower(string(content)), "hello from crush")
165			})
166			t.Run("bash tool", func(t *testing.T) {
167				agent, env := setupAgent(t, pair)
168
169				session, err := env.sessions.Create(t.Context(), "New Session")
170				require.NoError(t, err)
171
172				res, err := agent.Run(t.Context(), SessionAgentCall{
173					Prompt:          "use bash to create a file named test.txt with content 'hello bash'. do not print its timestamp",
174					SessionID:       session.ID,
175					MaxOutputTokens: 10000,
176				})
177				require.NoError(t, err)
178				assert.NotNil(t, res)
179
180				msgs, err := env.messages.List(t.Context(), session.ID)
181				require.NoError(t, err)
182
183				foundBash := false
184				var bashTCID string
185
186				for _, msg := range msgs {
187					if msg.Role == message.Assistant {
188						for _, tc := range msg.ToolCalls() {
189							if tc.Name == tools.BashToolName {
190								bashTCID = tc.ID
191							}
192						}
193					}
194					if msg.Role == message.Tool {
195						for _, tr := range msg.ToolResults() {
196							if tr.ToolCallID == bashTCID {
197								foundBash = true
198							}
199						}
200					}
201				}
202
203				require.True(t, foundBash, "Expected to find a bash operation")
204
205				testFilePath := filepath.Join(env.workingDir, "test.txt")
206				content, err := os.ReadFile(testFilePath)
207				require.NoError(t, err)
208				require.Contains(t, string(content), "hello bash")
209			})
210			t.Run("download tool", func(t *testing.T) {
211				agent, env := setupAgent(t, pair)
212
213				session, err := env.sessions.Create(t.Context(), "New Session")
214				require.NoError(t, err)
215
216				res, err := agent.Run(t.Context(), SessionAgentCall{
217					Prompt:          "download the file from https://example-files.online-convert.com/document/txt/example.txt and save it as example.txt",
218					SessionID:       session.ID,
219					MaxOutputTokens: 10000,
220				})
221				require.NoError(t, err)
222				assert.NotNil(t, res)
223
224				msgs, err := env.messages.List(t.Context(), session.ID)
225				require.NoError(t, err)
226
227				foundDownload := false
228				var downloadTCID string
229
230				for _, msg := range msgs {
231					if msg.Role == message.Assistant {
232						for _, tc := range msg.ToolCalls() {
233							if tc.Name == tools.DownloadToolName {
234								downloadTCID = tc.ID
235							}
236						}
237					}
238					if msg.Role == message.Tool {
239						for _, tr := range msg.ToolResults() {
240							if tr.ToolCallID == downloadTCID {
241								foundDownload = true
242							}
243						}
244					}
245				}
246
247				require.True(t, foundDownload, "Expected to find a download operation")
248
249				examplePath := filepath.Join(env.workingDir, "example.txt")
250				_, err = os.Stat(examplePath)
251				require.NoError(t, err, "Expected example.txt file to exist")
252			})
253			t.Run("fetch tool", func(t *testing.T) {
254				agent, env := setupAgent(t, pair)
255
256				session, err := env.sessions.Create(t.Context(), "New Session")
257				require.NoError(t, err)
258
259				res, err := agent.Run(t.Context(), SessionAgentCall{
260					Prompt:          "fetch the content from https://example-files.online-convert.com/website/html/example.html and tell me if it contains the word 'John Doe'",
261					SessionID:       session.ID,
262					MaxOutputTokens: 10000,
263				})
264				require.NoError(t, err)
265				assert.NotNil(t, res)
266
267				msgs, err := env.messages.List(t.Context(), session.ID)
268				require.NoError(t, err)
269
270				foundFetch := false
271				var fetchTCID string
272
273				for _, msg := range msgs {
274					if msg.Role == message.Assistant {
275						for _, tc := range msg.ToolCalls() {
276							if tc.Name == tools.FetchToolName {
277								fetchTCID = tc.ID
278							}
279						}
280					}
281					if msg.Role == message.Tool {
282						for _, tr := range msg.ToolResults() {
283							if tr.ToolCallID == fetchTCID {
284								foundFetch = true
285							}
286						}
287					}
288				}
289
290				require.True(t, foundFetch, "Expected to find a fetch operation")
291			})
292			t.Run("glob tool", func(t *testing.T) {
293				agent, env := setupAgent(t, pair)
294
295				session, err := env.sessions.Create(t.Context(), "New Session")
296				require.NoError(t, err)
297
298				res, err := agent.Run(t.Context(), SessionAgentCall{
299					Prompt:          "use glob to find all .go files in the current directory",
300					SessionID:       session.ID,
301					MaxOutputTokens: 10000,
302				})
303				require.NoError(t, err)
304				assert.NotNil(t, res)
305
306				msgs, err := env.messages.List(t.Context(), session.ID)
307				require.NoError(t, err)
308
309				foundGlob := false
310				var globTCID string
311
312				for _, msg := range msgs {
313					if msg.Role == message.Assistant {
314						for _, tc := range msg.ToolCalls() {
315							if tc.Name == tools.GlobToolName {
316								globTCID = tc.ID
317							}
318						}
319					}
320					if msg.Role == message.Tool {
321						for _, tr := range msg.ToolResults() {
322							if tr.ToolCallID == globTCID {
323								foundGlob = true
324								require.Contains(t, tr.Content, "main.go", "Expected glob to find main.go")
325							}
326						}
327					}
328				}
329
330				require.True(t, foundGlob, "Expected to find a glob operation")
331			})
332			t.Run("grep tool", func(t *testing.T) {
333				agent, env := setupAgent(t, pair)
334
335				session, err := env.sessions.Create(t.Context(), "New Session")
336				require.NoError(t, err)
337
338				res, err := agent.Run(t.Context(), SessionAgentCall{
339					Prompt:          "use grep to search for the word 'package' in go files",
340					SessionID:       session.ID,
341					MaxOutputTokens: 10000,
342				})
343				require.NoError(t, err)
344				assert.NotNil(t, res)
345
346				msgs, err := env.messages.List(t.Context(), session.ID)
347				require.NoError(t, err)
348
349				foundGrep := false
350				var grepTCID string
351
352				for _, msg := range msgs {
353					if msg.Role == message.Assistant {
354						for _, tc := range msg.ToolCalls() {
355							if tc.Name == tools.GrepToolName {
356								grepTCID = tc.ID
357							}
358						}
359					}
360					if msg.Role == message.Tool {
361						for _, tr := range msg.ToolResults() {
362							if tr.ToolCallID == grepTCID {
363								foundGrep = true
364								require.Contains(t, tr.Content, "main.go", "Expected grep to find main.go")
365							}
366						}
367					}
368				}
369
370				require.True(t, foundGrep, "Expected to find a grep operation")
371			})
372			t.Run("ls tool", func(t *testing.T) {
373				agent, env := setupAgent(t, pair)
374
375				session, err := env.sessions.Create(t.Context(), "New Session")
376				require.NoError(t, err)
377
378				res, err := agent.Run(t.Context(), SessionAgentCall{
379					Prompt:          "use ls to list the files in the current directory",
380					SessionID:       session.ID,
381					MaxOutputTokens: 10000,
382				})
383				require.NoError(t, err)
384				assert.NotNil(t, res)
385
386				msgs, err := env.messages.List(t.Context(), session.ID)
387				require.NoError(t, err)
388
389				foundLS := false
390				var lsTCID string
391
392				for _, msg := range msgs {
393					if msg.Role == message.Assistant {
394						for _, tc := range msg.ToolCalls() {
395							if tc.Name == tools.LSToolName {
396								lsTCID = tc.ID
397							}
398						}
399					}
400					if msg.Role == message.Tool {
401						for _, tr := range msg.ToolResults() {
402							if tr.ToolCallID == lsTCID {
403								foundLS = true
404								require.Contains(t, tr.Content, "main.go", "Expected ls to list main.go")
405								require.Contains(t, tr.Content, "go.mod", "Expected ls to list go.mod")
406							}
407						}
408					}
409				}
410
411				require.True(t, foundLS, "Expected to find an ls operation")
412			})
413			t.Run("multiedit tool", func(t *testing.T) {
414				agent, env := setupAgent(t, pair)
415
416				session, err := env.sessions.Create(t.Context(), "New Session")
417				require.NoError(t, err)
418
419				res, err := agent.Run(t.Context(), SessionAgentCall{
420					Prompt:          "use multiedit to change 'Hello, World!' to 'Hello, Crush!' and add a comment '// Greeting' above the fmt.Println line in main.go",
421					SessionID:       session.ID,
422					MaxOutputTokens: 10000,
423				})
424				require.NoError(t, err)
425				assert.NotNil(t, res)
426
427				msgs, err := env.messages.List(t.Context(), session.ID)
428				require.NoError(t, err)
429
430				foundMultiEdit := false
431				var multiEditTCID string
432
433				for _, msg := range msgs {
434					if msg.Role == message.Assistant {
435						for _, tc := range msg.ToolCalls() {
436							if tc.Name == tools.MultiEditToolName {
437								multiEditTCID = tc.ID
438							}
439						}
440					}
441					if msg.Role == message.Tool {
442						for _, tr := range msg.ToolResults() {
443							if tr.ToolCallID == multiEditTCID {
444								foundMultiEdit = true
445							}
446						}
447					}
448				}
449
450				require.True(t, foundMultiEdit, "Expected to find a multiedit operation")
451
452				mainGoPath := filepath.Join(env.workingDir, "main.go")
453				content, err := os.ReadFile(mainGoPath)
454				require.NoError(t, err)
455				require.Contains(t, string(content), "Hello, Crush!", "Expected file to contain 'Hello, Crush!'")
456			})
457			t.Run("sourcegraph tool", func(t *testing.T) {
458				if runtime.GOOS == "darwin" {
459					t.Skip("skipping flacky test on macos for now")
460				}
461
462				agent, env := setupAgent(t, pair)
463
464				session, err := env.sessions.Create(t.Context(), "New Session")
465				require.NoError(t, err)
466
467				res, err := agent.Run(t.Context(), SessionAgentCall{
468					Prompt:          "use sourcegraph to search for 'func main' in Go repositories",
469					SessionID:       session.ID,
470					MaxOutputTokens: 10000,
471				})
472				require.NoError(t, err)
473				assert.NotNil(t, res)
474
475				msgs, err := env.messages.List(t.Context(), session.ID)
476				require.NoError(t, err)
477
478				foundSourcegraph := false
479				var sourcegraphTCID string
480
481				for _, msg := range msgs {
482					if msg.Role == message.Assistant {
483						for _, tc := range msg.ToolCalls() {
484							if tc.Name == tools.SourcegraphToolName {
485								sourcegraphTCID = tc.ID
486							}
487						}
488					}
489					if msg.Role == message.Tool {
490						for _, tr := range msg.ToolResults() {
491							if tr.ToolCallID == sourcegraphTCID {
492								foundSourcegraph = true
493							}
494						}
495					}
496				}
497
498				require.True(t, foundSourcegraph, "Expected to find a sourcegraph operation")
499			})
500			t.Run("write tool", func(t *testing.T) {
501				agent, env := setupAgent(t, pair)
502
503				session, err := env.sessions.Create(t.Context(), "New Session")
504				require.NoError(t, err)
505
506				res, err := agent.Run(t.Context(), SessionAgentCall{
507					Prompt:          "use write to create a new file called config.json with content '{\"name\": \"test\", \"version\": \"1.0.0\"}'",
508					SessionID:       session.ID,
509					MaxOutputTokens: 10000,
510				})
511				require.NoError(t, err)
512				assert.NotNil(t, res)
513
514				msgs, err := env.messages.List(t.Context(), session.ID)
515				require.NoError(t, err)
516
517				foundWrite := false
518				var writeTCID string
519
520				for _, msg := range msgs {
521					if msg.Role == message.Assistant {
522						for _, tc := range msg.ToolCalls() {
523							if tc.Name == tools.WriteToolName {
524								writeTCID = tc.ID
525							}
526						}
527					}
528					if msg.Role == message.Tool {
529						for _, tr := range msg.ToolResults() {
530							if tr.ToolCallID == writeTCID {
531								foundWrite = true
532							}
533						}
534					}
535				}
536
537				require.True(t, foundWrite, "Expected to find a write operation")
538
539				configPath := filepath.Join(env.workingDir, "config.json")
540				content, err := os.ReadFile(configPath)
541				require.NoError(t, err)
542				require.Contains(t, string(content), "test", "Expected config.json to contain 'test'")
543				require.Contains(t, string(content), "1.0.0", "Expected config.json to contain '1.0.0'")
544			})
545			t.Run("parallel tool calls", func(t *testing.T) {
546				agent, env := setupAgent(t, pair)
547
548				session, err := env.sessions.Create(t.Context(), "New Session")
549				require.NoError(t, err)
550
551				res, err := agent.Run(t.Context(), SessionAgentCall{
552					Prompt:          "use glob to find all .go files and use ls to list the current directory, it is very important that you run both tool calls in parallel",
553					SessionID:       session.ID,
554					MaxOutputTokens: 10000,
555				})
556				require.NoError(t, err)
557				assert.NotNil(t, res)
558
559				msgs, err := env.messages.List(t.Context(), session.ID)
560				require.NoError(t, err)
561
562				var assistantMsg *message.Message
563				var toolMsgs []message.Message
564
565				for _, msg := range msgs {
566					if msg.Role == message.Assistant && len(msg.ToolCalls()) > 0 {
567						assistantMsg = &msg
568					}
569					if msg.Role == message.Tool {
570						toolMsgs = append(toolMsgs, msg)
571					}
572				}
573
574				require.NotNil(t, assistantMsg, "Expected to find an assistant message with tool calls")
575				require.NotNil(t, toolMsgs, "Expected to find a tool message")
576
577				toolCalls := assistantMsg.ToolCalls()
578				require.GreaterOrEqual(t, len(toolCalls), 2, "Expected at least 2 tool calls in parallel")
579
580				foundGlob := false
581				foundLS := false
582				var globTCID, lsTCID string
583
584				for _, tc := range toolCalls {
585					if tc.Name == tools.GlobToolName {
586						foundGlob = true
587						globTCID = tc.ID
588					}
589					if tc.Name == tools.LSToolName {
590						foundLS = true
591						lsTCID = tc.ID
592					}
593				}
594
595				require.True(t, foundGlob, "Expected to find a glob tool call")
596				require.True(t, foundLS, "Expected to find an ls tool call")
597
598				require.GreaterOrEqual(t, len(toolMsgs), 2, "Expected at least 2 tool results in the same message")
599
600				foundGlobResult := false
601				foundLSResult := false
602
603				for _, msg := range toolMsgs {
604					for _, tr := range msg.ToolResults() {
605						if tr.ToolCallID == globTCID {
606							foundGlobResult = true
607							require.Contains(t, tr.Content, "main.go", "Expected glob result to contain main.go")
608							require.False(t, tr.IsError, "Expected glob result to not be an error")
609						}
610						if tr.ToolCallID == lsTCID {
611							foundLSResult = true
612							require.Contains(t, tr.Content, "main.go", "Expected ls result to contain main.go")
613							require.False(t, tr.IsError, "Expected ls result to not be an error")
614						}
615					}
616				}
617
618				require.True(t, foundGlobResult, "Expected to find glob tool result")
619				require.True(t, foundLSResult, "Expected to find ls tool result")
620			})
621		})
622	}
623}
624
625func makeTestTodos(n int) []session.Todo {
626	todos := make([]session.Todo, n)
627	for i := range n {
628		todos[i] = session.Todo{
629			Status:  session.TodoStatusPending,
630			Content: fmt.Sprintf("Task %d: Implement feature with some description that makes it realistic", i),
631		}
632	}
633	return todos
634}
635
636func BenchmarkBuildSummaryPrompt(b *testing.B) {
637	cases := []struct {
638		name     string
639		numTodos int
640	}{
641		{"0todos", 0},
642		{"5todos", 5},
643		{"10todos", 10},
644		{"50todos", 50},
645	}
646
647	for _, tc := range cases {
648		todos := makeTestTodos(tc.numTodos)
649
650		b.Run(tc.name, func(b *testing.B) {
651			b.ReportAllocs()
652			for range b.N {
653				_ = buildSummaryPrompt(todos)
654			}
655		})
656	}
657}