1package ssh
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "net"
8 "os"
9 "path/filepath"
10 "strconv"
11 "strings"
12 "time"
13
14 "github.com/charmbracelet/keygen"
15 "github.com/charmbracelet/log"
16 "github.com/charmbracelet/soft-serve/server/backend"
17 cm "github.com/charmbracelet/soft-serve/server/cmd"
18 "github.com/charmbracelet/soft-serve/server/config"
19 "github.com/charmbracelet/soft-serve/server/git"
20 "github.com/charmbracelet/soft-serve/server/utils"
21 "github.com/charmbracelet/ssh"
22 "github.com/charmbracelet/wish"
23 bm "github.com/charmbracelet/wish/bubbletea"
24 lm "github.com/charmbracelet/wish/logging"
25 rm "github.com/charmbracelet/wish/recover"
26 "github.com/muesli/termenv"
27 "github.com/prometheus/client_golang/prometheus"
28 "github.com/prometheus/client_golang/prometheus/promauto"
29 gossh "golang.org/x/crypto/ssh"
30)
31
32var (
33 publicKeyCounter = promauto.NewCounterVec(prometheus.CounterOpts{
34 Namespace: "soft_serve",
35 Subsystem: "ssh",
36 Name: "public_key_auth_total",
37 Help: "The total number of public key auth requests",
38 }, []string{"key", "user", "allowed"})
39
40 keyboardInteractiveCounter = promauto.NewCounterVec(prometheus.CounterOpts{
41 Namespace: "soft_serve",
42 Subsystem: "ssh",
43 Name: "keyboard_interactive_auth_total",
44 Help: "The total number of keyboard interactive auth requests",
45 }, []string{"user", "allowed"})
46
47 uploadPackCounter = promauto.NewCounterVec(prometheus.CounterOpts{
48 Namespace: "soft_serve",
49 Subsystem: "ssh",
50 Name: "git_upload_pack_total",
51 Help: "The total number of git-upload-pack requests",
52 }, []string{"key", "user", "repo"})
53
54 receivePackCounter = promauto.NewCounterVec(prometheus.CounterOpts{
55 Namespace: "soft_serve",
56 Subsystem: "ssh",
57 Name: "git_receive_pack_total",
58 Help: "The total number of git-receive-pack requests",
59 }, []string{"key", "user", "repo"})
60
61 uploadArchiveCounter = promauto.NewCounterVec(prometheus.CounterOpts{
62 Namespace: "soft_serve",
63 Subsystem: "ssh",
64 Name: "git_upload_archive_total",
65 Help: "The total number of git-upload-archive requests",
66 }, []string{"key", "user", "repo"})
67
68 createRepoCounter = promauto.NewCounterVec(prometheus.CounterOpts{
69 Namespace: "soft_serve",
70 Subsystem: "ssh",
71 Name: "create_repo_total",
72 Help: "The total number of create repo requests",
73 }, []string{"key", "user", "repo"})
74)
75
76// SSHServer is a SSH server that implements the git protocol.
77type SSHServer struct {
78 srv *ssh.Server
79 cfg *config.Config
80 ctx context.Context
81 logger *log.Logger
82}
83
84// NewSSHServer returns a new SSHServer.
85func NewSSHServer(ctx context.Context) (*SSHServer, error) {
86 cfg := config.FromContext(ctx)
87
88 var err error
89 s := &SSHServer{
90 cfg: cfg,
91 ctx: ctx,
92 logger: log.FromContext(ctx).WithPrefix("ssh"),
93 }
94
95 logger := s.logger.StandardLog(log.StandardLogOptions{ForceLevel: log.DebugLevel})
96 mw := []wish.Middleware{
97 rm.MiddlewareWithLogger(
98 logger,
99 // BubbleTea middleware.
100 bm.MiddlewareWithProgramHandler(SessionHandler(cfg), termenv.ANSI256),
101 // CLI middleware.
102 cm.Middleware(cfg),
103 // Git middleware.
104 s.Middleware(cfg),
105 // Logging middleware.
106 lm.MiddlewareWithLogger(logger),
107 ),
108 }
109
110 s.srv, err = wish.NewServer(
111 ssh.PublicKeyAuth(s.PublicKeyHandler),
112 ssh.KeyboardInteractiveAuth(s.KeyboardInteractiveHandler),
113 wish.WithAddress(cfg.SSH.ListenAddr),
114 wish.WithHostKeyPath(cfg.SSH.KeyPath),
115 wish.WithMiddleware(mw...),
116 )
117 if err != nil {
118 return nil, err
119 }
120
121 if cfg.SSH.MaxTimeout > 0 {
122 s.srv.MaxTimeout = time.Duration(cfg.SSH.MaxTimeout) * time.Second
123 }
124
125 if cfg.SSH.IdleTimeout > 0 {
126 s.srv.IdleTimeout = time.Duration(cfg.SSH.IdleTimeout) * time.Second
127 }
128
129 // Create client ssh key
130 if _, err := os.Stat(cfg.SSH.ClientKeyPath); err != nil && os.IsNotExist(err) {
131 _, err := keygen.New(cfg.SSH.ClientKeyPath, keygen.WithKeyType(keygen.Ed25519), keygen.WithWrite())
132 if err != nil {
133 return nil, fmt.Errorf("client ssh key: %w", err)
134 }
135 }
136
137 return s, nil
138}
139
140// ListenAndServe starts the SSH server.
141func (s *SSHServer) ListenAndServe() error {
142 return s.srv.ListenAndServe()
143}
144
145// Serve starts the SSH server on the given net.Listener.
146func (s *SSHServer) Serve(l net.Listener) error {
147 return s.srv.Serve(l)
148}
149
150// Close closes the SSH server.
151func (s *SSHServer) Close() error {
152 return s.srv.Close()
153}
154
155// Shutdown gracefully shuts down the SSH server.
156func (s *SSHServer) Shutdown(ctx context.Context) error {
157 return s.srv.Shutdown(ctx)
158}
159
160// PublicKeyAuthHandler handles public key authentication.
161func (s *SSHServer) PublicKeyHandler(ctx ssh.Context, pk ssh.PublicKey) (allowed bool) {
162 if pk == nil {
163 return s.cfg.Backend.AllowKeyless()
164 }
165
166 ak := backend.MarshalAuthorizedKey(pk)
167 defer func(allowed *bool) {
168 publicKeyCounter.WithLabelValues(ak, ctx.User(), strconv.FormatBool(*allowed)).Inc()
169 }(&allowed)
170
171 ac := s.cfg.Backend.AccessLevelByPublicKey("", pk)
172 s.logger.Debugf("access level for %q: %s", ak, ac)
173 allowed = ac >= backend.ReadOnlyAccess
174 return
175}
176
177// KeyboardInteractiveHandler handles keyboard interactive authentication.
178func (s *SSHServer) KeyboardInteractiveHandler(ctx ssh.Context, _ gossh.KeyboardInteractiveChallenge) bool {
179 ac := s.cfg.Backend.AllowKeyless()
180 keyboardInteractiveCounter.WithLabelValues(ctx.User(), strconv.FormatBool(ac)).Inc()
181 return ac
182}
183
184// Middleware adds Git server functionality to the ssh.Server. Repos are stored
185// in the specified repo directory. The provided Hooks implementation will be
186// checked for access on a per repo basis for a ssh.Session public key.
187// Hooks.Push and Hooks.Fetch will be called on successful completion of
188// their commands.
189func (ss *SSHServer) Middleware(cfg *config.Config) wish.Middleware {
190 return func(sh ssh.Handler) ssh.Handler {
191 return func(s ssh.Session) {
192 func() {
193 cmd := s.Command()
194 if len(cmd) >= 2 && strings.HasPrefix(cmd[0], "git") {
195 gc := cmd[0]
196 // repo should be in the form of "repo.git"
197 name := utils.SanitizeRepo(cmd[1])
198 pk := s.PublicKey()
199 ak := backend.MarshalAuthorizedKey(pk)
200 access := cfg.Backend.AccessLevelByPublicKey(name, pk)
201 // git bare repositories should end in ".git"
202 // https://git-scm.com/docs/gitrepository-layout
203 repo := name + ".git"
204 reposDir := filepath.Join(cfg.DataPath, "repos")
205 if err := git.EnsureWithin(reposDir, repo); err != nil {
206 sshFatal(s, err)
207 return
208 }
209
210 // Environment variables to pass down to git hooks.
211 envs := []string{
212 "SOFT_SERVE_REPO_NAME=" + name,
213 "SOFT_SERVE_REPO_PATH=" + filepath.Join(reposDir, repo),
214 "SOFT_SERVE_PUBLIC_KEY=" + ak,
215 }
216
217 ss.logger.Debug("git middleware", "cmd", gc, "access", access.String())
218 repoDir := filepath.Join(reposDir, repo)
219 switch gc {
220 case git.ReceivePackBin:
221 if access < backend.ReadWriteAccess {
222 sshFatal(s, git.ErrNotAuthed)
223 return
224 }
225 if _, err := cfg.Backend.Repository(name); err != nil {
226 if _, err := cfg.Backend.CreateRepository(name, backend.RepositoryOptions{Private: false}); err != nil {
227 log.Errorf("failed to create repo: %s", err)
228 sshFatal(s, err)
229 return
230 }
231 createRepoCounter.WithLabelValues(ak, s.User(), name).Inc()
232 }
233 if err := git.ReceivePack(s.Context(), s, s, s.Stderr(), repoDir, envs...); err != nil {
234 sshFatal(s, git.ErrSystemMalfunction)
235 }
236 receivePackCounter.WithLabelValues(ak, s.User(), name).Inc()
237 return
238 case git.UploadPackBin, git.UploadArchiveBin:
239 if access < backend.ReadOnlyAccess {
240 sshFatal(s, git.ErrNotAuthed)
241 return
242 }
243
244 gitPack := git.UploadPack
245 counter := uploadPackCounter
246 if gc == git.UploadArchiveBin {
247 gitPack = git.UploadArchive
248 counter = uploadArchiveCounter
249 }
250
251 err := gitPack(s.Context(), s, s, s.Stderr(), repoDir, envs...)
252 if errors.Is(err, git.ErrInvalidRepo) {
253 sshFatal(s, git.ErrInvalidRepo)
254 } else if err != nil {
255 sshFatal(s, git.ErrSystemMalfunction)
256 }
257
258 counter.WithLabelValues(ak, s.User(), name).Inc()
259 }
260 }
261 }()
262 sh(s)
263 }
264 }
265}
266
267// sshFatal prints to the session's STDOUT as a git response and exit 1.
268func sshFatal(s ssh.Session, v ...interface{}) {
269 git.WritePktline(s, v...)
270 s.Exit(1) // nolint: errcheck
271}