From e14d47a8135cef6950d6a337a8f7e3c62157bc27 Mon Sep 17 00:00:00 2001 From: Christian Rocha Date: Thu, 28 May 2026 10:55:58 -0400 Subject: [PATCH] fix(noninteractive): crush run reliability in client/server mode This fixes premature exits during tool use, hangs after certain tool-calling stops, and issues when continuing into a session that is already busy. Co-Authored-By: Charm Crush --- internal/agent/agent.go | 102 +++++++- internal/agent/coordinator.go | 35 ++- internal/agent/notify/notify.go | 28 +++ internal/agent/run_complete_test.go | 89 +++++++ internal/agent/runid.go | 33 +++ internal/app/app.go | 42 ++++ internal/app/testing.go | 3 + internal/backend/agent.go | 11 + internal/client/proto.go | 15 +- internal/cmd/run.go | 176 ++++++++++--- internal/cmd/run_stream_test.go | 329 +++++++++++++++++++++++++ internal/proto/proto.go | 39 +++ internal/pubsub/events.go | 8 + internal/server/events.go | 12 + internal/server/events_test.go | 65 +++++ internal/workspace/client_workspace.go | 24 +- 16 files changed, 965 insertions(+), 46 deletions(-) create mode 100644 internal/agent/run_complete_test.go create mode 100644 internal/agent/runid.go create mode 100644 internal/cmd/run_stream_test.go diff --git a/internal/agent/agent.go b/internal/agent/agent.go index f8e4d716c8f9579b4bb16d8479af55e4ff78f12c..53e63af3b95e8bb4fba6144675d97c3686e78546 100644 --- a/internal/agent/agent.go +++ b/internal/agent/agent.go @@ -70,7 +70,16 @@ var ( ) type SessionAgentCall struct { - SessionID string + SessionID string + // RunID, when non-empty, is the caller-supplied correlator that + // gets echoed back on the notify.RunComplete event emitted for + // this turn. It is preserved when the call is enqueued behind a + // busy session so the queued turn's terminal event is still + // recognisable to the original caller. Callers that need a + // reliable completion contract (e.g. `crush run` against a + // session that may be busy) MUST set it; SessionID alone is + // ambiguous when concurrent turns share the same session. + RunID string Prompt string ProviderOptions fantasy.ProviderOptions Attachments []message.Attachment @@ -81,6 +90,19 @@ type SessionAgentCall struct { FrequencyPenalty *float64 PresencePenalty *float64 NonInteractive bool + // OnComplete, when non-nil, replaces the default RunComplete + // publish path: the inner Run hands the terminal payload to this + // callback instead of emitting it on the RunComplete broker. The + // coordinator uses this hook to coalesce the unauthorized → + // re-auth → retry chain into a single user-visible terminal + // event, so non-interactive clients (e.g. `crush run`) don't + // exit on a stale failed-attempt RunComplete before the + // successful retry. It is intentionally stripped when queueing + // a busy-session call (see Run): the originating + // coordinator.Run has long returned by the time the queued + // recursion drains, so falling back to the default broker + // publish keeps the event visible to subscribers. + OnComplete func(notify.RunComplete) } type SessionAgent interface { @@ -119,6 +141,7 @@ type sessionAgent struct { disableAutoSummarize bool isYolo bool notify pubsub.Publisher[notify.Notification] + runComplete pubsub.Publisher[notify.RunComplete] messageQueue *csync.Map[string, []SessionAgentCall] activeRequests *csync.Map[string, context.CancelFunc] @@ -136,6 +159,7 @@ type SessionAgentOptions struct { Messages message.Service Tools []fantasy.AgentTool Notify pubsub.Publisher[notify.Notification] + RunComplete pubsub.Publisher[notify.RunComplete] } func NewSessionAgent( @@ -153,12 +177,13 @@ func NewSessionAgent( tools: csync.NewSliceFrom(opts.Tools), isYolo: opts.IsYolo, notify: opts.Notify, + runComplete: opts.RunComplete, messageQueue: csync.NewMap[string, []SessionAgentCall](), activeRequests: csync.NewMap[string, context.CancelFunc](), } } -func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) { +func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (result *fantasy.AgentResult, retErr error) { if call.Prompt == "" && !message.ContainsTextAttachment(call.Attachments) { return nil, ErrEmptyPrompt } @@ -166,13 +191,21 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy return nil, ErrSessionMissing } - // Queue the message if busy + // Queue the message if busy. Strip OnComplete: the caller that + // supplied the hook (typically coordinator.Run) has its own + // retry/coalesce scope that ends when it returns, so by the time + // the queue drains nobody is left to consume the buffered + // terminal event. The recursive Run will fall back to the + // default broker publish, which is what existing subscribers + // expect for queued turns. if a.IsSessionBusy(call.SessionID) { existing, ok := a.messageQueue.Get(call.SessionID) if !ok { existing = []SessionAgentCall{} } - existing = append(existing, call) + queued := call + queued.OnComplete = nil + existing = append(existing, queued) a.messageQueue.Set(call.SessionID, existing) return nil, nil } @@ -245,14 +278,65 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy defer cancel() defer a.activeRequests.Del(call.SessionID) + // skipRunComplete is set just before the queued-recursion path so + // the outer Run doesn't publish a RunComplete that would race + // with — and be superseded by — the recursive call's own + // RunComplete (each queued user prompt is its own turn and + // publishes exactly one terminal event). + var skipRunComplete bool + // currentAssistant is declared here so the deferred RunComplete + // publish below can capture the pointer that PrepareStep will + // later (re)assign for each streaming step. The final assistant + // message of the turn is the value reachable through this + // pointer when the defer runs. + var currentAssistant *message.Message // Drain any debounced message updates before returning. message.Service // already flushes synchronously on terminal updates, but a defer here // guarantees the contract at every Run exit (success, error, panic // recovery upstream) without callers needing to know. + // + // After the flush completes — meaning all per-message + // Publish(UpdatedEvent) calls have fired and been buffered into + // every subscriber's channel — publish the authoritative + // RunComplete event for this turn. The flush-then-publish order + // gives well-behaved clients the best chance of seeing the final + // message event before RunComplete; the embedded Text field + // reconciles for clients that observe the events out of order + // (the pubsub broker fan-in does not serialize publishes from + // different upstream brokers). defer func() { if flushErr := a.messages.FlushAll(ctx); flushErr != nil { slog.Error("Failed to flush pending message updates after run", "error", flushErr) } + if skipRunComplete { + return + } + complete := notify.RunComplete{SessionID: call.SessionID, RunID: call.RunID} + if currentAssistant != nil { + complete.MessageID = currentAssistant.ID + complete.Text = currentAssistant.Content().String() + } + if retErr != nil { + complete.Error = retErr.Error() + complete.Cancelled = errors.Is(retErr, context.Canceled) + } else if ctx.Err() != nil { + complete.Cancelled = true + } + // Prefer the per-call hook when supplied so the coordinator + // can coalesce retries (e.g. unauthorized → re-auth → retry) + // into a single user-visible terminal event. The fallback + // must-deliver publish applies bounded-blocking semantics to + // the authoritative terminal event so a momentarily-full + // subscriber channel can't silently drop it and hang + // non-interactive clients waiting on RunComplete. + if call.OnComplete != nil { + call.OnComplete(complete) + return + } + if a.runComplete == nil { + return + } + a.runComplete.PublishMustDeliver(ctx, pubsub.UpdatedEvent, complete) }() history, files := a.preparePrompt(msgs, largeModel.CatwalkCfg.SupportsImages, call.Attachments...) @@ -260,7 +344,6 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy startTime := time.Now() a.eventPromptSent(call.SessionID) - var currentAssistant *message.Message var stepMessages []fantasy.Message var shouldSummarize bool // Don't send MaxOutputTokens if 0 — some providers (e.g. LM Studio) reject it @@ -268,7 +351,7 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy if call.MaxOutputTokens > 0 { maxOutputTokens = &call.MaxOutputTokens } - result, err := agent.Stream(genCtx, fantasy.AgentStreamCall{ + result, err = agent.Stream(genCtx, fantasy.AgentStreamCall{ Prompt: message.PromptWithTextAttachments(call.Prompt, call.Attachments), Files: files, Messages: history, @@ -634,7 +717,12 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy if !ok || len(queuedMessages) == 0 { return result, err } - // There are queued messages restart the loop. + // There are queued messages restart the loop. The recursive Run + // publishes its own RunComplete for the queued prompt, so suppress + // the outer defer's emit to avoid a duplicate event whose Error + // field would belong to the recursive turn but whose MessageID/Text + // would belong to the outer turn. + skipRunComplete = true firstQueuedMessage := queuedMessages[0] a.messageQueue.Set(call.SessionID, queuedMessages[1:]) return a.Run(ctx, firstQueuedMessage) diff --git a/internal/agent/coordinator.go b/internal/agent/coordinator.go index 14f67f724a8410b043164928d2d5fbe253108300..d4c05951e6af1e0243ca80df9dba9410c1529936 100644 --- a/internal/agent/coordinator.go +++ b/internal/agent/coordinator.go @@ -97,6 +97,7 @@ type coordinator struct { filetracker filetracker.Service lspManager *lsp.Manager notify pubsub.Publisher[notify.Notification] + runComplete pubsub.Publisher[notify.RunComplete] currentAgent SessionAgent agents map[string]SessionAgent @@ -119,6 +120,7 @@ func NewCoordinator( filetracker filetracker.Service, lspManager *lsp.Manager, notify pubsub.Publisher[notify.Notification], + runComplete pubsub.Publisher[notify.RunComplete], skillsMgr *skills.Manager, ) (Coordinator, error) { // Skills are pre-discovered by the caller (see app.New / @@ -143,6 +145,7 @@ func NewCoordinator( filetracker: filetracker, lspManager: lspManager, notify: notify, + runComplete: runComplete, agents: make(map[string]SessionAgent), allSkills: allSkills, activeSkills: activeSkills, @@ -210,9 +213,34 @@ func (c *coordinator) Run(ctx context.Context, sessionID string, prompt string, slog.Error("Failed to refresh OAuth2 token. Proceeding with existing token.", "error", err) } + // Coalesce per-attempt RunComplete payloads so only the final + // outcome reaches subscribers. Without this, the first attempt's + // failed RunComplete (unauthorized) would race ahead of the + // retry's success, and `crush run` would exit on the stale error + // before ever seeing the retry result. Each attempt's + // SessionAgentCall.OnComplete hook overwrites latest; we publish + // exactly once after retries resolve, via PublishMustDeliver, so + // a momentarily-full subscriber buffer can't silently drop the + // terminal event. + var ( + latest notify.RunComplete + hasLatest bool + ) + onComplete := func(rc notify.RunComplete) { + latest = rc + hasLatest = true + } + // Propagate the caller-supplied RunID (set via agent.WithRunID + // at the HTTP boundary in backend.SendMessage) onto the + // SessionAgentCall so the terminal RunComplete event echoes it + // back. Both attempts in the retry chain reuse the same RunID; + // the coalesce closure publishes the final outcome under that + // same correlator. + runID := RunIDFromContext(ctx) run := func() (*fantasy.AgentResult, error) { return c.currentAgent.Run(ctx, SessionAgentCall{ SessionID: sessionID, + RunID: runID, Prompt: prompt, Attachments: attachments, MaxOutputTokens: maxTokens, @@ -222,6 +250,7 @@ func (c *coordinator) Run(ctx context.Context, sessionID string, prompt string, TopK: topK, FrequencyPenalty: freqPenalty, PresencePenalty: presPenalty, + OnComplete: onComplete, }) } beforeLoaded := c.skillTracker.LoadedNames() @@ -230,10 +259,13 @@ func (c *coordinator) Run(ctx context.Context, sessionID string, prompt string, if c.isUnauthorized(originalErr) { if err := c.retryAfterUnauthorized(ctx, providerCfg); err == nil { - return run() + result, originalErr = run() } } + if hasLatest && c.runComplete != nil { + c.runComplete.PublishMustDeliver(ctx, pubsub.UpdatedEvent, latest) + } return result, originalErr } @@ -452,6 +484,7 @@ func (c *coordinator) buildAgent(ctx context.Context, prompt *prompt.Prompt, age Messages: c.messages, Tools: nil, Notify: c.notify, + RunComplete: c.runComplete, }) c.readyWg.Go(func() error { diff --git a/internal/agent/notify/notify.go b/internal/agent/notify/notify.go index 2ffb03203dd36f646cad6c544c717c400917b007..ac7f724c0f07f552d9759247821a2555c9e12524 100644 --- a/internal/agent/notify/notify.go +++ b/internal/agent/notify/notify.go @@ -21,3 +21,31 @@ type Notification struct { Type Type ProviderID string } + +// RunComplete is the authoritative end-of-run signal for a session. +// It is published exactly once per top-level agent run (per +// [sessionAgent.Run] invocation that actually executed) after all +// message updates for the turn have been flushed via +// message.Service.FlushAll. Carries the final assistant text and +// message ID so non-interactive clients can reconcile stdout even if +// SSE events arrive out of order or are dropped by the broker. Error +// is non-empty when the run terminated with an error; Cancelled is +// true when the run terminated due to context cancellation. The two +// are mutually exclusive in the success case but may overlap when a +// cancel triggers a downstream error. +// +// RunID identifies the specific request that produced this event. +// It is the value the caller set on `proto.AgentMessage.RunID` (or +// equivalently propagated via agent.WithRunID on the context that +// reaches the coordinator); empty when no caller set one. Filtering +// by RunID lets a client correlate a SendMessage call with its +// terminal event even when the session is busy and other turns are +// finishing on the same session. +type RunComplete struct { + SessionID string + RunID string + MessageID string + Text string + Error string + Cancelled bool +} diff --git a/internal/agent/run_complete_test.go b/internal/agent/run_complete_test.go new file mode 100644 index 0000000000000000000000000000000000000000..74f9232a0946b24d38f05873fa39066dcae40c27 --- /dev/null +++ b/internal/agent/run_complete_test.go @@ -0,0 +1,89 @@ +package agent + +import ( + "context" + "testing" + "time" + + "github.com/charmbracelet/crush/internal/agent/notify" + "github.com/charmbracelet/crush/internal/pubsub" + "github.com/stretchr/testify/require" +) + +// TestSessionAgentRun_QueueStripsOnComplete verifies that when a Run +// call is enqueued (because the session is already busy), the +// OnComplete hook is NOT propagated onto the queued copy. The hook +// belongs to the caller's retry/coalesce scope (typically +// coordinator.Run) which has already returned by the time the queue +// drains; carrying it forward would silently funnel the terminal +// event into a closure nobody reads, and subscribers (`crush run`) +// would hang waiting for a RunComplete that never publishes. +func TestSessionAgentRun_QueueStripsOnComplete(t *testing.T) { + t.Parallel() + + env := testEnv(t) + a := NewSessionAgent(SessionAgentOptions{ + Sessions: env.sessions, + Messages: env.messages, + }).(*sessionAgent) + + const sessionID = "queued-session" + // Mark the session as busy so Run takes the queue branch + // without needing a real model. + a.activeRequests.Set(sessionID, func() {}) + + var called bool + hook := func(notify.RunComplete) { called = true } + + res, err := a.Run(t.Context(), SessionAgentCall{ + SessionID: sessionID, + RunID: "run-xyz", + Prompt: "queued prompt", + OnComplete: hook, + }) + require.NoError(t, err) + require.Nil(t, res, "queued Run must return (nil, nil)") + require.False(t, called, + "OnComplete must not fire on the enqueue path; the caller's scope is still live") + + queued, ok := a.messageQueue.Get(sessionID) + require.True(t, ok) + require.Len(t, queued, 1) + require.Nil(t, queued[0].OnComplete, + "queued SessionAgentCall must have OnComplete stripped so the drain falls back to the default broker publish") + require.Equal(t, "queued prompt", queued[0].Prompt, + "all other fields must be preserved on the queued copy") + require.Equal(t, "run-xyz", queued[0].RunID, + "RunID must be preserved on the queued copy so the drained turn's "+ + "RunComplete still correlates with the originating SendMessage") +} + +// TestRunCompletePublisher_MustDeliverOverTakesPublish exercises the +// pubsub.Publisher interface change end-to-end: a Broker is the only +// concrete Publisher implementation and must satisfy both Publish and +// PublishMustDeliver. The coordinator's final RunComplete emit relies +// on PublishMustDeliver to apply bounded-blocking semantics so a +// momentarily-full subscriber buffer can't silently drop the +// authoritative end-of-run event. +func TestRunCompletePublisher_MustDeliverOverTakesPublish(t *testing.T) { + t.Parallel() + + broker := pubsub.NewBroker[notify.RunComplete]() + t.Cleanup(broker.Shutdown) + + // Subscribe before publishing so the event is delivered. + ctx, cancel := context.WithCancel(t.Context()) + defer cancel() + ch := broker.Subscribe(ctx) + + rc := notify.RunComplete{SessionID: "S", MessageID: "m", Text: "ok"} + var pub pubsub.Publisher[notify.RunComplete] = broker + pub.PublishMustDeliver(t.Context(), pubsub.UpdatedEvent, rc) + + select { + case got := <-ch: + require.Equal(t, rc, got.Payload) + case <-time.After(time.Second): + t.Fatal("PublishMustDeliver did not deliver event") + } +} diff --git a/internal/agent/runid.go b/internal/agent/runid.go new file mode 100644 index 0000000000000000000000000000000000000000..1afac005b9e4c627c9c06f21a2c565bab86e1c28 --- /dev/null +++ b/internal/agent/runid.go @@ -0,0 +1,33 @@ +package agent + +import "context" + +// runIDContextKey is the unexported context key used to carry a +// caller-supplied RunID from the workspace HTTP boundary +// (backend.SendMessage) down into coordinator.Run without forcing a +// breaking change to the Coordinator.Run signature. The value is +// then copied onto SessionAgentCall.RunID by the coordinator so the +// agent's terminal RunComplete event can echo it back to the +// originating caller. +type runIDContextKey struct{} + +// WithRunID returns ctx tagged with a per-request RunID. It is the +// boundary helper for callers that need their SendMessage→Run +// terminal event to be uniquely correlatable (e.g. `crush run` +// against a session that may be busy). Empty runIDs are stored +// as-is; downstream code treats an empty RunID as "caller did not +// supply one" and falls back to SessionID-only correlation. +func WithRunID(ctx context.Context, runID string) context.Context { + return context.WithValue(ctx, runIDContextKey{}, runID) +} + +// RunIDFromContext returns the RunID set by [WithRunID], or "" if +// none was set or the value is not a string. Exported because the +// coordinator and tests in other packages need to read it; safe to +// call on any context. +func RunIDFromContext(ctx context.Context) string { + if v, ok := ctx.Value(runIDContextKey{}).(string); ok { + return v + } + return "" +} diff --git a/internal/app/app.go b/internal/app/app.go index 8fd3c9d9749b112d66dff8bb236734e92ab2f843..9509fa3a9dc778d507d38f60c8ca523031b7ecb7 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -75,6 +75,13 @@ type App struct { globalCtx context.Context cleanupFuncs []func(context.Context) error agentNotifications *pubsub.Broker[notify.Notification] + // runCompletions is the authoritative per-run completion signal, + // emitted once per top-level agent turn after all message + // updates have been flushed. Bridged into app.events so SSE + // subscribers (notably `crush run` in client/server mode) can + // drive their exit on a deterministic, payload-bearing event + // instead of guessing from message finish parts. + runCompletions *pubsub.Broker[notify.RunComplete] } // New initializes a new application instance. skillsMgr carries the @@ -110,6 +117,7 @@ func New(ctx context.Context, conn *sql.DB, store *config.ConfigStore, skillsMgr serviceEventsWG: &sync.WaitGroup{}, tuiWG: &sync.WaitGroup{}, agentNotifications: pubsub.NewBroker[notify.Notification](), + runCompletions: pubsub.NewBroker[notify.RunComplete](), } app.setupEvents() @@ -485,6 +493,7 @@ func (app *App) setupEvents() { setupSubscriber(ctx, app.serviceEventsWG, "permissions-notifications", app.Permissions.SubscribeNotifications, app.events) setupSubscriber(ctx, app.serviceEventsWG, "history", app.History.Subscribe, app.events) setupSubscriber(ctx, app.serviceEventsWG, "agent-notifications", app.agentNotifications.Subscribe, app.events) + setupSubscriberMustDeliver(ctx, app.serviceEventsWG, "run-completions", app.runCompletions.Subscribe, app.events) setupSubscriber(ctx, app.serviceEventsWG, "mcp", mcp.SubscribeEvents, app.events) setupSubscriber(ctx, app.serviceEventsWG, "lsp", SubscribeLSPEvents, app.events) if app.Skills != nil { @@ -524,6 +533,38 @@ func setupSubscriber[T any]( }) } +// setupSubscriberMustDeliver is the bounded-blocking fan-in variant of +// setupSubscriber: it re-publishes upstream events onto the shared +// app.events broker using PublishMustDeliver instead of Publish. Use +// this for terminal events that subscribers cannot tolerate losing — +// notably RunComplete, which is the authoritative end-of-run signal +// for `crush run`. A lossy fan-in here can drop the only terminal +// event and hang non-interactive clients waiting on it. +func setupSubscriberMustDeliver[T any]( + ctx context.Context, + wg *sync.WaitGroup, + name string, + subscriber func(context.Context) <-chan pubsub.Event[T], + broker *pubsub.Broker[tea.Msg], +) { + wg.Go(func() { + subCh := subscriber(ctx) + for { + select { + case event, ok := <-subCh: + if !ok { + slog.Debug("Subscription channel closed", "name", name) + return + } + broker.PublishMustDeliver(ctx, pubsub.UpdatedEvent, tea.Msg(event)) + case <-ctx.Done(): + slog.Debug("Subscription cancelled", "name", name) + return + } + } + }) +} + func (app *App) InitCoderAgent(ctx context.Context) error { coderAgentCfg := app.config.Config().Agents[config.AgentCoder] if coderAgentCfg.ID == "" { @@ -540,6 +581,7 @@ func (app *App) InitCoderAgent(ctx context.Context) error { app.FileTracker, app.LSPManager, app.agentNotifications, + app.runCompletions, app.Skills, ) if err != nil { diff --git a/internal/app/testing.go b/internal/app/testing.go index f17e94cfa99411b4594fce72bd894cc5fba4c4fd..1722e2b1544ebed0850941090eff34152567bbdd 100644 --- a/internal/app/testing.go +++ b/internal/app/testing.go @@ -34,6 +34,7 @@ func NewForTest(ctx context.Context) *App { serviceEventsWG: &sync.WaitGroup{}, tuiWG: &sync.WaitGroup{}, agentNotifications: pubsub.NewBroker[notify.Notification](), + runCompletions: pubsub.NewBroker[notify.RunComplete](), } eventsCtx, cancel := context.WithCancel(ctx) @@ -44,6 +45,8 @@ func NewForTest(ctx context.Context) *App { app.Permissions.SubscribeNotifications, app.events) setupSubscriber(eventsCtx, app.serviceEventsWG, "agent-notifications", app.agentNotifications.Subscribe, app.events) + setupSubscriber(eventsCtx, app.serviceEventsWG, "run-completions", + app.runCompletions.Subscribe, app.events) app.cleanupFuncs = append(app.cleanupFuncs, func(context.Context) error { cancel() app.serviceEventsWG.Wait() diff --git a/internal/backend/agent.go b/internal/backend/agent.go index 346e54bf8069af5f0f8b3bbde049aaed70e86cff..78447ab7c64a82bb2638fb3fe184d0be132b4589 100644 --- a/internal/backend/agent.go +++ b/internal/backend/agent.go @@ -3,12 +3,20 @@ package backend import ( "context" + "github.com/charmbracelet/crush/internal/agent" "github.com/charmbracelet/crush/internal/config" "github.com/charmbracelet/crush/internal/proto" ) // SendMessage sends a prompt to the agent coordinator for the given // workspace and session. +// +// When msg.RunID is non-empty it is attached to the context via +// agent.WithRunID so the coordinator can stamp the resulting +// SessionAgentCall (and therefore the terminal notify.RunComplete +// event) with that correlator. This is the only way for the +// originating client to distinguish its own turn's RunComplete from +// any concurrent turn that finishes on the same session. func (b *Backend) SendMessage(ctx context.Context, workspaceID string, msg proto.AgentMessage) error { ws, err := b.GetWorkspace(workspaceID) if err != nil { @@ -19,6 +27,9 @@ func (b *Backend) SendMessage(ctx context.Context, workspaceID string, msg proto return ErrAgentNotInitialized } + if msg.RunID != "" { + ctx = agent.WithRunID(ctx, msg.RunID) + } _, err = ws.AgentCoordinator.Run(ctx, msg.SessionID, msg.Prompt, proto.AttachmentsToMessage(msg.Attachments)...) return err } diff --git a/internal/client/proto.go b/internal/client/proto.go index f5f2f4273aba69526fdb66c6a1e0230c554be456..62a43b5884e01ae8fcd3242c68e95d1f76251c42 100644 --- a/internal/client/proto.go +++ b/internal/client/proto.go @@ -231,6 +231,12 @@ func (c *Client) SubscribeEvents(ctx context.Context, id string) (<-chan any, er if !sendEvent(ctx, events, e) { return } + case pubsub.PayloadTypeRunComplete: + var e pubsub.Event[proto.RunComplete] + _ = json.Unmarshal(p.Payload, &e) + if !sendEvent(ctx, events, e) { + return + } default: slog.Warn("Unknown event type", "type", p.Type) continue @@ -400,9 +406,16 @@ func (c *Client) UpdateAgent(ctx context.Context, id string) error { } // SendMessage sends a message to the agent for a workspace. -func (c *Client) SendMessage(ctx context.Context, id string, sessionID, prompt string, attachments ...message.Attachment) error { +// +// When runID is non-empty it is echoed back on the resulting +// proto.RunComplete event, giving the caller a unique correlator +// for completion detection. Pass "" when the caller does not need +// to distinguish its own turn's terminal event from any concurrent +// turn on the same session (e.g. interactive TUI usage). +func (c *Client) SendMessage(ctx context.Context, id string, sessionID, runID, prompt string, attachments ...message.Attachment) error { rsp, err := c.post(ctx, fmt.Sprintf("/workspaces/%s/agent", id), nil, jsonBody(proto.AgentMessage{ SessionID: sessionID, + RunID: runID, Prompt: prompt, Attachments: proto.AttachmentsFromMessage(attachments), }), http.Header{"Content-Type": []string{"application/json"}}) diff --git a/internal/cmd/run.go b/internal/cmd/run.go index d957720068c47b85038731d7c4ec6eb2dbe9c135..87fa32606674847741a9d028a26375fb98935fc4 100644 --- a/internal/cmd/run.go +++ b/internal/cmd/run.go @@ -3,6 +3,7 @@ package cmd import ( "context" "fmt" + "io" "log/slog" "os" "os/signal" @@ -24,6 +25,7 @@ import ( "github.com/charmbracelet/x/ansi" "github.com/charmbracelet/x/exp/charmtone" "github.com/charmbracelet/x/term" + "github.com/google/uuid" "github.com/spf13/cobra" ) @@ -243,12 +245,22 @@ func runNonInteractive( return fmt.Errorf("failed to subscribe to events: %w", err) } - if err := c.SendMessage(ctx, ws.ID, sess.ID, prompt); err != nil { + // Mint a per-call RunID so we can correlate the terminal + // RunComplete with *this* SendMessage even if the session was + // busy and another turn finished first. Without it the stream + // loop would exit on whichever RunComplete arrived first for + // the same session and drop the queued prompt's output. + runID := uuid.New().String() + if err := c.SendMessage(ctx, ws.ID, sess.ID, runID, prompt); err != nil { return fmt.Errorf("failed to send message: %w", err) } - messageReadBytes := make(map[string]int) - var printed bool + stream := &runStream{ + sessionID: sess.ID, + runID: runID, + out: os.Stdout, + read: make(map[string]int), + } defer func() { if progress && stderrTTY { @@ -269,49 +281,141 @@ func runNonInteractive( return nil } - switch e := ev.(type) { - case pubsub.Event[proto.Message]: - msg := e.Payload - if msg.SessionID != sess.ID || msg.Role != proto.Assistant || len(msg.Parts) == 0 { - continue - } - stopSpinner() + done, err := stream.handle(ev, stopSpinner) + if err != nil { + return err + } + if done { + return nil + } - content := msg.Content().String() - readBytes := messageReadBytes[msg.ID] + case <-ctx.Done(): + stopSpinner() + return ctx.Err() + } + } +} - if len(content) < readBytes { - slog.Error("Non-interactive: message content shorter than read bytes", - "message_length", len(content), "read_bytes", readBytes) - return fmt.Errorf("message content is shorter than read bytes: %d < %d", len(content), readBytes) - } +// runStream tracks the per-message stdout cursor and the +// reconciliation state used by [runNonInteractive] to translate +// streaming SSE events into a final, complete stdout for `crush run`. +// It is split out so the state machine can be exercised in unit tests +// without spinning up the full server/client harness. +// +// runID, when non-empty, is the authoritative correlator for the +// terminal RunComplete event: the stream suppresses live message +// events and only exits on a RunComplete whose RunID matches, so a +// turn that finishes first on the same session (e.g. when our prompt +// was queued behind a busy session) cannot contaminate stdout or +// terminate us prematurely. When empty (older servers, tests that +// don't supply one) the stream falls back to SessionID-only matching +// and live message streaming, which is still correct for the +// single-turn case. +type runStream struct { + sessionID string + runID string + out io.Writer + read map[string]int + printed bool +} - part := content[readBytes:] - if readBytes == 0 { - part = strings.TrimLeft(part, " \t") - } - if printed || strings.TrimSpace(part) != "" { - printed = true - fmt.Fprint(os.Stdout, part) - } - messageReadBytes[msg.ID] = len(content) +// handle processes one SSE event. Returns done=true when the run +// loop should exit (RunComplete observed); returns an error only +// when the agent run failed (not on context cancel — that path is +// handled by the caller's select). stopSpinner is called on the +// first observable assistant output and on completion; passing nil +// is safe for tests. +func (s *runStream) handle(ev any, stopSpinner func()) (done bool, err error) { + stop := func() { + if stopSpinner != nil { + stopSpinner() + } + } + switch e := ev.(type) { + case pubsub.Event[proto.Message]: + msg := e.Payload + if msg.SessionID != s.sessionID || msg.Role != proto.Assistant || len(msg.Parts) == 0 { + return false, nil + } + if s.runID != "" { + return false, nil + } + stop() + + content := msg.Content().String() + readBytes := s.read[msg.ID] + if len(content) < readBytes { + slog.Error("Non-interactive: message content shorter than read bytes", + "message_length", len(content), "read_bytes", readBytes) + return false, fmt.Errorf("message content is shorter than read bytes: %d < %d", len(content), readBytes) + } - if msg.IsFinished() { - return nil + part := content[readBytes:] + if readBytes == 0 { + part = strings.TrimLeft(part, " \t") + } + if s.printed || strings.TrimSpace(part) != "" { + s.printed = true + fmt.Fprint(s.out, part) + } + s.read[msg.ID] = len(content) + return false, nil + + case pubsub.Event[proto.RunComplete]: + // RunComplete is the authoritative end-of-run signal. We + // exit on it instead of guessing from message finish parts, + // which fire on every tool-call step too and were the + // source of the regression where `crush run` exited + // mid-turn on finish.reason == tool_use. + // + // Correlation: + // - if we minted a RunID for this SendMessage, only the + // event whose RunID matches is ours; any other turn + // finishing first on the same session (busy-session + // queue path) must be ignored. + // - if we have no RunID (older server, tests), fall back + // to SessionID matching. + if s.runID != "" { + if e.Payload.RunID != s.runID { + return false, nil + } + } else if e.Payload.SessionID != s.sessionID { + return false, nil + } + stop() + if e.Payload.Error != "" && !e.Payload.Cancelled { + return true, fmt.Errorf("agent run failed: %s", e.Payload.Error) + } + // Reconcile stdout against the authoritative final + // assistant text carried in the event. The pubsub fan-in + // does not serialize publishes across upstream brokers, so + // the final message event may not have reached this loop + // yet; the embedded Text field is the backstop that + // guarantees the full final text always appears on stdout. + if e.Payload.MessageID != "" { + full := e.Payload.Text + readBytes := s.read[e.Payload.MessageID] + if readBytes < len(full) { + tail := full[readBytes:] + if readBytes == 0 { + tail = strings.TrimLeft(tail, " \t") } - - case pubsub.Event[proto.AgentEvent]: - if e.Payload.Error != nil { - stopSpinner() - return fmt.Errorf("agent error: %w", e.Payload.Error) + if s.printed || strings.TrimSpace(tail) != "" { + s.printed = true + fmt.Fprint(s.out, tail) } } + } + return true, nil - case <-ctx.Done(): - stopSpinner() - return ctx.Err() + case pubsub.Event[proto.AgentEvent]: + if e.Payload.Error != nil { + stop() + return true, fmt.Errorf("agent error: %w", e.Payload.Error) } + return false, nil } + return false, nil } // waitForAgent polls GetAgentInfo until the agent is ready, with a diff --git a/internal/cmd/run_stream_test.go b/internal/cmd/run_stream_test.go new file mode 100644 index 0000000000000000000000000000000000000000..ac168fa77045aa6aa5761b6f9c657f066c952734 --- /dev/null +++ b/internal/cmd/run_stream_test.go @@ -0,0 +1,329 @@ +package cmd + +import ( + "bytes" + "testing" + "time" + + "github.com/charmbracelet/crush/internal/proto" + "github.com/charmbracelet/crush/internal/pubsub" + "github.com/stretchr/testify/require" +) + +// TestRunStream_ToolUseDoesNotTerminate is the regression test for +// the original bug: a tool-call assistant message has a Finish part +// with reason=tool_use and used to terminate `crush run` early via +// the discarded `msg.IsFinished()` exit condition. With the new +// RunComplete-driven loop, tool_use finishes must keep the stream +// alive so the post-tool final text still reaches stdout. +func TestRunStream_ToolUseDoesNotTerminate(t *testing.T) { + t.Parallel() + + s := &runStream{sessionID: "S", out: &bytes.Buffer{}, read: map[string]int{}} + + toolUse := proto.Message{ + ID: "m1", + SessionID: "S", + Role: proto.Assistant, + Parts: []proto.ContentPart{ + proto.TextContent{Text: ""}, + proto.Finish{Reason: proto.FinishReasonToolUse, Time: time.Now().Unix()}, + }, + } + done, err := s.handle(pubsub.Event[proto.Message]{Payload: toolUse}, nil) + require.NoError(t, err) + require.False(t, done, "tool_use finish must NOT terminate the run loop") +} + +// TestRunStream_RunCompleteExits verifies the happy path: streaming +// assistant text then RunComplete terminates with the full final +// text on stdout. Together with the tool_use test above this +// nails down the "tool use + final text" sequence that the original +// bug truncated. +func TestRunStream_RunCompleteExits(t *testing.T) { + t.Parallel() + + buf := &bytes.Buffer{} + s := &runStream{sessionID: "S", out: buf, read: map[string]int{}} + + // Tool-use step. + done, err := s.handle(pubsub.Event[proto.Message]{Payload: proto.Message{ + ID: "m1", SessionID: "S", Role: proto.Assistant, + Parts: []proto.ContentPart{ + proto.TextContent{Text: ""}, + proto.Finish{Reason: proto.FinishReasonToolUse}, + }, + }}, nil) + require.NoError(t, err) + require.False(t, done) + + // Final assistant message stream. + done, err = s.handle(pubsub.Event[proto.Message]{Payload: proto.Message{ + ID: "m2", SessionID: "S", Role: proto.Assistant, + Parts: []proto.ContentPart{ + proto.TextContent{Text: "VERDICT: APPROVED"}, + proto.Finish{Reason: proto.FinishReasonEndTurn}, + }, + }}, nil) + require.NoError(t, err) + require.False(t, done, "message finish (even end_turn) must not exit; RunComplete is the only terminal signal") + + // RunComplete. + done, err = s.handle(pubsub.Event[proto.RunComplete]{Payload: proto.RunComplete{ + SessionID: "S", + MessageID: "m2", + Text: "VERDICT: APPROVED", + }}, nil) + require.NoError(t, err) + require.True(t, done) + require.Equal(t, "VERDICT: APPROVED", buf.String()) +} + +// TestRunStream_ReconcilesOnOutOfOrderRunComplete is the worst-case +// ordering scenario: RunComplete reaches the client BEFORE any of +// the streaming assistant message events for the turn (the pubsub +// fan-in across upstream brokers does not preserve cross-broker +// ordering). The embedded Text field must rescue stdout so the +// caller still sees the complete final text. +func TestRunStream_ReconcilesOnOutOfOrderRunComplete(t *testing.T) { + t.Parallel() + + buf := &bytes.Buffer{} + s := &runStream{sessionID: "S", out: buf, read: map[string]int{}} + + done, err := s.handle(pubsub.Event[proto.RunComplete]{Payload: proto.RunComplete{ + SessionID: "S", + MessageID: "m2", + Text: "VERDICT: APPROVED", + }}, nil) + require.NoError(t, err) + require.True(t, done) + require.Equal(t, "VERDICT: APPROVED", buf.String(), + "RunComplete must reconcile stdout when message events did not arrive in time") +} + +// TestRunStream_ReconcilesPartialStream covers the realistic case +// where some streaming output reached stdout before RunComplete +// arrived: the reconciliation pass must append only the unread tail, +// never duplicate the prefix. +func TestRunStream_ReconcilesPartialStream(t *testing.T) { + t.Parallel() + + buf := &bytes.Buffer{} + s := &runStream{sessionID: "S", out: buf, read: map[string]int{}} + + _, err := s.handle(pubsub.Event[proto.Message]{Payload: proto.Message{ + ID: "m2", SessionID: "S", Role: proto.Assistant, + Parts: []proto.ContentPart{proto.TextContent{Text: "VERDICT: "}}, + }}, nil) + require.NoError(t, err) + + _, err = s.handle(pubsub.Event[proto.RunComplete]{Payload: proto.RunComplete{ + SessionID: "S", + MessageID: "m2", + Text: "VERDICT: APPROVED", + }}, nil) + require.NoError(t, err) + require.Equal(t, "VERDICT: APPROVED", buf.String()) +} + +// TestRunStream_IgnoresOtherSessions ensures multi-session +// subscribers (e.g. a TUI watching workspace events while `crush +// run` is in flight against the same workspace) do not cause +// premature exit on RunComplete for a different session. +func TestRunStream_IgnoresOtherSessions(t *testing.T) { + t.Parallel() + + s := &runStream{sessionID: "S", out: &bytes.Buffer{}, read: map[string]int{}} + done, err := s.handle(pubsub.Event[proto.RunComplete]{Payload: proto.RunComplete{ + SessionID: "OTHER", + MessageID: "x", + Text: "noise", + }}, nil) + require.NoError(t, err) + require.False(t, done) +} + +// TestRunStream_ErrorRunComplete surfaces a failing run as a +// non-nil error from `crush run` so shells and CI catch it via +// exit status. +func TestRunStream_ErrorRunComplete(t *testing.T) { + t.Parallel() + + s := &runStream{sessionID: "S", out: &bytes.Buffer{}, read: map[string]int{}} + done, err := s.handle(pubsub.Event[proto.RunComplete]{Payload: proto.RunComplete{ + SessionID: "S", + Error: "model temporarily unavailable", + }}, nil) + require.True(t, done) + require.Error(t, err) + require.Contains(t, err.Error(), "model temporarily unavailable") +} + +// TestRunStream_CancelledRunCompleteIsClean ensures a cancelled +// run (e.g. Ctrl+C while `crush run` waits) exits cleanly rather +// than reporting the cancellation as a failure. +func TestRunStream_CancelledRunCompleteIsClean(t *testing.T) { + t.Parallel() + + s := &runStream{sessionID: "S", out: &bytes.Buffer{}, read: map[string]int{}} + done, err := s.handle(pubsub.Event[proto.RunComplete]{Payload: proto.RunComplete{ + SessionID: "S", + Error: "context canceled", + Cancelled: true, + }}, nil) + require.True(t, done) + require.NoError(t, err) +} + +// TestRunStream_LeadingWhitespaceTrimmedOnce mirrors the +// pre-existing trim of leading whitespace on the first byte of +// stdout: the trim must happen exactly once even when stdout is +// first produced by the RunComplete reconciliation path rather +// than the live stream. +func TestRunStream_LeadingWhitespaceTrimmedOnce(t *testing.T) { + t.Parallel() + + buf := &bytes.Buffer{} + s := &runStream{sessionID: "S", out: buf, read: map[string]int{}} + + done, err := s.handle(pubsub.Event[proto.RunComplete]{Payload: proto.RunComplete{ + SessionID: "S", + MessageID: "m2", + Text: " \tactual output", + }}, nil) + require.NoError(t, err) + require.True(t, done) + require.Equal(t, "actual output", buf.String()) +} + +// TestRunStream_StopSpinnerInvokedOnFirstOutput verifies the +// spinner is stopped exactly when meaningful output starts (either +// a streamed assistant message or the reconciliation tail). This +// matches the prior behaviour and prevents the spinner from +// painting over the final response on TTYs. +func TestRunStream_StopSpinnerInvokedOnFirstOutput(t *testing.T) { + t.Parallel() + + calls := 0 + stop := func() { calls++ } + s := &runStream{sessionID: "S", out: &bytes.Buffer{}, read: map[string]int{}} + _, _ = s.handle(pubsub.Event[proto.Message]{Payload: proto.Message{ + ID: "m1", SessionID: "S", Role: proto.Assistant, + Parts: []proto.ContentPart{proto.TextContent{Text: "hi"}}, + }}, stop) + require.GreaterOrEqual(t, calls, 1, "spinner must stop once stdout has content") +} + +// TestRunStream_RunIDFiltersForeignTurns covers the busy-session +// queue scenario: `crush run --continue` attaches to a session +// whose currently running turn finishes first, publishing its +// RunComplete on the same session ID. Without per-run correlation +// the stream would exit on that foreign event and drop our own +// queued turn's output. With RunID filtering the foreign event is +// ignored and only the matching RunComplete terminates the stream. +func TestRunStream_RunIDFiltersForeignTurns(t *testing.T) { + t.Parallel() + + const sessionID = "S" + const myRun = "run-mine" + const otherRun = "run-other" + + buf := &bytes.Buffer{} + s := &runStream{ + sessionID: sessionID, + runID: myRun, + out: buf, + read: map[string]int{}, + } + + // The busy session's existing turn emits more text before it finishes. + done, err := s.handle(pubsub.Event[proto.Message]{Payload: proto.Message{ + ID: "other-msg", + SessionID: sessionID, + Role: proto.Assistant, + Parts: []proto.ContentPart{proto.TextContent{Text: "noise from another turn"}}, + }}, nil) + require.NoError(t, err) + require.False(t, done, + "foreign message on same session must not terminate our run") + require.Empty(t, buf.String(), + "foreign message on same session must not write to our stdout") + + // The busy session's existing turn finishes first. + done, err = s.handle(pubsub.Event[proto.RunComplete]{Payload: proto.RunComplete{ + SessionID: sessionID, + RunID: otherRun, + MessageID: "other-msg", + Text: "noise from another turn", + }}, nil) + require.NoError(t, err) + require.False(t, done, + "foreign RunComplete on same session must not terminate our run") + require.Empty(t, buf.String(), + "foreign RunComplete must not write to our stdout") + + // Our own queued turn eventually finishes. + done, err = s.handle(pubsub.Event[proto.RunComplete]{Payload: proto.RunComplete{ + SessionID: sessionID, + RunID: myRun, + MessageID: "my-msg", + Text: "OK", + }}, nil) + require.NoError(t, err) + require.True(t, done, "matching RunID must terminate the stream") + require.Equal(t, "OK", buf.String()) +} + +func TestRunStream_RunIDSuppressesLiveMessagesAndPrintsRunComplete(t *testing.T) { + t.Parallel() + + buf := &bytes.Buffer{} + s := &runStream{ + sessionID: "S", + runID: "run-mine", + out: buf, + read: map[string]int{}, + } + + done, err := s.handle(pubsub.Event[proto.Message]{Payload: proto.Message{ + ID: "my-msg", + SessionID: "S", + Role: proto.Assistant, + Parts: []proto.ContentPart{proto.TextContent{Text: "streamed prefix"}}, + }}, nil) + require.NoError(t, err) + require.False(t, done) + require.Empty(t, buf.String()) + + done, err = s.handle(pubsub.Event[proto.RunComplete]{Payload: proto.RunComplete{ + SessionID: "S", + RunID: "run-mine", + MessageID: "my-msg", + Text: "streamed prefix final", + }}, nil) + require.NoError(t, err) + require.True(t, done) + require.Equal(t, "streamed prefix final", buf.String()) +} + +// TestRunStream_NoRunIDFallsBackToSessionID preserves the older +// behaviour for callers (and tests) that don't supply a RunID: +// SessionID-only matching still terminates the stream on the +// session's RunComplete. This keeps the contract backwards +// compatible with servers that don't echo RunID and with the +// pre-existing TestRunStream_* assertions. +func TestRunStream_NoRunIDFallsBackToSessionID(t *testing.T) { + t.Parallel() + + buf := &bytes.Buffer{} + s := &runStream{sessionID: "S", out: buf, read: map[string]int{}} + done, err := s.handle(pubsub.Event[proto.RunComplete]{Payload: proto.RunComplete{ + SessionID: "S", + MessageID: "m2", + Text: "DONE", + }}, nil) + require.NoError(t, err) + require.True(t, done) + require.Equal(t, "DONE", buf.String()) +} diff --git a/internal/proto/proto.go b/internal/proto/proto.go index 739d8ddd9ef34c40f4b1d8ca25ddc20cd8a9f581..3e37f61def9cd15ea4884ca6535fb62af82431e8 100644 --- a/internal/proto/proto.go +++ b/internal/proto/proto.go @@ -46,6 +46,31 @@ type CurrentSession struct { SessionID string `json:"session_id"` } +// RunComplete is the authoritative end-of-run signal for a session, +// emitted exactly once per top-level agent turn after all message +// updates for the turn have flushed. Clients that need a reliable +// completion contract (notably `crush run` in client/server mode) +// should listen for this event filtered by RunID (preferred) — or +// by SessionID when no RunID was supplied — and use Text and +// MessageID to reconcile any output they have already streamed from +// earlier message events. Error is non-empty when the run terminated +// with an error; Cancelled is true when terminated due to context +// cancellation. +// +// RunID echoes the value the caller set on AgentMessage.RunID. It is +// the only safe correlator when the caller's prompt was queued +// behind a busy session: another turn's RunComplete for the same +// SessionID may arrive first, and filtering by SessionID alone +// would terminate the caller before its own turn ran. +type RunComplete struct { + SessionID string `json:"session_id"` + RunID string `json:"run_id,omitempty"` + MessageID string `json:"message_id"` + Text string `json:"text,omitempty"` + Error string `json:"error,omitempty"` + Cancelled bool `json:"cancelled,omitempty"` +} + // SkillInfo describes a visible skill exposed to a frontend. type SkillInfo struct { ID string `json:"id"` @@ -89,8 +114,22 @@ func (a AgentInfo) IsZero() bool { } // AgentMessage represents a message sent to the agent. +// +// RunID, when non-empty, is echoed back on the [RunComplete] event +// emitted for the resulting turn. Callers that need to correlate a +// specific SendMessage with its terminal event (notably +// `crush run`, which may attach to a busy session whose currently +// running turn finishes first) should set it to a fresh unique +// value before the request. Server-side propagation flows through +// agent.WithRunID on the request context into the +// SessionAgentCall; it is preserved across the busy-session queue. +// When empty the resulting RunComplete carries an empty RunID and +// callers must fall back to SessionID-only filtering, which +// remains correct only when no other turns are in flight for the +// same session. type AgentMessage struct { SessionID string `json:"session_id"` + RunID string `json:"run_id,omitempty"` Prompt string `json:"prompt"` Attachments []Attachment `json:"attachments,omitempty"` } diff --git a/internal/pubsub/events.go b/internal/pubsub/events.go index 7f75d7d19e39f2a714fccc5be0232a19fadab7b9..682672dfb2730718565fa195b9f5ef56005773d4 100644 --- a/internal/pubsub/events.go +++ b/internal/pubsub/events.go @@ -26,6 +26,7 @@ const ( PayloadTypeAgentEvent PayloadType = "agent_event" PayloadTypeConfigChanged PayloadType = "config_changed" PayloadTypeSkillsEvent PayloadType = "skills_event" + PayloadTypeRunComplete PayloadType = "run_complete" ) // Payload wraps a discriminated JSON payload with a type tag. @@ -50,7 +51,14 @@ type ( } // Publisher can publish events of type T. + // + // Publish is best-effort and lossy under back-pressure; + // PublishMustDeliver applies the bounded-blocking semantics used + // for terminal events that must reach subscribers (finish, tool + // result, error, cancel, RunComplete). See [Broker.Publish] and + // [Broker.PublishMustDeliver]. Publisher[T any] interface { Publish(EventType, T) + PublishMustDeliver(context.Context, EventType, T) } ) diff --git a/internal/server/events.go b/internal/server/events.go index f38619c52528679bf75675780eb4bb47961bd640..4e3d6a1a262ae7b3399f0fa765ae863160ba57c8 100644 --- a/internal/server/events.go +++ b/internal/server/events.go @@ -93,6 +93,18 @@ func wrapEvent(ev any) *pubsub.Payload { Type: proto.AgentEventType(e.Payload.Type), }, }) + case pubsub.Event[notify.RunComplete]: + return envelope(pubsub.PayloadTypeRunComplete, pubsub.Event[proto.RunComplete]{ + Type: e.Type, + Payload: proto.RunComplete{ + SessionID: e.Payload.SessionID, + RunID: e.Payload.RunID, + MessageID: e.Payload.MessageID, + Text: e.Payload.Text, + Error: e.Payload.Error, + Cancelled: e.Payload.Cancelled, + }, + }) case pubsub.Event[proto.ConfigChanged]: return envelope(pubsub.PayloadTypeConfigChanged, e) case pubsub.Event[skills.Event]: diff --git a/internal/server/events_test.go b/internal/server/events_test.go index b32d694d793e04f216a51035098489588aa39628..432bc42f910b4acec675baea46754b81defab9f6 100644 --- a/internal/server/events_test.go +++ b/internal/server/events_test.go @@ -5,6 +5,7 @@ import ( "errors" "testing" + "github.com/charmbracelet/crush/internal/agent/notify" "github.com/charmbracelet/crush/internal/message" "github.com/charmbracelet/crush/internal/proto" "github.com/charmbracelet/crush/internal/pubsub" @@ -83,3 +84,67 @@ func TestSkillsEventToProto_RoundTrip(t *testing.T) { require.Equal(t, proto.SkillStateError, decoded.Payload.States[1].State) require.Equal(t, "bad frontmatter", decoded.Payload.States[1].Error) } + +// TestRunCompleteToProto_RoundTrip verifies that the authoritative +// per-run completion event survives the SSE envelope conversion with +// all reconciliation fields intact. SessionID, MessageID, and Text +// are what non-interactive clients (e.g. `crush run`) rely on to +// terminate the run loop and guarantee final text on stdout when +// message events arrive out of order. +func TestRunCompleteToProto_RoundTrip(t *testing.T) { + t.Parallel() + + src := pubsub.Event[notify.RunComplete]{ + Type: pubsub.UpdatedEvent, + Payload: notify.RunComplete{ + SessionID: "S", + RunID: "run-42", + MessageID: "M", + Text: "VERDICT: APPROVED", + Error: "", + Cancelled: false, + }, + } + + env := wrapEvent(src) + require.NotNil(t, env) + require.Equal(t, pubsub.PayloadTypeRunComplete, env.Type) + + var decoded pubsub.Event[proto.RunComplete] + require.NoError(t, json.Unmarshal(env.Payload, &decoded)) + require.Equal(t, pubsub.UpdatedEvent, decoded.Type) + require.Equal(t, "S", decoded.Payload.SessionID) + require.Equal(t, "run-42", decoded.Payload.RunID, + "RunID must survive the SSE envelope so clients can correlate "+ + "this event with the SendMessage call that produced it") + require.Equal(t, "M", decoded.Payload.MessageID) + require.Equal(t, "VERDICT: APPROVED", decoded.Payload.Text) + require.Empty(t, decoded.Payload.Error) + require.False(t, decoded.Payload.Cancelled) +} + +// TestRunCompleteToProto_Error verifies that error- and cancel-shaped +// RunComplete events round-trip cleanly so clients can distinguish +// "agent failed" (returns non-zero from `crush run`) from "agent +// cancelled by user" (clean exit). +func TestRunCompleteToProto_Error(t *testing.T) { + t.Parallel() + + src := pubsub.Event[notify.RunComplete]{ + Type: pubsub.UpdatedEvent, + Payload: notify.RunComplete{ + SessionID: "S", + MessageID: "M", + Text: "partial", + Error: "context canceled", + Cancelled: true, + }, + } + + env := wrapEvent(src) + require.NotNil(t, env) + var decoded pubsub.Event[proto.RunComplete] + require.NoError(t, json.Unmarshal(env.Payload, &decoded)) + require.Equal(t, "context canceled", decoded.Payload.Error) + require.True(t, decoded.Payload.Cancelled) +} diff --git a/internal/workspace/client_workspace.go b/internal/workspace/client_workspace.go index f4bd4ba35a6eb26db98420215f2a6282ebac0f9f..2018fab8a7dcc2cb3aeb0f44fc0920c1db72d852 100644 --- a/internal/workspace/client_workspace.go +++ b/internal/workspace/client_workspace.go @@ -179,7 +179,11 @@ func (w *ClientWorkspace) ListAllUserMessages(ctx context.Context) ([]message.Me // -- Agent -- func (w *ClientWorkspace) AgentRun(ctx context.Context, sessionID, prompt string, attachments ...message.Attachment) error { - return w.client.SendMessage(ctx, w.workspaceID(), sessionID, prompt, attachments...) + // The interactive TUI does not consume notify.RunComplete for + // completion detection (it observes message events directly), + // so passing an empty RunID is correct here: it skips the + // correlator stamping path without functional consequences. + return w.client.SendMessage(ctx, w.workspaceID(), sessionID, "", prompt, attachments...) } func (w *ClientWorkspace) AgentCancel(sessionID string) { @@ -707,6 +711,24 @@ func (w *ClientWorkspace) translateEvent(ev any) tea.Msg { Type: notify.Type(e.Payload.Type), }, } + case pubsub.Event[proto.RunComplete]: + // Translate the wire-level proto.RunComplete back into the + // agent's domain notify.RunComplete. Without this case the + // default branch below warns on every run completion in the + // server-mode TUI, even though the TUI itself doesn't act + // on RunComplete — converting silently keeps the workspace + // event bridge symmetric with the server-side wrapEvent. + return pubsub.Event[notify.RunComplete]{ + Type: e.Type, + Payload: notify.RunComplete{ + SessionID: e.Payload.SessionID, + RunID: e.Payload.RunID, + MessageID: e.Payload.MessageID, + Text: e.Payload.Text, + Error: e.Payload.Error, + Cancelled: e.Payload.Cancelled, + }, + } case pubsub.Event[proto.SkillsEvent]: states := protoToSkillStates(e.Payload.States) if w.skills != nil {