message_bandwidth_test.go

  1package server
  2
  3import (
  4	"bufio"
  5	"context"
  6	"encoding/json"
  7	"log/slog"
  8	"net/http"
  9	"net/http/httptest"
 10	"os"
 11	"strings"
 12	"testing"
 13	"time"
 14
 15	"shelley.exe.dev/claudetool"
 16	"shelley.exe.dev/loop"
 17)
 18
 19// TestMessageSentOnlyOnce verifies that each message is sent to SSE subscribers
 20// only once, not with every update.
 21func TestMessageSentOnlyOnce(t *testing.T) {
 22	database, cleanup := setupTestDB(t)
 23	defer cleanup()
 24
 25	predictableService := loop.NewPredictableService()
 26	llmManager := &testLLMManager{service: predictableService}
 27	logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelWarn}))
 28
 29	server := NewServer(database, llmManager, claudetool.ToolSetConfig{}, logger, true, "", "predictable", "", nil)
 30
 31	// Create conversation
 32	conversation, err := database.CreateConversation(context.Background(), nil, true, nil, nil)
 33	if err != nil {
 34		t.Fatalf("failed to create conversation: %v", err)
 35	}
 36	conversationID := conversation.ConversationID
 37
 38	// Set up real HTTP server
 39	mux := http.NewServeMux()
 40	server.RegisterRoutes(mux)
 41	httpServer := httptest.NewServer(mux)
 42	defer httpServer.Close()
 43
 44	// Connect to SSE stream
 45	sseResp, err := http.Get(httpServer.URL + "/api/conversation/" + conversationID + "/stream")
 46	if err != nil {
 47		t.Fatalf("failed to connect to SSE stream: %v", err)
 48	}
 49	defer sseResp.Body.Close()
 50
 51	// Start reading SSE events in background
 52	type sseEvent struct {
 53		data      StreamResponse
 54		msgCount  int
 55		totalSize int
 56	}
 57	sseEvents := make(chan sseEvent, 100)
 58
 59	go func() {
 60		scanner := bufio.NewScanner(sseResp.Body)
 61		for scanner.Scan() {
 62			line := scanner.Text()
 63			if !strings.HasPrefix(line, "data: ") {
 64				continue
 65			}
 66			jsonStr := strings.TrimPrefix(line, "data: ")
 67			var streamResp StreamResponse
 68			if err := json.Unmarshal([]byte(jsonStr), &streamResp); err != nil {
 69				continue
 70			}
 71			sseEvents <- sseEvent{
 72				data:      streamResp,
 73				msgCount:  len(streamResp.Messages),
 74				totalSize: len(jsonStr),
 75			}
 76		}
 77	}()
 78
 79	// Wait for initial SSE event (empty)
 80	select {
 81	case ev := <-sseEvents:
 82		t.Logf("Initial SSE event: %d messages, %d bytes", ev.msgCount, ev.totalSize)
 83	case <-time.After(2 * time.Second):
 84		t.Fatal("timed out waiting for initial SSE event")
 85	}
 86
 87	// Send first user message
 88	chatReq := ChatRequest{
 89		Message: "hello",
 90		Model:   "predictable",
 91	}
 92	chatBody, _ := json.Marshal(chatReq)
 93
 94	resp, err := http.Post(
 95		httpServer.URL+"/api/conversation/"+conversationID+"/chat",
 96		"application/json",
 97		strings.NewReader(string(chatBody)),
 98	)
 99	if err != nil {
100		t.Fatalf("failed to send chat message: %v", err)
101	}
102	resp.Body.Close()
103
104	// Collect SSE events for a short time to see the message progression
105	var receivedEvents []sseEvent
106	deadline := time.Now().Add(3 * time.Second)
107
108	for time.Now().Before(deadline) {
109		select {
110		case ev := <-sseEvents:
111			receivedEvents = append(receivedEvents, ev)
112			t.Logf("SSE event %d: %d messages, %d bytes", len(receivedEvents), ev.msgCount, ev.totalSize)
113
114			// Check if we have end_of_turn
115			if len(ev.data.Messages) > 0 {
116				lastMsg := ev.data.Messages[len(ev.data.Messages)-1]
117				if lastMsg.EndOfTurn != nil && *lastMsg.EndOfTurn {
118					t.Log("Got end_of_turn, stopping collection")
119					goto done
120				}
121			}
122		case <-time.After(100 * time.Millisecond):
123			// Keep waiting
124		}
125	}
126
127done:
128	if len(receivedEvents) == 0 {
129		t.Fatal("received no SSE events after sending message")
130	}
131
132	// Analyze: count how many times each message was sent
133	messagesSent := make(map[int64]int) // sequence_id -> count
134	totalBytes := 0
135
136	for _, ev := range receivedEvents {
137		totalBytes += ev.totalSize
138		for _, msg := range ev.data.Messages {
139			messagesSent[msg.SequenceID]++
140		}
141	}
142
143	t.Logf("Total bytes sent across all SSE events: %d", totalBytes)
144	t.Logf("Message send counts:")
145	for seqID, count := range messagesSent {
146		t.Logf("  Sequence %d: sent %d times", seqID, count)
147		if count > 1 {
148			t.Errorf("BUG: Message with sequence_id=%d was sent %d times (expected 1)", seqID, count)
149		}
150	}
151}
152
153// TestContextWindowSizeInSSE verifies that context_window_size is correctly
154// included only when agent messages with usage data are sent.
155func TestContextWindowSizeInSSE(t *testing.T) {
156	database, cleanup := setupTestDB(t)
157	defer cleanup()
158
159	predictableService := loop.NewPredictableService()
160	llmManager := &testLLMManager{service: predictableService}
161	logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelWarn}))
162
163	server := NewServer(database, llmManager, claudetool.ToolSetConfig{}, logger, true, "", "predictable", "", nil)
164
165	// Create conversation
166	conversation, err := database.CreateConversation(context.Background(), nil, true, nil, nil)
167	if err != nil {
168		t.Fatalf("failed to create conversation: %v", err)
169	}
170	conversationID := conversation.ConversationID
171
172	// Set up real HTTP server
173	mux := http.NewServeMux()
174	server.RegisterRoutes(mux)
175	httpServer := httptest.NewServer(mux)
176	defer httpServer.Close()
177
178	// Connect to SSE stream
179	sseResp, err := http.Get(httpServer.URL + "/api/conversation/" + conversationID + "/stream")
180	if err != nil {
181		t.Fatalf("failed to connect to SSE stream: %v", err)
182	}
183	defer sseResp.Body.Close()
184
185	// Start reading SSE events in background
186	type sseEvent struct {
187		data              StreamResponse
188		contextWindowSize uint64
189		hasContextWindow  bool
190	}
191	sseEvents := make(chan sseEvent, 100)
192
193	go func() {
194		scanner := bufio.NewScanner(sseResp.Body)
195		for scanner.Scan() {
196			line := scanner.Text()
197			if !strings.HasPrefix(line, "data: ") {
198				continue
199			}
200			jsonStr := strings.TrimPrefix(line, "data: ")
201			var streamResp StreamResponse
202			if err := json.Unmarshal([]byte(jsonStr), &streamResp); err != nil {
203				continue
204			}
205			// Check if context_window_size was present in the JSON
206			var raw map[string]interface{}
207			json.Unmarshal([]byte(jsonStr), &raw)
208			_, hasCtx := raw["context_window_size"]
209
210			sseEvents <- sseEvent{
211				data:              streamResp,
212				contextWindowSize: streamResp.ContextWindowSize,
213				hasContextWindow:  hasCtx,
214			}
215		}
216	}()
217
218	// Wait for initial SSE event (empty)
219	select {
220	case ev := <-sseEvents:
221		t.Logf("Initial: context_window_size present=%v value=%d", ev.hasContextWindow, ev.contextWindowSize)
222	case <-time.After(2 * time.Second):
223		t.Fatal("timed out waiting for initial SSE event")
224	}
225
226	// Send user message
227	chatReq := ChatRequest{
228		Message: "hello",
229		Model:   "predictable",
230	}
231	chatBody, _ := json.Marshal(chatReq)
232
233	resp, err := http.Post(
234		httpServer.URL+"/api/conversation/"+conversationID+"/chat",
235		"application/json",
236		strings.NewReader(string(chatBody)),
237	)
238	if err != nil {
239		t.Fatalf("failed to send chat message: %v", err)
240	}
241	resp.Body.Close()
242
243	// Collect SSE events
244	var receivedEvents []sseEvent
245	deadline := time.Now().Add(3 * time.Second)
246
247	for time.Now().Before(deadline) {
248		select {
249		case ev := <-sseEvents:
250			receivedEvents = append(receivedEvents, ev)
251			msgType := "unknown"
252			if len(ev.data.Messages) > 0 {
253				msgType = ev.data.Messages[0].Type
254			}
255			t.Logf("Event %d: type=%s context_window_size present=%v value=%d",
256				len(receivedEvents), msgType, ev.hasContextWindow, ev.contextWindowSize)
257
258			// Check if we have end_of_turn
259			if len(ev.data.Messages) > 0 {
260				lastMsg := ev.data.Messages[len(ev.data.Messages)-1]
261				if lastMsg.EndOfTurn != nil && *lastMsg.EndOfTurn {
262					goto done
263				}
264			}
265		case <-time.After(100 * time.Millisecond):
266		}
267	}
268
269done:
270	// Verify: user messages should NOT have context_window_size (omitted via omitempty)
271	// Agent messages with usage data SHOULD have context_window_size
272	for i, ev := range receivedEvents {
273		if len(ev.data.Messages) == 0 {
274			continue
275		}
276		msg := ev.data.Messages[0]
277		if msg.Type == "user" {
278			// User messages have no usage data, context_window_size should be omitted (0)
279			if ev.hasContextWindow && ev.contextWindowSize != 0 {
280				t.Errorf("Event %d: user message should not have context_window_size, got %d", i+1, ev.contextWindowSize)
281			}
282		} else if msg.Type == "agent" && msg.UsageData != nil {
283			// Agent messages with usage data should have context_window_size
284			if !ev.hasContextWindow {
285				t.Errorf("Event %d: agent message with usage data should have context_window_size", i+1)
286			}
287			if ev.contextWindowSize == 0 {
288				t.Errorf("Event %d: agent message context_window_size should not be 0", i+1)
289			}
290		}
291	}
292}