common_test.go

  1package agent
  2
  3import (
  4	"context"
  5	"net/http"
  6	"os"
  7	"path/filepath"
  8	"testing"
  9	"time"
 10
 11	"charm.land/catwalk/pkg/catwalk"
 12	"charm.land/fantasy"
 13	"charm.land/fantasy/providers/openaicompat"
 14	"charm.land/x/vcr"
 15	"github.com/charmbracelet/crush/internal/agent/prompt"
 16	"github.com/charmbracelet/crush/internal/agent/tools"
 17	"github.com/charmbracelet/crush/internal/config"
 18	"github.com/charmbracelet/crush/internal/csync"
 19	"github.com/charmbracelet/crush/internal/db"
 20	"github.com/charmbracelet/crush/internal/filetracker"
 21	"github.com/charmbracelet/crush/internal/history"
 22	"github.com/charmbracelet/crush/internal/lsp"
 23	"github.com/charmbracelet/crush/internal/message"
 24	"github.com/charmbracelet/crush/internal/permission"
 25	"github.com/charmbracelet/crush/internal/session"
 26	"github.com/stretchr/testify/require"
 27
 28	_ "github.com/joho/godotenv/autoload"
 29)
 30
 31// fakeEnv is an environment for testing.
 32type fakeEnv struct {
 33	workingDir  string
 34	sessions    session.Service
 35	messages    message.Service
 36	permissions permission.Service
 37	history     history.Service
 38	filetracker *filetracker.Service
 39	lspClients  *csync.Map[string, *lsp.Client]
 40}
 41
 42type builderFunc func(t *testing.T, r *vcr.Recorder) (fantasy.LanguageModel, error)
 43
 44type modelPair struct {
 45	name       string
 46	largeModel builderFunc
 47	smallModel builderFunc
 48}
 49
 50func hyperBuilder(model string) builderFunc {
 51	return func(t *testing.T, r *vcr.Recorder) (fantasy.LanguageModel, error) {
 52		provider, err := openaicompat.New(
 53			openaicompat.WithBaseURL("https://hyper.charm.land/v1"),
 54			openaicompat.WithAPIKey(os.Getenv("CRUSH_HYPER_API_KEY")),
 55			openaicompat.WithHTTPClient(&http.Client{Transport: r}),
 56		)
 57		if err != nil {
 58			return nil, err
 59		}
 60		return provider.LanguageModel(t.Context(), model)
 61	}
 62}
 63
 64func testEnv(t *testing.T) fakeEnv {
 65	workingDir := filepath.Join("/tmp/crush-test/", t.Name())
 66	os.RemoveAll(workingDir)
 67
 68	err := os.MkdirAll(workingDir, 0o755)
 69	require.NoError(t, err)
 70
 71	conn, err := db.Connect(t.Context(), t.TempDir())
 72	require.NoError(t, err)
 73
 74	q := db.New(conn)
 75	sessions := session.NewService(q, conn)
 76	messages := message.NewService(q)
 77
 78	permissions := permission.NewPermissionService(workingDir, true, []string{})
 79	history := history.NewService(q, conn)
 80	filetrackerService := filetracker.NewService(q)
 81	lspClients := csync.NewMap[string, *lsp.Client]()
 82
 83	t.Cleanup(func() {
 84		conn.Close()
 85		os.RemoveAll(workingDir)
 86	})
 87
 88	return fakeEnv{
 89		workingDir,
 90		sessions,
 91		messages,
 92		permissions,
 93		history,
 94		&filetrackerService,
 95		lspClients,
 96	}
 97}
 98
 99func testSessionAgent(env fakeEnv, large, small fantasy.LanguageModel, systemPrompt string, tools ...fantasy.AgentTool) SessionAgent {
100	largeModel := Model{
101		Model: large,
102		CatwalkCfg: catwalk.Model{
103			ContextWindow:    200000,
104			DefaultMaxTokens: 10000,
105		},
106	}
107	smallModel := Model{
108		Model: small,
109		CatwalkCfg: catwalk.Model{
110			ContextWindow:    200000,
111			DefaultMaxTokens: 10000,
112		},
113	}
114	agent := NewSessionAgent(SessionAgentOptions{
115		LargeModel:   largeModel,
116		SmallModel:   smallModel,
117		SystemPrompt: systemPrompt,
118		IsYolo:       true,
119		Sessions:     env.sessions,
120		Messages:     env.messages,
121		Tools:        tools,
122	})
123	return agent
124}
125
126func coderAgent(r *vcr.Recorder, env fakeEnv, large, small fantasy.LanguageModel) (SessionAgent, error) {
127	fixedTime := func() time.Time {
128		t, _ := time.Parse("1/2/2006", "1/1/2025")
129		return t
130	}
131	prompt, err := coderPrompt(
132		prompt.WithTimeFunc(fixedTime),
133		prompt.WithPlatform("linux"),
134		prompt.WithWorkingDir(filepath.ToSlash(env.workingDir)),
135	)
136	if err != nil {
137		return nil, err
138	}
139	cfg, err := config.Init(env.workingDir, "", false)
140	if err != nil {
141		return nil, err
142	}
143
144	// NOTE(@andreynering): Set a fixed config to ensure cassettes match
145	// independently of user config on `$HOME/.config/crush/crush.json`.
146	cfg.Config().Options.Attribution = &config.Attribution{
147		TrailerStyle:  "co-authored-by",
148		GeneratedWith: true,
149	}
150
151	// Clear some fields to avoid issues with VCR cassette matching.
152	cfg.Config().Options.SkillsPaths = nil
153	cfg.Config().Options.DisabledSkills = []string{"crush-config"}
154	cfg.Config().Options.ContextPaths = nil
155	cfg.Config().LSP = nil
156
157	systemPrompt, err := prompt.Build(context.TODO(), large.Provider(), large.Model(), cfg)
158	if err != nil {
159		return nil, err
160	}
161
162	// Get the model name for the bash tool
163	modelName := large.Model() // fallback to ID if Name not available
164	if model := cfg.Config().GetModel(large.Provider(), large.Model()); model != nil {
165		modelName = model.Name
166	}
167
168	allTools := []fantasy.AgentTool{
169		tools.NewBashTool(env.permissions, env.workingDir, cfg.Config().Options.Attribution, modelName),
170		tools.NewDownloadTool(env.permissions, env.workingDir, r.GetDefaultClient()),
171		tools.NewEditTool(nil, env.permissions, env.history, *env.filetracker, env.workingDir),
172		tools.NewMultiEditTool(nil, env.permissions, env.history, *env.filetracker, env.workingDir),
173		tools.NewFetchTool(env.permissions, env.workingDir, r.GetDefaultClient()),
174		tools.NewGlobTool(env.workingDir),
175		tools.NewGrepTool(env.workingDir, cfg.Config().Tools.Grep),
176		tools.NewLsTool(env.permissions, env.workingDir, cfg.Config().Tools.Ls),
177		tools.NewSourcegraphTool(r.GetDefaultClient()),
178		tools.NewViewTool(nil, env.permissions, *env.filetracker, nil, env.workingDir),
179		tools.NewWriteTool(nil, env.permissions, env.history, *env.filetracker, env.workingDir),
180	}
181
182	return testSessionAgent(env, large, small, systemPrompt, allTools...), nil
183}
184
185// createSimpleGoProject creates a simple Go project structure in the given directory.
186// It creates a go.mod file and a main.go file with a basic hello world program.
187func createSimpleGoProject(t *testing.T, dir string) {
188	goMod := `module example.com/testproject
189
190go 1.23
191`
192	err := os.WriteFile(dir+"/go.mod", []byte(goMod), 0o644)
193	require.NoError(t, err)
194
195	mainGo := `package main
196
197import "fmt"
198
199func main() {
200	fmt.Println("Hello, World!")
201}
202`
203	err = os.WriteFile(dir+"/main.go", []byte(mainGo), 0o644)
204	require.NoError(t, err)
205}