refactor(lock): add canonical internal/lock package, migrate db and cmd callers

Kieran Klukas created

Introduce internal/lock as the single file-locking helper for Crush,
replacing the duplicated flock implementations in internal/db and
internal/cmd.

- lock.File(ctx, path) blocks until ctx cancels or lock acquired
  (context-aware deadline replaces hardcoded 5s timeout)
- lock.TryFile(path) returns ErrContended immediately if held
- Migrate internal/db.acquireDataDirLock to lock.TryFile
- Migrate internal/cmd.spawnAndWaitReady to lock.File(cmd.Context(), ...)
  (now respects Ctrl+C during spawn wait)
- Delete internal/db/datadirlock_{unix,windows}.go
- Delete internal/cmd/spawnlock_{other,windows}.go
- Add tests: acquire/release, contention, deadline, cancellation,
  concurrent mutual exclusion

Config store uses lock.File with a 5s context deadline, preserving
existing behavior while gaining context cancellation support.

Change summary

internal/cmd/root.go               |   3 
internal/cmd/spawnlock_other.go    |  28 ----
internal/cmd/spawnlock_windows.go  |  32 ----
internal/config/store.go           | 208 ++++++++++++++++++++++----------
internal/config/store_test.go      |  57 ++++++++
internal/db/connect_test.go        |  25 ++-
internal/db/datadirlock.go         |   5 
internal/db/datadirlock_unix.go    |  45 ------
internal/db/datadirlock_windows.go |  46 -------
internal/lock/lock.go              |  75 +++++++++++
internal/lock/lock_test.go         | 168 +++++++++++++++++++++++++
internal/lock/lock_unix.go         |  45 ++++++
internal/lock/lock_windows.go      |  57 ++++++++
13 files changed, 563 insertions(+), 231 deletions(-)

Detailed changes

internal/cmd/root.go 🔗

