1package server
2
3import (
4 "bufio"
5 "context"
6 "encoding/json"
7 "log/slog"
8 "net/http"
9 "net/http/httptest"
10 "os"
11 "strings"
12 "testing"
13 "time"
14
15 "shelley.exe.dev/claudetool"
16 "shelley.exe.dev/loop"
17)
18
19// TestMessageSentOnlyOnce verifies that each message is sent to SSE subscribers
20// only once, not with every update.
21func TestMessageSentOnlyOnce(t *testing.T) {
22 database, cleanup := setupTestDB(t)
23 defer cleanup()
24
25 predictableService := loop.NewPredictableService()
26 llmManager := &testLLMManager{service: predictableService}
27 logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelWarn}))
28
29 server := NewServer(database, llmManager, claudetool.ToolSetConfig{}, logger, true, "", "predictable", "", nil)
30
31 // Create conversation
32 conversation, err := database.CreateConversation(context.Background(), nil, true, nil, nil)
33 if err != nil {
34 t.Fatalf("failed to create conversation: %v", err)
35 }
36 conversationID := conversation.ConversationID
37
38 // Set up real HTTP server
39 mux := http.NewServeMux()
40 server.RegisterRoutes(mux)
41 httpServer := httptest.NewServer(mux)
42 defer httpServer.Close()
43
44 // Connect to SSE stream
45 sseResp, err := http.Get(httpServer.URL + "/api/conversation/" + conversationID + "/stream")
46 if err != nil {
47 t.Fatalf("failed to connect to SSE stream: %v", err)
48 }
49 defer sseResp.Body.Close()
50
51 // Start reading SSE events in background
52 type sseEvent struct {
53 data StreamResponse
54 msgCount int
55 totalSize int
56 }
57 sseEvents := make(chan sseEvent, 100)
58
59 go func() {
60 scanner := bufio.NewScanner(sseResp.Body)
61 for scanner.Scan() {
62 line := scanner.Text()
63 if !strings.HasPrefix(line, "data: ") {
64 continue
65 }
66 jsonStr := strings.TrimPrefix(line, "data: ")
67 var streamResp StreamResponse
68 if err := json.Unmarshal([]byte(jsonStr), &streamResp); err != nil {
69 continue
70 }
71 sseEvents <- sseEvent{
72 data: streamResp,
73 msgCount: len(streamResp.Messages),
74 totalSize: len(jsonStr),
75 }
76 }
77 }()
78
79 // Wait for initial SSE event (empty)
80 select {
81 case ev := <-sseEvents:
82 t.Logf("Initial SSE event: %d messages, %d bytes", ev.msgCount, ev.totalSize)
83 case <-time.After(2 * time.Second):
84 t.Fatal("timed out waiting for initial SSE event")
85 }
86
87 // Send first user message
88 chatReq := ChatRequest{
89 Message: "hello",
90 Model: "predictable",
91 }
92 chatBody, _ := json.Marshal(chatReq)
93
94 resp, err := http.Post(
95 httpServer.URL+"/api/conversation/"+conversationID+"/chat",
96 "application/json",
97 strings.NewReader(string(chatBody)),
98 )
99 if err != nil {
100 t.Fatalf("failed to send chat message: %v", err)
101 }
102 resp.Body.Close()
103
104 // Collect SSE events for a short time to see the message progression
105 var receivedEvents []sseEvent
106 deadline := time.Now().Add(3 * time.Second)
107
108 for time.Now().Before(deadline) {
109 select {
110 case ev := <-sseEvents:
111 receivedEvents = append(receivedEvents, ev)
112 t.Logf("SSE event %d: %d messages, %d bytes", len(receivedEvents), ev.msgCount, ev.totalSize)
113
114 // Check if we have end_of_turn
115 if len(ev.data.Messages) > 0 {
116 lastMsg := ev.data.Messages[len(ev.data.Messages)-1]
117 if lastMsg.EndOfTurn != nil && *lastMsg.EndOfTurn {
118 t.Log("Got end_of_turn, stopping collection")
119 goto done
120 }
121 }
122 case <-time.After(100 * time.Millisecond):
123 // Keep waiting
124 }
125 }
126
127done:
128 if len(receivedEvents) == 0 {
129 t.Fatal("received no SSE events after sending message")
130 }
131
132 // Analyze: count how many times each message was sent
133 messagesSent := make(map[int64]int) // sequence_id -> count
134 totalBytes := 0
135
136 for _, ev := range receivedEvents {
137 totalBytes += ev.totalSize
138 for _, msg := range ev.data.Messages {
139 messagesSent[msg.SequenceID]++
140 }
141 }
142
143 t.Logf("Total bytes sent across all SSE events: %d", totalBytes)
144 t.Logf("Message send counts:")
145 for seqID, count := range messagesSent {
146 t.Logf(" Sequence %d: sent %d times", seqID, count)
147 if count > 1 {
148 t.Errorf("BUG: Message with sequence_id=%d was sent %d times (expected 1)", seqID, count)
149 }
150 }
151}
152
153// TestContextWindowSizeInSSE verifies that context_window_size is correctly
154// included only when agent messages with usage data are sent.
155func TestContextWindowSizeInSSE(t *testing.T) {
156 database, cleanup := setupTestDB(t)
157 defer cleanup()
158
159 predictableService := loop.NewPredictableService()
160 llmManager := &testLLMManager{service: predictableService}
161 logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelWarn}))
162
163 server := NewServer(database, llmManager, claudetool.ToolSetConfig{}, logger, true, "", "predictable", "", nil)
164
165 // Create conversation
166 conversation, err := database.CreateConversation(context.Background(), nil, true, nil, nil)
167 if err != nil {
168 t.Fatalf("failed to create conversation: %v", err)
169 }
170 conversationID := conversation.ConversationID
171
172 // Set up real HTTP server
173 mux := http.NewServeMux()
174 server.RegisterRoutes(mux)
175 httpServer := httptest.NewServer(mux)
176 defer httpServer.Close()
177
178 // Connect to SSE stream
179 sseResp, err := http.Get(httpServer.URL + "/api/conversation/" + conversationID + "/stream")
180 if err != nil {
181 t.Fatalf("failed to connect to SSE stream: %v", err)
182 }
183 defer sseResp.Body.Close()
184
185 // Start reading SSE events in background
186 type sseEvent struct {
187 data StreamResponse
188 contextWindowSize uint64
189 hasContextWindow bool
190 }
191 sseEvents := make(chan sseEvent, 100)
192
193 go func() {
194 scanner := bufio.NewScanner(sseResp.Body)
195 for scanner.Scan() {
196 line := scanner.Text()
197 if !strings.HasPrefix(line, "data: ") {
198 continue
199 }
200 jsonStr := strings.TrimPrefix(line, "data: ")
201 var streamResp StreamResponse
202 if err := json.Unmarshal([]byte(jsonStr), &streamResp); err != nil {
203 continue
204 }
205 // Check if context_window_size was present in the JSON
206 var raw map[string]interface{}
207 json.Unmarshal([]byte(jsonStr), &raw)
208 _, hasCtx := raw["context_window_size"]
209
210 sseEvents <- sseEvent{
211 data: streamResp,
212 contextWindowSize: streamResp.ContextWindowSize,
213 hasContextWindow: hasCtx,
214 }
215 }
216 }()
217
218 // Wait for initial SSE event (empty)
219 select {
220 case ev := <-sseEvents:
221 t.Logf("Initial: context_window_size present=%v value=%d", ev.hasContextWindow, ev.contextWindowSize)
222 case <-time.After(2 * time.Second):
223 t.Fatal("timed out waiting for initial SSE event")
224 }
225
226 // Send user message
227 chatReq := ChatRequest{
228 Message: "hello",
229 Model: "predictable",
230 }
231 chatBody, _ := json.Marshal(chatReq)
232
233 resp, err := http.Post(
234 httpServer.URL+"/api/conversation/"+conversationID+"/chat",
235 "application/json",
236 strings.NewReader(string(chatBody)),
237 )
238 if err != nil {
239 t.Fatalf("failed to send chat message: %v", err)
240 }
241 resp.Body.Close()
242
243 // Collect SSE events
244 var receivedEvents []sseEvent
245 deadline := time.Now().Add(3 * time.Second)
246
247 for time.Now().Before(deadline) {
248 select {
249 case ev := <-sseEvents:
250 receivedEvents = append(receivedEvents, ev)
251 msgType := "unknown"
252 if len(ev.data.Messages) > 0 {
253 msgType = ev.data.Messages[0].Type
254 }
255 t.Logf("Event %d: type=%s context_window_size present=%v value=%d",
256 len(receivedEvents), msgType, ev.hasContextWindow, ev.contextWindowSize)
257
258 // Check if we have end_of_turn
259 if len(ev.data.Messages) > 0 {
260 lastMsg := ev.data.Messages[len(ev.data.Messages)-1]
261 if lastMsg.EndOfTurn != nil && *lastMsg.EndOfTurn {
262 goto done
263 }
264 }
265 case <-time.After(100 * time.Millisecond):
266 }
267 }
268
269done:
270 // Verify: user messages should NOT have context_window_size (omitted via omitempty)
271 // Agent messages with usage data SHOULD have context_window_size
272 for i, ev := range receivedEvents {
273 if len(ev.data.Messages) == 0 {
274 continue
275 }
276 msg := ev.data.Messages[0]
277 if msg.Type == "user" {
278 // User messages have no usage data, context_window_size should be omitted (0)
279 if ev.hasContextWindow && ev.contextWindowSize != 0 {
280 t.Errorf("Event %d: user message should not have context_window_size, got %d", i+1, ev.contextWindowSize)
281 }
282 } else if msg.Type == "agent" && msg.UsageData != nil {
283 // Agent messages with usage data should have context_window_size
284 if !ev.hasContextWindow {
285 t.Errorf("Event %d: agent message with usage data should have context_window_size", i+1)
286 }
287 if ev.contextWindowSize == 0 {
288 t.Errorf("Event %d: agent message context_window_size should not be 0", i+1)
289 }
290 }
291 }
292}