test: initial test setup

kujtimiihoxha created

Change summary

go.mod                                              |   2 
go.sum                                              |   4 
internal/agent/.env.sample                          |   1 
internal/agent/agent.go                             |  41 ++++-
internal/agent/agent_test.go                        |  85 ++++++++++++
internal/agent/recorder_test.go                     | 104 +++++++++++++++
internal/agent/testdata/TestSessionSimpleAgent.yaml | 100 ++++++++++++++
7 files changed, 326 insertions(+), 11 deletions(-)

Detailed changes

go.mod 🔗

@@ -41,7 +41,9 @@ require (
 	github.com/stretchr/testify v1.11.1
 	github.com/tidwall/sjson v1.2.5
 	github.com/zeebo/xxh3 v1.0.2
+	go.yaml.in/yaml/v4 v4.0.0-rc.2
 	google.golang.org/genai v1.26.0
+	gopkg.in/dnaeon/go-vcr.v4 v4.0.6-0.20250923044825-7b4892dd3117
 	gopkg.in/natefinch/lumberjack.v2 v2.2.1
 	mvdan.cc/sh/v3 v3.12.1-0.20250902163504-3cf4fd5717a5
 )

go.sum 🔗

@@ -367,6 +367,8 @@ go.opentelemetry.io/otel/trace v1.37.0 h1:HLdcFNbRQBE2imdSEgm/kwqmQj1Or1l/7bW6mx
 go.opentelemetry.io/otel/trace v1.37.0/go.mod h1:TlgrlQ+PtQO5XFerSPUYG0JSgGyryXewPGyayAWSBS0=
 go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0=
 go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
+go.yaml.in/yaml/v4 v4.0.0-rc.2 h1:/FrI8D64VSr4HtGIlUtlFMGsm7H7pWTbj6vOLVZcA6s=
+go.yaml.in/yaml/v4 v4.0.0-rc.2/go.mod h1:aZqd9kCMsGL7AuUv/m/PvWLdg5sjJsZ4oHDEnfPPfY0=
 golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
 golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
 golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
@@ -514,6 +516,8 @@ gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8
 gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
 gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
 gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
+gopkg.in/dnaeon/go-vcr.v4 v4.0.6-0.20250923044825-7b4892dd3117 h1:fbE/sTnBb9UNfE8cJsOzrYYPqVWVHb7jWH4SI1W//cM=
+gopkg.in/dnaeon/go-vcr.v4 v4.0.6-0.20250923044825-7b4892dd3117/go.mod h1:YuVT9NPq7t3oT2WpUemB0DbNL7djIjgajZycxoDLnqs=
 gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc=
 gopkg.in/natefinch/lumberjack.v2 v2.2.1/go.mod h1:YD8tP3GAjkrDg1eZH7EGmyESg/lsYskCTPBJVb9jqSc=
 gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ=

internal/agent/agent.go 🔗

