1package ssh
  2
  3import (
  4	"context"
  5	"fmt"
  6	"net"
  7	"os"
  8	"strconv"
  9	"time"
 10
 11	"github.com/charmbracelet/keygen"
 12	"github.com/charmbracelet/log"
 13	"github.com/charmbracelet/soft-serve/server/backend"
 14	"github.com/charmbracelet/soft-serve/server/config"
 15	"github.com/charmbracelet/soft-serve/server/db"
 16	"github.com/charmbracelet/soft-serve/server/git"
 17	"github.com/charmbracelet/soft-serve/server/proto"
 18	"github.com/charmbracelet/soft-serve/server/store"
 19	"github.com/charmbracelet/ssh"
 20	"github.com/charmbracelet/wish"
 21	bm "github.com/charmbracelet/wish/bubbletea"
 22	lm "github.com/charmbracelet/wish/logging"
 23	rm "github.com/charmbracelet/wish/recover"
 24	"github.com/muesli/termenv"
 25	"github.com/prometheus/client_golang/prometheus"
 26	"github.com/prometheus/client_golang/prometheus/promauto"
 27	gossh "golang.org/x/crypto/ssh"
 28)
 29
 30var (
 31	publicKeyCounter = promauto.NewCounterVec(prometheus.CounterOpts{
 32		Namespace: "soft_serve",
 33		Subsystem: "ssh",
 34		Name:      "public_key_auth_total",
 35		Help:      "The total number of public key auth requests",
 36	}, []string{"allowed"})
 37
 38	keyboardInteractiveCounter = promauto.NewCounterVec(prometheus.CounterOpts{
 39		Namespace: "soft_serve",
 40		Subsystem: "ssh",
 41		Name:      "keyboard_interactive_auth_total",
 42		Help:      "The total number of keyboard interactive auth requests",
 43	}, []string{"allowed"})
 44
 45	uploadPackCounter = promauto.NewCounterVec(prometheus.CounterOpts{
 46		Namespace: "soft_serve",
 47		Subsystem: "git",
 48		Name:      "upload_pack_total",
 49		Help:      "The total number of git-upload-pack requests",
 50	}, []string{"repo"})
 51
 52	receivePackCounter = promauto.NewCounterVec(prometheus.CounterOpts{
 53		Namespace: "soft_serve",
 54		Subsystem: "git",
 55		Name:      "receive_pack_total",
 56		Help:      "The total number of git-receive-pack requests",
 57	}, []string{"repo"})
 58
 59	uploadArchiveCounter = promauto.NewCounterVec(prometheus.CounterOpts{
 60		Namespace: "soft_serve",
 61		Subsystem: "git",
 62		Name:      "upload_archive_total",
 63		Help:      "The total number of git-upload-archive requests",
 64	}, []string{"repo"})
 65
 66	uploadPackSeconds = promauto.NewCounterVec(prometheus.CounterOpts{
 67		Namespace: "soft_serve",
 68		Subsystem: "git",
 69		Name:      "upload_pack_seconds_total",
 70		Help:      "The total time spent on git-upload-pack requests",
 71	}, []string{"repo"})
 72
 73	receivePackSeconds = promauto.NewCounterVec(prometheus.CounterOpts{
 74		Namespace: "soft_serve",
 75		Subsystem: "git",
 76		Name:      "receive_pack_seconds_total",
 77		Help:      "The total time spent on git-receive-pack requests",
 78	}, []string{"repo"})
 79
 80	uploadArchiveSeconds = promauto.NewCounterVec(prometheus.CounterOpts{
 81		Namespace: "soft_serve",
 82		Subsystem: "git",
 83		Name:      "upload_archive_seconds_total",
 84		Help:      "The total time spent on git-upload-archive requests",
 85	}, []string{"repo"})
 86
 87	createRepoCounter = promauto.NewCounterVec(prometheus.CounterOpts{
 88		Namespace: "soft_serve",
 89		Subsystem: "ssh",
 90		Name:      "create_repo_total",
 91		Help:      "The total number of create repo requests",
 92	}, []string{"repo"})
 93)
 94
 95// SSHServer is a SSH server that implements the git protocol.
 96type SSHServer struct { // nolint: revive
 97	srv    *ssh.Server
 98	cfg    *config.Config
 99	be     *backend.Backend
100	ctx    context.Context
101	logger *log.Logger
102}
103
104// NewSSHServer returns a new SSHServer.
105func NewSSHServer(ctx context.Context) (*SSHServer, error) {
106	cfg := config.FromContext(ctx)
107	logger := log.FromContext(ctx).WithPrefix("ssh")
108	dbx := db.FromContext(ctx)
109	datastore := store.FromContext(ctx)
110	be := backend.FromContext(ctx)
111
112	var err error
113	s := &SSHServer{
114		cfg:    cfg,
115		ctx:    ctx,
116		be:     be,
117		logger: logger,
118	}
119
120	mw := []wish.Middleware{
121		rm.MiddlewareWithLogger(
122			logger,
123			// BubbleTea middleware.
124			bm.MiddlewareWithProgramHandler(SessionHandler, termenv.ANSI256),
125			// CLI middleware.
126			CommandMiddleware,
127			// Context middleware.
128			ContextMiddleware(cfg, dbx, datastore, be, logger),
129			// Logging middleware.
130			lm.MiddlewareWithLogger(
131				&loggerAdapter{logger, log.DebugLevel},
132			),
133		),
134	}
135
136	s.srv, err = wish.NewServer(
137		ssh.PublicKeyAuth(s.PublicKeyHandler),
138		ssh.KeyboardInteractiveAuth(s.KeyboardInteractiveHandler),
139		wish.WithAddress(cfg.SSH.ListenAddr),
140		wish.WithHostKeyPath(cfg.SSH.KeyPath),
141		wish.WithMiddleware(mw...),
142	)
143	if err != nil {
144		return nil, err
145	}
146
147	if cfg.SSH.MaxTimeout > 0 {
148		s.srv.MaxTimeout = time.Duration(cfg.SSH.MaxTimeout) * time.Second
149	}
150
151	if cfg.SSH.IdleTimeout > 0 {
152		s.srv.IdleTimeout = time.Duration(cfg.SSH.IdleTimeout) * time.Second
153	}
154
155	// Create client ssh key
156	if _, err := os.Stat(cfg.SSH.ClientKeyPath); err != nil && os.IsNotExist(err) {
157		_, err := keygen.New(cfg.SSH.ClientKeyPath, keygen.WithKeyType(keygen.Ed25519), keygen.WithWrite())
158		if err != nil {
159			return nil, fmt.Errorf("client ssh key: %w", err)
160		}
161	}
162
163	return s, nil
164}
165
166// ListenAndServe starts the SSH server.
167func (s *SSHServer) ListenAndServe() error {
168	return s.srv.ListenAndServe()
169}
170
171// Serve starts the SSH server on the given net.Listener.
172func (s *SSHServer) Serve(l net.Listener) error {
173	return s.srv.Serve(l)
174}
175
176// Close closes the SSH server.
177func (s *SSHServer) Close() error {
178	return s.srv.Close()
179}
180
181// Shutdown gracefully shuts down the SSH server.
182func (s *SSHServer) Shutdown(ctx context.Context) error {
183	return s.srv.Shutdown(ctx)
184}
185
186// PublicKeyAuthHandler handles public key authentication.
187func (s *SSHServer) PublicKeyHandler(ctx ssh.Context, pk ssh.PublicKey) (allowed bool) {
188	if pk == nil {
189		return false
190	}
191
192	defer func(allowed *bool) {
193		publicKeyCounter.WithLabelValues(strconv.FormatBool(*allowed)).Inc()
194	}(&allowed)
195
196	user, _ := s.be.UserByPublicKey(ctx, pk)
197	if user != nil {
198		ctx.SetValue(proto.ContextKeyUser, user)
199		allowed = true
200	}
201
202	return
203}
204
205// KeyboardInteractiveHandler handles keyboard interactive authentication.
206// This is used after all public key authentication has failed.
207func (s *SSHServer) KeyboardInteractiveHandler(ctx ssh.Context, _ gossh.KeyboardInteractiveChallenge) bool {
208	ac := s.be.AllowKeyless(ctx)
209	keyboardInteractiveCounter.WithLabelValues(strconv.FormatBool(ac)).Inc()
210	return ac
211}
212
213// sshFatal prints to the session's STDOUT as a git response and exit 1.
214func sshFatal(s ssh.Session, err error) {
215	git.WritePktlineErr(s, err) // nolint: errcheck
216	s.Exit(1)                   // nolint: errcheck
217}