From cf42f88da036a51c8194e7ca05b6383523f9d92e Mon Sep 17 00:00:00 2001 From: kujtimiihoxha Date: Tue, 30 Sep 2025 13:25:06 +0200 Subject: [PATCH] test: initial test setup --- go.mod | 2 + go.sum | 4 + internal/agent/.env.sample | 1 + internal/agent/agent.go | 41 +++++-- internal/agent/agent_test.go | 85 ++++++++++++++ internal/agent/recorder_test.go | 104 ++++++++++++++++++ .../testdata/TestSessionSimpleAgent.yaml | 100 +++++++++++++++++ 7 files changed, 326 insertions(+), 11 deletions(-) create mode 100644 internal/agent/.env.sample create mode 100644 internal/agent/agent_test.go create mode 100644 internal/agent/recorder_test.go create mode 100644 internal/agent/testdata/TestSessionSimpleAgent.yaml diff --git a/go.mod b/go.mod index 37a5fb425c9a1927499e4224455ced273001507b..68b692b0317993f5a7b8cd8b5bcdf38ea2cf74c9 100644 --- a/go.mod +++ b/go.mod @@ -41,7 +41,9 @@ require ( github.com/stretchr/testify v1.11.1 github.com/tidwall/sjson v1.2.5 github.com/zeebo/xxh3 v1.0.2 + go.yaml.in/yaml/v4 v4.0.0-rc.2 google.golang.org/genai v1.26.0 + gopkg.in/dnaeon/go-vcr.v4 v4.0.6-0.20250923044825-7b4892dd3117 gopkg.in/natefinch/lumberjack.v2 v2.2.1 mvdan.cc/sh/v3 v3.12.1-0.20250902163504-3cf4fd5717a5 ) diff --git a/go.sum b/go.sum index 181eb54b51bac24142061ee6ef40b4a3cc79d5fe..fb765cf868ad3bea981779b752db79d7599d9d65 100644 --- a/go.sum +++ b/go.sum @@ -367,6 +367,8 @@ go.opentelemetry.io/otel/trace v1.37.0 h1:HLdcFNbRQBE2imdSEgm/kwqmQj1Or1l/7bW6mx go.opentelemetry.io/otel/trace v1.37.0/go.mod h1:TlgrlQ+PtQO5XFerSPUYG0JSgGyryXewPGyayAWSBS0= go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0= go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= +go.yaml.in/yaml/v4 v4.0.0-rc.2 h1:/FrI8D64VSr4HtGIlUtlFMGsm7H7pWTbj6vOLVZcA6s= +go.yaml.in/yaml/v4 v4.0.0-rc.2/go.mod h1:aZqd9kCMsGL7AuUv/m/PvWLdg5sjJsZ4oHDEnfPPfY0= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= @@ -514,6 +516,8 @@ gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8 gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= +gopkg.in/dnaeon/go-vcr.v4 v4.0.6-0.20250923044825-7b4892dd3117 h1:fbE/sTnBb9UNfE8cJsOzrYYPqVWVHb7jWH4SI1W//cM= +gopkg.in/dnaeon/go-vcr.v4 v4.0.6-0.20250923044825-7b4892dd3117/go.mod h1:YuVT9NPq7t3oT2WpUemB0DbNL7djIjgajZycxoDLnqs= gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc= gopkg.in/natefinch/lumberjack.v2 v2.2.1/go.mod h1:YD8tP3GAjkrDg1eZH7EGmyESg/lsYskCTPBJVb9jqSc= gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ= diff --git a/internal/agent/.env.sample b/internal/agent/.env.sample new file mode 100644 index 0000000000000000000000000000000000000000..c5260ede4b820f12690c611514ba847392c29ae8 --- /dev/null +++ b/internal/agent/.env.sample @@ -0,0 +1 @@ +CRUSH_ANTHROPIC_API_KEY= diff --git a/internal/agent/agent.go b/internal/agent/agent.go index ffb1d1ffde565a9f19574461f69c66633e84749a..2da0c3d65e9e31936175f13ec53a502172ee16e1 100644 --- a/internal/agent/agent.go +++ b/internal/agent/agent.go @@ -7,6 +7,7 @@ import ( "fmt" "log/slog" "strings" + "sync" "time" "github.com/charmbracelet/catwalk/pkg/catwalk" @@ -57,13 +58,12 @@ type Model struct { } type sessionAgent struct { - largeModel Model - smallModel Model - systemPrompt string - tools []ai.AgentTool - maxOutputTokens int64 - sessions session.Service - messages message.Service + largeModel Model + smallModel Model + systemPrompt string + tools []ai.AgentTool + sessions session.Service + messages message.Service messageQueue *csync.Map[string, []SessionAgentCall] activeRequests *csync.Map[string, context.CancelFunc] @@ -71,8 +71,24 @@ type sessionAgent struct { type SessionAgentOption func(*sessionAgent) -func NewSessionAgent() SessionAgent { - return &sessionAgent{} +func NewSessionAgent( + largeModel Model, + smallModel Model, + systemPrompt string, + sessions session.Service, + messages message.Service, + tools ...ai.AgentTool, +) SessionAgent { + return &sessionAgent{ + largeModel: largeModel, + smallModel: smallModel, + systemPrompt: systemPrompt, + sessions: sessions, + messages: messages, + tools: tools, + messageQueue: csync.NewMap[string, []SessionAgentCall](), + activeRequests: csync.NewMap[string, context.CancelFunc](), + } } func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*ai.AgentResult, error) { @@ -103,7 +119,6 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*ai.Agen a.largeModel.model, ai.WithSystemPrompt(a.systemPrompt), ai.WithTools(a.tools...), - ai.WithMaxOutputTokens(a.maxOutputTokens), ) currentSession, err := a.sessions.Get(ctx, call.SessionID) @@ -116,9 +131,12 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*ai.Agen return nil, fmt.Errorf("failed to get session messages: %w", err) } + var wg sync.WaitGroup // Generate title if first message if len(msgs) == 0 { - go a.generateTitle(ctx, currentSession, call.Prompt) + wg.Go(func() { + a.generateTitle(ctx, currentSession, call.Prompt) + }) } // Add the user message to the session @@ -325,6 +343,7 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*ai.Agen if err != nil { return nil, err } + wg.Wait() queuedMessages, ok := a.messageQueue.Get(call.SessionID) if !ok || len(queuedMessages) == 0 { diff --git a/internal/agent/agent_test.go b/internal/agent/agent_test.go new file mode 100644 index 0000000000000000000000000000000000000000..3980a965321c7615594a4d413e0d0f1c39b256cb --- /dev/null +++ b/internal/agent/agent_test.go @@ -0,0 +1,85 @@ +package agent + +import ( + "database/sql" + "net/http" + "os" + "testing" + + "github.com/charmbracelet/catwalk/pkg/catwalk" + "github.com/charmbracelet/crush/internal/db" + "github.com/charmbracelet/crush/internal/message" + "github.com/charmbracelet/crush/internal/session" + "github.com/charmbracelet/fantasy/ai" + "github.com/charmbracelet/fantasy/anthropic" + "github.com/stretchr/testify/require" + "gopkg.in/dnaeon/go-vcr.v4/pkg/recorder" + + _ "github.com/joho/godotenv/autoload" +) + +type builderFunc func(r *recorder.Recorder) (ai.LanguageModel, error) + +func TestSessionSimpleAgent(t *testing.T) { + r := newRecorder(t) + sonnet, err := anthropicBuilder("claude-sonnet-4-5-20250929")(r) + require.Nil(t, err) + haiku, err := anthropicBuilder("claude-3-5-haiku-20241022")(r) + require.Nil(t, err) + agent, sessions, messages := testSessionAgent(t, sonnet, haiku, "You are a helpful assistant") + session, err := sessions.Create(t.Context(), "New Session") + require.Nil(t, err) + + res, err := agent.Run(t.Context(), SessionAgentCall{ + Prompt: "Hello", + SessionID: session.ID, + MaxOutputTokens: 10000, + }) + + require.Nil(t, err) + require.NotNil(t, res) + + t.Run("should create session messages", func(t *testing.T) { + msgs, err := messages.List(t.Context(), session.ID) + require.Nil(t, err) + // Should have the agent and user message + require.Equal(t, len(msgs), 2) + }) +} + +func anthropicBuilder(model string) builderFunc { + return func(r *recorder.Recorder) (ai.LanguageModel, error) { + provider := anthropic.New( + anthropic.WithAPIKey(os.Getenv("CRUSH_ANTHROPIC_API_KEY")), + anthropic.WithHTTPClient(&http.Client{Transport: r}), + ) + return provider.LanguageModel(model) + } +} + +func testDBConn(t *testing.T) (*sql.DB, error) { + return db.Connect(t.Context(), t.TempDir()) +} + +func testSessionAgent(t *testing.T, large, small ai.LanguageModel, systemPrompt string, tools ...ai.AgentTool) (SessionAgent, session.Service, message.Service) { + conn, err := testDBConn(t) + require.Nil(t, err) + q := db.New(conn) + sessions := session.NewService(q) + messages := message.NewService(q) + + largeModel := Model{ + model: large, + config: catwalk.Model{ + // todo: add values + }, + } + smallModel := Model{ + model: large, + config: catwalk.Model{ + // todo: add values + }, + } + agent := NewSessionAgent(largeModel, smallModel, systemPrompt, sessions, messages, tools...) + return agent, sessions, messages +} diff --git a/internal/agent/recorder_test.go b/internal/agent/recorder_test.go new file mode 100644 index 0000000000000000000000000000000000000000..34d8bae6b4abbd6cdc89626c226b3b29f93635a4 --- /dev/null +++ b/internal/agent/recorder_test.go @@ -0,0 +1,104 @@ +package agent + +import ( + "bytes" + "encoding/json" + "io" + "net/http" + "path/filepath" + "reflect" + "strings" + "testing" + + "go.yaml.in/yaml/v4" + "gopkg.in/dnaeon/go-vcr.v4/pkg/cassette" + "gopkg.in/dnaeon/go-vcr.v4/pkg/recorder" +) + +func newRecorder(t *testing.T) *recorder.Recorder { + cassetteName := filepath.Join("testdata", t.Name()) + + r, err := recorder.New( + cassetteName, + recorder.WithMode(recorder.ModeRecordOnce), + recorder.WithMatcher(customMatcher(t)), + recorder.WithMarshalFunc(marshalFunc), + recorder.WithSkipRequestLatency(true), // disable sleep to simulate response time, makes tests faster + recorder.WithHook(hookRemoveHeaders, recorder.AfterCaptureHook), + ) + if err != nil { + t.Fatalf("recorder: failed to create recorder: %v", err) + } + + t.Cleanup(func() { + if err := r.Stop(); err != nil { + t.Errorf("recorder: failed to stop recorder: %v", err) + } + }) + + return r +} + +func customMatcher(t *testing.T) recorder.MatcherFunc { + return func(r *http.Request, i cassette.Request) bool { + if r.Body == nil || r.Body == http.NoBody { + return cassette.DefaultMatcher(r, i) + } + if r.Method != i.Method || r.URL.String() != i.URL { + return false + } + + reqBody, err := io.ReadAll(r.Body) + if err != nil { + t.Fatalf("recorder: failed to read request body") + } + r.Body.Close() + r.Body = io.NopCloser(bytes.NewBuffer(reqBody)) + + // Some providers can sometimes generate JSON requests with keys in + // a different order, which means a direct string comparison will fail. + // Falling back to deserializing the content if we don't have a match. + if string(reqBody) == i.Body { // hot path + return true + } + var content1, content2 any + if err := json.Unmarshal(reqBody, &content1); err != nil { + return false + } + if err := json.Unmarshal([]byte(i.Body), &content2); err != nil { + return false + } + return reflect.DeepEqual(content1, content2) + } +} + +func marshalFunc(in any) ([]byte, error) { + var buff bytes.Buffer + enc := yaml.NewEncoder(&buff) + enc.SetIndent(2) + enc.CompactSeqIndent() + if err := enc.Encode(in); err != nil { + return nil, err + } + return buff.Bytes(), nil +} + +var headersToKeep = map[string]struct{}{ + "accept": {}, + "content-type": {}, + "user-agent": {}, +} + +func hookRemoveHeaders(i *cassette.Interaction) error { + for k := range i.Request.Headers { + if _, ok := headersToKeep[strings.ToLower(k)]; !ok { + delete(i.Request.Headers, k) + } + } + for k := range i.Response.Headers { + if _, ok := headersToKeep[strings.ToLower(k)]; !ok { + delete(i.Response.Headers, k) + } + } + return nil +} diff --git a/internal/agent/testdata/TestSessionSimpleAgent.yaml b/internal/agent/testdata/TestSessionSimpleAgent.yaml new file mode 100644 index 0000000000000000000000000000000000000000..91a5f1e0232dcd5dcbddfd571dc2f908128fd275 --- /dev/null +++ b/internal/agent/testdata/TestSessionSimpleAgent.yaml @@ -0,0 +1,100 @@ +--- +version: 2 +interactions: +- id: 0 + request: + proto: HTTP/1.1 + proto_major: 1 + proto_minor: 1 + content_length: 281 + host: "" + body: '{"max_tokens":10000,"messages":[{"content":[{"text":"Hello","cache_control":{"type":"ephemeral"},"type":"text"}],"role":"user"}],"model":"claude-sonnet-4-5-20250929","system":[{"text":"You are a helpful assistant","cache_control":{"type":"ephemeral"},"type":"text"}],"stream":true}' + headers: + Accept: + - application/json + Content-Type: + - application/json + User-Agent: + - Anthropic/Go 1.12.0 + url: https://api.anthropic.com/v1/messages + method: POST + response: + proto: HTTP/2.0 + proto_major: 2 + proto_minor: 0 + content_length: -1 + body: |+ + event: message_start + data: {"type":"message_start","message":{"id":"msg_01WGBTmd2Q5E2ajXUoHZYg6K","type":"message","role":"assistant","model":"claude-sonnet-4-5-20250929","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":13,"cache_creation_input_tokens":0,"cache_read_input_tokens":0,"cache_creation":{"ephemeral_5m_input_tokens":0,"ephemeral_1h_input_tokens":0},"output_tokens":8,"service_tier":"standard"}} } + + event: content_block_start + data: {"type":"content_block_start","index":0,"content_block":{"type":"text","text":""} } + + event: content_block_delta + data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Hello! How can I help you today"}} + + event: content_block_delta + data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"?"} } + + event: content_block_stop + data: {"type":"content_block_stop","index":0 } + + event: message_delta + data: {"type":"message_delta","delta":{"stop_reason":"end_turn","stop_sequence":null},"usage":{"input_tokens":13,"cache_creation_input_tokens":0,"cache_read_input_tokens":0,"output_tokens":12} } + + event: message_stop + data: {"type":"message_stop" } + + headers: + Content-Type: + - text/event-stream; charset=utf-8 + status: 200 OK + code: 200 + duration: 1.891647s +- id: 1 + request: + proto: HTTP/1.1 + proto_major: 1 + proto_minor: 1 + content_length: 621 + host: "" + body: '{"max_tokens":40,"messages":[{"content":[{"text":"Generate a concise title for the following content:\n\nHello","type":"text"}],"role":"user"}],"model":"claude-sonnet-4-5-20250929","system":[{"text":"you will generate a short title based on the first message a user begins a conversation with\n\n- ensure it is not more than 50 characters long\n- the title should be a summary of the user''s message\n- it should be one line long\n- do not use quotes or colons\n- the entire text you return will be used as the title\n- never return anything that is more than one sentence (one line) long\n","type":"text"}],"stream":true}' + headers: + Accept: + - application/json + Content-Type: + - application/json + User-Agent: + - Anthropic/Go 1.12.0 + url: https://api.anthropic.com/v1/messages + method: POST + response: + proto: HTTP/2.0 + proto_major: 2 + proto_minor: 0 + content_length: -1 + body: |+ + event: message_start + data: {"type":"message_start","message":{"id":"msg_01CZb9drep7yMKkc2wzNDt5G","type":"message","role":"assistant","model":"claude-sonnet-4-5-20250929","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":109,"cache_creation_input_tokens":0,"cache_read_input_tokens":0,"cache_creation":{"ephemeral_5m_input_tokens":0,"ephemeral_1h_input_tokens":0},"output_tokens":10,"service_tier":"standard"}} } + + event: content_block_start + data: {"type":"content_block_start","index":0,"content_block":{"type":"text","text":""} } + + event: content_block_delta + data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Greeting or Starting a Conversation"} } + + event: content_block_stop + data: {"type":"content_block_stop","index":0} + + event: message_delta + data: {"type":"message_delta","delta":{"stop_reason":"end_turn","stop_sequence":null},"usage":{"input_tokens":109,"cache_creation_input_tokens":0,"cache_read_input_tokens":0,"output_tokens":10} } + + event: message_stop + data: {"type":"message_stop" } + + headers: + Content-Type: + - text/event-stream; charset=utf-8 + status: 200 OK + code: 200 + duration: 2.547489s