1package ssh
2
3import (
4 "context"
5 "fmt"
6 "net"
7 "os"
8 "runtime"
9 "strconv"
10 "time"
11
12 "github.com/charmbracelet/keygen"
13 "github.com/charmbracelet/log"
14 "github.com/charmbracelet/soft-serve/pkg/backend"
15 "github.com/charmbracelet/soft-serve/pkg/config"
16 "github.com/charmbracelet/soft-serve/pkg/db"
17 "github.com/charmbracelet/soft-serve/pkg/proto"
18 "github.com/charmbracelet/soft-serve/pkg/store"
19 "github.com/charmbracelet/soft-serve/pkg/ui/common"
20 "github.com/charmbracelet/ssh"
21 "github.com/charmbracelet/wish"
22 bm "github.com/charmbracelet/wish/bubbletea"
23 rm "github.com/charmbracelet/wish/recover"
24 "github.com/prometheus/client_golang/prometheus"
25 "github.com/prometheus/client_golang/prometheus/promauto"
26 gossh "golang.org/x/crypto/ssh"
27)
28
29var (
30 publicKeyCounter = promauto.NewCounterVec(prometheus.CounterOpts{
31 Namespace: "soft_serve",
32 Subsystem: "ssh",
33 Name: "public_key_auth_total",
34 Help: "The total number of public key auth requests",
35 }, []string{"allowed"})
36
37 keyboardInteractiveCounter = promauto.NewCounterVec(prometheus.CounterOpts{
38 Namespace: "soft_serve",
39 Subsystem: "ssh",
40 Name: "keyboard_interactive_auth_total",
41 Help: "The total number of keyboard interactive auth requests",
42 }, []string{"allowed"})
43)
44
45// SSHServer is a SSH server that implements the git protocol.
46type SSHServer struct { // nolint: revive
47 srv *ssh.Server
48 cfg *config.Config
49 be *backend.Backend
50 ctx context.Context
51 logger *log.Logger
52}
53
54// NewSSHServer returns a new SSHServer.
55func NewSSHServer(ctx context.Context) (*SSHServer, error) {
56 cfg := config.FromContext(ctx)
57 logger := log.FromContext(ctx).WithPrefix("ssh")
58 dbx := db.FromContext(ctx)
59 datastore := store.FromContext(ctx)
60 be := backend.FromContext(ctx)
61
62 var err error
63 s := &SSHServer{
64 cfg: cfg,
65 ctx: ctx,
66 be: be,
67 logger: logger,
68 }
69
70 mw := []wish.Middleware{
71 rm.MiddlewareWithLogger(
72 logger,
73 // BubbleTea middleware.
74 bm.MiddlewareWithProgramHandler(SessionHandler, common.DefaultColorProfile),
75 // CLI middleware.
76 CommandMiddleware,
77 // Logging middleware.
78 LoggingMiddleware,
79 // Context middleware.
80 ContextMiddleware(cfg, dbx, datastore, be, logger),
81 // Authentication middleware.
82 // gossh.PublicKeyHandler doesn't guarantee that the public key
83 // is in fact the one used for authentication, so we need to
84 // check it again here.
85 AuthenticationMiddleware,
86 ),
87 }
88
89 opts := []ssh.Option{
90 ssh.PublicKeyAuth(s.PublicKeyHandler),
91 ssh.KeyboardInteractiveAuth(s.KeyboardInteractiveHandler),
92 wish.WithAddress(cfg.SSH.ListenAddr),
93 wish.WithHostKeyPath(cfg.SSH.KeyPath),
94 wish.WithMiddleware(mw...),
95 }
96 if runtime.GOOS == "windows" {
97 opts = append(opts, ssh.EmulatePty())
98 } else {
99 opts = append(opts, ssh.AllocatePty())
100 }
101 s.srv, err = wish.NewServer(opts...)
102 if err != nil {
103 return nil, err
104 }
105
106 if config.IsDebug() {
107 s.srv.ServerConfigCallback = func(ctx ssh.Context) *gossh.ServerConfig {
108 return &gossh.ServerConfig{
109 AuthLogCallback: func(conn gossh.ConnMetadata, method string, err error) {
110 logger.Debug("authentication", "user", conn.User(), "method", method, "err", err)
111 },
112 }
113 }
114 }
115
116 if cfg.SSH.MaxTimeout > 0 {
117 s.srv.MaxTimeout = time.Duration(cfg.SSH.MaxTimeout) * time.Second
118 }
119
120 if cfg.SSH.IdleTimeout > 0 {
121 s.srv.IdleTimeout = time.Duration(cfg.SSH.IdleTimeout) * time.Second
122 }
123
124 // Create client ssh key
125 if _, err := os.Stat(cfg.SSH.ClientKeyPath); err != nil && os.IsNotExist(err) {
126 _, err := keygen.New(cfg.SSH.ClientKeyPath, keygen.WithKeyType(keygen.Ed25519), keygen.WithWrite())
127 if err != nil {
128 return nil, fmt.Errorf("client ssh key: %w", err)
129 }
130 }
131
132 return s, nil
133}
134
135// ListenAndServe starts the SSH server.
136func (s *SSHServer) ListenAndServe() error {
137 return s.srv.ListenAndServe()
138}
139
140// Serve starts the SSH server on the given net.Listener.
141func (s *SSHServer) Serve(l net.Listener) error {
142 return s.srv.Serve(l)
143}
144
145// Close closes the SSH server.
146func (s *SSHServer) Close() error {
147 return s.srv.Close()
148}
149
150// Shutdown gracefully shuts down the SSH server.
151func (s *SSHServer) Shutdown(ctx context.Context) error {
152 return s.srv.Shutdown(ctx)
153}
154
155func initializePermissions(ctx ssh.Context) {
156 perms := ctx.Permissions()
157 if perms == nil || perms.Permissions == nil {
158 perms = &ssh.Permissions{Permissions: &gossh.Permissions{}}
159 }
160 if perms.Extensions == nil {
161 perms.Extensions = make(map[string]string)
162 }
163 if perms.Permissions.Extensions == nil {
164 perms.Permissions.Extensions = make(map[string]string)
165 }
166}
167
168// PublicKeyAuthHandler handles public key authentication.
169func (s *SSHServer) PublicKeyHandler(ctx ssh.Context, pk ssh.PublicKey) (allowed bool) {
170 if pk == nil {
171 return false
172 }
173
174 defer func(allowed *bool) {
175 publicKeyCounter.WithLabelValues(strconv.FormatBool(*allowed)).Inc()
176 }(&allowed)
177
178 user, _ := s.be.UserByPublicKey(ctx, pk)
179 if user != nil {
180 ctx.SetValue(proto.ContextKeyUser, user)
181 allowed = true
182
183 // XXX: store the first "approved" public-key fingerprint in the
184 // permissions block to use for authentication later.
185 initializePermissions(ctx)
186 perms := ctx.Permissions()
187
188 // Set the public key fingerprint to be used for authentication.
189 perms.Extensions["pubkey-fp"] = gossh.FingerprintSHA256(pk)
190 ctx.SetValue(ssh.ContextKeyPermissions, perms)
191 }
192
193 return
194}
195
196// KeyboardInteractiveHandler handles keyboard interactive authentication.
197// This is used after all public key authentication has failed.
198func (s *SSHServer) KeyboardInteractiveHandler(ctx ssh.Context, _ gossh.KeyboardInteractiveChallenge) bool {
199 ac := s.be.AllowKeyless(ctx)
200 keyboardInteractiveCounter.WithLabelValues(strconv.FormatBool(ac)).Inc()
201
202 // If we're allowing keyless access, reset the public key fingerprint
203 if ac {
204 initializePermissions(ctx)
205 perms := ctx.Permissions()
206
207 // XXX: reset the public-key fingerprint. This is used to validate the
208 // public key being used to authenticate.
209 perms.Extensions["pubkey-fp"] = ""
210 ctx.SetValue(ssh.ContextKeyPermissions, perms)
211 }
212 return ac
213}