1package agent
  2
  3import (
  4	"fmt"
  5	"net/http"
  6	"os"
  7	"path/filepath"
  8	"testing"
  9	"time"
 10
 11	"github.com/charmbracelet/catwalk/pkg/catwalk"
 12	"github.com/charmbracelet/crush/internal/agent/prompt"
 13	"github.com/charmbracelet/crush/internal/agent/tools"
 14	"github.com/charmbracelet/crush/internal/config"
 15	"github.com/charmbracelet/crush/internal/csync"
 16	"github.com/charmbracelet/crush/internal/db"
 17	"github.com/charmbracelet/crush/internal/history"
 18	"github.com/charmbracelet/crush/internal/lsp"
 19	"github.com/charmbracelet/crush/internal/message"
 20	"github.com/charmbracelet/crush/internal/permission"
 21	"github.com/charmbracelet/crush/internal/session"
 22	"github.com/charmbracelet/fantasy/ai"
 23	"github.com/charmbracelet/fantasy/anthropic"
 24	"github.com/charmbracelet/fantasy/openai"
 25	"github.com/charmbracelet/fantasy/openrouter"
 26	"github.com/stretchr/testify/require"
 27	"gopkg.in/dnaeon/go-vcr.v4/pkg/recorder"
 28
 29	_ "github.com/joho/godotenv/autoload"
 30)
 31
 32type env struct {
 33	workingDir  string
 34	sessions    session.Service
 35	messages    message.Service
 36	permissions permission.Service
 37	history     history.Service
 38	lspClients  *csync.Map[string, *lsp.Client]
 39}
 40
 41type builderFunc func(t *testing.T, r *recorder.Recorder) (ai.LanguageModel, error)
 42
 43type modelPair struct {
 44	name       string
 45	largeModel builderFunc
 46	smallModel builderFunc
 47}
 48
 49func anthropicBuilder(model string) builderFunc {
 50	return func(_ *testing.T, r *recorder.Recorder) (ai.LanguageModel, error) {
 51		provider := anthropic.New(
 52			anthropic.WithAPIKey(os.Getenv("CRUSH_ANTHROPIC_API_KEY")),
 53			anthropic.WithHTTPClient(&http.Client{Transport: r}),
 54		)
 55		return provider.LanguageModel(model)
 56	}
 57}
 58
 59func openaiBuilder(model string) builderFunc {
 60	return func(_ *testing.T, r *recorder.Recorder) (ai.LanguageModel, error) {
 61		provider := openai.New(
 62			openai.WithAPIKey(os.Getenv("CRUSH_OPENAI_API_KEY")),
 63			openai.WithHTTPClient(&http.Client{Transport: r}),
 64		)
 65		return provider.LanguageModel(model)
 66	}
 67}
 68
 69func openRouterBuilder(model string) builderFunc {
 70	return func(t *testing.T, r *recorder.Recorder) (ai.LanguageModel, error) {
 71		tf := func() func() string {
 72			id := 0
 73			return func() string {
 74				id += 1
 75				return fmt.Sprintf("%s-%d", t.Name(), id)
 76			}
 77		}
 78		provider := openrouter.New(
 79			openrouter.WithAPIKey(os.Getenv("CRUSH_OPENROUTER_API_KEY")),
 80			openrouter.WithHTTPClient(&http.Client{Transport: r}),
 81			openrouter.WithLanguageUniqueToolCallIds(),
 82			openrouter.WithLanguageModelGenerateIDFunc(tf()),
 83		)
 84		return provider.LanguageModel(model)
 85	}
 86}
 87
 88func testEnv(t *testing.T) env {
 89	testDir := filepath.Join("/tmp/crush-test/", t.Name())
 90	os.RemoveAll(testDir)
 91	err := os.MkdirAll(testDir, 0o755)
 92	t.Cleanup(func() {
 93		os.RemoveAll(testDir)
 94	})
 95	require.NoError(t, err)
 96	workingDir := testDir
 97	conn, err := db.Connect(t.Context(), t.TempDir())
 98	require.NoError(t, err)
 99	q := db.New(conn)
100	sessions := session.NewService(q)
101	messages := message.NewService(q)
102	permissions := permission.NewPermissionService(workingDir, true, []string{})
103	history := history.NewService(q, conn)
104	lspClients := csync.NewMap[string, *lsp.Client]()
105	return env{
106		workingDir,
107		sessions,
108		messages,
109		permissions,
110		history,
111		lspClients,
112	}
113}
114
115func testSessionAgent(env env, large, small ai.LanguageModel, systemPrompt string, tools ...ai.AgentTool) SessionAgent {
116	largeModel := Model{
117		model:  large,
118		config: catwalk.Model{
119			// todo: add values
120		},
121	}
122	smallModel := Model{
123		model:  small,
124		config: catwalk.Model{
125			// todo: add values
126		},
127	}
128	agent := NewSessionAgent(largeModel, smallModel, systemPrompt, env.sessions, env.messages, tools...)
129	return agent
130}
131
132func coderAgent(r *recorder.Recorder, env env, large, small ai.LanguageModel) (SessionAgent, error) {
133	fixedTime := func() time.Time {
134		t, _ := time.Parse("1/2/2006", "1/1/2025")
135		return t
136	}
137	prompt, err := coderPrompt(prompt.WithTimeFunc(fixedTime))
138	if err != nil {
139		return nil, err
140	}
141	cfg, err := config.Init(env.workingDir, "", false)
142	if err != nil {
143		return nil, err
144	}
145
146	systemPrompt, err := prompt.Build(large.Provider(), large.Model(), *cfg)
147	if err != nil {
148		return nil, err
149	}
150	allTools := []ai.AgentTool{
151		tools.NewBashTool(env.permissions, env.workingDir, cfg.Options.Attribution),
152		tools.NewDownloadTool(env.permissions, env.workingDir, r.GetDefaultClient()),
153		tools.NewEditTool(env.lspClients, env.permissions, env.history, env.workingDir),
154		tools.NewMultiEditTool(env.lspClients, env.permissions, env.history, env.workingDir),
155		tools.NewFetchTool(env.permissions, env.workingDir, r.GetDefaultClient()),
156		tools.NewGlobTool(env.workingDir),
157		tools.NewGrepTool(env.workingDir),
158		tools.NewLsTool(env.permissions, env.workingDir),
159		tools.NewSourcegraphTool(r.GetDefaultClient()),
160		tools.NewViewTool(env.lspClients, env.permissions, env.workingDir),
161		tools.NewWriteTool(env.lspClients, env.permissions, env.history, env.workingDir),
162	}
163
164	return testSessionAgent(env, large, small, systemPrompt, allTools...), nil
165}
166
167// createSimpleGoProject creates a simple Go project structure in the given directory.
168// It creates a go.mod file and a main.go file with a basic hello world program.
169func createSimpleGoProject(t *testing.T, dir string) {
170	goMod := `module example.com/testproject
171
172go 1.23
173`
174	err := os.WriteFile(dir+"/go.mod", []byte(goMod), 0o644)
175	require.NoError(t, err)
176
177	mainGo := `package main
178
179import "fmt"
180
181func main() {
182	fmt.Println("Hello, World!")
183}
184`
185	err = os.WriteFile(dir+"/main.go", []byte(mainGo), 0o644)
186	require.NoError(t, err)
187}