1// Package daemon provides Git daemon server functionality.
2package daemon
3
4import (
5 "context"
6 "errors"
7 "net"
8 "sync"
9 "time"
10)
11
12// connections is a synchronizes access to to a net.Conn pool.
13type connections struct {
14 m map[net.Conn]struct{}
15 mu sync.Mutex
16}
17
18func (m *connections) Add(c net.Conn) {
19 m.mu.Lock()
20 defer m.mu.Unlock()
21 m.m[c] = struct{}{}
22}
23
24func (m *connections) Close(c net.Conn) error {
25 m.mu.Lock()
26 defer m.mu.Unlock()
27 err := c.Close()
28 delete(m.m, c)
29 return err //nolint:wrapcheck
30}
31
32func (m *connections) Size() int {
33 m.mu.Lock()
34 defer m.mu.Unlock()
35 return len(m.m)
36}
37
38func (m *connections) CloseAll() error {
39 m.mu.Lock()
40 defer m.mu.Unlock()
41 var err error
42 for c := range m.m {
43 err = errors.Join(err, c.Close())
44 delete(m.m, c)
45 }
46
47 return err
48}
49
50// serverConn is a wrapper around a net.Conn that closes the connection when
51// the one of the timeouts is reached.
52type serverConn struct {
53 net.Conn
54
55 initTimeout time.Duration
56 idleTimeout time.Duration
57 maxDeadline time.Time
58 closeCanceler context.CancelFunc
59}
60
61var _ net.Conn = (*serverConn)(nil)
62
63func (c *serverConn) Write(p []byte) (n int, err error) {
64 c.updateDeadline()
65 n, err = c.Conn.Write(p)
66 if _, isNetErr := err.(net.Error); isNetErr && c.closeCanceler != nil {
67 c.closeCanceler()
68 }
69 return
70}
71
72func (c *serverConn) Read(b []byte) (n int, err error) {
73 c.updateDeadline()
74 n, err = c.Conn.Read(b)
75 if _, isNetErr := err.(net.Error); isNetErr && c.closeCanceler != nil {
76 c.closeCanceler()
77 }
78 return
79}
80
81func (c *serverConn) Close() (err error) {
82 err = c.Conn.Close()
83 if c.closeCanceler != nil {
84 c.closeCanceler()
85 }
86 return
87}
88
89func (c *serverConn) updateDeadline() {
90 switch {
91 case c.initTimeout > 0:
92 initTimeout := time.Now().Add(c.initTimeout)
93 c.initTimeout = 0
94 if initTimeout.Unix() < c.maxDeadline.Unix() || c.maxDeadline.IsZero() {
95 c.SetDeadline(initTimeout) //nolint:errcheck,gosec
96 return
97 }
98 case c.idleTimeout > 0:
99 idleDeadline := time.Now().Add(c.idleTimeout)
100 if idleDeadline.Unix() < c.maxDeadline.Unix() || c.maxDeadline.IsZero() {
101 c.SetDeadline(idleDeadline) //nolint:errcheck,gosec
102 return
103 }
104 }
105 c.SetDeadline(c.maxDeadline) //nolint:errcheck,gosec
106}