1package ssh
2
3import (
4 "fmt"
5 "time"
6
7 "github.com/charmbracelet/log"
8 "github.com/charmbracelet/soft-serve/server/backend"
9 "github.com/charmbracelet/soft-serve/server/config"
10 "github.com/charmbracelet/soft-serve/server/db"
11 "github.com/charmbracelet/soft-serve/server/proto"
12 "github.com/charmbracelet/soft-serve/server/ssh/cmd"
13 "github.com/charmbracelet/soft-serve/server/sshutils"
14 "github.com/charmbracelet/soft-serve/server/store"
15 "github.com/charmbracelet/ssh"
16 "github.com/prometheus/client_golang/prometheus"
17 "github.com/prometheus/client_golang/prometheus/promauto"
18 "github.com/spf13/cobra"
19)
20
21// ContextMiddleware adds the config, backend, and logger to the session context.
22func ContextMiddleware(cfg *config.Config, dbx *db.DB, datastore store.Store, be *backend.Backend, logger *log.Logger) func(ssh.Handler) ssh.Handler {
23 return func(sh ssh.Handler) ssh.Handler {
24 return func(s ssh.Session) {
25 s.Context().SetValue(sshutils.ContextKeySession, s)
26 s.Context().SetValue(config.ContextKey, cfg)
27 s.Context().SetValue(db.ContextKey, dbx)
28 s.Context().SetValue(store.ContextKey, datastore)
29 s.Context().SetValue(backend.ContextKey, be)
30 s.Context().SetValue(log.ContextKey, logger.WithPrefix("ssh"))
31 sh(s)
32 }
33 }
34}
35
36var cliCommandCounter = promauto.NewCounterVec(prometheus.CounterOpts{
37 Namespace: "soft_serve",
38 Subsystem: "cli",
39 Name: "commands_total",
40 Help: "Total times each command was called",
41}, []string{"command"})
42
43// CommandMiddleware handles git commands and CLI commands.
44// This middleware must be run after the ContextMiddleware.
45func CommandMiddleware(sh ssh.Handler) ssh.Handler {
46 return func(s ssh.Session) {
47 func() {
48 _, _, ptyReq := s.Pty()
49 if ptyReq {
50 return
51 }
52
53 ctx := s.Context()
54 cfg := config.FromContext(ctx)
55
56 args := s.Command()
57 cliCommandCounter.WithLabelValues(cmd.CommandName(args)).Inc()
58 rootCmd := &cobra.Command{
59 Short: "Soft Serve is a self-hostable Git server for the command line.",
60 SilenceUsage: true,
61 }
62 rootCmd.CompletionOptions.DisableDefaultCmd = true
63
64 rootCmd.SetUsageTemplate(cmd.UsageTemplate)
65 rootCmd.SetUsageFunc(cmd.UsageFunc)
66 rootCmd.AddCommand(
67 cmd.GitUploadPackCommand(),
68 cmd.GitUploadArchiveCommand(),
69 cmd.GitReceivePackCommand(),
70 cmd.RepoCommand(),
71 )
72
73 if cfg.LFS.Enabled {
74 rootCmd.AddCommand(
75 cmd.GitLFSAuthenticateCommand(),
76 )
77
78 if cfg.LFS.SSHEnabled {
79 rootCmd.AddCommand(
80 cmd.GitLFSTransfer(),
81 )
82 }
83 }
84
85 rootCmd.SetArgs(args)
86 if len(args) == 0 {
87 // otherwise it'll default to os.Args, which is not what we want.
88 rootCmd.SetArgs([]string{"--help"})
89 }
90 rootCmd.SetIn(s)
91 rootCmd.SetOut(s)
92 rootCmd.SetErr(s.Stderr())
93 rootCmd.SetContext(ctx)
94
95 user := proto.UserFromContext(ctx)
96 isAdmin := cmd.IsPublicKeyAdmin(cfg, s.PublicKey()) || (user != nil && user.IsAdmin())
97 if user != nil || isAdmin {
98 if isAdmin {
99 rootCmd.AddCommand(
100 cmd.SettingsCommand(),
101 cmd.UserCommand(),
102 )
103 }
104
105 rootCmd.AddCommand(
106 cmd.InfoCommand(),
107 cmd.PubkeyCommand(),
108 cmd.SetUsernameCommand(),
109 cmd.JWTCommand(),
110 cmd.TokenCommand(),
111 )
112 }
113
114 if err := rootCmd.ExecuteContext(ctx); err != nil {
115 s.Exit(1) // nolint: errcheck
116 return
117 }
118 }()
119 sh(s)
120 }
121}
122
123// LoggingMiddleware logs the ssh connection and command.
124func LoggingMiddleware(sh ssh.Handler) ssh.Handler {
125 return func(s ssh.Session) {
126 ctx := s.Context()
127 logger := log.FromContext(ctx).WithPrefix("ssh")
128 ct := time.Now()
129 hpk := sshutils.MarshalAuthorizedKey(s.PublicKey())
130 ptyReq, _, isPty := s.Pty()
131 addr := s.RemoteAddr().String()
132 user := proto.UserFromContext(ctx)
133 logArgs := []interface{}{
134 "addr",
135 addr,
136 "cmd",
137 s.Command(),
138 }
139
140 if user != nil {
141 logArgs = append([]interface{}{
142 "username",
143 user.Username(),
144 }, logArgs...)
145 }
146
147 if isPty {
148 logArgs = []interface{}{
149 "term", ptyReq.Term,
150 "width", ptyReq.Window.Width,
151 "height", ptyReq.Window.Height,
152 }
153 }
154
155 if config.IsVerbose() {
156 logArgs = append(logArgs,
157 "key", hpk,
158 "envs", s.Environ(),
159 )
160 }
161
162 msg := fmt.Sprintf("user %q", s.User())
163 logger.Debug(msg+" connected", logArgs...)
164 sh(s)
165 logger.Debug(msg+" disconnected", append(logArgs, "duration", time.Since(ct))...)
166 }
167}