1package server
2
3import (
4 "bufio"
5 "context"
6 "encoding/json"
7 "log/slog"
8 "net/http"
9 "net/http/httptest"
10 "os"
11 "strings"
12 "sync"
13 "testing"
14 "time"
15
16 "shelley.exe.dev/claudetool"
17 "shelley.exe.dev/db"
18 "shelley.exe.dev/llm"
19 "shelley.exe.dev/loop"
20)
21
22// flusherRecorder wraps httptest.ResponseRecorder to implement http.Flusher
23// and provide immediate access to written data in a thread-safe manner
24type flusherRecorder struct {
25 *httptest.ResponseRecorder
26 mu sync.Mutex
27 chunks []string
28 flushed chan struct{}
29}
30
31func newFlusherRecorder() *flusherRecorder {
32 return &flusherRecorder{
33 ResponseRecorder: httptest.NewRecorder(),
34 flushed: make(chan struct{}, 100),
35 }
36}
37
38// Write overrides ResponseRecorder.Write to provide thread-safe access
39func (f *flusherRecorder) Write(p []byte) (int, error) {
40 f.mu.Lock()
41 defer f.mu.Unlock()
42 return f.ResponseRecorder.Write(p)
43}
44
45func (f *flusherRecorder) Flush() {
46 f.mu.Lock()
47 body := f.Body.String()
48 f.chunks = append(f.chunks, body)
49 f.mu.Unlock()
50
51 select {
52 case f.flushed <- struct{}{}:
53 default:
54 }
55}
56
57func (f *flusherRecorder) getChunks() []string {
58 f.mu.Lock()
59 defer f.mu.Unlock()
60 result := make([]string, len(f.chunks))
61 copy(result, f.chunks)
62 return result
63}
64
65// getString returns the current body contents in a thread-safe manner
66func (f *flusherRecorder) getString() string {
67 f.mu.Lock()
68 defer f.mu.Unlock()
69 return f.Body.String()
70}
71
72// TestSSEUserMessageAppearsImmediately tests that when a user sends a message,
73// the message appears in the SSE stream immediately, before the LLM responds.
74func TestSSEUserMessageAppearsImmediately(t *testing.T) {
75 database, cleanup := setupTestDB(t)
76 defer cleanup()
77
78 predictableService := loop.NewPredictableService()
79 llmManager := &testLLMManager{service: predictableService}
80 logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelWarn}))
81 server := NewServer(database, llmManager, claudetool.ToolSetConfig{}, logger, true, "", "predictable", "", nil)
82
83 // Create conversation
84 conversation, err := database.CreateConversation(context.Background(), nil, true, nil, nil)
85 if err != nil {
86 t.Fatalf("failed to create conversation: %v", err)
87 }
88 conversationID := conversation.ConversationID
89
90 // Set up a context we can cancel to stop the SSE handler
91 sseCtx, sseCancel := context.WithCancel(context.Background())
92 defer sseCancel()
93
94 // Start the SSE stream handler in a goroutine
95 sseRecorder := newFlusherRecorder()
96 sseReq := httptest.NewRequest("GET", "/api/conversation/"+conversationID+"/stream", nil)
97 sseReq = sseReq.WithContext(sseCtx)
98
99 sseStarted := make(chan struct{})
100 sseDone := make(chan struct{})
101 go func() {
102 close(sseStarted)
103 server.handleStreamConversation(sseRecorder, sseReq, conversationID)
104 close(sseDone)
105 }()
106
107 // Wait for SSE handler to start and send initial state
108 <-sseStarted
109
110 // Wait for the initial SSE event (empty messages)
111 select {
112 case <-sseRecorder.flushed:
113 // Got initial state
114 case <-time.After(2 * time.Second):
115 t.Fatal("timed out waiting for initial SSE event")
116 }
117
118 // Now send a user message that triggers a SLOW LLM response (3 seconds delay)
119 chatReq := ChatRequest{
120 Message: "delay: 3",
121 Model: "predictable",
122 }
123 chatBody, _ := json.Marshal(chatReq)
124
125 req := httptest.NewRequest("POST", "/api/conversation/"+conversationID+"/chat", strings.NewReader(string(chatBody)))
126 req.Header.Set("Content-Type", "application/json")
127 w := httptest.NewRecorder()
128
129 server.handleChatConversation(w, req, conversationID)
130 if w.Code != http.StatusAccepted {
131 t.Fatalf("expected status 202, got %d: %s", w.Code, w.Body.String())
132 }
133
134 // The user message should appear in the SSE stream IMMEDIATELY (within 500ms)
135 // NOT after the 3 second LLM delay
136 deadline := time.Now().Add(500 * time.Millisecond)
137 userMessageFound := false
138
139 for time.Now().Before(deadline) {
140 select {
141 case <-sseRecorder.flushed:
142 // Check if user message is now in the stream
143 body := sseRecorder.getString()
144 if containsUserMessage(body, "delay: 3") {
145 userMessageFound = true
146 }
147 case <-time.After(50 * time.Millisecond):
148 // Also check current body
149 body := sseRecorder.getString()
150 if containsUserMessage(body, "delay: 3") {
151 userMessageFound = true
152 }
153 }
154 if userMessageFound {
155 break
156 }
157 }
158
159 if !userMessageFound {
160 t.Errorf("BUG: user message did not appear in SSE stream within 500ms (LLM has 3s delay)")
161 t.Log("This likely means notifySubscribers is not being called immediately after recording the user message")
162 t.Logf("SSE body so far: %s", sseRecorder.getString())
163 } else {
164 t.Log("SUCCESS: user message appeared in SSE stream immediately")
165 }
166
167 // Clean up: cancel SSE context and wait for handler to finish
168 sseCancel()
169 select {
170 case <-sseDone:
171 case <-time.After(1 * time.Second):
172 // Handler may not exit immediately, that's OK
173 }
174}
175
176// containsUserMessage checks if the SSE body contains a user message with the given text
177func containsUserMessage(sseBody, messageText string) bool {
178 // SSE format is "data: {json}\n\n"
179 scanner := bufio.NewScanner(strings.NewReader(sseBody))
180 for scanner.Scan() {
181 line := scanner.Text()
182 if !strings.HasPrefix(line, "data: ") {
183 continue
184 }
185 jsonStr := strings.TrimPrefix(line, "data: ")
186
187 var streamResp StreamResponse
188 if err := json.Unmarshal([]byte(jsonStr), &streamResp); err != nil {
189 continue
190 }
191
192 for _, msg := range streamResp.Messages {
193 if msg.Type != string(db.MessageTypeUser) {
194 continue
195 }
196 if msg.LlmData == nil {
197 continue
198 }
199 var llmMsg llm.Message
200 if err := json.Unmarshal([]byte(*msg.LlmData), &llmMsg); err != nil {
201 continue
202 }
203 for _, content := range llmMsg.Content {
204 if content.Type == llm.ContentTypeText && strings.Contains(content.Text, messageText) {
205 return true
206 }
207 }
208 }
209 }
210 return false
211}
212
213// TestSSEUserMessageWithRealHTTPServer tests with a real HTTP server to properly
214// test HTTP context cancellation behavior
215func TestSSEUserMessageWithRealHTTPServer(t *testing.T) {
216 database, cleanup := setupTestDB(t)
217 defer cleanup()
218
219 predictableService := loop.NewPredictableService()
220 llmManager := &testLLMManager{service: predictableService}
221 logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelWarn}))
222 srv := NewServer(database, llmManager, claudetool.ToolSetConfig{}, logger, true, "", "predictable", "", nil)
223
224 // Create conversation
225 conversation, err := database.CreateConversation(context.Background(), nil, true, nil, nil)
226 if err != nil {
227 t.Fatalf("failed to create conversation: %v", err)
228 }
229 conversationID := conversation.ConversationID
230
231 // Set up real HTTP server
232 mux := http.NewServeMux()
233 srv.RegisterRoutes(mux)
234 httpServer := httptest.NewServer(mux)
235 defer httpServer.Close()
236
237 // Connect to SSE stream
238 sseResp, err := http.Get(httpServer.URL + "/api/conversation/" + conversationID + "/stream")
239 if err != nil {
240 t.Fatalf("failed to connect to SSE stream: %v", err)
241 }
242 defer sseResp.Body.Close()
243
244 // Start reading SSE events in background
245 sseEvents := make(chan string, 100)
246 go func() {
247 scanner := bufio.NewScanner(sseResp.Body)
248 for scanner.Scan() {
249 line := scanner.Text()
250 if strings.HasPrefix(line, "data: ") {
251 sseEvents <- line
252 }
253 }
254 }()
255
256 // Wait for initial SSE event
257 select {
258 case <-sseEvents:
259 // Got initial state
260 case <-time.After(2 * time.Second):
261 t.Fatal("timed out waiting for initial SSE event")
262 }
263
264 // Send user message with slow LLM response via real HTTP client
265 chatReq := ChatRequest{
266 Message: "delay: 5",
267 Model: "predictable",
268 }
269 chatBody, _ := json.Marshal(chatReq)
270
271 resp, err := http.Post(
272 httpServer.URL+"/api/conversation/"+conversationID+"/chat",
273 "application/json",
274 strings.NewReader(string(chatBody)),
275 )
276 if err != nil {
277 t.Fatalf("failed to send chat message: %v", err)
278 }
279 resp.Body.Close()
280
281 if resp.StatusCode != http.StatusAccepted {
282 t.Fatalf("expected status 202, got %d", resp.StatusCode)
283 }
284
285 // User message should appear in SSE stream within 500ms (before 5s LLM delay)
286 deadline := time.Now().Add(500 * time.Millisecond)
287 userMessageFound := false
288
289 for time.Now().Before(deadline) && !userMessageFound {
290 select {
291 case eventLine := <-sseEvents:
292 jsonStr := strings.TrimPrefix(eventLine, "data: ")
293 var streamResp StreamResponse
294 if err := json.Unmarshal([]byte(jsonStr), &streamResp); err != nil {
295 continue
296 }
297 for _, msg := range streamResp.Messages {
298 if msg.Type == string(db.MessageTypeUser) && msg.LlmData != nil {
299 var llmMsg llm.Message
300 if err := json.Unmarshal([]byte(*msg.LlmData), &llmMsg); err == nil {
301 for _, content := range llmMsg.Content {
302 if content.Type == llm.ContentTypeText && strings.Contains(content.Text, "delay: 5") {
303 userMessageFound = true
304 break
305 }
306 }
307 }
308 }
309 }
310 case <-time.After(50 * time.Millisecond):
311 // Keep waiting
312 }
313 }
314
315 if !userMessageFound {
316 t.Error("BUG: user message did not appear in SSE stream within 500ms with real HTTP server")
317 t.Log("This confirms the context cancellation bug in notifySubscribers")
318 } else {
319 t.Log("SUCCESS: user message appeared in SSE stream immediately with real HTTP server")
320 }
321}
322
323// TestSSEUserMessageWithExistingConnection is a simpler version that tests
324// message recording and notification without the SSE complexity
325func TestSSEUserMessageWithExistingConnection(t *testing.T) {
326 database, cleanup := setupTestDB(t)
327 defer cleanup()
328
329 predictableService := loop.NewPredictableService()
330 llmManager := &testLLMManager{service: predictableService}
331 logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelWarn}))
332 server := NewServer(database, llmManager, claudetool.ToolSetConfig{}, logger, true, "", "predictable", "", nil)
333
334 // Create conversation and get a manager (simulating an established SSE connection)
335 conversation, err := database.CreateConversation(context.Background(), nil, true, nil, nil)
336 if err != nil {
337 t.Fatalf("failed to create conversation: %v", err)
338 }
339 conversationID := conversation.ConversationID
340
341 // Get the conversation manager to set up subscription
342 manager, err := server.getOrCreateConversationManager(context.Background(), conversationID)
343 if err != nil {
344 t.Fatalf("failed to get conversation manager: %v", err)
345 }
346
347 // Subscribe to updates
348 subCtx, subCancel := context.WithCancel(context.Background())
349 defer subCancel()
350 next := manager.subpub.Subscribe(subCtx, -1)
351
352 // Channel to receive updates
353 updates := make(chan StreamResponse, 10)
354 go func() {
355 for {
356 data, ok := next()
357 if !ok {
358 return
359 }
360 updates <- data
361 }
362 }()
363
364 // Now send a user message with slow LLM response
365 chatReq := ChatRequest{
366 Message: "delay: 5",
367 Model: "predictable",
368 }
369 chatBody, _ := json.Marshal(chatReq)
370
371 req := httptest.NewRequest("POST", "/api/conversation/"+conversationID+"/chat", strings.NewReader(string(chatBody)))
372 req.Header.Set("Content-Type", "application/json")
373 w := httptest.NewRecorder()
374
375 server.handleChatConversation(w, req, conversationID)
376 if w.Code != http.StatusAccepted {
377 t.Fatalf("expected status 202, got %d: %s", w.Code, w.Body.String())
378 }
379
380 // We should receive an update with the user message within 500ms
381 // (well before the 5 second LLM delay)
382 // Note: We may receive other updates first (e.g., ConversationListUpdate for slug changes),
383 // so we need to keep checking until we find the user message or timeout.
384 deadline := time.Now().Add(500 * time.Millisecond)
385 foundUserMsg := false
386
387 for time.Now().Before(deadline) && !foundUserMsg {
388 select {
389 case update := <-updates:
390 // Check if this update contains the user message
391 for _, msg := range update.Messages {
392 if msg.Type == string(db.MessageTypeUser) && msg.LlmData != nil {
393 var llmMsg llm.Message
394 if err := json.Unmarshal([]byte(*msg.LlmData), &llmMsg); err == nil {
395 for _, content := range llmMsg.Content {
396 if content.Type == llm.ContentTypeText && strings.Contains(content.Text, "delay: 5") {
397 foundUserMsg = true
398 break
399 }
400 }
401 }
402 }
403 }
404 case <-time.After(50 * time.Millisecond):
405 // Keep waiting
406 }
407 }
408
409 if !foundUserMsg {
410 t.Error("BUG: did not receive subpub update with user message within 500ms")
411 t.Log("This means notifySubscribers is failing or not being called after user message is recorded")
412 } else {
413 t.Log("SUCCESS: received user message via subpub immediately")
414 }
415}