1package agent
2
3import (
4 "context"
5 "net/http"
6 "os"
7 "path/filepath"
8 "testing"
9 "time"
10
11 "charm.land/fantasy"
12 "charm.land/fantasy/providers/anthropic"
13 "charm.land/fantasy/providers/openai"
14 "charm.land/fantasy/providers/openaicompat"
15 "charm.land/fantasy/providers/openrouter"
16 "github.com/charmbracelet/catwalk/pkg/catwalk"
17 "github.com/charmbracelet/crush/internal/agent/prompt"
18 "github.com/charmbracelet/crush/internal/agent/tools"
19 "github.com/charmbracelet/crush/internal/config"
20 "github.com/charmbracelet/crush/internal/csync"
21 "github.com/charmbracelet/crush/internal/db"
22 "github.com/charmbracelet/crush/internal/history"
23 "github.com/charmbracelet/crush/internal/lsp"
24 "github.com/charmbracelet/crush/internal/message"
25 "github.com/charmbracelet/crush/internal/permission"
26 "github.com/charmbracelet/crush/internal/session"
27 "github.com/stretchr/testify/require"
28 "gopkg.in/dnaeon/go-vcr.v4/pkg/recorder"
29
30 _ "github.com/joho/godotenv/autoload"
31)
32
33type env struct {
34 workingDir string
35 sessions session.Service
36 messages message.Service
37 permissions permission.Service
38 history history.Service
39 lspClients *csync.Map[string, *lsp.Client]
40}
41
42type builderFunc func(t *testing.T, r *recorder.Recorder) (fantasy.LanguageModel, error)
43
44type modelPair struct {
45 name string
46 largeModel builderFunc
47 smallModel builderFunc
48}
49
50func anthropicBuilder(model string) builderFunc {
51 return func(t *testing.T, r *recorder.Recorder) (fantasy.LanguageModel, error) {
52 provider, err := anthropic.New(
53 anthropic.WithAPIKey(os.Getenv("CRUSH_ANTHROPIC_API_KEY")),
54 anthropic.WithHTTPClient(&http.Client{Transport: r}),
55 )
56 if err != nil {
57 return nil, err
58 }
59 return provider.LanguageModel(t.Context(), model)
60 }
61}
62
63func openaiBuilder(model string) builderFunc {
64 return func(t *testing.T, r *recorder.Recorder) (fantasy.LanguageModel, error) {
65 provider, err := openai.New(
66 openai.WithAPIKey(os.Getenv("CRUSH_OPENAI_API_KEY")),
67 openai.WithHTTPClient(&http.Client{Transport: r}),
68 )
69 if err != nil {
70 return nil, err
71 }
72 return provider.LanguageModel(t.Context(), model)
73 }
74}
75
76func openRouterBuilder(model string) builderFunc {
77 return func(t *testing.T, r *recorder.Recorder) (fantasy.LanguageModel, error) {
78 provider, err := openrouter.New(
79 openrouter.WithAPIKey(os.Getenv("CRUSH_OPENROUTER_API_KEY")),
80 openrouter.WithHTTPClient(&http.Client{Transport: r}),
81 )
82 if err != nil {
83 return nil, err
84 }
85 return provider.LanguageModel(t.Context(), model)
86 }
87}
88
89func zAIBuilder(model string) builderFunc {
90 return func(t *testing.T, r *recorder.Recorder) (fantasy.LanguageModel, error) {
91 provider, err := openaicompat.New(
92 openaicompat.WithBaseURL("https://api.z.ai/api/coding/paas/v4"),
93 openaicompat.WithAPIKey(os.Getenv("CRUSH_ZAI_API_KEY")),
94 openaicompat.WithHTTPClient(&http.Client{Transport: r}),
95 )
96 if err != nil {
97 return nil, err
98 }
99 return provider.LanguageModel(t.Context(), model)
100 }
101}
102
103func testEnv(t *testing.T) env {
104 testDir := filepath.Join("/tmp/crush-test/", t.Name())
105 os.RemoveAll(testDir)
106 err := os.MkdirAll(testDir, 0o755)
107 require.NoError(t, err)
108 workingDir := testDir
109 conn, err := db.Connect(t.Context(), t.TempDir())
110 require.NoError(t, err)
111 q := db.New(conn)
112 sessions := session.NewService(q)
113 messages := message.NewService(q)
114 permissions := permission.NewPermissionService(workingDir, true, []string{})
115 history := history.NewService(q, conn)
116 lspClients := csync.NewMap[string, *lsp.Client]()
117
118 t.Cleanup(func() {
119 conn.Close()
120 os.RemoveAll(testDir)
121 })
122
123 return env{
124 workingDir,
125 sessions,
126 messages,
127 permissions,
128 history,
129 lspClients,
130 }
131}
132
133func testSessionAgent(env env, large, small fantasy.LanguageModel, systemPrompt string, tools ...fantasy.AgentTool) SessionAgent {
134 largeModel := Model{
135 Model: large,
136 CatwalkCfg: catwalk.Model{
137 ContextWindow: 200000,
138 DefaultMaxTokens: 10000,
139 },
140 }
141 smallModel := Model{
142 Model: small,
143 CatwalkCfg: catwalk.Model{
144 ContextWindow: 200000,
145 DefaultMaxTokens: 10000,
146 },
147 }
148 agent := NewSessionAgent(SessionAgentOptions{largeModel, smallModel, systemPrompt, false, env.sessions, env.messages, tools})
149 return agent
150}
151
152func coderAgent(r *recorder.Recorder, env env, large, small fantasy.LanguageModel) (SessionAgent, error) {
153 fixedTime := func() time.Time {
154 t, _ := time.Parse("1/2/2006", "1/1/2025")
155 return t
156 }
157 prompt, err := coderPrompt(
158 prompt.WithTimeFunc(fixedTime),
159 prompt.WithPlatform("linux"),
160 prompt.WithWorkingDir(env.workingDir),
161 )
162 if err != nil {
163 return nil, err
164 }
165 cfg, err := config.Init(env.workingDir, "", false)
166 if err != nil {
167 return nil, err
168 }
169
170 systemPrompt, err := prompt.Build(context.TODO(), large.Provider(), large.Model(), *cfg)
171 if err != nil {
172 return nil, err
173 }
174 allTools := []fantasy.AgentTool{
175 tools.NewBashTool(env.permissions, env.workingDir, cfg.Options.Attribution),
176 tools.NewDownloadTool(env.permissions, env.workingDir, r.GetDefaultClient()),
177 tools.NewEditTool(env.lspClients, env.permissions, env.history, env.workingDir),
178 tools.NewMultiEditTool(env.lspClients, env.permissions, env.history, env.workingDir),
179 tools.NewFetchTool(env.permissions, env.workingDir, r.GetDefaultClient()),
180 tools.NewGlobTool(env.workingDir),
181 tools.NewGrepTool(env.workingDir),
182 tools.NewLsTool(env.permissions, env.workingDir, cfg.Tools.Ls),
183 tools.NewSourcegraphTool(r.GetDefaultClient()),
184 tools.NewViewTool(env.lspClients, env.permissions, env.workingDir),
185 tools.NewWriteTool(env.lspClients, env.permissions, env.history, env.workingDir),
186 }
187
188 return testSessionAgent(env, large, small, systemPrompt, allTools...), nil
189}
190
191// createSimpleGoProject creates a simple Go project structure in the given directory.
192// It creates a go.mod file and a main.go file with a basic hello world program.
193func createSimpleGoProject(t *testing.T, dir string) {
194 goMod := `module example.com/testproject
195
196go 1.23
197`
198 err := os.WriteFile(dir+"/go.mod", []byte(goMod), 0o644)
199 require.NoError(t, err)
200
201 mainGo := `package main
202
203import "fmt"
204
205func main() {
206 fmt.Println("Hello, World!")
207}
208`
209 err = os.WriteFile(dir+"/main.go", []byte(mainGo), 0o644)
210 require.NoError(t, err)
211}