agent_test.go

  1package agent
  2
  3import (
  4	"os"
  5	"path/filepath"
  6	"runtime"
  7	"strings"
  8	"testing"
  9
 10	"charm.land/fantasy"
 11	"github.com/charmbracelet/crush/internal/agent/tools"
 12	"github.com/charmbracelet/crush/internal/message"
 13	"github.com/stretchr/testify/assert"
 14	"github.com/stretchr/testify/require"
 15	"gopkg.in/dnaeon/go-vcr.v4/pkg/recorder"
 16
 17	_ "github.com/joho/godotenv/autoload"
 18)
 19
 20var modelPairs = []modelPair{
 21	{"anthropic-sonnet", anthropicBuilder("claude-sonnet-4-5-20250929"), anthropicBuilder("claude-3-5-haiku-20241022")},
 22	{"openai-gpt-5", openaiBuilder("gpt-5"), openaiBuilder("gpt-4o")},
 23	{"openrouter-kimi-k2", openRouterBuilder("moonshotai/kimi-k2-0905"), openRouterBuilder("qwen/qwen3-next-80b-a3b-instruct")},
 24	{"zai-glm4.6", zAIBuilder("glm-4.6"), zAIBuilder("glm-4.5-air")},
 25}
 26
 27func getModels(t *testing.T, r *recorder.Recorder, pair modelPair) (fantasy.LanguageModel, fantasy.LanguageModel) {
 28	large, err := pair.largeModel(t, r)
 29	require.NoError(t, err)
 30	small, err := pair.smallModel(t, r)
 31	require.NoError(t, err)
 32	return large, small
 33}
 34
 35func setupAgent(t *testing.T, pair modelPair) (SessionAgent, env) {
 36	r := newRecorder(t)
 37	large, small := getModels(t, r, pair)
 38	env := testEnv(t)
 39
 40	createSimpleGoProject(t, env.workingDir)
 41	agent, err := coderAgent(r, env, large, small)
 42	require.NoError(t, err)
 43	return agent, env
 44}
 45
 46func TestCoderAgent(t *testing.T) {
 47	if runtime.GOOS == "windows" {
 48		t.Skip("skipping on windows for now")
 49	}
 50
 51	for _, pair := range modelPairs {
 52		t.Run(pair.name, func(t *testing.T) {
 53			t.Run("simple test", func(t *testing.T) {
 54				agent, env := setupAgent(t, pair)
 55
 56				session, err := env.sessions.Create(t.Context(), "New Session")
 57				require.NoError(t, err)
 58
 59				res, err := agent.Run(t.Context(), SessionAgentCall{
 60					Prompt:          "Hello",
 61					SessionID:       session.ID,
 62					MaxOutputTokens: 10000,
 63				})
 64				require.NoError(t, err)
 65				assert.NotNil(t, res)
 66
 67				msgs, err := env.messages.List(t.Context(), session.ID)
 68				require.NoError(t, err)
 69				// Should have the agent and user message
 70				assert.Equal(t, len(msgs), 2)
 71			})
 72			t.Run("read a file", func(t *testing.T) {
 73				agent, env := setupAgent(t, pair)
 74
 75				session, err := env.sessions.Create(t.Context(), "New Session")
 76				require.NoError(t, err)
 77				res, err := agent.Run(t.Context(), SessionAgentCall{
 78					Prompt:          "Read the go mod",
 79					SessionID:       session.ID,
 80					MaxOutputTokens: 10000,
 81				})
 82
 83				require.NoError(t, err)
 84				assert.NotNil(t, res)
 85
 86				msgs, err := env.messages.List(t.Context(), session.ID)
 87				require.NoError(t, err)
 88				foundFile := false
 89				var tcID string
 90			out:
 91				for _, msg := range msgs {
 92					if msg.Role == message.Assistant {
 93						for _, tc := range msg.ToolCalls() {
 94							if tc.Name == tools.ViewToolName {
 95								tcID = tc.ID
 96							}
 97						}
 98					}
 99					if msg.Role == message.Tool {
100						for _, tr := range msg.ToolResults() {
101							if tr.ToolCallID == tcID {
102								if strings.Contains(tr.Content, "module example.com/testproject") {
103									foundFile = true
104									break out
105								}
106							}
107						}
108					}
109				}
110				require.True(t, foundFile)
111			})
112			t.Run("update a file", func(t *testing.T) {
113				agent, env := setupAgent(t, pair)
114
115				session, err := env.sessions.Create(t.Context(), "New Session")
116				require.NoError(t, err)
117
118				res, err := agent.Run(t.Context(), SessionAgentCall{
119					Prompt:          "update the main.go file by changing the print to say hello from crush",
120					SessionID:       session.ID,
121					MaxOutputTokens: 10000,
122				})
123				require.NoError(t, err)
124				assert.NotNil(t, res)
125
126				msgs, err := env.messages.List(t.Context(), session.ID)
127				require.NoError(t, err)
128
129				foundRead := false
130				foundWrite := false
131				var readTCID, writeTCID string
132
133				for _, msg := range msgs {
134					if msg.Role == message.Assistant {
135						for _, tc := range msg.ToolCalls() {
136							if tc.Name == tools.ViewToolName {
137								readTCID = tc.ID
138							}
139							if tc.Name == tools.EditToolName || tc.Name == tools.WriteToolName {
140								writeTCID = tc.ID
141							}
142						}
143					}
144					if msg.Role == message.Tool {
145						for _, tr := range msg.ToolResults() {
146							if tr.ToolCallID == readTCID {
147								foundRead = true
148							}
149							if tr.ToolCallID == writeTCID {
150								foundWrite = true
151							}
152						}
153					}
154				}
155
156				require.True(t, foundRead, "Expected to find a read operation")
157				require.True(t, foundWrite, "Expected to find a write operation")
158
159				mainGoPath := filepath.Join(env.workingDir, "main.go")
160				content, err := os.ReadFile(mainGoPath)
161				require.NoError(t, err)
162				require.Contains(t, strings.ToLower(string(content)), "hello from crush")
163			})
164			t.Run("bash tool", func(t *testing.T) {
165				agent, env := setupAgent(t, pair)
166
167				session, err := env.sessions.Create(t.Context(), "New Session")
168				require.NoError(t, err)
169
170				res, err := agent.Run(t.Context(), SessionAgentCall{
171					Prompt:          "use bash to create a file named test.txt with content 'hello bash'. do not print its timestamp",
172					SessionID:       session.ID,
173					MaxOutputTokens: 10000,
174				})
175				require.NoError(t, err)
176				assert.NotNil(t, res)
177
178				msgs, err := env.messages.List(t.Context(), session.ID)
179				require.NoError(t, err)
180
181				foundBash := false
182				var bashTCID string
183
184				for _, msg := range msgs {
185					if msg.Role == message.Assistant {
186						for _, tc := range msg.ToolCalls() {
187							if tc.Name == tools.BashToolName {
188								bashTCID = tc.ID
189							}
190						}
191					}
192					if msg.Role == message.Tool {
193						for _, tr := range msg.ToolResults() {
194							if tr.ToolCallID == bashTCID {
195								foundBash = true
196							}
197						}
198					}
199				}
200
201				require.True(t, foundBash, "Expected to find a bash operation")
202
203				testFilePath := filepath.Join(env.workingDir, "test.txt")
204				content, err := os.ReadFile(testFilePath)
205				require.NoError(t, err)
206				require.Contains(t, string(content), "hello bash")
207			})
208			t.Run("download tool", func(t *testing.T) {
209				agent, env := setupAgent(t, pair)
210
211				session, err := env.sessions.Create(t.Context(), "New Session")
212				require.NoError(t, err)
213
214				res, err := agent.Run(t.Context(), SessionAgentCall{
215					Prompt:          "download the file from https://example-files.online-convert.com/document/txt/example.txt and save it as example.txt",
216					SessionID:       session.ID,
217					MaxOutputTokens: 10000,
218				})
219				require.NoError(t, err)
220				assert.NotNil(t, res)
221
222				msgs, err := env.messages.List(t.Context(), session.ID)
223				require.NoError(t, err)
224
225				foundDownload := false
226				var downloadTCID string
227
228				for _, msg := range msgs {
229					if msg.Role == message.Assistant {
230						for _, tc := range msg.ToolCalls() {
231							if tc.Name == tools.DownloadToolName {
232								downloadTCID = tc.ID
233							}
234						}
235					}
236					if msg.Role == message.Tool {
237						for _, tr := range msg.ToolResults() {
238							if tr.ToolCallID == downloadTCID {
239								foundDownload = true
240							}
241						}
242					}
243				}
244
245				require.True(t, foundDownload, "Expected to find a download operation")
246
247				examplePath := filepath.Join(env.workingDir, "example.txt")
248				_, err = os.Stat(examplePath)
249				require.NoError(t, err, "Expected example.txt file to exist")
250			})
251			t.Run("fetch tool", func(t *testing.T) {
252				agent, env := setupAgent(t, pair)
253
254				session, err := env.sessions.Create(t.Context(), "New Session")
255				require.NoError(t, err)
256
257				res, err := agent.Run(t.Context(), SessionAgentCall{
258					Prompt:          "fetch the content from https://example-files.online-convert.com/website/html/example.html and tell me if it contains the word 'John Doe'",
259					SessionID:       session.ID,
260					MaxOutputTokens: 10000,
261				})
262				require.NoError(t, err)
263				assert.NotNil(t, res)
264
265				msgs, err := env.messages.List(t.Context(), session.ID)
266				require.NoError(t, err)
267
268				foundFetch := false
269				var fetchTCID string
270
271				for _, msg := range msgs {
272					if msg.Role == message.Assistant {
273						for _, tc := range msg.ToolCalls() {
274							if tc.Name == tools.FetchToolName {
275								fetchTCID = tc.ID
276							}
277						}
278					}
279					if msg.Role == message.Tool {
280						for _, tr := range msg.ToolResults() {
281							if tr.ToolCallID == fetchTCID {
282								foundFetch = true
283							}
284						}
285					}
286				}
287
288				require.True(t, foundFetch, "Expected to find a fetch operation")
289			})
290			t.Run("glob tool", func(t *testing.T) {
291				agent, env := setupAgent(t, pair)
292
293				session, err := env.sessions.Create(t.Context(), "New Session")
294				require.NoError(t, err)
295
296				res, err := agent.Run(t.Context(), SessionAgentCall{
297					Prompt:          "use glob to find all .go files in the current directory",
298					SessionID:       session.ID,
299					MaxOutputTokens: 10000,
300				})
301				require.NoError(t, err)
302				assert.NotNil(t, res)
303
304				msgs, err := env.messages.List(t.Context(), session.ID)
305				require.NoError(t, err)
306
307				foundGlob := false
308				var globTCID string
309
310				for _, msg := range msgs {
311					if msg.Role == message.Assistant {
312						for _, tc := range msg.ToolCalls() {
313							if tc.Name == tools.GlobToolName {
314								globTCID = tc.ID
315							}
316						}
317					}
318					if msg.Role == message.Tool {
319						for _, tr := range msg.ToolResults() {
320							if tr.ToolCallID == globTCID {
321								foundGlob = true
322								require.Contains(t, tr.Content, "main.go", "Expected glob to find main.go")
323							}
324						}
325					}
326				}
327
328				require.True(t, foundGlob, "Expected to find a glob operation")
329			})
330			t.Run("grep tool", func(t *testing.T) {
331				agent, env := setupAgent(t, pair)
332
333				session, err := env.sessions.Create(t.Context(), "New Session")
334				require.NoError(t, err)
335
336				res, err := agent.Run(t.Context(), SessionAgentCall{
337					Prompt:          "use grep to search for the word 'package' in go files",
338					SessionID:       session.ID,
339					MaxOutputTokens: 10000,
340				})
341				require.NoError(t, err)
342				assert.NotNil(t, res)
343
344				msgs, err := env.messages.List(t.Context(), session.ID)
345				require.NoError(t, err)
346
347				foundGrep := false
348				var grepTCID string
349
350				for _, msg := range msgs {
351					if msg.Role == message.Assistant {
352						for _, tc := range msg.ToolCalls() {
353							if tc.Name == tools.GrepToolName {
354								grepTCID = tc.ID
355							}
356						}
357					}
358					if msg.Role == message.Tool {
359						for _, tr := range msg.ToolResults() {
360							if tr.ToolCallID == grepTCID {
361								foundGrep = true
362								require.Contains(t, tr.Content, "main.go", "Expected grep to find main.go")
363							}
364						}
365					}
366				}
367
368				require.True(t, foundGrep, "Expected to find a grep operation")
369			})
370			t.Run("ls tool", func(t *testing.T) {
371				agent, env := setupAgent(t, pair)
372
373				session, err := env.sessions.Create(t.Context(), "New Session")
374				require.NoError(t, err)
375
376				res, err := agent.Run(t.Context(), SessionAgentCall{
377					Prompt:          "use ls to list the files in the current directory",
378					SessionID:       session.ID,
379					MaxOutputTokens: 10000,
380				})
381				require.NoError(t, err)
382				assert.NotNil(t, res)
383
384				msgs, err := env.messages.List(t.Context(), session.ID)
385				require.NoError(t, err)
386
387				foundLS := false
388				var lsTCID string
389
390				for _, msg := range msgs {
391					if msg.Role == message.Assistant {
392						for _, tc := range msg.ToolCalls() {
393							if tc.Name == tools.LSToolName {
394								lsTCID = tc.ID
395							}
396						}
397					}
398					if msg.Role == message.Tool {
399						for _, tr := range msg.ToolResults() {
400							if tr.ToolCallID == lsTCID {
401								foundLS = true
402								require.Contains(t, tr.Content, "main.go", "Expected ls to list main.go")
403								require.Contains(t, tr.Content, "go.mod", "Expected ls to list go.mod")
404							}
405						}
406					}
407				}
408
409				require.True(t, foundLS, "Expected to find an ls operation")
410			})
411			t.Run("multiedit tool", func(t *testing.T) {
412				agent, env := setupAgent(t, pair)
413
414				session, err := env.sessions.Create(t.Context(), "New Session")
415				require.NoError(t, err)
416
417				res, err := agent.Run(t.Context(), SessionAgentCall{
418					Prompt:          "use multiedit to change 'Hello, World!' to 'Hello, Crush!' and add a comment '// Greeting' above the fmt.Println line in main.go",
419					SessionID:       session.ID,
420					MaxOutputTokens: 10000,
421				})
422				require.NoError(t, err)
423				assert.NotNil(t, res)
424
425				msgs, err := env.messages.List(t.Context(), session.ID)
426				require.NoError(t, err)
427
428				foundMultiEdit := false
429				var multiEditTCID string
430
431				for _, msg := range msgs {
432					if msg.Role == message.Assistant {
433						for _, tc := range msg.ToolCalls() {
434							if tc.Name == tools.MultiEditToolName {
435								multiEditTCID = tc.ID
436							}
437						}
438					}
439					if msg.Role == message.Tool {
440						for _, tr := range msg.ToolResults() {
441							if tr.ToolCallID == multiEditTCID {
442								foundMultiEdit = true
443							}
444						}
445					}
446				}
447
448				require.True(t, foundMultiEdit, "Expected to find a multiedit operation")
449
450				mainGoPath := filepath.Join(env.workingDir, "main.go")
451				content, err := os.ReadFile(mainGoPath)
452				require.NoError(t, err)
453				require.Contains(t, string(content), "Hello, Crush!", "Expected file to contain 'Hello, Crush!'")
454			})
455			t.Run("sourcegraph tool", func(t *testing.T) {
456				agent, env := setupAgent(t, pair)
457
458				session, err := env.sessions.Create(t.Context(), "New Session")
459				require.NoError(t, err)
460
461				res, err := agent.Run(t.Context(), SessionAgentCall{
462					Prompt:          "use sourcegraph to search for 'func main' in Go repositories",
463					SessionID:       session.ID,
464					MaxOutputTokens: 10000,
465				})
466				require.NoError(t, err)
467				assert.NotNil(t, res)
468
469				msgs, err := env.messages.List(t.Context(), session.ID)
470				require.NoError(t, err)
471
472				foundSourcegraph := false
473				var sourcegraphTCID string
474
475				for _, msg := range msgs {
476					if msg.Role == message.Assistant {
477						for _, tc := range msg.ToolCalls() {
478							if tc.Name == tools.SourcegraphToolName {
479								sourcegraphTCID = tc.ID
480							}
481						}
482					}
483					if msg.Role == message.Tool {
484						for _, tr := range msg.ToolResults() {
485							if tr.ToolCallID == sourcegraphTCID {
486								foundSourcegraph = true
487							}
488						}
489					}
490				}
491
492				require.True(t, foundSourcegraph, "Expected to find a sourcegraph operation")
493			})
494			t.Run("write tool", func(t *testing.T) {
495				agent, env := setupAgent(t, pair)
496
497				session, err := env.sessions.Create(t.Context(), "New Session")
498				require.NoError(t, err)
499
500				res, err := agent.Run(t.Context(), SessionAgentCall{
501					Prompt:          "use write to create a new file called config.json with content '{\"name\": \"test\", \"version\": \"1.0.0\"}'",
502					SessionID:       session.ID,
503					MaxOutputTokens: 10000,
504				})
505				require.NoError(t, err)
506				assert.NotNil(t, res)
507
508				msgs, err := env.messages.List(t.Context(), session.ID)
509				require.NoError(t, err)
510
511				foundWrite := false
512				var writeTCID string
513
514				for _, msg := range msgs {
515					if msg.Role == message.Assistant {
516						for _, tc := range msg.ToolCalls() {
517							if tc.Name == tools.WriteToolName {
518								writeTCID = tc.ID
519							}
520						}
521					}
522					if msg.Role == message.Tool {
523						for _, tr := range msg.ToolResults() {
524							if tr.ToolCallID == writeTCID {
525								foundWrite = true
526							}
527						}
528					}
529				}
530
531				require.True(t, foundWrite, "Expected to find a write operation")
532
533				configPath := filepath.Join(env.workingDir, "config.json")
534				content, err := os.ReadFile(configPath)
535				require.NoError(t, err)
536				require.Contains(t, string(content), "test", "Expected config.json to contain 'test'")
537				require.Contains(t, string(content), "1.0.0", "Expected config.json to contain '1.0.0'")
538			})
539			t.Run("parallel tool calls", func(t *testing.T) {
540				agent, env := setupAgent(t, pair)
541
542				session, err := env.sessions.Create(t.Context(), "New Session")
543				require.NoError(t, err)
544
545				res, err := agent.Run(t.Context(), SessionAgentCall{
546					Prompt:          "use glob to find all .go files and use ls to list the current directory, it is very important that you run both tool calls in parallel",
547					SessionID:       session.ID,
548					MaxOutputTokens: 10000,
549				})
550				require.NoError(t, err)
551				assert.NotNil(t, res)
552
553				msgs, err := env.messages.List(t.Context(), session.ID)
554				require.NoError(t, err)
555
556				var assistantMsg *message.Message
557				var toolMsgs []message.Message
558
559				for _, msg := range msgs {
560					if msg.Role == message.Assistant && len(msg.ToolCalls()) > 0 {
561						assistantMsg = &msg
562					}
563					if msg.Role == message.Tool {
564						toolMsgs = append(toolMsgs, msg)
565					}
566				}
567
568				require.NotNil(t, assistantMsg, "Expected to find an assistant message with tool calls")
569				require.NotNil(t, toolMsgs, "Expected to find a tool message")
570
571				toolCalls := assistantMsg.ToolCalls()
572				require.GreaterOrEqual(t, len(toolCalls), 2, "Expected at least 2 tool calls in parallel")
573
574				foundGlob := false
575				foundLS := false
576				var globTCID, lsTCID string
577
578				for _, tc := range toolCalls {
579					if tc.Name == tools.GlobToolName {
580						foundGlob = true
581						globTCID = tc.ID
582					}
583					if tc.Name == tools.LSToolName {
584						foundLS = true
585						lsTCID = tc.ID
586					}
587				}
588
589				require.True(t, foundGlob, "Expected to find a glob tool call")
590				require.True(t, foundLS, "Expected to find an ls tool call")
591
592				require.GreaterOrEqual(t, len(toolMsgs), 2, "Expected at least 2 tool results in the same message")
593
594				foundGlobResult := false
595				foundLSResult := false
596
597				for _, msg := range toolMsgs {
598					for _, tr := range msg.ToolResults() {
599						if tr.ToolCallID == globTCID {
600							foundGlobResult = true
601							require.Contains(t, tr.Content, "main.go", "Expected glob result to contain main.go")
602							require.False(t, tr.IsError, "Expected glob result to not be an error")
603						}
604						if tr.ToolCallID == lsTCID {
605							foundLSResult = true
606							require.Contains(t, tr.Content, "main.go", "Expected ls result to contain main.go")
607							require.False(t, tr.IsError, "Expected ls result to not be an error")
608						}
609					}
610				}
611
612				require.True(t, foundGlobResult, "Expected to find glob tool result")
613				require.True(t, foundLSResult, "Expected to find ls tool result")
614			})
615		})
616	}
617}