1package agent
2
3import (
4 "context"
5 "net/http"
6 "os"
7 "path/filepath"
8 "testing"
9 "time"
10
11 "charm.land/fantasy"
12 "charm.land/fantasy/providers/anthropic"
13 "charm.land/fantasy/providers/openai"
14 "charm.land/fantasy/providers/openaicompat"
15 "charm.land/fantasy/providers/openrouter"
16 "github.com/charmbracelet/catwalk/pkg/catwalk"
17 "github.com/charmbracelet/crush/internal/agent/prompt"
18 "github.com/charmbracelet/crush/internal/agent/tools"
19 "github.com/charmbracelet/crush/internal/config"
20 "github.com/charmbracelet/crush/internal/csync"
21 "github.com/charmbracelet/crush/internal/db"
22 "github.com/charmbracelet/crush/internal/history"
23 "github.com/charmbracelet/crush/internal/lsp"
24 "github.com/charmbracelet/crush/internal/message"
25 "github.com/charmbracelet/crush/internal/permission"
26 "github.com/charmbracelet/crush/internal/session"
27 "github.com/stretchr/testify/require"
28 "gopkg.in/dnaeon/go-vcr.v4/pkg/recorder"
29
30 _ "github.com/joho/godotenv/autoload"
31)
32
33type env struct {
34 workingDir string
35 sessions session.Service
36 messages message.Service
37 permissions permission.Service
38 history history.Service
39 lspClients *csync.Map[string, *lsp.Client]
40}
41
42type builderFunc func(t *testing.T, r *recorder.Recorder) (fantasy.LanguageModel, error)
43
44type modelPair struct {
45 name string
46 largeModel builderFunc
47 smallModel builderFunc
48}
49
50func anthropicBuilder(model string) builderFunc {
51 return func(t *testing.T, r *recorder.Recorder) (fantasy.LanguageModel, error) {
52 provider, err := anthropic.New(
53 anthropic.WithAPIKey(os.Getenv("CRUSH_ANTHROPIC_API_KEY")),
54 anthropic.WithHTTPClient(&http.Client{Transport: r}),
55 )
56 if err != nil {
57 return nil, err
58 }
59 return provider.LanguageModel(t.Context(), model)
60 }
61}
62
63func openaiBuilder(model string) builderFunc {
64 return func(t *testing.T, r *recorder.Recorder) (fantasy.LanguageModel, error) {
65 provider, err := openai.New(
66 openai.WithAPIKey(os.Getenv("CRUSH_OPENAI_API_KEY")),
67 openai.WithHTTPClient(&http.Client{Transport: r}),
68 )
69 if err != nil {
70 return nil, err
71 }
72 return provider.LanguageModel(t.Context(), model)
73 }
74}
75
76func openRouterBuilder(model string) builderFunc {
77 return func(t *testing.T, r *recorder.Recorder) (fantasy.LanguageModel, error) {
78 provider, err := openrouter.New(
79 openrouter.WithAPIKey(os.Getenv("CRUSH_OPENROUTER_API_KEY")),
80 openrouter.WithHTTPClient(&http.Client{Transport: r}),
81 )
82 if err != nil {
83 return nil, err
84 }
85 return provider.LanguageModel(t.Context(), model)
86 }
87}
88
89func zAIBuilder(model string) builderFunc {
90 return func(t *testing.T, r *recorder.Recorder) (fantasy.LanguageModel, error) {
91 provider, err := openaicompat.New(
92 openaicompat.WithBaseURL("https://api.z.ai/api/coding/paas/v4"),
93 openaicompat.WithAPIKey(os.Getenv("CRUSH_ZAI_API_KEY")),
94 openaicompat.WithHTTPClient(&http.Client{Transport: r}),
95 )
96 if err != nil {
97 return nil, err
98 }
99 return provider.LanguageModel(t.Context(), model)
100 }
101}
102
103func testEnv(t *testing.T) env {
104 workingDir := filepath.Join("/tmp/crush-test/", t.Name())
105 os.RemoveAll(workingDir)
106
107 err := os.MkdirAll(workingDir, 0o755)
108 require.NoError(t, err)
109
110 conn, err := db.Connect(t.Context(), t.TempDir())
111 require.NoError(t, err)
112
113 q := db.New(conn)
114 sessions := session.NewService(q)
115 messages := message.NewService(q)
116
117 permissions := permission.NewPermissionService(workingDir, true, []string{})
118 history := history.NewService(q, conn)
119 lspClients := csync.NewMap[string, *lsp.Client]()
120
121 t.Cleanup(func() {
122 conn.Close()
123 os.RemoveAll(workingDir)
124 })
125
126 return env{
127 workingDir,
128 sessions,
129 messages,
130 permissions,
131 history,
132 lspClients,
133 }
134}
135
136func testSessionAgent(env env, large, small fantasy.LanguageModel, systemPrompt string, tools ...fantasy.AgentTool) SessionAgent {
137 largeModel := Model{
138 Model: large,
139 CatwalkCfg: catwalk.Model{
140 ContextWindow: 200000,
141 DefaultMaxTokens: 10000,
142 },
143 }
144 smallModel := Model{
145 Model: small,
146 CatwalkCfg: catwalk.Model{
147 ContextWindow: 200000,
148 DefaultMaxTokens: 10000,
149 },
150 }
151 agent := NewSessionAgent(SessionAgentOptions{largeModel, smallModel, "", systemPrompt, false, true, env.sessions, env.messages, tools})
152 return agent
153}
154
155func coderAgent(r *recorder.Recorder, env env, large, small fantasy.LanguageModel) (SessionAgent, error) {
156 fixedTime := func() time.Time {
157 t, _ := time.Parse("1/2/2006", "1/1/2025")
158 return t
159 }
160 prompt, err := coderPrompt(
161 prompt.WithTimeFunc(fixedTime),
162 prompt.WithPlatform("linux"),
163 prompt.WithWorkingDir(filepath.ToSlash(env.workingDir)),
164 )
165 if err != nil {
166 return nil, err
167 }
168 cfg, err := config.Init(env.workingDir, "", false)
169 if err != nil {
170 return nil, err
171 }
172
173 systemPrompt, err := prompt.Build(context.TODO(), large.Provider(), large.Model(), *cfg)
174 if err != nil {
175 return nil, err
176 }
177 allTools := []fantasy.AgentTool{
178 tools.NewWebFetchTool(env.workingDir, r.GetDefaultClient()),
179 tools.NewBashTool(env.permissions, env.workingDir, cfg.Options.Attribution),
180 tools.NewDownloadTool(env.permissions, env.workingDir, r.GetDefaultClient()),
181 tools.NewEditTool(env.lspClients, env.permissions, env.history, env.workingDir),
182 tools.NewMultiEditTool(env.lspClients, env.permissions, env.history, env.workingDir),
183 tools.NewGlobTool(env.workingDir),
184 tools.NewGrepTool(env.workingDir),
185 tools.NewLsTool(env.permissions, env.workingDir, cfg.Tools.Ls),
186 tools.NewSourcegraphTool(r.GetDefaultClient()),
187 tools.NewViewTool(env.lspClients, env.permissions, env.workingDir),
188 tools.NewWriteTool(env.lspClients, env.permissions, env.history, env.workingDir),
189 }
190
191 return testSessionAgent(env, large, small, systemPrompt, allTools...), nil
192}
193
194// createSimpleGoProject creates a simple Go project structure in the given directory.
195// It creates a go.mod file and a main.go file with a basic hello world program.
196func createSimpleGoProject(t *testing.T, dir string) {
197 goMod := `module example.com/testproject
198
199go 1.23
200`
201 err := os.WriteFile(dir+"/go.mod", []byte(goMod), 0o644)
202 require.NoError(t, err)
203
204 mainGo := `package main
205
206import "fmt"
207
208func main() {
209 fmt.Println("Hello, World!")
210}
211`
212 err = os.WriteFile(dir+"/main.go", []byte(mainGo), 0o644)
213 require.NoError(t, err)
214}