1package agent
2
3import (
4 "context"
5 "net/http"
6 "os"
7 "path/filepath"
8 "testing"
9 "time"
10
11 "charm.land/fantasy"
12 "charm.land/fantasy/providers/anthropic"
13 "charm.land/fantasy/providers/openai"
14 "charm.land/fantasy/providers/openaicompat"
15 "charm.land/fantasy/providers/openrouter"
16 "github.com/charmbracelet/catwalk/pkg/catwalk"
17 "github.com/charmbracelet/crush/internal/agent/prompt"
18 "github.com/charmbracelet/crush/internal/agent/tools"
19 "github.com/charmbracelet/crush/internal/config"
20 "github.com/charmbracelet/crush/internal/csync"
21 "github.com/charmbracelet/crush/internal/db"
22 "github.com/charmbracelet/crush/internal/history"
23 "github.com/charmbracelet/crush/internal/lsp"
24 "github.com/charmbracelet/crush/internal/message"
25 "github.com/charmbracelet/crush/internal/permission"
26 "github.com/charmbracelet/crush/internal/session"
27 "github.com/stretchr/testify/require"
28 "gopkg.in/dnaeon/go-vcr.v4/pkg/recorder"
29
30 _ "github.com/joho/godotenv/autoload"
31)
32
33type env struct {
34 workingDir string
35 sessions session.Service
36 messages message.Service
37 permissions permission.Service
38 history history.Service
39 lspClients *csync.Map[string, *lsp.Client]
40}
41
42type builderFunc func(t *testing.T, r *recorder.Recorder) (fantasy.LanguageModel, error)
43
44type modelPair struct {
45 name string
46 largeModel builderFunc
47 smallModel builderFunc
48}
49
50func anthropicBuilder(model string) builderFunc {
51 return func(t *testing.T, r *recorder.Recorder) (fantasy.LanguageModel, error) {
52 provider, err := anthropic.New(
53 anthropic.WithAPIKey(os.Getenv("CRUSH_ANTHROPIC_API_KEY")),
54 anthropic.WithHTTPClient(&http.Client{Transport: r}),
55 )
56 if err != nil {
57 return nil, err
58 }
59 return provider.LanguageModel(t.Context(), model)
60 }
61}
62
63func openaiBuilder(model string) builderFunc {
64 return func(t *testing.T, r *recorder.Recorder) (fantasy.LanguageModel, error) {
65 provider, err := openai.New(
66 openai.WithAPIKey(os.Getenv("CRUSH_OPENAI_API_KEY")),
67 openai.WithHTTPClient(&http.Client{Transport: r}),
68 )
69 if err != nil {
70 return nil, err
71 }
72 return provider.LanguageModel(t.Context(), model)
73 }
74}
75
76func openRouterBuilder(model string) builderFunc {
77 return func(t *testing.T, r *recorder.Recorder) (fantasy.LanguageModel, error) {
78 provider, err := openrouter.New(
79 openrouter.WithAPIKey(os.Getenv("CRUSH_OPENROUTER_API_KEY")),
80 openrouter.WithHTTPClient(&http.Client{Transport: r}),
81 )
82 if err != nil {
83 return nil, err
84 }
85 return provider.LanguageModel(t.Context(), model)
86 }
87}
88
89func zAIBuilder(model string) builderFunc {
90 return func(t *testing.T, r *recorder.Recorder) (fantasy.LanguageModel, error) {
91 provider, err := openaicompat.New(
92 openaicompat.WithBaseURL("https://api.z.ai/api/coding/paas/v4"),
93 openaicompat.WithAPIKey(os.Getenv("CRUSH_ZAI_API_KEY")),
94 openaicompat.WithHTTPClient(&http.Client{Transport: r}),
95 )
96 if err != nil {
97 return nil, err
98 }
99 return provider.LanguageModel(t.Context(), model)
100 }
101}
102
103func testEnv(t *testing.T) env {
104 workingDir := filepath.Join("/tmp/crush-test/", t.Name())
105 os.RemoveAll(workingDir)
106
107 err := os.MkdirAll(workingDir, 0o755)
108 require.NoError(t, err)
109
110 conn, err := db.Connect(t.Context(), t.TempDir())
111 require.NoError(t, err)
112
113 q := db.New(conn)
114 sessions := session.NewService(q)
115 messages := message.NewService(q)
116
117 permissions := permission.NewPermissionService(workingDir, true, []string{})
118 history := history.NewService(q, conn)
119 lspClients := csync.NewMap[string, *lsp.Client]()
120
121 t.Cleanup(func() {
122 conn.Close()
123 os.RemoveAll(workingDir)
124 })
125
126 return env{
127 workingDir,
128 sessions,
129 messages,
130 permissions,
131 history,
132 lspClients,
133 }
134}
135
136func testSessionAgent(env env, large, small fantasy.LanguageModel, systemPrompt string, tools ...fantasy.AgentTool) SessionAgent {
137 largeModel := Model{
138 Model: large,
139 CatwalkCfg: catwalk.Model{
140 ContextWindow: 200000,
141 DefaultMaxTokens: 10000,
142 },
143 }
144 smallModel := Model{
145 Model: small,
146 CatwalkCfg: catwalk.Model{
147 ContextWindow: 200000,
148 DefaultMaxTokens: 10000,
149 },
150 }
151 agent := NewSessionAgent(SessionAgentOptions{
152 LargeModel: largeModel,
153 SmallModel: smallModel,
154 SystemPromptPrefix: "",
155 SystemPrompt: systemPrompt,
156 DisableAutoSummarize: false,
157 IsYolo: true,
158 Sessions: env.sessions,
159 Messages: env.messages,
160 Tools: tools,
161 Hooks: nil,
162 })
163 return agent
164}
165
166func coderAgent(r *recorder.Recorder, env env, large, small fantasy.LanguageModel) (SessionAgent, error) {
167 fixedTime := func() time.Time {
168 t, _ := time.Parse("1/2/2006", "1/1/2025")
169 return t
170 }
171 prompt, err := coderPrompt(
172 prompt.WithTimeFunc(fixedTime),
173 prompt.WithPlatform("linux"),
174 prompt.WithWorkingDir(filepath.ToSlash(env.workingDir)),
175 )
176 if err != nil {
177 return nil, err
178 }
179 cfg, err := config.Init(env.workingDir, "", false)
180 if err != nil {
181 return nil, err
182 }
183
184 systemPrompt, err := prompt.Build(context.TODO(), large.Provider(), large.Model(), *cfg)
185 if err != nil {
186 return nil, err
187 }
188 allTools := []fantasy.AgentTool{
189 tools.NewBashTool(env.permissions, env.workingDir, cfg.Options.Attribution),
190 tools.NewDownloadTool(env.permissions, env.workingDir, r.GetDefaultClient()),
191 tools.NewEditTool(env.lspClients, env.permissions, env.history, env.workingDir),
192 tools.NewMultiEditTool(env.lspClients, env.permissions, env.history, env.workingDir),
193 tools.NewFetchTool(env.permissions, env.workingDir, r.GetDefaultClient()),
194 tools.NewGlobTool(env.workingDir),
195 tools.NewGrepTool(env.workingDir),
196 tools.NewLsTool(env.permissions, env.workingDir, cfg.Tools.Ls),
197 tools.NewSourcegraphTool(r.GetDefaultClient()),
198 tools.NewViewTool(env.lspClients, env.permissions, env.workingDir),
199 tools.NewWriteTool(env.lspClients, env.permissions, env.history, env.workingDir),
200 }
201
202 return testSessionAgent(env, large, small, systemPrompt, allTools...), nil
203}
204
205// createSimpleGoProject creates a simple Go project structure in the given directory.
206// It creates a go.mod file and a main.go file with a basic hello world program.
207func createSimpleGoProject(t *testing.T, dir string) {
208 goMod := `module example.com/testproject
209
210go 1.23
211`
212 err := os.WriteFile(dir+"/go.mod", []byte(goMod), 0o644)
213 require.NoError(t, err)
214
215 mainGo := `package main
216
217import "fmt"
218
219func main() {
220 fmt.Println("Hello, World!")
221}
222`
223 err = os.WriteFile(dir+"/main.go", []byte(mainGo), 0o644)
224 require.NoError(t, err)
225}