1package ssh
2
3import (
4 "context"
5 "errors"
6 "net"
7 "path/filepath"
8 "strconv"
9 "strings"
10 "time"
11
12 "github.com/charmbracelet/log"
13 "github.com/charmbracelet/soft-serve/server/backend"
14 cm "github.com/charmbracelet/soft-serve/server/cmd"
15 "github.com/charmbracelet/soft-serve/server/config"
16 "github.com/charmbracelet/soft-serve/server/git"
17 "github.com/charmbracelet/soft-serve/server/hooks"
18 "github.com/charmbracelet/soft-serve/server/utils"
19 "github.com/charmbracelet/ssh"
20 "github.com/charmbracelet/wish"
21 bm "github.com/charmbracelet/wish/bubbletea"
22 lm "github.com/charmbracelet/wish/logging"
23 rm "github.com/charmbracelet/wish/recover"
24 "github.com/muesli/termenv"
25 "github.com/prometheus/client_golang/prometheus"
26 "github.com/prometheus/client_golang/prometheus/promauto"
27 gossh "golang.org/x/crypto/ssh"
28)
29
30var (
31 logger = log.WithPrefix("server.ssh")
32)
33
34var (
35 publicKeyCounter = promauto.NewCounterVec(prometheus.CounterOpts{
36 Namespace: "soft_serve",
37 Subsystem: "ssh",
38 Name: "public_key_auth_total",
39 Help: "The total number of public key auth requests",
40 }, []string{"key", "user", "allowed"})
41
42 keyboardInteractiveCounter = promauto.NewCounterVec(prometheus.CounterOpts{
43 Namespace: "soft_serve",
44 Subsystem: "ssh",
45 Name: "keyboard_interactive_auth_total",
46 Help: "The total number of keyboard interactive auth requests",
47 }, []string{"user", "allowed"})
48
49 uploadPackCounter = promauto.NewCounterVec(prometheus.CounterOpts{
50 Namespace: "soft_serve",
51 Subsystem: "ssh",
52 Name: "git_upload_pack_total",
53 Help: "The total number of git-upload-pack requests",
54 }, []string{"key", "user", "repo"})
55
56 receivePackCounter = promauto.NewCounterVec(prometheus.CounterOpts{
57 Namespace: "soft_serve",
58 Subsystem: "ssh",
59 Name: "git_receive_pack_total",
60 Help: "The total number of git-receive-pack requests",
61 }, []string{"key", "user", "repo"})
62
63 uploadArchiveCounter = promauto.NewCounterVec(prometheus.CounterOpts{
64 Namespace: "soft_serve",
65 Subsystem: "ssh",
66 Name: "git_upload_archive_total",
67 Help: "The total number of git-upload-archive requests",
68 }, []string{"key", "user", "repo"})
69
70 createRepoCounter = promauto.NewCounterVec(prometheus.CounterOpts{
71 Namespace: "soft_serve",
72 Subsystem: "ssh",
73 Name: "create_repo_total",
74 Help: "The total number of create repo requests",
75 }, []string{"key", "user", "repo"})
76)
77
78// SSHServer is a SSH server that implements the git protocol.
79type SSHServer struct {
80 srv *ssh.Server
81 cfg *config.Config
82}
83
84// NewSSHServer returns a new SSHServer.
85func NewSSHServer(cfg *config.Config, hooks hooks.Hooks) (*SSHServer, error) {
86 var err error
87 s := &SSHServer{cfg: cfg}
88 logger := logger.StandardLog(log.StandardLogOptions{ForceLevel: log.DebugLevel})
89 mw := []wish.Middleware{
90 rm.MiddlewareWithLogger(
91 logger,
92 // BubbleTea middleware.
93 bm.MiddlewareWithProgramHandler(SessionHandler(cfg), termenv.ANSI256),
94 // CLI middleware.
95 cm.Middleware(cfg, hooks),
96 // Git middleware.
97 s.Middleware(cfg),
98 // Logging middleware.
99 lm.MiddlewareWithLogger(logger),
100 ),
101 }
102 s.srv, err = wish.NewServer(
103 ssh.PublicKeyAuth(s.PublicKeyHandler),
104 ssh.KeyboardInteractiveAuth(s.KeyboardInteractiveHandler),
105 wish.WithAddress(cfg.SSH.ListenAddr),
106 wish.WithHostKeyPath(cfg.SSH.KeyPath),
107 wish.WithMiddleware(mw...),
108 )
109 if err != nil {
110 return nil, err
111 }
112
113 if cfg.SSH.MaxTimeout > 0 {
114 s.srv.MaxTimeout = time.Duration(cfg.SSH.MaxTimeout) * time.Second
115 }
116 if cfg.SSH.IdleTimeout > 0 {
117 s.srv.IdleTimeout = time.Duration(cfg.SSH.IdleTimeout) * time.Second
118 }
119
120 return s, nil
121}
122
123// ListenAndServe starts the SSH server.
124func (s *SSHServer) ListenAndServe() error {
125 return s.srv.ListenAndServe()
126}
127
128// Serve starts the SSH server on the given net.Listener.
129func (s *SSHServer) Serve(l net.Listener) error {
130 return s.srv.Serve(l)
131}
132
133// Close closes the SSH server.
134func (s *SSHServer) Close() error {
135 return s.srv.Close()
136}
137
138// Shutdown gracefully shuts down the SSH server.
139func (s *SSHServer) Shutdown(ctx context.Context) error {
140 return s.srv.Shutdown(ctx)
141}
142
143// PublicKeyAuthHandler handles public key authentication.
144func (s *SSHServer) PublicKeyHandler(ctx ssh.Context, pk ssh.PublicKey) (allowed bool) {
145 if pk == nil {
146 return s.cfg.Backend.AllowKeyless()
147 }
148
149 ak := backend.MarshalAuthorizedKey(pk)
150 defer func(allowed *bool) {
151 publicKeyCounter.WithLabelValues(ak, ctx.User(), strconv.FormatBool(*allowed)).Inc()
152 }(&allowed)
153
154 ac := s.cfg.Backend.AccessLevelByPublicKey("", pk)
155 logger.Debugf("access level for %q: %s", ak, ac)
156 allowed = ac >= backend.ReadOnlyAccess
157 return
158}
159
160// KeyboardInteractiveHandler handles keyboard interactive authentication.
161func (s *SSHServer) KeyboardInteractiveHandler(ctx ssh.Context, _ gossh.KeyboardInteractiveChallenge) bool {
162 ac := s.cfg.Backend.AllowKeyless()
163 keyboardInteractiveCounter.WithLabelValues(ctx.User(), strconv.FormatBool(ac)).Inc()
164 return ac
165}
166
167// Middleware adds Git server functionality to the ssh.Server. Repos are stored
168// in the specified repo directory. The provided Hooks implementation will be
169// checked for access on a per repo basis for a ssh.Session public key.
170// Hooks.Push and Hooks.Fetch will be called on successful completion of
171// their commands.
172func (s *SSHServer) Middleware(cfg *config.Config) wish.Middleware {
173 return func(sh ssh.Handler) ssh.Handler {
174 return func(s ssh.Session) {
175 func() {
176 cmd := s.Command()
177 if len(cmd) >= 2 && strings.HasPrefix(cmd[0], "git") {
178 gc := cmd[0]
179 // repo should be in the form of "repo.git"
180 name := utils.SanitizeRepo(cmd[1])
181 pk := s.PublicKey()
182 ak := backend.MarshalAuthorizedKey(pk)
183 access := cfg.Backend.AccessLevelByPublicKey(name, pk)
184 // git bare repositories should end in ".git"
185 // https://git-scm.com/docs/gitrepository-layout
186 repo := name + ".git"
187 reposDir := filepath.Join(cfg.DataPath, "repos")
188 if err := git.EnsureWithin(reposDir, repo); err != nil {
189 sshFatal(s, err)
190 return
191 }
192
193 logger.Debug("git middleware", "cmd", gc, "access", access.String())
194 repoDir := filepath.Join(reposDir, repo)
195 switch gc {
196 case git.ReceivePackBin:
197 if access < backend.ReadWriteAccess {
198 sshFatal(s, git.ErrNotAuthed)
199 return
200 }
201 if _, err := cfg.Backend.Repository(name); err != nil {
202 if _, err := cfg.Backend.CreateRepository(name, backend.RepositoryOptions{Private: false}); err != nil {
203 log.Errorf("failed to create repo: %s", err)
204 sshFatal(s, err)
205 return
206 }
207 createRepoCounter.WithLabelValues(ak, s.User(), name).Inc()
208 }
209 if err := git.ReceivePack(s, s, s.Stderr(), repoDir); err != nil {
210 sshFatal(s, git.ErrSystemMalfunction)
211 }
212 receivePackCounter.WithLabelValues(ak, s.User(), name).Inc()
213 return
214 case git.UploadPackBin, git.UploadArchiveBin:
215 if access < backend.ReadOnlyAccess {
216 sshFatal(s, git.ErrNotAuthed)
217 return
218 }
219
220 gitPack := git.UploadPack
221 counter := uploadPackCounter
222 if gc == git.UploadArchiveBin {
223 gitPack = git.UploadArchive
224 counter = uploadArchiveCounter
225 }
226
227 err := gitPack(s, s, s.Stderr(), repoDir)
228 if errors.Is(err, git.ErrInvalidRepo) {
229 sshFatal(s, git.ErrInvalidRepo)
230 } else if err != nil {
231 sshFatal(s, git.ErrSystemMalfunction)
232 }
233
234 counter.WithLabelValues(ak, s.User(), name).Inc()
235 }
236 }
237 }()
238 sh(s)
239 }
240 }
241}
242
243// sshFatal prints to the session's STDOUT as a git response and exit 1.
244func sshFatal(s ssh.Session, v ...interface{}) {
245 git.WritePktline(s, v...)
246 s.Exit(1) // nolint: errcheck
247}