conn.go

  1package daemon
  2
  3import (
  4	"context"
  5	"errors"
  6	"net"
  7	"sync"
  8	"time"
  9)
 10
 11// connections is a synchronizes access to to a net.Conn pool.
 12type connections struct {
 13	m  map[net.Conn]struct{}
 14	mu sync.Mutex
 15}
 16
 17func (m *connections) Add(c net.Conn) {
 18	m.mu.Lock()
 19	defer m.mu.Unlock()
 20	m.m[c] = struct{}{}
 21}
 22
 23func (m *connections) Close(c net.Conn) error {
 24	m.mu.Lock()
 25	defer m.mu.Unlock()
 26	err := c.Close()
 27	delete(m.m, c)
 28	return err
 29}
 30
 31func (m *connections) Size() int {
 32	m.mu.Lock()
 33	defer m.mu.Unlock()
 34	return len(m.m)
 35}
 36
 37func (m *connections) CloseAll() error {
 38	m.mu.Lock()
 39	defer m.mu.Unlock()
 40	var err error
 41	for c := range m.m {
 42		err = errors.Join(err, c.Close())
 43		delete(m.m, c)
 44	}
 45
 46	return err
 47}
 48
 49// serverConn is a wrapper around a net.Conn that closes the connection when
 50// the one of the timeouts is reached.
 51type serverConn struct {
 52	net.Conn
 53
 54	initTimeout   time.Duration
 55	idleTimeout   time.Duration
 56	maxDeadline   time.Time
 57	closeCanceler context.CancelFunc
 58}
 59
 60var _ net.Conn = (*serverConn)(nil)
 61
 62func (c *serverConn) Write(p []byte) (n int, err error) {
 63	c.updateDeadline()
 64	n, err = c.Conn.Write(p)
 65	if _, isNetErr := err.(net.Error); isNetErr && c.closeCanceler != nil {
 66		c.closeCanceler()
 67	}
 68	return
 69}
 70
 71func (c *serverConn) Read(b []byte) (n int, err error) {
 72	c.updateDeadline()
 73	n, err = c.Conn.Read(b)
 74	if _, isNetErr := err.(net.Error); isNetErr && c.closeCanceler != nil {
 75		c.closeCanceler()
 76	}
 77	return
 78}
 79
 80func (c *serverConn) Close() (err error) {
 81	err = c.Conn.Close()
 82	if c.closeCanceler != nil {
 83		c.closeCanceler()
 84	}
 85	return
 86}
 87
 88func (c *serverConn) updateDeadline() {
 89	switch {
 90	case c.initTimeout > 0:
 91		initTimeout := time.Now().Add(c.initTimeout)
 92		c.initTimeout = 0
 93		if initTimeout.Unix() < c.maxDeadline.Unix() || c.maxDeadline.IsZero() {
 94			c.Conn.SetDeadline(initTimeout) //nolint: errcheck
 95			return
 96		}
 97	case c.idleTimeout > 0:
 98		idleDeadline := time.Now().Add(c.idleTimeout)
 99		if idleDeadline.Unix() < c.maxDeadline.Unix() || c.maxDeadline.IsZero() {
100			c.Conn.SetDeadline(idleDeadline) //nolint: errcheck
101			return
102		}
103	}
104	c.Conn.SetDeadline(c.maxDeadline) //nolint: errcheck
105}