run_stream_test.go

  1package cmd
  2
  3import (
  4	"bytes"
  5	"testing"
  6	"time"
  7
  8	"github.com/charmbracelet/crush/internal/proto"
  9	"github.com/charmbracelet/crush/internal/pubsub"
 10	"github.com/stretchr/testify/require"
 11)
 12
 13// TestRunStream_ToolUseDoesNotTerminate is the regression test for
 14// the original bug: a tool-call assistant message has a Finish part
 15// with reason=tool_use and used to terminate `crush run` early via
 16// the discarded `msg.IsFinished()` exit condition. With the new
 17// RunComplete-driven loop, tool_use finishes must keep the stream
 18// alive so the post-tool final text still reaches stdout.
 19func TestRunStream_ToolUseDoesNotTerminate(t *testing.T) {
 20	t.Parallel()
 21
 22	s := &runStream{sessionID: "S", out: &bytes.Buffer{}, read: map[string]int{}}
 23
 24	toolUse := proto.Message{
 25		ID:        "m1",
 26		SessionID: "S",
 27		Role:      proto.Assistant,
 28		Parts: []proto.ContentPart{
 29			proto.TextContent{Text: ""},
 30			proto.Finish{Reason: proto.FinishReasonToolUse, Time: time.Now().Unix()},
 31		},
 32	}
 33	done, err := s.handle(pubsub.Event[proto.Message]{Payload: toolUse}, nil)
 34	require.NoError(t, err)
 35	require.False(t, done, "tool_use finish must NOT terminate the run loop")
 36}
 37
 38// TestRunStream_RunCompleteExits verifies the happy path: streaming
 39// assistant text then RunComplete terminates with the full final
 40// text on stdout. Together with the tool_use test above this
 41// nails down the "tool use + final text" sequence that the original
 42// bug truncated.
 43func TestRunStream_RunCompleteExits(t *testing.T) {
 44	t.Parallel()
 45
 46	buf := &bytes.Buffer{}
 47	s := &runStream{sessionID: "S", out: buf, read: map[string]int{}}
 48
 49	// Tool-use step.
 50	done, err := s.handle(pubsub.Event[proto.Message]{Payload: proto.Message{
 51		ID: "m1", SessionID: "S", Role: proto.Assistant,
 52		Parts: []proto.ContentPart{
 53			proto.TextContent{Text: ""},
 54			proto.Finish{Reason: proto.FinishReasonToolUse},
 55		},
 56	}}, nil)
 57	require.NoError(t, err)
 58	require.False(t, done)
 59
 60	// Final assistant message stream.
 61	done, err = s.handle(pubsub.Event[proto.Message]{Payload: proto.Message{
 62		ID: "m2", SessionID: "S", Role: proto.Assistant,
 63		Parts: []proto.ContentPart{
 64			proto.TextContent{Text: "VERDICT: APPROVED"},
 65			proto.Finish{Reason: proto.FinishReasonEndTurn},
 66		},
 67	}}, nil)
 68	require.NoError(t, err)
 69	require.False(t, done, "message finish (even end_turn) must not exit; RunComplete is the only terminal signal")
 70
 71	// RunComplete.
 72	done, err = s.handle(pubsub.Event[proto.RunComplete]{Payload: proto.RunComplete{
 73		SessionID: "S",
 74		MessageID: "m2",
 75		Text:      "VERDICT: APPROVED",
 76	}}, nil)
 77	require.NoError(t, err)
 78	require.True(t, done)
 79	require.Equal(t, "VERDICT: APPROVED", buf.String())
 80}
 81
 82// TestRunStream_ReconcilesOnOutOfOrderRunComplete is the worst-case
 83// ordering scenario: RunComplete reaches the client BEFORE any of
 84// the streaming assistant message events for the turn (the pubsub
 85// fan-in across upstream brokers does not preserve cross-broker
 86// ordering). The embedded Text field must rescue stdout so the
 87// caller still sees the complete final text.
 88func TestRunStream_ReconcilesOnOutOfOrderRunComplete(t *testing.T) {
 89	t.Parallel()
 90
 91	buf := &bytes.Buffer{}
 92	s := &runStream{sessionID: "S", out: buf, read: map[string]int{}}
 93
 94	done, err := s.handle(pubsub.Event[proto.RunComplete]{Payload: proto.RunComplete{
 95		SessionID: "S",
 96		MessageID: "m2",
 97		Text:      "VERDICT: APPROVED",
 98	}}, nil)
 99	require.NoError(t, err)
100	require.True(t, done)
101	require.Equal(t, "VERDICT: APPROVED", buf.String(),
102		"RunComplete must reconcile stdout when message events did not arrive in time")
103}
104
105// TestRunStream_ReconcilesPartialStream covers the realistic case
106// where some streaming output reached stdout before RunComplete
107// arrived: the reconciliation pass must append only the unread tail,
108// never duplicate the prefix.
109func TestRunStream_ReconcilesPartialStream(t *testing.T) {
110	t.Parallel()
111
112	buf := &bytes.Buffer{}
113	s := &runStream{sessionID: "S", out: buf, read: map[string]int{}}
114
115	_, err := s.handle(pubsub.Event[proto.Message]{Payload: proto.Message{
116		ID: "m2", SessionID: "S", Role: proto.Assistant,
117		Parts: []proto.ContentPart{proto.TextContent{Text: "VERDICT: "}},
118	}}, nil)
119	require.NoError(t, err)
120
121	_, err = s.handle(pubsub.Event[proto.RunComplete]{Payload: proto.RunComplete{
122		SessionID: "S",
123		MessageID: "m2",
124		Text:      "VERDICT: APPROVED",
125	}}, nil)
126	require.NoError(t, err)
127	require.Equal(t, "VERDICT: APPROVED", buf.String())
128}
129
130// TestRunStream_IgnoresOtherSessions ensures multi-session
131// subscribers (e.g. a TUI watching workspace events while `crush
132// run` is in flight against the same workspace) do not cause
133// premature exit on RunComplete for a different session.
134func TestRunStream_IgnoresOtherSessions(t *testing.T) {
135	t.Parallel()
136
137	s := &runStream{sessionID: "S", out: &bytes.Buffer{}, read: map[string]int{}}
138	done, err := s.handle(pubsub.Event[proto.RunComplete]{Payload: proto.RunComplete{
139		SessionID: "OTHER",
140		MessageID: "x",
141		Text:      "noise",
142	}}, nil)
143	require.NoError(t, err)
144	require.False(t, done)
145}
146
147// TestRunStream_ErrorRunComplete surfaces a failing run as a
148// non-nil error from `crush run` so shells and CI catch it via
149// exit status.
150func TestRunStream_ErrorRunComplete(t *testing.T) {
151	t.Parallel()
152
153	s := &runStream{sessionID: "S", out: &bytes.Buffer{}, read: map[string]int{}}
154	done, err := s.handle(pubsub.Event[proto.RunComplete]{Payload: proto.RunComplete{
155		SessionID: "S",
156		Error:     "model temporarily unavailable",
157	}}, nil)
158	require.True(t, done)
159	require.Error(t, err)
160	require.Contains(t, err.Error(), "model temporarily unavailable")
161}
162
163// TestRunStream_CancelledRunCompleteIsClean ensures a cancelled
164// run (e.g. Ctrl+C while `crush run` waits) exits cleanly rather
165// than reporting the cancellation as a failure.
166func TestRunStream_CancelledRunCompleteIsClean(t *testing.T) {
167	t.Parallel()
168
169	s := &runStream{sessionID: "S", out: &bytes.Buffer{}, read: map[string]int{}}
170	done, err := s.handle(pubsub.Event[proto.RunComplete]{Payload: proto.RunComplete{
171		SessionID: "S",
172		Error:     "context canceled",
173		Cancelled: true,
174	}}, nil)
175	require.True(t, done)
176	require.NoError(t, err)
177}
178
179// TestRunStream_LeadingWhitespaceTrimmedOnce mirrors the
180// pre-existing trim of leading whitespace on the first byte of
181// stdout: the trim must happen exactly once even when stdout is
182// first produced by the RunComplete reconciliation path rather
183// than the live stream.
184func TestRunStream_LeadingWhitespaceTrimmedOnce(t *testing.T) {
185	t.Parallel()
186
187	buf := &bytes.Buffer{}
188	s := &runStream{sessionID: "S", out: buf, read: map[string]int{}}
189
190	done, err := s.handle(pubsub.Event[proto.RunComplete]{Payload: proto.RunComplete{
191		SessionID: "S",
192		MessageID: "m2",
193		Text:      "  \tactual output",
194	}}, nil)
195	require.NoError(t, err)
196	require.True(t, done)
197	require.Equal(t, "actual output", buf.String())
198}
199
200// TestRunStream_StopSpinnerInvokedOnFirstOutput verifies the
201// spinner is stopped exactly when meaningful output starts (either
202// a streamed assistant message or the reconciliation tail). This
203// matches the prior behaviour and prevents the spinner from
204// painting over the final response on TTYs.
205func TestRunStream_StopSpinnerInvokedOnFirstOutput(t *testing.T) {
206	t.Parallel()
207
208	calls := 0
209	stop := func() { calls++ }
210	s := &runStream{sessionID: "S", out: &bytes.Buffer{}, read: map[string]int{}}
211	_, _ = s.handle(pubsub.Event[proto.Message]{Payload: proto.Message{
212		ID: "m1", SessionID: "S", Role: proto.Assistant,
213		Parts: []proto.ContentPart{proto.TextContent{Text: "hi"}},
214	}}, stop)
215	require.GreaterOrEqual(t, calls, 1, "spinner must stop once stdout has content")
216}
217
218// TestRunStream_RunIDFiltersForeignTurns covers the busy-session
219// queue scenario: `crush run --continue` attaches to a session
220// whose currently running turn finishes first, publishing its
221// RunComplete on the same session ID. Without per-run correlation
222// the stream would exit on that foreign event and drop our own
223// queued turn's output. With RunID filtering the foreign event is
224// ignored and only the matching RunComplete terminates the stream.
225func TestRunStream_RunIDFiltersForeignTurns(t *testing.T) {
226	t.Parallel()
227
228	const sessionID = "S"
229	const myRun = "run-mine"
230	const otherRun = "run-other"
231
232	buf := &bytes.Buffer{}
233	s := &runStream{
234		sessionID: sessionID,
235		runID:     myRun,
236		out:       buf,
237		read:      map[string]int{},
238	}
239
240	// The busy session's existing turn emits more text before it finishes.
241	done, err := s.handle(pubsub.Event[proto.Message]{Payload: proto.Message{
242		ID:        "other-msg",
243		SessionID: sessionID,
244		Role:      proto.Assistant,
245		Parts:     []proto.ContentPart{proto.TextContent{Text: "noise from another turn"}},
246	}}, nil)
247	require.NoError(t, err)
248	require.False(t, done,
249		"foreign message on same session must not terminate our run")
250	require.Empty(t, buf.String(),
251		"foreign message on same session must not write to our stdout")
252
253	// The busy session's existing turn finishes first.
254	done, err = s.handle(pubsub.Event[proto.RunComplete]{Payload: proto.RunComplete{
255		SessionID: sessionID,
256		RunID:     otherRun,
257		MessageID: "other-msg",
258		Text:      "noise from another turn",
259	}}, nil)
260	require.NoError(t, err)
261	require.False(t, done,
262		"foreign RunComplete on same session must not terminate our run")
263	require.Empty(t, buf.String(),
264		"foreign RunComplete must not write to our stdout")
265
266	// Our own queued turn eventually finishes.
267	done, err = s.handle(pubsub.Event[proto.RunComplete]{Payload: proto.RunComplete{
268		SessionID: sessionID,
269		RunID:     myRun,
270		MessageID: "my-msg",
271		Text:      "OK",
272	}}, nil)
273	require.NoError(t, err)
274	require.True(t, done, "matching RunID must terminate the stream")
275	require.Equal(t, "OK", buf.String())
276}
277
278func TestRunStream_RunIDSuppressesLiveMessagesAndPrintsRunComplete(t *testing.T) {
279	t.Parallel()
280
281	buf := &bytes.Buffer{}
282	s := &runStream{
283		sessionID: "S",
284		runID:     "run-mine",
285		out:       buf,
286		read:      map[string]int{},
287	}
288
289	done, err := s.handle(pubsub.Event[proto.Message]{Payload: proto.Message{
290		ID:        "my-msg",
291		SessionID: "S",
292		Role:      proto.Assistant,
293		Parts:     []proto.ContentPart{proto.TextContent{Text: "streamed prefix"}},
294	}}, nil)
295	require.NoError(t, err)
296	require.False(t, done)
297	require.Empty(t, buf.String())
298
299	done, err = s.handle(pubsub.Event[proto.RunComplete]{Payload: proto.RunComplete{
300		SessionID: "S",
301		RunID:     "run-mine",
302		MessageID: "my-msg",
303		Text:      "streamed prefix final",
304	}}, nil)
305	require.NoError(t, err)
306	require.True(t, done)
307	require.Equal(t, "streamed prefix final", buf.String())
308}
309
310// TestRunStream_NoRunIDFallsBackToSessionID preserves the older
311// behaviour for callers (and tests) that don't supply a RunID:
312// SessionID-only matching still terminates the stream on the
313// session's RunComplete. This keeps the contract backwards
314// compatible with servers that don't echo RunID and with the
315// pre-existing TestRunStream_* assertions.
316func TestRunStream_NoRunIDFallsBackToSessionID(t *testing.T) {
317	t.Parallel()
318
319	buf := &bytes.Buffer{}
320	s := &runStream{sessionID: "S", out: buf, read: map[string]int{}}
321	done, err := s.handle(pubsub.Event[proto.RunComplete]{Payload: proto.RunComplete{
322		SessionID: "S",
323		MessageID: "m2",
324		Text:      "DONE",
325	}}, nil)
326	require.NoError(t, err)
327	require.True(t, done)
328	require.Equal(t, "DONE", buf.String())
329}