1// SPDX-FileCopyrightText: Amolith <amolith@secluded.site>
2//
3// SPDX-License-Identifier: AGPL-3.0-or-later
4
5// Package auth provides Bearer token authentication for the MCP server.
6package auth
7
8import (
9 "crypto/rand"
10 "crypto/subtle"
11 "encoding/base64"
12 "errors"
13 "fmt"
14 "net/http"
15 "strings"
16
17 "golang.org/x/crypto/argon2"
18)
19
20// Argon2id parameters per OWASP recommendations.
21const (
22 argonTime = 2
23 argonMemory = 19 * 1024 // 19 MiB
24 argonThreads = 1
25 argonKeyLen = 32
26 saltLen = 16
27 hashParts = 6 // $argon2id$v=19$m=...,t=...,p=...$salt$hash
28)
29
30// ErrInvalidHash indicates the stored hash is malformed.
31var ErrInvalidHash = errors.New("invalid hash format")
32
33// Hash generates an argon2id hash of the token with a random salt.
34// Returns an encoded string containing version, params, salt, and hash.
35func Hash(token string) (string, error) {
36 salt := make([]byte, saltLen)
37 if _, err := rand.Read(salt); err != nil {
38 return "", fmt.Errorf("generating salt: %w", err)
39 }
40
41 hash := argon2.IDKey([]byte(token), salt, argonTime, argonMemory, argonThreads, argonKeyLen)
42
43 // Format: $argon2id$v=19$m=19456,t=2,p=1$<salt>$<hash>
44 encoded := fmt.Sprintf(
45 "$argon2id$v=%d$m=%d,t=%d,p=%d$%s$%s",
46 argon2.Version,
47 argonMemory,
48 argonTime,
49 argonThreads,
50 base64.RawStdEncoding.EncodeToString(salt),
51 base64.RawStdEncoding.EncodeToString(hash),
52 )
53
54 return encoded, nil
55}
56
57// Verify checks if the token matches the encoded hash.
58// Uses constant-time comparison to prevent timing attacks.
59func Verify(token, encodedHash string) bool {
60 salt, storedHash, err := decodeHash(encodedHash)
61 if err != nil {
62 return false
63 }
64
65 computedHash := argon2.IDKey([]byte(token), salt, argonTime, argonMemory, argonThreads, argonKeyLen)
66
67 return subtle.ConstantTimeCompare(storedHash, computedHash) == 1
68}
69
70// decodeHash parses the encoded hash string and extracts the salt and hash.
71func decodeHash(encoded string) ([]byte, []byte, error) {
72 parts := strings.Split(encoded, "$")
73 if len(parts) != hashParts {
74 return nil, nil, ErrInvalidHash
75 }
76
77 // parts[0] is empty (leading $)
78 // parts[1] is "argon2id"
79 // parts[2] is "v=19"
80 // parts[3] is "m=19456,t=2,p=1"
81 // parts[4] is base64-encoded salt
82 // parts[5] is base64-encoded hash
83
84 if parts[1] != "argon2id" {
85 return nil, nil, ErrInvalidHash
86 }
87
88 salt, err := base64.RawStdEncoding.DecodeString(parts[4])
89 if err != nil {
90 return nil, nil, fmt.Errorf("decoding salt: %w", err)
91 }
92
93 hash, err := base64.RawStdEncoding.DecodeString(parts[5])
94 if err != nil {
95 return nil, nil, fmt.Errorf("decoding hash: %w", err)
96 }
97
98 return salt, hash, nil
99}
100
101// Middleware returns HTTP middleware that validates Bearer tokens.
102// Accepts tokens via Authorization header (preferred) or access_token query
103// parameter (RFC 6750 Section 2.3 fallback for clients that can't set headers).
104// Returns 401 with WWW-Authenticate header on failure.
105func Middleware(tokenHash string) func(http.Handler) http.Handler {
106 return func(next http.Handler) http.Handler {
107 return http.HandlerFunc(func(writer http.ResponseWriter, req *http.Request) {
108 token := extractToken(req)
109 if token == "" || !Verify(token, tokenHash) {
110 unauthorized(writer)
111
112 return
113 }
114
115 next.ServeHTTP(writer, req)
116 })
117 }
118}
119
120// extractToken gets the Bearer token from Authorization header or query param.
121func extractToken(req *http.Request) string {
122 // Prefer Authorization header (RFC 6750 Section 2.1)
123 if authHeader := req.Header.Get("Authorization"); authHeader != "" {
124 const bearerPrefix = "Bearer "
125 if strings.HasPrefix(authHeader, bearerPrefix) {
126 return strings.TrimPrefix(authHeader, bearerPrefix)
127 }
128 }
129
130 // Fall back to query parameter (RFC 6750 Section 2.3)
131 return req.URL.Query().Get("access_token")
132}
133
134func unauthorized(w http.ResponseWriter) {
135 w.Header().Set("WWW-Authenticate", `Bearer realm="lune-mcp"`)
136 http.Error(w, "Unauthorized", http.StatusUnauthorized)
137}