subagent_stream_test.go

  1package test
  2
  3import (
  4	"bufio"
  5	"context"
  6	"encoding/json"
  7	"log/slog"
  8	"net/http"
  9	"net/http/httptest"
 10	"os"
 11	"strings"
 12	"testing"
 13	"time"
 14
 15	"shelley.exe.dev/claudetool"
 16	"shelley.exe.dev/db"
 17	"shelley.exe.dev/db/generated"
 18	"shelley.exe.dev/llm"
 19	"shelley.exe.dev/loop"
 20	"shelley.exe.dev/models"
 21	"shelley.exe.dev/server"
 22)
 23
 24// StreamResponse matches server.StreamResponse for testing
 25type StreamResponse struct {
 26	Messages               []json.RawMessage       `json:"messages"`
 27	Conversation           generated.Conversation  `json:"conversation"`
 28	ConversationState      *ConversationState      `json:"conversation_state,omitempty"`
 29	ConversationListUpdate *ConversationListUpdate `json:"conversation_list_update,omitempty"`
 30	Heartbeat              bool                    `json:"heartbeat,omitempty"`
 31}
 32
 33type ConversationState struct {
 34	ConversationID string `json:"conversation_id"`
 35	Working        bool   `json:"working"`
 36	Model          string `json:"model,omitempty"`
 37}
 38
 39type ConversationListUpdate struct {
 40	Type           string                  `json:"type"`
 41	Conversation   *generated.Conversation `json:"conversation,omitempty"`
 42	ConversationID string                  `json:"conversation_id,omitempty"`
 43}
 44
 45type fakeLLMManager struct {
 46	service *loop.PredictableService
 47}
 48
 49func (m *fakeLLMManager) GetService(modelID string) (llm.Service, error) {
 50	return m.service, nil
 51}
 52
 53func (m *fakeLLMManager) GetAvailableModels() []string {
 54	return []string{"predictable"}
 55}
 56
 57func (m *fakeLLMManager) HasModel(modelID string) bool {
 58	return modelID == "predictable"
 59}
 60
 61func (m *fakeLLMManager) GetModelInfo(modelID string) *models.ModelInfo {
 62	return nil
 63}
 64
 65func (m *fakeLLMManager) RefreshCustomModels() error {
 66	return nil
 67}
 68
 69func setupTestServerForSubagent(t *testing.T) (*server.Server, *db.DB, *httptest.Server, *loop.PredictableService) {
 70	t.Helper()
 71
 72	// Create temporary database
 73	tempDB := t.TempDir() + "/test.db"
 74	database, err := db.New(db.Config{DSN: tempDB})
 75	if err != nil {
 76		t.Fatalf("Failed to create test database: %v", err)
 77	}
 78	t.Cleanup(func() { database.Close() })
 79
 80	// Run migrations
 81	if err := database.Migrate(context.Background()); err != nil {
 82		t.Fatalf("Failed to migrate database: %v", err)
 83	}
 84
 85	logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{
 86		Level: slog.LevelDebug,
 87	}))
 88
 89	// Use predictable model
 90	predictableService := loop.NewPredictableService()
 91	llmManager := &fakeLLMManager{service: predictableService}
 92
 93	toolSetConfig := claudetool.ToolSetConfig{
 94		WorkingDir:    t.TempDir(),
 95		EnableBrowser: false,
 96	}
 97
 98	svr := server.NewServer(database, llmManager, toolSetConfig, logger, true, "", "predictable", "", nil)
 99
