1package test
2
3import (
4 "bufio"
5 "context"
6 "encoding/json"
7 "log/slog"
8 "net/http"
9 "net/http/httptest"
10 "os"
11 "strings"
12 "testing"
13 "time"
14
15 "shelley.exe.dev/claudetool"
16 "shelley.exe.dev/db"
17 "shelley.exe.dev/db/generated"
18 "shelley.exe.dev/llm"
19 "shelley.exe.dev/loop"
20 "shelley.exe.dev/models"
21 "shelley.exe.dev/server"
22)
23
24// StreamResponse matches server.StreamResponse for testing
25type StreamResponse struct {
26 Messages []json.RawMessage `json:"messages"`
27 Conversation generated.Conversation `json:"conversation"`
28 ConversationState *ConversationState `json:"conversation_state,omitempty"`
29 ConversationListUpdate *ConversationListUpdate `json:"conversation_list_update,omitempty"`
30 Heartbeat bool `json:"heartbeat,omitempty"`
31}
32
33type ConversationState struct {
34 ConversationID string `json:"conversation_id"`
35 Working bool `json:"working"`
36 Model string `json:"model,omitempty"`
37}
38
39type ConversationListUpdate struct {
40 Type string `json:"type"`
41 Conversation *generated.Conversation `json:"conversation,omitempty"`
42 ConversationID string `json:"conversation_id,omitempty"`
43}
44
45type fakeLLMManager struct {
46 service *loop.PredictableService
47}
48
49func (m *fakeLLMManager) GetService(modelID string) (llm.Service, error) {
50 return m.service, nil
51}
52
53func (m *fakeLLMManager) GetAvailableModels() []string {
54 return []string{"predictable"}
55}
56
57func (m *fakeLLMManager) HasModel(modelID string) bool {
58 return modelID == "predictable"
59}
60
61func (m *fakeLLMManager) GetModelInfo(modelID string) *models.ModelInfo {
62 return nil
63}
64
65func (m *fakeLLMManager) RefreshCustomModels() error {
66 return nil
67}
68
69func setupTestServerForSubagent(t *testing.T) (*server.Server, *db.DB, *httptest.Server, *loop.PredictableService) {
70 t.Helper()
71
72 // Create temporary database
73 tempDB := t.TempDir() + "/test.db"
74 database, err := db.New(db.Config{DSN: tempDB})
75 if err != nil {
76 t.Fatalf("Failed to create test database: %v", err)
77 }
78 t.Cleanup(func() { database.Close() })
79
80 // Run migrations
81 if err := database.Migrate(context.Background()); err != nil {
82 t.Fatalf("Failed to migrate database: %v", err)
83 }
84
85 logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{
86 Level: slog.LevelDebug,
87 }))
88
89 // Use predictable model
90 predictableService := loop.NewPredictableService()
91 llmManager := &fakeLLMManager{service: predictableService}
92
93 toolSetConfig := claudetool.ToolSetConfig{
94 WorkingDir: t.TempDir(),
95 EnableBrowser: false,
96 }
97
98 svr := server.NewServer(database, llmManager, toolSetConfig, logger, true, "", "predictable", "", nil)
99
100 mux := http.NewServeMux()
101 svr.RegisterRoutes(mux)
102 testServer := httptest.NewServer(mux)
103 t.Cleanup(testServer.Close)
104
105 return svr, database, testServer, predictableService
106}
107
108// readSSEEvent reads a single SSE event from the response body with a timeout
109func readSSEEventWithTimeout(reader *bufio.Reader, timeout time.Duration) (*StreamResponse, error) {
110 type result struct {
111 resp *StreamResponse
112 err error
113 }
114 ch := make(chan result, 1)
115
116 go func() {
117 var dataLines []string
118 for {
119 line, err := reader.ReadString('\n')
120 if err != nil {
121 ch <- result{nil, err}
122 return
123 }
124 line = strings.TrimSpace(line)
125
126 if line == "" && len(dataLines) > 0 {
127 // End of event
128 break
129 }
130
131 if strings.HasPrefix(line, "data: ") {
132 dataLines = append(dataLines, strings.TrimPrefix(line, "data: "))
133 }
134 }
135
136 if len(dataLines) == 0 {
137 ch <- result{nil, nil}
138 return
139 }
140
141 data := strings.Join(dataLines, "\n")
142 var response StreamResponse
143 if err := json.Unmarshal([]byte(data), &response); err != nil {
144 ch <- result{nil, err}
145 return
146 }
147 ch <- result{&response, nil}
148 }()
149
150 select {
151 case r := <-ch:
152 return r.resp, r.err
153 case <-time.After(timeout):
154 return nil, context.DeadlineExceeded
155 }
156}
157
158// TestSubagentNotificationViaStream tests that when RunSubagent is called,
159// the subagent conversation is properly notified to all SSE streams.
160func TestSubagentNotificationViaStream(t *testing.T) {
161 svr, database, testServer, _ := setupTestServerForSubagent(t)
162
163 ctx := context.Background()
164
165 // Create parent conversation
166 parentSlug := "parent-convo"
167 parentConv, err := database.CreateConversation(ctx, &parentSlug, true, nil, nil)
168 if err != nil {
169 t.Fatalf("Failed to create parent conversation: %v", err)
170 }
171
172 // Start streaming from parent conversation
173 streamURL := testServer.URL + "/api/conversation/" + parentConv.ConversationID + "/stream"
174 resp, err := http.Get(streamURL)
175 if err != nil {
176 t.Fatalf("Failed to connect to stream: %v", err)
177 }
178 defer resp.Body.Close()
179
180 reader := bufio.NewReader(resp.Body)
181
182 // Read initial event (should be the conversation state)
183 initialEvent, err := readSSEEventWithTimeout(reader, 2*time.Second)
184 if err != nil {
185 t.Fatalf("Failed to read initial SSE event: %v", err)
186 }
187 if initialEvent == nil {
188 t.Fatal("Expected initial event")
189 }
190 t.Logf("Initial event: conversation_id=%s, has_state=%v",
191 initialEvent.Conversation.ConversationID,
192 initialEvent.ConversationState != nil)
193
194 // Create a subagent conversation directly in DB (simulating what SubagentTool.Run does)
195 subSlug := "sub-worker"
196 subConv, err := database.CreateSubagentConversation(ctx, subSlug, parentConv.ConversationID, nil)
197 if err != nil {
198 t.Fatalf("Failed to create subagent conversation: %v", err)
199 }
200 t.Logf("Created subagent: id=%s, slug=%s, parent=%s",
201 subConv.ConversationID, *subConv.Slug, *subConv.ParentConversationID)
202
203 // Now call RunSubagent (what the subagent tool does after creating the conversation)
204 // This should trigger the notification to all SSE streams
205 subagentRunner := server.NewSubagentRunner(svr)
206 go func() {
207 // Call RunSubagent with wait=false so it returns quickly
208 subagentRunner.RunSubagent(ctx, subConv.ConversationID, "Test prompt", false, 10*time.Second)
209 }()
210
211 // Wait for notification
212 var receivedSubagentUpdate bool
213 var receivedUpdate *ConversationListUpdate
214
215 deadline := time.Now().Add(3 * time.Second)
216 for time.Now().Before(deadline) {
217 event, err := readSSEEventWithTimeout(reader, 500*time.Millisecond)
218 if err == context.DeadlineExceeded {
219 continue // Keep waiting
220 }
221 if err != nil {
222 t.Logf("Error reading event: %v", err)
223 break
224 }
225 if event == nil {
226 continue
227 }
228
229 t.Logf("Received event: has_list_update=%v, has_state=%v, heartbeat=%v",
230 event.ConversationListUpdate != nil,
231 event.ConversationState != nil,
232 event.Heartbeat)
233
234 if event.ConversationListUpdate != nil {
235 update := event.ConversationListUpdate
236 t.Logf("List update: type=%s", update.Type)
237 if update.Conversation != nil {
238 t.Logf(" conversation_id=%s, parent=%v, slug=%v",
239 update.Conversation.ConversationID,
240 update.Conversation.ParentConversationID,
241 update.Conversation.Slug)
242 if update.Conversation.ConversationID == subConv.ConversationID {
243 receivedSubagentUpdate = true
244 receivedUpdate = update
245 break
246 }
247 }
248 }
249 }
250
251 // Verify we received the notification
252 if !receivedSubagentUpdate {
253 t.Error("Expected to receive subagent update notification via SSE stream when RunSubagent is called")
254 } else {
255 t.Logf("SUCCESS: Received subagent update: type=%s, slug=%v", receivedUpdate.Type, receivedUpdate.Conversation.Slug)
256 }
257}
258
259// TestSubagentWorkingStateNotification tests that subagent working state changes
260// are properly notified via the SSE stream.
261func TestSubagentWorkingStateNotification(t *testing.T) {
262 // This test would verify that when a subagent starts/stops working,
263 // the parent conversation's stream receives a ConversationState update.
264 // Currently we just document this should work via publishConversationState.
265 t.Skip("Skipping - requires more infrastructure to trigger working state changes")
266}