1package ssh
  2
  3import (
  4	"context"
  5	"errors"
  6	"net"
  7	"path/filepath"
  8	"strconv"
  9	"strings"
 10	"time"
 11
 12	"github.com/charmbracelet/log"
 13	"github.com/charmbracelet/soft-serve/server/backend"
 14	cm "github.com/charmbracelet/soft-serve/server/cmd"
 15	"github.com/charmbracelet/soft-serve/server/config"
 16	"github.com/charmbracelet/soft-serve/server/git"
 17	"github.com/charmbracelet/soft-serve/server/utils"
 18	"github.com/charmbracelet/ssh"
 19	"github.com/charmbracelet/wish"
 20	bm "github.com/charmbracelet/wish/bubbletea"
 21	lm "github.com/charmbracelet/wish/logging"
 22	rm "github.com/charmbracelet/wish/recover"
 23	"github.com/muesli/termenv"
 24	"github.com/prometheus/client_golang/prometheus"
 25	"github.com/prometheus/client_golang/prometheus/promauto"
 26	gossh "golang.org/x/crypto/ssh"
 27)
 28
 29var (
 30	logger = log.WithPrefix("server.ssh")
 31)
 32
 33var (
 34	publicKeyCounter = promauto.NewCounterVec(prometheus.CounterOpts{
 35		Namespace: "soft_serve",
 36		Subsystem: "ssh",
 37		Name:      "public_key_auth_total",
 38		Help:      "The total number of public key auth requests",
 39	}, []string{"key", "user", "allowed"})
 40
 41	keyboardInteractiveCounter = promauto.NewCounterVec(prometheus.CounterOpts{
 42		Namespace: "soft_serve",
 43		Subsystem: "ssh",
 44		Name:      "keyboard_interactive_auth_total",
 45		Help:      "The total number of keyboard interactive auth requests",
 46	}, []string{"user", "allowed"})
 47
 48	uploadPackCounter = promauto.NewCounterVec(prometheus.CounterOpts{
 49		Namespace: "soft_serve",
 50		Subsystem: "ssh",
 51		Name:      "git_upload_pack_total",
 52		Help:      "The total number of git-upload-pack requests",
 53	}, []string{"key", "user", "repo"})
 54
 55	receivePackCounter = promauto.NewCounterVec(prometheus.CounterOpts{
 56		Namespace: "soft_serve",
 57		Subsystem: "ssh",
 58		Name:      "git_receive_pack_total",
 59		Help:      "The total number of git-receive-pack requests",
 60	}, []string{"key", "user", "repo"})
 61
 62	uploadArchiveCounter = promauto.NewCounterVec(prometheus.CounterOpts{
 63		Namespace: "soft_serve",
 64		Subsystem: "ssh",
 65		Name:      "git_upload_archive_total",
 66		Help:      "The total number of git-upload-archive requests",
 67	}, []string{"key", "user", "repo"})
 68
 69	createRepoCounter = promauto.NewCounterVec(prometheus.CounterOpts{
 70		Namespace: "soft_serve",
 71		Subsystem: "ssh",
 72		Name:      "create_repo_total",
 73		Help:      "The total number of create repo requests",
 74	}, []string{"key", "user", "repo"})
 75)
 76
 77// SSHServer is a SSH server that implements the git protocol.
 78type SSHServer struct {
 79	srv *ssh.Server
 80	cfg *config.Config
 81}
 82
 83// NewSSHServer returns a new SSHServer.
 84func NewSSHServer(cfg *config.Config) (*SSHServer, error) {
 85	var err error
 86	s := &SSHServer{cfg: cfg}
 87	logger := logger.StandardLog(log.StandardLogOptions{ForceLevel: log.DebugLevel})
 88	mw := []wish.Middleware{
 89		rm.MiddlewareWithLogger(
 90			logger,
 91			// BubbleTea middleware.
 92			bm.MiddlewareWithProgramHandler(SessionHandler(cfg), termenv.ANSI256),
 93			// CLI middleware.
 94			cm.Middleware(cfg),
 95			// Git middleware.
 96			s.Middleware(cfg),
 97			// Logging middleware.
 98			lm.MiddlewareWithLogger(logger),
 99		),
100	}
101	s.srv, err = wish.NewServer(
102		ssh.PublicKeyAuth(s.PublicKeyHandler),
103		ssh.KeyboardInteractiveAuth(s.KeyboardInteractiveHandler),
104		wish.WithAddress(cfg.SSH.ListenAddr),
105		wish.WithHostKeyPath(cfg.SSH.KeyPath),
106		wish.WithMiddleware(mw...),
107	)
108	if err != nil {
109		return nil, err
110	}
111
112	if cfg.SSH.MaxTimeout > 0 {
113		s.srv.MaxTimeout = time.Duration(cfg.SSH.MaxTimeout) * time.Second
114	}
115	if cfg.SSH.IdleTimeout > 0 {
116		s.srv.IdleTimeout = time.Duration(cfg.SSH.IdleTimeout) * time.Second
117	}
118
119	return s, nil
120}
121
122// ListenAndServe starts the SSH server.
123func (s *SSHServer) ListenAndServe() error {
124	return s.srv.ListenAndServe()
125}
126
127// Serve starts the SSH server on the given net.Listener.
128func (s *SSHServer) Serve(l net.Listener) error {
129	return s.srv.Serve(l)
130}
131
132// Close closes the SSH server.
133func (s *SSHServer) Close() error {
134	return s.srv.Close()
135}
136
137// Shutdown gracefully shuts down the SSH server.
138func (s *SSHServer) Shutdown(ctx context.Context) error {
139	return s.srv.Shutdown(ctx)
140}
141
142// PublicKeyAuthHandler handles public key authentication.
143func (s *SSHServer) PublicKeyHandler(ctx ssh.Context, pk ssh.PublicKey) (allowed bool) {
144	if pk == nil {
145		return s.cfg.Backend.AllowKeyless()
146	}
147
148	ak := backend.MarshalAuthorizedKey(pk)
149	defer func(allowed *bool) {
150		publicKeyCounter.WithLabelValues(ak, ctx.User(), strconv.FormatBool(*allowed)).Inc()
151	}(&allowed)
152
153	ac := s.cfg.Backend.AccessLevelByPublicKey("", pk)
154	logger.Debugf("access level for %q: %s", ak, ac)
155	allowed = ac >= backend.ReadOnlyAccess
156	return
157}
158
159// KeyboardInteractiveHandler handles keyboard interactive authentication.
160func (s *SSHServer) KeyboardInteractiveHandler(ctx ssh.Context, _ gossh.KeyboardInteractiveChallenge) bool {
161	ac := s.cfg.Backend.AllowKeyless()
162	keyboardInteractiveCounter.WithLabelValues(ctx.User(), strconv.FormatBool(ac)).Inc()
163	return ac
164}
165
166// Middleware adds Git server functionality to the ssh.Server. Repos are stored
167// in the specified repo directory. The provided Hooks implementation will be
168// checked for access on a per repo basis for a ssh.Session public key.
169// Hooks.Push and Hooks.Fetch will be called on successful completion of
170// their commands.
171func (s *SSHServer) Middleware(cfg *config.Config) wish.Middleware {
172	return func(sh ssh.Handler) ssh.Handler {
173		return func(s ssh.Session) {
174			func() {
175				cmd := s.Command()
176				if len(cmd) >= 2 && strings.HasPrefix(cmd[0], "git") {
177					gc := cmd[0]
178					// repo should be in the form of "repo.git"
179					name := utils.SanitizeRepo(cmd[1])
180					pk := s.PublicKey()
181					ak := backend.MarshalAuthorizedKey(pk)
182					access := cfg.Backend.AccessLevelByPublicKey(name, pk)
183					// git bare repositories should end in ".git"
184					// https://git-scm.com/docs/gitrepository-layout
185					repo := name + ".git"
186					reposDir := filepath.Join(cfg.DataPath, "repos")
187					if err := git.EnsureWithin(reposDir, repo); err != nil {
188						sshFatal(s, err)
189						return
190					}
191
192					logger.Debug("git middleware", "cmd", gc, "access", access.String())
193					repoDir := filepath.Join(reposDir, repo)
194					switch gc {
195					case git.ReceivePackBin:
196						if access < backend.ReadWriteAccess {
197							sshFatal(s, git.ErrNotAuthed)
198							return
199						}
200						if _, err := cfg.Backend.Repository(name); err != nil {
201							if _, err := cfg.Backend.CreateRepository(name, backend.RepositoryOptions{Private: false}); err != nil {
202								log.Errorf("failed to create repo: %s", err)
203								sshFatal(s, err)
204								return
205							}
206							createRepoCounter.WithLabelValues(ak, s.User(), name).Inc()
207						}
208						if err := git.ReceivePack(s, s, s.Stderr(), repoDir); err != nil {
209							sshFatal(s, git.ErrSystemMalfunction)
210						}
211						receivePackCounter.WithLabelValues(ak, s.User(), name).Inc()
212						return
213					case git.UploadPackBin, git.UploadArchiveBin:
214						if access < backend.ReadOnlyAccess {
215							sshFatal(s, git.ErrNotAuthed)
216							return
217						}
218
219						gitPack := git.UploadPack
220						counter := uploadPackCounter
221						if gc == git.UploadArchiveBin {
222							gitPack = git.UploadArchive
223							counter = uploadArchiveCounter
224						}
225
226						err := gitPack(s, s, s.Stderr(), repoDir)
227						if errors.Is(err, git.ErrInvalidRepo) {
228							sshFatal(s, git.ErrInvalidRepo)
229						} else if err != nil {
230							sshFatal(s, git.ErrSystemMalfunction)
231						}
232
233						counter.WithLabelValues(ak, s.User(), name).Inc()
234					}
235				}
236			}()
237			sh(s)
238		}
239	}
240}
241
242// sshFatal prints to the session's STDOUT as a git response and exit 1.
243func sshFatal(s ssh.Session, v ...interface{}) {
244	git.WritePktline(s, v...)
245	s.Exit(1) // nolint: errcheck
246}