sse_immediacy_test.go

  1package server
  2
  3import (
  4	"bufio"
  5	"context"
  6	"encoding/json"
  7	"log/slog"
  8	"net/http"
  9	"net/http/httptest"
 10	"os"
 11	"strings"
 12	"sync"
 13	"testing"
 14	"time"
 15
 16	"shelley.exe.dev/claudetool"
 17	"shelley.exe.dev/db"
 18	"shelley.exe.dev/llm"
 19	"shelley.exe.dev/loop"
 20)
 21
 22// flusherRecorder wraps httptest.ResponseRecorder to implement http.Flusher
 23// and provide immediate access to written data in a thread-safe manner
 24type flusherRecorder struct {
 25	*httptest.ResponseRecorder
 26	mu      sync.Mutex
 27	chunks  []string
 28	flushed chan struct{}
 29}
 30
 31func newFlusherRecorder() *flusherRecorder {
 32	return &flusherRecorder{
 33		ResponseRecorder: httptest.NewRecorder(),
 34		flushed:          make(chan struct{}, 100),
 35	}
 36}
 37
 38// Write overrides ResponseRecorder.Write to provide thread-safe access
 39func (f *flusherRecorder) Write(p []byte) (int, error) {
 40	f.mu.Lock()
 41	defer f.mu.Unlock()
 42	return f.ResponseRecorder.Write(p)
 43}
 44
 45func (f *flusherRecorder) Flush() {
 46	f.mu.Lock()
 47	body := f.Body.String()
 48	f.chunks = append(f.chunks, body)
 49	f.mu.Unlock()
 50
 51	select {
 52	case f.flushed <- struct{}{}:
 53	default:
 54	}
 55}
 56
 57func (f *flusherRecorder) getChunks() []string {
 58	f.mu.Lock()
 59	defer f.mu.Unlock()
 60	result := make([]string, len(f.chunks))
 61	copy(result, f.chunks)
 62	return result
 63}
 64
 65// getString returns the current body contents in a thread-safe manner
 66func (f *flusherRecorder) getString() string {
 67	f.mu.Lock()
 68	defer f.mu.Unlock()
 69	return f.Body.String()
 70}
 71
 72// TestSSEUserMessageAppearsImmediately tests that when a user sends a message,
 73// the message appears in the SSE stream immediately, before the LLM responds.
 74func TestSSEUserMessageAppearsImmediately(t *testing.T) {
 75	database, cleanup := setupTestDB(t)
 76	defer cleanup()
 77
 78	predictableService := loop.NewPredictableService()
 79	llmManager := &testLLMManager{service: predictableService}
 80	logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelWarn}))
 81	server := NewServer(database, llmManager, claudetool.ToolSetConfig{}, logger, true, "", "predictable", "", nil)
 82
 83	// Create conversation
 84	conversation, err := database.CreateConversation(context.Background(), nil, true, nil, nil)
 85	if err != nil {
 86		t.Fatalf("failed to create conversation: %v", err)
 87	}
 88	conversationID := conversation.ConversationID
 89
 90	// Set up a context we can cancel to stop the SSE handler
 91	sseCtx, sseCancel := context.WithCancel(context.Background())
 92	defer sseCancel()
 93
 94	// Start the SSE stream handler in a goroutine
 95	sseRecorder := newFlusherRecorder()
 96	sseReq := httptest.NewRequest("GET", "/api/conversation/"+conversationID+"/stream", nil)
 97	sseReq = sseReq.WithContext(sseCtx)
 98
 99	sseStarted := make(chan struct{})
