diff --git a/internal/agent/agent.go b/internal/agent/agent.go index 66062270c669ffaba298980ff997c6e2c2e04c2e..ba1215e9a22ca425f745e6e00c85908737d50a0e 100644 --- a/internal/agent/agent.go +++ b/internal/agent/agent.go @@ -245,6 +245,15 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy defer cancel() defer a.activeRequests.Del(call.SessionID) + // Drain any debounced message updates before returning. message.Service + // already flushes synchronously on terminal updates, but a defer here + // guarantees the contract at every Run exit (success, error, panic + // recovery upstream) without callers needing to know. + defer func() { + if flushErr := a.messages.FlushAll(ctx); flushErr != nil { + slog.Error("Failed to flush pending message updates after run", "error", flushErr) + } + }() history, files := a.preparePrompt(msgs, largeModel.CatwalkCfg.SupportsImages, call.Attachments...) @@ -653,6 +662,11 @@ func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fan a.activeRequests.Set(sessionID, cancel) defer a.activeRequests.Del(sessionID) defer cancel() + defer func() { + if flushErr := a.messages.FlushAll(ctx); flushErr != nil { + slog.Error("Failed to flush pending message updates after summarize", "error", flushErr) + } + }() agent := fantasy.NewAgent(largeModel.Model, fantasy.WithSystemPrompt(string(summaryPrompt)), diff --git a/internal/app/app.go b/internal/app/app.go index a167ca8638c8497a6d6f4260782ba334c6dbe0c3..a5414246b7819cbafccc452c3b79580564c9b68c 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -582,13 +582,23 @@ func (app *App) Shutdown() { app.AgentCoordinator.CancelAll() } - // Now run remaining cleanup tasks in parallel. - var wg sync.WaitGroup - // Shared shutdown context for all timeout-bounded cleanup. shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() + // Drain any debounced message updates before the DB-close cleanup + // runs in the parallel block below. message.Service buffers + // streaming deltas (see internal/message/message.go) and we must + // land them while the connection is still open. + if app.Messages != nil { + if err := app.Messages.FlushAll(shutdownCtx); err != nil { + slog.Error("Failed to flush pending message updates on shutdown", "error", err) + } + } + + // Now run remaining cleanup tasks in parallel. + var wg sync.WaitGroup + // Send exit event wg.Go(func() { event.AppExited() diff --git a/internal/backend/session.go b/internal/backend/session.go index 10e21ed8932ccbc990a525785166517cd231595c..9542282319f766e384d9ffe7942e9c235ae63e29 100644 --- a/internal/backend/session.go +++ b/internal/backend/session.go @@ -72,6 +72,12 @@ func (b *Backend) ListSessionMessages(ctx context.Context, workspaceID, sessionI return nil, err } + // Drain debounced updates so HTTP clients (and the TUI on session + // switch) observe the latest in-memory state rather than racing the + // debounce timer in message.Service. + if err := ws.Messages.FlushAll(ctx); err != nil { + return nil, err + } return ws.Messages.List(ctx, sessionID) } diff --git a/internal/message/message.go b/internal/message/message.go index 6da8827b72227602dc36c39b6a2254aba18d2b0d..9e8d258c543d9f9aed31518b848439456b5e4336 100644 --- a/internal/message/message.go +++ b/internal/message/message.go @@ -5,6 +5,7 @@ import ( "database/sql" "encoding/json" "fmt" + "sync" "time" "github.com/charmbracelet/crush/internal/db" @@ -12,6 +13,13 @@ import ( "github.com/google/uuid" ) +// defaultUpdateDebounce is the default debounce window for [Service.Update]. +// Streaming deltas that arrive within the window are coalesced into a +// single SQL write and a single pubsub event. Terminal updates +// (finish/error/cancel/tool-call structural changes) bypass the +// debounce and flush synchronously. +const defaultUpdateDebounce = 33 * time.Millisecond + type CreateMessageParams struct { Role MessageRole Parts []ContentPart @@ -20,6 +28,21 @@ type CreateMessageParams struct { IsSummaryMessage bool } +// Service is the public interface to the message store. +// +// [Service.Update] is eventually consistent: it accepts new state into +// an in-memory buffer and writes it to SQLite plus publishes a +// [pubsub.UpdatedEvent] on the next debounce tick (default +// [defaultUpdateDebounce]) or on the next terminal-state update, +// whichever comes first. Terminal-state updates — those that finish +// the message, add or finish a tool call, or end a reasoning section — +// flush synchronously before [Service.Update] returns. +// +// Callers that need stronger ordering (e.g. tests, shutdown, +// session-switch reads) must use [Service.Flush] or [Service.FlushAll] +// before reading via [Service.Get] / [Service.List]. Without an +// explicit flush, a read can race the debounce timer and miss the +// most recent in-memory state. type Service interface { pubsub.Subscriber[Message] Create(ctx context.Context, sessionID string, params CreateMessageParams) (Message, error) @@ -30,20 +53,87 @@ type Service interface { ListAllUserMessages(ctx context.Context) ([]Message, error) Delete(ctx context.Context, id string) error DeleteSessionMessages(ctx context.Context, sessionID string) error + + // Flush synchronously drains any pending debounced state for the + // given message ID, performs the SQL write, and publishes the + // resulting [pubsub.UpdatedEvent]. Idempotent; cheap no-op if no + // updates are pending. Use this before any read that must observe + // the latest [Service.Update]. + Flush(ctx context.Context, id string) error + + // FlushAll synchronously drains pending debounced state for every + // message known to the service. Intended for shutdown and + // session-switch paths. + FlushAll(ctx context.Context) error +} + +// pendingState holds the in-memory coalescing buffer for a single +// message ID. All fields except where noted are guarded by +// service.mu. The flushing flag serializes concurrent flushers for +// the same ID so SQL writes never reorder. +type pendingState struct { + // latest is the most recent [Message] passed to [Service.Update] + // that has not yet been flushed. + latest Message + + // dirty is true when latest contains state that has not been + // written to SQL since the last successful flush. + dirty bool + + // flushing is true while a goroutine is performing the SQL write + // for this ID. New updates are still accepted (and re-mark dirty) + // but other flushers must back off. + flushing bool + + // timer is the active debounce timer, or nil if no flush is + // scheduled. Stopped and reset when a terminal update preempts + // the debounce window. + timer *time.Timer + + // lastFlushed is the snapshot most recently written to SQL. Used + // as the baseline for terminal-state detection. + lastFlushed Message + + // hasFlushed is false until the first successful write for this + // ID; until then lastFlushed is the zero value and must not be + // treated as a real prior state. + hasFlushed bool } type service struct { *pubsub.Broker[Message] - q db.Querier + q db.Querier + debounce time.Duration + + mu sync.Mutex + pending map[string]*pendingState } -func NewService(q db.Querier) Service { - return &service{ - Broker: pubsub.NewBroker[Message](), - q: q, +// ServiceOption configures a [Service] at construction. +type ServiceOption func(*service) + +// WithDebounce overrides the debounce window for [Service.Update]. A +// zero or negative value disables debouncing entirely (every update +// flushes synchronously). Intended primarily for tests. +func WithDebounce(d time.Duration) ServiceOption { + return func(s *service) { + s.debounce = d } } +func NewService(q db.Querier, opts ...ServiceOption) Service { + s := &service{ + Broker: pubsub.NewBroker[Message](), + q: q, + debounce: defaultUpdateDebounce, + pending: make(map[string]*pendingState), + } + for _, opt := range opts { + opt(s) + } + return s +} + func (s *service) Delete(ctx context.Context, id string) error { message, err := s.Get(ctx, id) if err != nil { @@ -53,6 +143,16 @@ func (s *service) Delete(ctx context.Context, id string) error { if err != nil { return err } + // Drop any pending coalesced state for this ID. We never want to + // flush back over a deleted row. + s.mu.Lock() + if p, ok := s.pending[id]; ok { + if p.timer != nil { + p.timer.Stop() + } + delete(s.pending, id) + } + s.mu.Unlock() // Clone the message before publishing to avoid race conditions with // concurrent modifications to the Parts slice. s.Publish(pubsub.DeletedEvent, message.Clone()) @@ -111,31 +211,240 @@ func (s *service) DeleteSessionMessages(ctx context.Context, sessionID string) e return nil } -func (s *service) Update(ctx context.Context, message Message) error { - parts, err := marshalParts(message.Parts) +// Update accepts a new state for a message and either flushes +// synchronously (terminal updates, debounce <= 0) or buffers it until +// the next debounce tick. See [Service] for the contract. +func (s *service) Update(ctx context.Context, msg Message) error { + cloned := msg.Clone() + + // Zero or negative debounce: flush every update synchronously. This + // preserves the pre-coalescing behaviour for tests and any caller + // that explicitly opted out via [WithDebounce]. + if s.debounce <= 0 { + s.mu.Lock() + p, ok := s.pending[msg.ID] + if !ok { + p = &pendingState{} + s.pending[msg.ID] = p + } + p.latest = cloned + p.dirty = true + s.mu.Unlock() + return s.flushOne(ctx, msg.ID, true) + } + + s.mu.Lock() + p, ok := s.pending[msg.ID] + if !ok { + p = &pendingState{} + s.pending[msg.ID] = p + } + p.latest = cloned + p.dirty = true + + var prev *Message + if p.hasFlushed { + prev = &p.lastFlushed + } + terminal := shouldFlushNow(prev, &cloned) + + if terminal { + if p.timer != nil { + p.timer.Stop() + p.timer = nil + } + s.mu.Unlock() + return s.flushOne(ctx, msg.ID, true) + } + + // Debounce: schedule a single flush per pending state. If a flush + // is already running we let it finish; the trailing dirty bit will + // be picked up by the next Update or by Flush. + if p.timer == nil && !p.flushing { + id := msg.ID + p.timer = time.AfterFunc(s.debounce, func() { + // Detached from caller ctx so a cancelled stream context + // does not strand the buffered write. + _ = s.flushOne(context.Background(), id, false) + }) + } + s.mu.Unlock() + return nil +} + +// Flush implements [Service.Flush]. +func (s *service) Flush(ctx context.Context, id string) error { + return s.flushOne(ctx, id, true) +} + +// FlushAll implements [Service.FlushAll]. It snapshots every ID with +// outstanding work — either dirty buffered state or a flush already in +// flight — then drains each one. Picking up in-flight IDs ensures +// FlushAll cannot return while a timer-fired write is still mid-SQL, +// which is what shutdown and session-switch callers rely on. +func (s *service) FlushAll(ctx context.Context) error { + s.mu.Lock() + ids := make([]string, 0, len(s.pending)) + for id, p := range s.pending { + if p.dirty || p.flushing { + ids = append(ids, id) + } + } + s.mu.Unlock() + var firstErr error + for _, id := range ids { + if err := s.flushOne(ctx, id, true); err != nil && firstErr == nil { + firstErr = err + } + } + return firstErr +} + +// flushOne drains a single message ID. When syncCaller is true the +// caller is willing to wait through a concurrent in-flight flush so +// that, on return, lastFlushed equals latest at the moment of return. +// When false (timer-fired path) we bail if another flusher is already +// running; that flusher will pick up the trailing dirty bit. +// +// Order matters: a sync caller must wait for any in-flight flush to +// drain even when the buffer is currently clean — that in-flight +// write has not yet updated the SQL row, so returning early would +// violate the contract that on success lastFlushed reflects the most +// recent state. +func (s *service) flushOne(ctx context.Context, id string, syncCaller bool) error { + for { + s.mu.Lock() + p, ok := s.pending[id] + if !ok { + s.mu.Unlock() + return nil + } + if p.flushing { + if !syncCaller { + s.mu.Unlock() + return nil + } + s.mu.Unlock() + // Brief yield; in-flight write should land in <1ms typical. + time.Sleep(time.Millisecond) + continue + } + if !p.dirty { + s.mu.Unlock() + return nil + } + + if p.timer != nil { + p.timer.Stop() + p.timer = nil + } + snap := p.latest + // Decide whether this snapshot represents a terminal event + // against the prior baseline. We must do this before resetting + // dirty/flushing because shouldFlushNow looks at p.lastFlushed + // (which is what was on disk before this write). + var prev *Message + if p.hasFlushed { + prev = &p.lastFlushed + } + isTerminal := shouldFlushNow(prev, &snap) + p.flushing = true + p.dirty = false + s.mu.Unlock() + + err := s.write(ctx, snap) + + s.mu.Lock() + p.flushing = false + if err == nil { + p.lastFlushed = snap + p.hasFlushed = true + } else { + // Restore dirty so the next caller retries. + p.dirty = true + } + // If a delta arrived during the SQL write and we are a sync + // caller, the user expects that delta to land too. + wasDirty := p.dirty + s.mu.Unlock() + + if err != nil { + return err + } + + // Terminal events — message finished, tool call added or + // finished, reasoning ended — use the bounded must-deliver + // path so they never get dropped under channel contention. + if isTerminal { + s.PublishMustDeliver(ctx, pubsub.UpdatedEvent, snap) + } else { + s.Publish(pubsub.UpdatedEvent, snap) + } + + if wasDirty && syncCaller { + continue + } + return nil + } +} + +// write performs the unguarded SQL write + UpdatedAt stamp. Caller +// owns publishing. +func (s *service) write(ctx context.Context, msg Message) error { + parts, err := marshalParts(msg.Parts) if err != nil { return err } finishedAt := sql.NullInt64{} - if f := message.FinishPart(); f != nil { + if f := msg.FinishPart(); f != nil { finishedAt.Int64 = f.Time finishedAt.Valid = true } - err = s.q.UpdateMessage(ctx, db.UpdateMessageParams{ - ID: message.ID, + if err := s.q.UpdateMessage(ctx, db.UpdateMessageParams{ + ID: msg.ID, Parts: string(parts), FinishedAt: finishedAt, - }) - if err != nil { + }); err != nil { return err } - message.UpdatedAt = time.Now().Unix() - // Clone the message before publishing to avoid race conditions with - // concurrent modifications to the Parts slice. - s.Publish(pubsub.UpdatedEvent, message.Clone()) return nil } +// shouldFlushNow returns true when next represents a structural +// change that must not be silently coalesced: the message just +// finished, the tool-call set grew, a tool call transitioned to +// finished, or reasoning just finished. prev is the last-flushed +// snapshot (or nil if no write has landed yet). +func shouldFlushNow(prev, next *Message) bool { + if next.IsFinished() { + return true + } + + var prevCalls []ToolCall + var prevReasoningFinishedAt int64 + if prev != nil { + prevCalls = prev.ToolCalls() + prevReasoningFinishedAt = prev.ReasoningContent().FinishedAt + } + nextCalls := next.ToolCalls() + if len(nextCalls) != len(prevCalls) { + return true + } + for i := range nextCalls { + // Bounds-safe: lengths are equal here. + if nextCalls[i].Finished != prevCalls[i].Finished { + return true + } + // A tool call's input only matters once it has landed (Finished + // flips true). Earlier deltas to Input are debounced with the + // rest of the streaming state. + } + if next.ReasoningContent().FinishedAt > 0 && prevReasoningFinishedAt == 0 { + return true + } + return false +} + func (s *service) Get(ctx context.Context, id string) (Message, error) { dbMessage, err := s.q.GetMessage(ctx, id) if err != nil { diff --git a/internal/message/message_test.go b/internal/message/message_test.go new file mode 100644 index 0000000000000000000000000000000000000000..0a046337542606ae10aad16ef8cb36d1eb879c03 --- /dev/null +++ b/internal/message/message_test.go @@ -0,0 +1,695 @@ +package message + +import ( + "context" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/charmbracelet/crush/internal/db" + "github.com/charmbracelet/crush/internal/pubsub" + "github.com/charmbracelet/crush/internal/session" + "github.com/stretchr/testify/require" +) + +// slowUpdateQuerier wraps a [db.Querier] and forces UpdateMessage to +// hang on a release channel. Used to simulate an in-flight SQL write. +type slowUpdateQuerier struct { + db.Querier + release chan struct{} + started chan struct{} + startOnce sync.Once +} + +func (s *slowUpdateQuerier) UpdateMessage(ctx context.Context, arg db.UpdateMessageParams) error { + s.startOnce.Do(func() { close(s.started) }) + select { + case <-s.release: + case <-ctx.Done(): + return ctx.Err() + } + return s.Querier.UpdateMessage(ctx, arg) +} + +// newTestService spins up a fresh in-memory message.Service backed by a +// temporary on-disk SQLite database. Returns the service plus a session +// ID to attach messages to. +func newTestService(t *testing.T, opts ...ServiceOption) (Service, string) { + t.Helper() + conn, err := db.Connect(t.Context(), t.TempDir()) + require.NoError(t, err) + t.Cleanup(func() { _ = conn.Close() }) + + q := db.New(conn) + sessions := session.NewService(q, conn) + sess, err := sessions.Create(t.Context(), "test") + require.NoError(t, err) + + svc := NewService(q, opts...) + return svc, sess.ID +} + +// eventCollector consumes broker events into a slice in a goroutine +// and exposes thread-safe Snapshot / Reset helpers for assertions. +type eventCollector struct { + mu sync.Mutex + events []pubsub.Event[Message] +} + +func collect(ctx context.Context, sub <-chan pubsub.Event[Message]) *eventCollector { + c := &eventCollector{} + go func() { + for { + select { + case <-ctx.Done(): + return + case ev, ok := <-sub: + if !ok { + return + } + c.mu.Lock() + c.events = append(c.events, ev) + c.mu.Unlock() + } + } + }() + return c +} + +func (c *eventCollector) snapshot() []pubsub.Event[Message] { + c.mu.Lock() + defer c.mu.Unlock() + out := make([]pubsub.Event[Message], len(c.events)) + copy(out, c.events) + return out +} + +func (c *eventCollector) reset() { + c.mu.Lock() + defer c.mu.Unlock() + c.events = nil +} + +func TestUpdate_DebouncesTextDeltas(t *testing.T) { + t.Parallel() + + // Long-enough debounce that we can verify nothing flushes prematurely. + svc, sessionID := newTestService(t, WithDebounce(50*time.Millisecond)) + + subCtx, cancelSub := context.WithCancel(t.Context()) + defer cancelSub() + sub := svc.Subscribe(subCtx) + collector := collect(subCtx, sub) + + msg, err := svc.Create(t.Context(), sessionID, CreateMessageParams{ + Role: Assistant, + }) + require.NoError(t, err) + // Drop the CreatedEvent emitted by Create. + time.Sleep(5 * time.Millisecond) + collector.reset() + + // Push 5 deltas inside a single debounce window. + for i := 0; i < 5; i++ { + msg.AppendContent("a") + require.NoError(t, svc.Update(t.Context(), msg)) + } + + // Before the debounce expires no UpdatedEvent should have landed. + time.Sleep(10 * time.Millisecond) + require.Empty(t, collector.snapshot(), "no events should land before debounce window expires") + + // Wait for the debounce timer to fire. + require.Eventually(t, func() bool { + return len(collector.snapshot()) >= 1 + }, time.Second, 5*time.Millisecond) + events := collector.snapshot() + require.Len(t, events, 1, "5 deltas should coalesce into 1 UpdatedEvent") + require.Equal(t, pubsub.UpdatedEvent, events[0].Type) + require.Equal(t, "aaaaa", events[0].Payload.Content().Text) + + // Final state must be persisted. + got, err := svc.Get(t.Context(), msg.ID) + require.NoError(t, err) + require.Equal(t, "aaaaa", got.Content().Text) +} + +func TestUpdate_TerminalUpdatesFlushSynchronously(t *testing.T) { + t.Parallel() + + svc, sessionID := newTestService(t, WithDebounce(time.Hour)) + + subCtx, cancelSub := context.WithCancel(t.Context()) + defer cancelSub() + sub := svc.Subscribe(subCtx) + collector := collect(subCtx, sub) + + msg, err := svc.Create(t.Context(), sessionID, CreateMessageParams{Role: Assistant}) + require.NoError(t, err) + time.Sleep(5 * time.Millisecond) + collector.reset() + + // AddFinish makes the message terminal; Update must flush + // synchronously even with a 1-hour debounce. + msg.AppendContent("done") + msg.AddFinish(FinishReasonEndTurn, "", "") + require.NoError(t, svc.Update(t.Context(), msg)) + + require.Eventually(t, func() bool { + return len(collector.snapshot()) >= 1 + }, time.Second, 5*time.Millisecond, + "terminal update must publish without waiting for debounce") + events := collector.snapshot() + require.Len(t, events, 1) + require.True(t, events[0].Payload.IsFinished()) + + got, err := svc.Get(t.Context(), msg.ID) + require.NoError(t, err) + require.True(t, got.IsFinished()) +} + +func TestUpdate_ToolCallStructuralChangeFlushes(t *testing.T) { + t.Parallel() + + svc, sessionID := newTestService(t, WithDebounce(time.Hour)) + + msg, err := svc.Create(t.Context(), sessionID, CreateMessageParams{Role: Assistant}) + require.NoError(t, err) + + // Adding a new tool call is a structural change → sync flush. + msg.AddToolCall(ToolCall{ID: "tc1", Name: "view", Finished: false}) + require.NoError(t, svc.Update(t.Context(), msg)) + + got, err := svc.Get(t.Context(), msg.ID) + require.NoError(t, err) + require.Len(t, got.ToolCalls(), 1) + require.Equal(t, "tc1", got.ToolCalls()[0].ID) + + // Marking the tool call finished is also a structural change. + msg.AddToolCall(ToolCall{ID: "tc1", Name: "view", Input: "{}", Finished: true}) + require.NoError(t, svc.Update(t.Context(), msg)) + + got, err = svc.Get(t.Context(), msg.ID) + require.NoError(t, err) + require.True(t, got.ToolCalls()[0].Finished) +} + +func TestUpdate_ReasoningEndFlushes(t *testing.T) { + t.Parallel() + + svc, sessionID := newTestService(t, WithDebounce(time.Hour)) + + msg, err := svc.Create(t.Context(), sessionID, CreateMessageParams{Role: Assistant}) + require.NoError(t, err) + + // Reasoning deltas alone debounce. + msg.AppendReasoningContent("hmm") + require.NoError(t, svc.Update(t.Context(), msg)) + got, err := svc.Get(t.Context(), msg.ID) + require.NoError(t, err) + require.Empty(t, got.ReasoningContent().Thinking, "reasoning delta should still be in the debounce buffer") + + // FinishThinking sets FinishedAt → terminal flush. + msg.FinishThinking() + require.NoError(t, svc.Update(t.Context(), msg)) + + got, err = svc.Get(t.Context(), msg.ID) + require.NoError(t, err) + require.Equal(t, "hmm", got.ReasoningContent().Thinking) + require.NotZero(t, got.ReasoningContent().FinishedAt) +} + +func TestFlush_DrainsPendingDebouncedUpdates(t *testing.T) { + t.Parallel() + + svc, sessionID := newTestService(t, WithDebounce(time.Hour)) + + msg, err := svc.Create(t.Context(), sessionID, CreateMessageParams{Role: Assistant}) + require.NoError(t, err) + msg.AppendContent("buffered") + require.NoError(t, svc.Update(t.Context(), msg)) + + // Without a flush the SQL row is unchanged from Create. + got, err := svc.Get(t.Context(), msg.ID) + require.NoError(t, err) + require.Empty(t, got.Content().Text) + + require.NoError(t, svc.Flush(t.Context(), msg.ID)) + + got, err = svc.Get(t.Context(), msg.ID) + require.NoError(t, err) + require.Equal(t, "buffered", got.Content().Text) + + // Subsequent Flush is a no-op. + require.NoError(t, svc.Flush(t.Context(), msg.ID)) +} + +func TestFlushAll_DrainsAllPending(t *testing.T) { + t.Parallel() + + svc, sessionID := newTestService(t, WithDebounce(time.Hour)) + + const n = 5 + msgs := make([]Message, n) + for i := range msgs { + m, err := svc.Create(t.Context(), sessionID, CreateMessageParams{Role: Assistant}) + require.NoError(t, err) + m.AppendContent("hi") + require.NoError(t, svc.Update(t.Context(), m)) + msgs[i] = m + } + + require.NoError(t, svc.FlushAll(t.Context())) + + for _, m := range msgs { + got, err := svc.Get(t.Context(), m.ID) + require.NoError(t, err) + require.Equal(t, "hi", got.Content().Text, "FlushAll should drain every pending message") + } +} + +func TestUpdate_OrderingMatchesNonCoalesced(t *testing.T) { + t.Parallel() + + // Compare the final state after coalesced vs zero-debounce updates. + // A sequence of interleaved text/reasoning/tool-call updates must + // converge to the same final DB row either way. + build := func(svc Service, sessionID string) Message { + msg, err := svc.Create(t.Context(), sessionID, CreateMessageParams{Role: Assistant}) + require.NoError(t, err) + msg.AppendReasoningContent("thinking 1 ") + require.NoError(t, svc.Update(t.Context(), msg)) + msg.AppendReasoningContent("thinking 2") + require.NoError(t, svc.Update(t.Context(), msg)) + msg.FinishThinking() + require.NoError(t, svc.Update(t.Context(), msg)) + msg.AppendContent("hello ") + require.NoError(t, svc.Update(t.Context(), msg)) + msg.AppendContent("world") + require.NoError(t, svc.Update(t.Context(), msg)) + msg.AddToolCall(ToolCall{ID: "tc", Name: "x", Finished: false}) + require.NoError(t, svc.Update(t.Context(), msg)) + msg.AddToolCall(ToolCall{ID: "tc", Name: "x", Input: "{}", Finished: true}) + require.NoError(t, svc.Update(t.Context(), msg)) + msg.AddFinish(FinishReasonEndTurn, "", "") + require.NoError(t, svc.Update(t.Context(), msg)) + return msg + } + + coalesced, sid1 := newTestService(t, WithDebounce(20*time.Millisecond)) + a := build(coalesced, sid1) + require.NoError(t, coalesced.FlushAll(t.Context())) + gotA, err := coalesced.Get(t.Context(), a.ID) + require.NoError(t, err) + + immediate, sid2 := newTestService(t, WithDebounce(0)) + b := build(immediate, sid2) + gotB, err := immediate.Get(t.Context(), b.ID) + require.NoError(t, err) + + require.Equal(t, gotA.Content().Text, gotB.Content().Text) + require.Equal(t, gotA.ReasoningContent().Thinking, gotB.ReasoningContent().Thinking) + require.Equal(t, len(gotA.ToolCalls()), len(gotB.ToolCalls())) + require.Equal(t, gotA.IsFinished(), gotB.IsFinished()) +} + +func TestDelete_DropsPendingState(t *testing.T) { + t.Parallel() + + svc, sessionID := newTestService(t, WithDebounce(time.Hour)) + msg, err := svc.Create(t.Context(), sessionID, CreateMessageParams{Role: Assistant}) + require.NoError(t, err) + msg.AppendContent("dropped") + require.NoError(t, svc.Update(t.Context(), msg)) + + require.NoError(t, svc.Delete(t.Context(), msg.ID)) + + // FlushAll after Delete must not write to the deleted row. + require.NoError(t, svc.FlushAll(t.Context())) + + _, err = svc.Get(t.Context(), msg.ID) + require.Error(t, err, "deleted message must remain deleted") +} + +func TestBroker_PublishLossyDropCounter(t *testing.T) { + t.Parallel() + + // Tiny channel buffer so we can saturate from a single sender. + b := pubsub.NewBrokerWithOptions[int](1, 1000) + defer b.Shutdown() + + subCtx, cancel := context.WithCancel(t.Context()) + defer cancel() + sub := b.Subscribe(subCtx) + require.NotNil(t, sub) + + // Don't read from sub. Saturate the buffer. + for range 100 { + b.Publish(pubsub.UpdatedEvent, 1) + } + + require.GreaterOrEqual(t, b.DropCount(), uint64(1), + "lossy Publish must increment the drop counter under contention") +} + +func TestBroker_PublishMustDeliverHonorsTimeout(t *testing.T) { + t.Parallel() + + b := pubsub.NewBrokerWithOptions[int](1, 1000) + b.SetMustDeliverTimeout(20 * time.Millisecond) + defer b.Shutdown() + + subCtx, cancel := context.WithCancel(t.Context()) + defer cancel() + sub := b.Subscribe(subCtx) + require.NotNil(t, sub) + + // Saturate: one event sits in the buffer, the second must wait. + b.Publish(pubsub.UpdatedEvent, 1) + + // PublishMustDeliver should block up to 20ms then drop. + start := time.Now() + b.PublishMustDeliver(t.Context(), pubsub.UpdatedEvent, 2) + elapsed := time.Since(start) + + require.GreaterOrEqual(t, elapsed, 20*time.Millisecond, + "PublishMustDeliver should block at least the timeout under contention") + require.Less(t, elapsed, 200*time.Millisecond, + "PublishMustDeliver must not block indefinitely") + require.GreaterOrEqual(t, b.MustDeliverDropCount(), uint64(1), + "timeout must increment the must-deliver drop counter") +} + +func TestBroker_PublishMustDeliverWithReader(t *testing.T) { + t.Parallel() + + b := pubsub.NewBrokerWithOptions[int](1, 1000) + b.SetMustDeliverTimeout(50 * time.Millisecond) + defer b.Shutdown() + + subCtx, cancel := context.WithCancel(t.Context()) + defer cancel() + sub := b.Subscribe(subCtx) + + var received atomic.Uint64 + done := make(chan struct{}) + go func() { + defer close(done) + for { + select { + case <-subCtx.Done(): + return + case _, ok := <-sub: + if !ok { + return + } + received.Add(1) + } + } + }() + + for i := range 10 { + b.PublishMustDeliver(t.Context(), pubsub.UpdatedEvent, i) + } + + // All 10 should land within the must-deliver timeout window. + require.Eventually(t, func() bool { return received.Load() == 10 }, + time.Second, 5*time.Millisecond, + "all must-deliver events should reach an active subscriber") + require.Zero(t, b.MustDeliverDropCount(), + "no drops expected when subscriber drains promptly") +} + +func TestUpdate_TerminalEventUsesMustDeliver(t *testing.T) { + t.Parallel() + + svc, sessionID := newTestService(t, WithDebounce(time.Hour)) + + subCtx, cancel := context.WithCancel(t.Context()) + defer cancel() + sub := svc.Subscribe(subCtx) + + var seenFinish atomic.Bool + done := make(chan struct{}) + go func() { + defer close(done) + for { + select { + case <-subCtx.Done(): + return + case ev, ok := <-sub: + if !ok { + return + } + if ev.Type == pubsub.UpdatedEvent && ev.Payload.IsFinished() { + seenFinish.Store(true) + } + } + } + }() + + msg, err := svc.Create(t.Context(), sessionID, CreateMessageParams{Role: Assistant}) + require.NoError(t, err) + msg.AppendContent("final") + msg.AddFinish(FinishReasonEndTurn, "", "") + require.NoError(t, svc.Update(t.Context(), msg)) + + require.Eventually(t, func() bool { return seenFinish.Load() }, + time.Second, 10*time.Millisecond, + "terminal update must reach subscribers via the must-deliver path") +} + +func TestUpdate_ZeroDebounceFlushesEveryUpdate(t *testing.T) { + t.Parallel() + + svc, sessionID := newTestService(t, WithDebounce(0)) + + msg, err := svc.Create(t.Context(), sessionID, CreateMessageParams{Role: Assistant}) + require.NoError(t, err) + + for i := 0; i < 3; i++ { + msg.AppendContent("x") + require.NoError(t, svc.Update(t.Context(), msg)) + got, err := svc.Get(t.Context(), msg.ID) + require.NoError(t, err) + require.Len(t, got.Content().Text, i+1, "every update must land synchronously when debounce is 0") + } +} + +// TestFlush_WaitsForInFlightWrite reproduces the failure where Flush +// or FlushAll could return before a concurrent in-flight SQL write +// completed. We block UpdateMessage on a release channel, fire the +// debounce timer, then call Flush and assert it does not return until +// the in-flight write actually lands. +func TestFlush_WaitsForInFlightWrite(t *testing.T) { + t.Parallel() + + conn, err := db.Connect(t.Context(), t.TempDir()) + require.NoError(t, err) + t.Cleanup(func() { _ = conn.Close() }) + + q := db.New(conn) + sessions := session.NewService(q, conn) + sess, err := sessions.Create(t.Context(), "test") + require.NoError(t, err) + + slow := &slowUpdateQuerier{ + Querier: q, + release: make(chan struct{}), + started: make(chan struct{}), + } + // Short debounce so the timer fires quickly. + svc := NewService(slow, WithDebounce(10*time.Millisecond)) + + msg, err := svc.Create(t.Context(), sess.ID, CreateMessageParams{Role: Assistant}) + require.NoError(t, err) + msg.AppendContent("payload") + require.NoError(t, svc.Update(t.Context(), msg)) + + // Wait for the timer-fired flush to enter UpdateMessage. + select { + case <-slow.started: + case <-time.After(time.Second): + t.Fatal("timer-fired flush never reached UpdateMessage") + } + + // At this point the buffer is dirty=false but flushing=true. A + // naive Flush would early-return on !dirty. Spawn Flush in a + // goroutine and assert it has not returned while the write is + // still blocked. + flushDone := make(chan error, 1) + go func() { flushDone <- svc.Flush(t.Context(), msg.ID) }() + + select { + case err := <-flushDone: + t.Fatalf("Flush returned %v while in-flight write was still blocked", err) + case <-time.After(50 * time.Millisecond): + // Expected: Flush is correctly waiting. + } + + // Release the slow write; Flush must now return cleanly. + close(slow.release) + select { + case err := <-flushDone: + require.NoError(t, err) + case <-time.After(time.Second): + t.Fatal("Flush did not return after in-flight write completed") + } + + // The SQL row should now reflect the buffered payload. + got, err := svc.Get(t.Context(), msg.ID) + require.NoError(t, err) + require.Equal(t, "payload", got.Content().Text) +} + +// TestFlushAll_WaitsForInFlightWrite asserts FlushAll picks up IDs +// whose buffer is currently flushing (dirty=false) so shutdown and +// session-switch callers don't return while an SQL write is mid-flight. +func TestFlushAll_WaitsForInFlightWrite(t *testing.T) { + t.Parallel() + + conn, err := db.Connect(t.Context(), t.TempDir()) + require.NoError(t, err) + t.Cleanup(func() { _ = conn.Close() }) + + q := db.New(conn) + sessions := session.NewService(q, conn) + sess, err := sessions.Create(t.Context(), "test") + require.NoError(t, err) + + slow := &slowUpdateQuerier{ + Querier: q, + release: make(chan struct{}), + started: make(chan struct{}), + } + svc := NewService(slow, WithDebounce(10*time.Millisecond)) + + msg, err := svc.Create(t.Context(), sess.ID, CreateMessageParams{Role: Assistant}) + require.NoError(t, err) + msg.AppendContent("payload") + require.NoError(t, svc.Update(t.Context(), msg)) + + select { + case <-slow.started: + case <-time.After(time.Second): + t.Fatal("timer-fired flush never reached UpdateMessage") + } + + flushDone := make(chan error, 1) + go func() { flushDone <- svc.FlushAll(t.Context()) }() + + select { + case err := <-flushDone: + t.Fatalf("FlushAll returned %v while in-flight write was still blocked", err) + case <-time.After(50 * time.Millisecond): + } + + close(slow.release) + select { + case err := <-flushDone: + require.NoError(t, err) + case <-time.After(time.Second): + t.Fatal("FlushAll did not return after in-flight write completed") + } + + got, err := svc.Get(t.Context(), msg.ID) + require.NoError(t, err) + require.Equal(t, "payload", got.Content().Text) +} + +// TestUpdate_StructuralFlushUsesMustDeliver covers the second review +// finding: structural terminal events (tool-call add, tool-call +// finish, reasoning end) must publish via the must-deliver path even +// when the message itself is not yet IsFinished. +// +// We detect which path was taken by saturating a subscriber's channel +// buffer with no reader. With a short must-deliver timeout, the +// must-deliver path increments [pubsub.Broker.MustDeliverDropCount] +// after the timeout expires; the lossy path increments +// [pubsub.Broker.DropCount] immediately. The two counters are +// disjoint, so they precisely identify which call site published the +// event. +func TestUpdate_StructuralFlushUsesMustDeliver(t *testing.T) { + t.Parallel() + + cases := []struct { + name string + mut func(*Message) + }{ + { + name: "tool call add", + mut: func(m *Message) { + m.AddToolCall(ToolCall{ID: "tc1", Name: "view"}) + }, + }, + { + name: "tool call finish", + mut: func(m *Message) { + m.AddToolCall(ToolCall{ID: "tc1", Name: "view", Input: "{}", Finished: true}) + }, + }, + { + name: "reasoning end", + mut: func(m *Message) { + m.AppendReasoningContent("hmm") + m.FinishThinking() + }, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + conn, err := db.Connect(t.Context(), t.TempDir()) + require.NoError(t, err) + t.Cleanup(func() { _ = conn.Close() }) + + q := db.New(conn) + sessions := session.NewService(q, conn) + sess, err := sessions.Create(t.Context(), "test") + require.NoError(t, err) + + // Replace the default broker with a tiny buffer + short + // must-deliver timeout so we can fully saturate from a + // single sender and observe drops without long waits. + svc := NewService(q, WithDebounce(time.Hour)) + impl := svc.(*service) + impl.Shutdown() + impl.Broker = pubsub.NewBrokerWithOptions[Message](1, 1000) + impl.SetMustDeliverTimeout(40 * time.Millisecond) + + subCtx, cancel := context.WithCancel(t.Context()) + defer cancel() + sub := svc.Subscribe(subCtx) + + msg, err := svc.Create(t.Context(), sess.ID, CreateMessageParams{Role: Assistant}) + require.NoError(t, err) + + // Saturate the subscriber's buffer (capacity 1). The + // CreatedEvent from Create above already left one event + // in the buffer; we never read sub, so the next publish + // has nowhere to go. + _ = sub // intentionally not drained. + + // Drive the structural change. With debounce=1h, Update + // flushes synchronously and routes through whichever + // publish path the service chose for structural events. + tc.mut(&msg) + require.NoError(t, svc.Update(t.Context(), msg)) + + // Must-deliver timeout (40ms) should have expired with + // no drain. If structural events are routed through + // must-deliver: MustDeliverDropCount > 0, DropCount + // unchanged. If routed through lossy Publish: + // DropCount > 0, MustDeliverDropCount == 0. + require.Eventually(t, func() bool { + return impl.MustDeliverDropCount() >= 1 + }, time.Second, 5*time.Millisecond, + "structural terminal event should publish via must-deliver, not lossy Publish") + require.Zero(t, impl.DropCount(), + "structural terminal event must not be silently dropped via lossy Publish") + }) + } +} diff --git a/internal/pubsub/broker.go b/internal/pubsub/broker.go index 52e827c2d6306fa2365372b9867810d9b99d0227..03af2e676f67bee68a757eb85a600d7b09838c4a 100644 --- a/internal/pubsub/broker.go +++ b/internal/pubsub/broker.go @@ -1,19 +1,54 @@ +// Package pubsub provides a lightweight in-process broker for fan-out +// event delivery between services and the UI. +// +// Delivery semantics: +// +// - [Broker.Publish] is best-effort and lossy under contention. If a +// subscriber's channel is full, the event is dropped for that +// subscriber and a counter is incremented. This is the right choice +// for high-frequency intermediate updates (e.g. streaming token +// deltas) where only the latest state matters. +// +// - [Broker.PublishMustDeliver] is bounded-blocking. For each +// subscriber it first tries a non-blocking send, then falls back to +// a per-subscriber blocking send with a hard timeout. On timeout the +// event is dropped for that subscriber, an error is logged, and the +// must-deliver drop counter is incremented. The publisher never +// blocks indefinitely. This is the right choice for terminal events +// (finish, tool result, error, cancel) that must not be silently +// coalesced away. +// +// Drop counters ([Broker.DropCount], [Broker.MustDeliverDropCount]) are +// exposed so callers can surface saturation in telemetry. package pubsub import ( "context" + "log/slog" "sync" + "sync/atomic" + "time" ) -const bufferSize = 64 +const ( + bufferSize = 64 + + // defaultMustDeliverTimeout is the per-subscriber upper bound on how + // long [Broker.PublishMustDeliver] will block waiting for buffer + // space before giving up on that subscriber. + defaultMustDeliverTimeout = 50 * time.Millisecond +) type Broker[T any] struct { - subs map[chan Event[T]]struct{} - mu sync.RWMutex - done chan struct{} - subCount int - maxEvents int - channelBufferSize int + subs map[chan Event[T]]struct{} + mu sync.RWMutex + done chan struct{} + subCount int + maxEvents int + channelBufferSize int + mustDeliverTimeout time.Duration + dropCount atomic.Uint64 + mustDeliverDropCount atomic.Uint64 } func NewBroker[T any]() *Broker[T] { @@ -22,13 +57,27 @@ func NewBroker[T any]() *Broker[T] { func NewBrokerWithOptions[T any](channelBufferSize, maxEvents int) *Broker[T] { return &Broker[T]{ - subs: make(map[chan Event[T]]struct{}), - done: make(chan struct{}), - maxEvents: maxEvents, - channelBufferSize: channelBufferSize, + subs: make(map[chan Event[T]]struct{}), + done: make(chan struct{}), + maxEvents: maxEvents, + channelBufferSize: channelBufferSize, + mustDeliverTimeout: defaultMustDeliverTimeout, } } +// SetMustDeliverTimeout overrides the per-subscriber timeout used by +// [Broker.PublishMustDeliver]. A zero or negative value resets to the +// default. Intended primarily for tests. +func (b *Broker[T]) SetMustDeliverTimeout(d time.Duration) { + b.mu.Lock() + defer b.mu.Unlock() + if d <= 0 { + b.mustDeliverTimeout = defaultMustDeliverTimeout + return + } + b.mustDeliverTimeout = d +} + func (b *Broker[T]) Shutdown() { select { case <-b.done: // Already closed @@ -90,6 +139,25 @@ func (b *Broker[T]) GetSubscriberCount() int { return b.subCount } +// DropCount returns the cumulative number of events dropped by +// [Broker.Publish] because a subscriber's channel was full. +func (b *Broker[T]) DropCount() uint64 { + return b.dropCount.Load() +} + +// MustDeliverDropCount returns the cumulative number of events dropped +// by [Broker.PublishMustDeliver] after the per-subscriber timeout +// expired. +func (b *Broker[T]) MustDeliverDropCount() uint64 { + return b.mustDeliverDropCount.Load() +} + +// Publish delivers an event to every active subscriber. +// +// Delivery is non-blocking and lossy: if a subscriber's channel is full +// the event is dropped for that subscriber and [Broker.DropCount] is +// incremented. Use [Broker.PublishMustDeliver] for events that must not +// be silently dropped. func (b *Broker[T]) Publish(t EventType, payload T) { b.mu.RLock() defer b.mu.RUnlock() @@ -106,8 +174,57 @@ func (b *Broker[T]) Publish(t EventType, payload T) { select { case sub <- event: default: - // Channel is full, subscriber is slow - skip this event - // This prevents blocking the publisher + // Channel is full, subscriber is slow - skip this event. + // Lossy by design; counted so saturation is observable. + b.dropCount.Add(1) + } + } +} + +// PublishMustDeliver delivers an event with bounded-blocking semantics. +// For each subscriber it first attempts a non-blocking send, then falls +// back to a blocking send bounded by a per-subscriber timeout (default +// [defaultMustDeliverTimeout]). On timeout the event is dropped for +// that subscriber, [Broker.MustDeliverDropCount] is incremented, and an +// error is logged. The publisher never blocks indefinitely. +// +// Use this for terminal events that must reach subscribers (finish, +// tool result, error, cancel). Callers must still tolerate rare drops +// after timeout — recovery is the subscriber's responsibility (e.g. a +// re-fetch on the next session-visible event). +func (b *Broker[T]) PublishMustDeliver(ctx context.Context, t EventType, payload T) { + b.mu.RLock() + defer b.mu.RUnlock() + + select { + case <-b.done: + return + default: + } + + event := Event[T]{Type: t, Payload: payload} + timeout := b.mustDeliverTimeout + + for sub := range b.subs { + // Fast path: non-blocking send. + select { + case sub <- event: + continue + default: + } + + // Slow path: bounded blocking send. + timer := time.NewTimer(timeout) + select { + case sub <- event: + timer.Stop() + case <-timer.C: + b.mustDeliverDropCount.Add(1) + slog.Error("PublishMustDeliver timed out delivering event", + "type", t, "timeout", timeout) + case <-ctx.Done(): + timer.Stop() + return } } } diff --git a/internal/workspace/app_workspace.go b/internal/workspace/app_workspace.go index 57b1228e7eacb28a16141283ee2703a33511bd18..d4e5ed790e3a1cf7a4bcc299e4e4e63bedfbacd5 100644 --- a/internal/workspace/app_workspace.go +++ b/internal/workspace/app_workspace.go @@ -70,6 +70,12 @@ func (w *AppWorkspace) ParseAgentToolSessionID(sessionID string) (string, string // -- Messages -- func (w *AppWorkspace) ListMessages(ctx context.Context, sessionID string) ([]message.Message, error) { + // Drain any debounced updates so the caller observes the latest + // in-memory state. message.Service buffers streaming deltas and a + // cold List would otherwise miss them at session-switch time. + if err := w.app.Messages.FlushAll(ctx); err != nil { + return nil, err + } return w.app.Messages.List(ctx, sessionID) }