1package ssh
  2
  3import (
  4	"context"
  5	"errors"
  6	"fmt"
  7	"net"
  8	"os"
  9	"path/filepath"
 10	"strconv"
 11	"strings"
 12	"time"
 13
 14	"github.com/charmbracelet/keygen"
 15	"github.com/charmbracelet/log"
 16	"github.com/charmbracelet/soft-serve/server/backend"
 17	cm "github.com/charmbracelet/soft-serve/server/cmd"
 18	"github.com/charmbracelet/soft-serve/server/config"
 19	"github.com/charmbracelet/soft-serve/server/git"
 20	"github.com/charmbracelet/soft-serve/server/utils"
 21	"github.com/charmbracelet/ssh"
 22	"github.com/charmbracelet/wish"
 23	bm "github.com/charmbracelet/wish/bubbletea"
 24	lm "github.com/charmbracelet/wish/logging"
 25	rm "github.com/charmbracelet/wish/recover"
 26	"github.com/muesli/termenv"
 27	"github.com/prometheus/client_golang/prometheus"
 28	"github.com/prometheus/client_golang/prometheus/promauto"
 29	gossh "golang.org/x/crypto/ssh"
 30)
 31
 32var (
 33	publicKeyCounter = promauto.NewCounterVec(prometheus.CounterOpts{
 34		Namespace: "soft_serve",
 35		Subsystem: "ssh",
 36		Name:      "public_key_auth_total",
 37		Help:      "The total number of public key auth requests",
 38	}, []string{"key", "user", "allowed"})
 39
 40	keyboardInteractiveCounter = promauto.NewCounterVec(prometheus.CounterOpts{
 41		Namespace: "soft_serve",
 42		Subsystem: "ssh",
 43		Name:      "keyboard_interactive_auth_total",
 44		Help:      "The total number of keyboard interactive auth requests",
 45	}, []string{"user", "allowed"})
 46
 47	uploadPackCounter = promauto.NewCounterVec(prometheus.CounterOpts{
 48		Namespace: "soft_serve",
 49		Subsystem: "ssh",
 50		Name:      "git_upload_pack_total",
 51		Help:      "The total number of git-upload-pack requests",
 52	}, []string{"key", "user", "repo"})
 53
 54	receivePackCounter = promauto.NewCounterVec(prometheus.CounterOpts{
 55		Namespace: "soft_serve",
 56		Subsystem: "ssh",
 57		Name:      "git_receive_pack_total",
 58		Help:      "The total number of git-receive-pack requests",
 59	}, []string{"key", "user", "repo"})
 60
 61	uploadArchiveCounter = promauto.NewCounterVec(prometheus.CounterOpts{
 62		Namespace: "soft_serve",
 63		Subsystem: "ssh",
 64		Name:      "git_upload_archive_total",
 65		Help:      "The total number of git-upload-archive requests",
 66	}, []string{"key", "user", "repo"})
 67
 68	createRepoCounter = promauto.NewCounterVec(prometheus.CounterOpts{
 69		Namespace: "soft_serve",
 70		Subsystem: "ssh",
 71		Name:      "create_repo_total",
 72		Help:      "The total number of create repo requests",
 73	}, []string{"key", "user", "repo"})
 74)
 75
 76// SSHServer is a SSH server that implements the git protocol.
 77type SSHServer struct {
 78	srv    *ssh.Server
 79	cfg    *config.Config
 80	be     backend.Backend
 81	ctx    context.Context
 82	logger *log.Logger
 83}
 84
 85// NewSSHServer returns a new SSHServer.
 86func NewSSHServer(ctx context.Context) (*SSHServer, error) {
 87	cfg := config.FromContext(ctx)
 88	logger := log.FromContext(ctx).WithPrefix("ssh")
 89
 90	var err error
 91	s := &SSHServer{
 92		cfg:    cfg,
 93		ctx:    ctx,
 94		be:     backend.FromContext(ctx),
 95		logger: logger,
 96	}
 97
 98	mw := []wish.Middleware{
 99		rm.MiddlewareWithLogger(
100			logger,
101			// BubbleTea middleware.
102			bm.MiddlewareWithProgramHandler(SessionHandler(cfg), termenv.ANSI256),
103			// CLI middleware.
104			cm.Middleware(cfg, logger),
105			// Git middleware.
106			s.Middleware(cfg),
107			// Logging middleware.
108			lm.MiddlewareWithLogger(logger.
109				StandardLog(log.StandardLogOptions{ForceLevel: log.DebugLevel})),
110		),
111	}
112
113	s.srv, err = wish.NewServer(
114		ssh.PublicKeyAuth(s.PublicKeyHandler),
115		ssh.KeyboardInteractiveAuth(s.KeyboardInteractiveHandler),
116		wish.WithAddress(cfg.SSH.ListenAddr),
117		wish.WithHostKeyPath(cfg.SSH.KeyPath),
118		wish.WithMiddleware(mw...),
119	)
120	if err != nil {
121		return nil, err
122	}
123
124	if cfg.SSH.MaxTimeout > 0 {
125		s.srv.MaxTimeout = time.Duration(cfg.SSH.MaxTimeout) * time.Second
126	}
127
128	if cfg.SSH.IdleTimeout > 0 {
129		s.srv.IdleTimeout = time.Duration(cfg.SSH.IdleTimeout) * time.Second
130	}
131
132	// Create client ssh key
133	if _, err := os.Stat(cfg.SSH.ClientKeyPath); err != nil && os.IsNotExist(err) {
134		_, err := keygen.New(cfg.SSH.ClientKeyPath, keygen.WithKeyType(keygen.Ed25519), keygen.WithWrite())
135		if err != nil {
136			return nil, fmt.Errorf("client ssh key: %w", err)
137		}
138	}
139
140	return s, nil
141}
142
143// ListenAndServe starts the SSH server.
144func (s *SSHServer) ListenAndServe() error {
145	return s.srv.ListenAndServe()
146}
147
148// Serve starts the SSH server on the given net.Listener.
149func (s *SSHServer) Serve(l net.Listener) error {
150	return s.srv.Serve(l)
151}
152
153// Close closes the SSH server.
154func (s *SSHServer) Close() error {
155	return s.srv.Close()
156}
157
158// Shutdown gracefully shuts down the SSH server.
159func (s *SSHServer) Shutdown(ctx context.Context) error {
160	return s.srv.Shutdown(ctx)
161}
162
163// PublicKeyAuthHandler handles public key authentication.
164func (s *SSHServer) PublicKeyHandler(ctx ssh.Context, pk ssh.PublicKey) (allowed bool) {
165	if pk == nil {
166		return s.cfg.Backend.AllowKeyless()
167	}
168
169	ak := backend.MarshalAuthorizedKey(pk)
170	defer func(allowed *bool) {
171		publicKeyCounter.WithLabelValues(ak, ctx.User(), strconv.FormatBool(*allowed)).Inc()
172	}(&allowed)
173
174	ac := s.cfg.Backend.AccessLevelByPublicKey("", pk)
175	s.logger.Debugf("access level for %q: %s", ak, ac)
176	allowed = ac >= backend.ReadOnlyAccess
177	return
178}
179
180// KeyboardInteractiveHandler handles keyboard interactive authentication.
181func (s *SSHServer) KeyboardInteractiveHandler(ctx ssh.Context, _ gossh.KeyboardInteractiveChallenge) bool {
182	ac := s.cfg.Backend.AllowKeyless()
183	keyboardInteractiveCounter.WithLabelValues(ctx.User(), strconv.FormatBool(ac)).Inc()
184	return ac
185}
186
187// Middleware adds Git server functionality to the ssh.Server. Repos are stored
188// in the specified repo directory. The provided Hooks implementation will be
189// checked for access on a per repo basis for a ssh.Session public key.
190// Hooks.Push and Hooks.Fetch will be called on successful completion of
191// their commands.
192func (ss *SSHServer) Middleware(cfg *config.Config) wish.Middleware {
193	return func(sh ssh.Handler) ssh.Handler {
194		return func(s ssh.Session) {
195			func() {
196				cmd := s.Command()
197				ctx := s.Context()
198				be := ss.be.WithContext(ctx)
199				if len(cmd) >= 2 && strings.HasPrefix(cmd[0], "git") {
200					gc := cmd[0]
201					// repo should be in the form of "repo.git"
202					name := utils.SanitizeRepo(cmd[1])
203					pk := s.PublicKey()
204					ak := backend.MarshalAuthorizedKey(pk)
205					access := cfg.Backend.AccessLevelByPublicKey(name, pk)
206					// git bare repositories should end in ".git"
207					// https://git-scm.com/docs/gitrepository-layout
208					repo := name + ".git"
209					reposDir := filepath.Join(cfg.DataPath, "repos")
210					if err := git.EnsureWithin(reposDir, repo); err != nil {
211						sshFatal(s, err)
212						return
213					}
214
215					// Environment variables to pass down to git hooks.
216					envs := []string{
217						"SOFT_SERVE_REPO_NAME=" + name,
218						"SOFT_SERVE_REPO_PATH=" + filepath.Join(reposDir, repo),
219						"SOFT_SERVE_PUBLIC_KEY=" + ak,
220					}
221
222					ss.logger.Debug("git middleware", "cmd", gc, "access", access.String())
223					repoDir := filepath.Join(reposDir, repo)
224					switch gc {
225					case git.ReceivePackBin:
226						if access < backend.ReadWriteAccess {
227							sshFatal(s, git.ErrNotAuthed)
228							return
229						}
230						if _, err := be.Repository(name); err != nil {
231							if _, err := be.CreateRepository(name, backend.RepositoryOptions{Private: false}); err != nil {
232								log.Errorf("failed to create repo: %s", err)
233								sshFatal(s, err)
234								return
235							}
236							createRepoCounter.WithLabelValues(ak, s.User(), name).Inc()
237						}
238						if err := git.ReceivePack(s.Context(), s, s, s.Stderr(), repoDir, envs...); err != nil {
239							sshFatal(s, git.ErrSystemMalfunction)
240						}
241						receivePackCounter.WithLabelValues(ak, s.User(), name).Inc()
242						return
243					case git.UploadPackBin, git.UploadArchiveBin:
244						if access < backend.ReadOnlyAccess {
245							sshFatal(s, git.ErrNotAuthed)
246							return
247						}
248
249						gitPack := git.UploadPack
250						counter := uploadPackCounter
251						if gc == git.UploadArchiveBin {
252							gitPack = git.UploadArchive
253							counter = uploadArchiveCounter
254						}
255
256						err := gitPack(ctx, s, s, s.Stderr(), repoDir, envs...)
257						if errors.Is(err, git.ErrInvalidRepo) {
258							sshFatal(s, git.ErrInvalidRepo)
259						} else if err != nil {
260							sshFatal(s, git.ErrSystemMalfunction)
261						}
262
263						counter.WithLabelValues(ak, s.User(), name).Inc()
264					}
265				}
266			}()
267			sh(s)
268		}
269	}
270}
271
272// sshFatal prints to the session's STDOUT as a git response and exit 1.
273func sshFatal(s ssh.Session, v ...interface{}) {
274	git.WritePktline(s, v...)
275	s.Exit(1) // nolint: errcheck
276}