fix(server): clear leftover sockets so server can always start

Christian Rocha and Charm Crush created

When the server starts and finds a socket file with no process behind
it, it now removes the dead socket before binding instead of failing
with an address-already-in-use error. A live server is never disturbed,
and the cleanup is recorded in the log.

Co-Authored-By: Charm Crush <crush@charm.land>

Change summary

internal/server/net_other.go   | 57 ++++++++++++++++++++++++++++++++++-
internal/server/net_windows.go | 19 ++++++++++-
internal/server/server.go      |  5 ++
3 files changed, 74 insertions(+), 7 deletions(-)

Detailed changes

internal/server/net_other.go 🔗

@@ -2,9 +2,60 @@
 
 package server
 
-import "net"
+import (
+	"errors"
+	"io/fs"
+	"net"
+	"os"
+	"time"
+)
 
-func listen(network, address string) (net.Listener, error) {
+// staleSocketDialTimeout bounds the probe used to detect whether a Unix
+// socket file on disk is backed by a live listener.
+const staleSocketDialTimeout = 200 * time.Millisecond
+
+// listen binds a net.Listener on the given network and address.
+//
+// For unix sockets it self-heals from stale socket files: if the path
+// already exists on disk, it first probes with a short net.DialTimeout.
+// A successful dial means a live server owns the socket, so we proceed
+// to net.Listen (which surfaces the usual "address already in use"
+// error). A failed dial that isStaleSocketErr classifies as stale
+// triggers an os.Remove of the path (ignoring fs.ErrNotExist) before
+// the bind.
+//
+// The returned removedStale bool reports whether a stale socket file
+// was removed prior to binding so callers can log it. The operation
+// is idempotent: removing an absent file is a no-op, and a live
+// socket is never removed.
+func listen(network, address string) (net.Listener, bool, error) {
+	var removedStale bool
+	if network == "unix" && address != "" {
+		if _, err := os.Stat(address); err == nil {
+			conn, dialErr := net.DialTimeout(network, address, staleSocketDialTimeout)
+			if dialErr == nil {
+				// A live server owns the socket. Fall through to
+				// net.Listen so the caller sees the standard
+				// "address already in use" error.
+				conn.Close()
+			} else if isStaleSocketErr(dialErr) {
+				rmErr := os.Remove(address)
+				switch {
+				case rmErr == nil:
+					removedStale = true
+				case errors.Is(rmErr, fs.ErrNotExist):
+					// Another process removed it between our
+					// stat and remove; treat as a no-op.
+				default:
+					return nil, false, rmErr
+				}
+			}
+		}
+	}
 	//nolint:noctx
-	return net.Listen(network, address)
+	ln, err := net.Listen(network, address)
+	if err != nil {
+		return nil, removedStale, err
+	}
+	return ln, removedStale, nil
 }

internal/server/net_windows.go 🔗

@@ -9,7 +9,12 @@ import (
 	"github.com/Microsoft/go-winio"
 )
 
-func listen(network, address string) (net.Listener, error) {
+// listen binds a net.Listener on the given network and address.
+//
+// On Windows there is no Unix-socket stale-file recovery to perform,
+// so removedStale is always false. The signature matches the
+// non-Windows implementation so callers can use a single code path.
+func listen(network, address string) (net.Listener, bool, error) {
 	switch network {
 	case "npipe":
 		cfg := &winio.PipeConfig{
@@ -17,8 +22,16 @@ func listen(network, address string) (net.Listener, error) {
 			InputBufferSize:  65536,
 			OutputBufferSize: 65536,
 		}
-		return winio.ListenPipe(address, cfg)
+		ln, err := winio.ListenPipe(address, cfg)
+		if err != nil {
+			return nil, false, err
+		}
+		return ln, false, nil
 	default:
-		return net.Listen(network, address) //nolint:noctx
+		ln, err := net.Listen(network, address) //nolint:noctx
+		if err != nil {
+			return nil, false, err
+		}
+		return ln, false, nil
 	}
 }

internal/server/server.go 🔗

@@ -234,10 +234,13 @@ func (s *Server) ListenAndServe() error {
 	if s.ln != nil {
 		return fmt.Errorf("server already started")
 	}
-	ln, err := listen(s.network, s.Addr)
+	ln, removedStale, err := listen(s.network, s.Addr)
 	if err != nil {
 		return fmt.Errorf("failed to listen on %s: %w", s.Addr, err)
 	}
+	if removedStale && s.logger != nil {
+		s.logger.Warn("Removed stale socket before binding", "address", s.Addr)
+	}
 	return s.Serve(ln)
 }