common_test.go

  1package agent
  2
  3import (
  4	"context"
  5	"net/http"
  6	"os"
  7	"path/filepath"
  8	"testing"
  9	"time"
 10
 11	"charm.land/fantasy"
 12	"charm.land/fantasy/providers/anthropic"
 13	"charm.land/fantasy/providers/openai"
 14	"charm.land/fantasy/providers/openaicompat"
 15	"charm.land/fantasy/providers/openrouter"
 16	"github.com/charmbracelet/catwalk/pkg/catwalk"
 17	"github.com/charmbracelet/crush/internal/agent/prompt"
 18	"github.com/charmbracelet/crush/internal/agent/tools"
 19	"github.com/charmbracelet/crush/internal/config"
 20	"github.com/charmbracelet/crush/internal/csync"
 21	"github.com/charmbracelet/crush/internal/db"
 22	"github.com/charmbracelet/crush/internal/history"
 23	"github.com/charmbracelet/crush/internal/lsp"
 24	"github.com/charmbracelet/crush/internal/message"
 25	"github.com/charmbracelet/crush/internal/permission"
 26	"github.com/charmbracelet/crush/internal/session"
 27	"github.com/stretchr/testify/require"
 28	"gopkg.in/dnaeon/go-vcr.v4/pkg/recorder"
 29
 30	_ "github.com/joho/godotenv/autoload"
 31)
 32
 33type env struct {
 34	workingDir  string
 35	sessions    session.Service
 36	messages    message.Service
 37	permissions permission.Service
 38	history     history.Service
 39	lspClients  *csync.Map[string, *lsp.Client]
 40}
 41
 42type builderFunc func(t *testing.T, r *recorder.Recorder) (fantasy.LanguageModel, error)
 43
 44type modelPair struct {
 45	name       string
 46	largeModel builderFunc
 47	smallModel builderFunc
 48}
 49
 50func anthropicBuilder(model string) builderFunc {
 51	return func(t *testing.T, r *recorder.Recorder) (fantasy.LanguageModel, error) {
 52		provider, err := anthropic.New(
 53			anthropic.WithAPIKey(os.Getenv("CRUSH_ANTHROPIC_API_KEY")),
 54			anthropic.WithHTTPClient(&http.Client{Transport: r}),
 55		)
 56		if err != nil {
 57			return nil, err
 58		}
 59		return provider.LanguageModel(t.Context(), model)
 60	}
 61}
 62
 63func openaiBuilder(model string) builderFunc {
 64	return func(t *testing.T, r *recorder.Recorder) (fantasy.LanguageModel, error) {
 65		provider, err := openai.New(
 66			openai.WithAPIKey(os.Getenv("CRUSH_OPENAI_API_KEY")),
 67			openai.WithHTTPClient(&http.Client{Transport: r}),
 68		)
 69		if err != nil {
 70			return nil, err
 71		}
 72		return provider.LanguageModel(t.Context(), model)
 73	}
 74}
 75
 76func openRouterBuilder(model string) builderFunc {
 77	return func(t *testing.T, r *recorder.Recorder) (fantasy.LanguageModel, error) {
 78		provider, err := openrouter.New(
 79			openrouter.WithAPIKey(os.Getenv("CRUSH_OPENROUTER_API_KEY")),
 80			openrouter.WithHTTPClient(&http.Client{Transport: r}),
 81		)
 82		if err != nil {
 83			return nil, err
 84		}
 85		return provider.LanguageModel(t.Context(), model)
 86	}
 87}
 88
 89func zAIBuilder(model string) builderFunc {
 90	return func(t *testing.T, r *recorder.Recorder) (fantasy.LanguageModel, error) {
 91		provider, err := openaicompat.New(
 92			openaicompat.WithBaseURL("https://api.z.ai/api/coding/paas/v4"),
 93			openaicompat.WithAPIKey(os.Getenv("CRUSH_ZAI_API_KEY")),
 94			openaicompat.WithHTTPClient(&http.Client{Transport: r}),
 95		)
 96		if err != nil {
 97			return nil, err
 98		}
 99		return provider.LanguageModel(t.Context(), model)
100	}
101}
102
103func testEnv(t *testing.T) env {
104	workingDir := filepath.Join("/tmp/crush-test/", t.Name())
105	os.RemoveAll(workingDir)
106
107	err := os.MkdirAll(workingDir, 0o755)
108	require.NoError(t, err)
109
110	conn, err := db.Connect(t.Context(), t.TempDir())
111	require.NoError(t, err)
112
113	q := db.New(conn)
114	sessions := session.NewService(q)
115	messages := message.NewService(q)
116
117	permissions := permission.NewPermissionService(workingDir, true, []string{})
118	history := history.NewService(q, conn)
119	lspClients := csync.NewMap[string, *lsp.Client]()
120
121	t.Cleanup(func() {
122		conn.Close()
123		os.RemoveAll(workingDir)
124	})
125
126	return env{
127		workingDir,
128		sessions,
129		messages,
130		permissions,
131		history,
132		lspClients,
133	}
134}
135
136func testSessionAgent(env env, large, small fantasy.LanguageModel, systemPrompt string, tools ...fantasy.AgentTool) SessionAgent {
137	largeModel := Model{
138		Model: large,
139		CatwalkCfg: catwalk.Model{
140			ContextWindow:    200000,
141			DefaultMaxTokens: 10000,
142		},
143	}
144	smallModel := Model{
145		Model: small,
146		CatwalkCfg: catwalk.Model{
147			ContextWindow:    200000,
148			DefaultMaxTokens: 10000,
149		},
150	}
151	agent := NewSessionAgent(SessionAgentOptions{
152		LargeModel:           largeModel,
153		SmallModel:           smallModel,
154		SystemPromptPrefix:   "",
155		SystemPrompt:         systemPrompt,
156		DisableAutoSummarize: false,
157		IsYolo:               true,
158		Sessions:             env.sessions,
159		Messages:             env.messages,
160		Tools:                tools,
161		Hooks:                nil,
162	})
163	return agent
164}
165
166func coderAgent(r *recorder.Recorder, env env, large, small fantasy.LanguageModel) (SessionAgent, error) {
167	fixedTime := func() time.Time {
168		t, _ := time.Parse("1/2/2006", "1/1/2025")
169		return t
170	}
171	prompt, err := coderPrompt(
172		prompt.WithTimeFunc(fixedTime),
173		prompt.WithPlatform("linux"),
174		prompt.WithWorkingDir(filepath.ToSlash(env.workingDir)),
175	)
176	if err != nil {
177		return nil, err
178	}
179	cfg, err := config.Init(env.workingDir, "", false)
180	if err != nil {
181		return nil, err
182	}
183
184	systemPrompt, err := prompt.Build(context.TODO(), large.Provider(), large.Model(), *cfg)
185	if err != nil {
186		return nil, err
187	}
188	allTools := []fantasy.AgentTool{
189		tools.NewBashTool(env.permissions, env.workingDir, cfg.Options.Attribution),
190		tools.NewDownloadTool(env.permissions, env.workingDir, r.GetDefaultClient()),
191		tools.NewEditTool(env.lspClients, env.permissions, env.history, env.workingDir),
192		tools.NewMultiEditTool(env.lspClients, env.permissions, env.history, env.workingDir),
193		tools.NewFetchTool(env.permissions, env.workingDir, r.GetDefaultClient()),
194		tools.NewGlobTool(env.workingDir),
195		tools.NewGrepTool(env.workingDir),
196		tools.NewLsTool(env.permissions, env.workingDir, cfg.Tools.Ls),
197		tools.NewSourcegraphTool(r.GetDefaultClient()),
198		tools.NewViewTool(env.lspClients, env.permissions, env.workingDir),
199		tools.NewWriteTool(env.lspClients, env.permissions, env.history, env.workingDir),
200	}
201
202	return testSessionAgent(env, large, small, systemPrompt, allTools...), nil
203}
204
205// createSimpleGoProject creates a simple Go project structure in the given directory.
206// It creates a go.mod file and a main.go file with a basic hello world program.
207func createSimpleGoProject(t *testing.T, dir string) {
208	goMod := `module example.com/testproject
209
210go 1.23
211`
212	err := os.WriteFile(dir+"/go.mod", []byte(goMod), 0o644)
213	require.NoError(t, err)
214
215	mainGo := `package main
216
217import "fmt"
218
219func main() {
220	fmt.Println("Hello, World!")
221}
222`
223	err = os.WriteFile(dir+"/main.go", []byte(mainGo), 0o644)
224	require.NoError(t, err)
225}