1package server
2
3import (
4 "bytes"
5 "compress/gzip"
6 "io"
7 "net/http"
8 "net/http/httptest"
9 "testing"
10)
11
12func TestCSRFMiddleware_BlocksPostWithoutHeader(t *testing.T) {
13 handler := CSRFMiddleware()(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
14 w.WriteHeader(http.StatusOK)
15 }))
16
17 req := httptest.NewRequest("POST", "/api/test", nil)
18 w := httptest.NewRecorder()
19
20 handler.ServeHTTP(w, req)
21
22 if w.Code != http.StatusForbidden {
23 t.Errorf("expected status 403 for POST without X-Shelley-Request, got %d", w.Code)
24 }
25}
26
27func TestCSRFMiddleware_AllowsPostWithHeader(t *testing.T) {
28 handler := CSRFMiddleware()(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
29 w.WriteHeader(http.StatusOK)
30 }))
31
32 req := httptest.NewRequest("POST", "/api/test", nil)
33 req.Header.Set("X-Shelley-Request", "1")
34 w := httptest.NewRecorder()
35
36 handler.ServeHTTP(w, req)
37
38 if w.Code != http.StatusOK {
39 t.Errorf("expected status 200 for POST with X-Shelley-Request, got %d", w.Code)
40 }
41}
42
43func TestCSRFMiddleware_AllowsGetWithoutHeader(t *testing.T) {
44 handler := CSRFMiddleware()(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
45 w.WriteHeader(http.StatusOK)
46 }))
47
48 req := httptest.NewRequest("GET", "/api/test", nil)
49 w := httptest.NewRecorder()
50
51 handler.ServeHTTP(w, req)
52
53 if w.Code != http.StatusOK {
54 t.Errorf("expected status 200 for GET without X-Shelley-Request, got %d", w.Code)
55 }
56}
57
58func TestCSRFMiddleware_BlocksPutWithoutHeader(t *testing.T) {
59 handler := CSRFMiddleware()(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
60 w.WriteHeader(http.StatusOK)
61 }))
62
63 req := httptest.NewRequest("PUT", "/api/test", nil)
64 w := httptest.NewRecorder()
65
66 handler.ServeHTTP(w, req)
67
68 if w.Code != http.StatusForbidden {
69 t.Errorf("expected status 403 for PUT without X-Shelley-Request, got %d", w.Code)
70 }
71}
72
73func TestCSRFMiddleware_BlocksDeleteWithoutHeader(t *testing.T) {
74 handler := CSRFMiddleware()(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
75 w.WriteHeader(http.StatusOK)
76 }))
77
78 req := httptest.NewRequest("DELETE", "/api/test", nil)
79 w := httptest.NewRecorder()
80
81 handler.ServeHTTP(w, req)
82
83 if w.Code != http.StatusForbidden {
84 t.Errorf("expected status 403 for DELETE without X-Shelley-Request, got %d", w.Code)
85 }
86}
87
88func TestRequireHeaderMiddleware_BlocksAPIWithoutHeader(t *testing.T) {
89 handler := RequireHeaderMiddleware("X-Exedev-Userid")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
90 w.WriteHeader(http.StatusOK)
91 }))
92
93 req := httptest.NewRequest("GET", "/api/conversations", nil)
94 w := httptest.NewRecorder()
95
96 handler.ServeHTTP(w, req)
97
98 if w.Code != http.StatusForbidden {
99 t.Errorf("expected status 403 for API request without required header, got %d", w.Code)
100 }
101}
102
103func TestRequireHeaderMiddleware_AllowsAPIWithHeader(t *testing.T) {
104 handler := RequireHeaderMiddleware("X-Exedev-Userid")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
105 w.WriteHeader(http.StatusOK)
106 }))
107
108 req := httptest.NewRequest("GET", "/api/conversations", nil)
109 req.Header.Set("X-Exedev-Userid", "user123")
110 w := httptest.NewRecorder()
111
112 handler.ServeHTTP(w, req)
113
114 if w.Code != http.StatusOK {
115 t.Errorf("expected status 200 for API request with required header, got %d", w.Code)
116 }
117}
118
119func TestRequireHeaderMiddleware_AllowsNonAPIWithoutHeader(t *testing.T) {
120 handler := RequireHeaderMiddleware("X-Exedev-Userid")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
121 w.WriteHeader(http.StatusOK)
122 }))
123
124 req := httptest.NewRequest("GET", "/", nil)
125 w := httptest.NewRecorder()
126
127 handler.ServeHTTP(w, req)
128
129 if w.Code != http.StatusOK {
130 t.Errorf("expected status 200 for non-API request without required header, got %d", w.Code)
131 }
132}
133
134func TestRequireHeaderMiddleware_AllowsVersionEndpointWithoutHeader(t *testing.T) {
135 handler := RequireHeaderMiddleware("X-Exedev-Userid")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
136 w.WriteHeader(http.StatusOK)
137 }))
138
139 req := httptest.NewRequest("GET", "/version", nil)
140 w := httptest.NewRecorder()
141
142 handler.ServeHTTP(w, req)
143
144 if w.Code != http.StatusOK {
145 t.Errorf("expected status 200 for /version without required header, got %d", w.Code)
146 }
147}
148
149func TestGzipHandler_CompressesResponse(t *testing.T) {
150 handler := gzipHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
151 w.Header().Set("Content-Type", "application/json")
152 w.Write([]byte(`{"message": "hello world"}`))
153 }))
154
155 req := httptest.NewRequest("GET", "/test", nil)
156 req.Header.Set("Accept-Encoding", "gzip")
157 w := httptest.NewRecorder()
158
159 handler.ServeHTTP(w, req)
160
161 if w.Header().Get("Content-Encoding") != "gzip" {
162 t.Errorf("expected Content-Encoding: gzip, got %q", w.Header().Get("Content-Encoding"))
163 }
164
165 // Verify we can decompress the response
166 gr, err := gzip.NewReader(bytes.NewReader(w.Body.Bytes()))
167 if err != nil {
168 t.Fatalf("failed to create gzip reader: %v", err)
169 }
170 defer gr.Close()
171
172 body, err := io.ReadAll(gr)
173 if err != nil {
174 t.Fatalf("failed to read gzip body: %v", err)
175 }
176
177 if !bytes.Contains(body, []byte("hello world")) {
178 t.Errorf("decompressed body doesn't contain expected content: %s", body)
179 }
180}
181
182func TestGzipHandler_SkipsWhenNoAcceptEncoding(t *testing.T) {
183 handler := gzipHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
184 w.Header().Set("Content-Type", "application/json")
185 w.Write([]byte(`{"message": "hello"}`))
186 }))
187
188 req := httptest.NewRequest("GET", "/test", nil)
189 // No Accept-Encoding header
190 w := httptest.NewRecorder()
191
192 handler.ServeHTTP(w, req)
193
194 if w.Header().Get("Content-Encoding") != "" {
195 t.Errorf("expected no Content-Encoding, got %q", w.Header().Get("Content-Encoding"))
196 }
197
198 if !bytes.Contains(w.Body.Bytes(), []byte("hello")) {
199 t.Errorf("body doesn't contain expected content: %s", w.Body.String())
200 }
201}