1package db
2
3import (
4 "context"
5 "fmt"
6 "strings"
7 "testing"
8 "time"
9
10 "shelley.exe.dev/db/generated"
11)
12
13// setupTestDB creates a test database with schema migrated
14func setupTestDB(t *testing.T) *DB {
15 t.Helper()
16
17 // Use a temporary file instead of :memory: because the pool requires multiple connections
18 tmpDir := t.TempDir()
19 db, err := New(Config{DSN: tmpDir + "/test.db"})
20 if err != nil {
21 t.Fatalf("Failed to create test database: %v", err)
22 }
23
24 ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
25 defer cancel()
26
27 if err := db.Migrate(ctx); err != nil {
28 t.Fatalf("Failed to migrate test database: %v", err)
29 }
30
31 return db
32}
33
34func TestNew(t *testing.T) {
35 tests := []struct {
36 name string
37 cfg Config
38 wantErr bool
39 }{
40 {
41 name: "memory database not supported",
42 cfg: Config{DSN: ":memory:"},
43 wantErr: true,
44 },
45 {
46 name: "empty DSN",
47 cfg: Config{DSN: ""},
48 wantErr: true,
49 },
50 }
51
52 for _, tt := range tests {
53 t.Run(tt.name, func(t *testing.T) {
54 db, err := New(tt.cfg)
55 if (err != nil) != tt.wantErr {
56 t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr)
57 return
58 }
59 if db != nil {
60 defer db.Close()
61 }
62 })
63 }
64}
65
66func TestDB_Migrate(t *testing.T) {
67 tmpDir := t.TempDir()
68 db, err := New(Config{DSN: tmpDir + "/test.db"})
69 if err != nil {
70 t.Fatalf("Failed to create test database: %v", err)
71 }
72 defer db.Close()
73
74 ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
75 defer cancel()
76
77 // Run migrations first time
78 if err := db.Migrate(ctx); err != nil {
79 t.Errorf("Migrate() error = %v", err)
80 }
81
82 // Verify tables were created by trying to count conversations
83 var count int64
84 err = db.Queries(ctx, func(q *generated.Queries) error {
85 var err error
86 count, err = q.CountConversations(ctx)
87 return err
88 })
89 if err != nil {
90 t.Errorf("Failed to query conversations after migration: %v", err)
91 }
92 if count != 0 {
93 t.Errorf("Expected 0 conversations, got %d", count)
94 }
95
96 // Run migrations a second time to verify idempotency
97 if err := db.Migrate(ctx); err != nil {
98 t.Errorf("Second Migrate() error = %v", err)
99 }
100
101 // Verify we can still query after running migrations twice
102 err = db.Queries(ctx, func(q *generated.Queries) error {
103 var err error
104 count, err = q.CountConversations(ctx)
105 return err
106 })
107 if err != nil {
108 t.Errorf("Failed to query conversations after second migration: %v", err)
109 }
110}
111
112func TestDB_WithTx(t *testing.T) {
113 db := setupTestDB(t)
114 defer db.Close()
115
116 ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
117 defer cancel()
118
119 // Test successful transaction
120 err := db.WithTx(ctx, func(q *generated.Queries) error {
121 _, err := q.CreateConversation(ctx, generated.CreateConversationParams{
122 ConversationID: "test-conv-1",
123 Slug: stringPtr("test-slug"),
124 UserInitiated: true,
125 Model: nil,
126 })
127 return err
128 })
129 if err != nil {
130 t.Errorf("WithTx() error = %v", err)
131 }
132
133 // Verify the conversation was created
134 var conv generated.Conversation
135 err = db.Queries(ctx, func(q *generated.Queries) error {
136 var err error
137 conv, err = q.GetConversation(ctx, "test-conv-1")
138 return err
139 })
140 if err != nil {
141 t.Errorf("Failed to get conversation after transaction: %v", err)
142 }
143 if conv.ConversationID != "test-conv-1" {
144 t.Errorf("Expected conversation ID 'test-conv-1', got %s", conv.ConversationID)
145 }
146}
147
148// stringPtr returns a pointer to the given string
149func stringPtr(s string) *string {
150 return &s
151}
152
153func TestDB_ForeignKeyConstraints(t *testing.T) {
154 db := setupTestDB(t)
155 defer db.Close()
156
157 ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
158 defer cancel()
159
160 // Try to create a message with a non-existent conversation_id
161 // This should fail due to foreign key constraint
162 err := db.QueriesTx(ctx, func(q *generated.Queries) error {
163 _, err := q.CreateMessage(ctx, generated.CreateMessageParams{
164 MessageID: "test-msg-1",
165 ConversationID: "non-existent-conversation",
166 Type: "user",
167 })
168 return err
169 })
170
171 if err == nil {
172 t.Error("Expected error when creating message with non-existent conversation_id")
173 return
174 }
175
176 // Verify the error is related to foreign key constraint
177 if !strings.Contains(err.Error(), "FOREIGN KEY constraint failed") {
178 t.Errorf("Expected foreign key constraint error, got: %v", err)
179 }
180}
181
182func TestDB_Pool(t *testing.T) {
183 db := setupTestDB(t)
184 defer db.Close()
185
186 // Test Pool method
187 pool := db.Pool()
188 if pool == nil {
189 t.Error("Expected non-nil pool")
190 }
191}
192
193func TestDB_WithTxRes(t *testing.T) {
194 db := setupTestDB(t)
195 defer db.Close()
196
197 ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
198 defer cancel()
199
200 // Test WithTxRes with a simple function that returns a string
201 result, err := WithTxRes[string](db, ctx, func(queries *generated.Queries) (string, error) {
202 return "test result", nil
203 })
204 if err != nil {
205 t.Errorf("WithTxRes() error = %v", err)
206 }
207
208 if result != "test result" {
209 t.Errorf("Expected 'test result', got %s", result)
210 }
211
212 // Test WithTxRes with error handling
213 _, err = WithTxRes[string](db, ctx, func(queries *generated.Queries) (string, error) {
214 return "", fmt.Errorf("test error")
215 })
216
217 if err == nil {
218 t.Error("Expected error from WithTxRes, got none")
219 }
220}
221
222func TestLLMRequestPrefixDeduplication(t *testing.T) {
223 db := setupTestDB(t)
224 defer db.Close()
225
226 ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
227 defer cancel()
228
229 // Create a conversation first
230 slug := "test-prefix-conv"
231 conv, err := db.CreateConversation(ctx, &slug, true, nil, nil)
232 if err != nil {
233 t.Fatalf("Failed to create conversation: %v", err)
234 }
235
236 // Create a long shared prefix (must be > 100 bytes for deduplication to kick in)
237 sharedPrefix := strings.Repeat("A", 200) // 200 bytes of 'A's
238
239 // First request - full body stored
240 req1Body := sharedPrefix + "_suffix1"
241 req1, err := db.InsertLLMRequest(ctx, generated.InsertLLMRequestParams{
242 ConversationID: &conv.ConversationID,
243 Model: "test-model",
244 Provider: "test-provider",
245 Url: "http://example.com",
246 RequestBody: &req1Body,
247 })
248 if err != nil {
249 t.Fatalf("Failed to insert first request: %v", err)
250 }
251
252 // First request should have full body, no prefix reference
253 if req1.PrefixRequestID != nil {
254 t.Errorf("First request should not have prefix reference, got %v", *req1.PrefixRequestID)
255 }
256 if req1.PrefixLength != nil && *req1.PrefixLength != 0 {
257 t.Errorf("First request should have no prefix length, got %v", *req1.PrefixLength)
258 }
259 if req1.RequestBody == nil || *req1.RequestBody != req1Body {
260 t.Errorf("First request body mismatch: expected %q, got %q", req1Body, safeDeref(req1.RequestBody))
261 }
262
263 // Second request - shares prefix with first
264 req2Body := sharedPrefix + "_suffix2_longer"
265 req2, err := db.InsertLLMRequest(ctx, generated.InsertLLMRequestParams{
266 ConversationID: &conv.ConversationID,
267 Model: "test-model",
268 Provider: "test-provider",
269 Url: "http://example.com",
270 RequestBody: &req2Body,
271 })
272 if err != nil {
273 t.Fatalf("Failed to insert second request: %v", err)
274 }
275
276 // Second request should have prefix reference
277 if req2.PrefixRequestID == nil || *req2.PrefixRequestID != req1.ID {
278 t.Errorf("Second request should reference first request, got prefix_request_id=%v", safeDeref64(req2.PrefixRequestID))
279 }
280 // Common prefix is sharedPrefix + "_suffix" = 200 + 7 = 207 bytes
281 expectedPrefixLen := len(sharedPrefix) + len("_suffix")
282 if req2.PrefixLength == nil || *req2.PrefixLength != int64(expectedPrefixLen) {
283 t.Errorf("Second request prefix length should be %d, got %v", expectedPrefixLen, safeDeref64(req2.PrefixLength))
284 }
285 // Stored body should only be the suffix after the shared prefix ("1" vs "2_longer")
286 expectedSuffix := "2_longer"
287 if req2.RequestBody == nil || *req2.RequestBody != expectedSuffix {
288 t.Errorf("Second request should only store suffix %q, got %q", expectedSuffix, safeDeref(req2.RequestBody))
289 }
290
291 // Third request - shares even longer prefix with second
292 req3Body := sharedPrefix + "_suffix2_longer_and_more"
293 req3, err := db.InsertLLMRequest(ctx, generated.InsertLLMRequestParams{
294 ConversationID: &conv.ConversationID,
295 Model: "test-model",
296 Provider: "test-provider",
297 Url: "http://example.com",
298 RequestBody: &req3Body,
299 })
300 if err != nil {
301 t.Fatalf("Failed to insert third request: %v", err)
302 }
303
304 // Third request should reference second request
305 if req3.PrefixRequestID == nil || *req3.PrefixRequestID != req2.ID {
306 t.Errorf("Third request should reference second request, got prefix_request_id=%v", safeDeref64(req3.PrefixRequestID))
307 }
308 // The prefix length should be the full length of req2Body (since req3Body starts with req2Body)
309 if req3.PrefixLength == nil || *req3.PrefixLength != int64(len(sharedPrefix)+len("_suffix2_longer")) {
310 t.Errorf("Third request prefix length should be %d, got %v", len(sharedPrefix)+len("_suffix2_longer"), safeDeref64(req3.PrefixLength))
311 }
312
313 // Test reconstruction of full bodies
314 reconstructed1, err := db.GetFullLLMRequestBody(ctx, req1.ID)
315 if err != nil {
316 t.Fatalf("Failed to reconstruct first request: %v", err)
317 }
318 if reconstructed1 != req1Body {
319 t.Errorf("Reconstructed first request mismatch: expected %q, got %q", req1Body, reconstructed1)
320 }
321
322 reconstructed2, err := db.GetFullLLMRequestBody(ctx, req2.ID)
323 if err != nil {
324 t.Fatalf("Failed to reconstruct second request: %v", err)
325 }
326 if reconstructed2 != req2Body {
327 t.Errorf("Reconstructed second request mismatch: expected %q, got %q", req2Body, reconstructed2)
328 }
329
330 reconstructed3, err := db.GetFullLLMRequestBody(ctx, req3.ID)
331 if err != nil {
332 t.Fatalf("Failed to reconstruct third request: %v", err)
333 }
334 if reconstructed3 != req3Body {
335 t.Errorf("Reconstructed third request mismatch: expected %q, got %q", req3Body, reconstructed3)
336 }
337}
338
339func TestLLMRequestNoPrefixForShortOverlap(t *testing.T) {
340 db := setupTestDB(t)
341 defer db.Close()
342
343 ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
344 defer cancel()
345
346 slug := "test-short-conv"
347 conv, err := db.CreateConversation(ctx, &slug, true, nil, nil)
348 if err != nil {
349 t.Fatalf("Failed to create conversation: %v", err)
350 }
351
352 // Short prefix (< 100 bytes) - should NOT deduplicate
353 shortPrefix := strings.Repeat("B", 50)
354
355 req1Body := shortPrefix + "_first"
356 _, err = db.InsertLLMRequest(ctx, generated.InsertLLMRequestParams{
357 ConversationID: &conv.ConversationID,
358 Model: "test-model",
359 Provider: "test-provider",
360 Url: "http://example.com",
361 RequestBody: &req1Body,
362 })
363 if err != nil {
364 t.Fatalf("Failed to insert first request: %v", err)
365 }
366
367 req2Body := shortPrefix + "_second"
368 req2, err := db.InsertLLMRequest(ctx, generated.InsertLLMRequestParams{
369 ConversationID: &conv.ConversationID,
370 Model: "test-model",
371 Provider: "test-provider",
372 Url: "http://example.com",
373 RequestBody: &req2Body,
374 })
375 if err != nil {
376 t.Fatalf("Failed to insert second request: %v", err)
377 }
378
379 // With short prefix, should NOT have prefix reference (full body stored)
380 if req2.PrefixRequestID != nil {
381 t.Errorf("Short overlap should not have prefix reference, got %v", *req2.PrefixRequestID)
382 }
383 if req2.RequestBody == nil || *req2.RequestBody != req2Body {
384 t.Errorf("Short overlap should store full body %q, got %q", req2Body, safeDeref(req2.RequestBody))
385 }
386}
387
388func TestLLMRequestNoConversationID(t *testing.T) {
389 db := setupTestDB(t)
390 defer db.Close()
391
392 ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
393 defer cancel()
394
395 // Request without conversation_id - should store full body
396 reqBody := strings.Repeat("C", 300)
397 req, err := db.InsertLLMRequest(ctx, generated.InsertLLMRequestParams{
398 ConversationID: nil,
399 Model: "test-model",
400 Provider: "test-provider",
401 Url: "http://example.com",
402 RequestBody: &reqBody,
403 })
404 if err != nil {
405 t.Fatalf("Failed to insert request: %v", err)
406 }
407
408 // Should not have prefix reference
409 if req.PrefixRequestID != nil {
410 t.Errorf("Request without conversation_id should not have prefix reference")
411 }
412 if req.RequestBody == nil || *req.RequestBody != reqBody {
413 t.Errorf("Request should store full body")
414 }
415}
416
417func safeDeref(s *string) string {
418 if s == nil {
419 return "<nil>"
420 }
421 return *s
422}
423
424func safeDeref64(i *int64) int64 {
425 if i == nil {
426 return -1
427 }
428 return *i
429}
430
431func TestLLMRequestRealisticConversation(t *testing.T) {
432 // This test simulates realistic LLM API request patterns where each
433 // subsequent request includes all previous messages plus new ones
434 db := setupTestDB(t)
435 defer db.Close()
436
437 ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
438 defer cancel()
439
440 slug := "test-realistic-conv"
441 conv, err := db.CreateConversation(ctx, &slug, true, nil, nil)
442 if err != nil {
443 t.Fatalf("Failed to create conversation: %v", err)
444 }
445
446 // Simulate Anthropic-style messages array growing over conversation
447 // Each request adds to the previous messages
448 baseRequest := `{"model":"claude-sonnet-4-5-20250929","system":[{"type":"text","text":"You are a helpful assistant."}],"messages":[`
449
450 message1 := `{"role":"user","content":[{"type":"text","text":"Hello, how are you?"}]}`
451 req1Body := baseRequest + message1 + `],"max_tokens":8192}`
452
453 req1, err := db.InsertLLMRequest(ctx, generated.InsertLLMRequestParams{
454 ConversationID: &conv.ConversationID,
455 Model: "claude-sonnet-4-5-20250929",
456 Provider: "anthropic",
457 Url: "https://api.anthropic.com/v1/messages",
458 RequestBody: &req1Body,
459 })
460 if err != nil {
461 t.Fatalf("Failed to insert first request: %v", err)
462 }
463
464 // First request stored in full
465 if req1.PrefixRequestID != nil {
466 t.Errorf("First request should not have prefix reference")
467 }
468
469 // Second request: user message + assistant response + new user message
470 message2 := `{"role":"assistant","content":[{"type":"text","text":"I'm doing well, thank you for asking!"}]}`
471 message3 := `{"role":"user","content":[{"type":"text","text":"Can you help me write some code?"}]}`
472 req2Body := baseRequest + message1 + `,` + message2 + `,` + message3 + `],"max_tokens":8192}`
473
474 req2, err := db.InsertLLMRequest(ctx, generated.InsertLLMRequestParams{
475 ConversationID: &conv.ConversationID,
476 Model: "claude-sonnet-4-5-20250929",
477 Provider: "anthropic",
478 Url: "https://api.anthropic.com/v1/messages",
479 RequestBody: &req2Body,
480 })
481 if err != nil {
482 t.Fatalf("Failed to insert second request: %v", err)
483 }
484
485 // Second request should have prefix deduplication
486 if req2.PrefixRequestID == nil {
487 t.Errorf("Second request should have prefix reference")
488 } else if *req2.PrefixRequestID != req1.ID {
489 t.Errorf("Second request should reference first request")
490 }
491
492 // Verify prefix length is reasonable (should be at least the base + message1 length)
493 minExpectedPrefix := len(baseRequest) + len(message1)
494 if req2.PrefixLength == nil || *req2.PrefixLength < int64(minExpectedPrefix) {
495 t.Errorf("Second request prefix length should be at least %d, got %v", minExpectedPrefix, safeDeref64(req2.PrefixLength))
496 }
497
498 // Verify we saved significant space
499 req2StoredLen := len(safeDeref(req2.RequestBody))
500 req2FullLen := len(req2Body)
501 if req2StoredLen >= req2FullLen {
502 t.Errorf("Second request should store less than full body: stored %d, full %d", req2StoredLen, req2FullLen)
503 }
504 t.Logf("Space saved for request 2: %d bytes (%.1f%% reduction)",
505 req2FullLen-req2StoredLen,
506 100.0*float64(req2FullLen-req2StoredLen)/float64(req2FullLen))
507
508 // Third request: even more messages
509 message4 := `{"role":"assistant","content":[{"type":"text","text":"Of course! What kind of code would you like me to help you with?"}]}`
510 message5 := `{"role":"user","content":[{"type":"text","text":"I need a function to calculate fibonacci numbers."}]}`
511 req3Body := baseRequest + message1 + `,` + message2 + `,` + message3 + `,` + message4 + `,` + message5 + `],"max_tokens":8192}`
512
513 req3, err := db.InsertLLMRequest(ctx, generated.InsertLLMRequestParams{
514 ConversationID: &conv.ConversationID,
515 Model: "claude-sonnet-4-5-20250929",
516 Provider: "anthropic",
517 Url: "https://api.anthropic.com/v1/messages",
518 RequestBody: &req3Body,
519 })
520 if err != nil {
521 t.Fatalf("Failed to insert third request: %v", err)
522 }
523
524 // Third request should reference second
525 if req3.PrefixRequestID == nil || *req3.PrefixRequestID != req2.ID {
526 t.Errorf("Third request should reference second request")
527 }
528
529 req3StoredLen := len(safeDeref(req3.RequestBody))
530 req3FullLen := len(req3Body)
531 t.Logf("Space saved for request 3: %d bytes (%.1f%% reduction)",
532 req3FullLen-req3StoredLen,
533 100.0*float64(req3FullLen-req3StoredLen)/float64(req3FullLen))
534
535 // Verify reconstruction works for all requests
536 reconstructed1, err := db.GetFullLLMRequestBody(ctx, req1.ID)
537 if err != nil {
538 t.Fatalf("Failed to reconstruct request 1: %v", err)
539 }
540 if reconstructed1 != req1Body {
541 t.Errorf("Reconstructed request 1 mismatch")
542 }
543
544 reconstructed2, err := db.GetFullLLMRequestBody(ctx, req2.ID)
545 if err != nil {
546 t.Fatalf("Failed to reconstruct request 2: %v", err)
547 }
548 if reconstructed2 != req2Body {
549 t.Errorf("Reconstructed request 2 mismatch")
550 }
551
552 reconstructed3, err := db.GetFullLLMRequestBody(ctx, req3.ID)
553 if err != nil {
554 t.Fatalf("Failed to reconstruct request 3: %v", err)
555 }
556 if reconstructed3 != req3Body {
557 t.Errorf("Reconstructed request 3 mismatch")
558 }
559
560 // Calculate total storage savings
561 totalOriginal := len(req1Body) + len(req2Body) + len(req3Body)
562 totalStored := len(safeDeref(req1.RequestBody)) + len(safeDeref(req2.RequestBody)) + len(safeDeref(req3.RequestBody))
563 t.Logf("Total space: original %d bytes, stored %d bytes, saved %d bytes (%.1f%% reduction)",
564 totalOriginal, totalStored, totalOriginal-totalStored,
565 100.0*float64(totalOriginal-totalStored)/float64(totalOriginal))
566}
567
568func TestLLMRequestOpenAIStyle(t *testing.T) {
569 // Test with OpenAI-style request format
570 db := setupTestDB(t)
571 defer db.Close()
572
573 ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
574 defer cancel()
575
576 slug := "test-openai-conv"
577 conv, err := db.CreateConversation(ctx, &slug, true, nil, nil)
578 if err != nil {
579 t.Fatalf("Failed to create conversation: %v", err)
580 }
581
582 // OpenAI-style request format
583 baseRequest := `{"model":"gpt-4","messages":[`
584 message1 := `{"role":"system","content":"You are a helpful assistant."},{"role":"user","content":"Hello!"}`
585 req1Body := baseRequest + message1 + `],"stream":true}`
586
587 req1, err := db.InsertLLMRequest(ctx, generated.InsertLLMRequestParams{
588 ConversationID: &conv.ConversationID,
589 Model: "gpt-4",
590 Provider: "openai",
591 Url: "https://api.openai.com/v1/chat/completions",
592 RequestBody: &req1Body,
593 })
594 if err != nil {
595 t.Fatalf("Failed to insert first request: %v", err)
596 }
597
598 // Second request with more messages
599 message2 := `{"role":"assistant","content":"Hello! How can I help you today?"},{"role":"user","content":"What's the weather like?"}`
600 req2Body := baseRequest + message1 + `,` + message2 + `],"stream":true}`
601
602 req2, err := db.InsertLLMRequest(ctx, generated.InsertLLMRequestParams{
603 ConversationID: &conv.ConversationID,
604 Model: "gpt-4",
605 Provider: "openai",
606 Url: "https://api.openai.com/v1/chat/completions",
607 RequestBody: &req2Body,
608 })
609 if err != nil {
610 t.Fatalf("Failed to insert second request: %v", err)
611 }
612
613 // Should have prefix deduplication
614 if req2.PrefixRequestID == nil || *req2.PrefixRequestID != req1.ID {
615 t.Errorf("Second request should reference first request")
616 }
617
618 // Verify reconstruction
619 reconstructed2, err := db.GetFullLLMRequestBody(ctx, req2.ID)
620 if err != nil {
621 t.Fatalf("Failed to reconstruct second request: %v", err)
622 }
623 if reconstructed2 != req2Body {
624 t.Errorf("Reconstructed request mismatch:\nexpected: %s\ngot: %s", req2Body, reconstructed2)
625 }
626
627 // Calculate savings
628 req2StoredLen := len(safeDeref(req2.RequestBody))
629 req2FullLen := len(req2Body)
630 t.Logf("OpenAI-style space saved: %d bytes (%.1f%% reduction)",
631 req2FullLen-req2StoredLen,
632 100.0*float64(req2FullLen-req2StoredLen)/float64(req2FullLen))
633}