ssh.go

  1package ssh
  2
  3import (
  4	"context"
  5	"errors"
  6	"fmt"
  7	"net"
  8	"os"
  9	"path/filepath"
 10	"strconv"
 11	"strings"
 12	"time"
 13
 14	"github.com/charmbracelet/keygen"
 15	"github.com/charmbracelet/log"
 16	"github.com/charmbracelet/soft-serve/server/backend"
 17	cm "github.com/charmbracelet/soft-serve/server/cmd"
 18	"github.com/charmbracelet/soft-serve/server/config"
 19	"github.com/charmbracelet/soft-serve/server/git"
 20	"github.com/charmbracelet/soft-serve/server/utils"
 21	"github.com/charmbracelet/ssh"
 22	"github.com/charmbracelet/wish"
 23	bm "github.com/charmbracelet/wish/bubbletea"
 24	lm "github.com/charmbracelet/wish/logging"
 25	rm "github.com/charmbracelet/wish/recover"
 26	"github.com/muesli/termenv"
 27	"github.com/prometheus/client_golang/prometheus"
 28	"github.com/prometheus/client_golang/prometheus/promauto"
 29	gossh "golang.org/x/crypto/ssh"
 30)
 31
 32var (
 33	publicKeyCounter = promauto.NewCounterVec(prometheus.CounterOpts{
 34		Namespace: "soft_serve",
 35		Subsystem: "ssh",
 36		Name:      "public_key_auth_total",
 37		Help:      "The total number of public key auth requests",
 38	}, []string{"key", "user", "allowed"})
 39
 40	noAuthCounter = promauto.NewCounterVec(prometheus.CounterOpts{
 41		Namespace: "soft_serve",
 42		Subsystem: "ssh",
 43		Name:      "none_auth_total",
 44		Help:      "The total number of none auth requests",
 45	}, []string{"user", "allowed"})
 46
 47	keyboardInteractiveCounter = promauto.NewCounterVec(prometheus.CounterOpts{
 48		Namespace: "soft_serve",
 49		Subsystem: "ssh",
 50		Name:      "keyboard_interactive_auth_total",
 51		Help:      "The total number of keyboard interactive auth requests",
 52	}, []string{"user", "allowed"})
 53
 54	uploadPackCounter = promauto.NewCounterVec(prometheus.CounterOpts{
 55		Namespace: "soft_serve",
 56		Subsystem: "ssh",
 57		Name:      "git_upload_pack_total",
 58		Help:      "The total number of git-upload-pack requests",
 59	}, []string{"key", "user", "repo"})
 60
 61	receivePackCounter = promauto.NewCounterVec(prometheus.CounterOpts{
 62		Namespace: "soft_serve",
 63		Subsystem: "ssh",
 64		Name:      "git_receive_pack_total",
 65		Help:      "The total number of git-receive-pack requests",
 66	}, []string{"key", "user", "repo"})
 67
 68	uploadArchiveCounter = promauto.NewCounterVec(prometheus.CounterOpts{
 69		Namespace: "soft_serve",
 70		Subsystem: "ssh",
 71		Name:      "git_upload_archive_total",
 72		Help:      "The total number of git-upload-archive requests",
 73	}, []string{"key", "user", "repo"})
 74
 75	createRepoCounter = promauto.NewCounterVec(prometheus.CounterOpts{
 76		Namespace: "soft_serve",
 77		Subsystem: "ssh",
 78		Name:      "create_repo_total",
 79		Help:      "The total number of create repo requests",
 80	}, []string{"key", "user", "repo"})
 81)
 82
 83// SSHServer is a SSH server that implements the git protocol.
 84type SSHServer struct {
 85	srv    *ssh.Server
 86	cfg    *config.Config
 87	be     backend.Backend
 88	ctx    context.Context
 89	logger *log.Logger
 90}
 91
 92// NewSSHServer returns a new SSHServer.
 93func NewSSHServer(ctx context.Context) (*SSHServer, error) {
 94	cfg := config.FromContext(ctx)
 95	logger := log.FromContext(ctx).WithPrefix("ssh")
 96
 97	var err error
 98	s := &SSHServer{
 99		cfg:    cfg,
100		ctx:    ctx,
101		be:     backend.FromContext(ctx),
102		logger: logger,
103	}
104
105	mw := []wish.Middleware{
106		rm.MiddlewareWithLogger(
107			logger,
108			// BubbleTea middleware.
109			bm.MiddlewareWithProgramHandler(SessionHandler(cfg), termenv.ANSI256),
110			// CLI middleware.
111			cm.Middleware(cfg, logger),
112			// Git middleware.
113			s.Middleware(cfg),
114			// Logging middleware.
115			lm.MiddlewareWithLogger(logger.
116				StandardLog(log.StandardLogOptions{ForceLevel: log.DebugLevel})),
117		),
118	}
119
120	s.srv, err = wish.NewServer(
121		ssh.PublicKeyAuth(s.PublicKeyHandler),
122		ssh.KeyboardInteractiveAuth(s.KeyboardInteractiveHandler),
123		wish.WithAddress(cfg.SSH.ListenAddr),
124		wish.WithHostKeyPath(cfg.SSH.KeyPath),
125		wish.WithMiddleware(mw...),
126	)
127	if err != nil {
128		return nil, err
129	}
130
131	// Custom ServerConfig
132	// which handles "none" auth method
133	s.srv.ServerConfigCallback = s.ServerConfigCallback
134
135	if cfg.SSH.MaxTimeout > 0 {
136		s.srv.MaxTimeout = time.Duration(cfg.SSH.MaxTimeout) * time.Second
137	}
138
139	if cfg.SSH.IdleTimeout > 0 {
140		s.srv.IdleTimeout = time.Duration(cfg.SSH.IdleTimeout) * time.Second
141	}
142
143	// Create client ssh key
144	if _, err := os.Stat(cfg.SSH.ClientKeyPath); err != nil && os.IsNotExist(err) {
145		_, err := keygen.New(cfg.SSH.ClientKeyPath, keygen.WithKeyType(keygen.Ed25519), keygen.WithWrite())
146		if err != nil {
147			return nil, fmt.Errorf("client ssh key: %w", err)
148		}
149	}
150
151	return s, nil
152}
153
154// ListenAndServe starts the SSH server.
155func (s *SSHServer) ListenAndServe() error {
156	return s.srv.ListenAndServe()
157}
158
159// Serve starts the SSH server on the given net.Listener.
160func (s *SSHServer) Serve(l net.Listener) error {
161	return s.srv.Serve(l)
162}
163
164// Close closes the SSH server.
165func (s *SSHServer) Close() error {
166	return s.srv.Close()
167}
168
169// Shutdown gracefully shuts down the SSH server.
170func (s *SSHServer) Shutdown(ctx context.Context) error {
171	return s.srv.Shutdown(ctx)
172}
173
174// PublicKeyAuthHandler handles public key authentication.
175func (s *SSHServer) PublicKeyHandler(ctx ssh.Context, pk ssh.PublicKey) (allowed bool) {
176	if pk == nil {
177		return false
178	}
179
180	ak := backend.MarshalAuthorizedKey(pk)
181	defer func(allowed *bool) {
182		publicKeyCounter.WithLabelValues(ak, ctx.User(), strconv.FormatBool(*allowed)).Inc()
183	}(&allowed)
184
185	ac := s.cfg.Backend.AccessLevelByPublicKey("", pk)
186	s.logger.Debugf("access level for %q: %s", ak, ac)
187	allowed = ac >= backend.ReadWriteAccess
188	return
189}
190
191// KeyboardInteractiveHandler handles keyboard interactive authentication.
192// This is used after all public key authentication has failed.
193func (s *SSHServer) KeyboardInteractiveHandler(ctx ssh.Context, _ gossh.KeyboardInteractiveChallenge) bool {
194	ac := s.cfg.Backend.AllowKeyless()
195	keyboardInteractiveCounter.WithLabelValues(ctx.User(), strconv.FormatBool(ac)).Inc()
196	return ac
197}
198
199// NoClientAuthCallback handles "none" auth method.
200func (s *SSHServer) NoClientAuthCallback(meta gossh.ConnMetadata) (*gossh.Permissions, error) {
201	ac := s.cfg.Backend.AllowKeyless()
202	keyboardInteractiveCounter.WithLabelValues(meta.User(), strconv.FormatBool(ac)).Inc()
203	perms := &gossh.Permissions{}
204	if !ac {
205		return perms, fmt.Errorf("keyless access not allowed")
206	}
207
208	return perms, nil
209}
210
211// ServerConfigCallback returns a custom ssh.ServerConfig.
212func (s *SSHServer) ServerConfigCallback(ctx ssh.Context) *gossh.ServerConfig {
213	return &gossh.ServerConfig{
214		NoClientAuth:         true,
215		NoClientAuthCallback: s.NoClientAuthCallback,
216	}
217}
218
219// Middleware adds Git server functionality to the ssh.Server. Repos are stored
220// in the specified repo directory. The provided Hooks implementation will be
221// checked for access on a per repo basis for a ssh.Session public key.
222// Hooks.Push and Hooks.Fetch will be called on successful completion of
223// their commands.
224func (ss *SSHServer) Middleware(cfg *config.Config) wish.Middleware {
225	return func(sh ssh.Handler) ssh.Handler {
226		return func(s ssh.Session) {
227			func() {
228				cmd := s.Command()
229				ctx := s.Context()
230				be := ss.be.WithContext(ctx)
231				if len(cmd) >= 2 && strings.HasPrefix(cmd[0], "git") {
232					gc := cmd[0]
233					// repo should be in the form of "repo.git"
234					name := utils.SanitizeRepo(cmd[1])
235					pk := s.PublicKey()
236					ak := backend.MarshalAuthorizedKey(pk)
237					access := cfg.Backend.AccessLevelByPublicKey(name, pk)
238					// git bare repositories should end in ".git"
239					// https://git-scm.com/docs/gitrepository-layout
240					repo := name + ".git"
241					reposDir := filepath.Join(cfg.DataPath, "repos")
242					if err := git.EnsureWithin(reposDir, repo); err != nil {
243						sshFatal(s, err)
244						return
245					}
246
247					// Environment variables to pass down to git hooks.
248					envs := []string{
249						"SOFT_SERVE_REPO_NAME=" + name,
250						"SOFT_SERVE_REPO_PATH=" + filepath.Join(reposDir, repo),
251						"SOFT_SERVE_PUBLIC_KEY=" + ak,
252					}
253
254					ss.logger.Debug("git middleware", "cmd", gc, "access", access.String())
255					repoDir := filepath.Join(reposDir, repo)
256					switch gc {
257					case git.ReceivePackBin:
258						if access < backend.ReadWriteAccess {
259							sshFatal(s, git.ErrNotAuthed)
260							return
261						}
262						if _, err := be.Repository(name); err != nil {
263							if _, err := be.CreateRepository(name, backend.RepositoryOptions{Private: false}); err != nil {
264								log.Errorf("failed to create repo: %s", err)
265								sshFatal(s, err)
266								return
267							}
268							createRepoCounter.WithLabelValues(ak, s.User(), name).Inc()
269						}
270						if err := git.ReceivePack(s.Context(), s, s, s.Stderr(), repoDir, envs...); err != nil {
271							sshFatal(s, git.ErrSystemMalfunction)
272						}
273						receivePackCounter.WithLabelValues(ak, s.User(), name).Inc()
274						return
275					case git.UploadPackBin, git.UploadArchiveBin:
276						if access < backend.ReadOnlyAccess {
277							sshFatal(s, git.ErrNotAuthed)
278							return
279						}
280
281						gitPack := git.UploadPack
282						counter := uploadPackCounter
283						if gc == git.UploadArchiveBin {
284							gitPack = git.UploadArchive
285							counter = uploadArchiveCounter
286						}
287
288						err := gitPack(ctx, s, s, s.Stderr(), repoDir, envs...)
289						if errors.Is(err, git.ErrInvalidRepo) {
290							sshFatal(s, git.ErrInvalidRepo)
291						} else if err != nil {
292							sshFatal(s, git.ErrSystemMalfunction)
293						}
294
295						counter.WithLabelValues(ak, s.User(), name).Inc()
296					}
297				}
298			}()
299			sh(s)
300		}
301	}
302}
303
304// sshFatal prints to the session's STDOUT as a git response and exit 1.
305func sshFatal(s ssh.Session, v ...interface{}) {
306	git.WritePktline(s, v...)
307	s.Exit(1) // nolint: errcheck
308}