agent_cancel_test.go

  1package server
  2
  3import (
  4	"bytes"
  5	"context"
  6	"encoding/json"
  7	"net/http"
  8	"net/http/httptest"
  9	"sync"
 10	"sync/atomic"
 11	"testing"
 12	"time"
 13
 14	"charm.land/fantasy"
 15	"github.com/charmbracelet/crush/internal/agent"
 16	"github.com/charmbracelet/crush/internal/app"
 17	"github.com/charmbracelet/crush/internal/backend"
 18	"github.com/charmbracelet/crush/internal/message"
 19	"github.com/charmbracelet/crush/internal/proto"
 20	"github.com/google/uuid"
 21	"github.com/stretchr/testify/require"
 22)
 23
 24// runCoordinator is a configurable agent.Coordinator stub for the
 25// cancel/drop tests. Run blocks until either ctx is canceled (so it
 26// can observe explicit Cancel paths) or release fires (so the test
 27// can let a "still running" turn finish on its own). The most recent
 28// ctx and the error returned to the caller are recorded for
 29// assertions.
 30type runCoordinator struct {
 31	release  chan struct{}
 32	returnFn func(ctx context.Context) error
 33
 34	mu         sync.Mutex
 35	gotCtx     context.Context
 36	ranCount   atomic.Int32
 37	entered    chan struct{} // closed exactly once when Run is first entered.
 38	enteredOne sync.Once
 39}
 40
 41func newRunCoordinator(returnFn func(ctx context.Context) error) *runCoordinator {
 42	return &runCoordinator{
 43		release:  make(chan struct{}),
 44		returnFn: returnFn,
 45		entered:  make(chan struct{}),
 46	}
 47}
 48
 49func (s *runCoordinator) Run(ctx context.Context, sessionID, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error) {
 50	s.mu.Lock()
 51	s.gotCtx = ctx
 52	s.mu.Unlock()
 53	s.ranCount.Add(1)
 54	s.enteredOne.Do(func() { close(s.entered) })
 55	select {
 56	case <-s.release:
 57	case <-ctx.Done():
 58		// Only fires if the run is actually cancellable.
 59	}
 60	return nil, s.returnFn(ctx)
 61}
 62
 63func (s *runCoordinator) Cancel(string) {}
 64func (s *runCoordinator) CancelAll()    {}
 65func (s *runCoordinator) IsBusy() bool  { return false }
 66func (s *runCoordinator) IsSessionBusy(string) bool {
 67	return false
 68}
 69func (s *runCoordinator) QueuedPrompts(string) int          { return 0 }
 70func (s *runCoordinator) QueuedPromptsList(string) []string { return nil }
 71func (s *runCoordinator) ClearQueue(string)                 {}
 72func (s *runCoordinator) Summarize(context.Context, string) error {
 73	return nil
 74}
 75func (s *runCoordinator) Model() agent.Model                 { return agent.Model{} }
 76func (s *runCoordinator) UpdateModels(context.Context) error { return nil }
 77
 78func (s *runCoordinator) capturedCtx() context.Context {
 79	s.mu.Lock()
 80	defer s.mu.Unlock()
 81	return s.gotCtx
 82}
 83
 84// buildAgentWorkspace returns a controller wired to a backend whose
 85// single workspace exposes the given coordinator. The workspace
 86// shutdown hook is overridden to avoid driving a real [app.App]
 87// through teardown when the test exits.
 88func buildAgentWorkspace(t *testing.T, coord agent.Coordinator) (*controllerV1, string) {
 89	t.Helper()
 90	b := backend.New(context.Background(), nil, nil)
 91	a := &app.App{AgentCoordinator: coord}
 92
 93	ws := &backend.Workspace{
 94		ID:   uuid.New().String(),
 95		Path: t.TempDir(),
 96		App:  a,
 97	}
 98	backend.InsertWorkspaceForTest(b, ws)
 99	backend.SetWorkspaceShutdownFnForTest(ws, func() {})
100
101	s := &Server{backend: b}
102	return &controllerV1{backend: b, server: s}, ws.ID
103}
104
105func postAgent(t *testing.T, c *controllerV1, ctx context.Context, wsID, sessionID string) *httptest.ResponseRecorder {
106	t.Helper()
107	body, err := json.Marshal(proto.AgentMessage{SessionID: sessionID, Prompt: "hi"})
108	require.NoError(t, err)
109	req := httptest.NewRequestWithContext(ctx, http.MethodPost, "/v1/workspaces/"+wsID+"/agent", bytes.NewReader(body))
110	req.SetPathValue("id", wsID)
111	req.Header.Set("Content-Type", "application/json")
112	rec := httptest.NewRecorder()
113	c.handlePostWorkspaceAgent(rec, req)
114	return rec
115}
116
117// TestPostAgent_ReturnsOKOnContextCanceled verifies that when another
118// client cancels the session mid-turn, the prompting client's still
119// open POST receives 200 (not 500). The agent surfaces the
120// FinishReasonCanceled marker to every SSE subscriber via the
121// assistant message; the HTTP response from the prompter should not
122// double as an error signal.
123func TestPostAgent_ReturnsOKOnContextCanceled(t *testing.T) {
124	t.Parallel()
125
126	coord := newRunCoordinator(func(context.Context) error {
127		return context.Canceled
128	})
129	c, wsID := buildAgentWorkspace(t, coord)
130
131	done := make(chan *httptest.ResponseRecorder, 1)
132	go func() {
133		done <- postAgent(t, c, t.Context(), wsID, "S1")
134	}()
135
136	// Wait until Run is in flight, then release it to return
137	// context.Canceled.
138	select {
139	case <-coord.entered:
140	case <-time.After(2 * time.Second):
141		t.Fatal("coordinator Run was never entered")
142	}
143	close(coord.release)
144
145	select {
146	case rec := <-done:
147		require.Equal(t, http.StatusOK, rec.Code, "context.Canceled from another client's cancel must not surface as 500")
148	case <-time.After(2 * time.Second):
149		t.Fatal("handler did not return after coordinator returned context.Canceled")
150	}
151}
152
153// TestPostAgent_DetachesRequestContext verifies that canceling the
154// prompting client's HTTP request context does not cancel the
155// in-flight agent run. The coordinator must observe a context whose
156// Done channel never fires from the request side; only the explicit
157// cancel endpoint may end the run.
158func TestPostAgent_DetachesRequestContext(t *testing.T) {
159	t.Parallel()
160
161	coord := newRunCoordinator(func(context.Context) error {
162		return nil
163	})
164	c, wsID := buildAgentWorkspace(t, coord)
165
166	reqCtx, cancelReq := context.WithCancel(context.Background())
167
168	done := make(chan *httptest.ResponseRecorder, 1)
169	go func() {
170		done <- postAgent(t, c, reqCtx, wsID, "S1")
171	}()
172
173	// Wait until Run is in flight, then drop the prompting client.
174	select {
175	case <-coord.entered:
176	case <-time.After(2 * time.Second):
177		t.Fatal("coordinator Run was never entered")
178	}
179	cancelReq()
180
181	// The captured ctx must be detached: context.WithoutCancel
182	// returns a ctx with Done() == nil so request cancellation cannot
183	// propagate.
184	got := coord.capturedCtx()
185	require.NotNil(t, got)
186	require.Nil(t, got.Done(), "coordinator ctx must be detached from r.Context() via context.WithoutCancel")
187	require.NoError(t, got.Err(), "coordinator ctx must not inherit cancellation from the dropped request")
188
189	// Confirm Run is still running: it should not have completed
190	// just because the request ctx was canceled.
191	select {
192	case <-done:
193		t.Fatal("handler returned before run completed; request ctx cancellation leaked into the run")
194	case <-time.After(50 * time.Millisecond):
195	}
196
197	// Release the run; the handler should now complete cleanly.
198	close(coord.release)
199	select {
200	case rec := <-done:
201		// Writing to a recorder whose request ctx was canceled
202		// still works; in production the TCP write would silently
203		// fail, which is fine because the run already completed and
204		// SSE subscribers have the result.
205		require.Equal(t, http.StatusOK, rec.Code)
206	case <-time.After(2 * time.Second):
207		t.Fatal("handler did not return after release")
208	}
209	require.Equal(t, int32(1), coord.ranCount.Load())
210}