From 560cff96b99efa02cbed3a88ee5335e597758540 Mon Sep 17 00:00:00 2001 From: Amolith Date: Fri, 26 Dec 2025 10:48:00 -0700 Subject: [PATCH] feat(mcp): add optional Bearer token auth Adds argon2id-hashed token authentication for SSE and HTTP transports. Tokens accepted via Authorization header or access_token query parameter (RFC 6750). Auth enabled by presence of token_hash in config. New command: lune mcp set-token Assisted-by: Claude Sonnet 4 via Crush --- cmd/mcp/server.go | 15 ++- cmd/mcp/set_token.go | 163 +++++++++++++++++++++++ go.mod | 1 + go.sum | 2 + internal/config/config.go | 9 +- internal/mcp/auth/auth.go | 137 +++++++++++++++++++ internal/mcp/auth/auth_test.go | 237 +++++++++++++++++++++++++++++++++ 7 files changed, 558 insertions(+), 6 deletions(-) create mode 100644 cmd/mcp/set_token.go create mode 100644 internal/mcp/auth/auth.go create mode 100644 internal/mcp/auth/auth_test.go diff --git a/cmd/mcp/server.go b/cmd/mcp/server.go index 3008c8ef965d564a1da733e3f7f0e046b8cb5035..e2e7383ed01943cffc2c8d2b7c093afbb42e290d 100644 --- a/cmd/mcp/server.go +++ b/cmd/mcp/server.go @@ -12,6 +12,7 @@ import ( "strconv" "git.secluded.site/lune/internal/config" + "git.secluded.site/lune/internal/mcp/auth" "git.secluded.site/lune/internal/mcp/resources/areas" "git.secluded.site/lune/internal/mcp/resources/habits" noters "git.secluded.site/lune/internal/mcp/resources/note" @@ -426,10 +427,15 @@ func runSSE(cmd *cobra.Command, mcpServer *mcp.Server, cfg *config.Config) error return mcpServer }, nil) + var httpHandler http.Handler = handler + if cfg.MCP.TokenHash != "" { + httpHandler = auth.Middleware(cfg.MCP.TokenHash)(handler) + } + fmt.Fprintf(cmd.OutOrStdout(), "SSE server listening on %s\n", hostPort) //nolint:gosec // MCP SDK controls server lifecycle; timeouts not applicable - if err := http.ListenAndServe(hostPort, handler); err != nil { + if err := http.ListenAndServe(hostPort, httpHandler); err != nil { return fmt.Errorf("SSE server error: %w", err) } @@ -442,10 +448,15 @@ func runHTTP(cmd *cobra.Command, mcpServer *mcp.Server, cfg *config.Config) erro return mcpServer }, nil) + var httpHandler http.Handler = handler + if cfg.MCP.TokenHash != "" { + httpHandler = auth.Middleware(cfg.MCP.TokenHash)(handler) + } + fmt.Fprintf(cmd.OutOrStdout(), "HTTP server listening on %s\n", hostPort) //nolint:gosec // MCP SDK controls server lifecycle; timeouts not applicable - if err := http.ListenAndServe(hostPort, handler); err != nil { + if err := http.ListenAndServe(hostPort, httpHandler); err != nil { return fmt.Errorf("HTTP server error: %w", err) } diff --git a/cmd/mcp/set_token.go b/cmd/mcp/set_token.go new file mode 100644 index 0000000000000000000000000000000000000000..4168fafa7eddf68e4a9e3e6874fb4f0b0cad2f3f --- /dev/null +++ b/cmd/mcp/set_token.go @@ -0,0 +1,163 @@ +// SPDX-FileCopyrightText: Amolith +// +// SPDX-License-Identifier: AGPL-3.0-or-later + +package mcp + +import ( + "errors" + "fmt" + "io" + + "git.secluded.site/lune/internal/config" + "git.secluded.site/lune/internal/mcp/auth" + "git.secluded.site/lune/internal/ui" + "github.com/charmbracelet/huh" + "github.com/spf13/cobra" +) + +var ( + errTokenMismatch = errors.New("tokens do not match") + errTokenEmpty = errors.New("token cannot be empty") +) + +var setTokenCmd = &cobra.Command{ + Use: "set-token", + Short: "Set authentication token for the MCP server", + Long: `Set a Bearer token for authenticating MCP server requests. + +When a token hash is configured, SSE and HTTP transports require clients +to provide the token via the Authorization header: + + Authorization: Bearer + +The token is hashed using argon2id before being stored in the config file. +Only the hash is stored, not the plaintext token. + +Note: stdio transport does not use authentication (local processes only).`, + RunE: runSetToken, +} + +func init() { + Cmd.AddCommand(setTokenCmd) +} + +func runSetToken(cmd *cobra.Command, _ []string) error { + out := cmd.OutOrStdout() + + cfg, err := loadOrCreateConfig() + if err != nil { + return err + } + + if cfg.MCP.TokenHash != "" { + proceed, confirmErr := confirmTokenReplace(out) + if confirmErr != nil { + return confirmErr + } + + if !proceed { + return nil + } + } + + token, err := promptForMCPToken() + if err != nil { + if errors.Is(err, huh.ErrUserAborted) { + return nil + } + + return err + } + + hash, err := auth.Hash(token) + if err != nil { + return fmt.Errorf("hashing token: %w", err) + } + + cfg.MCP.TokenHash = hash + + if err := cfg.Save(); err != nil { + return fmt.Errorf("saving config: %w", err) + } + + fmt.Fprintln(out, ui.Success.Render("MCP authentication token configured.")) + fmt.Fprintln(out, "Restart the MCP server for changes to take effect.") + + return nil +} + +func loadOrCreateConfig() (*config.Config, error) { + cfg, err := config.Load() + if err != nil { + if errors.Is(err, config.ErrNotFound) { + return &config.Config{}, nil + } + + return nil, fmt.Errorf("loading config: %w", err) + } + + return cfg, nil +} + +func confirmTokenReplace(out io.Writer) (bool, error) { + fmt.Fprintln(out, ui.Warning.Render("An authentication token is already configured.")) + fmt.Fprintln(out, "Setting a new token will replace the existing one.") + fmt.Fprintln(out) + + var proceed bool + + err := huh.NewConfirm(). + Title("Replace existing token?"). + Affirmative("Yes"). + Negative("No"). + Value(&proceed). + Run() + if err != nil { + if errors.Is(err, huh.ErrUserAborted) { + return false, nil + } + + return false, err + } + + return proceed, nil +} + +func promptForMCPToken() (string, error) { + var token, confirm string + + err := huh.NewForm( + huh.NewGroup( + huh.NewInput(). + Title("MCP Authentication Token"). + Description("Enter a token to protect your MCP server."). + EchoMode(huh.EchoModePassword). + Value(&token). + Validate(func(input string) error { + if input == "" { + return errTokenEmpty + } + + return nil + }), + huh.NewInput(). + Title("Confirm Token"). + Description("Re-enter the token to confirm."). + EchoMode(huh.EchoModePassword). + Value(&confirm). + Validate(func(input string) error { + if input != token { + return errTokenMismatch + } + + return nil + }), + ), + ).Run() + if err != nil { + return "", err + } + + return token, nil +} diff --git a/go.mod b/go.mod index 41f7a2ca92cf80800943af74e7093ebd98f041dc..ccb94cac308b1708d901be9f2aacc8d1b5e88d03 100644 --- a/go.mod +++ b/go.mod @@ -69,6 +69,7 @@ require ( github.com/wasilibs/go-re2 v1.3.0 // indirect github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect github.com/yosida95/uritemplate/v3 v3.0.2 // indirect + golang.org/x/crypto v0.46.0 // indirect golang.org/x/oauth2 v0.30.0 // indirect golang.org/x/sync v0.19.0 // indirect golang.org/x/sys v0.39.0 // indirect diff --git a/go.sum b/go.sum index b647a66f962867b3391f7606dc0cc83156bd7f14..932e1a5b19e864ea35889d6e07dd5fdda989e560 100644 --- a/go.sum +++ b/go.sum @@ -150,6 +150,8 @@ github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT0 github.com/zalando/go-keyring v0.2.6 h1:r7Yc3+H+Ux0+M72zacZoItR3UDxeWfKTcabvkI8ua9s= github.com/zalando/go-keyring v0.2.6/go.mod h1:2TCrxYrbUNYfNS/Kgy/LSrkSQzZ5UPVH85RwfczwvcI= go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg= +golang.org/x/crypto v0.46.0 h1:cKRW/pmt1pKAfetfu+RCEvjvZkA9RimPbh7bhFjGVBU= +golang.org/x/crypto v0.46.0/go.mod h1:Evb/oLKmMraqjZ2iQTwDwvCtJkczlDuTmdJXoZVzqU0= golang.org/x/exp v0.0.0-20231006140011-7918f672742d h1:jtJma62tbqLibJ5sFQz8bKtEM8rJBtfilJ2qTU199MI= golang.org/x/exp v0.0.0-20231006140011-7918f672742d/go.mod h1:ldy0pHrwJyGW56pPQzzkH36rKxoZW1tw7ZJpeKx+hdo= golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI= diff --git a/internal/config/config.go b/internal/config/config.go index 6ba72c25fa71238979127fecb2bcb1f5fc838213..b458c9ba0b11ae6ecac3c41d9d5393306bb41ddd 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -30,10 +30,11 @@ type Config struct { // MCPConfig holds MCP server settings. type MCPConfig struct { - Host string `toml:"host"` - Port int `toml:"port"` - Timezone string `toml:"timezone"` - Tools ToolsConfig `toml:"tools"` + Host string `toml:"host"` + Port int `toml:"port"` + Timezone string `toml:"timezone"` + TokenHash string `toml:"token_hash,omitempty"` + Tools ToolsConfig `toml:"tools"` } // ToolsConfig controls which MCP tools are enabled. diff --git a/internal/mcp/auth/auth.go b/internal/mcp/auth/auth.go new file mode 100644 index 0000000000000000000000000000000000000000..ecb2e4c29e391cb5d33ae3c45b67e05eb8e9ee38 --- /dev/null +++ b/internal/mcp/auth/auth.go @@ -0,0 +1,137 @@ +// SPDX-FileCopyrightText: Amolith +// +// SPDX-License-Identifier: AGPL-3.0-or-later + +// Package auth provides Bearer token authentication for the MCP server. +package auth + +import ( + "crypto/rand" + "crypto/subtle" + "encoding/base64" + "errors" + "fmt" + "net/http" + "strings" + + "golang.org/x/crypto/argon2" +) + +// Argon2id parameters per OWASP recommendations. +const ( + argonTime = 2 + argonMemory = 19 * 1024 // 19 MiB + argonThreads = 1 + argonKeyLen = 32 + saltLen = 16 + hashParts = 6 // $argon2id$v=19$m=...,t=...,p=...$salt$hash +) + +// ErrInvalidHash indicates the stored hash is malformed. +var ErrInvalidHash = errors.New("invalid hash format") + +// Hash generates an argon2id hash of the token with a random salt. +// Returns an encoded string containing version, params, salt, and hash. +func Hash(token string) (string, error) { + salt := make([]byte, saltLen) + if _, err := rand.Read(salt); err != nil { + return "", fmt.Errorf("generating salt: %w", err) + } + + hash := argon2.IDKey([]byte(token), salt, argonTime, argonMemory, argonThreads, argonKeyLen) + + // Format: $argon2id$v=19$m=19456,t=2,p=1$$ + encoded := fmt.Sprintf( + "$argon2id$v=%d$m=%d,t=%d,p=%d$%s$%s", + argon2.Version, + argonMemory, + argonTime, + argonThreads, + base64.RawStdEncoding.EncodeToString(salt), + base64.RawStdEncoding.EncodeToString(hash), + ) + + return encoded, nil +} + +// Verify checks if the token matches the encoded hash. +// Uses constant-time comparison to prevent timing attacks. +func Verify(token, encodedHash string) bool { + salt, storedHash, err := decodeHash(encodedHash) + if err != nil { + return false + } + + computedHash := argon2.IDKey([]byte(token), salt, argonTime, argonMemory, argonThreads, argonKeyLen) + + return subtle.ConstantTimeCompare(storedHash, computedHash) == 1 +} + +// decodeHash parses the encoded hash string and extracts the salt and hash. +func decodeHash(encoded string) ([]byte, []byte, error) { + parts := strings.Split(encoded, "$") + if len(parts) != hashParts { + return nil, nil, ErrInvalidHash + } + + // parts[0] is empty (leading $) + // parts[1] is "argon2id" + // parts[2] is "v=19" + // parts[3] is "m=19456,t=2,p=1" + // parts[4] is base64-encoded salt + // parts[5] is base64-encoded hash + + if parts[1] != "argon2id" { + return nil, nil, ErrInvalidHash + } + + salt, err := base64.RawStdEncoding.DecodeString(parts[4]) + if err != nil { + return nil, nil, fmt.Errorf("decoding salt: %w", err) + } + + hash, err := base64.RawStdEncoding.DecodeString(parts[5]) + if err != nil { + return nil, nil, fmt.Errorf("decoding hash: %w", err) + } + + return salt, hash, nil +} + +// Middleware returns HTTP middleware that validates Bearer tokens. +// Accepts tokens via Authorization header (preferred) or access_token query +// parameter (RFC 6750 Section 2.3 fallback for clients that can't set headers). +// Returns 401 with WWW-Authenticate header on failure. +func Middleware(tokenHash string) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(writer http.ResponseWriter, req *http.Request) { + token := extractToken(req) + if token == "" || !Verify(token, tokenHash) { + unauthorized(writer) + + return + } + + next.ServeHTTP(writer, req) + }) + } +} + +// extractToken gets the Bearer token from Authorization header or query param. +func extractToken(req *http.Request) string { + // Prefer Authorization header (RFC 6750 Section 2.1) + if authHeader := req.Header.Get("Authorization"); authHeader != "" { + const bearerPrefix = "Bearer " + if strings.HasPrefix(authHeader, bearerPrefix) { + return strings.TrimPrefix(authHeader, bearerPrefix) + } + } + + // Fall back to query parameter (RFC 6750 Section 2.3) + return req.URL.Query().Get("access_token") +} + +func unauthorized(w http.ResponseWriter) { + w.Header().Set("WWW-Authenticate", `Bearer realm="lune-mcp"`) + http.Error(w, "Unauthorized", http.StatusUnauthorized) +} diff --git a/internal/mcp/auth/auth_test.go b/internal/mcp/auth/auth_test.go new file mode 100644 index 0000000000000000000000000000000000000000..0ad3f1ab48755051c7ae53c4cb6c09911d11ef88 --- /dev/null +++ b/internal/mcp/auth/auth_test.go @@ -0,0 +1,237 @@ +// SPDX-FileCopyrightText: Amolith +// +// SPDX-License-Identifier: AGPL-3.0-or-later + +package auth_test + +import ( + "net/http" + "net/http/httptest" + "testing" + + "git.secluded.site/lune/internal/mcp/auth" +) + +func TestHashAndVerify(t *testing.T) { + t.Parallel() + + token := "my-secret-token" + + hash, err := auth.Hash(token) + if err != nil { + t.Fatalf("Hash() error = %v", err) + } + + if hash == "" { + t.Fatal("Hash() returned empty string") + } + + // Verify with correct token + if !auth.Verify(token, hash) { + t.Error("Verify() returned false for correct token") + } + + // Verify with wrong token + if auth.Verify("wrong-token", hash) { + t.Error("Verify() returned true for wrong token") + } +} + +func TestHashUniqueSalts(t *testing.T) { + t.Parallel() + + token := "same-token" + + hash1, err := auth.Hash(token) + if err != nil { + t.Fatalf("Hash() error = %v", err) + } + + hash2, err := auth.Hash(token) + if err != nil { + t.Fatalf("Hash() error = %v", err) + } + + if hash1 == hash2 { + t.Error("Hash() produced identical hashes for same token (salt not random)") + } + + // Both should still verify correctly + if !auth.Verify(token, hash1) { + t.Error("Verify() failed for hash1") + } + + if !auth.Verify(token, hash2) { + t.Error("Verify() failed for hash2") + } +} + +func TestVerifyInvalidHash(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + hash string + }{ + {"empty", ""}, + {"garbage", "not-a-valid-hash"}, + {"wrong algorithm", "$bcrypt$v=2$something$salt$hash"}, + {"too few parts", "$argon2id$v=19$m=19456$salt"}, + {"invalid base64 salt", "$argon2id$v=19$m=19456,t=2,p=1$!!!invalid$hash"}, + {"invalid base64 hash", "$argon2id$v=19$m=19456,t=2,p=1$c2FsdA$!!!invalid"}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + if auth.Verify("token", tc.hash) { + t.Errorf("Verify() returned true for invalid hash %q", tc.hash) + } + }) + } +} + +func TestMiddleware(t *testing.T) { + t.Parallel() + + token := "test-bearer-token" + + hash, err := auth.Hash(token) + if err != nil { + t.Fatalf("Hash() error = %v", err) + } + + handler := http.HandlerFunc(func(rw http.ResponseWriter, _ *http.Request) { + rw.WriteHeader(http.StatusOK) + _, _ = rw.Write([]byte("OK")) + }) + + protected := auth.Middleware(hash)(handler) + + testCases := middlewareTestCases(token) + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + t.Parallel() + runMiddlewareTest(t, protected, testCase) + }) + } +} + +type middlewareTestCase struct { + name string + authHeader string + wantStatus int + wantAuthHeader bool +} + +func middlewareTestCases(token string) []middlewareTestCase { + return []middlewareTestCase{ + { + name: "missing header", + authHeader: "", + wantStatus: http.StatusUnauthorized, + wantAuthHeader: true, + }, + { + name: "wrong scheme", + authHeader: "Basic dXNlcjpwYXNz", + wantStatus: http.StatusUnauthorized, + wantAuthHeader: true, + }, + { + name: "wrong token", + authHeader: "Bearer wrong-token", + wantStatus: http.StatusUnauthorized, + wantAuthHeader: true, + }, + { + name: "valid token", + authHeader: "Bearer " + token, + wantStatus: http.StatusOK, + wantAuthHeader: false, + }, + } +} + +func runMiddlewareTest(t *testing.T, protected http.Handler, testCase middlewareTestCase) { + t.Helper() + + req := httptest.NewRequest(http.MethodGet, "/", nil) + if testCase.authHeader != "" { + req.Header.Set("Authorization", testCase.authHeader) + } + + recorder := httptest.NewRecorder() + protected.ServeHTTP(recorder, req) + + if recorder.Code != testCase.wantStatus { + t.Errorf("status = %d, want %d", recorder.Code, testCase.wantStatus) + } + + wwwAuth := recorder.Header().Get("WWW-Authenticate") + hasAuthHeader := wwwAuth != "" + + if hasAuthHeader != testCase.wantAuthHeader { + t.Errorf("WWW-Authenticate header present = %v, want %v", hasAuthHeader, testCase.wantAuthHeader) + } + + if testCase.wantAuthHeader && wwwAuth != `Bearer realm="lune-mcp"` { + t.Errorf("WWW-Authenticate = %q, want %q", wwwAuth, `Bearer realm="lune-mcp"`) + } +} + +func TestMiddlewareQueryParam(t *testing.T) { + t.Parallel() + + token := "query-param-token" //nolint:gosec // test fixture, not a real credential + + hash, err := auth.Hash(token) + if err != nil { + t.Fatalf("Hash() error = %v", err) + } + + handler := http.HandlerFunc(func(rw http.ResponseWriter, _ *http.Request) { + rw.WriteHeader(http.StatusOK) + }) + + protected := auth.Middleware(hash)(handler) + + t.Run("valid query param", func(t *testing.T) { + t.Parallel() + + req := httptest.NewRequest(http.MethodGet, "/?access_token="+token, nil) + recorder := httptest.NewRecorder() + protected.ServeHTTP(recorder, req) + + if recorder.Code != http.StatusOK { + t.Errorf("status = %d, want %d", recorder.Code, http.StatusOK) + } + }) + + t.Run("invalid query param", func(t *testing.T) { + t.Parallel() + + req := httptest.NewRequest(http.MethodGet, "/?access_token=wrong-token", nil) + recorder := httptest.NewRecorder() + protected.ServeHTTP(recorder, req) + + if recorder.Code != http.StatusUnauthorized { + t.Errorf("status = %d, want %d", recorder.Code, http.StatusUnauthorized) + } + }) + + t.Run("header takes precedence over query param", func(t *testing.T) { + t.Parallel() + + req := httptest.NewRequest(http.MethodGet, "/?access_token=wrong-token", nil) + req.Header.Set("Authorization", "Bearer "+token) + + recorder := httptest.NewRecorder() + protected.ServeHTTP(recorder, req) + + if recorder.Code != http.StatusOK { + t.Errorf("status = %d, want %d (header should take precedence)", recorder.Code, http.StatusOK) + } + }) +}