common_test.go

  1package agent
  2
  3import (
  4	"fmt"
  5	"net/http"
  6	"os"
  7	"path/filepath"
  8	"testing"
  9	"time"
 10
 11	"github.com/charmbracelet/catwalk/pkg/catwalk"
 12	"github.com/charmbracelet/crush/internal/agent/prompt"
 13	"github.com/charmbracelet/crush/internal/agent/tools"
 14	"github.com/charmbracelet/crush/internal/config"
 15	"github.com/charmbracelet/crush/internal/csync"
 16	"github.com/charmbracelet/crush/internal/db"
 17	"github.com/charmbracelet/crush/internal/history"
 18	"github.com/charmbracelet/crush/internal/lsp"
 19	"github.com/charmbracelet/crush/internal/message"
 20	"github.com/charmbracelet/crush/internal/permission"
 21	"github.com/charmbracelet/crush/internal/session"
 22	"github.com/charmbracelet/fantasy/ai"
 23	"github.com/charmbracelet/fantasy/anthropic"
 24	"github.com/charmbracelet/fantasy/openai"
 25	"github.com/charmbracelet/fantasy/openaicompat"
 26	"github.com/charmbracelet/fantasy/openrouter"
 27	"github.com/stretchr/testify/require"
 28	"gopkg.in/dnaeon/go-vcr.v4/pkg/recorder"
 29
 30	_ "github.com/joho/godotenv/autoload"
 31)
 32
 33type env struct {
 34	workingDir  string
 35	sessions    session.Service
 36	messages    message.Service
 37	permissions permission.Service
 38	history     history.Service
 39	lspClients  *csync.Map[string, *lsp.Client]
 40}
 41
 42type builderFunc func(t *testing.T, r *recorder.Recorder) (ai.LanguageModel, error)
 43
 44type modelPair struct {
 45	name       string
 46	largeModel builderFunc
 47	smallModel builderFunc
 48}
 49
 50func anthropicBuilder(model string) builderFunc {
 51	return func(_ *testing.T, r *recorder.Recorder) (ai.LanguageModel, error) {
 52		provider := anthropic.New(
 53			anthropic.WithAPIKey(os.Getenv("CRUSH_ANTHROPIC_API_KEY")),
 54			anthropic.WithHTTPClient(&http.Client{Transport: r}),
 55		)
 56		return provider.LanguageModel(model)
 57	}
 58}
 59
 60func openaiBuilder(model string) builderFunc {
 61	return func(_ *testing.T, r *recorder.Recorder) (ai.LanguageModel, error) {
 62		provider := openai.New(
 63			openai.WithAPIKey(os.Getenv("CRUSH_OPENAI_API_KEY")),
 64			openai.WithHTTPClient(&http.Client{Transport: r}),
 65		)
 66		return provider.LanguageModel(model)
 67	}
 68}
 69
 70func openRouterBuilder(model string) builderFunc {
 71	return func(t *testing.T, r *recorder.Recorder) (ai.LanguageModel, error) {
 72		tf := func() func() string {
 73			id := 0
 74			return func() string {
 75				id += 1
 76				return fmt.Sprintf("%s-%d", t.Name(), id)
 77			}
 78		}
 79		provider := openrouter.New(
 80			openrouter.WithAPIKey(os.Getenv("CRUSH_OPENROUTER_API_KEY")),
 81			openrouter.WithHTTPClient(&http.Client{Transport: r}),
 82			openrouter.WithLanguageUniqueToolCallIds(),
 83			openrouter.WithLanguageModelGenerateIDFunc(tf()),
 84		)
 85		return provider.LanguageModel(model)
 86	}
 87}
 88
 89func zAIBuilder(model string) builderFunc {
 90	return func(t *testing.T, r *recorder.Recorder) (ai.LanguageModel, error) {
 91		tf := func() func() string {
 92			id := 0
 93			return func() string {
 94				id += 1
 95				return fmt.Sprintf("%s-%d", t.Name(), id)
 96			}
 97		}
 98		provider := openaicompat.New(
 99			"https://api.z.ai/api/coding/paas/v4",
100			openaicompat.WithAPIKey(os.Getenv("CRUSH_ZAI_API_KEY")),
101			openaicompat.WithHTTPClient(&http.Client{Transport: r}),
102			openaicompat.WithLanguageUniqueToolCallIds(),
103			openaicompat.WithLanguageModelGenerateIDFunc(tf()),
104		)
105		return provider.LanguageModel(model)
106	}
107}
108
109func testEnv(t *testing.T) env {
110	testDir := filepath.Join("/tmp/crush-test/", t.Name())
111	os.RemoveAll(testDir)
112	err := os.MkdirAll(testDir, 0o755)
113	t.Cleanup(func() {
114		os.RemoveAll(testDir)
115	})
116	require.NoError(t, err)
117	workingDir := testDir
118	conn, err := db.Connect(t.Context(), t.TempDir())
119	require.NoError(t, err)
120	q := db.New(conn)
121	sessions := session.NewService(q)
122	messages := message.NewService(q)
123	permissions := permission.NewPermissionService(workingDir, true, []string{})
124	history := history.NewService(q, conn)
125	lspClients := csync.NewMap[string, *lsp.Client]()
126	return env{
127		workingDir,
128		sessions,
129		messages,
130		permissions,
131		history,
132		lspClients,
133	}
134}
135
136func testSessionAgent(env env, large, small ai.LanguageModel, systemPrompt string, tools ...ai.AgentTool) SessionAgent {
137	largeModel := Model{
138		Model:      large,
139		CatwalkCfg: catwalk.Model{
140			// todo: add values
141		},
142	}
143	smallModel := Model{
144		Model:      small,
145		CatwalkCfg: catwalk.Model{
146			// todo: add values
147		},
148	}
149	agent := NewSessionAgent(largeModel, smallModel, systemPrompt, env.sessions, env.messages, tools...)
150	return agent
151}
152
153func coderAgent(r *recorder.Recorder, env env, large, small ai.LanguageModel) (SessionAgent, error) {
154	fixedTime := func() time.Time {
155		t, _ := time.Parse("1/2/2006", "1/1/2025")
156		return t
157	}
158	prompt, err := coderPrompt(
159		prompt.WithTimeFunc(fixedTime),
160		prompt.WithPlatform("linux"),
161		prompt.WithWorkingDir(env.workingDir),
162	)
163	if err != nil {
164		return nil, err
165	}
166	cfg, err := config.Init(env.workingDir, "", false)
167	if err != nil {
168		return nil, err
169	}
170
171	systemPrompt, err := prompt.Build(large.Provider(), large.Model(), *cfg)
172	if err != nil {
173		return nil, err
174	}
175	allTools := []ai.AgentTool{
176		tools.NewBashTool(env.permissions, env.workingDir, cfg.Options.Attribution),
177		tools.NewDownloadTool(env.permissions, env.workingDir, r.GetDefaultClient()),
178		tools.NewEditTool(env.lspClients, env.permissions, env.history, env.workingDir),
179		tools.NewMultiEditTool(env.lspClients, env.permissions, env.history, env.workingDir),
180		tools.NewFetchTool(env.permissions, env.workingDir, r.GetDefaultClient()),
181		tools.NewGlobTool(env.workingDir),
182		tools.NewGrepTool(env.workingDir),
183		tools.NewLsTool(env.permissions, env.workingDir, cfg.Tools.Ls),
184		tools.NewSourcegraphTool(r.GetDefaultClient()),
185		tools.NewViewTool(env.lspClients, env.permissions, env.workingDir),
186		tools.NewWriteTool(env.lspClients, env.permissions, env.history, env.workingDir),
187	}
188
189	return testSessionAgent(env, large, small, systemPrompt, allTools...), nil
190}
191
192// createSimpleGoProject creates a simple Go project structure in the given directory.
193// It creates a go.mod file and a main.go file with a basic hello world program.
194func createSimpleGoProject(t *testing.T, dir string) {
195	goMod := `module example.com/testproject
196
197go 1.23
198`
199	err := os.WriteFile(dir+"/go.mod", []byte(goMod), 0o644)
200	require.NoError(t, err)
201
202	mainGo := `package main
203
204import "fmt"
205
206func main() {
207	fmt.Println("Hello, World!")
208}
209`
210	err = os.WriteFile(dir+"/main.go", []byte(mainGo), 0o644)
211	require.NoError(t, err)
212}