From 7ce093611c8304aaa8099bdecc9b4e36ce2ff477 Mon Sep 17 00:00:00 2001 From: Kieran Klukas Date: Thu, 28 May 2026 14:36:01 -0400 Subject: [PATCH] refactor(lock): add canonical internal/lock package, migrate db and cmd callers Introduce internal/lock as the single file-locking helper for Crush, replacing the duplicated flock implementations in internal/db and internal/cmd. - lock.File(ctx, path) blocks until ctx cancels or lock acquired (context-aware deadline replaces hardcoded 5s timeout) - lock.TryFile(path) returns ErrContended immediately if held - Migrate internal/db.acquireDataDirLock to lock.TryFile - Migrate internal/cmd.spawnAndWaitReady to lock.File(cmd.Context(), ...) (now respects Ctrl+C during spawn wait) - Delete internal/db/datadirlock_{unix,windows}.go - Delete internal/cmd/spawnlock_{other,windows}.go - Add tests: acquire/release, contention, deadline, cancellation, concurrent mutual exclusion Config store uses lock.File with a 5s context deadline, preserving existing behavior while gaining context cancellation support. --- internal/cmd/root.go | 3 +- internal/cmd/spawnlock_other.go | 28 ---- internal/cmd/spawnlock_windows.go | 32 ----- internal/config/store.go | 208 ++++++++++++++++++++--------- internal/config/store_test.go | 57 ++++++++ internal/db/connect_test.go | 25 ++-- internal/db/datadirlock.go | 5 +- internal/db/datadirlock_unix.go | 45 ------- internal/db/datadirlock_windows.go | 46 ------- internal/lock/lock.go | 75 +++++++++++ internal/lock/lock_test.go | 168 +++++++++++++++++++++++ internal/lock/lock_unix.go | 45 +++++++ internal/lock/lock_windows.go | 57 ++++++++ 13 files changed, 563 insertions(+), 231 deletions(-) delete mode 100644 internal/cmd/spawnlock_other.go delete mode 100644 internal/cmd/spawnlock_windows.go delete mode 100644 internal/db/datadirlock_unix.go delete mode 100644 internal/db/datadirlock_windows.go create mode 100644 internal/lock/lock.go create mode 100644 internal/lock/lock_test.go create mode 100644 internal/lock/lock_unix.go create mode 100644 internal/lock/lock_windows.go diff --git a/internal/cmd/root.go b/internal/cmd/root.go index 99811f9305c5b7e66d0ca1495dcd94ba4a8b017f..a8364ac4af3c32a21a36af35df9c40f4b278087c 100644 --- a/internal/cmd/root.go +++ b/internal/cmd/root.go @@ -29,6 +29,7 @@ import ( "github.com/charmbracelet/crush/internal/config" "github.com/charmbracelet/crush/internal/db" "github.com/charmbracelet/crush/internal/event" + "github.com/charmbracelet/crush/internal/lock" crushlog "github.com/charmbracelet/crush/internal/log" "github.com/charmbracelet/crush/internal/projects" "github.com/charmbracelet/crush/internal/proto" @@ -466,7 +467,7 @@ func spawnAndWaitReady(cmd *cobra.Command, hostURL *url.URL) error { if err != nil { return err } - release, err := acquireSpawnLock(filepath.Join(chDir, "start.lock")) + release, err := lock.File(cmd.Context(), filepath.Join(chDir, "start.lock")) if err != nil { // If the lock itself is unavailable, fall back to the // unsynchronized path rather than blocking the user. diff --git a/internal/cmd/spawnlock_other.go b/internal/cmd/spawnlock_other.go deleted file mode 100644 index 1e07b7728a26e51e0ffaee16af1d685c13e5f424..0000000000000000000000000000000000000000 --- a/internal/cmd/spawnlock_other.go +++ /dev/null @@ -1,28 +0,0 @@ -//go:build !windows - -package cmd - -import ( - "fmt" - "os" - - "golang.org/x/sys/unix" -) - -// acquireSpawnLock takes an exclusive flock on the given file (creating -// it if necessary) and returns a release function that unlocks and -// closes the file. Blocks until the lock is acquired. -func acquireSpawnLock(path string) (func(), error) { - f, err := os.OpenFile(path, os.O_RDWR|os.O_CREATE, 0o600) - if err != nil { - return nil, fmt.Errorf("open spawn lock %q: %v", path, err) - } - if err := unix.Flock(int(f.Fd()), unix.LOCK_EX); err != nil { - _ = f.Close() - return nil, fmt.Errorf("flock spawn lock %q: %v", path, err) - } - return func() { - _ = unix.Flock(int(f.Fd()), unix.LOCK_UN) - _ = f.Close() - }, nil -} diff --git a/internal/cmd/spawnlock_windows.go b/internal/cmd/spawnlock_windows.go deleted file mode 100644 index d3e7492b229ac4bb5b3eca815711d5bc14ddcf0c..0000000000000000000000000000000000000000 --- a/internal/cmd/spawnlock_windows.go +++ /dev/null @@ -1,32 +0,0 @@ -//go:build windows - -package cmd - -import ( - "fmt" - "math" - "os" - - "golang.org/x/sys/windows" -) - -// acquireSpawnLock takes an exclusive lock on the given file (creating -// it if necessary) using LockFileEx, and returns a release function -// that unlocks and closes the file. Blocks until the lock is acquired. -func acquireSpawnLock(path string) (func(), error) { - f, err := os.OpenFile(path, os.O_RDWR|os.O_CREATE, 0o600) - if err != nil { - return nil, fmt.Errorf("open spawn lock %q: %v", path, err) - } - h := windows.Handle(f.Fd()) - ol := new(windows.Overlapped) - if err := windows.LockFileEx(h, windows.LOCKFILE_EXCLUSIVE_LOCK, 0, math.MaxUint32, math.MaxUint32, ol); err != nil { - _ = f.Close() - return nil, fmt.Errorf("LockFileEx spawn lock %q: %v", path, err) - } - return func() { - ol := new(windows.Overlapped) - _ = windows.UnlockFileEx(windows.Handle(f.Fd()), 0, math.MaxUint32, math.MaxUint32, ol) - _ = f.Close() - }, nil -} diff --git a/internal/config/store.go b/internal/config/store.go index 81b19a5926dcd80feb3ee3f24974596aa527aba4..259c495e81437dca17ee991eb91ad0a988b99785 100644 --- a/internal/config/store.go +++ b/internal/config/store.go @@ -8,10 +8,13 @@ import ( "os" "path/filepath" "slices" + "sync" + "time" "charm.land/catwalk/pkg/catwalk" hyperp "github.com/charmbracelet/crush/internal/agent/hyper" "github.com/charmbracelet/crush/internal/env" + "github.com/charmbracelet/crush/internal/lock" "github.com/charmbracelet/crush/internal/oauth" "github.com/charmbracelet/crush/internal/oauth/copilot" "github.com/charmbracelet/crush/internal/oauth/hyper" @@ -19,6 +22,11 @@ import ( "github.com/tidwall/sjson" ) +// configLockDeadline bounds how long lockConfig waits for the +// cross-process flock before giving up. A few seconds is plenty for +// honest contention; longer suggests something is wedged. +const configLockDeadline = 5 * time.Second + // fileSnapshot captures metadata about a config file at a point in time. type fileSnapshot struct { Path string @@ -37,6 +45,11 @@ type RuntimeOverrides struct { // ConfigStore is the single entry point for all config access. It owns the // pure-data Config, runtime state (working directory, resolver, known // providers), and persistence to both global and workspace config files. +// +// mu serialises all config file mutations (SetConfigFields, +// RemoveConfigField, RefreshOAuthToken) to prevent both in-process +// goroutine races and, together with the shared lock.File, cross-process +// races on the config file. type ConfigStore struct { config *Config workingDir string @@ -50,6 +63,8 @@ type ConfigStore struct { snapshots map[string]fileSnapshot // path -> snapshot at last capture autoReloadDisabled bool // set during load/reload to prevent re-entrancy reloadInProgress bool // set during reload to avoid disk writes mid-reload + + mu sync.Mutex } // Config returns the pure-data config struct (read-only after load). @@ -95,6 +110,74 @@ func (s *ConfigStore) LoadedPaths() []string { return slices.Clone(s.loadedPaths) } +// lockConfig acquires both the in-process mutex and a cross-process flock +// on the config file for the given scope. Callers that need to do I/O +// between reading and writing (e.g. an HTTP token exchange) must use +// lockConfig explicitly rather than atomicWrite. +// +// The returned release function drops both locks. Callers must call it +// as soon as the file access is complete — no I/O should be performed +// while the lock is held. +func (s *ConfigStore) lockConfig(scope Scope) (func(), error) { + s.mu.Lock() + path, err := s.configPath(scope) + if err != nil { + s.mu.Unlock() + return nil, err + } + if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { + s.mu.Unlock() + return nil, fmt.Errorf("create config directory: %w", err) + } + ctx, cancel := context.WithTimeout(context.Background(), configLockDeadline) + defer cancel() + release, err := lock.File(ctx, path+".lock") + if err != nil { + s.mu.Unlock() + return nil, fmt.Errorf("acquire config lock: %w", err) + } + return func() { + release() + s.mu.Unlock() + }, nil +} + +// atomicWrite handles the lock-read-transform-write-unlock cycle for +// config file mutations. The fn callback receives the current file +// contents (raw bytes, or {} if the file is missing) and must return the +// new contents. fn must be pure — no I/O, no network calls. +func (s *ConfigStore) atomicWrite(scope Scope, fn func(current []byte) ([]byte, error)) error { + unlock, err := s.lockConfig(scope) + if err != nil { + return err + } + defer unlock() + + path, err := s.configPath(scope) + if err != nil { + return err + } + + data, err := os.ReadFile(path) + if err != nil { + if os.IsNotExist(err) { + data = []byte("{}") + } else { + return fmt.Errorf("read config file: %w", err) + } + } + + newData, err := fn(data) + if err != nil { + return err + } + + if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { + return fmt.Errorf("create config directory: %w", err) + } + return atomicWriteFile(path, newData, 0o600) +} + // configPath returns the file path for the given scope. func (s *ConfigStore) configPath(scope Scope) (string, error) { switch scope { @@ -134,32 +217,23 @@ func (s *ConfigStore) SetConfigField(scope Scope, key string, value any) error { // config to keep in-memory state fresh. This is preferred over multiple // SetConfigField calls when writing several fields atomically to avoid // intermediate reloads with partial state. +// +// The write is protected by an in-process mutex and a cross-process flock +// to prevent races between concurrent writers in different processes. func (s *ConfigStore) SetConfigFields(scope Scope, kv map[string]any) error { - path, err := s.configPath(scope) - if err != nil { - return fmt.Errorf("%v: %w", kv, err) - } - data, err := os.ReadFile(path) - if err != nil { - if os.IsNotExist(err) { - data = []byte("{}") - } else { - return fmt.Errorf("failed to read config file: %w", err) - } - } - - newValue := string(data) - for key, value := range kv { - newValue, err = sjson.Set(newValue, key, value) - if err != nil { - return fmt.Errorf("failed to set config field %s: %w", key, err) + err := s.atomicWrite(scope, func(data []byte) ([]byte, error) { + v := string(data) + for key, value := range kv { + var sErr error + v, sErr = sjson.Set(v, key, value) + if sErr != nil { + return nil, fmt.Errorf("failed to set config field %s: %w", key, sErr) + } } - } - if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { - return fmt.Errorf("failed to create config directory %q: %w", path, err) - } - if err := atomicWriteFile(path, []byte(newValue), 0o600); err != nil { - return fmt.Errorf("failed to write config file: %w", err) + return []byte(v), nil + }) + if err != nil { + return err } // Auto-reload to keep in-memory state fresh after config edits. @@ -176,28 +250,20 @@ func (s *ConfigStore) SetConfigFields(scope Scope, kv map[string]any) error { // RemoveConfigField removes a key from the config file for the given scope. // After a successful write, it automatically reloads config to keep in-memory // state fresh. +// +// The write is protected by an in-process mutex and a cross-process flock. func (s *ConfigStore) RemoveConfigField(scope Scope, key string) error { - path, err := s.configPath(scope) - if err != nil { - return fmt.Errorf("%s: %w", key, err) - } - data, err := os.ReadFile(path) - if err != nil { - return fmt.Errorf("failed to read config file: %w", err) - } - - newValue, err := sjson.Delete(string(data), key) + err := s.atomicWrite(scope, func(data []byte) ([]byte, error) { + v, sErr := sjson.Delete(string(data), key) + if sErr != nil { + return nil, fmt.Errorf("failed to delete config field %s: %w", key, sErr) + } + return []byte(v), nil + }) if err != nil { - return fmt.Errorf("failed to delete config field %s: %w", key, err) - } - if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { - return fmt.Errorf("failed to create config directory %q: %w", path, err) - } - if err := atomicWriteFile(path, []byte(newValue), 0o600); err != nil { - return fmt.Errorf("failed to write config file: %w", err) + return err } - // Auto-reload to keep in-memory state fresh after config edits. if err := s.autoReload(context.Background()); err != nil { slog.Warn("Config file updated but failed to reload in-memory state", "error", err) } @@ -300,12 +366,14 @@ func (s *ConfigStore) SetProviderAPIKey(scope Scope, providerID string, apiKey a } // RefreshOAuthToken refreshes the OAuth token for the given provider. -// Before making an external refresh request, it checks the config file on -// disk to see if another Crush session has already refreshed the token. If -// a newer, unexpired token is found, it is used instead of refreshing. If -// the exchange fails (e.g. because another session already rotated the -// refresh token), the disk is re-checked to recover the other session's -// token. +// +// It uses two-phase locking: the pre-check (reading the config file to +// see if another process already refreshed) happens under the config +// lock, then the HTTP exchange runs without any lock held, and finally +// the result is persisted via SetConfigFields (which acquires the lock +// internally). If the exchange fails — e.g. because another process +// already rotated the refresh token — the disk is re-checked under lock +// to recover the other process's token. func (s *ConfigStore) RefreshOAuthToken(ctx context.Context, scope Scope, providerID string) error { providerConfig, exists := s.config.Providers.Get(providerID) if !exists { @@ -316,16 +384,22 @@ func (s *ConfigStore) RefreshOAuthToken(ctx context.Context, scope Scope, provid return fmt.Errorf("provider %s does not have an OAuth token", providerID) } - // Check if another session refreshed the token recently by reading - // the current token from the config file on disk. - newToken, err := s.loadTokenFromDisk(scope, providerID) - if err != nil { - slog.Warn("Failed to read token from config file, proceeding with refresh", "provider", providerID, "error", err) - } else if newToken != nil && !newToken.IsExpired() && newToken.AccessToken != providerConfig.OAuthToken.AccessToken { - slog.Info("Using token refreshed by another session", "provider", providerID) - return s.applyToken(providerConfig, newToken, providerID) + // Phase 1: Pre-check under lock — did another process already refresh? + release, lockErr := s.lockConfig(scope) + if lockErr != nil { + slog.Warn("Failed to lock config for pre-check, proceeding anyway", "provider", providerID, "error", lockErr) + } else { + diskToken, err := s.loadTokenFromDisk(scope, providerID) + release() + if err != nil { + slog.Warn("Failed to read token from config file", "provider", providerID, "error", err) + } else if diskToken != nil && !diskToken.IsExpired() && diskToken.AccessToken != providerConfig.OAuthToken.AccessToken { + slog.Info("Using token refreshed by another session", "provider", providerID) + return s.applyToken(providerConfig, diskToken, providerID) + } } + // Phase 2: HTTP exchange — no lock held. var refreshedToken *oauth.Token var refreshErr error switch providerID { @@ -337,15 +411,19 @@ func (s *ConfigStore) RefreshOAuthToken(ctx context.Context, scope Scope, provid return fmt.Errorf("OAuth refresh not supported for provider %s", providerID) } if refreshErr != nil { - // The exchange may have failed because another session already - // rotated the refresh token. Re-read the config file and use the - // other session's token if available. - if diskToken, diskErr := s.loadTokenFromDisk(scope, providerID); diskErr == nil && - diskToken != nil && - !diskToken.IsExpired() && - diskToken.AccessToken != providerConfig.OAuthToken.AccessToken { - slog.Info("Using token refreshed by another session after exchange failure", "provider", providerID) - return s.applyToken(providerConfig, diskToken, providerID) + // Phase 3: Fallback — re-check disk under lock. The exchange may + // have failed because another process already rotated the refresh + // token. + if release, lockErr := s.lockConfig(scope); lockErr == nil { + diskToken, diskErr := s.loadTokenFromDisk(scope, providerID) + release() + if diskErr == nil && + diskToken != nil && + !diskToken.IsExpired() && + diskToken.AccessToken != providerConfig.OAuthToken.AccessToken { + slog.Info("Using token refreshed by another session after exchange failure", "provider", providerID) + return s.applyToken(providerConfig, diskToken, providerID) + } } return fmt.Errorf("failed to refresh OAuth token for provider %s: %w", providerID, refreshErr) } diff --git a/internal/config/store_test.go b/internal/config/store_test.go index 66f9a00a44ffd82bef18eb70afdae93cfa66fc69..601d3a8487bc339345b347e7d5276d57bf67666d 100644 --- a/internal/config/store_test.go +++ b/internal/config/store_test.go @@ -3,6 +3,7 @@ package config import ( "context" "errors" + "fmt" "os" "path/filepath" "testing" @@ -11,6 +12,7 @@ import ( "github.com/charmbracelet/crush/internal/csync" "github.com/charmbracelet/crush/internal/oauth" "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" ) func TestConfigStore_ConfigPath_GlobalAlwaysWorks(t *testing.T) { @@ -729,3 +731,58 @@ func TestRefreshOAuthToken_UsesDiskTokenWhenDifferent(t *testing.T) { require.Equal(t, "newer-access-token", updatedConfig.OAuthToken.AccessToken) require.Equal(t, "refresh-abc", updatedConfig.OAuthToken.RefreshToken) } + +// TestConfigStore_SetConfigFields_concurrent verifies that concurrent writes do +// not lose data when protected by the in-process mutex and cross-process flock. +func TestConfigStore_SetConfigFields_concurrent(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + configPath := filepath.Join(dir, "crush.json") + require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755)) + require.NoError(t, os.WriteFile(configPath, []byte("{}"), 0o600)) + + store := &ConfigStore{ + config: &Config{ + Providers: csync.NewMap[string, ProviderConfig](), + Models: make(map[SelectedModelType]SelectedModel), + }, + globalDataPath: configPath, + workingDir: dir, + } + + const ( + numGoroutines = 20 + fieldsPerRoutine = 5 + ) + + errs := make(chan error, numGoroutines) + for i := 0; i < numGoroutines; i++ { + go func(id int) { + kv := make(map[string]any, fieldsPerRoutine) + for j := 0; j < fieldsPerRoutine; j++ { + key := fmt.Sprintf("goroutine_%d_field_%d", id, j) + kv[key] = fmt.Sprintf("value_%d_%d", id, j) + } + errs <- store.SetConfigFields(ScopeGlobal, kv) + }(i) + } + + for i := 0; i < numGoroutines; i++ { + require.NoError(t, <-errs) + } + + // Verify all fields are present in the config file. + data, err := os.ReadFile(configPath) + require.NoError(t, err) + + for i := 0; i < numGoroutines; i++ { + for j := 0; j < fieldsPerRoutine; j++ { + key := fmt.Sprintf("goroutine_%d_field_%d", i, j) + expectedValue := fmt.Sprintf("value_%d_%d", i, j) + result := gjson.Get(string(data), key) + require.True(t, result.Exists(), "key %s should exist", key) + require.Equal(t, expectedValue, result.String(), "key %s should have the correct value", key) + } + } +} diff --git a/internal/db/connect_test.go b/internal/db/connect_test.go index 45e39758924a9351b07bdb5956ddbf1ae85d1b02..d34d11262cce91fe37355232bd41a6cf0d4225aa 100644 --- a/internal/db/connect_test.go +++ b/internal/db/connect_test.go @@ -6,6 +6,7 @@ import ( "path/filepath" "testing" + "github.com/charmbracelet/crush/internal/lock" "github.com/stretchr/testify/require" ) @@ -65,7 +66,7 @@ func TestConnect_FailsWhenDataDirLocked(t *testing.T) { dataDir := t.TempDir() lockPath := filepath.Join(dataDir, dataDirLockFile) - release, err := tryFileLock(lockPath) + release, err := lock.TryFile(lockPath) require.NoError(t, err, "expected to take the data-dir lock for the first time") t.Cleanup(release) @@ -82,7 +83,7 @@ func TestConnect_SucceedsAfterContenderReleases(t *testing.T) { dataDir := t.TempDir() lockPath := filepath.Join(dataDir, dataDirLockFile) - release, err := tryFileLock(lockPath) + release, err := lock.TryFile(lockPath) require.NoError(t, err) _, err = Connect(context.Background(), dataDir, WithDataDirLock(true)) @@ -110,16 +111,16 @@ func TestConnect_LockReleasedOnFinalRelease(t *testing.T) { require.NoError(t, conn.PingContext(context.Background())) // Holding the in-process entry must keep the OS lock held so a - // "second process" (simulated by a fresh tryFileLock call) is + // "second process" (simulated by a fresh lock.TryFile call) is // rejected. - _, lockErr := tryFileLock(lockPath) + _, lockErr := lock.TryFile(lockPath) require.Error(t, lockErr) - require.True(t, errors.Is(lockErr, errLockContended), "expected contended lock while pool entry is live") + require.True(t, errors.Is(lockErr, lock.ErrContended), "expected contended lock while pool entry is live") require.NoError(t, Release(dataDir)) // After the final release the lock is free again. - release, err := tryFileLock(lockPath) + release, err := lock.TryFile(lockPath) require.NoError(t, err, "expected lock to be released after final Release") release() } @@ -143,8 +144,8 @@ func TestConnect_SharedPoolDoesNotReacquireLock(t *testing.T) { // Drop one reference; lock must still be held. require.NoError(t, Release(dataDir)) - _, lockErr := tryFileLock(lockPath) - require.ErrorIs(t, lockErr, errLockContended) + _, lockErr := lock.TryFile(lockPath) + require.ErrorIs(t, lockErr, lock.ErrContended) require.NoError(t, Release(dataDir)) } @@ -157,7 +158,7 @@ func TestConnect_SkipLockEnvBypassesAcquisition(t *testing.T) { dataDir := t.TempDir() lockPath := filepath.Join(dataDir, dataDirLockFile) - release, err := tryFileLock(lockPath) + release, err := lock.TryFile(lockPath) require.NoError(t, err) t.Cleanup(release) @@ -171,7 +172,7 @@ func TestConnect_SkipLockEnvBypassesAcquisition(t *testing.T) { // TestConnect_DefaultIgnoresContendedLock confirms that without // WithDataDirLock(true) the lock file is irrelevant: a contender can -// hold tryFileLock and Connect still succeeds. This pins the +// hold lock.TryFile and Connect still succeeds. This pins the // local-mode default to its pre-lock behavior. func TestConnect_DefaultIgnoresContendedLock(t *testing.T) { t.Cleanup(ResetPool) @@ -179,7 +180,7 @@ func TestConnect_DefaultIgnoresContendedLock(t *testing.T) { dataDir := t.TempDir() lockPath := filepath.Join(dataDir, dataDirLockFile) - release, err := tryFileLock(lockPath) + release, err := lock.TryFile(lockPath) require.NoError(t, err, "expected to take the data-dir lock for the first time") t.Cleanup(release) @@ -199,7 +200,7 @@ func TestConnect_ServerPathFailsWhenDataDirLocked(t *testing.T) { dataDir := t.TempDir() lockPath := filepath.Join(dataDir, dataDirLockFile) - release, err := tryFileLock(lockPath) + release, err := lock.TryFile(lockPath) require.NoError(t, err, "expected to take the data-dir lock for the first time") t.Cleanup(release) diff --git a/internal/db/datadirlock.go b/internal/db/datadirlock.go index 914933503fd795dd13a2052af76a8cd597015c04..2df272f081b3f5ee0341cfd2527aa1362f065ef6 100644 --- a/internal/db/datadirlock.go +++ b/internal/db/datadirlock.go @@ -10,6 +10,7 @@ import ( "strconv" "time" + "github.com/charmbracelet/crush/internal/lock" "github.com/charmbracelet/crush/internal/version" ) @@ -53,9 +54,9 @@ func acquireDataDirLock(dataDir string) (*dataDirLock, error) { } path := filepath.Join(dataDir, dataDirLockFile) - release, err := tryFileLock(path) + release, err := lock.TryFile(path) if err != nil { - if errors.Is(err, errLockContended) { + if errors.Is(err, lock.ErrContended) { return nil, contendedLockError(dataDir, path) } return nil, fmt.Errorf("failed to lock data directory %q: %w", dataDir, err) diff --git a/internal/db/datadirlock_unix.go b/internal/db/datadirlock_unix.go deleted file mode 100644 index 7e495349dd1b29c1960bc8c5731d3d19dd716d50..0000000000000000000000000000000000000000 --- a/internal/db/datadirlock_unix.go +++ /dev/null @@ -1,45 +0,0 @@ -//go:build !windows - -package db - -import ( - "errors" - "fmt" - "os" - - "golang.org/x/sys/unix" -) - -// errLockContended is returned by tryFileLock when the lock is already -// held by another open file description (typically another process). -var errLockContended = errors.New("file lock is held by another process") - -// tryFileLock takes an exclusive non-blocking BSD flock on path, -// creating the file if necessary. On success it returns a release -// function that drops the lock and closes the descriptor. When the -// lock is contended it returns errLockContended. -// -// BSD flock is advisory and per-open-file-description, so it does not -// interfere with the byte-range locks SQLite itself uses on the same -// file's siblings (crush.db, crush.db-wal, crush.db-shm). The lock is -// also released automatically by the kernel when the file descriptor -// is closed, including on process crash, so we do not need any -// explicit stale-lock recovery. -func tryFileLock(path string) (func(), error) { - f, err := os.OpenFile(path, os.O_RDWR|os.O_CREATE, 0o600) - if err != nil { - return nil, fmt.Errorf("open lock file: %w", err) - } - if err := unix.Flock(int(f.Fd()), unix.LOCK_EX|unix.LOCK_NB); err != nil { - _ = f.Close() - if errors.Is(err, unix.EWOULDBLOCK) { - return nil, errLockContended - } - return nil, fmt.Errorf("flock: %w", err) - } - return func() { - // Closing the descriptor releases the flock atomically. - _ = unix.Flock(int(f.Fd()), unix.LOCK_UN) - _ = f.Close() - }, nil -} diff --git a/internal/db/datadirlock_windows.go b/internal/db/datadirlock_windows.go deleted file mode 100644 index 1a0d53894c39d303e4a5e1820c513764375c891b..0000000000000000000000000000000000000000 --- a/internal/db/datadirlock_windows.go +++ /dev/null @@ -1,46 +0,0 @@ -//go:build windows - -package db - -import ( - "errors" - "fmt" - "math" - "os" - - "golang.org/x/sys/windows" -) - -// errLockContended is returned by tryFileLock when the lock is held -// by another process. -var errLockContended = errors.New("file lock is held by another process") - -// tryFileLock takes an exclusive non-blocking lock on path via -// LockFileEx. On success it returns a release function that unlocks -// and closes the descriptor. -// -// The flags combine LOCKFILE_EXCLUSIVE_LOCK with LOCKFILE_FAIL_IMMEDIATELY -// to mirror the BSD LOCK_EX|LOCK_NB semantics used on POSIX. The lock -// is released when the file handle closes, including on process exit, -// which gives us automatic stale-lock recovery without any bookkeeping. -func tryFileLock(path string) (func(), error) { - f, err := os.OpenFile(path, os.O_RDWR|os.O_CREATE, 0o600) - if err != nil { - return nil, fmt.Errorf("open lock file: %w", err) - } - h := windows.Handle(f.Fd()) - ol := new(windows.Overlapped) - flags := uint32(windows.LOCKFILE_EXCLUSIVE_LOCK | windows.LOCKFILE_FAIL_IMMEDIATELY) - if err := windows.LockFileEx(h, flags, 0, math.MaxUint32, math.MaxUint32, ol); err != nil { - _ = f.Close() - if errors.Is(err, windows.ERROR_LOCK_VIOLATION) || errors.Is(err, windows.ERROR_IO_PENDING) { - return nil, errLockContended - } - return nil, fmt.Errorf("LockFileEx: %w", err) - } - return func() { - ol := new(windows.Overlapped) - _ = windows.UnlockFileEx(windows.Handle(f.Fd()), 0, math.MaxUint32, math.MaxUint32, ol) - _ = f.Close() - }, nil -} diff --git a/internal/lock/lock.go b/internal/lock/lock.go new file mode 100644 index 0000000000000000000000000000000000000000..dbbbec6a2acb1a01d74d7e4e9019c0c7627b4114 --- /dev/null +++ b/internal/lock/lock.go @@ -0,0 +1,75 @@ +// Package lock provides cross-process advisory file locking. +// +// File acquires an exclusive lock on the file at path, blocking until +// the context is cancelled (or its deadline elapses). TryFile does the +// same but returns ErrContended immediately if the lock is already +// held. In both cases the returned release function drops the lock and +// closes the underlying file descriptor. +// +// The lock is released automatically by the kernel on process +// termination (including crash), so no stale-lock recovery is needed. +// +// The lock file at path is created if it does not exist. It is never +// unlinked — flock is keyed by inode, not path, and unlinking could +// create a window where two processes lock different inodes at the +// same path. +// +// This is the canonical file-locking helper for Crush. Callers should +// prefer it over rolling their own platform-specific code. +package lock + +import ( + "context" + "errors" + "fmt" + "os" +) + +// ErrContended is returned by TryFile when the lock is already held by +// another process. +var ErrContended = errors.New("file lock is held by another process") + +// File acquires an exclusive advisory lock on the file at path, blocking +// until the lock is acquired or ctx is cancelled. It returns a release +// function that drops the lock and closes the underlying file descriptor. +// +// Pass a context with a deadline (e.g. context.WithTimeout) to bound the +// wait. Pass context.Background() to block indefinitely. +func File(ctx context.Context, path string) (func(), error) { + f, err := os.OpenFile(path, os.O_RDWR|os.O_CREATE, 0o600) + if err != nil { + return nil, fmt.Errorf("open lock file %q: %w", path, err) + } + + release, err := lockFile(ctx, f) + if err != nil { + f.Close() + return nil, err + } + + return func() { + release() + f.Close() + }, nil +} + +// TryFile is like File but returns ErrContended immediately if the lock +// is already held by another process. Use this when you want to fail +// fast rather than wait. +func TryFile(path string) (func(), error) { + f, err := os.OpenFile(path, os.O_RDWR|os.O_CREATE, 0o600) + if err != nil { + return nil, fmt.Errorf("open lock file %q: %w", path, err) + } + + release, err := tryLockFile(f) + if err != nil { + f.Close() + return nil, err + } + + return func() { + release() + f.Close() + }, nil +} diff --git a/internal/lock/lock_test.go b/internal/lock/lock_test.go new file mode 100644 index 0000000000000000000000000000000000000000..6d7b8de56cd7d69c365cb4cee6746135df9e8dfd --- /dev/null +++ b/internal/lock/lock_test.go @@ -0,0 +1,168 @@ +package lock + +import ( + "context" + "errors" + "path/filepath" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestTryFile_AcquiresWhenFree(t *testing.T) { + t.Parallel() + path := filepath.Join(t.TempDir(), "test.lock") + + release, err := TryFile(path) + require.NoError(t, err) + require.NotNil(t, release) + release() +} + +func TestTryFile_ReturnsErrContendedWhenHeld(t *testing.T) { + t.Parallel() + path := filepath.Join(t.TempDir(), "test.lock") + + release, err := TryFile(path) + require.NoError(t, err) + t.Cleanup(release) + + _, err = TryFile(path) + require.ErrorIs(t, err, ErrContended) +} + +func TestTryFile_ReacquireAfterRelease(t *testing.T) { + t.Parallel() + path := filepath.Join(t.TempDir(), "test.lock") + + release, err := TryFile(path) + require.NoError(t, err) + release() + + release2, err := TryFile(path) + require.NoError(t, err) + t.Cleanup(release2) +} + +func TestFile_AcquiresWhenFree(t *testing.T) { + t.Parallel() + path := filepath.Join(t.TempDir(), "test.lock") + + release, err := File(context.Background(), path) + require.NoError(t, err) + t.Cleanup(release) +} + +func TestFile_BlocksThenSucceeds(t *testing.T) { + t.Parallel() + path := filepath.Join(t.TempDir(), "test.lock") + + release, err := TryFile(path) + require.NoError(t, err) + + // Release the lock after a short delay so the blocking acquirer + // can complete within the test timeout. + go func() { + time.Sleep(150 * time.Millisecond) + release() + }() + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + release2, err := File(ctx, path) + require.NoError(t, err, "should acquire after first releases") + release2() +} + +func TestFile_RespectsContextDeadline(t *testing.T) { + t.Parallel() + path := filepath.Join(t.TempDir(), "test.lock") + + release, err := TryFile(path) + require.NoError(t, err) + t.Cleanup(release) + + ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond) + defer cancel() + start := time.Now() + _, err = File(ctx, path) + elapsed := time.Since(start) + require.Error(t, err) + require.True(t, errors.Is(err, context.DeadlineExceeded), "expected deadline exceeded, got %v", err) + require.Less(t, elapsed, 1*time.Second, "should return promptly after deadline") +} + +func TestFile_RespectsContextCancellation(t *testing.T) { + t.Parallel() + path := filepath.Join(t.TempDir(), "test.lock") + + release, err := TryFile(path) + require.NoError(t, err) + t.Cleanup(release) + + ctx, cancel := context.WithCancel(context.Background()) + done := make(chan error, 1) + go func() { + _, err := File(ctx, path) + done <- err + }() + + time.Sleep(50 * time.Millisecond) + cancel() + + select { + case err := <-done: + require.ErrorIs(t, err, context.Canceled) + case <-time.After(2 * time.Second): + t.Fatal("File did not return after context cancellation") + } +} + +// TestFile_ConcurrentAcquirers verifies that multiple blocking acquirers +// queue up correctly: each gets the lock in turn, exactly one at a time. +func TestFile_ConcurrentAcquirers(t *testing.T) { + t.Parallel() + path := filepath.Join(t.TempDir(), "test.lock") + + const n = 5 + var ( + mu sync.Mutex + inside int + maxSeen int + finished int + ) + var wg sync.WaitGroup + wg.Add(n) + for range n { + go func() { + defer wg.Done() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + release, err := File(ctx, path) + if err != nil { + t.Errorf("acquire failed: %v", err) + return + } + mu.Lock() + inside++ + if inside > maxSeen { + maxSeen = inside + } + mu.Unlock() + + time.Sleep(20 * time.Millisecond) + + mu.Lock() + inside-- + finished++ + mu.Unlock() + release() + }() + } + wg.Wait() + + require.Equal(t, n, finished) + require.Equal(t, 1, maxSeen, "lock must be mutually exclusive") +} diff --git a/internal/lock/lock_unix.go b/internal/lock/lock_unix.go new file mode 100644 index 0000000000000000000000000000000000000000..a60ee9b234870bb397530f08396a905e38967cc7 --- /dev/null +++ b/internal/lock/lock_unix.go @@ -0,0 +1,45 @@ +//go:build !windows + +package lock + +import ( + "context" + "errors" + "fmt" + "os" + "time" + + "golang.org/x/sys/unix" +) + +// retrySleep is the interval between non-blocking flock retries in the +// blocking File path. Small enough that contention resolution feels +// snappy; large enough that we don't burn a CPU spinning. +const retrySleep = 100 * time.Millisecond + +func lockFile(ctx context.Context, f *os.File) (func(), error) { + for { + err := unix.Flock(int(f.Fd()), unix.LOCK_EX|unix.LOCK_NB) + if err == nil { + return func() { _ = unix.Flock(int(f.Fd()), unix.LOCK_UN) }, nil + } + if !errors.Is(err, unix.EWOULDBLOCK) { + return nil, fmt.Errorf("flock: %w", err) + } + select { + case <-ctx.Done(): + return nil, fmt.Errorf("acquire lock: %w", ctx.Err()) + case <-time.After(retrySleep): + } + } +} + +func tryLockFile(f *os.File) (func(), error) { + if err := unix.Flock(int(f.Fd()), unix.LOCK_EX|unix.LOCK_NB); err != nil { + if errors.Is(err, unix.EWOULDBLOCK) { + return nil, ErrContended + } + return nil, fmt.Errorf("flock: %w", err) + } + return func() { _ = unix.Flock(int(f.Fd()), unix.LOCK_UN) }, nil +} diff --git a/internal/lock/lock_windows.go b/internal/lock/lock_windows.go new file mode 100644 index 0000000000000000000000000000000000000000..2b7845966ff8e7503bfd06e9dd1bd60f99fbbd99 --- /dev/null +++ b/internal/lock/lock_windows.go @@ -0,0 +1,57 @@ +//go:build windows + +package lock + +import ( + "context" + "errors" + "fmt" + "math" + "os" + "time" + + "golang.org/x/sys/windows" +) + +// retrySleep is the interval between non-blocking lock retries in the +// blocking File path. +const retrySleep = 100 * time.Millisecond + +func lockFile(ctx context.Context, f *os.File) (func(), error) { + h := windows.Handle(f.Fd()) + for { + ol := new(windows.Overlapped) + flags := uint32(windows.LOCKFILE_EXCLUSIVE_LOCK | windows.LOCKFILE_FAIL_IMMEDIATELY) + err := windows.LockFileEx(h, flags, 0, math.MaxUint32, math.MaxUint32, ol) + if err == nil { + return func() { + ol := new(windows.Overlapped) + _ = windows.UnlockFileEx(windows.Handle(f.Fd()), 0, math.MaxUint32, math.MaxUint32, ol) + }, nil + } + if !errors.Is(err, windows.ERROR_LOCK_VIOLATION) && !errors.Is(err, windows.ERROR_IO_PENDING) { + return nil, fmt.Errorf("LockFileEx: %w", err) + } + select { + case <-ctx.Done(): + return nil, fmt.Errorf("acquire lock: %w", ctx.Err()) + case <-time.After(retrySleep): + } + } +} + +func tryLockFile(f *os.File) (func(), error) { + h := windows.Handle(f.Fd()) + ol := new(windows.Overlapped) + flags := uint32(windows.LOCKFILE_EXCLUSIVE_LOCK | windows.LOCKFILE_FAIL_IMMEDIATELY) + if err := windows.LockFileEx(h, flags, 0, math.MaxUint32, math.MaxUint32, ol); err != nil { + if errors.Is(err, windows.ERROR_LOCK_VIOLATION) || errors.Is(err, windows.ERROR_IO_PENDING) { + return nil, ErrContended + } + return nil, fmt.Errorf("LockFileEx: %w", err) + } + return func() { + ol := new(windows.Overlapped) + _ = windows.UnlockFileEx(windows.Handle(f.Fd()), 0, math.MaxUint32, math.MaxUint32, ol) + }, nil +}