1// SPDX-FileCopyrightText: Amolith <amolith@secluded.site>
2//
3// SPDX-License-Identifier: AGPL-3.0-or-later
4
5package auth_test
6
7import (
8 "net/http"
9 "net/http/httptest"
10 "testing"
11
12 "git.secluded.site/lune/internal/mcp/auth"
13)
14
15func TestHashAndVerify(t *testing.T) {
16 t.Parallel()
17
18 token := "my-secret-token"
19
20 hash, err := auth.Hash(token)
21 if err != nil {
22 t.Fatalf("Hash() error = %v", err)
23 }
24
25 if hash == "" {
26 t.Fatal("Hash() returned empty string")
27 }
28
29 // Verify with correct token
30 if !auth.Verify(token, hash) {
31 t.Error("Verify() returned false for correct token")
32 }
33
34 // Verify with wrong token
35 if auth.Verify("wrong-token", hash) {
36 t.Error("Verify() returned true for wrong token")
37 }
38}
39
40func TestHashUniqueSalts(t *testing.T) {
41 t.Parallel()
42
43 token := "same-token"
44
45 hash1, err := auth.Hash(token)
46 if err != nil {
47 t.Fatalf("Hash() error = %v", err)
48 }
49
50 hash2, err := auth.Hash(token)
51 if err != nil {
52 t.Fatalf("Hash() error = %v", err)
53 }
54
55 if hash1 == hash2 {
56 t.Error("Hash() produced identical hashes for same token (salt not random)")
57 }
58
59 // Both should still verify correctly
60 if !auth.Verify(token, hash1) {
61 t.Error("Verify() failed for hash1")
62 }
63
64 if !auth.Verify(token, hash2) {
65 t.Error("Verify() failed for hash2")
66 }
67}
68
69func TestVerifyInvalidHash(t *testing.T) {
70 t.Parallel()
71
72 tests := []struct {
73 name string
74 hash string
75 }{
76 {"empty", ""},
77 {"garbage", "not-a-valid-hash"},
78 {"wrong algorithm", "$bcrypt$v=2$something$salt$hash"},
79 {"too few parts", "$argon2id$v=19$m=19456$salt"},
80 {"invalid base64 salt", "$argon2id$v=19$m=19456,t=2,p=1$!!!invalid$hash"},
81 {"invalid base64 hash", "$argon2id$v=19$m=19456,t=2,p=1$c2FsdA$!!!invalid"},
82 }
83
84 for _, tc := range tests {
85 t.Run(tc.name, func(t *testing.T) {
86 t.Parallel()
87
88 if auth.Verify("token", tc.hash) {
89 t.Errorf("Verify() returned true for invalid hash %q", tc.hash)
90 }
91 })
92 }
93}
94
95func TestMiddleware(t *testing.T) {
96 t.Parallel()
97
98 token := "test-bearer-token"
99
100 hash, err := auth.Hash(token)
101 if err != nil {
102 t.Fatalf("Hash() error = %v", err)
103 }
104
105 handler := http.HandlerFunc(func(rw http.ResponseWriter, _ *http.Request) {
106 rw.WriteHeader(http.StatusOK)
107 _, _ = rw.Write([]byte("OK"))
108 })
109
110 protected := auth.Middleware(hash)(handler)
111
112 testCases := middlewareTestCases(token)
113 for _, testCase := range testCases {
114 t.Run(testCase.name, func(t *testing.T) {
115 t.Parallel()
116 runMiddlewareTest(t, protected, testCase)
117 })
118 }
119}
120
121type middlewareTestCase struct {
122 name string
123 authHeader string
124 wantStatus int
125 wantAuthHeader bool
126}
127
128func middlewareTestCases(token string) []middlewareTestCase {
129 return []middlewareTestCase{
130 {
131 name: "missing header",
132 authHeader: "",
133 wantStatus: http.StatusUnauthorized,
134 wantAuthHeader: true,
135 },
136 {
137 name: "wrong scheme",
138 authHeader: "Basic dXNlcjpwYXNz",
139 wantStatus: http.StatusUnauthorized,
140 wantAuthHeader: true,
141 },
142 {
143 name: "wrong token",
144 authHeader: "Bearer wrong-token",
145 wantStatus: http.StatusUnauthorized,
146 wantAuthHeader: true,
147 },
148 {
149 name: "valid token",
150 authHeader: "Bearer " + token,
151 wantStatus: http.StatusOK,
152 wantAuthHeader: false,
153 },
154 }
155}
156
157func runMiddlewareTest(t *testing.T, protected http.Handler, testCase middlewareTestCase) {
158 t.Helper()
159
160 req := httptest.NewRequest(http.MethodGet, "/", nil)
161 if testCase.authHeader != "" {
162 req.Header.Set("Authorization", testCase.authHeader)
163 }
164
165 recorder := httptest.NewRecorder()
166 protected.ServeHTTP(recorder, req)
167
168 if recorder.Code != testCase.wantStatus {
169 t.Errorf("status = %d, want %d", recorder.Code, testCase.wantStatus)
170 }
171
172 wwwAuth := recorder.Header().Get("WWW-Authenticate")
173 hasAuthHeader := wwwAuth != ""
174
175 if hasAuthHeader != testCase.wantAuthHeader {
176 t.Errorf("WWW-Authenticate header present = %v, want %v", hasAuthHeader, testCase.wantAuthHeader)
177 }
178
179 if testCase.wantAuthHeader && wwwAuth != `Bearer realm="lune-mcp"` {
180 t.Errorf("WWW-Authenticate = %q, want %q", wwwAuth, `Bearer realm="lune-mcp"`)
181 }
182}
183
184func TestMiddlewareQueryParam(t *testing.T) {
185 t.Parallel()
186
187 token := "query-param-token" //nolint:gosec // test fixture, not a real credential
188
189 hash, err := auth.Hash(token)
190 if err != nil {
191 t.Fatalf("Hash() error = %v", err)
192 }
193
194 handler := http.HandlerFunc(func(rw http.ResponseWriter, _ *http.Request) {
195 rw.WriteHeader(http.StatusOK)
196 })
197
198 protected := auth.Middleware(hash)(handler)
199
200 t.Run("valid query param", func(t *testing.T) {
201 t.Parallel()
202
203 req := httptest.NewRequest(http.MethodGet, "/?access_token="+token, nil)
204 recorder := httptest.NewRecorder()
205 protected.ServeHTTP(recorder, req)
206
207 if recorder.Code != http.StatusOK {
208 t.Errorf("status = %d, want %d", recorder.Code, http.StatusOK)
209 }
210 })
211
212 t.Run("invalid query param", func(t *testing.T) {
213 t.Parallel()
214
215 req := httptest.NewRequest(http.MethodGet, "/?access_token=wrong-token", nil)
216 recorder := httptest.NewRecorder()
217 protected.ServeHTTP(recorder, req)
218
219 if recorder.Code != http.StatusUnauthorized {
220 t.Errorf("status = %d, want %d", recorder.Code, http.StatusUnauthorized)
221 }
222 })
223
224 t.Run("header takes precedence over query param", func(t *testing.T) {
225 t.Parallel()
226
227 req := httptest.NewRequest(http.MethodGet, "/?access_token=wrong-token", nil)
228 req.Header.Set("Authorization", "Bearer "+token)
229
230 recorder := httptest.NewRecorder()
231 protected.ServeHTTP(recorder, req)
232
233 if recorder.Code != http.StatusOK {
234 t.Errorf("status = %d, want %d (header should take precedence)", recorder.Code, http.StatusOK)
235 }
236 })
237}