1//go:build !windows
2
3package server
4
5import (
6 "errors"
7 "fmt"
8 "io/fs"
9 "net"
10 "os"
11 "path/filepath"
12 "strings"
13 "sync"
14 "syscall"
15 "testing"
16 "time"
17
18 "github.com/stretchr/testify/require"
19)
20
21// fakeTimeoutErr is a minimal net.Error implementation whose Timeout()
22// returns true. It is used to verify that IsStaleSocketErr never
23// classifies a timeout as stale.
24type fakeTimeoutErr struct{}
25
26func (fakeTimeoutErr) Error() string { return "fake timeout" }
27func (fakeTimeoutErr) Timeout() bool { return true }
28func (fakeTimeoutErr) Temporary() bool { return true }
29
30func TestIsStaleSocketErr(t *testing.T) {
31 t.Parallel()
32
33 cases := []struct {
34 name string
35 err error
36 want bool
37 }{
38 {name: "nil", err: nil, want: false},
39 {name: "ECONNREFUSED", err: syscall.ECONNREFUSED, want: true},
40 {
41 name: "wrapped ECONNREFUSED",
42 err: fmt.Errorf("dial: %w", syscall.ECONNREFUSED),
43 want: true,
44 },
45 {name: "fs.ErrNotExist", err: fs.ErrNotExist, want: true},
46 {
47 name: "wrapped fs.ErrNotExist",
48 err: fmt.Errorf("stat: %w", fs.ErrNotExist),
49 want: true,
50 },
51 {name: "timeout net.Error", err: fakeTimeoutErr{}, want: false},
52 {name: "generic error", err: errors.New("boom"), want: false},
53 }
54
55 for _, tc := range cases {
56 t.Run(tc.name, func(t *testing.T) {
57 t.Parallel()
58 require.Equal(t, tc.want, IsStaleSocketErr(tc.err))
59 })
60 }
61}
62
63func TestDefaultHost_XDGRuntimeDir(t *testing.T) {
64 dir := t.TempDir()
65 t.Setenv("XDG_RUNTIME_DIR", dir)
66
67 host := DefaultHost()
68
69 require.True(t, strings.HasPrefix(host, "unix://"),
70 "DefaultHost should return a unix:// URL, got %q", host)
71 path := strings.TrimPrefix(host, "unix://")
72
73 // The composed path may exceed maxUnixSocketPathLen and fall back
74 // to /tmp; only assert containment when it did not.
75 if len(filepath.Join(dir, "crush.sock")) <= maxUnixSocketPathLen {
76 require.True(t, strings.HasPrefix(path, dir),
77 "socket path %q should live under %q", path, dir)
78 }
79 require.True(t, strings.HasSuffix(path, ".sock"),
80 "socket path %q should end in .sock", path)
81 require.Contains(t, filepath.Base(path), "crush",
82 "socket filename should contain 'crush'")
83}
84
85func TestDefaultHost_FallbackTemp(t *testing.T) {
86 t.Setenv("XDG_RUNTIME_DIR", "")
87
88 host := DefaultHost()
89
90 require.True(t, strings.HasPrefix(host, "unix://"),
91 "DefaultHost should return a unix:// URL, got %q", host)
92 path := strings.TrimPrefix(host, "unix://")
93 require.NotEmpty(t, path, "fallback socket path must be non-empty")
94 require.True(t, strings.HasSuffix(path, ".sock"),
95 "socket path %q should end in .sock", path)
96 require.Contains(t, filepath.Base(path), "crush",
97 "socket filename should contain 'crush'")
98}
99
100// staleSocketPath creates a deterministic stale unix socket file on
101// disk: the socket node exists but no goroutine is accepting on it.
102// It does so by binding a listener, disabling unlink-on-close, then
103// closing the listener. The path is returned so the caller can probe
104// it. A leftover file is best-effort removed via t.Cleanup.
105func staleSocketPath(t *testing.T, path string) {
106 t.Helper()
107 ln, err := net.Listen("unix", path)
108 require.NoError(t, err)
109 ul, ok := ln.(*net.UnixListener)
110 require.True(t, ok, "expected *net.UnixListener, got %T", ln)
111 ul.SetUnlinkOnClose(false)
112 require.NoError(t, ul.Close())
113
114 // Verify it is actually stale: dialing should fail.
115 conn, dialErr := net.DialTimeout("unix", path, 200*time.Millisecond)
116 if dialErr == nil {
117 conn.Close()
118 t.Fatalf("expected stale socket at %q to refuse connections", path)
119 }
120 require.True(t, IsStaleSocketErr(dialErr),
121 "expected stale-socket dial error, got %v", dialErr)
122
123 t.Cleanup(func() {
124 _ = os.Remove(path)
125 })
126}
127
128func TestListen_RemovesStaleSocket(t *testing.T) {
129 // t.TempDir() yields a path that may already be near the macOS
130 // sun_path limit; use a short filename to stay well under it.
131 dir := t.TempDir()
132 path := filepath.Join(dir, "s.sock")
133
134 staleSocketPath(t, path)
135
136 // Confirm the stale node is present before we call listen.
137 _, statErr := os.Stat(path)
138 require.NoError(t, statErr, "stale socket file should exist on disk")
139
140 ln, removedStale, err := listen("unix", path)
141 require.NoError(t, err)
142 require.NotNil(t, ln)
143 require.True(t, removedStale, "listen should report removedStale=true")
144 t.Cleanup(func() {
145 _ = ln.Close()
146 })
147}
148
149func TestListen_LiveSocketNotRemoved(t *testing.T) {
150 dir := t.TempDir()
151 path := filepath.Join(dir, "s.sock")
152
153 ln1, err := net.Listen("unix", path)
154 require.NoError(t, err)
155
156 // Drain accepts so the listener stays alive and responsive without
157 // blocking the test on a stray connection.
158 var wg sync.WaitGroup
159 wg.Add(1)
160 go func() {
161 defer wg.Done()
162 for {
163 c, err := ln1.Accept()
164 if err != nil {
165 return
166 }
167 _ = c.Close()
168 }
169 }()
170 t.Cleanup(func() {
171 _ = ln1.Close()
172 wg.Wait()
173 })
174
175 ln2, removedStale, err := listen("unix", path)
176 if ln2 != nil {
177 _ = ln2.Close()
178 }
179 require.Error(t, err, "listen on a live socket must fail")
180 require.False(t, removedStale,
181 "a live socket must never be removed (got removedStale=true)")
182
183 // The live socket file must still be on disk and dialable.
184 _, statErr := os.Stat(path)
185 require.NoError(t, statErr, "live socket file should still exist")
186 conn, dialErr := net.DialTimeout("unix", path, 200*time.Millisecond)
187 require.NoError(t, dialErr, "live socket should still accept dials")
188 _ = conn.Close()
189}