middleware.go

  1package ssh
  2
  3import (
  4	"fmt"
  5	"time"
  6
  7	"github.com/charmbracelet/log"
  8	"github.com/charmbracelet/soft-serve/pkg/backend"
  9	"github.com/charmbracelet/soft-serve/pkg/config"
 10	"github.com/charmbracelet/soft-serve/pkg/db"
 11	"github.com/charmbracelet/soft-serve/pkg/proto"
 12	"github.com/charmbracelet/soft-serve/pkg/ssh/cmd"
 13	"github.com/charmbracelet/soft-serve/pkg/sshutils"
 14	"github.com/charmbracelet/soft-serve/pkg/store"
 15	"github.com/charmbracelet/ssh"
 16	"github.com/charmbracelet/wish"
 17	bm "github.com/charmbracelet/wish/bubbletea"
 18	"github.com/prometheus/client_golang/prometheus"
 19	"github.com/prometheus/client_golang/prometheus/promauto"
 20	"github.com/spf13/cobra"
 21	gossh "golang.org/x/crypto/ssh"
 22)
 23
 24// ErrPermissionDenied is returned when a user is not allowed connect.
 25var ErrPermissionDenied = fmt.Errorf("permission denied")
 26
 27// AuthenticationMiddleware handles authentication.
 28func AuthenticationMiddleware(sh ssh.Handler) ssh.Handler {
 29	return func(s ssh.Session) {
 30		// XXX: The authentication key is set in the context but gossh doesn't
 31		// validate the authentication. We need to verify that the _last_ key
 32		// that was approved is the one that's being used.
 33
 34		pk := s.PublicKey()
 35		if pk != nil {
 36			// There is no public key stored in the context, public-key auth
 37			// was never requested, skip
 38			perms := s.Permissions().Permissions
 39			if perms == nil {
 40				wish.Fatalln(s, ErrPermissionDenied)
 41				return
 42			}
 43
 44			// Check if the key is the same as the one we have in context
 45			fp := perms.Extensions["pubkey-fp"]
 46			if fp != gossh.FingerprintSHA256(pk) {
 47				wish.Fatalln(s, ErrPermissionDenied)
 48				return
 49			}
 50		}
 51
 52		sh(s)
 53	}
 54}
 55
 56// ContextMiddleware adds the config, backend, and logger to the session context.
 57func ContextMiddleware(cfg *config.Config, dbx *db.DB, datastore store.Store, be *backend.Backend, logger *log.Logger) func(ssh.Handler) ssh.Handler {
 58	return func(sh ssh.Handler) ssh.Handler {
 59		return func(s ssh.Session) {
 60			s.Context().SetValue(sshutils.ContextKeySession, s)
 61			s.Context().SetValue(config.ContextKey, cfg)
 62			s.Context().SetValue(db.ContextKey, dbx)
 63			s.Context().SetValue(store.ContextKey, datastore)
 64			s.Context().SetValue(backend.ContextKey, be)
 65			s.Context().SetValue(log.ContextKey, logger.WithPrefix("ssh"))
 66			sh(s)
 67		}
 68	}
 69}
 70
 71var cliCommandCounter = promauto.NewCounterVec(prometheus.CounterOpts{
 72	Namespace: "soft_serve",
 73	Subsystem: "cli",
 74	Name:      "commands_total",
 75	Help:      "Total times each command was called",
 76}, []string{"command"})
 77
 78// CommandMiddleware handles git commands and CLI commands.
 79// This middleware must be run after the ContextMiddleware.
 80func CommandMiddleware(sh ssh.Handler) ssh.Handler {
 81	return func(s ssh.Session) {
 82		func() {
 83			_, _, ptyReq := s.Pty()
 84			if ptyReq {
 85				return
 86			}
 87
 88			r := bm.MakeRenderer(s)
 89
 90			ctx := s.Context()
 91			cfg := config.FromContext(ctx)
 92
 93			args := s.Command()
 94			cliCommandCounter.WithLabelValues(cmd.CommandName(args)).Inc()
 95			rootCmd := &cobra.Command{
 96				Short:        "Soft Serve is a self-hostable Git server for the command line.",
 97				SilenceUsage: true,
 98			}
 99			rootCmd.CompletionOptions.DisableDefaultCmd = true
100
101			rootCmd.SetUsageTemplate(cmd.UsageTemplate)
102			rootCmd.SetUsageFunc(cmd.UsageFunc)
103			rootCmd.AddCommand(
104				cmd.GitUploadPackCommand(),
105				cmd.GitUploadArchiveCommand(),
106				cmd.GitReceivePackCommand(),
107				cmd.RepoCommand(r),
108				cmd.SettingsCommand(),
109				cmd.UserCommand(),
110				cmd.InfoCommand(),
111				cmd.PubkeyCommand(),
112				cmd.SetUsernameCommand(),
113				cmd.JWTCommand(),
114				cmd.TokenCommand(),
115			)
116
117			if cfg.LFS.Enabled {
118				rootCmd.AddCommand(
119					cmd.GitLFSAuthenticateCommand(),
120				)
121
122				if cfg.LFS.SSHEnabled {
123					rootCmd.AddCommand(
124						cmd.GitLFSTransfer(),
125					)
126				}
127			}
128
129			rootCmd.SetArgs(args)
130			if len(args) == 0 {
131				// otherwise it'll default to os.Args, which is not what we want.
132				rootCmd.SetArgs([]string{"--help"})
133			}
134			rootCmd.SetIn(s)
135			rootCmd.SetOut(s)
136			rootCmd.SetErr(s.Stderr())
137			rootCmd.SetContext(ctx)
138
139			if err := rootCmd.ExecuteContext(ctx); err != nil {
140				s.Exit(1) // nolint: errcheck
141				return
142			}
143		}()
144		sh(s)
145	}
146}
147
148// LoggingMiddleware logs the ssh connection and command.
149func LoggingMiddleware(sh ssh.Handler) ssh.Handler {
150	return func(s ssh.Session) {
151		ctx := s.Context()
152		logger := log.FromContext(ctx).WithPrefix("ssh")
153		ct := time.Now()
154		hpk := sshutils.MarshalAuthorizedKey(s.PublicKey())
155		ptyReq, _, isPty := s.Pty()
156		addr := s.RemoteAddr().String()
157		user := proto.UserFromContext(ctx)
158		logArgs := []interface{}{
159			"addr",
160			addr,
161			"cmd",
162			s.Command(),
163		}
164
165		if user != nil {
166			logArgs = append([]interface{}{
167				"username",
168				user.Username(),
169			}, logArgs...)
170		}
171
172		if isPty {
173			logArgs = []interface{}{
174				"term", ptyReq.Term,
175				"width", ptyReq.Window.Width,
176				"height", ptyReq.Window.Height,
177			}
178		}
179
180		if config.IsVerbose() {
181			logArgs = append(logArgs,
182				"key", hpk,
183				"envs", s.Environ(),
184			)
185		}
186
187		msg := fmt.Sprintf("user %q", s.User())
188		logger.Debug(msg+" connected", logArgs...)
189		sh(s)
190		logger.Debug(msg+" disconnected", append(logArgs, "duration", time.Since(ct))...)
191	}
192}