1package cmd
2
3import (
4 "bytes"
5 "testing"
6 "time"
7
8 "github.com/charmbracelet/crush/internal/proto"
9 "github.com/charmbracelet/crush/internal/pubsub"
10 "github.com/stretchr/testify/require"
11)
12
13// TestRunStream_ToolUseDoesNotTerminate is the regression test for
14// the original bug: a tool-call assistant message has a Finish part
15// with reason=tool_use and used to terminate `crush run` early via
16// the discarded `msg.IsFinished()` exit condition. With the new
17// RunComplete-driven loop, tool_use finishes must keep the stream
18// alive so the post-tool final text still reaches stdout.
19func TestRunStream_ToolUseDoesNotTerminate(t *testing.T) {
20 t.Parallel()
21
22 s := &runStream{sessionID: "S", out: &bytes.Buffer{}, read: map[string]int{}}
23
24 toolUse := proto.Message{
25 ID: "m1",
26 SessionID: "S",
27 Role: proto.Assistant,
28 Parts: []proto.ContentPart{
29 proto.TextContent{Text: ""},
30 proto.Finish{Reason: proto.FinishReasonToolUse, Time: time.Now().Unix()},
31 },
32 }
33 done, err := s.handle(pubsub.Event[proto.Message]{Payload: toolUse}, nil)
34 require.NoError(t, err)
35 require.False(t, done, "tool_use finish must NOT terminate the run loop")
36}
37
38// TestRunStream_RunCompleteExits verifies the happy path: streaming
39// assistant text then RunComplete terminates with the full final
40// text on stdout. Together with the tool_use test above this
41// nails down the "tool use + final text" sequence that the original
42// bug truncated.
43func TestRunStream_RunCompleteExits(t *testing.T) {
44 t.Parallel()
45
46 buf := &bytes.Buffer{}
47 s := &runStream{sessionID: "S", out: buf, read: map[string]int{}}
48
49 // Tool-use step.
50 done, err := s.handle(pubsub.Event[proto.Message]{Payload: proto.Message{
51 ID: "m1", SessionID: "S", Role: proto.Assistant,
52 Parts: []proto.ContentPart{
53 proto.TextContent{Text: ""},
54 proto.Finish{Reason: proto.FinishReasonToolUse},
55 },
56 }}, nil)
57 require.NoError(t, err)
58 require.False(t, done)
59
60 // Final assistant message stream.
61 done, err = s.handle(pubsub.Event[proto.Message]{Payload: proto.Message{
62 ID: "m2", SessionID: "S", Role: proto.Assistant,
63 Parts: []proto.ContentPart{
64 proto.TextContent{Text: "VERDICT: APPROVED"},
65 proto.Finish{Reason: proto.FinishReasonEndTurn},
66 },
67 }}, nil)
68 require.NoError(t, err)
69 require.False(t, done, "message finish (even end_turn) must not exit; RunComplete is the only terminal signal")
70
71 // RunComplete.
72 done, err = s.handle(pubsub.Event[proto.RunComplete]{Payload: proto.RunComplete{
73 SessionID: "S",
74 MessageID: "m2",
75 Text: "VERDICT: APPROVED",
76 }}, nil)
77 require.NoError(t, err)
78 require.True(t, done)
79 require.Equal(t, "VERDICT: APPROVED", buf.String())
80}
81
82// TestRunStream_ReconcilesOnOutOfOrderRunComplete is the worst-case
83// ordering scenario: RunComplete reaches the client BEFORE any of
84// the streaming assistant message events for the turn (the pubsub
85// fan-in across upstream brokers does not preserve cross-broker
86// ordering). The embedded Text field must rescue stdout so the
87// caller still sees the complete final text.
88func TestRunStream_ReconcilesOnOutOfOrderRunComplete(t *testing.T) {
89 t.Parallel()
90
91 buf := &bytes.Buffer{}
92 s := &runStream{sessionID: "S", out: buf, read: map[string]int{}}
93
94 done, err := s.handle(pubsub.Event[proto.RunComplete]{Payload: proto.RunComplete{
95 SessionID: "S",
96 MessageID: "m2",
97 Text: "VERDICT: APPROVED",
98 }}, nil)
99 require.NoError(t, err)
100 require.True(t, done)
101 require.Equal(t, "VERDICT: APPROVED", buf.String(),
102 "RunComplete must reconcile stdout when message events did not arrive in time")
103}
104
105// TestRunStream_ReconcilesPartialStream covers the realistic case
106// where some streaming output reached stdout before RunComplete
107// arrived: the reconciliation pass must append only the unread tail,
108// never duplicate the prefix.
109func TestRunStream_ReconcilesPartialStream(t *testing.T) {
110 t.Parallel()
111
112 buf := &bytes.Buffer{}
113 s := &runStream{sessionID: "S", out: buf, read: map[string]int{}}
114
115 _, err := s.handle(pubsub.Event[proto.Message]{Payload: proto.Message{
116 ID: "m2", SessionID: "S", Role: proto.Assistant,
117 Parts: []proto.ContentPart{proto.TextContent{Text: "VERDICT: "}},
118 }}, nil)
119 require.NoError(t, err)
120
121 _, err = s.handle(pubsub.Event[proto.RunComplete]{Payload: proto.RunComplete{
122 SessionID: "S",
123 MessageID: "m2",
124 Text: "VERDICT: APPROVED",
125 }}, nil)
126 require.NoError(t, err)
127 require.Equal(t, "VERDICT: APPROVED", buf.String())
128}
129
130// TestRunStream_IgnoresOtherSessions ensures multi-session
131// subscribers (e.g. a TUI watching workspace events while `crush
132// run` is in flight against the same workspace) do not cause
133// premature exit on RunComplete for a different session.
134func TestRunStream_IgnoresOtherSessions(t *testing.T) {
135 t.Parallel()
136
137 s := &runStream{sessionID: "S", out: &bytes.Buffer{}, read: map[string]int{}}
138 done, err := s.handle(pubsub.Event[proto.RunComplete]{Payload: proto.RunComplete{
139 SessionID: "OTHER",
140 MessageID: "x",
141 Text: "noise",
142 }}, nil)
143 require.NoError(t, err)
144 require.False(t, done)
145}
146
147// TestRunStream_ErrorRunComplete surfaces a failing run as a
148// non-nil error from `crush run` so shells and CI catch it via
149// exit status.
150func TestRunStream_ErrorRunComplete(t *testing.T) {
151 t.Parallel()
152
153 s := &runStream{sessionID: "S", out: &bytes.Buffer{}, read: map[string]int{}}
154 done, err := s.handle(pubsub.Event[proto.RunComplete]{Payload: proto.RunComplete{
155 SessionID: "S",
156 Error: "model temporarily unavailable",
157 }}, nil)
158 require.True(t, done)
159 require.Error(t, err)
160 require.Contains(t, err.Error(), "model temporarily unavailable")
161}
162
163// TestRunStream_CancelledRunCompleteIsClean ensures a cancelled
164// run (e.g. Ctrl+C while `crush run` waits) exits cleanly rather
165// than reporting the cancellation as a failure.
166func TestRunStream_CancelledRunCompleteIsClean(t *testing.T) {
167 t.Parallel()
168
169 s := &runStream{sessionID: "S", out: &bytes.Buffer{}, read: map[string]int{}}
170 done, err := s.handle(pubsub.Event[proto.RunComplete]{Payload: proto.RunComplete{
171 SessionID: "S",
172 Error: "context canceled",
173 Cancelled: true,
174 }}, nil)
175 require.True(t, done)
176 require.NoError(t, err)
177}
178
179// TestRunStream_LeadingWhitespaceTrimmedOnce mirrors the
180// pre-existing trim of leading whitespace on the first byte of
181// stdout: the trim must happen exactly once even when stdout is
182// first produced by the RunComplete reconciliation path rather
183// than the live stream.
184func TestRunStream_LeadingWhitespaceTrimmedOnce(t *testing.T) {
185 t.Parallel()
186
187 buf := &bytes.Buffer{}
188 s := &runStream{sessionID: "S", out: buf, read: map[string]int{}}
189
190 done, err := s.handle(pubsub.Event[proto.RunComplete]{Payload: proto.RunComplete{
191 SessionID: "S",
192 MessageID: "m2",
193 Text: " \tactual output",
194 }}, nil)
195 require.NoError(t, err)
196 require.True(t, done)
197 require.Equal(t, "actual output", buf.String())
198}
199
200// TestRunStream_StopSpinnerInvokedOnFirstOutput verifies the
201// spinner is stopped exactly when meaningful output starts (either
202// a streamed assistant message or the reconciliation tail). This
203// matches the prior behaviour and prevents the spinner from
204// painting over the final response on TTYs.
205func TestRunStream_StopSpinnerInvokedOnFirstOutput(t *testing.T) {
206 t.Parallel()
207
208 calls := 0
209 stop := func() { calls++ }
210 s := &runStream{sessionID: "S", out: &bytes.Buffer{}, read: map[string]int{}}
211 _, _ = s.handle(pubsub.Event[proto.Message]{Payload: proto.Message{
212 ID: "m1", SessionID: "S", Role: proto.Assistant,
213 Parts: []proto.ContentPart{proto.TextContent{Text: "hi"}},
214 }}, stop)
215 require.GreaterOrEqual(t, calls, 1, "spinner must stop once stdout has content")
216}
217
218// TestRunStream_RunIDFiltersForeignTurns covers the busy-session
219// queue scenario: `crush run --continue` attaches to a session
220// whose currently running turn finishes first, publishing its
221// RunComplete on the same session ID. Without per-run correlation
222// the stream would exit on that foreign event and drop our own
223// queued turn's output. With RunID filtering the foreign event is
224// ignored and only the matching RunComplete terminates the stream.
225func TestRunStream_RunIDFiltersForeignTurns(t *testing.T) {
226 t.Parallel()
227
228 const sessionID = "S"
229 const myRun = "run-mine"
230 const otherRun = "run-other"
231
232 buf := &bytes.Buffer{}
233 s := &runStream{
234 sessionID: sessionID,
235 runID: myRun,
236 out: buf,
237 read: map[string]int{},
238 }
239
240 // The busy session's existing turn emits more text before it finishes.
241 done, err := s.handle(pubsub.Event[proto.Message]{Payload: proto.Message{
242 ID: "other-msg",
243 SessionID: sessionID,
244 Role: proto.Assistant,
245 Parts: []proto.ContentPart{proto.TextContent{Text: "noise from another turn"}},
246 }}, nil)
247 require.NoError(t, err)
248 require.False(t, done,
249 "foreign message on same session must not terminate our run")
250 require.Empty(t, buf.String(),
251 "foreign message on same session must not write to our stdout")
252
253 // The busy session's existing turn finishes first.
254 done, err = s.handle(pubsub.Event[proto.RunComplete]{Payload: proto.RunComplete{
255 SessionID: sessionID,
256 RunID: otherRun,
257 MessageID: "other-msg",
258 Text: "noise from another turn",
259 }}, nil)
260 require.NoError(t, err)
261 require.False(t, done,
262 "foreign RunComplete on same session must not terminate our run")
263 require.Empty(t, buf.String(),
264 "foreign RunComplete must not write to our stdout")
265
266 // Our own queued turn eventually finishes.
267 done, err = s.handle(pubsub.Event[proto.RunComplete]{Payload: proto.RunComplete{
268 SessionID: sessionID,
269 RunID: myRun,
270 MessageID: "my-msg",
271 Text: "OK",
272 }}, nil)
273 require.NoError(t, err)
274 require.True(t, done, "matching RunID must terminate the stream")
275 require.Equal(t, "OK", buf.String())
276}
277
278func TestRunStream_RunIDSuppressesLiveMessagesAndPrintsRunComplete(t *testing.T) {
279 t.Parallel()
280
281 buf := &bytes.Buffer{}
282 s := &runStream{
283 sessionID: "S",
284 runID: "run-mine",
285 out: buf,
286 read: map[string]int{},
287 }
288
289 done, err := s.handle(pubsub.Event[proto.Message]{Payload: proto.Message{
290 ID: "my-msg",
291 SessionID: "S",
292 Role: proto.Assistant,
293 Parts: []proto.ContentPart{proto.TextContent{Text: "streamed prefix"}},
294 }}, nil)
295 require.NoError(t, err)
296 require.False(t, done)
297 require.Empty(t, buf.String())
298
299 done, err = s.handle(pubsub.Event[proto.RunComplete]{Payload: proto.RunComplete{
300 SessionID: "S",
301 RunID: "run-mine",
302 MessageID: "my-msg",
303 Text: "streamed prefix final",
304 }}, nil)
305 require.NoError(t, err)
306 require.True(t, done)
307 require.Equal(t, "streamed prefix final", buf.String())
308}
309
310// TestRunStream_NoRunIDFallsBackToSessionID preserves the older
311// behaviour for callers (and tests) that don't supply a RunID:
312// SessionID-only matching still terminates the stream on the
313// session's RunComplete. This keeps the contract backwards
314// compatible with servers that don't echo RunID and with the
315// pre-existing TestRunStream_* assertions.
316func TestRunStream_NoRunIDFallsBackToSessionID(t *testing.T) {
317 t.Parallel()
318
319 buf := &bytes.Buffer{}
320 s := &runStream{sessionID: "S", out: buf, read: map[string]int{}}
321 done, err := s.handle(pubsub.Event[proto.RunComplete]{Payload: proto.RunComplete{
322 SessionID: "S",
323 MessageID: "m2",
324 Text: "DONE",
325 }}, nil)
326 require.NoError(t, err)
327 require.True(t, done)
328 require.Equal(t, "DONE", buf.String())
329}