1package agent
2
3import (
4 "context"
5 "net/http"
6 "os"
7 "path/filepath"
8 "testing"
9 "time"
10
11 "charm.land/catwalk/pkg/catwalk"
12 "charm.land/fantasy"
13 "charm.land/fantasy/providers/anthropic"
14 "charm.land/fantasy/providers/openai"
15 "charm.land/fantasy/providers/openaicompat"
16 "charm.land/fantasy/providers/openrouter"
17 "charm.land/x/vcr"
18 "github.com/charmbracelet/crush/internal/agent/prompt"
19 "github.com/charmbracelet/crush/internal/agent/tools"
20 "github.com/charmbracelet/crush/internal/config"
21 "github.com/charmbracelet/crush/internal/csync"
22 "github.com/charmbracelet/crush/internal/db"
23 "github.com/charmbracelet/crush/internal/filetracker"
24 "github.com/charmbracelet/crush/internal/history"
25 "github.com/charmbracelet/crush/internal/lsp"
26 "github.com/charmbracelet/crush/internal/message"
27 "github.com/charmbracelet/crush/internal/permission"
28 "github.com/charmbracelet/crush/internal/session"
29 "github.com/stretchr/testify/require"
30
31 _ "github.com/joho/godotenv/autoload"
32)
33
34// fakeEnv is an environment for testing.
35type fakeEnv struct {
36 workingDir string
37 sessions session.Service
38 messages message.Service
39 permissions permission.Service
40 history history.Service
41 filetracker *filetracker.Service
42 lspClients *csync.Map[string, *lsp.Client]
43}
44
45type builderFunc func(t *testing.T, r *vcr.Recorder) (fantasy.LanguageModel, error)
46
47type modelPair struct {
48 name string
49 largeModel builderFunc
50 smallModel builderFunc
51}
52
53func anthropicBuilder(model string) builderFunc {
54 return func(t *testing.T, r *vcr.Recorder) (fantasy.LanguageModel, error) {
55 provider, err := anthropic.New(
56 anthropic.WithAPIKey(os.Getenv("CRUSH_ANTHROPIC_API_KEY")),
57 anthropic.WithHTTPClient(&http.Client{Transport: r}),
58 )
59 if err != nil {
60 return nil, err
61 }
62 return provider.LanguageModel(t.Context(), model)
63 }
64}
65
66func openaiBuilder(model string) builderFunc {
67 return func(t *testing.T, r *vcr.Recorder) (fantasy.LanguageModel, error) {
68 provider, err := openai.New(
69 openai.WithAPIKey(os.Getenv("CRUSH_OPENAI_API_KEY")),
70 openai.WithHTTPClient(&http.Client{Transport: r}),
71 )
72 if err != nil {
73 return nil, err
74 }
75 return provider.LanguageModel(t.Context(), model)
76 }
77}
78
79func openRouterBuilder(model string) builderFunc {
80 return func(t *testing.T, r *vcr.Recorder) (fantasy.LanguageModel, error) {
81 provider, err := openrouter.New(
82 openrouter.WithAPIKey(os.Getenv("CRUSH_OPENROUTER_API_KEY")),
83 openrouter.WithHTTPClient(&http.Client{Transport: r}),
84 )
85 if err != nil {
86 return nil, err
87 }
88 return provider.LanguageModel(t.Context(), model)
89 }
90}
91
92func zAIBuilder(model string) builderFunc {
93 return func(t *testing.T, r *vcr.Recorder) (fantasy.LanguageModel, error) {
94 provider, err := openaicompat.New(
95 openaicompat.WithBaseURL("https://api.z.ai/api/coding/paas/v4"),
96 openaicompat.WithAPIKey(os.Getenv("CRUSH_ZAI_API_KEY")),
97 openaicompat.WithHTTPClient(&http.Client{Transport: r}),
98 )
99 if err != nil {
100 return nil, err
101 }
102 return provider.LanguageModel(t.Context(), model)
103 }
104}
105
106func testEnv(t *testing.T) fakeEnv {
107 workingDir := filepath.Join("/tmp/crush-test/", t.Name())
108 os.RemoveAll(workingDir)
109
110 err := os.MkdirAll(workingDir, 0o755)
111 require.NoError(t, err)
112
113 conn, err := db.Connect(t.Context(), t.TempDir())
114 require.NoError(t, err)
115
116 q := db.New(conn)
117 sessions := session.NewService(q, conn)
118 messages := message.NewService(q)
119
120 permissions := permission.NewPermissionService(workingDir, true, []string{})
121 history := history.NewService(q, conn)
122 filetrackerService := filetracker.NewService(q)
123 lspClients := csync.NewMap[string, *lsp.Client]()
124
125 t.Cleanup(func() {
126 conn.Close()
127 os.RemoveAll(workingDir)
128 })
129
130 return fakeEnv{
131 workingDir,
132 sessions,
133 messages,
134 permissions,
135 history,
136 &filetrackerService,
137 lspClients,
138 }
139}
140
141func testSessionAgent(env fakeEnv, large, small fantasy.LanguageModel, systemPrompt string, tools ...fantasy.AgentTool) SessionAgent {
142 largeModel := Model{
143 Model: large,
144 CatwalkCfg: catwalk.Model{
145 ContextWindow: 200000,
146 DefaultMaxTokens: 10000,
147 },
148 }
149 smallModel := Model{
150 Model: small,
151 CatwalkCfg: catwalk.Model{
152 ContextWindow: 200000,
153 DefaultMaxTokens: 10000,
154 },
155 }
156 agent := NewSessionAgent(SessionAgentOptions{largeModel, smallModel, "", systemPrompt, false, false, true, env.sessions, env.messages, tools})
157 return agent
158}
159
160func coderAgent(r *vcr.Recorder, env fakeEnv, large, small fantasy.LanguageModel) (SessionAgent, error) {
161 fixedTime := func() time.Time {
162 t, _ := time.Parse("1/2/2006", "1/1/2025")
163 return t
164 }
165 prompt, err := coderPrompt(
166 prompt.WithTimeFunc(fixedTime),
167 prompt.WithPlatform("linux"),
168 prompt.WithWorkingDir(filepath.ToSlash(env.workingDir)),
169 )
170 if err != nil {
171 return nil, err
172 }
173 cfg, err := config.Init(env.workingDir, "", false)
174 if err != nil {
175 return nil, err
176 }
177
178 // NOTE(@andreynering): Set a fixed config to ensure cassettes match
179 // independently of user config on `$HOME/.config/crush/crush.json`.
180 cfg.Options.Attribution = &config.Attribution{
181 TrailerStyle: "co-authored-by",
182 GeneratedWith: true,
183 }
184
185 // Clear skills paths to ensure test reproducibility - user's skills
186 // would be included in prompt and break VCR cassette matching.
187 cfg.Options.SkillsPaths = []string{}
188
189 // Clear LSP config to ensure test reproducibility - user's LSP config
190 // would be included in prompt and break VCR cassette matching.
191 cfg.LSP = nil
192
193 systemPrompt, err := prompt.Build(context.TODO(), large.Provider(), large.Model(), *cfg)
194 if err != nil {
195 return nil, err
196 }
197
198 // Get the model name for the bash tool
199 modelName := large.Model() // fallback to ID if Name not available
200 if model := cfg.GetModel(large.Provider(), large.Model()); model != nil {
201 modelName = model.Name
202 }
203
204 allTools := []fantasy.AgentTool{
205 tools.NewBashTool(env.permissions, env.workingDir, cfg.Options.Attribution, modelName),
206 tools.NewDownloadTool(env.permissions, env.workingDir, r.GetDefaultClient()),
207 tools.NewEditTool(nil, env.permissions, env.history, *env.filetracker, env.workingDir),
208 tools.NewMultiEditTool(nil, env.permissions, env.history, *env.filetracker, env.workingDir),
209 tools.NewFetchTool(env.permissions, env.workingDir, r.GetDefaultClient()),
210 tools.NewGlobTool(env.workingDir),
211 tools.NewGrepTool(env.workingDir, cfg.Tools.Grep),
212 tools.NewLsTool(env.permissions, env.workingDir, cfg.Tools.Ls),
213 tools.NewSourcegraphTool(r.GetDefaultClient()),
214 tools.NewViewTool(nil, env.permissions, *env.filetracker, env.workingDir),
215 tools.NewWriteTool(nil, env.permissions, env.history, *env.filetracker, env.workingDir),
216 }
217
218 return testSessionAgent(env, large, small, systemPrompt, allTools...), nil
219}
220
221// createSimpleGoProject creates a simple Go project structure in the given directory.
222// It creates a go.mod file and a main.go file with a basic hello world program.
223func createSimpleGoProject(t *testing.T, dir string) {
224 goMod := `module example.com/testproject
225
226go 1.23
227`
228 err := os.WriteFile(dir+"/go.mod", []byte(goMod), 0o644)
229 require.NoError(t, err)
230
231 mainGo := `package main
232
233import "fmt"
234
235func main() {
236 fmt.Println("Hello, World!")
237}
238`
239 err = os.WriteFile(dir+"/main.go", []byte(mainGo), 0o644)
240 require.NoError(t, err)
241}