queued_runid_test.go

  1package agent
  2
  3import (
  4	"context"
  5	"errors"
  6	"sync/atomic"
  7	"testing"
  8	"time"
  9
 10	"charm.land/catwalk/pkg/catwalk"
 11	"charm.land/fantasy"
 12	"github.com/charmbracelet/crush/internal/agent/notify"
 13	"github.com/charmbracelet/crush/internal/message"
 14	"github.com/charmbracelet/crush/internal/pubsub"
 15	"github.com/stretchr/testify/require"
 16)
 17
 18// gatedStreamModel streams a single text part followed by a clean finish,
 19// but blocks the very first Stream call until its gate is released. That
 20// lets a test hold a run "active" (past PrepareStep, inside Stream) just
 21// long enough to enqueue a follow-up prompt behind the busy session.
 22// Subsequent Stream calls (e.g. the recursive run draining the queue)
 23// proceed immediately.
 24type gatedStreamModel struct {
 25	text    string
 26	gate    chan struct{}
 27	entered chan struct{}
 28	calls   atomic.Int64
 29}
 30
 31func (m *gatedStreamModel) Provider() string { return "fake" }
 32func (m *gatedStreamModel) Model() string    { return "fake-model" }
 33
 34func (m *gatedStreamModel) Generate(ctx context.Context, call fantasy.Call) (*fantasy.Response, error) {
 35	return &fantasy.Response{
 36		Content:      fantasy.ResponseContent{fantasy.TextContent{Text: m.text}},
 37		FinishReason: fantasy.FinishReasonStop,
 38	}, nil
 39}
 40
 41func (m *gatedStreamModel) Stream(ctx context.Context, call fantasy.Call) (fantasy.StreamResponse, error) {
 42	if m.calls.Add(1) == 1 {
 43		close(m.entered)
 44		select {
 45		case <-m.gate:
 46		case <-ctx.Done():
 47		}
 48	}
 49	text := m.text
 50	return func(yield func(fantasy.StreamPart) bool) {
 51		if !yield(fantasy.StreamPart{Type: fantasy.StreamPartTypeTextStart, ID: "1"}) {
 52			return
 53		}
 54		if !yield(fantasy.StreamPart{Type: fantasy.StreamPartTypeTextDelta, ID: "1", Delta: text}) {
 55			return
 56		}
 57		if !yield(fantasy.StreamPart{Type: fantasy.StreamPartTypeTextEnd, ID: "1"}) {
 58			return
 59		}
 60		yield(fantasy.StreamPart{Type: fantasy.StreamPartTypeFinish, FinishReason: fantasy.FinishReasonStop})
 61	}, nil
 62}
 63
 64func (m *gatedStreamModel) GenerateObject(ctx context.Context, call fantasy.ObjectCall) (*fantasy.ObjectResponse, error) {
 65	return nil, errors.New("not implemented")
 66}
 67
 68func (m *gatedStreamModel) StreamObject(ctx context.Context, call fantasy.ObjectCall) (fantasy.ObjectStreamResponse, error) {
 69	return nil, errors.New("not implemented")
 70}
 71
 72// TestRun_QueuedRunIDPromptRunsRecursivelyAndPublishesRunComplete is the
 73// end-to-end proof of fix 2: a prompt carrying a RunID that is queued
 74// behind a busy session must NOT be silently folded into the active turn.
 75// It runs as its own turn via the recursive run path and publishes its
 76// own terminal RunComplete, so a `crush run` caller blocking on that
 77// RunID does not hang. The active turn keeps its own RunComplete too.
 78func TestRun_QueuedRunIDPromptRunsRecursivelyAndPublishesRunComplete(t *testing.T) {
 79	t.Parallel()
 80
 81	env := testEnv(t)
 82	broker := pubsub.NewBroker[notify.RunComplete]()
 83	t.Cleanup(broker.Shutdown)
 84
 85	large := &gatedStreamModel{
 86		text:    "done",
 87		gate:    make(chan struct{}),
 88		entered: make(chan struct{}),
 89	}
 90	small := &finishStreamModel{text: "title"}
 91
 92	sa := NewSessionAgent(SessionAgentOptions{
 93		LargeModel:  Model{Model: large, CatwalkCfg: catwalk.Model{ContextWindow: 200000, DefaultMaxTokens: 10000}},
 94		SmallModel:  Model{Model: small, CatwalkCfg: catwalk.Model{ContextWindow: 200000, DefaultMaxTokens: 10000}},
 95		IsYolo:      true,
 96		Sessions:    env.sessions,
 97		Messages:    env.messages,
 98		RunComplete: broker,
 99	}).(*sessionAgent)
100
101	sess, err := env.sessions.Create(t.Context(), "session")
102	require.NoError(t, err)
103
104	subCtx, subCancel := context.WithCancel(t.Context())
105	defer subCancel()
106	ch := broker.Subscribe(subCtx)
107
108	// Start the main turn; it blocks inside Stream once active.
109	mainDone := make(chan error, 1)
110	go func() {
111		_, runErr := sa.Run(t.Context(), SessionAgentCall{
112			SessionID: sess.ID,
113			RunID:     "run-main",
114			Prompt:    "main",
115		})
116		mainDone <- runErr
117	}()
118
119	// Wait until the main turn is active (inside Stream).
120	select {
121	case <-large.entered:
122	case <-time.After(5 * time.Second):
123		t.Fatal("main run never entered Stream")
124	}
125	require.True(t, sa.IsSessionBusy(sess.ID), "main run must be active before enqueueing the follow-up")
126
127	// Enqueue a RunID-bearing follow-up behind the busy session.
128	res, err := sa.Run(t.Context(), SessionAgentCall{
129		SessionID: sess.ID,
130		RunID:     "run-follow",
131		Prompt:    "follow",
132	})
133	require.NoError(t, err)
134	require.Nil(t, res, "a busy-session follow-up must enqueue and return (nil, nil)")
135	require.Equal(t, 1, sa.QueuedPrompts(sess.ID), "the follow-up must be queued, not folded")
136
137	// Release the main turn so it completes and hands off to the queue.
138	close(large.gate)
139	require.NoError(t, <-mainDone)
140
141	// Both turns must publish their own terminal RunComplete.
142	got := map[string]notify.RunComplete{}
143	deadline := time.After(5 * time.Second)
144	for len(got) < 2 {
145		select {
146		case ev := <-ch:
147			got[ev.Payload.RunID] = ev.Payload
148		case <-deadline:
149			t.Fatalf("timed out waiting for both RunCompletes; got %v", got)
150		}
151	}
152
153	main, ok := got["run-main"]
154	require.True(t, ok, "the active turn must publish its own RunComplete")
155	require.Empty(t, main.Error)
156	require.False(t, main.Cancelled)
157
158	follow, ok := got["run-follow"]
159	require.True(t, ok,
160		"the queued RunID prompt must publish its own RunComplete instead of being folded silently")
161	require.Empty(t, follow.Error)
162	require.False(t, follow.Cancelled)
163	require.Equal(t, "done", follow.Text, "the queued prompt ran as its own turn")
164
165	// Two distinct assistant turns prove the follow-up was not folded.
166	msgs, err := env.messages.List(t.Context(), sess.ID)
167	require.NoError(t, err)
168	var assistants, follows int
169	for _, m := range msgs {
170		switch m.Role {
171		case message.Assistant:
172			assistants++
173		case message.User:
174			if m.Content().String() == "follow" {
175				follows++
176			}
177		}
178	}
179	require.Equal(t, 2, assistants, "the active turn and the recursive turn each produce one assistant message")
180	require.Equal(t, 1, follows, "the follow-up prompt is its own user turn")
181}