1package agent
  2
  3import (
  4	"context"
  5	"net/http"
  6	"os"
  7	"path/filepath"
  8	"testing"
  9	"time"
 10
 11	"charm.land/fantasy"
 12	"charm.land/fantasy/providers/anthropic"
 13	"charm.land/fantasy/providers/openai"
 14	"charm.land/fantasy/providers/openaicompat"
 15	"charm.land/fantasy/providers/openrouter"
 16	"github.com/charmbracelet/catwalk/pkg/catwalk"
 17	"github.com/charmbracelet/crush/internal/agent/prompt"
 18	"github.com/charmbracelet/crush/internal/agent/tools"
 19	"github.com/charmbracelet/crush/internal/config"
 20	"github.com/charmbracelet/crush/internal/csync"
 21	"github.com/charmbracelet/crush/internal/db"
 22	"github.com/charmbracelet/crush/internal/history"
 23	"github.com/charmbracelet/crush/internal/lsp"
 24	"github.com/charmbracelet/crush/internal/message"
 25	"github.com/charmbracelet/crush/internal/permission"
 26	"github.com/charmbracelet/crush/internal/session"
 27	"github.com/stretchr/testify/require"
 28	"gopkg.in/dnaeon/go-vcr.v4/pkg/recorder"
 29
 30	_ "github.com/joho/godotenv/autoload"
 31)
 32
 33type env struct {
 34	workingDir  string
 35	sessions    session.Service
 36	messages    message.Service
 37	permissions permission.Service
 38	history     history.Service
 39	lspClients  *csync.Map[string, *lsp.Client]
 40}
 41
 42type builderFunc func(t *testing.T, r *recorder.Recorder) (fantasy.LanguageModel, error)
 43
 44type modelPair struct {
 45	name       string
 46	largeModel builderFunc
 47	smallModel builderFunc
 48}
 49
 50func anthropicBuilder(model string) builderFunc {
 51	return func(t *testing.T, r *recorder.Recorder) (fantasy.LanguageModel, error) {
 52		provider, err := anthropic.New(
 53			anthropic.WithAPIKey(os.Getenv("CRUSH_ANTHROPIC_API_KEY")),
 54			anthropic.WithHTTPClient(&http.Client{Transport: r}),
 55		)
 56		if err != nil {
 57			return nil, err
 58		}
 59		return provider.LanguageModel(t.Context(), model)
 60	}
 61}
 62
 63func openaiBuilder(model string) builderFunc {
 64	return func(t *testing.T, r *recorder.Recorder) (fantasy.LanguageModel, error) {
 65		provider, err := openai.New(
 66			openai.WithAPIKey(os.Getenv("CRUSH_OPENAI_API_KEY")),
 67			openai.WithHTTPClient(&http.Client{Transport: r}),
 68		)
 69		if err != nil {
 70			return nil, err
 71		}
 72		return provider.LanguageModel(t.Context(), model)
 73	}
 74}
 75
 76func openRouterBuilder(model string) builderFunc {
 77	return func(t *testing.T, r *recorder.Recorder) (fantasy.LanguageModel, error) {
 78		provider, err := openrouter.New(
 79			openrouter.WithAPIKey(os.Getenv("CRUSH_OPENROUTER_API_KEY")),
 80			openrouter.WithHTTPClient(&http.Client{Transport: r}),
 81		)
 82		if err != nil {
 83			return nil, err
 84		}
 85		return provider.LanguageModel(t.Context(), model)
 86	}
 87}
 88
 89func zAIBuilder(model string) builderFunc {
 90	return func(t *testing.T, r *recorder.Recorder) (fantasy.LanguageModel, error) {
 91		provider, err := openaicompat.New(
 92			openaicompat.WithBaseURL("https://api.z.ai/api/coding/paas/v4"),
 93			openaicompat.WithAPIKey(os.Getenv("CRUSH_ZAI_API_KEY")),
 94			openaicompat.WithHTTPClient(&http.Client{Transport: r}),
 95		)
 96		if err != nil {
 97			return nil, err
 98		}
 99		return provider.LanguageModel(t.Context(), model)
100	}
101}
102
103func testEnv(t *testing.T) env {
104	testDir := filepath.Join("/tmp/crush-test/", t.Name())
105	os.RemoveAll(testDir)
106	err := os.MkdirAll(testDir, 0o755)
107	require.NoError(t, err)
108	workingDir := testDir
109	conn, err := db.Connect(t.Context(), t.TempDir())
110	require.NoError(t, err)
111	q := db.New(conn)
112	sessions := session.NewService(q)
113	messages := message.NewService(q)
114	permissions := permission.NewPermissionService(workingDir, true, []string{})
115	history := history.NewService(q, conn)
116	lspClients := csync.NewMap[string, *lsp.Client]()
117
118	t.Cleanup(func() {
119		conn.Close()
120		os.RemoveAll(testDir)
121	})
122
123	return env{
124		workingDir,
125		sessions,
126		messages,
127		permissions,
128		history,
129		lspClients,
130	}
131}
132
133func testSessionAgent(env env, large, small fantasy.LanguageModel, systemPrompt string, tools ...fantasy.AgentTool) SessionAgent {
134	largeModel := Model{
135		Model: large,
136		CatwalkCfg: catwalk.Model{
137			ContextWindow:    200000,
138			DefaultMaxTokens: 10000,
139		},
140	}
141	smallModel := Model{
142		Model: small,
143		CatwalkCfg: catwalk.Model{
144			ContextWindow:    200000,
145			DefaultMaxTokens: 10000,
146		},
147	}
148	agent := NewSessionAgent(SessionAgentOptions{largeModel, smallModel, systemPrompt, false, env.sessions, env.messages, tools})
149	return agent
150}
151
152func coderAgent(r *recorder.Recorder, env env, large, small fantasy.LanguageModel) (SessionAgent, error) {
153	fixedTime := func() time.Time {
154		t, _ := time.Parse("1/2/2006", "1/1/2025")
155		return t
156	}
157	prompt, err := coderPrompt(
158		prompt.WithTimeFunc(fixedTime),
159		prompt.WithPlatform("linux"),
160		prompt.WithWorkingDir(env.workingDir),
161	)
162	if err != nil {
163		return nil, err
164	}
165	cfg, err := config.Init(env.workingDir, "", false)
166	if err != nil {
167		return nil, err
168	}
169
170	systemPrompt, err := prompt.Build(context.TODO(), large.Provider(), large.Model(), *cfg)
171	if err != nil {
172		return nil, err
173	}
174	allTools := []fantasy.AgentTool{
175		tools.NewBashTool(env.permissions, env.workingDir, cfg.Options.Attribution),
176		tools.NewDownloadTool(env.permissions, env.workingDir, r.GetDefaultClient()),
177		tools.NewEditTool(env.lspClients, env.permissions, env.history, env.workingDir),
178		tools.NewMultiEditTool(env.lspClients, env.permissions, env.history, env.workingDir),
179		tools.NewFetchTool(env.permissions, env.workingDir, r.GetDefaultClient()),
180		tools.NewGlobTool(env.workingDir),
181		tools.NewGrepTool(env.workingDir),
182		tools.NewLsTool(env.permissions, env.workingDir, cfg.Tools.Ls),
183		tools.NewSourcegraphTool(r.GetDefaultClient()),
184		tools.NewViewTool(env.lspClients, env.permissions, env.workingDir),
185		tools.NewWriteTool(env.lspClients, env.permissions, env.history, env.workingDir),
186	}
187
188	return testSessionAgent(env, large, small, systemPrompt, allTools...), nil
189}
190
191// createSimpleGoProject creates a simple Go project structure in the given directory.
192// It creates a go.mod file and a main.go file with a basic hello world program.
193func createSimpleGoProject(t *testing.T, dir string) {
194	goMod := `module example.com/testproject
195
196go 1.23
197`
198	err := os.WriteFile(dir+"/go.mod", []byte(goMod), 0o644)
199	require.NoError(t, err)
200
201	mainGo := `package main
202
203import "fmt"
204
205func main() {
206	fmt.Println("Hello, World!")
207}
208`
209	err = os.WriteFile(dir+"/main.go", []byte(mainGo), 0o644)
210	require.NoError(t, err)
211}