From 4347015ab7669b1376e6cb814c10f89aef3481ad Mon Sep 17 00:00:00 2001 From: Christian Rocha Date: Sun, 31 May 2026 22:32:39 -0400 Subject: [PATCH] chore(server): honor cancels immediately after prompt acceptance Track prompts that have been accepted but have not started running yet so a cancel issued right after acceptance applies to that prompt. Idle-session cancels remain a no-op, preventing one client's cancel from poisoning the next prompt. Co-Authored-By: Charm Crush --- internal/agent/accepted_run_test.go | 226 +++++++++++++++ internal/agent/agent.go | 360 ++++++++++++++++++++++-- internal/agent/coordinator.go | 33 +++ internal/agent/coordinator_test.go | 4 + internal/agent/dispatch_cancel_test.go | 198 +++++++++++++ internal/server/agent_cancel_test.go | 7 + internal/server/sessions_isbusy_test.go | 8 + 7 files changed, 812 insertions(+), 24 deletions(-) create mode 100644 internal/agent/accepted_run_test.go create mode 100644 internal/agent/dispatch_cancel_test.go diff --git a/internal/agent/accepted_run_test.go b/internal/agent/accepted_run_test.go new file mode 100644 index 0000000000000000000000000000000000000000..d62422a9f02bec68a8da1a08c6e6d6b52e7d7699 --- /dev/null +++ b/internal/agent/accepted_run_test.go @@ -0,0 +1,226 @@ +package agent + +import ( + "context" + "testing" + + "github.com/charmbracelet/crush/internal/message" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// newCancelTestAgent builds a DB-backed sessionAgent with no model. The +// tests here exercise the dispatch/cancel/persist paths, none of which +// reach agent.Stream, so a model is unnecessary. +func newCancelTestAgent(t *testing.T) (*sessionAgent, fakeEnv) { + t.Helper() + env := testEnv(t) + sa := NewSessionAgent(SessionAgentOptions{ + Sessions: env.sessions, + Messages: env.messages, + }).(*sessionAgent) + return sa, env +} + +func (a *sessionAgent) acceptedCount(sessionID string) int { + c, _ := a.acceptedRuns.Get(sessionID) + return c +} + +func (a *sessionAgent) hasPendingCancel(sessionID string) bool { + _, ok := a.pendingCancels.Get(sessionID) + return ok +} + +func TestAcceptedRun_CloseIsIdempotent(t *testing.T) { + t.Parallel() + sa, _ := newCancelTestAgent(t) + + accept := sa.BeginAccepted("sid") + require.Equal(t, "sid", accept.SessionID()) + require.Equal(t, 1, sa.acceptedCount("sid")) + + accept.Close() + require.Equal(t, 0, sa.acceptedCount("sid")) + + // Repeated Close must not underflow the counter. + accept.Close() + accept.Close() + require.Equal(t, 0, sa.acceptedCount("sid")) +} + +func TestAcceptedRun_MultipleReservations(t *testing.T) { + t.Parallel() + sa, _ := newCancelTestAgent(t) + + a1 := sa.BeginAccepted("sid") + a2 := sa.BeginAccepted("sid") + require.Equal(t, 2, sa.acceptedCount("sid")) + + a1.Close() + require.Equal(t, 1, sa.acceptedCount("sid")) + + a2.Close() + require.Equal(t, 0, sa.acceptedCount("sid")) +} + +func TestAcceptedRun_NilSafe(t *testing.T) { + t.Parallel() + var accept *AcceptedRun + require.Equal(t, "", accept.SessionID()) + // Must not panic. + accept.Close() +} + +func TestCancel_IdleDoesNotRecordPendingCancel(t *testing.T) { + t.Parallel() + sa, _ := newCancelTestAgent(t) + + // No accepted run, no active request: a true no-op. + sa.Cancel("sid") + require.False(t, sa.hasPendingCancel("sid")) +} + +func TestCancel_AcceptedRecordsPendingCancel(t *testing.T) { + t.Parallel() + sa, _ := newCancelTestAgent(t) + + accept := sa.BeginAccepted("sid") + defer accept.Close() + + sa.Cancel("sid") + require.True(t, sa.hasPendingCancel("sid")) +} + +func TestCancel_SecondCancelWhilePendingIsNoOp(t *testing.T) { + t.Parallel() + sa, _ := newCancelTestAgent(t) + + accept := sa.BeginAccepted("sid") + defer accept.Close() + + sa.Cancel("sid") + require.True(t, sa.hasPendingCancel("sid")) + + // A second cancel while a pending cancel is already recorded must + // remain a single pending cancel; one Run consumes exactly one. + sa.Cancel("sid") + require.True(t, sa.hasPendingCancel("sid")) +} + +func TestRun_CancelOnEntryPersistsCanceledTurn(t *testing.T) { + t.Parallel() + sa, env := newCancelTestAgent(t) + + sess, err := env.sessions.Create(t.Context(), "session") + require.NoError(t, err) + + accept := sa.BeginAccepted(sess.ID) + // A cancel arrives in the accepted-but-not-yet-active window. + sa.Cancel(sess.ID) + require.True(t, sa.hasPendingCancel(sess.ID)) + + result, err := sa.Run(t.Context(), SessionAgentCall{ + SessionID: sess.ID, + Prompt: "hello", + Accepted: accept, + }) + require.NoError(t, err) + require.Nil(t, result) + + // The pending cancel was consumed and the accept released. + require.False(t, sa.hasPendingCancel(sess.ID)) + require.Equal(t, 0, sa.acceptedCount(sess.ID)) + + msgs, err := env.messages.List(t.Context(), sess.ID) + require.NoError(t, err) + require.Len(t, msgs, 2) + assert.Equal(t, message.User, msgs[0].Role) + assert.Equal(t, message.Assistant, msgs[1].Role) + assert.Equal(t, message.FinishReasonCanceled, msgs[1].FinishReason()) +} + +func TestPersistCanceledTurn_WritesBothWhenUserMissing(t *testing.T) { + t.Parallel() + sa, env := newCancelTestAgent(t) + + sess, err := env.sessions.Create(t.Context(), "session") + require.NoError(t, err) + + err = sa.persistCanceledTurn(t.Context(), SessionAgentCall{ + SessionID: sess.ID, + Prompt: "hello", + }, false) + require.NoError(t, err) + + msgs, err := env.messages.List(t.Context(), sess.ID) + require.NoError(t, err) + require.Len(t, msgs, 2) + assert.Equal(t, message.User, msgs[0].Role) + assert.Equal(t, message.Assistant, msgs[1].Role) + assert.Equal(t, message.FinishReasonCanceled, msgs[1].FinishReason()) +} + +func TestPersistCanceledTurn_WritesAssistantOnlyWhenUserCreated(t *testing.T) { + t.Parallel() + sa, env := newCancelTestAgent(t) + + sess, err := env.sessions.Create(t.Context(), "session") + require.NoError(t, err) + + // Simulate PrepareStep having already created the user message. + _, err = sa.createUserMessage(t.Context(), SessionAgentCall{ + SessionID: sess.ID, + Prompt: "hello", + }) + require.NoError(t, err) + + err = sa.persistCanceledTurn(t.Context(), SessionAgentCall{ + SessionID: sess.ID, + Prompt: "hello", + }, true) + require.NoError(t, err) + + msgs, err := env.messages.List(t.Context(), sess.ID) + require.NoError(t, err) + require.Len(t, msgs, 2) + assert.Equal(t, message.User, msgs[0].Role) + assert.Equal(t, message.Assistant, msgs[1].Role) + assert.Equal(t, message.FinishReasonCanceled, msgs[1].FinishReason()) +} + +func TestPersistCanceledTurn_SucceedsWithCanceledContext(t *testing.T) { + t.Parallel() + sa, env := newCancelTestAgent(t) + + sess, err := env.sessions.Create(t.Context(), "session") + require.NoError(t, err) + + // Simulate workspace shutdown having already canceled the run + // context. WithoutCancel must let the writes through. + ctx, cancel := context.WithCancel(t.Context()) + cancel() + + err = sa.persistCanceledTurn(ctx, SessionAgentCall{ + SessionID: sess.ID, + Prompt: "hello", + }, false) + require.NoError(t, err) + + msgs, err := env.messages.List(t.Context(), sess.ID) + require.NoError(t, err) + require.Len(t, msgs, 2) +} + +func TestClearPendingCancel(t *testing.T) { + t.Parallel() + sa, _ := newCancelTestAgent(t) + + accept := sa.BeginAccepted("sid") + defer accept.Close() + sa.Cancel("sid") + require.True(t, sa.hasPendingCancel("sid")) + + sa.clearPendingCancel("sid") + require.False(t, sa.hasPendingCancel("sid")) +} diff --git a/internal/agent/agent.go b/internal/agent/agent.go index bc3df59e7626943ca31cbc293c0b76a814c05fae..97bd7a21af4c28f30e46fcaf23c23348467e90ac 100644 --- a/internal/agent/agent.go +++ b/internal/agent/agent.go @@ -21,6 +21,7 @@ import ( "strconv" "strings" "sync" + "sync/atomic" "time" "charm.land/catwalk/pkg/catwalk" @@ -103,10 +104,19 @@ type SessionAgentCall struct { // recursion drains, so falling back to the default broker // publish keeps the event visible to subscribers. OnComplete func(notify.RunComplete) + // Accepted, when non-nil, is the accept reservation taken by + // BeginAccepted before the call was dispatched onto a goroutine + // (the client/server fire-and-forget path). Run consumes it under + // dispatchMu[SessionID] once the accepted -> (cancel-on-entry | + // queued | active) transition has been chosen. When nil + // (in-process / local callers like AppWorkspace), behavior is + // unchanged and no accept tracking applies. + Accepted *AcceptedRun } type SessionAgent interface { Run(context.Context, SessionAgentCall) (*fantasy.AgentResult, error) + BeginAccepted(sessionID string) *AcceptedRun SetModels(large Model, small Model) SetTools(tools []fantasy.AgentTool) SetSystemPrompt(systemPrompt string) @@ -145,6 +155,32 @@ type sessionAgent struct { messageQueue *csync.Map[string, []SessionAgentCall] activeRequests *csync.Map[string, context.CancelFunc] + + // dispatchMu holds a per-session mutex that serializes the + // accepted -> (cancel-on-entry | queued | active) transition in + // Run against a concurrent Cancel. The lock is held only during + // the brief handoff (no DB or LLM I/O under the lock). + dispatchMu *csync.Map[string, *sync.Mutex] + // acceptedRuns counts dispatched-but-not-yet-active runs per + // session. A counter > 0 means a dispatched prompt is in flight + // and has not yet completed the dispatch handoff in Run. Only + // BeginAccepted increments it; only AcceptedRun.Close decrements + // it. + acceptedRuns *csync.Map[string, int] + // pendingCancels records sessions whose dispatched-but-not-yet- + // running call should observe a cancellation request. It is only + // set by Cancel when acceptedRuns > 0, so an idle Escape never + // poisons the next prompt. + pendingCancels *csync.Map[string, struct{}] + // dispatchMuCreate guards lazy creation of per-session entries in + // dispatchMu so two goroutines can't race to lock different mutex + // instances for the same session. + dispatchMuCreate sync.Mutex + // acceptedMu serializes increments/decrements of acceptedRuns. It + // is separate from dispatchMu so AcceptedRun.Close (which may run + // while Run holds dispatchMu for the same session) does not + // deadlock by re-entering the dispatch lock. + acceptedMu sync.Mutex } type SessionAgentOptions struct { @@ -180,7 +216,144 @@ func NewSessionAgent( runComplete: opts.RunComplete, messageQueue: csync.NewMap[string, []SessionAgentCall](), activeRequests: csync.NewMap[string, context.CancelFunc](), + dispatchMu: csync.NewMap[string, *sync.Mutex](), + acceptedRuns: csync.NewMap[string, int](), + pendingCancels: csync.NewMap[string, struct{}](), + } +} + +// AcceptedRun owns exactly one accept reservation taken by +// BeginAccepted. It is the only carrier of accept-state across the +// backend.runAgent / Coordinator.Run / sessionAgent.Run layers: a +// counter > 0 means a dispatched prompt is in flight and has not yet +// completed the dispatch handoff in Run. Close is the only way to +// release the reservation and is idempotent. +type AcceptedRun struct { + agent *sessionAgent + sessionID string + done atomic.Bool +} + +// Close decrements the accept counter for this reservation. It is safe +// to call multiple times; only the first call has effect. +func (r *AcceptedRun) Close() { + if r == nil { + return + } + if !r.done.CompareAndSwap(false, true) { + return + } + r.agent.endAccepted(r.sessionID) +} + +// SessionID exposes the session this reservation is for so the run path +// can use it without an extra parameter. +func (r *AcceptedRun) SessionID() string { + if r == nil { + return "" + } + return r.sessionID +} + +// BeginAccepted increments the accept counter for sessionID and returns +// a handle whose Close is the only way to decrement it. It is the only +// entry point that mutates acceptedRuns. +func (a *sessionAgent) BeginAccepted(sessionID string) *AcceptedRun { + a.acceptedMu.Lock() + defer a.acceptedMu.Unlock() + count, _ := a.acceptedRuns.Get(sessionID) + a.acceptedRuns.Set(sessionID, count+1) + return &AcceptedRun{agent: a, sessionID: sessionID} +} + +// endAccepted decrements the accept counter for sessionID. It is only +// called via AcceptedRun.Close. It uses a dedicated lock (not the +// per-session dispatch mutex) so it can run while Run holds dispatchMu +// for the same session without deadlocking. +func (a *sessionAgent) endAccepted(sessionID string) { + a.acceptedMu.Lock() + defer a.acceptedMu.Unlock() + count, ok := a.acceptedRuns.Get(sessionID) + if !ok || count <= 1 { + a.acceptedRuns.Del(sessionID) + return + } + a.acceptedRuns.Set(sessionID, count-1) +} + +// sessionMu returns the per-session dispatch mutex, creating it on first +// use. Creation is guarded so concurrent callers always observe the same +// mutex instance for a given session. +func (a *sessionAgent) sessionMu(sessionID string) *sync.Mutex { + if mu, ok := a.dispatchMu.Get(sessionID); ok { + return mu + } + a.dispatchMuCreate.Lock() + defer a.dispatchMuCreate.Unlock() + if mu, ok := a.dispatchMu.Get(sessionID); ok { + return mu + } + mu := &sync.Mutex{} + a.dispatchMu.Set(sessionID, mu) + return mu +} + +// enqueueCall appends call to the session's message queue. The +// OnComplete hook is stripped: the caller that supplied it (typically +// coordinator.Run) has its own retry/coalesce scope that ends when it +// returns, so by the time the queue drains nobody is left to consume the +// buffered terminal event. The recursive Run falls back to the default +// broker publish, which is what existing subscribers expect for queued +// turns. +func (a *sessionAgent) enqueueCall(call SessionAgentCall) { + existing, ok := a.messageQueue.Get(call.SessionID) + if !ok { + existing = []SessionAgentCall{} } + queued := call + queued.OnComplete = nil + queued.Accepted = nil + existing = append(existing, queued) + a.messageQueue.Set(call.SessionID, existing) +} + +// clearPendingCancel removes any pending-cancel record for sessionID. It +// takes the per-session dispatch lock so it is ordered against Cancel and +// the dispatch handoff. +func (a *sessionAgent) clearPendingCancel(sessionID string) { + mu := a.sessionMu(sessionID) + mu.Lock() + defer mu.Unlock() + a.pendingCancels.Del(sessionID) +} + +// persistCanceledTurn writes the user/assistant records for a turn that +// was canceled before (or just as) streaming would have produced them. +// It creates the user message only when it was not already created by an +// earlier createUserMessage call (userMsgCreated), then writes an +// assistant message with FinishReasonCanceled. Both writes use +// context.WithoutCancel(ctx) so workspace shutdown (which cancels the run +// context) can't drop them. +func (a *sessionAgent) persistCanceledTurn(ctx context.Context, call SessionAgentCall, userMsgCreated bool) error { + writeCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), 5*time.Second) + defer cancel() + if !userMsgCreated { + if _, err := a.createUserMessage(writeCtx, call); err != nil { + return err + } + } + largeModel := a.largeModel.Get() + assistant, err := a.messages.Create(writeCtx, call.SessionID, message.CreateMessageParams{ + Role: message.Assistant, + Parts: []message.ContentPart{}, + Model: largeModel.ModelCfg.Model, + Provider: largeModel.ModelCfg.Provider, + }) + if err != nil { + return err + } + assistant.AddFinish(message.FinishReasonCanceled, "User canceled request", "") + return a.messages.Update(writeCtx, assistant) } // ValidateCall performs the cheap structural validation that @@ -204,22 +377,73 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (result * return nil, err } - // Queue the message if busy. Strip OnComplete: the caller that - // supplied the hook (typically coordinator.Run) has its own - // retry/coalesce scope that ends when it returns, so by the time - // the queue drains nobody is left to consume the buffered - // terminal event. The recursive Run will fall back to the - // default broker publish, which is what existing subscribers - // expect for queued turns. - if a.IsSessionBusy(call.SessionID) { - existing, ok := a.messageQueue.Get(call.SessionID) - if !ok { - existing = []SessionAgentCall{} + // genCtx/cancel are the run context and its cancel func. For the + // accepted (fire-and-forget) dispatch path they are created under + // dispatchMu below so a concurrent Cancel can observe the + // activeRequests entry before the assistant message exists. For + // the in-process path they stay nil here and are created later, + // preserving the original ordering. + var ( + genCtx context.Context + cancel context.CancelFunc + activeRegistered bool + userMsgCreated bool + ) + + if call.Accepted != nil { + // Serialize the accepted -> (cancel-on-entry | queued | + // active) transition against a concurrent Cancel. Cancel takes + // the same per-session lock, so every cancel observes at least + // one of: pendingCancels, an activeRequests entry, or a + // messageQueue entry it then clears. + mu := a.sessionMu(call.SessionID) + mu.Lock() + + if _, pending := a.pendingCancels.Get(call.SessionID); pending { + // Cancel-on-entry: a cancel arrived while this run was + // dispatched but not yet active. Consume the pending + // cancel, release the accept reservation, drop the lock, + // and persist a canceled turn without entering Stream. + a.pendingCancels.Del(call.SessionID) + call.Accepted.Close() + mu.Unlock() + if err := a.persistCanceledTurn(ctx, call, false); err != nil { + return nil, err + } + return nil, nil + } + + if a.IsSessionBusy(call.SessionID) { + // Busy: an earlier prompt is active. Queue this call and + // release the accept reservation. A Cancel arriving after + // this point sees the active entry and clears the queue. + a.enqueueCall(call) + call.Accepted.Close() + mu.Unlock() + return nil, nil } - queued := call - queued.OnComplete = nil - existing = append(existing, queued) - a.messageQueue.Set(call.SessionID, existing) + + // Idle: become the active run. Register the cancel func before + // dropping the lock so a Cancel that arrives between here and + // assistant creation is not lost. + runCtx := context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID) + genCtx, cancel = context.WithCancel(runCtx) + a.activeRequests.Set(call.SessionID, cancel) + activeRegistered = true + call.Accepted.Close() + mu.Unlock() + + defer cancel() + defer a.activeRequests.Del(call.SessionID) + } else if a.IsSessionBusy(call.SessionID) { + // Queue the message if busy. Strip OnComplete: the caller that + // supplied the hook (typically coordinator.Run) has its own + // retry/coalesce scope that ends when it returns, so by the time + // the queue drains nobody is left to consume the buffered + // terminal event. The recursive Run will fall back to the + // default broker publish, which is what existing subscribers + // expect for queued turns. + a.enqueueCall(call) return nil, nil } @@ -282,15 +506,22 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (result * if err != nil { return nil, err } + userMsgCreated = true // Add the session to the context. ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID) - genCtx, cancel := context.WithCancel(ctx) - a.activeRequests.Set(call.SessionID, cancel) + // For the accepted dispatch path the run context and cancel func + // were already created and registered under dispatchMu above; reuse + // them. For the in-process path create them here, preserving the + // original ordering. + if !activeRegistered { + genCtx, cancel = context.WithCancel(ctx) + a.activeRequests.Set(call.SessionID, cancel) - defer cancel() - defer a.activeRequests.Del(call.SessionID) + defer cancel() + defer a.activeRequests.Del(call.SessionID) + } // skipRunComplete is set just before the queued-recursion path so // the outer Run doesn't publish a RunComplete that would race // with — and be superseded by — the recursive call's own @@ -390,14 +621,24 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (result * // Use latest tools (updated by SetTools when MCP tools change). prepared.Tools = a.tools.Copy() + // Drain queued follow-up prompts, but skip them if a cancel + // was recorded for the session while they sat in the queue: + // a cancel that arrived after the queue insertion must not + // let the queued prompt run as part of this step. + dispatchLock := a.sessionMu(call.SessionID) + dispatchLock.Lock() + _, canceled := a.pendingCancels.Get(call.SessionID) queuedCalls, _ := a.messageQueue.Get(call.SessionID) a.messageQueue.Del(call.SessionID) - for _, queued := range queuedCalls { - userMessage, createErr := a.createUserMessage(callContext, queued) - if createErr != nil { - return callContext, prepared, createErr + dispatchLock.Unlock() + if !canceled { + for _, queued := range queuedCalls { + userMessage, createErr := a.createUserMessage(callContext, queued) + if createErr != nil { + return callContext, prepared, createErr + } + prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...) } - prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...) } prepared.Messages = a.workaroundProviderMediaLimitations(prepared.Messages, largeModel) @@ -594,6 +835,18 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (result * isHyper := largeModel.ModelCfg.Provider == hyper.Name isCancelErr := errors.Is(err, context.Canceled) if currentAssistant == nil { + // Cancel-before-assistant-creation window: the run was + // canceled after activeRequests.Set but before PrepareStep + // created the assistant message. Without this, the turn + // would return with no FinishReasonCanceled marker and no + // user-visible record. The user message was already created + // above, so persistCanceledTurn only writes the assistant + // record. + if isCancelErr { + if persistErr := a.persistCanceledTurn(ctx, call, userMsgCreated); persistErr != nil { + return nil, persistErr + } + } return result, err } // Persist final state with a context detached from the run @@ -741,8 +994,35 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (result * }) } + // Hand off to the next queued prompt (if any) under dispatchMu so + // the transition from this finished run to the queued run is atomic + // against a concurrent Cancel. activeRequests for this session was + // just deleted above, so without the lock there is a window in + // which the session looks idle and a cancel becomes a no-op that + // fails to stop the queued prompt. Holding the lock lets us observe + // a pending cancel recorded against the session and drop the queue + // instead of running it, and (for the recursion) hand a fresh + // accept reservation to the dequeued call so acceptedRuns stays > 0 + // across the recursive Run's own dispatch handoff — keeping the + // session observable to Cancel for the entire transition and + // closing the dequeue -> re-register window. + mu := a.sessionMu(call.SessionID) + mu.Lock() + if _, pending := a.pendingCancels.Get(call.SessionID); pending { + // A cancel was recorded for this session (e.g. it arrived while + // this run was active and a follow-up had been accepted). Drop + // the queue instead of running it and consume the marker. + a.pendingCancels.Del(call.SessionID) + a.messageQueue.Del(call.SessionID) + mu.Unlock() + return result, err + } queuedMessages, ok := a.messageQueue.Get(call.SessionID) if !ok || len(queuedMessages) == 0 { + // No queued work. Clear any stale pending-cancel entry as a + // safety net so it can't catch a future run (no-op when unset). + a.pendingCancels.Del(call.SessionID) + mu.Unlock() return result, err } // There are queued messages restart the loop. The recursive Run @@ -753,6 +1033,14 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (result * skipRunComplete = true firstQueuedMessage := queuedMessages[0] a.messageQueue.Set(call.SessionID, queuedMessages[1:]) + // Reserve a fresh accept for the dequeued prompt before dropping the + // lock so acceptedRuns > 0 across the handoff into the recursive + // Run. This closes the window between this dequeue and the recursive + // Run registering its activeRequests entry: a cancel arriving in + // that window now records a pending cancel (acceptedRuns > 0) that + // the recursive Run's accepted path observes as cancel-on-entry. + firstQueuedMessage.Accepted = a.BeginAccepted(call.SessionID) + mu.Unlock() return a.Run(ctx, firstQueuedMessage) } @@ -1305,6 +1593,16 @@ func summaryCompletionTokens(usage fantasy.Usage, summaryMessage message.Message } func (a *sessionAgent) Cancel(sessionID string) { + // Serialize against the dispatch handoff in Run so the accepted -> + // (cancel-on-entry | queued | active) transition is atomic against + // this cancel. Every cancel observes at least one of: an active + // request, an accepted run (recorded as a pending cancel), or a + // queue entry it then clears. If none of those hold, an idle Escape + // is a true no-op and must not poison the next prompt. + mu := a.sessionMu(sessionID) + mu.Lock() + defer mu.Unlock() + // Cancel regular requests. Don't use Take() here - we need the entry to // remain in activeRequests so IsBusy() returns true until the goroutine // fully completes (including error handling that may access the DB). @@ -1320,6 +1618,20 @@ func (a *sessionAgent) Cancel(sessionID string) { cancel() } + // Record a pending cancel only when a dispatched-but-not-yet-active + // run exists. This catches a run still in the goroutine scheduler or + // about to enter Run's busy-queue branch, while leaving an idle + // session untouched. Active and accepted are not mutually exclusive: + // when a run is active and a follow-up has been accepted, both the + // cancel above and this pending record fire. + a.acceptedMu.Lock() + count, ok := a.acceptedRuns.Get(sessionID) + a.acceptedMu.Unlock() + if ok && count > 0 { + slog.Debug("Recording pending cancel for accepted run", "session_id", sessionID) + a.pendingCancels.Set(sessionID, struct{}{}) + } + if a.QueuedPrompts(sessionID) > 0 { slog.Debug("Clearing queued prompts", "session_id", sessionID) a.messageQueue.Del(sessionID) diff --git a/internal/agent/coordinator.go b/internal/agent/coordinator.go index a26aa111eeb8e45a699a6aab90774f04a1aca4bb..f5ca831e60cdb54edf0c0d7bfde83702a79701f1 100644 --- a/internal/agent/coordinator.go +++ b/internal/agent/coordinator.go @@ -81,6 +81,15 @@ type Coordinator interface { // INFO: (kujtim) this is not used yet we will use this when we have multiple agents // SetMainAgent(string) Run(ctx context.Context, sessionID, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error) + // RunAccepted runs a call that was already accepted via + // BeginAccepted on the fire-and-forget dispatch path. The handle is + // the only carrier of accept-state across the backend.runAgent / + // Coordinator / sessionAgent.Run layers: it reaches + // sessionAgent.Run as SessionAgentCall.Accepted, where it is + // consumed under dispatchMu once the accepted -> (cancel-on-entry | + // queued | active) transition is chosen. + RunAccepted(ctx context.Context, accept *AcceptedRun, sessionID, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error) + BeginAccepted(sessionID string) *AcceptedRun Cancel(sessionID string) CancelAll() IsSessionBusy(sessionID string) bool @@ -179,6 +188,20 @@ func NewCoordinator( // Run implements Coordinator. func (c *coordinator) Run(ctx context.Context, sessionID string, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error) { + return c.run(ctx, nil, sessionID, prompt, attachments...) +} + +// RunAccepted implements Coordinator. +func (c *coordinator) RunAccepted(ctx context.Context, accept *AcceptedRun, sessionID string, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error) { + return c.run(ctx, accept, sessionID, prompt, attachments...) +} + +// run is the shared implementation behind Run and RunAccepted. When +// accept is non-nil it is threaded onto the SessionAgentCall as +// Accepted so sessionAgent.Run can consume the accept reservation under +// dispatchMu; when nil (the in-process/local path) no accept tracking +// applies. +func (c *coordinator) run(ctx context.Context, accept *AcceptedRun, sessionID string, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error) { if err := c.readyWg.Wait(); err != nil { return nil, err } @@ -256,6 +279,7 @@ func (c *coordinator) Run(ctx context.Context, sessionID string, prompt string, FrequencyPenalty: freqPenalty, PresencePenalty: presPenalty, OnComplete: onComplete, + Accepted: accept, }) } beforeLoaded := c.skillTracker.LoadedNames() @@ -989,6 +1013,15 @@ func isExactoSupported(modelID string) bool { return slices.Contains(supportedModels, modelID) } +// BeginAccepted reserves an accept slot for sessionID on the active +// agent and returns the ownership handle. It is the fire-and-forget +// dispatch path's only way to mark a run as accepted-but-not-yet-active +// so a cancel arriving before the run registers in activeRequests is not +// lost. +func (c *coordinator) BeginAccepted(sessionID string) *AcceptedRun { + return c.currentAgent.BeginAccepted(sessionID) +} + func (c *coordinator) Cancel(sessionID string) { c.currentAgent.Cancel(sessionID) } diff --git a/internal/agent/coordinator_test.go b/internal/agent/coordinator_test.go index da0ddd0db2bf77c3c1e3eb6463549875a989a4ca..c522ef5de1061435e4cf9df1789bc3c92d9152a4 100644 --- a/internal/agent/coordinator_test.go +++ b/internal/agent/coordinator_test.go @@ -25,6 +25,10 @@ func (m *mockSessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fan return m.runFunc(ctx, call) } +func (m *mockSessionAgent) BeginAccepted(sessionID string) *AcceptedRun { + return &AcceptedRun{sessionID: sessionID} +} + func (m *mockSessionAgent) Model() Model { return m.model } func (m *mockSessionAgent) SetModels(large, small Model) {} func (m *mockSessionAgent) SetTools(tools []fantasy.AgentTool) {} diff --git a/internal/agent/dispatch_cancel_test.go b/internal/agent/dispatch_cancel_test.go new file mode 100644 index 0000000000000000000000000000000000000000..f66de252e63559239c1d577fe51c0650589aa5b4 --- /dev/null +++ b/internal/agent/dispatch_cancel_test.go @@ -0,0 +1,198 @@ +package agent + +import ( + "context" + "errors" + "sync/atomic" + "testing" + + "charm.land/fantasy" + "github.com/charmbracelet/crush/internal/message" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// finishStreamModel is a minimal fantasy.LanguageModel that streams a +// single text part followed by a normal (FinishReasonStop) finish. It +// is enough to drive sessionAgent.Run through PrepareStep and a clean +// completion without a recorded provider cassette. +type finishStreamModel struct { + text string +} + +func (m *finishStreamModel) Provider() string { return "fake" } +func (m *finishStreamModel) Model() string { return "fake-model" } + +func (m *finishStreamModel) Generate(ctx context.Context, call fantasy.Call) (*fantasy.Response, error) { + return &fantasy.Response{ + Content: fantasy.ResponseContent{fantasy.TextContent{Text: m.text}}, + FinishReason: fantasy.FinishReasonStop, + }, nil +} + +func (m *finishStreamModel) Stream(ctx context.Context, call fantasy.Call) (fantasy.StreamResponse, error) { + text := m.text + return func(yield func(fantasy.StreamPart) bool) { + if !yield(fantasy.StreamPart{Type: fantasy.StreamPartTypeTextStart, ID: "1"}) { + return + } + if !yield(fantasy.StreamPart{Type: fantasy.StreamPartTypeTextDelta, ID: "1", Delta: text}) { + return + } + if !yield(fantasy.StreamPart{Type: fantasy.StreamPartTypeTextEnd, ID: "1"}) { + return + } + yield(fantasy.StreamPart{Type: fantasy.StreamPartTypeFinish, FinishReason: fantasy.FinishReasonStop}) + }, nil +} + +func (m *finishStreamModel) GenerateObject(ctx context.Context, call fantasy.ObjectCall) (*fantasy.ObjectResponse, error) { + return nil, errors.New("not implemented") +} + +func (m *finishStreamModel) StreamObject(ctx context.Context, call fantasy.ObjectCall) (fantasy.ObjectStreamResponse, error) { + return nil, errors.New("not implemented") +} + +func newStreamTestAgent(t *testing.T) (*sessionAgent, fakeEnv) { + t.Helper() + env := testEnv(t) + model := &finishStreamModel{text: "done"} + sa := testSessionAgent(env, model, model, "system").(*sessionAgent) + return sa, env +} + +// TestCancel_ActiveAndAcceptedFiresBothBranches covers the case where a +// session is actively running (activeRequests set) AND a follow-up has +// been accepted (acceptedRuns > 0). A single Cancel must fire both: it +// invokes the active cancel func and records a pending cancel for the +// accepted follow-up. +func TestCancel_ActiveAndAcceptedFiresBothBranches(t *testing.T) { + t.Parallel() + sa, _ := newCancelTestAgent(t) + + const sid = "sid" + var activeCanceled atomic.Bool + sa.activeRequests.Set(sid, func() { activeCanceled.Store(true) }) + + accept := sa.BeginAccepted(sid) + defer accept.Close() + + sa.Cancel(sid) + + require.True(t, activeCanceled.Load(), "active cancel func must fire") + require.True(t, sa.hasPendingCancel(sid), "accepted follow-up must record a pending cancel") +} + +// TestRun_BusyWithPendingCancelTakesCancelOnEntry covers the busy-queue +// branch consulting pendingCancels: when the session is busy AND a +// cancel has been recorded for an accepted follow-up, Run must take the +// cancel-on-entry path (persist a canceled turn) instead of enqueueing +// the call behind the active run. +func TestRun_BusyWithPendingCancelTakesCancelOnEntry(t *testing.T) { + t.Parallel() + sa, env := newCancelTestAgent(t) + + sess, err := env.sessions.Create(t.Context(), "session") + require.NoError(t, err) + + // Make the session look busy: an earlier prompt is active. + sa.activeRequests.Set(sess.ID, func() {}) + + accept := sa.BeginAccepted(sess.ID) + // A cancel arrives while this follow-up is accepted-but-not-active. + sa.Cancel(sess.ID) + require.True(t, sa.hasPendingCancel(sess.ID)) + + result, err := sa.Run(t.Context(), SessionAgentCall{ + SessionID: sess.ID, + Prompt: "follow-up", + Accepted: accept, + }) + require.NoError(t, err) + require.Nil(t, result) + + // The follow-up was canceled on entry, not enqueued. + require.Equal(t, 0, sa.QueuedPrompts(sess.ID), + "cancel-on-entry must not enqueue the follow-up behind the active run") + require.False(t, sa.hasPendingCancel(sess.ID), "pending cancel must be consumed") + require.Equal(t, 0, sa.acceptedCount(sess.ID), "accept reservation must be released") + + msgs, err := env.messages.List(t.Context(), sess.ID) + require.NoError(t, err) + require.Len(t, msgs, 2) + assert.Equal(t, message.User, msgs[0].Role) + assert.Equal(t, message.Assistant, msgs[1].Role) + assert.Equal(t, message.FinishReasonCanceled, msgs[1].FinishReason()) +} + +// TestRun_PrepareStepDrainSkipsQueuedOnPendingCancel verifies that the +// queue drain inside PrepareStep skips queued follow-up prompts when a +// cancel has been recorded for the session: the queued prompt must not +// be folded into the active turn as an extra user message. +func TestRun_PrepareStepDrainSkipsQueuedOnPendingCancel(t *testing.T) { + t.Parallel() + sa, env := newStreamTestAgent(t) + + sess, err := env.sessions.Create(t.Context(), "session") + require.NoError(t, err) + + // A follow-up prompt sits queued for the session. + sa.enqueueCall(SessionAgentCall{SessionID: sess.ID, Prompt: "queued-followup"}) + // A cancel was recorded for the session while it sat in the queue. + sa.pendingCancels.Set(sess.ID, struct{}{}) + + result, err := sa.Run(t.Context(), SessionAgentCall{ + SessionID: sess.ID, + Prompt: "main", + }) + require.NoError(t, err) + require.NotNil(t, result) + + // Only the main prompt produced a user message; the queued + // follow-up was skipped, not folded into the turn. + msgs, err := env.messages.List(t.Context(), sess.ID) + require.NoError(t, err) + var userMsgs []message.Message + for _, m := range msgs { + if m.Role == message.User { + userMsgs = append(userMsgs, m) + } + } + require.Len(t, userMsgs, 1, "queued follow-up must not create a user message") + assert.Equal(t, "main", userMsgs[0].Content().String()) + + // The queue was drained and the pending cancel consumed. + require.Equal(t, 0, sa.QueuedPrompts(sess.ID)) + require.False(t, sa.hasPendingCancel(sess.ID)) +} + +// TestRun_NormalCompletionClearsStalePendingCancel verifies that a Run +// which completes normally clears any stale pending-cancel entry for the +// session, so it cannot catch a future run. +func TestRun_NormalCompletionClearsStalePendingCancel(t *testing.T) { + t.Parallel() + sa, env := newStreamTestAgent(t) + + sess, err := env.sessions.Create(t.Context(), "session") + require.NoError(t, err) + + // A stale pending cancel lingers (no queued work, no accepted run). + sa.pendingCancels.Set(sess.ID, struct{}{}) + + result, err := sa.Run(t.Context(), SessionAgentCall{ + SessionID: sess.ID, + Prompt: "main", + }) + require.NoError(t, err) + require.NotNil(t, result) + + require.False(t, sa.hasPendingCancel(sess.ID), + "normal completion must clear the stale pending cancel") + + msgs, err := env.messages.List(t.Context(), sess.ID) + require.NoError(t, err) + require.Len(t, msgs, 2) + assert.Equal(t, message.Assistant, msgs[1].Role) + assert.Equal(t, message.FinishReasonEndTurn, msgs[1].FinishReason()) +} diff --git a/internal/server/agent_cancel_test.go b/internal/server/agent_cancel_test.go index 1bbb05e511a26f02cc1dad0e1da77454af0f8905..18ea8046d0647d3576a50769b7b2146d9aa103e9 100644 --- a/internal/server/agent_cancel_test.go +++ b/internal/server/agent_cancel_test.go @@ -60,6 +60,13 @@ func (s *runCoordinator) Run(ctx context.Context, sessionID, prompt string, atta return nil, s.returnFn(ctx) } +func (s *runCoordinator) RunAccepted(ctx context.Context, accept *agent.AcceptedRun, sessionID, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error) { + return s.Run(ctx, sessionID, prompt, attachments...) +} + +func (s *runCoordinator) BeginAccepted(sessionID string) *agent.AcceptedRun { + return nil +} func (s *runCoordinator) Cancel(string) {} func (s *runCoordinator) CancelAll() {} func (s *runCoordinator) IsBusy() bool { return false } diff --git a/internal/server/sessions_isbusy_test.go b/internal/server/sessions_isbusy_test.go index 060c00abe9367dc7162bdb50dd77fe951041aa51..615f4a5b58cde3f5779b053ca9ce92e0d0d253a4 100644 --- a/internal/server/sessions_isbusy_test.go +++ b/internal/server/sessions_isbusy_test.go @@ -30,6 +30,14 @@ type stubCoordinator struct { func (s *stubCoordinator) Run(ctx context.Context, sessionID, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error) { return nil, nil } + +func (s *stubCoordinator) RunAccepted(ctx context.Context, accept *agent.AcceptedRun, sessionID, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error) { + return nil, nil +} + +func (s *stubCoordinator) BeginAccepted(sessionID string) *agent.AcceptedRun { + return nil +} func (s *stubCoordinator) Cancel(string) {} func (s *stubCoordinator) CancelAll() {} func (s *stubCoordinator) IsBusy() bool { return false }