dispatch_cancel_test.go

  1package agent
  2
  3import (
  4	"context"
  5	"errors"
  6	"sync/atomic"
  7	"testing"
  8	"time"
  9
 10	"charm.land/fantasy"
 11	"github.com/charmbracelet/crush/internal/agent/notify"
 12	"github.com/charmbracelet/crush/internal/message"
 13	"github.com/charmbracelet/crush/internal/pubsub"
 14	"github.com/stretchr/testify/assert"
 15	"github.com/stretchr/testify/require"
 16)
 17
 18// finishStreamModel is a minimal fantasy.LanguageModel that streams a
 19// single text part followed by a normal (FinishReasonStop) finish. It
 20// is enough to drive sessionAgent.Run through PrepareStep and a clean
 21// completion without a recorded provider cassette.
 22type finishStreamModel struct {
 23	text string
 24}
 25
 26func (m *finishStreamModel) Provider() string { return "fake" }
 27func (m *finishStreamModel) Model() string    { return "fake-model" }
 28
 29func (m *finishStreamModel) Generate(ctx context.Context, call fantasy.Call) (*fantasy.Response, error) {
 30	return &fantasy.Response{
 31		Content:      fantasy.ResponseContent{fantasy.TextContent{Text: m.text}},
 32		FinishReason: fantasy.FinishReasonStop,
 33	}, nil
 34}
 35
 36func (m *finishStreamModel) Stream(ctx context.Context, call fantasy.Call) (fantasy.StreamResponse, error) {
 37	text := m.text
 38	return func(yield func(fantasy.StreamPart) bool) {
 39		if !yield(fantasy.StreamPart{Type: fantasy.StreamPartTypeTextStart, ID: "1"}) {
 40			return
 41		}
 42		if !yield(fantasy.StreamPart{Type: fantasy.StreamPartTypeTextDelta, ID: "1", Delta: text}) {
 43			return
 44		}
 45		if !yield(fantasy.StreamPart{Type: fantasy.StreamPartTypeTextEnd, ID: "1"}) {
 46			return
 47		}
 48		yield(fantasy.StreamPart{Type: fantasy.StreamPartTypeFinish, FinishReason: fantasy.FinishReasonStop})
 49	}, nil
 50}
 51
 52func (m *finishStreamModel) GenerateObject(ctx context.Context, call fantasy.ObjectCall) (*fantasy.ObjectResponse, error) {
 53	return nil, errors.New("not implemented")
 54}
 55
 56func (m *finishStreamModel) StreamObject(ctx context.Context, call fantasy.ObjectCall) (fantasy.ObjectStreamResponse, error) {
 57	return nil, errors.New("not implemented")
 58}
 59
 60func newStreamTestAgent(t *testing.T) (*sessionAgent, fakeEnv) {
 61	t.Helper()
 62	env := testEnv(t)
 63	model := &finishStreamModel{text: "done"}
 64	sa := testSessionAgent(env, model, model, "system").(*sessionAgent)
 65	return sa, env
 66}
 67
 68// TestCancel_ActiveAndAcceptedFiresBothBranches covers the case where a
 69// session is actively running (activeRequests set) AND a follow-up has
 70// been accepted (acceptedRuns > 0). A single Cancel must fire both: it
 71// invokes the active cancel func and records a pending cancel for the
 72// accepted follow-up.
 73func TestCancel_ActiveAndAcceptedFiresBothBranches(t *testing.T) {
 74	t.Parallel()
 75	sa, _ := newCancelTestAgent(t)
 76
 77	const sid = "sid"
 78	var activeCanceled atomic.Bool
 79	sa.activeRequests.Set(sid, func() { activeCanceled.Store(true) })
 80
 81	accept := sa.BeginAccepted(sid)
 82	defer accept.Close()
 83
 84	sa.Cancel(sid)
 85
 86	require.True(t, activeCanceled.Load(), "active cancel func must fire")
 87	require.True(t, sa.hasPendingCancel(sid), "accepted follow-up must record a pending cancel")
 88}
 89
 90// TestRun_BusyWithPendingCancelTakesCancelOnEntry covers the busy-queue
 91// branch consulting pendingCancels: when the session is busy AND a
 92// cancel has been recorded for an accepted follow-up, Run must take the
 93// cancel-on-entry path (persist a canceled turn) instead of enqueueing
 94// the call behind the active run.
 95func TestRun_BusyWithPendingCancelTakesCancelOnEntry(t *testing.T) {
 96	t.Parallel()
 97	sa, env := newCancelTestAgent(t)
 98
 99	sess, err := env.sessions.Create(t.Context(), "session")
100	require.NoError(t, err)
101
102	// Make the session look busy: an earlier prompt is active.
103	sa.activeRequests.Set(sess.ID, func() {})
104
105	accept := sa.BeginAccepted(sess.ID)
106	// A cancel arrives while this follow-up is accepted-but-not-active.
107	sa.Cancel(sess.ID)
108	require.True(t, sa.hasPendingCancel(sess.ID))
109
110	result, err := sa.Run(t.Context(), SessionAgentCall{
111		SessionID: sess.ID,
112		Prompt:    "follow-up",
113		Accepted:  accept,
114	})
115	require.NoError(t, err)
116	require.Nil(t, result)
117
118	// The follow-up was canceled on entry, not enqueued.
119	require.Equal(t, 0, sa.QueuedPrompts(sess.ID),
120		"cancel-on-entry must not enqueue the follow-up behind the active run")
121	require.False(t, sa.hasPendingCancel(sess.ID), "pending cancel must be consumed")
122	require.Equal(t, 0, sa.acceptedCount(sess.ID), "accept reservation must be released")
123
124	msgs, err := env.messages.List(t.Context(), sess.ID)
125	require.NoError(t, err)
126	require.Len(t, msgs, 2)
127	assert.Equal(t, message.User, msgs[0].Role)
128	assert.Equal(t, message.Assistant, msgs[1].Role)
129	assert.Equal(t, message.FinishReasonCanceled, msgs[1].FinishReason())
130}
131
132// TestRun_PrepareStepDrainSkipsQueuedOnPendingCancel verifies that the
133// queue drain inside PrepareStep skips queued follow-up prompts when a
134// cancel has been recorded for the session: the queued prompt must not
135// be folded into the active turn as an extra user message.
136func TestRun_PrepareStepDrainSkipsQueuedOnPendingCancel(t *testing.T) {
137	t.Parallel()
138	sa, env := newStreamTestAgent(t)
139
140	sess, err := env.sessions.Create(t.Context(), "session")
141	require.NoError(t, err)
142
143	// A follow-up prompt sits queued for the session.
144	sa.enqueueCall(SessionAgentCall{SessionID: sess.ID, Prompt: "queued-followup"})
145	// A cancel was recorded for the session while it sat in the queue.
146	sa.cancelMark.Set(sess.ID, 1)
147
148	result, err := sa.Run(t.Context(), SessionAgentCall{
149		SessionID: sess.ID,
150		Prompt:    "main",
151	})
152	require.NoError(t, err)
153	require.NotNil(t, result)
154
155	// Only the main prompt produced a user message; the queued
156	// follow-up was skipped, not folded into the turn.
157	msgs, err := env.messages.List(t.Context(), sess.ID)
158	require.NoError(t, err)
159	var userMsgs []message.Message
160	for _, m := range msgs {
161		if m.Role == message.User {
162			userMsgs = append(userMsgs, m)
163		}
164	}
165	require.Len(t, userMsgs, 1, "queued follow-up must not create a user message")
166	assert.Equal(t, "main", userMsgs[0].Content().String())
167
168	// The queue was drained and the pending cancel consumed.
169	require.Equal(t, 0, sa.QueuedPrompts(sess.ID))
170	require.False(t, sa.hasPendingCancel(sess.ID))
171}
172
173// TestRun_NormalCompletionClearsStalePendingCancel verifies that a Run
174// which completes normally clears any stale pending-cancel entry for the
175// session, so it cannot catch a future run.
176func TestRun_NormalCompletionClearsStalePendingCancel(t *testing.T) {
177	t.Parallel()
178	sa, env := newStreamTestAgent(t)
179
180	sess, err := env.sessions.Create(t.Context(), "session")
181	require.NoError(t, err)
182
183	// A stale cancel mark lingers (no queued work, no accepted run).
184	sa.cancelMark.Set(sess.ID, 1)
185
186	result, err := sa.Run(t.Context(), SessionAgentCall{
187		SessionID: sess.ID,
188		Prompt:    "main",
189	})
190	require.NoError(t, err)
191	require.NotNil(t, result)
192
193	require.False(t, sa.hasPendingCancel(sess.ID),
194		"normal completion must clear the stale pending cancel")
195
196	msgs, err := env.messages.List(t.Context(), sess.ID)
197	require.NoError(t, err)
198	require.Len(t, msgs, 2)
199	assert.Equal(t, message.Assistant, msgs[1].Role)
200	assert.Equal(t, message.FinishReasonEndTurn, msgs[1].FinishReason())
201}
202
203// newCancelTestAgentWithRunComplete builds a DB-backed sessionAgent wired
204// to a RunComplete broker so tests can observe the terminal event a
205// RunID-bearing caller (e.g. `crush run`) blocks on.
206func newCancelTestAgentWithRunComplete(t *testing.T) (*sessionAgent, fakeEnv, *pubsub.Broker[notify.RunComplete]) {
207	t.Helper()
208	env := testEnv(t)
209	broker := pubsub.NewBroker[notify.RunComplete]()
210	t.Cleanup(broker.Shutdown)
211	sa := NewSessionAgent(SessionAgentOptions{
212		Sessions:    env.sessions,
213		Messages:    env.messages,
214		RunComplete: broker,
215	}).(*sessionAgent)
216	return sa, env, broker
217}
218
219// TestRun_CancelOnEntryPublishesRunComplete covers the first review
220// finding: the cancel-on-entry path returned before the streaming defer
221// that publishes RunComplete was installed. A caller that dispatches a
222// run with a RunID and blocks on RunComplete (ignoring message events,
223// like `crush run`) would hang on an immediately-canceled accepted run.
224// The cancel-on-entry path must publish a terminal RunComplete carrying
225// the originating RunID.
226func TestRun_CancelOnEntryPublishesRunComplete(t *testing.T) {
227	t.Parallel()
228	sa, env, broker := newCancelTestAgentWithRunComplete(t)
229
230	sess, err := env.sessions.Create(t.Context(), "session")
231	require.NoError(t, err)
232
233	ctx, cancel := context.WithCancel(t.Context())
234	defer cancel()
235	ch := broker.Subscribe(ctx)
236
237	accept := sa.BeginAccepted(sess.ID)
238	// A cancel arrives in the accepted-but-not-yet-active window.
239	sa.Cancel(sess.ID)
240	require.True(t, sa.hasPendingCancel(sess.ID))
241
242	result, err := sa.Run(t.Context(), SessionAgentCall{
243		SessionID: sess.ID,
244		RunID:     "run-cancel-on-entry",
245		Prompt:    "hello",
246		Accepted:  accept,
247	})
248	require.NoError(t, err)
249	require.Nil(t, result)
250
251	select {
252	case got := <-ch:
253		assert.Equal(t, "run-cancel-on-entry", got.Payload.RunID,
254			"RunComplete must echo the originating RunID")
255		assert.Equal(t, sess.ID, got.Payload.SessionID)
256		assert.True(t, got.Payload.Cancelled,
257			"cancel-on-entry RunComplete must be marked cancelled")
258	case <-time.After(2 * time.Second):
259		t.Fatal("cancel-on-entry must publish RunComplete; a RunID caller would hang otherwise")
260	}
261}
262
263// TestCancel_TwoAcceptedBothObserveCancellation covers the second review
264// finding: a single cancel with two accepted-not-yet-active prompts must
265// cancel both. The cancel raises the session's high-water mark to the
266// latest accept sequence, so every prompt accepted-but-not-yet-active at
267// cancel time is covered and both take the cancel-on-entry path.
268func TestCancel_TwoAcceptedBothObserveCancellation(t *testing.T) {
269	t.Parallel()
270	sa, env := newCancelTestAgent(t)
271
272	sess, err := env.sessions.Create(t.Context(), "session")
273	require.NoError(t, err)
274
275	// Two prompts are accepted-but-not-yet-active for the same session.
276	accept1 := sa.BeginAccepted(sess.ID)
277	accept2 := sa.BeginAccepted(sess.ID)
278	require.Equal(t, 2, sa.acceptedCount(sess.ID))
279
280	// A single cancel arrives before either becomes active.
281	sa.Cancel(sess.ID)
282	require.Equal(t, accept2.seq, sa.pendingCancelMark(sess.ID),
283		"one cancel must mark every currently-accepted prompt as canceled")
284	require.GreaterOrEqual(t, sa.pendingCancelMark(sess.ID), accept1.seq,
285		"the mark must cover the earlier accepted prompt too")
286
287	// Both prompts enter Run; each must take cancel-on-entry, not run.
288	r1, err := sa.Run(t.Context(), SessionAgentCall{
289		SessionID: sess.ID,
290		Prompt:    "first",
291		Accepted:  accept1,
292	})
293	require.NoError(t, err)
294	require.Nil(t, r1)
295
296	r2, err := sa.Run(t.Context(), SessionAgentCall{
297		SessionID: sess.ID,
298		Prompt:    "second",
299		Accepted:  accept2,
300	})
301	require.NoError(t, err)
302	require.Nil(t, r2)
303
304	require.False(t, sa.hasPendingCancel(sess.ID),
305		"both reserved units must be consumed")
306	require.Equal(t, 0, sa.acceptedCount(sess.ID),
307		"both accept reservations must be released")
308
309	// Each canceled-on-entry turn writes a user + canceled assistant
310	// message, and neither prompt was enqueued to run normally.
311	require.Equal(t, 0, sa.QueuedPrompts(sess.ID),
312		"neither accepted prompt may be enqueued to run normally")
313	msgs, err := env.messages.List(t.Context(), sess.ID)
314	require.NoError(t, err)
315	require.Len(t, msgs, 4, "two canceled turns produce two user + two assistant messages")
316	var canceled int
317	for _, m := range msgs {
318		if m.Role == message.Assistant {
319			assert.Equal(t, message.FinishReasonCanceled, m.FinishReason())
320			canceled++
321		}
322	}
323	require.Equal(t, 2, canceled, "both turns must finish canceled")
324}
325
326// TestRun_IdleCancelDoesNotPoisonNextPrompt covers the idle-cancel
327// no-op guarantee end-to-end: an Escape on an idle session must not
328// record a pending cancel that leaks into the next accepted prompt, which
329// must run normally to completion.
330func TestRun_IdleCancelDoesNotPoisonNextPrompt(t *testing.T) {
331	t.Parallel()
332	sa, env := newStreamTestAgent(t)
333
334	sess, err := env.sessions.Create(t.Context(), "session")
335	require.NoError(t, err)
336
337	// Idle Escape: no accepted run, no active request.
338	sa.Cancel(sess.ID)
339	require.False(t, sa.hasPendingCancel(sess.ID),
340		"idle cancel must not record a pending cancel")
341
342	// The next accepted prompt must run normally, not cancel on entry.
343	accept := sa.BeginAccepted(sess.ID)
344	result, err := sa.Run(t.Context(), SessionAgentCall{
345		SessionID: sess.ID,
346		Prompt:    "next",
347		Accepted:  accept,
348	})
349	require.NoError(t, err)
350	require.NotNil(t, result, "next prompt must run normally after an idle cancel")
351
352	msgs, err := env.messages.List(t.Context(), sess.ID)
353	require.NoError(t, err)
354	require.Len(t, msgs, 2)
355	assert.Equal(t, message.User, msgs[0].Role)
356	assert.Equal(t, message.Assistant, msgs[1].Role)
357	assert.Equal(t, message.FinishReasonEndTurn, msgs[1].FinishReason(),
358		"the prompt must finish normally, not canceled")
359}
360
361// TestCancel_AcceptedAfterCancelIsNotPoisoned is the regression test for
362// the reviewer's finding: a counted session-level pending cancel let a
363// prompt accepted after the cancel enter Run first and consume a unit
364// reserved for the earlier prompts. With a sequence high-water mark, a
365// single cancel covers exactly the prompts accepted-but-not-yet-active at
366// cancel time (A and B); a prompt accepted afterwards (C) gets a higher
367// sequence and must run normally without consuming A or B's cancellation.
368// C is run first to prove it neither cancels nor drains the mark, then A
369// and B are run and must both cancel on entry.
370func TestCancel_AcceptedAfterCancelIsNotPoisoned(t *testing.T) {
371	t.Parallel()
372	sa, env := newStreamTestAgent(t)
373
374	sess, err := env.sessions.Create(t.Context(), "session")
375	require.NoError(t, err)
376
377	// A and B are accepted-but-not-yet-active.
378	acceptA := sa.BeginAccepted(sess.ID)
379	acceptB := sa.BeginAccepted(sess.ID)
380
381	// One cancel arrives covering both A and B.
382	sa.Cancel(sess.ID)
383	require.True(t, sa.hasPendingCancel(sess.ID))
384	require.Equal(t, acceptB.seq, sa.pendingCancelMark(sess.ID),
385		"the mark must cover every prompt accepted before the cancel")
386
387	// C is accepted AFTER the cancel; its sequence is above the mark.
388	acceptC := sa.BeginAccepted(sess.ID)
389	require.Greater(t, acceptC.seq, sa.pendingCancelMark(sess.ID),
390		"a prompt accepted after the cancel must not be covered by the mark")
391
392	// Run C first. It must run normally to completion and must not
393	// consume or clear the cancellation reserved for A and B.
394	rc, err := sa.Run(t.Context(), SessionAgentCall{
395		SessionID: sess.ID,
396		Prompt:    "C",
397		Accepted:  acceptC,
398	})
399	require.NoError(t, err)
400	require.NotNil(t, rc, "C was accepted after the cancel and must run normally")
401	require.True(t, sa.hasPendingCancel(sess.ID),
402		"running C must not drain the cancellation reserved for A and B")
403
404	// Now A and B run. Both were accepted before the cancel and must
405	// take the cancel-on-entry path.
406	ra, err := sa.Run(t.Context(), SessionAgentCall{
407		SessionID: sess.ID,
408		Prompt:    "A",
409		Accepted:  acceptA,
410	})
411	require.NoError(t, err)
412	require.Nil(t, ra, "A must cancel on entry, not run")
413
414	rb, err := sa.Run(t.Context(), SessionAgentCall{
415		SessionID: sess.ID,
416		Prompt:    "B",
417		Accepted:  acceptB,
418	})
419	require.NoError(t, err)
420	require.Nil(t, rb, "B must cancel on entry, not run")
421
422	require.False(t, sa.hasPendingCancel(sess.ID),
423		"the mark clears once all covered handles are resolved")
424	require.Equal(t, 0, sa.acceptedCount(sess.ID))
425	require.Equal(t, 0, sa.QueuedPrompts(sess.ID),
426		"neither A nor B may be enqueued to run normally")
427
428	// C produced a normal turn; A and B each produced a canceled turn.
429	msgs, err := env.messages.List(t.Context(), sess.ID)
430	require.NoError(t, err)
431	require.Len(t, msgs, 6, "C normal + A canceled + B canceled = 3 user + 3 assistant")
432
433	var normal, canceled int
434	for _, m := range msgs {
435		if m.Role != message.Assistant {
436			continue
437		}
438		switch m.FinishReason() {
439		case message.FinishReasonEndTurn:
440			normal++
441		case message.FinishReasonCanceled:
442			canceled++
443		}
444	}
445	require.Equal(t, 1, normal, "only C finished normally")
446	require.Equal(t, 2, canceled, "both A and B finished canceled")
447}