1package ssh
2
3import (
4 "context"
5 "fmt"
6 "net"
7 "os"
8 "strconv"
9 "time"
10
11 "github.com/charmbracelet/keygen"
12 "github.com/charmbracelet/log"
13 "github.com/charmbracelet/soft-serve/pkg/backend"
14 "github.com/charmbracelet/soft-serve/pkg/config"
15 "github.com/charmbracelet/soft-serve/pkg/db"
16 "github.com/charmbracelet/soft-serve/pkg/proto"
17 "github.com/charmbracelet/soft-serve/pkg/store"
18 "github.com/charmbracelet/ssh"
19 "github.com/charmbracelet/wish"
20 rm "github.com/charmbracelet/wish/recover"
21 "github.com/prometheus/client_golang/prometheus"
22 "github.com/prometheus/client_golang/prometheus/promauto"
23 gossh "golang.org/x/crypto/ssh"
24)
25
26var (
27 publicKeyCounter = promauto.NewCounterVec(prometheus.CounterOpts{
28 Namespace: "soft_serve",
29 Subsystem: "ssh",
30 Name: "public_key_auth_total",
31 Help: "The total number of public key auth requests",
32 }, []string{"allowed"})
33
34 keyboardInteractiveCounter = promauto.NewCounterVec(prometheus.CounterOpts{
35 Namespace: "soft_serve",
36 Subsystem: "ssh",
37 Name: "keyboard_interactive_auth_total",
38 Help: "The total number of keyboard interactive auth requests",
39 }, []string{"allowed"})
40)
41
42// SSHServer is a SSH server that implements the git protocol.
43type SSHServer struct { // nolint: revive
44 srv *ssh.Server
45 cfg *config.Config
46 be *backend.Backend
47 ctx context.Context
48 logger *log.Logger
49}
50
51// NewSSHServer returns a new SSHServer.
52func NewSSHServer(ctx context.Context) (*SSHServer, error) {
53 cfg := config.FromContext(ctx)
54 logger := log.FromContext(ctx).WithPrefix("ssh")
55 dbx := db.FromContext(ctx)
56 datastore := store.FromContext(ctx)
57 be := backend.FromContext(ctx)
58
59 var err error
60 s := &SSHServer{
61 cfg: cfg,
62 ctx: ctx,
63 be: be,
64 logger: logger,
65 }
66
67 mw := []wish.Middleware{
68 rm.MiddlewareWithLogger(
69 logger,
70 // BubbleTea middleware.
71 // bm.MiddlewareWithProgramHandler(SessionHandler, termenv.ANSI256),
72 // CLI middleware.
73 // CommandMiddleware,
74 ShellMiddleware,
75 // Logging middleware.
76 LoggingMiddleware,
77 // Context middleware.
78 ContextMiddleware(cfg, dbx, datastore, be, logger),
79 // Authentication middleware.
80 // gossh.PublicKeyHandler doesn't guarantee that the public key
81 // is in fact the one used for authentication, so we need to
82 // check it again here.
83 AuthenticationMiddleware,
84 ),
85 }
86
87 s.srv, err = wish.NewServer(
88 ssh.AllocatePty(),
89 ssh.PublicKeyAuth(s.PublicKeyHandler),
90 ssh.KeyboardInteractiveAuth(s.KeyboardInteractiveHandler),
91 wish.WithAddress(cfg.SSH.ListenAddr),
92 wish.WithHostKeyPath(cfg.SSH.KeyPath),
93 wish.WithMiddleware(mw...),
94 )
95 if err != nil {
96 return nil, err
97 }
98
99 if config.IsDebug() {
100 s.srv.ServerConfigCallback = func(ctx ssh.Context) *gossh.ServerConfig {
101 return &gossh.ServerConfig{
102 AuthLogCallback: func(conn gossh.ConnMetadata, method string, err error) {
103 logger.Debug("authentication", "user", conn.User(), "method", method, "err", err)
104 },
105 }
106 }
107 }
108
109 if cfg.SSH.MaxTimeout > 0 {
110 s.srv.MaxTimeout = time.Duration(cfg.SSH.MaxTimeout) * time.Second
111 }
112
113 if cfg.SSH.IdleTimeout > 0 {
114 s.srv.IdleTimeout = time.Duration(cfg.SSH.IdleTimeout) * time.Second
115 }
116
117 // Create client ssh key
118 if _, err := os.Stat(cfg.SSH.ClientKeyPath); err != nil && os.IsNotExist(err) {
119 _, err := keygen.New(cfg.SSH.ClientKeyPath, keygen.WithKeyType(keygen.Ed25519), keygen.WithWrite())
120 if err != nil {
121 return nil, fmt.Errorf("client ssh key: %w", err)
122 }
123 }
124
125 return s, nil
126}
127
128// ListenAndServe starts the SSH server.
129func (s *SSHServer) ListenAndServe() error {
130 return s.srv.ListenAndServe()
131}
132
133// Serve starts the SSH server on the given net.Listener.
134func (s *SSHServer) Serve(l net.Listener) error {
135 return s.srv.Serve(l)
136}
137
138// Close closes the SSH server.
139func (s *SSHServer) Close() error {
140 return s.srv.Close()
141}
142
143// Shutdown gracefully shuts down the SSH server.
144func (s *SSHServer) Shutdown(ctx context.Context) error {
145 return s.srv.Shutdown(ctx)
146}
147
148func initializePermissions(ctx ssh.Context) {
149 perms := ctx.Permissions()
150 if perms == nil || perms.Permissions == nil {
151 perms = &ssh.Permissions{Permissions: &gossh.Permissions{}}
152 }
153 if perms.Extensions == nil {
154 perms.Extensions = make(map[string]string)
155 }
156 if perms.Permissions.Extensions == nil {
157 perms.Permissions.Extensions = make(map[string]string)
158 }
159}
160
161// PublicKeyAuthHandler handles public key authentication.
162func (s *SSHServer) PublicKeyHandler(ctx ssh.Context, pk ssh.PublicKey) (allowed bool) {
163 if pk == nil {
164 return false
165 }
166
167 defer func(allowed *bool) {
168 publicKeyCounter.WithLabelValues(strconv.FormatBool(*allowed)).Inc()
169 }(&allowed)
170
171 user, _ := s.be.UserByPublicKey(ctx, pk)
172 if user != nil {
173 ctx.SetValue(proto.ContextKeyUser, user)
174 allowed = true
175
176 // XXX: store the first "approved" public-key fingerprint in the
177 // permissions block to use for authentication later.
178 initializePermissions(ctx)
179 perms := ctx.Permissions()
180
181 // Set the public key fingerprint to be used for authentication.
182 perms.Extensions["pubkey-fp"] = gossh.FingerprintSHA256(pk)
183 ctx.SetValue(ssh.ContextKeyPermissions, perms)
184 }
185
186 return
187}
188
189// KeyboardInteractiveHandler handles keyboard interactive authentication.
190// This is used after all public key authentication has failed.
191func (s *SSHServer) KeyboardInteractiveHandler(ctx ssh.Context, _ gossh.KeyboardInteractiveChallenge) bool {
192 ac := s.be.AllowKeyless(ctx)
193 keyboardInteractiveCounter.WithLabelValues(strconv.FormatBool(ac)).Inc()
194
195 // If we're allowing keyless access, reset the public key fingerprint
196 if ac {
197 initializePermissions(ctx)
198 perms := ctx.Permissions()
199
200 // XXX: reset the public-key fingerprint. This is used to validate the
201 // public key being used to authenticate.
202 perms.Extensions["pubkey-fp"] = ""
203 ctx.SetValue(ssh.ContextKeyPermissions, perms)
204 }
205 return ac
206}