1package agent
2
3import (
4 "fmt"
5 "net/http"
6 "os"
7 "path/filepath"
8 "testing"
9 "time"
10
11 "github.com/charmbracelet/catwalk/pkg/catwalk"
12 "github.com/charmbracelet/crush/internal/agent/prompt"
13 "github.com/charmbracelet/crush/internal/agent/tools"
14 "github.com/charmbracelet/crush/internal/config"
15 "github.com/charmbracelet/crush/internal/csync"
16 "github.com/charmbracelet/crush/internal/db"
17 "github.com/charmbracelet/crush/internal/history"
18 "github.com/charmbracelet/crush/internal/lsp"
19 "github.com/charmbracelet/crush/internal/message"
20 "github.com/charmbracelet/crush/internal/permission"
21 "github.com/charmbracelet/crush/internal/session"
22 "github.com/charmbracelet/fantasy/ai"
23 "github.com/charmbracelet/fantasy/anthropic"
24 "github.com/charmbracelet/fantasy/openai"
25 "github.com/charmbracelet/fantasy/openaicompat"
26 "github.com/charmbracelet/fantasy/openrouter"
27 "github.com/stretchr/testify/require"
28 "gopkg.in/dnaeon/go-vcr.v4/pkg/recorder"
29
30 _ "github.com/joho/godotenv/autoload"
31)
32
33type env struct {
34 workingDir string
35 sessions session.Service
36 messages message.Service
37 permissions permission.Service
38 history history.Service
39 lspClients *csync.Map[string, *lsp.Client]
40}
41
42type builderFunc func(t *testing.T, r *recorder.Recorder) (ai.LanguageModel, error)
43
44type modelPair struct {
45 name string
46 largeModel builderFunc
47 smallModel builderFunc
48}
49
50func anthropicBuilder(model string) builderFunc {
51 return func(_ *testing.T, r *recorder.Recorder) (ai.LanguageModel, error) {
52 provider := anthropic.New(
53 anthropic.WithAPIKey(os.Getenv("CRUSH_ANTHROPIC_API_KEY")),
54 anthropic.WithHTTPClient(&http.Client{Transport: r}),
55 )
56 return provider.LanguageModel(model)
57 }
58}
59
60func openaiBuilder(model string) builderFunc {
61 return func(_ *testing.T, r *recorder.Recorder) (ai.LanguageModel, error) {
62 provider := openai.New(
63 openai.WithAPIKey(os.Getenv("CRUSH_OPENAI_API_KEY")),
64 openai.WithHTTPClient(&http.Client{Transport: r}),
65 )
66 return provider.LanguageModel(model)
67 }
68}
69
70func openRouterBuilder(model string) builderFunc {
71 return func(t *testing.T, r *recorder.Recorder) (ai.LanguageModel, error) {
72 tf := func() func() string {
73 id := 0
74 return func() string {
75 id += 1
76 return fmt.Sprintf("%s-%d", t.Name(), id)
77 }
78 }
79 provider := openrouter.New(
80 openrouter.WithAPIKey(os.Getenv("CRUSH_OPENROUTER_API_KEY")),
81 openrouter.WithHTTPClient(&http.Client{Transport: r}),
82 openrouter.WithLanguageUniqueToolCallIds(),
83 openrouter.WithLanguageModelGenerateIDFunc(tf()),
84 )
85 return provider.LanguageModel(model)
86 }
87}
88
89func zAIBuilder(model string) builderFunc {
90 return func(t *testing.T, r *recorder.Recorder) (ai.LanguageModel, error) {
91 tf := func() func() string {
92 id := 0
93 return func() string {
94 id += 1
95 return fmt.Sprintf("%s-%d", t.Name(), id)
96 }
97 }
98 provider := openaicompat.New(
99 "https://api.z.ai/api/coding/paas/v4",
100 openaicompat.WithAPIKey(os.Getenv("CRUSH_ZAI_API_KEY")),
101 openaicompat.WithHTTPClient(&http.Client{Transport: r}),
102 openaicompat.WithLanguageUniqueToolCallIds(),
103 openaicompat.WithLanguageModelGenerateIDFunc(tf()),
104 )
105 return provider.LanguageModel(model)
106 }
107}
108
109func testEnv(t *testing.T) env {
110 testDir := filepath.Join("/tmp/crush-test/", t.Name())
111 os.RemoveAll(testDir)
112 err := os.MkdirAll(testDir, 0o755)
113 t.Cleanup(func() {
114 os.RemoveAll(testDir)
115 })
116 require.NoError(t, err)
117 workingDir := testDir
118 conn, err := db.Connect(t.Context(), t.TempDir())
119 require.NoError(t, err)
120 q := db.New(conn)
121 sessions := session.NewService(q)
122 messages := message.NewService(q)
123 permissions := permission.NewPermissionService(workingDir, true, []string{})
124 history := history.NewService(q, conn)
125 lspClients := csync.NewMap[string, *lsp.Client]()
126 return env{
127 workingDir,
128 sessions,
129 messages,
130 permissions,
131 history,
132 lspClients,
133 }
134}
135
136func testSessionAgent(env env, large, small ai.LanguageModel, systemPrompt string, tools ...ai.AgentTool) SessionAgent {
137 largeModel := Model{
138 Model: large,
139 CatwalkCfg: catwalk.Model{
140 // todo: add values
141 },
142 }
143 smallModel := Model{
144 Model: small,
145 CatwalkCfg: catwalk.Model{
146 // todo: add values
147 },
148 }
149 agent := NewSessionAgent(largeModel, smallModel, systemPrompt, env.sessions, env.messages, tools...)
150 return agent
151}
152
153func coderAgent(r *recorder.Recorder, env env, large, small ai.LanguageModel) (SessionAgent, error) {
154 fixedTime := func() time.Time {
155 t, _ := time.Parse("1/2/2006", "1/1/2025")
156 return t
157 }
158 prompt, err := coderPrompt(
159 prompt.WithTimeFunc(fixedTime),
160 prompt.WithPlatform("linux"),
161 prompt.WithWorkingDir(env.workingDir),
162 )
163 if err != nil {
164 return nil, err
165 }
166 cfg, err := config.Init(env.workingDir, "", false)
167 if err != nil {
168 return nil, err
169 }
170
171 systemPrompt, err := prompt.Build(large.Provider(), large.Model(), *cfg)
172 if err != nil {
173 return nil, err
174 }
175 allTools := []ai.AgentTool{
176 tools.NewBashTool(env.permissions, env.workingDir, cfg.Options.Attribution),
177 tools.NewDownloadTool(env.permissions, env.workingDir, r.GetDefaultClient()),
178 tools.NewEditTool(env.lspClients, env.permissions, env.history, env.workingDir),
179 tools.NewMultiEditTool(env.lspClients, env.permissions, env.history, env.workingDir),
180 tools.NewFetchTool(env.permissions, env.workingDir, r.GetDefaultClient()),
181 tools.NewGlobTool(env.workingDir),
182 tools.NewGrepTool(env.workingDir),
183 tools.NewLsTool(env.permissions, env.workingDir, cfg.Tools.Ls),
184 tools.NewSourcegraphTool(r.GetDefaultClient()),
185 tools.NewViewTool(env.lspClients, env.permissions, env.workingDir),
186 tools.NewWriteTool(env.lspClients, env.permissions, env.history, env.workingDir),
187 }
188
189 return testSessionAgent(env, large, small, systemPrompt, allTools...), nil
190}
191
192// createSimpleGoProject creates a simple Go project structure in the given directory.
193// It creates a go.mod file and a main.go file with a basic hello world program.
194func createSimpleGoProject(t *testing.T, dir string) {
195 goMod := `module example.com/testproject
196
197go 1.23
198`
199 err := os.WriteFile(dir+"/go.mod", []byte(goMod), 0o644)
200 require.NoError(t, err)
201
202 mainGo := `package main
203
204import "fmt"
205
206func main() {
207 fmt.Println("Hello, World!")
208}
209`
210 err = os.WriteFile(dir+"/main.go", []byte(mainGo), 0o644)
211 require.NoError(t, err)
212}