conn.go

  1// Package daemon provides Git daemon server functionality.
  2package daemon
  3
  4import (
  5	"context"
  6	"errors"
  7	"net"
  8	"sync"
  9	"time"
 10)
 11
 12// connections is a synchronizes access to to a net.Conn pool.
 13type connections struct {
 14	m  map[net.Conn]struct{}
 15	mu sync.Mutex
 16}
 17
 18func (m *connections) Add(c net.Conn) {
 19	m.mu.Lock()
 20	defer m.mu.Unlock()
 21	m.m[c] = struct{}{}
 22}
 23
 24func (m *connections) Close(c net.Conn) error {
 25	m.mu.Lock()
 26	defer m.mu.Unlock()
 27	err := c.Close()
 28	delete(m.m, c)
 29	return err //nolint:wrapcheck
 30}
 31
 32func (m *connections) Size() int {
 33	m.mu.Lock()
 34	defer m.mu.Unlock()
 35	return len(m.m)
 36}
 37
 38func (m *connections) CloseAll() error {
 39	m.mu.Lock()
 40	defer m.mu.Unlock()
 41	var err error
 42	for c := range m.m {
 43		err = errors.Join(err, c.Close())
 44		delete(m.m, c)
 45	}
 46
 47	return err
 48}
 49
 50// serverConn is a wrapper around a net.Conn that closes the connection when
 51// the one of the timeouts is reached.
 52type serverConn struct {
 53	net.Conn
 54
 55	initTimeout   time.Duration
 56	idleTimeout   time.Duration
 57	maxDeadline   time.Time
 58	closeCanceler context.CancelFunc
 59}
 60
 61var _ net.Conn = (*serverConn)(nil)
 62
 63func (c *serverConn) Write(p []byte) (n int, err error) {
 64	c.updateDeadline()
 65	n, err = c.Conn.Write(p)
 66	if _, isNetErr := err.(net.Error); isNetErr && c.closeCanceler != nil {
 67		c.closeCanceler()
 68	}
 69	return
 70}
 71
 72func (c *serverConn) Read(b []byte) (n int, err error) {
 73	c.updateDeadline()
 74	n, err = c.Conn.Read(b)
 75	if _, isNetErr := err.(net.Error); isNetErr && c.closeCanceler != nil {
 76		c.closeCanceler()
 77	}
 78	return
 79}
 80
 81func (c *serverConn) Close() (err error) {
 82	err = c.Conn.Close()
 83	if c.closeCanceler != nil {
 84		c.closeCanceler()
 85	}
 86	return
 87}
 88
 89func (c *serverConn) updateDeadline() {
 90	switch {
 91	case c.initTimeout > 0:
 92		initTimeout := time.Now().Add(c.initTimeout)
 93		c.initTimeout = 0
 94		if initTimeout.Unix() < c.maxDeadline.Unix() || c.maxDeadline.IsZero() {
 95			c.SetDeadline(initTimeout) //nolint:errcheck,gosec
 96			return
 97		}
 98	case c.idleTimeout > 0:
 99		idleDeadline := time.Now().Add(c.idleTimeout)
100		if idleDeadline.Unix() < c.maxDeadline.Unix() || c.maxDeadline.IsZero() {
101			c.SetDeadline(idleDeadline) //nolint:errcheck,gosec
102			return
103		}
104	}
105	c.SetDeadline(c.maxDeadline) //nolint:errcheck,gosec
106}