1package ssh
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "net"
8 "os"
9 "path/filepath"
10 "strconv"
11 "strings"
12 "time"
13
14 "github.com/charmbracelet/keygen"
15 "github.com/charmbracelet/log"
16 "github.com/charmbracelet/soft-serve/server/backend"
17 cm "github.com/charmbracelet/soft-serve/server/cmd"
18 "github.com/charmbracelet/soft-serve/server/config"
19 "github.com/charmbracelet/soft-serve/server/git"
20 "github.com/charmbracelet/soft-serve/server/utils"
21 "github.com/charmbracelet/ssh"
22 "github.com/charmbracelet/wish"
23 bm "github.com/charmbracelet/wish/bubbletea"
24 lm "github.com/charmbracelet/wish/logging"
25 rm "github.com/charmbracelet/wish/recover"
26 "github.com/muesli/termenv"
27 "github.com/prometheus/client_golang/prometheus"
28 "github.com/prometheus/client_golang/prometheus/promauto"
29 gossh "golang.org/x/crypto/ssh"
30)
31
32var (
33 publicKeyCounter = promauto.NewCounterVec(prometheus.CounterOpts{
34 Namespace: "soft_serve",
35 Subsystem: "ssh",
36 Name: "public_key_auth_total",
37 Help: "The total number of public key auth requests",
38 }, []string{"key", "user", "allowed"})
39
40 noAuthCounter = promauto.NewCounterVec(prometheus.CounterOpts{
41 Namespace: "soft_serve",
42 Subsystem: "ssh",
43 Name: "none_auth_total",
44 Help: "The total number of none auth requests",
45 }, []string{"user", "allowed"})
46
47 keyboardInteractiveCounter = promauto.NewCounterVec(prometheus.CounterOpts{
48 Namespace: "soft_serve",
49 Subsystem: "ssh",
50 Name: "keyboard_interactive_auth_total",
51 Help: "The total number of keyboard interactive auth requests",
52 }, []string{"user", "allowed"})
53
54 uploadPackCounter = promauto.NewCounterVec(prometheus.CounterOpts{
55 Namespace: "soft_serve",
56 Subsystem: "ssh",
57 Name: "git_upload_pack_total",
58 Help: "The total number of git-upload-pack requests",
59 }, []string{"key", "user", "repo"})
60
61 receivePackCounter = promauto.NewCounterVec(prometheus.CounterOpts{
62 Namespace: "soft_serve",
63 Subsystem: "ssh",
64 Name: "git_receive_pack_total",
65 Help: "The total number of git-receive-pack requests",
66 }, []string{"key", "user", "repo"})
67
68 uploadArchiveCounter = promauto.NewCounterVec(prometheus.CounterOpts{
69 Namespace: "soft_serve",
70 Subsystem: "ssh",
71 Name: "git_upload_archive_total",
72 Help: "The total number of git-upload-archive requests",
73 }, []string{"key", "user", "repo"})
74
75 createRepoCounter = promauto.NewCounterVec(prometheus.CounterOpts{
76 Namespace: "soft_serve",
77 Subsystem: "ssh",
78 Name: "create_repo_total",
79 Help: "The total number of create repo requests",
80 }, []string{"key", "user", "repo"})
81)
82
83// SSHServer is a SSH server that implements the git protocol.
84type SSHServer struct {
85 srv *ssh.Server
86 cfg *config.Config
87 be backend.Backend
88 ctx context.Context
89 logger *log.Logger
90}
91
92// NewSSHServer returns a new SSHServer.
93func NewSSHServer(ctx context.Context) (*SSHServer, error) {
94 cfg := config.FromContext(ctx)
95 logger := log.FromContext(ctx).WithPrefix("ssh")
96
97 var err error
98 s := &SSHServer{
99 cfg: cfg,
100 ctx: ctx,
101 be: backend.FromContext(ctx),
102 logger: logger,
103 }
104
105 mw := []wish.Middleware{
106 rm.MiddlewareWithLogger(
107 logger,
108 // BubbleTea middleware.
109 bm.MiddlewareWithProgramHandler(SessionHandler(cfg), termenv.ANSI256),
110 // CLI middleware.
111 cm.Middleware(cfg, logger),
112 // Git middleware.
113 s.Middleware(cfg),
114 // Logging middleware.
115 lm.MiddlewareWithLogger(logger.
116 StandardLog(log.StandardLogOptions{ForceLevel: log.DebugLevel})),
117 ),
118 }
119
120 s.srv, err = wish.NewServer(
121 ssh.PublicKeyAuth(s.PublicKeyHandler),
122 ssh.KeyboardInteractiveAuth(s.KeyboardInteractiveHandler),
123 wish.WithAddress(cfg.SSH.ListenAddr),
124 wish.WithHostKeyPath(cfg.SSH.KeyPath),
125 wish.WithMiddleware(mw...),
126 )
127 if err != nil {
128 return nil, err
129 }
130
131 // Custom ServerConfig
132 // which handles "none" auth method
133 s.srv.ServerConfigCallback = s.ServerConfigCallback
134
135 if cfg.SSH.MaxTimeout > 0 {
136 s.srv.MaxTimeout = time.Duration(cfg.SSH.MaxTimeout) * time.Second
137 }
138
139 if cfg.SSH.IdleTimeout > 0 {
140 s.srv.IdleTimeout = time.Duration(cfg.SSH.IdleTimeout) * time.Second
141 }
142
143 // Create client ssh key
144 if _, err := os.Stat(cfg.SSH.ClientKeyPath); err != nil && os.IsNotExist(err) {
145 _, err := keygen.New(cfg.SSH.ClientKeyPath, keygen.WithKeyType(keygen.Ed25519), keygen.WithWrite())
146 if err != nil {
147 return nil, fmt.Errorf("client ssh key: %w", err)
148 }
149 }
150
151 return s, nil
152}
153
154// ListenAndServe starts the SSH server.
155func (s *SSHServer) ListenAndServe() error {
156 return s.srv.ListenAndServe()
157}
158
159// Serve starts the SSH server on the given net.Listener.
160func (s *SSHServer) Serve(l net.Listener) error {
161 return s.srv.Serve(l)
162}
163
164// Close closes the SSH server.
165func (s *SSHServer) Close() error {
166 return s.srv.Close()
167}
168
169// Shutdown gracefully shuts down the SSH server.
170func (s *SSHServer) Shutdown(ctx context.Context) error {
171 return s.srv.Shutdown(ctx)
172}
173
174// PublicKeyAuthHandler handles public key authentication.
175func (s *SSHServer) PublicKeyHandler(ctx ssh.Context, pk ssh.PublicKey) (allowed bool) {
176 if pk == nil {
177 return false
178 }
179
180 ak := backend.MarshalAuthorizedKey(pk)
181 defer func(allowed *bool) {
182 publicKeyCounter.WithLabelValues(ak, ctx.User(), strconv.FormatBool(*allowed)).Inc()
183 }(&allowed)
184
185 ac := s.cfg.Backend.AccessLevelByPublicKey("", pk)
186 s.logger.Debugf("access level for %q: %s", ak, ac)
187 allowed = ac >= backend.ReadWriteAccess
188 return
189}
190
191// KeyboardInteractiveHandler handles keyboard interactive authentication.
192// This is used after all public key authentication has failed.
193func (s *SSHServer) KeyboardInteractiveHandler(ctx ssh.Context, _ gossh.KeyboardInteractiveChallenge) bool {
194 ac := s.cfg.Backend.AllowKeyless()
195 keyboardInteractiveCounter.WithLabelValues(ctx.User(), strconv.FormatBool(ac)).Inc()
196 return ac
197}
198
199// NoClientAuthCallback handles "none" auth method.
200func (s *SSHServer) NoClientAuthCallback(meta gossh.ConnMetadata) (*gossh.Permissions, error) {
201 ac := s.cfg.Backend.AllowKeyless()
202 keyboardInteractiveCounter.WithLabelValues(meta.User(), strconv.FormatBool(ac)).Inc()
203 perms := &gossh.Permissions{}
204 if !ac {
205 return perms, fmt.Errorf("keyless access not allowed")
206 }
207
208 return perms, nil
209}
210
211// ServerConfigCallback returns a custom ssh.ServerConfig.
212func (s *SSHServer) ServerConfigCallback(ctx ssh.Context) *gossh.ServerConfig {
213 return &gossh.ServerConfig{
214 NoClientAuth: true,
215 NoClientAuthCallback: s.NoClientAuthCallback,
216 }
217}
218
219// Middleware adds Git server functionality to the ssh.Server. Repos are stored
220// in the specified repo directory. The provided Hooks implementation will be
221// checked for access on a per repo basis for a ssh.Session public key.
222// Hooks.Push and Hooks.Fetch will be called on successful completion of
223// their commands.
224func (ss *SSHServer) Middleware(cfg *config.Config) wish.Middleware {
225 return func(sh ssh.Handler) ssh.Handler {
226 return func(s ssh.Session) {
227 func() {
228 cmd := s.Command()
229 ctx := s.Context()
230 be := ss.be.WithContext(ctx)
231 if len(cmd) >= 2 && strings.HasPrefix(cmd[0], "git") {
232 gc := cmd[0]
233 // repo should be in the form of "repo.git"
234 name := utils.SanitizeRepo(cmd[1])
235 pk := s.PublicKey()
236 ak := backend.MarshalAuthorizedKey(pk)
237 access := cfg.Backend.AccessLevelByPublicKey(name, pk)
238 // git bare repositories should end in ".git"
239 // https://git-scm.com/docs/gitrepository-layout
240 repo := name + ".git"
241 reposDir := filepath.Join(cfg.DataPath, "repos")
242 if err := git.EnsureWithin(reposDir, repo); err != nil {
243 sshFatal(s, err)
244 return
245 }
246
247 // Environment variables to pass down to git hooks.
248 envs := []string{
249 "SOFT_SERVE_REPO_NAME=" + name,
250 "SOFT_SERVE_REPO_PATH=" + filepath.Join(reposDir, repo),
251 "SOFT_SERVE_PUBLIC_KEY=" + ak,
252 }
253
254 ss.logger.Debug("git middleware", "cmd", gc, "access", access.String())
255 repoDir := filepath.Join(reposDir, repo)
256 switch gc {
257 case git.ReceivePackBin:
258 if access < backend.ReadWriteAccess {
259 sshFatal(s, git.ErrNotAuthed)
260 return
261 }
262 if _, err := be.Repository(name); err != nil {
263 if _, err := be.CreateRepository(name, backend.RepositoryOptions{Private: false}); err != nil {
264 log.Errorf("failed to create repo: %s", err)
265 sshFatal(s, err)
266 return
267 }
268 createRepoCounter.WithLabelValues(ak, s.User(), name).Inc()
269 }
270 if err := git.ReceivePack(s.Context(), s, s, s.Stderr(), repoDir, envs...); err != nil {
271 sshFatal(s, git.ErrSystemMalfunction)
272 }
273 receivePackCounter.WithLabelValues(ak, s.User(), name).Inc()
274 return
275 case git.UploadPackBin, git.UploadArchiveBin:
276 if access < backend.ReadOnlyAccess {
277 sshFatal(s, git.ErrNotAuthed)
278 return
279 }
280
281 gitPack := git.UploadPack
282 counter := uploadPackCounter
283 if gc == git.UploadArchiveBin {
284 gitPack = git.UploadArchive
285 counter = uploadArchiveCounter
286 }
287
288 err := gitPack(ctx, s, s, s.Stderr(), repoDir, envs...)
289 if errors.Is(err, git.ErrInvalidRepo) {
290 sshFatal(s, git.ErrInvalidRepo)
291 } else if err != nil {
292 sshFatal(s, git.ErrSystemMalfunction)
293 }
294
295 counter.WithLabelValues(ak, s.User(), name).Inc()
296 }
297 }
298 }()
299 sh(s)
300 }
301 }
302}
303
304// sshFatal prints to the session's STDOUT as a git response and exit 1.
305func sshFatal(s ssh.Session, v ...interface{}) {
306 git.WritePktline(s, v...)
307 s.Exit(1) // nolint: errcheck
308}