1package agent
2
3import (
4 "net/http"
5 "os"
6 "testing"
7
8 "github.com/charmbracelet/catwalk/pkg/catwalk"
9 "github.com/charmbracelet/crush/internal/agent/tools"
10 "github.com/charmbracelet/crush/internal/config"
11 "github.com/charmbracelet/crush/internal/csync"
12 "github.com/charmbracelet/crush/internal/db"
13 "github.com/charmbracelet/crush/internal/history"
14 "github.com/charmbracelet/crush/internal/lsp"
15 "github.com/charmbracelet/crush/internal/message"
16 "github.com/charmbracelet/crush/internal/permission"
17 "github.com/charmbracelet/crush/internal/session"
18 "github.com/charmbracelet/fantasy/ai"
19 "github.com/charmbracelet/fantasy/anthropic"
20 "github.com/stretchr/testify/assert"
21 "github.com/stretchr/testify/require"
22 "gopkg.in/dnaeon/go-vcr.v4/pkg/recorder"
23
24 _ "github.com/joho/godotenv/autoload"
25)
26
27type env struct {
28 workingDir string
29 sessions session.Service
30 messages message.Service
31 permissions permission.Service
32 history history.Service
33 lspClients *csync.Map[string, *lsp.Client]
34}
35
36type builderFunc func(r *recorder.Recorder) (ai.LanguageModel, error)
37
38func TestSessionAgent(t *testing.T) {
39 t.Run("simple test", func(t *testing.T) {
40 r := newRecorder(t)
41 sonnet, err := anthropicBuilder("claude-sonnet-4-5-20250929")(r)
42 require.NoError(t, err)
43 haiku, err := anthropicBuilder("claude-3-5-haiku-20241022")(r)
44 require.NoError(t, err)
45
46 env := testEnv(t)
47 agent := testSessionAgent(env, sonnet, haiku, "You are a helpful assistant")
48 session, err := env.sessions.Create(t.Context(), "New Session")
49 require.NoError(t, err)
50
51 res, err := agent.Run(t.Context(), SessionAgentCall{
52 Prompt: "Hello",
53 SessionID: session.ID,
54 MaxOutputTokens: 10000,
55 })
56
57 require.NoError(t, err)
58 assert.NotNil(t, res)
59
60 t.Run("should create session messages", func(t *testing.T) {
61 msgs, err := env.messages.List(t.Context(), session.ID)
62 require.NoError(t, err)
63 // Should have the agent and user message
64 assert.Equal(t, len(msgs), 2)
65 })
66 })
67}
68
69func TestCoderAgent(t *testing.T) {
70 t.Run("simple test", func(t *testing.T) {
71 r := newRecorder(t)
72 sonnet, err := anthropicBuilder("claude-sonnet-4-5-20250929")(r)
73 require.NoError(t, err)
74 haiku, err := anthropicBuilder("claude-3-5-haiku-20241022")(r)
75 require.NoError(t, err)
76
77 env := testEnv(t)
78 agent, err := coderAgent(env, sonnet, haiku)
79 require.NoError(t, err)
80 session, err := env.sessions.Create(t.Context(), "New Session")
81 require.NoError(t, err)
82
83 res, err := agent.Run(t.Context(), SessionAgentCall{
84 Prompt: "Hello",
85 SessionID: session.ID,
86 MaxOutputTokens: 10000,
87 })
88
89 require.NoError(t, err)
90 assert.NotNil(t, res)
91
92 msgs, err := env.messages.List(t.Context(), session.ID)
93 require.NoError(t, err)
94 // Should have the agent and user message
95 assert.Equal(t, len(msgs), 2)
96 })
97}
98
99func anthropicBuilder(model string) builderFunc {
100 return func(r *recorder.Recorder) (ai.LanguageModel, error) {
101 provider := anthropic.New(
102 anthropic.WithAPIKey(os.Getenv("CRUSH_ANTHROPIC_API_KEY")),
103 anthropic.WithHTTPClient(&http.Client{Transport: r}),
104 )
105 return provider.LanguageModel(model)
106 }
107}
108
109func testEnv(t *testing.T) env {
110 workingDir := t.TempDir()
111 conn, err := db.Connect(t.Context(), t.TempDir())
112 require.NoError(t, err)
113 q := db.New(conn)
114 sessions := session.NewService(q)
115 messages := message.NewService(q)
116 permissions := permission.NewPermissionService(workingDir, true, []string{})
117 history := history.NewService(q, conn)
118 lspClients := csync.NewMap[string, *lsp.Client]()
119 return env{
120 workingDir,
121 sessions,
122 messages,
123 permissions,
124 history,
125 lspClients,
126 }
127}
128
129func testSessionAgent(env env, large, small ai.LanguageModel, systemPrompt string, tools ...ai.AgentTool) SessionAgent {
130 largeModel := Model{
131 model: large,
132 config: catwalk.Model{
133 // todo: add values
134 },
135 }
136 smallModel := Model{
137 model: small,
138 config: catwalk.Model{
139 // todo: add values
140 },
141 }
142 agent := NewSessionAgent(largeModel, smallModel, systemPrompt, env.sessions, env.messages, tools...)
143 return agent
144}
145
146func coderAgent(env env, large, small ai.LanguageModel) (SessionAgent, error) {
147 prompt, err := coderPrompt()
148 if err != nil {
149 return nil, err
150 }
151 cfg, err := config.Init(env.workingDir, "", false)
152 if err != nil {
153 return nil, err
154 }
155
156 systemPrompt, err := prompt.Build(large.Provider(), large.Model(), *cfg)
157 if err != nil {
158 return nil, err
159 }
160 allTools := []ai.AgentTool{
161 tools.NewBashTool(env.permissions, env.workingDir, cfg.Options.Attribution),
162 tools.NewDownloadTool(env.permissions, env.workingDir),
163 tools.NewEditTool(env.lspClients, env.permissions, env.history, env.workingDir),
164 tools.NewMultiEditTool(env.lspClients, env.permissions, env.history, env.workingDir),
165 tools.NewFetchTool(env.permissions, env.workingDir),
166 tools.NewGlobTool(env.workingDir),
167 tools.NewGrepTool(env.workingDir),
168 tools.NewLsTool(env.permissions, env.workingDir),
169 tools.NewSourcegraphTool(),
170 tools.NewViewTool(env.lspClients, env.permissions, env.workingDir),
171 tools.NewWriteTool(env.lspClients, env.permissions, env.history, env.workingDir),
172 }
173
174 return testSessionAgent(env, large, small, systemPrompt, allTools...), nil
175}