1package ssh
2
3import (
4 "context"
5 "fmt"
6 "net"
7 "os"
8 "runtime"
9 "strconv"
10 "time"
11
12 "github.com/charmbracelet/keygen"
13 "github.com/charmbracelet/log"
14 "github.com/charmbracelet/soft-serve/pkg/backend"
15 "github.com/charmbracelet/soft-serve/pkg/config"
16 "github.com/charmbracelet/soft-serve/pkg/db"
17 "github.com/charmbracelet/soft-serve/pkg/proto"
18 "github.com/charmbracelet/soft-serve/pkg/store"
19 "github.com/charmbracelet/ssh"
20 "github.com/charmbracelet/wish/v2"
21 bm "github.com/charmbracelet/wish/v2/bubbletea"
22 rm "github.com/charmbracelet/wish/v2/recover"
23 "github.com/prometheus/client_golang/prometheus"
24 "github.com/prometheus/client_golang/prometheus/promauto"
25 gossh "golang.org/x/crypto/ssh"
26)
27
28var (
29 publicKeyCounter = promauto.NewCounterVec(prometheus.CounterOpts{
30 Namespace: "soft_serve",
31 Subsystem: "ssh",
32 Name: "public_key_auth_total",
33 Help: "The total number of public key auth requests",
34 }, []string{"allowed"})
35
36 keyboardInteractiveCounter = promauto.NewCounterVec(prometheus.CounterOpts{
37 Namespace: "soft_serve",
38 Subsystem: "ssh",
39 Name: "keyboard_interactive_auth_total",
40 Help: "The total number of keyboard interactive auth requests",
41 }, []string{"allowed"})
42)
43
44// SSHServer is a SSH server that implements the git protocol.
45type SSHServer struct { // nolint: revive
46 srv *ssh.Server
47 cfg *config.Config
48 be *backend.Backend
49 ctx context.Context
50 logger *log.Logger
51}
52
53// NewSSHServer returns a new SSHServer.
54func NewSSHServer(ctx context.Context) (*SSHServer, error) {
55 cfg := config.FromContext(ctx)
56 logger := log.FromContext(ctx).WithPrefix("ssh")
57 dbx := db.FromContext(ctx)
58 datastore := store.FromContext(ctx)
59 be := backend.FromContext(ctx)
60
61 var err error
62 s := &SSHServer{
63 cfg: cfg,
64 ctx: ctx,
65 be: be,
66 logger: logger,
67 }
68
69 mw := []wish.Middleware{
70 rm.MiddlewareWithLogger(
71 logger,
72 // BubbleTea middleware.
73 bm.MiddlewareWithProgramHandler(SessionHandler),
74 // CLI middleware.
75 CommandMiddleware,
76 // Logging middleware.
77 LoggingMiddleware,
78 // Context middleware.
79 ContextMiddleware(cfg, dbx, datastore, be, logger),
80 // Authentication middleware.
81 // gossh.PublicKeyHandler doesn't guarantee that the public key
82 // is in fact the one used for authentication, so we need to
83 // check it again here.
84 AuthenticationMiddleware,
85 ),
86 }
87
88 opts := []ssh.Option{
89 ssh.PublicKeyAuth(s.PublicKeyHandler),
90 ssh.KeyboardInteractiveAuth(s.KeyboardInteractiveHandler),
91 wish.WithAddress(cfg.SSH.ListenAddr),
92 wish.WithHostKeyPath(cfg.SSH.KeyPath),
93 wish.WithMiddleware(mw...),
94 }
95 if runtime.GOOS == "windows" {
96 opts = append(opts, ssh.EmulatePty())
97 } else {
98 opts = append(opts, ssh.AllocatePty())
99 }
100 s.srv, err = wish.NewServer(opts...)
101 if err != nil {
102 return nil, err
103 }
104
105 if config.IsDebug() {
106 s.srv.ServerConfigCallback = func(_ ssh.Context) *gossh.ServerConfig {
107 return &gossh.ServerConfig{
108 AuthLogCallback: func(conn gossh.ConnMetadata, method string, err error) {
109 logger.Debug("authentication", "user", conn.User(), "method", method, "err", err)
110 },
111 }
112 }
113 }
114
115 if cfg.SSH.MaxTimeout > 0 {
116 s.srv.MaxTimeout = time.Duration(cfg.SSH.MaxTimeout) * time.Second
117 }
118
119 if cfg.SSH.IdleTimeout > 0 {
120 s.srv.IdleTimeout = time.Duration(cfg.SSH.IdleTimeout) * time.Second
121 }
122
123 // Create client ssh key
124 if _, err := os.Stat(cfg.SSH.ClientKeyPath); err != nil && os.IsNotExist(err) {
125 _, err := keygen.New(cfg.SSH.ClientKeyPath, keygen.WithKeyType(keygen.Ed25519), keygen.WithWrite())
126 if err != nil {
127 return nil, fmt.Errorf("client ssh key: %w", err)
128 }
129 }
130
131 return s, nil
132}
133
134// ListenAndServe starts the SSH server.
135func (s *SSHServer) ListenAndServe() error {
136 return s.srv.ListenAndServe()
137}
138
139// Serve starts the SSH server on the given net.Listener.
140func (s *SSHServer) Serve(l net.Listener) error {
141 return s.srv.Serve(l)
142}
143
144// Close closes the SSH server.
145func (s *SSHServer) Close() error {
146 return s.srv.Close()
147}
148
149// Shutdown gracefully shuts down the SSH server.
150func (s *SSHServer) Shutdown(ctx context.Context) error {
151 return s.srv.Shutdown(ctx)
152}
153
154func initializePermissions(ctx ssh.Context) {
155 perms := ctx.Permissions()
156 if perms == nil || perms.Permissions == nil {
157 perms = &ssh.Permissions{Permissions: &gossh.Permissions{}}
158 }
159 if perms.Extensions == nil {
160 perms.Extensions = make(map[string]string)
161 }
162 if perms.Permissions.Extensions == nil {
163 perms.Permissions.Extensions = make(map[string]string)
164 }
165}
166
167// PublicKeyAuthHandler handles public key authentication.
168func (s *SSHServer) PublicKeyHandler(ctx ssh.Context, pk ssh.PublicKey) (allowed bool) {
169 if pk == nil {
170 return false
171 }
172
173 allowed = true
174 defer func(allowed *bool) {
175 publicKeyCounter.WithLabelValues(strconv.FormatBool(*allowed)).Inc()
176 }(&allowed)
177
178 user, _ := s.be.UserByPublicKey(ctx, pk)
179 if user != nil {
180 ctx.SetValue(proto.ContextKeyUser, user)
181 }
182
183 // XXX: store the first "approved" public-key fingerprint in the
184 // permissions block to use for authentication later.
185 initializePermissions(ctx)
186 perms := ctx.Permissions()
187
188 // Set the public key fingerprint to be used for authentication.
189 perms.Extensions["pubkey-fp"] = gossh.FingerprintSHA256(pk)
190 ctx.SetValue(ssh.ContextKeyPermissions, perms)
191
192 return
193}
194
195// KeyboardInteractiveHandler handles keyboard interactive authentication.
196// This is used after all public key authentication has failed.
197func (s *SSHServer) KeyboardInteractiveHandler(ctx ssh.Context, _ gossh.KeyboardInteractiveChallenge) bool {
198 ac := s.be.AllowKeyless(ctx)
199 keyboardInteractiveCounter.WithLabelValues(strconv.FormatBool(ac)).Inc()
200
201 // If we're allowing keyless access, reset the public key fingerprint
202 if ac {
203 initializePermissions(ctx)
204 perms := ctx.Permissions()
205
206 // XXX: reset the public-key fingerprint. This is used to validate the
207 // public key being used to authenticate.
208 perms.Extensions["pubkey-fp"] = ""
209 ctx.SetValue(ssh.ContextKeyPermissions, perms)
210 }
211 return ac
212}