common_test.go

  1package agent
  2
  3import (
  4	"context"
  5	"net/http"
  6	"os"
  7	"path/filepath"
  8	"testing"
  9	"time"
 10
 11	"charm.land/fantasy"
 12	"charm.land/fantasy/providers/anthropic"
 13	"charm.land/fantasy/providers/openai"
 14	"charm.land/fantasy/providers/openaicompat"
 15	"charm.land/fantasy/providers/openrouter"
 16	"github.com/charmbracelet/catwalk/pkg/catwalk"
 17	"github.com/charmbracelet/crush/internal/agent/prompt"
 18	"github.com/charmbracelet/crush/internal/agent/tools"
 19	"github.com/charmbracelet/crush/internal/config"
 20	"github.com/charmbracelet/crush/internal/csync"
 21	"github.com/charmbracelet/crush/internal/db"
 22	"github.com/charmbracelet/crush/internal/history"
 23	"github.com/charmbracelet/crush/internal/lsp"
 24	"github.com/charmbracelet/crush/internal/message"
 25	"github.com/charmbracelet/crush/internal/permission"
 26	"github.com/charmbracelet/crush/internal/session"
 27	"github.com/stretchr/testify/require"
 28	"gopkg.in/dnaeon/go-vcr.v4/pkg/recorder"
 29
 30	_ "github.com/joho/godotenv/autoload"
 31)
 32
 33type env struct {
 34	workingDir  string
 35	sessions    session.Service
 36	messages    message.Service
 37	permissions permission.Service
 38	history     history.Service
 39	lspClients  *csync.Map[string, *lsp.Client]
 40}
 41
 42type builderFunc func(t *testing.T, r *recorder.Recorder) (fantasy.LanguageModel, error)
 43
 44type modelPair struct {
 45	name       string
 46	largeModel builderFunc
 47	smallModel builderFunc
 48}
 49
 50func anthropicBuilder(model string) builderFunc {
 51	return func(t *testing.T, r *recorder.Recorder) (fantasy.LanguageModel, error) {
 52		provider, err := anthropic.New(
 53			anthropic.WithAPIKey(os.Getenv("CRUSH_ANTHROPIC_API_KEY")),
 54			anthropic.WithHTTPClient(&http.Client{Transport: r}),
 55		)
 56		if err != nil {
 57			return nil, err
 58		}
 59		return provider.LanguageModel(t.Context(), model)
 60	}
 61}
 62
 63func openaiBuilder(model string) builderFunc {
 64	return func(t *testing.T, r *recorder.Recorder) (fantasy.LanguageModel, error) {
 65		provider, err := openai.New(
 66			openai.WithAPIKey(os.Getenv("CRUSH_OPENAI_API_KEY")),
 67			openai.WithHTTPClient(&http.Client{Transport: r}),
 68		)
 69		if err != nil {
 70			return nil, err
 71		}
 72		return provider.LanguageModel(t.Context(), model)
 73	}
 74}
 75
 76func openRouterBuilder(model string) builderFunc {
 77	return func(t *testing.T, r *recorder.Recorder) (fantasy.LanguageModel, error) {
 78		provider, err := openrouter.New(
 79			openrouter.WithAPIKey(os.Getenv("CRUSH_OPENROUTER_API_KEY")),
 80			openrouter.WithHTTPClient(&http.Client{Transport: r}),
 81		)
 82		if err != nil {
 83			return nil, err
 84		}
 85		return provider.LanguageModel(t.Context(), model)
 86	}
 87}
 88
 89func zAIBuilder(model string) builderFunc {
 90	return func(t *testing.T, r *recorder.Recorder) (fantasy.LanguageModel, error) {
 91		provider, err := openaicompat.New(
 92			openaicompat.WithBaseURL("https://api.z.ai/api/coding/paas/v4"),
 93			openaicompat.WithAPIKey(os.Getenv("CRUSH_ZAI_API_KEY")),
 94			openaicompat.WithHTTPClient(&http.Client{Transport: r}),
 95		)
 96		if err != nil {
 97			return nil, err
 98		}
 99		return provider.LanguageModel(t.Context(), model)
100	}
101}
102
103func testEnv(t *testing.T) env {
104	workingDir := filepath.Join("/tmp/crush-test/", t.Name())
105	os.RemoveAll(workingDir)
106
107	err := os.MkdirAll(workingDir, 0o755)
108	require.NoError(t, err)
109
110	conn, err := db.Connect(t.Context(), t.TempDir())
111	require.NoError(t, err)
112
113	q := db.New(conn)
114	sessions := session.NewService(q)
115	messages := message.NewService(q)
116
117	permissions := permission.NewPermissionService(workingDir, true, []string{})
118	history := history.NewService(q, conn)
119	lspClients := csync.NewMap[string, *lsp.Client]()
120
121	t.Cleanup(func() {
122		conn.Close()
123		os.RemoveAll(workingDir)
124	})
125
126	return env{
127		workingDir,
128		sessions,
129		messages,
130		permissions,
131		history,
132		lspClients,
133	}
134}
135
136func testSessionAgent(env env, large, small fantasy.LanguageModel, systemPrompt string, tools ...fantasy.AgentTool) SessionAgent {
137	largeModel := Model{
138		Model: large,
139		CatwalkCfg: catwalk.Model{
140			ContextWindow:    200000,
141			DefaultMaxTokens: 10000,
142		},
143	}
144	smallModel := Model{
145		Model: small,
146		CatwalkCfg: catwalk.Model{
147			ContextWindow:    200000,
148			DefaultMaxTokens: 10000,
149		},
150	}
151	agent := NewSessionAgent(SessionAgentOptions{largeModel, smallModel, "", systemPrompt, false, true, env.sessions, env.messages, tools})
152	return agent
153}
154
155func coderAgent(r *recorder.Recorder, env env, large, small fantasy.LanguageModel) (SessionAgent, error) {
156	fixedTime := func() time.Time {
157		t, _ := time.Parse("1/2/2006", "1/1/2025")
158		return t
159	}
160	prompt, err := coderPrompt(
161		prompt.WithTimeFunc(fixedTime),
162		prompt.WithPlatform("linux"),
163		prompt.WithWorkingDir(filepath.ToSlash(env.workingDir)),
164	)
165	if err != nil {
166		return nil, err
167	}
168	cfg, err := config.Init(env.workingDir, "", false)
169	if err != nil {
170		return nil, err
171	}
172
173	systemPrompt, err := prompt.Build(context.TODO(), large.Provider(), large.Model(), *cfg)
174	if err != nil {
175		return nil, err
176	}
177	allTools := []fantasy.AgentTool{
178		tools.NewWebFetchTool(env.workingDir, r.GetDefaultClient()),
179		tools.NewBashTool(env.permissions, env.workingDir, cfg.Options.Attribution),
180		tools.NewDownloadTool(env.permissions, env.workingDir, r.GetDefaultClient()),
181		tools.NewEditTool(env.lspClients, env.permissions, env.history, env.workingDir),
182		tools.NewMultiEditTool(env.lspClients, env.permissions, env.history, env.workingDir),
183		tools.NewGlobTool(env.workingDir),
184		tools.NewGrepTool(env.workingDir),
185		tools.NewLsTool(env.permissions, env.workingDir, cfg.Tools.Ls),
186		tools.NewSourcegraphTool(r.GetDefaultClient()),
187		tools.NewViewTool(env.lspClients, env.permissions, env.workingDir),
188		tools.NewWriteTool(env.lspClients, env.permissions, env.history, env.workingDir),
189	}
190
191	return testSessionAgent(env, large, small, systemPrompt, allTools...), nil
192}
193
194// createSimpleGoProject creates a simple Go project structure in the given directory.
195// It creates a go.mod file and a main.go file with a basic hello world program.
196func createSimpleGoProject(t *testing.T, dir string) {
197	goMod := `module example.com/testproject
198
199go 1.23
200`
201	err := os.WriteFile(dir+"/go.mod", []byte(goMod), 0o644)
202	require.NoError(t, err)
203
204	mainGo := `package main
205
206import "fmt"
207
208func main() {
209	fmt.Println("Hello, World!")
210}
211`
212	err = os.WriteFile(dir+"/main.go", []byte(mainGo), 0o644)
213	require.NoError(t, err)
214}