feat(server): git daemon

Ayman Bagabas created

Implement a TCP git daemon
Add tests

Reference: https://git-scm.com/book/en/v2/Git-on-the-Server-Git-Daemon

Change summary

cmd/soft/serve.go                |   1 
examples/setuid/main.go          |   2 
git/command.go                   |  11 +
go.mod                           |   1 
server/config/config.go          |  21 +-
server/git/auth.go               |  46 ++++++
server/git/daemon/daemon.go      | 194 ++++++++++++++++++++++++++
server/git/daemon/daemon_test.go |  87 +++++++++++
server/git/error.go              |  18 ++
server/git/pack.go               | 131 +++++++++++++++++
server/git/ssh.go                | 250 ----------------------------------
server/git/ssh/ssh.go            |  88 +++++++++++
server/git/ssh/ssh_test.go       |  23 +-
server/server.go                 |  67 ++++++--
14 files changed, 648 insertions(+), 292 deletions(-)

Detailed changes

cmd/soft/serve.go 🔗

@@ -36,7 +36,6 @@ var (
 			signal.Notify(done, os.Interrupt, syscall.SIGINT, syscall.SIGTERM)
 			<-done
 
-			log.Printf("Stopping SSH server on %s:%d", cfg.BindAddr, cfg.Port)
 			ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
 			defer cancel()
 			if err := s.Shutdown(ctx); err != nil {

examples/setuid/main.go 🔗

@@ -52,7 +52,7 @@ func main() {
 
 	log.Printf("Starting SSH server on %s:%d", cfg.BindAddr, cfg.Port)
 	go func() {
-		if err := s.Serve(ls); err != nil {
+		if err := s.SSHServer.Serve(ls); err != nil {
 			log.Fatalln(err)
 		}
 	}()

git/command.go 🔗

@@ -0,0 +1,11 @@
+package git
+
+import "github.com/gogs/git-module"
+
+// RunInDirOptions are options for RunInDir.
+type RunInDirOptions = git.RunInDirOptions
+
+// NewCommand creates a new git command.
+func NewCommand(args ...string) *git.Command {
+	return git.NewCommand(args...)
+}

go.mod 🔗

@@ -31,6 +31,7 @@ require (
 	github.com/muesli/mango-cobra v1.2.0
 	github.com/muesli/roff v0.1.0
 	github.com/spf13/cobra v1.6.1
+	golang.org/x/sync v0.0.0-20210220032951-036812b2e83c
 	gopkg.in/yaml.v3 v3.0.1
 )
 

server/config/config.go 🔗

@@ -16,14 +16,19 @@ type Callbacks interface {
 
 // Config is the configuration for Soft Serve.
 type Config struct {
-	BindAddr         string   `env:"SOFT_SERVE_BIND_ADDRESS" envDefault:""`
-	Host             string   `env:"SOFT_SERVE_HOST" envDefault:"localhost"`
-	Port             int      `env:"SOFT_SERVE_PORT" envDefault:"23231"`
-	KeyPath          string   `env:"SOFT_SERVE_KEY_PATH"`
-	RepoPath         string   `env:"SOFT_SERVE_REPO_PATH" envDefault:".repos"`
-	InitialAdminKeys []string `env:"SOFT_SERVE_INITIAL_ADMIN_KEY" envSeparator:"\n"`
-	Callbacks        Callbacks
-	ErrorLog         *log.Logger
+	BindAddr      string `env:"SOFT_SERVE_BIND_ADDRESS" envDefault:""`
+	Host          string `env:"SOFT_SERVE_HOST" envDefault:"localhost"`
+	Port          int    `env:"SOFT_SERVE_PORT" envDefault:"23231"`
+	GitPort       int    `env:"SOFT_SERVE_GIT_PORT" envDefault:"9418"`
+	GitMaxTimeout int    `env:"SOFT_SERVE_GIT_MAX_TIMEOUT" envDefault:"300"`
+	// MaxReadTimeout is the maximum time a client can take to send a request.
+	GitMaxReadTimeout int      `env:"SOFT_SERVE_GIT_MAX_READ_TIMEOUT" envDefault:"3"`
+	GitMaxConnections int      `env:"SOFT_SERVE_GIT_MAX_CONNECTIONS" envDefault:"32"`
+	KeyPath           string   `env:"SOFT_SERVE_KEY_PATH"`
+	RepoPath          string   `env:"SOFT_SERVE_REPO_PATH" envDefault:".repos"`
+	InitialAdminKeys  []string `env:"SOFT_SERVE_INITIAL_ADMIN_KEY" envSeparator:"\n"`
+	Callbacks         Callbacks
+	ErrorLog          *log.Logger
 }
 
 // DefaultConfig returns a Config with the values populated with the defaults

server/git/auth.go 🔗

@@ -0,0 +1,46 @@
+package git
+
+import "github.com/gliderlabs/ssh"
+
+// AccessLevel is the level of access allowed to a repo.
+type AccessLevel int
+
+const (
+	// NoAccess does not allow access to the repo.
+	NoAccess AccessLevel = iota
+
+	// ReadOnlyAccess allows read-only access to the repo.
+	ReadOnlyAccess
+
+	// ReadWriteAccess allows read and write access to the repo.
+	ReadWriteAccess
+
+	// AdminAccess allows read, write, and admin access to the repo.
+	AdminAccess
+)
+
+// String implements the Stringer interface for AccessLevel.
+func (a AccessLevel) String() string {
+	switch a {
+	case NoAccess:
+		return "no-access"
+	case ReadOnlyAccess:
+		return "read-only"
+	case ReadWriteAccess:
+		return "read-write"
+	case AdminAccess:
+		return "admin-access"
+	default:
+		return ""
+	}
+}
+
+// Hooks is an interface that allows for custom authorization
+// implementations and post push/fetch notifications. Prior to git access,
+// AuthRepo will be called with the ssh.Session public key and the repo name.
+// Implementers return the appropriate AccessLevel.
+type Hooks interface {
+	AuthRepo(string, ssh.PublicKey) AccessLevel
+	Push(string, ssh.PublicKey)
+	Fetch(string, ssh.PublicKey)
+}

server/git/daemon/daemon.go 🔗

@@ -0,0 +1,194 @@
+package daemon
+
+import (
+	"bytes"
+	"context"
+	"errors"
+	"fmt"
+	"log"
+	"net"
+	"path/filepath"
+	"strings"
+	"sync"
+	"time"
+
+	"github.com/charmbracelet/soft-serve/server/config"
+	"github.com/charmbracelet/soft-serve/server/git"
+	"github.com/go-git/go-git/v5/plumbing/format/pktline"
+)
+
+// ErrServerClosed indicates that the server has been closed.
+var ErrServerClosed = errors.New("git: Server closed")
+
+// Daemon represents a Git daemon.
+type Daemon struct {
+	auth     git.Hooks
+	listener net.Listener
+	addr     string
+	exit     chan struct{}
+	conns    map[net.Conn]struct{}
+	cfg      *config.Config
+	wg       sync.WaitGroup
+}
+
+// NewDaemon returns a new Git daemon.
+func NewDaemon(cfg *config.Config, auth git.Hooks) (*Daemon, error) {
+	addr := fmt.Sprintf("%s:%d", cfg.BindAddr, cfg.GitPort)
+	d := &Daemon{
+		addr:  addr,
+		auth:  auth,
+		exit:  make(chan struct{}),
+		cfg:   cfg,
+		conns: make(map[net.Conn]struct{}),
+	}
+	listener, err := net.Listen("tcp", d.addr)
+	if err != nil {
+		return nil, err
+	}
+	d.listener = listener
+	d.wg.Add(1)
+	return d, nil
+}
+
+// Start starts the Git TCP daemon.
+func (d *Daemon) Start() error {
+	// set up channel on which to send accepted connections
+	listen := make(chan net.Conn, d.cfg.GitMaxConnections)
+	go d.acceptConnection(d.listener, listen)
+
+	// loop work cycle with accept connections or interrupt
+	// by system signal
+	for {
+		log.Printf("listener len %d cap %d", len(listen), cap(listen))
+		select {
+		case conn := <-listen:
+			d.wg.Add(1)
+			go func() {
+				d.handleClient(conn)
+				d.wg.Done()
+			}()
+		case <-d.exit:
+			if err := d.Close(); err != nil {
+				return err
+			}
+			return ErrServerClosed
+		}
+	}
+}
+
+func fatal(c net.Conn, err error) {
+	git.WritePktline(c, err)
+	if err := c.Close(); err != nil {
+		log.Printf("git: error closing connection: %v", err)
+	}
+}
+
+// acceptConnection accepts connections on the listener.
+func (d *Daemon) acceptConnection(listener net.Listener, listen chan<- net.Conn) {
+	defer d.wg.Done()
+	for {
+		conn, err := listener.Accept()
+		if err != nil {
+			select {
+			case <-d.exit:
+				log.Printf("git: listener closed")
+				return
+			default:
+				log.Printf("git: error accepting connection: %v", err)
+				continue
+			}
+		}
+		listen <- conn
+	}
+}
+
+// handleClient handles a git protocol client.
+func (d *Daemon) handleClient(c net.Conn) {
+	d.conns[c] = struct{}{}
+	defer delete(d.conns, c)
+
+	// Close connection if there are too many open connections.
+	if len(d.conns) >= d.cfg.GitMaxConnections {
+		log.Printf("git: max connections reached, closing %s", c.RemoteAddr())
+		fatal(c, git.ErrMaxConns)
+		return
+	}
+
+	// Set connection timeout.
+	if err := c.SetDeadline(time.Now().Add(time.Duration(d.cfg.GitMaxTimeout) * time.Second)); err != nil {
+		log.Printf("git: error setting deadline: %v", err)
+		fatal(c, git.ErrSystemMalfunction)
+		return
+	}
+
+	readc := make(chan struct{}, 1)
+	go func() {
+		select {
+		case <-time.After(time.Duration(d.cfg.GitMaxReadTimeout) * time.Second):
+			log.Printf("git: read timeout from %s", c.RemoteAddr())
+			fatal(c, git.ErrMaxTimeout)
+		case <-readc:
+		}
+	}()
+
+	s := pktline.NewScanner(c)
+	if !s.Scan() {
+		if err := s.Err(); err != nil {
+			log.Printf("git: error scanning pktline: %v", err)
+			fatal(c, git.ErrSystemMalfunction)
+		}
+		return
+	}
+	readc <- struct{}{}
+
+	line := s.Bytes()
+	split := bytes.SplitN(line, []byte{' '}, 2)
+	if len(split) != 2 {
+		return
+	}
+
+	var repo string
+	cmd := string(split[0])
+	opts := bytes.Split(split[1], []byte{'\x00'})
+	if len(opts) == 0 {
+		return
+	}
+	repo = filepath.Clean(string(opts[0]))
+
+	log.Printf("git: connect %s %s %s", c.RemoteAddr(), cmd, repo)
+	defer log.Printf("git: disconnect %s %s %s", c.RemoteAddr(), cmd, repo)
+	repo = strings.TrimPrefix(repo, "/")
+	auth := d.auth.AuthRepo(strings.TrimSuffix(repo, ".git"), nil)
+	if auth < git.ReadOnlyAccess {
+		fatal(c, git.ErrNotAuthed)
+		return
+	}
+	// git bare repositories should end in ".git"
+	// https://git-scm.com/docs/gitrepository-layout
+	if !strings.HasSuffix(repo, ".git") {
+		repo += ".git"
+	}
+
+	err := git.GitPack(c, c, c, cmd, d.cfg.RepoPath, repo)
+	if err == git.ErrInvalidRepo {
+		trimmed := strings.TrimSuffix(repo, ".git")
+		log.Printf("git: invalid repo %q trying again %q", repo, trimmed)
+		err = git.GitPack(c, c, c, cmd, d.cfg.RepoPath, trimmed)
+	}
+	if err != nil {
+		fatal(c, err)
+		return
+	}
+}
+
+// Close closes the underlying listener.
+func (d *Daemon) Close() error {
+	return d.listener.Close()
+}
+
+// Shutdown gracefully shuts down the daemon.
+func (d *Daemon) Shutdown(_ context.Context) error {
+	close(d.exit)
+	d.wg.Wait()
+	return nil
+}

server/git/daemon/daemon_test.go 🔗

@@ -0,0 +1,87 @@
+package daemon
+
+import (
+	"bytes"
+	"context"
+	"io"
+	"log"
+	"net"
+	"os"
+	"testing"
+
+	appCfg "github.com/charmbracelet/soft-serve/config"
+	"github.com/charmbracelet/soft-serve/server/config"
+	"github.com/charmbracelet/soft-serve/server/git"
+	"github.com/go-git/go-git/v5/plumbing/format/pktline"
+)
+
+var testDaemon *Daemon
+
+func TestMain(m *testing.M) {
+	cfg := config.DefaultConfig()
+	// Reduce the max connections to 3 so we can test the timeout.
+	cfg.GitMaxConnections = 3
+	// Reduce the max timeout to 100 second so we can test the timeout.
+	cfg.GitMaxTimeout = 100
+	// Reduce the max read timeout to 1 second so we can test the timeout.
+	cfg.GitMaxReadTimeout = 1
+	ac, err := appCfg.NewConfig(cfg)
+	if err != nil {
+		log.Fatal(err)
+	}
+	d, err := NewDaemon(cfg, ac)
+	if err != nil {
+		log.Fatal(err)
+	}
+	testDaemon = d
+	go func() {
+		if err := d.Start(); err != ErrServerClosed {
+			log.Fatal(err)
+		}
+	}()
+	defer d.Shutdown(context.Background())
+	os.Exit(m.Run())
+}
+
+func TestMaxReadTimeout(t *testing.T) {
+	c, err := net.Dial("tcp", testDaemon.addr)
+	if err != nil {
+		t.Fatal(err)
+	}
+	out, err := readPktline(c)
+	if err != nil {
+		t.Fatalf("expected nil, got error: %v", err)
+	}
+	if out != git.ErrMaxTimeout.Error() {
+		t.Fatalf("expected %q error, got nil", git.ErrMaxTimeout)
+	}
+}
+
+func TestInvalidRepo(t *testing.T) {
+	c, err := net.Dial("tcp", testDaemon.addr)
+	if err != nil {
+		t.Fatal(err)
+	}
+	if err := pktline.NewEncoder(c).EncodeString("git-upload-pack /test.git\x00"); err != nil {
+		t.Fatalf("expected nil, got error: %v", err)
+	}
+	out, err := readPktline(c)
+	if err != nil {
+		t.Fatalf("expected nil, got error: %v", err)
+	}
+	if out != git.ErrInvalidRepo.Error() {
+		t.Fatalf("expected %q error, got nil", git.ErrInvalidRepo)
+	}
+}
+
+func readPktline(c net.Conn) (string, error) {
+	buf, err := io.ReadAll(c)
+	if err != nil {
+		return "", err
+	}
+	pktout := pktline.NewScanner(bytes.NewReader(buf))
+	if !pktout.Scan() {
+		return "", pktout.Err()
+	}
+	return string(pktout.Bytes()), nil
+}

server/git/error.go 🔗

@@ -0,0 +1,18 @@
+package git
+
+import "errors"
+
+// ErrNotAuthed represents unauthorized access.
+var ErrNotAuthed = errors.New("you are not authorized to do this")
+
+// ErrSystemMalfunction represents a general system error returned to clients.
+var ErrSystemMalfunction = errors.New("something went wrong")
+
+// ErrInvalidRepo represents an attempt to access a non-existent repo.
+var ErrInvalidRepo = errors.New("invalid repo")
+
+// ErrMaxConns represents a maximum connection limit being reached.
+var ErrMaxConns = errors.New("too many connections, try again later")
+
+// ErrMaxTimeout is returned when the maximum read timeout is exceeded.
+var ErrMaxTimeout = errors.New("git: max timeout reached")

server/git/pack.go 🔗

@@ -0,0 +1,131 @@
+package git
+
+import (
+	"fmt"
+	"io"
+	"log"
+	"os"
+	"path/filepath"
+	"strings"
+
+	"github.com/charmbracelet/soft-serve/git"
+	"github.com/go-git/go-git/v5/plumbing/format/pktline"
+)
+
+// GitPack runs the git pack protocol against the provided repo.
+func GitPack(out io.Writer, in io.Reader, er io.Writer, gitCmd string, repoDir string, repo string) error {
+	cmd := strings.TrimPrefix(gitCmd, "git-")
+	rp := filepath.Join(repoDir, repo)
+	switch gitCmd {
+	case "git-upload-archive", "git-upload-pack":
+		exists, err := fileExists(rp)
+		if !exists {
+			return ErrInvalidRepo
+		}
+		if err != nil {
+			return err
+		}
+		return RunGit(out, in, er, "", cmd, rp)
+	case "git-receive-pack":
+		err := ensureRepo(repoDir, repo)
+		if err != nil {
+			return err
+		}
+		err = RunGit(out, in, er, "", cmd, rp)
+		if err != nil {
+			return err
+		}
+		err = ensureDefaultBranch(out, in, er, rp)
+		if err != nil {
+			return err
+		}
+		// Needed for git dumb http server
+		return RunGit(out, in, er, rp, "update-server-info")
+	default:
+		return fmt.Errorf("unknown git command: %s", gitCmd)
+	}
+}
+
+// RunGit runs a git command in the given repo.
+func RunGit(out io.Writer, in io.Reader, err io.Writer, dir string, args ...string) error {
+	c := git.NewCommand(args...)
+	return c.RunInDirWithOptions(dir, git.RunInDirOptions{
+		Stdout: out,
+		Stdin:  in,
+		Stderr: err,
+	})
+}
+
+// WritePktline encodes and writes a pktline to the given writer.
+func WritePktline(w io.Writer, v ...interface{}) {
+	msg := fmt.Sprint(v...)
+	pkt := pktline.NewEncoder(w)
+	if err := pkt.EncodeString(msg); err != nil {
+		log.Printf("git: error writing pkt-line message: %s", err)
+	}
+	if err := pkt.Flush(); err != nil {
+		log.Printf("git: error flushing pkt-line message: %s", err)
+	}
+}
+
+func fileExists(path string) (bool, error) {
+	_, err := os.Stat(path)
+	if err == nil {
+		return true, nil
+	}
+	if os.IsNotExist(err) {
+		return false, nil
+	}
+	return true, err
+}
+
+func ensureRepo(dir string, repo string) error {
+	exists, err := fileExists(dir)
+	if err != nil {
+		return err
+	}
+	if !exists {
+		err = os.MkdirAll(dir, os.ModeDir|os.FileMode(0700))
+		if err != nil {
+			return err
+		}
+	}
+	rp := filepath.Join(dir, repo)
+	exists, err = fileExists(rp)
+	if err != nil {
+		return err
+	}
+	if !exists {
+		_, err := git.Init(rp, true)
+		if err != nil {
+			return err
+		}
+	}
+	return nil
+}
+
+func ensureDefaultBranch(out io.Writer, in io.Reader, er io.Writer, repoPath string) error {
+	r, err := git.Open(repoPath)
+	if err != nil {
+		return err
+	}
+	brs, err := r.Branches()
+	if err != nil {
+		return err
+	}
+	if len(brs) == 0 {
+		return fmt.Errorf("no branches found")
+	}
+	// Rename the default branch to the first branch available
+	_, err = r.HEAD()
+	if err == git.ErrReferenceNotExist {
+		err = RunGit(out, in, er, repoPath, "branch", "-M", brs[0])
+		if err != nil {
+			return err
+		}
+	}
+	if err != nil && err != git.ErrReferenceNotExist {
+		return err
+	}
+	return nil
+}

server/git/ssh.go 🔗

@@ -1,250 +0,0 @@
-package git
-
-import (
-	"errors"
-	"fmt"
-	"log"
-	"os"
-	"path/filepath"
-	"strings"
-
-	"github.com/charmbracelet/soft-serve/git"
-	"github.com/charmbracelet/wish"
-	"github.com/gliderlabs/ssh"
-	g "github.com/gogs/git-module"
-)
-
-// ErrNotAuthed represents unauthorized access.
-var ErrNotAuthed = errors.New("you are not authorized to do this")
-
-// ErrSystemMalfunction represents a general system error returned to clients.
-var ErrSystemMalfunction = errors.New("something went wrong")
-
-// ErrInvalidRepo represents an attempt to access a non-existent repo.
-var ErrInvalidRepo = errors.New("invalid repo")
-
-// AccessLevel is the level of access allowed to a repo.
-type AccessLevel int
-
-const (
-	// NoAccess does not allow access to the repo.
-	NoAccess AccessLevel = iota
-
-	// ReadOnlyAccess allows read-only access to the repo.
-	ReadOnlyAccess
-
-	// ReadWriteAccess allows read and write access to the repo.
-	ReadWriteAccess
-
-	// AdminAccess allows read, write, and admin access to the repo.
-	AdminAccess
-)
-
-// String implements the Stringer interface for AccessLevel.
-func (a AccessLevel) String() string {
-	switch a {
-	case NoAccess:
-		return "no-access"
-	case ReadOnlyAccess:
-		return "read-only"
-	case ReadWriteAccess:
-		return "read-write"
-	case AdminAccess:
-		return "admin-access"
-	default:
-		return ""
-	}
-}
-
-// Hooks is an interface that allows for custom authorization
-// implementations and post push/fetch notifications. Prior to git access,
-// AuthRepo will be called with the ssh.Session public key and the repo name.
-// Implementers return the appropriate AccessLevel.
-type Hooks interface {
-	AuthRepo(string, ssh.PublicKey) AccessLevel
-	Push(string, ssh.PublicKey)
-	Fetch(string, ssh.PublicKey)
-}
-
-// Middleware adds Git server functionality to the ssh.Server. Repos are stored
-// in the specified repo directory. The provided Hooks implementation will be
-// checked for access on a per repo basis for a ssh.Session public key.
-// Hooks.Push and Hooks.Fetch will be called on successful completion of
-// their commands.
-func Middleware(repoDir string, gh Hooks) wish.Middleware {
-	return func(sh ssh.Handler) ssh.Handler {
-		return func(s ssh.Session) {
-			func() {
-				cmd := s.Command()
-				if len(cmd) == 2 && strings.HasPrefix(cmd[0], "git") {
-					gc := cmd[0]
-					// repo should be in the form of "repo.git"
-					repo := strings.TrimPrefix(cmd[1], "/")
-					repo = filepath.Clean(repo)
-					if strings.Contains(repo, "/") {
-						log.Printf("invalid repo: %s", repo)
-						Fatal(s, fmt.Errorf("%s: %s", ErrInvalidRepo, "user repos not supported"))
-						return
-					}
-					pk := s.PublicKey()
-					access := gh.AuthRepo(strings.TrimSuffix(repo, ".git"), pk)
-					// git bare repositories should end in ".git"
-					// https://git-scm.com/docs/gitrepository-layout
-					if !strings.HasSuffix(repo, ".git") {
-						repo += ".git"
-					}
-					switch gc {
-					case "git-receive-pack":
-						switch access {
-						case ReadWriteAccess, AdminAccess:
-							err := gitPack(s, gc, repoDir, repo)
-							if err != nil {
-								Fatal(s, ErrSystemMalfunction)
-							} else {
-								gh.Push(repo, pk)
-							}
-						default:
-							Fatal(s, ErrNotAuthed)
-						}
-						return
-					case "git-upload-archive", "git-upload-pack":
-						switch access {
-						case ReadOnlyAccess, ReadWriteAccess, AdminAccess:
-							// try to upload <repo>.git first, then <repo>
-							err := gitPack(s, gc, repoDir, repo)
-							if err != nil {
-								err = gitPack(s, gc, repoDir, strings.TrimSuffix(repo, ".git"))
-							}
-							switch err {
-							case ErrInvalidRepo:
-								Fatal(s, ErrInvalidRepo)
-							case nil:
-								gh.Fetch(repo, pk)
-							default:
-								log.Printf("unknown git error: %s", err)
-								Fatal(s, ErrSystemMalfunction)
-							}
-						default:
-							Fatal(s, ErrNotAuthed)
-						}
-						return
-					}
-				}
-			}()
-			sh(s)
-		}
-	}
-}
-
-func gitPack(s ssh.Session, gitCmd string, repoDir string, repo string) error {
-	cmd := strings.TrimPrefix(gitCmd, "git-")
-	rp := filepath.Join(repoDir, repo)
-	switch gitCmd {
-	case "git-upload-archive", "git-upload-pack":
-		exists, err := fileExists(rp)
-		if !exists {
-			return ErrInvalidRepo
-		}
-		if err != nil {
-			return err
-		}
-		return runGit(s, "", cmd, rp)
-	case "git-receive-pack":
-		err := ensureRepo(repoDir, repo)
-		if err != nil {
-			return err
-		}
-		err = runGit(s, "", cmd, rp)
-		if err != nil {
-			return err
-		}
-		err = ensureDefaultBranch(s, rp)
-		if err != nil {
-			return err
-		}
-		// Needed for git dumb http server
-		return runGit(s, rp, "update-server-info")
-	default:
-		return fmt.Errorf("unknown git command: %s", gitCmd)
-	}
-}
-
-func fileExists(path string) (bool, error) {
-	_, err := os.Stat(path)
-	if err == nil {
-		return true, nil
-	}
-	if os.IsNotExist(err) {
-		return false, nil
-	}
-	return true, err
-}
-
-// Fatal prints to the session's STDOUT as a git response and exit 1.
-func Fatal(s ssh.Session, v ...interface{}) {
-	msg := fmt.Sprint(v...)
-	// hex length includes 4 byte length prefix and ending newline
-	pktLine := fmt.Sprintf("%04x%s\n", len(msg)+5, msg)
-	_, _ = wish.WriteString(s, pktLine)
-	s.Exit(1) // nolint: errcheck
-}
-
-func ensureRepo(dir string, repo string) error {
-	exists, err := fileExists(dir)
-	if err != nil {
-		return err
-	}
-	if !exists {
-		err = os.MkdirAll(dir, os.ModeDir|os.FileMode(0700))
-		if err != nil {
-			return err
-		}
-	}
-	rp := filepath.Join(dir, repo)
-	exists, err = fileExists(rp)
-	if err != nil {
-		return err
-	}
-	if !exists {
-		_, err := git.Init(rp, true)
-		if err != nil {
-			return err
-		}
-	}
-	return nil
-}
-
-func runGit(s ssh.Session, dir string, args ...string) error {
-	c := g.NewCommand(args...)
-	return c.RunInDirWithOptions(dir, g.RunInDirOptions{
-		Stdout: s,
-		Stdin:  s,
-		Stderr: s.Stderr(),
-	})
-}
-
-func ensureDefaultBranch(s ssh.Session, repoPath string) error {
-	r, err := git.Open(repoPath)
-	if err != nil {
-		return err
-	}
-	brs, err := r.Branches()
-	if err != nil {
-		return err
-	}
-	if len(brs) == 0 {
-		return fmt.Errorf("no branches found")
-	}
-	// Rename the default branch to the first branch available
-	_, err = r.HEAD()
-	if err == git.ErrReferenceNotExist {
-		err = runGit(s, repoPath, "branch", "-M", brs[0])
-		if err != nil {
-			return err
-		}
-	}
-	if err != nil && err != git.ErrReferenceNotExist {
-		return err
-	}
-	return nil
-}

server/git/ssh/ssh.go 🔗

@@ -0,0 +1,88 @@
+package ssh
+
+import (
+	"fmt"
+	"log"
+	"path/filepath"
+	"strings"
+
+	"github.com/charmbracelet/soft-serve/server/git"
+	"github.com/charmbracelet/wish"
+	"github.com/gliderlabs/ssh"
+)
+
+// Middleware adds Git server functionality to the ssh.Server. Repos are stored
+// in the specified repo directory. The provided Hooks implementation will be
+// checked for access on a per repo basis for a ssh.Session public key.
+// Hooks.Push and Hooks.Fetch will be called on successful completion of
+// their commands.
+func Middleware(repoDir string, gh git.Hooks) wish.Middleware {
+	return func(sh ssh.Handler) ssh.Handler {
+		return func(s ssh.Session) {
+			func() {
+				cmd := s.Command()
+				if len(cmd) == 2 && strings.HasPrefix(cmd[0], "git") {
+					gc := cmd[0]
+					// repo should be in the form of "repo.git"
+					repo := strings.TrimPrefix(cmd[1], "/")
+					repo = filepath.Clean(repo)
+					if strings.Contains(repo, "/") {
+						log.Printf("invalid repo: %s", repo)
+						Fatal(s, fmt.Errorf("%s: %s", git.ErrInvalidRepo, "user repos not supported"))
+						return
+					}
+					pk := s.PublicKey()
+					access := gh.AuthRepo(strings.TrimSuffix(repo, ".git"), pk)
+					// git bare repositories should end in ".git"
+					// https://git-scm.com/docs/gitrepository-layout
+					if !strings.HasSuffix(repo, ".git") {
+						repo += ".git"
+					}
+					switch gc {
+					case "git-receive-pack":
+						switch access {
+						case git.ReadWriteAccess, git.AdminAccess:
+							err := git.GitPack(s, s, s.Stderr(), gc, repoDir, repo)
+							if err != nil {
+								Fatal(s, git.ErrSystemMalfunction)
+							} else {
+								gh.Push(repo, pk)
+							}
+						default:
+							Fatal(s, git.ErrNotAuthed)
+						}
+						return
+					case "git-upload-archive", "git-upload-pack":
+						switch access {
+						case git.ReadOnlyAccess, git.ReadWriteAccess, git.AdminAccess:
+							// try to upload <repo>.git first, then <repo>
+							err := git.GitPack(s, s, s.Stderr(), gc, repoDir, repo)
+							if err != nil {
+								err = git.GitPack(s, s, s.Stderr(), gc, repoDir, strings.TrimSuffix(repo, ".git"))
+							}
+							switch err {
+							case git.ErrInvalidRepo:
+								Fatal(s, git.ErrInvalidRepo)
+							case nil:
+								gh.Fetch(repo, pk)
+							default:
+								log.Printf("unknown git error: %s", err)
+								Fatal(s, git.ErrSystemMalfunction)
+							}
+						default:
+							Fatal(s, git.ErrNotAuthed)
+						}
+						return
+					}
+				}
+			}()
+			sh(s)
+		}
+	}
+}
+
+// Fatal prints to the session's STDOUT as a git response and exit 1.
+func Fatal(s ssh.Session, v ...interface{}) {
+	git.WritePktline(s, v...)
+	s.Exit(1) // nolint: errcheck
+}

server/git/ssh_test.go → server/git/ssh/ssh_test.go 🔗

@@ -1,4 +1,4 @@
-package git
+package ssh
 
 import (
 	"fmt"
@@ -11,6 +11,7 @@ import (
 	"testing"
 
 	"github.com/charmbracelet/keygen"
+	"github.com/charmbracelet/soft-serve/server/git"
 	"github.com/charmbracelet/wish"
 	"github.com/gliderlabs/ssh"
 )
@@ -27,13 +28,13 @@ func TestGitMiddleware(t *testing.T) {
 		pushes:  []action{},
 		fetches: []action{},
 		access: []accessDetails{
-			{pubkey, "repo1", AdminAccess},
-			{pubkey, "repo2", AdminAccess},
-			{pubkey, "repo3", AdminAccess},
-			{pubkey, "repo4", AdminAccess},
-			{pubkey, "repo5", NoAccess},
-			{pubkey, "repo6", ReadOnlyAccess},
-			{pubkey, "repo7", AdminAccess},
+			{pubkey, "repo1", git.AdminAccess},
+			{pubkey, "repo2", git.AdminAccess},
+			{pubkey, "repo3", git.AdminAccess},
+			{pubkey, "repo4", git.AdminAccess},
+			{pubkey, "repo5", git.NoAccess},
+			{pubkey, "repo6", git.ReadOnlyAccess},
+			{pubkey, "repo7", git.AdminAccess},
 		},
 	}
 	srv, err := wish.NewServer(
@@ -179,7 +180,7 @@ func createKeyPair(t *testing.T) (ssh.PublicKey, string) {
 type accessDetails struct {
 	key   ssh.PublicKey
 	repo  string
-	level AccessLevel
+	level git.AccessLevel
 }
 
 type action struct {
@@ -194,13 +195,13 @@ type testHooks struct {
 	access  []accessDetails
 }
 
-func (h *testHooks) AuthRepo(repo string, key ssh.PublicKey) AccessLevel {
+func (h *testHooks) AuthRepo(repo string, key ssh.PublicKey) git.AccessLevel {
 	for _, dets := range h.access {
 		if dets.repo == repo && ssh.KeysEqual(key, dets.key) {
 			return dets.level
 		}
 	}
-	return NoAccess
+	return git.NoAccess
 }
 
 func (h *testHooks) Push(repo string, key ssh.PublicKey) {

server/server.go 🔗

@@ -4,23 +4,25 @@ import (
 	"context"
 	"fmt"
 	"log"
-	"net"
 
 	appCfg "github.com/charmbracelet/soft-serve/config"
 	cm "github.com/charmbracelet/soft-serve/server/cmd"
 	"github.com/charmbracelet/soft-serve/server/config"
-	gm "github.com/charmbracelet/soft-serve/server/git"
+	"github.com/charmbracelet/soft-serve/server/git/daemon"
+	gm "github.com/charmbracelet/soft-serve/server/git/ssh"
 	"github.com/charmbracelet/wish"
 	bm "github.com/charmbracelet/wish/bubbletea"
 	lm "github.com/charmbracelet/wish/logging"
 	rm "github.com/charmbracelet/wish/recover"
 	"github.com/gliderlabs/ssh"
 	"github.com/muesli/termenv"
+	"golang.org/x/sync/errgroup"
 )
 
 // Server is the Soft Serve server.
 type Server struct {
 	SSHServer *ssh.Server
+	GitServer *daemon.Daemon
 	Config    *config.Config
 	config    *appCfg.Config
 }
@@ -58,40 +60,63 @@ func NewServer(cfg *config.Config) *Server {
 	if err != nil {
 		log.Fatalln(err)
 	}
+	d, err := daemon.NewDaemon(cfg, ac)
+	if err != nil {
+		log.Fatalln(err)
+	}
 	return &Server{
 		SSHServer: s,
+		GitServer: d,
 		Config:    cfg,
 		config:    ac,
 	}
 }
 
 // Reload reloads the server configuration.
-func (srv *Server) Reload() error {
-	return srv.config.Reload()
+func (s *Server) Reload() error {
+	return s.config.Reload()
 }
 
 // Start starts the SSH server.
-func (srv *Server) Start() error {
-	if err := srv.SSHServer.ListenAndServe(); err != ssh.ErrServerClosed {
-		return err
-	}
-	return nil
-}
-
-// Serve serves the SSH server using the provided listener.
-func (srv *Server) Serve(l net.Listener) error {
-	if err := srv.SSHServer.Serve(l); err != ssh.ErrServerClosed {
-		return err
-	}
-	return nil
+func (s *Server) Start() error {
+	var errg errgroup.Group
+	errg.Go(func() error {
+		log.Printf("Starting Git server on %s:%d", s.Config.BindAddr, s.Config.GitPort)
+		if err := s.GitServer.Start(); err != daemon.ErrServerClosed {
+			return err
+		}
+		return nil
+	})
+	errg.Go(func() error {
+		log.Printf("Starting SSH server on %s:%d", s.Config.BindAddr, s.Config.Port)
+		if err := s.SSHServer.ListenAndServe(); err != ssh.ErrServerClosed {
+			return err
+		}
+		return nil
+	})
+	return errg.Wait()
 }
 
 // Shutdown lets the server gracefully shutdown.
-func (srv *Server) Shutdown(ctx context.Context) error {
-	return srv.SSHServer.Shutdown(ctx)
+func (s *Server) Shutdown(ctx context.Context) error {
+	var errg errgroup.Group
+	errg.Go(func() error {
+		return s.SSHServer.Shutdown(ctx)
+	})
+	errg.Go(func() error {
+		return s.GitServer.Shutdown(ctx)
+	})
+	return errg.Wait()
 }
 
 // Close closes the SSH server.
-func (srv *Server) Close() error {
-	return srv.SSHServer.Close()
+func (s *Server) Close() error {
+	var errg errgroup.Group
+	errg.Go(func() error {
+		return s.SSHServer.Close()
+	})
+	errg.Go(func() error {
+		return s.GitServer.Close()
+	})
+	return errg.Wait()
 }