1package ssh
2
3import (
4 "context"
5 "errors"
6 "net"
7 "path/filepath"
8 "strconv"
9 "strings"
10 "time"
11
12 "github.com/charmbracelet/log"
13 "github.com/charmbracelet/soft-serve/server/backend"
14 cm "github.com/charmbracelet/soft-serve/server/cmd"
15 "github.com/charmbracelet/soft-serve/server/config"
16 "github.com/charmbracelet/soft-serve/server/git"
17 "github.com/charmbracelet/soft-serve/server/utils"
18 "github.com/charmbracelet/ssh"
19 "github.com/charmbracelet/wish"
20 bm "github.com/charmbracelet/wish/bubbletea"
21 lm "github.com/charmbracelet/wish/logging"
22 rm "github.com/charmbracelet/wish/recover"
23 "github.com/muesli/termenv"
24 "github.com/prometheus/client_golang/prometheus"
25 "github.com/prometheus/client_golang/prometheus/promauto"
26 gossh "golang.org/x/crypto/ssh"
27)
28
29var (
30 publicKeyCounter = promauto.NewCounterVec(prometheus.CounterOpts{
31 Namespace: "soft_serve",
32 Subsystem: "ssh",
33 Name: "public_key_auth_total",
34 Help: "The total number of public key auth requests",
35 }, []string{"key", "user", "allowed"})
36
37 keyboardInteractiveCounter = promauto.NewCounterVec(prometheus.CounterOpts{
38 Namespace: "soft_serve",
39 Subsystem: "ssh",
40 Name: "keyboard_interactive_auth_total",
41 Help: "The total number of keyboard interactive auth requests",
42 }, []string{"user", "allowed"})
43
44 uploadPackCounter = promauto.NewCounterVec(prometheus.CounterOpts{
45 Namespace: "soft_serve",
46 Subsystem: "ssh",
47 Name: "git_upload_pack_total",
48 Help: "The total number of git-upload-pack requests",
49 }, []string{"key", "user", "repo"})
50
51 receivePackCounter = promauto.NewCounterVec(prometheus.CounterOpts{
52 Namespace: "soft_serve",
53 Subsystem: "ssh",
54 Name: "git_receive_pack_total",
55 Help: "The total number of git-receive-pack requests",
56 }, []string{"key", "user", "repo"})
57
58 uploadArchiveCounter = promauto.NewCounterVec(prometheus.CounterOpts{
59 Namespace: "soft_serve",
60 Subsystem: "ssh",
61 Name: "git_upload_archive_total",
62 Help: "The total number of git-upload-archive requests",
63 }, []string{"key", "user", "repo"})
64
65 createRepoCounter = promauto.NewCounterVec(prometheus.CounterOpts{
66 Namespace: "soft_serve",
67 Subsystem: "ssh",
68 Name: "create_repo_total",
69 Help: "The total number of create repo requests",
70 }, []string{"key", "user", "repo"})
71)
72
73// SSHServer is a SSH server that implements the git protocol.
74type SSHServer struct {
75 srv *ssh.Server
76 cfg *config.Config
77 ctx context.Context
78 logger *log.Logger
79}
80
81// NewSSHServer returns a new SSHServer.
82func NewSSHServer(ctx context.Context) (*SSHServer, error) {
83 cfg := config.FromContext(ctx)
84 var err error
85 s := &SSHServer{
86 cfg: cfg,
87 ctx: ctx,
88 logger: log.FromContext(ctx).WithPrefix("ssh"),
89 }
90 logger := s.logger.StandardLog(log.StandardLogOptions{ForceLevel: log.DebugLevel})
91 mw := []wish.Middleware{
92 rm.MiddlewareWithLogger(
93 logger,
94 // BubbleTea middleware.
95 bm.MiddlewareWithProgramHandler(SessionHandler(cfg), termenv.ANSI256),
96 // CLI middleware.
97 cm.Middleware(cfg),
98 // Git middleware.
99 s.Middleware(cfg),
100 // Logging middleware.
101 lm.MiddlewareWithLogger(logger),
102 ),
103 }
104 s.srv, err = wish.NewServer(
105 ssh.PublicKeyAuth(s.PublicKeyHandler),
106 ssh.KeyboardInteractiveAuth(s.KeyboardInteractiveHandler),
107 wish.WithAddress(cfg.SSH.ListenAddr),
108 wish.WithHostKeyPath(cfg.SSH.KeyPath),
109 wish.WithMiddleware(mw...),
110 )
111 if err != nil {
112 return nil, err
113 }
114
115 if cfg.SSH.MaxTimeout > 0 {
116 s.srv.MaxTimeout = time.Duration(cfg.SSH.MaxTimeout) * time.Second
117 }
118 if cfg.SSH.IdleTimeout > 0 {
119 s.srv.IdleTimeout = time.Duration(cfg.SSH.IdleTimeout) * time.Second
120 }
121
122 return s, nil
123}
124
125// ListenAndServe starts the SSH server.
126func (s *SSHServer) ListenAndServe() error {
127 return s.srv.ListenAndServe()
128}
129
130// Serve starts the SSH server on the given net.Listener.
131func (s *SSHServer) Serve(l net.Listener) error {
132 return s.srv.Serve(l)
133}
134
135// Close closes the SSH server.
136func (s *SSHServer) Close() error {
137 return s.srv.Close()
138}
139
140// Shutdown gracefully shuts down the SSH server.
141func (s *SSHServer) Shutdown(ctx context.Context) error {
142 return s.srv.Shutdown(ctx)
143}
144
145// PublicKeyAuthHandler handles public key authentication.
146func (s *SSHServer) PublicKeyHandler(ctx ssh.Context, pk ssh.PublicKey) (allowed bool) {
147 if pk == nil {
148 return s.cfg.Backend.AllowKeyless()
149 }
150
151 ak := backend.MarshalAuthorizedKey(pk)
152 defer func(allowed *bool) {
153 publicKeyCounter.WithLabelValues(ak, ctx.User(), strconv.FormatBool(*allowed)).Inc()
154 }(&allowed)
155
156 ac := s.cfg.Backend.AccessLevelByPublicKey("", pk)
157 s.logger.Debugf("access level for %q: %s", ak, ac)
158 allowed = ac >= backend.ReadOnlyAccess
159 return
160}
161
162// KeyboardInteractiveHandler handles keyboard interactive authentication.
163func (s *SSHServer) KeyboardInteractiveHandler(ctx ssh.Context, _ gossh.KeyboardInteractiveChallenge) bool {
164 ac := s.cfg.Backend.AllowKeyless()
165 keyboardInteractiveCounter.WithLabelValues(ctx.User(), strconv.FormatBool(ac)).Inc()
166 return ac
167}
168
169// Middleware adds Git server functionality to the ssh.Server. Repos are stored
170// in the specified repo directory. The provided Hooks implementation will be
171// checked for access on a per repo basis for a ssh.Session public key.
172// Hooks.Push and Hooks.Fetch will be called on successful completion of
173// their commands.
174func (ss *SSHServer) Middleware(cfg *config.Config) wish.Middleware {
175 return func(sh ssh.Handler) ssh.Handler {
176 return func(s ssh.Session) {
177 func() {
178 cmd := s.Command()
179 if len(cmd) >= 2 && strings.HasPrefix(cmd[0], "git") {
180 gc := cmd[0]
181 // repo should be in the form of "repo.git"
182 name := utils.SanitizeRepo(cmd[1])
183 pk := s.PublicKey()
184 ak := backend.MarshalAuthorizedKey(pk)
185 access := cfg.Backend.AccessLevelByPublicKey(name, pk)
186 // git bare repositories should end in ".git"
187 // https://git-scm.com/docs/gitrepository-layout
188 repo := name + ".git"
189 reposDir := filepath.Join(cfg.DataPath, "repos")
190 if err := git.EnsureWithin(reposDir, repo); err != nil {
191 sshFatal(s, err)
192 return
193 }
194
195 // Environment variables to pass down to git hooks.
196 envs := []string{
197 "SOFT_SERVE_REPO_NAME=" + name,
198 "SOFT_SERVE_REPO_PATH=" + filepath.Join(reposDir, repo),
199 "SOFT_SERVE_PUBLIC_KEY=" + ak,
200 }
201
202 ss.logger.Debug("git middleware", "cmd", gc, "access", access.String())
203 repoDir := filepath.Join(reposDir, repo)
204 switch gc {
205 case git.ReceivePackBin:
206 if access < backend.ReadWriteAccess {
207 sshFatal(s, git.ErrNotAuthed)
208 return
209 }
210 if _, err := cfg.Backend.Repository(name); err != nil {
211 if _, err := cfg.Backend.CreateRepository(name, backend.RepositoryOptions{Private: false}); err != nil {
212 log.Errorf("failed to create repo: %s", err)
213 sshFatal(s, err)
214 return
215 }
216 createRepoCounter.WithLabelValues(ak, s.User(), name).Inc()
217 }
218 if err := git.ReceivePack(s.Context(), s, s, s.Stderr(), repoDir, envs...); err != nil {
219 sshFatal(s, git.ErrSystemMalfunction)
220 }
221 receivePackCounter.WithLabelValues(ak, s.User(), name).Inc()
222 return
223 case git.UploadPackBin, git.UploadArchiveBin:
224 if access < backend.ReadOnlyAccess {
225 sshFatal(s, git.ErrNotAuthed)
226 return
227 }
228
229 gitPack := git.UploadPack
230 counter := uploadPackCounter
231 if gc == git.UploadArchiveBin {
232 gitPack = git.UploadArchive
233 counter = uploadArchiveCounter
234 }
235
236 err := gitPack(s.Context(), s, s, s.Stderr(), repoDir, envs...)
237 if errors.Is(err, git.ErrInvalidRepo) {
238 sshFatal(s, git.ErrInvalidRepo)
239 } else if err != nil {
240 sshFatal(s, git.ErrSystemMalfunction)
241 }
242
243 counter.WithLabelValues(ak, s.User(), name).Inc()
244 }
245 }
246 }()
247 sh(s)
248 }
249 }
250}
251
252// sshFatal prints to the session's STDOUT as a git response and exit 1.
253func sshFatal(s ssh.Session, v ...interface{}) {
254 git.WritePktline(s, v...)
255 s.Exit(1) // nolint: errcheck
256}