recover_test.go

 1package server
 2
 3import (
 4	"encoding/json"
 5	"io"
 6	"net/http"
 7	"net/http/httptest"
 8	"testing"
 9
10	"github.com/charmbracelet/crush/internal/proto"
11	"github.com/stretchr/testify/require"
12)
13
14// TestRecoverHandler_PanicReturns500 verifies that a panicking handler
15// surfaces as a structured 500 to the client, rather than closing the
16// connection silently and producing an opaque EOF.
17func TestRecoverHandler_PanicReturns500(t *testing.T) {
18	t.Parallel()
19
20	s := &Server{}
21	h := s.recoverHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
22		panic("kaboom")
23	}))
24
25	rec := httptest.NewRecorder()
26	req := httptest.NewRequest(http.MethodGet, "/test", nil)
27	h.ServeHTTP(rec, req)
28
29	require.Equal(t, http.StatusInternalServerError, rec.Code)
30	body, err := io.ReadAll(rec.Body)
31	require.NoError(t, err)
32	require.NotEmpty(t, body)
33
34	var perr proto.Error
35	require.NoError(t, json.Unmarshal(body, &perr))
36	require.NotEmpty(t, perr.Message)
37}
38
39// TestRecoverHandler_NoPanicPassthrough verifies that the middleware
40// does not interfere with successful responses.
41func TestRecoverHandler_NoPanicPassthrough(t *testing.T) {
42	t.Parallel()
43
44	s := &Server{}
45	h := s.recoverHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
46		w.WriteHeader(http.StatusTeapot)
47		_, _ = w.Write([]byte("ok"))
48	}))
49
50	rec := httptest.NewRecorder()
51	req := httptest.NewRequest(http.MethodGet, "/test", nil)
52	h.ServeHTTP(rec, req)
53
54	require.Equal(t, http.StatusTeapot, rec.Code)
55	require.Equal(t, "ok", rec.Body.String())
56}
57
58// TestRecoverHandler_PanicAfterWriteHeader verifies that if a handler
59// panics after it has already started writing the response, the
60// middleware does not attempt to overwrite the status (which would
61// trigger a superfluous WriteHeader warning) but still logs and
62// recovers.
63func TestRecoverHandler_PanicAfterWriteHeader(t *testing.T) {
64	t.Parallel()
65
66	s := &Server{}
67	h := s.recoverHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
68		w.WriteHeader(http.StatusOK)
69		_, _ = w.Write([]byte("partial"))
70		panic("late panic")
71	}))
72
73	rec := httptest.NewRecorder()
74	req := httptest.NewRequest(http.MethodGet, "/test", nil)
75	require.NotPanics(t, func() { h.ServeHTTP(rec, req) })
76	require.Equal(t, http.StatusOK, rec.Code)
77	require.Equal(t, "partial", rec.Body.String())
78}
79
80// TestRecoverHandler_AbortHandlerPropagates verifies that the documented
81// http.ErrAbortHandler sentinel is re-panicked so the net/http server
82// can handle it normally (suppress logging, close connection).
83func TestRecoverHandler_AbortHandlerPropagates(t *testing.T) {
84	t.Parallel()
85
86	s := &Server{}
87	h := s.recoverHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
88		panic(http.ErrAbortHandler)
89	}))
90
91	rec := httptest.NewRecorder()
92	req := httptest.NewRequest(http.MethodGet, "/test", nil)
93	require.PanicsWithValue(t, http.ErrAbortHandler, func() { h.ServeHTTP(rec, req) })
94}