Detailed changes
@@ -41,7 +41,9 @@ require (
github.com/stretchr/testify v1.11.1
github.com/tidwall/sjson v1.2.5
github.com/zeebo/xxh3 v1.0.2
+ go.yaml.in/yaml/v4 v4.0.0-rc.2
google.golang.org/genai v1.26.0
+ gopkg.in/dnaeon/go-vcr.v4 v4.0.6-0.20250923044825-7b4892dd3117
gopkg.in/natefinch/lumberjack.v2 v2.2.1
mvdan.cc/sh/v3 v3.12.1-0.20250902163504-3cf4fd5717a5
)
@@ -367,6 +367,8 @@ go.opentelemetry.io/otel/trace v1.37.0 h1:HLdcFNbRQBE2imdSEgm/kwqmQj1Or1l/7bW6mx
go.opentelemetry.io/otel/trace v1.37.0/go.mod h1:TlgrlQ+PtQO5XFerSPUYG0JSgGyryXewPGyayAWSBS0=
go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0=
go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
+go.yaml.in/yaml/v4 v4.0.0-rc.2 h1:/FrI8D64VSr4HtGIlUtlFMGsm7H7pWTbj6vOLVZcA6s=
+go.yaml.in/yaml/v4 v4.0.0-rc.2/go.mod h1:aZqd9kCMsGL7AuUv/m/PvWLdg5sjJsZ4oHDEnfPPfY0=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
@@ -514,6 +516,8 @@ gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
+gopkg.in/dnaeon/go-vcr.v4 v4.0.6-0.20250923044825-7b4892dd3117 h1:fbE/sTnBb9UNfE8cJsOzrYYPqVWVHb7jWH4SI1W//cM=
+gopkg.in/dnaeon/go-vcr.v4 v4.0.6-0.20250923044825-7b4892dd3117/go.mod h1:YuVT9NPq7t3oT2WpUemB0DbNL7djIjgajZycxoDLnqs=
gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc=
gopkg.in/natefinch/lumberjack.v2 v2.2.1/go.mod h1:YD8tP3GAjkrDg1eZH7EGmyESg/lsYskCTPBJVb9jqSc=
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ=
@@ -0,0 +1 @@
+CRUSH_ANTHROPIC_API_KEY=
@@ -7,6 +7,7 @@ import (
"fmt"
"log/slog"
"strings"
+ "sync"
"time"
"github.com/charmbracelet/catwalk/pkg/catwalk"
@@ -57,13 +58,12 @@ type Model struct {
}
type sessionAgent struct {
- largeModel Model
- smallModel Model
- systemPrompt string
- tools []ai.AgentTool
- maxOutputTokens int64
- sessions session.Service
- messages message.Service
+ largeModel Model
+ smallModel Model
+ systemPrompt string
+ tools []ai.AgentTool
+ sessions session.Service
+ messages message.Service
messageQueue *csync.Map[string, []SessionAgentCall]
activeRequests *csync.Map[string, context.CancelFunc]
@@ -71,8 +71,24 @@ type sessionAgent struct {
type SessionAgentOption func(*sessionAgent)
-func NewSessionAgent() SessionAgent {
- return &sessionAgent{}
+func NewSessionAgent(
+ largeModel Model,
+ smallModel Model,
+ systemPrompt string,
+ sessions session.Service,
+ messages message.Service,
+ tools ...ai.AgentTool,
+) SessionAgent {
+ return &sessionAgent{
+ largeModel: largeModel,
+ smallModel: smallModel,
+ systemPrompt: systemPrompt,
+ sessions: sessions,
+ messages: messages,
+ tools: tools,
+ messageQueue: csync.NewMap[string, []SessionAgentCall](),
+ activeRequests: csync.NewMap[string, context.CancelFunc](),
+ }
}
func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*ai.AgentResult, error) {
@@ -103,7 +119,6 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*ai.Agen
a.largeModel.model,
ai.WithSystemPrompt(a.systemPrompt),
ai.WithTools(a.tools...),
- ai.WithMaxOutputTokens(a.maxOutputTokens),
)
currentSession, err := a.sessions.Get(ctx, call.SessionID)
@@ -116,9 +131,12 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*ai.Agen
return nil, fmt.Errorf("failed to get session messages: %w", err)
}
+ var wg sync.WaitGroup
// Generate title if first message
if len(msgs) == 0 {
- go a.generateTitle(ctx, currentSession, call.Prompt)
+ wg.Go(func() {
+ a.generateTitle(ctx, currentSession, call.Prompt)
+ })
}
// Add the user message to the session
@@ -325,6 +343,7 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*ai.Agen
if err != nil {
return nil, err
}
+ wg.Wait()
queuedMessages, ok := a.messageQueue.Get(call.SessionID)
if !ok || len(queuedMessages) == 0 {
@@ -0,0 +1,85 @@
+package agent
+
+import (
+ "database/sql"
+ "net/http"
+ "os"
+ "testing"
+
+ "github.com/charmbracelet/catwalk/pkg/catwalk"
+ "github.com/charmbracelet/crush/internal/db"
+ "github.com/charmbracelet/crush/internal/message"
+ "github.com/charmbracelet/crush/internal/session"
+ "github.com/charmbracelet/fantasy/ai"
+ "github.com/charmbracelet/fantasy/anthropic"
+ "github.com/stretchr/testify/require"
+ "gopkg.in/dnaeon/go-vcr.v4/pkg/recorder"
+
+ _ "github.com/joho/godotenv/autoload"
+)
+
+type builderFunc func(r *recorder.Recorder) (ai.LanguageModel, error)
+
+func TestSessionSimpleAgent(t *testing.T) {
+ r := newRecorder(t)
+ sonnet, err := anthropicBuilder("claude-sonnet-4-5-20250929")(r)
+ require.Nil(t, err)
+ haiku, err := anthropicBuilder("claude-3-5-haiku-20241022")(r)
+ require.Nil(t, err)
+ agent, sessions, messages := testSessionAgent(t, sonnet, haiku, "You are a helpful assistant")
+ session, err := sessions.Create(t.Context(), "New Session")
+ require.Nil(t, err)
+
+ res, err := agent.Run(t.Context(), SessionAgentCall{
+ Prompt: "Hello",
+ SessionID: session.ID,
+ MaxOutputTokens: 10000,
+ })
+
+ require.Nil(t, err)
+ require.NotNil(t, res)
+
+ t.Run("should create session messages", func(t *testing.T) {
+ msgs, err := messages.List(t.Context(), session.ID)
+ require.Nil(t, err)
+ // Should have the agent and user message
+ require.Equal(t, len(msgs), 2)
+ })
+}
+
+func anthropicBuilder(model string) builderFunc {
+ return func(r *recorder.Recorder) (ai.LanguageModel, error) {
+ provider := anthropic.New(
+ anthropic.WithAPIKey(os.Getenv("CRUSH_ANTHROPIC_API_KEY")),
+ anthropic.WithHTTPClient(&http.Client{Transport: r}),
+ )
+ return provider.LanguageModel(model)
+ }
+}
+
+func testDBConn(t *testing.T) (*sql.DB, error) {
+ return db.Connect(t.Context(), t.TempDir())
+}
+
+func testSessionAgent(t *testing.T, large, small ai.LanguageModel, systemPrompt string, tools ...ai.AgentTool) (SessionAgent, session.Service, message.Service) {
+ conn, err := testDBConn(t)
+ require.Nil(t, err)
+ q := db.New(conn)
+ sessions := session.NewService(q)
+ messages := message.NewService(q)
+
+ largeModel := Model{
+ model: large,
+ config: catwalk.Model{
+ // todo: add values
+ },
+ }
+ smallModel := Model{
+ model: large,
+ config: catwalk.Model{
+ // todo: add values
+ },
+ }
+ agent := NewSessionAgent(largeModel, smallModel, systemPrompt, sessions, messages, tools...)
+ return agent, sessions, messages
+}
@@ -0,0 +1,104 @@
+package agent
+
+import (
+ "bytes"
+ "encoding/json"
+ "io"
+ "net/http"
+ "path/filepath"
+ "reflect"
+ "strings"
+ "testing"
+
+ "go.yaml.in/yaml/v4"
+ "gopkg.in/dnaeon/go-vcr.v4/pkg/cassette"
+ "gopkg.in/dnaeon/go-vcr.v4/pkg/recorder"
+)
+
+func newRecorder(t *testing.T) *recorder.Recorder {
+ cassetteName := filepath.Join("testdata", t.Name())
+
+ r, err := recorder.New(
+ cassetteName,
+ recorder.WithMode(recorder.ModeRecordOnce),
+ recorder.WithMatcher(customMatcher(t)),
+ recorder.WithMarshalFunc(marshalFunc),
+ recorder.WithSkipRequestLatency(true), // disable sleep to simulate response time, makes tests faster
+ recorder.WithHook(hookRemoveHeaders, recorder.AfterCaptureHook),
+ )
+ if err != nil {
+ t.Fatalf("recorder: failed to create recorder: %v", err)
+ }
+
+ t.Cleanup(func() {
+ if err := r.Stop(); err != nil {
+ t.Errorf("recorder: failed to stop recorder: %v", err)
+ }
+ })
+
+ return r
+}
+
+func customMatcher(t *testing.T) recorder.MatcherFunc {
+ return func(r *http.Request, i cassette.Request) bool {
+ if r.Body == nil || r.Body == http.NoBody {
+ return cassette.DefaultMatcher(r, i)
+ }
+ if r.Method != i.Method || r.URL.String() != i.URL {
+ return false
+ }
+
+ reqBody, err := io.ReadAll(r.Body)
+ if err != nil {
+ t.Fatalf("recorder: failed to read request body")
+ }
+ r.Body.Close()
+ r.Body = io.NopCloser(bytes.NewBuffer(reqBody))
+
+ // Some providers can sometimes generate JSON requests with keys in
+ // a different order, which means a direct string comparison will fail.
+ // Falling back to deserializing the content if we don't have a match.
+ if string(reqBody) == i.Body { // hot path
+ return true
+ }
+ var content1, content2 any
+ if err := json.Unmarshal(reqBody, &content1); err != nil {
+ return false
+ }
+ if err := json.Unmarshal([]byte(i.Body), &content2); err != nil {
+ return false
+ }
+ return reflect.DeepEqual(content1, content2)
+ }
+}
+
+func marshalFunc(in any) ([]byte, error) {
+ var buff bytes.Buffer
+ enc := yaml.NewEncoder(&buff)
+ enc.SetIndent(2)
+ enc.CompactSeqIndent()
+ if err := enc.Encode(in); err != nil {
+ return nil, err
+ }
+ return buff.Bytes(), nil
+}
+
+var headersToKeep = map[string]struct{}{
+ "accept": {},
+ "content-type": {},
+ "user-agent": {},
+}
+
+func hookRemoveHeaders(i *cassette.Interaction) error {
+ for k := range i.Request.Headers {
+ if _, ok := headersToKeep[strings.ToLower(k)]; !ok {
+ delete(i.Request.Headers, k)
+ }
+ }
+ for k := range i.Response.Headers {
+ if _, ok := headersToKeep[strings.ToLower(k)]; !ok {
+ delete(i.Response.Headers, k)
+ }
+ }
+ return nil
+}
@@ -0,0 +1,100 @@
+---
+version: 2
+interactions:
+- id: 0
+ request:
+ proto: HTTP/1.1
+ proto_major: 1
+ proto_minor: 1
+ content_length: 281
+ host: ""
+ body: '{"max_tokens":10000,"messages":[{"content":[{"text":"Hello","cache_control":{"type":"ephemeral"},"type":"text"}],"role":"user"}],"model":"claude-sonnet-4-5-20250929","system":[{"text":"You are a helpful assistant","cache_control":{"type":"ephemeral"},"type":"text"}],"stream":true}'
+ headers:
+ Accept:
+ - application/json
+ Content-Type:
+ - application/json
+ User-Agent:
+ - Anthropic/Go 1.12.0
+ url: https://api.anthropic.com/v1/messages
+ method: POST
+ response:
+ proto: HTTP/2.0
+ proto_major: 2
+ proto_minor: 0
+ content_length: -1
+ body: |+
+ event: message_start
+ data: {"type":"message_start","message":{"id":"msg_01WGBTmd2Q5E2ajXUoHZYg6K","type":"message","role":"assistant","model":"claude-sonnet-4-5-20250929","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":13,"cache_creation_input_tokens":0,"cache_read_input_tokens":0,"cache_creation":{"ephemeral_5m_input_tokens":0,"ephemeral_1h_input_tokens":0},"output_tokens":8,"service_tier":"standard"}} }
+
+ event: content_block_start
+ data: {"type":"content_block_start","index":0,"content_block":{"type":"text","text":""} }
+
+ event: content_block_delta
+ data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Hello! How can I help you today"}}
+
+ event: content_block_delta
+ data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"?"} }
+
+ event: content_block_stop
+ data: {"type":"content_block_stop","index":0 }
+
+ event: message_delta
+ data: {"type":"message_delta","delta":{"stop_reason":"end_turn","stop_sequence":null},"usage":{"input_tokens":13,"cache_creation_input_tokens":0,"cache_read_input_tokens":0,"output_tokens":12} }
+
+ event: message_stop
+ data: {"type":"message_stop" }
+
+ headers:
+ Content-Type:
+ - text/event-stream; charset=utf-8
+ status: 200 OK
+ code: 200
+ duration: 1.891647s
+- id: 1
+ request:
+ proto: HTTP/1.1
+ proto_major: 1
+ proto_minor: 1
+ content_length: 621
+ host: ""
+ body: '{"max_tokens":40,"messages":[{"content":[{"text":"Generate a concise title for the following content:\n\nHello","type":"text"}],"role":"user"}],"model":"claude-sonnet-4-5-20250929","system":[{"text":"you will generate a short title based on the first message a user begins a conversation with\n\n- ensure it is not more than 50 characters long\n- the title should be a summary of the user''s message\n- it should be one line long\n- do not use quotes or colons\n- the entire text you return will be used as the title\n- never return anything that is more than one sentence (one line) long\n","type":"text"}],"stream":true}'
+ headers:
+ Accept:
+ - application/json
+ Content-Type:
+ - application/json
+ User-Agent:
+ - Anthropic/Go 1.12.0
+ url: https://api.anthropic.com/v1/messages
+ method: POST
+ response:
+ proto: HTTP/2.0
+ proto_major: 2
+ proto_minor: 0
+ content_length: -1
+ body: |+
+ event: message_start
+ data: {"type":"message_start","message":{"id":"msg_01CZb9drep7yMKkc2wzNDt5G","type":"message","role":"assistant","model":"claude-sonnet-4-5-20250929","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":109,"cache_creation_input_tokens":0,"cache_read_input_tokens":0,"cache_creation":{"ephemeral_5m_input_tokens":0,"ephemeral_1h_input_tokens":0},"output_tokens":10,"service_tier":"standard"}} }
+
+ event: content_block_start
+ data: {"type":"content_block_start","index":0,"content_block":{"type":"text","text":""} }
+
+ event: content_block_delta
+ data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Greeting or Starting a Conversation"} }
+
+ event: content_block_stop
+ data: {"type":"content_block_stop","index":0}
+
+ event: message_delta
+ data: {"type":"message_delta","delta":{"stop_reason":"end_turn","stop_sequence":null},"usage":{"input_tokens":109,"cache_creation_input_tokens":0,"cache_read_input_tokens":0,"output_tokens":10} }
+
+ event: message_stop
+ data: {"type":"message_stop" }
+
+ headers:
+ Content-Type:
+ - text/event-stream; charset=utf-8
+ status: 200 OK
+ code: 200
+ duration: 2.547489s