ssh.go

  1package ssh
  2
  3import (
  4	"context"
  5	"fmt"
  6	"net"
  7	"os"
  8	"strconv"
  9	"time"
 10
 11	"github.com/charmbracelet/keygen"
 12	"github.com/charmbracelet/log"
 13	"github.com/charmbracelet/soft-serve/server/access"
 14	"github.com/charmbracelet/soft-serve/server/backend"
 15	"github.com/charmbracelet/soft-serve/server/config"
 16	"github.com/charmbracelet/soft-serve/server/db"
 17	"github.com/charmbracelet/soft-serve/server/git"
 18	"github.com/charmbracelet/soft-serve/server/proto"
 19	"github.com/charmbracelet/soft-serve/server/sshutils"
 20	"github.com/charmbracelet/soft-serve/server/store"
 21	"github.com/charmbracelet/ssh"
 22	"github.com/charmbracelet/wish"
 23	bm "github.com/charmbracelet/wish/bubbletea"
 24	lm "github.com/charmbracelet/wish/logging"
 25	rm "github.com/charmbracelet/wish/recover"
 26	"github.com/muesli/termenv"
 27	"github.com/prometheus/client_golang/prometheus"
 28	"github.com/prometheus/client_golang/prometheus/promauto"
 29	gossh "golang.org/x/crypto/ssh"
 30)
 31
 32var (
 33	publicKeyCounter = promauto.NewCounterVec(prometheus.CounterOpts{
 34		Namespace: "soft_serve",
 35		Subsystem: "ssh",
 36		Name:      "public_key_auth_total",
 37		Help:      "The total number of public key auth requests",
 38	}, []string{"allowed"})
 39
 40	keyboardInteractiveCounter = promauto.NewCounterVec(prometheus.CounterOpts{
 41		Namespace: "soft_serve",
 42		Subsystem: "ssh",
 43		Name:      "keyboard_interactive_auth_total",
 44		Help:      "The total number of keyboard interactive auth requests",
 45	}, []string{"allowed"})
 46
 47	uploadPackCounter = promauto.NewCounterVec(prometheus.CounterOpts{
 48		Namespace: "soft_serve",
 49		Subsystem: "git",
 50		Name:      "upload_pack_total",
 51		Help:      "The total number of git-upload-pack requests",
 52	}, []string{"repo"})
 53
 54	receivePackCounter = promauto.NewCounterVec(prometheus.CounterOpts{
 55		Namespace: "soft_serve",
 56		Subsystem: "git",
 57		Name:      "receive_pack_total",
 58		Help:      "The total number of git-receive-pack requests",
 59	}, []string{"repo"})
 60
 61	uploadArchiveCounter = promauto.NewCounterVec(prometheus.CounterOpts{
 62		Namespace: "soft_serve",
 63		Subsystem: "git",
 64		Name:      "upload_archive_total",
 65		Help:      "The total number of git-upload-archive requests",
 66	}, []string{"repo"})
 67
 68	uploadPackSeconds = promauto.NewCounterVec(prometheus.CounterOpts{
 69		Namespace: "soft_serve",
 70		Subsystem: "git",
 71		Name:      "upload_pack_seconds_total",
 72		Help:      "The total time spent on git-upload-pack requests",
 73	}, []string{"repo"})
 74
 75	receivePackSeconds = promauto.NewCounterVec(prometheus.CounterOpts{
 76		Namespace: "soft_serve",
 77		Subsystem: "git",
 78		Name:      "receive_pack_seconds_total",
 79		Help:      "The total time spent on git-receive-pack requests",
 80	}, []string{"repo"})
 81
 82	uploadArchiveSeconds = promauto.NewCounterVec(prometheus.CounterOpts{
 83		Namespace: "soft_serve",
 84		Subsystem: "git",
 85		Name:      "upload_archive_seconds_total",
 86		Help:      "The total time spent on git-upload-archive requests",
 87	}, []string{"repo"})
 88
 89	createRepoCounter = promauto.NewCounterVec(prometheus.CounterOpts{
 90		Namespace: "soft_serve",
 91		Subsystem: "ssh",
 92		Name:      "create_repo_total",
 93		Help:      "The total number of create repo requests",
 94	}, []string{"repo"})
 95)
 96
 97// SSHServer is a SSH server that implements the git protocol.
 98type SSHServer struct { // nolint: revive
 99	srv    *ssh.Server
100	cfg    *config.Config
101	be     *backend.Backend
102	ctx    context.Context
103	logger *log.Logger
104}
105
106// NewSSHServer returns a new SSHServer.
107func NewSSHServer(ctx context.Context) (*SSHServer, error) {
108	cfg := config.FromContext(ctx)
109	logger := log.FromContext(ctx).WithPrefix("ssh")
110	dbx := db.FromContext(ctx)
111	datastore := store.FromContext(ctx)
112	be := backend.FromContext(ctx)
113
114	var err error
115	s := &SSHServer{
116		cfg:    cfg,
117		ctx:    ctx,
118		be:     be,
119		logger: logger,
120	}
121
122	mw := []wish.Middleware{
123		rm.MiddlewareWithLogger(
124			logger,
125			// BubbleTea middleware.
126			bm.MiddlewareWithProgramHandler(SessionHandler, termenv.ANSI256),
127			// CLI middleware.
128			CommandMiddleware,
129			// Context middleware.
130			ContextMiddleware(cfg, dbx, datastore, be, logger),
131			// Logging middleware.
132			lm.MiddlewareWithLogger(
133				&loggerAdapter{logger, log.DebugLevel},
134			),
135		),
136	}
137
138	s.srv, err = wish.NewServer(
139		ssh.PublicKeyAuth(s.PublicKeyHandler),
140		ssh.KeyboardInteractiveAuth(s.KeyboardInteractiveHandler),
141		wish.WithAddress(cfg.SSH.ListenAddr),
142		wish.WithHostKeyPath(cfg.SSH.KeyPath),
143		wish.WithMiddleware(mw...),
144	)
145	if err != nil {
146		return nil, err
147	}
148
149	if cfg.SSH.MaxTimeout > 0 {
150		s.srv.MaxTimeout = time.Duration(cfg.SSH.MaxTimeout) * time.Second
151	}
152
153	if cfg.SSH.IdleTimeout > 0 {
154		s.srv.IdleTimeout = time.Duration(cfg.SSH.IdleTimeout) * time.Second
155	}
156
157	// Create client ssh key
158	if _, err := os.Stat(cfg.SSH.ClientKeyPath); err != nil && os.IsNotExist(err) {
159		_, err := keygen.New(cfg.SSH.ClientKeyPath, keygen.WithKeyType(keygen.Ed25519), keygen.WithWrite())
160		if err != nil {
161			return nil, fmt.Errorf("client ssh key: %w", err)
162		}
163	}
164
165	return s, nil
166}
167
168// ListenAndServe starts the SSH server.
169func (s *SSHServer) ListenAndServe() error {
170	return s.srv.ListenAndServe()
171}
172
173// Serve starts the SSH server on the given net.Listener.
174func (s *SSHServer) Serve(l net.Listener) error {
175	return s.srv.Serve(l)
176}
177
178// Close closes the SSH server.
179func (s *SSHServer) Close() error {
180	return s.srv.Close()
181}
182
183// Shutdown gracefully shuts down the SSH server.
184func (s *SSHServer) Shutdown(ctx context.Context) error {
185	return s.srv.Shutdown(ctx)
186}
187
188// PublicKeyAuthHandler handles public key authentication.
189func (s *SSHServer) PublicKeyHandler(ctx ssh.Context, pk ssh.PublicKey) (allowed bool) {
190	if pk == nil {
191		return false
192	}
193
194	ak := sshutils.MarshalAuthorizedKey(pk)
195	defer func(allowed *bool) {
196		publicKeyCounter.WithLabelValues(strconv.FormatBool(*allowed)).Inc()
197	}(&allowed)
198
199	user, _ := s.be.UserByPublicKey(ctx, pk)
200	ctx.SetValue(proto.ContextKeyUser, user)
201
202	ac := s.be.AccessLevelForUser(ctx, "", user)
203	s.logger.Debugf("access level for %q: %s", ak, ac)
204	allowed = ac >= access.ReadWriteAccess
205	return
206}
207
208// KeyboardInteractiveHandler handles keyboard interactive authentication.
209// This is used after all public key authentication has failed.
210func (s *SSHServer) KeyboardInteractiveHandler(ctx ssh.Context, _ gossh.KeyboardInteractiveChallenge) bool {
211	ac := s.be.AllowKeyless(ctx)
212	keyboardInteractiveCounter.WithLabelValues(strconv.FormatBool(ac)).Inc()
213	return ac
214}
215
216// sshFatal prints to the session's STDOUT as a git response and exit 1.
217func sshFatal(s ssh.Session, err error) {
218	git.WritePktlineErr(s, err) // nolint: errcheck
219	s.Exit(1)                   // nolint: errcheck
220}