fix(server): close a queued-prompt cancel race

Christian Rocha and Charm Crush created

Check queued prompt cancellation while holding the same session handoff
lock used for accepted and active prompts. Queue draining now follows
the same ordering guarantees as prompt dispatch.

Co-Authored-By: Charm Crush <crush@charm.land>

Change summary

internal/agent/agent.go             | 35 ++++++++++++++----
internal/agent/run_complete_test.go | 56 +++++++++++++++++++++++++++++++
2 files changed, 82 insertions(+), 9 deletions(-)

Detailed changes

internal/agent/agent.go 🔗

@@ -359,6 +359,30 @@ func (a *sessionAgent) enqueueCall(call SessionAgentCall) {
 	a.messageQueue.Set(call.SessionID, existing)
 }
 
+// drainUncanceledQueue copies and clears the session's message queue and
+// returns the queued calls not covered by a pending cancel. The queue
+// copy/delete and the cancel-mark check happen under the per-session
+// dispatch mutex so the filtering is atomic against a concurrent Cancel:
+// canceledBySeq requires the caller to hold that mutex, and evaluating it
+// here (rather than after unlocking) prevents a cancel recorded between
+// the drain and the check from being observed inconsistently. The
+// returned survivors are processed by the caller without the lock held.
+func (a *sessionAgent) drainUncanceledQueue(sessionID string) []SessionAgentCall {
+	dispatchLock := a.sessionMu(sessionID)
+	dispatchLock.Lock()
+	defer dispatchLock.Unlock()
+	queuedCalls, _ := a.messageQueue.Get(sessionID)
+	a.messageQueue.Del(sessionID)
+	survivors := queuedCalls[:0]
+	for _, queued := range queuedCalls {
+		if a.canceledBySeq(sessionID, queued.acceptSeq) {
+			continue
+		}
+		survivors = append(survivors, queued)
+	}
+	return survivors
+}
+
 // clearPendingCancel removes any pending-cancel mark for sessionID. It
 // takes the per-session dispatch lock so it is ordered against Cancel
 // and the dispatch handoff.
@@ -717,15 +741,8 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (result *
 			// part of this step. Coverage is per-call by accept sequence
 			// so a follow-up queued after the cancel (higher seq) is
 			// still folded in.
-			dispatchLock := a.sessionMu(call.SessionID)
-			dispatchLock.Lock()
-			queuedCalls, _ := a.messageQueue.Get(call.SessionID)
-			a.messageQueue.Del(call.SessionID)
-			dispatchLock.Unlock()
-			for _, queued := range queuedCalls {
-				if a.canceledBySeq(call.SessionID, queued.acceptSeq) {
-					continue
-				}
+			survivors := a.drainUncanceledQueue(call.SessionID)
+			for _, queued := range survivors {
 				userMessage, createErr := a.createUserMessage(callContext, queued)
 				if createErr != nil {
 					return callContext, prepared, createErr

internal/agent/run_complete_test.go 🔗

@@ -58,6 +58,62 @@ func TestSessionAgentRun_QueueStripsOnComplete(t *testing.T) {
 			"RunComplete still correlates with the originating SendMessage")
 }
 
+// TestDrainUncanceledQueue_FiltersUnderDispatchLock verifies that the
+// queue drain evaluates the per-session cancel mark while holding the
+// dispatch mutex (canceledBySeq's documented precondition). Queued calls
+// at or below the cancel high-water mark are dropped, calls queued after
+// the cancel (higher seq) survive, untracked enqueues (seq == 0) are
+// dropped whenever any mark is present, and the queue is cleared.
+func TestDrainUncanceledQueue_FiltersUnderDispatchLock(t *testing.T) {
+	t.Parallel()
+
+	env := testEnv(t)
+	a := NewSessionAgent(SessionAgentOptions{
+		Sessions: env.sessions,
+		Messages: env.messages,
+	}).(*sessionAgent)
+
+	const sessionID = "drain-session"
+	a.messageQueue.Set(sessionID, []SessionAgentCall{
+		{SessionID: sessionID, Prompt: "below", acceptSeq: 1},
+		{SessionID: sessionID, Prompt: "at-mark", acceptSeq: 2},
+		{SessionID: sessionID, Prompt: "after", acceptSeq: 3},
+		{SessionID: sessionID, Prompt: "untracked", acceptSeq: 0},
+	})
+	// Cancel high-water mark at seq 2: seq <= 2 and seq == 0 are covered.
+	a.cancelMark.Set(sessionID, 2)
+
+	survivors := a.drainUncanceledQueue(sessionID)
+
+	require.Len(t, survivors, 1,
+		"only the follow-up queued after the cancel (seq > mark) must survive")
+	require.Equal(t, "after", survivors[0].Prompt)
+
+	_, ok := a.messageQueue.Get(sessionID)
+	require.False(t, ok, "drain must clear the session message queue")
+}
+
+// TestDrainUncanceledQueue_NoMarkKeepsAll verifies that with no cancel
+// mark recorded, every queued call survives the drain.
+func TestDrainUncanceledQueue_NoMarkKeepsAll(t *testing.T) {
+	t.Parallel()
+
+	env := testEnv(t)
+	a := NewSessionAgent(SessionAgentOptions{
+		Sessions: env.sessions,
+		Messages: env.messages,
+	}).(*sessionAgent)
+
+	const sessionID = "drain-nomark"
+	a.messageQueue.Set(sessionID, []SessionAgentCall{
+		{SessionID: sessionID, Prompt: "a", acceptSeq: 0},
+		{SessionID: sessionID, Prompt: "b", acceptSeq: 5},
+	})
+
+	survivors := a.drainUncanceledQueue(sessionID)
+	require.Len(t, survivors, 2, "no cancel mark means all queued calls survive")
+}
+
 // TestRunCompletePublisher_MustDeliverOverTakesPublish exercises the
 // pubsub.Publisher interface change end-to-end: a Broker is the only
 // concrete Publisher implementation and must satisfy both Publish and