feat(server): use a separate ssh server for internal commands

Ayman Bagabas created

Change summary

cmd/soft/hook.go                |   7 +
cmd/soft/serve.go               |   3 
server/backend/sqlite/db.go     |   2 
server/backend/sqlite/sqlite.go |   3 
server/backend/sqlite/user.go   |   5 
server/cmd/cmd.go               |  28 ++------
server/cmd/tree.go              |   3 
server/cmd/user.go              |   5 
server/config/config.go         | 107 +++++++++++++++++++++-------------
server/config/file.go           |  24 +++++--
server/errors/errors.go         |  12 +++
server/hooks.go                 |  22 +++---
server/hooks/hooks.go           |   8 +-
server/internal/cmd.go          |  84 +++++++++++++++++++++++++++
server/internal/hook.go         |  48 +++++++--------
server/internal/internal.go     |  86 ++++++++++++++++++++++++++++
server/jobs.go                  |   3 
server/server.go                |  78 +++++++++++++------------
server/ssh/session.go           |   4 
server/ssh/ssh.go               |   5 
20 files changed, 373 insertions(+), 164 deletions(-)

Detailed changes

cmd/soft/hook.go 🔗

@@ -135,6 +135,7 @@ func init() {
 	hookCmd.PersistentFlags().StringVarP(&configPath, "config", "c", "", "path to config file")
 }
 
