1// SPDX-FileCopyrightText: Amolith <amolith@secluded.site>
  2//
  3// SPDX-License-Identifier: AGPL-3.0-or-later
  4
  5package db
  6
  7import (
  8	"context"
  9	"encoding/json"
 10	"errors"
 11	"os"
 12	"path/filepath"
 13	"runtime"
 14	"strings"
 15	"testing"
 16	"time"
 17)
 18
 19func openTestDB(t *testing.T) *Database {
 20	t.Helper()
 21
 22	database, err := Open(Options{Path: t.TempDir()})
 23	if err != nil {
 24		t.Fatalf("open db: %v", err)
 25	}
 26
 27	t.Cleanup(func() {
 28		if err := database.Close(); err != nil {
 29			t.Fatalf("closing db: %v", err)
 30		}
 31	})
 32
 33	return database
 34}
 35
 36func TestDatabasePathAndClose(t *testing.T) {
 37	database := openTestDB(t)
 38	if database.Path() == "" {
 39		t.Fatalf("Path() returned empty string")
 40	}
 41	if err := database.Close(); err != nil {
 42		t.Fatalf("Close() returned error on first call: %v", err)
 43	}
 44	// Second close should be a noop.
 45	if err := database.Close(); err != nil {
 46		t.Fatalf("Close() returned error on second call: %v", err)
 47	}
 48}
 49
 50func TestDatabaseViewErrorMapping(t *testing.T) {
 51	database := openTestDB(t)
 52
 53	err := database.View(context.Background(), func(txn *Txn) error {
 54		_, err := txn.Get([]byte("missing"))
 55		return err
 56	})
 57	if !errors.Is(err, ErrKeyNotFound) {
 58		t.Fatalf("expected ErrKeyNotFound, got %v", err)
 59	}
 60
 61	aborted := database.View(context.Background(), func(txn *Txn) error {
 62		return txn.Abort()
 63	})
 64	if aborted != nil {
 65		t.Fatalf("expected nil after abort, got %v", aborted)
 66	}
 67
 68	readonlyErr := database.View(context.Background(), func(txn *Txn) error {
 69		return txn.Set([]byte("k"), []byte("v"))
 70	})
 71	if !errors.Is(readonlyErr, ErrReadOnly) {
 72		t.Fatalf("expected ErrReadOnly inside view, got %v", readonlyErr)
 73	}
 74}
 75
 76func TestDatabaseUpdateLifecycle(t *testing.T) {
 77	database := openTestDB(t)
 78
 79	type payload struct {
 80		Message string
 81		Count   int
 82	}
 83	p := payload{Message: "hello", Count: 42}
 84
 85	err := database.Update(context.Background(), func(txn *Txn) error {
 86		if err := txn.Set([]byte("alpha"), []byte("bravo")); err != nil {
 87			return err
 88		}
 89		if err := txn.SetJSON([]byte("payload"), p); err != nil {
 90			return err
 91		}
 92		if _, err := txn.IncrementUint64([]byte("counter"), 5); err != nil {
 93			return err
 94		}
 95		// Seed a value with an invalid length for later IncrementUint64 failure.
 96		if err := txn.Set([]byte("counter-invalid"), []byte("oops")); err != nil {
 97			return err
 98		}
 99		return nil
100	})
101	if err != nil {
102		t.Fatalf("Update() returned error: %v", err)
103	}
104
105	err = database.View(context.Background(), func(txn *Txn) error {
106		val, err := txn.Get([]byte("alpha"))
107		if err != nil {
108			return err
109		}
110		if string(val) != "bravo" {
111			t.Fatalf("unexpected value: %q", val)
112		}
113
114		var decoded payload
115		if err := txn.GetJSON([]byte("payload"), &decoded); err != nil {
116			return err
117		}
118		if decoded != p {
119			t.Fatalf("unexpected payload: %#v", decoded)
120		}
121
122		exists, err := txn.Exists([]byte("alpha"))
123		if err != nil {
124			return err
125		}
126		if !exists {
127			t.Fatalf("expected Exists to report true")
128		}
129
130		expected := map[string]string{
131			"alpha":   "bravo",
132			"payload": string(mustJSON(p)),
133		}
134		collected := make(map[string]string)
135		iterErr := txn.Iterate(IterateOptions{PrefetchValues: true}, func(item Item) error {
136			key := item.KeyString()
137			val, err := item.Value()
138			if err != nil {
139				return err
140			}
141			collected[key] = string(val)
142			return nil
143		})
144		if iterErr != nil {
145			return iterErr
146		}
147
148		for key, want := range expected {
149			got, ok := collected[key]
150			if !ok {
151				t.Fatalf("missing key %q in iteration results", key)
152			}
153			if got != want {
154				t.Fatalf("unexpected value for %q: got %q want %q", key, got, want)
155			}
156		}
157
158		abortErr := txn.Iterate(IterateOptions{}, func(Item) error {
159			return ErrTxnAborted
160		})
161		if abortErr != nil {
162			t.Fatalf("expected nil when aborting iteration, got %v", abortErr)
163		}
164
165		prefixCount := 0
166		prefixErr := txn.Iterate(IterateOptions{Prefix: []byte("a")}, func(item Item) error {
167			if !strings.HasPrefix(item.KeyString(), "a") {
168				t.Fatalf("expected prefix match, got %q", item.KeyString())
169			}
170			prefixCount++
171			return nil
172		})
173		if prefixErr != nil {
174			t.Fatalf("prefix iteration error: %v", prefixErr)
175		}
176		if prefixCount == 0 {
177			t.Fatalf("expected at least one prefix item, got %d", prefixCount)
178		}
179
180		return nil
181	})
182	if err != nil {
183		t.Fatalf("View() returned error: %v", err)
184	}
185
186	err = database.Update(context.Background(), func(txn *Txn) error {
187		val, err := txn.IncrementUint64([]byte("counter"), 7)
188		if err != nil {
189			return err
190		}
191		if val != 12 {
192			t.Fatalf("expected counter=12, got %d", val)
193		}
194
195		_, err = txn.IncrementUint64([]byte("counter-invalid"), 1)
196		if err == nil {
197			t.Fatalf("expected error when incrementing invalid counter")
198		}
199
200		return txn.Delete([]byte("alpha"))
201	})
202	if err != nil {
203		t.Fatalf("Update() increment phase error: %v", err)
204	}
205
206	err = database.View(context.Background(), func(txn *Txn) error {
207		ok, err := txn.Exists([]byte("alpha"))
208		if err != nil {
209			return err
210		}
211		if ok {
212			t.Fatalf("expected alpha to be deleted")
213		}
214		return nil
215	})
216	if err != nil {
217		t.Fatalf("View after delete error: %v", err)
218	}
219}
220
221func TestDatabaseUpdateErrorMapping(t *testing.T) {
222	database := openTestDB(t)
223
224	err := database.Update(context.Background(), func(txn *Txn) error {
225		_, err := txn.Get([]byte("missing"))
226		return err
227	})
228	if !errors.Is(err, ErrKeyNotFound) {
229		t.Fatalf("expected ErrKeyNotFound, got %v", err)
230	}
231
232	aborted := database.Update(context.Background(), func(txn *Txn) error {
233		return txn.Abort()
234	})
235	if aborted != nil {
236		t.Fatalf("expected nil after abort, got %v", aborted)
237	}
238}
239
240func TestDatabaseUpdateReadOnly(t *testing.T) {
241	path := t.TempDir()
242	writable, err := Open(Options{Path: path})
243	if err != nil {
244		t.Fatalf("prepare writable db: %v", err)
245	}
246	if err := writable.Close(); err != nil {
247		t.Fatalf("close writable db: %v", err)
248	}
249
250	database, err := Open(Options{Path: path, ReadOnly: true})
251	if err != nil {
252		t.Fatalf("open read-only db: %v", err)
253	}
254	t.Cleanup(func() {
255		if err := database.Close(); err != nil {
256			t.Fatalf("close db: %v", err)
257		}
258	})
259
260	err = database.Update(context.Background(), func(txn *Txn) error {
261		return txn.Set([]byte("k"), []byte("v"))
262	})
263	if !errors.Is(err, ErrReadOnly) {
264		t.Fatalf("expected ErrReadOnly for read-only DB, got %v", err)
265	}
266}
267
268func TestDatabaseUpdateRetryRespectsContext(t *testing.T) {
269	database := openTestDB(t)
270
271	ctx, cancel := context.WithCancel(context.Background())
272	cancel() // cancel immediately
273
274	err := database.Update(ctx, func(txn *Txn) error {
275		return txn.Set([]byte("k"), []byte("v"))
276	})
277	if !errors.Is(err, context.Canceled) {
278		t.Fatalf("expected context.Canceled, got %v", err)
279	}
280}
281
282func TestIterateReverseOrder(t *testing.T) {
283	database := openTestDB(t)
284
285	keys := [][]byte{
286		[]byte("alpha"),
287		[]byte("beta"),
288		[]byte("gamma"),
289	}
290	err := database.Update(context.Background(), func(txn *Txn) error {
291		for _, key := range keys {
292			if err := txn.Set(key, []byte(key)); err != nil {
293				return err
294			}
295		}
296		return nil
297	})
298	if err != nil {
299		t.Fatalf("populate keys: %v", err)
300	}
301
302	var seen []string
303	err = database.View(context.Background(), func(txn *Txn) error {
304		return txn.Iterate(IterateOptions{Reverse: true}, func(item Item) error {
305			seen = append(seen, item.KeyString())
306			return nil
307		})
308	})
309	if err != nil {
310		t.Fatalf("reverse iteration: %v", err)
311	}
312
313	expected := []string{"gamma", "beta", "alpha"}
314	if len(seen) != len(expected) {
315		t.Fatalf("unexpected number of items: got %d want %d", len(seen), len(expected))
316	}
317	for i, want := range expected {
318		if seen[i] != want {
319			t.Fatalf("unexpected key at %d: got %q want %q", i, seen[i], want)
320		}
321	}
322}
323
324func TestIncrementUint64InitialValue(t *testing.T) {
325	database := openTestDB(t)
326	var result uint64
327	err := database.Update(context.Background(), func(txn *Txn) error {
328		var err error
329		result, err = txn.IncrementUint64([]byte("counter"), 3)
330		return err
331	})
332	if err != nil {
333		t.Fatalf("increment: %v", err)
334	}
335	if result != 3 {
336		t.Fatalf("expected counter=3, got %d", result)
337	}
338}
339
340type spyLogger struct {
341	errors   int
342	warnings int
343	infos    int
344	debugs   int
345}
346
347func (s *spyLogger) Errorf(string, ...any)   { s.errors++ }
348func (s *spyLogger) Warningf(string, ...any) { s.warnings++ }
349func (s *spyLogger) Infof(string, ...any)    { s.infos++ }
350func (s *spyLogger) Debugf(string, ...any)   { s.debugs++ }
351
352func TestBadgerLoggerAdapter(t *testing.T) {
353	logger := &spyLogger{}
354	adapter := badgerLoggerAdapter{logger: logger}
355	adapter.Errorf("error")
356	adapter.Warningf("warn")
357	adapter.Infof("info")
358	adapter.Debugf("debug")
359
360	if logger.errors != 1 || logger.warnings != 1 || logger.infos != 1 || logger.debugs != 1 {
361		t.Fatalf("logger counts unexpected: %+v", logger)
362	}
363}
364
365func TestEnsureDir(t *testing.T) {
366	base := t.TempDir()
367	target := filepath.Join(base, "a", "b")
368	if err := ensureDir(target); err != nil {
369		t.Fatalf("ensureDir: %v", err)
370	}
371	info, err := os.Stat(target)
372	if err != nil {
373		t.Fatalf("stat target: %v", err)
374	}
375	if !info.IsDir() {
376		t.Fatalf("expected directory, got file")
377	}
378}
379
380func mustJSON(v any) []byte {
381	data, err := json.Marshal(v)
382	if err != nil {
383		panic(err)
384	}
385	return data
386}
387
388func TestDefaultPath(t *testing.T) {
389	path, err := DefaultPath()
390	if err != nil {
391		t.Fatalf("DefaultPath() error: %v", err)
392	}
393	if path == "" {
394		t.Fatalf("DefaultPath() returned empty string")
395	}
396}
397
398func TestOptionsApplyDefaults(t *testing.T) {
399	t.Run("DefaultsApplied", func(t *testing.T) {
400		opts, err := (Options{}).applyDefaults()
401		if err != nil {
402			t.Fatalf("applyDefaults: %v", err)
403		}
404		if opts.Path == "" {
405			t.Fatalf("expected Path to be populated")
406		}
407		if opts.Logger == nil {
408			t.Fatalf("expected Logger to be non-nil")
409		}
410		if opts.MaxTxnRetries != defaultTxnMaxRetries {
411			t.Fatalf("expected MaxTxnRetries=%d, got %d", defaultTxnMaxRetries, opts.MaxTxnRetries)
412		}
413		if opts.ConflictBackoff != defaultConflictBackoff {
414			t.Fatalf("expected ConflictBackoff=%s, got %s", defaultConflictBackoff, opts.ConflictBackoff)
415		}
416		if !opts.SyncWrites {
417			t.Fatalf("expected SyncWrites default true")
418		}
419	})
420
421	t.Run("NegativeRetries", func(t *testing.T) {
422		_, err := (Options{MaxTxnRetries: -1}).applyDefaults()
423		if err == nil {
424			t.Fatalf("expected error for negative MaxTxnRetries")
425		}
426	})
427
428	t.Run("RespectExistingValues", func(t *testing.T) {
429		opts, err := (Options{
430			Path:            "/tmp/custom",
431			Logger:          &spyLogger{},
432			MaxTxnRetries:   3,
433			ConflictBackoff: time.Second,
434			ReadOnly:        true,
435		}).applyDefaults()
436		if err != nil {
437			t.Fatalf("applyDefaults: %v", err)
438		}
439		if opts.Path != "/tmp/custom" {
440			t.Fatalf("expected Path to be preserved, got %q", opts.Path)
441		}
442		if opts.MaxTxnRetries != 3 {
443			t.Fatalf("expected MaxTxnRetries=3, got %d", opts.MaxTxnRetries)
444		}
445		if opts.ConflictBackoff != time.Second {
446			t.Fatalf("expected ConflictBackoff=1s, got %s", opts.ConflictBackoff)
447		}
448		if opts.SyncWrites {
449			t.Fatalf("expected SyncWrites to remain false for read-only")
450		}
451		if _, ok := opts.Logger.(*spyLogger); !ok {
452			t.Fatalf("expected Logger to remain spyLogger")
453		}
454	})
455}
456
457func TestCanonicalizeDir(t *testing.T) {
458	dir := t.TempDir()
459	canonical, err := CanonicalizeDir(dir)
460	if err != nil {
461		t.Fatalf("CanonicalizeDir error: %v", err)
462	}
463	if canonical == "" {
464		t.Fatalf("expected non-empty canonical path")
465	}
466	if runtime.GOOS != "windows" && !strings.HasPrefix(canonical, "/") {
467		t.Fatalf("expected absolute path, got %q", canonical)
468	}
469}
470
471func TestDirHashConsistency(t *testing.T) {
472	dir := t.TempDir()
473	canonical, err := CanonicalizeDir(dir)
474	if err != nil {
475		t.Fatalf("canonicalize: %v", err)
476	}
477	hash1 := DirHash(canonical)
478	hash2 := DirHash(canonical)
479	if hash1 != hash2 {
480		t.Fatalf("expected consistent hashes, got %q vs %q", hash1, hash2)
481	}
482	if len(hash1) != 64 {
483		t.Fatalf("expected 64 hex chars, got %d", len(hash1))
484	}
485}
486
487func TestCanonicalizeAndHash(t *testing.T) {
488	dir := t.TempDir()
489	canonical, hash, err := CanonicalizeAndHash(dir)
490	if err != nil {
491		t.Fatalf("CanonicalizeAndHash error: %v", err)
492	}
493	if canonical == "" || hash == "" {
494		t.Fatalf("expected non-empty canonical/hash")
495	}
496}
497
498func TestParentWalk(t *testing.T) {
499	dir := t.TempDir()
500	canonical, err := CanonicalizeDir(dir)
501	if err != nil {
502		t.Fatalf("canonicalize: %v", err)
503	}
504	parents := ParentWalk(canonical)
505	if len(parents) == 0 {
506		t.Fatalf("expected at least one parent")
507	}
508	if parents[0] != canonical {
509		t.Fatalf("expected first element to be input path")
510	}
511	seenRoot := false
512	for _, p := range parents {
513		if p == "/" || (runtime.GOOS == "windows" && len(p) == 3 && p[1] == ':' && p[2] == '/') {
514			seenRoot = true
515		}
516	}
517	if !seenRoot {
518		t.Fatalf("expected parent list to reach root, got %v", parents)
519	}
520}
521
522func TestKeysAndPrefixes(t *testing.T) {
523	if got := string(KeySchemaVersion()); got != "meta/schema_version" {
524		t.Fatalf("unexpected schema key: %q", got)
525	}
526	if got := string(KeyDirActive("hash")); got != "dir/hash/active" {
527		t.Fatalf("unexpected dir active key: %q", got)
528	}
529	if got := string(KeyDirArchived("hash", "ts", "sid")); got != "dir/hash/archived/ts/sid" {
530		t.Fatalf("unexpected dir archived key: %q", got)
531	}
532	if got := string(KeyIdxActive("sid")); got != "idx/active/sid" {
533		t.Fatalf("unexpected idx active key: %q", got)
534	}
535	if got := string(KeyIdxArchived("ts", "sid")); got != "idx/archived/ts/sid" {
536		t.Fatalf("unexpected idx archived key: %q", got)
537	}
538	if got := string(KeySessionMeta("sid")); got != "s/sid/meta" {
539		t.Fatalf("unexpected session meta key: %q", got)
540	}
541	if got := string(KeySessionGoal("sid")); got != "s/sid/goal" {
542		t.Fatalf("unexpected session goal key: %q", got)
543	}
544	if got := string(KeySessionTask("sid", "tid")); got != "s/sid/task/tid" {
545		t.Fatalf("unexpected session task key: %q", got)
546	}
547	if got := string(KeySessionTaskStatusIndex("sid", "pending", "tid")); got != "s/sid/idx/status/pending/tid" {
548		t.Fatalf("unexpected status idx key: %q", got)
549	}
550	if got := string(KeySessionEventSeq("sid")); got != "s/sid/meta/evt_seq" {
551		t.Fatalf("unexpected event seq key: %q", got)
552	}
553	if got := string(KeySessionEvent("sid", 10)); got != "s/sid/evt/000000000000000a" {
554		t.Fatalf("unexpected session event key: %q", got)
555	}
556	if prefix := string(PrefixSessionTasks("sid")); prefix != "s/sid/task" {
557		t.Fatalf("unexpected tasks prefix: %q", prefix)
558	}
559	if prefix := string(PrefixSessionStatusIndex("sid", "pending")); prefix != "s/sid/idx/status/pending" {
560		t.Fatalf("unexpected status prefix: %q", prefix)
561	}
562	if prefix := string(PrefixSessionEvents("sid")); prefix != "s/sid/evt" {
563		t.Fatalf("unexpected events prefix: %q", prefix)
564	}
565	if prefix := string(PrefixDirArchived("hash")); prefix != "dir/hash/archived" {
566		t.Fatalf("unexpected dir archived prefix: %q", prefix)
567	}
568	if prefix := string(PrefixIdxActive()); prefix != "idx/active" {
569		t.Fatalf("unexpected idx active prefix: %q", prefix)
570	}
571	if prefix := string(PrefixIdxArchived()); prefix != "idx/archived" {
572		t.Fatalf("unexpected idx archived prefix: %q", prefix)
573	}
574}
575
576func TestEncodingHelpers(t *testing.T) {
577	var buf [8]byte
578	putUint64(buf[:], 123)
579	if got := readUint64(buf[:]); got != 123 {
580		t.Fatalf("unexpected roundtrip: got %d want 123", got)
581	}
582	if _, err := decodeUint64([]byte{1, 2}); err == nil {
583		t.Fatalf("expected error for short buffer")
584	}
585	if hex := Uint64Hex(10); hex != "000000000000000a" {
586		t.Fatalf("unexpected Uint64Hex: %q", hex)
587	}
588	if hex := encodeHex([]byte{0xde, 0xad}); hex != "dead" {
589		t.Fatalf("unexpected encodeHex: %q", hex)
590	}
591}
592
593func TestIterateReversePrefixOrder(t *testing.T) {
594	database := openTestDB(t)
595
596	keys := [][]byte{
597		[]byte("dir/abc/archived/0001/sid1"),
598		[]byte("dir/abc/archived/0002/sid2"),
599		[]byte("dir/abc/archived/0003/sid3"),
600		[]byte("dir/xyz/archived/0001/sid4"),
601		[]byte("other/key"),
602	}
603
604	err := database.Update(context.Background(), func(txn *Txn) error {
605		for _, key := range keys {
606			if err := txn.Set(key, key); err != nil {
607				return err
608			}
609		}
610		return nil
611	})
612	if err != nil {
613		t.Fatalf("populate keys: %v", err)
614	}
615
616	t.Run("ReverseWithPrefix", func(t *testing.T) {
617		var seen []string
618		err := database.View(context.Background(), func(txn *Txn) error {
619			return txn.Iterate(IterateOptions{
620				Prefix:  []byte("dir/abc/archived"),
621				Reverse: true,
622			}, func(item Item) error {
623				seen = append(seen, item.KeyString())
624				return nil
625			})
626		})
627		if err != nil {
628			t.Fatalf("reverse prefix iteration: %v", err)
629		}
630
631		expected := []string{
632			"dir/abc/archived/0003/sid3",
633			"dir/abc/archived/0002/sid2",
634			"dir/abc/archived/0001/sid1",
635		}
636		if len(seen) != len(expected) {
637			t.Fatalf("unexpected number of items: got %d want %d", len(seen), len(expected))
638		}
639		for i, want := range expected {
640			if seen[i] != want {
641				t.Fatalf("unexpected key at %d: got %q want %q", i, seen[i], want)
642			}
643		}
644	})
645
646	t.Run("ForwardWithPrefix", func(t *testing.T) {
647		var seen []string
648		err := database.View(context.Background(), func(txn *Txn) error {
649			return txn.Iterate(IterateOptions{
650				Prefix: []byte("dir/abc/archived"),
651			}, func(item Item) error {
652				seen = append(seen, item.KeyString())
653				return nil
654			})
655		})
656		if err != nil {
657			t.Fatalf("forward prefix iteration: %v", err)
658		}
659
660		expected := []string{
661			"dir/abc/archived/0001/sid1",
662			"dir/abc/archived/0002/sid2",
663			"dir/abc/archived/0003/sid3",
664		}
665		if len(seen) != len(expected) {
666			t.Fatalf("unexpected number of items: got %d want %d", len(seen), len(expected))
667		}
668		for i, want := range expected {
669			if seen[i] != want {
670				t.Fatalf("unexpected key at %d: got %q want %q", i, seen[i], want)
671			}
672		}
673	})
674
675	t.Run("EmptyPrefixReverse", func(t *testing.T) {
676		var count int
677		err := database.View(context.Background(), func(txn *Txn) error {
678			return txn.Iterate(IterateOptions{
679				Prefix:  []byte("nonexistent/prefix"),
680				Reverse: true,
681			}, func(item Item) error {
682				count++
683				return nil
684			})
685		})
686		if err != nil {
687			t.Fatalf("empty prefix iteration: %v", err)
688		}
689		if count != 0 {
690			t.Fatalf("expected 0 items for nonexistent prefix, got %d", count)
691		}
692	})
693}