accepted_run_test.go

  1package agent
  2
  3import (
  4	"context"
  5	"testing"
  6
  7	"github.com/charmbracelet/crush/internal/message"
  8	"github.com/stretchr/testify/assert"
  9	"github.com/stretchr/testify/require"
 10)
 11
 12// newCancelTestAgent builds a DB-backed sessionAgent with no model. The
 13// tests here exercise the dispatch/cancel/persist paths, none of which
 14// reach agent.Stream, so a model is unnecessary.
 15func newCancelTestAgent(t *testing.T) (*sessionAgent, fakeEnv) {
 16	t.Helper()
 17	env := testEnv(t)
 18	sa := NewSessionAgent(SessionAgentOptions{
 19		Sessions: env.sessions,
 20		Messages: env.messages,
 21	}).(*sessionAgent)
 22	return sa, env
 23}
 24
 25func (a *sessionAgent) acceptedCount(sessionID string) int {
 26	c, _ := a.acceptedRuns.Get(sessionID)
 27	return c
 28}
 29
 30func (a *sessionAgent) hasPendingCancel(sessionID string) bool {
 31	_, ok := a.pendingCancels.Get(sessionID)
 32	return ok
 33}
 34
 35func TestAcceptedRun_CloseIsIdempotent(t *testing.T) {
 36	t.Parallel()
 37	sa, _ := newCancelTestAgent(t)
 38
 39	accept := sa.BeginAccepted("sid")
 40	require.Equal(t, "sid", accept.SessionID())
 41	require.Equal(t, 1, sa.acceptedCount("sid"))
 42
 43	accept.Close()
 44	require.Equal(t, 0, sa.acceptedCount("sid"))
 45
 46	// Repeated Close must not underflow the counter.
 47	accept.Close()
 48	accept.Close()
 49	require.Equal(t, 0, sa.acceptedCount("sid"))
 50}
 51
 52func TestAcceptedRun_MultipleReservations(t *testing.T) {
 53	t.Parallel()
 54	sa, _ := newCancelTestAgent(t)
 55
 56	a1 := sa.BeginAccepted("sid")
 57	a2 := sa.BeginAccepted("sid")
 58	require.Equal(t, 2, sa.acceptedCount("sid"))
 59
 60	a1.Close()
 61	require.Equal(t, 1, sa.acceptedCount("sid"))
 62
 63	a2.Close()
 64	require.Equal(t, 0, sa.acceptedCount("sid"))
 65}
 66
 67func TestAcceptedRun_NilSafe(t *testing.T) {
 68	t.Parallel()
 69	var accept *AcceptedRun
 70	require.Equal(t, "", accept.SessionID())
 71	// Must not panic.
 72	accept.Close()
 73}
 74
 75func TestCancel_IdleDoesNotRecordPendingCancel(t *testing.T) {
 76	t.Parallel()
 77	sa, _ := newCancelTestAgent(t)
 78
 79	// No accepted run, no active request: a true no-op.
 80	sa.Cancel("sid")
 81	require.False(t, sa.hasPendingCancel("sid"))
 82}
 83
 84func TestCancel_AcceptedRecordsPendingCancel(t *testing.T) {
 85	t.Parallel()
 86	sa, _ := newCancelTestAgent(t)
 87
 88	accept := sa.BeginAccepted("sid")
 89	defer accept.Close()
 90
 91	sa.Cancel("sid")
 92	require.True(t, sa.hasPendingCancel("sid"))
 93}
 94
 95func TestCancel_SecondCancelWhilePendingIsNoOp(t *testing.T) {
 96	t.Parallel()
 97	sa, _ := newCancelTestAgent(t)
 98
 99	accept := sa.BeginAccepted("sid")
100	defer accept.Close()
101
102	sa.Cancel("sid")
103	require.True(t, sa.hasPendingCancel("sid"))
104
105	// A second cancel while a pending cancel is already recorded must
106	// remain a single pending cancel; one Run consumes exactly one.
107	sa.Cancel("sid")
108	require.True(t, sa.hasPendingCancel("sid"))
109}
110
111func TestRun_CancelOnEntryPersistsCanceledTurn(t *testing.T) {
112	t.Parallel()
113	sa, env := newCancelTestAgent(t)
114
115	sess, err := env.sessions.Create(t.Context(), "session")
116	require.NoError(t, err)
117
118	accept := sa.BeginAccepted(sess.ID)
119	// A cancel arrives in the accepted-but-not-yet-active window.
120	sa.Cancel(sess.ID)
121	require.True(t, sa.hasPendingCancel(sess.ID))
122
123	result, err := sa.Run(t.Context(), SessionAgentCall{
124		SessionID: sess.ID,
125		Prompt:    "hello",
126		Accepted:  accept,
127	})
128	require.NoError(t, err)
129	require.Nil(t, result)
130
131	// The pending cancel was consumed and the accept released.
132	require.False(t, sa.hasPendingCancel(sess.ID))
133	require.Equal(t, 0, sa.acceptedCount(sess.ID))
134
135	msgs, err := env.messages.List(t.Context(), sess.ID)
136	require.NoError(t, err)
137	require.Len(t, msgs, 2)
138	assert.Equal(t, message.User, msgs[0].Role)
139	assert.Equal(t, message.Assistant, msgs[1].Role)
140	assert.Equal(t, message.FinishReasonCanceled, msgs[1].FinishReason())
141}
142
143func TestPersistCanceledTurn_WritesBothWhenUserMissing(t *testing.T) {
144	t.Parallel()
145	sa, env := newCancelTestAgent(t)
146
147	sess, err := env.sessions.Create(t.Context(), "session")
148	require.NoError(t, err)
149
150	err = sa.persistCanceledTurn(t.Context(), SessionAgentCall{
151		SessionID: sess.ID,
152		Prompt:    "hello",
153	}, false)
154	require.NoError(t, err)
155
156	msgs, err := env.messages.List(t.Context(), sess.ID)
157	require.NoError(t, err)
158	require.Len(t, msgs, 2)
159	assert.Equal(t, message.User, msgs[0].Role)
160	assert.Equal(t, message.Assistant, msgs[1].Role)
161	assert.Equal(t, message.FinishReasonCanceled, msgs[1].FinishReason())
162}
163
164func TestPersistCanceledTurn_WritesAssistantOnlyWhenUserCreated(t *testing.T) {
165	t.Parallel()
166	sa, env := newCancelTestAgent(t)
167
168	sess, err := env.sessions.Create(t.Context(), "session")
169	require.NoError(t, err)
170
171	// Simulate PrepareStep having already created the user message.
172	_, err = sa.createUserMessage(t.Context(), SessionAgentCall{
173		SessionID: sess.ID,
174		Prompt:    "hello",
175	})
176	require.NoError(t, err)
177
178	err = sa.persistCanceledTurn(t.Context(), SessionAgentCall{
179		SessionID: sess.ID,
180		Prompt:    "hello",
181	}, true)
182	require.NoError(t, err)
183
184	msgs, err := env.messages.List(t.Context(), sess.ID)
185	require.NoError(t, err)
186	require.Len(t, msgs, 2)
187	assert.Equal(t, message.User, msgs[0].Role)
188	assert.Equal(t, message.Assistant, msgs[1].Role)
189	assert.Equal(t, message.FinishReasonCanceled, msgs[1].FinishReason())
190}
191
192func TestPersistCanceledTurn_SucceedsWithCanceledContext(t *testing.T) {
193	t.Parallel()
194	sa, env := newCancelTestAgent(t)
195
196	sess, err := env.sessions.Create(t.Context(), "session")
197	require.NoError(t, err)
198
199	// Simulate workspace shutdown having already canceled the run
200	// context. WithoutCancel must let the writes through.
201	ctx, cancel := context.WithCancel(t.Context())
202	cancel()
203
204	err = sa.persistCanceledTurn(ctx, SessionAgentCall{
205		SessionID: sess.ID,
206		Prompt:    "hello",
207	}, false)
208	require.NoError(t, err)
209
210	msgs, err := env.messages.List(t.Context(), sess.ID)
211	require.NoError(t, err)
212	require.Len(t, msgs, 2)
213}
214
215func TestClearPendingCancel(t *testing.T) {
216	t.Parallel()
217	sa, _ := newCancelTestAgent(t)
218
219	accept := sa.BeginAccepted("sid")
220	defer accept.Close()
221	sa.Cancel("sid")
222	require.True(t, sa.hasPendingCancel("sid"))
223
224	sa.clearPendingCancel("sid")
225	require.False(t, sa.hasPendingCancel("sid"))
226}