fix(server): bound the current context to the underlying operation

Ayman Bagabas created

Use the connection context when running external commands.

Signed-off-by: Ayman Bagabas <ayman.bagabas@gmail.com>

Change summary

go.mod                          |  2 +
go.sum                          |  5 +--
server/backend/backend.go       |  4 +++
server/backend/context.go       | 19 +++++++++++++++++
server/backend/sqlite/sqlite.go | 12 +++++++++++
server/cmd/cmd.go               | 38 +++++++++++++---------------------
server/daemon/daemon.go         |  7 ++++-
server/daemon/daemon_test.go    |  8 ++++--
server/server.go                |  1 
server/ssh/ssh.go               | 19 +++++++++++------
server/web/http.go              |  8 +++++-
11 files changed, 83 insertions(+), 40 deletions(-)

Detailed changes

go.mod 🔗

@@ -93,3 +93,5 @@ require (
 	modernc.org/strutil v1.1.3 // indirect
 	modernc.org/token v1.0.1 // indirect
 )
+
+replace github.com/gogs/git-module => github.com/aymanbagabas/git-module v1.4.1-0.20230509180555-975c24cdb79a

go.sum 🔗

@@ -51,6 +51,8 @@ github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be/go.mod h1:ySMOLuW
 github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5/go.mod h1:wHh0iHkYZB8zMSxRWpUBQtwG5a7fFgvEO+odwuTv2gs=
 github.com/atotto/clipboard v0.1.4 h1:EH0zSVneZPSuFR11BlR9YppQTVDbh5+16AmcJi4g1z4=
 github.com/atotto/clipboard v0.1.4/go.mod h1:ZY9tmq7sm5xIbd9bOK4onWV4S6X0u6GY7Vn0Yu86PYI=
+github.com/aymanbagabas/git-module v1.4.1-0.20230509180555-975c24cdb79a h1:rY724fIR0NNR/UTXJufwKwYz+sNYVve/ZdzWX39xMqM=
+github.com/aymanbagabas/git-module v1.4.1-0.20230509180555-975c24cdb79a/go.mod h1:GUSSUH+RM7fZOtjhS6Obh4B9aAvs3EeROpazfMNMF8g=
 github.com/aymanbagabas/go-osc52 v1.0.3/go.mod h1:zT8H+Rk4VSabYN90pWyugflM3ZhpTZNC7cASDfUCdT4=
 github.com/aymanbagabas/go-osc52 v1.2.1 h1:q2sWUyDcozPLcLabEMd+a+7Ea2DitxZVN9hTxab9L4E=
 github.com/aymanbagabas/go-osc52 v1.2.1/go.mod h1:zT8H+Rk4VSabYN90pWyugflM3ZhpTZNC7cASDfUCdT4=
@@ -148,8 +150,6 @@ github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/me
 github.com/gobwas/glob v0.2.3 h1:A4xDbljILXROh+kObIiy5kIaPYD8e96x1tgBhUI5J+Y=
 github.com/gobwas/glob v0.2.3/go.mod h1:d3Ez4x06l9bZtSvzIay5+Yzi0fmZzPgnTbPcKjJAkT8=
 github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ=
-github.com/gogs/git-module v1.8.1 h1:yC5BZ3unJOXC8N6/FgGQ8EtJXpOd217lgDcd2aPOxkc=
-github.com/gogs/git-module v1.8.1/go.mod h1:Y3rsSqtFZEbn7lp+3gWf42GKIY1eNTtLt7JrmOy0yAQ=
 github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q=
 github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
 github.com/golang/groupcache v0.0.0-20191227052852-215e87163ea7/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
@@ -392,7 +392,6 @@ github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81P
 github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
 github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
 github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
-github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
 github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8=
 github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
 github.com/xanzy/ssh-agent v0.3.3/go.mod h1:6dzNDKs0J9rVPHPhaGCukekBHKqfl+L3KghI1Bc68Uw=

server/backend/backend.go 🔗

