From 6938dedd6cde0378f428b2a90d1591f69dbce6eb Mon Sep 17 00:00:00 2001 From: Christian Rocha Date: Tue, 12 May 2026 10:29:19 -0400 Subject: [PATCH] perf: batch streaming message updates Group rapid streaming updates into one save and one notification per short window instead of one per token. Important updates like finishes, tool calls, and errors still go through immediately. Cuts database writes and UI redraws by orders of magnitude during long responses. Co-Authored-By: Charm Crush --- internal/agent/agent.go | 14 + internal/app/app.go | 16 +- internal/backend/session.go | 6 + internal/message/message.go | 341 +++++++++++++- internal/message/message_test.go | 695 ++++++++++++++++++++++++++++ internal/pubsub/broker.go | 143 +++++- internal/workspace/app_workspace.go | 6 + 7 files changed, 1189 insertions(+), 32 deletions(-) create mode 100644 internal/message/message_test.go diff --git a/internal/agent/agent.go b/internal/agent/agent.go index 66062270c669ffaba298980ff997c6e2c2e04c2e..ba1215e9a22ca425f745e6e00c85908737d50a0e 100644 --- a/internal/agent/agent.go +++ b/internal/agent/agent.go @@ -245,6 +245,15 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy defer cancel() defer a.activeRequests.Del(call.SessionID) + // Drain any debounced message updates before returning. message.Service + // already flushes synchronously on terminal updates, but a defer here + // guarantees the contract at every Run exit (success, error, panic + // recovery upstream) without callers needing to know. + defer func() { + if flushErr := a.messages.FlushAll(ctx); flushErr != nil { + slog.Error("Failed to flush pending message updates after run", "error", flushErr) + } + }() history, files := a.preparePrompt(msgs, largeModel.CatwalkCfg.SupportsImages, call.Attachments...) @@ -653,6 +662,11 @@ func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fan a.activeRequests.Set(sessionID, cancel) defer a.activeRequests.Del(sessionID) defer cancel() + defer func() { + if flushErr := a.messages.FlushAll(ctx); flushErr != nil { + slog.Error("Failed to flush pending message updates after summarize", "error", flushErr) + } + }() agent := fantasy.NewAgent(largeModel.Model, fantasy.WithSystemPrompt(string(summaryPrompt)), diff --git a/internal/app/app.go b/internal/app/app.go index a167ca8638c8497a6d6f4260782ba334c6dbe0c3..a5414246b7819cbafccc452c3b79580564c9b68c 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -582,13 +582,23 @@ func (app *App) Shutdown() { app.AgentCoordinator.CancelAll() } - // Now run remaining cleanup tasks in parallel. - var wg sync.WaitGroup - // Shared shutdown context for all timeout-bounded cleanup. shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() + // Drain any debounced message updates before the DB-close cleanup + // runs in the parallel block below. message.Service buffers + // streaming deltas (see internal/message/message.go) and we must + // land them while the connection is still open. + if app.Messages != nil { + if err := app.Messages.FlushAll(shutdownCtx); err != nil { + slog.Error("Failed to flush pending message updates on shutdown", "error", err) + } + } + + // Now run remaining cleanup tasks in parallel. + var wg sync.WaitGroup + // Send exit event wg.Go(func() { event.AppExited() diff --git a/internal/backend/session.go b/internal/backend/session.go index 10e21ed8932ccbc990a525785166517cd231595c..9542282319f766e384d9ffe7942e9c235ae63e29 100644 --- a/internal/backend/session.go +++ b/internal/backend/session.go @@ -72,6 +72,12 @@ func (b *Backend) ListSessionMessages(ctx context.Context, workspaceID, sessionI return nil, err } + // Drain debounced updates so HTTP clients (and the TUI on session + // switch) observe the latest in-memory state rather than racing the + // debounce timer in message.Service. + if err := ws.Messages.FlushAll(ctx); err != nil { + return nil, err + } return ws.Messages.List(ctx, sessionID) } diff --git a/internal/message/message.go b/internal/message/message.go index 6da8827b72227602dc36c39b6a2254aba18d2b0d..9e8d258c543d9f9aed31518b848439456b5e4336 100644 --- a/internal/message/message.go +++ b/internal/message/message.go @@ -5,6 +5,7 @@ import ( "database/sql" "encoding/json" "fmt" + "sync" "time" "github.com/charmbracelet/crush/internal/db" @@ -12,6 +13,13 @@ import ( "github.com/google/uuid" ) +// defaultUpdateDebounce is the default debounce window for [Service.Update]. +// Streaming deltas that arrive within the window are coalesced into a +// single SQL write and a single pubsub event. Terminal updates +// (finish/error/cancel/tool-call structural changes) bypass the +// debounce and flush synchronously. +const defaultUpdateDebounce = 33 * time.Millisecond + type CreateMessageParams struct { Role MessageRole Parts []ContentPart @@ -20,6 +28,21 @@ type CreateMessageParams struct { IsSummaryMessage bool } +// Service is the public interface to the message store. +// +// [Service.Update] is eventually consistent: it accepts new state into +// an in-memory buffer and writes it to SQLite plus publishes a +// [pubsub.UpdatedEvent] on the next debounce tick (default +// [defaultUpdateDebounce]) or on the next terminal-state update, +// whichever comes first. Terminal-state updates — those that finish +// the message, add or finish a tool call, or end a reasoning section — +// flush synchronously before [Service.Update] returns. +// +// Callers that need stronger ordering (e.g. tests, shutdown, +// session-switch reads) must use [Service.Flush] or [Service.FlushAll] +// before reading via [Service.Get] / [Service.List]. Without an +// explicit flush, a read can race the debounce timer and miss the +// most recent in-memory state. type Service interface { pubsub.Subscriber[Message] Create(ctx context.Context, sessionID string, params CreateMessageParams) (Message, error) @@ -30,20 +53,87 @@ type Service interface { ListAllUserMessages(ctx context.Context) ([]Message, error) Delete(ctx context.Context, id string) error DeleteSessionMessages(ctx context.Context, sessionID string) error + + // Flush synchronously drains any pending debounced state for the + // given message ID, performs the SQL write, and publishes the + // resulting [pubsub.UpdatedEvent]. Idempotent; cheap no-op if no + // updates are pending. Use this before any read that must observe + // the latest [Service.Update]. + Flush(ctx context.Context, id string) error + + // FlushAll synchronously drains pending debounced state for every + // message known to the service. Intended for shutdown and + // session-switch paths. + FlushAll(ctx context.Context) error +} + +// pendingState holds the in-memory coalescing buffer for a single +// message ID. All fields except where noted are guarded by +// service.mu. The flushing flag serializes concurrent flushers for +// the same ID so SQL writes never reorder. +type pendingState struct { + // latest is the most recent [Message] passed to [Service.Update] + // that has not yet been flushed. + latest Message + + // dirty is true when latest contains state that has not been + // written to SQL since the last successful flush. + dirty bool + + // flushing is true while a goroutine is performing the SQL write + // for this ID. New updates are still accepted (and re-mark dirty) + // but other flushers must back off. + flushing bool + + // timer is the active debounce timer, or nil if no flush is + // scheduled. Stopped and reset when a terminal update preempts + // the debounce window. + timer *time.Timer + + // lastFlushed is the snapshot most recently written to SQL. Used + // as the baseline for terminal-state detection. + lastFlushed Message + + // hasFlushed is false until the first successful write for this + // ID; until then lastFlushed is the zero value and must not be + // treated as a real prior state. + hasFlushed bool } type service struct { *pubsub.Broker[Message] - q db.Querier + q db.Querier + debounce time.Duration + + mu sync.Mutex + pending map[string]*pendingState } -func NewService(q db.Querier) Service { - return &service{ - Broker: pubsub.NewBroker[Message](), - q: q, +// ServiceOption configures a [Service] at construction. +type ServiceOption func(*service) + +// WithDebounce overrides the debounce window for [Service.Update]. A +// zero or negative value disables debouncing entirely (every update +// flushes synchronously). Intended primarily for tests. +func WithDebounce(d time.Duration) ServiceOption { + return func(s *service) { + s.debounce = d } } +func NewService(q db.Querier, opts ...ServiceOption) Service { + s := &service{ + Broker: pubsub.NewBroker[Message](), + q: q, + debounce: defaultUpdateDebounce, + pending: make(map[string]*pendingState), + } + for _, opt := range opts { + opt(s) + } + return s +} + func (s *service) Delete(ctx context.Context, id string) error { message, err := s.Get(ctx, id) if err != nil { @@ -53,6 +143,16 @@ func (s *service) Delete(ctx context.Context, id string) error { if err != nil { return err } + // Drop any pending coalesced state for this ID. We never want to + // flush back over a deleted row. + s.mu.Lock() + if p, ok := s.pending[id]; ok { + if p.timer != nil { + p.timer.Stop() + } + delete(s.pending, id) + } + s.mu.Unlock() // Clone the message before publishing to avoid race conditions with // concurrent modifications to the Parts slice. s.Publish(pubsub.DeletedEvent, message.Clone()) @@ -111,31 +211,240 @@ func (s *service) DeleteSessionMessages(ctx context.Context, sessionID string) e return nil } -func (s *service) Update(ctx context.Context, message Message) error { - parts, err := marshalParts(message.Parts) +// Update accepts a new state for a message and either flushes +// synchronously (terminal updates, debounce <= 0) or buffers it until +// the next debounce tick. See [Service] for the contract. +func (s *service) Update(ctx context.Context, msg Message) error { + cloned := msg.Clone() + + // Zero or negative debounce: flush every update synchronously. This + // preserves the pre-coalescing behaviour for tests and any caller + // that explicitly opted out via [WithDebounce]. + if s.debounce <= 0 { + s.mu.Lock() + p, ok := s.pending[msg.ID] + if !ok { + p = &pendingState{} + s.pending[msg.ID] = p + } + p.latest = cloned + p.dirty = true + s.mu.Unlock() + return s.flushOne(ctx, msg.ID, true) + } + + s.mu.Lock() + p, ok := s.pending[msg.ID] + if !ok { + p = &pendingState{} + s.pending[msg.ID] = p + } + p.latest = cloned + p.dirty = true + + var prev *Message + if p.hasFlushed { + prev = &p.lastFlushed + } + terminal := shouldFlushNow(prev, &cloned) + + if terminal { + if p.timer != nil { + p.timer.Stop() + p.timer = nil + } + s.mu.Unlock() + return s.flushOne(ctx, msg.ID, true) + } + + // Debounce: schedule a single flush per pending state. If a flush + // is already running we let it finish; the trailing dirty bit will + // be picked up by the next Update or by Flush. + if p.timer == nil && !p.flushing { + id := msg.ID + p.timer = time.AfterFunc(s.debounce, func() { + // Detached from caller ctx so a cancelled stream context + // does not strand the buffered write. + _ = s.flushOne(context.Background(), id, false) + }) + } + s.mu.Unlock() + return nil +} + +// Flush implements [Service.Flush]. +func (s *service) Flush(ctx context.Context, id string) error { + return s.flushOne(ctx, id, true) +} + +// FlushAll implements [Service.FlushAll]. It snapshots every ID with +// outstanding work — either dirty buffered state or a flush already in +// flight — then drains each one. Picking up in-flight IDs ensures +// FlushAll cannot return while a timer-fired write is still mid-SQL, +// which is what shutdown and session-switch callers rely on. +func (s *service) FlushAll(ctx context.Context) error { + s.mu.Lock() + ids := make([]string, 0, len(s.pending)) + for id, p := range s.pending { + if p.dirty || p.flushing { + ids = append(ids, id) + } + } + s.mu.Unlock() + var firstErr error + for _, id := range ids { + if err := s.flushOne(ctx, id, true); err != nil && firstErr == nil { + firstErr = err + } + } + return firstErr +} + +// flushOne drains a single message ID. When syncCaller is true the +// caller is willing to wait through a concurrent in-flight flush so +// that, on return, lastFlushed equals latest at the moment of return. +// When false (timer-fired path) we bail if another flusher is already +// running; that flusher will pick up the trailing dirty bit. +// +// Order matters: a sync caller must wait for any in-flight flush to +// drain even when the buffer is currently clean — that in-flight +// write has not yet updated the SQL row, so returning early would +// violate the contract that on success lastFlushed reflects the most +// recent state. +func (s *service) flushOne(ctx context.Context, id string, syncCaller bool) error { + for { + s.mu.Lock() + p, ok := s.pending[id] + if !ok { + s.mu.Unlock() + return nil + } + if p.flushing { + if !syncCaller { + s.mu.Unlock() + return nil + } + s.mu.Unlock() + // Brief yield; in-flight write should land in <1ms typical. + time.Sleep(time.Millisecond) + continue + } + if !p.dirty { + s.mu.Unlock() + return nil + } + + if p.timer != nil { + p.timer.Stop() + p.timer = nil + } + snap := p.latest + // Decide whether this snapshot represents a terminal event + // against the prior baseline. We must do this before resetting + // dirty/flushing because shouldFlushNow looks at p.lastFlushed + // (which is what was on disk before this write). + var prev *Message + if p.hasFlushed { + prev = &p.lastFlushed + } + isTerminal := shouldFlushNow(prev, &snap) + p.flushing = true + p.dirty = false + s.mu.Unlock() + + err := s.write(ctx, snap) + + s.mu.Lock() + p.flushing = false + if err == nil { + p.lastFlushed = snap + p.hasFlushed = true + } else { + // Restore dirty so the next caller retries. + p.dirty = true + } + // If a delta arrived during the SQL write and we are a sync + // caller, the user expects that delta to land too. + wasDirty := p.dirty + s.mu.Unlock() + + if err != nil { + return err + } + + // Terminal events — message finished, tool call added or + // finished, reasoning ended — use the bounded must-deliver + // path so they never get dropped under channel contention. + if isTerminal { + s.PublishMustDeliver(ctx, pubsub.UpdatedEvent, snap) + } else { + s.Publish(pubsub.UpdatedEvent, snap) + } + + if wasDirty && syncCaller { + continue + } + return nil + } +} + +// write performs the unguarded SQL write + UpdatedAt stamp. Caller +// owns publishing. +func (s *service) write(ctx context.Context, msg Message) error { + parts, err := marshalParts(msg.Parts) if err != nil { return err } finishedAt := sql.NullInt64{} - if f := message.FinishPart(); f != nil { + if f := msg.FinishPart(); f != nil { finishedAt.Int64 = f.Time finishedAt.Valid = true } - err = s.q.UpdateMessage(ctx, db.UpdateMessageParams{ - ID: message.ID, + if err := s.q.UpdateMessage(ctx, db.UpdateMessageParams{ + ID: msg.ID, Parts: string(parts), FinishedAt: finishedAt, - }) - if err != nil { + }); err != nil { return err } - message.UpdatedAt = time.Now().Unix() - // Clone the message before publishing to avoid race conditions with - // concurrent modifications to the Parts slice. - s.Publish(pubsub.UpdatedEvent, message.Clone()) return nil } +// shouldFlushNow returns true when next represents a structural +// change that must not be silently coalesced: the message just +// finished, the tool-call set grew, a tool call transitioned to +// finished, or reasoning just finished. prev is the last-flushed +// snapshot (or nil if no write has landed yet). +func shouldFlushNow(prev, next *Message) bool { + if next.IsFinished() { + return true + } + + var prevCalls []ToolCall + var prevReasoningFinishedAt int64 + if prev != nil { + prevCalls = prev.ToolCalls() + prevReasoningFinishedAt = prev.ReasoningContent().FinishedAt + } + nextCalls := next.ToolCalls() + if len(nextCalls) != len(prevCalls) { + return true + } + for i := range nextCalls { + // Bounds-safe: lengths are equal here. + if nextCalls[i].Finished != prevCalls[i].Finished { + return true + } + // A tool call's input only matters once it has landed (Finished + // flips true). Earlier deltas to Input are debounced with the + // rest of the streaming state. + } + if next.ReasoningContent().FinishedAt > 0 && prevReasoningFinishedAt == 0 { + return true + } + return false +} + func (s *service) Get(ctx context.Context, id string) (Message, error) { dbMessage, err := s.q.GetMessage(ctx, id) if err != nil { diff --git a/internal/message/message_test.go b/internal/message/message_test.go new file mode 100644 index 0000000000000000000000000000000000000000..0a046337542606ae10aad16ef8cb36d1eb879c03 --- /dev/null +++ b/internal/message/message_test.go @@ -0,0 +1,695 @@ +package message + +import ( + "context" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/charmbracelet/crush/internal/db" + "github.com/charmbracelet/crush/internal/pubsub" + "github.com/charmbracelet/crush/internal/session" + "github.com/stretchr/testify/require" +) + +// slowUpdateQuerier wraps a [db.Querier] and forces UpdateMessage to +// hang on a release channel. Used to simulate an in-flight SQL write. +type slowUpdateQuerier struct { + db.Querier + release chan struct{} + started chan struct{} + startOnce sync.Once +} + +func (s *slowUpdateQuerier) UpdateMessage(ctx context.Context, arg db.UpdateMessageParams) error { + s.startOnce.Do(func() { close(s.started) }) + select { + case <-s.release: + case <-ctx.Done(): + return ctx.Err() + } + return s.Querier.UpdateMessage(ctx, arg) +} + +// newTestService spins up a fresh in-memory message.Service backed by a +// temporary on-disk SQLite database. Returns the service plus a session +// ID to attach messages to. +func newTestService(t *testing.T, opts ...ServiceOption) (Service, string) { + t.Helper() + conn, err := db.Connect(t.Context(), t.TempDir()) + require.NoError(t, err) + t.Cleanup(func() { _ = conn.Close() }) + + q := db.New(conn) + sessions := session.NewService(q, conn) + sess, err := sessions.Create(t.Context(), "test") + require.NoError(t, err) + + svc := NewService(q, opts...) + return svc, sess.ID +} + +// eventCollector consumes broker events into a slice in a goroutine +// and exposes thread-safe Snapshot / Reset helpers for assertions. +type eventCollector struct { + mu sync.Mutex + events []pubsub.Event[Message] +} + +func collect(ctx context.Context, sub <-chan pubsub.Event[Message]) *eventCollector { + c := &eventCollector{} + go func() { + for { + select { + case <-ctx.Done(): + return + case ev, ok := <-sub: + if !ok { + return + } + c.mu.Lock() + c.events = append(c.events, ev) + c.mu.Unlock() + } + } + }() + return c +} + +func (c *eventCollector) snapshot() []pubsub.Event[Message] { + c.mu.Lock() + defer c.mu.Unlock() + out := make([]pubsub.Event[Message], len(c.events)) + copy(out, c.events) + return out +} + +func (c *eventCollector) reset() { + c.mu.Lock() + defer c.mu.Unlock() + c.events = nil +} + +func TestUpdate_DebouncesTextDeltas(t *testing.T) { + t.Parallel() + + // Long-enough debounce that we can verify nothing flushes prematurely. + svc, sessionID := newTestService(t, WithDebounce(50*time.Millisecond)) + + subCtx, cancelSub := context.WithCancel(t.Context()) + defer cancelSub() + sub := svc.Subscribe(subCtx) + collector := collect(subCtx, sub) + + msg, err := svc.Create(t.Context(), sessionID, CreateMessageParams{ + Role: Assistant, + }) + require.NoError(t, err) + // Drop the CreatedEvent emitted by Create. + time.Sleep(5 * time.Millisecond) + collector.reset() + + // Push 5 deltas inside a single debounce window. + for i := 0; i < 5; i++ { + msg.AppendContent("a") + require.NoError(t, svc.Update(t.Context(), msg)) + } + + // Before the debounce expires no UpdatedEvent should have landed. + time.Sleep(10 * time.Millisecond) + require.Empty(t, collector.snapshot(), "no events should land before debounce window expires") + + // Wait for the debounce timer to fire. + require.Eventually(t, func() bool { + return len(collector.snapshot()) >= 1 + }, time.Second, 5*time.Millisecond) + events := collector.snapshot() + require.Len(t, events, 1, "5 deltas should coalesce into 1 UpdatedEvent") + require.Equal(t, pubsub.UpdatedEvent, events[0].Type) + require.Equal(t, "aaaaa", events[0].Payload.Content().Text) + + // Final state must be persisted. + got, err := svc.Get(t.Context(), msg.ID) + require.NoError(t, err) + require.Equal(t, "aaaaa", got.Content().Text) +} + +func TestUpdate_TerminalUpdatesFlushSynchronously(t *testing.T) { + t.Parallel() + + svc, sessionID := newTestService(t, WithDebounce(time.Hour)) + + subCtx, cancelSub := context.WithCancel(t.Context()) + defer cancelSub() + sub := svc.Subscribe(subCtx) + collector := collect(subCtx, sub) + + msg, err := svc.Create(t.Context(), sessionID, CreateMessageParams{Role: Assistant}) + require.NoError(t, err) + time.Sleep(5 * time.Millisecond) + collector.reset() + + // AddFinish makes the message terminal; Update must flush + // synchronously even with a 1-hour debounce. + msg.AppendContent("done") + msg.AddFinish(FinishReasonEndTurn, "", "") + require.NoError(t, svc.Update(t.Context(), msg)) + + require.Eventually(t, func() bool { + return len(collector.snapshot()) >= 1 + }, time.Second, 5*time.Millisecond, + "terminal update must publish without waiting for debounce") + events := collector.snapshot() + require.Len(t, events, 1) + require.True(t, events[0].Payload.IsFinished()) + + got, err := svc.Get(t.Context(), msg.ID) + require.NoError(t, err) + require.True(t, got.IsFinished()) +} + +func TestUpdate_ToolCallStructuralChangeFlushes(t *testing.T) { + t.Parallel() + + svc, sessionID := newTestService(t, WithDebounce(time.Hour)) + + msg, err := svc.Create(t.Context(), sessionID, CreateMessageParams{Role: Assistant}) + require.NoError(t, err) + + // Adding a new tool call is a structural change → sync flush. + msg.AddToolCall(ToolCall{ID: "tc1", Name: "view", Finished: false}) + require.NoError(t, svc.Update(t.Context(), msg)) + + got, err := svc.Get(t.Context(), msg.ID) + require.NoError(t, err) + require.Len(t, got.ToolCalls(), 1) + require.Equal(t, "tc1", got.ToolCalls()[0].ID) + + // Marking the tool call finished is also a structural change. + msg.AddToolCall(ToolCall{ID: "tc1", Name: "view", Input: "{}", Finished: true}) + require.NoError(t, svc.Update(t.Context(), msg)) + + got, err = svc.Get(t.Context(), msg.ID) + require.NoError(t, err) + require.True(t, got.ToolCalls()[0].Finished) +} + +func TestUpdate_ReasoningEndFlushes(t *testing.T) { + t.Parallel() + + svc, sessionID := newTestService(t, WithDebounce(time.Hour)) + + msg, err := svc.Create(t.Context(), sessionID, CreateMessageParams{Role: Assistant}) + require.NoError(t, err) + + // Reasoning deltas alone debounce. + msg.AppendReasoningContent("hmm") + require.NoError(t, svc.Update(t.Context(), msg)) + got, err := svc.Get(t.Context(), msg.ID) + require.NoError(t, err) + require.Empty(t, got.ReasoningContent().Thinking, "reasoning delta should still be in the debounce buffer") + + // FinishThinking sets FinishedAt → terminal flush. + msg.FinishThinking() + require.NoError(t, svc.Update(t.Context(), msg)) + + got, err = svc.Get(t.Context(), msg.ID) + require.NoError(t, err) + require.Equal(t, "hmm", got.ReasoningContent().Thinking) + require.NotZero(t, got.ReasoningContent().FinishedAt) +} + +func TestFlush_DrainsPendingDebouncedUpdates(t *testing.T) { + t.Parallel() + + svc, sessionID := newTestService(t, WithDebounce(time.Hour)) + + msg, err := svc.Create(t.Context(), sessionID, CreateMessageParams{Role: Assistant}) + require.NoError(t, err) + msg.AppendContent("buffered") + require.NoError(t, svc.Update(t.Context(), msg)) + + // Without a flush the SQL row is unchanged from Create. + got, err := svc.Get(t.Context(), msg.ID) + require.NoError(t, err) + require.Empty(t, got.Content().Text) + + require.NoError(t, svc.Flush(t.Context(), msg.ID)) + + got, err = svc.Get(t.Context(), msg.ID) + require.NoError(t, err) + require.Equal(t, "buffered", got.Content().Text) + + // Subsequent Flush is a no-op. + require.NoError(t, svc.Flush(t.Context(), msg.ID)) +} + +func TestFlushAll_DrainsAllPending(t *testing.T) { + t.Parallel() + + svc, sessionID := newTestService(t, WithDebounce(time.Hour)) + + const n = 5 + msgs := make([]Message, n) + for i := range msgs { + m, err := svc.Create(t.Context(), sessionID, CreateMessageParams{Role: Assistant}) + require.NoError(t, err) + m.AppendContent("hi") + require.NoError(t, svc.Update(t.Context(), m)) + msgs[i] = m + } + + require.NoError(t, svc.FlushAll(t.Context())) + + for _, m := range msgs { + got, err := svc.Get(t.Context(), m.ID) + require.NoError(t, err) + require.Equal(t, "hi", got.Content().Text, "FlushAll should drain every pending message") + } +} + +func TestUpdate_OrderingMatchesNonCoalesced(t *testing.T) { + t.Parallel() + + // Compare the final state after coalesced vs zero-debounce updates. + // A sequence of interleaved text/reasoning/tool-call updates must + // converge to the same final DB row either way. + build := func(svc Service, sessionID string) Message { + msg, err := svc.Create(t.Context(), sessionID, CreateMessageParams{Role: Assistant}) + require.NoError(t, err) + msg.AppendReasoningContent("thinking 1 ") + require.NoError(t, svc.Update(t.Context(), msg)) + msg.AppendReasoningContent("thinking 2") + require.NoError(t, svc.Update(t.Context(), msg)) + msg.FinishThinking() + require.NoError(t, svc.Update(t.Context(), msg)) + msg.AppendContent("hello ") + require.NoError(t, svc.Update(t.Context(), msg)) + msg.AppendContent("world") + require.NoError(t, svc.Update(t.Context(), msg)) + msg.AddToolCall(ToolCall{ID: "tc", Name: "x", Finished: false}) + require.NoError(t, svc.Update(t.Context(), msg)) + msg.AddToolCall(ToolCall{ID: "tc", Name: "x", Input: "{}", Finished: true}) + require.NoError(t, svc.Update(t.Context(), msg)) + msg.AddFinish(FinishReasonEndTurn, "", "") + require.NoError(t, svc.Update(t.Context(), msg)) + return msg + } + + coalesced, sid1 := newTestService(t, WithDebounce(20*time.Millisecond)) + a := build(coalesced, sid1) + require.NoError(t, coalesced.FlushAll(t.Context())) + gotA, err := coalesced.Get(t.Context(), a.ID) + require.NoError(t, err) + + immediate, sid2 := newTestService(t, WithDebounce(0)) + b := build(immediate, sid2) + gotB, err := immediate.Get(t.Context(), b.ID) + require.NoError(t, err) + + require.Equal(t, gotA.Content().Text, gotB.Content().Text) + require.Equal(t, gotA.ReasoningContent().Thinking, gotB.ReasoningContent().Thinking) + require.Equal(t, len(gotA.ToolCalls()), len(gotB.ToolCalls())) + require.Equal(t, gotA.IsFinished(), gotB.IsFinished()) +} + +func TestDelete_DropsPendingState(t *testing.T) { + t.Parallel() + + svc, sessionID := newTestService(t, WithDebounce(time.Hour)) + msg, err := svc.Create(t.Context(), sessionID, CreateMessageParams{Role: Assistant}) + require.NoError(t, err) + msg.AppendContent("dropped") + require.NoError(t, svc.Update(t.Context(), msg)) + + require.NoError(t, svc.Delete(t.Context(), msg.ID)) + + // FlushAll after Delete must not write to the deleted row. + require.NoError(t, svc.FlushAll(t.Context())) + + _, err = svc.Get(t.Context(), msg.ID) + require.Error(t, err, "deleted message must remain deleted") +} + +func TestBroker_PublishLossyDropCounter(t *testing.T) { + t.Parallel() + + // Tiny channel buffer so we can saturate from a single sender. + b := pubsub.NewBrokerWithOptions[int](1, 1000) + defer b.Shutdown() + + subCtx, cancel := context.WithCancel(t.Context()) + defer cancel() + sub := b.Subscribe(subCtx) + require.NotNil(t, sub) + + // Don't read from sub. Saturate the buffer. + for range 100 { + b.Publish(pubsub.UpdatedEvent, 1) + } + + require.GreaterOrEqual(t, b.DropCount(), uint64(1), + "lossy Publish must increment the drop counter under contention") +} + +func TestBroker_PublishMustDeliverHonorsTimeout(t *testing.T) { + t.Parallel() + + b := pubsub.NewBrokerWithOptions[int](1, 1000) + b.SetMustDeliverTimeout(20 * time.Millisecond) + defer b.Shutdown() + + subCtx, cancel := context.WithCancel(t.Context()) + defer cancel() + sub := b.Subscribe(subCtx) + require.NotNil(t, sub) + + // Saturate: one event sits in the buffer, the second must wait. + b.Publish(pubsub.UpdatedEvent, 1) + + // PublishMustDeliver should block up to 20ms then drop. + start := time.Now() + b.PublishMustDeliver(t.Context(), pubsub.UpdatedEvent, 2) + elapsed := time.Since(start) + + require.GreaterOrEqual(t, elapsed, 20*time.Millisecond, + "PublishMustDeliver should block at least the timeout under contention") + require.Less(t, elapsed, 200*time.Millisecond, + "PublishMustDeliver must not block indefinitely") + require.GreaterOrEqual(t, b.MustDeliverDropCount(), uint64(1), + "timeout must increment the must-deliver drop counter") +} + +func TestBroker_PublishMustDeliverWithReader(t *testing.T) { + t.Parallel() + + b := pubsub.NewBrokerWithOptions[int](1, 1000) + b.SetMustDeliverTimeout(50 * time.Millisecond) + defer b.Shutdown() + + subCtx, cancel := context.WithCancel(t.Context()) + defer cancel() + sub := b.Subscribe(subCtx) + + var received atomic.Uint64 + done := make(chan struct{}) + go func() { + defer close(done) + for { + select { + case <-subCtx.Done(): + return + case _, ok := <-sub: + if !ok { + return + } + received.Add(1) + } + } + }() + + for i := range 10 { + b.PublishMustDeliver(t.Context(), pubsub.UpdatedEvent, i) + } + + // All 10 should land within the must-deliver timeout window. + require.Eventually(t, func() bool { return received.Load() == 10 }, + time.Second, 5*time.Millisecond, + "all must-deliver events should reach an active subscriber") + require.Zero(t, b.MustDeliverDropCount(), + "no drops expected when subscriber drains promptly") +} + +func TestUpdate_TerminalEventUsesMustDeliver(t *testing.T) { + t.Parallel() + + svc, sessionID := newTestService(t, WithDebounce(time.Hour)) + + subCtx, cancel := context.WithCancel(t.Context()) + defer cancel() + sub := svc.Subscribe(subCtx) + + var seenFinish atomic.Bool + done := make(chan struct{}) + go func() { + defer close(done) + for { + select { + case <-subCtx.Done(): + return + case ev, ok := <-sub: + if !ok { + return + } + if ev.Type == pubsub.UpdatedEvent && ev.Payload.IsFinished() { + seenFinish.Store(true) + } + } + } + }() + + msg, err := svc.Create(t.Context(), sessionID, CreateMessageParams{Role: Assistant}) + require.NoError(t, err) + msg.AppendContent("final") + msg.AddFinish(FinishReasonEndTurn, "", "") + require.NoError(t, svc.Update(t.Context(), msg)) + + require.Eventually(t, func() bool { return seenFinish.Load() }, + time.Second, 10*time.Millisecond, + "terminal update must reach subscribers via the must-deliver path") +} + +func TestUpdate_ZeroDebounceFlushesEveryUpdate(t *testing.T) { + t.Parallel() + + svc, sessionID := newTestService(t, WithDebounce(0)) + + msg, err := svc.Create(t.Context(), sessionID, CreateMessageParams{Role: Assistant}) + require.NoError(t, err) + + for i := 0; i < 3; i++ { + msg.AppendContent("x") + require.NoError(t, svc.Update(t.Context(), msg)) + got, err := svc.Get(t.Context(), msg.ID) + require.NoError(t, err) + require.Len(t, got.Content().Text, i+1, "every update must land synchronously when debounce is 0") + } +} + +// TestFlush_WaitsForInFlightWrite reproduces the failure where Flush +// or FlushAll could return before a concurrent in-flight SQL write +// completed. We block UpdateMessage on a release channel, fire the +// debounce timer, then call Flush and assert it does not return until +// the in-flight write actually lands. +func TestFlush_WaitsForInFlightWrite(t *testing.T) { + t.Parallel() + + conn, err := db.Connect(t.Context(), t.TempDir()) + require.NoError(t, err) + t.Cleanup(func() { _ = conn.Close() }) + + q := db.New(conn) + sessions := session.NewService(q, conn) + sess, err := sessions.Create(t.Context(), "test") + require.NoError(t, err) + + slow := &slowUpdateQuerier{ + Querier: q, + release: make(chan struct{}), + started: make(chan struct{}), + } + // Short debounce so the timer fires quickly. + svc := NewService(slow, WithDebounce(10*time.Millisecond)) + + msg, err := svc.Create(t.Context(), sess.ID, CreateMessageParams{Role: Assistant}) + require.NoError(t, err) + msg.AppendContent("payload") + require.NoError(t, svc.Update(t.Context(), msg)) + + // Wait for the timer-fired flush to enter UpdateMessage. + select { + case <-slow.started: + case <-time.After(time.Second): + t.Fatal("timer-fired flush never reached UpdateMessage") + } + + // At this point the buffer is dirty=false but flushing=true. A + // naive Flush would early-return on !dirty. Spawn Flush in a + // goroutine and assert it has not returned while the write is + // still blocked. + flushDone := make(chan error, 1) + go func() { flushDone <- svc.Flush(t.Context(), msg.ID) }() + + select { + case err := <-flushDone: + t.Fatalf("Flush returned %v while in-flight write was still blocked", err) + case <-time.After(50 * time.Millisecond): + // Expected: Flush is correctly waiting. + } + + // Release the slow write; Flush must now return cleanly. + close(slow.release) + select { + case err := <-flushDone: + require.NoError(t, err) + case <-time.After(time.Second): + t.Fatal("Flush did not return after in-flight write completed") + } + + // The SQL row should now reflect the buffered payload. + got, err := svc.Get(t.Context(), msg.ID) + require.NoError(t, err) + require.Equal(t, "payload", got.Content().Text) +} + +// TestFlushAll_WaitsForInFlightWrite asserts FlushAll picks up IDs +// whose buffer is currently flushing (dirty=false) so shutdown and +// session-switch callers don't return while an SQL write is mid-flight. +func TestFlushAll_WaitsForInFlightWrite(t *testing.T) { + t.Parallel() + + conn, err := db.Connect(t.Context(), t.TempDir()) + require.NoError(t, err) + t.Cleanup(func() { _ = conn.Close() }) + + q := db.New(conn) + sessions := session.NewService(q, conn) + sess, err := sessions.Create(t.Context(), "test") + require.NoError(t, err) + + slow := &slowUpdateQuerier{ + Querier: q, + release: make(chan struct{}), + started: make(chan struct{}), + } + svc := NewService(slow, WithDebounce(10*time.Millisecond)) + + msg, err := svc.Create(t.Context(), sess.ID, CreateMessageParams{Role: Assistant}) + require.NoError(t, err) + msg.AppendContent("payload") + require.NoError(t, svc.Update(t.Context(), msg)) + + select { + case <-slow.started: + case <-time.After(time.Second): + t.Fatal("timer-fired flush never reached UpdateMessage") + } + + flushDone := make(chan error, 1) + go func() { flushDone <- svc.FlushAll(t.Context()) }() + + select { + case err := <-flushDone: + t.Fatalf("FlushAll returned %v while in-flight write was still blocked", err) + case <-time.After(50 * time.Millisecond): + } + + close(slow.release) + select { + case err := <-flushDone: + require.NoError(t, err) + case <-time.After(time.Second): + t.Fatal("FlushAll did not return after in-flight write completed") + } + + got, err := svc.Get(t.Context(), msg.ID) + require.NoError(t, err) + require.Equal(t, "payload", got.Content().Text) +} + +// TestUpdate_StructuralFlushUsesMustDeliver covers the second review +// finding: structural terminal events (tool-call add, tool-call +// finish, reasoning end) must publish via the must-deliver path even +// when the message itself is not yet IsFinished. +// +// We detect which path was taken by saturating a subscriber's channel +// buffer with no reader. With a short must-deliver timeout, the +// must-deliver path increments [pubsub.Broker.MustDeliverDropCount] +// after the timeout expires; the lossy path increments +// [pubsub.Broker.DropCount] immediately. The two counters are +// disjoint, so they precisely identify which call site published the +// event. +func TestUpdate_StructuralFlushUsesMustDeliver(t *testing.T) { + t.Parallel() + + cases := []struct { + name string + mut func(*Message) + }{ + { + name: "tool call add", + mut: func(m *Message) { + m.AddToolCall(ToolCall{ID: "tc1", Name: "view"}) + }, + }, + { + name: "tool call finish", + mut: func(m *Message) { + m.AddToolCall(ToolCall{ID: "tc1", Name: "view", Input: "{}", Finished: true}) + }, + }, + { + name: "reasoning end", + mut: func(m *Message) { + m.AppendReasoningContent("hmm") + m.FinishThinking() + }, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + conn, err := db.Connect(t.Context(), t.TempDir()) + require.NoError(t, err) + t.Cleanup(func() { _ = conn.Close() }) + + q := db.New(conn) + sessions := session.NewService(q, conn) + sess, err := sessions.Create(t.Context(), "test") + require.NoError(t, err) + + // Replace the default broker with a tiny buffer + short + // must-deliver timeout so we can fully saturate from a + // single sender and observe drops without long waits. + svc := NewService(q, WithDebounce(time.Hour)) + impl := svc.(*service) + impl.Shutdown() + impl.Broker = pubsub.NewBrokerWithOptions[Message](1, 1000) + impl.SetMustDeliverTimeout(40 * time.Millisecond) + + subCtx, cancel := context.WithCancel(t.Context()) + defer cancel() + sub := svc.Subscribe(subCtx) + + msg, err := svc.Create(t.Context(), sess.ID, CreateMessageParams{Role: Assistant}) + require.NoError(t, err) + + // Saturate the subscriber's buffer (capacity 1). The + // CreatedEvent from Create above already left one event + // in the buffer; we never read sub, so the next publish + // has nowhere to go. + _ = sub // intentionally not drained. + + // Drive the structural change. With debounce=1h, Update + // flushes synchronously and routes through whichever + // publish path the service chose for structural events. + tc.mut(&msg) + require.NoError(t, svc.Update(t.Context(), msg)) + + // Must-deliver timeout (40ms) should have expired with + // no drain. If structural events are routed through + // must-deliver: MustDeliverDropCount > 0, DropCount + // unchanged. If routed through lossy Publish: + // DropCount > 0, MustDeliverDropCount == 0. + require.Eventually(t, func() bool { + return impl.MustDeliverDropCount() >= 1 + }, time.Second, 5*time.Millisecond, + "structural terminal event should publish via must-deliver, not lossy Publish") + require.Zero(t, impl.DropCount(), + "structural terminal event must not be silently dropped via lossy Publish") + }) + } +} diff --git a/internal/pubsub/broker.go b/internal/pubsub/broker.go index 52e827c2d6306fa2365372b9867810d9b99d0227..03af2e676f67bee68a757eb85a600d7b09838c4a 100644 --- a/internal/pubsub/broker.go +++ b/internal/pubsub/broker.go @@ -1,19 +1,54 @@ +// Package pubsub provides a lightweight in-process broker for fan-out +// event delivery between services and the UI. +// +// Delivery semantics: +// +// - [Broker.Publish] is best-effort and lossy under contention. If a +// subscriber's channel is full, the event is dropped for that +// subscriber and a counter is incremented. This is the right choice +// for high-frequency intermediate updates (e.g. streaming token +// deltas) where only the latest state matters. +// +// - [Broker.PublishMustDeliver] is bounded-blocking. For each +// subscriber it first tries a non-blocking send, then falls back to +// a per-subscriber blocking send with a hard timeout. On timeout the +// event is dropped for that subscriber, an error is logged, and the +// must-deliver drop counter is incremented. The publisher never +// blocks indefinitely. This is the right choice for terminal events +// (finish, tool result, error, cancel) that must not be silently +// coalesced away. +// +// Drop counters ([Broker.DropCount], [Broker.MustDeliverDropCount]) are +// exposed so callers can surface saturation in telemetry. package pubsub import ( "context" + "log/slog" "sync" + "sync/atomic" + "time" ) -const bufferSize = 64 +const ( + bufferSize = 64 + + // defaultMustDeliverTimeout is the per-subscriber upper bound on how + // long [Broker.PublishMustDeliver] will block waiting for buffer + // space before giving up on that subscriber. + defaultMustDeliverTimeout = 50 * time.Millisecond +) type Broker[T any] struct { - subs map[chan Event[T]]struct{} - mu sync.RWMutex - done chan struct{} - subCount int - maxEvents int - channelBufferSize int + subs map[chan Event[T]]struct{} + mu sync.RWMutex + done chan struct{} + subCount int + maxEvents int + channelBufferSize int + mustDeliverTimeout time.Duration + dropCount atomic.Uint64 + mustDeliverDropCount atomic.Uint64 } func NewBroker[T any]() *Broker[T] { @@ -22,13 +57,27 @@ func NewBroker[T any]() *Broker[T] { func NewBrokerWithOptions[T any](channelBufferSize, maxEvents int) *Broker[T] { return &Broker[T]{ - subs: make(map[chan Event[T]]struct{}), - done: make(chan struct{}), - maxEvents: maxEvents, - channelBufferSize: channelBufferSize, + subs: make(map[chan Event[T]]struct{}), + done: make(chan struct{}), + maxEvents: maxEvents, + channelBufferSize: channelBufferSize, + mustDeliverTimeout: defaultMustDeliverTimeout, } } +// SetMustDeliverTimeout overrides the per-subscriber timeout used by +// [Broker.PublishMustDeliver]. A zero or negative value resets to the +// default. Intended primarily for tests. +func (b *Broker[T]) SetMustDeliverTimeout(d time.Duration) { + b.mu.Lock() + defer b.mu.Unlock() + if d <= 0 { + b.mustDeliverTimeout = defaultMustDeliverTimeout + return + } + b.mustDeliverTimeout = d +} + func (b *Broker[T]) Shutdown() { select { case <-b.done: // Already closed @@ -90,6 +139,25 @@ func (b *Broker[T]) GetSubscriberCount() int { return b.subCount } +// DropCount returns the cumulative number of events dropped by +// [Broker.Publish] because a subscriber's channel was full. +func (b *Broker[T]) DropCount() uint64 { + return b.dropCount.Load() +} + +// MustDeliverDropCount returns the cumulative number of events dropped +// by [Broker.PublishMustDeliver] after the per-subscriber timeout +// expired. +func (b *Broker[T]) MustDeliverDropCount() uint64 { + return b.mustDeliverDropCount.Load() +} + +// Publish delivers an event to every active subscriber. +// +// Delivery is non-blocking and lossy: if a subscriber's channel is full +// the event is dropped for that subscriber and [Broker.DropCount] is +// incremented. Use [Broker.PublishMustDeliver] for events that must not +// be silently dropped. func (b *Broker[T]) Publish(t EventType, payload T) { b.mu.RLock() defer b.mu.RUnlock() @@ -106,8 +174,57 @@ func (b *Broker[T]) Publish(t EventType, payload T) { select { case sub <- event: default: - // Channel is full, subscriber is slow - skip this event - // This prevents blocking the publisher + // Channel is full, subscriber is slow - skip this event. + // Lossy by design; counted so saturation is observable. + b.dropCount.Add(1) + } + } +} + +// PublishMustDeliver delivers an event with bounded-blocking semantics. +// For each subscriber it first attempts a non-blocking send, then falls +// back to a blocking send bounded by a per-subscriber timeout (default +// [defaultMustDeliverTimeout]). On timeout the event is dropped for +// that subscriber, [Broker.MustDeliverDropCount] is incremented, and an +// error is logged. The publisher never blocks indefinitely. +// +// Use this for terminal events that must reach subscribers (finish, +// tool result, error, cancel). Callers must still tolerate rare drops +// after timeout — recovery is the subscriber's responsibility (e.g. a +// re-fetch on the next session-visible event). +func (b *Broker[T]) PublishMustDeliver(ctx context.Context, t EventType, payload T) { + b.mu.RLock() + defer b.mu.RUnlock() + + select { + case <-b.done: + return + default: + } + + event := Event[T]{Type: t, Payload: payload} + timeout := b.mustDeliverTimeout + + for sub := range b.subs { + // Fast path: non-blocking send. + select { + case sub <- event: + continue + default: + } + + // Slow path: bounded blocking send. + timer := time.NewTimer(timeout) + select { + case sub <- event: + timer.Stop() + case <-timer.C: + b.mustDeliverDropCount.Add(1) + slog.Error("PublishMustDeliver timed out delivering event", + "type", t, "timeout", timeout) + case <-ctx.Done(): + timer.Stop() + return } } } diff --git a/internal/workspace/app_workspace.go b/internal/workspace/app_workspace.go index 57b1228e7eacb28a16141283ee2703a33511bd18..d4e5ed790e3a1cf7a4bcc299e4e4e63bedfbacd5 100644 --- a/internal/workspace/app_workspace.go +++ b/internal/workspace/app_workspace.go @@ -70,6 +70,12 @@ func (w *AppWorkspace) ParseAgentToolSessionID(sessionID string) (string, string // -- Messages -- func (w *AppWorkspace) ListMessages(ctx context.Context, sessionID string) ([]message.Message, error) { + // Drain any debounced updates so the caller observes the latest + // in-memory state. message.Service buffers streaming deltas and a + // cold List would otherwise miss them at session-switch time. + if err := w.app.Messages.FlushAll(ctx); err != nil { + return nil, err + } return w.app.Messages.List(ctx, sessionID) }