common_test.go

  1package agent
  2
  3import (
  4	"context"
  5	"net/http"
  6	"os"
  7	"path/filepath"
  8	"testing"
  9	"time"
 10
 11	"charm.land/fantasy"
 12	"charm.land/fantasy/providers/anthropic"
 13	"charm.land/fantasy/providers/openai"
 14	"charm.land/fantasy/providers/openaicompat"
 15	"charm.land/fantasy/providers/openrouter"
 16	"git.secluded.site/crush/internal/agent/prompt"
 17	"git.secluded.site/crush/internal/agent/tools"
 18	"git.secluded.site/crush/internal/config"
 19	"git.secluded.site/crush/internal/csync"
 20	"git.secluded.site/crush/internal/db"
 21	"git.secluded.site/crush/internal/history"
 22	"git.secluded.site/crush/internal/lsp"
 23	"git.secluded.site/crush/internal/message"
 24	"git.secluded.site/crush/internal/permission"
 25	"git.secluded.site/crush/internal/session"
 26	"github.com/charmbracelet/catwalk/pkg/catwalk"
 27	"github.com/stretchr/testify/require"
 28	"gopkg.in/dnaeon/go-vcr.v4/pkg/recorder"
 29
 30	_ "github.com/joho/godotenv/autoload"
 31)
 32
 33type env struct {
 34	workingDir  string
 35	sessions    session.Service
 36	messages    message.Service
 37	permissions permission.Service
 38	history     history.Service
 39	lspClients  *csync.Map[string, *lsp.Client]
 40}
 41
 42type builderFunc func(t *testing.T, r *recorder.Recorder) (fantasy.LanguageModel, error)
 43
 44type modelPair struct {
 45	name       string
 46	largeModel builderFunc
 47	smallModel builderFunc
 48}
 49
 50func anthropicBuilder(model string) builderFunc {
 51	return func(t *testing.T, r *recorder.Recorder) (fantasy.LanguageModel, error) {
 52		provider, err := anthropic.New(
 53			anthropic.WithAPIKey(os.Getenv("CRUSH_ANTHROPIC_API_KEY")),
 54			anthropic.WithHTTPClient(&http.Client{Transport: r}),
 55		)
 56		if err != nil {
 57			return nil, err
 58		}
 59		return provider.LanguageModel(t.Context(), model)
 60	}
 61}
 62
 63func openaiBuilder(model string) builderFunc {
 64	return func(t *testing.T, r *recorder.Recorder) (fantasy.LanguageModel, error) {
 65		provider, err := openai.New(
 66			openai.WithAPIKey(os.Getenv("CRUSH_OPENAI_API_KEY")),
 67			openai.WithHTTPClient(&http.Client{Transport: r}),
 68		)
 69		if err != nil {
 70			return nil, err
 71		}
 72		return provider.LanguageModel(t.Context(), model)
 73	}
 74}
 75
 76func openRouterBuilder(model string) builderFunc {
 77	return func(t *testing.T, r *recorder.Recorder) (fantasy.LanguageModel, error) {
 78		provider, err := openrouter.New(
 79			openrouter.WithAPIKey(os.Getenv("CRUSH_OPENROUTER_API_KEY")),
 80			openrouter.WithHTTPClient(&http.Client{Transport: r}),
 81		)
 82		if err != nil {
 83			return nil, err
 84		}
 85		return provider.LanguageModel(t.Context(), model)
 86	}
 87}
 88
 89func zAIBuilder(model string) builderFunc {
 90	return func(t *testing.T, r *recorder.Recorder) (fantasy.LanguageModel, error) {
 91		provider, err := openaicompat.New(
 92			openaicompat.WithBaseURL("https://api.z.ai/api/coding/paas/v4"),
 93			openaicompat.WithAPIKey(os.Getenv("CRUSH_ZAI_API_KEY")),
 94			openaicompat.WithHTTPClient(&http.Client{Transport: r}),
 95		)
 96		if err != nil {
 97			return nil, err
 98		}
 99		return provider.LanguageModel(t.Context(), model)
100	}
101}
102
103func testEnv(t *testing.T) env {
104	workingDir := filepath.Join("/tmp/crush-test/", t.Name())
105	os.RemoveAll(workingDir)
106
107	err := os.MkdirAll(workingDir, 0o755)
108	require.NoError(t, err)
109
110	conn, err := db.Connect(t.Context(), t.TempDir())
111	require.NoError(t, err)
112
113	q := db.New(conn)
114	sessions := session.NewService(q)
115	messages := message.NewService(q)
116
117	permissions := permission.NewPermissionService(t.Context(), workingDir, true, []string{}, nil)
118	history := history.NewService(q, conn)
119	lspClients := csync.NewMap[string, *lsp.Client]()
120
121	t.Cleanup(func() {
122		conn.Close()
123		os.RemoveAll(workingDir)
124	})
125
126	return env{
127		workingDir,
128		sessions,
129		messages,
130		permissions,
131		history,
132		lspClients,
133	}
134}
135
136func testSessionAgent(env env, large, small fantasy.LanguageModel, systemPrompt string, tools ...fantasy.AgentTool) SessionAgent {
137	largeModel := Model{
138		Model: large,
139		CatwalkCfg: catwalk.Model{
140			ContextWindow:    200000,
141			DefaultMaxTokens: 10000,
142		},
143	}
144	smallModel := Model{
145		Model: small,
146		CatwalkCfg: catwalk.Model{
147			ContextWindow:    200000,
148			DefaultMaxTokens: 10000,
149		},
150	}
151	agent := NewSessionAgent(SessionAgentOptions{
152		LargeModel:           largeModel,
153		SmallModel:           smallModel,
154		SystemPromptPrefix:   "",
155		SystemPrompt:         systemPrompt,
156		DisableAutoSummarize: false,
157		IsYolo:               true,
158		Sessions:             env.sessions,
159		Messages:             env.messages,
160		Tools:                tools,
161		Notifier:             nil,
162		NotificationCtx:      context.Background(),
163	})
164	return agent
165}
166
167func coderAgent(r *recorder.Recorder, env env, large, small fantasy.LanguageModel) (SessionAgent, error) {
168	fixedTime := func() time.Time {
169		t, _ := time.Parse("1/2/2006", "1/1/2025")
170		return t
171	}
172	prompt, err := coderPrompt(
173		prompt.WithTimeFunc(fixedTime),
174		prompt.WithPlatform("linux"),
175		prompt.WithWorkingDir(filepath.ToSlash(env.workingDir)),
176	)
177	if err != nil {
178		return nil, err
179	}
180	cfg, err := config.Init(env.workingDir, "", false)
181	if err != nil {
182		return nil, err
183	}
184
185	systemPrompt, err := prompt.Build(context.TODO(), large.Provider(), large.Model(), *cfg)
186	if err != nil {
187		return nil, err
188	}
189	allTools := []fantasy.AgentTool{
190		tools.NewBashTool(env.permissions, env.workingDir, cfg.Options.Attribution),
191		tools.NewDownloadTool(env.permissions, env.workingDir, r.GetDefaultClient()),
192		tools.NewEditTool(env.lspClients, env.permissions, env.history, env.workingDir),
193		tools.NewMultiEditTool(env.lspClients, env.permissions, env.history, env.workingDir),
194		tools.NewFetchTool(env.permissions, env.workingDir, r.GetDefaultClient()),
195		tools.NewGlobTool(env.workingDir),
196		tools.NewGrepTool(env.workingDir),
197		tools.NewLsTool(env.permissions, env.workingDir, cfg.Tools.Ls),
198		tools.NewSourcegraphTool(r.GetDefaultClient()),
199		tools.NewViewTool(env.lspClients, env.permissions, env.workingDir),
200		tools.NewWriteTool(env.lspClients, env.permissions, env.history, env.workingDir),
201	}
202
203	return testSessionAgent(env, large, small, systemPrompt, allTools...), nil
204}
205
206// createSimpleGoProject creates a simple Go project structure in the given directory.
207// It creates a go.mod file and a main.go file with a basic hello world program.
208func createSimpleGoProject(t *testing.T, dir string) {
209	goMod := `module example.com/testproject
210
211go 1.23
212`
213	err := os.WriteFile(dir+"/go.mod", []byte(goMod), 0o644)
214	require.NoError(t, err)
215
216	mainGo := `package main
217
218import "fmt"
219
220func main() {
221	fmt.Println("Hello, World!")
222}
223`
224	err = os.WriteFile(dir+"/main.go", []byte(mainGo), 0o644)
225	require.NoError(t, err)
226}