1//go:build !windows
2
3package server
4
5import (
6 "errors"
7 "fmt"
8 "io/fs"
9 "net"
10 "os"
11 "path/filepath"
12 "strings"
13 "sync"
14 "syscall"
15 "testing"
16 "time"
17
18 "github.com/stretchr/testify/require"
19)
20
21// fakeTimeoutErr is a minimal net.Error implementation whose Timeout()
22// returns true. It is used to verify that IsStaleSocketErr never
23// classifies a timeout as stale.
24type fakeTimeoutErr struct{}
25
26func (fakeTimeoutErr) Error() string { return "fake timeout" }
27func (fakeTimeoutErr) Timeout() bool { return true }
28func (fakeTimeoutErr) Temporary() bool { return true }
29
30func TestIsStaleSocketErr(t *testing.T) {
31 t.Parallel()
32
33 cases := []struct {
34 name string
35 err error
36 want bool
37 }{
38 {name: "nil", err: nil, want: false},
39 {name: "ECONNREFUSED", err: syscall.ECONNREFUSED, want: true},
40 {
41 name: "wrapped ECONNREFUSED",
42 err: fmt.Errorf("dial: %w", syscall.ECONNREFUSED),
43 want: true,
44 },
45 {name: "fs.ErrNotExist", err: fs.ErrNotExist, want: true},
46 {
47 name: "wrapped fs.ErrNotExist",
48 err: fmt.Errorf("stat: %w", fs.ErrNotExist),
49 want: true,
50 },
51 {name: "timeout net.Error", err: fakeTimeoutErr{}, want: false},
52 {name: "generic error", err: errors.New("boom"), want: false},
53 }
54
55 for _, tc := range cases {
56 t.Run(tc.name, func(t *testing.T) {
57 t.Parallel()
58 require.Equal(t, tc.want, IsStaleSocketErr(tc.err))
59 })
60 }
61}
62
63func TestDefaultHost_XDGRuntimeDir(t *testing.T) {
64 dir := t.TempDir()
65 t.Setenv("XDG_RUNTIME_DIR", dir)
66
67 host := DefaultHost()
68
69 require.True(t, strings.HasPrefix(host, "unix://"),
70 "DefaultHost should return a unix:// URL, got %q", host)
71 path := strings.TrimPrefix(host, "unix://")
72
73 // The composed path may exceed maxUnixSocketPathLen and fall back
74 // to /tmp; only assert containment when it did not. Recompose the
75 // path under dir (rather than checking the returned path length,
76 // which is short again after a /tmp fallback) to decide whether a
77 // fallback happened. The socket is named crush-<uid>.sock.
78 composed := filepath.Join(dir, filepath.Base(path))
79 if len(composed) <= maxUnixSocketPathLen {
80 require.True(t, strings.HasPrefix(path, dir),
81 "socket path %q should live under %q", path, dir)
82 }
83 require.True(t, strings.HasSuffix(path, ".sock"),
84 "socket path %q should end in .sock", path)
85 require.Contains(t, filepath.Base(path), "crush",
86 "socket filename should contain 'crush'")
87}
88
89func TestDefaultHost_FallbackTemp(t *testing.T) {
90 t.Setenv("XDG_RUNTIME_DIR", "")
91
92 host := DefaultHost()
93
94 require.True(t, strings.HasPrefix(host, "unix://"),
95 "DefaultHost should return a unix:// URL, got %q", host)
96 path := strings.TrimPrefix(host, "unix://")
97 require.NotEmpty(t, path, "fallback socket path must be non-empty")
98 require.True(t, strings.HasSuffix(path, ".sock"),
99 "socket path %q should end in .sock", path)
100 require.Contains(t, filepath.Base(path), "crush",
101 "socket filename should contain 'crush'")
102}
103
104// staleSocketPath creates a deterministic stale unix socket file on
105// disk: the socket node exists but no goroutine is accepting on it.
106// It does so by binding a listener, disabling unlink-on-close, then
107// closing the listener. The path is returned so the caller can probe
108// it. A leftover file is best-effort removed via t.Cleanup.
109func staleSocketPath(t *testing.T, path string) {
110 t.Helper()
111 ln, err := net.Listen("unix", path) //nolint:noctx
112 require.NoError(t, err)
113 ul, ok := ln.(*net.UnixListener)
114 require.True(t, ok, "expected *net.UnixListener, got %T", ln)
115 ul.SetUnlinkOnClose(false)
116 require.NoError(t, ul.Close())
117
118 // Verify it is actually stale: dialing should fail.
119 conn, dialErr := net.DialTimeout("unix", path, 200*time.Millisecond) //nolint:noctx
120 if dialErr == nil {
121 conn.Close()
122 t.Fatalf("expected stale socket at %q to refuse connections", path)
123 }
124 require.True(t, IsStaleSocketErr(dialErr),
125 "expected stale-socket dial error, got %v", dialErr)
126
127 t.Cleanup(func() {
128 _ = os.Remove(path)
129 })
130}
131
132func TestListen_RemovesStaleSocket(t *testing.T) {
133 // t.TempDir() yields a path that may already be near the macOS
134 // sun_path limit; use a short filename to stay well under it.
135 dir := t.TempDir()
136 path := filepath.Join(dir, "s.sock")
137
138 staleSocketPath(t, path)
139
140 // Confirm the stale node is present before we call listen.
141 _, statErr := os.Stat(path)
142 require.NoError(t, statErr, "stale socket file should exist on disk")
143
144 ln, removedStale, err := listen("unix", path)
145 require.NoError(t, err)
146 require.NotNil(t, ln)
147 require.True(t, removedStale, "listen should report removedStale=true")
148 t.Cleanup(func() {
149 _ = ln.Close()
150 })
151}
152
153func TestListen_LiveSocketNotRemoved(t *testing.T) {
154 dir := t.TempDir()
155 path := filepath.Join(dir, "s.sock")
156
157 ln1, err := net.Listen("unix", path) //nolint:noctx
158 require.NoError(t, err)
159
160 // Drain accepts so the listener stays alive and responsive without
161 // blocking the test on a stray connection.
162 var wg sync.WaitGroup
163 wg.Add(1)
164 go func() {
165 defer wg.Done()
166 for {
167 c, err := ln1.Accept()
168 if err != nil {
169 return
170 }
171 _ = c.Close()
172 }
173 }()
174 t.Cleanup(func() {
175 _ = ln1.Close()
176 wg.Wait()
177 })
178
179 ln2, removedStale, err := listen("unix", path)
180 if ln2 != nil {
181 _ = ln2.Close()
182 }
183 require.Error(t, err, "listen on a live socket must fail")
184 require.False(t, removedStale,
185 "a live socket must never be removed (got removedStale=true)")
186
187 // The live socket file must still be on disk and dialable.
188 _, statErr := os.Stat(path)
189 require.NoError(t, statErr, "live socket file should still exist")
190 conn, dialErr := net.DialTimeout("unix", path, 200*time.Millisecond) //nolint:noctx
191 require.NoError(t, dialErr, "live socket should still accept dials")
192 _ = conn.Close()
193}