ssh.go

  1package ssh
  2
  3import (
  4	"context"
  5	"errors"
  6	"fmt"
  7	"net"
  8	"os"
  9	"path/filepath"
 10	"strconv"
 11	"strings"
 12	"time"
 13
 14	"github.com/charmbracelet/keygen"
 15	"github.com/charmbracelet/log"
 16	"github.com/charmbracelet/soft-serve/server/backend"
 17	cm "github.com/charmbracelet/soft-serve/server/cmd"
 18	"github.com/charmbracelet/soft-serve/server/config"
 19	"github.com/charmbracelet/soft-serve/server/git"
 20	"github.com/charmbracelet/soft-serve/server/utils"
 21	"github.com/charmbracelet/ssh"
 22	"github.com/charmbracelet/wish"
 23	bm "github.com/charmbracelet/wish/bubbletea"
 24	lm "github.com/charmbracelet/wish/logging"
 25	rm "github.com/charmbracelet/wish/recover"
 26	"github.com/muesli/termenv"
 27	"github.com/prometheus/client_golang/prometheus"
 28	"github.com/prometheus/client_golang/prometheus/promauto"
 29	gossh "golang.org/x/crypto/ssh"
 30)
 31
 32var (
 33	publicKeyCounter = promauto.NewCounterVec(prometheus.CounterOpts{
 34		Namespace: "soft_serve",
 35		Subsystem: "ssh",
 36		Name:      "public_key_auth_total",
 37		Help:      "The total number of public key auth requests",
 38	}, []string{"allowed"})
 39
 40	keyboardInteractiveCounter = promauto.NewCounterVec(prometheus.CounterOpts{
 41		Namespace: "soft_serve",
 42		Subsystem: "ssh",
 43		Name:      "keyboard_interactive_auth_total",
 44		Help:      "The total number of keyboard interactive auth requests",
 45	}, []string{"allowed"})
 46
 47	uploadPackCounter = promauto.NewCounterVec(prometheus.CounterOpts{
 48		Namespace: "soft_serve",
 49		Subsystem: "git",
 50		Name:      "upload_pack_total",
 51		Help:      "The total number of git-upload-pack requests",
 52	}, []string{"repo"})
 53
 54	receivePackCounter = promauto.NewCounterVec(prometheus.CounterOpts{
 55		Namespace: "soft_serve",
 56		Subsystem: "git",
 57		Name:      "receive_pack_total",
 58		Help:      "The total number of git-receive-pack requests",
 59	}, []string{"repo"})
 60
 61	uploadArchiveCounter = promauto.NewCounterVec(prometheus.CounterOpts{
 62		Namespace: "soft_serve",
 63		Subsystem: "git",
 64		Name:      "upload_archive_total",
 65		Help:      "The total number of git-upload-archive requests",
 66	}, []string{"repo"})
 67
 68	uploadPackSeconds = promauto.NewCounterVec(prometheus.CounterOpts{
 69		Namespace: "soft_serve",
 70		Subsystem: "git",
 71		Name:      "upload_pack_seconds_total",
 72		Help:      "The total time spent on git-upload-pack requests",
 73	}, []string{"repo"})
 74
 75	receivePackSeconds = promauto.NewCounterVec(prometheus.CounterOpts{
 76		Namespace: "soft_serve",
 77		Subsystem: "git",
 78		Name:      "receive_pack_seconds_total",
 79		Help:      "The total time spent on git-receive-pack requests",
 80	}, []string{"repo"})
 81
 82	uploadArchiveSeconds = promauto.NewCounterVec(prometheus.CounterOpts{
 83		Namespace: "soft_serve",
 84		Subsystem: "git",
 85		Name:      "upload_archive_seconds_total",
 86		Help:      "The total time spent on git-upload-archive requests",
 87	}, []string{"repo"})
 88
 89	createRepoCounter = promauto.NewCounterVec(prometheus.CounterOpts{
 90		Namespace: "soft_serve",
 91		Subsystem: "ssh",
 92		Name:      "create_repo_total",
 93		Help:      "The total number of create repo requests",
 94	}, []string{"repo"})
 95)
 96
 97// SSHServer is a SSH server that implements the git protocol.
 98type SSHServer struct {
 99	srv    *ssh.Server
100	cfg    *config.Config
101	be     backend.Backend
102	ctx    context.Context
103	logger *log.Logger
104}
105
106// NewSSHServer returns a new SSHServer.
107func NewSSHServer(ctx context.Context) (*SSHServer, error) {
108	cfg := config.FromContext(ctx)
109	logger := log.FromContext(ctx).WithPrefix("ssh")
110
111	var err error
112	s := &SSHServer{
113		cfg:    cfg,
114		ctx:    ctx,
115		be:     backend.FromContext(ctx),
116		logger: logger,
117	}
118
119	mw := []wish.Middleware{
120		rm.MiddlewareWithLogger(
121			logger,
122			// BubbleTea middleware.
123			bm.MiddlewareWithProgramHandler(SessionHandler(cfg), termenv.ANSI256),
124			// CLI middleware.
125			cm.Middleware(cfg, logger),
126			// Git middleware.
127			s.Middleware(cfg),
128			// Logging middleware.
129			lm.MiddlewareWithLogger(logger.
130				StandardLog(log.StandardLogOptions{ForceLevel: log.DebugLevel})),
131		),
132	}
133
134	s.srv, err = wish.NewServer(
135		ssh.PublicKeyAuth(s.PublicKeyHandler),
136		ssh.KeyboardInteractiveAuth(s.KeyboardInteractiveHandler),
137		wish.WithAddress(cfg.SSH.ListenAddr),
138		wish.WithHostKeyPath(cfg.SSH.KeyPath),
139		wish.WithMiddleware(mw...),
140	)
141	if err != nil {
142		return nil, err
143	}
144
145	if cfg.SSH.MaxTimeout > 0 {
146		s.srv.MaxTimeout = time.Duration(cfg.SSH.MaxTimeout) * time.Second
147	}
148
149	if cfg.SSH.IdleTimeout > 0 {
150		s.srv.IdleTimeout = time.Duration(cfg.SSH.IdleTimeout) * time.Second
151	}
152
153	// Create client ssh key
154	if _, err := os.Stat(cfg.SSH.ClientKeyPath); err != nil && os.IsNotExist(err) {
155		_, err := keygen.New(cfg.SSH.ClientKeyPath, keygen.WithKeyType(keygen.Ed25519), keygen.WithWrite())
156		if err != nil {
157			return nil, fmt.Errorf("client ssh key: %w", err)
158		}
159	}
160
161	return s, nil
162}
163
164// ListenAndServe starts the SSH server.
165func (s *SSHServer) ListenAndServe() error {
166	return s.srv.ListenAndServe()
167}
168
169// Serve starts the SSH server on the given net.Listener.
170func (s *SSHServer) Serve(l net.Listener) error {
171	return s.srv.Serve(l)
172}
173
174// Close closes the SSH server.
175func (s *SSHServer) Close() error {
176	return s.srv.Close()
177}
178
179// Shutdown gracefully shuts down the SSH server.
180func (s *SSHServer) Shutdown(ctx context.Context) error {
181	return s.srv.Shutdown(ctx)
182}
183
184// PublicKeyAuthHandler handles public key authentication.
185func (s *SSHServer) PublicKeyHandler(ctx ssh.Context, pk ssh.PublicKey) (allowed bool) {
186	if pk == nil {
187		return false
188	}
189
190	ak := backend.MarshalAuthorizedKey(pk)
191	defer func(allowed *bool) {
192		publicKeyCounter.WithLabelValues(strconv.FormatBool(*allowed)).Inc()
193	}(&allowed)
194
195	ac := s.cfg.Backend.AccessLevelByPublicKey("", pk)
196	s.logger.Debugf("access level for %q: %s", ak, ac)
197	allowed = ac >= backend.ReadWriteAccess
198	return
199}
200
201// KeyboardInteractiveHandler handles keyboard interactive authentication.
202// This is used after all public key authentication has failed.
203func (s *SSHServer) KeyboardInteractiveHandler(ctx ssh.Context, _ gossh.KeyboardInteractiveChallenge) bool {
204	ac := s.cfg.Backend.AllowKeyless()
205	keyboardInteractiveCounter.WithLabelValues(strconv.FormatBool(ac)).Inc()
206	return ac
207}
208
209// Middleware adds Git server functionality to the ssh.Server. Repos are stored
210// in the specified repo directory. The provided Hooks implementation will be
211// checked for access on a per repo basis for a ssh.Session public key.
212// Hooks.Push and Hooks.Fetch will be called on successful completion of
213// their commands.
214func (ss *SSHServer) Middleware(cfg *config.Config) wish.Middleware {
215	return func(sh ssh.Handler) ssh.Handler {
216		return func(s ssh.Session) {
217			func() {
218				start := time.Now()
219				cmdLine := s.Command()
220				ctx := s.Context()
221				be := ss.be.WithContext(ctx)
222
223				if len(cmdLine) >= 2 && strings.HasPrefix(cmdLine[0], "git") {
224					// repo should be in the form of "repo.git"
225					name := utils.SanitizeRepo(cmdLine[1])
226					pk := s.PublicKey()
227					ak := backend.MarshalAuthorizedKey(pk)
228					access := cfg.Backend.AccessLevelByPublicKey(name, pk)
229					// git bare repositories should end in ".git"
230					// https://git-scm.com/docs/gitrepository-layout
231					repo := name + ".git"
232					reposDir := filepath.Join(cfg.DataPath, "repos")
233					if err := git.EnsureWithin(reposDir, repo); err != nil {
234						sshFatal(s, err)
235						return
236					}
237
238					// Environment variables to pass down to git hooks.
239					envs := []string{
240						"SOFT_SERVE_REPO_NAME=" + name,
241						"SOFT_SERVE_REPO_PATH=" + filepath.Join(reposDir, repo),
242						"SOFT_SERVE_PUBLIC_KEY=" + ak,
243						"SOFT_SERVE_USERNAME=" + ctx.User(),
244					}
245
246					// Add ssh session & config environ
247					envs = append(envs, s.Environ()...)
248					envs = append(envs, cfg.Environ()...)
249
250					repoDir := filepath.Join(reposDir, repo)
251					service := git.Service(cmdLine[0])
252					cmd := git.ServiceCommand{
253						Stdin:  s,
254						Stdout: s,
255						Stderr: s.Stderr(),
256						Env:    envs,
257						Dir:    repoDir,
258					}
259
260					ss.logger.Debug("git middleware", "cmd", service, "access", access.String())
261
262					switch service {
263					case git.ReceivePackService:
264						receivePackCounter.WithLabelValues(name).Inc()
265						defer func() {
266							receivePackSeconds.WithLabelValues(name).Add(time.Since(start).Seconds())
267						}()
268						if access < backend.ReadWriteAccess {
269							sshFatal(s, git.ErrNotAuthed)
270							return
271						}
272						if _, err := be.Repository(name); err != nil {
273							if _, err := be.CreateRepository(name, backend.RepositoryOptions{Private: false}); err != nil {
274								log.Errorf("failed to create repo: %s", err)
275								sshFatal(s, err)
276								return
277							}
278							createRepoCounter.WithLabelValues(name).Inc()
279						}
280
281						if err := git.ReceivePack(ctx, cmd); err != nil {
282							sshFatal(s, git.ErrSystemMalfunction)
283						}
284
285						if err := git.EnsureDefaultBranch(ctx, cmd); err != nil {
286							sshFatal(s, git.ErrSystemMalfunction)
287						}
288
289						receivePackCounter.WithLabelValues(name).Inc()
290						return
291					case git.UploadPackService, git.UploadArchiveService:
292						if access < backend.ReadOnlyAccess {
293							sshFatal(s, git.ErrNotAuthed)
294							return
295						}
296
297						handler := git.UploadPack
298						switch service {
299						case git.UploadArchiveService:
300							handler = git.UploadArchive
301							uploadArchiveCounter.WithLabelValues(name).Inc()
302							defer func() {
303								uploadArchiveSeconds.WithLabelValues(name).Add(time.Since(start).Seconds())
304							}()
305						default:
306							uploadPackCounter.WithLabelValues(name).Inc()
307							defer func() {
308								uploadPackSeconds.WithLabelValues(name).Add(time.Since(start).Seconds())
309							}()
310						}
311
312						err := handler(ctx, cmd)
313						if errors.Is(err, git.ErrInvalidRepo) {
314							sshFatal(s, git.ErrInvalidRepo)
315						} else if err != nil {
316							sshFatal(s, git.ErrSystemMalfunction)
317						}
318
319					}
320				}
321			}()
322			sh(s)
323		}
324	}
325}
326
327// sshFatal prints to the session's STDOUT as a git response and exit 1.
328func sshFatal(s ssh.Session, v ...interface{}) {
329	git.WritePktline(s, v...)
330	s.Exit(1) // nolint: errcheck
331}