diff --git a/internal/client/client.go b/internal/client/client.go index e97a0570e42e7176debf3e6ca4d91760483a197d..42dd0243b234bc1c9bfc4801311a728d027eb240 100644 --- a/internal/client/client.go +++ b/internal/client/client.go @@ -121,6 +121,14 @@ func (c *Client) ShutdownServer(ctx context.Context) error { return nil } +// Dial opens a connection to the server using the same scheme-aware +// logic the client uses for its HTTP transport. Exposed so callers can +// reuse the dialer when they need to construct sibling HTTP transports +// (e.g. a readiness probe in the CLI). +func (c *Client) Dial(ctx context.Context, network, address string) (net.Conn, error) { + return c.dialer(ctx, network, address) +} + func (c *Client) dialer(ctx context.Context, network, address string) (net.Conn, error) { d := net.Dialer{ Timeout: 30 * time.Second, diff --git a/internal/cmd/clientserverrace/race_test.go b/internal/cmd/clientserverrace/race_test.go new file mode 100644 index 0000000000000000000000000000000000000000..461e799f2b7047ecbdb427b202d492f68d54e0e1 --- /dev/null +++ b/internal/cmd/clientserverrace/race_test.go @@ -0,0 +1,361 @@ +// Package clientserverrace_test is a regression test for the +// CRUSH_CLIENT_SERVER=1 socket-init race documented in +// docs/notes/2026-05-11-client-server-socket-init-race.md (item F5). +// +// It lives in its own directory so it can build even if other test +// files in internal/cmd are temporarily broken — this test only needs +// the binary, not the cmd package. +package clientserverrace_test + +import ( + "context" + "fmt" + "net" + "net/http" + "os" + "os/exec" + "path/filepath" + "runtime" + "strings" + "sync" + "sync/atomic" + "testing" + "time" +) + +// readinessErrSubstr is the user-visible error string emitted by +// ensureServer when it gives up waiting for the server socket / +// readiness probe (internal/cmd/root.go). Seeing this in any client's +// output means the race fired. +const readinessErrSubstr = "failed to initialize crush server" + +// numClients is intentionally larger than the typical CPU count to +// ensure the spawn lock + readiness probe are exercised under +// contention. +const numClients = 8 + +// clientTimeout bounds each child invocation. It only needs to be long +// enough for the spawn-and-readiness phase to complete on a cold cache; +// after that, the client may legitimately keep running (e.g. +// subscribing to server events) and we'll cancel it. The race we care +// about is observable strictly within ensureServer. +const clientTimeout = 15 * time.Second + +func TestClientServerSpawnRace(t *testing.T) { + if testing.Short() { + t.Skip("skipping client/server spawn race test in -short mode") + } + // The race and its fix are unix-socket specific. Windows uses + // named pipes via a different code path; not covered here. + if runtime.GOOS == "windows" { + t.Skip("skipping unix-socket specific race test on windows") + } + if _, err := exec.LookPath("go"); err != nil { + t.Skip("skipping: 'go' not available on PATH") + } + + repoRoot := repoRootFromTest(t) + bin := buildCrushBinary(t, repoRoot) + + // Use /tmp directly so the unix socket path stays under the + // 104-char sockaddr_un limit on darwin. t.TempDir() can return a + // path inside /var/folders/... that is too long. + runDir, err := os.MkdirTemp("/tmp", "crush-race-") + if err != nil { + t.Fatalf("mkdtemp: %v", err) + } + t.Cleanup(func() { _ = os.RemoveAll(runDir) }) + + socketPath := filepath.Join(runDir, "crush.sock") + host := "unix://" + socketPath + + // Fresh, isolated XDG/HOME so we don't touch the user's real + // state or any other test's cache. These all live under runDir + // so cleanup is one RemoveAll. + cacheHome := filepath.Join(runDir, "cache") + dataHome := filepath.Join(runDir, "data") + configHome := filepath.Join(runDir, "config") + homeDir := filepath.Join(runDir, "home") + for _, d := range []string{cacheHome, dataHome, configHome, homeDir} { + if err := os.MkdirAll(d, 0o700); err != nil { + t.Fatalf("mkdir %s: %v", d, err) + } + } + + env := append( + os.Environ(), + "CRUSH_CLIENT_SERVER=1", + "XDG_CACHE_HOME="+cacheHome, + "XDG_DATA_HOME="+dataHome, + "XDG_CONFIG_HOME="+configHome, + "HOME="+homeDir, + // Belt-and-suspenders: if anything tries to talk to a real + // provider, fail loudly rather than make a network call. + "CRUSH_DISABLE_PROVIDER_AUTO_UPDATE=1", + ) + + // Make sure no server is up before we start. + if _, err := os.Stat(socketPath); err == nil { + t.Fatalf("socket %s exists before test started", socketPath) + } + + // Always try to shut down any server we spawned, regardless of + // outcome. + t.Cleanup(func() { shutdownServer(t, socketPath) }) + + type result struct { + idx int + stdout string + stderr string + } + results := make(chan result, numClients) + + // Probe /v1/health concurrently while the clients are still + // running. The server self-shuts-down when the last workspace is + // released (internal/backend/backend.go:DeleteWorkspace), so once + // all clients exit cleanly the socket may legitimately be gone — + // asserting the socket post-hoc would race with that documented + // self-shutdown. Instead we require that during the parallel run + // at least one /v1/health probe got a 2xx, which proves the + // spawn-and-readiness path actually produced a live server. + var sawHealthy atomic.Bool + probeDone := make(chan struct{}) + stopProbe := make(chan struct{}) + + var wg sync.WaitGroup + start := make(chan struct{}) + + go func() { + defer close(probeDone) + <-start + deadline := time.Now().Add(clientTimeout) + for time.Now().Before(deadline) { + select { + case <-stopProbe: + return + default: + } + if err := pingHealth(socketPath); err == nil { + sawHealthy.Store(true) + return + } + select { + case <-stopProbe: + return + case <-time.After(50 * time.Millisecond): + } + } + }() + + for i := range numClients { + wg.Add(1) + go func(i int) { + defer wg.Done() + + // Each client gets its own working directory so the + // per-client workspace registration paths don't collide + // in confusing ways. + cwd := filepath.Join(runDir, fmt.Sprintf("ws-%d", i)) + if err := os.MkdirAll(cwd, 0o700); err != nil { + results <- result{idx: i, stderr: fmt.Sprintf("mkdir cwd: %v", err)} + return + } + + ctx, cancel := context.WithTimeout(context.Background(), clientTimeout) + defer cancel() + + // `crush run` exercises connectToServer (which is where + // the readiness race lives). On a fresh sandbox the + // command may legitimately keep running past the race + // (e.g. waiting on event subscriptions); the context + // timeout above bounds that. We assert race outcomes + // purely from output, not exit code. + c := exec.CommandContext( + ctx, bin, + "--host", host, + "--cwd", cwd, + "run", "hi", + ) + c.Env = env + var outBuf, errBuf strings.Builder + c.Stdout = &outBuf + c.Stderr = &errBuf + + <-start + _ = c.Run() + results <- result{ + idx: i, + stdout: outBuf.String(), + stderr: errBuf.String(), + } + }(i) + } + + close(start) // release all clients as simultaneously as possible + wg.Wait() + close(results) + close(stopProbe) + <-probeDone + + var raceFailures []string + for r := range results { + if strings.Contains(r.stderr, readinessErrSubstr) || + strings.Contains(r.stdout, readinessErrSubstr) { + raceFailures = append(raceFailures, fmt.Sprintf( + "client %d: readiness error in output\nstderr:\n%s\nstdout:\n%s", + r.idx, r.stderr, r.stdout, + )) + } + } + + if len(raceFailures) > 0 { + t.Fatalf( + "client/server spawn race regressed: %d/%d clients failed\n\n%s", + len(raceFailures), numClients, + strings.Join(raceFailures, "\n---\n"), + ) + } + + // Positive sanity check: at some point during the parallel run a + // /v1/health probe must have succeeded. We deliberately do *not* + // stat the socket post-hoc: when every client returns cleanly + // (e.g. exits early because no providers are configured), the + // last DeleteWorkspace triggers the server's self-shutdown and + // the socket disappears. That is correct behaviour, not a race + // regression. + if !sawHealthy.Load() { + t.Fatalf("no /v1/health probe succeeded on %s while %d clients were running", + socketPath, numClients) + } +} + +// pingHealth issues a single GET /v1/health over the unix socket and +// requires a 2xx response. +func pingHealth(socketPath string) error { + tr := &http.Transport{ + DialContext: func(ctx context.Context, _, _ string) (net.Conn, error) { + var d net.Dialer + return d.DialContext(ctx, "unix", socketPath) + }, + } + defer tr.CloseIdleConnections() + hc := &http.Client{Transport: tr, Timeout: 2 * time.Second} + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + req, err := http.NewRequestWithContext(ctx, http.MethodGet, + "http://crush.local/v1/health", nil) + if err != nil { + return err + } + rsp, err := hc.Do(req) + if err != nil { + return err + } + defer rsp.Body.Close() + if rsp.StatusCode < 200 || rsp.StatusCode >= 300 { + return fmt.Errorf("health check returned %s", rsp.Status) + } + return nil +} + +// repoRootFromTest walks up from this test file's directory to find +// the repo root (the directory containing go.mod). Walking up by a +// fixed count is fragile across reorganisations. +func repoRootFromTest(t *testing.T) string { + t.Helper() + cwd, err := os.Getwd() + if err != nil { + t.Fatalf("getwd: %v", err) + } + dir := cwd + for { + if _, err := os.Stat(filepath.Join(dir, "go.mod")); err == nil { + return dir + } + parent := filepath.Dir(dir) + if parent == dir { + t.Fatalf("could not find go.mod walking up from %s", cwd) + } + dir = parent + } +} + +// buildCrushBinary builds the crush binary once at the start of the +// test and returns the absolute path. Subsequent t.Cleanup removes +// the built artefact. +func buildCrushBinary(t *testing.T, repoRoot string) string { + t.Helper() + + binDir, err := os.MkdirTemp("", "crush-race-bin-") + if err != nil { + t.Fatalf("mkdtemp bin: %v", err) + } + t.Cleanup(func() { _ = os.RemoveAll(binDir) }) + + binPath := filepath.Join(binDir, "crush") + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) + defer cancel() + + cmd := exec.CommandContext(ctx, "go", "build", "-o", binPath, ".") + cmd.Dir = repoRoot + // Match the project's standard build flags. CGO_ENABLED=0 keeps + // the binary statically linked and avoids surprising the test on + // hosts without a C toolchain. + cmd.Env = append(os.Environ(), "CGO_ENABLED=0") + out, err := cmd.CombinedOutput() + if err != nil { + t.Fatalf("go build crush: %v\n%s", err, out) + } + return binPath +} + +// shutdownServer best-effort terminates any crush server bound to +// socketPath by POSTing to /v1/control. We don't import the project's +// own client package to keep this test free of internal API churn. +func shutdownServer(t *testing.T, socketPath string) { + t.Helper() + if _, err := os.Stat(socketPath); err != nil { + return + } + + tr := &http.Transport{ + DialContext: func(ctx context.Context, _, _ string) (net.Conn, error) { + var d net.Dialer + return d.DialContext(ctx, "unix", socketPath) + }, + } + hc := &http.Client{Transport: tr, Timeout: 5 * time.Second} + defer tr.CloseIdleConnections() + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + body := strings.NewReader(`{"command":"shutdown"}`) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, + "http://crush.local/v1/control", body) + if err != nil { + t.Logf("shutdown: build request: %v", err) + return + } + req.Header.Set("Content-Type", "application/json") + + resp, err := hc.Do(req) + if err != nil { + // Server may already be gone — not an error. + t.Logf("shutdown: %v (probably already exited)", err) + return + } + _ = resp.Body.Close() + + // Wait briefly for the socket to disappear so the next test + // using the same path doesn't race. + deadline := time.Now().Add(5 * time.Second) + for time.Now().Before(deadline) { + if _, err := os.Stat(socketPath); err != nil { + return + } + time.Sleep(50 * time.Millisecond) + } +} diff --git a/internal/cmd/root.go b/internal/cmd/root.go index 8cc9303b312943c75715057c70c47f47fb50e65b..b9ba4e8241100383492d20c3e64aefe5c0a5b286 100644 --- a/internal/cmd/root.go +++ b/internal/cmd/root.go @@ -9,6 +9,8 @@ import ( "io" "io/fs" "log/slog" + "net" + "net/http" "net/url" "os" "os/exec" @@ -357,22 +359,7 @@ func connectToServer(cmd *cobra.Command) (*client.Client, *proto.Workspace, func ws, err := c.CreateWorkspace(ctx, wsReq) if err != nil { - // The server socket may exist before the HTTP handler is ready. - // Retry a few times with a short backoff. - for range 5 { - select { - case <-ctx.Done(): - return nil, nil, nil, ctx.Err() - case <-time.After(200 * time.Millisecond): - } - ws, err = c.CreateWorkspace(ctx, wsReq) - if err == nil { - break - } - } - if err != nil { - return nil, nil, nil, fmt.Errorf("failed to create workspace: %v", err) - } + return nil, nil, nil, fmt.Errorf("failed to create workspace: %v", err) } if shouldEnableMetrics(ws.Config) { @@ -396,55 +383,239 @@ func ensureServer(cmd *cobra.Command, hostURL *url.URL) error { switch hostURL.Scheme { case "unix", "npipe": needsStart := false - if _, err := os.Stat(hostURL.Host); err != nil && errors.Is(err, fs.ErrNotExist) { + _, statErr := os.Stat(hostURL.Host) + switch { + case statErr == nil: + restarted, err := restartIfStale(cmd, hostURL) + if err != nil { + slog.Warn("Failed to check server version", "error", err) + } + needsStart = restarted || err != nil + case errors.Is(statErr, fs.ErrNotExist): needsStart = true - } else if err == nil { - if err := restartIfStale(cmd, hostURL); err != nil { - slog.Warn("Failed to check server version, restarting", "error", err) - needsStart = true + default: + slog.Warn("Unexpected error stat'ing server socket, attempting cleanup", + "path", hostURL.Host, "error", statErr) + if err := os.Remove(hostURL.Host); err != nil && !errors.Is(err, fs.ErrNotExist) { + return fmt.Errorf("failed to remove stale server socket %q: %v", hostURL.Host, err) } + needsStart = true } if needsStart { - if err := startDetachedServer(cmd); err != nil { - return err + if err := spawnAndWaitReady(cmd, hostURL); err != nil { + return fmt.Errorf("failed to initialize crush server: %v", err) } + return nil } - var err error - for range 10 { - _, err = os.Stat(hostURL.Host) - if err == nil { - break - } - select { - case <-cmd.Context().Done(): - return cmd.Context().Err() - case <-time.After(100 * time.Millisecond): + if err := waitForServerReady(cmd.Context(), hostURL); err != nil { + return fmt.Errorf("failed to initialize crush server: %v", err) + } + } + + return nil +} + +// spawnAndWaitReady serializes the spawn-and-wait-for-readiness sequence +// across concurrent clients via an exclusive flock on +// $XDG_CACHE_HOME/crush/server-/start.lock. +// +// After acquiring the lock it re-probes readiness so that a client that +// blocked while another client was spawning can skip its own spawn and +// just use the now-running server. The lock is held only for the +// duration of "spawn + readiness probe" and released before the caller +// resumes its normal lifetime. +func spawnAndWaitReady(cmd *cobra.Command, hostURL *url.URL) error { + chDir, err := perHostServerDir(hostURL) + if err != nil { + return err + } + release, err := acquireSpawnLock(filepath.Join(chDir, "start.lock")) + if err != nil { + // If the lock itself is unavailable, fall back to the + // unsynchronized path rather than blocking the user. + slog.Warn("Failed to acquire spawn lock, proceeding without single-flight", "error", err) + if err := startDetachedServer(cmd, hostURL); err != nil { + return err + } + return waitForServerReady(cmd.Context(), hostURL) + } + defer release() + + // Another client may have just finished spawning while we were + // waiting on the lock; if the server is already responsive, skip + // the spawn entirely. + probeCtx, cancel := context.WithTimeout(cmd.Context(), 200*time.Millisecond) + probeErr := quickHealthProbe(probeCtx, hostURL) + cancel() + if probeErr == nil { + return nil + } + + if err := startDetachedServer(cmd, hostURL); err != nil { + return err + } + return waitForServerReady(cmd.Context(), hostURL) +} + +// quickHealthProbe issues a single readiness request with the caller's +// context and returns nil iff the server is responsive right now. +func quickHealthProbe(ctx context.Context, hostURL *url.URL) error { + httpClient, reqURL, err := readinessHTTPClient(hostURL) + if err != nil { + return err + } + return probeHealth(ctx, httpClient, reqURL, hostURL) +} + +// perHostServerDir returns (and creates) the cache directory used for +// per-host server state (logs, start.lock, etc.). The path is derived +// from the parsed host URL rather than the global flag so the same key +// is computed regardless of where the host came from. +func perHostServerDir(hostURL *url.URL) (string, error) { + chDir := filepath.Join(config.GlobalCacheDir(), "server-"+safeHostName(hostURL)) + if err := os.MkdirAll(chDir, 0o700); err != nil { + return "", fmt.Errorf("failed to create server working directory: %v", err) + } + return chDir, nil +} + +// safeHostName returns a filesystem-safe identifier for hostURL, +// suitable for use as a directory name. It mirrors the input shape of +// the --host flag so client and server compute the same key. +func safeHostName(hostURL *url.URL) string { + return safeNameRegexp.ReplaceAllString( + hostURL.Scheme+"://"+hostURL.Host+hostURL.Path, "_") +} + +// serverReadyTimeout returns the total budget for the readiness probe. +// Overridable via CRUSH_SERVER_READY_TIMEOUT (parsed as a Go duration). +func serverReadyTimeout() time.Duration { + const def = 10 * time.Second + v := os.Getenv("CRUSH_SERVER_READY_TIMEOUT") + if v == "" { + return def + } + d, err := time.ParseDuration(v) + if err != nil || d <= 0 { + return def + } + return d +} + +// waitForServerReady polls GET /v1/health until the server responds with +// any 2xx status or the total timeout elapses. Each attempt uses a short +// per-attempt timeout so a hung listener doesn't burn the whole budget. +// +// The HTTP transport is built to mirror how *client.Client dials so the +// same unix socket / npipe / tcp setups all work uniformly here. +func waitForServerReady(ctx context.Context, hostURL *url.URL) error { + httpClient, reqURL, err := readinessHTTPClient(hostURL) + if err != nil { + return err + } + + const perAttempt = 100 * time.Millisecond + deadline := time.Now().Add(serverReadyTimeout()) + + var lastErr error + for { + if err := ctx.Err(); err != nil { + return err + } + if time.Now().After(deadline) { + if lastErr != nil { + return lastErr } + return fmt.Errorf("timed out waiting for server readiness") } - if err != nil { - return fmt.Errorf("failed to initialize crush server: %v", err) + + attemptCtx, cancel := context.WithTimeout(ctx, perAttempt) + err := probeHealth(attemptCtx, httpClient, reqURL, hostURL) + cancel() + if err == nil { + return nil + } + lastErr = err + + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(perAttempt): } } +} +// readinessHTTPClient builds an *http.Client whose transport dials the +// server using the same scheme-aware logic as *client.Client (unix +// socket, named pipe, or tcp). +func readinessHTTPClient(hostURL *url.URL) (*http.Client, string, error) { + c, err := client.NewClient("", hostURL.Scheme, hostURL.Host) + if err != nil { + return nil, "", err + } + + tr := http.DefaultTransport.(*http.Transport).Clone() + tr.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { + return c.Dial(ctx, network, addr) + } + if hostURL.Scheme == "unix" || hostURL.Scheme == "npipe" { + tr.DisableCompression = true + } + + httpClient := &http.Client{Transport: tr} + + // For unix sockets / named pipes we still need a syntactically valid + // HTTP URL; the actual address is resolved by the dialer. + host := hostURL.Host + if hostURL.Scheme == "unix" || hostURL.Scheme == "npipe" { + host = client.DummyHost + } + reqURL := (&url.URL{Scheme: "http", Host: host, Path: "/v1/health"}).String() + return httpClient, reqURL, nil +} + +// probeHealth issues a single GET to the readiness endpoint and treats +// any 2xx response as success. +func probeHealth(ctx context.Context, h *http.Client, reqURL string, hostURL *url.URL) error { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, reqURL, nil) + if err != nil { + return err + } + if hostURL.Scheme == "unix" || hostURL.Scheme == "npipe" { + req.Host = client.DummyHost + } + rsp, err := h.Do(req) + if err != nil { + return err + } + defer rsp.Body.Close() + _, _ = io.Copy(io.Discard, rsp.Body) + if rsp.StatusCode < 200 || rsp.StatusCode >= 300 { + return fmt.Errorf("server health check failed: %s", rsp.Status) + } return nil } // restartIfStale checks whether the running server matches the current // client version. When they differ, it sends a shutdown command and // removes the stale socket so the caller can start a fresh server. -func restartIfStale(cmd *cobra.Command, hostURL *url.URL) error { +// +// It returns restarted=true when it has shut down a stale server and the +// caller must spawn a new one. When the server matches the client version +// (or the check itself fails), restarted is false. +func restartIfStale(cmd *cobra.Command, hostURL *url.URL) (restarted bool, err error) { c, err := client.NewClient("", hostURL.Scheme, hostURL.Host) if err != nil { - return err + return false, err } vi, err := c.VersionInfo(cmd.Context()) if err != nil { - return err + return false, err } if vi.Version == version.Version { - return nil + return false, nil } slog.Info("Server version mismatch, restarting", "server", vi.Version, @@ -458,27 +629,26 @@ func restartIfStale(cmd *cobra.Command, hostURL *url.URL) error { } select { case <-cmd.Context().Done(): - return cmd.Context().Err() + return true, cmd.Context().Err() case <-time.After(100 * time.Millisecond): } } // Force-remove if the socket is still lingering. _ = os.Remove(hostURL.Host) - return nil + return true, nil } var safeNameRegexp = regexp.MustCompile(`[^a-zA-Z0-9._-]`) -func startDetachedServer(cmd *cobra.Command) error { +func startDetachedServer(cmd *cobra.Command, hostURL *url.URL) error { exe, err := os.Executable() if err != nil { return fmt.Errorf("failed to get executable path: %v", err) } - safeClientHost := safeNameRegexp.ReplaceAllString(clientHost, "_") - chDir := filepath.Join(config.GlobalCacheDir(), "server-"+safeClientHost) - if err := os.MkdirAll(chDir, 0o700); err != nil { - return fmt.Errorf("failed to create server working directory: %v", err) + chDir, err := perHostServerDir(hostURL) + if err != nil { + return err } cmdArgs := []string{"server"} @@ -486,7 +656,11 @@ func startDetachedServer(cmd *cobra.Command) error { cmdArgs = append(cmdArgs, "--host", clientHost) } - c := exec.CommandContext(cmd.Context(), exe, cmdArgs...) + // Use context.Background() so the parent's context cancellation does not + // kill the spawned server. detachProcess (Setsid on !windows, + // DETACHED_PROCESS on windows) is what truly detaches the child from + // this process's lifetime. + c := exec.CommandContext(context.Background(), exe, cmdArgs...) stdoutPath := filepath.Join(chDir, "stdout.log") stderrPath := filepath.Join(chDir, "stderr.log") detachProcess(c) diff --git a/internal/cmd/server.go b/internal/cmd/server.go index 460d5280e18930c2008db1199aac18a5b281a83d..0033632af2e547711a60dbc8d1314abf393f56a5 100644 --- a/internal/cmd/server.go +++ b/internal/cmd/server.go @@ -42,7 +42,12 @@ var serverCmd = &cobra.Command{ return fmt.Errorf("failed to load configuration: %v", err) } - logFile := filepath.Join(config.GlobalCacheDir(), "server-"+safeNameRegexp.ReplaceAllString(serverHost, "_"), "crush.log") + hostURL, err := server.ParseHostURL(serverHost) + if err != nil { + return fmt.Errorf("invalid server host: %v", err) + } + + logFile := filepath.Join(config.GlobalCacheDir(), "server-"+safeHostName(hostURL), "crush.log") if term.IsTerminal(os.Stderr.Fd()) { crushlog.Setup(logFile, debug, os.Stderr) @@ -50,11 +55,6 @@ var serverCmd = &cobra.Command{ crushlog.Setup(logFile, debug) } - hostURL, err := server.ParseHostURL(serverHost) - if err != nil { - return fmt.Errorf("invalid server host: %v", err) - } - srv := server.NewServer(cfg, hostURL.Scheme, hostURL.Host) srv.SetLogger(slog.Default()) slog.Info("Starting Crush server...", "addr", serverHost) diff --git a/internal/cmd/spawnlock_other.go b/internal/cmd/spawnlock_other.go new file mode 100644 index 0000000000000000000000000000000000000000..1e07b7728a26e51e0ffaee16af1d685c13e5f424 --- /dev/null +++ b/internal/cmd/spawnlock_other.go @@ -0,0 +1,28 @@ +//go:build !windows + +package cmd + +import ( + "fmt" + "os" + + "golang.org/x/sys/unix" +) + +// acquireSpawnLock takes an exclusive flock on the given file (creating +// it if necessary) and returns a release function that unlocks and +// closes the file. Blocks until the lock is acquired. +func acquireSpawnLock(path string) (func(), error) { + f, err := os.OpenFile(path, os.O_RDWR|os.O_CREATE, 0o600) + if err != nil { + return nil, fmt.Errorf("open spawn lock %q: %v", path, err) + } + if err := unix.Flock(int(f.Fd()), unix.LOCK_EX); err != nil { + _ = f.Close() + return nil, fmt.Errorf("flock spawn lock %q: %v", path, err) + } + return func() { + _ = unix.Flock(int(f.Fd()), unix.LOCK_UN) + _ = f.Close() + }, nil +} diff --git a/internal/cmd/spawnlock_windows.go b/internal/cmd/spawnlock_windows.go new file mode 100644 index 0000000000000000000000000000000000000000..d3e7492b229ac4bb5b3eca815711d5bc14ddcf0c --- /dev/null +++ b/internal/cmd/spawnlock_windows.go @@ -0,0 +1,32 @@ +//go:build windows + +package cmd + +import ( + "fmt" + "math" + "os" + + "golang.org/x/sys/windows" +) + +// acquireSpawnLock takes an exclusive lock on the given file (creating +// it if necessary) using LockFileEx, and returns a release function +// that unlocks and closes the file. Blocks until the lock is acquired. +func acquireSpawnLock(path string) (func(), error) { + f, err := os.OpenFile(path, os.O_RDWR|os.O_CREATE, 0o600) + if err != nil { + return nil, fmt.Errorf("open spawn lock %q: %v", path, err) + } + h := windows.Handle(f.Fd()) + ol := new(windows.Overlapped) + if err := windows.LockFileEx(h, windows.LOCKFILE_EXCLUSIVE_LOCK, 0, math.MaxUint32, math.MaxUint32, ol); err != nil { + _ = f.Close() + return nil, fmt.Errorf("LockFileEx spawn lock %q: %v", path, err) + } + return func() { + ol := new(windows.Overlapped) + _ = windows.UnlockFileEx(windows.Handle(f.Fd()), 0, math.MaxUint32, math.MaxUint32, ol) + _ = f.Close() + }, nil +}