middleware_test.go

  1package server
  2
  3import (
  4	"bytes"
  5	"compress/gzip"
  6	"io"
  7	"net/http"
  8	"net/http/httptest"
  9	"testing"
 10)
 11
 12func TestCSRFMiddleware_BlocksPostWithoutHeader(t *testing.T) {
 13	handler := CSRFMiddleware()(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 14		w.WriteHeader(http.StatusOK)
 15	}))
 16
 17	req := httptest.NewRequest("POST", "/api/test", nil)
 18	w := httptest.NewRecorder()
 19
 20	handler.ServeHTTP(w, req)
 21
 22	if w.Code != http.StatusForbidden {
 23		t.Errorf("expected status 403 for POST without X-Shelley-Request, got %d", w.Code)
 24	}
 25}
 26
 27func TestCSRFMiddleware_AllowsPostWithHeader(t *testing.T) {
 28	handler := CSRFMiddleware()(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 29		w.WriteHeader(http.StatusOK)
 30	}))
 31
 32	req := httptest.NewRequest("POST", "/api/test", nil)
 33	req.Header.Set("X-Shelley-Request", "1")
 34	w := httptest.NewRecorder()
 35
 36	handler.ServeHTTP(w, req)
 37
 38	if w.Code != http.StatusOK {
 39		t.Errorf("expected status 200 for POST with X-Shelley-Request, got %d", w.Code)
 40	}
 41}
 42
 43func TestCSRFMiddleware_AllowsGetWithoutHeader(t *testing.T) {
 44	handler := CSRFMiddleware()(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 45		w.WriteHeader(http.StatusOK)
 46	}))
 47
 48	req := httptest.NewRequest("GET", "/api/test", nil)
 49	w := httptest.NewRecorder()
 50
 51	handler.ServeHTTP(w, req)
 52
 53	if w.Code != http.StatusOK {
 54		t.Errorf("expected status 200 for GET without X-Shelley-Request, got %d", w.Code)
 55	}
 56}
 57
 58func TestCSRFMiddleware_BlocksPutWithoutHeader(t *testing.T) {
 59	handler := CSRFMiddleware()(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 60		w.WriteHeader(http.StatusOK)
 61	}))
 62
 63	req := httptest.NewRequest("PUT", "/api/test", nil)
 64	w := httptest.NewRecorder()
 65
 66	handler.ServeHTTP(w, req)
 67
 68	if w.Code != http.StatusForbidden {
 69		t.Errorf("expected status 403 for PUT without X-Shelley-Request, got %d", w.Code)
 70	}
 71}
 72
 73func TestCSRFMiddleware_BlocksDeleteWithoutHeader(t *testing.T) {
 74	handler := CSRFMiddleware()(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 75		w.WriteHeader(http.StatusOK)
 76	}))
 77
 78	req := httptest.NewRequest("DELETE", "/api/test", nil)
 79	w := httptest.NewRecorder()
 80
 81	handler.ServeHTTP(w, req)
 82
 83	if w.Code != http.StatusForbidden {
 84		t.Errorf("expected status 403 for DELETE without X-Shelley-Request, got %d", w.Code)
 85	}
 86}
 87
 88func TestRequireHeaderMiddleware_BlocksAPIWithoutHeader(t *testing.T) {
 89	handler := RequireHeaderMiddleware("X-Exedev-Userid")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 90		w.WriteHeader(http.StatusOK)
 91	}))
 92
 93	req := httptest.NewRequest("GET", "/api/conversations", nil)
 94	w := httptest.NewRecorder()
 95
 96	handler.ServeHTTP(w, req)
 97
 98	if w.Code != http.StatusForbidden {
 99		t.Errorf("expected status 403 for API request without required header, got %d", w.Code)
100	}
101}
102
103func TestRequireHeaderMiddleware_AllowsAPIWithHeader(t *testing.T) {
104	handler := RequireHeaderMiddleware("X-Exedev-Userid")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
105		w.WriteHeader(http.StatusOK)
106	}))
107
108	req := httptest.NewRequest("GET", "/api/conversations", nil)
109	req.Header.Set("X-Exedev-Userid", "user123")
110	w := httptest.NewRecorder()
111
112	handler.ServeHTTP(w, req)
113
114	if w.Code != http.StatusOK {
115		t.Errorf("expected status 200 for API request with required header, got %d", w.Code)
116	}
117}
118
119func TestRequireHeaderMiddleware_AllowsNonAPIWithoutHeader(t *testing.T) {
120	handler := RequireHeaderMiddleware("X-Exedev-Userid")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
121		w.WriteHeader(http.StatusOK)
122	}))
123
124	req := httptest.NewRequest("GET", "/", nil)
125	w := httptest.NewRecorder()
126
127	handler.ServeHTTP(w, req)
128
129	if w.Code != http.StatusOK {
130		t.Errorf("expected status 200 for non-API request without required header, got %d", w.Code)
131	}
132}
133
134func TestRequireHeaderMiddleware_AllowsVersionEndpointWithoutHeader(t *testing.T) {
135	handler := RequireHeaderMiddleware("X-Exedev-Userid")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
136		w.WriteHeader(http.StatusOK)
137	}))
138
139	req := httptest.NewRequest("GET", "/version", nil)
140	w := httptest.NewRecorder()
141
142	handler.ServeHTTP(w, req)
143
144	if w.Code != http.StatusOK {
145		t.Errorf("expected status 200 for /version without required header, got %d", w.Code)
146	}
147}
148
149func TestGzipHandler_CompressesResponse(t *testing.T) {
150	handler := gzipHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
151		w.Header().Set("Content-Type", "application/json")
152		w.Write([]byte(`{"message": "hello world"}`))
153	}))
154
155	req := httptest.NewRequest("GET", "/test", nil)
156	req.Header.Set("Accept-Encoding", "gzip")
157	w := httptest.NewRecorder()
158
159	handler.ServeHTTP(w, req)
160
161	if w.Header().Get("Content-Encoding") != "gzip" {
162		t.Errorf("expected Content-Encoding: gzip, got %q", w.Header().Get("Content-Encoding"))
163	}
164
165	// Verify we can decompress the response
166	gr, err := gzip.NewReader(bytes.NewReader(w.Body.Bytes()))
167	if err != nil {
168		t.Fatalf("failed to create gzip reader: %v", err)
169	}
170	defer gr.Close()
171
172	body, err := io.ReadAll(gr)
173	if err != nil {
174		t.Fatalf("failed to read gzip body: %v", err)
175	}
176
177	if !bytes.Contains(body, []byte("hello world")) {
178		t.Errorf("decompressed body doesn't contain expected content: %s", body)
179	}
180}
181
182func TestGzipHandler_SkipsWhenNoAcceptEncoding(t *testing.T) {
183	handler := gzipHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
184		w.Header().Set("Content-Type", "application/json")
185		w.Write([]byte(`{"message": "hello"}`))
186	}))
187
188	req := httptest.NewRequest("GET", "/test", nil)
189	// No Accept-Encoding header
190	w := httptest.NewRecorder()
191
192	handler.ServeHTTP(w, req)
193
194	if w.Header().Get("Content-Encoding") != "" {
195		t.Errorf("expected no Content-Encoding, got %q", w.Header().Get("Content-Encoding"))
196	}
197
198	if !bytes.Contains(w.Body.Bytes(), []byte("hello")) {
199		t.Errorf("body doesn't contain expected content: %s", w.Body.String())
200	}
201}