1package cmd
2
3import (
4 "bytes"
5 "errors"
6 "testing"
7 "time"
8
9 "github.com/charmbracelet/crush/internal/proto"
10 "github.com/charmbracelet/crush/internal/pubsub"
11 "github.com/stretchr/testify/require"
12)
13
14// TestRunStream_ToolUseDoesNotTerminate is the regression test for
15// the original bug: a tool-call assistant message has a Finish part
16// with reason=tool_use and used to terminate `crush run` early via
17// the discarded `msg.IsFinished()` exit condition. With the new
18// RunComplete-driven loop, tool_use finishes must keep the stream
19// alive so the post-tool final text still reaches stdout.
20func TestRunStream_ToolUseDoesNotTerminate(t *testing.T) {
21 t.Parallel()
22
23 s := &runStream{sessionID: "S", out: &bytes.Buffer{}, read: map[string]int{}}
24
25 toolUse := proto.Message{
26 ID: "m1",
27 SessionID: "S",
28 Role: proto.Assistant,
29 Parts: []proto.ContentPart{
30 proto.TextContent{Text: ""},
31 proto.Finish{Reason: proto.FinishReasonToolUse, Time: time.Now().Unix()},
32 },
33 }
34 done, err := s.handle(pubsub.Event[proto.Message]{Payload: toolUse}, nil)
35 require.NoError(t, err)
36 require.False(t, done, "tool_use finish must NOT terminate the run loop")
37}
38
39// TestRunStream_RunCompleteExits verifies the happy path: streaming
40// assistant text then RunComplete terminates with the full final
41// text on stdout. Together with the tool_use test above this
42// nails down the "tool use + final text" sequence that the original
43// bug truncated.
44func TestRunStream_RunCompleteExits(t *testing.T) {
45 t.Parallel()
46
47 buf := &bytes.Buffer{}
48 s := &runStream{sessionID: "S", out: buf, read: map[string]int{}}
49
50 // Tool-use step.
51 done, err := s.handle(pubsub.Event[proto.Message]{Payload: proto.Message{
52 ID: "m1", SessionID: "S", Role: proto.Assistant,
53 Parts: []proto.ContentPart{
54 proto.TextContent{Text: ""},
55 proto.Finish{Reason: proto.FinishReasonToolUse},
56 },
57 }}, nil)
58 require.NoError(t, err)
59 require.False(t, done)
60
61 // Final assistant message stream.
62 done, err = s.handle(pubsub.Event[proto.Message]{Payload: proto.Message{
63 ID: "m2", SessionID: "S", Role: proto.Assistant,
64 Parts: []proto.ContentPart{
65 proto.TextContent{Text: "VERDICT: APPROVED"},
66 proto.Finish{Reason: proto.FinishReasonEndTurn},
67 },
68 }}, nil)
69 require.NoError(t, err)
70 require.False(t, done, "message finish (even end_turn) must not exit; RunComplete is the only terminal signal")
71
72 // RunComplete.
73 done, err = s.handle(pubsub.Event[proto.RunComplete]{Payload: proto.RunComplete{
74 SessionID: "S",
75 MessageID: "m2",
76 Text: "VERDICT: APPROVED",
77 }}, nil)
78 require.NoError(t, err)
79 require.True(t, done)
80 require.Equal(t, "VERDICT: APPROVED", buf.String())
81}
82
83// TestRunStream_ReconcilesOnOutOfOrderRunComplete is the worst-case
84// ordering scenario: RunComplete reaches the client BEFORE any of
85// the streaming assistant message events for the turn (the pubsub
86// fan-in across upstream brokers does not preserve cross-broker
87// ordering). The embedded Text field must rescue stdout so the
88// caller still sees the complete final text.
89func TestRunStream_ReconcilesOnOutOfOrderRunComplete(t *testing.T) {
90 t.Parallel()
91
92 buf := &bytes.Buffer{}
93 s := &runStream{sessionID: "S", out: buf, read: map[string]int{}}
94
95 done, err := s.handle(pubsub.Event[proto.RunComplete]{Payload: proto.RunComplete{
96 SessionID: "S",
97 MessageID: "m2",
98 Text: "VERDICT: APPROVED",
99 }}, nil)
100 require.NoError(t, err)
101 require.True(t, done)
102 require.Equal(t, "VERDICT: APPROVED", buf.String(),
103 "RunComplete must reconcile stdout when message events did not arrive in time")
104}
105
106// TestRunStream_ReconcilesPartialStream covers the realistic case
107// where some streaming output reached stdout before RunComplete
108// arrived: the reconciliation pass must append only the unread tail,
109// never duplicate the prefix.
110func TestRunStream_ReconcilesPartialStream(t *testing.T) {
111 t.Parallel()
112
113 buf := &bytes.Buffer{}
114 s := &runStream{sessionID: "S", out: buf, read: map[string]int{}}
115
116 _, err := s.handle(pubsub.Event[proto.Message]{Payload: proto.Message{
117 ID: "m2", SessionID: "S", Role: proto.Assistant,
118 Parts: []proto.ContentPart{proto.TextContent{Text: "VERDICT: "}},
119 }}, nil)
120 require.NoError(t, err)
121
122 _, err = s.handle(pubsub.Event[proto.RunComplete]{Payload: proto.RunComplete{
123 SessionID: "S",
124 MessageID: "m2",
125 Text: "VERDICT: APPROVED",
126 }}, nil)
127 require.NoError(t, err)
128 require.Equal(t, "VERDICT: APPROVED", buf.String())
129}
130
131// TestRunStream_IgnoresOtherSessions ensures multi-session
132// subscribers (e.g. a TUI watching workspace events while `crush
133// run` is in flight against the same workspace) do not cause
134// premature exit on RunComplete for a different session.
135func TestRunStream_IgnoresOtherSessions(t *testing.T) {
136 t.Parallel()
137
138 s := &runStream{sessionID: "S", out: &bytes.Buffer{}, read: map[string]int{}}
139 done, err := s.handle(pubsub.Event[proto.RunComplete]{Payload: proto.RunComplete{
140 SessionID: "OTHER",
141 MessageID: "x",
142 Text: "noise",
143 }}, nil)
144 require.NoError(t, err)
145 require.False(t, done)
146}
147
148// TestRunStream_ErrorRunComplete surfaces a failing run as a
149// non-nil error from `crush run` so shells and CI catch it via
150// exit status.
151func TestRunStream_ErrorRunComplete(t *testing.T) {
152 t.Parallel()
153
154 s := &runStream{sessionID: "S", out: &bytes.Buffer{}, read: map[string]int{}}
155 done, err := s.handle(pubsub.Event[proto.RunComplete]{Payload: proto.RunComplete{
156 SessionID: "S",
157 Error: "model temporarily unavailable",
158 }}, nil)
159 require.True(t, done)
160 require.Error(t, err)
161 require.Contains(t, err.Error(), "model temporarily unavailable")
162}
163
164// TestRunStream_CancelledRunCompleteIsClean ensures a cancelled
165// run (e.g. Ctrl+C while `crush run` waits) exits cleanly rather
166// than reporting the cancellation as a failure.
167func TestRunStream_CancelledRunCompleteIsClean(t *testing.T) {
168 t.Parallel()
169
170 s := &runStream{sessionID: "S", out: &bytes.Buffer{}, read: map[string]int{}}
171 done, err := s.handle(pubsub.Event[proto.RunComplete]{Payload: proto.RunComplete{
172 SessionID: "S",
173 Error: "context canceled",
174 Cancelled: true,
175 }}, nil)
176 require.True(t, done)
177 require.NoError(t, err)
178}
179
180// TestRunStream_LeadingWhitespaceTrimmedOnce mirrors the
181// pre-existing trim of leading whitespace on the first byte of
182// stdout: the trim must happen exactly once even when stdout is
183// first produced by the RunComplete reconciliation path rather
184// than the live stream.
185func TestRunStream_LeadingWhitespaceTrimmedOnce(t *testing.T) {
186 t.Parallel()
187
188 buf := &bytes.Buffer{}
189 s := &runStream{sessionID: "S", out: buf, read: map[string]int{}}
190
191 done, err := s.handle(pubsub.Event[proto.RunComplete]{Payload: proto.RunComplete{
192 SessionID: "S",
193 MessageID: "m2",
194 Text: " \tactual output",
195 }}, nil)
196 require.NoError(t, err)
197 require.True(t, done)
198 require.Equal(t, "actual output", buf.String())
199}
200
201// TestRunStream_StopSpinnerInvokedOnFirstOutput verifies the
202// spinner is stopped exactly when meaningful output starts (either
203// a streamed assistant message or the reconciliation tail). This
204// matches the prior behaviour and prevents the spinner from
205// painting over the final response on TTYs.
206func TestRunStream_StopSpinnerInvokedOnFirstOutput(t *testing.T) {
207 t.Parallel()
208
209 calls := 0
210 stop := func() { calls++ }
211 s := &runStream{sessionID: "S", out: &bytes.Buffer{}, read: map[string]int{}}
212 _, _ = s.handle(pubsub.Event[proto.Message]{Payload: proto.Message{
213 ID: "m1", SessionID: "S", Role: proto.Assistant,
214 Parts: []proto.ContentPart{proto.TextContent{Text: "hi"}},
215 }}, stop)
216 require.GreaterOrEqual(t, calls, 1, "spinner must stop once stdout has content")
217}
218
219// TestRunStream_RunIDFiltersForeignTurns covers the busy-session
220// queue scenario: `crush run --continue` attaches to a session
221// whose currently running turn finishes first, publishing its
222// RunComplete on the same session ID. Without per-run correlation
223// the stream would exit on that foreign event and drop our own
224// queued turn's output. With RunID filtering the foreign event is
225// ignored and only the matching RunComplete terminates the stream.
226func TestRunStream_RunIDFiltersForeignTurns(t *testing.T) {
227 t.Parallel()
228
229 const sessionID = "S"
230 const myRun = "run-mine"
231 const otherRun = "run-other"
232
233 buf := &bytes.Buffer{}
234 s := &runStream{
235 sessionID: sessionID,
236 runID: myRun,
237 out: buf,
238 read: map[string]int{},
239 }
240
241 // The busy session's existing turn emits more text before it finishes.
242 done, err := s.handle(pubsub.Event[proto.Message]{Payload: proto.Message{
243 ID: "other-msg",
244 SessionID: sessionID,
245 Role: proto.Assistant,
246 Parts: []proto.ContentPart{proto.TextContent{Text: "noise from another turn"}},
247 }}, nil)
248 require.NoError(t, err)
249 require.False(t, done,
250 "foreign message on same session must not terminate our run")
251 require.Empty(t, buf.String(),
252 "foreign message on same session must not write to our stdout")
253
254 // The busy session's existing turn finishes first.
255 done, err = s.handle(pubsub.Event[proto.RunComplete]{Payload: proto.RunComplete{
256 SessionID: sessionID,
257 RunID: otherRun,
258 MessageID: "other-msg",
259 Text: "noise from another turn",
260 }}, nil)
261 require.NoError(t, err)
262 require.False(t, done,
263 "foreign RunComplete on same session must not terminate our run")
264 require.Empty(t, buf.String(),
265 "foreign RunComplete must not write to our stdout")
266
267 // Our own queued turn eventually finishes.
268 done, err = s.handle(pubsub.Event[proto.RunComplete]{Payload: proto.RunComplete{
269 SessionID: sessionID,
270 RunID: myRun,
271 MessageID: "my-msg",
272 Text: "OK",
273 }}, nil)
274 require.NoError(t, err)
275 require.True(t, done, "matching RunID must terminate the stream")
276 require.Equal(t, "OK", buf.String())
277}
278
279func TestRunStream_RunIDSuppressesLiveMessagesAndPrintsRunComplete(t *testing.T) {
280 t.Parallel()
281
282 buf := &bytes.Buffer{}
283 s := &runStream{
284 sessionID: "S",
285 runID: "run-mine",
286 out: buf,
287 read: map[string]int{},
288 }
289
290 done, err := s.handle(pubsub.Event[proto.Message]{Payload: proto.Message{
291 ID: "my-msg",
292 SessionID: "S",
293 Role: proto.Assistant,
294 Parts: []proto.ContentPart{proto.TextContent{Text: "streamed prefix"}},
295 }}, nil)
296 require.NoError(t, err)
297 require.False(t, done)
298 require.Empty(t, buf.String())
299
300 done, err = s.handle(pubsub.Event[proto.RunComplete]{Payload: proto.RunComplete{
301 SessionID: "S",
302 RunID: "run-mine",
303 MessageID: "my-msg",
304 Text: "streamed prefix final",
305 }}, nil)
306 require.NoError(t, err)
307 require.True(t, done)
308 require.Equal(t, "streamed prefix final", buf.String())
309}
310
311// TestRunStream_AgentErrorRunIDFiltersForeign verifies that an async
312// agent error carrying a non-empty RunID is fatal only when it matches
313// our run. A foreign RunID is ignored regardless of the event's
314// SessionID, because RunID is the authoritative correlator and async
315// errors share the agent event channel: without strict RunID matching
316// an unrelated workspace failure would abort our run.
317func TestRunStream_AgentErrorRunIDFiltersForeign(t *testing.T) {
318 t.Parallel()
319
320 // Foreign RunID with a matching session is still foreign.
321 s := &runStream{sessionID: "S", runID: "run-mine", out: &bytes.Buffer{}, read: map[string]int{}}
322 done, err := s.handle(pubsub.Event[proto.AgentEvent]{Payload: proto.AgentEvent{
323 Type: proto.AgentEventTypeError,
324 SessionID: "S",
325 RunID: "run-other",
326 Error: errors.New("foreign boom"),
327 }}, nil)
328 require.NoError(t, err, "foreign RunID error must not abort our run")
329 require.False(t, done)
330
331 // Foreign RunID with a different session is ignored.
332 done, err = s.handle(pubsub.Event[proto.AgentEvent]{Payload: proto.AgentEvent{
333 Type: proto.AgentEventTypeError,
334 SessionID: "other",
335 RunID: "run-other",
336 Error: errors.New("foreign boom"),
337 }}, nil)
338 require.NoError(t, err, "foreign RunID error must not abort our run")
339 require.False(t, done)
340
341 // Foreign RunID with a missing session is ignored.
342 done, err = s.handle(pubsub.Event[proto.AgentEvent]{Payload: proto.AgentEvent{
343 Type: proto.AgentEventTypeError,
344 RunID: "run-other",
345 Error: errors.New("foreign boom"),
346 }}, nil)
347 require.NoError(t, err, "foreign RunID error must not abort our run")
348 require.False(t, done)
349
350 // Matching RunID is fatal.
351 done, err = s.handle(pubsub.Event[proto.AgentEvent]{Payload: proto.AgentEvent{
352 Type: proto.AgentEventTypeError,
353 SessionID: "S",
354 RunID: "run-mine",
355 Error: errors.New("my boom"),
356 }}, nil)
357 require.Error(t, err, "matching RunID error must be fatal")
358 require.True(t, done)
359}
360
361// TestRunStream_AgentErrorNoRunIDFiltersBySession verifies the
362// compatibility fallback: when the event carries no RunID, attribution
363// falls back to SessionID. An error for another session or with an
364// empty session is ignored, while an error for our own session is fatal
365// so a real failure is never dropped.
366func TestRunStream_AgentErrorNoRunIDFiltersBySession(t *testing.T) {
367 t.Parallel()
368
369 s := &runStream{sessionID: "S", out: &bytes.Buffer{}, read: map[string]int{}}
370
371 // Empty RunID for another session is ignored.
372 done, err := s.handle(pubsub.Event[proto.AgentEvent]{Payload: proto.AgentEvent{
373 Type: proto.AgentEventTypeError,
374 SessionID: "other",
375 Error: errors.New("foreign boom"),
376 }}, nil)
377 require.NoError(t, err, "error for another session must not abort our run")
378 require.False(t, done)
379
380 // Empty RunID with an empty session is ignored.
381 done, err = s.handle(pubsub.Event[proto.AgentEvent]{Payload: proto.AgentEvent{
382 Type: proto.AgentEventTypeError,
383 Error: errors.New("foreign boom"),
384 }}, nil)
385 require.NoError(t, err, "error with no session must not abort our run")
386 require.False(t, done)
387
388 // Empty RunID with a matching session is fatal.
389 done, err = s.handle(pubsub.Event[proto.AgentEvent]{Payload: proto.AgentEvent{
390 Type: proto.AgentEventTypeError,
391 SessionID: "S",
392 Error: errors.New("my boom"),
393 }}, nil)
394 require.Error(t, err, "error for our own session must be fatal")
395 require.True(t, done)
396}
397
398// TestRunStream_NoRunIDFallsBackToSessionID preserves the older
399// behaviour for callers (and tests) that don't supply a RunID:
400// SessionID-only matching still terminates the stream on the
401// session's RunComplete. This keeps the contract backwards
402// compatible with servers that don't echo RunID and with the
403// pre-existing TestRunStream_* assertions.
404func TestRunStream_NoRunIDFallsBackToSessionID(t *testing.T) {
405 t.Parallel()
406
407 buf := &bytes.Buffer{}
408 s := &runStream{sessionID: "S", out: buf, read: map[string]int{}}
409 done, err := s.handle(pubsub.Event[proto.RunComplete]{Payload: proto.RunComplete{
410 SessionID: "S",
411 MessageID: "m2",
412 Text: "DONE",
413 }}, nil)
414 require.NoError(t, err)
415 require.True(t, done)
416 require.Equal(t, "DONE", buf.String())
417}