1package agent
2
3import (
4 "context"
5 "net/http"
6 "os"
7 "path/filepath"
8 "strings"
9 "testing"
10 "time"
11
12 "charm.land/fantasy"
13 "charm.land/fantasy/providers/anthropic"
14 "charm.land/fantasy/providers/openai"
15 "charm.land/fantasy/providers/openaicompat"
16 "charm.land/fantasy/providers/openrouter"
17 "github.com/charmbracelet/catwalk/pkg/catwalk"
18 "github.com/charmbracelet/crush/internal/agent/prompt"
19 "github.com/charmbracelet/crush/internal/agent/tools"
20 "github.com/charmbracelet/crush/internal/config"
21 "github.com/charmbracelet/crush/internal/csync"
22 "github.com/charmbracelet/crush/internal/db"
23 "github.com/charmbracelet/crush/internal/history"
24 "github.com/charmbracelet/crush/internal/lsp"
25 "github.com/charmbracelet/crush/internal/message"
26 "github.com/charmbracelet/crush/internal/permission"
27 "github.com/charmbracelet/crush/internal/session"
28 "github.com/stretchr/testify/require"
29 "gopkg.in/dnaeon/go-vcr.v4/pkg/recorder"
30
31 _ "github.com/joho/godotenv/autoload"
32)
33
34type env struct {
35 workingDir string
36 sessions session.Service
37 messages message.Service
38 permissions permission.Service
39 history history.Service
40 lspClients *csync.Map[string, *lsp.Client]
41}
42
43type builderFunc func(t *testing.T, r *recorder.Recorder) (fantasy.LanguageModel, error)
44
45type modelPair struct {
46 name string
47 largeModel builderFunc
48 smallModel builderFunc
49}
50
51func anthropicBuilder(model string) builderFunc {
52 return func(t *testing.T, r *recorder.Recorder) (fantasy.LanguageModel, error) {
53 provider, err := anthropic.New(
54 anthropic.WithAPIKey(os.Getenv("CRUSH_ANTHROPIC_API_KEY")),
55 anthropic.WithHTTPClient(&http.Client{Transport: r}),
56 )
57 if err != nil {
58 return nil, err
59 }
60 return provider.LanguageModel(t.Context(), model)
61 }
62}
63
64func openaiBuilder(model string) builderFunc {
65 return func(t *testing.T, r *recorder.Recorder) (fantasy.LanguageModel, error) {
66 provider, err := openai.New(
67 openai.WithAPIKey(os.Getenv("CRUSH_OPENAI_API_KEY")),
68 openai.WithHTTPClient(&http.Client{Transport: r}),
69 )
70 if err != nil {
71 return nil, err
72 }
73 return provider.LanguageModel(t.Context(), model)
74 }
75}
76
77func openRouterBuilder(model string) builderFunc {
78 return func(t *testing.T, r *recorder.Recorder) (fantasy.LanguageModel, error) {
79 provider, err := openrouter.New(
80 openrouter.WithAPIKey(os.Getenv("CRUSH_OPENROUTER_API_KEY")),
81 openrouter.WithHTTPClient(&http.Client{Transport: r}),
82 )
83 if err != nil {
84 return nil, err
85 }
86 return provider.LanguageModel(t.Context(), model)
87 }
88}
89
90func zAIBuilder(model string) builderFunc {
91 return func(t *testing.T, r *recorder.Recorder) (fantasy.LanguageModel, error) {
92 provider, err := openaicompat.New(
93 openaicompat.WithBaseURL("https://api.z.ai/api/coding/paas/v4"),
94 openaicompat.WithAPIKey(os.Getenv("CRUSH_ZAI_API_KEY")),
95 openaicompat.WithHTTPClient(&http.Client{Transport: r}),
96 )
97 if err != nil {
98 return nil, err
99 }
100 return provider.LanguageModel(t.Context(), model)
101 }
102}
103
104func testEnv(t *testing.T) env {
105 testDir := filepath.Join("/tmp/crush-test/", t.Name())
106 os.RemoveAll(testDir)
107 err := os.MkdirAll(testDir, 0o755)
108 require.NoError(t, err)
109 workingDir := testDir
110 conn, err := db.Connect(t.Context(), t.TempDir())
111 require.NoError(t, err)
112 q := db.New(conn)
113 sessions := session.NewService(q)
114 messages := message.NewService(q)
115 permissions := permission.NewPermissionService(workingDir, true, []string{})
116 history := history.NewService(q, conn)
117 lspClients := csync.NewMap[string, *lsp.Client]()
118
119 t.Cleanup(func() {
120 conn.Close()
121 os.RemoveAll(testDir)
122 })
123
124 return env{
125 workingDir,
126 sessions,
127 messages,
128 permissions,
129 history,
130 lspClients,
131 }
132}
133
134func testSessionAgent(env env, large, small fantasy.LanguageModel, systemPrompt string, tools ...fantasy.AgentTool) SessionAgent {
135 largeModel := Model{
136 Model: large,
137 CatwalkCfg: catwalk.Model{
138 ContextWindow: 200000,
139 DefaultMaxTokens: 10000,
140 },
141 }
142 smallModel := Model{
143 Model: small,
144 CatwalkCfg: catwalk.Model{
145 ContextWindow: 200000,
146 DefaultMaxTokens: 10000,
147 },
148 }
149 agent := NewSessionAgent(SessionAgentOptions{largeModel, smallModel, "", systemPrompt, false, true, env.sessions, env.messages, tools})
150 return agent
151}
152
153func coderAgent(r *recorder.Recorder, env env, large, small fantasy.LanguageModel) (SessionAgent, error) {
154 fixedTime := func() time.Time {
155 t, _ := time.Parse("1/2/2006", "1/1/2025")
156 return t
157 }
158 path := filepath.ToSlash(env.workingDir)
159 // try fixing windows
160 if strings.Contains(path, ":") {
161 path = strings.Split(path, ":")[1]
162 }
163 prompt, err := coderPrompt(
164 prompt.WithTimeFunc(fixedTime),
165 prompt.WithPlatform("linux"),
166 prompt.WithWorkingDir(path),
167 )
168 if err != nil {
169 return nil, err
170 }
171 cfg, err := config.Init(env.workingDir, "", false)
172 if err != nil {
173 return nil, err
174 }
175
176 systemPrompt, err := prompt.Build(context.TODO(), large.Provider(), large.Model(), *cfg)
177 if err != nil {
178 return nil, err
179 }
180 allTools := []fantasy.AgentTool{
181 tools.NewBashTool(env.permissions, env.workingDir, cfg.Options.Attribution),
182 tools.NewDownloadTool(env.permissions, env.workingDir, r.GetDefaultClient()),
183 tools.NewEditTool(env.lspClients, env.permissions, env.history, env.workingDir),
184 tools.NewMultiEditTool(env.lspClients, env.permissions, env.history, env.workingDir),
185 tools.NewFetchTool(env.permissions, env.workingDir, r.GetDefaultClient()),
186 tools.NewGlobTool(env.workingDir),
187 tools.NewGrepTool(env.workingDir),
188 tools.NewLsTool(env.permissions, env.workingDir, cfg.Tools.Ls),
189 tools.NewSourcegraphTool(r.GetDefaultClient()),
190 tools.NewViewTool(env.lspClients, env.permissions, env.workingDir),
191 tools.NewWriteTool(env.lspClients, env.permissions, env.history, env.workingDir),
192 }
193
194 return testSessionAgent(env, large, small, systemPrompt, allTools...), nil
195}
196
197// createSimpleGoProject creates a simple Go project structure in the given directory.
198// It creates a go.mod file and a main.go file with a basic hello world program.
199func createSimpleGoProject(t *testing.T, dir string) {
200 goMod := `module example.com/testproject
201
202go 1.23
203`
204 err := os.WriteFile(dir+"/go.mod", []byte(goMod), 0o644)
205 require.NoError(t, err)
206
207 mainGo := `package main
208
209import "fmt"
210
211func main() {
212 fmt.Println("Hello, World!")
213}
214`
215 err = os.WriteFile(dir+"/main.go", []byte(mainGo), 0o644)
216 require.NoError(t, err)
217}