common_test.go

  1package agent
  2
  3import (
  4	"context"
  5	"net/http"
  6	"os"
  7	"path/filepath"
  8	"testing"
  9	"time"
 10
 11	"charm.land/catwalk/pkg/catwalk"
 12	"charm.land/fantasy"
 13	"charm.land/fantasy/providers/anthropic"
 14	"charm.land/fantasy/providers/openai"
 15	"charm.land/fantasy/providers/openaicompat"
 16	"charm.land/fantasy/providers/openrouter"
 17	"charm.land/x/vcr"
 18	"github.com/charmbracelet/crush/internal/agent/prompt"
 19	"github.com/charmbracelet/crush/internal/agent/tools"
 20	"github.com/charmbracelet/crush/internal/config"
 21	"github.com/charmbracelet/crush/internal/csync"
 22	"github.com/charmbracelet/crush/internal/db"
 23	"github.com/charmbracelet/crush/internal/filetracker"
 24	"github.com/charmbracelet/crush/internal/history"
 25	"github.com/charmbracelet/crush/internal/lsp"
 26	"github.com/charmbracelet/crush/internal/message"
 27	"github.com/charmbracelet/crush/internal/permission"
 28	"github.com/charmbracelet/crush/internal/session"
 29	"github.com/stretchr/testify/require"
 30
 31	_ "github.com/joho/godotenv/autoload"
 32)
 33
 34// fakeEnv is an environment for testing.
 35type fakeEnv struct {
 36	workingDir  string
 37	sessions    session.Service
 38	messages    message.Service
 39	permissions permission.Service
 40	history     history.Service
 41	filetracker *filetracker.Service
 42	lspClients  *csync.Map[string, *lsp.Client]
 43}
 44
 45type builderFunc func(t *testing.T, r *vcr.Recorder) (fantasy.LanguageModel, error)
 46
 47type modelPair struct {
 48	name       string
 49	largeModel builderFunc
 50	smallModel builderFunc
 51}
 52
 53func anthropicBuilder(model string) builderFunc {
 54	return func(t *testing.T, r *vcr.Recorder) (fantasy.LanguageModel, error) {
 55		provider, err := anthropic.New(
 56			anthropic.WithAPIKey(os.Getenv("CRUSH_ANTHROPIC_API_KEY")),
 57			anthropic.WithHTTPClient(&http.Client{Transport: r}),
 58		)
 59		if err != nil {
 60			return nil, err
 61		}
 62		return provider.LanguageModel(t.Context(), model)
 63	}
 64}
 65
 66func openaiBuilder(model string) builderFunc {
 67	return func(t *testing.T, r *vcr.Recorder) (fantasy.LanguageModel, error) {
 68		provider, err := openai.New(
 69			openai.WithAPIKey(os.Getenv("CRUSH_OPENAI_API_KEY")),
 70			openai.WithHTTPClient(&http.Client{Transport: r}),
 71		)
 72		if err != nil {
 73			return nil, err
 74		}
 75		return provider.LanguageModel(t.Context(), model)
 76	}
 77}
 78
 79func openRouterBuilder(model string) builderFunc {
 80	return func(t *testing.T, r *vcr.Recorder) (fantasy.LanguageModel, error) {
 81		provider, err := openrouter.New(
 82			openrouter.WithAPIKey(os.Getenv("CRUSH_OPENROUTER_API_KEY")),
 83			openrouter.WithHTTPClient(&http.Client{Transport: r}),
 84		)
 85		if err != nil {
 86			return nil, err
 87		}
 88		return provider.LanguageModel(t.Context(), model)
 89	}
 90}
 91
 92func zAIBuilder(model string) builderFunc {
 93	return func(t *testing.T, r *vcr.Recorder) (fantasy.LanguageModel, error) {
 94		provider, err := openaicompat.New(
 95			openaicompat.WithBaseURL("https://api.z.ai/api/coding/paas/v4"),
 96			openaicompat.WithAPIKey(os.Getenv("CRUSH_ZAI_API_KEY")),
 97			openaicompat.WithHTTPClient(&http.Client{Transport: r}),
 98		)
 99		if err != nil {
100			return nil, err
101		}
102		return provider.LanguageModel(t.Context(), model)
103	}
104}
105
106func testEnv(t *testing.T) fakeEnv {
107	workingDir := filepath.Join("/tmp/crush-test/", t.Name())
108	os.RemoveAll(workingDir)
109
110	err := os.MkdirAll(workingDir, 0o755)
111	require.NoError(t, err)
112
113	conn, err := db.Connect(t.Context(), t.TempDir())
114	require.NoError(t, err)
115
116	q := db.New(conn)
117	sessions := session.NewService(q, conn)
118	messages := message.NewService(q)
119
120	permissions := permission.NewPermissionService(workingDir, true, []string{})
121	history := history.NewService(q, conn)
122	filetrackerService := filetracker.NewService(q)
123	lspClients := csync.NewMap[string, *lsp.Client]()
124
125	t.Cleanup(func() {
126		conn.Close()
127		os.RemoveAll(workingDir)
128	})
129
130	return fakeEnv{
131		workingDir,
132		sessions,
133		messages,
134		permissions,
135		history,
136		&filetrackerService,
137		lspClients,
138	}
139}
140
141func testSessionAgent(env fakeEnv, large, small fantasy.LanguageModel, systemPrompt string, tools ...fantasy.AgentTool) SessionAgent {
142	largeModel := Model{
143		Model: large,
144		CatwalkCfg: catwalk.Model{
145			ContextWindow:    200000,
146			DefaultMaxTokens: 10000,
147		},
148	}
149	smallModel := Model{
150		Model: small,
151		CatwalkCfg: catwalk.Model{
152			ContextWindow:    200000,
153			DefaultMaxTokens: 10000,
154		},
155	}
156	agent := NewSessionAgent(SessionAgentOptions{largeModel, smallModel, "", systemPrompt, false, false, true, env.sessions, env.messages, tools})
157	return agent
158}
159
160func coderAgent(r *vcr.Recorder, env fakeEnv, large, small fantasy.LanguageModel) (SessionAgent, error) {
161	fixedTime := func() time.Time {
162		t, _ := time.Parse("1/2/2006", "1/1/2025")
163		return t
164	}
165	prompt, err := coderPrompt(
166		prompt.WithTimeFunc(fixedTime),
167		prompt.WithPlatform("linux"),
168		prompt.WithWorkingDir(filepath.ToSlash(env.workingDir)),
169	)
170	if err != nil {
171		return nil, err
172	}
173	cfg, err := config.Init(env.workingDir, "", false)
174	if err != nil {
175		return nil, err
176	}
177
178	// NOTE(@andreynering): Set a fixed config to ensure cassettes match
179	// independently of user config on `$HOME/.config/crush/crush.json`.
180	cfg.Options.Attribution = &config.Attribution{
181		TrailerStyle:  "co-authored-by",
182		GeneratedWith: true,
183	}
184
185	// Clear skills paths to ensure test reproducibility - user's skills
186	// would be included in prompt and break VCR cassette matching.
187	cfg.Options.SkillsPaths = []string{}
188
189	// Clear LSP config to ensure test reproducibility - user's LSP config
190	// would be included in prompt and break VCR cassette matching.
191	cfg.LSP = nil
192
193	systemPrompt, err := prompt.Build(context.TODO(), large.Provider(), large.Model(), *cfg)
194	if err != nil {
195		return nil, err
196	}
197
198	// Get the model name for the bash tool
199	modelName := large.Model() // fallback to ID if Name not available
200	if model := cfg.GetModel(large.Provider(), large.Model()); model != nil {
201		modelName = model.Name
202	}
203
204	allTools := []fantasy.AgentTool{
205		tools.NewBashTool(env.permissions, env.workingDir, cfg.Options.Attribution, modelName),
206		tools.NewDownloadTool(env.permissions, env.workingDir, r.GetDefaultClient()),
207		tools.NewEditTool(nil, env.permissions, env.history, *env.filetracker, env.workingDir),
208		tools.NewMultiEditTool(nil, env.permissions, env.history, *env.filetracker, env.workingDir),
209		tools.NewFetchTool(env.permissions, env.workingDir, r.GetDefaultClient()),
210		tools.NewGlobTool(env.workingDir),
211		tools.NewGrepTool(env.workingDir, cfg.Tools.Grep),
212		tools.NewLsTool(env.permissions, env.workingDir, cfg.Tools.Ls),
213		tools.NewSourcegraphTool(r.GetDefaultClient()),
214		tools.NewViewTool(nil, env.permissions, *env.filetracker, env.workingDir),
215		tools.NewWriteTool(nil, env.permissions, env.history, *env.filetracker, env.workingDir),
216	}
217
218	return testSessionAgent(env, large, small, systemPrompt, allTools...), nil
219}
220
221// createSimpleGoProject creates a simple Go project structure in the given directory.
222// It creates a go.mod file and a main.go file with a basic hello world program.
223func createSimpleGoProject(t *testing.T, dir string) {
224	goMod := `module example.com/testproject
225
226go 1.23
227`
228	err := os.WriteFile(dir+"/go.mod", []byte(goMod), 0o644)
229	require.NoError(t, err)
230
231	mainGo := `package main
232
233import "fmt"
234
235func main() {
236	fmt.Println("Hello, World!")
237}
238`
239	err = os.WriteFile(dir+"/main.go", []byte(mainGo), 0o644)
240	require.NoError(t, err)
241}