1package server
  2
  3import (
  4	"context"
  5	"errors"
  6	"net"
  7	"path/filepath"
  8	"strconv"
  9	"strings"
 10	"time"
 11
 12	"github.com/charmbracelet/log"
 13	"github.com/charmbracelet/soft-serve/server/backend"
 14	cm "github.com/charmbracelet/soft-serve/server/cmd"
 15	"github.com/charmbracelet/soft-serve/server/config"
 16	"github.com/charmbracelet/soft-serve/server/hooks"
 17	"github.com/charmbracelet/soft-serve/server/utils"
 18	"github.com/charmbracelet/ssh"
 19	"github.com/charmbracelet/wish"
 20	bm "github.com/charmbracelet/wish/bubbletea"
 21	lm "github.com/charmbracelet/wish/logging"
 22	rm "github.com/charmbracelet/wish/recover"
 23	"github.com/muesli/termenv"
 24	"github.com/prometheus/client_golang/prometheus"
 25	"github.com/prometheus/client_golang/prometheus/promauto"
 26	gossh "golang.org/x/crypto/ssh"
 27)
 28
 29var (
 30	publicKeyCounter = promauto.NewCounterVec(prometheus.CounterOpts{
 31		Namespace: "soft_serve",
 32		Subsystem: "ssh",
 33		Name:      "public_key_auth_total",
 34		Help:      "The total number of public key auth requests",
 35	}, []string{"key", "user", "allowed"})
 36
 37	keyboardInteractiveCounter = promauto.NewCounterVec(prometheus.CounterOpts{
 38		Namespace: "soft_serve",
 39		Subsystem: "ssh",
 40		Name:      "keyboard_interactive_auth_total",
 41		Help:      "The total number of keyboard interactive auth requests",
 42	}, []string{"user", "allowed"})
 43
 44	uploadPackCounter = promauto.NewCounterVec(prometheus.CounterOpts{
 45		Namespace: "soft_serve",
 46		Subsystem: "ssh",
 47		Name:      "git_upload_pack_total",
 48		Help:      "The total number of git-upload-pack requests",
 49	}, []string{"key", "user", "repo"})
 50
 51	receivePackCounter = promauto.NewCounterVec(prometheus.CounterOpts{
 52		Namespace: "soft_serve",
 53		Subsystem: "ssh",
 54		Name:      "git_receive_pack_total",
 55		Help:      "The total number of git-receive-pack requests",
 56	}, []string{"key", "user", "repo"})
 57
 58	uploadArchiveCounter = promauto.NewCounterVec(prometheus.CounterOpts{
 59		Namespace: "soft_serve",
 60		Subsystem: "ssh",
 61		Name:      "git_upload_archive_total",
 62		Help:      "The total number of git-upload-archive requests",
 63	}, []string{"key", "user", "repo"})
 64
 65	createRepoCounter = promauto.NewCounterVec(prometheus.CounterOpts{
 66		Namespace: "soft_serve",
 67		Subsystem: "ssh",
 68		Name:      "create_repo_total",
 69		Help:      "The total number of create repo requests",
 70	}, []string{"key", "user", "repo"})
 71)
 72
 73// SSHServer is a SSH server that implements the git protocol.
 74type SSHServer struct {
 75	srv *ssh.Server
 76	cfg *config.Config
 77}
 78
 79// NewSSHServer returns a new SSHServer.
 80func NewSSHServer(cfg *config.Config, hooks hooks.Hooks) (*SSHServer, error) {
 81	var err error
 82	s := &SSHServer{cfg: cfg}
 83	logger := logger.StandardLog(log.StandardLogOptions{ForceLevel: log.DebugLevel})
 84	mw := []wish.Middleware{
 85		rm.MiddlewareWithLogger(
 86			logger,
 87			// BubbleTea middleware.
 88			bm.MiddlewareWithProgramHandler(SessionHandler(cfg), termenv.ANSI256),
 89			// CLI middleware.
 90			cm.Middleware(cfg, hooks),
 91			// Git middleware.
 92			s.Middleware(cfg),
 93			// Logging middleware.
 94			lm.MiddlewareWithLogger(logger),
 95		),
 96	}
 97	s.srv, err = wish.NewServer(
 98		ssh.PublicKeyAuth(s.PublicKeyHandler),
 99		ssh.KeyboardInteractiveAuth(s.KeyboardInteractiveHandler),
100		wish.WithAddress(cfg.SSH.ListenAddr),
101		wish.WithHostKeyPath(filepath.Join(cfg.DataPath, cfg.SSH.KeyPath)),
102		wish.WithMiddleware(mw...),
103	)
104	if err != nil {
105		return nil, err
106	}
107
108	if cfg.SSH.MaxTimeout > 0 {
109		s.srv.MaxTimeout = time.Duration(cfg.SSH.MaxTimeout) * time.Second
110	}
111	if cfg.SSH.IdleTimeout > 0 {
112		s.srv.IdleTimeout = time.Duration(cfg.SSH.IdleTimeout) * time.Second
113	}
114
115	return s, nil
116}
117
118// ListenAndServe starts the SSH server.
119func (s *SSHServer) ListenAndServe() error {
120	return s.srv.ListenAndServe()
121}
122
123// Serve starts the SSH server on the given net.Listener.
124func (s *SSHServer) Serve(l net.Listener) error {
125	return s.srv.Serve(l)
126}
127
128// Close closes the SSH server.
129func (s *SSHServer) Close() error {
130	return s.srv.Close()
131}
132
133// Shutdown gracefully shuts down the SSH server.
134func (s *SSHServer) Shutdown(ctx context.Context) error {
135	return s.srv.Shutdown(ctx)
136}
137
138// PublicKeyAuthHandler handles public key authentication.
139func (s *SSHServer) PublicKeyHandler(ctx ssh.Context, pk ssh.PublicKey) (allowed bool) {
140	ak := backend.MarshalAuthorizedKey(pk)
141	defer func() {
142		publicKeyCounter.WithLabelValues(ak, ctx.User(), strconv.FormatBool(allowed)).Inc()
143	}()
144	for _, k := range s.cfg.InitialAdminKeys {
145		if k == ak {
146			allowed = true
147			return
148		}
149	}
150
151	user, _ := s.cfg.Backend.UserByPublicKey(pk)
152	if user == nil {
153		logger.Debug("public key auth user not found")
154		return s.cfg.Backend.AnonAccess() >= backend.ReadOnlyAccess
155	}
156
157	allowed = s.cfg.Backend.AccessLevel("", user.Username()) >= backend.ReadOnlyAccess
158	return
159}
160
161// KeyboardInteractiveHandler handles keyboard interactive authentication.
162func (s *SSHServer) KeyboardInteractiveHandler(ctx ssh.Context, _ gossh.KeyboardInteractiveChallenge) bool {
163	ac := s.cfg.Backend.AllowKeyless() && s.PublicKeyHandler(ctx, nil)
164	keyboardInteractiveCounter.WithLabelValues(ctx.User(), strconv.FormatBool(ac)).Inc()
165	return ac
166}
167
168// Middleware adds Git server functionality to the ssh.Server. Repos are stored
169// in the specified repo directory. The provided Hooks implementation will be
170// checked for access on a per repo basis for a ssh.Session public key.
171// Hooks.Push and Hooks.Fetch will be called on successful completion of
172// their commands.
173func (s *SSHServer) Middleware(cfg *config.Config) wish.Middleware {
174	return func(sh ssh.Handler) ssh.Handler {
175		return func(s ssh.Session) {
176			func() {
177				cmd := s.Command()
178				if len(cmd) >= 2 && strings.HasPrefix(cmd[0], "git") {
179					gc := cmd[0]
180					// repo should be in the form of "repo.git"
181					name := utils.SanitizeRepo(cmd[1])
182					pk := s.PublicKey()
183					ak := backend.MarshalAuthorizedKey(pk)
184					access := cfg.Backend.AccessLevelByPublicKey(name, pk)
185					// git bare repositories should end in ".git"
186					// https://git-scm.com/docs/gitrepository-layout
187					repo := name + ".git"
188					reposDir := filepath.Join(cfg.DataPath, "repos")
189					if err := ensureWithin(reposDir, repo); err != nil {
190						sshFatal(s, err)
191						return
192					}
193
194					repoDir := filepath.Join(reposDir, repo)
195					switch gc {
196					case receivePackBin:
197						if access < backend.ReadWriteAccess {
198							sshFatal(s, ErrNotAuthed)
199							return
200						}
201						if _, err := cfg.Backend.Repository(name); err != nil {
202							if _, err := cfg.Backend.CreateRepository(name, backend.RepositoryOptions{Private: false}); err != nil {
203								log.Errorf("failed to create repo: %s", err)
204								sshFatal(s, err)
205								return
206							}
207							createRepoCounter.WithLabelValues(ak, s.User(), name).Inc()
208						}
209						if err := receivePack(s, s, s.Stderr(), repoDir); err != nil {
210							sshFatal(s, ErrSystemMalfunction)
211						}
212						receivePackCounter.WithLabelValues(ak, s.User(), name).Inc()
213						return
214					case uploadPackBin, uploadArchiveBin:
215						if access < backend.ReadOnlyAccess {
216							sshFatal(s, ErrNotAuthed)
217							return
218						}
219
220						gitPack := uploadPack
221						counter := uploadPackCounter
222						if gc == uploadArchiveBin {
223							gitPack = uploadArchive
224							counter = uploadArchiveCounter
225						}
226
227						err := gitPack(s, s, s.Stderr(), repoDir)
228						if errors.Is(err, ErrInvalidRepo) {
229							sshFatal(s, ErrInvalidRepo)
230						} else if err != nil {
231							sshFatal(s, ErrSystemMalfunction)
232						}
233
234						counter.WithLabelValues(ak, s.User(), name).Inc()
235					}
236				}
237			}()
238			sh(s)
239		}
240	}
241}
242
243// sshFatal prints to the session's STDOUT as a git response and exit 1.
244func sshFatal(s ssh.Session, v ...interface{}) {
245	writePktline(s, v...)
246	s.Exit(1) // nolint: errcheck
247}