1package agent
2
3import (
4 "net/http"
5 "os"
6 "path/filepath"
7 "testing"
8 "time"
9
10 "github.com/charmbracelet/catwalk/pkg/catwalk"
11 "github.com/charmbracelet/crush/internal/agent/prompt"
12 "github.com/charmbracelet/crush/internal/agent/tools"
13 "github.com/charmbracelet/crush/internal/config"
14 "github.com/charmbracelet/crush/internal/csync"
15 "github.com/charmbracelet/crush/internal/db"
16 "github.com/charmbracelet/crush/internal/history"
17 "github.com/charmbracelet/crush/internal/lsp"
18 "github.com/charmbracelet/crush/internal/message"
19 "github.com/charmbracelet/crush/internal/permission"
20 "github.com/charmbracelet/crush/internal/session"
21 "github.com/charmbracelet/fantasy/ai"
22 "github.com/charmbracelet/fantasy/anthropic"
23 "github.com/charmbracelet/fantasy/openai"
24 "github.com/charmbracelet/fantasy/openaicompat"
25 "github.com/charmbracelet/fantasy/openrouter"
26 "github.com/stretchr/testify/require"
27 "gopkg.in/dnaeon/go-vcr.v4/pkg/recorder"
28
29 _ "github.com/joho/godotenv/autoload"
30)
31
32type env struct {
33 workingDir string
34 sessions session.Service
35 messages message.Service
36 permissions permission.Service
37 history history.Service
38 lspClients *csync.Map[string, *lsp.Client]
39}
40
41type builderFunc func(t *testing.T, r *recorder.Recorder) (ai.LanguageModel, error)
42
43type modelPair struct {
44 name string
45 largeModel builderFunc
46 smallModel builderFunc
47}
48
49func anthropicBuilder(model string) builderFunc {
50 return func(_ *testing.T, r *recorder.Recorder) (ai.LanguageModel, error) {
51 provider := anthropic.New(
52 anthropic.WithAPIKey(os.Getenv("CRUSH_ANTHROPIC_API_KEY")),
53 anthropic.WithHTTPClient(&http.Client{Transport: r}),
54 )
55 return provider.LanguageModel(model)
56 }
57}
58
59func openaiBuilder(model string) builderFunc {
60 return func(_ *testing.T, r *recorder.Recorder) (ai.LanguageModel, error) {
61 provider := openai.New(
62 openai.WithAPIKey(os.Getenv("CRUSH_OPENAI_API_KEY")),
63 openai.WithHTTPClient(&http.Client{Transport: r}),
64 )
65 return provider.LanguageModel(model)
66 }
67}
68
69func openRouterBuilder(model string) builderFunc {
70 return func(t *testing.T, r *recorder.Recorder) (ai.LanguageModel, error) {
71 provider := openrouter.New(
72 openrouter.WithAPIKey(os.Getenv("CRUSH_OPENROUTER_API_KEY")),
73 openrouter.WithHTTPClient(&http.Client{Transport: r}),
74 )
75 return provider.LanguageModel(model)
76 }
77}
78
79func zAIBuilder(model string) builderFunc {
80 return func(t *testing.T, r *recorder.Recorder) (ai.LanguageModel, error) {
81 provider := openaicompat.New(
82 openaicompat.WithBaseURL("https://api.z.ai/api/coding/paas/v4"),
83 openaicompat.WithAPIKey(os.Getenv("CRUSH_ZAI_API_KEY")),
84 openaicompat.WithHTTPClient(&http.Client{Transport: r}),
85 )
86 return provider.LanguageModel(model)
87 }
88}
89
90func testEnv(t *testing.T) env {
91 testDir := filepath.Join("/tmp/crush-test/", t.Name())
92 os.RemoveAll(testDir)
93 err := os.MkdirAll(testDir, 0o755)
94 t.Cleanup(func() {
95 os.RemoveAll(testDir)
96 })
97 require.NoError(t, err)
98 workingDir := testDir
99 conn, err := db.Connect(t.Context(), t.TempDir())
100 require.NoError(t, err)
101 q := db.New(conn)
102 sessions := session.NewService(q)
103 messages := message.NewService(q)
104 permissions := permission.NewPermissionService(workingDir, true, []string{})
105 history := history.NewService(q, conn)
106 lspClients := csync.NewMap[string, *lsp.Client]()
107 return env{
108 workingDir,
109 sessions,
110 messages,
111 permissions,
112 history,
113 lspClients,
114 }
115}
116
117func testSessionAgent(env env, large, small ai.LanguageModel, systemPrompt string, tools ...ai.AgentTool) SessionAgent {
118 largeModel := Model{
119 Model: large,
120 CatwalkCfg: catwalk.Model{
121 ContextWindow: 200000,
122 DefaultMaxTokens: 10000,
123 },
124 }
125 smallModel := Model{
126 Model: small,
127 CatwalkCfg: catwalk.Model{
128 ContextWindow: 200000,
129 DefaultMaxTokens: 10000,
130 },
131 }
132 agent := NewSessionAgent(SessionAgentOptions{largeModel, smallModel, systemPrompt, false, env.sessions, env.messages, tools})
133 return agent
134}
135
136func coderAgent(r *recorder.Recorder, env env, large, small ai.LanguageModel) (SessionAgent, error) {
137 fixedTime := func() time.Time {
138 t, _ := time.Parse("1/2/2006", "1/1/2025")
139 return t
140 }
141 prompt, err := coderPrompt(
142 prompt.WithTimeFunc(fixedTime),
143 prompt.WithPlatform("linux"),
144 prompt.WithWorkingDir(env.workingDir),
145 )
146 if err != nil {
147 return nil, err
148 }
149 cfg, err := config.Init(env.workingDir, "", false)
150 if err != nil {
151 return nil, err
152 }
153
154 systemPrompt, err := prompt.Build(large.Provider(), large.Model(), *cfg)
155 if err != nil {
156 return nil, err
157 }
158 allTools := []ai.AgentTool{
159 tools.NewBashTool(env.permissions, env.workingDir, cfg.Options.Attribution),
160 tools.NewDownloadTool(env.permissions, env.workingDir, r.GetDefaultClient()),
161 tools.NewEditTool(env.lspClients, env.permissions, env.history, env.workingDir),
162 tools.NewMultiEditTool(env.lspClients, env.permissions, env.history, env.workingDir),
163 tools.NewFetchTool(env.permissions, env.workingDir, r.GetDefaultClient()),
164 tools.NewGlobTool(env.workingDir),
165 tools.NewGrepTool(env.workingDir),
166 tools.NewLsTool(env.permissions, env.workingDir, cfg.Tools.Ls),
167 tools.NewSourcegraphTool(r.GetDefaultClient()),
168 tools.NewViewTool(env.lspClients, env.permissions, env.workingDir),
169 tools.NewWriteTool(env.lspClients, env.permissions, env.history, env.workingDir),
170 }
171
172 return testSessionAgent(env, large, small, systemPrompt, allTools...), nil
173}
174
175// createSimpleGoProject creates a simple Go project structure in the given directory.
176// It creates a go.mod file and a main.go file with a basic hello world program.
177func createSimpleGoProject(t *testing.T, dir string) {
178 goMod := `module example.com/testproject
179
180go 1.23
181`
182 err := os.WriteFile(dir+"/go.mod", []byte(goMod), 0o644)
183 require.NoError(t, err)
184
185 mainGo := `package main
186
187import "fmt"
188
189func main() {
190 fmt.Println("Hello, World!")
191}
192`
193 err = os.WriteFile(dir+"/main.go", []byte(mainGo), 0o644)
194 require.NoError(t, err)
195}