agent_cancel_test.go

  1package server
  2
  3import (
  4	"bytes"
  5	"context"
  6	"encoding/json"
  7	"net/http"
  8	"net/http/httptest"
  9	"sync"
 10	"sync/atomic"
 11	"testing"
 12	"time"
 13
 14	"charm.land/fantasy"
 15	"github.com/charmbracelet/crush/internal/agent"
 16	"github.com/charmbracelet/crush/internal/app"
 17	"github.com/charmbracelet/crush/internal/backend"
 18	"github.com/charmbracelet/crush/internal/message"
 19	"github.com/charmbracelet/crush/internal/proto"
 20	"github.com/google/uuid"
 21	"github.com/stretchr/testify/require"
 22)
 23
 24// runCoordinator is a configurable agent.Coordinator stub for the
 25// cancel/drop tests. Run blocks until either ctx is canceled (so it
 26// can observe explicit Cancel paths) or release fires (so the test
 27// can let a "still running" turn finish on its own). The most recent
 28// ctx and the error returned to the caller are recorded for
 29// assertions.
 30type runCoordinator struct {
 31	release  chan struct{}
 32	returnFn func(ctx context.Context) error
 33
 34	mu         sync.Mutex
 35	gotCtx     context.Context
 36	ranCount   atomic.Int32
 37	entered    chan struct{} // closed exactly once when Run is first entered.
 38	enteredOne sync.Once
 39}
 40
 41func newRunCoordinator(returnFn func(ctx context.Context) error) *runCoordinator {
 42	return &runCoordinator{
 43		release:  make(chan struct{}),
 44		returnFn: returnFn,
 45		entered:  make(chan struct{}),
 46	}
 47}
 48
 49func (s *runCoordinator) Run(ctx context.Context, sessionID, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error) {
 50	s.mu.Lock()
 51	s.gotCtx = ctx
 52	s.mu.Unlock()
 53	s.ranCount.Add(1)
 54	s.enteredOne.Do(func() { close(s.entered) })
 55	select {
 56	case <-s.release:
 57	case <-ctx.Done():
 58		// Only fires if the run is actually cancellable.
 59	}
 60	return nil, s.returnFn(ctx)
 61}
 62
 63func (s *runCoordinator) RunAccepted(ctx context.Context, accept *agent.AcceptedRun, sessionID, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error) {
 64	return s.Run(ctx, sessionID, prompt, attachments...)
 65}
 66
 67func (s *runCoordinator) BeginAccepted(sessionID string) *agent.AcceptedRun {
 68	return nil
 69}
 70func (s *runCoordinator) Cancel(string) {}
 71func (s *runCoordinator) CancelAll()    {}
 72func (s *runCoordinator) IsBusy() bool  { return false }
 73func (s *runCoordinator) IsSessionBusy(string) bool {
 74	return false
 75}
 76func (s *runCoordinator) QueuedPrompts(string) int          { return 0 }
 77func (s *runCoordinator) QueuedPromptsList(string) []string { return nil }
 78func (s *runCoordinator) ClearQueue(string)                 {}
 79func (s *runCoordinator) Summarize(context.Context, string) error {
 80	return nil
 81}
 82func (s *runCoordinator) Model() agent.Model                 { return agent.Model{} }
 83func (s *runCoordinator) UpdateModels(context.Context) error { return nil }
 84
 85func (s *runCoordinator) capturedCtx() context.Context {
 86	s.mu.Lock()
 87	defer s.mu.Unlock()
 88	return s.gotCtx
 89}
 90
 91// buildAgentWorkspace returns a controller wired to a backend whose
 92// single workspace exposes the given coordinator. The workspace
 93// shutdown hook is overridden to avoid driving a real [app.App]
 94// through teardown when the test exits.
 95func buildAgentWorkspace(t *testing.T, coord agent.Coordinator) (*controllerV1, string) {
 96	t.Helper()
 97	b := backend.New(context.Background(), nil, nil)
 98	a := &app.App{AgentCoordinator: coord}
 99
100	ws := &backend.Workspace{
101		ID:   uuid.New().String(),
102		Path: t.TempDir(),
103		App:  a,
104	}
105	backend.InsertWorkspaceForTest(b, ws)
106	backend.SetWorkspaceShutdownFnForTest(ws, func() {})
107
108	s := &Server{backend: b}
109	return &controllerV1{backend: b, server: s}, ws.ID
110}
111
112func postAgent(t *testing.T, c *controllerV1, ctx context.Context, wsID, sessionID string) *httptest.ResponseRecorder {
113	t.Helper()
114	body, err := json.Marshal(proto.AgentMessage{SessionID: sessionID, Prompt: "hi"})
115	require.NoError(t, err)
116	req := httptest.NewRequestWithContext(ctx, http.MethodPost, "/v1/workspaces/"+wsID+"/agent", bytes.NewReader(body))
117	req.SetPathValue("id", wsID)
118	req.Header.Set("Content-Type", "application/json")
119	rec := httptest.NewRecorder()
120	c.handlePostWorkspaceAgent(rec, req)
121	return rec
122}
123
124// TestPostAgent_ReturnsOKOnContextCanceled verifies that when another
125// client cancels the session mid-turn, the prompting client's POST is
126// unaffected: SendMessage is fire-and-forget, so the handler returns
127// 200 immediately without waiting for the turn. A run that later
128// returns context.Canceled never surfaces as a 500 to the prompter;
129// the FinishReasonCanceled marker reaches SSE subscribers via the
130// assistant message instead.
131func TestPostAgent_ReturnsOKOnContextCanceled(t *testing.T) {
132	t.Parallel()
133
134	coord := newRunCoordinator(func(context.Context) error {
135		return context.Canceled
136	})
137	c, wsID := buildAgentWorkspace(t, coord)
138
139	// The handler returns immediately, before the dispatched run is
140	// released, because the run no longer owns the HTTP response.
141	rec := postAgent(t, c, t.Context(), wsID, "S1")
142	require.Equal(t, http.StatusAccepted, rec.Code, "fire-and-forget SendMessage must return 202 without waiting for the run")
143
144	// The run is dispatched on a goroutine; let it return
145	// context.Canceled. Nothing from that path reaches the (already
146	// returned) handler.
147	select {
148	case <-coord.entered:
149	case <-time.After(2 * time.Second):
150		t.Fatal("dispatched run was never entered")
151	}
152	close(coord.release)
153}
154
155// TestPostAgent_DetachesRequestContext verifies that the dispatched run
156// is bound to the workspace context, not the prompting client's HTTP
157// request context. Canceling the request context must neither cancel
158// the run nor be observed by the coordinator.
159func TestPostAgent_DetachesRequestContext(t *testing.T) {
160	t.Parallel()
161
162	coord := newRunCoordinator(func(context.Context) error {
163		return nil
164	})
165	c, wsID := buildAgentWorkspace(t, coord)
166
167	reqCtx, cancelReq := context.WithCancel(context.Background())
168
169	// The handler returns immediately; the run keeps executing on its
170	// own goroutine bound to the workspace context.
171	rec := postAgent(t, c, reqCtx, wsID, "S1")
172	require.Equal(t, http.StatusAccepted, rec.Code)
173
174	select {
175	case <-coord.entered:
176	case <-time.After(2 * time.Second):
177		t.Fatal("dispatched run was never entered")
178	}
179
180	// Drop the prompting client. This must not reach the run.
181	cancelReq()
182
183	got := coord.capturedCtx()
184	require.NotNil(t, got)
185	// Compare by identity (pointer), not reflect.DeepEqual: deep
186	// comparison would traverse context internals that the runtime
187	// mutates concurrently.
188	require.False(t, got == reqCtx, "run ctx must not be the request ctx")
189	require.NoError(t, got.Err(), "run ctx must not inherit cancellation from the dropped request")
190
191	// Release the run so it returns cleanly.
192	close(coord.release)
193	require.Eventually(t, func() bool {
194		return coord.ranCount.Load() == 1
195	}, 2*time.Second, 10*time.Millisecond)
196}