Detailed changes
@@ -6,6 +6,7 @@ import (
"embed"
"fmt"
"log/slog"
+ "os"
"path/filepath"
"sync"
"testing"
@@ -38,10 +39,16 @@ func init() {
}
}
-// connEntry holds a shared database connection and its reference count.
+// connEntry holds a shared database connection, its reference count,
+// and the data-directory lock that gates access to this entry. The
+// lock is acquired exactly once when the entry is created and released
+// when the last reference is dropped, which lets the same process open
+// the same data directory concurrently while still blocking a second
+// crush process from racing the storage.
type connEntry struct {
db *sql.DB
refCount int
+ lock *dataDirLock
}
var (
@@ -76,8 +83,23 @@ func Connect(ctx context.Context, dataDir string) (*sql.DB, error) {
return entry.db, nil
}
+ // Take the per-data-directory lock before opening the database so
+ // we fail fast and with a clear error rather than racing another
+ // crush process on the same SQLite file. The lock is released when
+ // the matching Release call drops the refcount to zero. Ensuring
+ // the data directory exists is required because the lock file
+ // lives inside it.
+ if err := os.MkdirAll(dataDir, 0o700); err != nil {
+ return nil, fmt.Errorf("failed to create data directory %q: %w", dataDir, err)
+ }
+ lock, err := acquireDataDirLock(dataDir)
+ if err != nil {
+ return nil, err
+ }
+
conn, err := openDB(dbPath)
if err != nil {
+ lock.release()
return nil, err
}
@@ -90,22 +112,25 @@ func Connect(ctx context.Context, dataDir string) (*sql.DB, error) {
if err = conn.PingContext(ctx); err != nil {
conn.Close()
+ lock.release()
return nil, fmt.Errorf("failed to connect to database: %w", err)
}
if err := initGoose(); err != nil {
conn.Close()
+ lock.release()
slog.Error("Failed to initialize goose", "error", err)
return nil, fmt.Errorf("failed to initialize goose: %w", err)
}
if err := goose.Up(conn, "migrations"); err != nil {
conn.Close()
+ lock.release()
slog.Error("Failed to apply migrations", "error", err)
return nil, fmt.Errorf("failed to apply migrations: %w", err)
}
- pool[absPath] = &connEntry{db: conn, refCount: 1}
+ pool[absPath] = &connEntry{db: conn, refCount: 1, lock: lock}
return conn, nil
}
@@ -133,7 +158,11 @@ func Release(dataDir string) error {
}
delete(pool, absPath)
- return entry.db.Close()
+ closeErr := entry.db.Close()
+ if entry.lock != nil {
+ entry.lock.release()
+ }
+ return closeErr
}
// ResetPool closes all pooled connections and clears the pool. This is
@@ -143,6 +172,9 @@ func ResetPool() {
defer poolMu.Unlock()
for path, entry := range pool {
entry.db.Close()
+ if entry.lock != nil {
+ entry.lock.release()
+ }
delete(pool, path)
}
}
@@ -2,6 +2,8 @@ package db
import (
"context"
+ "errors"
+ "path/filepath"
"testing"
"github.com/stretchr/testify/require"
@@ -52,3 +54,117 @@ func TestRelease_NoopForUnknownDataDir(t *testing.T) {
require.NoError(t, Release("/nonexistent/path"), "releasing unknown data dir should not error")
}
+
+// TestConnect_FailsWhenDataDirLocked simulates a second crush process by
+// taking the data-dir lock directly via the OS primitive on a separate
+// file descriptor and then asserting that Connect surfaces a clean
+// ErrDataDirLocked instead of opening the database under contention.
+func TestConnect_FailsWhenDataDirLocked(t *testing.T) {
+ t.Cleanup(ResetPool)
+
+ dataDir := t.TempDir()
+ lockPath := filepath.Join(dataDir, dataDirLockFile)
+
+ release, err := tryFileLock(lockPath)
+ require.NoError(t, err, "expected to take the data-dir lock for the first time")
+ t.Cleanup(release)
+
+ _, err = Connect(context.Background(), dataDir)
+ require.Error(t, err, "Connect must refuse to open a locked data dir")
+ require.ErrorIs(t, err, ErrDataDirLocked)
+}
+
+// TestConnect_SucceedsAfterContenderReleases ensures the lock is purely
+// advisory and that a clean release lets the next Connect proceed.
+func TestConnect_SucceedsAfterContenderReleases(t *testing.T) {
+ t.Cleanup(ResetPool)
+
+ dataDir := t.TempDir()
+ lockPath := filepath.Join(dataDir, dataDirLockFile)
+
+ release, err := tryFileLock(lockPath)
+ require.NoError(t, err)
+
+ _, err = Connect(context.Background(), dataDir)
+ require.ErrorIs(t, err, ErrDataDirLocked)
+
+ release()
+
+ conn, err := Connect(context.Background(), dataDir)
+ require.NoError(t, err, "Connect should succeed once the contender releases the lock")
+ require.NoError(t, conn.PingContext(context.Background()))
+ require.NoError(t, Release(dataDir))
+}
+
+// TestConnect_LockReleasedOnFinalRelease confirms that closing the last
+// reference to a pool entry also drops the OS lock, so subsequent
+// processes can take the data dir.
+func TestConnect_LockReleasedOnFinalRelease(t *testing.T) {
+ t.Cleanup(ResetPool)
+
+ dataDir := t.TempDir()
+ lockPath := filepath.Join(dataDir, dataDirLockFile)
+
+ conn, err := Connect(context.Background(), dataDir)
+ require.NoError(t, err)
+ require.NoError(t, conn.PingContext(context.Background()))
+
+ // Holding the in-process entry must keep the OS lock held so a
+ // "second process" (simulated by a fresh tryFileLock call) is
+ // rejected.
+ _, lockErr := tryFileLock(lockPath)
+ require.Error(t, lockErr)
+ require.True(t, errors.Is(lockErr, errLockContended), "expected contended lock while pool entry is live")
+
+ require.NoError(t, Release(dataDir))
+
+ // After the final release the lock is free again.
+ release, err := tryFileLock(lockPath)
+ require.NoError(t, err, "expected lock to be released after final Release")
+ release()
+}
+
+// TestConnect_SharedPoolDoesNotReacquireLock makes sure that subsequent
+// in-process Connect calls reuse the existing OS lock through refcount,
+// not by re-acquiring it. The simplest observable signal of correctness
+// is that the second Connect does not error and the lock is still held
+// after a single Release.
+func TestConnect_SharedPoolDoesNotReacquireLock(t *testing.T) {
+ t.Cleanup(ResetPool)
+
+ dataDir := t.TempDir()
+ lockPath := filepath.Join(dataDir, dataDirLockFile)
+
+ _, err := Connect(context.Background(), dataDir)
+ require.NoError(t, err)
+
+ _, err = Connect(context.Background(), dataDir)
+ require.NoError(t, err)
+
+ // Drop one reference; lock must still be held.
+ require.NoError(t, Release(dataDir))
+ _, lockErr := tryFileLock(lockPath)
+ require.ErrorIs(t, lockErr, errLockContended)
+
+ require.NoError(t, Release(dataDir))
+}
+
+// TestConnect_SkipLockEnvBypassesAcquisition exercises the escape
+// hatch used by users on filesystems where flock is unreliable.
+func TestConnect_SkipLockEnvBypassesAcquisition(t *testing.T) {
+ t.Cleanup(ResetPool)
+
+ dataDir := t.TempDir()
+ lockPath := filepath.Join(dataDir, dataDirLockFile)
+
+ release, err := tryFileLock(lockPath)
+ require.NoError(t, err)
+ t.Cleanup(release)
+
+ t.Setenv("CRUSH_SKIP_DATADIR_LOCK", "1")
+
+ conn, err := Connect(context.Background(), dataDir)
+ require.NoError(t, err, "skip-lock env should bypass contention")
+ require.NoError(t, conn.PingContext(context.Background()))
+ require.NoError(t, Release(dataDir))
+}
@@ -0,0 +1,125 @@
+package db
+
+import (
+ "encoding/json"
+ "errors"
+ "fmt"
+ "os"
+ "path/filepath"
+ "strconv"
+ "time"
+
+ "github.com/charmbracelet/crush/internal/version"
+)
+
+// ErrDataDirLocked is returned by Connect when the data directory is
+// already in use by another crush process.
+var ErrDataDirLocked = errors.New("data directory already in use by another crush process")
+
+// dataDirLockFile is the name of the lock file inside the data
+// directory. It lives next to crush.db so users can `ls` and find it.
+const dataDirLockFile = "crush.lock"
+
+// dataDirOwnerInfo is the JSON payload written into the lock file by
+// the process that currently owns it. It is purely informational; the
+// authoritative state of ownership is the operating system flock on
+// the file descriptor.
+type dataDirOwnerInfo struct {
+ PID int `json:"pid"`
+ Version string `json:"version,omitempty"`
+ StartedAt string `json:"started_at,omitempty"`
+}
+
+// dataDirLock represents an acquired exclusive lock on a data
+// directory. release closes the underlying file descriptor which the
+// kernel uses to drop the OS-level lock.
+type dataDirLock struct {
+ release func()
+}
+
+// acquireDataDirLock takes an exclusive non-blocking lock on
+// {dataDir}/crush.lock. If the lock is already held by another
+// process, it returns ErrDataDirLocked wrapped with a diagnostic that
+// includes whatever owner info that process wrote.
+//
+// Acquisition is skipped (returning a no-op lock) when
+// CRUSH_SKIP_DATADIR_LOCK is set to a truthy value. This is intended
+// as an escape hatch for hostile filesystems that do not implement
+// advisory locking; it should not be used in normal operation.
+func acquireDataDirLock(dataDir string) (*dataDirLock, error) {
+ if skipDataDirLock() {
+ return &dataDirLock{release: func() {}}, nil
+ }
+
+ path := filepath.Join(dataDir, dataDirLockFile)
+ release, err := tryFileLock(path)
+ if err != nil {
+ if errors.Is(err, errLockContended) {
+ return nil, contendedLockError(dataDir, path)
+ }
+ return nil, fmt.Errorf("failed to lock data directory %q: %w", dataDir, err)
+ }
+
+ // Record ownership metadata so a contending process can identify
+ // us. Failures here are non-fatal: the OS-level lock is what
+ // actually guarantees mutual exclusion, and a missing/partial JSON
+ // payload only degrades diagnostics.
+ if err := writeOwnerInfo(path); err != nil {
+ // Best-effort; log via stderr only when running in a debug
+ // context would be invasive here, so we silently swallow.
+ _ = err
+ }
+
+ return &dataDirLock{release: release}, nil
+}
+
+// skipDataDirLock reports whether the data-dir lock should be bypassed.
+func skipDataDirLock() bool {
+ v, _ := strconv.ParseBool(os.Getenv("CRUSH_SKIP_DATADIR_LOCK"))
+ return v
+}
+
+// writeOwnerInfo truncates and rewrites the lock file with the current
+// process's identifying information. It is called only after the lock
+// is held.
+func writeOwnerInfo(path string) error {
+ info := dataDirOwnerInfo{
+ PID: os.Getpid(),
+ Version: version.Version,
+ StartedAt: time.Now().UTC().Format(time.RFC3339),
+ }
+ payload, err := json.MarshalIndent(info, "", " ")
+ if err != nil {
+ return err
+ }
+ payload = append(payload, '\n')
+ return os.WriteFile(path, payload, 0o600)
+}
+
+// readOwnerInfo returns the lock file's recorded owner, if it parses.
+// A missing or malformed file yields an empty struct and no error;
+// the caller decides what to surface to the user.
+func readOwnerInfo(path string) dataDirOwnerInfo {
+ raw, err := os.ReadFile(path)
+ if err != nil || len(raw) == 0 {
+ return dataDirOwnerInfo{}
+ }
+ var info dataDirOwnerInfo
+ _ = json.Unmarshal(raw, &info)
+ return info
+}
+
+// contendedLockError builds a wrapped ErrDataDirLocked annotated with
+// whatever owner metadata is currently in the lock file.
+func contendedLockError(dataDir, lockPath string) error {
+ info := readOwnerInfo(lockPath)
+ details := ""
+ switch {
+ case info.PID != 0 && info.StartedAt != "":
+ details = fmt.Sprintf(" (owner pid=%d version=%s started_at=%s)",
+ info.PID, info.Version, info.StartedAt)
+ case info.PID != 0:
+ details = fmt.Sprintf(" (owner pid=%d)", info.PID)
+ }
+ return fmt.Errorf("%w: %s%s", ErrDataDirLocked, dataDir, details)
+}
@@ -0,0 +1,45 @@
+//go:build !windows
+
+package db
+
+import (
+ "errors"
+ "fmt"
+ "os"
+
+ "golang.org/x/sys/unix"
+)
+
+// errLockContended is returned by tryFileLock when the lock is already
+// held by another open file description (typically another process).
+var errLockContended = errors.New("file lock is held by another process")
+
+// tryFileLock takes an exclusive non-blocking BSD flock on path,
+// creating the file if necessary. On success it returns a release
+// function that drops the lock and closes the descriptor. When the
+// lock is contended it returns errLockContended.
+//
+// BSD flock is advisory and per-open-file-description, so it does not
+// interfere with the byte-range locks SQLite itself uses on the same
+// file's siblings (crush.db, crush.db-wal, crush.db-shm). The lock is
+// also released automatically by the kernel when the file descriptor
+// is closed, including on process crash, so we do not need any
+// explicit stale-lock recovery.
+func tryFileLock(path string) (func(), error) {
+ f, err := os.OpenFile(path, os.O_RDWR|os.O_CREATE, 0o600)
+ if err != nil {
+ return nil, fmt.Errorf("open lock file: %w", err)
+ }
+ if err := unix.Flock(int(f.Fd()), unix.LOCK_EX|unix.LOCK_NB); err != nil {
+ _ = f.Close()
+ if errors.Is(err, unix.EWOULDBLOCK) {
+ return nil, errLockContended
+ }
+ return nil, fmt.Errorf("flock: %w", err)
+ }
+ return func() {
+ // Closing the descriptor releases the flock atomically.
+ _ = unix.Flock(int(f.Fd()), unix.LOCK_UN)
+ _ = f.Close()
+ }, nil
+}
@@ -0,0 +1,46 @@
+//go:build windows
+
+package db
+
+import (
+ "errors"
+ "fmt"
+ "math"
+ "os"
+
+ "golang.org/x/sys/windows"
+)
+
+// errLockContended is returned by tryFileLock when the lock is held
+// by another process.
+var errLockContended = errors.New("file lock is held by another process")
+
+// tryFileLock takes an exclusive non-blocking lock on path via
+// LockFileEx. On success it returns a release function that unlocks
+// and closes the descriptor.
+//
+// The flags combine LOCKFILE_EXCLUSIVE_LOCK with LOCKFILE_FAIL_IMMEDIATELY
+// to mirror the BSD LOCK_EX|LOCK_NB semantics used on POSIX. The lock
+// is released when the file handle closes, including on process exit,
+// which gives us automatic stale-lock recovery without any bookkeeping.
+func tryFileLock(path string) (func(), error) {
+ f, err := os.OpenFile(path, os.O_RDWR|os.O_CREATE, 0o600)
+ if err != nil {
+ return nil, fmt.Errorf("open lock file: %w", err)
+ }
+ h := windows.Handle(f.Fd())
+ ol := new(windows.Overlapped)
+ flags := uint32(windows.LOCKFILE_EXCLUSIVE_LOCK | windows.LOCKFILE_FAIL_IMMEDIATELY)
+ if err := windows.LockFileEx(h, flags, 0, math.MaxUint32, math.MaxUint32, ol); err != nil {
+ _ = f.Close()
+ if errors.Is(err, windows.ERROR_LOCK_VIOLATION) || errors.Is(err, windows.ERROR_IO_PENDING) {
+ return nil, errLockContended
+ }
+ return nil, fmt.Errorf("LockFileEx: %w", err)
+ }
+ return func() {
+ ol := new(windows.Overlapped)
+ _ = windows.UnlockFileEx(windows.Handle(f.Fd()), 0, math.MaxUint32, math.MaxUint32, ol)
+ _ = f.Close()
+ }, nil
+}