1package server
  2
  3import (
  4	"context"
  5	"errors"
  6	"net"
  7	"path/filepath"
  8	"strconv"
  9	"strings"
 10	"time"
 11
 12	"github.com/charmbracelet/log"
 13	"github.com/charmbracelet/soft-serve/server/backend"
 14	cm "github.com/charmbracelet/soft-serve/server/cmd"
 15	"github.com/charmbracelet/soft-serve/server/config"
 16	"github.com/charmbracelet/soft-serve/server/hooks"
 17	"github.com/charmbracelet/soft-serve/server/utils"
 18	"github.com/charmbracelet/ssh"
 19	"github.com/charmbracelet/wish"
 20	bm "github.com/charmbracelet/wish/bubbletea"
 21	lm "github.com/charmbracelet/wish/logging"
 22	rm "github.com/charmbracelet/wish/recover"
 23	"github.com/muesli/termenv"
 24	"github.com/prometheus/client_golang/prometheus"
 25	"github.com/prometheus/client_golang/prometheus/promauto"
 26	gossh "golang.org/x/crypto/ssh"
 27)
 28
 29var (
 30	publicKeyCounter = promauto.NewCounterVec(prometheus.CounterOpts{
 31		Namespace: "soft_serve",
 32		Subsystem: "ssh",
 33		Name:      "public_key_auth_total",
 34		Help:      "The total number of public key auth requests",
 35	}, []string{"key", "user", "allowed"})
 36
 37	keyboardInteractiveCounter = promauto.NewCounterVec(prometheus.CounterOpts{
 38		Namespace: "soft_serve",
 39		Subsystem: "ssh",
 40		Name:      "keyboard_interactive_auth_total",
 41		Help:      "The total number of keyboard interactive auth requests",
 42	}, []string{"user", "allowed"})
 43
 44	uploadPackCounter = promauto.NewCounterVec(prometheus.CounterOpts{
 45		Namespace: "soft_serve",
 46		Subsystem: "ssh",
 47		Name:      "git_upload_pack_total",
 48		Help:      "The total number of git-upload-pack requests",
 49	}, []string{"key", "user", "repo"})
 50
 51	receivePackCounter = promauto.NewCounterVec(prometheus.CounterOpts{
 52		Namespace: "soft_serve",
 53		Subsystem: "ssh",
 54		Name:      "git_receive_pack_total",
 55		Help:      "The total number of git-receive-pack requests",
 56	}, []string{"key", "user", "repo"})
 57
 58	uploadArchiveCounter = promauto.NewCounterVec(prometheus.CounterOpts{
 59		Namespace: "soft_serve",
 60		Subsystem: "ssh",
 61		Name:      "git_upload_archive_total",
 62		Help:      "The total number of git-upload-archive requests",
 63	}, []string{"key", "user", "repo"})
 64
 65	createRepoCounter = promauto.NewCounterVec(prometheus.CounterOpts{
 66		Namespace: "soft_serve",
 67		Subsystem: "ssh",
 68		Name:      "create_repo_total",
 69		Help:      "The total number of create repo requests",
 70	}, []string{"key", "user", "repo"})
 71)
 72
 73// SSHServer is a SSH server that implements the git protocol.
 74type SSHServer struct {
 75	srv *ssh.Server
 76	cfg *config.Config
 77}
 78
 79// NewSSHServer returns a new SSHServer.
 80func NewSSHServer(cfg *config.Config, hooks hooks.Hooks) (*SSHServer, error) {
 81	var err error
 82	s := &SSHServer{cfg: cfg}
 83	logger := logger.StandardLog(log.StandardLogOptions{ForceLevel: log.DebugLevel})
 84	mw := []wish.Middleware{
 85		rm.MiddlewareWithLogger(
 86			logger,
 87			// BubbleTea middleware.
 88			bm.MiddlewareWithProgramHandler(SessionHandler(cfg), termenv.ANSI256),
 89			// CLI middleware.
 90			cm.Middleware(cfg, hooks),
 91			// Git middleware.
 92			s.Middleware(cfg),
 93			// Logging middleware.
 94			lm.MiddlewareWithLogger(logger),
 95		),
 96	}
 97	s.srv, err = wish.NewServer(
 98		ssh.PublicKeyAuth(s.PublicKeyHandler),
 99		ssh.KeyboardInteractiveAuth(s.KeyboardInteractiveHandler),
100		wish.WithAddress(cfg.SSH.ListenAddr),
101		wish.WithHostKeyPath(filepath.Join(cfg.DataPath, cfg.SSH.KeyPath)),
102		wish.WithMiddleware(mw...),
103	)
104	if err != nil {
105		return nil, err
106	}
107
108	if cfg.SSH.MaxTimeout > 0 {
109		s.srv.MaxTimeout = time.Duration(cfg.SSH.MaxTimeout) * time.Second
110	}
111	if cfg.SSH.IdleTimeout > 0 {
112		s.srv.IdleTimeout = time.Duration(cfg.SSH.IdleTimeout) * time.Second
113	}
114
115	return s, nil
116}
117
118// ListenAndServe starts the SSH server.
119func (s *SSHServer) ListenAndServe() error {
120	return s.srv.ListenAndServe()
121}
122
123// Serve starts the SSH server on the given net.Listener.
124func (s *SSHServer) Serve(l net.Listener) error {
125	return s.srv.Serve(l)
126}
127
128// Close closes the SSH server.
129func (s *SSHServer) Close() error {
130	return s.srv.Close()
131}
132
133// Shutdown gracefully shuts down the SSH server.
134func (s *SSHServer) Shutdown(ctx context.Context) error {
135	return s.srv.Shutdown(ctx)
136}
137
138// PublicKeyAuthHandler handles public key authentication.
139func (s *SSHServer) PublicKeyHandler(ctx ssh.Context, pk ssh.PublicKey) (allowed bool) {
140	if pk == nil {
141		return s.cfg.Backend.AllowKeyless()
142	}
143
144	ak := backend.MarshalAuthorizedKey(pk)
145	defer func() {
146		publicKeyCounter.WithLabelValues(ak, ctx.User(), strconv.FormatBool(allowed)).Inc()
147	}()
148
149	for _, k := range s.cfg.InitialAdminKeys {
150		if k == ak {
151			allowed = true
152			return
153		}
154	}
155
156	ac := s.cfg.Backend.AccessLevelByPublicKey("", pk)
157	logger.Debugf("access level for %s: %d", ak, ac)
158	allowed = ac >= backend.ReadOnlyAccess
159	return
160}
161
162// KeyboardInteractiveHandler handles keyboard interactive authentication.
163func (s *SSHServer) KeyboardInteractiveHandler(ctx ssh.Context, _ gossh.KeyboardInteractiveChallenge) bool {
164	ac := s.cfg.Backend.AllowKeyless()
165	keyboardInteractiveCounter.WithLabelValues(ctx.User(), strconv.FormatBool(ac)).Inc()
166	return ac
167}
168
169// Middleware adds Git server functionality to the ssh.Server. Repos are stored
170// in the specified repo directory. The provided Hooks implementation will be
171// checked for access on a per repo basis for a ssh.Session public key.
172// Hooks.Push and Hooks.Fetch will be called on successful completion of
173// their commands.
174func (s *SSHServer) Middleware(cfg *config.Config) wish.Middleware {
175	return func(sh ssh.Handler) ssh.Handler {
176		return func(s ssh.Session) {
177			func() {
178				cmd := s.Command()
179				if len(cmd) >= 2 && strings.HasPrefix(cmd[0], "git") {
180					gc := cmd[0]
181					// repo should be in the form of "repo.git"
182					name := utils.SanitizeRepo(cmd[1])
183					pk := s.PublicKey()
184					ak := backend.MarshalAuthorizedKey(pk)
185					access := cfg.Backend.AccessLevelByPublicKey(name, pk)
186					// git bare repositories should end in ".git"
187					// https://git-scm.com/docs/gitrepository-layout
188					repo := name + ".git"
189					reposDir := filepath.Join(cfg.DataPath, "repos")
190					if err := ensureWithin(reposDir, repo); err != nil {
191						sshFatal(s, err)
192						return
193					}
194
195					logger.Debug("git middleware", "cmd", gc, "access", access.String())
196					repoDir := filepath.Join(reposDir, repo)
197					switch gc {
198					case receivePackBin:
199						if access < backend.ReadWriteAccess {
200							sshFatal(s, ErrNotAuthed)
201							return
202						}
203						if _, err := cfg.Backend.Repository(name); err != nil {
204							if _, err := cfg.Backend.CreateRepository(name, backend.RepositoryOptions{Private: false}); err != nil {
205								log.Errorf("failed to create repo: %s", err)
206								sshFatal(s, err)
207								return
208							}
209							createRepoCounter.WithLabelValues(ak, s.User(), name).Inc()
210						}
211						if err := receivePack(s, s, s.Stderr(), repoDir); err != nil {
212							sshFatal(s, ErrSystemMalfunction)
213						}
214						receivePackCounter.WithLabelValues(ak, s.User(), name).Inc()
215						return
216					case uploadPackBin, uploadArchiveBin:
217						if access < backend.ReadOnlyAccess {
218							sshFatal(s, ErrNotAuthed)
219							return
220						}
221
222						gitPack := uploadPack
223						counter := uploadPackCounter
224						if gc == uploadArchiveBin {
225							gitPack = uploadArchive
226							counter = uploadArchiveCounter
227						}
228
229						err := gitPack(s, s, s.Stderr(), repoDir)
230						if errors.Is(err, ErrInvalidRepo) {
231							sshFatal(s, ErrInvalidRepo)
232						} else if err != nil {
233							sshFatal(s, ErrSystemMalfunction)
234						}
235
236						counter.WithLabelValues(ak, s.User(), name).Inc()
237					}
238				}
239			}()
240			sh(s)
241		}
242	}
243}
244
245// sshFatal prints to the session's STDOUT as a git response and exit 1.
246func sshFatal(s ssh.Session, v ...interface{}) {
247	writePktline(s, v...)
248	s.Exit(1) // nolint: errcheck
249}