common_test.go

  1package agent
  2
  3import (
  4	"net/http"
  5	"os"
  6	"path/filepath"
  7	"testing"
  8	"time"
  9
 10	"charm.land/fantasy"
 11	"charm.land/fantasy/providers/anthropic"
 12	"charm.land/fantasy/providers/openai"
 13	"charm.land/fantasy/providers/openaicompat"
 14	"charm.land/fantasy/providers/openrouter"
 15	"github.com/charmbracelet/catwalk/pkg/catwalk"
 16	"github.com/charmbracelet/crush/internal/agent/prompt"
 17	"github.com/charmbracelet/crush/internal/agent/tools"
 18	"github.com/charmbracelet/crush/internal/config"
 19	"github.com/charmbracelet/crush/internal/csync"
 20	"github.com/charmbracelet/crush/internal/db"
 21	"github.com/charmbracelet/crush/internal/history"
 22	"github.com/charmbracelet/crush/internal/lsp"
 23	"github.com/charmbracelet/crush/internal/message"
 24	"github.com/charmbracelet/crush/internal/permission"
 25	"github.com/charmbracelet/crush/internal/session"
 26	"github.com/stretchr/testify/require"
 27	"gopkg.in/dnaeon/go-vcr.v4/pkg/recorder"
 28
 29	_ "github.com/joho/godotenv/autoload"
 30)
 31
 32type env struct {
 33	workingDir  string
 34	sessions    session.Service
 35	messages    message.Service
 36	permissions permission.Service
 37	history     history.Service
 38	lspClients  *csync.Map[string, *lsp.Client]
 39}
 40
 41type builderFunc func(t *testing.T, r *recorder.Recorder) (fantasy.LanguageModel, error)
 42
 43type modelPair struct {
 44	name       string
 45	largeModel builderFunc
 46	smallModel builderFunc
 47}
 48
 49func anthropicBuilder(model string) builderFunc {
 50	return func(t *testing.T, r *recorder.Recorder) (fantasy.LanguageModel, error) {
 51		provider, err := anthropic.New(
 52			anthropic.WithAPIKey(os.Getenv("CRUSH_ANTHROPIC_API_KEY")),
 53			anthropic.WithHTTPClient(&http.Client{Transport: r}),
 54		)
 55		if err != nil {
 56			return nil, err
 57		}
 58		return provider.LanguageModel(t.Context(), model)
 59	}
 60}
 61
 62func openaiBuilder(model string) builderFunc {
 63	return func(t *testing.T, r *recorder.Recorder) (fantasy.LanguageModel, error) {
 64		provider, err := openai.New(
 65			openai.WithAPIKey(os.Getenv("CRUSH_OPENAI_API_KEY")),
 66			openai.WithHTTPClient(&http.Client{Transport: r}),
 67		)
 68		if err != nil {
 69			return nil, err
 70		}
 71		return provider.LanguageModel(t.Context(), model)
 72	}
 73}
 74
 75func openRouterBuilder(model string) builderFunc {
 76	return func(t *testing.T, r *recorder.Recorder) (fantasy.LanguageModel, error) {
 77		provider, err := openrouter.New(
 78			openrouter.WithAPIKey(os.Getenv("CRUSH_OPENROUTER_API_KEY")),
 79			openrouter.WithHTTPClient(&http.Client{Transport: r}),
 80		)
 81		if err != nil {
 82			return nil, err
 83		}
 84		return provider.LanguageModel(t.Context(), model)
 85	}
 86}
 87
 88func zAIBuilder(model string) builderFunc {
 89	return func(t *testing.T, r *recorder.Recorder) (fantasy.LanguageModel, error) {
 90		provider, err := openaicompat.New(
 91			openaicompat.WithBaseURL("https://api.z.ai/api/coding/paas/v4"),
 92			openaicompat.WithAPIKey(os.Getenv("CRUSH_ZAI_API_KEY")),
 93			openaicompat.WithHTTPClient(&http.Client{Transport: r}),
 94		)
 95		if err != nil {
 96			return nil, err
 97		}
 98		return provider.LanguageModel(t.Context(), model)
 99	}
100}
101
102func testEnv(t *testing.T) env {
103	testDir := filepath.Join("/tmp/crush-test/", t.Name())
104	os.RemoveAll(testDir)
105	err := os.MkdirAll(testDir, 0o755)
106	t.Cleanup(func() {
107		os.RemoveAll(testDir)
108	})
109	require.NoError(t, err)
110	workingDir := testDir
111	conn, err := db.Connect(t.Context(), t.TempDir())
112	require.NoError(t, err)
113	q := db.New(conn)
114	sessions := session.NewService(q)
115	messages := message.NewService(q)
116	permissions := permission.NewPermissionService(workingDir, true, []string{})
117	history := history.NewService(q, conn)
118	lspClients := csync.NewMap[string, *lsp.Client]()
119	return env{
120		workingDir,
121		sessions,
122		messages,
123		permissions,
124		history,
125		lspClients,
126	}
127}
128
129func testSessionAgent(env env, large, small fantasy.LanguageModel, systemPrompt string, tools ...fantasy.AgentTool) SessionAgent {
130	largeModel := Model{
131		Model: large,
132		CatwalkCfg: catwalk.Model{
133			ContextWindow:    200000,
134			DefaultMaxTokens: 10000,
135		},
136	}
137	smallModel := Model{
138		Model: small,
139		CatwalkCfg: catwalk.Model{
140			ContextWindow:    200000,
141			DefaultMaxTokens: 10000,
142		},
143	}
144	agent := NewSessionAgent(SessionAgentOptions{largeModel, smallModel, systemPrompt, false, env.sessions, env.messages, tools})
145	return agent
146}
147
148func coderAgent(r *recorder.Recorder, env env, large, small fantasy.LanguageModel) (SessionAgent, error) {
149	fixedTime := func() time.Time {
150		t, _ := time.Parse("1/2/2006", "1/1/2025")
151		return t
152	}
153	prompt, err := coderPrompt(
154		prompt.WithTimeFunc(fixedTime),
155		prompt.WithPlatform("linux"),
156		prompt.WithWorkingDir(env.workingDir),
157	)
158	if err != nil {
159		return nil, err
160	}
161	cfg, err := config.Init(env.workingDir, "", false)
162	if err != nil {
163		return nil, err
164	}
165
166	systemPrompt, err := prompt.Build(large.Provider(), large.Model(), *cfg)
167	if err != nil {
168		return nil, err
169	}
170	allTools := []fantasy.AgentTool{
171		tools.NewBashTool(env.permissions, env.workingDir, cfg.Options.Attribution),
172		tools.NewDownloadTool(env.permissions, env.workingDir, r.GetDefaultClient()),
173		tools.NewEditTool(env.lspClients, env.permissions, env.history, env.workingDir),
174		tools.NewMultiEditTool(env.lspClients, env.permissions, env.history, env.workingDir),
175		tools.NewFetchTool(env.permissions, env.workingDir, r.GetDefaultClient()),
176		tools.NewGlobTool(env.workingDir),
177		tools.NewGrepTool(env.workingDir),
178		tools.NewLsTool(env.permissions, env.workingDir, cfg.Tools.Ls),
179		tools.NewSourcegraphTool(r.GetDefaultClient()),
180		tools.NewViewTool(env.lspClients, env.permissions, env.workingDir),
181		tools.NewWriteTool(env.lspClients, env.permissions, env.history, env.workingDir),
182	}
183
184	return testSessionAgent(env, large, small, systemPrompt, allTools...), nil
185}
186
187// createSimpleGoProject creates a simple Go project structure in the given directory.
188// It creates a go.mod file and a main.go file with a basic hello world program.
189func createSimpleGoProject(t *testing.T, dir string) {
190	goMod := `module example.com/testproject
191
192go 1.23
193`
194	err := os.WriteFile(dir+"/go.mod", []byte(goMod), 0o644)
195	require.NoError(t, err)
196
197	mainGo := `package main
198
199import "fmt"
200
201func main() {
202	fmt.Println("Hello, World!")
203}
204`
205	err = os.WriteFile(dir+"/main.go", []byte(mainGo), 0o644)
206	require.NoError(t, err)
207}