common_test.go

  1package agent
  2
  3import (
  4	"context"
  5	"net/http"
  6	"os"
  7	"path/filepath"
  8	"testing"
  9	"time"
 10
 11	"charm.land/fantasy"
 12	"charm.land/fantasy/providers/anthropic"
 13	"charm.land/fantasy/providers/openai"
 14	"charm.land/fantasy/providers/openaicompat"
 15	"charm.land/fantasy/providers/openrouter"
 16	"github.com/charmbracelet/catwalk/pkg/catwalk"
 17	"github.com/charmbracelet/crush/internal/agent/prompt"
 18	"github.com/charmbracelet/crush/internal/agent/tools"
 19	"github.com/charmbracelet/crush/internal/config"
 20	"github.com/charmbracelet/crush/internal/csync"
 21	"github.com/charmbracelet/crush/internal/db"
 22	"github.com/charmbracelet/crush/internal/history"
 23	"github.com/charmbracelet/crush/internal/lsp"
 24	"github.com/charmbracelet/crush/internal/message"
 25	"github.com/charmbracelet/crush/internal/permission"
 26	"github.com/charmbracelet/crush/internal/session"
 27	"github.com/stretchr/testify/require"
 28	"gopkg.in/dnaeon/go-vcr.v4/pkg/recorder"
 29
 30	_ "github.com/joho/godotenv/autoload"
 31)
 32
 33type env struct {
 34	workingDir  string
 35	sessions    session.Service
 36	messages    message.Service
 37	permissions permission.Service
 38	history     history.Service
 39	lspClients  *csync.Map[string, *lsp.Client]
 40}
 41
 42type builderFunc func(t *testing.T, r *recorder.Recorder) (fantasy.LanguageModel, error)
 43
 44type modelPair struct {
 45	name       string
 46	largeModel builderFunc
 47	smallModel builderFunc
 48}
 49
 50func anthropicBuilder(model string) builderFunc {
 51	return func(t *testing.T, r *recorder.Recorder) (fantasy.LanguageModel, error) {
 52		provider, err := anthropic.New(
 53			anthropic.WithAPIKey(os.Getenv("CRUSH_ANTHROPIC_API_KEY")),
 54			anthropic.WithHTTPClient(&http.Client{Transport: r}),
 55		)
 56		if err != nil {
 57			return nil, err
 58		}
 59		return provider.LanguageModel(t.Context(), model)
 60	}
 61}
 62
 63func openaiBuilder(model string) builderFunc {
 64	return func(t *testing.T, r *recorder.Recorder) (fantasy.LanguageModel, error) {
 65		provider, err := openai.New(
 66			openai.WithAPIKey(os.Getenv("CRUSH_OPENAI_API_KEY")),
 67			openai.WithHTTPClient(&http.Client{Transport: r}),
 68		)
 69		if err != nil {
 70			return nil, err
 71		}
 72		return provider.LanguageModel(t.Context(), model)
 73	}
 74}
 75
 76func openRouterBuilder(model string) builderFunc {
 77	return func(t *testing.T, r *recorder.Recorder) (fantasy.LanguageModel, error) {
 78		provider, err := openrouter.New(
 79			openrouter.WithAPIKey(os.Getenv("CRUSH_OPENROUTER_API_KEY")),
 80			openrouter.WithHTTPClient(&http.Client{Transport: r}),
 81		)
 82		if err != nil {
 83			return nil, err
 84		}
 85		return provider.LanguageModel(t.Context(), model)
 86	}
 87}
 88
 89func zAIBuilder(model string) builderFunc {
 90	return func(t *testing.T, r *recorder.Recorder) (fantasy.LanguageModel, error) {
 91		provider, err := openaicompat.New(
 92			openaicompat.WithBaseURL("https://api.z.ai/api/coding/paas/v4"),
 93			openaicompat.WithAPIKey(os.Getenv("CRUSH_ZAI_API_KEY")),
 94			openaicompat.WithHTTPClient(&http.Client{Transport: r}),
 95		)
 96		if err != nil {
 97			return nil, err
 98		}
 99		return provider.LanguageModel(t.Context(), model)
100	}
101}
102
103func testEnv(t *testing.T) env {
104	testDir := filepath.Join("/tmp/crush-test/", t.Name())
105	os.RemoveAll(testDir)
106	err := os.MkdirAll(testDir, 0o755)
107	t.Cleanup(func() {
108		os.RemoveAll(testDir)
109	})
110	require.NoError(t, err)
111	workingDir := testDir
112	conn, err := db.Connect(t.Context(), t.TempDir())
113	require.NoError(t, err)
114	q := db.New(conn)
115	sessions := session.NewService(q)
116	messages := message.NewService(q)
117	permissions := permission.NewPermissionService(workingDir, true, []string{})
118	history := history.NewService(q, conn)
119	lspClients := csync.NewMap[string, *lsp.Client]()
120	return env{
121		workingDir,
122		sessions,
123		messages,
124		permissions,
125		history,
126		lspClients,
127	}
128}
129
130func testSessionAgent(env env, large, small fantasy.LanguageModel, systemPrompt string, tools ...fantasy.AgentTool) SessionAgent {
131	largeModel := Model{
132		Model: large,
133		CatwalkCfg: catwalk.Model{
134			ContextWindow:    200000,
135			DefaultMaxTokens: 10000,
136		},
137	}
138	smallModel := Model{
139		Model: small,
140		CatwalkCfg: catwalk.Model{
141			ContextWindow:    200000,
142			DefaultMaxTokens: 10000,
143		},
144	}
145	agent := NewSessionAgent(SessionAgentOptions{largeModel, smallModel, systemPrompt, false, env.sessions, env.messages, tools})
146	return agent
147}
148
149func coderAgent(r *recorder.Recorder, env env, large, small fantasy.LanguageModel) (SessionAgent, error) {
150	fixedTime := func() time.Time {
151		t, _ := time.Parse("1/2/2006", "1/1/2025")
152		return t
153	}
154	prompt, err := coderPrompt(
155		prompt.WithTimeFunc(fixedTime),
156		prompt.WithPlatform("linux"),
157		prompt.WithWorkingDir(env.workingDir),
158	)
159	if err != nil {
160		return nil, err
161	}
162	cfg, err := config.Init(env.workingDir, "", false)
163	if err != nil {
164		return nil, err
165	}
166
167	systemPrompt, err := prompt.Build(context.TODO(), large.Provider(), large.Model(), *cfg)
168	if err != nil {
169		return nil, err
170	}
171	allTools := []fantasy.AgentTool{
172		tools.NewBashTool(env.permissions, env.workingDir, cfg.Options.Attribution),
173		tools.NewDownloadTool(env.permissions, env.workingDir, r.GetDefaultClient()),
174		tools.NewEditTool(env.lspClients, env.permissions, env.history, env.workingDir),
175		tools.NewMultiEditTool(env.lspClients, env.permissions, env.history, env.workingDir),
176		tools.NewFetchTool(env.permissions, env.workingDir, r.GetDefaultClient()),
177		tools.NewGlobTool(env.workingDir),
178		tools.NewGrepTool(env.workingDir),
179		tools.NewLsTool(env.permissions, env.workingDir, cfg.Tools.Ls),
180		tools.NewSourcegraphTool(r.GetDefaultClient()),
181		tools.NewViewTool(env.lspClients, env.permissions, env.workingDir),
182		tools.NewWriteTool(env.lspClients, env.permissions, env.history, env.workingDir),
183	}
184
185	return testSessionAgent(env, large, small, systemPrompt, allTools...), nil
186}
187
188// createSimpleGoProject creates a simple Go project structure in the given directory.
189// It creates a go.mod file and a main.go file with a basic hello world program.
190func createSimpleGoProject(t *testing.T, dir string) {
191	goMod := `module example.com/testproject
192
193go 1.23
194`
195	err := os.WriteFile(dir+"/go.mod", []byte(goMod), 0o644)
196	require.NoError(t, err)
197
198	mainGo := `package main
199
200import "fmt"
201
202func main() {
203	fmt.Println("Hello, World!")
204}
205`
206	err = os.WriteFile(dir+"/main.go", []byte(mainGo), 0o644)
207	require.NoError(t, err)
208}