1package ssh
2
3import (
4 "context"
5 "fmt"
6 "net"
7 "os"
8 "strconv"
9 "time"
10
11 "github.com/charmbracelet/keygen"
12 "github.com/charmbracelet/log"
13 "github.com/charmbracelet/soft-serve/server/backend"
14 "github.com/charmbracelet/soft-serve/server/config"
15 "github.com/charmbracelet/soft-serve/server/db"
16 "github.com/charmbracelet/soft-serve/server/proto"
17 "github.com/charmbracelet/soft-serve/server/store"
18 "github.com/charmbracelet/ssh"
19 "github.com/charmbracelet/wish"
20 bm "github.com/charmbracelet/wish/bubbletea"
21 lm "github.com/charmbracelet/wish/logging"
22 rm "github.com/charmbracelet/wish/recover"
23 "github.com/muesli/termenv"
24 "github.com/prometheus/client_golang/prometheus"
25 "github.com/prometheus/client_golang/prometheus/promauto"
26 gossh "golang.org/x/crypto/ssh"
27)
28
29var (
30 publicKeyCounter = promauto.NewCounterVec(prometheus.CounterOpts{
31 Namespace: "soft_serve",
32 Subsystem: "ssh",
33 Name: "public_key_auth_total",
34 Help: "The total number of public key auth requests",
35 }, []string{"allowed"})
36
37 keyboardInteractiveCounter = promauto.NewCounterVec(prometheus.CounterOpts{
38 Namespace: "soft_serve",
39 Subsystem: "ssh",
40 Name: "keyboard_interactive_auth_total",
41 Help: "The total number of keyboard interactive auth requests",
42 }, []string{"allowed"})
43)
44
45// SSHServer is a SSH server that implements the git protocol.
46type SSHServer struct { // nolint: revive
47 srv *ssh.Server
48 cfg *config.Config
49 be *backend.Backend
50 ctx context.Context
51 logger *log.Logger
52}
53
54// NewSSHServer returns a new SSHServer.
55func NewSSHServer(ctx context.Context) (*SSHServer, error) {
56 cfg := config.FromContext(ctx)
57 logger := log.FromContext(ctx).WithPrefix("ssh")
58 dbx := db.FromContext(ctx)
59 datastore := store.FromContext(ctx)
60 be := backend.FromContext(ctx)
61
62 var err error
63 s := &SSHServer{
64 cfg: cfg,
65 ctx: ctx,
66 be: be,
67 logger: logger,
68 }
69
70 mw := []wish.Middleware{
71 rm.MiddlewareWithLogger(
72 logger,
73 // BubbleTea middleware.
74 bm.MiddlewareWithProgramHandler(SessionHandler, termenv.ANSI256),
75 // CLI middleware.
76 CommandMiddleware,
77 // Context middleware.
78 ContextMiddleware(cfg, dbx, datastore, be, logger),
79 // Logging middleware.
80 lm.MiddlewareWithLogger(
81 &loggerAdapter{logger, log.DebugLevel},
82 ),
83 ),
84 }
85
86 s.srv, err = wish.NewServer(
87 ssh.PublicKeyAuth(s.PublicKeyHandler),
88 ssh.KeyboardInteractiveAuth(s.KeyboardInteractiveHandler),
89 wish.WithAddress(cfg.SSH.ListenAddr),
90 wish.WithHostKeyPath(cfg.SSH.KeyPath),
91 wish.WithMiddleware(mw...),
92 )
93 if err != nil {
94 return nil, err
95 }
96
97 if cfg.SSH.MaxTimeout > 0 {
98 s.srv.MaxTimeout = time.Duration(cfg.SSH.MaxTimeout) * time.Second
99 }
100
101 if cfg.SSH.IdleTimeout > 0 {
102 s.srv.IdleTimeout = time.Duration(cfg.SSH.IdleTimeout) * time.Second
103 }
104
105 // Create client ssh key
106 if _, err := os.Stat(cfg.SSH.ClientKeyPath); err != nil && os.IsNotExist(err) {
107 _, err := keygen.New(cfg.SSH.ClientKeyPath, keygen.WithKeyType(keygen.Ed25519), keygen.WithWrite())
108 if err != nil {
109 return nil, fmt.Errorf("client ssh key: %w", err)
110 }
111 }
112
113 return s, nil
114}
115
116// ListenAndServe starts the SSH server.
117func (s *SSHServer) ListenAndServe() error {
118 return s.srv.ListenAndServe()
119}
120
121// Serve starts the SSH server on the given net.Listener.
122func (s *SSHServer) Serve(l net.Listener) error {
123 return s.srv.Serve(l)
124}
125
126// Close closes the SSH server.
127func (s *SSHServer) Close() error {
128 return s.srv.Close()
129}
130
131// Shutdown gracefully shuts down the SSH server.
132func (s *SSHServer) Shutdown(ctx context.Context) error {
133 return s.srv.Shutdown(ctx)
134}
135
136// PublicKeyAuthHandler handles public key authentication.
137func (s *SSHServer) PublicKeyHandler(ctx ssh.Context, pk ssh.PublicKey) (allowed bool) {
138 if pk == nil {
139 return false
140 }
141
142 defer func(allowed *bool) {
143 publicKeyCounter.WithLabelValues(strconv.FormatBool(*allowed)).Inc()
144 }(&allowed)
145
146 user, _ := s.be.UserByPublicKey(ctx, pk)
147 if user != nil {
148 ctx.SetValue(proto.ContextKeyUser, user)
149 allowed = true
150 }
151
152 return
153}
154
155// KeyboardInteractiveHandler handles keyboard interactive authentication.
156// This is used after all public key authentication has failed.
157func (s *SSHServer) KeyboardInteractiveHandler(ctx ssh.Context, _ gossh.KeyboardInteractiveChallenge) bool {
158 ac := s.be.AllowKeyless(ctx)
159 keyboardInteractiveCounter.WithLabelValues(strconv.FormatBool(ac)).Inc()
160 return ac
161}