1package web
2
3import (
4 "bufio"
5 "fmt"
6 "net"
7 "net/http"
8 "time"
9
10 "charm.land/log/v2"
11 "github.com/dustin/go-humanize"
12)
13
14// logWriter is a wrapper around http.ResponseWriter that allows us to capture
15// the HTTP status code and bytes written to the response.
16type logWriter struct {
17 http.ResponseWriter
18 code, bytes int
19}
20
21var (
22 _ http.ResponseWriter = (*logWriter)(nil)
23 _ http.Flusher = (*logWriter)(nil)
24 _ http.Hijacker = (*logWriter)(nil)
25 _ http.CloseNotifier = (*logWriter)(nil)
26)
27
28// Write implements http.ResponseWriter.
29func (r *logWriter) Write(p []byte) (int, error) {
30 written, err := r.ResponseWriter.Write(p)
31 r.bytes += written
32 return written, err
33}
34
35// Note this is generally only called when sending an HTTP error, so it's
36// important to set the `code` value to 200 as a default.
37func (r *logWriter) WriteHeader(code int) {
38 r.code = code
39 r.ResponseWriter.WriteHeader(code)
40}
41
42// Unwrap returns the underlying http.ResponseWriter.
43func (r *logWriter) Unwrap() http.ResponseWriter {
44 return r.ResponseWriter
45}
46
47// Flush implements http.Flusher.
48func (r *logWriter) Flush() {
49 if f, ok := r.ResponseWriter.(http.Flusher); ok {
50 f.Flush()
51 }
52}
53
54// CloseNotify implements http.CloseNotifier.
55func (r *logWriter) CloseNotify() <-chan bool {
56 if cn, ok := r.ResponseWriter.(http.CloseNotifier); ok {
57 return cn.CloseNotify()
58 }
59 return nil
60}
61
62// Hijack implements http.Hijacker.
63func (r *logWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
64 if h, ok := r.ResponseWriter.(http.Hijacker); ok {
65 return h.Hijack()
66 }
67 return nil, nil, fmt.Errorf("http.Hijacker not implemented")
68}
69
70// NewLoggingMiddleware returns a new logging middleware.
71func NewLoggingMiddleware(next http.Handler, logger *log.Logger) http.Handler {
72 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
73 start := time.Now()
74 writer := &logWriter{code: http.StatusOK, ResponseWriter: w}
75 logger.Debug("request",
76 "method", r.Method,
77 "path", r.URL,
78 "addr", r.RemoteAddr)
79 next.ServeHTTP(writer, r)
80 elapsed := time.Since(start)
81 logger.Debug("response",
82 "status", fmt.Sprintf("%d %s", writer.code, http.StatusText(writer.code)),
83 "bytes", humanize.Bytes(uint64(writer.bytes)), //nolint:gosec
84 "time", elapsed)
85 })
86}