agent_test.go

  1package agent
  2
  3import (
  4	"os"
  5	"path/filepath"
  6	"strings"
  7	"testing"
  8
  9	"github.com/charmbracelet/crush/internal/agent/tools"
 10	"github.com/charmbracelet/crush/internal/message"
 11	"github.com/charmbracelet/crush/internal/shell"
 12	"github.com/charmbracelet/fantasy/ai"
 13	"github.com/stretchr/testify/assert"
 14	"github.com/stretchr/testify/require"
 15	"gopkg.in/dnaeon/go-vcr.v4/pkg/recorder"
 16
 17	_ "github.com/joho/godotenv/autoload"
 18)
 19
 20var modelPairs = []modelPair{
 21	{"anthropic-sonnet", anthropicBuilder("claude-sonnet-4-5-20250929"), anthropicBuilder("claude-3-5-haiku-20241022")},
 22	{"openai-gpt-5", openaiBuilder("gpt-5"), openaiBuilder("gpt-4o")},
 23	{"openrouter-kimi-k2", openRouterBuilder("moonshotai/kimi-k2-0905"), openRouterBuilder("qwen/qwen3-next-80b-a3b-instruct")},
 24}
 25
 26func getModels(t *testing.T, r *recorder.Recorder, pair modelPair) (ai.LanguageModel, ai.LanguageModel) {
 27	large, err := pair.largeModel(t, r)
 28	require.NoError(t, err)
 29	small, err := pair.smallModel(t, r)
 30	require.NoError(t, err)
 31	return large, small
 32}
 33
 34func setupAgent(t *testing.T, pair modelPair) (SessionAgent, env) {
 35	r := newRecorder(t)
 36	large, small := getModels(t, r, pair)
 37	env := testEnv(t)
 38
 39	createSimpleGoProject(t, env.workingDir)
 40	agent, err := coderAgent(r, env, large, small)
 41	shell.Reset(env.workingDir)
 42	require.NoError(t, err)
 43	return agent, env
 44}
 45
 46func TestCoderAgent(t *testing.T) {
 47	for _, pair := range modelPairs {
 48		t.Run(pair.name, func(t *testing.T) {
 49			t.Run("simple test", func(t *testing.T) {
 50				agent, env := setupAgent(t, pair)
 51
 52				session, err := env.sessions.Create(t.Context(), "New Session")
 53				require.NoError(t, err)
 54
 55				res, err := agent.Run(t.Context(), SessionAgentCall{
 56					Prompt:          "Hello",
 57					SessionID:       session.ID,
 58					MaxOutputTokens: 10000,
 59				})
 60				require.NoError(t, err)
 61				assert.NotNil(t, res)
 62
 63				msgs, err := env.messages.List(t.Context(), session.ID)
 64				require.NoError(t, err)
 65				// Should have the agent and user message
 66				assert.Equal(t, len(msgs), 2)
 67			})
 68			t.Run("read a file", func(t *testing.T) {
 69				agent, env := setupAgent(t, pair)
 70
 71				session, err := env.sessions.Create(t.Context(), "New Session")
 72				require.NoError(t, err)
 73				res, err := agent.Run(t.Context(), SessionAgentCall{
 74					Prompt:          "Read the go mod",
 75					SessionID:       session.ID,
 76					MaxOutputTokens: 10000,
 77				})
 78
 79				require.NoError(t, err)
 80				assert.NotNil(t, res)
 81
 82				msgs, err := env.messages.List(t.Context(), session.ID)
 83				require.NoError(t, err)
 84				foundFile := false
 85				var tcID string
 86			out:
 87				for _, msg := range msgs {
 88					if msg.Role == message.Assistant {
 89						for _, tc := range msg.ToolCalls() {
 90							if tc.Name == tools.ViewToolName {
 91								tcID = tc.ID
 92							}
 93						}
 94					}
 95					if msg.Role == message.Tool {
 96						for _, tr := range msg.ToolResults() {
 97							if tr.ToolCallID == tcID {
 98								if strings.Contains(tr.Content, "module example.com/testproject") {
 99									foundFile = true
100									break out
101								}
102							}
103						}
104					}
105				}
106				require.True(t, foundFile)
107			})
108			t.Run("update a file", func(t *testing.T) {
109				agent, env := setupAgent(t, pair)
110
111				session, err := env.sessions.Create(t.Context(), "New Session")
112				require.NoError(t, err)
113
114				res, err := agent.Run(t.Context(), SessionAgentCall{
115					Prompt:          "update the main.go file by changing the print to say hello from crush",
116					SessionID:       session.ID,
117					MaxOutputTokens: 10000,
118				})
119				require.NoError(t, err)
120				assert.NotNil(t, res)
121
122				msgs, err := env.messages.List(t.Context(), session.ID)
123				require.NoError(t, err)
124
125				foundRead := false
126				foundWrite := false
127				var readTCID, writeTCID string
128
129				for _, msg := range msgs {
130					if msg.Role == message.Assistant {
131						for _, tc := range msg.ToolCalls() {
132							if tc.Name == tools.ViewToolName {
133								readTCID = tc.ID
134							}
135							if tc.Name == tools.EditToolName || tc.Name == tools.WriteToolName {
136								writeTCID = tc.ID
137							}
138						}
139					}
140					if msg.Role == message.Tool {
141						for _, tr := range msg.ToolResults() {
142							if tr.ToolCallID == readTCID {
143								foundRead = true
144							}
145							if tr.ToolCallID == writeTCID {
146								foundWrite = true
147							}
148						}
149					}
150				}
151
152				require.True(t, foundRead, "Expected to find a read operation")
153				require.True(t, foundWrite, "Expected to find a write operation")
154
155				mainGoPath := filepath.Join(env.workingDir, "main.go")
156				content, err := os.ReadFile(mainGoPath)
157				require.NoError(t, err)
158				require.Contains(t, strings.ToLower(string(content)), "hello from crush")
159			})
160			t.Run("bash tool", func(t *testing.T) {
161				agent, env := setupAgent(t, pair)
162
163				session, err := env.sessions.Create(t.Context(), "New Session")
164				require.NoError(t, err)
165
166				res, err := agent.Run(t.Context(), SessionAgentCall{
167					Prompt:          "use bash to create a file named test.txt with content 'hello bash'",
168					SessionID:       session.ID,
169					MaxOutputTokens: 10000,
170				})
171				require.NoError(t, err)
172				assert.NotNil(t, res)
173
174				msgs, err := env.messages.List(t.Context(), session.ID)
175				require.NoError(t, err)
176
177				foundBash := false
178				var bashTCID string
179
180				for _, msg := range msgs {
181					if msg.Role == message.Assistant {
182						for _, tc := range msg.ToolCalls() {
183							if tc.Name == tools.BashToolName {
184								bashTCID = tc.ID
185							}
186						}
187					}
188					if msg.Role == message.Tool {
189						for _, tr := range msg.ToolResults() {
190							if tr.ToolCallID == bashTCID {
191								foundBash = true
192							}
193						}
194					}
195				}
196
197				require.True(t, foundBash, "Expected to find a bash operation")
198
199				testFilePath := filepath.Join(env.workingDir, "test.txt")
200				content, err := os.ReadFile(testFilePath)
201				require.NoError(t, err)
202				require.Contains(t, string(content), "hello bash")
203			})
204			t.Run("download tool", func(t *testing.T) {
205				agent, env := setupAgent(t, pair)
206
207				session, err := env.sessions.Create(t.Context(), "New Session")
208				require.NoError(t, err)
209
210				res, err := agent.Run(t.Context(), SessionAgentCall{
211					Prompt:          "download the file from https://httpbin.org/robots.txt and save it as robots.txt",
212					SessionID:       session.ID,
213					MaxOutputTokens: 10000,
214				})
215				require.NoError(t, err)
216				assert.NotNil(t, res)
217
218				msgs, err := env.messages.List(t.Context(), session.ID)
219				require.NoError(t, err)
220
221				foundDownload := false
222				var downloadTCID string
223
224				for _, msg := range msgs {
225					if msg.Role == message.Assistant {
226						for _, tc := range msg.ToolCalls() {
227							if tc.Name == tools.DownloadToolName {
228								downloadTCID = tc.ID
229							}
230						}
231					}
232					if msg.Role == message.Tool {
233						for _, tr := range msg.ToolResults() {
234							if tr.ToolCallID == downloadTCID {
235								foundDownload = true
236							}
237						}
238					}
239				}
240
241				require.True(t, foundDownload, "Expected to find a download operation")
242
243				robotsPath := filepath.Join(env.workingDir, "robots.txt")
244				_, err = os.Stat(robotsPath)
245				require.NoError(t, err, "Expected robots.txt file to exist")
246			})
247			t.Run("fetch tool", func(t *testing.T) {
248				agent, env := setupAgent(t, pair)
249
250				session, err := env.sessions.Create(t.Context(), "New Session")
251				require.NoError(t, err)
252
253				res, err := agent.Run(t.Context(), SessionAgentCall{
254					Prompt:          "fetch the content from https://httpbin.org/html and tell me if it contains the word 'Herman'",
255					SessionID:       session.ID,
256					MaxOutputTokens: 10000,
257				})
258				require.NoError(t, err)
259				assert.NotNil(t, res)
260
261				msgs, err := env.messages.List(t.Context(), session.ID)
262				require.NoError(t, err)
263
264				foundFetch := false
265				var fetchTCID string
266
267				for _, msg := range msgs {
268					if msg.Role == message.Assistant {
269						for _, tc := range msg.ToolCalls() {
270							if tc.Name == tools.FetchToolName {
271								fetchTCID = tc.ID
272							}
273						}
274					}
275					if msg.Role == message.Tool {
276						for _, tr := range msg.ToolResults() {
277							if tr.ToolCallID == fetchTCID {
278								foundFetch = true
279							}
280						}
281					}
282				}
283
284				require.True(t, foundFetch, "Expected to find a fetch operation")
285			})
286			t.Run("glob tool", func(t *testing.T) {
287				agent, env := setupAgent(t, pair)
288
289				session, err := env.sessions.Create(t.Context(), "New Session")
290				require.NoError(t, err)
291
292				res, err := agent.Run(t.Context(), SessionAgentCall{
293					Prompt:          "use glob to find all .go files in the current directory",
294					SessionID:       session.ID,
295					MaxOutputTokens: 10000,
296				})
297				require.NoError(t, err)
298				assert.NotNil(t, res)
299
300				msgs, err := env.messages.List(t.Context(), session.ID)
301				require.NoError(t, err)
302
303				foundGlob := false
304				var globTCID string
305
306				for _, msg := range msgs {
307					if msg.Role == message.Assistant {
308						for _, tc := range msg.ToolCalls() {
309							if tc.Name == tools.GlobToolName {
310								globTCID = tc.ID
311							}
312						}
313					}
314					if msg.Role == message.Tool {
315						for _, tr := range msg.ToolResults() {
316							if tr.ToolCallID == globTCID {
317								foundGlob = true
318								require.Contains(t, tr.Content, "main.go", "Expected glob to find main.go")
319							}
320						}
321					}
322				}
323
324				require.True(t, foundGlob, "Expected to find a glob operation")
325			})
326			t.Run("grep tool", func(t *testing.T) {
327				agent, env := setupAgent(t, pair)
328
329				session, err := env.sessions.Create(t.Context(), "New Session")
330				require.NoError(t, err)
331
332				res, err := agent.Run(t.Context(), SessionAgentCall{
333					Prompt:          "use grep to search for the word 'package' in go files",
334					SessionID:       session.ID,
335					MaxOutputTokens: 10000,
336				})
337				require.NoError(t, err)
338				assert.NotNil(t, res)
339
340				msgs, err := env.messages.List(t.Context(), session.ID)
341				require.NoError(t, err)
342
343				foundGrep := false
344				var grepTCID string
345
346				for _, msg := range msgs {
347					if msg.Role == message.Assistant {
348						for _, tc := range msg.ToolCalls() {
349							if tc.Name == tools.GrepToolName {
350								grepTCID = tc.ID
351							}
352						}
353					}
354					if msg.Role == message.Tool {
355						for _, tr := range msg.ToolResults() {
356							if tr.ToolCallID == grepTCID {
357								foundGrep = true
358								require.Contains(t, tr.Content, "main.go", "Expected grep to find main.go")
359							}
360						}
361					}
362				}
363
364				require.True(t, foundGrep, "Expected to find a grep operation")
365			})
366			t.Run("ls tool", func(t *testing.T) {
367				agent, env := setupAgent(t, pair)
368
369				session, err := env.sessions.Create(t.Context(), "New Session")
370				require.NoError(t, err)
371
372				res, err := agent.Run(t.Context(), SessionAgentCall{
373					Prompt:          "use ls to list the files in the current directory",
374					SessionID:       session.ID,
375					MaxOutputTokens: 10000,
376				})
377				require.NoError(t, err)
378				assert.NotNil(t, res)
379
380				msgs, err := env.messages.List(t.Context(), session.ID)
381				require.NoError(t, err)
382
383				foundLS := false
384				var lsTCID string
385
386				for _, msg := range msgs {
387					if msg.Role == message.Assistant {
388						for _, tc := range msg.ToolCalls() {
389							if tc.Name == tools.LSToolName {
390								lsTCID = tc.ID
391							}
392						}
393					}
394					if msg.Role == message.Tool {
395						for _, tr := range msg.ToolResults() {
396							if tr.ToolCallID == lsTCID {
397								foundLS = true
398								require.Contains(t, tr.Content, "main.go", "Expected ls to list main.go")
399								require.Contains(t, tr.Content, "go.mod", "Expected ls to list go.mod")
400							}
401						}
402					}
403				}
404
405				require.True(t, foundLS, "Expected to find an ls operation")
406			})
407			t.Run("multiedit tool", func(t *testing.T) {
408				agent, env := setupAgent(t, pair)
409
410				session, err := env.sessions.Create(t.Context(), "New Session")
411				require.NoError(t, err)
412
413				res, err := agent.Run(t.Context(), SessionAgentCall{
414					Prompt:          "use multiedit to change 'Hello, World!' to 'Hello, Crush!' and add a comment '// Greeting' above the fmt.Println line in main.go",
415					SessionID:       session.ID,
416					MaxOutputTokens: 10000,
417				})
418				require.NoError(t, err)
419				assert.NotNil(t, res)
420
421				msgs, err := env.messages.List(t.Context(), session.ID)
422				require.NoError(t, err)
423
424				foundMultiEdit := false
425				var multiEditTCID string
426
427				for _, msg := range msgs {
428					if msg.Role == message.Assistant {
429						for _, tc := range msg.ToolCalls() {
430							if tc.Name == tools.MultiEditToolName {
431								multiEditTCID = tc.ID
432							}
433						}
434					}
435					if msg.Role == message.Tool {
436						for _, tr := range msg.ToolResults() {
437							if tr.ToolCallID == multiEditTCID {
438								foundMultiEdit = true
439							}
440						}
441					}
442				}
443
444				require.True(t, foundMultiEdit, "Expected to find a multiedit operation")
445
446				mainGoPath := filepath.Join(env.workingDir, "main.go")
447				content, err := os.ReadFile(mainGoPath)
448				require.NoError(t, err)
449				require.Contains(t, string(content), "Hello, Crush!", "Expected file to contain 'Hello, Crush!'")
450			})
451			t.Run("sourcegraph tool", func(t *testing.T) {
452				agent, env := setupAgent(t, pair)
453
454				session, err := env.sessions.Create(t.Context(), "New Session")
455				require.NoError(t, err)
456
457				res, err := agent.Run(t.Context(), SessionAgentCall{
458					Prompt:          "use sourcegraph to search for 'func main' in Go repositories",
459					SessionID:       session.ID,
460					MaxOutputTokens: 10000,
461				})
462				require.NoError(t, err)
463				assert.NotNil(t, res)
464
465				msgs, err := env.messages.List(t.Context(), session.ID)
466				require.NoError(t, err)
467
468				foundSourcegraph := false
469				var sourcegraphTCID string
470
471				for _, msg := range msgs {
472					if msg.Role == message.Assistant {
473						for _, tc := range msg.ToolCalls() {
474							if tc.Name == tools.SourcegraphToolName {
475								sourcegraphTCID = tc.ID
476							}
477						}
478					}
479					if msg.Role == message.Tool {
480						for _, tr := range msg.ToolResults() {
481							if tr.ToolCallID == sourcegraphTCID {
482								foundSourcegraph = true
483							}
484						}
485					}
486				}
487
488				require.True(t, foundSourcegraph, "Expected to find a sourcegraph operation")
489			})
490			t.Run("write tool", func(t *testing.T) {
491				agent, env := setupAgent(t, pair)
492
493				session, err := env.sessions.Create(t.Context(), "New Session")
494				require.NoError(t, err)
495
496				res, err := agent.Run(t.Context(), SessionAgentCall{
497					Prompt:          "use write to create a new file called config.json with content '{\"name\": \"test\", \"version\": \"1.0.0\"}'",
498					SessionID:       session.ID,
499					MaxOutputTokens: 10000,
500				})
501				require.NoError(t, err)
502				assert.NotNil(t, res)
503
504				msgs, err := env.messages.List(t.Context(), session.ID)
505				require.NoError(t, err)
506
507				foundWrite := false
508				var writeTCID string
509
510				for _, msg := range msgs {
511					if msg.Role == message.Assistant {
512						for _, tc := range msg.ToolCalls() {
513							if tc.Name == tools.WriteToolName {
514								writeTCID = tc.ID
515							}
516						}
517					}
518					if msg.Role == message.Tool {
519						for _, tr := range msg.ToolResults() {
520							if tr.ToolCallID == writeTCID {
521								foundWrite = true
522							}
523						}
524					}
525				}
526
527				require.True(t, foundWrite, "Expected to find a write operation")
528
529				configPath := filepath.Join(env.workingDir, "config.json")
530				content, err := os.ReadFile(configPath)
531				require.NoError(t, err)
532				require.Contains(t, string(content), "test", "Expected config.json to contain 'test'")
533				require.Contains(t, string(content), "1.0.0", "Expected config.json to contain '1.0.0'")
534			})
535			t.Run("parallel tool calls", func(t *testing.T) {
536				agent, env := setupAgent(t, pair)
537
538				session, err := env.sessions.Create(t.Context(), "New Session")
539				require.NoError(t, err)
540
541				res, err := agent.Run(t.Context(), SessionAgentCall{
542					Prompt:          "use glob to find all .go files and use ls to list the current directory, it is very important that you run both tool calls in parallel",
543					SessionID:       session.ID,
544					MaxOutputTokens: 10000,
545				})
546				require.NoError(t, err)
547				assert.NotNil(t, res)
548
549				msgs, err := env.messages.List(t.Context(), session.ID)
550				require.NoError(t, err)
551
552				var assistantMsg *message.Message
553				var toolMsgs []message.Message
554
555				for _, msg := range msgs {
556					if msg.Role == message.Assistant && len(msg.ToolCalls()) > 0 {
557						assistantMsg = &msg
558					}
559					if msg.Role == message.Tool {
560						toolMsgs = append(toolMsgs, msg)
561					}
562				}
563
564				require.NotNil(t, assistantMsg, "Expected to find an assistant message with tool calls")
565				require.NotNil(t, toolMsgs, "Expected to find a tool message")
566
567				toolCalls := assistantMsg.ToolCalls()
568				require.GreaterOrEqual(t, len(toolCalls), 2, "Expected at least 2 tool calls in parallel")
569
570				foundGlob := false
571				foundLS := false
572				var globTCID, lsTCID string
573
574				for _, tc := range toolCalls {
575					if tc.Name == tools.GlobToolName {
576						foundGlob = true
577						globTCID = tc.ID
578					}
579					if tc.Name == tools.LSToolName {
580						foundLS = true
581						lsTCID = tc.ID
582					}
583				}
584
585				require.True(t, foundGlob, "Expected to find a glob tool call")
586				require.True(t, foundLS, "Expected to find an ls tool call")
587
588				require.GreaterOrEqual(t, len(toolMsgs), 2, "Expected at least 2 tool results in the same message")
589
590				foundGlobResult := false
591				foundLSResult := false
592
593				for _, msg := range toolMsgs {
594					for _, tr := range msg.ToolResults() {
595						if tr.ToolCallID == globTCID {
596							foundGlobResult = true
597							require.Contains(t, tr.Content, "main.go", "Expected glob result to contain main.go")
598							require.False(t, tr.IsError, "Expected glob result to not be an error")
599						}
600						if tr.ToolCallID == lsTCID {
601							foundLSResult = true
602							require.Contains(t, tr.Content, "main.go", "Expected ls result to contain main.go")
603							require.False(t, tr.IsError, "Expected ls result to not be an error")
604						}
605					}
606				}
607
608				require.True(t, foundGlobResult, "Expected to find glob tool result")
609				require.True(t, foundLSResult, "Expected to find ls tool result")
610			})
611		})
612	}
613}