1package ssh
  2
  3import (
  4	"context"
  5	"errors"
  6	"net"
  7	"path/filepath"
  8	"strconv"
  9	"strings"
 10	"time"
 11
 12	"github.com/charmbracelet/log"
 13	"github.com/charmbracelet/soft-serve/server/backend"
 14	cm "github.com/charmbracelet/soft-serve/server/cmd"
 15	"github.com/charmbracelet/soft-serve/server/config"
 16	"github.com/charmbracelet/soft-serve/server/git"
 17	"github.com/charmbracelet/soft-serve/server/hooks"
 18	"github.com/charmbracelet/soft-serve/server/utils"
 19	"github.com/charmbracelet/ssh"
 20	"github.com/charmbracelet/wish"
 21	bm "github.com/charmbracelet/wish/bubbletea"
 22	lm "github.com/charmbracelet/wish/logging"
 23	rm "github.com/charmbracelet/wish/recover"
 24	"github.com/muesli/termenv"
 25	"github.com/prometheus/client_golang/prometheus"
 26	"github.com/prometheus/client_golang/prometheus/promauto"
 27	gossh "golang.org/x/crypto/ssh"
 28)
 29
 30var (
 31	logger = log.WithPrefix("server.ssh")
 32)
 33
 34var (
 35	publicKeyCounter = promauto.NewCounterVec(prometheus.CounterOpts{
 36		Namespace: "soft_serve",
 37		Subsystem: "ssh",
 38		Name:      "public_key_auth_total",
 39		Help:      "The total number of public key auth requests",
 40	}, []string{"key", "user", "allowed"})
 41
 42	keyboardInteractiveCounter = promauto.NewCounterVec(prometheus.CounterOpts{
 43		Namespace: "soft_serve",
 44		Subsystem: "ssh",
 45		Name:      "keyboard_interactive_auth_total",
 46		Help:      "The total number of keyboard interactive auth requests",
 47	}, []string{"user", "allowed"})
 48
 49	uploadPackCounter = promauto.NewCounterVec(prometheus.CounterOpts{
 50		Namespace: "soft_serve",
 51		Subsystem: "ssh",
 52		Name:      "git_upload_pack_total",
 53		Help:      "The total number of git-upload-pack requests",
 54	}, []string{"key", "user", "repo"})
 55
 56	receivePackCounter = promauto.NewCounterVec(prometheus.CounterOpts{
 57		Namespace: "soft_serve",
 58		Subsystem: "ssh",
 59		Name:      "git_receive_pack_total",
 60		Help:      "The total number of git-receive-pack requests",
 61	}, []string{"key", "user", "repo"})
 62
 63	uploadArchiveCounter = promauto.NewCounterVec(prometheus.CounterOpts{
 64		Namespace: "soft_serve",
 65		Subsystem: "ssh",
 66		Name:      "git_upload_archive_total",
 67		Help:      "The total number of git-upload-archive requests",
 68	}, []string{"key", "user", "repo"})
 69
 70	createRepoCounter = promauto.NewCounterVec(prometheus.CounterOpts{
 71		Namespace: "soft_serve",
 72		Subsystem: "ssh",
 73		Name:      "create_repo_total",
 74		Help:      "The total number of create repo requests",
 75	}, []string{"key", "user", "repo"})
 76)
 77
 78// SSHServer is a SSH server that implements the git protocol.
 79type SSHServer struct {
 80	srv *ssh.Server
 81	cfg *config.Config
 82}
 83
 84// NewSSHServer returns a new SSHServer.
 85func NewSSHServer(cfg *config.Config, hooks hooks.Hooks) (*SSHServer, error) {
 86	var err error
 87	s := &SSHServer{cfg: cfg}
 88	logger := logger.StandardLog(log.StandardLogOptions{ForceLevel: log.DebugLevel})
 89	mw := []wish.Middleware{
 90		rm.MiddlewareWithLogger(
 91			logger,
 92			// BubbleTea middleware.
 93			bm.MiddlewareWithProgramHandler(SessionHandler(cfg), termenv.ANSI256),
 94			// CLI middleware.
 95			cm.Middleware(cfg, hooks),
 96			// Git middleware.
 97			s.Middleware(cfg),
 98			// Logging middleware.
 99			lm.MiddlewareWithLogger(logger),
100		),
101	}
102	s.srv, err = wish.NewServer(
103		ssh.PublicKeyAuth(s.PublicKeyHandler),
104		ssh.KeyboardInteractiveAuth(s.KeyboardInteractiveHandler),
105		wish.WithAddress(cfg.SSH.ListenAddr),
106		wish.WithHostKeyPath(filepath.Join(cfg.DataPath, cfg.SSH.KeyPath)),
107		wish.WithMiddleware(mw...),
108	)
109	if err != nil {
110		return nil, err
111	}
112
113	if cfg.SSH.MaxTimeout > 0 {
114		s.srv.MaxTimeout = time.Duration(cfg.SSH.MaxTimeout) * time.Second
115	}
116	if cfg.SSH.IdleTimeout > 0 {
117		s.srv.IdleTimeout = time.Duration(cfg.SSH.IdleTimeout) * time.Second
118	}
119
120	return s, nil
121}
122
123// ListenAndServe starts the SSH server.
124func (s *SSHServer) ListenAndServe() error {
125	return s.srv.ListenAndServe()
126}
127
128// Serve starts the SSH server on the given net.Listener.
129func (s *SSHServer) Serve(l net.Listener) error {
130	return s.srv.Serve(l)
131}
132
133// Close closes the SSH server.
134func (s *SSHServer) Close() error {
135	return s.srv.Close()
136}
137
138// Shutdown gracefully shuts down the SSH server.
139func (s *SSHServer) Shutdown(ctx context.Context) error {
140	return s.srv.Shutdown(ctx)
141}
142
143// PublicKeyAuthHandler handles public key authentication.
144func (s *SSHServer) PublicKeyHandler(ctx ssh.Context, pk ssh.PublicKey) (allowed bool) {
145	if pk == nil {
146		return s.cfg.Backend.AllowKeyless()
147	}
148
149	ak := backend.MarshalAuthorizedKey(pk)
150	defer func() {
151		publicKeyCounter.WithLabelValues(ak, ctx.User(), strconv.FormatBool(allowed)).Inc()
152	}()
153
154	for _, k := range s.cfg.InitialAdminKeys {
155		if k == ak {
156			allowed = true
157			return
158		}
159	}
160
161	ac := s.cfg.Backend.AccessLevelByPublicKey("", pk)
162	logger.Debugf("access level for %s: %d", ak, ac)
163	allowed = ac >= backend.ReadOnlyAccess
164	return
165}
166
167// KeyboardInteractiveHandler handles keyboard interactive authentication.
168func (s *SSHServer) KeyboardInteractiveHandler(ctx ssh.Context, _ gossh.KeyboardInteractiveChallenge) bool {
169	ac := s.cfg.Backend.AllowKeyless()
170	keyboardInteractiveCounter.WithLabelValues(ctx.User(), strconv.FormatBool(ac)).Inc()
171	return ac
172}
173
174// Middleware adds Git server functionality to the ssh.Server. Repos are stored
175// in the specified repo directory. The provided Hooks implementation will be
176// checked for access on a per repo basis for a ssh.Session public key.
177// Hooks.Push and Hooks.Fetch will be called on successful completion of
178// their commands.
179func (s *SSHServer) Middleware(cfg *config.Config) wish.Middleware {
180	return func(sh ssh.Handler) ssh.Handler {
181		return func(s ssh.Session) {
182			func() {
183				cmd := s.Command()
184				if len(cmd) >= 2 && strings.HasPrefix(cmd[0], "git") {
185					gc := cmd[0]
186					// repo should be in the form of "repo.git"
187					name := utils.SanitizeRepo(cmd[1])
188					pk := s.PublicKey()
189					ak := backend.MarshalAuthorizedKey(pk)
190					access := cfg.Backend.AccessLevelByPublicKey(name, pk)
191					// git bare repositories should end in ".git"
192					// https://git-scm.com/docs/gitrepository-layout
193					repo := name + ".git"
194					reposDir := filepath.Join(cfg.DataPath, "repos")
195					if err := git.EnsureWithin(reposDir, repo); err != nil {
196						sshFatal(s, err)
197						return
198					}
199
200					logger.Debug("git middleware", "cmd", gc, "access", access.String())
201					repoDir := filepath.Join(reposDir, repo)
202					switch gc {
203					case git.ReceivePackBin:
204						if access < backend.ReadWriteAccess {
205							sshFatal(s, git.ErrNotAuthed)
206							return
207						}
208						if _, err := cfg.Backend.Repository(name); err != nil {
209							if _, err := cfg.Backend.CreateRepository(name, backend.RepositoryOptions{Private: false}); err != nil {
210								log.Errorf("failed to create repo: %s", err)
211								sshFatal(s, err)
212								return
213							}
214							createRepoCounter.WithLabelValues(ak, s.User(), name).Inc()
215						}
216						if err := git.ReceivePack(s, s, s.Stderr(), repoDir); err != nil {
217							sshFatal(s, git.ErrSystemMalfunction)
218						}
219						receivePackCounter.WithLabelValues(ak, s.User(), name).Inc()
220						return
221					case git.UploadPackBin, git.UploadArchiveBin:
222						if access < backend.ReadOnlyAccess {
223							sshFatal(s, git.ErrNotAuthed)
224							return
225						}
226
227						gitPack := git.UploadPack
228						counter := uploadPackCounter
229						if gc == git.UploadArchiveBin {
230							gitPack = git.UploadArchive
231							counter = uploadArchiveCounter
232						}
233
234						err := gitPack(s, s, s.Stderr(), repoDir)
235						if errors.Is(err, git.ErrInvalidRepo) {
236							sshFatal(s, git.ErrInvalidRepo)
237						} else if err != nil {
238							sshFatal(s, git.ErrSystemMalfunction)
239						}
240
241						counter.WithLabelValues(ak, s.User(), name).Inc()
242					}
243				}
244			}()
245			sh(s)
246		}
247	}
248}
249
250// sshFatal prints to the session's STDOUT as a git response and exit 1.
251func sshFatal(s ssh.Session, v ...interface{}) {
252	git.WritePktline(s, v...)
253	s.Exit(1) // nolint: errcheck
254}