agent_test.go

  1package agent
  2
  3import (
  4	"os"
  5	"path/filepath"
  6	"runtime"
  7	"strings"
  8	"testing"
  9
 10	"charm.land/fantasy"
 11	"github.com/charmbracelet/crush/internal/agent/tools"
 12	"github.com/charmbracelet/crush/internal/message"
 13	"github.com/charmbracelet/crush/internal/shell"
 14	"github.com/stretchr/testify/assert"
 15	"github.com/stretchr/testify/require"
 16	"gopkg.in/dnaeon/go-vcr.v4/pkg/recorder"
 17
 18	_ "github.com/joho/godotenv/autoload"
 19)
 20
 21var modelPairs = []modelPair{
 22	{"anthropic-sonnet", anthropicBuilder("claude-sonnet-4-5-20250929"), anthropicBuilder("claude-3-5-haiku-20241022")},
 23	{"openai-gpt-5", openaiBuilder("gpt-5"), openaiBuilder("gpt-4o")},
 24	{"openrouter-kimi-k2", openRouterBuilder("moonshotai/kimi-k2-0905"), openRouterBuilder("qwen/qwen3-next-80b-a3b-instruct")},
 25	{"zai-glm4.6", zAIBuilder("glm-4.6"), zAIBuilder("glm-4.5-air")},
 26}
 27
 28func getModels(t *testing.T, r *recorder.Recorder, pair modelPair) (fantasy.LanguageModel, fantasy.LanguageModel) {
 29	large, err := pair.largeModel(t, r)
 30	require.NoError(t, err)
 31	small, err := pair.smallModel(t, r)
 32	require.NoError(t, err)
 33	return large, small
 34}
 35
 36func setupAgent(t *testing.T, pair modelPair) (SessionAgent, fakeEnv) {
 37	r := newRecorder(t)
 38	large, small := getModels(t, r, pair)
 39	env := testEnv(t)
 40
 41	createSimpleGoProject(t, env.workingDir)
 42	agent, err := coderAgent(r, env, large, small)
 43	shell.Reset(env.workingDir)
 44	require.NoError(t, err)
 45	return agent, env
 46}
 47
 48func TestCoderAgent(t *testing.T) {
 49	if runtime.GOOS == "windows" {
 50		t.Skip("skipping on windows for now")
 51	}
 52
 53	for _, pair := range modelPairs {
 54		t.Run(pair.name, func(t *testing.T) {
 55			t.Run("simple test", func(t *testing.T) {
 56				agent, env := setupAgent(t, pair)
 57
 58				session, err := env.sessions.Create(t.Context(), "New Session")
 59				require.NoError(t, err)
 60
 61				res, err := agent.Run(t.Context(), SessionAgentCall{
 62					Prompt:          "Hello",
 63					SessionID:       session.ID,
 64					MaxOutputTokens: 10000,
 65				})
 66				require.NoError(t, err)
 67				assert.NotNil(t, res)
 68
 69				msgs, err := env.messages.List(t.Context(), session.ID)
 70				require.NoError(t, err)
 71				// Should have the agent and user message
 72				assert.Equal(t, len(msgs), 2)
 73			})
 74			t.Run("read a file", func(t *testing.T) {
 75				agent, env := setupAgent(t, pair)
 76
 77				session, err := env.sessions.Create(t.Context(), "New Session")
 78				require.NoError(t, err)
 79				res, err := agent.Run(t.Context(), SessionAgentCall{
 80					Prompt:          "Read the go mod",
 81					SessionID:       session.ID,
 82					MaxOutputTokens: 10000,
 83				})
 84
 85				require.NoError(t, err)
 86				assert.NotNil(t, res)
 87
 88				msgs, err := env.messages.List(t.Context(), session.ID)
 89				require.NoError(t, err)
 90				foundFile := false
 91				var tcID string
 92			out:
 93				for _, msg := range msgs {
 94					if msg.Role == message.Assistant {
 95						for _, tc := range msg.ToolCalls() {
 96							if tc.Name == tools.ViewToolName {
 97								tcID = tc.ID
 98							}
 99						}
100					}
101					if msg.Role == message.Tool {
102						for _, tr := range msg.ToolResults() {
103							if tr.ToolCallID == tcID {
104								if strings.Contains(tr.Content, "module example.com/testproject") {
105									foundFile = true
106									break out
107								}
108							}
109						}
110					}
111				}
112				require.True(t, foundFile)
113			})
114			t.Run("update a file", func(t *testing.T) {
115				agent, env := setupAgent(t, pair)
116
117				session, err := env.sessions.Create(t.Context(), "New Session")
118				require.NoError(t, err)
119
120				res, err := agent.Run(t.Context(), SessionAgentCall{
121					Prompt:          "update the main.go file by changing the print to say hello from crush",
122					SessionID:       session.ID,
123					MaxOutputTokens: 10000,
124				})
125				require.NoError(t, err)
126				assert.NotNil(t, res)
127
128				msgs, err := env.messages.List(t.Context(), session.ID)
129				require.NoError(t, err)
130
131				foundRead := false
132				foundWrite := false
133				var readTCID, writeTCID string
134
135				for _, msg := range msgs {
136					if msg.Role == message.Assistant {
137						for _, tc := range msg.ToolCalls() {
138							if tc.Name == tools.ViewToolName {
139								readTCID = tc.ID
140							}
141							if tc.Name == tools.EditToolName || tc.Name == tools.WriteToolName {
142								writeTCID = tc.ID
143							}
144						}
145					}
146					if msg.Role == message.Tool {
147						for _, tr := range msg.ToolResults() {
148							if tr.ToolCallID == readTCID {
149								foundRead = true
150							}
151							if tr.ToolCallID == writeTCID {
152								foundWrite = true
153							}
154						}
155					}
156				}
157
158				require.True(t, foundRead, "Expected to find a read operation")
159				require.True(t, foundWrite, "Expected to find a write operation")
160
161				mainGoPath := filepath.Join(env.workingDir, "main.go")
162				content, err := os.ReadFile(mainGoPath)
163				require.NoError(t, err)
164				require.Contains(t, strings.ToLower(string(content)), "hello from crush")
165			})
166			t.Run("bash tool", func(t *testing.T) {
167				agent, env := setupAgent(t, pair)
168
169				session, err := env.sessions.Create(t.Context(), "New Session")
170				require.NoError(t, err)
171
172				res, err := agent.Run(t.Context(), SessionAgentCall{
173					Prompt:          "use bash to create a file named test.txt with content 'hello bash'. do not print its timestamp",
174					SessionID:       session.ID,
175					MaxOutputTokens: 10000,
176				})
177				require.NoError(t, err)
178				assert.NotNil(t, res)
179
180				msgs, err := env.messages.List(t.Context(), session.ID)
181				require.NoError(t, err)
182
183				foundBash := false
184				var bashTCID string
185
186				for _, msg := range msgs {
187					if msg.Role == message.Assistant {
188						for _, tc := range msg.ToolCalls() {
189							if tc.Name == tools.BashToolName {
190								bashTCID = tc.ID
191							}
192						}
193					}
194					if msg.Role == message.Tool {
195						for _, tr := range msg.ToolResults() {
196							if tr.ToolCallID == bashTCID {
197								foundBash = true
198							}
199						}
200					}
201				}
202
203				require.True(t, foundBash, "Expected to find a bash operation")
204
205				testFilePath := filepath.Join(env.workingDir, "test.txt")
206				content, err := os.ReadFile(testFilePath)
207				require.NoError(t, err)
208				require.Contains(t, string(content), "hello bash")
209			})
210			t.Run("download tool", func(t *testing.T) {
211				agent, env := setupAgent(t, pair)
212
213				session, err := env.sessions.Create(t.Context(), "New Session")
214				require.NoError(t, err)
215
216				res, err := agent.Run(t.Context(), SessionAgentCall{
217					Prompt:          "download the file from https://example-files.online-convert.com/document/txt/example.txt and save it as example.txt",
218					SessionID:       session.ID,
219					MaxOutputTokens: 10000,
220				})
221				require.NoError(t, err)
222				assert.NotNil(t, res)
223
224				msgs, err := env.messages.List(t.Context(), session.ID)
225				require.NoError(t, err)
226
227				foundDownload := false
228				var downloadTCID string
229
230				for _, msg := range msgs {
231					if msg.Role == message.Assistant {
232						for _, tc := range msg.ToolCalls() {
233							if tc.Name == tools.DownloadToolName {
234								downloadTCID = tc.ID
235							}
236						}
237					}
238					if msg.Role == message.Tool {
239						for _, tr := range msg.ToolResults() {
240							if tr.ToolCallID == downloadTCID {
241								foundDownload = true
242							}
243						}
244					}
245				}
246
247				require.True(t, foundDownload, "Expected to find a download operation")
248
249				examplePath := filepath.Join(env.workingDir, "example.txt")
250				_, err = os.Stat(examplePath)
251				require.NoError(t, err, "Expected example.txt file to exist")
252			})
253			t.Run("fetch tool", func(t *testing.T) {
254				agent, env := setupAgent(t, pair)
255
256				session, err := env.sessions.Create(t.Context(), "New Session")
257				require.NoError(t, err)
258
259				res, err := agent.Run(t.Context(), SessionAgentCall{
260					Prompt:          "fetch the content from https://example-files.online-convert.com/website/html/example.html and tell me if it contains the word 'John Doe'",
261					SessionID:       session.ID,
262					MaxOutputTokens: 10000,
263				})
264				require.NoError(t, err)
265				assert.NotNil(t, res)
266
267				msgs, err := env.messages.List(t.Context(), session.ID)
268				require.NoError(t, err)
269
270				foundFetch := false
271				var fetchTCID string
272
273				for _, msg := range msgs {
274					if msg.Role == message.Assistant {
275						for _, tc := range msg.ToolCalls() {
276							if tc.Name == tools.FetchToolName {
277								fetchTCID = tc.ID
278							}
279						}
280					}
281					if msg.Role == message.Tool {
282						for _, tr := range msg.ToolResults() {
283							if tr.ToolCallID == fetchTCID {
284								foundFetch = true
285							}
286						}
287					}
288				}
289
290				require.True(t, foundFetch, "Expected to find a fetch operation")
291			})
292			t.Run("glob tool", func(t *testing.T) {
293				agent, env := setupAgent(t, pair)
294
295				session, err := env.sessions.Create(t.Context(), "New Session")
296				require.NoError(t, err)
297
298				res, err := agent.Run(t.Context(), SessionAgentCall{
299					Prompt:          "use glob to find all .go files in the current directory",
300					SessionID:       session.ID,
301					MaxOutputTokens: 10000,
302				})
303				require.NoError(t, err)
304				assert.NotNil(t, res)
305
306				msgs, err := env.messages.List(t.Context(), session.ID)
307				require.NoError(t, err)
308
309				foundGlob := false
310				var globTCID string
311
312				for _, msg := range msgs {
313					if msg.Role == message.Assistant {
314						for _, tc := range msg.ToolCalls() {
315							if tc.Name == tools.GlobToolName {
316								globTCID = tc.ID
317							}
318						}
319					}
320					if msg.Role == message.Tool {
321						for _, tr := range msg.ToolResults() {
322							if tr.ToolCallID == globTCID {
323								foundGlob = true
324								require.Contains(t, tr.Content, "main.go", "Expected glob to find main.go")
325							}
326						}
327					}
328				}
329
330				require.True(t, foundGlob, "Expected to find a glob operation")
331			})
332			t.Run("grep tool", func(t *testing.T) {
333				agent, env := setupAgent(t, pair)
334
335				session, err := env.sessions.Create(t.Context(), "New Session")
336				require.NoError(t, err)
337
338				res, err := agent.Run(t.Context(), SessionAgentCall{
339					Prompt:          "use grep to search for the word 'package' in go files",
340					SessionID:       session.ID,
341					MaxOutputTokens: 10000,
342				})
343				require.NoError(t, err)
344				assert.NotNil(t, res)
345
346				msgs, err := env.messages.List(t.Context(), session.ID)
347				require.NoError(t, err)
348
349				foundGrep := false
350				var grepTCID string
351
352				for _, msg := range msgs {
353					if msg.Role == message.Assistant {
354						for _, tc := range msg.ToolCalls() {
355							if tc.Name == tools.GrepToolName {
356								grepTCID = tc.ID
357							}
358						}
359					}
360					if msg.Role == message.Tool {
361						for _, tr := range msg.ToolResults() {
362							if tr.ToolCallID == grepTCID {
363								foundGrep = true
364								require.Contains(t, tr.Content, "main.go", "Expected grep to find main.go")
365							}
366						}
367					}
368				}
369
370				require.True(t, foundGrep, "Expected to find a grep operation")
371			})
372			t.Run("ls tool", func(t *testing.T) {
373				agent, env := setupAgent(t, pair)
374
375				session, err := env.sessions.Create(t.Context(), "New Session")
376				require.NoError(t, err)
377
378				res, err := agent.Run(t.Context(), SessionAgentCall{
379					Prompt:          "use ls to list the files in the current directory",
380					SessionID:       session.ID,
381					MaxOutputTokens: 10000,
382				})
383				require.NoError(t, err)
384				assert.NotNil(t, res)
385
386				msgs, err := env.messages.List(t.Context(), session.ID)
387				require.NoError(t, err)
388
389				foundLS := false
390				var lsTCID string
391
392				for _, msg := range msgs {
393					if msg.Role == message.Assistant {
394						for _, tc := range msg.ToolCalls() {
395							if tc.Name == tools.LSToolName {
396								lsTCID = tc.ID
397							}
398						}
399					}
400					if msg.Role == message.Tool {
401						for _, tr := range msg.ToolResults() {
402							if tr.ToolCallID == lsTCID {
403								foundLS = true
404								require.Contains(t, tr.Content, "main.go", "Expected ls to list main.go")
405								require.Contains(t, tr.Content, "go.mod", "Expected ls to list go.mod")
406							}
407						}
408					}
409				}
410
411				require.True(t, foundLS, "Expected to find an ls operation")
412			})
413			t.Run("multiedit tool", func(t *testing.T) {
414				agent, env := setupAgent(t, pair)
415
416				session, err := env.sessions.Create(t.Context(), "New Session")
417				require.NoError(t, err)
418
419				res, err := agent.Run(t.Context(), SessionAgentCall{
420					Prompt:          "use multiedit to change 'Hello, World!' to 'Hello, Crush!' and add a comment '// Greeting' above the fmt.Println line in main.go",
421					SessionID:       session.ID,
422					MaxOutputTokens: 10000,
423				})
424				require.NoError(t, err)
425				assert.NotNil(t, res)
426
427				msgs, err := env.messages.List(t.Context(), session.ID)
428				require.NoError(t, err)
429
430				foundMultiEdit := false
431				var multiEditTCID string
432
433				for _, msg := range msgs {
434					if msg.Role == message.Assistant {
435						for _, tc := range msg.ToolCalls() {
436							if tc.Name == tools.MultiEditToolName {
437								multiEditTCID = tc.ID
438							}
439						}
440					}
441					if msg.Role == message.Tool {
442						for _, tr := range msg.ToolResults() {
443							if tr.ToolCallID == multiEditTCID {
444								foundMultiEdit = true
445							}
446						}
447					}
448				}
449
450				require.True(t, foundMultiEdit, "Expected to find a multiedit operation")
451
452				mainGoPath := filepath.Join(env.workingDir, "main.go")
453				content, err := os.ReadFile(mainGoPath)
454				require.NoError(t, err)
455				require.Contains(t, string(content), "Hello, Crush!", "Expected file to contain 'Hello, Crush!'")
456			})
457			t.Run("sourcegraph tool", func(t *testing.T) {
458				agent, env := setupAgent(t, pair)
459
460				session, err := env.sessions.Create(t.Context(), "New Session")
461				require.NoError(t, err)
462
463				res, err := agent.Run(t.Context(), SessionAgentCall{
464					Prompt:          "use sourcegraph to search for 'func main' in Go repositories",
465					SessionID:       session.ID,
466					MaxOutputTokens: 10000,
467				})
468				require.NoError(t, err)
469				assert.NotNil(t, res)
470
471				msgs, err := env.messages.List(t.Context(), session.ID)
472				require.NoError(t, err)
473
474				foundSourcegraph := false
475				var sourcegraphTCID string
476
477				for _, msg := range msgs {
478					if msg.Role == message.Assistant {
479						for _, tc := range msg.ToolCalls() {
480							if tc.Name == tools.SourcegraphToolName {
481								sourcegraphTCID = tc.ID
482							}
483						}
484					}
485					if msg.Role == message.Tool {
486						for _, tr := range msg.ToolResults() {
487							if tr.ToolCallID == sourcegraphTCID {
488								foundSourcegraph = true
489							}
490						}
491					}
492				}
493
494				require.True(t, foundSourcegraph, "Expected to find a sourcegraph operation")
495			})
496			t.Run("write tool", func(t *testing.T) {
497				agent, env := setupAgent(t, pair)
498
499				session, err := env.sessions.Create(t.Context(), "New Session")
500				require.NoError(t, err)
501
502				res, err := agent.Run(t.Context(), SessionAgentCall{
503					Prompt:          "use write to create a new file called config.json with content '{\"name\": \"test\", \"version\": \"1.0.0\"}'",
504					SessionID:       session.ID,
505					MaxOutputTokens: 10000,
506				})
507				require.NoError(t, err)
508				assert.NotNil(t, res)
509
510				msgs, err := env.messages.List(t.Context(), session.ID)
511				require.NoError(t, err)
512
513				foundWrite := false
514				var writeTCID string
515
516				for _, msg := range msgs {
517					if msg.Role == message.Assistant {
518						for _, tc := range msg.ToolCalls() {
519							if tc.Name == tools.WriteToolName {
520								writeTCID = tc.ID
521							}
522						}
523					}
524					if msg.Role == message.Tool {
525						for _, tr := range msg.ToolResults() {
526							if tr.ToolCallID == writeTCID {
527								foundWrite = true
528							}
529						}
530					}
531				}
532
533				require.True(t, foundWrite, "Expected to find a write operation")
534
535				configPath := filepath.Join(env.workingDir, "config.json")
536				content, err := os.ReadFile(configPath)
537				require.NoError(t, err)
538				require.Contains(t, string(content), "test", "Expected config.json to contain 'test'")
539				require.Contains(t, string(content), "1.0.0", "Expected config.json to contain '1.0.0'")
540			})
541			t.Run("parallel tool calls", func(t *testing.T) {
542				agent, env := setupAgent(t, pair)
543
544				session, err := env.sessions.Create(t.Context(), "New Session")
545				require.NoError(t, err)
546
547				res, err := agent.Run(t.Context(), SessionAgentCall{
548					Prompt:          "use glob to find all .go files and use ls to list the current directory, it is very important that you run both tool calls in parallel",
549					SessionID:       session.ID,
550					MaxOutputTokens: 10000,
551				})
552				require.NoError(t, err)
553				assert.NotNil(t, res)
554
555				msgs, err := env.messages.List(t.Context(), session.ID)
556				require.NoError(t, err)
557
558				var assistantMsg *message.Message
559				var toolMsgs []message.Message
560
561				for _, msg := range msgs {
562					if msg.Role == message.Assistant && len(msg.ToolCalls()) > 0 {
563						assistantMsg = &msg
564					}
565					if msg.Role == message.Tool {
566						toolMsgs = append(toolMsgs, msg)
567					}
568				}
569
570				require.NotNil(t, assistantMsg, "Expected to find an assistant message with tool calls")
571				require.NotNil(t, toolMsgs, "Expected to find a tool message")
572
573				toolCalls := assistantMsg.ToolCalls()
574				require.GreaterOrEqual(t, len(toolCalls), 2, "Expected at least 2 tool calls in parallel")
575
576				foundGlob := false
577				foundLS := false
578				var globTCID, lsTCID string
579
580				for _, tc := range toolCalls {
581					if tc.Name == tools.GlobToolName {
582						foundGlob = true
583						globTCID = tc.ID
584					}
585					if tc.Name == tools.LSToolName {
586						foundLS = true
587						lsTCID = tc.ID
588					}
589				}
590
591				require.True(t, foundGlob, "Expected to find a glob tool call")
592				require.True(t, foundLS, "Expected to find an ls tool call")
593
594				require.GreaterOrEqual(t, len(toolMsgs), 2, "Expected at least 2 tool results in the same message")
595
596				foundGlobResult := false
597				foundLSResult := false
598
599				for _, msg := range toolMsgs {
600					for _, tr := range msg.ToolResults() {
601						if tr.ToolCallID == globTCID {
602							foundGlobResult = true
603							require.Contains(t, tr.Content, "main.go", "Expected glob result to contain main.go")
604							require.False(t, tr.IsError, "Expected glob result to not be an error")
605						}
606						if tr.ToolCallID == lsTCID {
607							foundLSResult = true
608							require.Contains(t, tr.Content, "main.go", "Expected ls result to contain main.go")
609							require.False(t, tr.IsError, "Expected ls result to not be an error")
610						}
611					}
612				}
613
614				require.True(t, foundGlobResult, "Expected to find glob tool result")
615				require.True(t, foundLSResult, "Expected to find ls tool result")
616			})
617		})
618	}
619}