@@ -95,7 +95,7 @@ func (s *Server) Start() error {
errg, _ := errgroup.WithContext(s.ctx)
errg.Go(func() error {
s.logger.Print("Starting Git daemon", "addr", s.Config.Git.ListenAddr)
- if err := s.GitDaemon.Start(); !errors.Is(err, daemon.ErrServerClosed) {
+ if err := s.GitDaemon.ListenAndServe(); !errors.Is(err, daemon.ErrServerClosed) {
return err
}
return nil
@@ -8,6 +8,7 @@ import (
"path/filepath"
"strings"
"sync"
+ "sync/atomic"
"time"
"github.com/charmbracelet/log"
@@ -43,7 +44,6 @@ var ErrServerClosed = fmt.Errorf("git: %w", net.ErrClosed)
// GitDaemon represents a Git daemon.
type GitDaemon struct {
ctx context.Context
- listener net.Listener
addr string
finished chan struct{}
conns connections
@@ -52,6 +52,7 @@ type GitDaemon struct {
wg sync.WaitGroup
once sync.Once
logger *log.Logger
+ done atomic.Bool // indicates if the server has been closed
}
// NewDaemon returns a new Git daemon.
@@ -70,26 +71,31 @@ func NewGitDaemon(ctx context.Context) (*GitDaemon, error) {
return d, nil
}
-// Start starts the Git TCP daemon.
-func (d *GitDaemon) Start() error {
- // listen on the socket
- {
- listener, err := net.Listen("tcp", d.addr)
- if err != nil {
- return err
- }
- d.listener = listener
+// ListenAndServe starts the Git TCP daemon.
+func (d *GitDaemon) ListenAndServe() error {
+ if d.done.Load() {
+ return ErrServerClosed
+ }
+ listener, err := net.Listen("tcp", d.addr)
+ if err != nil {
+ return err
}
+ return d.Serve(listener)
+}
- // close eventual connections to the socket
- defer d.listener.Close() // nolint: errcheck
+// Serve listens on the TCP network address and serves Git requests.
+func (d *GitDaemon) Serve(listener net.Listener) error {
+ if d.done.Load() {
+ return ErrServerClosed
+ }
d.wg.Add(1)
defer d.wg.Done()
+ defer listener.Close() //nolint:errcheck
var tempDelay time.Duration
for {
- conn, err := d.listener.Accept()
+ conn, err := listener.Accept()
if err != nil {
select {
case <-d.finished:
@@ -305,21 +311,30 @@ func (d *GitDaemon) handleClient(conn net.Conn) {
// Close closes the underlying listener.
func (d *GitDaemon) Close() error {
- d.once.Do(func() { close(d.finished) })
- err := d.listener.Close()
+ err := d.closeListener()
d.conns.CloseAll() // nolint: errcheck
return err
}
+// closeListener closes the listener and the finished channel.
+func (d *GitDaemon) closeListener() error {
+ if d.done.Load() {
+ return ErrServerClosed
+ }
+ d.once.Do(func() {
+ close(d.finished)
+ d.done.Store(true)
+ })
+ return nil
+}
+
// Shutdown gracefully shuts down the daemon.
func (d *GitDaemon) Shutdown(ctx context.Context) error {
- // in the case when git daemon was never started
- if d.listener == nil {
- return nil
+ if d.done.Load() {
+ return ErrServerClosed
}
- d.once.Do(func() { close(d.finished) })
- err := d.listener.Close()
+ err := d.closeListener()
finished := make(chan struct{}, 1)
go func() {
d.wg.Wait()
@@ -59,7 +59,7 @@ func TestMain(m *testing.M) {
}
testDaemon = d
go func() {
- if err := d.Start(); err != ErrServerClosed {
+ if err := d.ListenAndServe(); err != ErrServerClosed {
log.Fatal(err)
}
}()
@@ -75,11 +75,21 @@ func TestMain(m *testing.M) {
}
func TestIdleTimeout(t *testing.T) {
- c, err := net.Dial("tcp", testDaemon.addr)
- if err != nil {
- t.Fatal(err)
+ var err error
+ var c net.Conn
+ var tries int
+ for {
+ c, err = net.Dial("tcp", testDaemon.addr)
+ if err != nil && tries >= 3 {
+ t.Fatal(err)
+ }
+ tries++
+ if testDaemon.conns.Size() != 0 {
+ break
+ }
+ time.Sleep(10 * time.Millisecond)
}
- time.Sleep(time.Second)
+ time.Sleep(2 * time.Second)
_, err = readPktline(c)
if err == nil {
t.Errorf("expected error, got nil")