auth.go

  1// SPDX-FileCopyrightText: Amolith <amolith@secluded.site>
  2//
  3// SPDX-License-Identifier: AGPL-3.0-or-later
  4
  5// Package auth provides Bearer token authentication for the MCP server.
  6package auth
  7
  8import (
  9	"crypto/rand"
 10	"crypto/subtle"
 11	"encoding/base64"
 12	"errors"
 13	"fmt"
 14	"net/http"
 15	"strings"
 16
 17	"golang.org/x/crypto/argon2"
 18)
 19
 20// Argon2id parameters per OWASP recommendations.
 21const (
 22	argonTime    = 2
 23	argonMemory  = 19 * 1024 // 19 MiB
 24	argonThreads = 1
 25	argonKeyLen  = 32
 26	saltLen      = 16
 27	hashParts    = 6 // $argon2id$v=19$m=...,t=...,p=...$salt$hash
 28)
 29
 30// ErrInvalidHash indicates the stored hash is malformed.
 31var ErrInvalidHash = errors.New("invalid hash format")
 32
 33// Hash generates an argon2id hash of the token with a random salt.
 34// Returns an encoded string containing version, params, salt, and hash.
 35func Hash(token string) (string, error) {
 36	salt := make([]byte, saltLen)
 37	if _, err := rand.Read(salt); err != nil {
 38		return "", fmt.Errorf("generating salt: %w", err)
 39	}
 40
 41	hash := argon2.IDKey([]byte(token), salt, argonTime, argonMemory, argonThreads, argonKeyLen)
 42
 43	// Format: $argon2id$v=19$m=19456,t=2,p=1$<salt>$<hash>
 44	encoded := fmt.Sprintf(
 45		"$argon2id$v=%d$m=%d,t=%d,p=%d$%s$%s",
 46		argon2.Version,
 47		argonMemory,
 48		argonTime,
 49		argonThreads,
 50		base64.RawStdEncoding.EncodeToString(salt),
 51		base64.RawStdEncoding.EncodeToString(hash),
 52	)
 53
 54	return encoded, nil
 55}
 56
 57// Verify checks if the token matches the encoded hash.
 58// Uses constant-time comparison to prevent timing attacks.
 59func Verify(token, encodedHash string) bool {
 60	salt, storedHash, err := decodeHash(encodedHash)
 61	if err != nil {
 62		return false
 63	}
 64
 65	computedHash := argon2.IDKey([]byte(token), salt, argonTime, argonMemory, argonThreads, argonKeyLen)
 66
 67	return subtle.ConstantTimeCompare(storedHash, computedHash) == 1
 68}
 69
 70// decodeHash parses the encoded hash string and extracts the salt and hash.
 71func decodeHash(encoded string) ([]byte, []byte, error) {
 72	parts := strings.Split(encoded, "$")
 73	if len(parts) != hashParts {
 74		return nil, nil, ErrInvalidHash
 75	}
 76
 77	// parts[0] is empty (leading $)
 78	// parts[1] is "argon2id"
 79	// parts[2] is "v=19"
 80	// parts[3] is "m=19456,t=2,p=1"
 81	// parts[4] is base64-encoded salt
 82	// parts[5] is base64-encoded hash
 83
 84	if parts[1] != "argon2id" {
 85		return nil, nil, ErrInvalidHash
 86	}
 87
 88	salt, err := base64.RawStdEncoding.DecodeString(parts[4])
 89	if err != nil {
 90		return nil, nil, fmt.Errorf("decoding salt: %w", err)
 91	}
 92
 93	hash, err := base64.RawStdEncoding.DecodeString(parts[5])
 94	if err != nil {
 95		return nil, nil, fmt.Errorf("decoding hash: %w", err)
 96	}
 97
 98	return salt, hash, nil
 99}
100
101// Middleware returns HTTP middleware that validates Bearer tokens.
102// Accepts tokens via Authorization header (preferred) or access_token query
103// parameter (RFC 6750 Section 2.3 fallback for clients that can't set headers).
104// Returns 401 with WWW-Authenticate header on failure.
105func Middleware(tokenHash string) func(http.Handler) http.Handler {
106	return func(next http.Handler) http.Handler {
107		return http.HandlerFunc(func(writer http.ResponseWriter, req *http.Request) {
108			token := extractToken(req)
109			if token == "" || !Verify(token, tokenHash) {
110				unauthorized(writer)
111
112				return
113			}
114
115			next.ServeHTTP(writer, req)
116		})
117	}
118}
119
120// extractToken gets the Bearer token from Authorization header or query param.
121func extractToken(req *http.Request) string {
122	// Prefer Authorization header (RFC 6750 Section 2.1)
123	if authHeader := req.Header.Get("Authorization"); authHeader != "" {
124		const bearerPrefix = "Bearer "
125		if strings.HasPrefix(authHeader, bearerPrefix) {
126			return strings.TrimPrefix(authHeader, bearerPrefix)
127		}
128	}
129
130	// Fall back to query parameter (RFC 6750 Section 2.3)
131	return req.URL.Query().Get("access_token")
132}
133
134func unauthorized(w http.ResponseWriter) {
135	w.Header().Set("WWW-Authenticate", `Bearer realm="lune-mcp"`)
136	http.Error(w, "Unauthorized", http.StatusUnauthorized)
137}