1package agent
2
3import (
4 "context"
5 "errors"
6 "sync/atomic"
7 "testing"
8 "time"
9
10 "charm.land/catwalk/pkg/catwalk"
11 "charm.land/fantasy"
12 "github.com/charmbracelet/crush/internal/agent/notify"
13 "github.com/charmbracelet/crush/internal/message"
14 "github.com/charmbracelet/crush/internal/pubsub"
15 "github.com/stretchr/testify/require"
16)
17
18// gatedStreamModel streams a single text part followed by a clean finish,
19// but blocks the very first Stream call until its gate is released. That
20// lets a test hold a run "active" (past PrepareStep, inside Stream) just
21// long enough to enqueue a follow-up prompt behind the busy session.
22// Subsequent Stream calls (e.g. the recursive run draining the queue)
23// proceed immediately.
24type gatedStreamModel struct {
25 text string
26 gate chan struct{}
27 entered chan struct{}
28 calls atomic.Int64
29}
30
31func (m *gatedStreamModel) Provider() string { return "fake" }
32func (m *gatedStreamModel) Model() string { return "fake-model" }
33
34func (m *gatedStreamModel) Generate(ctx context.Context, call fantasy.Call) (*fantasy.Response, error) {
35 return &fantasy.Response{
36 Content: fantasy.ResponseContent{fantasy.TextContent{Text: m.text}},
37 FinishReason: fantasy.FinishReasonStop,
38 }, nil
39}
40
41func (m *gatedStreamModel) Stream(ctx context.Context, call fantasy.Call) (fantasy.StreamResponse, error) {
42 if m.calls.Add(1) == 1 {
43 close(m.entered)
44 select {
45 case <-m.gate:
46 case <-ctx.Done():
47 }
48 }
49 text := m.text
50 return func(yield func(fantasy.StreamPart) bool) {
51 if !yield(fantasy.StreamPart{Type: fantasy.StreamPartTypeTextStart, ID: "1"}) {
52 return
53 }
54 if !yield(fantasy.StreamPart{Type: fantasy.StreamPartTypeTextDelta, ID: "1", Delta: text}) {
55 return
56 }
57 if !yield(fantasy.StreamPart{Type: fantasy.StreamPartTypeTextEnd, ID: "1"}) {
58 return
59 }
60 yield(fantasy.StreamPart{Type: fantasy.StreamPartTypeFinish, FinishReason: fantasy.FinishReasonStop})
61 }, nil
62}
63
64func (m *gatedStreamModel) GenerateObject(ctx context.Context, call fantasy.ObjectCall) (*fantasy.ObjectResponse, error) {
65 return nil, errors.New("not implemented")
66}
67
68func (m *gatedStreamModel) StreamObject(ctx context.Context, call fantasy.ObjectCall) (fantasy.ObjectStreamResponse, error) {
69 return nil, errors.New("not implemented")
70}
71
72// TestRun_QueuedRunIDPromptRunsRecursivelyAndPublishesRunComplete is the
73// end-to-end proof of fix 2: a prompt carrying a RunID that is queued
74// behind a busy session must NOT be silently folded into the active turn.
75// It runs as its own turn via the recursive run path and publishes its
76// own terminal RunComplete, so a `crush run` caller blocking on that
77// RunID does not hang. The active turn keeps its own RunComplete too.
78func TestRun_QueuedRunIDPromptRunsRecursivelyAndPublishesRunComplete(t *testing.T) {
79 t.Parallel()
80
81 env := testEnv(t)
82 broker := pubsub.NewBroker[notify.RunComplete]()
83 t.Cleanup(broker.Shutdown)
84
85 large := &gatedStreamModel{
86 text: "done",
87 gate: make(chan struct{}),
88 entered: make(chan struct{}),
89 }
90 small := &finishStreamModel{text: "title"}
91
92 sa := NewSessionAgent(SessionAgentOptions{
93 LargeModel: Model{Model: large, CatwalkCfg: catwalk.Model{ContextWindow: 200000, DefaultMaxTokens: 10000}},
94 SmallModel: Model{Model: small, CatwalkCfg: catwalk.Model{ContextWindow: 200000, DefaultMaxTokens: 10000}},
95 IsYolo: true,
96 Sessions: env.sessions,
97 Messages: env.messages,
98 RunComplete: broker,
99 }).(*sessionAgent)
100
101 sess, err := env.sessions.Create(t.Context(), "session")
102 require.NoError(t, err)
103
104 subCtx, subCancel := context.WithCancel(t.Context())
105 defer subCancel()
106 ch := broker.Subscribe(subCtx)
107
108 // Start the main turn; it blocks inside Stream once active.
109 mainDone := make(chan error, 1)
110 go func() {
111 _, runErr := sa.Run(t.Context(), SessionAgentCall{
112 SessionID: sess.ID,
113 RunID: "run-main",
114 Prompt: "main",
115 })
116 mainDone <- runErr
117 }()
118
119 // Wait until the main turn is active (inside Stream).
120 select {
121 case <-large.entered:
122 case <-time.After(5 * time.Second):
123 t.Fatal("main run never entered Stream")
124 }
125 require.True(t, sa.IsSessionBusy(sess.ID), "main run must be active before enqueueing the follow-up")
126
127 // Enqueue a RunID-bearing follow-up behind the busy session.
128 res, err := sa.Run(t.Context(), SessionAgentCall{
129 SessionID: sess.ID,
130 RunID: "run-follow",
131 Prompt: "follow",
132 })
133 require.NoError(t, err)
134 require.Nil(t, res, "a busy-session follow-up must enqueue and return (nil, nil)")
135 require.Equal(t, 1, sa.QueuedPrompts(sess.ID), "the follow-up must be queued, not folded")
136
137 // Release the main turn so it completes and hands off to the queue.
138 close(large.gate)
139 require.NoError(t, <-mainDone)
140
141 // Both turns must publish their own terminal RunComplete.
142 got := map[string]notify.RunComplete{}
143 deadline := time.After(5 * time.Second)
144 for len(got) < 2 {
145 select {
146 case ev := <-ch:
147 got[ev.Payload.RunID] = ev.Payload
148 case <-deadline:
149 t.Fatalf("timed out waiting for both RunCompletes; got %v", got)
150 }
151 }
152
153 main, ok := got["run-main"]
154 require.True(t, ok, "the active turn must publish its own RunComplete")
155 require.Empty(t, main.Error)
156 require.False(t, main.Cancelled)
157
158 follow, ok := got["run-follow"]
159 require.True(t, ok,
160 "the queued RunID prompt must publish its own RunComplete instead of being folded silently")
161 require.Empty(t, follow.Error)
162 require.False(t, follow.Cancelled)
163 require.Equal(t, "done", follow.Text, "the queued prompt ran as its own turn")
164
165 // Two distinct assistant turns prove the follow-up was not folded.
166 msgs, err := env.messages.List(t.Context(), sess.ID)
167 require.NoError(t, err)
168 var assistants, follows int
169 for _, m := range msgs {
170 switch m.Role {
171 case message.Assistant:
172 assistants++
173 case message.User:
174 if m.Content().String() == "follow" {
175 follows++
176 }
177 }
178 }
179 require.Equal(t, 2, assistants, "the active turn and the recursive turn each produce one assistant message")
180 require.Equal(t, 1, follows, "the follow-up prompt is its own user turn")
181}