1package agent
  2
  3import (
  4	"context"
  5	"net/http"
  6	"os"
  7	"path/filepath"
  8	"strings"
  9	"testing"
 10	"time"
 11
 12	"charm.land/fantasy"
 13	"charm.land/fantasy/providers/anthropic"
 14	"charm.land/fantasy/providers/openai"
 15	"charm.land/fantasy/providers/openaicompat"
 16	"charm.land/fantasy/providers/openrouter"
 17	"github.com/charmbracelet/catwalk/pkg/catwalk"
 18	"github.com/charmbracelet/crush/internal/agent/prompt"
 19	"github.com/charmbracelet/crush/internal/agent/tools"
 20	"github.com/charmbracelet/crush/internal/config"
 21	"github.com/charmbracelet/crush/internal/csync"
 22	"github.com/charmbracelet/crush/internal/db"
 23	"github.com/charmbracelet/crush/internal/history"
 24	"github.com/charmbracelet/crush/internal/lsp"
 25	"github.com/charmbracelet/crush/internal/message"
 26	"github.com/charmbracelet/crush/internal/permission"
 27	"github.com/charmbracelet/crush/internal/session"
 28	"github.com/stretchr/testify/require"
 29	"gopkg.in/dnaeon/go-vcr.v4/pkg/recorder"
 30
 31	_ "github.com/joho/godotenv/autoload"
 32)
 33
 34type env struct {
 35	workingDir  string
 36	sessions    session.Service
 37	messages    message.Service
 38	permissions permission.Service
 39	history     history.Service
 40	lspClients  *csync.Map[string, *lsp.Client]
 41}
 42
 43type builderFunc func(t *testing.T, r *recorder.Recorder) (fantasy.LanguageModel, error)
 44
 45type modelPair struct {
 46	name       string
 47	largeModel builderFunc
 48	smallModel builderFunc
 49}
 50
 51func anthropicBuilder(model string) builderFunc {
 52	return func(t *testing.T, r *recorder.Recorder) (fantasy.LanguageModel, error) {
 53		provider, err := anthropic.New(
 54			anthropic.WithAPIKey(os.Getenv("CRUSH_ANTHROPIC_API_KEY")),
 55			anthropic.WithHTTPClient(&http.Client{Transport: r}),
 56		)
 57		if err != nil {
 58			return nil, err
 59		}
 60		return provider.LanguageModel(t.Context(), model)
 61	}
 62}
 63
 64func openaiBuilder(model string) builderFunc {
 65	return func(t *testing.T, r *recorder.Recorder) (fantasy.LanguageModel, error) {
 66		provider, err := openai.New(
 67			openai.WithAPIKey(os.Getenv("CRUSH_OPENAI_API_KEY")),
 68			openai.WithHTTPClient(&http.Client{Transport: r}),
 69		)
 70		if err != nil {
 71			return nil, err
 72		}
 73		return provider.LanguageModel(t.Context(), model)
 74	}
 75}
 76
 77func openRouterBuilder(model string) builderFunc {
 78	return func(t *testing.T, r *recorder.Recorder) (fantasy.LanguageModel, error) {
 79		provider, err := openrouter.New(
 80			openrouter.WithAPIKey(os.Getenv("CRUSH_OPENROUTER_API_KEY")),
 81			openrouter.WithHTTPClient(&http.Client{Transport: r}),
 82		)
 83		if err != nil {
 84			return nil, err
 85		}
 86		return provider.LanguageModel(t.Context(), model)
 87	}
 88}
 89
 90func zAIBuilder(model string) builderFunc {
 91	return func(t *testing.T, r *recorder.Recorder) (fantasy.LanguageModel, error) {
 92		provider, err := openaicompat.New(
 93			openaicompat.WithBaseURL("https://api.z.ai/api/coding/paas/v4"),
 94			openaicompat.WithAPIKey(os.Getenv("CRUSH_ZAI_API_KEY")),
 95			openaicompat.WithHTTPClient(&http.Client{Transport: r}),
 96		)
 97		if err != nil {
 98			return nil, err
 99		}
100		return provider.LanguageModel(t.Context(), model)
101	}
102}
103
104func testEnv(t *testing.T) env {
105	testDir := filepath.Join("/tmp/crush-test/", t.Name())
106	os.RemoveAll(testDir)
107	err := os.MkdirAll(testDir, 0o755)
108	require.NoError(t, err)
109	workingDir := testDir
110	conn, err := db.Connect(t.Context(), t.TempDir())
111	require.NoError(t, err)
112	q := db.New(conn)
113	sessions := session.NewService(q)
114	messages := message.NewService(q)
115	permissions := permission.NewPermissionService(workingDir, true, []string{})
116	history := history.NewService(q, conn)
117	lspClients := csync.NewMap[string, *lsp.Client]()
118
119	t.Cleanup(func() {
120		conn.Close()
121		os.RemoveAll(testDir)
122	})
123
124	return env{
125		workingDir,
126		sessions,
127		messages,
128		permissions,
129		history,
130		lspClients,
131	}
132}
133
134func testSessionAgent(env env, large, small fantasy.LanguageModel, systemPrompt string, tools ...fantasy.AgentTool) SessionAgent {
135	largeModel := Model{
136		Model: large,
137		CatwalkCfg: catwalk.Model{
138			ContextWindow:    200000,
139			DefaultMaxTokens: 10000,
140		},
141	}
142	smallModel := Model{
143		Model: small,
144		CatwalkCfg: catwalk.Model{
145			ContextWindow:    200000,
146			DefaultMaxTokens: 10000,
147		},
148	}
149	agent := NewSessionAgent(SessionAgentOptions{largeModel, smallModel, "", systemPrompt, false, true, env.sessions, env.messages, tools})
150	return agent
151}
152
153func coderAgent(r *recorder.Recorder, env env, large, small fantasy.LanguageModel) (SessionAgent, error) {
154	fixedTime := func() time.Time {
155		t, _ := time.Parse("1/2/2006", "1/1/2025")
156		return t
157	}
158	path := filepath.ToSlash(env.workingDir)
159	// try fixing windows
160	if strings.Contains(path, ":") {
161		path = strings.Split(path, ":")[1]
162	}
163	prompt, err := coderPrompt(
164		prompt.WithTimeFunc(fixedTime),
165		prompt.WithPlatform("linux"),
166		prompt.WithWorkingDir(path),
167	)
168	if err != nil {
169		return nil, err
170	}
171	cfg, err := config.Init(env.workingDir, "", false)
172	if err != nil {
173		return nil, err
174	}
175
176	systemPrompt, err := prompt.Build(context.TODO(), large.Provider(), large.Model(), *cfg)
177	if err != nil {
178		return nil, err
179	}
180	allTools := []fantasy.AgentTool{
181		tools.NewBashTool(env.permissions, env.workingDir, cfg.Options.Attribution),
182		tools.NewDownloadTool(env.permissions, env.workingDir, r.GetDefaultClient()),
183		tools.NewEditTool(env.lspClients, env.permissions, env.history, env.workingDir),
184		tools.NewMultiEditTool(env.lspClients, env.permissions, env.history, env.workingDir),
185		tools.NewFetchTool(env.permissions, env.workingDir, r.GetDefaultClient()),
186		tools.NewGlobTool(env.workingDir),
187		tools.NewGrepTool(env.workingDir),
188		tools.NewLsTool(env.permissions, env.workingDir, cfg.Tools.Ls),
189		tools.NewSourcegraphTool(r.GetDefaultClient()),
190		tools.NewViewTool(env.lspClients, env.permissions, env.workingDir),
191		tools.NewWriteTool(env.lspClients, env.permissions, env.history, env.workingDir),
192	}
193
194	return testSessionAgent(env, large, small, systemPrompt, allTools...), nil
195}
196
197// createSimpleGoProject creates a simple Go project structure in the given directory.
198// It creates a go.mod file and a main.go file with a basic hello world program.
199func createSimpleGoProject(t *testing.T, dir string) {
200	goMod := `module example.com/testproject
201
202go 1.23
203`
204	err := os.WriteFile(dir+"/go.mod", []byte(goMod), 0o644)
205	require.NoError(t, err)
206
207	mainGo := `package main
208
209import "fmt"
210
211func main() {
212	fmt.Println("Hello, World!")
213}
214`
215	err = os.WriteFile(dir+"/main.go", []byte(mainGo), 0o644)
216	require.NoError(t, err)
217}