1package web
2
3import (
4 "bufio"
5 "fmt"
6 "net"
7 "net/http"
8 "time"
9
10 "github.com/charmbracelet/log"
11 "github.com/dustin/go-humanize"
12)
13
14// logWriter is a wrapper around http.ResponseWriter that allows us to capture
15// the HTTP status code and bytes written to the response.
16type logWriter struct {
17 http.ResponseWriter
18 code, bytes int
19}
20
21var _ http.ResponseWriter = (*logWriter)(nil)
22
23var _ http.Flusher = (*logWriter)(nil)
24
25var _ http.Hijacker = (*logWriter)(nil)
26
27var _ http.CloseNotifier = (*logWriter)(nil) // nolint: staticcheck
28
29// Write implements http.ResponseWriter.
30func (r *logWriter) Write(p []byte) (int, error) {
31 written, err := r.ResponseWriter.Write(p)
32 r.bytes += written
33 return written, err
34}
35
36// Note this is generally only called when sending an HTTP error, so it's
37// important to set the `code` value to 200 as a default.
38func (r *logWriter) WriteHeader(code int) {
39 r.code = code
40 r.ResponseWriter.WriteHeader(code)
41}
42
43// Flush implements http.Flusher.
44func (r *logWriter) Flush() {
45 if f, ok := r.ResponseWriter.(http.Flusher); ok {
46 f.Flush()
47 }
48}
49
50// CloseNotify implements http.CloseNotifier.
51func (r *logWriter) CloseNotify() <-chan bool {
52 if cn, ok := r.ResponseWriter.(http.CloseNotifier); ok { // nolint: staticcheck
53 return cn.CloseNotify()
54 }
55 return nil
56}
57
58// Hijack implements http.Hijacker.
59func (r *logWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
60 if h, ok := r.ResponseWriter.(http.Hijacker); ok {
61 return h.Hijack()
62 }
63 return nil, nil, fmt.Errorf("http.Hijacker not implemented")
64}
65
66// NewLoggingMiddleware returns a new logging middleware.
67func NewLoggingMiddleware(next http.Handler, logger *log.Logger) http.Handler {
68 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
69 start := time.Now()
70 writer := &logWriter{code: http.StatusOK, ResponseWriter: w}
71 logger.Debug("request",
72 "method", r.Method,
73 "path", r.URL,
74 "addr", r.RemoteAddr)
75 next.ServeHTTP(writer, r)
76 elapsed := time.Since(start)
77 logger.Debug("response",
78 "status", fmt.Sprintf("%d %s", writer.code, http.StatusText(writer.code)),
79 "bytes", humanize.Bytes(uint64(writer.bytes)),
80 "time", elapsed)
81 })
82}