agent_test.go

  1package agent
  2
  3import (
  4	"encoding/json"
  5	"fmt"
  6	"strings"
  7	"testing"
  8
  9	"github.com/charmbracelet/crush/internal/agent/tools"
 10	"github.com/charmbracelet/crush/internal/message"
 11	"github.com/charmbracelet/fantasy/ai"
 12	"github.com/stretchr/testify/assert"
 13	"github.com/stretchr/testify/require"
 14
 15	_ "github.com/joho/godotenv/autoload"
 16)
 17
 18var modelPairs = []modelPair{
 19	{"anthropic-sonnet", anthropicBuilder("claude-sonnet-4-5-20250929"), anthropicBuilder("claude-3-5-haiku-20241022")},
 20	{"openai-gpt-5", openaiBuilder("gpt-5"), openaiBuilder("gpt-4o")},
 21	{"openrouter-kimi-k2", openRouterBuilder("moonshotai/kimi-k2-0905"), openRouterBuilder("qwen/qwen3-next-80b-a3b-instruct")},
 22}
 23
 24func getModels(t *testing.T, pair modelPair) (ai.LanguageModel, ai.LanguageModel) {
 25	r := newRecorder(t)
 26	large, err := pair.largeModel(t, r)
 27	require.NoError(t, err)
 28	small, err := pair.smallModel(t, r)
 29	require.NoError(t, err)
 30	return large, small
 31}
 32
 33func setupAgent(t *testing.T, pair modelPair) (SessionAgent, env) {
 34	large, small := getModels(t, pair)
 35	env := testEnv(t)
 36
 37	createSimpleGoProject(t, env.workingDir)
 38	agent, err := coderAgent(env, large, small)
 39	require.NoError(t, err)
 40	return agent, env
 41}
 42
 43func TestCoderAgent(t *testing.T) {
 44	for _, pair := range modelPairs {
 45		t.Run(pair.name, func(t *testing.T) {
 46			t.Run("simple test", func(t *testing.T) {
 47				agent, env := setupAgent(t, pair)
 48
 49				session, err := env.sessions.Create(t.Context(), "New Session")
 50				require.NoError(t, err)
 51
 52				res, err := agent.Run(t.Context(), SessionAgentCall{
 53					Prompt:          "Hello",
 54					SessionID:       session.ID,
 55					MaxOutputTokens: 10000,
 56				})
 57				require.NoError(t, err)
 58				assert.NotNil(t, res)
 59
 60				msgs, err := env.messages.List(t.Context(), session.ID)
 61				require.NoError(t, err)
 62				// Should have the agent and user message
 63				assert.Equal(t, len(msgs), 2)
 64			})
 65			t.Run("read a file", func(t *testing.T) {
 66				agent, env := setupAgent(t, pair)
 67
 68				session, err := env.sessions.Create(t.Context(), "New Session")
 69				require.NoError(t, err)
 70				res, err := agent.Run(t.Context(), SessionAgentCall{
 71					Prompt:          "Read the go mod",
 72					SessionID:       session.ID,
 73					MaxOutputTokens: 10000,
 74				})
 75
 76				require.NoError(t, err)
 77				assert.NotNil(t, res)
 78
 79				msgs, err := env.messages.List(t.Context(), session.ID)
 80				require.NoError(t, err)
 81				foundFile := false
 82				var tcID string
 83			out:
 84				for _, msg := range msgs {
 85					data, _ := json.Marshal(msg)
 86					fmt.Println(string(data))
 87					if msg.Role == message.Assistant {
 88						for _, tc := range msg.ToolCalls() {
 89							if tc.Name == tools.ViewToolName {
 90								tcID = tc.ID
 91							}
 92						}
 93					}
 94					if msg.Role == message.Tool {
 95						for _, tr := range msg.ToolResults() {
 96							if tr.ToolCallID == tcID {
 97								if strings.Contains(tr.Content, "module example.com/testproject") {
 98									foundFile = true
 99									break out
100								}
101							}
102						}
103					}
104				}
105				require.True(t, foundFile)
106			})
107		})
108	}
109}