1package agent
2
3import (
4 "fmt"
5 "net/http"
6 "os"
7 "path/filepath"
8 "testing"
9 "time"
10
11 "github.com/charmbracelet/catwalk/pkg/catwalk"
12 "github.com/charmbracelet/crush/internal/agent/prompt"
13 "github.com/charmbracelet/crush/internal/agent/tools"
14 "github.com/charmbracelet/crush/internal/config"
15 "github.com/charmbracelet/crush/internal/csync"
16 "github.com/charmbracelet/crush/internal/db"
17 "github.com/charmbracelet/crush/internal/history"
18 "github.com/charmbracelet/crush/internal/lsp"
19 "github.com/charmbracelet/crush/internal/message"
20 "github.com/charmbracelet/crush/internal/permission"
21 "github.com/charmbracelet/crush/internal/session"
22 "github.com/charmbracelet/fantasy/ai"
23 "github.com/charmbracelet/fantasy/anthropic"
24 "github.com/charmbracelet/fantasy/openai"
25 "github.com/charmbracelet/fantasy/openrouter"
26 "github.com/stretchr/testify/require"
27 "gopkg.in/dnaeon/go-vcr.v4/pkg/recorder"
28
29 _ "github.com/joho/godotenv/autoload"
30)
31
32type env struct {
33 workingDir string
34 sessions session.Service
35 messages message.Service
36 permissions permission.Service
37 history history.Service
38 lspClients *csync.Map[string, *lsp.Client]
39}
40
41type builderFunc func(t *testing.T, r *recorder.Recorder) (ai.LanguageModel, error)
42
43type modelPair struct {
44 name string
45 largeModel builderFunc
46 smallModel builderFunc
47}
48
49func anthropicBuilder(model string) builderFunc {
50 return func(_ *testing.T, r *recorder.Recorder) (ai.LanguageModel, error) {
51 provider := anthropic.New(
52 anthropic.WithAPIKey(os.Getenv("CRUSH_ANTHROPIC_API_KEY")),
53 anthropic.WithHTTPClient(&http.Client{Transport: r}),
54 )
55 return provider.LanguageModel(model)
56 }
57}
58
59func openaiBuilder(model string) builderFunc {
60 return func(_ *testing.T, r *recorder.Recorder) (ai.LanguageModel, error) {
61 provider := openai.New(
62 openai.WithAPIKey(os.Getenv("CRUSH_OPENAI_API_KEY")),
63 openai.WithHTTPClient(&http.Client{Transport: r}),
64 )
65 return provider.LanguageModel(model)
66 }
67}
68
69func openRouterBuilder(model string) builderFunc {
70 return func(t *testing.T, r *recorder.Recorder) (ai.LanguageModel, error) {
71 tf := func() func() string {
72 id := 0
73 return func() string {
74 id += 1
75 return fmt.Sprintf("%s-%d", t.Name(), id)
76 }
77 }
78 provider := openrouter.New(
79 openrouter.WithAPIKey(os.Getenv("CRUSH_OPENROUTER_API_KEY")),
80 openrouter.WithHTTPClient(&http.Client{Transport: r}),
81 openrouter.WithLanguageUniqueToolCallIds(),
82 openrouter.WithLanguageModelGenerateIDFunc(tf()),
83 )
84 return provider.LanguageModel(model)
85 }
86}
87
88func testEnv(t *testing.T) env {
89 testDir := filepath.Join("/tmp/crush-test/", t.Name())
90 os.RemoveAll(testDir)
91 err := os.MkdirAll(testDir, 0o755)
92 t.Cleanup(func() {
93 os.RemoveAll(testDir)
94 })
95 require.NoError(t, err)
96 workingDir := testDir
97 conn, err := db.Connect(t.Context(), t.TempDir())
98 require.NoError(t, err)
99 q := db.New(conn)
100 sessions := session.NewService(q)
101 messages := message.NewService(q)
102 permissions := permission.NewPermissionService(workingDir, true, []string{})
103 history := history.NewService(q, conn)
104 lspClients := csync.NewMap[string, *lsp.Client]()
105 return env{
106 workingDir,
107 sessions,
108 messages,
109 permissions,
110 history,
111 lspClients,
112 }
113}
114
115func testSessionAgent(env env, large, small ai.LanguageModel, systemPrompt string, tools ...ai.AgentTool) SessionAgent {
116 largeModel := Model{
117 model: large,
118 config: catwalk.Model{
119 // todo: add values
120 },
121 }
122 smallModel := Model{
123 model: small,
124 config: catwalk.Model{
125 // todo: add values
126 },
127 }
128 agent := NewSessionAgent(largeModel, smallModel, systemPrompt, env.sessions, env.messages, tools...)
129 return agent
130}
131
132func coderAgent(r *recorder.Recorder, env env, large, small ai.LanguageModel) (SessionAgent, error) {
133 fixedTime := func() time.Time {
134 t, _ := time.Parse("1/2/2006", "1/1/2025")
135 return t
136 }
137 prompt, err := coderPrompt(
138 prompt.WithTimeFunc(fixedTime),
139 prompt.WithPlatform("linux"),
140 prompt.WithWorkingDir(env.workingDir),
141 )
142 if err != nil {
143 return nil, err
144 }
145 cfg, err := config.Init(env.workingDir, "", false)
146 if err != nil {
147 return nil, err
148 }
149
150 systemPrompt, err := prompt.Build(large.Provider(), large.Model(), *cfg)
151 if err != nil {
152 return nil, err
153 }
154 allTools := []ai.AgentTool{
155 tools.NewBashTool(env.permissions, env.workingDir, cfg.Options.Attribution),
156 tools.NewDownloadTool(env.permissions, env.workingDir, r.GetDefaultClient()),
157 tools.NewEditTool(env.lspClients, env.permissions, env.history, env.workingDir),
158 tools.NewMultiEditTool(env.lspClients, env.permissions, env.history, env.workingDir),
159 tools.NewFetchTool(env.permissions, env.workingDir, r.GetDefaultClient()),
160 tools.NewGlobTool(env.workingDir),
161 tools.NewGrepTool(env.workingDir),
162 tools.NewLsTool(env.permissions, env.workingDir, cfg.Tools.Ls),
163 tools.NewSourcegraphTool(r.GetDefaultClient()),
164 tools.NewViewTool(env.lspClients, env.permissions, env.workingDir),
165 tools.NewWriteTool(env.lspClients, env.permissions, env.history, env.workingDir),
166 }
167
168 return testSessionAgent(env, large, small, systemPrompt, allTools...), nil
169}
170
171// createSimpleGoProject creates a simple Go project structure in the given directory.
172// It creates a go.mod file and a main.go file with a basic hello world program.
173func createSimpleGoProject(t *testing.T, dir string) {
174 goMod := `module example.com/testproject
175
176go 1.23
177`
178 err := os.WriteFile(dir+"/go.mod", []byte(goMod), 0o644)
179 require.NoError(t, err)
180
181 mainGo := `package main
182
183import "fmt"
184
185func main() {
186 fmt.Println("Hello, World!")
187}
188`
189 err = os.WriteFile(dir+"/main.go", []byte(mainGo), 0o644)
190 require.NoError(t, err)
191}