fix(daemon): close listener only once (#615)

Ayman Bagabas created

* fix(daemon): close listener only once

* refactor(daemon): rename Start to ListenAndServe and implement Serve

* fix(daemon): use atomic.Bool for server

* fix(daemon): attempt to fix idle timeout test

Change summary

cmd/soft/serve/server.go  |  2 
pkg/daemon/daemon.go      | 55 ++++++++++++++++++++++++++--------------
pkg/daemon/daemon_test.go | 20 +++++++++++---
3 files changed, 51 insertions(+), 26 deletions(-)

Detailed changes

cmd/soft/serve/server.go 🔗

@@ -95,7 +95,7 @@ func (s *Server) Start() error {
 	errg, _ := errgroup.WithContext(s.ctx)
 	errg.Go(func() error {
 		s.logger.Print("Starting Git daemon", "addr", s.Config.Git.ListenAddr)
-		if err := s.GitDaemon.Start(); !errors.Is(err, daemon.ErrServerClosed) {
+		if err := s.GitDaemon.ListenAndServe(); !errors.Is(err, daemon.ErrServerClosed) {
 			return err
 		}
 		return nil

pkg/daemon/daemon.go 🔗

@@ -8,6 +8,7 @@ import (
 	"path/filepath"
 	"strings"
 	"sync"
+	"sync/atomic"
 	"time"
 
 	"github.com/charmbracelet/log"
@@ -43,7 +44,6 @@ var ErrServerClosed = fmt.Errorf("git: %w", net.ErrClosed)
 // GitDaemon represents a Git daemon.
 type GitDaemon struct {
 	ctx      context.Context
-	listener net.Listener
 	addr     string
 	finished chan struct{}
 	conns    connections
@@ -52,6 +52,7 @@ type GitDaemon struct {
 	wg       sync.WaitGroup
 	once     sync.Once
 	logger   *log.Logger
+	done     atomic.Bool // indicates if the server has been closed
 }
 
 // NewDaemon returns a new Git daemon.
@@ -70,26 +71,31 @@ func NewGitDaemon(ctx context.Context) (*GitDaemon, error) {
 	return d, nil
 }
 
-// Start starts the Git TCP daemon.
-func (d *GitDaemon) Start() error {
-	// listen on the socket
-	{
-		listener, err := net.Listen("tcp", d.addr)
-		if err != nil {
-			return err
-		}
-		d.listener = listener
+// ListenAndServe starts the Git TCP daemon.
+func (d *GitDaemon) ListenAndServe() error {
+	if d.done.Load() {
+		return ErrServerClosed
+	}
+	listener, err := net.Listen("tcp", d.addr)
+	if err != nil {
+		return err
 	}
+	return d.Serve(listener)
+}
 
-	// close eventual connections to the socket
-	defer d.listener.Close() // nolint: errcheck
+// Serve listens on the TCP network address and serves Git requests.
+func (d *GitDaemon) Serve(listener net.Listener) error {
+	if d.done.Load() {
+		return ErrServerClosed
+	}
 
 	d.wg.Add(1)
 	defer d.wg.Done()
+	defer listener.Close() //nolint:errcheck
 
 	var tempDelay time.Duration
 	for {
-		conn, err := d.listener.Accept()
+		conn, err := listener.Accept()
 		if err != nil {
 			select {
 			case <-d.finished:
@@ -305,21 +311,30 @@ func (d *GitDaemon) handleClient(conn net.Conn) {
 
 // Close closes the underlying listener.
 func (d *GitDaemon) Close() error {
-	d.once.Do(func() { close(d.finished) })
-	err := d.listener.Close()
+	err := d.closeListener()
 	d.conns.CloseAll() // nolint: errcheck
 	return err
 }
 
+// closeListener closes the listener and the finished channel.
+func (d *GitDaemon) closeListener() error {
+	if d.done.Load() {
+		return ErrServerClosed
+	}
+	d.once.Do(func() {
+		close(d.finished)
+		d.done.Store(true)
+	})
+	return nil
+}
+
 // Shutdown gracefully shuts down the daemon.
 func (d *GitDaemon) Shutdown(ctx context.Context) error {
-	// in the case when git daemon was never started
-	if d.listener == nil {
-		return nil
+	if d.done.Load() {
+		return ErrServerClosed
 	}
 
-	d.once.Do(func() { close(d.finished) })
-	err := d.listener.Close()
+	err := d.closeListener()
 	finished := make(chan struct{}, 1)
 	go func() {
 		d.wg.Wait()

pkg/daemon/daemon_test.go 🔗

@@ -59,7 +59,7 @@ func TestMain(m *testing.M) {
 	}
 	testDaemon = d
 	go func() {
-		if err := d.Start(); err != ErrServerClosed {
+		if err := d.ListenAndServe(); err != ErrServerClosed {
 			log.Fatal(err)
 		}
 	}()
@@ -75,11 +75,21 @@ func TestMain(m *testing.M) {
 }
 
 func TestIdleTimeout(t *testing.T) {
-	c, err := net.Dial("tcp", testDaemon.addr)
-	if err != nil {
-		t.Fatal(err)
+	var err error
+	var c net.Conn
+	var tries int
+	for {
+		c, err = net.Dial("tcp", testDaemon.addr)
+		if err != nil && tries >= 3 {
+			t.Fatal(err)
+		}
+		tries++
+		if testDaemon.conns.Size() != 0 {
+			break
+		}
+		time.Sleep(10 * time.Millisecond)
 	}
-	time.Sleep(time.Second)
+	time.Sleep(2 * time.Second)
 	_, err = readPktline(c)
 	if err == nil {
 		t.Errorf("expected error, got nil")