// SPDX-FileCopyrightText: Amolith <amolith@secluded.site>
//
// SPDX-License-Identifier: AGPL-3.0-or-later

// Package auth provides Bearer token authentication for the MCP server.
package auth

import (
	"crypto/rand"
	"crypto/subtle"
	"encoding/base64"
	"errors"
	"fmt"
	"net/http"
	"strings"

	"golang.org/x/crypto/argon2"
)

// Argon2id parameters per OWASP recommendations.
const (
	argonTime    = 2
	argonMemory  = 19 * 1024 // 19 MiB
	argonThreads = 1
	argonKeyLen  = 32
	saltLen      = 16
	hashParts    = 6 // $argon2id$v=19$m=...,t=...,p=...$salt$hash
)

// ErrInvalidHash indicates the stored hash is malformed.
var ErrInvalidHash = errors.New("invalid hash format")

// Hash generates an argon2id hash of the token with a random salt.
// Returns an encoded string containing version, params, salt, and hash.
func Hash(token string) (string, error) {
	salt := make([]byte, saltLen)
	if _, err := rand.Read(salt); err != nil {
		return "", fmt.Errorf("generating salt: %w", err)
	}

	hash := argon2.IDKey([]byte(token), salt, argonTime, argonMemory, argonThreads, argonKeyLen)

	// Format: $argon2id$v=19$m=19456,t=2,p=1$<salt>$<hash>
	encoded := fmt.Sprintf(
		"$argon2id$v=%d$m=%d,t=%d,p=%d$%s$%s",
		argon2.Version,
		argonMemory,
		argonTime,
		argonThreads,
		base64.RawStdEncoding.EncodeToString(salt),
		base64.RawStdEncoding.EncodeToString(hash),
	)

	return encoded, nil
}

// Verify checks if the token matches the encoded hash.
// Uses constant-time comparison to prevent timing attacks.
func Verify(token, encodedHash string) bool {
	salt, storedHash, err := decodeHash(encodedHash)
	if err != nil {
		return false
	}

	computedHash := argon2.IDKey([]byte(token), salt, argonTime, argonMemory, argonThreads, argonKeyLen)

	return subtle.ConstantTimeCompare(storedHash, computedHash) == 1
}

// decodeHash parses the encoded hash string and extracts the salt and hash.
func decodeHash(encoded string) ([]byte, []byte, error) {
	parts := strings.Split(encoded, "$")
	if len(parts) != hashParts {
		return nil, nil, ErrInvalidHash
	}

	// parts[0] is empty (leading $)
	// parts[1] is "argon2id"
	// parts[2] is "v=19"
	// parts[3] is "m=19456,t=2,p=1"
	// parts[4] is base64-encoded salt
	// parts[5] is base64-encoded hash

	if parts[1] != "argon2id" {
		return nil, nil, ErrInvalidHash
	}

	salt, err := base64.RawStdEncoding.DecodeString(parts[4])
	if err != nil {
		return nil, nil, fmt.Errorf("decoding salt: %w", err)
	}

	hash, err := base64.RawStdEncoding.DecodeString(parts[5])
	if err != nil {
		return nil, nil, fmt.Errorf("decoding hash: %w", err)
	}

	return salt, hash, nil
}

// Middleware returns HTTP middleware that validates Bearer tokens.
// Accepts tokens via Authorization header (preferred) or access_token query
// parameter (RFC 6750 Section 2.3 fallback for clients that can't set headers).
// Returns 401 with WWW-Authenticate header on failure.
func Middleware(tokenHash string) func(http.Handler) http.Handler {
	return func(next http.Handler) http.Handler {
		return http.HandlerFunc(func(writer http.ResponseWriter, req *http.Request) {
			token := extractToken(req)
			if token == "" || !Verify(token, tokenHash) {
				unauthorized(writer)

				return
			}

			next.ServeHTTP(writer, req)
		})
	}
}

// extractToken gets the Bearer token from Authorization header or query param.
func extractToken(req *http.Request) string {
	// Prefer Authorization header (RFC 6750 Section 2.1)
	if authHeader := req.Header.Get("Authorization"); authHeader != "" {
		const bearerPrefix = "Bearer "
		if strings.HasPrefix(authHeader, bearerPrefix) {
			return strings.TrimPrefix(authHeader, bearerPrefix)
		}
	}

	// Fall back to query parameter (RFC 6750 Section 2.3)
	return req.URL.Query().Get("access_token")
}

func unauthorized(w http.ResponseWriter) {
	w.Header().Set("WWW-Authenticate", `Bearer realm="lune-mcp"`)
	http.Error(w, "Unauthorized", http.StatusUnauthorized)
}
