diff --git a/internal/session/session.go b/internal/session/session.go index 8ea76a36e0334000d47fe58c455561d79a7c291e..6612dc0691ea5bc806585f11b368e7cfd1a0c11f 100644 --- a/internal/session/session.go +++ b/internal/session/session.go @@ -67,6 +67,59 @@ func NewStore(database *db.Database, clock timeutil.Clock) *Store { } } +// WithTxn exposes transactional helpers for use within db.Update. +func (s *Store) WithTxn(txn *db.Txn) TxnStore { + clock := s.clock + if clock == nil { + clock = timeutil.UTCClock{} + } + return TxnStore{ + txn: txn, + clock: clock, + } +} + +// TxnStore coordinates session operations within an existing transaction. +type TxnStore struct { + txn *db.Txn + clock timeutil.Clock +} + +// Load retrieves the session document for sid. +func (s TxnStore) Load(sid string) (Document, error) { + if s.txn == nil { + return Document{}, errors.New("session: transaction is nil") + } + return loadDocument(s.txn, sid) +} + +// TouchAt updates LastUpdatedAt for sid using at when provided (or the store's clock). +func (s TxnStore) TouchAt(sid string, at time.Time) (Document, error) { + if s.txn == nil { + return Document{}, errors.New("session: transaction is nil") + } + + if at.IsZero() { + clock := s.clock + if clock == nil { + clock = timeutil.UTCClock{} + } + at = clock.Now() + } + at = timeutil.EnsureUTC(at) + + doc, err := loadDocument(s.txn, sid) + if err != nil { + return Document{}, err + } + doc.LastUpdatedAt = at + + if err := s.txn.SetJSON(db.KeySessionMeta(sid), doc); err != nil { + return Document{}, err + } + return doc, nil +} + // Start creates a new session bound to path. When a session already exists for // the directory, an AlreadyActiveError is returned containing the existing // session document. @@ -93,18 +146,10 @@ func (s *Store) Start(ctx context.Context, path string) (Document, error) { return err } sid := string(sidBytes) - metaKey := db.KeySessionMeta(sid) - already, err := txn.Exists(metaKey) + existing, err := loadDocument(txn, sid) if err != nil { return err } - if !already { - return fmt.Errorf("session: active session %q missing metadata", sid) - } - var existing Document - if err := txn.GetJSON(metaKey, &existing); err != nil { - return err - } return AlreadyActiveError{Session: existing} } @@ -153,19 +198,36 @@ func (s *Store) Start(ctx context.Context, path string) (Document, error) { func (s *Store) Get(ctx context.Context, sid string) (Document, error) { var doc Document err := s.db.View(ctx, func(txn *db.Txn) error { - metaKey := db.KeySessionMeta(sid) - exists, err := txn.Exists(metaKey) - if err != nil { - return err - } - if !exists { - return ErrNotFound - } - return txn.GetJSON(metaKey, &doc) + var err error + doc, err = loadDocument(txn, sid) + return err }) return doc, err } +func loadDocument(txn *db.Txn, sid string) (Document, error) { + if txn == nil { + return Document{}, errors.New("session: transaction is nil") + } + key := db.KeySessionMeta(sid) + exists, err := txn.Exists(key) + if err != nil { + return Document{}, err + } + if !exists { + return Document{}, ErrNotFound + } + + var doc Document + if err := txn.GetJSON(key, &doc); err != nil { + return Document{}, err + } + if doc.SID == "" { + doc.SID = sid + } + return doc, nil +} + var crockfordEncoding = base32.NewEncoding("0123456789ABCDEFGHJKMNPQRSTVWXYZ").WithPadding(base32.NoPadding) func newSessionID(now time.Time) (string, error) { @@ -218,10 +280,14 @@ func (s *Store) ActiveByPath(ctx context.Context, path string) (Document, bool, if err != nil { return err } - metaKey := db.KeySessionMeta(string(sidBytes)) - if err := txn.GetJSON(metaKey, &doc); err != nil { + loaded, err := loadDocument(txn, string(sidBytes)) + if err != nil { + if errors.Is(err, ErrNotFound) { + continue + } return err } + doc = loaded found = true return nil } @@ -242,17 +308,11 @@ func (s *Store) Archive(ctx context.Context, sid string) (Document, error) { var doc Document err := s.db.Update(ctx, func(txn *db.Txn) error { - metaKey := db.KeySessionMeta(sid) - exists, err := txn.Exists(metaKey) + var err error + doc, err = loadDocument(txn, sid) if err != nil { return err } - if !exists { - return ErrNotFound - } - if err := txn.GetJSON(metaKey, &doc); err != nil { - return err - } if doc.State == StateArchived { return nil } @@ -276,7 +336,7 @@ func (s *Store) Archive(ctx context.Context, sid string) (Document, error) { if err := txn.Set(db.KeyDirArchived(doc.DirHash, tsHex, doc.SID), []byte{}); err != nil { return err } - if err := txn.SetJSON(metaKey, doc); err != nil { + if err := txn.SetJSON(db.KeySessionMeta(sid), doc); err != nil { return err } return nil diff --git a/internal/session/session_test.go b/internal/session/session_test.go index 032dbfaf60bd1bb8c7c2ecc1a73b041b9622b299..3700d0612f3fd2222f9459e97db5960492b633f5 100644 --- a/internal/session/session_test.go +++ b/internal/session/session_test.go @@ -191,3 +191,58 @@ func TestStoreStartAndArchive(t *testing.T) { t.Fatalf("ActiveByPath should not find archived session") } } + +func TestTxnStoreTouchAt(t *testing.T) { + ctx := context.Background() + database := testutil.OpenDB(t) + + startTime := time.Date(2025, time.May, 5, 9, 0, 0, 0, time.FixedZone("A", 3600)) + autoTime := time.Date(2025, time.May, 5, 10, 30, 0, 0, time.FixedZone("B", -3600)) + + clock := &testutil.SequenceClock{ + Times: []time.Time{startTime, autoTime}, + } + + store := session.NewStore(database, clock) + + dir := t.TempDir() + + started, err := store.Start(ctx, dir) + if err != nil { + t.Fatalf("Start: %v", err) + } + + custom := time.Date(2025, time.May, 6, 8, 15, 0, 0, time.UTC) + err = database.Update(ctx, func(txn *db.Txn) error { + _, err := store.WithTxn(txn).TouchAt(started.SID, custom) + return err + }) + if err != nil { + t.Fatalf("TouchAt(custom): %v", err) + } + + updated, err := store.Get(ctx, started.SID) + if err != nil { + t.Fatalf("Get after custom touch: %v", err) + } + if !updated.LastUpdatedAt.Equal(custom) { + t.Fatalf("LastUpdatedAt mismatch: got %v want %v", updated.LastUpdatedAt, custom) + } + + err = database.Update(ctx, func(txn *db.Txn) error { + _, err := store.WithTxn(txn).TouchAt(started.SID, time.Time{}) + return err + }) + if err != nil { + t.Fatalf("TouchAt(clock): %v", err) + } + + autoUpdated, err := store.Get(ctx, started.SID) + if err != nil { + t.Fatalf("Get after clock touch: %v", err) + } + + if want := autoTime.UTC(); !autoUpdated.LastUpdatedAt.Equal(want) { + t.Fatalf("Expected LastUpdatedAt %v, got %v", want, autoUpdated.LastUpdatedAt) + } +}