auth_test.go

  1// SPDX-FileCopyrightText: Amolith <amolith@secluded.site>
  2//
  3// SPDX-License-Identifier: AGPL-3.0-or-later
  4
  5package auth_test
  6
  7import (
  8	"net/http"
  9	"net/http/httptest"
 10	"testing"
 11
 12	"git.secluded.site/lune/internal/mcp/auth"
 13)
 14
 15func TestHashAndVerify(t *testing.T) {
 16	t.Parallel()
 17
 18	token := "my-secret-token"
 19
 20	hash, err := auth.Hash(token)
 21	if err != nil {
 22		t.Fatalf("Hash() error = %v", err)
 23	}
 24
 25	if hash == "" {
 26		t.Fatal("Hash() returned empty string")
 27	}
 28
 29	// Verify with correct token
 30	if !auth.Verify(token, hash) {
 31		t.Error("Verify() returned false for correct token")
 32	}
 33
 34	// Verify with wrong token
 35	if auth.Verify("wrong-token", hash) {
 36		t.Error("Verify() returned true for wrong token")
 37	}
 38}
 39
 40func TestHashUniqueSalts(t *testing.T) {
 41	t.Parallel()
 42
 43	token := "same-token"
 44
 45	hash1, err := auth.Hash(token)
 46	if err != nil {
 47		t.Fatalf("Hash() error = %v", err)
 48	}
 49
 50	hash2, err := auth.Hash(token)
 51	if err != nil {
 52		t.Fatalf("Hash() error = %v", err)
 53	}
 54
 55	if hash1 == hash2 {
 56		t.Error("Hash() produced identical hashes for same token (salt not random)")
 57	}
 58
 59	// Both should still verify correctly
 60	if !auth.Verify(token, hash1) {
 61		t.Error("Verify() failed for hash1")
 62	}
 63
 64	if !auth.Verify(token, hash2) {
 65		t.Error("Verify() failed for hash2")
 66	}
 67}
 68
 69func TestVerifyInvalidHash(t *testing.T) {
 70	t.Parallel()
 71
 72	tests := []struct {
 73		name string
 74		hash string
 75	}{
 76		{"empty", ""},
 77		{"garbage", "not-a-valid-hash"},
 78		{"wrong algorithm", "$bcrypt$v=2$something$salt$hash"},
 79		{"too few parts", "$argon2id$v=19$m=19456$salt"},
 80		{"invalid base64 salt", "$argon2id$v=19$m=19456,t=2,p=1$!!!invalid$hash"},
 81		{"invalid base64 hash", "$argon2id$v=19$m=19456,t=2,p=1$c2FsdA$!!!invalid"},
 82	}
 83
 84	for _, tc := range tests {
 85		t.Run(tc.name, func(t *testing.T) {
 86			t.Parallel()
 87
 88			if auth.Verify("token", tc.hash) {
 89				t.Errorf("Verify() returned true for invalid hash %q", tc.hash)
 90			}
 91		})
 92	}
 93}
 94
 95func TestMiddleware(t *testing.T) {
 96	t.Parallel()
 97
 98	token := "test-bearer-token"
 99
100	hash, err := auth.Hash(token)
101	if err != nil {
102		t.Fatalf("Hash() error = %v", err)
103	}
104
105	handler := http.HandlerFunc(func(rw http.ResponseWriter, _ *http.Request) {
106		rw.WriteHeader(http.StatusOK)
107		_, _ = rw.Write([]byte("OK"))
108	})
109
110	protected := auth.Middleware(hash)(handler)
111
112	testCases := middlewareTestCases(token)
113	for _, testCase := range testCases {
114		t.Run(testCase.name, func(t *testing.T) {
115			t.Parallel()
116			runMiddlewareTest(t, protected, testCase)
117		})
118	}
119}
120
121type middlewareTestCase struct {
122	name           string
123	authHeader     string
124	wantStatus     int
125	wantAuthHeader bool
126}
127
128func middlewareTestCases(token string) []middlewareTestCase {
129	return []middlewareTestCase{
130		{
131			name:           "missing header",
132			authHeader:     "",
133			wantStatus:     http.StatusUnauthorized,
134			wantAuthHeader: true,
135		},
136		{
137			name:           "wrong scheme",
138			authHeader:     "Basic dXNlcjpwYXNz",
139			wantStatus:     http.StatusUnauthorized,
140			wantAuthHeader: true,
141		},
142		{
143			name:           "wrong token",
144			authHeader:     "Bearer wrong-token",
145			wantStatus:     http.StatusUnauthorized,
146			wantAuthHeader: true,
147		},
148		{
149			name:           "valid token",
150			authHeader:     "Bearer " + token,
151			wantStatus:     http.StatusOK,
152			wantAuthHeader: false,
153		},
154	}
155}
156
157func runMiddlewareTest(t *testing.T, protected http.Handler, testCase middlewareTestCase) {
158	t.Helper()
159
160	req := httptest.NewRequest(http.MethodGet, "/", nil)
161	if testCase.authHeader != "" {
162		req.Header.Set("Authorization", testCase.authHeader)
163	}
164
165	recorder := httptest.NewRecorder()
166	protected.ServeHTTP(recorder, req)
167
168	if recorder.Code != testCase.wantStatus {
169		t.Errorf("status = %d, want %d", recorder.Code, testCase.wantStatus)
170	}
171
172	wwwAuth := recorder.Header().Get("WWW-Authenticate")
173	hasAuthHeader := wwwAuth != ""
174
175	if hasAuthHeader != testCase.wantAuthHeader {
176		t.Errorf("WWW-Authenticate header present = %v, want %v", hasAuthHeader, testCase.wantAuthHeader)
177	}
178
179	if testCase.wantAuthHeader && wwwAuth != `Bearer realm="lune-mcp"` {
180		t.Errorf("WWW-Authenticate = %q, want %q", wwwAuth, `Bearer realm="lune-mcp"`)
181	}
182}
183
184func TestMiddlewareQueryParam(t *testing.T) {
185	t.Parallel()
186
187	token := "query-param-token" //nolint:gosec // test fixture, not a real credential
188
189	hash, err := auth.Hash(token)
190	if err != nil {
191		t.Fatalf("Hash() error = %v", err)
192	}
193
194	handler := http.HandlerFunc(func(rw http.ResponseWriter, _ *http.Request) {
195		rw.WriteHeader(http.StatusOK)
196	})
197
198	protected := auth.Middleware(hash)(handler)
199
200	t.Run("valid query param", func(t *testing.T) {
201		t.Parallel()
202
203		req := httptest.NewRequest(http.MethodGet, "/?access_token="+token, nil)
204		recorder := httptest.NewRecorder()
205		protected.ServeHTTP(recorder, req)
206
207		if recorder.Code != http.StatusOK {
208			t.Errorf("status = %d, want %d", recorder.Code, http.StatusOK)
209		}
210	})
211
212	t.Run("invalid query param", func(t *testing.T) {
213		t.Parallel()
214
215		req := httptest.NewRequest(http.MethodGet, "/?access_token=wrong-token", nil)
216		recorder := httptest.NewRecorder()
217		protected.ServeHTTP(recorder, req)
218
219		if recorder.Code != http.StatusUnauthorized {
220			t.Errorf("status = %d, want %d", recorder.Code, http.StatusUnauthorized)
221		}
222	})
223
224	t.Run("header takes precedence over query param", func(t *testing.T) {
225		t.Parallel()
226
227		req := httptest.NewRequest(http.MethodGet, "/?access_token=wrong-token", nil)
228		req.Header.Set("Authorization", "Bearer "+token)
229
230		recorder := httptest.NewRecorder()
231		protected.ServeHTTP(recorder, req)
232
233		if recorder.Code != http.StatusOK {
234			t.Errorf("status = %d, want %d (header should take precedence)", recorder.Code, http.StatusOK)
235		}
236	})
237}