@@ -2,6 +2,7 @@ package backend
 
 import (
 	"bytes"
+	"context"
 
 	"github.com/charmbracelet/ssh"
 	gossh "golang.org/x/crypto/ssh"
@@ -17,6 +18,9 @@ type Backend interface {
 	UserStore
 	UserAccess
 	Hooks
+
+	// WithContext returns a copy Backend with the given context.
+	WithContext(ctx context.Context) Backend
 }
 
 // ParseAuthorizedKey parses an authorized key string into a public key.

server/backend/context.go 🔗

@@ -0,0 +1,19 @@
+package backend
+
+import "context"
+
+var contextKey = &struct{ string }{"backend"}
+
+// FromContext returns the backend from a context.
+func FromContext(ctx context.Context) Backend {
+	if b, ok := ctx.Value(contextKey).(Backend); ok {
+		return b
+	}
+
+	return nil
+}
+
+// WithContext returns a new context with the backend attached.
+func WithContext(ctx context.Context, b Backend) context.Context {
+	return context.WithValue(ctx, contextKey, b)
+}

server/backend/sqlite/sqlite.go 🔗

@@ -2,6 +2,7 @@ package sqlite
 
 import (
 	"context"
+	"errors"
 	"fmt"
 	"os"
 	"path/filepath"
@@ -67,6 +68,12 @@ func NewSqliteBackend(ctx context.Context) (*SqliteBackend, error) {
 	return d, d.initRepos()
 }
 
+// WithContext returns a copy of SqliteBackend with the given context.
+func (d SqliteBackend) WithContext(ctx context.Context) backend.Backend {
+	d.ctx = ctx
+	return &d
+}
+
 // AllowKeyless returns whether or not keyless access is allowed.
 //
 // It implements backend.Backend.
@@ -183,6 +190,8 @@ func (d *SqliteBackend) ImportRepository(name string, remote string, opts backen
 		Quiet:   true,
 		Timeout: 15 * time.Minute,
 		CommandOptions: git.CommandOptions{
+			Timeout: -1,
+			Context: d.ctx,
 			Envs: []string{
 				fmt.Sprintf(`GIT_SSH_COMMAND=ssh -o UserKnownHostsFile="%s" -o StrictHostKeyChecking=no -i "%s"`,
 					filepath.Join(d.cfg.DataPath, "ssh", "known_hosts"),
@@ -194,6 +203,9 @@ func (d *SqliteBackend) ImportRepository(name string, remote string, opts backen
 
 	if err := git.Clone(remote, rp, copts); err != nil {
 		d.logger.Error("failed to clone repository", "err", err, "mirror", opts.Mirror, "remote", remote, "path", rp)
+		if rerr := os.RemoveAll(rp); rerr != nil {
+			err = errors.Join(err, rerr)
+		}
 		return nil, err
 	}
 

server/cmd/cmd.go 🔗

@@ -18,25 +18,9 @@ import (
 	"github.com/spf13/cobra"
 )
 
-// ContextKey is a type that can be used as a key in a context.
-type ContextKey string
-
-// String returns the string representation of the ContextKey.
-func (c ContextKey) String() string {
-	return string(c) + "ContextKey"
-}
-
-var (
-	// ConfigCtxKey is the key for the config in the context.
-	ConfigCtxKey = ContextKey("config")
-	// SessionCtxKey is the key for the session in the context.
-	SessionCtxKey = ContextKey("session")
-	// HooksCtxKey is the key for the git hooks in the context.
-	HooksCtxKey = ContextKey("hooks")
-)
-
 var (
-	logger = log.WithPrefix("server.cmd")
+	// sessionCtxKey is the key for the session in the context.
+	sessionCtxKey = &struct{ string }{"session"}
 )
 
 var templateFuncs = template.FuncMap{
@@ -152,8 +136,8 @@ func rootCommand(cfg *config.Config, s ssh.Session) *cobra.Command {
 
 func fromContext(cmd *cobra.Command) (*config.Config, ssh.Session) {
 	ctx := cmd.Context()
-	cfg := ctx.Value(ConfigCtxKey).(*config.Config)
-	s := ctx.Value(SessionCtxKey).(ssh.Session)
+	cfg := config.FromContext(ctx)
+	s := ctx.Value(sessionCtxKey).(ssh.Session)
 	return cfg, s
 }
 
@@ -213,7 +197,7 @@ func checkIfCollab(cmd *cobra.Command, args []string) error {
 }
 
 // Middleware is the Soft Serve middleware that handles SSH commands.
-func Middleware(cfg *config.Config) wish.Middleware {
+func Middleware(cfg *config.Config, logger *log.Logger) wish.Middleware {
 	return func(sh ssh.Handler) ssh.Handler {
 		return func(s ssh.Session) {
 			func() {
@@ -232,8 +216,16 @@ func Middleware(cfg *config.Config) wish.Middleware {
 					}
 				}
 
-				ctx := context.WithValue(s.Context(), ConfigCtxKey, cfg)
-				ctx = context.WithValue(ctx, SessionCtxKey, s)
+				// Here we copy the server's config and replace the backend
+				// with a new one that uses the session's context.
+				var ctx context.Context = s.Context()
+				scfg := *cfg
+				cfg = &scfg
+				be := cfg.Backend.WithContext(ctx)
+				cfg.Backend = be
+				ctx = config.WithContext(ctx, cfg)
+				ctx = backend.WithContext(ctx, be)
+				ctx = context.WithValue(ctx, sessionCtxKey, s)
 
 				rootCmd := rootCommand(cfg, s)
 				rootCmd.SetArgs(args)

server/daemon/daemon.go 🔗

@@ -83,6 +83,7 @@ type GitDaemon struct {
 	finished chan struct{}
 	conns    connections
 	cfg      *config.Config
+	be       backend.Backend
 	wg       sync.WaitGroup
 	once     sync.Once
 	logger   *log.Logger
@@ -97,6 +98,7 @@ func NewGitDaemon(ctx context.Context) (*GitDaemon, error) {
 		addr:     addr,
 		finished: make(chan struct{}, 1),
 		cfg:      cfg,
+		be:       backend.FromContext(ctx),
 		conns:    connections{m: make(map[net.Conn]struct{})},
 		logger:   log.FromContext(ctx).WithPrefix("gitdaemon"),
 	}
@@ -231,7 +233,8 @@ func (d *GitDaemon) handleClient(conn net.Conn) {
 			return
 		}
 
-		if !d.cfg.Backend.AllowKeyless() {
+		be := d.be.WithContext(ctx)
+		if !be.AllowKeyless() {
 			d.fatal(c, git.ErrNotAuthed)
 			return
 		}
@@ -248,7 +251,7 @@ func (d *GitDaemon) handleClient(conn net.Conn) {
 			return
 		}
 
-		auth := d.cfg.Backend.AccessLevel(name, "")
+		auth := be.AccessLevel(name, "")
 		if auth < backend.ReadOnlyAccess {
 			d.fatal(c, git.ErrNotAuthed)
 			return

server/daemon/daemon_test.go 🔗

@@ -12,6 +12,7 @@ import (
 	"strings"
 	"testing"
 
+	"github.com/charmbracelet/soft-serve/server/backend"
 	"github.com/charmbracelet/soft-serve/server/backend/sqlite"
 	"github.com/charmbracelet/soft-serve/server/config"
 	"github.com/charmbracelet/soft-serve/server/git"
@@ -35,15 +36,16 @@ func TestMain(m *testing.M) {
 	ctx := context.TODO()
 	cfg := config.DefaultConfig()
 	ctx = config.WithContext(ctx, cfg)
-	d, err := NewGitDaemon(ctx)
+	fb, err := sqlite.NewSqliteBackend(ctx)
 	if err != nil {
 		log.Fatal(err)
 	}
-	fb, err := sqlite.NewSqliteBackend(ctx)
+	cfg = cfg.WithBackend(fb)
+	ctx = backend.WithContext(ctx, fb)
+	d, err := NewGitDaemon(ctx)
 	if err != nil {
 		log.Fatal(err)
 	}
-	cfg = cfg.WithBackend(fb)
 	testDaemon = d
 	go func() {
 		if err := d.Start(); err != ErrServerClosed {

server/server.go 🔗

@@ -50,6 +50,7 @@ func NewServer(ctx context.Context) (*Server, error) {
 		}
 
 		cfg = cfg.WithBackend(sb)
+		ctx = backend.WithContext(ctx, sb)
 	}
 
 	srv := &Server{

server/ssh/ssh.go 🔗

@@ -77,6 +77,7 @@ var (
 type SSHServer struct {
 	srv    *ssh.Server
 	cfg    *config.Config
+	be     backend.Backend
 	ctx    context.Context
 	logger *log.Logger
 }
@@ -84,26 +85,28 @@ type SSHServer struct {
 // NewSSHServer returns a new SSHServer.
 func NewSSHServer(ctx context.Context) (*SSHServer, error) {
 	cfg := config.FromContext(ctx)
+	logger := log.FromContext(ctx).WithPrefix("ssh")
 
 	var err error
 	s := &SSHServer{
 		cfg:    cfg,
 		ctx:    ctx,
-		logger: log.FromContext(ctx).WithPrefix("ssh"),
+		be:     backend.FromContext(ctx),
+		logger: logger,
 	}
 
-	logger := s.logger.StandardLog(log.StandardLogOptions{ForceLevel: log.DebugLevel})
 	mw := []wish.Middleware{
 		rm.MiddlewareWithLogger(
 			logger,
 			// BubbleTea middleware.
 			bm.MiddlewareWithProgramHandler(SessionHandler(cfg), termenv.ANSI256),
 			// CLI middleware.
-			cm.Middleware(cfg),
+			cm.Middleware(cfg, logger),
 			// Git middleware.
 			s.Middleware(cfg),
 			// Logging middleware.
-			lm.MiddlewareWithLogger(logger),
+			lm.MiddlewareWithLogger(logger.
+				StandardLog(log.StandardLogOptions{ForceLevel: log.DebugLevel})),
 		),
 	}
 
@@ -191,6 +194,8 @@ func (ss *SSHServer) Middleware(cfg *config.Config) wish.Middleware {
 		return func(s ssh.Session) {
 			func() {
 				cmd := s.Command()
+				ctx := s.Context()
+				be := ss.be.WithContext(ctx)
 				if len(cmd) >= 2 && strings.HasPrefix(cmd[0], "git") {
 					gc := cmd[0]
 					// repo should be in the form of "repo.git"
@@ -222,8 +227,8 @@ func (ss *SSHServer) Middleware(cfg *config.Config) wish.Middleware {
 							sshFatal(s, git.ErrNotAuthed)
 							return
 						}
-						if _, err := cfg.Backend.Repository(name); err != nil {
-							if _, err := cfg.Backend.CreateRepository(name, backend.RepositoryOptions{Private: false}); err != nil {
+						if _, err := be.Repository(name); err != nil {
+							if _, err := be.CreateRepository(name, backend.RepositoryOptions{Private: false}); err != nil {
 								log.Errorf("failed to create repo: %s", err)
 								sshFatal(s, err)
 								return
@@ -248,7 +253,7 @@ func (ss *SSHServer) Middleware(cfg *config.Config) wish.Middleware {
 							counter = uploadArchiveCounter
 						}
 
-						err := gitPack(s.Context(), s, s, s.Stderr(), repoDir, envs...)
+						err := gitPack(ctx, s, s, s.Stderr(), repoDir, envs...)
 						if errors.Is(err, git.ErrInvalidRepo) {
 							sshFatal(s, git.ErrInvalidRepo)
 						} else if err != nil {

server/web/http.go 🔗

@@ -81,6 +81,7 @@ func (s *HTTPServer) loggingMiddleware(next http.Handler) http.Handler {
 type HTTPServer struct {
 	ctx        context.Context
 	cfg        *config.Config
+	be         backend.Backend
 	server     *http.Server
 	dirHandler http.Handler
 	logger     *log.Logger
@@ -92,6 +93,7 @@ func NewHTTPServer(ctx context.Context) (*HTTPServer, error) {
 	s := &HTTPServer{
 		ctx:        ctx,
 		cfg:        cfg,
+		be:         backend.FromContext(ctx),
 		logger:     log.FromContext(ctx).WithPrefix("http"),
 		dirHandler: http.FileServer(http.Dir(filepath.Join(cfg.DataPath, "repos"))),
 		server: &http.Server{
@@ -254,6 +256,7 @@ Redirecting to docs at <a href="https://godoc.org/{{ .ImportRoot }}/{{ .Repo }}"
 func (s *HTTPServer) handleIndex(w http.ResponseWriter, r *http.Request) {
 	repo := pattern.Path(r.Context())
 	repo = utils.SanitizeRepo(repo)
+	be := s.be.WithContext(r.Context())
 
 	// Handle go get requests.
 	//
@@ -271,7 +274,7 @@ func (s *HTTPServer) handleIndex(w http.ResponseWriter, r *http.Request) {
 
 		// find the repo
 		for {
-			if _, err := s.cfg.Backend.Repository(repo); err == nil {
+			if _, err := be.Repository(repo); err == nil {
 				break
 			}
 
@@ -305,7 +308,8 @@ func (s *HTTPServer) handleIndex(w http.ResponseWriter, r *http.Request) {
 func (s *HTTPServer) handleGit(w http.ResponseWriter, r *http.Request) {
 	repo := pat.Param(r, "repo")
 	repo = utils.SanitizeRepo(repo) + ".git"
-	if _, err := s.cfg.Backend.Repository(repo); err != nil {
+	be := s.be.WithContext(r.Context())
+	if _, err := be.Repository(repo); err != nil {
 		s.logger.Debug("repository not found", "repo", repo, "err", err)
 		http.NotFound(w, r)
 		return