1package agent
2
3import (
4 "context"
5 "net/http"
6 "os"
7 "path/filepath"
8 "testing"
9 "time"
10
11 "charm.land/catwalk/pkg/catwalk"
12 "charm.land/fantasy"
13 "charm.land/fantasy/providers/anthropic"
14 "charm.land/fantasy/providers/openai"
15 "charm.land/fantasy/providers/openaicompat"
16 "charm.land/fantasy/providers/openrouter"
17 "charm.land/x/vcr"
18 "github.com/charmbracelet/crush/internal/agent/prompt"
19 "github.com/charmbracelet/crush/internal/agent/tools"
20 "github.com/charmbracelet/crush/internal/config"
21 "github.com/charmbracelet/crush/internal/csync"
22 "github.com/charmbracelet/crush/internal/db"
23 "github.com/charmbracelet/crush/internal/filetracker"
24 "github.com/charmbracelet/crush/internal/history"
25 "github.com/charmbracelet/crush/internal/lsp"
26 "github.com/charmbracelet/crush/internal/message"
27 "github.com/charmbracelet/crush/internal/permission"
28 "github.com/charmbracelet/crush/internal/session"
29 "github.com/stretchr/testify/require"
30
31 _ "github.com/joho/godotenv/autoload"
32)
33
34// fakeEnv is an environment for testing.
35type fakeEnv struct {
36 workingDir string
37 sessions session.Service
38 messages message.Service
39 permissions permission.Service
40 history history.Service
41 filetracker *filetracker.Service
42 lspClients *csync.Map[string, *lsp.Client]
43}
44
45type builderFunc func(t *testing.T, r *vcr.Recorder) (fantasy.LanguageModel, error)
46
47type modelPair struct {
48 name string
49 largeModel builderFunc
50 smallModel builderFunc
51}
52
53func anthropicBuilder(model string) builderFunc {
54 return func(t *testing.T, r *vcr.Recorder) (fantasy.LanguageModel, error) {
55 provider, err := anthropic.New(
56 anthropic.WithAPIKey(os.Getenv("CRUSH_ANTHROPIC_API_KEY")),
57 anthropic.WithHTTPClient(&http.Client{Transport: r}),
58 )
59 if err != nil {
60 return nil, err
61 }
62 return provider.LanguageModel(t.Context(), model)
63 }
64}
65
66func openaiBuilder(model string) builderFunc {
67 return func(t *testing.T, r *vcr.Recorder) (fantasy.LanguageModel, error) {
68 provider, err := openai.New(
69 openai.WithAPIKey(os.Getenv("CRUSH_OPENAI_API_KEY")),
70 openai.WithHTTPClient(&http.Client{Transport: r}),
71 )
72 if err != nil {
73 return nil, err
74 }
75 return provider.LanguageModel(t.Context(), model)
76 }
77}
78
79func openRouterBuilder(model string) builderFunc {
80 return func(t *testing.T, r *vcr.Recorder) (fantasy.LanguageModel, error) {
81 provider, err := openrouter.New(
82 openrouter.WithAPIKey(os.Getenv("CRUSH_OPENROUTER_API_KEY")),
83 openrouter.WithHTTPClient(&http.Client{Transport: r}),
84 )
85 if err != nil {
86 return nil, err
87 }
88 return provider.LanguageModel(t.Context(), model)
89 }
90}
91
92func zAIBuilder(model string) builderFunc {
93 return func(t *testing.T, r *vcr.Recorder) (fantasy.LanguageModel, error) {
94 provider, err := openaicompat.New(
95 openaicompat.WithBaseURL("https://api.z.ai/api/coding/paas/v4"),
96 openaicompat.WithAPIKey(os.Getenv("CRUSH_ZAI_API_KEY")),
97 openaicompat.WithHTTPClient(&http.Client{Transport: r}),
98 )
99 if err != nil {
100 return nil, err
101 }
102 return provider.LanguageModel(t.Context(), model)
103 }
104}
105
106func testEnv(t *testing.T) fakeEnv {
107 workingDir := filepath.Join("/tmp/crush-test/", t.Name())
108 os.RemoveAll(workingDir)
109
110 err := os.MkdirAll(workingDir, 0o755)
111 require.NoError(t, err)
112
113 conn, err := db.Connect(t.Context(), t.TempDir())
114 require.NoError(t, err)
115
116 q := db.New(conn)
117 sessions := session.NewService(q, conn)
118 messages := message.NewService(q)
119
120 permissions := permission.NewPermissionService(workingDir, true, []string{})
121 history := history.NewService(q, conn)
122 filetrackerService := filetracker.NewService(q)
123 lspClients := csync.NewMap[string, *lsp.Client]()
124
125 t.Cleanup(func() {
126 conn.Close()
127 os.RemoveAll(workingDir)
128 })
129
130 return fakeEnv{
131 workingDir,
132 sessions,
133 messages,
134 permissions,
135 history,
136 &filetrackerService,
137 lspClients,
138 }
139}
140
141func testSessionAgent(env fakeEnv, large, small fantasy.LanguageModel, systemPrompt string, tools ...fantasy.AgentTool) SessionAgent {
142 largeModel := Model{
143 Model: large,
144 CatwalkCfg: catwalk.Model{
145 ContextWindow: 200000,
146 DefaultMaxTokens: 10000,
147 },
148 }
149 smallModel := Model{
150 Model: small,
151 CatwalkCfg: catwalk.Model{
152 ContextWindow: 200000,
153 DefaultMaxTokens: 10000,
154 },
155 }
156 agent := NewSessionAgent(SessionAgentOptions{
157 LargeModel: largeModel,
158 SmallModel: smallModel,
159 SystemPrompt: systemPrompt,
160 IsYolo: true,
161 Sessions: env.sessions,
162 Messages: env.messages,
163 Tools: tools,
164 })
165 return agent
166}
167
168func coderAgent(r *vcr.Recorder, env fakeEnv, large, small fantasy.LanguageModel) (SessionAgent, error) {
169 fixedTime := func() time.Time {
170 t, _ := time.Parse("1/2/2006", "1/1/2025")
171 return t
172 }
173 prompt, err := coderPrompt(
174 prompt.WithTimeFunc(fixedTime),
175 prompt.WithPlatform("linux"),
176 prompt.WithWorkingDir(filepath.ToSlash(env.workingDir)),
177 )
178 if err != nil {
179 return nil, err
180 }
181 cfg, err := config.Init(env.workingDir, "", false)
182 if err != nil {
183 return nil, err
184 }
185
186 // NOTE(@andreynering): Set a fixed config to ensure cassettes match
187 // independently of user config on `$HOME/.config/crush/crush.json`.
188 cfg.Config().Options.Attribution = &config.Attribution{
189 TrailerStyle: "co-authored-by",
190 GeneratedWith: true,
191 }
192
193 // Clear some fields to avoid issues with VCR cassette matching.
194 cfg.Config().Options.SkillsPaths = nil
195 cfg.Config().Options.ContextPaths = nil
196 cfg.Config().LSP = nil
197
198 systemPrompt, err := prompt.Build(context.TODO(), large.Provider(), large.Model(), cfg)
199 if err != nil {
200 return nil, err
201 }
202
203 // Get the model name for the bash tool
204 modelName := large.Model() // fallback to ID if Name not available
205 if model := cfg.Config().GetModel(large.Provider(), large.Model()); model != nil {
206 modelName = model.Name
207 }
208
209 allTools := []fantasy.AgentTool{
210 tools.NewBashTool(env.permissions, env.workingDir, cfg.Config().Options.Attribution, modelName),
211 tools.NewDownloadTool(env.permissions, env.workingDir, r.GetDefaultClient()),
212 tools.NewEditTool(nil, env.permissions, env.history, *env.filetracker, env.workingDir),
213 tools.NewMultiEditTool(nil, env.permissions, env.history, *env.filetracker, env.workingDir),
214 tools.NewFetchTool(env.permissions, env.workingDir, r.GetDefaultClient()),
215 tools.NewGlobTool(env.workingDir),
216 tools.NewGrepTool(env.workingDir, cfg.Config().Tools.Grep),
217 tools.NewLsTool(env.permissions, env.workingDir, cfg.Config().Tools.Ls),
218 tools.NewSourcegraphTool(r.GetDefaultClient()),
219 tools.NewViewTool(nil, env.permissions, *env.filetracker, env.workingDir),
220 tools.NewWriteTool(nil, env.permissions, env.history, *env.filetracker, env.workingDir),
221 }
222
223 return testSessionAgent(env, large, small, systemPrompt, allTools...), nil
224}
225
226// createSimpleGoProject creates a simple Go project structure in the given directory.
227// It creates a go.mod file and a main.go file with a basic hello world program.
228func createSimpleGoProject(t *testing.T, dir string) {
229 goMod := `module example.com/testproject
230
231go 1.23
232`
233 err := os.WriteFile(dir+"/go.mod", []byte(goMod), 0o644)
234 require.NoError(t, err)
235
236 mainGo := `package main
237
238import "fmt"
239
240func main() {
241 fmt.Println("Hello, World!")
242}
243`
244 err = os.WriteFile(dir+"/main.go", []byte(mainGo), 0o644)
245 require.NoError(t, err)
246}