agent_test.go

  1package agent
  2
  3import (
  4	"os"
  5	"path/filepath"
  6	"runtime"
  7	"strings"
  8	"testing"
  9
 10	"charm.land/fantasy"
 11	"github.com/charmbracelet/crush/internal/agent/tools"
 12	"github.com/charmbracelet/crush/internal/message"
 13	"github.com/charmbracelet/crush/internal/shell"
 14	"github.com/stretchr/testify/assert"
 15	"github.com/stretchr/testify/require"
 16	"gopkg.in/dnaeon/go-vcr.v4/pkg/recorder"
 17
 18	_ "github.com/joho/godotenv/autoload"
 19)
 20
 21var modelPairs = []modelPair{
 22	{"anthropic-sonnet", anthropicBuilder("claude-sonnet-4-5-20250929"), anthropicBuilder("claude-3-5-haiku-20241022")},
 23	{"openai-gpt-5", openaiBuilder("gpt-5"), openaiBuilder("gpt-4o")},
 24	{"openrouter-kimi-k2", openRouterBuilder("moonshotai/kimi-k2-0905"), openRouterBuilder("qwen/qwen3-next-80b-a3b-instruct")},
 25	{"zai-glm4.6", zAIBuilder("glm-4.6"), zAIBuilder("glm-4.5-air")},
 26}
 27
 28func getModels(t *testing.T, r *recorder.Recorder, pair modelPair) (fantasy.LanguageModel, fantasy.LanguageModel) {
 29	large, err := pair.largeModel(t, r)
 30	require.NoError(t, err)
 31	small, err := pair.smallModel(t, r)
 32	require.NoError(t, err)
 33	return large, small
 34}
 35
 36func setupAgent(t *testing.T, pair modelPair) (SessionAgent, env) {
 37	r := newRecorder(t)
 38	large, small := getModels(t, r, pair)
 39	env := testEnv(t)
 40
 41	createSimpleGoProject(t, env.workingDir)
 42	agent, err := coderAgent(r, env, large, small)
 43	shell.Reset(env.workingDir)
 44	require.NoError(t, err)
 45	return agent, env
 46}
 47
 48func TestCoderAgent(t *testing.T) {
 49	if runtime.GOOS == "windows" {
 50		t.Skip("skipping on windows for now")
 51	}
 52
 53	for _, pair := range modelPairs {
 54		t.Run(pair.name, func(t *testing.T) {
 55			t.Run("simple test", func(t *testing.T) {
 56				agent, env := setupAgent(t, pair)
 57
 58				session, err := env.sessions.Create(t.Context(), "New Session")
 59				require.NoError(t, err)
 60
 61				res, err := agent.Run(t.Context(), SessionAgentCall{
 62					Prompt:          "Hello",
 63					SessionID:       session.ID,
 64					MaxOutputTokens: 10000,
 65				})
 66				require.NoError(t, err)
 67				assert.NotNil(t, res)
 68
 69				msgs, err := env.messages.List(t.Context(), session.ID)
 70				require.NoError(t, err)
 71				// Should have the agent and user message
 72				assert.Equal(t, len(msgs), 2)
 73			})
 74			t.Run("read a file", func(t *testing.T) {
 75				agent, env := setupAgent(t, pair)
 76
 77				session, err := env.sessions.Create(t.Context(), "New Session")
 78				require.NoError(t, err)
 79				res, err := agent.Run(t.Context(), SessionAgentCall{
 80					Prompt:          "Read the go mod",
 81					SessionID:       session.ID,
 82					MaxOutputTokens: 10000,
 83				})
 84
 85				require.NoError(t, err)
 86				assert.NotNil(t, res)
 87
 88				msgs, err := env.messages.List(t.Context(), session.ID)
 89				require.NoError(t, err)
 90				foundFile := false
 91				var tcID string
 92			out:
 93				for _, msg := range msgs {
 94					if msg.Role == message.Assistant {
 95						for _, tc := range msg.ToolCalls() {
 96							if tc.Name == tools.ViewToolName {
 97								tcID = tc.ID
 98							}
 99						}
100					}
101					if msg.Role == message.Tool {
102						for _, tr := range msg.ToolResults() {
103							if tr.ToolCallID == tcID {
104								if strings.Contains(tr.Content, "module example.com/testproject") {
105									foundFile = true
106									break out
107								}
108							}
109						}
110					}
111				}
112				require.True(t, foundFile)
113			})
114			t.Run("update a file", func(t *testing.T) {
115				agent, env := setupAgent(t, pair)
116
117				session, err := env.sessions.Create(t.Context(), "New Session")
118				require.NoError(t, err)
119
120				res, err := agent.Run(t.Context(), SessionAgentCall{
121					Prompt:          "update the main.go file by changing the print to say hello from crush",
122					SessionID:       session.ID,
123					MaxOutputTokens: 10000,
124				})
125				require.NoError(t, err)
126				assert.NotNil(t, res)
127
128				msgs, err := env.messages.List(t.Context(), session.ID)
129				require.NoError(t, err)
130
131				foundRead := false
132				foundWrite := false
133				var readTCID, writeTCID string
134
135				for _, msg := range msgs {
136					if msg.Role == message.Assistant {
137						for _, tc := range msg.ToolCalls() {
138							if tc.Name == tools.ViewToolName {
139								readTCID = tc.ID
140							}
141							if tc.Name == tools.EditToolName || tc.Name == tools.WriteToolName {
142								writeTCID = tc.ID
143							}
144						}
145					}
146					if msg.Role == message.Tool {
147						for _, tr := range msg.ToolResults() {
148							if tr.ToolCallID == readTCID {
149								foundRead = true
150							}
151							if tr.ToolCallID == writeTCID {
152								foundWrite = true
153							}
154						}
155					}
156				}
157
158				require.True(t, foundRead, "Expected to find a read operation")
159				require.True(t, foundWrite, "Expected to find a write operation")
160
161				mainGoPath := filepath.Join(env.workingDir, "main.go")
162				content, err := os.ReadFile(mainGoPath)
163				require.NoError(t, err)
164				require.Contains(t, strings.ToLower(string(content)), "hello from crush")
165			})
166			t.Run("bash tool", func(t *testing.T) {
167				agent, env := setupAgent(t, pair)
168
169				session, err := env.sessions.Create(t.Context(), "New Session")
170				require.NoError(t, err)
171
172				res, err := agent.Run(t.Context(), SessionAgentCall{
173					Prompt:          "use bash to create a file named test.txt with content 'hello bash'. do not print its timestamp",
174					SessionID:       session.ID,
175					MaxOutputTokens: 10000,
176				})
177				require.NoError(t, err)
178				assert.NotNil(t, res)
179
180				msgs, err := env.messages.List(t.Context(), session.ID)
181				require.NoError(t, err)
182
183				foundBash := false
184				var bashTCID string
185
186				for _, msg := range msgs {
187					if msg.Role == message.Assistant {
188						for _, tc := range msg.ToolCalls() {
189							if tc.Name == tools.BashToolName {
190								bashTCID = tc.ID
191							}
192						}
193					}
194					if msg.Role == message.Tool {
195						for _, tr := range msg.ToolResults() {
196							if tr.ToolCallID == bashTCID {
197								foundBash = true
198							}
199						}
200					}
201				}
202
203				require.True(t, foundBash, "Expected to find a bash operation")
204
205				testFilePath := filepath.Join(env.workingDir, "test.txt")
206				content, err := os.ReadFile(testFilePath)
207				require.NoError(t, err)
208				require.Contains(t, string(content), "hello bash")
209			})
210			t.Run("download tool", func(t *testing.T) {
211				agent, env := setupAgent(t, pair)
212
213				session, err := env.sessions.Create(t.Context(), "New Session")
214				require.NoError(t, err)
215
216				res, err := agent.Run(t.Context(), SessionAgentCall{
217					Prompt:          "download the file from https://example-files.online-convert.com/document/txt/example.txt and save it as example.txt",
218					SessionID:       session.ID,
219					MaxOutputTokens: 10000,
220				})
221				require.NoError(t, err)
222				assert.NotNil(t, res)
223
224				msgs, err := env.messages.List(t.Context(), session.ID)
225				require.NoError(t, err)
226
227				foundDownload := false
228				var downloadTCID string
229
230				for _, msg := range msgs {
231					if msg.Role == message.Assistant {
232						for _, tc := range msg.ToolCalls() {
233							if tc.Name == tools.DownloadToolName {
234								downloadTCID = tc.ID
235							}
236						}
237					}
238					if msg.Role == message.Tool {
239						for _, tr := range msg.ToolResults() {
240							if tr.ToolCallID == downloadTCID {
241								foundDownload = true
242							}
243						}
244					}
245				}
246
247				require.True(t, foundDownload, "Expected to find a download operation")
248
249				examplePath := filepath.Join(env.workingDir, "example.txt")
250				_, err = os.Stat(examplePath)
251				require.NoError(t, err, "Expected example.txt file to exist")
252			})
253			t.Run("glob tool", func(t *testing.T) {
254				agent, env := setupAgent(t, pair)
255
256				session, err := env.sessions.Create(t.Context(), "New Session")
257				require.NoError(t, err)
258
259				res, err := agent.Run(t.Context(), SessionAgentCall{
260					Prompt:          "use glob to find all .go files in the current directory",
261					SessionID:       session.ID,
262					MaxOutputTokens: 10000,
263				})
264				require.NoError(t, err)
265				assert.NotNil(t, res)
266
267				msgs, err := env.messages.List(t.Context(), session.ID)
268				require.NoError(t, err)
269
270				foundGlob := false
271				var globTCID string
272
273				for _, msg := range msgs {
274					if msg.Role == message.Assistant {
275						for _, tc := range msg.ToolCalls() {
276							if tc.Name == tools.GlobToolName {
277								globTCID = tc.ID
278							}
279						}
280					}
281					if msg.Role == message.Tool {
282						for _, tr := range msg.ToolResults() {
283							if tr.ToolCallID == globTCID {
284								foundGlob = true
285								require.Contains(t, tr.Content, "main.go", "Expected glob to find main.go")
286							}
287						}
288					}
289				}
290
291				require.True(t, foundGlob, "Expected to find a glob operation")
292			})
293			t.Run("grep tool", func(t *testing.T) {
294				agent, env := setupAgent(t, pair)
295
296				session, err := env.sessions.Create(t.Context(), "New Session")
297				require.NoError(t, err)
298
299				res, err := agent.Run(t.Context(), SessionAgentCall{
300					Prompt:          "use grep to search for the word 'package' in go files",
301					SessionID:       session.ID,
302					MaxOutputTokens: 10000,
303				})
304				require.NoError(t, err)
305				assert.NotNil(t, res)
306
307				msgs, err := env.messages.List(t.Context(), session.ID)
308				require.NoError(t, err)
309
310				foundGrep := false
311				var grepTCID string
312
313				for _, msg := range msgs {
314					if msg.Role == message.Assistant {
315						for _, tc := range msg.ToolCalls() {
316							if tc.Name == tools.GrepToolName {
317								grepTCID = tc.ID
318							}
319						}
320					}
321					if msg.Role == message.Tool {
322						for _, tr := range msg.ToolResults() {
323							if tr.ToolCallID == grepTCID {
324								foundGrep = true
325								require.Contains(t, tr.Content, "main.go", "Expected grep to find main.go")
326							}
327						}
328					}
329				}
330
331				require.True(t, foundGrep, "Expected to find a grep operation")
332			})
333			t.Run("ls tool", func(t *testing.T) {
334				agent, env := setupAgent(t, pair)
335
336				session, err := env.sessions.Create(t.Context(), "New Session")
337				require.NoError(t, err)
338
339				res, err := agent.Run(t.Context(), SessionAgentCall{
340					Prompt:          "use ls to list the files in the current directory",
341					SessionID:       session.ID,
342					MaxOutputTokens: 10000,
343				})
344				require.NoError(t, err)
345				assert.NotNil(t, res)
346
347				msgs, err := env.messages.List(t.Context(), session.ID)
348				require.NoError(t, err)
349
350				foundLS := false
351				var lsTCID string
352
353				for _, msg := range msgs {
354					if msg.Role == message.Assistant {
355						for _, tc := range msg.ToolCalls() {
356							if tc.Name == tools.LSToolName {
357								lsTCID = tc.ID
358							}
359						}
360					}
361					if msg.Role == message.Tool {
362						for _, tr := range msg.ToolResults() {
363							if tr.ToolCallID == lsTCID {
364								foundLS = true
365								require.Contains(t, tr.Content, "main.go", "Expected ls to list main.go")
366								require.Contains(t, tr.Content, "go.mod", "Expected ls to list go.mod")
367							}
368						}
369					}
370				}
371
372				require.True(t, foundLS, "Expected to find an ls operation")
373			})
374			t.Run("multiedit tool", func(t *testing.T) {
375				agent, env := setupAgent(t, pair)
376
377				session, err := env.sessions.Create(t.Context(), "New Session")
378				require.NoError(t, err)
379
380				res, err := agent.Run(t.Context(), SessionAgentCall{
381					Prompt:          "use multiedit to change 'Hello, World!' to 'Hello, Crush!' and add a comment '// Greeting' above the fmt.Println line in main.go",
382					SessionID:       session.ID,
383					MaxOutputTokens: 10000,
384				})
385				require.NoError(t, err)
386				assert.NotNil(t, res)
387
388				msgs, err := env.messages.List(t.Context(), session.ID)
389				require.NoError(t, err)
390
391				foundMultiEdit := false
392				var multiEditTCID string
393
394				for _, msg := range msgs {
395					if msg.Role == message.Assistant {
396						for _, tc := range msg.ToolCalls() {
397							if tc.Name == tools.MultiEditToolName {
398								multiEditTCID = tc.ID
399							}
400						}
401					}
402					if msg.Role == message.Tool {
403						for _, tr := range msg.ToolResults() {
404							if tr.ToolCallID == multiEditTCID {
405								foundMultiEdit = true
406							}
407						}
408					}
409				}
410
411				require.True(t, foundMultiEdit, "Expected to find a multiedit operation")
412
413				mainGoPath := filepath.Join(env.workingDir, "main.go")
414				content, err := os.ReadFile(mainGoPath)
415				require.NoError(t, err)
416				require.Contains(t, string(content), "Hello, Crush!", "Expected file to contain 'Hello, Crush!'")
417			})
418			t.Run("sourcegraph tool", func(t *testing.T) {
419				agent, env := setupAgent(t, pair)
420
421				session, err := env.sessions.Create(t.Context(), "New Session")
422				require.NoError(t, err)
423
424				res, err := agent.Run(t.Context(), SessionAgentCall{
425					Prompt:          "use sourcegraph to search for 'func main' in Go repositories",
426					SessionID:       session.ID,
427					MaxOutputTokens: 10000,
428				})
429				require.NoError(t, err)
430				assert.NotNil(t, res)
431
432				msgs, err := env.messages.List(t.Context(), session.ID)
433				require.NoError(t, err)
434
435				foundSourcegraph := false
436				var sourcegraphTCID string
437
438				for _, msg := range msgs {
439					if msg.Role == message.Assistant {
440						for _, tc := range msg.ToolCalls() {
441							if tc.Name == tools.SourcegraphToolName {
442								sourcegraphTCID = tc.ID
443							}
444						}
445					}
446					if msg.Role == message.Tool {
447						for _, tr := range msg.ToolResults() {
448							if tr.ToolCallID == sourcegraphTCID {
449								foundSourcegraph = true
450							}
451						}
452					}
453				}
454
455				require.True(t, foundSourcegraph, "Expected to find a sourcegraph operation")
456			})
457			t.Run("write tool", func(t *testing.T) {
458				agent, env := setupAgent(t, pair)
459
460				session, err := env.sessions.Create(t.Context(), "New Session")
461				require.NoError(t, err)
462
463				res, err := agent.Run(t.Context(), SessionAgentCall{
464					Prompt:          "use write to create a new file called config.json with content '{\"name\": \"test\", \"version\": \"1.0.0\"}'",
465					SessionID:       session.ID,
466					MaxOutputTokens: 10000,
467				})
468				require.NoError(t, err)
469				assert.NotNil(t, res)
470
471				msgs, err := env.messages.List(t.Context(), session.ID)
472				require.NoError(t, err)
473
474				foundWrite := false
475				var writeTCID string
476
477				for _, msg := range msgs {
478					if msg.Role == message.Assistant {
479						for _, tc := range msg.ToolCalls() {
480							if tc.Name == tools.WriteToolName {
481								writeTCID = tc.ID
482							}
483						}
484					}
485					if msg.Role == message.Tool {
486						for _, tr := range msg.ToolResults() {
487							if tr.ToolCallID == writeTCID {
488								foundWrite = true
489							}
490						}
491					}
492				}
493
494				require.True(t, foundWrite, "Expected to find a write operation")
495
496				configPath := filepath.Join(env.workingDir, "config.json")
497				content, err := os.ReadFile(configPath)
498				require.NoError(t, err)
499				require.Contains(t, string(content), "test", "Expected config.json to contain 'test'")
500				require.Contains(t, string(content), "1.0.0", "Expected config.json to contain '1.0.0'")
501			})
502			t.Run("parallel tool calls", func(t *testing.T) {
503				agent, env := setupAgent(t, pair)
504
505				session, err := env.sessions.Create(t.Context(), "New Session")
506				require.NoError(t, err)
507
508				res, err := agent.Run(t.Context(), SessionAgentCall{
509					Prompt:          "use glob to find all .go files and use ls to list the current directory, it is very important that you run both tool calls in parallel",
510					SessionID:       session.ID,
511					MaxOutputTokens: 10000,
512				})
513				require.NoError(t, err)
514				assert.NotNil(t, res)
515
516				msgs, err := env.messages.List(t.Context(), session.ID)
517				require.NoError(t, err)
518
519				var assistantMsg *message.Message
520				var toolMsgs []message.Message
521
522				for _, msg := range msgs {
523					if msg.Role == message.Assistant && len(msg.ToolCalls()) > 0 {
524						assistantMsg = &msg
525					}
526					if msg.Role == message.Tool {
527						toolMsgs = append(toolMsgs, msg)
528					}
529				}
530
531				require.NotNil(t, assistantMsg, "Expected to find an assistant message with tool calls")
532				require.NotNil(t, toolMsgs, "Expected to find a tool message")
533
534				toolCalls := assistantMsg.ToolCalls()
535				require.GreaterOrEqual(t, len(toolCalls), 2, "Expected at least 2 tool calls in parallel")
536
537				foundGlob := false
538				foundLS := false
539				var globTCID, lsTCID string
540
541				for _, tc := range toolCalls {
542					if tc.Name == tools.GlobToolName {
543						foundGlob = true
544						globTCID = tc.ID
545					}
546					if tc.Name == tools.LSToolName {
547						foundLS = true
548						lsTCID = tc.ID
549					}
550				}
551
552				require.True(t, foundGlob, "Expected to find a glob tool call")
553				require.True(t, foundLS, "Expected to find an ls tool call")
554
555				require.GreaterOrEqual(t, len(toolMsgs), 2, "Expected at least 2 tool results in the same message")
556
557				foundGlobResult := false
558				foundLSResult := false
559
560				for _, msg := range toolMsgs {
561					for _, tr := range msg.ToolResults() {
562						if tr.ToolCallID == globTCID {
563							foundGlobResult = true
564							require.Contains(t, tr.Content, "main.go", "Expected glob result to contain main.go")
565							require.False(t, tr.IsError, "Expected glob result to not be an error")
566						}
567						if tr.ToolCallID == lsTCID {
568							foundLSResult = true
569							require.Contains(t, tr.Content, "main.go", "Expected ls result to contain main.go")
570							require.False(t, tr.IsError, "Expected ls result to not be an error")
571						}
572					}
573				}
574
575				require.True(t, foundGlobResult, "Expected to find glob tool result")
576				require.True(t, foundLSResult, "Expected to find ls tool result")
577			})
578		})
579	}
580}