agent_test.go

 1package agent
 2
 3import (
 4	"database/sql"
 5	"net/http"
 6	"os"
 7	"testing"
 8
 9	"github.com/charmbracelet/catwalk/pkg/catwalk"
10	"github.com/charmbracelet/crush/internal/db"
11	"github.com/charmbracelet/crush/internal/message"
12	"github.com/charmbracelet/crush/internal/session"
13	"github.com/charmbracelet/fantasy/ai"
14	"github.com/charmbracelet/fantasy/anthropic"
15	"github.com/stretchr/testify/assert"
16	"github.com/stretchr/testify/require"
17	"gopkg.in/dnaeon/go-vcr.v4/pkg/recorder"
18
19	_ "github.com/joho/godotenv/autoload"
20)
21
22type builderFunc func(r *recorder.Recorder) (ai.LanguageModel, error)
23
24func TestSessionSimpleAgent(t *testing.T) {
25	r := newRecorder(t)
26	sonnet, err := anthropicBuilder("claude-sonnet-4-5-20250929")(r)
27	require.NoError(t, err)
28	haiku, err := anthropicBuilder("claude-3-5-haiku-20241022")(r)
29	require.NoError(t, err)
30	agent, sessions, messages := testSessionAgent(t, sonnet, haiku, "You are a helpful assistant")
31	session, err := sessions.Create(t.Context(), "New Session")
32	require.NoError(t, err)
33
34	res, err := agent.Run(t.Context(), SessionAgentCall{
35		Prompt:          "Hello",
36		SessionID:       session.ID,
37		MaxOutputTokens: 10000,
38	})
39
40	require.NoError(t, err)
41	assert.NotNil(t, res)
42
43	t.Run("should create session messages", func(t *testing.T) {
44		msgs, err := messages.List(t.Context(), session.ID)
45		require.NoError(t, err)
46		// Should have the agent and user message
47		assert.Equal(t, len(msgs), 2)
48	})
49}
50
51func anthropicBuilder(model string) builderFunc {
52	return func(r *recorder.Recorder) (ai.LanguageModel, error) {
53		provider := anthropic.New(
54			anthropic.WithAPIKey(os.Getenv("CRUSH_ANTHROPIC_API_KEY")),
55			anthropic.WithHTTPClient(&http.Client{Transport: r}),
56		)
57		return provider.LanguageModel(model)
58	}
59}
60
61func testDBConn(t *testing.T) (*sql.DB, error) {
62	return db.Connect(t.Context(), t.TempDir())
63}
64
65func testSessionAgent(t *testing.T, large, small ai.LanguageModel, systemPrompt string, tools ...ai.AgentTool) (SessionAgent, session.Service, message.Service) {
66	conn, err := testDBConn(t)
67	require.Nil(t, err)
68	q := db.New(conn)
69	sessions := session.NewService(q)
70	messages := message.NewService(q)
71
72	largeModel := Model{
73		model:  large,
74		config: catwalk.Model{
75			// todo: add values
76		},
77	}
78	smallModel := Model{
79		model:  large,
80		config: catwalk.Model{
81			// todo: add values
82		},
83	}
84	agent := NewSessionAgent(largeModel, smallModel, systemPrompt, sessions, messages, tools...)
85	return agent, sessions, messages
86}