Detailed changes
@@ -29,6 +29,7 @@ import (
"github.com/charmbracelet/crush/internal/config"
"github.com/charmbracelet/crush/internal/db"
"github.com/charmbracelet/crush/internal/event"
+ "github.com/charmbracelet/crush/internal/lock"
crushlog "github.com/charmbracelet/crush/internal/log"
"github.com/charmbracelet/crush/internal/projects"
"github.com/charmbracelet/crush/internal/proto"
@@ -466,7 +467,7 @@ func spawnAndWaitReady(cmd *cobra.Command, hostURL *url.URL) error {
if err != nil {
return err
}
- release, err := acquireSpawnLock(filepath.Join(chDir, "start.lock"))
+ release, err := lock.File(cmd.Context(), filepath.Join(chDir, "start.lock"))
if err != nil {
// If the lock itself is unavailable, fall back to the
// unsynchronized path rather than blocking the user.
@@ -1,28 +0,0 @@
-//go:build !windows
-
-package cmd
-
-import (
- "fmt"
- "os"
-
- "golang.org/x/sys/unix"
-)
-
-// acquireSpawnLock takes an exclusive flock on the given file (creating
-// it if necessary) and returns a release function that unlocks and
-// closes the file. Blocks until the lock is acquired.
-func acquireSpawnLock(path string) (func(), error) {
- f, err := os.OpenFile(path, os.O_RDWR|os.O_CREATE, 0o600)
- if err != nil {
- return nil, fmt.Errorf("open spawn lock %q: %v", path, err)
- }
- if err := unix.Flock(int(f.Fd()), unix.LOCK_EX); err != nil {
- _ = f.Close()
- return nil, fmt.Errorf("flock spawn lock %q: %v", path, err)
- }
- return func() {
- _ = unix.Flock(int(f.Fd()), unix.LOCK_UN)
- _ = f.Close()
- }, nil
-}
@@ -1,32 +0,0 @@
-//go:build windows
-
-package cmd
-
-import (
- "fmt"
- "math"
- "os"
-
- "golang.org/x/sys/windows"
-)
-
-// acquireSpawnLock takes an exclusive lock on the given file (creating
-// it if necessary) using LockFileEx, and returns a release function
-// that unlocks and closes the file. Blocks until the lock is acquired.
-func acquireSpawnLock(path string) (func(), error) {
- f, err := os.OpenFile(path, os.O_RDWR|os.O_CREATE, 0o600)
- if err != nil {
- return nil, fmt.Errorf("open spawn lock %q: %v", path, err)
- }
- h := windows.Handle(f.Fd())
- ol := new(windows.Overlapped)
- if err := windows.LockFileEx(h, windows.LOCKFILE_EXCLUSIVE_LOCK, 0, math.MaxUint32, math.MaxUint32, ol); err != nil {
- _ = f.Close()
- return nil, fmt.Errorf("LockFileEx spawn lock %q: %v", path, err)
- }
- return func() {
- ol := new(windows.Overlapped)
- _ = windows.UnlockFileEx(windows.Handle(f.Fd()), 0, math.MaxUint32, math.MaxUint32, ol)
- _ = f.Close()
- }, nil
-}
@@ -8,10 +8,13 @@ import (
"os"
"path/filepath"
"slices"
+ "sync"
+ "time"
"charm.land/catwalk/pkg/catwalk"
hyperp "github.com/charmbracelet/crush/internal/agent/hyper"
"github.com/charmbracelet/crush/internal/env"
+ "github.com/charmbracelet/crush/internal/lock"
"github.com/charmbracelet/crush/internal/oauth"
"github.com/charmbracelet/crush/internal/oauth/copilot"
"github.com/charmbracelet/crush/internal/oauth/hyper"
@@ -19,6 +22,11 @@ import (
"github.com/tidwall/sjson"
)
+// configLockDeadline bounds how long lockConfig waits for the
+// cross-process flock before giving up. A few seconds is plenty for
+// honest contention; longer suggests something is wedged.
+const configLockDeadline = 5 * time.Second
+
// fileSnapshot captures metadata about a config file at a point in time.
type fileSnapshot struct {
Path string
@@ -37,6 +45,11 @@ type RuntimeOverrides struct {
// ConfigStore is the single entry point for all config access. It owns the
// pure-data Config, runtime state (working directory, resolver, known
// providers), and persistence to both global and workspace config files.
+//
+// mu serialises all config file mutations (SetConfigFields,
+// RemoveConfigField, RefreshOAuthToken) to prevent both in-process
+// goroutine races and, together with the shared lock.File, cross-process
+// races on the config file.
type ConfigStore struct {
config *Config
workingDir string
@@ -50,6 +63,8 @@ type ConfigStore struct {
snapshots map[string]fileSnapshot // path -> snapshot at last capture
autoReloadDisabled bool // set during load/reload to prevent re-entrancy
reloadInProgress bool // set during reload to avoid disk writes mid-reload
+
+ mu sync.Mutex
}
// Config returns the pure-data config struct (read-only after load).
@@ -95,6 +110,74 @@ func (s *ConfigStore) LoadedPaths() []string {
return slices.Clone(s.loadedPaths)
}
+// lockConfig acquires both the in-process mutex and a cross-process flock
+// on the config file for the given scope. Callers that need to do I/O
+// between reading and writing (e.g. an HTTP token exchange) must use
+// lockConfig explicitly rather than atomicWrite.
+//
+// The returned release function drops both locks. Callers must call it
+// as soon as the file access is complete — no I/O should be performed
+// while the lock is held.
+func (s *ConfigStore) lockConfig(scope Scope) (func(), error) {
+ s.mu.Lock()
+ path, err := s.configPath(scope)
+ if err != nil {
+ s.mu.Unlock()
+ return nil, err
+ }
+ if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
+ s.mu.Unlock()
+ return nil, fmt.Errorf("create config directory: %w", err)
+ }
+ ctx, cancel := context.WithTimeout(context.Background(), configLockDeadline)
+ defer cancel()
+ release, err := lock.File(ctx, path+".lock")
+ if err != nil {
+ s.mu.Unlock()
+ return nil, fmt.Errorf("acquire config lock: %w", err)
+ }
+ return func() {
+ release()
+ s.mu.Unlock()
+ }, nil
+}
+
+// atomicWrite handles the lock-read-transform-write-unlock cycle for
+// config file mutations. The fn callback receives the current file
+// contents (raw bytes, or {} if the file is missing) and must return the
+// new contents. fn must be pure — no I/O, no network calls.
+func (s *ConfigStore) atomicWrite(scope Scope, fn func(current []byte) ([]byte, error)) error {
+ unlock, err := s.lockConfig(scope)
+ if err != nil {
+ return err
+ }
+ defer unlock()
+
+ path, err := s.configPath(scope)
+ if err != nil {
+ return err
+ }
+
+ data, err := os.ReadFile(path)
+ if err != nil {
+ if os.IsNotExist(err) {
+ data = []byte("{}")
+ } else {
+ return fmt.Errorf("read config file: %w", err)
+ }
+ }
+
+ newData, err := fn(data)
+ if err != nil {
+ return err
+ }
+
+ if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
+ return fmt.Errorf("create config directory: %w", err)
+ }
+ return atomicWriteFile(path, newData, 0o600)
+}
+
// configPath returns the file path for the given scope.
func (s *ConfigStore) configPath(scope Scope) (string, error) {
switch scope {
@@ -134,32 +217,23 @@ func (s *ConfigStore) SetConfigField(scope Scope, key string, value any) error {
// config to keep in-memory state fresh. This is preferred over multiple
// SetConfigField calls when writing several fields atomically to avoid
// intermediate reloads with partial state.
+//
+// The write is protected by an in-process mutex and a cross-process flock
+// to prevent races between concurrent writers in different processes.
func (s *ConfigStore) SetConfigFields(scope Scope, kv map[string]any) error {
- path, err := s.configPath(scope)
- if err != nil {
- return fmt.Errorf("%v: %w", kv, err)
- }
- data, err := os.ReadFile(path)
- if err != nil {
- if os.IsNotExist(err) {
- data = []byte("{}")
- } else {
- return fmt.Errorf("failed to read config file: %w", err)
- }
- }
-
- newValue := string(data)
- for key, value := range kv {
- newValue, err = sjson.Set(newValue, key, value)
- if err != nil {
- return fmt.Errorf("failed to set config field %s: %w", key, err)
+ err := s.atomicWrite(scope, func(data []byte) ([]byte, error) {
+ v := string(data)
+ for key, value := range kv {
+ var sErr error
+ v, sErr = sjson.Set(v, key, value)
+ if sErr != nil {
+ return nil, fmt.Errorf("failed to set config field %s: %w", key, sErr)
+ }
}
- }
- if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
- return fmt.Errorf("failed to create config directory %q: %w", path, err)
- }
- if err := atomicWriteFile(path, []byte(newValue), 0o600); err != nil {
- return fmt.Errorf("failed to write config file: %w", err)
+ return []byte(v), nil
+ })
+ if err != nil {
+ return err
}
// Auto-reload to keep in-memory state fresh after config edits.
@@ -176,28 +250,20 @@ func (s *ConfigStore) SetConfigFields(scope Scope, kv map[string]any) error {
// RemoveConfigField removes a key from the config file for the given scope.
// After a successful write, it automatically reloads config to keep in-memory
// state fresh.
+//
+// The write is protected by an in-process mutex and a cross-process flock.
func (s *ConfigStore) RemoveConfigField(scope Scope, key string) error {
- path, err := s.configPath(scope)
- if err != nil {
- return fmt.Errorf("%s: %w", key, err)
- }
- data, err := os.ReadFile(path)
- if err != nil {
- return fmt.Errorf("failed to read config file: %w", err)
- }
-
- newValue, err := sjson.Delete(string(data), key)
+ err := s.atomicWrite(scope, func(data []byte) ([]byte, error) {
+ v, sErr := sjson.Delete(string(data), key)
+ if sErr != nil {
+ return nil, fmt.Errorf("failed to delete config field %s: %w", key, sErr)
+ }
+ return []byte(v), nil
+ })
if err != nil {
- return fmt.Errorf("failed to delete config field %s: %w", key, err)
- }
- if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
- return fmt.Errorf("failed to create config directory %q: %w", path, err)
- }
- if err := atomicWriteFile(path, []byte(newValue), 0o600); err != nil {
- return fmt.Errorf("failed to write config file: %w", err)
+ return err
}
- // Auto-reload to keep in-memory state fresh after config edits.
if err := s.autoReload(context.Background()); err != nil {
slog.Warn("Config file updated but failed to reload in-memory state", "error", err)
}
@@ -300,12 +366,14 @@ func (s *ConfigStore) SetProviderAPIKey(scope Scope, providerID string, apiKey a
}
// RefreshOAuthToken refreshes the OAuth token for the given provider.
-// Before making an external refresh request, it checks the config file on
-// disk to see if another Crush session has already refreshed the token. If
-// a newer, unexpired token is found, it is used instead of refreshing. If
-// the exchange fails (e.g. because another session already rotated the
-// refresh token), the disk is re-checked to recover the other session's
-// token.
+//
+// It uses two-phase locking: the pre-check (reading the config file to
+// see if another process already refreshed) happens under the config
+// lock, then the HTTP exchange runs without any lock held, and finally
+// the result is persisted via SetConfigFields (which acquires the lock
+// internally). If the exchange fails — e.g. because another process
+// already rotated the refresh token — the disk is re-checked under lock
+// to recover the other process's token.
func (s *ConfigStore) RefreshOAuthToken(ctx context.Context, scope Scope, providerID string) error {
providerConfig, exists := s.config.Providers.Get(providerID)
if !exists {
@@ -316,16 +384,22 @@ func (s *ConfigStore) RefreshOAuthToken(ctx context.Context, scope Scope, provid
return fmt.Errorf("provider %s does not have an OAuth token", providerID)
}
- // Check if another session refreshed the token recently by reading
- // the current token from the config file on disk.
- newToken, err := s.loadTokenFromDisk(scope, providerID)
- if err != nil {
- slog.Warn("Failed to read token from config file, proceeding with refresh", "provider", providerID, "error", err)
- } else if newToken != nil && !newToken.IsExpired() && newToken.AccessToken != providerConfig.OAuthToken.AccessToken {
- slog.Info("Using token refreshed by another session", "provider", providerID)
- return s.applyToken(providerConfig, newToken, providerID)
+ // Phase 1: Pre-check under lock — did another process already refresh?
+ release, lockErr := s.lockConfig(scope)
+ if lockErr != nil {
+ slog.Warn("Failed to lock config for pre-check, proceeding anyway", "provider", providerID, "error", lockErr)
+ } else {
+ diskToken, err := s.loadTokenFromDisk(scope, providerID)
+ release()
+ if err != nil {
+ slog.Warn("Failed to read token from config file", "provider", providerID, "error", err)
+ } else if diskToken != nil && !diskToken.IsExpired() && diskToken.AccessToken != providerConfig.OAuthToken.AccessToken {
+ slog.Info("Using token refreshed by another session", "provider", providerID)
+ return s.applyToken(providerConfig, diskToken, providerID)
+ }
}
+ // Phase 2: HTTP exchange — no lock held.
var refreshedToken *oauth.Token
var refreshErr error
switch providerID {
@@ -337,15 +411,19 @@ func (s *ConfigStore) RefreshOAuthToken(ctx context.Context, scope Scope, provid
return fmt.Errorf("OAuth refresh not supported for provider %s", providerID)
}
if refreshErr != nil {
- // The exchange may have failed because another session already
- // rotated the refresh token. Re-read the config file and use the
- // other session's token if available.
- if diskToken, diskErr := s.loadTokenFromDisk(scope, providerID); diskErr == nil &&
- diskToken != nil &&
- !diskToken.IsExpired() &&
- diskToken.AccessToken != providerConfig.OAuthToken.AccessToken {
- slog.Info("Using token refreshed by another session after exchange failure", "provider", providerID)
- return s.applyToken(providerConfig, diskToken, providerID)
+ // Phase 3: Fallback — re-check disk under lock. The exchange may
+ // have failed because another process already rotated the refresh
+ // token.
+ if release, lockErr := s.lockConfig(scope); lockErr == nil {
+ diskToken, diskErr := s.loadTokenFromDisk(scope, providerID)
+ release()
+ if diskErr == nil &&
+ diskToken != nil &&
+ !diskToken.IsExpired() &&
+ diskToken.AccessToken != providerConfig.OAuthToken.AccessToken {
+ slog.Info("Using token refreshed by another session after exchange failure", "provider", providerID)
+ return s.applyToken(providerConfig, diskToken, providerID)
+ }
}
return fmt.Errorf("failed to refresh OAuth token for provider %s: %w", providerID, refreshErr)
}
@@ -3,6 +3,7 @@ package config
import (
"context"
"errors"
+ "fmt"
"os"
"path/filepath"
"testing"
@@ -11,6 +12,7 @@ import (
"github.com/charmbracelet/crush/internal/csync"
"github.com/charmbracelet/crush/internal/oauth"
"github.com/stretchr/testify/require"
+ "github.com/tidwall/gjson"
)
func TestConfigStore_ConfigPath_GlobalAlwaysWorks(t *testing.T) {
@@ -729,3 +731,58 @@ func TestRefreshOAuthToken_UsesDiskTokenWhenDifferent(t *testing.T) {
require.Equal(t, "newer-access-token", updatedConfig.OAuthToken.AccessToken)
require.Equal(t, "refresh-abc", updatedConfig.OAuthToken.RefreshToken)
}
+
+// TestConfigStore_SetConfigFields_concurrent verifies that concurrent writes do
+// not lose data when protected by the in-process mutex and cross-process flock.
+func TestConfigStore_SetConfigFields_concurrent(t *testing.T) {
+ t.Parallel()
+
+ dir := t.TempDir()
+ configPath := filepath.Join(dir, "crush.json")
+ require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755))
+ require.NoError(t, os.WriteFile(configPath, []byte("{}"), 0o600))
+
+ store := &ConfigStore{
+ config: &Config{
+ Providers: csync.NewMap[string, ProviderConfig](),
+ Models: make(map[SelectedModelType]SelectedModel),
+ },
+ globalDataPath: configPath,
+ workingDir: dir,
+ }
+
+ const (
+ numGoroutines = 20
+ fieldsPerRoutine = 5
+ )
+
+ errs := make(chan error, numGoroutines)
+ for i := 0; i < numGoroutines; i++ {
+ go func(id int) {
+ kv := make(map[string]any, fieldsPerRoutine)
+ for j := 0; j < fieldsPerRoutine; j++ {
+ key := fmt.Sprintf("goroutine_%d_field_%d", id, j)
+ kv[key] = fmt.Sprintf("value_%d_%d", id, j)
+ }
+ errs <- store.SetConfigFields(ScopeGlobal, kv)
+ }(i)
+ }
+
+ for i := 0; i < numGoroutines; i++ {
+ require.NoError(t, <-errs)
+ }
+
+ // Verify all fields are present in the config file.
+ data, err := os.ReadFile(configPath)
+ require.NoError(t, err)
+
+ for i := 0; i < numGoroutines; i++ {
+ for j := 0; j < fieldsPerRoutine; j++ {
+ key := fmt.Sprintf("goroutine_%d_field_%d", i, j)
+ expectedValue := fmt.Sprintf("value_%d_%d", i, j)
+ result := gjson.Get(string(data), key)
+ require.True(t, result.Exists(), "key %s should exist", key)
+ require.Equal(t, expectedValue, result.String(), "key %s should have the correct value", key)
+ }
+ }
+}
@@ -6,6 +6,7 @@ import (
"path/filepath"
"testing"
+ "github.com/charmbracelet/crush/internal/lock"
"github.com/stretchr/testify/require"
)
@@ -65,7 +66,7 @@ func TestConnect_FailsWhenDataDirLocked(t *testing.T) {
dataDir := t.TempDir()
lockPath := filepath.Join(dataDir, dataDirLockFile)
- release, err := tryFileLock(lockPath)
+ release, err := lock.TryFile(lockPath)
require.NoError(t, err, "expected to take the data-dir lock for the first time")
t.Cleanup(release)
@@ -82,7 +83,7 @@ func TestConnect_SucceedsAfterContenderReleases(t *testing.T) {
dataDir := t.TempDir()
lockPath := filepath.Join(dataDir, dataDirLockFile)
- release, err := tryFileLock(lockPath)
+ release, err := lock.TryFile(lockPath)
require.NoError(t, err)
_, err = Connect(context.Background(), dataDir, WithDataDirLock(true))
@@ -110,16 +111,16 @@ func TestConnect_LockReleasedOnFinalRelease(t *testing.T) {
require.NoError(t, conn.PingContext(context.Background()))
// Holding the in-process entry must keep the OS lock held so a
- // "second process" (simulated by a fresh tryFileLock call) is
+ // "second process" (simulated by a fresh lock.TryFile call) is
// rejected.
- _, lockErr := tryFileLock(lockPath)
+ _, lockErr := lock.TryFile(lockPath)
require.Error(t, lockErr)
- require.True(t, errors.Is(lockErr, errLockContended), "expected contended lock while pool entry is live")
+ require.True(t, errors.Is(lockErr, lock.ErrContended), "expected contended lock while pool entry is live")
require.NoError(t, Release(dataDir))
// After the final release the lock is free again.
- release, err := tryFileLock(lockPath)
+ release, err := lock.TryFile(lockPath)
require.NoError(t, err, "expected lock to be released after final Release")
release()
}
@@ -143,8 +144,8 @@ func TestConnect_SharedPoolDoesNotReacquireLock(t *testing.T) {
// Drop one reference; lock must still be held.
require.NoError(t, Release(dataDir))
- _, lockErr := tryFileLock(lockPath)
- require.ErrorIs(t, lockErr, errLockContended)
+ _, lockErr := lock.TryFile(lockPath)
+ require.ErrorIs(t, lockErr, lock.ErrContended)
require.NoError(t, Release(dataDir))
}
@@ -157,7 +158,7 @@ func TestConnect_SkipLockEnvBypassesAcquisition(t *testing.T) {
dataDir := t.TempDir()
lockPath := filepath.Join(dataDir, dataDirLockFile)
- release, err := tryFileLock(lockPath)
+ release, err := lock.TryFile(lockPath)
require.NoError(t, err)
t.Cleanup(release)
@@ -171,7 +172,7 @@ func TestConnect_SkipLockEnvBypassesAcquisition(t *testing.T) {
// TestConnect_DefaultIgnoresContendedLock confirms that without
// WithDataDirLock(true) the lock file is irrelevant: a contender can
-// hold tryFileLock and Connect still succeeds. This pins the
+// hold lock.TryFile and Connect still succeeds. This pins the
// local-mode default to its pre-lock behavior.
func TestConnect_DefaultIgnoresContendedLock(t *testing.T) {
t.Cleanup(ResetPool)
@@ -179,7 +180,7 @@ func TestConnect_DefaultIgnoresContendedLock(t *testing.T) {
dataDir := t.TempDir()
lockPath := filepath.Join(dataDir, dataDirLockFile)
- release, err := tryFileLock(lockPath)
+ release, err := lock.TryFile(lockPath)
require.NoError(t, err, "expected to take the data-dir lock for the first time")
t.Cleanup(release)
@@ -199,7 +200,7 @@ func TestConnect_ServerPathFailsWhenDataDirLocked(t *testing.T) {
dataDir := t.TempDir()
lockPath := filepath.Join(dataDir, dataDirLockFile)
- release, err := tryFileLock(lockPath)
+ release, err := lock.TryFile(lockPath)
require.NoError(t, err, "expected to take the data-dir lock for the first time")
t.Cleanup(release)
@@ -10,6 +10,7 @@ import (
"strconv"
"time"
+ "github.com/charmbracelet/crush/internal/lock"
"github.com/charmbracelet/crush/internal/version"
)
@@ -53,9 +54,9 @@ func acquireDataDirLock(dataDir string) (*dataDirLock, error) {
}
path := filepath.Join(dataDir, dataDirLockFile)
- release, err := tryFileLock(path)
+ release, err := lock.TryFile(path)
if err != nil {
- if errors.Is(err, errLockContended) {
+ if errors.Is(err, lock.ErrContended) {
return nil, contendedLockError(dataDir, path)
}
return nil, fmt.Errorf("failed to lock data directory %q: %w", dataDir, err)
@@ -1,45 +0,0 @@
-//go:build !windows
-
-package db
-
-import (
- "errors"
- "fmt"
- "os"
-
- "golang.org/x/sys/unix"
-)
-
-// errLockContended is returned by tryFileLock when the lock is already
-// held by another open file description (typically another process).
-var errLockContended = errors.New("file lock is held by another process")
-
-// tryFileLock takes an exclusive non-blocking BSD flock on path,
-// creating the file if necessary. On success it returns a release
-// function that drops the lock and closes the descriptor. When the
-// lock is contended it returns errLockContended.
-//
-// BSD flock is advisory and per-open-file-description, so it does not
-// interfere with the byte-range locks SQLite itself uses on the same
-// file's siblings (crush.db, crush.db-wal, crush.db-shm). The lock is
-// also released automatically by the kernel when the file descriptor
-// is closed, including on process crash, so we do not need any
-// explicit stale-lock recovery.
-func tryFileLock(path string) (func(), error) {
- f, err := os.OpenFile(path, os.O_RDWR|os.O_CREATE, 0o600)
- if err != nil {
- return nil, fmt.Errorf("open lock file: %w", err)
- }
- if err := unix.Flock(int(f.Fd()), unix.LOCK_EX|unix.LOCK_NB); err != nil {
- _ = f.Close()
- if errors.Is(err, unix.EWOULDBLOCK) {
- return nil, errLockContended
- }
- return nil, fmt.Errorf("flock: %w", err)
- }
- return func() {
- // Closing the descriptor releases the flock atomically.
- _ = unix.Flock(int(f.Fd()), unix.LOCK_UN)
- _ = f.Close()
- }, nil
-}
@@ -1,46 +0,0 @@
-//go:build windows
-
-package db
-
-import (
- "errors"
- "fmt"
- "math"
- "os"
-
- "golang.org/x/sys/windows"
-)
-
-// errLockContended is returned by tryFileLock when the lock is held
-// by another process.
-var errLockContended = errors.New("file lock is held by another process")
-
-// tryFileLock takes an exclusive non-blocking lock on path via
-// LockFileEx. On success it returns a release function that unlocks
-// and closes the descriptor.
-//
-// The flags combine LOCKFILE_EXCLUSIVE_LOCK with LOCKFILE_FAIL_IMMEDIATELY
-// to mirror the BSD LOCK_EX|LOCK_NB semantics used on POSIX. The lock
-// is released when the file handle closes, including on process exit,
-// which gives us automatic stale-lock recovery without any bookkeeping.
-func tryFileLock(path string) (func(), error) {
- f, err := os.OpenFile(path, os.O_RDWR|os.O_CREATE, 0o600)
- if err != nil {
- return nil, fmt.Errorf("open lock file: %w", err)
- }
- h := windows.Handle(f.Fd())
- ol := new(windows.Overlapped)
- flags := uint32(windows.LOCKFILE_EXCLUSIVE_LOCK | windows.LOCKFILE_FAIL_IMMEDIATELY)
- if err := windows.LockFileEx(h, flags, 0, math.MaxUint32, math.MaxUint32, ol); err != nil {
- _ = f.Close()
- if errors.Is(err, windows.ERROR_LOCK_VIOLATION) || errors.Is(err, windows.ERROR_IO_PENDING) {
- return nil, errLockContended
- }
- return nil, fmt.Errorf("LockFileEx: %w", err)
- }
- return func() {
- ol := new(windows.Overlapped)
- _ = windows.UnlockFileEx(windows.Handle(f.Fd()), 0, math.MaxUint32, math.MaxUint32, ol)
- _ = f.Close()
- }, nil
-}
@@ -0,0 +1,75 @@
+// Package lock provides cross-process advisory file locking.
+//
+// File acquires an exclusive lock on the file at path, blocking until
+// the context is cancelled (or its deadline elapses). TryFile does the
+// same but returns ErrContended immediately if the lock is already
+// held. In both cases the returned release function drops the lock and
+// closes the underlying file descriptor.
+//
+// The lock is released automatically by the kernel on process
+// termination (including crash), so no stale-lock recovery is needed.
+//
+// The lock file at path is created if it does not exist. It is never
+// unlinked — flock is keyed by inode, not path, and unlinking could
+// create a window where two processes lock different inodes at the
+// same path.
+//
+// This is the canonical file-locking helper for Crush. Callers should
+// prefer it over rolling their own platform-specific code.
+package lock
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "os"
+)
+
+// ErrContended is returned by TryFile when the lock is already held by
+// another process.
+var ErrContended = errors.New("file lock is held by another process")
+
+// File acquires an exclusive advisory lock on the file at path, blocking
+// until the lock is acquired or ctx is cancelled. It returns a release
+// function that drops the lock and closes the underlying file descriptor.
+//
+// Pass a context with a deadline (e.g. context.WithTimeout) to bound the
+// wait. Pass context.Background() to block indefinitely.
+func File(ctx context.Context, path string) (func(), error) {
+ f, err := os.OpenFile(path, os.O_RDWR|os.O_CREATE, 0o600)
+ if err != nil {
+ return nil, fmt.Errorf("open lock file %q: %w", path, err)
+ }
+
+ release, err := lockFile(ctx, f)
+ if err != nil {
+ f.Close()
+ return nil, err
+ }
+
+ return func() {
+ release()
+ f.Close()
+ }, nil
+}
+
+// TryFile is like File but returns ErrContended immediately if the lock
+// is already held by another process. Use this when you want to fail
+// fast rather than wait.
+func TryFile(path string) (func(), error) {
+ f, err := os.OpenFile(path, os.O_RDWR|os.O_CREATE, 0o600)
+ if err != nil {
+ return nil, fmt.Errorf("open lock file %q: %w", path, err)
+ }
+
+ release, err := tryLockFile(f)
+ if err != nil {
+ f.Close()
+ return nil, err
+ }
+
+ return func() {
+ release()
+ f.Close()
+ }, nil
+}
@@ -0,0 +1,168 @@
+package lock
+
+import (
+ "context"
+ "errors"
+ "path/filepath"
+ "sync"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestTryFile_AcquiresWhenFree(t *testing.T) {
+ t.Parallel()
+ path := filepath.Join(t.TempDir(), "test.lock")
+
+ release, err := TryFile(path)
+ require.NoError(t, err)
+ require.NotNil(t, release)
+ release()
+}
+
+func TestTryFile_ReturnsErrContendedWhenHeld(t *testing.T) {
+ t.Parallel()
+ path := filepath.Join(t.TempDir(), "test.lock")
+
+ release, err := TryFile(path)
+ require.NoError(t, err)
+ t.Cleanup(release)
+
+ _, err = TryFile(path)
+ require.ErrorIs(t, err, ErrContended)
+}
+
+func TestTryFile_ReacquireAfterRelease(t *testing.T) {
+ t.Parallel()
+ path := filepath.Join(t.TempDir(), "test.lock")
+
+ release, err := TryFile(path)
+ require.NoError(t, err)
+ release()
+
+ release2, err := TryFile(path)
+ require.NoError(t, err)
+ t.Cleanup(release2)
+}
+
+func TestFile_AcquiresWhenFree(t *testing.T) {
+ t.Parallel()
+ path := filepath.Join(t.TempDir(), "test.lock")
+
+ release, err := File(context.Background(), path)
+ require.NoError(t, err)
+ t.Cleanup(release)
+}
+
+func TestFile_BlocksThenSucceeds(t *testing.T) {
+ t.Parallel()
+ path := filepath.Join(t.TempDir(), "test.lock")
+
+ release, err := TryFile(path)
+ require.NoError(t, err)
+
+ // Release the lock after a short delay so the blocking acquirer
+ // can complete within the test timeout.
+ go func() {
+ time.Sleep(150 * time.Millisecond)
+ release()
+ }()
+
+ ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
+ defer cancel()
+ release2, err := File(ctx, path)
+ require.NoError(t, err, "should acquire after first releases")
+ release2()
+}
+
+func TestFile_RespectsContextDeadline(t *testing.T) {
+ t.Parallel()
+ path := filepath.Join(t.TempDir(), "test.lock")
+
+ release, err := TryFile(path)
+ require.NoError(t, err)
+ t.Cleanup(release)
+
+ ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
+ defer cancel()
+ start := time.Now()
+ _, err = File(ctx, path)
+ elapsed := time.Since(start)
+ require.Error(t, err)
+ require.True(t, errors.Is(err, context.DeadlineExceeded), "expected deadline exceeded, got %v", err)
+ require.Less(t, elapsed, 1*time.Second, "should return promptly after deadline")
+}
+
+func TestFile_RespectsContextCancellation(t *testing.T) {
+ t.Parallel()
+ path := filepath.Join(t.TempDir(), "test.lock")
+
+ release, err := TryFile(path)
+ require.NoError(t, err)
+ t.Cleanup(release)
+
+ ctx, cancel := context.WithCancel(context.Background())
+ done := make(chan error, 1)
+ go func() {
+ _, err := File(ctx, path)
+ done <- err
+ }()
+
+ time.Sleep(50 * time.Millisecond)
+ cancel()
+
+ select {
+ case err := <-done:
+ require.ErrorIs(t, err, context.Canceled)
+ case <-time.After(2 * time.Second):
+ t.Fatal("File did not return after context cancellation")
+ }
+}
+
+// TestFile_ConcurrentAcquirers verifies that multiple blocking acquirers
+// queue up correctly: each gets the lock in turn, exactly one at a time.
+func TestFile_ConcurrentAcquirers(t *testing.T) {
+ t.Parallel()
+ path := filepath.Join(t.TempDir(), "test.lock")
+
+ const n = 5
+ var (
+ mu sync.Mutex
+ inside int
+ maxSeen int
+ finished int
+ )
+ var wg sync.WaitGroup
+ wg.Add(n)
+ for range n {
+ go func() {
+ defer wg.Done()
+ ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
+ defer cancel()
+ release, err := File(ctx, path)
+ if err != nil {
+ t.Errorf("acquire failed: %v", err)
+ return
+ }
+ mu.Lock()
+ inside++
+ if inside > maxSeen {
+ maxSeen = inside
+ }
+ mu.Unlock()
+
+ time.Sleep(20 * time.Millisecond)
+
+ mu.Lock()
+ inside--
+ finished++
+ mu.Unlock()
+ release()
+ }()
+ }
+ wg.Wait()
+
+ require.Equal(t, n, finished)
+ require.Equal(t, 1, maxSeen, "lock must be mutually exclusive")
+}
@@ -0,0 +1,45 @@
+//go:build !windows
+
+package lock
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "os"
+ "time"
+
+ "golang.org/x/sys/unix"
+)
+
+// retrySleep is the interval between non-blocking flock retries in the
+// blocking File path. Small enough that contention resolution feels
+// snappy; large enough that we don't burn a CPU spinning.
+const retrySleep = 100 * time.Millisecond
+
+func lockFile(ctx context.Context, f *os.File) (func(), error) {
+ for {
+ err := unix.Flock(int(f.Fd()), unix.LOCK_EX|unix.LOCK_NB)
+ if err == nil {
+ return func() { _ = unix.Flock(int(f.Fd()), unix.LOCK_UN) }, nil
+ }
+ if !errors.Is(err, unix.EWOULDBLOCK) {
+ return nil, fmt.Errorf("flock: %w", err)
+ }
+ select {
+ case <-ctx.Done():
+ return nil, fmt.Errorf("acquire lock: %w", ctx.Err())
+ case <-time.After(retrySleep):
+ }
+ }
+}
+
+func tryLockFile(f *os.File) (func(), error) {
+ if err := unix.Flock(int(f.Fd()), unix.LOCK_EX|unix.LOCK_NB); err != nil {
+ if errors.Is(err, unix.EWOULDBLOCK) {
+ return nil, ErrContended
+ }
+ return nil, fmt.Errorf("flock: %w", err)
+ }
+ return func() { _ = unix.Flock(int(f.Fd()), unix.LOCK_UN) }, nil
+}
@@ -0,0 +1,57 @@
+//go:build windows
+
+package lock
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "math"
+ "os"
+ "time"
+
+ "golang.org/x/sys/windows"
+)
+
+// retrySleep is the interval between non-blocking lock retries in the
+// blocking File path.
+const retrySleep = 100 * time.Millisecond
+
+func lockFile(ctx context.Context, f *os.File) (func(), error) {
+ h := windows.Handle(f.Fd())
+ for {
+ ol := new(windows.Overlapped)
+ flags := uint32(windows.LOCKFILE_EXCLUSIVE_LOCK | windows.LOCKFILE_FAIL_IMMEDIATELY)
+ err := windows.LockFileEx(h, flags, 0, math.MaxUint32, math.MaxUint32, ol)
+ if err == nil {
+ return func() {
+ ol := new(windows.Overlapped)
+ _ = windows.UnlockFileEx(windows.Handle(f.Fd()), 0, math.MaxUint32, math.MaxUint32, ol)
+ }, nil
+ }
+ if !errors.Is(err, windows.ERROR_LOCK_VIOLATION) && !errors.Is(err, windows.ERROR_IO_PENDING) {
+ return nil, fmt.Errorf("LockFileEx: %w", err)
+ }
+ select {
+ case <-ctx.Done():
+ return nil, fmt.Errorf("acquire lock: %w", ctx.Err())
+ case <-time.After(retrySleep):
+ }
+ }
+}
+
+func tryLockFile(f *os.File) (func(), error) {
+ h := windows.Handle(f.Fd())
+ ol := new(windows.Overlapped)
+ flags := uint32(windows.LOCKFILE_EXCLUSIVE_LOCK | windows.LOCKFILE_FAIL_IMMEDIATELY)
+ if err := windows.LockFileEx(h, flags, 0, math.MaxUint32, math.MaxUint32, ol); err != nil {
+ if errors.Is(err, windows.ERROR_LOCK_VIOLATION) || errors.Is(err, windows.ERROR_IO_PENDING) {
+ return nil, ErrContended
+ }
+ return nil, fmt.Errorf("LockFileEx: %w", err)
+ }
+ return func() {
+ ol := new(windows.Overlapped)
+ _ = windows.UnlockFileEx(windows.Handle(f.Fd()), 0, math.MaxUint32, math.MaxUint32, ol)
+ }, nil
+}