1package agent
2
3import (
4 "context"
5 "testing"
6
7 "github.com/charmbracelet/crush/internal/message"
8 "github.com/stretchr/testify/assert"
9 "github.com/stretchr/testify/require"
10)
11
12// newCancelTestAgent builds a DB-backed sessionAgent with no model. The
13// tests here exercise the dispatch/cancel/persist paths, none of which
14// reach agent.Stream, so a model is unnecessary.
15func newCancelTestAgent(t *testing.T) (*sessionAgent, fakeEnv) {
16 t.Helper()
17 env := testEnv(t)
18 sa := NewSessionAgent(SessionAgentOptions{
19 Sessions: env.sessions,
20 Messages: env.messages,
21 }).(*sessionAgent)
22 return sa, env
23}
24
25func (a *sessionAgent) acceptedCount(sessionID string) int {
26 c, _ := a.acceptedRuns.Get(sessionID)
27 return c
28}
29
30func (a *sessionAgent) hasPendingCancel(sessionID string) bool {
31 mark, ok := a.cancelMark.Get(sessionID)
32 return ok && mark > 0
33}
34
35func (a *sessionAgent) pendingCancelMark(sessionID string) uint64 {
36 mark, _ := a.cancelMark.Get(sessionID)
37 return mark
38}
39
40func TestAcceptedRun_CloseIsIdempotent(t *testing.T) {
41 t.Parallel()
42 sa, _ := newCancelTestAgent(t)
43
44 accept := sa.BeginAccepted("sid")
45 require.Equal(t, "sid", accept.SessionID())
46 require.Equal(t, 1, sa.acceptedCount("sid"))
47
48 accept.Close()
49 require.Equal(t, 0, sa.acceptedCount("sid"))
50
51 // Repeated Close must not underflow the counter.
52 accept.Close()
53 accept.Close()
54 require.Equal(t, 0, sa.acceptedCount("sid"))
55}
56
57func TestAcceptedRun_MultipleReservations(t *testing.T) {
58 t.Parallel()
59 sa, _ := newCancelTestAgent(t)
60
61 a1 := sa.BeginAccepted("sid")
62 a2 := sa.BeginAccepted("sid")
63 require.Equal(t, 2, sa.acceptedCount("sid"))
64
65 a1.Close()
66 require.Equal(t, 1, sa.acceptedCount("sid"))
67
68 a2.Close()
69 require.Equal(t, 0, sa.acceptedCount("sid"))
70}
71
72func TestAcceptedRun_NilSafe(t *testing.T) {
73 t.Parallel()
74 var accept *AcceptedRun
75 require.Equal(t, "", accept.SessionID())
76 // Must not panic.
77 accept.Close()
78}
79
80func TestCancel_IdleDoesNotRecordPendingCancel(t *testing.T) {
81 t.Parallel()
82 sa, _ := newCancelTestAgent(t)
83
84 // No accepted run, no active request: a true no-op.
85 sa.Cancel("sid")
86 require.False(t, sa.hasPendingCancel("sid"))
87}
88
89func TestCancel_AcceptedRecordsPendingCancel(t *testing.T) {
90 t.Parallel()
91 sa, _ := newCancelTestAgent(t)
92
93 accept := sa.BeginAccepted("sid")
94 defer accept.Close()
95
96 sa.Cancel("sid")
97 require.True(t, sa.hasPendingCancel("sid"))
98}
99
100func TestCancel_SecondCancelWhilePendingIsNoOp(t *testing.T) {
101 t.Parallel()
102 sa, _ := newCancelTestAgent(t)
103
104 accept := sa.BeginAccepted("sid")
105 defer accept.Close()
106
107 sa.Cancel("sid")
108 require.True(t, sa.hasPendingCancel("sid"))
109
110 // A second cancel while a pending cancel is already recorded must
111 // remain a single pending cancel; one Run consumes exactly one.
112 sa.Cancel("sid")
113 require.True(t, sa.hasPendingCancel("sid"))
114}
115
116func TestRun_CancelOnEntryPersistsCanceledTurn(t *testing.T) {
117 t.Parallel()
118 sa, env := newCancelTestAgent(t)
119
120 sess, err := env.sessions.Create(t.Context(), "session")
121 require.NoError(t, err)
122
123 accept := sa.BeginAccepted(sess.ID)
124 // A cancel arrives in the accepted-but-not-yet-active window.
125 sa.Cancel(sess.ID)
126 require.True(t, sa.hasPendingCancel(sess.ID))
127
128 result, err := sa.Run(t.Context(), SessionAgentCall{
129 SessionID: sess.ID,
130 Prompt: "hello",
131 Accepted: accept,
132 })
133 require.NoError(t, err)
134 require.Nil(t, result)
135
136 // The pending cancel was consumed and the accept released.
137 require.False(t, sa.hasPendingCancel(sess.ID))
138 require.Equal(t, 0, sa.acceptedCount(sess.ID))
139
140 msgs, err := env.messages.List(t.Context(), sess.ID)
141 require.NoError(t, err)
142 require.Len(t, msgs, 2)
143 assert.Equal(t, message.User, msgs[0].Role)
144 assert.Equal(t, message.Assistant, msgs[1].Role)
145 assert.Equal(t, message.FinishReasonCanceled, msgs[1].FinishReason())
146}
147
148func TestPersistCanceledTurn_WritesBothWhenUserMissing(t *testing.T) {
149 t.Parallel()
150 sa, env := newCancelTestAgent(t)
151
152 sess, err := env.sessions.Create(t.Context(), "session")
153 require.NoError(t, err)
154
155 err = sa.persistCanceledTurn(t.Context(), SessionAgentCall{
156 SessionID: sess.ID,
157 Prompt: "hello",
158 }, false)
159 require.NoError(t, err)
160
161 msgs, err := env.messages.List(t.Context(), sess.ID)
162 require.NoError(t, err)
163 require.Len(t, msgs, 2)
164 assert.Equal(t, message.User, msgs[0].Role)
165 assert.Equal(t, message.Assistant, msgs[1].Role)
166 assert.Equal(t, message.FinishReasonCanceled, msgs[1].FinishReason())
167}
168
169func TestPersistCanceledTurn_WritesAssistantOnlyWhenUserCreated(t *testing.T) {
170 t.Parallel()
171 sa, env := newCancelTestAgent(t)
172
173 sess, err := env.sessions.Create(t.Context(), "session")
174 require.NoError(t, err)
175
176 // Simulate PrepareStep having already created the user message.
177 _, err = sa.createUserMessage(t.Context(), SessionAgentCall{
178 SessionID: sess.ID,
179 Prompt: "hello",
180 })
181 require.NoError(t, err)
182
183 err = sa.persistCanceledTurn(t.Context(), SessionAgentCall{
184 SessionID: sess.ID,
185 Prompt: "hello",
186 }, true)
187 require.NoError(t, err)
188
189 msgs, err := env.messages.List(t.Context(), sess.ID)
190 require.NoError(t, err)
191 require.Len(t, msgs, 2)
192 assert.Equal(t, message.User, msgs[0].Role)
193 assert.Equal(t, message.Assistant, msgs[1].Role)
194 assert.Equal(t, message.FinishReasonCanceled, msgs[1].FinishReason())
195}
196
197func TestPersistCanceledTurn_SucceedsWithCanceledContext(t *testing.T) {
198 t.Parallel()
199 sa, env := newCancelTestAgent(t)
200
201 sess, err := env.sessions.Create(t.Context(), "session")
202 require.NoError(t, err)
203
204 // Simulate workspace shutdown having already canceled the run
205 // context. WithoutCancel must let the writes through.
206 ctx, cancel := context.WithCancel(t.Context())
207 cancel()
208
209 err = sa.persistCanceledTurn(ctx, SessionAgentCall{
210 SessionID: sess.ID,
211 Prompt: "hello",
212 }, false)
213 require.NoError(t, err)
214
215 msgs, err := env.messages.List(t.Context(), sess.ID)
216 require.NoError(t, err)
217 require.Len(t, msgs, 2)
218}
219
220func TestClearPendingCancel(t *testing.T) {
221 t.Parallel()
222 sa, _ := newCancelTestAgent(t)
223
224 accept := sa.BeginAccepted("sid")
225 defer accept.Close()
226 sa.Cancel("sid")
227 require.True(t, sa.hasPendingCancel("sid"))
228
229 sa.clearPendingCancel("sid")
230 require.False(t, sa.hasPendingCancel("sid"))
231}