ssh.go

  1package ssh
  2
  3import (
  4	"context"
  5	"errors"
  6	"net"
  7	"path/filepath"
  8	"strconv"
  9	"strings"
 10	"time"
 11
 12	"github.com/charmbracelet/log"
 13	"github.com/charmbracelet/soft-serve/server/backend"
 14	cm "github.com/charmbracelet/soft-serve/server/cmd"
 15	"github.com/charmbracelet/soft-serve/server/config"
 16	"github.com/charmbracelet/soft-serve/server/git"
 17	"github.com/charmbracelet/soft-serve/server/utils"
 18	"github.com/charmbracelet/ssh"
 19	"github.com/charmbracelet/wish"
 20	bm "github.com/charmbracelet/wish/bubbletea"
 21	lm "github.com/charmbracelet/wish/logging"
 22	rm "github.com/charmbracelet/wish/recover"
 23	"github.com/muesli/termenv"
 24	"github.com/prometheus/client_golang/prometheus"
 25	"github.com/prometheus/client_golang/prometheus/promauto"
 26	gossh "golang.org/x/crypto/ssh"
 27)
 28
 29var (
 30	publicKeyCounter = promauto.NewCounterVec(prometheus.CounterOpts{
 31		Namespace: "soft_serve",
 32		Subsystem: "ssh",
 33		Name:      "public_key_auth_total",
 34		Help:      "The total number of public key auth requests",
 35	}, []string{"key", "user", "allowed"})
 36
 37	keyboardInteractiveCounter = promauto.NewCounterVec(prometheus.CounterOpts{
 38		Namespace: "soft_serve",
 39		Subsystem: "ssh",
 40		Name:      "keyboard_interactive_auth_total",
 41		Help:      "The total number of keyboard interactive auth requests",
 42	}, []string{"user", "allowed"})
 43
 44	uploadPackCounter = promauto.NewCounterVec(prometheus.CounterOpts{
 45		Namespace: "soft_serve",
 46		Subsystem: "ssh",
 47		Name:      "git_upload_pack_total",
 48		Help:      "The total number of git-upload-pack requests",
 49	}, []string{"key", "user", "repo"})
 50
 51	receivePackCounter = promauto.NewCounterVec(prometheus.CounterOpts{
 52		Namespace: "soft_serve",
 53		Subsystem: "ssh",
 54		Name:      "git_receive_pack_total",
 55		Help:      "The total number of git-receive-pack requests",
 56	}, []string{"key", "user", "repo"})
 57
 58	uploadArchiveCounter = promauto.NewCounterVec(prometheus.CounterOpts{
 59		Namespace: "soft_serve",
 60		Subsystem: "ssh",
 61		Name:      "git_upload_archive_total",
 62		Help:      "The total number of git-upload-archive requests",
 63	}, []string{"key", "user", "repo"})
 64
 65	createRepoCounter = promauto.NewCounterVec(prometheus.CounterOpts{
 66		Namespace: "soft_serve",
 67		Subsystem: "ssh",
 68		Name:      "create_repo_total",
 69		Help:      "The total number of create repo requests",
 70	}, []string{"key", "user", "repo"})
 71)
 72
 73// SSHServer is a SSH server that implements the git protocol.
 74type SSHServer struct {
 75	srv    *ssh.Server
 76	cfg    *config.Config
 77	ctx    context.Context
 78	logger *log.Logger
 79}
 80
 81// NewSSHServer returns a new SSHServer.
 82func NewSSHServer(ctx context.Context) (*SSHServer, error) {
 83	cfg := config.FromContext(ctx)
 84	var err error
 85	s := &SSHServer{
 86		cfg:    cfg,
 87		ctx:    ctx,
 88		logger: log.FromContext(ctx).WithPrefix("ssh"),
 89	}
 90	logger := s.logger.StandardLog(log.StandardLogOptions{ForceLevel: log.DebugLevel})
 91	mw := []wish.Middleware{
 92		rm.MiddlewareWithLogger(
 93			logger,
 94			// BubbleTea middleware.
 95			bm.MiddlewareWithProgramHandler(SessionHandler(cfg), termenv.ANSI256),
 96			// CLI middleware.
 97			cm.Middleware(cfg),
 98			// Git middleware.
 99			s.Middleware(cfg),
100			// Logging middleware.
101			lm.MiddlewareWithLogger(logger),
102		),
103	}
104	s.srv, err = wish.NewServer(
105		ssh.PublicKeyAuth(s.PublicKeyHandler),
106		ssh.KeyboardInteractiveAuth(s.KeyboardInteractiveHandler),
107		wish.WithAddress(cfg.SSH.ListenAddr),
108		wish.WithHostKeyPath(cfg.SSH.KeyPath),
109		wish.WithMiddleware(mw...),
110	)
111	if err != nil {
112		return nil, err
113	}
114
115	if cfg.SSH.MaxTimeout > 0 {
116		s.srv.MaxTimeout = time.Duration(cfg.SSH.MaxTimeout) * time.Second
117	}
118	if cfg.SSH.IdleTimeout > 0 {
119		s.srv.IdleTimeout = time.Duration(cfg.SSH.IdleTimeout) * time.Second
120	}
121
122	return s, nil
123}
124
125// ListenAndServe starts the SSH server.
126func (s *SSHServer) ListenAndServe() error {
127	return s.srv.ListenAndServe()
128}
129
130// Serve starts the SSH server on the given net.Listener.
131func (s *SSHServer) Serve(l net.Listener) error {
132	return s.srv.Serve(l)
133}
134
135// Close closes the SSH server.
136func (s *SSHServer) Close() error {
137	return s.srv.Close()
138}
139
140// Shutdown gracefully shuts down the SSH server.
141func (s *SSHServer) Shutdown(ctx context.Context) error {
142	return s.srv.Shutdown(ctx)
143}
144
145// PublicKeyAuthHandler handles public key authentication.
146func (s *SSHServer) PublicKeyHandler(ctx ssh.Context, pk ssh.PublicKey) (allowed bool) {
147	if pk == nil {
148		return s.cfg.Backend.AllowKeyless()
149	}
150
151	ak := backend.MarshalAuthorizedKey(pk)
152	defer func(allowed *bool) {
153		publicKeyCounter.WithLabelValues(ak, ctx.User(), strconv.FormatBool(*allowed)).Inc()
154	}(&allowed)
155
156	ac := s.cfg.Backend.AccessLevelByPublicKey("", pk)
157	s.logger.Debugf("access level for %q: %s", ak, ac)
158	allowed = ac >= backend.ReadOnlyAccess
159	return
160}
161
162// KeyboardInteractiveHandler handles keyboard interactive authentication.
163func (s *SSHServer) KeyboardInteractiveHandler(ctx ssh.Context, _ gossh.KeyboardInteractiveChallenge) bool {
164	ac := s.cfg.Backend.AllowKeyless()
165	keyboardInteractiveCounter.WithLabelValues(ctx.User(), strconv.FormatBool(ac)).Inc()
166	return ac
167}
168
169// Middleware adds Git server functionality to the ssh.Server. Repos are stored
170// in the specified repo directory. The provided Hooks implementation will be
171// checked for access on a per repo basis for a ssh.Session public key.
172// Hooks.Push and Hooks.Fetch will be called on successful completion of
173// their commands.
174func (ss *SSHServer) Middleware(cfg *config.Config) wish.Middleware {
175	return func(sh ssh.Handler) ssh.Handler {
176		return func(s ssh.Session) {
177			func() {
178				cmd := s.Command()
179				if len(cmd) >= 2 && strings.HasPrefix(cmd[0], "git") {
180					gc := cmd[0]
181					// repo should be in the form of "repo.git"
182					name := utils.SanitizeRepo(cmd[1])
183					pk := s.PublicKey()
184					ak := backend.MarshalAuthorizedKey(pk)
185					access := cfg.Backend.AccessLevelByPublicKey(name, pk)
186					// git bare repositories should end in ".git"
187					// https://git-scm.com/docs/gitrepository-layout
188					repo := name + ".git"
189					reposDir := filepath.Join(cfg.DataPath, "repos")
190					if err := git.EnsureWithin(reposDir, repo); err != nil {
191						sshFatal(s, err)
192						return
193					}
194
195					// Environment variables to pass down to git hooks.
196					envs := []string{
197						"SOFT_SERVE_REPO_NAME=" + name,
198						"SOFT_SERVE_REPO_PATH=" + filepath.Join(reposDir, repo),
199						"SOFT_SERVE_PUBLIC_KEY=" + ak,
200					}
201
202					ss.logger.Debug("git middleware", "cmd", gc, "access", access.String())
203					repoDir := filepath.Join(reposDir, repo)
204					switch gc {
205					case git.ReceivePackBin:
206						if access < backend.ReadWriteAccess {
207							sshFatal(s, git.ErrNotAuthed)
208							return
209						}
210						if _, err := cfg.Backend.Repository(name); err != nil {
211							if _, err := cfg.Backend.CreateRepository(name, backend.RepositoryOptions{Private: false}); err != nil {
212								log.Errorf("failed to create repo: %s", err)
213								sshFatal(s, err)
214								return
215							}
216							createRepoCounter.WithLabelValues(ak, s.User(), name).Inc()
217						}
218						if err := git.ReceivePack(s.Context(), s, s, s.Stderr(), repoDir, envs...); err != nil {
219							sshFatal(s, git.ErrSystemMalfunction)
220						}
221						receivePackCounter.WithLabelValues(ak, s.User(), name).Inc()
222						return
223					case git.UploadPackBin, git.UploadArchiveBin:
224						if access < backend.ReadOnlyAccess {
225							sshFatal(s, git.ErrNotAuthed)
226							return
227						}
228
229						gitPack := git.UploadPack
230						counter := uploadPackCounter
231						if gc == git.UploadArchiveBin {
232							gitPack = git.UploadArchive
233							counter = uploadArchiveCounter
234						}
235
236						err := gitPack(s.Context(), s, s, s.Stderr(), repoDir, envs...)
237						if errors.Is(err, git.ErrInvalidRepo) {
238							sshFatal(s, git.ErrInvalidRepo)
239						} else if err != nil {
240							sshFatal(s, git.ErrSystemMalfunction)
241						}
242
243						counter.WithLabelValues(ak, s.User(), name).Inc()
244					}
245				}
246			}()
247			sh(s)
248		}
249	}
250}
251
252// sshFatal prints to the session's STDOUT as a git response and exit 1.
253func sshFatal(s ssh.Session, v ...interface{}) {
254	git.WritePktline(s, v...)
255	s.Exit(1) // nolint: errcheck
256}