common_test.go

  1package agent
  2
  3import (
  4	"context"
  5	"net/http"
  6	"os"
  7	"path/filepath"
  8	"testing"
  9	"time"
 10
 11	"charm.land/catwalk/pkg/catwalk"
 12	"charm.land/fantasy"
 13	"charm.land/fantasy/providers/anthropic"
 14	"charm.land/fantasy/providers/openai"
 15	"charm.land/fantasy/providers/openaicompat"
 16	"charm.land/fantasy/providers/openrouter"
 17	"charm.land/x/vcr"
 18	"github.com/charmbracelet/crush/internal/agent/prompt"
 19	"github.com/charmbracelet/crush/internal/agent/tools"
 20	"github.com/charmbracelet/crush/internal/config"
 21	"github.com/charmbracelet/crush/internal/csync"
 22	"github.com/charmbracelet/crush/internal/db"
 23	"github.com/charmbracelet/crush/internal/filetracker"
 24	"github.com/charmbracelet/crush/internal/history"
 25	"github.com/charmbracelet/crush/internal/lsp"
 26	"github.com/charmbracelet/crush/internal/message"
 27	"github.com/charmbracelet/crush/internal/permission"
 28	"github.com/charmbracelet/crush/internal/session"
 29	"github.com/stretchr/testify/require"
 30
 31	_ "github.com/joho/godotenv/autoload"
 32)
 33
 34// fakeEnv is an environment for testing.
 35type fakeEnv struct {
 36	workingDir  string
 37	sessions    session.Service
 38	messages    message.Service
 39	permissions permission.Service
 40	history     history.Service
 41	filetracker *filetracker.Service
 42	lspClients  *csync.Map[string, *lsp.Client]
 43}
 44
 45type builderFunc func(t *testing.T, r *vcr.Recorder) (fantasy.LanguageModel, error)
 46
 47type modelPair struct {
 48	name       string
 49	largeModel builderFunc
 50	smallModel builderFunc
 51}
 52
 53func anthropicBuilder(model string) builderFunc {
 54	return func(t *testing.T, r *vcr.Recorder) (fantasy.LanguageModel, error) {
 55		provider, err := anthropic.New(
 56			anthropic.WithAPIKey(os.Getenv("CRUSH_ANTHROPIC_API_KEY")),
 57			anthropic.WithHTTPClient(&http.Client{Transport: r}),
 58		)
 59		if err != nil {
 60			return nil, err
 61		}
 62		return provider.LanguageModel(t.Context(), model)
 63	}
 64}
 65
 66func openaiBuilder(model string) builderFunc {
 67	return func(t *testing.T, r *vcr.Recorder) (fantasy.LanguageModel, error) {
 68		provider, err := openai.New(
 69			openai.WithAPIKey(os.Getenv("CRUSH_OPENAI_API_KEY")),
 70			openai.WithHTTPClient(&http.Client{Transport: r}),
 71		)
 72		if err != nil {
 73			return nil, err
 74		}
 75		return provider.LanguageModel(t.Context(), model)
 76	}
 77}
 78
 79func openRouterBuilder(model string) builderFunc {
 80	return func(t *testing.T, r *vcr.Recorder) (fantasy.LanguageModel, error) {
 81		provider, err := openrouter.New(
 82			openrouter.WithAPIKey(os.Getenv("CRUSH_OPENROUTER_API_KEY")),
 83			openrouter.WithHTTPClient(&http.Client{Transport: r}),
 84		)
 85		if err != nil {
 86			return nil, err
 87		}
 88		return provider.LanguageModel(t.Context(), model)
 89	}
 90}
 91
 92func zAIBuilder(model string) builderFunc {
 93	return func(t *testing.T, r *vcr.Recorder) (fantasy.LanguageModel, error) {
 94		provider, err := openaicompat.New(
 95			openaicompat.WithBaseURL("https://api.z.ai/api/coding/paas/v4"),
 96			openaicompat.WithAPIKey(os.Getenv("CRUSH_ZAI_API_KEY")),
 97			openaicompat.WithHTTPClient(&http.Client{Transport: r}),
 98		)
 99		if err != nil {
100			return nil, err
101		}
102		return provider.LanguageModel(t.Context(), model)
103	}
104}
105
106func testEnv(t *testing.T) fakeEnv {
107	workingDir := filepath.Join("/tmp/crush-test/", t.Name())
108	os.RemoveAll(workingDir)
109
110	err := os.MkdirAll(workingDir, 0o755)
111	require.NoError(t, err)
112
113	conn, err := db.Connect(t.Context(), t.TempDir())
114	require.NoError(t, err)
115
116	q := db.New(conn)
117	sessions := session.NewService(q, conn)
118	messages := message.NewService(q)
119
120	permissions := permission.NewPermissionService(workingDir, true, []string{})
121	history := history.NewService(q, conn)
122	filetrackerService := filetracker.NewService(q)
123	lspClients := csync.NewMap[string, *lsp.Client]()
124
125	t.Cleanup(func() {
126		conn.Close()
127		os.RemoveAll(workingDir)
128	})
129
130	return fakeEnv{
131		workingDir,
132		sessions,
133		messages,
134		permissions,
135		history,
136		&filetrackerService,
137		lspClients,
138	}
139}
140
141func testSessionAgent(env fakeEnv, large, small fantasy.LanguageModel, systemPrompt string, tools ...fantasy.AgentTool) SessionAgent {
142	largeModel := Model{
143		Model: large,
144		CatwalkCfg: catwalk.Model{
145			ContextWindow:    200000,
146			DefaultMaxTokens: 10000,
147		},
148	}
149	smallModel := Model{
150		Model: small,
151		CatwalkCfg: catwalk.Model{
152			ContextWindow:    200000,
153			DefaultMaxTokens: 10000,
154		},
155	}
156	agent := NewSessionAgent(SessionAgentOptions{
157		LargeModel:   largeModel,
158		SmallModel:   smallModel,
159		SystemPrompt: systemPrompt,
160		IsYolo:       true,
161		Sessions:     env.sessions,
162		Messages:     env.messages,
163		Tools:        tools,
164	})
165	return agent
166}
167
168func coderAgent(r *vcr.Recorder, env fakeEnv, large, small fantasy.LanguageModel) (SessionAgent, error) {
169	fixedTime := func() time.Time {
170		t, _ := time.Parse("1/2/2006", "1/1/2025")
171		return t
172	}
173	prompt, err := coderPrompt(
174		prompt.WithTimeFunc(fixedTime),
175		prompt.WithPlatform("linux"),
176		prompt.WithWorkingDir(filepath.ToSlash(env.workingDir)),
177	)
178	if err != nil {
179		return nil, err
180	}
181	cfg, err := config.Init(env.workingDir, "", false)
182	if err != nil {
183		return nil, err
184	}
185
186	// NOTE(@andreynering): Set a fixed config to ensure cassettes match
187	// independently of user config on `$HOME/.config/crush/crush.json`.
188	cfg.Options.Attribution = &config.Attribution{
189		TrailerStyle:  "co-authored-by",
190		GeneratedWith: true,
191	}
192
193	// Clear some fields to avoid issues with VCR cassette matching.
194	cfg.Options.SkillsPaths = nil
195	cfg.Options.ContextPaths = nil
196	cfg.LSP = nil
197
198	systemPrompt, err := prompt.Build(context.TODO(), large.Provider(), large.Model(), *cfg)
199	if err != nil {
200		return nil, err
201	}
202
203	// Get the model name for the bash tool
204	modelName := large.Model() // fallback to ID if Name not available
205	if model := cfg.GetModel(large.Provider(), large.Model()); model != nil {
206		modelName = model.Name
207	}
208
209	allTools := []fantasy.AgentTool{
210		tools.NewBashTool(env.permissions, env.workingDir, cfg.Options.Attribution, modelName),
211		tools.NewDownloadTool(env.permissions, env.workingDir, r.GetDefaultClient()),
212		tools.NewEditTool(nil, env.permissions, env.history, *env.filetracker, env.workingDir),
213		tools.NewMultiEditTool(nil, env.permissions, env.history, *env.filetracker, env.workingDir),
214		tools.NewFetchTool(env.permissions, env.workingDir, r.GetDefaultClient()),
215		tools.NewGlobTool(env.workingDir),
216		tools.NewGrepTool(env.workingDir, cfg.Tools.Grep),
217		tools.NewLsTool(env.permissions, env.workingDir, cfg.Tools.Ls),
218		tools.NewSourcegraphTool(r.GetDefaultClient()),
219		tools.NewViewTool(nil, env.permissions, *env.filetracker, env.workingDir),
220		tools.NewWriteTool(nil, env.permissions, env.history, *env.filetracker, env.workingDir),
221	}
222
223	return testSessionAgent(env, large, small, systemPrompt, allTools...), nil
224}
225
226// createSimpleGoProject creates a simple Go project structure in the given directory.
227// It creates a go.mod file and a main.go file with a basic hello world program.
228func createSimpleGoProject(t *testing.T, dir string) {
229	goMod := `module example.com/testproject
230
231go 1.23
232`
233	err := os.WriteFile(dir+"/go.mod", []byte(goMod), 0o644)
234	require.NoError(t, err)
235
236	mainGo := `package main
237
238import "fmt"
239
240func main() {
241	fmt.Println("Hello, World!")
242}
243`
244	err = os.WriteFile(dir+"/main.go", []byte(mainGo), 0o644)
245	require.NoError(t, err)
246}