context_window_test.go

  1package server
  2
  3import (
  4	"encoding/json"
  5	"testing"
  6
  7	"shelley.exe.dev/db"
  8	"shelley.exe.dev/llm"
  9)
 10
 11// TestContextWindowSizeCalculation tests that the context window size is correctly
 12// calculated including cached tokens.
 13func TestContextWindowSizeCalculation(t *testing.T) {
 14	// Test the calculateContextWindowSize function directly
 15	t.Run("includes_all_token_types", func(t *testing.T) {
 16		// Create usage data with all token types
 17		usage := llm.Usage{
 18			InputTokens:              100,
 19			CacheCreationInputTokens: 50,
 20			CacheReadInputTokens:     200,
 21			OutputTokens:             30,
 22		}
 23		usageJSON, _ := json.Marshal(usage)
 24		usageStr := string(usageJSON)
 25
 26		messages := []APIMessage{
 27			{
 28				Type:      string(db.MessageTypeAgent),
 29				UsageData: &usageStr,
 30			},
 31		}
 32
 33		// Expected: 100 + 50 + 200 + 30 = 380
 34		got := calculateContextWindowSize(messages)
 35		want := uint64(380)
 36
 37		if got != want {
 38			t.Errorf("calculateContextWindowSize() = %d, want %d", got, want)
 39		}
 40	})
 41
 42	t.Run("only_input_tokens", func(t *testing.T) {
 43		// Test with just input tokens (no caching)
 44		usage := llm.Usage{
 45			InputTokens:  150,
 46			OutputTokens: 50,
 47		}
 48		usageJSON, _ := json.Marshal(usage)
 49		usageStr := string(usageJSON)
 50
 51		messages := []APIMessage{
 52			{
 53				Type:      string(db.MessageTypeAgent),
 54				UsageData: &usageStr,
 55			},
 56		}
 57
 58		// Expected: 150 + 50 = 200
 59		got := calculateContextWindowSize(messages)
 60		want := uint64(200)
 61
 62		if got != want {
 63			t.Errorf("calculateContextWindowSize() = %d, want %d", got, want)
 64		}
 65	})
 66
 67	t.Run("uses_last_message_with_usage", func(t *testing.T) {
 68		// Test that we use the last message, not the first
 69		usage1 := llm.Usage{
 70			InputTokens:  100,
 71			OutputTokens: 50,
 72		}
 73		usage1JSON, _ := json.Marshal(usage1)
 74		usage1Str := string(usage1JSON)
 75
 76		usage2 := llm.Usage{
 77			InputTokens:          200,
 78			CacheReadInputTokens: 100,
 79			OutputTokens:         75,
 80		}
 81		usage2JSON, _ := json.Marshal(usage2)
 82		usage2Str := string(usage2JSON)
 83
 84		messages := []APIMessage{
 85			{
 86				Type:      string(db.MessageTypeAgent),
 87				UsageData: &usage1Str,
 88			},
 89			{
 90				Type:      string(db.MessageTypeUser),
 91				UsageData: nil, // User messages typically don't have usage
 92			},
 93			{
 94				Type:      string(db.MessageTypeAgent),
 95				UsageData: &usage2Str,
 96			},
 97		}
 98
 99		// Expected: 200 + 100 + 75 = 375 (from the last message)
100		got := calculateContextWindowSize(messages)
101		want := uint64(375)
102
103		if got != want {
104			t.Errorf("calculateContextWindowSize() = %d, want %d", got, want)
105		}
106	})
107
108	t.Run("empty_messages", func(t *testing.T) {
109		messages := []APIMessage{}
110		got := calculateContextWindowSize(messages)
111		want := uint64(0)
112
113		if got != want {
114			t.Errorf("calculateContextWindowSize() = %d, want %d", got, want)
115		}
116	})
117
118	t.Run("skips_zero_usage_messages", func(t *testing.T) {
119		// Test that we skip messages with zero usage data (common for user/tool messages)
120		// and find the last message with actual usage
121		validUsage := llm.Usage{
122			InputTokens:  200,
123			OutputTokens: 50,
124		}
125		validUsageJSON, _ := json.Marshal(validUsage)
126		validUsageStr := string(validUsageJSON)
127
128		zeroUsage := llm.Usage{} // All zeros
129		zeroUsageJSON, _ := json.Marshal(zeroUsage)
130		zeroUsageStr := string(zeroUsageJSON)
131
132		messages := []APIMessage{
133			{
134				Type:      string(db.MessageTypeSystem),
135				UsageData: &zeroUsageStr, // System message with zero usage
136			},
137			{
138				Type:      string(db.MessageTypeUser),
139				UsageData: &zeroUsageStr, // User message with zero usage
140			},
141			{
142				Type:      string(db.MessageTypeAgent),
143				UsageData: &validUsageStr, // Agent message with valid usage
144			},
145			{
146				Type:      string(db.MessageTypeUser),
147				UsageData: &zeroUsageStr, // User message after agent (zero usage)
148			},
149		}
150
151		// Should find the agent message's usage (200 + 50 = 250), not the last message's zero usage
152		got := calculateContextWindowSize(messages)
153		want := uint64(250)
154
155		if got != want {
156			t.Errorf("calculateContextWindowSize() = %d, want %d", got, want)
157		}
158	})
159}
160
161// TestContextWindowGrowsWithConversation tests that the context window size grows
162// as the conversation progresses, using the test harness and predictable service.
163func TestContextWindowGrowsWithConversation(t *testing.T) {
164	h := NewTestHarness(t)
165	defer h.Close()
166
167	// Start a new conversation
168	h.NewConversation("echo: first message", "/tmp")
169
170	// Wait for the response
171	resp1 := h.WaitResponse()
172	t.Logf("First response: %q", resp1)
173
174	// Get the context window size from the first message
175	firstSize := h.GetContextWindowSize()
176	t.Logf("First context window size: %d", firstSize)
177	if firstSize == 0 {
178		t.Fatal("expected non-zero context window size after first message")
179	}
180
181	// Send another message
182	h.Chat("echo: second message that is longer")
183	resp2 := h.WaitResponse()
184	t.Logf("Second response: %q", resp2)
185
186	// Context window should have grown
187	secondSize := h.GetContextWindowSize()
188	t.Logf("Second context window size: %d", secondSize)
189	if secondSize <= firstSize {
190		t.Errorf("context window should grow: first=%d, second=%d", firstSize, secondSize)
191	}
192
193	// Send a third message
194	h.Chat("echo: third message with even more text to demonstrate growth")
195	resp3 := h.WaitResponse()
196	t.Logf("Third response: %q", resp3)
197
198	thirdSize := h.GetContextWindowSize()
199	t.Logf("Third context window size: %d", thirdSize)
200	if thirdSize <= secondSize {
201		t.Errorf("context window should grow: second=%d, third=%d", secondSize, thirdSize)
202	}
203
204	t.Logf("Context window sizes: first=%d, second=%d, third=%d", firstSize, secondSize, thirdSize)
205}