1package server
2
3import (
4 "encoding/json"
5 "testing"
6
7 "shelley.exe.dev/db"
8 "shelley.exe.dev/llm"
9)
10
11// TestContextWindowSizeCalculation tests that the context window size is correctly
12// calculated including cached tokens.
13func TestContextWindowSizeCalculation(t *testing.T) {
14 // Test the calculateContextWindowSize function directly
15 t.Run("includes_all_token_types", func(t *testing.T) {
16 // Create usage data with all token types
17 usage := llm.Usage{
18 InputTokens: 100,
19 CacheCreationInputTokens: 50,
20 CacheReadInputTokens: 200,
21 OutputTokens: 30,
22 }
23 usageJSON, _ := json.Marshal(usage)
24 usageStr := string(usageJSON)
25
26 messages := []APIMessage{
27 {
28 Type: string(db.MessageTypeAgent),
29 UsageData: &usageStr,
30 },
31 }
32
33 // Expected: 100 + 50 + 200 + 30 = 380
34 got := calculateContextWindowSize(messages)
35 want := uint64(380)
36
37 if got != want {
38 t.Errorf("calculateContextWindowSize() = %d, want %d", got, want)
39 }
40 })
41
42 t.Run("only_input_tokens", func(t *testing.T) {
43 // Test with just input tokens (no caching)
44 usage := llm.Usage{
45 InputTokens: 150,
46 OutputTokens: 50,
47 }
48 usageJSON, _ := json.Marshal(usage)
49 usageStr := string(usageJSON)
50
51 messages := []APIMessage{
52 {
53 Type: string(db.MessageTypeAgent),
54 UsageData: &usageStr,
55 },
56 }
57
58 // Expected: 150 + 50 = 200
59 got := calculateContextWindowSize(messages)
60 want := uint64(200)
61
62 if got != want {
63 t.Errorf("calculateContextWindowSize() = %d, want %d", got, want)
64 }
65 })
66
67 t.Run("uses_last_message_with_usage", func(t *testing.T) {
68 // Test that we use the last message, not the first
69 usage1 := llm.Usage{
70 InputTokens: 100,
71 OutputTokens: 50,
72 }
73 usage1JSON, _ := json.Marshal(usage1)
74 usage1Str := string(usage1JSON)
75
76 usage2 := llm.Usage{
77 InputTokens: 200,
78 CacheReadInputTokens: 100,
79 OutputTokens: 75,
80 }
81 usage2JSON, _ := json.Marshal(usage2)
82 usage2Str := string(usage2JSON)
83
84 messages := []APIMessage{
85 {
86 Type: string(db.MessageTypeAgent),
87 UsageData: &usage1Str,
88 },
89 {
90 Type: string(db.MessageTypeUser),
91 UsageData: nil, // User messages typically don't have usage
92 },
93 {
94 Type: string(db.MessageTypeAgent),
95 UsageData: &usage2Str,
96 },
97 }
98
99 // Expected: 200 + 100 + 75 = 375 (from the last message)
100 got := calculateContextWindowSize(messages)
101 want := uint64(375)
102
103 if got != want {
104 t.Errorf("calculateContextWindowSize() = %d, want %d", got, want)
105 }
106 })
107
108 t.Run("empty_messages", func(t *testing.T) {
109 messages := []APIMessage{}
110 got := calculateContextWindowSize(messages)
111 want := uint64(0)
112
113 if got != want {
114 t.Errorf("calculateContextWindowSize() = %d, want %d", got, want)
115 }
116 })
117
118 t.Run("skips_zero_usage_messages", func(t *testing.T) {
119 // Test that we skip messages with zero usage data (common for user/tool messages)
120 // and find the last message with actual usage
121 validUsage := llm.Usage{
122 InputTokens: 200,
123 OutputTokens: 50,
124 }
125 validUsageJSON, _ := json.Marshal(validUsage)
126 validUsageStr := string(validUsageJSON)
127
128 zeroUsage := llm.Usage{} // All zeros
129 zeroUsageJSON, _ := json.Marshal(zeroUsage)
130 zeroUsageStr := string(zeroUsageJSON)
131
132 messages := []APIMessage{
133 {
134 Type: string(db.MessageTypeSystem),
135 UsageData: &zeroUsageStr, // System message with zero usage
136 },
137 {
138 Type: string(db.MessageTypeUser),
139 UsageData: &zeroUsageStr, // User message with zero usage
140 },
141 {
142 Type: string(db.MessageTypeAgent),
143 UsageData: &validUsageStr, // Agent message with valid usage
144 },
145 {
146 Type: string(db.MessageTypeUser),
147 UsageData: &zeroUsageStr, // User message after agent (zero usage)
148 },
149 }
150
151 // Should find the agent message's usage (200 + 50 = 250), not the last message's zero usage
152 got := calculateContextWindowSize(messages)
153 want := uint64(250)
154
155 if got != want {
156 t.Errorf("calculateContextWindowSize() = %d, want %d", got, want)
157 }
158 })
159}
160
161// TestContextWindowGrowsWithConversation tests that the context window size grows
162// as the conversation progresses, using the test harness and predictable service.
163func TestContextWindowGrowsWithConversation(t *testing.T) {
164 h := NewTestHarness(t)
165 defer h.Close()
166
167 // Start a new conversation
168 h.NewConversation("echo: first message", "/tmp")
169
170 // Wait for the response
171 resp1 := h.WaitResponse()
172 t.Logf("First response: %q", resp1)
173
174 // Get the context window size from the first message
175 firstSize := h.GetContextWindowSize()
176 t.Logf("First context window size: %d", firstSize)
177 if firstSize == 0 {
178 t.Fatal("expected non-zero context window size after first message")
179 }
180
181 // Send another message
182 h.Chat("echo: second message that is longer")
183 resp2 := h.WaitResponse()
184 t.Logf("Second response: %q", resp2)
185
186 // Context window should have grown
187 secondSize := h.GetContextWindowSize()
188 t.Logf("Second context window size: %d", secondSize)
189 if secondSize <= firstSize {
190 t.Errorf("context window should grow: first=%d, second=%d", firstSize, secondSize)
191 }
192
193 // Send a third message
194 h.Chat("echo: third message with even more text to demonstrate growth")
195 resp3 := h.WaitResponse()
196 t.Logf("Third response: %q", resp3)
197
198 thirdSize := h.GetContextWindowSize()
199 t.Logf("Third context window size: %d", thirdSize)
200 if thirdSize <= secondSize {
201 t.Errorf("context window should grow: second=%d, third=%d", secondSize, thirdSize)
202 }
203
204 t.Logf("Context window sizes: first=%d, second=%d, third=%d", firstSize, secondSize, thirdSize)
205}