Detailed changes
@@ -0,0 +1,226 @@
+package agent
+
+import (
+ "context"
+ "testing"
+
+ "github.com/charmbracelet/crush/internal/message"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+// newCancelTestAgent builds a DB-backed sessionAgent with no model. The
+// tests here exercise the dispatch/cancel/persist paths, none of which
+// reach agent.Stream, so a model is unnecessary.
+func newCancelTestAgent(t *testing.T) (*sessionAgent, fakeEnv) {
+ t.Helper()
+ env := testEnv(t)
+ sa := NewSessionAgent(SessionAgentOptions{
+ Sessions: env.sessions,
+ Messages: env.messages,
+ }).(*sessionAgent)
+ return sa, env
+}
+
+func (a *sessionAgent) acceptedCount(sessionID string) int {
+ c, _ := a.acceptedRuns.Get(sessionID)
+ return c
+}
+
+func (a *sessionAgent) hasPendingCancel(sessionID string) bool {
+ _, ok := a.pendingCancels.Get(sessionID)
+ return ok
+}
+
+func TestAcceptedRun_CloseIsIdempotent(t *testing.T) {
+ t.Parallel()
+ sa, _ := newCancelTestAgent(t)
+
+ accept := sa.BeginAccepted("sid")
+ require.Equal(t, "sid", accept.SessionID())
+ require.Equal(t, 1, sa.acceptedCount("sid"))
+
+ accept.Close()
+ require.Equal(t, 0, sa.acceptedCount("sid"))
+
+ // Repeated Close must not underflow the counter.
+ accept.Close()
+ accept.Close()
+ require.Equal(t, 0, sa.acceptedCount("sid"))
+}
+
+func TestAcceptedRun_MultipleReservations(t *testing.T) {
+ t.Parallel()
+ sa, _ := newCancelTestAgent(t)
+
+ a1 := sa.BeginAccepted("sid")
+ a2 := sa.BeginAccepted("sid")
+ require.Equal(t, 2, sa.acceptedCount("sid"))
+
+ a1.Close()
+ require.Equal(t, 1, sa.acceptedCount("sid"))
+
+ a2.Close()
+ require.Equal(t, 0, sa.acceptedCount("sid"))
+}
+
+func TestAcceptedRun_NilSafe(t *testing.T) {
+ t.Parallel()
+ var accept *AcceptedRun
+ require.Equal(t, "", accept.SessionID())
+ // Must not panic.
+ accept.Close()
+}
+
+func TestCancel_IdleDoesNotRecordPendingCancel(t *testing.T) {
+ t.Parallel()
+ sa, _ := newCancelTestAgent(t)
+
+ // No accepted run, no active request: a true no-op.
+ sa.Cancel("sid")
+ require.False(t, sa.hasPendingCancel("sid"))
+}
+
+func TestCancel_AcceptedRecordsPendingCancel(t *testing.T) {
+ t.Parallel()
+ sa, _ := newCancelTestAgent(t)
+
+ accept := sa.BeginAccepted("sid")
+ defer accept.Close()
+
+ sa.Cancel("sid")
+ require.True(t, sa.hasPendingCancel("sid"))
+}
+
+func TestCancel_SecondCancelWhilePendingIsNoOp(t *testing.T) {
+ t.Parallel()
+ sa, _ := newCancelTestAgent(t)
+
+ accept := sa.BeginAccepted("sid")
+ defer accept.Close()
+
+ sa.Cancel("sid")
+ require.True(t, sa.hasPendingCancel("sid"))
+
+ // A second cancel while a pending cancel is already recorded must
+ // remain a single pending cancel; one Run consumes exactly one.
+ sa.Cancel("sid")
+ require.True(t, sa.hasPendingCancel("sid"))
+}
+
+func TestRun_CancelOnEntryPersistsCanceledTurn(t *testing.T) {
+ t.Parallel()
+ sa, env := newCancelTestAgent(t)
+
+ sess, err := env.sessions.Create(t.Context(), "session")
+ require.NoError(t, err)
+
+ accept := sa.BeginAccepted(sess.ID)
+ // A cancel arrives in the accepted-but-not-yet-active window.
+ sa.Cancel(sess.ID)
+ require.True(t, sa.hasPendingCancel(sess.ID))
+
+ result, err := sa.Run(t.Context(), SessionAgentCall{
+ SessionID: sess.ID,
+ Prompt: "hello",
+ Accepted: accept,
+ })
+ require.NoError(t, err)
+ require.Nil(t, result)
+
+ // The pending cancel was consumed and the accept released.
+ require.False(t, sa.hasPendingCancel(sess.ID))
+ require.Equal(t, 0, sa.acceptedCount(sess.ID))
+
+ msgs, err := env.messages.List(t.Context(), sess.ID)
+ require.NoError(t, err)
+ require.Len(t, msgs, 2)
+ assert.Equal(t, message.User, msgs[0].Role)
+ assert.Equal(t, message.Assistant, msgs[1].Role)
+ assert.Equal(t, message.FinishReasonCanceled, msgs[1].FinishReason())
+}
+
+func TestPersistCanceledTurn_WritesBothWhenUserMissing(t *testing.T) {
+ t.Parallel()
+ sa, env := newCancelTestAgent(t)
+
+ sess, err := env.sessions.Create(t.Context(), "session")
+ require.NoError(t, err)
+
+ err = sa.persistCanceledTurn(t.Context(), SessionAgentCall{
+ SessionID: sess.ID,
+ Prompt: "hello",
+ }, false)
+ require.NoError(t, err)
+
+ msgs, err := env.messages.List(t.Context(), sess.ID)
+ require.NoError(t, err)
+ require.Len(t, msgs, 2)
+ assert.Equal(t, message.User, msgs[0].Role)
+ assert.Equal(t, message.Assistant, msgs[1].Role)
+ assert.Equal(t, message.FinishReasonCanceled, msgs[1].FinishReason())
+}
+
+func TestPersistCanceledTurn_WritesAssistantOnlyWhenUserCreated(t *testing.T) {
+ t.Parallel()
+ sa, env := newCancelTestAgent(t)
+
+ sess, err := env.sessions.Create(t.Context(), "session")
+ require.NoError(t, err)
+
+ // Simulate PrepareStep having already created the user message.
+ _, err = sa.createUserMessage(t.Context(), SessionAgentCall{
+ SessionID: sess.ID,
+ Prompt: "hello",
+ })
+ require.NoError(t, err)
+
+ err = sa.persistCanceledTurn(t.Context(), SessionAgentCall{
+ SessionID: sess.ID,
+ Prompt: "hello",
+ }, true)
+ require.NoError(t, err)
+
+ msgs, err := env.messages.List(t.Context(), sess.ID)
+ require.NoError(t, err)
+ require.Len(t, msgs, 2)
+ assert.Equal(t, message.User, msgs[0].Role)
+ assert.Equal(t, message.Assistant, msgs[1].Role)
+ assert.Equal(t, message.FinishReasonCanceled, msgs[1].FinishReason())
+}
+
+func TestPersistCanceledTurn_SucceedsWithCanceledContext(t *testing.T) {
+ t.Parallel()
+ sa, env := newCancelTestAgent(t)
+
+ sess, err := env.sessions.Create(t.Context(), "session")
+ require.NoError(t, err)
+
+ // Simulate workspace shutdown having already canceled the run
+ // context. WithoutCancel must let the writes through.
+ ctx, cancel := context.WithCancel(t.Context())
+ cancel()
+
+ err = sa.persistCanceledTurn(ctx, SessionAgentCall{
+ SessionID: sess.ID,
+ Prompt: "hello",
+ }, false)
+ require.NoError(t, err)
+
+ msgs, err := env.messages.List(t.Context(), sess.ID)
+ require.NoError(t, err)
+ require.Len(t, msgs, 2)
+}
+
+func TestClearPendingCancel(t *testing.T) {
+ t.Parallel()
+ sa, _ := newCancelTestAgent(t)
+
+ accept := sa.BeginAccepted("sid")
+ defer accept.Close()
+ sa.Cancel("sid")
+ require.True(t, sa.hasPendingCancel("sid"))
+
+ sa.clearPendingCancel("sid")
+ require.False(t, sa.hasPendingCancel("sid"))
+}
@@ -21,6 +21,7 @@ import (
"strconv"
"strings"
"sync"
+ "sync/atomic"
"time"
"charm.land/catwalk/pkg/catwalk"
@@ -103,10 +104,19 @@ type SessionAgentCall struct {
// recursion drains, so falling back to the default broker
// publish keeps the event visible to subscribers.
OnComplete func(notify.RunComplete)
+ // Accepted, when non-nil, is the accept reservation taken by
+ // BeginAccepted before the call was dispatched onto a goroutine
+ // (the client/server fire-and-forget path). Run consumes it under
+ // dispatchMu[SessionID] once the accepted -> (cancel-on-entry |
+ // queued | active) transition has been chosen. When nil
+ // (in-process / local callers like AppWorkspace), behavior is
+ // unchanged and no accept tracking applies.
+ Accepted *AcceptedRun
}
type SessionAgent interface {
Run(context.Context, SessionAgentCall) (*fantasy.AgentResult, error)
+ BeginAccepted(sessionID string) *AcceptedRun
SetModels(large Model, small Model)
SetTools(tools []fantasy.AgentTool)
SetSystemPrompt(systemPrompt string)
@@ -145,6 +155,32 @@ type sessionAgent struct {
messageQueue *csync.Map[string, []SessionAgentCall]
activeRequests *csync.Map[string, context.CancelFunc]
+
+ // dispatchMu holds a per-session mutex that serializes the
+ // accepted -> (cancel-on-entry | queued | active) transition in
+ // Run against a concurrent Cancel. The lock is held only during
+ // the brief handoff (no DB or LLM I/O under the lock).
+ dispatchMu *csync.Map[string, *sync.Mutex]
+ // acceptedRuns counts dispatched-but-not-yet-active runs per
+ // session. A counter > 0 means a dispatched prompt is in flight
+ // and has not yet completed the dispatch handoff in Run. Only
+ // BeginAccepted increments it; only AcceptedRun.Close decrements
+ // it.
+ acceptedRuns *csync.Map[string, int]
+ // pendingCancels records sessions whose dispatched-but-not-yet-
+ // running call should observe a cancellation request. It is only
+ // set by Cancel when acceptedRuns > 0, so an idle Escape never
+ // poisons the next prompt.
+ pendingCancels *csync.Map[string, struct{}]
+ // dispatchMuCreate guards lazy creation of per-session entries in
+ // dispatchMu so two goroutines can't race to lock different mutex
+ // instances for the same session.
+ dispatchMuCreate sync.Mutex
+ // acceptedMu serializes increments/decrements of acceptedRuns. It
+ // is separate from dispatchMu so AcceptedRun.Close (which may run
+ // while Run holds dispatchMu for the same session) does not
+ // deadlock by re-entering the dispatch lock.
+ acceptedMu sync.Mutex
}
type SessionAgentOptions struct {
@@ -180,7 +216,144 @@ func NewSessionAgent(
runComplete: opts.RunComplete,
messageQueue: csync.NewMap[string, []SessionAgentCall](),
activeRequests: csync.NewMap[string, context.CancelFunc](),
+ dispatchMu: csync.NewMap[string, *sync.Mutex](),
+ acceptedRuns: csync.NewMap[string, int](),
+ pendingCancels: csync.NewMap[string, struct{}](),
+ }
+}
+
+// AcceptedRun owns exactly one accept reservation taken by
+// BeginAccepted. It is the only carrier of accept-state across the
+// backend.runAgent / Coordinator.Run / sessionAgent.Run layers: a
+// counter > 0 means a dispatched prompt is in flight and has not yet
+// completed the dispatch handoff in Run. Close is the only way to
+// release the reservation and is idempotent.
+type AcceptedRun struct {
+ agent *sessionAgent
+ sessionID string
+ done atomic.Bool
+}
+
+// Close decrements the accept counter for this reservation. It is safe
+// to call multiple times; only the first call has effect.
+func (r *AcceptedRun) Close() {
+ if r == nil {
+ return
+ }
+ if !r.done.CompareAndSwap(false, true) {
+ return
+ }
+ r.agent.endAccepted(r.sessionID)
+}
+
+// SessionID exposes the session this reservation is for so the run path
+// can use it without an extra parameter.
+func (r *AcceptedRun) SessionID() string {
+ if r == nil {
+ return ""
+ }
+ return r.sessionID
+}
+
+// BeginAccepted increments the accept counter for sessionID and returns
+// a handle whose Close is the only way to decrement it. It is the only
+// entry point that mutates acceptedRuns.
+func (a *sessionAgent) BeginAccepted(sessionID string) *AcceptedRun {
+ a.acceptedMu.Lock()
+ defer a.acceptedMu.Unlock()
+ count, _ := a.acceptedRuns.Get(sessionID)
+ a.acceptedRuns.Set(sessionID, count+1)
+ return &AcceptedRun{agent: a, sessionID: sessionID}
+}
+
+// endAccepted decrements the accept counter for sessionID. It is only
+// called via AcceptedRun.Close. It uses a dedicated lock (not the
+// per-session dispatch mutex) so it can run while Run holds dispatchMu
+// for the same session without deadlocking.
+func (a *sessionAgent) endAccepted(sessionID string) {
+ a.acceptedMu.Lock()
+ defer a.acceptedMu.Unlock()
+ count, ok := a.acceptedRuns.Get(sessionID)
+ if !ok || count <= 1 {
+ a.acceptedRuns.Del(sessionID)
+ return
+ }
+ a.acceptedRuns.Set(sessionID, count-1)
+}
+
+// sessionMu returns the per-session dispatch mutex, creating it on first
+// use. Creation is guarded so concurrent callers always observe the same
+// mutex instance for a given session.
+func (a *sessionAgent) sessionMu(sessionID string) *sync.Mutex {
+ if mu, ok := a.dispatchMu.Get(sessionID); ok {
+ return mu
+ }
+ a.dispatchMuCreate.Lock()
+ defer a.dispatchMuCreate.Unlock()
+ if mu, ok := a.dispatchMu.Get(sessionID); ok {
+ return mu
+ }
+ mu := &sync.Mutex{}
+ a.dispatchMu.Set(sessionID, mu)
+ return mu
+}
+
+// enqueueCall appends call to the session's message queue. The
+// OnComplete hook is stripped: the caller that supplied it (typically
+// coordinator.Run) has its own retry/coalesce scope that ends when it
+// returns, so by the time the queue drains nobody is left to consume the
+// buffered terminal event. The recursive Run falls back to the default
+// broker publish, which is what existing subscribers expect for queued
+// turns.
+func (a *sessionAgent) enqueueCall(call SessionAgentCall) {
+ existing, ok := a.messageQueue.Get(call.SessionID)
+ if !ok {
+ existing = []SessionAgentCall{}
}
+ queued := call
+ queued.OnComplete = nil
+ queued.Accepted = nil
+ existing = append(existing, queued)
+ a.messageQueue.Set(call.SessionID, existing)
+}
+
+// clearPendingCancel removes any pending-cancel record for sessionID. It
+// takes the per-session dispatch lock so it is ordered against Cancel and
+// the dispatch handoff.
+func (a *sessionAgent) clearPendingCancel(sessionID string) {
+ mu := a.sessionMu(sessionID)
+ mu.Lock()
+ defer mu.Unlock()
+ a.pendingCancels.Del(sessionID)
+}
+
+// persistCanceledTurn writes the user/assistant records for a turn that
+// was canceled before (or just as) streaming would have produced them.
+// It creates the user message only when it was not already created by an
+// earlier createUserMessage call (userMsgCreated), then writes an
+// assistant message with FinishReasonCanceled. Both writes use
+// context.WithoutCancel(ctx) so workspace shutdown (which cancels the run
+// context) can't drop them.
+func (a *sessionAgent) persistCanceledTurn(ctx context.Context, call SessionAgentCall, userMsgCreated bool) error {
+ writeCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), 5*time.Second)
+ defer cancel()
+ if !userMsgCreated {
+ if _, err := a.createUserMessage(writeCtx, call); err != nil {
+ return err
+ }
+ }
+ largeModel := a.largeModel.Get()
+ assistant, err := a.messages.Create(writeCtx, call.SessionID, message.CreateMessageParams{
+ Role: message.Assistant,
+ Parts: []message.ContentPart{},
+ Model: largeModel.ModelCfg.Model,
+ Provider: largeModel.ModelCfg.Provider,
+ })
+ if err != nil {
+ return err
+ }
+ assistant.AddFinish(message.FinishReasonCanceled, "User canceled request", "")
+ return a.messages.Update(writeCtx, assistant)
}
// ValidateCall performs the cheap structural validation that
@@ -204,22 +377,73 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (result *
return nil, err
}
- // Queue the message if busy. Strip OnComplete: the caller that
- // supplied the hook (typically coordinator.Run) has its own
- // retry/coalesce scope that ends when it returns, so by the time
- // the queue drains nobody is left to consume the buffered
- // terminal event. The recursive Run will fall back to the
- // default broker publish, which is what existing subscribers
- // expect for queued turns.
- if a.IsSessionBusy(call.SessionID) {
- existing, ok := a.messageQueue.Get(call.SessionID)
- if !ok {
- existing = []SessionAgentCall{}
+ // genCtx/cancel are the run context and its cancel func. For the
+ // accepted (fire-and-forget) dispatch path they are created under
+ // dispatchMu below so a concurrent Cancel can observe the
+ // activeRequests entry before the assistant message exists. For
+ // the in-process path they stay nil here and are created later,
+ // preserving the original ordering.
+ var (
+ genCtx context.Context
+ cancel context.CancelFunc
+ activeRegistered bool
+ userMsgCreated bool
+ )
+
+ if call.Accepted != nil {
+ // Serialize the accepted -> (cancel-on-entry | queued |
+ // active) transition against a concurrent Cancel. Cancel takes
+ // the same per-session lock, so every cancel observes at least
+ // one of: pendingCancels, an activeRequests entry, or a
+ // messageQueue entry it then clears.
+ mu := a.sessionMu(call.SessionID)
+ mu.Lock()
+
+ if _, pending := a.pendingCancels.Get(call.SessionID); pending {
+ // Cancel-on-entry: a cancel arrived while this run was
+ // dispatched but not yet active. Consume the pending
+ // cancel, release the accept reservation, drop the lock,
+ // and persist a canceled turn without entering Stream.
+ a.pendingCancels.Del(call.SessionID)
+ call.Accepted.Close()
+ mu.Unlock()
+ if err := a.persistCanceledTurn(ctx, call, false); err != nil {
+ return nil, err
+ }
+ return nil, nil
+ }
+
+ if a.IsSessionBusy(call.SessionID) {
+ // Busy: an earlier prompt is active. Queue this call and
+ // release the accept reservation. A Cancel arriving after
+ // this point sees the active entry and clears the queue.
+ a.enqueueCall(call)
+ call.Accepted.Close()
+ mu.Unlock()
+ return nil, nil
}
- queued := call
- queued.OnComplete = nil
- existing = append(existing, queued)
- a.messageQueue.Set(call.SessionID, existing)
+
+ // Idle: become the active run. Register the cancel func before
+ // dropping the lock so a Cancel that arrives between here and
+ // assistant creation is not lost.
+ runCtx := context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
+ genCtx, cancel = context.WithCancel(runCtx)
+ a.activeRequests.Set(call.SessionID, cancel)
+ activeRegistered = true
+ call.Accepted.Close()
+ mu.Unlock()
+
+ defer cancel()
+ defer a.activeRequests.Del(call.SessionID)
+ } else if a.IsSessionBusy(call.SessionID) {
+ // Queue the message if busy. Strip OnComplete: the caller that
+ // supplied the hook (typically coordinator.Run) has its own
+ // retry/coalesce scope that ends when it returns, so by the time
+ // the queue drains nobody is left to consume the buffered
+ // terminal event. The recursive Run will fall back to the
+ // default broker publish, which is what existing subscribers
+ // expect for queued turns.
+ a.enqueueCall(call)
return nil, nil
}
@@ -282,15 +506,22 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (result *
if err != nil {
return nil, err
}
+ userMsgCreated = true
// Add the session to the context.
ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
- genCtx, cancel := context.WithCancel(ctx)
- a.activeRequests.Set(call.SessionID, cancel)
+ // For the accepted dispatch path the run context and cancel func
+ // were already created and registered under dispatchMu above; reuse
+ // them. For the in-process path create them here, preserving the
+ // original ordering.
+ if !activeRegistered {
+ genCtx, cancel = context.WithCancel(ctx)
+ a.activeRequests.Set(call.SessionID, cancel)
- defer cancel()
- defer a.activeRequests.Del(call.SessionID)
+ defer cancel()
+ defer a.activeRequests.Del(call.SessionID)
+ }
// skipRunComplete is set just before the queued-recursion path so
// the outer Run doesn't publish a RunComplete that would race
// with — and be superseded by — the recursive call's own
@@ -390,14 +621,24 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (result *
// Use latest tools (updated by SetTools when MCP tools change).
prepared.Tools = a.tools.Copy()
+ // Drain queued follow-up prompts, but skip them if a cancel
+ // was recorded for the session while they sat in the queue:
+ // a cancel that arrived after the queue insertion must not
+ // let the queued prompt run as part of this step.
+ dispatchLock := a.sessionMu(call.SessionID)
+ dispatchLock.Lock()
+ _, canceled := a.pendingCancels.Get(call.SessionID)
queuedCalls, _ := a.messageQueue.Get(call.SessionID)
a.messageQueue.Del(call.SessionID)
- for _, queued := range queuedCalls {
- userMessage, createErr := a.createUserMessage(callContext, queued)
- if createErr != nil {
- return callContext, prepared, createErr
+ dispatchLock.Unlock()
+ if !canceled {
+ for _, queued := range queuedCalls {
+ userMessage, createErr := a.createUserMessage(callContext, queued)
+ if createErr != nil {
+ return callContext, prepared, createErr
+ }
+ prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
}
- prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
}
prepared.Messages = a.workaroundProviderMediaLimitations(prepared.Messages, largeModel)
@@ -594,6 +835,18 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (result *
isHyper := largeModel.ModelCfg.Provider == hyper.Name
isCancelErr := errors.Is(err, context.Canceled)
if currentAssistant == nil {
+ // Cancel-before-assistant-creation window: the run was
+ // canceled after activeRequests.Set but before PrepareStep
+ // created the assistant message. Without this, the turn
+ // would return with no FinishReasonCanceled marker and no
+ // user-visible record. The user message was already created
+ // above, so persistCanceledTurn only writes the assistant
+ // record.
+ if isCancelErr {
+ if persistErr := a.persistCanceledTurn(ctx, call, userMsgCreated); persistErr != nil {
+ return nil, persistErr
+ }
+ }
return result, err
}
// Persist final state with a context detached from the run
@@ -741,8 +994,35 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (result *
})
}
+ // Hand off to the next queued prompt (if any) under dispatchMu so
+ // the transition from this finished run to the queued run is atomic
+ // against a concurrent Cancel. activeRequests for this session was
+ // just deleted above, so without the lock there is a window in
+ // which the session looks idle and a cancel becomes a no-op that
+ // fails to stop the queued prompt. Holding the lock lets us observe
+ // a pending cancel recorded against the session and drop the queue
+ // instead of running it, and (for the recursion) hand a fresh
+ // accept reservation to the dequeued call so acceptedRuns stays > 0
+ // across the recursive Run's own dispatch handoff — keeping the
+ // session observable to Cancel for the entire transition and
+ // closing the dequeue -> re-register window.
+ mu := a.sessionMu(call.SessionID)
+ mu.Lock()
+ if _, pending := a.pendingCancels.Get(call.SessionID); pending {
+ // A cancel was recorded for this session (e.g. it arrived while
+ // this run was active and a follow-up had been accepted). Drop
+ // the queue instead of running it and consume the marker.
+ a.pendingCancels.Del(call.SessionID)
+ a.messageQueue.Del(call.SessionID)
+ mu.Unlock()
+ return result, err
+ }
queuedMessages, ok := a.messageQueue.Get(call.SessionID)
if !ok || len(queuedMessages) == 0 {
+ // No queued work. Clear any stale pending-cancel entry as a
+ // safety net so it can't catch a future run (no-op when unset).
+ a.pendingCancels.Del(call.SessionID)
+ mu.Unlock()
return result, err
}
// There are queued messages restart the loop. The recursive Run
@@ -753,6 +1033,14 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (result *
skipRunComplete = true
firstQueuedMessage := queuedMessages[0]
a.messageQueue.Set(call.SessionID, queuedMessages[1:])
+ // Reserve a fresh accept for the dequeued prompt before dropping the
+ // lock so acceptedRuns > 0 across the handoff into the recursive
+ // Run. This closes the window between this dequeue and the recursive
+ // Run registering its activeRequests entry: a cancel arriving in
+ // that window now records a pending cancel (acceptedRuns > 0) that
+ // the recursive Run's accepted path observes as cancel-on-entry.
+ firstQueuedMessage.Accepted = a.BeginAccepted(call.SessionID)
+ mu.Unlock()
return a.Run(ctx, firstQueuedMessage)
}
@@ -1305,6 +1593,16 @@ func summaryCompletionTokens(usage fantasy.Usage, summaryMessage message.Message
}
func (a *sessionAgent) Cancel(sessionID string) {
+ // Serialize against the dispatch handoff in Run so the accepted ->
+ // (cancel-on-entry | queued | active) transition is atomic against
+ // this cancel. Every cancel observes at least one of: an active
+ // request, an accepted run (recorded as a pending cancel), or a
+ // queue entry it then clears. If none of those hold, an idle Escape
+ // is a true no-op and must not poison the next prompt.
+ mu := a.sessionMu(sessionID)
+ mu.Lock()
+ defer mu.Unlock()
+
// Cancel regular requests. Don't use Take() here - we need the entry to
// remain in activeRequests so IsBusy() returns true until the goroutine
// fully completes (including error handling that may access the DB).
@@ -1320,6 +1618,20 @@ func (a *sessionAgent) Cancel(sessionID string) {
cancel()
}
+ // Record a pending cancel only when a dispatched-but-not-yet-active
+ // run exists. This catches a run still in the goroutine scheduler or
+ // about to enter Run's busy-queue branch, while leaving an idle
+ // session untouched. Active and accepted are not mutually exclusive:
+ // when a run is active and a follow-up has been accepted, both the
+ // cancel above and this pending record fire.
+ a.acceptedMu.Lock()
+ count, ok := a.acceptedRuns.Get(sessionID)
+ a.acceptedMu.Unlock()
+ if ok && count > 0 {
+ slog.Debug("Recording pending cancel for accepted run", "session_id", sessionID)
+ a.pendingCancels.Set(sessionID, struct{}{})
+ }
+
if a.QueuedPrompts(sessionID) > 0 {
slog.Debug("Clearing queued prompts", "session_id", sessionID)
a.messageQueue.Del(sessionID)
@@ -81,6 +81,15 @@ type Coordinator interface {
// INFO: (kujtim) this is not used yet we will use this when we have multiple agents
// SetMainAgent(string)
Run(ctx context.Context, sessionID, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error)
+ // RunAccepted runs a call that was already accepted via
+ // BeginAccepted on the fire-and-forget dispatch path. The handle is
+ // the only carrier of accept-state across the backend.runAgent /
+ // Coordinator / sessionAgent.Run layers: it reaches
+ // sessionAgent.Run as SessionAgentCall.Accepted, where it is
+ // consumed under dispatchMu once the accepted -> (cancel-on-entry |
+ // queued | active) transition is chosen.
+ RunAccepted(ctx context.Context, accept *AcceptedRun, sessionID, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error)
+ BeginAccepted(sessionID string) *AcceptedRun
Cancel(sessionID string)
CancelAll()
IsSessionBusy(sessionID string) bool
@@ -179,6 +188,20 @@ func NewCoordinator(
// Run implements Coordinator.
func (c *coordinator) Run(ctx context.Context, sessionID string, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error) {
+ return c.run(ctx, nil, sessionID, prompt, attachments...)
+}
+
+// RunAccepted implements Coordinator.
+func (c *coordinator) RunAccepted(ctx context.Context, accept *AcceptedRun, sessionID string, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error) {
+ return c.run(ctx, accept, sessionID, prompt, attachments...)
+}
+
+// run is the shared implementation behind Run and RunAccepted. When
+// accept is non-nil it is threaded onto the SessionAgentCall as
+// Accepted so sessionAgent.Run can consume the accept reservation under
+// dispatchMu; when nil (the in-process/local path) no accept tracking
+// applies.
+func (c *coordinator) run(ctx context.Context, accept *AcceptedRun, sessionID string, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error) {
if err := c.readyWg.Wait(); err != nil {
return nil, err
}
@@ -256,6 +279,7 @@ func (c *coordinator) Run(ctx context.Context, sessionID string, prompt string,
FrequencyPenalty: freqPenalty,
PresencePenalty: presPenalty,
OnComplete: onComplete,
+ Accepted: accept,
})
}
beforeLoaded := c.skillTracker.LoadedNames()
@@ -989,6 +1013,15 @@ func isExactoSupported(modelID string) bool {
return slices.Contains(supportedModels, modelID)
}
+// BeginAccepted reserves an accept slot for sessionID on the active
+// agent and returns the ownership handle. It is the fire-and-forget
+// dispatch path's only way to mark a run as accepted-but-not-yet-active
+// so a cancel arriving before the run registers in activeRequests is not
+// lost.
+func (c *coordinator) BeginAccepted(sessionID string) *AcceptedRun {
+ return c.currentAgent.BeginAccepted(sessionID)
+}
+
func (c *coordinator) Cancel(sessionID string) {
c.currentAgent.Cancel(sessionID)
}
@@ -25,6 +25,10 @@ func (m *mockSessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fan
return m.runFunc(ctx, call)
}
+func (m *mockSessionAgent) BeginAccepted(sessionID string) *AcceptedRun {
+ return &AcceptedRun{sessionID: sessionID}
+}
+
func (m *mockSessionAgent) Model() Model { return m.model }
func (m *mockSessionAgent) SetModels(large, small Model) {}
func (m *mockSessionAgent) SetTools(tools []fantasy.AgentTool) {}
@@ -0,0 +1,198 @@
+package agent
+
+import (
+ "context"
+ "errors"
+ "sync/atomic"
+ "testing"
+
+ "charm.land/fantasy"
+ "github.com/charmbracelet/crush/internal/message"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+// finishStreamModel is a minimal fantasy.LanguageModel that streams a
+// single text part followed by a normal (FinishReasonStop) finish. It
+// is enough to drive sessionAgent.Run through PrepareStep and a clean
+// completion without a recorded provider cassette.
+type finishStreamModel struct {
+ text string
+}
+
+func (m *finishStreamModel) Provider() string { return "fake" }
+func (m *finishStreamModel) Model() string { return "fake-model" }
+
+func (m *finishStreamModel) Generate(ctx context.Context, call fantasy.Call) (*fantasy.Response, error) {
+ return &fantasy.Response{
+ Content: fantasy.ResponseContent{fantasy.TextContent{Text: m.text}},
+ FinishReason: fantasy.FinishReasonStop,
+ }, nil
+}
+
+func (m *finishStreamModel) Stream(ctx context.Context, call fantasy.Call) (fantasy.StreamResponse, error) {
+ text := m.text
+ return func(yield func(fantasy.StreamPart) bool) {
+ if !yield(fantasy.StreamPart{Type: fantasy.StreamPartTypeTextStart, ID: "1"}) {
+ return
+ }
+ if !yield(fantasy.StreamPart{Type: fantasy.StreamPartTypeTextDelta, ID: "1", Delta: text}) {
+ return
+ }
+ if !yield(fantasy.StreamPart{Type: fantasy.StreamPartTypeTextEnd, ID: "1"}) {
+ return
+ }
+ yield(fantasy.StreamPart{Type: fantasy.StreamPartTypeFinish, FinishReason: fantasy.FinishReasonStop})
+ }, nil
+}
+
+func (m *finishStreamModel) GenerateObject(ctx context.Context, call fantasy.ObjectCall) (*fantasy.ObjectResponse, error) {
+ return nil, errors.New("not implemented")
+}
+
+func (m *finishStreamModel) StreamObject(ctx context.Context, call fantasy.ObjectCall) (fantasy.ObjectStreamResponse, error) {
+ return nil, errors.New("not implemented")
+}
+
+func newStreamTestAgent(t *testing.T) (*sessionAgent, fakeEnv) {
+ t.Helper()
+ env := testEnv(t)
+ model := &finishStreamModel{text: "done"}
+ sa := testSessionAgent(env, model, model, "system").(*sessionAgent)
+ return sa, env
+}
+
+// TestCancel_ActiveAndAcceptedFiresBothBranches covers the case where a
+// session is actively running (activeRequests set) AND a follow-up has
+// been accepted (acceptedRuns > 0). A single Cancel must fire both: it
+// invokes the active cancel func and records a pending cancel for the
+// accepted follow-up.
+func TestCancel_ActiveAndAcceptedFiresBothBranches(t *testing.T) {
+ t.Parallel()
+ sa, _ := newCancelTestAgent(t)
+
+ const sid = "sid"
+ var activeCanceled atomic.Bool
+ sa.activeRequests.Set(sid, func() { activeCanceled.Store(true) })
+
+ accept := sa.BeginAccepted(sid)
+ defer accept.Close()
+
+ sa.Cancel(sid)
+
+ require.True(t, activeCanceled.Load(), "active cancel func must fire")
+ require.True(t, sa.hasPendingCancel(sid), "accepted follow-up must record a pending cancel")
+}
+
+// TestRun_BusyWithPendingCancelTakesCancelOnEntry covers the busy-queue
+// branch consulting pendingCancels: when the session is busy AND a
+// cancel has been recorded for an accepted follow-up, Run must take the
+// cancel-on-entry path (persist a canceled turn) instead of enqueueing
+// the call behind the active run.
+func TestRun_BusyWithPendingCancelTakesCancelOnEntry(t *testing.T) {
+ t.Parallel()
+ sa, env := newCancelTestAgent(t)
+
+ sess, err := env.sessions.Create(t.Context(), "session")
+ require.NoError(t, err)
+
+ // Make the session look busy: an earlier prompt is active.
+ sa.activeRequests.Set(sess.ID, func() {})
+
+ accept := sa.BeginAccepted(sess.ID)
+ // A cancel arrives while this follow-up is accepted-but-not-active.
+ sa.Cancel(sess.ID)
+ require.True(t, sa.hasPendingCancel(sess.ID))
+
+ result, err := sa.Run(t.Context(), SessionAgentCall{
+ SessionID: sess.ID,
+ Prompt: "follow-up",
+ Accepted: accept,
+ })
+ require.NoError(t, err)
+ require.Nil(t, result)
+
+ // The follow-up was canceled on entry, not enqueued.
+ require.Equal(t, 0, sa.QueuedPrompts(sess.ID),
+ "cancel-on-entry must not enqueue the follow-up behind the active run")
+ require.False(t, sa.hasPendingCancel(sess.ID), "pending cancel must be consumed")
+ require.Equal(t, 0, sa.acceptedCount(sess.ID), "accept reservation must be released")
+
+ msgs, err := env.messages.List(t.Context(), sess.ID)
+ require.NoError(t, err)
+ require.Len(t, msgs, 2)
+ assert.Equal(t, message.User, msgs[0].Role)
+ assert.Equal(t, message.Assistant, msgs[1].Role)
+ assert.Equal(t, message.FinishReasonCanceled, msgs[1].FinishReason())
+}
+
+// TestRun_PrepareStepDrainSkipsQueuedOnPendingCancel verifies that the
+// queue drain inside PrepareStep skips queued follow-up prompts when a
+// cancel has been recorded for the session: the queued prompt must not
+// be folded into the active turn as an extra user message.
+func TestRun_PrepareStepDrainSkipsQueuedOnPendingCancel(t *testing.T) {
+ t.Parallel()
+ sa, env := newStreamTestAgent(t)
+
+ sess, err := env.sessions.Create(t.Context(), "session")
+ require.NoError(t, err)
+
+ // A follow-up prompt sits queued for the session.
+ sa.enqueueCall(SessionAgentCall{SessionID: sess.ID, Prompt: "queued-followup"})
+ // A cancel was recorded for the session while it sat in the queue.
+ sa.pendingCancels.Set(sess.ID, struct{}{})
+
+ result, err := sa.Run(t.Context(), SessionAgentCall{
+ SessionID: sess.ID,
+ Prompt: "main",
+ })
+ require.NoError(t, err)
+ require.NotNil(t, result)
+
+ // Only the main prompt produced a user message; the queued
+ // follow-up was skipped, not folded into the turn.
+ msgs, err := env.messages.List(t.Context(), sess.ID)
+ require.NoError(t, err)
+ var userMsgs []message.Message
+ for _, m := range msgs {
+ if m.Role == message.User {
+ userMsgs = append(userMsgs, m)
+ }
+ }
+ require.Len(t, userMsgs, 1, "queued follow-up must not create a user message")
+ assert.Equal(t, "main", userMsgs[0].Content().String())
+
+ // The queue was drained and the pending cancel consumed.
+ require.Equal(t, 0, sa.QueuedPrompts(sess.ID))
+ require.False(t, sa.hasPendingCancel(sess.ID))
+}
+
+// TestRun_NormalCompletionClearsStalePendingCancel verifies that a Run
+// which completes normally clears any stale pending-cancel entry for the
+// session, so it cannot catch a future run.
+func TestRun_NormalCompletionClearsStalePendingCancel(t *testing.T) {
+ t.Parallel()
+ sa, env := newStreamTestAgent(t)
+
+ sess, err := env.sessions.Create(t.Context(), "session")
+ require.NoError(t, err)
+
+ // A stale pending cancel lingers (no queued work, no accepted run).
+ sa.pendingCancels.Set(sess.ID, struct{}{})
+
+ result, err := sa.Run(t.Context(), SessionAgentCall{
+ SessionID: sess.ID,
+ Prompt: "main",
+ })
+ require.NoError(t, err)
+ require.NotNil(t, result)
+
+ require.False(t, sa.hasPendingCancel(sess.ID),
+ "normal completion must clear the stale pending cancel")
+
+ msgs, err := env.messages.List(t.Context(), sess.ID)
+ require.NoError(t, err)
+ require.Len(t, msgs, 2)
+ assert.Equal(t, message.Assistant, msgs[1].Role)
+ assert.Equal(t, message.FinishReasonEndTurn, msgs[1].FinishReason())
+}
@@ -60,6 +60,13 @@ func (s *runCoordinator) Run(ctx context.Context, sessionID, prompt string, atta
return nil, s.returnFn(ctx)
}
+func (s *runCoordinator) RunAccepted(ctx context.Context, accept *agent.AcceptedRun, sessionID, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error) {
+ return s.Run(ctx, sessionID, prompt, attachments...)
+}
+
+func (s *runCoordinator) BeginAccepted(sessionID string) *agent.AcceptedRun {
+ return nil
+}
func (s *runCoordinator) Cancel(string) {}
func (s *runCoordinator) CancelAll() {}
func (s *runCoordinator) IsBusy() bool { return false }
@@ -30,6 +30,14 @@ type stubCoordinator struct {
func (s *stubCoordinator) Run(ctx context.Context, sessionID, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error) {
return nil, nil
}
+
+func (s *stubCoordinator) RunAccepted(ctx context.Context, accept *agent.AcceptedRun, sessionID, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error) {
+ return nil, nil
+}
+
+func (s *stubCoordinator) BeginAccepted(sessionID string) *agent.AcceptedRun {
+ return nil
+}
func (s *stubCoordinator) Cancel(string) {}
func (s *stubCoordinator) CancelAll() {}
func (s *stubCoordinator) IsBusy() bool { return false }