1package ssh
2
3import (
4 "context"
5 "fmt"
6 "net"
7 "os"
8 "strconv"
9 "time"
10
11 "github.com/charmbracelet/keygen"
12 "github.com/charmbracelet/log"
13 "github.com/charmbracelet/soft-serve/server/backend"
14 "github.com/charmbracelet/soft-serve/server/config"
15 "github.com/charmbracelet/soft-serve/server/db"
16 "github.com/charmbracelet/soft-serve/server/proto"
17 "github.com/charmbracelet/soft-serve/server/store"
18 "github.com/charmbracelet/ssh"
19 "github.com/charmbracelet/wish"
20 bm "github.com/charmbracelet/wish/bubbletea"
21 rm "github.com/charmbracelet/wish/recover"
22 "github.com/muesli/termenv"
23 "github.com/prometheus/client_golang/prometheus"
24 "github.com/prometheus/client_golang/prometheus/promauto"
25 gossh "golang.org/x/crypto/ssh"
26)
27
28var (
29 publicKeyCounter = promauto.NewCounterVec(prometheus.CounterOpts{
30 Namespace: "soft_serve",
31 Subsystem: "ssh",
32 Name: "public_key_auth_total",
33 Help: "The total number of public key auth requests",
34 }, []string{"allowed"})
35
36 keyboardInteractiveCounter = promauto.NewCounterVec(prometheus.CounterOpts{
37 Namespace: "soft_serve",
38 Subsystem: "ssh",
39 Name: "keyboard_interactive_auth_total",
40 Help: "The total number of keyboard interactive auth requests",
41 }, []string{"allowed"})
42)
43
44// SSHServer is a SSH server that implements the git protocol.
45type SSHServer struct { // nolint: revive
46 srv *ssh.Server
47 cfg *config.Config
48 be *backend.Backend
49 ctx context.Context
50 logger *log.Logger
51}
52
53// NewSSHServer returns a new SSHServer.
54func NewSSHServer(ctx context.Context) (*SSHServer, error) {
55 cfg := config.FromContext(ctx)
56 logger := log.FromContext(ctx).WithPrefix("ssh")
57 dbx := db.FromContext(ctx)
58 datastore := store.FromContext(ctx)
59 be := backend.FromContext(ctx)
60
61 var err error
62 s := &SSHServer{
63 cfg: cfg,
64 ctx: ctx,
65 be: be,
66 logger: logger,
67 }
68
69 mw := []wish.Middleware{
70 rm.MiddlewareWithLogger(
71 logger,
72 // BubbleTea middleware.
73 bm.MiddlewareWithProgramHandler(SessionHandler, termenv.ANSI256),
74 // CLI middleware.
75 CommandMiddleware,
76 // Logging middleware.
77 LoggingMiddleware,
78 // Context middleware.
79 ContextMiddleware(cfg, dbx, datastore, be, logger),
80 ),
81 }
82
83 s.srv, err = wish.NewServer(
84 ssh.PublicKeyAuth(s.PublicKeyHandler),
85 ssh.KeyboardInteractiveAuth(s.KeyboardInteractiveHandler),
86 wish.WithAddress(cfg.SSH.ListenAddr),
87 wish.WithHostKeyPath(cfg.SSH.KeyPath),
88 wish.WithMiddleware(mw...),
89 )
90 if err != nil {
91 return nil, err
92 }
93
94 if cfg.SSH.MaxTimeout > 0 {
95 s.srv.MaxTimeout = time.Duration(cfg.SSH.MaxTimeout) * time.Second
96 }
97
98 if cfg.SSH.IdleTimeout > 0 {
99 s.srv.IdleTimeout = time.Duration(cfg.SSH.IdleTimeout) * time.Second
100 }
101
102 // Create client ssh key
103 if _, err := os.Stat(cfg.SSH.ClientKeyPath); err != nil && os.IsNotExist(err) {
104 _, err := keygen.New(cfg.SSH.ClientKeyPath, keygen.WithKeyType(keygen.Ed25519), keygen.WithWrite())
105 if err != nil {
106 return nil, fmt.Errorf("client ssh key: %w", err)
107 }
108 }
109
110 return s, nil
111}
112
113// ListenAndServe starts the SSH server.
114func (s *SSHServer) ListenAndServe() error {
115 return s.srv.ListenAndServe()
116}
117
118// Serve starts the SSH server on the given net.Listener.
119func (s *SSHServer) Serve(l net.Listener) error {
120 return s.srv.Serve(l)
121}
122
123// Close closes the SSH server.
124func (s *SSHServer) Close() error {
125 return s.srv.Close()
126}
127
128// Shutdown gracefully shuts down the SSH server.
129func (s *SSHServer) Shutdown(ctx context.Context) error {
130 return s.srv.Shutdown(ctx)
131}
132
133// PublicKeyAuthHandler handles public key authentication.
134func (s *SSHServer) PublicKeyHandler(ctx ssh.Context, pk ssh.PublicKey) (allowed bool) {
135 if pk == nil {
136 return false
137 }
138
139 defer func(allowed *bool) {
140 publicKeyCounter.WithLabelValues(strconv.FormatBool(*allowed)).Inc()
141 }(&allowed)
142
143 user, _ := s.be.UserByPublicKey(ctx, pk)
144 if user != nil {
145 ctx.SetValue(proto.ContextKeyUser, user)
146 allowed = true
147 }
148
149 return
150}
151
152// KeyboardInteractiveHandler handles keyboard interactive authentication.
153// This is used after all public key authentication has failed.
154func (s *SSHServer) KeyboardInteractiveHandler(ctx ssh.Context, _ gossh.KeyboardInteractiveChallenge) bool {
155 ac := s.be.AllowKeyless(ctx)
156 keyboardInteractiveCounter.WithLabelValues(strconv.FormatBool(ac)).Inc()
157 return ac
158}