1package agent
2
3import (
4 "context"
5 "net/http"
6 "os"
7 "path/filepath"
8 "testing"
9 "time"
10
11 "charm.land/catwalk/pkg/catwalk"
12 "charm.land/fantasy"
13 "charm.land/fantasy/providers/openaicompat"
14 "charm.land/x/vcr"
15 "github.com/charmbracelet/crush/internal/agent/prompt"
16 "github.com/charmbracelet/crush/internal/agent/tools"
17 "github.com/charmbracelet/crush/internal/config"
18 "github.com/charmbracelet/crush/internal/csync"
19 "github.com/charmbracelet/crush/internal/db"
20 "github.com/charmbracelet/crush/internal/filetracker"
21 "github.com/charmbracelet/crush/internal/history"
22 "github.com/charmbracelet/crush/internal/lsp"
23 "github.com/charmbracelet/crush/internal/message"
24 "github.com/charmbracelet/crush/internal/permission"
25 "github.com/charmbracelet/crush/internal/session"
26 "github.com/stretchr/testify/require"
27
28 _ "github.com/joho/godotenv/autoload"
29)
30
31// fakeEnv is an environment for testing.
32type fakeEnv struct {
33 workingDir string
34 sessions session.Service
35 messages message.Service
36 permissions permission.Service
37 history history.Service
38 filetracker *filetracker.Service
39 lspClients *csync.Map[string, *lsp.Client]
40}
41
42type builderFunc func(t *testing.T, r *vcr.Recorder) (fantasy.LanguageModel, error)
43
44type modelPair struct {
45 name string
46 largeModel builderFunc
47 smallModel builderFunc
48}
49
50func hyperBuilder(model string) builderFunc {
51 return func(t *testing.T, r *vcr.Recorder) (fantasy.LanguageModel, error) {
52 provider, err := openaicompat.New(
53 openaicompat.WithBaseURL("https://hyper.charm.land/v1"),
54 openaicompat.WithAPIKey(os.Getenv("CRUSH_HYPER_API_KEY")),
55 openaicompat.WithHTTPClient(&http.Client{Transport: r}),
56 )
57 if err != nil {
58 return nil, err
59 }
60 return provider.LanguageModel(t.Context(), model)
61 }
62}
63
64func testEnv(t *testing.T) fakeEnv {
65 workingDir := filepath.Join("/tmp/crush-test/", t.Name())
66 os.RemoveAll(workingDir)
67
68 err := os.MkdirAll(workingDir, 0o755)
69 require.NoError(t, err)
70
71 conn, err := db.Connect(t.Context(), t.TempDir())
72 require.NoError(t, err)
73
74 q := db.New(conn)
75 sessions := session.NewService(q, conn)
76 messages := message.NewService(q)
77
78 permissions := permission.NewPermissionService(workingDir, true, []string{})
79 history := history.NewService(q, conn)
80 filetrackerService := filetracker.NewService(q)
81 lspClients := csync.NewMap[string, *lsp.Client]()
82
83 t.Cleanup(func() {
84 conn.Close()
85 os.RemoveAll(workingDir)
86 })
87
88 return fakeEnv{
89 workingDir,
90 sessions,
91 messages,
92 permissions,
93 history,
94 &filetrackerService,
95 lspClients,
96 }
97}
98
99func testSessionAgent(env fakeEnv, large, small fantasy.LanguageModel, systemPrompt string, tools ...fantasy.AgentTool) SessionAgent {
100 largeModel := Model{
101 Model: large,
102 CatwalkCfg: catwalk.Model{
103 ContextWindow: 200000,
104 DefaultMaxTokens: 10000,
105 },
106 }
107 smallModel := Model{
108 Model: small,
109 CatwalkCfg: catwalk.Model{
110 ContextWindow: 200000,
111 DefaultMaxTokens: 10000,
112 },
113 }
114 agent := NewSessionAgent(SessionAgentOptions{
115 LargeModel: largeModel,
116 SmallModel: smallModel,
117 SystemPrompt: systemPrompt,
118 IsYolo: true,
119 Sessions: env.sessions,
120 Messages: env.messages,
121 Tools: tools,
122 })
123 return agent
124}
125
126func coderAgent(r *vcr.Recorder, env fakeEnv, large, small fantasy.LanguageModel) (SessionAgent, error) {
127 fixedTime := func() time.Time {
128 t, _ := time.Parse("1/2/2006", "1/1/2025")
129 return t
130 }
131 prompt, err := coderPrompt(
132 prompt.WithTimeFunc(fixedTime),
133 prompt.WithPlatform("linux"),
134 prompt.WithWorkingDir(filepath.ToSlash(env.workingDir)),
135 )
136 if err != nil {
137 return nil, err
138 }
139 cfg, err := config.Init(env.workingDir, "", false)
140 if err != nil {
141 return nil, err
142 }
143
144 // NOTE(@andreynering): Set a fixed config to ensure cassettes match
145 // independently of user config on `$HOME/.config/crush/crush.json`.
146 cfg.Config().Options.Attribution = &config.Attribution{
147 TrailerStyle: "co-authored-by",
148 GeneratedWith: true,
149 }
150
151 // Clear some fields to avoid issues with VCR cassette matching.
152 cfg.Config().Options.SkillsPaths = nil
153 cfg.Config().Options.DisabledSkills = []string{"crush-config"}
154 cfg.Config().Options.ContextPaths = nil
155 cfg.Config().LSP = nil
156
157 systemPrompt, err := prompt.Build(context.TODO(), large.Provider(), large.Model(), cfg)
158 if err != nil {
159 return nil, err
160 }
161
162 // Get the model name for the bash tool
163 modelName := large.Model() // fallback to ID if Name not available
164 if model := cfg.Config().GetModel(large.Provider(), large.Model()); model != nil {
165 modelName = model.Name
166 }
167
168 allTools := []fantasy.AgentTool{
169 tools.NewBashTool(env.permissions, env.workingDir, cfg.Config().Options.Attribution, modelName),
170 tools.NewDownloadTool(env.permissions, env.workingDir, r.GetDefaultClient()),
171 tools.NewEditTool(nil, env.permissions, env.history, *env.filetracker, env.workingDir),
172 tools.NewMultiEditTool(nil, env.permissions, env.history, *env.filetracker, env.workingDir),
173 tools.NewFetchTool(env.permissions, env.workingDir, r.GetDefaultClient()),
174 tools.NewGlobTool(env.workingDir),
175 tools.NewGrepTool(env.workingDir, cfg.Config().Tools.Grep),
176 tools.NewLsTool(env.permissions, env.workingDir, cfg.Config().Tools.Ls),
177 tools.NewSourcegraphTool(r.GetDefaultClient()),
178 tools.NewViewTool(nil, env.permissions, *env.filetracker, nil, env.workingDir),
179 tools.NewWriteTool(nil, env.permissions, env.history, *env.filetracker, env.workingDir),
180 }
181
182 return testSessionAgent(env, large, small, systemPrompt, allTools...), nil
183}
184
185// createSimpleGoProject creates a simple Go project structure in the given directory.
186// It creates a go.mod file and a main.go file with a basic hello world program.
187func createSimpleGoProject(t *testing.T, dir string) {
188 goMod := `module example.com/testproject
189
190go 1.23
191`
192 err := os.WriteFile(dir+"/go.mod", []byte(goMod), 0o644)
193 require.NoError(t, err)
194
195 mainGo := `package main
196
197import "fmt"
198
199func main() {
200 fmt.Println("Hello, World!")
201}
202`
203 err = os.WriteFile(dir+"/main.go", []byte(mainGo), 0o644)
204 require.NoError(t, err)
205}