From 8b242c5313a79630d22429c04b699adcb624a89a Mon Sep 17 00:00:00 2001 From: Christian Rocha Date: Sun, 31 May 2026 23:06:41 -0400 Subject: [PATCH] chore(server): run accepted server prompts in the background Accept validated prompts quickly, then run them on workspace-owned background work. Failures are published through events while normal cancellation continues to be represented by the canceled assistant turn. Co-Authored-By: Charm Crush --- internal/backend/agent.go | 78 +++++++++++-- internal/backend/agent_test.go | 163 +++++++++++++++++++++++++++ internal/backend/testing.go | 14 +++ internal/server/agent_cancel_test.go | 89 ++++++--------- internal/server/proto.go | 17 +-- 5 files changed, 287 insertions(+), 74 deletions(-) create mode 100644 internal/backend/agent_test.go diff --git a/internal/backend/agent.go b/internal/backend/agent.go index 78447ab7c64a82bb2638fb3fe184d0be132b4589..2dd0479d3236d55e3919bdef1f16bb593fe5684e 100644 --- a/internal/backend/agent.go +++ b/internal/backend/agent.go @@ -2,22 +2,30 @@ package backend import ( "context" + "errors" "github.com/charmbracelet/crush/internal/agent" + "github.com/charmbracelet/crush/internal/agent/notify" "github.com/charmbracelet/crush/internal/config" "github.com/charmbracelet/crush/internal/proto" + "github.com/charmbracelet/crush/internal/pubsub" ) -// SendMessage sends a prompt to the agent coordinator for the given -// workspace and session. +// SendMessage validates and accepts a prompt for the workspace's agent, +// then dispatches the run on a goroutine bound to the workspace context +// and returns immediately. It does not wait for the LLM turn to +// complete: the run's lifetime is owned by the workspace, not by the +// caller. Errors from the dispatched run reach observers through the +// agent event channels (a notify.TypeAgentError notification), not +// through this return value. // -// When msg.RunID is non-empty it is attached to the context via -// agent.WithRunID so the coordinator can stamp the resulting -// SessionAgentCall (and therefore the terminal notify.RunComplete -// event) with that correlator. This is the only way for the -// originating client to distinguish its own turn's RunComplete from -// any concurrent turn that finishes on the same session. -func (b *Backend) SendMessage(ctx context.Context, workspaceID string, msg proto.AgentMessage) error { +// SendMessage returns synchronously when the request cannot be accepted: +// ErrWorkspaceNotFound if the workspace is missing, ErrAgentNotInitialized +// if its coordinator is nil, the structural validation errors from +// agent.ValidateCall (ErrEmptyPrompt, ErrSessionMissing) when the prompt +// or session is missing, and ErrWorkspaceClosing if the workspace is +// being torn down. +func (b *Backend) SendMessage(workspaceID string, msg proto.AgentMessage) error { ws, err := b.GetWorkspace(workspaceID) if err != nil { return err @@ -27,11 +35,59 @@ func (b *Backend) SendMessage(ctx context.Context, workspaceID string, msg proto return ErrAgentNotInitialized } + if err := agent.ValidateCall(agent.SessionAgentCall{ + SessionID: msg.SessionID, + Prompt: msg.Prompt, + Attachments: proto.AttachmentsToMessage(msg.Attachments), + }); err != nil { + return err + } + + accept := ws.AgentCoordinator.BeginAccepted(msg.SessionID) + + ws.runMu.Lock() + if ws.closing { + ws.runMu.Unlock() + accept.Close() + return ErrWorkspaceClosing + } + ws.runWG.Add(1) + ws.runMu.Unlock() + + go b.runAgent(ws, msg, accept) + return nil +} + +// runAgent executes an accepted agent run for the workspace. It owns the +// accept reservation (releasing it on return) and the runWG ticket added +// by SendMessage. The run is bound to the workspace context so its +// lifetime is independent of any client's HTTP request. On a non-cancel +// error it surfaces the failure to observers via a notify.TypeAgentError +// notification; context.Canceled is expected (the FinishReasonCanceled +// marker is already published by sessionAgent.Run) and swallowed. +// +// When msg.RunID is non-empty it is attached to the context via +// agent.WithRunID so the coordinator can stamp the terminal +// notify.RunComplete event with that correlator. +func (b *Backend) runAgent(ws *Workspace, msg proto.AgentMessage, accept *agent.AcceptedRun) { + defer ws.runWG.Done() + defer accept.Close() + + ctx := ws.ctx if msg.RunID != "" { ctx = agent.WithRunID(ctx, msg.RunID) } - _, err = ws.AgentCoordinator.Run(ctx, msg.SessionID, msg.Prompt, proto.AttachmentsToMessage(msg.Attachments)...) - return err + + _, err := ws.AgentCoordinator.RunAccepted(ctx, accept, msg.SessionID, msg.Prompt, proto.AttachmentsToMessage(msg.Attachments)...) + if err == nil || errors.Is(err, context.Canceled) { + return + } + + ws.AgentNotifications().Publish(pubsub.CreatedEvent, notify.Notification{ + SessionID: msg.SessionID, + Type: notify.TypeAgentError, + Message: err.Error(), + }) } // GetAgentInfo returns the agent's model and busy status. diff --git a/internal/backend/agent_test.go b/internal/backend/agent_test.go new file mode 100644 index 0000000000000000000000000000000000000000..5d9365ecc899236b91d73198ab322bcabbf9cc77 --- /dev/null +++ b/internal/backend/agent_test.go @@ -0,0 +1,163 @@ +package backend + +import ( + "context" + "sync/atomic" + "testing" + "time" + + "charm.land/fantasy" + "github.com/charmbracelet/crush/internal/agent" + "github.com/charmbracelet/crush/internal/app" + "github.com/charmbracelet/crush/internal/message" + "github.com/charmbracelet/crush/internal/proto" + "github.com/google/uuid" + "github.com/stretchr/testify/require" +) + +// blockingCoordinator is a minimal agent.Coordinator whose RunAccepted +// blocks until release is closed. It records that RunAccepted was +// entered so tests can observe the dispatched goroutine. Every other +// method returns a zero value. +type blockingCoordinator struct { + entered chan struct{} + release chan struct{} + runCount atomic.Int32 +} + +func newBlockingCoordinator() *blockingCoordinator { + return &blockingCoordinator{ + entered: make(chan struct{}, 1), + release: make(chan struct{}), + } +} + +func (c *blockingCoordinator) Run(ctx context.Context, sessionID, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error) { + return nil, nil +} + +func (c *blockingCoordinator) RunAccepted(ctx context.Context, accept *agent.AcceptedRun, sessionID, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error) { + c.runCount.Add(1) + select { + case c.entered <- struct{}{}: + default: + } + <-c.release + return nil, nil +} + +func (c *blockingCoordinator) BeginAccepted(sessionID string) *agent.AcceptedRun { return nil } +func (c *blockingCoordinator) Cancel(string) {} +func (c *blockingCoordinator) CancelAll() {} +func (c *blockingCoordinator) IsBusy() bool { return false } +func (c *blockingCoordinator) IsSessionBusy(string) bool { return false } +func (c *blockingCoordinator) QueuedPrompts(string) int { return 0 } +func (c *blockingCoordinator) QueuedPromptsList(string) []string { return nil } +func (c *blockingCoordinator) ClearQueue(string) {} +func (c *blockingCoordinator) Summarize(context.Context, string) error { return nil } +func (c *blockingCoordinator) Model() agent.Model { return agent.Model{} } +func (c *blockingCoordinator) UpdateModels(context.Context) error { return nil } + +// insertAgentWorkspace installs a synthetic workspace with the given +// coordinator (or none) and a workspace run context, mirroring the +// fields CreateWorkspace initializes. +func insertAgentWorkspace(t *testing.T, b *Backend, coord agent.Coordinator) *Workspace { + t.Helper() + ws := &Workspace{ + ID: uuid.New().String(), + Path: t.TempDir(), + resolvedPath: t.TempDir(), + clients: make(map[string]*clientState), + shutdownFn: func() {}, + } + ws.App = &app.App{AgentCoordinator: coord} + ws.ctx, ws.cancel = context.WithCancel(b.ctx) + b.mu.Lock() + b.workspaces.Set(ws.ID, ws) + b.pathIndex[ws.resolvedPath] = ws.ID + b.mu.Unlock() + return ws +} + +func TestSendMessage_WorkspaceNotFound(t *testing.T) { + t.Parallel() + b, _ := newTestBackend(t) + err := b.SendMessage("nope", proto.AgentMessage{SessionID: "S1", Prompt: "hi"}) + require.ErrorIs(t, err, ErrWorkspaceNotFound) +} + +func TestSendMessage_AgentNotInitialized(t *testing.T) { + t.Parallel() + b, _ := newTestBackend(t) + ws := insertAgentWorkspace(t, b, nil) + err := b.SendMessage(ws.ID, proto.AgentMessage{SessionID: "S1", Prompt: "hi"}) + require.ErrorIs(t, err, ErrAgentNotInitialized) +} + +func TestSendMessage_EmptyPrompt(t *testing.T) { + t.Parallel() + b, _ := newTestBackend(t) + ws := insertAgentWorkspace(t, b, newBlockingCoordinator()) + err := b.SendMessage(ws.ID, proto.AgentMessage{SessionID: "S1", Prompt: ""}) + require.ErrorIs(t, err, agent.ErrEmptyPrompt) +} + +func TestSendMessage_SessionMissing(t *testing.T) { + t.Parallel() + b, _ := newTestBackend(t) + ws := insertAgentWorkspace(t, b, newBlockingCoordinator()) + err := b.SendMessage(ws.ID, proto.AgentMessage{SessionID: "", Prompt: "hi"}) + require.ErrorIs(t, err, agent.ErrSessionMissing) +} + +func TestSendMessage_WorkspaceClosing(t *testing.T) { + t.Parallel() + b, _ := newTestBackend(t) + ws := insertAgentWorkspace(t, b, newBlockingCoordinator()) + ws.runMu.Lock() + ws.closing = true + ws.runMu.Unlock() + err := b.SendMessage(ws.ID, proto.AgentMessage{SessionID: "S1", Prompt: "hi"}) + require.ErrorIs(t, err, ErrWorkspaceClosing) +} + +// TestSendMessage_SuccessIncrementsRunWG asserts the happy path returns +// nil synchronously and dispatches a tracked goroutine: while +// RunAccepted blocks, runWG.Wait must not complete (the ticket is +// outstanding); after release it drains. +func TestSendMessage_SuccessIncrementsRunWG(t *testing.T) { + t.Parallel() + b, _ := newTestBackend(t) + coord := newBlockingCoordinator() + ws := insertAgentWorkspace(t, b, coord) + + err := b.SendMessage(ws.ID, proto.AgentMessage{SessionID: "S1", Prompt: "hi"}) + require.NoError(t, err) + + select { + case <-coord.entered: + case <-time.After(2 * time.Second): + t.Fatal("dispatched goroutine never entered RunAccepted") + } + require.Equal(t, int32(1), coord.runCount.Load()) + + waited := make(chan struct{}) + go func() { + ws.runWG.Wait() + close(waited) + }() + + select { + case <-waited: + t.Fatal("runWG.Wait completed while the run was still in flight; ticket was not added") + case <-time.After(100 * time.Millisecond): + } + + close(coord.release) + + select { + case <-waited: + case <-time.After(2 * time.Second): + t.Fatal("runWG.Wait did not complete after the run returned") + } +} diff --git a/internal/backend/testing.go b/internal/backend/testing.go index 6616e0f19e06595fac68808b484394d960d7f79f..1c71caed6566747ac947d61c30ece575d3d13eb4 100644 --- a/internal/backend/testing.go +++ b/internal/backend/testing.go @@ -1,9 +1,16 @@ package backend +import "context" + // InsertWorkspaceForTest registers ws with b under its current ID and // path. It is intended for tests in other packages that need to drive // HTTP handlers against a synthetic workspace without booting a real // app.App. Production code should go through CreateWorkspace. +// +// If the workspace has no run context yet it is derived from the +// backend context (falling back to context.Background), mirroring the +// initialization CreateWorkspace performs, so dispatched agent runs +// have a non-nil ws.ctx. func InsertWorkspaceForTest(b *Backend, ws *Workspace) { if ws.resolvedPath == "" { ws.resolvedPath = ws.Path @@ -11,6 +18,13 @@ func InsertWorkspaceForTest(b *Backend, ws *Workspace) { if ws.clients == nil { ws.clients = make(map[string]*clientState) } + if ws.ctx == nil { + parent := b.ctx + if parent == nil { + parent = context.Background() + } + ws.ctx, ws.cancel = context.WithCancel(parent) + } b.mu.Lock() defer b.mu.Unlock() b.workspaces.Set(ws.ID, ws) diff --git a/internal/server/agent_cancel_test.go b/internal/server/agent_cancel_test.go index 18ea8046d0647d3576a50769b7b2146d9aa103e9..6697bdd210b911862b6795ef1c5e5009e486ad64 100644 --- a/internal/server/agent_cancel_test.go +++ b/internal/server/agent_cancel_test.go @@ -122,11 +122,12 @@ func postAgent(t *testing.T, c *controllerV1, ctx context.Context, wsID, session } // TestPostAgent_ReturnsOKOnContextCanceled verifies that when another -// client cancels the session mid-turn, the prompting client's still -// open POST receives 200 (not 500). The agent surfaces the -// FinishReasonCanceled marker to every SSE subscriber via the -// assistant message; the HTTP response from the prompter should not -// double as an error signal. +// client cancels the session mid-turn, the prompting client's POST is +// unaffected: SendMessage is fire-and-forget, so the handler returns +// 200 immediately without waiting for the turn. A run that later +// returns context.Canceled never surfaces as a 500 to the prompter; +// the FinishReasonCanceled marker reaches SSE subscribers via the +// assistant message instead. func TestPostAgent_ReturnsOKOnContextCanceled(t *testing.T) { t.Parallel() @@ -135,33 +136,26 @@ func TestPostAgent_ReturnsOKOnContextCanceled(t *testing.T) { }) c, wsID := buildAgentWorkspace(t, coord) - done := make(chan *httptest.ResponseRecorder, 1) - go func() { - done <- postAgent(t, c, t.Context(), wsID, "S1") - }() + // The handler returns immediately, before the dispatched run is + // released, because the run no longer owns the HTTP response. + rec := postAgent(t, c, t.Context(), wsID, "S1") + require.Equal(t, http.StatusOK, rec.Code, "fire-and-forget SendMessage must return 200 without waiting for the run") - // Wait until Run is in flight, then release it to return - // context.Canceled. + // The run is dispatched on a goroutine; let it return + // context.Canceled. Nothing from that path reaches the (already + // returned) handler. select { case <-coord.entered: case <-time.After(2 * time.Second): - t.Fatal("coordinator Run was never entered") + t.Fatal("dispatched run was never entered") } close(coord.release) - - select { - case rec := <-done: - require.Equal(t, http.StatusOK, rec.Code, "context.Canceled from another client's cancel must not surface as 500") - case <-time.After(2 * time.Second): - t.Fatal("handler did not return after coordinator returned context.Canceled") - } } -// TestPostAgent_DetachesRequestContext verifies that canceling the -// prompting client's HTTP request context does not cancel the -// in-flight agent run. The coordinator must observe a context whose -// Done channel never fires from the request side; only the explicit -// cancel endpoint may end the run. +// TestPostAgent_DetachesRequestContext verifies that the dispatched run +// is bound to the workspace context, not the prompting client's HTTP +// request context. Canceling the request context must neither cancel +// the run nor be observed by the coordinator. func TestPostAgent_DetachesRequestContext(t *testing.T) { t.Parallel() @@ -172,46 +166,31 @@ func TestPostAgent_DetachesRequestContext(t *testing.T) { reqCtx, cancelReq := context.WithCancel(context.Background()) - done := make(chan *httptest.ResponseRecorder, 1) - go func() { - done <- postAgent(t, c, reqCtx, wsID, "S1") - }() + // The handler returns immediately; the run keeps executing on its + // own goroutine bound to the workspace context. + rec := postAgent(t, c, reqCtx, wsID, "S1") + require.Equal(t, http.StatusOK, rec.Code) - // Wait until Run is in flight, then drop the prompting client. select { case <-coord.entered: case <-time.After(2 * time.Second): - t.Fatal("coordinator Run was never entered") + t.Fatal("dispatched run was never entered") } + + // Drop the prompting client. This must not reach the run. cancelReq() - // The captured ctx must be detached: context.WithoutCancel - // returns a ctx with Done() == nil so request cancellation cannot - // propagate. got := coord.capturedCtx() require.NotNil(t, got) - require.Nil(t, got.Done(), "coordinator ctx must be detached from r.Context() via context.WithoutCancel") - require.NoError(t, got.Err(), "coordinator ctx must not inherit cancellation from the dropped request") + // Compare by identity (pointer), not reflect.DeepEqual: deep + // comparison would traverse context internals that the runtime + // mutates concurrently. + require.False(t, got == reqCtx, "run ctx must not be the request ctx") + require.NoError(t, got.Err(), "run ctx must not inherit cancellation from the dropped request") - // Confirm Run is still running: it should not have completed - // just because the request ctx was canceled. - select { - case <-done: - t.Fatal("handler returned before run completed; request ctx cancellation leaked into the run") - case <-time.After(50 * time.Millisecond): - } - - // Release the run; the handler should now complete cleanly. + // Release the run so it returns cleanly. close(coord.release) - select { - case rec := <-done: - // Writing to a recorder whose request ctx was canceled - // still works; in production the TCP write would silently - // fail, which is fine because the run already completed and - // SSE subscribers have the result. - require.Equal(t, http.StatusOK, rec.Code) - case <-time.After(2 * time.Second): - t.Fatal("handler did not return after release") - } - require.Equal(t, int32(1), coord.ranCount.Load()) + require.Eventually(t, func() bool { + return coord.ranCount.Load() == 1 + }, 2*time.Second, 10*time.Millisecond) } diff --git a/internal/server/proto.go b/internal/server/proto.go index 5e1c4e2605fb8413df0464be2cd7ee6cb40a5f66..4b43dcd096de34ab4bd2adde61cc4f5a73e0f8c9 100644 --- a/internal/server/proto.go +++ b/internal/server/proto.go @@ -755,14 +755,15 @@ func (c *controllerV1) handlePostWorkspaceAgent(w http.ResponseWriter, r *http.R return } - // Detach the run's lifetime from the prompting client's HTTP - // request. Without this, A dropping its TCP connection (network - // blip, TUI restart) or B canceling the session via the explicit - // cancel endpoint would also cancel A's request context and tear - // down a turn that other subscribed clients are still watching. - // Only the explicit cancel endpoint should be able to end a run. - ctx := context.WithoutCancel(r.Context()) - if err := c.backend.SendMessage(ctx, id, msg); err != nil { + // The run's lifetime is detached from the prompting client's HTTP + // request: SendMessage validates and accepts the prompt, dispatches + // the run on a goroutine bound to the workspace context, and returns + // immediately. A dropping its TCP connection (network blip, TUI + // restart) or B canceling the session via the explicit cancel + // endpoint can no longer tear down a turn that other subscribed + // clients are still watching. Only the explicit cancel endpoint + // should be able to end a run. + if err := c.backend.SendMessage(id, msg); err != nil { c.handleError(w, r, err) return }