middleware.go

  1package ssh
  2
  3import (
  4	"fmt"
  5	"time"
  6
  7	"github.com/charmbracelet/log"
  8	"github.com/charmbracelet/soft-serve/pkg/backend"
  9	"github.com/charmbracelet/soft-serve/pkg/config"
 10	"github.com/charmbracelet/soft-serve/pkg/db"
 11	"github.com/charmbracelet/soft-serve/pkg/proto"
 12	"github.com/charmbracelet/soft-serve/pkg/ssh/cmd"
 13	"github.com/charmbracelet/soft-serve/pkg/sshutils"
 14	"github.com/charmbracelet/soft-serve/pkg/store"
 15	"github.com/charmbracelet/ssh"
 16	"github.com/charmbracelet/wish"
 17	"github.com/prometheus/client_golang/prometheus"
 18	"github.com/prometheus/client_golang/prometheus/promauto"
 19	"github.com/spf13/cobra"
 20	gossh "golang.org/x/crypto/ssh"
 21)
 22
 23// ErrPermissionDenied is returned when a user is not allowed connect.
 24var ErrPermissionDenied = fmt.Errorf("permission denied")
 25
 26// AuthenticationMiddleware handles authentication.
 27func AuthenticationMiddleware(sh ssh.Handler) ssh.Handler {
 28	return func(s ssh.Session) {
 29		// XXX: The authentication key is set in the context but gossh doesn't
 30		// validate the authentication. We need to verify that the _last_ key
 31		// that was approved is the one that's being used.
 32
 33		pk := s.PublicKey()
 34		if pk != nil {
 35			// There is no public key stored in the context, public-key auth
 36			// was never requested, skip
 37			perms := s.Permissions().Permissions
 38			if perms == nil {
 39				wish.Fatalln(s, ErrPermissionDenied)
 40				return
 41			}
 42
 43			// Check if the key is the same as the one we have in context
 44			fp := perms.Extensions["pubkey-fp"]
 45			if fp != gossh.FingerprintSHA256(pk) {
 46				wish.Fatalln(s, ErrPermissionDenied)
 47				return
 48			}
 49		}
 50
 51		sh(s)
 52	}
 53}
 54
 55// ContextMiddleware adds the config, backend, and logger to the session context.
 56func ContextMiddleware(cfg *config.Config, dbx *db.DB, datastore store.Store, be *backend.Backend, logger *log.Logger) func(ssh.Handler) ssh.Handler {
 57	return func(sh ssh.Handler) ssh.Handler {
 58		return func(s ssh.Session) {
 59			s.Context().SetValue(sshutils.ContextKeySession, s)
 60			s.Context().SetValue(config.ContextKey, cfg)
 61			s.Context().SetValue(db.ContextKey, dbx)
 62			s.Context().SetValue(store.ContextKey, datastore)
 63			s.Context().SetValue(backend.ContextKey, be)
 64			s.Context().SetValue(log.ContextKey, logger.WithPrefix("ssh"))
 65			sh(s)
 66		}
 67	}
 68}
 69
 70var cliCommandCounter = promauto.NewCounterVec(prometheus.CounterOpts{
 71	Namespace: "soft_serve",
 72	Subsystem: "cli",
 73	Name:      "commands_total",
 74	Help:      "Total times each command was called",
 75}, []string{"command"})
 76
 77// CommandMiddleware handles git commands and CLI commands.
 78// This middleware must be run after the ContextMiddleware.
 79func CommandMiddleware(sh ssh.Handler) ssh.Handler {
 80	return func(s ssh.Session) {
 81		_, _, ptyReq := s.Pty()
 82		if ptyReq {
 83			sh(s)
 84			return
 85		}
 86
 87		ctx := s.Context()
 88		cfg := config.FromContext(ctx)
 89
 90		args := s.Command()
 91		cliCommandCounter.WithLabelValues(cmd.CommandName(args)).Inc()
 92		rootCmd := &cobra.Command{
 93			Short:        "Soft Serve is a self-hostable Git server for the command line.",
 94			SilenceUsage: true,
 95		}
 96		rootCmd.CompletionOptions.DisableDefaultCmd = true
 97
 98		rootCmd.SetUsageTemplate(cmd.UsageTemplate)
 99		rootCmd.SetUsageFunc(cmd.UsageFunc)
100		rootCmd.AddCommand(
101			cmd.GitUploadPackCommand(),
102			cmd.GitUploadArchiveCommand(),
103			cmd.GitReceivePackCommand(),
104			cmd.RepoCommand(),
105			cmd.SettingsCommand(),
106			cmd.UserCommand(),
107			cmd.InfoCommand(),
108			cmd.PubkeyCommand(),
109			cmd.SetUsernameCommand(),
110			cmd.JWTCommand(),
111			cmd.TokenCommand(),
112		)
113
114		if cfg.LFS.Enabled {
115			rootCmd.AddCommand(
116				cmd.GitLFSAuthenticateCommand(),
117			)
118
119			if cfg.LFS.SSHEnabled {
120				rootCmd.AddCommand(
121					cmd.GitLFSTransfer(),
122				)
123			}
124		}
125
126		rootCmd.SetArgs(args)
127		if len(args) == 0 {
128			// otherwise it'll default to os.Args, which is not what we want.
129			rootCmd.SetArgs([]string{"--help"})
130		}
131		rootCmd.SetIn(s)
132		rootCmd.SetOut(s)
133		rootCmd.SetErr(s.Stderr())
134		rootCmd.SetContext(ctx)
135
136		if err := rootCmd.ExecuteContext(ctx); err != nil {
137			s.Exit(1) // nolint: errcheck
138			return
139		}
140	}
141}
142
143// LoggingMiddleware logs the ssh connection and command.
144func LoggingMiddleware(sh ssh.Handler) ssh.Handler {
145	return func(s ssh.Session) {
146		ctx := s.Context()
147		logger := log.FromContext(ctx).WithPrefix("ssh")
148		ct := time.Now()
149		hpk := sshutils.MarshalAuthorizedKey(s.PublicKey())
150		ptyReq, _, isPty := s.Pty()
151		addr := s.RemoteAddr().String()
152		user := proto.UserFromContext(ctx)
153		logArgs := []interface{}{
154			"addr",
155			addr,
156			"cmd",
157			s.Command(),
158		}
159
160		if user != nil {
161			logArgs = append([]interface{}{
162				"username",
163				user.Username(),
164			}, logArgs...)
165		}
166
167		if isPty {
168			logArgs = []interface{}{
169				"term", ptyReq.Term,
170				"width", ptyReq.Window.Width,
171				"height", ptyReq.Window.Height,
172			}
173		}
174
175		if config.IsVerbose() {
176			logArgs = append(logArgs,
177				"key", hpk,
178				"envs", s.Environ(),
179			)
180		}
181
182		msg := fmt.Sprintf("user %q", s.User())
183		logger.Debug(msg+" connected", logArgs...)
184		sh(s)
185		logger.Debug(msg+" disconnected", append(logArgs, "duration", time.Since(ct))...)
186	}
187}