diff --git a/internal/agent/agent.go b/internal/agent/agent.go index 393587a111ad806dc26bbf8a80f52dba49ce0397..6728864527b2e4a7f970262f25c3953168a0b5bc 100644 --- a/internal/agent/agent.go +++ b/internal/agent/agent.go @@ -359,6 +359,30 @@ func (a *sessionAgent) enqueueCall(call SessionAgentCall) { a.messageQueue.Set(call.SessionID, existing) } +// drainUncanceledQueue copies and clears the session's message queue and +// returns the queued calls not covered by a pending cancel. The queue +// copy/delete and the cancel-mark check happen under the per-session +// dispatch mutex so the filtering is atomic against a concurrent Cancel: +// canceledBySeq requires the caller to hold that mutex, and evaluating it +// here (rather than after unlocking) prevents a cancel recorded between +// the drain and the check from being observed inconsistently. The +// returned survivors are processed by the caller without the lock held. +func (a *sessionAgent) drainUncanceledQueue(sessionID string) []SessionAgentCall { + dispatchLock := a.sessionMu(sessionID) + dispatchLock.Lock() + defer dispatchLock.Unlock() + queuedCalls, _ := a.messageQueue.Get(sessionID) + a.messageQueue.Del(sessionID) + survivors := queuedCalls[:0] + for _, queued := range queuedCalls { + if a.canceledBySeq(sessionID, queued.acceptSeq) { + continue + } + survivors = append(survivors, queued) + } + return survivors +} + // clearPendingCancel removes any pending-cancel mark for sessionID. It // takes the per-session dispatch lock so it is ordered against Cancel // and the dispatch handoff. @@ -717,15 +741,8 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (result * // part of this step. Coverage is per-call by accept sequence // so a follow-up queued after the cancel (higher seq) is // still folded in. - dispatchLock := a.sessionMu(call.SessionID) - dispatchLock.Lock() - queuedCalls, _ := a.messageQueue.Get(call.SessionID) - a.messageQueue.Del(call.SessionID) - dispatchLock.Unlock() - for _, queued := range queuedCalls { - if a.canceledBySeq(call.SessionID, queued.acceptSeq) { - continue - } + survivors := a.drainUncanceledQueue(call.SessionID) + for _, queued := range survivors { userMessage, createErr := a.createUserMessage(callContext, queued) if createErr != nil { return callContext, prepared, createErr diff --git a/internal/agent/run_complete_test.go b/internal/agent/run_complete_test.go index 74f9232a0946b24d38f05873fa39066dcae40c27..50c5fdf9355c11123d91d26b7ec2e90b219df9d4 100644 --- a/internal/agent/run_complete_test.go +++ b/internal/agent/run_complete_test.go @@ -58,6 +58,62 @@ func TestSessionAgentRun_QueueStripsOnComplete(t *testing.T) { "RunComplete still correlates with the originating SendMessage") } +// TestDrainUncanceledQueue_FiltersUnderDispatchLock verifies that the +// queue drain evaluates the per-session cancel mark while holding the +// dispatch mutex (canceledBySeq's documented precondition). Queued calls +// at or below the cancel high-water mark are dropped, calls queued after +// the cancel (higher seq) survive, untracked enqueues (seq == 0) are +// dropped whenever any mark is present, and the queue is cleared. +func TestDrainUncanceledQueue_FiltersUnderDispatchLock(t *testing.T) { + t.Parallel() + + env := testEnv(t) + a := NewSessionAgent(SessionAgentOptions{ + Sessions: env.sessions, + Messages: env.messages, + }).(*sessionAgent) + + const sessionID = "drain-session" + a.messageQueue.Set(sessionID, []SessionAgentCall{ + {SessionID: sessionID, Prompt: "below", acceptSeq: 1}, + {SessionID: sessionID, Prompt: "at-mark", acceptSeq: 2}, + {SessionID: sessionID, Prompt: "after", acceptSeq: 3}, + {SessionID: sessionID, Prompt: "untracked", acceptSeq: 0}, + }) + // Cancel high-water mark at seq 2: seq <= 2 and seq == 0 are covered. + a.cancelMark.Set(sessionID, 2) + + survivors := a.drainUncanceledQueue(sessionID) + + require.Len(t, survivors, 1, + "only the follow-up queued after the cancel (seq > mark) must survive") + require.Equal(t, "after", survivors[0].Prompt) + + _, ok := a.messageQueue.Get(sessionID) + require.False(t, ok, "drain must clear the session message queue") +} + +// TestDrainUncanceledQueue_NoMarkKeepsAll verifies that with no cancel +// mark recorded, every queued call survives the drain. +func TestDrainUncanceledQueue_NoMarkKeepsAll(t *testing.T) { + t.Parallel() + + env := testEnv(t) + a := NewSessionAgent(SessionAgentOptions{ + Sessions: env.sessions, + Messages: env.messages, + }).(*sessionAgent) + + const sessionID = "drain-nomark" + a.messageQueue.Set(sessionID, []SessionAgentCall{ + {SessionID: sessionID, Prompt: "a", acceptSeq: 0}, + {SessionID: sessionID, Prompt: "b", acceptSeq: 5}, + }) + + survivors := a.drainUncanceledQueue(sessionID) + require.Len(t, survivors, 2, "no cancel mark means all queued calls survive") +} + // TestRunCompletePublisher_MustDeliverOverTakesPublish exercises the // pubsub.Publisher interface change end-to-end: a Broker is the only // concrete Publisher implementation and must satisfy both Publish and