middleware.go

  1package ssh
  2
  3import (
  4	"fmt"
  5	"strconv"
  6	"time"
  7
  8	"charm.land/log/v2"
  9	"charm.land/wish/v2"
 10	"github.com/charmbracelet/soft-serve/pkg/backend"
 11	"github.com/charmbracelet/soft-serve/pkg/config"
 12	"github.com/charmbracelet/soft-serve/pkg/db"
 13	"github.com/charmbracelet/soft-serve/pkg/proto"
 14	"github.com/charmbracelet/soft-serve/pkg/ssh/cmd"
 15	"github.com/charmbracelet/soft-serve/pkg/sshutils"
 16	"github.com/charmbracelet/soft-serve/pkg/store"
 17	"github.com/charmbracelet/ssh"
 18	"github.com/prometheus/client_golang/prometheus"
 19	"github.com/prometheus/client_golang/prometheus/promauto"
 20	"github.com/spf13/cobra"
 21	gossh "golang.org/x/crypto/ssh"
 22)
 23
 24// ErrPermissionDenied is returned when a user is not allowed connect.
 25var ErrPermissionDenied = fmt.Errorf("permission denied")
 26
 27// AuthenticationMiddleware handles authentication.
 28func AuthenticationMiddleware(sh ssh.Handler) ssh.Handler {
 29	return func(s ssh.Session) {
 30		// XXX: The authentication key is set in the context but gossh doesn't
 31		// validate the authentication. We need to verify that the _last_ key
 32		// that was approved is the one that's being used.
 33
 34		ctx := s.Context()
 35		be := backend.FromContext(ctx)
 36
 37		var pkFp string
 38		perms := s.Permissions().Permissions
 39		pk := s.PublicKey()
 40		if pk != nil {
 41			// There is no public key stored in the context, public-key auth
 42			// was never requested, skip
 43			if perms == nil {
 44				wish.Fatalln(s, ErrPermissionDenied)
 45				return
 46			}
 47
 48			pkFp = gossh.FingerprintSHA256(pk)
 49		}
 50
 51		// Check if the key is the same as the one we have in context
 52		fp := perms.Extensions["pubkey-fp"]
 53		if fp != "" && fp != pkFp {
 54			wish.Fatalln(s, ErrPermissionDenied)
 55			return
 56		}
 57
 58		ac := be.AllowKeyless(ctx)
 59		publicKeyCounter.WithLabelValues(strconv.FormatBool(ac || pk != nil)).Inc()
 60		if !ac && pk == nil {
 61			wish.Fatalln(s, ErrPermissionDenied)
 62			return
 63		}
 64
 65		// Set the auth'd user, or anon, in the context
 66		var user proto.User
 67		if pk != nil {
 68			user, _ = be.UserByPublicKey(ctx, pk)
 69		}
 70		ctx.SetValue(proto.ContextKeyUser, user)
 71
 72		sh(s)
 73	}
 74}
 75
 76// ContextMiddleware adds the config, backend, and logger to the session context.
 77func ContextMiddleware(cfg *config.Config, dbx *db.DB, datastore store.Store, be *backend.Backend, logger *log.Logger) func(ssh.Handler) ssh.Handler {
 78	return func(sh ssh.Handler) ssh.Handler {
 79		return func(s ssh.Session) {
 80			ctx := s.Context()
 81			ctx.SetValue(sshutils.ContextKeySession, s)
 82			ctx.SetValue(config.ContextKey, cfg)
 83			ctx.SetValue(db.ContextKey, dbx)
 84			ctx.SetValue(store.ContextKey, datastore)
 85			ctx.SetValue(backend.ContextKey, be)
 86			ctx.SetValue(log.ContextKey, logger.WithPrefix("ssh"))
 87			sh(s)
 88		}
 89	}
 90}
 91
 92var cliCommandCounter = promauto.NewCounterVec(prometheus.CounterOpts{
 93	Namespace: "soft_serve",
 94	Subsystem: "cli",
 95	Name:      "commands_total",
 96	Help:      "Total times each command was called",
 97}, []string{"command"})
 98
 99// CommandMiddleware handles git commands and CLI commands.
