Smart HTTP Git transport & partial clones (#291) (#332)

Ayman Bagabas created

* refactor: tidy up server git

use git services to implement handling git server commands
pass config to git as environment variables

* feat(git): enable partial clones

* feat(server): use smart http git backend

This implements the smart http git protocol which also supports
git-receive-pack service.

Change summary

cmd/soft/root.go                |   3 
git/repo.go                     |   7 
git/server.go                   |  18 +
git/utils.go                    |  23 +
internal/log/log.go             |   4 
server/backend/sqlite/hooks.go  |  24 -
server/backend/sqlite/sqlite.go |   7 
server/config/config.go         |  34 ++
server/daemon/conn.go           | 105 ++++++++
server/daemon/daemon.go         | 171 +++++-------
server/git/git.go               | 144 +---------
server/git/service.go           | 186 ++++++++++++++
server/ssh/ssh.go               |  51 ++-
server/web/git.go               | 459 +++++++++++++++++++++++++++++++++++
server/web/goget.go             |  94 +++++++
server/web/http.go              | 295 ---------------------
server/web/logging.go           |  84 ++++++
server/web/server.go            |  40 +++
18 files changed, 1,189 insertions(+), 560 deletions(-)

Detailed changes

cmd/soft/root.go 🔗

@@ -54,6 +54,9 @@ func init() {
 func main() {
 	logger := NewDefaultLogger()
 
+	// Set global logger
+	log.SetDefault(logger)
+
 	// Set the max number of processes to the number of CPUs
 	// This is useful when running soft serve in a container
 	if _, err := maxprocs.Set(maxprocs.Logger(logger.Debugf)); err != nil {

git/repo.go 🔗

@@ -200,13 +200,6 @@ func (r *Repository) CommitsByPage(ref *Reference, page, size int) (Commits, err
 	return commits, nil
 }
 
-// UpdateServerInfo updates the repository server info.
-func (r *Repository) UpdateServerInfo() error {
-	cmd := git.NewCommand("update-server-info")
-	_, err := cmd.RunInDir(r.Path)
-	return err
-}
-
 // Config returns the config value for the given key.
 func (r *Repository) Config(key string, opts ...ConfigOptions) (string, error) {
 	dir, err := gitDir(r.Repository)

git/server.go 🔗

@@ -0,0 +1,18 @@
+package git
+
+import (
+	"context"
+
+	"github.com/gogs/git-module"
+)
+
+// UpdateServerInfo updates the server info file for the given repo path.
+func UpdateServerInfo(ctx context.Context, path string) error {
+	if !isGitDir(path) {
+		return ErrNotAGitRepository
+	}
+
+	cmd := git.NewCommand("update-server-info").WithContext(ctx).WithTimeout(-1)
+	_, err := cmd.RunInDir(path)
+	return err
+}

git/utils.go 🔗

@@ -1,6 +1,7 @@
 package git
 
 import (
+	"os"
 	"path/filepath"
 
 	"github.com/gobwas/glob"
@@ -49,3 +50,25 @@ func LatestFile(repo *Repository, pattern string) (string, string, error) {
 	}
 	return "", "", ErrFileNotFound
 }
+
+// Returns true if path is a directory containing an `objects` directory and a
+// `HEAD` file.
+func isGitDir(path string) bool {
+	stat, err := os.Stat(filepath.Join(path, "objects"))
+	if err != nil {
+		return false
+	}
+	if !stat.IsDir() {
+		return false
+	}
+
+	stat, err = os.Stat(filepath.Join(path, "HEAD"))
+	if err != nil {
+		return false
+	}
+	if stat.IsDir() {
+		return false
+	}
+
+	return true
+}

internal/log/log.go 🔗

@@ -32,6 +32,10 @@ func NewDefaultLogger() *log.Logger {
 
 	if debug, _ := strconv.ParseBool(os.Getenv("SOFT_SERVE_DEBUG")); debug {
 		logger.SetLevel(log.DebugLevel)
+
+		if verbose, _ := strconv.ParseBool(os.Getenv("SOFT_SERVE_VERBOSE")); verbose {
+			logger.SetReportCaller(true)
+		}
 	}
 
 	logger.SetTimeFormat(cfg.Log.TimeFormat)

server/backend/sqlite/hooks.go 🔗

@@ -36,16 +36,6 @@ func (d *SqliteBackend) PostUpdate(stdout io.Writer, stderr io.Writer, repo stri
 
 	var wg sync.WaitGroup
 
-	// Update server info
-	wg.Add(1)
-	go func() {
-		defer wg.Done()
-		if err := updateServerInfo(d, repo); err != nil {
-			d.logger.Error("error updating server-info", "repo", repo, "err", err)
-			return
-		}
-	}()
-
 	// Populate last-modified file.
 	wg.Add(1)
 	go func() {
@@ -59,20 +49,6 @@ func (d *SqliteBackend) PostUpdate(stdout io.Writer, stderr io.Writer, repo stri
 	wg.Wait()
 }
 
-func updateServerInfo(d *SqliteBackend, repo string) error {
-	rr, err := d.Repository(repo)
-	if err != nil {
-		return err
-	}
-
-	r, err := rr.Open()
-	if err != nil {
-		return err
-	}
-
-	return r.UpdateServerInfo()
-}
-
 func populateLastModified(d *SqliteBackend, repo string) error {
 	var rr *Repo
 	_rr, err := d.Repository(repo)

server/backend/sqlite/sqlite.go 🔗

@@ -151,17 +151,12 @@ func (d *SqliteBackend) CreateRepository(name string, opts backend.RepositoryOpt
 			return err
 		}
 
-		rr, err := git.Init(rp, true)
+		_, err := git.Init(rp, true)
 		if err != nil {
 			d.logger.Debug("failed to create repository", "err", err)
 			return err
 		}
 
-		if err := rr.UpdateServerInfo(); err != nil {
-			d.logger.Debug("failed to update server info", "err", err)
-			return err
-		}
-
 		return nil
 	}); err != nil {
 		d.logger.Debug("failed to create repository in database", "err", err)

server/config/config.go 🔗

@@ -114,6 +114,40 @@ type Config struct {
 	Backend backend.Backend `yaml:"-"`
 }
 
+// Environ returns the config as a list of environment variables.
+func (c *Config) Environ() []string {
+	envs := []string{}
+	if c == nil {
+		return envs
+	}
+
+	// TODO: do this dynamically
+	envs = append(envs, []string{
+		fmt.Sprintf("SOFT_SERVE_NAME=%s", c.Name),
+		fmt.Sprintf("SOFT_SERVE_DATA_PATH=%s", c.DataPath),
+		fmt.Sprintf("SOFT_SERVE_INITIAL_ADMIN_KEYS=%s", strings.Join(c.InitialAdminKeys, "\n")),
+		fmt.Sprintf("SOFT_SERVE_SSH_LISTEN_ADDR=%s", c.SSH.ListenAddr),
+		fmt.Sprintf("SOFT_SERVE_SSH_PUBLIC_URL=%s", c.SSH.PublicURL),
+		fmt.Sprintf("SOFT_SERVE_SSH_KEY_PATH=%s", c.SSH.KeyPath),
+		fmt.Sprintf("SOFT_SERVE_SSH_CLIENT_KEY_PATH=%s", c.SSH.ClientKeyPath),
+		fmt.Sprintf("SOFT_SERVE_SSH_MAX_TIMEOUT=%d", c.SSH.MaxTimeout),
+		fmt.Sprintf("SOFT_SERVE_SSH_IDLE_TIMEOUT=%d", c.SSH.IdleTimeout),
+		fmt.Sprintf("SOFT_SERVE_GIT_LISTEN_ADDR=%s", c.Git.ListenAddr),
+		fmt.Sprintf("SOFT_SERVE_GIT_MAX_TIMEOUT=%d", c.Git.MaxTimeout),
+		fmt.Sprintf("SOFT_SERVE_GIT_IDLE_TIMEOUT=%d", c.Git.IdleTimeout),
+		fmt.Sprintf("SOFT_SERVE_GIT_MAX_CONNECTIONS=%d", c.Git.MaxConnections),
+		fmt.Sprintf("SOFT_SERVE_HTTP_LISTEN_ADDR=%s", c.HTTP.ListenAddr),
+		fmt.Sprintf("SOFT_SERVE_HTTP_TLS_KEY_PATH=%s", c.HTTP.TLSKeyPath),
+		fmt.Sprintf("SOFT_SERVE_HTTP_TLS_CERT_PATH=%s", c.HTTP.TLSCertPath),
+		fmt.Sprintf("SOFT_SERVE_HTTP_PUBLIC_URL=%s", c.HTTP.PublicURL),
+		fmt.Sprintf("SOFT_SERVE_STATS_LISTEN_ADDR=%s", c.Stats.ListenAddr),
+		fmt.Sprintf("SOFT_SERVE_LOG_FORMAT=%s", c.Log.Format),
+		fmt.Sprintf("SOFT_SERVE_LOG_TIME_FORMAT=%s", c.Log.TimeFormat),
+	}...)
+
+	return envs
+}
+
 func parseConfig(path string) (*Config, error) {
 	dataPath := filepath.Dir(path)
 	cfg := &Config{

server/daemon/conn.go 🔗

@@ -0,0 +1,105 @@
+package daemon
+
+import (
+	"context"
+	"errors"
+	"net"
+	"sync"
+	"time"
+)
+
+// connections is a synchronizes access to to a net.Conn pool.
+type connections struct {
+	m  map[net.Conn]struct{}
+	mu sync.Mutex
+}
+
+func (m *connections) Add(c net.Conn) {
+	m.mu.Lock()
+	defer m.mu.Unlock()
+	m.m[c] = struct{}{}
+}
+
+func (m *connections) Close(c net.Conn) error {
+	m.mu.Lock()
+	defer m.mu.Unlock()
+	err := c.Close()
+	delete(m.m, c)
+	return err
+}
+
+func (m *connections) Size() int {
+	m.mu.Lock()
+	defer m.mu.Unlock()
+	return len(m.m)
+}
+
+func (m *connections) CloseAll() error {
+	m.mu.Lock()
+	defer m.mu.Unlock()
+	var err error
+	for c := range m.m {
+		err = errors.Join(err, c.Close())
+		delete(m.m, c)
+	}
+
+	return err
+}
+
+// serverConn is a wrapper around a net.Conn that closes the connection when
+// the one of the timeouts is reached.
+type serverConn struct {
+	net.Conn
+
+	initTimeout   time.Duration
+	idleTimeout   time.Duration
+	maxDeadline   time.Time
+	closeCanceler context.CancelFunc
+}
+
+var _ net.Conn = (*serverConn)(nil)
+
+func (c *serverConn) Write(p []byte) (n int, err error) {
+	c.updateDeadline()
+	n, err = c.Conn.Write(p)
+	if _, isNetErr := err.(net.Error); isNetErr && c.closeCanceler != nil {
+		c.closeCanceler()
+	}
+	return
+}
+
+func (c *serverConn) Read(b []byte) (n int, err error) {
+	c.updateDeadline()
+	n, err = c.Conn.Read(b)
+	if _, isNetErr := err.(net.Error); isNetErr && c.closeCanceler != nil {
+		c.closeCanceler()
+	}
+	return
+}
+
+func (c *serverConn) Close() (err error) {
+	err = c.Conn.Close()
+	if c.closeCanceler != nil {
+		c.closeCanceler()
+	}
+	return
+}
+
+func (c *serverConn) updateDeadline() {
+	switch {
+	case c.initTimeout > 0:
+		initTimeout := time.Now().Add(c.initTimeout)
+		c.initTimeout = 0
+		if initTimeout.Unix() < c.maxDeadline.Unix() || c.maxDeadline.IsZero() {
+			c.Conn.SetDeadline(initTimeout)
+			return
+		}
+	case c.idleTimeout > 0:
+		idleDeadline := time.Now().Add(c.idleTimeout)
+		if idleDeadline.Unix() < c.maxDeadline.Unix() || c.maxDeadline.IsZero() {
+			c.Conn.SetDeadline(idleDeadline)
+			return
+		}
+	}
+	c.Conn.SetDeadline(c.maxDeadline)
+}

server/daemon/daemon.go 🔗

@@ -6,6 +6,7 @@ import (
 	"fmt"
 	"net"
 	"path/filepath"
+	"strings"
 	"sync"
 	"time"
 
@@ -41,40 +42,6 @@ var (
 	ErrServerClosed = fmt.Errorf("git: %w", net.ErrClosed)
 )
 
-// connections synchronizes access to to a net.Conn pool.
-type connections struct {
-	m  map[net.Conn]struct{}
-	mu sync.Mutex
-}
-
-func (m *connections) Add(c net.Conn) {
-	m.mu.Lock()
-	defer m.mu.Unlock()
-	m.m[c] = struct{}{}
-}
-
-func (m *connections) Close(c net.Conn) {
-	m.mu.Lock()
-	defer m.mu.Unlock()
-	_ = c.Close()
-	delete(m.m, c)
-}
-
-func (m *connections) Size() int {
-	m.mu.Lock()
-	defer m.mu.Unlock()
-	return len(m.m)
-}
-
-func (m *connections) CloseAll() {
-	m.mu.Lock()
-	defer m.mu.Unlock()
-	for c := range m.m {
-		_ = c.Close()
-		delete(m.m, c)
-	}
-}
-
 // GitDaemon represents a Git daemon.
 type GitDaemon struct {
 	ctx      context.Context
@@ -213,26 +180,53 @@ func (d *GitDaemon) handleClient(conn net.Conn) {
 			return
 		}
 
-		gitPack := git.UploadPack
-		counter := uploadPackGitCounter
-		cmd := string(split[0])
-		switch cmd {
-		case git.UploadPackBin:
-			gitPack = git.UploadPack
-		case git.UploadArchiveBin:
-			gitPack = git.UploadArchive
+		var handler git.ServiceHandler
+		var counter *prometheus.CounterVec
+		service := git.Service(split[0])
+		switch service {
+		case git.UploadPackService:
+			handler = git.UploadPack
+			counter = uploadPackGitCounter
+		case git.UploadArchiveService:
+			handler = git.UploadArchive
 			counter = uploadArchiveGitCounter
 		default:
 			d.fatal(c, git.ErrInvalidRequest)
 			return
 		}
 
-		opts := bytes.Split(split[1], []byte{'\x00'})
-		if len(opts) == 0 {
-			d.fatal(c, git.ErrInvalidRequest)
+		opts := bytes.SplitN(split[1], []byte{0}, 3)
+		if len(opts) < 2 {
+			d.fatal(c, git.ErrInvalidRequest) // nolint: errcheck
 			return
 		}
 
+		host := strings.TrimPrefix(string(opts[1]), "host=")
+		extraParams := map[string]string{}
+
+		if len(opts) > 2 {
+			buf := bytes.TrimPrefix(opts[2], []byte{0})
+			for _, o := range bytes.Split(buf, []byte{0}) {
+				opt := string(o)
+				if opt == "" {
+					continue
+				}
+
+				kv := strings.SplitN(opt, "=", 2)
+				if len(kv) != 2 {
+					d.logger.Errorf("git: invalid option %q", opt)
+					continue
+				}
+
+				extraParams[kv[0]] = kv[1]
+			}
+
+			version := extraParams["version"]
+			if version != "" {
+				d.logger.Debugf("git: protocol version %s", version)
+			}
+		}
+
 		be := d.be.WithContext(ctx)
 		if !be.AllowKeyless() {
 			d.fatal(c, git.ErrNotAuthed)
@@ -240,14 +234,21 @@ func (d *GitDaemon) handleClient(conn net.Conn) {
 		}
 
 		name := utils.SanitizeRepo(string(opts[0]))
-		d.logger.Debugf("git: connect %s %s %s", c.RemoteAddr(), cmd, name)
-		defer d.logger.Debugf("git: disconnect %s %s %s", c.RemoteAddr(), cmd, name)
+		d.logger.Debugf("git: connect %s %s %s", c.RemoteAddr(), service, name)
+		defer d.logger.Debugf("git: disconnect %s %s %s", c.RemoteAddr(), service, name)
+
 		// git bare repositories should end in ".git"
 		// https://git-scm.com/docs/gitrepository-layout
 		repo := name + ".git"
 		reposDir := filepath.Join(d.cfg.DataPath, "repos")
 		if err := git.EnsureWithin(reposDir, repo); err != nil {
-			d.fatal(c, err)
+			d.logger.Debugf("git: error ensuring repo path: %v", err)
+			d.fatal(c, git.ErrInvalidRepo)
+			return
+		}
+
+		if _, err := d.be.Repository(repo); err != nil {
+			d.fatal(c, git.ErrInvalidRepo)
 			return
 		}
 
@@ -261,9 +262,33 @@ func (d *GitDaemon) handleClient(conn net.Conn) {
 		envs := []string{
 			"SOFT_SERVE_REPO_NAME=" + name,
 			"SOFT_SERVE_REPO_PATH=" + filepath.Join(reposDir, repo),
+			"SOFT_SERVE_HOST=" + host,
 		}
 
-		if err := gitPack(ctx, c, c, c, filepath.Join(reposDir, repo), envs...); err != nil {
+		// Add git protocol environment variable.
+		if len(extraParams) > 0 {
+			var gitProto string
+			for k, v := range extraParams {
+				if len(gitProto) > 0 {
+					gitProto += ":"
+				}
+				gitProto += k + "=" + v
+			}
+			envs = append(envs, "GIT_PROTOCOL="+gitProto)
+		}
+
+		envs = append(envs, d.cfg.Environ()...)
+
+		cmd := git.ServiceCommand{
+			Stdin:  c,
+			Stdout: c,
+			Stderr: c,
+			Env:    envs,
+			Dir:    filepath.Join(reposDir, repo),
+		}
+
+		if err := handler(ctx, cmd); err != nil {
+			d.logger.Debugf("git: error handling request: %v", err)
 			d.fatal(c, err)
 			return
 		}
@@ -296,51 +321,3 @@ func (d *GitDaemon) Shutdown(ctx context.Context) error {
 		return err
 	}
 }
-
-type serverConn struct {
-	net.Conn
-
-	idleTimeout   time.Duration
-	maxDeadline   time.Time
-	closeCanceler context.CancelFunc
-}
-
-func (c *serverConn) Write(p []byte) (n int, err error) {
-	c.updateDeadline()
-	n, err = c.Conn.Write(p)
-	if _, isNetErr := err.(net.Error); isNetErr && c.closeCanceler != nil {
-		c.closeCanceler()
-	}
-	return
-}
-
-func (c *serverConn) Read(b []byte) (n int, err error) {
-	c.updateDeadline()
-	n, err = c.Conn.Read(b)
-	if _, isNetErr := err.(net.Error); isNetErr && c.closeCanceler != nil {
-		c.closeCanceler()
-	}
-	return
-}
-
-func (c *serverConn) Close() (err error) {
-	err = c.Conn.Close()
-	if c.closeCanceler != nil {
-		c.closeCanceler()
-	}
-	return
-}
-
-func (c *serverConn) updateDeadline() {
-	switch {
-	case c.idleTimeout > 0:
-		idleDeadline := time.Now().Add(c.idleTimeout)
-		if idleDeadline.Unix() < c.maxDeadline.Unix() || c.maxDeadline.IsZero() {
-			c.Conn.SetDeadline(idleDeadline)
-			return
-		}
-		fallthrough
-	default:
-		c.Conn.SetDeadline(c.maxDeadline)
-	}
-}

server/git/git.go 🔗

@@ -5,16 +5,12 @@ import (
 	"errors"
 	"fmt"
 	"io"
-	"os"
-	"os/exec"
 	"path/filepath"
 	"strings"
 
 	"github.com/charmbracelet/log"
 	"github.com/charmbracelet/soft-serve/git"
-	"github.com/charmbracelet/soft-serve/server/config"
 	"github.com/go-git/go-git/v5/plumbing/format/pktline"
-	"golang.org/x/sync/errgroup"
 )
 
 var (
@@ -38,112 +34,6 @@ var (
 	ErrTimeout = errors.New("I/O timeout reached")
 )
 
-// Git protocol commands.
-const (
-	ReceivePackBin   = "git-receive-pack"
-	UploadPackBin    = "git-upload-pack"
-	UploadArchiveBin = "git-upload-archive"
-)
-
-// UploadPack runs the git upload-pack protocol against the provided repo.
-func UploadPack(ctx context.Context, in io.Reader, out io.Writer, er io.Writer, repoDir string, envs ...string) error {
-	exists, err := fileExists(repoDir)
-	if !exists {
-		return ErrInvalidRepo
-	}
-	if err != nil {
-		return err
-	}
-	return RunGit(ctx, in, out, er, "", envs, UploadPackBin[4:], repoDir)
-}
-
-// UploadArchive runs the git upload-archive protocol against the provided repo.
-func UploadArchive(ctx context.Context, in io.Reader, out io.Writer, er io.Writer, repoDir string, envs ...string) error {
-	exists, err := fileExists(repoDir)
-	if !exists {
-		return ErrInvalidRepo
-	}
-	if err != nil {
-		return err
-	}
-	return RunGit(ctx, in, out, er, "", envs, UploadArchiveBin[4:], repoDir)
-}
-
-// ReceivePack runs the git receive-pack protocol against the provided repo.
-func ReceivePack(ctx context.Context, in io.Reader, out io.Writer, er io.Writer, repoDir string, envs ...string) error {
-	if err := RunGit(ctx, in, out, er, "", envs, ReceivePackBin[4:], repoDir); err != nil {
-		return err
-	}
-	return EnsureDefaultBranch(ctx, in, out, er, repoDir)
-}
-
-// RunGit runs a git command in the given repo.
-func RunGit(ctx context.Context, in io.Reader, out io.Writer, er io.Writer, dir string, envs []string, args ...string) error {
-	cfg := config.FromContext(ctx)
-	logger := log.FromContext(ctx).WithPrefix("rungit")
-	c := exec.CommandContext(ctx, "git", args...)
-	c.Dir = dir
-	c.Env = append(os.Environ(), envs...)
-	c.Env = append(c.Env, "PATH="+os.Getenv("PATH"))
-	c.Env = append(c.Env, "SOFT_SERVE_DEBUG="+os.Getenv("SOFT_SERVE_DEBUG"))
-	if cfg != nil {
-		c.Env = append(c.Env, "SOFT_SERVE_LOG_FORMAT="+cfg.Log.Format)
-		c.Env = append(c.Env, "SOFT_SERVE_LOG_TIME_FORMAT="+cfg.Log.TimeFormat)
-	}
-
-	stdin, err := c.StdinPipe()
-	if err != nil {
-		logger.Error("failed to get stdin pipe", "err", err)
-		return err
-	}
-
-	stdout, err := c.StdoutPipe()
-	if err != nil {
-		logger.Error("failed to get stdout pipe", "err", err)
-		return err
-	}
-
-	stderr, err := c.StderrPipe()
-	if err != nil {
-		logger.Error("failed to get stderr pipe", "err", err)
-		return err
-	}
-
-	if err := c.Start(); err != nil {
-		logger.Error("failed to start command", "err", err)
-		return err
-	}
-
-	errg, ctx := errgroup.WithContext(ctx)
-
-	// stdin
-	errg.Go(func() error {
-		defer stdin.Close()
-
-		_, err := io.Copy(stdin, in)
-		return err
-	})
-
-	// stdout
-	errg.Go(func() error {
-		_, err := io.Copy(out, stdout)
-		return err
-	})
-
-	// stderr
-	errg.Go(func() error {
-		_, err := io.Copy(er, stderr)
-		return err
-	})
-
-	if err := errg.Wait(); err != nil {
-		logger.Error("while copying output", "err", err)
-	}
-
-	// Wait for the command to finish
-	return c.Wait()
-}
-
 // WritePktline encodes and writes a pktline to the given writer.
 func WritePktline(w io.Writer, v ...interface{}) {
 	msg := fmt.Sprintln(v...)
@@ -179,19 +69,10 @@ func EnsureWithin(reposDir string, repo string) error {
 	return nil
 }
 
-func fileExists(path string) (bool, error) {
-	_, err := os.Stat(path)
-	if err == nil {
-		return true, nil
-	}
-	if os.IsNotExist(err) {
-		return false, nil
-	}
-	return true, err
-}
-
-func EnsureDefaultBranch(ctx context.Context, in io.Reader, out io.Writer, er io.Writer, repoPath string) error {
-	r, err := git.Open(repoPath)
+// EnsureDefaultBranch ensures the repo has a default branch.
+// It will prefer choosing "main" or "master" if available.
+func EnsureDefaultBranch(ctx context.Context, scmd ServiceCommand) error {
+	r, err := git.Open(scmd.Dir)
 	if err != nil {
 		return err
 	}
@@ -205,8 +86,21 @@ func EnsureDefaultBranch(ctx context.Context, in io.Reader, out io.Writer, er io
 	// Rename the default branch to the first branch available
 	_, err = r.HEAD()
 	if err == git.ErrReferenceNotExist {
-		err = RunGit(ctx, in, out, er, repoPath, []string{}, "branch", "-M", brs[0])
-		if err != nil {
+		branch := brs[0]
+		// Prefer "main" or "master" as the default branch
+		for _, b := range brs {
+			if b == "main" || b == "master" {
+				branch = b
+				break
+			}
+		}
+
+		cmd := git.NewCommand("branch", "-M", branch).WithContext(ctx)
+		if err := cmd.RunInDirWithOptions(scmd.Dir, git.RunInDirOptions{
+			Stdin:  scmd.Stdin,
+			Stdout: scmd.Stdout,
+			Stderr: scmd.Stderr,
+		}); err != nil {
 			return err
 		}
 	}

server/git/service.go 🔗

@@ -0,0 +1,186 @@
+package git
+
+import (
+	"context"
+	"errors"
+	"fmt"
+	"io"
+	"os"
+	"os/exec"
+	"strings"
+
+	"github.com/charmbracelet/log"
+	"golang.org/x/sync/errgroup"
+)
+
+// Service is a Git daemon service.
+type Service string
+
+const (
+	// UploadPackService is the upload-pack service.
+	UploadPackService Service = "git-upload-pack"
+	// UploadArchiveService is the upload-archive service.
+	UploadArchiveService Service = "git-upload-archive"
+	// ReceivePackService is the receive-pack service.
+	ReceivePackService Service = "git-receive-pack"
+)
+
+// String returns the string representation of the service.
+func (s Service) String() string {
+	return string(s)
+}
+
+// Name returns the name of the service.
+func (s Service) Name() string {
+	return strings.TrimPrefix(s.String(), "git-")
+}
+
+// Handler is the service handler.
+func (s Service) Handler(ctx context.Context, cmd ServiceCommand) error {
+	switch s {
+	case UploadPackService, UploadArchiveService, ReceivePackService:
+		return gitServiceHandler(ctx, s, cmd)
+	default:
+		return fmt.Errorf("unsupported service: %s", s)
+	}
+}
+
+// ServiceHandler is a git service command handler.
+type ServiceHandler func(ctx context.Context, cmd ServiceCommand) error
+
+// gitServiceHandler is the default service handler using the git binary.
+func gitServiceHandler(ctx context.Context, svc Service, scmd ServiceCommand) error {
+	cmd := exec.CommandContext(ctx, "git", "-c", "uploadpack.allowFilter=true", svc.Name()) // nolint: gosec
+	cmd.Dir = scmd.Dir
+	if len(scmd.Args) > 0 {
+		cmd.Args = append(cmd.Args, scmd.Args...)
+	}
+
+	cmd.Args = append(cmd.Args, ".")
+
+	cmd.Env = os.Environ()
+	if len(scmd.Env) > 0 {
+		cmd.Env = append(cmd.Env, scmd.Env...)
+	}
+
+	if scmd.CmdFunc != nil {
+		scmd.CmdFunc(cmd)
+	}
+
+	var (
+		err    error
+		stdin  io.WriteCloser
+		stdout io.ReadCloser
+		stderr io.ReadCloser
+	)
+
+	if scmd.Stdin != nil {
+		stdin, err = cmd.StdinPipe()
+		if err != nil {
+			return err
+		}
+	}
+
+	if scmd.Stdout != nil {
+		stdout, err = cmd.StdoutPipe()
+		if err != nil {
+			return err
+		}
+	}
+
+	if scmd.Stderr != nil {
+		stderr, err = cmd.StderrPipe()
+		if err != nil {
+			return err
+		}
+	}
+
+	log.Debugf("git service command in %q: %s", cmd.Dir, cmd.String())
+	if err := cmd.Start(); err != nil {
+		return err
+	}
+
+	errg, ctx := errgroup.WithContext(ctx)
+
+	// stdin
+	if scmd.Stdin != nil {
+		errg.Go(func() error {
+			if scmd.StdinHandler != nil {
+				return scmd.StdinHandler(scmd.Stdin, stdin)
+			} else {
+				return defaultStdinHandler(scmd.Stdin, stdin)
+			}
+		})
+	}
+
+	// stdout
+	if scmd.Stdout != nil {
+		errg.Go(func() error {
+			if scmd.StdoutHandler != nil {
+				return scmd.StdoutHandler(scmd.Stdout, stdout)
+			} else {
+				return defaultStdoutHandler(scmd.Stdout, stdout)
+			}
+		})
+	}
+
+	// stderr
+	if scmd.Stderr != nil {
+		errg.Go(func() error {
+			if scmd.StderrHandler != nil {
+				return scmd.StderrHandler(scmd.Stderr, stderr)
+			} else {
+				return defaultStderrHandler(scmd.Stderr, stderr)
+			}
+		})
+	}
+
+	return errors.Join(errg.Wait(), cmd.Wait())
+}
+
+// ServiceCommand is used to run a git service command.
+type ServiceCommand struct {
+	Stdin  io.Reader
+	Stdout io.Writer
+	Stderr io.Writer
+	Dir    string
+	Env    []string
+	Args   []string
+
+	// Modifier functions
+	CmdFunc       func(*exec.Cmd)
+	StdinHandler  func(io.Reader, io.WriteCloser) error
+	StdoutHandler func(io.Writer, io.ReadCloser) error
+	StderrHandler func(io.Writer, io.ReadCloser) error
+}
+
+func defaultStdinHandler(in io.Reader, stdin io.WriteCloser) error {
+	defer stdin.Close() // nolint: errcheck
+	_, err := io.Copy(stdin, in)
+	return err
+}
+
+func defaultStdoutHandler(out io.Writer, stdout io.ReadCloser) error {
+	_, err := io.Copy(out, stdout)
+	return err
+}
+
+func defaultStderrHandler(err io.Writer, stderr io.ReadCloser) error {
+	_, erro := io.Copy(err, stderr)
+	return erro
+}
+
+// UploadPack runs the git upload-pack protocol against the provided repo.
+func UploadPack(ctx context.Context, cmd ServiceCommand) error {
+	return gitServiceHandler(ctx, UploadPackService, cmd)
+}
+
+// UploadArchive runs the git upload-archive protocol against the provided repo.
+func UploadArchive(ctx context.Context, cmd ServiceCommand) error {
+	return gitServiceHandler(ctx, UploadArchiveService, cmd)
+}
+
+// ReceivePack runs the git receive-pack protocol against the provided repo.
+func ReceivePack(ctx context.Context, cmd ServiceCommand) error {
+	return gitServiceHandler(ctx, ReceivePackService, cmd)
+}

server/ssh/ssh.go 🔗

@@ -216,13 +216,13 @@ func (ss *SSHServer) Middleware(cfg *config.Config) wish.Middleware {
 		return func(s ssh.Session) {
 			func() {
 				start := time.Now()
-				cmd := s.Command()
+				cmdLine := s.Command()
 				ctx := s.Context()
 				be := ss.be.WithContext(ctx)
-				if len(cmd) >= 2 && strings.HasPrefix(cmd[0], "git") {
-					gc := cmd[0]
+
+				if len(cmdLine) >= 2 && strings.HasPrefix(cmdLine[0], "git") {
 					// repo should be in the form of "repo.git"
-					name := utils.SanitizeRepo(cmd[1])
+					name := utils.SanitizeRepo(cmdLine[1])
 					pk := s.PublicKey()
 					ak := backend.MarshalAuthorizedKey(pk)
 					access := cfg.Backend.AccessLevelByPublicKey(name, pk)
@@ -240,12 +240,27 @@ func (ss *SSHServer) Middleware(cfg *config.Config) wish.Middleware {
 						"SOFT_SERVE_REPO_NAME=" + name,
 						"SOFT_SERVE_REPO_PATH=" + filepath.Join(reposDir, repo),
 						"SOFT_SERVE_PUBLIC_KEY=" + ak,
+						"SOFT_SERVE_USERNAME=" + ctx.User(),
 					}
 
-					ss.logger.Debug("git middleware", "cmd", gc, "access", access.String())
+					// Add ssh session & config environ
+					envs = append(envs, s.Environ()...)
+					envs = append(envs, cfg.Environ()...)
+
 					repoDir := filepath.Join(reposDir, repo)
-					switch gc {
-					case git.ReceivePackBin:
+					service := git.Service(cmdLine[0])
+					cmd := git.ServiceCommand{
+						Stdin:  s,
+						Stdout: s,
+						Stderr: s.Stderr(),
+						Env:    envs,
+						Dir:    repoDir,
+					}
+
+					ss.logger.Debug("git middleware", "cmd", service, "access", access.String())
+
+					switch service {
+					case git.ReceivePackService:
 						receivePackCounter.WithLabelValues(name).Inc()
 						defer func() {
 							receivePackSeconds.WithLabelValues(name).Add(time.Since(start).Seconds())
@@ -262,20 +277,27 @@ func (ss *SSHServer) Middleware(cfg *config.Config) wish.Middleware {
 							}
 							createRepoCounter.WithLabelValues(name).Inc()
 						}
-						if err := git.ReceivePack(s.Context(), s, s, s.Stderr(), repoDir, envs...); err != nil {
+
+						if err := git.ReceivePack(ctx, cmd); err != nil {
+							sshFatal(s, git.ErrSystemMalfunction)
+						}
+
+						if err := git.EnsureDefaultBranch(ctx, cmd); err != nil {
 							sshFatal(s, git.ErrSystemMalfunction)
 						}
+
+						receivePackCounter.WithLabelValues(name).Inc()
 						return
-					case git.UploadPackBin, git.UploadArchiveBin:
+					case git.UploadPackService, git.UploadArchiveService:
 						if access < backend.ReadOnlyAccess {
 							sshFatal(s, git.ErrNotAuthed)
 							return
 						}
 
-						gitPack := git.UploadPack
-						switch gc {
-						case git.UploadArchiveBin:
-							gitPack = git.UploadArchive
+						handler := git.UploadPack
+						switch service {
+						case git.UploadArchiveService:
+							handler = git.UploadArchive
 							uploadArchiveCounter.WithLabelValues(name).Inc()
 							defer func() {
 								uploadArchiveSeconds.WithLabelValues(name).Add(time.Since(start).Seconds())
@@ -285,10 +307,9 @@ func (ss *SSHServer) Middleware(cfg *config.Config) wish.Middleware {
 							defer func() {
 								uploadPackSeconds.WithLabelValues(name).Add(time.Since(start).Seconds())
 							}()
-
 						}
 
-						err := gitPack(ctx, s, s, s.Stderr(), repoDir, envs...)
+						err := handler(ctx, cmd)
 						if errors.Is(err, git.ErrInvalidRepo) {
 							sshFatal(s, git.ErrInvalidRepo)
 						} else if err != nil {

server/web/git.go 🔗

@@ -0,0 +1,459 @@
+package web
+
+import (
+	"bytes"
+	"compress/gzip"
+	"context"
+	"fmt"
+	"io"
+	"net/http"
+	"os"
+	"path/filepath"
+	"regexp"
+	"strings"
+	"time"
+
+	"github.com/charmbracelet/log"
+	gitb "github.com/charmbracelet/soft-serve/git"
+	"github.com/charmbracelet/soft-serve/server/backend"
+	"github.com/charmbracelet/soft-serve/server/config"
+	"github.com/charmbracelet/soft-serve/server/git"
+	"github.com/charmbracelet/soft-serve/server/utils"
+	"github.com/prometheus/client_golang/prometheus"
+	"github.com/prometheus/client_golang/prometheus/promauto"
+	"goji.io/pat"
+	"goji.io/pattern"
+)
+
+// GitRoute is a route for git services.
+type GitRoute struct {
+	method  string
+	pattern *regexp.Regexp
+	handler http.HandlerFunc
+
+	cfg    *config.Config
+	be     backend.Backend
+	logger *log.Logger
+}
+
+var _ Route = GitRoute{}
+
+// Match implements goji.Pattern.
+func (g GitRoute) Match(r *http.Request) *http.Request {
+	if g.method != r.Method {
+		return nil
+	}
+
+	re := g.pattern
+	ctx := r.Context()
+	if m := re.FindStringSubmatch(r.URL.Path); m != nil {
+		file := strings.Replace(r.URL.Path, m[1]+"/", "", 1)
+		repo := utils.SanitizeRepo(m[1]) + ".git"
+
+		var service git.Service
+		switch {
+		case strings.HasSuffix(r.URL.Path, git.UploadPackService.String()):
+			service = git.UploadPackService
+		case strings.HasSuffix(r.URL.Path, git.ReceivePackService.String()):
+			service = git.ReceivePackService
+		}
+
+		ctx = context.WithValue(ctx, pattern.Variable("service"), service.String())
+		ctx = context.WithValue(ctx, pattern.Variable("dir"), filepath.Join(g.cfg.DataPath, "repos", repo))
+		ctx = context.WithValue(ctx, pattern.Variable("repo"), repo)
+		ctx = context.WithValue(ctx, pattern.Variable("file"), file)
+
+		if g.cfg != nil {
+			ctx = config.WithContext(ctx, g.cfg)
+		}
+
+		if g.be != nil {
+			ctx = backend.WithContext(ctx, g.be.WithContext(ctx))
+		}
+
+		if g.logger != nil {
+			ctx = log.WithContext(ctx, g.logger)
+		}
+
+		return r.WithContext(ctx)
+	}
+
+	return nil
+}
+
+// ServeHTTP implements http.Handler.
+func (g GitRoute) ServeHTTP(w http.ResponseWriter, r *http.Request) {
+	g.handler(w, r)
+}
+
+var (
+	gitHttpReceiveCounter = promauto.NewCounterVec(prometheus.CounterOpts{
+		Namespace: "soft_serve",
+		Subsystem: "http",
+		Name:      "git_receive_pack_total",
+		Help:      "The total number of git push requests",
+	}, []string{"repo"})
+
+	gitHttpUploadCounter = promauto.NewCounterVec(prometheus.CounterOpts{
+		Namespace: "soft_serve",
+		Subsystem: "http",
+		Name:      "git_upload_pack_total",
+		Help:      "The total number of git fetch/pull requests",
+	}, []string{"repo", "file"})
+)
+
+func gitRoutes(ctx context.Context, logger *log.Logger) []Route {
+	routes := make([]Route, 0)
+	cfg := config.FromContext(ctx)
+	be := backend.FromContext(ctx)
+
+	// Git services
+	// These routes don't handle authentication/authorization.
+	// This is handled through wrapping the handlers for each route.
+	// See below (withAccess).
+	// TODO: add lfs support
+	for _, route := range []GitRoute{
+		{
+			pattern: regexp.MustCompile("(.*?)/git-upload-pack$"),
+			method:  http.MethodPost,
+			handler: serviceRpc,
+		},
+		{
+			pattern: regexp.MustCompile("(.*?)/git-receive-pack$"),
+			method:  http.MethodPost,
+			handler: serviceRpc,
+		},
+		{
+			pattern: regexp.MustCompile("(.*?)/info/refs$"),
+			method:  http.MethodGet,
+			handler: getInfoRefs,
+		},
+		{
+			pattern: regexp.MustCompile("(.*?)/HEAD$"),
+			method:  http.MethodGet,
+			handler: getTextFile,
+		},
+		{
+			pattern: regexp.MustCompile("(.*?)/objects/info/alternates$"),
+			method:  http.MethodGet,
+			handler: getTextFile,
+		},
+		{
+			pattern: regexp.MustCompile("(.*?)/objects/info/http-alternates$"),
+			method:  http.MethodGet,
+			handler: getTextFile,
+		},
+		{
+			pattern: regexp.MustCompile("(.*?)/objects/info/packs$"),
+			method:  http.MethodGet,
+			handler: getInfoPacks,
+		},
+		{
+			pattern: regexp.MustCompile("(.*?)/objects/info/[^/]*$"),
+			method:  http.MethodGet,
+			handler: getTextFile,
+		},
+		{
+			pattern: regexp.MustCompile("(.*?)/objects/[0-9a-f]{2}/[0-9a-f]{38}$"),
+			method:  http.MethodGet,
+			handler: getLooseObject,
+		},
+		{
+			pattern: regexp.MustCompile("(.*?)/objects/pack/pack-[0-9a-f]{40}\\.pack$"),
+			method:  http.MethodGet,
+			handler: getPackFile,
+		},
+		{
+			pattern: regexp.MustCompile("(.*?)/objects/pack/pack-[0-9a-f]{40}\\.idx$"),
+			method:  http.MethodGet,
+			handler: getIdxFile,
+		},
+	} {
+		route.cfg = cfg
+		route.be = be
+		route.logger = logger
+		route.handler = withAccess(route.handler)
+		routes = append(routes, route)
+	}
+
+	return routes
+}
+
+// withAccess handles auth.
+func withAccess(fn http.HandlerFunc) http.HandlerFunc {
+	return func(w http.ResponseWriter, r *http.Request) {
+		ctx := r.Context()
+		be := backend.FromContext(ctx)
+		logger := log.FromContext(ctx)
+
+		if !be.AllowKeyless() {
+			renderForbidden(w)
+			return
+		}
+
+		repo := pat.Param(r, "repo")
+		service := git.Service(pat.Param(r, "service"))
+		access := be.AccessLevel(repo, "")
+
+		switch service {
+		case git.ReceivePackService:
+			if access < backend.ReadWriteAccess {
+				renderUnauthorized(w)
+				return
+			}
+
+			// Create the repo if it doesn't exist.
+			if _, err := be.Repository(repo); err != nil {
+				if _, err := be.CreateRepository(repo, backend.RepositoryOptions{}); err != nil {
+					logger.Error("failed to create repository", "repo", repo, "err", err)
+					renderInternalServerError(w)
+					return
+				}
+			}
+		default:
+			if access < backend.ReadOnlyAccess {
+				renderUnauthorized(w)
+				return
+			}
+		}
+
+		fn(w, r)
+	}
+}
+
+func serviceRpc(w http.ResponseWriter, r *http.Request) {
+	ctx := r.Context()
+	logger := log.FromContext(ctx)
+	service, dir, repo := git.Service(pat.Param(r, "service")), pat.Param(r, "dir"), pat.Param(r, "repo")
+
+	if !isSmart(r, service) {
+		renderForbidden(w)
+		return
+	}
+
+	if service == git.ReceivePackService {
+		gitHttpReceiveCounter.WithLabelValues(repo)
+	}
+
+	w.Header().Set("Content-Type", fmt.Sprintf("application/x-%s-result", service))
+	w.Header().Set("Connection", "Keep-Alive")
+	w.Header().Set("Transfer-Encoding", "chunked")
+	w.Header().Set("X-Content-Type-Options", "nosniff")
+	w.WriteHeader(http.StatusOK)
+
+	version := r.Header.Get("Git-Protocol")
+
+	cmd := git.ServiceCommand{
+		Stdin:  r.Body,
+		Stdout: w,
+		Dir:    dir,
+		Args:   []string{"--stateless-rpc"},
+	}
+
+	if len(version) != 0 {
+		cmd.Env = append(cmd.Env, fmt.Sprintf("GIT_PROTOCOL=%s", version))
+	}
+
+	// Handle gzip encoding
+	cmd.StdinHandler = func(in io.Reader, stdin io.WriteCloser) (err error) {
+		// We know that `in` is an `io.ReadCloser` because it's `r.Body`.
+		reader := in.(io.ReadCloser)
+		defer reader.Close() // nolint: errcheck
+		switch r.Header.Get("Content-Encoding") {
+		case "gzip":
+			reader, err = gzip.NewReader(reader)
+			if err != nil {
+				return err
+			}
+			defer reader.Close() // nolint: errcheck
+		}
+
+		_, err = io.Copy(stdin, reader)
+		return err
+	}
+
+	// Handle buffered output
+	// Useful when using proxies
+	cmd.StdoutHandler = func(out io.Writer, stdout io.ReadCloser) error {
+		// We know that `out` is an `http.ResponseWriter`.
+		flusher, ok := out.(http.Flusher)
+		if !ok {
+			return fmt.Errorf("expected http.ResponseWriter to be an http.Flusher, got %T", out)
+		}
+
+		p := make([]byte, 1024)
+		for {
+			nRead, err := stdout.Read(p)
+			if err == io.EOF {
+				break
+			}
+			nWrite, err := out.Write(p[:nRead])
+			if err != nil {
+				return err
+			}
+			if nRead != nWrite {
+				return fmt.Errorf("failed to write data: %d read, %d written", nRead, nWrite)
+			}
+			flusher.Flush()
+		}
+
+		return nil
+	}
+
+	if err := service.Handler(ctx, cmd); err != nil {
+		logger.Errorf("error executing service: %s", err)
+	}
+}
+
+func getInfoRefs(w http.ResponseWriter, r *http.Request) {
+	ctx := r.Context()
+	logger := log.FromContext(ctx)
+	dir, repo, file := pat.Param(r, "dir"), pat.Param(r, "repo"), pat.Param(r, "file")
+	service := getServiceType(r)
+	version := r.Header.Get("Git-Protocol")
+
+	gitHttpUploadCounter.WithLabelValues(repo, file).Inc()
+
+	if service != "" && (service == git.UploadPackService || service == git.ReceivePackService) {
+		// Smart HTTP
+		var refs bytes.Buffer
+		cmd := git.ServiceCommand{
+			Stdout: &refs,
+			Dir:    dir,
+			Args:   []string{"--stateless-rpc", "--advertise-refs"},
+		}
+
+		if len(version) != 0 {
+			cmd.Env = append(cmd.Env, fmt.Sprintf("GIT_PROTOCOL=%s", version))
+		}
+
+		if err := service.Handler(ctx, cmd); err != nil {
+			logger.Errorf("error executing service: %s", err)
+			renderNotFound(w)
+			return
+		}
+
+		hdrNocache(w)
+		w.Header().Set("Content-Type", fmt.Sprintf("application/x-%s-advertisement", service))
+		w.WriteHeader(http.StatusOK)
+		if len(version) == 0 {
+			git.WritePktline(w, "# service="+service.String())
+		}
+
+		w.Write(refs.Bytes()) // nolint: errcheck
+	} else {
+		// Dumb HTTP
+		updateServerInfo(ctx, dir) // nolint: errcheck
+		hdrNocache(w)
+		sendFile("text/plain; charset=utf-8", w, r)
+	}
+}
+
+func getInfoPacks(w http.ResponseWriter, r *http.Request) {
+	hdrCacheForever(w)
+	sendFile("text/plain; charset=utf-8", w, r)
+}
+
+func getLooseObject(w http.ResponseWriter, r *http.Request) {
+	hdrCacheForever(w)
+	sendFile("application/x-git-loose-object", w, r)
+}
+
+func getPackFile(w http.ResponseWriter, r *http.Request) {
+	hdrCacheForever(w)
+	sendFile("application/x-git-packed-objects", w, r)
+}
+
+func getIdxFile(w http.ResponseWriter, r *http.Request) {
+	hdrCacheForever(w)
+	sendFile("application/x-git-packed-objects-toc", w, r)
+}
+
+func getTextFile(w http.ResponseWriter, r *http.Request) {
+	hdrNocache(w)
+	sendFile("text/plain", w, r)
+}
+
+func sendFile(contentType string, w http.ResponseWriter, r *http.Request) {
+	dir, file := pat.Param(r, "dir"), pat.Param(r, "file")
+	reqFile := filepath.Join(dir, file)
+
+	f, err := os.Stat(reqFile)
+	if os.IsNotExist(err) {
+		renderNotFound(w)
+		return
+	}
+
+	w.Header().Set("Content-Type", contentType)
+	w.Header().Set("Content-Length", fmt.Sprintf("%d", f.Size()))
+	w.Header().Set("Last-Modified", f.ModTime().Format(http.TimeFormat))
+	http.ServeFile(w, r, reqFile)
+}
+
+func getServiceType(r *http.Request) git.Service {
+	service := r.FormValue("service")
+	if !strings.HasPrefix(service, "git-") {
+		return ""
+	}
+
+	return git.Service(service)
+}
+
+func isSmart(r *http.Request, service git.Service) bool {
+	if r.Header.Get("Content-Type") == fmt.Sprintf("application/x-%s-request", service) {
+		return true
+	}
+	return false
+}
+
+func updateServerInfo(ctx context.Context, dir string) error {
+	return gitb.UpdateServerInfo(ctx, dir)
+}
+
+// HTTP error response handling functions
+
+func renderMethodNotAllowed(w http.ResponseWriter, r *http.Request) {
+	if r.Proto == "HTTP/1.1" {
+		w.WriteHeader(http.StatusMethodNotAllowed)
+		w.Write([]byte("Method Not Allowed")) // nolint: errcheck
+	} else {
+		w.WriteHeader(http.StatusBadRequest)
+		w.Write([]byte("Bad Request")) // nolint: errcheck
+	}
+}
+
+func renderNotFound(w http.ResponseWriter) {
+	w.WriteHeader(http.StatusNotFound)
+	w.Write([]byte("Not Found")) // nolint: errcheck
+}
+
+func renderUnauthorized(w http.ResponseWriter) {
+	w.WriteHeader(http.StatusUnauthorized)
+	w.Write([]byte("Unauthorized")) // nolint: errcheck
+}
+
+func renderForbidden(w http.ResponseWriter) {
+	w.WriteHeader(http.StatusForbidden)
+	w.Write([]byte("Forbidden")) // nolint: errcheck
+}
+
+func renderInternalServerError(w http.ResponseWriter) {
+	w.WriteHeader(http.StatusInternalServerError)
+	w.Write([]byte("Internal Server Error")) // nolint: errcheck
+}
+
+// Header writing functions
+
+func hdrNocache(w http.ResponseWriter) {
+	w.Header().Set("Expires", "Fri, 01 Jan 1980 00:00:00 GMT")
+	w.Header().Set("Pragma", "no-cache")
+	w.Header().Set("Cache-Control", "no-cache, max-age=0, must-revalidate")
+}
+
+func hdrCacheForever(w http.ResponseWriter) {
+	now := time.Now().Unix()
+	expires := now + 31536000
+	w.Header().Set("Date", fmt.Sprintf("%d", now))
+	w.Header().Set("Expires", fmt.Sprintf("%d", expires))
+	w.Header().Set("Cache-Control", "public, max-age=31536000")
+}

server/web/goget.go 🔗

@@ -0,0 +1,94 @@
+package web
+
+import (
+	"net/http"
+	"net/url"
+	"path"
+	"text/template"
+
+	"github.com/charmbracelet/soft-serve/server/backend"
+	"github.com/charmbracelet/soft-serve/server/config"
+	"github.com/charmbracelet/soft-serve/server/utils"
+	"github.com/prometheus/client_golang/prometheus"
+	"github.com/prometheus/client_golang/prometheus/promauto"
+	"goji.io/pattern"
+)
+
+var goGetCounter = promauto.NewCounterVec(prometheus.CounterOpts{
+	Namespace: "soft_serve",
+	Subsystem: "http",
+	Name:      "go_get_total",
+	Help:      "The total number of go get requests",
+}, []string{"repo"})
+
+var repoIndexHTMLTpl = template.Must(template.New("index").Parse(`<!DOCTYPE html>
+<html lang="en">
+<head>
+    <meta http-equiv="Content-Type" content="text/html; charset=utf-8"/>
+    <meta http-equiv="refresh" content="0; url=https://godoc.org/{{ .ImportRoot }}/{{.Repo}}">
+    <meta name="go-import" content="{{ .ImportRoot }}/{{ .Repo }} git {{ .Config.HTTP.PublicURL }}/{{ .Repo }}">
+</head>
+<body>
+Redirecting to docs at <a href="https://godoc.org/{{ .ImportRoot }}/{{ .Repo }}">godoc.org/{{ .ImportRoot }}/{{ .Repo }}</a>...
+</body>
+</html>`))
+
+// GoGetHandler handles go get requests.
+type GoGetHandler struct {
+	cfg *config.Config
+	be  backend.Backend
+}
+
+var _ http.Handler = (*GoGetHandler)(nil)
+
+func (g GoGetHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
+	repo := pattern.Path(r.Context())
+	repo = utils.SanitizeRepo(repo)
+	be := g.be.WithContext(r.Context())
+
+	// Handle go get requests.
+	//
+	// Always return a 200 status code, even if the repo doesn't exist.
+	//
+	// https://golang.org/cmd/go/#hdr-Remote_import_paths
+	// https://go.dev/ref/mod#vcs-branch
+	if r.URL.Query().Get("go-get") == "1" {
+		repo := repo
+		importRoot, err := url.Parse(g.cfg.HTTP.PublicURL)
+		if err != nil {
+			http.Error(w, err.Error(), http.StatusInternalServerError)
+			return
+		}
+
+		// find the repo
+		for {
+			if _, err := be.Repository(repo); err == nil {
+				break
+			}
+
+			if repo == "" || repo == "." || repo == "/" {
+				return
+			}
+
+			repo = path.Dir(repo)
+		}
+
+		if err := repoIndexHTMLTpl.Execute(w, struct {
+			Repo       string
+			Config     *config.Config
+			ImportRoot string
+		}{
+			Repo:       url.PathEscape(repo),
+			Config:     g.cfg,
+			ImportRoot: importRoot.Host,
+		}); err != nil {
+			http.Error(w, err.Error(), http.StatusInternalServerError)
+			return
+		}
+
+		goGetCounter.WithLabelValues(repo).Inc()
+		return
+	}
+
+	http.NotFound(w, r)
+}

server/web/http.go 🔗

@@ -2,103 +2,31 @@ package web
 
 import (
 	"context"
-	"fmt"
 	"net/http"
-	"net/url"
-	"path"
-	"path/filepath"
-	"regexp"
-	"strings"
-	"text/template"
 	"time"
 
-	"github.com/charmbracelet/log"
 	"github.com/charmbracelet/soft-serve/server/backend"
 	"github.com/charmbracelet/soft-serve/server/config"
-	"github.com/charmbracelet/soft-serve/server/utils"
-	"github.com/dustin/go-humanize"
-	"github.com/prometheus/client_golang/prometheus"
-	"github.com/prometheus/client_golang/prometheus/promauto"
-	"goji.io"
-	"goji.io/pat"
-	"goji.io/pattern"
 )
 
-var (
-	gitHttpCounter = promauto.NewCounterVec(prometheus.CounterOpts{
-		Namespace: "soft_serve",
-		Subsystem: "http",
-		Name:      "git_fetch_pull_total",
-		Help:      "The total number of git fetch/pull requests",
-	}, []string{"repo", "file"})
-
-	goGetCounter = promauto.NewCounterVec(prometheus.CounterOpts{
-		Namespace: "soft_serve",
-		Subsystem: "http",
-		Name:      "go_get_total",
-		Help:      "The total number of go get requests",
-	}, []string{"repo"})
-)
-
-// logWriter is a wrapper around http.ResponseWriter that allows us to capture
-// the HTTP status code and bytes written to the response.
-type logWriter struct {
-	http.ResponseWriter
-	code, bytes int
-}
-
-func (r *logWriter) Write(p []byte) (int, error) {
-	written, err := r.ResponseWriter.Write(p)
-	r.bytes += written
-	return written, err
-}
-
-// Note this is generally only called when sending an HTTP error, so it's
-// important to set the `code` value to 200 as a default
-func (r *logWriter) WriteHeader(code int) {
-	r.code = code
-	r.ResponseWriter.WriteHeader(code)
-}
-
-func (s *HTTPServer) loggingMiddleware(next http.Handler) http.Handler {
-	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
-		start := time.Now()
-		writer := &logWriter{code: http.StatusOK, ResponseWriter: w}
-		s.logger.Debug("request",
-			"method", r.Method,
-			"uri", r.RequestURI,
-			"addr", r.RemoteAddr)
-		next.ServeHTTP(writer, r)
-		elapsed := time.Since(start)
-		s.logger.Debug("response",
-			"status", fmt.Sprintf("%d %s", writer.code, http.StatusText(writer.code)),
-			"bytes", humanize.Bytes(uint64(writer.bytes)),
-			"time", elapsed)
-	})
-}
-
 // HTTPServer is an http server.
 type HTTPServer struct {
-	ctx        context.Context
-	cfg        *config.Config
-	be         backend.Backend
-	server     *http.Server
-	dirHandler http.Handler
-	logger     *log.Logger
+	ctx    context.Context
+	cfg    *config.Config
+	be     backend.Backend
+	server *http.Server
 }
 
+// NewHTTPServer creates a new HTTP server.
 func NewHTTPServer(ctx context.Context) (*HTTPServer, error) {
 	cfg := config.FromContext(ctx)
-	mux := goji.NewMux()
 	s := &HTTPServer{
-		ctx:        ctx,
-		cfg:        cfg,
-		be:         backend.FromContext(ctx),
-		logger:     log.FromContext(ctx).WithPrefix("http"),
-		dirHandler: http.FileServer(http.Dir(filepath.Join(cfg.DataPath, "repos"))),
+		ctx: ctx,
+		cfg: cfg,
+		be:  backend.FromContext(ctx),
 		server: &http.Server{
 			Addr:              cfg.HTTP.ListenAddr,
-			Handler:           mux,
+			Handler:           NewRouter(ctx),
 			ReadHeaderTimeout: time.Second * 10,
 			ReadTimeout:       time.Second * 10,
 			WriteTimeout:      time.Second * 10,
@@ -106,21 +34,6 @@ func NewHTTPServer(ctx context.Context) (*HTTPServer, error) {
 		},
 	}
 
-	mux.Use(s.loggingMiddleware)
-	for _, m := range []Matcher{
-		getInfoRefs,
-		getHead,
-		getAlternates,
-		getHTTPAlternates,
-		getInfoPacks,
-		getInfoFile,
-		getLooseObject,
-		getPackFile,
-		getIdxFile,
-	} {
-		mux.HandleFunc(NewPattern(m), s.handleGit)
-	}
-	mux.HandleFunc(pat.Get("/*"), s.handleIndex)
 	return s, nil
 }
 
@@ -141,193 +54,3 @@ func (s *HTTPServer) ListenAndServe() error {
 func (s *HTTPServer) Shutdown(ctx context.Context) error {
 	return s.server.Shutdown(ctx)
 }
-
-// Pattern is a pattern for matching a URL.
-// It matches against GET requests.
-type Pattern struct {
-	match func(*url.URL) *match
-}
-
-// NewPattern returns a new Pattern with the given matcher.
-func NewPattern(m Matcher) *Pattern {
-	return &Pattern{
-		match: m,
-	}
-}
-
-// Match is a match for a URL.
-//
-// It implements goji.Pattern.
-func (p *Pattern) Match(r *http.Request) *http.Request {
-	if r.Method != "GET" {
-		return nil
-	}
-
-	if m := p.match(r.URL); m != nil {
-		ctx := context.WithValue(r.Context(), pattern.Variable("repo"), m.RepoPath)
-		ctx = context.WithValue(ctx, pattern.Variable("file"), m.FilePath)
-		return r.WithContext(ctx)
-	}
-	return nil
-}
-
-// Matcher finds a match in a *url.URL.
-type Matcher = func(*url.URL) *match
-
-var (
-	getInfoRefs = func(u *url.URL) *match {
-		return matchSuffix(u.Path, "/info/refs")
-	}
-
-	getHead = func(u *url.URL) *match {
-		return matchSuffix(u.Path, "/HEAD")
-	}
-
-	getAlternates = func(u *url.URL) *match {
-		return matchSuffix(u.Path, "/objects/info/alternates")
-	}
-
-	getHTTPAlternates = func(u *url.URL) *match {
-		return matchSuffix(u.Path, "/objects/info/http-alternates")
-	}
-
-	getInfoPacks = func(u *url.URL) *match {
-		return matchSuffix(u.Path, "/objects/info/packs")
-	}
-
-	getInfoFileRegexp = regexp.MustCompile(".*?(/objects/info/[^/]*)$")
-	getInfoFile       = func(u *url.URL) *match {
-		return findStringSubmatch(u.Path, getInfoFileRegexp)
-	}
-
-	getLooseObjectRegexp = regexp.MustCompile(".*?(/objects/[0-9a-f]{2}/[0-9a-f]{38})$")
-	getLooseObject       = func(u *url.URL) *match {
-		return findStringSubmatch(u.Path, getLooseObjectRegexp)
-	}
-
-	getPackFileRegexp = regexp.MustCompile(`.*?(/objects/pack/pack-[0-9a-f]{40}\.pack)$`)
-	getPackFile       = func(u *url.URL) *match {
-		return findStringSubmatch(u.Path, getPackFileRegexp)
-	}
-
-	getIdxFileRegexp = regexp.MustCompile(`.*?(/objects/pack/pack-[0-9a-f]{40}\.idx)$`)
-	getIdxFile       = func(u *url.URL) *match {
-		return findStringSubmatch(u.Path, getIdxFileRegexp)
-	}
-)
-
-// match represents a match for a URL.
-type match struct {
-	RepoPath, FilePath string
-}
-
-func matchSuffix(path, suffix string) *match {
-	if !strings.HasSuffix(path, suffix) {
-		return nil
-	}
-	repoPath := strings.Replace(path, suffix, "", 1)
-	filePath := strings.Replace(path, repoPath+"/", "", 1)
-	return &match{repoPath, filePath}
-}
-
-func findStringSubmatch(path string, prefix *regexp.Regexp) *match {
-	m := prefix.FindStringSubmatch(path)
-	if m == nil {
-		return nil
-	}
-	suffix := m[1]
-	repoPath := strings.Replace(path, suffix, "", 1)
-	filePath := strings.Replace(path, repoPath+"/", "", 1)
-	return &match{repoPath, filePath}
-}
-
-var repoIndexHTMLTpl = template.Must(template.New("index").Parse(`<!DOCTYPE html>
-<html lang="en">
-<head>
-    <meta http-equiv="Content-Type" content="text/html; charset=utf-8"/>
-    <meta http-equiv="refresh" content="0; url=https://godoc.org/{{ .ImportRoot }}/{{.Repo}}">
-    <meta name="go-import" content="{{ .ImportRoot }}/{{ .Repo }} git {{ .Config.HTTP.PublicURL }}/{{ .Repo }}">
-</head>
-<body>
-Redirecting to docs at <a href="https://godoc.org/{{ .ImportRoot }}/{{ .Repo }}">godoc.org/{{ .ImportRoot }}/{{ .Repo }}</a>...
-</body>
-</html>`))
-
-func (s *HTTPServer) handleIndex(w http.ResponseWriter, r *http.Request) {
-	repo := pattern.Path(r.Context())
-	repo = utils.SanitizeRepo(repo)
-	be := s.be.WithContext(r.Context())
-
-	// Handle go get requests.
-	//
-	// Always return a 200 status code, even if the repo doesn't exist.
-	//
-	// https://golang.org/cmd/go/#hdr-Remote_import_paths
-	// https://go.dev/ref/mod#vcs-branch
-	if r.URL.Query().Get("go-get") == "1" {
-		repo := repo
-		importRoot, err := url.Parse(s.cfg.HTTP.PublicURL)
-		if err != nil {
-			http.Error(w, err.Error(), http.StatusInternalServerError)
-			return
-		}
-
-		// find the repo
-		for {
-			if _, err := be.Repository(repo); err == nil {
-				break
-			}
-
-			if repo == "" || repo == "." || repo == "/" {
-				return
-			}
-
-			repo = path.Dir(repo)
-		}
-
-		if err := repoIndexHTMLTpl.Execute(w, struct {
-			Repo       string
-			Config     *config.Config
-			ImportRoot string
-		}{
-			Repo:       url.PathEscape(repo),
-			Config:     s.cfg,
-			ImportRoot: importRoot.Host,
-		}); err != nil {
-			http.Error(w, err.Error(), http.StatusInternalServerError)
-			return
-		}
-
-		goGetCounter.WithLabelValues(repo).Inc()
-		return
-	}
-
-	http.NotFound(w, r)
-}
-
-func (s *HTTPServer) handleGit(w http.ResponseWriter, r *http.Request) {
-	repo := pat.Param(r, "repo")
-	repo = utils.SanitizeRepo(repo) + ".git"
-	be := s.be.WithContext(r.Context())
-	if _, err := be.Repository(repo); err != nil {
-		s.logger.Debug("repository not found", "repo", repo, "err", err)
-		http.NotFound(w, r)
-		return
-	}
-
-	if !s.cfg.Backend.AllowKeyless() {
-		http.Error(w, "Forbidden", http.StatusForbidden)
-		return
-	}
-
-	access := s.cfg.Backend.AccessLevel(repo, "")
-	if access < backend.ReadOnlyAccess {
-		http.Error(w, "Unauthorized", http.StatusUnauthorized)
-		return
-	}
-
-	file := pat.Param(r, "file")
-	gitHttpCounter.WithLabelValues(repo, file).Inc()
-	r.URL.Path = fmt.Sprintf("/%s/%s", repo, file)
-	s.dirHandler.ServeHTTP(w, r)
-}

server/web/logging.go 🔗

@@ -0,0 +1,84 @@
+package web
+
+import (
+	"bufio"
+	"fmt"
+	"net"
+	"net/http"
+	"time"
+
+	"github.com/charmbracelet/log"
+	"github.com/dustin/go-humanize"
+)
+
+// logWriter is a wrapper around http.ResponseWriter that allows us to capture
+// the HTTP status code and bytes written to the response.
+type logWriter struct {
+	http.ResponseWriter
+	code, bytes int
+}
+
+var _ http.ResponseWriter = (*logWriter)(nil)
+
+var _ http.Flusher = (*logWriter)(nil)
+
+var _ http.Hijacker = (*logWriter)(nil)
+
+var _ http.CloseNotifier = (*logWriter)(nil)
+
+// Write implements http.ResponseWriter.
+func (r *logWriter) Write(p []byte) (int, error) {
+	written, err := r.ResponseWriter.Write(p)
+	r.bytes += written
+	return written, err
+}
+
+// Note this is generally only called when sending an HTTP error, so it's
+// important to set the `code` value to 200 as a default.
+func (r *logWriter) WriteHeader(code int) {
+	r.code = code
+	r.ResponseWriter.WriteHeader(code)
+}
+
+// Flush implements http.Flusher.
+func (r *logWriter) Flush() {
+	if f, ok := r.ResponseWriter.(http.Flusher); ok {
+		f.Flush()
+	}
+}
+
+// CloseNotify implements http.CloseNotifier.
+func (r *logWriter) CloseNotify() <-chan bool {
+	if cn, ok := r.ResponseWriter.(http.CloseNotifier); ok {
+		return cn.CloseNotify()
+	}
+	return nil
+}
+
+// Hijack implements http.Hijacker.
+func (r *logWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
+	if h, ok := r.ResponseWriter.(http.Hijacker); ok {
+		return h.Hijack()
+	}
+	return nil, nil, fmt.Errorf("http.Hijacker not implemented")
+}
+
+// NewLoggingMiddleware returns a new logging middleware.
+func NewLoggingMiddleware(logger *log.Logger) func(http.Handler) http.Handler {
+	return func(next http.Handler) http.Handler {
+		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+			start := time.Now()
+			writer := &logWriter{code: http.StatusOK, ResponseWriter: w}
+			logger.Debug("request",
+				"method", r.Method,
+				"uri", r.RequestURI,
+				"addr", r.RemoteAddr)
+			next.ServeHTTP(writer, r)
+			elapsed := time.Since(start)
+			logger.Debug("response",
+				"status", fmt.Sprintf("%d %s", writer.code, http.StatusText(writer.code)),
+				"bytes", humanize.Bytes(uint64(writer.bytes)),
+				"time", elapsed)
+		})
+	}
+}

server/web/server.go 🔗

@@ -0,0 +1,40 @@
+// Package server is the reusable server
+package web
+
+import (
+	"context"
+	"net/http"
+
+	"github.com/charmbracelet/log"
+	"github.com/charmbracelet/soft-serve/server/backend"
+	"github.com/charmbracelet/soft-serve/server/config"
+	"goji.io"
+	"goji.io/pat"
+)
+
+// Route is an interface for a route.
+type Route interface {
+	http.Handler
+	goji.Pattern
+}
+
+// NewRouter returns a new HTTP router.
+func NewRouter(ctx context.Context) *goji.Mux {
+	mux := goji.NewMux()
+	cfg := config.FromContext(ctx)
+	be := backend.FromContext(ctx)
+	logger := log.FromContext(ctx).WithPrefix("http")
+
+	// Middlewares
+	mux.Use(NewLoggingMiddleware(logger))
+
+	// Git routes
+	for _, service := range gitRoutes(ctx, logger) {
+		mux.Handle(service, service)
+	}
+
+	// go-get handler
+	mux.Handle(pat.Get("/*"), GoGetHandler{cfg, be})
+
+	return mux
+}