run_complete_test.go

  1package agent
  2
  3import (
  4	"context"
  5	"testing"
  6	"time"
  7
  8	"github.com/charmbracelet/crush/internal/agent/notify"
  9	"github.com/charmbracelet/crush/internal/pubsub"
 10	"github.com/stretchr/testify/require"
 11)
 12
 13// TestSessionAgentRun_QueueStripsOnComplete verifies that when a Run
 14// call is enqueued (because the session is already busy), the
 15// OnComplete hook is NOT propagated onto the queued copy. The hook
 16// belongs to the caller's retry/coalesce scope (typically
 17// coordinator.Run) which has already returned by the time the queue
 18// drains; carrying it forward would silently funnel the terminal
 19// event into a closure nobody reads, and subscribers (`crush run`)
 20// would hang waiting for a RunComplete that never publishes.
 21func TestSessionAgentRun_QueueStripsOnComplete(t *testing.T) {
 22	t.Parallel()
 23
 24	env := testEnv(t)
 25	a := NewSessionAgent(SessionAgentOptions{
 26		Sessions: env.sessions,
 27		Messages: env.messages,
 28	}).(*sessionAgent)
 29
 30	const sessionID = "queued-session"
 31	// Mark the session as busy so Run takes the queue branch
 32	// without needing a real model.
 33	a.activeRequests.Set(sessionID, func() {})
 34
 35	var called bool
 36	hook := func(notify.RunComplete) { called = true }
 37
 38	res, err := a.Run(t.Context(), SessionAgentCall{
 39		SessionID:  sessionID,
 40		RunID:      "run-xyz",
 41		Prompt:     "queued prompt",
 42		OnComplete: hook,
 43	})
 44	require.NoError(t, err)
 45	require.Nil(t, res, "queued Run must return (nil, nil)")
 46	require.False(t, called,
 47		"OnComplete must not fire on the enqueue path; the caller's scope is still live")
 48
 49	queued, ok := a.messageQueue.Get(sessionID)
 50	require.True(t, ok)
 51	require.Len(t, queued, 1)
 52	require.Nil(t, queued[0].OnComplete,
 53		"queued SessionAgentCall must have OnComplete stripped so the drain falls back to the default broker publish")
 54	require.Equal(t, "queued prompt", queued[0].Prompt,
 55		"all other fields must be preserved on the queued copy")
 56	require.Equal(t, "run-xyz", queued[0].RunID,
 57		"RunID must be preserved on the queued copy so the drained turn's "+
 58			"RunComplete still correlates with the originating SendMessage")
 59}
 60
 61// TestDrainUncanceledQueue_FiltersUnderDispatchLock verifies that the
 62// queue drain evaluates the per-session cancel mark while holding the
 63// dispatch mutex (canceledBySeq's documented precondition). Queued calls
 64// at or below the cancel high-water mark are dropped, calls queued after
 65// the cancel (higher seq) survive, untracked enqueues (seq == 0) are
 66// dropped whenever any mark is present, and the queue is cleared.
 67func TestDrainUncanceledQueue_FiltersUnderDispatchLock(t *testing.T) {
 68	t.Parallel()
 69
 70	env := testEnv(t)
 71	a := NewSessionAgent(SessionAgentOptions{
 72		Sessions: env.sessions,
 73		Messages: env.messages,
 74	}).(*sessionAgent)
 75
 76	const sessionID = "drain-session"
 77	a.messageQueue.Set(sessionID, []SessionAgentCall{
 78		{SessionID: sessionID, Prompt: "below", acceptSeq: 1},
 79		{SessionID: sessionID, Prompt: "at-mark", acceptSeq: 2},
 80		{SessionID: sessionID, Prompt: "after", acceptSeq: 3},
 81		{SessionID: sessionID, Prompt: "untracked", acceptSeq: 0},
 82	})
 83	// Cancel high-water mark at seq 2: seq <= 2 and seq == 0 are covered.
 84	a.cancelMark.Set(sessionID, 2)
 85
 86	survivors := a.drainUncanceledQueue(sessionID)
 87
 88	require.Len(t, survivors, 1,
 89		"only the follow-up queued after the cancel (seq > mark) must survive")
 90	require.Equal(t, "after", survivors[0].Prompt)
 91
 92	_, ok := a.messageQueue.Get(sessionID)
 93	require.False(t, ok, "drain must clear the session message queue")
 94}
 95
 96// TestDrainUncanceledQueue_NoMarkKeepsAll verifies that with no cancel
 97// mark recorded, every queued call survives the drain.
 98func TestDrainUncanceledQueue_NoMarkKeepsAll(t *testing.T) {
 99	t.Parallel()
100
101	env := testEnv(t)
102	a := NewSessionAgent(SessionAgentOptions{
103		Sessions: env.sessions,
104		Messages: env.messages,
105	}).(*sessionAgent)
106
107	const sessionID = "drain-nomark"
108	a.messageQueue.Set(sessionID, []SessionAgentCall{
109		{SessionID: sessionID, Prompt: "a", acceptSeq: 0},
110		{SessionID: sessionID, Prompt: "b", acceptSeq: 5},
111	})
112
113	survivors := a.drainUncanceledQueue(sessionID)
114	require.Len(t, survivors, 2, "no cancel mark means all queued calls survive")
115}
116
117// TestRunCompletePublisher_MustDeliverOverTakesPublish exercises the
118// pubsub.Publisher interface change end-to-end: a Broker is the only
119// concrete Publisher implementation and must satisfy both Publish and
120// PublishMustDeliver. The coordinator's final RunComplete emit relies
121// on PublishMustDeliver to apply bounded-blocking semantics so a
122// momentarily-full subscriber buffer can't silently drop the
123// authoritative end-of-run event.
124func TestRunCompletePublisher_MustDeliverOverTakesPublish(t *testing.T) {
125	t.Parallel()
126
127	broker := pubsub.NewBroker[notify.RunComplete]()
128	t.Cleanup(broker.Shutdown)
129
130	// Subscribe before publishing so the event is delivered.
131	ctx, cancel := context.WithCancel(t.Context())
132	defer cancel()
133	ch := broker.Subscribe(ctx)
134
135	rc := notify.RunComplete{SessionID: "S", MessageID: "m", Text: "ok"}
136	var pub pubsub.Publisher[notify.RunComplete] = broker
137	pub.PublishMustDeliver(t.Context(), pubsub.UpdatedEvent, rc)
138
139	select {
140	case got := <-ch:
141		require.Equal(t, rc, got.Payload)
142	case <-time.After(time.Second):
143		t.Fatal("PublishMustDeliver did not deliver event")
144	}
145}