100	sseDone := make(chan struct{})
101	go func() {
102		close(sseStarted)
103		server.handleStreamConversation(sseRecorder, sseReq, conversationID)
104		close(sseDone)
105	}()
106
107	// Wait for SSE handler to start and send initial state
108	<-sseStarted
109
110	// Wait for the initial SSE event (empty messages)
111	select {
112	case <-sseRecorder.flushed:
113		// Got initial state
114	case <-time.After(2 * time.Second):
115		t.Fatal("timed out waiting for initial SSE event")
116	}
117
118	// Now send a user message that triggers a SLOW LLM response (3 seconds delay)
119	chatReq := ChatRequest{
120		Message: "delay: 3",
121		Model:   "predictable",
122	}
123	chatBody, _ := json.Marshal(chatReq)
124
125	req := httptest.NewRequest("POST", "/api/conversation/"+conversationID+"/chat", strings.NewReader(string(chatBody)))
126	req.Header.Set("Content-Type", "application/json")
127	w := httptest.NewRecorder()
128
129	server.handleChatConversation(w, req, conversationID)
130	if w.Code != http.StatusAccepted {
131		t.Fatalf("expected status 202, got %d: %s", w.Code, w.Body.String())
132	}
133
134	// The user message should appear in the SSE stream IMMEDIATELY (within 500ms)
135	// NOT after the 3 second LLM delay
136	deadline := time.Now().Add(500 * time.Millisecond)
137	userMessageFound := false
138
139	for time.Now().Before(deadline) {
140		select {
141		case <-sseRecorder.flushed:
142			// Check if user message is now in the stream
143			body := sseRecorder.getString()
144			if containsUserMessage(body, "delay: 3") {
145				userMessageFound = true
146			}
147		case <-time.After(50 * time.Millisecond):
148			// Also check current body
149			body := sseRecorder.getString()
150			if containsUserMessage(body, "delay: 3") {
151				userMessageFound = true
152			}
153		}
154		if userMessageFound {
155			break
156		}
157	}
158
159	if !userMessageFound {
160		t.Errorf("BUG: user message did not appear in SSE stream within 500ms (LLM has 3s delay)")
161		t.Log("This likely means notifySubscribers is not being called immediately after recording the user message")
162		t.Logf("SSE body so far: %s", sseRecorder.getString())
163	} else {
164		t.Log("SUCCESS: user message appeared in SSE stream immediately")
165	}
166
167	// Clean up: cancel SSE context and wait for handler to finish
168	sseCancel()
169	select {
170	case <-sseDone:
171	case <-time.After(1 * time.Second):
172		// Handler may not exit immediately, that's OK
173	}
174}
175
176// containsUserMessage checks if the SSE body contains a user message with the given text
177func containsUserMessage(sseBody, messageText string) bool {
178	// SSE format is "data: {json}\n\n"
179	scanner := bufio.NewScanner(strings.NewReader(sseBody))
180	for scanner.Scan() {
181		line := scanner.Text()
182		if !strings.HasPrefix(line, "data: ") {
183			continue
184		}
185		jsonStr := strings.TrimPrefix(line, "data: ")
186
187		var streamResp StreamResponse
188		if err := json.Unmarshal([]byte(jsonStr), &streamResp); err != nil {
189			continue
190		}
191
192		for _, msg := range streamResp.Messages {
193			if msg.Type != string(db.MessageTypeUser) {
194				continue
195			}
196			if msg.LlmData == nil {
197				continue
198			}
199			var llmMsg llm.Message
200			if err := json.Unmarshal([]byte(*msg.LlmData), &llmMsg); err != nil {
201				continue
202			}
203			for _, content := range llmMsg.Content {
204				if content.Type == llm.ContentTypeText && strings.Contains(content.Text, messageText) {
205					return true
206				}
207			}
208		}
209	}
210	return false
211}
212
213// TestSSEUserMessageWithRealHTTPServer tests with a real HTTP server to properly
214// test HTTP context cancellation behavior
215func TestSSEUserMessageWithRealHTTPServer(t *testing.T) {
216	database, cleanup := setupTestDB(t)
217	defer cleanup()
218
219	predictableService := loop.NewPredictableService()
220	llmManager := &testLLMManager{service: predictableService}
221	logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelWarn}))
222	srv := NewServer(database, llmManager, claudetool.ToolSetConfig{}, logger, true, "", "predictable", "", nil)
223
224	// Create conversation
225	conversation, err := database.CreateConversation(context.Background(), nil, true, nil, nil)
226	if err != nil {
227		t.Fatalf("failed to create conversation: %v", err)
228	}
229	conversationID := conversation.ConversationID
230
231	// Set up real HTTP server
232	mux := http.NewServeMux()
233	srv.RegisterRoutes(mux)
234	httpServer := httptest.NewServer(mux)
235	defer httpServer.Close()
236
237	// Connect to SSE stream
238	sseResp, err := http.Get(httpServer.URL + "/api/conversation/" + conversationID + "/stream")
239	if err != nil {
240		t.Fatalf("failed to connect to SSE stream: %v", err)
241	}
242	defer sseResp.Body.Close()
243
244	// Start reading SSE events in background
245	sseEvents := make(chan string, 100)
246	go func() {
247		scanner := bufio.NewScanner(sseResp.Body)
248		for scanner.Scan() {
249			line := scanner.Text()
250			if strings.HasPrefix(line, "data: ") {
251				sseEvents <- line
252			}
253		}
254	}()
255
256	// Wait for initial SSE event
257	select {
258	case <-sseEvents:
259		// Got initial state
260	case <-time.After(2 * time.Second):
261		t.Fatal("timed out waiting for initial SSE event")
262	}
263
264	// Send user message with slow LLM response via real HTTP client
265	chatReq := ChatRequest{
266		Message: "delay: 5",
267		Model:   "predictable",
268	}
269	chatBody, _ := json.Marshal(chatReq)
270
271	resp, err := http.Post(
272		httpServer.URL+"/api/conversation/"+conversationID+"/chat",
273		"application/json",
274		strings.NewReader(string(chatBody)),
275	)
276	if err != nil {
277		t.Fatalf("failed to send chat message: %v", err)
278	}
279	resp.Body.Close()
280
281	if resp.StatusCode != http.StatusAccepted {
282		t.Fatalf("expected status 202, got %d", resp.StatusCode)
283	}
284
285	// User message should appear in SSE stream within 500ms (before 5s LLM delay)
286	deadline := time.Now().Add(500 * time.Millisecond)
287	userMessageFound := false
288
289	for time.Now().Before(deadline) && !userMessageFound {
290		select {
291		case eventLine := <-sseEvents:
292			jsonStr := strings.TrimPrefix(eventLine, "data: ")
293			var streamResp StreamResponse
294			if err := json.Unmarshal([]byte(jsonStr), &streamResp); err != nil {
295				continue
296			}
297			for _, msg := range streamResp.Messages {
298				if msg.Type == string(db.MessageTypeUser) && msg.LlmData != nil {
299					var llmMsg llm.Message
300					if err := json.Unmarshal([]byte(*msg.LlmData), &llmMsg); err == nil {
301						for _, content := range llmMsg.Content {
302							if content.Type == llm.ContentTypeText && strings.Contains(content.Text, "delay: 5") {
303								userMessageFound = true
304								break
305							}
306						}
307					}
308				}
309			}
310		case <-time.After(50 * time.Millisecond):
311			// Keep waiting
312		}
313	}
314
315	if !userMessageFound {
316		t.Error("BUG: user message did not appear in SSE stream within 500ms with real HTTP server")
317		t.Log("This confirms the context cancellation bug in notifySubscribers")
318	} else {
319		t.Log("SUCCESS: user message appeared in SSE stream immediately with real HTTP server")
320	}
321}
322
323// TestSSEUserMessageWithExistingConnection is a simpler version that tests
324// message recording and notification without the SSE complexity
325func TestSSEUserMessageWithExistingConnection(t *testing.T) {
326	database, cleanup := setupTestDB(t)
327	defer cleanup()
328
329	predictableService := loop.NewPredictableService()
330	llmManager := &testLLMManager{service: predictableService}
331	logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelWarn}))
332	server := NewServer(database, llmManager, claudetool.ToolSetConfig{}, logger, true, "", "predictable", "", nil)
333
334	// Create conversation and get a manager (simulating an established SSE connection)
335	conversation, err := database.CreateConversation(context.Background(), nil, true, nil, nil)
336	if err != nil {
337		t.Fatalf("failed to create conversation: %v", err)
338	}
339	conversationID := conversation.ConversationID
340
341	// Get the conversation manager to set up subscription
342	manager, err := server.getOrCreateConversationManager(context.Background(), conversationID)
343	if err != nil {
344		t.Fatalf("failed to get conversation manager: %v", err)
345	}
346
347	// Subscribe to updates
348	subCtx, subCancel := context.WithCancel(context.Background())
349	defer subCancel()
350	next := manager.subpub.Subscribe(subCtx, -1)
351
352	// Channel to receive updates
353	updates := make(chan StreamResponse, 10)
354	go func() {
355		for {
356			data, ok := next()
357			if !ok {
358				return
359			}
360			updates <- data
361		}
362	}()
363
364	// Now send a user message with slow LLM response
365	chatReq := ChatRequest{
366		Message: "delay: 5",
367		Model:   "predictable",
368	}
369	chatBody, _ := json.Marshal(chatReq)
370
371	req := httptest.NewRequest("POST", "/api/conversation/"+conversationID+"/chat", strings.NewReader(string(chatBody)))
372	req.Header.Set("Content-Type", "application/json")
373	w := httptest.NewRecorder()
374
375	server.handleChatConversation(w, req, conversationID)
376	if w.Code != http.StatusAccepted {
377		t.Fatalf("expected status 202, got %d: %s", w.Code, w.Body.String())
378	}
379
380	// We should receive an update with the user message within 500ms
381	// (well before the 5 second LLM delay)
382	// Note: We may receive other updates first (e.g., ConversationListUpdate for slug changes),
383	// so we need to keep checking until we find the user message or timeout.
384	deadline := time.Now().Add(500 * time.Millisecond)
385	foundUserMsg := false
386
387	for time.Now().Before(deadline) && !foundUserMsg {
388		select {
389		case update := <-updates:
390			// Check if this update contains the user message
391			for _, msg := range update.Messages {
392				if msg.Type == string(db.MessageTypeUser) && msg.LlmData != nil {
393					var llmMsg llm.Message
394					if err := json.Unmarshal([]byte(*msg.LlmData), &llmMsg); err == nil {
395						for _, content := range llmMsg.Content {
396							if content.Type == llm.ContentTypeText && strings.Contains(content.Text, "delay: 5") {
397								foundUserMsg = true
398								break
399							}
400						}
401					}
402				}
403			}
404		case <-time.After(50 * time.Millisecond):
405			// Keep waiting
406		}
407	}
408
409	if !foundUserMsg {
410		t.Error("BUG: did not receive subpub update with user message within 500ms")
411		t.Log("This means notifySubscribers is failing or not being called after user message is recorded")
412	} else {
413		t.Log("SUCCESS: received user message via subpub immediately")
414	}
415}