middleware.go

  1package server
  2
  3import (
  4	"compress/gzip"
  5	"log/slog"
  6	"net/http"
  7	"strings"
  8	"sync"
  9
 10	sloghttp "github.com/samber/slog-http"
 11)
 12
 13// LoggerMiddleware adds request logging using slog-http
 14func LoggerMiddleware(logger *slog.Logger) func(http.Handler) http.Handler {
 15	config := sloghttp.Config{
 16		DefaultLevel:     slog.LevelInfo,
 17		ClientErrorLevel: slog.LevelInfo,
 18		ServerErrorLevel: slog.LevelInfo,
 19		WithRequestID:    false,
 20	}
 21	return sloghttp.NewWithConfig(logger, config)
 22}
 23
 24// CSRFMiddleware protects against CSRF attacks by requiring the X-Shelley-Request header
 25// on state-changing requests (POST, PUT, DELETE). This works because browsers will not
 26// add custom headers to simple cross-origin requests, and CORS preflight will block
 27// complex requests from other origins that don't have explicit permission.
 28func CSRFMiddleware() func(http.Handler) http.Handler {
 29	return func(next http.Handler) http.Handler {
 30		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 31			// Only check state-changing methods
 32			if r.Method == http.MethodPost || r.Method == http.MethodPut || r.Method == http.MethodDelete {
 33				// Require X-Shelley-Request header (value doesn't matter, just presence)
 34				if r.Header.Get("X-Shelley-Request") == "" {
 35					http.Error(w, "CSRF protection: X-Shelley-Request header required", http.StatusForbidden)
 36					return
 37				}
 38			}
 39			next.ServeHTTP(w, r)
 40		})
 41	}
 42}
 43
 44// RequireHeaderMiddleware requires a specific header to be present on all API requests.
 45// This is used to ensure requests come through an authenticated proxy.
 46func RequireHeaderMiddleware(headerName string) func(http.Handler) http.Handler {
 47	return func(next http.Handler) http.Handler {
 48		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 49			// Only check API routes
 50			if strings.HasPrefix(r.URL.Path, "/api/") {
 51				if r.Header.Get(headerName) == "" {
 52					http.Error(w, "missing required header: "+headerName, http.StatusForbidden)
 53					return
 54				}
 55			}
 56			next.ServeHTTP(w, r)
 57		})
 58	}
 59}
 60
 61// gzipResponseWriter wraps http.ResponseWriter to compress responses
 62type gzipResponseWriter struct {
 63	http.ResponseWriter
 64	gw *gzip.Writer
 65}
 66
 67func (w *gzipResponseWriter) Write(b []byte) (int, error) {
 68	return w.gw.Write(b)
 69}
 70
 71var gzipWriterPool = sync.Pool{
 72	New: func() interface{} {
 73		gw, _ := gzip.NewWriterLevel(nil, gzip.BestSpeed)
 74		return gw
 75	},
 76}
 77
 78// gzipHandler wraps a handler to compress responses when the client accepts gzip.
 79// Use this to wrap specific handlers that benefit from compression.
 80// Do NOT use for SSE or streaming responses.
 81func gzipHandler(next http.Handler) http.Handler {
 82	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 83		if !strings.Contains(r.Header.Get("Accept-Encoding"), "gzip") {
 84			next.ServeHTTP(w, r)
 85			return
 86		}
 87
 88		gw := gzipWriterPool.Get().(*gzip.Writer)
 89		gw.Reset(w)
 90		defer func() {
 91			gw.Close()
 92			gzipWriterPool.Put(gw)
 93		}()
 94
 95		w.Header().Set("Content-Encoding", "gzip")
 96		w.Header().Add("Vary", "Accept-Encoding")
 97		w.Header().Del("Content-Length") // Compression changes size
 98
 99		next.ServeHTTP(&gzipResponseWriter{ResponseWriter: w, gw: gw}, r)
100	})
101}