1package agent
  2
  3import (
  4	"os"
  5	"path/filepath"
  6	"strings"
  7	"testing"
  8
  9	"charm.land/fantasy"
 10	"github.com/charmbracelet/crush/internal/agent/tools"
 11	"github.com/charmbracelet/crush/internal/message"
 12	"github.com/charmbracelet/crush/internal/shell"
 13	"github.com/stretchr/testify/assert"
 14	"github.com/stretchr/testify/require"
 15	"gopkg.in/dnaeon/go-vcr.v4/pkg/recorder"
 16
 17	_ "github.com/joho/godotenv/autoload"
 18)
 19
 20var modelPairs = []modelPair{
 21	{"anthropic-sonnet", anthropicBuilder("claude-sonnet-4-5-20250929"), anthropicBuilder("claude-3-5-haiku-20241022")},
 22	{"openai-gpt-5", openaiBuilder("gpt-5"), openaiBuilder("gpt-4o")},
 23	{"openrouter-kimi-k2", openRouterBuilder("moonshotai/kimi-k2-0905"), openRouterBuilder("qwen/qwen3-next-80b-a3b-instruct")},
 24	{"zai-glm4.6", zAIBuilder("glm-4.6"), zAIBuilder("glm-4.5-air")},
 25}
 26
 27func getModels(t *testing.T, r *recorder.Recorder, pair modelPair) (fantasy.LanguageModel, fantasy.LanguageModel) {
 28	large, err := pair.largeModel(t, r)
 29	require.NoError(t, err)
 30	small, err := pair.smallModel(t, r)
 31	require.NoError(t, err)
 32	return large, small
 33}
 34
 35func setupAgent(t *testing.T, pair modelPair) (SessionAgent, env) {
 36	r := newRecorder(t)
 37	large, small := getModels(t, r, pair)
 38	env := testEnv(t)
 39
 40	createSimpleGoProject(t, env.workingDir)
 41	agent, err := coderAgent(r, env, large, small)
 42	shell.Reset(env.workingDir)
 43	require.NoError(t, err)
 44	return agent, env
 45}
 46
 47func TestCoderAgent(t *testing.T) {
 48	for _, pair := range modelPairs {
 49		t.Run(pair.name, func(t *testing.T) {
 50			t.Run("simple test", func(t *testing.T) {
 51				agent, env := setupAgent(t, pair)
 52
 53				session, err := env.sessions.Create(t.Context(), "New Session")
 54				require.NoError(t, err)
 55
 56				res, err := agent.Run(t.Context(), SessionAgentCall{
 57					Prompt:          "Hello",
 58					SessionID:       session.ID,
 59					MaxOutputTokens: 10000,
 60				})
 61				require.NoError(t, err)
 62				assert.NotNil(t, res)
 63
 64				msgs, err := env.messages.List(t.Context(), session.ID)
 65				require.NoError(t, err)
 66				// Should have the agent and user message
 67				assert.Equal(t, len(msgs), 2)
 68			})
 69			t.Run("read a file", func(t *testing.T) {
 70				agent, env := setupAgent(t, pair)
 71
 72				session, err := env.sessions.Create(t.Context(), "New Session")
 73				require.NoError(t, err)
 74				res, err := agent.Run(t.Context(), SessionAgentCall{
 75					Prompt:          "Read the go mod",
 76					SessionID:       session.ID,
 77					MaxOutputTokens: 10000,
 78				})
 79
 80				require.NoError(t, err)
 81				assert.NotNil(t, res)
 82
 83				msgs, err := env.messages.List(t.Context(), session.ID)
 84				require.NoError(t, err)
 85				foundFile := false
 86				var tcID string
 87			out:
 88				for _, msg := range msgs {
 89					if msg.Role == message.Assistant {
 90						for _, tc := range msg.ToolCalls() {
 91							if tc.Name == tools.ViewToolName {
 92								tcID = tc.ID
 93							}
 94						}
 95					}
 96					if msg.Role == message.Tool {
 97						for _, tr := range msg.ToolResults() {
 98							if tr.ToolCallID == tcID {
 99								if strings.Contains(tr.Content, "module example.com/testproject") {
100									foundFile = true
101									break out
102								}
103							}
104						}
105					}
106				}
107				require.True(t, foundFile)
108			})
109			t.Run("update a file", func(t *testing.T) {
110				agent, env := setupAgent(t, pair)
111
112				session, err := env.sessions.Create(t.Context(), "New Session")
113				require.NoError(t, err)
114
115				res, err := agent.Run(t.Context(), SessionAgentCall{
116					Prompt:          "update the main.go file by changing the print to say hello from crush",
117					SessionID:       session.ID,
118					MaxOutputTokens: 10000,
119				})
120				require.NoError(t, err)
121				assert.NotNil(t, res)
122
123				msgs, err := env.messages.List(t.Context(), session.ID)
124				require.NoError(t, err)
125
126				foundRead := false
127				foundWrite := false
128				var readTCID, writeTCID string
129
130				for _, msg := range msgs {
131					if msg.Role == message.Assistant {
132						for _, tc := range msg.ToolCalls() {
133							if tc.Name == tools.ViewToolName {
134								readTCID = tc.ID
135							}
136							if tc.Name == tools.EditToolName || tc.Name == tools.WriteToolName {
137								writeTCID = tc.ID
138							}
139						}
140					}
141					if msg.Role == message.Tool {
142						for _, tr := range msg.ToolResults() {
143							if tr.ToolCallID == readTCID {
144								foundRead = true
145							}
146							if tr.ToolCallID == writeTCID {
147								foundWrite = true
148							}
149						}
150					}
151				}
152
153				require.True(t, foundRead, "Expected to find a read operation")
154				require.True(t, foundWrite, "Expected to find a write operation")
155
156				mainGoPath := filepath.Join(env.workingDir, "main.go")
157				content, err := os.ReadFile(mainGoPath)
158				require.NoError(t, err)
159				require.Contains(t, strings.ToLower(string(content)), "hello from crush")
160			})
161			t.Run("bash tool", func(t *testing.T) {
162				agent, env := setupAgent(t, pair)
163
164				session, err := env.sessions.Create(t.Context(), "New Session")
165				require.NoError(t, err)
166
167				res, err := agent.Run(t.Context(), SessionAgentCall{
168					Prompt:          "use bash to create a file named test.txt with content 'hello bash'",
169					SessionID:       session.ID,
170					MaxOutputTokens: 10000,
171				})
172				require.NoError(t, err)
173				assert.NotNil(t, res)
174
175				msgs, err := env.messages.List(t.Context(), session.ID)
176				require.NoError(t, err)
177
178				foundBash := false
179				var bashTCID string
180
181				for _, msg := range msgs {
182					if msg.Role == message.Assistant {
183						for _, tc := range msg.ToolCalls() {
184							if tc.Name == tools.BashToolName {
185								bashTCID = tc.ID
186							}
187						}
188					}
189					if msg.Role == message.Tool {
190						for _, tr := range msg.ToolResults() {
191							if tr.ToolCallID == bashTCID {
192								foundBash = true
193							}
194						}
195					}
196				}
197
198				require.True(t, foundBash, "Expected to find a bash operation")
199
200				testFilePath := filepath.Join(env.workingDir, "test.txt")
201				content, err := os.ReadFile(testFilePath)
202				require.NoError(t, err)
203				require.Contains(t, string(content), "hello bash")
204			})
205			t.Run("download tool", func(t *testing.T) {
206				agent, env := setupAgent(t, pair)
207
208				session, err := env.sessions.Create(t.Context(), "New Session")
209				require.NoError(t, err)
210
211				res, err := agent.Run(t.Context(), SessionAgentCall{
212					Prompt:          "download the file from https://example-files.online-convert.com/document/txt/example.txt and save it as example.txt",
213					SessionID:       session.ID,
214					MaxOutputTokens: 10000,
215				})
216				require.NoError(t, err)
217				assert.NotNil(t, res)
218
219				msgs, err := env.messages.List(t.Context(), session.ID)
220				require.NoError(t, err)
221
222				foundDownload := false
223				var downloadTCID string
224
225				for _, msg := range msgs {
226					if msg.Role == message.Assistant {
227						for _, tc := range msg.ToolCalls() {
228							if tc.Name == tools.DownloadToolName {
229								downloadTCID = tc.ID
230							}
231						}
232					}
233					if msg.Role == message.Tool {
234						for _, tr := range msg.ToolResults() {
235							if tr.ToolCallID == downloadTCID {
236								foundDownload = true
237							}
238						}
239					}
240				}
241
242				require.True(t, foundDownload, "Expected to find a download operation")
243
244				examplePath := filepath.Join(env.workingDir, "example.txt")
245				_, err = os.Stat(examplePath)
246				require.NoError(t, err, "Expected example.txt file to exist")
247			})
248			t.Run("glob tool", func(t *testing.T) {
249				agent, env := setupAgent(t, pair)
250
251				session, err := env.sessions.Create(t.Context(), "New Session")
252				require.NoError(t, err)
253
254				res, err := agent.Run(t.Context(), SessionAgentCall{
255					Prompt:          "use glob to find all .go files in the current directory",
256					SessionID:       session.ID,
257					MaxOutputTokens: 10000,
258				})
259				require.NoError(t, err)
260				assert.NotNil(t, res)
261
262				msgs, err := env.messages.List(t.Context(), session.ID)
263				require.NoError(t, err)
264
265				foundGlob := false
266				var globTCID string
267
268				for _, msg := range msgs {
269					if msg.Role == message.Assistant {
270						for _, tc := range msg.ToolCalls() {
271							if tc.Name == tools.GlobToolName {
272								globTCID = tc.ID
273							}
274						}
275					}
276					if msg.Role == message.Tool {
277						for _, tr := range msg.ToolResults() {
278							if tr.ToolCallID == globTCID {
279								foundGlob = true
280								require.Contains(t, tr.Content, "main.go", "Expected glob to find main.go")
281							}
282						}
283					}
284				}
285
286				require.True(t, foundGlob, "Expected to find a glob operation")
287			})
288			t.Run("grep tool", func(t *testing.T) {
289				agent, env := setupAgent(t, pair)
290
291				session, err := env.sessions.Create(t.Context(), "New Session")
292				require.NoError(t, err)
293
294				res, err := agent.Run(t.Context(), SessionAgentCall{
295					Prompt:          "use grep to search for the word 'package' in go files",
296					SessionID:       session.ID,
297					MaxOutputTokens: 10000,
298				})
299				require.NoError(t, err)
300				assert.NotNil(t, res)
301
302				msgs, err := env.messages.List(t.Context(), session.ID)
303				require.NoError(t, err)
304
305				foundGrep := false
306				var grepTCID string
307
308				for _, msg := range msgs {
309					if msg.Role == message.Assistant {
310						for _, tc := range msg.ToolCalls() {
311							if tc.Name == tools.GrepToolName {
312								grepTCID = tc.ID
313							}
314						}
315					}
316					if msg.Role == message.Tool {
317						for _, tr := range msg.ToolResults() {
318							if tr.ToolCallID == grepTCID {
319								foundGrep = true
320								require.Contains(t, tr.Content, "main.go", "Expected grep to find main.go")
321							}
322						}
323					}
324				}
325
326				require.True(t, foundGrep, "Expected to find a grep operation")
327			})
328			t.Run("ls tool", func(t *testing.T) {
329				agent, env := setupAgent(t, pair)
330
331				session, err := env.sessions.Create(t.Context(), "New Session")
332				require.NoError(t, err)
333
334				res, err := agent.Run(t.Context(), SessionAgentCall{
335					Prompt:          "use ls to list the files in the current directory",
336					SessionID:       session.ID,
337					MaxOutputTokens: 10000,
338				})
339				require.NoError(t, err)
340				assert.NotNil(t, res)
341
342				msgs, err := env.messages.List(t.Context(), session.ID)
343				require.NoError(t, err)
344
345				foundLS := false
346				var lsTCID string
347
348				for _, msg := range msgs {
349					if msg.Role == message.Assistant {
350						for _, tc := range msg.ToolCalls() {
351							if tc.Name == tools.LSToolName {
352								lsTCID = tc.ID
353							}
354						}
355					}
356					if msg.Role == message.Tool {
357						for _, tr := range msg.ToolResults() {
358							if tr.ToolCallID == lsTCID {
359								foundLS = true
360								require.Contains(t, tr.Content, "main.go", "Expected ls to list main.go")
361								require.Contains(t, tr.Content, "go.mod", "Expected ls to list go.mod")
362							}
363						}
364					}
365				}
366
367				require.True(t, foundLS, "Expected to find an ls operation")
368			})
369			t.Run("multiedit tool", func(t *testing.T) {
370				agent, env := setupAgent(t, pair)
371
372				session, err := env.sessions.Create(t.Context(), "New Session")
373				require.NoError(t, err)
374
375				res, err := agent.Run(t.Context(), SessionAgentCall{
376					Prompt:          "use multiedit to change 'Hello, World!' to 'Hello, Crush!' and add a comment '// Greeting' above the fmt.Println line in main.go",
377					SessionID:       session.ID,
378					MaxOutputTokens: 10000,
379				})
380				require.NoError(t, err)
381				assert.NotNil(t, res)
382
383				msgs, err := env.messages.List(t.Context(), session.ID)
384				require.NoError(t, err)
385
386				foundMultiEdit := false
387				var multiEditTCID string
388
389				for _, msg := range msgs {
390					if msg.Role == message.Assistant {
391						for _, tc := range msg.ToolCalls() {
392							if tc.Name == tools.MultiEditToolName {
393								multiEditTCID = tc.ID
394							}
395						}
396					}
397					if msg.Role == message.Tool {
398						for _, tr := range msg.ToolResults() {
399							if tr.ToolCallID == multiEditTCID {
400								foundMultiEdit = true
401							}
402						}
403					}
404				}
405
406				require.True(t, foundMultiEdit, "Expected to find a multiedit operation")
407
408				mainGoPath := filepath.Join(env.workingDir, "main.go")
409				content, err := os.ReadFile(mainGoPath)
410				require.NoError(t, err)
411				require.Contains(t, string(content), "Hello, Crush!", "Expected file to contain 'Hello, Crush!'")
412			})
413			t.Run("sourcegraph tool", func(t *testing.T) {
414				agent, env := setupAgent(t, pair)
415
416				session, err := env.sessions.Create(t.Context(), "New Session")
417				require.NoError(t, err)
418
419				res, err := agent.Run(t.Context(), SessionAgentCall{
420					Prompt:          "use sourcegraph to search for 'func main' in Go repositories",
421					SessionID:       session.ID,
422					MaxOutputTokens: 10000,
423				})
424				require.NoError(t, err)
425				assert.NotNil(t, res)
426
427				msgs, err := env.messages.List(t.Context(), session.ID)
428				require.NoError(t, err)
429
430				foundSourcegraph := false
431				var sourcegraphTCID string
432
433				for _, msg := range msgs {
434					if msg.Role == message.Assistant {
435						for _, tc := range msg.ToolCalls() {
436							if tc.Name == tools.SourcegraphToolName {
437								sourcegraphTCID = tc.ID
438							}
439						}
440					}
441					if msg.Role == message.Tool {
442						for _, tr := range msg.ToolResults() {
443							if tr.ToolCallID == sourcegraphTCID {
444								foundSourcegraph = true
445							}
446						}
447					}
448				}
449
450				require.True(t, foundSourcegraph, "Expected to find a sourcegraph operation")
451			})
452			t.Run("write tool", func(t *testing.T) {
453				agent, env := setupAgent(t, pair)
454
455				session, err := env.sessions.Create(t.Context(), "New Session")
456				require.NoError(t, err)
457
458				res, err := agent.Run(t.Context(), SessionAgentCall{
459					Prompt:          "use write to create a new file called config.json with content '{\"name\": \"test\", \"version\": \"1.0.0\"}'",
460					SessionID:       session.ID,
461					MaxOutputTokens: 10000,
462				})
463				require.NoError(t, err)
464				assert.NotNil(t, res)
465
466				msgs, err := env.messages.List(t.Context(), session.ID)
467				require.NoError(t, err)
468
469				foundWrite := false
470				var writeTCID string
471
472				for _, msg := range msgs {
473					if msg.Role == message.Assistant {
474						for _, tc := range msg.ToolCalls() {
475							if tc.Name == tools.WriteToolName {
476								writeTCID = tc.ID
477							}
478						}
479					}
480					if msg.Role == message.Tool {
481						for _, tr := range msg.ToolResults() {
482							if tr.ToolCallID == writeTCID {
483								foundWrite = true
484							}
485						}
486					}
487				}
488
489				require.True(t, foundWrite, "Expected to find a write operation")
490
491				configPath := filepath.Join(env.workingDir, "config.json")
492				content, err := os.ReadFile(configPath)
493				require.NoError(t, err)
494				require.Contains(t, string(content), "test", "Expected config.json to contain 'test'")
495				require.Contains(t, string(content), "1.0.0", "Expected config.json to contain '1.0.0'")
496			})
497			t.Run("parallel tool calls", func(t *testing.T) {
498				agent, env := setupAgent(t, pair)
499
500				session, err := env.sessions.Create(t.Context(), "New Session")
501				require.NoError(t, err)
502
503				res, err := agent.Run(t.Context(), SessionAgentCall{
504					Prompt:          "use glob to find all .go files and use ls to list the current directory, it is very important that you run both tool calls in parallel",
505					SessionID:       session.ID,
506					MaxOutputTokens: 10000,
507				})
508				require.NoError(t, err)
509				assert.NotNil(t, res)
510
511				msgs, err := env.messages.List(t.Context(), session.ID)
512				require.NoError(t, err)
513
514				var assistantMsg *message.Message
515				var toolMsgs []message.Message
516
517				for _, msg := range msgs {
518					if msg.Role == message.Assistant && len(msg.ToolCalls()) > 0 {
519						assistantMsg = &msg
520					}
521					if msg.Role == message.Tool {
522						toolMsgs = append(toolMsgs, msg)
523					}
524				}
525
526				require.NotNil(t, assistantMsg, "Expected to find an assistant message with tool calls")
527				require.NotNil(t, toolMsgs, "Expected to find a tool message")
528
529				toolCalls := assistantMsg.ToolCalls()
530				require.GreaterOrEqual(t, len(toolCalls), 2, "Expected at least 2 tool calls in parallel")
531
532				foundGlob := false
533				foundLS := false
534				var globTCID, lsTCID string
535
536				for _, tc := range toolCalls {
537					if tc.Name == tools.GlobToolName {
538						foundGlob = true
539						globTCID = tc.ID
540					}
541					if tc.Name == tools.LSToolName {
542						foundLS = true
543						lsTCID = tc.ID
544					}
545				}
546
547				require.True(t, foundGlob, "Expected to find a glob tool call")
548				require.True(t, foundLS, "Expected to find an ls tool call")
549
550				require.GreaterOrEqual(t, len(toolMsgs), 2, "Expected at least 2 tool results in the same message")
551
552				foundGlobResult := false
553				foundLSResult := false
554
555				for _, msg := range toolMsgs {
556					for _, tr := range msg.ToolResults() {
557						if tr.ToolCallID == globTCID {
558							foundGlobResult = true
559							require.Contains(t, tr.Content, "main.go", "Expected glob result to contain main.go")
560							require.False(t, tr.IsError, "Expected glob result to not be an error")
561						}
562						if tr.ToolCallID == lsTCID {
563							foundLSResult = true
564							require.Contains(t, tr.Content, "main.go", "Expected ls result to contain main.go")
565							require.False(t, tr.IsError, "Expected ls result to not be an error")
566						}
567					}
568				}
569
570				require.True(t, foundGlobResult, "Expected to find glob tool result")
571				require.True(t, foundLSResult, "Expected to find ls tool result")
572			})
573		})
574	}
575}