common_test.go

  1package agent
  2
  3import (
  4	"context"
  5	"net/http"
  6	"os"
  7	"path/filepath"
  8	"testing"
  9	"time"
 10
 11	"charm.land/fantasy"
 12	"charm.land/fantasy/providers/anthropic"
 13	"charm.land/fantasy/providers/openai"
 14	"charm.land/fantasy/providers/openaicompat"
 15	"charm.land/fantasy/providers/openrouter"
 16	"charm.land/x/vcr"
 17	"github.com/charmbracelet/catwalk/pkg/catwalk"
 18	"github.com/charmbracelet/crush/internal/agent/prompt"
 19	"github.com/charmbracelet/crush/internal/agent/tools"
 20	"github.com/charmbracelet/crush/internal/config"
 21	"github.com/charmbracelet/crush/internal/csync"
 22	"github.com/charmbracelet/crush/internal/db"
 23	"github.com/charmbracelet/crush/internal/history"
 24	"github.com/charmbracelet/crush/internal/lsp"
 25	"github.com/charmbracelet/crush/internal/message"
 26	"github.com/charmbracelet/crush/internal/permission"
 27	"github.com/charmbracelet/crush/internal/session"
 28	"github.com/stretchr/testify/require"
 29
 30	_ "github.com/joho/godotenv/autoload"
 31)
 32
 33// fakeEnv is an environment for testing.
 34type fakeEnv struct {
 35	workingDir  string
 36	sessions    session.Service
 37	messages    message.Service
 38	permissions permission.Service
 39	history     history.Service
 40	lspClients  *csync.Map[string, *lsp.Client]
 41}
 42
 43type builderFunc func(t *testing.T, r *vcr.Recorder) (fantasy.LanguageModel, error)
 44
 45type modelPair struct {
 46	name       string
 47	largeModel builderFunc
 48	smallModel builderFunc
 49}
 50
 51func anthropicBuilder(model string) builderFunc {
 52	return func(t *testing.T, r *vcr.Recorder) (fantasy.LanguageModel, error) {
 53		provider, err := anthropic.New(
 54			anthropic.WithAPIKey(os.Getenv("CRUSH_ANTHROPIC_API_KEY")),
 55			anthropic.WithHTTPClient(&http.Client{Transport: r}),
 56		)
 57		if err != nil {
 58			return nil, err
 59		}
 60		return provider.LanguageModel(t.Context(), model)
 61	}
 62}
 63
 64func openaiBuilder(model string) builderFunc {
 65	return func(t *testing.T, r *vcr.Recorder) (fantasy.LanguageModel, error) {
 66		provider, err := openai.New(
 67			openai.WithAPIKey(os.Getenv("CRUSH_OPENAI_API_KEY")),
 68			openai.WithHTTPClient(&http.Client{Transport: r}),
 69		)
 70		if err != nil {
 71			return nil, err
 72		}
 73		return provider.LanguageModel(t.Context(), model)
 74	}
 75}
 76
 77func openRouterBuilder(model string) builderFunc {
 78	return func(t *testing.T, r *vcr.Recorder) (fantasy.LanguageModel, error) {
 79		provider, err := openrouter.New(
 80			openrouter.WithAPIKey(os.Getenv("CRUSH_OPENROUTER_API_KEY")),
 81			openrouter.WithHTTPClient(&http.Client{Transport: r}),
 82		)
 83		if err != nil {
 84			return nil, err
 85		}
 86		return provider.LanguageModel(t.Context(), model)
 87	}
 88}
 89
 90func zAIBuilder(model string) builderFunc {
 91	return func(t *testing.T, r *vcr.Recorder) (fantasy.LanguageModel, error) {
 92		provider, err := openaicompat.New(
 93			openaicompat.WithBaseURL("https://api.z.ai/api/coding/paas/v4"),
 94			openaicompat.WithAPIKey(os.Getenv("CRUSH_ZAI_API_KEY")),
 95			openaicompat.WithHTTPClient(&http.Client{Transport: r}),
 96		)
 97		if err != nil {
 98			return nil, err
 99		}
100		return provider.LanguageModel(t.Context(), model)
101	}
102}
103
104func testEnv(t *testing.T) fakeEnv {
105	workingDir := filepath.Join("/tmp/crush-test/", t.Name())
106	os.RemoveAll(workingDir)
107
108	err := os.MkdirAll(workingDir, 0o755)
109	require.NoError(t, err)
110
111	conn, err := db.Connect(t.Context(), t.TempDir())
112	require.NoError(t, err)
113
114	q := db.New(conn)
115	sessions := session.NewService(q)
116	messages := message.NewService(q)
117
118	permissions := permission.NewPermissionService(workingDir, true, []string{})
119	history := history.NewService(q, conn)
120	lspClients := csync.NewMap[string, *lsp.Client]()
121
122	t.Cleanup(func() {
123		conn.Close()
124		os.RemoveAll(workingDir)
125	})
126
127	return fakeEnv{
128		workingDir,
129		sessions,
130		messages,
131		permissions,
132		history,
133		lspClients,
134	}
135}
136
137func testSessionAgent(env fakeEnv, large, small fantasy.LanguageModel, systemPrompt string, tools ...fantasy.AgentTool) SessionAgent {
138	largeModel := Model{
139		Model: large,
140		CatwalkCfg: catwalk.Model{
141			ContextWindow:    200000,
142			DefaultMaxTokens: 10000,
143		},
144	}
145	smallModel := Model{
146		Model: small,
147		CatwalkCfg: catwalk.Model{
148			ContextWindow:    200000,
149			DefaultMaxTokens: 10000,
150		},
151	}
152	agent := NewSessionAgent(SessionAgentOptions{
153		LargeModel:           largeModel,
154		SmallModel:           smallModel,
155		SystemPromptPrefix:   "",
156		SystemPrompt:         systemPrompt,
157		DisableAutoSummarize: false,
158		IsYolo:               true,
159		IsSubAgent:           false,
160		HooksManager:         nil,
161		WorkingDir:           "",
162		Sessions:             env.sessions,
163		Messages:             env.messages,
164		Tools:                tools,
165	})
166	return agent
167}
168
169func coderAgent(r *vcr.Recorder, env fakeEnv, large, small fantasy.LanguageModel) (SessionAgent, error) {
170	fixedTime := func() time.Time {
171		t, _ := time.Parse("1/2/2006", "1/1/2025")
172		return t
173	}
174	prompt, err := coderPrompt(
175		prompt.WithTimeFunc(fixedTime),
176		prompt.WithPlatform("linux"),
177		prompt.WithWorkingDir(filepath.ToSlash(env.workingDir)),
178	)
179	if err != nil {
180		return nil, err
181	}
182	cfg, err := config.Init(env.workingDir, "", false)
183	if err != nil {
184		return nil, err
185	}
186
187	// NOTE(@andreynering): Set a fixed config to ensure cassettes match
188	// independently of user config on `$HOME/.config/crush/crush.json`.
189	cfg.Options.Attribution = &config.Attribution{
190		TrailerStyle:  "co-authored-by",
191		GeneratedWith: true,
192	}
193
194	systemPrompt, err := prompt.Build(context.TODO(), large.Provider(), large.Model(), *cfg)
195	if err != nil {
196		return nil, err
197	}
198
199	// Get the model name for the bash tool
200	modelName := large.Model() // fallback to ID if Name not available
201	if model := cfg.GetModel(large.Provider(), large.Model()); model != nil {
202		modelName = model.Name
203	}
204
205	allTools := []fantasy.AgentTool{
206		tools.NewBashTool(env.permissions, env.workingDir, cfg.Options.Attribution, modelName),
207		tools.NewDownloadTool(env.permissions, env.workingDir, r.GetDefaultClient()),
208		tools.NewEditTool(env.lspClients, env.permissions, env.history, env.workingDir),
209		tools.NewMultiEditTool(env.lspClients, env.permissions, env.history, env.workingDir),
210		tools.NewFetchTool(env.permissions, env.workingDir, r.GetDefaultClient()),
211		tools.NewGlobTool(env.workingDir),
212		tools.NewGrepTool(env.workingDir),
213		tools.NewLsTool(env.permissions, env.workingDir, cfg.Tools.Ls),
214		tools.NewSourcegraphTool(r.GetDefaultClient()),
215		tools.NewViewTool(env.lspClients, env.permissions, env.workingDir),
216		tools.NewWriteTool(env.lspClients, env.permissions, env.history, env.workingDir),
217	}
218
219	return testSessionAgent(env, large, small, systemPrompt, allTools...), nil
220}
221
222// createSimpleGoProject creates a simple Go project structure in the given directory.
223// It creates a go.mod file and a main.go file with a basic hello world program.
224func createSimpleGoProject(t *testing.T, dir string) {
225	goMod := `module example.com/testproject
226
227go 1.23
228`
229	err := os.WriteFile(dir+"/go.mod", []byte(goMod), 0o644)
230	require.NoError(t, err)
231
232	mainGo := `package main
233
234import "fmt"
235
236func main() {
237	fmt.Println("Hello, World!")
238}
239`
240	err = os.WriteFile(dir+"/main.go", []byte(mainGo), 0o644)
241	require.NoError(t, err)
242}