1package ssh
2
3import (
4 "fmt"
5 "time"
6
7 "github.com/charmbracelet/log"
8 "github.com/charmbracelet/soft-serve/server/backend"
9 "github.com/charmbracelet/soft-serve/server/config"
10 "github.com/charmbracelet/soft-serve/server/db"
11 "github.com/charmbracelet/soft-serve/server/proto"
12 "github.com/charmbracelet/soft-serve/server/ssh/cmd"
13 "github.com/charmbracelet/soft-serve/server/sshutils"
14 "github.com/charmbracelet/soft-serve/server/store"
15 "github.com/charmbracelet/ssh"
16 "github.com/prometheus/client_golang/prometheus"
17 "github.com/prometheus/client_golang/prometheus/promauto"
18 "github.com/spf13/cobra"
19)
20
21// ContextMiddleware adds the config, backend, and logger to the session context.
22func ContextMiddleware(cfg *config.Config, dbx *db.DB, datastore store.Store, be *backend.Backend, logger *log.Logger) func(ssh.Handler) ssh.Handler {
23 return func(sh ssh.Handler) ssh.Handler {
24 return func(s ssh.Session) {
25 s.Context().SetValue(sshutils.ContextKeySession, s)
26 s.Context().SetValue(config.ContextKey, cfg)
27 s.Context().SetValue(db.ContextKey, dbx)
28 s.Context().SetValue(store.ContextKey, datastore)
29 s.Context().SetValue(backend.ContextKey, be)
30 s.Context().SetValue(log.ContextKey, logger.WithPrefix("ssh"))
31 sh(s)
32 }
33 }
34}
35
36var cliCommandCounter = promauto.NewCounterVec(prometheus.CounterOpts{
37 Namespace: "soft_serve",
38 Subsystem: "cli",
39 Name: "commands_total",
40 Help: "Total times each command was called",
41}, []string{"command"})
42
43// CommandMiddleware handles git commands and CLI commands.
44// This middleware must be run after the ContextMiddleware.
45func CommandMiddleware(sh ssh.Handler) ssh.Handler {
46 return func(s ssh.Session) {
47 func() {
48 _, _, ptyReq := s.Pty()
49 if ptyReq {
50 return
51 }
52
53 ctx := s.Context()
54 cfg := config.FromContext(ctx)
55
56 args := s.Command()
57 cliCommandCounter.WithLabelValues(cmd.CommandName(args)).Inc()
58 rootCmd := &cobra.Command{
59 Short: "Soft Serve is a self-hostable Git server for the command line.",
60 SilenceUsage: true,
61 }
62 rootCmd.CompletionOptions.DisableDefaultCmd = true
63
64 rootCmd.SetUsageTemplate(cmd.UsageTemplate)
65 rootCmd.SetUsageFunc(cmd.UsageFunc)
66 rootCmd.AddCommand(
67 cmd.GitUploadPackCommand(),
68 cmd.GitUploadArchiveCommand(),
69 cmd.GitReceivePackCommand(),
70 cmd.RepoCommand(),
71 cmd.SettingsCommand(),
72 cmd.UserCommand(),
73 cmd.InfoCommand(),
74 cmd.PubkeyCommand(),
75 cmd.SetUsernameCommand(),
76 cmd.JWTCommand(),
77 cmd.TokenCommand(),
78 )
79
80 if cfg.LFS.Enabled {
81 rootCmd.AddCommand(
82 cmd.GitLFSAuthenticateCommand(),
83 )
84
85 if cfg.LFS.SSHEnabled {
86 rootCmd.AddCommand(
87 cmd.GitLFSTransfer(),
88 )
89 }
90 }
91
92 rootCmd.SetArgs(args)
93 if len(args) == 0 {
94 // otherwise it'll default to os.Args, which is not what we want.
95 rootCmd.SetArgs([]string{"--help"})
96 }
97 rootCmd.SetIn(s)
98 rootCmd.SetOut(s)
99 rootCmd.SetErr(s.Stderr())
100 rootCmd.SetContext(ctx)
101
102 if err := rootCmd.ExecuteContext(ctx); err != nil {
103 s.Exit(1) // nolint: errcheck
104 return
105 }
106 }()
107 sh(s)
108 }
109}
110
111// LoggingMiddleware logs the ssh connection and command.
112func LoggingMiddleware(sh ssh.Handler) ssh.Handler {
113 return func(s ssh.Session) {
114 ctx := s.Context()
115 logger := log.FromContext(ctx).WithPrefix("ssh")
116 ct := time.Now()
117 hpk := sshutils.MarshalAuthorizedKey(s.PublicKey())
118 ptyReq, _, isPty := s.Pty()
119 addr := s.RemoteAddr().String()
120 user := proto.UserFromContext(ctx)
121 logArgs := []interface{}{
122 "addr",
123 addr,
124 "cmd",
125 s.Command(),
126 }
127
128 if user != nil {
129 logArgs = append([]interface{}{
130 "username",
131 user.Username(),
132 }, logArgs...)
133 }
134
135 if isPty {
136 logArgs = []interface{}{
137 "term", ptyReq.Term,
138 "width", ptyReq.Window.Width,
139 "height", ptyReq.Window.Height,
140 }
141 }
142
143 if config.IsVerbose() {
144 logArgs = append(logArgs,
145 "key", hpk,
146 "envs", s.Environ(),
147 )
148 }
149
150 msg := fmt.Sprintf("user %q", s.User())
151 logger.Debug(msg+" connected", logArgs...)
152 sh(s)
153 logger.Debug(msg+" disconnected", append(logArgs, "duration", time.Since(ct))...)
154 }
155}