100// This middleware must be run after the ContextMiddleware.
101func CommandMiddleware(sh ssh.Handler) ssh.Handler {
102	return func(s ssh.Session) {
103		_, _, ptyReq := s.Pty()
104		if ptyReq {
105			sh(s)
106			return
107		}
108
109		ctx := s.Context()
110		cfg := config.FromContext(ctx)
111
112		args := s.Command()
113		cliCommandCounter.WithLabelValues(cmd.CommandName(args)).Inc()
114		rootCmd := &cobra.Command{
115			Short:        "Soft Serve is a self-hostable Git server for the command line.",
116			SilenceUsage: true,
117		}
118		rootCmd.CompletionOptions.DisableDefaultCmd = true
119
120		rootCmd.SetUsageTemplate(cmd.UsageTemplate)
121		rootCmd.SetUsageFunc(cmd.UsageFunc)
122		rootCmd.AddCommand(
123			cmd.GitUploadPackCommand(),
124			cmd.GitUploadArchiveCommand(),
125			cmd.GitReceivePackCommand(),
126			cmd.RepoCommand(),
127			cmd.SettingsCommand(),
128			cmd.UserCommand(),
129			cmd.InfoCommand(),
130			cmd.PubkeyCommand(),
131			cmd.SetUsernameCommand(),
132			cmd.JWTCommand(),
133			cmd.TokenCommand(),
134		)
135
136		if cfg.LFS.Enabled {
137			rootCmd.AddCommand(
138				cmd.GitLFSAuthenticateCommand(),
139			)
140
141			if cfg.LFS.SSHEnabled {
142				rootCmd.AddCommand(
143					cmd.GitLFSTransfer(),
144				)
145			}
146		}
147
148		rootCmd.SetArgs(args)
149		if len(args) == 0 {
150			// otherwise it'll default to os.Args, which is not what we want.
151			rootCmd.SetArgs([]string{"--help"})
152		}
153		rootCmd.SetIn(s)
154		rootCmd.SetOut(s)
155		rootCmd.SetErr(s.Stderr())
156		rootCmd.SetContext(ctx)
157
158		if err := rootCmd.ExecuteContext(ctx); err != nil {
159			s.Exit(1) //nolint: errcheck
160			return
161		}
162	}
163}
164
165// LoggingMiddleware logs the ssh connection and command.
166func LoggingMiddleware(sh ssh.Handler) ssh.Handler {
167	return func(s ssh.Session) {
168		ctx := s.Context()
169		logger := log.FromContext(ctx).WithPrefix("ssh")
170		ct := time.Now()
171		hpk := sshutils.MarshalAuthorizedKey(s.PublicKey())
172		ptyReq, _, isPty := s.Pty()
173		addr := s.RemoteAddr().String()
174		user := proto.UserFromContext(ctx)
175		logArgs := []interface{}{
176			"addr",
177			addr,
178			"cmd",
179			s.Command(),
180		}
181
182		if user != nil {
183			logArgs = append([]interface{}{
184				"username",
185				user.Username(),
186			}, logArgs...)
187		}
188
189		if isPty {
190			logArgs = []interface{}{
191				"term", ptyReq.Term,
192				"width", ptyReq.Window.Width,
193				"height", ptyReq.Window.Height,
194			}
195		}
196
197		if config.IsVerbose() {
198			logArgs = append(logArgs,
199				"key", hpk,
200				"envs", s.Environ(),
201			)
202		}
203
204		msg := fmt.Sprintf("user %q", s.User())
205		logger.Debug(msg+" connected", logArgs...)
206		sh(s)
207		logger.Debug(msg+" disconnected", append(logArgs, "duration", time.Since(ct))...)
208	}
209}