1package agent
2
3import (
4 "database/sql"
5 "net/http"
6 "os"
7 "testing"
8
9 "github.com/charmbracelet/catwalk/pkg/catwalk"
10 "github.com/charmbracelet/crush/internal/db"
11 "github.com/charmbracelet/crush/internal/message"
12 "github.com/charmbracelet/crush/internal/session"
13 "github.com/charmbracelet/fantasy/ai"
14 "github.com/charmbracelet/fantasy/anthropic"
15 "github.com/stretchr/testify/require"
16 "gopkg.in/dnaeon/go-vcr.v4/pkg/recorder"
17
18 _ "github.com/joho/godotenv/autoload"
19)
20
21type builderFunc func(r *recorder.Recorder) (ai.LanguageModel, error)
22
23func TestSessionSimpleAgent(t *testing.T) {
24 r := newRecorder(t)
25 sonnet, err := anthropicBuilder("claude-sonnet-4-5-20250929")(r)
26 require.Nil(t, err)
27 haiku, err := anthropicBuilder("claude-3-5-haiku-20241022")(r)
28 require.Nil(t, err)
29 agent, sessions, messages := testSessionAgent(t, sonnet, haiku, "You are a helpful assistant")
30 session, err := sessions.Create(t.Context(), "New Session")
31 require.Nil(t, err)
32
33 res, err := agent.Run(t.Context(), SessionAgentCall{
34 Prompt: "Hello",
35 SessionID: session.ID,
36 MaxOutputTokens: 10000,
37 })
38
39 require.Nil(t, err)
40 require.NotNil(t, res)
41
42 t.Run("should create session messages", func(t *testing.T) {
43 msgs, err := messages.List(t.Context(), session.ID)
44 require.Nil(t, err)
45 // Should have the agent and user message
46 require.Equal(t, len(msgs), 2)
47 })
48}
49
50func anthropicBuilder(model string) builderFunc {
51 return func(r *recorder.Recorder) (ai.LanguageModel, error) {
52 provider := anthropic.New(
53 anthropic.WithAPIKey(os.Getenv("CRUSH_ANTHROPIC_API_KEY")),
54 anthropic.WithHTTPClient(&http.Client{Transport: r}),
55 )
56 return provider.LanguageModel(model)
57 }
58}
59
60func testDBConn(t *testing.T) (*sql.DB, error) {
61 return db.Connect(t.Context(), t.TempDir())
62}
63
64func testSessionAgent(t *testing.T, large, small ai.LanguageModel, systemPrompt string, tools ...ai.AgentTool) (SessionAgent, session.Service, message.Service) {
65 conn, err := testDBConn(t)
66 require.Nil(t, err)
67 q := db.New(conn)
68 sessions := session.NewService(q)
69 messages := message.NewService(q)
70
71 largeModel := Model{
72 model: large,
73 config: catwalk.Model{
74 // todo: add values
75 },
76 }
77 smallModel := Model{
78 model: large,
79 config: catwalk.Model{
80 // todo: add values
81 },
82 }
83 agent := NewSessionAgent(largeModel, smallModel, systemPrompt, sessions, messages, tools...)
84 return agent, sessions, messages
85}