dispatch_cancel_test.go

  1package agent
  2
  3import (
  4	"context"
  5	"errors"
  6	"sync/atomic"
  7	"testing"
  8
  9	"charm.land/fantasy"
 10	"github.com/charmbracelet/crush/internal/message"
 11	"github.com/stretchr/testify/assert"
 12	"github.com/stretchr/testify/require"
 13)
 14
 15// finishStreamModel is a minimal fantasy.LanguageModel that streams a
 16// single text part followed by a normal (FinishReasonStop) finish. It
 17// is enough to drive sessionAgent.Run through PrepareStep and a clean
 18// completion without a recorded provider cassette.
 19type finishStreamModel struct {
 20	text string
 21}
 22
 23func (m *finishStreamModel) Provider() string { return "fake" }
 24func (m *finishStreamModel) Model() string    { return "fake-model" }
 25
 26func (m *finishStreamModel) Generate(ctx context.Context, call fantasy.Call) (*fantasy.Response, error) {
 27	return &fantasy.Response{
 28		Content:      fantasy.ResponseContent{fantasy.TextContent{Text: m.text}},
 29		FinishReason: fantasy.FinishReasonStop,
 30	}, nil
 31}
 32
 33func (m *finishStreamModel) Stream(ctx context.Context, call fantasy.Call) (fantasy.StreamResponse, error) {
 34	text := m.text
 35	return func(yield func(fantasy.StreamPart) bool) {
 36		if !yield(fantasy.StreamPart{Type: fantasy.StreamPartTypeTextStart, ID: "1"}) {
 37			return
 38		}
 39		if !yield(fantasy.StreamPart{Type: fantasy.StreamPartTypeTextDelta, ID: "1", Delta: text}) {
 40			return
 41		}
 42		if !yield(fantasy.StreamPart{Type: fantasy.StreamPartTypeTextEnd, ID: "1"}) {
 43			return
 44		}
 45		yield(fantasy.StreamPart{Type: fantasy.StreamPartTypeFinish, FinishReason: fantasy.FinishReasonStop})
 46	}, nil
 47}
 48
 49func (m *finishStreamModel) GenerateObject(ctx context.Context, call fantasy.ObjectCall) (*fantasy.ObjectResponse, error) {
 50	return nil, errors.New("not implemented")
 51}
 52
 53func (m *finishStreamModel) StreamObject(ctx context.Context, call fantasy.ObjectCall) (fantasy.ObjectStreamResponse, error) {
 54	return nil, errors.New("not implemented")
 55}
 56
 57func newStreamTestAgent(t *testing.T) (*sessionAgent, fakeEnv) {
 58	t.Helper()
 59	env := testEnv(t)
 60	model := &finishStreamModel{text: "done"}
 61	sa := testSessionAgent(env, model, model, "system").(*sessionAgent)
 62	return sa, env
 63}
 64
 65// TestCancel_ActiveAndAcceptedFiresBothBranches covers the case where a
 66// session is actively running (activeRequests set) AND a follow-up has
 67// been accepted (acceptedRuns > 0). A single Cancel must fire both: it
 68// invokes the active cancel func and records a pending cancel for the
 69// accepted follow-up.
 70func TestCancel_ActiveAndAcceptedFiresBothBranches(t *testing.T) {
 71	t.Parallel()
 72	sa, _ := newCancelTestAgent(t)
 73
 74	const sid = "sid"
 75	var activeCanceled atomic.Bool
 76	sa.activeRequests.Set(sid, func() { activeCanceled.Store(true) })
 77
 78	accept := sa.BeginAccepted(sid)
 79	defer accept.Close()
 80
 81	sa.Cancel(sid)
 82
 83	require.True(t, activeCanceled.Load(), "active cancel func must fire")
 84	require.True(t, sa.hasPendingCancel(sid), "accepted follow-up must record a pending cancel")
 85}
 86
 87// TestRun_BusyWithPendingCancelTakesCancelOnEntry covers the busy-queue
 88// branch consulting pendingCancels: when the session is busy AND a
 89// cancel has been recorded for an accepted follow-up, Run must take the
 90// cancel-on-entry path (persist a canceled turn) instead of enqueueing
 91// the call behind the active run.
 92func TestRun_BusyWithPendingCancelTakesCancelOnEntry(t *testing.T) {
 93	t.Parallel()
 94	sa, env := newCancelTestAgent(t)
 95
 96	sess, err := env.sessions.Create(t.Context(), "session")
 97	require.NoError(t, err)
 98
 99	// Make the session look busy: an earlier prompt is active.
100	sa.activeRequests.Set(sess.ID, func() {})
101
102	accept := sa.BeginAccepted(sess.ID)
103	// A cancel arrives while this follow-up is accepted-but-not-active.
104	sa.Cancel(sess.ID)
105	require.True(t, sa.hasPendingCancel(sess.ID))
106
107	result, err := sa.Run(t.Context(), SessionAgentCall{
108		SessionID: sess.ID,
109		Prompt:    "follow-up",
110		Accepted:  accept,
111	})
112	require.NoError(t, err)
113	require.Nil(t, result)
114
115	// The follow-up was canceled on entry, not enqueued.
116	require.Equal(t, 0, sa.QueuedPrompts(sess.ID),
117		"cancel-on-entry must not enqueue the follow-up behind the active run")
118	require.False(t, sa.hasPendingCancel(sess.ID), "pending cancel must be consumed")
119	require.Equal(t, 0, sa.acceptedCount(sess.ID), "accept reservation must be released")
120
121	msgs, err := env.messages.List(t.Context(), sess.ID)
122	require.NoError(t, err)
123	require.Len(t, msgs, 2)
124	assert.Equal(t, message.User, msgs[0].Role)
125	assert.Equal(t, message.Assistant, msgs[1].Role)
126	assert.Equal(t, message.FinishReasonCanceled, msgs[1].FinishReason())
127}
128
129// TestRun_PrepareStepDrainSkipsQueuedOnPendingCancel verifies that the
130// queue drain inside PrepareStep skips queued follow-up prompts when a
131// cancel has been recorded for the session: the queued prompt must not
132// be folded into the active turn as an extra user message.
133func TestRun_PrepareStepDrainSkipsQueuedOnPendingCancel(t *testing.T) {
134	t.Parallel()
135	sa, env := newStreamTestAgent(t)
136
137	sess, err := env.sessions.Create(t.Context(), "session")
138	require.NoError(t, err)
139
140	// A follow-up prompt sits queued for the session.
141	sa.enqueueCall(SessionAgentCall{SessionID: sess.ID, Prompt: "queued-followup"})
142	// A cancel was recorded for the session while it sat in the queue.
143	sa.pendingCancels.Set(sess.ID, struct{}{})
144
145	result, err := sa.Run(t.Context(), SessionAgentCall{
146		SessionID: sess.ID,
147		Prompt:    "main",
148	})
149	require.NoError(t, err)
150	require.NotNil(t, result)
151
152	// Only the main prompt produced a user message; the queued
153	// follow-up was skipped, not folded into the turn.
154	msgs, err := env.messages.List(t.Context(), sess.ID)
155	require.NoError(t, err)
156	var userMsgs []message.Message
157	for _, m := range msgs {
158		if m.Role == message.User {
159			userMsgs = append(userMsgs, m)
160		}
161	}
162	require.Len(t, userMsgs, 1, "queued follow-up must not create a user message")
163	assert.Equal(t, "main", userMsgs[0].Content().String())
164
165	// The queue was drained and the pending cancel consumed.
166	require.Equal(t, 0, sa.QueuedPrompts(sess.ID))
167	require.False(t, sa.hasPendingCancel(sess.ID))
168}
169
170// TestRun_NormalCompletionClearsStalePendingCancel verifies that a Run
171// which completes normally clears any stale pending-cancel entry for the
172// session, so it cannot catch a future run.
173func TestRun_NormalCompletionClearsStalePendingCancel(t *testing.T) {
174	t.Parallel()
175	sa, env := newStreamTestAgent(t)
176
177	sess, err := env.sessions.Create(t.Context(), "session")
178	require.NoError(t, err)
179
180	// A stale pending cancel lingers (no queued work, no accepted run).
181	sa.pendingCancels.Set(sess.ID, struct{}{})
182
183	result, err := sa.Run(t.Context(), SessionAgentCall{
184		SessionID: sess.ID,
185		Prompt:    "main",
186	})
187	require.NoError(t, err)
188	require.NotNil(t, result)
189
190	require.False(t, sa.hasPendingCancel(sess.ID),
191		"normal completion must clear the stale pending cancel")
192
193	msgs, err := env.messages.List(t.Context(), sess.ID)
194	require.NoError(t, err)
195	require.Len(t, msgs, 2)
196	assert.Equal(t, message.Assistant, msgs[1].Role)
197	assert.Equal(t, message.FinishReasonEndTurn, msgs[1].FinishReason())
198}