From 34995e9333082f6f8a6437e4bc75a55fca45c981 Mon Sep 17 00:00:00 2001 From: Christian Rocha Date: Mon, 1 Jun 2026 17:59:26 -0400 Subject: [PATCH] fix(server): prevent cancels from affecting future prompts Apply a cancel only to prompts that were already accepted when the cancel request arrived. Immediately canceled accepted prompts also publish completion so callers waiting on that prompt do not hang. Co-Authored-By: Charm Crush --- internal/agent/accepted_run_test.go | 9 +- internal/agent/agent.go | 227 ++++++++++++++++------ internal/agent/dispatch_cancel_test.go | 255 ++++++++++++++++++++++++- 3 files changed, 432 insertions(+), 59 deletions(-) diff --git a/internal/agent/accepted_run_test.go b/internal/agent/accepted_run_test.go index d62422a9f02bec68a8da1a08c6e6d6b52e7d7699..14aec44265d44fa0f5b055ef14af3086e13a0cf3 100644 --- a/internal/agent/accepted_run_test.go +++ b/internal/agent/accepted_run_test.go @@ -28,8 +28,13 @@ func (a *sessionAgent) acceptedCount(sessionID string) int { } func (a *sessionAgent) hasPendingCancel(sessionID string) bool { - _, ok := a.pendingCancels.Get(sessionID) - return ok + mark, ok := a.cancelMark.Get(sessionID) + return ok && mark > 0 +} + +func (a *sessionAgent) pendingCancelMark(sessionID string) uint64 { + mark, _ := a.cancelMark.Get(sessionID) + return mark } func TestAcceptedRun_CloseIsIdempotent(t *testing.T) { diff --git a/internal/agent/agent.go b/internal/agent/agent.go index 97bd7a21af4c28f30e46fcaf23c23348467e90ac..393587a111ad806dc26bbf8a80f52dba49ce0397 100644 --- a/internal/agent/agent.go +++ b/internal/agent/agent.go @@ -112,6 +112,15 @@ type SessionAgentCall struct { // (in-process / local callers like AppWorkspace), behavior is // unchanged and no accept tracking applies. Accepted *AcceptedRun + // acceptSeq carries the accept sequence of the handle that produced + // this call after it has been enqueued and its Accepted handle + // stripped. The queue-drain paths compare it against a session's + // cancel mark so a follow-up queued before a cancel is dropped while + // one queued after the cancel survives. 0 means untracked (an + // in-process enqueue with no accept reservation), which the drain + // paths treat as covered by any present mark, preserving the + // pre-sequence behavior. + acceptSeq uint64 } type SessionAgent interface { @@ -167,20 +176,31 @@ type sessionAgent struct { // BeginAccepted increments it; only AcceptedRun.Close decrements // it. acceptedRuns *csync.Map[string, int] - // pendingCancels records sessions whose dispatched-but-not-yet- - // running call should observe a cancellation request. It is only - // set by Cancel when acceptedRuns > 0, so an idle Escape never - // poisons the next prompt. - pendingCancels *csync.Map[string, struct{}] + // cancelMark records, per session, a high-water accept sequence: an + // accepted handle is canceled by it iff the handle's sequence is at + // or below the mark. Cancel raises the mark to the latest sequence + // assigned at cancel time, so a single Cancel covers every prompt + // accepted-but-not-yet-active then, while a prompt accepted later + // (higher sequence) is never poisoned. Absent or 0 means no pending + // cancel. It is only raised by Cancel when acceptedRuns > 0, so an + // idle Escape never records a mark. + cancelMark *csync.Map[string, uint64] // dispatchMuCreate guards lazy creation of per-session entries in // dispatchMu so two goroutines can't race to lock different mutex // instances for the same session. dispatchMuCreate sync.Mutex - // acceptedMu serializes increments/decrements of acceptedRuns. It + // acceptedMu serializes increments/decrements of acceptedRuns and + // the assignment of accept sequence numbers from acceptSeqGen. It // is separate from dispatchMu so AcceptedRun.Close (which may run // while Run holds dispatchMu for the same session) does not // deadlock by re-entering the dispatch lock. acceptedMu sync.Mutex + // acceptSeqGen is the monotonic source of accept sequence numbers. + // Each BeginAccepted increments it under acceptedMu and stamps the + // returned handle, so sequences strictly increase in accept order + // across the agent. Cancel uses its current value as the per-session + // high-water mark. + acceptSeqGen uint64 } type SessionAgentOptions struct { @@ -218,7 +238,7 @@ func NewSessionAgent( activeRequests: csync.NewMap[string, context.CancelFunc](), dispatchMu: csync.NewMap[string, *sync.Mutex](), acceptedRuns: csync.NewMap[string, int](), - pendingCancels: csync.NewMap[string, struct{}](), + cancelMark: csync.NewMap[string, uint64](), } } @@ -231,7 +251,12 @@ func NewSessionAgent( type AcceptedRun struct { agent *sessionAgent sessionID string - done atomic.Bool + // seq is the monotonic accept sequence stamped by BeginAccepted. A + // cancel covers this handle iff seq is at or below the session's + // cancel mark, so a handle accepted after a cancel (higher seq) is + // never poisoned by it. + seq uint64 + done atomic.Bool } // Close decrements the accept counter for this reservation. It is safe @@ -263,19 +288,30 @@ func (a *sessionAgent) BeginAccepted(sessionID string) *AcceptedRun { defer a.acceptedMu.Unlock() count, _ := a.acceptedRuns.Get(sessionID) a.acceptedRuns.Set(sessionID, count+1) - return &AcceptedRun{agent: a, sessionID: sessionID} + a.acceptSeqGen++ + return &AcceptedRun{agent: a, sessionID: sessionID, seq: a.acceptSeqGen} } // endAccepted decrements the accept counter for sessionID. It is only // called via AcceptedRun.Close. It uses a dedicated lock (not the // per-session dispatch mutex) so it can run while Run holds dispatchMu // for the same session without deadlocking. +// +// When the count reaches zero the session's cancel mark is dropped: no +// accepted handle remains for it to cover, and any handle accepted later +// gets a strictly higher sequence that the mark would not match anyway. +// Handles canceled on entry never reach RunComplete, so this is the only +// place that clears the mark for an all-canceled batch. Sibling handles +// covered by the same mark are serialized on the per-session dispatch +// mutex and read the mark before they Close, so this never clears it out +// from under a covered handle still waiting to enter Run. func (a *sessionAgent) endAccepted(sessionID string) { a.acceptedMu.Lock() defer a.acceptedMu.Unlock() count, ok := a.acceptedRuns.Get(sessionID) if !ok || count <= 1 { a.acceptedRuns.Del(sessionID) + a.cancelMark.Del(sessionID) return } a.acceptedRuns.Set(sessionID, count-1) @@ -311,20 +347,44 @@ func (a *sessionAgent) enqueueCall(call SessionAgentCall) { existing = []SessionAgentCall{} } queued := call + if call.Accepted != nil { + // Preserve the accept sequence after the handle is stripped so + // the queue-drain paths can tell a follow-up queued before a + // cancel (covered by the mark) from one queued after it. + queued.acceptSeq = call.Accepted.seq + } queued.OnComplete = nil queued.Accepted = nil existing = append(existing, queued) a.messageQueue.Set(call.SessionID, existing) } -// clearPendingCancel removes any pending-cancel record for sessionID. It -// takes the per-session dispatch lock so it is ordered against Cancel and -// the dispatch handoff. +// clearPendingCancel removes any pending-cancel mark for sessionID. It +// takes the per-session dispatch lock so it is ordered against Cancel +// and the dispatch handoff. func (a *sessionAgent) clearPendingCancel(sessionID string) { mu := a.sessionMu(sessionID) mu.Lock() defer mu.Unlock() - a.pendingCancels.Del(sessionID) + a.cancelMark.Del(sessionID) +} + +// canceledBySeq reports whether an accepted handle or queued call with +// the given accept sequence is covered by a pending cancel for the +// session. Callers must hold the session's dispatch mutex. A tracked +// sequence (seq > 0) is covered only when it is at or below the cancel +// high-water mark, so a prompt accepted after the cancel (higher seq) is +// never poisoned. An untracked sequence (seq == 0, an in-process enqueue +// with no accept reservation) is covered whenever any mark is present, +// preserving the pre-sequence behavior. The mark is not consumed: it +// stays so every sibling handle it covers observes the same cancel, and +// a later handle (higher seq) ignores it regardless. +func (a *sessionAgent) canceledBySeq(sessionID string, seq uint64) bool { + mark, ok := a.cancelMark.Get(sessionID) + if !ok || mark == 0 { + return false + } + return seq == 0 || seq <= mark } // persistCanceledTurn writes the user/assistant records for a turn that @@ -356,6 +416,26 @@ func (a *sessionAgent) persistCanceledTurn(ctx context.Context, call SessionAgen return a.messages.Update(writeCtx, assistant) } +// publishRunComplete emits the authoritative terminal event for a turn. +// It honors the per-call OnComplete hook when set (so the coordinator can +// coalesce retries) and otherwise falls back to the RunComplete broker. +// ctx is used only for the bounded-blocking must-deliver publish; the +// terminal payload is supplied by the caller. This is the single emit path +// shared by the streaming defer and the cancel-on-entry early return so a +// caller waiting on RunComplete (e.g. `crush run` with a RunID) always +// observes exactly one terminal event regardless of which Run branch ends +// the turn. +func (a *sessionAgent) publishRunComplete(ctx context.Context, call SessionAgentCall, complete notify.RunComplete) { + if call.OnComplete != nil { + call.OnComplete(complete) + return + } + if a.runComplete == nil { + return + } + a.runComplete.PublishMustDeliver(ctx, pubsub.UpdatedEvent, complete) +} + // ValidateCall performs the cheap structural validation that // sessionAgent.Run requires before a call can be dispatched: a call must // carry either a non-empty prompt or a text attachment, and it must name a @@ -394,22 +474,39 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (result * // Serialize the accepted -> (cancel-on-entry | queued | // active) transition against a concurrent Cancel. Cancel takes // the same per-session lock, so every cancel observes at least - // one of: pendingCancels, an activeRequests entry, or a + // one of: a cancel mark, an activeRequests entry, or a // messageQueue entry it then clears. mu := a.sessionMu(call.SessionID) mu.Lock() - if _, pending := a.pendingCancels.Get(call.SessionID); pending { + if a.canceledBySeq(call.SessionID, call.Accepted.seq) { // Cancel-on-entry: a cancel arrived while this run was - // dispatched but not yet active. Consume the pending - // cancel, release the accept reservation, drop the lock, - // and persist a canceled turn without entering Stream. - a.pendingCancels.Del(call.SessionID) + // dispatched but not yet active, and this handle's accept + // sequence is at or below the session's cancel mark. The + // mark is left in place so sibling handles it also covers + // observe the same cancel; release the accept reservation, + // drop the lock, and persist a canceled turn without + // entering Stream. + // + // This path returns before the streaming defer that + // publishes RunComplete is installed, so emit the terminal + // event explicitly. Without it, a caller waiting on + // RunComplete for this RunID (e.g. `crush run`, which + // ignores message events and blocks on RunComplete) would + // hang on an immediately-canceled accepted run. call.Accepted.Close() mu.Unlock() + complete := notify.RunComplete{ + SessionID: call.SessionID, + RunID: call.RunID, + Cancelled: true, + } if err := a.persistCanceledTurn(ctx, call, false); err != nil { + complete.Error = err.Error() + a.publishRunComplete(ctx, call, complete) return nil, err } + a.publishRunComplete(ctx, call, complete) return nil, nil } @@ -579,14 +676,7 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (result * // the authoritative terminal event so a momentarily-full // subscriber channel can't silently drop it and hang // non-interactive clients waiting on RunComplete. - if call.OnComplete != nil { - call.OnComplete(complete) - return - } - if a.runComplete == nil { - return - } - a.runComplete.PublishMustDeliver(ctx, pubsub.UpdatedEvent, complete) + a.publishRunComplete(ctx, call, complete) }() history, files := a.preparePrompt(msgs, largeModel.CatwalkCfg.SupportsImages, call.Attachments...) @@ -621,24 +711,26 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (result * // Use latest tools (updated by SetTools when MCP tools change). prepared.Tools = a.tools.Copy() - // Drain queued follow-up prompts, but skip them if a cancel - // was recorded for the session while they sat in the queue: - // a cancel that arrived after the queue insertion must not - // let the queued prompt run as part of this step. + // Drain queued follow-up prompts, but skip any covered by a + // cancel recorded while they sat in the queue: a cancel that + // arrived after a prompt was queued must not let it run as + // part of this step. Coverage is per-call by accept sequence + // so a follow-up queued after the cancel (higher seq) is + // still folded in. dispatchLock := a.sessionMu(call.SessionID) dispatchLock.Lock() - _, canceled := a.pendingCancels.Get(call.SessionID) queuedCalls, _ := a.messageQueue.Get(call.SessionID) a.messageQueue.Del(call.SessionID) dispatchLock.Unlock() - if !canceled { - for _, queued := range queuedCalls { - userMessage, createErr := a.createUserMessage(callContext, queued) - if createErr != nil { - return callContext, prepared, createErr - } - prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...) + for _, queued := range queuedCalls { + if a.canceledBySeq(call.SessionID, queued.acceptSeq) { + continue } + userMessage, createErr := a.createUserMessage(callContext, queued) + if createErr != nil { + return callContext, prepared, createErr + } + prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...) } prepared.Messages = a.workaroundProviderMediaLimitations(prepared.Messages, largeModel) @@ -1008,20 +1100,37 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (result * // closing the dequeue -> re-register window. mu := a.sessionMu(call.SessionID) mu.Lock() - if _, pending := a.pendingCancels.Get(call.SessionID); pending { + queuedMessages, _ := a.messageQueue.Get(call.SessionID) + if mark, ok := a.cancelMark.Get(call.SessionID); ok && mark > 0 && len(queuedMessages) > 0 { // A cancel was recorded for this session (e.g. it arrived while - // this run was active and a follow-up had been accepted). Drop - // the queue instead of running it and consume the marker. - a.pendingCancels.Del(call.SessionID) + // this run was active and follow-ups had been queued). Drop the + // queued prompts it covers (accept sequence at or below the + // mark, or untracked); keep any queued after the cancel (higher + // sequence) so they still run. + var kept []SessionAgentCall + for _, q := range queuedMessages { + if q.acceptSeq == 0 || q.acceptSeq <= mark { + continue + } + kept = append(kept, q) + } + queuedMessages = kept + a.messageQueue.Set(call.SessionID, kept) + } + if len(queuedMessages) == 0 { + // No queued work. Clear the cancel mark only when no accepted + // run remains in flight that it might still cover; otherwise a + // sibling prompt (sequence at or below the mark) waiting to + // enter Run would lose its cancellation. When accepted runs are + // gone, this also clears a stale mark so it can't catch a + // future run. a.messageQueue.Del(call.SessionID) - mu.Unlock() - return result, err - } - queuedMessages, ok := a.messageQueue.Get(call.SessionID) - if !ok || len(queuedMessages) == 0 { - // No queued work. Clear any stale pending-cancel entry as a - // safety net so it can't catch a future run (no-op when unset). - a.pendingCancels.Del(call.SessionID) + a.acceptedMu.Lock() + inFlight, _ := a.acceptedRuns.Get(call.SessionID) + a.acceptedMu.Unlock() + if inFlight == 0 { + a.cancelMark.Del(call.SessionID) + } mu.Unlock() return result, err } @@ -1619,17 +1728,27 @@ func (a *sessionAgent) Cancel(sessionID string) { } // Record a pending cancel only when a dispatched-but-not-yet-active - // run exists. This catches a run still in the goroutine scheduler or + // run exists. This catches runs still in the goroutine scheduler or // about to enter Run's busy-queue branch, while leaving an idle // session untouched. Active and accepted are not mutually exclusive: // when a run is active and a follow-up has been accepted, both the // cancel above and this pending record fire. + // + // Raise the session's cancel mark to the latest accept sequence + // assigned so far. Every prompt currently accepted-but-not-yet- + // active has a sequence at or below that value, so one cancel covers + // all of them; a prompt accepted after this cancel gets a strictly + // higher sequence and is never poisoned. Using max keeps repeated + // cancels idempotent while the same prompts are in flight and lets a + // later cancel extend coverage to prompts accepted since. a.acceptedMu.Lock() count, ok := a.acceptedRuns.Get(sessionID) + mark := a.acceptSeqGen a.acceptedMu.Unlock() if ok && count > 0 { - slog.Debug("Recording pending cancel for accepted run", "session_id", sessionID) - a.pendingCancels.Set(sessionID, struct{}{}) + slog.Debug("Recording cancel mark for accepted runs", "session_id", sessionID, "count", count, "mark", mark) + existing, _ := a.cancelMark.Get(sessionID) + a.cancelMark.Set(sessionID, max(existing, mark)) } if a.QueuedPrompts(sessionID) > 0 { diff --git a/internal/agent/dispatch_cancel_test.go b/internal/agent/dispatch_cancel_test.go index f66de252e63559239c1d577fe51c0650589aa5b4..f1b0faad21da845f16728fd2d1101b64c569f2dc 100644 --- a/internal/agent/dispatch_cancel_test.go +++ b/internal/agent/dispatch_cancel_test.go @@ -5,9 +5,12 @@ import ( "errors" "sync/atomic" "testing" + "time" "charm.land/fantasy" + "github.com/charmbracelet/crush/internal/agent/notify" "github.com/charmbracelet/crush/internal/message" + "github.com/charmbracelet/crush/internal/pubsub" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -140,7 +143,7 @@ func TestRun_PrepareStepDrainSkipsQueuedOnPendingCancel(t *testing.T) { // A follow-up prompt sits queued for the session. sa.enqueueCall(SessionAgentCall{SessionID: sess.ID, Prompt: "queued-followup"}) // A cancel was recorded for the session while it sat in the queue. - sa.pendingCancels.Set(sess.ID, struct{}{}) + sa.cancelMark.Set(sess.ID, 1) result, err := sa.Run(t.Context(), SessionAgentCall{ SessionID: sess.ID, @@ -177,8 +180,8 @@ func TestRun_NormalCompletionClearsStalePendingCancel(t *testing.T) { sess, err := env.sessions.Create(t.Context(), "session") require.NoError(t, err) - // A stale pending cancel lingers (no queued work, no accepted run). - sa.pendingCancels.Set(sess.ID, struct{}{}) + // A stale cancel mark lingers (no queued work, no accepted run). + sa.cancelMark.Set(sess.ID, 1) result, err := sa.Run(t.Context(), SessionAgentCall{ SessionID: sess.ID, @@ -196,3 +199,249 @@ func TestRun_NormalCompletionClearsStalePendingCancel(t *testing.T) { assert.Equal(t, message.Assistant, msgs[1].Role) assert.Equal(t, message.FinishReasonEndTurn, msgs[1].FinishReason()) } + +// newCancelTestAgentWithRunComplete builds a DB-backed sessionAgent wired +// to a RunComplete broker so tests can observe the terminal event a +// RunID-bearing caller (e.g. `crush run`) blocks on. +func newCancelTestAgentWithRunComplete(t *testing.T) (*sessionAgent, fakeEnv, *pubsub.Broker[notify.RunComplete]) { + t.Helper() + env := testEnv(t) + broker := pubsub.NewBroker[notify.RunComplete]() + t.Cleanup(broker.Shutdown) + sa := NewSessionAgent(SessionAgentOptions{ + Sessions: env.sessions, + Messages: env.messages, + RunComplete: broker, + }).(*sessionAgent) + return sa, env, broker +} + +// TestRun_CancelOnEntryPublishesRunComplete covers the first review +// finding: the cancel-on-entry path returned before the streaming defer +// that publishes RunComplete was installed. A caller that dispatches a +// run with a RunID and blocks on RunComplete (ignoring message events, +// like `crush run`) would hang on an immediately-canceled accepted run. +// The cancel-on-entry path must publish a terminal RunComplete carrying +// the originating RunID. +func TestRun_CancelOnEntryPublishesRunComplete(t *testing.T) { + t.Parallel() + sa, env, broker := newCancelTestAgentWithRunComplete(t) + + sess, err := env.sessions.Create(t.Context(), "session") + require.NoError(t, err) + + ctx, cancel := context.WithCancel(t.Context()) + defer cancel() + ch := broker.Subscribe(ctx) + + accept := sa.BeginAccepted(sess.ID) + // A cancel arrives in the accepted-but-not-yet-active window. + sa.Cancel(sess.ID) + require.True(t, sa.hasPendingCancel(sess.ID)) + + result, err := sa.Run(t.Context(), SessionAgentCall{ + SessionID: sess.ID, + RunID: "run-cancel-on-entry", + Prompt: "hello", + Accepted: accept, + }) + require.NoError(t, err) + require.Nil(t, result) + + select { + case got := <-ch: + assert.Equal(t, "run-cancel-on-entry", got.Payload.RunID, + "RunComplete must echo the originating RunID") + assert.Equal(t, sess.ID, got.Payload.SessionID) + assert.True(t, got.Payload.Cancelled, + "cancel-on-entry RunComplete must be marked cancelled") + case <-time.After(2 * time.Second): + t.Fatal("cancel-on-entry must publish RunComplete; a RunID caller would hang otherwise") + } +} + +// TestCancel_TwoAcceptedBothObserveCancellation covers the second review +// finding: a single cancel with two accepted-not-yet-active prompts must +// cancel both. The cancel raises the session's high-water mark to the +// latest accept sequence, so every prompt accepted-but-not-yet-active at +// cancel time is covered and both take the cancel-on-entry path. +func TestCancel_TwoAcceptedBothObserveCancellation(t *testing.T) { + t.Parallel() + sa, env := newCancelTestAgent(t) + + sess, err := env.sessions.Create(t.Context(), "session") + require.NoError(t, err) + + // Two prompts are accepted-but-not-yet-active for the same session. + accept1 := sa.BeginAccepted(sess.ID) + accept2 := sa.BeginAccepted(sess.ID) + require.Equal(t, 2, sa.acceptedCount(sess.ID)) + + // A single cancel arrives before either becomes active. + sa.Cancel(sess.ID) + require.Equal(t, accept2.seq, sa.pendingCancelMark(sess.ID), + "one cancel must mark every currently-accepted prompt as canceled") + require.GreaterOrEqual(t, sa.pendingCancelMark(sess.ID), accept1.seq, + "the mark must cover the earlier accepted prompt too") + + // Both prompts enter Run; each must take cancel-on-entry, not run. + r1, err := sa.Run(t.Context(), SessionAgentCall{ + SessionID: sess.ID, + Prompt: "first", + Accepted: accept1, + }) + require.NoError(t, err) + require.Nil(t, r1) + + r2, err := sa.Run(t.Context(), SessionAgentCall{ + SessionID: sess.ID, + Prompt: "second", + Accepted: accept2, + }) + require.NoError(t, err) + require.Nil(t, r2) + + require.False(t, sa.hasPendingCancel(sess.ID), + "both reserved units must be consumed") + require.Equal(t, 0, sa.acceptedCount(sess.ID), + "both accept reservations must be released") + + // Each canceled-on-entry turn writes a user + canceled assistant + // message, and neither prompt was enqueued to run normally. + require.Equal(t, 0, sa.QueuedPrompts(sess.ID), + "neither accepted prompt may be enqueued to run normally") + msgs, err := env.messages.List(t.Context(), sess.ID) + require.NoError(t, err) + require.Len(t, msgs, 4, "two canceled turns produce two user + two assistant messages") + var canceled int + for _, m := range msgs { + if m.Role == message.Assistant { + assert.Equal(t, message.FinishReasonCanceled, m.FinishReason()) + canceled++ + } + } + require.Equal(t, 2, canceled, "both turns must finish canceled") +} + +// TestRun_IdleCancelDoesNotPoisonNextPrompt covers the idle-cancel +// no-op guarantee end-to-end: an Escape on an idle session must not +// record a pending cancel that leaks into the next accepted prompt, which +// must run normally to completion. +func TestRun_IdleCancelDoesNotPoisonNextPrompt(t *testing.T) { + t.Parallel() + sa, env := newStreamTestAgent(t) + + sess, err := env.sessions.Create(t.Context(), "session") + require.NoError(t, err) + + // Idle Escape: no accepted run, no active request. + sa.Cancel(sess.ID) + require.False(t, sa.hasPendingCancel(sess.ID), + "idle cancel must not record a pending cancel") + + // The next accepted prompt must run normally, not cancel on entry. + accept := sa.BeginAccepted(sess.ID) + result, err := sa.Run(t.Context(), SessionAgentCall{ + SessionID: sess.ID, + Prompt: "next", + Accepted: accept, + }) + require.NoError(t, err) + require.NotNil(t, result, "next prompt must run normally after an idle cancel") + + msgs, err := env.messages.List(t.Context(), sess.ID) + require.NoError(t, err) + require.Len(t, msgs, 2) + assert.Equal(t, message.User, msgs[0].Role) + assert.Equal(t, message.Assistant, msgs[1].Role) + assert.Equal(t, message.FinishReasonEndTurn, msgs[1].FinishReason(), + "the prompt must finish normally, not canceled") +} + +// TestCancel_AcceptedAfterCancelIsNotPoisoned is the regression test for +// the reviewer's finding: a counted session-level pending cancel let a +// prompt accepted after the cancel enter Run first and consume a unit +// reserved for the earlier prompts. With a sequence high-water mark, a +// single cancel covers exactly the prompts accepted-but-not-yet-active at +// cancel time (A and B); a prompt accepted afterwards (C) gets a higher +// sequence and must run normally without consuming A or B's cancellation. +// C is run first to prove it neither cancels nor drains the mark, then A +// and B are run and must both cancel on entry. +func TestCancel_AcceptedAfterCancelIsNotPoisoned(t *testing.T) { + t.Parallel() + sa, env := newStreamTestAgent(t) + + sess, err := env.sessions.Create(t.Context(), "session") + require.NoError(t, err) + + // A and B are accepted-but-not-yet-active. + acceptA := sa.BeginAccepted(sess.ID) + acceptB := sa.BeginAccepted(sess.ID) + + // One cancel arrives covering both A and B. + sa.Cancel(sess.ID) + require.True(t, sa.hasPendingCancel(sess.ID)) + require.Equal(t, acceptB.seq, sa.pendingCancelMark(sess.ID), + "the mark must cover every prompt accepted before the cancel") + + // C is accepted AFTER the cancel; its sequence is above the mark. + acceptC := sa.BeginAccepted(sess.ID) + require.Greater(t, acceptC.seq, sa.pendingCancelMark(sess.ID), + "a prompt accepted after the cancel must not be covered by the mark") + + // Run C first. It must run normally to completion and must not + // consume or clear the cancellation reserved for A and B. + rc, err := sa.Run(t.Context(), SessionAgentCall{ + SessionID: sess.ID, + Prompt: "C", + Accepted: acceptC, + }) + require.NoError(t, err) + require.NotNil(t, rc, "C was accepted after the cancel and must run normally") + require.True(t, sa.hasPendingCancel(sess.ID), + "running C must not drain the cancellation reserved for A and B") + + // Now A and B run. Both were accepted before the cancel and must + // take the cancel-on-entry path. + ra, err := sa.Run(t.Context(), SessionAgentCall{ + SessionID: sess.ID, + Prompt: "A", + Accepted: acceptA, + }) + require.NoError(t, err) + require.Nil(t, ra, "A must cancel on entry, not run") + + rb, err := sa.Run(t.Context(), SessionAgentCall{ + SessionID: sess.ID, + Prompt: "B", + Accepted: acceptB, + }) + require.NoError(t, err) + require.Nil(t, rb, "B must cancel on entry, not run") + + require.False(t, sa.hasPendingCancel(sess.ID), + "the mark clears once all covered handles are resolved") + require.Equal(t, 0, sa.acceptedCount(sess.ID)) + require.Equal(t, 0, sa.QueuedPrompts(sess.ID), + "neither A nor B may be enqueued to run normally") + + // C produced a normal turn; A and B each produced a canceled turn. + msgs, err := env.messages.List(t.Context(), sess.ID) + require.NoError(t, err) + require.Len(t, msgs, 6, "C normal + A canceled + B canceled = 3 user + 3 assistant") + + var normal, canceled int + for _, m := range msgs { + if m.Role != message.Assistant { + continue + } + switch m.FinishReason() { + case message.FinishReasonEndTurn: + normal++ + case message.FinishReasonCanceled: + canceled++ + } + } + require.Equal(t, 1, normal, "only C finished normally") + require.Equal(t, 2, canceled, "both A and B finished canceled") +}