1package agent
2
3import (
4 "database/sql"
5 "net/http"
6 "os"
7 "testing"
8
9 "github.com/charmbracelet/catwalk/pkg/catwalk"
10 "github.com/charmbracelet/crush/internal/db"
11 "github.com/charmbracelet/crush/internal/message"
12 "github.com/charmbracelet/crush/internal/session"
13 "github.com/charmbracelet/fantasy/ai"
14 "github.com/charmbracelet/fantasy/anthropic"
15 "github.com/stretchr/testify/assert"
16 "github.com/stretchr/testify/require"
17 "gopkg.in/dnaeon/go-vcr.v4/pkg/recorder"
18
19 _ "github.com/joho/godotenv/autoload"
20)
21
22type builderFunc func(r *recorder.Recorder) (ai.LanguageModel, error)
23
24func TestSessionSimpleAgent(t *testing.T) {
25 r := newRecorder(t)
26 sonnet, err := anthropicBuilder("claude-sonnet-4-5-20250929")(r)
27 require.NoError(t, err)
28 haiku, err := anthropicBuilder("claude-3-5-haiku-20241022")(r)
29 require.NoError(t, err)
30 agent, sessions, messages := testSessionAgent(t, sonnet, haiku, "You are a helpful assistant")
31 session, err := sessions.Create(t.Context(), "New Session")
32 require.NoError(t, err)
33
34 res, err := agent.Run(t.Context(), SessionAgentCall{
35 Prompt: "Hello",
36 SessionID: session.ID,
37 MaxOutputTokens: 10000,
38 })
39
40 require.NoError(t, err)
41 assert.NotNil(t, res)
42
43 t.Run("should create session messages", func(t *testing.T) {
44 msgs, err := messages.List(t.Context(), session.ID)
45 require.NoError(t, err)
46 // Should have the agent and user message
47 assert.Equal(t, len(msgs), 2)
48 })
49}
50
51func anthropicBuilder(model string) builderFunc {
52 return func(r *recorder.Recorder) (ai.LanguageModel, error) {
53 provider := anthropic.New(
54 anthropic.WithAPIKey(os.Getenv("CRUSH_ANTHROPIC_API_KEY")),
55 anthropic.WithHTTPClient(&http.Client{Transport: r}),
56 )
57 return provider.LanguageModel(model)
58 }
59}
60
61func testDBConn(t *testing.T) (*sql.DB, error) {
62 return db.Connect(t.Context(), t.TempDir())
63}
64
65func testSessionAgent(t *testing.T, large, small ai.LanguageModel, systemPrompt string, tools ...ai.AgentTool) (SessionAgent, session.Service, message.Service) {
66 conn, err := testDBConn(t)
67 require.Nil(t, err)
68 q := db.New(conn)
69 sessions := session.NewService(q)
70 messages := message.NewService(q)
71
72 largeModel := Model{
73 model: large,
74 config: catwalk.Model{
75 // todo: add values
76 },
77 }
78 smallModel := Model{
79 model: large,
80 config: catwalk.Model{
81 // todo: add values
82 },
83 }
84 agent := NewSessionAgent(largeModel, smallModel, systemPrompt, sessions, messages, tools...)
85 return agent, sessions, messages
86}