1package ssh
2
3import (
4 "fmt"
5 "time"
6
7 "github.com/charmbracelet/log"
8 "github.com/charmbracelet/soft-serve/pkg/backend"
9 "github.com/charmbracelet/soft-serve/pkg/config"
10 "github.com/charmbracelet/soft-serve/pkg/db"
11 "github.com/charmbracelet/soft-serve/pkg/proto"
12 "github.com/charmbracelet/soft-serve/pkg/ssh/cmd"
13 "github.com/charmbracelet/soft-serve/pkg/sshutils"
14 "github.com/charmbracelet/soft-serve/pkg/store"
15 "github.com/charmbracelet/ssh"
16 "github.com/charmbracelet/wish"
17 bm "github.com/charmbracelet/wish/bubbletea"
18 "github.com/prometheus/client_golang/prometheus"
19 "github.com/prometheus/client_golang/prometheus/promauto"
20 "github.com/spf13/cobra"
21 gossh "golang.org/x/crypto/ssh"
22)
23
24// ErrPermissionDenied is returned when a user is not allowed connect.
25var ErrPermissionDenied = fmt.Errorf("permission denied")
26
27// AuthenticationMiddleware handles authentication.
28func AuthenticationMiddleware(sh ssh.Handler) ssh.Handler {
29 return func(s ssh.Session) {
30 // XXX: The authentication key is set in the context but gossh doesn't
31 // validate the authentication. We need to verify that the _last_ key
32 // that was approved is the one that's being used.
33
34 pk := s.PublicKey()
35 if pk != nil {
36 // There is no public key stored in the context, public-key auth
37 // was never requested, skip
38 perms := s.Permissions().Permissions
39 if perms == nil {
40 wish.Fatalln(s, ErrPermissionDenied)
41 return
42 }
43
44 // Check if the key is the same as the one we have in context
45 fp := perms.Extensions["pubkey-fp"]
46 if fp != gossh.FingerprintSHA256(pk) {
47 wish.Fatalln(s, ErrPermissionDenied)
48 return
49 }
50 }
51
52 sh(s)
53 }
54}
55
56// ContextMiddleware adds the config, backend, and logger to the session context.
57func ContextMiddleware(cfg *config.Config, dbx *db.DB, datastore store.Store, be *backend.Backend, logger *log.Logger) func(ssh.Handler) ssh.Handler {
58 return func(sh ssh.Handler) ssh.Handler {
59 return func(s ssh.Session) {
60 s.Context().SetValue(sshutils.ContextKeySession, s)
61 s.Context().SetValue(config.ContextKey, cfg)
62 s.Context().SetValue(db.ContextKey, dbx)
63 s.Context().SetValue(store.ContextKey, datastore)
64 s.Context().SetValue(backend.ContextKey, be)
65 s.Context().SetValue(log.ContextKey, logger.WithPrefix("ssh"))
66 sh(s)
67 }
68 }
69}
70
71var cliCommandCounter = promauto.NewCounterVec(prometheus.CounterOpts{
72 Namespace: "soft_serve",
73 Subsystem: "cli",
74 Name: "commands_total",
75 Help: "Total times each command was called",
76}, []string{"command"})
77
78// CommandMiddleware handles git commands and CLI commands.
79// This middleware must be run after the ContextMiddleware.
80func CommandMiddleware(sh ssh.Handler) ssh.Handler {
81 return func(s ssh.Session) {
82 func() {
83 _, _, ptyReq := s.Pty()
84 if ptyReq {
85 return
86 }
87
88 r := bm.MakeRenderer(s)
89
90 ctx := s.Context()
91 cfg := config.FromContext(ctx)
92
93 args := s.Command()
94 cliCommandCounter.WithLabelValues(cmd.CommandName(args)).Inc()
95 rootCmd := &cobra.Command{
96 Short: "Soft Serve is a self-hostable Git server for the command line.",
97 SilenceUsage: true,
98 }
99 rootCmd.CompletionOptions.DisableDefaultCmd = true
100
101 rootCmd.SetUsageTemplate(cmd.UsageTemplate)
102 rootCmd.SetUsageFunc(cmd.UsageFunc)
103 rootCmd.AddCommand(
104 cmd.GitUploadPackCommand(),
105 cmd.GitUploadArchiveCommand(),
106 cmd.GitReceivePackCommand(),
107 cmd.RepoCommand(r),
108 cmd.SettingsCommand(),
109 cmd.UserCommand(),
110 cmd.InfoCommand(),
111 cmd.PubkeyCommand(),
112 cmd.SetUsernameCommand(),
113 cmd.JWTCommand(),
114 cmd.TokenCommand(),
115 )
116
117 if cfg.LFS.Enabled {
118 rootCmd.AddCommand(
119 cmd.GitLFSAuthenticateCommand(),
120 )
121
122 if cfg.LFS.SSHEnabled {
123 rootCmd.AddCommand(
124 cmd.GitLFSTransfer(),
125 )
126 }
127 }
128
129 rootCmd.SetArgs(args)
130 if len(args) == 0 {
131 // otherwise it'll default to os.Args, which is not what we want.
132 rootCmd.SetArgs([]string{"--help"})
133 }
134 rootCmd.SetIn(s)
135 rootCmd.SetOut(s)
136 rootCmd.SetErr(s.Stderr())
137 rootCmd.SetContext(ctx)
138
139 if err := rootCmd.ExecuteContext(ctx); err != nil {
140 s.Exit(1) // nolint: errcheck
141 return
142 }
143 }()
144 sh(s)
145 }
146}
147
148// LoggingMiddleware logs the ssh connection and command.
149func LoggingMiddleware(sh ssh.Handler) ssh.Handler {
150 return func(s ssh.Session) {
151 ctx := s.Context()
152 logger := log.FromContext(ctx).WithPrefix("ssh")
153 ct := time.Now()
154 hpk := sshutils.MarshalAuthorizedKey(s.PublicKey())
155 ptyReq, _, isPty := s.Pty()
156 addr := s.RemoteAddr().String()
157 user := proto.UserFromContext(ctx)
158 logArgs := []interface{}{
159 "addr",
160 addr,
161 "cmd",
162 s.Command(),
163 }
164
165 if user != nil {
166 logArgs = append([]interface{}{
167 "username",
168 user.Username(),
169 }, logArgs...)
170 }
171
172 if isPty {
173 logArgs = []interface{}{
174 "term", ptyReq.Term,
175 "width", ptyReq.Window.Width,
176 "height", ptyReq.Window.Height,
177 }
178 }
179
180 if config.IsVerbose() {
181 logArgs = append(logArgs,
182 "key", hpk,
183 "envs", s.Environ(),
184 )
185 }
186
187 msg := fmt.Sprintf("user %q", s.User())
188 logger.Debug(msg+" connected", logArgs...)
189 sh(s)
190 logger.Debug(msg+" disconnected", append(logArgs, "duration", time.Since(ct))...)
191 }
192}