feat(ssh): use custom logging middleware

Ayman Bagabas created

Change summary

server/git/errors.go     | 23 ++++++++++++++++++
server/git/git.go        | 22 ------------------
server/git/service.go    |  2 -
server/ssh/cmd/git.go    | 18 +++++++++-----
server/ssh/middleware.go | 51 ++++++++++++++++++++++++++++++++++++++++-
server/ssh/ssh.go        |  7 +----
server/web/git.go        |  1 
7 files changed, 85 insertions(+), 39 deletions(-)

Detailed changes

server/git/errors.go 🔗

@@ -0,0 +1,23 @@
+package git
+
+import "errors"
+
+var (
+	// ErrNotAuthed represents unauthorized access.
+	ErrNotAuthed = errors.New("you are not authorized to do this")
+
+	// ErrSystemMalfunction represents a general system error returned to clients.
+	ErrSystemMalfunction = errors.New("something went wrong")
+
+	// ErrInvalidRepo represents an attempt to access a non-existent repo.
+	ErrInvalidRepo = errors.New("invalid repo")
+
+	// ErrInvalidRequest represents an invalid request.
+	ErrInvalidRequest = errors.New("invalid request")
+
+	// ErrMaxConnections represents a maximum connection limit being reached.
+	ErrMaxConnections = errors.New("too many connections, try again later")
+
+	// ErrTimeout is returned when the maximum read timeout is exceeded.
+	ErrTimeout = errors.New("I/O timeout reached")
+)

server/git/git.go 🔗

@@ -2,7 +2,6 @@ package git
 
 import (
 	"context"
-	"errors"
 	"fmt"
 	"io"
 	"path/filepath"
@@ -14,27 +13,6 @@ import (
 	gitm "github.com/gogs/git-module"
 )
 
-var (
-
-	// ErrNotAuthed represents unauthorized access.
-	ErrNotAuthed = errors.New("you are not authorized to do this")
-
-	// ErrSystemMalfunction represents a general system error returned to clients.
-	ErrSystemMalfunction = errors.New("something went wrong")
-
-	// ErrInvalidRepo represents an attempt to access a non-existent repo.
-	ErrInvalidRepo = errors.New("invalid repo")
-
-	// ErrInvalidRequest represents an invalid request.
-	ErrInvalidRequest = errors.New("invalid request")
-
-	// ErrMaxConnections represents a maximum connection limit being reached.
-	ErrMaxConnections = errors.New("too many connections, try again later")
-
-	// ErrTimeout is returned when the maximum read timeout is exceeded.
-	ErrTimeout = errors.New("I/O timeout reached")
-)
-
 // WritePktline encodes and writes a pktline to the given writer.
 func WritePktline(w io.Writer, v ...interface{}) error {
 	msg := fmt.Sprintln(v...)

server/git/service.go 🔗

@@ -9,7 +9,6 @@ import (
 	"os/exec"
 	"strings"
 
-	"github.com/charmbracelet/log"
 	"golang.org/x/sync/errgroup"
 )
 
@@ -112,7 +111,6 @@ func gitServiceHandler(ctx context.Context, svc Service, scmd ServiceCommand) er
 		}
 	}
 
