1package server
2
3import (
4 "bytes"
5 "context"
6 "encoding/json"
7 "net/http"
8 "net/http/httptest"
9 "sync"
10 "sync/atomic"
11 "testing"
12 "time"
13
14 "charm.land/fantasy"
15 "github.com/charmbracelet/crush/internal/agent"
16 "github.com/charmbracelet/crush/internal/app"
17 "github.com/charmbracelet/crush/internal/backend"
18 "github.com/charmbracelet/crush/internal/message"
19 "github.com/charmbracelet/crush/internal/proto"
20 "github.com/google/uuid"
21 "github.com/stretchr/testify/require"
22)
23
24// runCoordinator is a configurable agent.Coordinator stub for the
25// cancel/drop tests. Run blocks until either ctx is canceled (so it
26// can observe explicit Cancel paths) or release fires (so the test
27// can let a "still running" turn finish on its own). The most recent
28// ctx and the error returned to the caller are recorded for
29// assertions.
30type runCoordinator struct {
31 release chan struct{}
32 returnFn func(ctx context.Context) error
33
34 mu sync.Mutex
35 gotCtx context.Context
36 ranCount atomic.Int32
37 entered chan struct{} // closed exactly once when Run is first entered.
38 enteredOne sync.Once
39}
40
41func newRunCoordinator(returnFn func(ctx context.Context) error) *runCoordinator {
42 return &runCoordinator{
43 release: make(chan struct{}),
44 returnFn: returnFn,
45 entered: make(chan struct{}),
46 }
47}
48
49func (s *runCoordinator) Run(ctx context.Context, sessionID, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error) {
50 s.mu.Lock()
51 s.gotCtx = ctx
52 s.mu.Unlock()
53 s.ranCount.Add(1)
54 s.enteredOne.Do(func() { close(s.entered) })
55 select {
56 case <-s.release:
57 case <-ctx.Done():
58 // Only fires if the run is actually cancellable.
59 }
60 return nil, s.returnFn(ctx)
61}
62
63func (s *runCoordinator) RunAccepted(ctx context.Context, accept *agent.AcceptedRun, sessionID, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error) {
64 return s.Run(ctx, sessionID, prompt, attachments...)
65}
66
67func (s *runCoordinator) BeginAccepted(sessionID string) *agent.AcceptedRun {
68 return nil
69}
70func (s *runCoordinator) Cancel(string) {}
71func (s *runCoordinator) CancelAll() {}
72func (s *runCoordinator) IsBusy() bool { return false }
73func (s *runCoordinator) IsSessionBusy(string) bool {
74 return false
75}
76func (s *runCoordinator) QueuedPrompts(string) int { return 0 }
77func (s *runCoordinator) QueuedPromptsList(string) []string { return nil }
78func (s *runCoordinator) ClearQueue(string) {}
79func (s *runCoordinator) Summarize(context.Context, string) error {
80 return nil
81}
82func (s *runCoordinator) Model() agent.Model { return agent.Model{} }
83func (s *runCoordinator) UpdateModels(context.Context) error { return nil }
84
85func (s *runCoordinator) capturedCtx() context.Context {
86 s.mu.Lock()
87 defer s.mu.Unlock()
88 return s.gotCtx
89}
90
91// buildAgentWorkspace returns a controller wired to a backend whose
92// single workspace exposes the given coordinator. The workspace
93// shutdown hook is overridden to avoid driving a real [app.App]
94// through teardown when the test exits.
95func buildAgentWorkspace(t *testing.T, coord agent.Coordinator) (*controllerV1, string) {
96 t.Helper()
97 b := backend.New(context.Background(), nil, nil)
98 a := &app.App{AgentCoordinator: coord}
99
100 ws := &backend.Workspace{
101 ID: uuid.New().String(),
102 Path: t.TempDir(),
103 App: a,
104 }
105 backend.InsertWorkspaceForTest(b, ws)
106 backend.SetWorkspaceShutdownFnForTest(ws, func() {})
107
108 s := &Server{backend: b}
109 return &controllerV1{backend: b, server: s}, ws.ID
110}
111
112func postAgent(t *testing.T, c *controllerV1, ctx context.Context, wsID, sessionID string) *httptest.ResponseRecorder {
113 t.Helper()
114 body, err := json.Marshal(proto.AgentMessage{SessionID: sessionID, Prompt: "hi"})
115 require.NoError(t, err)
116 req := httptest.NewRequestWithContext(ctx, http.MethodPost, "/v1/workspaces/"+wsID+"/agent", bytes.NewReader(body))
117 req.SetPathValue("id", wsID)
118 req.Header.Set("Content-Type", "application/json")
119 rec := httptest.NewRecorder()
120 c.handlePostWorkspaceAgent(rec, req)
121 return rec
122}
123
124// TestPostAgent_ReturnsOKOnContextCanceled verifies that when another
125// client cancels the session mid-turn, the prompting client's still
126// open POST receives 200 (not 500). The agent surfaces the
127// FinishReasonCanceled marker to every SSE subscriber via the
128// assistant message; the HTTP response from the prompter should not
129// double as an error signal.
130func TestPostAgent_ReturnsOKOnContextCanceled(t *testing.T) {
131 t.Parallel()
132
133 coord := newRunCoordinator(func(context.Context) error {
134 return context.Canceled
135 })
136 c, wsID := buildAgentWorkspace(t, coord)
137
138 done := make(chan *httptest.ResponseRecorder, 1)
139 go func() {
140 done <- postAgent(t, c, t.Context(), wsID, "S1")
141 }()
142
143 // Wait until Run is in flight, then release it to return
144 // context.Canceled.
145 select {
146 case <-coord.entered:
147 case <-time.After(2 * time.Second):
148 t.Fatal("coordinator Run was never entered")
149 }
150 close(coord.release)
151
152 select {
153 case rec := <-done:
154 require.Equal(t, http.StatusOK, rec.Code, "context.Canceled from another client's cancel must not surface as 500")
155 case <-time.After(2 * time.Second):
156 t.Fatal("handler did not return after coordinator returned context.Canceled")
157 }
158}
159
160// TestPostAgent_DetachesRequestContext verifies that canceling the
161// prompting client's HTTP request context does not cancel the
162// in-flight agent run. The coordinator must observe a context whose
163// Done channel never fires from the request side; only the explicit
164// cancel endpoint may end the run.
165func TestPostAgent_DetachesRequestContext(t *testing.T) {
166 t.Parallel()
167
168 coord := newRunCoordinator(func(context.Context) error {
169 return nil
170 })
171 c, wsID := buildAgentWorkspace(t, coord)
172
173 reqCtx, cancelReq := context.WithCancel(context.Background())
174
175 done := make(chan *httptest.ResponseRecorder, 1)
176 go func() {
177 done <- postAgent(t, c, reqCtx, wsID, "S1")
178 }()
179
180 // Wait until Run is in flight, then drop the prompting client.
181 select {
182 case <-coord.entered:
183 case <-time.After(2 * time.Second):
184 t.Fatal("coordinator Run was never entered")
185 }
186 cancelReq()
187
188 // The captured ctx must be detached: context.WithoutCancel
189 // returns a ctx with Done() == nil so request cancellation cannot
190 // propagate.
191 got := coord.capturedCtx()
192 require.NotNil(t, got)
193 require.Nil(t, got.Done(), "coordinator ctx must be detached from r.Context() via context.WithoutCancel")
194 require.NoError(t, got.Err(), "coordinator ctx must not inherit cancellation from the dropped request")
195
196 // Confirm Run is still running: it should not have completed
197 // just because the request ctx was canceled.
198 select {
199 case <-done:
200 t.Fatal("handler returned before run completed; request ctx cancellation leaked into the run")
201 case <-time.After(50 * time.Millisecond):
202 }
203
204 // Release the run; the handler should now complete cleanly.
205 close(coord.release)
206 select {
207 case rec := <-done:
208 // Writing to a recorder whose request ctx was canceled
209 // still works; in production the TCP write would silently
210 // fail, which is fine because the run already completed and
211 // SSE subscribers have the result.
212 require.Equal(t, http.StatusOK, rec.Code)
213 case <-time.After(2 * time.Second):
214 t.Fatal("handler did not return after release")
215 }
216 require.Equal(t, int32(1), coord.ranCount.Load())
217}