socket_test.go

  1//go:build !windows
  2
  3package server
  4
  5import (
  6	"errors"
  7	"fmt"
  8	"io/fs"
  9	"net"
 10	"os"
 11	"path/filepath"
 12	"strings"
 13	"sync"
 14	"syscall"
 15	"testing"
 16	"time"
 17
 18	"github.com/stretchr/testify/require"
 19)
 20
 21// fakeTimeoutErr is a minimal net.Error implementation whose Timeout()
 22// returns true. It is used to verify that IsStaleSocketErr never
 23// classifies a timeout as stale.
 24type fakeTimeoutErr struct{}
 25
 26func (fakeTimeoutErr) Error() string   { return "fake timeout" }
 27func (fakeTimeoutErr) Timeout() bool   { return true }
 28func (fakeTimeoutErr) Temporary() bool { return true }
 29
 30func TestIsStaleSocketErr(t *testing.T) {
 31	t.Parallel()
 32
 33	cases := []struct {
 34		name string
 35		err  error
 36		want bool
 37	}{
 38		{name: "nil", err: nil, want: false},
 39		{name: "ECONNREFUSED", err: syscall.ECONNREFUSED, want: true},
 40		{
 41			name: "wrapped ECONNREFUSED",
 42			err:  fmt.Errorf("dial: %w", syscall.ECONNREFUSED),
 43			want: true,
 44		},
 45		{name: "fs.ErrNotExist", err: fs.ErrNotExist, want: true},
 46		{
 47			name: "wrapped fs.ErrNotExist",
 48			err:  fmt.Errorf("stat: %w", fs.ErrNotExist),
 49			want: true,
 50		},
 51		{name: "timeout net.Error", err: fakeTimeoutErr{}, want: false},
 52		{name: "generic error", err: errors.New("boom"), want: false},
 53	}
 54
 55	for _, tc := range cases {
 56		t.Run(tc.name, func(t *testing.T) {
 57			t.Parallel()
 58			require.Equal(t, tc.want, IsStaleSocketErr(tc.err))
 59		})
 60	}
 61}
 62
 63func TestDefaultHost_XDGRuntimeDir(t *testing.T) {
 64	dir := t.TempDir()
 65	t.Setenv("XDG_RUNTIME_DIR", dir)
 66
 67	host := DefaultHost()
 68
 69	require.True(t, strings.HasPrefix(host, "unix://"),
 70		"DefaultHost should return a unix:// URL, got %q", host)
 71	path := strings.TrimPrefix(host, "unix://")
 72
 73	// The composed path may exceed maxUnixSocketPathLen and fall back
 74	// to /tmp; only assert containment when it did not.
 75	if len(filepath.Join(dir, "crush.sock")) <= maxUnixSocketPathLen {
 76		require.True(t, strings.HasPrefix(path, dir),
 77			"socket path %q should live under %q", path, dir)
 78	}
 79	require.True(t, strings.HasSuffix(path, ".sock"),
 80		"socket path %q should end in .sock", path)
 81	require.Contains(t, filepath.Base(path), "crush",
 82		"socket filename should contain 'crush'")
 83}
 84
 85func TestDefaultHost_FallbackTemp(t *testing.T) {
 86	t.Setenv("XDG_RUNTIME_DIR", "")
 87
 88	host := DefaultHost()
 89
 90	require.True(t, strings.HasPrefix(host, "unix://"),
 91		"DefaultHost should return a unix:// URL, got %q", host)
 92	path := strings.TrimPrefix(host, "unix://")
 93	require.NotEmpty(t, path, "fallback socket path must be non-empty")
 94	require.True(t, strings.HasSuffix(path, ".sock"),
 95		"socket path %q should end in .sock", path)
 96	require.Contains(t, filepath.Base(path), "crush",
 97		"socket filename should contain 'crush'")
 98}
 99
100// staleSocketPath creates a deterministic stale unix socket file on
101// disk: the socket node exists but no goroutine is accepting on it.
102// It does so by binding a listener, disabling unlink-on-close, then
103// closing the listener. The path is returned so the caller can probe
104// it. A leftover file is best-effort removed via t.Cleanup.
105func staleSocketPath(t *testing.T, path string) {
106	t.Helper()
107	ln, err := net.Listen("unix", path)
108	require.NoError(t, err)
109	ul, ok := ln.(*net.UnixListener)
110	require.True(t, ok, "expected *net.UnixListener, got %T", ln)
111	ul.SetUnlinkOnClose(false)
112	require.NoError(t, ul.Close())
113
114	// Verify it is actually stale: dialing should fail.
115	conn, dialErr := net.DialTimeout("unix", path, 200*time.Millisecond)
116	if dialErr == nil {
117		conn.Close()
118		t.Fatalf("expected stale socket at %q to refuse connections", path)
119	}
120	require.True(t, IsStaleSocketErr(dialErr),
121		"expected stale-socket dial error, got %v", dialErr)
122
123	t.Cleanup(func() {
124		_ = os.Remove(path)
125	})
126}
127
128func TestListen_RemovesStaleSocket(t *testing.T) {
129	// t.TempDir() yields a path that may already be near the macOS
130	// sun_path limit; use a short filename to stay well under it.
131	dir := t.TempDir()
132	path := filepath.Join(dir, "s.sock")
133
134	staleSocketPath(t, path)
135
136	// Confirm the stale node is present before we call listen.
137	_, statErr := os.Stat(path)
138	require.NoError(t, statErr, "stale socket file should exist on disk")
139
140	ln, removedStale, err := listen("unix", path)
141	require.NoError(t, err)
142	require.NotNil(t, ln)
143	require.True(t, removedStale, "listen should report removedStale=true")
144	t.Cleanup(func() {
145		_ = ln.Close()
146	})
147}
148
149func TestListen_LiveSocketNotRemoved(t *testing.T) {
150	dir := t.TempDir()
151	path := filepath.Join(dir, "s.sock")
152
153	ln1, err := net.Listen("unix", path)
154	require.NoError(t, err)
155
156	// Drain accepts so the listener stays alive and responsive without
157	// blocking the test on a stray connection.
158	var wg sync.WaitGroup
159	wg.Add(1)
160	go func() {
161		defer wg.Done()
162		for {
163			c, err := ln1.Accept()
164			if err != nil {
165				return
166			}
167			_ = c.Close()
168		}
169	}()
170	t.Cleanup(func() {
171		_ = ln1.Close()
172		wg.Wait()
173	})
174
175	ln2, removedStale, err := listen("unix", path)
176	if ln2 != nil {
177		_ = ln2.Close()
178	}
179	require.Error(t, err, "listen on a live socket must fail")
180	require.False(t, removedStale,
181		"a live socket must never be removed (got removedStale=true)")
182
183	// The live socket file must still be on disk and dialable.
184	_, statErr := os.Stat(path)
185	require.NoError(t, statErr, "live socket file should still exist")
186	conn, dialErr := net.DialTimeout("unix", path, 200*time.Millisecond)
187	require.NoError(t, dialErr, "live socket should still accept dials")
188	_ = conn.Close()
189}