ssh.go

  1package ssh
  2
  3import (
  4	"context"
  5	"errors"
  6	"fmt"
  7	"io"
  8	"net"
  9	"os"
 10	"os/exec"
 11	"path/filepath"
 12	"strconv"
 13	"strings"
 14	"syscall"
 15	"time"
 16	"unsafe"
 17
 18	"github.com/charmbracelet/keygen"
 19	"github.com/charmbracelet/log"
 20	"github.com/charmbracelet/soft-serve/server/access"
 21	"github.com/charmbracelet/soft-serve/server/auth"
 22	"github.com/charmbracelet/soft-serve/server/backend"
 23	"github.com/creack/pty"
 24
 25	// cm "github.com/charmbracelet/soft-serve/server/cmd"
 26
 27	"github.com/charmbracelet/soft-serve/server/config"
 28	"github.com/charmbracelet/soft-serve/server/git"
 29	"github.com/charmbracelet/soft-serve/server/sshutils"
 30	"github.com/charmbracelet/soft-serve/server/store"
 31	"github.com/charmbracelet/soft-serve/server/utils"
 32	"github.com/charmbracelet/ssh"
 33	"github.com/charmbracelet/wish"
 34	lm "github.com/charmbracelet/wish/logging"
 35	rm "github.com/charmbracelet/wish/recover"
 36	"github.com/prometheus/client_golang/prometheus"
 37	"github.com/prometheus/client_golang/prometheus/promauto"
 38	gossh "golang.org/x/crypto/ssh"
 39)
 40
 41var (
 42	publicKeyCounter = promauto.NewCounterVec(prometheus.CounterOpts{
 43		Namespace: "soft_serve",
 44		Subsystem: "ssh",
 45		Name:      "public_key_auth_total",
 46		Help:      "The total number of public key auth requests",
 47	}, []string{"key", "user", "allowed"})
 48
 49	keyboardInteractiveCounter = promauto.NewCounterVec(prometheus.CounterOpts{
 50		Namespace: "soft_serve",
 51		Subsystem: "ssh",
 52		Name:      "keyboard_interactive_auth_total",
 53		Help:      "The total number of keyboard interactive auth requests",
 54	}, []string{"user", "allowed"})
 55
 56	uploadPackCounter = promauto.NewCounterVec(prometheus.CounterOpts{
 57		Namespace: "soft_serve",
 58		Subsystem: "ssh",
 59		Name:      "git_upload_pack_total",
 60		Help:      "The total number of git-upload-pack requests",
 61	}, []string{"key", "user", "repo"})
 62
 63	receivePackCounter = promauto.NewCounterVec(prometheus.CounterOpts{
 64		Namespace: "soft_serve",
 65		Subsystem: "ssh",
 66		Name:      "git_receive_pack_total",
 67		Help:      "The total number of git-receive-pack requests",
 68	}, []string{"key", "user", "repo"})
 69
 70	uploadArchiveCounter = promauto.NewCounterVec(prometheus.CounterOpts{
 71		Namespace: "soft_serve",
 72		Subsystem: "ssh",
 73		Name:      "git_upload_archive_total",
 74		Help:      "The total number of git-upload-archive requests",
 75	}, []string{"key", "user", "repo"})
 76
 77	createRepoCounter = promauto.NewCounterVec(prometheus.CounterOpts{
 78		Namespace: "soft_serve",
 79		Subsystem: "ssh",
 80		Name:      "create_repo_total",
 81		Help:      "The total number of create repo requests",
 82	}, []string{"key", "user", "repo"})
 83)
 84
 85// SSHServer is a SSH server that implements the git protocol.
 86type SSHServer struct {
 87	srv    *ssh.Server
 88	cfg    *config.Config
 89	be     *backend.Backend
 90	ctx    context.Context
 91	logger *log.Logger
 92}
 93
 94func setWinsize(f *os.File, w, h int) {
 95	syscall.Syscall(syscall.SYS_IOCTL, f.Fd(), uintptr(syscall.TIOCSWINSZ),
 96		uintptr(unsafe.Pointer(&struct{ h, w, x, y uint16 }{uint16(h), uint16(w), 0, 0})))
 97}
 98
 99// NewSSHServer returns a new SSHServer.
