ssh.go

  1package ssh
  2
  3import (
  4	"context"
  5	"fmt"
  6	"net"
  7	"os"
  8	"strconv"
  9	"time"
 10
 11	"github.com/charmbracelet/keygen"
 12	log "github.com/charmbracelet/log/v2"
 13	"github.com/charmbracelet/soft-serve/pkg/backend"
 14	"github.com/charmbracelet/soft-serve/pkg/config"
 15	"github.com/charmbracelet/soft-serve/pkg/db"
 16	"github.com/charmbracelet/soft-serve/pkg/proto"
 17	"github.com/charmbracelet/soft-serve/pkg/store"
 18	"github.com/charmbracelet/ssh"
 19		wish "github.com/charmbracelet/wish/v2"
 20	bm "github.com/charmbracelet/wish/v2/bubbletea"
 21	rm "github.com/charmbracelet/wish/v2/recover"
 22	"github.com/prometheus/client_golang/prometheus"
 23	"github.com/prometheus/client_golang/prometheus/promauto"
 24	gossh "golang.org/x/crypto/ssh"
 25)
 26
 27var (
 28	publicKeyCounter = promauto.NewCounterVec(prometheus.CounterOpts{
 29		Namespace: "soft_serve",
 30		Subsystem: "ssh",
 31		Name:      "public_key_auth_total",
 32		Help:      "The total number of public key auth requests",
 33	}, []string{"allowed"})
 34
 35	keyboardInteractiveCounter = promauto.NewCounterVec(prometheus.CounterOpts{
 36		Namespace: "soft_serve",
 37		Subsystem: "ssh",
 38		Name:      "keyboard_interactive_auth_total",
 39		Help:      "The total number of keyboard interactive auth requests",
 40	}, []string{"allowed"})
 41)
 42
 43// SSHServer is a SSH server that implements the git protocol.
 44type SSHServer struct { //nolint:revive
 45	srv    *ssh.Server
 46	cfg    *config.Config
 47	be     *backend.Backend
 48	ctx    context.Context
 49	logger *log.Logger
 50}
 51
 52// NewSSHServer returns a new SSHServer.
 53func NewSSHServer(ctx context.Context) (*SSHServer, error) {
 54	cfg := config.FromContext(ctx)
 55	logger := log.FromContext(ctx).WithPrefix("ssh")
 56	dbx := db.FromContext(ctx)
 57	datastore := store.FromContext(ctx)
 58	be := backend.FromContext(ctx)
 59
 60	var err error
 61	s := &SSHServer{
 62		cfg:    cfg,
 63		ctx:    ctx,
 64		be:     be,
 65		logger: logger,
 66	}
 67
 68	mw := []wish.Middleware{
 69		rm.MiddlewareWithLogger(
 70			logger,
 71			// BubbleTea middleware.
 72			bm.MiddlewareWithProgramHandler(SessionHandler),
 73			// CLI middleware.
 74			CommandMiddleware,
 75			// Logging middleware.
 76			LoggingMiddleware,
 77			// Context middleware.
 78			ContextMiddleware(cfg, dbx, datastore, be, logger),
 79			// Authentication middleware.
 80			// gossh.PublicKeyHandler doesn't guarantee that the public key
 81			// is in fact the one used for authentication, so we need to
 82			// check it again here.
 83			AuthenticationMiddleware,
 84		),
 85	}
 86
 87	opts := []ssh.Option{
 88		ssh.PublicKeyAuth(s.PublicKeyHandler),
 89		ssh.KeyboardInteractiveAuth(s.KeyboardInteractiveHandler),
 90		wish.WithAddress(cfg.SSH.ListenAddr),
 91		wish.WithHostKeyPath(cfg.SSH.KeyPath),
 92		wish.WithMiddleware(mw...),
 93	}
 94
 95	// TODO: Support a real PTY in future version.
 96	opts = append(opts, ssh.EmulatePty())
 97
 98	s.srv, err = wish.NewServer(opts...)
 99	if err != nil {
100		return nil, err
101	}
102
103	if config.IsDebug() {
104		s.srv.ServerConfigCallback = func(_ ssh.Context) *gossh.ServerConfig {
105			return &gossh.ServerConfig{
106				AuthLogCallback: func(conn gossh.ConnMetadata, method string, err error) {
107					logger.Debug("authentication", "user", conn.User(), "method", method, "err", err)
108				},
109			}
110		}
111	}
112
113	if cfg.SSH.MaxTimeout > 0 {
114		s.srv.MaxTimeout = time.Duration(cfg.SSH.MaxTimeout) * time.Second
115	}
116
117	if cfg.SSH.IdleTimeout > 0 {
118		s.srv.IdleTimeout = time.Duration(cfg.SSH.IdleTimeout) * time.Second
119	}
120
121	// Create client ssh key
122	if _, err := os.Stat(cfg.SSH.ClientKeyPath); err != nil && os.IsNotExist(err) {
123		_, err := keygen.New(cfg.SSH.ClientKeyPath, keygen.WithKeyType(keygen.Ed25519), keygen.WithWrite())
124		if err != nil {
125			return nil, fmt.Errorf("client ssh key: %w", err)
126		}
127	}
128
129	return s, nil
130}
131
132// ListenAndServe starts the SSH server.
133func (s *SSHServer) ListenAndServe() error {
134	return s.srv.ListenAndServe()
135}
136
137// Serve starts the SSH server on the given net.Listener.
138func (s *SSHServer) Serve(l net.Listener) error {
139	return s.srv.Serve(l)
140}
141
142// Close closes the SSH server.
143func (s *SSHServer) Close() error {
144	return s.srv.Close()
145}
146
147// Shutdown gracefully shuts down the SSH server.
148func (s *SSHServer) Shutdown(ctx context.Context) error {
149	return s.srv.Shutdown(ctx)
150}
151
152func initializePermissions(ctx ssh.Context) {
153	perms := ctx.Permissions()
154	if perms == nil || perms.Permissions == nil {
155		perms = &ssh.Permissions{Permissions: &gossh.Permissions{}}
156	}
157	if perms.Extensions == nil {
158		perms.Extensions = make(map[string]string)
159	}
160	if perms.Permissions.Extensions == nil {
161		perms.Permissions.Extensions = make(map[string]string)
162	}
163}
164
165// PublicKeyHandler handles public key authentication.
166func (s *SSHServer) PublicKeyHandler(ctx ssh.Context, pk ssh.PublicKey) (allowed bool) {
167	if pk == nil {
168		return false
169	}
170
171	allowed = true
172	defer func(allowed *bool) {
173		publicKeyCounter.WithLabelValues(strconv.FormatBool(*allowed)).Inc()
174	}(&allowed)
175
176	user, _ := s.be.UserByPublicKey(ctx, pk)
177	if user != nil {
178		ctx.SetValue(proto.ContextKeyUser, user)
179	}
180
181	// XXX: store the first "approved" public-key fingerprint in the
182	// permissions block to use for authentication later.
183	initializePermissions(ctx)
184	perms := ctx.Permissions()
185
186	// Set the public key fingerprint to be used for authentication.
187	perms.Extensions["pubkey-fp"] = gossh.FingerprintSHA256(pk)
188	ctx.SetValue(ssh.ContextKeyPermissions, perms)
189
190	return
191}
192
193// KeyboardInteractiveHandler handles keyboard interactive authentication.
194// This is used after all public key authentication has failed.
195func (s *SSHServer) KeyboardInteractiveHandler(ctx ssh.Context, _ gossh.KeyboardInteractiveChallenge) bool {
196	ac := s.be.AllowKeyless(ctx)
197	keyboardInteractiveCounter.WithLabelValues(strconv.FormatBool(ac)).Inc()
198
199	// If we're allowing keyless access, reset the public key fingerprint
200	if ac {
201		initializePermissions(ctx)
202		perms := ctx.Permissions()
203
204		// XXX: reset the public-key fingerprint. This is used to validate the
205		// public key being used to authenticate.
206		perms.Extensions["pubkey-fp"] = ""
207		ctx.SetValue(ssh.ContextKeyPermissions, perms)
208	}
209	return ac
210}