1package agent
2
3import (
4 "context"
5 "net/http"
6 "os"
7 "path/filepath"
8 "testing"
9 "time"
10
11 "charm.land/fantasy"
12 "charm.land/fantasy/providers/anthropic"
13 "charm.land/fantasy/providers/openai"
14 "charm.land/fantasy/providers/openaicompat"
15 "charm.land/fantasy/providers/openrouter"
16 "github.com/charmbracelet/catwalk/pkg/catwalk"
17 "github.com/charmbracelet/crush/internal/agent/prompt"
18 "github.com/charmbracelet/crush/internal/agent/tools"
19 "github.com/charmbracelet/crush/internal/config"
20 "github.com/charmbracelet/crush/internal/csync"
21 "github.com/charmbracelet/crush/internal/db"
22 "github.com/charmbracelet/crush/internal/history"
23 "github.com/charmbracelet/crush/internal/lsp"
24 "github.com/charmbracelet/crush/internal/message"
25 "github.com/charmbracelet/crush/internal/permission"
26 "github.com/charmbracelet/crush/internal/session"
27 "github.com/stretchr/testify/require"
28 "gopkg.in/dnaeon/go-vcr.v4/pkg/recorder"
29
30 _ "github.com/joho/godotenv/autoload"
31)
32
33type env struct {
34 workingDir string
35 sessions session.Service
36 messages message.Service
37 permissions permission.Service
38 history history.Service
39 lspClients *csync.Map[string, *lsp.Client]
40}
41
42type builderFunc func(t *testing.T, r *recorder.Recorder) (fantasy.LanguageModel, error)
43
44type modelPair struct {
45 name string
46 largeModel builderFunc
47 smallModel builderFunc
48}
49
50func anthropicBuilder(model string) builderFunc {
51 return func(t *testing.T, r *recorder.Recorder) (fantasy.LanguageModel, error) {
52 provider, err := anthropic.New(
53 anthropic.WithAPIKey(os.Getenv("CRUSH_ANTHROPIC_API_KEY")),
54 anthropic.WithHTTPClient(&http.Client{Transport: r}),
55 )
56 if err != nil {
57 return nil, err
58 }
59 return provider.LanguageModel(t.Context(), model)
60 }
61}
62
63func openaiBuilder(model string) builderFunc {
64 return func(t *testing.T, r *recorder.Recorder) (fantasy.LanguageModel, error) {
65 provider, err := openai.New(
66 openai.WithAPIKey(os.Getenv("CRUSH_OPENAI_API_KEY")),
67 openai.WithHTTPClient(&http.Client{Transport: r}),
68 )
69 if err != nil {
70 return nil, err
71 }
72 return provider.LanguageModel(t.Context(), model)
73 }
74}
75
76func openRouterBuilder(model string) builderFunc {
77 return func(t *testing.T, r *recorder.Recorder) (fantasy.LanguageModel, error) {
78 provider, err := openrouter.New(
79 openrouter.WithAPIKey(os.Getenv("CRUSH_OPENROUTER_API_KEY")),
80 openrouter.WithHTTPClient(&http.Client{Transport: r}),
81 )
82 if err != nil {
83 return nil, err
84 }
85 return provider.LanguageModel(t.Context(), model)
86 }
87}
88
89func zAIBuilder(model string) builderFunc {
90 return func(t *testing.T, r *recorder.Recorder) (fantasy.LanguageModel, error) {
91 provider, err := openaicompat.New(
92 openaicompat.WithBaseURL("https://api.z.ai/api/coding/paas/v4"),
93 openaicompat.WithAPIKey(os.Getenv("CRUSH_ZAI_API_KEY")),
94 openaicompat.WithHTTPClient(&http.Client{Transport: r}),
95 )
96 if err != nil {
97 return nil, err
98 }
99 return provider.LanguageModel(t.Context(), model)
100 }
101}
102
103func testEnv(t *testing.T) env {
104 testDir := filepath.Join("/tmp/crush-test/", t.Name())
105 os.RemoveAll(testDir)
106 err := os.MkdirAll(testDir, 0o755)
107 t.Cleanup(func() {
108 os.RemoveAll(testDir)
109 })
110 require.NoError(t, err)
111 workingDir := testDir
112 conn, err := db.Connect(t.Context(), t.TempDir())
113 require.NoError(t, err)
114 q := db.New(conn)
115 sessions := session.NewService(q)
116 messages := message.NewService(q)
117 permissions := permission.NewPermissionService(workingDir, true, []string{})
118 history := history.NewService(q, conn)
119 lspClients := csync.NewMap[string, *lsp.Client]()
120 return env{
121 workingDir,
122 sessions,
123 messages,
124 permissions,
125 history,
126 lspClients,
127 }
128}
129
130func testSessionAgent(env env, large, small fantasy.LanguageModel, systemPrompt string, tools ...fantasy.AgentTool) SessionAgent {
131 largeModel := Model{
132 Model: large,
133 CatwalkCfg: catwalk.Model{
134 ContextWindow: 200000,
135 DefaultMaxTokens: 10000,
136 },
137 }
138 smallModel := Model{
139 Model: small,
140 CatwalkCfg: catwalk.Model{
141 ContextWindow: 200000,
142 DefaultMaxTokens: 10000,
143 },
144 }
145 agent := NewSessionAgent(SessionAgentOptions{largeModel, smallModel, systemPrompt, false, env.sessions, env.messages, tools})
146 return agent
147}
148
149func coderAgent(r *recorder.Recorder, env env, large, small fantasy.LanguageModel) (SessionAgent, error) {
150 fixedTime := func() time.Time {
151 t, _ := time.Parse("1/2/2006", "1/1/2025")
152 return t
153 }
154 prompt, err := coderPrompt(
155 prompt.WithTimeFunc(fixedTime),
156 prompt.WithPlatform("linux"),
157 prompt.WithWorkingDir(env.workingDir),
158 )
159 if err != nil {
160 return nil, err
161 }
162 cfg, err := config.Init(env.workingDir, "", false)
163 if err != nil {
164 return nil, err
165 }
166
167 systemPrompt, err := prompt.Build(context.TODO(), large.Provider(), large.Model(), *cfg)
168 if err != nil {
169 return nil, err
170 }
171 allTools := []fantasy.AgentTool{
172 tools.NewBashTool(env.permissions, env.workingDir, cfg.Options.Attribution),
173 tools.NewDownloadTool(env.permissions, env.workingDir, r.GetDefaultClient()),
174 tools.NewEditTool(env.lspClients, env.permissions, env.history, env.workingDir),
175 tools.NewMultiEditTool(env.lspClients, env.permissions, env.history, env.workingDir),
176 tools.NewFetchTool(env.permissions, env.workingDir, r.GetDefaultClient()),
177 tools.NewGlobTool(env.workingDir),
178 tools.NewGrepTool(env.workingDir),
179 tools.NewLsTool(env.permissions, env.workingDir, cfg.Tools.Ls),
180 tools.NewSourcegraphTool(r.GetDefaultClient()),
181 tools.NewViewTool(env.lspClients, env.permissions, env.workingDir),
182 tools.NewWriteTool(env.lspClients, env.permissions, env.history, env.workingDir),
183 }
184
185 return testSessionAgent(env, large, small, systemPrompt, allTools...), nil
186}
187
188// createSimpleGoProject creates a simple Go project structure in the given directory.
189// It creates a go.mod file and a main.go file with a basic hello world program.
190func createSimpleGoProject(t *testing.T, dir string) {
191 goMod := `module example.com/testproject
192
193go 1.23
194`
195 err := os.WriteFile(dir+"/go.mod", []byte(goMod), 0o644)
196 require.NoError(t, err)
197
198 mainGo := `package main
199
200import "fmt"
201
202func main() {
203 fmt.Println("Hello, World!")
204}
205`
206 err = os.WriteFile(dir+"/main.go", []byte(mainGo), 0o644)
207 require.NoError(t, err)
208}