100func NewSSHServer(ctx context.Context) (*SSHServer, error) {
101	cfg := config.FromContext(ctx)
102	logger := log.FromContext(ctx).WithPrefix("ssh")
103
104	var err error
105	s := &SSHServer{
106		cfg:    cfg,
107		ctx:    ctx,
108		be:     backend.FromContext(ctx),
109		logger: logger,
110	}
111
112	mw := []wish.Middleware{
113		rm.MiddlewareWithLogger(
114			logger,
115			// BubbleTea middleware.
116			// bm.MiddlewareWithProgramHandler(SessionHandler(ctx), termenv.ANSI256),
117			// CLI middleware.
118			// cm.Middleware(ctx, logger),
119			// Git middleware.
120			// s.Middleware(cfg),
121			func(h ssh.Handler) ssh.Handler {
122				return func(s ssh.Session) {
123					ptyReq, winCh, isPty := s.Pty()
124					cmds := s.Command()
125
126					exe, err := os.Executable()
127					if err != nil {
128						s.Exit(1)
129						return
130					}
131
132					cmd := exec.Command(exe, cmds...)
133					if isPty {
134						cmd.Env = append(cmd.Env, fmt.Sprintf("TERM=%s", ptyReq.Term))
135					}
136					cmd.Env = append(cmd.Env, fmt.Sprintf("SSH_ORIGINAL_COMMAND=%s", strings.Join(cmds, " ")))
137					cmd.Env = append(cmd.Env, cfg.Environ()...)
138
139					ptyf, tty, err := pty.Open()
140					if err != nil {
141						os.Exit(1)
142						return
143					}
144					defer tty.Close()
145
146					cmd.Env = append(cmd.Env, fmt.Sprintf("SSH_TTY=%s", tty.Name()))
147
148					if cmd.Stdout == nil {
149						cmd.Stdout = tty
150					}
151					if cmd.Stderr == nil {
152						cmd.Stderr = tty
153					}
154					if cmd.Stdin == nil {
155						cmd.Stdin = tty
156					}
157
158					cmd.SysProcAttr = &syscall.SysProcAttr{
159						Setsid:  true,
160						Setctty: true,
161					}
162
163					if err := cmd.Start(); err != nil {
164						_ = ptyf.Close()
165						os.Exit(1)
166						return
167					}
168					go func() {
169						for win := range winCh {
170							setWinsize(ptyf, win.Width, win.Height)
171						}
172					}()
173					go func() {
174						io.Copy(ptyf, s) // stdin
175					}()
176					io.Copy(s, ptyf) // stdout
177
178					cmd.Wait()
179					h(s)
180				}
181			},
182			// Logging middleware.
183			lm.MiddlewareWithLogger(logger.
184				StandardLog(log.StandardLogOptions{ForceLevel: log.DebugLevel})),
185		),
186	}
187
188	s.srv, err = wish.NewServer(
189		ssh.PublicKeyAuth(s.PublicKeyHandler),
190		ssh.KeyboardInteractiveAuth(s.KeyboardInteractiveHandler),
191		wish.WithAddress(cfg.SSH.ListenAddr),
192		wish.WithHostKeyPath(cfg.SSH.KeyPath),
193		wish.WithMiddleware(mw...),
194	)
195	if err != nil {
196		return nil, err
197	}
198
199	if cfg.SSH.MaxTimeout > 0 {
200		s.srv.MaxTimeout = time.Duration(cfg.SSH.MaxTimeout) * time.Second
201	}
202
203	if cfg.SSH.IdleTimeout > 0 {
204		s.srv.IdleTimeout = time.Duration(cfg.SSH.IdleTimeout) * time.Second
205	}
206
207	// Create client ssh key
208	if _, err := os.Stat(cfg.SSH.ClientKeyPath); err != nil && os.IsNotExist(err) {
209		_, err := keygen.New(cfg.SSH.ClientKeyPath, keygen.WithKeyType(keygen.Ed25519), keygen.WithWrite())
210		if err != nil {
211			return nil, fmt.Errorf("client ssh key: %w", err)
212		}
213	}
214
215	return s, nil
216}
217
218// ListenAndServe starts the SSH server.
219func (s *SSHServer) ListenAndServe() error {
220	return s.srv.ListenAndServe()
221}
222
223// Serve starts the SSH server on the given net.Listener.
224func (s *SSHServer) Serve(l net.Listener) error {
225	return s.srv.Serve(l)
226}
227
228// Close closes the SSH server.
229func (s *SSHServer) Close() error {
230	return s.srv.Close()
231}
232
233// Shutdown gracefully shuts down the SSH server.
234func (s *SSHServer) Shutdown(ctx context.Context) error {
235	return s.srv.Shutdown(ctx)
236}
237
238// PublicKeyAuthHandler handles public key authentication.
239func (s *SSHServer) PublicKeyHandler(ctx ssh.Context, pk ssh.PublicKey) (allowed bool) {
240	ctx.SetValue(config.ContextKeyConfig, s.cfg)
241	ctx.SetValue(ssh.ContextKeyPublicKey, pk)
242
243	if pk == nil {
244		return false
245	}
246
247	var ac access.AccessLevel
248	var user auth.User
249	ak := sshutils.MarshalAuthorizedKey(pk)
250
251	defer func(allowed *bool) {
252		publicKeyCounter.WithLabelValues(ak, ctx.User(), strconv.FormatBool(*allowed)).Inc()
253		s.logger.Debugf("access level for %q: %s", ak, ac)
254		ctx.SetValue(auth.ContextKeyUser, user)
255	}(&allowed)
256
257	user, _ = s.be.Authenticate(ctx, auth.NewPublicKey(pk))
258	ac, _ = s.be.AccessLevel(ctx, "", user)
259	allowed = ac >= access.ReadWriteAccess
260	return
261}
262
263// KeyboardInteractiveHandler handles keyboard interactive authentication.
264// This is used after all public key authentication has failed.
265func (s *SSHServer) KeyboardInteractiveHandler(ctx ssh.Context, _ gossh.KeyboardInteractiveChallenge) bool {
266	ctx.SetValue(config.ContextKeyConfig, s.cfg)
267	ac := s.be.AllowKeyless(ctx)
268	keyboardInteractiveCounter.WithLabelValues(ctx.User(), strconv.FormatBool(ac)).Inc()
269	return ac
270}
271
272// Middleware adds Git server functionality to the ssh.Server. Repos are stored
273// in the specified repo directory. The provided Hooks implementation will be
274// checked for access on a per repo basis for a ssh.Session public key.
275// Hooks.Push and Hooks.Fetch will be called on successful completion of
276// their commands.
277func (ss *SSHServer) Middleware(cfg *config.Config) wish.Middleware {
278	return func(sh ssh.Handler) ssh.Handler {
279		return func(s ssh.Session) {
280			func() {
281				cmdLine := s.Command()
282				ctx := s.Context()
283
284				if len(cmdLine) >= 2 && strings.HasPrefix(cmdLine[0], "git") {
285					// repo should be in the form of "repo.git"
286					name := utils.SanitizeRepo(cmdLine[1])
287					pk := s.PublicKey()
288					ak := sshutils.MarshalAuthorizedKey(pk)
289					user, _ := ss.be.Authenticate(ctx, auth.NewPublicKey(pk))
290					ac, _ := ss.be.AccessLevel(ctx, name, user)
291
292					// git bare repositories should end in ".git"
293					// https://git-scm.com/docs/gitrepository-layout
294					repo := name + ".git"
295					reposDir := filepath.Join(cfg.DataPath, "repos")
296					if err := git.EnsureWithin(reposDir, repo); err != nil {
297						sshFatal(s, err)
298						return
299					}
300
301					// Environment variables to pass down to git hooks.
302					envs := []string{
303						"SOFT_SERVE_REPO_NAME=" + name,
304						"SOFT_SERVE_REPO_PATH=" + filepath.Join(reposDir, repo),
305						"SOFT_SERVE_PUBLIC_KEY=" + ak,
306						"SOFT_SERVE_USERNAME=" + ctx.User(),
307					}
308
309					// Add ssh session & config environ
310					envs = append(envs, s.Environ()...)
311					envs = append(envs, cfg.Environ()...)
312
313					repoDir := filepath.Join(reposDir, repo)
314					service := git.Service(cmdLine[0])
315					cmd := git.ServiceCommand{
316						Stdin:  s,
317						Stdout: s,
318						Stderr: s.Stderr(),
319						Env:    envs,
320						Dir:    repoDir,
321					}
322
323					ss.logger.Debug("git middleware", "cmd", service, "access", ac.String())
324
325					switch service {
326					case git.ReceivePackService:
327						if ac < access.ReadWriteAccess {
328							sshFatal(s, git.ErrUnauthorized)
329							return
330						}
331						if _, err := ss.be.Repository(ctx, name); err != nil {
332							if _, err := ss.be.CreateRepository(ctx, name, store.RepositoryOptions{Private: false}); err != nil {
333								log.Errorf("failed to create repo: %s", err)
334								sshFatal(s, err)
335								return
336							}
337
338							createRepoCounter.WithLabelValues(ak, s.User(), name).Inc()
339						}
340
341						if err := git.ReceivePack(ctx, cmd); err != nil {
342							sshFatal(s, git.ErrSystemMalfunction)
343						}
344
345						if err := git.EnsureDefaultBranch(ctx, cmd); err != nil {
346							sshFatal(s, git.ErrSystemMalfunction)
347						}
348
349						receivePackCounter.WithLabelValues(ak, s.User(), name).Inc()
350						return
351					case git.UploadPackService, git.UploadArchiveService:
352						if ac < access.ReadOnlyAccess {
353							sshFatal(s, git.ErrUnauthorized)
354							return
355						}
356
357						handler := git.UploadPack
358						counter := uploadPackCounter
359						if service == git.UploadArchiveService {
360							handler = git.UploadArchive
361							counter = uploadArchiveCounter
362						}
363
364						err := handler(ctx, cmd)
365						if errors.Is(err, git.ErrNotExist) {
366							sshFatal(s, git.ErrNotExist)
367						} else if err != nil {
368							sshFatal(s, git.ErrSystemMalfunction)
369						}
370
371						counter.WithLabelValues(ak, s.User(), name).Inc()
372					}
373				}
374			}()
375			sh(s)
376		}
377	}
378}
379
380// sshFatal prints to the session's STDOUT as a git response and exit 1.
381func sshFatal(s ssh.Session, v ...interface{}) {
382	git.WritePktline(s, v...)
383	s.Exit(1) // nolint: errcheck
384}