ssh.go

  1package ssh
  2
  3import (
  4	"context"
  5	"errors"
  6	"fmt"
  7	"net"
  8	"os"
  9	"path/filepath"
 10	"strconv"
 11	"strings"
 12	"time"
 13
 14	"github.com/charmbracelet/keygen"
 15	"github.com/charmbracelet/log"
 16	"github.com/charmbracelet/soft-serve/server/backend"
 17	cm "github.com/charmbracelet/soft-serve/server/cmd"
 18	"github.com/charmbracelet/soft-serve/server/config"
 19	"github.com/charmbracelet/soft-serve/server/git"
 20	"github.com/charmbracelet/soft-serve/server/utils"
 21	"github.com/charmbracelet/ssh"
 22	"github.com/charmbracelet/wish"
 23	bm "github.com/charmbracelet/wish/bubbletea"
 24	lm "github.com/charmbracelet/wish/logging"
 25	rm "github.com/charmbracelet/wish/recover"
 26	"github.com/muesli/termenv"
 27	"github.com/prometheus/client_golang/prometheus"
 28	"github.com/prometheus/client_golang/prometheus/promauto"
 29	gossh "golang.org/x/crypto/ssh"
 30)
 31
 32var (
 33	publicKeyCounter = promauto.NewCounterVec(prometheus.CounterOpts{
 34		Namespace: "soft_serve",
 35		Subsystem: "ssh",
 36		Name:      "public_key_auth_total",
 37		Help:      "The total number of public key auth requests",
 38	}, []string{"key", "user", "allowed"})
 39
 40	keyboardInteractiveCounter = promauto.NewCounterVec(prometheus.CounterOpts{
 41		Namespace: "soft_serve",
 42		Subsystem: "ssh",
 43		Name:      "keyboard_interactive_auth_total",
 44		Help:      "The total number of keyboard interactive auth requests",
 45	}, []string{"user", "allowed"})
 46
 47	uploadPackCounter = promauto.NewCounterVec(prometheus.CounterOpts{
 48		Namespace: "soft_serve",
 49		Subsystem: "ssh",
 50		Name:      "git_upload_pack_total",
 51		Help:      "The total number of git-upload-pack requests",
 52	}, []string{"key", "user", "repo"})
 53
 54	receivePackCounter = promauto.NewCounterVec(prometheus.CounterOpts{
 55		Namespace: "soft_serve",
 56		Subsystem: "ssh",
 57		Name:      "git_receive_pack_total",
 58		Help:      "The total number of git-receive-pack requests",
 59	}, []string{"key", "user", "repo"})
 60
 61	uploadArchiveCounter = promauto.NewCounterVec(prometheus.CounterOpts{
 62		Namespace: "soft_serve",
 63		Subsystem: "ssh",
 64		Name:      "git_upload_archive_total",
 65		Help:      "The total number of git-upload-archive requests",
 66	}, []string{"key", "user", "repo"})
 67
 68	createRepoCounter = promauto.NewCounterVec(prometheus.CounterOpts{
 69		Namespace: "soft_serve",
 70		Subsystem: "ssh",
 71		Name:      "create_repo_total",
 72		Help:      "The total number of create repo requests",
 73	}, []string{"key", "user", "repo"})
 74)
 75
 76// SSHServer is a SSH server that implements the git protocol.
 77type SSHServer struct {
 78	srv    *ssh.Server
 79	cfg    *config.Config
 80	be     backend.Backend
 81	ctx    context.Context
 82	logger *log.Logger
 83}
 84
 85// NewSSHServer returns a new SSHServer.
 86func NewSSHServer(ctx context.Context) (*SSHServer, error) {
 87	cfg := config.FromContext(ctx)
 88	logger := log.FromContext(ctx).WithPrefix("ssh")
 89
 90	var err error
 91	s := &SSHServer{
 92		cfg:    cfg,
 93		ctx:    ctx,
 94		be:     backend.FromContext(ctx),
 95		logger: logger,
 96	}
 97
 98	mw := []wish.Middleware{
 99		rm.MiddlewareWithLogger(
100			logger,
101			// BubbleTea middleware.
102			bm.MiddlewareWithProgramHandler(SessionHandler(cfg), termenv.ANSI256),
103			// CLI middleware.
104			cm.Middleware(cfg, logger),
105			// Git middleware.
106			s.Middleware(cfg),
107			// Logging middleware.
108			lm.MiddlewareWithLogger(logger.
109				StandardLog(log.StandardLogOptions{ForceLevel: log.DebugLevel})),
110		),
111	}
112
113	s.srv, err = wish.NewServer(
114		ssh.PublicKeyAuth(s.PublicKeyHandler),
115		ssh.KeyboardInteractiveAuth(s.KeyboardInteractiveHandler),
116		wish.WithAddress(cfg.SSH.ListenAddr),
117		wish.WithHostKeyPath(cfg.SSH.KeyPath),
118		wish.WithMiddleware(mw...),
119	)
120	if err != nil {
121		return nil, err
122	}
123
124	if cfg.SSH.MaxTimeout > 0 {
125		s.srv.MaxTimeout = time.Duration(cfg.SSH.MaxTimeout) * time.Second
126	}
127
128	if cfg.SSH.IdleTimeout > 0 {
129		s.srv.IdleTimeout = time.Duration(cfg.SSH.IdleTimeout) * time.Second
130	}
131
132	// Create client ssh key
133	if _, err := os.Stat(cfg.SSH.ClientKeyPath); err != nil && os.IsNotExist(err) {
134		_, err := keygen.New(cfg.SSH.ClientKeyPath, keygen.WithKeyType(keygen.Ed25519), keygen.WithWrite())
135		if err != nil {
136			return nil, fmt.Errorf("client ssh key: %w", err)
137		}
138	}
139
140	return s, nil
141}
142
143// ListenAndServe starts the SSH server.
144func (s *SSHServer) ListenAndServe() error {
145	return s.srv.ListenAndServe()
146}
147
148// Serve starts the SSH server on the given net.Listener.
149func (s *SSHServer) Serve(l net.Listener) error {
150	return s.srv.Serve(l)
151}
152
153// Close closes the SSH server.
154func (s *SSHServer) Close() error {
155	return s.srv.Close()
156}
157
158// Shutdown gracefully shuts down the SSH server.
159func (s *SSHServer) Shutdown(ctx context.Context) error {
160	return s.srv.Shutdown(ctx)
161}
162
163// PublicKeyAuthHandler handles public key authentication.
164func (s *SSHServer) PublicKeyHandler(ctx ssh.Context, pk ssh.PublicKey) (allowed bool) {
165	if pk == nil {
166		return false
167	}
168
169	ak := backend.MarshalAuthorizedKey(pk)
170	defer func(allowed *bool) {
171		publicKeyCounter.WithLabelValues(ak, ctx.User(), strconv.FormatBool(*allowed)).Inc()
172	}(&allowed)
173
174	ac := s.cfg.Backend.AccessLevelByPublicKey("", pk)
175	s.logger.Debugf("access level for %q: %s", ak, ac)
176	allowed = ac >= backend.ReadWriteAccess
177	return
178}
179
180// KeyboardInteractiveHandler handles keyboard interactive authentication.
181// This is used after all public key authentication has failed.
182func (s *SSHServer) KeyboardInteractiveHandler(ctx ssh.Context, _ gossh.KeyboardInteractiveChallenge) bool {
183	ac := s.cfg.Backend.AllowKeyless()
184	keyboardInteractiveCounter.WithLabelValues(ctx.User(), strconv.FormatBool(ac)).Inc()
185	return ac
186}
187
188// Middleware adds Git server functionality to the ssh.Server. Repos are stored
189// in the specified repo directory. The provided Hooks implementation will be
190// checked for access on a per repo basis for a ssh.Session public key.
191// Hooks.Push and Hooks.Fetch will be called on successful completion of
192// their commands.
193func (ss *SSHServer) Middleware(cfg *config.Config) wish.Middleware {
194	return func(sh ssh.Handler) ssh.Handler {
195		return func(s ssh.Session) {
196			func() {
197				cmdLine := s.Command()
198				ctx := s.Context()
199				be := ss.be.WithContext(ctx)
200
201				if len(cmdLine) >= 2 && strings.HasPrefix(cmdLine[0], "git") {
202					// repo should be in the form of "repo.git"
203					name := utils.SanitizeRepo(cmdLine[1])
204					pk := s.PublicKey()
205					ak := backend.MarshalAuthorizedKey(pk)
206					access := cfg.Backend.AccessLevelByPublicKey(name, pk)
207					// git bare repositories should end in ".git"
208					// https://git-scm.com/docs/gitrepository-layout
209					repo := name + ".git"
210					reposDir := filepath.Join(cfg.DataPath, "repos")
211					if err := git.EnsureWithin(reposDir, repo); err != nil {
212						sshFatal(s, err)
213						return
214					}
215
216					// Environment variables to pass down to git hooks.
217					envs := []string{
218						"SOFT_SERVE_REPO_NAME=" + name,
219						"SOFT_SERVE_REPO_PATH=" + filepath.Join(reposDir, repo),
220						"SOFT_SERVE_PUBLIC_KEY=" + ak,
221						"SOFT_SERVE_USERNAME=" + ctx.User(),
222					}
223
224					// Add ssh session & config environ
225					envs = append(envs, s.Environ()...)
226					envs = append(envs, cfg.Environ()...)
227
228					repoDir := filepath.Join(reposDir, repo)
229					service := git.Service(cmdLine[0])
230					cmd := git.ServiceCommand{
231						Stdin:  s,
232						Stdout: s,
233						Stderr: s.Stderr(),
234						Env:    envs,
235						Dir:    repoDir,
236					}
237
238					ss.logger.Debug("git middleware", "cmd", service, "access", access.String())
239
240					switch service {
241					case git.ReceivePackService:
242						if access < backend.ReadWriteAccess {
243							sshFatal(s, git.ErrNotAuthed)
244							return
245						}
246						if _, err := be.Repository(name); err != nil {
247							if _, err := be.CreateRepository(name, backend.RepositoryOptions{Private: false}); err != nil {
248								log.Errorf("failed to create repo: %s", err)
249								sshFatal(s, err)
250								return
251							}
252
253							createRepoCounter.WithLabelValues(ak, s.User(), name).Inc()
254						}
255
256						if err := git.ReceivePack(ctx, cmd); err != nil {
257							sshFatal(s, git.ErrSystemMalfunction)
258						}
259
260						if err := git.EnsureDefaultBranch(ctx, cmd); err != nil {
261							sshFatal(s, git.ErrSystemMalfunction)
262						}
263
264						receivePackCounter.WithLabelValues(ak, s.User(), name).Inc()
265						return
266					case git.UploadPackService, git.UploadArchiveService:
267						if access < backend.ReadOnlyAccess {
268							sshFatal(s, git.ErrNotAuthed)
269							return
270						}
271
272						handler := git.UploadPack
273						counter := uploadPackCounter
274						if service == git.UploadArchiveService {
275							handler = git.UploadArchive
276							counter = uploadArchiveCounter
277						}
278
279						err := handler(ctx, cmd)
280						if errors.Is(err, git.ErrInvalidRepo) {
281							sshFatal(s, git.ErrInvalidRepo)
282						} else if err != nil {
283							sshFatal(s, git.ErrSystemMalfunction)
284						}
285
286						counter.WithLabelValues(ak, s.User(), name).Inc()
287					}
288				}
289			}()
290			sh(s)
291		}
292	}
293}
294
295// sshFatal prints to the session's STDOUT as a git response and exit 1.
296func sshFatal(s ssh.Session, v ...interface{}) {
297	git.WritePktline(s, v...)
298	s.Exit(1) // nolint: errcheck
299}