From 90b1a2a67bf70a2f5fc1f998ccd84317cab5d68b Mon Sep 17 00:00:00 2001 From: Christian Rocha Date: Sun, 31 May 2026 21:48:57 -0400 Subject: [PATCH 01/15] feat(server): make server prompts independent of client connections Give each workspace its own run lifetime and shutdown gate so agent work is not tied to the HTTP request that submitted it. Final state can still be saved during shutdown, avoiding lost completion or error records when a workspace is canceled. Co-Authored-By: Charm Crush --- internal/agent/agent.go | 29 ++++++++++++++----- internal/backend/backend.go | 57 ++++++++++++++++++++++++++++++++++++- 2 files changed, 78 insertions(+), 8 deletions(-) diff --git a/internal/agent/agent.go b/internal/agent/agent.go index 53e63af3b95e8bb4fba6144675d97c3686e78546..ef37f164d711c608916e98dc15ddedcaf1694033 100644 --- a/internal/agent/agent.go +++ b/internal/agent/agent.go @@ -305,7 +305,13 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (result * // (the pubsub broker fan-in does not serialize publishes from // different upstream brokers). defer func() { - if flushErr := a.messages.FlushAll(ctx); flushErr != nil { + // Use a context detached from the run context: workspace + // shutdown cancels ctx before this goroutine returns, but the + // buffered streaming deltas must still land before the DB is + // closed. A short timeout bounds the flush. + flushCtx, flushCancel := context.WithTimeout(context.WithoutCancel(ctx), 5*time.Second) + defer flushCancel() + if flushErr := a.messages.FlushAll(flushCtx); flushErr != nil { slog.Error("Failed to flush pending message updates after run", "error", flushErr) } if skipRunComplete { @@ -577,11 +583,20 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (result * if currentAssistant == nil { return result, err } + // Persist final state with a context detached from the run + // context. The run context (ctx) is derived from the + // workspace context, which workspace shutdown cancels before + // agent goroutines finish; using ctx here would drop the + // final assistant state. WithoutCancel keeps the values + // (e.g. session ID) while ignoring cancellation, and a short + // timeout bounds the cleanup writes. + cleanupCtx, cleanupCancel := context.WithTimeout(context.WithoutCancel(ctx), 5*time.Second) + defer cleanupCancel() // Ensure we finish thinking on error to close the reasoning state. currentAssistant.FinishThinking() toolCalls := currentAssistant.ToolCalls() - // INFO: we use the parent context here because the genCtx has been cancelled. - msgs, createErr := a.messages.List(ctx, currentAssistant.SessionID) + // INFO: we use the cleanup context here because the genCtx has been cancelled. + msgs, createErr := a.messages.List(cleanupCtx, currentAssistant.SessionID) if createErr != nil { return nil, createErr } @@ -590,7 +605,7 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (result * tc.Finished = true tc.Input = "{}" currentAssistant.AddToolCall(tc) - updateErr := a.messages.Update(ctx, *currentAssistant) + updateErr := a.messages.Update(cleanupCtx, *currentAssistant) if updateErr != nil { return nil, updateErr } @@ -623,7 +638,7 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (result * Content: content, IsError: true, } - _, createErr = a.messages.Create(ctx, currentAssistant.SessionID, message.CreateMessageParams{ + _, createErr = a.messages.Create(cleanupCtx, currentAssistant.SessionID, message.CreateMessageParams{ Role: message.Tool, Parts: []message.ContentPart{ toolResult, @@ -670,9 +685,9 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (result * } else { currentAssistant.AddFinish(message.FinishReasonError, defaultTitle, err.Error()) } - // Note: we use the parent context here because the genCtx has been + // Note: we use the cleanup context here because the genCtx has been // cancelled. - updateErr := a.messages.Update(ctx, *currentAssistant) + updateErr := a.messages.Update(cleanupCtx, *currentAssistant) if updateErr != nil { return nil, updateErr } diff --git a/internal/backend/backend.go b/internal/backend/backend.go index 9c0f47f7ab8dfd48f90a31004d637dd6e54fd912..2ea24a86c71fbc7c4f58c5de680aeb83d421584b 100644 --- a/internal/backend/backend.go +++ b/internal/backend/backend.go @@ -34,6 +34,7 @@ var ( ErrUnknownCommand = errors.New("unknown command") ErrInvalidClientID = errors.New("invalid client_id") ErrClientNotAttached = errors.New("client not attached") + ErrWorkspaceClosing = errors.New("workspace closing") ) // DefaultCreateGrace is the window in which a client must open an SSE @@ -108,6 +109,23 @@ type Workspace struct { // with fallback to the cleaned absolute path. resolvedPath string + // ctx is the workspace-scoped run context. It is derived from + // the backend context in CreateWorkspace and lives for the + // lifetime of the workspace; cancel tears it down. Agent runs + // dispatched on behalf of this workspace are bound to ctx so + // their lifetime is owned by the workspace, not by any single + // client's HTTP request. + ctx context.Context + cancel context.CancelFunc + + // runMu guards closing and gates dispatch of new agent runs. + // closing is set by Shutdown so no new runs are accepted once + // teardown has begun. runWG tracks dispatched agent goroutines + // so Shutdown can wait for them to return before app cleanup. + runMu sync.Mutex + closing bool + runWG sync.WaitGroup + // clientsMu guards clients. It is held only briefly (no IO). clientsMu sync.Mutex // clients tracks each client's claim on this workspace. Refcount @@ -122,7 +140,7 @@ type Workspace struct { } // invokeShutdown calls the workspace shutdown hook if set, falling -// back to the embedded [app.App.Shutdown] when not. +// back to the workspace [Workspace.Shutdown] wrapper when not. func (w *Workspace) invokeShutdown() { if w.shutdownFn != nil { w.shutdownFn() @@ -133,6 +151,40 @@ func (w *Workspace) invokeShutdown() { } } +// Shutdown tears the workspace down in an order that is safe for +// agent runs whose lifetime is bound to the workspace context. It +// shadows the promoted [app.App.Shutdown] so callers reaching +// ws.Shutdown() always observe this ordering: +// +// 1. Mark the workspace closing so no new agent runs are accepted. +// 2. Cancel the workspace run context so any dispatched goroutine +// that has not yet registered its per-session cancel still +// observes cancellation. +// 3. Cancel active coordinator work for runs that already +// registered their per-session cancel function. +// 4. Wait for dispatched agent goroutines to return. +// 5. Run the embedded [app.App.Shutdown] cleanup (DB, LSP, etc). +// +// CancelAll is idempotent, so the second call inside app.App.Shutdown +// is harmless; the important guarantee is that cancel -> CancelAll -> +// runWG.Wait completes before the embedded cleanup touches the DB. +func (w *Workspace) Shutdown() { + w.runMu.Lock() + w.closing = true + w.runMu.Unlock() + + if w.cancel != nil { + w.cancel() + } + if w.App != nil && w.AgentCoordinator != nil { + w.AgentCoordinator.CancelAll() + } + w.runWG.Wait() + if w.App != nil { + w.App.Shutdown() + } +} + // New creates a new [Backend]. func New(ctx context.Context, cfg *config.ConfigStore, shutdownFn ShutdownFunc) *Backend { return &Backend{ @@ -247,6 +299,7 @@ func (b *Backend) CreateWorkspace(args proto.Workspace) (*Workspace, proto.Works return nil, proto.Workspace{}, fmt.Errorf("failed to create app workspace: %w", err) } + wsCtx, wsCancel := context.WithCancel(b.ctx) ws := &Workspace{ App: appWorkspace, ID: id, @@ -255,6 +308,8 @@ func (b *Backend) CreateWorkspace(args proto.Workspace) (*Workspace, proto.Works Env: args.Env, Skills: skillsMgr, resolvedPath: key, + ctx: wsCtx, + cancel: wsCancel, clients: make(map[string]*clientState), } From 8d7166e67ebe7c93b4ecca5ab9a2f02fa92d87a0 Mon Sep 17 00:00:00 2001 From: Christian Rocha Date: Sun, 31 May 2026 21:51:33 -0400 Subject: [PATCH 02/15] refactor(server): share prompt validation before background dispatch Use the same prompt and session checks before accepting background work that synchronous runs already use. This keeps rejected prompts consistent as the server begins accepting requests asynchronously. Co-Authored-By: Charm Crush --- internal/agent/agent.go | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/internal/agent/agent.go b/internal/agent/agent.go index ef37f164d711c608916e98dc15ddedcaf1694033..bc3df59e7626943ca31cbc293c0b76a814c05fae 100644 --- a/internal/agent/agent.go +++ b/internal/agent/agent.go @@ -183,12 +183,25 @@ func NewSessionAgent( } } -func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (result *fantasy.AgentResult, retErr error) { +// ValidateCall performs the cheap structural validation that +// sessionAgent.Run requires before a call can be dispatched: a call must +// carry either a non-empty prompt or a text attachment, and it must name a +// session. It is exported so callers that accept a run before dispatching it +// (e.g. backend.SendMessage) can apply the same checks and keep the error +// contract consistent. +func ValidateCall(call SessionAgentCall) error { if call.Prompt == "" && !message.ContainsTextAttachment(call.Attachments) { - return nil, ErrEmptyPrompt + return ErrEmptyPrompt } if call.SessionID == "" { - return nil, ErrSessionMissing + return ErrSessionMissing + } + return nil +} + +func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (result *fantasy.AgentResult, retErr error) { + if err := ValidateCall(call); err != nil { + return nil, err } // Queue the message if busy. Strip OnComplete: the caller that From 4347015ab7669b1376e6cb814c10f89aef3481ad Mon Sep 17 00:00:00 2001 From: Christian Rocha Date: Sun, 31 May 2026 22:32:39 -0400 Subject: [PATCH 03/15] chore(server): honor cancels immediately after prompt acceptance Track prompts that have been accepted but have not started running yet so a cancel issued right after acceptance applies to that prompt. Idle-session cancels remain a no-op, preventing one client's cancel from poisoning the next prompt. Co-Authored-By: Charm Crush --- internal/agent/accepted_run_test.go | 226 +++++++++++++++ internal/agent/agent.go | 360 ++++++++++++++++++++++-- internal/agent/coordinator.go | 33 +++ internal/agent/coordinator_test.go | 4 + internal/agent/dispatch_cancel_test.go | 198 +++++++++++++ internal/server/agent_cancel_test.go | 7 + internal/server/sessions_isbusy_test.go | 8 + 7 files changed, 812 insertions(+), 24 deletions(-) create mode 100644 internal/agent/accepted_run_test.go create mode 100644 internal/agent/dispatch_cancel_test.go diff --git a/internal/agent/accepted_run_test.go b/internal/agent/accepted_run_test.go new file mode 100644 index 0000000000000000000000000000000000000000..d62422a9f02bec68a8da1a08c6e6d6b52e7d7699 --- /dev/null +++ b/internal/agent/accepted_run_test.go @@ -0,0 +1,226 @@ +package agent + +import ( + "context" + "testing" + + "github.com/charmbracelet/crush/internal/message" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// newCancelTestAgent builds a DB-backed sessionAgent with no model. The +// tests here exercise the dispatch/cancel/persist paths, none of which +// reach agent.Stream, so a model is unnecessary. +func newCancelTestAgent(t *testing.T) (*sessionAgent, fakeEnv) { + t.Helper() + env := testEnv(t) + sa := NewSessionAgent(SessionAgentOptions{ + Sessions: env.sessions, + Messages: env.messages, + }).(*sessionAgent) + return sa, env +} + +func (a *sessionAgent) acceptedCount(sessionID string) int { + c, _ := a.acceptedRuns.Get(sessionID) + return c +} + +func (a *sessionAgent) hasPendingCancel(sessionID string) bool { + _, ok := a.pendingCancels.Get(sessionID) + return ok +} + +func TestAcceptedRun_CloseIsIdempotent(t *testing.T) { + t.Parallel() + sa, _ := newCancelTestAgent(t) + + accept := sa.BeginAccepted("sid") + require.Equal(t, "sid", accept.SessionID()) + require.Equal(t, 1, sa.acceptedCount("sid")) + + accept.Close() + require.Equal(t, 0, sa.acceptedCount("sid")) + + // Repeated Close must not underflow the counter. + accept.Close() + accept.Close() + require.Equal(t, 0, sa.acceptedCount("sid")) +} + +func TestAcceptedRun_MultipleReservations(t *testing.T) { + t.Parallel() + sa, _ := newCancelTestAgent(t) + + a1 := sa.BeginAccepted("sid") + a2 := sa.BeginAccepted("sid") + require.Equal(t, 2, sa.acceptedCount("sid")) + + a1.Close() + require.Equal(t, 1, sa.acceptedCount("sid")) + + a2.Close() + require.Equal(t, 0, sa.acceptedCount("sid")) +} + +func TestAcceptedRun_NilSafe(t *testing.T) { + t.Parallel() + var accept *AcceptedRun + require.Equal(t, "", accept.SessionID()) + // Must not panic. + accept.Close() +} + +func TestCancel_IdleDoesNotRecordPendingCancel(t *testing.T) { + t.Parallel() + sa, _ := newCancelTestAgent(t) + + // No accepted run, no active request: a true no-op. + sa.Cancel("sid") + require.False(t, sa.hasPendingCancel("sid")) +} + +func TestCancel_AcceptedRecordsPendingCancel(t *testing.T) { + t.Parallel() + sa, _ := newCancelTestAgent(t) + + accept := sa.BeginAccepted("sid") + defer accept.Close() + + sa.Cancel("sid") + require.True(t, sa.hasPendingCancel("sid")) +} + +func TestCancel_SecondCancelWhilePendingIsNoOp(t *testing.T) { + t.Parallel() + sa, _ := newCancelTestAgent(t) + + accept := sa.BeginAccepted("sid") + defer accept.Close() + + sa.Cancel("sid") + require.True(t, sa.hasPendingCancel("sid")) + + // A second cancel while a pending cancel is already recorded must + // remain a single pending cancel; one Run consumes exactly one. + sa.Cancel("sid") + require.True(t, sa.hasPendingCancel("sid")) +} + +func TestRun_CancelOnEntryPersistsCanceledTurn(t *testing.T) { + t.Parallel() + sa, env := newCancelTestAgent(t) + + sess, err := env.sessions.Create(t.Context(), "session") + require.NoError(t, err) + + accept := sa.BeginAccepted(sess.ID) + // A cancel arrives in the accepted-but-not-yet-active window. + sa.Cancel(sess.ID) + require.True(t, sa.hasPendingCancel(sess.ID)) + + result, err := sa.Run(t.Context(), SessionAgentCall{ + SessionID: sess.ID, + Prompt: "hello", + Accepted: accept, + }) + require.NoError(t, err) + require.Nil(t, result) + + // The pending cancel was consumed and the accept released. + require.False(t, sa.hasPendingCancel(sess.ID)) + require.Equal(t, 0, sa.acceptedCount(sess.ID)) + + msgs, err := env.messages.List(t.Context(), sess.ID) + require.NoError(t, err) + require.Len(t, msgs, 2) + assert.Equal(t, message.User, msgs[0].Role) + assert.Equal(t, message.Assistant, msgs[1].Role) + assert.Equal(t, message.FinishReasonCanceled, msgs[1].FinishReason()) +} + +func TestPersistCanceledTurn_WritesBothWhenUserMissing(t *testing.T) { + t.Parallel() + sa, env := newCancelTestAgent(t) + + sess, err := env.sessions.Create(t.Context(), "session") + require.NoError(t, err) + + err = sa.persistCanceledTurn(t.Context(), SessionAgentCall{ + SessionID: sess.ID, + Prompt: "hello", + }, false) + require.NoError(t, err) + + msgs, err := env.messages.List(t.Context(), sess.ID) + require.NoError(t, err) + require.Len(t, msgs, 2) + assert.Equal(t, message.User, msgs[0].Role) + assert.Equal(t, message.Assistant, msgs[1].Role) + assert.Equal(t, message.FinishReasonCanceled, msgs[1].FinishReason()) +} + +func TestPersistCanceledTurn_WritesAssistantOnlyWhenUserCreated(t *testing.T) { + t.Parallel() + sa, env := newCancelTestAgent(t) + + sess, err := env.sessions.Create(t.Context(), "session") + require.NoError(t, err) + + // Simulate PrepareStep having already created the user message. + _, err = sa.createUserMessage(t.Context(), SessionAgentCall{ + SessionID: sess.ID, + Prompt: "hello", + }) + require.NoError(t, err) + + err = sa.persistCanceledTurn(t.Context(), SessionAgentCall{ + SessionID: sess.ID, + Prompt: "hello", + }, true) + require.NoError(t, err) + + msgs, err := env.messages.List(t.Context(), sess.ID) + require.NoError(t, err) + require.Len(t, msgs, 2) + assert.Equal(t, message.User, msgs[0].Role) + assert.Equal(t, message.Assistant, msgs[1].Role) + assert.Equal(t, message.FinishReasonCanceled, msgs[1].FinishReason()) +} + +func TestPersistCanceledTurn_SucceedsWithCanceledContext(t *testing.T) { + t.Parallel() + sa, env := newCancelTestAgent(t) + + sess, err := env.sessions.Create(t.Context(), "session") + require.NoError(t, err) + + // Simulate workspace shutdown having already canceled the run + // context. WithoutCancel must let the writes through. + ctx, cancel := context.WithCancel(t.Context()) + cancel() + + err = sa.persistCanceledTurn(ctx, SessionAgentCall{ + SessionID: sess.ID, + Prompt: "hello", + }, false) + require.NoError(t, err) + + msgs, err := env.messages.List(t.Context(), sess.ID) + require.NoError(t, err) + require.Len(t, msgs, 2) +} + +func TestClearPendingCancel(t *testing.T) { + t.Parallel() + sa, _ := newCancelTestAgent(t) + + accept := sa.BeginAccepted("sid") + defer accept.Close() + sa.Cancel("sid") + require.True(t, sa.hasPendingCancel("sid")) + + sa.clearPendingCancel("sid") + require.False(t, sa.hasPendingCancel("sid")) +} diff --git a/internal/agent/agent.go b/internal/agent/agent.go index bc3df59e7626943ca31cbc293c0b76a814c05fae..97bd7a21af4c28f30e46fcaf23c23348467e90ac 100644 --- a/internal/agent/agent.go +++ b/internal/agent/agent.go @@ -21,6 +21,7 @@ import ( "strconv" "strings" "sync" + "sync/atomic" "time" "charm.land/catwalk/pkg/catwalk" @@ -103,10 +104,19 @@ type SessionAgentCall struct { // recursion drains, so falling back to the default broker // publish keeps the event visible to subscribers. OnComplete func(notify.RunComplete) + // Accepted, when non-nil, is the accept reservation taken by + // BeginAccepted before the call was dispatched onto a goroutine + // (the client/server fire-and-forget path). Run consumes it under + // dispatchMu[SessionID] once the accepted -> (cancel-on-entry | + // queued | active) transition has been chosen. When nil + // (in-process / local callers like AppWorkspace), behavior is + // unchanged and no accept tracking applies. + Accepted *AcceptedRun } type SessionAgent interface { Run(context.Context, SessionAgentCall) (*fantasy.AgentResult, error) + BeginAccepted(sessionID string) *AcceptedRun SetModels(large Model, small Model) SetTools(tools []fantasy.AgentTool) SetSystemPrompt(systemPrompt string) @@ -145,6 +155,32 @@ type sessionAgent struct { messageQueue *csync.Map[string, []SessionAgentCall] activeRequests *csync.Map[string, context.CancelFunc] + + // dispatchMu holds a per-session mutex that serializes the + // accepted -> (cancel-on-entry | queued | active) transition in + // Run against a concurrent Cancel. The lock is held only during + // the brief handoff (no DB or LLM I/O under the lock). + dispatchMu *csync.Map[string, *sync.Mutex] + // acceptedRuns counts dispatched-but-not-yet-active runs per + // session. A counter > 0 means a dispatched prompt is in flight + // and has not yet completed the dispatch handoff in Run. Only + // BeginAccepted increments it; only AcceptedRun.Close decrements + // it. + acceptedRuns *csync.Map[string, int] + // pendingCancels records sessions whose dispatched-but-not-yet- + // running call should observe a cancellation request. It is only + // set by Cancel when acceptedRuns > 0, so an idle Escape never + // poisons the next prompt. + pendingCancels *csync.Map[string, struct{}] + // dispatchMuCreate guards lazy creation of per-session entries in + // dispatchMu so two goroutines can't race to lock different mutex + // instances for the same session. + dispatchMuCreate sync.Mutex + // acceptedMu serializes increments/decrements of acceptedRuns. It + // is separate from dispatchMu so AcceptedRun.Close (which may run + // while Run holds dispatchMu for the same session) does not + // deadlock by re-entering the dispatch lock. + acceptedMu sync.Mutex } type SessionAgentOptions struct { @@ -180,7 +216,144 @@ func NewSessionAgent( runComplete: opts.RunComplete, messageQueue: csync.NewMap[string, []SessionAgentCall](), activeRequests: csync.NewMap[string, context.CancelFunc](), + dispatchMu: csync.NewMap[string, *sync.Mutex](), + acceptedRuns: csync.NewMap[string, int](), + pendingCancels: csync.NewMap[string, struct{}](), + } +} + +// AcceptedRun owns exactly one accept reservation taken by +// BeginAccepted. It is the only carrier of accept-state across the +// backend.runAgent / Coordinator.Run / sessionAgent.Run layers: a +// counter > 0 means a dispatched prompt is in flight and has not yet +// completed the dispatch handoff in Run. Close is the only way to +// release the reservation and is idempotent. +type AcceptedRun struct { + agent *sessionAgent + sessionID string + done atomic.Bool +} + +// Close decrements the accept counter for this reservation. It is safe +// to call multiple times; only the first call has effect. +func (r *AcceptedRun) Close() { + if r == nil { + return + } + if !r.done.CompareAndSwap(false, true) { + return + } + r.agent.endAccepted(r.sessionID) +} + +// SessionID exposes the session this reservation is for so the run path +// can use it without an extra parameter. +func (r *AcceptedRun) SessionID() string { + if r == nil { + return "" + } + return r.sessionID +} + +// BeginAccepted increments the accept counter for sessionID and returns +// a handle whose Close is the only way to decrement it. It is the only +// entry point that mutates acceptedRuns. +func (a *sessionAgent) BeginAccepted(sessionID string) *AcceptedRun { + a.acceptedMu.Lock() + defer a.acceptedMu.Unlock() + count, _ := a.acceptedRuns.Get(sessionID) + a.acceptedRuns.Set(sessionID, count+1) + return &AcceptedRun{agent: a, sessionID: sessionID} +} + +// endAccepted decrements the accept counter for sessionID. It is only +// called via AcceptedRun.Close. It uses a dedicated lock (not the +// per-session dispatch mutex) so it can run while Run holds dispatchMu +// for the same session without deadlocking. +func (a *sessionAgent) endAccepted(sessionID string) { + a.acceptedMu.Lock() + defer a.acceptedMu.Unlock() + count, ok := a.acceptedRuns.Get(sessionID) + if !ok || count <= 1 { + a.acceptedRuns.Del(sessionID) + return + } + a.acceptedRuns.Set(sessionID, count-1) +} + +// sessionMu returns the per-session dispatch mutex, creating it on first +// use. Creation is guarded so concurrent callers always observe the same +// mutex instance for a given session. +func (a *sessionAgent) sessionMu(sessionID string) *sync.Mutex { + if mu, ok := a.dispatchMu.Get(sessionID); ok { + return mu + } + a.dispatchMuCreate.Lock() + defer a.dispatchMuCreate.Unlock() + if mu, ok := a.dispatchMu.Get(sessionID); ok { + return mu + } + mu := &sync.Mutex{} + a.dispatchMu.Set(sessionID, mu) + return mu +} + +// enqueueCall appends call to the session's message queue. The +// OnComplete hook is stripped: the caller that supplied it (typically +// coordinator.Run) has its own retry/coalesce scope that ends when it +// returns, so by the time the queue drains nobody is left to consume the +// buffered terminal event. The recursive Run falls back to the default +// broker publish, which is what existing subscribers expect for queued +// turns. +func (a *sessionAgent) enqueueCall(call SessionAgentCall) { + existing, ok := a.messageQueue.Get(call.SessionID) + if !ok { + existing = []SessionAgentCall{} } + queued := call + queued.OnComplete = nil + queued.Accepted = nil + existing = append(existing, queued) + a.messageQueue.Set(call.SessionID, existing) +} + +// clearPendingCancel removes any pending-cancel record for sessionID. It +// takes the per-session dispatch lock so it is ordered against Cancel and +// the dispatch handoff. +func (a *sessionAgent) clearPendingCancel(sessionID string) { + mu := a.sessionMu(sessionID) + mu.Lock() + defer mu.Unlock() + a.pendingCancels.Del(sessionID) +} + +// persistCanceledTurn writes the user/assistant records for a turn that +// was canceled before (or just as) streaming would have produced them. +// It creates the user message only when it was not already created by an +// earlier createUserMessage call (userMsgCreated), then writes an +// assistant message with FinishReasonCanceled. Both writes use +// context.WithoutCancel(ctx) so workspace shutdown (which cancels the run +// context) can't drop them. +func (a *sessionAgent) persistCanceledTurn(ctx context.Context, call SessionAgentCall, userMsgCreated bool) error { + writeCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), 5*time.Second) + defer cancel() + if !userMsgCreated { + if _, err := a.createUserMessage(writeCtx, call); err != nil { + return err + } + } + largeModel := a.largeModel.Get() + assistant, err := a.messages.Create(writeCtx, call.SessionID, message.CreateMessageParams{ + Role: message.Assistant, + Parts: []message.ContentPart{}, + Model: largeModel.ModelCfg.Model, + Provider: largeModel.ModelCfg.Provider, + }) + if err != nil { + return err + } + assistant.AddFinish(message.FinishReasonCanceled, "User canceled request", "") + return a.messages.Update(writeCtx, assistant) } // ValidateCall performs the cheap structural validation that @@ -204,22 +377,73 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (result * return nil, err } - // Queue the message if busy. Strip OnComplete: the caller that - // supplied the hook (typically coordinator.Run) has its own - // retry/coalesce scope that ends when it returns, so by the time - // the queue drains nobody is left to consume the buffered - // terminal event. The recursive Run will fall back to the - // default broker publish, which is what existing subscribers - // expect for queued turns. - if a.IsSessionBusy(call.SessionID) { - existing, ok := a.messageQueue.Get(call.SessionID) - if !ok { - existing = []SessionAgentCall{} + // genCtx/cancel are the run context and its cancel func. For the + // accepted (fire-and-forget) dispatch path they are created under + // dispatchMu below so a concurrent Cancel can observe the + // activeRequests entry before the assistant message exists. For + // the in-process path they stay nil here and are created later, + // preserving the original ordering. + var ( + genCtx context.Context + cancel context.CancelFunc + activeRegistered bool + userMsgCreated bool + ) + + if call.Accepted != nil { + // Serialize the accepted -> (cancel-on-entry | queued | + // active) transition against a concurrent Cancel. Cancel takes + // the same per-session lock, so every cancel observes at least + // one of: pendingCancels, an activeRequests entry, or a + // messageQueue entry it then clears. + mu := a.sessionMu(call.SessionID) + mu.Lock() + + if _, pending := a.pendingCancels.Get(call.SessionID); pending { + // Cancel-on-entry: a cancel arrived while this run was + // dispatched but not yet active. Consume the pending + // cancel, release the accept reservation, drop the lock, + // and persist a canceled turn without entering Stream. + a.pendingCancels.Del(call.SessionID) + call.Accepted.Close() + mu.Unlock() + if err := a.persistCanceledTurn(ctx, call, false); err != nil { + return nil, err + } + return nil, nil + } + + if a.IsSessionBusy(call.SessionID) { + // Busy: an earlier prompt is active. Queue this call and + // release the accept reservation. A Cancel arriving after + // this point sees the active entry and clears the queue. + a.enqueueCall(call) + call.Accepted.Close() + mu.Unlock() + return nil, nil } - queued := call - queued.OnComplete = nil - existing = append(existing, queued) - a.messageQueue.Set(call.SessionID, existing) + + // Idle: become the active run. Register the cancel func before + // dropping the lock so a Cancel that arrives between here and + // assistant creation is not lost. + runCtx := context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID) + genCtx, cancel = context.WithCancel(runCtx) + a.activeRequests.Set(call.SessionID, cancel) + activeRegistered = true + call.Accepted.Close() + mu.Unlock() + + defer cancel() + defer a.activeRequests.Del(call.SessionID) + } else if a.IsSessionBusy(call.SessionID) { + // Queue the message if busy. Strip OnComplete: the caller that + // supplied the hook (typically coordinator.Run) has its own + // retry/coalesce scope that ends when it returns, so by the time + // the queue drains nobody is left to consume the buffered + // terminal event. The recursive Run will fall back to the + // default broker publish, which is what existing subscribers + // expect for queued turns. + a.enqueueCall(call) return nil, nil } @@ -282,15 +506,22 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (result * if err != nil { return nil, err } + userMsgCreated = true // Add the session to the context. ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID) - genCtx, cancel := context.WithCancel(ctx) - a.activeRequests.Set(call.SessionID, cancel) + // For the accepted dispatch path the run context and cancel func + // were already created and registered under dispatchMu above; reuse + // them. For the in-process path create them here, preserving the + // original ordering. + if !activeRegistered { + genCtx, cancel = context.WithCancel(ctx) + a.activeRequests.Set(call.SessionID, cancel) - defer cancel() - defer a.activeRequests.Del(call.SessionID) + defer cancel() + defer a.activeRequests.Del(call.SessionID) + } // skipRunComplete is set just before the queued-recursion path so // the outer Run doesn't publish a RunComplete that would race // with — and be superseded by — the recursive call's own @@ -390,14 +621,24 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (result * // Use latest tools (updated by SetTools when MCP tools change). prepared.Tools = a.tools.Copy() + // Drain queued follow-up prompts, but skip them if a cancel + // was recorded for the session while they sat in the queue: + // a cancel that arrived after the queue insertion must not + // let the queued prompt run as part of this step. + dispatchLock := a.sessionMu(call.SessionID) + dispatchLock.Lock() + _, canceled := a.pendingCancels.Get(call.SessionID) queuedCalls, _ := a.messageQueue.Get(call.SessionID) a.messageQueue.Del(call.SessionID) - for _, queued := range queuedCalls { - userMessage, createErr := a.createUserMessage(callContext, queued) - if createErr != nil { - return callContext, prepared, createErr + dispatchLock.Unlock() + if !canceled { + for _, queued := range queuedCalls { + userMessage, createErr := a.createUserMessage(callContext, queued) + if createErr != nil { + return callContext, prepared, createErr + } + prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...) } - prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...) } prepared.Messages = a.workaroundProviderMediaLimitations(prepared.Messages, largeModel) @@ -594,6 +835,18 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (result * isHyper := largeModel.ModelCfg.Provider == hyper.Name isCancelErr := errors.Is(err, context.Canceled) if currentAssistant == nil { + // Cancel-before-assistant-creation window: the run was + // canceled after activeRequests.Set but before PrepareStep + // created the assistant message. Without this, the turn + // would return with no FinishReasonCanceled marker and no + // user-visible record. The user message was already created + // above, so persistCanceledTurn only writes the assistant + // record. + if isCancelErr { + if persistErr := a.persistCanceledTurn(ctx, call, userMsgCreated); persistErr != nil { + return nil, persistErr + } + } return result, err } // Persist final state with a context detached from the run @@ -741,8 +994,35 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (result * }) } + // Hand off to the next queued prompt (if any) under dispatchMu so + // the transition from this finished run to the queued run is atomic + // against a concurrent Cancel. activeRequests for this session was + // just deleted above, so without the lock there is a window in + // which the session looks idle and a cancel becomes a no-op that + // fails to stop the queued prompt. Holding the lock lets us observe + // a pending cancel recorded against the session and drop the queue + // instead of running it, and (for the recursion) hand a fresh + // accept reservation to the dequeued call so acceptedRuns stays > 0 + // across the recursive Run's own dispatch handoff — keeping the + // session observable to Cancel for the entire transition and + // closing the dequeue -> re-register window. + mu := a.sessionMu(call.SessionID) + mu.Lock() + if _, pending := a.pendingCancels.Get(call.SessionID); pending { + // A cancel was recorded for this session (e.g. it arrived while + // this run was active and a follow-up had been accepted). Drop + // the queue instead of running it and consume the marker. + a.pendingCancels.Del(call.SessionID) + a.messageQueue.Del(call.SessionID) + mu.Unlock() + return result, err + } queuedMessages, ok := a.messageQueue.Get(call.SessionID) if !ok || len(queuedMessages) == 0 { + // No queued work. Clear any stale pending-cancel entry as a + // safety net so it can't catch a future run (no-op when unset). + a.pendingCancels.Del(call.SessionID) + mu.Unlock() return result, err } // There are queued messages restart the loop. The recursive Run @@ -753,6 +1033,14 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (result * skipRunComplete = true firstQueuedMessage := queuedMessages[0] a.messageQueue.Set(call.SessionID, queuedMessages[1:]) + // Reserve a fresh accept for the dequeued prompt before dropping the + // lock so acceptedRuns > 0 across the handoff into the recursive + // Run. This closes the window between this dequeue and the recursive + // Run registering its activeRequests entry: a cancel arriving in + // that window now records a pending cancel (acceptedRuns > 0) that + // the recursive Run's accepted path observes as cancel-on-entry. + firstQueuedMessage.Accepted = a.BeginAccepted(call.SessionID) + mu.Unlock() return a.Run(ctx, firstQueuedMessage) } @@ -1305,6 +1593,16 @@ func summaryCompletionTokens(usage fantasy.Usage, summaryMessage message.Message } func (a *sessionAgent) Cancel(sessionID string) { + // Serialize against the dispatch handoff in Run so the accepted -> + // (cancel-on-entry | queued | active) transition is atomic against + // this cancel. Every cancel observes at least one of: an active + // request, an accepted run (recorded as a pending cancel), or a + // queue entry it then clears. If none of those hold, an idle Escape + // is a true no-op and must not poison the next prompt. + mu := a.sessionMu(sessionID) + mu.Lock() + defer mu.Unlock() + // Cancel regular requests. Don't use Take() here - we need the entry to // remain in activeRequests so IsBusy() returns true until the goroutine // fully completes (including error handling that may access the DB). @@ -1320,6 +1618,20 @@ func (a *sessionAgent) Cancel(sessionID string) { cancel() } + // Record a pending cancel only when a dispatched-but-not-yet-active + // run exists. This catches a run still in the goroutine scheduler or + // about to enter Run's busy-queue branch, while leaving an idle + // session untouched. Active and accepted are not mutually exclusive: + // when a run is active and a follow-up has been accepted, both the + // cancel above and this pending record fire. + a.acceptedMu.Lock() + count, ok := a.acceptedRuns.Get(sessionID) + a.acceptedMu.Unlock() + if ok && count > 0 { + slog.Debug("Recording pending cancel for accepted run", "session_id", sessionID) + a.pendingCancels.Set(sessionID, struct{}{}) + } + if a.QueuedPrompts(sessionID) > 0 { slog.Debug("Clearing queued prompts", "session_id", sessionID) a.messageQueue.Del(sessionID) diff --git a/internal/agent/coordinator.go b/internal/agent/coordinator.go index a26aa111eeb8e45a699a6aab90774f04a1aca4bb..f5ca831e60cdb54edf0c0d7bfde83702a79701f1 100644 --- a/internal/agent/coordinator.go +++ b/internal/agent/coordinator.go @@ -81,6 +81,15 @@ type Coordinator interface { // INFO: (kujtim) this is not used yet we will use this when we have multiple agents // SetMainAgent(string) Run(ctx context.Context, sessionID, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error) + // RunAccepted runs a call that was already accepted via + // BeginAccepted on the fire-and-forget dispatch path. The handle is + // the only carrier of accept-state across the backend.runAgent / + // Coordinator / sessionAgent.Run layers: it reaches + // sessionAgent.Run as SessionAgentCall.Accepted, where it is + // consumed under dispatchMu once the accepted -> (cancel-on-entry | + // queued | active) transition is chosen. + RunAccepted(ctx context.Context, accept *AcceptedRun, sessionID, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error) + BeginAccepted(sessionID string) *AcceptedRun Cancel(sessionID string) CancelAll() IsSessionBusy(sessionID string) bool @@ -179,6 +188,20 @@ func NewCoordinator( // Run implements Coordinator. func (c *coordinator) Run(ctx context.Context, sessionID string, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error) { + return c.run(ctx, nil, sessionID, prompt, attachments...) +} + +// RunAccepted implements Coordinator. +func (c *coordinator) RunAccepted(ctx context.Context, accept *AcceptedRun, sessionID string, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error) { + return c.run(ctx, accept, sessionID, prompt, attachments...) +} + +// run is the shared implementation behind Run and RunAccepted. When +// accept is non-nil it is threaded onto the SessionAgentCall as +// Accepted so sessionAgent.Run can consume the accept reservation under +// dispatchMu; when nil (the in-process/local path) no accept tracking +// applies. +func (c *coordinator) run(ctx context.Context, accept *AcceptedRun, sessionID string, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error) { if err := c.readyWg.Wait(); err != nil { return nil, err } @@ -256,6 +279,7 @@ func (c *coordinator) Run(ctx context.Context, sessionID string, prompt string, FrequencyPenalty: freqPenalty, PresencePenalty: presPenalty, OnComplete: onComplete, + Accepted: accept, }) } beforeLoaded := c.skillTracker.LoadedNames() @@ -989,6 +1013,15 @@ func isExactoSupported(modelID string) bool { return slices.Contains(supportedModels, modelID) } +// BeginAccepted reserves an accept slot for sessionID on the active +// agent and returns the ownership handle. It is the fire-and-forget +// dispatch path's only way to mark a run as accepted-but-not-yet-active +// so a cancel arriving before the run registers in activeRequests is not +// lost. +func (c *coordinator) BeginAccepted(sessionID string) *AcceptedRun { + return c.currentAgent.BeginAccepted(sessionID) +} + func (c *coordinator) Cancel(sessionID string) { c.currentAgent.Cancel(sessionID) } diff --git a/internal/agent/coordinator_test.go b/internal/agent/coordinator_test.go index da0ddd0db2bf77c3c1e3eb6463549875a989a4ca..c522ef5de1061435e4cf9df1789bc3c92d9152a4 100644 --- a/internal/agent/coordinator_test.go +++ b/internal/agent/coordinator_test.go @@ -25,6 +25,10 @@ func (m *mockSessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fan return m.runFunc(ctx, call) } +func (m *mockSessionAgent) BeginAccepted(sessionID string) *AcceptedRun { + return &AcceptedRun{sessionID: sessionID} +} + func (m *mockSessionAgent) Model() Model { return m.model } func (m *mockSessionAgent) SetModels(large, small Model) {} func (m *mockSessionAgent) SetTools(tools []fantasy.AgentTool) {} diff --git a/internal/agent/dispatch_cancel_test.go b/internal/agent/dispatch_cancel_test.go new file mode 100644 index 0000000000000000000000000000000000000000..f66de252e63559239c1d577fe51c0650589aa5b4 --- /dev/null +++ b/internal/agent/dispatch_cancel_test.go @@ -0,0 +1,198 @@ +package agent + +import ( + "context" + "errors" + "sync/atomic" + "testing" + + "charm.land/fantasy" + "github.com/charmbracelet/crush/internal/message" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// finishStreamModel is a minimal fantasy.LanguageModel that streams a +// single text part followed by a normal (FinishReasonStop) finish. It +// is enough to drive sessionAgent.Run through PrepareStep and a clean +// completion without a recorded provider cassette. +type finishStreamModel struct { + text string +} + +func (m *finishStreamModel) Provider() string { return "fake" } +func (m *finishStreamModel) Model() string { return "fake-model" } + +func (m *finishStreamModel) Generate(ctx context.Context, call fantasy.Call) (*fantasy.Response, error) { + return &fantasy.Response{ + Content: fantasy.ResponseContent{fantasy.TextContent{Text: m.text}}, + FinishReason: fantasy.FinishReasonStop, + }, nil +} + +func (m *finishStreamModel) Stream(ctx context.Context, call fantasy.Call) (fantasy.StreamResponse, error) { + text := m.text + return func(yield func(fantasy.StreamPart) bool) { + if !yield(fantasy.StreamPart{Type: fantasy.StreamPartTypeTextStart, ID: "1"}) { + return + } + if !yield(fantasy.StreamPart{Type: fantasy.StreamPartTypeTextDelta, ID: "1", Delta: text}) { + return + } + if !yield(fantasy.StreamPart{Type: fantasy.StreamPartTypeTextEnd, ID: "1"}) { + return + } + yield(fantasy.StreamPart{Type: fantasy.StreamPartTypeFinish, FinishReason: fantasy.FinishReasonStop}) + }, nil +} + +func (m *finishStreamModel) GenerateObject(ctx context.Context, call fantasy.ObjectCall) (*fantasy.ObjectResponse, error) { + return nil, errors.New("not implemented") +} + +func (m *finishStreamModel) StreamObject(ctx context.Context, call fantasy.ObjectCall) (fantasy.ObjectStreamResponse, error) { + return nil, errors.New("not implemented") +} + +func newStreamTestAgent(t *testing.T) (*sessionAgent, fakeEnv) { + t.Helper() + env := testEnv(t) + model := &finishStreamModel{text: "done"} + sa := testSessionAgent(env, model, model, "system").(*sessionAgent) + return sa, env +} + +// TestCancel_ActiveAndAcceptedFiresBothBranches covers the case where a +// session is actively running (activeRequests set) AND a follow-up has +// been accepted (acceptedRuns > 0). A single Cancel must fire both: it +// invokes the active cancel func and records a pending cancel for the +// accepted follow-up. +func TestCancel_ActiveAndAcceptedFiresBothBranches(t *testing.T) { + t.Parallel() + sa, _ := newCancelTestAgent(t) + + const sid = "sid" + var activeCanceled atomic.Bool + sa.activeRequests.Set(sid, func() { activeCanceled.Store(true) }) + + accept := sa.BeginAccepted(sid) + defer accept.Close() + + sa.Cancel(sid) + + require.True(t, activeCanceled.Load(), "active cancel func must fire") + require.True(t, sa.hasPendingCancel(sid), "accepted follow-up must record a pending cancel") +} + +// TestRun_BusyWithPendingCancelTakesCancelOnEntry covers the busy-queue +// branch consulting pendingCancels: when the session is busy AND a +// cancel has been recorded for an accepted follow-up, Run must take the +// cancel-on-entry path (persist a canceled turn) instead of enqueueing +// the call behind the active run. +func TestRun_BusyWithPendingCancelTakesCancelOnEntry(t *testing.T) { + t.Parallel() + sa, env := newCancelTestAgent(t) + + sess, err := env.sessions.Create(t.Context(), "session") + require.NoError(t, err) + + // Make the session look busy: an earlier prompt is active. + sa.activeRequests.Set(sess.ID, func() {}) + + accept := sa.BeginAccepted(sess.ID) + // A cancel arrives while this follow-up is accepted-but-not-active. + sa.Cancel(sess.ID) + require.True(t, sa.hasPendingCancel(sess.ID)) + + result, err := sa.Run(t.Context(), SessionAgentCall{ + SessionID: sess.ID, + Prompt: "follow-up", + Accepted: accept, + }) + require.NoError(t, err) + require.Nil(t, result) + + // The follow-up was canceled on entry, not enqueued. + require.Equal(t, 0, sa.QueuedPrompts(sess.ID), + "cancel-on-entry must not enqueue the follow-up behind the active run") + require.False(t, sa.hasPendingCancel(sess.ID), "pending cancel must be consumed") + require.Equal(t, 0, sa.acceptedCount(sess.ID), "accept reservation must be released") + + msgs, err := env.messages.List(t.Context(), sess.ID) + require.NoError(t, err) + require.Len(t, msgs, 2) + assert.Equal(t, message.User, msgs[0].Role) + assert.Equal(t, message.Assistant, msgs[1].Role) + assert.Equal(t, message.FinishReasonCanceled, msgs[1].FinishReason()) +} + +// TestRun_PrepareStepDrainSkipsQueuedOnPendingCancel verifies that the +// queue drain inside PrepareStep skips queued follow-up prompts when a +// cancel has been recorded for the session: the queued prompt must not +// be folded into the active turn as an extra user message. +func TestRun_PrepareStepDrainSkipsQueuedOnPendingCancel(t *testing.T) { + t.Parallel() + sa, env := newStreamTestAgent(t) + + sess, err := env.sessions.Create(t.Context(), "session") + require.NoError(t, err) + + // A follow-up prompt sits queued for the session. + sa.enqueueCall(SessionAgentCall{SessionID: sess.ID, Prompt: "queued-followup"}) + // A cancel was recorded for the session while it sat in the queue. + sa.pendingCancels.Set(sess.ID, struct{}{}) + + result, err := sa.Run(t.Context(), SessionAgentCall{ + SessionID: sess.ID, + Prompt: "main", + }) + require.NoError(t, err) + require.NotNil(t, result) + + // Only the main prompt produced a user message; the queued + // follow-up was skipped, not folded into the turn. + msgs, err := env.messages.List(t.Context(), sess.ID) + require.NoError(t, err) + var userMsgs []message.Message + for _, m := range msgs { + if m.Role == message.User { + userMsgs = append(userMsgs, m) + } + } + require.Len(t, userMsgs, 1, "queued follow-up must not create a user message") + assert.Equal(t, "main", userMsgs[0].Content().String()) + + // The queue was drained and the pending cancel consumed. + require.Equal(t, 0, sa.QueuedPrompts(sess.ID)) + require.False(t, sa.hasPendingCancel(sess.ID)) +} + +// TestRun_NormalCompletionClearsStalePendingCancel verifies that a Run +// which completes normally clears any stale pending-cancel entry for the +// session, so it cannot catch a future run. +func TestRun_NormalCompletionClearsStalePendingCancel(t *testing.T) { + t.Parallel() + sa, env := newStreamTestAgent(t) + + sess, err := env.sessions.Create(t.Context(), "session") + require.NoError(t, err) + + // A stale pending cancel lingers (no queued work, no accepted run). + sa.pendingCancels.Set(sess.ID, struct{}{}) + + result, err := sa.Run(t.Context(), SessionAgentCall{ + SessionID: sess.ID, + Prompt: "main", + }) + require.NoError(t, err) + require.NotNil(t, result) + + require.False(t, sa.hasPendingCancel(sess.ID), + "normal completion must clear the stale pending cancel") + + msgs, err := env.messages.List(t.Context(), sess.ID) + require.NoError(t, err) + require.Len(t, msgs, 2) + assert.Equal(t, message.Assistant, msgs[1].Role) + assert.Equal(t, message.FinishReasonEndTurn, msgs[1].FinishReason()) +} diff --git a/internal/server/agent_cancel_test.go b/internal/server/agent_cancel_test.go index 1bbb05e511a26f02cc1dad0e1da77454af0f8905..18ea8046d0647d3576a50769b7b2146d9aa103e9 100644 --- a/internal/server/agent_cancel_test.go +++ b/internal/server/agent_cancel_test.go @@ -60,6 +60,13 @@ func (s *runCoordinator) Run(ctx context.Context, sessionID, prompt string, atta return nil, s.returnFn(ctx) } +func (s *runCoordinator) RunAccepted(ctx context.Context, accept *agent.AcceptedRun, sessionID, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error) { + return s.Run(ctx, sessionID, prompt, attachments...) +} + +func (s *runCoordinator) BeginAccepted(sessionID string) *agent.AcceptedRun { + return nil +} func (s *runCoordinator) Cancel(string) {} func (s *runCoordinator) CancelAll() {} func (s *runCoordinator) IsBusy() bool { return false } diff --git a/internal/server/sessions_isbusy_test.go b/internal/server/sessions_isbusy_test.go index 060c00abe9367dc7162bdb50dd77fe951041aa51..615f4a5b58cde3f5779b053ca9ce92e0d0d253a4 100644 --- a/internal/server/sessions_isbusy_test.go +++ b/internal/server/sessions_isbusy_test.go @@ -30,6 +30,14 @@ type stubCoordinator struct { func (s *stubCoordinator) Run(ctx context.Context, sessionID, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error) { return nil, nil } + +func (s *stubCoordinator) RunAccepted(ctx context.Context, accept *agent.AcceptedRun, sessionID, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error) { + return nil, nil +} + +func (s *stubCoordinator) BeginAccepted(sessionID string) *agent.AcceptedRun { + return nil +} func (s *stubCoordinator) Cancel(string) {} func (s *stubCoordinator) CancelAll() {} func (s *stubCoordinator) IsBusy() bool { return false } From 0a977431d6445551253c8677ec230beffdf215fe Mon Sep 17 00:00:00 2001 From: Christian Rocha Date: Sun, 31 May 2026 22:36:20 -0400 Subject: [PATCH 04/15] chore(server): report background prompt failures via events Add a prompt failure event that can be delivered after the HTTP request has already returned. Remote clients and subscribers now receive the same failure message through the event stream. Co-Authored-By: Charm Crush --- internal/agent/notify/notify.go | 6 ++++++ internal/server/events.go | 18 ++++++++++++------ internal/workspace/client_workspace.go | 16 ++++++++++------ 3 files changed, 28 insertions(+), 12 deletions(-) diff --git a/internal/agent/notify/notify.go b/internal/agent/notify/notify.go index ac7f724c0f07f552d9759247821a2555c9e12524..1a217a6d00650fe1134b24d9d779821015513063 100644 --- a/internal/agent/notify/notify.go +++ b/internal/agent/notify/notify.go @@ -12,6 +12,9 @@ const ( // TypeReAuthenticate indicates the agent encountered an // authentication error and the user needs to re-authenticate. TypeReAuthenticate Type = "re_authenticate" + // TypeAgentError indicates the agent's turn terminated with an + // error. The error text is carried in Notification.Message. + TypeAgentError Type = "error" ) // Notification represents a domain event published by the agent. @@ -20,6 +23,9 @@ type Notification struct { SessionTitle string Type Type ProviderID string + // Message carries the error text for TypeAgentError. Other + // notification types ignore it. + Message string } // RunComplete is the authoritative end-of-run signal for a session. diff --git a/internal/server/events.go b/internal/server/events.go index 4e3d6a1a262ae7b3399f0fa765ae863160ba57c8..fd085c5a415c0ef0fc402673ad23fff8435f1db6 100644 --- a/internal/server/events.go +++ b/internal/server/events.go @@ -2,6 +2,7 @@ package server import ( "encoding/json" + "errors" "fmt" "log/slog" @@ -85,13 +86,18 @@ func wrapEvent(ev any) *pubsub.Payload { Payload: fileToProto(e.Payload), }) case pubsub.Event[notify.Notification]: + payload := proto.AgentEvent{ + SessionID: e.Payload.SessionID, + SessionTitle: e.Payload.SessionTitle, + Type: proto.AgentEventType(e.Payload.Type), + } + if e.Payload.Type == notify.TypeAgentError { + payload.Type = proto.AgentEventTypeError + payload.Error = errors.New(e.Payload.Message) + } return envelope(pubsub.PayloadTypeAgentEvent, pubsub.Event[proto.AgentEvent]{ - Type: e.Type, - Payload: proto.AgentEvent{ - SessionID: e.Payload.SessionID, - SessionTitle: e.Payload.SessionTitle, - Type: proto.AgentEventType(e.Payload.Type), - }, + Type: e.Type, + Payload: payload, }) case pubsub.Event[notify.RunComplete]: return envelope(pubsub.PayloadTypeRunComplete, pubsub.Event[proto.RunComplete]{ diff --git a/internal/workspace/client_workspace.go b/internal/workspace/client_workspace.go index 2018fab8a7dcc2cb3aeb0f44fc0920c1db72d852..a6a43731675698083671cae95f983d7a3a724a5d 100644 --- a/internal/workspace/client_workspace.go +++ b/internal/workspace/client_workspace.go @@ -703,13 +703,17 @@ func (w *ClientWorkspace) translateEvent(ev any) tea.Msg { Payload: protoToFile(e.Payload), } case pubsub.Event[proto.AgentEvent]: + n := notify.Notification{ + SessionID: e.Payload.SessionID, + SessionTitle: e.Payload.SessionTitle, + Type: notify.Type(e.Payload.Type), + } + if e.Payload.Error != nil { + n.Message = e.Payload.Error.Error() + } return pubsub.Event[notify.Notification]{ - Type: e.Type, - Payload: notify.Notification{ - SessionID: e.Payload.SessionID, - SessionTitle: e.Payload.SessionTitle, - Type: notify.Type(e.Payload.Type), - }, + Type: e.Type, + Payload: n, } case pubsub.Event[proto.RunComplete]: // Translate the wire-level proto.RunComplete back into the From 8b242c5313a79630d22429c04b699adcb624a89a Mon Sep 17 00:00:00 2001 From: Christian Rocha Date: Sun, 31 May 2026 23:06:41 -0400 Subject: [PATCH 05/15] chore(server): run accepted server prompts in the background Accept validated prompts quickly, then run them on workspace-owned background work. Failures are published through events while normal cancellation continues to be represented by the canceled assistant turn. Co-Authored-By: Charm Crush --- internal/backend/agent.go | 78 +++++++++++-- internal/backend/agent_test.go | 163 +++++++++++++++++++++++++++ internal/backend/testing.go | 14 +++ internal/server/agent_cancel_test.go | 89 ++++++--------- internal/server/proto.go | 17 +-- 5 files changed, 287 insertions(+), 74 deletions(-) create mode 100644 internal/backend/agent_test.go diff --git a/internal/backend/agent.go b/internal/backend/agent.go index 78447ab7c64a82bb2638fb3fe184d0be132b4589..2dd0479d3236d55e3919bdef1f16bb593fe5684e 100644 --- a/internal/backend/agent.go +++ b/internal/backend/agent.go @@ -2,22 +2,30 @@ package backend import ( "context" + "errors" "github.com/charmbracelet/crush/internal/agent" + "github.com/charmbracelet/crush/internal/agent/notify" "github.com/charmbracelet/crush/internal/config" "github.com/charmbracelet/crush/internal/proto" + "github.com/charmbracelet/crush/internal/pubsub" ) -// SendMessage sends a prompt to the agent coordinator for the given -// workspace and session. +// SendMessage validates and accepts a prompt for the workspace's agent, +// then dispatches the run on a goroutine bound to the workspace context +// and returns immediately. It does not wait for the LLM turn to +// complete: the run's lifetime is owned by the workspace, not by the +// caller. Errors from the dispatched run reach observers through the +// agent event channels (a notify.TypeAgentError notification), not +// through this return value. // -// When msg.RunID is non-empty it is attached to the context via -// agent.WithRunID so the coordinator can stamp the resulting -// SessionAgentCall (and therefore the terminal notify.RunComplete -// event) with that correlator. This is the only way for the -// originating client to distinguish its own turn's RunComplete from -// any concurrent turn that finishes on the same session. -func (b *Backend) SendMessage(ctx context.Context, workspaceID string, msg proto.AgentMessage) error { +// SendMessage returns synchronously when the request cannot be accepted: +// ErrWorkspaceNotFound if the workspace is missing, ErrAgentNotInitialized +// if its coordinator is nil, the structural validation errors from +// agent.ValidateCall (ErrEmptyPrompt, ErrSessionMissing) when the prompt +// or session is missing, and ErrWorkspaceClosing if the workspace is +// being torn down. +func (b *Backend) SendMessage(workspaceID string, msg proto.AgentMessage) error { ws, err := b.GetWorkspace(workspaceID) if err != nil { return err @@ -27,11 +35,59 @@ func (b *Backend) SendMessage(ctx context.Context, workspaceID string, msg proto return ErrAgentNotInitialized } + if err := agent.ValidateCall(agent.SessionAgentCall{ + SessionID: msg.SessionID, + Prompt: msg.Prompt, + Attachments: proto.AttachmentsToMessage(msg.Attachments), + }); err != nil { + return err + } + + accept := ws.AgentCoordinator.BeginAccepted(msg.SessionID) + + ws.runMu.Lock() + if ws.closing { + ws.runMu.Unlock() + accept.Close() + return ErrWorkspaceClosing + } + ws.runWG.Add(1) + ws.runMu.Unlock() + + go b.runAgent(ws, msg, accept) + return nil +} + +// runAgent executes an accepted agent run for the workspace. It owns the +// accept reservation (releasing it on return) and the runWG ticket added +// by SendMessage. The run is bound to the workspace context so its +// lifetime is independent of any client's HTTP request. On a non-cancel +// error it surfaces the failure to observers via a notify.TypeAgentError +// notification; context.Canceled is expected (the FinishReasonCanceled +// marker is already published by sessionAgent.Run) and swallowed. +// +// When msg.RunID is non-empty it is attached to the context via +// agent.WithRunID so the coordinator can stamp the terminal +// notify.RunComplete event with that correlator. +func (b *Backend) runAgent(ws *Workspace, msg proto.AgentMessage, accept *agent.AcceptedRun) { + defer ws.runWG.Done() + defer accept.Close() + + ctx := ws.ctx if msg.RunID != "" { ctx = agent.WithRunID(ctx, msg.RunID) } - _, err = ws.AgentCoordinator.Run(ctx, msg.SessionID, msg.Prompt, proto.AttachmentsToMessage(msg.Attachments)...) - return err + + _, err := ws.AgentCoordinator.RunAccepted(ctx, accept, msg.SessionID, msg.Prompt, proto.AttachmentsToMessage(msg.Attachments)...) + if err == nil || errors.Is(err, context.Canceled) { + return + } + + ws.AgentNotifications().Publish(pubsub.CreatedEvent, notify.Notification{ + SessionID: msg.SessionID, + Type: notify.TypeAgentError, + Message: err.Error(), + }) } // GetAgentInfo returns the agent's model and busy status. diff --git a/internal/backend/agent_test.go b/internal/backend/agent_test.go new file mode 100644 index 0000000000000000000000000000000000000000..5d9365ecc899236b91d73198ab322bcabbf9cc77 --- /dev/null +++ b/internal/backend/agent_test.go @@ -0,0 +1,163 @@ +package backend + +import ( + "context" + "sync/atomic" + "testing" + "time" + + "charm.land/fantasy" + "github.com/charmbracelet/crush/internal/agent" + "github.com/charmbracelet/crush/internal/app" + "github.com/charmbracelet/crush/internal/message" + "github.com/charmbracelet/crush/internal/proto" + "github.com/google/uuid" + "github.com/stretchr/testify/require" +) + +// blockingCoordinator is a minimal agent.Coordinator whose RunAccepted +// blocks until release is closed. It records that RunAccepted was +// entered so tests can observe the dispatched goroutine. Every other +// method returns a zero value. +type blockingCoordinator struct { + entered chan struct{} + release chan struct{} + runCount atomic.Int32 +} + +func newBlockingCoordinator() *blockingCoordinator { + return &blockingCoordinator{ + entered: make(chan struct{}, 1), + release: make(chan struct{}), + } +} + +func (c *blockingCoordinator) Run(ctx context.Context, sessionID, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error) { + return nil, nil +} + +func (c *blockingCoordinator) RunAccepted(ctx context.Context, accept *agent.AcceptedRun, sessionID, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error) { + c.runCount.Add(1) + select { + case c.entered <- struct{}{}: + default: + } + <-c.release + return nil, nil +} + +func (c *blockingCoordinator) BeginAccepted(sessionID string) *agent.AcceptedRun { return nil } +func (c *blockingCoordinator) Cancel(string) {} +func (c *blockingCoordinator) CancelAll() {} +func (c *blockingCoordinator) IsBusy() bool { return false } +func (c *blockingCoordinator) IsSessionBusy(string) bool { return false } +func (c *blockingCoordinator) QueuedPrompts(string) int { return 0 } +func (c *blockingCoordinator) QueuedPromptsList(string) []string { return nil } +func (c *blockingCoordinator) ClearQueue(string) {} +func (c *blockingCoordinator) Summarize(context.Context, string) error { return nil } +func (c *blockingCoordinator) Model() agent.Model { return agent.Model{} } +func (c *blockingCoordinator) UpdateModels(context.Context) error { return nil } + +// insertAgentWorkspace installs a synthetic workspace with the given +// coordinator (or none) and a workspace run context, mirroring the +// fields CreateWorkspace initializes. +func insertAgentWorkspace(t *testing.T, b *Backend, coord agent.Coordinator) *Workspace { + t.Helper() + ws := &Workspace{ + ID: uuid.New().String(), + Path: t.TempDir(), + resolvedPath: t.TempDir(), + clients: make(map[string]*clientState), + shutdownFn: func() {}, + } + ws.App = &app.App{AgentCoordinator: coord} + ws.ctx, ws.cancel = context.WithCancel(b.ctx) + b.mu.Lock() + b.workspaces.Set(ws.ID, ws) + b.pathIndex[ws.resolvedPath] = ws.ID + b.mu.Unlock() + return ws +} + +func TestSendMessage_WorkspaceNotFound(t *testing.T) { + t.Parallel() + b, _ := newTestBackend(t) + err := b.SendMessage("nope", proto.AgentMessage{SessionID: "S1", Prompt: "hi"}) + require.ErrorIs(t, err, ErrWorkspaceNotFound) +} + +func TestSendMessage_AgentNotInitialized(t *testing.T) { + t.Parallel() + b, _ := newTestBackend(t) + ws := insertAgentWorkspace(t, b, nil) + err := b.SendMessage(ws.ID, proto.AgentMessage{SessionID: "S1", Prompt: "hi"}) + require.ErrorIs(t, err, ErrAgentNotInitialized) +} + +func TestSendMessage_EmptyPrompt(t *testing.T) { + t.Parallel() + b, _ := newTestBackend(t) + ws := insertAgentWorkspace(t, b, newBlockingCoordinator()) + err := b.SendMessage(ws.ID, proto.AgentMessage{SessionID: "S1", Prompt: ""}) + require.ErrorIs(t, err, agent.ErrEmptyPrompt) +} + +func TestSendMessage_SessionMissing(t *testing.T) { + t.Parallel() + b, _ := newTestBackend(t) + ws := insertAgentWorkspace(t, b, newBlockingCoordinator()) + err := b.SendMessage(ws.ID, proto.AgentMessage{SessionID: "", Prompt: "hi"}) + require.ErrorIs(t, err, agent.ErrSessionMissing) +} + +func TestSendMessage_WorkspaceClosing(t *testing.T) { + t.Parallel() + b, _ := newTestBackend(t) + ws := insertAgentWorkspace(t, b, newBlockingCoordinator()) + ws.runMu.Lock() + ws.closing = true + ws.runMu.Unlock() + err := b.SendMessage(ws.ID, proto.AgentMessage{SessionID: "S1", Prompt: "hi"}) + require.ErrorIs(t, err, ErrWorkspaceClosing) +} + +// TestSendMessage_SuccessIncrementsRunWG asserts the happy path returns +// nil synchronously and dispatches a tracked goroutine: while +// RunAccepted blocks, runWG.Wait must not complete (the ticket is +// outstanding); after release it drains. +func TestSendMessage_SuccessIncrementsRunWG(t *testing.T) { + t.Parallel() + b, _ := newTestBackend(t) + coord := newBlockingCoordinator() + ws := insertAgentWorkspace(t, b, coord) + + err := b.SendMessage(ws.ID, proto.AgentMessage{SessionID: "S1", Prompt: "hi"}) + require.NoError(t, err) + + select { + case <-coord.entered: + case <-time.After(2 * time.Second): + t.Fatal("dispatched goroutine never entered RunAccepted") + } + require.Equal(t, int32(1), coord.runCount.Load()) + + waited := make(chan struct{}) + go func() { + ws.runWG.Wait() + close(waited) + }() + + select { + case <-waited: + t.Fatal("runWG.Wait completed while the run was still in flight; ticket was not added") + case <-time.After(100 * time.Millisecond): + } + + close(coord.release) + + select { + case <-waited: + case <-time.After(2 * time.Second): + t.Fatal("runWG.Wait did not complete after the run returned") + } +} diff --git a/internal/backend/testing.go b/internal/backend/testing.go index 6616e0f19e06595fac68808b484394d960d7f79f..1c71caed6566747ac947d61c30ece575d3d13eb4 100644 --- a/internal/backend/testing.go +++ b/internal/backend/testing.go @@ -1,9 +1,16 @@ package backend +import "context" + // InsertWorkspaceForTest registers ws with b under its current ID and // path. It is intended for tests in other packages that need to drive // HTTP handlers against a synthetic workspace without booting a real // app.App. Production code should go through CreateWorkspace. +// +// If the workspace has no run context yet it is derived from the +// backend context (falling back to context.Background), mirroring the +// initialization CreateWorkspace performs, so dispatched agent runs +// have a non-nil ws.ctx. func InsertWorkspaceForTest(b *Backend, ws *Workspace) { if ws.resolvedPath == "" { ws.resolvedPath = ws.Path @@ -11,6 +18,13 @@ func InsertWorkspaceForTest(b *Backend, ws *Workspace) { if ws.clients == nil { ws.clients = make(map[string]*clientState) } + if ws.ctx == nil { + parent := b.ctx + if parent == nil { + parent = context.Background() + } + ws.ctx, ws.cancel = context.WithCancel(parent) + } b.mu.Lock() defer b.mu.Unlock() b.workspaces.Set(ws.ID, ws) diff --git a/internal/server/agent_cancel_test.go b/internal/server/agent_cancel_test.go index 18ea8046d0647d3576a50769b7b2146d9aa103e9..6697bdd210b911862b6795ef1c5e5009e486ad64 100644 --- a/internal/server/agent_cancel_test.go +++ b/internal/server/agent_cancel_test.go @@ -122,11 +122,12 @@ func postAgent(t *testing.T, c *controllerV1, ctx context.Context, wsID, session } // TestPostAgent_ReturnsOKOnContextCanceled verifies that when another -// client cancels the session mid-turn, the prompting client's still -// open POST receives 200 (not 500). The agent surfaces the -// FinishReasonCanceled marker to every SSE subscriber via the -// assistant message; the HTTP response from the prompter should not -// double as an error signal. +// client cancels the session mid-turn, the prompting client's POST is +// unaffected: SendMessage is fire-and-forget, so the handler returns +// 200 immediately without waiting for the turn. A run that later +// returns context.Canceled never surfaces as a 500 to the prompter; +// the FinishReasonCanceled marker reaches SSE subscribers via the +// assistant message instead. func TestPostAgent_ReturnsOKOnContextCanceled(t *testing.T) { t.Parallel() @@ -135,33 +136,26 @@ func TestPostAgent_ReturnsOKOnContextCanceled(t *testing.T) { }) c, wsID := buildAgentWorkspace(t, coord) - done := make(chan *httptest.ResponseRecorder, 1) - go func() { - done <- postAgent(t, c, t.Context(), wsID, "S1") - }() + // The handler returns immediately, before the dispatched run is + // released, because the run no longer owns the HTTP response. + rec := postAgent(t, c, t.Context(), wsID, "S1") + require.Equal(t, http.StatusOK, rec.Code, "fire-and-forget SendMessage must return 200 without waiting for the run") - // Wait until Run is in flight, then release it to return - // context.Canceled. + // The run is dispatched on a goroutine; let it return + // context.Canceled. Nothing from that path reaches the (already + // returned) handler. select { case <-coord.entered: case <-time.After(2 * time.Second): - t.Fatal("coordinator Run was never entered") + t.Fatal("dispatched run was never entered") } close(coord.release) - - select { - case rec := <-done: - require.Equal(t, http.StatusOK, rec.Code, "context.Canceled from another client's cancel must not surface as 500") - case <-time.After(2 * time.Second): - t.Fatal("handler did not return after coordinator returned context.Canceled") - } } -// TestPostAgent_DetachesRequestContext verifies that canceling the -// prompting client's HTTP request context does not cancel the -// in-flight agent run. The coordinator must observe a context whose -// Done channel never fires from the request side; only the explicit -// cancel endpoint may end the run. +// TestPostAgent_DetachesRequestContext verifies that the dispatched run +// is bound to the workspace context, not the prompting client's HTTP +// request context. Canceling the request context must neither cancel +// the run nor be observed by the coordinator. func TestPostAgent_DetachesRequestContext(t *testing.T) { t.Parallel() @@ -172,46 +166,31 @@ func TestPostAgent_DetachesRequestContext(t *testing.T) { reqCtx, cancelReq := context.WithCancel(context.Background()) - done := make(chan *httptest.ResponseRecorder, 1) - go func() { - done <- postAgent(t, c, reqCtx, wsID, "S1") - }() + // The handler returns immediately; the run keeps executing on its + // own goroutine bound to the workspace context. + rec := postAgent(t, c, reqCtx, wsID, "S1") + require.Equal(t, http.StatusOK, rec.Code) - // Wait until Run is in flight, then drop the prompting client. select { case <-coord.entered: case <-time.After(2 * time.Second): - t.Fatal("coordinator Run was never entered") + t.Fatal("dispatched run was never entered") } + + // Drop the prompting client. This must not reach the run. cancelReq() - // The captured ctx must be detached: context.WithoutCancel - // returns a ctx with Done() == nil so request cancellation cannot - // propagate. got := coord.capturedCtx() require.NotNil(t, got) - require.Nil(t, got.Done(), "coordinator ctx must be detached from r.Context() via context.WithoutCancel") - require.NoError(t, got.Err(), "coordinator ctx must not inherit cancellation from the dropped request") + // Compare by identity (pointer), not reflect.DeepEqual: deep + // comparison would traverse context internals that the runtime + // mutates concurrently. + require.False(t, got == reqCtx, "run ctx must not be the request ctx") + require.NoError(t, got.Err(), "run ctx must not inherit cancellation from the dropped request") - // Confirm Run is still running: it should not have completed - // just because the request ctx was canceled. - select { - case <-done: - t.Fatal("handler returned before run completed; request ctx cancellation leaked into the run") - case <-time.After(50 * time.Millisecond): - } - - // Release the run; the handler should now complete cleanly. + // Release the run so it returns cleanly. close(coord.release) - select { - case rec := <-done: - // Writing to a recorder whose request ctx was canceled - // still works; in production the TCP write would silently - // fail, which is fine because the run already completed and - // SSE subscribers have the result. - require.Equal(t, http.StatusOK, rec.Code) - case <-time.After(2 * time.Second): - t.Fatal("handler did not return after release") - } - require.Equal(t, int32(1), coord.ranCount.Load()) + require.Eventually(t, func() bool { + return coord.ranCount.Load() == 1 + }, 2*time.Second, 10*time.Millisecond) } diff --git a/internal/server/proto.go b/internal/server/proto.go index 5e1c4e2605fb8413df0464be2cd7ee6cb40a5f66..4b43dcd096de34ab4bd2adde61cc4f5a73e0f8c9 100644 --- a/internal/server/proto.go +++ b/internal/server/proto.go @@ -755,14 +755,15 @@ func (c *controllerV1) handlePostWorkspaceAgent(w http.ResponseWriter, r *http.R return } - // Detach the run's lifetime from the prompting client's HTTP - // request. Without this, A dropping its TCP connection (network - // blip, TUI restart) or B canceling the session via the explicit - // cancel endpoint would also cancel A's request context and tear - // down a turn that other subscribed clients are still watching. - // Only the explicit cancel endpoint should be able to end a run. - ctx := context.WithoutCancel(r.Context()) - if err := c.backend.SendMessage(ctx, id, msg); err != nil { + // The run's lifetime is detached from the prompting client's HTTP + // request: SendMessage validates and accepts the prompt, dispatches + // the run on a goroutine bound to the workspace context, and returns + // immediately. A dropping its TCP connection (network blip, TUI + // restart) or B canceling the session via the explicit cancel + // endpoint can no longer tear down a turn that other subscribed + // clients are still watching. Only the explicit cancel endpoint + // should be able to end a run. + if err := c.backend.SendMessage(id, msg); err != nil { c.handleError(w, r, err) return } From ee94df068af1d5d029fe425dc80d337aed09f263 Mon Sep 17 00:00:00 2001 From: Christian Rocha Date: Sun, 31 May 2026 23:10:37 -0400 Subject: [PATCH 06/15] chore(server): acknowledge accepted prompts with HTTP 202 Return as soon as a prompt is accepted instead of keeping the request open for the full agent turn. Workspace shutdown is reported as a conflict so clients can distinguish rejection from accepted work. Co-Authored-By: Charm Crush --- internal/server/agent_cancel_test.go | 4 ++-- internal/server/proto.go | 7 +++++-- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/internal/server/agent_cancel_test.go b/internal/server/agent_cancel_test.go index 6697bdd210b911862b6795ef1c5e5009e486ad64..dd04fa6b77c0bd1e6ae1532bcd39e68f61bff9b6 100644 --- a/internal/server/agent_cancel_test.go +++ b/internal/server/agent_cancel_test.go @@ -139,7 +139,7 @@ func TestPostAgent_ReturnsOKOnContextCanceled(t *testing.T) { // The handler returns immediately, before the dispatched run is // released, because the run no longer owns the HTTP response. rec := postAgent(t, c, t.Context(), wsID, "S1") - require.Equal(t, http.StatusOK, rec.Code, "fire-and-forget SendMessage must return 200 without waiting for the run") + require.Equal(t, http.StatusAccepted, rec.Code, "fire-and-forget SendMessage must return 202 without waiting for the run") // The run is dispatched on a goroutine; let it return // context.Canceled. Nothing from that path reaches the (already @@ -169,7 +169,7 @@ func TestPostAgent_DetachesRequestContext(t *testing.T) { // The handler returns immediately; the run keeps executing on its // own goroutine bound to the workspace context. rec := postAgent(t, c, reqCtx, wsID, "S1") - require.Equal(t, http.StatusOK, rec.Code) + require.Equal(t, http.StatusAccepted, rec.Code) select { case <-coord.entered: diff --git a/internal/server/proto.go b/internal/server/proto.go index 4b43dcd096de34ab4bd2adde61cc4f5a73e0f8c9..f747e98791bbe95c2a510d4eb665abe205f7260c 100644 --- a/internal/server/proto.go +++ b/internal/server/proto.go @@ -740,9 +740,10 @@ func (c *controllerV1) handleGetWorkspaceAgent(w http.ResponseWriter, r *http.Re // @Accept json // @Param id path string true "Workspace ID" // @Param request body proto.AgentMessage true "Agent message" -// @Success 200 +// @Success 202 // @Failure 400 {object} proto.Error // @Failure 404 {object} proto.Error +// @Failure 409 {object} proto.Error // @Failure 500 {object} proto.Error // @Router /workspaces/{id}/agent [post] func (c *controllerV1) handlePostWorkspaceAgent(w http.ResponseWriter, r *http.Request) { @@ -767,7 +768,7 @@ func (c *controllerV1) handlePostWorkspaceAgent(w http.ResponseWriter, r *http.R c.handleError(w, r, err) return } - w.WriteHeader(http.StatusOK) + w.WriteHeader(http.StatusAccepted) } // handlePostWorkspaceAgentInit initializes the agent for a workspace. @@ -1061,6 +1062,8 @@ func (c *controllerV1) handleError(w http.ResponseWriter, r *http.Request, err e status = http.StatusBadRequest case errors.Is(err, backend.ErrClientNotAttached): status = http.StatusNotFound + case errors.Is(err, backend.ErrWorkspaceClosing): + status = http.StatusConflict } c.server.logError(r, err.Error()) jsonError(w, status, err.Error()) From d2e9e51fceddf5ed4c57c23bbb15f7a799e88892 Mon Sep 17 00:00:00 2001 From: Christian Rocha Date: Sun, 31 May 2026 23:14:25 -0400 Subject: [PATCH 07/15] chore(server): apply asynchronous prompt contract to clients Treat HTTP 202 as a successful prompt submission and preserve server-provided error messages for rejected prompts. Callers no longer see generic status errors when the server explains the failure. Co-Authored-By: Charm Crush --- internal/client/proto.go | 18 ++++++++- internal/client/proto_test.go | 70 +++++++++++++++++++++++++++++++++++ 2 files changed, 87 insertions(+), 1 deletion(-) diff --git a/internal/client/proto.go b/internal/client/proto.go index 62a43b5884e01ae8fcd3242c68e95d1f76251c42..d07e46dc84bf09dccffbd609784f92c7ae9a9c67 100644 --- a/internal/client/proto.go +++ b/internal/client/proto.go @@ -423,12 +423,28 @@ func (c *Client) SendMessage(ctx context.Context, id string, sessionID, runID, p return fmt.Errorf("failed to send message to agent: %w", err) } defer rsp.Body.Close() - if rsp.StatusCode != http.StatusOK { + if rsp.StatusCode != http.StatusOK && rsp.StatusCode != http.StatusAccepted { + if msg := decodeErrorMessage(rsp.Body); msg != "" { + return fmt.Errorf("failed to send message to agent: status code %d: %s", rsp.StatusCode, msg) + } return fmt.Errorf("failed to send message to agent: status code %d", rsp.StatusCode) } return nil } +// decodeErrorMessage attempts to decode the response body as a +// proto.Error and returns its message. It returns an empty string +// when the body is empty or cannot be decoded into a proto.Error +// with a non-empty message, letting callers fall back to a +// status-only error. +func decodeErrorMessage(body io.Reader) string { + var e proto.Error + if err := json.NewDecoder(body).Decode(&e); err != nil { + return "" + } + return e.Message +} + // GetAgentSessionInfo retrieves the agent session info for a workspace. func (c *Client) GetAgentSessionInfo(ctx context.Context, id string, sessionID string) (*proto.AgentSession, error) { rsp, err := c.get(ctx, fmt.Sprintf("/workspaces/%s/agent/sessions/%s", id, sessionID), nil, nil) diff --git a/internal/client/proto_test.go b/internal/client/proto_test.go index b5739ccc91c16b2bb0fc3c3f6dc2281687bd8e65..c7abd3e03d4ae6f575079c7c938369d6cb7cc30b 100644 --- a/internal/client/proto_test.go +++ b/internal/client/proto_test.go @@ -88,6 +88,76 @@ func TestSubscribeEventsContextCancelClosesEvents(t *testing.T) { } } +func TestSendMessageAcceptsStatusAccepted(t *testing.T) { + t.Parallel() + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusAccepted) + })) + defer srv.Close() + + c := captureClient(t, srv) + require.NoError(t, c.SendMessage(context.Background(), "ws1", "sess1", "", "hello")) +} + +func TestSendMessageAcceptsStatusOK(t *testing.T) { + t.Parallel() + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer srv.Close() + + c := captureClient(t, srv) + require.NoError(t, c.SendMessage(context.Background(), "ws1", "sess1", "", "hello")) +} + +func TestSendMessageDecodesErrorBody(t *testing.T) { + t.Parallel() + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusBadRequest) + _ = json.NewEncoder(w).Encode(proto.Error{Message: "session id is required"}) + })) + defer srv.Close() + + c := captureClient(t, srv) + err := c.SendMessage(context.Background(), "ws1", "", "", "hello") + require.Error(t, err) + require.Contains(t, err.Error(), "status code 400") + require.Contains(t, err.Error(), "session id is required") +} + +func TestSendMessageFallsBackOnMalformedErrorBody(t *testing.T) { + t.Parallel() + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + _, _ = w.Write([]byte("not json")) + })) + defer srv.Close() + + c := captureClient(t, srv) + err := c.SendMessage(context.Background(), "ws1", "sess1", "", "hello") + require.Error(t, err) + require.Contains(t, err.Error(), "status code 500") + require.NotContains(t, err.Error(), "not json") +} + +func TestSendMessageFallsBackOnEmptyErrorBody(t *testing.T) { + t.Parallel() + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + })) + defer srv.Close() + + c := captureClient(t, srv) + err := c.SendMessage(context.Background(), "ws1", "sess1", "", "hello") + require.Error(t, err) + require.Contains(t, err.Error(), "status code 500") +} + func marshalSSEPayload(t *testing.T) []byte { t.Helper() From 3c0619397fe9e05fa4fadf129f8e09dbc8d021e5 Mon Sep 17 00:00:00 2001 From: Christian Rocha Date: Mon, 1 Jun 2026 07:28:25 -0400 Subject: [PATCH 08/15] chore(server,tesst): cover multi-client prompt cancellation flows Add end-to-end coverage for cancellation from another client, request disconnects, accepted-prompt races, and workspace handoff under the new asynchronous prompt contract. Co-Authored-By: Charm Crush --- internal/server/e2e_agent_test.go | 741 ++++++++++++++++++++++++++++++ internal/server/e2e_test.go | 12 + internal/ui/model/ui.go | 8 +- 3 files changed, 757 insertions(+), 4 deletions(-) create mode 100644 internal/server/e2e_agent_test.go diff --git a/internal/server/e2e_agent_test.go b/internal/server/e2e_agent_test.go new file mode 100644 index 0000000000000000000000000000000000000000..012068a967536cf214518f24c946e8b98fd16932 --- /dev/null +++ b/internal/server/e2e_agent_test.go @@ -0,0 +1,741 @@ +package server + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "sync" + "sync/atomic" + "testing" + "time" + + "charm.land/fantasy" + "github.com/charmbracelet/crush/internal/agent" + "github.com/charmbracelet/crush/internal/app" + "github.com/charmbracelet/crush/internal/backend" + "github.com/charmbracelet/crush/internal/message" + "github.com/charmbracelet/crush/internal/proto" + "github.com/charmbracelet/crush/internal/pubsub" + "github.com/google/uuid" + "github.com/stretchr/testify/require" +) + +// scriptedCoordinator is an agent.Coordinator stub that mimics the +// externally-observable contract of a real run over the SSE pipeline +// without booting a real model, database, or scheduler. It publishes a +// user message when a run begins and an assistant message (with the +// appropriate FinishReason) when the run ends, exactly the way the real +// sessionAgent.Run surfaces a turn to SSE subscribers. +// +// A run blocks until either its per-session context is canceled (via +// Cancel, mirroring the explicit cancel endpoint) or the test releases +// it. On cancel it emits a FinishReasonCanceled assistant message and +// returns context.Canceled (which backend.runAgent swallows, so no +// AgentEvent error is published). On normal release it emits a +// FinishReasonEndTurn assistant message and returns nil. +// +// The internal scheduler signal points the PLAN's e2e cases reference +// (e.g. "before registration in activeRequests", "between +// activeRequests.Set and assistant create") are not exposed by the +// codebase, so this stub reproduces the documented black-box outcome by +// controlling run timing directly through blockEntered / release. +type scriptedCoordinator struct { + app *app.App + + // blockEntered, when non-nil, is signaled (once) right after a run + // is entered and before the user message is emitted, letting a test + // interleave a cancel with the dispatched goroutine. + blockEntered chan struct{} + + mu sync.Mutex + // cancels holds the cancel func for every in-flight run, keyed by a + // monotonic id so concurrent runs for the same session each get their + // own entry (a map keyed only by sessionID would let a second run + // overwrite the first's cancel func and leak it). + cancels map[int64]sessionCancel + // pendingCancels counts cancels that arrived for a session while a run + // was in flight; a run for that session consumes one on entry and + // cancels itself, modeling the cancel-on-entry path a follow-up takes. + pendingCancels map[string]int + nextRunID int64 + // entered carries the monotonic run id assigned to each run as it is + // entered, so a test can correlate a later assistant message back to a + // specific run (run 1 vs an accepted follow-up). + entered chan int64 + runStarts atomic.Int32 + + release chan struct{} +} + +type sessionCancel struct { + sessionID string + cancel context.CancelFunc +} + +func newScriptedCoordinator(a *app.App) *scriptedCoordinator { + return &scriptedCoordinator{ + app: a, + cancels: make(map[int64]sessionCancel), + pendingCancels: make(map[string]int), + entered: make(chan int64, 8), + release: make(chan struct{}), + } +} + +func (c *scriptedCoordinator) emitUser(sessionID, id string) { + c.app.SendEvent(pubsub.Event[message.Message]{ + Type: pubsub.CreatedEvent, + Payload: message.Message{ + ID: id, + SessionID: sessionID, + Role: message.User, + Parts: []message.ContentPart{message.TextContent{Text: "hi"}}, + }, + }) +} + +func (c *scriptedCoordinator) emitAssistant(sessionID, id string, reason message.FinishReason) { + c.app.SendEvent(pubsub.Event[message.Message]{ + Type: pubsub.CreatedEvent, + Payload: message.Message{ + ID: id, + SessionID: sessionID, + Role: message.Assistant, + Parts: []message.ContentPart{message.Finish{Reason: reason}}, + }, + }) +} + +func (c *scriptedCoordinator) Run(ctx context.Context, sessionID, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error) { + c.runStarts.Add(1) + runCtx, cancel := context.WithCancel(ctx) + + c.mu.Lock() + id := c.nextRunID + c.nextRunID++ + c.cancels[id] = sessionCancel{sessionID: sessionID, cancel: cancel} + // Cancel-on-entry: if a cancel for this session arrived while this + // run was still being dispatched (no run yet in flight to receive + // it), consume the pending cancel now so the run takes the canceled + // path instead of streaming output. + if c.pendingCancels[sessionID] > 0 { + c.pendingCancels[sessionID]-- + cancel() + } + c.mu.Unlock() + + select { + case c.entered <- id: + default: + } + + if c.blockEntered != nil { + select { + case <-c.blockEntered: + case <-runCtx.Done(): + } + } + + defer func() { + c.mu.Lock() + delete(c.cancels, id) + c.mu.Unlock() + cancel() + }() + + // Qualify the emitted message ids with the run id so a test can + // attribute an assistant message to the exact run that produced it + // (run 1 vs an accepted follow-up sharing the same session). + userID := fmt.Sprintf("u-%s-%d", sessionID, id) + asstID := fmt.Sprintf("a-%s-%d", sessionID, id) + + c.emitUser(sessionID, userID) + + // Cancellation takes priority: if the run was already canceled it + // must take the canceled path even when release is closed, so a + // canceled run never races into a normal FinishReasonEndTurn. + select { + case <-runCtx.Done(): + c.emitAssistant(sessionID, asstID, message.FinishReasonCanceled) + return nil, context.Canceled + default: + } + + select { + case <-c.release: + c.emitAssistant(sessionID, asstID, message.FinishReasonEndTurn) + return nil, nil + case <-runCtx.Done(): + c.emitAssistant(sessionID, asstID, message.FinishReasonCanceled) + return nil, context.Canceled + } +} + +func (c *scriptedCoordinator) RunAccepted(ctx context.Context, accept *agent.AcceptedRun, sessionID, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error) { + return c.Run(ctx, sessionID, prompt, attachments...) +} + +func (c *scriptedCoordinator) BeginAccepted(string) *agent.AcceptedRun { return nil } + +func (c *scriptedCoordinator) Cancel(sessionID string) { + c.mu.Lock() + defer c.mu.Unlock() + // Cancel every in-flight run for this session. Concurrent runs for + // the same session (an active run plus an accepted follow-up still + // dispatching) each hold their own entry, so all of them are torn + // down by a single per-session cancel. + var canceled int + for _, sc := range c.cancels { + if sc.sessionID == sessionID { + sc.cancel() + canceled++ + } + } + // If at least one run was in flight, arm a pending cancel so a + // follow-up that has been accepted but not yet entered Run takes the + // cancel-on-entry path. With no run in flight this is a no-op, + // mirroring the production guarantee that an idle cancel does not arm + // a pending cancel against the next prompt. + if canceled > 0 { + c.pendingCancels[sessionID]++ + } +} + +func (c *scriptedCoordinator) CancelAll() { + c.mu.Lock() + defer c.mu.Unlock() + for _, sc := range c.cancels { + sc.cancel() + } +} + +func (c *scriptedCoordinator) IsBusy() bool { return false } +func (c *scriptedCoordinator) IsSessionBusy(string) bool { return false } +func (c *scriptedCoordinator) QueuedPrompts(string) int { return 0 } +func (c *scriptedCoordinator) QueuedPromptsList(string) []string { return nil } +func (c *scriptedCoordinator) ClearQueue(string) {} +func (c *scriptedCoordinator) Summarize(context.Context, string) error { return nil } +func (c *scriptedCoordinator) Model() agent.Model { return agent.Model{} } +func (c *scriptedCoordinator) UpdateModels(context.Context) error { return nil } + +// agentE2EHarness extends the SSE harness with a scripted coordinator +// wired into the workspace's embedded app.App, so POST /agent drives a +// real backend.SendMessage dispatch whose emitted user/assistant +// messages fan out over the same SSE pipeline production uses. +type agentE2EHarness struct { + *e2eHarness + coord *scriptedCoordinator +} + +func newAgentE2EHarness(t *testing.T) *agentE2EHarness { + t.Helper() + + h := &e2eHarness{} + + appCtx, cancel := context.WithCancel(context.Background()) + a := app.NewForTest(appCtx) + coord := newScriptedCoordinator(a) + a.AgentCoordinator = coord + t.Cleanup(func() { + cancel() + a.ShutdownForTest() + }) + + h.installServer(t) + + ws := &backend.Workspace{ + ID: uuid.New().String(), + Path: t.TempDir(), + App: a, + } + backend.SetWorkspaceShutdownFnForTest(ws, func() {}) + backend.InsertWorkspaceForTest(h.backend, ws) + + h.workspace = ws + h.app = a + return &agentE2EHarness{e2eHarness: h, coord: coord} +} + +// postAgentHTTP drives POST /v1/workspaces/{id}/agent over the harness's +// httptest server and returns the status code. +func (h *agentE2EHarness) postAgentHTTP(t *testing.T, ctx context.Context, sessionID string) int { + t.Helper() + body, err := json.Marshal(proto.AgentMessage{SessionID: sessionID, Prompt: "hi"}) + require.NoError(t, err) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, + h.httpSrv.URL+"/v1/workspaces/"+h.workspace.ID+"/agent", bytes.NewReader(body)) + require.NoError(t, err) + req.Header.Set("Content-Type", "application/json") + resp, err := h.httpSrv.Client().Do(req) + require.NoError(t, err) + _, _ = io.Copy(io.Discard, resp.Body) + resp.Body.Close() + return resp.StatusCode +} + +// cancelAgentHTTP drives POST /v1/workspaces/{id}/agent/sessions/{sid}/cancel. +func (h *agentE2EHarness) cancelAgentHTTP(t *testing.T, ctx context.Context, sessionID string) int { + t.Helper() + req, err := http.NewRequestWithContext(ctx, http.MethodPost, + h.httpSrv.URL+"/v1/workspaces/"+h.workspace.ID+"/agent/sessions/"+sessionID+"/cancel", nil) + require.NoError(t, err) + resp, err := h.httpSrv.Client().Do(req) + require.NoError(t, err) + _, _ = io.Copy(io.Discard, resp.Body) + resp.Body.Close() + return resp.StatusCode +} + +// waitForRunEntered blocks until a dispatched run for any session has +// been entered by the scripted coordinator, or fails the test. It +// returns the monotonic run id assigned to that run so a caller can +// correlate it with a later assistant message; callers that don't need +// the id can ignore the return value. +func (h *agentE2EHarness) waitForRunEntered(t *testing.T) int64 { + t.Helper() + select { + case id := <-h.coord.entered: + return id + case <-time.After(2 * time.Second): + t.Fatal("dispatched run was never entered") + return 0 + } +} + +// finishReason extracts the assistant message's FinishReason, if any. +func finishReason(m proto.Message) (proto.FinishReason, bool) { + for _, p := range m.Parts { + if f, ok := p.(proto.Finish); ok { + return f.Reason, true + } + } + return "", false +} + +// TestE2E_CancelByOtherClientDoesNotErrorPrompter covers PLAN Tests -> +// New end-to-end coverage item 1: a second client canceling a run does +// not surface a server error to the prompter; the run ends with a +// FinishReasonCanceled assistant message and no AgentEvent carries a +// non-nil Error. +func TestE2E_CancelByOtherClientDoesNotErrorPrompter(t *testing.T) { + t.Parallel() + h := newAgentE2EHarness(t) + ctx, cancel := context.WithCancel(t.Context()) + t.Cleanup(cancel) + + cidA := uuid.New().String() + cidB := uuid.New().String() + evcA, cancelA := h.subscribeSSE(t, ctx, h.workspace.ID, cidA) + t.Cleanup(cancelA) + evcB, cancelB := h.subscribeSSE(t, ctx, h.workspace.ID, cidB) + t.Cleanup(cancelB) + h.waitForAttached(t, 2) + + const sid = "s-cancel-other" + + // A posts a long-running prompt; the handler must return 202 + // immediately (the run blocks in the coordinator). + require.Equal(t, http.StatusAccepted, h.postAgentHTTP(t, ctx, sid)) + h.waitForRunEntered(t) + + // B cancels. + require.Equal(t, http.StatusOK, h.cancelAgentHTTP(t, ctx, sid)) + + // A's SSE stream receives the FinishReasonCanceled assistant + // message. + pickCtx, pickCancel := context.WithTimeout(ctx, 3*time.Second) + defer pickCancel() + got, ok := drainUntil(pickCtx, evcA, func(e pubsub.Event[proto.Message]) bool { + r, has := finishReason(e.Payload) + return e.Payload.Role == proto.Assistant && has && r == proto.FinishReasonCanceled + }) + require.True(t, ok, "client A must observe a FinishReasonCanceled assistant message") + require.Equal(t, sid, got.Payload.SessionID) + + // No AgentEvent error reaches A (cancel is not a server error). + errCtx, errCancel := context.WithTimeout(ctx, 250*time.Millisecond) + defer errCancel() + _, gotErrA := drainUntil(errCtx, evcA, func(e pubsub.Event[proto.AgentEvent]) bool { + return e.Payload.Type == proto.AgentEventTypeError && e.Payload.Error != nil + }) + require.False(t, gotErrA, "cancel must not surface an AgentEvent error to the prompter") + + // And no AgentEvent error reaches the canceling client B either; the + // PLAN requires that *no* client observes a non-nil Error. + errCtxB, errCancelB := context.WithTimeout(ctx, 250*time.Millisecond) + defer errCancelB() + _, gotErrB := drainUntil(errCtxB, evcB, func(e pubsub.Event[proto.AgentEvent]) bool { + return e.Payload.Type == proto.AgentEventTypeError && e.Payload.Error != nil + }) + require.False(t, gotErrB, "cancel must not surface an AgentEvent error to any client") +} + +// TestE2E_CancelImmediatelyAfter202IsNotLost covers PLAN item 1a: a +// cancel that races a freshly-dispatched run (before it would emit any +// output) is not lost. The run takes the cancel-on-entry path and emits +// a user message followed by a FinishReasonCanceled assistant message +// rather than streaming model output. +func TestE2E_CancelImmediatelyAfter202IsNotLost(t *testing.T) { + t.Parallel() + h := newAgentE2EHarness(t) + // Gate the run on a signal the test controls so the cancel can be + // observed while the dispatched goroutine is parked at entry. + h.coord.blockEntered = make(chan struct{}) + + ctx, cancel := context.WithCancel(t.Context()) + t.Cleanup(cancel) + + cid := uuid.New().String() + evc, cancelSSE := h.subscribeSSE(t, ctx, h.workspace.ID, cid) + t.Cleanup(cancelSSE) + h.waitForAttached(t, 1) + + const sid = "s-race-cancel" + require.Equal(t, http.StatusAccepted, h.postAgentHTTP(t, ctx, sid)) + h.waitForRunEntered(t) + + // Cancel while the run is still blocked at entry, then release it. + require.Equal(t, http.StatusOK, h.cancelAgentHTTP(t, ctx, sid)) + close(h.coord.blockEntered) + + pickCtx, pickCancel := context.WithTimeout(ctx, 3*time.Second) + defer pickCancel() + + gotUser, okUser := drainUntil(pickCtx, evc, func(e pubsub.Event[proto.Message]) bool { + return e.Payload.Role == proto.User && e.Payload.SessionID == sid + }) + require.True(t, okUser, "the canceled turn must still record a user message") + require.Equal(t, sid, gotUser.Payload.SessionID) + + gotAsst, okAsst := drainUntil(pickCtx, evc, func(e pubsub.Event[proto.Message]) bool { + r, has := finishReason(e.Payload) + return e.Payload.Role == proto.Assistant && has && r == proto.FinishReasonCanceled + }) + require.True(t, okAsst, "the canceled turn must end with a FinishReasonCanceled assistant message") + require.Equal(t, sid, gotAsst.Payload.SessionID) +} + +// TestE2E_IdleCancelDoesNotPoisonNextPrompt covers PLAN item 1b: an +// idle cancel (no active run) must not poison the next prompt. With the +// scripted coordinator the cancel records a pending entry only if a run +// is in flight; an idle cancel records one, but the documented +// guarantee is that the *next* prompt's outcome is observable. Here we +// assert the regression-relevant external behavior: after an idle +// cancel, a subsequent normal prompt is able to run and emit output. +// +// NOTE: This is a simplified version. The real "idle Escape must not +// poison" guarantee lives inside sessionAgent.Cancel's acceptedRuns +// gating, which is covered by the agent unit tests; the e2e stub cannot +// distinguish "truly idle" from "accepted but not yet running" without +// the internal acceptedRuns signal. See test summary. +func TestE2E_IdleCancelDoesNotPoisonNextPrompt(t *testing.T) { + t.Parallel() + h := newAgentE2EHarness(t) + ctx, cancel := context.WithCancel(t.Context()) + t.Cleanup(cancel) + + cid := uuid.New().String() + evc, cancelSSE := h.subscribeSSE(t, ctx, h.workspace.ID, cid) + t.Cleanup(cancelSSE) + h.waitForAttached(t, 1) + + const sid = "s-idle-cancel" + + // Idle cancel: no run in flight. The scripted coordinator drops it + // (no pending cancel recorded for a session that has no run), which + // models the production guarantee that an idle Escape does not arm + // a cancel against the next prompt. + require.Equal(t, http.StatusOK, h.cancelAgentHTTP(t, ctx, sid)) + + // Now a normal prompt; release it so it finishes successfully. + require.Equal(t, http.StatusAccepted, h.postAgentHTTP(t, ctx, sid)) + h.waitForRunEntered(t) + close(h.coord.release) + + pickCtx, pickCancel := context.WithTimeout(ctx, 3*time.Second) + defer pickCancel() + got, ok := drainUntil(pickCtx, evc, func(e pubsub.Event[proto.Message]) bool { + r, has := finishReason(e.Payload) + return e.Payload.Role == proto.Assistant && has && r == proto.FinishReasonEndTurn + }) + require.True(t, ok, "the next prompt after an idle cancel must run to FinishReasonEndTurn") + require.Equal(t, sid, got.Payload.SessionID) + + // And it must not be marked canceled. + canCtx, canCancel := context.WithTimeout(ctx, 200*time.Millisecond) + defer canCancel() + _, gotCanceled := drainUntil(canCtx, evc, func(e pubsub.Event[proto.Message]) bool { + r, has := finishReason(e.Payload) + return e.Payload.Role == proto.Assistant && has && r == proto.FinishReasonCanceled + }) + require.False(t, gotCanceled, "an idle cancel must not produce a FinishReasonCanceled marker on the next prompt") +} + +// TestE2E_CancelBetweenActiveSetAndAssistantCreate covers PLAN item 1d: +// a cancel that arrives after the run has begun but before it would +// create the assistant message must still produce a user message and a +// FinishReasonCanceled assistant message, never a silent return. The +// blockEntered gate parks the run after entry (modeling the window +// between activeRequests.Set and assistant creation). +func TestE2E_CancelBetweenActiveSetAndAssistantCreate(t *testing.T) { + t.Parallel() + h := newAgentE2EHarness(t) + h.coord.blockEntered = make(chan struct{}) + + ctx, cancel := context.WithCancel(t.Context()) + t.Cleanup(cancel) + + cid := uuid.New().String() + evc, cancelSSE := h.subscribeSSE(t, ctx, h.workspace.ID, cid) + t.Cleanup(cancelSSE) + h.waitForAttached(t, 1) + + const sid = "s-mid-window" + require.Equal(t, http.StatusAccepted, h.postAgentHTTP(t, ctx, sid)) + h.waitForRunEntered(t) + + // Cancel while parked at entry; then release so the run proceeds + // into its cancel branch (runCtx already canceled). + require.Equal(t, http.StatusOK, h.cancelAgentHTTP(t, ctx, sid)) + close(h.coord.blockEntered) + + pickCtx, pickCancel := context.WithTimeout(ctx, 3*time.Second) + defer pickCancel() + + _, okUser := drainUntil(pickCtx, evc, func(e pubsub.Event[proto.Message]) bool { + return e.Payload.Role == proto.User && e.Payload.SessionID == sid + }) + require.True(t, okUser, "a user message must be recorded for the canceled turn") + + gotAsst, okAsst := drainUntil(pickCtx, evc, func(e pubsub.Event[proto.Message]) bool { + r, has := finishReason(e.Payload) + return e.Payload.Role == proto.Assistant && has && r == proto.FinishReasonCanceled + }) + require.True(t, okAsst, "the run must not return silently; it must emit a FinishReasonCanceled assistant message") + require.Equal(t, sid, gotAsst.Payload.SessionID) + + // No AgentEvent error is published: a cancel in the + // activeRequests.Set -> assistant-create window is not a server + // error. + errCtx, errCancel := context.WithTimeout(ctx, 250*time.Millisecond) + defer errCancel() + _, gotErr := drainUntil(errCtx, evc, func(e pubsub.Event[proto.AgentEvent]) bool { + return e.Payload.Type == proto.AgentEventTypeError && e.Payload.Error != nil + }) + require.False(t, gotErr, "no AgentEvent error must be published for the canceled turn") +} + +// TestE2E_PromptRequestContextDoesNotOwnRun covers PLAN item 2: the +// prompting client's HTTP request context does not own the run. A POST +// with a very short request-context timeout still returns 202 before +// that context would expire, and the run keeps going (observed via SSE +// finishing normally after release). +func TestE2E_PromptRequestContextDoesNotOwnRun(t *testing.T) { + t.Parallel() + h := newAgentE2EHarness(t) + streamCtx, streamCancel := context.WithCancel(t.Context()) + t.Cleanup(streamCancel) + + cid := uuid.New().String() + evc, cancelSSE := h.subscribeSSE(t, streamCtx, h.workspace.ID, cid) + t.Cleanup(cancelSSE) + h.waitForAttached(t, 1) + + const sid = "s-short-req" + + // The POST request context times out almost immediately. The + // handler must still return 202 (fire-and-forget) and the run must + // survive past the request-context deadline. + reqCtx, reqCancel := context.WithTimeout(t.Context(), 50*time.Millisecond) + defer reqCancel() + require.Equal(t, http.StatusAccepted, h.postAgentHTTP(t, reqCtx, sid)) + h.waitForRunEntered(t) + + // Let the request context expire, then release the run. + <-reqCtx.Done() + close(h.coord.release) + + pickCtx, pickCancel := context.WithTimeout(streamCtx, 3*time.Second) + defer pickCancel() + got, ok := drainUntil(pickCtx, evc, func(e pubsub.Event[proto.Message]) bool { + r, has := finishReason(e.Payload) + return e.Payload.Role == proto.Assistant && has && r == proto.FinishReasonEndTurn + }) + require.True(t, ok, "the run must finish normally even after the prompting request context expired") + require.Equal(t, sid, got.Payload.SessionID) +} + +// TestE2E_AgentRunSurvivesAcrossWorkspaceClaims covers PLAN item 3: a +// run started by client A survives A detaching as long as another +// client (B) keeps the workspace alive; B observes the run finish via +// SSE. +func TestE2E_AgentRunSurvivesAcrossWorkspaceClaims(t *testing.T) { + t.Parallel() + h := newAgentE2EHarness(t) + + ctxA, cancelA := context.WithCancel(t.Context()) + ctxB, cancelB := context.WithCancel(t.Context()) + t.Cleanup(cancelB) + + cidA := uuid.New().String() + cidB := uuid.New().String() + _, killA := h.subscribeSSE(t, ctxA, h.workspace.ID, cidA) + t.Cleanup(killA) + evcB, killB := h.subscribeSSE(t, ctxB, h.workspace.ID, cidB) + t.Cleanup(killB) + h.waitForAttached(t, 2) + + const sid = "s-survive" + // A is the poster; the run must outlive A detaching as long as B + // keeps the workspace alive. + require.Equal(t, http.StatusAccepted, h.postAgentHTTP(t, ctxA, sid)) + h.waitForRunEntered(t) + + // A detaches; B is still attached so the workspace stays alive. + cancelA() + killA() + require.Eventually(t, func() bool { + return backend.WorkspaceLiveStreamCountForTest(h.workspace) == 1 + }, 3*time.Second, 10*time.Millisecond, + "A detaching must leave B as the sole attached client") + require.False(t, h.shutdownHit.Load(), "workspace must stay alive while B is attached") + + // Release the run; B must still observe it finish. + close(h.coord.release) + pickCtx, pickCancel := context.WithTimeout(ctxB, 3*time.Second) + defer pickCancel() + got, ok := drainUntil(pickCtx, evcB, func(e pubsub.Event[proto.Message]) bool { + r, has := finishReason(e.Payload) + return e.Payload.Role == proto.Assistant && has && r == proto.FinishReasonEndTurn + }) + require.True(t, ok, "B must observe the run finish after A detaches") + require.Equal(t, sid, got.Payload.SessionID) +} + +// TestE2E_CancelOfActiveRunAlsoCancelsAcceptedFollowUp covers PLAN item +// 1c at the externally-observable level: while session sid has an active +// run, a second prompt for sid is accepted; a cancel for sid must cancel +// the active run and must not let the follow-up stream a normal +// FinishReasonEndTurn. +// +// The sequence follows the PLAN exactly: prompt 1 becomes the active +// run, prompt 2 for the same sid is accepted, then a cancel for sid +// fires, and only afterwards are any signals released. The scripted +// coordinator models the externally-observable contract of the +// busy-queue branch and pendingCancels (which depend on internal +// scheduler signals the codebase does not expose): a per-session cancel +// tears down every in-flight run for sid and arms a cancel-on-entry for +// a follow-up still dispatching. The invariant asserted is the one that +// matters: after the cancel, the active run ends canceled and the +// follow-up never streams a normal FinishReasonEndTurn. +func TestE2E_CancelOfActiveRunAlsoCancelsAcceptedFollowUp(t *testing.T) { + t.Parallel() + h := newAgentE2EHarness(t) + ctx, cancel := context.WithCancel(t.Context()) + t.Cleanup(cancel) + + cid := uuid.New().String() + evc, cancelSSE := h.subscribeSSE(t, ctx, h.workspace.ID, cid) + t.Cleanup(cancelSSE) + h.waitForAttached(t, 1) + + const sid = "s-followup" + + // (a) Prompt 1 for sid becomes the active run. Capture its run id so + // the canceled assistant message below can be attributed to run 1 + // unambiguously. + require.Equal(t, http.StatusAccepted, h.postAgentHTTP(t, ctx, sid)) + run1 := h.waitForRunEntered(t) + + // (b) Prompt 2 for the *same* sid is accepted while the active run + // is still in flight; it is the follow-up the PLAN describes + // (acceptedRuns > 0, either still dispatching or about to enter the + // busy-queue branch). + require.Equal(t, http.StatusAccepted, h.postAgentHTTP(t, ctx, sid)) + run2 := h.waitForRunEntered(t) + require.NotEqual(t, run1, run2, "the follow-up must be a distinct run from the active one") + + // (c) B cancels sid. This tears down every in-flight run for the + // session and arms a pending cancel for any follow-up that has not + // yet entered Run. + require.Equal(t, http.StatusOK, h.cancelAgentHTTP(t, ctx, sid)) + + // (d) Open the coordinator gate so any run that is NOT canceled would + // be free to proceed straight into the normal FinishReasonEndTurn + // branch. The scripted Run checks runCtx.Done() before the release + // select, so a canceled run still takes the canceled path even with + // release closed; only a non-canceled run reaches FinishReasonEndTurn. + // Releasing here is therefore what makes the assertions below + // meaningful: if the cancel had failed to tear down run 1 or arm the + // cancel-on-entry for the follow-up, the freed gate would let that run + // stream a normal FinishReasonEndTurn and the test would fail. + close(h.coord.release) + + pickCtx, pickCancel := context.WithTimeout(ctx, 3*time.Second) + defer pickCancel() + + // (e) Run 1 (the active run) must end with FinishReasonCanceled. The + // assistant message id is qualified with the run id, so matching on + // run1's id proves the cancellation is attributed to the FIRST run + // and not to the follow-up. + // + // The single drain below is also the negative assertion for run 2: + // the match closure inspects every assistant event for sid as it + // scans, and if it ever observes the follow-up (run 2) streaming a + // normal FinishReasonEndTurn it records that violation immediately. + // This is what makes the run-2 check sound: a previous two-phase + // approach could let this very drain consume and discard a run-2 + // EndTurn while still hunting for run 1's canceled message, leaving a + // later no-EndTurn check unable to prove run 2 stayed canceled. + // Folding the negative check into the same scan means a run-2 EndTurn + // can never slip past unobserved, whether it arrives before or after + // run 1's canceled message. + run1AsstID := fmt.Sprintf("a-%s-%d", sid, run1) + run2AsstID := fmt.Sprintf("a-%s-%d", sid, run2) + var followUpEndTurn bool + got, ok := drainUntil(pickCtx, evc, func(e pubsub.Event[proto.Message]) bool { + if e.Payload.SessionID != sid || e.Payload.Role != proto.Assistant { + return false + } + r, has := finishReason(e.Payload) + if !has { + return false + } + // Any normal model output for sid after the cancel is a + // violation. The follow-up (run 2) must never reach the + // FinishReasonEndTurn branch; flag it the moment it is seen so + // the assertion below fails even if this event arrives while we + // are still waiting for run 1's canceled message. + if r == proto.FinishReasonEndTurn { + if e.Payload.ID == run2AsstID || e.Payload.ID != run1AsstID { + followUpEndTurn = true + } + // Stop draining; the EndTurn observation is decisive and the + // require.False below will surface the failure. + return true + } + return e.Payload.ID == run1AsstID && r == proto.FinishReasonCanceled + }) + require.False(t, followUpEndTurn, "the accepted follow-up must not stream a normal FinishReasonEndTurn after the cancel") + require.True(t, ok, "the first (active) run must end with FinishReasonCanceled") + require.Equal(t, run1AsstID, got.Payload.ID, "the canceled message must belong to the first (active) run") + gotReason, gotHas := finishReason(got.Payload) + require.True(t, gotHas) + require.Equal(t, proto.FinishReasonCanceled, gotReason, "the matched run-1 message must be canceled, not a normal end turn") + require.Equal(t, sid, got.Payload.SessionID) + + // Confirm no normal FinishReasonEndTurn for sid is still in flight. + // By this point the scan above has already ruled out a run-2 EndTurn + // arriving before run 1's canceled message; this guards against one + // arriving afterward. + endCtx, endCancel := context.WithTimeout(ctx, 300*time.Millisecond) + defer endCancel() + _, gotEnd := drainUntil(endCtx, evc, func(e pubsub.Event[proto.Message]) bool { + r, has := finishReason(e.Payload) + return e.Payload.SessionID == sid && e.Payload.Role == proto.Assistant && has && r == proto.FinishReasonEndTurn + }) + require.False(t, gotEnd, "the accepted follow-up must not stream model output after the cancel") +} diff --git a/internal/server/e2e_test.go b/internal/server/e2e_test.go index 08aaedf66c95edd704f18b62d83d64e79966564e..565a989136536e5bea8b1134995a3770183d4caa 100644 --- a/internal/server/e2e_test.go +++ b/internal/server/e2e_test.go @@ -240,6 +240,18 @@ func decodeSSEEnvelope(p pubsub.Payload) (any, bool) { return nil, false } return e, true + case pubsub.PayloadTypeAgentEvent: + var e pubsub.Event[proto.AgentEvent] + if err := json.Unmarshal(p.Payload, &e); err != nil { + return nil, false + } + return e, true + case pubsub.PayloadTypeRunComplete: + var e pubsub.Event[proto.RunComplete] + if err := json.Unmarshal(p.Payload, &e); err != nil { + return nil, false + } + return e, true } return nil, false } diff --git a/internal/ui/model/ui.go b/internal/ui/model/ui.go index 890dfc7de8a97eae13c4ecbd56ca07b566061408..2972400f236de92533ab336684ffcc10843bbda6 100644 --- a/internal/ui/model/ui.go +++ b/internal/ui/model/ui.go @@ -3283,12 +3283,12 @@ func (m *UI) sendMessage(content string, attachments ...message.Attachment) tea. // Capture session ID to avoid race with main goroutine updating m.session. sessionID := m.session.ID cmds = append(cmds, func() tea.Msg { + // AgentRun is fire-and-forget: it returns once the prompt has + // been accepted (HTTP 202) or synchronously with a validation + // or transport error. Run failures and cancellation surface + // through SSE-derived events, not this return value. err := m.com.Workspace.AgentRun(context.Background(), sessionID, content, attachments...) if err != nil { - isCancelErr := errors.Is(err, context.Canceled) - if isCancelErr { - return nil - } return util.InfoMsg{ Type: util.InfoTypeError, Msg: fmt.Sprintf("%v", err), From 5aaa709b92582a38d4e5d183bcd675a2c997dff3 Mon Sep 17 00:00:00 2001 From: Christian Rocha Date: Mon, 1 Jun 2026 07:35:26 -0400 Subject: [PATCH 09/15] chore(server,tests): cover async cancellation cleanup behavior Lock in that prompt cancellation is handled by background completion events rather than a synchronous HTTP error. This prevents canceled prompts from being reported as server failures. Co-Authored-By: Charm Crush --- internal/server/agent_cancel_test.go | 30 ++++++++++++++++++++++++++++ internal/server/proto.go | 17 ++++++++-------- 2 files changed, 38 insertions(+), 9 deletions(-) diff --git a/internal/server/agent_cancel_test.go b/internal/server/agent_cancel_test.go index dd04fa6b77c0bd1e6ae1532bcd39e68f61bff9b6..68d1f10132db3fb9f6b1e4251b744c59887f613e 100644 --- a/internal/server/agent_cancel_test.go +++ b/internal/server/agent_cancel_test.go @@ -150,6 +150,36 @@ func TestPostAgent_ReturnsOKOnContextCanceled(t *testing.T) { t.Fatal("dispatched run was never entered") } close(coord.release) + + // Wait for the dispatched run to fully return. Backend.runAgent + // swallows context.Canceled, so it must not publish a + // notify.TypeAgentError. Publishing would dereference the synthetic + // workspace's nil notification broker and crash this goroutine, + // which is the explicit guard that a cancel produces no top-level + // error event. + require.Eventually(t, func() bool { + return coord.ranCount.Load() == 1 + }, 2*time.Second, 10*time.Millisecond) +} + +// TestHandleError_ContextCanceledFallsThroughTo500 documents the step 8 +// cleanup: the old context.Canceled special case in handleError was +// removed because runtime cancellation of an agent run can no longer +// reach handleError. The agent-prompt handler returns 202 before the run +// starts (fire-and-forget SendMessage) and Backend.runAgent swallows +// context.Canceled. Any context.Canceled that still reaches handleError +// is therefore an unexpected synchronous error and falls through to the +// default 500 like any other. +func TestHandleError_ContextCanceledFallsThroughTo500(t *testing.T) { + t.Parallel() + + c := &controllerV1{server: &Server{}} + rec := httptest.NewRecorder() + req := httptest.NewRequestWithContext(t.Context(), http.MethodGet, "/", nil) + + c.handleError(rec, req, context.Canceled) + + require.Equal(t, http.StatusInternalServerError, rec.Code) } // TestPostAgent_DetachesRequestContext verifies that the dispatched run diff --git a/internal/server/proto.go b/internal/server/proto.go index f747e98791bbe95c2a510d4eb665abe205f7260c..f388d51bb87490a484bcb06b16e4698058bae134 100644 --- a/internal/server/proto.go +++ b/internal/server/proto.go @@ -1,7 +1,6 @@ package server import ( - "context" "encoding/json" "errors" "fmt" @@ -1035,15 +1034,15 @@ func (c *controllerV1) handleGetWorkspacePermissionsSkip(w http.ResponseWriter, // handleError maps backend errors to HTTP status codes and writes the // JSON error response. +// +// Runtime cancellation of an agent run no longer reaches here for the +// agent-prompt path: SendMessage is fire-and-forget (the handler returns +// 202 before the run starts) and Backend.runAgent swallows +// context.Canceled, surfacing the FinishReasonCanceled marker to SSE +// subscribers instead. The remaining callers pass synchronous backend +// errors, so context.Canceled gets no special case and would fall through +// to the default 500 like any other unexpected error. func (c *controllerV1) handleError(w http.ResponseWriter, r *http.Request, err error) { - // A canceled agent run is not an error from the prompting - // client's perspective. The cancellation reaches every SSE - // subscriber via the FinishReasonCanceled marker on the assistant - // message; the still-open POST should not surface a 500. - if errors.Is(err, context.Canceled) { - w.WriteHeader(http.StatusOK) - return - } status := http.StatusInternalServerError switch { case errors.Is(err, backend.ErrWorkspaceNotFound): From 34995e9333082f6f8a6437e4bc75a55fca45c981 Mon Sep 17 00:00:00 2001 From: Christian Rocha Date: Mon, 1 Jun 2026 17:59:26 -0400 Subject: [PATCH 10/15] fix(server): prevent cancels from affecting future prompts Apply a cancel only to prompts that were already accepted when the cancel request arrived. Immediately canceled accepted prompts also publish completion so callers waiting on that prompt do not hang. Co-Authored-By: Charm Crush --- internal/agent/accepted_run_test.go | 9 +- internal/agent/agent.go | 227 ++++++++++++++++------ internal/agent/dispatch_cancel_test.go | 255 ++++++++++++++++++++++++- 3 files changed, 432 insertions(+), 59 deletions(-) diff --git a/internal/agent/accepted_run_test.go b/internal/agent/accepted_run_test.go index d62422a9f02bec68a8da1a08c6e6d6b52e7d7699..14aec44265d44fa0f5b055ef14af3086e13a0cf3 100644 --- a/internal/agent/accepted_run_test.go +++ b/internal/agent/accepted_run_test.go @@ -28,8 +28,13 @@ func (a *sessionAgent) acceptedCount(sessionID string) int { } func (a *sessionAgent) hasPendingCancel(sessionID string) bool { - _, ok := a.pendingCancels.Get(sessionID) - return ok + mark, ok := a.cancelMark.Get(sessionID) + return ok && mark > 0 +} + +func (a *sessionAgent) pendingCancelMark(sessionID string) uint64 { + mark, _ := a.cancelMark.Get(sessionID) + return mark } func TestAcceptedRun_CloseIsIdempotent(t *testing.T) { diff --git a/internal/agent/agent.go b/internal/agent/agent.go index 97bd7a21af4c28f30e46fcaf23c23348467e90ac..393587a111ad806dc26bbf8a80f52dba49ce0397 100644 --- a/internal/agent/agent.go +++ b/internal/agent/agent.go @@ -112,6 +112,15 @@ type SessionAgentCall struct { // (in-process / local callers like AppWorkspace), behavior is // unchanged and no accept tracking applies. Accepted *AcceptedRun + // acceptSeq carries the accept sequence of the handle that produced + // this call after it has been enqueued and its Accepted handle + // stripped. The queue-drain paths compare it against a session's + // cancel mark so a follow-up queued before a cancel is dropped while + // one queued after the cancel survives. 0 means untracked (an + // in-process enqueue with no accept reservation), which the drain + // paths treat as covered by any present mark, preserving the + // pre-sequence behavior. + acceptSeq uint64 } type SessionAgent interface { @@ -167,20 +176,31 @@ type sessionAgent struct { // BeginAccepted increments it; only AcceptedRun.Close decrements // it. acceptedRuns *csync.Map[string, int] - // pendingCancels records sessions whose dispatched-but-not-yet- - // running call should observe a cancellation request. It is only - // set by Cancel when acceptedRuns > 0, so an idle Escape never - // poisons the next prompt. - pendingCancels *csync.Map[string, struct{}] + // cancelMark records, per session, a high-water accept sequence: an + // accepted handle is canceled by it iff the handle's sequence is at + // or below the mark. Cancel raises the mark to the latest sequence + // assigned at cancel time, so a single Cancel covers every prompt + // accepted-but-not-yet-active then, while a prompt accepted later + // (higher sequence) is never poisoned. Absent or 0 means no pending + // cancel. It is only raised by Cancel when acceptedRuns > 0, so an + // idle Escape never records a mark. + cancelMark *csync.Map[string, uint64] // dispatchMuCreate guards lazy creation of per-session entries in // dispatchMu so two goroutines can't race to lock different mutex // instances for the same session. dispatchMuCreate sync.Mutex - // acceptedMu serializes increments/decrements of acceptedRuns. It + // acceptedMu serializes increments/decrements of acceptedRuns and + // the assignment of accept sequence numbers from acceptSeqGen. It // is separate from dispatchMu so AcceptedRun.Close (which may run // while Run holds dispatchMu for the same session) does not // deadlock by re-entering the dispatch lock. acceptedMu sync.Mutex + // acceptSeqGen is the monotonic source of accept sequence numbers. + // Each BeginAccepted increments it under acceptedMu and stamps the + // returned handle, so sequences strictly increase in accept order + // across the agent. Cancel uses its current value as the per-session + // high-water mark. + acceptSeqGen uint64 } type SessionAgentOptions struct { @@ -218,7 +238,7 @@ func NewSessionAgent( activeRequests: csync.NewMap[string, context.CancelFunc](), dispatchMu: csync.NewMap[string, *sync.Mutex](), acceptedRuns: csync.NewMap[string, int](), - pendingCancels: csync.NewMap[string, struct{}](), + cancelMark: csync.NewMap[string, uint64](), } } @@ -231,7 +251,12 @@ func NewSessionAgent( type AcceptedRun struct { agent *sessionAgent sessionID string - done atomic.Bool + // seq is the monotonic accept sequence stamped by BeginAccepted. A + // cancel covers this handle iff seq is at or below the session's + // cancel mark, so a handle accepted after a cancel (higher seq) is + // never poisoned by it. + seq uint64 + done atomic.Bool } // Close decrements the accept counter for this reservation. It is safe @@ -263,19 +288,30 @@ func (a *sessionAgent) BeginAccepted(sessionID string) *AcceptedRun { defer a.acceptedMu.Unlock() count, _ := a.acceptedRuns.Get(sessionID) a.acceptedRuns.Set(sessionID, count+1) - return &AcceptedRun{agent: a, sessionID: sessionID} + a.acceptSeqGen++ + return &AcceptedRun{agent: a, sessionID: sessionID, seq: a.acceptSeqGen} } // endAccepted decrements the accept counter for sessionID. It is only // called via AcceptedRun.Close. It uses a dedicated lock (not the // per-session dispatch mutex) so it can run while Run holds dispatchMu // for the same session without deadlocking. +// +// When the count reaches zero the session's cancel mark is dropped: no +// accepted handle remains for it to cover, and any handle accepted later +// gets a strictly higher sequence that the mark would not match anyway. +// Handles canceled on entry never reach RunComplete, so this is the only +// place that clears the mark for an all-canceled batch. Sibling handles +// covered by the same mark are serialized on the per-session dispatch +// mutex and read the mark before they Close, so this never clears it out +// from under a covered handle still waiting to enter Run. func (a *sessionAgent) endAccepted(sessionID string) { a.acceptedMu.Lock() defer a.acceptedMu.Unlock() count, ok := a.acceptedRuns.Get(sessionID) if !ok || count <= 1 { a.acceptedRuns.Del(sessionID) + a.cancelMark.Del(sessionID) return } a.acceptedRuns.Set(sessionID, count-1) @@ -311,20 +347,44 @@ func (a *sessionAgent) enqueueCall(call SessionAgentCall) { existing = []SessionAgentCall{} } queued := call + if call.Accepted != nil { + // Preserve the accept sequence after the handle is stripped so + // the queue-drain paths can tell a follow-up queued before a + // cancel (covered by the mark) from one queued after it. + queued.acceptSeq = call.Accepted.seq + } queued.OnComplete = nil queued.Accepted = nil existing = append(existing, queued) a.messageQueue.Set(call.SessionID, existing) } -// clearPendingCancel removes any pending-cancel record for sessionID. It -// takes the per-session dispatch lock so it is ordered against Cancel and -// the dispatch handoff. +// clearPendingCancel removes any pending-cancel mark for sessionID. It +// takes the per-session dispatch lock so it is ordered against Cancel +// and the dispatch handoff. func (a *sessionAgent) clearPendingCancel(sessionID string) { mu := a.sessionMu(sessionID) mu.Lock() defer mu.Unlock() - a.pendingCancels.Del(sessionID) + a.cancelMark.Del(sessionID) +} + +// canceledBySeq reports whether an accepted handle or queued call with +// the given accept sequence is covered by a pending cancel for the +// session. Callers must hold the session's dispatch mutex. A tracked +// sequence (seq > 0) is covered only when it is at or below the cancel +// high-water mark, so a prompt accepted after the cancel (higher seq) is +// never poisoned. An untracked sequence (seq == 0, an in-process enqueue +// with no accept reservation) is covered whenever any mark is present, +// preserving the pre-sequence behavior. The mark is not consumed: it +// stays so every sibling handle it covers observes the same cancel, and +// a later handle (higher seq) ignores it regardless. +func (a *sessionAgent) canceledBySeq(sessionID string, seq uint64) bool { + mark, ok := a.cancelMark.Get(sessionID) + if !ok || mark == 0 { + return false + } + return seq == 0 || seq <= mark } // persistCanceledTurn writes the user/assistant records for a turn that @@ -356,6 +416,26 @@ func (a *sessionAgent) persistCanceledTurn(ctx context.Context, call SessionAgen return a.messages.Update(writeCtx, assistant) } +// publishRunComplete emits the authoritative terminal event for a turn. +// It honors the per-call OnComplete hook when set (so the coordinator can +// coalesce retries) and otherwise falls back to the RunComplete broker. +// ctx is used only for the bounded-blocking must-deliver publish; the +// terminal payload is supplied by the caller. This is the single emit path +// shared by the streaming defer and the cancel-on-entry early return so a +// caller waiting on RunComplete (e.g. `crush run` with a RunID) always +// observes exactly one terminal event regardless of which Run branch ends +// the turn. +func (a *sessionAgent) publishRunComplete(ctx context.Context, call SessionAgentCall, complete notify.RunComplete) { + if call.OnComplete != nil { + call.OnComplete(complete) + return + } + if a.runComplete == nil { + return + } + a.runComplete.PublishMustDeliver(ctx, pubsub.UpdatedEvent, complete) +} + // ValidateCall performs the cheap structural validation that // sessionAgent.Run requires before a call can be dispatched: a call must // carry either a non-empty prompt or a text attachment, and it must name a @@ -394,22 +474,39 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (result * // Serialize the accepted -> (cancel-on-entry | queued | // active) transition against a concurrent Cancel. Cancel takes // the same per-session lock, so every cancel observes at least - // one of: pendingCancels, an activeRequests entry, or a + // one of: a cancel mark, an activeRequests entry, or a // messageQueue entry it then clears. mu := a.sessionMu(call.SessionID) mu.Lock() - if _, pending := a.pendingCancels.Get(call.SessionID); pending { + if a.canceledBySeq(call.SessionID, call.Accepted.seq) { // Cancel-on-entry: a cancel arrived while this run was - // dispatched but not yet active. Consume the pending - // cancel, release the accept reservation, drop the lock, - // and persist a canceled turn without entering Stream. - a.pendingCancels.Del(call.SessionID) + // dispatched but not yet active, and this handle's accept + // sequence is at or below the session's cancel mark. The + // mark is left in place so sibling handles it also covers + // observe the same cancel; release the accept reservation, + // drop the lock, and persist a canceled turn without + // entering Stream. + // + // This path returns before the streaming defer that + // publishes RunComplete is installed, so emit the terminal + // event explicitly. Without it, a caller waiting on + // RunComplete for this RunID (e.g. `crush run`, which + // ignores message events and blocks on RunComplete) would + // hang on an immediately-canceled accepted run. call.Accepted.Close() mu.Unlock() + complete := notify.RunComplete{ + SessionID: call.SessionID, + RunID: call.RunID, + Cancelled: true, + } if err := a.persistCanceledTurn(ctx, call, false); err != nil { + complete.Error = err.Error() + a.publishRunComplete(ctx, call, complete) return nil, err } + a.publishRunComplete(ctx, call, complete) return nil, nil } @@ -579,14 +676,7 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (result * // the authoritative terminal event so a momentarily-full // subscriber channel can't silently drop it and hang // non-interactive clients waiting on RunComplete. - if call.OnComplete != nil { - call.OnComplete(complete) - return - } - if a.runComplete == nil { - return - } - a.runComplete.PublishMustDeliver(ctx, pubsub.UpdatedEvent, complete) + a.publishRunComplete(ctx, call, complete) }() history, files := a.preparePrompt(msgs, largeModel.CatwalkCfg.SupportsImages, call.Attachments...) @@ -621,24 +711,26 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (result * // Use latest tools (updated by SetTools when MCP tools change). prepared.Tools = a.tools.Copy() - // Drain queued follow-up prompts, but skip them if a cancel - // was recorded for the session while they sat in the queue: - // a cancel that arrived after the queue insertion must not - // let the queued prompt run as part of this step. + // Drain queued follow-up prompts, but skip any covered by a + // cancel recorded while they sat in the queue: a cancel that + // arrived after a prompt was queued must not let it run as + // part of this step. Coverage is per-call by accept sequence + // so a follow-up queued after the cancel (higher seq) is + // still folded in. dispatchLock := a.sessionMu(call.SessionID) dispatchLock.Lock() - _, canceled := a.pendingCancels.Get(call.SessionID) queuedCalls, _ := a.messageQueue.Get(call.SessionID) a.messageQueue.Del(call.SessionID) dispatchLock.Unlock() - if !canceled { - for _, queued := range queuedCalls { - userMessage, createErr := a.createUserMessage(callContext, queued) - if createErr != nil { - return callContext, prepared, createErr - } - prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...) + for _, queued := range queuedCalls { + if a.canceledBySeq(call.SessionID, queued.acceptSeq) { + continue } + userMessage, createErr := a.createUserMessage(callContext, queued) + if createErr != nil { + return callContext, prepared, createErr + } + prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...) } prepared.Messages = a.workaroundProviderMediaLimitations(prepared.Messages, largeModel) @@ -1008,20 +1100,37 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (result * // closing the dequeue -> re-register window. mu := a.sessionMu(call.SessionID) mu.Lock() - if _, pending := a.pendingCancels.Get(call.SessionID); pending { + queuedMessages, _ := a.messageQueue.Get(call.SessionID) + if mark, ok := a.cancelMark.Get(call.SessionID); ok && mark > 0 && len(queuedMessages) > 0 { // A cancel was recorded for this session (e.g. it arrived while - // this run was active and a follow-up had been accepted). Drop - // the queue instead of running it and consume the marker. - a.pendingCancels.Del(call.SessionID) + // this run was active and follow-ups had been queued). Drop the + // queued prompts it covers (accept sequence at or below the + // mark, or untracked); keep any queued after the cancel (higher + // sequence) so they still run. + var kept []SessionAgentCall + for _, q := range queuedMessages { + if q.acceptSeq == 0 || q.acceptSeq <= mark { + continue + } + kept = append(kept, q) + } + queuedMessages = kept + a.messageQueue.Set(call.SessionID, kept) + } + if len(queuedMessages) == 0 { + // No queued work. Clear the cancel mark only when no accepted + // run remains in flight that it might still cover; otherwise a + // sibling prompt (sequence at or below the mark) waiting to + // enter Run would lose its cancellation. When accepted runs are + // gone, this also clears a stale mark so it can't catch a + // future run. a.messageQueue.Del(call.SessionID) - mu.Unlock() - return result, err - } - queuedMessages, ok := a.messageQueue.Get(call.SessionID) - if !ok || len(queuedMessages) == 0 { - // No queued work. Clear any stale pending-cancel entry as a - // safety net so it can't catch a future run (no-op when unset). - a.pendingCancels.Del(call.SessionID) + a.acceptedMu.Lock() + inFlight, _ := a.acceptedRuns.Get(call.SessionID) + a.acceptedMu.Unlock() + if inFlight == 0 { + a.cancelMark.Del(call.SessionID) + } mu.Unlock() return result, err } @@ -1619,17 +1728,27 @@ func (a *sessionAgent) Cancel(sessionID string) { } // Record a pending cancel only when a dispatched-but-not-yet-active - // run exists. This catches a run still in the goroutine scheduler or + // run exists. This catches runs still in the goroutine scheduler or // about to enter Run's busy-queue branch, while leaving an idle // session untouched. Active and accepted are not mutually exclusive: // when a run is active and a follow-up has been accepted, both the // cancel above and this pending record fire. + // + // Raise the session's cancel mark to the latest accept sequence + // assigned so far. Every prompt currently accepted-but-not-yet- + // active has a sequence at or below that value, so one cancel covers + // all of them; a prompt accepted after this cancel gets a strictly + // higher sequence and is never poisoned. Using max keeps repeated + // cancels idempotent while the same prompts are in flight and lets a + // later cancel extend coverage to prompts accepted since. a.acceptedMu.Lock() count, ok := a.acceptedRuns.Get(sessionID) + mark := a.acceptSeqGen a.acceptedMu.Unlock() if ok && count > 0 { - slog.Debug("Recording pending cancel for accepted run", "session_id", sessionID) - a.pendingCancels.Set(sessionID, struct{}{}) + slog.Debug("Recording cancel mark for accepted runs", "session_id", sessionID, "count", count, "mark", mark) + existing, _ := a.cancelMark.Get(sessionID) + a.cancelMark.Set(sessionID, max(existing, mark)) } if a.QueuedPrompts(sessionID) > 0 { diff --git a/internal/agent/dispatch_cancel_test.go b/internal/agent/dispatch_cancel_test.go index f66de252e63559239c1d577fe51c0650589aa5b4..f1b0faad21da845f16728fd2d1101b64c569f2dc 100644 --- a/internal/agent/dispatch_cancel_test.go +++ b/internal/agent/dispatch_cancel_test.go @@ -5,9 +5,12 @@ import ( "errors" "sync/atomic" "testing" + "time" "charm.land/fantasy" + "github.com/charmbracelet/crush/internal/agent/notify" "github.com/charmbracelet/crush/internal/message" + "github.com/charmbracelet/crush/internal/pubsub" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -140,7 +143,7 @@ func TestRun_PrepareStepDrainSkipsQueuedOnPendingCancel(t *testing.T) { // A follow-up prompt sits queued for the session. sa.enqueueCall(SessionAgentCall{SessionID: sess.ID, Prompt: "queued-followup"}) // A cancel was recorded for the session while it sat in the queue. - sa.pendingCancels.Set(sess.ID, struct{}{}) + sa.cancelMark.Set(sess.ID, 1) result, err := sa.Run(t.Context(), SessionAgentCall{ SessionID: sess.ID, @@ -177,8 +180,8 @@ func TestRun_NormalCompletionClearsStalePendingCancel(t *testing.T) { sess, err := env.sessions.Create(t.Context(), "session") require.NoError(t, err) - // A stale pending cancel lingers (no queued work, no accepted run). - sa.pendingCancels.Set(sess.ID, struct{}{}) + // A stale cancel mark lingers (no queued work, no accepted run). + sa.cancelMark.Set(sess.ID, 1) result, err := sa.Run(t.Context(), SessionAgentCall{ SessionID: sess.ID, @@ -196,3 +199,249 @@ func TestRun_NormalCompletionClearsStalePendingCancel(t *testing.T) { assert.Equal(t, message.Assistant, msgs[1].Role) assert.Equal(t, message.FinishReasonEndTurn, msgs[1].FinishReason()) } + +// newCancelTestAgentWithRunComplete builds a DB-backed sessionAgent wired +// to a RunComplete broker so tests can observe the terminal event a +// RunID-bearing caller (e.g. `crush run`) blocks on. +func newCancelTestAgentWithRunComplete(t *testing.T) (*sessionAgent, fakeEnv, *pubsub.Broker[notify.RunComplete]) { + t.Helper() + env := testEnv(t) + broker := pubsub.NewBroker[notify.RunComplete]() + t.Cleanup(broker.Shutdown) + sa := NewSessionAgent(SessionAgentOptions{ + Sessions: env.sessions, + Messages: env.messages, + RunComplete: broker, + }).(*sessionAgent) + return sa, env, broker +} + +// TestRun_CancelOnEntryPublishesRunComplete covers the first review +// finding: the cancel-on-entry path returned before the streaming defer +// that publishes RunComplete was installed. A caller that dispatches a +// run with a RunID and blocks on RunComplete (ignoring message events, +// like `crush run`) would hang on an immediately-canceled accepted run. +// The cancel-on-entry path must publish a terminal RunComplete carrying +// the originating RunID. +func TestRun_CancelOnEntryPublishesRunComplete(t *testing.T) { + t.Parallel() + sa, env, broker := newCancelTestAgentWithRunComplete(t) + + sess, err := env.sessions.Create(t.Context(), "session") + require.NoError(t, err) + + ctx, cancel := context.WithCancel(t.Context()) + defer cancel() + ch := broker.Subscribe(ctx) + + accept := sa.BeginAccepted(sess.ID) + // A cancel arrives in the accepted-but-not-yet-active window. + sa.Cancel(sess.ID) + require.True(t, sa.hasPendingCancel(sess.ID)) + + result, err := sa.Run(t.Context(), SessionAgentCall{ + SessionID: sess.ID, + RunID: "run-cancel-on-entry", + Prompt: "hello", + Accepted: accept, + }) + require.NoError(t, err) + require.Nil(t, result) + + select { + case got := <-ch: + assert.Equal(t, "run-cancel-on-entry", got.Payload.RunID, + "RunComplete must echo the originating RunID") + assert.Equal(t, sess.ID, got.Payload.SessionID) + assert.True(t, got.Payload.Cancelled, + "cancel-on-entry RunComplete must be marked cancelled") + case <-time.After(2 * time.Second): + t.Fatal("cancel-on-entry must publish RunComplete; a RunID caller would hang otherwise") + } +} + +// TestCancel_TwoAcceptedBothObserveCancellation covers the second review +// finding: a single cancel with two accepted-not-yet-active prompts must +// cancel both. The cancel raises the session's high-water mark to the +// latest accept sequence, so every prompt accepted-but-not-yet-active at +// cancel time is covered and both take the cancel-on-entry path. +func TestCancel_TwoAcceptedBothObserveCancellation(t *testing.T) { + t.Parallel() + sa, env := newCancelTestAgent(t) + + sess, err := env.sessions.Create(t.Context(), "session") + require.NoError(t, err) + + // Two prompts are accepted-but-not-yet-active for the same session. + accept1 := sa.BeginAccepted(sess.ID) + accept2 := sa.BeginAccepted(sess.ID) + require.Equal(t, 2, sa.acceptedCount(sess.ID)) + + // A single cancel arrives before either becomes active. + sa.Cancel(sess.ID) + require.Equal(t, accept2.seq, sa.pendingCancelMark(sess.ID), + "one cancel must mark every currently-accepted prompt as canceled") + require.GreaterOrEqual(t, sa.pendingCancelMark(sess.ID), accept1.seq, + "the mark must cover the earlier accepted prompt too") + + // Both prompts enter Run; each must take cancel-on-entry, not run. + r1, err := sa.Run(t.Context(), SessionAgentCall{ + SessionID: sess.ID, + Prompt: "first", + Accepted: accept1, + }) + require.NoError(t, err) + require.Nil(t, r1) + + r2, err := sa.Run(t.Context(), SessionAgentCall{ + SessionID: sess.ID, + Prompt: "second", + Accepted: accept2, + }) + require.NoError(t, err) + require.Nil(t, r2) + + require.False(t, sa.hasPendingCancel(sess.ID), + "both reserved units must be consumed") + require.Equal(t, 0, sa.acceptedCount(sess.ID), + "both accept reservations must be released") + + // Each canceled-on-entry turn writes a user + canceled assistant + // message, and neither prompt was enqueued to run normally. + require.Equal(t, 0, sa.QueuedPrompts(sess.ID), + "neither accepted prompt may be enqueued to run normally") + msgs, err := env.messages.List(t.Context(), sess.ID) + require.NoError(t, err) + require.Len(t, msgs, 4, "two canceled turns produce two user + two assistant messages") + var canceled int + for _, m := range msgs { + if m.Role == message.Assistant { + assert.Equal(t, message.FinishReasonCanceled, m.FinishReason()) + canceled++ + } + } + require.Equal(t, 2, canceled, "both turns must finish canceled") +} + +// TestRun_IdleCancelDoesNotPoisonNextPrompt covers the idle-cancel +// no-op guarantee end-to-end: an Escape on an idle session must not +// record a pending cancel that leaks into the next accepted prompt, which +// must run normally to completion. +func TestRun_IdleCancelDoesNotPoisonNextPrompt(t *testing.T) { + t.Parallel() + sa, env := newStreamTestAgent(t) + + sess, err := env.sessions.Create(t.Context(), "session") + require.NoError(t, err) + + // Idle Escape: no accepted run, no active request. + sa.Cancel(sess.ID) + require.False(t, sa.hasPendingCancel(sess.ID), + "idle cancel must not record a pending cancel") + + // The next accepted prompt must run normally, not cancel on entry. + accept := sa.BeginAccepted(sess.ID) + result, err := sa.Run(t.Context(), SessionAgentCall{ + SessionID: sess.ID, + Prompt: "next", + Accepted: accept, + }) + require.NoError(t, err) + require.NotNil(t, result, "next prompt must run normally after an idle cancel") + + msgs, err := env.messages.List(t.Context(), sess.ID) + require.NoError(t, err) + require.Len(t, msgs, 2) + assert.Equal(t, message.User, msgs[0].Role) + assert.Equal(t, message.Assistant, msgs[1].Role) + assert.Equal(t, message.FinishReasonEndTurn, msgs[1].FinishReason(), + "the prompt must finish normally, not canceled") +} + +// TestCancel_AcceptedAfterCancelIsNotPoisoned is the regression test for +// the reviewer's finding: a counted session-level pending cancel let a +// prompt accepted after the cancel enter Run first and consume a unit +// reserved for the earlier prompts. With a sequence high-water mark, a +// single cancel covers exactly the prompts accepted-but-not-yet-active at +// cancel time (A and B); a prompt accepted afterwards (C) gets a higher +// sequence and must run normally without consuming A or B's cancellation. +// C is run first to prove it neither cancels nor drains the mark, then A +// and B are run and must both cancel on entry. +func TestCancel_AcceptedAfterCancelIsNotPoisoned(t *testing.T) { + t.Parallel() + sa, env := newStreamTestAgent(t) + + sess, err := env.sessions.Create(t.Context(), "session") + require.NoError(t, err) + + // A and B are accepted-but-not-yet-active. + acceptA := sa.BeginAccepted(sess.ID) + acceptB := sa.BeginAccepted(sess.ID) + + // One cancel arrives covering both A and B. + sa.Cancel(sess.ID) + require.True(t, sa.hasPendingCancel(sess.ID)) + require.Equal(t, acceptB.seq, sa.pendingCancelMark(sess.ID), + "the mark must cover every prompt accepted before the cancel") + + // C is accepted AFTER the cancel; its sequence is above the mark. + acceptC := sa.BeginAccepted(sess.ID) + require.Greater(t, acceptC.seq, sa.pendingCancelMark(sess.ID), + "a prompt accepted after the cancel must not be covered by the mark") + + // Run C first. It must run normally to completion and must not + // consume or clear the cancellation reserved for A and B. + rc, err := sa.Run(t.Context(), SessionAgentCall{ + SessionID: sess.ID, + Prompt: "C", + Accepted: acceptC, + }) + require.NoError(t, err) + require.NotNil(t, rc, "C was accepted after the cancel and must run normally") + require.True(t, sa.hasPendingCancel(sess.ID), + "running C must not drain the cancellation reserved for A and B") + + // Now A and B run. Both were accepted before the cancel and must + // take the cancel-on-entry path. + ra, err := sa.Run(t.Context(), SessionAgentCall{ + SessionID: sess.ID, + Prompt: "A", + Accepted: acceptA, + }) + require.NoError(t, err) + require.Nil(t, ra, "A must cancel on entry, not run") + + rb, err := sa.Run(t.Context(), SessionAgentCall{ + SessionID: sess.ID, + Prompt: "B", + Accepted: acceptB, + }) + require.NoError(t, err) + require.Nil(t, rb, "B must cancel on entry, not run") + + require.False(t, sa.hasPendingCancel(sess.ID), + "the mark clears once all covered handles are resolved") + require.Equal(t, 0, sa.acceptedCount(sess.ID)) + require.Equal(t, 0, sa.QueuedPrompts(sess.ID), + "neither A nor B may be enqueued to run normally") + + // C produced a normal turn; A and B each produced a canceled turn. + msgs, err := env.messages.List(t.Context(), sess.ID) + require.NoError(t, err) + require.Len(t, msgs, 6, "C normal + A canceled + B canceled = 3 user + 3 assistant") + + var normal, canceled int + for _, m := range msgs { + if m.Role != message.Assistant { + continue + } + switch m.FinishReason() { + case message.FinishReasonEndTurn: + normal++ + case message.FinishReasonCanceled: + canceled++ + } + } + require.Equal(t, 1, normal, "only C finished normally") + require.Equal(t, 2, canceled, "both A and B finished canceled") +} From a6a8459c07cc2088cd1ce012171b0cc9c05a1339 Mon Sep 17 00:00:00 2001 From: Christian Rocha Date: Mon, 1 Jun 2026 18:11:28 -0400 Subject: [PATCH 11/15] fix(server): ignore background errors from unrelated prompts Carry the submitted prompt identity through error events so command-line runs fail only for their own prompt. Older session-scoped errors remain supported for compatibility. Co-Authored-By: Charm Crush --- internal/agent/notify/notify.go | 6 ++ internal/backend/agent.go | 1 + internal/cmd/run.go | 24 +++++-- internal/cmd/run_stream_test.go | 88 ++++++++++++++++++++++++++ internal/proto/agent.go | 7 ++ internal/server/events.go | 1 + internal/server/events_test.go | 32 ++++++++++ internal/workspace/client_workspace.go | 1 + 8 files changed, 156 insertions(+), 4 deletions(-) diff --git a/internal/agent/notify/notify.go b/internal/agent/notify/notify.go index 1a217a6d00650fe1134b24d9d779821015513063..22e9f17769b5585302a195049bb3abca919f9a91 100644 --- a/internal/agent/notify/notify.go +++ b/internal/agent/notify/notify.go @@ -23,6 +23,12 @@ type Notification struct { SessionTitle string Type Type ProviderID string + // RunID, when non-empty, is the caller-supplied correlator + // (proto.AgentMessage.RunID) for the run that produced this + // notification. It lets observers attribute a TypeAgentError to a + // specific request rather than to any in-flight run on the + // session. Empty when no caller set one. + RunID string // Message carries the error text for TypeAgentError. Other // notification types ignore it. Message string diff --git a/internal/backend/agent.go b/internal/backend/agent.go index 2dd0479d3236d55e3919bdef1f16bb593fe5684e..4af3b8f0d2f88ad5daff41f40664b303c948b263 100644 --- a/internal/backend/agent.go +++ b/internal/backend/agent.go @@ -85,6 +85,7 @@ func (b *Backend) runAgent(ws *Workspace, msg proto.AgentMessage, accept *agent. ws.AgentNotifications().Publish(pubsub.CreatedEvent, notify.Notification{ SessionID: msg.SessionID, + RunID: msg.RunID, Type: notify.TypeAgentError, Message: err.Error(), }) diff --git a/internal/cmd/run.go b/internal/cmd/run.go index 87fa32606674847741a9d028a26375fb98935fc4..2feeba78e6f4862e453fcb790e428b3e08ab0505 100644 --- a/internal/cmd/run.go +++ b/internal/cmd/run.go @@ -409,11 +409,27 @@ func (s *runStream) handle(ev any, stopSpinner func()) (done bool, err error) { return true, nil case pubsub.Event[proto.AgentEvent]: - if e.Payload.Error != nil { - stop() - return true, fmt.Errorf("agent error: %w", e.Payload.Error) + if e.Payload.Error == nil { + return false, nil } - return false, nil + // Attribute the error to our run before treating it as + // fatal. Async errors from an unrelated workspace run share + // this channel, so a foreign failure must not abort us: + // - if the event carries a RunID, it is the authoritative + // correlator: it must match our run exactly, otherwise it + // belongs to a different request and we ignore it. + // - if the event carries no RunID (older server), fall back + // to SessionID: it must be present and match our session, + // otherwise we ignore it. + if e.Payload.RunID != "" { + if e.Payload.RunID != s.runID { + return false, nil + } + } else if e.Payload.SessionID == "" || e.Payload.SessionID != s.sessionID { + return false, nil + } + stop() + return true, fmt.Errorf("agent error: %w", e.Payload.Error) } return false, nil } diff --git a/internal/cmd/run_stream_test.go b/internal/cmd/run_stream_test.go index ac168fa77045aa6aa5761b6f9c657f066c952734..028eb03baa0dc7a55a0037e67f033b708ff9634e 100644 --- a/internal/cmd/run_stream_test.go +++ b/internal/cmd/run_stream_test.go @@ -2,6 +2,7 @@ package cmd import ( "bytes" + "errors" "testing" "time" @@ -307,6 +308,93 @@ func TestRunStream_RunIDSuppressesLiveMessagesAndPrintsRunComplete(t *testing.T) require.Equal(t, "streamed prefix final", buf.String()) } +// TestRunStream_AgentErrorRunIDFiltersForeign verifies that an async +// agent error carrying a non-empty RunID is fatal only when it matches +// our run. A foreign RunID is ignored regardless of the event's +// SessionID, because RunID is the authoritative correlator and async +// errors share the agent event channel: without strict RunID matching +// an unrelated workspace failure would abort our run. +func TestRunStream_AgentErrorRunIDFiltersForeign(t *testing.T) { + t.Parallel() + + // Foreign RunID with a matching session is still foreign. + s := &runStream{sessionID: "S", runID: "run-mine", out: &bytes.Buffer{}, read: map[string]int{}} + done, err := s.handle(pubsub.Event[proto.AgentEvent]{Payload: proto.AgentEvent{ + Type: proto.AgentEventTypeError, + SessionID: "S", + RunID: "run-other", + Error: errors.New("foreign boom"), + }}, nil) + require.NoError(t, err, "foreign RunID error must not abort our run") + require.False(t, done) + + // Foreign RunID with a different session is ignored. + done, err = s.handle(pubsub.Event[proto.AgentEvent]{Payload: proto.AgentEvent{ + Type: proto.AgentEventTypeError, + SessionID: "other", + RunID: "run-other", + Error: errors.New("foreign boom"), + }}, nil) + require.NoError(t, err, "foreign RunID error must not abort our run") + require.False(t, done) + + // Foreign RunID with a missing session is ignored. + done, err = s.handle(pubsub.Event[proto.AgentEvent]{Payload: proto.AgentEvent{ + Type: proto.AgentEventTypeError, + RunID: "run-other", + Error: errors.New("foreign boom"), + }}, nil) + require.NoError(t, err, "foreign RunID error must not abort our run") + require.False(t, done) + + // Matching RunID is fatal. + done, err = s.handle(pubsub.Event[proto.AgentEvent]{Payload: proto.AgentEvent{ + Type: proto.AgentEventTypeError, + SessionID: "S", + RunID: "run-mine", + Error: errors.New("my boom"), + }}, nil) + require.Error(t, err, "matching RunID error must be fatal") + require.True(t, done) +} + +// TestRunStream_AgentErrorNoRunIDFiltersBySession verifies the +// compatibility fallback: when the event carries no RunID, attribution +// falls back to SessionID. An error for another session or with an +// empty session is ignored, while an error for our own session is fatal +// so a real failure is never dropped. +func TestRunStream_AgentErrorNoRunIDFiltersBySession(t *testing.T) { + t.Parallel() + + s := &runStream{sessionID: "S", out: &bytes.Buffer{}, read: map[string]int{}} + + // Empty RunID for another session is ignored. + done, err := s.handle(pubsub.Event[proto.AgentEvent]{Payload: proto.AgentEvent{ + Type: proto.AgentEventTypeError, + SessionID: "other", + Error: errors.New("foreign boom"), + }}, nil) + require.NoError(t, err, "error for another session must not abort our run") + require.False(t, done) + + // Empty RunID with an empty session is ignored. + done, err = s.handle(pubsub.Event[proto.AgentEvent]{Payload: proto.AgentEvent{ + Type: proto.AgentEventTypeError, + Error: errors.New("foreign boom"), + }}, nil) + require.NoError(t, err, "error with no session must not abort our run") + require.False(t, done) + + // Empty RunID with a matching session is fatal. + done, err = s.handle(pubsub.Event[proto.AgentEvent]{Payload: proto.AgentEvent{ + Type: proto.AgentEventTypeError, + SessionID: "S", + Error: errors.New("my boom"), + }}, nil) + require.Error(t, err, "error for our own session must be fatal") + require.True(t, done) +} + // TestRunStream_NoRunIDFallsBackToSessionID preserves the older // behaviour for callers (and tests) that don't supply a RunID: // SessionID-only matching still terminates the stream on the diff --git a/internal/proto/agent.go b/internal/proto/agent.go index e5266e52614a5bc43065ff62cf18b16f8ee7401f..2c85923e547b6755357479218f9ff4815e491527 100644 --- a/internal/proto/agent.go +++ b/internal/proto/agent.go @@ -31,6 +31,13 @@ type AgentEvent struct { Message Message `json:"message"` Error error `json:"error,omitempty"` + // RunID echoes the caller-supplied AgentMessage.RunID for the run + // that produced this event. It lets observers (notably + // `crush run`) attribute an error event to a specific request + // instead of to any in-flight run on the session. Empty when no + // caller set one. + RunID string `json:"run_id,omitempty"` + // When summarizing. SessionID string `json:"session_id,omitempty"` SessionTitle string `json:"session_title,omitempty"` diff --git a/internal/server/events.go b/internal/server/events.go index fd085c5a415c0ef0fc402673ad23fff8435f1db6..526f9e195009cd70c453958778fb98887aae4a37 100644 --- a/internal/server/events.go +++ b/internal/server/events.go @@ -89,6 +89,7 @@ func wrapEvent(ev any) *pubsub.Payload { payload := proto.AgentEvent{ SessionID: e.Payload.SessionID, SessionTitle: e.Payload.SessionTitle, + RunID: e.Payload.RunID, Type: proto.AgentEventType(e.Payload.Type), } if e.Payload.Type == notify.TypeAgentError { diff --git a/internal/server/events_test.go b/internal/server/events_test.go index 432bc42f910b4acec675baea46754b81defab9f6..e4238a05eb3abf50e13329acfaabd2cb77dd464c 100644 --- a/internal/server/events_test.go +++ b/internal/server/events_test.go @@ -123,6 +123,38 @@ func TestRunCompleteToProto_RoundTrip(t *testing.T) { require.False(t, decoded.Payload.Cancelled) } +// TestAgentErrorToProto_PreservesRunID verifies that an async agent +// error notification carries its originating RunID (and SessionID) +// through the SSE envelope. Without these correlators, `crush run` +// cannot tell whether an error event belongs to its own run and +// would abort on any unrelated workspace failure. +func TestAgentErrorToProto_PreservesRunID(t *testing.T) { + t.Parallel() + + src := pubsub.Event[notify.Notification]{ + Type: pubsub.CreatedEvent, + Payload: notify.Notification{ + SessionID: "S", + RunID: "run-99", + Type: notify.TypeAgentError, + Message: "boom", + }, + } + + env := wrapEvent(src) + require.NotNil(t, env) + require.Equal(t, pubsub.PayloadTypeAgentEvent, env.Type) + + var decoded pubsub.Event[proto.AgentEvent] + require.NoError(t, json.Unmarshal(env.Payload, &decoded)) + require.Equal(t, proto.AgentEventTypeError, decoded.Payload.Type) + require.Equal(t, "S", decoded.Payload.SessionID) + require.Equal(t, "run-99", decoded.Payload.RunID, + "RunID must survive so observers can attribute the error to its run") + require.NotNil(t, decoded.Payload.Error) + require.Equal(t, "boom", decoded.Payload.Error.Error()) +} + // TestRunCompleteToProto_Error verifies that error- and cancel-shaped // RunComplete events round-trip cleanly so clients can distinguish // "agent failed" (returns non-zero from `crush run`) from "agent diff --git a/internal/workspace/client_workspace.go b/internal/workspace/client_workspace.go index a6a43731675698083671cae95f983d7a3a724a5d..609a9145bd3a5374c6fbaf96b3a7549187b146d5 100644 --- a/internal/workspace/client_workspace.go +++ b/internal/workspace/client_workspace.go @@ -706,6 +706,7 @@ func (w *ClientWorkspace) translateEvent(ev any) tea.Msg { n := notify.Notification{ SessionID: e.Payload.SessionID, SessionTitle: e.Payload.SessionTitle, + RunID: e.Payload.RunID, Type: notify.Type(e.Payload.Type), } if e.Payload.Error != nil { From c6d15905a9ca53b176c1c3430c03a497de33a983 Mon Sep 17 00:00:00 2001 From: Christian Rocha Date: Mon, 1 Jun 2026 19:09:50 -0400 Subject: [PATCH 12/15] chore(server,tests): cover the accepted-prompt cancel race end to end Exercise the real server-to-agent path for a cancel that lands after a prompt is accepted but before it becomes active. This protects the race fix from passing only in mocked unit paths. Co-Authored-By: Charm Crush --- internal/agent/agenttest/coordinator.go | 79 +++++++++++ .../backend/accepted_run_integration_test.go | 131 ++++++++++++++++++ 2 files changed, 210 insertions(+) create mode 100644 internal/agent/agenttest/coordinator.go create mode 100644 internal/backend/accepted_run_integration_test.go diff --git a/internal/agent/agenttest/coordinator.go b/internal/agent/agenttest/coordinator.go new file mode 100644 index 0000000000000000000000000000000000000000..9cb1e139b20c4f54daf9ed89b2a2bf43ff72bfce --- /dev/null +++ b/internal/agent/agenttest/coordinator.go @@ -0,0 +1,79 @@ +// Package agenttest provides test-only constructors for wiring a real +// production agent.Coordinator without booting a full app.App. It is +// imported only from _test.go files (e.g. internal/backend integration +// tests) and is never referenced by production code, so it is compiled +// only under tests and never ships in the production binary or API. +package agenttest + +import ( + "context" + + "charm.land/catwalk/pkg/catwalk" + "charm.land/fantasy/providers/openaicompat" + "github.com/charmbracelet/crush/internal/agent" + "github.com/charmbracelet/crush/internal/config" + "github.com/charmbracelet/crush/internal/message" + "github.com/charmbracelet/crush/internal/permission" + "github.com/charmbracelet/crush/internal/session" +) + +// NewCoordinator builds a real agent.Coordinator through the production +// agent.NewCoordinator constructor so the RunAccepted / BeginAccepted / +// run path (including UpdateModels) is the actual code under test. +// +// It installs a minimal config with a single openai-compatible provider +// whose model resolves offline. run rebuilds the model on every call, so +// the provider must construct without network I/O; the cancel-on-entry +// path this helper is built to exercise returns before any model call, +// so no request is ever issued. The coder agent's allowed-tools list is +// cleared to keep tool construction cheap and free of sub-agent wiring. +// +// The optional coordinator dependencies (history, filetracker, LSP, +// notify, runComplete, skills) are nil: run guards the publisher fields +// and the cancel-on-entry path never touches the others. +func NewCoordinator( + ctx context.Context, + workingDir string, + sessions session.Service, + messages message.Service, +) (agent.Coordinator, error) { + cfg, err := config.Init(workingDir, "", false) + if err != nil { + return nil, err + } + + const ( + providerID = "test-openai-compat" + modelID = "test-model" + ) + cfg.Config().Providers.Set(providerID, config.ProviderConfig{ + ID: providerID, + Name: "Test", + Type: openaicompat.Name, + BaseURL: "http://127.0.0.1:0/v1", + APIKey: "test", + Models: []catwalk.Model{{ID: modelID, DefaultMaxTokens: 4096}}, + }) + selected := config.SelectedModel{Provider: providerID, Model: modelID} + cfg.Config().Models[config.SelectedModelTypeLarge] = selected + cfg.Config().Models[config.SelectedModelTypeSmall] = selected + + // Keep buildTools light: no sub-agent or agentic-fetch construction. + coderCfg := cfg.Config().Agents[config.AgentCoder] + coderCfg.AllowedTools = nil + cfg.Config().Agents[config.AgentCoder] = coderCfg + + return agent.NewCoordinator( + ctx, + cfg, + sessions, + messages, + permission.NewPermissionService(workingDir, true, nil), + nil, + nil, + nil, + nil, + nil, + nil, + ) +} diff --git a/internal/backend/accepted_run_integration_test.go b/internal/backend/accepted_run_integration_test.go new file mode 100644 index 0000000000000000000000000000000000000000..9a0a7ba249cc547956dd479cabcf5545a07a5c26 --- /dev/null +++ b/internal/backend/accepted_run_integration_test.go @@ -0,0 +1,131 @@ +package backend + +import ( + "context" + "testing" + "time" + + "charm.land/fantasy" + "github.com/charmbracelet/crush/internal/agent" + "github.com/charmbracelet/crush/internal/agent/agenttest" + "github.com/charmbracelet/crush/internal/db" + "github.com/charmbracelet/crush/internal/message" + "github.com/charmbracelet/crush/internal/proto" + "github.com/charmbracelet/crush/internal/session" + "github.com/stretchr/testify/require" +) + +// gatedCoordinator wraps a real agent.Coordinator and parks RunAccepted +// before delegating to it. Every method other than RunAccepted is +// inherited from the embedded coordinator, so BeginAccepted (called by +// Backend.SendMessage) and RunAccepted (called by the dispatched run) +// are the production agent.Coordinator implementations under test, not +// stubs. The gate only delays entry into the real RunAccepted so a +// cancel can be made to land in the accepted-but-not-yet-active window +// deterministically: the accept handle is not consumed by +// sessionAgent.Run until the real RunAccepted runs after the gate opens. +type gatedCoordinator struct { + agent.Coordinator + entered chan struct{} + gate chan struct{} +} + +func (c *gatedCoordinator) RunAccepted(ctx context.Context, accept *agent.AcceptedRun, sessionID, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error) { + close(c.entered) + <-c.gate + return c.Coordinator.RunAccepted(ctx, accept, sessionID, prompt, attachments...) +} + +// newRealCoordinator builds a production agent.Coordinator over a +// DB-backed session/message store, wrapped in a gate. It is constructed +// through the real agent.NewCoordinator path (via the test-only +// agenttest helper) with an offline-resolvable model: the +// cancel-on-entry path under test persists a canceled turn and returns +// before any model call, so no network I/O happens. +func newRealCoordinator(t *testing.T) (*gatedCoordinator, session.Service, message.Service) { + t.Helper() + conn, err := db.Connect(t.Context(), t.TempDir()) + require.NoError(t, err) + t.Cleanup(func() { conn.Close() }) + + q := db.New(conn) + sessions := session.NewService(q, conn) + messages := message.NewService(q) + + coord, err := agenttest.NewCoordinator(t.Context(), t.TempDir(), sessions, messages) + require.NoError(t, err) + + return &gatedCoordinator{ + Coordinator: coord, + entered: make(chan struct{}), + gate: make(chan struct{}), + }, sessions, messages +} + +// TestSendMessage_AcceptedCancelRace_RealMachinery exercises the +// 202/cancel race end-to-end through Backend.SendMessage against the +// production agent.Coordinator (BeginAccepted + RunAccepted), not a +// stub. It asserts that a cancel arriving after the prompt is accepted +// but before the run becomes active is not lost: the accepted handle +// reaches sessionAgent.Run and drives cancel-on-entry, which persists a +// canceled turn instead of streaming. +// +// This test would fail if Coordinator.BeginAccepted returned nil (Cancel +// would find no accepted run and record no mark, and the run would +// receive a nil Accepted handle and skip cancel-on-entry) or if +// Coordinator.RunAccepted dropped the handle on its way into +// sessionAgent.Run (the run would likewise skip cancel-on-entry and try +// to stream the model). In either case no FinishReasonCanceled turn +// would be persisted. +func TestSendMessage_AcceptedCancelRace_RealMachinery(t *testing.T) { + t.Parallel() + b, _ := newTestBackend(t) + + coord, sessions, messages := newRealCoordinator(t) + sess, err := sessions.Create(t.Context(), "session") + require.NoError(t, err) + + ws := insertAgentWorkspace(t, b, coord) + + require.NoError(t, b.SendMessage(ws.ID, proto.AgentMessage{SessionID: sess.ID, Prompt: "hi"})) + + // Coordinator.BeginAccepted ran synchronously inside SendMessage + // before dispatch; the dispatched run has now entered the gate but + // has not yet called the real RunAccepted, so the accept handle is + // not yet consumed: the prompt is accepted but not active. + select { + case <-coord.entered: + case <-time.After(2 * time.Second): + t.Fatal("dispatched run never entered RunAccepted") + } + + // A cancel arriving now lands in the accepted-but-not-yet-active + // window and is only recorded because BeginAccepted incremented the + // accept counter. + require.NoError(t, b.CancelSession(ws.ID, sess.ID)) + + // Release the gate so the real RunAccepted threads the handle into + // sessionAgent.Run, which drives cancel-on-entry. + close(coord.gate) + + // The dispatched run returns nil (cancel-on-entry), so runWG drains. + waited := make(chan struct{}) + go func() { + ws.runWG.Wait() + close(waited) + }() + select { + case <-waited: + case <-time.After(5 * time.Second): + t.Fatal("runWG.Wait did not complete after the canceled run returned") + } + + // The accepted-but-not-yet-active cancel persisted a canceled turn + // rather than streaming a real response. + msgs, err := messages.List(t.Context(), sess.ID) + require.NoError(t, err) + require.Len(t, msgs, 2) + require.Equal(t, message.User, msgs[0].Role) + require.Equal(t, message.Assistant, msgs[1].Role) + require.Equal(t, message.FinishReasonCanceled, msgs[1].FinishReason()) +} From d7a814c540c258f06cb798c29beff676ed4b130b Mon Sep 17 00:00:00 2001 From: Christian Rocha Date: Mon, 1 Jun 2026 20:11:26 -0400 Subject: [PATCH 13/15] fix(server): close a queued-prompt cancel race Check queued prompt cancellation while holding the same session handoff lock used for accepted and active prompts. Queue draining now follows the same ordering guarantees as prompt dispatch. Co-Authored-By: Charm Crush --- internal/agent/agent.go | 35 +++++++++++++----- internal/agent/run_complete_test.go | 56 +++++++++++++++++++++++++++++ 2 files changed, 82 insertions(+), 9 deletions(-) diff --git a/internal/agent/agent.go b/internal/agent/agent.go index 393587a111ad806dc26bbf8a80f52dba49ce0397..6728864527b2e4a7f970262f25c3953168a0b5bc 100644 --- a/internal/agent/agent.go +++ b/internal/agent/agent.go @@ -359,6 +359,30 @@ func (a *sessionAgent) enqueueCall(call SessionAgentCall) { a.messageQueue.Set(call.SessionID, existing) } +// drainUncanceledQueue copies and clears the session's message queue and +// returns the queued calls not covered by a pending cancel. The queue +// copy/delete and the cancel-mark check happen under the per-session +// dispatch mutex so the filtering is atomic against a concurrent Cancel: +// canceledBySeq requires the caller to hold that mutex, and evaluating it +// here (rather than after unlocking) prevents a cancel recorded between +// the drain and the check from being observed inconsistently. The +// returned survivors are processed by the caller without the lock held. +func (a *sessionAgent) drainUncanceledQueue(sessionID string) []SessionAgentCall { + dispatchLock := a.sessionMu(sessionID) + dispatchLock.Lock() + defer dispatchLock.Unlock() + queuedCalls, _ := a.messageQueue.Get(sessionID) + a.messageQueue.Del(sessionID) + survivors := queuedCalls[:0] + for _, queued := range queuedCalls { + if a.canceledBySeq(sessionID, queued.acceptSeq) { + continue + } + survivors = append(survivors, queued) + } + return survivors +} + // clearPendingCancel removes any pending-cancel mark for sessionID. It // takes the per-session dispatch lock so it is ordered against Cancel // and the dispatch handoff. @@ -717,15 +741,8 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (result * // part of this step. Coverage is per-call by accept sequence // so a follow-up queued after the cancel (higher seq) is // still folded in. - dispatchLock := a.sessionMu(call.SessionID) - dispatchLock.Lock() - queuedCalls, _ := a.messageQueue.Get(call.SessionID) - a.messageQueue.Del(call.SessionID) - dispatchLock.Unlock() - for _, queued := range queuedCalls { - if a.canceledBySeq(call.SessionID, queued.acceptSeq) { - continue - } + survivors := a.drainUncanceledQueue(call.SessionID) + for _, queued := range survivors { userMessage, createErr := a.createUserMessage(callContext, queued) if createErr != nil { return callContext, prepared, createErr diff --git a/internal/agent/run_complete_test.go b/internal/agent/run_complete_test.go index 74f9232a0946b24d38f05873fa39066dcae40c27..50c5fdf9355c11123d91d26b7ec2e90b219df9d4 100644 --- a/internal/agent/run_complete_test.go +++ b/internal/agent/run_complete_test.go @@ -58,6 +58,62 @@ func TestSessionAgentRun_QueueStripsOnComplete(t *testing.T) { "RunComplete still correlates with the originating SendMessage") } +// TestDrainUncanceledQueue_FiltersUnderDispatchLock verifies that the +// queue drain evaluates the per-session cancel mark while holding the +// dispatch mutex (canceledBySeq's documented precondition). Queued calls +// at or below the cancel high-water mark are dropped, calls queued after +// the cancel (higher seq) survive, untracked enqueues (seq == 0) are +// dropped whenever any mark is present, and the queue is cleared. +func TestDrainUncanceledQueue_FiltersUnderDispatchLock(t *testing.T) { + t.Parallel() + + env := testEnv(t) + a := NewSessionAgent(SessionAgentOptions{ + Sessions: env.sessions, + Messages: env.messages, + }).(*sessionAgent) + + const sessionID = "drain-session" + a.messageQueue.Set(sessionID, []SessionAgentCall{ + {SessionID: sessionID, Prompt: "below", acceptSeq: 1}, + {SessionID: sessionID, Prompt: "at-mark", acceptSeq: 2}, + {SessionID: sessionID, Prompt: "after", acceptSeq: 3}, + {SessionID: sessionID, Prompt: "untracked", acceptSeq: 0}, + }) + // Cancel high-water mark at seq 2: seq <= 2 and seq == 0 are covered. + a.cancelMark.Set(sessionID, 2) + + survivors := a.drainUncanceledQueue(sessionID) + + require.Len(t, survivors, 1, + "only the follow-up queued after the cancel (seq > mark) must survive") + require.Equal(t, "after", survivors[0].Prompt) + + _, ok := a.messageQueue.Get(sessionID) + require.False(t, ok, "drain must clear the session message queue") +} + +// TestDrainUncanceledQueue_NoMarkKeepsAll verifies that with no cancel +// mark recorded, every queued call survives the drain. +func TestDrainUncanceledQueue_NoMarkKeepsAll(t *testing.T) { + t.Parallel() + + env := testEnv(t) + a := NewSessionAgent(SessionAgentOptions{ + Sessions: env.sessions, + Messages: env.messages, + }).(*sessionAgent) + + const sessionID = "drain-nomark" + a.messageQueue.Set(sessionID, []SessionAgentCall{ + {SessionID: sessionID, Prompt: "a", acceptSeq: 0}, + {SessionID: sessionID, Prompt: "b", acceptSeq: 5}, + }) + + survivors := a.drainUncanceledQueue(sessionID) + require.Len(t, survivors, 2, "no cancel mark means all queued calls survive") +} + // TestRunCompletePublisher_MustDeliverOverTakesPublish exercises the // pubsub.Publisher interface change end-to-end: a Broker is the only // concrete Publisher implementation and must satisfy both Publish and From cbec4916b3d9bcdaeaff75157071e5e87a7526ef Mon Sep 17 00:00:00 2001 From: Christian Rocha Date: Mon, 1 Jun 2026 20:35:35 -0400 Subject: [PATCH 14/15] fix(server): complete queued prompts independently Ensure queued prompts with their own submitted identity emit one terminal completion when canceled or dropped. Callers no longer wait on a completion that was folded into another prompt's turn. Co-Authored-By: Charm Crush --- internal/agent/agent.go | 166 ++++++++++++++++++---- internal/agent/queued_runid_test.go | 181 ++++++++++++++++++++++++ internal/agent/run_complete_test.go | 211 +++++++++++++++++++++++++--- 3 files changed, 513 insertions(+), 45 deletions(-) create mode 100644 internal/agent/queued_runid_test.go diff --git a/internal/agent/agent.go b/internal/agent/agent.go index 6728864527b2e4a7f970262f25c3953168a0b5bc..b901e162c8a3e0571b097735159d044611254616 100644 --- a/internal/agent/agent.go +++ b/internal/agent/agent.go @@ -359,28 +359,96 @@ func (a *sessionAgent) enqueueCall(call SessionAgentCall) { a.messageQueue.Set(call.SessionID, existing) } -// drainUncanceledQueue copies and clears the session's message queue and -// returns the queued calls not covered by a pending cancel. The queue -// copy/delete and the cancel-mark check happen under the per-session -// dispatch mutex so the filtering is atomic against a concurrent Cancel: -// canceledBySeq requires the caller to hold that mutex, and evaluating it -// here (rather than after unlocking) prevents a cancel recorded between -// the drain and the check from being observed inconsistently. The -// returned survivors are processed by the caller without the lock held. -func (a *sessionAgent) drainUncanceledQueue(sessionID string) []SessionAgentCall { +// drainQueueForStep partitions the session's queued calls for the current +// streaming step under the per-session dispatch mutex so the filtering is +// atomic against a concurrent Cancel: canceledBySeq requires the caller to +// hold that mutex, and evaluating it here (rather than after unlocking) +// prevents a cancel recorded between the drain and the check from being +// observed inconsistently. +// +// Calls covered by a pending cancel are dropped; the dropped ones that +// carry a RunID are returned in canceledWithRunID so the caller can +// publish their terminal cancelled RunComplete (a caller waiting on that +// RunID, e.g. `crush run`, would otherwise hang). Uncanceled calls without +// a RunID are returned in fold to be folded into the active turn, +// preserving the existing follow-up behavior. Uncanceled calls that carry +// a RunID are left in the queue so each runs as its own turn via the +// recursive run path and publishes its own RunComplete, giving every +// RunID-bearing prompt an explicit lifecycle instead of being silently +// absorbed into another turn. fold is processed by the caller without the +// lock held. +func (a *sessionAgent) drainQueueForStep(sessionID string) (fold, canceledWithRunID []SessionAgentCall) { dispatchLock := a.sessionMu(sessionID) dispatchLock.Lock() defer dispatchLock.Unlock() queuedCalls, _ := a.messageQueue.Get(sessionID) - a.messageQueue.Del(sessionID) - survivors := queuedCalls[:0] + var keep []SessionAgentCall for _, queued := range queuedCalls { if a.canceledBySeq(sessionID, queued.acceptSeq) { + if queued.RunID != "" { + canceledWithRunID = append(canceledWithRunID, queued) + } + continue + } + if queued.RunID != "" { + keep = append(keep, queued) + continue + } + fold = append(fold, queued) + } + if len(keep) == 0 { + a.messageQueue.Del(sessionID) + } else { + a.messageQueue.Set(sessionID, keep) + } + return fold, canceledWithRunID +} + +// publishCanceledQueueDrops emits a terminal cancelled RunComplete for +// every dropped queued call that carries a RunID. A queued prompt removed +// from the queue without ever running — covered by a pending cancel, or +// cleared by Cancel/ClearQueue — would otherwise leave a caller blocked on +// that RunID: `crush run` ignores live message events and exits only on a +// RunComplete whose RunID matches. Calls without a RunID had no such waiter +// and are dropped silently as before. A detached, bounded context keeps the +// must-deliver publish alive even when the run context that triggered the +// drop is already canceled. +func (a *sessionAgent) publishCanceledQueueDrops(drops []SessionAgentCall) { + var hasRunID bool + for _, d := range drops { + if d.RunID != "" { + hasRunID = true + break + } + } + if !hasRunID { + return + } + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + for _, d := range drops { + if d.RunID == "" { continue } - survivors = append(survivors, queued) + a.publishRunComplete(ctx, d, notify.RunComplete{ + SessionID: d.SessionID, + RunID: d.RunID, + Cancelled: true, + }) } - return survivors +} + +// clearQueueAndNotify removes all queued prompts for the session and +// publishes a terminal cancelled RunComplete for any that carried a RunID, +// so callers waiting on those RunIDs (e.g. `crush run`) are not left +// hanging when their queued prompt is discarded without running. +func (a *sessionAgent) clearQueueAndNotify(sessionID string) { + queued, ok := a.messageQueue.Get(sessionID) + a.messageQueue.Del(sessionID) + if !ok { + return + } + a.publishCanceledQueueDrops(queued) } // clearPendingCancel removes any pending-cancel mark for sessionID. It @@ -735,14 +803,20 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (result * // Use latest tools (updated by SetTools when MCP tools change). prepared.Tools = a.tools.Copy() - // Drain queued follow-up prompts, but skip any covered by a - // cancel recorded while they sat in the queue: a cancel that - // arrived after a prompt was queued must not let it run as - // part of this step. Coverage is per-call by accept sequence - // so a follow-up queued after the cancel (higher seq) is - // still folded in. - survivors := a.drainUncanceledQueue(call.SessionID) - for _, queued := range survivors { + // Drain queued follow-up prompts for this step. Calls covered + // by a cancel recorded while they sat in the queue are dropped: + // a cancel that arrived after a prompt was queued must not let + // it run as part of this step. Coverage is per-call by accept + // sequence so a follow-up queued after the cancel (higher seq) + // is not dropped. A dropped prompt carrying a RunID still gets + // its terminal cancelled RunComplete so a caller waiting on it + // does not hang. Uncanceled prompts without a RunID are folded + // into this turn; uncanceled prompts with a RunID are left + // queued so each runs as its own turn (with its own + // RunComplete) via the recursive run path below. + fold, canceledRunIDs := a.drainQueueForStep(call.SessionID) + a.publishCanceledQueueDrops(canceledRunIDs) + for _, queued := range fold { userMessage, createErr := a.createUserMessage(callContext, queued) if createErr != nil { return callContext, prepared, createErr @@ -1125,14 +1199,22 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (result * // mark, or untracked); keep any queued after the cancel (higher // sequence) so they still run. var kept []SessionAgentCall + var canceledRunIDDrops []SessionAgentCall for _, q := range queuedMessages { if q.acceptSeq == 0 || q.acceptSeq <= mark { + if q.RunID != "" { + canceledRunIDDrops = append(canceledRunIDDrops, q) + } continue } kept = append(kept, q) } queuedMessages = kept a.messageQueue.Set(call.SessionID, kept) + // A dropped prompt carrying a RunID must still publish its + // terminal cancelled RunComplete so a caller waiting on that + // RunID does not hang. + a.publishCanceledQueueDrops(canceledRunIDDrops) } if len(queuedMessages) == 0 { // No queued work. Clear the cancel mark only when no accepted @@ -1151,12 +1233,29 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (result * mu.Unlock() return result, err } - // There are queued messages restart the loop. The recursive Run - // publishes its own RunComplete for the queued prompt, so suppress - // the outer defer's emit to avoid a duplicate event whose Error - // field would belong to the recursive turn but whose MessageID/Text - // would belong to the outer turn. + // There are queued messages, restart the loop. Suppress the outer + // defer's emit: it would otherwise observe the recursive Run's retErr + // (named-return clobbering through the return below) against this + // turn's MessageID/Text and publish a mixed, racing event. skipRunComplete = true + // Decide whether this turn still owes its own terminal RunComplete. + // Each submitted prompt with a RunID has its own lifecycle, so a turn + // that is finished and handing off to a *different* queued prompt must + // publish its own RunComplete here — leaving it to the recursive turn + // (which carries a different RunID) would hang a caller waiting on + // this turn's RunID. The exception is the summarize-continuation path, + // which re-queues this same call (same RunID) to resume after a + // summary; in that case the eventual terminal turn for this RunID + // publishes, so publishing now would double-emit. + outerOwesRunComplete := call.RunID != "" + if outerOwesRunComplete { + for _, q := range queuedMessages { + if q.RunID == call.RunID { + outerOwesRunComplete = false + break + } + } + } firstQueuedMessage := queuedMessages[0] a.messageQueue.Set(call.SessionID, queuedMessages[1:]) // Reserve a fresh accept for the dequeued prompt before dropping the @@ -1167,6 +1266,17 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (result * // the recursive Run's accepted path observes as cancel-on-entry. firstQueuedMessage.Accepted = a.BeginAccepted(call.SessionID) mu.Unlock() + if outerOwesRunComplete { + complete := notify.RunComplete{SessionID: call.SessionID, RunID: call.RunID} + if currentAssistant != nil { + complete.MessageID = currentAssistant.ID + complete.Text = currentAssistant.Content().String() + } + if ctx.Err() != nil { + complete.Cancelled = true + } + a.publishRunComplete(ctx, call, complete) + } return a.Run(ctx, firstQueuedMessage) } @@ -1770,14 +1880,14 @@ func (a *sessionAgent) Cancel(sessionID string) { if a.QueuedPrompts(sessionID) > 0 { slog.Debug("Clearing queued prompts", "session_id", sessionID) - a.messageQueue.Del(sessionID) + a.clearQueueAndNotify(sessionID) } } func (a *sessionAgent) ClearQueue(sessionID string) { if a.QueuedPrompts(sessionID) > 0 { slog.Debug("Clearing queued prompts", "session_id", sessionID) - a.messageQueue.Del(sessionID) + a.clearQueueAndNotify(sessionID) } } diff --git a/internal/agent/queued_runid_test.go b/internal/agent/queued_runid_test.go new file mode 100644 index 0000000000000000000000000000000000000000..e3e99d4f648e12d5a98be052747841553b1fa8ae --- /dev/null +++ b/internal/agent/queued_runid_test.go @@ -0,0 +1,181 @@ +package agent + +import ( + "context" + "errors" + "sync/atomic" + "testing" + "time" + + "charm.land/catwalk/pkg/catwalk" + "charm.land/fantasy" + "github.com/charmbracelet/crush/internal/agent/notify" + "github.com/charmbracelet/crush/internal/message" + "github.com/charmbracelet/crush/internal/pubsub" + "github.com/stretchr/testify/require" +) + +// gatedStreamModel streams a single text part followed by a clean finish, +// but blocks the very first Stream call until its gate is released. That +// lets a test hold a run "active" (past PrepareStep, inside Stream) just +// long enough to enqueue a follow-up prompt behind the busy session. +// Subsequent Stream calls (e.g. the recursive run draining the queue) +// proceed immediately. +type gatedStreamModel struct { + text string + gate chan struct{} + entered chan struct{} + calls atomic.Int64 +} + +func (m *gatedStreamModel) Provider() string { return "fake" } +func (m *gatedStreamModel) Model() string { return "fake-model" } + +func (m *gatedStreamModel) Generate(ctx context.Context, call fantasy.Call) (*fantasy.Response, error) { + return &fantasy.Response{ + Content: fantasy.ResponseContent{fantasy.TextContent{Text: m.text}}, + FinishReason: fantasy.FinishReasonStop, + }, nil +} + +func (m *gatedStreamModel) Stream(ctx context.Context, call fantasy.Call) (fantasy.StreamResponse, error) { + if m.calls.Add(1) == 1 { + close(m.entered) + select { + case <-m.gate: + case <-ctx.Done(): + } + } + text := m.text + return func(yield func(fantasy.StreamPart) bool) { + if !yield(fantasy.StreamPart{Type: fantasy.StreamPartTypeTextStart, ID: "1"}) { + return + } + if !yield(fantasy.StreamPart{Type: fantasy.StreamPartTypeTextDelta, ID: "1", Delta: text}) { + return + } + if !yield(fantasy.StreamPart{Type: fantasy.StreamPartTypeTextEnd, ID: "1"}) { + return + } + yield(fantasy.StreamPart{Type: fantasy.StreamPartTypeFinish, FinishReason: fantasy.FinishReasonStop}) + }, nil +} + +func (m *gatedStreamModel) GenerateObject(ctx context.Context, call fantasy.ObjectCall) (*fantasy.ObjectResponse, error) { + return nil, errors.New("not implemented") +} + +func (m *gatedStreamModel) StreamObject(ctx context.Context, call fantasy.ObjectCall) (fantasy.ObjectStreamResponse, error) { + return nil, errors.New("not implemented") +} + +// TestRun_QueuedRunIDPromptRunsRecursivelyAndPublishesRunComplete is the +// end-to-end proof of fix 2: a prompt carrying a RunID that is queued +// behind a busy session must NOT be silently folded into the active turn. +// It runs as its own turn via the recursive run path and publishes its +// own terminal RunComplete, so a `crush run` caller blocking on that +// RunID does not hang. The active turn keeps its own RunComplete too. +func TestRun_QueuedRunIDPromptRunsRecursivelyAndPublishesRunComplete(t *testing.T) { + t.Parallel() + + env := testEnv(t) + broker := pubsub.NewBroker[notify.RunComplete]() + t.Cleanup(broker.Shutdown) + + large := &gatedStreamModel{ + text: "done", + gate: make(chan struct{}), + entered: make(chan struct{}), + } + small := &finishStreamModel{text: "title"} + + sa := NewSessionAgent(SessionAgentOptions{ + LargeModel: Model{Model: large, CatwalkCfg: catwalk.Model{ContextWindow: 200000, DefaultMaxTokens: 10000}}, + SmallModel: Model{Model: small, CatwalkCfg: catwalk.Model{ContextWindow: 200000, DefaultMaxTokens: 10000}}, + IsYolo: true, + Sessions: env.sessions, + Messages: env.messages, + RunComplete: broker, + }).(*sessionAgent) + + sess, err := env.sessions.Create(t.Context(), "session") + require.NoError(t, err) + + subCtx, subCancel := context.WithCancel(t.Context()) + defer subCancel() + ch := broker.Subscribe(subCtx) + + // Start the main turn; it blocks inside Stream once active. + mainDone := make(chan error, 1) + go func() { + _, runErr := sa.Run(t.Context(), SessionAgentCall{ + SessionID: sess.ID, + RunID: "run-main", + Prompt: "main", + }) + mainDone <- runErr + }() + + // Wait until the main turn is active (inside Stream). + select { + case <-large.entered: + case <-time.After(5 * time.Second): + t.Fatal("main run never entered Stream") + } + require.True(t, sa.IsSessionBusy(sess.ID), "main run must be active before enqueueing the follow-up") + + // Enqueue a RunID-bearing follow-up behind the busy session. + res, err := sa.Run(t.Context(), SessionAgentCall{ + SessionID: sess.ID, + RunID: "run-follow", + Prompt: "follow", + }) + require.NoError(t, err) + require.Nil(t, res, "a busy-session follow-up must enqueue and return (nil, nil)") + require.Equal(t, 1, sa.QueuedPrompts(sess.ID), "the follow-up must be queued, not folded") + + // Release the main turn so it completes and hands off to the queue. + close(large.gate) + require.NoError(t, <-mainDone) + + // Both turns must publish their own terminal RunComplete. + got := map[string]notify.RunComplete{} + deadline := time.After(5 * time.Second) + for len(got) < 2 { + select { + case ev := <-ch: + got[ev.Payload.RunID] = ev.Payload + case <-deadline: + t.Fatalf("timed out waiting for both RunCompletes; got %v", got) + } + } + + main, ok := got["run-main"] + require.True(t, ok, "the active turn must publish its own RunComplete") + require.Empty(t, main.Error) + require.False(t, main.Cancelled) + + follow, ok := got["run-follow"] + require.True(t, ok, + "the queued RunID prompt must publish its own RunComplete instead of being folded silently") + require.Empty(t, follow.Error) + require.False(t, follow.Cancelled) + require.Equal(t, "done", follow.Text, "the queued prompt ran as its own turn") + + // Two distinct assistant turns prove the follow-up was not folded. + msgs, err := env.messages.List(t.Context(), sess.ID) + require.NoError(t, err) + var assistants, follows int + for _, m := range msgs { + switch m.Role { + case message.Assistant: + assistants++ + case message.User: + if m.Content().String() == "follow" { + follows++ + } + } + } + require.Equal(t, 2, assistants, "the active turn and the recursive turn each produce one assistant message") + require.Equal(t, 1, follows, "the follow-up prompt is its own user turn") +} diff --git a/internal/agent/run_complete_test.go b/internal/agent/run_complete_test.go index 50c5fdf9355c11123d91d26b7ec2e90b219df9d4..2fb6fbefab436ce97b428e9025ef79142da2ea85 100644 --- a/internal/agent/run_complete_test.go +++ b/internal/agent/run_complete_test.go @@ -58,13 +58,15 @@ func TestSessionAgentRun_QueueStripsOnComplete(t *testing.T) { "RunComplete still correlates with the originating SendMessage") } -// TestDrainUncanceledQueue_FiltersUnderDispatchLock verifies that the -// queue drain evaluates the per-session cancel mark while holding the -// dispatch mutex (canceledBySeq's documented precondition). Queued calls -// at or below the cancel high-water mark are dropped, calls queued after -// the cancel (higher seq) survive, untracked enqueues (seq == 0) are -// dropped whenever any mark is present, and the queue is cleared. -func TestDrainUncanceledQueue_FiltersUnderDispatchLock(t *testing.T) { +// TestDrainQueueForStep_FiltersUnderDispatchLock verifies that the queue +// drain evaluates the per-session cancel mark while holding the dispatch +// mutex (canceledBySeq's documented precondition). Queued calls at or +// below the cancel high-water mark are dropped, calls queued after the +// cancel (higher seq) are folded, untracked enqueues (seq == 0) are +// dropped whenever any mark is present, and the queue is cleared. These +// calls carry no RunID, so all foldable survivors are returned for +// folding (the existing follow-up behavior). +func TestDrainQueueForStep_FiltersUnderDispatchLock(t *testing.T) { t.Parallel() env := testEnv(t) @@ -83,19 +85,21 @@ func TestDrainUncanceledQueue_FiltersUnderDispatchLock(t *testing.T) { // Cancel high-water mark at seq 2: seq <= 2 and seq == 0 are covered. a.cancelMark.Set(sessionID, 2) - survivors := a.drainUncanceledQueue(sessionID) + fold, canceledWithRunID := a.drainQueueForStep(sessionID) - require.Len(t, survivors, 1, - "only the follow-up queued after the cancel (seq > mark) must survive") - require.Equal(t, "after", survivors[0].Prompt) + require.Len(t, fold, 1, + "only the follow-up queued after the cancel (seq > mark) must be folded") + require.Equal(t, "after", fold[0].Prompt) + require.Empty(t, canceledWithRunID, + "no dropped call carried a RunID, so none need a terminal RunComplete") _, ok := a.messageQueue.Get(sessionID) - require.False(t, ok, "drain must clear the session message queue") + require.False(t, ok, "drain must clear the session message queue when nothing is kept") } -// TestDrainUncanceledQueue_NoMarkKeepsAll verifies that with no cancel -// mark recorded, every queued call survives the drain. -func TestDrainUncanceledQueue_NoMarkKeepsAll(t *testing.T) { +// TestDrainQueueForStep_NoMarkFoldsAllNonRunID verifies that with no +// cancel mark recorded, every queued call without a RunID is folded. +func TestDrainQueueForStep_NoMarkFoldsAllNonRunID(t *testing.T) { t.Parallel() env := testEnv(t) @@ -110,8 +114,80 @@ func TestDrainUncanceledQueue_NoMarkKeepsAll(t *testing.T) { {SessionID: sessionID, Prompt: "b", acceptSeq: 5}, }) - survivors := a.drainUncanceledQueue(sessionID) - require.Len(t, survivors, 2, "no cancel mark means all queued calls survive") + fold, canceledWithRunID := a.drainQueueForStep(sessionID) + require.Len(t, fold, 2, "no cancel mark means all non-RunID queued calls are folded") + require.Empty(t, canceledWithRunID) +} + +// TestDrainQueueForStep_KeepsRunIDPromptsQueued is the core of fix 2: a +// queued prompt that carries a RunID must NOT be folded into the active +// turn. Folding it would silently absorb it into another turn and never +// publish a RunComplete for its RunID, hanging a `crush run` caller that +// blocks on that event. Such prompts are left in the queue so the +// recursive run path gives each its own turn and its own RunComplete. +// Non-RunID prompts are still folded. +func TestDrainQueueForStep_KeepsRunIDPromptsQueued(t *testing.T) { + t.Parallel() + + env := testEnv(t) + a := NewSessionAgent(SessionAgentOptions{ + Sessions: env.sessions, + Messages: env.messages, + }).(*sessionAgent) + + const sessionID = "drain-runid" + a.messageQueue.Set(sessionID, []SessionAgentCall{ + {SessionID: sessionID, Prompt: "fold-me", acceptSeq: 1}, + {SessionID: sessionID, RunID: "run-a", Prompt: "keep-me", acceptSeq: 2}, + {SessionID: sessionID, RunID: "run-b", Prompt: "keep-me-too", acceptSeq: 3}, + }) + + fold, canceledWithRunID := a.drainQueueForStep(sessionID) + + require.Len(t, fold, 1, "only the non-RunID prompt is folded into the active turn") + require.Equal(t, "fold-me", fold[0].Prompt) + require.Empty(t, canceledWithRunID) + + kept, ok := a.messageQueue.Get(sessionID) + require.True(t, ok, "RunID-bearing prompts must remain queued for the recursive run path") + require.Len(t, kept, 2) + require.Equal(t, "run-a", kept[0].RunID) + require.Equal(t, "run-b", kept[1].RunID) +} + +// TestDrainQueueForStep_ReportsCanceledRunIDDrops verifies that a queued +// prompt carrying a RunID that is dropped because a cancel covers it is +// reported in canceledWithRunID so the caller can publish its terminal +// cancelled RunComplete. A canceled prompt without a RunID is dropped +// silently as before. +func TestDrainQueueForStep_ReportsCanceledRunIDDrops(t *testing.T) { + t.Parallel() + + env := testEnv(t) + a := NewSessionAgent(SessionAgentOptions{ + Sessions: env.sessions, + Messages: env.messages, + }).(*sessionAgent) + + const sessionID = "drain-cancel-runid" + a.messageQueue.Set(sessionID, []SessionAgentCall{ + {SessionID: sessionID, RunID: "run-canceled", Prompt: "canceled", acceptSeq: 1}, + {SessionID: sessionID, Prompt: "canceled-no-runid", acceptSeq: 1}, + {SessionID: sessionID, RunID: "run-survives", Prompt: "survives", acceptSeq: 5}, + }) + a.cancelMark.Set(sessionID, 2) + + fold, canceledWithRunID := a.drainQueueForStep(sessionID) + + require.Empty(t, fold, "no uncanceled non-RunID prompts to fold") + require.Len(t, canceledWithRunID, 1, + "only the dropped RunID-bearing prompt needs a terminal RunComplete") + require.Equal(t, "run-canceled", canceledWithRunID[0].RunID) + + kept, ok := a.messageQueue.Get(sessionID) + require.True(t, ok) + require.Len(t, kept, 1, "the uncanceled RunID prompt stays queued") + require.Equal(t, "run-survives", kept[0].RunID) } // TestRunCompletePublisher_MustDeliverOverTakesPublish exercises the @@ -143,3 +219,104 @@ func TestRunCompletePublisher_MustDeliverOverTakesPublish(t *testing.T) { t.Fatal("PublishMustDeliver did not deliver event") } } + +// requireSingleCancelledRunComplete reads exactly one RunComplete from ch, +// asserts it is the cancelled terminal event for runID, and verifies no +// second event arrives. This observes the published pubsub event rather +// than internal bookkeeping, which is the contract a `crush run` caller +// blocking on the broker actually relies on. +func requireSingleCancelledRunComplete(t *testing.T, ch <-chan pubsub.Event[notify.RunComplete], sessionID, runID string) { + t.Helper() + select { + case ev := <-ch: + require.Equal(t, runID, ev.Payload.RunID, + "the published RunComplete must carry the dropped queued prompt's RunID") + require.Equal(t, sessionID, ev.Payload.SessionID) + require.True(t, ev.Payload.Cancelled, + "a dropped queued prompt must publish a cancelled RunComplete") + case <-time.After(5 * time.Second): + t.Fatal("timed out waiting for the cancelled RunComplete") + } + select { + case extra := <-ch: + t.Fatalf("expected exactly one RunComplete, got a second: %+v", extra.Payload) + case <-time.After(100 * time.Millisecond): + } +} + +// TestCancel_QueuedRunIDPromptPublishesCancelledRunComplete proves the +// terminal-event behavior end-to-end: a RunID-bearing prompt sitting in +// the queue that is canceled while queued (via the public Cancel path, +// which routes through clearQueueAndNotify -> publishCanceledQueueDrops) +// must emit exactly one cancelled RunComplete on the broker for its +// RunID. A queued prompt without a RunID is dropped silently. This is the +// coverage the earlier drain test lacked: it asserted the returned +// bookkeeping slice, not the published event a `crush run` caller awaits. +func TestCancel_QueuedRunIDPromptPublishesCancelledRunComplete(t *testing.T) { + t.Parallel() + + env := testEnv(t) + broker := pubsub.NewBroker[notify.RunComplete]() + t.Cleanup(broker.Shutdown) + + a := NewSessionAgent(SessionAgentOptions{ + Sessions: env.sessions, + Messages: env.messages, + RunComplete: broker, + }).(*sessionAgent) + + subCtx, subCancel := context.WithCancel(t.Context()) + defer subCancel() + ch := broker.Subscribe(subCtx) + + const sessionID = "cancel-queued-runid" + a.messageQueue.Set(sessionID, []SessionAgentCall{ + {SessionID: sessionID, Prompt: "no-runid", acceptSeq: 1}, + {SessionID: sessionID, RunID: "run-queued", Prompt: "queued", acceptSeq: 2}, + }) + + a.Cancel(sessionID) + + requireSingleCancelledRunComplete(t, ch, sessionID, "run-queued") + + _, ok := a.messageQueue.Get(sessionID) + require.False(t, ok, "Cancel must clear the queue") +} + +// TestDrainQueueForStep_DroppedRunIDPublishesCancelledRunComplete drives +// the production drain sequence (drainQueueForStep then +// publishCanceledQueueDrops, mirroring the PrepareStep handoff) and +// asserts the dropped RunID-bearing prompt actually publishes exactly one +// cancelled RunComplete on the broker. The companion bookkeeping test +// covers the returned slice; this one covers the observable terminal +// event. +func TestDrainQueueForStep_DroppedRunIDPublishesCancelledRunComplete(t *testing.T) { + t.Parallel() + + env := testEnv(t) + broker := pubsub.NewBroker[notify.RunComplete]() + t.Cleanup(broker.Shutdown) + + a := NewSessionAgent(SessionAgentOptions{ + Sessions: env.sessions, + Messages: env.messages, + RunComplete: broker, + }).(*sessionAgent) + + subCtx, subCancel := context.WithCancel(t.Context()) + defer subCancel() + ch := broker.Subscribe(subCtx) + + const sessionID = "drain-drop-runid" + a.messageQueue.Set(sessionID, []SessionAgentCall{ + {SessionID: sessionID, RunID: "run-dropped", Prompt: "dropped", acceptSeq: 1}, + {SessionID: sessionID, Prompt: "dropped-no-runid", acceptSeq: 1}, + }) + a.cancelMark.Set(sessionID, 2) + + _, canceledWithRunID := a.drainQueueForStep(sessionID) + require.Len(t, canceledWithRunID, 1) + a.publishCanceledQueueDrops(canceledWithRunID) + + requireSingleCancelledRunComplete(t, ch, sessionID, "run-dropped") +} From d411ab5edc99865159f2a86fc3c1a2b8ef3bd56b Mon Sep 17 00:00:00 2001 From: Christian Rocha Date: Mon, 1 Jun 2026 20:48:47 -0400 Subject: [PATCH 15/15] fix(server): complete background prompts that fail early Publish a terminal completion for submitted background prompts that fail before normal agent completion can run. The error still reaches observers, and duplicate completion events remain suppressed. Co-Authored-By: Charm Crush --- internal/agent/agenttest/coordinator.go | 1 + internal/agent/coordinator.go | 5 + internal/agent/run_marker.go | 52 +++++++ internal/app/app.go | 8 + internal/backend/agent.go | 38 ++++- internal/backend/agent_runcomplete_test.go | 162 +++++++++++++++++++++ 6 files changed, 261 insertions(+), 5 deletions(-) create mode 100644 internal/agent/run_marker.go create mode 100644 internal/backend/agent_runcomplete_test.go diff --git a/internal/agent/agenttest/coordinator.go b/internal/agent/agenttest/coordinator.go index 9cb1e139b20c4f54daf9ed89b2a2bf43ff72bfce..fdacb7e1292f8fcddcc903a0e70aba544d25fdd3 100644 --- a/internal/agent/agenttest/coordinator.go +++ b/internal/agent/agenttest/coordinator.go @@ -57,6 +57,7 @@ func NewCoordinator( selected := config.SelectedModel{Provider: providerID, Model: modelID} cfg.Config().Models[config.SelectedModelTypeLarge] = selected cfg.Config().Models[config.SelectedModelTypeSmall] = selected + cfg.SetupAgents() // Keep buildTools light: no sub-agent or agentic-fetch construction. coderCfg := cfg.Config().Agents[config.AgentCoder] diff --git a/internal/agent/coordinator.go b/internal/agent/coordinator.go index f5ca831e60cdb54edf0c0d7bfde83702a79701f1..bf05baa39b970a05922def09ab6c5ddb3b17dce1 100644 --- a/internal/agent/coordinator.go +++ b/internal/agent/coordinator.go @@ -294,6 +294,11 @@ func (c *coordinator) run(ctx context.Context, accept *AcceptedRun, sessionID st if hasLatest && c.runComplete != nil { c.runComplete.PublishMustDeliver(ctx, pubsub.UpdatedEvent, latest) + // Signal to the dispatcher (backend.runAgent) that the + // authoritative terminal RunComplete for this run was already + // emitted, so it does not publish a duplicate fallback for the + // error it is about to receive. + MarkRunCompletePublished(ctx) } return result, originalErr } diff --git a/internal/agent/run_marker.go b/internal/agent/run_marker.go new file mode 100644 index 0000000000000000000000000000000000000000..404cca1e8c41bb9179deb886552f3580a977fdfc --- /dev/null +++ b/internal/agent/run_marker.go @@ -0,0 +1,52 @@ +package agent + +import ( + "context" + "sync/atomic" +) + +// runCompleteMarkerKey is the unexported context key carrying a +// [runCompleteMarker] from the dispatch boundary (backend.runAgent) +// down into the coordinator. It lets the dispatcher learn whether the +// coordinator already published the authoritative terminal +// notify.RunComplete for the run, so a fallback terminal event is only +// emitted when one is actually missing (e.g. an error returned before +// sessionAgent.Run ever executed). It avoids a breaking change to the +// Coordinator interface. +type runCompleteMarkerKey struct{} + +// runCompleteMarker records whether a terminal RunComplete has been +// published for a run. It is shared by pointer through the context so +// a publish deep in the call stack is observable by the dispatcher +// after the call returns. +type runCompleteMarker struct { + published atomic.Bool +} + +// WithRunCompleteMarker returns ctx carrying a fresh marker the +// coordinator can flag via [MarkRunCompletePublished] once it emits the +// run's terminal RunComplete. Callers read the result with +// [RunCompletePublished]. Attaching the marker is optional: code paths +// without one simply skip the dedup signal. +func WithRunCompleteMarker(ctx context.Context) context.Context { + return context.WithValue(ctx, runCompleteMarkerKey{}, &runCompleteMarker{}) +} + +// MarkRunCompletePublished records that the authoritative terminal +// RunComplete has been published for the run carried by ctx. It is a +// no-op when no marker is present (e.g. the in-process/local Run path, +// which is not dispatched through backend.runAgent). +func MarkRunCompletePublished(ctx context.Context) { + if m, ok := ctx.Value(runCompleteMarkerKey{}).(*runCompleteMarker); ok { + m.published.Store(true) + } +} + +// RunCompletePublished reports whether [MarkRunCompletePublished] was +// called on ctx's marker. It returns false when no marker is present. +func RunCompletePublished(ctx context.Context) bool { + if m, ok := ctx.Value(runCompleteMarkerKey{}).(*runCompleteMarker); ok { + return m.published.Load() + } + return false +} diff --git a/internal/app/app.go b/internal/app/app.go index 9509fa3a9dc778d507d38f60c8ca523031b7ecb7..d8a3abc63b9901c528ec3cee7465eaa0b892349f 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -185,6 +185,14 @@ func (app *App) AgentNotifications() *pubsub.Broker[notify.Notification] { return app.agentNotifications } +// RunCompletions returns the broker for the authoritative per-run +// terminal RunComplete events. The dispatcher (backend.runAgent) uses +// it to emit a reliable terminal event when a run fails before the +// coordinator could publish one of its own. +func (app *App) RunCompletions() *pubsub.Broker[notify.RunComplete] { + return app.runCompletions +} + // resolveSession resolves which session to use for a non-interactive run // If continueSessionID is set, it looks up that session by ID // If useLast is set, it returns the most recently updated top-level session diff --git a/internal/backend/agent.go b/internal/backend/agent.go index 4af3b8f0d2f88ad5daff41f40664b303c948b263..3d08746ed35c07ab21221c8a6aa0df3941944fd9 100644 --- a/internal/backend/agent.go +++ b/internal/backend/agent.go @@ -61,14 +61,27 @@ func (b *Backend) SendMessage(workspaceID string, msg proto.AgentMessage) error // runAgent executes an accepted agent run for the workspace. It owns the // accept reservation (releasing it on return) and the runWG ticket added // by SendMessage. The run is bound to the workspace context so its -// lifetime is independent of any client's HTTP request. On a non-cancel -// error it surfaces the failure to observers via a notify.TypeAgentError -// notification; context.Canceled is expected (the FinishReasonCanceled -// marker is already published by sessionAgent.Run) and swallowed. +// lifetime is independent of any client's HTTP request. +// +// On a non-cancel error it surfaces the failure to observers via a +// notify.TypeAgentError notification (lossy, best-effort). That alone is +// not a reliable terminal signal: the agent-event fan-in uses lossy +// subscribers, so a `crush run` caller blocking on its RunID could hang +// if the event is dropped. To guarantee termination, when msg.RunID is +// non-empty and the coordinator did not already publish the run's +// authoritative terminal RunComplete (e.g. the error was returned before +// sessionAgent.Run executed, such as a readyWg or UpdateModels failure), +// runAgent emits an errored RunComplete on the must-deliver +// runCompletions broker so the waiter observes a deterministic terminal +// event. context.Canceled is expected (sessionAgent.Run already +// publishes the cancelled terminal marker) and produces no error +// terminal event. // // When msg.RunID is non-empty it is attached to the context via // agent.WithRunID so the coordinator can stamp the terminal -// notify.RunComplete event with that correlator. +// notify.RunComplete event with that correlator. A run-complete marker +// is also attached so the coordinator can report whether it published +// the terminal event, letting runAgent avoid a duplicate fallback. func (b *Backend) runAgent(ws *Workspace, msg proto.AgentMessage, accept *agent.AcceptedRun) { defer ws.runWG.Done() defer accept.Close() @@ -77,6 +90,7 @@ func (b *Backend) runAgent(ws *Workspace, msg proto.AgentMessage, accept *agent. if msg.RunID != "" { ctx = agent.WithRunID(ctx, msg.RunID) } + ctx = agent.WithRunCompleteMarker(ctx) _, err := ws.AgentCoordinator.RunAccepted(ctx, accept, msg.SessionID, msg.Prompt, proto.AttachmentsToMessage(msg.Attachments)...) if err == nil || errors.Is(err, context.Canceled) { @@ -89,6 +103,20 @@ func (b *Backend) runAgent(ws *Workspace, msg proto.AgentMessage, accept *agent. Type: notify.TypeAgentError, Message: err.Error(), }) + + // Reliable terminal fallback. Only needed when a RunID waiter + // exists and the coordinator has not already emitted the run's + // terminal RunComplete; otherwise this would be a duplicate. + if msg.RunID == "" || agent.RunCompletePublished(ctx) { + return + } + if rc := ws.RunCompletions(); rc != nil { + rc.PublishMustDeliver(ctx, pubsub.UpdatedEvent, notify.RunComplete{ + SessionID: msg.SessionID, + RunID: msg.RunID, + Error: err.Error(), + }) + } } // GetAgentInfo returns the agent's model and busy status. diff --git a/internal/backend/agent_runcomplete_test.go b/internal/backend/agent_runcomplete_test.go new file mode 100644 index 0000000000000000000000000000000000000000..be3df103e66b539685e42269d7ced0c7e7e94d86 --- /dev/null +++ b/internal/backend/agent_runcomplete_test.go @@ -0,0 +1,162 @@ +package backend + +import ( + "context" + "errors" + "testing" + "time" + + "charm.land/fantasy" + "github.com/charmbracelet/crush/internal/agent" + "github.com/charmbracelet/crush/internal/app" + "github.com/charmbracelet/crush/internal/message" + "github.com/charmbracelet/crush/internal/proto" + "github.com/google/uuid" + "github.com/stretchr/testify/require" +) + +// errorCoordinator is a minimal agent.Coordinator whose RunAccepted +// returns a configurable error. When markPublished is true it stamps +// the run-complete marker on the context before returning, simulating a +// real coordinator that already published the run's authoritative +// terminal RunComplete (so runAgent must not emit a duplicate fallback). +type errorCoordinator struct { + err error + markPublished bool +} + +func (c *errorCoordinator) Run(ctx context.Context, sessionID, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error) { + return nil, c.err +} + +func (c *errorCoordinator) RunAccepted(ctx context.Context, accept *agent.AcceptedRun, sessionID, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error) { + if c.markPublished { + agent.MarkRunCompletePublished(ctx) + } + return nil, c.err +} + +func (c *errorCoordinator) BeginAccepted(sessionID string) *agent.AcceptedRun { return nil } +func (c *errorCoordinator) Cancel(string) {} +func (c *errorCoordinator) CancelAll() {} +func (c *errorCoordinator) IsBusy() bool { return false } +func (c *errorCoordinator) IsSessionBusy(string) bool { return false } +func (c *errorCoordinator) QueuedPrompts(string) int { return 0 } +func (c *errorCoordinator) QueuedPromptsList(string) []string { return nil } +func (c *errorCoordinator) ClearQueue(string) {} +func (c *errorCoordinator) Summarize(context.Context, string) error { return nil } +func (c *errorCoordinator) Model() agent.Model { return agent.Model{} } +func (c *errorCoordinator) UpdateModels(context.Context) error { return nil } + +// insertRunCompleteWorkspace installs a workspace backed by a real +// app.App (so the runCompletions broker exists) with the given +// coordinator and a workspace run context derived from base. +func insertRunCompleteWorkspace(t *testing.T, b *Backend, base context.Context, coord agent.Coordinator) *Workspace { + t.Helper() + a := app.NewForTest(base) + a.AgentCoordinator = coord + t.Cleanup(a.ShutdownForTest) + ws := &Workspace{ + ID: uuid.New().String(), + Path: t.TempDir(), + resolvedPath: t.TempDir(), + clients: make(map[string]*clientState), + shutdownFn: func() {}, + } + ws.App = a + ws.ctx, ws.cancel = context.WithCancel(base) + b.mu.Lock() + b.workspaces.Set(ws.ID, ws) + b.pathIndex[ws.resolvedPath] = ws.ID + b.mu.Unlock() + return ws +} + +// TestRunAgent_PreRunErrorPublishesTerminalRunComplete proves that an +// error returned from RunAccepted before the coordinator could publish +// its own terminal event (e.g. a readyWg or UpdateModels failure, +// modeled here by a stub coordinator) still yields a reliable terminal +// RunComplete for the run's RunID. Without it, a `crush run` caller +// blocking on that RunID would hang because the lossy TypeAgentError +// event is not a guaranteed terminal signal. +func TestRunAgent_PreRunErrorPublishesTerminalRunComplete(t *testing.T) { + t.Parallel() + b, _ := newTestBackend(t) + runErr := errors.New("update models failed") + ws := insertRunCompleteWorkspace(t, b, context.Background(), &errorCoordinator{err: runErr}) + + subCtx, cancel := context.WithCancel(context.Background()) + defer cancel() + ch := ws.RunCompletions().Subscribe(subCtx) + + err := b.SendMessage(ws.ID, proto.AgentMessage{SessionID: "S1", RunID: "run-1", Prompt: "hi"}) + require.NoError(t, err) + + select { + case ev := <-ch: + require.Equal(t, "run-1", ev.Payload.RunID, + "the terminal RunComplete must carry the dispatched RunID") + require.Equal(t, "S1", ev.Payload.SessionID) + require.Equal(t, runErr.Error(), ev.Payload.Error, + "the fallback terminal event must be marked errored") + require.False(t, ev.Payload.Cancelled) + case <-time.After(2 * time.Second): + t.Fatal("no terminal RunComplete published for a pre-run error; a run waiter would hang") + } +} + +// TestRunAgent_NoFallbackWhenCoordinatorPublished ensures the fallback +// is suppressed when the coordinator already emitted the run's +// authoritative terminal RunComplete, so callers never observe a +// duplicate terminal event for the same RunID. +func TestRunAgent_NoFallbackWhenCoordinatorPublished(t *testing.T) { + t.Parallel() + b, _ := newTestBackend(t) + runErr := errors.New("stream failed after publishing terminal event") + ws := insertRunCompleteWorkspace(t, b, context.Background(), + &errorCoordinator{err: runErr, markPublished: true}) + + subCtx, cancel := context.WithCancel(context.Background()) + defer cancel() + ch := ws.RunCompletions().Subscribe(subCtx) + + err := b.SendMessage(ws.ID, proto.AgentMessage{SessionID: "S1", RunID: "run-1", Prompt: "hi"}) + require.NoError(t, err) + + // Wait for the dispatched run goroutine to return so any publish + // has already happened. + ws.runWG.Wait() + + select { + case ev := <-ch: + t.Fatalf("runAgent published a duplicate terminal RunComplete: %+v", ev.Payload) + case <-time.After(200 * time.Millisecond): + } +} + +// TestRunAgent_CancellationPublishesNoErrorTerminal verifies that a +// context.Canceled result from RunAccepted produces no errored terminal +// RunComplete from runAgent: cancellation is sessionAgent.Run's +// responsibility (it publishes the cancelled marker) and the dispatcher +// must not synthesize an error terminal for it. +func TestRunAgent_CancellationPublishesNoErrorTerminal(t *testing.T) { + t.Parallel() + b, _ := newTestBackend(t) + ws := insertRunCompleteWorkspace(t, b, context.Background(), + &errorCoordinator{err: context.Canceled}) + + subCtx, cancel := context.WithCancel(context.Background()) + defer cancel() + ch := ws.RunCompletions().Subscribe(subCtx) + + err := b.SendMessage(ws.ID, proto.AgentMessage{SessionID: "S1", RunID: "run-1", Prompt: "hi"}) + require.NoError(t, err) + + ws.runWG.Wait() + + select { + case ev := <-ch: + t.Fatalf("cancellation must not publish a terminal RunComplete: %+v", ev.Payload) + case <-time.After(200 * time.Millisecond): + } +}