fix(server): git daemon

Ayman Bagabas created

use custom connection (copied from gliderlabs/ssh) to handle timeouts

Change summary

server/config/config.go          |   9 
server/db/fakedb/db.go           | 181 ++++++++++++++++++++++++++++++++++
server/git/daemon/conn.go        |  55 ++++++++++
server/git/daemon/daemon.go      | 170 ++++++++++++++++++-------------
server/git/daemon/daemon_test.go |  20 ++-
server/git/error.go              |   7 
server/server.go                 |   7 +
7 files changed, 362 insertions(+), 87 deletions(-)

Detailed changes

server/config/config.go 🔗

@@ -27,14 +27,15 @@ type SSHConfig struct {
 	AllowKeyless  bool   `env:"ALLOW_KEYLESS" envDefault:"true"`
 	AllowPassword bool   `env:"ALLOW_PASSWORD" envDefault:"false"`
 	Password      string `env:"PASSWORD"`
+	MaxTimeout    int    `env:"MAX_TIMEOUT" envDefault:"0"`
+	IdleTimeout   int    `env:"IDLE_TIMEOUT" envDefault:"300"`
 }
 
 // GitConfig is the Git protocol configuration for the server.
 type GitConfig struct {
-	Port       int `env:"PORT" envDefault:"9418"`
-	MaxTimeout int `env:"MAX_TIMEOUT" envDefault:"300"`
-	// MaxReadTimeout is the maximum time a client can take to send a request.
-	MaxReadTimeout int `env:"MAX_READ_TIMEOUT" envDefault:"3"`
+	Port           int `env:"PORT" envDefault:"9418"`
+	MaxTimeout     int `env:"MAX_TIMEOUT" envDefault:"0"`
+	IdleTimeout    int `env:"IDLE_TIMEOUT" envDefault:"3"`
 	MaxConnections int `env:"SOFT_SERVE_GIT_MAX_CONNECTIONS" envDefault:"32"`
 }
 

server/db/fakedb/db.go 🔗

