ssh.go

  1package ssh
  2
  3import (
  4	"context"
  5	"fmt"
  6	"net"
  7	"os"
  8	"strconv"
  9	"time"
 10
 11	"github.com/charmbracelet/keygen"
 12	"github.com/charmbracelet/log"
 13	"github.com/charmbracelet/soft-serve/server/access"
 14	"github.com/charmbracelet/soft-serve/server/backend"
 15	"github.com/charmbracelet/soft-serve/server/config"
 16	"github.com/charmbracelet/soft-serve/server/git"
 17	"github.com/charmbracelet/soft-serve/server/sshutils"
 18	"github.com/charmbracelet/ssh"
 19	"github.com/charmbracelet/wish"
 20	bm "github.com/charmbracelet/wish/bubbletea"
 21	lm "github.com/charmbracelet/wish/logging"
 22	rm "github.com/charmbracelet/wish/recover"
 23	"github.com/muesli/termenv"
 24	"github.com/prometheus/client_golang/prometheus"
 25	"github.com/prometheus/client_golang/prometheus/promauto"
 26	gossh "golang.org/x/crypto/ssh"
 27)
 28
 29var (
 30	publicKeyCounter = promauto.NewCounterVec(prometheus.CounterOpts{
 31		Namespace: "soft_serve",
 32		Subsystem: "ssh",
 33		Name:      "public_key_auth_total",
 34		Help:      "The total number of public key auth requests",
 35	}, []string{"allowed"})
 36
 37	keyboardInteractiveCounter = promauto.NewCounterVec(prometheus.CounterOpts{
 38		Namespace: "soft_serve",
 39		Subsystem: "ssh",
 40		Name:      "keyboard_interactive_auth_total",
 41		Help:      "The total number of keyboard interactive auth requests",
 42	}, []string{"allowed"})
 43
 44	uploadPackCounter = promauto.NewCounterVec(prometheus.CounterOpts{
 45		Namespace: "soft_serve",
 46		Subsystem: "git",
 47		Name:      "upload_pack_total",
 48		Help:      "The total number of git-upload-pack requests",
 49	}, []string{"repo"})
 50
 51	receivePackCounter = promauto.NewCounterVec(prometheus.CounterOpts{
 52		Namespace: "soft_serve",
 53		Subsystem: "git",
 54		Name:      "receive_pack_total",
 55		Help:      "The total number of git-receive-pack requests",
 56	}, []string{"repo"})
 57
 58	uploadArchiveCounter = promauto.NewCounterVec(prometheus.CounterOpts{
 59		Namespace: "soft_serve",
 60		Subsystem: "git",
 61		Name:      "upload_archive_total",
 62		Help:      "The total number of git-upload-archive requests",
 63	}, []string{"repo"})
 64
 65	uploadPackSeconds = promauto.NewCounterVec(prometheus.CounterOpts{
 66		Namespace: "soft_serve",
 67		Subsystem: "git",
 68		Name:      "upload_pack_seconds_total",
 69		Help:      "The total time spent on git-upload-pack requests",
 70	}, []string{"repo"})
 71
 72	receivePackSeconds = promauto.NewCounterVec(prometheus.CounterOpts{
 73		Namespace: "soft_serve",
 74		Subsystem: "git",
 75		Name:      "receive_pack_seconds_total",
 76		Help:      "The total time spent on git-receive-pack requests",
 77	}, []string{"repo"})
 78
 79	uploadArchiveSeconds = promauto.NewCounterVec(prometheus.CounterOpts{
 80		Namespace: "soft_serve",
 81		Subsystem: "git",
 82		Name:      "upload_archive_seconds_total",
 83		Help:      "The total time spent on git-upload-archive requests",
 84	}, []string{"repo"})
 85
 86	createRepoCounter = promauto.NewCounterVec(prometheus.CounterOpts{
 87		Namespace: "soft_serve",
 88		Subsystem: "ssh",
 89		Name:      "create_repo_total",
 90		Help:      "The total number of create repo requests",
 91	}, []string{"repo"})
 92)
 93
 94// SSHServer is a SSH server that implements the git protocol.
 95type SSHServer struct { // nolint: revive
 96	srv    *ssh.Server
 97	cfg    *config.Config
 98	be     *backend.Backend
 99	ctx    context.Context
100	logger *log.Logger
101}
102
103// NewSSHServer returns a new SSHServer.
104func NewSSHServer(ctx context.Context) (*SSHServer, error) {
105	cfg := config.FromContext(ctx)
106	logger := log.FromContext(ctx).WithPrefix("ssh")
107	be := backend.FromContext(ctx)
108
109	var err error
110	s := &SSHServer{
111		cfg:    cfg,
112		ctx:    ctx,
113		be:     be,
114		logger: logger,
115	}
116
117	mw := []wish.Middleware{
118		rm.MiddlewareWithLogger(
119			logger,
120			// BubbleTea middleware.
121			bm.MiddlewareWithProgramHandler(SessionHandler, termenv.ANSI256),
122			// CLI middleware.
123			CommandMiddleware,
124			// Context middleware.
125			ContextMiddleware(cfg, be, logger),
126			// Logging middleware.
127			lm.MiddlewareWithLogger(
128				&loggerAdapter{logger, log.DebugLevel},
129			),
130		),
131	}
132
133	s.srv, err = wish.NewServer(
134		ssh.PublicKeyAuth(s.PublicKeyHandler),
135		ssh.KeyboardInteractiveAuth(s.KeyboardInteractiveHandler),
136		wish.WithAddress(cfg.SSH.ListenAddr),
137		wish.WithHostKeyPath(cfg.SSH.KeyPath),
138		wish.WithMiddleware(mw...),
139	)
140	if err != nil {
141		return nil, err
142	}
143
144	if cfg.SSH.MaxTimeout > 0 {
145		s.srv.MaxTimeout = time.Duration(cfg.SSH.MaxTimeout) * time.Second
146	}
147
148	if cfg.SSH.IdleTimeout > 0 {
149		s.srv.IdleTimeout = time.Duration(cfg.SSH.IdleTimeout) * time.Second
150	}
151
152	// Create client ssh key
153	if _, err := os.Stat(cfg.SSH.ClientKeyPath); err != nil && os.IsNotExist(err) {
154		_, err := keygen.New(cfg.SSH.ClientKeyPath, keygen.WithKeyType(keygen.Ed25519), keygen.WithWrite())
155		if err != nil {
156			return nil, fmt.Errorf("client ssh key: %w", err)
157		}
158	}
159
160	return s, nil
161}
162
163// ListenAndServe starts the SSH server.
164func (s *SSHServer) ListenAndServe() error {
165	return s.srv.ListenAndServe()
166}
167
168// Serve starts the SSH server on the given net.Listener.
169func (s *SSHServer) Serve(l net.Listener) error {
170	return s.srv.Serve(l)
171}
172
173// Close closes the SSH server.
174func (s *SSHServer) Close() error {
175	return s.srv.Close()
176}
177
178// Shutdown gracefully shuts down the SSH server.
179func (s *SSHServer) Shutdown(ctx context.Context) error {
180	return s.srv.Shutdown(ctx)
181}
182
183// PublicKeyAuthHandler handles public key authentication.
184func (s *SSHServer) PublicKeyHandler(ctx ssh.Context, pk ssh.PublicKey) (allowed bool) {
185	if pk == nil {
186		return false
187	}
188
189	ak := sshutils.MarshalAuthorizedKey(pk)
190	defer func(allowed *bool) {
191		publicKeyCounter.WithLabelValues(strconv.FormatBool(*allowed)).Inc()
192	}(&allowed)
193
194	ac := s.be.AccessLevelByPublicKey(ctx, "", pk)
195	s.logger.Debugf("access level for %q: %s", ak, ac)
196	allowed = ac >= access.ReadWriteAccess
197	return
198}
199
200// KeyboardInteractiveHandler handles keyboard interactive authentication.
201// This is used after all public key authentication has failed.
202func (s *SSHServer) KeyboardInteractiveHandler(ctx ssh.Context, _ gossh.KeyboardInteractiveChallenge) bool {
203	ac := s.be.AllowKeyless(ctx)
204	keyboardInteractiveCounter.WithLabelValues(strconv.FormatBool(ac)).Inc()
205	return ac
206}
207
208// sshFatal prints to the session's STDOUT as a git response and exit 1.
209func sshFatal(s ssh.Session, err error) {
210	git.WritePktlineErr(s, err) // nolint: errcheck
211	s.Exit(1)                   // nolint: errcheck
212}