middleware.go

  1// Package ssh provides SSH server functionality and middleware.
  2package ssh
  3
  4import (
  5	"fmt"
  6	"time"
  7
  8	"github.com/charmbracelet/log/v2"
  9	"github.com/charmbracelet/soft-serve/pkg/backend"
 10	"github.com/charmbracelet/soft-serve/pkg/config"
 11	"github.com/charmbracelet/soft-serve/pkg/db"
 12	"github.com/charmbracelet/soft-serve/pkg/proto"
 13	"github.com/charmbracelet/soft-serve/pkg/ssh/cmd"
 14	"github.com/charmbracelet/soft-serve/pkg/sshutils"
 15	"github.com/charmbracelet/soft-serve/pkg/store"
 16	"github.com/charmbracelet/ssh"
 17	"github.com/charmbracelet/wish/v2"
 18	"github.com/prometheus/client_golang/prometheus"
 19	"github.com/prometheus/client_golang/prometheus/promauto"
 20	"github.com/spf13/cobra"
 21	gossh "golang.org/x/crypto/ssh"
 22)
 23
 24// ErrPermissionDenied is returned when a user is not allowed connect.
 25var ErrPermissionDenied = fmt.Errorf("permission denied")
 26
 27// AuthenticationMiddleware handles authentication.
 28func AuthenticationMiddleware(sh ssh.Handler) ssh.Handler {
 29	return func(s ssh.Session) {
 30		// XXX: The authentication key is set in the context but gossh doesn't
 31		// validate the authentication. We need to verify that the _last_ key
 32		// that was approved is the one that's being used.
 33
 34		pk := s.PublicKey()
 35		if pk != nil {
 36			// There is no public key stored in the context, public-key auth
 37			// was never requested, skip
 38			perms := s.Permissions().Permissions
 39			if perms == nil {
 40				wish.Fatalln(s, ErrPermissionDenied)
 41				return
 42			}
 43
 44			// Check if the key is the same as the one we have in context
 45			fp := perms.Extensions["pubkey-fp"]
 46			if fp == "" || fp != gossh.FingerprintSHA256(pk) {
 47				wish.Fatalln(s, ErrPermissionDenied)
 48				return
 49			}
 50		}
 51
 52		sh(s)
 53	}
 54}
 55
 56// ContextMiddleware adds the config, backend, and logger to the session context.
 57func ContextMiddleware(cfg *config.Config, dbx *db.DB, datastore store.Store, be *backend.Backend, logger *log.Logger) func(ssh.Handler) ssh.Handler {
 58	return func(sh ssh.Handler) ssh.Handler {
 59		return func(s ssh.Session) {
 60			ctx := s.Context()
 61			ctx.SetValue(sshutils.ContextKeySession, s)
 62			ctx.SetValue(config.ContextKey, cfg)
 63			ctx.SetValue(db.ContextKey, dbx)
 64			ctx.SetValue(store.ContextKey, datastore)
 65			ctx.SetValue(backend.ContextKey, be)
 66			ctx.SetValue(log.ContextKey, logger.WithPrefix("ssh"))
 67			sh(s)
 68		}
 69	}
 70}
 71
 72var cliCommandCounter = promauto.NewCounterVec(prometheus.CounterOpts{
 73	Namespace: "soft_serve",
 74	Subsystem: "cli",
 75	Name:      "commands_total",
 76	Help:      "Total times each command was called",
 77}, []string{"command"})
 78
 79// CommandMiddleware handles git commands and CLI commands.
 80// This middleware must be run after the ContextMiddleware.
 81func CommandMiddleware(sh ssh.Handler) ssh.Handler {
 82	return func(s ssh.Session) {
 83		_, _, ptyReq := s.Pty()
 84		if ptyReq {
 85			sh(s)
 86			return
 87		}
 88
 89		ctx := s.Context()
 90		cfg := config.FromContext(ctx)
 91
 92		args := s.Command()
 93		cliCommandCounter.WithLabelValues(cmd.CommandName(args)).Inc()
 94		rootCmd := &cobra.Command{
 95			Short:        "Soft Serve is a self-hostable Git server for the command line.",
 96			SilenceUsage: true,
 97		}
 98		rootCmd.CompletionOptions.DisableDefaultCmd = true
 99
100		rootCmd.SetUsageTemplate(cmd.UsageTemplate)
101		rootCmd.SetUsageFunc(cmd.UsageFunc)
102		rootCmd.AddCommand(
103			cmd.GitUploadPackCommand(),
104			cmd.GitUploadArchiveCommand(),
105			cmd.GitReceivePackCommand(),
106			cmd.RepoCommand(),
107			cmd.SettingsCommand(),
108			cmd.UserCommand(),
109			cmd.InfoCommand(),
110			cmd.PubkeyCommand(),
111			cmd.SetUsernameCommand(),
112			cmd.JWTCommand(),
113			cmd.TokenCommand(),
114		)
115
116		if cfg.LFS.Enabled {
117			rootCmd.AddCommand(
118				cmd.GitLFSAuthenticateCommand(),
119			)
120
121			if cfg.LFS.SSHEnabled {
122				rootCmd.AddCommand(
123					cmd.GitLFSTransfer(),
124				)
125			}
126		}
127
128		rootCmd.SetArgs(args)
129		if len(args) == 0 {
130			// otherwise it'll default to os.Args, which is not what we want.
131			rootCmd.SetArgs([]string{"--help"})
132		}
133		rootCmd.SetIn(s)
134		rootCmd.SetOut(s)
135		rootCmd.SetErr(s.Stderr())
136		rootCmd.SetContext(ctx)
137
138		if err := rootCmd.ExecuteContext(ctx); err != nil {
139			s.Exit(1) //nolint:errcheck,gosec
140			return
141		}
142	}
143}
144
145// LoggingMiddleware logs the ssh connection and command.
146func LoggingMiddleware(sh ssh.Handler) ssh.Handler {
147	return func(s ssh.Session) {
148		ctx := s.Context()
149		logger := log.FromContext(ctx).WithPrefix("ssh")
150		ct := time.Now()
151		hpk := sshutils.MarshalAuthorizedKey(s.PublicKey())
152		ptyReq, _, isPty := s.Pty()
153		addr := s.RemoteAddr().String()
154		user := proto.UserFromContext(ctx)
155		logArgs := []interface{}{
156			"addr",
157			addr,
158			"cmd",
159			s.Command(),
160		}
161
162		if user != nil {
163			logArgs = append([]interface{}{
164				"username",
165				user.Username(),
166			}, logArgs...)
167		}
168
169		if isPty {
170			logArgs = []interface{}{
171				"term", ptyReq.Term,
172				"width", ptyReq.Window.Width,
173				"height", ptyReq.Window.Height,
174			}
175		}
176
177		if config.IsVerbose() {
178			logArgs = append(logArgs,
179				"key", hpk,
180				"envs", s.Environ(),
181			)
182		}
183
184		msg := fmt.Sprintf("user %q", s.User())
185		logger.Debug(msg+" connected", logArgs...)
186		sh(s)
187		logger.Debug(msg+" disconnected", append(logArgs, "duration", time.Since(ct))...)
188	}
189}