refactor(git): sanitize repo name

Ayman Bagabas created

Change summary

server/daemon.go | 40 +++++++++++++++++++++-------------------
server/git.go    | 30 +++++++++++++++++++++++++++---
server/ssh.go    | 20 ++++++++++----------
3 files changed, 58 insertions(+), 32 deletions(-)

Detailed changes

server/daemon.go 🔗

@@ -5,10 +5,8 @@ import (
 	"context"
 	"errors"
 	"io"
-	"log"
 	"net"
 	"path/filepath"
-	"strings"
 	"sync"
 	"time"
 
@@ -97,7 +95,7 @@ func (d *GitDaemon) Start() error {
 			case <-d.finished:
 				return ErrServerClosed
 			default:
-				log.Printf("git: error accepting connection: %v", err)
+				logger.Debugf("git: error accepting connection: %v", err)
 			}
 			if ne, ok := err.(net.Error); ok && ne.Temporary() {
 				if tempDelay == 0 {
@@ -116,7 +114,7 @@ func (d *GitDaemon) Start() error {
 
 		// Close connection if there are too many open connections.
 		if d.conns.Size()+1 >= d.cfg.Git.MaxConnections {
-			log.Printf("git: max connections reached, closing %s", conn.RemoteAddr())
+			logger.Debugf("git: max connections reached, closing %s", conn.RemoteAddr())
 			fatal(conn, ErrMaxConnections)
 			continue
 		}
@@ -132,7 +130,7 @@ func (d *GitDaemon) Start() error {
 func fatal(c net.Conn, err error) {
 	WritePktline(c, err)
 	if err := c.Close(); err != nil {
-		log.Printf("git: error closing connection: %v", err)
+		logger.Debugf("git: error closing connection: %v", err)
 	}
 }
 
@@ -162,7 +160,7 @@ func (d *GitDaemon) handleClient(conn net.Conn) {
 				if nerr, ok := err.(net.Error); ok && nerr.Timeout() {
 					fatal(c, ErrTimeout)
 				} else {
-					log.Printf("git: error scanning pktline: %v", err)
+					logger.Debugf("git: error scanning pktline: %v", err)
 					fatal(c, ErrSystemMalfunction)
 				}
 			}
@@ -174,7 +172,7 @@ func (d *GitDaemon) handleClient(conn net.Conn) {
 	select {
 	case <-ctx.Done():
 		if err := ctx.Err(); err != nil {
-			log.Printf("git: connection context error: %v", err)
+			logger.Debugf("git: connection context error: %v", err)
 		}
 		return
 	case <-readc:
@@ -186,7 +184,6 @@ func (d *GitDaemon) handleClient(conn net.Conn) {
 		}
 
 		var gitPack func(io.Reader, io.Writer, io.Writer, string) error
-		var repo string
 		cmd := string(split[0])
 		switch cmd {
 		case UploadPackBin:
@@ -204,21 +201,26 @@ func (d *GitDaemon) handleClient(conn net.Conn) {
 			return
 		}
 
-		repo = filepath.Clean(string(opts[0]))
-		log.Printf("git: connect %s %s %s", c.RemoteAddr(), cmd, repo)
-		defer log.Printf("git: disconnect %s %s %s", c.RemoteAddr(), cmd, repo)
-		repo = strings.TrimPrefix(repo, "/")
-		auth := d.cfg.Access.AccessLevel(strings.TrimSuffix(repo, ".git"), nil)
+		name := sanitizeRepoName(string(opts[0]))
+		logger.Debugf("git: connect %s %s %s", c.RemoteAddr(), cmd, name)
+		defer logger.Debugf("git: disconnect %s %s %s", c.RemoteAddr(), cmd, name)
+		// git bare repositories should end in ".git"
+		// https://git-scm.com/docs/gitrepository-layout
+		repo := name + ".git"
+		// FIXME: determine repositories path
+		reposDir := filepath.Join(d.cfg.DataPath, "repos")
+		if err := ensureWithin(reposDir, repo); err != nil {
+			fatal(c, err)
+			return
+		}
+
+		auth := d.cfg.Access.AccessLevel(name, nil)
 		if auth < backend.ReadOnlyAccess {
 			fatal(c, ErrNotAuthed)
 			return
 		}
-		// git bare repositories should end in ".git"
-		// https://git-scm.com/docs/gitrepository-layout
-		repo = strings.TrimSuffix(repo, ".git") + ".git"
-		// FIXME: determine repositories path
-		repoDir := filepath.Join(d.cfg.DataPath, "repos", repo)
-		if err := gitPack(c, c, c, repoDir); err != nil {
+
+		if err := gitPack(c, c, c, filepath.Join(reposDir, repo)); err != nil {
 			fatal(c, err)
 			return
 		}

server/git.go 🔗

@@ -4,10 +4,11 @@ import (
 	"errors"
 	"fmt"
 	"io"
-	"log"
 	"os"
 	"path/filepath"
+	"strings"
 
+	"github.com/charmbracelet/log"
 	"github.com/charmbracelet/soft-serve/git"
 	"github.com/go-git/go-git/v5/plumbing/format/pktline"
 )
@@ -90,13 +91,36 @@ func WritePktline(w io.Writer, v ...interface{}) {
 	msg := fmt.Sprintln(v...)
 	pkt := pktline.NewEncoder(w)
 	if err := pkt.EncodeString(msg); err != nil {
-		log.Printf("git: error writing pkt-line message: %s", err)
+		log.Debugf("git: error writing pkt-line message: %s", err)
 	}
 	if err := pkt.Flush(); err != nil {
-		log.Printf("git: error flushing pkt-line message: %s", err)
+		log.Debugf("git: error flushing pkt-line message: %s", err)
 	}
 }
 
+// ensureWithin ensures the given repo is within the repos directory.
+func ensureWithin(reposDir string, repo string) error {
+	repoDir := filepath.Join(reposDir, repo)
+	absRepos, err := filepath.Abs(reposDir)
+	if err != nil {
+		log.Debugf("failed to get absolute path for repo: %s", err)
+		return ErrSystemMalfunction
+	}
+	absRepo, err := filepath.Abs(repoDir)
+	if err != nil {
+		log.Debugf("failed to get absolute path for repos: %s", err)
+		return ErrSystemMalfunction
+	}
+
+	// ensure the repo is within the repos directory
+	if !strings.HasPrefix(absRepo, absRepos) {
+		log.Debugf("repo path is outside of repos directory: %s", absRepo)
+		return ErrInvalidRepo
+	}
+
+	return nil
+}
+
 func fileExists(path string) (bool, error) {
 	_, err := os.Stat(path)
 	if err == nil {

server/ssh.go 🔗

@@ -2,7 +2,6 @@ package server
 
 import (
 	"errors"
-	"fmt"
 	"path/filepath"
 	"strings"
 	"time"
@@ -88,20 +87,21 @@ func (s *SSHServer) Middleware(cfg *config.Config) wish.Middleware {
 				if len(cmd) >= 2 && strings.HasPrefix(cmd[0], "git") {
 					gc := cmd[0]
 					// repo should be in the form of "repo.git"
-					repo := sanitizeRepoName(cmd[1])
-					name := repo
-					if strings.Contains(repo, "/") {
-						log.Printf("invalid repo: %s", repo)
-						sshFatal(s, fmt.Errorf("%s: %s", ErrInvalidRepo, "user repos not supported"))
-						return
-					}
+					name := sanitizeRepoName(cmd[1])
 					pk := s.PublicKey()
 					access := cfg.Access.AccessLevel(name, pk)
 					// git bare repositories should end in ".git"
 					// https://git-scm.com/docs/gitrepository-layout
-					repo = strings.TrimSuffix(repo, ".git") + ".git"
+					repo := name + ".git"
+
 					// FIXME: determine repositories path
-					repoDir := filepath.Join(cfg.DataPath, "repos", repo)
+					reposDir := filepath.Join(cfg.DataPath, "repos")
+					if err := ensureWithin(reposDir, repo); err != nil {
+						sshFatal(s, err)
+						return
+					}
+
+					repoDir := filepath.Join(reposDir, repo)
 					switch gc {
 					case ReceivePackBin:
 						if access < backend.ReadWriteAccess {