refactor: tidy up server git

Ayman Bagabas created

use git services to implement handling git server commands
pass config to git as environment variables

Change summary

server/config/config.go |  34 ++++++++
server/daemon/conn.go   | 105 ++++++++++++++++++++++++++
server/daemon/daemon.go | 171 ++++++++++++++++++------------------------
server/git/git.go       | 142 ++++-------------------------------
server/git/service.go   | 136 ++++++++++++++++++++++++++++++++++
server/ssh/ssh.go       |  48 ++++++++---
6 files changed, 401 insertions(+), 235 deletions(-)

Detailed changes

server/config/config.go 🔗

@@ -114,6 +114,40 @@ type Config struct {
 	Backend backend.Backend `yaml:"-"`
 }
 
+// Environ returns the config as a list of environment variables.
+func (c *Config) Environ() []string {
+	envs := []string{}
+	if c == nil {
+		return envs
+	}
+
+	// TODO: do this dynamically
+	envs = append(envs, []string{
+		fmt.Sprintf("SOFT_SERVE_NAME=%s", c.Name),
+		fmt.Sprintf("SOFT_SERVE_DATA_PATH=%s", c.DataPath),
+		fmt.Sprintf("SOFT_SERVE_INITIAL_ADMIN_KEYS=%s", strings.Join(c.InitialAdminKeys, "\n")),
+		fmt.Sprintf("SOFT_SERVE_SSH_LISTEN_ADDR=%s", c.SSH.ListenAddr),
+		fmt.Sprintf("SOFT_SERVE_SSH_PUBLIC_URL=%s", c.SSH.PublicURL),
+		fmt.Sprintf("SOFT_SERVE_SSH_KEY_PATH=%s", c.SSH.KeyPath),
+		fmt.Sprintf("SOFT_SERVE_SSH_CLIENT_KEY_PATH=%s", c.SSH.ClientKeyPath),
+		fmt.Sprintf("SOFT_SERVE_SSH_MAX_TIMEOUT=%d", c.SSH.MaxTimeout),
+		fmt.Sprintf("SOFT_SERVE_SSH_IDLE_TIMEOUT=%d", c.SSH.IdleTimeout),
+		fmt.Sprintf("SOFT_SERVE_GIT_LISTEN_ADDR=%s", c.Git.ListenAddr),
+		fmt.Sprintf("SOFT_SERVE_GIT_MAX_TIMEOUT=%d", c.Git.MaxTimeout),
+		fmt.Sprintf("SOFT_SERVE_GIT_IDLE_TIMEOUT=%d", c.Git.IdleTimeout),
+		fmt.Sprintf("SOFT_SERVE_GIT_MAX_CONNECTIONS=%d", c.Git.MaxConnections),
+		fmt.Sprintf("SOFT_SERVE_HTTP_LISTEN_ADDR=%s", c.HTTP.ListenAddr),
+		fmt.Sprintf("SOFT_SERVE_HTTP_TLS_KEY_PATH=%s", c.HTTP.TLSKeyPath),
+		fmt.Sprintf("SOFT_SERVE_HTTP_TLS_CERT_PATH=%s", c.HTTP.TLSCertPath),
+		fmt.Sprintf("SOFT_SERVE_HTTP_PUBLIC_URL=%s", c.HTTP.PublicURL),
+		fmt.Sprintf("SOFT_SERVE_STATS_LISTEN_ADDR=%s", c.Stats.ListenAddr),
+		fmt.Sprintf("SOFT_SERVE_LOG_FORMAT=%s", c.Log.Format),
+		fmt.Sprintf("SOFT_SERVE_LOG_TIME_FORMAT=%s", c.Log.TimeFormat),
+	}...)
+
+	return envs
+}
+
 func parseConfig(path string) (*Config, error) {
 	dataPath := filepath.Dir(path)
 	cfg := &Config{

server/daemon/conn.go 🔗

@@ -0,0 +1,105 @@
+package daemon
+
+import (
+	"context"
+	"errors"
+	"net"
+	"sync"
+	"time"
+)
+
+// connections is a synchronizes access to to a net.Conn pool.
+type connections struct {
+	m  map[net.Conn]struct{}
+	mu sync.Mutex
+}
+
+func (m *connections) Add(c net.Conn) {
+	m.mu.Lock()
+	defer m.mu.Unlock()
+	m.m[c] = struct{}{}
+}
+
+func (m *connections) Close(c net.Conn) error {
+	m.mu.Lock()
+	defer m.mu.Unlock()
+	err := c.Close()
+	delete(m.m, c)
+	return err
+}
+
+func (m *connections) Size() int {
+	m.mu.Lock()
+	defer m.mu.Unlock()
+	return len(m.m)
+}
+
+func (m *connections) CloseAll() error {
+	m.mu.Lock()
+	defer m.mu.Unlock()
+	var err error
+	for c := range m.m {
+		err = errors.Join(err, c.Close())
+		delete(m.m, c)
+	}
+
+	return err
+}
+
+// serverConn is a wrapper around a net.Conn that closes the connection when
+// the one of the timeouts is reached.
+type serverConn struct {
+	net.Conn
+
+	initTimeout   time.Duration
+	idleTimeout   time.Duration
+	maxDeadline   time.Time
+	closeCanceler context.CancelFunc
+}
+
+var _ net.Conn = (*serverConn)(nil)
+
+func (c *serverConn) Write(p []byte) (n int, err error) {
+	c.updateDeadline()
+	n, err = c.Conn.Write(p)
+	if _, isNetErr := err.(net.Error); isNetErr && c.closeCanceler != nil {
+		c.closeCanceler()
+	}
+	return
+}
+
+func (c *serverConn) Read(b []byte) (n int, err error) {
+	c.updateDeadline()
+	n, err = c.Conn.Read(b)
+	if _, isNetErr := err.(net.Error); isNetErr && c.closeCanceler != nil {
+		c.closeCanceler()
+	}
+	return
+}
+
+func (c *serverConn) Close() (err error) {
+	err = c.Conn.Close()
+	if c.closeCanceler != nil {
+		c.closeCanceler()
+	}
+	return
+}
+
+func (c *serverConn) updateDeadline() {
+	switch {
+	case c.initTimeout > 0:
+		initTimeout := time.Now().Add(c.initTimeout)
+		c.initTimeout = 0
+		if initTimeout.Unix() < c.maxDeadline.Unix() || c.maxDeadline.IsZero() {
+			c.Conn.SetDeadline(initTimeout)
+			return
+		}
+	case c.idleTimeout > 0:
+		idleDeadline := time.Now().Add(c.idleTimeout)
+		if idleDeadline.Unix() < c.maxDeadline.Unix() || c.maxDeadline.IsZero() {
+			c.Conn.SetDeadline(idleDeadline)
+			return
+		}
+	}
+	c.Conn.SetDeadline(c.maxDeadline)
+}

server/daemon/daemon.go 🔗

@@ -6,6 +6,7 @@ import (
 	"fmt"
 	"net"
 	"path/filepath"
+	"strings"
 	"sync"
 	"time"
 
@@ -41,40 +42,6 @@ var (
 	ErrServerClosed = fmt.Errorf("git: %w", net.ErrClosed)
 )
 
-// connections synchronizes access to to a net.Conn pool.
-type connections struct {
-	m  map[net.Conn]struct{}
-	mu sync.Mutex
-}
-
-func (m *connections) Add(c net.Conn) {
-	m.mu.Lock()
-	defer m.mu.Unlock()
-	m.m[c] = struct{}{}
-}
-
-func (m *connections) Close(c net.Conn) {
-	m.mu.Lock()
-	defer m.mu.Unlock()
-	_ = c.Close()
-	delete(m.m, c)
-}
-
-func (m *connections) Size() int {
-	m.mu.Lock()
-	defer m.mu.Unlock()
-	return len(m.m)
-}
-
-func (m *connections) CloseAll() {
-	m.mu.Lock()
-	defer m.mu.Unlock()
-	for c := range m.m {
-		_ = c.Close()
-		delete(m.m, c)
-	}
-}
-
 // GitDaemon represents a Git daemon.
 type GitDaemon struct {
 	ctx      context.Context
@@ -213,26 +180,53 @@ func (d *GitDaemon) handleClient(conn net.Conn) {
 			return
 		}
 
-		gitPack := git.UploadPack
-		counter := uploadPackGitCounter
-		cmd := string(split[0])
-		switch cmd {
-		case git.UploadPackBin:
-			gitPack = git.UploadPack
-		case git.UploadArchiveBin:
-			gitPack = git.UploadArchive
+		var handler git.ServiceHandler
+		var counter *prometheus.CounterVec
+		service := git.Service(split[0])
+		switch service {
+		case git.UploadPackService:
+			handler = git.UploadPack
+			counter = uploadPackGitCounter
+		case git.UploadArchiveService:
+			handler = git.UploadArchive
 			counter = uploadArchiveGitCounter
 		default:
 			d.fatal(c, git.ErrInvalidRequest)
 			return
 		}
 
-		opts := bytes.Split(split[1], []byte{'\x00'})
-		if len(opts) == 0 {
-			d.fatal(c, git.ErrInvalidRequest)
+		opts := bytes.SplitN(split[1], []byte{0}, 3)
+		if len(opts) < 2 {
+			d.fatal(c, git.ErrInvalidRequest) // nolint: errcheck
 			return
 		}
 
+		host := strings.TrimPrefix(string(opts[1]), "host=")
+		extraParams := map[string]string{}
+
+		if len(opts) > 2 {
+			buf := bytes.TrimPrefix(opts[2], []byte{0})
+			for _, o := range bytes.Split(buf, []byte{0}) {
+				opt := string(o)
+				if opt == "" {
+					continue
+				}
+
+				kv := strings.SplitN(opt, "=", 2)
+				if len(kv) != 2 {
+					d.logger.Errorf("git: invalid option %q", opt)
+					continue
+				}
+
+				extraParams[kv[0]] = kv[1]
+			}
+
+			version := extraParams["version"]
+			if version != "" {
+				d.logger.Debugf("git: protocol version %s", version)
+			}
+		}
+
 		be := d.be.WithContext(ctx)
 		if !be.AllowKeyless() {
 			d.fatal(c, git.ErrNotAuthed)
@@ -240,14 +234,21 @@ func (d *GitDaemon) handleClient(conn net.Conn) {
 		}
 
 		name := utils.SanitizeRepo(string(opts[0]))
-		d.logger.Debugf("git: connect %s %s %s", c.RemoteAddr(), cmd, name)
-		defer d.logger.Debugf("git: disconnect %s %s %s", c.RemoteAddr(), cmd, name)
+		d.logger.Debugf("git: connect %s %s %s", c.RemoteAddr(), service, name)
+		defer d.logger.Debugf("git: disconnect %s %s %s", c.RemoteAddr(), service, name)
+
 		// git bare repositories should end in ".git"
 		// https://git-scm.com/docs/gitrepository-layout
 		repo := name + ".git"
 		reposDir := filepath.Join(d.cfg.DataPath, "repos")
 		if err := git.EnsureWithin(reposDir, repo); err != nil {
-			d.fatal(c, err)
+			d.logger.Debugf("git: error ensuring repo path: %v", err)
+			d.fatal(c, git.ErrInvalidRepo)
+			return
+		}
+
+		if _, err := d.be.Repository(repo); err != nil {
+			d.fatal(c, git.ErrInvalidRepo)
 			return
 		}
 
@@ -261,9 +262,33 @@ func (d *GitDaemon) handleClient(conn net.Conn) {
 		envs := []string{
 			"SOFT_SERVE_REPO_NAME=" + name,
 			"SOFT_SERVE_REPO_PATH=" + filepath.Join(reposDir, repo),
+			"SOFT_SERVE_HOST=" + host,
 		}
 
-		if err := gitPack(ctx, c, c, c, filepath.Join(reposDir, repo), envs...); err != nil {
+		// Add git protocol environment variable.
+		if len(extraParams) > 0 {
+			var gitProto string
+			for k, v := range extraParams {
+				if len(gitProto) > 0 {
+					gitProto += ":"
+				}
+				gitProto += k + "=" + v
+			}
+			envs = append(envs, "GIT_PROTOCOL="+gitProto)
+		}
+
+		envs = append(envs, d.cfg.Environ()...)
+
+		cmd := git.ServiceCommand{
+			Stdin:  c,
+			Stdout: c,
+			Stderr: c,
+			Env:    envs,
+			Dir:    filepath.Join(reposDir, repo),
+		}
+
+		if err := handler(ctx, cmd); err != nil {
+			d.logger.Debugf("git: error handling request: %v", err)
 			d.fatal(c, err)
 			return
 		}
@@ -296,51 +321,3 @@ func (d *GitDaemon) Shutdown(ctx context.Context) error {
 		return err
 	}
 }
-
-type serverConn struct {
-	net.Conn
-
-	idleTimeout   time.Duration
-	maxDeadline   time.Time
-	closeCanceler context.CancelFunc
-}
-
-func (c *serverConn) Write(p []byte) (n int, err error) {
-	c.updateDeadline()
-	n, err = c.Conn.Write(p)
-	if _, isNetErr := err.(net.Error); isNetErr && c.closeCanceler != nil {
-		c.closeCanceler()
-	}
-	return
-}
-
-func (c *serverConn) Read(b []byte) (n int, err error) {
-	c.updateDeadline()
-	n, err = c.Conn.Read(b)
-	if _, isNetErr := err.(net.Error); isNetErr && c.closeCanceler != nil {
-		c.closeCanceler()
-	}
-	return
-}
-
-func (c *serverConn) Close() (err error) {
-	err = c.Conn.Close()
-	if c.closeCanceler != nil {
-		c.closeCanceler()
-	}
-	return
-}
-
-func (c *serverConn) updateDeadline() {
-	switch {
-	case c.idleTimeout > 0:
-		idleDeadline := time.Now().Add(c.idleTimeout)
-		if idleDeadline.Unix() < c.maxDeadline.Unix() || c.maxDeadline.IsZero() {
-			c.Conn.SetDeadline(idleDeadline)
-			return
-		}
-		fallthrough
-	default:
-		c.Conn.SetDeadline(c.maxDeadline)
-	}
-}

server/git/git.go 🔗

@@ -5,16 +5,12 @@ import (
 	"errors"
 	"fmt"
 	"io"
-	"os"
-	"os/exec"
 	"path/filepath"
 	"strings"
 
 	"github.com/charmbracelet/log"
 	"github.com/charmbracelet/soft-serve/git"
-	"github.com/charmbracelet/soft-serve/server/config"
 	"github.com/go-git/go-git/v5/plumbing/format/pktline"
-	"golang.org/x/sync/errgroup"
 )
 
 var (
@@ -38,112 +34,6 @@ var (
 	ErrTimeout = errors.New("I/O timeout reached")
 )
 
-// Git protocol commands.
-const (
-	ReceivePackBin   = "git-receive-pack"
-	UploadPackBin    = "git-upload-pack"
-	UploadArchiveBin = "git-upload-archive"
-)
-
-// UploadPack runs the git upload-pack protocol against the provided repo.
-func UploadPack(ctx context.Context, in io.Reader, out io.Writer, er io.Writer, repoDir string, envs ...string) error {
-	exists, err := fileExists(repoDir)
-	if !exists {
-		return ErrInvalidRepo
-	}
-	if err != nil {
-		return err
-	}
-	return RunGit(ctx, in, out, er, "", envs, UploadPackBin[4:], repoDir)
-}
-
-// UploadArchive runs the git upload-archive protocol against the provided repo.
-func UploadArchive(ctx context.Context, in io.Reader, out io.Writer, er io.Writer, repoDir string, envs ...string) error {
-	exists, err := fileExists(repoDir)
-	if !exists {
-		return ErrInvalidRepo
-	}
-	if err != nil {
-		return err
-	}
-	return RunGit(ctx, in, out, er, "", envs, UploadArchiveBin[4:], repoDir)
-}
-
-// ReceivePack runs the git receive-pack protocol against the provided repo.
-func ReceivePack(ctx context.Context, in io.Reader, out io.Writer, er io.Writer, repoDir string, envs ...string) error {
-	if err := RunGit(ctx, in, out, er, "", envs, ReceivePackBin[4:], repoDir); err != nil {
-		return err
-	}
-	return EnsureDefaultBranch(ctx, in, out, er, repoDir)
-}
-
-// RunGit runs a git command in the given repo.
-func RunGit(ctx context.Context, in io.Reader, out io.Writer, er io.Writer, dir string, envs []string, args ...string) error {
-	cfg := config.FromContext(ctx)
-	logger := log.FromContext(ctx).WithPrefix("rungit")
-	c := exec.CommandContext(ctx, "git", args...)
-	c.Dir = dir
-	c.Env = append(os.Environ(), envs...)
-	c.Env = append(c.Env, "PATH="+os.Getenv("PATH"))
-	c.Env = append(c.Env, "SOFT_SERVE_DEBUG="+os.Getenv("SOFT_SERVE_DEBUG"))
-	if cfg != nil {
-		c.Env = append(c.Env, "SOFT_SERVE_LOG_FORMAT="+cfg.Log.Format)
-		c.Env = append(c.Env, "SOFT_SERVE_LOG_TIME_FORMAT="+cfg.Log.TimeFormat)
-	}
-
-	stdin, err := c.StdinPipe()
-	if err != nil {
-		logger.Error("failed to get stdin pipe", "err", err)
-		return err
-	}
-
-	stdout, err := c.StdoutPipe()
-	if err != nil {
-		logger.Error("failed to get stdout pipe", "err", err)
-		return err
-	}
-
-	stderr, err := c.StderrPipe()
-	if err != nil {
-		logger.Error("failed to get stderr pipe", "err", err)
-		return err
-	}
-
-	if err := c.Start(); err != nil {
-		logger.Error("failed to start command", "err", err)
-		return err
-	}
-
-	errg, ctx := errgroup.WithContext(ctx)
-
-	// stdin
-	errg.Go(func() error {
-		defer stdin.Close()
-
-		_, err := io.Copy(stdin, in)
-		return err
-	})
-
-	// stdout
-	errg.Go(func() error {
-		_, err := io.Copy(out, stdout)
-		return err
-	})
-
-	// stderr
-	errg.Go(func() error {
-		_, err := io.Copy(er, stderr)
-		return err
-	})
-
-	if err := errg.Wait(); err != nil {
-		logger.Error("while copying output", "err", err)
-	}
-
-	// Wait for the command to finish
-	return c.Wait()
-}
-
 // WritePktline encodes and writes a pktline to the given writer.
 func WritePktline(w io.Writer, v ...interface{}) {
 	msg := fmt.Sprintln(v...)
@@ -179,19 +69,8 @@ func EnsureWithin(reposDir string, repo string) error {
 	return nil
 }
 
-func fileExists(path string) (bool, error) {
-	_, err := os.Stat(path)
-	if err == nil {
-		return true, nil
-	}
-	if os.IsNotExist(err) {
-		return false, nil
-	}
-	return true, err
-}
-
-func EnsureDefaultBranch(ctx context.Context, in io.Reader, out io.Writer, er io.Writer, repoPath string) error {
-	r, err := git.Open(repoPath)
+func EnsureDefaultBranch(ctx context.Context, scmd ServiceCommand) error {
+	r, err := git.Open(scmd.Dir)
 	if err != nil {
 		return err
 	}
@@ -205,8 +84,21 @@ func EnsureDefaultBranch(ctx context.Context, in io.Reader, out io.Writer, er io
 	// Rename the default branch to the first branch available
 	_, err = r.HEAD()
 	if err == git.ErrReferenceNotExist {
-		err = RunGit(ctx, in, out, er, repoPath, []string{}, "branch", "-M", brs[0])
-		if err != nil {
+		branch := brs[0]
+		// Prefer "main" or "master" as the default branch
+		for _, b := range brs {
+			if b == "main" || b == "master" {
+				branch = b
+				break
+			}
+		}
+
+		cmd := git.NewCommand("branch", "-M", branch).WithContext(ctx)
+		if err := cmd.RunInDirWithOptions(scmd.Dir, git.RunInDirOptions{
+			Stdin:  scmd.Stdin,
+			Stdout: scmd.Stdout,
+			Stderr: scmd.Stderr,
+		}); err != nil {
 			return err
 		}
 	}

server/git/service.go 🔗

@@ -0,0 +1,136 @@
+package git
+
+import (
+	"context"
+	"errors"
+	"fmt"
+	"io"
+	"os"
+	"os/exec"
+	"strings"
+
+	"golang.org/x/sync/errgroup"
+)
+
+// Service is a Git daemon service.
+type Service string
+
+const (
+	// UploadPackService is the upload-pack service.
+	UploadPackService Service = "git-upload-pack"
+	// UploadArchiveService is the upload-archive service.
+	UploadArchiveService Service = "git-upload-archive"
+	// ReceivePackService is the receive-pack service.
+	ReceivePackService Service = "git-receive-pack"
+)
+
+// String returns the string representation of the service.
+func (s Service) String() string {
+	return string(s)
+}
+
+// Name returns the name of the service.
+func (s Service) Name() string {
+	return strings.TrimPrefix(s.String(), "git-")
+}
+
+// Handler is the service handler.
+func (s Service) Handler(ctx context.Context, cmd ServiceCommand) error {
+	switch s {
+	case UploadPackService, UploadArchiveService, ReceivePackService:
+		return gitServiceHandler(ctx, s, cmd)
+	default:
+		return fmt.Errorf("unsupported service: %s", s)
+	}
+}
+
+// ServiceHandler is a git service command handler.
+type ServiceHandler func(ctx context.Context, cmd ServiceCommand) error
+
+// gitServiceHandler is the default service handler using the git binary.
+func gitServiceHandler(ctx context.Context, svc Service, scmd ServiceCommand) error {
+	cmd := exec.CommandContext(ctx, "git", svc.Name()) // nolint: gosec
+	cmd.Dir = scmd.Dir
+	if len(scmd.Args) > 0 {
+		cmd.Args = append(cmd.Args, scmd.Args...)
+	}
+
+	cmd.Args = append(cmd.Args, ".")
+
+	cmd.Env = os.Environ()
+	if len(scmd.Env) > 0 {
+		cmd.Env = append(cmd.Env, scmd.Env...)
+	}
+
+	if scmd.CmdFunc != nil {
+		scmd.CmdFunc(cmd)
+	}
+
+	stdin, err := cmd.StdinPipe()
+	if err != nil {
+		return err
+	}
+
+	stdout, err := cmd.StdoutPipe()
+	if err != nil {
+		return err
+	}
+
+	stderr, err := cmd.StderrPipe()
+	if err != nil {
+		return err
+	}
+
+	if err := cmd.Start(); err != nil {
+		return err
+	}
+
+	errg, ctx := errgroup.WithContext(ctx)
+
+	// stdin
+	errg.Go(func() error {
+		defer stdin.Close() // nolint: errcheck
+		_, err := io.Copy(stdin, scmd.Stdin)
+		return err
+	})
+
+	// stdout
+	errg.Go(func() error {
+		_, err := io.Copy(scmd.Stdout, stdout)
+		return err
+	})
+
+	// stderr
+	errg.Go(func() error {
+		_, err := io.Copy(scmd.Stderr, stderr)
+		return err
+	})
+
+	return errors.Join(errg.Wait(), cmd.Wait())
+}
+
+// ServiceCommand is used to run a git service command.
+type ServiceCommand struct {
+	Stdin   io.Reader
+	Stdout  io.Writer
+	Stderr  io.Writer
+	Dir     string
+	Env     []string
+	Args    []string
+	CmdFunc func(*exec.Cmd)
+}
+
+// UploadPack runs the git upload-pack protocol against the provided repo.
+func UploadPack(ctx context.Context, cmd ServiceCommand) error {
+	return gitServiceHandler(ctx, UploadPackService, cmd)
+}
+
+// UploadArchive runs the git upload-archive protocol against the provided repo.
+func UploadArchive(ctx context.Context, cmd ServiceCommand) error {
+	return gitServiceHandler(ctx, UploadArchiveService, cmd)
+}
+
+// ReceivePack runs the git receive-pack protocol against the provided repo.
+func ReceivePack(ctx context.Context, cmd ServiceCommand) error {
+	return gitServiceHandler(ctx, ReceivePackService, cmd)
+}

server/ssh/ssh.go 🔗

@@ -194,13 +194,13 @@ func (ss *SSHServer) Middleware(cfg *config.Config) wish.Middleware {
 	return func(sh ssh.Handler) ssh.Handler {
 		return func(s ssh.Session) {
 			func() {
-				cmd := s.Command()
+				cmdLine := s.Command()
 				ctx := s.Context()
 				be := ss.be.WithContext(ctx)
-				if len(cmd) >= 2 && strings.HasPrefix(cmd[0], "git") {
-					gc := cmd[0]
+
+				if len(cmdLine) >= 2 && strings.HasPrefix(cmdLine[0], "git") {
 					// repo should be in the form of "repo.git"
-					name := utils.SanitizeRepo(cmd[1])
+					name := utils.SanitizeRepo(cmdLine[1])
 					pk := s.PublicKey()
 					ak := backend.MarshalAuthorizedKey(pk)
 					access := cfg.Backend.AccessLevelByPublicKey(name, pk)
@@ -218,12 +218,27 @@ func (ss *SSHServer) Middleware(cfg *config.Config) wish.Middleware {
 						"SOFT_SERVE_REPO_NAME=" + name,
 						"SOFT_SERVE_REPO_PATH=" + filepath.Join(reposDir, repo),
 						"SOFT_SERVE_PUBLIC_KEY=" + ak,
+						"SOFT_SERVE_USERNAME=" + ctx.User(),
 					}
 
-					ss.logger.Debug("git middleware", "cmd", gc, "access", access.String())
+					// Add ssh session & config environ
+					envs = append(envs, s.Environ()...)
+					envs = append(envs, cfg.Environ()...)
+
 					repoDir := filepath.Join(reposDir, repo)
-					switch gc {
-					case git.ReceivePackBin:
+					service := git.Service(cmdLine[0])
+					cmd := git.ServiceCommand{
+						Stdin:  s,
+						Stdout: s,
+						Stderr: s.Stderr(),
+						Env:    envs,
+						Dir:    repoDir,
+					}
+
+					ss.logger.Debug("git middleware", "cmd", service, "access", access.String())
+
+					switch service {
+					case git.ReceivePackService:
 						if access < backend.ReadWriteAccess {
 							sshFatal(s, git.ErrNotAuthed)
 							return
@@ -234,27 +249,34 @@ func (ss *SSHServer) Middleware(cfg *config.Config) wish.Middleware {
 								sshFatal(s, err)
 								return
 							}
+
 							createRepoCounter.WithLabelValues(ak, s.User(), name).Inc()
 						}
-						if err := git.ReceivePack(s.Context(), s, s, s.Stderr(), repoDir, envs...); err != nil {
+
+						if err := git.ReceivePack(ctx, cmd); err != nil {
 							sshFatal(s, git.ErrSystemMalfunction)
 						}
+
+						if err := git.EnsureDefaultBranch(ctx, cmd); err != nil {
+							sshFatal(s, git.ErrSystemMalfunction)
+						}
+
 						receivePackCounter.WithLabelValues(ak, s.User(), name).Inc()
 						return
-					case git.UploadPackBin, git.UploadArchiveBin:
+					case git.UploadPackService, git.UploadArchiveService:
 						if access < backend.ReadOnlyAccess {
 							sshFatal(s, git.ErrNotAuthed)
 							return
 						}
 
-						gitPack := git.UploadPack
+						handler := git.UploadPack
 						counter := uploadPackCounter
-						if gc == git.UploadArchiveBin {
-							gitPack = git.UploadArchive
+						if service == git.UploadArchiveService {
+							handler = git.UploadArchive
 							counter = uploadArchiveCounter
 						}
 
-						err := gitPack(ctx, s, s, s.Stderr(), repoDir, envs...)
+						err := handler(ctx, cmd)
 						if errors.Is(err, git.ErrInvalidRepo) {
 							sshFatal(s, git.ErrInvalidRepo)
 						} else if err != nil {