1package ssh
2
3import (
4 "context"
5 "fmt"
6 "net"
7 "os"
8 "strconv"
9 "time"
10
11 "github.com/charmbracelet/keygen"
12 "github.com/charmbracelet/log"
13 "github.com/charmbracelet/soft-serve/server/access"
14 "github.com/charmbracelet/soft-serve/server/backend"
15 "github.com/charmbracelet/soft-serve/server/config"
16 "github.com/charmbracelet/soft-serve/server/db"
17 "github.com/charmbracelet/soft-serve/server/git"
18 "github.com/charmbracelet/soft-serve/server/proto"
19 "github.com/charmbracelet/soft-serve/server/sshutils"
20 "github.com/charmbracelet/soft-serve/server/store"
21 "github.com/charmbracelet/ssh"
22 "github.com/charmbracelet/wish"
23 bm "github.com/charmbracelet/wish/bubbletea"
24 lm "github.com/charmbracelet/wish/logging"
25 rm "github.com/charmbracelet/wish/recover"
26 "github.com/muesli/termenv"
27 "github.com/prometheus/client_golang/prometheus"
28 "github.com/prometheus/client_golang/prometheus/promauto"
29 gossh "golang.org/x/crypto/ssh"
30)
31
32var (
33 publicKeyCounter = promauto.NewCounterVec(prometheus.CounterOpts{
34 Namespace: "soft_serve",
35 Subsystem: "ssh",
36 Name: "public_key_auth_total",
37 Help: "The total number of public key auth requests",
38 }, []string{"allowed"})
39
40 keyboardInteractiveCounter = promauto.NewCounterVec(prometheus.CounterOpts{
41 Namespace: "soft_serve",
42 Subsystem: "ssh",
43 Name: "keyboard_interactive_auth_total",
44 Help: "The total number of keyboard interactive auth requests",
45 }, []string{"allowed"})
46
47 uploadPackCounter = promauto.NewCounterVec(prometheus.CounterOpts{
48 Namespace: "soft_serve",
49 Subsystem: "git",
50 Name: "upload_pack_total",
51 Help: "The total number of git-upload-pack requests",
52 }, []string{"repo"})
53
54 receivePackCounter = promauto.NewCounterVec(prometheus.CounterOpts{
55 Namespace: "soft_serve",
56 Subsystem: "git",
57 Name: "receive_pack_total",
58 Help: "The total number of git-receive-pack requests",
59 }, []string{"repo"})
60
61 uploadArchiveCounter = promauto.NewCounterVec(prometheus.CounterOpts{
62 Namespace: "soft_serve",
63 Subsystem: "git",
64 Name: "upload_archive_total",
65 Help: "The total number of git-upload-archive requests",
66 }, []string{"repo"})
67
68 uploadPackSeconds = promauto.NewCounterVec(prometheus.CounterOpts{
69 Namespace: "soft_serve",
70 Subsystem: "git",
71 Name: "upload_pack_seconds_total",
72 Help: "The total time spent on git-upload-pack requests",
73 }, []string{"repo"})
74
75 receivePackSeconds = promauto.NewCounterVec(prometheus.CounterOpts{
76 Namespace: "soft_serve",
77 Subsystem: "git",
78 Name: "receive_pack_seconds_total",
79 Help: "The total time spent on git-receive-pack requests",
80 }, []string{"repo"})
81
82 uploadArchiveSeconds = promauto.NewCounterVec(prometheus.CounterOpts{
83 Namespace: "soft_serve",
84 Subsystem: "git",
85 Name: "upload_archive_seconds_total",
86 Help: "The total time spent on git-upload-archive requests",
87 }, []string{"repo"})
88
89 createRepoCounter = promauto.NewCounterVec(prometheus.CounterOpts{
90 Namespace: "soft_serve",
91 Subsystem: "ssh",
92 Name: "create_repo_total",
93 Help: "The total number of create repo requests",
94 }, []string{"repo"})
95)
96
97// SSHServer is a SSH server that implements the git protocol.
98type SSHServer struct { // nolint: revive
99 srv *ssh.Server
100 cfg *config.Config
101 be *backend.Backend
102 ctx context.Context
103 logger *log.Logger
104}
105
106// NewSSHServer returns a new SSHServer.
107func NewSSHServer(ctx context.Context) (*SSHServer, error) {
108 cfg := config.FromContext(ctx)
109 logger := log.FromContext(ctx).WithPrefix("ssh")
110 dbx := db.FromContext(ctx)
111 datastore := store.FromContext(ctx)
112 be := backend.FromContext(ctx)
113
114 var err error
115 s := &SSHServer{
116 cfg: cfg,
117 ctx: ctx,
118 be: be,
119 logger: logger,
120 }
121
122 mw := []wish.Middleware{
123 rm.MiddlewareWithLogger(
124 logger,
125 // BubbleTea middleware.
126 bm.MiddlewareWithProgramHandler(SessionHandler, termenv.ANSI256),
127 // CLI middleware.
128 CommandMiddleware,
129 // Context middleware.
130 ContextMiddleware(cfg, dbx, datastore, be, logger),
131 // Logging middleware.
132 lm.MiddlewareWithLogger(
133 &loggerAdapter{logger, log.DebugLevel},
134 ),
135 ),
136 }
137
138 s.srv, err = wish.NewServer(
139 ssh.PublicKeyAuth(s.PublicKeyHandler),
140 ssh.KeyboardInteractiveAuth(s.KeyboardInteractiveHandler),
141 wish.WithAddress(cfg.SSH.ListenAddr),
142 wish.WithHostKeyPath(cfg.SSH.KeyPath),
143 wish.WithMiddleware(mw...),
144 )
145 if err != nil {
146 return nil, err
147 }
148
149 if cfg.SSH.MaxTimeout > 0 {
150 s.srv.MaxTimeout = time.Duration(cfg.SSH.MaxTimeout) * time.Second
151 }
152
153 if cfg.SSH.IdleTimeout > 0 {
154 s.srv.IdleTimeout = time.Duration(cfg.SSH.IdleTimeout) * time.Second
155 }
156
157 // Create client ssh key
158 if _, err := os.Stat(cfg.SSH.ClientKeyPath); err != nil && os.IsNotExist(err) {
159 _, err := keygen.New(cfg.SSH.ClientKeyPath, keygen.WithKeyType(keygen.Ed25519), keygen.WithWrite())
160 if err != nil {
161 return nil, fmt.Errorf("client ssh key: %w", err)
162 }
163 }
164
165 return s, nil
166}
167
168// ListenAndServe starts the SSH server.
169func (s *SSHServer) ListenAndServe() error {
170 return s.srv.ListenAndServe()
171}
172
173// Serve starts the SSH server on the given net.Listener.
174func (s *SSHServer) Serve(l net.Listener) error {
175 return s.srv.Serve(l)
176}
177
178// Close closes the SSH server.
179func (s *SSHServer) Close() error {
180 return s.srv.Close()
181}
182
183// Shutdown gracefully shuts down the SSH server.
184func (s *SSHServer) Shutdown(ctx context.Context) error {
185 return s.srv.Shutdown(ctx)
186}
187
188// PublicKeyAuthHandler handles public key authentication.
189func (s *SSHServer) PublicKeyHandler(ctx ssh.Context, pk ssh.PublicKey) (allowed bool) {
190 if pk == nil {
191 return false
192 }
193
194 ak := sshutils.MarshalAuthorizedKey(pk)
195 defer func(allowed *bool) {
196 publicKeyCounter.WithLabelValues(strconv.FormatBool(*allowed)).Inc()
197 }(&allowed)
198
199 user, _ := s.be.UserByPublicKey(ctx, pk)
200 ctx.SetValue(proto.ContextKeyUser, user)
201
202 ac := s.be.AccessLevelForUser(ctx, "", user)
203 s.logger.Debugf("access level for %q: %s", ak, ac)
204 allowed = ac >= access.ReadWriteAccess
205 return
206}
207
208// KeyboardInteractiveHandler handles keyboard interactive authentication.
209// This is used after all public key authentication has failed.
210func (s *SSHServer) KeyboardInteractiveHandler(ctx ssh.Context, _ gossh.KeyboardInteractiveChallenge) bool {
211 ac := s.be.AllowKeyless(ctx)
212 keyboardInteractiveCounter.WithLabelValues(strconv.FormatBool(ac)).Inc()
213 return ac
214}
215
216// sshFatal prints to the session's STDOUT as a git response and exit 1.
217func sshFatal(s ssh.Session, err error) {
218 git.WritePktlineErr(s, err) // nolint: errcheck
219 s.Exit(1) // nolint: errcheck
220}