agent_test.go

 1package agent
 2
 3import (
 4	"database/sql"
 5	"net/http"
 6	"os"
 7	"testing"
 8
 9	"github.com/charmbracelet/catwalk/pkg/catwalk"
10	"github.com/charmbracelet/crush/internal/db"
11	"github.com/charmbracelet/crush/internal/message"
12	"github.com/charmbracelet/crush/internal/session"
13	"github.com/charmbracelet/fantasy/ai"
14	"github.com/charmbracelet/fantasy/anthropic"
15	"github.com/stretchr/testify/require"
16	"gopkg.in/dnaeon/go-vcr.v4/pkg/recorder"
17
18	_ "github.com/joho/godotenv/autoload"
19)
20
21type builderFunc func(r *recorder.Recorder) (ai.LanguageModel, error)
22
23func TestSessionSimpleAgent(t *testing.T) {
24	r := newRecorder(t)
25	sonnet, err := anthropicBuilder("claude-sonnet-4-5-20250929")(r)
26	require.Nil(t, err)
27	haiku, err := anthropicBuilder("claude-3-5-haiku-20241022")(r)
28	require.Nil(t, err)
29	agent, sessions, messages := testSessionAgent(t, sonnet, haiku, "You are a helpful assistant")
30	session, err := sessions.Create(t.Context(), "New Session")
31	require.Nil(t, err)
32
33	res, err := agent.Run(t.Context(), SessionAgentCall{
34		Prompt:          "Hello",
35		SessionID:       session.ID,
36		MaxOutputTokens: 10000,
37	})
38
39	require.Nil(t, err)
40	require.NotNil(t, res)
41
42	t.Run("should create session messages", func(t *testing.T) {
43		msgs, err := messages.List(t.Context(), session.ID)
44		require.Nil(t, err)
45		// Should have the agent and user message
46		require.Equal(t, len(msgs), 2)
47	})
48}
49
50func anthropicBuilder(model string) builderFunc {
51	return func(r *recorder.Recorder) (ai.LanguageModel, error) {
52		provider := anthropic.New(
53			anthropic.WithAPIKey(os.Getenv("CRUSH_ANTHROPIC_API_KEY")),
54			anthropic.WithHTTPClient(&http.Client{Transport: r}),
55		)
56		return provider.LanguageModel(model)
57	}
58}
59
60func testDBConn(t *testing.T) (*sql.DB, error) {
61	return db.Connect(t.Context(), t.TempDir())
62}
63
64func testSessionAgent(t *testing.T, large, small ai.LanguageModel, systemPrompt string, tools ...ai.AgentTool) (SessionAgent, session.Service, message.Service) {
65	conn, err := testDBConn(t)
66	require.Nil(t, err)
67	q := db.New(conn)
68	sessions := session.NewService(q)
69	messages := message.NewService(q)
70
71	largeModel := Model{
72		model:  large,
73		config: catwalk.Model{
74			// todo: add values
75		},
76	}
77	smallModel := Model{
78		model:  large,
79		config: catwalk.Model{
80			// todo: add values
81		},
82	}
83	agent := NewSessionAgent(largeModel, smallModel, systemPrompt, sessions, messages, tools...)
84	return agent, sessions, messages
85}