accepted_run_test.go

  1package agent
  2
  3import (
  4	"context"
  5	"testing"
  6
  7	"github.com/charmbracelet/crush/internal/message"
  8	"github.com/stretchr/testify/assert"
  9	"github.com/stretchr/testify/require"
 10)
 11
 12// newCancelTestAgent builds a DB-backed sessionAgent with no model. The
 13// tests here exercise the dispatch/cancel/persist paths, none of which
 14// reach agent.Stream, so a model is unnecessary.
 15func newCancelTestAgent(t *testing.T) (*sessionAgent, fakeEnv) {
 16	t.Helper()
 17	env := testEnv(t)
 18	sa := NewSessionAgent(SessionAgentOptions{
 19		Sessions: env.sessions,
 20		Messages: env.messages,
 21	}).(*sessionAgent)
 22	return sa, env
 23}
 24
 25func (a *sessionAgent) acceptedCount(sessionID string) int {
 26	c, _ := a.acceptedRuns.Get(sessionID)
 27	return c
 28}
 29
 30func (a *sessionAgent) hasPendingCancel(sessionID string) bool {
 31	mark, ok := a.cancelMark.Get(sessionID)
 32	return ok && mark > 0
 33}
 34
 35func (a *sessionAgent) pendingCancelMark(sessionID string) uint64 {
 36	mark, _ := a.cancelMark.Get(sessionID)
 37	return mark
 38}
 39
 40func TestAcceptedRun_CloseIsIdempotent(t *testing.T) {
 41	t.Parallel()
 42	sa, _ := newCancelTestAgent(t)
 43
 44	accept := sa.BeginAccepted("sid")
 45	require.Equal(t, "sid", accept.SessionID())
 46	require.Equal(t, 1, sa.acceptedCount("sid"))
 47
 48	accept.Close()
 49	require.Equal(t, 0, sa.acceptedCount("sid"))
 50
 51	// Repeated Close must not underflow the counter.
 52	accept.Close()
 53	accept.Close()
 54	require.Equal(t, 0, sa.acceptedCount("sid"))
 55}
 56
 57func TestAcceptedRun_MultipleReservations(t *testing.T) {
 58	t.Parallel()
 59	sa, _ := newCancelTestAgent(t)
 60
 61	a1 := sa.BeginAccepted("sid")
 62	a2 := sa.BeginAccepted("sid")
 63	require.Equal(t, 2, sa.acceptedCount("sid"))
 64
 65	a1.Close()
 66	require.Equal(t, 1, sa.acceptedCount("sid"))
 67
 68	a2.Close()
 69	require.Equal(t, 0, sa.acceptedCount("sid"))
 70}
 71
 72func TestAcceptedRun_NilSafe(t *testing.T) {
 73	t.Parallel()
 74	var accept *AcceptedRun
 75	require.Equal(t, "", accept.SessionID())
 76	// Must not panic.
 77	accept.Close()
 78}
 79
 80func TestCancel_IdleDoesNotRecordPendingCancel(t *testing.T) {
 81	t.Parallel()
 82	sa, _ := newCancelTestAgent(t)
 83
 84	// No accepted run, no active request: a true no-op.
 85	sa.Cancel("sid")
 86	require.False(t, sa.hasPendingCancel("sid"))
 87}
 88
 89func TestCancel_AcceptedRecordsPendingCancel(t *testing.T) {
 90	t.Parallel()
 91	sa, _ := newCancelTestAgent(t)
 92
 93	accept := sa.BeginAccepted("sid")
 94	defer accept.Close()
 95
 96	sa.Cancel("sid")
 97	require.True(t, sa.hasPendingCancel("sid"))
 98}
 99
