1package daemon
2
3import (
4 "context"
5 "errors"
6 "net"
7 "sync"
8 "time"
9)
10
11// connections is a synchronizes access to to a net.Conn pool.
12type connections struct {
13 m map[net.Conn]struct{}
14 mu sync.Mutex
15}
16
17func (m *connections) Add(c net.Conn) {
18 m.mu.Lock()
19 defer m.mu.Unlock()
20 m.m[c] = struct{}{}
21}
22
23func (m *connections) Close(c net.Conn) error {
24 m.mu.Lock()
25 defer m.mu.Unlock()
26 err := c.Close()
27 delete(m.m, c)
28 return err
29}
30
31func (m *connections) Size() int {
32 m.mu.Lock()
33 defer m.mu.Unlock()
34 return len(m.m)
35}
36
37func (m *connections) CloseAll() error {
38 m.mu.Lock()
39 defer m.mu.Unlock()
40 var err error
41 for c := range m.m {
42 err = errors.Join(err, c.Close())
43 delete(m.m, c)
44 }
45
46 return err
47}
48
49// serverConn is a wrapper around a net.Conn that closes the connection when
50// the one of the timeouts is reached.
51type serverConn struct {
52 net.Conn
53
54 initTimeout time.Duration
55 idleTimeout time.Duration
56 maxDeadline time.Time
57 closeCanceler context.CancelFunc
58}
59
60var _ net.Conn = (*serverConn)(nil)
61
62func (c *serverConn) Write(p []byte) (n int, err error) {
63 c.updateDeadline()
64 n, err = c.Conn.Write(p)
65 if _, isNetErr := err.(net.Error); isNetErr && c.closeCanceler != nil {
66 c.closeCanceler()
67 }
68 return
69}
70
71func (c *serverConn) Read(b []byte) (n int, err error) {
72 c.updateDeadline()
73 n, err = c.Conn.Read(b)
74 if _, isNetErr := err.(net.Error); isNetErr && c.closeCanceler != nil {
75 c.closeCanceler()
76 }
77 return
78}
79
80func (c *serverConn) Close() (err error) {
81 err = c.Conn.Close()
82 if c.closeCanceler != nil {
83 c.closeCanceler()
84 }
85 return
86}
87
88func (c *serverConn) updateDeadline() {
89 switch {
90 case c.initTimeout > 0:
91 initTimeout := time.Now().Add(c.initTimeout)
92 c.initTimeout = 0
93 if initTimeout.Unix() < c.maxDeadline.Unix() || c.maxDeadline.IsZero() {
94 c.Conn.SetDeadline(initTimeout) //nolint: errcheck
95 return
96 }
97 case c.idleTimeout > 0:
98 idleDeadline := time.Now().Add(c.idleTimeout)
99 if idleDeadline.Unix() < c.maxDeadline.Unix() || c.maxDeadline.IsZero() {
100 c.Conn.SetDeadline(idleDeadline) //nolint: errcheck
101 return
102 }
103 }
104 c.Conn.SetDeadline(c.maxDeadline) //nolint: errcheck
105}