1package agent
2
3import (
4 "context"
5 "errors"
6 "sync/atomic"
7 "testing"
8
9 "charm.land/fantasy"
10 "github.com/charmbracelet/crush/internal/message"
11 "github.com/stretchr/testify/assert"
12 "github.com/stretchr/testify/require"
13)
14
15// finishStreamModel is a minimal fantasy.LanguageModel that streams a
16// single text part followed by a normal (FinishReasonStop) finish. It
17// is enough to drive sessionAgent.Run through PrepareStep and a clean
18// completion without a recorded provider cassette.
19type finishStreamModel struct {
20 text string
21}
22
23func (m *finishStreamModel) Provider() string { return "fake" }
24func (m *finishStreamModel) Model() string { return "fake-model" }
25
26func (m *finishStreamModel) Generate(ctx context.Context, call fantasy.Call) (*fantasy.Response, error) {
27 return &fantasy.Response{
28 Content: fantasy.ResponseContent{fantasy.TextContent{Text: m.text}},
29 FinishReason: fantasy.FinishReasonStop,
30 }, nil
31}
32
33func (m *finishStreamModel) Stream(ctx context.Context, call fantasy.Call) (fantasy.StreamResponse, error) {
34 text := m.text
35 return func(yield func(fantasy.StreamPart) bool) {
36 if !yield(fantasy.StreamPart{Type: fantasy.StreamPartTypeTextStart, ID: "1"}) {
37 return
38 }
39 if !yield(fantasy.StreamPart{Type: fantasy.StreamPartTypeTextDelta, ID: "1", Delta: text}) {
40 return
41 }
42 if !yield(fantasy.StreamPart{Type: fantasy.StreamPartTypeTextEnd, ID: "1"}) {
43 return
44 }
45 yield(fantasy.StreamPart{Type: fantasy.StreamPartTypeFinish, FinishReason: fantasy.FinishReasonStop})
46 }, nil
47}
48
49func (m *finishStreamModel) GenerateObject(ctx context.Context, call fantasy.ObjectCall) (*fantasy.ObjectResponse, error) {
50 return nil, errors.New("not implemented")
51}
52
53func (m *finishStreamModel) StreamObject(ctx context.Context, call fantasy.ObjectCall) (fantasy.ObjectStreamResponse, error) {
54 return nil, errors.New("not implemented")
55}
56
57func newStreamTestAgent(t *testing.T) (*sessionAgent, fakeEnv) {
58 t.Helper()
59 env := testEnv(t)
60 model := &finishStreamModel{text: "done"}
61 sa := testSessionAgent(env, model, model, "system").(*sessionAgent)
62 return sa, env
63}
64
65// TestCancel_ActiveAndAcceptedFiresBothBranches covers the case where a
66// session is actively running (activeRequests set) AND a follow-up has
67// been accepted (acceptedRuns > 0). A single Cancel must fire both: it
68// invokes the active cancel func and records a pending cancel for the
69// accepted follow-up.
70func TestCancel_ActiveAndAcceptedFiresBothBranches(t *testing.T) {
71 t.Parallel()
72 sa, _ := newCancelTestAgent(t)
73
74 const sid = "sid"
75 var activeCanceled atomic.Bool
76 sa.activeRequests.Set(sid, func() { activeCanceled.Store(true) })
77
78 accept := sa.BeginAccepted(sid)
79 defer accept.Close()
80
81 sa.Cancel(sid)
82
83 require.True(t, activeCanceled.Load(), "active cancel func must fire")
84 require.True(t, sa.hasPendingCancel(sid), "accepted follow-up must record a pending cancel")
85}
86
87// TestRun_BusyWithPendingCancelTakesCancelOnEntry covers the busy-queue
88// branch consulting pendingCancels: when the session is busy AND a
89// cancel has been recorded for an accepted follow-up, Run must take the
90// cancel-on-entry path (persist a canceled turn) instead of enqueueing
91// the call behind the active run.
92func TestRun_BusyWithPendingCancelTakesCancelOnEntry(t *testing.T) {
93 t.Parallel()
94 sa, env := newCancelTestAgent(t)
95
96 sess, err := env.sessions.Create(t.Context(), "session")
97 require.NoError(t, err)
98
99 // Make the session look busy: an earlier prompt is active.
100 sa.activeRequests.Set(sess.ID, func() {})
101
102 accept := sa.BeginAccepted(sess.ID)
103 // A cancel arrives while this follow-up is accepted-but-not-active.
104 sa.Cancel(sess.ID)
105 require.True(t, sa.hasPendingCancel(sess.ID))
106
107 result, err := sa.Run(t.Context(), SessionAgentCall{
108 SessionID: sess.ID,
109 Prompt: "follow-up",
110 Accepted: accept,
111 })
112 require.NoError(t, err)
113 require.Nil(t, result)
114
115 // The follow-up was canceled on entry, not enqueued.
116 require.Equal(t, 0, sa.QueuedPrompts(sess.ID),
117 "cancel-on-entry must not enqueue the follow-up behind the active run")
118 require.False(t, sa.hasPendingCancel(sess.ID), "pending cancel must be consumed")
119 require.Equal(t, 0, sa.acceptedCount(sess.ID), "accept reservation must be released")
120
121 msgs, err := env.messages.List(t.Context(), sess.ID)
122 require.NoError(t, err)
123 require.Len(t, msgs, 2)
124 assert.Equal(t, message.User, msgs[0].Role)
125 assert.Equal(t, message.Assistant, msgs[1].Role)
126 assert.Equal(t, message.FinishReasonCanceled, msgs[1].FinishReason())
127}
128
129// TestRun_PrepareStepDrainSkipsQueuedOnPendingCancel verifies that the
130// queue drain inside PrepareStep skips queued follow-up prompts when a
131// cancel has been recorded for the session: the queued prompt must not
132// be folded into the active turn as an extra user message.
133func TestRun_PrepareStepDrainSkipsQueuedOnPendingCancel(t *testing.T) {
134 t.Parallel()
135 sa, env := newStreamTestAgent(t)
136
137 sess, err := env.sessions.Create(t.Context(), "session")
138 require.NoError(t, err)
139
140 // A follow-up prompt sits queued for the session.
141 sa.enqueueCall(SessionAgentCall{SessionID: sess.ID, Prompt: "queued-followup"})
142 // A cancel was recorded for the session while it sat in the queue.
143 sa.pendingCancels.Set(sess.ID, struct{}{})
144
145 result, err := sa.Run(t.Context(), SessionAgentCall{
146 SessionID: sess.ID,
147 Prompt: "main",
148 })
149 require.NoError(t, err)
150 require.NotNil(t, result)
151
152 // Only the main prompt produced a user message; the queued
153 // follow-up was skipped, not folded into the turn.
154 msgs, err := env.messages.List(t.Context(), sess.ID)
155 require.NoError(t, err)
156 var userMsgs []message.Message
157 for _, m := range msgs {
158 if m.Role == message.User {
159 userMsgs = append(userMsgs, m)
160 }
161 }
162 require.Len(t, userMsgs, 1, "queued follow-up must not create a user message")
163 assert.Equal(t, "main", userMsgs[0].Content().String())
164
165 // The queue was drained and the pending cancel consumed.
166 require.Equal(t, 0, sa.QueuedPrompts(sess.ID))
167 require.False(t, sa.hasPendingCancel(sess.ID))
168}
169
170// TestRun_NormalCompletionClearsStalePendingCancel verifies that a Run
171// which completes normally clears any stale pending-cancel entry for the
172// session, so it cannot catch a future run.
173func TestRun_NormalCompletionClearsStalePendingCancel(t *testing.T) {
174 t.Parallel()
175 sa, env := newStreamTestAgent(t)
176
177 sess, err := env.sessions.Create(t.Context(), "session")
178 require.NoError(t, err)
179
180 // A stale pending cancel lingers (no queued work, no accepted run).
181 sa.pendingCancels.Set(sess.ID, struct{}{})
182
183 result, err := sa.Run(t.Context(), SessionAgentCall{
184 SessionID: sess.ID,
185 Prompt: "main",
186 })
187 require.NoError(t, err)
188 require.NotNil(t, result)
189
190 require.False(t, sa.hasPendingCancel(sess.ID),
191 "normal completion must clear the stale pending cancel")
192
193 msgs, err := env.messages.List(t.Context(), sess.ID)
194 require.NoError(t, err)
195 require.Len(t, msgs, 2)
196 assert.Equal(t, message.Assistant, msgs[1].Role)
197 assert.Equal(t, message.FinishReasonEndTurn, msgs[1].FinishReason())
198}