fix(server): recover from handler panics + return 500

Christian Rocha created

Change summary

internal/server/recover.go      | 59 +++++++++++++++++++++
internal/server/recover_test.go | 94 +++++++++++++++++++++++++++++++++++
internal/server/server.go       |  2 
3 files changed, 154 insertions(+), 1 deletion(-)

Detailed changes

internal/server/recover.go 🔗

@@ -0,0 +1,59 @@
+package server
+
+import (
+	"log/slog"
+	"net/http"
+	"runtime/debug"
+)
+
+// recoverHandler wraps the next handler in a panic-recovery middleware.
+// If a handler panics, the panic is logged with a stack trace and a 500
+// JSON error is written to the client (when no response has been started
+// yet). Without this, a panicking handler closes the connection silently
+// and surfaces as an opaque EOF on the client side.
+func (s *Server) recoverHandler(next http.Handler) http.Handler {
+	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+		rrw := &recoverResponseWriter{ResponseWriter: w}
+		defer func() {
+			rec := recover()
+			if rec == nil {
+				return
+			}
+			// http.ErrAbortHandler is the documented way to abort a
+			// handler without logging; preserve that contract.
+			if rec == http.ErrAbortHandler {
+				panic(rec)
+			}
+			s.logError(
+				r, "Panic in handler",
+				slog.Any("panic", rec),
+				slog.String("stack", string(debug.Stack())),
+			)
+			if !rrw.wroteHeader {
+				jsonError(rrw, http.StatusInternalServerError, "internal server error")
+			}
+		}()
+		next.ServeHTTP(rrw, r)
+	})
+}
+
+// recoverResponseWriter tracks whether the response has been started so
+// the recovery middleware knows if it can still write a 500 error.
+type recoverResponseWriter struct {
+	http.ResponseWriter
+	wroteHeader bool
+}
+
+func (rrw *recoverResponseWriter) WriteHeader(code int) {
+	rrw.wroteHeader = true
+	rrw.ResponseWriter.WriteHeader(code)
+}
+
+func (rrw *recoverResponseWriter) Write(b []byte) (int, error) {
+	rrw.wroteHeader = true
+	return rrw.ResponseWriter.Write(b)
+}
+
+func (rrw *recoverResponseWriter) Unwrap() http.ResponseWriter {
+	return rrw.ResponseWriter
+}

internal/server/recover_test.go 🔗

@@ -0,0 +1,94 @@
+package server
+
+import (
+	"encoding/json"
+	"io"
+	"net/http"
+	"net/http/httptest"
+	"testing"
+
+	"github.com/charmbracelet/crush/internal/proto"
+	"github.com/stretchr/testify/require"
+)
+
+// TestRecoverHandler_PanicReturns500 verifies that a panicking handler
+// surfaces as a structured 500 to the client, rather than closing the
+// connection silently and producing an opaque EOF.
+func TestRecoverHandler_PanicReturns500(t *testing.T) {
+	t.Parallel()
+
+	s := &Server{}
+	h := s.recoverHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+		panic("kaboom")
+	}))
+
+	rec := httptest.NewRecorder()
+	req := httptest.NewRequest(http.MethodGet, "/test", nil)
+	h.ServeHTTP(rec, req)
+
+	require.Equal(t, http.StatusInternalServerError, rec.Code)
+	body, err := io.ReadAll(rec.Body)
+	require.NoError(t, err)
+	require.NotEmpty(t, body)
+
+	var perr proto.Error
+	require.NoError(t, json.Unmarshal(body, &perr))
+	require.NotEmpty(t, perr.Message)
+}
+
+// TestRecoverHandler_NoPanicPassthrough verifies that the middleware
+// does not interfere with successful responses.
+func TestRecoverHandler_NoPanicPassthrough(t *testing.T) {
+	t.Parallel()
+
+	s := &Server{}
+	h := s.recoverHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+		w.WriteHeader(http.StatusTeapot)
+		_, _ = w.Write([]byte("ok"))
+	}))
+
+	rec := httptest.NewRecorder()
+	req := httptest.NewRequest(http.MethodGet, "/test", nil)
+	h.ServeHTTP(rec, req)
+
+	require.Equal(t, http.StatusTeapot, rec.Code)
+	require.Equal(t, "ok", rec.Body.String())
+}
+
+// TestRecoverHandler_PanicAfterWriteHeader verifies that if a handler
+// panics after it has already started writing the response, the
+// middleware does not attempt to overwrite the status (which would
+// trigger a superfluous WriteHeader warning) but still logs and
+// recovers.
+func TestRecoverHandler_PanicAfterWriteHeader(t *testing.T) {
+	t.Parallel()
+
+	s := &Server{}
+	h := s.recoverHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+		w.WriteHeader(http.StatusOK)
+		_, _ = w.Write([]byte("partial"))
+		panic("late panic")
+	}))
+
+	rec := httptest.NewRecorder()
+	req := httptest.NewRequest(http.MethodGet, "/test", nil)
+	require.NotPanics(t, func() { h.ServeHTTP(rec, req) })
+	require.Equal(t, http.StatusOK, rec.Code)
+	require.Equal(t, "partial", rec.Body.String())
+}
+
+// TestRecoverHandler_AbortHandlerPropagates verifies that the documented
+// http.ErrAbortHandler sentinel is re-panicked so the net/http server
+// can handle it normally (suppress logging, close connection).
+func TestRecoverHandler_AbortHandlerPropagates(t *testing.T) {
+	t.Parallel()
+
+	s := &Server{}
+	h := s.recoverHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+		panic(http.ErrAbortHandler)
+	}))
+
+	rec := httptest.NewRecorder()
+	req := httptest.NewRequest(http.MethodGet, "/test", nil)
+	require.PanicsWithValue(t, http.ErrAbortHandler, func() { h.ServeHTTP(rec, req) })
+}

internal/server/server.go 🔗

@@ -168,7 +168,7 @@ func NewServer(cfg *config.ConfigStore, network, address string) *Server {
 	mux.Handle("/v1/docs/", httpswagger.WrapHandler)
 	s.h = &http.Server{
 		Protocols: &p,
-		Handler:   s.loggingHandler(mux),
+		Handler:   s.recoverHandler(s.loggingHandler(mux)),
 	}
 	if network == "tcp" {
 		s.h.Addr = address