ssh.go

  1package ssh
  2
  3import (
  4	"context"
  5	"fmt"
  6	"net"
  7	"os"
  8	"strconv"
  9	"time"
 10
 11	"github.com/charmbracelet/keygen"
 12	"github.com/charmbracelet/log"
 13	"github.com/charmbracelet/soft-serve/pkg/backend"
 14	"github.com/charmbracelet/soft-serve/pkg/config"
 15	"github.com/charmbracelet/soft-serve/pkg/db"
 16	"github.com/charmbracelet/soft-serve/pkg/proto"
 17	"github.com/charmbracelet/soft-serve/pkg/store"
 18	"github.com/charmbracelet/ssh"
 19	"github.com/charmbracelet/wish"
 20	rm "github.com/charmbracelet/wish/recover"
 21	"github.com/prometheus/client_golang/prometheus"
 22	"github.com/prometheus/client_golang/prometheus/promauto"
 23	gossh "golang.org/x/crypto/ssh"
 24)
 25
 26var (
 27	publicKeyCounter = promauto.NewCounterVec(prometheus.CounterOpts{
 28		Namespace: "soft_serve",
 29		Subsystem: "ssh",
 30		Name:      "public_key_auth_total",
 31		Help:      "The total number of public key auth requests",
 32	}, []string{"allowed"})
 33
 34	keyboardInteractiveCounter = promauto.NewCounterVec(prometheus.CounterOpts{
 35		Namespace: "soft_serve",
 36		Subsystem: "ssh",
 37		Name:      "keyboard_interactive_auth_total",
 38		Help:      "The total number of keyboard interactive auth requests",
 39	}, []string{"allowed"})
 40)
 41
 42// SSHServer is a SSH server that implements the git protocol.
 43type SSHServer struct { // nolint: revive
 44	srv    *ssh.Server
 45	cfg    *config.Config
 46	be     *backend.Backend
 47	ctx    context.Context
 48	logger *log.Logger
 49}
 50
 51// NewSSHServer returns a new SSHServer.
 52func NewSSHServer(ctx context.Context) (*SSHServer, error) {
 53	cfg := config.FromContext(ctx)
 54	logger := log.FromContext(ctx).WithPrefix("ssh")
 55	dbx := db.FromContext(ctx)
 56	datastore := store.FromContext(ctx)
 57	be := backend.FromContext(ctx)
 58
 59	var err error
 60	s := &SSHServer{
 61		cfg:    cfg,
 62		ctx:    ctx,
 63		be:     be,
 64		logger: logger,
 65	}
 66
 67	mw := []wish.Middleware{
 68		rm.MiddlewareWithLogger(
 69			logger,
 70			// BubbleTea middleware.
 71			// bm.MiddlewareWithProgramHandler(SessionHandler, termenv.ANSI256),
 72			// CLI middleware.
 73			// CommandMiddleware,
 74			ShellMiddleware,
 75			// Logging middleware.
 76			LoggingMiddleware,
 77			// Context middleware.
 78			ContextMiddleware(cfg, dbx, datastore, be, logger),
 79			// Authentication middleware.
 80			// gossh.PublicKeyHandler doesn't guarantee that the public key
 81			// is in fact the one used for authentication, so we need to
 82			// check it again here.
 83			AuthenticationMiddleware,
 84		),
 85	}
 86
 87	s.srv, err = wish.NewServer(
 88		ssh.AllocatePty(),
 89		ssh.PublicKeyAuth(s.PublicKeyHandler),
 90		ssh.KeyboardInteractiveAuth(s.KeyboardInteractiveHandler),
 91		wish.WithAddress(cfg.SSH.ListenAddr),
 92		wish.WithHostKeyPath(cfg.SSH.KeyPath),
 93		wish.WithMiddleware(mw...),
 94	)
 95	if err != nil {
 96		return nil, err
 97	}
 98
 99	if config.IsDebug() {
100		s.srv.ServerConfigCallback = func(ctx ssh.Context) *gossh.ServerConfig {
101			return &gossh.ServerConfig{
102				AuthLogCallback: func(conn gossh.ConnMetadata, method string, err error) {
103					logger.Debug("authentication", "user", conn.User(), "method", method, "err", err)
104				},
105			}
106		}
107	}
108
109	if cfg.SSH.MaxTimeout > 0 {
110		s.srv.MaxTimeout = time.Duration(cfg.SSH.MaxTimeout) * time.Second
111	}
112
113	if cfg.SSH.IdleTimeout > 0 {
114		s.srv.IdleTimeout = time.Duration(cfg.SSH.IdleTimeout) * time.Second
115	}
116
117	// Create client ssh key
118	if _, err := os.Stat(cfg.SSH.ClientKeyPath); err != nil && os.IsNotExist(err) {
119		_, err := keygen.New(cfg.SSH.ClientKeyPath, keygen.WithKeyType(keygen.Ed25519), keygen.WithWrite())
120		if err != nil {
121			return nil, fmt.Errorf("client ssh key: %w", err)
122		}
123	}
124
125	return s, nil
126}
127
128// ListenAndServe starts the SSH server.
129func (s *SSHServer) ListenAndServe() error {
130	return s.srv.ListenAndServe()
131}
132
133// Serve starts the SSH server on the given net.Listener.
134func (s *SSHServer) Serve(l net.Listener) error {
135	return s.srv.Serve(l)
136}
137
138// Close closes the SSH server.
139func (s *SSHServer) Close() error {
140	return s.srv.Close()
141}
142
143// Shutdown gracefully shuts down the SSH server.
144func (s *SSHServer) Shutdown(ctx context.Context) error {
145	return s.srv.Shutdown(ctx)
146}
147
148func initializePermissions(ctx ssh.Context) {
149	perms := ctx.Permissions()
150	if perms == nil || perms.Permissions == nil {
151		perms = &ssh.Permissions{Permissions: &gossh.Permissions{}}
152	}
153	if perms.Extensions == nil {
154		perms.Extensions = make(map[string]string)
155	}
156	if perms.Permissions.Extensions == nil {
157		perms.Permissions.Extensions = make(map[string]string)
158	}
159}
160
161// PublicKeyAuthHandler handles public key authentication.
162func (s *SSHServer) PublicKeyHandler(ctx ssh.Context, pk ssh.PublicKey) (allowed bool) {
163	if pk == nil {
164		return false
165	}
166
167	defer func(allowed *bool) {
168		publicKeyCounter.WithLabelValues(strconv.FormatBool(*allowed)).Inc()
169	}(&allowed)
170
171	user, _ := s.be.UserByPublicKey(ctx, pk)
172	if user != nil {
173		ctx.SetValue(proto.ContextKeyUser, user)
174		allowed = true
175
176		// XXX: store the first "approved" public-key fingerprint in the
177		// permissions block to use for authentication later.
178		initializePermissions(ctx)
179		perms := ctx.Permissions()
180
181		// Set the public key fingerprint to be used for authentication.
182		perms.Extensions["pubkey-fp"] = gossh.FingerprintSHA256(pk)
183		ctx.SetValue(ssh.ContextKeyPermissions, perms)
184	}
185
186	return
187}
188
189// KeyboardInteractiveHandler handles keyboard interactive authentication.
190// This is used after all public key authentication has failed.
191func (s *SSHServer) KeyboardInteractiveHandler(ctx ssh.Context, _ gossh.KeyboardInteractiveChallenge) bool {
192	ac := s.be.AllowKeyless(ctx)
193	keyboardInteractiveCounter.WithLabelValues(strconv.FormatBool(ac)).Inc()
194
195	// If we're allowing keyless access, reset the public key fingerprint
196	if ac {
197		initializePermissions(ctx)
198		perms := ctx.Permissions()
199
200		// XXX: reset the public-key fingerprint. This is used to validate the
201		// public key being used to authenticate.
202		perms.Extensions["pubkey-fp"] = ""
203		ctx.SetValue(ssh.ContextKeyPermissions, perms)
204	}
205	return ac
206}