From d7a814c540c258f06cb798c29beff676ed4b130b Mon Sep 17 00:00:00 2001 From: Christian Rocha Date: Mon, 1 Jun 2026 20:11:26 -0400 Subject: [PATCH] fix(server): close a queued-prompt cancel race Check queued prompt cancellation while holding the same session handoff lock used for accepted and active prompts. Queue draining now follows the same ordering guarantees as prompt dispatch. Co-Authored-By: Charm Crush --- internal/agent/agent.go | 35 +++++++++++++----- internal/agent/run_complete_test.go | 56 +++++++++++++++++++++++++++++ 2 files changed, 82 insertions(+), 9 deletions(-) diff --git a/internal/agent/agent.go b/internal/agent/agent.go index 393587a111ad806dc26bbf8a80f52dba49ce0397..6728864527b2e4a7f970262f25c3953168a0b5bc 100644 --- a/internal/agent/agent.go +++ b/internal/agent/agent.go @@ -359,6 +359,30 @@ func (a *sessionAgent) enqueueCall(call SessionAgentCall) { a.messageQueue.Set(call.SessionID, existing) } +// drainUncanceledQueue copies and clears the session's message queue and +// returns the queued calls not covered by a pending cancel. The queue +// copy/delete and the cancel-mark check happen under the per-session +// dispatch mutex so the filtering is atomic against a concurrent Cancel: +// canceledBySeq requires the caller to hold that mutex, and evaluating it +// here (rather than after unlocking) prevents a cancel recorded between +// the drain and the check from being observed inconsistently. The +// returned survivors are processed by the caller without the lock held. +func (a *sessionAgent) drainUncanceledQueue(sessionID string) []SessionAgentCall { + dispatchLock := a.sessionMu(sessionID) + dispatchLock.Lock() + defer dispatchLock.Unlock() + queuedCalls, _ := a.messageQueue.Get(sessionID) + a.messageQueue.Del(sessionID) + survivors := queuedCalls[:0] + for _, queued := range queuedCalls { + if a.canceledBySeq(sessionID, queued.acceptSeq) { + continue + } + survivors = append(survivors, queued) + } + return survivors +} + // clearPendingCancel removes any pending-cancel mark for sessionID. It // takes the per-session dispatch lock so it is ordered against Cancel // and the dispatch handoff. @@ -717,15 +741,8 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (result * // part of this step. Coverage is per-call by accept sequence // so a follow-up queued after the cancel (higher seq) is // still folded in. - dispatchLock := a.sessionMu(call.SessionID) - dispatchLock.Lock() - queuedCalls, _ := a.messageQueue.Get(call.SessionID) - a.messageQueue.Del(call.SessionID) - dispatchLock.Unlock() - for _, queued := range queuedCalls { - if a.canceledBySeq(call.SessionID, queued.acceptSeq) { - continue - } + survivors := a.drainUncanceledQueue(call.SessionID) + for _, queued := range survivors { userMessage, createErr := a.createUserMessage(callContext, queued) if createErr != nil { return callContext, prepared, createErr diff --git a/internal/agent/run_complete_test.go b/internal/agent/run_complete_test.go index 74f9232a0946b24d38f05873fa39066dcae40c27..50c5fdf9355c11123d91d26b7ec2e90b219df9d4 100644 --- a/internal/agent/run_complete_test.go +++ b/internal/agent/run_complete_test.go @@ -58,6 +58,62 @@ func TestSessionAgentRun_QueueStripsOnComplete(t *testing.T) { "RunComplete still correlates with the originating SendMessage") } +// TestDrainUncanceledQueue_FiltersUnderDispatchLock verifies that the +// queue drain evaluates the per-session cancel mark while holding the +// dispatch mutex (canceledBySeq's documented precondition). Queued calls +// at or below the cancel high-water mark are dropped, calls queued after +// the cancel (higher seq) survive, untracked enqueues (seq == 0) are +// dropped whenever any mark is present, and the queue is cleared. +func TestDrainUncanceledQueue_FiltersUnderDispatchLock(t *testing.T) { + t.Parallel() + + env := testEnv(t) + a := NewSessionAgent(SessionAgentOptions{ + Sessions: env.sessions, + Messages: env.messages, + }).(*sessionAgent) + + const sessionID = "drain-session" + a.messageQueue.Set(sessionID, []SessionAgentCall{ + {SessionID: sessionID, Prompt: "below", acceptSeq: 1}, + {SessionID: sessionID, Prompt: "at-mark", acceptSeq: 2}, + {SessionID: sessionID, Prompt: "after", acceptSeq: 3}, + {SessionID: sessionID, Prompt: "untracked", acceptSeq: 0}, + }) + // Cancel high-water mark at seq 2: seq <= 2 and seq == 0 are covered. + a.cancelMark.Set(sessionID, 2) + + survivors := a.drainUncanceledQueue(sessionID) + + require.Len(t, survivors, 1, + "only the follow-up queued after the cancel (seq > mark) must survive") + require.Equal(t, "after", survivors[0].Prompt) + + _, ok := a.messageQueue.Get(sessionID) + require.False(t, ok, "drain must clear the session message queue") +} + +// TestDrainUncanceledQueue_NoMarkKeepsAll verifies that with no cancel +// mark recorded, every queued call survives the drain. +func TestDrainUncanceledQueue_NoMarkKeepsAll(t *testing.T) { + t.Parallel() + + env := testEnv(t) + a := NewSessionAgent(SessionAgentOptions{ + Sessions: env.sessions, + Messages: env.messages, + }).(*sessionAgent) + + const sessionID = "drain-nomark" + a.messageQueue.Set(sessionID, []SessionAgentCall{ + {SessionID: sessionID, Prompt: "a", acceptSeq: 0}, + {SessionID: sessionID, Prompt: "b", acceptSeq: 5}, + }) + + survivors := a.drainUncanceledQueue(sessionID) + require.Len(t, survivors, 2, "no cancel mark means all queued calls survive") +} + // TestRunCompletePublisher_MustDeliverOverTakesPublish exercises the // pubsub.Publisher interface change end-to-end: a Broker is the only // concrete Publisher implementation and must satisfy both Publish and