common_test.go

  1package agent
  2
  3import (
  4	"fmt"
  5	"net/http"
  6	"os"
  7	"path/filepath"
  8	"testing"
  9	"time"
 10
 11	"github.com/charmbracelet/catwalk/pkg/catwalk"
 12	"github.com/charmbracelet/crush/internal/agent/prompt"
 13	"github.com/charmbracelet/crush/internal/agent/tools"
 14	"github.com/charmbracelet/crush/internal/config"
 15	"github.com/charmbracelet/crush/internal/csync"
 16	"github.com/charmbracelet/crush/internal/db"
 17	"github.com/charmbracelet/crush/internal/history"
 18	"github.com/charmbracelet/crush/internal/lsp"
 19	"github.com/charmbracelet/crush/internal/message"
 20	"github.com/charmbracelet/crush/internal/permission"
 21	"github.com/charmbracelet/crush/internal/session"
 22	"github.com/charmbracelet/fantasy/ai"
 23	"github.com/charmbracelet/fantasy/anthropic"
 24	"github.com/charmbracelet/fantasy/openai"
 25	"github.com/charmbracelet/fantasy/openrouter"
 26	"github.com/stretchr/testify/require"
 27	"gopkg.in/dnaeon/go-vcr.v4/pkg/recorder"
 28
 29	_ "github.com/joho/godotenv/autoload"
 30)
 31
 32type env struct {
 33	workingDir  string
 34	sessions    session.Service
 35	messages    message.Service
 36	permissions permission.Service
 37	history     history.Service
 38	lspClients  *csync.Map[string, *lsp.Client]
 39}
 40
 41type builderFunc func(t *testing.T, r *recorder.Recorder) (ai.LanguageModel, error)
 42
 43type modelPair struct {
 44	name       string
 45	largeModel builderFunc
 46	smallModel builderFunc
 47}
 48
 49func anthropicBuilder(model string) builderFunc {
 50	return func(_ *testing.T, r *recorder.Recorder) (ai.LanguageModel, error) {
 51		provider := anthropic.New(
 52			anthropic.WithAPIKey(os.Getenv("CRUSH_ANTHROPIC_API_KEY")),
 53			anthropic.WithHTTPClient(&http.Client{Transport: r}),
 54		)
 55		return provider.LanguageModel(model)
 56	}
 57}
 58
 59func openaiBuilder(model string) builderFunc {
 60	return func(_ *testing.T, r *recorder.Recorder) (ai.LanguageModel, error) {
 61		provider := openai.New(
 62			openai.WithAPIKey(os.Getenv("CRUSH_OPENAI_API_KEY")),
 63			openai.WithHTTPClient(&http.Client{Transport: r}),
 64		)
 65		return provider.LanguageModel(model)
 66	}
 67}
 68
 69func openRouterBuilder(model string) builderFunc {
 70	return func(t *testing.T, r *recorder.Recorder) (ai.LanguageModel, error) {
 71		tf := func() func() string {
 72			id := 0
 73			return func() string {
 74				id += 1
 75				return fmt.Sprintf("%s-%d", t.Name(), id)
 76			}
 77		}
 78		provider := openrouter.New(
 79			openrouter.WithAPIKey(os.Getenv("CRUSH_OPENROUTER_API_KEY")),
 80			openrouter.WithHTTPClient(&http.Client{Transport: r}),
 81			openrouter.WithLanguageUniqueToolCallIds(),
 82			openrouter.WithLanguageModelGenerateIDFunc(tf()),
 83		)
 84		return provider.LanguageModel(model)
 85	}
 86}
 87
 88func testEnv(t *testing.T) env {
 89	testDir := filepath.Join("/tmp/crush-test/", t.Name())
 90	os.RemoveAll(testDir)
 91	err := os.MkdirAll(testDir, 0o755)
 92	t.Cleanup(func() {
 93		os.RemoveAll(testDir)
 94	})
 95	require.NoError(t, err)
 96	workingDir := testDir
 97	conn, err := db.Connect(t.Context(), t.TempDir())
 98	require.NoError(t, err)
 99	q := db.New(conn)
100	sessions := session.NewService(q)
101	messages := message.NewService(q)
102	permissions := permission.NewPermissionService(workingDir, true, []string{})
103	history := history.NewService(q, conn)
104	lspClients := csync.NewMap[string, *lsp.Client]()
105	return env{
106		workingDir,
107		sessions,
108		messages,
109		permissions,
110		history,
111		lspClients,
112	}
113}
114
115func testSessionAgent(env env, large, small ai.LanguageModel, systemPrompt string, tools ...ai.AgentTool) SessionAgent {
116	largeModel := Model{
117		model:  large,
118		config: catwalk.Model{
119			// todo: add values
120		},
121	}
122	smallModel := Model{
123		model:  small,
124		config: catwalk.Model{
125			// todo: add values
126		},
127	}
128	agent := NewSessionAgent(largeModel, smallModel, systemPrompt, env.sessions, env.messages, tools...)
129	return agent
130}
131
132func coderAgent(r *recorder.Recorder, env env, large, small ai.LanguageModel) (SessionAgent, error) {
133	fixedTime := func() time.Time {
134		t, _ := time.Parse("1/2/2006", "1/1/2025")
135		return t
136	}
137	prompt, err := coderPrompt(
138		prompt.WithTimeFunc(fixedTime),
139		prompt.WithPlatform("linux"),
140		prompt.WithWorkingDir(env.workingDir),
141	)
142	if err != nil {
143		return nil, err
144	}
145	cfg, err := config.Init(env.workingDir, "", false)
146	if err != nil {
147		return nil, err
148	}
149
150	systemPrompt, err := prompt.Build(large.Provider(), large.Model(), *cfg)
151	if err != nil {
152		return nil, err
153	}
154	allTools := []ai.AgentTool{
155		tools.NewBashTool(env.permissions, env.workingDir, cfg.Options.Attribution),
156		tools.NewDownloadTool(env.permissions, env.workingDir, r.GetDefaultClient()),
157		tools.NewEditTool(env.lspClients, env.permissions, env.history, env.workingDir),
158		tools.NewMultiEditTool(env.lspClients, env.permissions, env.history, env.workingDir),
159		tools.NewFetchTool(env.permissions, env.workingDir, r.GetDefaultClient()),
160		tools.NewGlobTool(env.workingDir),
161		tools.NewGrepTool(env.workingDir),
162		tools.NewLsTool(env.permissions, env.workingDir, cfg.Tools.Ls),
163		tools.NewSourcegraphTool(r.GetDefaultClient()),
164		tools.NewViewTool(env.lspClients, env.permissions, env.workingDir),
165		tools.NewWriteTool(env.lspClients, env.permissions, env.history, env.workingDir),
166	}
167
168	return testSessionAgent(env, large, small, systemPrompt, allTools...), nil
169}
170
171// createSimpleGoProject creates a simple Go project structure in the given directory.
172// It creates a go.mod file and a main.go file with a basic hello world program.
173func createSimpleGoProject(t *testing.T, dir string) {
174	goMod := `module example.com/testproject
175
176go 1.23
177`
178	err := os.WriteFile(dir+"/go.mod", []byte(goMod), 0o644)
179	require.NoError(t, err)
180
181	mainGo := `package main
182
183import "fmt"
184
185func main() {
186	fmt.Println("Hello, World!")
187}
188`
189	err = os.WriteFile(dir+"/main.go", []byte(mainGo), 0o644)
190	require.NoError(t, err)
191}