feat(mcp): add optional Bearer token auth

Amolith created

Adds argon2id-hashed token authentication for SSE and HTTP transports.
Tokens accepted via Authorization header or access_token query parameter
(RFC 6750). Auth enabled by presence of token_hash in config.

New command: lune mcp set-token

Assisted-by: Claude Sonnet 4 via Crush <crush@charm.land>

Change summary

cmd/mcp/server.go              |  15 +
cmd/mcp/set_token.go           | 163 ++++++++++++++++++++++++
go.mod                         |   1 
go.sum                         |   2 
internal/config/config.go      |   9 
internal/mcp/auth/auth.go      | 137 ++++++++++++++++++++
internal/mcp/auth/auth_test.go | 237 ++++++++++++++++++++++++++++++++++++
7 files changed, 558 insertions(+), 6 deletions(-)

Detailed changes

cmd/mcp/server.go 🔗

@@ -12,6 +12,7 @@ import (
 	"strconv"
 
 	"git.secluded.site/lune/internal/config"
+	"git.secluded.site/lune/internal/mcp/auth"
 	"git.secluded.site/lune/internal/mcp/resources/areas"
 	"git.secluded.site/lune/internal/mcp/resources/habits"
 	noters "git.secluded.site/lune/internal/mcp/resources/note"
@@ -426,10 +427,15 @@ func runSSE(cmd *cobra.Command, mcpServer *mcp.Server, cfg *config.Config) error
 		return mcpServer
 	}, nil)
 
+	var httpHandler http.Handler = handler
+	if cfg.MCP.TokenHash != "" {
+		httpHandler = auth.Middleware(cfg.MCP.TokenHash)(handler)
+	}
+
 	fmt.Fprintf(cmd.OutOrStdout(), "SSE server listening on %s\n", hostPort)
 
 	//nolint:gosec // MCP SDK controls server lifecycle; timeouts not applicable
-	if err := http.ListenAndServe(hostPort, handler); err != nil {
+	if err := http.ListenAndServe(hostPort, httpHandler); err != nil {
 		return fmt.Errorf("SSE server error: %w", err)
 	}
 
@@ -442,10 +448,15 @@ func runHTTP(cmd *cobra.Command, mcpServer *mcp.Server, cfg *config.Config) erro
 		return mcpServer
 	}, nil)
 
+	var httpHandler http.Handler = handler
+	if cfg.MCP.TokenHash != "" {
+		httpHandler = auth.Middleware(cfg.MCP.TokenHash)(handler)
+	}
+
 	fmt.Fprintf(cmd.OutOrStdout(), "HTTP server listening on %s\n", hostPort)
 
 	//nolint:gosec // MCP SDK controls server lifecycle; timeouts not applicable