@@ -0,0 +1,181 @@
+package fakedb
+
+import (
+	"github.com/charmbracelet/soft-serve/proto"
+	"github.com/charmbracelet/soft-serve/server/db"
+	"github.com/charmbracelet/soft-serve/server/db/types"
+)
+
+var _ db.Store = &FakeDB{}
+
+type FakeDB struct{}
+
+// Open implements db.Store
+func (*FakeDB) Open(name string) (proto.RepositoryService, error) {
+	return nil, nil
+}
+
+// GetConfig implements db.Store
+func (*FakeDB) GetConfig() (*types.Config, error) {
+	return nil, nil
+}
+
+// SetConfigAllowKeyless implements db.Store
+func (*FakeDB) SetConfigAllowKeyless(bool) error {
+	return nil
+}
+
+// SetConfigAnonAccess implements db.Store
+func (*FakeDB) SetConfigAnonAccess(string) error {
+	return nil
+}
+
+// SetConfigHost implements db.Store
+func (*FakeDB) SetConfigHost(string) error {
+	return nil
+}
+
+// SetConfigName implements db.Store
+func (*FakeDB) SetConfigName(string) error {
+	return nil
+}
+
+// SetConfigPort implements db.Store
+func (*FakeDB) SetConfigPort(int) error {
+	return nil
+}
+
+// AddUser implements db.Store
+func (*FakeDB) AddUser(name string, login string, email string, password string, isAdmin bool) error {
+	return nil
+}
+
+// CountUsers implements db.Store
+func (*FakeDB) CountUsers() (int, error) {
+	return 0, nil
+}
+
+// DeleteUser implements db.Store
+func (*FakeDB) DeleteUser(int) error {
+	return nil
+}
+
+// GetUser implements db.Store
+func (*FakeDB) GetUser(int) (*types.User, error) {
+	return nil, nil
+}
+
+// GetUserByEmail implements db.Store
+func (*FakeDB) GetUserByEmail(string) (*types.User, error) {
+	return nil, nil
+}
+
+// GetUserByLogin implements db.Store
+func (*FakeDB) GetUserByLogin(string) (*types.User, error) {
+	return nil, nil
+}
+
+// GetUserByPublicKey implements db.Store
+func (*FakeDB) GetUserByPublicKey(string) (*types.User, error) {
+	return nil, nil
+}
+
+// SetUserAdmin implements db.Store
+func (*FakeDB) SetUserAdmin(*types.User, bool) error {
+	return nil
+}
+
+// SetUserEmail implements db.Store
+func (*FakeDB) SetUserEmail(*types.User, string) error {
+	return nil
+}
+
+// SetUserLogin implements db.Store
+func (*FakeDB) SetUserLogin(*types.User, string) error {
+	return nil
+}
+
+// SetUserName implements db.Store
+func (*FakeDB) SetUserName(*types.User, string) error {
+	return nil
+}
+
+// SetUserPassword implements db.Store
+func (*FakeDB) SetUserPassword(*types.User, string) error {
+	return nil
+}
+
+// AddUserPublicKey implements db.Store
+func (*FakeDB) AddUserPublicKey(*types.User, string) error {
+	return nil
+}
+
+// DeleteUserPublicKey implements db.Store
+func (*FakeDB) DeleteUserPublicKey(int) error {
+	return nil
+}
+
+// GetUserPublicKeys implements db.Store
+func (*FakeDB) GetUserPublicKeys(*types.User) ([]*types.PublicKey, error) {
+	return nil, nil
+}
+
+// AddRepo implements db.Store
+func (*FakeDB) AddRepo(name string, projectName string, description string, isPrivate bool) error {
+	return nil
+}
+
+// DeleteRepo implements db.Store
+func (*FakeDB) DeleteRepo(string) error {
+	return nil
+}
+
+// GetRepo implements db.Store
+func (*FakeDB) GetRepo(string) (*types.Repo, error) {
+	return nil, nil
+}
+
+// SetRepoDescription implements db.Store
+func (*FakeDB) SetRepoDescription(string, string) error {
+	return nil
+}
+
+// SetRepoPrivate implements db.Store
+func (*FakeDB) SetRepoPrivate(string, bool) error {
+	return nil
+}
+
+// SetRepoProjectName implements db.Store
+func (*FakeDB) SetRepoProjectName(string, string) error {
+	return nil
+}
+
+// AddRepoCollab implements db.Store
+func (*FakeDB) AddRepoCollab(string, *types.User) error {
+	return nil
+}
+
+// DeleteRepoCollab implements db.Store
+func (*FakeDB) DeleteRepoCollab(int, int) error {
+	return nil
+}
+
+// ListRepoCollabs implements db.Store
+func (*FakeDB) ListRepoCollabs(string) ([]*types.User, error) {
+	return nil, nil
+}
+
+// ListRepoPublicKeys implements db.Store
+func (*FakeDB) ListRepoPublicKeys(string) ([]*types.PublicKey, error) {
+	return nil, nil
+}
+
+// Close implements db.Store
+func (*FakeDB) Close() error {
+	return nil
+}
+
+// CreateDB implements db.Store
+func (*FakeDB) CreateDB() error {
+	return nil
+}

server/git/daemon/conn.go 🔗

@@ -0,0 +1,55 @@
+package daemon
+
+import (
+	"context"
+	"net"
+	"time"
+)
+
+type serverConn struct {
+	net.Conn
+
+	idleTimeout   time.Duration
+	maxDeadline   time.Time
+	closeCanceler context.CancelFunc
+}
+
+func (c *serverConn) Write(p []byte) (n int, err error) {
+	c.updateDeadline()
+	n, err = c.Conn.Write(p)
+	if _, isNetErr := err.(net.Error); isNetErr && c.closeCanceler != nil {
+		c.closeCanceler()
+	}
+	return
+}
+
+func (c *serverConn) Read(b []byte) (n int, err error) {
+	c.updateDeadline()
+	n, err = c.Conn.Read(b)
+	if _, isNetErr := err.(net.Error); isNetErr && c.closeCanceler != nil {
+		c.closeCanceler()
+	}
+	return
+}
+
+func (c *serverConn) Close() (err error) {
+	err = c.Conn.Close()
+	if c.closeCanceler != nil {
+		c.closeCanceler()
+	}
+	return
+}
+
+func (c *serverConn) updateDeadline() {
+	switch {
+	case c.idleTimeout > 0:
+		idleDeadline := time.Now().Add(c.idleTimeout)
+		if idleDeadline.Unix() < c.maxDeadline.Unix() || c.maxDeadline.IsZero() {
+			c.Conn.SetDeadline(idleDeadline)
+			return
+		}
+		fallthrough
+	default:
+		c.Conn.SetDeadline(c.maxDeadline)
+	}
+}

server/git/daemon/daemon.go 🔗

