conversation_list_stream_test.go

  1package server
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"log/slog"
  7	"net/http"
  8	"net/http/httptest"
  9	"os"
 10	"strings"
 11	"testing"
 12	"time"
 13
 14	"shelley.exe.dev/claudetool"
 15	"shelley.exe.dev/loop"
 16)
 17
 18// TestConversationStreamReceivesListUpdateForNewConversation tests that when subscribed
 19// to one conversation's stream, we receive updates about new conversations.
 20func TestConversationStreamReceivesListUpdateForNewConversation(t *testing.T) {
 21	database, cleanup := setupTestDB(t)
 22	defer cleanup()
 23
 24	predictableService := loop.NewPredictableService()
 25	llmManager := &testLLMManager{service: predictableService}
 26	logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelWarn}))
 27	server := NewServer(database, llmManager, claudetool.ToolSetConfig{}, logger, true, "", "predictable", "", nil)
 28
 29	// Create a conversation to subscribe to
 30	conversation, err := database.CreateConversation(context.Background(), nil, true, nil, nil)
 31	if err != nil {
 32		t.Fatalf("failed to create conversation: %v", err)
 33	}
 34
 35	// Get or create conversation manager to ensure the conversation is active
 36	_, err = server.getOrCreateConversationManager(context.Background(), conversation.ConversationID)
 37	if err != nil {
 38		t.Fatalf("failed to get conversation manager: %v", err)
 39	}
 40
 41	// Start the conversation stream
 42	sseCtx, sseCancel := context.WithCancel(context.Background())
 43	defer sseCancel()
 44
 45	sseRecorder := newFlusherRecorder()
 46	sseReq := httptest.NewRequest("GET", "/api/conversation/"+conversation.ConversationID+"/stream", nil)
 47	sseReq = sseReq.WithContext(sseCtx)
 48
 49	sseStarted := make(chan struct{})
 50	sseDone := make(chan struct{})
 51	go func() {
 52		close(sseStarted)
 53		server.handleStreamConversation(sseRecorder, sseReq, conversation.ConversationID)
 54		close(sseDone)
 55	}()
 56
 57	<-sseStarted
 58
 59	// Wait for the initial event
 60	select {
 61	case <-sseRecorder.flushed:
 62	case <-time.After(2 * time.Second):
 63		t.Fatal("timeout waiting for initial SSE event")
 64	}
 65
 66	// Create another conversation via the API
 67	chatReq := ChatRequest{
 68		Message: "hello",
 69		Model:   "predictable",
 70	}
 71	chatBody, _ := json.Marshal(chatReq)
 72	req := httptest.NewRequest("POST", "/api/conversations/new", strings.NewReader(string(chatBody)))
 73	req.Header.Set("Content-Type", "application/json")
 74	w := httptest.NewRecorder()
 75
 76	server.handleNewConversation(w, req)
 77	if w.Code != http.StatusCreated {
 78		t.Fatalf("expected status 201, got %d: %s", w.Code, w.Body.String())
 79	}
 80
 81	var resp struct {
 82		ConversationID string `json:"conversation_id"`
 83	}
 84	if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
 85		t.Fatalf("failed to parse response: %v", err)
 86	}
 87
 88	// Wait for the conversation list update to come through the existing stream
 89	deadline := time.Now().Add(2 * time.Second)
 90	var receivedUpdate bool
 91	for time.Now().Before(deadline) && !receivedUpdate {
 92		select {
 93		case <-sseRecorder.flushed:
 94			chunks := sseRecorder.getChunks()
 95			for _, chunk := range chunks {
 96				// Check for conversation_list_update with the new conversation ID
 97				if strings.Contains(chunk, "conversation_list_update") && strings.Contains(chunk, resp.ConversationID) {
 98					receivedUpdate = true
 99					break
100				}
101			}
102		case <-time.After(100 * time.Millisecond):
103		}
104	}
105
106	if !receivedUpdate {
107		t.Error("did not receive conversation list update for new conversation")
108		chunks := sseRecorder.getChunks()
109		t.Logf("SSE chunks received: %v", chunks)
110	}
111
112	sseCancel()
113	<-sseDone
114}
115
116// TestConversationStreamReceivesListUpdateForRename tests that when subscribed
117// to one conversation's stream, we receive updates when another conversation is renamed.
118func TestConversationStreamReceivesListUpdateForRename(t *testing.T) {
119	database, cleanup := setupTestDB(t)
120	defer cleanup()
121
122	predictableService := loop.NewPredictableService()
123	llmManager := &testLLMManager{service: predictableService}
124	logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelWarn}))
125	server := NewServer(database, llmManager, claudetool.ToolSetConfig{}, logger, true, "", "predictable", "", nil)
126
127	// Create two conversations
128	conv1, err := database.CreateConversation(context.Background(), nil, true, nil, nil)
129	if err != nil {
130		t.Fatalf("failed to create conversation 1: %v", err)
131	}
132	conv2, err := database.CreateConversation(context.Background(), nil, true, nil, nil)
133	if err != nil {
134		t.Fatalf("failed to create conversation 2: %v", err)
135	}
136
137	// Get or create conversation manager for conv1 (the one we'll subscribe to)
138	_, err = server.getOrCreateConversationManager(context.Background(), conv1.ConversationID)
139	if err != nil {
140		t.Fatalf("failed to get conversation manager: %v", err)
141	}
142
143	// Start the conversation stream for conv1
144	sseCtx, sseCancel := context.WithCancel(context.Background())
145	defer sseCancel()
146
147	sseRecorder := newFlusherRecorder()
148	sseReq := httptest.NewRequest("GET", "/api/conversation/"+conv1.ConversationID+"/stream", nil)
149	sseReq = sseReq.WithContext(sseCtx)
150
151	sseStarted := make(chan struct{})
152	sseDone := make(chan struct{})
153	go func() {
154		close(sseStarted)
155		server.handleStreamConversation(sseRecorder, sseReq, conv1.ConversationID)
156		close(sseDone)
157	}()
158
159	<-sseStarted
160
161	// Wait for the initial event
162	select {
163	case <-sseRecorder.flushed:
164	case <-time.After(2 * time.Second):
165		t.Fatal("timeout waiting for initial SSE event")
166	}
167
168	// Rename conv2
169	renameReq := RenameRequest{Slug: "test-slug-rename"}
170	renameBody, _ := json.Marshal(renameReq)
171	req := httptest.NewRequest("POST", "/api/conversation/"+conv2.ConversationID+"/rename", strings.NewReader(string(renameBody)))
172	req.Header.Set("Content-Type", "application/json")
173	w := httptest.NewRecorder()
174
175	server.handleRenameConversation(w, req, conv2.ConversationID)
176	if w.Code != http.StatusOK {
177		t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String())
178	}
179
180	// Wait for the conversation list update with the new slug
181	deadline := time.Now().Add(2 * time.Second)
182	var receivedUpdate bool
183	for time.Now().Before(deadline) && !receivedUpdate {
184		select {
185		case <-sseRecorder.flushed:
186			chunks := sseRecorder.getChunks()
187			for _, chunk := range chunks {
188				if strings.Contains(chunk, "conversation_list_update") && strings.Contains(chunk, "test-slug-rename") {
189					receivedUpdate = true
190					break
191				}
192			}
193		case <-time.After(100 * time.Millisecond):
194		}
195	}
196
197	if !receivedUpdate {
198		t.Error("did not receive conversation list update for slug change")
199		chunks := sseRecorder.getChunks()
200		t.Logf("SSE chunks received: %v", chunks)
201	}
202
203	sseCancel()
204	<-sseDone
205}
206
207// TestConversationStreamReceivesListUpdateForDelete tests that when subscribed
208// to one conversation's stream, we receive updates when another conversation is deleted.
209func TestConversationStreamReceivesListUpdateForDelete(t *testing.T) {
210	database, cleanup := setupTestDB(t)
211	defer cleanup()
212
213	predictableService := loop.NewPredictableService()
214	llmManager := &testLLMManager{service: predictableService}
215	logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelWarn}))
216	server := NewServer(database, llmManager, claudetool.ToolSetConfig{}, logger, true, "", "predictable", "", nil)
217
218	// Create two conversations
219	conv1, err := database.CreateConversation(context.Background(), nil, true, nil, nil)
220	if err != nil {
221		t.Fatalf("failed to create conversation 1: %v", err)
222	}
223	conv2, err := database.CreateConversation(context.Background(), nil, true, nil, nil)
224	if err != nil {
225		t.Fatalf("failed to create conversation 2: %v", err)
226	}
227
228	// Get or create conversation manager for conv1
229	_, err = server.getOrCreateConversationManager(context.Background(), conv1.ConversationID)
230	if err != nil {
231		t.Fatalf("failed to get conversation manager: %v", err)
232	}
233
234	// Start the conversation stream for conv1
235	sseCtx, sseCancel := context.WithCancel(context.Background())
236	defer sseCancel()
237
238	sseRecorder := newFlusherRecorder()
239	sseReq := httptest.NewRequest("GET", "/api/conversation/"+conv1.ConversationID+"/stream", nil)
240	sseReq = sseReq.WithContext(sseCtx)
241
242	sseStarted := make(chan struct{})
243	sseDone := make(chan struct{})
244	go func() {
245		close(sseStarted)
246		server.handleStreamConversation(sseRecorder, sseReq, conv1.ConversationID)
247		close(sseDone)
248	}()
249
250	<-sseStarted
251
252	// Wait for the initial event
253	select {
254	case <-sseRecorder.flushed:
255	case <-time.After(2 * time.Second):
256		t.Fatal("timeout waiting for initial SSE event")
257	}
258
259	// Delete conv2
260	req := httptest.NewRequest("POST", "/api/conversation/"+conv2.ConversationID+"/delete", nil)
261	w := httptest.NewRecorder()
262
263	server.handleDeleteConversation(w, req, conv2.ConversationID)
264	if w.Code != http.StatusOK {
265		t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String())
266	}
267
268	// Wait for the delete update
269	deadline := time.Now().Add(2 * time.Second)
270	var receivedUpdate bool
271	for time.Now().Before(deadline) && !receivedUpdate {
272		select {
273		case <-sseRecorder.flushed:
274			chunks := sseRecorder.getChunks()
275			for _, chunk := range chunks {
276				if strings.Contains(chunk, "conversation_list_update") &&
277					strings.Contains(chunk, `"type":"delete"`) &&
278					strings.Contains(chunk, conv2.ConversationID) {
279					receivedUpdate = true
280					break
281				}
282			}
283		case <-time.After(100 * time.Millisecond):
284		}
285	}
286
287	if !receivedUpdate {
288		t.Error("did not receive conversation list delete update")
289		chunks := sseRecorder.getChunks()
290		t.Logf("SSE chunks received: %v", chunks)
291	}
292
293	sseCancel()
294	<-sseDone
295}
296
297// TestConversationStreamReceivesListUpdateForArchive tests that when subscribed
298// to one conversation's stream, we receive updates when another conversation is archived.
299func TestConversationStreamReceivesListUpdateForArchive(t *testing.T) {
300	database, cleanup := setupTestDB(t)
301	defer cleanup()
302
303	predictableService := loop.NewPredictableService()
304	llmManager := &testLLMManager{service: predictableService}
305	logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelWarn}))
306	server := NewServer(database, llmManager, claudetool.ToolSetConfig{}, logger, true, "", "predictable", "", nil)
307
308	// Create two conversations
309	conv1, err := database.CreateConversation(context.Background(), nil, true, nil, nil)
310	if err != nil {
311		t.Fatalf("failed to create conversation 1: %v", err)
312	}
313	conv2, err := database.CreateConversation(context.Background(), nil, true, nil, nil)
314	if err != nil {
315		t.Fatalf("failed to create conversation 2: %v", err)
316	}
317
318	// Get or create conversation manager for conv1
319	_, err = server.getOrCreateConversationManager(context.Background(), conv1.ConversationID)
320	if err != nil {
321		t.Fatalf("failed to get conversation manager: %v", err)
322	}
323
324	// Start the conversation stream for conv1
325	sseCtx, sseCancel := context.WithCancel(context.Background())
326	defer sseCancel()
327
328	sseRecorder := newFlusherRecorder()
329	sseReq := httptest.NewRequest("GET", "/api/conversation/"+conv1.ConversationID+"/stream", nil)
330	sseReq = sseReq.WithContext(sseCtx)
331
332	sseStarted := make(chan struct{})
333	sseDone := make(chan struct{})
334	go func() {
335		close(sseStarted)
336		server.handleStreamConversation(sseRecorder, sseReq, conv1.ConversationID)
337		close(sseDone)
338	}()
339
340	<-sseStarted
341
342	// Wait for the initial event
343	select {
344	case <-sseRecorder.flushed:
345	case <-time.After(2 * time.Second):
346		t.Fatal("timeout waiting for initial SSE event")
347	}
348
349	// Archive conv2
350	req := httptest.NewRequest("POST", "/api/conversation/"+conv2.ConversationID+"/archive", nil)
351	w := httptest.NewRecorder()
352
353	server.handleArchiveConversation(w, req, conv2.ConversationID)
354	if w.Code != http.StatusOK {
355		t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String())
356	}
357
358	// Wait for the archive update
359	deadline := time.Now().Add(2 * time.Second)
360	var receivedUpdate bool
361	for time.Now().Before(deadline) && !receivedUpdate {
362		select {
363		case <-sseRecorder.flushed:
364			chunks := sseRecorder.getChunks()
365			for _, chunk := range chunks {
366				if strings.Contains(chunk, "conversation_list_update") &&
367					strings.Contains(chunk, conv2.ConversationID) &&
368					strings.Contains(chunk, `"archived":true`) {
369					receivedUpdate = true
370					break
371				}
372			}
373		case <-time.After(100 * time.Millisecond):
374		}
375	}
376
377	if !receivedUpdate {
378		t.Error("did not receive conversation list archive update")
379		chunks := sseRecorder.getChunks()
380		t.Logf("SSE chunks received: %v", chunks)
381	}
382
383	sseCancel()
384	<-sseDone
385}