db_test.go

  1package db
  2
  3import (
  4	"context"
  5	"fmt"
  6	"strings"
  7	"testing"
  8	"time"
  9
 10	"shelley.exe.dev/db/generated"
 11)
 12
 13// setupTestDB creates a test database with schema migrated
 14func setupTestDB(t *testing.T) *DB {
 15	t.Helper()
 16
 17	// Use a temporary file instead of :memory: because the pool requires multiple connections
 18	tmpDir := t.TempDir()
 19	db, err := New(Config{DSN: tmpDir + "/test.db"})
 20	if err != nil {
 21		t.Fatalf("Failed to create test database: %v", err)
 22	}
 23
 24	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
 25	defer cancel()
 26
 27	if err := db.Migrate(ctx); err != nil {
 28		t.Fatalf("Failed to migrate test database: %v", err)
 29	}
 30
 31	return db
 32}
 33
 34func TestNew(t *testing.T) {
 35	tests := []struct {
 36		name    string
 37		cfg     Config
 38		wantErr bool
 39	}{
 40		{
 41			name:    "memory database not supported",
 42			cfg:     Config{DSN: ":memory:"},
 43			wantErr: true,
 44		},
 45		{
 46			name:    "empty DSN",
 47			cfg:     Config{DSN: ""},
 48			wantErr: true,
 49		},
 50	}
 51
 52	for _, tt := range tests {
 53		t.Run(tt.name, func(t *testing.T) {
 54			db, err := New(tt.cfg)
 55			if (err != nil) != tt.wantErr {
 56				t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr)
 57				return
 58			}
 59			if db != nil {
 60				defer db.Close()
 61			}
 62		})
 63	}
 64}
 65
 66func TestDB_Migrate(t *testing.T) {
 67	tmpDir := t.TempDir()
 68	db, err := New(Config{DSN: tmpDir + "/test.db"})
 69	if err != nil {
 70		t.Fatalf("Failed to create test database: %v", err)
 71	}
 72	defer db.Close()
 73
 74	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
 75	defer cancel()
 76
 77	// Run migrations first time
 78	if err := db.Migrate(ctx); err != nil {
 79		t.Errorf("Migrate() error = %v", err)
 80	}
 81
 82	// Verify tables were created by trying to count conversations
 83	var count int64
 84	err = db.Queries(ctx, func(q *generated.Queries) error {
 85		var err error
 86		count, err = q.CountConversations(ctx)
 87		return err
 88	})
 89	if err != nil {
 90		t.Errorf("Failed to query conversations after migration: %v", err)
 91	}
 92	if count != 0 {
 93		t.Errorf("Expected 0 conversations, got %d", count)
 94	}
 95
 96	// Run migrations a second time to verify idempotency
 97	if err := db.Migrate(ctx); err != nil {
 98		t.Errorf("Second Migrate() error = %v", err)
 99	}
100
101	// Verify we can still query after running migrations twice
102	err = db.Queries(ctx, func(q *generated.Queries) error {
103		var err error
104		count, err = q.CountConversations(ctx)
105		return err
106	})
107	if err != nil {
108		t.Errorf("Failed to query conversations after second migration: %v", err)
109	}
110}
111
112func TestDB_WithTx(t *testing.T) {
113	db := setupTestDB(t)
114	defer db.Close()
115
116	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
117	defer cancel()
118
119	// Test successful transaction
120	err := db.WithTx(ctx, func(q *generated.Queries) error {
121		_, err := q.CreateConversation(ctx, generated.CreateConversationParams{
122			ConversationID: "test-conv-1",
123			Slug:           stringPtr("test-slug"),
124			UserInitiated:  true,
125			Model:          nil,
126		})
127		return err
128	})
129	if err != nil {
130		t.Errorf("WithTx() error = %v", err)
131	}
132
133	// Verify the conversation was created
134	var conv generated.Conversation
135	err = db.Queries(ctx, func(q *generated.Queries) error {
136		var err error
137		conv, err = q.GetConversation(ctx, "test-conv-1")
138		return err
139	})
140	if err != nil {
141		t.Errorf("Failed to get conversation after transaction: %v", err)
142	}
143	if conv.ConversationID != "test-conv-1" {
144		t.Errorf("Expected conversation ID 'test-conv-1', got %s", conv.ConversationID)
145	}
146}
147
148// stringPtr returns a pointer to the given string
149func stringPtr(s string) *string {
150	return &s
151}
152
153func TestDB_ForeignKeyConstraints(t *testing.T) {
154	db := setupTestDB(t)
155	defer db.Close()
156
157	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
158	defer cancel()
159
160	// Try to create a message with a non-existent conversation_id
161	// This should fail due to foreign key constraint
162	err := db.QueriesTx(ctx, func(q *generated.Queries) error {
163		_, err := q.CreateMessage(ctx, generated.CreateMessageParams{
164			MessageID:      "test-msg-1",
165			ConversationID: "non-existent-conversation",
166			Type:           "user",
167		})
168		return err
169	})
170
171	if err == nil {
172		t.Error("Expected error when creating message with non-existent conversation_id")
173		return
174	}
175
176	// Verify the error is related to foreign key constraint
177	if !strings.Contains(err.Error(), "FOREIGN KEY constraint failed") {
178		t.Errorf("Expected foreign key constraint error, got: %v", err)
179	}
180}
181
182func TestDB_Pool(t *testing.T) {
183	db := setupTestDB(t)
184	defer db.Close()
185
186	// Test Pool method
187	pool := db.Pool()
188	if pool == nil {
189		t.Error("Expected non-nil pool")
190	}
191}
192
193func TestDB_WithTxRes(t *testing.T) {
194	db := setupTestDB(t)
195	defer db.Close()
196
197	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
198	defer cancel()
199
200	// Test WithTxRes with a simple function that returns a string
201	result, err := WithTxRes[string](db, ctx, func(queries *generated.Queries) (string, error) {
202		return "test result", nil
203	})
204	if err != nil {
205		t.Errorf("WithTxRes() error = %v", err)
206	}
207
208	if result != "test result" {
209		t.Errorf("Expected 'test result', got %s", result)
210	}
211
212	// Test WithTxRes with error handling
213	_, err = WithTxRes[string](db, ctx, func(queries *generated.Queries) (string, error) {
214		return "", fmt.Errorf("test error")
215	})
216
217	if err == nil {
218		t.Error("Expected error from WithTxRes, got none")
219	}
220}
221
222func TestLLMRequestPrefixDeduplication(t *testing.T) {
223	db := setupTestDB(t)
224	defer db.Close()
225
226	ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
227	defer cancel()
228
229	// Create a conversation first
230	slug := "test-prefix-conv"
231	conv, err := db.CreateConversation(ctx, &slug, true, nil, nil)
232	if err != nil {
233		t.Fatalf("Failed to create conversation: %v", err)
234	}
235
236	// Create a long shared prefix (must be > 100 bytes for deduplication to kick in)
237	sharedPrefix := strings.Repeat("A", 200) // 200 bytes of 'A's
238
239	// First request - full body stored
240	req1Body := sharedPrefix + "_suffix1"
241	req1, err := db.InsertLLMRequest(ctx, generated.InsertLLMRequestParams{
242		ConversationID: &conv.ConversationID,
243		Model:          "test-model",
244		Provider:       "test-provider",
245		Url:            "http://example.com",
246		RequestBody:    &req1Body,
247	})
248	if err != nil {
249		t.Fatalf("Failed to insert first request: %v", err)
250	}
251
252	// First request should have full body, no prefix reference
253	if req1.PrefixRequestID != nil {
254		t.Errorf("First request should not have prefix reference, got %v", *req1.PrefixRequestID)
255	}
256	if req1.PrefixLength != nil && *req1.PrefixLength != 0 {
257		t.Errorf("First request should have no prefix length, got %v", *req1.PrefixLength)
258	}
259	if req1.RequestBody == nil || *req1.RequestBody != req1Body {
260		t.Errorf("First request body mismatch: expected %q, got %q", req1Body, safeDeref(req1.RequestBody))
261	}
262
263	// Second request - shares prefix with first
264	req2Body := sharedPrefix + "_suffix2_longer"
265	req2, err := db.InsertLLMRequest(ctx, generated.InsertLLMRequestParams{
266		ConversationID: &conv.ConversationID,
267		Model:          "test-model",
268		Provider:       "test-provider",
269		Url:            "http://example.com",
270		RequestBody:    &req2Body,
271	})
272	if err != nil {
273		t.Fatalf("Failed to insert second request: %v", err)
274	}
275
276	// Second request should have prefix reference
277	if req2.PrefixRequestID == nil || *req2.PrefixRequestID != req1.ID {
278		t.Errorf("Second request should reference first request, got prefix_request_id=%v", safeDeref64(req2.PrefixRequestID))
279	}
280	// Common prefix is sharedPrefix + "_suffix" = 200 + 7 = 207 bytes
281	expectedPrefixLen := len(sharedPrefix) + len("_suffix")
282	if req2.PrefixLength == nil || *req2.PrefixLength != int64(expectedPrefixLen) {
283		t.Errorf("Second request prefix length should be %d, got %v", expectedPrefixLen, safeDeref64(req2.PrefixLength))
284	}
285	// Stored body should only be the suffix after the shared prefix ("1" vs "2_longer")
286	expectedSuffix := "2_longer"
287	if req2.RequestBody == nil || *req2.RequestBody != expectedSuffix {
288		t.Errorf("Second request should only store suffix %q, got %q", expectedSuffix, safeDeref(req2.RequestBody))
289	}
290
291	// Third request - shares even longer prefix with second
292	req3Body := sharedPrefix + "_suffix2_longer_and_more"
293	req3, err := db.InsertLLMRequest(ctx, generated.InsertLLMRequestParams{
294		ConversationID: &conv.ConversationID,
295		Model:          "test-model",
296		Provider:       "test-provider",
297		Url:            "http://example.com",
298		RequestBody:    &req3Body,
299	})
300	if err != nil {
301		t.Fatalf("Failed to insert third request: %v", err)
302	}
303
304	// Third request should reference second request
305	if req3.PrefixRequestID == nil || *req3.PrefixRequestID != req2.ID {
306		t.Errorf("Third request should reference second request, got prefix_request_id=%v", safeDeref64(req3.PrefixRequestID))
307	}
308	// The prefix length should be the full length of req2Body (since req3Body starts with req2Body)
309	if req3.PrefixLength == nil || *req3.PrefixLength != int64(len(sharedPrefix)+len("_suffix2_longer")) {
310		t.Errorf("Third request prefix length should be %d, got %v", len(sharedPrefix)+len("_suffix2_longer"), safeDeref64(req3.PrefixLength))
311	}
312
313	// Test reconstruction of full bodies
314	reconstructed1, err := db.GetFullLLMRequestBody(ctx, req1.ID)
315	if err != nil {
316		t.Fatalf("Failed to reconstruct first request: %v", err)
317	}
318	if reconstructed1 != req1Body {
319		t.Errorf("Reconstructed first request mismatch: expected %q, got %q", req1Body, reconstructed1)
320	}
321
322	reconstructed2, err := db.GetFullLLMRequestBody(ctx, req2.ID)
323	if err != nil {
324		t.Fatalf("Failed to reconstruct second request: %v", err)
325	}
326	if reconstructed2 != req2Body {
327		t.Errorf("Reconstructed second request mismatch: expected %q, got %q", req2Body, reconstructed2)
328	}
329
330	reconstructed3, err := db.GetFullLLMRequestBody(ctx, req3.ID)
331	if err != nil {
332		t.Fatalf("Failed to reconstruct third request: %v", err)
333	}
334	if reconstructed3 != req3Body {
335		t.Errorf("Reconstructed third request mismatch: expected %q, got %q", req3Body, reconstructed3)
336	}
337}
338
339func TestLLMRequestNoPrefixForShortOverlap(t *testing.T) {
340	db := setupTestDB(t)
341	defer db.Close()
342
343	ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
344	defer cancel()
345
346	slug := "test-short-conv"
347	conv, err := db.CreateConversation(ctx, &slug, true, nil, nil)
348	if err != nil {
349		t.Fatalf("Failed to create conversation: %v", err)
350	}
351
352	// Short prefix (< 100 bytes) - should NOT deduplicate
353	shortPrefix := strings.Repeat("B", 50)
354
355	req1Body := shortPrefix + "_first"
356	_, err = db.InsertLLMRequest(ctx, generated.InsertLLMRequestParams{
357		ConversationID: &conv.ConversationID,
358		Model:          "test-model",
359		Provider:       "test-provider",
360		Url:            "http://example.com",
361		RequestBody:    &req1Body,
362	})
363	if err != nil {
364		t.Fatalf("Failed to insert first request: %v", err)
365	}
366
367	req2Body := shortPrefix + "_second"
368	req2, err := db.InsertLLMRequest(ctx, generated.InsertLLMRequestParams{
369		ConversationID: &conv.ConversationID,
370		Model:          "test-model",
371		Provider:       "test-provider",
372		Url:            "http://example.com",
373		RequestBody:    &req2Body,
374	})
375	if err != nil {
376		t.Fatalf("Failed to insert second request: %v", err)
377	}
378
379	// With short prefix, should NOT have prefix reference (full body stored)
380	if req2.PrefixRequestID != nil {
381		t.Errorf("Short overlap should not have prefix reference, got %v", *req2.PrefixRequestID)
382	}
383	if req2.RequestBody == nil || *req2.RequestBody != req2Body {
384		t.Errorf("Short overlap should store full body %q, got %q", req2Body, safeDeref(req2.RequestBody))
385	}
386}
387
388func TestLLMRequestNoConversationID(t *testing.T) {
389	db := setupTestDB(t)
390	defer db.Close()
391
392	ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
393	defer cancel()
394
395	// Request without conversation_id - should store full body
396	reqBody := strings.Repeat("C", 300)
397	req, err := db.InsertLLMRequest(ctx, generated.InsertLLMRequestParams{
398		ConversationID: nil,
399		Model:          "test-model",
400		Provider:       "test-provider",
401		Url:            "http://example.com",
402		RequestBody:    &reqBody,
403	})
404	if err != nil {
405		t.Fatalf("Failed to insert request: %v", err)
406	}
407
408	// Should not have prefix reference
409	if req.PrefixRequestID != nil {
410		t.Errorf("Request without conversation_id should not have prefix reference")
411	}
412	if req.RequestBody == nil || *req.RequestBody != reqBody {
413		t.Errorf("Request should store full body")
414	}
415}
416
417func safeDeref(s *string) string {
418	if s == nil {
419		return "<nil>"
420	}
421	return *s
422}
423
424func safeDeref64(i *int64) int64 {
425	if i == nil {
426		return -1
427	}
428	return *i
429}
430
431func TestLLMRequestRealisticConversation(t *testing.T) {
432	// This test simulates realistic LLM API request patterns where each
433	// subsequent request includes all previous messages plus new ones
434	db := setupTestDB(t)
435	defer db.Close()
436
437	ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
438	defer cancel()
439
440	slug := "test-realistic-conv"
441	conv, err := db.CreateConversation(ctx, &slug, true, nil, nil)
442	if err != nil {
443		t.Fatalf("Failed to create conversation: %v", err)
444	}
445
446	// Simulate Anthropic-style messages array growing over conversation
447	// Each request adds to the previous messages
448	baseRequest := `{"model":"claude-sonnet-4-5-20250929","system":[{"type":"text","text":"You are a helpful assistant."}],"messages":[`
449
450	message1 := `{"role":"user","content":[{"type":"text","text":"Hello, how are you?"}]}`
451	req1Body := baseRequest + message1 + `],"max_tokens":8192}`
452
453	req1, err := db.InsertLLMRequest(ctx, generated.InsertLLMRequestParams{
454		ConversationID: &conv.ConversationID,
455		Model:          "claude-sonnet-4-5-20250929",
456		Provider:       "anthropic",
457		Url:            "https://api.anthropic.com/v1/messages",
458		RequestBody:    &req1Body,
459	})
460	if err != nil {
461		t.Fatalf("Failed to insert first request: %v", err)
462	}
463
464	// First request stored in full
465	if req1.PrefixRequestID != nil {
466		t.Errorf("First request should not have prefix reference")
467	}
468
469	// Second request: user message + assistant response + new user message
470	message2 := `{"role":"assistant","content":[{"type":"text","text":"I'm doing well, thank you for asking!"}]}`
471	message3 := `{"role":"user","content":[{"type":"text","text":"Can you help me write some code?"}]}`
472	req2Body := baseRequest + message1 + `,` + message2 + `,` + message3 + `],"max_tokens":8192}`
473
474	req2, err := db.InsertLLMRequest(ctx, generated.InsertLLMRequestParams{
475		ConversationID: &conv.ConversationID,
476		Model:          "claude-sonnet-4-5-20250929",
477		Provider:       "anthropic",
478		Url:            "https://api.anthropic.com/v1/messages",
479		RequestBody:    &req2Body,
480	})
481	if err != nil {
482		t.Fatalf("Failed to insert second request: %v", err)
483	}
484
485	// Second request should have prefix deduplication
486	if req2.PrefixRequestID == nil {
487		t.Errorf("Second request should have prefix reference")
488	} else if *req2.PrefixRequestID != req1.ID {
489		t.Errorf("Second request should reference first request")
490	}
491
492	// Verify prefix length is reasonable (should be at least the base + message1 length)
493	minExpectedPrefix := len(baseRequest) + len(message1)
494	if req2.PrefixLength == nil || *req2.PrefixLength < int64(minExpectedPrefix) {
495		t.Errorf("Second request prefix length should be at least %d, got %v", minExpectedPrefix, safeDeref64(req2.PrefixLength))
496	}
497
498	// Verify we saved significant space
499	req2StoredLen := len(safeDeref(req2.RequestBody))
500	req2FullLen := len(req2Body)
501	if req2StoredLen >= req2FullLen {
502		t.Errorf("Second request should store less than full body: stored %d, full %d", req2StoredLen, req2FullLen)
503	}
504	t.Logf("Space saved for request 2: %d bytes (%.1f%% reduction)",
505		req2FullLen-req2StoredLen,
506		100.0*float64(req2FullLen-req2StoredLen)/float64(req2FullLen))
507
508	// Third request: even more messages
509	message4 := `{"role":"assistant","content":[{"type":"text","text":"Of course! What kind of code would you like me to help you with?"}]}`
510	message5 := `{"role":"user","content":[{"type":"text","text":"I need a function to calculate fibonacci numbers."}]}`
511	req3Body := baseRequest + message1 + `,` + message2 + `,` + message3 + `,` + message4 + `,` + message5 + `],"max_tokens":8192}`
512
513	req3, err := db.InsertLLMRequest(ctx, generated.InsertLLMRequestParams{
514		ConversationID: &conv.ConversationID,
515		Model:          "claude-sonnet-4-5-20250929",
516		Provider:       "anthropic",
517		Url:            "https://api.anthropic.com/v1/messages",
518		RequestBody:    &req3Body,
519	})
520	if err != nil {
521		t.Fatalf("Failed to insert third request: %v", err)
522	}
523
524	// Third request should reference second
525	if req3.PrefixRequestID == nil || *req3.PrefixRequestID != req2.ID {
526		t.Errorf("Third request should reference second request")
527	}
528
529	req3StoredLen := len(safeDeref(req3.RequestBody))
530	req3FullLen := len(req3Body)
531	t.Logf("Space saved for request 3: %d bytes (%.1f%% reduction)",
532		req3FullLen-req3StoredLen,
533		100.0*float64(req3FullLen-req3StoredLen)/float64(req3FullLen))
534
535	// Verify reconstruction works for all requests
536	reconstructed1, err := db.GetFullLLMRequestBody(ctx, req1.ID)
537	if err != nil {
538		t.Fatalf("Failed to reconstruct request 1: %v", err)
539	}
540	if reconstructed1 != req1Body {
541		t.Errorf("Reconstructed request 1 mismatch")
542	}
543
544	reconstructed2, err := db.GetFullLLMRequestBody(ctx, req2.ID)
545	if err != nil {
546		t.Fatalf("Failed to reconstruct request 2: %v", err)
547	}
548	if reconstructed2 != req2Body {
549		t.Errorf("Reconstructed request 2 mismatch")
550	}
551
552	reconstructed3, err := db.GetFullLLMRequestBody(ctx, req3.ID)
553	if err != nil {
554		t.Fatalf("Failed to reconstruct request 3: %v", err)
555	}
556	if reconstructed3 != req3Body {
557		t.Errorf("Reconstructed request 3 mismatch")
558	}
559
560	// Calculate total storage savings
561	totalOriginal := len(req1Body) + len(req2Body) + len(req3Body)
562	totalStored := len(safeDeref(req1.RequestBody)) + len(safeDeref(req2.RequestBody)) + len(safeDeref(req3.RequestBody))
563	t.Logf("Total space: original %d bytes, stored %d bytes, saved %d bytes (%.1f%% reduction)",
564		totalOriginal, totalStored, totalOriginal-totalStored,
565		100.0*float64(totalOriginal-totalStored)/float64(totalOriginal))
566}
567
568func TestLLMRequestOpenAIStyle(t *testing.T) {
569	// Test with OpenAI-style request format
570	db := setupTestDB(t)
571	defer db.Close()
572
573	ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
574	defer cancel()
575
576	slug := "test-openai-conv"
577	conv, err := db.CreateConversation(ctx, &slug, true, nil, nil)
578	if err != nil {
579		t.Fatalf("Failed to create conversation: %v", err)
580	}
581
582	// OpenAI-style request format
583	baseRequest := `{"model":"gpt-4","messages":[`
584	message1 := `{"role":"system","content":"You are a helpful assistant."},{"role":"user","content":"Hello!"}`
585	req1Body := baseRequest + message1 + `],"stream":true}`
586
587	req1, err := db.InsertLLMRequest(ctx, generated.InsertLLMRequestParams{
588		ConversationID: &conv.ConversationID,
589		Model:          "gpt-4",
590		Provider:       "openai",
591		Url:            "https://api.openai.com/v1/chat/completions",
592		RequestBody:    &req1Body,
593	})
594	if err != nil {
595		t.Fatalf("Failed to insert first request: %v", err)
596	}
597
598	// Second request with more messages
599	message2 := `{"role":"assistant","content":"Hello! How can I help you today?"},{"role":"user","content":"What's the weather like?"}`
600	req2Body := baseRequest + message1 + `,` + message2 + `],"stream":true}`
601
602	req2, err := db.InsertLLMRequest(ctx, generated.InsertLLMRequestParams{
603		ConversationID: &conv.ConversationID,
604		Model:          "gpt-4",
605		Provider:       "openai",
606		Url:            "https://api.openai.com/v1/chat/completions",
607		RequestBody:    &req2Body,
608	})
609	if err != nil {
610		t.Fatalf("Failed to insert second request: %v", err)
611	}
612
613	// Should have prefix deduplication
614	if req2.PrefixRequestID == nil || *req2.PrefixRequestID != req1.ID {
615		t.Errorf("Second request should reference first request")
616	}
617
618	// Verify reconstruction
619	reconstructed2, err := db.GetFullLLMRequestBody(ctx, req2.ID)
620	if err != nil {
621		t.Fatalf("Failed to reconstruct second request: %v", err)
622	}
623	if reconstructed2 != req2Body {
624		t.Errorf("Reconstructed request mismatch:\nexpected: %s\ngot: %s", req2Body, reconstructed2)
625	}
626
627	// Calculate savings
628	req2StoredLen := len(safeDeref(req2.RequestBody))
629	req2FullLen := len(req2Body)
630	t.Logf("OpenAI-style space saved: %d bytes (%.1f%% reduction)",
631		req2FullLen-req2StoredLen,
632		100.0*float64(req2FullLen-req2StoredLen)/float64(req2FullLen))
633}