@@ -359,6 +359,30 @@ func (a *sessionAgent) enqueueCall(call SessionAgentCall) {
a.messageQueue.Set(call.SessionID, existing)
}
+// drainUncanceledQueue copies and clears the session's message queue and
+// returns the queued calls not covered by a pending cancel. The queue
+// copy/delete and the cancel-mark check happen under the per-session
+// dispatch mutex so the filtering is atomic against a concurrent Cancel:
+// canceledBySeq requires the caller to hold that mutex, and evaluating it
+// here (rather than after unlocking) prevents a cancel recorded between
+// the drain and the check from being observed inconsistently. The
+// returned survivors are processed by the caller without the lock held.
+func (a *sessionAgent) drainUncanceledQueue(sessionID string) []SessionAgentCall {
+ dispatchLock := a.sessionMu(sessionID)
+ dispatchLock.Lock()
+ defer dispatchLock.Unlock()
+ queuedCalls, _ := a.messageQueue.Get(sessionID)
+ a.messageQueue.Del(sessionID)
+ survivors := queuedCalls[:0]
+ for _, queued := range queuedCalls {
+ if a.canceledBySeq(sessionID, queued.acceptSeq) {
+ continue
+ }
+ survivors = append(survivors, queued)
+ }
+ return survivors
+}
+
// clearPendingCancel removes any pending-cancel mark for sessionID. It
// takes the per-session dispatch lock so it is ordered against Cancel
// and the dispatch handoff.
@@ -717,15 +741,8 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (result *
// part of this step. Coverage is per-call by accept sequence
// so a follow-up queued after the cancel (higher seq) is
// still folded in.
- dispatchLock := a.sessionMu(call.SessionID)
- dispatchLock.Lock()
- queuedCalls, _ := a.messageQueue.Get(call.SessionID)
- a.messageQueue.Del(call.SessionID)
- dispatchLock.Unlock()
- for _, queued := range queuedCalls {
- if a.canceledBySeq(call.SessionID, queued.acceptSeq) {
- continue
- }
+ survivors := a.drainUncanceledQueue(call.SessionID)
+ for _, queued := range survivors {
userMessage, createErr := a.createUserMessage(callContext, queued)
if createErr != nil {
return callContext, prepared, createErr
@@ -58,6 +58,62 @@ func TestSessionAgentRun_QueueStripsOnComplete(t *testing.T) {
"RunComplete still correlates with the originating SendMessage")
}
+// TestDrainUncanceledQueue_FiltersUnderDispatchLock verifies that the
+// queue drain evaluates the per-session cancel mark while holding the
+// dispatch mutex (canceledBySeq's documented precondition). Queued calls
+// at or below the cancel high-water mark are dropped, calls queued after
+// the cancel (higher seq) survive, untracked enqueues (seq == 0) are
+// dropped whenever any mark is present, and the queue is cleared.
+func TestDrainUncanceledQueue_FiltersUnderDispatchLock(t *testing.T) {
+ t.Parallel()
+
+ env := testEnv(t)
+ a := NewSessionAgent(SessionAgentOptions{
+ Sessions: env.sessions,
+ Messages: env.messages,
+ }).(*sessionAgent)
+
+ const sessionID = "drain-session"
+ a.messageQueue.Set(sessionID, []SessionAgentCall{
+ {SessionID: sessionID, Prompt: "below", acceptSeq: 1},
+ {SessionID: sessionID, Prompt: "at-mark", acceptSeq: 2},
+ {SessionID: sessionID, Prompt: "after", acceptSeq: 3},
+ {SessionID: sessionID, Prompt: "untracked", acceptSeq: 0},
+ })
+ // Cancel high-water mark at seq 2: seq <= 2 and seq == 0 are covered.
+ a.cancelMark.Set(sessionID, 2)
+
+ survivors := a.drainUncanceledQueue(sessionID)
+
+ require.Len(t, survivors, 1,
+ "only the follow-up queued after the cancel (seq > mark) must survive")
+ require.Equal(t, "after", survivors[0].Prompt)
+
+ _, ok := a.messageQueue.Get(sessionID)
+ require.False(t, ok, "drain must clear the session message queue")
+}
+
+// TestDrainUncanceledQueue_NoMarkKeepsAll verifies that with no cancel
+// mark recorded, every queued call survives the drain.
+func TestDrainUncanceledQueue_NoMarkKeepsAll(t *testing.T) {
+ t.Parallel()
+
+ env := testEnv(t)
+ a := NewSessionAgent(SessionAgentOptions{
+ Sessions: env.sessions,
+ Messages: env.messages,
+ }).(*sessionAgent)
+
+ const sessionID = "drain-nomark"
+ a.messageQueue.Set(sessionID, []SessionAgentCall{
+ {SessionID: sessionID, Prompt: "a", acceptSeq: 0},
+ {SessionID: sessionID, Prompt: "b", acceptSeq: 5},
+ })
+
+ survivors := a.drainUncanceledQueue(sessionID)
+ require.Len(t, survivors, 2, "no cancel mark means all queued calls survive")
+}
+
// TestRunCompletePublisher_MustDeliverOverTakesPublish exercises the
// pubsub.Publisher interface change end-to-end: a Broker is the only
// concrete Publisher implementation and must satisfy both Publish and