@@ -112,6 +112,15 @@ type SessionAgentCall struct {
// (in-process / local callers like AppWorkspace), behavior is
// unchanged and no accept tracking applies.
Accepted *AcceptedRun
+ // acceptSeq carries the accept sequence of the handle that produced
+ // this call after it has been enqueued and its Accepted handle
+ // stripped. The queue-drain paths compare it against a session's
+ // cancel mark so a follow-up queued before a cancel is dropped while
+ // one queued after the cancel survives. 0 means untracked (an
+ // in-process enqueue with no accept reservation), which the drain
+ // paths treat as covered by any present mark, preserving the
+ // pre-sequence behavior.
+ acceptSeq uint64
}
type SessionAgent interface {
@@ -167,20 +176,31 @@ type sessionAgent struct {
// BeginAccepted increments it; only AcceptedRun.Close decrements
// it.
acceptedRuns *csync.Map[string, int]
- // pendingCancels records sessions whose dispatched-but-not-yet-
- // running call should observe a cancellation request. It is only
- // set by Cancel when acceptedRuns > 0, so an idle Escape never
- // poisons the next prompt.
- pendingCancels *csync.Map[string, struct{}]
+ // cancelMark records, per session, a high-water accept sequence: an
+ // accepted handle is canceled by it iff the handle's sequence is at
+ // or below the mark. Cancel raises the mark to the latest sequence
+ // assigned at cancel time, so a single Cancel covers every prompt
+ // accepted-but-not-yet-active then, while a prompt accepted later
+ // (higher sequence) is never poisoned. Absent or 0 means no pending
+ // cancel. It is only raised by Cancel when acceptedRuns > 0, so an
+ // idle Escape never records a mark.
+ cancelMark *csync.Map[string, uint64]
// dispatchMuCreate guards lazy creation of per-session entries in
// dispatchMu so two goroutines can't race to lock different mutex
// instances for the same session.
dispatchMuCreate sync.Mutex
- // acceptedMu serializes increments/decrements of acceptedRuns. It
+ // acceptedMu serializes increments/decrements of acceptedRuns and
+ // the assignment of accept sequence numbers from acceptSeqGen. It
// is separate from dispatchMu so AcceptedRun.Close (which may run
// while Run holds dispatchMu for the same session) does not
// deadlock by re-entering the dispatch lock.
acceptedMu sync.Mutex
+ // acceptSeqGen is the monotonic source of accept sequence numbers.
+ // Each BeginAccepted increments it under acceptedMu and stamps the
+ // returned handle, so sequences strictly increase in accept order
+ // across the agent. Cancel uses its current value as the per-session
+ // high-water mark.
+ acceptSeqGen uint64
}
type SessionAgentOptions struct {
@@ -218,7 +238,7 @@ func NewSessionAgent(
activeRequests: csync.NewMap[string, context.CancelFunc](),
dispatchMu: csync.NewMap[string, *sync.Mutex](),
acceptedRuns: csync.NewMap[string, int](),
- pendingCancels: csync.NewMap[string, struct{}](),
+ cancelMark: csync.NewMap[string, uint64](),
}
}
@@ -231,7 +251,12 @@ func NewSessionAgent(
type AcceptedRun struct {
agent *sessionAgent
sessionID string
- done atomic.Bool
+ // seq is the monotonic accept sequence stamped by BeginAccepted. A
+ // cancel covers this handle iff seq is at or below the session's
+ // cancel mark, so a handle accepted after a cancel (higher seq) is
+ // never poisoned by it.
+ seq uint64
+ done atomic.Bool
}
// Close decrements the accept counter for this reservation. It is safe
@@ -263,19 +288,30 @@ func (a *sessionAgent) BeginAccepted(sessionID string) *AcceptedRun {
defer a.acceptedMu.Unlock()
count, _ := a.acceptedRuns.Get(sessionID)
a.acceptedRuns.Set(sessionID, count+1)
- return &AcceptedRun{agent: a, sessionID: sessionID}
+ a.acceptSeqGen++
+ return &AcceptedRun{agent: a, sessionID: sessionID, seq: a.acceptSeqGen}
}
// endAccepted decrements the accept counter for sessionID. It is only
// called via AcceptedRun.Close. It uses a dedicated lock (not the
// per-session dispatch mutex) so it can run while Run holds dispatchMu
// for the same session without deadlocking.
+//
+// When the count reaches zero the session's cancel mark is dropped: no
+// accepted handle remains for it to cover, and any handle accepted later
+// gets a strictly higher sequence that the mark would not match anyway.
+// Handles canceled on entry never reach RunComplete, so this is the only
+// place that clears the mark for an all-canceled batch. Sibling handles
+// covered by the same mark are serialized on the per-session dispatch
+// mutex and read the mark before they Close, so this never clears it out
+// from under a covered handle still waiting to enter Run.
func (a *sessionAgent) endAccepted(sessionID string) {
a.acceptedMu.Lock()
defer a.acceptedMu.Unlock()
count, ok := a.acceptedRuns.Get(sessionID)
if !ok || count <= 1 {
a.acceptedRuns.Del(sessionID)
+ a.cancelMark.Del(sessionID)
return
}
a.acceptedRuns.Set(sessionID, count-1)
@@ -311,20 +347,44 @@ func (a *sessionAgent) enqueueCall(call SessionAgentCall) {
existing = []SessionAgentCall{}
}
queued := call
+ if call.Accepted != nil {
+ // Preserve the accept sequence after the handle is stripped so
+ // the queue-drain paths can tell a follow-up queued before a
+ // cancel (covered by the mark) from one queued after it.
+ queued.acceptSeq = call.Accepted.seq
+ }
queued.OnComplete = nil
queued.Accepted = nil
existing = append(existing, queued)
a.messageQueue.Set(call.SessionID, existing)
}
-// clearPendingCancel removes any pending-cancel record for sessionID. It
-// takes the per-session dispatch lock so it is ordered against Cancel and
-// the dispatch handoff.
+// clearPendingCancel removes any pending-cancel mark for sessionID. It
+// takes the per-session dispatch lock so it is ordered against Cancel
+// and the dispatch handoff.
func (a *sessionAgent) clearPendingCancel(sessionID string) {
mu := a.sessionMu(sessionID)
mu.Lock()
defer mu.Unlock()
- a.pendingCancels.Del(sessionID)
+ a.cancelMark.Del(sessionID)
+}
+
+// canceledBySeq reports whether an accepted handle or queued call with
+// the given accept sequence is covered by a pending cancel for the
+// session. Callers must hold the session's dispatch mutex. A tracked
+// sequence (seq > 0) is covered only when it is at or below the cancel
+// high-water mark, so a prompt accepted after the cancel (higher seq) is
+// never poisoned. An untracked sequence (seq == 0, an in-process enqueue
+// with no accept reservation) is covered whenever any mark is present,
+// preserving the pre-sequence behavior. The mark is not consumed: it
+// stays so every sibling handle it covers observes the same cancel, and
+// a later handle (higher seq) ignores it regardless.
+func (a *sessionAgent) canceledBySeq(sessionID string, seq uint64) bool {
+ mark, ok := a.cancelMark.Get(sessionID)
+ if !ok || mark == 0 {
+ return false
+ }
+ return seq == 0 || seq <= mark
}
// persistCanceledTurn writes the user/assistant records for a turn that
@@ -356,6 +416,26 @@ func (a *sessionAgent) persistCanceledTurn(ctx context.Context, call SessionAgen
return a.messages.Update(writeCtx, assistant)
}
+// publishRunComplete emits the authoritative terminal event for a turn.
+// It honors the per-call OnComplete hook when set (so the coordinator can
+// coalesce retries) and otherwise falls back to the RunComplete broker.
+// ctx is used only for the bounded-blocking must-deliver publish; the
+// terminal payload is supplied by the caller. This is the single emit path
+// shared by the streaming defer and the cancel-on-entry early return so a
+// caller waiting on RunComplete (e.g. `crush run` with a RunID) always
+// observes exactly one terminal event regardless of which Run branch ends
+// the turn.
+func (a *sessionAgent) publishRunComplete(ctx context.Context, call SessionAgentCall, complete notify.RunComplete) {
+ if call.OnComplete != nil {
+ call.OnComplete(complete)
+ return
+ }
+ if a.runComplete == nil {
+ return
+ }
+ a.runComplete.PublishMustDeliver(ctx, pubsub.UpdatedEvent, complete)
+}
+
// ValidateCall performs the cheap structural validation that
// sessionAgent.Run requires before a call can be dispatched: a call must
// carry either a non-empty prompt or a text attachment, and it must name a
@@ -394,22 +474,39 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (result *
// Serialize the accepted -> (cancel-on-entry | queued |
// active) transition against a concurrent Cancel. Cancel takes
// the same per-session lock, so every cancel observes at least
- // one of: pendingCancels, an activeRequests entry, or a
+ // one of: a cancel mark, an activeRequests entry, or a
// messageQueue entry it then clears.
mu := a.sessionMu(call.SessionID)
mu.Lock()
- if _, pending := a.pendingCancels.Get(call.SessionID); pending {
+ if a.canceledBySeq(call.SessionID, call.Accepted.seq) {
// Cancel-on-entry: a cancel arrived while this run was
- // dispatched but not yet active. Consume the pending
- // cancel, release the accept reservation, drop the lock,
- // and persist a canceled turn without entering Stream.
- a.pendingCancels.Del(call.SessionID)
+ // dispatched but not yet active, and this handle's accept
+ // sequence is at or below the session's cancel mark. The
+ // mark is left in place so sibling handles it also covers
+ // observe the same cancel; release the accept reservation,
+ // drop the lock, and persist a canceled turn without
+ // entering Stream.
+ //
+ // This path returns before the streaming defer that
+ // publishes RunComplete is installed, so emit the terminal
+ // event explicitly. Without it, a caller waiting on
+ // RunComplete for this RunID (e.g. `crush run`, which
+ // ignores message events and blocks on RunComplete) would
+ // hang on an immediately-canceled accepted run.
call.Accepted.Close()
mu.Unlock()
+ complete := notify.RunComplete{
+ SessionID: call.SessionID,
+ RunID: call.RunID,
+ Cancelled: true,
+ }
if err := a.persistCanceledTurn(ctx, call, false); err != nil {
+ complete.Error = err.Error()
+ a.publishRunComplete(ctx, call, complete)
return nil, err
}
+ a.publishRunComplete(ctx, call, complete)
return nil, nil
}
@@ -579,14 +676,7 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (result *
// the authoritative terminal event so a momentarily-full
// subscriber channel can't silently drop it and hang
// non-interactive clients waiting on RunComplete.
- if call.OnComplete != nil {
- call.OnComplete(complete)
- return
- }
- if a.runComplete == nil {
- return
- }
- a.runComplete.PublishMustDeliver(ctx, pubsub.UpdatedEvent, complete)
+ a.publishRunComplete(ctx, call, complete)
}()
history, files := a.preparePrompt(msgs, largeModel.CatwalkCfg.SupportsImages, call.Attachments...)
@@ -621,24 +711,26 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (result *
// Use latest tools (updated by SetTools when MCP tools change).
prepared.Tools = a.tools.Copy()
- // Drain queued follow-up prompts, but skip them if a cancel
- // was recorded for the session while they sat in the queue:
- // a cancel that arrived after the queue insertion must not
- // let the queued prompt run as part of this step.
+ // Drain queued follow-up prompts, but skip any covered by a
+ // cancel recorded while they sat in the queue: a cancel that
+ // arrived after a prompt was queued must not let it run as
+ // part of this step. Coverage is per-call by accept sequence
+ // so a follow-up queued after the cancel (higher seq) is
+ // still folded in.
dispatchLock := a.sessionMu(call.SessionID)
dispatchLock.Lock()
- _, canceled := a.pendingCancels.Get(call.SessionID)
queuedCalls, _ := a.messageQueue.Get(call.SessionID)
a.messageQueue.Del(call.SessionID)
dispatchLock.Unlock()
- if !canceled {
- for _, queued := range queuedCalls {
- userMessage, createErr := a.createUserMessage(callContext, queued)
- if createErr != nil {
- return callContext, prepared, createErr
- }
- prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
+ for _, queued := range queuedCalls {
+ if a.canceledBySeq(call.SessionID, queued.acceptSeq) {
+ continue
}
+ userMessage, createErr := a.createUserMessage(callContext, queued)
+ if createErr != nil {
+ return callContext, prepared, createErr
+ }
+ prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
}
prepared.Messages = a.workaroundProviderMediaLimitations(prepared.Messages, largeModel)
@@ -1008,20 +1100,37 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (result *
// closing the dequeue -> re-register window.
mu := a.sessionMu(call.SessionID)
mu.Lock()
- if _, pending := a.pendingCancels.Get(call.SessionID); pending {
+ queuedMessages, _ := a.messageQueue.Get(call.SessionID)
+ if mark, ok := a.cancelMark.Get(call.SessionID); ok && mark > 0 && len(queuedMessages) > 0 {
// A cancel was recorded for this session (e.g. it arrived while
- // this run was active and a follow-up had been accepted). Drop
- // the queue instead of running it and consume the marker.
- a.pendingCancels.Del(call.SessionID)
+ // this run was active and follow-ups had been queued). Drop the
+ // queued prompts it covers (accept sequence at or below the
+ // mark, or untracked); keep any queued after the cancel (higher
+ // sequence) so they still run.
+ var kept []SessionAgentCall
+ for _, q := range queuedMessages {
+ if q.acceptSeq == 0 || q.acceptSeq <= mark {
+ continue
+ }
+ kept = append(kept, q)
+ }
+ queuedMessages = kept
+ a.messageQueue.Set(call.SessionID, kept)
+ }
+ if len(queuedMessages) == 0 {
+ // No queued work. Clear the cancel mark only when no accepted
+ // run remains in flight that it might still cover; otherwise a
+ // sibling prompt (sequence at or below the mark) waiting to
+ // enter Run would lose its cancellation. When accepted runs are
+ // gone, this also clears a stale mark so it can't catch a
+ // future run.
a.messageQueue.Del(call.SessionID)
- mu.Unlock()
- return result, err
- }
- queuedMessages, ok := a.messageQueue.Get(call.SessionID)
- if !ok || len(queuedMessages) == 0 {
- // No queued work. Clear any stale pending-cancel entry as a
- // safety net so it can't catch a future run (no-op when unset).
- a.pendingCancels.Del(call.SessionID)
+ a.acceptedMu.Lock()
+ inFlight, _ := a.acceptedRuns.Get(call.SessionID)
+ a.acceptedMu.Unlock()
+ if inFlight == 0 {
+ a.cancelMark.Del(call.SessionID)
+ }
mu.Unlock()
return result, err
}
@@ -1619,17 +1728,27 @@ func (a *sessionAgent) Cancel(sessionID string) {
}
// Record a pending cancel only when a dispatched-but-not-yet-active
- // run exists. This catches a run still in the goroutine scheduler or
+ // run exists. This catches runs still in the goroutine scheduler or
// about to enter Run's busy-queue branch, while leaving an idle
// session untouched. Active and accepted are not mutually exclusive:
// when a run is active and a follow-up has been accepted, both the
// cancel above and this pending record fire.
+ //
+ // Raise the session's cancel mark to the latest accept sequence
+ // assigned so far. Every prompt currently accepted-but-not-yet-
+ // active has a sequence at or below that value, so one cancel covers
+ // all of them; a prompt accepted after this cancel gets a strictly
+ // higher sequence and is never poisoned. Using max keeps repeated
+ // cancels idempotent while the same prompts are in flight and lets a
+ // later cancel extend coverage to prompts accepted since.
a.acceptedMu.Lock()
count, ok := a.acceptedRuns.Get(sessionID)
+ mark := a.acceptSeqGen
a.acceptedMu.Unlock()
if ok && count > 0 {
- slog.Debug("Recording pending cancel for accepted run", "session_id", sessionID)
- a.pendingCancels.Set(sessionID, struct{}{})
+ slog.Debug("Recording cancel mark for accepted runs", "session_id", sessionID, "count", count, "mark", mark)
+ existing, _ := a.cancelMark.Get(sessionID)
+ a.cancelMark.Set(sessionID, max(existing, mark))
}
if a.QueuedPrompts(sessionID) > 0 {
@@ -5,9 +5,12 @@ import (
"errors"
"sync/atomic"
"testing"
+ "time"
"charm.land/fantasy"
+ "github.com/charmbracelet/crush/internal/agent/notify"
"github.com/charmbracelet/crush/internal/message"
+ "github.com/charmbracelet/crush/internal/pubsub"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@@ -140,7 +143,7 @@ func TestRun_PrepareStepDrainSkipsQueuedOnPendingCancel(t *testing.T) {
// A follow-up prompt sits queued for the session.
sa.enqueueCall(SessionAgentCall{SessionID: sess.ID, Prompt: "queued-followup"})
// A cancel was recorded for the session while it sat in the queue.
- sa.pendingCancels.Set(sess.ID, struct{}{})
+ sa.cancelMark.Set(sess.ID, 1)
result, err := sa.Run(t.Context(), SessionAgentCall{
SessionID: sess.ID,
@@ -177,8 +180,8 @@ func TestRun_NormalCompletionClearsStalePendingCancel(t *testing.T) {
sess, err := env.sessions.Create(t.Context(), "session")
require.NoError(t, err)
- // A stale pending cancel lingers (no queued work, no accepted run).
- sa.pendingCancels.Set(sess.ID, struct{}{})
+ // A stale cancel mark lingers (no queued work, no accepted run).
+ sa.cancelMark.Set(sess.ID, 1)
result, err := sa.Run(t.Context(), SessionAgentCall{
SessionID: sess.ID,
@@ -196,3 +199,249 @@ func TestRun_NormalCompletionClearsStalePendingCancel(t *testing.T) {
assert.Equal(t, message.Assistant, msgs[1].Role)
assert.Equal(t, message.FinishReasonEndTurn, msgs[1].FinishReason())
}
+
+// newCancelTestAgentWithRunComplete builds a DB-backed sessionAgent wired
+// to a RunComplete broker so tests can observe the terminal event a
+// RunID-bearing caller (e.g. `crush run`) blocks on.
+func newCancelTestAgentWithRunComplete(t *testing.T) (*sessionAgent, fakeEnv, *pubsub.Broker[notify.RunComplete]) {
+ t.Helper()
+ env := testEnv(t)
+ broker := pubsub.NewBroker[notify.RunComplete]()
+ t.Cleanup(broker.Shutdown)
+ sa := NewSessionAgent(SessionAgentOptions{
+ Sessions: env.sessions,
+ Messages: env.messages,
+ RunComplete: broker,
+ }).(*sessionAgent)
+ return sa, env, broker
+}
+
+// TestRun_CancelOnEntryPublishesRunComplete covers the first review
+// finding: the cancel-on-entry path returned before the streaming defer
+// that publishes RunComplete was installed. A caller that dispatches a
+// run with a RunID and blocks on RunComplete (ignoring message events,
+// like `crush run`) would hang on an immediately-canceled accepted run.
+// The cancel-on-entry path must publish a terminal RunComplete carrying
+// the originating RunID.
+func TestRun_CancelOnEntryPublishesRunComplete(t *testing.T) {
+ t.Parallel()
+ sa, env, broker := newCancelTestAgentWithRunComplete(t)
+
+ sess, err := env.sessions.Create(t.Context(), "session")
+ require.NoError(t, err)
+
+ ctx, cancel := context.WithCancel(t.Context())
+ defer cancel()
+ ch := broker.Subscribe(ctx)
+
+ accept := sa.BeginAccepted(sess.ID)
+ // A cancel arrives in the accepted-but-not-yet-active window.
+ sa.Cancel(sess.ID)
+ require.True(t, sa.hasPendingCancel(sess.ID))
+
+ result, err := sa.Run(t.Context(), SessionAgentCall{
+ SessionID: sess.ID,
+ RunID: "run-cancel-on-entry",
+ Prompt: "hello",
+ Accepted: accept,
+ })
+ require.NoError(t, err)
+ require.Nil(t, result)
+
+ select {
+ case got := <-ch:
+ assert.Equal(t, "run-cancel-on-entry", got.Payload.RunID,
+ "RunComplete must echo the originating RunID")
+ assert.Equal(t, sess.ID, got.Payload.SessionID)
+ assert.True(t, got.Payload.Cancelled,
+ "cancel-on-entry RunComplete must be marked cancelled")
+ case <-time.After(2 * time.Second):
+ t.Fatal("cancel-on-entry must publish RunComplete; a RunID caller would hang otherwise")
+ }
+}
+
+// TestCancel_TwoAcceptedBothObserveCancellation covers the second review
+// finding: a single cancel with two accepted-not-yet-active prompts must
+// cancel both. The cancel raises the session's high-water mark to the
+// latest accept sequence, so every prompt accepted-but-not-yet-active at
+// cancel time is covered and both take the cancel-on-entry path.
+func TestCancel_TwoAcceptedBothObserveCancellation(t *testing.T) {
+ t.Parallel()
+ sa, env := newCancelTestAgent(t)
+
+ sess, err := env.sessions.Create(t.Context(), "session")
+ require.NoError(t, err)
+
+ // Two prompts are accepted-but-not-yet-active for the same session.
+ accept1 := sa.BeginAccepted(sess.ID)
+ accept2 := sa.BeginAccepted(sess.ID)
+ require.Equal(t, 2, sa.acceptedCount(sess.ID))
+
+ // A single cancel arrives before either becomes active.
+ sa.Cancel(sess.ID)
+ require.Equal(t, accept2.seq, sa.pendingCancelMark(sess.ID),
+ "one cancel must mark every currently-accepted prompt as canceled")
+ require.GreaterOrEqual(t, sa.pendingCancelMark(sess.ID), accept1.seq,
+ "the mark must cover the earlier accepted prompt too")
+
+ // Both prompts enter Run; each must take cancel-on-entry, not run.
+ r1, err := sa.Run(t.Context(), SessionAgentCall{
+ SessionID: sess.ID,
+ Prompt: "first",
+ Accepted: accept1,
+ })
+ require.NoError(t, err)
+ require.Nil(t, r1)
+
+ r2, err := sa.Run(t.Context(), SessionAgentCall{
+ SessionID: sess.ID,
+ Prompt: "second",
+ Accepted: accept2,
+ })
+ require.NoError(t, err)
+ require.Nil(t, r2)
+
+ require.False(t, sa.hasPendingCancel(sess.ID),
+ "both reserved units must be consumed")
+ require.Equal(t, 0, sa.acceptedCount(sess.ID),
+ "both accept reservations must be released")
+
+ // Each canceled-on-entry turn writes a user + canceled assistant
+ // message, and neither prompt was enqueued to run normally.
+ require.Equal(t, 0, sa.QueuedPrompts(sess.ID),
+ "neither accepted prompt may be enqueued to run normally")
+ msgs, err := env.messages.List(t.Context(), sess.ID)
+ require.NoError(t, err)
+ require.Len(t, msgs, 4, "two canceled turns produce two user + two assistant messages")
+ var canceled int
+ for _, m := range msgs {
+ if m.Role == message.Assistant {
+ assert.Equal(t, message.FinishReasonCanceled, m.FinishReason())
+ canceled++
+ }
+ }
+ require.Equal(t, 2, canceled, "both turns must finish canceled")
+}
+
+// TestRun_IdleCancelDoesNotPoisonNextPrompt covers the idle-cancel
+// no-op guarantee end-to-end: an Escape on an idle session must not
+// record a pending cancel that leaks into the next accepted prompt, which
+// must run normally to completion.
+func TestRun_IdleCancelDoesNotPoisonNextPrompt(t *testing.T) {
+ t.Parallel()
+ sa, env := newStreamTestAgent(t)
+
+ sess, err := env.sessions.Create(t.Context(), "session")
+ require.NoError(t, err)
+
+ // Idle Escape: no accepted run, no active request.
+ sa.Cancel(sess.ID)
+ require.False(t, sa.hasPendingCancel(sess.ID),
+ "idle cancel must not record a pending cancel")
+
+ // The next accepted prompt must run normally, not cancel on entry.
+ accept := sa.BeginAccepted(sess.ID)
+ result, err := sa.Run(t.Context(), SessionAgentCall{
+ SessionID: sess.ID,
+ Prompt: "next",
+ Accepted: accept,
+ })
+ require.NoError(t, err)
+ require.NotNil(t, result, "next prompt must run normally after an idle cancel")
+
+ msgs, err := env.messages.List(t.Context(), sess.ID)
+ require.NoError(t, err)
+ require.Len(t, msgs, 2)
+ assert.Equal(t, message.User, msgs[0].Role)
+ assert.Equal(t, message.Assistant, msgs[1].Role)
+ assert.Equal(t, message.FinishReasonEndTurn, msgs[1].FinishReason(),
+ "the prompt must finish normally, not canceled")
+}
+
+// TestCancel_AcceptedAfterCancelIsNotPoisoned is the regression test for
+// the reviewer's finding: a counted session-level pending cancel let a
+// prompt accepted after the cancel enter Run first and consume a unit
+// reserved for the earlier prompts. With a sequence high-water mark, a
+// single cancel covers exactly the prompts accepted-but-not-yet-active at
+// cancel time (A and B); a prompt accepted afterwards (C) gets a higher
+// sequence and must run normally without consuming A or B's cancellation.
+// C is run first to prove it neither cancels nor drains the mark, then A
+// and B are run and must both cancel on entry.
+func TestCancel_AcceptedAfterCancelIsNotPoisoned(t *testing.T) {
+ t.Parallel()
+ sa, env := newStreamTestAgent(t)
+
+ sess, err := env.sessions.Create(t.Context(), "session")
+ require.NoError(t, err)
+
+ // A and B are accepted-but-not-yet-active.
+ acceptA := sa.BeginAccepted(sess.ID)
+ acceptB := sa.BeginAccepted(sess.ID)
+
+ // One cancel arrives covering both A and B.
+ sa.Cancel(sess.ID)
+ require.True(t, sa.hasPendingCancel(sess.ID))
+ require.Equal(t, acceptB.seq, sa.pendingCancelMark(sess.ID),
+ "the mark must cover every prompt accepted before the cancel")
+
+ // C is accepted AFTER the cancel; its sequence is above the mark.
+ acceptC := sa.BeginAccepted(sess.ID)
+ require.Greater(t, acceptC.seq, sa.pendingCancelMark(sess.ID),
+ "a prompt accepted after the cancel must not be covered by the mark")
+
+ // Run C first. It must run normally to completion and must not
+ // consume or clear the cancellation reserved for A and B.
+ rc, err := sa.Run(t.Context(), SessionAgentCall{
+ SessionID: sess.ID,
+ Prompt: "C",
+ Accepted: acceptC,
+ })
+ require.NoError(t, err)
+ require.NotNil(t, rc, "C was accepted after the cancel and must run normally")
+ require.True(t, sa.hasPendingCancel(sess.ID),
+ "running C must not drain the cancellation reserved for A and B")
+
+ // Now A and B run. Both were accepted before the cancel and must
+ // take the cancel-on-entry path.
+ ra, err := sa.Run(t.Context(), SessionAgentCall{
+ SessionID: sess.ID,
+ Prompt: "A",
+ Accepted: acceptA,
+ })
+ require.NoError(t, err)
+ require.Nil(t, ra, "A must cancel on entry, not run")
+
+ rb, err := sa.Run(t.Context(), SessionAgentCall{
+ SessionID: sess.ID,
+ Prompt: "B",
+ Accepted: acceptB,
+ })
+ require.NoError(t, err)
+ require.Nil(t, rb, "B must cancel on entry, not run")
+
+ require.False(t, sa.hasPendingCancel(sess.ID),
+ "the mark clears once all covered handles are resolved")
+ require.Equal(t, 0, sa.acceptedCount(sess.ID))
+ require.Equal(t, 0, sa.QueuedPrompts(sess.ID),
+ "neither A nor B may be enqueued to run normally")
+
+ // C produced a normal turn; A and B each produced a canceled turn.
+ msgs, err := env.messages.List(t.Context(), sess.ID)
+ require.NoError(t, err)
+ require.Len(t, msgs, 6, "C normal + A canceled + B canceled = 3 user + 3 assistant")
+
+ var normal, canceled int
+ for _, m := range msgs {
+ if m.Role != message.Assistant {
+ continue
+ }
+ switch m.FinishReason() {
+ case message.FinishReasonEndTurn:
+ normal++
+ case message.FinishReasonCanceled:
+ canceled++
+ }
+ }
+ require.Equal(t, 1, normal, "only C finished normally")
+ require.Equal(t, 2, canceled, "both A and B finished canceled")
+}