-	log.Debugf("git service command in %q: %s", cmd.Dir, cmd.String())
 	if err := cmd.Start(); err != nil {
 		if errors.Is(err, os.ErrNotExist) {
 			return ErrInvalidRepo

server/ssh/cmd/git.go 🔗

@@ -209,16 +209,17 @@ func gitRunE(cmd *cobra.Command, args []string) error {
 
 	repoPath := filepath.Join(reposDir, repoDir)
 	service := git.Service(cmd.Name())
+	stdin := cmd.InOrStdin()
+	stdout := cmd.OutOrStdout()
+	stderr := cmd.ErrOrStderr()
 	scmd := git.ServiceCommand{
-		Stdin:  cmd.InOrStdin(),
-		Stdout: s,
-		Stderr: s.Stderr(),
+		Stdin:  stdin,
+		Stdout: stdout,
+		Stderr: stderr,
 		Env:    envs,
 		Dir:    repoPath,
 	}
 
-	logger.Debug("git middleware", "cmd", service, "access", accessLevel.String())
-
 	switch service {
 	case git.ReceivePackService:
 		receivePackCounter.WithLabelValues(name).Inc()
@@ -237,16 +238,19 @@ func gitRunE(cmd *cobra.Command, args []string) error {
 		}
 
 		if err := service.Handler(ctx, scmd); err != nil {
+			logger.Error("failed to handle git service", "service", service, "err", err, "repo", name)
 			defer func() {
 				if repo == nil {
 					// If the repo was created, but the request failed, delete it.
 					be.DeleteRepository(ctx, name) // nolint: errcheck
 				}
 			}()
+
 			return git.ErrSystemMalfunction
 		}
 
 		if err := git.EnsureDefaultBranch(ctx, scmd); err != nil {
+			logger.Error("failed to ensure default branch", "err", err, "repo", name)
 			return git.ErrSystemMalfunction
 		}
 
@@ -279,7 +283,7 @@ func gitRunE(cmd *cobra.Command, args []string) error {
 		if errors.Is(err, git.ErrInvalidRepo) {
 			return git.ErrInvalidRepo
 		} else if err != nil {
-			logger.Error("git middleware", "err", err)
+			logger.Error("failed to handle git service", "service", service, "err", err, "repo", name)
 			return git.ErrSystemMalfunction
 		}
 
@@ -322,7 +326,7 @@ func gitRunE(cmd *cobra.Command, args []string) error {
 		}
 
 		if err := service.Handler(ctx, scmd); err != nil {
-			logger.Error("git middleware", "err", err)
+			logger.Error("failed to handle lfs service", "service", service, "err", err, "repo", name)
 			return git.ErrSystemMalfunction
 		}
 

server/ssh/middleware.go 🔗

@@ -1,6 +1,9 @@
 package ssh
 
 import (
+	"fmt"
+	"time"
+
 	"github.com/charmbracelet/log"
 	"github.com/charmbracelet/soft-serve/server/backend"
 	"github.com/charmbracelet/soft-serve/server/config"
@@ -49,7 +52,6 @@ func CommandMiddleware(sh ssh.Handler) ssh.Handler {
 
 			ctx := s.Context()
 			cfg := config.FromContext(ctx)
-			logger := log.FromContext(ctx)
 
 			args := s.Command()
 			cliCommandCounter.WithLabelValues(cmd.CommandName(args)).Inc()
@@ -110,7 +112,6 @@ func CommandMiddleware(sh ssh.Handler) ssh.Handler {
 			}
 
 			if err := rootCmd.ExecuteContext(ctx); err != nil {
-				logger.Error("error executing command", "err", err)
 				s.Exit(1) // nolint: errcheck
 				return
 			}
@@ -118,3 +119,49 @@ func CommandMiddleware(sh ssh.Handler) ssh.Handler {
 		sh(s)
 	}
 }
+
+// LoggingMiddleware logs the ssh connection and command.
+func LoggingMiddleware(sh ssh.Handler) ssh.Handler {
+	return func(s ssh.Session) {
+		ctx := s.Context()
+		logger := log.FromContext(ctx).WithPrefix("ssh")
+		ct := time.Now()
+		hpk := sshutils.MarshalAuthorizedKey(s.PublicKey())
+		ptyReq, _, isPty := s.Pty()
+		addr := s.RemoteAddr().String()
+		user := proto.UserFromContext(ctx)
+		logArgs := []interface{}{
+			"addr",
+			addr,
+			"cmd",
+			s.Command(),
+		}
+
+		if user != nil {
+			logArgs = append([]interface{}{
+				"username",
+				user.Username(),
+			}, logArgs...)
+		}
+
+		if isPty {
+			logArgs = []interface{}{
+				"term", ptyReq.Term,
+				"width", ptyReq.Window.Width,
+				"height", ptyReq.Window.Height,
+			}
+		}
+
+		if config.IsVerbose() {
+			logArgs = append(logArgs,
+				"key", hpk,
+				"envs", s.Environ(),
+			)
+		}
+
+		msg := fmt.Sprintf("user %q", s.User())
+		logger.Debug(msg+" connected", logArgs...)
+		sh(s)
+		logger.Debug(msg+" disconnected", append(logArgs, "duration", time.Since(ct))...)
+	}
+}

server/ssh/ssh.go 🔗

@@ -18,7 +18,6 @@ import (
 	"github.com/charmbracelet/ssh"
 	"github.com/charmbracelet/wish"
 	bm "github.com/charmbracelet/wish/bubbletea"
-	lm "github.com/charmbracelet/wish/logging"
 	rm "github.com/charmbracelet/wish/recover"
 	"github.com/muesli/termenv"
 	"github.com/prometheus/client_golang/prometheus"
@@ -74,12 +73,10 @@ func NewSSHServer(ctx context.Context) (*SSHServer, error) {
 			bm.MiddlewareWithProgramHandler(SessionHandler, termenv.ANSI256),
 			// CLI middleware.
 			CommandMiddleware,
+			// Logging middleware.
+			LoggingMiddleware,
 			// Context middleware.
 			ContextMiddleware(cfg, dbx, datastore, be, logger),
-			// Logging middleware.
-			lm.MiddlewareWithLogger(
-				&loggerAdapter{logger, log.DebugLevel},
-			),
 		),
 	}
 

server/web/git.go 🔗

@@ -75,7 +75,6 @@ var (
 func withParams(h http.Handler) http.Handler {
 	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 		ctx := r.Context()
-		logger := log.FromContext(ctx)
 		cfg := config.FromContext(ctx)
 		vars := mux.Vars(r)
 		repo := vars["repo"]