diff --git a/internal/agent/accepted_run_test.go b/internal/agent/accepted_run_test.go new file mode 100644 index 0000000000000000000000000000000000000000..14aec44265d44fa0f5b055ef14af3086e13a0cf3 --- /dev/null +++ b/internal/agent/accepted_run_test.go @@ -0,0 +1,231 @@ +package agent + +import ( + "context" + "testing" + + "github.com/charmbracelet/crush/internal/message" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// newCancelTestAgent builds a DB-backed sessionAgent with no model. The +// tests here exercise the dispatch/cancel/persist paths, none of which +// reach agent.Stream, so a model is unnecessary. +func newCancelTestAgent(t *testing.T) (*sessionAgent, fakeEnv) { + t.Helper() + env := testEnv(t) + sa := NewSessionAgent(SessionAgentOptions{ + Sessions: env.sessions, + Messages: env.messages, + }).(*sessionAgent) + return sa, env +} + +func (a *sessionAgent) acceptedCount(sessionID string) int { + c, _ := a.acceptedRuns.Get(sessionID) + return c +} + +func (a *sessionAgent) hasPendingCancel(sessionID string) bool { + mark, ok := a.cancelMark.Get(sessionID) + return ok && mark > 0 +} + +func (a *sessionAgent) pendingCancelMark(sessionID string) uint64 { + mark, _ := a.cancelMark.Get(sessionID) + return mark +} + +func TestAcceptedRun_CloseIsIdempotent(t *testing.T) { + t.Parallel() + sa, _ := newCancelTestAgent(t) + + accept := sa.BeginAccepted("sid") + require.Equal(t, "sid", accept.SessionID()) + require.Equal(t, 1, sa.acceptedCount("sid")) + + accept.Close() + require.Equal(t, 0, sa.acceptedCount("sid")) + + // Repeated Close must not underflow the counter. + accept.Close() + accept.Close() + require.Equal(t, 0, sa.acceptedCount("sid")) +} + +func TestAcceptedRun_MultipleReservations(t *testing.T) { + t.Parallel() + sa, _ := newCancelTestAgent(t) + + a1 := sa.BeginAccepted("sid") + a2 := sa.BeginAccepted("sid") + require.Equal(t, 2, sa.acceptedCount("sid")) + + a1.Close() + require.Equal(t, 1, sa.acceptedCount("sid")) + + a2.Close() + require.Equal(t, 0, sa.acceptedCount("sid")) +} + +func TestAcceptedRun_NilSafe(t *testing.T) { + t.Parallel() + var accept *AcceptedRun + require.Equal(t, "", accept.SessionID()) + // Must not panic. + accept.Close() +} + +func TestCancel_IdleDoesNotRecordPendingCancel(t *testing.T) { + t.Parallel() + sa, _ := newCancelTestAgent(t) + + // No accepted run, no active request: a true no-op. + sa.Cancel("sid") + require.False(t, sa.hasPendingCancel("sid")) +} + +func TestCancel_AcceptedRecordsPendingCancel(t *testing.T) { + t.Parallel() + sa, _ := newCancelTestAgent(t) + + accept := sa.BeginAccepted("sid") + defer accept.Close() + + sa.Cancel("sid") + require.True(t, sa.hasPendingCancel("sid")) +} + +func TestCancel_SecondCancelWhilePendingIsNoOp(t *testing.T) { + t.Parallel() + sa, _ := newCancelTestAgent(t) + + accept := sa.BeginAccepted("sid") + defer accept.Close() + + sa.Cancel("sid") + require.True(t, sa.hasPendingCancel("sid")) + + // A second cancel while a pending cancel is already recorded must + // remain a single pending cancel; one Run consumes exactly one. + sa.Cancel("sid") + require.True(t, sa.hasPendingCancel("sid")) +} + +func TestRun_CancelOnEntryPersistsCanceledTurn(t *testing.T) { + t.Parallel() + sa, env := newCancelTestAgent(t) + + sess, err := env.sessions.Create(t.Context(), "session") + require.NoError(t, err) + + accept := sa.BeginAccepted(sess.ID) + // A cancel arrives in the accepted-but-not-yet-active window. + sa.Cancel(sess.ID) + require.True(t, sa.hasPendingCancel(sess.ID)) + + result, err := sa.Run(t.Context(), SessionAgentCall{ + SessionID: sess.ID, + Prompt: "hello", + Accepted: accept, + }) + require.NoError(t, err) + require.Nil(t, result) + + // The pending cancel was consumed and the accept released. + require.False(t, sa.hasPendingCancel(sess.ID)) + require.Equal(t, 0, sa.acceptedCount(sess.ID)) + + msgs, err := env.messages.List(t.Context(), sess.ID) + require.NoError(t, err) + require.Len(t, msgs, 2) + assert.Equal(t, message.User, msgs[0].Role) + assert.Equal(t, message.Assistant, msgs[1].Role) + assert.Equal(t, message.FinishReasonCanceled, msgs[1].FinishReason()) +} + +func TestPersistCanceledTurn_WritesBothWhenUserMissing(t *testing.T) { + t.Parallel() + sa, env := newCancelTestAgent(t) + + sess, err := env.sessions.Create(t.Context(), "session") + require.NoError(t, err) + + err = sa.persistCanceledTurn(t.Context(), SessionAgentCall{ + SessionID: sess.ID, + Prompt: "hello", + }, false) + require.NoError(t, err) + + msgs, err := env.messages.List(t.Context(), sess.ID) + require.NoError(t, err) + require.Len(t, msgs, 2) + assert.Equal(t, message.User, msgs[0].Role) + assert.Equal(t, message.Assistant, msgs[1].Role) + assert.Equal(t, message.FinishReasonCanceled, msgs[1].FinishReason()) +} + +func TestPersistCanceledTurn_WritesAssistantOnlyWhenUserCreated(t *testing.T) { + t.Parallel() + sa, env := newCancelTestAgent(t) + + sess, err := env.sessions.Create(t.Context(), "session") + require.NoError(t, err) + + // Simulate PrepareStep having already created the user message. + _, err = sa.createUserMessage(t.Context(), SessionAgentCall{ + SessionID: sess.ID, + Prompt: "hello", + }) + require.NoError(t, err) + + err = sa.persistCanceledTurn(t.Context(), SessionAgentCall{ + SessionID: sess.ID, + Prompt: "hello", + }, true) + require.NoError(t, err) + + msgs, err := env.messages.List(t.Context(), sess.ID) + require.NoError(t, err) + require.Len(t, msgs, 2) + assert.Equal(t, message.User, msgs[0].Role) + assert.Equal(t, message.Assistant, msgs[1].Role) + assert.Equal(t, message.FinishReasonCanceled, msgs[1].FinishReason()) +} + +func TestPersistCanceledTurn_SucceedsWithCanceledContext(t *testing.T) { + t.Parallel() + sa, env := newCancelTestAgent(t) + + sess, err := env.sessions.Create(t.Context(), "session") + require.NoError(t, err) + + // Simulate workspace shutdown having already canceled the run + // context. WithoutCancel must let the writes through. + ctx, cancel := context.WithCancel(t.Context()) + cancel() + + err = sa.persistCanceledTurn(ctx, SessionAgentCall{ + SessionID: sess.ID, + Prompt: "hello", + }, false) + require.NoError(t, err) + + msgs, err := env.messages.List(t.Context(), sess.ID) + require.NoError(t, err) + require.Len(t, msgs, 2) +} + +func TestClearPendingCancel(t *testing.T) { + t.Parallel() + sa, _ := newCancelTestAgent(t) + + accept := sa.BeginAccepted("sid") + defer accept.Close() + sa.Cancel("sid") + require.True(t, sa.hasPendingCancel("sid")) + + sa.clearPendingCancel("sid") + require.False(t, sa.hasPendingCancel("sid")) +} diff --git a/internal/agent/agent.go b/internal/agent/agent.go index 5461ae7c5bd3ca055f286635199222ad02facfa5..f4972b181a3ddadfbb7c0c652fd0060f928a9865 100644 --- a/internal/agent/agent.go +++ b/internal/agent/agent.go @@ -21,6 +21,7 @@ import ( "strconv" "strings" "sync" + "sync/atomic" "time" "charm.land/catwalk/pkg/catwalk" @@ -103,10 +104,28 @@ type SessionAgentCall struct { // recursion drains, so falling back to the default broker // publish keeps the event visible to subscribers. OnComplete func(notify.RunComplete) + // Accepted, when non-nil, is the accept reservation taken by + // BeginAccepted before the call was dispatched onto a goroutine + // (the client/server fire-and-forget path). Run consumes it under + // dispatchMu[SessionID] once the accepted -> (cancel-on-entry | + // queued | active) transition has been chosen. When nil + // (in-process / local callers like AppWorkspace), behavior is + // unchanged and no accept tracking applies. + Accepted *AcceptedRun + // acceptSeq carries the accept sequence of the handle that produced + // this call after it has been enqueued and its Accepted handle + // stripped. The queue-drain paths compare it against a session's + // cancel mark so a follow-up queued before a cancel is dropped while + // one queued after the cancel survives. 0 means untracked (an + // in-process enqueue with no accept reservation), which the drain + // paths treat as covered by any present mark, preserving the + // pre-sequence behavior. + acceptSeq uint64 } type SessionAgent interface { Run(context.Context, SessionAgentCall) (*fantasy.AgentResult, error) + BeginAccepted(sessionID string) *AcceptedRun SetModels(large Model, small Model) SetTools(tools []fantasy.AgentTool) SetSystemPrompt(systemPrompt string) @@ -145,6 +164,43 @@ type sessionAgent struct { messageQueue *csync.Map[string, []SessionAgentCall] activeRequests *csync.Map[string, context.CancelFunc] + + // dispatchMu holds a per-session mutex that serializes the + // accepted -> (cancel-on-entry | queued | active) transition in + // Run against a concurrent Cancel. The lock is held only during + // the brief handoff (no DB or LLM I/O under the lock). + dispatchMu *csync.Map[string, *sync.Mutex] + // acceptedRuns counts dispatched-but-not-yet-active runs per + // session. A counter > 0 means a dispatched prompt is in flight + // and has not yet completed the dispatch handoff in Run. Only + // BeginAccepted increments it; only AcceptedRun.Close decrements + // it. + acceptedRuns *csync.Map[string, int] + // cancelMark records, per session, a high-water accept sequence: an + // accepted handle is canceled by it iff the handle's sequence is at + // or below the mark. Cancel raises the mark to the latest sequence + // assigned at cancel time, so a single Cancel covers every prompt + // accepted-but-not-yet-active then, while a prompt accepted later + // (higher sequence) is never poisoned. Absent or 0 means no pending + // cancel. It is only raised by Cancel when acceptedRuns > 0, so an + // idle Escape never records a mark. + cancelMark *csync.Map[string, uint64] + // dispatchMuCreate guards lazy creation of per-session entries in + // dispatchMu so two goroutines can't race to lock different mutex + // instances for the same session. + dispatchMuCreate sync.Mutex + // acceptedMu serializes increments/decrements of acceptedRuns and + // the assignment of accept sequence numbers from acceptSeqGen. It + // is separate from dispatchMu so AcceptedRun.Close (which may run + // while Run holds dispatchMu for the same session) does not + // deadlock by re-entering the dispatch lock. + acceptedMu sync.Mutex + // acceptSeqGen is the monotonic source of accept sequence numbers. + // Each BeginAccepted increments it under acceptedMu and stamps the + // returned handle, so sequences strictly increase in accept order + // across the agent. Cancel uses its current value as the per-session + // high-water mark. + acceptSeqGen uint64 } type SessionAgentOptions struct { @@ -180,33 +236,403 @@ func NewSessionAgent( runComplete: opts.RunComplete, messageQueue: csync.NewMap[string, []SessionAgentCall](), activeRequests: csync.NewMap[string, context.CancelFunc](), + dispatchMu: csync.NewMap[string, *sync.Mutex](), + acceptedRuns: csync.NewMap[string, int](), + cancelMark: csync.NewMap[string, uint64](), } } -func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (result *fantasy.AgentResult, retErr error) { +// AcceptedRun owns exactly one accept reservation taken by +// BeginAccepted. It is the only carrier of accept-state across the +// backend.runAgent / Coordinator.Run / sessionAgent.Run layers: a +// counter > 0 means a dispatched prompt is in flight and has not yet +// completed the dispatch handoff in Run. Close is the only way to +// release the reservation and is idempotent. +type AcceptedRun struct { + agent *sessionAgent + sessionID string + // seq is the monotonic accept sequence stamped by BeginAccepted. A + // cancel covers this handle iff seq is at or below the session's + // cancel mark, so a handle accepted after a cancel (higher seq) is + // never poisoned by it. + seq uint64 + done atomic.Bool +} + +// Close decrements the accept counter for this reservation. It is safe +// to call multiple times; only the first call has effect. +func (r *AcceptedRun) Close() { + if r == nil { + return + } + if !r.done.CompareAndSwap(false, true) { + return + } + r.agent.endAccepted(r.sessionID) +} + +// SessionID exposes the session this reservation is for so the run path +// can use it without an extra parameter. +func (r *AcceptedRun) SessionID() string { + if r == nil { + return "" + } + return r.sessionID +} + +// BeginAccepted increments the accept counter for sessionID and returns +// a handle whose Close is the only way to decrement it. It is the only +// entry point that mutates acceptedRuns. +func (a *sessionAgent) BeginAccepted(sessionID string) *AcceptedRun { + a.acceptedMu.Lock() + defer a.acceptedMu.Unlock() + count, _ := a.acceptedRuns.Get(sessionID) + a.acceptedRuns.Set(sessionID, count+1) + a.acceptSeqGen++ + return &AcceptedRun{agent: a, sessionID: sessionID, seq: a.acceptSeqGen} +} + +// endAccepted decrements the accept counter for sessionID. It is only +// called via AcceptedRun.Close. It uses a dedicated lock (not the +// per-session dispatch mutex) so it can run while Run holds dispatchMu +// for the same session without deadlocking. +// +// When the count reaches zero the session's cancel mark is dropped: no +// accepted handle remains for it to cover, and any handle accepted later +// gets a strictly higher sequence that the mark would not match anyway. +// Handles canceled on entry never reach RunComplete, so this is the only +// place that clears the mark for an all-canceled batch. Sibling handles +// covered by the same mark are serialized on the per-session dispatch +// mutex and read the mark before they Close, so this never clears it out +// from under a covered handle still waiting to enter Run. +func (a *sessionAgent) endAccepted(sessionID string) { + a.acceptedMu.Lock() + defer a.acceptedMu.Unlock() + count, ok := a.acceptedRuns.Get(sessionID) + if !ok || count <= 1 { + a.acceptedRuns.Del(sessionID) + a.cancelMark.Del(sessionID) + return + } + a.acceptedRuns.Set(sessionID, count-1) +} + +// sessionMu returns the per-session dispatch mutex, creating it on first +// use. Creation is guarded so concurrent callers always observe the same +// mutex instance for a given session. +func (a *sessionAgent) sessionMu(sessionID string) *sync.Mutex { + if mu, ok := a.dispatchMu.Get(sessionID); ok { + return mu + } + a.dispatchMuCreate.Lock() + defer a.dispatchMuCreate.Unlock() + if mu, ok := a.dispatchMu.Get(sessionID); ok { + return mu + } + mu := &sync.Mutex{} + a.dispatchMu.Set(sessionID, mu) + return mu +} + +// enqueueCall appends call to the session's message queue. The +// OnComplete hook is stripped: the caller that supplied it (typically +// coordinator.Run) has its own retry/coalesce scope that ends when it +// returns, so by the time the queue drains nobody is left to consume the +// buffered terminal event. The recursive Run falls back to the default +// broker publish, which is what existing subscribers expect for queued +// turns. +func (a *sessionAgent) enqueueCall(call SessionAgentCall) { + existing, ok := a.messageQueue.Get(call.SessionID) + if !ok { + existing = []SessionAgentCall{} + } + queued := call + if call.Accepted != nil { + // Preserve the accept sequence after the handle is stripped so + // the queue-drain paths can tell a follow-up queued before a + // cancel (covered by the mark) from one queued after it. + queued.acceptSeq = call.Accepted.seq + } + queued.OnComplete = nil + queued.Accepted = nil + existing = append(existing, queued) + a.messageQueue.Set(call.SessionID, existing) +} + +// drainQueueForStep partitions the session's queued calls for the current +// streaming step under the per-session dispatch mutex so the filtering is +// atomic against a concurrent Cancel: canceledBySeq requires the caller to +// hold that mutex, and evaluating it here (rather than after unlocking) +// prevents a cancel recorded between the drain and the check from being +// observed inconsistently. +// +// Calls covered by a pending cancel are dropped; the dropped ones that +// carry a RunID are returned in canceledWithRunID so the caller can +// publish their terminal cancelled RunComplete (a caller waiting on that +// RunID, e.g. `crush run`, would otherwise hang). Uncanceled calls without +// a RunID are returned in fold to be folded into the active turn, +// preserving the existing follow-up behavior. Uncanceled calls that carry +// a RunID are left in the queue so each runs as its own turn via the +// recursive run path and publishes its own RunComplete, giving every +// RunID-bearing prompt an explicit lifecycle instead of being silently +// absorbed into another turn. fold is processed by the caller without the +// lock held. +func (a *sessionAgent) drainQueueForStep(sessionID string) (fold, canceledWithRunID []SessionAgentCall) { + dispatchLock := a.sessionMu(sessionID) + dispatchLock.Lock() + defer dispatchLock.Unlock() + queuedCalls, _ := a.messageQueue.Get(sessionID) + var keep []SessionAgentCall + for _, queued := range queuedCalls { + if a.canceledBySeq(sessionID, queued.acceptSeq) { + if queued.RunID != "" { + canceledWithRunID = append(canceledWithRunID, queued) + } + continue + } + if queued.RunID != "" { + keep = append(keep, queued) + continue + } + fold = append(fold, queued) + } + if len(keep) == 0 { + a.messageQueue.Del(sessionID) + } else { + a.messageQueue.Set(sessionID, keep) + } + return fold, canceledWithRunID +} + +// publishCanceledQueueDrops emits a terminal cancelled RunComplete for +// every dropped queued call that carries a RunID. A queued prompt removed +// from the queue without ever running — covered by a pending cancel, or +// cleared by Cancel/ClearQueue — would otherwise leave a caller blocked on +// that RunID: `crush run` ignores live message events and exits only on a +// RunComplete whose RunID matches. Calls without a RunID had no such waiter +// and are dropped silently as before. A detached, bounded context keeps the +// must-deliver publish alive even when the run context that triggered the +// drop is already canceled. +func (a *sessionAgent) publishCanceledQueueDrops(drops []SessionAgentCall) { + var hasRunID bool + for _, d := range drops { + if d.RunID != "" { + hasRunID = true + break + } + } + if !hasRunID { + return + } + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + for _, d := range drops { + if d.RunID == "" { + continue + } + a.publishRunComplete(ctx, d, notify.RunComplete{ + SessionID: d.SessionID, + RunID: d.RunID, + Cancelled: true, + }) + } +} + +// clearQueueAndNotify removes all queued prompts for the session and +// publishes a terminal cancelled RunComplete for any that carried a RunID, +// so callers waiting on those RunIDs (e.g. `crush run`) are not left +// hanging when their queued prompt is discarded without running. +func (a *sessionAgent) clearQueueAndNotify(sessionID string) { + queued, ok := a.messageQueue.Get(sessionID) + a.messageQueue.Del(sessionID) + if !ok { + return + } + a.publishCanceledQueueDrops(queued) +} + +// clearPendingCancel removes any pending-cancel mark for sessionID. It +// takes the per-session dispatch lock so it is ordered against Cancel +// and the dispatch handoff. +func (a *sessionAgent) clearPendingCancel(sessionID string) { + mu := a.sessionMu(sessionID) + mu.Lock() + defer mu.Unlock() + a.cancelMark.Del(sessionID) +} + +// canceledBySeq reports whether an accepted handle or queued call with +// the given accept sequence is covered by a pending cancel for the +// session. Callers must hold the session's dispatch mutex. A tracked +// sequence (seq > 0) is covered only when it is at or below the cancel +// high-water mark, so a prompt accepted after the cancel (higher seq) is +// never poisoned. An untracked sequence (seq == 0, an in-process enqueue +// with no accept reservation) is covered whenever any mark is present, +// preserving the pre-sequence behavior. The mark is not consumed: it +// stays so every sibling handle it covers observes the same cancel, and +// a later handle (higher seq) ignores it regardless. +func (a *sessionAgent) canceledBySeq(sessionID string, seq uint64) bool { + mark, ok := a.cancelMark.Get(sessionID) + if !ok || mark == 0 { + return false + } + return seq == 0 || seq <= mark +} + +// persistCanceledTurn writes the user/assistant records for a turn that +// was canceled before (or just as) streaming would have produced them. +// It creates the user message only when it was not already created by an +// earlier createUserMessage call (userMsgCreated), then writes an +// assistant message with FinishReasonCanceled. Both writes use +// context.WithoutCancel(ctx) so workspace shutdown (which cancels the run +// context) can't drop them. +func (a *sessionAgent) persistCanceledTurn(ctx context.Context, call SessionAgentCall, userMsgCreated bool) error { + writeCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), 5*time.Second) + defer cancel() + if !userMsgCreated { + if _, err := a.createUserMessage(writeCtx, call); err != nil { + return err + } + } + largeModel := a.largeModel.Get() + assistant, err := a.messages.Create(writeCtx, call.SessionID, message.CreateMessageParams{ + Role: message.Assistant, + Parts: []message.ContentPart{}, + Model: largeModel.ModelCfg.Model, + Provider: largeModel.ModelCfg.Provider, + }) + if err != nil { + return err + } + assistant.AddFinish(message.FinishReasonCanceled, "User canceled request", "") + return a.messages.Update(writeCtx, assistant) +} + +// publishRunComplete emits the authoritative terminal event for a turn. +// It honors the per-call OnComplete hook when set (so the coordinator can +// coalesce retries) and otherwise falls back to the RunComplete broker. +// ctx is used only for the bounded-blocking must-deliver publish; the +// terminal payload is supplied by the caller. This is the single emit path +// shared by the streaming defer and the cancel-on-entry early return so a +// caller waiting on RunComplete (e.g. `crush run` with a RunID) always +// observes exactly one terminal event regardless of which Run branch ends +// the turn. +func (a *sessionAgent) publishRunComplete(ctx context.Context, call SessionAgentCall, complete notify.RunComplete) { + if call.OnComplete != nil { + call.OnComplete(complete) + return + } + if a.runComplete == nil { + return + } + a.runComplete.PublishMustDeliver(ctx, pubsub.UpdatedEvent, complete) +} + +// ValidateCall performs the cheap structural validation that +// sessionAgent.Run requires before a call can be dispatched: a call must +// carry either a non-empty prompt or a text attachment, and it must name a +// session. It is exported so callers that accept a run before dispatching it +// (e.g. backend.SendMessage) can apply the same checks and keep the error +// contract consistent. +func ValidateCall(call SessionAgentCall) error { if call.Prompt == "" && !message.ContainsTextAttachment(call.Attachments) { - return nil, ErrEmptyPrompt + return ErrEmptyPrompt } if call.SessionID == "" { - return nil, ErrSessionMissing - } - - // Queue the message if busy. Strip OnComplete: the caller that - // supplied the hook (typically coordinator.Run) has its own - // retry/coalesce scope that ends when it returns, so by the time - // the queue drains nobody is left to consume the buffered - // terminal event. The recursive Run will fall back to the - // default broker publish, which is what existing subscribers - // expect for queued turns. - if a.IsSessionBusy(call.SessionID) { - existing, ok := a.messageQueue.Get(call.SessionID) - if !ok { - existing = []SessionAgentCall{} + return ErrSessionMissing + } + return nil +} + +func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (result *fantasy.AgentResult, retErr error) { + if err := ValidateCall(call); err != nil { + return nil, err + } + + // genCtx/cancel are the run context and its cancel func. For the + // accepted (fire-and-forget) dispatch path they are created under + // dispatchMu below so a concurrent Cancel can observe the + // activeRequests entry before the assistant message exists. For + // the in-process path they stay nil here and are created later, + // preserving the original ordering. + var ( + genCtx context.Context + cancel context.CancelFunc + activeRegistered bool + userMsgCreated bool + ) + + if call.Accepted != nil { + // Serialize the accepted -> (cancel-on-entry | queued | + // active) transition against a concurrent Cancel. Cancel takes + // the same per-session lock, so every cancel observes at least + // one of: a cancel mark, an activeRequests entry, or a + // messageQueue entry it then clears. + mu := a.sessionMu(call.SessionID) + mu.Lock() + + if a.canceledBySeq(call.SessionID, call.Accepted.seq) { + // Cancel-on-entry: a cancel arrived while this run was + // dispatched but not yet active, and this handle's accept + // sequence is at or below the session's cancel mark. The + // mark is left in place so sibling handles it also covers + // observe the same cancel; release the accept reservation, + // drop the lock, and persist a canceled turn without + // entering Stream. + // + // This path returns before the streaming defer that + // publishes RunComplete is installed, so emit the terminal + // event explicitly. Without it, a caller waiting on + // RunComplete for this RunID (e.g. `crush run`, which + // ignores message events and blocks on RunComplete) would + // hang on an immediately-canceled accepted run. + call.Accepted.Close() + mu.Unlock() + complete := notify.RunComplete{ + SessionID: call.SessionID, + RunID: call.RunID, + Cancelled: true, + } + if err := a.persistCanceledTurn(ctx, call, false); err != nil { + complete.Error = err.Error() + a.publishRunComplete(ctx, call, complete) + return nil, err + } + a.publishRunComplete(ctx, call, complete) + return nil, nil } - queued := call - queued.OnComplete = nil - existing = append(existing, queued) - a.messageQueue.Set(call.SessionID, existing) + + if a.IsSessionBusy(call.SessionID) { + // Busy: an earlier prompt is active. Queue this call and + // release the accept reservation. A Cancel arriving after + // this point sees the active entry and clears the queue. + a.enqueueCall(call) + call.Accepted.Close() + mu.Unlock() + return nil, nil + } + + // Idle: become the active run. Register the cancel func before + // dropping the lock so a Cancel that arrives between here and + // assistant creation is not lost. + runCtx := context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID) + genCtx, cancel = context.WithCancel(runCtx) + a.activeRequests.Set(call.SessionID, cancel) + activeRegistered = true + call.Accepted.Close() + mu.Unlock() + + defer cancel() + defer a.activeRequests.Del(call.SessionID) + } else if a.IsSessionBusy(call.SessionID) { + // Queue the message if busy. Strip OnComplete: the caller that + // supplied the hook (typically coordinator.Run) has its own + // retry/coalesce scope that ends when it returns, so by the time + // the queue drains nobody is left to consume the buffered + // terminal event. The recursive Run will fall back to the + // default broker publish, which is what existing subscribers + // expect for queued turns. + a.enqueueCall(call) return nil, nil } @@ -269,15 +695,22 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (result * if err != nil { return nil, err } + userMsgCreated = true // Add the session to the context. ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID) - genCtx, cancel := context.WithCancel(ctx) - a.activeRequests.Set(call.SessionID, cancel) + // For the accepted dispatch path the run context and cancel func + // were already created and registered under dispatchMu above; reuse + // them. For the in-process path create them here, preserving the + // original ordering. + if !activeRegistered { + genCtx, cancel = context.WithCancel(ctx) + a.activeRequests.Set(call.SessionID, cancel) - defer cancel() - defer a.activeRequests.Del(call.SessionID) + defer cancel() + defer a.activeRequests.Del(call.SessionID) + } // skipRunComplete is set just before the queued-recursion path so // the outer Run doesn't publish a RunComplete that would race // with — and be superseded by — the recursive call's own @@ -305,7 +738,13 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (result * // (the pubsub broker fan-in does not serialize publishes from // different upstream brokers). defer func() { - if flushErr := a.messages.FlushAll(ctx); flushErr != nil { + // Use a context detached from the run context: workspace + // shutdown cancels ctx before this goroutine returns, but the + // buffered streaming deltas must still land before the DB is + // closed. A short timeout bounds the flush. + flushCtx, flushCancel := context.WithTimeout(context.WithoutCancel(ctx), 5*time.Second) + defer flushCancel() + if flushErr := a.messages.FlushAll(flushCtx); flushErr != nil { slog.Error("Failed to flush pending message updates after run", "error", flushErr) } if skipRunComplete { @@ -329,14 +768,7 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (result * // the authoritative terminal event so a momentarily-full // subscriber channel can't silently drop it and hang // non-interactive clients waiting on RunComplete. - if call.OnComplete != nil { - call.OnComplete(complete) - return - } - if a.runComplete == nil { - return - } - a.runComplete.PublishMustDeliver(ctx, pubsub.UpdatedEvent, complete) + a.publishRunComplete(ctx, call, complete) }() history, files := a.preparePrompt(msgs, largeModel.CatwalkCfg.SupportsImages, call.Attachments...) @@ -371,9 +803,20 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (result * // Use latest tools (updated by SetTools when MCP tools change). prepared.Tools = a.tools.Copy() - queuedCalls, _ := a.messageQueue.Get(call.SessionID) - a.messageQueue.Del(call.SessionID) - for _, queued := range queuedCalls { + // Drain queued follow-up prompts for this step. Calls covered + // by a cancel recorded while they sat in the queue are dropped: + // a cancel that arrived after a prompt was queued must not let + // it run as part of this step. Coverage is per-call by accept + // sequence so a follow-up queued after the cancel (higher seq) + // is not dropped. A dropped prompt carrying a RunID still gets + // its terminal cancelled RunComplete so a caller waiting on it + // does not hang. Uncanceled prompts without a RunID are folded + // into this turn; uncanceled prompts with a RunID are left + // queued so each runs as its own turn (with its own + // RunComplete) via the recursive run path below. + fold, canceledRunIDs := a.drainQueueForStep(call.SessionID) + a.publishCanceledQueueDrops(canceledRunIDs) + for _, queued := range fold { userMessage, createErr := a.createUserMessage(callContext, queued) if createErr != nil { return callContext, prepared, createErr @@ -575,13 +1018,34 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (result * isHyper := largeModel.ModelCfg.Provider == hyper.Name isCancelErr := errors.Is(err, context.Canceled) if currentAssistant == nil { + // Cancel-before-assistant-creation window: the run was + // canceled after activeRequests.Set but before PrepareStep + // created the assistant message. Without this, the turn + // would return with no FinishReasonCanceled marker and no + // user-visible record. The user message was already created + // above, so persistCanceledTurn only writes the assistant + // record. + if isCancelErr { + if persistErr := a.persistCanceledTurn(ctx, call, userMsgCreated); persistErr != nil { + return nil, persistErr + } + } return result, err } + // Persist final state with a context detached from the run + // context. The run context (ctx) is derived from the + // workspace context, which workspace shutdown cancels before + // agent goroutines finish; using ctx here would drop the + // final assistant state. WithoutCancel keeps the values + // (e.g. session ID) while ignoring cancellation, and a short + // timeout bounds the cleanup writes. + cleanupCtx, cleanupCancel := context.WithTimeout(context.WithoutCancel(ctx), 5*time.Second) + defer cleanupCancel() // Ensure we finish thinking on error to close the reasoning state. currentAssistant.FinishThinking() toolCalls := currentAssistant.ToolCalls() - // INFO: we use the parent context here because the genCtx has been cancelled. - msgs, createErr := a.messages.List(ctx, currentAssistant.SessionID) + // INFO: we use the cleanup context here because the genCtx has been cancelled. + msgs, createErr := a.messages.List(cleanupCtx, currentAssistant.SessionID) if createErr != nil { return nil, createErr } @@ -590,7 +1054,7 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (result * tc.Finished = true tc.Input = "{}" currentAssistant.AddToolCall(tc) - updateErr := a.messages.Update(ctx, *currentAssistant) + updateErr := a.messages.Update(cleanupCtx, *currentAssistant) if updateErr != nil { return nil, updateErr } @@ -623,7 +1087,7 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (result * Content: content, IsError: true, } - _, createErr = a.messages.Create(ctx, currentAssistant.SessionID, message.CreateMessageParams{ + _, createErr = a.messages.Create(cleanupCtx, currentAssistant.SessionID, message.CreateMessageParams{ Role: message.Tool, Parts: []message.ContentPart{ toolResult, @@ -662,9 +1126,9 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (result * } else { currentAssistant.AddFinish(message.FinishReasonError, defaultTitle, err.Error()) } - // Note: we use the parent context here because the genCtx has been + // Note: we use the cleanup context here because the genCtx has been // cancelled. - updateErr := a.messages.Update(ctx, *currentAssistant) + updateErr := a.messages.Update(cleanupCtx, *currentAssistant) if updateErr != nil { return nil, updateErr } @@ -705,18 +1169,106 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (result * }) } - queuedMessages, ok := a.messageQueue.Get(call.SessionID) - if !ok || len(queuedMessages) == 0 { + // Hand off to the next queued prompt (if any) under dispatchMu so + // the transition from this finished run to the queued run is atomic + // against a concurrent Cancel. activeRequests for this session was + // just deleted above, so without the lock there is a window in + // which the session looks idle and a cancel becomes a no-op that + // fails to stop the queued prompt. Holding the lock lets us observe + // a pending cancel recorded against the session and drop the queue + // instead of running it, and (for the recursion) hand a fresh + // accept reservation to the dequeued call so acceptedRuns stays > 0 + // across the recursive Run's own dispatch handoff — keeping the + // session observable to Cancel for the entire transition and + // closing the dequeue -> re-register window. + mu := a.sessionMu(call.SessionID) + mu.Lock() + queuedMessages, _ := a.messageQueue.Get(call.SessionID) + if mark, ok := a.cancelMark.Get(call.SessionID); ok && mark > 0 && len(queuedMessages) > 0 { + // A cancel was recorded for this session (e.g. it arrived while + // this run was active and follow-ups had been queued). Drop the + // queued prompts it covers (accept sequence at or below the + // mark, or untracked); keep any queued after the cancel (higher + // sequence) so they still run. + var kept []SessionAgentCall + var canceledRunIDDrops []SessionAgentCall + for _, q := range queuedMessages { + if q.acceptSeq == 0 || q.acceptSeq <= mark { + if q.RunID != "" { + canceledRunIDDrops = append(canceledRunIDDrops, q) + } + continue + } + kept = append(kept, q) + } + queuedMessages = kept + a.messageQueue.Set(call.SessionID, kept) + // A dropped prompt carrying a RunID must still publish its + // terminal cancelled RunComplete so a caller waiting on that + // RunID does not hang. + a.publishCanceledQueueDrops(canceledRunIDDrops) + } + if len(queuedMessages) == 0 { + // No queued work. Clear the cancel mark only when no accepted + // run remains in flight that it might still cover; otherwise a + // sibling prompt (sequence at or below the mark) waiting to + // enter Run would lose its cancellation. When accepted runs are + // gone, this also clears a stale mark so it can't catch a + // future run. + a.messageQueue.Del(call.SessionID) + a.acceptedMu.Lock() + inFlight, _ := a.acceptedRuns.Get(call.SessionID) + a.acceptedMu.Unlock() + if inFlight == 0 { + a.cancelMark.Del(call.SessionID) + } + mu.Unlock() return result, err } - // There are queued messages restart the loop. The recursive Run - // publishes its own RunComplete for the queued prompt, so suppress - // the outer defer's emit to avoid a duplicate event whose Error - // field would belong to the recursive turn but whose MessageID/Text - // would belong to the outer turn. + // There are queued messages, restart the loop. Suppress the outer + // defer's emit: it would otherwise observe the recursive Run's retErr + // (named-return clobbering through the return below) against this + // turn's MessageID/Text and publish a mixed, racing event. skipRunComplete = true + // Decide whether this turn still owes its own terminal RunComplete. + // Each submitted prompt with a RunID has its own lifecycle, so a turn + // that is finished and handing off to a *different* queued prompt must + // publish its own RunComplete here — leaving it to the recursive turn + // (which carries a different RunID) would hang a caller waiting on + // this turn's RunID. The exception is the summarize-continuation path, + // which re-queues this same call (same RunID) to resume after a + // summary; in that case the eventual terminal turn for this RunID + // publishes, so publishing now would double-emit. + outerOwesRunComplete := call.RunID != "" + if outerOwesRunComplete { + for _, q := range queuedMessages { + if q.RunID == call.RunID { + outerOwesRunComplete = false + break + } + } + } firstQueuedMessage := queuedMessages[0] a.messageQueue.Set(call.SessionID, queuedMessages[1:]) + // Reserve a fresh accept for the dequeued prompt before dropping the + // lock so acceptedRuns > 0 across the handoff into the recursive + // Run. This closes the window between this dequeue and the recursive + // Run registering its activeRequests entry: a cancel arriving in + // that window now records a pending cancel (acceptedRuns > 0) that + // the recursive Run's accepted path observes as cancel-on-entry. + firstQueuedMessage.Accepted = a.BeginAccepted(call.SessionID) + mu.Unlock() + if outerOwesRunComplete { + complete := notify.RunComplete{SessionID: call.SessionID, RunID: call.RunID} + if currentAssistant != nil { + complete.MessageID = currentAssistant.ID + complete.Text = currentAssistant.Content().String() + } + if ctx.Err() != nil { + complete.Cancelled = true + } + a.publishRunComplete(ctx, call, complete) + } return a.Run(ctx, firstQueuedMessage) } @@ -1269,6 +1821,16 @@ func summaryCompletionTokens(usage fantasy.Usage, summaryMessage message.Message } func (a *sessionAgent) Cancel(sessionID string) { + // Serialize against the dispatch handoff in Run so the accepted -> + // (cancel-on-entry | queued | active) transition is atomic against + // this cancel. Every cancel observes at least one of: an active + // request, an accepted run (recorded as a pending cancel), or a + // queue entry it then clears. If none of those hold, an idle Escape + // is a true no-op and must not poison the next prompt. + mu := a.sessionMu(sessionID) + mu.Lock() + defer mu.Unlock() + // Cancel regular requests. Don't use Take() here - we need the entry to // remain in activeRequests so IsBusy() returns true until the goroutine // fully completes (including error handling that may access the DB). @@ -1284,16 +1846,40 @@ func (a *sessionAgent) Cancel(sessionID string) { cancel() } + // Record a pending cancel only when a dispatched-but-not-yet-active + // run exists. This catches runs still in the goroutine scheduler or + // about to enter Run's busy-queue branch, while leaving an idle + // session untouched. Active and accepted are not mutually exclusive: + // when a run is active and a follow-up has been accepted, both the + // cancel above and this pending record fire. + // + // Raise the session's cancel mark to the latest accept sequence + // assigned so far. Every prompt currently accepted-but-not-yet- + // active has a sequence at or below that value, so one cancel covers + // all of them; a prompt accepted after this cancel gets a strictly + // higher sequence and is never poisoned. Using max keeps repeated + // cancels idempotent while the same prompts are in flight and lets a + // later cancel extend coverage to prompts accepted since. + a.acceptedMu.Lock() + count, ok := a.acceptedRuns.Get(sessionID) + mark := a.acceptSeqGen + a.acceptedMu.Unlock() + if ok && count > 0 { + slog.Debug("Recording cancel mark for accepted runs", "session_id", sessionID, "count", count, "mark", mark) + existing, _ := a.cancelMark.Get(sessionID) + a.cancelMark.Set(sessionID, max(existing, mark)) + } + if a.QueuedPrompts(sessionID) > 0 { slog.Debug("Clearing queued prompts", "session_id", sessionID) - a.messageQueue.Del(sessionID) + a.clearQueueAndNotify(sessionID) } } func (a *sessionAgent) ClearQueue(sessionID string) { if a.QueuedPrompts(sessionID) > 0 { slog.Debug("Clearing queued prompts", "session_id", sessionID) - a.messageQueue.Del(sessionID) + a.clearQueueAndNotify(sessionID) } } diff --git a/internal/agent/agenttest/coordinator.go b/internal/agent/agenttest/coordinator.go new file mode 100644 index 0000000000000000000000000000000000000000..fdacb7e1292f8fcddcc903a0e70aba544d25fdd3 --- /dev/null +++ b/internal/agent/agenttest/coordinator.go @@ -0,0 +1,80 @@ +// Package agenttest provides test-only constructors for wiring a real +// production agent.Coordinator without booting a full app.App. It is +// imported only from _test.go files (e.g. internal/backend integration +// tests) and is never referenced by production code, so it is compiled +// only under tests and never ships in the production binary or API. +package agenttest + +import ( + "context" + + "charm.land/catwalk/pkg/catwalk" + "charm.land/fantasy/providers/openaicompat" + "github.com/charmbracelet/crush/internal/agent" + "github.com/charmbracelet/crush/internal/config" + "github.com/charmbracelet/crush/internal/message" + "github.com/charmbracelet/crush/internal/permission" + "github.com/charmbracelet/crush/internal/session" +) + +// NewCoordinator builds a real agent.Coordinator through the production +// agent.NewCoordinator constructor so the RunAccepted / BeginAccepted / +// run path (including UpdateModels) is the actual code under test. +// +// It installs a minimal config with a single openai-compatible provider +// whose model resolves offline. run rebuilds the model on every call, so +// the provider must construct without network I/O; the cancel-on-entry +// path this helper is built to exercise returns before any model call, +// so no request is ever issued. The coder agent's allowed-tools list is +// cleared to keep tool construction cheap and free of sub-agent wiring. +// +// The optional coordinator dependencies (history, filetracker, LSP, +// notify, runComplete, skills) are nil: run guards the publisher fields +// and the cancel-on-entry path never touches the others. +func NewCoordinator( + ctx context.Context, + workingDir string, + sessions session.Service, + messages message.Service, +) (agent.Coordinator, error) { + cfg, err := config.Init(workingDir, "", false) + if err != nil { + return nil, err + } + + const ( + providerID = "test-openai-compat" + modelID = "test-model" + ) + cfg.Config().Providers.Set(providerID, config.ProviderConfig{ + ID: providerID, + Name: "Test", + Type: openaicompat.Name, + BaseURL: "http://127.0.0.1:0/v1", + APIKey: "test", + Models: []catwalk.Model{{ID: modelID, DefaultMaxTokens: 4096}}, + }) + selected := config.SelectedModel{Provider: providerID, Model: modelID} + cfg.Config().Models[config.SelectedModelTypeLarge] = selected + cfg.Config().Models[config.SelectedModelTypeSmall] = selected + cfg.SetupAgents() + + // Keep buildTools light: no sub-agent or agentic-fetch construction. + coderCfg := cfg.Config().Agents[config.AgentCoder] + coderCfg.AllowedTools = nil + cfg.Config().Agents[config.AgentCoder] = coderCfg + + return agent.NewCoordinator( + ctx, + cfg, + sessions, + messages, + permission.NewPermissionService(workingDir, true, nil), + nil, + nil, + nil, + nil, + nil, + nil, + ) +} diff --git a/internal/agent/coordinator.go b/internal/agent/coordinator.go index 75087e11fe96e13b51eb03c51fee46c8f8618bcb..86ca09e3bf7cdafca106ddf44fb88fc2c2a7ceac 100644 --- a/internal/agent/coordinator.go +++ b/internal/agent/coordinator.go @@ -81,6 +81,15 @@ type Coordinator interface { // INFO: (kujtim) this is not used yet we will use this when we have multiple agents // SetMainAgent(string) Run(ctx context.Context, sessionID, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error) + // RunAccepted runs a call that was already accepted via + // BeginAccepted on the fire-and-forget dispatch path. The handle is + // the only carrier of accept-state across the backend.runAgent / + // Coordinator / sessionAgent.Run layers: it reaches + // sessionAgent.Run as SessionAgentCall.Accepted, where it is + // consumed under dispatchMu once the accepted -> (cancel-on-entry | + // queued | active) transition is chosen. + RunAccepted(ctx context.Context, accept *AcceptedRun, sessionID, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error) + BeginAccepted(sessionID string) *AcceptedRun Cancel(sessionID string) CancelAll() IsSessionBusy(sessionID string) bool @@ -179,6 +188,20 @@ func NewCoordinator( // Run implements Coordinator. func (c *coordinator) Run(ctx context.Context, sessionID string, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error) { + return c.run(ctx, nil, sessionID, prompt, attachments...) +} + +// RunAccepted implements Coordinator. +func (c *coordinator) RunAccepted(ctx context.Context, accept *AcceptedRun, sessionID string, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error) { + return c.run(ctx, accept, sessionID, prompt, attachments...) +} + +// run is the shared implementation behind Run and RunAccepted. When +// accept is non-nil it is threaded onto the SessionAgentCall as +// Accepted so sessionAgent.Run can consume the accept reservation under +// dispatchMu; when nil (the in-process/local path) no accept tracking +// applies. +func (c *coordinator) run(ctx context.Context, accept *AcceptedRun, sessionID string, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error) { if err := c.readyWg.Wait(); err != nil { return nil, err } @@ -256,6 +279,7 @@ func (c *coordinator) Run(ctx context.Context, sessionID string, prompt string, FrequencyPenalty: freqPenalty, PresencePenalty: presPenalty, OnComplete: onComplete, + Accepted: accept, }) } beforeLoaded := c.skillTracker.LoadedNames() @@ -278,6 +302,11 @@ func (c *coordinator) Run(ctx context.Context, sessionID string, prompt string, if hasLatest && c.runComplete != nil { c.runComplete.PublishMustDeliver(ctx, pubsub.UpdatedEvent, latest) + // Signal to the dispatcher (backend.runAgent) that the + // authoritative terminal RunComplete for this run was already + // emitted, so it does not publish a duplicate fallback for the + // error it is about to receive. + MarkRunCompletePublished(ctx) } return result, originalErr } @@ -1016,6 +1045,15 @@ func isExactoSupported(modelID string) bool { return slices.Contains(supportedModels, modelID) } +// BeginAccepted reserves an accept slot for sessionID on the active +// agent and returns the ownership handle. It is the fire-and-forget +// dispatch path's only way to mark a run as accepted-but-not-yet-active +// so a cancel arriving before the run registers in activeRequests is not +// lost. +func (c *coordinator) BeginAccepted(sessionID string) *AcceptedRun { + return c.currentAgent.BeginAccepted(sessionID) +} + func (c *coordinator) Cancel(sessionID string) { c.currentAgent.Cancel(sessionID) } diff --git a/internal/agent/coordinator_test.go b/internal/agent/coordinator_test.go index da0ddd0db2bf77c3c1e3eb6463549875a989a4ca..c522ef5de1061435e4cf9df1789bc3c92d9152a4 100644 --- a/internal/agent/coordinator_test.go +++ b/internal/agent/coordinator_test.go @@ -25,6 +25,10 @@ func (m *mockSessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fan return m.runFunc(ctx, call) } +func (m *mockSessionAgent) BeginAccepted(sessionID string) *AcceptedRun { + return &AcceptedRun{sessionID: sessionID} +} + func (m *mockSessionAgent) Model() Model { return m.model } func (m *mockSessionAgent) SetModels(large, small Model) {} func (m *mockSessionAgent) SetTools(tools []fantasy.AgentTool) {} diff --git a/internal/agent/dispatch_cancel_test.go b/internal/agent/dispatch_cancel_test.go new file mode 100644 index 0000000000000000000000000000000000000000..f1b0faad21da845f16728fd2d1101b64c569f2dc --- /dev/null +++ b/internal/agent/dispatch_cancel_test.go @@ -0,0 +1,447 @@ +package agent + +import ( + "context" + "errors" + "sync/atomic" + "testing" + "time" + + "charm.land/fantasy" + "github.com/charmbracelet/crush/internal/agent/notify" + "github.com/charmbracelet/crush/internal/message" + "github.com/charmbracelet/crush/internal/pubsub" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// finishStreamModel is a minimal fantasy.LanguageModel that streams a +// single text part followed by a normal (FinishReasonStop) finish. It +// is enough to drive sessionAgent.Run through PrepareStep and a clean +// completion without a recorded provider cassette. +type finishStreamModel struct { + text string +} + +func (m *finishStreamModel) Provider() string { return "fake" } +func (m *finishStreamModel) Model() string { return "fake-model" } + +func (m *finishStreamModel) Generate(ctx context.Context, call fantasy.Call) (*fantasy.Response, error) { + return &fantasy.Response{ + Content: fantasy.ResponseContent{fantasy.TextContent{Text: m.text}}, + FinishReason: fantasy.FinishReasonStop, + }, nil +} + +func (m *finishStreamModel) Stream(ctx context.Context, call fantasy.Call) (fantasy.StreamResponse, error) { + text := m.text + return func(yield func(fantasy.StreamPart) bool) { + if !yield(fantasy.StreamPart{Type: fantasy.StreamPartTypeTextStart, ID: "1"}) { + return + } + if !yield(fantasy.StreamPart{Type: fantasy.StreamPartTypeTextDelta, ID: "1", Delta: text}) { + return + } + if !yield(fantasy.StreamPart{Type: fantasy.StreamPartTypeTextEnd, ID: "1"}) { + return + } + yield(fantasy.StreamPart{Type: fantasy.StreamPartTypeFinish, FinishReason: fantasy.FinishReasonStop}) + }, nil +} + +func (m *finishStreamModel) GenerateObject(ctx context.Context, call fantasy.ObjectCall) (*fantasy.ObjectResponse, error) { + return nil, errors.New("not implemented") +} + +func (m *finishStreamModel) StreamObject(ctx context.Context, call fantasy.ObjectCall) (fantasy.ObjectStreamResponse, error) { + return nil, errors.New("not implemented") +} + +func newStreamTestAgent(t *testing.T) (*sessionAgent, fakeEnv) { + t.Helper() + env := testEnv(t) + model := &finishStreamModel{text: "done"} + sa := testSessionAgent(env, model, model, "system").(*sessionAgent) + return sa, env +} + +// TestCancel_ActiveAndAcceptedFiresBothBranches covers the case where a +// session is actively running (activeRequests set) AND a follow-up has +// been accepted (acceptedRuns > 0). A single Cancel must fire both: it +// invokes the active cancel func and records a pending cancel for the +// accepted follow-up. +func TestCancel_ActiveAndAcceptedFiresBothBranches(t *testing.T) { + t.Parallel() + sa, _ := newCancelTestAgent(t) + + const sid = "sid" + var activeCanceled atomic.Bool + sa.activeRequests.Set(sid, func() { activeCanceled.Store(true) }) + + accept := sa.BeginAccepted(sid) + defer accept.Close() + + sa.Cancel(sid) + + require.True(t, activeCanceled.Load(), "active cancel func must fire") + require.True(t, sa.hasPendingCancel(sid), "accepted follow-up must record a pending cancel") +} + +// TestRun_BusyWithPendingCancelTakesCancelOnEntry covers the busy-queue +// branch consulting pendingCancels: when the session is busy AND a +// cancel has been recorded for an accepted follow-up, Run must take the +// cancel-on-entry path (persist a canceled turn) instead of enqueueing +// the call behind the active run. +func TestRun_BusyWithPendingCancelTakesCancelOnEntry(t *testing.T) { + t.Parallel() + sa, env := newCancelTestAgent(t) + + sess, err := env.sessions.Create(t.Context(), "session") + require.NoError(t, err) + + // Make the session look busy: an earlier prompt is active. + sa.activeRequests.Set(sess.ID, func() {}) + + accept := sa.BeginAccepted(sess.ID) + // A cancel arrives while this follow-up is accepted-but-not-active. + sa.Cancel(sess.ID) + require.True(t, sa.hasPendingCancel(sess.ID)) + + result, err := sa.Run(t.Context(), SessionAgentCall{ + SessionID: sess.ID, + Prompt: "follow-up", + Accepted: accept, + }) + require.NoError(t, err) + require.Nil(t, result) + + // The follow-up was canceled on entry, not enqueued. + require.Equal(t, 0, sa.QueuedPrompts(sess.ID), + "cancel-on-entry must not enqueue the follow-up behind the active run") + require.False(t, sa.hasPendingCancel(sess.ID), "pending cancel must be consumed") + require.Equal(t, 0, sa.acceptedCount(sess.ID), "accept reservation must be released") + + msgs, err := env.messages.List(t.Context(), sess.ID) + require.NoError(t, err) + require.Len(t, msgs, 2) + assert.Equal(t, message.User, msgs[0].Role) + assert.Equal(t, message.Assistant, msgs[1].Role) + assert.Equal(t, message.FinishReasonCanceled, msgs[1].FinishReason()) +} + +// TestRun_PrepareStepDrainSkipsQueuedOnPendingCancel verifies that the +// queue drain inside PrepareStep skips queued follow-up prompts when a +// cancel has been recorded for the session: the queued prompt must not +// be folded into the active turn as an extra user message. +func TestRun_PrepareStepDrainSkipsQueuedOnPendingCancel(t *testing.T) { + t.Parallel() + sa, env := newStreamTestAgent(t) + + sess, err := env.sessions.Create(t.Context(), "session") + require.NoError(t, err) + + // A follow-up prompt sits queued for the session. + sa.enqueueCall(SessionAgentCall{SessionID: sess.ID, Prompt: "queued-followup"}) + // A cancel was recorded for the session while it sat in the queue. + sa.cancelMark.Set(sess.ID, 1) + + result, err := sa.Run(t.Context(), SessionAgentCall{ + SessionID: sess.ID, + Prompt: "main", + }) + require.NoError(t, err) + require.NotNil(t, result) + + // Only the main prompt produced a user message; the queued + // follow-up was skipped, not folded into the turn. + msgs, err := env.messages.List(t.Context(), sess.ID) + require.NoError(t, err) + var userMsgs []message.Message + for _, m := range msgs { + if m.Role == message.User { + userMsgs = append(userMsgs, m) + } + } + require.Len(t, userMsgs, 1, "queued follow-up must not create a user message") + assert.Equal(t, "main", userMsgs[0].Content().String()) + + // The queue was drained and the pending cancel consumed. + require.Equal(t, 0, sa.QueuedPrompts(sess.ID)) + require.False(t, sa.hasPendingCancel(sess.ID)) +} + +// TestRun_NormalCompletionClearsStalePendingCancel verifies that a Run +// which completes normally clears any stale pending-cancel entry for the +// session, so it cannot catch a future run. +func TestRun_NormalCompletionClearsStalePendingCancel(t *testing.T) { + t.Parallel() + sa, env := newStreamTestAgent(t) + + sess, err := env.sessions.Create(t.Context(), "session") + require.NoError(t, err) + + // A stale cancel mark lingers (no queued work, no accepted run). + sa.cancelMark.Set(sess.ID, 1) + + result, err := sa.Run(t.Context(), SessionAgentCall{ + SessionID: sess.ID, + Prompt: "main", + }) + require.NoError(t, err) + require.NotNil(t, result) + + require.False(t, sa.hasPendingCancel(sess.ID), + "normal completion must clear the stale pending cancel") + + msgs, err := env.messages.List(t.Context(), sess.ID) + require.NoError(t, err) + require.Len(t, msgs, 2) + assert.Equal(t, message.Assistant, msgs[1].Role) + assert.Equal(t, message.FinishReasonEndTurn, msgs[1].FinishReason()) +} + +// newCancelTestAgentWithRunComplete builds a DB-backed sessionAgent wired +// to a RunComplete broker so tests can observe the terminal event a +// RunID-bearing caller (e.g. `crush run`) blocks on. +func newCancelTestAgentWithRunComplete(t *testing.T) (*sessionAgent, fakeEnv, *pubsub.Broker[notify.RunComplete]) { + t.Helper() + env := testEnv(t) + broker := pubsub.NewBroker[notify.RunComplete]() + t.Cleanup(broker.Shutdown) + sa := NewSessionAgent(SessionAgentOptions{ + Sessions: env.sessions, + Messages: env.messages, + RunComplete: broker, + }).(*sessionAgent) + return sa, env, broker +} + +// TestRun_CancelOnEntryPublishesRunComplete covers the first review +// finding: the cancel-on-entry path returned before the streaming defer +// that publishes RunComplete was installed. A caller that dispatches a +// run with a RunID and blocks on RunComplete (ignoring message events, +// like `crush run`) would hang on an immediately-canceled accepted run. +// The cancel-on-entry path must publish a terminal RunComplete carrying +// the originating RunID. +func TestRun_CancelOnEntryPublishesRunComplete(t *testing.T) { + t.Parallel() + sa, env, broker := newCancelTestAgentWithRunComplete(t) + + sess, err := env.sessions.Create(t.Context(), "session") + require.NoError(t, err) + + ctx, cancel := context.WithCancel(t.Context()) + defer cancel() + ch := broker.Subscribe(ctx) + + accept := sa.BeginAccepted(sess.ID) + // A cancel arrives in the accepted-but-not-yet-active window. + sa.Cancel(sess.ID) + require.True(t, sa.hasPendingCancel(sess.ID)) + + result, err := sa.Run(t.Context(), SessionAgentCall{ + SessionID: sess.ID, + RunID: "run-cancel-on-entry", + Prompt: "hello", + Accepted: accept, + }) + require.NoError(t, err) + require.Nil(t, result) + + select { + case got := <-ch: + assert.Equal(t, "run-cancel-on-entry", got.Payload.RunID, + "RunComplete must echo the originating RunID") + assert.Equal(t, sess.ID, got.Payload.SessionID) + assert.True(t, got.Payload.Cancelled, + "cancel-on-entry RunComplete must be marked cancelled") + case <-time.After(2 * time.Second): + t.Fatal("cancel-on-entry must publish RunComplete; a RunID caller would hang otherwise") + } +} + +// TestCancel_TwoAcceptedBothObserveCancellation covers the second review +// finding: a single cancel with two accepted-not-yet-active prompts must +// cancel both. The cancel raises the session's high-water mark to the +// latest accept sequence, so every prompt accepted-but-not-yet-active at +// cancel time is covered and both take the cancel-on-entry path. +func TestCancel_TwoAcceptedBothObserveCancellation(t *testing.T) { + t.Parallel() + sa, env := newCancelTestAgent(t) + + sess, err := env.sessions.Create(t.Context(), "session") + require.NoError(t, err) + + // Two prompts are accepted-but-not-yet-active for the same session. + accept1 := sa.BeginAccepted(sess.ID) + accept2 := sa.BeginAccepted(sess.ID) + require.Equal(t, 2, sa.acceptedCount(sess.ID)) + + // A single cancel arrives before either becomes active. + sa.Cancel(sess.ID) + require.Equal(t, accept2.seq, sa.pendingCancelMark(sess.ID), + "one cancel must mark every currently-accepted prompt as canceled") + require.GreaterOrEqual(t, sa.pendingCancelMark(sess.ID), accept1.seq, + "the mark must cover the earlier accepted prompt too") + + // Both prompts enter Run; each must take cancel-on-entry, not run. + r1, err := sa.Run(t.Context(), SessionAgentCall{ + SessionID: sess.ID, + Prompt: "first", + Accepted: accept1, + }) + require.NoError(t, err) + require.Nil(t, r1) + + r2, err := sa.Run(t.Context(), SessionAgentCall{ + SessionID: sess.ID, + Prompt: "second", + Accepted: accept2, + }) + require.NoError(t, err) + require.Nil(t, r2) + + require.False(t, sa.hasPendingCancel(sess.ID), + "both reserved units must be consumed") + require.Equal(t, 0, sa.acceptedCount(sess.ID), + "both accept reservations must be released") + + // Each canceled-on-entry turn writes a user + canceled assistant + // message, and neither prompt was enqueued to run normally. + require.Equal(t, 0, sa.QueuedPrompts(sess.ID), + "neither accepted prompt may be enqueued to run normally") + msgs, err := env.messages.List(t.Context(), sess.ID) + require.NoError(t, err) + require.Len(t, msgs, 4, "two canceled turns produce two user + two assistant messages") + var canceled int + for _, m := range msgs { + if m.Role == message.Assistant { + assert.Equal(t, message.FinishReasonCanceled, m.FinishReason()) + canceled++ + } + } + require.Equal(t, 2, canceled, "both turns must finish canceled") +} + +// TestRun_IdleCancelDoesNotPoisonNextPrompt covers the idle-cancel +// no-op guarantee end-to-end: an Escape on an idle session must not +// record a pending cancel that leaks into the next accepted prompt, which +// must run normally to completion. +func TestRun_IdleCancelDoesNotPoisonNextPrompt(t *testing.T) { + t.Parallel() + sa, env := newStreamTestAgent(t) + + sess, err := env.sessions.Create(t.Context(), "session") + require.NoError(t, err) + + // Idle Escape: no accepted run, no active request. + sa.Cancel(sess.ID) + require.False(t, sa.hasPendingCancel(sess.ID), + "idle cancel must not record a pending cancel") + + // The next accepted prompt must run normally, not cancel on entry. + accept := sa.BeginAccepted(sess.ID) + result, err := sa.Run(t.Context(), SessionAgentCall{ + SessionID: sess.ID, + Prompt: "next", + Accepted: accept, + }) + require.NoError(t, err) + require.NotNil(t, result, "next prompt must run normally after an idle cancel") + + msgs, err := env.messages.List(t.Context(), sess.ID) + require.NoError(t, err) + require.Len(t, msgs, 2) + assert.Equal(t, message.User, msgs[0].Role) + assert.Equal(t, message.Assistant, msgs[1].Role) + assert.Equal(t, message.FinishReasonEndTurn, msgs[1].FinishReason(), + "the prompt must finish normally, not canceled") +} + +// TestCancel_AcceptedAfterCancelIsNotPoisoned is the regression test for +// the reviewer's finding: a counted session-level pending cancel let a +// prompt accepted after the cancel enter Run first and consume a unit +// reserved for the earlier prompts. With a sequence high-water mark, a +// single cancel covers exactly the prompts accepted-but-not-yet-active at +// cancel time (A and B); a prompt accepted afterwards (C) gets a higher +// sequence and must run normally without consuming A or B's cancellation. +// C is run first to prove it neither cancels nor drains the mark, then A +// and B are run and must both cancel on entry. +func TestCancel_AcceptedAfterCancelIsNotPoisoned(t *testing.T) { + t.Parallel() + sa, env := newStreamTestAgent(t) + + sess, err := env.sessions.Create(t.Context(), "session") + require.NoError(t, err) + + // A and B are accepted-but-not-yet-active. + acceptA := sa.BeginAccepted(sess.ID) + acceptB := sa.BeginAccepted(sess.ID) + + // One cancel arrives covering both A and B. + sa.Cancel(sess.ID) + require.True(t, sa.hasPendingCancel(sess.ID)) + require.Equal(t, acceptB.seq, sa.pendingCancelMark(sess.ID), + "the mark must cover every prompt accepted before the cancel") + + // C is accepted AFTER the cancel; its sequence is above the mark. + acceptC := sa.BeginAccepted(sess.ID) + require.Greater(t, acceptC.seq, sa.pendingCancelMark(sess.ID), + "a prompt accepted after the cancel must not be covered by the mark") + + // Run C first. It must run normally to completion and must not + // consume or clear the cancellation reserved for A and B. + rc, err := sa.Run(t.Context(), SessionAgentCall{ + SessionID: sess.ID, + Prompt: "C", + Accepted: acceptC, + }) + require.NoError(t, err) + require.NotNil(t, rc, "C was accepted after the cancel and must run normally") + require.True(t, sa.hasPendingCancel(sess.ID), + "running C must not drain the cancellation reserved for A and B") + + // Now A and B run. Both were accepted before the cancel and must + // take the cancel-on-entry path. + ra, err := sa.Run(t.Context(), SessionAgentCall{ + SessionID: sess.ID, + Prompt: "A", + Accepted: acceptA, + }) + require.NoError(t, err) + require.Nil(t, ra, "A must cancel on entry, not run") + + rb, err := sa.Run(t.Context(), SessionAgentCall{ + SessionID: sess.ID, + Prompt: "B", + Accepted: acceptB, + }) + require.NoError(t, err) + require.Nil(t, rb, "B must cancel on entry, not run") + + require.False(t, sa.hasPendingCancel(sess.ID), + "the mark clears once all covered handles are resolved") + require.Equal(t, 0, sa.acceptedCount(sess.ID)) + require.Equal(t, 0, sa.QueuedPrompts(sess.ID), + "neither A nor B may be enqueued to run normally") + + // C produced a normal turn; A and B each produced a canceled turn. + msgs, err := env.messages.List(t.Context(), sess.ID) + require.NoError(t, err) + require.Len(t, msgs, 6, "C normal + A canceled + B canceled = 3 user + 3 assistant") + + var normal, canceled int + for _, m := range msgs { + if m.Role != message.Assistant { + continue + } + switch m.FinishReason() { + case message.FinishReasonEndTurn: + normal++ + case message.FinishReasonCanceled: + canceled++ + } + } + require.Equal(t, 1, normal, "only C finished normally") + require.Equal(t, 2, canceled, "both A and B finished canceled") +} diff --git a/internal/agent/notify/notify.go b/internal/agent/notify/notify.go index ac7f724c0f07f552d9759247821a2555c9e12524..22e9f17769b5585302a195049bb3abca919f9a91 100644 --- a/internal/agent/notify/notify.go +++ b/internal/agent/notify/notify.go @@ -12,6 +12,9 @@ const ( // TypeReAuthenticate indicates the agent encountered an // authentication error and the user needs to re-authenticate. TypeReAuthenticate Type = "re_authenticate" + // TypeAgentError indicates the agent's turn terminated with an + // error. The error text is carried in Notification.Message. + TypeAgentError Type = "error" ) // Notification represents a domain event published by the agent. @@ -20,6 +23,15 @@ type Notification struct { SessionTitle string Type Type ProviderID string + // RunID, when non-empty, is the caller-supplied correlator + // (proto.AgentMessage.RunID) for the run that produced this + // notification. It lets observers attribute a TypeAgentError to a + // specific request rather than to any in-flight run on the + // session. Empty when no caller set one. + RunID string + // Message carries the error text for TypeAgentError. Other + // notification types ignore it. + Message string } // RunComplete is the authoritative end-of-run signal for a session. diff --git a/internal/agent/queued_runid_test.go b/internal/agent/queued_runid_test.go new file mode 100644 index 0000000000000000000000000000000000000000..e3e99d4f648e12d5a98be052747841553b1fa8ae --- /dev/null +++ b/internal/agent/queued_runid_test.go @@ -0,0 +1,181 @@ +package agent + +import ( + "context" + "errors" + "sync/atomic" + "testing" + "time" + + "charm.land/catwalk/pkg/catwalk" + "charm.land/fantasy" + "github.com/charmbracelet/crush/internal/agent/notify" + "github.com/charmbracelet/crush/internal/message" + "github.com/charmbracelet/crush/internal/pubsub" + "github.com/stretchr/testify/require" +) + +// gatedStreamModel streams a single text part followed by a clean finish, +// but blocks the very first Stream call until its gate is released. That +// lets a test hold a run "active" (past PrepareStep, inside Stream) just +// long enough to enqueue a follow-up prompt behind the busy session. +// Subsequent Stream calls (e.g. the recursive run draining the queue) +// proceed immediately. +type gatedStreamModel struct { + text string + gate chan struct{} + entered chan struct{} + calls atomic.Int64 +} + +func (m *gatedStreamModel) Provider() string { return "fake" } +func (m *gatedStreamModel) Model() string { return "fake-model" } + +func (m *gatedStreamModel) Generate(ctx context.Context, call fantasy.Call) (*fantasy.Response, error) { + return &fantasy.Response{ + Content: fantasy.ResponseContent{fantasy.TextContent{Text: m.text}}, + FinishReason: fantasy.FinishReasonStop, + }, nil +} + +func (m *gatedStreamModel) Stream(ctx context.Context, call fantasy.Call) (fantasy.StreamResponse, error) { + if m.calls.Add(1) == 1 { + close(m.entered) + select { + case <-m.gate: + case <-ctx.Done(): + } + } + text := m.text + return func(yield func(fantasy.StreamPart) bool) { + if !yield(fantasy.StreamPart{Type: fantasy.StreamPartTypeTextStart, ID: "1"}) { + return + } + if !yield(fantasy.StreamPart{Type: fantasy.StreamPartTypeTextDelta, ID: "1", Delta: text}) { + return + } + if !yield(fantasy.StreamPart{Type: fantasy.StreamPartTypeTextEnd, ID: "1"}) { + return + } + yield(fantasy.StreamPart{Type: fantasy.StreamPartTypeFinish, FinishReason: fantasy.FinishReasonStop}) + }, nil +} + +func (m *gatedStreamModel) GenerateObject(ctx context.Context, call fantasy.ObjectCall) (*fantasy.ObjectResponse, error) { + return nil, errors.New("not implemented") +} + +func (m *gatedStreamModel) StreamObject(ctx context.Context, call fantasy.ObjectCall) (fantasy.ObjectStreamResponse, error) { + return nil, errors.New("not implemented") +} + +// TestRun_QueuedRunIDPromptRunsRecursivelyAndPublishesRunComplete is the +// end-to-end proof of fix 2: a prompt carrying a RunID that is queued +// behind a busy session must NOT be silently folded into the active turn. +// It runs as its own turn via the recursive run path and publishes its +// own terminal RunComplete, so a `crush run` caller blocking on that +// RunID does not hang. The active turn keeps its own RunComplete too. +func TestRun_QueuedRunIDPromptRunsRecursivelyAndPublishesRunComplete(t *testing.T) { + t.Parallel() + + env := testEnv(t) + broker := pubsub.NewBroker[notify.RunComplete]() + t.Cleanup(broker.Shutdown) + + large := &gatedStreamModel{ + text: "done", + gate: make(chan struct{}), + entered: make(chan struct{}), + } + small := &finishStreamModel{text: "title"} + + sa := NewSessionAgent(SessionAgentOptions{ + LargeModel: Model{Model: large, CatwalkCfg: catwalk.Model{ContextWindow: 200000, DefaultMaxTokens: 10000}}, + SmallModel: Model{Model: small, CatwalkCfg: catwalk.Model{ContextWindow: 200000, DefaultMaxTokens: 10000}}, + IsYolo: true, + Sessions: env.sessions, + Messages: env.messages, + RunComplete: broker, + }).(*sessionAgent) + + sess, err := env.sessions.Create(t.Context(), "session") + require.NoError(t, err) + + subCtx, subCancel := context.WithCancel(t.Context()) + defer subCancel() + ch := broker.Subscribe(subCtx) + + // Start the main turn; it blocks inside Stream once active. + mainDone := make(chan error, 1) + go func() { + _, runErr := sa.Run(t.Context(), SessionAgentCall{ + SessionID: sess.ID, + RunID: "run-main", + Prompt: "main", + }) + mainDone <- runErr + }() + + // Wait until the main turn is active (inside Stream). + select { + case <-large.entered: + case <-time.After(5 * time.Second): + t.Fatal("main run never entered Stream") + } + require.True(t, sa.IsSessionBusy(sess.ID), "main run must be active before enqueueing the follow-up") + + // Enqueue a RunID-bearing follow-up behind the busy session. + res, err := sa.Run(t.Context(), SessionAgentCall{ + SessionID: sess.ID, + RunID: "run-follow", + Prompt: "follow", + }) + require.NoError(t, err) + require.Nil(t, res, "a busy-session follow-up must enqueue and return (nil, nil)") + require.Equal(t, 1, sa.QueuedPrompts(sess.ID), "the follow-up must be queued, not folded") + + // Release the main turn so it completes and hands off to the queue. + close(large.gate) + require.NoError(t, <-mainDone) + + // Both turns must publish their own terminal RunComplete. + got := map[string]notify.RunComplete{} + deadline := time.After(5 * time.Second) + for len(got) < 2 { + select { + case ev := <-ch: + got[ev.Payload.RunID] = ev.Payload + case <-deadline: + t.Fatalf("timed out waiting for both RunCompletes; got %v", got) + } + } + + main, ok := got["run-main"] + require.True(t, ok, "the active turn must publish its own RunComplete") + require.Empty(t, main.Error) + require.False(t, main.Cancelled) + + follow, ok := got["run-follow"] + require.True(t, ok, + "the queued RunID prompt must publish its own RunComplete instead of being folded silently") + require.Empty(t, follow.Error) + require.False(t, follow.Cancelled) + require.Equal(t, "done", follow.Text, "the queued prompt ran as its own turn") + + // Two distinct assistant turns prove the follow-up was not folded. + msgs, err := env.messages.List(t.Context(), sess.ID) + require.NoError(t, err) + var assistants, follows int + for _, m := range msgs { + switch m.Role { + case message.Assistant: + assistants++ + case message.User: + if m.Content().String() == "follow" { + follows++ + } + } + } + require.Equal(t, 2, assistants, "the active turn and the recursive turn each produce one assistant message") + require.Equal(t, 1, follows, "the follow-up prompt is its own user turn") +} diff --git a/internal/agent/run_complete_test.go b/internal/agent/run_complete_test.go index 74f9232a0946b24d38f05873fa39066dcae40c27..2fb6fbefab436ce97b428e9025ef79142da2ea85 100644 --- a/internal/agent/run_complete_test.go +++ b/internal/agent/run_complete_test.go @@ -58,6 +58,138 @@ func TestSessionAgentRun_QueueStripsOnComplete(t *testing.T) { "RunComplete still correlates with the originating SendMessage") } +// TestDrainQueueForStep_FiltersUnderDispatchLock verifies that the queue +// drain evaluates the per-session cancel mark while holding the dispatch +// mutex (canceledBySeq's documented precondition). Queued calls at or +// below the cancel high-water mark are dropped, calls queued after the +// cancel (higher seq) are folded, untracked enqueues (seq == 0) are +// dropped whenever any mark is present, and the queue is cleared. These +// calls carry no RunID, so all foldable survivors are returned for +// folding (the existing follow-up behavior). +func TestDrainQueueForStep_FiltersUnderDispatchLock(t *testing.T) { + t.Parallel() + + env := testEnv(t) + a := NewSessionAgent(SessionAgentOptions{ + Sessions: env.sessions, + Messages: env.messages, + }).(*sessionAgent) + + const sessionID = "drain-session" + a.messageQueue.Set(sessionID, []SessionAgentCall{ + {SessionID: sessionID, Prompt: "below", acceptSeq: 1}, + {SessionID: sessionID, Prompt: "at-mark", acceptSeq: 2}, + {SessionID: sessionID, Prompt: "after", acceptSeq: 3}, + {SessionID: sessionID, Prompt: "untracked", acceptSeq: 0}, + }) + // Cancel high-water mark at seq 2: seq <= 2 and seq == 0 are covered. + a.cancelMark.Set(sessionID, 2) + + fold, canceledWithRunID := a.drainQueueForStep(sessionID) + + require.Len(t, fold, 1, + "only the follow-up queued after the cancel (seq > mark) must be folded") + require.Equal(t, "after", fold[0].Prompt) + require.Empty(t, canceledWithRunID, + "no dropped call carried a RunID, so none need a terminal RunComplete") + + _, ok := a.messageQueue.Get(sessionID) + require.False(t, ok, "drain must clear the session message queue when nothing is kept") +} + +// TestDrainQueueForStep_NoMarkFoldsAllNonRunID verifies that with no +// cancel mark recorded, every queued call without a RunID is folded. +func TestDrainQueueForStep_NoMarkFoldsAllNonRunID(t *testing.T) { + t.Parallel() + + env := testEnv(t) + a := NewSessionAgent(SessionAgentOptions{ + Sessions: env.sessions, + Messages: env.messages, + }).(*sessionAgent) + + const sessionID = "drain-nomark" + a.messageQueue.Set(sessionID, []SessionAgentCall{ + {SessionID: sessionID, Prompt: "a", acceptSeq: 0}, + {SessionID: sessionID, Prompt: "b", acceptSeq: 5}, + }) + + fold, canceledWithRunID := a.drainQueueForStep(sessionID) + require.Len(t, fold, 2, "no cancel mark means all non-RunID queued calls are folded") + require.Empty(t, canceledWithRunID) +} + +// TestDrainQueueForStep_KeepsRunIDPromptsQueued is the core of fix 2: a +// queued prompt that carries a RunID must NOT be folded into the active +// turn. Folding it would silently absorb it into another turn and never +// publish a RunComplete for its RunID, hanging a `crush run` caller that +// blocks on that event. Such prompts are left in the queue so the +// recursive run path gives each its own turn and its own RunComplete. +// Non-RunID prompts are still folded. +func TestDrainQueueForStep_KeepsRunIDPromptsQueued(t *testing.T) { + t.Parallel() + + env := testEnv(t) + a := NewSessionAgent(SessionAgentOptions{ + Sessions: env.sessions, + Messages: env.messages, + }).(*sessionAgent) + + const sessionID = "drain-runid" + a.messageQueue.Set(sessionID, []SessionAgentCall{ + {SessionID: sessionID, Prompt: "fold-me", acceptSeq: 1}, + {SessionID: sessionID, RunID: "run-a", Prompt: "keep-me", acceptSeq: 2}, + {SessionID: sessionID, RunID: "run-b", Prompt: "keep-me-too", acceptSeq: 3}, + }) + + fold, canceledWithRunID := a.drainQueueForStep(sessionID) + + require.Len(t, fold, 1, "only the non-RunID prompt is folded into the active turn") + require.Equal(t, "fold-me", fold[0].Prompt) + require.Empty(t, canceledWithRunID) + + kept, ok := a.messageQueue.Get(sessionID) + require.True(t, ok, "RunID-bearing prompts must remain queued for the recursive run path") + require.Len(t, kept, 2) + require.Equal(t, "run-a", kept[0].RunID) + require.Equal(t, "run-b", kept[1].RunID) +} + +// TestDrainQueueForStep_ReportsCanceledRunIDDrops verifies that a queued +// prompt carrying a RunID that is dropped because a cancel covers it is +// reported in canceledWithRunID so the caller can publish its terminal +// cancelled RunComplete. A canceled prompt without a RunID is dropped +// silently as before. +func TestDrainQueueForStep_ReportsCanceledRunIDDrops(t *testing.T) { + t.Parallel() + + env := testEnv(t) + a := NewSessionAgent(SessionAgentOptions{ + Sessions: env.sessions, + Messages: env.messages, + }).(*sessionAgent) + + const sessionID = "drain-cancel-runid" + a.messageQueue.Set(sessionID, []SessionAgentCall{ + {SessionID: sessionID, RunID: "run-canceled", Prompt: "canceled", acceptSeq: 1}, + {SessionID: sessionID, Prompt: "canceled-no-runid", acceptSeq: 1}, + {SessionID: sessionID, RunID: "run-survives", Prompt: "survives", acceptSeq: 5}, + }) + a.cancelMark.Set(sessionID, 2) + + fold, canceledWithRunID := a.drainQueueForStep(sessionID) + + require.Empty(t, fold, "no uncanceled non-RunID prompts to fold") + require.Len(t, canceledWithRunID, 1, + "only the dropped RunID-bearing prompt needs a terminal RunComplete") + require.Equal(t, "run-canceled", canceledWithRunID[0].RunID) + + kept, ok := a.messageQueue.Get(sessionID) + require.True(t, ok) + require.Len(t, kept, 1, "the uncanceled RunID prompt stays queued") + require.Equal(t, "run-survives", kept[0].RunID) +} + // TestRunCompletePublisher_MustDeliverOverTakesPublish exercises the // pubsub.Publisher interface change end-to-end: a Broker is the only // concrete Publisher implementation and must satisfy both Publish and @@ -87,3 +219,104 @@ func TestRunCompletePublisher_MustDeliverOverTakesPublish(t *testing.T) { t.Fatal("PublishMustDeliver did not deliver event") } } + +// requireSingleCancelledRunComplete reads exactly one RunComplete from ch, +// asserts it is the cancelled terminal event for runID, and verifies no +// second event arrives. This observes the published pubsub event rather +// than internal bookkeeping, which is the contract a `crush run` caller +// blocking on the broker actually relies on. +func requireSingleCancelledRunComplete(t *testing.T, ch <-chan pubsub.Event[notify.RunComplete], sessionID, runID string) { + t.Helper() + select { + case ev := <-ch: + require.Equal(t, runID, ev.Payload.RunID, + "the published RunComplete must carry the dropped queued prompt's RunID") + require.Equal(t, sessionID, ev.Payload.SessionID) + require.True(t, ev.Payload.Cancelled, + "a dropped queued prompt must publish a cancelled RunComplete") + case <-time.After(5 * time.Second): + t.Fatal("timed out waiting for the cancelled RunComplete") + } + select { + case extra := <-ch: + t.Fatalf("expected exactly one RunComplete, got a second: %+v", extra.Payload) + case <-time.After(100 * time.Millisecond): + } +} + +// TestCancel_QueuedRunIDPromptPublishesCancelledRunComplete proves the +// terminal-event behavior end-to-end: a RunID-bearing prompt sitting in +// the queue that is canceled while queued (via the public Cancel path, +// which routes through clearQueueAndNotify -> publishCanceledQueueDrops) +// must emit exactly one cancelled RunComplete on the broker for its +// RunID. A queued prompt without a RunID is dropped silently. This is the +// coverage the earlier drain test lacked: it asserted the returned +// bookkeeping slice, not the published event a `crush run` caller awaits. +func TestCancel_QueuedRunIDPromptPublishesCancelledRunComplete(t *testing.T) { + t.Parallel() + + env := testEnv(t) + broker := pubsub.NewBroker[notify.RunComplete]() + t.Cleanup(broker.Shutdown) + + a := NewSessionAgent(SessionAgentOptions{ + Sessions: env.sessions, + Messages: env.messages, + RunComplete: broker, + }).(*sessionAgent) + + subCtx, subCancel := context.WithCancel(t.Context()) + defer subCancel() + ch := broker.Subscribe(subCtx) + + const sessionID = "cancel-queued-runid" + a.messageQueue.Set(sessionID, []SessionAgentCall{ + {SessionID: sessionID, Prompt: "no-runid", acceptSeq: 1}, + {SessionID: sessionID, RunID: "run-queued", Prompt: "queued", acceptSeq: 2}, + }) + + a.Cancel(sessionID) + + requireSingleCancelledRunComplete(t, ch, sessionID, "run-queued") + + _, ok := a.messageQueue.Get(sessionID) + require.False(t, ok, "Cancel must clear the queue") +} + +// TestDrainQueueForStep_DroppedRunIDPublishesCancelledRunComplete drives +// the production drain sequence (drainQueueForStep then +// publishCanceledQueueDrops, mirroring the PrepareStep handoff) and +// asserts the dropped RunID-bearing prompt actually publishes exactly one +// cancelled RunComplete on the broker. The companion bookkeeping test +// covers the returned slice; this one covers the observable terminal +// event. +func TestDrainQueueForStep_DroppedRunIDPublishesCancelledRunComplete(t *testing.T) { + t.Parallel() + + env := testEnv(t) + broker := pubsub.NewBroker[notify.RunComplete]() + t.Cleanup(broker.Shutdown) + + a := NewSessionAgent(SessionAgentOptions{ + Sessions: env.sessions, + Messages: env.messages, + RunComplete: broker, + }).(*sessionAgent) + + subCtx, subCancel := context.WithCancel(t.Context()) + defer subCancel() + ch := broker.Subscribe(subCtx) + + const sessionID = "drain-drop-runid" + a.messageQueue.Set(sessionID, []SessionAgentCall{ + {SessionID: sessionID, RunID: "run-dropped", Prompt: "dropped", acceptSeq: 1}, + {SessionID: sessionID, Prompt: "dropped-no-runid", acceptSeq: 1}, + }) + a.cancelMark.Set(sessionID, 2) + + _, canceledWithRunID := a.drainQueueForStep(sessionID) + require.Len(t, canceledWithRunID, 1) + a.publishCanceledQueueDrops(canceledWithRunID) + + requireSingleCancelledRunComplete(t, ch, sessionID, "run-dropped") +} diff --git a/internal/agent/run_marker.go b/internal/agent/run_marker.go new file mode 100644 index 0000000000000000000000000000000000000000..404cca1e8c41bb9179deb886552f3580a977fdfc --- /dev/null +++ b/internal/agent/run_marker.go @@ -0,0 +1,52 @@ +package agent + +import ( + "context" + "sync/atomic" +) + +// runCompleteMarkerKey is the unexported context key carrying a +// [runCompleteMarker] from the dispatch boundary (backend.runAgent) +// down into the coordinator. It lets the dispatcher learn whether the +// coordinator already published the authoritative terminal +// notify.RunComplete for the run, so a fallback terminal event is only +// emitted when one is actually missing (e.g. an error returned before +// sessionAgent.Run ever executed). It avoids a breaking change to the +// Coordinator interface. +type runCompleteMarkerKey struct{} + +// runCompleteMarker records whether a terminal RunComplete has been +// published for a run. It is shared by pointer through the context so +// a publish deep in the call stack is observable by the dispatcher +// after the call returns. +type runCompleteMarker struct { + published atomic.Bool +} + +// WithRunCompleteMarker returns ctx carrying a fresh marker the +// coordinator can flag via [MarkRunCompletePublished] once it emits the +// run's terminal RunComplete. Callers read the result with +// [RunCompletePublished]. Attaching the marker is optional: code paths +// without one simply skip the dedup signal. +func WithRunCompleteMarker(ctx context.Context) context.Context { + return context.WithValue(ctx, runCompleteMarkerKey{}, &runCompleteMarker{}) +} + +// MarkRunCompletePublished records that the authoritative terminal +// RunComplete has been published for the run carried by ctx. It is a +// no-op when no marker is present (e.g. the in-process/local Run path, +// which is not dispatched through backend.runAgent). +func MarkRunCompletePublished(ctx context.Context) { + if m, ok := ctx.Value(runCompleteMarkerKey{}).(*runCompleteMarker); ok { + m.published.Store(true) + } +} + +// RunCompletePublished reports whether [MarkRunCompletePublished] was +// called on ctx's marker. It returns false when no marker is present. +func RunCompletePublished(ctx context.Context) bool { + if m, ok := ctx.Value(runCompleteMarkerKey{}).(*runCompleteMarker); ok { + return m.published.Load() + } + return false +} diff --git a/internal/app/app.go b/internal/app/app.go index 9509fa3a9dc778d507d38f60c8ca523031b7ecb7..d8a3abc63b9901c528ec3cee7465eaa0b892349f 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -185,6 +185,14 @@ func (app *App) AgentNotifications() *pubsub.Broker[notify.Notification] { return app.agentNotifications } +// RunCompletions returns the broker for the authoritative per-run +// terminal RunComplete events. The dispatcher (backend.runAgent) uses +// it to emit a reliable terminal event when a run fails before the +// coordinator could publish one of its own. +func (app *App) RunCompletions() *pubsub.Broker[notify.RunComplete] { + return app.runCompletions +} + // resolveSession resolves which session to use for a non-interactive run // If continueSessionID is set, it looks up that session by ID // If useLast is set, it returns the most recently updated top-level session diff --git a/internal/backend/accepted_run_integration_test.go b/internal/backend/accepted_run_integration_test.go new file mode 100644 index 0000000000000000000000000000000000000000..9a0a7ba249cc547956dd479cabcf5545a07a5c26 --- /dev/null +++ b/internal/backend/accepted_run_integration_test.go @@ -0,0 +1,131 @@ +package backend + +import ( + "context" + "testing" + "time" + + "charm.land/fantasy" + "github.com/charmbracelet/crush/internal/agent" + "github.com/charmbracelet/crush/internal/agent/agenttest" + "github.com/charmbracelet/crush/internal/db" + "github.com/charmbracelet/crush/internal/message" + "github.com/charmbracelet/crush/internal/proto" + "github.com/charmbracelet/crush/internal/session" + "github.com/stretchr/testify/require" +) + +// gatedCoordinator wraps a real agent.Coordinator and parks RunAccepted +// before delegating to it. Every method other than RunAccepted is +// inherited from the embedded coordinator, so BeginAccepted (called by +// Backend.SendMessage) and RunAccepted (called by the dispatched run) +// are the production agent.Coordinator implementations under test, not +// stubs. The gate only delays entry into the real RunAccepted so a +// cancel can be made to land in the accepted-but-not-yet-active window +// deterministically: the accept handle is not consumed by +// sessionAgent.Run until the real RunAccepted runs after the gate opens. +type gatedCoordinator struct { + agent.Coordinator + entered chan struct{} + gate chan struct{} +} + +func (c *gatedCoordinator) RunAccepted(ctx context.Context, accept *agent.AcceptedRun, sessionID, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error) { + close(c.entered) + <-c.gate + return c.Coordinator.RunAccepted(ctx, accept, sessionID, prompt, attachments...) +} + +// newRealCoordinator builds a production agent.Coordinator over a +// DB-backed session/message store, wrapped in a gate. It is constructed +// through the real agent.NewCoordinator path (via the test-only +// agenttest helper) with an offline-resolvable model: the +// cancel-on-entry path under test persists a canceled turn and returns +// before any model call, so no network I/O happens. +func newRealCoordinator(t *testing.T) (*gatedCoordinator, session.Service, message.Service) { + t.Helper() + conn, err := db.Connect(t.Context(), t.TempDir()) + require.NoError(t, err) + t.Cleanup(func() { conn.Close() }) + + q := db.New(conn) + sessions := session.NewService(q, conn) + messages := message.NewService(q) + + coord, err := agenttest.NewCoordinator(t.Context(), t.TempDir(), sessions, messages) + require.NoError(t, err) + + return &gatedCoordinator{ + Coordinator: coord, + entered: make(chan struct{}), + gate: make(chan struct{}), + }, sessions, messages +} + +// TestSendMessage_AcceptedCancelRace_RealMachinery exercises the +// 202/cancel race end-to-end through Backend.SendMessage against the +// production agent.Coordinator (BeginAccepted + RunAccepted), not a +// stub. It asserts that a cancel arriving after the prompt is accepted +// but before the run becomes active is not lost: the accepted handle +// reaches sessionAgent.Run and drives cancel-on-entry, which persists a +// canceled turn instead of streaming. +// +// This test would fail if Coordinator.BeginAccepted returned nil (Cancel +// would find no accepted run and record no mark, and the run would +// receive a nil Accepted handle and skip cancel-on-entry) or if +// Coordinator.RunAccepted dropped the handle on its way into +// sessionAgent.Run (the run would likewise skip cancel-on-entry and try +// to stream the model). In either case no FinishReasonCanceled turn +// would be persisted. +func TestSendMessage_AcceptedCancelRace_RealMachinery(t *testing.T) { + t.Parallel() + b, _ := newTestBackend(t) + + coord, sessions, messages := newRealCoordinator(t) + sess, err := sessions.Create(t.Context(), "session") + require.NoError(t, err) + + ws := insertAgentWorkspace(t, b, coord) + + require.NoError(t, b.SendMessage(ws.ID, proto.AgentMessage{SessionID: sess.ID, Prompt: "hi"})) + + // Coordinator.BeginAccepted ran synchronously inside SendMessage + // before dispatch; the dispatched run has now entered the gate but + // has not yet called the real RunAccepted, so the accept handle is + // not yet consumed: the prompt is accepted but not active. + select { + case <-coord.entered: + case <-time.After(2 * time.Second): + t.Fatal("dispatched run never entered RunAccepted") + } + + // A cancel arriving now lands in the accepted-but-not-yet-active + // window and is only recorded because BeginAccepted incremented the + // accept counter. + require.NoError(t, b.CancelSession(ws.ID, sess.ID)) + + // Release the gate so the real RunAccepted threads the handle into + // sessionAgent.Run, which drives cancel-on-entry. + close(coord.gate) + + // The dispatched run returns nil (cancel-on-entry), so runWG drains. + waited := make(chan struct{}) + go func() { + ws.runWG.Wait() + close(waited) + }() + select { + case <-waited: + case <-time.After(5 * time.Second): + t.Fatal("runWG.Wait did not complete after the canceled run returned") + } + + // The accepted-but-not-yet-active cancel persisted a canceled turn + // rather than streaming a real response. + msgs, err := messages.List(t.Context(), sess.ID) + require.NoError(t, err) + require.Len(t, msgs, 2) + require.Equal(t, message.User, msgs[0].Role) + require.Equal(t, message.Assistant, msgs[1].Role) + require.Equal(t, message.FinishReasonCanceled, msgs[1].FinishReason()) +} diff --git a/internal/backend/agent.go b/internal/backend/agent.go index 78447ab7c64a82bb2638fb3fe184d0be132b4589..3d08746ed35c07ab21221c8a6aa0df3941944fd9 100644 --- a/internal/backend/agent.go +++ b/internal/backend/agent.go @@ -2,22 +2,30 @@ package backend import ( "context" + "errors" "github.com/charmbracelet/crush/internal/agent" + "github.com/charmbracelet/crush/internal/agent/notify" "github.com/charmbracelet/crush/internal/config" "github.com/charmbracelet/crush/internal/proto" + "github.com/charmbracelet/crush/internal/pubsub" ) -// SendMessage sends a prompt to the agent coordinator for the given -// workspace and session. +// SendMessage validates and accepts a prompt for the workspace's agent, +// then dispatches the run on a goroutine bound to the workspace context +// and returns immediately. It does not wait for the LLM turn to +// complete: the run's lifetime is owned by the workspace, not by the +// caller. Errors from the dispatched run reach observers through the +// agent event channels (a notify.TypeAgentError notification), not +// through this return value. // -// When msg.RunID is non-empty it is attached to the context via -// agent.WithRunID so the coordinator can stamp the resulting -// SessionAgentCall (and therefore the terminal notify.RunComplete -// event) with that correlator. This is the only way for the -// originating client to distinguish its own turn's RunComplete from -// any concurrent turn that finishes on the same session. -func (b *Backend) SendMessage(ctx context.Context, workspaceID string, msg proto.AgentMessage) error { +// SendMessage returns synchronously when the request cannot be accepted: +// ErrWorkspaceNotFound if the workspace is missing, ErrAgentNotInitialized +// if its coordinator is nil, the structural validation errors from +// agent.ValidateCall (ErrEmptyPrompt, ErrSessionMissing) when the prompt +// or session is missing, and ErrWorkspaceClosing if the workspace is +// being torn down. +func (b *Backend) SendMessage(workspaceID string, msg proto.AgentMessage) error { ws, err := b.GetWorkspace(workspaceID) if err != nil { return err @@ -27,11 +35,88 @@ func (b *Backend) SendMessage(ctx context.Context, workspaceID string, msg proto return ErrAgentNotInitialized } + if err := agent.ValidateCall(agent.SessionAgentCall{ + SessionID: msg.SessionID, + Prompt: msg.Prompt, + Attachments: proto.AttachmentsToMessage(msg.Attachments), + }); err != nil { + return err + } + + accept := ws.AgentCoordinator.BeginAccepted(msg.SessionID) + + ws.runMu.Lock() + if ws.closing { + ws.runMu.Unlock() + accept.Close() + return ErrWorkspaceClosing + } + ws.runWG.Add(1) + ws.runMu.Unlock() + + go b.runAgent(ws, msg, accept) + return nil +} + +// runAgent executes an accepted agent run for the workspace. It owns the +// accept reservation (releasing it on return) and the runWG ticket added +// by SendMessage. The run is bound to the workspace context so its +// lifetime is independent of any client's HTTP request. +// +// On a non-cancel error it surfaces the failure to observers via a +// notify.TypeAgentError notification (lossy, best-effort). That alone is +// not a reliable terminal signal: the agent-event fan-in uses lossy +// subscribers, so a `crush run` caller blocking on its RunID could hang +// if the event is dropped. To guarantee termination, when msg.RunID is +// non-empty and the coordinator did not already publish the run's +// authoritative terminal RunComplete (e.g. the error was returned before +// sessionAgent.Run executed, such as a readyWg or UpdateModels failure), +// runAgent emits an errored RunComplete on the must-deliver +// runCompletions broker so the waiter observes a deterministic terminal +// event. context.Canceled is expected (sessionAgent.Run already +// publishes the cancelled terminal marker) and produces no error +// terminal event. +// +// When msg.RunID is non-empty it is attached to the context via +// agent.WithRunID so the coordinator can stamp the terminal +// notify.RunComplete event with that correlator. A run-complete marker +// is also attached so the coordinator can report whether it published +// the terminal event, letting runAgent avoid a duplicate fallback. +func (b *Backend) runAgent(ws *Workspace, msg proto.AgentMessage, accept *agent.AcceptedRun) { + defer ws.runWG.Done() + defer accept.Close() + + ctx := ws.ctx if msg.RunID != "" { ctx = agent.WithRunID(ctx, msg.RunID) } - _, err = ws.AgentCoordinator.Run(ctx, msg.SessionID, msg.Prompt, proto.AttachmentsToMessage(msg.Attachments)...) - return err + ctx = agent.WithRunCompleteMarker(ctx) + + _, err := ws.AgentCoordinator.RunAccepted(ctx, accept, msg.SessionID, msg.Prompt, proto.AttachmentsToMessage(msg.Attachments)...) + if err == nil || errors.Is(err, context.Canceled) { + return + } + + ws.AgentNotifications().Publish(pubsub.CreatedEvent, notify.Notification{ + SessionID: msg.SessionID, + RunID: msg.RunID, + Type: notify.TypeAgentError, + Message: err.Error(), + }) + + // Reliable terminal fallback. Only needed when a RunID waiter + // exists and the coordinator has not already emitted the run's + // terminal RunComplete; otherwise this would be a duplicate. + if msg.RunID == "" || agent.RunCompletePublished(ctx) { + return + } + if rc := ws.RunCompletions(); rc != nil { + rc.PublishMustDeliver(ctx, pubsub.UpdatedEvent, notify.RunComplete{ + SessionID: msg.SessionID, + RunID: msg.RunID, + Error: err.Error(), + }) + } } // GetAgentInfo returns the agent's model and busy status. diff --git a/internal/backend/agent_runcomplete_test.go b/internal/backend/agent_runcomplete_test.go new file mode 100644 index 0000000000000000000000000000000000000000..be3df103e66b539685e42269d7ced0c7e7e94d86 --- /dev/null +++ b/internal/backend/agent_runcomplete_test.go @@ -0,0 +1,162 @@ +package backend + +import ( + "context" + "errors" + "testing" + "time" + + "charm.land/fantasy" + "github.com/charmbracelet/crush/internal/agent" + "github.com/charmbracelet/crush/internal/app" + "github.com/charmbracelet/crush/internal/message" + "github.com/charmbracelet/crush/internal/proto" + "github.com/google/uuid" + "github.com/stretchr/testify/require" +) + +// errorCoordinator is a minimal agent.Coordinator whose RunAccepted +// returns a configurable error. When markPublished is true it stamps +// the run-complete marker on the context before returning, simulating a +// real coordinator that already published the run's authoritative +// terminal RunComplete (so runAgent must not emit a duplicate fallback). +type errorCoordinator struct { + err error + markPublished bool +} + +func (c *errorCoordinator) Run(ctx context.Context, sessionID, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error) { + return nil, c.err +} + +func (c *errorCoordinator) RunAccepted(ctx context.Context, accept *agent.AcceptedRun, sessionID, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error) { + if c.markPublished { + agent.MarkRunCompletePublished(ctx) + } + return nil, c.err +} + +func (c *errorCoordinator) BeginAccepted(sessionID string) *agent.AcceptedRun { return nil } +func (c *errorCoordinator) Cancel(string) {} +func (c *errorCoordinator) CancelAll() {} +func (c *errorCoordinator) IsBusy() bool { return false } +func (c *errorCoordinator) IsSessionBusy(string) bool { return false } +func (c *errorCoordinator) QueuedPrompts(string) int { return 0 } +func (c *errorCoordinator) QueuedPromptsList(string) []string { return nil } +func (c *errorCoordinator) ClearQueue(string) {} +func (c *errorCoordinator) Summarize(context.Context, string) error { return nil } +func (c *errorCoordinator) Model() agent.Model { return agent.Model{} } +func (c *errorCoordinator) UpdateModels(context.Context) error { return nil } + +// insertRunCompleteWorkspace installs a workspace backed by a real +// app.App (so the runCompletions broker exists) with the given +// coordinator and a workspace run context derived from base. +func insertRunCompleteWorkspace(t *testing.T, b *Backend, base context.Context, coord agent.Coordinator) *Workspace { + t.Helper() + a := app.NewForTest(base) + a.AgentCoordinator = coord + t.Cleanup(a.ShutdownForTest) + ws := &Workspace{ + ID: uuid.New().String(), + Path: t.TempDir(), + resolvedPath: t.TempDir(), + clients: make(map[string]*clientState), + shutdownFn: func() {}, + } + ws.App = a + ws.ctx, ws.cancel = context.WithCancel(base) + b.mu.Lock() + b.workspaces.Set(ws.ID, ws) + b.pathIndex[ws.resolvedPath] = ws.ID + b.mu.Unlock() + return ws +} + +// TestRunAgent_PreRunErrorPublishesTerminalRunComplete proves that an +// error returned from RunAccepted before the coordinator could publish +// its own terminal event (e.g. a readyWg or UpdateModels failure, +// modeled here by a stub coordinator) still yields a reliable terminal +// RunComplete for the run's RunID. Without it, a `crush run` caller +// blocking on that RunID would hang because the lossy TypeAgentError +// event is not a guaranteed terminal signal. +func TestRunAgent_PreRunErrorPublishesTerminalRunComplete(t *testing.T) { + t.Parallel() + b, _ := newTestBackend(t) + runErr := errors.New("update models failed") + ws := insertRunCompleteWorkspace(t, b, context.Background(), &errorCoordinator{err: runErr}) + + subCtx, cancel := context.WithCancel(context.Background()) + defer cancel() + ch := ws.RunCompletions().Subscribe(subCtx) + + err := b.SendMessage(ws.ID, proto.AgentMessage{SessionID: "S1", RunID: "run-1", Prompt: "hi"}) + require.NoError(t, err) + + select { + case ev := <-ch: + require.Equal(t, "run-1", ev.Payload.RunID, + "the terminal RunComplete must carry the dispatched RunID") + require.Equal(t, "S1", ev.Payload.SessionID) + require.Equal(t, runErr.Error(), ev.Payload.Error, + "the fallback terminal event must be marked errored") + require.False(t, ev.Payload.Cancelled) + case <-time.After(2 * time.Second): + t.Fatal("no terminal RunComplete published for a pre-run error; a run waiter would hang") + } +} + +// TestRunAgent_NoFallbackWhenCoordinatorPublished ensures the fallback +// is suppressed when the coordinator already emitted the run's +// authoritative terminal RunComplete, so callers never observe a +// duplicate terminal event for the same RunID. +func TestRunAgent_NoFallbackWhenCoordinatorPublished(t *testing.T) { + t.Parallel() + b, _ := newTestBackend(t) + runErr := errors.New("stream failed after publishing terminal event") + ws := insertRunCompleteWorkspace(t, b, context.Background(), + &errorCoordinator{err: runErr, markPublished: true}) + + subCtx, cancel := context.WithCancel(context.Background()) + defer cancel() + ch := ws.RunCompletions().Subscribe(subCtx) + + err := b.SendMessage(ws.ID, proto.AgentMessage{SessionID: "S1", RunID: "run-1", Prompt: "hi"}) + require.NoError(t, err) + + // Wait for the dispatched run goroutine to return so any publish + // has already happened. + ws.runWG.Wait() + + select { + case ev := <-ch: + t.Fatalf("runAgent published a duplicate terminal RunComplete: %+v", ev.Payload) + case <-time.After(200 * time.Millisecond): + } +} + +// TestRunAgent_CancellationPublishesNoErrorTerminal verifies that a +// context.Canceled result from RunAccepted produces no errored terminal +// RunComplete from runAgent: cancellation is sessionAgent.Run's +// responsibility (it publishes the cancelled marker) and the dispatcher +// must not synthesize an error terminal for it. +func TestRunAgent_CancellationPublishesNoErrorTerminal(t *testing.T) { + t.Parallel() + b, _ := newTestBackend(t) + ws := insertRunCompleteWorkspace(t, b, context.Background(), + &errorCoordinator{err: context.Canceled}) + + subCtx, cancel := context.WithCancel(context.Background()) + defer cancel() + ch := ws.RunCompletions().Subscribe(subCtx) + + err := b.SendMessage(ws.ID, proto.AgentMessage{SessionID: "S1", RunID: "run-1", Prompt: "hi"}) + require.NoError(t, err) + + ws.runWG.Wait() + + select { + case ev := <-ch: + t.Fatalf("cancellation must not publish a terminal RunComplete: %+v", ev.Payload) + case <-time.After(200 * time.Millisecond): + } +} diff --git a/internal/backend/agent_test.go b/internal/backend/agent_test.go new file mode 100644 index 0000000000000000000000000000000000000000..5d9365ecc899236b91d73198ab322bcabbf9cc77 --- /dev/null +++ b/internal/backend/agent_test.go @@ -0,0 +1,163 @@ +package backend + +import ( + "context" + "sync/atomic" + "testing" + "time" + + "charm.land/fantasy" + "github.com/charmbracelet/crush/internal/agent" + "github.com/charmbracelet/crush/internal/app" + "github.com/charmbracelet/crush/internal/message" + "github.com/charmbracelet/crush/internal/proto" + "github.com/google/uuid" + "github.com/stretchr/testify/require" +) + +// blockingCoordinator is a minimal agent.Coordinator whose RunAccepted +// blocks until release is closed. It records that RunAccepted was +// entered so tests can observe the dispatched goroutine. Every other +// method returns a zero value. +type blockingCoordinator struct { + entered chan struct{} + release chan struct{} + runCount atomic.Int32 +} + +func newBlockingCoordinator() *blockingCoordinator { + return &blockingCoordinator{ + entered: make(chan struct{}, 1), + release: make(chan struct{}), + } +} + +func (c *blockingCoordinator) Run(ctx context.Context, sessionID, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error) { + return nil, nil +} + +func (c *blockingCoordinator) RunAccepted(ctx context.Context, accept *agent.AcceptedRun, sessionID, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error) { + c.runCount.Add(1) + select { + case c.entered <- struct{}{}: + default: + } + <-c.release + return nil, nil +} + +func (c *blockingCoordinator) BeginAccepted(sessionID string) *agent.AcceptedRun { return nil } +func (c *blockingCoordinator) Cancel(string) {} +func (c *blockingCoordinator) CancelAll() {} +func (c *blockingCoordinator) IsBusy() bool { return false } +func (c *blockingCoordinator) IsSessionBusy(string) bool { return false } +func (c *blockingCoordinator) QueuedPrompts(string) int { return 0 } +func (c *blockingCoordinator) QueuedPromptsList(string) []string { return nil } +func (c *blockingCoordinator) ClearQueue(string) {} +func (c *blockingCoordinator) Summarize(context.Context, string) error { return nil } +func (c *blockingCoordinator) Model() agent.Model { return agent.Model{} } +func (c *blockingCoordinator) UpdateModels(context.Context) error { return nil } + +// insertAgentWorkspace installs a synthetic workspace with the given +// coordinator (or none) and a workspace run context, mirroring the +// fields CreateWorkspace initializes. +func insertAgentWorkspace(t *testing.T, b *Backend, coord agent.Coordinator) *Workspace { + t.Helper() + ws := &Workspace{ + ID: uuid.New().String(), + Path: t.TempDir(), + resolvedPath: t.TempDir(), + clients: make(map[string]*clientState), + shutdownFn: func() {}, + } + ws.App = &app.App{AgentCoordinator: coord} + ws.ctx, ws.cancel = context.WithCancel(b.ctx) + b.mu.Lock() + b.workspaces.Set(ws.ID, ws) + b.pathIndex[ws.resolvedPath] = ws.ID + b.mu.Unlock() + return ws +} + +func TestSendMessage_WorkspaceNotFound(t *testing.T) { + t.Parallel() + b, _ := newTestBackend(t) + err := b.SendMessage("nope", proto.AgentMessage{SessionID: "S1", Prompt: "hi"}) + require.ErrorIs(t, err, ErrWorkspaceNotFound) +} + +func TestSendMessage_AgentNotInitialized(t *testing.T) { + t.Parallel() + b, _ := newTestBackend(t) + ws := insertAgentWorkspace(t, b, nil) + err := b.SendMessage(ws.ID, proto.AgentMessage{SessionID: "S1", Prompt: "hi"}) + require.ErrorIs(t, err, ErrAgentNotInitialized) +} + +func TestSendMessage_EmptyPrompt(t *testing.T) { + t.Parallel() + b, _ := newTestBackend(t) + ws := insertAgentWorkspace(t, b, newBlockingCoordinator()) + err := b.SendMessage(ws.ID, proto.AgentMessage{SessionID: "S1", Prompt: ""}) + require.ErrorIs(t, err, agent.ErrEmptyPrompt) +} + +func TestSendMessage_SessionMissing(t *testing.T) { + t.Parallel() + b, _ := newTestBackend(t) + ws := insertAgentWorkspace(t, b, newBlockingCoordinator()) + err := b.SendMessage(ws.ID, proto.AgentMessage{SessionID: "", Prompt: "hi"}) + require.ErrorIs(t, err, agent.ErrSessionMissing) +} + +func TestSendMessage_WorkspaceClosing(t *testing.T) { + t.Parallel() + b, _ := newTestBackend(t) + ws := insertAgentWorkspace(t, b, newBlockingCoordinator()) + ws.runMu.Lock() + ws.closing = true + ws.runMu.Unlock() + err := b.SendMessage(ws.ID, proto.AgentMessage{SessionID: "S1", Prompt: "hi"}) + require.ErrorIs(t, err, ErrWorkspaceClosing) +} + +// TestSendMessage_SuccessIncrementsRunWG asserts the happy path returns +// nil synchronously and dispatches a tracked goroutine: while +// RunAccepted blocks, runWG.Wait must not complete (the ticket is +// outstanding); after release it drains. +func TestSendMessage_SuccessIncrementsRunWG(t *testing.T) { + t.Parallel() + b, _ := newTestBackend(t) + coord := newBlockingCoordinator() + ws := insertAgentWorkspace(t, b, coord) + + err := b.SendMessage(ws.ID, proto.AgentMessage{SessionID: "S1", Prompt: "hi"}) + require.NoError(t, err) + + select { + case <-coord.entered: + case <-time.After(2 * time.Second): + t.Fatal("dispatched goroutine never entered RunAccepted") + } + require.Equal(t, int32(1), coord.runCount.Load()) + + waited := make(chan struct{}) + go func() { + ws.runWG.Wait() + close(waited) + }() + + select { + case <-waited: + t.Fatal("runWG.Wait completed while the run was still in flight; ticket was not added") + case <-time.After(100 * time.Millisecond): + } + + close(coord.release) + + select { + case <-waited: + case <-time.After(2 * time.Second): + t.Fatal("runWG.Wait did not complete after the run returned") + } +} diff --git a/internal/backend/backend.go b/internal/backend/backend.go index 9c0f47f7ab8dfd48f90a31004d637dd6e54fd912..2ea24a86c71fbc7c4f58c5de680aeb83d421584b 100644 --- a/internal/backend/backend.go +++ b/internal/backend/backend.go @@ -34,6 +34,7 @@ var ( ErrUnknownCommand = errors.New("unknown command") ErrInvalidClientID = errors.New("invalid client_id") ErrClientNotAttached = errors.New("client not attached") + ErrWorkspaceClosing = errors.New("workspace closing") ) // DefaultCreateGrace is the window in which a client must open an SSE @@ -108,6 +109,23 @@ type Workspace struct { // with fallback to the cleaned absolute path. resolvedPath string + // ctx is the workspace-scoped run context. It is derived from + // the backend context in CreateWorkspace and lives for the + // lifetime of the workspace; cancel tears it down. Agent runs + // dispatched on behalf of this workspace are bound to ctx so + // their lifetime is owned by the workspace, not by any single + // client's HTTP request. + ctx context.Context + cancel context.CancelFunc + + // runMu guards closing and gates dispatch of new agent runs. + // closing is set by Shutdown so no new runs are accepted once + // teardown has begun. runWG tracks dispatched agent goroutines + // so Shutdown can wait for them to return before app cleanup. + runMu sync.Mutex + closing bool + runWG sync.WaitGroup + // clientsMu guards clients. It is held only briefly (no IO). clientsMu sync.Mutex // clients tracks each client's claim on this workspace. Refcount @@ -122,7 +140,7 @@ type Workspace struct { } // invokeShutdown calls the workspace shutdown hook if set, falling -// back to the embedded [app.App.Shutdown] when not. +// back to the workspace [Workspace.Shutdown] wrapper when not. func (w *Workspace) invokeShutdown() { if w.shutdownFn != nil { w.shutdownFn() @@ -133,6 +151,40 @@ func (w *Workspace) invokeShutdown() { } } +// Shutdown tears the workspace down in an order that is safe for +// agent runs whose lifetime is bound to the workspace context. It +// shadows the promoted [app.App.Shutdown] so callers reaching +// ws.Shutdown() always observe this ordering: +// +// 1. Mark the workspace closing so no new agent runs are accepted. +// 2. Cancel the workspace run context so any dispatched goroutine +// that has not yet registered its per-session cancel still +// observes cancellation. +// 3. Cancel active coordinator work for runs that already +// registered their per-session cancel function. +// 4. Wait for dispatched agent goroutines to return. +// 5. Run the embedded [app.App.Shutdown] cleanup (DB, LSP, etc). +// +// CancelAll is idempotent, so the second call inside app.App.Shutdown +// is harmless; the important guarantee is that cancel -> CancelAll -> +// runWG.Wait completes before the embedded cleanup touches the DB. +func (w *Workspace) Shutdown() { + w.runMu.Lock() + w.closing = true + w.runMu.Unlock() + + if w.cancel != nil { + w.cancel() + } + if w.App != nil && w.AgentCoordinator != nil { + w.AgentCoordinator.CancelAll() + } + w.runWG.Wait() + if w.App != nil { + w.App.Shutdown() + } +} + // New creates a new [Backend]. func New(ctx context.Context, cfg *config.ConfigStore, shutdownFn ShutdownFunc) *Backend { return &Backend{ @@ -247,6 +299,7 @@ func (b *Backend) CreateWorkspace(args proto.Workspace) (*Workspace, proto.Works return nil, proto.Workspace{}, fmt.Errorf("failed to create app workspace: %w", err) } + wsCtx, wsCancel := context.WithCancel(b.ctx) ws := &Workspace{ App: appWorkspace, ID: id, @@ -255,6 +308,8 @@ func (b *Backend) CreateWorkspace(args proto.Workspace) (*Workspace, proto.Works Env: args.Env, Skills: skillsMgr, resolvedPath: key, + ctx: wsCtx, + cancel: wsCancel, clients: make(map[string]*clientState), } diff --git a/internal/backend/testing.go b/internal/backend/testing.go index 6616e0f19e06595fac68808b484394d960d7f79f..1c71caed6566747ac947d61c30ece575d3d13eb4 100644 --- a/internal/backend/testing.go +++ b/internal/backend/testing.go @@ -1,9 +1,16 @@ package backend +import "context" + // InsertWorkspaceForTest registers ws with b under its current ID and // path. It is intended for tests in other packages that need to drive // HTTP handlers against a synthetic workspace without booting a real // app.App. Production code should go through CreateWorkspace. +// +// If the workspace has no run context yet it is derived from the +// backend context (falling back to context.Background), mirroring the +// initialization CreateWorkspace performs, so dispatched agent runs +// have a non-nil ws.ctx. func InsertWorkspaceForTest(b *Backend, ws *Workspace) { if ws.resolvedPath == "" { ws.resolvedPath = ws.Path @@ -11,6 +18,13 @@ func InsertWorkspaceForTest(b *Backend, ws *Workspace) { if ws.clients == nil { ws.clients = make(map[string]*clientState) } + if ws.ctx == nil { + parent := b.ctx + if parent == nil { + parent = context.Background() + } + ws.ctx, ws.cancel = context.WithCancel(parent) + } b.mu.Lock() defer b.mu.Unlock() b.workspaces.Set(ws.ID, ws) diff --git a/internal/client/proto.go b/internal/client/proto.go index 62a43b5884e01ae8fcd3242c68e95d1f76251c42..d07e46dc84bf09dccffbd609784f92c7ae9a9c67 100644 --- a/internal/client/proto.go +++ b/internal/client/proto.go @@ -423,12 +423,28 @@ func (c *Client) SendMessage(ctx context.Context, id string, sessionID, runID, p return fmt.Errorf("failed to send message to agent: %w", err) } defer rsp.Body.Close() - if rsp.StatusCode != http.StatusOK { + if rsp.StatusCode != http.StatusOK && rsp.StatusCode != http.StatusAccepted { + if msg := decodeErrorMessage(rsp.Body); msg != "" { + return fmt.Errorf("failed to send message to agent: status code %d: %s", rsp.StatusCode, msg) + } return fmt.Errorf("failed to send message to agent: status code %d", rsp.StatusCode) } return nil } +// decodeErrorMessage attempts to decode the response body as a +// proto.Error and returns its message. It returns an empty string +// when the body is empty or cannot be decoded into a proto.Error +// with a non-empty message, letting callers fall back to a +// status-only error. +func decodeErrorMessage(body io.Reader) string { + var e proto.Error + if err := json.NewDecoder(body).Decode(&e); err != nil { + return "" + } + return e.Message +} + // GetAgentSessionInfo retrieves the agent session info for a workspace. func (c *Client) GetAgentSessionInfo(ctx context.Context, id string, sessionID string) (*proto.AgentSession, error) { rsp, err := c.get(ctx, fmt.Sprintf("/workspaces/%s/agent/sessions/%s", id, sessionID), nil, nil) diff --git a/internal/client/proto_test.go b/internal/client/proto_test.go index b5739ccc91c16b2bb0fc3c3f6dc2281687bd8e65..c7abd3e03d4ae6f575079c7c938369d6cb7cc30b 100644 --- a/internal/client/proto_test.go +++ b/internal/client/proto_test.go @@ -88,6 +88,76 @@ func TestSubscribeEventsContextCancelClosesEvents(t *testing.T) { } } +func TestSendMessageAcceptsStatusAccepted(t *testing.T) { + t.Parallel() + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusAccepted) + })) + defer srv.Close() + + c := captureClient(t, srv) + require.NoError(t, c.SendMessage(context.Background(), "ws1", "sess1", "", "hello")) +} + +func TestSendMessageAcceptsStatusOK(t *testing.T) { + t.Parallel() + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer srv.Close() + + c := captureClient(t, srv) + require.NoError(t, c.SendMessage(context.Background(), "ws1", "sess1", "", "hello")) +} + +func TestSendMessageDecodesErrorBody(t *testing.T) { + t.Parallel() + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusBadRequest) + _ = json.NewEncoder(w).Encode(proto.Error{Message: "session id is required"}) + })) + defer srv.Close() + + c := captureClient(t, srv) + err := c.SendMessage(context.Background(), "ws1", "", "", "hello") + require.Error(t, err) + require.Contains(t, err.Error(), "status code 400") + require.Contains(t, err.Error(), "session id is required") +} + +func TestSendMessageFallsBackOnMalformedErrorBody(t *testing.T) { + t.Parallel() + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + _, _ = w.Write([]byte("not json")) + })) + defer srv.Close() + + c := captureClient(t, srv) + err := c.SendMessage(context.Background(), "ws1", "sess1", "", "hello") + require.Error(t, err) + require.Contains(t, err.Error(), "status code 500") + require.NotContains(t, err.Error(), "not json") +} + +func TestSendMessageFallsBackOnEmptyErrorBody(t *testing.T) { + t.Parallel() + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + })) + defer srv.Close() + + c := captureClient(t, srv) + err := c.SendMessage(context.Background(), "ws1", "sess1", "", "hello") + require.Error(t, err) + require.Contains(t, err.Error(), "status code 500") +} + func marshalSSEPayload(t *testing.T) []byte { t.Helper() diff --git a/internal/cmd/run.go b/internal/cmd/run.go index 87fa32606674847741a9d028a26375fb98935fc4..2feeba78e6f4862e453fcb790e428b3e08ab0505 100644 --- a/internal/cmd/run.go +++ b/internal/cmd/run.go @@ -409,11 +409,27 @@ func (s *runStream) handle(ev any, stopSpinner func()) (done bool, err error) { return true, nil case pubsub.Event[proto.AgentEvent]: - if e.Payload.Error != nil { - stop() - return true, fmt.Errorf("agent error: %w", e.Payload.Error) + if e.Payload.Error == nil { + return false, nil } - return false, nil + // Attribute the error to our run before treating it as + // fatal. Async errors from an unrelated workspace run share + // this channel, so a foreign failure must not abort us: + // - if the event carries a RunID, it is the authoritative + // correlator: it must match our run exactly, otherwise it + // belongs to a different request and we ignore it. + // - if the event carries no RunID (older server), fall back + // to SessionID: it must be present and match our session, + // otherwise we ignore it. + if e.Payload.RunID != "" { + if e.Payload.RunID != s.runID { + return false, nil + } + } else if e.Payload.SessionID == "" || e.Payload.SessionID != s.sessionID { + return false, nil + } + stop() + return true, fmt.Errorf("agent error: %w", e.Payload.Error) } return false, nil } diff --git a/internal/cmd/run_stream_test.go b/internal/cmd/run_stream_test.go index ac168fa77045aa6aa5761b6f9c657f066c952734..028eb03baa0dc7a55a0037e67f033b708ff9634e 100644 --- a/internal/cmd/run_stream_test.go +++ b/internal/cmd/run_stream_test.go @@ -2,6 +2,7 @@ package cmd import ( "bytes" + "errors" "testing" "time" @@ -307,6 +308,93 @@ func TestRunStream_RunIDSuppressesLiveMessagesAndPrintsRunComplete(t *testing.T) require.Equal(t, "streamed prefix final", buf.String()) } +// TestRunStream_AgentErrorRunIDFiltersForeign verifies that an async +// agent error carrying a non-empty RunID is fatal only when it matches +// our run. A foreign RunID is ignored regardless of the event's +// SessionID, because RunID is the authoritative correlator and async +// errors share the agent event channel: without strict RunID matching +// an unrelated workspace failure would abort our run. +func TestRunStream_AgentErrorRunIDFiltersForeign(t *testing.T) { + t.Parallel() + + // Foreign RunID with a matching session is still foreign. + s := &runStream{sessionID: "S", runID: "run-mine", out: &bytes.Buffer{}, read: map[string]int{}} + done, err := s.handle(pubsub.Event[proto.AgentEvent]{Payload: proto.AgentEvent{ + Type: proto.AgentEventTypeError, + SessionID: "S", + RunID: "run-other", + Error: errors.New("foreign boom"), + }}, nil) + require.NoError(t, err, "foreign RunID error must not abort our run") + require.False(t, done) + + // Foreign RunID with a different session is ignored. + done, err = s.handle(pubsub.Event[proto.AgentEvent]{Payload: proto.AgentEvent{ + Type: proto.AgentEventTypeError, + SessionID: "other", + RunID: "run-other", + Error: errors.New("foreign boom"), + }}, nil) + require.NoError(t, err, "foreign RunID error must not abort our run") + require.False(t, done) + + // Foreign RunID with a missing session is ignored. + done, err = s.handle(pubsub.Event[proto.AgentEvent]{Payload: proto.AgentEvent{ + Type: proto.AgentEventTypeError, + RunID: "run-other", + Error: errors.New("foreign boom"), + }}, nil) + require.NoError(t, err, "foreign RunID error must not abort our run") + require.False(t, done) + + // Matching RunID is fatal. + done, err = s.handle(pubsub.Event[proto.AgentEvent]{Payload: proto.AgentEvent{ + Type: proto.AgentEventTypeError, + SessionID: "S", + RunID: "run-mine", + Error: errors.New("my boom"), + }}, nil) + require.Error(t, err, "matching RunID error must be fatal") + require.True(t, done) +} + +// TestRunStream_AgentErrorNoRunIDFiltersBySession verifies the +// compatibility fallback: when the event carries no RunID, attribution +// falls back to SessionID. An error for another session or with an +// empty session is ignored, while an error for our own session is fatal +// so a real failure is never dropped. +func TestRunStream_AgentErrorNoRunIDFiltersBySession(t *testing.T) { + t.Parallel() + + s := &runStream{sessionID: "S", out: &bytes.Buffer{}, read: map[string]int{}} + + // Empty RunID for another session is ignored. + done, err := s.handle(pubsub.Event[proto.AgentEvent]{Payload: proto.AgentEvent{ + Type: proto.AgentEventTypeError, + SessionID: "other", + Error: errors.New("foreign boom"), + }}, nil) + require.NoError(t, err, "error for another session must not abort our run") + require.False(t, done) + + // Empty RunID with an empty session is ignored. + done, err = s.handle(pubsub.Event[proto.AgentEvent]{Payload: proto.AgentEvent{ + Type: proto.AgentEventTypeError, + Error: errors.New("foreign boom"), + }}, nil) + require.NoError(t, err, "error with no session must not abort our run") + require.False(t, done) + + // Empty RunID with a matching session is fatal. + done, err = s.handle(pubsub.Event[proto.AgentEvent]{Payload: proto.AgentEvent{ + Type: proto.AgentEventTypeError, + SessionID: "S", + Error: errors.New("my boom"), + }}, nil) + require.Error(t, err, "error for our own session must be fatal") + require.True(t, done) +} + // TestRunStream_NoRunIDFallsBackToSessionID preserves the older // behaviour for callers (and tests) that don't supply a RunID: // SessionID-only matching still terminates the stream on the diff --git a/internal/proto/agent.go b/internal/proto/agent.go index e5266e52614a5bc43065ff62cf18b16f8ee7401f..2c85923e547b6755357479218f9ff4815e491527 100644 --- a/internal/proto/agent.go +++ b/internal/proto/agent.go @@ -31,6 +31,13 @@ type AgentEvent struct { Message Message `json:"message"` Error error `json:"error,omitempty"` + // RunID echoes the caller-supplied AgentMessage.RunID for the run + // that produced this event. It lets observers (notably + // `crush run`) attribute an error event to a specific request + // instead of to any in-flight run on the session. Empty when no + // caller set one. + RunID string `json:"run_id,omitempty"` + // When summarizing. SessionID string `json:"session_id,omitempty"` SessionTitle string `json:"session_title,omitempty"` diff --git a/internal/server/agent_cancel_test.go b/internal/server/agent_cancel_test.go index 1bbb05e511a26f02cc1dad0e1da77454af0f8905..68d1f10132db3fb9f6b1e4251b744c59887f613e 100644 --- a/internal/server/agent_cancel_test.go +++ b/internal/server/agent_cancel_test.go @@ -60,6 +60,13 @@ func (s *runCoordinator) Run(ctx context.Context, sessionID, prompt string, atta return nil, s.returnFn(ctx) } +func (s *runCoordinator) RunAccepted(ctx context.Context, accept *agent.AcceptedRun, sessionID, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error) { + return s.Run(ctx, sessionID, prompt, attachments...) +} + +func (s *runCoordinator) BeginAccepted(sessionID string) *agent.AcceptedRun { + return nil +} func (s *runCoordinator) Cancel(string) {} func (s *runCoordinator) CancelAll() {} func (s *runCoordinator) IsBusy() bool { return false } @@ -115,11 +122,12 @@ func postAgent(t *testing.T, c *controllerV1, ctx context.Context, wsID, session } // TestPostAgent_ReturnsOKOnContextCanceled verifies that when another -// client cancels the session mid-turn, the prompting client's still -// open POST receives 200 (not 500). The agent surfaces the -// FinishReasonCanceled marker to every SSE subscriber via the -// assistant message; the HTTP response from the prompter should not -// double as an error signal. +// client cancels the session mid-turn, the prompting client's POST is +// unaffected: SendMessage is fire-and-forget, so the handler returns +// 200 immediately without waiting for the turn. A run that later +// returns context.Canceled never surfaces as a 500 to the prompter; +// the FinishReasonCanceled marker reaches SSE subscribers via the +// assistant message instead. func TestPostAgent_ReturnsOKOnContextCanceled(t *testing.T) { t.Parallel() @@ -128,33 +136,56 @@ func TestPostAgent_ReturnsOKOnContextCanceled(t *testing.T) { }) c, wsID := buildAgentWorkspace(t, coord) - done := make(chan *httptest.ResponseRecorder, 1) - go func() { - done <- postAgent(t, c, t.Context(), wsID, "S1") - }() + // The handler returns immediately, before the dispatched run is + // released, because the run no longer owns the HTTP response. + rec := postAgent(t, c, t.Context(), wsID, "S1") + require.Equal(t, http.StatusAccepted, rec.Code, "fire-and-forget SendMessage must return 202 without waiting for the run") - // Wait until Run is in flight, then release it to return - // context.Canceled. + // The run is dispatched on a goroutine; let it return + // context.Canceled. Nothing from that path reaches the (already + // returned) handler. select { case <-coord.entered: case <-time.After(2 * time.Second): - t.Fatal("coordinator Run was never entered") + t.Fatal("dispatched run was never entered") } close(coord.release) - select { - case rec := <-done: - require.Equal(t, http.StatusOK, rec.Code, "context.Canceled from another client's cancel must not surface as 500") - case <-time.After(2 * time.Second): - t.Fatal("handler did not return after coordinator returned context.Canceled") - } + // Wait for the dispatched run to fully return. Backend.runAgent + // swallows context.Canceled, so it must not publish a + // notify.TypeAgentError. Publishing would dereference the synthetic + // workspace's nil notification broker and crash this goroutine, + // which is the explicit guard that a cancel produces no top-level + // error event. + require.Eventually(t, func() bool { + return coord.ranCount.Load() == 1 + }, 2*time.Second, 10*time.Millisecond) } -// TestPostAgent_DetachesRequestContext verifies that canceling the -// prompting client's HTTP request context does not cancel the -// in-flight agent run. The coordinator must observe a context whose -// Done channel never fires from the request side; only the explicit -// cancel endpoint may end the run. +// TestHandleError_ContextCanceledFallsThroughTo500 documents the step 8 +// cleanup: the old context.Canceled special case in handleError was +// removed because runtime cancellation of an agent run can no longer +// reach handleError. The agent-prompt handler returns 202 before the run +// starts (fire-and-forget SendMessage) and Backend.runAgent swallows +// context.Canceled. Any context.Canceled that still reaches handleError +// is therefore an unexpected synchronous error and falls through to the +// default 500 like any other. +func TestHandleError_ContextCanceledFallsThroughTo500(t *testing.T) { + t.Parallel() + + c := &controllerV1{server: &Server{}} + rec := httptest.NewRecorder() + req := httptest.NewRequestWithContext(t.Context(), http.MethodGet, "/", nil) + + c.handleError(rec, req, context.Canceled) + + require.Equal(t, http.StatusInternalServerError, rec.Code) +} + +// TestPostAgent_DetachesRequestContext verifies that the dispatched run +// is bound to the workspace context, not the prompting client's HTTP +// request context. Canceling the request context must neither cancel +// the run nor be observed by the coordinator. func TestPostAgent_DetachesRequestContext(t *testing.T) { t.Parallel() @@ -165,46 +196,31 @@ func TestPostAgent_DetachesRequestContext(t *testing.T) { reqCtx, cancelReq := context.WithCancel(context.Background()) - done := make(chan *httptest.ResponseRecorder, 1) - go func() { - done <- postAgent(t, c, reqCtx, wsID, "S1") - }() + // The handler returns immediately; the run keeps executing on its + // own goroutine bound to the workspace context. + rec := postAgent(t, c, reqCtx, wsID, "S1") + require.Equal(t, http.StatusAccepted, rec.Code) - // Wait until Run is in flight, then drop the prompting client. select { case <-coord.entered: case <-time.After(2 * time.Second): - t.Fatal("coordinator Run was never entered") + t.Fatal("dispatched run was never entered") } + + // Drop the prompting client. This must not reach the run. cancelReq() - // The captured ctx must be detached: context.WithoutCancel - // returns a ctx with Done() == nil so request cancellation cannot - // propagate. got := coord.capturedCtx() require.NotNil(t, got) - require.Nil(t, got.Done(), "coordinator ctx must be detached from r.Context() via context.WithoutCancel") - require.NoError(t, got.Err(), "coordinator ctx must not inherit cancellation from the dropped request") - - // Confirm Run is still running: it should not have completed - // just because the request ctx was canceled. - select { - case <-done: - t.Fatal("handler returned before run completed; request ctx cancellation leaked into the run") - case <-time.After(50 * time.Millisecond): - } + // Compare by identity (pointer), not reflect.DeepEqual: deep + // comparison would traverse context internals that the runtime + // mutates concurrently. + require.False(t, got == reqCtx, "run ctx must not be the request ctx") + require.NoError(t, got.Err(), "run ctx must not inherit cancellation from the dropped request") - // Release the run; the handler should now complete cleanly. + // Release the run so it returns cleanly. close(coord.release) - select { - case rec := <-done: - // Writing to a recorder whose request ctx was canceled - // still works; in production the TCP write would silently - // fail, which is fine because the run already completed and - // SSE subscribers have the result. - require.Equal(t, http.StatusOK, rec.Code) - case <-time.After(2 * time.Second): - t.Fatal("handler did not return after release") - } - require.Equal(t, int32(1), coord.ranCount.Load()) + require.Eventually(t, func() bool { + return coord.ranCount.Load() == 1 + }, 2*time.Second, 10*time.Millisecond) } diff --git a/internal/server/e2e_agent_test.go b/internal/server/e2e_agent_test.go new file mode 100644 index 0000000000000000000000000000000000000000..012068a967536cf214518f24c946e8b98fd16932 --- /dev/null +++ b/internal/server/e2e_agent_test.go @@ -0,0 +1,741 @@ +package server + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "sync" + "sync/atomic" + "testing" + "time" + + "charm.land/fantasy" + "github.com/charmbracelet/crush/internal/agent" + "github.com/charmbracelet/crush/internal/app" + "github.com/charmbracelet/crush/internal/backend" + "github.com/charmbracelet/crush/internal/message" + "github.com/charmbracelet/crush/internal/proto" + "github.com/charmbracelet/crush/internal/pubsub" + "github.com/google/uuid" + "github.com/stretchr/testify/require" +) + +// scriptedCoordinator is an agent.Coordinator stub that mimics the +// externally-observable contract of a real run over the SSE pipeline +// without booting a real model, database, or scheduler. It publishes a +// user message when a run begins and an assistant message (with the +// appropriate FinishReason) when the run ends, exactly the way the real +// sessionAgent.Run surfaces a turn to SSE subscribers. +// +// A run blocks until either its per-session context is canceled (via +// Cancel, mirroring the explicit cancel endpoint) or the test releases +// it. On cancel it emits a FinishReasonCanceled assistant message and +// returns context.Canceled (which backend.runAgent swallows, so no +// AgentEvent error is published). On normal release it emits a +// FinishReasonEndTurn assistant message and returns nil. +// +// The internal scheduler signal points the PLAN's e2e cases reference +// (e.g. "before registration in activeRequests", "between +// activeRequests.Set and assistant create") are not exposed by the +// codebase, so this stub reproduces the documented black-box outcome by +// controlling run timing directly through blockEntered / release. +type scriptedCoordinator struct { + app *app.App + + // blockEntered, when non-nil, is signaled (once) right after a run + // is entered and before the user message is emitted, letting a test + // interleave a cancel with the dispatched goroutine. + blockEntered chan struct{} + + mu sync.Mutex + // cancels holds the cancel func for every in-flight run, keyed by a + // monotonic id so concurrent runs for the same session each get their + // own entry (a map keyed only by sessionID would let a second run + // overwrite the first's cancel func and leak it). + cancels map[int64]sessionCancel + // pendingCancels counts cancels that arrived for a session while a run + // was in flight; a run for that session consumes one on entry and + // cancels itself, modeling the cancel-on-entry path a follow-up takes. + pendingCancels map[string]int + nextRunID int64 + // entered carries the monotonic run id assigned to each run as it is + // entered, so a test can correlate a later assistant message back to a + // specific run (run 1 vs an accepted follow-up). + entered chan int64 + runStarts atomic.Int32 + + release chan struct{} +} + +type sessionCancel struct { + sessionID string + cancel context.CancelFunc +} + +func newScriptedCoordinator(a *app.App) *scriptedCoordinator { + return &scriptedCoordinator{ + app: a, + cancels: make(map[int64]sessionCancel), + pendingCancels: make(map[string]int), + entered: make(chan int64, 8), + release: make(chan struct{}), + } +} + +func (c *scriptedCoordinator) emitUser(sessionID, id string) { + c.app.SendEvent(pubsub.Event[message.Message]{ + Type: pubsub.CreatedEvent, + Payload: message.Message{ + ID: id, + SessionID: sessionID, + Role: message.User, + Parts: []message.ContentPart{message.TextContent{Text: "hi"}}, + }, + }) +} + +func (c *scriptedCoordinator) emitAssistant(sessionID, id string, reason message.FinishReason) { + c.app.SendEvent(pubsub.Event[message.Message]{ + Type: pubsub.CreatedEvent, + Payload: message.Message{ + ID: id, + SessionID: sessionID, + Role: message.Assistant, + Parts: []message.ContentPart{message.Finish{Reason: reason}}, + }, + }) +} + +func (c *scriptedCoordinator) Run(ctx context.Context, sessionID, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error) { + c.runStarts.Add(1) + runCtx, cancel := context.WithCancel(ctx) + + c.mu.Lock() + id := c.nextRunID + c.nextRunID++ + c.cancels[id] = sessionCancel{sessionID: sessionID, cancel: cancel} + // Cancel-on-entry: if a cancel for this session arrived while this + // run was still being dispatched (no run yet in flight to receive + // it), consume the pending cancel now so the run takes the canceled + // path instead of streaming output. + if c.pendingCancels[sessionID] > 0 { + c.pendingCancels[sessionID]-- + cancel() + } + c.mu.Unlock() + + select { + case c.entered <- id: + default: + } + + if c.blockEntered != nil { + select { + case <-c.blockEntered: + case <-runCtx.Done(): + } + } + + defer func() { + c.mu.Lock() + delete(c.cancels, id) + c.mu.Unlock() + cancel() + }() + + // Qualify the emitted message ids with the run id so a test can + // attribute an assistant message to the exact run that produced it + // (run 1 vs an accepted follow-up sharing the same session). + userID := fmt.Sprintf("u-%s-%d", sessionID, id) + asstID := fmt.Sprintf("a-%s-%d", sessionID, id) + + c.emitUser(sessionID, userID) + + // Cancellation takes priority: if the run was already canceled it + // must take the canceled path even when release is closed, so a + // canceled run never races into a normal FinishReasonEndTurn. + select { + case <-runCtx.Done(): + c.emitAssistant(sessionID, asstID, message.FinishReasonCanceled) + return nil, context.Canceled + default: + } + + select { + case <-c.release: + c.emitAssistant(sessionID, asstID, message.FinishReasonEndTurn) + return nil, nil + case <-runCtx.Done(): + c.emitAssistant(sessionID, asstID, message.FinishReasonCanceled) + return nil, context.Canceled + } +} + +func (c *scriptedCoordinator) RunAccepted(ctx context.Context, accept *agent.AcceptedRun, sessionID, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error) { + return c.Run(ctx, sessionID, prompt, attachments...) +} + +func (c *scriptedCoordinator) BeginAccepted(string) *agent.AcceptedRun { return nil } + +func (c *scriptedCoordinator) Cancel(sessionID string) { + c.mu.Lock() + defer c.mu.Unlock() + // Cancel every in-flight run for this session. Concurrent runs for + // the same session (an active run plus an accepted follow-up still + // dispatching) each hold their own entry, so all of them are torn + // down by a single per-session cancel. + var canceled int + for _, sc := range c.cancels { + if sc.sessionID == sessionID { + sc.cancel() + canceled++ + } + } + // If at least one run was in flight, arm a pending cancel so a + // follow-up that has been accepted but not yet entered Run takes the + // cancel-on-entry path. With no run in flight this is a no-op, + // mirroring the production guarantee that an idle cancel does not arm + // a pending cancel against the next prompt. + if canceled > 0 { + c.pendingCancels[sessionID]++ + } +} + +func (c *scriptedCoordinator) CancelAll() { + c.mu.Lock() + defer c.mu.Unlock() + for _, sc := range c.cancels { + sc.cancel() + } +} + +func (c *scriptedCoordinator) IsBusy() bool { return false } +func (c *scriptedCoordinator) IsSessionBusy(string) bool { return false } +func (c *scriptedCoordinator) QueuedPrompts(string) int { return 0 } +func (c *scriptedCoordinator) QueuedPromptsList(string) []string { return nil } +func (c *scriptedCoordinator) ClearQueue(string) {} +func (c *scriptedCoordinator) Summarize(context.Context, string) error { return nil } +func (c *scriptedCoordinator) Model() agent.Model { return agent.Model{} } +func (c *scriptedCoordinator) UpdateModels(context.Context) error { return nil } + +// agentE2EHarness extends the SSE harness with a scripted coordinator +// wired into the workspace's embedded app.App, so POST /agent drives a +// real backend.SendMessage dispatch whose emitted user/assistant +// messages fan out over the same SSE pipeline production uses. +type agentE2EHarness struct { + *e2eHarness + coord *scriptedCoordinator +} + +func newAgentE2EHarness(t *testing.T) *agentE2EHarness { + t.Helper() + + h := &e2eHarness{} + + appCtx, cancel := context.WithCancel(context.Background()) + a := app.NewForTest(appCtx) + coord := newScriptedCoordinator(a) + a.AgentCoordinator = coord + t.Cleanup(func() { + cancel() + a.ShutdownForTest() + }) + + h.installServer(t) + + ws := &backend.Workspace{ + ID: uuid.New().String(), + Path: t.TempDir(), + App: a, + } + backend.SetWorkspaceShutdownFnForTest(ws, func() {}) + backend.InsertWorkspaceForTest(h.backend, ws) + + h.workspace = ws + h.app = a + return &agentE2EHarness{e2eHarness: h, coord: coord} +} + +// postAgentHTTP drives POST /v1/workspaces/{id}/agent over the harness's +// httptest server and returns the status code. +func (h *agentE2EHarness) postAgentHTTP(t *testing.T, ctx context.Context, sessionID string) int { + t.Helper() + body, err := json.Marshal(proto.AgentMessage{SessionID: sessionID, Prompt: "hi"}) + require.NoError(t, err) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, + h.httpSrv.URL+"/v1/workspaces/"+h.workspace.ID+"/agent", bytes.NewReader(body)) + require.NoError(t, err) + req.Header.Set("Content-Type", "application/json") + resp, err := h.httpSrv.Client().Do(req) + require.NoError(t, err) + _, _ = io.Copy(io.Discard, resp.Body) + resp.Body.Close() + return resp.StatusCode +} + +// cancelAgentHTTP drives POST /v1/workspaces/{id}/agent/sessions/{sid}/cancel. +func (h *agentE2EHarness) cancelAgentHTTP(t *testing.T, ctx context.Context, sessionID string) int { + t.Helper() + req, err := http.NewRequestWithContext(ctx, http.MethodPost, + h.httpSrv.URL+"/v1/workspaces/"+h.workspace.ID+"/agent/sessions/"+sessionID+"/cancel", nil) + require.NoError(t, err) + resp, err := h.httpSrv.Client().Do(req) + require.NoError(t, err) + _, _ = io.Copy(io.Discard, resp.Body) + resp.Body.Close() + return resp.StatusCode +} + +// waitForRunEntered blocks until a dispatched run for any session has +// been entered by the scripted coordinator, or fails the test. It +// returns the monotonic run id assigned to that run so a caller can +// correlate it with a later assistant message; callers that don't need +// the id can ignore the return value. +func (h *agentE2EHarness) waitForRunEntered(t *testing.T) int64 { + t.Helper() + select { + case id := <-h.coord.entered: + return id + case <-time.After(2 * time.Second): + t.Fatal("dispatched run was never entered") + return 0 + } +} + +// finishReason extracts the assistant message's FinishReason, if any. +func finishReason(m proto.Message) (proto.FinishReason, bool) { + for _, p := range m.Parts { + if f, ok := p.(proto.Finish); ok { + return f.Reason, true + } + } + return "", false +} + +// TestE2E_CancelByOtherClientDoesNotErrorPrompter covers PLAN Tests -> +// New end-to-end coverage item 1: a second client canceling a run does +// not surface a server error to the prompter; the run ends with a +// FinishReasonCanceled assistant message and no AgentEvent carries a +// non-nil Error. +func TestE2E_CancelByOtherClientDoesNotErrorPrompter(t *testing.T) { + t.Parallel() + h := newAgentE2EHarness(t) + ctx, cancel := context.WithCancel(t.Context()) + t.Cleanup(cancel) + + cidA := uuid.New().String() + cidB := uuid.New().String() + evcA, cancelA := h.subscribeSSE(t, ctx, h.workspace.ID, cidA) + t.Cleanup(cancelA) + evcB, cancelB := h.subscribeSSE(t, ctx, h.workspace.ID, cidB) + t.Cleanup(cancelB) + h.waitForAttached(t, 2) + + const sid = "s-cancel-other" + + // A posts a long-running prompt; the handler must return 202 + // immediately (the run blocks in the coordinator). + require.Equal(t, http.StatusAccepted, h.postAgentHTTP(t, ctx, sid)) + h.waitForRunEntered(t) + + // B cancels. + require.Equal(t, http.StatusOK, h.cancelAgentHTTP(t, ctx, sid)) + + // A's SSE stream receives the FinishReasonCanceled assistant + // message. + pickCtx, pickCancel := context.WithTimeout(ctx, 3*time.Second) + defer pickCancel() + got, ok := drainUntil(pickCtx, evcA, func(e pubsub.Event[proto.Message]) bool { + r, has := finishReason(e.Payload) + return e.Payload.Role == proto.Assistant && has && r == proto.FinishReasonCanceled + }) + require.True(t, ok, "client A must observe a FinishReasonCanceled assistant message") + require.Equal(t, sid, got.Payload.SessionID) + + // No AgentEvent error reaches A (cancel is not a server error). + errCtx, errCancel := context.WithTimeout(ctx, 250*time.Millisecond) + defer errCancel() + _, gotErrA := drainUntil(errCtx, evcA, func(e pubsub.Event[proto.AgentEvent]) bool { + return e.Payload.Type == proto.AgentEventTypeError && e.Payload.Error != nil + }) + require.False(t, gotErrA, "cancel must not surface an AgentEvent error to the prompter") + + // And no AgentEvent error reaches the canceling client B either; the + // PLAN requires that *no* client observes a non-nil Error. + errCtxB, errCancelB := context.WithTimeout(ctx, 250*time.Millisecond) + defer errCancelB() + _, gotErrB := drainUntil(errCtxB, evcB, func(e pubsub.Event[proto.AgentEvent]) bool { + return e.Payload.Type == proto.AgentEventTypeError && e.Payload.Error != nil + }) + require.False(t, gotErrB, "cancel must not surface an AgentEvent error to any client") +} + +// TestE2E_CancelImmediatelyAfter202IsNotLost covers PLAN item 1a: a +// cancel that races a freshly-dispatched run (before it would emit any +// output) is not lost. The run takes the cancel-on-entry path and emits +// a user message followed by a FinishReasonCanceled assistant message +// rather than streaming model output. +func TestE2E_CancelImmediatelyAfter202IsNotLost(t *testing.T) { + t.Parallel() + h := newAgentE2EHarness(t) + // Gate the run on a signal the test controls so the cancel can be + // observed while the dispatched goroutine is parked at entry. + h.coord.blockEntered = make(chan struct{}) + + ctx, cancel := context.WithCancel(t.Context()) + t.Cleanup(cancel) + + cid := uuid.New().String() + evc, cancelSSE := h.subscribeSSE(t, ctx, h.workspace.ID, cid) + t.Cleanup(cancelSSE) + h.waitForAttached(t, 1) + + const sid = "s-race-cancel" + require.Equal(t, http.StatusAccepted, h.postAgentHTTP(t, ctx, sid)) + h.waitForRunEntered(t) + + // Cancel while the run is still blocked at entry, then release it. + require.Equal(t, http.StatusOK, h.cancelAgentHTTP(t, ctx, sid)) + close(h.coord.blockEntered) + + pickCtx, pickCancel := context.WithTimeout(ctx, 3*time.Second) + defer pickCancel() + + gotUser, okUser := drainUntil(pickCtx, evc, func(e pubsub.Event[proto.Message]) bool { + return e.Payload.Role == proto.User && e.Payload.SessionID == sid + }) + require.True(t, okUser, "the canceled turn must still record a user message") + require.Equal(t, sid, gotUser.Payload.SessionID) + + gotAsst, okAsst := drainUntil(pickCtx, evc, func(e pubsub.Event[proto.Message]) bool { + r, has := finishReason(e.Payload) + return e.Payload.Role == proto.Assistant && has && r == proto.FinishReasonCanceled + }) + require.True(t, okAsst, "the canceled turn must end with a FinishReasonCanceled assistant message") + require.Equal(t, sid, gotAsst.Payload.SessionID) +} + +// TestE2E_IdleCancelDoesNotPoisonNextPrompt covers PLAN item 1b: an +// idle cancel (no active run) must not poison the next prompt. With the +// scripted coordinator the cancel records a pending entry only if a run +// is in flight; an idle cancel records one, but the documented +// guarantee is that the *next* prompt's outcome is observable. Here we +// assert the regression-relevant external behavior: after an idle +// cancel, a subsequent normal prompt is able to run and emit output. +// +// NOTE: This is a simplified version. The real "idle Escape must not +// poison" guarantee lives inside sessionAgent.Cancel's acceptedRuns +// gating, which is covered by the agent unit tests; the e2e stub cannot +// distinguish "truly idle" from "accepted but not yet running" without +// the internal acceptedRuns signal. See test summary. +func TestE2E_IdleCancelDoesNotPoisonNextPrompt(t *testing.T) { + t.Parallel() + h := newAgentE2EHarness(t) + ctx, cancel := context.WithCancel(t.Context()) + t.Cleanup(cancel) + + cid := uuid.New().String() + evc, cancelSSE := h.subscribeSSE(t, ctx, h.workspace.ID, cid) + t.Cleanup(cancelSSE) + h.waitForAttached(t, 1) + + const sid = "s-idle-cancel" + + // Idle cancel: no run in flight. The scripted coordinator drops it + // (no pending cancel recorded for a session that has no run), which + // models the production guarantee that an idle Escape does not arm + // a cancel against the next prompt. + require.Equal(t, http.StatusOK, h.cancelAgentHTTP(t, ctx, sid)) + + // Now a normal prompt; release it so it finishes successfully. + require.Equal(t, http.StatusAccepted, h.postAgentHTTP(t, ctx, sid)) + h.waitForRunEntered(t) + close(h.coord.release) + + pickCtx, pickCancel := context.WithTimeout(ctx, 3*time.Second) + defer pickCancel() + got, ok := drainUntil(pickCtx, evc, func(e pubsub.Event[proto.Message]) bool { + r, has := finishReason(e.Payload) + return e.Payload.Role == proto.Assistant && has && r == proto.FinishReasonEndTurn + }) + require.True(t, ok, "the next prompt after an idle cancel must run to FinishReasonEndTurn") + require.Equal(t, sid, got.Payload.SessionID) + + // And it must not be marked canceled. + canCtx, canCancel := context.WithTimeout(ctx, 200*time.Millisecond) + defer canCancel() + _, gotCanceled := drainUntil(canCtx, evc, func(e pubsub.Event[proto.Message]) bool { + r, has := finishReason(e.Payload) + return e.Payload.Role == proto.Assistant && has && r == proto.FinishReasonCanceled + }) + require.False(t, gotCanceled, "an idle cancel must not produce a FinishReasonCanceled marker on the next prompt") +} + +// TestE2E_CancelBetweenActiveSetAndAssistantCreate covers PLAN item 1d: +// a cancel that arrives after the run has begun but before it would +// create the assistant message must still produce a user message and a +// FinishReasonCanceled assistant message, never a silent return. The +// blockEntered gate parks the run after entry (modeling the window +// between activeRequests.Set and assistant creation). +func TestE2E_CancelBetweenActiveSetAndAssistantCreate(t *testing.T) { + t.Parallel() + h := newAgentE2EHarness(t) + h.coord.blockEntered = make(chan struct{}) + + ctx, cancel := context.WithCancel(t.Context()) + t.Cleanup(cancel) + + cid := uuid.New().String() + evc, cancelSSE := h.subscribeSSE(t, ctx, h.workspace.ID, cid) + t.Cleanup(cancelSSE) + h.waitForAttached(t, 1) + + const sid = "s-mid-window" + require.Equal(t, http.StatusAccepted, h.postAgentHTTP(t, ctx, sid)) + h.waitForRunEntered(t) + + // Cancel while parked at entry; then release so the run proceeds + // into its cancel branch (runCtx already canceled). + require.Equal(t, http.StatusOK, h.cancelAgentHTTP(t, ctx, sid)) + close(h.coord.blockEntered) + + pickCtx, pickCancel := context.WithTimeout(ctx, 3*time.Second) + defer pickCancel() + + _, okUser := drainUntil(pickCtx, evc, func(e pubsub.Event[proto.Message]) bool { + return e.Payload.Role == proto.User && e.Payload.SessionID == sid + }) + require.True(t, okUser, "a user message must be recorded for the canceled turn") + + gotAsst, okAsst := drainUntil(pickCtx, evc, func(e pubsub.Event[proto.Message]) bool { + r, has := finishReason(e.Payload) + return e.Payload.Role == proto.Assistant && has && r == proto.FinishReasonCanceled + }) + require.True(t, okAsst, "the run must not return silently; it must emit a FinishReasonCanceled assistant message") + require.Equal(t, sid, gotAsst.Payload.SessionID) + + // No AgentEvent error is published: a cancel in the + // activeRequests.Set -> assistant-create window is not a server + // error. + errCtx, errCancel := context.WithTimeout(ctx, 250*time.Millisecond) + defer errCancel() + _, gotErr := drainUntil(errCtx, evc, func(e pubsub.Event[proto.AgentEvent]) bool { + return e.Payload.Type == proto.AgentEventTypeError && e.Payload.Error != nil + }) + require.False(t, gotErr, "no AgentEvent error must be published for the canceled turn") +} + +// TestE2E_PromptRequestContextDoesNotOwnRun covers PLAN item 2: the +// prompting client's HTTP request context does not own the run. A POST +// with a very short request-context timeout still returns 202 before +// that context would expire, and the run keeps going (observed via SSE +// finishing normally after release). +func TestE2E_PromptRequestContextDoesNotOwnRun(t *testing.T) { + t.Parallel() + h := newAgentE2EHarness(t) + streamCtx, streamCancel := context.WithCancel(t.Context()) + t.Cleanup(streamCancel) + + cid := uuid.New().String() + evc, cancelSSE := h.subscribeSSE(t, streamCtx, h.workspace.ID, cid) + t.Cleanup(cancelSSE) + h.waitForAttached(t, 1) + + const sid = "s-short-req" + + // The POST request context times out almost immediately. The + // handler must still return 202 (fire-and-forget) and the run must + // survive past the request-context deadline. + reqCtx, reqCancel := context.WithTimeout(t.Context(), 50*time.Millisecond) + defer reqCancel() + require.Equal(t, http.StatusAccepted, h.postAgentHTTP(t, reqCtx, sid)) + h.waitForRunEntered(t) + + // Let the request context expire, then release the run. + <-reqCtx.Done() + close(h.coord.release) + + pickCtx, pickCancel := context.WithTimeout(streamCtx, 3*time.Second) + defer pickCancel() + got, ok := drainUntil(pickCtx, evc, func(e pubsub.Event[proto.Message]) bool { + r, has := finishReason(e.Payload) + return e.Payload.Role == proto.Assistant && has && r == proto.FinishReasonEndTurn + }) + require.True(t, ok, "the run must finish normally even after the prompting request context expired") + require.Equal(t, sid, got.Payload.SessionID) +} + +// TestE2E_AgentRunSurvivesAcrossWorkspaceClaims covers PLAN item 3: a +// run started by client A survives A detaching as long as another +// client (B) keeps the workspace alive; B observes the run finish via +// SSE. +func TestE2E_AgentRunSurvivesAcrossWorkspaceClaims(t *testing.T) { + t.Parallel() + h := newAgentE2EHarness(t) + + ctxA, cancelA := context.WithCancel(t.Context()) + ctxB, cancelB := context.WithCancel(t.Context()) + t.Cleanup(cancelB) + + cidA := uuid.New().String() + cidB := uuid.New().String() + _, killA := h.subscribeSSE(t, ctxA, h.workspace.ID, cidA) + t.Cleanup(killA) + evcB, killB := h.subscribeSSE(t, ctxB, h.workspace.ID, cidB) + t.Cleanup(killB) + h.waitForAttached(t, 2) + + const sid = "s-survive" + // A is the poster; the run must outlive A detaching as long as B + // keeps the workspace alive. + require.Equal(t, http.StatusAccepted, h.postAgentHTTP(t, ctxA, sid)) + h.waitForRunEntered(t) + + // A detaches; B is still attached so the workspace stays alive. + cancelA() + killA() + require.Eventually(t, func() bool { + return backend.WorkspaceLiveStreamCountForTest(h.workspace) == 1 + }, 3*time.Second, 10*time.Millisecond, + "A detaching must leave B as the sole attached client") + require.False(t, h.shutdownHit.Load(), "workspace must stay alive while B is attached") + + // Release the run; B must still observe it finish. + close(h.coord.release) + pickCtx, pickCancel := context.WithTimeout(ctxB, 3*time.Second) + defer pickCancel() + got, ok := drainUntil(pickCtx, evcB, func(e pubsub.Event[proto.Message]) bool { + r, has := finishReason(e.Payload) + return e.Payload.Role == proto.Assistant && has && r == proto.FinishReasonEndTurn + }) + require.True(t, ok, "B must observe the run finish after A detaches") + require.Equal(t, sid, got.Payload.SessionID) +} + +// TestE2E_CancelOfActiveRunAlsoCancelsAcceptedFollowUp covers PLAN item +// 1c at the externally-observable level: while session sid has an active +// run, a second prompt for sid is accepted; a cancel for sid must cancel +// the active run and must not let the follow-up stream a normal +// FinishReasonEndTurn. +// +// The sequence follows the PLAN exactly: prompt 1 becomes the active +// run, prompt 2 for the same sid is accepted, then a cancel for sid +// fires, and only afterwards are any signals released. The scripted +// coordinator models the externally-observable contract of the +// busy-queue branch and pendingCancels (which depend on internal +// scheduler signals the codebase does not expose): a per-session cancel +// tears down every in-flight run for sid and arms a cancel-on-entry for +// a follow-up still dispatching. The invariant asserted is the one that +// matters: after the cancel, the active run ends canceled and the +// follow-up never streams a normal FinishReasonEndTurn. +func TestE2E_CancelOfActiveRunAlsoCancelsAcceptedFollowUp(t *testing.T) { + t.Parallel() + h := newAgentE2EHarness(t) + ctx, cancel := context.WithCancel(t.Context()) + t.Cleanup(cancel) + + cid := uuid.New().String() + evc, cancelSSE := h.subscribeSSE(t, ctx, h.workspace.ID, cid) + t.Cleanup(cancelSSE) + h.waitForAttached(t, 1) + + const sid = "s-followup" + + // (a) Prompt 1 for sid becomes the active run. Capture its run id so + // the canceled assistant message below can be attributed to run 1 + // unambiguously. + require.Equal(t, http.StatusAccepted, h.postAgentHTTP(t, ctx, sid)) + run1 := h.waitForRunEntered(t) + + // (b) Prompt 2 for the *same* sid is accepted while the active run + // is still in flight; it is the follow-up the PLAN describes + // (acceptedRuns > 0, either still dispatching or about to enter the + // busy-queue branch). + require.Equal(t, http.StatusAccepted, h.postAgentHTTP(t, ctx, sid)) + run2 := h.waitForRunEntered(t) + require.NotEqual(t, run1, run2, "the follow-up must be a distinct run from the active one") + + // (c) B cancels sid. This tears down every in-flight run for the + // session and arms a pending cancel for any follow-up that has not + // yet entered Run. + require.Equal(t, http.StatusOK, h.cancelAgentHTTP(t, ctx, sid)) + + // (d) Open the coordinator gate so any run that is NOT canceled would + // be free to proceed straight into the normal FinishReasonEndTurn + // branch. The scripted Run checks runCtx.Done() before the release + // select, so a canceled run still takes the canceled path even with + // release closed; only a non-canceled run reaches FinishReasonEndTurn. + // Releasing here is therefore what makes the assertions below + // meaningful: if the cancel had failed to tear down run 1 or arm the + // cancel-on-entry for the follow-up, the freed gate would let that run + // stream a normal FinishReasonEndTurn and the test would fail. + close(h.coord.release) + + pickCtx, pickCancel := context.WithTimeout(ctx, 3*time.Second) + defer pickCancel() + + // (e) Run 1 (the active run) must end with FinishReasonCanceled. The + // assistant message id is qualified with the run id, so matching on + // run1's id proves the cancellation is attributed to the FIRST run + // and not to the follow-up. + // + // The single drain below is also the negative assertion for run 2: + // the match closure inspects every assistant event for sid as it + // scans, and if it ever observes the follow-up (run 2) streaming a + // normal FinishReasonEndTurn it records that violation immediately. + // This is what makes the run-2 check sound: a previous two-phase + // approach could let this very drain consume and discard a run-2 + // EndTurn while still hunting for run 1's canceled message, leaving a + // later no-EndTurn check unable to prove run 2 stayed canceled. + // Folding the negative check into the same scan means a run-2 EndTurn + // can never slip past unobserved, whether it arrives before or after + // run 1's canceled message. + run1AsstID := fmt.Sprintf("a-%s-%d", sid, run1) + run2AsstID := fmt.Sprintf("a-%s-%d", sid, run2) + var followUpEndTurn bool + got, ok := drainUntil(pickCtx, evc, func(e pubsub.Event[proto.Message]) bool { + if e.Payload.SessionID != sid || e.Payload.Role != proto.Assistant { + return false + } + r, has := finishReason(e.Payload) + if !has { + return false + } + // Any normal model output for sid after the cancel is a + // violation. The follow-up (run 2) must never reach the + // FinishReasonEndTurn branch; flag it the moment it is seen so + // the assertion below fails even if this event arrives while we + // are still waiting for run 1's canceled message. + if r == proto.FinishReasonEndTurn { + if e.Payload.ID == run2AsstID || e.Payload.ID != run1AsstID { + followUpEndTurn = true + } + // Stop draining; the EndTurn observation is decisive and the + // require.False below will surface the failure. + return true + } + return e.Payload.ID == run1AsstID && r == proto.FinishReasonCanceled + }) + require.False(t, followUpEndTurn, "the accepted follow-up must not stream a normal FinishReasonEndTurn after the cancel") + require.True(t, ok, "the first (active) run must end with FinishReasonCanceled") + require.Equal(t, run1AsstID, got.Payload.ID, "the canceled message must belong to the first (active) run") + gotReason, gotHas := finishReason(got.Payload) + require.True(t, gotHas) + require.Equal(t, proto.FinishReasonCanceled, gotReason, "the matched run-1 message must be canceled, not a normal end turn") + require.Equal(t, sid, got.Payload.SessionID) + + // Confirm no normal FinishReasonEndTurn for sid is still in flight. + // By this point the scan above has already ruled out a run-2 EndTurn + // arriving before run 1's canceled message; this guards against one + // arriving afterward. + endCtx, endCancel := context.WithTimeout(ctx, 300*time.Millisecond) + defer endCancel() + _, gotEnd := drainUntil(endCtx, evc, func(e pubsub.Event[proto.Message]) bool { + r, has := finishReason(e.Payload) + return e.Payload.SessionID == sid && e.Payload.Role == proto.Assistant && has && r == proto.FinishReasonEndTurn + }) + require.False(t, gotEnd, "the accepted follow-up must not stream model output after the cancel") +} diff --git a/internal/server/e2e_test.go b/internal/server/e2e_test.go index 08aaedf66c95edd704f18b62d83d64e79966564e..565a989136536e5bea8b1134995a3770183d4caa 100644 --- a/internal/server/e2e_test.go +++ b/internal/server/e2e_test.go @@ -240,6 +240,18 @@ func decodeSSEEnvelope(p pubsub.Payload) (any, bool) { return nil, false } return e, true + case pubsub.PayloadTypeAgentEvent: + var e pubsub.Event[proto.AgentEvent] + if err := json.Unmarshal(p.Payload, &e); err != nil { + return nil, false + } + return e, true + case pubsub.PayloadTypeRunComplete: + var e pubsub.Event[proto.RunComplete] + if err := json.Unmarshal(p.Payload, &e); err != nil { + return nil, false + } + return e, true } return nil, false } diff --git a/internal/server/events.go b/internal/server/events.go index 4e3d6a1a262ae7b3399f0fa765ae863160ba57c8..526f9e195009cd70c453958778fb98887aae4a37 100644 --- a/internal/server/events.go +++ b/internal/server/events.go @@ -2,6 +2,7 @@ package server import ( "encoding/json" + "errors" "fmt" "log/slog" @@ -85,13 +86,19 @@ func wrapEvent(ev any) *pubsub.Payload { Payload: fileToProto(e.Payload), }) case pubsub.Event[notify.Notification]: + payload := proto.AgentEvent{ + SessionID: e.Payload.SessionID, + SessionTitle: e.Payload.SessionTitle, + RunID: e.Payload.RunID, + Type: proto.AgentEventType(e.Payload.Type), + } + if e.Payload.Type == notify.TypeAgentError { + payload.Type = proto.AgentEventTypeError + payload.Error = errors.New(e.Payload.Message) + } return envelope(pubsub.PayloadTypeAgentEvent, pubsub.Event[proto.AgentEvent]{ - Type: e.Type, - Payload: proto.AgentEvent{ - SessionID: e.Payload.SessionID, - SessionTitle: e.Payload.SessionTitle, - Type: proto.AgentEventType(e.Payload.Type), - }, + Type: e.Type, + Payload: payload, }) case pubsub.Event[notify.RunComplete]: return envelope(pubsub.PayloadTypeRunComplete, pubsub.Event[proto.RunComplete]{ diff --git a/internal/server/events_test.go b/internal/server/events_test.go index 432bc42f910b4acec675baea46754b81defab9f6..e4238a05eb3abf50e13329acfaabd2cb77dd464c 100644 --- a/internal/server/events_test.go +++ b/internal/server/events_test.go @@ -123,6 +123,38 @@ func TestRunCompleteToProto_RoundTrip(t *testing.T) { require.False(t, decoded.Payload.Cancelled) } +// TestAgentErrorToProto_PreservesRunID verifies that an async agent +// error notification carries its originating RunID (and SessionID) +// through the SSE envelope. Without these correlators, `crush run` +// cannot tell whether an error event belongs to its own run and +// would abort on any unrelated workspace failure. +func TestAgentErrorToProto_PreservesRunID(t *testing.T) { + t.Parallel() + + src := pubsub.Event[notify.Notification]{ + Type: pubsub.CreatedEvent, + Payload: notify.Notification{ + SessionID: "S", + RunID: "run-99", + Type: notify.TypeAgentError, + Message: "boom", + }, + } + + env := wrapEvent(src) + require.NotNil(t, env) + require.Equal(t, pubsub.PayloadTypeAgentEvent, env.Type) + + var decoded pubsub.Event[proto.AgentEvent] + require.NoError(t, json.Unmarshal(env.Payload, &decoded)) + require.Equal(t, proto.AgentEventTypeError, decoded.Payload.Type) + require.Equal(t, "S", decoded.Payload.SessionID) + require.Equal(t, "run-99", decoded.Payload.RunID, + "RunID must survive so observers can attribute the error to its run") + require.NotNil(t, decoded.Payload.Error) + require.Equal(t, "boom", decoded.Payload.Error.Error()) +} + // TestRunCompleteToProto_Error verifies that error- and cancel-shaped // RunComplete events round-trip cleanly so clients can distinguish // "agent failed" (returns non-zero from `crush run`) from "agent diff --git a/internal/server/proto.go b/internal/server/proto.go index 5e1c4e2605fb8413df0464be2cd7ee6cb40a5f66..f388d51bb87490a484bcb06b16e4698058bae134 100644 --- a/internal/server/proto.go +++ b/internal/server/proto.go @@ -1,7 +1,6 @@ package server import ( - "context" "encoding/json" "errors" "fmt" @@ -740,9 +739,10 @@ func (c *controllerV1) handleGetWorkspaceAgent(w http.ResponseWriter, r *http.Re // @Accept json // @Param id path string true "Workspace ID" // @Param request body proto.AgentMessage true "Agent message" -// @Success 200 +// @Success 202 // @Failure 400 {object} proto.Error // @Failure 404 {object} proto.Error +// @Failure 409 {object} proto.Error // @Failure 500 {object} proto.Error // @Router /workspaces/{id}/agent [post] func (c *controllerV1) handlePostWorkspaceAgent(w http.ResponseWriter, r *http.Request) { @@ -755,18 +755,19 @@ func (c *controllerV1) handlePostWorkspaceAgent(w http.ResponseWriter, r *http.R return } - // Detach the run's lifetime from the prompting client's HTTP - // request. Without this, A dropping its TCP connection (network - // blip, TUI restart) or B canceling the session via the explicit - // cancel endpoint would also cancel A's request context and tear - // down a turn that other subscribed clients are still watching. - // Only the explicit cancel endpoint should be able to end a run. - ctx := context.WithoutCancel(r.Context()) - if err := c.backend.SendMessage(ctx, id, msg); err != nil { + // The run's lifetime is detached from the prompting client's HTTP + // request: SendMessage validates and accepts the prompt, dispatches + // the run on a goroutine bound to the workspace context, and returns + // immediately. A dropping its TCP connection (network blip, TUI + // restart) or B canceling the session via the explicit cancel + // endpoint can no longer tear down a turn that other subscribed + // clients are still watching. Only the explicit cancel endpoint + // should be able to end a run. + if err := c.backend.SendMessage(id, msg); err != nil { c.handleError(w, r, err) return } - w.WriteHeader(http.StatusOK) + w.WriteHeader(http.StatusAccepted) } // handlePostWorkspaceAgentInit initializes the agent for a workspace. @@ -1033,15 +1034,15 @@ func (c *controllerV1) handleGetWorkspacePermissionsSkip(w http.ResponseWriter, // handleError maps backend errors to HTTP status codes and writes the // JSON error response. +// +// Runtime cancellation of an agent run no longer reaches here for the +// agent-prompt path: SendMessage is fire-and-forget (the handler returns +// 202 before the run starts) and Backend.runAgent swallows +// context.Canceled, surfacing the FinishReasonCanceled marker to SSE +// subscribers instead. The remaining callers pass synchronous backend +// errors, so context.Canceled gets no special case and would fall through +// to the default 500 like any other unexpected error. func (c *controllerV1) handleError(w http.ResponseWriter, r *http.Request, err error) { - // A canceled agent run is not an error from the prompting - // client's perspective. The cancellation reaches every SSE - // subscriber via the FinishReasonCanceled marker on the assistant - // message; the still-open POST should not surface a 500. - if errors.Is(err, context.Canceled) { - w.WriteHeader(http.StatusOK) - return - } status := http.StatusInternalServerError switch { case errors.Is(err, backend.ErrWorkspaceNotFound): @@ -1060,6 +1061,8 @@ func (c *controllerV1) handleError(w http.ResponseWriter, r *http.Request, err e status = http.StatusBadRequest case errors.Is(err, backend.ErrClientNotAttached): status = http.StatusNotFound + case errors.Is(err, backend.ErrWorkspaceClosing): + status = http.StatusConflict } c.server.logError(r, err.Error()) jsonError(w, status, err.Error()) diff --git a/internal/server/sessions_isbusy_test.go b/internal/server/sessions_isbusy_test.go index 060c00abe9367dc7162bdb50dd77fe951041aa51..615f4a5b58cde3f5779b053ca9ce92e0d0d253a4 100644 --- a/internal/server/sessions_isbusy_test.go +++ b/internal/server/sessions_isbusy_test.go @@ -30,6 +30,14 @@ type stubCoordinator struct { func (s *stubCoordinator) Run(ctx context.Context, sessionID, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error) { return nil, nil } + +func (s *stubCoordinator) RunAccepted(ctx context.Context, accept *agent.AcceptedRun, sessionID, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error) { + return nil, nil +} + +func (s *stubCoordinator) BeginAccepted(sessionID string) *agent.AcceptedRun { + return nil +} func (s *stubCoordinator) Cancel(string) {} func (s *stubCoordinator) CancelAll() {} func (s *stubCoordinator) IsBusy() bool { return false } diff --git a/internal/ui/model/ui.go b/internal/ui/model/ui.go index 63226d0a2f9ec951ccc0cf1167f4e06ac6c88c16..0c6102572cdc768ce0ebdf061a816eacb66af974 100644 --- a/internal/ui/model/ui.go +++ b/internal/ui/model/ui.go @@ -3285,12 +3285,12 @@ func (m *UI) sendMessage(content string, attachments ...message.Attachment) tea. // Capture session ID to avoid race with main goroutine updating m.session. sessionID := m.session.ID cmds = append(cmds, func() tea.Msg { + // AgentRun is fire-and-forget: it returns once the prompt has + // been accepted (HTTP 202) or synchronously with a validation + // or transport error. Run failures and cancellation surface + // through SSE-derived events, not this return value. err := m.com.Workspace.AgentRun(context.Background(), sessionID, content, attachments...) if err != nil { - isCancelErr := errors.Is(err, context.Canceled) - if isCancelErr { - return nil - } return util.InfoMsg{ Type: util.InfoTypeError, Msg: fmt.Sprintf("%v", err), diff --git a/internal/workspace/client_workspace.go b/internal/workspace/client_workspace.go index 3653046709c66f87c6b496e11dce2373a5ea40f2..09ff57c612cd776c0feae6801345035224b90422 100644 --- a/internal/workspace/client_workspace.go +++ b/internal/workspace/client_workspace.go @@ -704,13 +704,18 @@ func (w *ClientWorkspace) translateEvent(ev any) tea.Msg { Payload: protoToFile(e.Payload), } case pubsub.Event[proto.AgentEvent]: + n := notify.Notification{ + SessionID: e.Payload.SessionID, + SessionTitle: e.Payload.SessionTitle, + RunID: e.Payload.RunID, + Type: notify.Type(e.Payload.Type), + } + if e.Payload.Error != nil { + n.Message = e.Payload.Error.Error() + } return pubsub.Event[notify.Notification]{ - Type: e.Type, - Payload: notify.Notification{ - SessionID: e.Payload.SessionID, - SessionTitle: e.Payload.SessionTitle, - Type: notify.Type(e.Payload.Type), - }, + Type: e.Type, + Payload: n, } case pubsub.Event[proto.RunComplete]: // Translate the wire-level proto.RunComplete back into the