@@ -7,6 +7,7 @@ import (
 	"fmt"
 	"log/slog"
 	"strings"
+	"sync"
 	"time"
 
 	"github.com/charmbracelet/catwalk/pkg/catwalk"
@@ -57,13 +58,12 @@ type Model struct {
 }
 
 type sessionAgent struct {
-	largeModel      Model
-	smallModel      Model
-	systemPrompt    string
-	tools           []ai.AgentTool
-	maxOutputTokens int64
-	sessions        session.Service
-	messages        message.Service
+	largeModel   Model
+	smallModel   Model
+	systemPrompt string
+	tools        []ai.AgentTool
+	sessions     session.Service
+	messages     message.Service
 
 	messageQueue   *csync.Map[string, []SessionAgentCall]
 	activeRequests *csync.Map[string, context.CancelFunc]
@@ -71,8 +71,24 @@ type sessionAgent struct {
 
 type SessionAgentOption func(*sessionAgent)
 
-func NewSessionAgent() SessionAgent {
-	return &sessionAgent{}
+func NewSessionAgent(
+	largeModel Model,
+	smallModel Model,
+	systemPrompt string,
+	sessions session.Service,
+	messages message.Service,
+	tools ...ai.AgentTool,
+) SessionAgent {
+	return &sessionAgent{
+		largeModel:     largeModel,
+		smallModel:     smallModel,
+		systemPrompt:   systemPrompt,
+		sessions:       sessions,
+		messages:       messages,
+		tools:          tools,
+		messageQueue:   csync.NewMap[string, []SessionAgentCall](),
+		activeRequests: csync.NewMap[string, context.CancelFunc](),
+	}
 }
 
 func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*ai.AgentResult, error) {
@@ -103,7 +119,6 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*ai.Agen
 		a.largeModel.model,
 		ai.WithSystemPrompt(a.systemPrompt),
 		ai.WithTools(a.tools...),
-		ai.WithMaxOutputTokens(a.maxOutputTokens),
 	)
 
 	currentSession, err := a.sessions.Get(ctx, call.SessionID)
@@ -116,9 +131,12 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*ai.Agen
 		return nil, fmt.Errorf("failed to get session messages: %w", err)
 	}
 
+	var wg sync.WaitGroup
 	// Generate title if first message
 	if len(msgs) == 0 {
-		go a.generateTitle(ctx, currentSession, call.Prompt)
+		wg.Go(func() {
+			a.generateTitle(ctx, currentSession, call.Prompt)
+		})
 	}
 
 	// Add the user message to the session
@@ -325,6 +343,7 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*ai.Agen
 	if err != nil {
 		return nil, err
 	}
+	wg.Wait()
 
 	queuedMessages, ok := a.messageQueue.Get(call.SessionID)
 	if !ok || len(queuedMessages) == 0 {

internal/agent/agent_test.go 🔗

@@ -0,0 +1,85 @@
+package agent
+
+import (
+	"database/sql"
+	"net/http"
+	"os"
+	"testing"
+
+	"github.com/charmbracelet/catwalk/pkg/catwalk"
+	"github.com/charmbracelet/crush/internal/db"
+	"github.com/charmbracelet/crush/internal/message"
+	"github.com/charmbracelet/crush/internal/session"
+	"github.com/charmbracelet/fantasy/ai"
+	"github.com/charmbracelet/fantasy/anthropic"
+	"github.com/stretchr/testify/require"
+	"gopkg.in/dnaeon/go-vcr.v4/pkg/recorder"
+
+	_ "github.com/joho/godotenv/autoload"
+)
+
+type builderFunc func(r *recorder.Recorder) (ai.LanguageModel, error)
+
+func TestSessionSimpleAgent(t *testing.T) {
+	r := newRecorder(t)
+	sonnet, err := anthropicBuilder("claude-sonnet-4-5-20250929")(r)
+	require.Nil(t, err)
+	haiku, err := anthropicBuilder("claude-3-5-haiku-20241022")(r)
+	require.Nil(t, err)
+	agent, sessions, messages := testSessionAgent(t, sonnet, haiku, "You are a helpful assistant")
+	session, err := sessions.Create(t.Context(), "New Session")
+	require.Nil(t, err)
+
+	res, err := agent.Run(t.Context(), SessionAgentCall{
+		Prompt:          "Hello",
+		SessionID:       session.ID,
+		MaxOutputTokens: 10000,
+	})
+
+	require.Nil(t, err)
+	require.NotNil(t, res)
+
+	t.Run("should create session messages", func(t *testing.T) {
+		msgs, err := messages.List(t.Context(), session.ID)
+		require.Nil(t, err)
+		// Should have the agent and user message
+		require.Equal(t, len(msgs), 2)
+	})
+}
+
+func anthropicBuilder(model string) builderFunc {
+	return func(r *recorder.Recorder) (ai.LanguageModel, error) {
+		provider := anthropic.New(
+			anthropic.WithAPIKey(os.Getenv("CRUSH_ANTHROPIC_API_KEY")),
+			anthropic.WithHTTPClient(&http.Client{Transport: r}),
+		)
+		return provider.LanguageModel(model)
+	}
+}
+
+func testDBConn(t *testing.T) (*sql.DB, error) {
+	return db.Connect(t.Context(), t.TempDir())
+}
+
+func testSessionAgent(t *testing.T, large, small ai.LanguageModel, systemPrompt string, tools ...ai.AgentTool) (SessionAgent, session.Service, message.Service) {
+	conn, err := testDBConn(t)
+	require.Nil(t, err)
+	q := db.New(conn)
+	sessions := session.NewService(q)
+	messages := message.NewService(q)
+
+	largeModel := Model{
+		model:  large,
+		config: catwalk.Model{
+			// todo: add values
+		},
+	}
+	smallModel := Model{
+		model:  large,
+		config: catwalk.Model{
+			// todo: add values
+		},
+	}
+	agent := NewSessionAgent(largeModel, smallModel, systemPrompt, sessions, messages, tools...)
+	return agent, sessions, messages
+}

internal/agent/recorder_test.go 🔗

@@ -0,0 +1,104 @@
+package agent
+
+import (
+	"bytes"
+	"encoding/json"
+	"io"
+	"net/http"
+	"path/filepath"
+	"reflect"
+	"strings"
+	"testing"
+
+	"go.yaml.in/yaml/v4"
+	"gopkg.in/dnaeon/go-vcr.v4/pkg/cassette"
+	"gopkg.in/dnaeon/go-vcr.v4/pkg/recorder"
+)
+
+func newRecorder(t *testing.T) *recorder.Recorder {
+	cassetteName := filepath.Join("testdata", t.Name())
+
+	r, err := recorder.New(
+		cassetteName,
+		recorder.WithMode(recorder.ModeRecordOnce),
+		recorder.WithMatcher(customMatcher(t)),
+		recorder.WithMarshalFunc(marshalFunc),
+		recorder.WithSkipRequestLatency(true), // disable sleep to simulate response time, makes tests faster
+		recorder.WithHook(hookRemoveHeaders, recorder.AfterCaptureHook),
+	)
+	if err != nil {
+		t.Fatalf("recorder: failed to create recorder: %v", err)
+	}
+
+	t.Cleanup(func() {
+		if err := r.Stop(); err != nil {
+			t.Errorf("recorder: failed to stop recorder: %v", err)
+		}
+	})
+
+	return r
+}
+
+func customMatcher(t *testing.T) recorder.MatcherFunc {
+	return func(r *http.Request, i cassette.Request) bool {
+		if r.Body == nil || r.Body == http.NoBody {
+			return cassette.DefaultMatcher(r, i)
+		}
+		if r.Method != i.Method || r.URL.String() != i.URL {
+			return false
+		}
+
+		reqBody, err := io.ReadAll(r.Body)
+		if err != nil {
+			t.Fatalf("recorder: failed to read request body")
+		}
+		r.Body.Close()
+		r.Body = io.NopCloser(bytes.NewBuffer(reqBody))
+
+		// Some providers can sometimes generate JSON requests with keys in
+		// a different order, which means a direct string comparison will fail.
+		// Falling back to deserializing the content if we don't have a match.
+		if string(reqBody) == i.Body { // hot path
+			return true
+		}
+		var content1, content2 any
+		if err := json.Unmarshal(reqBody, &content1); err != nil {
+			return false
+		}
+		if err := json.Unmarshal([]byte(i.Body), &content2); err != nil {
+			return false
+		}
+		return reflect.DeepEqual(content1, content2)
+	}
+}
+
+func marshalFunc(in any) ([]byte, error) {
+	var buff bytes.Buffer
+	enc := yaml.NewEncoder(&buff)
+	enc.SetIndent(2)
+	enc.CompactSeqIndent()
+	if err := enc.Encode(in); err != nil {
+		return nil, err
+	}
+	return buff.Bytes(), nil
+}
+
+var headersToKeep = map[string]struct{}{
+	"accept":       {},
+	"content-type": {},
+	"user-agent":   {},
+}
+
+func hookRemoveHeaders(i *cassette.Interaction) error {
+	for k := range i.Request.Headers {
+		if _, ok := headersToKeep[strings.ToLower(k)]; !ok {
+			delete(i.Request.Headers, k)
+		}
+	}
+	for k := range i.Response.Headers {
+		if _, ok := headersToKeep[strings.ToLower(k)]; !ok {
+			delete(i.Response.Headers, k)
+		}
+	}
+	return nil
+}

internal/agent/testdata/TestSessionSimpleAgent.yaml 🔗

@@ -0,0 +1,100 @@
+---
+version: 2
+interactions:
+- id: 0
+  request:
+    proto: HTTP/1.1
+    proto_major: 1
+    proto_minor: 1
+    content_length: 281
+    host: ""
+    body: '{"max_tokens":10000,"messages":[{"content":[{"text":"Hello","cache_control":{"type":"ephemeral"},"type":"text"}],"role":"user"}],"model":"claude-sonnet-4-5-20250929","system":[{"text":"You are a helpful assistant","cache_control":{"type":"ephemeral"},"type":"text"}],"stream":true}'
+    headers:
+      Accept:
+      - application/json
+      Content-Type:
+      - application/json
+      User-Agent:
+      - Anthropic/Go 1.12.0
+    url: https://api.anthropic.com/v1/messages
+    method: POST
+  response:
+    proto: HTTP/2.0
+    proto_major: 2
+    proto_minor: 0
+    content_length: -1
+    body: |+
+      event: message_start
+      data: {"type":"message_start","message":{"id":"msg_01WGBTmd2Q5E2ajXUoHZYg6K","type":"message","role":"assistant","model":"claude-sonnet-4-5-20250929","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":13,"cache_creation_input_tokens":0,"cache_read_input_tokens":0,"cache_creation":{"ephemeral_5m_input_tokens":0,"ephemeral_1h_input_tokens":0},"output_tokens":8,"service_tier":"standard"}}      }
+
+      event: content_block_start
+      data: {"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}         }
+
+      event: content_block_delta
+      data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Hello! How can I help you today"}}
+
+      event: content_block_delta
+      data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"?"}    }
+
+      event: content_block_stop
+      data: {"type":"content_block_stop","index":0  }
+
+      event: message_delta
+      data: {"type":"message_delta","delta":{"stop_reason":"end_turn","stop_sequence":null},"usage":{"input_tokens":13,"cache_creation_input_tokens":0,"cache_read_input_tokens":0,"output_tokens":12}       }
+
+      event: message_stop
+      data: {"type":"message_stop"        }
+
+    headers:
+      Content-Type:
+      - text/event-stream; charset=utf-8
+    status: 200 OK
+    code: 200
+    duration: 1.891647s
+- id: 1
+  request:
+    proto: HTTP/1.1
+    proto_major: 1
+    proto_minor: 1
+    content_length: 621
+    host: ""
+    body: '{"max_tokens":40,"messages":[{"content":[{"text":"Generate a concise title for the following content:\n\nHello","type":"text"}],"role":"user"}],"model":"claude-sonnet-4-5-20250929","system":[{"text":"you will generate a short title based on the first message a user begins a conversation with\n\n- ensure it is not more than 50 characters long\n- the title should be a summary of the user''s message\n- it should be one line long\n- do not use quotes or colons\n- the entire text you return will be used as the title\n- never return anything that is more than one sentence (one line) long\n","type":"text"}],"stream":true}'
+    headers:
+      Accept:
+      - application/json
+      Content-Type:
+      - application/json
+      User-Agent:
+      - Anthropic/Go 1.12.0
+    url: https://api.anthropic.com/v1/messages
+    method: POST
+  response:
+    proto: HTTP/2.0
+    proto_major: 2
+    proto_minor: 0
+    content_length: -1
+    body: |+
+      event: message_start
+      data: {"type":"message_start","message":{"id":"msg_01CZb9drep7yMKkc2wzNDt5G","type":"message","role":"assistant","model":"claude-sonnet-4-5-20250929","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":109,"cache_creation_input_tokens":0,"cache_read_input_tokens":0,"cache_creation":{"ephemeral_5m_input_tokens":0,"ephemeral_1h_input_tokens":0},"output_tokens":10,"service_tier":"standard"}}    }
+
+      event: content_block_start
+      data: {"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}       }
+
+      event: content_block_delta
+      data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Greeting or Starting a Conversation"}           }
+
+      event: content_block_stop
+      data: {"type":"content_block_stop","index":0}
+
+      event: message_delta
+      data: {"type":"message_delta","delta":{"stop_reason":"end_turn","stop_sequence":null},"usage":{"input_tokens":109,"cache_creation_input_tokens":0,"cache_read_input_tokens":0,"output_tokens":10}          }
+
+      event: message_stop
+      data: {"type":"message_stop"               }
+
+    headers:
+      Content-Type:
+      - text/event-stream; charset=utf-8
+    status: 200 OK
+    code: 200
+    duration: 2.547489s