1package agent
2
3import (
4 "context"
5 "testing"
6
7 "github.com/charmbracelet/crush/internal/message"
8 "github.com/stretchr/testify/assert"
9 "github.com/stretchr/testify/require"
10)
11
12// newCancelTestAgent builds a DB-backed sessionAgent with no model. The
13// tests here exercise the dispatch/cancel/persist paths, none of which
14// reach agent.Stream, so a model is unnecessary.
15func newCancelTestAgent(t *testing.T) (*sessionAgent, fakeEnv) {
16 t.Helper()
17 env := testEnv(t)
18 sa := NewSessionAgent(SessionAgentOptions{
19 Sessions: env.sessions,
20 Messages: env.messages,
21 }).(*sessionAgent)
22 return sa, env
23}
24
25func (a *sessionAgent) acceptedCount(sessionID string) int {
26 c, _ := a.acceptedRuns.Get(sessionID)
27 return c
28}
29
30func (a *sessionAgent) hasPendingCancel(sessionID string) bool {
31 _, ok := a.pendingCancels.Get(sessionID)
32 return ok
33}
34
35func TestAcceptedRun_CloseIsIdempotent(t *testing.T) {
36 t.Parallel()
37 sa, _ := newCancelTestAgent(t)
38
39 accept := sa.BeginAccepted("sid")
40 require.Equal(t, "sid", accept.SessionID())
41 require.Equal(t, 1, sa.acceptedCount("sid"))
42
43 accept.Close()
44 require.Equal(t, 0, sa.acceptedCount("sid"))
45
46 // Repeated Close must not underflow the counter.
47 accept.Close()
48 accept.Close()
49 require.Equal(t, 0, sa.acceptedCount("sid"))
50}
51
52func TestAcceptedRun_MultipleReservations(t *testing.T) {
53 t.Parallel()
54 sa, _ := newCancelTestAgent(t)
55
56 a1 := sa.BeginAccepted("sid")
57 a2 := sa.BeginAccepted("sid")
58 require.Equal(t, 2, sa.acceptedCount("sid"))
59
60 a1.Close()
61 require.Equal(t, 1, sa.acceptedCount("sid"))
62
63 a2.Close()
64 require.Equal(t, 0, sa.acceptedCount("sid"))
65}
66
67func TestAcceptedRun_NilSafe(t *testing.T) {
68 t.Parallel()
69 var accept *AcceptedRun
70 require.Equal(t, "", accept.SessionID())
71 // Must not panic.
72 accept.Close()
73}
74
75func TestCancel_IdleDoesNotRecordPendingCancel(t *testing.T) {
76 t.Parallel()
77 sa, _ := newCancelTestAgent(t)
78
79 // No accepted run, no active request: a true no-op.
80 sa.Cancel("sid")
81 require.False(t, sa.hasPendingCancel("sid"))
82}
83
84func TestCancel_AcceptedRecordsPendingCancel(t *testing.T) {
85 t.Parallel()
86 sa, _ := newCancelTestAgent(t)
87
88 accept := sa.BeginAccepted("sid")
89 defer accept.Close()
90
91 sa.Cancel("sid")
92 require.True(t, sa.hasPendingCancel("sid"))
93}
94
95func TestCancel_SecondCancelWhilePendingIsNoOp(t *testing.T) {
96 t.Parallel()
97 sa, _ := newCancelTestAgent(t)
98
99 accept := sa.BeginAccepted("sid")
100 defer accept.Close()
101
102 sa.Cancel("sid")
103 require.True(t, sa.hasPendingCancel("sid"))
104
105 // A second cancel while a pending cancel is already recorded must
106 // remain a single pending cancel; one Run consumes exactly one.
107 sa.Cancel("sid")
108 require.True(t, sa.hasPendingCancel("sid"))
109}
110
111func TestRun_CancelOnEntryPersistsCanceledTurn(t *testing.T) {
112 t.Parallel()
113 sa, env := newCancelTestAgent(t)
114
115 sess, err := env.sessions.Create(t.Context(), "session")
116 require.NoError(t, err)
117
118 accept := sa.BeginAccepted(sess.ID)
119 // A cancel arrives in the accepted-but-not-yet-active window.
120 sa.Cancel(sess.ID)
121 require.True(t, sa.hasPendingCancel(sess.ID))
122
123 result, err := sa.Run(t.Context(), SessionAgentCall{
124 SessionID: sess.ID,
125 Prompt: "hello",
126 Accepted: accept,
127 })
128 require.NoError(t, err)
129 require.Nil(t, result)
130
131 // The pending cancel was consumed and the accept released.
132 require.False(t, sa.hasPendingCancel(sess.ID))
133 require.Equal(t, 0, sa.acceptedCount(sess.ID))
134
135 msgs, err := env.messages.List(t.Context(), sess.ID)
136 require.NoError(t, err)
137 require.Len(t, msgs, 2)
138 assert.Equal(t, message.User, msgs[0].Role)
139 assert.Equal(t, message.Assistant, msgs[1].Role)
140 assert.Equal(t, message.FinishReasonCanceled, msgs[1].FinishReason())
141}
142
143func TestPersistCanceledTurn_WritesBothWhenUserMissing(t *testing.T) {
144 t.Parallel()
145 sa, env := newCancelTestAgent(t)
146
147 sess, err := env.sessions.Create(t.Context(), "session")
148 require.NoError(t, err)
149
150 err = sa.persistCanceledTurn(t.Context(), SessionAgentCall{
151 SessionID: sess.ID,
152 Prompt: "hello",
153 }, false)
154 require.NoError(t, err)
155
156 msgs, err := env.messages.List(t.Context(), sess.ID)
157 require.NoError(t, err)
158 require.Len(t, msgs, 2)
159 assert.Equal(t, message.User, msgs[0].Role)
160 assert.Equal(t, message.Assistant, msgs[1].Role)
161 assert.Equal(t, message.FinishReasonCanceled, msgs[1].FinishReason())
162}
163
164func TestPersistCanceledTurn_WritesAssistantOnlyWhenUserCreated(t *testing.T) {
165 t.Parallel()
166 sa, env := newCancelTestAgent(t)
167
168 sess, err := env.sessions.Create(t.Context(), "session")
169 require.NoError(t, err)
170
171 // Simulate PrepareStep having already created the user message.
172 _, err = sa.createUserMessage(t.Context(), SessionAgentCall{
173 SessionID: sess.ID,
174 Prompt: "hello",
175 })
176 require.NoError(t, err)
177
178 err = sa.persistCanceledTurn(t.Context(), SessionAgentCall{
179 SessionID: sess.ID,
180 Prompt: "hello",
181 }, true)
182 require.NoError(t, err)
183
184 msgs, err := env.messages.List(t.Context(), sess.ID)
185 require.NoError(t, err)
186 require.Len(t, msgs, 2)
187 assert.Equal(t, message.User, msgs[0].Role)
188 assert.Equal(t, message.Assistant, msgs[1].Role)
189 assert.Equal(t, message.FinishReasonCanceled, msgs[1].FinishReason())
190}
191
192func TestPersistCanceledTurn_SucceedsWithCanceledContext(t *testing.T) {
193 t.Parallel()
194 sa, env := newCancelTestAgent(t)
195
196 sess, err := env.sessions.Create(t.Context(), "session")
197 require.NoError(t, err)
198
199 // Simulate workspace shutdown having already canceled the run
200 // context. WithoutCancel must let the writes through.
201 ctx, cancel := context.WithCancel(t.Context())
202 cancel()
203
204 err = sa.persistCanceledTurn(ctx, SessionAgentCall{
205 SessionID: sess.ID,
206 Prompt: "hello",
207 }, false)
208 require.NoError(t, err)
209
210 msgs, err := env.messages.List(t.Context(), sess.ID)
211 require.NoError(t, err)
212 require.Len(t, msgs, 2)
213}
214
215func TestClearPendingCancel(t *testing.T) {
216 t.Parallel()
217 sa, _ := newCancelTestAgent(t)
218
219 accept := sa.BeginAccepted("sid")
220 defer accept.Close()
221 sa.Cancel("sid")
222 require.True(t, sa.hasPendingCancel("sid"))
223
224 sa.clearPendingCancel("sid")
225 require.False(t, sa.hasPendingCancel("sid"))
226}