run_stream_test.go

  1package cmd
  2
  3import (
  4	"bytes"
  5	"errors"
  6	"testing"
  7	"time"
  8
  9	"github.com/charmbracelet/crush/internal/proto"
 10	"github.com/charmbracelet/crush/internal/pubsub"
 11	"github.com/stretchr/testify/require"
 12)
 13
 14// TestRunStream_ToolUseDoesNotTerminate is the regression test for
 15// the original bug: a tool-call assistant message has a Finish part
 16// with reason=tool_use and used to terminate `crush run` early via
 17// the discarded `msg.IsFinished()` exit condition. With the new
 18// RunComplete-driven loop, tool_use finishes must keep the stream
 19// alive so the post-tool final text still reaches stdout.
 20func TestRunStream_ToolUseDoesNotTerminate(t *testing.T) {
 21	t.Parallel()
 22
 23	s := &runStream{sessionID: "S", out: &bytes.Buffer{}, read: map[string]int{}}
 24
 25	toolUse := proto.Message{
 26		ID:        "m1",
 27		SessionID: "S",
 28		Role:      proto.Assistant,
 29		Parts: []proto.ContentPart{
 30			proto.TextContent{Text: ""},
 31			proto.Finish{Reason: proto.FinishReasonToolUse, Time: time.Now().Unix()},
 32		},
 33	}
 34	done, err := s.handle(pubsub.Event[proto.Message]{Payload: toolUse}, nil)
 35	require.NoError(t, err)
 36	require.False(t, done, "tool_use finish must NOT terminate the run loop")
 37}
 38
 39// TestRunStream_RunCompleteExits verifies the happy path: streaming
 40// assistant text then RunComplete terminates with the full final
 41// text on stdout. Together with the tool_use test above this
 42// nails down the "tool use + final text" sequence that the original
 43// bug truncated.
 44func TestRunStream_RunCompleteExits(t *testing.T) {
 45	t.Parallel()
 46
 47	buf := &bytes.Buffer{}
 48	s := &runStream{sessionID: "S", out: buf, read: map[string]int{}}
 49
 50	// Tool-use step.
 51	done, err := s.handle(pubsub.Event[proto.Message]{Payload: proto.Message{
 52		ID: "m1", SessionID: "S", Role: proto.Assistant,
 53		Parts: []proto.ContentPart{
 54			proto.TextContent{Text: ""},
 55			proto.Finish{Reason: proto.FinishReasonToolUse},
 56		},
 57	}}, nil)
 58	require.NoError(t, err)
 59	require.False(t, done)
 60
 61	// Final assistant message stream.
 62	done, err = s.handle(pubsub.Event[proto.Message]{Payload: proto.Message{
 63		ID: "m2", SessionID: "S", Role: proto.Assistant,
 64		Parts: []proto.ContentPart{
 65			proto.TextContent{Text: "VERDICT: APPROVED"},
 66			proto.Finish{Reason: proto.FinishReasonEndTurn},
 67		},
 68	}}, nil)
 69	require.NoError(t, err)
 70	require.False(t, done, "message finish (even end_turn) must not exit; RunComplete is the only terminal signal")
 71
 72	// RunComplete.
 73	done, err = s.handle(pubsub.Event[proto.RunComplete]{Payload: proto.RunComplete{
 74		SessionID: "S",
 75		MessageID: "m2",
 76		Text:      "VERDICT: APPROVED",
 77	}}, nil)
 78	require.NoError(t, err)
 79	require.True(t, done)
 80	require.Equal(t, "VERDICT: APPROVED", buf.String())
 81}
 82
 83// TestRunStream_ReconcilesOnOutOfOrderRunComplete is the worst-case
 84// ordering scenario: RunComplete reaches the client BEFORE any of
 85// the streaming assistant message events for the turn (the pubsub
 86// fan-in across upstream brokers does not preserve cross-broker
 87// ordering). The embedded Text field must rescue stdout so the
 88// caller still sees the complete final text.
 89func TestRunStream_ReconcilesOnOutOfOrderRunComplete(t *testing.T) {
 90	t.Parallel()
 91
 92	buf := &bytes.Buffer{}
 93	s := &runStream{sessionID: "S", out: buf, read: map[string]int{}}
 94
 95	done, err := s.handle(pubsub.Event[proto.RunComplete]{Payload: proto.RunComplete{
 96		SessionID: "S",
 97		MessageID: "m2",
 98		Text:      "VERDICT: APPROVED",
 99	}}, nil)
100	require.NoError(t, err)
101	require.True(t, done)
102	require.Equal(t, "VERDICT: APPROVED", buf.String(),
103		"RunComplete must reconcile stdout when message events did not arrive in time")
104}
105
106// TestRunStream_ReconcilesPartialStream covers the realistic case
107// where some streaming output reached stdout before RunComplete
108// arrived: the reconciliation pass must append only the unread tail,
109// never duplicate the prefix.
110func TestRunStream_ReconcilesPartialStream(t *testing.T) {
111	t.Parallel()
112
113	buf := &bytes.Buffer{}
114	s := &runStream{sessionID: "S", out: buf, read: map[string]int{}}
115
116	_, err := s.handle(pubsub.Event[proto.Message]{Payload: proto.Message{
117		ID: "m2", SessionID: "S", Role: proto.Assistant,
118		Parts: []proto.ContentPart{proto.TextContent{Text: "VERDICT: "}},
119	}}, nil)
120	require.NoError(t, err)
121
122	_, err = s.handle(pubsub.Event[proto.RunComplete]{Payload: proto.RunComplete{
123		SessionID: "S",
124		MessageID: "m2",
125		Text:      "VERDICT: APPROVED",
126	}}, nil)
127	require.NoError(t, err)
128	require.Equal(t, "VERDICT: APPROVED", buf.String())
129}
130
131// TestRunStream_IgnoresOtherSessions ensures multi-session
132// subscribers (e.g. a TUI watching workspace events while `crush
133// run` is in flight against the same workspace) do not cause
134// premature exit on RunComplete for a different session.
135func TestRunStream_IgnoresOtherSessions(t *testing.T) {
136	t.Parallel()
137
138	s := &runStream{sessionID: "S", out: &bytes.Buffer{}, read: map[string]int{}}
139	done, err := s.handle(pubsub.Event[proto.RunComplete]{Payload: proto.RunComplete{
140		SessionID: "OTHER",
141		MessageID: "x",
142		Text:      "noise",
143	}}, nil)
144	require.NoError(t, err)
145	require.False(t, done)
146}
147
148// TestRunStream_ErrorRunComplete surfaces a failing run as a
149// non-nil error from `crush run` so shells and CI catch it via
150// exit status.
151func TestRunStream_ErrorRunComplete(t *testing.T) {
152	t.Parallel()
153
154	s := &runStream{sessionID: "S", out: &bytes.Buffer{}, read: map[string]int{}}
155	done, err := s.handle(pubsub.Event[proto.RunComplete]{Payload: proto.RunComplete{
156		SessionID: "S",
157		Error:     "model temporarily unavailable",
158	}}, nil)
159	require.True(t, done)
160	require.Error(t, err)
161	require.Contains(t, err.Error(), "model temporarily unavailable")
162}
163
164// TestRunStream_CancelledRunCompleteIsClean ensures a cancelled
165// run (e.g. Ctrl+C while `crush run` waits) exits cleanly rather
166// than reporting the cancellation as a failure.
167func TestRunStream_CancelledRunCompleteIsClean(t *testing.T) {
168	t.Parallel()
169
170	s := &runStream{sessionID: "S", out: &bytes.Buffer{}, read: map[string]int{}}
171	done, err := s.handle(pubsub.Event[proto.RunComplete]{Payload: proto.RunComplete{
172		SessionID: "S",
173		Error:     "context canceled",
174		Cancelled: true,
175	}}, nil)
176	require.True(t, done)
177	require.NoError(t, err)
178}
179
180// TestRunStream_LeadingWhitespaceTrimmedOnce mirrors the
181// pre-existing trim of leading whitespace on the first byte of
182// stdout: the trim must happen exactly once even when stdout is
183// first produced by the RunComplete reconciliation path rather
184// than the live stream.
185func TestRunStream_LeadingWhitespaceTrimmedOnce(t *testing.T) {
186	t.Parallel()
187
188	buf := &bytes.Buffer{}
189	s := &runStream{sessionID: "S", out: buf, read: map[string]int{}}
190
191	done, err := s.handle(pubsub.Event[proto.RunComplete]{Payload: proto.RunComplete{
192		SessionID: "S",
193		MessageID: "m2",
194		Text:      "  \tactual output",
195	}}, nil)
196	require.NoError(t, err)
197	require.True(t, done)
198	require.Equal(t, "actual output", buf.String())
199}
200
201// TestRunStream_StopSpinnerInvokedOnFirstOutput verifies the
202// spinner is stopped exactly when meaningful output starts (either
203// a streamed assistant message or the reconciliation tail). This
204// matches the prior behaviour and prevents the spinner from
205// painting over the final response on TTYs.
206func TestRunStream_StopSpinnerInvokedOnFirstOutput(t *testing.T) {
207	t.Parallel()
208
209	calls := 0
210	stop := func() { calls++ }
211	s := &runStream{sessionID: "S", out: &bytes.Buffer{}, read: map[string]int{}}
212	_, _ = s.handle(pubsub.Event[proto.Message]{Payload: proto.Message{
213		ID: "m1", SessionID: "S", Role: proto.Assistant,
214		Parts: []proto.ContentPart{proto.TextContent{Text: "hi"}},
215	}}, stop)
216	require.GreaterOrEqual(t, calls, 1, "spinner must stop once stdout has content")
217}
218
219// TestRunStream_RunIDFiltersForeignTurns covers the busy-session
220// queue scenario: `crush run --continue` attaches to a session
221// whose currently running turn finishes first, publishing its
222// RunComplete on the same session ID. Without per-run correlation
223// the stream would exit on that foreign event and drop our own
224// queued turn's output. With RunID filtering the foreign event is
225// ignored and only the matching RunComplete terminates the stream.
226func TestRunStream_RunIDFiltersForeignTurns(t *testing.T) {
227	t.Parallel()
228
229	const sessionID = "S"
230	const myRun = "run-mine"
231	const otherRun = "run-other"
232
233	buf := &bytes.Buffer{}
234	s := &runStream{
235		sessionID: sessionID,
236		runID:     myRun,
237		out:       buf,
238		read:      map[string]int{},
239	}
240
241	// The busy session's existing turn emits more text before it finishes.
242	done, err := s.handle(pubsub.Event[proto.Message]{Payload: proto.Message{
243		ID:        "other-msg",
244		SessionID: sessionID,
245		Role:      proto.Assistant,
246		Parts:     []proto.ContentPart{proto.TextContent{Text: "noise from another turn"}},
247	}}, nil)
248	require.NoError(t, err)
249	require.False(t, done,
250		"foreign message on same session must not terminate our run")
251	require.Empty(t, buf.String(),
252		"foreign message on same session must not write to our stdout")
253
254	// The busy session's existing turn finishes first.
255	done, err = s.handle(pubsub.Event[proto.RunComplete]{Payload: proto.RunComplete{
256		SessionID: sessionID,
257		RunID:     otherRun,
258		MessageID: "other-msg",
259		Text:      "noise from another turn",
260	}}, nil)
261	require.NoError(t, err)
262	require.False(t, done,
263		"foreign RunComplete on same session must not terminate our run")
264	require.Empty(t, buf.String(),
265		"foreign RunComplete must not write to our stdout")
266
267	// Our own queued turn eventually finishes.
268	done, err = s.handle(pubsub.Event[proto.RunComplete]{Payload: proto.RunComplete{
269		SessionID: sessionID,
270		RunID:     myRun,
271		MessageID: "my-msg",
272		Text:      "OK",
273	}}, nil)
274	require.NoError(t, err)
275	require.True(t, done, "matching RunID must terminate the stream")
276	require.Equal(t, "OK", buf.String())
277}
278
279func TestRunStream_RunIDSuppressesLiveMessagesAndPrintsRunComplete(t *testing.T) {
280	t.Parallel()
281
282	buf := &bytes.Buffer{}
283	s := &runStream{
284		sessionID: "S",
285		runID:     "run-mine",
286		out:       buf,
287		read:      map[string]int{},
288	}
289
290	done, err := s.handle(pubsub.Event[proto.Message]{Payload: proto.Message{
291		ID:        "my-msg",
292		SessionID: "S",
293		Role:      proto.Assistant,
294		Parts:     []proto.ContentPart{proto.TextContent{Text: "streamed prefix"}},
295	}}, nil)
296	require.NoError(t, err)
297	require.False(t, done)
298	require.Empty(t, buf.String())
299
300	done, err = s.handle(pubsub.Event[proto.RunComplete]{Payload: proto.RunComplete{
301		SessionID: "S",
302		RunID:     "run-mine",
303		MessageID: "my-msg",
304		Text:      "streamed prefix final",
305	}}, nil)
306	require.NoError(t, err)
307	require.True(t, done)
308	require.Equal(t, "streamed prefix final", buf.String())
309}
310
311// TestRunStream_AgentErrorRunIDFiltersForeign verifies that an async
312// agent error carrying a non-empty RunID is fatal only when it matches
313// our run. A foreign RunID is ignored regardless of the event's
314// SessionID, because RunID is the authoritative correlator and async
315// errors share the agent event channel: without strict RunID matching
316// an unrelated workspace failure would abort our run.
317func TestRunStream_AgentErrorRunIDFiltersForeign(t *testing.T) {
318	t.Parallel()
319
320	// Foreign RunID with a matching session is still foreign.
321	s := &runStream{sessionID: "S", runID: "run-mine", out: &bytes.Buffer{}, read: map[string]int{}}
322	done, err := s.handle(pubsub.Event[proto.AgentEvent]{Payload: proto.AgentEvent{
323		Type:      proto.AgentEventTypeError,
324		SessionID: "S",
325		RunID:     "run-other",
326		Error:     errors.New("foreign boom"),
327	}}, nil)
328	require.NoError(t, err, "foreign RunID error must not abort our run")
329	require.False(t, done)
330
331	// Foreign RunID with a different session is ignored.
332	done, err = s.handle(pubsub.Event[proto.AgentEvent]{Payload: proto.AgentEvent{
333		Type:      proto.AgentEventTypeError,
334		SessionID: "other",
335		RunID:     "run-other",
336		Error:     errors.New("foreign boom"),
337	}}, nil)
338	require.NoError(t, err, "foreign RunID error must not abort our run")
339	require.False(t, done)
340
341	// Foreign RunID with a missing session is ignored.
342	done, err = s.handle(pubsub.Event[proto.AgentEvent]{Payload: proto.AgentEvent{
343		Type:  proto.AgentEventTypeError,
344		RunID: "run-other",
345		Error: errors.New("foreign boom"),
346	}}, nil)
347	require.NoError(t, err, "foreign RunID error must not abort our run")
348	require.False(t, done)
349
350	// Matching RunID is fatal.
351	done, err = s.handle(pubsub.Event[proto.AgentEvent]{Payload: proto.AgentEvent{
352		Type:      proto.AgentEventTypeError,
353		SessionID: "S",
354		RunID:     "run-mine",
355		Error:     errors.New("my boom"),
356	}}, nil)
357	require.Error(t, err, "matching RunID error must be fatal")
358	require.True(t, done)
359}
360
361// TestRunStream_AgentErrorNoRunIDFiltersBySession verifies the
362// compatibility fallback: when the event carries no RunID, attribution
363// falls back to SessionID. An error for another session or with an
364// empty session is ignored, while an error for our own session is fatal
365// so a real failure is never dropped.
366func TestRunStream_AgentErrorNoRunIDFiltersBySession(t *testing.T) {
367	t.Parallel()
368
369	s := &runStream{sessionID: "S", out: &bytes.Buffer{}, read: map[string]int{}}
370
371	// Empty RunID for another session is ignored.
372	done, err := s.handle(pubsub.Event[proto.AgentEvent]{Payload: proto.AgentEvent{
373		Type:      proto.AgentEventTypeError,
374		SessionID: "other",
375		Error:     errors.New("foreign boom"),
376	}}, nil)
377	require.NoError(t, err, "error for another session must not abort our run")
378	require.False(t, done)
379
380	// Empty RunID with an empty session is ignored.
381	done, err = s.handle(pubsub.Event[proto.AgentEvent]{Payload: proto.AgentEvent{
382		Type:  proto.AgentEventTypeError,
383		Error: errors.New("foreign boom"),
384	}}, nil)
385	require.NoError(t, err, "error with no session must not abort our run")
386	require.False(t, done)
387
388	// Empty RunID with a matching session is fatal.
389	done, err = s.handle(pubsub.Event[proto.AgentEvent]{Payload: proto.AgentEvent{
390		Type:      proto.AgentEventTypeError,
391		SessionID: "S",
392		Error:     errors.New("my boom"),
393	}}, nil)
394	require.Error(t, err, "error for our own session must be fatal")
395	require.True(t, done)
396}
397
398// TestRunStream_NoRunIDFallsBackToSessionID preserves the older
399// behaviour for callers (and tests) that don't supply a RunID:
400// SessionID-only matching still terminates the stream on the
401// session's RunComplete. This keeps the contract backwards
402// compatible with servers that don't echo RunID and with the
403// pre-existing TestRunStream_* assertions.
404func TestRunStream_NoRunIDFallsBackToSessionID(t *testing.T) {
405	t.Parallel()
406
407	buf := &bytes.Buffer{}
408	s := &runStream{sessionID: "S", out: buf, read: map[string]int{}}
409	done, err := s.handle(pubsub.Event[proto.RunComplete]{Payload: proto.RunComplete{
410		SessionID: "S",
411		MessageID: "m2",
412		Text:      "DONE",
413	}}, nil)
414	require.NoError(t, err)
415	require.True(t, done)
416	require.Equal(t, "DONE", buf.String())
417}