1package ssh
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "io"
8 "net"
9 "os"
10 "os/exec"
11 "path/filepath"
12 "strconv"
13 "strings"
14 "syscall"
15 "time"
16 "unsafe"
17
18 "github.com/charmbracelet/keygen"
19 "github.com/charmbracelet/log"
20 "github.com/charmbracelet/soft-serve/server/access"
21 "github.com/charmbracelet/soft-serve/server/auth"
22 "github.com/charmbracelet/soft-serve/server/backend"
23 "github.com/creack/pty"
24
25 // cm "github.com/charmbracelet/soft-serve/server/cmd"
26
27 "github.com/charmbracelet/soft-serve/server/config"
28 "github.com/charmbracelet/soft-serve/server/git"
29 "github.com/charmbracelet/soft-serve/server/sshutils"
30 "github.com/charmbracelet/soft-serve/server/store"
31 "github.com/charmbracelet/soft-serve/server/utils"
32 "github.com/charmbracelet/ssh"
33 "github.com/charmbracelet/wish"
34 lm "github.com/charmbracelet/wish/logging"
35 rm "github.com/charmbracelet/wish/recover"
36 "github.com/prometheus/client_golang/prometheus"
37 "github.com/prometheus/client_golang/prometheus/promauto"
38 gossh "golang.org/x/crypto/ssh"
39)
40
41var (
42 publicKeyCounter = promauto.NewCounterVec(prometheus.CounterOpts{
43 Namespace: "soft_serve",
44 Subsystem: "ssh",
45 Name: "public_key_auth_total",
46 Help: "The total number of public key auth requests",
47 }, []string{"key", "user", "allowed"})
48
49 keyboardInteractiveCounter = promauto.NewCounterVec(prometheus.CounterOpts{
50 Namespace: "soft_serve",
51 Subsystem: "ssh",
52 Name: "keyboard_interactive_auth_total",
53 Help: "The total number of keyboard interactive auth requests",
54 }, []string{"user", "allowed"})
55
56 uploadPackCounter = promauto.NewCounterVec(prometheus.CounterOpts{
57 Namespace: "soft_serve",
58 Subsystem: "ssh",
59 Name: "git_upload_pack_total",
60 Help: "The total number of git-upload-pack requests",
61 }, []string{"key", "user", "repo"})
62
63 receivePackCounter = promauto.NewCounterVec(prometheus.CounterOpts{
64 Namespace: "soft_serve",
65 Subsystem: "ssh",
66 Name: "git_receive_pack_total",
67 Help: "The total number of git-receive-pack requests",
68 }, []string{"key", "user", "repo"})
69
70 uploadArchiveCounter = promauto.NewCounterVec(prometheus.CounterOpts{
71 Namespace: "soft_serve",
72 Subsystem: "ssh",
73 Name: "git_upload_archive_total",
74 Help: "The total number of git-upload-archive requests",
75 }, []string{"key", "user", "repo"})
76
77 createRepoCounter = promauto.NewCounterVec(prometheus.CounterOpts{
78 Namespace: "soft_serve",
79 Subsystem: "ssh",
80 Name: "create_repo_total",
81 Help: "The total number of create repo requests",
82 }, []string{"key", "user", "repo"})
83)
84
85// SSHServer is a SSH server that implements the git protocol.
86type SSHServer struct {
87 srv *ssh.Server
88 cfg *config.Config
89 be *backend.Backend
90 ctx context.Context
91 logger *log.Logger
92}
93
94func setWinsize(f *os.File, w, h int) {
95 syscall.Syscall(syscall.SYS_IOCTL, f.Fd(), uintptr(syscall.TIOCSWINSZ),
96 uintptr(unsafe.Pointer(&struct{ h, w, x, y uint16 }{uint16(h), uint16(w), 0, 0})))
97}
98
99// NewSSHServer returns a new SSHServer.
100func NewSSHServer(ctx context.Context) (*SSHServer, error) {
101 cfg := config.FromContext(ctx)
102 logger := log.FromContext(ctx).WithPrefix("ssh")
103
104 var err error
105 s := &SSHServer{
106 cfg: cfg,
107 ctx: ctx,
108 be: backend.FromContext(ctx),
109 logger: logger,
110 }
111
112 mw := []wish.Middleware{
113 rm.MiddlewareWithLogger(
114 logger,
115 // BubbleTea middleware.
116 // bm.MiddlewareWithProgramHandler(SessionHandler(ctx), termenv.ANSI256),
117 // CLI middleware.
118 // cm.Middleware(ctx, logger),
119 // Git middleware.
120 // s.Middleware(cfg),
121 func(h ssh.Handler) ssh.Handler {
122 return func(s ssh.Session) {
123 ptyReq, winCh, isPty := s.Pty()
124 cmds := s.Command()
125
126 exe, err := os.Executable()
127 if err != nil {
128 s.Exit(1)
129 return
130 }
131
132 cmd := exec.Command(exe, cmds...)
133 if isPty {
134 cmd.Env = append(cmd.Env, fmt.Sprintf("TERM=%s", ptyReq.Term))
135 }
136 cmd.Env = append(cmd.Env, fmt.Sprintf("SSH_ORIGINAL_COMMAND=%s", strings.Join(cmds, " ")))
137 cmd.Env = append(cmd.Env, cfg.Environ()...)
138
139 ptyf, tty, err := pty.Open()
140 if err != nil {
141 os.Exit(1)
142 return
143 }
144 defer tty.Close()
145
146 cmd.Env = append(cmd.Env, fmt.Sprintf("SSH_TTY=%s", tty.Name()))
147
148 if cmd.Stdout == nil {
149 cmd.Stdout = tty
150 }
151 if cmd.Stderr == nil {
152 cmd.Stderr = tty
153 }
154 if cmd.Stdin == nil {
155 cmd.Stdin = tty
156 }
157
158 cmd.SysProcAttr = &syscall.SysProcAttr{
159 Setsid: true,
160 Setctty: true,
161 }
162
163 if err := cmd.Start(); err != nil {
164 _ = ptyf.Close()
165 os.Exit(1)
166 return
167 }
168 go func() {
169 for win := range winCh {
170 setWinsize(ptyf, win.Width, win.Height)
171 }
172 }()
173 go func() {
174 io.Copy(ptyf, s) // stdin
175 }()
176 io.Copy(s, ptyf) // stdout
177
178 cmd.Wait()
179 h(s)
180 }
181 },
182 // Logging middleware.
183 lm.MiddlewareWithLogger(logger.
184 StandardLog(log.StandardLogOptions{ForceLevel: log.DebugLevel})),
185 ),
186 }
187
188 s.srv, err = wish.NewServer(
189 ssh.PublicKeyAuth(s.PublicKeyHandler),
190 ssh.KeyboardInteractiveAuth(s.KeyboardInteractiveHandler),
191 wish.WithAddress(cfg.SSH.ListenAddr),
192 wish.WithHostKeyPath(cfg.SSH.KeyPath),
193 wish.WithMiddleware(mw...),
194 )
195 if err != nil {
196 return nil, err
197 }
198
199 if cfg.SSH.MaxTimeout > 0 {
200 s.srv.MaxTimeout = time.Duration(cfg.SSH.MaxTimeout) * time.Second
201 }
202
203 if cfg.SSH.IdleTimeout > 0 {
204 s.srv.IdleTimeout = time.Duration(cfg.SSH.IdleTimeout) * time.Second
205 }
206
207 // Create client ssh key
208 if _, err := os.Stat(cfg.SSH.ClientKeyPath); err != nil && os.IsNotExist(err) {
209 _, err := keygen.New(cfg.SSH.ClientKeyPath, keygen.WithKeyType(keygen.Ed25519), keygen.WithWrite())
210 if err != nil {
211 return nil, fmt.Errorf("client ssh key: %w", err)
212 }
213 }
214
215 return s, nil
216}
217
218// ListenAndServe starts the SSH server.
219func (s *SSHServer) ListenAndServe() error {
220 return s.srv.ListenAndServe()
221}
222
223// Serve starts the SSH server on the given net.Listener.
224func (s *SSHServer) Serve(l net.Listener) error {
225 return s.srv.Serve(l)
226}
227
228// Close closes the SSH server.
229func (s *SSHServer) Close() error {
230 return s.srv.Close()
231}
232
233// Shutdown gracefully shuts down the SSH server.
234func (s *SSHServer) Shutdown(ctx context.Context) error {
235 return s.srv.Shutdown(ctx)
236}
237
238// PublicKeyAuthHandler handles public key authentication.
239func (s *SSHServer) PublicKeyHandler(ctx ssh.Context, pk ssh.PublicKey) (allowed bool) {
240 ctx.SetValue(config.ContextKeyConfig, s.cfg)
241 ctx.SetValue(ssh.ContextKeyPublicKey, pk)
242
243 if pk == nil {
244 return false
245 }
246
247 var ac access.AccessLevel
248 var user auth.User
249 ak := sshutils.MarshalAuthorizedKey(pk)
250
251 defer func(allowed *bool) {
252 publicKeyCounter.WithLabelValues(ak, ctx.User(), strconv.FormatBool(*allowed)).Inc()
253 s.logger.Debugf("access level for %q: %s", ak, ac)
254 ctx.SetValue(auth.ContextKeyUser, user)
255 }(&allowed)
256
257 user, _ = s.be.Authenticate(ctx, auth.NewPublicKey(pk))
258 ac, _ = s.be.AccessLevel(ctx, "", user)
259 allowed = ac >= access.ReadWriteAccess
260 return
261}
262
263// KeyboardInteractiveHandler handles keyboard interactive authentication.
264// This is used after all public key authentication has failed.
265func (s *SSHServer) KeyboardInteractiveHandler(ctx ssh.Context, _ gossh.KeyboardInteractiveChallenge) bool {
266 ctx.SetValue(config.ContextKeyConfig, s.cfg)
267 ac := s.be.AllowKeyless(ctx)
268 keyboardInteractiveCounter.WithLabelValues(ctx.User(), strconv.FormatBool(ac)).Inc()
269 return ac
270}
271
272// Middleware adds Git server functionality to the ssh.Server. Repos are stored
273// in the specified repo directory. The provided Hooks implementation will be
274// checked for access on a per repo basis for a ssh.Session public key.
275// Hooks.Push and Hooks.Fetch will be called on successful completion of
276// their commands.
277func (ss *SSHServer) Middleware(cfg *config.Config) wish.Middleware {
278 return func(sh ssh.Handler) ssh.Handler {
279 return func(s ssh.Session) {
280 func() {
281 cmdLine := s.Command()
282 ctx := s.Context()
283
284 if len(cmdLine) >= 2 && strings.HasPrefix(cmdLine[0], "git") {
285 // repo should be in the form of "repo.git"
286 name := utils.SanitizeRepo(cmdLine[1])
287 pk := s.PublicKey()
288 ak := sshutils.MarshalAuthorizedKey(pk)
289 user, _ := ss.be.Authenticate(ctx, auth.NewPublicKey(pk))
290 ac, _ := ss.be.AccessLevel(ctx, name, user)
291
292 // git bare repositories should end in ".git"
293 // https://git-scm.com/docs/gitrepository-layout
294 repo := name + ".git"
295 reposDir := filepath.Join(cfg.DataPath, "repos")
296 if err := git.EnsureWithin(reposDir, repo); err != nil {
297 sshFatal(s, err)
298 return
299 }
300
301 // Environment variables to pass down to git hooks.
302 envs := []string{
303 "SOFT_SERVE_REPO_NAME=" + name,
304 "SOFT_SERVE_REPO_PATH=" + filepath.Join(reposDir, repo),
305 "SOFT_SERVE_PUBLIC_KEY=" + ak,
306 "SOFT_SERVE_USERNAME=" + ctx.User(),
307 }
308
309 // Add ssh session & config environ
310 envs = append(envs, s.Environ()...)
311 envs = append(envs, cfg.Environ()...)
312
313 repoDir := filepath.Join(reposDir, repo)
314 service := git.Service(cmdLine[0])
315 cmd := git.ServiceCommand{
316 Stdin: s,
317 Stdout: s,
318 Stderr: s.Stderr(),
319 Env: envs,
320 Dir: repoDir,
321 }
322
323 ss.logger.Debug("git middleware", "cmd", service, "access", ac.String())
324
325 switch service {
326 case git.ReceivePackService:
327 if ac < access.ReadWriteAccess {
328 sshFatal(s, git.ErrUnauthorized)
329 return
330 }
331 if _, err := ss.be.Repository(ctx, name); err != nil {
332 if _, err := ss.be.CreateRepository(ctx, name, store.RepositoryOptions{Private: false}); err != nil {
333 log.Errorf("failed to create repo: %s", err)
334 sshFatal(s, err)
335 return
336 }
337
338 createRepoCounter.WithLabelValues(ak, s.User(), name).Inc()
339 }
340
341 if err := git.ReceivePack(ctx, cmd); err != nil {
342 sshFatal(s, git.ErrSystemMalfunction)
343 }
344
345 if err := git.EnsureDefaultBranch(ctx, cmd); err != nil {
346 sshFatal(s, git.ErrSystemMalfunction)
347 }
348
349 receivePackCounter.WithLabelValues(ak, s.User(), name).Inc()
350 return
351 case git.UploadPackService, git.UploadArchiveService:
352 if ac < access.ReadOnlyAccess {
353 sshFatal(s, git.ErrUnauthorized)
354 return
355 }
356
357 handler := git.UploadPack
358 counter := uploadPackCounter
359 if service == git.UploadArchiveService {
360 handler = git.UploadArchive
361 counter = uploadArchiveCounter
362 }
363
364 err := handler(ctx, cmd)
365 if errors.Is(err, git.ErrNotExist) {
366 sshFatal(s, git.ErrNotExist)
367 } else if err != nil {
368 sshFatal(s, git.ErrSystemMalfunction)
369 }
370
371 counter.WithLabelValues(ak, s.User(), name).Inc()
372 }
373 }
374 }()
375 sh(s)
376 }
377 }
378}
379
380// sshFatal prints to the session's STDOUT as a git response and exit 1.
381func sshFatal(s ssh.Session, v ...interface{}) {
382 git.WritePktline(s, v...)
383 s.Exit(1) // nolint: errcheck
384}