@@ -29,6 +29,7 @@ import (
 	"github.com/charmbracelet/crush/internal/config"
 	"github.com/charmbracelet/crush/internal/db"
 	"github.com/charmbracelet/crush/internal/event"
+	"github.com/charmbracelet/crush/internal/lock"
 	crushlog "github.com/charmbracelet/crush/internal/log"
 	"github.com/charmbracelet/crush/internal/projects"
 	"github.com/charmbracelet/crush/internal/proto"
@@ -466,7 +467,7 @@ func spawnAndWaitReady(cmd *cobra.Command, hostURL *url.URL) error {
 	if err != nil {
 		return err
 	}
-	release, err := acquireSpawnLock(filepath.Join(chDir, "start.lock"))
+	release, err := lock.File(cmd.Context(), filepath.Join(chDir, "start.lock"))
 	if err != nil {
 		// If the lock itself is unavailable, fall back to the
 		// unsynchronized path rather than blocking the user.

internal/cmd/spawnlock_other.go 🔗

@@ -1,28 +0,0 @@
-//go:build !windows
-
-package cmd
-
-import (
-	"fmt"
-	"os"
-
-	"golang.org/x/sys/unix"
-)
-
-// acquireSpawnLock takes an exclusive flock on the given file (creating
-// it if necessary) and returns a release function that unlocks and
-// closes the file. Blocks until the lock is acquired.
-func acquireSpawnLock(path string) (func(), error) {
-	f, err := os.OpenFile(path, os.O_RDWR|os.O_CREATE, 0o600)
-	if err != nil {
-		return nil, fmt.Errorf("open spawn lock %q: %v", path, err)
-	}
-	if err := unix.Flock(int(f.Fd()), unix.LOCK_EX); err != nil {
-		_ = f.Close()
-		return nil, fmt.Errorf("flock spawn lock %q: %v", path, err)
-	}
-	return func() {
-		_ = unix.Flock(int(f.Fd()), unix.LOCK_UN)
-		_ = f.Close()
-	}, nil
-}

internal/cmd/spawnlock_windows.go 🔗

@@ -1,32 +0,0 @@
-//go:build windows
-
-package cmd
-
-import (
-	"fmt"
-	"math"
-	"os"
-
-	"golang.org/x/sys/windows"
-)
-
-// acquireSpawnLock takes an exclusive lock on the given file (creating
-// it if necessary) using LockFileEx, and returns a release function
-// that unlocks and closes the file. Blocks until the lock is acquired.
-func acquireSpawnLock(path string) (func(), error) {
-	f, err := os.OpenFile(path, os.O_RDWR|os.O_CREATE, 0o600)
-	if err != nil {
-		return nil, fmt.Errorf("open spawn lock %q: %v", path, err)
-	}
-	h := windows.Handle(f.Fd())
-	ol := new(windows.Overlapped)
-	if err := windows.LockFileEx(h, windows.LOCKFILE_EXCLUSIVE_LOCK, 0, math.MaxUint32, math.MaxUint32, ol); err != nil {
-		_ = f.Close()
-		return nil, fmt.Errorf("LockFileEx spawn lock %q: %v", path, err)
-	}
-	return func() {
-		ol := new(windows.Overlapped)
-		_ = windows.UnlockFileEx(windows.Handle(f.Fd()), 0, math.MaxUint32, math.MaxUint32, ol)
-		_ = f.Close()
-	}, nil
-}

internal/config/store.go 🔗

@@ -8,10 +8,13 @@ import (
 	"os"
 	"path/filepath"
 	"slices"
+	"sync"
+	"time"
 
 	"charm.land/catwalk/pkg/catwalk"
 	hyperp "github.com/charmbracelet/crush/internal/agent/hyper"
 	"github.com/charmbracelet/crush/internal/env"
+	"github.com/charmbracelet/crush/internal/lock"
 	"github.com/charmbracelet/crush/internal/oauth"
 	"github.com/charmbracelet/crush/internal/oauth/copilot"
 	"github.com/charmbracelet/crush/internal/oauth/hyper"
@@ -19,6 +22,11 @@ import (
 	"github.com/tidwall/sjson"
 )
 
+// configLockDeadline bounds how long lockConfig waits for the
+// cross-process flock before giving up. A few seconds is plenty for
+// honest contention; longer suggests something is wedged.
+const configLockDeadline = 5 * time.Second
+
 // fileSnapshot captures metadata about a config file at a point in time.
 type fileSnapshot struct {
 	Path    string
@@ -37,6 +45,11 @@ type RuntimeOverrides struct {
 // ConfigStore is the single entry point for all config access. It owns the
 // pure-data Config, runtime state (working directory, resolver, known
 // providers), and persistence to both global and workspace config files.
+//
+// mu serialises all config file mutations (SetConfigFields,
+// RemoveConfigField, RefreshOAuthToken) to prevent both in-process
+// goroutine races and, together with the shared lock.File, cross-process
+// races on the config file.
 type ConfigStore struct {
 	config             *Config
 	workingDir         string
@@ -50,6 +63,8 @@ type ConfigStore struct {
 	snapshots          map[string]fileSnapshot // path -> snapshot at last capture
 	autoReloadDisabled bool                    // set during load/reload to prevent re-entrancy
 	reloadInProgress   bool                    // set during reload to avoid disk writes mid-reload
+
+	mu sync.Mutex
 }
 
 // Config returns the pure-data config struct (read-only after load).
@@ -95,6 +110,74 @@ func (s *ConfigStore) LoadedPaths() []string {
 	return slices.Clone(s.loadedPaths)
 }
 
+// lockConfig acquires both the in-process mutex and a cross-process flock
+// on the config file for the given scope. Callers that need to do I/O
+// between reading and writing (e.g. an HTTP token exchange) must use
+// lockConfig explicitly rather than atomicWrite.
+//
+// The returned release function drops both locks. Callers must call it
+// as soon as the file access is complete — no I/O should be performed
+// while the lock is held.
+func (s *ConfigStore) lockConfig(scope Scope) (func(), error) {
+	s.mu.Lock()
+	path, err := s.configPath(scope)
+	if err != nil {
+		s.mu.Unlock()
+		return nil, err
+	}
+	if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
+		s.mu.Unlock()
+		return nil, fmt.Errorf("create config directory: %w", err)
+	}
+	ctx, cancel := context.WithTimeout(context.Background(), configLockDeadline)
+	defer cancel()
+	release, err := lock.File(ctx, path+".lock")
+	if err != nil {
+		s.mu.Unlock()
+		return nil, fmt.Errorf("acquire config lock: %w", err)
+	}
+	return func() {
+		release()
+		s.mu.Unlock()
+	}, nil
+}
+
+// atomicWrite handles the lock-read-transform-write-unlock cycle for
+// config file mutations. The fn callback receives the current file
+// contents (raw bytes, or {} if the file is missing) and must return the
+// new contents. fn must be pure — no I/O, no network calls.
+func (s *ConfigStore) atomicWrite(scope Scope, fn func(current []byte) ([]byte, error)) error {
+	unlock, err := s.lockConfig(scope)
+	if err != nil {
+		return err
+	}
+	defer unlock()
+
+	path, err := s.configPath(scope)
+	if err != nil {
+		return err
+	}
+
+	data, err := os.ReadFile(path)
+	if err != nil {
+		if os.IsNotExist(err) {
+			data = []byte("{}")
+		} else {
+			return fmt.Errorf("read config file: %w", err)
+		}
+	}
+
+	newData, err := fn(data)
+	if err != nil {
+		return err
+	}
+
+	if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
+		return fmt.Errorf("create config directory: %w", err)
+	}
+	return atomicWriteFile(path, newData, 0o600)
+}
+
 // configPath returns the file path for the given scope.
 func (s *ConfigStore) configPath(scope Scope) (string, error) {
 	switch scope {
@@ -134,32 +217,23 @@ func (s *ConfigStore) SetConfigField(scope Scope, key string, value any) error {
 // config to keep in-memory state fresh. This is preferred over multiple
 // SetConfigField calls when writing several fields atomically to avoid
 // intermediate reloads with partial state.
+//
+// The write is protected by an in-process mutex and a cross-process flock
+// to prevent races between concurrent writers in different processes.
 func (s *ConfigStore) SetConfigFields(scope Scope, kv map[string]any) error {
-	path, err := s.configPath(scope)
-	if err != nil {
-		return fmt.Errorf("%v: %w", kv, err)
-	}
-	data, err := os.ReadFile(path)
-	if err != nil {
-		if os.IsNotExist(err) {
-			data = []byte("{}")
-		} else {
-			return fmt.Errorf("failed to read config file: %w", err)
-		}
-	}
-
-	newValue := string(data)
-	for key, value := range kv {
-		newValue, err = sjson.Set(newValue, key, value)
-		if err != nil {
-			return fmt.Errorf("failed to set config field %s: %w", key, err)
+	err := s.atomicWrite(scope, func(data []byte) ([]byte, error) {
+		v := string(data)
+		for key, value := range kv {
+			var sErr error
+			v, sErr = sjson.Set(v, key, value)
+			if sErr != nil {
+				return nil, fmt.Errorf("failed to set config field %s: %w", key, sErr)
+			}
 		}
-	}
-	if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
-		return fmt.Errorf("failed to create config directory %q: %w", path, err)
-	}
-	if err := atomicWriteFile(path, []byte(newValue), 0o600); err != nil {
-		return fmt.Errorf("failed to write config file: %w", err)
+		return []byte(v), nil
+	})
+	if err != nil {
+		return err
 	}
 
 	// Auto-reload to keep in-memory state fresh after config edits.
@@ -176,28 +250,20 @@ func (s *ConfigStore) SetConfigFields(scope Scope, kv map[string]any) error {
 // RemoveConfigField removes a key from the config file for the given scope.
 // After a successful write, it automatically reloads config to keep in-memory
 // state fresh.
+//
+// The write is protected by an in-process mutex and a cross-process flock.
 func (s *ConfigStore) RemoveConfigField(scope Scope, key string) error {
-	path, err := s.configPath(scope)
-	if err != nil {
-		return fmt.Errorf("%s: %w", key, err)
-	}
-	data, err := os.ReadFile(path)
-	if err != nil {
-		return fmt.Errorf("failed to read config file: %w", err)
-	}
-
-	newValue, err := sjson.Delete(string(data), key)
+	err := s.atomicWrite(scope, func(data []byte) ([]byte, error) {
+		v, sErr := sjson.Delete(string(data), key)
+		if sErr != nil {
+			return nil, fmt.Errorf("failed to delete config field %s: %w", key, sErr)
+		}
+		return []byte(v), nil
+	})
 	if err != nil {
-		return fmt.Errorf("failed to delete config field %s: %w", key, err)
-	}
-	if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
-		return fmt.Errorf("failed to create config directory %q: %w", path, err)
-	}
-	if err := atomicWriteFile(path, []byte(newValue), 0o600); err != nil {
-		return fmt.Errorf("failed to write config file: %w", err)
+		return err
 	}
 
-	// Auto-reload to keep in-memory state fresh after config edits.
 	if err := s.autoReload(context.Background()); err != nil {
 		slog.Warn("Config file updated but failed to reload in-memory state", "error", err)
 	}
@@ -300,12 +366,14 @@ func (s *ConfigStore) SetProviderAPIKey(scope Scope, providerID string, apiKey a
 }
 
 // RefreshOAuthToken refreshes the OAuth token for the given provider.
-// Before making an external refresh request, it checks the config file on
-// disk to see if another Crush session has already refreshed the token. If
-// a newer, unexpired token is found, it is used instead of refreshing. If
-// the exchange fails (e.g. because another session already rotated the
-// refresh token), the disk is re-checked to recover the other session's
-// token.
+//
+// It uses two-phase locking: the pre-check (reading the config file to
+// see if another process already refreshed) happens under the config
+// lock, then the HTTP exchange runs without any lock held, and finally
+// the result is persisted via SetConfigFields (which acquires the lock
+// internally). If the exchange fails — e.g. because another process
+// already rotated the refresh token — the disk is re-checked under lock
+// to recover the other process's token.
 func (s *ConfigStore) RefreshOAuthToken(ctx context.Context, scope Scope, providerID string) error {
 	providerConfig, exists := s.config.Providers.Get(providerID)
 	if !exists {
@@ -316,16 +384,22 @@ func (s *ConfigStore) RefreshOAuthToken(ctx context.Context, scope Scope, provid
 		return fmt.Errorf("provider %s does not have an OAuth token", providerID)
 	}
 
-	// Check if another session refreshed the token recently by reading
-	// the current token from the config file on disk.
-	newToken, err := s.loadTokenFromDisk(scope, providerID)
-	if err != nil {
-		slog.Warn("Failed to read token from config file, proceeding with refresh", "provider", providerID, "error", err)
-	} else if newToken != nil && !newToken.IsExpired() && newToken.AccessToken != providerConfig.OAuthToken.AccessToken {
-		slog.Info("Using token refreshed by another session", "provider", providerID)
-		return s.applyToken(providerConfig, newToken, providerID)
+	// Phase 1: Pre-check under lock — did another process already refresh?
+	release, lockErr := s.lockConfig(scope)
+	if lockErr != nil {
+		slog.Warn("Failed to lock config for pre-check, proceeding anyway", "provider", providerID, "error", lockErr)
+	} else {
+		diskToken, err := s.loadTokenFromDisk(scope, providerID)
+		release()
+		if err != nil {
+			slog.Warn("Failed to read token from config file", "provider", providerID, "error", err)
+		} else if diskToken != nil && !diskToken.IsExpired() && diskToken.AccessToken != providerConfig.OAuthToken.AccessToken {
+			slog.Info("Using token refreshed by another session", "provider", providerID)
+			return s.applyToken(providerConfig, diskToken, providerID)
+		}
 	}
 
+	// Phase 2: HTTP exchange — no lock held.
 	var refreshedToken *oauth.Token
 	var refreshErr error
 	switch providerID {
@@ -337,15 +411,19 @@ func (s *ConfigStore) RefreshOAuthToken(ctx context.Context, scope Scope, provid
 		return fmt.Errorf("OAuth refresh not supported for provider %s", providerID)
 	}
 	if refreshErr != nil {
-		// The exchange may have failed because another session already
-		// rotated the refresh token. Re-read the config file and use the
-		// other session's token if available.
-		if diskToken, diskErr := s.loadTokenFromDisk(scope, providerID); diskErr == nil &&
-			diskToken != nil &&
-			!diskToken.IsExpired() &&
-			diskToken.AccessToken != providerConfig.OAuthToken.AccessToken {
-			slog.Info("Using token refreshed by another session after exchange failure", "provider", providerID)
-			return s.applyToken(providerConfig, diskToken, providerID)
+		// Phase 3: Fallback — re-check disk under lock. The exchange may
+		// have failed because another process already rotated the refresh
+		// token.
+		if release, lockErr := s.lockConfig(scope); lockErr == nil {
+			diskToken, diskErr := s.loadTokenFromDisk(scope, providerID)
+			release()
+			if diskErr == nil &&
+				diskToken != nil &&
+				!diskToken.IsExpired() &&
+				diskToken.AccessToken != providerConfig.OAuthToken.AccessToken {
+				slog.Info("Using token refreshed by another session after exchange failure", "provider", providerID)
+				return s.applyToken(providerConfig, diskToken, providerID)
+			}
 		}
 		return fmt.Errorf("failed to refresh OAuth token for provider %s: %w", providerID, refreshErr)
 	}

internal/config/store_test.go 🔗

@@ -3,6 +3,7 @@ package config
 import (
 	"context"
 	"errors"
+	"fmt"
 	"os"
 	"path/filepath"
 	"testing"
@@ -11,6 +12,7 @@ import (
 	"github.com/charmbracelet/crush/internal/csync"
 	"github.com/charmbracelet/crush/internal/oauth"
 	"github.com/stretchr/testify/require"
+	"github.com/tidwall/gjson"
 )
 
 func TestConfigStore_ConfigPath_GlobalAlwaysWorks(t *testing.T) {
@@ -729,3 +731,58 @@ func TestRefreshOAuthToken_UsesDiskTokenWhenDifferent(t *testing.T) {
 	require.Equal(t, "newer-access-token", updatedConfig.OAuthToken.AccessToken)
 	require.Equal(t, "refresh-abc", updatedConfig.OAuthToken.RefreshToken)
 }
+
+// TestConfigStore_SetConfigFields_concurrent verifies that concurrent writes do
+// not lose data when protected by the in-process mutex and cross-process flock.
+func TestConfigStore_SetConfigFields_concurrent(t *testing.T) {
+	t.Parallel()
+
+	dir := t.TempDir()
+	configPath := filepath.Join(dir, "crush.json")
+	require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o755))
+	require.NoError(t, os.WriteFile(configPath, []byte("{}"), 0o600))
+
+	store := &ConfigStore{
+		config: &Config{
+			Providers: csync.NewMap[string, ProviderConfig](),
+			Models:    make(map[SelectedModelType]SelectedModel),
+		},
+		globalDataPath: configPath,
+		workingDir:     dir,
+	}
+
+	const (
+		numGoroutines    = 20
+		fieldsPerRoutine = 5
+	)
+
+	errs := make(chan error, numGoroutines)
+	for i := 0; i < numGoroutines; i++ {
+		go func(id int) {
+			kv := make(map[string]any, fieldsPerRoutine)
+			for j := 0; j < fieldsPerRoutine; j++ {
+				key := fmt.Sprintf("goroutine_%d_field_%d", id, j)
+				kv[key] = fmt.Sprintf("value_%d_%d", id, j)
+			}
+			errs <- store.SetConfigFields(ScopeGlobal, kv)
+		}(i)
+	}
+
+	for i := 0; i < numGoroutines; i++ {
+		require.NoError(t, <-errs)
+	}
+
+	// Verify all fields are present in the config file.
+	data, err := os.ReadFile(configPath)
+	require.NoError(t, err)
+
+	for i := 0; i < numGoroutines; i++ {
+		for j := 0; j < fieldsPerRoutine; j++ {
+			key := fmt.Sprintf("goroutine_%d_field_%d", i, j)
+			expectedValue := fmt.Sprintf("value_%d_%d", i, j)
+			result := gjson.Get(string(data), key)
+			require.True(t, result.Exists(), "key %s should exist", key)
+			require.Equal(t, expectedValue, result.String(), "key %s should have the correct value", key)
+		}
+	}
+}

internal/db/connect_test.go 🔗

@@ -6,6 +6,7 @@ import (
 	"path/filepath"
 	"testing"
 
+	"github.com/charmbracelet/crush/internal/lock"
 	"github.com/stretchr/testify/require"
 )
 
@@ -65,7 +66,7 @@ func TestConnect_FailsWhenDataDirLocked(t *testing.T) {
 	dataDir := t.TempDir()
 	lockPath := filepath.Join(dataDir, dataDirLockFile)
 
-	release, err := tryFileLock(lockPath)
+	release, err := lock.TryFile(lockPath)
 	require.NoError(t, err, "expected to take the data-dir lock for the first time")
 	t.Cleanup(release)
 
@@ -82,7 +83,7 @@ func TestConnect_SucceedsAfterContenderReleases(t *testing.T) {
 	dataDir := t.TempDir()
 	lockPath := filepath.Join(dataDir, dataDirLockFile)
 
-	release, err := tryFileLock(lockPath)
+	release, err := lock.TryFile(lockPath)
 	require.NoError(t, err)
 
 	_, err = Connect(context.Background(), dataDir, WithDataDirLock(true))
@@ -110,16 +111,16 @@ func TestConnect_LockReleasedOnFinalRelease(t *testing.T) {
 	require.NoError(t, conn.PingContext(context.Background()))
 
 	// Holding the in-process entry must keep the OS lock held so a
-	// "second process" (simulated by a fresh tryFileLock call) is
+	// "second process" (simulated by a fresh lock.TryFile call) is
 	// rejected.
-	_, lockErr := tryFileLock(lockPath)
+	_, lockErr := lock.TryFile(lockPath)
 	require.Error(t, lockErr)
-	require.True(t, errors.Is(lockErr, errLockContended), "expected contended lock while pool entry is live")
+	require.True(t, errors.Is(lockErr, lock.ErrContended), "expected contended lock while pool entry is live")
 
 	require.NoError(t, Release(dataDir))
 
 	// After the final release the lock is free again.
-	release, err := tryFileLock(lockPath)
+	release, err := lock.TryFile(lockPath)
 	require.NoError(t, err, "expected lock to be released after final Release")
 	release()
 }
@@ -143,8 +144,8 @@ func TestConnect_SharedPoolDoesNotReacquireLock(t *testing.T) {
 
 	// Drop one reference; lock must still be held.
 	require.NoError(t, Release(dataDir))
-	_, lockErr := tryFileLock(lockPath)
-	require.ErrorIs(t, lockErr, errLockContended)
+	_, lockErr := lock.TryFile(lockPath)
+	require.ErrorIs(t, lockErr, lock.ErrContended)
 
 	require.NoError(t, Release(dataDir))
 }
@@ -157,7 +158,7 @@ func TestConnect_SkipLockEnvBypassesAcquisition(t *testing.T) {
 	dataDir := t.TempDir()
 	lockPath := filepath.Join(dataDir, dataDirLockFile)
 
-	release, err := tryFileLock(lockPath)
+	release, err := lock.TryFile(lockPath)
 	require.NoError(t, err)
 	t.Cleanup(release)
 
@@ -171,7 +172,7 @@ func TestConnect_SkipLockEnvBypassesAcquisition(t *testing.T) {
 
 // TestConnect_DefaultIgnoresContendedLock confirms that without
 // WithDataDirLock(true) the lock file is irrelevant: a contender can
-// hold tryFileLock and Connect still succeeds. This pins the
+// hold lock.TryFile and Connect still succeeds. This pins the
 // local-mode default to its pre-lock behavior.
 func TestConnect_DefaultIgnoresContendedLock(t *testing.T) {
 	t.Cleanup(ResetPool)
@@ -179,7 +180,7 @@ func TestConnect_DefaultIgnoresContendedLock(t *testing.T) {
 	dataDir := t.TempDir()
 	lockPath := filepath.Join(dataDir, dataDirLockFile)
 
-	release, err := tryFileLock(lockPath)
+	release, err := lock.TryFile(lockPath)
 	require.NoError(t, err, "expected to take the data-dir lock for the first time")
 	t.Cleanup(release)
 
@@ -199,7 +200,7 @@ func TestConnect_ServerPathFailsWhenDataDirLocked(t *testing.T) {
 	dataDir := t.TempDir()
 	lockPath := filepath.Join(dataDir, dataDirLockFile)
 
-	release, err := tryFileLock(lockPath)
+	release, err := lock.TryFile(lockPath)
 	require.NoError(t, err, "expected to take the data-dir lock for the first time")
 	t.Cleanup(release)
 

internal/db/datadirlock.go 🔗

@@ -10,6 +10,7 @@ import (
 	"strconv"
 	"time"
 
+	"github.com/charmbracelet/crush/internal/lock"
 	"github.com/charmbracelet/crush/internal/version"
 )
 
@@ -53,9 +54,9 @@ func acquireDataDirLock(dataDir string) (*dataDirLock, error) {
 	}
 
 	path := filepath.Join(dataDir, dataDirLockFile)
-	release, err := tryFileLock(path)
+	release, err := lock.TryFile(path)
 	if err != nil {
-		if errors.Is(err, errLockContended) {
+		if errors.Is(err, lock.ErrContended) {
 			return nil, contendedLockError(dataDir, path)
 		}
 		return nil, fmt.Errorf("failed to lock data directory %q: %w", dataDir, err)

internal/db/datadirlock_unix.go 🔗

@@ -1,45 +0,0 @@
-//go:build !windows
-
-package db
-
-import (
-	"errors"
-	"fmt"
-	"os"
-
-	"golang.org/x/sys/unix"
-)
-
-// errLockContended is returned by tryFileLock when the lock is already
-// held by another open file description (typically another process).
-var errLockContended = errors.New("file lock is held by another process")
-
-// tryFileLock takes an exclusive non-blocking BSD flock on path,
-// creating the file if necessary. On success it returns a release
-// function that drops the lock and closes the descriptor. When the
-// lock is contended it returns errLockContended.
-//
-// BSD flock is advisory and per-open-file-description, so it does not
-// interfere with the byte-range locks SQLite itself uses on the same
-// file's siblings (crush.db, crush.db-wal, crush.db-shm). The lock is
-// also released automatically by the kernel when the file descriptor
-// is closed, including on process crash, so we do not need any
-// explicit stale-lock recovery.
-func tryFileLock(path string) (func(), error) {
-	f, err := os.OpenFile(path, os.O_RDWR|os.O_CREATE, 0o600)
-	if err != nil {
-		return nil, fmt.Errorf("open lock file: %w", err)
-	}
-	if err := unix.Flock(int(f.Fd()), unix.LOCK_EX|unix.LOCK_NB); err != nil {
-		_ = f.Close()
-		if errors.Is(err, unix.EWOULDBLOCK) {
-			return nil, errLockContended
-		}
-		return nil, fmt.Errorf("flock: %w", err)
-	}
-	return func() {
-		// Closing the descriptor releases the flock atomically.
-		_ = unix.Flock(int(f.Fd()), unix.LOCK_UN)
-		_ = f.Close()
-	}, nil
-}

internal/db/datadirlock_windows.go 🔗

@@ -1,46 +0,0 @@
-//go:build windows
-
-package db
-
-import (
-	"errors"
-	"fmt"
-	"math"
-	"os"
-
-	"golang.org/x/sys/windows"
-)
-
-// errLockContended is returned by tryFileLock when the lock is held
-// by another process.
-var errLockContended = errors.New("file lock is held by another process")
-
-// tryFileLock takes an exclusive non-blocking lock on path via
-// LockFileEx. On success it returns a release function that unlocks
-// and closes the descriptor.
-//
-// The flags combine LOCKFILE_EXCLUSIVE_LOCK with LOCKFILE_FAIL_IMMEDIATELY
-// to mirror the BSD LOCK_EX|LOCK_NB semantics used on POSIX. The lock
-// is released when the file handle closes, including on process exit,
-// which gives us automatic stale-lock recovery without any bookkeeping.
-func tryFileLock(path string) (func(), error) {
-	f, err := os.OpenFile(path, os.O_RDWR|os.O_CREATE, 0o600)
-	if err != nil {
-		return nil, fmt.Errorf("open lock file: %w", err)
-	}
-	h := windows.Handle(f.Fd())
-	ol := new(windows.Overlapped)
-	flags := uint32(windows.LOCKFILE_EXCLUSIVE_LOCK | windows.LOCKFILE_FAIL_IMMEDIATELY)
-	if err := windows.LockFileEx(h, flags, 0, math.MaxUint32, math.MaxUint32, ol); err != nil {
-		_ = f.Close()
-		if errors.Is(err, windows.ERROR_LOCK_VIOLATION) || errors.Is(err, windows.ERROR_IO_PENDING) {
-			return nil, errLockContended
-		}
-		return nil, fmt.Errorf("LockFileEx: %w", err)
-	}
-	return func() {
-		ol := new(windows.Overlapped)
-		_ = windows.UnlockFileEx(windows.Handle(f.Fd()), 0, math.MaxUint32, math.MaxUint32, ol)
-		_ = f.Close()
-	}, nil
-}

internal/lock/lock.go 🔗

@@ -0,0 +1,75 @@
+// Package lock provides cross-process advisory file locking.
+//
+// File acquires an exclusive lock on the file at path, blocking until
+// the context is cancelled (or its deadline elapses). TryFile does the
+// same but returns ErrContended immediately if the lock is already
+// held. In both cases the returned release function drops the lock and
+// closes the underlying file descriptor.
+//
+// The lock is released automatically by the kernel on process
+// termination (including crash), so no stale-lock recovery is needed.
+//
+// The lock file at path is created if it does not exist. It is never
+// unlinked — flock is keyed by inode, not path, and unlinking could
+// create a window where two processes lock different inodes at the
+// same path.
+//
+// This is the canonical file-locking helper for Crush. Callers should
+// prefer it over rolling their own platform-specific code.
+package lock
+
+import (
+	"context"
+	"errors"
+	"fmt"
+	"os"
+)
+
+// ErrContended is returned by TryFile when the lock is already held by
+// another process.
+var ErrContended = errors.New("file lock is held by another process")
+
+// File acquires an exclusive advisory lock on the file at path, blocking
+// until the lock is acquired or ctx is cancelled. It returns a release
+// function that drops the lock and closes the underlying file descriptor.
+//
+// Pass a context with a deadline (e.g. context.WithTimeout) to bound the
+// wait. Pass context.Background() to block indefinitely.
+func File(ctx context.Context, path string) (func(), error) {
+	f, err := os.OpenFile(path, os.O_RDWR|os.O_CREATE, 0o600)
+	if err != nil {
+		return nil, fmt.Errorf("open lock file %q: %w", path, err)
+	}
+
+	release, err := lockFile(ctx, f)
+	if err != nil {
+		f.Close()
+		return nil, err
+	}
+
+	return func() {
+		release()
+		f.Close()
+	}, nil
+}
+
+// TryFile is like File but returns ErrContended immediately if the lock
+// is already held by another process. Use this when you want to fail
+// fast rather than wait.
+func TryFile(path string) (func(), error) {
+	f, err := os.OpenFile(path, os.O_RDWR|os.O_CREATE, 0o600)
+	if err != nil {
+		return nil, fmt.Errorf("open lock file %q: %w", path, err)
+	}
+
+	release, err := tryLockFile(f)
+	if err != nil {
+		f.Close()
+		return nil, err
+	}
+
+	return func() {
+		release()
+		f.Close()
+	}, nil
+}

internal/lock/lock_test.go 🔗

@@ -0,0 +1,168 @@
+package lock
+
+import (
+	"context"
+	"errors"
+	"path/filepath"
+	"sync"
+	"testing"
+	"time"
+
+	"github.com/stretchr/testify/require"
+)
+
+func TestTryFile_AcquiresWhenFree(t *testing.T) {
+	t.Parallel()
+	path := filepath.Join(t.TempDir(), "test.lock")
+
+	release, err := TryFile(path)
+	require.NoError(t, err)
+	require.NotNil(t, release)
+	release()
+}
+
+func TestTryFile_ReturnsErrContendedWhenHeld(t *testing.T) {
+	t.Parallel()
+	path := filepath.Join(t.TempDir(), "test.lock")
+
+	release, err := TryFile(path)
+	require.NoError(t, err)
+	t.Cleanup(release)
+
+	_, err = TryFile(path)
+	require.ErrorIs(t, err, ErrContended)
+}
+
+func TestTryFile_ReacquireAfterRelease(t *testing.T) {
+	t.Parallel()
+	path := filepath.Join(t.TempDir(), "test.lock")
+
+	release, err := TryFile(path)
+	require.NoError(t, err)
+	release()
+
+	release2, err := TryFile(path)
+	require.NoError(t, err)
+	t.Cleanup(release2)
+}
+
+func TestFile_AcquiresWhenFree(t *testing.T) {
+	t.Parallel()
+	path := filepath.Join(t.TempDir(), "test.lock")
+
+	release, err := File(context.Background(), path)
+	require.NoError(t, err)
+	t.Cleanup(release)
+}
+
+func TestFile_BlocksThenSucceeds(t *testing.T) {
+	t.Parallel()
+	path := filepath.Join(t.TempDir(), "test.lock")
+
+	release, err := TryFile(path)
+	require.NoError(t, err)
+
+	// Release the lock after a short delay so the blocking acquirer
+	// can complete within the test timeout.
+	go func() {
+		time.Sleep(150 * time.Millisecond)
+		release()
+	}()
+
+	ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
+	defer cancel()
+	release2, err := File(ctx, path)
+	require.NoError(t, err, "should acquire after first releases")
+	release2()
+}
+
+func TestFile_RespectsContextDeadline(t *testing.T) {
+	t.Parallel()
+	path := filepath.Join(t.TempDir(), "test.lock")
+
+	release, err := TryFile(path)
+	require.NoError(t, err)
+	t.Cleanup(release)
+
+	ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
+	defer cancel()
+	start := time.Now()
+	_, err = File(ctx, path)
+	elapsed := time.Since(start)
+	require.Error(t, err)
+	require.True(t, errors.Is(err, context.DeadlineExceeded), "expected deadline exceeded, got %v", err)
+	require.Less(t, elapsed, 1*time.Second, "should return promptly after deadline")
+}
+
+func TestFile_RespectsContextCancellation(t *testing.T) {
+	t.Parallel()
+	path := filepath.Join(t.TempDir(), "test.lock")
+
+	release, err := TryFile(path)
+	require.NoError(t, err)
+	t.Cleanup(release)
+
+	ctx, cancel := context.WithCancel(context.Background())
+	done := make(chan error, 1)
+	go func() {
+		_, err := File(ctx, path)
+		done <- err
+	}()
+
+	time.Sleep(50 * time.Millisecond)
+	cancel()
+
+	select {
+	case err := <-done:
+		require.ErrorIs(t, err, context.Canceled)
+	case <-time.After(2 * time.Second):
+		t.Fatal("File did not return after context cancellation")
+	}
+}
+
+// TestFile_ConcurrentAcquirers verifies that multiple blocking acquirers
+// queue up correctly: each gets the lock in turn, exactly one at a time.
+func TestFile_ConcurrentAcquirers(t *testing.T) {
+	t.Parallel()
+	path := filepath.Join(t.TempDir(), "test.lock")
+
+	const n = 5
+	var (
+		mu       sync.Mutex
+		inside   int
+		maxSeen  int
+		finished int
+	)
+	var wg sync.WaitGroup
+	wg.Add(n)
+	for range n {
+		go func() {
+			defer wg.Done()
+			ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
+			defer cancel()
+			release, err := File(ctx, path)
+			if err != nil {
+				t.Errorf("acquire failed: %v", err)
+				return
+			}
+			mu.Lock()
+			inside++
+			if inside > maxSeen {
+				maxSeen = inside
+			}
+			mu.Unlock()
+
+			time.Sleep(20 * time.Millisecond)
+
+			mu.Lock()
+			inside--
+			finished++
+			mu.Unlock()
+			release()
+		}()
+	}
+	wg.Wait()
+
+	require.Equal(t, n, finished)
+	require.Equal(t, 1, maxSeen, "lock must be mutually exclusive")
+}

internal/lock/lock_unix.go 🔗

@@ -0,0 +1,45 @@
+//go:build !windows
+
+package lock
+
+import (
+	"context"
+	"errors"
+	"fmt"
+	"os"
+	"time"
+
+	"golang.org/x/sys/unix"
+)
+
+// retrySleep is the interval between non-blocking flock retries in the
+// blocking File path. Small enough that contention resolution feels
+// snappy; large enough that we don't burn a CPU spinning.
+const retrySleep = 100 * time.Millisecond
+
+func lockFile(ctx context.Context, f *os.File) (func(), error) {
+	for {
+		err := unix.Flock(int(f.Fd()), unix.LOCK_EX|unix.LOCK_NB)
+		if err == nil {
+			return func() { _ = unix.Flock(int(f.Fd()), unix.LOCK_UN) }, nil
+		}
+		if !errors.Is(err, unix.EWOULDBLOCK) {
+			return nil, fmt.Errorf("flock: %w", err)
+		}
+		select {
+		case <-ctx.Done():
+			return nil, fmt.Errorf("acquire lock: %w", ctx.Err())
+		case <-time.After(retrySleep):
+		}
+	}
+}
+
+func tryLockFile(f *os.File) (func(), error) {
+	if err := unix.Flock(int(f.Fd()), unix.LOCK_EX|unix.LOCK_NB); err != nil {
+		if errors.Is(err, unix.EWOULDBLOCK) {
+			return nil, ErrContended
+		}
+		return nil, fmt.Errorf("flock: %w", err)
+	}
+	return func() { _ = unix.Flock(int(f.Fd()), unix.LOCK_UN) }, nil
+}

internal/lock/lock_windows.go 🔗

@@ -0,0 +1,57 @@
+//go:build windows
+
+package lock
+
+import (
+	"context"
+	"errors"
+	"fmt"
+	"math"
+	"os"
+	"time"
+
+	"golang.org/x/sys/windows"
+)
+
+// retrySleep is the interval between non-blocking lock retries in the
+// blocking File path.
+const retrySleep = 100 * time.Millisecond
+
+func lockFile(ctx context.Context, f *os.File) (func(), error) {
+	h := windows.Handle(f.Fd())
+	for {
+		ol := new(windows.Overlapped)
+		flags := uint32(windows.LOCKFILE_EXCLUSIVE_LOCK | windows.LOCKFILE_FAIL_IMMEDIATELY)
+		err := windows.LockFileEx(h, flags, 0, math.MaxUint32, math.MaxUint32, ol)
+		if err == nil {
+			return func() {
+				ol := new(windows.Overlapped)
+				_ = windows.UnlockFileEx(windows.Handle(f.Fd()), 0, math.MaxUint32, math.MaxUint32, ol)
+			}, nil
+		}
+		if !errors.Is(err, windows.ERROR_LOCK_VIOLATION) && !errors.Is(err, windows.ERROR_IO_PENDING) {
+			return nil, fmt.Errorf("LockFileEx: %w", err)
+		}
+		select {
+		case <-ctx.Done():
+			return nil, fmt.Errorf("acquire lock: %w", ctx.Err())
+		case <-time.After(retrySleep):
+		}
+	}
+}
+
+func tryLockFile(f *os.File) (func(), error) {
+	h := windows.Handle(f.Fd())
+	ol := new(windows.Overlapped)
+	flags := uint32(windows.LOCKFILE_EXCLUSIVE_LOCK | windows.LOCKFILE_FAIL_IMMEDIATELY)
+	if err := windows.LockFileEx(h, flags, 0, math.MaxUint32, math.MaxUint32, ol); err != nil {
+		if errors.Is(err, windows.ERROR_LOCK_VIOLATION) || errors.Is(err, windows.ERROR_IO_PENDING) {
+			return nil, ErrContended
+		}
+		return nil, fmt.Errorf("LockFileEx: %w", err)
+	}
+	return func() {
+		ol := new(windows.Overlapped)
+		_ = windows.UnlockFileEx(windows.Handle(f.Fd()), 0, math.MaxUint32, math.MaxUint32, ol)
+	}, nil
+}