1package server
2
3import (
4 "compress/gzip"
5 "log/slog"
6 "net/http"
7 "strings"
8 "sync"
9
10 sloghttp "github.com/samber/slog-http"
11)
12
13// LoggerMiddleware adds request logging using slog-http
14func LoggerMiddleware(logger *slog.Logger) func(http.Handler) http.Handler {
15 config := sloghttp.Config{
16 DefaultLevel: slog.LevelInfo,
17 ClientErrorLevel: slog.LevelInfo,
18 ServerErrorLevel: slog.LevelInfo,
19 WithRequestID: false,
20 }
21 return sloghttp.NewWithConfig(logger, config)
22}
23
24// CSRFMiddleware protects against CSRF attacks by requiring the X-Shelley-Request header
25// on state-changing requests (POST, PUT, DELETE). This works because browsers will not
26// add custom headers to simple cross-origin requests, and CORS preflight will block
27// complex requests from other origins that don't have explicit permission.
28func CSRFMiddleware() func(http.Handler) http.Handler {
29 return func(next http.Handler) http.Handler {
30 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
31 // Only check state-changing methods
32 if r.Method == http.MethodPost || r.Method == http.MethodPut || r.Method == http.MethodDelete {
33 // Require X-Shelley-Request header (value doesn't matter, just presence)
34 if r.Header.Get("X-Shelley-Request") == "" {
35 http.Error(w, "CSRF protection: X-Shelley-Request header required", http.StatusForbidden)
36 return
37 }
38 }
39 next.ServeHTTP(w, r)
40 })
41 }
42}
43
44// RequireHeaderMiddleware requires a specific header to be present on all API requests.
45// This is used to ensure requests come through an authenticated proxy.
46func RequireHeaderMiddleware(headerName string) func(http.Handler) http.Handler {
47 return func(next http.Handler) http.Handler {
48 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
49 // Only check API routes
50 if strings.HasPrefix(r.URL.Path, "/api/") {
51 if r.Header.Get(headerName) == "" {
52 http.Error(w, "missing required header: "+headerName, http.StatusForbidden)
53 return
54 }
55 }
56 next.ServeHTTP(w, r)
57 })
58 }
59}
60
61// gzipResponseWriter wraps http.ResponseWriter to compress responses
62type gzipResponseWriter struct {
63 http.ResponseWriter
64 gw *gzip.Writer
65}
66
67func (w *gzipResponseWriter) Write(b []byte) (int, error) {
68 return w.gw.Write(b)
69}
70
71var gzipWriterPool = sync.Pool{
72 New: func() interface{} {
73 gw, _ := gzip.NewWriterLevel(nil, gzip.BestSpeed)
74 return gw
75 },
76}
77
78// gzipHandler wraps a handler to compress responses when the client accepts gzip.
79// Use this to wrap specific handlers that benefit from compression.
80// Do NOT use for SSE or streaming responses.
81func gzipHandler(next http.Handler) http.Handler {
82 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
83 if !strings.Contains(r.Header.Get("Accept-Encoding"), "gzip") {
84 next.ServeHTTP(w, r)
85 return
86 }
87
88 gw := gzipWriterPool.Get().(*gzip.Writer)
89 gw.Reset(w)
90 defer func() {
91 gw.Close()
92 gzipWriterPool.Put(gw)
93 }()
94
95 w.Header().Set("Content-Encoding", "gzip")
96 w.Header().Add("Vary", "Accept-Encoding")
97 w.Header().Del("Content-Length") // Compression changes size
98
99 next.ServeHTTP(&gzipResponseWriter{ResponseWriter: w, gw: gw}, r)
100 })
101}