middleware.go

  1package ssh
  2
  3import (
  4	"fmt"
  5	"time"
  6
  7	log "github.com/charmbracelet/log/v2"
  8	"github.com/charmbracelet/soft-serve/pkg/backend"
  9	"github.com/charmbracelet/soft-serve/pkg/config"
 10	"github.com/charmbracelet/soft-serve/pkg/db"
 11	"github.com/charmbracelet/soft-serve/pkg/proto"
 12	"github.com/charmbracelet/soft-serve/pkg/ssh/cmd"
 13	"github.com/charmbracelet/soft-serve/pkg/sshutils"
 14	"github.com/charmbracelet/soft-serve/pkg/store"
 15	"github.com/charmbracelet/ssh"
 16	wish "github.com/charmbracelet/wish/v2"
 17	"github.com/prometheus/client_golang/prometheus"
 18	"github.com/prometheus/client_golang/prometheus/promauto"
 19	"github.com/spf13/cobra"
 20	gossh "golang.org/x/crypto/ssh"
 21)
 22
 23// ErrPermissionDenied is returned when a user is not allowed connect.
 24var ErrPermissionDenied = fmt.Errorf("permission denied")
 25
 26// AuthenticationMiddleware handles authentication.
 27func AuthenticationMiddleware(sh ssh.Handler) ssh.Handler {
 28	return func(s ssh.Session) {
 29		// XXX: The authentication key is set in the context but gossh doesn't
 30		// validate the authentication. We need to verify that the _last_ key
 31		// that was approved is the one that's being used.
 32
 33		pk := s.PublicKey()
 34		if pk != nil {
 35			// There is no public key stored in the context, public-key auth
 36			// was never requested, skip
 37			perms := s.Permissions().Permissions
 38			if perms == nil {
 39				wish.Fatalln(s, ErrPermissionDenied)
 40				return
 41			}
 42
 43			// Check if the key is the same as the one we have in context
 44			fp := perms.Extensions["pubkey-fp"]
 45			if fp == "" || fp != gossh.FingerprintSHA256(pk) {
 46				wish.Fatalln(s, ErrPermissionDenied)
 47				return
 48			}
 49		}
 50
 51		sh(s)
 52	}
 53}
 54
 55// ContextMiddleware adds the config, backend, and logger to the session context.
 56func ContextMiddleware(cfg *config.Config, dbx *db.DB, datastore store.Store, be *backend.Backend, logger *log.Logger) func(ssh.Handler) ssh.Handler {
 57	return func(sh ssh.Handler) ssh.Handler {
 58		return func(s ssh.Session) {
 59			ctx := s.Context()
 60			ctx.SetValue(sshutils.ContextKeySession, s)
 61			ctx.SetValue(config.ContextKey, cfg)
 62			ctx.SetValue(db.ContextKey, dbx)
 63			ctx.SetValue(store.ContextKey, datastore)
 64			ctx.SetValue(backend.ContextKey, be)
 65			ctx.SetValue(log.ContextKey, logger.WithPrefix("ssh"))
 66			sh(s)
 67		}
 68	}
 69}
 70
 71var cliCommandCounter = promauto.NewCounterVec(prometheus.CounterOpts{
 72	Namespace: "soft_serve",
 73	Subsystem: "cli",
 74	Name:      "commands_total",
 75	Help:      "Total times each command was called",
 76}, []string{"command"})
 77
 78// CommandMiddleware handles git commands and CLI commands.
 79// This middleware must be run after the ContextMiddleware.
 80func CommandMiddleware(sh ssh.Handler) ssh.Handler {
 81	return func(s ssh.Session) {
 82		_, _, ptyReq := s.Pty()
 83		if ptyReq {
 84			sh(s)
 85			return
 86		}
 87
 88		ctx := s.Context()
 89		cfg := config.FromContext(ctx)
 90
 91		args := s.Command()
 92		cliCommandCounter.WithLabelValues(cmd.CommandName(args)).Inc()
 93		rootCmd := &cobra.Command{
 94			Short:        "Soft Serve is a self-hostable Git server for the command line.",
 95			SilenceUsage: true,
 96		}
 97		rootCmd.CompletionOptions.DisableDefaultCmd = true
 98
 99		rootCmd.SetUsageTemplate(cmd.UsageTemplate)
100		rootCmd.SetUsageFunc(cmd.UsageFunc)
101		rootCmd.AddCommand(
102			cmd.GitUploadPackCommand(),
103			cmd.GitUploadArchiveCommand(),
104			cmd.GitReceivePackCommand(),
105			cmd.RepoCommand(),
106			cmd.SettingsCommand(),
107			cmd.UserCommand(),
108			cmd.InfoCommand(),
109			cmd.PubkeyCommand(),
110			cmd.SetUsernameCommand(),
111			cmd.JWTCommand(),
112			cmd.TokenCommand(),
113		)
114
115		if cfg.LFS.Enabled {
116			rootCmd.AddCommand(
117				cmd.GitLFSAuthenticateCommand(),
118			)
119
120			if cfg.LFS.SSHEnabled {
121				rootCmd.AddCommand(
122					cmd.GitLFSTransfer(),
123				)
124			}
125		}
126
127		rootCmd.SetArgs(args)
128		if len(args) == 0 {
129			// otherwise it'll default to os.Args, which is not what we want.
130			rootCmd.SetArgs([]string{"--help"})
131		}
132		rootCmd.SetIn(s)
133		rootCmd.SetOut(s)
134		rootCmd.SetErr(s.Stderr())
135		rootCmd.SetContext(ctx)
136
137		if err := rootCmd.ExecuteContext(ctx); err != nil {
138			s.Exit(1) // nolint: errcheck
139			return
140		}
141	}
142}
143
144// LoggingMiddleware logs the ssh connection and command.
145func LoggingMiddleware(sh ssh.Handler) ssh.Handler {
146	return func(s ssh.Session) {
147		ctx := s.Context()
148		logger := log.FromContext(ctx).WithPrefix("ssh")
149		ct := time.Now()
150		hpk := sshutils.MarshalAuthorizedKey(s.PublicKey())
151		ptyReq, _, isPty := s.Pty()
152		addr := s.RemoteAddr().String()
153		user := proto.UserFromContext(ctx)
154		logArgs := []interface{}{
155			"addr",
156			addr,
157			"cmd",
158			s.Command(),
159		}
160
161		if user != nil {
162			logArgs = append([]interface{}{
163				"username",
164				user.Username(),
165			}, logArgs...)
166		}
167
168		if isPty {
169			logArgs = []interface{}{
170				"term", ptyReq.Term,
171				"width", ptyReq.Window.Width,
172				"height", ptyReq.Window.Height,
173			}
174		}
175
176		if config.IsVerbose() {
177			logArgs = append(logArgs,
178				"key", hpk,
179				"envs", s.Environ(),
180			)
181		}
182
183		msg := fmt.Sprintf("user %q", s.User())
184		logger.Debug(msg+" connected", logArgs...)
185		sh(s)
186		logger.Debug(msg+" disconnected", append(logArgs, "duration", time.Since(ct))...)
187	}
188}