1package ssh
  2
  3import (
  4	"context"
  5	"fmt"
  6	"net"
  7	"os"
  8	"runtime"
  9	"strconv"
 10	"time"
 11
 12	"github.com/charmbracelet/keygen"
 13	"github.com/charmbracelet/log/v2"
 14	"github.com/charmbracelet/soft-serve/pkg/backend"
 15	"github.com/charmbracelet/soft-serve/pkg/config"
 16	"github.com/charmbracelet/soft-serve/pkg/db"
 17	"github.com/charmbracelet/soft-serve/pkg/proto"
 18	"github.com/charmbracelet/soft-serve/pkg/store"
 19	"github.com/charmbracelet/ssh"
 20	"github.com/charmbracelet/wish/v2"
 21	bm "github.com/charmbracelet/wish/v2/bubbletea"
 22	rm "github.com/charmbracelet/wish/v2/recover"
 23	"github.com/prometheus/client_golang/prometheus"
 24	"github.com/prometheus/client_golang/prometheus/promauto"
 25	gossh "golang.org/x/crypto/ssh"
 26)
 27
 28var (
 29	publicKeyCounter = promauto.NewCounterVec(prometheus.CounterOpts{
 30		Namespace: "soft_serve",
 31		Subsystem: "ssh",
 32		Name:      "public_key_auth_total",
 33		Help:      "The total number of public key auth requests",
 34	}, []string{"allowed"})
 35
 36	keyboardInteractiveCounter = promauto.NewCounterVec(prometheus.CounterOpts{
 37		Namespace: "soft_serve",
 38		Subsystem: "ssh",
 39		Name:      "keyboard_interactive_auth_total",
 40		Help:      "The total number of keyboard interactive auth requests",
 41	}, []string{"allowed"})
 42)
 43
 44// SSHServer is a SSH server that implements the git protocol.
 45type SSHServer struct { // nolint: revive
 46	srv    *ssh.Server
 47	cfg    *config.Config
 48	be     *backend.Backend
 49	ctx    context.Context
 50	logger *log.Logger
 51}
 52
 53// NewSSHServer returns a new SSHServer.
 54func NewSSHServer(ctx context.Context) (*SSHServer, error) {
 55	cfg := config.FromContext(ctx)
 56	logger := log.FromContext(ctx).WithPrefix("ssh")
 57	dbx := db.FromContext(ctx)
 58	datastore := store.FromContext(ctx)
 59	be := backend.FromContext(ctx)
 60
 61	var err error
 62	s := &SSHServer{
 63		cfg:    cfg,
 64		ctx:    ctx,
 65		be:     be,
 66		logger: logger,
 67	}
 68
 69	mw := []wish.Middleware{
 70		rm.MiddlewareWithLogger(
 71			logger,
 72			// BubbleTea middleware.
 73			bm.MiddlewareWithProgramHandler(SessionHandler),
 74			// CLI middleware.
 75			CommandMiddleware,
 76			// Logging middleware.
 77			LoggingMiddleware,
 78			// Context middleware.
 79			ContextMiddleware(cfg, dbx, datastore, be, logger),
 80			// Authentication middleware.
 81			// gossh.PublicKeyHandler doesn't guarantee that the public key
 82			// is in fact the one used for authentication, so we need to
 83			// check it again here.
 84			AuthenticationMiddleware,
 85		),
 86	}
 87
 88	opts := []ssh.Option{
 89		ssh.PublicKeyAuth(s.PublicKeyHandler),
 90		ssh.KeyboardInteractiveAuth(s.KeyboardInteractiveHandler),
 91		wish.WithAddress(cfg.SSH.ListenAddr),
 92		wish.WithHostKeyPath(cfg.SSH.KeyPath),
 93		wish.WithMiddleware(mw...),
 94	}
 95	if runtime.GOOS == "windows" {
 96		opts = append(opts, ssh.EmulatePty())
 97	} else {
 98		opts = append(opts, ssh.AllocatePty())
 99	}
100	s.srv, err = wish.NewServer(opts...)
101	if err != nil {
102		return nil, err
103	}
104
105	if config.IsDebug() {
106		s.srv.ServerConfigCallback = func(_ ssh.Context) *gossh.ServerConfig {
107			return &gossh.ServerConfig{
108				AuthLogCallback: func(conn gossh.ConnMetadata, method string, err error) {
109					logger.Debug("authentication", "user", conn.User(), "method", method, "err", err)
110				},
111			}
112		}
113	}
114
115	if cfg.SSH.MaxTimeout > 0 {
116		s.srv.MaxTimeout = time.Duration(cfg.SSH.MaxTimeout) * time.Second
117	}
118
119	if cfg.SSH.IdleTimeout > 0 {
120		s.srv.IdleTimeout = time.Duration(cfg.SSH.IdleTimeout) * time.Second
121	}
122
123	// Create client ssh key
124	if _, err := os.Stat(cfg.SSH.ClientKeyPath); err != nil && os.IsNotExist(err) {
125		_, err := keygen.New(cfg.SSH.ClientKeyPath, keygen.WithKeyType(keygen.Ed25519), keygen.WithWrite())
126		if err != nil {
127			return nil, fmt.Errorf("client ssh key: %w", err)
128		}
129	}
130
131	return s, nil
132}
133
134// ListenAndServe starts the SSH server.
135func (s *SSHServer) ListenAndServe() error {
136	return s.srv.ListenAndServe()
137}
138
139// Serve starts the SSH server on the given net.Listener.
140func (s *SSHServer) Serve(l net.Listener) error {
141	return s.srv.Serve(l)
142}
143
144// Close closes the SSH server.
145func (s *SSHServer) Close() error {
146	return s.srv.Close()
147}
148
149// Shutdown gracefully shuts down the SSH server.
150func (s *SSHServer) Shutdown(ctx context.Context) error {
151	return s.srv.Shutdown(ctx)
152}
153
154func initializePermissions(ctx ssh.Context) {
155	perms := ctx.Permissions()
156	if perms == nil || perms.Permissions == nil {
157		perms = &ssh.Permissions{Permissions: &gossh.Permissions{}}
158	}
159	if perms.Extensions == nil {
160		perms.Extensions = make(map[string]string)
161	}
162	if perms.Permissions.Extensions == nil {
163		perms.Permissions.Extensions = make(map[string]string)
164	}
165}
166
167// PublicKeyAuthHandler handles public key authentication.
168func (s *SSHServer) PublicKeyHandler(ctx ssh.Context, pk ssh.PublicKey) (allowed bool) {
169	if pk == nil {
170		return false
171	}
172
173	allowed = true
174	defer func(allowed *bool) {
175		publicKeyCounter.WithLabelValues(strconv.FormatBool(*allowed)).Inc()
176	}(&allowed)
177
178	user, _ := s.be.UserByPublicKey(ctx, pk)
179	if user != nil {
180		ctx.SetValue(proto.ContextKeyUser, user)
181	}
182
183	// XXX: store the first "approved" public-key fingerprint in the
184	// permissions block to use for authentication later.
185	initializePermissions(ctx)
186	perms := ctx.Permissions()
187
188	// Set the public key fingerprint to be used for authentication.
189	perms.Extensions["pubkey-fp"] = gossh.FingerprintSHA256(pk)
190	ctx.SetValue(ssh.ContextKeyPermissions, perms)
191
192	return
193}
194
195// KeyboardInteractiveHandler handles keyboard interactive authentication.
196// This is used after all public key authentication has failed.
197func (s *SSHServer) KeyboardInteractiveHandler(ctx ssh.Context, _ gossh.KeyboardInteractiveChallenge) bool {
198	ac := s.be.AllowKeyless(ctx)
199	keyboardInteractiveCounter.WithLabelValues(strconv.FormatBool(ac)).Inc()
200
201	// If we're allowing keyless access, reset the public key fingerprint
202	if ac {
203		initializePermissions(ctx)
204		perms := ctx.Permissions()
205
206		// XXX: reset the public-key fingerprint. This is used to validate the
207		// public key being used to authenticate.
208		perms.Extensions["pubkey-fp"] = ""
209		ctx.SetValue(ssh.ContextKeyPermissions, perms)
210	}
211	return ac
212}