1package server
2
3import (
4 "context"
5 "encoding/json"
6 "log/slog"
7 "net/http"
8 "net/http/httptest"
9 "os"
10 "strings"
11 "testing"
12 "time"
13
14 "shelley.exe.dev/claudetool"
15 "shelley.exe.dev/loop"
16)
17
18// TestConversationStreamReceivesListUpdateForNewConversation tests that when subscribed
19// to one conversation's stream, we receive updates about new conversations.
20func TestConversationStreamReceivesListUpdateForNewConversation(t *testing.T) {
21 database, cleanup := setupTestDB(t)
22 defer cleanup()
23
24 predictableService := loop.NewPredictableService()
25 llmManager := &testLLMManager{service: predictableService}
26 logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelWarn}))
27 server := NewServer(database, llmManager, claudetool.ToolSetConfig{}, logger, true, "", "predictable", "", nil)
28
29 // Create a conversation to subscribe to
30 conversation, err := database.CreateConversation(context.Background(), nil, true, nil, nil)
31 if err != nil {
32 t.Fatalf("failed to create conversation: %v", err)
33 }
34
35 // Get or create conversation manager to ensure the conversation is active
36 _, err = server.getOrCreateConversationManager(context.Background(), conversation.ConversationID)
37 if err != nil {
38 t.Fatalf("failed to get conversation manager: %v", err)
39 }
40
41 // Start the conversation stream
42 sseCtx, sseCancel := context.WithCancel(context.Background())
43 defer sseCancel()
44
45 sseRecorder := newFlusherRecorder()
46 sseReq := httptest.NewRequest("GET", "/api/conversation/"+conversation.ConversationID+"/stream", nil)
47 sseReq = sseReq.WithContext(sseCtx)
48
49 sseStarted := make(chan struct{})
50 sseDone := make(chan struct{})
51 go func() {
52 close(sseStarted)
53 server.handleStreamConversation(sseRecorder, sseReq, conversation.ConversationID)
54 close(sseDone)
55 }()
56
57 <-sseStarted
58
59 // Wait for the initial event
60 select {
61 case <-sseRecorder.flushed:
62 case <-time.After(2 * time.Second):
63 t.Fatal("timeout waiting for initial SSE event")
64 }
65
66 // Create another conversation via the API
67 chatReq := ChatRequest{
68 Message: "hello",
69 Model: "predictable",
70 }
71 chatBody, _ := json.Marshal(chatReq)
72 req := httptest.NewRequest("POST", "/api/conversations/new", strings.NewReader(string(chatBody)))
73 req.Header.Set("Content-Type", "application/json")
74 w := httptest.NewRecorder()
75
76 server.handleNewConversation(w, req)
77 if w.Code != http.StatusCreated {
78 t.Fatalf("expected status 201, got %d: %s", w.Code, w.Body.String())
79 }
80
81 var resp struct {
82 ConversationID string `json:"conversation_id"`
83 }
84 if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
85 t.Fatalf("failed to parse response: %v", err)
86 }
87
88 // Wait for the conversation list update to come through the existing stream
89 deadline := time.Now().Add(2 * time.Second)
90 var receivedUpdate bool
91 for time.Now().Before(deadline) && !receivedUpdate {
92 select {
93 case <-sseRecorder.flushed:
94 chunks := sseRecorder.getChunks()
95 for _, chunk := range chunks {
96 // Check for conversation_list_update with the new conversation ID
97 if strings.Contains(chunk, "conversation_list_update") && strings.Contains(chunk, resp.ConversationID) {
98 receivedUpdate = true
99 break
100 }
101 }
102 case <-time.After(100 * time.Millisecond):
103 }
104 }
105
106 if !receivedUpdate {
107 t.Error("did not receive conversation list update for new conversation")
108 chunks := sseRecorder.getChunks()
109 t.Logf("SSE chunks received: %v", chunks)
110 }
111
112 sseCancel()
113 <-sseDone
114}
115
116// TestConversationStreamReceivesListUpdateForRename tests that when subscribed
117// to one conversation's stream, we receive updates when another conversation is renamed.
118func TestConversationStreamReceivesListUpdateForRename(t *testing.T) {
119 database, cleanup := setupTestDB(t)
120 defer cleanup()
121
122 predictableService := loop.NewPredictableService()
123 llmManager := &testLLMManager{service: predictableService}
124 logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelWarn}))
125 server := NewServer(database, llmManager, claudetool.ToolSetConfig{}, logger, true, "", "predictable", "", nil)
126
127 // Create two conversations
128 conv1, err := database.CreateConversation(context.Background(), nil, true, nil, nil)
129 if err != nil {
130 t.Fatalf("failed to create conversation 1: %v", err)
131 }
132 conv2, err := database.CreateConversation(context.Background(), nil, true, nil, nil)
133 if err != nil {
134 t.Fatalf("failed to create conversation 2: %v", err)
135 }
136
137 // Get or create conversation manager for conv1 (the one we'll subscribe to)
138 _, err = server.getOrCreateConversationManager(context.Background(), conv1.ConversationID)
139 if err != nil {
140 t.Fatalf("failed to get conversation manager: %v", err)
141 }
142
143 // Start the conversation stream for conv1
144 sseCtx, sseCancel := context.WithCancel(context.Background())
145 defer sseCancel()
146
147 sseRecorder := newFlusherRecorder()
148 sseReq := httptest.NewRequest("GET", "/api/conversation/"+conv1.ConversationID+"/stream", nil)
149 sseReq = sseReq.WithContext(sseCtx)
150
151 sseStarted := make(chan struct{})
152 sseDone := make(chan struct{})
153 go func() {
154 close(sseStarted)
155 server.handleStreamConversation(sseRecorder, sseReq, conv1.ConversationID)
156 close(sseDone)
157 }()
158
159 <-sseStarted
160
161 // Wait for the initial event
162 select {
163 case <-sseRecorder.flushed:
164 case <-time.After(2 * time.Second):
165 t.Fatal("timeout waiting for initial SSE event")
166 }
167
168 // Rename conv2
169 renameReq := RenameRequest{Slug: "test-slug-rename"}
170 renameBody, _ := json.Marshal(renameReq)
171 req := httptest.NewRequest("POST", "/api/conversation/"+conv2.ConversationID+"/rename", strings.NewReader(string(renameBody)))
172 req.Header.Set("Content-Type", "application/json")
173 w := httptest.NewRecorder()
174
175 server.handleRenameConversation(w, req, conv2.ConversationID)
176 if w.Code != http.StatusOK {
177 t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String())
178 }
179
180 // Wait for the conversation list update with the new slug
181 deadline := time.Now().Add(2 * time.Second)
182 var receivedUpdate bool
183 for time.Now().Before(deadline) && !receivedUpdate {
184 select {
185 case <-sseRecorder.flushed:
186 chunks := sseRecorder.getChunks()
187 for _, chunk := range chunks {
188 if strings.Contains(chunk, "conversation_list_update") && strings.Contains(chunk, "test-slug-rename") {
189 receivedUpdate = true
190 break
191 }
192 }
193 case <-time.After(100 * time.Millisecond):
194 }
195 }
196
197 if !receivedUpdate {
198 t.Error("did not receive conversation list update for slug change")
199 chunks := sseRecorder.getChunks()
200 t.Logf("SSE chunks received: %v", chunks)
201 }
202
203 sseCancel()
204 <-sseDone
205}
206
207// TestConversationStreamReceivesListUpdateForDelete tests that when subscribed
208// to one conversation's stream, we receive updates when another conversation is deleted.
209func TestConversationStreamReceivesListUpdateForDelete(t *testing.T) {
210 database, cleanup := setupTestDB(t)
211 defer cleanup()
212
213 predictableService := loop.NewPredictableService()
214 llmManager := &testLLMManager{service: predictableService}
215 logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelWarn}))
216 server := NewServer(database, llmManager, claudetool.ToolSetConfig{}, logger, true, "", "predictable", "", nil)
217
218 // Create two conversations
219 conv1, err := database.CreateConversation(context.Background(), nil, true, nil, nil)
220 if err != nil {
221 t.Fatalf("failed to create conversation 1: %v", err)
222 }
223 conv2, err := database.CreateConversation(context.Background(), nil, true, nil, nil)
224 if err != nil {
225 t.Fatalf("failed to create conversation 2: %v", err)
226 }
227
228 // Get or create conversation manager for conv1
229 _, err = server.getOrCreateConversationManager(context.Background(), conv1.ConversationID)
230 if err != nil {
231 t.Fatalf("failed to get conversation manager: %v", err)
232 }
233
234 // Start the conversation stream for conv1
235 sseCtx, sseCancel := context.WithCancel(context.Background())
236 defer sseCancel()
237
238 sseRecorder := newFlusherRecorder()
239 sseReq := httptest.NewRequest("GET", "/api/conversation/"+conv1.ConversationID+"/stream", nil)
240 sseReq = sseReq.WithContext(sseCtx)
241
242 sseStarted := make(chan struct{})
243 sseDone := make(chan struct{})
244 go func() {
245 close(sseStarted)
246 server.handleStreamConversation(sseRecorder, sseReq, conv1.ConversationID)
247 close(sseDone)
248 }()
249
250 <-sseStarted
251
252 // Wait for the initial event
253 select {
254 case <-sseRecorder.flushed:
255 case <-time.After(2 * time.Second):
256 t.Fatal("timeout waiting for initial SSE event")
257 }
258
259 // Delete conv2
260 req := httptest.NewRequest("POST", "/api/conversation/"+conv2.ConversationID+"/delete", nil)
261 w := httptest.NewRecorder()
262
263 server.handleDeleteConversation(w, req, conv2.ConversationID)
264 if w.Code != http.StatusOK {
265 t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String())
266 }
267
268 // Wait for the delete update
269 deadline := time.Now().Add(2 * time.Second)
270 var receivedUpdate bool
271 for time.Now().Before(deadline) && !receivedUpdate {
272 select {
273 case <-sseRecorder.flushed:
274 chunks := sseRecorder.getChunks()
275 for _, chunk := range chunks {
276 if strings.Contains(chunk, "conversation_list_update") &&
277 strings.Contains(chunk, `"type":"delete"`) &&
278 strings.Contains(chunk, conv2.ConversationID) {
279 receivedUpdate = true
280 break
281 }
282 }
283 case <-time.After(100 * time.Millisecond):
284 }
285 }
286
287 if !receivedUpdate {
288 t.Error("did not receive conversation list delete update")
289 chunks := sseRecorder.getChunks()
290 t.Logf("SSE chunks received: %v", chunks)
291 }
292
293 sseCancel()
294 <-sseDone
295}
296
297// TestConversationStreamReceivesListUpdateForArchive tests that when subscribed
298// to one conversation's stream, we receive updates when another conversation is archived.
299func TestConversationStreamReceivesListUpdateForArchive(t *testing.T) {
300 database, cleanup := setupTestDB(t)
301 defer cleanup()
302
303 predictableService := loop.NewPredictableService()
304 llmManager := &testLLMManager{service: predictableService}
305 logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelWarn}))
306 server := NewServer(database, llmManager, claudetool.ToolSetConfig{}, logger, true, "", "predictable", "", nil)
307
308 // Create two conversations
309 conv1, err := database.CreateConversation(context.Background(), nil, true, nil, nil)
310 if err != nil {
311 t.Fatalf("failed to create conversation 1: %v", err)
312 }
313 conv2, err := database.CreateConversation(context.Background(), nil, true, nil, nil)
314 if err != nil {
315 t.Fatalf("failed to create conversation 2: %v", err)
316 }
317
318 // Get or create conversation manager for conv1
319 _, err = server.getOrCreateConversationManager(context.Background(), conv1.ConversationID)
320 if err != nil {
321 t.Fatalf("failed to get conversation manager: %v", err)
322 }
323
324 // Start the conversation stream for conv1
325 sseCtx, sseCancel := context.WithCancel(context.Background())
326 defer sseCancel()
327
328 sseRecorder := newFlusherRecorder()
329 sseReq := httptest.NewRequest("GET", "/api/conversation/"+conv1.ConversationID+"/stream", nil)
330 sseReq = sseReq.WithContext(sseCtx)
331
332 sseStarted := make(chan struct{})
333 sseDone := make(chan struct{})
334 go func() {
335 close(sseStarted)
336 server.handleStreamConversation(sseRecorder, sseReq, conv1.ConversationID)
337 close(sseDone)
338 }()
339
340 <-sseStarted
341
342 // Wait for the initial event
343 select {
344 case <-sseRecorder.flushed:
345 case <-time.After(2 * time.Second):
346 t.Fatal("timeout waiting for initial SSE event")
347 }
348
349 // Archive conv2
350 req := httptest.NewRequest("POST", "/api/conversation/"+conv2.ConversationID+"/archive", nil)
351 w := httptest.NewRecorder()
352
353 server.handleArchiveConversation(w, req, conv2.ConversationID)
354 if w.Code != http.StatusOK {
355 t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String())
356 }
357
358 // Wait for the archive update
359 deadline := time.Now().Add(2 * time.Second)
360 var receivedUpdate bool
361 for time.Now().Before(deadline) && !receivedUpdate {
362 select {
363 case <-sseRecorder.flushed:
364 chunks := sseRecorder.getChunks()
365 for _, chunk := range chunks {
366 if strings.Contains(chunk, "conversation_list_update") &&
367 strings.Contains(chunk, conv2.ConversationID) &&
368 strings.Contains(chunk, `"archived":true`) {
369 receivedUpdate = true
370 break
371 }
372 }
373 case <-time.After(100 * time.Millisecond):
374 }
375 }
376
377 if !receivedUpdate {
378 t.Error("did not receive conversation list archive update")
379 chunks := sseRecorder.getChunks()
380 t.Logf("SSE chunks received: %v", chunks)
381 }
382
383 sseCancel()
384 <-sseDone
385}