1// SPDX-FileCopyrightText: Amolith <amolith@secluded.site>
2//
3// SPDX-License-Identifier: AGPL-3.0-or-later
4
5package db
6
7import (
8 "context"
9 "encoding/json"
10 "errors"
11 "os"
12 "path/filepath"
13 "runtime"
14 "strings"
15 "testing"
16 "time"
17)
18
19func openTestDB(t *testing.T) *Database {
20 t.Helper()
21
22 database, err := Open(Options{Path: t.TempDir()})
23 if err != nil {
24 t.Fatalf("open db: %v", err)
25 }
26
27 t.Cleanup(func() {
28 if err := database.Close(); err != nil {
29 t.Fatalf("closing db: %v", err)
30 }
31 })
32
33 return database
34}
35
36func TestDatabasePathAndClose(t *testing.T) {
37 database := openTestDB(t)
38 if database.Path() == "" {
39 t.Fatalf("Path() returned empty string")
40 }
41 if err := database.Close(); err != nil {
42 t.Fatalf("Close() returned error on first call: %v", err)
43 }
44 // Second close should be a noop.
45 if err := database.Close(); err != nil {
46 t.Fatalf("Close() returned error on second call: %v", err)
47 }
48}
49
50func TestDatabaseViewErrorMapping(t *testing.T) {
51 database := openTestDB(t)
52
53 err := database.View(context.Background(), func(txn *Txn) error {
54 _, err := txn.Get([]byte("missing"))
55 return err
56 })
57 if !errors.Is(err, ErrKeyNotFound) {
58 t.Fatalf("expected ErrKeyNotFound, got %v", err)
59 }
60
61 aborted := database.View(context.Background(), func(txn *Txn) error {
62 return txn.Abort()
63 })
64 if aborted != nil {
65 t.Fatalf("expected nil after abort, got %v", aborted)
66 }
67
68 readonlyErr := database.View(context.Background(), func(txn *Txn) error {
69 return txn.Set([]byte("k"), []byte("v"))
70 })
71 if !errors.Is(readonlyErr, ErrReadOnly) {
72 t.Fatalf("expected ErrReadOnly inside view, got %v", readonlyErr)
73 }
74}
75
76func TestDatabaseUpdateLifecycle(t *testing.T) {
77 database := openTestDB(t)
78
79 type payload struct {
80 Message string
81 Count int
82 }
83 p := payload{Message: "hello", Count: 42}
84
85 err := database.Update(context.Background(), func(txn *Txn) error {
86 if err := txn.Set([]byte("alpha"), []byte("bravo")); err != nil {
87 return err
88 }
89 if err := txn.SetJSON([]byte("payload"), p); err != nil {
90 return err
91 }
92 if _, err := txn.IncrementUint64([]byte("counter"), 5); err != nil {
93 return err
94 }
95 // Seed a value with an invalid length for later IncrementUint64 failure.
96 if err := txn.Set([]byte("counter-invalid"), []byte("oops")); err != nil {
97 return err
98 }
99 return nil
100 })
101 if err != nil {
102 t.Fatalf("Update() returned error: %v", err)
103 }
104
105 err = database.View(context.Background(), func(txn *Txn) error {
106 val, err := txn.Get([]byte("alpha"))
107 if err != nil {
108 return err
109 }
110 if string(val) != "bravo" {
111 t.Fatalf("unexpected value: %q", val)
112 }
113
114 var decoded payload
115 if err := txn.GetJSON([]byte("payload"), &decoded); err != nil {
116 return err
117 }
118 if decoded != p {
119 t.Fatalf("unexpected payload: %#v", decoded)
120 }
121
122 exists, err := txn.Exists([]byte("alpha"))
123 if err != nil {
124 return err
125 }
126 if !exists {
127 t.Fatalf("expected Exists to report true")
128 }
129
130 expected := map[string]string{
131 "alpha": "bravo",
132 "payload": string(mustJSON(p)),
133 }
134 collected := make(map[string]string)
135 iterErr := txn.Iterate(IterateOptions{PrefetchValues: true}, func(item Item) error {
136 key := item.KeyString()
137 val, err := item.Value()
138 if err != nil {
139 return err
140 }
141 collected[key] = string(val)
142 return nil
143 })
144 if iterErr != nil {
145 return iterErr
146 }
147
148 for key, want := range expected {
149 got, ok := collected[key]
150 if !ok {
151 t.Fatalf("missing key %q in iteration results", key)
152 }
153 if got != want {
154 t.Fatalf("unexpected value for %q: got %q want %q", key, got, want)
155 }
156 }
157
158 abortErr := txn.Iterate(IterateOptions{}, func(Item) error {
159 return ErrTxnAborted
160 })
161 if abortErr != nil {
162 t.Fatalf("expected nil when aborting iteration, got %v", abortErr)
163 }
164
165 prefixCount := 0
166 prefixErr := txn.Iterate(IterateOptions{Prefix: []byte("a")}, func(item Item) error {
167 if !strings.HasPrefix(item.KeyString(), "a") {
168 t.Fatalf("expected prefix match, got %q", item.KeyString())
169 }
170 prefixCount++
171 return nil
172 })
173 if prefixErr != nil {
174 t.Fatalf("prefix iteration error: %v", prefixErr)
175 }
176 if prefixCount == 0 {
177 t.Fatalf("expected at least one prefix item, got %d", prefixCount)
178 }
179
180 return nil
181 })
182 if err != nil {
183 t.Fatalf("View() returned error: %v", err)
184 }
185
186 err = database.Update(context.Background(), func(txn *Txn) error {
187 val, err := txn.IncrementUint64([]byte("counter"), 7)
188 if err != nil {
189 return err
190 }
191 if val != 12 {
192 t.Fatalf("expected counter=12, got %d", val)
193 }
194
195 _, err = txn.IncrementUint64([]byte("counter-invalid"), 1)
196 if err == nil {
197 t.Fatalf("expected error when incrementing invalid counter")
198 }
199
200 return txn.Delete([]byte("alpha"))
201 })
202 if err != nil {
203 t.Fatalf("Update() increment phase error: %v", err)
204 }
205
206 err = database.View(context.Background(), func(txn *Txn) error {
207 ok, err := txn.Exists([]byte("alpha"))
208 if err != nil {
209 return err
210 }
211 if ok {
212 t.Fatalf("expected alpha to be deleted")
213 }
214 return nil
215 })
216 if err != nil {
217 t.Fatalf("View after delete error: %v", err)
218 }
219}
220
221func TestDatabaseUpdateErrorMapping(t *testing.T) {
222 database := openTestDB(t)
223
224 err := database.Update(context.Background(), func(txn *Txn) error {
225 _, err := txn.Get([]byte("missing"))
226 return err
227 })
228 if !errors.Is(err, ErrKeyNotFound) {
229 t.Fatalf("expected ErrKeyNotFound, got %v", err)
230 }
231
232 aborted := database.Update(context.Background(), func(txn *Txn) error {
233 return txn.Abort()
234 })
235 if aborted != nil {
236 t.Fatalf("expected nil after abort, got %v", aborted)
237 }
238}
239
240func TestDatabaseUpdateReadOnly(t *testing.T) {
241 path := t.TempDir()
242 writable, err := Open(Options{Path: path})
243 if err != nil {
244 t.Fatalf("prepare writable db: %v", err)
245 }
246 if err := writable.Close(); err != nil {
247 t.Fatalf("close writable db: %v", err)
248 }
249
250 database, err := Open(Options{Path: path, ReadOnly: true})
251 if err != nil {
252 t.Fatalf("open read-only db: %v", err)
253 }
254 t.Cleanup(func() {
255 if err := database.Close(); err != nil {
256 t.Fatalf("close db: %v", err)
257 }
258 })
259
260 err = database.Update(context.Background(), func(txn *Txn) error {
261 return txn.Set([]byte("k"), []byte("v"))
262 })
263 if !errors.Is(err, ErrReadOnly) {
264 t.Fatalf("expected ErrReadOnly for read-only DB, got %v", err)
265 }
266}
267
268func TestDatabaseUpdateRetryRespectsContext(t *testing.T) {
269 database := openTestDB(t)
270
271 ctx, cancel := context.WithCancel(context.Background())
272 cancel() // cancel immediately
273
274 err := database.Update(ctx, func(txn *Txn) error {
275 return txn.Set([]byte("k"), []byte("v"))
276 })
277 if !errors.Is(err, context.Canceled) {
278 t.Fatalf("expected context.Canceled, got %v", err)
279 }
280}
281
282func TestIterateReverseOrder(t *testing.T) {
283 database := openTestDB(t)
284
285 keys := [][]byte{
286 []byte("alpha"),
287 []byte("beta"),
288 []byte("gamma"),
289 }
290 err := database.Update(context.Background(), func(txn *Txn) error {
291 for _, key := range keys {
292 if err := txn.Set(key, []byte(key)); err != nil {
293 return err
294 }
295 }
296 return nil
297 })
298 if err != nil {
299 t.Fatalf("populate keys: %v", err)
300 }
301
302 var seen []string
303 err = database.View(context.Background(), func(txn *Txn) error {
304 return txn.Iterate(IterateOptions{Reverse: true}, func(item Item) error {
305 seen = append(seen, item.KeyString())
306 return nil
307 })
308 })
309 if err != nil {
310 t.Fatalf("reverse iteration: %v", err)
311 }
312
313 expected := []string{"gamma", "beta", "alpha"}
314 if len(seen) != len(expected) {
315 t.Fatalf("unexpected number of items: got %d want %d", len(seen), len(expected))
316 }
317 for i, want := range expected {
318 if seen[i] != want {
319 t.Fatalf("unexpected key at %d: got %q want %q", i, seen[i], want)
320 }
321 }
322}
323
324func TestIncrementUint64InitialValue(t *testing.T) {
325 database := openTestDB(t)
326 var result uint64
327 err := database.Update(context.Background(), func(txn *Txn) error {
328 var err error
329 result, err = txn.IncrementUint64([]byte("counter"), 3)
330 return err
331 })
332 if err != nil {
333 t.Fatalf("increment: %v", err)
334 }
335 if result != 3 {
336 t.Fatalf("expected counter=3, got %d", result)
337 }
338}
339
340type spyLogger struct {
341 errors int
342 warnings int
343 infos int
344 debugs int
345}
346
347func (s *spyLogger) Errorf(string, ...any) { s.errors++ }
348func (s *spyLogger) Warningf(string, ...any) { s.warnings++ }
349func (s *spyLogger) Infof(string, ...any) { s.infos++ }
350func (s *spyLogger) Debugf(string, ...any) { s.debugs++ }
351
352func TestBadgerLoggerAdapter(t *testing.T) {
353 logger := &spyLogger{}
354 adapter := badgerLoggerAdapter{logger: logger}
355 adapter.Errorf("error")
356 adapter.Warningf("warn")
357 adapter.Infof("info")
358 adapter.Debugf("debug")
359
360 if logger.errors != 1 || logger.warnings != 1 || logger.infos != 1 || logger.debugs != 1 {
361 t.Fatalf("logger counts unexpected: %+v", logger)
362 }
363}
364
365func TestEnsureDir(t *testing.T) {
366 base := t.TempDir()
367 target := filepath.Join(base, "a", "b")
368 if err := ensureDir(target); err != nil {
369 t.Fatalf("ensureDir: %v", err)
370 }
371 info, err := os.Stat(target)
372 if err != nil {
373 t.Fatalf("stat target: %v", err)
374 }
375 if !info.IsDir() {
376 t.Fatalf("expected directory, got file")
377 }
378}
379
380func mustJSON(v any) []byte {
381 data, err := json.Marshal(v)
382 if err != nil {
383 panic(err)
384 }
385 return data
386}
387
388func TestDefaultPath(t *testing.T) {
389 path, err := DefaultPath()
390 if err != nil {
391 t.Fatalf("DefaultPath() error: %v", err)
392 }
393 if path == "" {
394 t.Fatalf("DefaultPath() returned empty string")
395 }
396}
397
398func TestOptionsApplyDefaults(t *testing.T) {
399 t.Run("DefaultsApplied", func(t *testing.T) {
400 opts, err := (Options{}).applyDefaults()
401 if err != nil {
402 t.Fatalf("applyDefaults: %v", err)
403 }
404 if opts.Path == "" {
405 t.Fatalf("expected Path to be populated")
406 }
407 if opts.Logger == nil {
408 t.Fatalf("expected Logger to be non-nil")
409 }
410 if opts.MaxTxnRetries != defaultTxnMaxRetries {
411 t.Fatalf("expected MaxTxnRetries=%d, got %d", defaultTxnMaxRetries, opts.MaxTxnRetries)
412 }
413 if opts.ConflictBackoff != defaultConflictBackoff {
414 t.Fatalf("expected ConflictBackoff=%s, got %s", defaultConflictBackoff, opts.ConflictBackoff)
415 }
416 if !opts.SyncWrites {
417 t.Fatalf("expected SyncWrites default true")
418 }
419 })
420
421 t.Run("NegativeRetries", func(t *testing.T) {
422 _, err := (Options{MaxTxnRetries: -1}).applyDefaults()
423 if err == nil {
424 t.Fatalf("expected error for negative MaxTxnRetries")
425 }
426 })
427
428 t.Run("RespectExistingValues", func(t *testing.T) {
429 opts, err := (Options{
430 Path: "/tmp/custom",
431 Logger: &spyLogger{},
432 MaxTxnRetries: 3,
433 ConflictBackoff: time.Second,
434 ReadOnly: true,
435 }).applyDefaults()
436 if err != nil {
437 t.Fatalf("applyDefaults: %v", err)
438 }
439 if opts.Path != "/tmp/custom" {
440 t.Fatalf("expected Path to be preserved, got %q", opts.Path)
441 }
442 if opts.MaxTxnRetries != 3 {
443 t.Fatalf("expected MaxTxnRetries=3, got %d", opts.MaxTxnRetries)
444 }
445 if opts.ConflictBackoff != time.Second {
446 t.Fatalf("expected ConflictBackoff=1s, got %s", opts.ConflictBackoff)
447 }
448 if opts.SyncWrites {
449 t.Fatalf("expected SyncWrites to remain false for read-only")
450 }
451 if _, ok := opts.Logger.(*spyLogger); !ok {
452 t.Fatalf("expected Logger to remain spyLogger")
453 }
454 })
455}
456
457func TestCanonicalizeDir(t *testing.T) {
458 dir := t.TempDir()
459 canonical, err := CanonicalizeDir(dir)
460 if err != nil {
461 t.Fatalf("CanonicalizeDir error: %v", err)
462 }
463 if canonical == "" {
464 t.Fatalf("expected non-empty canonical path")
465 }
466 if runtime.GOOS != "windows" && !strings.HasPrefix(canonical, "/") {
467 t.Fatalf("expected absolute path, got %q", canonical)
468 }
469}
470
471func TestDirHashConsistency(t *testing.T) {
472 dir := t.TempDir()
473 canonical, err := CanonicalizeDir(dir)
474 if err != nil {
475 t.Fatalf("canonicalize: %v", err)
476 }
477 hash1 := DirHash(canonical)
478 hash2 := DirHash(canonical)
479 if hash1 != hash2 {
480 t.Fatalf("expected consistent hashes, got %q vs %q", hash1, hash2)
481 }
482 if len(hash1) != 64 {
483 t.Fatalf("expected 64 hex chars, got %d", len(hash1))
484 }
485}
486
487func TestCanonicalizeAndHash(t *testing.T) {
488 dir := t.TempDir()
489 canonical, hash, err := CanonicalizeAndHash(dir)
490 if err != nil {
491 t.Fatalf("CanonicalizeAndHash error: %v", err)
492 }
493 if canonical == "" || hash == "" {
494 t.Fatalf("expected non-empty canonical/hash")
495 }
496}
497
498func TestParentWalk(t *testing.T) {
499 dir := t.TempDir()
500 canonical, err := CanonicalizeDir(dir)
501 if err != nil {
502 t.Fatalf("canonicalize: %v", err)
503 }
504 parents := ParentWalk(canonical)
505 if len(parents) == 0 {
506 t.Fatalf("expected at least one parent")
507 }
508 if parents[0] != canonical {
509 t.Fatalf("expected first element to be input path")
510 }
511 seenRoot := false
512 for _, p := range parents {
513 if p == "/" || (runtime.GOOS == "windows" && len(p) == 3 && p[1] == ':' && p[2] == '/') {
514 seenRoot = true
515 }
516 }
517 if !seenRoot {
518 t.Fatalf("expected parent list to reach root, got %v", parents)
519 }
520}
521
522func TestKeysAndPrefixes(t *testing.T) {
523 if got := string(KeySchemaVersion()); got != "meta/schema_version" {
524 t.Fatalf("unexpected schema key: %q", got)
525 }
526 if got := string(KeyDirActive("hash")); got != "dir/hash/active" {
527 t.Fatalf("unexpected dir active key: %q", got)
528 }
529 if got := string(KeyDirArchived("hash", "ts", "sid")); got != "dir/hash/archived/ts/sid" {
530 t.Fatalf("unexpected dir archived key: %q", got)
531 }
532 if got := string(KeyIdxActive("sid")); got != "idx/active/sid" {
533 t.Fatalf("unexpected idx active key: %q", got)
534 }
535 if got := string(KeyIdxArchived("ts", "sid")); got != "idx/archived/ts/sid" {
536 t.Fatalf("unexpected idx archived key: %q", got)
537 }
538 if got := string(KeySessionMeta("sid")); got != "s/sid/meta" {
539 t.Fatalf("unexpected session meta key: %q", got)
540 }
541 if got := string(KeySessionGoal("sid")); got != "s/sid/goal" {
542 t.Fatalf("unexpected session goal key: %q", got)
543 }
544 if got := string(KeySessionTask("sid", "tid")); got != "s/sid/task/tid" {
545 t.Fatalf("unexpected session task key: %q", got)
546 }
547 if got := string(KeySessionTaskStatusIndex("sid", "pending", "tid")); got != "s/sid/idx/status/pending/tid" {
548 t.Fatalf("unexpected status idx key: %q", got)
549 }
550 if got := string(KeySessionEventSeq("sid")); got != "s/sid/meta/evt_seq" {
551 t.Fatalf("unexpected event seq key: %q", got)
552 }
553 if got := string(KeySessionEvent("sid", 10)); got != "s/sid/evt/000000000000000a" {
554 t.Fatalf("unexpected session event key: %q", got)
555 }
556 if prefix := string(PrefixSessionTasks("sid")); prefix != "s/sid/task" {
557 t.Fatalf("unexpected tasks prefix: %q", prefix)
558 }
559 if prefix := string(PrefixSessionStatusIndex("sid", "pending")); prefix != "s/sid/idx/status/pending" {
560 t.Fatalf("unexpected status prefix: %q", prefix)
561 }
562 if prefix := string(PrefixSessionEvents("sid")); prefix != "s/sid/evt" {
563 t.Fatalf("unexpected events prefix: %q", prefix)
564 }
565 if prefix := string(PrefixDirArchived("hash")); prefix != "dir/hash/archived" {
566 t.Fatalf("unexpected dir archived prefix: %q", prefix)
567 }
568 if prefix := string(PrefixIdxActive()); prefix != "idx/active" {
569 t.Fatalf("unexpected idx active prefix: %q", prefix)
570 }
571 if prefix := string(PrefixIdxArchived()); prefix != "idx/archived" {
572 t.Fatalf("unexpected idx archived prefix: %q", prefix)
573 }
574}
575
576func TestEncodingHelpers(t *testing.T) {
577 var buf [8]byte
578 putUint64(buf[:], 123)
579 if got := readUint64(buf[:]); got != 123 {
580 t.Fatalf("unexpected roundtrip: got %d want 123", got)
581 }
582 if _, err := decodeUint64([]byte{1, 2}); err == nil {
583 t.Fatalf("expected error for short buffer")
584 }
585 if hex := Uint64Hex(10); hex != "000000000000000a" {
586 t.Fatalf("unexpected Uint64Hex: %q", hex)
587 }
588 if hex := encodeHex([]byte{0xde, 0xad}); hex != "dead" {
589 t.Fatalf("unexpected encodeHex: %q", hex)
590 }
591}