1package agent
  2
  3import (
  4	"net/http"
  5	"os"
  6	"path/filepath"
  7	"testing"
  8	"time"
  9
 10	"github.com/charmbracelet/catwalk/pkg/catwalk"
 11	"github.com/charmbracelet/crush/internal/agent/prompt"
 12	"github.com/charmbracelet/crush/internal/agent/tools"
 13	"github.com/charmbracelet/crush/internal/config"
 14	"github.com/charmbracelet/crush/internal/csync"
 15	"github.com/charmbracelet/crush/internal/db"
 16	"github.com/charmbracelet/crush/internal/history"
 17	"github.com/charmbracelet/crush/internal/lsp"
 18	"github.com/charmbracelet/crush/internal/message"
 19	"github.com/charmbracelet/crush/internal/permission"
 20	"github.com/charmbracelet/crush/internal/session"
 21	"github.com/charmbracelet/fantasy/ai"
 22	"github.com/charmbracelet/fantasy/anthropic"
 23	"github.com/charmbracelet/fantasy/openai"
 24	"github.com/charmbracelet/fantasy/openaicompat"
 25	"github.com/charmbracelet/fantasy/openrouter"
 26	"github.com/stretchr/testify/require"
 27	"gopkg.in/dnaeon/go-vcr.v4/pkg/recorder"
 28
 29	_ "github.com/joho/godotenv/autoload"
 30)
 31
 32type env struct {
 33	workingDir  string
 34	sessions    session.Service
 35	messages    message.Service
 36	permissions permission.Service
 37	history     history.Service
 38	lspClients  *csync.Map[string, *lsp.Client]
 39}
 40
 41type builderFunc func(t *testing.T, r *recorder.Recorder) (ai.LanguageModel, error)
 42
 43type modelPair struct {
 44	name       string
 45	largeModel builderFunc
 46	smallModel builderFunc
 47}
 48
 49func anthropicBuilder(model string) builderFunc {
 50	return func(_ *testing.T, r *recorder.Recorder) (ai.LanguageModel, error) {
 51		provider := anthropic.New(
 52			anthropic.WithAPIKey(os.Getenv("CRUSH_ANTHROPIC_API_KEY")),
 53			anthropic.WithHTTPClient(&http.Client{Transport: r}),
 54		)
 55		return provider.LanguageModel(model)
 56	}
 57}
 58
 59func openaiBuilder(model string) builderFunc {
 60	return func(_ *testing.T, r *recorder.Recorder) (ai.LanguageModel, error) {
 61		provider := openai.New(
 62			openai.WithAPIKey(os.Getenv("CRUSH_OPENAI_API_KEY")),
 63			openai.WithHTTPClient(&http.Client{Transport: r}),
 64		)
 65		return provider.LanguageModel(model)
 66	}
 67}
 68
 69func openRouterBuilder(model string) builderFunc {
 70	return func(t *testing.T, r *recorder.Recorder) (ai.LanguageModel, error) {
 71		provider := openrouter.New(
 72			openrouter.WithAPIKey(os.Getenv("CRUSH_OPENROUTER_API_KEY")),
 73			openrouter.WithHTTPClient(&http.Client{Transport: r}),
 74		)
 75		return provider.LanguageModel(model)
 76	}
 77}
 78
 79func zAIBuilder(model string) builderFunc {
 80	return func(t *testing.T, r *recorder.Recorder) (ai.LanguageModel, error) {
 81		provider := openaicompat.New(
 82			openaicompat.WithBaseURL("https://api.z.ai/api/coding/paas/v4"),
 83			openaicompat.WithAPIKey(os.Getenv("CRUSH_ZAI_API_KEY")),
 84			openaicompat.WithHTTPClient(&http.Client{Transport: r}),
 85		)
 86		return provider.LanguageModel(model)
 87	}
 88}
 89
 90func testEnv(t *testing.T) env {
 91	testDir := filepath.Join("/tmp/crush-test/", t.Name())
 92	os.RemoveAll(testDir)
 93	err := os.MkdirAll(testDir, 0o755)
 94	t.Cleanup(func() {
 95		os.RemoveAll(testDir)
 96	})
 97	require.NoError(t, err)
 98	workingDir := testDir
 99	conn, err := db.Connect(t.Context(), t.TempDir())
100	require.NoError(t, err)
101	q := db.New(conn)
102	sessions := session.NewService(q)
103	messages := message.NewService(q)
104	permissions := permission.NewPermissionService(workingDir, true, []string{})
105	history := history.NewService(q, conn)
106	lspClients := csync.NewMap[string, *lsp.Client]()
107	return env{
108		workingDir,
109		sessions,
110		messages,
111		permissions,
112		history,
113		lspClients,
114	}
115}
116
117func testSessionAgent(env env, large, small ai.LanguageModel, systemPrompt string, tools ...ai.AgentTool) SessionAgent {
118	largeModel := Model{
119		Model: large,
120		CatwalkCfg: catwalk.Model{
121			ContextWindow:    200000,
122			DefaultMaxTokens: 10000,
123		},
124	}
125	smallModel := Model{
126		Model: small,
127		CatwalkCfg: catwalk.Model{
128			ContextWindow:    200000,
129			DefaultMaxTokens: 10000,
130		},
131	}
132	agent := NewSessionAgent(SessionAgentOptions{largeModel, smallModel, systemPrompt, false, env.sessions, env.messages, tools})
133	return agent
134}
135
136func coderAgent(r *recorder.Recorder, env env, large, small ai.LanguageModel) (SessionAgent, error) {
137	fixedTime := func() time.Time {
138		t, _ := time.Parse("1/2/2006", "1/1/2025")
139		return t
140	}
141	prompt, err := coderPrompt(
142		prompt.WithTimeFunc(fixedTime),
143		prompt.WithPlatform("linux"),
144		prompt.WithWorkingDir(env.workingDir),
145	)
146	if err != nil {
147		return nil, err
148	}
149	cfg, err := config.Init(env.workingDir, "", false)
150	if err != nil {
151		return nil, err
152	}
153
154	systemPrompt, err := prompt.Build(large.Provider(), large.Model(), *cfg)
155	if err != nil {
156		return nil, err
157	}
158	allTools := []ai.AgentTool{
159		tools.NewBashTool(env.permissions, env.workingDir, cfg.Options.Attribution),
160		tools.NewDownloadTool(env.permissions, env.workingDir, r.GetDefaultClient()),
161		tools.NewEditTool(env.lspClients, env.permissions, env.history, env.workingDir),
162		tools.NewMultiEditTool(env.lspClients, env.permissions, env.history, env.workingDir),
163		tools.NewFetchTool(env.permissions, env.workingDir, r.GetDefaultClient()),
164		tools.NewGlobTool(env.workingDir),
165		tools.NewGrepTool(env.workingDir),
166		tools.NewLsTool(env.permissions, env.workingDir, cfg.Tools.Ls),
167		tools.NewSourcegraphTool(r.GetDefaultClient()),
168		tools.NewViewTool(env.lspClients, env.permissions, env.workingDir),
169		tools.NewWriteTool(env.lspClients, env.permissions, env.history, env.workingDir),
170	}
171
172	return testSessionAgent(env, large, small, systemPrompt, allTools...), nil
173}
174
175// createSimpleGoProject creates a simple Go project structure in the given directory.
176// It creates a go.mod file and a main.go file with a basic hello world program.
177func createSimpleGoProject(t *testing.T, dir string) {
178	goMod := `module example.com/testproject
179
180go 1.23
181`
182	err := os.WriteFile(dir+"/go.mod", []byte(goMod), 0o644)
183	require.NoError(t, err)
184
185	mainGo := `package main
186
187import "fmt"
188
189func main() {
190	fmt.Println("Hello, World!")
191}
192`
193	err = os.WriteFile(dir+"/main.go", []byte(mainGo), 0o644)
194	require.NoError(t, err)
195}