1package ssh
  2
  3import (
  4	"context"
  5	"errors"
  6	"fmt"
  7	"net"
  8	"os"
  9	"path/filepath"
 10	"strconv"
 11	"strings"
 12	"time"
 13
 14	"github.com/charmbracelet/keygen"
 15	"github.com/charmbracelet/log"
 16	"github.com/charmbracelet/soft-serve/server/backend"
 17	cm "github.com/charmbracelet/soft-serve/server/cmd"
 18	"github.com/charmbracelet/soft-serve/server/config"
 19	"github.com/charmbracelet/soft-serve/server/git"
 20	"github.com/charmbracelet/soft-serve/server/utils"
 21	"github.com/charmbracelet/ssh"
 22	"github.com/charmbracelet/wish"
 23	bm "github.com/charmbracelet/wish/bubbletea"
 24	lm "github.com/charmbracelet/wish/logging"
 25	rm "github.com/charmbracelet/wish/recover"
 26	"github.com/muesli/termenv"
 27	"github.com/prometheus/client_golang/prometheus"
 28	"github.com/prometheus/client_golang/prometheus/promauto"
 29	gossh "golang.org/x/crypto/ssh"
 30)
 31
 32var (
 33	publicKeyCounter = promauto.NewCounterVec(prometheus.CounterOpts{
 34		Namespace: "soft_serve",
 35		Subsystem: "ssh",
 36		Name:      "public_key_auth_total",
 37		Help:      "The total number of public key auth requests",
 38	}, []string{"key", "user", "allowed"})
 39
 40	keyboardInteractiveCounter = promauto.NewCounterVec(prometheus.CounterOpts{
 41		Namespace: "soft_serve",
 42		Subsystem: "ssh",
 43		Name:      "keyboard_interactive_auth_total",
 44		Help:      "The total number of keyboard interactive auth requests",
 45	}, []string{"user", "allowed"})
 46
 47	uploadPackCounter = promauto.NewCounterVec(prometheus.CounterOpts{
 48		Namespace: "soft_serve",
 49		Subsystem: "ssh",
 50		Name:      "git_upload_pack_total",
 51		Help:      "The total number of git-upload-pack requests",
 52	}, []string{"key", "user", "repo"})
 53
 54	receivePackCounter = promauto.NewCounterVec(prometheus.CounterOpts{
 55		Namespace: "soft_serve",
 56		Subsystem: "ssh",
 57		Name:      "git_receive_pack_total",
 58		Help:      "The total number of git-receive-pack requests",
 59	}, []string{"key", "user", "repo"})
 60
 61	uploadArchiveCounter = promauto.NewCounterVec(prometheus.CounterOpts{
 62		Namespace: "soft_serve",
 63		Subsystem: "ssh",
 64		Name:      "git_upload_archive_total",
 65		Help:      "The total number of git-upload-archive requests",
 66	}, []string{"key", "user", "repo"})
 67
 68	createRepoCounter = promauto.NewCounterVec(prometheus.CounterOpts{
 69		Namespace: "soft_serve",
 70		Subsystem: "ssh",
 71		Name:      "create_repo_total",
 72		Help:      "The total number of create repo requests",
 73	}, []string{"key", "user", "repo"})
 74)
 75
 76// SSHServer is a SSH server that implements the git protocol.
 77type SSHServer struct {
 78	srv    *ssh.Server
 79	cfg    *config.Config
 80	ctx    context.Context
 81	logger *log.Logger
 82}
 83
 84// NewSSHServer returns a new SSHServer.
 85func NewSSHServer(ctx context.Context) (*SSHServer, error) {
 86	cfg := config.FromContext(ctx)
 87
 88	var err error
 89	s := &SSHServer{
 90		cfg:    cfg,
 91		ctx:    ctx,
 92		logger: log.FromContext(ctx).WithPrefix("ssh"),
 93	}
 94
 95	logger := s.logger.StandardLog(log.StandardLogOptions{ForceLevel: log.DebugLevel})
 96	mw := []wish.Middleware{
 97		rm.MiddlewareWithLogger(
 98			logger,
 99			// BubbleTea middleware.
100			bm.MiddlewareWithProgramHandler(SessionHandler(cfg), termenv.ANSI256),
101			// CLI middleware.
102			cm.Middleware(cfg),
103			// Git middleware.
104			s.Middleware(cfg),
105			// Logging middleware.
106			lm.MiddlewareWithLogger(logger),
107		),
108	}
109
110	s.srv, err = wish.NewServer(
111		ssh.PublicKeyAuth(s.PublicKeyHandler),
112		ssh.KeyboardInteractiveAuth(s.KeyboardInteractiveHandler),
113		wish.WithAddress(cfg.SSH.ListenAddr),
114		wish.WithHostKeyPath(cfg.SSH.KeyPath),
115		wish.WithMiddleware(mw...),
116	)
117	if err != nil {
118		return nil, err
119	}
120
121	if cfg.SSH.MaxTimeout > 0 {
122		s.srv.MaxTimeout = time.Duration(cfg.SSH.MaxTimeout) * time.Second
123	}
124
125	if cfg.SSH.IdleTimeout > 0 {
126		s.srv.IdleTimeout = time.Duration(cfg.SSH.IdleTimeout) * time.Second
127	}
128
129	// Create client ssh key
130	if _, err := os.Stat(cfg.SSH.ClientKeyPath); err != nil && os.IsNotExist(err) {
131		_, err := keygen.New(cfg.SSH.ClientKeyPath, keygen.WithKeyType(keygen.Ed25519), keygen.WithWrite())
132		if err != nil {
133			return nil, fmt.Errorf("client ssh key: %w", err)
134		}
135	}
136
137	return s, nil
138}
139
140// ListenAndServe starts the SSH server.
141func (s *SSHServer) ListenAndServe() error {
142	return s.srv.ListenAndServe()
143}
144
145// Serve starts the SSH server on the given net.Listener.
146func (s *SSHServer) Serve(l net.Listener) error {
147	return s.srv.Serve(l)
148}
149
150// Close closes the SSH server.
151func (s *SSHServer) Close() error {
152	return s.srv.Close()
153}
154
155// Shutdown gracefully shuts down the SSH server.
156func (s *SSHServer) Shutdown(ctx context.Context) error {
157	return s.srv.Shutdown(ctx)
158}
159
160// PublicKeyAuthHandler handles public key authentication.
161func (s *SSHServer) PublicKeyHandler(ctx ssh.Context, pk ssh.PublicKey) (allowed bool) {
162	if pk == nil {
163		return s.cfg.Backend.AllowKeyless()
164	}
165
166	ak := backend.MarshalAuthorizedKey(pk)
167	defer func(allowed *bool) {
168		publicKeyCounter.WithLabelValues(ak, ctx.User(), strconv.FormatBool(*allowed)).Inc()
169	}(&allowed)
170
171	ac := s.cfg.Backend.AccessLevelByPublicKey("", pk)
172	s.logger.Debugf("access level for %q: %s", ak, ac)
173	allowed = ac >= backend.ReadOnlyAccess
174	return
175}
176
177// KeyboardInteractiveHandler handles keyboard interactive authentication.
178func (s *SSHServer) KeyboardInteractiveHandler(ctx ssh.Context, _ gossh.KeyboardInteractiveChallenge) bool {
179	ac := s.cfg.Backend.AllowKeyless()
180	keyboardInteractiveCounter.WithLabelValues(ctx.User(), strconv.FormatBool(ac)).Inc()
181	return ac
182}
183
184// Middleware adds Git server functionality to the ssh.Server. Repos are stored
185// in the specified repo directory. The provided Hooks implementation will be
186// checked for access on a per repo basis for a ssh.Session public key.
187// Hooks.Push and Hooks.Fetch will be called on successful completion of
188// their commands.
189func (ss *SSHServer) Middleware(cfg *config.Config) wish.Middleware {
190	return func(sh ssh.Handler) ssh.Handler {
191		return func(s ssh.Session) {
192			func() {
193				cmd := s.Command()
194				if len(cmd) >= 2 && strings.HasPrefix(cmd[0], "git") {
195					gc := cmd[0]
196					// repo should be in the form of "repo.git"
197					name := utils.SanitizeRepo(cmd[1])
198					pk := s.PublicKey()
199					ak := backend.MarshalAuthorizedKey(pk)
200					access := cfg.Backend.AccessLevelByPublicKey(name, pk)
201					// git bare repositories should end in ".git"
202					// https://git-scm.com/docs/gitrepository-layout
203					repo := name + ".git"
204					reposDir := filepath.Join(cfg.DataPath, "repos")
205					if err := git.EnsureWithin(reposDir, repo); err != nil {
206						sshFatal(s, err)
207						return
208					}
209
210					// Environment variables to pass down to git hooks.
211					envs := []string{
212						"SOFT_SERVE_REPO_NAME=" + name,
213						"SOFT_SERVE_REPO_PATH=" + filepath.Join(reposDir, repo),
214						"SOFT_SERVE_PUBLIC_KEY=" + ak,
215					}
216
217					ss.logger.Debug("git middleware", "cmd", gc, "access", access.String())
218					repoDir := filepath.Join(reposDir, repo)
219					switch gc {
220					case git.ReceivePackBin:
221						if access < backend.ReadWriteAccess {
222							sshFatal(s, git.ErrNotAuthed)
223							return
224						}
225						if _, err := cfg.Backend.Repository(name); err != nil {
226							if _, err := cfg.Backend.CreateRepository(name, backend.RepositoryOptions{Private: false}); err != nil {
227								log.Errorf("failed to create repo: %s", err)
228								sshFatal(s, err)
229								return
230							}
231							createRepoCounter.WithLabelValues(ak, s.User(), name).Inc()
232						}
233						if err := git.ReceivePack(s.Context(), s, s, s.Stderr(), repoDir, envs...); err != nil {
234							sshFatal(s, git.ErrSystemMalfunction)
235						}
236						receivePackCounter.WithLabelValues(ak, s.User(), name).Inc()
237						return
238					case git.UploadPackBin, git.UploadArchiveBin:
239						if access < backend.ReadOnlyAccess {
240							sshFatal(s, git.ErrNotAuthed)
241							return
242						}
243
244						gitPack := git.UploadPack
245						counter := uploadPackCounter
246						if gc == git.UploadArchiveBin {
247							gitPack = git.UploadArchive
248							counter = uploadArchiveCounter
249						}
250
251						err := gitPack(s.Context(), s, s, s.Stderr(), repoDir, envs...)
252						if errors.Is(err, git.ErrInvalidRepo) {
253							sshFatal(s, git.ErrInvalidRepo)
254						} else if err != nil {
255							sshFatal(s, git.ErrSystemMalfunction)
256						}
257
258						counter.WithLabelValues(ak, s.User(), name).Inc()
259					}
260				}
261			}()
262			sh(s)
263		}
264	}
265}
266
267// sshFatal prints to the session's STDOUT as a git response and exit 1.
268func sshFatal(s ssh.Session, v ...interface{}) {
269	git.WritePktline(s, v...)
270	s.Exit(1) // nolint: errcheck
271}