Detailed changes
@@ -12,6 +12,7 @@ import (
"strconv"
"git.secluded.site/lune/internal/config"
+ "git.secluded.site/lune/internal/mcp/auth"
"git.secluded.site/lune/internal/mcp/resources/areas"
"git.secluded.site/lune/internal/mcp/resources/habits"
noters "git.secluded.site/lune/internal/mcp/resources/note"
@@ -426,10 +427,15 @@ func runSSE(cmd *cobra.Command, mcpServer *mcp.Server, cfg *config.Config) error
return mcpServer
}, nil)
+ var httpHandler http.Handler = handler
+ if cfg.MCP.TokenHash != "" {
+ httpHandler = auth.Middleware(cfg.MCP.TokenHash)(handler)
+ }
+
fmt.Fprintf(cmd.OutOrStdout(), "SSE server listening on %s\n", hostPort)
//nolint:gosec // MCP SDK controls server lifecycle; timeouts not applicable
- if err := http.ListenAndServe(hostPort, handler); err != nil {
+ if err := http.ListenAndServe(hostPort, httpHandler); err != nil {
return fmt.Errorf("SSE server error: %w", err)
}
@@ -442,10 +448,15 @@ func runHTTP(cmd *cobra.Command, mcpServer *mcp.Server, cfg *config.Config) erro
return mcpServer
}, nil)
+ var httpHandler http.Handler = handler
+ if cfg.MCP.TokenHash != "" {
+ httpHandler = auth.Middleware(cfg.MCP.TokenHash)(handler)
+ }
+
fmt.Fprintf(cmd.OutOrStdout(), "HTTP server listening on %s\n", hostPort)
//nolint:gosec // MCP SDK controls server lifecycle; timeouts not applicable
- if err := http.ListenAndServe(hostPort, handler); err != nil {
+ if err := http.ListenAndServe(hostPort, httpHandler); err != nil {
return fmt.Errorf("HTTP server error: %w", err)
}
@@ -0,0 +1,163 @@
+// SPDX-FileCopyrightText: Amolith <amolith@secluded.site>
+//
+// SPDX-License-Identifier: AGPL-3.0-or-later
+
+package mcp
+
+import (
+ "errors"
+ "fmt"
+ "io"
+
+ "git.secluded.site/lune/internal/config"
+ "git.secluded.site/lune/internal/mcp/auth"
+ "git.secluded.site/lune/internal/ui"
+ "github.com/charmbracelet/huh"
+ "github.com/spf13/cobra"
+)
+
+var (
+ errTokenMismatch = errors.New("tokens do not match")
+ errTokenEmpty = errors.New("token cannot be empty")
+)
+
+var setTokenCmd = &cobra.Command{
+ Use: "set-token",
+ Short: "Set authentication token for the MCP server",
+ Long: `Set a Bearer token for authenticating MCP server requests.
+
+When a token hash is configured, SSE and HTTP transports require clients
+to provide the token via the Authorization header:
+
+ Authorization: Bearer <token>
+
+The token is hashed using argon2id before being stored in the config file.
+Only the hash is stored, not the plaintext token.
+
+Note: stdio transport does not use authentication (local processes only).`,
+ RunE: runSetToken,
+}
+
+func init() {
+ Cmd.AddCommand(setTokenCmd)
+}
+
+func runSetToken(cmd *cobra.Command, _ []string) error {
+ out := cmd.OutOrStdout()
+
+ cfg, err := loadOrCreateConfig()
+ if err != nil {
+ return err
+ }
+
+ if cfg.MCP.TokenHash != "" {
+ proceed, confirmErr := confirmTokenReplace(out)
+ if confirmErr != nil {
+ return confirmErr
+ }
+
+ if !proceed {
+ return nil
+ }
+ }
+
+ token, err := promptForMCPToken()
+ if err != nil {
+ if errors.Is(err, huh.ErrUserAborted) {
+ return nil
+ }
+
+ return err
+ }
+
+ hash, err := auth.Hash(token)
+ if err != nil {
+ return fmt.Errorf("hashing token: %w", err)
+ }
+
+ cfg.MCP.TokenHash = hash
+
+ if err := cfg.Save(); err != nil {
+ return fmt.Errorf("saving config: %w", err)
+ }
+
+ fmt.Fprintln(out, ui.Success.Render("MCP authentication token configured."))
+ fmt.Fprintln(out, "Restart the MCP server for changes to take effect.")
+
+ return nil
+}
+
+func loadOrCreateConfig() (*config.Config, error) {
+ cfg, err := config.Load()
+ if err != nil {
+ if errors.Is(err, config.ErrNotFound) {
+ return &config.Config{}, nil
+ }
+
+ return nil, fmt.Errorf("loading config: %w", err)
+ }
+
+ return cfg, nil
+}
+
+func confirmTokenReplace(out io.Writer) (bool, error) {
+ fmt.Fprintln(out, ui.Warning.Render("An authentication token is already configured."))
+ fmt.Fprintln(out, "Setting a new token will replace the existing one.")
+ fmt.Fprintln(out)
+
+ var proceed bool
+
+ err := huh.NewConfirm().
+ Title("Replace existing token?").
+ Affirmative("Yes").
+ Negative("No").
+ Value(&proceed).
+ Run()
+ if err != nil {
+ if errors.Is(err, huh.ErrUserAborted) {
+ return false, nil
+ }
+
+ return false, err
+ }
+
+ return proceed, nil
+}
+
+func promptForMCPToken() (string, error) {
+ var token, confirm string
+
+ err := huh.NewForm(
+ huh.NewGroup(
+ huh.NewInput().
+ Title("MCP Authentication Token").
+ Description("Enter a token to protect your MCP server.").
+ EchoMode(huh.EchoModePassword).
+ Value(&token).
+ Validate(func(input string) error {
+ if input == "" {
+ return errTokenEmpty
+ }
+
+ return nil
+ }),
+ huh.NewInput().
+ Title("Confirm Token").
+ Description("Re-enter the token to confirm.").
+ EchoMode(huh.EchoModePassword).
+ Value(&confirm).
+ Validate(func(input string) error {
+ if input != token {
+ return errTokenMismatch
+ }
+
+ return nil
+ }),
+ ),
+ ).Run()
+ if err != nil {
+ return "", err
+ }
+
+ return token, nil
+}
@@ -69,6 +69,7 @@ require (
github.com/wasilibs/go-re2 v1.3.0 // indirect
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
github.com/yosida95/uritemplate/v3 v3.0.2 // indirect
+ golang.org/x/crypto v0.46.0 // indirect
golang.org/x/oauth2 v0.30.0 // indirect
golang.org/x/sync v0.19.0 // indirect
golang.org/x/sys v0.39.0 // indirect
@@ -150,6 +150,8 @@ github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT0
github.com/zalando/go-keyring v0.2.6 h1:r7Yc3+H+Ux0+M72zacZoItR3UDxeWfKTcabvkI8ua9s=
github.com/zalando/go-keyring v0.2.6/go.mod h1:2TCrxYrbUNYfNS/Kgy/LSrkSQzZ5UPVH85RwfczwvcI=
go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
+golang.org/x/crypto v0.46.0 h1:cKRW/pmt1pKAfetfu+RCEvjvZkA9RimPbh7bhFjGVBU=
+golang.org/x/crypto v0.46.0/go.mod h1:Evb/oLKmMraqjZ2iQTwDwvCtJkczlDuTmdJXoZVzqU0=
golang.org/x/exp v0.0.0-20231006140011-7918f672742d h1:jtJma62tbqLibJ5sFQz8bKtEM8rJBtfilJ2qTU199MI=
golang.org/x/exp v0.0.0-20231006140011-7918f672742d/go.mod h1:ldy0pHrwJyGW56pPQzzkH36rKxoZW1tw7ZJpeKx+hdo=
golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI=
@@ -30,10 +30,11 @@ type Config struct {
// MCPConfig holds MCP server settings.
type MCPConfig struct {
- Host string `toml:"host"`
- Port int `toml:"port"`
- Timezone string `toml:"timezone"`
- Tools ToolsConfig `toml:"tools"`
+ Host string `toml:"host"`
+ Port int `toml:"port"`
+ Timezone string `toml:"timezone"`
+ TokenHash string `toml:"token_hash,omitempty"`
+ Tools ToolsConfig `toml:"tools"`
}
// ToolsConfig controls which MCP tools are enabled.
@@ -0,0 +1,137 @@
+// SPDX-FileCopyrightText: Amolith <amolith@secluded.site>
+//
+// SPDX-License-Identifier: AGPL-3.0-or-later
+
+// Package auth provides Bearer token authentication for the MCP server.
+package auth
+
+import (
+ "crypto/rand"
+ "crypto/subtle"
+ "encoding/base64"
+ "errors"
+ "fmt"
+ "net/http"
+ "strings"
+
+ "golang.org/x/crypto/argon2"
+)
+
+// Argon2id parameters per OWASP recommendations.
+const (
+ argonTime = 2
+ argonMemory = 19 * 1024 // 19 MiB
+ argonThreads = 1
+ argonKeyLen = 32
+ saltLen = 16
+ hashParts = 6 // $argon2id$v=19$m=...,t=...,p=...$salt$hash
+)
+
+// ErrInvalidHash indicates the stored hash is malformed.
+var ErrInvalidHash = errors.New("invalid hash format")
+
+// Hash generates an argon2id hash of the token with a random salt.
+// Returns an encoded string containing version, params, salt, and hash.
+func Hash(token string) (string, error) {
+ salt := make([]byte, saltLen)
+ if _, err := rand.Read(salt); err != nil {
+ return "", fmt.Errorf("generating salt: %w", err)
+ }
+
+ hash := argon2.IDKey([]byte(token), salt, argonTime, argonMemory, argonThreads, argonKeyLen)
+
+ // Format: $argon2id$v=19$m=19456,t=2,p=1$<salt>$<hash>
+ encoded := fmt.Sprintf(
+ "$argon2id$v=%d$m=%d,t=%d,p=%d$%s$%s",
+ argon2.Version,
+ argonMemory,
+ argonTime,
+ argonThreads,
+ base64.RawStdEncoding.EncodeToString(salt),
+ base64.RawStdEncoding.EncodeToString(hash),
+ )
+
+ return encoded, nil
+}
+
+// Verify checks if the token matches the encoded hash.
+// Uses constant-time comparison to prevent timing attacks.
+func Verify(token, encodedHash string) bool {
+ salt, storedHash, err := decodeHash(encodedHash)
+ if err != nil {
+ return false
+ }
+
+ computedHash := argon2.IDKey([]byte(token), salt, argonTime, argonMemory, argonThreads, argonKeyLen)
+
+ return subtle.ConstantTimeCompare(storedHash, computedHash) == 1
+}
+
+// decodeHash parses the encoded hash string and extracts the salt and hash.
+func decodeHash(encoded string) ([]byte, []byte, error) {
+ parts := strings.Split(encoded, "$")
+ if len(parts) != hashParts {
+ return nil, nil, ErrInvalidHash
+ }
+
+ // parts[0] is empty (leading $)
+ // parts[1] is "argon2id"
+ // parts[2] is "v=19"
+ // parts[3] is "m=19456,t=2,p=1"
+ // parts[4] is base64-encoded salt
+ // parts[5] is base64-encoded hash
+
+ if parts[1] != "argon2id" {
+ return nil, nil, ErrInvalidHash
+ }
+
+ salt, err := base64.RawStdEncoding.DecodeString(parts[4])
+ if err != nil {
+ return nil, nil, fmt.Errorf("decoding salt: %w", err)
+ }
+
+ hash, err := base64.RawStdEncoding.DecodeString(parts[5])
+ if err != nil {
+ return nil, nil, fmt.Errorf("decoding hash: %w", err)
+ }
+
+ return salt, hash, nil
+}
+
+// Middleware returns HTTP middleware that validates Bearer tokens.
+// Accepts tokens via Authorization header (preferred) or access_token query
+// parameter (RFC 6750 Section 2.3 fallback for clients that can't set headers).
+// Returns 401 with WWW-Authenticate header on failure.
+func Middleware(tokenHash string) func(http.Handler) http.Handler {
+ return func(next http.Handler) http.Handler {
+ return http.HandlerFunc(func(writer http.ResponseWriter, req *http.Request) {
+ token := extractToken(req)
+ if token == "" || !Verify(token, tokenHash) {
+ unauthorized(writer)
+
+ return
+ }
+
+ next.ServeHTTP(writer, req)
+ })
+ }
+}
+
+// extractToken gets the Bearer token from Authorization header or query param.
+func extractToken(req *http.Request) string {
+ // Prefer Authorization header (RFC 6750 Section 2.1)
+ if authHeader := req.Header.Get("Authorization"); authHeader != "" {
+ const bearerPrefix = "Bearer "
+ if strings.HasPrefix(authHeader, bearerPrefix) {
+ return strings.TrimPrefix(authHeader, bearerPrefix)
+ }
+ }
+
+ // Fall back to query parameter (RFC 6750 Section 2.3)
+ return req.URL.Query().Get("access_token")
+}
+
+func unauthorized(w http.ResponseWriter) {
+ w.Header().Set("WWW-Authenticate", `Bearer realm="lune-mcp"`)
+ http.Error(w, "Unauthorized", http.StatusUnauthorized)
+}
@@ -0,0 +1,237 @@
+// SPDX-FileCopyrightText: Amolith <amolith@secluded.site>
+//
+// SPDX-License-Identifier: AGPL-3.0-or-later
+
+package auth_test
+
+import (
+ "net/http"
+ "net/http/httptest"
+ "testing"
+
+ "git.secluded.site/lune/internal/mcp/auth"
+)
+
+func TestHashAndVerify(t *testing.T) {
+ t.Parallel()
+
+ token := "my-secret-token"
+
+ hash, err := auth.Hash(token)
+ if err != nil {
+ t.Fatalf("Hash() error = %v", err)
+ }
+
+ if hash == "" {
+ t.Fatal("Hash() returned empty string")
+ }
+
+ // Verify with correct token
+ if !auth.Verify(token, hash) {
+ t.Error("Verify() returned false for correct token")
+ }
+
+ // Verify with wrong token
+ if auth.Verify("wrong-token", hash) {
+ t.Error("Verify() returned true for wrong token")
+ }
+}
+
+func TestHashUniqueSalts(t *testing.T) {
+ t.Parallel()
+
+ token := "same-token"
+
+ hash1, err := auth.Hash(token)
+ if err != nil {
+ t.Fatalf("Hash() error = %v", err)
+ }
+
+ hash2, err := auth.Hash(token)
+ if err != nil {
+ t.Fatalf("Hash() error = %v", err)
+ }
+
+ if hash1 == hash2 {
+ t.Error("Hash() produced identical hashes for same token (salt not random)")
+ }
+
+ // Both should still verify correctly
+ if !auth.Verify(token, hash1) {
+ t.Error("Verify() failed for hash1")
+ }
+
+ if !auth.Verify(token, hash2) {
+ t.Error("Verify() failed for hash2")
+ }
+}
+
+func TestVerifyInvalidHash(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ hash string
+ }{
+ {"empty", ""},
+ {"garbage", "not-a-valid-hash"},
+ {"wrong algorithm", "$bcrypt$v=2$something$salt$hash"},
+ {"too few parts", "$argon2id$v=19$m=19456$salt"},
+ {"invalid base64 salt", "$argon2id$v=19$m=19456,t=2,p=1$!!!invalid$hash"},
+ {"invalid base64 hash", "$argon2id$v=19$m=19456,t=2,p=1$c2FsdA$!!!invalid"},
+ }
+
+ for _, tc := range tests {
+ t.Run(tc.name, func(t *testing.T) {
+ t.Parallel()
+
+ if auth.Verify("token", tc.hash) {
+ t.Errorf("Verify() returned true for invalid hash %q", tc.hash)
+ }
+ })
+ }
+}
+
+func TestMiddleware(t *testing.T) {
+ t.Parallel()
+
+ token := "test-bearer-token"
+
+ hash, err := auth.Hash(token)
+ if err != nil {
+ t.Fatalf("Hash() error = %v", err)
+ }
+
+ handler := http.HandlerFunc(func(rw http.ResponseWriter, _ *http.Request) {
+ rw.WriteHeader(http.StatusOK)
+ _, _ = rw.Write([]byte("OK"))
+ })
+
+ protected := auth.Middleware(hash)(handler)
+
+ testCases := middlewareTestCases(token)
+ for _, testCase := range testCases {
+ t.Run(testCase.name, func(t *testing.T) {
+ t.Parallel()
+ runMiddlewareTest(t, protected, testCase)
+ })
+ }
+}
+
+type middlewareTestCase struct {
+ name string
+ authHeader string
+ wantStatus int
+ wantAuthHeader bool
+}
+
+func middlewareTestCases(token string) []middlewareTestCase {
+ return []middlewareTestCase{
+ {
+ name: "missing header",
+ authHeader: "",
+ wantStatus: http.StatusUnauthorized,
+ wantAuthHeader: true,
+ },
+ {
+ name: "wrong scheme",
+ authHeader: "Basic dXNlcjpwYXNz",
+ wantStatus: http.StatusUnauthorized,
+ wantAuthHeader: true,
+ },
+ {
+ name: "wrong token",
+ authHeader: "Bearer wrong-token",
+ wantStatus: http.StatusUnauthorized,
+ wantAuthHeader: true,
+ },
+ {
+ name: "valid token",
+ authHeader: "Bearer " + token,
+ wantStatus: http.StatusOK,
+ wantAuthHeader: false,
+ },
+ }
+}
+
+func runMiddlewareTest(t *testing.T, protected http.Handler, testCase middlewareTestCase) {
+ t.Helper()
+
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ if testCase.authHeader != "" {
+ req.Header.Set("Authorization", testCase.authHeader)
+ }
+
+ recorder := httptest.NewRecorder()
+ protected.ServeHTTP(recorder, req)
+
+ if recorder.Code != testCase.wantStatus {
+ t.Errorf("status = %d, want %d", recorder.Code, testCase.wantStatus)
+ }
+
+ wwwAuth := recorder.Header().Get("WWW-Authenticate")
+ hasAuthHeader := wwwAuth != ""
+
+ if hasAuthHeader != testCase.wantAuthHeader {
+ t.Errorf("WWW-Authenticate header present = %v, want %v", hasAuthHeader, testCase.wantAuthHeader)
+ }
+
+ if testCase.wantAuthHeader && wwwAuth != `Bearer realm="lune-mcp"` {
+ t.Errorf("WWW-Authenticate = %q, want %q", wwwAuth, `Bearer realm="lune-mcp"`)
+ }
+}
+
+func TestMiddlewareQueryParam(t *testing.T) {
+ t.Parallel()
+
+ token := "query-param-token" //nolint:gosec // test fixture, not a real credential
+
+ hash, err := auth.Hash(token)
+ if err != nil {
+ t.Fatalf("Hash() error = %v", err)
+ }
+
+ handler := http.HandlerFunc(func(rw http.ResponseWriter, _ *http.Request) {
+ rw.WriteHeader(http.StatusOK)
+ })
+
+ protected := auth.Middleware(hash)(handler)
+
+ t.Run("valid query param", func(t *testing.T) {
+ t.Parallel()
+
+ req := httptest.NewRequest(http.MethodGet, "/?access_token="+token, nil)
+ recorder := httptest.NewRecorder()
+ protected.ServeHTTP(recorder, req)
+
+ if recorder.Code != http.StatusOK {
+ t.Errorf("status = %d, want %d", recorder.Code, http.StatusOK)
+ }
+ })
+
+ t.Run("invalid query param", func(t *testing.T) {
+ t.Parallel()
+
+ req := httptest.NewRequest(http.MethodGet, "/?access_token=wrong-token", nil)
+ recorder := httptest.NewRecorder()
+ protected.ServeHTTP(recorder, req)
+
+ if recorder.Code != http.StatusUnauthorized {
+ t.Errorf("status = %d, want %d", recorder.Code, http.StatusUnauthorized)
+ }
+ })
+
+ t.Run("header takes precedence over query param", func(t *testing.T) {
+ t.Parallel()
+
+ req := httptest.NewRequest(http.MethodGet, "/?access_token=wrong-token", nil)
+ req.Header.Set("Authorization", "Bearer "+token)
+
+ recorder := httptest.NewRecorder()
+ protected.ServeHTTP(recorder, req)
+
+ if recorder.Code != http.StatusOK {
+ t.Errorf("status = %d, want %d (header should take precedence)", recorder.Code, http.StatusOK)
+ }
+ })
+}