chore(test,server): cover stale socket cleanup and socket location

Christian Rocha and Charm Crush created

Adds tests confirming that a leftover socket is detected and cleared
while a live socket is left untouched, and that the socket is placed
in the expected per-user location.

Co-Authored-By: Charm Crush <crush@charm.land>

Change summary

internal/server/socket_test.go | 189 ++++++++++++++++++++++++++++++++++++
1 file changed, 189 insertions(+)

Detailed changes

internal/server/socket_test.go 🔗

@@ -0,0 +1,189 @@
+//go:build !windows
+
+package server
+
+import (
+	"errors"
+	"fmt"
+	"io/fs"
+	"net"
+	"os"
+	"path/filepath"
+	"strings"
+	"sync"
+	"syscall"
+	"testing"
+	"time"
+
+	"github.com/stretchr/testify/require"
+)
+
+// fakeTimeoutErr is a minimal net.Error implementation whose Timeout()
+// returns true. It is used to verify that IsStaleSocketErr never
+// classifies a timeout as stale.
+type fakeTimeoutErr struct{}
+
+func (fakeTimeoutErr) Error() string   { return "fake timeout" }
+func (fakeTimeoutErr) Timeout() bool   { return true }
+func (fakeTimeoutErr) Temporary() bool { return true }
+
+func TestIsStaleSocketErr(t *testing.T) {
+	t.Parallel()
+
+	cases := []struct {
+		name string
+		err  error
+		want bool
+	}{
+		{name: "nil", err: nil, want: false},
+		{name: "ECONNREFUSED", err: syscall.ECONNREFUSED, want: true},
+		{
+			name: "wrapped ECONNREFUSED",
+			err:  fmt.Errorf("dial: %w", syscall.ECONNREFUSED),
+			want: true,
+		},
+		{name: "fs.ErrNotExist", err: fs.ErrNotExist, want: true},
+		{
+			name: "wrapped fs.ErrNotExist",
+			err:  fmt.Errorf("stat: %w", fs.ErrNotExist),
+			want: true,
+		},
+		{name: "timeout net.Error", err: fakeTimeoutErr{}, want: false},
+		{name: "generic error", err: errors.New("boom"), want: false},
+	}
+
+	for _, tc := range cases {
+		t.Run(tc.name, func(t *testing.T) {
+			t.Parallel()
+			require.Equal(t, tc.want, IsStaleSocketErr(tc.err))
+		})
+	}
+}
+
+func TestDefaultHost_XDGRuntimeDir(t *testing.T) {
+	dir := t.TempDir()
+	t.Setenv("XDG_RUNTIME_DIR", dir)
+
+	host := DefaultHost()
+
+	require.True(t, strings.HasPrefix(host, "unix://"),
+		"DefaultHost should return a unix:// URL, got %q", host)
+	path := strings.TrimPrefix(host, "unix://")
+
+	// The composed path may exceed maxUnixSocketPathLen and fall back
+	// to /tmp; only assert containment when it did not.
+	if len(filepath.Join(dir, "crush.sock")) <= maxUnixSocketPathLen {
+		require.True(t, strings.HasPrefix(path, dir),
+			"socket path %q should live under %q", path, dir)
+	}
+	require.True(t, strings.HasSuffix(path, ".sock"),
+		"socket path %q should end in .sock", path)
+	require.Contains(t, filepath.Base(path), "crush",
+		"socket filename should contain 'crush'")
+}
+
+func TestDefaultHost_FallbackTemp(t *testing.T) {
+	t.Setenv("XDG_RUNTIME_DIR", "")
+
+	host := DefaultHost()
+
+	require.True(t, strings.HasPrefix(host, "unix://"),
+		"DefaultHost should return a unix:// URL, got %q", host)
+	path := strings.TrimPrefix(host, "unix://")
+	require.NotEmpty(t, path, "fallback socket path must be non-empty")
+	require.True(t, strings.HasSuffix(path, ".sock"),
+		"socket path %q should end in .sock", path)
+	require.Contains(t, filepath.Base(path), "crush",
+		"socket filename should contain 'crush'")
+}
+
+// staleSocketPath creates a deterministic stale unix socket file on
+// disk: the socket node exists but no goroutine is accepting on it.
+// It does so by binding a listener, disabling unlink-on-close, then
+// closing the listener. The path is returned so the caller can probe
+// it. A leftover file is best-effort removed via t.Cleanup.
+func staleSocketPath(t *testing.T, path string) {
+	t.Helper()
+	ln, err := net.Listen("unix", path)
+	require.NoError(t, err)
+	ul, ok := ln.(*net.UnixListener)
+	require.True(t, ok, "expected *net.UnixListener, got %T", ln)
+	ul.SetUnlinkOnClose(false)
+	require.NoError(t, ul.Close())
+
+	// Verify it is actually stale: dialing should fail.
+	conn, dialErr := net.DialTimeout("unix", path, 200*time.Millisecond)
+	if dialErr == nil {
+		conn.Close()
+		t.Fatalf("expected stale socket at %q to refuse connections", path)
+	}
+	require.True(t, IsStaleSocketErr(dialErr),
+		"expected stale-socket dial error, got %v", dialErr)
+
+	t.Cleanup(func() {
+		_ = os.Remove(path)
+	})
+}
+
+func TestListen_RemovesStaleSocket(t *testing.T) {
+	// t.TempDir() yields a path that may already be near the macOS
+	// sun_path limit; use a short filename to stay well under it.
+	dir := t.TempDir()
+	path := filepath.Join(dir, "s.sock")
+
+	staleSocketPath(t, path)
+
+	// Confirm the stale node is present before we call listen.
+	_, statErr := os.Stat(path)
+	require.NoError(t, statErr, "stale socket file should exist on disk")
+
+	ln, removedStale, err := listen("unix", path)
+	require.NoError(t, err)
+	require.NotNil(t, ln)
+	require.True(t, removedStale, "listen should report removedStale=true")
+	t.Cleanup(func() {
+		_ = ln.Close()
+	})
+}
+
+func TestListen_LiveSocketNotRemoved(t *testing.T) {
+	dir := t.TempDir()
+	path := filepath.Join(dir, "s.sock")
+
+	ln1, err := net.Listen("unix", path)
+	require.NoError(t, err)
+
+	// Drain accepts so the listener stays alive and responsive without
+	// blocking the test on a stray connection.
+	var wg sync.WaitGroup
+	wg.Add(1)
+	go func() {
+		defer wg.Done()
+		for {
+			c, err := ln1.Accept()
+			if err != nil {
+				return
+			}
+			_ = c.Close()
+		}
+	}()
+	t.Cleanup(func() {
+		_ = ln1.Close()
+		wg.Wait()
+	})
+
+	ln2, removedStale, err := listen("unix", path)
+	if ln2 != nil {
+		_ = ln2.Close()
+	}
+	require.Error(t, err, "listen on a live socket must fail")
+	require.False(t, removedStale,
+		"a live socket must never be removed (got removedStale=true)")
+
+	// The live socket file must still be on disk and dialable.
+	_, statErr := os.Stat(path)
+	require.NoError(t, statErr, "live socket file should still exist")
+	conn, dialErr := net.DialTimeout("unix", path, 200*time.Millisecond)
+	require.NoError(t, dialErr, "live socket should still accept dials")
+	_ = conn.Close()
+}