1package agent
2
3import (
4 "fmt"
5 "net/http"
6 "os"
7 "path/filepath"
8 "testing"
9 "time"
10
11 "github.com/charmbracelet/catwalk/pkg/catwalk"
12 "github.com/charmbracelet/crush/internal/agent/prompt"
13 "github.com/charmbracelet/crush/internal/agent/tools"
14 "github.com/charmbracelet/crush/internal/config"
15 "github.com/charmbracelet/crush/internal/csync"
16 "github.com/charmbracelet/crush/internal/db"
17 "github.com/charmbracelet/crush/internal/history"
18 "github.com/charmbracelet/crush/internal/lsp"
19 "github.com/charmbracelet/crush/internal/message"
20 "github.com/charmbracelet/crush/internal/permission"
21 "github.com/charmbracelet/crush/internal/session"
22 "github.com/charmbracelet/fantasy/ai"
23 "github.com/charmbracelet/fantasy/anthropic"
24 "github.com/charmbracelet/fantasy/openai"
25 "github.com/charmbracelet/fantasy/openrouter"
26 "github.com/stretchr/testify/require"
27 "gopkg.in/dnaeon/go-vcr.v4/pkg/recorder"
28
29 _ "github.com/joho/godotenv/autoload"
30)
31
32type env struct {
33 workingDir string
34 sessions session.Service
35 messages message.Service
36 permissions permission.Service
37 history history.Service
38 lspClients *csync.Map[string, *lsp.Client]
39}
40
41type builderFunc func(t *testing.T, r *recorder.Recorder) (ai.LanguageModel, error)
42
43type modelPair struct {
44 name string
45 largeModel builderFunc
46 smallModel builderFunc
47}
48
49func anthropicBuilder(model string) builderFunc {
50 return func(_ *testing.T, r *recorder.Recorder) (ai.LanguageModel, error) {
51 provider := anthropic.New(
52 anthropic.WithAPIKey(os.Getenv("CRUSH_ANTHROPIC_API_KEY")),
53 anthropic.WithHTTPClient(&http.Client{Transport: r}),
54 )
55 return provider.LanguageModel(model)
56 }
57}
58
59func openaiBuilder(model string) builderFunc {
60 return func(_ *testing.T, r *recorder.Recorder) (ai.LanguageModel, error) {
61 provider := openai.New(
62 openai.WithAPIKey(os.Getenv("CRUSH_OPENAI_API_KEY")),
63 openai.WithHTTPClient(&http.Client{Transport: r}),
64 )
65 return provider.LanguageModel(model)
66 }
67}
68
69func openRouterBuilder(model string) builderFunc {
70 return func(t *testing.T, r *recorder.Recorder) (ai.LanguageModel, error) {
71 tf := func() func() string {
72 id := 0
73 return func() string {
74 id += 1
75 return fmt.Sprintf("%s-%d", t.Name(), id)
76 }
77 }
78 provider := openrouter.New(
79 openrouter.WithAPIKey(os.Getenv("CRUSH_OPENROUTER_API_KEY")),
80 openrouter.WithHTTPClient(&http.Client{Transport: r}),
81 openrouter.WithLanguageUniqueToolCallIds(),
82 openrouter.WithLanguageModelGenerateIDFunc(tf()),
83 )
84 return provider.LanguageModel(model)
85 }
86}
87
88func testEnv(t *testing.T) env {
89 testDir := filepath.Join("/tmp/crush-test/", t.Name())
90 os.RemoveAll(testDir)
91 err := os.MkdirAll(testDir, 0o755)
92 t.Cleanup(func() {
93 os.RemoveAll(testDir)
94 })
95 require.NoError(t, err)
96 workingDir := testDir
97 conn, err := db.Connect(t.Context(), t.TempDir())
98 require.NoError(t, err)
99 q := db.New(conn)
100 sessions := session.NewService(q)
101 messages := message.NewService(q)
102 permissions := permission.NewPermissionService(workingDir, true, []string{})
103 history := history.NewService(q, conn)
104 lspClients := csync.NewMap[string, *lsp.Client]()
105 return env{
106 workingDir,
107 sessions,
108 messages,
109 permissions,
110 history,
111 lspClients,
112 }
113}
114
115func testSessionAgent(env env, large, small ai.LanguageModel, systemPrompt string, tools ...ai.AgentTool) SessionAgent {
116 largeModel := Model{
117 model: large,
118 config: catwalk.Model{
119 // todo: add values
120 },
121 }
122 smallModel := Model{
123 model: small,
124 config: catwalk.Model{
125 // todo: add values
126 },
127 }
128 agent := NewSessionAgent(largeModel, smallModel, systemPrompt, env.sessions, env.messages, tools...)
129 return agent
130}
131
132func coderAgent(env env, large, small ai.LanguageModel) (SessionAgent, error) {
133 fixedTime := func() time.Time {
134 t, _ := time.Parse("1/2/2006", "1/1/2025")
135 return t
136 }
137 prompt, err := coderPrompt(prompt.WithTimeFunc(fixedTime))
138 if err != nil {
139 return nil, err
140 }
141 cfg, err := config.Init(env.workingDir, "", false)
142 if err != nil {
143 return nil, err
144 }
145
146 systemPrompt, err := prompt.Build(large.Provider(), large.Model(), *cfg)
147 if err != nil {
148 return nil, err
149 }
150 allTools := []ai.AgentTool{
151 tools.NewBashTool(env.permissions, env.workingDir, cfg.Options.Attribution),
152 tools.NewDownloadTool(env.permissions, env.workingDir),
153 tools.NewEditTool(env.lspClients, env.permissions, env.history, env.workingDir),
154 tools.NewMultiEditTool(env.lspClients, env.permissions, env.history, env.workingDir),
155 tools.NewFetchTool(env.permissions, env.workingDir),
156 tools.NewGlobTool(env.workingDir),
157 tools.NewGrepTool(env.workingDir),
158 tools.NewLsTool(env.permissions, env.workingDir),
159 tools.NewSourcegraphTool(),
160 tools.NewViewTool(env.lspClients, env.permissions, env.workingDir),
161 tools.NewWriteTool(env.lspClients, env.permissions, env.history, env.workingDir),
162 }
163
164 return testSessionAgent(env, large, small, systemPrompt, allTools...), nil
165}
166
167// createSimpleGoProject creates a simple Go project structure in the given directory.
168// It creates a go.mod file and a main.go file with a basic hello world program.
169func createSimpleGoProject(t *testing.T, dir string) {
170 goMod := `module example.com/testproject
171
172go 1.23
173`
174 err := os.WriteFile(dir+"/go.mod", []byte(goMod), 0o644)
175 require.NoError(t, err)
176
177 mainGo := `package main
178
179import "fmt"
180
181func main() {
182 fmt.Println("Hello, World!")
183}
184`
185 err = os.WriteFile(dir+"/main.go", []byte(mainGo), 0o644)
186 require.NoError(t, err)
187}