100func TestCancel_SecondCancelWhilePendingIsNoOp(t *testing.T) {
101	t.Parallel()
102	sa, _ := newCancelTestAgent(t)
103
104	accept := sa.BeginAccepted("sid")
105	defer accept.Close()
106
107	sa.Cancel("sid")
108	require.True(t, sa.hasPendingCancel("sid"))
109
110	// A second cancel while a pending cancel is already recorded must
111	// remain a single pending cancel; one Run consumes exactly one.
112	sa.Cancel("sid")
113	require.True(t, sa.hasPendingCancel("sid"))
114}
115
116func TestRun_CancelOnEntryPersistsCanceledTurn(t *testing.T) {
117	t.Parallel()
118	sa, env := newCancelTestAgent(t)
119
120	sess, err := env.sessions.Create(t.Context(), "session")
121	require.NoError(t, err)
122
123	accept := sa.BeginAccepted(sess.ID)
124	// A cancel arrives in the accepted-but-not-yet-active window.
125	sa.Cancel(sess.ID)
126	require.True(t, sa.hasPendingCancel(sess.ID))
127
128	result, err := sa.Run(t.Context(), SessionAgentCall{
129		SessionID: sess.ID,
130		Prompt:    "hello",
131		Accepted:  accept,
132	})
133	require.NoError(t, err)
134	require.Nil(t, result)
135
136	// The pending cancel was consumed and the accept released.
137	require.False(t, sa.hasPendingCancel(sess.ID))
138	require.Equal(t, 0, sa.acceptedCount(sess.ID))
139
140	msgs, err := env.messages.List(t.Context(), sess.ID)
141	require.NoError(t, err)
142	require.Len(t, msgs, 2)
143	assert.Equal(t, message.User, msgs[0].Role)
144	assert.Equal(t, message.Assistant, msgs[1].Role)
145	assert.Equal(t, message.FinishReasonCanceled, msgs[1].FinishReason())
146}
147
148func TestPersistCanceledTurn_WritesBothWhenUserMissing(t *testing.T) {
149	t.Parallel()
150	sa, env := newCancelTestAgent(t)
151
152	sess, err := env.sessions.Create(t.Context(), "session")
153	require.NoError(t, err)
154
155	err = sa.persistCanceledTurn(t.Context(), SessionAgentCall{
156		SessionID: sess.ID,
157		Prompt:    "hello",
158	}, false)
159	require.NoError(t, err)
160
161	msgs, err := env.messages.List(t.Context(), sess.ID)
162	require.NoError(t, err)
163	require.Len(t, msgs, 2)
164	assert.Equal(t, message.User, msgs[0].Role)
165	assert.Equal(t, message.Assistant, msgs[1].Role)
166	assert.Equal(t, message.FinishReasonCanceled, msgs[1].FinishReason())
167}
168
169func TestPersistCanceledTurn_WritesAssistantOnlyWhenUserCreated(t *testing.T) {
170	t.Parallel()
171	sa, env := newCancelTestAgent(t)
172
173	sess, err := env.sessions.Create(t.Context(), "session")
174	require.NoError(t, err)
175
176	// Simulate PrepareStep having already created the user message.
177	_, err = sa.createUserMessage(t.Context(), SessionAgentCall{
178		SessionID: sess.ID,
179		Prompt:    "hello",
180	})
181	require.NoError(t, err)
182
183	err = sa.persistCanceledTurn(t.Context(), SessionAgentCall{
184		SessionID: sess.ID,
185		Prompt:    "hello",
186	}, true)
187	require.NoError(t, err)
188
189	msgs, err := env.messages.List(t.Context(), sess.ID)
190	require.NoError(t, err)
191	require.Len(t, msgs, 2)
192	assert.Equal(t, message.User, msgs[0].Role)
193	assert.Equal(t, message.Assistant, msgs[1].Role)
194	assert.Equal(t, message.FinishReasonCanceled, msgs[1].FinishReason())
195}
196
197func TestPersistCanceledTurn_SucceedsWithCanceledContext(t *testing.T) {
198	t.Parallel()
199	sa, env := newCancelTestAgent(t)
200
201	sess, err := env.sessions.Create(t.Context(), "session")
202	require.NoError(t, err)
203
204	// Simulate workspace shutdown having already canceled the run
205	// context. WithoutCancel must let the writes through.
206	ctx, cancel := context.WithCancel(t.Context())
207	cancel()
208
209	err = sa.persistCanceledTurn(ctx, SessionAgentCall{
210		SessionID: sess.ID,
211		Prompt:    "hello",
212	}, false)
213	require.NoError(t, err)
214
215	msgs, err := env.messages.List(t.Context(), sess.ID)
216	require.NoError(t, err)
217	require.Len(t, msgs, 2)
218}
219
220func TestClearPendingCancel(t *testing.T) {
221	t.Parallel()
222	sa, _ := newCancelTestAgent(t)
223
224	accept := sa.BeginAccepted("sid")
225	defer accept.Close()
226	sa.Cancel("sid")
227	require.True(t, sa.hasPendingCancel("sid"))
228
229	sa.clearPendingCancel("sid")
230	require.False(t, sa.hasPendingCancel("sid"))
231}