-	if err := http.ListenAndServe(hostPort, handler); err != nil {
+	if err := http.ListenAndServe(hostPort, httpHandler); err != nil {
 		return fmt.Errorf("HTTP server error: %w", err)
 	}
 

cmd/mcp/set_token.go 🔗

@@ -0,0 +1,163 @@
+// SPDX-FileCopyrightText: Amolith <amolith@secluded.site>
+//
+// SPDX-License-Identifier: AGPL-3.0-or-later
+
+package mcp
+
+import (
+	"errors"
+	"fmt"
+	"io"
+
+	"git.secluded.site/lune/internal/config"
+	"git.secluded.site/lune/internal/mcp/auth"
+	"git.secluded.site/lune/internal/ui"
+	"github.com/charmbracelet/huh"
+	"github.com/spf13/cobra"
+)
+
+var (
+	errTokenMismatch = errors.New("tokens do not match")
+	errTokenEmpty    = errors.New("token cannot be empty")
+)
+
+var setTokenCmd = &cobra.Command{
+	Use:   "set-token",
+	Short: "Set authentication token for the MCP server",
+	Long: `Set a Bearer token for authenticating MCP server requests.
+
+When a token hash is configured, SSE and HTTP transports require clients
+to provide the token via the Authorization header:
+
+  Authorization: Bearer <token>
+
+The token is hashed using argon2id before being stored in the config file.
+Only the hash is stored, not the plaintext token.
+
+Note: stdio transport does not use authentication (local processes only).`,
+	RunE: runSetToken,
+}
+
+func init() {
+	Cmd.AddCommand(setTokenCmd)
+}
+
+func runSetToken(cmd *cobra.Command, _ []string) error {
+	out := cmd.OutOrStdout()
+
+	cfg, err := loadOrCreateConfig()
+	if err != nil {
+		return err
+	}
+
+	if cfg.MCP.TokenHash != "" {
+		proceed, confirmErr := confirmTokenReplace(out)
+		if confirmErr != nil {
+			return confirmErr
+		}
+
+		if !proceed {
+			return nil
+		}
+	}
+
+	token, err := promptForMCPToken()
+	if err != nil {
+		if errors.Is(err, huh.ErrUserAborted) {
+			return nil
+		}
+
+		return err
+	}
+
+	hash, err := auth.Hash(token)
+	if err != nil {
+		return fmt.Errorf("hashing token: %w", err)
+	}
+
+	cfg.MCP.TokenHash = hash
+
+	if err := cfg.Save(); err != nil {
+		return fmt.Errorf("saving config: %w", err)
+	}
+
+	fmt.Fprintln(out, ui.Success.Render("MCP authentication token configured."))
+	fmt.Fprintln(out, "Restart the MCP server for changes to take effect.")
+
+	return nil
+}
+
+func loadOrCreateConfig() (*config.Config, error) {
+	cfg, err := config.Load()
+	if err != nil {
+		if errors.Is(err, config.ErrNotFound) {
+			return &config.Config{}, nil
+		}
+
+		return nil, fmt.Errorf("loading config: %w", err)
+	}
+
+	return cfg, nil
+}
+
+func confirmTokenReplace(out io.Writer) (bool, error) {
+	fmt.Fprintln(out, ui.Warning.Render("An authentication token is already configured."))
+	fmt.Fprintln(out, "Setting a new token will replace the existing one.")
+	fmt.Fprintln(out)
+
+	var proceed bool
+
+	err := huh.NewConfirm().
+		Title("Replace existing token?").
+		Affirmative("Yes").
+		Negative("No").
+		Value(&proceed).
+		Run()
+	if err != nil {
+		if errors.Is(err, huh.ErrUserAborted) {
+			return false, nil
+		}
+
+		return false, err
+	}
+
+	return proceed, nil
+}
+
+func promptForMCPToken() (string, error) {
+	var token, confirm string
+
+	err := huh.NewForm(
+		huh.NewGroup(
+			huh.NewInput().
+				Title("MCP Authentication Token").
+				Description("Enter a token to protect your MCP server.").
+				EchoMode(huh.EchoModePassword).
+				Value(&token).
+				Validate(func(input string) error {
+					if input == "" {
+						return errTokenEmpty
+					}
+
+					return nil
+				}),
+			huh.NewInput().
+				Title("Confirm Token").
+				Description("Re-enter the token to confirm.").
+				EchoMode(huh.EchoModePassword).
+				Value(&confirm).
+				Validate(func(input string) error {
+					if input != token {
+						return errTokenMismatch
+					}
+
+					return nil
+				}),
+		),
+	).Run()
+	if err != nil {
+		return "", err
+	}
+
+	return token, nil
+}

go.mod 🔗

@@ -69,6 +69,7 @@ require (
 	github.com/wasilibs/go-re2 v1.3.0 // indirect
 	github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
 	github.com/yosida95/uritemplate/v3 v3.0.2 // indirect
+	golang.org/x/crypto v0.46.0 // indirect
 	golang.org/x/oauth2 v0.30.0 // indirect
 	golang.org/x/sync v0.19.0 // indirect
 	golang.org/x/sys v0.39.0 // indirect

go.sum 🔗

@@ -150,6 +150,8 @@ github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT0
 github.com/zalando/go-keyring v0.2.6 h1:r7Yc3+H+Ux0+M72zacZoItR3UDxeWfKTcabvkI8ua9s=
 github.com/zalando/go-keyring v0.2.6/go.mod h1:2TCrxYrbUNYfNS/Kgy/LSrkSQzZ5UPVH85RwfczwvcI=
 go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
+golang.org/x/crypto v0.46.0 h1:cKRW/pmt1pKAfetfu+RCEvjvZkA9RimPbh7bhFjGVBU=
+golang.org/x/crypto v0.46.0/go.mod h1:Evb/oLKmMraqjZ2iQTwDwvCtJkczlDuTmdJXoZVzqU0=
 golang.org/x/exp v0.0.0-20231006140011-7918f672742d h1:jtJma62tbqLibJ5sFQz8bKtEM8rJBtfilJ2qTU199MI=
 golang.org/x/exp v0.0.0-20231006140011-7918f672742d/go.mod h1:ldy0pHrwJyGW56pPQzzkH36rKxoZW1tw7ZJpeKx+hdo=
 golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI=

internal/config/config.go 🔗

@@ -30,10 +30,11 @@ type Config struct {
 
 // MCPConfig holds MCP server settings.
 type MCPConfig struct {
-	Host     string      `toml:"host"`
-	Port     int         `toml:"port"`
-	Timezone string      `toml:"timezone"`
-	Tools    ToolsConfig `toml:"tools"`
+	Host      string      `toml:"host"`
+	Port      int         `toml:"port"`
+	Timezone  string      `toml:"timezone"`
+	TokenHash string      `toml:"token_hash,omitempty"`
+	Tools     ToolsConfig `toml:"tools"`
 }
 
 // ToolsConfig controls which MCP tools are enabled.

internal/mcp/auth/auth.go 🔗

@@ -0,0 +1,137 @@
+// SPDX-FileCopyrightText: Amolith <amolith@secluded.site>
+//
+// SPDX-License-Identifier: AGPL-3.0-or-later
+
+// Package auth provides Bearer token authentication for the MCP server.
+package auth
+
+import (
+	"crypto/rand"
+	"crypto/subtle"
+	"encoding/base64"
+	"errors"
+	"fmt"
+	"net/http"
+	"strings"
+
+	"golang.org/x/crypto/argon2"
+)
+
+// Argon2id parameters per OWASP recommendations.
+const (
+	argonTime    = 2
+	argonMemory  = 19 * 1024 // 19 MiB
+	argonThreads = 1
+	argonKeyLen  = 32
+	saltLen      = 16
+	hashParts    = 6 // $argon2id$v=19$m=...,t=...,p=...$salt$hash
+)
+
+// ErrInvalidHash indicates the stored hash is malformed.
+var ErrInvalidHash = errors.New("invalid hash format")
+
+// Hash generates an argon2id hash of the token with a random salt.
+// Returns an encoded string containing version, params, salt, and hash.
+func Hash(token string) (string, error) {
+	salt := make([]byte, saltLen)
+	if _, err := rand.Read(salt); err != nil {
+		return "", fmt.Errorf("generating salt: %w", err)
+	}
+
+	hash := argon2.IDKey([]byte(token), salt, argonTime, argonMemory, argonThreads, argonKeyLen)
+
+	// Format: $argon2id$v=19$m=19456,t=2,p=1$<salt>$<hash>
+	encoded := fmt.Sprintf(
+		"$argon2id$v=%d$m=%d,t=%d,p=%d$%s$%s",
+		argon2.Version,
+		argonMemory,
+		argonTime,
+		argonThreads,
+		base64.RawStdEncoding.EncodeToString(salt),
+		base64.RawStdEncoding.EncodeToString(hash),
+	)
+
+	return encoded, nil
+}
+
+// Verify checks if the token matches the encoded hash.
+// Uses constant-time comparison to prevent timing attacks.
+func Verify(token, encodedHash string) bool {
+	salt, storedHash, err := decodeHash(encodedHash)
+	if err != nil {
+		return false
+	}
+
+	computedHash := argon2.IDKey([]byte(token), salt, argonTime, argonMemory, argonThreads, argonKeyLen)
+
+	return subtle.ConstantTimeCompare(storedHash, computedHash) == 1
+}
+
+// decodeHash parses the encoded hash string and extracts the salt and hash.
+func decodeHash(encoded string) ([]byte, []byte, error) {
+	parts := strings.Split(encoded, "$")
+	if len(parts) != hashParts {
+		return nil, nil, ErrInvalidHash
+	}
+
+	// parts[0] is empty (leading $)
+	// parts[1] is "argon2id"
+	// parts[2] is "v=19"
+	// parts[3] is "m=19456,t=2,p=1"
+	// parts[4] is base64-encoded salt
+	// parts[5] is base64-encoded hash
+
+	if parts[1] != "argon2id" {
+		return nil, nil, ErrInvalidHash
+	}
+
+	salt, err := base64.RawStdEncoding.DecodeString(parts[4])
+	if err != nil {
+		return nil, nil, fmt.Errorf("decoding salt: %w", err)
+	}
+
+	hash, err := base64.RawStdEncoding.DecodeString(parts[5])
+	if err != nil {
+		return nil, nil, fmt.Errorf("decoding hash: %w", err)
+	}
+
+	return salt, hash, nil
+}
+
+// Middleware returns HTTP middleware that validates Bearer tokens.
+// Accepts tokens via Authorization header (preferred) or access_token query
+// parameter (RFC 6750 Section 2.3 fallback for clients that can't set headers).
+// Returns 401 with WWW-Authenticate header on failure.
+func Middleware(tokenHash string) func(http.Handler) http.Handler {
+	return func(next http.Handler) http.Handler {
+		return http.HandlerFunc(func(writer http.ResponseWriter, req *http.Request) {
+			token := extractToken(req)
+			if token == "" || !Verify(token, tokenHash) {
+				unauthorized(writer)
+
+				return
+			}
+
+			next.ServeHTTP(writer, req)
+		})
+	}
+}
+
+// extractToken gets the Bearer token from Authorization header or query param.
+func extractToken(req *http.Request) string {
+	// Prefer Authorization header (RFC 6750 Section 2.1)
+	if authHeader := req.Header.Get("Authorization"); authHeader != "" {
+		const bearerPrefix = "Bearer "
+		if strings.HasPrefix(authHeader, bearerPrefix) {
+			return strings.TrimPrefix(authHeader, bearerPrefix)
+		}
+	}
+
+	// Fall back to query parameter (RFC 6750 Section 2.3)
+	return req.URL.Query().Get("access_token")
+}
+
+func unauthorized(w http.ResponseWriter) {
+	w.Header().Set("WWW-Authenticate", `Bearer realm="lune-mcp"`)
+	http.Error(w, "Unauthorized", http.StatusUnauthorized)
+}

internal/mcp/auth/auth_test.go 🔗

@@ -0,0 +1,237 @@
+// SPDX-FileCopyrightText: Amolith <amolith@secluded.site>
+//
+// SPDX-License-Identifier: AGPL-3.0-or-later
+
+package auth_test
+
+import (
+	"net/http"
+	"net/http/httptest"
+	"testing"
+
+	"git.secluded.site/lune/internal/mcp/auth"
+)
+
+func TestHashAndVerify(t *testing.T) {
+	t.Parallel()
+
+	token := "my-secret-token"
+
+	hash, err := auth.Hash(token)
+	if err != nil {
+		t.Fatalf("Hash() error = %v", err)
+	}
+
+	if hash == "" {
+		t.Fatal("Hash() returned empty string")
+	}
+
+	// Verify with correct token
+	if !auth.Verify(token, hash) {
+		t.Error("Verify() returned false for correct token")
+	}
+
+	// Verify with wrong token
+	if auth.Verify("wrong-token", hash) {
+		t.Error("Verify() returned true for wrong token")
+	}
+}
+
+func TestHashUniqueSalts(t *testing.T) {
+	t.Parallel()
+
+	token := "same-token"
+
+	hash1, err := auth.Hash(token)
+	if err != nil {
+		t.Fatalf("Hash() error = %v", err)
+	}
+
+	hash2, err := auth.Hash(token)
+	if err != nil {
+		t.Fatalf("Hash() error = %v", err)
+	}
+
+	if hash1 == hash2 {
+		t.Error("Hash() produced identical hashes for same token (salt not random)")
+	}
+
+	// Both should still verify correctly
+	if !auth.Verify(token, hash1) {
+		t.Error("Verify() failed for hash1")
+	}
+
+	if !auth.Verify(token, hash2) {
+		t.Error("Verify() failed for hash2")
+	}
+}
+
+func TestVerifyInvalidHash(t *testing.T) {
+	t.Parallel()
+
+	tests := []struct {
+		name string
+		hash string
+	}{
+		{"empty", ""},
+		{"garbage", "not-a-valid-hash"},
+		{"wrong algorithm", "$bcrypt$v=2$something$salt$hash"},
+		{"too few parts", "$argon2id$v=19$m=19456$salt"},
+		{"invalid base64 salt", "$argon2id$v=19$m=19456,t=2,p=1$!!!invalid$hash"},
+		{"invalid base64 hash", "$argon2id$v=19$m=19456,t=2,p=1$c2FsdA$!!!invalid"},
+	}
+
+	for _, tc := range tests {
+		t.Run(tc.name, func(t *testing.T) {
+			t.Parallel()
+
+			if auth.Verify("token", tc.hash) {
+				t.Errorf("Verify() returned true for invalid hash %q", tc.hash)
+			}
+		})
+	}
+}
+
+func TestMiddleware(t *testing.T) {
+	t.Parallel()
+
+	token := "test-bearer-token"
+
+	hash, err := auth.Hash(token)
+	if err != nil {
+		t.Fatalf("Hash() error = %v", err)
+	}
+
+	handler := http.HandlerFunc(func(rw http.ResponseWriter, _ *http.Request) {
+		rw.WriteHeader(http.StatusOK)
+		_, _ = rw.Write([]byte("OK"))
+	})
+
+	protected := auth.Middleware(hash)(handler)
+
+	testCases := middlewareTestCases(token)
+	for _, testCase := range testCases {
+		t.Run(testCase.name, func(t *testing.T) {
+			t.Parallel()
+			runMiddlewareTest(t, protected, testCase)
+		})
+	}
+}
+
+type middlewareTestCase struct {
+	name           string
+	authHeader     string
+	wantStatus     int
+	wantAuthHeader bool
+}
+
+func middlewareTestCases(token string) []middlewareTestCase {
+	return []middlewareTestCase{
+		{
+			name:           "missing header",
+			authHeader:     "",
+			wantStatus:     http.StatusUnauthorized,
+			wantAuthHeader: true,
+		},
+		{
+			name:           "wrong scheme",
+			authHeader:     "Basic dXNlcjpwYXNz",
+			wantStatus:     http.StatusUnauthorized,
+			wantAuthHeader: true,
+		},
+		{
+			name:           "wrong token",
+			authHeader:     "Bearer wrong-token",
+			wantStatus:     http.StatusUnauthorized,
+			wantAuthHeader: true,
+		},
+		{
+			name:           "valid token",
+			authHeader:     "Bearer " + token,
+			wantStatus:     http.StatusOK,
+			wantAuthHeader: false,
+		},
+	}
+}
+
+func runMiddlewareTest(t *testing.T, protected http.Handler, testCase middlewareTestCase) {
+	t.Helper()
+
+	req := httptest.NewRequest(http.MethodGet, "/", nil)
+	if testCase.authHeader != "" {
+		req.Header.Set("Authorization", testCase.authHeader)
+	}
+
+	recorder := httptest.NewRecorder()
+	protected.ServeHTTP(recorder, req)
+
+	if recorder.Code != testCase.wantStatus {
+		t.Errorf("status = %d, want %d", recorder.Code, testCase.wantStatus)
+	}
+
+	wwwAuth := recorder.Header().Get("WWW-Authenticate")
+	hasAuthHeader := wwwAuth != ""
+
+	if hasAuthHeader != testCase.wantAuthHeader {
+		t.Errorf("WWW-Authenticate header present = %v, want %v", hasAuthHeader, testCase.wantAuthHeader)
+	}
+
+	if testCase.wantAuthHeader && wwwAuth != `Bearer realm="lune-mcp"` {
+		t.Errorf("WWW-Authenticate = %q, want %q", wwwAuth, `Bearer realm="lune-mcp"`)
+	}
+}
+
+func TestMiddlewareQueryParam(t *testing.T) {
+	t.Parallel()
+
+	token := "query-param-token" //nolint:gosec // test fixture, not a real credential
+
+	hash, err := auth.Hash(token)
+	if err != nil {
+		t.Fatalf("Hash() error = %v", err)
+	}
+
+	handler := http.HandlerFunc(func(rw http.ResponseWriter, _ *http.Request) {
+		rw.WriteHeader(http.StatusOK)
+	})
+
+	protected := auth.Middleware(hash)(handler)
+
+	t.Run("valid query param", func(t *testing.T) {
+		t.Parallel()
+
+		req := httptest.NewRequest(http.MethodGet, "/?access_token="+token, nil)
+		recorder := httptest.NewRecorder()
+		protected.ServeHTTP(recorder, req)
+
+		if recorder.Code != http.StatusOK {
+			t.Errorf("status = %d, want %d", recorder.Code, http.StatusOK)
+		}
+	})
+
+	t.Run("invalid query param", func(t *testing.T) {
+		t.Parallel()
+
+		req := httptest.NewRequest(http.MethodGet, "/?access_token=wrong-token", nil)
+		recorder := httptest.NewRecorder()
+		protected.ServeHTTP(recorder, req)
+
+		if recorder.Code != http.StatusUnauthorized {
+			t.Errorf("status = %d, want %d", recorder.Code, http.StatusUnauthorized)
+		}
+	})
+
+	t.Run("header takes precedence over query param", func(t *testing.T) {
+		t.Parallel()
+
+		req := httptest.NewRequest(http.MethodGet, "/?access_token=wrong-token", nil)
+		req.Header.Set("Authorization", "Bearer "+token)
+
+		recorder := httptest.NewRecorder()
+		protected.ServeHTTP(recorder, req)
+
+		if recorder.Code != http.StatusOK {
+			t.Errorf("status = %d, want %d (header should take precedence)", recorder.Code, http.StatusOK)
+		}
+	})
+}