1package agent
2
3import (
4 "context"
5 "net/http"
6 "os"
7 "path/filepath"
8 "testing"
9 "time"
10
11 "charm.land/fantasy"
12 "charm.land/fantasy/providers/anthropic"
13 "charm.land/fantasy/providers/openai"
14 "charm.land/fantasy/providers/openaicompat"
15 "charm.land/fantasy/providers/openrouter"
16 "charm.land/x/vcr"
17 "github.com/charmbracelet/catwalk/pkg/catwalk"
18 "github.com/charmbracelet/crush/internal/agent/prompt"
19 "github.com/charmbracelet/crush/internal/agent/tools"
20 "github.com/charmbracelet/crush/internal/config"
21 "github.com/charmbracelet/crush/internal/csync"
22 "github.com/charmbracelet/crush/internal/db"
23 "github.com/charmbracelet/crush/internal/history"
24 "github.com/charmbracelet/crush/internal/lsp"
25 "github.com/charmbracelet/crush/internal/message"
26 "github.com/charmbracelet/crush/internal/permission"
27 "github.com/charmbracelet/crush/internal/session"
28 "github.com/stretchr/testify/require"
29
30 _ "github.com/joho/godotenv/autoload"
31)
32
33// fakeEnv is an environment for testing.
34type fakeEnv struct {
35 workingDir string
36 sessions session.Service
37 messages message.Service
38 permissions permission.Service
39 history history.Service
40 lspClients *csync.Map[string, *lsp.Client]
41}
42
43type builderFunc func(t *testing.T, r *vcr.Recorder) (fantasy.LanguageModel, error)
44
45type modelPair struct {
46 name string
47 largeModel builderFunc
48 smallModel builderFunc
49}
50
51func anthropicBuilder(model string) builderFunc {
52 return func(t *testing.T, r *vcr.Recorder) (fantasy.LanguageModel, error) {
53 provider, err := anthropic.New(
54 anthropic.WithAPIKey(os.Getenv("CRUSH_ANTHROPIC_API_KEY")),
55 anthropic.WithHTTPClient(&http.Client{Transport: r}),
56 )
57 if err != nil {
58 return nil, err
59 }
60 return provider.LanguageModel(t.Context(), model)
61 }
62}
63
64func openaiBuilder(model string) builderFunc {
65 return func(t *testing.T, r *vcr.Recorder) (fantasy.LanguageModel, error) {
66 provider, err := openai.New(
67 openai.WithAPIKey(os.Getenv("CRUSH_OPENAI_API_KEY")),
68 openai.WithHTTPClient(&http.Client{Transport: r}),
69 )
70 if err != nil {
71 return nil, err
72 }
73 return provider.LanguageModel(t.Context(), model)
74 }
75}
76
77func openRouterBuilder(model string) builderFunc {
78 return func(t *testing.T, r *vcr.Recorder) (fantasy.LanguageModel, error) {
79 provider, err := openrouter.New(
80 openrouter.WithAPIKey(os.Getenv("CRUSH_OPENROUTER_API_KEY")),
81 openrouter.WithHTTPClient(&http.Client{Transport: r}),
82 )
83 if err != nil {
84 return nil, err
85 }
86 return provider.LanguageModel(t.Context(), model)
87 }
88}
89
90func zAIBuilder(model string) builderFunc {
91 return func(t *testing.T, r *vcr.Recorder) (fantasy.LanguageModel, error) {
92 provider, err := openaicompat.New(
93 openaicompat.WithBaseURL("https://api.z.ai/api/coding/paas/v4"),
94 openaicompat.WithAPIKey(os.Getenv("CRUSH_ZAI_API_KEY")),
95 openaicompat.WithHTTPClient(&http.Client{Transport: r}),
96 )
97 if err != nil {
98 return nil, err
99 }
100 return provider.LanguageModel(t.Context(), model)
101 }
102}
103
104func testEnv(t *testing.T) fakeEnv {
105 workingDir := filepath.Join("/tmp/crush-test/", t.Name())
106 os.RemoveAll(workingDir)
107
108 err := os.MkdirAll(workingDir, 0o755)
109 require.NoError(t, err)
110
111 conn, err := db.Connect(t.Context(), t.TempDir())
112 require.NoError(t, err)
113
114 q := db.New(conn)
115 sessions := session.NewService(q)
116 messages := message.NewService(q)
117
118 permissions := permission.NewPermissionService(workingDir, true, []string{})
119 history := history.NewService(q, conn)
120 lspClients := csync.NewMap[string, *lsp.Client]()
121
122 t.Cleanup(func() {
123 conn.Close()
124 os.RemoveAll(workingDir)
125 })
126
127 return fakeEnv{
128 workingDir,
129 sessions,
130 messages,
131 permissions,
132 history,
133 lspClients,
134 }
135}
136
137func testSessionAgent(env fakeEnv, large, small fantasy.LanguageModel, systemPrompt string, tools ...fantasy.AgentTool) SessionAgent {
138 largeModel := Model{
139 Model: large,
140 CatwalkCfg: catwalk.Model{
141 ContextWindow: 200000,
142 DefaultMaxTokens: 10000,
143 },
144 }
145 smallModel := Model{
146 Model: small,
147 CatwalkCfg: catwalk.Model{
148 ContextWindow: 200000,
149 DefaultMaxTokens: 10000,
150 },
151 }
152 agent := NewSessionAgent(SessionAgentOptions{
153 LargeModel: largeModel,
154 SmallModel: smallModel,
155 SystemPromptPrefix: "",
156 SystemPrompt: systemPrompt,
157 DisableAutoSummarize: false,
158 IsYolo: true,
159 IsSubAgent: false,
160 HooksManager: nil,
161 WorkingDir: "",
162 Sessions: env.sessions,
163 Messages: env.messages,
164 Tools: tools,
165 })
166 return agent
167}
168
169func coderAgent(r *vcr.Recorder, env fakeEnv, large, small fantasy.LanguageModel) (SessionAgent, error) {
170 fixedTime := func() time.Time {
171 t, _ := time.Parse("1/2/2006", "1/1/2025")
172 return t
173 }
174 prompt, err := coderPrompt(
175 prompt.WithTimeFunc(fixedTime),
176 prompt.WithPlatform("linux"),
177 prompt.WithWorkingDir(filepath.ToSlash(env.workingDir)),
178 )
179 if err != nil {
180 return nil, err
181 }
182 cfg, err := config.Init(env.workingDir, "", false)
183 if err != nil {
184 return nil, err
185 }
186
187 // NOTE(@andreynering): Set a fixed config to ensure cassettes match
188 // independently of user config on `$HOME/.config/crush/crush.json`.
189 cfg.Options.Attribution = &config.Attribution{
190 TrailerStyle: "co-authored-by",
191 GeneratedWith: true,
192 }
193
194 systemPrompt, err := prompt.Build(context.TODO(), large.Provider(), large.Model(), *cfg)
195 if err != nil {
196 return nil, err
197 }
198
199 // Get the model name for the bash tool
200 modelName := large.Model() // fallback to ID if Name not available
201 if model := cfg.GetModel(large.Provider(), large.Model()); model != nil {
202 modelName = model.Name
203 }
204
205 allTools := []fantasy.AgentTool{
206 tools.NewBashTool(env.permissions, env.workingDir, cfg.Options.Attribution, modelName),
207 tools.NewDownloadTool(env.permissions, env.workingDir, r.GetDefaultClient()),
208 tools.NewEditTool(env.lspClients, env.permissions, env.history, env.workingDir),
209 tools.NewMultiEditTool(env.lspClients, env.permissions, env.history, env.workingDir),
210 tools.NewFetchTool(env.permissions, env.workingDir, r.GetDefaultClient()),
211 tools.NewGlobTool(env.workingDir),
212 tools.NewGrepTool(env.workingDir),
213 tools.NewLsTool(env.permissions, env.workingDir, cfg.Tools.Ls),
214 tools.NewSourcegraphTool(r.GetDefaultClient()),
215 tools.NewViewTool(env.lspClients, env.permissions, env.workingDir),
216 tools.NewWriteTool(env.lspClients, env.permissions, env.history, env.workingDir),
217 }
218
219 return testSessionAgent(env, large, small, systemPrompt, allTools...), nil
220}
221
222// createSimpleGoProject creates a simple Go project structure in the given directory.
223// It creates a go.mod file and a main.go file with a basic hello world program.
224func createSimpleGoProject(t *testing.T, dir string) {
225 goMod := `module example.com/testproject
226
227go 1.23
228`
229 err := os.WriteFile(dir+"/go.mod", []byte(goMod), 0o644)
230 require.NoError(t, err)
231
232 mainGo := `package main
233
234import "fmt"
235
236func main() {
237 fmt.Println("Hello, World!")
238}
239`
240 err = os.WriteFile(dir+"/main.go", []byte(mainGo), 0o644)
241 require.NoError(t, err)
242}