middleware.go

  1package ssh
  2
  3import (
  4	"fmt"
  5	"time"
  6
  7	"github.com/charmbracelet/log"
  8	"github.com/charmbracelet/soft-serve/pkg/backend"
  9	"github.com/charmbracelet/soft-serve/pkg/config"
 10	"github.com/charmbracelet/soft-serve/pkg/db"
 11	"github.com/charmbracelet/soft-serve/pkg/proto"
 12	"github.com/charmbracelet/soft-serve/pkg/ssh/cmd"
 13	"github.com/charmbracelet/soft-serve/pkg/sshutils"
 14	"github.com/charmbracelet/soft-serve/pkg/store"
 15	"github.com/charmbracelet/ssh"
 16	"github.com/charmbracelet/wish"
 17	"github.com/prometheus/client_golang/prometheus"
 18	"github.com/prometheus/client_golang/prometheus/promauto"
 19	"github.com/spf13/cobra"
 20	gossh "golang.org/x/crypto/ssh"
 21)
 22
 23// ErrPermissionDenied is returned when a user is not allowed connect.
 24var ErrPermissionDenied = fmt.Errorf("permission denied")
 25
 26// AuthenticationMiddleware handles authentication.
 27func AuthenticationMiddleware(sh ssh.Handler) ssh.Handler {
 28	return func(s ssh.Session) {
 29		// XXX: The authentication key is set in the context but gossh doesn't
 30		// validate the authentication. We need to verify that the _last_ key
 31		// that was approved is the one that's being used.
 32
 33		pk := s.PublicKey()
 34		if pk != nil {
 35			// There is no public key stored in the context, public-key auth
 36			// was never requested, skip
 37			perms := s.Permissions().Permissions
 38			if perms == nil {
 39				wish.Fatalln(s, ErrPermissionDenied)
 40				return
 41			}
 42
 43			// Check if the key is the same as the one we have in context
 44			fp := perms.Extensions["pubkey-fp"]
 45			if fp != gossh.FingerprintSHA256(pk) {
 46				wish.Fatalln(s, ErrPermissionDenied)
 47				return
 48			}
 49		}
 50
 51		sh(s)
 52	}
 53}
 54
 55// ContextMiddleware adds the config, backend, and logger to the session context.
 56func ContextMiddleware(cfg *config.Config, dbx *db.DB, datastore store.Store, be *backend.Backend, logger *log.Logger) func(ssh.Handler) ssh.Handler {
 57	return func(sh ssh.Handler) ssh.Handler {
 58		return func(s ssh.Session) {
 59			s.Context().SetValue(sshutils.ContextKeySession, s)
 60			s.Context().SetValue(config.ContextKey, cfg)
 61			s.Context().SetValue(db.ContextKey, dbx)
 62			s.Context().SetValue(store.ContextKey, datastore)
 63			s.Context().SetValue(backend.ContextKey, be)
 64			s.Context().SetValue(log.ContextKey, logger.WithPrefix("ssh"))
 65			sh(s)
 66		}
 67	}
 68}
 69
 70var cliCommandCounter = promauto.NewCounterVec(prometheus.CounterOpts{
 71	Namespace: "soft_serve",
 72	Subsystem: "cli",
 73	Name:      "commands_total",
 74	Help:      "Total times each command was called",
 75}, []string{"command"})
 76
 77// CommandMiddleware handles git commands and CLI commands.
 78// This middleware must be run after the ContextMiddleware.
 79func CommandMiddleware(sh ssh.Handler) ssh.Handler {
 80	return func(s ssh.Session) {
 81		func() {
 82			_, _, ptyReq := s.Pty()
 83			if ptyReq {
 84				return
 85			}
 86
 87			ctx := s.Context()
 88			cfg := config.FromContext(ctx)
 89
 90			args := s.Command()
 91			cliCommandCounter.WithLabelValues(cmd.CommandName(args)).Inc()
 92			rootCmd := &cobra.Command{
 93				Short:        "Soft Serve is a self-hostable Git server for the command line.",
 94				SilenceUsage: true,
 95			}
 96			rootCmd.CompletionOptions.DisableDefaultCmd = true
 97
 98			rootCmd.SetUsageTemplate(cmd.UsageTemplate)
 99			rootCmd.SetUsageFunc(cmd.UsageFunc)
100			rootCmd.AddCommand(
101				cmd.GitUploadPackCommand(),
102				cmd.GitUploadArchiveCommand(),
103				cmd.GitReceivePackCommand(),
104				cmd.RepoCommand(),
105				cmd.SettingsCommand(),
106				cmd.UserCommand(),
107				cmd.OrgCommand(),
108				cmd.InfoCommand(),
109				cmd.PubkeyCommand(),
110				cmd.SetUsernameCommand(),
111				cmd.JWTCommand(),
112				cmd.TokenCommand(),
113			)
114
115			if cfg.LFS.Enabled {
116				rootCmd.AddCommand(
117					cmd.GitLFSAuthenticateCommand(),
118				)
119
120				if cfg.LFS.SSHEnabled {
121					rootCmd.AddCommand(
122						cmd.GitLFSTransfer(),
123					)
124				}
125			}
126
127			rootCmd.SetArgs(args)
128			if len(args) == 0 {
129				// otherwise it'll default to os.Args, which is not what we want.
130				rootCmd.SetArgs([]string{"--help"})
131			}
132			rootCmd.SetIn(s)
133			rootCmd.SetOut(s)
134			rootCmd.SetErr(s.Stderr())
135			rootCmd.SetContext(ctx)
136
137			if err := rootCmd.ExecuteContext(ctx); err != nil {
138				s.Exit(1) // nolint: errcheck
139				return
140			}
141		}()
142		sh(s)
143	}
144}
145
146// LoggingMiddleware logs the ssh connection and command.
147func LoggingMiddleware(sh ssh.Handler) ssh.Handler {
148	return func(s ssh.Session) {
149		ctx := s.Context()
150		logger := log.FromContext(ctx).WithPrefix("ssh")
151		ct := time.Now()
152		hpk := sshutils.MarshalAuthorizedKey(s.PublicKey())
153		ptyReq, _, isPty := s.Pty()
154		addr := s.RemoteAddr().String()
155		user := proto.UserFromContext(ctx)
156		logArgs := []interface{}{
157			"addr",
158			addr,
159			"cmd",
160			s.Command(),
161		}
162
163		if user != nil {
164			logArgs = append([]interface{}{
165				"username",
166				user.Username(),
167			}, logArgs...)
168		}
169
170		if isPty {
171			logArgs = []interface{}{
172				"term", ptyReq.Term,
173				"width", ptyReq.Window.Width,
174				"height", ptyReq.Window.Height,
175			}
176		}
177
178		if config.IsVerbose() {
179			logArgs = append(logArgs,
180				"key", hpk,
181				"envs", s.Environ(),
182			)
183		}
184
185		msg := fmt.Sprintf("user %q", s.User())
186		logger.Debug(msg+" connected", logArgs...)
187		sh(s)
188		logger.Debug(msg+" disconnected", append(logArgs, "duration", time.Since(ct))...)
189	}
190}