+// TODO: use ssh controlmaster
 func commonInit() (c *gossh.Client, s *gossh.Session, err error) {
 	cfg, err := config.ParseConfig(configPath)
 	if err != nil {
@@ -173,11 +174,11 @@ func commonInit() (c *gossh.Client, s *gossh.Session, err error) {
 
 func newClient(cfg *config.Config) (*gossh.Client, error) {
 	// Only accept the server's host key.
-	pk, err := keygen.New(cfg.SSH.KeyPath, keygen.WithKeyType(keygen.Ed25519))
+	pk, err := keygen.New(cfg.Internal.KeyPath, keygen.WithKeyType(keygen.Ed25519))
 	if err != nil {
 		return nil, err
 	}
-	ik, err := keygen.New(cfg.SSH.InternalKeyPath, keygen.WithKeyType(keygen.Ed25519))
+	ik, err := keygen.New(cfg.Internal.InternalKeyPath, keygen.WithKeyType(keygen.Ed25519))
 	if err != nil {
 		return nil, err
 	}
@@ -188,7 +189,7 @@ func newClient(cfg *config.Config) (*gossh.Client, error) {
 		},
 		HostKeyCallback: gossh.FixedHostKey(pk.PublicKey()),
 	}
-	c, err := gossh.Dial("tcp", cfg.SSH.ListenAddr, cc)
+	c, err := gossh.Dial("tcp", cfg.Internal.ListenAddr, cc)
 	if err != nil {
 		return nil, err
 	}

cmd/soft/serve.go 🔗

@@ -2,6 +2,7 @@ package main
 
 import (
 	"context"
+	"fmt"
 	"os"
 	"os/signal"
 	"syscall"
@@ -24,7 +25,7 @@ var (
 			cfg := config.DefaultConfig()
 			s, err := server.NewServer(ctx, cfg)
 			if err != nil {
-				return err
+				return fmt.Errorf("start server: %w", err)
 			}
 
 			done := make(chan os.Signal, 1)

server/backend/sqlite/db.go 🔗

@@ -61,6 +61,8 @@ func (d *SqliteBackend) init() error {
 			}
 
 			// Add initial keys
+			// Don't use cfg.AdminKeys since it also includes the internal key
+			// used for internal api access.
 			for _, k := range d.cfg.InitialAdminKeys {
 				pk, _, err := backend.ParseAuthorizedKey(k)
 				if err != nil {

server/backend/sqlite/sqlite.go 🔗

@@ -186,8 +186,7 @@ func (d *SqliteBackend) ImportRepository(name string, remote string, opts backen
 			Envs: []string{
 				fmt.Sprintf(`GIT_SSH_COMMAND=ssh -o UserKnownHostsFile="%s" -o StrictHostKeyChecking=no -i "%s"`,
 					filepath.Join(d.cfg.DataPath, "ssh", "known_hosts"),
-					// FIXME: upstream keygen appends _ed25519 to the key path.
-					d.cfg.SSH.ClientKeyPath+"_ed25519",
+					d.cfg.Internal.ClientKeyPath,
 				),
 			},
 		},

server/backend/sqlite/user.go 🔗

@@ -118,9 +118,8 @@ func (d *SqliteBackend) AccessLevel(repo string, username string) backend.Access
 //
 // It implements backend.Backend.
 func (d *SqliteBackend) AccessLevelByPublicKey(repo string, pk ssh.PublicKey) backend.AccessLevel {
-	for _, k := range append(d.cfg.InitialAdminKeys, d.cfg.InternalPublicKey) {
-		ik, _, err := backend.ParseAuthorizedKey(k)
-		if err == nil && backend.KeysEqual(pk, ik) {
+	for _, k := range d.cfg.AdminKeys() {
+		if backend.KeysEqual(pk, k) {
 			return backend.AdminAccess
 		}
 	}

server/cmd/cmd.go 🔗

@@ -11,7 +11,7 @@ import (
 	"github.com/charmbracelet/log"
 	"github.com/charmbracelet/soft-serve/server/backend"
 	"github.com/charmbracelet/soft-serve/server/config"
-	"github.com/charmbracelet/soft-serve/server/hooks"
+	"github.com/charmbracelet/soft-serve/server/errors"
 	"github.com/charmbracelet/soft-serve/server/utils"
 	"github.com/charmbracelet/ssh"
 	"github.com/charmbracelet/wish"
@@ -35,15 +35,6 @@ var (
 	HooksCtxKey = ContextKey("hooks")
 )
 
-var (
-	// ErrUnauthorized is returned when the user is not authorized to perform action.
-	ErrUnauthorized = fmt.Errorf("Unauthorized")
-	// ErrRepoNotFound is returned when the repo is not found.
-	ErrRepoNotFound = fmt.Errorf("Repository not found")
-	// ErrFileNotFound is returned when the file is not found.
-	ErrFileNotFound = fmt.Errorf("File not found")
-)
-
 var (
 	logger = log.WithPrefix("server.cmd")
 )
@@ -136,7 +127,6 @@ func rootCommand(cfg *config.Config, s ssh.Session) *cobra.Command {
 	})
 	rootCmd.CompletionOptions.DisableDefaultCmd = true
 	rootCmd.AddCommand(
-		hookCommand(),
 		repoCommand(),
 	)
 
@@ -176,15 +166,14 @@ func checkIfReadable(cmd *cobra.Command, args []string) error {
 	rn := utils.SanitizeRepo(repo)
 	auth := cfg.Backend.AccessLevelByPublicKey(rn, s.PublicKey())
 	if auth < backend.ReadOnlyAccess {
-		return ErrUnauthorized
+		return errors.ErrUnauthorized
 	}
 	return nil
 }
 
 func isPublicKeyAdmin(cfg *config.Config, pk ssh.PublicKey) bool {
-	for _, k := range cfg.InitialAdminKeys {
-		pk2, _, err := backend.ParseAuthorizedKey(k)
-		if err == nil && backend.KeysEqual(pk, pk2) {
+	for _, k := range cfg.AdminKeys() {
+		if backend.KeysEqual(pk, k) {
 			return true
 		}
 	}
@@ -199,11 +188,11 @@ func checkIfAdmin(cmd *cobra.Command, _ []string) error {
 
 	user, _ := cfg.Backend.UserByPublicKey(s.PublicKey())
 	if user == nil {
-		return ErrUnauthorized
+		return errors.ErrUnauthorized
 	}
 
 	if !user.IsAdmin() {
-		return ErrUnauthorized
+		return errors.ErrUnauthorized
 	}
 
 	return nil
@@ -218,13 +207,13 @@ func checkIfCollab(cmd *cobra.Command, args []string) error {
 	rn := utils.SanitizeRepo(repo)
 	auth := cfg.Backend.AccessLevelByPublicKey(rn, s.PublicKey())
 	if auth < backend.ReadWriteAccess {
-		return ErrUnauthorized
+		return errors.ErrUnauthorized
 	}
 	return nil
 }
 
 // Middleware is the Soft Serve middleware that handles SSH commands.
-func Middleware(cfg *config.Config, hooks hooks.Hooks) wish.Middleware {
+func Middleware(cfg *config.Config) wish.Middleware {
 	return func(sh ssh.Handler) ssh.Handler {
 		return func(s ssh.Session) {
 			func() {
@@ -245,7 +234,6 @@ func Middleware(cfg *config.Config, hooks hooks.Hooks) wish.Middleware {
 
 				ctx := context.WithValue(s.Context(), ConfigCtxKey, cfg)
 				ctx = context.WithValue(ctx, SessionCtxKey, s)
-				ctx = context.WithValue(ctx, HooksCtxKey, hooks)
 
 				rootCmd := rootCommand(cfg, s)
 				rootCmd.SetArgs(args)

server/cmd/tree.go 🔗

@@ -4,6 +4,7 @@ import (
 	"fmt"
 
 	"github.com/charmbracelet/soft-serve/git"
+	"github.com/charmbracelet/soft-serve/server/errors"
 	"github.com/dustin/go-humanize"
 	"github.com/spf13/cobra"
 )
@@ -58,7 +59,7 @@ func treeCommand() *cobra.Command {
 			if path != "" && path != "/" {
 				te, err := tree.TreeEntry(path)
 				if err == git.ErrRevisionNotExist {
-					return ErrFileNotFound
+					return errors.ErrFileNotFound
 				}
 				if err != nil {
 					return err

server/cmd/user.go 🔗

@@ -142,7 +142,6 @@ func userCommand() *cobra.Command {
 		PersistentPreRunE: checkIfAdmin,
 		RunE: func(cmd *cobra.Command, args []string) error {
 			cfg, s := fromContext(cmd)
-			ak := backend.MarshalAuthorizedKey(s.PublicKey())
 			username := args[0]
 
 			user, err := cfg.Backend.User(username)
@@ -151,8 +150,8 @@ func userCommand() *cobra.Command {
 			}
 
 			isAdmin := user.IsAdmin()
-			for _, k := range cfg.InitialAdminKeys {
-				if ak == k {
+			for _, k := range cfg.AdminKeys() {
+				if backend.KeysEqual(k, s.PublicKey()) {
 					isAdmin = true
 					break
 				}

server/config/config.go 🔗

@@ -10,6 +10,7 @@ import (
 	"github.com/caarlos0/env/v7"
 	"github.com/charmbracelet/log"
 	"github.com/charmbracelet/soft-serve/server/backend"
+	"golang.org/x/crypto/ssh"
 	"gopkg.in/yaml.v3"
 )
 
@@ -24,12 +25,6 @@ type SSHConfig struct {
 	// KeyPath is the path to the SSH server's private key.
 	KeyPath string `env:"KEY_PATH" yaml:"key_path"`
 
-	// ClientKeyPath is the path to the SSH server's client private key.
-	ClientKeyPath string `env:"CLIENT_KEY_PATH" yaml:"client_key_path"`
-
-	// InternalKeyPath is the path to the SSH server's internal private key.
-	InternalKeyPath string `env:"INTERNAL_KEY_PATH" yaml:"internal_key_path"`
-
 	// MaxTimeout is the maximum number of seconds a connection can take.
 	MaxTimeout int `env:"MAX_TIMEOUT" yaml:"max_timeout`
 
@@ -73,6 +68,22 @@ type StatsConfig struct {
 	ListenAddr string `env:"LISTEN_ADDR" yaml:"listen_addr"`
 }
 
+// InternalConfig is the configuration for the internal server.
+// This is used for internal communication between the Soft Serve client and server.
+type InternalConfig struct {
+	// ListenAddr is the address on which the internal server will listen.
+	ListenAddr string `env:"LISTEN_ADDR" yaml:"listen_addr"`
+
+	// KeyPath is the path to the SSH server's host private key.
+	KeyPath string `env:"KEY_PATH" yaml:"key_path"`
+
+	// InternalKeyPath is the path to the server's internal private key.
+	InternalKeyPath string `env:"INTERNAL_KEY_PATH" yaml:"internal_key_path"`
+
+	// ClientKeyPath is the path to the server's client private key.
+	ClientKeyPath string `env:"CLIENT_KEY_PATH" yaml:"client_key_path"`
+}
+
 // Config is the configuration for Soft Serve.
 type Config struct {
 	// Name is the name of the server.
@@ -90,6 +101,9 @@ type Config struct {
 	// Stats is the configuration for the stats server.
 	Stats StatsConfig `envPrefix:"STATS_" yaml:"stats"`
 
+	// Internal is the configuration for the internal server.
+	Internal InternalConfig `envPrefix:"INTERNAL_" yaml:"internal"`
+
 	// InitialAdminKeys is a list of public keys that will be added to the list of admins.
 	InitialAdminKeys []string `env:"INITIAL_ADMIN_KEYS" envSeparator:"\n" yaml:"initial_admin_keys"`
 
@@ -98,12 +112,6 @@ type Config struct {
 
 	// Backend is the Git backend to use.
 	Backend backend.Backend `yaml:"-"`
-
-	// InternalPublicKey is the public key of the internal SSH key.
-	InternalPublicKey string `yaml:"-"`
-
-	// ClientPublicKey is the public key of the client SSH key.
-	ClientPublicKey string `yaml:"-"`
 }
 
 func parseConfig(path string) (*Config, error) {
@@ -112,13 +120,11 @@ func parseConfig(path string) (*Config, error) {
 		Name:     "Soft Serve",
 		DataPath: dataPath,
 		SSH: SSHConfig{
-			ListenAddr:      ":23231",
-			PublicURL:       "ssh://localhost:23231",
-			KeyPath:         filepath.Join("ssh", "soft_serve_host_ed25519"),
-			ClientKeyPath:   filepath.Join("ssh", "soft_serve_client_ed25519"),
-			InternalKeyPath: filepath.Join("ssh", "soft_serve_internal_ed25519"),
-			MaxTimeout:      0,
-			IdleTimeout:     120,
+			ListenAddr:  ":23231",
+			PublicURL:   "ssh://localhost:23231",
+			KeyPath:     filepath.Join("ssh", "soft_serve_host_ed25519"),
+			MaxTimeout:  0,
+			IdleTimeout: 120,
 		},
 		Git: GitConfig{
 			ListenAddr:     ":9418",
@@ -127,11 +133,17 @@ func parseConfig(path string) (*Config, error) {
 			MaxConnections: 32,
 		},
 		HTTP: HTTPConfig{
-			ListenAddr: ":8080",
-			PublicURL:  "http://localhost:8080",
+			ListenAddr: ":23232",
+			PublicURL:  "http://localhost:23232",
 		},
 		Stats: StatsConfig{
-			ListenAddr: ":8081",
+			ListenAddr: "localhost:23233",
+		},
+		Internal: InternalConfig{
+			ListenAddr:      "localhost:23230",
+			KeyPath:         filepath.Join("ssh", "soft_serve_internal_host_ed25519"),
+			InternalKeyPath: filepath.Join("ssh", "soft_serve_internal_ed25519"),
+			ClientKeyPath:   filepath.Join("ssh", "soft_serve_client_ed25519"),
 		},
 	}
 
@@ -160,20 +172,10 @@ func parseConfig(path string) (*Config, error) {
 
 	// Validate keys
 	pks := make([]string, 0)
-	for _, key := range cfg.InitialAdminKeys {
-		var pk string
-		if bts, err := os.ReadFile(key); err == nil {
-			// key is a file
-			pk = string(bts)
-		}
-		if _, _, err := backend.ParseAuthorizedKey(key); err == nil {
-			pk = key
-		}
-		pk = strings.TrimSpace(pk)
-		if pk != "" {
-			log.Debugf("found initial admin key: %q", key)
-			pks = append(pks, pk)
-		}
+	for _, key := range parseAuthKeys(cfg.InitialAdminKeys) {
+		ak := backend.MarshalAuthorizedKey(key)
+		pks = append(pks, ak)
+		log.Debugf("found initial admin key: %q", ak)
 	}
 
 	cfg.InitialAdminKeys = pks
@@ -259,12 +261,16 @@ func (c *Config) validate() error {
 		c.SSH.KeyPath = filepath.Join(c.DataPath, c.SSH.KeyPath)
 	}
 
-	if c.SSH.ClientKeyPath != "" && !filepath.IsAbs(c.SSH.ClientKeyPath) {
-		c.SSH.ClientKeyPath = filepath.Join(c.DataPath, c.SSH.ClientKeyPath)
+	if c.Internal.KeyPath != "" && !filepath.IsAbs(c.Internal.KeyPath) {
+		c.Internal.KeyPath = filepath.Join(c.DataPath, c.Internal.KeyPath)
 	}
 
-	if c.SSH.InternalKeyPath != "" && !filepath.IsAbs(c.SSH.InternalKeyPath) {
-		c.SSH.InternalKeyPath = filepath.Join(c.DataPath, c.SSH.InternalKeyPath)
+	if c.Internal.ClientKeyPath != "" && !filepath.IsAbs(c.Internal.ClientKeyPath) {
+		c.Internal.ClientKeyPath = filepath.Join(c.DataPath, c.Internal.ClientKeyPath)
+	}
+
+	if c.Internal.InternalKeyPath != "" && !filepath.IsAbs(c.Internal.InternalKeyPath) {
+		c.Internal.InternalKeyPath = filepath.Join(c.DataPath, c.Internal.InternalKeyPath)
 	}
 
 	if c.HTTP.TLSKeyPath != "" && !filepath.IsAbs(c.HTTP.TLSKeyPath) {
@@ -277,3 +283,24 @@ func (c *Config) validate() error {
 
 	return nil
 }
+
+// parseAuthKeys parses authorized keys from either file paths or string authorized_keys.
+func parseAuthKeys(aks []string) []ssh.PublicKey {
+	pks := make([]ssh.PublicKey, 0)
+	for _, key := range aks {
+		var ak string
+		if bts, err := os.ReadFile(key); err == nil {
+			// key is a file
+			ak = strings.TrimSpace(string(bts))
+		}
+		if pk, _, err := backend.ParseAuthorizedKey(ak); err == nil {
+			pks = append(pks, pk)
+		}
+	}
+	return pks
+}
+
+// AdminKeys returns the admin keys including the internal api key.
+func (c *Config) AdminKeys() []ssh.PublicKey {
+	return parseAuthKeys(append(c.InitialAdminKeys, c.Internal.InternalKeyPath))
+}

server/config/file.go 🔗

@@ -24,14 +24,6 @@ ssh:
   # The path to the SSH server's private key.
   key_path: "{{ .SSH.KeyPath }}"
 
-  # The path to the SSH server's client private key.
-  # This key will be used to authenticate the server to make git requests to
-  # ssh remotes.
-  client_key_path: "{{ .SSH.ClientKeyPath }}"
-
-  # The path to the SSH server's internal api private key.
-  internal_key_path: "{{ .SSH.InternalKeyPath }}"
-
   # The maximum number of seconds a connection can take.
   # A value of 0 means no timeout.
   max_timeout: {{ .SSH.MaxTimeout }}
@@ -75,6 +67,22 @@ stats:
   # The address on which the stats server will listen.
   listen_addr: "{{ .Stats.ListenAddr }}"
 
+# The internal server configuration.
+internal:
+  # The address on which the internal server will listen.
+  listen_addr: "{{ .Internal.ListenAddr }}"
+
+  # The path to the Internal server's host private key.
+  key_path: "{{ .Internal.KeyPath }}"
+
+  # The path to the Internal server's client private key.
+  # This key will be used to authenticate the server to make git requests to
+  # ssh remotes.
+  client_key_path: "{{ .Internal.ClientKeyPath }}"
+
+  # The path to the Internal server's internal api private key.
+  internal_key_path: "{{ .Internal.InternalKeyPath }}"
+
 # Additional admin keys.
 #initial_admin_keys:
 #  - "ssh-rsa AAAAB3NzaC1yc2..."

server/errors/errors.go 🔗

@@ -0,0 +1,12 @@
+package errors
+
+import "fmt"
+
+var (
+	// ErrUnauthorized is returned when the user is not authorized to perform action.
+	ErrUnauthorized = fmt.Errorf("Unauthorized")
+	// ErrRepoNotFound is returned when the repo is not found.
+	ErrRepoNotFound = fmt.Errorf("Repository not found")
+	// ErrFileNotFound is returned when the file is not found.
+	ErrFileNotFound = fmt.Errorf("File not found")
+)

server/hooks.go 🔗

@@ -3,6 +3,7 @@ package server
 import (
 	"io"
 
+	"github.com/charmbracelet/log"
 	"github.com/charmbracelet/soft-serve/server/hooks"
 )
 
@@ -11,42 +12,43 @@ var _ hooks.Hooks = (*Server)(nil)
 // PostReceive is called by the git post-receive hook.
 //
 // It implements Hooks.
-func (*Server) PostReceive(stdout io.Writer, stderr io.Writer, repo string, args []hooks.HookArg) {
-	logger.Debug("post-receive hook called", "repo", repo, "args", args)
+func (*Server) PostReceive(stdin io.Reader, stdout io.Writer, stderr io.Writer, repo string, args []hooks.HookArg) {
+	io.WriteString(stdout, "Hello, world!\n")
+	log.WithPrefix("server.hooks").Debug("post-receive hook called", "repo", repo, "args", args)
 }
 
 // PreReceive is called by the git pre-receive hook.
 //
 // It implements Hooks.
-func (*Server) PreReceive(stdout io.Writer, stderr io.Writer, repo string, args []hooks.HookArg) {
-	logger.Debug("pre-receive hook called", "repo", repo, "args", args)
+func (*Server) PreReceive(stdin io.Reader, stdout io.Writer, stderr io.Writer, repo string, args []hooks.HookArg) {
+	log.WithPrefix("server.hooks").Debug("pre-receive hook called", "repo", repo, "args", args)
 }
 
 // Update is called by the git update hook.
 //
 // It implements Hooks.
-func (*Server) Update(stdout io.Writer, stderr io.Writer, repo string, arg hooks.HookArg) {
-	logger.Debug("update hook called", "repo", repo, "arg", arg)
+func (*Server) Update(stdin io.Reader, stdout io.Writer, stderr io.Writer, repo string, arg hooks.HookArg) {
+	log.WithPrefix("server.hooks").Debug("update hook called", "repo", repo, "arg", arg)
 }
 
 // PostUpdate is called by the git post-update hook.
 //
 // It implements Hooks.
-func (s *Server) PostUpdate(stdout io.Writer, stderr io.Writer, repo string, args ...string) {
+func (s *Server) PostUpdate(stdin io.Reader, stdout io.Writer, stderr io.Writer, repo string, args ...string) {
 	rr, err := s.Config.Backend.Repository(repo)
 	if err != nil {
-		logger.WithPrefix("server.hooks.post-update").Error("error getting repository", "repo", repo, "err", err)
+		log.WithPrefix("server.hooks.post-update").Error("error getting repository", "repo", repo, "err", err)
 		return
 	}
 
 	r, err := rr.Open()
 	if err != nil {
-		logger.WithPrefix("server.hooks.post-update").Error("error opening repository", "repo", repo, "err", err)
+		log.WithPrefix("server.hooks.post-update").Error("error opening repository", "repo", repo, "err", err)
 		return
 	}
 
 	if err := r.UpdateServerInfo(); err != nil {
-		logger.WithPrefix("server.hooks.post-update").Error("error updating server info", "repo", repo, "err", err)
+		log.WithPrefix("server.hooks.post-update").Error("error updating server info", "repo", repo, "err", err)
 		return
 	}
 }

server/hooks/hooks.go 🔗

@@ -11,8 +11,8 @@ type HookArg struct {
 
 // Hooks provides an interface for git server-side hooks.
 type Hooks interface {
-	PreReceive(stdout io.Writer, stderr io.Writer, repo string, args []HookArg)
-	Update(stdout io.Writer, stderr io.Writer, repo string, arg HookArg)
-	PostReceive(stdout io.Writer, stderr io.Writer, repo string, args []HookArg)
-	PostUpdate(stdout io.Writer, stderr io.Writer, repo string, args ...string)
+	PreReceive(stdin io.Reader, stdout io.Writer, stderr io.Writer, repo string, args []HookArg)
+	Update(stdin io.Reader, stdout io.Writer, stderr io.Writer, repo string, arg HookArg)
+	PostReceive(stdin io.Reader, stdout io.Writer, stderr io.Writer, repo string, args []HookArg)
+	PostUpdate(stdin io.Reader, stdout io.Writer, stderr io.Writer, repo string, args ...string)
 }

server/internal/cmd.go 🔗

@@ -0,0 +1,84 @@
+package internal
+
+import (
+	"context"
+
+	"github.com/charmbracelet/soft-serve/server/config"
+	"github.com/charmbracelet/soft-serve/server/hooks"
+	"github.com/charmbracelet/ssh"
+	"github.com/charmbracelet/wish"
+	"github.com/spf13/cobra"
+)
+
+var (
+	hooksCtxKey   = "hooks"
+	sessionCtxKey = "session"
+	configCtxKey  = "config"
+)
+
+// rootCommand is the root command for the server.
+func rootCommand(cfg *config.Config, s ssh.Session) *cobra.Command {
+	rootCmd := &cobra.Command{
+		Short:        "Soft Serve internal API.",
+		SilenceUsage: true,
+	}
+
+	rootCmd.SetIn(s)
+	rootCmd.SetOut(s)
+	rootCmd.SetErr(s)
+	rootCmd.CompletionOptions.DisableDefaultCmd = true
+
+	rootCmd.AddCommand(
+		hookCommand(),
+	)
+
+	return rootCmd
+}
+
+// Middleware returns the middleware for the server.
+func (i *InternalServer) Middleware(hooks hooks.Hooks) wish.Middleware {
+	return func(sh ssh.Handler) ssh.Handler {
+		return func(s ssh.Session) {
+			_, _, active := s.Pty()
+			if active {
+				return
+			}
+
+			// Ignore git server commands.
+			args := s.Command()
+			if len(args) > 0 {
+				if args[0] == "git-receive-pack" ||
+					args[0] == "git-upload-pack" ||
+					args[0] == "git-upload-archive" {
+					return
+				}
+			}
+
+			ctx := context.WithValue(s.Context(), hooksCtxKey, hooks)
+			ctx = context.WithValue(ctx, sessionCtxKey, s)
+			ctx = context.WithValue(ctx, configCtxKey, i.cfg)
+
+			rootCmd := rootCommand(i.cfg, s)
+			rootCmd.SetArgs(args)
+			if len(args) == 0 {
+				// otherwise it'll default to os.Args, which is not what we want.
+				rootCmd.SetArgs([]string{"--help"})
+			}
+			rootCmd.SetIn(s)
+			rootCmd.SetOut(s)
+			rootCmd.CompletionOptions.DisableDefaultCmd = true
+			rootCmd.SetErr(s.Stderr())
+			if err := rootCmd.ExecuteContext(ctx); err != nil {
+				_ = s.Exit(1)
+			}
+			sh(s)
+		}
+	}
+}
+
+func fromContext(cmd *cobra.Command) (*config.Config, ssh.Session) {
+	ctx := cmd.Context()
+	cfg := ctx.Value(configCtxKey).(*config.Config)
+	s := ctx.Value(sessionCtxKey).(ssh.Session)
+	return cfg, s
+}

server/cmd/hook.go → server/internal/hook.go 🔗

@@ -1,4 +1,4 @@
-package cmd
+package internal
 
 import (
 	"bufio"
@@ -6,7 +6,9 @@ import (
 	"strings"
 
 	"github.com/charmbracelet/keygen"
+	"github.com/charmbracelet/log"
 	"github.com/charmbracelet/soft-serve/server/backend"
+	"github.com/charmbracelet/soft-serve/server/errors"
 	"github.com/charmbracelet/soft-serve/server/hooks"
 	"github.com/charmbracelet/ssh"
 	"github.com/spf13/cobra"
@@ -15,12 +17,11 @@ import (
 // hookCommand handles Soft Serve internal API git hook requests.
 func hookCommand() *cobra.Command {
 	preReceiveCmd := &cobra.Command{
-		Use:               "pre-receive",
-		Short:             "Run git pre-receive hook",
-		PersistentPreRunE: checkIfInternal,
+		Use:   "pre-receive",
+		Short: "Run git pre-receive hook",
 		RunE: func(cmd *cobra.Command, args []string) error {
 			_, s := fromContext(cmd)
-			hks := cmd.Context().Value(HooksCtxKey).(hooks.Hooks)
+			hks := cmd.Context().Value(hooksCtxKey).(hooks.Hooks)
 			repoName := getRepoName(s)
 			opts := make([]hooks.HookArg, 0)
 			scanner := bufio.NewScanner(s)
@@ -35,21 +36,20 @@ func hookCommand() *cobra.Command {
 					RefName: fields[2],
 				})
 			}
-			hks.PreReceive(s, s.Stderr(), repoName, opts)
+			hks.PreReceive(s, s, s.Stderr(), repoName, opts)
 			return nil
 		},
 	}
 
 	updateCmd := &cobra.Command{
-		Use:               "update",
-		Short:             "Run git update hook",
-		Args:              cobra.ExactArgs(3),
-		PersistentPreRunE: checkIfInternal,
+		Use:   "update",
+		Short: "Run git update hook",
+		Args:  cobra.ExactArgs(3),
 		RunE: func(cmd *cobra.Command, args []string) error {
 			_, s := fromContext(cmd)
-			hks := cmd.Context().Value(HooksCtxKey).(hooks.Hooks)
+			hks := cmd.Context().Value(hooksCtxKey).(hooks.Hooks)
 			repoName := getRepoName(s)
-			hks.Update(s, s.Stderr(), repoName, hooks.HookArg{
+			hks.Update(s, s, s.Stderr(), repoName, hooks.HookArg{
 				RefName: args[0],
 				OldSha:  args[1],
 				NewSha:  args[2],
@@ -59,12 +59,11 @@ func hookCommand() *cobra.Command {
 	}
 
 	postReceiveCmd := &cobra.Command{
-		Use:               "post-receive",
-		Short:             "Run git post-receive hook",
-		PersistentPreRunE: checkIfInternal,
+		Use:   "post-receive",
+		Short: "Run git post-receive hook",
 		RunE: func(cmd *cobra.Command, _ []string) error {
 			_, s := fromContext(cmd)
-			hks := cmd.Context().Value(HooksCtxKey).(hooks.Hooks)
+			hks := cmd.Context().Value(hooksCtxKey).(hooks.Hooks)
 			repoName := getRepoName(s)
 			opts := make([]hooks.HookArg, 0)
 			scanner := bufio.NewScanner(s)
@@ -79,20 +78,19 @@ func hookCommand() *cobra.Command {
 					RefName: fields[2],
 				})
 			}
-			hks.PostReceive(s, s.Stderr(), repoName, opts)
+			hks.PostReceive(s, s, s.Stderr(), repoName, opts)
 			return nil
 		},
 	}
 
 	postUpdateCmd := &cobra.Command{
-		Use:               "post-update",
-		Short:             "Run git post-update hook",
-		PersistentPreRunE: checkIfInternal,
+		Use:   "post-update",
+		Short: "Run git post-update hook",
 		RunE: func(cmd *cobra.Command, args []string) error {
 			_, s := fromContext(cmd)
-			hks := cmd.Context().Value(HooksCtxKey).(hooks.Hooks)
+			hks := cmd.Context().Value(hooksCtxKey).(hooks.Hooks)
 			repoName := getRepoName(s)
-			hks.PostUpdate(s, s.Stderr(), repoName, args...)
+			hks.PostUpdate(s, s, s.Stderr(), repoName, args...)
 			return nil
 		},
 	}
@@ -118,13 +116,13 @@ func hookCommand() *cobra.Command {
 func checkIfInternal(cmd *cobra.Command, _ []string) error {
 	cfg, s := fromContext(cmd)
 	pk := s.PublicKey()
-	kp, err := keygen.New(cfg.SSH.InternalKeyPath, keygen.WithKeyType(keygen.Ed25519))
+	kp, err := keygen.New(cfg.Internal.InternalKeyPath, keygen.WithKeyType(keygen.Ed25519))
 	if err != nil {
-		logger.Errorf("failed to read internal key: %v", err)
+		log.WithPrefix("server.internal").Errorf("failed to read internal key: %v", err)
 		return err
 	}
 	if !backend.KeysEqual(pk, kp.PublicKey()) {
-		return ErrUnauthorized
+		return errors.ErrUnauthorized
 	}
 	return nil
 }

server/internal/internal.go 🔗

@@ -0,0 +1,86 @@
+package internal
+
+import (
+	"context"
+	"fmt"
+
+	"github.com/charmbracelet/keygen"
+	"github.com/charmbracelet/soft-serve/server/backend"
+	"github.com/charmbracelet/soft-serve/server/config"
+	"github.com/charmbracelet/soft-serve/server/hooks"
+	"github.com/charmbracelet/ssh"
+	"github.com/charmbracelet/wish"
+)
+
+// InternalServer is a internal interface to communicate with the server.
+type InternalServer struct {
+	cfg *config.Config
+	s   *ssh.Server
+	kp  *keygen.SSHKeyPair
+	ckp *keygen.SSHKeyPair
+}
+
+// NewInternalServer returns a new internal server.
+func NewInternalServer(cfg *config.Config, hooks hooks.Hooks) (*InternalServer, error) {
+	i := &InternalServer{cfg: cfg}
+
+	// Create internal key.
+	ikp, err := keygen.New(
+		cfg.Internal.InternalKeyPath,
+		keygen.WithKeyType(keygen.Ed25519),
+		keygen.WithWrite(),
+	)
+	if err != nil {
+		return nil, fmt.Errorf("internal key: %w", err)
+	}
+
+	i.kp = ikp
+
+	// Create client key.
+	ckp, err := keygen.New(
+		cfg.Internal.ClientKeyPath,
+		keygen.WithKeyType(keygen.Ed25519),
+		keygen.WithWrite(),
+	)
+	if err != nil {
+		return nil, fmt.Errorf("client key: %w", err)
+	}
+
+	i.ckp = ckp
+
+	s, err := wish.NewServer(
+		wish.WithAddress(cfg.Internal.ListenAddr),
+		wish.WithHostKeyPath(cfg.Internal.KeyPath),
+		wish.WithPublicKeyAuth(i.PublicKeyHandler),
+		wish.WithMiddleware(
+			i.Middleware(hooks),
+		),
+	)
+	if err != nil {
+		return nil, fmt.Errorf("wish: %w", err)
+	}
+
+	i.s = s
+
+	return i, nil
+}
+
+// PublicKeyHandler handles public key authentication.
+func (i *InternalServer) PublicKeyHandler(ctx ssh.Context, pk ssh.PublicKey) bool {
+	return backend.KeysEqual(i.kp.PublicKey(), pk)
+}
+
+// Start starts the internal server.
+func (i *InternalServer) Start() error {
+	return i.s.ListenAndServe()
+}
+
+// Shutdown shuts down the internal server.
+func (i *InternalServer) Shutdown(ctx context.Context) error {
+	return i.s.Shutdown(ctx)
+}
+
+// Close closes the internal server.
+func (i *InternalServer) Close() error {
+	return i.s.Close()
+}

server/jobs.go 🔗

@@ -38,8 +38,7 @@ func mirrorJob(cfg *config.Config) func() {
 				cmd.AddEnvs(
 					fmt.Sprintf(`GIT_SSH_COMMAND=ssh -o UserKnownHostsFile="%s" -o StrictHostKeyChecking=no -i "%s"`,
 						filepath.Join(cfg.DataPath, "ssh", "known_hosts"),
-						// FIXME: upstream keygen appends _ed25519 to the key path.
-						cfg.SSH.ClientKeyPath+"_ed25519",
+						cfg.Internal.ClientKeyPath,
 					),
 				)
 				if _, err := cmd.RunInDir(r.Path); err != nil {

server/server.go 🔗

@@ -3,9 +3,9 @@ package server
 import (
 	"context"
 	"errors"
+	"fmt"
 	"net/http"
 
-	"github.com/charmbracelet/keygen"
 	"github.com/charmbracelet/log"
 
 	"github.com/charmbracelet/soft-serve/server/backend"
@@ -13,6 +13,7 @@ import (
 	"github.com/charmbracelet/soft-serve/server/config"
 	"github.com/charmbracelet/soft-serve/server/cron"
 	"github.com/charmbracelet/soft-serve/server/daemon"
+	"github.com/charmbracelet/soft-serve/server/internal"
 	sshsrv "github.com/charmbracelet/soft-serve/server/ssh"
 	"github.com/charmbracelet/soft-serve/server/stats"
 	"github.com/charmbracelet/soft-serve/server/web"
@@ -26,14 +27,15 @@ var (
 
 // Server is the Soft Serve server.
 type Server struct {
-	SSHServer   *sshsrv.SSHServer
-	GitDaemon   *daemon.GitDaemon
-	HTTPServer  *web.HTTPServer
-	StatsServer *stats.StatsServer
-	Cron        *cron.CronScheduler
-	Config      *config.Config
-	Backend     backend.Backend
-	ctx         context.Context
+	SSHServer      *sshsrv.SSHServer
+	GitDaemon      *daemon.GitDaemon
+	HTTPServer     *web.HTTPServer
+	StatsServer    *stats.StatsServer
+	InternalServer *internal.InternalServer
+	Cron           *cron.CronScheduler
+	Config         *config.Config
+	Backend        backend.Backend
+	ctx            context.Context
 }
 
 // NewServer returns a new *ssh.Server configured to serve Soft Serve. The SSH
@@ -46,32 +48,10 @@ func NewServer(ctx context.Context, cfg *config.Config) (*Server, error) {
 	if cfg.Backend == nil {
 		sb, err := sqlite.NewSqliteBackend(ctx, cfg)
 		if err != nil {
-			logger.Fatal(err)
+			return nil, fmt.Errorf("create backend: %w", err)
 		}
 
 		cfg = cfg.WithBackend(sb)
-
-		// Create internal key.
-		ikp, err := keygen.New(
-			cfg.SSH.InternalKeyPath,
-			keygen.WithKeyType(keygen.Ed25519),
-			keygen.WithWrite(),
-		)
-		if err != nil {
-			return nil, err
-		}
-		cfg.InternalPublicKey = ikp.AuthorizedKey()
-
-		// Create client key.
-		ckp, err := keygen.New(
-			cfg.SSH.ClientKeyPath,
-			keygen.WithKeyType(keygen.Ed25519),
-			keygen.WithWrite(),
-		)
-		if err != nil {
-			return nil, err
-		}
-		cfg.ClientPublicKey = ckp.AuthorizedKey()
 	}
 
 	srv := &Server{
@@ -84,24 +64,29 @@ func NewServer(ctx context.Context, cfg *config.Config) (*Server, error) {
 	// Add cron jobs.
 	srv.Cron.AddFunc(jobSpecs["mirror"], mirrorJob(cfg))
 
-	srv.SSHServer, err = sshsrv.NewSSHServer(cfg, srv)
+	srv.SSHServer, err = sshsrv.NewSSHServer(cfg)
 	if err != nil {
-		return nil, err
+		return nil, fmt.Errorf("create ssh server: %w", err)
 	}
 
 	srv.GitDaemon, err = daemon.NewGitDaemon(cfg)
 	if err != nil {
-		return nil, err
+		return nil, fmt.Errorf("create git daemon: %w", err)
 	}
 
 	srv.HTTPServer, err = web.NewHTTPServer(cfg)
 	if err != nil {
-		return nil, err
+		return nil, fmt.Errorf("create http server: %w", err)
 	}
 
 	srv.StatsServer, err = stats.NewStatsServer(cfg)
 	if err != nil {
-		return nil, err
+		return nil, fmt.Errorf("create stats server: %w", err)
+	}
+
+	srv.InternalServer, err = internal.NewInternalServer(cfg, srv)
+	if err != nil {
+		return nil, fmt.Errorf("create internal server: %w", err)
 	}
 
 	return srv, nil
@@ -158,6 +143,13 @@ func (s *Server) Start() error {
 		s.Cron.Start()
 		return nil
 	})
+	errg.Go(func() error {
+		logger.Print("Starting internal server", "addr", s.Config.Internal.ListenAddr)
+		if err := start(ctx, s.InternalServer.Start); !errors.Is(err, http.ErrServerClosed) {
+			return err
+		}
+		return nil
+	})
 	return errg.Wait()
 }
 
@@ -176,6 +168,13 @@ func (s *Server) Shutdown(ctx context.Context) error {
 	errg.Go(func() error {
 		return s.StatsServer.Shutdown(ctx)
 	})
+	errg.Go(func() error {
+		s.Cron.Stop()
+		return nil
+	})
+	errg.Go(func() error {
+		return s.InternalServer.Shutdown(ctx)
+	})
 	return errg.Wait()
 }
 
@@ -186,5 +185,10 @@ func (s *Server) Close() error {
 	errg.Go(s.HTTPServer.Close)
 	errg.Go(s.SSHServer.Close)
 	errg.Go(s.StatsServer.Close)
+	errg.Go(func() error {
+		s.Cron.Stop()
+		return nil
+	})
+	errg.Go(s.InternalServer.Close)
 	return errg.Wait()
 }

server/ssh/session.go 🔗

@@ -6,8 +6,8 @@ import (
 	"github.com/aymanbagabas/go-osc52"
 	tea "github.com/charmbracelet/bubbletea"
 	"github.com/charmbracelet/soft-serve/server/backend"
-	cm "github.com/charmbracelet/soft-serve/server/cmd"
 	"github.com/charmbracelet/soft-serve/server/config"
+	"github.com/charmbracelet/soft-serve/server/errors"
 	"github.com/charmbracelet/soft-serve/ui"
 	"github.com/charmbracelet/soft-serve/ui/common"
 	"github.com/charmbracelet/ssh"
@@ -41,7 +41,7 @@ func SessionHandler(cfg *config.Config) bm.ProgramHandler {
 			initialRepo = cmd[0]
 			auth := cfg.Backend.AccessLevelByPublicKey(initialRepo, s.PublicKey())
 			if auth < backend.ReadOnlyAccess {
-				wish.Fatalln(s, cm.ErrUnauthorized)
+				wish.Fatalln(s, errors.ErrUnauthorized)
 				return nil
 			}
 		}

server/ssh/ssh.go 🔗

@@ -14,7 +14,6 @@ import (
 	cm "github.com/charmbracelet/soft-serve/server/cmd"
 	"github.com/charmbracelet/soft-serve/server/config"
 	"github.com/charmbracelet/soft-serve/server/git"
-	"github.com/charmbracelet/soft-serve/server/hooks"
 	"github.com/charmbracelet/soft-serve/server/utils"
 	"github.com/charmbracelet/ssh"
 	"github.com/charmbracelet/wish"
@@ -82,7 +81,7 @@ type SSHServer struct {
 }
 
 // NewSSHServer returns a new SSHServer.
-func NewSSHServer(cfg *config.Config, hooks hooks.Hooks) (*SSHServer, error) {
+func NewSSHServer(cfg *config.Config) (*SSHServer, error) {
 	var err error
 	s := &SSHServer{cfg: cfg}
 	logger := logger.StandardLog(log.StandardLogOptions{ForceLevel: log.DebugLevel})
@@ -92,7 +91,7 @@ func NewSSHServer(cfg *config.Config, hooks hooks.Hooks) (*SSHServer, error) {
 			// BubbleTea middleware.
 			bm.MiddlewareWithProgramHandler(SessionHandler(cfg), termenv.ANSI256),
 			// CLI middleware.
-			cm.Middleware(cfg, hooks),
+			cm.Middleware(cfg),
 			// Git middleware.
 			s.Middleware(cfg),
 			// Logging middleware.