diff --git a/internal/cmd/root.go b/internal/cmd/root.go index 1674fe0ca8b2569915543853527be79153e6b237..7ee034d4be674ad9938b370a93b284c36e7abd9e 100644 --- a/internal/cmd/root.go +++ b/internal/cmd/root.go @@ -382,13 +382,23 @@ func ensureServer(cmd *cobra.Command, hostURL *url.URL) error { switch hostURL.Scheme { case "unix", "npipe": needsStart := false - if _, err := os.Stat(hostURL.Host); err != nil && errors.Is(err, fs.ErrNotExist) { + _, statErr := os.Stat(hostURL.Host) + switch { + case statErr == nil: + restarted, err := restartIfStale(cmd, hostURL) + if err != nil { + slog.Warn("Failed to check server version", "error", err) + } + needsStart = restarted || err != nil + case errors.Is(statErr, fs.ErrNotExist): needsStart = true - } else if err == nil { - if err := restartIfStale(cmd, hostURL); err != nil { - slog.Warn("Failed to check server version, restarting", "error", err) - needsStart = true + default: + slog.Warn("Unexpected error stat'ing server socket, attempting cleanup", + "path", hostURL.Host, "error", statErr) + if err := os.Remove(hostURL.Host); err != nil && !errors.Is(err, fs.ErrNotExist) { + return fmt.Errorf("failed to remove stale server socket %q: %v", hostURL.Host, err) } + needsStart = true } if needsStart { @@ -517,17 +527,21 @@ func probeHealth(ctx context.Context, h *http.Client, reqURL string, hostURL *ur // restartIfStale checks whether the running server matches the current // client version. When they differ, it sends a shutdown command and // removes the stale socket so the caller can start a fresh server. -func restartIfStale(cmd *cobra.Command, hostURL *url.URL) error { +// +// It returns restarted=true when it has shut down a stale server and the +// caller must spawn a new one. When the server matches the client version +// (or the check itself fails), restarted is false. +func restartIfStale(cmd *cobra.Command, hostURL *url.URL) (restarted bool, err error) { c, err := client.NewClient("", hostURL.Scheme, hostURL.Host) if err != nil { - return err + return false, err } vi, err := c.VersionInfo(cmd.Context()) if err != nil { - return err + return false, err } if vi.Version == version.Version { - return nil + return false, nil } slog.Info("Server version mismatch, restarting", "server", vi.Version, @@ -541,13 +555,13 @@ func restartIfStale(cmd *cobra.Command, hostURL *url.URL) error { } select { case <-cmd.Context().Done(): - return cmd.Context().Err() + return true, cmd.Context().Err() case <-time.After(100 * time.Millisecond): } } // Force-remove if the socket is still lingering. _ = os.Remove(hostURL.Host) - return nil + return true, nil } var safeNameRegexp = regexp.MustCompile(`[^a-zA-Z0-9._-]`)