1package server
2
3import (
4 "bytes"
5 "context"
6 "encoding/json"
7 "net/http"
8 "net/http/httptest"
9 "sync"
10 "sync/atomic"
11 "testing"
12 "time"
13
14 "charm.land/fantasy"
15 "github.com/charmbracelet/crush/internal/agent"
16 "github.com/charmbracelet/crush/internal/app"
17 "github.com/charmbracelet/crush/internal/backend"
18 "github.com/charmbracelet/crush/internal/message"
19 "github.com/charmbracelet/crush/internal/proto"
20 "github.com/google/uuid"
21 "github.com/stretchr/testify/require"
22)
23
24// runCoordinator is a configurable agent.Coordinator stub for the
25// cancel/drop tests. Run blocks until either ctx is canceled (so it
26// can observe explicit Cancel paths) or release fires (so the test
27// can let a "still running" turn finish on its own). The most recent
28// ctx and the error returned to the caller are recorded for
29// assertions.
30type runCoordinator struct {
31 release chan struct{}
32 returnFn func(ctx context.Context) error
33
34 mu sync.Mutex
35 gotCtx context.Context
36 ranCount atomic.Int32
37 entered chan struct{} // closed exactly once when Run is first entered.
38 enteredOne sync.Once
39}
40
41func newRunCoordinator(returnFn func(ctx context.Context) error) *runCoordinator {
42 return &runCoordinator{
43 release: make(chan struct{}),
44 returnFn: returnFn,
45 entered: make(chan struct{}),
46 }
47}
48
49func (s *runCoordinator) Run(ctx context.Context, sessionID, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error) {
50 s.mu.Lock()
51 s.gotCtx = ctx
52 s.mu.Unlock()
53 s.ranCount.Add(1)
54 s.enteredOne.Do(func() { close(s.entered) })
55 select {
56 case <-s.release:
57 case <-ctx.Done():
58 // Only fires if the run is actually cancellable.
59 }
60 return nil, s.returnFn(ctx)
61}
62
63func (s *runCoordinator) Cancel(string) {}
64func (s *runCoordinator) CancelAll() {}
65func (s *runCoordinator) IsBusy() bool { return false }
66func (s *runCoordinator) IsSessionBusy(string) bool {
67 return false
68}
69func (s *runCoordinator) QueuedPrompts(string) int { return 0 }
70func (s *runCoordinator) QueuedPromptsList(string) []string { return nil }
71func (s *runCoordinator) ClearQueue(string) {}
72func (s *runCoordinator) Summarize(context.Context, string) error {
73 return nil
74}
75func (s *runCoordinator) Model() agent.Model { return agent.Model{} }
76func (s *runCoordinator) UpdateModels(context.Context) error { return nil }
77
78func (s *runCoordinator) capturedCtx() context.Context {
79 s.mu.Lock()
80 defer s.mu.Unlock()
81 return s.gotCtx
82}
83
84// buildAgentWorkspace returns a controller wired to a backend whose
85// single workspace exposes the given coordinator. The workspace
86// shutdown hook is overridden to avoid driving a real [app.App]
87// through teardown when the test exits.
88func buildAgentWorkspace(t *testing.T, coord agent.Coordinator) (*controllerV1, string) {
89 t.Helper()
90 b := backend.New(context.Background(), nil, nil)
91 a := &app.App{AgentCoordinator: coord}
92
93 ws := &backend.Workspace{
94 ID: uuid.New().String(),
95 Path: t.TempDir(),
96 App: a,
97 }
98 backend.InsertWorkspaceForTest(b, ws)
99 backend.SetWorkspaceShutdownFnForTest(ws, func() {})
100
101 s := &Server{backend: b}
102 return &controllerV1{backend: b, server: s}, ws.ID
103}
104
105func postAgent(t *testing.T, c *controllerV1, ctx context.Context, wsID, sessionID string) *httptest.ResponseRecorder {
106 t.Helper()
107 body, err := json.Marshal(proto.AgentMessage{SessionID: sessionID, Prompt: "hi"})
108 require.NoError(t, err)
109 req := httptest.NewRequestWithContext(ctx, http.MethodPost, "/v1/workspaces/"+wsID+"/agent", bytes.NewReader(body))
110 req.SetPathValue("id", wsID)
111 req.Header.Set("Content-Type", "application/json")
112 rec := httptest.NewRecorder()
113 c.handlePostWorkspaceAgent(rec, req)
114 return rec
115}
116
117// TestPostAgent_ReturnsOKOnContextCanceled verifies that when another
118// client cancels the session mid-turn, the prompting client's still
119// open POST receives 200 (not 500). The agent surfaces the
120// FinishReasonCanceled marker to every SSE subscriber via the
121// assistant message; the HTTP response from the prompter should not
122// double as an error signal.
123func TestPostAgent_ReturnsOKOnContextCanceled(t *testing.T) {
124 t.Parallel()
125
126 coord := newRunCoordinator(func(context.Context) error {
127 return context.Canceled
128 })
129 c, wsID := buildAgentWorkspace(t, coord)
130
131 done := make(chan *httptest.ResponseRecorder, 1)
132 go func() {
133 done <- postAgent(t, c, t.Context(), wsID, "S1")
134 }()
135
136 // Wait until Run is in flight, then release it to return
137 // context.Canceled.
138 select {
139 case <-coord.entered:
140 case <-time.After(2 * time.Second):
141 t.Fatal("coordinator Run was never entered")
142 }
143 close(coord.release)
144
145 select {
146 case rec := <-done:
147 require.Equal(t, http.StatusOK, rec.Code, "context.Canceled from another client's cancel must not surface as 500")
148 case <-time.After(2 * time.Second):
149 t.Fatal("handler did not return after coordinator returned context.Canceled")
150 }
151}
152
153// TestPostAgent_DetachesRequestContext verifies that canceling the
154// prompting client's HTTP request context does not cancel the
155// in-flight agent run. The coordinator must observe a context whose
156// Done channel never fires from the request side; only the explicit
157// cancel endpoint may end the run.
158func TestPostAgent_DetachesRequestContext(t *testing.T) {
159 t.Parallel()
160
161 coord := newRunCoordinator(func(context.Context) error {
162 return nil
163 })
164 c, wsID := buildAgentWorkspace(t, coord)
165
166 reqCtx, cancelReq := context.WithCancel(context.Background())
167
168 done := make(chan *httptest.ResponseRecorder, 1)
169 go func() {
170 done <- postAgent(t, c, reqCtx, wsID, "S1")
171 }()
172
173 // Wait until Run is in flight, then drop the prompting client.
174 select {
175 case <-coord.entered:
176 case <-time.After(2 * time.Second):
177 t.Fatal("coordinator Run was never entered")
178 }
179 cancelReq()
180
181 // The captured ctx must be detached: context.WithoutCancel
182 // returns a ctx with Done() == nil so request cancellation cannot
183 // propagate.
184 got := coord.capturedCtx()
185 require.NotNil(t, got)
186 require.Nil(t, got.Done(), "coordinator ctx must be detached from r.Context() via context.WithoutCancel")
187 require.NoError(t, got.Err(), "coordinator ctx must not inherit cancellation from the dropped request")
188
189 // Confirm Run is still running: it should not have completed
190 // just because the request ctx was canceled.
191 select {
192 case <-done:
193 t.Fatal("handler returned before run completed; request ctx cancellation leaked into the run")
194 case <-time.After(50 * time.Millisecond):
195 }
196
197 // Release the run; the handler should now complete cleanly.
198 close(coord.release)
199 select {
200 case rec := <-done:
201 // Writing to a recorder whose request ctx was canceled
202 // still works; in production the TCP write would silently
203 // fail, which is fine because the run already completed and
204 // SSE subscribers have the result.
205 require.Equal(t, http.StatusOK, rec.Code)
206 case <-time.After(2 * time.Second):
207 t.Fatal("handler did not return after release")
208 }
209 require.Equal(t, int32(1), coord.ranCount.Load())
210}