1package agent
2
3import (
4 "context"
5 "testing"
6 "time"
7
8 "github.com/charmbracelet/crush/internal/agent/notify"
9 "github.com/charmbracelet/crush/internal/pubsub"
10 "github.com/stretchr/testify/require"
11)
12
13// TestSessionAgentRun_QueueStripsOnComplete verifies that when a Run
14// call is enqueued (because the session is already busy), the
15// OnComplete hook is NOT propagated onto the queued copy. The hook
16// belongs to the caller's retry/coalesce scope (typically
17// coordinator.Run) which has already returned by the time the queue
18// drains; carrying it forward would silently funnel the terminal
19// event into a closure nobody reads, and subscribers (`crush run`)
20// would hang waiting for a RunComplete that never publishes.
21func TestSessionAgentRun_QueueStripsOnComplete(t *testing.T) {
22 t.Parallel()
23
24 env := testEnv(t)
25 a := NewSessionAgent(SessionAgentOptions{
26 Sessions: env.sessions,
27 Messages: env.messages,
28 }).(*sessionAgent)
29
30 const sessionID = "queued-session"
31 // Mark the session as busy so Run takes the queue branch
32 // without needing a real model.
33 a.activeRequests.Set(sessionID, func() {})
34
35 var called bool
36 hook := func(notify.RunComplete) { called = true }
37
38 res, err := a.Run(t.Context(), SessionAgentCall{
39 SessionID: sessionID,
40 RunID: "run-xyz",
41 Prompt: "queued prompt",
42 OnComplete: hook,
43 })
44 require.NoError(t, err)
45 require.Nil(t, res, "queued Run must return (nil, nil)")
46 require.False(t, called,
47 "OnComplete must not fire on the enqueue path; the caller's scope is still live")
48
49 queued, ok := a.messageQueue.Get(sessionID)
50 require.True(t, ok)
51 require.Len(t, queued, 1)
52 require.Nil(t, queued[0].OnComplete,
53 "queued SessionAgentCall must have OnComplete stripped so the drain falls back to the default broker publish")
54 require.Equal(t, "queued prompt", queued[0].Prompt,
55 "all other fields must be preserved on the queued copy")
56 require.Equal(t, "run-xyz", queued[0].RunID,
57 "RunID must be preserved on the queued copy so the drained turn's "+
58 "RunComplete still correlates with the originating SendMessage")
59}
60
61// TestDrainUncanceledQueue_FiltersUnderDispatchLock verifies that the
62// queue drain evaluates the per-session cancel mark while holding the
63// dispatch mutex (canceledBySeq's documented precondition). Queued calls
64// at or below the cancel high-water mark are dropped, calls queued after
65// the cancel (higher seq) survive, untracked enqueues (seq == 0) are
66// dropped whenever any mark is present, and the queue is cleared.
67func TestDrainUncanceledQueue_FiltersUnderDispatchLock(t *testing.T) {
68 t.Parallel()
69
70 env := testEnv(t)
71 a := NewSessionAgent(SessionAgentOptions{
72 Sessions: env.sessions,
73 Messages: env.messages,
74 }).(*sessionAgent)
75
76 const sessionID = "drain-session"
77 a.messageQueue.Set(sessionID, []SessionAgentCall{
78 {SessionID: sessionID, Prompt: "below", acceptSeq: 1},
79 {SessionID: sessionID, Prompt: "at-mark", acceptSeq: 2},
80 {SessionID: sessionID, Prompt: "after", acceptSeq: 3},
81 {SessionID: sessionID, Prompt: "untracked", acceptSeq: 0},
82 })
83 // Cancel high-water mark at seq 2: seq <= 2 and seq == 0 are covered.
84 a.cancelMark.Set(sessionID, 2)
85
86 survivors := a.drainUncanceledQueue(sessionID)
87
88 require.Len(t, survivors, 1,
89 "only the follow-up queued after the cancel (seq > mark) must survive")
90 require.Equal(t, "after", survivors[0].Prompt)
91
92 _, ok := a.messageQueue.Get(sessionID)
93 require.False(t, ok, "drain must clear the session message queue")
94}
95
96// TestDrainUncanceledQueue_NoMarkKeepsAll verifies that with no cancel
97// mark recorded, every queued call survives the drain.
98func TestDrainUncanceledQueue_NoMarkKeepsAll(t *testing.T) {
99 t.Parallel()
100
101 env := testEnv(t)
102 a := NewSessionAgent(SessionAgentOptions{
103 Sessions: env.sessions,
104 Messages: env.messages,
105 }).(*sessionAgent)
106
107 const sessionID = "drain-nomark"
108 a.messageQueue.Set(sessionID, []SessionAgentCall{
109 {SessionID: sessionID, Prompt: "a", acceptSeq: 0},
110 {SessionID: sessionID, Prompt: "b", acceptSeq: 5},
111 })
112
113 survivors := a.drainUncanceledQueue(sessionID)
114 require.Len(t, survivors, 2, "no cancel mark means all queued calls survive")
115}
116
117// TestRunCompletePublisher_MustDeliverOverTakesPublish exercises the
118// pubsub.Publisher interface change end-to-end: a Broker is the only
119// concrete Publisher implementation and must satisfy both Publish and
120// PublishMustDeliver. The coordinator's final RunComplete emit relies
121// on PublishMustDeliver to apply bounded-blocking semantics so a
122// momentarily-full subscriber buffer can't silently drop the
123// authoritative end-of-run event.
124func TestRunCompletePublisher_MustDeliverOverTakesPublish(t *testing.T) {
125 t.Parallel()
126
127 broker := pubsub.NewBroker[notify.RunComplete]()
128 t.Cleanup(broker.Shutdown)
129
130 // Subscribe before publishing so the event is delivered.
131 ctx, cancel := context.WithCancel(t.Context())
132 defer cancel()
133 ch := broker.Subscribe(ctx)
134
135 rc := notify.RunComplete{SessionID: "S", MessageID: "m", Text: "ok"}
136 var pub pubsub.Publisher[notify.RunComplete] = broker
137 pub.PublishMustDeliver(t.Context(), pubsub.UpdatedEvent, rc)
138
139 select {
140 case got := <-ch:
141 require.Equal(t, rc, got.Payload)
142 case <-time.After(time.Second):
143 t.Fatal("PublishMustDeliver did not deliver event")
144 }
145}