100	mux := http.NewServeMux()
101	svr.RegisterRoutes(mux)
102	testServer := httptest.NewServer(mux)
103	t.Cleanup(testServer.Close)
104
105	return svr, database, testServer, predictableService
106}
107
108// readSSEEvent reads a single SSE event from the response body with a timeout
109func readSSEEventWithTimeout(reader *bufio.Reader, timeout time.Duration) (*StreamResponse, error) {
110	type result struct {
111		resp *StreamResponse
112		err  error
113	}
114	ch := make(chan result, 1)
115
116	go func() {
117		var dataLines []string
118		for {
119			line, err := reader.ReadString('\n')
120			if err != nil {
121				ch <- result{nil, err}
122				return
123			}
124			line = strings.TrimSpace(line)
125
126			if line == "" && len(dataLines) > 0 {
127				// End of event
128				break
129			}
130
131			if strings.HasPrefix(line, "data: ") {
132				dataLines = append(dataLines, strings.TrimPrefix(line, "data: "))
133			}
134		}
135
136		if len(dataLines) == 0 {
137			ch <- result{nil, nil}
138			return
139		}
140
141		data := strings.Join(dataLines, "\n")
142		var response StreamResponse
143		if err := json.Unmarshal([]byte(data), &response); err != nil {
144			ch <- result{nil, err}
145			return
146		}
147		ch <- result{&response, nil}
148	}()
149
150	select {
151	case r := <-ch:
152		return r.resp, r.err
153	case <-time.After(timeout):
154		return nil, context.DeadlineExceeded
155	}
156}
157
158// TestSubagentNotificationViaStream tests that when RunSubagent is called,
159// the subagent conversation is properly notified to all SSE streams.
160func TestSubagentNotificationViaStream(t *testing.T) {
161	svr, database, testServer, _ := setupTestServerForSubagent(t)
162
163	ctx := context.Background()
164
165	// Create parent conversation
166	parentSlug := "parent-convo"
167	parentConv, err := database.CreateConversation(ctx, &parentSlug, true, nil, nil)
168	if err != nil {
169		t.Fatalf("Failed to create parent conversation: %v", err)
170	}
171
172	// Start streaming from parent conversation
173	streamURL := testServer.URL + "/api/conversation/" + parentConv.ConversationID + "/stream"
174	resp, err := http.Get(streamURL)
175	if err != nil {
176		t.Fatalf("Failed to connect to stream: %v", err)
177	}
178	defer resp.Body.Close()
179
180	reader := bufio.NewReader(resp.Body)
181
182	// Read initial event (should be the conversation state)
183	initialEvent, err := readSSEEventWithTimeout(reader, 2*time.Second)
184	if err != nil {
185		t.Fatalf("Failed to read initial SSE event: %v", err)
186	}
187	if initialEvent == nil {
188		t.Fatal("Expected initial event")
189	}
190	t.Logf("Initial event: conversation_id=%s, has_state=%v",
191		initialEvent.Conversation.ConversationID,
192		initialEvent.ConversationState != nil)
193
194	// Create a subagent conversation directly in DB (simulating what SubagentTool.Run does)
195	subSlug := "sub-worker"
196	subConv, err := database.CreateSubagentConversation(ctx, subSlug, parentConv.ConversationID, nil)
197	if err != nil {
198		t.Fatalf("Failed to create subagent conversation: %v", err)
199	}
200	t.Logf("Created subagent: id=%s, slug=%s, parent=%s",
201		subConv.ConversationID, *subConv.Slug, *subConv.ParentConversationID)
202
203	// Now call RunSubagent (what the subagent tool does after creating the conversation)
204	// This should trigger the notification to all SSE streams
205	subagentRunner := server.NewSubagentRunner(svr)
206	go func() {
207		// Call RunSubagent with wait=false so it returns quickly
208		subagentRunner.RunSubagent(ctx, subConv.ConversationID, "Test prompt", false, 10*time.Second)
209	}()
210
211	// Wait for notification
212	var receivedSubagentUpdate bool
213	var receivedUpdate *ConversationListUpdate
214
215	deadline := time.Now().Add(3 * time.Second)
216	for time.Now().Before(deadline) {
217		event, err := readSSEEventWithTimeout(reader, 500*time.Millisecond)
218		if err == context.DeadlineExceeded {
219			continue // Keep waiting
220		}
221		if err != nil {
222			t.Logf("Error reading event: %v", err)
223			break
224		}
225		if event == nil {
226			continue
227		}
228
229		t.Logf("Received event: has_list_update=%v, has_state=%v, heartbeat=%v",
230			event.ConversationListUpdate != nil,
231			event.ConversationState != nil,
232			event.Heartbeat)
233
234		if event.ConversationListUpdate != nil {
235			update := event.ConversationListUpdate
236			t.Logf("List update: type=%s", update.Type)
237			if update.Conversation != nil {
238				t.Logf("  conversation_id=%s, parent=%v, slug=%v",
239					update.Conversation.ConversationID,
240					update.Conversation.ParentConversationID,
241					update.Conversation.Slug)
242				if update.Conversation.ConversationID == subConv.ConversationID {
243					receivedSubagentUpdate = true
244					receivedUpdate = update
245					break
246				}
247			}
248		}
249	}
250
251	// Verify we received the notification
252	if !receivedSubagentUpdate {
253		t.Error("Expected to receive subagent update notification via SSE stream when RunSubagent is called")
254	} else {
255		t.Logf("SUCCESS: Received subagent update: type=%s, slug=%v", receivedUpdate.Type, receivedUpdate.Conversation.Slug)
256	}
257}
258
259// TestSubagentWorkingStateNotification tests that subagent working state changes
260// are properly notified via the SSE stream.
261func TestSubagentWorkingStateNotification(t *testing.T) {
262	// This test would verify that when a subagent starts/stops working,
263	// the parent conversation's stream receives a ConversationState update.
264	// Currently we just document this should work via publishConversationState.
265	t.Skip("Skipping - requires more infrastructure to trigger working state changes")
266}