common_test.go

  1package agent
  2
  3import (
  4	"net/http"
  5	"os"
  6	"path/filepath"
  7	"testing"
  8	"time"
  9
 10	"github.com/charmbracelet/catwalk/pkg/catwalk"
 11	"github.com/charmbracelet/crush/internal/agent/prompt"
 12	"github.com/charmbracelet/crush/internal/agent/tools"
 13	"github.com/charmbracelet/crush/internal/config"
 14	"github.com/charmbracelet/crush/internal/csync"
 15	"github.com/charmbracelet/crush/internal/db"
 16	"github.com/charmbracelet/crush/internal/history"
 17	"github.com/charmbracelet/crush/internal/lsp"
 18	"github.com/charmbracelet/crush/internal/message"
 19	"github.com/charmbracelet/crush/internal/permission"
 20	"github.com/charmbracelet/crush/internal/session"
 21	"github.com/charmbracelet/fantasy/ai"
 22	"github.com/charmbracelet/fantasy/anthropic"
 23	"github.com/charmbracelet/fantasy/openai"
 24	"github.com/charmbracelet/fantasy/openaicompat"
 25	"github.com/charmbracelet/fantasy/openrouter"
 26	"github.com/stretchr/testify/require"
 27	"gopkg.in/dnaeon/go-vcr.v4/pkg/recorder"
 28
 29	_ "github.com/joho/godotenv/autoload"
 30)
 31
 32type env struct {
 33	workingDir  string
 34	sessions    session.Service
 35	messages    message.Service
 36	permissions permission.Service
 37	history     history.Service
 38	lspClients  *csync.Map[string, *lsp.Client]
 39}
 40
 41type builderFunc func(t *testing.T, r *recorder.Recorder) (ai.LanguageModel, error)
 42
 43type modelPair struct {
 44	name       string
 45	largeModel builderFunc
 46	smallModel builderFunc
 47}
 48
 49func anthropicBuilder(model string) builderFunc {
 50	return func(_ *testing.T, r *recorder.Recorder) (ai.LanguageModel, error) {
 51		provider := anthropic.New(
 52			anthropic.WithAPIKey(os.Getenv("CRUSH_ANTHROPIC_API_KEY")),
 53			anthropic.WithHTTPClient(&http.Client{Transport: r}),
 54		)
 55		return provider.LanguageModel(model)
 56	}
 57}
 58
 59func openaiBuilder(model string) builderFunc {
 60	return func(_ *testing.T, r *recorder.Recorder) (ai.LanguageModel, error) {
 61		provider := openai.New(
 62			openai.WithAPIKey(os.Getenv("CRUSH_OPENAI_API_KEY")),
 63			openai.WithHTTPClient(&http.Client{Transport: r}),
 64		)
 65		return provider.LanguageModel(model)
 66	}
 67}
 68
 69func openRouterBuilder(model string) builderFunc {
 70	return func(t *testing.T, r *recorder.Recorder) (ai.LanguageModel, error) {
 71		provider := openrouter.New(
 72			openrouter.WithAPIKey(os.Getenv("CRUSH_OPENROUTER_API_KEY")),
 73			openrouter.WithHTTPClient(&http.Client{Transport: r}),
 74		)
 75		return provider.LanguageModel(model)
 76	}
 77}
 78
 79func zAIBuilder(model string) builderFunc {
 80	return func(t *testing.T, r *recorder.Recorder) (ai.LanguageModel, error) {
 81		provider := openaicompat.New(
 82			"https://api.z.ai/api/coding/paas/v4",
 83			openaicompat.WithAPIKey(os.Getenv("CRUSH_ZAI_API_KEY")),
 84			openaicompat.WithHTTPClient(&http.Client{Transport: r}),
 85		)
 86		return provider.LanguageModel(model)
 87	}
 88}
 89
 90func testEnv(t *testing.T) env {
 91	testDir := filepath.Join("/tmp/crush-test/", t.Name())
 92	os.RemoveAll(testDir)
 93	err := os.MkdirAll(testDir, 0o755)
 94	t.Cleanup(func() {
 95		os.RemoveAll(testDir)
 96	})
 97	require.NoError(t, err)
 98	workingDir := testDir
 99	conn, err := db.Connect(t.Context(), t.TempDir())
100	require.NoError(t, err)
101	q := db.New(conn)
102	sessions := session.NewService(q)
103	messages := message.NewService(q)
104	permissions := permission.NewPermissionService(workingDir, true, []string{})
105	history := history.NewService(q, conn)
106	lspClients := csync.NewMap[string, *lsp.Client]()
107	return env{
108		workingDir,
109		sessions,
110		messages,
111		permissions,
112		history,
113		lspClients,
114	}
115}
116
117func testSessionAgent(env env, large, small ai.LanguageModel, systemPrompt string, tools ...ai.AgentTool) SessionAgent {
118	largeModel := Model{
119		Model:      large,
120		CatwalkCfg: catwalk.Model{
121			// todo: add values
122		},
123	}
124	smallModel := Model{
125		Model:      small,
126		CatwalkCfg: catwalk.Model{
127			// todo: add values
128		},
129	}
130	agent := NewSessionAgent(largeModel, smallModel, systemPrompt, env.sessions, env.messages, tools...)
131	return agent
132}
133
134func coderAgent(r *recorder.Recorder, env env, large, small ai.LanguageModel) (SessionAgent, error) {
135	fixedTime := func() time.Time {
136		t, _ := time.Parse("1/2/2006", "1/1/2025")
137		return t
138	}
139	prompt, err := coderPrompt(
140		prompt.WithTimeFunc(fixedTime),
141		prompt.WithPlatform("linux"),
142		prompt.WithWorkingDir(env.workingDir),
143	)
144	if err != nil {
145		return nil, err
146	}
147	cfg, err := config.Init(env.workingDir, "", false)
148	if err != nil {
149		return nil, err
150	}
151
152	systemPrompt, err := prompt.Build(large.Provider(), large.Model(), *cfg)
153	if err != nil {
154		return nil, err
155	}
156	allTools := []ai.AgentTool{
157		tools.NewBashTool(env.permissions, env.workingDir, cfg.Options.Attribution),
158		tools.NewDownloadTool(env.permissions, env.workingDir, r.GetDefaultClient()),
159		tools.NewEditTool(env.lspClients, env.permissions, env.history, env.workingDir),
160		tools.NewMultiEditTool(env.lspClients, env.permissions, env.history, env.workingDir),
161		tools.NewFetchTool(env.permissions, env.workingDir, r.GetDefaultClient()),
162		tools.NewGlobTool(env.workingDir),
163		tools.NewGrepTool(env.workingDir),
164		tools.NewLsTool(env.permissions, env.workingDir, cfg.Tools.Ls),
165		tools.NewSourcegraphTool(r.GetDefaultClient()),
166		tools.NewViewTool(env.lspClients, env.permissions, env.workingDir),
167		tools.NewWriteTool(env.lspClients, env.permissions, env.history, env.workingDir),
168	}
169
170	return testSessionAgent(env, large, small, systemPrompt, allTools...), nil
171}
172
173// createSimpleGoProject creates a simple Go project structure in the given directory.
174// It creates a go.mod file and a main.go file with a basic hello world program.
175func createSimpleGoProject(t *testing.T, dir string) {
176	goMod := `module example.com/testproject
177
178go 1.23
179`
180	err := os.WriteFile(dir+"/go.mod", []byte(goMod), 0o644)
181	require.NoError(t, err)
182
183	mainGo := `package main
184
185import "fmt"
186
187func main() {
188	fmt.Println("Hello, World!")
189}
190`
191	err = os.WriteFile(dir+"/main.go", []byte(mainGo), 0o644)
192	require.NoError(t, err)
193}