1package agent
2
3import (
4 "context"
5 "net/http"
6 "os"
7 "path/filepath"
8 "testing"
9 "time"
10
11 "charm.land/fantasy"
12 "charm.land/fantasy/providers/anthropic"
13 "charm.land/fantasy/providers/openai"
14 "charm.land/fantasy/providers/openaicompat"
15 "charm.land/fantasy/providers/openrouter"
16 "charm.land/x/vcr"
17 "github.com/charmbracelet/catwalk/pkg/catwalk"
18 "github.com/charmbracelet/crush/internal/agent/prompt"
19 "github.com/charmbracelet/crush/internal/agent/tools"
20 "github.com/charmbracelet/crush/internal/config"
21 "github.com/charmbracelet/crush/internal/csync"
22 "github.com/charmbracelet/crush/internal/db"
23 "github.com/charmbracelet/crush/internal/history"
24 "github.com/charmbracelet/crush/internal/lsp"
25 "github.com/charmbracelet/crush/internal/message"
26 "github.com/charmbracelet/crush/internal/permission"
27 "github.com/charmbracelet/crush/internal/session"
28 "github.com/stretchr/testify/require"
29
30 _ "github.com/joho/godotenv/autoload"
31)
32
33// fakeEnv is an environment for testing.
34type fakeEnv struct {
35 workingDir string
36 sessions session.Service
37 messages message.Service
38 permissions permission.Service
39 history history.Service
40 lspClients *csync.Map[string, *lsp.Client]
41}
42
43type builderFunc func(t *testing.T, r *vcr.Recorder) (fantasy.LanguageModel, error)
44
45type modelPair struct {
46 name string
47 largeModel builderFunc
48 smallModel builderFunc
49}
50
51func anthropicBuilder(model string) builderFunc {
52 return func(t *testing.T, r *vcr.Recorder) (fantasy.LanguageModel, error) {
53 provider, err := anthropic.New(
54 anthropic.WithAPIKey(os.Getenv("CRUSH_ANTHROPIC_API_KEY")),
55 anthropic.WithHTTPClient(&http.Client{Transport: r}),
56 )
57 if err != nil {
58 return nil, err
59 }
60 return provider.LanguageModel(t.Context(), model)
61 }
62}
63
64func openaiBuilder(model string) builderFunc {
65 return func(t *testing.T, r *vcr.Recorder) (fantasy.LanguageModel, error) {
66 provider, err := openai.New(
67 openai.WithAPIKey(os.Getenv("CRUSH_OPENAI_API_KEY")),
68 openai.WithHTTPClient(&http.Client{Transport: r}),
69 )
70 if err != nil {
71 return nil, err
72 }
73 return provider.LanguageModel(t.Context(), model)
74 }
75}
76
77func openRouterBuilder(model string) builderFunc {
78 return func(t *testing.T, r *vcr.Recorder) (fantasy.LanguageModel, error) {
79 provider, err := openrouter.New(
80 openrouter.WithAPIKey(os.Getenv("CRUSH_OPENROUTER_API_KEY")),
81 openrouter.WithHTTPClient(&http.Client{Transport: r}),
82 )
83 if err != nil {
84 return nil, err
85 }
86 return provider.LanguageModel(t.Context(), model)
87 }
88}
89
90func zAIBuilder(model string) builderFunc {
91 return func(t *testing.T, r *vcr.Recorder) (fantasy.LanguageModel, error) {
92 provider, err := openaicompat.New(
93 openaicompat.WithBaseURL("https://api.z.ai/api/coding/paas/v4"),
94 openaicompat.WithAPIKey(os.Getenv("CRUSH_ZAI_API_KEY")),
95 openaicompat.WithHTTPClient(&http.Client{Transport: r}),
96 )
97 if err != nil {
98 return nil, err
99 }
100 return provider.LanguageModel(t.Context(), model)
101 }
102}
103
104func testEnv(t *testing.T) fakeEnv {
105 workingDir := filepath.Join("/tmp/crush-test/", t.Name())
106 os.RemoveAll(workingDir)
107
108 err := os.MkdirAll(workingDir, 0o755)
109 require.NoError(t, err)
110
111 conn, err := db.Connect(t.Context(), t.TempDir())
112 require.NoError(t, err)
113
114 q := db.New(conn)
115 sessions := session.NewService(q)
116 messages := message.NewService(q)
117
118 permissions := permission.NewPermissionService(workingDir, true, []string{})
119 history := history.NewService(q, conn)
120 lspClients := csync.NewMap[string, *lsp.Client]()
121
122 t.Cleanup(func() {
123 conn.Close()
124 os.RemoveAll(workingDir)
125 })
126
127 return fakeEnv{
128 workingDir,
129 sessions,
130 messages,
131 permissions,
132 history,
133 lspClients,
134 }
135}
136
137func testSessionAgent(env fakeEnv, large, small fantasy.LanguageModel, systemPrompt string, tools ...fantasy.AgentTool) SessionAgent {
138 largeModel := Model{
139 Model: large,
140 CatwalkCfg: catwalk.Model{
141 ContextWindow: 200000,
142 DefaultMaxTokens: 10000,
143 },
144 }
145 smallModel := Model{
146 Model: small,
147 CatwalkCfg: catwalk.Model{
148 ContextWindow: 200000,
149 DefaultMaxTokens: 10000,
150 },
151 }
152 agent := NewSessionAgent(SessionAgentOptions{largeModel, smallModel, "", systemPrompt, false, false, true, env.sessions, env.messages, tools})
153 return agent
154}
155
156func coderAgent(r *vcr.Recorder, env fakeEnv, large, small fantasy.LanguageModel) (SessionAgent, error) {
157 fixedTime := func() time.Time {
158 t, _ := time.Parse("1/2/2006", "1/1/2025")
159 return t
160 }
161 prompt, err := coderPrompt(
162 prompt.WithTimeFunc(fixedTime),
163 prompt.WithPlatform("linux"),
164 prompt.WithWorkingDir(filepath.ToSlash(env.workingDir)),
165 )
166 if err != nil {
167 return nil, err
168 }
169 cfg, err := config.Init(env.workingDir, "", false)
170 if err != nil {
171 return nil, err
172 }
173
174 // NOTE(@andreynering): Set a fixed config to ensure cassettes match
175 // independently of user config on `$HOME/.config/crush/crush.json`.
176 cfg.Options.Attribution = &config.Attribution{
177 TrailerStyle: "co-authored-by",
178 GeneratedWith: true,
179 }
180
181 // Clear skills paths to ensure test reproducibility - user's skills
182 // would be included in prompt and break VCR cassette matching.
183 cfg.Options.SkillsPaths = []string{}
184
185 // Clear LSP config to ensure test reproducibility - user's LSP config
186 // would be included in prompt and break VCR cassette matching.
187 cfg.LSP = nil
188
189 systemPrompt, err := prompt.Build(context.TODO(), large.Provider(), large.Model(), *cfg)
190 if err != nil {
191 return nil, err
192 }
193
194 // Get the model name for the bash tool
195 modelName := large.Model() // fallback to ID if Name not available
196 if model := cfg.GetModel(large.Provider(), large.Model()); model != nil {
197 modelName = model.Name
198 }
199
200 allTools := []fantasy.AgentTool{
201 tools.NewBashTool(env.permissions, env.workingDir, cfg.Options.Attribution, modelName),
202 tools.NewDownloadTool(env.permissions, env.workingDir, r.GetDefaultClient()),
203 tools.NewEditTool(env.lspClients, env.permissions, env.history, env.workingDir),
204 tools.NewMultiEditTool(env.lspClients, env.permissions, env.history, env.workingDir),
205 tools.NewFetchTool(env.permissions, env.workingDir, r.GetDefaultClient()),
206 tools.NewGlobTool(env.workingDir),
207 tools.NewGrepTool(env.workingDir),
208 tools.NewLsTool(env.permissions, env.workingDir, cfg.Tools.Ls),
209 tools.NewSourcegraphTool(r.GetDefaultClient()),
210 tools.NewViewTool(env.lspClients, env.permissions, env.workingDir),
211 tools.NewWriteTool(env.lspClients, env.permissions, env.history, env.workingDir),
212 }
213
214 return testSessionAgent(env, large, small, systemPrompt, allTools...), nil
215}
216
217// createSimpleGoProject creates a simple Go project structure in the given directory.
218// It creates a go.mod file and a main.go file with a basic hello world program.
219func createSimpleGoProject(t *testing.T, dir string) {
220 goMod := `module example.com/testproject
221
222go 1.23
223`
224 err := os.WriteFile(dir+"/go.mod", []byte(goMod), 0o644)
225 require.NoError(t, err)
226
227 mainGo := `package main
228
229import "fmt"
230
231func main() {
232 fmt.Println("Hello, World!")
233}
234`
235 err = os.WriteFile(dir+"/main.go", []byte(mainGo), 0o644)
236 require.NoError(t, err)
237}