socket_test.go

  1//go:build !windows
  2
  3package server
  4
  5import (
  6	"errors"
  7	"fmt"
  8	"io/fs"
  9	"net"
 10	"os"
 11	"path/filepath"
 12	"strings"
 13	"sync"
 14	"syscall"
 15	"testing"
 16	"time"
 17
 18	"github.com/stretchr/testify/require"
 19)
 20
 21// fakeTimeoutErr is a minimal net.Error implementation whose Timeout()
 22// returns true. It is used to verify that IsStaleSocketErr never
 23// classifies a timeout as stale.
 24type fakeTimeoutErr struct{}
 25
 26func (fakeTimeoutErr) Error() string   { return "fake timeout" }
 27func (fakeTimeoutErr) Timeout() bool   { return true }
 28func (fakeTimeoutErr) Temporary() bool { return true }
 29
 30func TestIsStaleSocketErr(t *testing.T) {
 31	t.Parallel()
 32
 33	cases := []struct {
 34		name string
 35		err  error
 36		want bool
 37	}{
 38		{name: "nil", err: nil, want: false},
 39		{name: "ECONNREFUSED", err: syscall.ECONNREFUSED, want: true},
 40		{
 41			name: "wrapped ECONNREFUSED",
 42			err:  fmt.Errorf("dial: %w", syscall.ECONNREFUSED),
 43			want: true,
 44		},
 45		{name: "fs.ErrNotExist", err: fs.ErrNotExist, want: true},
 46		{
 47			name: "wrapped fs.ErrNotExist",
 48			err:  fmt.Errorf("stat: %w", fs.ErrNotExist),
 49			want: true,
 50		},
 51		{name: "timeout net.Error", err: fakeTimeoutErr{}, want: false},
 52		{name: "generic error", err: errors.New("boom"), want: false},
 53	}
 54
 55	for _, tc := range cases {
 56		t.Run(tc.name, func(t *testing.T) {
 57			t.Parallel()
 58			require.Equal(t, tc.want, IsStaleSocketErr(tc.err))
 59		})
 60	}
 61}
 62
 63func TestDefaultHost_XDGRuntimeDir(t *testing.T) {
 64	dir := t.TempDir()
 65	t.Setenv("XDG_RUNTIME_DIR", dir)
 66
 67	host := DefaultHost()
 68
 69	require.True(t, strings.HasPrefix(host, "unix://"),
 70		"DefaultHost should return a unix:// URL, got %q", host)
 71	path := strings.TrimPrefix(host, "unix://")
 72
 73	// The composed path may exceed maxUnixSocketPathLen and fall back
 74	// to /tmp; only assert containment when it did not. Recompose the
 75	// path under dir (rather than checking the returned path length,
 76	// which is short again after a /tmp fallback) to decide whether a
 77	// fallback happened. The socket is named crush-<uid>.sock.
 78	composed := filepath.Join(dir, filepath.Base(path))
 79	if len(composed) <= maxUnixSocketPathLen {
 80		require.True(t, strings.HasPrefix(path, dir),
 81			"socket path %q should live under %q", path, dir)
 82	}
 83	require.True(t, strings.HasSuffix(path, ".sock"),
 84		"socket path %q should end in .sock", path)
 85	require.Contains(t, filepath.Base(path), "crush",
 86		"socket filename should contain 'crush'")
 87}
 88
 89func TestDefaultHost_FallbackTemp(t *testing.T) {
 90	t.Setenv("XDG_RUNTIME_DIR", "")
 91
 92	host := DefaultHost()
 93
 94	require.True(t, strings.HasPrefix(host, "unix://"),
 95		"DefaultHost should return a unix:// URL, got %q", host)
 96	path := strings.TrimPrefix(host, "unix://")
 97	require.NotEmpty(t, path, "fallback socket path must be non-empty")
 98	require.True(t, strings.HasSuffix(path, ".sock"),
 99		"socket path %q should end in .sock", path)
100	require.Contains(t, filepath.Base(path), "crush",
101		"socket filename should contain 'crush'")
102}
103
104// staleSocketPath creates a deterministic stale unix socket file on
105// disk: the socket node exists but no goroutine is accepting on it.
106// It does so by binding a listener, disabling unlink-on-close, then
107// closing the listener. The path is returned so the caller can probe
108// it. A leftover file is best-effort removed via t.Cleanup.
109func staleSocketPath(t *testing.T, path string) {
110	t.Helper()
111	ln, err := net.Listen("unix", path) //nolint:noctx
112	require.NoError(t, err)
113	ul, ok := ln.(*net.UnixListener)
114	require.True(t, ok, "expected *net.UnixListener, got %T", ln)
115	ul.SetUnlinkOnClose(false)
116	require.NoError(t, ul.Close())
117
118	// Verify it is actually stale: dialing should fail.
119	conn, dialErr := net.DialTimeout("unix", path, 200*time.Millisecond) //nolint:noctx
120	if dialErr == nil {
121		conn.Close()
122		t.Fatalf("expected stale socket at %q to refuse connections", path)
123	}
124	require.True(t, IsStaleSocketErr(dialErr),
125		"expected stale-socket dial error, got %v", dialErr)
126
127	t.Cleanup(func() {
128		_ = os.Remove(path)
129	})
130}
131
132func TestListen_RemovesStaleSocket(t *testing.T) {
133	// t.TempDir() yields a path that may already be near the macOS
134	// sun_path limit; use a short filename to stay well under it.
135	dir := t.TempDir()
136	path := filepath.Join(dir, "s.sock")
137
138	staleSocketPath(t, path)
139
140	// Confirm the stale node is present before we call listen.
141	_, statErr := os.Stat(path)
142	require.NoError(t, statErr, "stale socket file should exist on disk")
143
144	ln, removedStale, err := listen("unix", path)
145	require.NoError(t, err)
146	require.NotNil(t, ln)
147	require.True(t, removedStale, "listen should report removedStale=true")
148	t.Cleanup(func() {
149		_ = ln.Close()
150	})
151}
152
153func TestListen_LiveSocketNotRemoved(t *testing.T) {
154	dir := t.TempDir()
155	path := filepath.Join(dir, "s.sock")
156
157	ln1, err := net.Listen("unix", path) //nolint:noctx
158	require.NoError(t, err)
159
160	// Drain accepts so the listener stays alive and responsive without
161	// blocking the test on a stray connection.
162	var wg sync.WaitGroup
163	wg.Add(1)
164	go func() {
165		defer wg.Done()
166		for {
167			c, err := ln1.Accept()
168			if err != nil {
169				return
170			}
171			_ = c.Close()
172		}
173	}()
174	t.Cleanup(func() {
175		_ = ln1.Close()
176		wg.Wait()
177	})
178
179	ln2, removedStale, err := listen("unix", path)
180	if ln2 != nil {
181		_ = ln2.Close()
182	}
183	require.Error(t, err, "listen on a live socket must fail")
184	require.False(t, removedStale,
185		"a live socket must never be removed (got removedStale=true)")
186
187	// The live socket file must still be on disk and dialable.
188	_, statErr := os.Stat(path)
189	require.NoError(t, statErr, "live socket file should still exist")
190	conn, dialErr := net.DialTimeout("unix", path, 200*time.Millisecond) //nolint:noctx
191	require.NoError(t, dialErr, "live socket should still accept dials")
192	_ = conn.Close()
193}