middleware.go

  1package ssh
  2
  3import (
  4	"fmt"
  5	"time"
  6
  7	"github.com/charmbracelet/log"
  8	"github.com/charmbracelet/soft-serve/pkg/backend"
  9	"github.com/charmbracelet/soft-serve/pkg/config"
 10	"github.com/charmbracelet/soft-serve/pkg/db"
 11	"github.com/charmbracelet/soft-serve/pkg/proto"
 12	"github.com/charmbracelet/soft-serve/pkg/ssh/cmd"
 13	"github.com/charmbracelet/soft-serve/pkg/sshutils"
 14	"github.com/charmbracelet/soft-serve/pkg/store"
 15	"github.com/charmbracelet/ssh"
 16	"github.com/charmbracelet/wish"
 17	"github.com/prometheus/client_golang/prometheus"
 18	"github.com/prometheus/client_golang/prometheus/promauto"
 19	"github.com/spf13/cobra"
 20	gossh "golang.org/x/crypto/ssh"
 21)
 22
 23// ErrPermissionDenied is returned when a user is not allowed connect.
 24var ErrPermissionDenied = fmt.Errorf("permission denied")
 25
 26// AuthenticationMiddleware handles authentication.
 27func AuthenticationMiddleware(sh ssh.Handler) ssh.Handler {
 28	return func(s ssh.Session) {
 29		// XXX: The authentication key is set in the context but gossh doesn't
 30		// validate the authentication. We need to verify that the _last_ key
 31		// that was approved is the one that's being used.
 32
 33		pk := s.PublicKey()
 34		if pk != nil {
 35			// There is no public key stored in the context, public-key auth
 36			// was never requested, skip
 37			perms := s.Permissions().Permissions
 38			if perms == nil {
 39				wish.Fatalln(s, ErrPermissionDenied)
 40				return
 41			}
 42
 43			// Check if the key is the same as the one we have in context
 44			fp := perms.Extensions["pubkey-fp"]
 45			if fp != gossh.FingerprintSHA256(pk) {
 46				wish.Fatalln(s, ErrPermissionDenied)
 47				return
 48			}
 49		}
 50
 51		sh(s)
 52	}
 53}
 54
 55// ContextMiddleware adds the config, backend, and logger to the session context.
 56func ContextMiddleware(cfg *config.Config, dbx *db.DB, datastore store.Store, be *backend.Backend, logger *log.Logger) func(ssh.Handler) ssh.Handler {
 57	return func(sh ssh.Handler) ssh.Handler {
 58		return func(s ssh.Session) {
 59			s.Context().SetValue(sshutils.ContextKeySession, s)
 60			s.Context().SetValue(config.ContextKey, cfg)
 61			s.Context().SetValue(db.ContextKey, dbx)
 62			s.Context().SetValue(store.ContextKey, datastore)
 63			s.Context().SetValue(backend.ContextKey, be)
 64			s.Context().SetValue(log.ContextKey, logger.WithPrefix("ssh"))
 65			sh(s)
 66		}
 67	}
 68}
 69
 70var cliCommandCounter = promauto.NewCounterVec(prometheus.CounterOpts{
 71	Namespace: "soft_serve",
 72	Subsystem: "cli",
 73	Name:      "commands_total",
 74	Help:      "Total times each command was called",
 75}, []string{"command"})
 76
 77// CommandMiddleware handles git commands and CLI commands.
 78// This middleware must be run after the ContextMiddleware.
 79func CommandMiddleware(sh ssh.Handler) ssh.Handler {
 80	return func(s ssh.Session) {
 81		func() {
 82			_, _, ptyReq := s.Pty()
 83			if ptyReq {
 84				return
 85			}
 86
 87			ctx := s.Context()
 88			cfg := config.FromContext(ctx)
 89
 90			args := s.Command()
 91			cliCommandCounter.WithLabelValues(cmd.CommandName(args)).Inc()
 92			rootCmd := &cobra.Command{
 93				Short:        "Soft Serve is a self-hostable Git server for the command line.",
 94				SilenceUsage: true,
 95			}
 96			rootCmd.CompletionOptions.DisableDefaultCmd = true
 97
 98			rootCmd.SetUsageTemplate(cmd.UsageTemplate)
 99			rootCmd.SetUsageFunc(cmd.UsageFunc)
100			rootCmd.AddCommand(
101				cmd.GitUploadPackCommand(),
102				cmd.GitUploadArchiveCommand(),
103				cmd.GitReceivePackCommand(),
104				cmd.RepoCommand(),
105				cmd.SettingsCommand(),
106				cmd.UserCommand(),
107				cmd.InfoCommand(),
108				cmd.PubkeyCommand(),
109				cmd.SetUsernameCommand(),
110				cmd.JWTCommand(),
111				cmd.TokenCommand(),
112			)
113
114			if cfg.LFS.Enabled {
115				rootCmd.AddCommand(
116					cmd.GitLFSAuthenticateCommand(),
117				)
118
119				if cfg.LFS.SSHEnabled {
120					rootCmd.AddCommand(
121						cmd.GitLFSTransfer(),
122					)
123				}
124			}
125
126			rootCmd.SetArgs(args)
127			if len(args) == 0 {
128				// otherwise it'll default to os.Args, which is not what we want.
129				rootCmd.SetArgs([]string{"--help"})
130			}
131			rootCmd.SetIn(s)
132			rootCmd.SetOut(s)
133			rootCmd.SetErr(s.Stderr())
134			rootCmd.SetContext(ctx)
135
136			if err := rootCmd.ExecuteContext(ctx); err != nil {
137				s.Exit(1) // nolint: errcheck
138				return
139			}
140		}()
141		sh(s)
142	}
143}
144
145// LoggingMiddleware logs the ssh connection and command.
146func LoggingMiddleware(sh ssh.Handler) ssh.Handler {
147	return func(s ssh.Session) {
148		ctx := s.Context()
149		logger := log.FromContext(ctx).WithPrefix("ssh")
150		ct := time.Now()
151		hpk := sshutils.MarshalAuthorizedKey(s.PublicKey())
152		ptyReq, _, isPty := s.Pty()
153		addr := s.RemoteAddr().String()
154		user := proto.UserFromContext(ctx)
155		logArgs := []interface{}{
156			"addr",
157			addr,
158			"cmd",
159			s.Command(),
160		}
161
162		if user != nil {
163			logArgs = append([]interface{}{
164				"username",
165				user.Username(),
166			}, logArgs...)
167		}
168
169		if isPty {
170			logArgs = append([]interface{}{
171				"term", ptyReq.Term,
172				"width", ptyReq.Window.Width,
173				"height", ptyReq.Window.Height,
174			}, logArgs...)
175		}
176
177		if config.IsVerbose() {
178			logArgs = append(logArgs,
179				"key", hpk,
180				"envs", s.Environ(),
181			)
182		}
183
184		msg := fmt.Sprintf("user %q", s.User())
185		logger.Debug(msg+" connected", logArgs...)
186		sh(s)
187		logger.Debug(msg+" disconnected", append(logArgs, "duration", time.Since(ct))...)
188	}
189}