fix(ssh): add authentication middleware

Ayman Bagabas created

We need to verify that the key used to establish the connection is the
same key used for authentication, otherwise, refuse connection.

Change summary

server/ssh/middleware.go | 34 +++++++++++++++++++++++++++++
server/ssh/ssh.go        | 48 ++++++++++++++++++++++++++++++++++++++++++
2 files changed, 82 insertions(+)

Detailed changes

server/ssh/middleware.go 🔗

@@ -13,11 +13,45 @@ import (
 	"github.com/charmbracelet/soft-serve/server/sshutils"
 	"github.com/charmbracelet/soft-serve/server/store"
 	"github.com/charmbracelet/ssh"
+	"github.com/charmbracelet/wish"
 	"github.com/prometheus/client_golang/prometheus"
 	"github.com/prometheus/client_golang/prometheus/promauto"
 	"github.com/spf13/cobra"
+	gossh "golang.org/x/crypto/ssh"
 )
 
+// ErrPermissionDenied is returned when a user is not allowed connect.
+var ErrPermissionDenied = fmt.Errorf("permission denied")
+
+// AuthenticationMiddleware handles authentication.
+func AuthenticationMiddleware(sh ssh.Handler) ssh.Handler {
+	return func(s ssh.Session) {
+		// XXX: The authentication key is set in the context but gossh doesn't
+		// validate the authentication. We need to verify that the _last_ key
+		// that was approved is the one that's being used.
+
+		pk := s.PublicKey()
+		if pk != nil {
+			// There is no public key stored in the context, public-key auth
+			// was never requested, skip
+			perms := s.Permissions().Permissions
+			if perms == nil {
+				wish.Fatalln(s, ErrPermissionDenied)
+				return
+			}
+
+			// Check if the key is the same as the one we have in context
+			fp := perms.Extensions["pubkey-fp"]
+			if fp != gossh.FingerprintSHA256(pk) {
+				wish.Fatalln(s, ErrPermissionDenied)
+				return
+			}
+		}
+
+		sh(s)
+	}
+}
+
 // ContextMiddleware adds the config, backend, and logger to the session context.
 func ContextMiddleware(cfg *config.Config, dbx *db.DB, datastore store.Store, be *backend.Backend, logger *log.Logger) func(ssh.Handler) ssh.Handler {
 	return func(sh ssh.Handler) ssh.Handler {

server/ssh/ssh.go 🔗

@@ -77,6 +77,11 @@ func NewSSHServer(ctx context.Context) (*SSHServer, error) {
 			LoggingMiddleware,
 			// Context middleware.
 			ContextMiddleware(cfg, dbx, datastore, be, logger),
+			// Authentication middleware.
+			// gossh.PublicKeyHandler doesn't guarantee that the public key
+			// is in fact the one used for authentication, so we need to
+			// check it again here.
+			AuthenticationMiddleware,
 		),
 	}
 
@@ -91,6 +96,16 @@ func NewSSHServer(ctx context.Context) (*SSHServer, error) {
 		return nil, err
 	}
 
+	if config.IsDebug() {
+		s.srv.ServerConfigCallback = func(ctx ssh.Context) *gossh.ServerConfig {
+			return &gossh.ServerConfig{
+				AuthLogCallback: func(conn gossh.ConnMetadata, method string, err error) {
+					logger.Debug("authentication", "user", conn.User(), "method", method, "err", err)
+				},
+			}
+		}
+	}
+
 	if cfg.SSH.MaxTimeout > 0 {
 		s.srv.MaxTimeout = time.Duration(cfg.SSH.MaxTimeout) * time.Second
 	}
@@ -130,6 +145,19 @@ func (s *SSHServer) Shutdown(ctx context.Context) error {
 	return s.srv.Shutdown(ctx)
 }
 
+func initializePermissions(ctx ssh.Context) {
+	perms := ctx.Permissions()
+	if perms == nil || perms.Permissions == nil {
+		perms = &ssh.Permissions{Permissions: &gossh.Permissions{}}
+	}
+	if perms.Extensions == nil {
+		perms.Extensions = make(map[string]string)
+	}
+	if perms.Permissions.Extensions == nil {
+		perms.Permissions.Extensions = make(map[string]string)
+	}
+}
+
 // PublicKeyAuthHandler handles public key authentication.
 func (s *SSHServer) PublicKeyHandler(ctx ssh.Context, pk ssh.PublicKey) (allowed bool) {
 	if pk == nil {
@@ -144,6 +172,15 @@ func (s *SSHServer) PublicKeyHandler(ctx ssh.Context, pk ssh.PublicKey) (allowed
 	if user != nil {
 		ctx.SetValue(proto.ContextKeyUser, user)
 		allowed = true
+
+		// XXX: store the first "approved" public-key fingerprint in the
+		// permissions block to use for authentication later.
+		initializePermissions(ctx)
+		perms := ctx.Permissions()
+
+		// Set the public key fingerprint to be used for authentication.
+		perms.Extensions["pubkey-fp"] = gossh.FingerprintSHA256(pk)
+		ctx.SetValue(ssh.ContextKeyPermissions, perms)
 	}
 
 	return
@@ -154,5 +191,16 @@ func (s *SSHServer) PublicKeyHandler(ctx ssh.Context, pk ssh.PublicKey) (allowed
 func (s *SSHServer) KeyboardInteractiveHandler(ctx ssh.Context, _ gossh.KeyboardInteractiveChallenge) bool {
 	ac := s.be.AllowKeyless(ctx)
 	keyboardInteractiveCounter.WithLabelValues(strconv.FormatBool(ac)).Inc()
+
+	// If we're allowing keyless access, reset the public key fingerprint
+	if ac {
+		initializePermissions(ctx)
+		perms := ctx.Permissions()
+
+		// XXX: reset the public-key fingerprint. This is used to validate the
+		// public key being used to authenticate.
+		perms.Extensions["pubkey-fp"] = ""
+		ctx.SetValue(ssh.ContextKeyPermissions, perms)
+	}
 	return ac
 }