1package agent
2
3import (
4 "context"
5 "errors"
6 "sync/atomic"
7 "testing"
8 "time"
9
10 "charm.land/fantasy"
11 "github.com/charmbracelet/crush/internal/agent/notify"
12 "github.com/charmbracelet/crush/internal/message"
13 "github.com/charmbracelet/crush/internal/pubsub"
14 "github.com/stretchr/testify/assert"
15 "github.com/stretchr/testify/require"
16)
17
18// finishStreamModel is a minimal fantasy.LanguageModel that streams a
19// single text part followed by a normal (FinishReasonStop) finish. It
20// is enough to drive sessionAgent.Run through PrepareStep and a clean
21// completion without a recorded provider cassette.
22type finishStreamModel struct {
23 text string
24}
25
26func (m *finishStreamModel) Provider() string { return "fake" }
27func (m *finishStreamModel) Model() string { return "fake-model" }
28
29func (m *finishStreamModel) Generate(ctx context.Context, call fantasy.Call) (*fantasy.Response, error) {
30 return &fantasy.Response{
31 Content: fantasy.ResponseContent{fantasy.TextContent{Text: m.text}},
32 FinishReason: fantasy.FinishReasonStop,
33 }, nil
34}
35
36func (m *finishStreamModel) Stream(ctx context.Context, call fantasy.Call) (fantasy.StreamResponse, error) {
37 text := m.text
38 return func(yield func(fantasy.StreamPart) bool) {
39 if !yield(fantasy.StreamPart{Type: fantasy.StreamPartTypeTextStart, ID: "1"}) {
40 return
41 }
42 if !yield(fantasy.StreamPart{Type: fantasy.StreamPartTypeTextDelta, ID: "1", Delta: text}) {
43 return
44 }
45 if !yield(fantasy.StreamPart{Type: fantasy.StreamPartTypeTextEnd, ID: "1"}) {
46 return
47 }
48 yield(fantasy.StreamPart{Type: fantasy.StreamPartTypeFinish, FinishReason: fantasy.FinishReasonStop})
49 }, nil
50}
51
52func (m *finishStreamModel) GenerateObject(ctx context.Context, call fantasy.ObjectCall) (*fantasy.ObjectResponse, error) {
53 return nil, errors.New("not implemented")
54}
55
56func (m *finishStreamModel) StreamObject(ctx context.Context, call fantasy.ObjectCall) (fantasy.ObjectStreamResponse, error) {
57 return nil, errors.New("not implemented")
58}
59
60func newStreamTestAgent(t *testing.T) (*sessionAgent, fakeEnv) {
61 t.Helper()
62 env := testEnv(t)
63 model := &finishStreamModel{text: "done"}
64 sa := testSessionAgent(env, model, model, "system").(*sessionAgent)
65 return sa, env
66}
67
68// TestCancel_ActiveAndAcceptedFiresBothBranches covers the case where a
69// session is actively running (activeRequests set) AND a follow-up has
70// been accepted (acceptedRuns > 0). A single Cancel must fire both: it
71// invokes the active cancel func and records a pending cancel for the
72// accepted follow-up.
73func TestCancel_ActiveAndAcceptedFiresBothBranches(t *testing.T) {
74 t.Parallel()
75 sa, _ := newCancelTestAgent(t)
76
77 const sid = "sid"
78 var activeCanceled atomic.Bool
79 sa.activeRequests.Set(sid, func() { activeCanceled.Store(true) })
80
81 accept := sa.BeginAccepted(sid)
82 defer accept.Close()
83
84 sa.Cancel(sid)
85
86 require.True(t, activeCanceled.Load(), "active cancel func must fire")
87 require.True(t, sa.hasPendingCancel(sid), "accepted follow-up must record a pending cancel")
88}
89
90// TestRun_BusyWithPendingCancelTakesCancelOnEntry covers the busy-queue
91// branch consulting pendingCancels: when the session is busy AND a
92// cancel has been recorded for an accepted follow-up, Run must take the
93// cancel-on-entry path (persist a canceled turn) instead of enqueueing
94// the call behind the active run.
95func TestRun_BusyWithPendingCancelTakesCancelOnEntry(t *testing.T) {
96 t.Parallel()
97 sa, env := newCancelTestAgent(t)
98
99 sess, err := env.sessions.Create(t.Context(), "session")
100 require.NoError(t, err)
101
102 // Make the session look busy: an earlier prompt is active.
103 sa.activeRequests.Set(sess.ID, func() {})
104
105 accept := sa.BeginAccepted(sess.ID)
106 // A cancel arrives while this follow-up is accepted-but-not-active.
107 sa.Cancel(sess.ID)
108 require.True(t, sa.hasPendingCancel(sess.ID))
109
110 result, err := sa.Run(t.Context(), SessionAgentCall{
111 SessionID: sess.ID,
112 Prompt: "follow-up",
113 Accepted: accept,
114 })
115 require.NoError(t, err)
116 require.Nil(t, result)
117
118 // The follow-up was canceled on entry, not enqueued.
119 require.Equal(t, 0, sa.QueuedPrompts(sess.ID),
120 "cancel-on-entry must not enqueue the follow-up behind the active run")
121 require.False(t, sa.hasPendingCancel(sess.ID), "pending cancel must be consumed")
122 require.Equal(t, 0, sa.acceptedCount(sess.ID), "accept reservation must be released")
123
124 msgs, err := env.messages.List(t.Context(), sess.ID)
125 require.NoError(t, err)
126 require.Len(t, msgs, 2)
127 assert.Equal(t, message.User, msgs[0].Role)
128 assert.Equal(t, message.Assistant, msgs[1].Role)
129 assert.Equal(t, message.FinishReasonCanceled, msgs[1].FinishReason())
130}
131
132// TestRun_PrepareStepDrainSkipsQueuedOnPendingCancel verifies that the
133// queue drain inside PrepareStep skips queued follow-up prompts when a
134// cancel has been recorded for the session: the queued prompt must not
135// be folded into the active turn as an extra user message.
136func TestRun_PrepareStepDrainSkipsQueuedOnPendingCancel(t *testing.T) {
137 t.Parallel()
138 sa, env := newStreamTestAgent(t)
139
140 sess, err := env.sessions.Create(t.Context(), "session")
141 require.NoError(t, err)
142
143 // A follow-up prompt sits queued for the session.
144 sa.enqueueCall(SessionAgentCall{SessionID: sess.ID, Prompt: "queued-followup"})
145 // A cancel was recorded for the session while it sat in the queue.
146 sa.cancelMark.Set(sess.ID, 1)
147
148 result, err := sa.Run(t.Context(), SessionAgentCall{
149 SessionID: sess.ID,
150 Prompt: "main",
151 })
152 require.NoError(t, err)
153 require.NotNil(t, result)
154
155 // Only the main prompt produced a user message; the queued
156 // follow-up was skipped, not folded into the turn.
157 msgs, err := env.messages.List(t.Context(), sess.ID)
158 require.NoError(t, err)
159 var userMsgs []message.Message
160 for _, m := range msgs {
161 if m.Role == message.User {
162 userMsgs = append(userMsgs, m)
163 }
164 }
165 require.Len(t, userMsgs, 1, "queued follow-up must not create a user message")
166 assert.Equal(t, "main", userMsgs[0].Content().String())
167
168 // The queue was drained and the pending cancel consumed.
169 require.Equal(t, 0, sa.QueuedPrompts(sess.ID))
170 require.False(t, sa.hasPendingCancel(sess.ID))
171}
172
173// TestRun_NormalCompletionClearsStalePendingCancel verifies that a Run
174// which completes normally clears any stale pending-cancel entry for the
175// session, so it cannot catch a future run.
176func TestRun_NormalCompletionClearsStalePendingCancel(t *testing.T) {
177 t.Parallel()
178 sa, env := newStreamTestAgent(t)
179
180 sess, err := env.sessions.Create(t.Context(), "session")
181 require.NoError(t, err)
182
183 // A stale cancel mark lingers (no queued work, no accepted run).
184 sa.cancelMark.Set(sess.ID, 1)
185
186 result, err := sa.Run(t.Context(), SessionAgentCall{
187 SessionID: sess.ID,
188 Prompt: "main",
189 })
190 require.NoError(t, err)
191 require.NotNil(t, result)
192
193 require.False(t, sa.hasPendingCancel(sess.ID),
194 "normal completion must clear the stale pending cancel")
195
196 msgs, err := env.messages.List(t.Context(), sess.ID)
197 require.NoError(t, err)
198 require.Len(t, msgs, 2)
199 assert.Equal(t, message.Assistant, msgs[1].Role)
200 assert.Equal(t, message.FinishReasonEndTurn, msgs[1].FinishReason())
201}
202
203// newCancelTestAgentWithRunComplete builds a DB-backed sessionAgent wired
204// to a RunComplete broker so tests can observe the terminal event a
205// RunID-bearing caller (e.g. `crush run`) blocks on.
206func newCancelTestAgentWithRunComplete(t *testing.T) (*sessionAgent, fakeEnv, *pubsub.Broker[notify.RunComplete]) {
207 t.Helper()
208 env := testEnv(t)
209 broker := pubsub.NewBroker[notify.RunComplete]()
210 t.Cleanup(broker.Shutdown)
211 sa := NewSessionAgent(SessionAgentOptions{
212 Sessions: env.sessions,
213 Messages: env.messages,
214 RunComplete: broker,
215 }).(*sessionAgent)
216 return sa, env, broker
217}
218
219// TestRun_CancelOnEntryPublishesRunComplete covers the first review
220// finding: the cancel-on-entry path returned before the streaming defer
221// that publishes RunComplete was installed. A caller that dispatches a
222// run with a RunID and blocks on RunComplete (ignoring message events,
223// like `crush run`) would hang on an immediately-canceled accepted run.
224// The cancel-on-entry path must publish a terminal RunComplete carrying
225// the originating RunID.
226func TestRun_CancelOnEntryPublishesRunComplete(t *testing.T) {
227 t.Parallel()
228 sa, env, broker := newCancelTestAgentWithRunComplete(t)
229
230 sess, err := env.sessions.Create(t.Context(), "session")
231 require.NoError(t, err)
232
233 ctx, cancel := context.WithCancel(t.Context())
234 defer cancel()
235 ch := broker.Subscribe(ctx)
236
237 accept := sa.BeginAccepted(sess.ID)
238 // A cancel arrives in the accepted-but-not-yet-active window.
239 sa.Cancel(sess.ID)
240 require.True(t, sa.hasPendingCancel(sess.ID))
241
242 result, err := sa.Run(t.Context(), SessionAgentCall{
243 SessionID: sess.ID,
244 RunID: "run-cancel-on-entry",
245 Prompt: "hello",
246 Accepted: accept,
247 })
248 require.NoError(t, err)
249 require.Nil(t, result)
250
251 select {
252 case got := <-ch:
253 assert.Equal(t, "run-cancel-on-entry", got.Payload.RunID,
254 "RunComplete must echo the originating RunID")
255 assert.Equal(t, sess.ID, got.Payload.SessionID)
256 assert.True(t, got.Payload.Cancelled,
257 "cancel-on-entry RunComplete must be marked cancelled")
258 case <-time.After(2 * time.Second):
259 t.Fatal("cancel-on-entry must publish RunComplete; a RunID caller would hang otherwise")
260 }
261}
262
263// TestCancel_TwoAcceptedBothObserveCancellation covers the second review
264// finding: a single cancel with two accepted-not-yet-active prompts must
265// cancel both. The cancel raises the session's high-water mark to the
266// latest accept sequence, so every prompt accepted-but-not-yet-active at
267// cancel time is covered and both take the cancel-on-entry path.
268func TestCancel_TwoAcceptedBothObserveCancellation(t *testing.T) {
269 t.Parallel()
270 sa, env := newCancelTestAgent(t)
271
272 sess, err := env.sessions.Create(t.Context(), "session")
273 require.NoError(t, err)
274
275 // Two prompts are accepted-but-not-yet-active for the same session.
276 accept1 := sa.BeginAccepted(sess.ID)
277 accept2 := sa.BeginAccepted(sess.ID)
278 require.Equal(t, 2, sa.acceptedCount(sess.ID))
279
280 // A single cancel arrives before either becomes active.
281 sa.Cancel(sess.ID)
282 require.Equal(t, accept2.seq, sa.pendingCancelMark(sess.ID),
283 "one cancel must mark every currently-accepted prompt as canceled")
284 require.GreaterOrEqual(t, sa.pendingCancelMark(sess.ID), accept1.seq,
285 "the mark must cover the earlier accepted prompt too")
286
287 // Both prompts enter Run; each must take cancel-on-entry, not run.
288 r1, err := sa.Run(t.Context(), SessionAgentCall{
289 SessionID: sess.ID,
290 Prompt: "first",
291 Accepted: accept1,
292 })
293 require.NoError(t, err)
294 require.Nil(t, r1)
295
296 r2, err := sa.Run(t.Context(), SessionAgentCall{
297 SessionID: sess.ID,
298 Prompt: "second",
299 Accepted: accept2,
300 })
301 require.NoError(t, err)
302 require.Nil(t, r2)
303
304 require.False(t, sa.hasPendingCancel(sess.ID),
305 "both reserved units must be consumed")
306 require.Equal(t, 0, sa.acceptedCount(sess.ID),
307 "both accept reservations must be released")
308
309 // Each canceled-on-entry turn writes a user + canceled assistant
310 // message, and neither prompt was enqueued to run normally.
311 require.Equal(t, 0, sa.QueuedPrompts(sess.ID),
312 "neither accepted prompt may be enqueued to run normally")
313 msgs, err := env.messages.List(t.Context(), sess.ID)
314 require.NoError(t, err)
315 require.Len(t, msgs, 4, "two canceled turns produce two user + two assistant messages")
316 var canceled int
317 for _, m := range msgs {
318 if m.Role == message.Assistant {
319 assert.Equal(t, message.FinishReasonCanceled, m.FinishReason())
320 canceled++
321 }
322 }
323 require.Equal(t, 2, canceled, "both turns must finish canceled")
324}
325
326// TestRun_IdleCancelDoesNotPoisonNextPrompt covers the idle-cancel
327// no-op guarantee end-to-end: an Escape on an idle session must not
328// record a pending cancel that leaks into the next accepted prompt, which
329// must run normally to completion.
330func TestRun_IdleCancelDoesNotPoisonNextPrompt(t *testing.T) {
331 t.Parallel()
332 sa, env := newStreamTestAgent(t)
333
334 sess, err := env.sessions.Create(t.Context(), "session")
335 require.NoError(t, err)
336
337 // Idle Escape: no accepted run, no active request.
338 sa.Cancel(sess.ID)
339 require.False(t, sa.hasPendingCancel(sess.ID),
340 "idle cancel must not record a pending cancel")
341
342 // The next accepted prompt must run normally, not cancel on entry.
343 accept := sa.BeginAccepted(sess.ID)
344 result, err := sa.Run(t.Context(), SessionAgentCall{
345 SessionID: sess.ID,
346 Prompt: "next",
347 Accepted: accept,
348 })
349 require.NoError(t, err)
350 require.NotNil(t, result, "next prompt must run normally after an idle cancel")
351
352 msgs, err := env.messages.List(t.Context(), sess.ID)
353 require.NoError(t, err)
354 require.Len(t, msgs, 2)
355 assert.Equal(t, message.User, msgs[0].Role)
356 assert.Equal(t, message.Assistant, msgs[1].Role)
357 assert.Equal(t, message.FinishReasonEndTurn, msgs[1].FinishReason(),
358 "the prompt must finish normally, not canceled")
359}
360
361// TestCancel_AcceptedAfterCancelIsNotPoisoned is the regression test for
362// the reviewer's finding: a counted session-level pending cancel let a
363// prompt accepted after the cancel enter Run first and consume a unit
364// reserved for the earlier prompts. With a sequence high-water mark, a
365// single cancel covers exactly the prompts accepted-but-not-yet-active at
366// cancel time (A and B); a prompt accepted afterwards (C) gets a higher
367// sequence and must run normally without consuming A or B's cancellation.
368// C is run first to prove it neither cancels nor drains the mark, then A
369// and B are run and must both cancel on entry.
370func TestCancel_AcceptedAfterCancelIsNotPoisoned(t *testing.T) {
371 t.Parallel()
372 sa, env := newStreamTestAgent(t)
373
374 sess, err := env.sessions.Create(t.Context(), "session")
375 require.NoError(t, err)
376
377 // A and B are accepted-but-not-yet-active.
378 acceptA := sa.BeginAccepted(sess.ID)
379 acceptB := sa.BeginAccepted(sess.ID)
380
381 // One cancel arrives covering both A and B.
382 sa.Cancel(sess.ID)
383 require.True(t, sa.hasPendingCancel(sess.ID))
384 require.Equal(t, acceptB.seq, sa.pendingCancelMark(sess.ID),
385 "the mark must cover every prompt accepted before the cancel")
386
387 // C is accepted AFTER the cancel; its sequence is above the mark.
388 acceptC := sa.BeginAccepted(sess.ID)
389 require.Greater(t, acceptC.seq, sa.pendingCancelMark(sess.ID),
390 "a prompt accepted after the cancel must not be covered by the mark")
391
392 // Run C first. It must run normally to completion and must not
393 // consume or clear the cancellation reserved for A and B.
394 rc, err := sa.Run(t.Context(), SessionAgentCall{
395 SessionID: sess.ID,
396 Prompt: "C",
397 Accepted: acceptC,
398 })
399 require.NoError(t, err)
400 require.NotNil(t, rc, "C was accepted after the cancel and must run normally")
401 require.True(t, sa.hasPendingCancel(sess.ID),
402 "running C must not drain the cancellation reserved for A and B")
403
404 // Now A and B run. Both were accepted before the cancel and must
405 // take the cancel-on-entry path.
406 ra, err := sa.Run(t.Context(), SessionAgentCall{
407 SessionID: sess.ID,
408 Prompt: "A",
409 Accepted: acceptA,
410 })
411 require.NoError(t, err)
412 require.Nil(t, ra, "A must cancel on entry, not run")
413
414 rb, err := sa.Run(t.Context(), SessionAgentCall{
415 SessionID: sess.ID,
416 Prompt: "B",
417 Accepted: acceptB,
418 })
419 require.NoError(t, err)
420 require.Nil(t, rb, "B must cancel on entry, not run")
421
422 require.False(t, sa.hasPendingCancel(sess.ID),
423 "the mark clears once all covered handles are resolved")
424 require.Equal(t, 0, sa.acceptedCount(sess.ID))
425 require.Equal(t, 0, sa.QueuedPrompts(sess.ID),
426 "neither A nor B may be enqueued to run normally")
427
428 // C produced a normal turn; A and B each produced a canceled turn.
429 msgs, err := env.messages.List(t.Context(), sess.ID)
430 require.NoError(t, err)
431 require.Len(t, msgs, 6, "C normal + A canceled + B canceled = 3 user + 3 assistant")
432
433 var normal, canceled int
434 for _, m := range msgs {
435 if m.Role != message.Assistant {
436 continue
437 }
438 switch m.FinishReason() {
439 case message.FinishReasonEndTurn:
440 normal++
441 case message.FinishReasonCanceled:
442 canceled++
443 }
444 }
445 require.Equal(t, 1, normal, "only C finished normally")
446 require.Equal(t, 2, canceled, "both A and B finished canceled")
447}