1package ssh
2
3import (
4 "context"
5 "fmt"
6 "net"
7 "os"
8 "strconv"
9 "time"
10
11 "github.com/charmbracelet/keygen"
12 log "github.com/charmbracelet/log/v2"
13 "github.com/charmbracelet/soft-serve/pkg/backend"
14 "github.com/charmbracelet/soft-serve/pkg/config"
15 "github.com/charmbracelet/soft-serve/pkg/db"
16 "github.com/charmbracelet/soft-serve/pkg/proto"
17 "github.com/charmbracelet/soft-serve/pkg/store"
18 "github.com/charmbracelet/ssh"
19 wish "github.com/charmbracelet/wish/v2"
20 bm "github.com/charmbracelet/wish/v2/bubbletea"
21 rm "github.com/charmbracelet/wish/v2/recover"
22 "github.com/prometheus/client_golang/prometheus"
23 "github.com/prometheus/client_golang/prometheus/promauto"
24 gossh "golang.org/x/crypto/ssh"
25)
26
27var (
28 publicKeyCounter = promauto.NewCounterVec(prometheus.CounterOpts{
29 Namespace: "soft_serve",
30 Subsystem: "ssh",
31 Name: "public_key_auth_total",
32 Help: "The total number of public key auth requests",
33 }, []string{"allowed"})
34
35 keyboardInteractiveCounter = promauto.NewCounterVec(prometheus.CounterOpts{
36 Namespace: "soft_serve",
37 Subsystem: "ssh",
38 Name: "keyboard_interactive_auth_total",
39 Help: "The total number of keyboard interactive auth requests",
40 }, []string{"allowed"})
41)
42
43// SSHServer is a SSH server that implements the git protocol.
44type SSHServer struct { //nolint:revive
45 srv *ssh.Server
46 cfg *config.Config
47 be *backend.Backend
48 ctx context.Context
49 logger *log.Logger
50}
51
52// NewSSHServer returns a new SSHServer.
53func NewSSHServer(ctx context.Context) (*SSHServer, error) {
54 cfg := config.FromContext(ctx)
55 logger := log.FromContext(ctx).WithPrefix("ssh")
56 dbx := db.FromContext(ctx)
57 datastore := store.FromContext(ctx)
58 be := backend.FromContext(ctx)
59
60 var err error
61 s := &SSHServer{
62 cfg: cfg,
63 ctx: ctx,
64 be: be,
65 logger: logger,
66 }
67
68 mw := []wish.Middleware{
69 rm.MiddlewareWithLogger(
70 logger,
71 // BubbleTea middleware.
72 bm.MiddlewareWithProgramHandler(SessionHandler),
73 // CLI middleware.
74 CommandMiddleware,
75 // Logging middleware.
76 LoggingMiddleware,
77 // Context middleware.
78 ContextMiddleware(cfg, dbx, datastore, be, logger),
79 // Authentication middleware.
80 // gossh.PublicKeyHandler doesn't guarantee that the public key
81 // is in fact the one used for authentication, so we need to
82 // check it again here.
83 AuthenticationMiddleware,
84 ),
85 }
86
87 opts := []ssh.Option{
88 ssh.PublicKeyAuth(s.PublicKeyHandler),
89 ssh.KeyboardInteractiveAuth(s.KeyboardInteractiveHandler),
90 wish.WithAddress(cfg.SSH.ListenAddr),
91 wish.WithHostKeyPath(cfg.SSH.KeyPath),
92 wish.WithMiddleware(mw...),
93 }
94
95 // TODO: Support a real PTY in future version.
96 opts = append(opts, ssh.EmulatePty())
97
98 s.srv, err = wish.NewServer(opts...)
99 if err != nil {
100 return nil, err
101 }
102
103 if config.IsDebug() {
104 s.srv.ServerConfigCallback = func(_ ssh.Context) *gossh.ServerConfig {
105 return &gossh.ServerConfig{
106 AuthLogCallback: func(conn gossh.ConnMetadata, method string, err error) {
107 logger.Debug("authentication", "user", conn.User(), "method", method, "err", err)
108 },
109 }
110 }
111 }
112
113 if cfg.SSH.MaxTimeout > 0 {
114 s.srv.MaxTimeout = time.Duration(cfg.SSH.MaxTimeout) * time.Second
115 }
116
117 if cfg.SSH.IdleTimeout > 0 {
118 s.srv.IdleTimeout = time.Duration(cfg.SSH.IdleTimeout) * time.Second
119 }
120
121 // Create client ssh key
122 if _, err := os.Stat(cfg.SSH.ClientKeyPath); err != nil && os.IsNotExist(err) {
123 _, err := keygen.New(cfg.SSH.ClientKeyPath, keygen.WithKeyType(keygen.Ed25519), keygen.WithWrite())
124 if err != nil {
125 return nil, fmt.Errorf("client ssh key: %w", err)
126 }
127 }
128
129 return s, nil
130}
131
132// ListenAndServe starts the SSH server.
133func (s *SSHServer) ListenAndServe() error {
134 return s.srv.ListenAndServe()
135}
136
137// Serve starts the SSH server on the given net.Listener.
138func (s *SSHServer) Serve(l net.Listener) error {
139 return s.srv.Serve(l)
140}
141
142// Close closes the SSH server.
143func (s *SSHServer) Close() error {
144 return s.srv.Close()
145}
146
147// Shutdown gracefully shuts down the SSH server.
148func (s *SSHServer) Shutdown(ctx context.Context) error {
149 return s.srv.Shutdown(ctx)
150}
151
152func initializePermissions(ctx ssh.Context) {
153 perms := ctx.Permissions()
154 if perms == nil || perms.Permissions == nil {
155 perms = &ssh.Permissions{Permissions: &gossh.Permissions{}}
156 }
157 if perms.Extensions == nil {
158 perms.Extensions = make(map[string]string)
159 }
160 if perms.Permissions.Extensions == nil {
161 perms.Permissions.Extensions = make(map[string]string)
162 }
163}
164
165// PublicKeyHandler handles public key authentication.
166func (s *SSHServer) PublicKeyHandler(ctx ssh.Context, pk ssh.PublicKey) (allowed bool) {
167 if pk == nil {
168 return false
169 }
170
171 allowed = true
172 defer func(allowed *bool) {
173 publicKeyCounter.WithLabelValues(strconv.FormatBool(*allowed)).Inc()
174 }(&allowed)
175
176 user, _ := s.be.UserByPublicKey(ctx, pk)
177 if user != nil {
178 ctx.SetValue(proto.ContextKeyUser, user)
179 }
180
181 // XXX: store the first "approved" public-key fingerprint in the
182 // permissions block to use for authentication later.
183 initializePermissions(ctx)
184 perms := ctx.Permissions()
185
186 // Set the public key fingerprint to be used for authentication.
187 perms.Extensions["pubkey-fp"] = gossh.FingerprintSHA256(pk)
188 ctx.SetValue(ssh.ContextKeyPermissions, perms)
189
190 return
191}
192
193// KeyboardInteractiveHandler handles keyboard interactive authentication.
194// This is used after all public key authentication has failed.
195func (s *SSHServer) KeyboardInteractiveHandler(ctx ssh.Context, _ gossh.KeyboardInteractiveChallenge) bool {
196 ac := s.be.AllowKeyless(ctx)
197 keyboardInteractiveCounter.WithLabelValues(strconv.FormatBool(ac)).Inc()
198
199 // If we're allowing keyless access, reset the public key fingerprint
200 if ac {
201 initializePermissions(ctx)
202 perms := ctx.Permissions()
203
204 // XXX: reset the public-key fingerprint. This is used to validate the
205 // public key being used to authenticate.
206 perms.Extensions["pubkey-fp"] = ""
207 ctx.SetValue(ssh.ContextKeyPermissions, perms)
208 }
209 return ac
210}