Detailed changes
@@ -0,0 +1,231 @@
+package agent
+
+import (
+ "context"
+ "testing"
+
+ "github.com/charmbracelet/crush/internal/message"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+// newCancelTestAgent builds a DB-backed sessionAgent with no model. The
+// tests here exercise the dispatch/cancel/persist paths, none of which
+// reach agent.Stream, so a model is unnecessary.
+func newCancelTestAgent(t *testing.T) (*sessionAgent, fakeEnv) {
+ t.Helper()
+ env := testEnv(t)
+ sa := NewSessionAgent(SessionAgentOptions{
+ Sessions: env.sessions,
+ Messages: env.messages,
+ }).(*sessionAgent)
+ return sa, env
+}
+
+func (a *sessionAgent) acceptedCount(sessionID string) int {
+ c, _ := a.acceptedRuns.Get(sessionID)
+ return c
+}
+
+func (a *sessionAgent) hasPendingCancel(sessionID string) bool {
+ mark, ok := a.cancelMark.Get(sessionID)
+ return ok && mark > 0
+}
+
+func (a *sessionAgent) pendingCancelMark(sessionID string) uint64 {
+ mark, _ := a.cancelMark.Get(sessionID)
+ return mark
+}
+
+func TestAcceptedRun_CloseIsIdempotent(t *testing.T) {
+ t.Parallel()
+ sa, _ := newCancelTestAgent(t)
+
+ accept := sa.BeginAccepted("sid")
+ require.Equal(t, "sid", accept.SessionID())
+ require.Equal(t, 1, sa.acceptedCount("sid"))
+
+ accept.Close()
+ require.Equal(t, 0, sa.acceptedCount("sid"))
+
+ // Repeated Close must not underflow the counter.
+ accept.Close()
+ accept.Close()
+ require.Equal(t, 0, sa.acceptedCount("sid"))
+}
+
+func TestAcceptedRun_MultipleReservations(t *testing.T) {
+ t.Parallel()
+ sa, _ := newCancelTestAgent(t)
+
+ a1 := sa.BeginAccepted("sid")
+ a2 := sa.BeginAccepted("sid")
+ require.Equal(t, 2, sa.acceptedCount("sid"))
+
+ a1.Close()
+ require.Equal(t, 1, sa.acceptedCount("sid"))
+
+ a2.Close()
+ require.Equal(t, 0, sa.acceptedCount("sid"))
+}
+
+func TestAcceptedRun_NilSafe(t *testing.T) {
+ t.Parallel()
+ var accept *AcceptedRun
+ require.Equal(t, "", accept.SessionID())
+ // Must not panic.
+ accept.Close()
+}
+
+func TestCancel_IdleDoesNotRecordPendingCancel(t *testing.T) {
+ t.Parallel()
+ sa, _ := newCancelTestAgent(t)
+
+ // No accepted run, no active request: a true no-op.
+ sa.Cancel("sid")
+ require.False(t, sa.hasPendingCancel("sid"))
+}
+
+func TestCancel_AcceptedRecordsPendingCancel(t *testing.T) {
+ t.Parallel()
+ sa, _ := newCancelTestAgent(t)
+
+ accept := sa.BeginAccepted("sid")
+ defer accept.Close()
+
+ sa.Cancel("sid")
+ require.True(t, sa.hasPendingCancel("sid"))
+}
+
+func TestCancel_SecondCancelWhilePendingIsNoOp(t *testing.T) {
+ t.Parallel()
+ sa, _ := newCancelTestAgent(t)
+
+ accept := sa.BeginAccepted("sid")
+ defer accept.Close()
+
+ sa.Cancel("sid")
+ require.True(t, sa.hasPendingCancel("sid"))
+
+ // A second cancel while a pending cancel is already recorded must
+ // remain a single pending cancel; one Run consumes exactly one.
+ sa.Cancel("sid")
+ require.True(t, sa.hasPendingCancel("sid"))
+}
+
+func TestRun_CancelOnEntryPersistsCanceledTurn(t *testing.T) {
+ t.Parallel()
+ sa, env := newCancelTestAgent(t)
+
+ sess, err := env.sessions.Create(t.Context(), "session")
+ require.NoError(t, err)
+
+ accept := sa.BeginAccepted(sess.ID)
+ // A cancel arrives in the accepted-but-not-yet-active window.
+ sa.Cancel(sess.ID)
+ require.True(t, sa.hasPendingCancel(sess.ID))
+
+ result, err := sa.Run(t.Context(), SessionAgentCall{
+ SessionID: sess.ID,
+ Prompt: "hello",
+ Accepted: accept,
+ })
+ require.NoError(t, err)
+ require.Nil(t, result)
+
+ // The pending cancel was consumed and the accept released.
+ require.False(t, sa.hasPendingCancel(sess.ID))
+ require.Equal(t, 0, sa.acceptedCount(sess.ID))
+
+ msgs, err := env.messages.List(t.Context(), sess.ID)
+ require.NoError(t, err)
+ require.Len(t, msgs, 2)
+ assert.Equal(t, message.User, msgs[0].Role)
+ assert.Equal(t, message.Assistant, msgs[1].Role)
+ assert.Equal(t, message.FinishReasonCanceled, msgs[1].FinishReason())
+}
+
+func TestPersistCanceledTurn_WritesBothWhenUserMissing(t *testing.T) {
+ t.Parallel()
+ sa, env := newCancelTestAgent(t)
+
+ sess, err := env.sessions.Create(t.Context(), "session")
+ require.NoError(t, err)
+
+ err = sa.persistCanceledTurn(t.Context(), SessionAgentCall{
+ SessionID: sess.ID,
+ Prompt: "hello",
+ }, false)
+ require.NoError(t, err)
+
+ msgs, err := env.messages.List(t.Context(), sess.ID)
+ require.NoError(t, err)
+ require.Len(t, msgs, 2)
+ assert.Equal(t, message.User, msgs[0].Role)
+ assert.Equal(t, message.Assistant, msgs[1].Role)
+ assert.Equal(t, message.FinishReasonCanceled, msgs[1].FinishReason())
+}
+
+func TestPersistCanceledTurn_WritesAssistantOnlyWhenUserCreated(t *testing.T) {
+ t.Parallel()
+ sa, env := newCancelTestAgent(t)
+
+ sess, err := env.sessions.Create(t.Context(), "session")
+ require.NoError(t, err)
+
+ // Simulate PrepareStep having already created the user message.
+ _, err = sa.createUserMessage(t.Context(), SessionAgentCall{
+ SessionID: sess.ID,
+ Prompt: "hello",
+ })
+ require.NoError(t, err)
+
+ err = sa.persistCanceledTurn(t.Context(), SessionAgentCall{
+ SessionID: sess.ID,
+ Prompt: "hello",
+ }, true)
+ require.NoError(t, err)
+
+ msgs, err := env.messages.List(t.Context(), sess.ID)
+ require.NoError(t, err)
+ require.Len(t, msgs, 2)
+ assert.Equal(t, message.User, msgs[0].Role)
+ assert.Equal(t, message.Assistant, msgs[1].Role)
+ assert.Equal(t, message.FinishReasonCanceled, msgs[1].FinishReason())
+}
+
+func TestPersistCanceledTurn_SucceedsWithCanceledContext(t *testing.T) {
+ t.Parallel()
+ sa, env := newCancelTestAgent(t)
+
+ sess, err := env.sessions.Create(t.Context(), "session")
+ require.NoError(t, err)
+
+ // Simulate workspace shutdown having already canceled the run
+ // context. WithoutCancel must let the writes through.
+ ctx, cancel := context.WithCancel(t.Context())
+ cancel()
+
+ err = sa.persistCanceledTurn(ctx, SessionAgentCall{
+ SessionID: sess.ID,
+ Prompt: "hello",
+ }, false)
+ require.NoError(t, err)
+
+ msgs, err := env.messages.List(t.Context(), sess.ID)
+ require.NoError(t, err)
+ require.Len(t, msgs, 2)
+}
+
+func TestClearPendingCancel(t *testing.T) {
+ t.Parallel()
+ sa, _ := newCancelTestAgent(t)
+
+ accept := sa.BeginAccepted("sid")
+ defer accept.Close()
+ sa.Cancel("sid")
+ require.True(t, sa.hasPendingCancel("sid"))
+
+ sa.clearPendingCancel("sid")
+ require.False(t, sa.hasPendingCancel("sid"))
+}
@@ -21,6 +21,7 @@ import (
"strconv"
"strings"
"sync"
+ "sync/atomic"
"time"
"charm.land/catwalk/pkg/catwalk"
@@ -103,10 +104,28 @@ type SessionAgentCall struct {
// recursion drains, so falling back to the default broker
// publish keeps the event visible to subscribers.
OnComplete func(notify.RunComplete)
+ // Accepted, when non-nil, is the accept reservation taken by
+ // BeginAccepted before the call was dispatched onto a goroutine
+ // (the client/server fire-and-forget path). Run consumes it under
+ // dispatchMu[SessionID] once the accepted -> (cancel-on-entry |
+ // queued | active) transition has been chosen. When nil
+ // (in-process / local callers like AppWorkspace), behavior is
+ // unchanged and no accept tracking applies.
+ Accepted *AcceptedRun
+ // acceptSeq carries the accept sequence of the handle that produced
+ // this call after it has been enqueued and its Accepted handle
+ // stripped. The queue-drain paths compare it against a session's
+ // cancel mark so a follow-up queued before a cancel is dropped while
+ // one queued after the cancel survives. 0 means untracked (an
+ // in-process enqueue with no accept reservation), which the drain
+ // paths treat as covered by any present mark, preserving the
+ // pre-sequence behavior.
+ acceptSeq uint64
}
type SessionAgent interface {
Run(context.Context, SessionAgentCall) (*fantasy.AgentResult, error)
+ BeginAccepted(sessionID string) *AcceptedRun
SetModels(large Model, small Model)
SetTools(tools []fantasy.AgentTool)
SetSystemPrompt(systemPrompt string)
@@ -145,6 +164,43 @@ type sessionAgent struct {
messageQueue *csync.Map[string, []SessionAgentCall]
activeRequests *csync.Map[string, context.CancelFunc]
+
+ // dispatchMu holds a per-session mutex that serializes the
+ // accepted -> (cancel-on-entry | queued | active) transition in
+ // Run against a concurrent Cancel. The lock is held only during
+ // the brief handoff (no DB or LLM I/O under the lock).
+ dispatchMu *csync.Map[string, *sync.Mutex]
+ // acceptedRuns counts dispatched-but-not-yet-active runs per
+ // session. A counter > 0 means a dispatched prompt is in flight
+ // and has not yet completed the dispatch handoff in Run. Only
+ // BeginAccepted increments it; only AcceptedRun.Close decrements
+ // it.
+ acceptedRuns *csync.Map[string, int]
+ // cancelMark records, per session, a high-water accept sequence: an
+ // accepted handle is canceled by it iff the handle's sequence is at
+ // or below the mark. Cancel raises the mark to the latest sequence
+ // assigned at cancel time, so a single Cancel covers every prompt
+ // accepted-but-not-yet-active then, while a prompt accepted later
+ // (higher sequence) is never poisoned. Absent or 0 means no pending
+ // cancel. It is only raised by Cancel when acceptedRuns > 0, so an
+ // idle Escape never records a mark.
+ cancelMark *csync.Map[string, uint64]
+ // dispatchMuCreate guards lazy creation of per-session entries in
+ // dispatchMu so two goroutines can't race to lock different mutex
+ // instances for the same session.
+ dispatchMuCreate sync.Mutex
+ // acceptedMu serializes increments/decrements of acceptedRuns and
+ // the assignment of accept sequence numbers from acceptSeqGen. It
+ // is separate from dispatchMu so AcceptedRun.Close (which may run
+ // while Run holds dispatchMu for the same session) does not
+ // deadlock by re-entering the dispatch lock.
+ acceptedMu sync.Mutex
+ // acceptSeqGen is the monotonic source of accept sequence numbers.
+ // Each BeginAccepted increments it under acceptedMu and stamps the
+ // returned handle, so sequences strictly increase in accept order
+ // across the agent. Cancel uses its current value as the per-session
+ // high-water mark.
+ acceptSeqGen uint64
}
type SessionAgentOptions struct {
@@ -180,33 +236,403 @@ func NewSessionAgent(
runComplete: opts.RunComplete,
messageQueue: csync.NewMap[string, []SessionAgentCall](),
activeRequests: csync.NewMap[string, context.CancelFunc](),
+ dispatchMu: csync.NewMap[string, *sync.Mutex](),
+ acceptedRuns: csync.NewMap[string, int](),
+ cancelMark: csync.NewMap[string, uint64](),
}
}
-func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (result *fantasy.AgentResult, retErr error) {
+// AcceptedRun owns exactly one accept reservation taken by
+// BeginAccepted. It is the only carrier of accept-state across the
+// backend.runAgent / Coordinator.Run / sessionAgent.Run layers: a
+// counter > 0 means a dispatched prompt is in flight and has not yet
+// completed the dispatch handoff in Run. Close is the only way to
+// release the reservation and is idempotent.
+type AcceptedRun struct {
+ agent *sessionAgent
+ sessionID string
+ // seq is the monotonic accept sequence stamped by BeginAccepted. A
+ // cancel covers this handle iff seq is at or below the session's
+ // cancel mark, so a handle accepted after a cancel (higher seq) is
+ // never poisoned by it.
+ seq uint64
+ done atomic.Bool
+}
+
+// Close decrements the accept counter for this reservation. It is safe
+// to call multiple times; only the first call has effect.
+func (r *AcceptedRun) Close() {
+ if r == nil {
+ return
+ }
+ if !r.done.CompareAndSwap(false, true) {
+ return
+ }
+ r.agent.endAccepted(r.sessionID)
+}
+
+// SessionID exposes the session this reservation is for so the run path
+// can use it without an extra parameter.
+func (r *AcceptedRun) SessionID() string {
+ if r == nil {
+ return ""
+ }
+ return r.sessionID
+}
+
+// BeginAccepted increments the accept counter for sessionID and returns
+// a handle whose Close is the only way to decrement it. It is the only
+// entry point that mutates acceptedRuns.
+func (a *sessionAgent) BeginAccepted(sessionID string) *AcceptedRun {
+ a.acceptedMu.Lock()
+ defer a.acceptedMu.Unlock()
+ count, _ := a.acceptedRuns.Get(sessionID)
+ a.acceptedRuns.Set(sessionID, count+1)
+ a.acceptSeqGen++
+ return &AcceptedRun{agent: a, sessionID: sessionID, seq: a.acceptSeqGen}
+}
+
+// endAccepted decrements the accept counter for sessionID. It is only
+// called via AcceptedRun.Close. It uses a dedicated lock (not the
+// per-session dispatch mutex) so it can run while Run holds dispatchMu
+// for the same session without deadlocking.
+//
+// When the count reaches zero the session's cancel mark is dropped: no
+// accepted handle remains for it to cover, and any handle accepted later
+// gets a strictly higher sequence that the mark would not match anyway.
+// Handles canceled on entry never reach RunComplete, so this is the only
+// place that clears the mark for an all-canceled batch. Sibling handles
+// covered by the same mark are serialized on the per-session dispatch
+// mutex and read the mark before they Close, so this never clears it out
+// from under a covered handle still waiting to enter Run.
+func (a *sessionAgent) endAccepted(sessionID string) {
+ a.acceptedMu.Lock()
+ defer a.acceptedMu.Unlock()
+ count, ok := a.acceptedRuns.Get(sessionID)
+ if !ok || count <= 1 {
+ a.acceptedRuns.Del(sessionID)
+ a.cancelMark.Del(sessionID)
+ return
+ }
+ a.acceptedRuns.Set(sessionID, count-1)
+}
+
+// sessionMu returns the per-session dispatch mutex, creating it on first
+// use. Creation is guarded so concurrent callers always observe the same
+// mutex instance for a given session.
+func (a *sessionAgent) sessionMu(sessionID string) *sync.Mutex {
+ if mu, ok := a.dispatchMu.Get(sessionID); ok {
+ return mu
+ }
+ a.dispatchMuCreate.Lock()
+ defer a.dispatchMuCreate.Unlock()
+ if mu, ok := a.dispatchMu.Get(sessionID); ok {
+ return mu
+ }
+ mu := &sync.Mutex{}
+ a.dispatchMu.Set(sessionID, mu)
+ return mu
+}
+
+// enqueueCall appends call to the session's message queue. The
+// OnComplete hook is stripped: the caller that supplied it (typically
+// coordinator.Run) has its own retry/coalesce scope that ends when it
+// returns, so by the time the queue drains nobody is left to consume the
+// buffered terminal event. The recursive Run falls back to the default
+// broker publish, which is what existing subscribers expect for queued
+// turns.
+func (a *sessionAgent) enqueueCall(call SessionAgentCall) {
+ existing, ok := a.messageQueue.Get(call.SessionID)
+ if !ok {
+ existing = []SessionAgentCall{}
+ }
+ queued := call
+ if call.Accepted != nil {
+ // Preserve the accept sequence after the handle is stripped so
+ // the queue-drain paths can tell a follow-up queued before a
+ // cancel (covered by the mark) from one queued after it.
+ queued.acceptSeq = call.Accepted.seq
+ }
+ queued.OnComplete = nil
+ queued.Accepted = nil
+ existing = append(existing, queued)
+ a.messageQueue.Set(call.SessionID, existing)
+}
+
+// drainQueueForStep partitions the session's queued calls for the current
+// streaming step under the per-session dispatch mutex so the filtering is
+// atomic against a concurrent Cancel: canceledBySeq requires the caller to
+// hold that mutex, and evaluating it here (rather than after unlocking)
+// prevents a cancel recorded between the drain and the check from being
+// observed inconsistently.
+//
+// Calls covered by a pending cancel are dropped; the dropped ones that
+// carry a RunID are returned in canceledWithRunID so the caller can
+// publish their terminal cancelled RunComplete (a caller waiting on that
+// RunID, e.g. `crush run`, would otherwise hang). Uncanceled calls without
+// a RunID are returned in fold to be folded into the active turn,
+// preserving the existing follow-up behavior. Uncanceled calls that carry
+// a RunID are left in the queue so each runs as its own turn via the
+// recursive run path and publishes its own RunComplete, giving every
+// RunID-bearing prompt an explicit lifecycle instead of being silently
+// absorbed into another turn. fold is processed by the caller without the
+// lock held.
+func (a *sessionAgent) drainQueueForStep(sessionID string) (fold, canceledWithRunID []SessionAgentCall) {
+ dispatchLock := a.sessionMu(sessionID)
+ dispatchLock.Lock()
+ defer dispatchLock.Unlock()
+ queuedCalls, _ := a.messageQueue.Get(sessionID)
+ var keep []SessionAgentCall
+ for _, queued := range queuedCalls {
+ if a.canceledBySeq(sessionID, queued.acceptSeq) {
+ if queued.RunID != "" {
+ canceledWithRunID = append(canceledWithRunID, queued)
+ }
+ continue
+ }
+ if queued.RunID != "" {
+ keep = append(keep, queued)
+ continue
+ }
+ fold = append(fold, queued)
+ }
+ if len(keep) == 0 {
+ a.messageQueue.Del(sessionID)
+ } else {
+ a.messageQueue.Set(sessionID, keep)
+ }
+ return fold, canceledWithRunID
+}
+
+// publishCanceledQueueDrops emits a terminal cancelled RunComplete for
+// every dropped queued call that carries a RunID. A queued prompt removed
+// from the queue without ever running — covered by a pending cancel, or
+// cleared by Cancel/ClearQueue — would otherwise leave a caller blocked on
+// that RunID: `crush run` ignores live message events and exits only on a
+// RunComplete whose RunID matches. Calls without a RunID had no such waiter
+// and are dropped silently as before. A detached, bounded context keeps the
+// must-deliver publish alive even when the run context that triggered the
+// drop is already canceled.
+func (a *sessionAgent) publishCanceledQueueDrops(drops []SessionAgentCall) {
+ var hasRunID bool
+ for _, d := range drops {
+ if d.RunID != "" {
+ hasRunID = true
+ break
+ }
+ }
+ if !hasRunID {
+ return
+ }
+ ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
+ defer cancel()
+ for _, d := range drops {
+ if d.RunID == "" {
+ continue
+ }
+ a.publishRunComplete(ctx, d, notify.RunComplete{
+ SessionID: d.SessionID,
+ RunID: d.RunID,
+ Cancelled: true,
+ })
+ }
+}
+
+// clearQueueAndNotify removes all queued prompts for the session and
+// publishes a terminal cancelled RunComplete for any that carried a RunID,
+// so callers waiting on those RunIDs (e.g. `crush run`) are not left
+// hanging when their queued prompt is discarded without running.
+func (a *sessionAgent) clearQueueAndNotify(sessionID string) {
+ queued, ok := a.messageQueue.Get(sessionID)
+ a.messageQueue.Del(sessionID)
+ if !ok {
+ return
+ }
+ a.publishCanceledQueueDrops(queued)
+}
+
+// clearPendingCancel removes any pending-cancel mark for sessionID. It
+// takes the per-session dispatch lock so it is ordered against Cancel
+// and the dispatch handoff.
+func (a *sessionAgent) clearPendingCancel(sessionID string) {
+ mu := a.sessionMu(sessionID)
+ mu.Lock()
+ defer mu.Unlock()
+ a.cancelMark.Del(sessionID)
+}
+
+// canceledBySeq reports whether an accepted handle or queued call with
+// the given accept sequence is covered by a pending cancel for the
+// session. Callers must hold the session's dispatch mutex. A tracked
+// sequence (seq > 0) is covered only when it is at or below the cancel
+// high-water mark, so a prompt accepted after the cancel (higher seq) is
+// never poisoned. An untracked sequence (seq == 0, an in-process enqueue
+// with no accept reservation) is covered whenever any mark is present,
+// preserving the pre-sequence behavior. The mark is not consumed: it
+// stays so every sibling handle it covers observes the same cancel, and
+// a later handle (higher seq) ignores it regardless.
+func (a *sessionAgent) canceledBySeq(sessionID string, seq uint64) bool {
+ mark, ok := a.cancelMark.Get(sessionID)
+ if !ok || mark == 0 {
+ return false
+ }
+ return seq == 0 || seq <= mark
+}
+
+// persistCanceledTurn writes the user/assistant records for a turn that
+// was canceled before (or just as) streaming would have produced them.
+// It creates the user message only when it was not already created by an
+// earlier createUserMessage call (userMsgCreated), then writes an
+// assistant message with FinishReasonCanceled. Both writes use
+// context.WithoutCancel(ctx) so workspace shutdown (which cancels the run
+// context) can't drop them.
+func (a *sessionAgent) persistCanceledTurn(ctx context.Context, call SessionAgentCall, userMsgCreated bool) error {
+ writeCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), 5*time.Second)
+ defer cancel()
+ if !userMsgCreated {
+ if _, err := a.createUserMessage(writeCtx, call); err != nil {
+ return err
+ }
+ }
+ largeModel := a.largeModel.Get()
+ assistant, err := a.messages.Create(writeCtx, call.SessionID, message.CreateMessageParams{
+ Role: message.Assistant,
+ Parts: []message.ContentPart{},
+ Model: largeModel.ModelCfg.Model,
+ Provider: largeModel.ModelCfg.Provider,
+ })
+ if err != nil {
+ return err
+ }
+ assistant.AddFinish(message.FinishReasonCanceled, "User canceled request", "")
+ return a.messages.Update(writeCtx, assistant)
+}
+
+// publishRunComplete emits the authoritative terminal event for a turn.
+// It honors the per-call OnComplete hook when set (so the coordinator can
+// coalesce retries) and otherwise falls back to the RunComplete broker.
+// ctx is used only for the bounded-blocking must-deliver publish; the
+// terminal payload is supplied by the caller. This is the single emit path
+// shared by the streaming defer and the cancel-on-entry early return so a
+// caller waiting on RunComplete (e.g. `crush run` with a RunID) always
+// observes exactly one terminal event regardless of which Run branch ends
+// the turn.
+func (a *sessionAgent) publishRunComplete(ctx context.Context, call SessionAgentCall, complete notify.RunComplete) {
+ if call.OnComplete != nil {
+ call.OnComplete(complete)
+ return
+ }
+ if a.runComplete == nil {
+ return
+ }
+ a.runComplete.PublishMustDeliver(ctx, pubsub.UpdatedEvent, complete)
+}
+
+// ValidateCall performs the cheap structural validation that
+// sessionAgent.Run requires before a call can be dispatched: a call must
+// carry either a non-empty prompt or a text attachment, and it must name a
+// session. It is exported so callers that accept a run before dispatching it
+// (e.g. backend.SendMessage) can apply the same checks and keep the error
+// contract consistent.
+func ValidateCall(call SessionAgentCall) error {
if call.Prompt == "" && !message.ContainsTextAttachment(call.Attachments) {
- return nil, ErrEmptyPrompt
+ return ErrEmptyPrompt
}
if call.SessionID == "" {
- return nil, ErrSessionMissing
- }
-
- // Queue the message if busy. Strip OnComplete: the caller that
- // supplied the hook (typically coordinator.Run) has its own
- // retry/coalesce scope that ends when it returns, so by the time
- // the queue drains nobody is left to consume the buffered
- // terminal event. The recursive Run will fall back to the
- // default broker publish, which is what existing subscribers
- // expect for queued turns.
- if a.IsSessionBusy(call.SessionID) {
- existing, ok := a.messageQueue.Get(call.SessionID)
- if !ok {
- existing = []SessionAgentCall{}
+ return ErrSessionMissing
+ }
+ return nil
+}
+
+func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (result *fantasy.AgentResult, retErr error) {
+ if err := ValidateCall(call); err != nil {
+ return nil, err
+ }
+
+ // genCtx/cancel are the run context and its cancel func. For the
+ // accepted (fire-and-forget) dispatch path they are created under
+ // dispatchMu below so a concurrent Cancel can observe the
+ // activeRequests entry before the assistant message exists. For
+ // the in-process path they stay nil here and are created later,
+ // preserving the original ordering.
+ var (
+ genCtx context.Context
+ cancel context.CancelFunc
+ activeRegistered bool
+ userMsgCreated bool
+ )
+
+ if call.Accepted != nil {
+ // Serialize the accepted -> (cancel-on-entry | queued |
+ // active) transition against a concurrent Cancel. Cancel takes
+ // the same per-session lock, so every cancel observes at least
+ // one of: a cancel mark, an activeRequests entry, or a
+ // messageQueue entry it then clears.
+ mu := a.sessionMu(call.SessionID)
+ mu.Lock()
+
+ if a.canceledBySeq(call.SessionID, call.Accepted.seq) {
+ // Cancel-on-entry: a cancel arrived while this run was
+ // dispatched but not yet active, and this handle's accept
+ // sequence is at or below the session's cancel mark. The
+ // mark is left in place so sibling handles it also covers
+ // observe the same cancel; release the accept reservation,
+ // drop the lock, and persist a canceled turn without
+ // entering Stream.
+ //
+ // This path returns before the streaming defer that
+ // publishes RunComplete is installed, so emit the terminal
+ // event explicitly. Without it, a caller waiting on
+ // RunComplete for this RunID (e.g. `crush run`, which
+ // ignores message events and blocks on RunComplete) would
+ // hang on an immediately-canceled accepted run.
+ call.Accepted.Close()
+ mu.Unlock()
+ complete := notify.RunComplete{
+ SessionID: call.SessionID,
+ RunID: call.RunID,
+ Cancelled: true,
+ }
+ if err := a.persistCanceledTurn(ctx, call, false); err != nil {
+ complete.Error = err.Error()
+ a.publishRunComplete(ctx, call, complete)
+ return nil, err
+ }
+ a.publishRunComplete(ctx, call, complete)
+ return nil, nil
}
- queued := call
- queued.OnComplete = nil
- existing = append(existing, queued)
- a.messageQueue.Set(call.SessionID, existing)
+
+ if a.IsSessionBusy(call.SessionID) {
+ // Busy: an earlier prompt is active. Queue this call and
+ // release the accept reservation. A Cancel arriving after
+ // this point sees the active entry and clears the queue.
+ a.enqueueCall(call)
+ call.Accepted.Close()
+ mu.Unlock()
+ return nil, nil
+ }
+
+ // Idle: become the active run. Register the cancel func before
+ // dropping the lock so a Cancel that arrives between here and
+ // assistant creation is not lost.
+ runCtx := context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
+ genCtx, cancel = context.WithCancel(runCtx)
+ a.activeRequests.Set(call.SessionID, cancel)
+ activeRegistered = true
+ call.Accepted.Close()
+ mu.Unlock()
+
+ defer cancel()
+ defer a.activeRequests.Del(call.SessionID)
+ } else if a.IsSessionBusy(call.SessionID) {
+ // Queue the message if busy. Strip OnComplete: the caller that
+ // supplied the hook (typically coordinator.Run) has its own
+ // retry/coalesce scope that ends when it returns, so by the time
+ // the queue drains nobody is left to consume the buffered
+ // terminal event. The recursive Run will fall back to the
+ // default broker publish, which is what existing subscribers
+ // expect for queued turns.
+ a.enqueueCall(call)
return nil, nil
}
@@ -269,15 +695,22 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (result *
if err != nil {
return nil, err
}
+ userMsgCreated = true
// Add the session to the context.
ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
- genCtx, cancel := context.WithCancel(ctx)
- a.activeRequests.Set(call.SessionID, cancel)
+ // For the accepted dispatch path the run context and cancel func
+ // were already created and registered under dispatchMu above; reuse
+ // them. For the in-process path create them here, preserving the
+ // original ordering.
+ if !activeRegistered {
+ genCtx, cancel = context.WithCancel(ctx)
+ a.activeRequests.Set(call.SessionID, cancel)
- defer cancel()
- defer a.activeRequests.Del(call.SessionID)
+ defer cancel()
+ defer a.activeRequests.Del(call.SessionID)
+ }
// skipRunComplete is set just before the queued-recursion path so
// the outer Run doesn't publish a RunComplete that would race
// with — and be superseded by — the recursive call's own
@@ -305,7 +738,13 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (result *
// (the pubsub broker fan-in does not serialize publishes from
// different upstream brokers).
defer func() {
- if flushErr := a.messages.FlushAll(ctx); flushErr != nil {
+ // Use a context detached from the run context: workspace
+ // shutdown cancels ctx before this goroutine returns, but the
+ // buffered streaming deltas must still land before the DB is
+ // closed. A short timeout bounds the flush.
+ flushCtx, flushCancel := context.WithTimeout(context.WithoutCancel(ctx), 5*time.Second)
+ defer flushCancel()
+ if flushErr := a.messages.FlushAll(flushCtx); flushErr != nil {
slog.Error("Failed to flush pending message updates after run", "error", flushErr)
}
if skipRunComplete {
@@ -329,14 +768,7 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (result *
// the authoritative terminal event so a momentarily-full
// subscriber channel can't silently drop it and hang
// non-interactive clients waiting on RunComplete.
- if call.OnComplete != nil {
- call.OnComplete(complete)
- return
- }
- if a.runComplete == nil {
- return
- }
- a.runComplete.PublishMustDeliver(ctx, pubsub.UpdatedEvent, complete)
+ a.publishRunComplete(ctx, call, complete)
}()
history, files := a.preparePrompt(msgs, largeModel.CatwalkCfg.SupportsImages, call.Attachments...)
@@ -371,9 +803,20 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (result *
// Use latest tools (updated by SetTools when MCP tools change).
prepared.Tools = a.tools.Copy()
- queuedCalls, _ := a.messageQueue.Get(call.SessionID)
- a.messageQueue.Del(call.SessionID)
- for _, queued := range queuedCalls {
+ // Drain queued follow-up prompts for this step. Calls covered
+ // by a cancel recorded while they sat in the queue are dropped:
+ // a cancel that arrived after a prompt was queued must not let
+ // it run as part of this step. Coverage is per-call by accept
+ // sequence so a follow-up queued after the cancel (higher seq)
+ // is not dropped. A dropped prompt carrying a RunID still gets
+ // its terminal cancelled RunComplete so a caller waiting on it
+ // does not hang. Uncanceled prompts without a RunID are folded
+ // into this turn; uncanceled prompts with a RunID are left
+ // queued so each runs as its own turn (with its own
+ // RunComplete) via the recursive run path below.
+ fold, canceledRunIDs := a.drainQueueForStep(call.SessionID)
+ a.publishCanceledQueueDrops(canceledRunIDs)
+ for _, queued := range fold {
userMessage, createErr := a.createUserMessage(callContext, queued)
if createErr != nil {
return callContext, prepared, createErr
@@ -575,13 +1018,34 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (result *
isHyper := largeModel.ModelCfg.Provider == hyper.Name
isCancelErr := errors.Is(err, context.Canceled)
if currentAssistant == nil {
+ // Cancel-before-assistant-creation window: the run was
+ // canceled after activeRequests.Set but before PrepareStep
+ // created the assistant message. Without this, the turn
+ // would return with no FinishReasonCanceled marker and no
+ // user-visible record. The user message was already created
+ // above, so persistCanceledTurn only writes the assistant
+ // record.
+ if isCancelErr {
+ if persistErr := a.persistCanceledTurn(ctx, call, userMsgCreated); persistErr != nil {
+ return nil, persistErr
+ }
+ }
return result, err
}
+ // Persist final state with a context detached from the run
+ // context. The run context (ctx) is derived from the
+ // workspace context, which workspace shutdown cancels before
+ // agent goroutines finish; using ctx here would drop the
+ // final assistant state. WithoutCancel keeps the values
+ // (e.g. session ID) while ignoring cancellation, and a short
+ // timeout bounds the cleanup writes.
+ cleanupCtx, cleanupCancel := context.WithTimeout(context.WithoutCancel(ctx), 5*time.Second)
+ defer cleanupCancel()
// Ensure we finish thinking on error to close the reasoning state.
currentAssistant.FinishThinking()
toolCalls := currentAssistant.ToolCalls()
- // INFO: we use the parent context here because the genCtx has been cancelled.
- msgs, createErr := a.messages.List(ctx, currentAssistant.SessionID)
+ // INFO: we use the cleanup context here because the genCtx has been cancelled.
+ msgs, createErr := a.messages.List(cleanupCtx, currentAssistant.SessionID)
if createErr != nil {
return nil, createErr
}
@@ -590,7 +1054,7 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (result *
tc.Finished = true
tc.Input = "{}"
currentAssistant.AddToolCall(tc)
- updateErr := a.messages.Update(ctx, *currentAssistant)
+ updateErr := a.messages.Update(cleanupCtx, *currentAssistant)
if updateErr != nil {
return nil, updateErr
}
@@ -623,7 +1087,7 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (result *
Content: content,
IsError: true,
}
- _, createErr = a.messages.Create(ctx, currentAssistant.SessionID, message.CreateMessageParams{
+ _, createErr = a.messages.Create(cleanupCtx, currentAssistant.SessionID, message.CreateMessageParams{
Role: message.Tool,
Parts: []message.ContentPart{
toolResult,
@@ -662,9 +1126,9 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (result *
} else {
currentAssistant.AddFinish(message.FinishReasonError, defaultTitle, err.Error())
}
- // Note: we use the parent context here because the genCtx has been
+ // Note: we use the cleanup context here because the genCtx has been
// cancelled.
- updateErr := a.messages.Update(ctx, *currentAssistant)
+ updateErr := a.messages.Update(cleanupCtx, *currentAssistant)
if updateErr != nil {
return nil, updateErr
}
@@ -705,18 +1169,106 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (result *
})
}
- queuedMessages, ok := a.messageQueue.Get(call.SessionID)
- if !ok || len(queuedMessages) == 0 {
+ // Hand off to the next queued prompt (if any) under dispatchMu so
+ // the transition from this finished run to the queued run is atomic
+ // against a concurrent Cancel. activeRequests for this session was
+ // just deleted above, so without the lock there is a window in
+ // which the session looks idle and a cancel becomes a no-op that
+ // fails to stop the queued prompt. Holding the lock lets us observe
+ // a pending cancel recorded against the session and drop the queue
+ // instead of running it, and (for the recursion) hand a fresh
+ // accept reservation to the dequeued call so acceptedRuns stays > 0
+ // across the recursive Run's own dispatch handoff — keeping the
+ // session observable to Cancel for the entire transition and
+ // closing the dequeue -> re-register window.
+ mu := a.sessionMu(call.SessionID)
+ mu.Lock()
+ queuedMessages, _ := a.messageQueue.Get(call.SessionID)
+ if mark, ok := a.cancelMark.Get(call.SessionID); ok && mark > 0 && len(queuedMessages) > 0 {
+ // A cancel was recorded for this session (e.g. it arrived while
+ // this run was active and follow-ups had been queued). Drop the
+ // queued prompts it covers (accept sequence at or below the
+ // mark, or untracked); keep any queued after the cancel (higher
+ // sequence) so they still run.
+ var kept []SessionAgentCall
+ var canceledRunIDDrops []SessionAgentCall
+ for _, q := range queuedMessages {
+ if q.acceptSeq == 0 || q.acceptSeq <= mark {
+ if q.RunID != "" {
+ canceledRunIDDrops = append(canceledRunIDDrops, q)
+ }
+ continue
+ }
+ kept = append(kept, q)
+ }
+ queuedMessages = kept
+ a.messageQueue.Set(call.SessionID, kept)
+ // A dropped prompt carrying a RunID must still publish its
+ // terminal cancelled RunComplete so a caller waiting on that
+ // RunID does not hang.
+ a.publishCanceledQueueDrops(canceledRunIDDrops)
+ }
+ if len(queuedMessages) == 0 {
+ // No queued work. Clear the cancel mark only when no accepted
+ // run remains in flight that it might still cover; otherwise a
+ // sibling prompt (sequence at or below the mark) waiting to
+ // enter Run would lose its cancellation. When accepted runs are
+ // gone, this also clears a stale mark so it can't catch a
+ // future run.
+ a.messageQueue.Del(call.SessionID)
+ a.acceptedMu.Lock()
+ inFlight, _ := a.acceptedRuns.Get(call.SessionID)
+ a.acceptedMu.Unlock()
+ if inFlight == 0 {
+ a.cancelMark.Del(call.SessionID)
+ }
+ mu.Unlock()
return result, err
}
- // There are queued messages restart the loop. The recursive Run
- // publishes its own RunComplete for the queued prompt, so suppress
- // the outer defer's emit to avoid a duplicate event whose Error
- // field would belong to the recursive turn but whose MessageID/Text
- // would belong to the outer turn.
+ // There are queued messages, restart the loop. Suppress the outer
+ // defer's emit: it would otherwise observe the recursive Run's retErr
+ // (named-return clobbering through the return below) against this
+ // turn's MessageID/Text and publish a mixed, racing event.
skipRunComplete = true
+ // Decide whether this turn still owes its own terminal RunComplete.
+ // Each submitted prompt with a RunID has its own lifecycle, so a turn
+ // that is finished and handing off to a *different* queued prompt must
+ // publish its own RunComplete here — leaving it to the recursive turn
+ // (which carries a different RunID) would hang a caller waiting on
+ // this turn's RunID. The exception is the summarize-continuation path,
+ // which re-queues this same call (same RunID) to resume after a
+ // summary; in that case the eventual terminal turn for this RunID
+ // publishes, so publishing now would double-emit.
+ outerOwesRunComplete := call.RunID != ""
+ if outerOwesRunComplete {
+ for _, q := range queuedMessages {
+ if q.RunID == call.RunID {
+ outerOwesRunComplete = false
+ break
+ }
+ }
+ }
firstQueuedMessage := queuedMessages[0]
a.messageQueue.Set(call.SessionID, queuedMessages[1:])
+ // Reserve a fresh accept for the dequeued prompt before dropping the
+ // lock so acceptedRuns > 0 across the handoff into the recursive
+ // Run. This closes the window between this dequeue and the recursive
+ // Run registering its activeRequests entry: a cancel arriving in
+ // that window now records a pending cancel (acceptedRuns > 0) that
+ // the recursive Run's accepted path observes as cancel-on-entry.
+ firstQueuedMessage.Accepted = a.BeginAccepted(call.SessionID)
+ mu.Unlock()
+ if outerOwesRunComplete {
+ complete := notify.RunComplete{SessionID: call.SessionID, RunID: call.RunID}
+ if currentAssistant != nil {
+ complete.MessageID = currentAssistant.ID
+ complete.Text = currentAssistant.Content().String()
+ }
+ if ctx.Err() != nil {
+ complete.Cancelled = true
+ }
+ a.publishRunComplete(ctx, call, complete)
+ }
return a.Run(ctx, firstQueuedMessage)
}
@@ -1269,6 +1821,16 @@ func summaryCompletionTokens(usage fantasy.Usage, summaryMessage message.Message
}
func (a *sessionAgent) Cancel(sessionID string) {
+ // Serialize against the dispatch handoff in Run so the accepted ->
+ // (cancel-on-entry | queued | active) transition is atomic against
+ // this cancel. Every cancel observes at least one of: an active
+ // request, an accepted run (recorded as a pending cancel), or a
+ // queue entry it then clears. If none of those hold, an idle Escape
+ // is a true no-op and must not poison the next prompt.
+ mu := a.sessionMu(sessionID)
+ mu.Lock()
+ defer mu.Unlock()
+
// Cancel regular requests. Don't use Take() here - we need the entry to
// remain in activeRequests so IsBusy() returns true until the goroutine
// fully completes (including error handling that may access the DB).
@@ -1284,16 +1846,40 @@ func (a *sessionAgent) Cancel(sessionID string) {
cancel()
}
+ // Record a pending cancel only when a dispatched-but-not-yet-active
+ // run exists. This catches runs still in the goroutine scheduler or
+ // about to enter Run's busy-queue branch, while leaving an idle
+ // session untouched. Active and accepted are not mutually exclusive:
+ // when a run is active and a follow-up has been accepted, both the
+ // cancel above and this pending record fire.
+ //
+ // Raise the session's cancel mark to the latest accept sequence
+ // assigned so far. Every prompt currently accepted-but-not-yet-
+ // active has a sequence at or below that value, so one cancel covers
+ // all of them; a prompt accepted after this cancel gets a strictly
+ // higher sequence and is never poisoned. Using max keeps repeated
+ // cancels idempotent while the same prompts are in flight and lets a
+ // later cancel extend coverage to prompts accepted since.
+ a.acceptedMu.Lock()
+ count, ok := a.acceptedRuns.Get(sessionID)
+ mark := a.acceptSeqGen
+ a.acceptedMu.Unlock()
+ if ok && count > 0 {
+ slog.Debug("Recording cancel mark for accepted runs", "session_id", sessionID, "count", count, "mark", mark)
+ existing, _ := a.cancelMark.Get(sessionID)
+ a.cancelMark.Set(sessionID, max(existing, mark))
+ }
+
if a.QueuedPrompts(sessionID) > 0 {
slog.Debug("Clearing queued prompts", "session_id", sessionID)
- a.messageQueue.Del(sessionID)
+ a.clearQueueAndNotify(sessionID)
}
}
func (a *sessionAgent) ClearQueue(sessionID string) {
if a.QueuedPrompts(sessionID) > 0 {
slog.Debug("Clearing queued prompts", "session_id", sessionID)
- a.messageQueue.Del(sessionID)
+ a.clearQueueAndNotify(sessionID)
}
}
@@ -0,0 +1,80 @@
+// Package agenttest provides test-only constructors for wiring a real
+// production agent.Coordinator without booting a full app.App. It is
+// imported only from _test.go files (e.g. internal/backend integration
+// tests) and is never referenced by production code, so it is compiled
+// only under tests and never ships in the production binary or API.
+package agenttest
+
+import (
+ "context"
+
+ "charm.land/catwalk/pkg/catwalk"
+ "charm.land/fantasy/providers/openaicompat"
+ "github.com/charmbracelet/crush/internal/agent"
+ "github.com/charmbracelet/crush/internal/config"
+ "github.com/charmbracelet/crush/internal/message"
+ "github.com/charmbracelet/crush/internal/permission"
+ "github.com/charmbracelet/crush/internal/session"
+)
+
+// NewCoordinator builds a real agent.Coordinator through the production
+// agent.NewCoordinator constructor so the RunAccepted / BeginAccepted /
+// run path (including UpdateModels) is the actual code under test.
+//
+// It installs a minimal config with a single openai-compatible provider
+// whose model resolves offline. run rebuilds the model on every call, so
+// the provider must construct without network I/O; the cancel-on-entry
+// path this helper is built to exercise returns before any model call,
+// so no request is ever issued. The coder agent's allowed-tools list is
+// cleared to keep tool construction cheap and free of sub-agent wiring.
+//
+// The optional coordinator dependencies (history, filetracker, LSP,
+// notify, runComplete, skills) are nil: run guards the publisher fields
+// and the cancel-on-entry path never touches the others.
+func NewCoordinator(
+ ctx context.Context,
+ workingDir string,
+ sessions session.Service,
+ messages message.Service,
+) (agent.Coordinator, error) {
+ cfg, err := config.Init(workingDir, "", false)
+ if err != nil {
+ return nil, err
+ }
+
+ const (
+ providerID = "test-openai-compat"
+ modelID = "test-model"
+ )
+ cfg.Config().Providers.Set(providerID, config.ProviderConfig{
+ ID: providerID,
+ Name: "Test",
+ Type: openaicompat.Name,
+ BaseURL: "http://127.0.0.1:0/v1",
+ APIKey: "test",
+ Models: []catwalk.Model{{ID: modelID, DefaultMaxTokens: 4096}},
+ })
+ selected := config.SelectedModel{Provider: providerID, Model: modelID}
+ cfg.Config().Models[config.SelectedModelTypeLarge] = selected
+ cfg.Config().Models[config.SelectedModelTypeSmall] = selected
+ cfg.SetupAgents()
+
+ // Keep buildTools light: no sub-agent or agentic-fetch construction.
+ coderCfg := cfg.Config().Agents[config.AgentCoder]
+ coderCfg.AllowedTools = nil
+ cfg.Config().Agents[config.AgentCoder] = coderCfg
+
+ return agent.NewCoordinator(
+ ctx,
+ cfg,
+ sessions,
+ messages,
+ permission.NewPermissionService(workingDir, true, nil),
+ nil,
+ nil,
+ nil,
+ nil,
+ nil,
+ nil,
+ )
+}
@@ -81,6 +81,15 @@ type Coordinator interface {
// INFO: (kujtim) this is not used yet we will use this when we have multiple agents
// SetMainAgent(string)
Run(ctx context.Context, sessionID, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error)
+ // RunAccepted runs a call that was already accepted via
+ // BeginAccepted on the fire-and-forget dispatch path. The handle is
+ // the only carrier of accept-state across the backend.runAgent /
+ // Coordinator / sessionAgent.Run layers: it reaches
+ // sessionAgent.Run as SessionAgentCall.Accepted, where it is
+ // consumed under dispatchMu once the accepted -> (cancel-on-entry |
+ // queued | active) transition is chosen.
+ RunAccepted(ctx context.Context, accept *AcceptedRun, sessionID, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error)
+ BeginAccepted(sessionID string) *AcceptedRun
Cancel(sessionID string)
CancelAll()
IsSessionBusy(sessionID string) bool
@@ -179,6 +188,20 @@ func NewCoordinator(
// Run implements Coordinator.
func (c *coordinator) Run(ctx context.Context, sessionID string, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error) {
+ return c.run(ctx, nil, sessionID, prompt, attachments...)
+}
+
+// RunAccepted implements Coordinator.
+func (c *coordinator) RunAccepted(ctx context.Context, accept *AcceptedRun, sessionID string, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error) {
+ return c.run(ctx, accept, sessionID, prompt, attachments...)
+}
+
+// run is the shared implementation behind Run and RunAccepted. When
+// accept is non-nil it is threaded onto the SessionAgentCall as
+// Accepted so sessionAgent.Run can consume the accept reservation under
+// dispatchMu; when nil (the in-process/local path) no accept tracking
+// applies.
+func (c *coordinator) run(ctx context.Context, accept *AcceptedRun, sessionID string, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error) {
if err := c.readyWg.Wait(); err != nil {
return nil, err
}
@@ -256,6 +279,7 @@ func (c *coordinator) Run(ctx context.Context, sessionID string, prompt string,
FrequencyPenalty: freqPenalty,
PresencePenalty: presPenalty,
OnComplete: onComplete,
+ Accepted: accept,
})
}
beforeLoaded := c.skillTracker.LoadedNames()
@@ -278,6 +302,11 @@ func (c *coordinator) Run(ctx context.Context, sessionID string, prompt string,
if hasLatest && c.runComplete != nil {
c.runComplete.PublishMustDeliver(ctx, pubsub.UpdatedEvent, latest)
+ // Signal to the dispatcher (backend.runAgent) that the
+ // authoritative terminal RunComplete for this run was already
+ // emitted, so it does not publish a duplicate fallback for the
+ // error it is about to receive.
+ MarkRunCompletePublished(ctx)
}
return result, originalErr
}
@@ -1016,6 +1045,15 @@ func isExactoSupported(modelID string) bool {
return slices.Contains(supportedModels, modelID)
}
+// BeginAccepted reserves an accept slot for sessionID on the active
+// agent and returns the ownership handle. It is the fire-and-forget
+// dispatch path's only way to mark a run as accepted-but-not-yet-active
+// so a cancel arriving before the run registers in activeRequests is not
+// lost.
+func (c *coordinator) BeginAccepted(sessionID string) *AcceptedRun {
+ return c.currentAgent.BeginAccepted(sessionID)
+}
+
func (c *coordinator) Cancel(sessionID string) {
c.currentAgent.Cancel(sessionID)
}
@@ -25,6 +25,10 @@ func (m *mockSessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fan
return m.runFunc(ctx, call)
}
+func (m *mockSessionAgent) BeginAccepted(sessionID string) *AcceptedRun {
+ return &AcceptedRun{sessionID: sessionID}
+}
+
func (m *mockSessionAgent) Model() Model { return m.model }
func (m *mockSessionAgent) SetModels(large, small Model) {}
func (m *mockSessionAgent) SetTools(tools []fantasy.AgentTool) {}
@@ -0,0 +1,447 @@
+package agent
+
+import (
+ "context"
+ "errors"
+ "sync/atomic"
+ "testing"
+ "time"
+
+ "charm.land/fantasy"
+ "github.com/charmbracelet/crush/internal/agent/notify"
+ "github.com/charmbracelet/crush/internal/message"
+ "github.com/charmbracelet/crush/internal/pubsub"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+// finishStreamModel is a minimal fantasy.LanguageModel that streams a
+// single text part followed by a normal (FinishReasonStop) finish. It
+// is enough to drive sessionAgent.Run through PrepareStep and a clean
+// completion without a recorded provider cassette.
+type finishStreamModel struct {
+ text string
+}
+
+func (m *finishStreamModel) Provider() string { return "fake" }
+func (m *finishStreamModel) Model() string { return "fake-model" }
+
+func (m *finishStreamModel) Generate(ctx context.Context, call fantasy.Call) (*fantasy.Response, error) {
+ return &fantasy.Response{
+ Content: fantasy.ResponseContent{fantasy.TextContent{Text: m.text}},
+ FinishReason: fantasy.FinishReasonStop,
+ }, nil
+}
+
+func (m *finishStreamModel) Stream(ctx context.Context, call fantasy.Call) (fantasy.StreamResponse, error) {
+ text := m.text
+ return func(yield func(fantasy.StreamPart) bool) {
+ if !yield(fantasy.StreamPart{Type: fantasy.StreamPartTypeTextStart, ID: "1"}) {
+ return
+ }
+ if !yield(fantasy.StreamPart{Type: fantasy.StreamPartTypeTextDelta, ID: "1", Delta: text}) {
+ return
+ }
+ if !yield(fantasy.StreamPart{Type: fantasy.StreamPartTypeTextEnd, ID: "1"}) {
+ return
+ }
+ yield(fantasy.StreamPart{Type: fantasy.StreamPartTypeFinish, FinishReason: fantasy.FinishReasonStop})
+ }, nil
+}
+
+func (m *finishStreamModel) GenerateObject(ctx context.Context, call fantasy.ObjectCall) (*fantasy.ObjectResponse, error) {
+ return nil, errors.New("not implemented")
+}
+
+func (m *finishStreamModel) StreamObject(ctx context.Context, call fantasy.ObjectCall) (fantasy.ObjectStreamResponse, error) {
+ return nil, errors.New("not implemented")
+}
+
+func newStreamTestAgent(t *testing.T) (*sessionAgent, fakeEnv) {
+ t.Helper()
+ env := testEnv(t)
+ model := &finishStreamModel{text: "done"}
+ sa := testSessionAgent(env, model, model, "system").(*sessionAgent)
+ return sa, env
+}
+
+// TestCancel_ActiveAndAcceptedFiresBothBranches covers the case where a
+// session is actively running (activeRequests set) AND a follow-up has
+// been accepted (acceptedRuns > 0). A single Cancel must fire both: it
+// invokes the active cancel func and records a pending cancel for the
+// accepted follow-up.
+func TestCancel_ActiveAndAcceptedFiresBothBranches(t *testing.T) {
+ t.Parallel()
+ sa, _ := newCancelTestAgent(t)
+
+ const sid = "sid"
+ var activeCanceled atomic.Bool
+ sa.activeRequests.Set(sid, func() { activeCanceled.Store(true) })
+
+ accept := sa.BeginAccepted(sid)
+ defer accept.Close()
+
+ sa.Cancel(sid)
+
+ require.True(t, activeCanceled.Load(), "active cancel func must fire")
+ require.True(t, sa.hasPendingCancel(sid), "accepted follow-up must record a pending cancel")
+}
+
+// TestRun_BusyWithPendingCancelTakesCancelOnEntry covers the busy-queue
+// branch consulting pendingCancels: when the session is busy AND a
+// cancel has been recorded for an accepted follow-up, Run must take the
+// cancel-on-entry path (persist a canceled turn) instead of enqueueing
+// the call behind the active run.
+func TestRun_BusyWithPendingCancelTakesCancelOnEntry(t *testing.T) {
+ t.Parallel()
+ sa, env := newCancelTestAgent(t)
+
+ sess, err := env.sessions.Create(t.Context(), "session")
+ require.NoError(t, err)
+
+ // Make the session look busy: an earlier prompt is active.
+ sa.activeRequests.Set(sess.ID, func() {})
+
+ accept := sa.BeginAccepted(sess.ID)
+ // A cancel arrives while this follow-up is accepted-but-not-active.
+ sa.Cancel(sess.ID)
+ require.True(t, sa.hasPendingCancel(sess.ID))
+
+ result, err := sa.Run(t.Context(), SessionAgentCall{
+ SessionID: sess.ID,
+ Prompt: "follow-up",
+ Accepted: accept,
+ })
+ require.NoError(t, err)
+ require.Nil(t, result)
+
+ // The follow-up was canceled on entry, not enqueued.
+ require.Equal(t, 0, sa.QueuedPrompts(sess.ID),
+ "cancel-on-entry must not enqueue the follow-up behind the active run")
+ require.False(t, sa.hasPendingCancel(sess.ID), "pending cancel must be consumed")
+ require.Equal(t, 0, sa.acceptedCount(sess.ID), "accept reservation must be released")
+
+ msgs, err := env.messages.List(t.Context(), sess.ID)
+ require.NoError(t, err)
+ require.Len(t, msgs, 2)
+ assert.Equal(t, message.User, msgs[0].Role)
+ assert.Equal(t, message.Assistant, msgs[1].Role)
+ assert.Equal(t, message.FinishReasonCanceled, msgs[1].FinishReason())
+}
+
+// TestRun_PrepareStepDrainSkipsQueuedOnPendingCancel verifies that the
+// queue drain inside PrepareStep skips queued follow-up prompts when a
+// cancel has been recorded for the session: the queued prompt must not
+// be folded into the active turn as an extra user message.
+func TestRun_PrepareStepDrainSkipsQueuedOnPendingCancel(t *testing.T) {
+ t.Parallel()
+ sa, env := newStreamTestAgent(t)
+
+ sess, err := env.sessions.Create(t.Context(), "session")
+ require.NoError(t, err)
+
+ // A follow-up prompt sits queued for the session.
+ sa.enqueueCall(SessionAgentCall{SessionID: sess.ID, Prompt: "queued-followup"})
+ // A cancel was recorded for the session while it sat in the queue.
+ sa.cancelMark.Set(sess.ID, 1)
+
+ result, err := sa.Run(t.Context(), SessionAgentCall{
+ SessionID: sess.ID,
+ Prompt: "main",
+ })
+ require.NoError(t, err)
+ require.NotNil(t, result)
+
+ // Only the main prompt produced a user message; the queued
+ // follow-up was skipped, not folded into the turn.
+ msgs, err := env.messages.List(t.Context(), sess.ID)
+ require.NoError(t, err)
+ var userMsgs []message.Message
+ for _, m := range msgs {
+ if m.Role == message.User {
+ userMsgs = append(userMsgs, m)
+ }
+ }
+ require.Len(t, userMsgs, 1, "queued follow-up must not create a user message")
+ assert.Equal(t, "main", userMsgs[0].Content().String())
+
+ // The queue was drained and the pending cancel consumed.
+ require.Equal(t, 0, sa.QueuedPrompts(sess.ID))
+ require.False(t, sa.hasPendingCancel(sess.ID))
+}
+
+// TestRun_NormalCompletionClearsStalePendingCancel verifies that a Run
+// which completes normally clears any stale pending-cancel entry for the
+// session, so it cannot catch a future run.
+func TestRun_NormalCompletionClearsStalePendingCancel(t *testing.T) {
+ t.Parallel()
+ sa, env := newStreamTestAgent(t)
+
+ sess, err := env.sessions.Create(t.Context(), "session")
+ require.NoError(t, err)
+
+ // A stale cancel mark lingers (no queued work, no accepted run).
+ sa.cancelMark.Set(sess.ID, 1)
+
+ result, err := sa.Run(t.Context(), SessionAgentCall{
+ SessionID: sess.ID,
+ Prompt: "main",
+ })
+ require.NoError(t, err)
+ require.NotNil(t, result)
+
+ require.False(t, sa.hasPendingCancel(sess.ID),
+ "normal completion must clear the stale pending cancel")
+
+ msgs, err := env.messages.List(t.Context(), sess.ID)
+ require.NoError(t, err)
+ require.Len(t, msgs, 2)
+ assert.Equal(t, message.Assistant, msgs[1].Role)
+ assert.Equal(t, message.FinishReasonEndTurn, msgs[1].FinishReason())
+}
+
+// newCancelTestAgentWithRunComplete builds a DB-backed sessionAgent wired
+// to a RunComplete broker so tests can observe the terminal event a
+// RunID-bearing caller (e.g. `crush run`) blocks on.
+func newCancelTestAgentWithRunComplete(t *testing.T) (*sessionAgent, fakeEnv, *pubsub.Broker[notify.RunComplete]) {
+ t.Helper()
+ env := testEnv(t)
+ broker := pubsub.NewBroker[notify.RunComplete]()
+ t.Cleanup(broker.Shutdown)
+ sa := NewSessionAgent(SessionAgentOptions{
+ Sessions: env.sessions,
+ Messages: env.messages,
+ RunComplete: broker,
+ }).(*sessionAgent)
+ return sa, env, broker
+}
+
+// TestRun_CancelOnEntryPublishesRunComplete covers the first review
+// finding: the cancel-on-entry path returned before the streaming defer
+// that publishes RunComplete was installed. A caller that dispatches a
+// run with a RunID and blocks on RunComplete (ignoring message events,
+// like `crush run`) would hang on an immediately-canceled accepted run.
+// The cancel-on-entry path must publish a terminal RunComplete carrying
+// the originating RunID.
+func TestRun_CancelOnEntryPublishesRunComplete(t *testing.T) {
+ t.Parallel()
+ sa, env, broker := newCancelTestAgentWithRunComplete(t)
+
+ sess, err := env.sessions.Create(t.Context(), "session")
+ require.NoError(t, err)
+
+ ctx, cancel := context.WithCancel(t.Context())
+ defer cancel()
+ ch := broker.Subscribe(ctx)
+
+ accept := sa.BeginAccepted(sess.ID)
+ // A cancel arrives in the accepted-but-not-yet-active window.
+ sa.Cancel(sess.ID)
+ require.True(t, sa.hasPendingCancel(sess.ID))
+
+ result, err := sa.Run(t.Context(), SessionAgentCall{
+ SessionID: sess.ID,
+ RunID: "run-cancel-on-entry",
+ Prompt: "hello",
+ Accepted: accept,
+ })
+ require.NoError(t, err)
+ require.Nil(t, result)
+
+ select {
+ case got := <-ch:
+ assert.Equal(t, "run-cancel-on-entry", got.Payload.RunID,
+ "RunComplete must echo the originating RunID")
+ assert.Equal(t, sess.ID, got.Payload.SessionID)
+ assert.True(t, got.Payload.Cancelled,
+ "cancel-on-entry RunComplete must be marked cancelled")
+ case <-time.After(2 * time.Second):
+ t.Fatal("cancel-on-entry must publish RunComplete; a RunID caller would hang otherwise")
+ }
+}
+
+// TestCancel_TwoAcceptedBothObserveCancellation covers the second review
+// finding: a single cancel with two accepted-not-yet-active prompts must
+// cancel both. The cancel raises the session's high-water mark to the
+// latest accept sequence, so every prompt accepted-but-not-yet-active at
+// cancel time is covered and both take the cancel-on-entry path.
+func TestCancel_TwoAcceptedBothObserveCancellation(t *testing.T) {
+ t.Parallel()
+ sa, env := newCancelTestAgent(t)
+
+ sess, err := env.sessions.Create(t.Context(), "session")
+ require.NoError(t, err)
+
+ // Two prompts are accepted-but-not-yet-active for the same session.
+ accept1 := sa.BeginAccepted(sess.ID)
+ accept2 := sa.BeginAccepted(sess.ID)
+ require.Equal(t, 2, sa.acceptedCount(sess.ID))
+
+ // A single cancel arrives before either becomes active.
+ sa.Cancel(sess.ID)
+ require.Equal(t, accept2.seq, sa.pendingCancelMark(sess.ID),
+ "one cancel must mark every currently-accepted prompt as canceled")
+ require.GreaterOrEqual(t, sa.pendingCancelMark(sess.ID), accept1.seq,
+ "the mark must cover the earlier accepted prompt too")
+
+ // Both prompts enter Run; each must take cancel-on-entry, not run.
+ r1, err := sa.Run(t.Context(), SessionAgentCall{
+ SessionID: sess.ID,
+ Prompt: "first",
+ Accepted: accept1,
+ })
+ require.NoError(t, err)
+ require.Nil(t, r1)
+
+ r2, err := sa.Run(t.Context(), SessionAgentCall{
+ SessionID: sess.ID,
+ Prompt: "second",
+ Accepted: accept2,
+ })
+ require.NoError(t, err)
+ require.Nil(t, r2)
+
+ require.False(t, sa.hasPendingCancel(sess.ID),
+ "both reserved units must be consumed")
+ require.Equal(t, 0, sa.acceptedCount(sess.ID),
+ "both accept reservations must be released")
+
+ // Each canceled-on-entry turn writes a user + canceled assistant
+ // message, and neither prompt was enqueued to run normally.
+ require.Equal(t, 0, sa.QueuedPrompts(sess.ID),
+ "neither accepted prompt may be enqueued to run normally")
+ msgs, err := env.messages.List(t.Context(), sess.ID)
+ require.NoError(t, err)
+ require.Len(t, msgs, 4, "two canceled turns produce two user + two assistant messages")
+ var canceled int
+ for _, m := range msgs {
+ if m.Role == message.Assistant {
+ assert.Equal(t, message.FinishReasonCanceled, m.FinishReason())
+ canceled++
+ }
+ }
+ require.Equal(t, 2, canceled, "both turns must finish canceled")
+}
+
+// TestRun_IdleCancelDoesNotPoisonNextPrompt covers the idle-cancel
+// no-op guarantee end-to-end: an Escape on an idle session must not
+// record a pending cancel that leaks into the next accepted prompt, which
+// must run normally to completion.
+func TestRun_IdleCancelDoesNotPoisonNextPrompt(t *testing.T) {
+ t.Parallel()
+ sa, env := newStreamTestAgent(t)
+
+ sess, err := env.sessions.Create(t.Context(), "session")
+ require.NoError(t, err)
+
+ // Idle Escape: no accepted run, no active request.
+ sa.Cancel(sess.ID)
+ require.False(t, sa.hasPendingCancel(sess.ID),
+ "idle cancel must not record a pending cancel")
+
+ // The next accepted prompt must run normally, not cancel on entry.
+ accept := sa.BeginAccepted(sess.ID)
+ result, err := sa.Run(t.Context(), SessionAgentCall{
+ SessionID: sess.ID,
+ Prompt: "next",
+ Accepted: accept,
+ })
+ require.NoError(t, err)
+ require.NotNil(t, result, "next prompt must run normally after an idle cancel")
+
+ msgs, err := env.messages.List(t.Context(), sess.ID)
+ require.NoError(t, err)
+ require.Len(t, msgs, 2)
+ assert.Equal(t, message.User, msgs[0].Role)
+ assert.Equal(t, message.Assistant, msgs[1].Role)
+ assert.Equal(t, message.FinishReasonEndTurn, msgs[1].FinishReason(),
+ "the prompt must finish normally, not canceled")
+}
+
+// TestCancel_AcceptedAfterCancelIsNotPoisoned is the regression test for
+// the reviewer's finding: a counted session-level pending cancel let a
+// prompt accepted after the cancel enter Run first and consume a unit
+// reserved for the earlier prompts. With a sequence high-water mark, a
+// single cancel covers exactly the prompts accepted-but-not-yet-active at
+// cancel time (A and B); a prompt accepted afterwards (C) gets a higher
+// sequence and must run normally without consuming A or B's cancellation.
+// C is run first to prove it neither cancels nor drains the mark, then A
+// and B are run and must both cancel on entry.
+func TestCancel_AcceptedAfterCancelIsNotPoisoned(t *testing.T) {
+ t.Parallel()
+ sa, env := newStreamTestAgent(t)
+
+ sess, err := env.sessions.Create(t.Context(), "session")
+ require.NoError(t, err)
+
+ // A and B are accepted-but-not-yet-active.
+ acceptA := sa.BeginAccepted(sess.ID)
+ acceptB := sa.BeginAccepted(sess.ID)
+
+ // One cancel arrives covering both A and B.
+ sa.Cancel(sess.ID)
+ require.True(t, sa.hasPendingCancel(sess.ID))
+ require.Equal(t, acceptB.seq, sa.pendingCancelMark(sess.ID),
+ "the mark must cover every prompt accepted before the cancel")
+
+ // C is accepted AFTER the cancel; its sequence is above the mark.
+ acceptC := sa.BeginAccepted(sess.ID)
+ require.Greater(t, acceptC.seq, sa.pendingCancelMark(sess.ID),
+ "a prompt accepted after the cancel must not be covered by the mark")
+
+ // Run C first. It must run normally to completion and must not
+ // consume or clear the cancellation reserved for A and B.
+ rc, err := sa.Run(t.Context(), SessionAgentCall{
+ SessionID: sess.ID,
+ Prompt: "C",
+ Accepted: acceptC,
+ })
+ require.NoError(t, err)
+ require.NotNil(t, rc, "C was accepted after the cancel and must run normally")
+ require.True(t, sa.hasPendingCancel(sess.ID),
+ "running C must not drain the cancellation reserved for A and B")
+
+ // Now A and B run. Both were accepted before the cancel and must
+ // take the cancel-on-entry path.
+ ra, err := sa.Run(t.Context(), SessionAgentCall{
+ SessionID: sess.ID,
+ Prompt: "A",
+ Accepted: acceptA,
+ })
+ require.NoError(t, err)
+ require.Nil(t, ra, "A must cancel on entry, not run")
+
+ rb, err := sa.Run(t.Context(), SessionAgentCall{
+ SessionID: sess.ID,
+ Prompt: "B",
+ Accepted: acceptB,
+ })
+ require.NoError(t, err)
+ require.Nil(t, rb, "B must cancel on entry, not run")
+
+ require.False(t, sa.hasPendingCancel(sess.ID),
+ "the mark clears once all covered handles are resolved")
+ require.Equal(t, 0, sa.acceptedCount(sess.ID))
+ require.Equal(t, 0, sa.QueuedPrompts(sess.ID),
+ "neither A nor B may be enqueued to run normally")
+
+ // C produced a normal turn; A and B each produced a canceled turn.
+ msgs, err := env.messages.List(t.Context(), sess.ID)
+ require.NoError(t, err)
+ require.Len(t, msgs, 6, "C normal + A canceled + B canceled = 3 user + 3 assistant")
+
+ var normal, canceled int
+ for _, m := range msgs {
+ if m.Role != message.Assistant {
+ continue
+ }
+ switch m.FinishReason() {
+ case message.FinishReasonEndTurn:
+ normal++
+ case message.FinishReasonCanceled:
+ canceled++
+ }
+ }
+ require.Equal(t, 1, normal, "only C finished normally")
+ require.Equal(t, 2, canceled, "both A and B finished canceled")
+}
@@ -12,6 +12,9 @@ const (
// TypeReAuthenticate indicates the agent encountered an
// authentication error and the user needs to re-authenticate.
TypeReAuthenticate Type = "re_authenticate"
+ // TypeAgentError indicates the agent's turn terminated with an
+ // error. The error text is carried in Notification.Message.
+ TypeAgentError Type = "error"
)
// Notification represents a domain event published by the agent.
@@ -20,6 +23,15 @@ type Notification struct {
SessionTitle string
Type Type
ProviderID string
+ // RunID, when non-empty, is the caller-supplied correlator
+ // (proto.AgentMessage.RunID) for the run that produced this
+ // notification. It lets observers attribute a TypeAgentError to a
+ // specific request rather than to any in-flight run on the
+ // session. Empty when no caller set one.
+ RunID string
+ // Message carries the error text for TypeAgentError. Other
+ // notification types ignore it.
+ Message string
}
// RunComplete is the authoritative end-of-run signal for a session.
@@ -0,0 +1,181 @@
+package agent
+
+import (
+ "context"
+ "errors"
+ "sync/atomic"
+ "testing"
+ "time"
+
+ "charm.land/catwalk/pkg/catwalk"
+ "charm.land/fantasy"
+ "github.com/charmbracelet/crush/internal/agent/notify"
+ "github.com/charmbracelet/crush/internal/message"
+ "github.com/charmbracelet/crush/internal/pubsub"
+ "github.com/stretchr/testify/require"
+)
+
+// gatedStreamModel streams a single text part followed by a clean finish,
+// but blocks the very first Stream call until its gate is released. That
+// lets a test hold a run "active" (past PrepareStep, inside Stream) just
+// long enough to enqueue a follow-up prompt behind the busy session.
+// Subsequent Stream calls (e.g. the recursive run draining the queue)
+// proceed immediately.
+type gatedStreamModel struct {
+ text string
+ gate chan struct{}
+ entered chan struct{}
+ calls atomic.Int64
+}
+
+func (m *gatedStreamModel) Provider() string { return "fake" }
+func (m *gatedStreamModel) Model() string { return "fake-model" }
+
+func (m *gatedStreamModel) Generate(ctx context.Context, call fantasy.Call) (*fantasy.Response, error) {
+ return &fantasy.Response{
+ Content: fantasy.ResponseContent{fantasy.TextContent{Text: m.text}},
+ FinishReason: fantasy.FinishReasonStop,
+ }, nil
+}
+
+func (m *gatedStreamModel) Stream(ctx context.Context, call fantasy.Call) (fantasy.StreamResponse, error) {
+ if m.calls.Add(1) == 1 {
+ close(m.entered)
+ select {
+ case <-m.gate:
+ case <-ctx.Done():
+ }
+ }
+ text := m.text
+ return func(yield func(fantasy.StreamPart) bool) {
+ if !yield(fantasy.StreamPart{Type: fantasy.StreamPartTypeTextStart, ID: "1"}) {
+ return
+ }
+ if !yield(fantasy.StreamPart{Type: fantasy.StreamPartTypeTextDelta, ID: "1", Delta: text}) {
+ return
+ }
+ if !yield(fantasy.StreamPart{Type: fantasy.StreamPartTypeTextEnd, ID: "1"}) {
+ return
+ }
+ yield(fantasy.StreamPart{Type: fantasy.StreamPartTypeFinish, FinishReason: fantasy.FinishReasonStop})
+ }, nil
+}
+
+func (m *gatedStreamModel) GenerateObject(ctx context.Context, call fantasy.ObjectCall) (*fantasy.ObjectResponse, error) {
+ return nil, errors.New("not implemented")
+}
+
+func (m *gatedStreamModel) StreamObject(ctx context.Context, call fantasy.ObjectCall) (fantasy.ObjectStreamResponse, error) {
+ return nil, errors.New("not implemented")
+}
+
+// TestRun_QueuedRunIDPromptRunsRecursivelyAndPublishesRunComplete is the
+// end-to-end proof of fix 2: a prompt carrying a RunID that is queued
+// behind a busy session must NOT be silently folded into the active turn.
+// It runs as its own turn via the recursive run path and publishes its
+// own terminal RunComplete, so a `crush run` caller blocking on that
+// RunID does not hang. The active turn keeps its own RunComplete too.
+func TestRun_QueuedRunIDPromptRunsRecursivelyAndPublishesRunComplete(t *testing.T) {
+ t.Parallel()
+
+ env := testEnv(t)
+ broker := pubsub.NewBroker[notify.RunComplete]()
+ t.Cleanup(broker.Shutdown)
+
+ large := &gatedStreamModel{
+ text: "done",
+ gate: make(chan struct{}),
+ entered: make(chan struct{}),
+ }
+ small := &finishStreamModel{text: "title"}
+
+ sa := NewSessionAgent(SessionAgentOptions{
+ LargeModel: Model{Model: large, CatwalkCfg: catwalk.Model{ContextWindow: 200000, DefaultMaxTokens: 10000}},
+ SmallModel: Model{Model: small, CatwalkCfg: catwalk.Model{ContextWindow: 200000, DefaultMaxTokens: 10000}},
+ IsYolo: true,
+ Sessions: env.sessions,
+ Messages: env.messages,
+ RunComplete: broker,
+ }).(*sessionAgent)
+
+ sess, err := env.sessions.Create(t.Context(), "session")
+ require.NoError(t, err)
+
+ subCtx, subCancel := context.WithCancel(t.Context())
+ defer subCancel()
+ ch := broker.Subscribe(subCtx)
+
+ // Start the main turn; it blocks inside Stream once active.
+ mainDone := make(chan error, 1)
+ go func() {
+ _, runErr := sa.Run(t.Context(), SessionAgentCall{
+ SessionID: sess.ID,
+ RunID: "run-main",
+ Prompt: "main",
+ })
+ mainDone <- runErr
+ }()
+
+ // Wait until the main turn is active (inside Stream).
+ select {
+ case <-large.entered:
+ case <-time.After(5 * time.Second):
+ t.Fatal("main run never entered Stream")
+ }
+ require.True(t, sa.IsSessionBusy(sess.ID), "main run must be active before enqueueing the follow-up")
+
+ // Enqueue a RunID-bearing follow-up behind the busy session.
+ res, err := sa.Run(t.Context(), SessionAgentCall{
+ SessionID: sess.ID,
+ RunID: "run-follow",
+ Prompt: "follow",
+ })
+ require.NoError(t, err)
+ require.Nil(t, res, "a busy-session follow-up must enqueue and return (nil, nil)")
+ require.Equal(t, 1, sa.QueuedPrompts(sess.ID), "the follow-up must be queued, not folded")
+
+ // Release the main turn so it completes and hands off to the queue.
+ close(large.gate)
+ require.NoError(t, <-mainDone)
+
+ // Both turns must publish their own terminal RunComplete.
+ got := map[string]notify.RunComplete{}
+ deadline := time.After(5 * time.Second)
+ for len(got) < 2 {
+ select {
+ case ev := <-ch:
+ got[ev.Payload.RunID] = ev.Payload
+ case <-deadline:
+ t.Fatalf("timed out waiting for both RunCompletes; got %v", got)
+ }
+ }
+
+ main, ok := got["run-main"]
+ require.True(t, ok, "the active turn must publish its own RunComplete")
+ require.Empty(t, main.Error)
+ require.False(t, main.Cancelled)
+
+ follow, ok := got["run-follow"]
+ require.True(t, ok,
+ "the queued RunID prompt must publish its own RunComplete instead of being folded silently")
+ require.Empty(t, follow.Error)
+ require.False(t, follow.Cancelled)
+ require.Equal(t, "done", follow.Text, "the queued prompt ran as its own turn")
+
+ // Two distinct assistant turns prove the follow-up was not folded.
+ msgs, err := env.messages.List(t.Context(), sess.ID)
+ require.NoError(t, err)
+ var assistants, follows int
+ for _, m := range msgs {
+ switch m.Role {
+ case message.Assistant:
+ assistants++
+ case message.User:
+ if m.Content().String() == "follow" {
+ follows++
+ }
+ }
+ }
+ require.Equal(t, 2, assistants, "the active turn and the recursive turn each produce one assistant message")
+ require.Equal(t, 1, follows, "the follow-up prompt is its own user turn")
+}
@@ -58,6 +58,138 @@ func TestSessionAgentRun_QueueStripsOnComplete(t *testing.T) {
"RunComplete still correlates with the originating SendMessage")
}
+// TestDrainQueueForStep_FiltersUnderDispatchLock verifies that the queue
+// drain evaluates the per-session cancel mark while holding the dispatch
+// mutex (canceledBySeq's documented precondition). Queued calls at or
+// below the cancel high-water mark are dropped, calls queued after the
+// cancel (higher seq) are folded, untracked enqueues (seq == 0) are
+// dropped whenever any mark is present, and the queue is cleared. These
+// calls carry no RunID, so all foldable survivors are returned for
+// folding (the existing follow-up behavior).
+func TestDrainQueueForStep_FiltersUnderDispatchLock(t *testing.T) {
+ t.Parallel()
+
+ env := testEnv(t)
+ a := NewSessionAgent(SessionAgentOptions{
+ Sessions: env.sessions,
+ Messages: env.messages,
+ }).(*sessionAgent)
+
+ const sessionID = "drain-session"
+ a.messageQueue.Set(sessionID, []SessionAgentCall{
+ {SessionID: sessionID, Prompt: "below", acceptSeq: 1},
+ {SessionID: sessionID, Prompt: "at-mark", acceptSeq: 2},
+ {SessionID: sessionID, Prompt: "after", acceptSeq: 3},
+ {SessionID: sessionID, Prompt: "untracked", acceptSeq: 0},
+ })
+ // Cancel high-water mark at seq 2: seq <= 2 and seq == 0 are covered.
+ a.cancelMark.Set(sessionID, 2)
+
+ fold, canceledWithRunID := a.drainQueueForStep(sessionID)
+
+ require.Len(t, fold, 1,
+ "only the follow-up queued after the cancel (seq > mark) must be folded")
+ require.Equal(t, "after", fold[0].Prompt)
+ require.Empty(t, canceledWithRunID,
+ "no dropped call carried a RunID, so none need a terminal RunComplete")
+
+ _, ok := a.messageQueue.Get(sessionID)
+ require.False(t, ok, "drain must clear the session message queue when nothing is kept")
+}
+
+// TestDrainQueueForStep_NoMarkFoldsAllNonRunID verifies that with no
+// cancel mark recorded, every queued call without a RunID is folded.
+func TestDrainQueueForStep_NoMarkFoldsAllNonRunID(t *testing.T) {
+ t.Parallel()
+
+ env := testEnv(t)
+ a := NewSessionAgent(SessionAgentOptions{
+ Sessions: env.sessions,
+ Messages: env.messages,
+ }).(*sessionAgent)
+
+ const sessionID = "drain-nomark"
+ a.messageQueue.Set(sessionID, []SessionAgentCall{
+ {SessionID: sessionID, Prompt: "a", acceptSeq: 0},
+ {SessionID: sessionID, Prompt: "b", acceptSeq: 5},
+ })
+
+ fold, canceledWithRunID := a.drainQueueForStep(sessionID)
+ require.Len(t, fold, 2, "no cancel mark means all non-RunID queued calls are folded")
+ require.Empty(t, canceledWithRunID)
+}
+
+// TestDrainQueueForStep_KeepsRunIDPromptsQueued is the core of fix 2: a
+// queued prompt that carries a RunID must NOT be folded into the active
+// turn. Folding it would silently absorb it into another turn and never
+// publish a RunComplete for its RunID, hanging a `crush run` caller that
+// blocks on that event. Such prompts are left in the queue so the
+// recursive run path gives each its own turn and its own RunComplete.
+// Non-RunID prompts are still folded.
+func TestDrainQueueForStep_KeepsRunIDPromptsQueued(t *testing.T) {
+ t.Parallel()
+
+ env := testEnv(t)
+ a := NewSessionAgent(SessionAgentOptions{
+ Sessions: env.sessions,
+ Messages: env.messages,
+ }).(*sessionAgent)
+
+ const sessionID = "drain-runid"
+ a.messageQueue.Set(sessionID, []SessionAgentCall{
+ {SessionID: sessionID, Prompt: "fold-me", acceptSeq: 1},
+ {SessionID: sessionID, RunID: "run-a", Prompt: "keep-me", acceptSeq: 2},
+ {SessionID: sessionID, RunID: "run-b", Prompt: "keep-me-too", acceptSeq: 3},
+ })
+
+ fold, canceledWithRunID := a.drainQueueForStep(sessionID)
+
+ require.Len(t, fold, 1, "only the non-RunID prompt is folded into the active turn")
+ require.Equal(t, "fold-me", fold[0].Prompt)
+ require.Empty(t, canceledWithRunID)
+
+ kept, ok := a.messageQueue.Get(sessionID)
+ require.True(t, ok, "RunID-bearing prompts must remain queued for the recursive run path")
+ require.Len(t, kept, 2)
+ require.Equal(t, "run-a", kept[0].RunID)
+ require.Equal(t, "run-b", kept[1].RunID)
+}
+
+// TestDrainQueueForStep_ReportsCanceledRunIDDrops verifies that a queued
+// prompt carrying a RunID that is dropped because a cancel covers it is
+// reported in canceledWithRunID so the caller can publish its terminal
+// cancelled RunComplete. A canceled prompt without a RunID is dropped
+// silently as before.
+func TestDrainQueueForStep_ReportsCanceledRunIDDrops(t *testing.T) {
+ t.Parallel()
+
+ env := testEnv(t)
+ a := NewSessionAgent(SessionAgentOptions{
+ Sessions: env.sessions,
+ Messages: env.messages,
+ }).(*sessionAgent)
+
+ const sessionID = "drain-cancel-runid"
+ a.messageQueue.Set(sessionID, []SessionAgentCall{
+ {SessionID: sessionID, RunID: "run-canceled", Prompt: "canceled", acceptSeq: 1},
+ {SessionID: sessionID, Prompt: "canceled-no-runid", acceptSeq: 1},
+ {SessionID: sessionID, RunID: "run-survives", Prompt: "survives", acceptSeq: 5},
+ })
+ a.cancelMark.Set(sessionID, 2)
+
+ fold, canceledWithRunID := a.drainQueueForStep(sessionID)
+
+ require.Empty(t, fold, "no uncanceled non-RunID prompts to fold")
+ require.Len(t, canceledWithRunID, 1,
+ "only the dropped RunID-bearing prompt needs a terminal RunComplete")
+ require.Equal(t, "run-canceled", canceledWithRunID[0].RunID)
+
+ kept, ok := a.messageQueue.Get(sessionID)
+ require.True(t, ok)
+ require.Len(t, kept, 1, "the uncanceled RunID prompt stays queued")
+ require.Equal(t, "run-survives", kept[0].RunID)
+}
+
// TestRunCompletePublisher_MustDeliverOverTakesPublish exercises the
// pubsub.Publisher interface change end-to-end: a Broker is the only
// concrete Publisher implementation and must satisfy both Publish and
@@ -87,3 +219,104 @@ func TestRunCompletePublisher_MustDeliverOverTakesPublish(t *testing.T) {
t.Fatal("PublishMustDeliver did not deliver event")
}
}
+
+// requireSingleCancelledRunComplete reads exactly one RunComplete from ch,
+// asserts it is the cancelled terminal event for runID, and verifies no
+// second event arrives. This observes the published pubsub event rather
+// than internal bookkeeping, which is the contract a `crush run` caller
+// blocking on the broker actually relies on.
+func requireSingleCancelledRunComplete(t *testing.T, ch <-chan pubsub.Event[notify.RunComplete], sessionID, runID string) {
+ t.Helper()
+ select {
+ case ev := <-ch:
+ require.Equal(t, runID, ev.Payload.RunID,
+ "the published RunComplete must carry the dropped queued prompt's RunID")
+ require.Equal(t, sessionID, ev.Payload.SessionID)
+ require.True(t, ev.Payload.Cancelled,
+ "a dropped queued prompt must publish a cancelled RunComplete")
+ case <-time.After(5 * time.Second):
+ t.Fatal("timed out waiting for the cancelled RunComplete")
+ }
+ select {
+ case extra := <-ch:
+ t.Fatalf("expected exactly one RunComplete, got a second: %+v", extra.Payload)
+ case <-time.After(100 * time.Millisecond):
+ }
+}
+
+// TestCancel_QueuedRunIDPromptPublishesCancelledRunComplete proves the
+// terminal-event behavior end-to-end: a RunID-bearing prompt sitting in
+// the queue that is canceled while queued (via the public Cancel path,
+// which routes through clearQueueAndNotify -> publishCanceledQueueDrops)
+// must emit exactly one cancelled RunComplete on the broker for its
+// RunID. A queued prompt without a RunID is dropped silently. This is the
+// coverage the earlier drain test lacked: it asserted the returned
+// bookkeeping slice, not the published event a `crush run` caller awaits.
+func TestCancel_QueuedRunIDPromptPublishesCancelledRunComplete(t *testing.T) {
+ t.Parallel()
+
+ env := testEnv(t)
+ broker := pubsub.NewBroker[notify.RunComplete]()
+ t.Cleanup(broker.Shutdown)
+
+ a := NewSessionAgent(SessionAgentOptions{
+ Sessions: env.sessions,
+ Messages: env.messages,
+ RunComplete: broker,
+ }).(*sessionAgent)
+
+ subCtx, subCancel := context.WithCancel(t.Context())
+ defer subCancel()
+ ch := broker.Subscribe(subCtx)
+
+ const sessionID = "cancel-queued-runid"
+ a.messageQueue.Set(sessionID, []SessionAgentCall{
+ {SessionID: sessionID, Prompt: "no-runid", acceptSeq: 1},
+ {SessionID: sessionID, RunID: "run-queued", Prompt: "queued", acceptSeq: 2},
+ })
+
+ a.Cancel(sessionID)
+
+ requireSingleCancelledRunComplete(t, ch, sessionID, "run-queued")
+
+ _, ok := a.messageQueue.Get(sessionID)
+ require.False(t, ok, "Cancel must clear the queue")
+}
+
+// TestDrainQueueForStep_DroppedRunIDPublishesCancelledRunComplete drives
+// the production drain sequence (drainQueueForStep then
+// publishCanceledQueueDrops, mirroring the PrepareStep handoff) and
+// asserts the dropped RunID-bearing prompt actually publishes exactly one
+// cancelled RunComplete on the broker. The companion bookkeeping test
+// covers the returned slice; this one covers the observable terminal
+// event.
+func TestDrainQueueForStep_DroppedRunIDPublishesCancelledRunComplete(t *testing.T) {
+ t.Parallel()
+
+ env := testEnv(t)
+ broker := pubsub.NewBroker[notify.RunComplete]()
+ t.Cleanup(broker.Shutdown)
+
+ a := NewSessionAgent(SessionAgentOptions{
+ Sessions: env.sessions,
+ Messages: env.messages,
+ RunComplete: broker,
+ }).(*sessionAgent)
+
+ subCtx, subCancel := context.WithCancel(t.Context())
+ defer subCancel()
+ ch := broker.Subscribe(subCtx)
+
+ const sessionID = "drain-drop-runid"
+ a.messageQueue.Set(sessionID, []SessionAgentCall{
+ {SessionID: sessionID, RunID: "run-dropped", Prompt: "dropped", acceptSeq: 1},
+ {SessionID: sessionID, Prompt: "dropped-no-runid", acceptSeq: 1},
+ })
+ a.cancelMark.Set(sessionID, 2)
+
+ _, canceledWithRunID := a.drainQueueForStep(sessionID)
+ require.Len(t, canceledWithRunID, 1)
+ a.publishCanceledQueueDrops(canceledWithRunID)
+
+ requireSingleCancelledRunComplete(t, ch, sessionID, "run-dropped")
+}
@@ -0,0 +1,52 @@
+package agent
+
+import (
+ "context"
+ "sync/atomic"
+)
+
+// runCompleteMarkerKey is the unexported context key carrying a
+// [runCompleteMarker] from the dispatch boundary (backend.runAgent)
+// down into the coordinator. It lets the dispatcher learn whether the
+// coordinator already published the authoritative terminal
+// notify.RunComplete for the run, so a fallback terminal event is only
+// emitted when one is actually missing (e.g. an error returned before
+// sessionAgent.Run ever executed). It avoids a breaking change to the
+// Coordinator interface.
+type runCompleteMarkerKey struct{}
+
+// runCompleteMarker records whether a terminal RunComplete has been
+// published for a run. It is shared by pointer through the context so
+// a publish deep in the call stack is observable by the dispatcher
+// after the call returns.
+type runCompleteMarker struct {
+ published atomic.Bool
+}
+
+// WithRunCompleteMarker returns ctx carrying a fresh marker the
+// coordinator can flag via [MarkRunCompletePublished] once it emits the
+// run's terminal RunComplete. Callers read the result with
+// [RunCompletePublished]. Attaching the marker is optional: code paths
+// without one simply skip the dedup signal.
+func WithRunCompleteMarker(ctx context.Context) context.Context {
+ return context.WithValue(ctx, runCompleteMarkerKey{}, &runCompleteMarker{})
+}
+
+// MarkRunCompletePublished records that the authoritative terminal
+// RunComplete has been published for the run carried by ctx. It is a
+// no-op when no marker is present (e.g. the in-process/local Run path,
+// which is not dispatched through backend.runAgent).
+func MarkRunCompletePublished(ctx context.Context) {
+ if m, ok := ctx.Value(runCompleteMarkerKey{}).(*runCompleteMarker); ok {
+ m.published.Store(true)
+ }
+}
+
+// RunCompletePublished reports whether [MarkRunCompletePublished] was
+// called on ctx's marker. It returns false when no marker is present.
+func RunCompletePublished(ctx context.Context) bool {
+ if m, ok := ctx.Value(runCompleteMarkerKey{}).(*runCompleteMarker); ok {
+ return m.published.Load()
+ }
+ return false
+}
@@ -185,6 +185,14 @@ func (app *App) AgentNotifications() *pubsub.Broker[notify.Notification] {
return app.agentNotifications
}
+// RunCompletions returns the broker for the authoritative per-run
+// terminal RunComplete events. The dispatcher (backend.runAgent) uses
+// it to emit a reliable terminal event when a run fails before the
+// coordinator could publish one of its own.
+func (app *App) RunCompletions() *pubsub.Broker[notify.RunComplete] {
+ return app.runCompletions
+}
+
// resolveSession resolves which session to use for a non-interactive run
// If continueSessionID is set, it looks up that session by ID
// If useLast is set, it returns the most recently updated top-level session
@@ -0,0 +1,131 @@
+package backend
+
+import (
+ "context"
+ "testing"
+ "time"
+
+ "charm.land/fantasy"
+ "github.com/charmbracelet/crush/internal/agent"
+ "github.com/charmbracelet/crush/internal/agent/agenttest"
+ "github.com/charmbracelet/crush/internal/db"
+ "github.com/charmbracelet/crush/internal/message"
+ "github.com/charmbracelet/crush/internal/proto"
+ "github.com/charmbracelet/crush/internal/session"
+ "github.com/stretchr/testify/require"
+)
+
+// gatedCoordinator wraps a real agent.Coordinator and parks RunAccepted
+// before delegating to it. Every method other than RunAccepted is
+// inherited from the embedded coordinator, so BeginAccepted (called by
+// Backend.SendMessage) and RunAccepted (called by the dispatched run)
+// are the production agent.Coordinator implementations under test, not
+// stubs. The gate only delays entry into the real RunAccepted so a
+// cancel can be made to land in the accepted-but-not-yet-active window
+// deterministically: the accept handle is not consumed by
+// sessionAgent.Run until the real RunAccepted runs after the gate opens.
+type gatedCoordinator struct {
+ agent.Coordinator
+ entered chan struct{}
+ gate chan struct{}
+}
+
+func (c *gatedCoordinator) RunAccepted(ctx context.Context, accept *agent.AcceptedRun, sessionID, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error) {
+ close(c.entered)
+ <-c.gate
+ return c.Coordinator.RunAccepted(ctx, accept, sessionID, prompt, attachments...)
+}
+
+// newRealCoordinator builds a production agent.Coordinator over a
+// DB-backed session/message store, wrapped in a gate. It is constructed
+// through the real agent.NewCoordinator path (via the test-only
+// agenttest helper) with an offline-resolvable model: the
+// cancel-on-entry path under test persists a canceled turn and returns
+// before any model call, so no network I/O happens.
+func newRealCoordinator(t *testing.T) (*gatedCoordinator, session.Service, message.Service) {
+ t.Helper()
+ conn, err := db.Connect(t.Context(), t.TempDir())
+ require.NoError(t, err)
+ t.Cleanup(func() { conn.Close() })
+
+ q := db.New(conn)
+ sessions := session.NewService(q, conn)
+ messages := message.NewService(q)
+
+ coord, err := agenttest.NewCoordinator(t.Context(), t.TempDir(), sessions, messages)
+ require.NoError(t, err)
+
+ return &gatedCoordinator{
+ Coordinator: coord,
+ entered: make(chan struct{}),
+ gate: make(chan struct{}),
+ }, sessions, messages
+}
+
+// TestSendMessage_AcceptedCancelRace_RealMachinery exercises the
+// 202/cancel race end-to-end through Backend.SendMessage against the
+// production agent.Coordinator (BeginAccepted + RunAccepted), not a
+// stub. It asserts that a cancel arriving after the prompt is accepted
+// but before the run becomes active is not lost: the accepted handle
+// reaches sessionAgent.Run and drives cancel-on-entry, which persists a
+// canceled turn instead of streaming.
+//
+// This test would fail if Coordinator.BeginAccepted returned nil (Cancel
+// would find no accepted run and record no mark, and the run would
+// receive a nil Accepted handle and skip cancel-on-entry) or if
+// Coordinator.RunAccepted dropped the handle on its way into
+// sessionAgent.Run (the run would likewise skip cancel-on-entry and try
+// to stream the model). In either case no FinishReasonCanceled turn
+// would be persisted.
+func TestSendMessage_AcceptedCancelRace_RealMachinery(t *testing.T) {
+ t.Parallel()
+ b, _ := newTestBackend(t)
+
+ coord, sessions, messages := newRealCoordinator(t)
+ sess, err := sessions.Create(t.Context(), "session")
+ require.NoError(t, err)
+
+ ws := insertAgentWorkspace(t, b, coord)
+
+ require.NoError(t, b.SendMessage(ws.ID, proto.AgentMessage{SessionID: sess.ID, Prompt: "hi"}))
+
+ // Coordinator.BeginAccepted ran synchronously inside SendMessage
+ // before dispatch; the dispatched run has now entered the gate but
+ // has not yet called the real RunAccepted, so the accept handle is
+ // not yet consumed: the prompt is accepted but not active.
+ select {
+ case <-coord.entered:
+ case <-time.After(2 * time.Second):
+ t.Fatal("dispatched run never entered RunAccepted")
+ }
+
+ // A cancel arriving now lands in the accepted-but-not-yet-active
+ // window and is only recorded because BeginAccepted incremented the
+ // accept counter.
+ require.NoError(t, b.CancelSession(ws.ID, sess.ID))
+
+ // Release the gate so the real RunAccepted threads the handle into
+ // sessionAgent.Run, which drives cancel-on-entry.
+ close(coord.gate)
+
+ // The dispatched run returns nil (cancel-on-entry), so runWG drains.
+ waited := make(chan struct{})
+ go func() {
+ ws.runWG.Wait()
+ close(waited)
+ }()
+ select {
+ case <-waited:
+ case <-time.After(5 * time.Second):
+ t.Fatal("runWG.Wait did not complete after the canceled run returned")
+ }
+
+ // The accepted-but-not-yet-active cancel persisted a canceled turn
+ // rather than streaming a real response.
+ msgs, err := messages.List(t.Context(), sess.ID)
+ require.NoError(t, err)
+ require.Len(t, msgs, 2)
+ require.Equal(t, message.User, msgs[0].Role)
+ require.Equal(t, message.Assistant, msgs[1].Role)
+ require.Equal(t, message.FinishReasonCanceled, msgs[1].FinishReason())
+}
@@ -2,22 +2,30 @@ package backend
import (
"context"
+ "errors"
"github.com/charmbracelet/crush/internal/agent"
+ "github.com/charmbracelet/crush/internal/agent/notify"
"github.com/charmbracelet/crush/internal/config"
"github.com/charmbracelet/crush/internal/proto"
+ "github.com/charmbracelet/crush/internal/pubsub"
)
-// SendMessage sends a prompt to the agent coordinator for the given
-// workspace and session.
+// SendMessage validates and accepts a prompt for the workspace's agent,
+// then dispatches the run on a goroutine bound to the workspace context
+// and returns immediately. It does not wait for the LLM turn to
+// complete: the run's lifetime is owned by the workspace, not by the
+// caller. Errors from the dispatched run reach observers through the
+// agent event channels (a notify.TypeAgentError notification), not
+// through this return value.
//
-// When msg.RunID is non-empty it is attached to the context via
-// agent.WithRunID so the coordinator can stamp the resulting
-// SessionAgentCall (and therefore the terminal notify.RunComplete
-// event) with that correlator. This is the only way for the
-// originating client to distinguish its own turn's RunComplete from
-// any concurrent turn that finishes on the same session.
-func (b *Backend) SendMessage(ctx context.Context, workspaceID string, msg proto.AgentMessage) error {
+// SendMessage returns synchronously when the request cannot be accepted:
+// ErrWorkspaceNotFound if the workspace is missing, ErrAgentNotInitialized
+// if its coordinator is nil, the structural validation errors from
+// agent.ValidateCall (ErrEmptyPrompt, ErrSessionMissing) when the prompt
+// or session is missing, and ErrWorkspaceClosing if the workspace is
+// being torn down.
+func (b *Backend) SendMessage(workspaceID string, msg proto.AgentMessage) error {
ws, err := b.GetWorkspace(workspaceID)
if err != nil {
return err
@@ -27,11 +35,88 @@ func (b *Backend) SendMessage(ctx context.Context, workspaceID string, msg proto
return ErrAgentNotInitialized
}
+ if err := agent.ValidateCall(agent.SessionAgentCall{
+ SessionID: msg.SessionID,
+ Prompt: msg.Prompt,
+ Attachments: proto.AttachmentsToMessage(msg.Attachments),
+ }); err != nil {
+ return err
+ }
+
+ accept := ws.AgentCoordinator.BeginAccepted(msg.SessionID)
+
+ ws.runMu.Lock()
+ if ws.closing {
+ ws.runMu.Unlock()
+ accept.Close()
+ return ErrWorkspaceClosing
+ }
+ ws.runWG.Add(1)
+ ws.runMu.Unlock()
+
+ go b.runAgent(ws, msg, accept)
+ return nil
+}
+
+// runAgent executes an accepted agent run for the workspace. It owns the
+// accept reservation (releasing it on return) and the runWG ticket added
+// by SendMessage. The run is bound to the workspace context so its
+// lifetime is independent of any client's HTTP request.
+//
+// On a non-cancel error it surfaces the failure to observers via a
+// notify.TypeAgentError notification (lossy, best-effort). That alone is
+// not a reliable terminal signal: the agent-event fan-in uses lossy
+// subscribers, so a `crush run` caller blocking on its RunID could hang
+// if the event is dropped. To guarantee termination, when msg.RunID is
+// non-empty and the coordinator did not already publish the run's
+// authoritative terminal RunComplete (e.g. the error was returned before
+// sessionAgent.Run executed, such as a readyWg or UpdateModels failure),
+// runAgent emits an errored RunComplete on the must-deliver
+// runCompletions broker so the waiter observes a deterministic terminal
+// event. context.Canceled is expected (sessionAgent.Run already
+// publishes the cancelled terminal marker) and produces no error
+// terminal event.
+//
+// When msg.RunID is non-empty it is attached to the context via
+// agent.WithRunID so the coordinator can stamp the terminal
+// notify.RunComplete event with that correlator. A run-complete marker
+// is also attached so the coordinator can report whether it published
+// the terminal event, letting runAgent avoid a duplicate fallback.
+func (b *Backend) runAgent(ws *Workspace, msg proto.AgentMessage, accept *agent.AcceptedRun) {
+ defer ws.runWG.Done()
+ defer accept.Close()
+
+ ctx := ws.ctx
if msg.RunID != "" {
ctx = agent.WithRunID(ctx, msg.RunID)
}
- _, err = ws.AgentCoordinator.Run(ctx, msg.SessionID, msg.Prompt, proto.AttachmentsToMessage(msg.Attachments)...)
- return err
+ ctx = agent.WithRunCompleteMarker(ctx)
+
+ _, err := ws.AgentCoordinator.RunAccepted(ctx, accept, msg.SessionID, msg.Prompt, proto.AttachmentsToMessage(msg.Attachments)...)
+ if err == nil || errors.Is(err, context.Canceled) {
+ return
+ }
+
+ ws.AgentNotifications().Publish(pubsub.CreatedEvent, notify.Notification{
+ SessionID: msg.SessionID,
+ RunID: msg.RunID,
+ Type: notify.TypeAgentError,
+ Message: err.Error(),
+ })
+
+ // Reliable terminal fallback. Only needed when a RunID waiter
+ // exists and the coordinator has not already emitted the run's
+ // terminal RunComplete; otherwise this would be a duplicate.
+ if msg.RunID == "" || agent.RunCompletePublished(ctx) {
+ return
+ }
+ if rc := ws.RunCompletions(); rc != nil {
+ rc.PublishMustDeliver(ctx, pubsub.UpdatedEvent, notify.RunComplete{
+ SessionID: msg.SessionID,
+ RunID: msg.RunID,
+ Error: err.Error(),
+ })
+ }
}
// GetAgentInfo returns the agent's model and busy status.
@@ -0,0 +1,162 @@
+package backend
+
+import (
+ "context"
+ "errors"
+ "testing"
+ "time"
+
+ "charm.land/fantasy"
+ "github.com/charmbracelet/crush/internal/agent"
+ "github.com/charmbracelet/crush/internal/app"
+ "github.com/charmbracelet/crush/internal/message"
+ "github.com/charmbracelet/crush/internal/proto"
+ "github.com/google/uuid"
+ "github.com/stretchr/testify/require"
+)
+
+// errorCoordinator is a minimal agent.Coordinator whose RunAccepted
+// returns a configurable error. When markPublished is true it stamps
+// the run-complete marker on the context before returning, simulating a
+// real coordinator that already published the run's authoritative
+// terminal RunComplete (so runAgent must not emit a duplicate fallback).
+type errorCoordinator struct {
+ err error
+ markPublished bool
+}
+
+func (c *errorCoordinator) Run(ctx context.Context, sessionID, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error) {
+ return nil, c.err
+}
+
+func (c *errorCoordinator) RunAccepted(ctx context.Context, accept *agent.AcceptedRun, sessionID, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error) {
+ if c.markPublished {
+ agent.MarkRunCompletePublished(ctx)
+ }
+ return nil, c.err
+}
+
+func (c *errorCoordinator) BeginAccepted(sessionID string) *agent.AcceptedRun { return nil }
+func (c *errorCoordinator) Cancel(string) {}
+func (c *errorCoordinator) CancelAll() {}
+func (c *errorCoordinator) IsBusy() bool { return false }
+func (c *errorCoordinator) IsSessionBusy(string) bool { return false }
+func (c *errorCoordinator) QueuedPrompts(string) int { return 0 }
+func (c *errorCoordinator) QueuedPromptsList(string) []string { return nil }
+func (c *errorCoordinator) ClearQueue(string) {}
+func (c *errorCoordinator) Summarize(context.Context, string) error { return nil }
+func (c *errorCoordinator) Model() agent.Model { return agent.Model{} }
+func (c *errorCoordinator) UpdateModels(context.Context) error { return nil }
+
+// insertRunCompleteWorkspace installs a workspace backed by a real
+// app.App (so the runCompletions broker exists) with the given
+// coordinator and a workspace run context derived from base.
+func insertRunCompleteWorkspace(t *testing.T, b *Backend, base context.Context, coord agent.Coordinator) *Workspace {
+ t.Helper()
+ a := app.NewForTest(base)
+ a.AgentCoordinator = coord
+ t.Cleanup(a.ShutdownForTest)
+ ws := &Workspace{
+ ID: uuid.New().String(),
+ Path: t.TempDir(),
+ resolvedPath: t.TempDir(),
+ clients: make(map[string]*clientState),
+ shutdownFn: func() {},
+ }
+ ws.App = a
+ ws.ctx, ws.cancel = context.WithCancel(base)
+ b.mu.Lock()
+ b.workspaces.Set(ws.ID, ws)
+ b.pathIndex[ws.resolvedPath] = ws.ID
+ b.mu.Unlock()
+ return ws
+}
+
+// TestRunAgent_PreRunErrorPublishesTerminalRunComplete proves that an
+// error returned from RunAccepted before the coordinator could publish
+// its own terminal event (e.g. a readyWg or UpdateModels failure,
+// modeled here by a stub coordinator) still yields a reliable terminal
+// RunComplete for the run's RunID. Without it, a `crush run` caller
+// blocking on that RunID would hang because the lossy TypeAgentError
+// event is not a guaranteed terminal signal.
+func TestRunAgent_PreRunErrorPublishesTerminalRunComplete(t *testing.T) {
+ t.Parallel()
+ b, _ := newTestBackend(t)
+ runErr := errors.New("update models failed")
+ ws := insertRunCompleteWorkspace(t, b, context.Background(), &errorCoordinator{err: runErr})
+
+ subCtx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+ ch := ws.RunCompletions().Subscribe(subCtx)
+
+ err := b.SendMessage(ws.ID, proto.AgentMessage{SessionID: "S1", RunID: "run-1", Prompt: "hi"})
+ require.NoError(t, err)
+
+ select {
+ case ev := <-ch:
+ require.Equal(t, "run-1", ev.Payload.RunID,
+ "the terminal RunComplete must carry the dispatched RunID")
+ require.Equal(t, "S1", ev.Payload.SessionID)
+ require.Equal(t, runErr.Error(), ev.Payload.Error,
+ "the fallback terminal event must be marked errored")
+ require.False(t, ev.Payload.Cancelled)
+ case <-time.After(2 * time.Second):
+ t.Fatal("no terminal RunComplete published for a pre-run error; a run waiter would hang")
+ }
+}
+
+// TestRunAgent_NoFallbackWhenCoordinatorPublished ensures the fallback
+// is suppressed when the coordinator already emitted the run's
+// authoritative terminal RunComplete, so callers never observe a
+// duplicate terminal event for the same RunID.
+func TestRunAgent_NoFallbackWhenCoordinatorPublished(t *testing.T) {
+ t.Parallel()
+ b, _ := newTestBackend(t)
+ runErr := errors.New("stream failed after publishing terminal event")
+ ws := insertRunCompleteWorkspace(t, b, context.Background(),
+ &errorCoordinator{err: runErr, markPublished: true})
+
+ subCtx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+ ch := ws.RunCompletions().Subscribe(subCtx)
+
+ err := b.SendMessage(ws.ID, proto.AgentMessage{SessionID: "S1", RunID: "run-1", Prompt: "hi"})
+ require.NoError(t, err)
+
+ // Wait for the dispatched run goroutine to return so any publish
+ // has already happened.
+ ws.runWG.Wait()
+
+ select {
+ case ev := <-ch:
+ t.Fatalf("runAgent published a duplicate terminal RunComplete: %+v", ev.Payload)
+ case <-time.After(200 * time.Millisecond):
+ }
+}
+
+// TestRunAgent_CancellationPublishesNoErrorTerminal verifies that a
+// context.Canceled result from RunAccepted produces no errored terminal
+// RunComplete from runAgent: cancellation is sessionAgent.Run's
+// responsibility (it publishes the cancelled marker) and the dispatcher
+// must not synthesize an error terminal for it.
+func TestRunAgent_CancellationPublishesNoErrorTerminal(t *testing.T) {
+ t.Parallel()
+ b, _ := newTestBackend(t)
+ ws := insertRunCompleteWorkspace(t, b, context.Background(),
+ &errorCoordinator{err: context.Canceled})
+
+ subCtx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+ ch := ws.RunCompletions().Subscribe(subCtx)
+
+ err := b.SendMessage(ws.ID, proto.AgentMessage{SessionID: "S1", RunID: "run-1", Prompt: "hi"})
+ require.NoError(t, err)
+
+ ws.runWG.Wait()
+
+ select {
+ case ev := <-ch:
+ t.Fatalf("cancellation must not publish a terminal RunComplete: %+v", ev.Payload)
+ case <-time.After(200 * time.Millisecond):
+ }
+}
@@ -0,0 +1,163 @@
+package backend
+
+import (
+ "context"
+ "sync/atomic"
+ "testing"
+ "time"
+
+ "charm.land/fantasy"
+ "github.com/charmbracelet/crush/internal/agent"
+ "github.com/charmbracelet/crush/internal/app"
+ "github.com/charmbracelet/crush/internal/message"
+ "github.com/charmbracelet/crush/internal/proto"
+ "github.com/google/uuid"
+ "github.com/stretchr/testify/require"
+)
+
+// blockingCoordinator is a minimal agent.Coordinator whose RunAccepted
+// blocks until release is closed. It records that RunAccepted was
+// entered so tests can observe the dispatched goroutine. Every other
+// method returns a zero value.
+type blockingCoordinator struct {
+ entered chan struct{}
+ release chan struct{}
+ runCount atomic.Int32
+}
+
+func newBlockingCoordinator() *blockingCoordinator {
+ return &blockingCoordinator{
+ entered: make(chan struct{}, 1),
+ release: make(chan struct{}),
+ }
+}
+
+func (c *blockingCoordinator) Run(ctx context.Context, sessionID, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error) {
+ return nil, nil
+}
+
+func (c *blockingCoordinator) RunAccepted(ctx context.Context, accept *agent.AcceptedRun, sessionID, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error) {
+ c.runCount.Add(1)
+ select {
+ case c.entered <- struct{}{}:
+ default:
+ }
+ <-c.release
+ return nil, nil
+}
+
+func (c *blockingCoordinator) BeginAccepted(sessionID string) *agent.AcceptedRun { return nil }
+func (c *blockingCoordinator) Cancel(string) {}
+func (c *blockingCoordinator) CancelAll() {}
+func (c *blockingCoordinator) IsBusy() bool { return false }
+func (c *blockingCoordinator) IsSessionBusy(string) bool { return false }
+func (c *blockingCoordinator) QueuedPrompts(string) int { return 0 }
+func (c *blockingCoordinator) QueuedPromptsList(string) []string { return nil }
+func (c *blockingCoordinator) ClearQueue(string) {}
+func (c *blockingCoordinator) Summarize(context.Context, string) error { return nil }
+func (c *blockingCoordinator) Model() agent.Model { return agent.Model{} }
+func (c *blockingCoordinator) UpdateModels(context.Context) error { return nil }
+
+// insertAgentWorkspace installs a synthetic workspace with the given
+// coordinator (or none) and a workspace run context, mirroring the
+// fields CreateWorkspace initializes.
+func insertAgentWorkspace(t *testing.T, b *Backend, coord agent.Coordinator) *Workspace {
+ t.Helper()
+ ws := &Workspace{
+ ID: uuid.New().String(),
+ Path: t.TempDir(),
+ resolvedPath: t.TempDir(),
+ clients: make(map[string]*clientState),
+ shutdownFn: func() {},
+ }
+ ws.App = &app.App{AgentCoordinator: coord}
+ ws.ctx, ws.cancel = context.WithCancel(b.ctx)
+ b.mu.Lock()
+ b.workspaces.Set(ws.ID, ws)
+ b.pathIndex[ws.resolvedPath] = ws.ID
+ b.mu.Unlock()
+ return ws
+}
+
+func TestSendMessage_WorkspaceNotFound(t *testing.T) {
+ t.Parallel()
+ b, _ := newTestBackend(t)
+ err := b.SendMessage("nope", proto.AgentMessage{SessionID: "S1", Prompt: "hi"})
+ require.ErrorIs(t, err, ErrWorkspaceNotFound)
+}
+
+func TestSendMessage_AgentNotInitialized(t *testing.T) {
+ t.Parallel()
+ b, _ := newTestBackend(t)
+ ws := insertAgentWorkspace(t, b, nil)
+ err := b.SendMessage(ws.ID, proto.AgentMessage{SessionID: "S1", Prompt: "hi"})
+ require.ErrorIs(t, err, ErrAgentNotInitialized)
+}
+
+func TestSendMessage_EmptyPrompt(t *testing.T) {
+ t.Parallel()
+ b, _ := newTestBackend(t)
+ ws := insertAgentWorkspace(t, b, newBlockingCoordinator())
+ err := b.SendMessage(ws.ID, proto.AgentMessage{SessionID: "S1", Prompt: ""})
+ require.ErrorIs(t, err, agent.ErrEmptyPrompt)
+}
+
+func TestSendMessage_SessionMissing(t *testing.T) {
+ t.Parallel()
+ b, _ := newTestBackend(t)
+ ws := insertAgentWorkspace(t, b, newBlockingCoordinator())
+ err := b.SendMessage(ws.ID, proto.AgentMessage{SessionID: "", Prompt: "hi"})
+ require.ErrorIs(t, err, agent.ErrSessionMissing)
+}
+
+func TestSendMessage_WorkspaceClosing(t *testing.T) {
+ t.Parallel()
+ b, _ := newTestBackend(t)
+ ws := insertAgentWorkspace(t, b, newBlockingCoordinator())
+ ws.runMu.Lock()
+ ws.closing = true
+ ws.runMu.Unlock()
+ err := b.SendMessage(ws.ID, proto.AgentMessage{SessionID: "S1", Prompt: "hi"})
+ require.ErrorIs(t, err, ErrWorkspaceClosing)
+}
+
+// TestSendMessage_SuccessIncrementsRunWG asserts the happy path returns
+// nil synchronously and dispatches a tracked goroutine: while
+// RunAccepted blocks, runWG.Wait must not complete (the ticket is
+// outstanding); after release it drains.
+func TestSendMessage_SuccessIncrementsRunWG(t *testing.T) {
+ t.Parallel()
+ b, _ := newTestBackend(t)
+ coord := newBlockingCoordinator()
+ ws := insertAgentWorkspace(t, b, coord)
+
+ err := b.SendMessage(ws.ID, proto.AgentMessage{SessionID: "S1", Prompt: "hi"})
+ require.NoError(t, err)
+
+ select {
+ case <-coord.entered:
+ case <-time.After(2 * time.Second):
+ t.Fatal("dispatched goroutine never entered RunAccepted")
+ }
+ require.Equal(t, int32(1), coord.runCount.Load())
+
+ waited := make(chan struct{})
+ go func() {
+ ws.runWG.Wait()
+ close(waited)
+ }()
+
+ select {
+ case <-waited:
+ t.Fatal("runWG.Wait completed while the run was still in flight; ticket was not added")
+ case <-time.After(100 * time.Millisecond):
+ }
+
+ close(coord.release)
+
+ select {
+ case <-waited:
+ case <-time.After(2 * time.Second):
+ t.Fatal("runWG.Wait did not complete after the run returned")
+ }
+}
@@ -34,6 +34,7 @@ var (
ErrUnknownCommand = errors.New("unknown command")
ErrInvalidClientID = errors.New("invalid client_id")
ErrClientNotAttached = errors.New("client not attached")
+ ErrWorkspaceClosing = errors.New("workspace closing")
)
// DefaultCreateGrace is the window in which a client must open an SSE
@@ -108,6 +109,23 @@ type Workspace struct {
// with fallback to the cleaned absolute path.
resolvedPath string
+ // ctx is the workspace-scoped run context. It is derived from
+ // the backend context in CreateWorkspace and lives for the
+ // lifetime of the workspace; cancel tears it down. Agent runs
+ // dispatched on behalf of this workspace are bound to ctx so
+ // their lifetime is owned by the workspace, not by any single
+ // client's HTTP request.
+ ctx context.Context
+ cancel context.CancelFunc
+
+ // runMu guards closing and gates dispatch of new agent runs.
+ // closing is set by Shutdown so no new runs are accepted once
+ // teardown has begun. runWG tracks dispatched agent goroutines
+ // so Shutdown can wait for them to return before app cleanup.
+ runMu sync.Mutex
+ closing bool
+ runWG sync.WaitGroup
+
// clientsMu guards clients. It is held only briefly (no IO).
clientsMu sync.Mutex
// clients tracks each client's claim on this workspace. Refcount
@@ -122,7 +140,7 @@ type Workspace struct {
}
// invokeShutdown calls the workspace shutdown hook if set, falling
-// back to the embedded [app.App.Shutdown] when not.
+// back to the workspace [Workspace.Shutdown] wrapper when not.
func (w *Workspace) invokeShutdown() {
if w.shutdownFn != nil {
w.shutdownFn()
@@ -133,6 +151,40 @@ func (w *Workspace) invokeShutdown() {
}
}
+// Shutdown tears the workspace down in an order that is safe for
+// agent runs whose lifetime is bound to the workspace context. It
+// shadows the promoted [app.App.Shutdown] so callers reaching
+// ws.Shutdown() always observe this ordering:
+//
+// 1. Mark the workspace closing so no new agent runs are accepted.
+// 2. Cancel the workspace run context so any dispatched goroutine
+// that has not yet registered its per-session cancel still
+// observes cancellation.
+// 3. Cancel active coordinator work for runs that already
+// registered their per-session cancel function.
+// 4. Wait for dispatched agent goroutines to return.
+// 5. Run the embedded [app.App.Shutdown] cleanup (DB, LSP, etc).
+//
+// CancelAll is idempotent, so the second call inside app.App.Shutdown
+// is harmless; the important guarantee is that cancel -> CancelAll ->
+// runWG.Wait completes before the embedded cleanup touches the DB.
+func (w *Workspace) Shutdown() {
+ w.runMu.Lock()
+ w.closing = true
+ w.runMu.Unlock()
+
+ if w.cancel != nil {
+ w.cancel()
+ }
+ if w.App != nil && w.AgentCoordinator != nil {
+ w.AgentCoordinator.CancelAll()
+ }
+ w.runWG.Wait()
+ if w.App != nil {
+ w.App.Shutdown()
+ }
+}
+
// New creates a new [Backend].
func New(ctx context.Context, cfg *config.ConfigStore, shutdownFn ShutdownFunc) *Backend {
return &Backend{
@@ -247,6 +299,7 @@ func (b *Backend) CreateWorkspace(args proto.Workspace) (*Workspace, proto.Works
return nil, proto.Workspace{}, fmt.Errorf("failed to create app workspace: %w", err)
}
+ wsCtx, wsCancel := context.WithCancel(b.ctx)
ws := &Workspace{
App: appWorkspace,
ID: id,
@@ -255,6 +308,8 @@ func (b *Backend) CreateWorkspace(args proto.Workspace) (*Workspace, proto.Works
Env: args.Env,
Skills: skillsMgr,
resolvedPath: key,
+ ctx: wsCtx,
+ cancel: wsCancel,
clients: make(map[string]*clientState),
}
@@ -1,9 +1,16 @@
package backend
+import "context"
+
// InsertWorkspaceForTest registers ws with b under its current ID and
// path. It is intended for tests in other packages that need to drive
// HTTP handlers against a synthetic workspace without booting a real
// app.App. Production code should go through CreateWorkspace.
+//
+// If the workspace has no run context yet it is derived from the
+// backend context (falling back to context.Background), mirroring the
+// initialization CreateWorkspace performs, so dispatched agent runs
+// have a non-nil ws.ctx.
func InsertWorkspaceForTest(b *Backend, ws *Workspace) {
if ws.resolvedPath == "" {
ws.resolvedPath = ws.Path
@@ -11,6 +18,13 @@ func InsertWorkspaceForTest(b *Backend, ws *Workspace) {
if ws.clients == nil {
ws.clients = make(map[string]*clientState)
}
+ if ws.ctx == nil {
+ parent := b.ctx
+ if parent == nil {
+ parent = context.Background()
+ }
+ ws.ctx, ws.cancel = context.WithCancel(parent)
+ }
b.mu.Lock()
defer b.mu.Unlock()
b.workspaces.Set(ws.ID, ws)
@@ -423,12 +423,28 @@ func (c *Client) SendMessage(ctx context.Context, id string, sessionID, runID, p
return fmt.Errorf("failed to send message to agent: %w", err)
}
defer rsp.Body.Close()
- if rsp.StatusCode != http.StatusOK {
+ if rsp.StatusCode != http.StatusOK && rsp.StatusCode != http.StatusAccepted {
+ if msg := decodeErrorMessage(rsp.Body); msg != "" {
+ return fmt.Errorf("failed to send message to agent: status code %d: %s", rsp.StatusCode, msg)
+ }
return fmt.Errorf("failed to send message to agent: status code %d", rsp.StatusCode)
}
return nil
}
+// decodeErrorMessage attempts to decode the response body as a
+// proto.Error and returns its message. It returns an empty string
+// when the body is empty or cannot be decoded into a proto.Error
+// with a non-empty message, letting callers fall back to a
+// status-only error.
+func decodeErrorMessage(body io.Reader) string {
+ var e proto.Error
+ if err := json.NewDecoder(body).Decode(&e); err != nil {
+ return ""
+ }
+ return e.Message
+}
+
// GetAgentSessionInfo retrieves the agent session info for a workspace.
func (c *Client) GetAgentSessionInfo(ctx context.Context, id string, sessionID string) (*proto.AgentSession, error) {
rsp, err := c.get(ctx, fmt.Sprintf("/workspaces/%s/agent/sessions/%s", id, sessionID), nil, nil)
@@ -88,6 +88,76 @@ func TestSubscribeEventsContextCancelClosesEvents(t *testing.T) {
}
}
+func TestSendMessageAcceptsStatusAccepted(t *testing.T) {
+ t.Parallel()
+
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
+ w.WriteHeader(http.StatusAccepted)
+ }))
+ defer srv.Close()
+
+ c := captureClient(t, srv)
+ require.NoError(t, c.SendMessage(context.Background(), "ws1", "sess1", "", "hello"))
+}
+
+func TestSendMessageAcceptsStatusOK(t *testing.T) {
+ t.Parallel()
+
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
+ w.WriteHeader(http.StatusOK)
+ }))
+ defer srv.Close()
+
+ c := captureClient(t, srv)
+ require.NoError(t, c.SendMessage(context.Background(), "ws1", "sess1", "", "hello"))
+}
+
+func TestSendMessageDecodesErrorBody(t *testing.T) {
+ t.Parallel()
+
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
+ w.WriteHeader(http.StatusBadRequest)
+ _ = json.NewEncoder(w).Encode(proto.Error{Message: "session id is required"})
+ }))
+ defer srv.Close()
+
+ c := captureClient(t, srv)
+ err := c.SendMessage(context.Background(), "ws1", "", "", "hello")
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "status code 400")
+ require.Contains(t, err.Error(), "session id is required")
+}
+
+func TestSendMessageFallsBackOnMalformedErrorBody(t *testing.T) {
+ t.Parallel()
+
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
+ w.WriteHeader(http.StatusInternalServerError)
+ _, _ = w.Write([]byte("not json"))
+ }))
+ defer srv.Close()
+
+ c := captureClient(t, srv)
+ err := c.SendMessage(context.Background(), "ws1", "sess1", "", "hello")
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "status code 500")
+ require.NotContains(t, err.Error(), "not json")
+}
+
+func TestSendMessageFallsBackOnEmptyErrorBody(t *testing.T) {
+ t.Parallel()
+
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
+ w.WriteHeader(http.StatusInternalServerError)
+ }))
+ defer srv.Close()
+
+ c := captureClient(t, srv)
+ err := c.SendMessage(context.Background(), "ws1", "sess1", "", "hello")
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "status code 500")
+}
+
func marshalSSEPayload(t *testing.T) []byte {
t.Helper()
@@ -409,11 +409,27 @@ func (s *runStream) handle(ev any, stopSpinner func()) (done bool, err error) {
return true, nil
case pubsub.Event[proto.AgentEvent]:
- if e.Payload.Error != nil {
- stop()
- return true, fmt.Errorf("agent error: %w", e.Payload.Error)
+ if e.Payload.Error == nil {
+ return false, nil
}
- return false, nil
+ // Attribute the error to our run before treating it as
+ // fatal. Async errors from an unrelated workspace run share
+ // this channel, so a foreign failure must not abort us:
+ // - if the event carries a RunID, it is the authoritative
+ // correlator: it must match our run exactly, otherwise it
+ // belongs to a different request and we ignore it.
+ // - if the event carries no RunID (older server), fall back
+ // to SessionID: it must be present and match our session,
+ // otherwise we ignore it.
+ if e.Payload.RunID != "" {
+ if e.Payload.RunID != s.runID {
+ return false, nil
+ }
+ } else if e.Payload.SessionID == "" || e.Payload.SessionID != s.sessionID {
+ return false, nil
+ }
+ stop()
+ return true, fmt.Errorf("agent error: %w", e.Payload.Error)
}
return false, nil
}
@@ -2,6 +2,7 @@ package cmd
import (
"bytes"
+ "errors"
"testing"
"time"
@@ -307,6 +308,93 @@ func TestRunStream_RunIDSuppressesLiveMessagesAndPrintsRunComplete(t *testing.T)
require.Equal(t, "streamed prefix final", buf.String())
}
+// TestRunStream_AgentErrorRunIDFiltersForeign verifies that an async
+// agent error carrying a non-empty RunID is fatal only when it matches
+// our run. A foreign RunID is ignored regardless of the event's
+// SessionID, because RunID is the authoritative correlator and async
+// errors share the agent event channel: without strict RunID matching
+// an unrelated workspace failure would abort our run.
+func TestRunStream_AgentErrorRunIDFiltersForeign(t *testing.T) {
+ t.Parallel()
+
+ // Foreign RunID with a matching session is still foreign.
+ s := &runStream{sessionID: "S", runID: "run-mine", out: &bytes.Buffer{}, read: map[string]int{}}
+ done, err := s.handle(pubsub.Event[proto.AgentEvent]{Payload: proto.AgentEvent{
+ Type: proto.AgentEventTypeError,
+ SessionID: "S",
+ RunID: "run-other",
+ Error: errors.New("foreign boom"),
+ }}, nil)
+ require.NoError(t, err, "foreign RunID error must not abort our run")
+ require.False(t, done)
+
+ // Foreign RunID with a different session is ignored.
+ done, err = s.handle(pubsub.Event[proto.AgentEvent]{Payload: proto.AgentEvent{
+ Type: proto.AgentEventTypeError,
+ SessionID: "other",
+ RunID: "run-other",
+ Error: errors.New("foreign boom"),
+ }}, nil)
+ require.NoError(t, err, "foreign RunID error must not abort our run")
+ require.False(t, done)
+
+ // Foreign RunID with a missing session is ignored.
+ done, err = s.handle(pubsub.Event[proto.AgentEvent]{Payload: proto.AgentEvent{
+ Type: proto.AgentEventTypeError,
+ RunID: "run-other",
+ Error: errors.New("foreign boom"),
+ }}, nil)
+ require.NoError(t, err, "foreign RunID error must not abort our run")
+ require.False(t, done)
+
+ // Matching RunID is fatal.
+ done, err = s.handle(pubsub.Event[proto.AgentEvent]{Payload: proto.AgentEvent{
+ Type: proto.AgentEventTypeError,
+ SessionID: "S",
+ RunID: "run-mine",
+ Error: errors.New("my boom"),
+ }}, nil)
+ require.Error(t, err, "matching RunID error must be fatal")
+ require.True(t, done)
+}
+
+// TestRunStream_AgentErrorNoRunIDFiltersBySession verifies the
+// compatibility fallback: when the event carries no RunID, attribution
+// falls back to SessionID. An error for another session or with an
+// empty session is ignored, while an error for our own session is fatal
+// so a real failure is never dropped.
+func TestRunStream_AgentErrorNoRunIDFiltersBySession(t *testing.T) {
+ t.Parallel()
+
+ s := &runStream{sessionID: "S", out: &bytes.Buffer{}, read: map[string]int{}}
+
+ // Empty RunID for another session is ignored.
+ done, err := s.handle(pubsub.Event[proto.AgentEvent]{Payload: proto.AgentEvent{
+ Type: proto.AgentEventTypeError,
+ SessionID: "other",
+ Error: errors.New("foreign boom"),
+ }}, nil)
+ require.NoError(t, err, "error for another session must not abort our run")
+ require.False(t, done)
+
+ // Empty RunID with an empty session is ignored.
+ done, err = s.handle(pubsub.Event[proto.AgentEvent]{Payload: proto.AgentEvent{
+ Type: proto.AgentEventTypeError,
+ Error: errors.New("foreign boom"),
+ }}, nil)
+ require.NoError(t, err, "error with no session must not abort our run")
+ require.False(t, done)
+
+ // Empty RunID with a matching session is fatal.
+ done, err = s.handle(pubsub.Event[proto.AgentEvent]{Payload: proto.AgentEvent{
+ Type: proto.AgentEventTypeError,
+ SessionID: "S",
+ Error: errors.New("my boom"),
+ }}, nil)
+ require.Error(t, err, "error for our own session must be fatal")
+ require.True(t, done)
+}
+
// TestRunStream_NoRunIDFallsBackToSessionID preserves the older
// behaviour for callers (and tests) that don't supply a RunID:
// SessionID-only matching still terminates the stream on the
@@ -31,6 +31,13 @@ type AgentEvent struct {
Message Message `json:"message"`
Error error `json:"error,omitempty"`
+ // RunID echoes the caller-supplied AgentMessage.RunID for the run
+ // that produced this event. It lets observers (notably
+ // `crush run`) attribute an error event to a specific request
+ // instead of to any in-flight run on the session. Empty when no
+ // caller set one.
+ RunID string `json:"run_id,omitempty"`
+
// When summarizing.
SessionID string `json:"session_id,omitempty"`
SessionTitle string `json:"session_title,omitempty"`
@@ -60,6 +60,13 @@ func (s *runCoordinator) Run(ctx context.Context, sessionID, prompt string, atta
return nil, s.returnFn(ctx)
}
+func (s *runCoordinator) RunAccepted(ctx context.Context, accept *agent.AcceptedRun, sessionID, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error) {
+ return s.Run(ctx, sessionID, prompt, attachments...)
+}
+
+func (s *runCoordinator) BeginAccepted(sessionID string) *agent.AcceptedRun {
+ return nil
+}
func (s *runCoordinator) Cancel(string) {}
func (s *runCoordinator) CancelAll() {}
func (s *runCoordinator) IsBusy() bool { return false }
@@ -115,11 +122,12 @@ func postAgent(t *testing.T, c *controllerV1, ctx context.Context, wsID, session
}
// TestPostAgent_ReturnsOKOnContextCanceled verifies that when another
-// client cancels the session mid-turn, the prompting client's still
-// open POST receives 200 (not 500). The agent surfaces the
-// FinishReasonCanceled marker to every SSE subscriber via the
-// assistant message; the HTTP response from the prompter should not
-// double as an error signal.
+// client cancels the session mid-turn, the prompting client's POST is
+// unaffected: SendMessage is fire-and-forget, so the handler returns
+// 200 immediately without waiting for the turn. A run that later
+// returns context.Canceled never surfaces as a 500 to the prompter;
+// the FinishReasonCanceled marker reaches SSE subscribers via the
+// assistant message instead.
func TestPostAgent_ReturnsOKOnContextCanceled(t *testing.T) {
t.Parallel()
@@ -128,33 +136,56 @@ func TestPostAgent_ReturnsOKOnContextCanceled(t *testing.T) {
})
c, wsID := buildAgentWorkspace(t, coord)
- done := make(chan *httptest.ResponseRecorder, 1)
- go func() {
- done <- postAgent(t, c, t.Context(), wsID, "S1")
- }()
+ // The handler returns immediately, before the dispatched run is
+ // released, because the run no longer owns the HTTP response.
+ rec := postAgent(t, c, t.Context(), wsID, "S1")
+ require.Equal(t, http.StatusAccepted, rec.Code, "fire-and-forget SendMessage must return 202 without waiting for the run")
- // Wait until Run is in flight, then release it to return
- // context.Canceled.
+ // The run is dispatched on a goroutine; let it return
+ // context.Canceled. Nothing from that path reaches the (already
+ // returned) handler.
select {
case <-coord.entered:
case <-time.After(2 * time.Second):
- t.Fatal("coordinator Run was never entered")
+ t.Fatal("dispatched run was never entered")
}
close(coord.release)
- select {
- case rec := <-done:
- require.Equal(t, http.StatusOK, rec.Code, "context.Canceled from another client's cancel must not surface as 500")
- case <-time.After(2 * time.Second):
- t.Fatal("handler did not return after coordinator returned context.Canceled")
- }
+ // Wait for the dispatched run to fully return. Backend.runAgent
+ // swallows context.Canceled, so it must not publish a
+ // notify.TypeAgentError. Publishing would dereference the synthetic
+ // workspace's nil notification broker and crash this goroutine,
+ // which is the explicit guard that a cancel produces no top-level
+ // error event.
+ require.Eventually(t, func() bool {
+ return coord.ranCount.Load() == 1
+ }, 2*time.Second, 10*time.Millisecond)
}
-// TestPostAgent_DetachesRequestContext verifies that canceling the
-// prompting client's HTTP request context does not cancel the
-// in-flight agent run. The coordinator must observe a context whose
-// Done channel never fires from the request side; only the explicit
-// cancel endpoint may end the run.
+// TestHandleError_ContextCanceledFallsThroughTo500 documents the step 8
+// cleanup: the old context.Canceled special case in handleError was
+// removed because runtime cancellation of an agent run can no longer
+// reach handleError. The agent-prompt handler returns 202 before the run
+// starts (fire-and-forget SendMessage) and Backend.runAgent swallows
+// context.Canceled. Any context.Canceled that still reaches handleError
+// is therefore an unexpected synchronous error and falls through to the
+// default 500 like any other.
+func TestHandleError_ContextCanceledFallsThroughTo500(t *testing.T) {
+ t.Parallel()
+
+ c := &controllerV1{server: &Server{}}
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequestWithContext(t.Context(), http.MethodGet, "/", nil)
+
+ c.handleError(rec, req, context.Canceled)
+
+ require.Equal(t, http.StatusInternalServerError, rec.Code)
+}
+
+// TestPostAgent_DetachesRequestContext verifies that the dispatched run
+// is bound to the workspace context, not the prompting client's HTTP
+// request context. Canceling the request context must neither cancel
+// the run nor be observed by the coordinator.
func TestPostAgent_DetachesRequestContext(t *testing.T) {
t.Parallel()
@@ -165,46 +196,31 @@ func TestPostAgent_DetachesRequestContext(t *testing.T) {
reqCtx, cancelReq := context.WithCancel(context.Background())
- done := make(chan *httptest.ResponseRecorder, 1)
- go func() {
- done <- postAgent(t, c, reqCtx, wsID, "S1")
- }()
+ // The handler returns immediately; the run keeps executing on its
+ // own goroutine bound to the workspace context.
+ rec := postAgent(t, c, reqCtx, wsID, "S1")
+ require.Equal(t, http.StatusAccepted, rec.Code)
- // Wait until Run is in flight, then drop the prompting client.
select {
case <-coord.entered:
case <-time.After(2 * time.Second):
- t.Fatal("coordinator Run was never entered")
+ t.Fatal("dispatched run was never entered")
}
+
+ // Drop the prompting client. This must not reach the run.
cancelReq()
- // The captured ctx must be detached: context.WithoutCancel
- // returns a ctx with Done() == nil so request cancellation cannot
- // propagate.
got := coord.capturedCtx()
require.NotNil(t, got)
- require.Nil(t, got.Done(), "coordinator ctx must be detached from r.Context() via context.WithoutCancel")
- require.NoError(t, got.Err(), "coordinator ctx must not inherit cancellation from the dropped request")
-
- // Confirm Run is still running: it should not have completed
- // just because the request ctx was canceled.
- select {
- case <-done:
- t.Fatal("handler returned before run completed; request ctx cancellation leaked into the run")
- case <-time.After(50 * time.Millisecond):
- }
+ // Compare by identity (pointer), not reflect.DeepEqual: deep
+ // comparison would traverse context internals that the runtime
+ // mutates concurrently.
+ require.False(t, got == reqCtx, "run ctx must not be the request ctx")
+ require.NoError(t, got.Err(), "run ctx must not inherit cancellation from the dropped request")
- // Release the run; the handler should now complete cleanly.
+ // Release the run so it returns cleanly.
close(coord.release)
- select {
- case rec := <-done:
- // Writing to a recorder whose request ctx was canceled
- // still works; in production the TCP write would silently
- // fail, which is fine because the run already completed and
- // SSE subscribers have the result.
- require.Equal(t, http.StatusOK, rec.Code)
- case <-time.After(2 * time.Second):
- t.Fatal("handler did not return after release")
- }
- require.Equal(t, int32(1), coord.ranCount.Load())
+ require.Eventually(t, func() bool {
+ return coord.ranCount.Load() == 1
+ }, 2*time.Second, 10*time.Millisecond)
}
@@ -0,0 +1,741 @@
+package server
+
+import (
+ "bytes"
+ "context"
+ "encoding/json"
+ "fmt"
+ "io"
+ "net/http"
+ "sync"
+ "sync/atomic"
+ "testing"
+ "time"
+
+ "charm.land/fantasy"
+ "github.com/charmbracelet/crush/internal/agent"
+ "github.com/charmbracelet/crush/internal/app"
+ "github.com/charmbracelet/crush/internal/backend"
+ "github.com/charmbracelet/crush/internal/message"
+ "github.com/charmbracelet/crush/internal/proto"
+ "github.com/charmbracelet/crush/internal/pubsub"
+ "github.com/google/uuid"
+ "github.com/stretchr/testify/require"
+)
+
+// scriptedCoordinator is an agent.Coordinator stub that mimics the
+// externally-observable contract of a real run over the SSE pipeline
+// without booting a real model, database, or scheduler. It publishes a
+// user message when a run begins and an assistant message (with the
+// appropriate FinishReason) when the run ends, exactly the way the real
+// sessionAgent.Run surfaces a turn to SSE subscribers.
+//
+// A run blocks until either its per-session context is canceled (via
+// Cancel, mirroring the explicit cancel endpoint) or the test releases
+// it. On cancel it emits a FinishReasonCanceled assistant message and
+// returns context.Canceled (which backend.runAgent swallows, so no
+// AgentEvent error is published). On normal release it emits a
+// FinishReasonEndTurn assistant message and returns nil.
+//
+// The internal scheduler signal points the PLAN's e2e cases reference
+// (e.g. "before registration in activeRequests", "between
+// activeRequests.Set and assistant create") are not exposed by the
+// codebase, so this stub reproduces the documented black-box outcome by
+// controlling run timing directly through blockEntered / release.
+type scriptedCoordinator struct {
+ app *app.App
+
+ // blockEntered, when non-nil, is signaled (once) right after a run
+ // is entered and before the user message is emitted, letting a test
+ // interleave a cancel with the dispatched goroutine.
+ blockEntered chan struct{}
+
+ mu sync.Mutex
+ // cancels holds the cancel func for every in-flight run, keyed by a
+ // monotonic id so concurrent runs for the same session each get their
+ // own entry (a map keyed only by sessionID would let a second run
+ // overwrite the first's cancel func and leak it).
+ cancels map[int64]sessionCancel
+ // pendingCancels counts cancels that arrived for a session while a run
+ // was in flight; a run for that session consumes one on entry and
+ // cancels itself, modeling the cancel-on-entry path a follow-up takes.
+ pendingCancels map[string]int
+ nextRunID int64
+ // entered carries the monotonic run id assigned to each run as it is
+ // entered, so a test can correlate a later assistant message back to a
+ // specific run (run 1 vs an accepted follow-up).
+ entered chan int64
+ runStarts atomic.Int32
+
+ release chan struct{}
+}
+
+type sessionCancel struct {
+ sessionID string
+ cancel context.CancelFunc
+}
+
+func newScriptedCoordinator(a *app.App) *scriptedCoordinator {
+ return &scriptedCoordinator{
+ app: a,
+ cancels: make(map[int64]sessionCancel),
+ pendingCancels: make(map[string]int),
+ entered: make(chan int64, 8),
+ release: make(chan struct{}),
+ }
+}
+
+func (c *scriptedCoordinator) emitUser(sessionID, id string) {
+ c.app.SendEvent(pubsub.Event[message.Message]{
+ Type: pubsub.CreatedEvent,
+ Payload: message.Message{
+ ID: id,
+ SessionID: sessionID,
+ Role: message.User,
+ Parts: []message.ContentPart{message.TextContent{Text: "hi"}},
+ },
+ })
+}
+
+func (c *scriptedCoordinator) emitAssistant(sessionID, id string, reason message.FinishReason) {
+ c.app.SendEvent(pubsub.Event[message.Message]{
+ Type: pubsub.CreatedEvent,
+ Payload: message.Message{
+ ID: id,
+ SessionID: sessionID,
+ Role: message.Assistant,
+ Parts: []message.ContentPart{message.Finish{Reason: reason}},
+ },
+ })
+}
+
+func (c *scriptedCoordinator) Run(ctx context.Context, sessionID, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error) {
+ c.runStarts.Add(1)
+ runCtx, cancel := context.WithCancel(ctx)
+
+ c.mu.Lock()
+ id := c.nextRunID
+ c.nextRunID++
+ c.cancels[id] = sessionCancel{sessionID: sessionID, cancel: cancel}
+ // Cancel-on-entry: if a cancel for this session arrived while this
+ // run was still being dispatched (no run yet in flight to receive
+ // it), consume the pending cancel now so the run takes the canceled
+ // path instead of streaming output.
+ if c.pendingCancels[sessionID] > 0 {
+ c.pendingCancels[sessionID]--
+ cancel()
+ }
+ c.mu.Unlock()
+
+ select {
+ case c.entered <- id:
+ default:
+ }
+
+ if c.blockEntered != nil {
+ select {
+ case <-c.blockEntered:
+ case <-runCtx.Done():
+ }
+ }
+
+ defer func() {
+ c.mu.Lock()
+ delete(c.cancels, id)
+ c.mu.Unlock()
+ cancel()
+ }()
+
+ // Qualify the emitted message ids with the run id so a test can
+ // attribute an assistant message to the exact run that produced it
+ // (run 1 vs an accepted follow-up sharing the same session).
+ userID := fmt.Sprintf("u-%s-%d", sessionID, id)
+ asstID := fmt.Sprintf("a-%s-%d", sessionID, id)
+
+ c.emitUser(sessionID, userID)
+
+ // Cancellation takes priority: if the run was already canceled it
+ // must take the canceled path even when release is closed, so a
+ // canceled run never races into a normal FinishReasonEndTurn.
+ select {
+ case <-runCtx.Done():
+ c.emitAssistant(sessionID, asstID, message.FinishReasonCanceled)
+ return nil, context.Canceled
+ default:
+ }
+
+ select {
+ case <-c.release:
+ c.emitAssistant(sessionID, asstID, message.FinishReasonEndTurn)
+ return nil, nil
+ case <-runCtx.Done():
+ c.emitAssistant(sessionID, asstID, message.FinishReasonCanceled)
+ return nil, context.Canceled
+ }
+}
+
+func (c *scriptedCoordinator) RunAccepted(ctx context.Context, accept *agent.AcceptedRun, sessionID, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error) {
+ return c.Run(ctx, sessionID, prompt, attachments...)
+}
+
+func (c *scriptedCoordinator) BeginAccepted(string) *agent.AcceptedRun { return nil }
+
+func (c *scriptedCoordinator) Cancel(sessionID string) {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ // Cancel every in-flight run for this session. Concurrent runs for
+ // the same session (an active run plus an accepted follow-up still
+ // dispatching) each hold their own entry, so all of them are torn
+ // down by a single per-session cancel.
+ var canceled int
+ for _, sc := range c.cancels {
+ if sc.sessionID == sessionID {
+ sc.cancel()
+ canceled++
+ }
+ }
+ // If at least one run was in flight, arm a pending cancel so a
+ // follow-up that has been accepted but not yet entered Run takes the
+ // cancel-on-entry path. With no run in flight this is a no-op,
+ // mirroring the production guarantee that an idle cancel does not arm
+ // a pending cancel against the next prompt.
+ if canceled > 0 {
+ c.pendingCancels[sessionID]++
+ }
+}
+
+func (c *scriptedCoordinator) CancelAll() {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ for _, sc := range c.cancels {
+ sc.cancel()
+ }
+}
+
+func (c *scriptedCoordinator) IsBusy() bool { return false }
+func (c *scriptedCoordinator) IsSessionBusy(string) bool { return false }
+func (c *scriptedCoordinator) QueuedPrompts(string) int { return 0 }
+func (c *scriptedCoordinator) QueuedPromptsList(string) []string { return nil }
+func (c *scriptedCoordinator) ClearQueue(string) {}
+func (c *scriptedCoordinator) Summarize(context.Context, string) error { return nil }
+func (c *scriptedCoordinator) Model() agent.Model { return agent.Model{} }
+func (c *scriptedCoordinator) UpdateModels(context.Context) error { return nil }
+
+// agentE2EHarness extends the SSE harness with a scripted coordinator
+// wired into the workspace's embedded app.App, so POST /agent drives a
+// real backend.SendMessage dispatch whose emitted user/assistant
+// messages fan out over the same SSE pipeline production uses.
+type agentE2EHarness struct {
+ *e2eHarness
+ coord *scriptedCoordinator
+}
+
+func newAgentE2EHarness(t *testing.T) *agentE2EHarness {
+ t.Helper()
+
+ h := &e2eHarness{}
+
+ appCtx, cancel := context.WithCancel(context.Background())
+ a := app.NewForTest(appCtx)
+ coord := newScriptedCoordinator(a)
+ a.AgentCoordinator = coord
+ t.Cleanup(func() {
+ cancel()
+ a.ShutdownForTest()
+ })
+
+ h.installServer(t)
+
+ ws := &backend.Workspace{
+ ID: uuid.New().String(),
+ Path: t.TempDir(),
+ App: a,
+ }
+ backend.SetWorkspaceShutdownFnForTest(ws, func() {})
+ backend.InsertWorkspaceForTest(h.backend, ws)
+
+ h.workspace = ws
+ h.app = a
+ return &agentE2EHarness{e2eHarness: h, coord: coord}
+}
+
+// postAgentHTTP drives POST /v1/workspaces/{id}/agent over the harness's
+// httptest server and returns the status code.
+func (h *agentE2EHarness) postAgentHTTP(t *testing.T, ctx context.Context, sessionID string) int {
+ t.Helper()
+ body, err := json.Marshal(proto.AgentMessage{SessionID: sessionID, Prompt: "hi"})
+ require.NoError(t, err)
+ req, err := http.NewRequestWithContext(ctx, http.MethodPost,
+ h.httpSrv.URL+"/v1/workspaces/"+h.workspace.ID+"/agent", bytes.NewReader(body))
+ require.NoError(t, err)
+ req.Header.Set("Content-Type", "application/json")
+ resp, err := h.httpSrv.Client().Do(req)
+ require.NoError(t, err)
+ _, _ = io.Copy(io.Discard, resp.Body)
+ resp.Body.Close()
+ return resp.StatusCode
+}
+
+// cancelAgentHTTP drives POST /v1/workspaces/{id}/agent/sessions/{sid}/cancel.
+func (h *agentE2EHarness) cancelAgentHTTP(t *testing.T, ctx context.Context, sessionID string) int {
+ t.Helper()
+ req, err := http.NewRequestWithContext(ctx, http.MethodPost,
+ h.httpSrv.URL+"/v1/workspaces/"+h.workspace.ID+"/agent/sessions/"+sessionID+"/cancel", nil)
+ require.NoError(t, err)
+ resp, err := h.httpSrv.Client().Do(req)
+ require.NoError(t, err)
+ _, _ = io.Copy(io.Discard, resp.Body)
+ resp.Body.Close()
+ return resp.StatusCode
+}
+
+// waitForRunEntered blocks until a dispatched run for any session has
+// been entered by the scripted coordinator, or fails the test. It
+// returns the monotonic run id assigned to that run so a caller can
+// correlate it with a later assistant message; callers that don't need
+// the id can ignore the return value.
+func (h *agentE2EHarness) waitForRunEntered(t *testing.T) int64 {
+ t.Helper()
+ select {
+ case id := <-h.coord.entered:
+ return id
+ case <-time.After(2 * time.Second):
+ t.Fatal("dispatched run was never entered")
+ return 0
+ }
+}
+
+// finishReason extracts the assistant message's FinishReason, if any.
+func finishReason(m proto.Message) (proto.FinishReason, bool) {
+ for _, p := range m.Parts {
+ if f, ok := p.(proto.Finish); ok {
+ return f.Reason, true
+ }
+ }
+ return "", false
+}
+
+// TestE2E_CancelByOtherClientDoesNotErrorPrompter covers PLAN Tests ->
+// New end-to-end coverage item 1: a second client canceling a run does
+// not surface a server error to the prompter; the run ends with a
+// FinishReasonCanceled assistant message and no AgentEvent carries a
+// non-nil Error.
+func TestE2E_CancelByOtherClientDoesNotErrorPrompter(t *testing.T) {
+ t.Parallel()
+ h := newAgentE2EHarness(t)
+ ctx, cancel := context.WithCancel(t.Context())
+ t.Cleanup(cancel)
+
+ cidA := uuid.New().String()
+ cidB := uuid.New().String()
+ evcA, cancelA := h.subscribeSSE(t, ctx, h.workspace.ID, cidA)
+ t.Cleanup(cancelA)
+ evcB, cancelB := h.subscribeSSE(t, ctx, h.workspace.ID, cidB)
+ t.Cleanup(cancelB)
+ h.waitForAttached(t, 2)
+
+ const sid = "s-cancel-other"
+
+ // A posts a long-running prompt; the handler must return 202
+ // immediately (the run blocks in the coordinator).
+ require.Equal(t, http.StatusAccepted, h.postAgentHTTP(t, ctx, sid))
+ h.waitForRunEntered(t)
+
+ // B cancels.
+ require.Equal(t, http.StatusOK, h.cancelAgentHTTP(t, ctx, sid))
+
+ // A's SSE stream receives the FinishReasonCanceled assistant
+ // message.
+ pickCtx, pickCancel := context.WithTimeout(ctx, 3*time.Second)
+ defer pickCancel()
+ got, ok := drainUntil(pickCtx, evcA, func(e pubsub.Event[proto.Message]) bool {
+ r, has := finishReason(e.Payload)
+ return e.Payload.Role == proto.Assistant && has && r == proto.FinishReasonCanceled
+ })
+ require.True(t, ok, "client A must observe a FinishReasonCanceled assistant message")
+ require.Equal(t, sid, got.Payload.SessionID)
+
+ // No AgentEvent error reaches A (cancel is not a server error).
+ errCtx, errCancel := context.WithTimeout(ctx, 250*time.Millisecond)
+ defer errCancel()
+ _, gotErrA := drainUntil(errCtx, evcA, func(e pubsub.Event[proto.AgentEvent]) bool {
+ return e.Payload.Type == proto.AgentEventTypeError && e.Payload.Error != nil
+ })
+ require.False(t, gotErrA, "cancel must not surface an AgentEvent error to the prompter")
+
+ // And no AgentEvent error reaches the canceling client B either; the
+ // PLAN requires that *no* client observes a non-nil Error.
+ errCtxB, errCancelB := context.WithTimeout(ctx, 250*time.Millisecond)
+ defer errCancelB()
+ _, gotErrB := drainUntil(errCtxB, evcB, func(e pubsub.Event[proto.AgentEvent]) bool {
+ return e.Payload.Type == proto.AgentEventTypeError && e.Payload.Error != nil
+ })
+ require.False(t, gotErrB, "cancel must not surface an AgentEvent error to any client")
+}
+
+// TestE2E_CancelImmediatelyAfter202IsNotLost covers PLAN item 1a: a
+// cancel that races a freshly-dispatched run (before it would emit any
+// output) is not lost. The run takes the cancel-on-entry path and emits
+// a user message followed by a FinishReasonCanceled assistant message
+// rather than streaming model output.
+func TestE2E_CancelImmediatelyAfter202IsNotLost(t *testing.T) {
+ t.Parallel()
+ h := newAgentE2EHarness(t)
+ // Gate the run on a signal the test controls so the cancel can be
+ // observed while the dispatched goroutine is parked at entry.
+ h.coord.blockEntered = make(chan struct{})
+
+ ctx, cancel := context.WithCancel(t.Context())
+ t.Cleanup(cancel)
+
+ cid := uuid.New().String()
+ evc, cancelSSE := h.subscribeSSE(t, ctx, h.workspace.ID, cid)
+ t.Cleanup(cancelSSE)
+ h.waitForAttached(t, 1)
+
+ const sid = "s-race-cancel"
+ require.Equal(t, http.StatusAccepted, h.postAgentHTTP(t, ctx, sid))
+ h.waitForRunEntered(t)
+
+ // Cancel while the run is still blocked at entry, then release it.
+ require.Equal(t, http.StatusOK, h.cancelAgentHTTP(t, ctx, sid))
+ close(h.coord.blockEntered)
+
+ pickCtx, pickCancel := context.WithTimeout(ctx, 3*time.Second)
+ defer pickCancel()
+
+ gotUser, okUser := drainUntil(pickCtx, evc, func(e pubsub.Event[proto.Message]) bool {
+ return e.Payload.Role == proto.User && e.Payload.SessionID == sid
+ })
+ require.True(t, okUser, "the canceled turn must still record a user message")
+ require.Equal(t, sid, gotUser.Payload.SessionID)
+
+ gotAsst, okAsst := drainUntil(pickCtx, evc, func(e pubsub.Event[proto.Message]) bool {
+ r, has := finishReason(e.Payload)
+ return e.Payload.Role == proto.Assistant && has && r == proto.FinishReasonCanceled
+ })
+ require.True(t, okAsst, "the canceled turn must end with a FinishReasonCanceled assistant message")
+ require.Equal(t, sid, gotAsst.Payload.SessionID)
+}
+
+// TestE2E_IdleCancelDoesNotPoisonNextPrompt covers PLAN item 1b: an
+// idle cancel (no active run) must not poison the next prompt. With the
+// scripted coordinator the cancel records a pending entry only if a run
+// is in flight; an idle cancel records one, but the documented
+// guarantee is that the *next* prompt's outcome is observable. Here we
+// assert the regression-relevant external behavior: after an idle
+// cancel, a subsequent normal prompt is able to run and emit output.
+//
+// NOTE: This is a simplified version. The real "idle Escape must not
+// poison" guarantee lives inside sessionAgent.Cancel's acceptedRuns
+// gating, which is covered by the agent unit tests; the e2e stub cannot
+// distinguish "truly idle" from "accepted but not yet running" without
+// the internal acceptedRuns signal. See test summary.
+func TestE2E_IdleCancelDoesNotPoisonNextPrompt(t *testing.T) {
+ t.Parallel()
+ h := newAgentE2EHarness(t)
+ ctx, cancel := context.WithCancel(t.Context())
+ t.Cleanup(cancel)
+
+ cid := uuid.New().String()
+ evc, cancelSSE := h.subscribeSSE(t, ctx, h.workspace.ID, cid)
+ t.Cleanup(cancelSSE)
+ h.waitForAttached(t, 1)
+
+ const sid = "s-idle-cancel"
+
+ // Idle cancel: no run in flight. The scripted coordinator drops it
+ // (no pending cancel recorded for a session that has no run), which
+ // models the production guarantee that an idle Escape does not arm
+ // a cancel against the next prompt.
+ require.Equal(t, http.StatusOK, h.cancelAgentHTTP(t, ctx, sid))
+
+ // Now a normal prompt; release it so it finishes successfully.
+ require.Equal(t, http.StatusAccepted, h.postAgentHTTP(t, ctx, sid))
+ h.waitForRunEntered(t)
+ close(h.coord.release)
+
+ pickCtx, pickCancel := context.WithTimeout(ctx, 3*time.Second)
+ defer pickCancel()
+ got, ok := drainUntil(pickCtx, evc, func(e pubsub.Event[proto.Message]) bool {
+ r, has := finishReason(e.Payload)
+ return e.Payload.Role == proto.Assistant && has && r == proto.FinishReasonEndTurn
+ })
+ require.True(t, ok, "the next prompt after an idle cancel must run to FinishReasonEndTurn")
+ require.Equal(t, sid, got.Payload.SessionID)
+
+ // And it must not be marked canceled.
+ canCtx, canCancel := context.WithTimeout(ctx, 200*time.Millisecond)
+ defer canCancel()
+ _, gotCanceled := drainUntil(canCtx, evc, func(e pubsub.Event[proto.Message]) bool {
+ r, has := finishReason(e.Payload)
+ return e.Payload.Role == proto.Assistant && has && r == proto.FinishReasonCanceled
+ })
+ require.False(t, gotCanceled, "an idle cancel must not produce a FinishReasonCanceled marker on the next prompt")
+}
+
+// TestE2E_CancelBetweenActiveSetAndAssistantCreate covers PLAN item 1d:
+// a cancel that arrives after the run has begun but before it would
+// create the assistant message must still produce a user message and a
+// FinishReasonCanceled assistant message, never a silent return. The
+// blockEntered gate parks the run after entry (modeling the window
+// between activeRequests.Set and assistant creation).
+func TestE2E_CancelBetweenActiveSetAndAssistantCreate(t *testing.T) {
+ t.Parallel()
+ h := newAgentE2EHarness(t)
+ h.coord.blockEntered = make(chan struct{})
+
+ ctx, cancel := context.WithCancel(t.Context())
+ t.Cleanup(cancel)
+
+ cid := uuid.New().String()
+ evc, cancelSSE := h.subscribeSSE(t, ctx, h.workspace.ID, cid)
+ t.Cleanup(cancelSSE)
+ h.waitForAttached(t, 1)
+
+ const sid = "s-mid-window"
+ require.Equal(t, http.StatusAccepted, h.postAgentHTTP(t, ctx, sid))
+ h.waitForRunEntered(t)
+
+ // Cancel while parked at entry; then release so the run proceeds
+ // into its cancel branch (runCtx already canceled).
+ require.Equal(t, http.StatusOK, h.cancelAgentHTTP(t, ctx, sid))
+ close(h.coord.blockEntered)
+
+ pickCtx, pickCancel := context.WithTimeout(ctx, 3*time.Second)
+ defer pickCancel()
+
+ _, okUser := drainUntil(pickCtx, evc, func(e pubsub.Event[proto.Message]) bool {
+ return e.Payload.Role == proto.User && e.Payload.SessionID == sid
+ })
+ require.True(t, okUser, "a user message must be recorded for the canceled turn")
+
+ gotAsst, okAsst := drainUntil(pickCtx, evc, func(e pubsub.Event[proto.Message]) bool {
+ r, has := finishReason(e.Payload)
+ return e.Payload.Role == proto.Assistant && has && r == proto.FinishReasonCanceled
+ })
+ require.True(t, okAsst, "the run must not return silently; it must emit a FinishReasonCanceled assistant message")
+ require.Equal(t, sid, gotAsst.Payload.SessionID)
+
+ // No AgentEvent error is published: a cancel in the
+ // activeRequests.Set -> assistant-create window is not a server
+ // error.
+ errCtx, errCancel := context.WithTimeout(ctx, 250*time.Millisecond)
+ defer errCancel()
+ _, gotErr := drainUntil(errCtx, evc, func(e pubsub.Event[proto.AgentEvent]) bool {
+ return e.Payload.Type == proto.AgentEventTypeError && e.Payload.Error != nil
+ })
+ require.False(t, gotErr, "no AgentEvent error must be published for the canceled turn")
+}
+
+// TestE2E_PromptRequestContextDoesNotOwnRun covers PLAN item 2: the
+// prompting client's HTTP request context does not own the run. A POST
+// with a very short request-context timeout still returns 202 before
+// that context would expire, and the run keeps going (observed via SSE
+// finishing normally after release).
+func TestE2E_PromptRequestContextDoesNotOwnRun(t *testing.T) {
+ t.Parallel()
+ h := newAgentE2EHarness(t)
+ streamCtx, streamCancel := context.WithCancel(t.Context())
+ t.Cleanup(streamCancel)
+
+ cid := uuid.New().String()
+ evc, cancelSSE := h.subscribeSSE(t, streamCtx, h.workspace.ID, cid)
+ t.Cleanup(cancelSSE)
+ h.waitForAttached(t, 1)
+
+ const sid = "s-short-req"
+
+ // The POST request context times out almost immediately. The
+ // handler must still return 202 (fire-and-forget) and the run must
+ // survive past the request-context deadline.
+ reqCtx, reqCancel := context.WithTimeout(t.Context(), 50*time.Millisecond)
+ defer reqCancel()
+ require.Equal(t, http.StatusAccepted, h.postAgentHTTP(t, reqCtx, sid))
+ h.waitForRunEntered(t)
+
+ // Let the request context expire, then release the run.
+ <-reqCtx.Done()
+ close(h.coord.release)
+
+ pickCtx, pickCancel := context.WithTimeout(streamCtx, 3*time.Second)
+ defer pickCancel()
+ got, ok := drainUntil(pickCtx, evc, func(e pubsub.Event[proto.Message]) bool {
+ r, has := finishReason(e.Payload)
+ return e.Payload.Role == proto.Assistant && has && r == proto.FinishReasonEndTurn
+ })
+ require.True(t, ok, "the run must finish normally even after the prompting request context expired")
+ require.Equal(t, sid, got.Payload.SessionID)
+}
+
+// TestE2E_AgentRunSurvivesAcrossWorkspaceClaims covers PLAN item 3: a
+// run started by client A survives A detaching as long as another
+// client (B) keeps the workspace alive; B observes the run finish via
+// SSE.
+func TestE2E_AgentRunSurvivesAcrossWorkspaceClaims(t *testing.T) {
+ t.Parallel()
+ h := newAgentE2EHarness(t)
+
+ ctxA, cancelA := context.WithCancel(t.Context())
+ ctxB, cancelB := context.WithCancel(t.Context())
+ t.Cleanup(cancelB)
+
+ cidA := uuid.New().String()
+ cidB := uuid.New().String()
+ _, killA := h.subscribeSSE(t, ctxA, h.workspace.ID, cidA)
+ t.Cleanup(killA)
+ evcB, killB := h.subscribeSSE(t, ctxB, h.workspace.ID, cidB)
+ t.Cleanup(killB)
+ h.waitForAttached(t, 2)
+
+ const sid = "s-survive"
+ // A is the poster; the run must outlive A detaching as long as B
+ // keeps the workspace alive.
+ require.Equal(t, http.StatusAccepted, h.postAgentHTTP(t, ctxA, sid))
+ h.waitForRunEntered(t)
+
+ // A detaches; B is still attached so the workspace stays alive.
+ cancelA()
+ killA()
+ require.Eventually(t, func() bool {
+ return backend.WorkspaceLiveStreamCountForTest(h.workspace) == 1
+ }, 3*time.Second, 10*time.Millisecond,
+ "A detaching must leave B as the sole attached client")
+ require.False(t, h.shutdownHit.Load(), "workspace must stay alive while B is attached")
+
+ // Release the run; B must still observe it finish.
+ close(h.coord.release)
+ pickCtx, pickCancel := context.WithTimeout(ctxB, 3*time.Second)
+ defer pickCancel()
+ got, ok := drainUntil(pickCtx, evcB, func(e pubsub.Event[proto.Message]) bool {
+ r, has := finishReason(e.Payload)
+ return e.Payload.Role == proto.Assistant && has && r == proto.FinishReasonEndTurn
+ })
+ require.True(t, ok, "B must observe the run finish after A detaches")
+ require.Equal(t, sid, got.Payload.SessionID)
+}
+
+// TestE2E_CancelOfActiveRunAlsoCancelsAcceptedFollowUp covers PLAN item
+// 1c at the externally-observable level: while session sid has an active
+// run, a second prompt for sid is accepted; a cancel for sid must cancel
+// the active run and must not let the follow-up stream a normal
+// FinishReasonEndTurn.
+//
+// The sequence follows the PLAN exactly: prompt 1 becomes the active
+// run, prompt 2 for the same sid is accepted, then a cancel for sid
+// fires, and only afterwards are any signals released. The scripted
+// coordinator models the externally-observable contract of the
+// busy-queue branch and pendingCancels (which depend on internal
+// scheduler signals the codebase does not expose): a per-session cancel
+// tears down every in-flight run for sid and arms a cancel-on-entry for
+// a follow-up still dispatching. The invariant asserted is the one that
+// matters: after the cancel, the active run ends canceled and the
+// follow-up never streams a normal FinishReasonEndTurn.
+func TestE2E_CancelOfActiveRunAlsoCancelsAcceptedFollowUp(t *testing.T) {
+ t.Parallel()
+ h := newAgentE2EHarness(t)
+ ctx, cancel := context.WithCancel(t.Context())
+ t.Cleanup(cancel)
+
+ cid := uuid.New().String()
+ evc, cancelSSE := h.subscribeSSE(t, ctx, h.workspace.ID, cid)
+ t.Cleanup(cancelSSE)
+ h.waitForAttached(t, 1)
+
+ const sid = "s-followup"
+
+ // (a) Prompt 1 for sid becomes the active run. Capture its run id so
+ // the canceled assistant message below can be attributed to run 1
+ // unambiguously.
+ require.Equal(t, http.StatusAccepted, h.postAgentHTTP(t, ctx, sid))
+ run1 := h.waitForRunEntered(t)
+
+ // (b) Prompt 2 for the *same* sid is accepted while the active run
+ // is still in flight; it is the follow-up the PLAN describes
+ // (acceptedRuns > 0, either still dispatching or about to enter the
+ // busy-queue branch).
+ require.Equal(t, http.StatusAccepted, h.postAgentHTTP(t, ctx, sid))
+ run2 := h.waitForRunEntered(t)
+ require.NotEqual(t, run1, run2, "the follow-up must be a distinct run from the active one")
+
+ // (c) B cancels sid. This tears down every in-flight run for the
+ // session and arms a pending cancel for any follow-up that has not
+ // yet entered Run.
+ require.Equal(t, http.StatusOK, h.cancelAgentHTTP(t, ctx, sid))
+
+ // (d) Open the coordinator gate so any run that is NOT canceled would
+ // be free to proceed straight into the normal FinishReasonEndTurn
+ // branch. The scripted Run checks runCtx.Done() before the release
+ // select, so a canceled run still takes the canceled path even with
+ // release closed; only a non-canceled run reaches FinishReasonEndTurn.
+ // Releasing here is therefore what makes the assertions below
+ // meaningful: if the cancel had failed to tear down run 1 or arm the
+ // cancel-on-entry for the follow-up, the freed gate would let that run
+ // stream a normal FinishReasonEndTurn and the test would fail.
+ close(h.coord.release)
+
+ pickCtx, pickCancel := context.WithTimeout(ctx, 3*time.Second)
+ defer pickCancel()
+
+ // (e) Run 1 (the active run) must end with FinishReasonCanceled. The
+ // assistant message id is qualified with the run id, so matching on
+ // run1's id proves the cancellation is attributed to the FIRST run
+ // and not to the follow-up.
+ //
+ // The single drain below is also the negative assertion for run 2:
+ // the match closure inspects every assistant event for sid as it
+ // scans, and if it ever observes the follow-up (run 2) streaming a
+ // normal FinishReasonEndTurn it records that violation immediately.
+ // This is what makes the run-2 check sound: a previous two-phase
+ // approach could let this very drain consume and discard a run-2
+ // EndTurn while still hunting for run 1's canceled message, leaving a
+ // later no-EndTurn check unable to prove run 2 stayed canceled.
+ // Folding the negative check into the same scan means a run-2 EndTurn
+ // can never slip past unobserved, whether it arrives before or after
+ // run 1's canceled message.
+ run1AsstID := fmt.Sprintf("a-%s-%d", sid, run1)
+ run2AsstID := fmt.Sprintf("a-%s-%d", sid, run2)
+ var followUpEndTurn bool
+ got, ok := drainUntil(pickCtx, evc, func(e pubsub.Event[proto.Message]) bool {
+ if e.Payload.SessionID != sid || e.Payload.Role != proto.Assistant {
+ return false
+ }
+ r, has := finishReason(e.Payload)
+ if !has {
+ return false
+ }
+ // Any normal model output for sid after the cancel is a
+ // violation. The follow-up (run 2) must never reach the
+ // FinishReasonEndTurn branch; flag it the moment it is seen so
+ // the assertion below fails even if this event arrives while we
+ // are still waiting for run 1's canceled message.
+ if r == proto.FinishReasonEndTurn {
+ if e.Payload.ID == run2AsstID || e.Payload.ID != run1AsstID {
+ followUpEndTurn = true
+ }
+ // Stop draining; the EndTurn observation is decisive and the
+ // require.False below will surface the failure.
+ return true
+ }
+ return e.Payload.ID == run1AsstID && r == proto.FinishReasonCanceled
+ })
+ require.False(t, followUpEndTurn, "the accepted follow-up must not stream a normal FinishReasonEndTurn after the cancel")
+ require.True(t, ok, "the first (active) run must end with FinishReasonCanceled")
+ require.Equal(t, run1AsstID, got.Payload.ID, "the canceled message must belong to the first (active) run")
+ gotReason, gotHas := finishReason(got.Payload)
+ require.True(t, gotHas)
+ require.Equal(t, proto.FinishReasonCanceled, gotReason, "the matched run-1 message must be canceled, not a normal end turn")
+ require.Equal(t, sid, got.Payload.SessionID)
+
+ // Confirm no normal FinishReasonEndTurn for sid is still in flight.
+ // By this point the scan above has already ruled out a run-2 EndTurn
+ // arriving before run 1's canceled message; this guards against one
+ // arriving afterward.
+ endCtx, endCancel := context.WithTimeout(ctx, 300*time.Millisecond)
+ defer endCancel()
+ _, gotEnd := drainUntil(endCtx, evc, func(e pubsub.Event[proto.Message]) bool {
+ r, has := finishReason(e.Payload)
+ return e.Payload.SessionID == sid && e.Payload.Role == proto.Assistant && has && r == proto.FinishReasonEndTurn
+ })
+ require.False(t, gotEnd, "the accepted follow-up must not stream model output after the cancel")
+}
@@ -240,6 +240,18 @@ func decodeSSEEnvelope(p pubsub.Payload) (any, bool) {
return nil, false
}
return e, true
+ case pubsub.PayloadTypeAgentEvent:
+ var e pubsub.Event[proto.AgentEvent]
+ if err := json.Unmarshal(p.Payload, &e); err != nil {
+ return nil, false
+ }
+ return e, true
+ case pubsub.PayloadTypeRunComplete:
+ var e pubsub.Event[proto.RunComplete]
+ if err := json.Unmarshal(p.Payload, &e); err != nil {
+ return nil, false
+ }
+ return e, true
}
return nil, false
}
@@ -2,6 +2,7 @@ package server
import (
"encoding/json"
+ "errors"
"fmt"
"log/slog"
@@ -85,13 +86,19 @@ func wrapEvent(ev any) *pubsub.Payload {
Payload: fileToProto(e.Payload),
})
case pubsub.Event[notify.Notification]:
+ payload := proto.AgentEvent{
+ SessionID: e.Payload.SessionID,
+ SessionTitle: e.Payload.SessionTitle,
+ RunID: e.Payload.RunID,
+ Type: proto.AgentEventType(e.Payload.Type),
+ }
+ if e.Payload.Type == notify.TypeAgentError {
+ payload.Type = proto.AgentEventTypeError
+ payload.Error = errors.New(e.Payload.Message)
+ }
return envelope(pubsub.PayloadTypeAgentEvent, pubsub.Event[proto.AgentEvent]{
- Type: e.Type,
- Payload: proto.AgentEvent{
- SessionID: e.Payload.SessionID,
- SessionTitle: e.Payload.SessionTitle,
- Type: proto.AgentEventType(e.Payload.Type),
- },
+ Type: e.Type,
+ Payload: payload,
})
case pubsub.Event[notify.RunComplete]:
return envelope(pubsub.PayloadTypeRunComplete, pubsub.Event[proto.RunComplete]{
@@ -123,6 +123,38 @@ func TestRunCompleteToProto_RoundTrip(t *testing.T) {
require.False(t, decoded.Payload.Cancelled)
}
+// TestAgentErrorToProto_PreservesRunID verifies that an async agent
+// error notification carries its originating RunID (and SessionID)
+// through the SSE envelope. Without these correlators, `crush run`
+// cannot tell whether an error event belongs to its own run and
+// would abort on any unrelated workspace failure.
+func TestAgentErrorToProto_PreservesRunID(t *testing.T) {
+ t.Parallel()
+
+ src := pubsub.Event[notify.Notification]{
+ Type: pubsub.CreatedEvent,
+ Payload: notify.Notification{
+ SessionID: "S",
+ RunID: "run-99",
+ Type: notify.TypeAgentError,
+ Message: "boom",
+ },
+ }
+
+ env := wrapEvent(src)
+ require.NotNil(t, env)
+ require.Equal(t, pubsub.PayloadTypeAgentEvent, env.Type)
+
+ var decoded pubsub.Event[proto.AgentEvent]
+ require.NoError(t, json.Unmarshal(env.Payload, &decoded))
+ require.Equal(t, proto.AgentEventTypeError, decoded.Payload.Type)
+ require.Equal(t, "S", decoded.Payload.SessionID)
+ require.Equal(t, "run-99", decoded.Payload.RunID,
+ "RunID must survive so observers can attribute the error to its run")
+ require.NotNil(t, decoded.Payload.Error)
+ require.Equal(t, "boom", decoded.Payload.Error.Error())
+}
+
// TestRunCompleteToProto_Error verifies that error- and cancel-shaped
// RunComplete events round-trip cleanly so clients can distinguish
// "agent failed" (returns non-zero from `crush run`) from "agent
@@ -1,7 +1,6 @@
package server
import (
- "context"
"encoding/json"
"errors"
"fmt"
@@ -740,9 +739,10 @@ func (c *controllerV1) handleGetWorkspaceAgent(w http.ResponseWriter, r *http.Re
// @Accept json
// @Param id path string true "Workspace ID"
// @Param request body proto.AgentMessage true "Agent message"
-// @Success 200
+// @Success 202
// @Failure 400 {object} proto.Error
// @Failure 404 {object} proto.Error
+// @Failure 409 {object} proto.Error
// @Failure 500 {object} proto.Error
// @Router /workspaces/{id}/agent [post]
func (c *controllerV1) handlePostWorkspaceAgent(w http.ResponseWriter, r *http.Request) {
@@ -755,18 +755,19 @@ func (c *controllerV1) handlePostWorkspaceAgent(w http.ResponseWriter, r *http.R
return
}
- // Detach the run's lifetime from the prompting client's HTTP
- // request. Without this, A dropping its TCP connection (network
- // blip, TUI restart) or B canceling the session via the explicit
- // cancel endpoint would also cancel A's request context and tear
- // down a turn that other subscribed clients are still watching.
- // Only the explicit cancel endpoint should be able to end a run.
- ctx := context.WithoutCancel(r.Context())
- if err := c.backend.SendMessage(ctx, id, msg); err != nil {
+ // The run's lifetime is detached from the prompting client's HTTP
+ // request: SendMessage validates and accepts the prompt, dispatches
+ // the run on a goroutine bound to the workspace context, and returns
+ // immediately. A dropping its TCP connection (network blip, TUI
+ // restart) or B canceling the session via the explicit cancel
+ // endpoint can no longer tear down a turn that other subscribed
+ // clients are still watching. Only the explicit cancel endpoint
+ // should be able to end a run.
+ if err := c.backend.SendMessage(id, msg); err != nil {
c.handleError(w, r, err)
return
}
- w.WriteHeader(http.StatusOK)
+ w.WriteHeader(http.StatusAccepted)
}
// handlePostWorkspaceAgentInit initializes the agent for a workspace.
@@ -1033,15 +1034,15 @@ func (c *controllerV1) handleGetWorkspacePermissionsSkip(w http.ResponseWriter,
// handleError maps backend errors to HTTP status codes and writes the
// JSON error response.
+//
+// Runtime cancellation of an agent run no longer reaches here for the
+// agent-prompt path: SendMessage is fire-and-forget (the handler returns
+// 202 before the run starts) and Backend.runAgent swallows
+// context.Canceled, surfacing the FinishReasonCanceled marker to SSE
+// subscribers instead. The remaining callers pass synchronous backend
+// errors, so context.Canceled gets no special case and would fall through
+// to the default 500 like any other unexpected error.
func (c *controllerV1) handleError(w http.ResponseWriter, r *http.Request, err error) {
- // A canceled agent run is not an error from the prompting
- // client's perspective. The cancellation reaches every SSE
- // subscriber via the FinishReasonCanceled marker on the assistant
- // message; the still-open POST should not surface a 500.
- if errors.Is(err, context.Canceled) {
- w.WriteHeader(http.StatusOK)
- return
- }
status := http.StatusInternalServerError
switch {
case errors.Is(err, backend.ErrWorkspaceNotFound):
@@ -1060,6 +1061,8 @@ func (c *controllerV1) handleError(w http.ResponseWriter, r *http.Request, err e
status = http.StatusBadRequest
case errors.Is(err, backend.ErrClientNotAttached):
status = http.StatusNotFound
+ case errors.Is(err, backend.ErrWorkspaceClosing):
+ status = http.StatusConflict
}
c.server.logError(r, err.Error())
jsonError(w, status, err.Error())
@@ -30,6 +30,14 @@ type stubCoordinator struct {
func (s *stubCoordinator) Run(ctx context.Context, sessionID, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error) {
return nil, nil
}
+
+func (s *stubCoordinator) RunAccepted(ctx context.Context, accept *agent.AcceptedRun, sessionID, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error) {
+ return nil, nil
+}
+
+func (s *stubCoordinator) BeginAccepted(sessionID string) *agent.AcceptedRun {
+ return nil
+}
func (s *stubCoordinator) Cancel(string) {}
func (s *stubCoordinator) CancelAll() {}
func (s *stubCoordinator) IsBusy() bool { return false }
@@ -3285,12 +3285,12 @@ func (m *UI) sendMessage(content string, attachments ...message.Attachment) tea.
// Capture session ID to avoid race with main goroutine updating m.session.
sessionID := m.session.ID
cmds = append(cmds, func() tea.Msg {
+ // AgentRun is fire-and-forget: it returns once the prompt has
+ // been accepted (HTTP 202) or synchronously with a validation
+ // or transport error. Run failures and cancellation surface
+ // through SSE-derived events, not this return value.
err := m.com.Workspace.AgentRun(context.Background(), sessionID, content, attachments...)
if err != nil {
- isCancelErr := errors.Is(err, context.Canceled)
- if isCancelErr {
- return nil
- }
return util.InfoMsg{
Type: util.InfoTypeError,
Msg: fmt.Sprintf("%v", err),
@@ -704,13 +704,18 @@ func (w *ClientWorkspace) translateEvent(ev any) tea.Msg {
Payload: protoToFile(e.Payload),
}
case pubsub.Event[proto.AgentEvent]:
+ n := notify.Notification{
+ SessionID: e.Payload.SessionID,
+ SessionTitle: e.Payload.SessionTitle,
+ RunID: e.Payload.RunID,
+ Type: notify.Type(e.Payload.Type),
+ }
+ if e.Payload.Error != nil {
+ n.Message = e.Payload.Error.Error()
+ }
return pubsub.Event[notify.Notification]{
- Type: e.Type,
- Payload: notify.Notification{
- SessionID: e.Payload.SessionID,
- SessionTitle: e.Payload.SessionTitle,
- Type: notify.Type(e.Payload.Type),
- },
+ Type: e.Type,
+ Payload: n,
}
case pubsub.Event[proto.RunComplete]:
// Translate the wire-level proto.RunComplete back into the