@@ -25,7 +25,7 @@ var ErrServerClosed = errors.New("git: Server closed")
 type Daemon struct {
 	listener net.Listener
 	addr     string
-	exit     chan struct{}
+	finished chan struct{}
 	conns    map[net.Conn]struct{}
 	cfg      *config.Config
 	wg       sync.WaitGroup
@@ -36,24 +36,25 @@ type Daemon struct {
 func NewDaemon(cfg *config.Config) (*Daemon, error) {
 	addr := fmt.Sprintf("%s:%d", cfg.Host, cfg.Git.Port)
 	d := &Daemon{
-		addr:  addr,
-		exit:  make(chan struct{}),
-		cfg:   cfg,
-		conns: make(map[net.Conn]struct{}),
+		addr:     addr,
+		finished: make(chan struct{}),
+		cfg:      cfg,
+		conns:    make(map[net.Conn]struct{}),
 	}
 	listener, err := net.Listen("tcp", d.addr)
 	if err != nil {
 		return nil, err
 	}
 	d.listener = listener
-	d.wg.Add(1)
 	return d, nil
 }
 
 // Start starts the Git TCP daemon.
 func (d *Daemon) Start() error {
+	defer d.listener.Close()
 	// set up channel on which to send accepted connections
 	listen := make(chan net.Conn, d.cfg.Git.MaxConnections)
+	d.wg.Add(1)
 	go d.acceptConnection(d.listener, listen)
 
 	// loop work cycle with accept connections or interrupt
@@ -66,10 +67,7 @@ func (d *Daemon) Start() error {
 				d.handleClient(conn)
 				d.wg.Done()
 			}()
-		case <-d.exit:
-			if err := d.Close(); err != nil {
-				return err
-			}
+		case <-d.finished:
 			return ErrServerClosed
 		}
 	}
@@ -89,8 +87,8 @@ func (d *Daemon) acceptConnection(listener net.Listener, listen chan<- net.Conn)
 		conn, err := listener.Accept()
 		if err != nil {
 			select {
-			case <-d.exit:
-				log.Printf("git: listener closed")
+			case <-d.finished:
+				log.Printf("git: %s", ErrServerClosed)
 				return
 			default:
 				log.Printf("git: error accepting connection: %v", err)
@@ -102,7 +100,19 @@ func (d *Daemon) acceptConnection(listener net.Listener, listen chan<- net.Conn)
 }
 
 // handleClient handles a git protocol client.
-func (d *Daemon) handleClient(c net.Conn) {
+func (d *Daemon) handleClient(conn net.Conn) {
+	ctx, cancel := context.WithCancel(context.Background())
+	idleTimeout := time.Duration(d.cfg.Git.IdleTimeout) * time.Second
+	c := &serverConn{
+		Conn:          conn,
+		idleTimeout:   idleTimeout,
+		closeCanceler: cancel,
+	}
+	if d.cfg.Git.MaxTimeout > 0 {
+		dur := time.Duration(d.cfg.Git.MaxTimeout) * time.Second
+		c.maxDeadline = time.Now().Add(dur)
+	}
+	defer c.Close()
 	d.conns[c] = struct{}{}
 	defer delete(d.conns, c)
 
@@ -110,85 +120,99 @@ func (d *Daemon) handleClient(c net.Conn) {
 	if len(d.conns) >= d.cfg.Git.MaxConnections {
 		log.Printf("git: max connections reached, closing %s", c.RemoteAddr())
 		fatal(c, git.ErrMaxConns)
-		return
-	}
-
-	// Set connection timeout.
-	if err := c.SetDeadline(time.Now().Add(time.Duration(d.cfg.Git.MaxTimeout) * time.Second)); err != nil {
-		log.Printf("git: error setting deadline: %v", err)
-		fatal(c, git.ErrSystemMalfunction)
+		c.closeCanceler()
 		return
 	}
 
 	readc := make(chan struct{}, 1)
+	s := pktline.NewScanner(c)
 	go func() {
-		select {
-		case <-time.After(time.Duration(d.cfg.Git.MaxReadTimeout) * time.Second):
-			log.Printf("git: read timeout from %s", c.RemoteAddr())
-			fatal(c, git.ErrMaxTimeout)
-		case <-readc:
+		if !s.Scan() {
+			if err := s.Err(); err != nil {
+				if nerr, ok := err.(net.Error); ok && nerr.Timeout() {
+					fatal(c, git.ErrTimeout)
+				} else {
+					log.Printf("git: error scanning pktline: %v", err)
+					fatal(c, git.ErrSystemMalfunction)
+				}
+			}
+			return
 		}
+		readc <- struct{}{}
 	}()
 
-	s := pktline.NewScanner(c)
-	if !s.Scan() {
-		if err := s.Err(); err != nil {
-			log.Printf("git: error scanning pktline: %v", err)
-			fatal(c, git.ErrSystemMalfunction)
+	select {
+	case <-ctx.Done():
+		if err := ctx.Err(); err != nil {
+			log.Printf("git: connection context error: %v", err)
 		}
 		return
-	}
-	readc <- struct{}{}
-
-	line := s.Bytes()
-	split := bytes.SplitN(line, []byte{' '}, 2)
-	if len(split) != 2 {
-		return
-	}
+	case <-readc:
+		line := s.Bytes()
+		split := bytes.SplitN(line, []byte{' '}, 2)
+		if len(split) != 2 {
+			fatal(c, git.ErrInvalidRequest)
+			return
+		}
 
-	var repo string
-	cmd := string(split[0])
-	opts := bytes.Split(split[1], []byte{'\x00'})
-	if len(opts) == 0 {
-		return
-	}
-	repo = filepath.Clean(string(opts[0]))
-
-	log.Printf("git: connect %s %s %s", c.RemoteAddr(), cmd, repo)
-	defer log.Printf("git: disconnect %s %s %s", c.RemoteAddr(), cmd, repo)
-	repo = strings.TrimPrefix(repo, "/")
-	auth := d.cfg.AuthRepo(strings.TrimSuffix(repo, ".git"), nil)
-	if auth < proto.ReadOnlyAccess {
-		fatal(c, git.ErrNotAuthed)
-		return
-	}
-	// git bare repositories should end in ".git"
-	// https://git-scm.com/docs/gitrepository-layout
-	if !strings.HasSuffix(repo, ".git") {
-		repo += ".git"
-	}
+		var repo string
+		cmd := string(split[0])
+		opts := bytes.Split(split[1], []byte{'\x00'})
+		if len(opts) == 0 {
+			fatal(c, git.ErrInvalidRequest)
+			return
+		}
+		repo = filepath.Clean(string(opts[0]))
+
+		log.Printf("git: connect %s %s %s", c.RemoteAddr(), cmd, repo)
+		defer log.Printf("git: disconnect %s %s %s", c.RemoteAddr(), cmd, repo)
+		repo = strings.TrimPrefix(repo, "/")
+		auth := d.cfg.AuthRepo(strings.TrimSuffix(repo, ".git"), nil)
+		if auth < proto.ReadOnlyAccess {
+			fatal(c, git.ErrNotAuthed)
+			return
+		}
+		// git bare repositories should end in ".git"
+		// https://git-scm.com/docs/gitrepository-layout
+		if !strings.HasSuffix(repo, ".git") {
+			repo += ".git"
+		}
 
-	err := git.GitPack(c, c, c, cmd, d.cfg.RepoPath(), repo)
-	if err == git.ErrInvalidRepo {
-		trimmed := strings.TrimSuffix(repo, ".git")
-		log.Printf("git: invalid repo %q trying again %q", repo, trimmed)
-		err = git.GitPack(c, c, c, cmd, d.cfg.RepoPath(), trimmed)
-	}
-	if err != nil {
-		fatal(c, err)
-		return
+		err := git.GitPack(c, c, c, cmd, d.cfg.RepoPath(), repo)
+		if err == git.ErrInvalidRepo {
+			trimmed := strings.TrimSuffix(repo, ".git")
+			log.Printf("git: invalid repo %q trying again %q", repo, trimmed)
+			err = git.GitPack(c, c, c, cmd, d.cfg.RepoPath(), trimmed)
+		}
+		if err != nil {
+			fatal(c, err)
+			return
+		}
 	}
 }
 
 // Close closes the underlying listener.
 func (d *Daemon) Close() error {
-	d.once.Do(func() { close(d.exit) })
+	d.once.Do(func() { close(d.finished) })
+	for c := range d.conns {
+		c.Close()
+		delete(d.conns, c)
+	}
 	return d.listener.Close()
 }
 
 // Shutdown gracefully shuts down the daemon.
-func (d *Daemon) Shutdown(_ context.Context) error {
-	d.once.Do(func() { close(d.exit) })
-	d.wg.Wait()
-	return nil
+func (d *Daemon) Shutdown(ctx context.Context) error {
+	d.once.Do(func() { close(d.finished) })
+	finished := make(chan struct{}, 1)
+	go func() {
+		d.wg.Wait()
+		finished <- struct{}{}
+	}()
+	select {
+	case <-ctx.Done():
+		return ctx.Err()
+	case <-finished:
+		return d.listener.Close()
+	}
 }

server/git/daemon/daemon_test.go 🔗

@@ -8,7 +8,9 @@ import (
 	"os"
 	"testing"
 
+	"github.com/charmbracelet/soft-serve/proto"
 	"github.com/charmbracelet/soft-serve/server/config"
+	"github.com/charmbracelet/soft-serve/server/db/fakedb"
 	"github.com/charmbracelet/soft-serve/server/git"
 	"github.com/go-git/go-git/v5/plumbing/format/pktline"
 )
@@ -22,18 +24,20 @@ func TestMain(m *testing.M) {
 	}
 	defer os.RemoveAll(tmp)
 	cfg := &config.Config{
-		Host:     "",
-		DataPath: tmp,
+		Host:       "",
+		DataPath:   tmp,
+		AnonAccess: proto.ReadOnlyAccess,
 		Git: config.GitConfig{
+			// Reduce the max read timeout to 1 second so we can test the timeout.
+			IdleTimeout: 3,
 			// Reduce the max timeout to 100 second so we can test the timeout.
 			MaxTimeout: 100,
-			// Reduce the max read timeout to 1 second so we can test the timeout.
-			MaxReadTimeout: 1,
 			// Reduce the max connections to 3 so we can test the timeout.
 			MaxConnections: 3,
 			Port:           9418,
 		},
 	}
+	cfg = cfg.WithDB(&fakedb.FakeDB{})
 	d, err := NewDaemon(cfg)
 	if err != nil {
 		log.Fatal(err)
@@ -48,7 +52,7 @@ func TestMain(m *testing.M) {
 	os.Exit(m.Run())
 }
 
-func TestMaxReadTimeout(t *testing.T) {
+func TestIdleTimeout(t *testing.T) {
 	c, err := net.Dial("tcp", testDaemon.addr)
 	if err != nil {
 		t.Fatal(err)
@@ -57,8 +61,8 @@ func TestMaxReadTimeout(t *testing.T) {
 	if err != nil {
 		t.Fatalf("expected nil, got error: %v", err)
 	}
-	if out != git.ErrMaxTimeout.Error() {
-		t.Fatalf("expected %q error, got nil", git.ErrMaxTimeout)
+	if out != git.ErrTimeout.Error() {
+		t.Fatalf("expected %q error, got %q", git.ErrTimeout, out)
 	}
 }
 
@@ -75,7 +79,7 @@ func TestInvalidRepo(t *testing.T) {
 		t.Fatalf("expected nil, got error: %v", err)
 	}
 	if out != git.ErrInvalidRepo.Error() {
-		t.Fatalf("expected %q error, got nil", git.ErrInvalidRepo)
+		t.Fatalf("expected %q error, got %q", git.ErrInvalidRepo, out)
 	}
 }
 

server/git/error.go 🔗

@@ -11,8 +11,11 @@ var ErrSystemMalfunction = errors.New("something went wrong")
 // ErrInvalidRepo represents an attempt to access a non-existent repo.
 var ErrInvalidRepo = errors.New("invalid repo")
 
+// ErrInvalidRequest represents an invalid request.
+var ErrInvalidRequest = errors.New("invalid request")
+
 // ErrMaxConns represents a maximum connection limit being reached.
 var ErrMaxConns = errors.New("too many connections, try again later")
 
-// ErrMaxTimeout is returned when the maximum read timeout is exceeded.
-var ErrMaxTimeout = errors.New("git: max timeout reached")
+// ErrTimeout is returned when the maximum read timeout is exceeded.
+var ErrTimeout = errors.New("I/O timeout reached")

server/server.go 🔗

@@ -4,6 +4,7 @@ import (
 	"context"
 	"fmt"
 	"log"
+	"time"
 
 	appCfg "github.com/charmbracelet/soft-serve/config"
 	cm "github.com/charmbracelet/soft-serve/server/cmd"
@@ -68,6 +69,12 @@ func NewServer(cfg *config.Config) *Server {
 	if err != nil {
 		log.Fatalln(err)
 	}
+	if cfg.SSH.MaxTimeout > 0 {
+		sh.MaxTimeout = time.Duration(cfg.SSH.MaxTimeout) * time.Second
+	}
+	if cfg.SSH.IdleTimeout > 0 {
+		sh.IdleTimeout = time.Duration(cfg.SSH.IdleTimeout) * time.Second
+	}
 	s.SSHServer = sh
 	d, err := daemon.NewDaemon(cfg)
 	if err != nil {