ssh.go

  1package ssh
  2
  3import (
  4	"context"
  5	"fmt"
  6	"net"
  7	"os"
  8	"runtime"
  9	"strconv"
 10	"time"
 11
 12	"github.com/charmbracelet/keygen"
 13	"github.com/charmbracelet/log"
 14	"github.com/charmbracelet/soft-serve/pkg/backend"
 15	"github.com/charmbracelet/soft-serve/pkg/config"
 16	"github.com/charmbracelet/soft-serve/pkg/db"
 17	"github.com/charmbracelet/soft-serve/pkg/proto"
 18	"github.com/charmbracelet/soft-serve/pkg/store"
 19	"github.com/charmbracelet/soft-serve/pkg/ui/common"
 20	"github.com/charmbracelet/ssh"
 21	"github.com/charmbracelet/wish"
 22	bm "github.com/charmbracelet/wish/bubbletea"
 23	rm "github.com/charmbracelet/wish/recover"
 24	"github.com/prometheus/client_golang/prometheus"
 25	"github.com/prometheus/client_golang/prometheus/promauto"
 26	gossh "golang.org/x/crypto/ssh"
 27)
 28
 29var (
 30	publicKeyCounter = promauto.NewCounterVec(prometheus.CounterOpts{
 31		Namespace: "soft_serve",
 32		Subsystem: "ssh",
 33		Name:      "public_key_auth_total",
 34		Help:      "The total number of public key auth requests",
 35	}, []string{"allowed"})
 36
 37	keyboardInteractiveCounter = promauto.NewCounterVec(prometheus.CounterOpts{
 38		Namespace: "soft_serve",
 39		Subsystem: "ssh",
 40		Name:      "keyboard_interactive_auth_total",
 41		Help:      "The total number of keyboard interactive auth requests",
 42	}, []string{"allowed"})
 43)
 44
 45// SSHServer is a SSH server that implements the git protocol.
 46type SSHServer struct { // nolint: revive
 47	srv    *ssh.Server
 48	cfg    *config.Config
 49	be     *backend.Backend
 50	ctx    context.Context
 51	logger *log.Logger
 52}
 53
 54// NewSSHServer returns a new SSHServer.
 55func NewSSHServer(ctx context.Context) (*SSHServer, error) {
 56	cfg := config.FromContext(ctx)
 57	logger := log.FromContext(ctx).WithPrefix("ssh")
 58	dbx := db.FromContext(ctx)
 59	datastore := store.FromContext(ctx)
 60	be := backend.FromContext(ctx)
 61
 62	var err error
 63	s := &SSHServer{
 64		cfg:    cfg,
 65		ctx:    ctx,
 66		be:     be,
 67		logger: logger,
 68	}
 69
 70	mw := []wish.Middleware{
 71		rm.MiddlewareWithLogger(
 72			logger,
 73			// BubbleTea middleware.
 74			bm.MiddlewareWithProgramHandler(SessionHandler, common.DefaultColorProfile),
 75			// CLI middleware.
 76			CommandMiddleware,
 77			// Logging middleware.
 78			LoggingMiddleware,
 79			// Context middleware.
 80			ContextMiddleware(cfg, dbx, datastore, be, logger),
 81			// Authentication middleware.
 82			// gossh.PublicKeyHandler doesn't guarantee that the public key
 83			// is in fact the one used for authentication, so we need to
 84			// check it again here.
 85			AuthenticationMiddleware,
 86		),
 87	}
 88
 89	opts := []ssh.Option{
 90		ssh.PublicKeyAuth(s.PublicKeyHandler),
 91		ssh.KeyboardInteractiveAuth(s.KeyboardInteractiveHandler),
 92		wish.WithAddress(cfg.SSH.ListenAddr),
 93		wish.WithHostKeyPath(cfg.SSH.KeyPath),
 94		wish.WithMiddleware(mw...),
 95	}
 96	if runtime.GOOS == "windows" {
 97		opts = append(opts, ssh.EmulatePty())
 98	} else {
 99		opts = append(opts, ssh.AllocatePty())
100	}
101	s.srv, err = wish.NewServer(opts...)
102	if err != nil {
103		return nil, err
104	}
105
106	if config.IsDebug() {
107		s.srv.ServerConfigCallback = func(ctx ssh.Context) *gossh.ServerConfig {
108			return &gossh.ServerConfig{
109				AuthLogCallback: func(conn gossh.ConnMetadata, method string, err error) {
110					logger.Debug("authentication", "user", conn.User(), "method", method, "err", err)
111				},
112			}
113		}
114	}
115
116	if cfg.SSH.MaxTimeout > 0 {
117		s.srv.MaxTimeout = time.Duration(cfg.SSH.MaxTimeout) * time.Second
118	}
119
120	if cfg.SSH.IdleTimeout > 0 {
121		s.srv.IdleTimeout = time.Duration(cfg.SSH.IdleTimeout) * time.Second
122	}
123
124	// Create client ssh key
125	if _, err := os.Stat(cfg.SSH.ClientKeyPath); err != nil && os.IsNotExist(err) {
126		_, err := keygen.New(cfg.SSH.ClientKeyPath, keygen.WithKeyType(keygen.Ed25519), keygen.WithWrite())
127		if err != nil {
128			return nil, fmt.Errorf("client ssh key: %w", err)
129		}
130	}
131
132	return s, nil
133}
134
135// ListenAndServe starts the SSH server.
136func (s *SSHServer) ListenAndServe() error {
137	return s.srv.ListenAndServe()
138}
139
140// Serve starts the SSH server on the given net.Listener.
141func (s *SSHServer) Serve(l net.Listener) error {
142	return s.srv.Serve(l)
143}
144
145// Close closes the SSH server.
146func (s *SSHServer) Close() error {
147	return s.srv.Close()
148}
149
150// Shutdown gracefully shuts down the SSH server.
151func (s *SSHServer) Shutdown(ctx context.Context) error {
152	return s.srv.Shutdown(ctx)
153}
154
155func initializePermissions(ctx ssh.Context) {
156	perms := ctx.Permissions()
157	if perms == nil || perms.Permissions == nil {
158		perms = &ssh.Permissions{Permissions: &gossh.Permissions{}}
159	}
160	if perms.Extensions == nil {
161		perms.Extensions = make(map[string]string)
162	}
163	if perms.Permissions.Extensions == nil {
164		perms.Permissions.Extensions = make(map[string]string)
165	}
166}
167
168// PublicKeyAuthHandler handles public key authentication.
169func (s *SSHServer) PublicKeyHandler(ctx ssh.Context, pk ssh.PublicKey) (allowed bool) {
170	if pk == nil {
171		return false
172	}
173
174	defer func(allowed *bool) {
175		publicKeyCounter.WithLabelValues(strconv.FormatBool(*allowed)).Inc()
176	}(&allowed)
177
178	user, _ := s.be.UserByPublicKey(ctx, pk)
179	if user != nil {
180		ctx.SetValue(proto.ContextKeyUser, user)
181		allowed = true
182
183		// XXX: store the first "approved" public-key fingerprint in the
184		// permissions block to use for authentication later.
185		initializePermissions(ctx)
186		perms := ctx.Permissions()
187
188		// Set the public key fingerprint to be used for authentication.
189		perms.Extensions["pubkey-fp"] = gossh.FingerprintSHA256(pk)
190		ctx.SetValue(ssh.ContextKeyPermissions, perms)
191	}
192
193	return
194}
195
196// KeyboardInteractiveHandler handles keyboard interactive authentication.
197// This is used after all public key authentication has failed.
198func (s *SSHServer) KeyboardInteractiveHandler(ctx ssh.Context, _ gossh.KeyboardInteractiveChallenge) bool {
199	ac := s.be.AllowKeyless(ctx)
200	keyboardInteractiveCounter.WithLabelValues(strconv.FormatBool(ac)).Inc()
201
202	// If we're allowing keyless access, reset the public key fingerprint
203	if ac {
204		initializePermissions(ctx)
205		perms := ctx.Permissions()
206
207		// XXX: reset the public-key fingerprint. This is used to validate the
208		// public key being used to authenticate.
209		perms.Extensions["pubkey-fp"] = ""
210		ctx.SetValue(ssh.ContextKeyPermissions, perms)
211	}
212	return ac
213}