From 9c35ee0141f5afb99f69cd912797d018d9681c9f Mon Sep 17 00:00:00 2001 From: Christian Rocha Date: Sun, 17 May 2026 09:43:53 -0400 Subject: [PATCH] fix(server): recover from handler panics + return 500 --- internal/server/recover.go | 59 +++++++++++++++++++++ internal/server/recover_test.go | 94 +++++++++++++++++++++++++++++++++ internal/server/server.go | 2 +- 3 files changed, 154 insertions(+), 1 deletion(-) create mode 100644 internal/server/recover.go create mode 100644 internal/server/recover_test.go diff --git a/internal/server/recover.go b/internal/server/recover.go new file mode 100644 index 0000000000000000000000000000000000000000..547dfb0c84fa7a8026ef0c91d761b196c9703de5 --- /dev/null +++ b/internal/server/recover.go @@ -0,0 +1,59 @@ +package server + +import ( + "log/slog" + "net/http" + "runtime/debug" +) + +// recoverHandler wraps the next handler in a panic-recovery middleware. +// If a handler panics, the panic is logged with a stack trace and a 500 +// JSON error is written to the client (when no response has been started +// yet). Without this, a panicking handler closes the connection silently +// and surfaces as an opaque EOF on the client side. +func (s *Server) recoverHandler(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + rrw := &recoverResponseWriter{ResponseWriter: w} + defer func() { + rec := recover() + if rec == nil { + return + } + // http.ErrAbortHandler is the documented way to abort a + // handler without logging; preserve that contract. + if rec == http.ErrAbortHandler { + panic(rec) + } + s.logError( + r, "Panic in handler", + slog.Any("panic", rec), + slog.String("stack", string(debug.Stack())), + ) + if !rrw.wroteHeader { + jsonError(rrw, http.StatusInternalServerError, "internal server error") + } + }() + next.ServeHTTP(rrw, r) + }) +} + +// recoverResponseWriter tracks whether the response has been started so +// the recovery middleware knows if it can still write a 500 error. +type recoverResponseWriter struct { + http.ResponseWriter + wroteHeader bool +} + +func (rrw *recoverResponseWriter) WriteHeader(code int) { + rrw.wroteHeader = true + rrw.ResponseWriter.WriteHeader(code) +} + +func (rrw *recoverResponseWriter) Write(b []byte) (int, error) { + rrw.wroteHeader = true + return rrw.ResponseWriter.Write(b) +} + +func (rrw *recoverResponseWriter) Unwrap() http.ResponseWriter { + return rrw.ResponseWriter +} diff --git a/internal/server/recover_test.go b/internal/server/recover_test.go new file mode 100644 index 0000000000000000000000000000000000000000..2bbd61efd592761dffaf73e724b513b6aa4561cf --- /dev/null +++ b/internal/server/recover_test.go @@ -0,0 +1,94 @@ +package server + +import ( + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/charmbracelet/crush/internal/proto" + "github.com/stretchr/testify/require" +) + +// TestRecoverHandler_PanicReturns500 verifies that a panicking handler +// surfaces as a structured 500 to the client, rather than closing the +// connection silently and producing an opaque EOF. +func TestRecoverHandler_PanicReturns500(t *testing.T) { + t.Parallel() + + s := &Server{} + h := s.recoverHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + panic("kaboom") + })) + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/test", nil) + h.ServeHTTP(rec, req) + + require.Equal(t, http.StatusInternalServerError, rec.Code) + body, err := io.ReadAll(rec.Body) + require.NoError(t, err) + require.NotEmpty(t, body) + + var perr proto.Error + require.NoError(t, json.Unmarshal(body, &perr)) + require.NotEmpty(t, perr.Message) +} + +// TestRecoverHandler_NoPanicPassthrough verifies that the middleware +// does not interfere with successful responses. +func TestRecoverHandler_NoPanicPassthrough(t *testing.T) { + t.Parallel() + + s := &Server{} + h := s.recoverHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusTeapot) + _, _ = w.Write([]byte("ok")) + })) + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/test", nil) + h.ServeHTTP(rec, req) + + require.Equal(t, http.StatusTeapot, rec.Code) + require.Equal(t, "ok", rec.Body.String()) +} + +// TestRecoverHandler_PanicAfterWriteHeader verifies that if a handler +// panics after it has already started writing the response, the +// middleware does not attempt to overwrite the status (which would +// trigger a superfluous WriteHeader warning) but still logs and +// recovers. +func TestRecoverHandler_PanicAfterWriteHeader(t *testing.T) { + t.Parallel() + + s := &Server{} + h := s.recoverHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("partial")) + panic("late panic") + })) + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/test", nil) + require.NotPanics(t, func() { h.ServeHTTP(rec, req) }) + require.Equal(t, http.StatusOK, rec.Code) + require.Equal(t, "partial", rec.Body.String()) +} + +// TestRecoverHandler_AbortHandlerPropagates verifies that the documented +// http.ErrAbortHandler sentinel is re-panicked so the net/http server +// can handle it normally (suppress logging, close connection). +func TestRecoverHandler_AbortHandlerPropagates(t *testing.T) { + t.Parallel() + + s := &Server{} + h := s.recoverHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + panic(http.ErrAbortHandler) + })) + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/test", nil) + require.PanicsWithValue(t, http.ErrAbortHandler, func() { h.ServeHTTP(rec, req) }) +} diff --git a/internal/server/server.go b/internal/server/server.go index 9ac4dba4c908050a0381b49258941d1b3a931970..75ef626d952af7183bcad5681dce7b0fdd85975c 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -168,7 +168,7 @@ func NewServer(cfg *config.ConfigStore, network, address string) *Server { mux.Handle("/v1/docs/", httpswagger.WrapHandler) s.h = &http.Server{ Protocols: &p, - Handler: s.loggingHandler(mux), + Handler: s.recoverHandler(s.loggingHandler(mux)), } if network == "tcp" { s.h.Addr = address