1package agent
2
3import (
4 "context"
5 "net/http"
6 "os"
7 "path/filepath"
8 "testing"
9 "time"
10
11 "charm.land/fantasy"
12 "charm.land/fantasy/providers/anthropic"
13 "charm.land/fantasy/providers/openai"
14 "charm.land/fantasy/providers/openaicompat"
15 "charm.land/fantasy/providers/openrouter"
16 "git.secluded.site/crush/internal/agent/prompt"
17 "git.secluded.site/crush/internal/agent/tools"
18 "git.secluded.site/crush/internal/config"
19 "git.secluded.site/crush/internal/csync"
20 "git.secluded.site/crush/internal/db"
21 "git.secluded.site/crush/internal/history"
22 "git.secluded.site/crush/internal/lsp"
23 "git.secluded.site/crush/internal/message"
24 "git.secluded.site/crush/internal/permission"
25 "git.secluded.site/crush/internal/session"
26 "github.com/charmbracelet/catwalk/pkg/catwalk"
27 "github.com/stretchr/testify/require"
28 "gopkg.in/dnaeon/go-vcr.v4/pkg/recorder"
29
30 _ "github.com/joho/godotenv/autoload"
31)
32
33type env struct {
34 workingDir string
35 sessions session.Service
36 messages message.Service
37 permissions permission.Service
38 history history.Service
39 lspClients *csync.Map[string, *lsp.Client]
40}
41
42type builderFunc func(t *testing.T, r *recorder.Recorder) (fantasy.LanguageModel, error)
43
44type modelPair struct {
45 name string
46 largeModel builderFunc
47 smallModel builderFunc
48}
49
50func anthropicBuilder(model string) builderFunc {
51 return func(t *testing.T, r *recorder.Recorder) (fantasy.LanguageModel, error) {
52 provider, err := anthropic.New(
53 anthropic.WithAPIKey(os.Getenv("CRUSH_ANTHROPIC_API_KEY")),
54 anthropic.WithHTTPClient(&http.Client{Transport: r}),
55 )
56 if err != nil {
57 return nil, err
58 }
59 return provider.LanguageModel(t.Context(), model)
60 }
61}
62
63func openaiBuilder(model string) builderFunc {
64 return func(t *testing.T, r *recorder.Recorder) (fantasy.LanguageModel, error) {
65 provider, err := openai.New(
66 openai.WithAPIKey(os.Getenv("CRUSH_OPENAI_API_KEY")),
67 openai.WithHTTPClient(&http.Client{Transport: r}),
68 )
69 if err != nil {
70 return nil, err
71 }
72 return provider.LanguageModel(t.Context(), model)
73 }
74}
75
76func openRouterBuilder(model string) builderFunc {
77 return func(t *testing.T, r *recorder.Recorder) (fantasy.LanguageModel, error) {
78 provider, err := openrouter.New(
79 openrouter.WithAPIKey(os.Getenv("CRUSH_OPENROUTER_API_KEY")),
80 openrouter.WithHTTPClient(&http.Client{Transport: r}),
81 )
82 if err != nil {
83 return nil, err
84 }
85 return provider.LanguageModel(t.Context(), model)
86 }
87}
88
89func zAIBuilder(model string) builderFunc {
90 return func(t *testing.T, r *recorder.Recorder) (fantasy.LanguageModel, error) {
91 provider, err := openaicompat.New(
92 openaicompat.WithBaseURL("https://api.z.ai/api/coding/paas/v4"),
93 openaicompat.WithAPIKey(os.Getenv("CRUSH_ZAI_API_KEY")),
94 openaicompat.WithHTTPClient(&http.Client{Transport: r}),
95 )
96 if err != nil {
97 return nil, err
98 }
99 return provider.LanguageModel(t.Context(), model)
100 }
101}
102
103func testEnv(t *testing.T) env {
104 workingDir := filepath.Join("/tmp/crush-test/", t.Name())
105 os.RemoveAll(workingDir)
106
107 err := os.MkdirAll(workingDir, 0o755)
108 require.NoError(t, err)
109
110 conn, err := db.Connect(t.Context(), t.TempDir())
111 require.NoError(t, err)
112
113 q := db.New(conn)
114 sessions := session.NewService(q)
115 messages := message.NewService(q)
116
117 permissions := permission.NewPermissionService(t.Context(), workingDir, true, []string{}, nil)
118 history := history.NewService(q, conn)
119 lspClients := csync.NewMap[string, *lsp.Client]()
120
121 t.Cleanup(func() {
122 conn.Close()
123 os.RemoveAll(workingDir)
124 })
125
126 return env{
127 workingDir,
128 sessions,
129 messages,
130 permissions,
131 history,
132 lspClients,
133 }
134}
135
136func testSessionAgent(env env, large, small fantasy.LanguageModel, systemPrompt string, tools ...fantasy.AgentTool) SessionAgent {
137 largeModel := Model{
138 Model: large,
139 CatwalkCfg: catwalk.Model{
140 ContextWindow: 200000,
141 DefaultMaxTokens: 10000,
142 },
143 }
144 smallModel := Model{
145 Model: small,
146 CatwalkCfg: catwalk.Model{
147 ContextWindow: 200000,
148 DefaultMaxTokens: 10000,
149 },
150 }
151 agent := NewSessionAgent(SessionAgentOptions{
152 LargeModel: largeModel,
153 SmallModel: smallModel,
154 SystemPromptPrefix: "",
155 SystemPrompt: systemPrompt,
156 DisableAutoSummarize: false,
157 IsYolo: true,
158 Sessions: env.sessions,
159 Messages: env.messages,
160 Tools: tools,
161 Notifier: nil,
162 NotificationCtx: context.Background(),
163 })
164 return agent
165}
166
167func coderAgent(r *recorder.Recorder, env env, large, small fantasy.LanguageModel) (SessionAgent, error) {
168 fixedTime := func() time.Time {
169 t, _ := time.Parse("1/2/2006", "1/1/2025")
170 return t
171 }
172 prompt, err := coderPrompt(
173 prompt.WithTimeFunc(fixedTime),
174 prompt.WithPlatform("linux"),
175 prompt.WithWorkingDir(filepath.ToSlash(env.workingDir)),
176 )
177 if err != nil {
178 return nil, err
179 }
180 cfg, err := config.Init(env.workingDir, "", false)
181 if err != nil {
182 return nil, err
183 }
184
185 systemPrompt, err := prompt.Build(context.TODO(), large.Provider(), large.Model(), *cfg)
186 if err != nil {
187 return nil, err
188 }
189 allTools := []fantasy.AgentTool{
190 tools.NewBashTool(env.permissions, env.workingDir, cfg.Options.Attribution),
191 tools.NewDownloadTool(env.permissions, env.workingDir, r.GetDefaultClient()),
192 tools.NewEditTool(env.lspClients, env.permissions, env.history, env.workingDir),
193 tools.NewMultiEditTool(env.lspClients, env.permissions, env.history, env.workingDir),
194 tools.NewFetchTool(env.permissions, env.workingDir, r.GetDefaultClient()),
195 tools.NewGlobTool(env.workingDir),
196 tools.NewGrepTool(env.workingDir),
197 tools.NewLsTool(env.permissions, env.workingDir, cfg.Tools.Ls),
198 tools.NewSourcegraphTool(r.GetDefaultClient()),
199 tools.NewViewTool(env.lspClients, env.permissions, env.workingDir),
200 tools.NewWriteTool(env.lspClients, env.permissions, env.history, env.workingDir),
201 }
202
203 return testSessionAgent(env, large, small, systemPrompt, allTools...), nil
204}
205
206// createSimpleGoProject creates a simple Go project structure in the given directory.
207// It creates a go.mod file and a main.go file with a basic hello world program.
208func createSimpleGoProject(t *testing.T, dir string) {
209 goMod := `module example.com/testproject
210
211go 1.23
212`
213 err := os.WriteFile(dir+"/go.mod", []byte(goMod), 0o644)
214 require.NoError(t, err)
215
216 mainGo := `package main
217
218import "fmt"
219
220func main() {
221 fmt.Println("Hello, World!")
222}
223`
224 err = os.WriteFile(dir+"/main.go", []byte(mainGo), 0o644)
225 require.NoError(t, err)
226}