agent_test.go

  1package agent
  2
  3import (
  4	"os"
  5	"path/filepath"
  6	"strings"
  7	"testing"
  8
  9	"charm.land/fantasy"
 10	"github.com/charmbracelet/crush/internal/agent/tools"
 11	"github.com/charmbracelet/crush/internal/message"
 12	"github.com/charmbracelet/crush/internal/shell"
 13	"github.com/stretchr/testify/assert"
 14	"github.com/stretchr/testify/require"
 15	"gopkg.in/dnaeon/go-vcr.v4/pkg/recorder"
 16
 17	_ "github.com/joho/godotenv/autoload"
 18)
 19
 20var modelPairs = []modelPair{
 21	{"anthropic-sonnet", anthropicBuilder("claude-sonnet-4-5-20250929"), anthropicBuilder("claude-3-5-haiku-20241022")},
 22	{"openai-gpt-5", openaiBuilder("gpt-5"), openaiBuilder("gpt-4o")},
 23	{"openrouter-kimi-k2", openRouterBuilder("moonshotai/kimi-k2-0905"), openRouterBuilder("qwen/qwen3-next-80b-a3b-instruct")},
 24	{"zai-glm4.6", zAIBuilder("glm-4.6"), zAIBuilder("glm-4.5-air")},
 25}
 26
 27func getModels(t *testing.T, r *recorder.Recorder, pair modelPair) (fantasy.LanguageModel, fantasy.LanguageModel) {
 28	large, err := pair.largeModel(t, r)
 29	require.NoError(t, err)
 30	small, err := pair.smallModel(t, r)
 31	require.NoError(t, err)
 32	return large, small
 33}
 34
 35func setupAgent(t *testing.T, pair modelPair) (SessionAgent, env) {
 36	r := newRecorder(t)
 37	large, small := getModels(t, r, pair)
 38	env := testEnv(t)
 39
 40	createSimpleGoProject(t, env.workingDir)
 41	agent, err := coderAgent(r, env, large, small)
 42	shell.Reset(env.workingDir)
 43	require.NoError(t, err)
 44	return agent, env
 45}
 46
 47func TestCoderAgent(t *testing.T) {
 48	for _, pair := range modelPairs {
 49		t.Run(pair.name, func(t *testing.T) {
 50			t.Run("simple test", func(t *testing.T) {
 51				agent, env := setupAgent(t, pair)
 52
 53				session, err := env.sessions.Create(t.Context(), "New Session")
 54				require.NoError(t, err)
 55
 56				res, err := agent.Run(t.Context(), SessionAgentCall{
 57					Prompt:          "Hello",
 58					SessionID:       session.ID,
 59					MaxOutputTokens: 10000,
 60				})
 61				require.NoError(t, err)
 62				assert.NotNil(t, res)
 63
 64				msgs, err := env.messages.List(t.Context(), session.ID)
 65				require.NoError(t, err)
 66				// Should have the agent and user message
 67				assert.Equal(t, len(msgs), 2)
 68			})
 69			t.Run("read a file", func(t *testing.T) {
 70				agent, env := setupAgent(t, pair)
 71
 72				session, err := env.sessions.Create(t.Context(), "New Session")
 73				require.NoError(t, err)
 74				res, err := agent.Run(t.Context(), SessionAgentCall{
 75					Prompt:          "Read the go mod",
 76					SessionID:       session.ID,
 77					MaxOutputTokens: 10000,
 78				})
 79
 80				require.NoError(t, err)
 81				assert.NotNil(t, res)
 82
 83				msgs, err := env.messages.List(t.Context(), session.ID)
 84				require.NoError(t, err)
 85				foundFile := false
 86				var tcID string
 87			out:
 88				for _, msg := range msgs {
 89					if msg.Role == message.Assistant {
 90						for _, tc := range msg.ToolCalls() {
 91							if tc.Name == tools.ViewToolName {
 92								tcID = tc.ID
 93							}
 94						}
 95					}
 96					if msg.Role == message.Tool {
 97						for _, tr := range msg.ToolResults() {
 98							if tr.ToolCallID == tcID {
 99								if strings.Contains(tr.Content, "module example.com/testproject") {
100									foundFile = true
101									break out
102								}
103							}
104						}
105					}
106				}
107				require.True(t, foundFile)
108			})
109			t.Run("update a file", func(t *testing.T) {
110				agent, env := setupAgent(t, pair)
111
112				session, err := env.sessions.Create(t.Context(), "New Session")
113				require.NoError(t, err)
114
115				res, err := agent.Run(t.Context(), SessionAgentCall{
116					Prompt:          "update the main.go file by changing the print to say hello from crush",
117					SessionID:       session.ID,
118					MaxOutputTokens: 10000,
119				})
120				require.NoError(t, err)
121				assert.NotNil(t, res)
122
123				msgs, err := env.messages.List(t.Context(), session.ID)
124				require.NoError(t, err)
125
126				foundRead := false
127				foundWrite := false
128				var readTCID, writeTCID string
129
130				for _, msg := range msgs {
131					if msg.Role == message.Assistant {
132						for _, tc := range msg.ToolCalls() {
133							if tc.Name == tools.ViewToolName {
134								readTCID = tc.ID
135							}
136							if tc.Name == tools.EditToolName || tc.Name == tools.WriteToolName {
137								writeTCID = tc.ID
138							}
139						}
140					}
141					if msg.Role == message.Tool {
142						for _, tr := range msg.ToolResults() {
143							if tr.ToolCallID == readTCID {
144								foundRead = true
145							}
146							if tr.ToolCallID == writeTCID {
147								foundWrite = true
148							}
149						}
150					}
151				}
152
153				require.True(t, foundRead, "Expected to find a read operation")
154				require.True(t, foundWrite, "Expected to find a write operation")
155
156				mainGoPath := filepath.Join(env.workingDir, "main.go")
157				content, err := os.ReadFile(mainGoPath)
158				require.NoError(t, err)
159				require.Contains(t, strings.ToLower(string(content)), "hello from crush")
160			})
161			t.Run("bash tool", func(t *testing.T) {
162				agent, env := setupAgent(t, pair)
163
164				session, err := env.sessions.Create(t.Context(), "New Session")
165				require.NoError(t, err)
166
167				res, err := agent.Run(t.Context(), SessionAgentCall{
168					Prompt:          "use bash to create a file named test.txt with content 'hello bash'. do not print its timestamp",
169					SessionID:       session.ID,
170					MaxOutputTokens: 10000,
171				})
172				require.NoError(t, err)
173				assert.NotNil(t, res)
174
175				msgs, err := env.messages.List(t.Context(), session.ID)
176				require.NoError(t, err)
177
178				foundBash := false
179				var bashTCID string
180
181				for _, msg := range msgs {
182					if msg.Role == message.Assistant {
183						for _, tc := range msg.ToolCalls() {
184							if tc.Name == tools.BashToolName {
185								bashTCID = tc.ID
186							}
187						}
188					}
189					if msg.Role == message.Tool {
190						for _, tr := range msg.ToolResults() {
191							if tr.ToolCallID == bashTCID {
192								foundBash = true
193							}
194						}
195					}
196				}
197
198				require.True(t, foundBash, "Expected to find a bash operation")
199
200				testFilePath := filepath.Join(env.workingDir, "test.txt")
201				content, err := os.ReadFile(testFilePath)
202				require.NoError(t, err)
203				require.Contains(t, string(content), "hello bash")
204			})
205			t.Run("download tool", func(t *testing.T) {
206				agent, env := setupAgent(t, pair)
207
208				session, err := env.sessions.Create(t.Context(), "New Session")
209				require.NoError(t, err)
210
211				res, err := agent.Run(t.Context(), SessionAgentCall{
212					Prompt:          "download the file from https://example-files.online-convert.com/document/txt/example.txt and save it as example.txt",
213					SessionID:       session.ID,
214					MaxOutputTokens: 10000,
215				})
216				require.NoError(t, err)
217				assert.NotNil(t, res)
218
219				msgs, err := env.messages.List(t.Context(), session.ID)
220				require.NoError(t, err)
221
222				foundDownload := false
223				var downloadTCID string
224
225				for _, msg := range msgs {
226					if msg.Role == message.Assistant {
227						for _, tc := range msg.ToolCalls() {
228							if tc.Name == tools.DownloadToolName {
229								downloadTCID = tc.ID
230							}
231						}
232					}
233					if msg.Role == message.Tool {
234						for _, tr := range msg.ToolResults() {
235							if tr.ToolCallID == downloadTCID {
236								foundDownload = true
237							}
238						}
239					}
240				}
241
242				require.True(t, foundDownload, "Expected to find a download operation")
243
244				examplePath := filepath.Join(env.workingDir, "example.txt")
245				_, err = os.Stat(examplePath)
246				require.NoError(t, err, "Expected example.txt file to exist")
247			})
248			t.Run("fetch tool", func(t *testing.T) {
249				agent, env := setupAgent(t, pair)
250
251				session, err := env.sessions.Create(t.Context(), "New Session")
252				require.NoError(t, err)
253
254				res, err := agent.Run(t.Context(), SessionAgentCall{
255					Prompt:          "fetch the content from https://example-files.online-convert.com/website/html/example.html and tell me if it contains the word 'John Doe'",
256					SessionID:       session.ID,
257					MaxOutputTokens: 10000,
258				})
259				require.NoError(t, err)
260				assert.NotNil(t, res)
261
262				msgs, err := env.messages.List(t.Context(), session.ID)
263				require.NoError(t, err)
264
265				foundFetch := false
266				var fetchTCID string
267
268				for _, msg := range msgs {
269					if msg.Role == message.Assistant {
270						for _, tc := range msg.ToolCalls() {
271							if tc.Name == tools.FetchToolName {
272								fetchTCID = tc.ID
273							}
274						}
275					}
276					if msg.Role == message.Tool {
277						for _, tr := range msg.ToolResults() {
278							if tr.ToolCallID == fetchTCID {
279								foundFetch = true
280							}
281						}
282					}
283				}
284
285				require.True(t, foundFetch, "Expected to find a fetch operation")
286			})
287			t.Run("glob tool", func(t *testing.T) {
288				agent, env := setupAgent(t, pair)
289
290				session, err := env.sessions.Create(t.Context(), "New Session")
291				require.NoError(t, err)
292
293				res, err := agent.Run(t.Context(), SessionAgentCall{
294					Prompt:          "use glob to find all .go files in the current directory",
295					SessionID:       session.ID,
296					MaxOutputTokens: 10000,
297				})
298				require.NoError(t, err)
299				assert.NotNil(t, res)
300
301				msgs, err := env.messages.List(t.Context(), session.ID)
302				require.NoError(t, err)
303
304				foundGlob := false
305				var globTCID string
306
307				for _, msg := range msgs {
308					if msg.Role == message.Assistant {
309						for _, tc := range msg.ToolCalls() {
310							if tc.Name == tools.GlobToolName {
311								globTCID = tc.ID
312							}
313						}
314					}
315					if msg.Role == message.Tool {
316						for _, tr := range msg.ToolResults() {
317							if tr.ToolCallID == globTCID {
318								foundGlob = true
319								require.Contains(t, tr.Content, "main.go", "Expected glob to find main.go")
320							}
321						}
322					}
323				}
324
325				require.True(t, foundGlob, "Expected to find a glob operation")
326			})
327			t.Run("grep tool", func(t *testing.T) {
328				agent, env := setupAgent(t, pair)
329
330				session, err := env.sessions.Create(t.Context(), "New Session")
331				require.NoError(t, err)
332
333				res, err := agent.Run(t.Context(), SessionAgentCall{
334					Prompt:          "use grep to search for the word 'package' in go files",
335					SessionID:       session.ID,
336					MaxOutputTokens: 10000,
337				})
338				require.NoError(t, err)
339				assert.NotNil(t, res)
340
341				msgs, err := env.messages.List(t.Context(), session.ID)
342				require.NoError(t, err)
343
344				foundGrep := false
345				var grepTCID string
346
347				for _, msg := range msgs {
348					if msg.Role == message.Assistant {
349						for _, tc := range msg.ToolCalls() {
350							if tc.Name == tools.GrepToolName {
351								grepTCID = tc.ID
352							}
353						}
354					}
355					if msg.Role == message.Tool {
356						for _, tr := range msg.ToolResults() {
357							if tr.ToolCallID == grepTCID {
358								foundGrep = true
359								require.Contains(t, tr.Content, "main.go", "Expected grep to find main.go")
360							}
361						}
362					}
363				}
364
365				require.True(t, foundGrep, "Expected to find a grep operation")
366			})
367			t.Run("ls tool", func(t *testing.T) {
368				agent, env := setupAgent(t, pair)
369
370				session, err := env.sessions.Create(t.Context(), "New Session")
371				require.NoError(t, err)
372
373				res, err := agent.Run(t.Context(), SessionAgentCall{
374					Prompt:          "use ls to list the files in the current directory",
375					SessionID:       session.ID,
376					MaxOutputTokens: 10000,
377				})
378				require.NoError(t, err)
379				assert.NotNil(t, res)
380
381				msgs, err := env.messages.List(t.Context(), session.ID)
382				require.NoError(t, err)
383
384				foundLS := false
385				var lsTCID string
386
387				for _, msg := range msgs {
388					if msg.Role == message.Assistant {
389						for _, tc := range msg.ToolCalls() {
390							if tc.Name == tools.LSToolName {
391								lsTCID = tc.ID
392							}
393						}
394					}
395					if msg.Role == message.Tool {
396						for _, tr := range msg.ToolResults() {
397							if tr.ToolCallID == lsTCID {
398								foundLS = true
399								require.Contains(t, tr.Content, "main.go", "Expected ls to list main.go")
400								require.Contains(t, tr.Content, "go.mod", "Expected ls to list go.mod")
401							}
402						}
403					}
404				}
405
406				require.True(t, foundLS, "Expected to find an ls operation")
407			})
408			t.Run("multiedit tool", func(t *testing.T) {
409				agent, env := setupAgent(t, pair)
410
411				session, err := env.sessions.Create(t.Context(), "New Session")
412				require.NoError(t, err)
413
414				res, err := agent.Run(t.Context(), SessionAgentCall{
415					Prompt:          "use multiedit to change 'Hello, World!' to 'Hello, Crush!' and add a comment '// Greeting' above the fmt.Println line in main.go",
416					SessionID:       session.ID,
417					MaxOutputTokens: 10000,
418				})
419				require.NoError(t, err)
420				assert.NotNil(t, res)
421
422				msgs, err := env.messages.List(t.Context(), session.ID)
423				require.NoError(t, err)
424
425				foundMultiEdit := false
426				var multiEditTCID string
427
428				for _, msg := range msgs {
429					if msg.Role == message.Assistant {
430						for _, tc := range msg.ToolCalls() {
431							if tc.Name == tools.MultiEditToolName {
432								multiEditTCID = tc.ID
433							}
434						}
435					}
436					if msg.Role == message.Tool {
437						for _, tr := range msg.ToolResults() {
438							if tr.ToolCallID == multiEditTCID {
439								foundMultiEdit = true
440							}
441						}
442					}
443				}
444
445				require.True(t, foundMultiEdit, "Expected to find a multiedit operation")
446
447				mainGoPath := filepath.Join(env.workingDir, "main.go")
448				content, err := os.ReadFile(mainGoPath)
449				require.NoError(t, err)
450				require.Contains(t, string(content), "Hello, Crush!", "Expected file to contain 'Hello, Crush!'")
451			})
452			t.Run("sourcegraph tool", func(t *testing.T) {
453				agent, env := setupAgent(t, pair)
454
455				session, err := env.sessions.Create(t.Context(), "New Session")
456				require.NoError(t, err)
457
458				res, err := agent.Run(t.Context(), SessionAgentCall{
459					Prompt:          "use sourcegraph to search for 'func main' in Go repositories",
460					SessionID:       session.ID,
461					MaxOutputTokens: 10000,
462				})
463				require.NoError(t, err)
464				assert.NotNil(t, res)
465
466				msgs, err := env.messages.List(t.Context(), session.ID)
467				require.NoError(t, err)
468
469				foundSourcegraph := false
470				var sourcegraphTCID string
471
472				for _, msg := range msgs {
473					if msg.Role == message.Assistant {
474						for _, tc := range msg.ToolCalls() {
475							if tc.Name == tools.SourcegraphToolName {
476								sourcegraphTCID = tc.ID
477							}
478						}
479					}
480					if msg.Role == message.Tool {
481						for _, tr := range msg.ToolResults() {
482							if tr.ToolCallID == sourcegraphTCID {
483								foundSourcegraph = true
484							}
485						}
486					}
487				}
488
489				require.True(t, foundSourcegraph, "Expected to find a sourcegraph operation")
490			})
491			t.Run("write tool", func(t *testing.T) {
492				agent, env := setupAgent(t, pair)
493
494				session, err := env.sessions.Create(t.Context(), "New Session")
495				require.NoError(t, err)
496
497				res, err := agent.Run(t.Context(), SessionAgentCall{
498					Prompt:          "use write to create a new file called config.json with content '{\"name\": \"test\", \"version\": \"1.0.0\"}'",
499					SessionID:       session.ID,
500					MaxOutputTokens: 10000,
501				})
502				require.NoError(t, err)
503				assert.NotNil(t, res)
504
505				msgs, err := env.messages.List(t.Context(), session.ID)
506				require.NoError(t, err)
507
508				foundWrite := false
509				var writeTCID string
510
511				for _, msg := range msgs {
512					if msg.Role == message.Assistant {
513						for _, tc := range msg.ToolCalls() {
514							if tc.Name == tools.WriteToolName {
515								writeTCID = tc.ID
516							}
517						}
518					}
519					if msg.Role == message.Tool {
520						for _, tr := range msg.ToolResults() {
521							if tr.ToolCallID == writeTCID {
522								foundWrite = true
523							}
524						}
525					}
526				}
527
528				require.True(t, foundWrite, "Expected to find a write operation")
529
530				configPath := filepath.Join(env.workingDir, "config.json")
531				content, err := os.ReadFile(configPath)
532				require.NoError(t, err)
533				require.Contains(t, string(content), "test", "Expected config.json to contain 'test'")
534				require.Contains(t, string(content), "1.0.0", "Expected config.json to contain '1.0.0'")
535			})
536			t.Run("parallel tool calls", func(t *testing.T) {
537				agent, env := setupAgent(t, pair)
538
539				session, err := env.sessions.Create(t.Context(), "New Session")
540				require.NoError(t, err)
541
542				res, err := agent.Run(t.Context(), SessionAgentCall{
543					Prompt:          "use glob to find all .go files and use ls to list the current directory, it is very important that you run both tool calls in parallel",
544					SessionID:       session.ID,
545					MaxOutputTokens: 10000,
546				})
547				require.NoError(t, err)
548				assert.NotNil(t, res)
549
550				msgs, err := env.messages.List(t.Context(), session.ID)
551				require.NoError(t, err)
552
553				var assistantMsg *message.Message
554				var toolMsgs []message.Message
555
556				for _, msg := range msgs {
557					if msg.Role == message.Assistant && len(msg.ToolCalls()) > 0 {
558						assistantMsg = &msg
559					}
560					if msg.Role == message.Tool {
561						toolMsgs = append(toolMsgs, msg)
562					}
563				}
564
565				require.NotNil(t, assistantMsg, "Expected to find an assistant message with tool calls")
566				require.NotNil(t, toolMsgs, "Expected to find a tool message")
567
568				toolCalls := assistantMsg.ToolCalls()
569				require.GreaterOrEqual(t, len(toolCalls), 2, "Expected at least 2 tool calls in parallel")
570
571				foundGlob := false
572				foundLS := false
573				var globTCID, lsTCID string
574
575				for _, tc := range toolCalls {
576					if tc.Name == tools.GlobToolName {
577						foundGlob = true
578						globTCID = tc.ID
579					}
580					if tc.Name == tools.LSToolName {
581						foundLS = true
582						lsTCID = tc.ID
583					}
584				}
585
586				require.True(t, foundGlob, "Expected to find a glob tool call")
587				require.True(t, foundLS, "Expected to find an ls tool call")
588
589				require.GreaterOrEqual(t, len(toolMsgs), 2, "Expected at least 2 tool results in the same message")
590
591				foundGlobResult := false
592				foundLSResult := false
593
594				for _, msg := range toolMsgs {
595					for _, tr := range msg.ToolResults() {
596						if tr.ToolCallID == globTCID {
597							foundGlobResult = true
598							require.Contains(t, tr.Content, "main.go", "Expected glob result to contain main.go")
599							require.False(t, tr.IsError, "Expected glob result to not be an error")
600						}
601						if tr.ToolCallID == lsTCID {
602							foundLSResult = true
603							require.Contains(t, tr.Content, "main.go", "Expected ls result to contain main.go")
604							require.False(t, tr.IsError, "Expected ls result to not be an error")
605						}
606					}
607				}
608
609				require.True(t, foundGlobResult, "Expected to find glob tool result")
610				require.True(t, foundLSResult, "Expected to find ls tool result")
611			})
612		})
613	}
614}