common_test.go

  1package agent
  2
  3import (
  4	"context"
  5	"net/http"
  6	"os"
  7	"path/filepath"
  8	"testing"
  9	"time"
 10
 11	"charm.land/fantasy"
 12	"charm.land/fantasy/providers/anthropic"
 13	"charm.land/fantasy/providers/openai"
 14	"charm.land/fantasy/providers/openaicompat"
 15	"charm.land/fantasy/providers/openrouter"
 16	"charm.land/x/vcr"
 17	"github.com/charmbracelet/catwalk/pkg/catwalk"
 18	"github.com/charmbracelet/crush/internal/agent/prompt"
 19	"github.com/charmbracelet/crush/internal/agent/tools"
 20	"github.com/charmbracelet/crush/internal/config"
 21	"github.com/charmbracelet/crush/internal/csync"
 22	"github.com/charmbracelet/crush/internal/db"
 23	"github.com/charmbracelet/crush/internal/history"
 24	"github.com/charmbracelet/crush/internal/lsp"
 25	"github.com/charmbracelet/crush/internal/message"
 26	"github.com/charmbracelet/crush/internal/permission"
 27	"github.com/charmbracelet/crush/internal/session"
 28	"github.com/stretchr/testify/require"
 29
 30	_ "github.com/joho/godotenv/autoload"
 31)
 32
 33// fakeEnv is an environment for testing.
 34type fakeEnv struct {
 35	workingDir  string
 36	sessions    session.Service
 37	messages    message.Service
 38	permissions permission.Service
 39	history     history.Service
 40	lspClients  *csync.Map[string, *lsp.Client]
 41}
 42
 43type builderFunc func(t *testing.T, r *vcr.Recorder) (fantasy.LanguageModel, error)
 44
 45type modelPair struct {
 46	name       string
 47	largeModel builderFunc
 48	smallModel builderFunc
 49}
 50
 51func anthropicBuilder(model string) builderFunc {
 52	return func(t *testing.T, r *vcr.Recorder) (fantasy.LanguageModel, error) {
 53		provider, err := anthropic.New(
 54			anthropic.WithAPIKey(os.Getenv("CRUSH_ANTHROPIC_API_KEY")),
 55			anthropic.WithHTTPClient(&http.Client{Transport: r}),
 56		)
 57		if err != nil {
 58			return nil, err
 59		}
 60		return provider.LanguageModel(t.Context(), model)
 61	}
 62}
 63
 64func openaiBuilder(model string) builderFunc {
 65	return func(t *testing.T, r *vcr.Recorder) (fantasy.LanguageModel, error) {
 66		provider, err := openai.New(
 67			openai.WithAPIKey(os.Getenv("CRUSH_OPENAI_API_KEY")),
 68			openai.WithHTTPClient(&http.Client{Transport: r}),
 69		)
 70		if err != nil {
 71			return nil, err
 72		}
 73		return provider.LanguageModel(t.Context(), model)
 74	}
 75}
 76
 77func openRouterBuilder(model string) builderFunc {
 78	return func(t *testing.T, r *vcr.Recorder) (fantasy.LanguageModel, error) {
 79		provider, err := openrouter.New(
 80			openrouter.WithAPIKey(os.Getenv("CRUSH_OPENROUTER_API_KEY")),
 81			openrouter.WithHTTPClient(&http.Client{Transport: r}),
 82		)
 83		if err != nil {
 84			return nil, err
 85		}
 86		return provider.LanguageModel(t.Context(), model)
 87	}
 88}
 89
 90func zAIBuilder(model string) builderFunc {
 91	return func(t *testing.T, r *vcr.Recorder) (fantasy.LanguageModel, error) {
 92		provider, err := openaicompat.New(
 93			openaicompat.WithBaseURL("https://api.z.ai/api/coding/paas/v4"),
 94			openaicompat.WithAPIKey(os.Getenv("CRUSH_ZAI_API_KEY")),
 95			openaicompat.WithHTTPClient(&http.Client{Transport: r}),
 96		)
 97		if err != nil {
 98			return nil, err
 99		}
100		return provider.LanguageModel(t.Context(), model)
101	}
102}
103
104func testEnv(t *testing.T) fakeEnv {
105	workingDir := filepath.Join("/tmp/crush-test/", t.Name())
106	os.RemoveAll(workingDir)
107
108	err := os.MkdirAll(workingDir, 0o755)
109	require.NoError(t, err)
110
111	conn, err := db.Connect(t.Context(), t.TempDir())
112	require.NoError(t, err)
113
114	q := db.New(conn)
115	sessions := session.NewService(q)
116	messages := message.NewService(q)
117
118	permissions := permission.NewPermissionService(workingDir, true, []string{})
119	history := history.NewService(q, conn)
120	lspClients := csync.NewMap[string, *lsp.Client]()
121
122	t.Cleanup(func() {
123		conn.Close()
124		os.RemoveAll(workingDir)
125	})
126
127	return fakeEnv{
128		workingDir,
129		sessions,
130		messages,
131		permissions,
132		history,
133		lspClients,
134	}
135}
136
137func testSessionAgent(env fakeEnv, large, small fantasy.LanguageModel, systemPrompt string, tools ...fantasy.AgentTool) SessionAgent {
138	largeModel := Model{
139		Model: large,
140		CatwalkCfg: catwalk.Model{
141			ContextWindow:    200000,
142			DefaultMaxTokens: 10000,
143		},
144	}
145	smallModel := Model{
146		Model: small,
147		CatwalkCfg: catwalk.Model{
148			ContextWindow:    200000,
149			DefaultMaxTokens: 10000,
150		},
151	}
152	agent := NewSessionAgent(SessionAgentOptions{largeModel, smallModel, "", systemPrompt, false, false, true, env.sessions, env.messages, tools})
153	return agent
154}
155
156func coderAgent(r *vcr.Recorder, env fakeEnv, large, small fantasy.LanguageModel) (SessionAgent, error) {
157	fixedTime := func() time.Time {
158		t, _ := time.Parse("1/2/2006", "1/1/2025")
159		return t
160	}
161	prompt, err := coderPrompt(
162		prompt.WithTimeFunc(fixedTime),
163		prompt.WithPlatform("linux"),
164		prompt.WithWorkingDir(filepath.ToSlash(env.workingDir)),
165	)
166	if err != nil {
167		return nil, err
168	}
169	cfg, err := config.Init(env.workingDir, "", false)
170	if err != nil {
171		return nil, err
172	}
173
174	// NOTE(@andreynering): Set a fixed config to ensure cassettes match
175	// independently of user config on `$HOME/.config/crush/crush.json`.
176	cfg.Options.Attribution = &config.Attribution{
177		TrailerStyle:  "co-authored-by",
178		GeneratedWith: true,
179	}
180
181	// Clear skills paths to ensure test reproducibility - user's skills
182	// would be included in prompt and break VCR cassette matching.
183	cfg.Options.SkillsPaths = []string{}
184
185	systemPrompt, err := prompt.Build(context.TODO(), large.Provider(), large.Model(), *cfg)
186	if err != nil {
187		return nil, err
188	}
189
190	// Get the model name for the bash tool
191	modelName := large.Model() // fallback to ID if Name not available
192	if model := cfg.GetModel(large.Provider(), large.Model()); model != nil {
193		modelName = model.Name
194	}
195
196	allTools := []fantasy.AgentTool{
197		tools.NewBashTool(env.permissions, env.workingDir, cfg.Options.Attribution, modelName),
198		tools.NewDownloadTool(env.permissions, env.workingDir, r.GetDefaultClient()),
199		tools.NewEditTool(env.lspClients, env.permissions, env.history, env.workingDir),
200		tools.NewMultiEditTool(env.lspClients, env.permissions, env.history, env.workingDir),
201		tools.NewFetchTool(env.permissions, env.workingDir, r.GetDefaultClient()),
202		tools.NewGlobTool(env.workingDir),
203		tools.NewGrepTool(env.workingDir),
204		tools.NewLsTool(env.permissions, env.workingDir, cfg.Tools.Ls),
205		tools.NewSourcegraphTool(r.GetDefaultClient()),
206		tools.NewViewTool(env.lspClients, env.permissions, env.workingDir),
207		tools.NewWriteTool(env.lspClients, env.permissions, env.history, env.workingDir),
208	}
209
210	return testSessionAgent(env, large, small, systemPrompt, allTools...), nil
211}
212
213// createSimpleGoProject creates a simple Go project structure in the given directory.
214// It creates a go.mod file and a main.go file with a basic hello world program.
215func createSimpleGoProject(t *testing.T, dir string) {
216	goMod := `module example.com/testproject
217
218go 1.23
219`
220	err := os.WriteFile(dir+"/go.mod", []byte(goMod), 0o644)
221	require.NoError(t, err)
222
223	mainGo := `package main
224
225import "fmt"
226
227func main() {
228	fmt.Println("Hello, World!")
229}
230`
231	err = os.WriteFile(dir+"/main.go", []byte(mainGo), 0o644)
232	require.NoError(t, err)
233}