1package ssh
  2
  3import (
  4	"fmt"
  5	"time"
  6
  7	"github.com/charmbracelet/log"
  8	"github.com/charmbracelet/soft-serve/server/backend"
  9	"github.com/charmbracelet/soft-serve/server/config"
 10	"github.com/charmbracelet/soft-serve/server/db"
 11	"github.com/charmbracelet/soft-serve/server/proto"
 12	"github.com/charmbracelet/soft-serve/server/ssh/cmd"
 13	"github.com/charmbracelet/soft-serve/server/sshutils"
 14	"github.com/charmbracelet/soft-serve/server/store"
 15	"github.com/charmbracelet/ssh"
 16	"github.com/prometheus/client_golang/prometheus"
 17	"github.com/prometheus/client_golang/prometheus/promauto"
 18	"github.com/spf13/cobra"
 19)
 20
 21// ContextMiddleware adds the config, backend, and logger to the session context.
 22func ContextMiddleware(cfg *config.Config, dbx *db.DB, datastore store.Store, be *backend.Backend, logger *log.Logger) func(ssh.Handler) ssh.Handler {
 23	return func(sh ssh.Handler) ssh.Handler {
 24		return func(s ssh.Session) {
 25			s.Context().SetValue(sshutils.ContextKeySession, s)
 26			s.Context().SetValue(config.ContextKey, cfg)
 27			s.Context().SetValue(db.ContextKey, dbx)
 28			s.Context().SetValue(store.ContextKey, datastore)
 29			s.Context().SetValue(backend.ContextKey, be)
 30			s.Context().SetValue(log.ContextKey, logger.WithPrefix("ssh"))
 31			sh(s)
 32		}
 33	}
 34}
 35
 36var cliCommandCounter = promauto.NewCounterVec(prometheus.CounterOpts{
 37	Namespace: "soft_serve",
 38	Subsystem: "cli",
 39	Name:      "commands_total",
 40	Help:      "Total times each command was called",
 41}, []string{"command"})
 42
 43// CommandMiddleware handles git commands and CLI commands.
 44// This middleware must be run after the ContextMiddleware.
 45func CommandMiddleware(sh ssh.Handler) ssh.Handler {
 46	return func(s ssh.Session) {
 47		func() {
 48			_, _, ptyReq := s.Pty()
 49			if ptyReq {
 50				return
 51			}
 52
 53			ctx := s.Context()
 54			cfg := config.FromContext(ctx)
 55
 56			args := s.Command()
 57			cliCommandCounter.WithLabelValues(cmd.CommandName(args)).Inc()
 58			rootCmd := &cobra.Command{
 59				Short:        "Soft Serve is a self-hostable Git server for the command line.",
 60				SilenceUsage: true,
 61			}
 62			rootCmd.CompletionOptions.DisableDefaultCmd = true
 63
 64			rootCmd.SetUsageTemplate(cmd.UsageTemplate)
 65			rootCmd.SetUsageFunc(cmd.UsageFunc)
 66			rootCmd.AddCommand(
 67				cmd.GitUploadPackCommand(),
 68				cmd.GitUploadArchiveCommand(),
 69				cmd.GitReceivePackCommand(),
 70				cmd.RepoCommand(),
 71				cmd.SettingsCommand(),
 72				cmd.UserCommand(),
 73				cmd.InfoCommand(),
 74				cmd.PubkeyCommand(),
 75				cmd.SetUsernameCommand(),
 76				cmd.JWTCommand(),
 77				cmd.TokenCommand(),
 78			)
 79
 80			if cfg.LFS.Enabled {
 81				rootCmd.AddCommand(
 82					cmd.GitLFSAuthenticateCommand(),
 83				)
 84
 85				if cfg.LFS.SSHEnabled {
 86					rootCmd.AddCommand(
 87						cmd.GitLFSTransfer(),
 88					)
 89				}
 90			}
 91
 92			rootCmd.SetArgs(args)
 93			if len(args) == 0 {
 94				// otherwise it'll default to os.Args, which is not what we want.
 95				rootCmd.SetArgs([]string{"--help"})
 96			}
 97			rootCmd.SetIn(s)
 98			rootCmd.SetOut(s)
 99			rootCmd.SetErr(s.Stderr())
100			rootCmd.SetContext(ctx)
101
102			if err := rootCmd.ExecuteContext(ctx); err != nil {
103				s.Exit(1) // nolint: errcheck
104				return
105			}
106		}()
107		sh(s)
108	}
109}
110
111// LoggingMiddleware logs the ssh connection and command.
112func LoggingMiddleware(sh ssh.Handler) ssh.Handler {
113	return func(s ssh.Session) {
114		ctx := s.Context()
115		logger := log.FromContext(ctx).WithPrefix("ssh")
116		ct := time.Now()
117		hpk := sshutils.MarshalAuthorizedKey(s.PublicKey())
118		ptyReq, _, isPty := s.Pty()
119		addr := s.RemoteAddr().String()
120		user := proto.UserFromContext(ctx)
121		logArgs := []interface{}{
122			"addr",
123			addr,
124			"cmd",
125			s.Command(),
126		}
127
128		if user != nil {
129			logArgs = append([]interface{}{
130				"username",
131				user.Username(),
132			}, logArgs...)
133		}
134
135		if isPty {
136			logArgs = []interface{}{
137				"term", ptyReq.Term,
138				"width", ptyReq.Window.Width,
139				"height", ptyReq.Window.Height,
140			}
141		}
142
143		if config.IsVerbose() {
144			logArgs = append(logArgs,
145				"key", hpk,
146				"envs", s.Environ(),
147			)
148		}
149
150		msg := fmt.Sprintf("user %q", s.User())
151		logger.Debug(msg+" connected", logArgs...)
152		sh(s)
153		logger.Debug(msg+" disconnected", append(logArgs, "duration", time.Since(ct))...)
154	}
155}