1package agent
2
3import (
4 "net/http"
5 "os"
6 "path/filepath"
7 "testing"
8 "time"
9
10 "charm.land/fantasy"
11 "charm.land/fantasy/providers/anthropic"
12 "charm.land/fantasy/providers/openai"
13 "charm.land/fantasy/providers/openaicompat"
14 "charm.land/fantasy/providers/openrouter"
15 "github.com/charmbracelet/catwalk/pkg/catwalk"
16 "github.com/charmbracelet/crush/internal/agent/prompt"
17 "github.com/charmbracelet/crush/internal/agent/tools"
18 "github.com/charmbracelet/crush/internal/config"
19 "github.com/charmbracelet/crush/internal/csync"
20 "github.com/charmbracelet/crush/internal/db"
21 "github.com/charmbracelet/crush/internal/history"
22 "github.com/charmbracelet/crush/internal/lsp"
23 "github.com/charmbracelet/crush/internal/message"
24 "github.com/charmbracelet/crush/internal/permission"
25 "github.com/charmbracelet/crush/internal/session"
26 "github.com/stretchr/testify/require"
27 "gopkg.in/dnaeon/go-vcr.v4/pkg/recorder"
28
29 _ "github.com/joho/godotenv/autoload"
30)
31
32type env struct {
33 workingDir string
34 sessions session.Service
35 messages message.Service
36 permissions permission.Service
37 history history.Service
38 lspClients *csync.Map[string, *lsp.Client]
39}
40
41type builderFunc func(t *testing.T, r *recorder.Recorder) (fantasy.LanguageModel, error)
42
43type modelPair struct {
44 name string
45 largeModel builderFunc
46 smallModel builderFunc
47}
48
49func anthropicBuilder(model string) builderFunc {
50 return func(t *testing.T, r *recorder.Recorder) (fantasy.LanguageModel, error) {
51 provider, err := anthropic.New(
52 anthropic.WithAPIKey(os.Getenv("CRUSH_ANTHROPIC_API_KEY")),
53 anthropic.WithHTTPClient(&http.Client{Transport: r}),
54 )
55 if err != nil {
56 return nil, err
57 }
58 return provider.LanguageModel(t.Context(), model)
59 }
60}
61
62func openaiBuilder(model string) builderFunc {
63 return func(t *testing.T, r *recorder.Recorder) (fantasy.LanguageModel, error) {
64 provider, err := openai.New(
65 openai.WithAPIKey(os.Getenv("CRUSH_OPENAI_API_KEY")),
66 openai.WithHTTPClient(&http.Client{Transport: r}),
67 )
68 if err != nil {
69 return nil, err
70 }
71 return provider.LanguageModel(t.Context(), model)
72 }
73}
74
75func openRouterBuilder(model string) builderFunc {
76 return func(t *testing.T, r *recorder.Recorder) (fantasy.LanguageModel, error) {
77 provider, err := openrouter.New(
78 openrouter.WithAPIKey(os.Getenv("CRUSH_OPENROUTER_API_KEY")),
79 openrouter.WithHTTPClient(&http.Client{Transport: r}),
80 )
81 if err != nil {
82 return nil, err
83 }
84 return provider.LanguageModel(t.Context(), model)
85 }
86}
87
88func zAIBuilder(model string) builderFunc {
89 return func(t *testing.T, r *recorder.Recorder) (fantasy.LanguageModel, error) {
90 provider, err := openaicompat.New(
91 openaicompat.WithBaseURL("https://api.z.ai/api/coding/paas/v4"),
92 openaicompat.WithAPIKey(os.Getenv("CRUSH_ZAI_API_KEY")),
93 openaicompat.WithHTTPClient(&http.Client{Transport: r}),
94 )
95 if err != nil {
96 return nil, err
97 }
98 return provider.LanguageModel(t.Context(), model)
99 }
100}
101
102func testEnv(t *testing.T) env {
103 testDir := filepath.Join("/tmp/crush-test/", t.Name())
104 os.RemoveAll(testDir)
105 err := os.MkdirAll(testDir, 0o755)
106 t.Cleanup(func() {
107 os.RemoveAll(testDir)
108 })
109 require.NoError(t, err)
110 workingDir := testDir
111 conn, err := db.Connect(t.Context(), t.TempDir())
112 require.NoError(t, err)
113 q := db.New(conn)
114 sessions := session.NewService(q)
115 messages := message.NewService(q)
116 permissions := permission.NewPermissionService(workingDir, true, []string{})
117 history := history.NewService(q, conn)
118 lspClients := csync.NewMap[string, *lsp.Client]()
119 return env{
120 workingDir,
121 sessions,
122 messages,
123 permissions,
124 history,
125 lspClients,
126 }
127}
128
129func testSessionAgent(env env, large, small fantasy.LanguageModel, systemPrompt string, tools ...fantasy.AgentTool) SessionAgent {
130 largeModel := Model{
131 Model: large,
132 CatwalkCfg: catwalk.Model{
133 ContextWindow: 200000,
134 DefaultMaxTokens: 10000,
135 },
136 }
137 smallModel := Model{
138 Model: small,
139 CatwalkCfg: catwalk.Model{
140 ContextWindow: 200000,
141 DefaultMaxTokens: 10000,
142 },
143 }
144 agent := NewSessionAgent(SessionAgentOptions{largeModel, smallModel, systemPrompt, false, env.sessions, env.messages, tools})
145 return agent
146}
147
148func coderAgent(r *recorder.Recorder, env env, large, small fantasy.LanguageModel) (SessionAgent, error) {
149 fixedTime := func() time.Time {
150 t, _ := time.Parse("1/2/2006", "1/1/2025")
151 return t
152 }
153 prompt, err := coderPrompt(
154 prompt.WithTimeFunc(fixedTime),
155 prompt.WithPlatform("linux"),
156 prompt.WithWorkingDir(env.workingDir),
157 )
158 if err != nil {
159 return nil, err
160 }
161 cfg, err := config.Init(env.workingDir, "", false)
162 if err != nil {
163 return nil, err
164 }
165
166 systemPrompt, err := prompt.Build(large.Provider(), large.Model(), *cfg)
167 if err != nil {
168 return nil, err
169 }
170 allTools := []fantasy.AgentTool{
171 tools.NewBashTool(env.permissions, env.workingDir, cfg.Options.Attribution),
172 tools.NewDownloadTool(env.permissions, env.workingDir, r.GetDefaultClient()),
173 tools.NewEditTool(env.lspClients, env.permissions, env.history, env.workingDir),
174 tools.NewMultiEditTool(env.lspClients, env.permissions, env.history, env.workingDir),
175 tools.NewFetchTool(env.permissions, env.workingDir, r.GetDefaultClient()),
176 tools.NewGlobTool(env.workingDir),
177 tools.NewGrepTool(env.workingDir),
178 tools.NewLsTool(env.permissions, env.workingDir, cfg.Tools.Ls),
179 tools.NewSourcegraphTool(r.GetDefaultClient()),
180 tools.NewViewTool(env.lspClients, env.permissions, env.workingDir),
181 tools.NewWriteTool(env.lspClients, env.permissions, env.history, env.workingDir),
182 }
183
184 return testSessionAgent(env, large, small, systemPrompt, allTools...), nil
185}
186
187// createSimpleGoProject creates a simple Go project structure in the given directory.
188// It creates a go.mod file and a main.go file with a basic hello world program.
189func createSimpleGoProject(t *testing.T, dir string) {
190 goMod := `module example.com/testproject
191
192go 1.23
193`
194 err := os.WriteFile(dir+"/go.mod", []byte(goMod), 0o644)
195 require.NoError(t, err)
196
197 mainGo := `package main
198
199import "fmt"
200
201func main() {
202 fmt.Println("Hello, World!")
203}
204`
205 err = os.WriteFile(dir+"/main.go", []byte(mainGo), 0o644)
206 require.NoError(t, err)
207}