1package agent
2
3import (
4 "net/http"
5 "os"
6 "path/filepath"
7 "testing"
8 "time"
9
10 "github.com/charmbracelet/catwalk/pkg/catwalk"
11 "github.com/charmbracelet/crush/internal/agent/prompt"
12 "github.com/charmbracelet/crush/internal/agent/tools"
13 "github.com/charmbracelet/crush/internal/config"
14 "github.com/charmbracelet/crush/internal/csync"
15 "github.com/charmbracelet/crush/internal/db"
16 "github.com/charmbracelet/crush/internal/history"
17 "github.com/charmbracelet/crush/internal/lsp"
18 "github.com/charmbracelet/crush/internal/message"
19 "github.com/charmbracelet/crush/internal/permission"
20 "github.com/charmbracelet/crush/internal/session"
21 "github.com/charmbracelet/fantasy/ai"
22 "github.com/charmbracelet/fantasy/anthropic"
23 "github.com/charmbracelet/fantasy/openai"
24 "github.com/charmbracelet/fantasy/openaicompat"
25 "github.com/charmbracelet/fantasy/openrouter"
26 "github.com/stretchr/testify/require"
27 "gopkg.in/dnaeon/go-vcr.v4/pkg/recorder"
28
29 _ "github.com/joho/godotenv/autoload"
30)
31
32type env struct {
33 workingDir string
34 sessions session.Service
35 messages message.Service
36 permissions permission.Service
37 history history.Service
38 lspClients *csync.Map[string, *lsp.Client]
39}
40
41type builderFunc func(t *testing.T, r *recorder.Recorder) (ai.LanguageModel, error)
42
43type modelPair struct {
44 name string
45 largeModel builderFunc
46 smallModel builderFunc
47}
48
49func anthropicBuilder(model string) builderFunc {
50 return func(_ *testing.T, r *recorder.Recorder) (ai.LanguageModel, error) {
51 provider := anthropic.New(
52 anthropic.WithAPIKey(os.Getenv("CRUSH_ANTHROPIC_API_KEY")),
53 anthropic.WithHTTPClient(&http.Client{Transport: r}),
54 )
55 return provider.LanguageModel(model)
56 }
57}
58
59func openaiBuilder(model string) builderFunc {
60 return func(_ *testing.T, r *recorder.Recorder) (ai.LanguageModel, error) {
61 provider := openai.New(
62 openai.WithAPIKey(os.Getenv("CRUSH_OPENAI_API_KEY")),
63 openai.WithHTTPClient(&http.Client{Transport: r}),
64 )
65 return provider.LanguageModel(model)
66 }
67}
68
69func openRouterBuilder(model string) builderFunc {
70 return func(t *testing.T, r *recorder.Recorder) (ai.LanguageModel, error) {
71 provider := openrouter.New(
72 openrouter.WithAPIKey(os.Getenv("CRUSH_OPENROUTER_API_KEY")),
73 openrouter.WithHTTPClient(&http.Client{Transport: r}),
74 )
75 return provider.LanguageModel(model)
76 }
77}
78
79func zAIBuilder(model string) builderFunc {
80 return func(t *testing.T, r *recorder.Recorder) (ai.LanguageModel, error) {
81 provider := openaicompat.New(
82 "https://api.z.ai/api/coding/paas/v4",
83 openaicompat.WithAPIKey(os.Getenv("CRUSH_ZAI_API_KEY")),
84 openaicompat.WithHTTPClient(&http.Client{Transport: r}),
85 )
86 return provider.LanguageModel(model)
87 }
88}
89
90func testEnv(t *testing.T) env {
91 testDir := filepath.Join("/tmp/crush-test/", t.Name())
92 os.RemoveAll(testDir)
93 err := os.MkdirAll(testDir, 0o755)
94 t.Cleanup(func() {
95 os.RemoveAll(testDir)
96 })
97 require.NoError(t, err)
98 workingDir := testDir
99 conn, err := db.Connect(t.Context(), t.TempDir())
100 require.NoError(t, err)
101 q := db.New(conn)
102 sessions := session.NewService(q)
103 messages := message.NewService(q)
104 permissions := permission.NewPermissionService(workingDir, true, []string{})
105 history := history.NewService(q, conn)
106 lspClients := csync.NewMap[string, *lsp.Client]()
107 return env{
108 workingDir,
109 sessions,
110 messages,
111 permissions,
112 history,
113 lspClients,
114 }
115}
116
117func testSessionAgent(env env, large, small ai.LanguageModel, systemPrompt string, tools ...ai.AgentTool) SessionAgent {
118 largeModel := Model{
119 Model: large,
120 CatwalkCfg: catwalk.Model{
121 // todo: add values
122 },
123 }
124 smallModel := Model{
125 Model: small,
126 CatwalkCfg: catwalk.Model{
127 // todo: add values
128 },
129 }
130 agent := NewSessionAgent(largeModel, smallModel, systemPrompt, env.sessions, env.messages, tools...)
131 return agent
132}
133
134func coderAgent(r *recorder.Recorder, env env, large, small ai.LanguageModel) (SessionAgent, error) {
135 fixedTime := func() time.Time {
136 t, _ := time.Parse("1/2/2006", "1/1/2025")
137 return t
138 }
139 prompt, err := coderPrompt(
140 prompt.WithTimeFunc(fixedTime),
141 prompt.WithPlatform("linux"),
142 prompt.WithWorkingDir(env.workingDir),
143 )
144 if err != nil {
145 return nil, err
146 }
147 cfg, err := config.Init(env.workingDir, "", false)
148 if err != nil {
149 return nil, err
150 }
151
152 systemPrompt, err := prompt.Build(large.Provider(), large.Model(), *cfg)
153 if err != nil {
154 return nil, err
155 }
156 allTools := []ai.AgentTool{
157 tools.NewBashTool(env.permissions, env.workingDir, cfg.Options.Attribution),
158 tools.NewDownloadTool(env.permissions, env.workingDir, r.GetDefaultClient()),
159 tools.NewEditTool(env.lspClients, env.permissions, env.history, env.workingDir),
160 tools.NewMultiEditTool(env.lspClients, env.permissions, env.history, env.workingDir),
161 tools.NewFetchTool(env.permissions, env.workingDir, r.GetDefaultClient()),
162 tools.NewGlobTool(env.workingDir),
163 tools.NewGrepTool(env.workingDir),
164 tools.NewLsTool(env.permissions, env.workingDir, cfg.Tools.Ls),
165 tools.NewSourcegraphTool(r.GetDefaultClient()),
166 tools.NewViewTool(env.lspClients, env.permissions, env.workingDir),
167 tools.NewWriteTool(env.lspClients, env.permissions, env.history, env.workingDir),
168 }
169
170 return testSessionAgent(env, large, small, systemPrompt, allTools...), nil
171}
172
173// createSimpleGoProject creates a simple Go project structure in the given directory.
174// It creates a go.mod file and a main.go file with a basic hello world program.
175func createSimpleGoProject(t *testing.T, dir string) {
176 goMod := `module example.com/testproject
177
178go 1.23
179`
180 err := os.WriteFile(dir+"/go.mod", []byte(goMod), 0o644)
181 require.NoError(t, err)
182
183 mainGo := `package main
184
185import "fmt"
186
187func main() {
188 fmt.Println("Hello, World!")
189}
190`
191 err = os.WriteFile(dir+"/main.go", []byte(mainGo), 0o644)
192 require.NoError(t, err)
193}