1package agent
  2
  3import (
  4	"net/http"
  5	"os"
  6	"testing"
  7
  8	"github.com/charmbracelet/catwalk/pkg/catwalk"
  9	"github.com/charmbracelet/crush/internal/agent/tools"
 10	"github.com/charmbracelet/crush/internal/config"
 11	"github.com/charmbracelet/crush/internal/csync"
 12	"github.com/charmbracelet/crush/internal/db"
 13	"github.com/charmbracelet/crush/internal/history"
 14	"github.com/charmbracelet/crush/internal/lsp"
 15	"github.com/charmbracelet/crush/internal/message"
 16	"github.com/charmbracelet/crush/internal/permission"
 17	"github.com/charmbracelet/crush/internal/session"
 18	"github.com/charmbracelet/fantasy/ai"
 19	"github.com/charmbracelet/fantasy/anthropic"
 20	"github.com/stretchr/testify/assert"
 21	"github.com/stretchr/testify/require"
 22	"gopkg.in/dnaeon/go-vcr.v4/pkg/recorder"
 23
 24	_ "github.com/joho/godotenv/autoload"
 25)
 26
 27type env struct {
 28	workingDir  string
 29	sessions    session.Service
 30	messages    message.Service
 31	permissions permission.Service
 32	history     history.Service
 33	lspClients  *csync.Map[string, *lsp.Client]
 34}
 35
 36type builderFunc func(r *recorder.Recorder) (ai.LanguageModel, error)
 37
 38func TestSessionAgent(t *testing.T) {
 39	t.Run("simple test", func(t *testing.T) {
 40		r := newRecorder(t)
 41		sonnet, err := anthropicBuilder("claude-sonnet-4-5-20250929")(r)
 42		require.NoError(t, err)
 43		haiku, err := anthropicBuilder("claude-3-5-haiku-20241022")(r)
 44		require.NoError(t, err)
 45
 46		env := testEnv(t)
 47		agent := testSessionAgent(env, sonnet, haiku, "You are a helpful assistant")
 48		session, err := env.sessions.Create(t.Context(), "New Session")
 49		require.NoError(t, err)
 50
 51		res, err := agent.Run(t.Context(), SessionAgentCall{
 52			Prompt:          "Hello",
 53			SessionID:       session.ID,
 54			MaxOutputTokens: 10000,
 55		})
 56
 57		require.NoError(t, err)
 58		assert.NotNil(t, res)
 59
 60		t.Run("should create session messages", func(t *testing.T) {
 61			msgs, err := env.messages.List(t.Context(), session.ID)
 62			require.NoError(t, err)
 63			// Should have the agent and user message
 64			assert.Equal(t, len(msgs), 2)
 65		})
 66	})
 67}
 68
 69func TestCoderAgent(t *testing.T) {
 70	t.Run("simple test", func(t *testing.T) {
 71		r := newRecorder(t)
 72		sonnet, err := anthropicBuilder("claude-sonnet-4-5-20250929")(r)
 73		require.NoError(t, err)
 74		haiku, err := anthropicBuilder("claude-3-5-haiku-20241022")(r)
 75		require.NoError(t, err)
 76
 77		env := testEnv(t)
 78		agent, err := coderAgent(env, sonnet, haiku)
 79		require.NoError(t, err)
 80		session, err := env.sessions.Create(t.Context(), "New Session")
 81		require.NoError(t, err)
 82
 83		res, err := agent.Run(t.Context(), SessionAgentCall{
 84			Prompt:          "Hello",
 85			SessionID:       session.ID,
 86			MaxOutputTokens: 10000,
 87		})
 88
 89		require.NoError(t, err)
 90		assert.NotNil(t, res)
 91
 92		msgs, err := env.messages.List(t.Context(), session.ID)
 93		require.NoError(t, err)
 94		// Should have the agent and user message
 95		assert.Equal(t, len(msgs), 2)
 96	})
 97}
 98
 99func anthropicBuilder(model string) builderFunc {
100	return func(r *recorder.Recorder) (ai.LanguageModel, error) {
101		provider := anthropic.New(
102			anthropic.WithAPIKey(os.Getenv("CRUSH_ANTHROPIC_API_KEY")),
103			anthropic.WithHTTPClient(&http.Client{Transport: r}),
104		)
105		return provider.LanguageModel(model)
106	}
107}
108
109func testEnv(t *testing.T) env {
110	workingDir := t.TempDir()
111	conn, err := db.Connect(t.Context(), t.TempDir())
112	require.NoError(t, err)
113	q := db.New(conn)
114	sessions := session.NewService(q)
115	messages := message.NewService(q)
116	permissions := permission.NewPermissionService(workingDir, true, []string{})
117	history := history.NewService(q, conn)
118	lspClients := csync.NewMap[string, *lsp.Client]()
119	return env{
120		workingDir,
121		sessions,
122		messages,
123		permissions,
124		history,
125		lspClients,
126	}
127}
128
129func testSessionAgent(env env, large, small ai.LanguageModel, systemPrompt string, tools ...ai.AgentTool) SessionAgent {
130	largeModel := Model{
131		model:  large,
132		config: catwalk.Model{
133			// todo: add values
134		},
135	}
136	smallModel := Model{
137		model:  small,
138		config: catwalk.Model{
139			// todo: add values
140		},
141	}
142	agent := NewSessionAgent(largeModel, smallModel, systemPrompt, env.sessions, env.messages, tools...)
143	return agent
144}
145
146func coderAgent(env env, large, small ai.LanguageModel) (SessionAgent, error) {
147	prompt, err := coderPrompt()
148	if err != nil {
149		return nil, err
150	}
151	cfg, err := config.Init(env.workingDir, "", false)
152	if err != nil {
153		return nil, err
154	}
155
156	systemPrompt, err := prompt.Build(large.Provider(), large.Model(), *cfg)
157	if err != nil {
158		return nil, err
159	}
160	allTools := []ai.AgentTool{
161		tools.NewBashTool(env.permissions, env.workingDir, cfg.Options.Attribution),
162		tools.NewDownloadTool(env.permissions, env.workingDir),
163		tools.NewEditTool(env.lspClients, env.permissions, env.history, env.workingDir),
164		tools.NewMultiEditTool(env.lspClients, env.permissions, env.history, env.workingDir),
165		tools.NewFetchTool(env.permissions, env.workingDir),
166		tools.NewGlobTool(env.workingDir),
167		tools.NewGrepTool(env.workingDir),
168		tools.NewLsTool(env.permissions, env.workingDir),
169		tools.NewSourcegraphTool(),
170		tools.NewViewTool(env.lspClients, env.permissions, env.workingDir),
171		tools.NewWriteTool(env.lspClients, env.permissions, env.history, env.workingDir),
172	}
173
174	return testSessionAgent(env, large, small, systemPrompt, allTools...), nil
175}