oauth.rs

   1//! OAuth 2.0 authentication for MCP servers using the Authorization Code +
   2//! PKCE flow, per the MCP spec's OAuth profile.
   3//!
   4//! The flow is split into two phases:
   5//!
   6//! 1. **Discovery** ([`discover`]) fetches Protected Resource Metadata and
   7//!    Authorization Server Metadata. This can happen early (e.g. on a 401
   8//!    during server startup) because it doesn't need the redirect URI yet.
   9//!
  10//! 2. **Client registration** ([`resolve_client_registration`]) is separate
  11//!    because DCR requires the actual loopback redirect URI, which includes an
  12//!    ephemeral port that only exists once the callback server has started.
  13//!
  14//! After authentication, the full state is captured in [`OAuthSession`] which
  15//! is persisted to the keychain. On next startup, the stored session feeds
  16//! directly into [`McpOAuthTokenProvider`], giving a refresh-capable provider
  17//! without requiring another browser flow.
  18
  19use anyhow::{Context as _, Result, anyhow, bail};
  20use async_trait::async_trait;
  21use base64::Engine as _;
  22use futures::AsyncReadExt as _;
  23use futures::channel::mpsc;
  24use http_client::{AsyncBody, HttpClient, Request};
  25use parking_lot::Mutex as SyncMutex;
  26use rand::Rng as _;
  27use serde::{Deserialize, Serialize};
  28use sha2::{Digest, Sha256};
  29
  30use std::str::FromStr;
  31use std::sync::Arc;
  32use std::time::{Duration, SystemTime};
  33use url::Url;
  34use util::ResultExt as _;
  35
  36/// The CIMD URL where Zed's OAuth client metadata document is hosted.
  37pub const CIMD_URL: &str = "https://zed.dev/oauth/client-metadata.json";
  38
  39/// Validate that a URL is safe to use as an OAuth endpoint.
  40///
  41/// OAuth endpoints carry sensitive material (authorization codes, PKCE
  42/// verifiers, tokens) and must use TLS. Plain HTTP is only permitted for
  43/// loopback addresses, per RFC 8252 Section 8.3.
  44fn require_https_or_loopback(url: &Url) -> Result<()> {
  45    if url.scheme() == "https" {
  46        return Ok(());
  47    }
  48    if url.scheme() == "http" {
  49        if let Some(host) = url.host() {
  50            match host {
  51                url::Host::Ipv4(ip) if ip.is_loopback() => return Ok(()),
  52                url::Host::Ipv6(ip) if ip.is_loopback() => return Ok(()),
  53                url::Host::Domain(d) if d.eq_ignore_ascii_case("localhost") => return Ok(()),
  54                _ => {}
  55            }
  56        }
  57    }
  58    bail!(
  59        "OAuth endpoint must use HTTPS (got {}://{})",
  60        url.scheme(),
  61        url.host_str().unwrap_or("?")
  62    )
  63}
  64
  65/// Validate that a URL is safe to use as an OAuth endpoint, including SSRF
  66/// protections against private/reserved IP ranges.
  67///
  68/// This wraps [`require_https_or_loopback`] and adds IP-range checks to prevent
  69/// an attacker-controlled MCP server from directing Zed to fetch internal
  70/// network resources via metadata URLs.
  71///
  72/// **Known limitation:** Domain-name URLs that resolve to private IPs are *not*
  73/// blocked here — full mitigation requires resolver-level validation (e.g. a
  74/// custom `Resolve` implementation). This function only blocks IP-literal URLs.
  75fn validate_oauth_url(url: &Url) -> Result<()> {
  76    require_https_or_loopback(url)?;
  77
  78    if let Some(host) = url.host() {
  79        match host {
  80            url::Host::Ipv4(ip) => {
  81                // Loopback is already allowed by require_https_or_loopback.
  82                if ip.is_private() || ip.is_link_local() || ip.is_broadcast() || ip.is_unspecified()
  83                {
  84                    bail!(
  85                        "OAuth endpoint must not point to private/reserved IP: {}",
  86                        ip
  87                    );
  88                }
  89            }
  90            url::Host::Ipv6(ip) => {
  91                // Check for IPv4-mapped IPv6 addresses (::ffff:a.b.c.d) which
  92                // could bypass the IPv4 checks above.
  93                if let Some(mapped_v4) = ip.to_ipv4_mapped() {
  94                    if mapped_v4.is_private()
  95                        || mapped_v4.is_link_local()
  96                        || mapped_v4.is_broadcast()
  97                        || mapped_v4.is_unspecified()
  98                    {
  99                        bail!(
 100                            "OAuth endpoint must not point to private/reserved IP: ::ffff:{}",
 101                            mapped_v4
 102                        );
 103                    }
 104                }
 105
 106                if ip.is_unspecified() || ip.is_multicast() {
 107                    bail!(
 108                        "OAuth endpoint must not point to reserved IPv6 address: {}",
 109                        ip
 110                    );
 111                }
 112                // IPv6 Unique Local Addresses (fc00::/7). is_unique_local() is
 113                // nightly-only, so check the prefix manually.
 114                if (ip.segments()[0] & 0xfe00) == 0xfc00 {
 115                    bail!(
 116                        "OAuth endpoint must not point to IPv6 unique-local address: {}",
 117                        ip
 118                    );
 119                }
 120            }
 121            url::Host::Domain(_) => {
 122                // Domain-based SSRF prevention requires resolver-level checks.
 123                // See known limitation in the doc comment above.
 124            }
 125        }
 126    }
 127
 128    Ok(())
 129}
 130
 131/// Parsed from the MCP server's WWW-Authenticate header or well-known endpoint
 132/// per RFC 9728 (OAuth 2.0 Protected Resource Metadata).
 133#[derive(Debug, Clone, Serialize, Deserialize)]
 134pub struct ProtectedResourceMetadata {
 135    pub resource: Url,
 136    pub authorization_servers: Vec<Url>,
 137    pub scopes_supported: Option<Vec<String>>,
 138}
 139
 140/// Parsed from the authorization server's .well-known endpoint
 141/// per RFC 8414 (OAuth 2.0 Authorization Server Metadata).
 142#[derive(Debug, Clone, Serialize, Deserialize)]
 143pub struct AuthServerMetadata {
 144    pub issuer: Url,
 145    pub authorization_endpoint: Url,
 146    pub token_endpoint: Url,
 147    pub registration_endpoint: Option<Url>,
 148    pub scopes_supported: Option<Vec<String>>,
 149    pub code_challenge_methods_supported: Option<Vec<String>>,
 150    pub client_id_metadata_document_supported: bool,
 151}
 152
 153/// The result of client registration — either CIMD or DCR.
 154#[derive(Clone, Serialize, Deserialize)]
 155pub struct OAuthClientRegistration {
 156    pub client_id: String,
 157    /// Only present for DCR-minted registrations.
 158    pub client_secret: Option<String>,
 159}
 160
 161impl std::fmt::Debug for OAuthClientRegistration {
 162    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 163        f.debug_struct("OAuthClientRegistration")
 164            .field("client_id", &self.client_id)
 165            .field(
 166                "client_secret",
 167                &self.client_secret.as_ref().map(|_| "[redacted]"),
 168            )
 169            .finish()
 170    }
 171}
 172
 173/// Access and refresh tokens obtained from the token endpoint.
 174#[derive(Clone, Serialize, Deserialize)]
 175pub struct OAuthTokens {
 176    pub access_token: String,
 177    pub refresh_token: Option<String>,
 178    pub expires_at: Option<SystemTime>,
 179}
 180
 181impl std::fmt::Debug for OAuthTokens {
 182    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 183        f.debug_struct("OAuthTokens")
 184            .field("access_token", &"[redacted]")
 185            .field(
 186                "refresh_token",
 187                &self.refresh_token.as_ref().map(|_| "[redacted]"),
 188            )
 189            .field("expires_at", &self.expires_at)
 190            .finish()
 191    }
 192}
 193
 194/// Everything discovered before the browser flow starts. Client registration is
 195/// resolved separately, once the real redirect URI is known.
 196#[derive(Debug, Clone, Serialize, Deserialize)]
 197pub struct OAuthDiscovery {
 198    pub resource_metadata: ProtectedResourceMetadata,
 199    pub auth_server_metadata: AuthServerMetadata,
 200    pub scopes: Vec<String>,
 201}
 202
 203/// The persisted OAuth session for a context server.
 204///
 205/// Stored in the keychain so startup can restore a refresh-capable provider
 206/// without another browser flow. Deliberately excludes the full discovery
 207/// metadata to keep the serialized size well within keychain item limits.
 208#[derive(Clone, Serialize, Deserialize)]
 209pub struct OAuthSession {
 210    pub token_endpoint: Url,
 211    pub resource: Url,
 212    pub client_registration: OAuthClientRegistration,
 213    pub tokens: OAuthTokens,
 214}
 215
 216impl std::fmt::Debug for OAuthSession {
 217    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 218        f.debug_struct("OAuthSession")
 219            .field("token_endpoint", &self.token_endpoint)
 220            .field("resource", &self.resource)
 221            .field("client_registration", &self.client_registration)
 222            .field("tokens", &self.tokens)
 223            .finish()
 224    }
 225}
 226
 227/// Error codes defined by RFC 6750 Section 3.1 for Bearer token authentication.
 228#[derive(Debug, Clone, Copy, PartialEq, Eq)]
 229pub enum BearerError {
 230    /// The request is missing a required parameter, includes an unsupported
 231    /// parameter or parameter value, or is otherwise malformed.
 232    InvalidRequest,
 233    /// The access token provided is expired, revoked, malformed, or invalid.
 234    InvalidToken,
 235    /// The request requires higher privileges than provided by the access token.
 236    InsufficientScope,
 237    /// An unrecognized error code (extension or future spec addition).
 238    Other,
 239}
 240
 241impl BearerError {
 242    fn parse(value: &str) -> Self {
 243        match value {
 244            "invalid_request" => BearerError::InvalidRequest,
 245            "invalid_token" => BearerError::InvalidToken,
 246            "insufficient_scope" => BearerError::InsufficientScope,
 247            _ => BearerError::Other,
 248        }
 249    }
 250}
 251
 252/// Fields extracted from a `WWW-Authenticate: Bearer` header.
 253///
 254/// Per RFC 9728 Section 5.1, MCP servers include `resource_metadata` to point
 255/// at the Protected Resource Metadata document. The optional `scope` parameter
 256/// (RFC 6750 Section 3) indicates scopes required for the request.
 257#[derive(Debug, Clone, PartialEq, Eq)]
 258pub struct WwwAuthenticate {
 259    pub resource_metadata: Option<Url>,
 260    pub scope: Option<Vec<String>>,
 261    /// The parsed `error` parameter per RFC 6750 Section 3.1.
 262    pub error: Option<BearerError>,
 263    pub error_description: Option<String>,
 264}
 265
 266/// Parse a `WWW-Authenticate` header value.
 267///
 268/// Expects the `Bearer` scheme followed by comma-separated `key="value"` pairs.
 269/// Per RFC 6750 and RFC 9728, the relevant parameters are:
 270/// - `resource_metadata` — URL of the Protected Resource Metadata document
 271/// - `scope` — space-separated list of required scopes
 272/// - `error` — error code (e.g. "insufficient_scope")
 273/// - `error_description` — human-readable error description
 274pub fn parse_www_authenticate(header: &str) -> Result<WwwAuthenticate> {
 275    let header = header.trim();
 276
 277    let params_str = if header.len() >= 6 && header[..6].eq_ignore_ascii_case("bearer") {
 278        header[6..].trim()
 279    } else {
 280        bail!("WWW-Authenticate header does not use Bearer scheme");
 281    };
 282
 283    if params_str.is_empty() {
 284        return Ok(WwwAuthenticate {
 285            resource_metadata: None,
 286            scope: None,
 287            error: None,
 288            error_description: None,
 289        });
 290    }
 291
 292    let params = parse_auth_params(params_str);
 293
 294    let resource_metadata = params
 295        .get("resource_metadata")
 296        .map(|v| Url::parse(v))
 297        .transpose()
 298        .map_err(|e| anyhow!("invalid resource_metadata URL: {}", e))?;
 299
 300    let scope = params
 301        .get("scope")
 302        .map(|v| v.split_whitespace().map(String::from).collect());
 303
 304    let error = params.get("error").map(|v| BearerError::parse(v));
 305    let error_description = params.get("error_description").cloned();
 306
 307    Ok(WwwAuthenticate {
 308        resource_metadata,
 309        scope,
 310        error,
 311        error_description,
 312    })
 313}
 314
 315/// Parse comma-separated `key="value"` or `key=token` parameters from an
 316/// auth-param list (RFC 7235 Section 2.1).
 317fn parse_auth_params(input: &str) -> collections::HashMap<String, String> {
 318    let mut params = collections::HashMap::default();
 319    let mut remaining = input.trim();
 320
 321    while !remaining.is_empty() {
 322        // Skip leading whitespace and commas.
 323        remaining = remaining.trim_start_matches(|c: char| c == ',' || c.is_whitespace());
 324        if remaining.is_empty() {
 325            break;
 326        }
 327
 328        // Find the key (everything before '=').
 329        let eq_pos = match remaining.find('=') {
 330            Some(pos) => pos,
 331            None => break,
 332        };
 333
 334        let key = remaining[..eq_pos].trim().to_lowercase();
 335        remaining = &remaining[eq_pos + 1..];
 336        remaining = remaining.trim_start();
 337
 338        // Parse the value: either quoted or unquoted (token).
 339        let value;
 340        if remaining.starts_with('"') {
 341            // Quoted string: find the closing quote, handling escaped chars.
 342            remaining = &remaining[1..]; // skip opening quote
 343            let mut val = String::new();
 344            let mut chars = remaining.char_indices();
 345            loop {
 346                match chars.next() {
 347                    Some((_, '\\')) => {
 348                        // Escaped character — take the next char literally.
 349                        if let Some((_, c)) = chars.next() {
 350                            val.push(c);
 351                        }
 352                    }
 353                    Some((i, '"')) => {
 354                        remaining = &remaining[i + 1..];
 355                        break;
 356                    }
 357                    Some((_, c)) => val.push(c),
 358                    None => {
 359                        remaining = "";
 360                        break;
 361                    }
 362                }
 363            }
 364            value = val;
 365        } else {
 366            // Unquoted token: read until comma or whitespace.
 367            let end = remaining
 368                .find(|c: char| c == ',' || c.is_whitespace())
 369                .unwrap_or(remaining.len());
 370            value = remaining[..end].to_string();
 371            remaining = &remaining[end..];
 372        }
 373
 374        if !key.is_empty() {
 375            params.insert(key, value);
 376        }
 377    }
 378
 379    params
 380}
 381
 382/// Construct the well-known Protected Resource Metadata URIs for a given MCP
 383/// server URL, per RFC 9728 Section 3.
 384///
 385/// Returns URIs in priority order:
 386/// 1. Path-specific: `https://<host>/.well-known/oauth-protected-resource/<path>`
 387/// 2. Root: `https://<host>/.well-known/oauth-protected-resource`
 388pub fn protected_resource_metadata_urls(server_url: &Url) -> Vec<Url> {
 389    let mut urls = Vec::new();
 390    let base = format!("{}://{}", server_url.scheme(), server_url.authority());
 391
 392    let path = server_url.path().trim_start_matches('/');
 393    if !path.is_empty() {
 394        if let Ok(url) = Url::parse(&format!(
 395            "{}/.well-known/oauth-protected-resource/{}",
 396            base, path
 397        )) {
 398            urls.push(url);
 399        }
 400    }
 401
 402    if let Ok(url) = Url::parse(&format!("{}/.well-known/oauth-protected-resource", base)) {
 403        urls.push(url);
 404    }
 405
 406    urls
 407}
 408
 409/// Construct the well-known Authorization Server Metadata URIs for a given
 410/// issuer URL, per RFC 8414 Section 3.1 and Section 5 (OIDC compat).
 411///
 412/// Returns URIs in priority order, which differs depending on whether the
 413/// issuer URL has a path component.
 414pub fn auth_server_metadata_urls(issuer: &Url) -> Vec<Url> {
 415    let mut urls = Vec::new();
 416    let base = format!("{}://{}", issuer.scheme(), issuer.authority());
 417    let path = issuer.path().trim_matches('/');
 418
 419    if !path.is_empty() {
 420        // Issuer with path: try path-inserted variants first.
 421        if let Ok(url) = Url::parse(&format!(
 422            "{}/.well-known/oauth-authorization-server/{}",
 423            base, path
 424        )) {
 425            urls.push(url);
 426        }
 427        if let Ok(url) = Url::parse(&format!(
 428            "{}/.well-known/openid-configuration/{}",
 429            base, path
 430        )) {
 431            urls.push(url);
 432        }
 433        if let Ok(url) = Url::parse(&format!(
 434            "{}/{}/.well-known/openid-configuration",
 435            base, path
 436        )) {
 437            urls.push(url);
 438        }
 439    } else {
 440        // No path: standard well-known locations.
 441        if let Ok(url) = Url::parse(&format!("{}/.well-known/oauth-authorization-server", base)) {
 442            urls.push(url);
 443        }
 444        if let Ok(url) = Url::parse(&format!("{}/.well-known/openid-configuration", base)) {
 445            urls.push(url);
 446        }
 447    }
 448
 449    urls
 450}
 451
 452// -- Canonical server URI (RFC 8707) -----------------------------------------
 453
 454/// Derive the canonical resource URI for an MCP server URL, suitable for the
 455/// `resource` parameter in authorization and token requests per RFC 8707.
 456///
 457/// Lowercases the scheme and host, preserves the path (without trailing slash),
 458/// strips fragments and query strings.
 459pub fn canonical_server_uri(server_url: &Url) -> String {
 460    let mut uri = format!(
 461        "{}://{}",
 462        server_url.scheme().to_ascii_lowercase(),
 463        server_url.host_str().unwrap_or("").to_ascii_lowercase(),
 464    );
 465    if let Some(port) = server_url.port() {
 466        uri.push_str(&format!(":{}", port));
 467    }
 468    let path = server_url.path();
 469    if path != "/" {
 470        uri.push_str(path.trim_end_matches('/'));
 471    }
 472    uri
 473}
 474
 475// -- Scope selection ---------------------------------------------------------
 476
 477/// Select scopes following the MCP spec's Scope Selection Strategy:
 478/// 1. Use `scope` from the `WWW-Authenticate` challenge if present.
 479/// 2. Fall back to `scopes_supported` from Protected Resource Metadata.
 480/// 3. Return empty if neither is available.
 481pub fn select_scopes(
 482    www_authenticate: &WwwAuthenticate,
 483    resource_metadata: &ProtectedResourceMetadata,
 484) -> Vec<String> {
 485    if let Some(ref scopes) = www_authenticate.scope {
 486        if !scopes.is_empty() {
 487            return scopes.clone();
 488        }
 489    }
 490    resource_metadata
 491        .scopes_supported
 492        .clone()
 493        .unwrap_or_default()
 494}
 495
 496// -- Client registration strategy --------------------------------------------
 497
 498/// The registration approach to use, determined from auth server metadata.
 499#[derive(Debug, Clone, PartialEq, Eq)]
 500pub enum ClientRegistrationStrategy {
 501    /// The auth server supports CIMD. Use the CIMD URL as client_id directly.
 502    Cimd { client_id: String },
 503    /// The auth server has a registration endpoint. Caller must POST to it.
 504    Dcr { registration_endpoint: Url },
 505    /// No supported registration mechanism.
 506    Unavailable,
 507}
 508
 509/// Determine how to register with the authorization server, following the
 510/// spec's recommended priority: CIMD first, DCR fallback.
 511pub fn determine_registration_strategy(
 512    auth_server_metadata: &AuthServerMetadata,
 513) -> ClientRegistrationStrategy {
 514    if auth_server_metadata.client_id_metadata_document_supported {
 515        ClientRegistrationStrategy::Cimd {
 516            client_id: CIMD_URL.to_string(),
 517        }
 518    } else if let Some(ref endpoint) = auth_server_metadata.registration_endpoint {
 519        ClientRegistrationStrategy::Dcr {
 520            registration_endpoint: endpoint.clone(),
 521        }
 522    } else {
 523        ClientRegistrationStrategy::Unavailable
 524    }
 525}
 526
 527// -- PKCE (RFC 7636) ---------------------------------------------------------
 528
 529/// A PKCE code verifier and its S256 challenge.
 530#[derive(Clone)]
 531pub struct PkceChallenge {
 532    pub verifier: String,
 533    pub challenge: String,
 534}
 535
 536impl std::fmt::Debug for PkceChallenge {
 537    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 538        f.debug_struct("PkceChallenge")
 539            .field("verifier", &"[redacted]")
 540            .field("challenge", &self.challenge)
 541            .finish()
 542    }
 543}
 544
 545/// Generate a PKCE code verifier and S256 challenge per RFC 7636.
 546///
 547/// The verifier is 43 base64url characters derived from 32 random bytes.
 548/// The challenge is `BASE64URL(SHA256(verifier))`.
 549pub fn generate_pkce_challenge() -> PkceChallenge {
 550    let mut random_bytes = [0u8; 32];
 551    rand::rng().fill(&mut random_bytes);
 552    let engine = base64::engine::general_purpose::URL_SAFE_NO_PAD;
 553    let verifier = engine.encode(&random_bytes);
 554
 555    let digest = Sha256::digest(verifier.as_bytes());
 556    let challenge = engine.encode(digest);
 557
 558    PkceChallenge {
 559        verifier,
 560        challenge,
 561    }
 562}
 563
 564// -- Authorization URL construction ------------------------------------------
 565
 566/// Build the authorization URL for the OAuth Authorization Code + PKCE flow.
 567pub fn build_authorization_url(
 568    auth_server_metadata: &AuthServerMetadata,
 569    client_id: &str,
 570    redirect_uri: &str,
 571    scopes: &[String],
 572    resource: &str,
 573    pkce: &PkceChallenge,
 574    state: &str,
 575) -> Url {
 576    let mut url = auth_server_metadata.authorization_endpoint.clone();
 577    {
 578        let mut query = url.query_pairs_mut();
 579        query.append_pair("response_type", "code");
 580        query.append_pair("client_id", client_id);
 581        query.append_pair("redirect_uri", redirect_uri);
 582        if !scopes.is_empty() {
 583            query.append_pair("scope", &scopes.join(" "));
 584        }
 585        query.append_pair("resource", resource);
 586        query.append_pair("code_challenge", &pkce.challenge);
 587        query.append_pair("code_challenge_method", "S256");
 588        query.append_pair("state", state);
 589    }
 590    url
 591}
 592
 593// -- Token endpoint request bodies -------------------------------------------
 594
 595/// The JSON body returned by the token endpoint on success.
 596#[derive(Deserialize)]
 597pub struct TokenResponse {
 598    pub access_token: String,
 599    #[serde(default)]
 600    pub refresh_token: Option<String>,
 601    #[serde(default)]
 602    pub expires_in: Option<u64>,
 603    #[serde(default)]
 604    pub token_type: Option<String>,
 605}
 606
 607impl std::fmt::Debug for TokenResponse {
 608    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 609        f.debug_struct("TokenResponse")
 610            .field("access_token", &"[redacted]")
 611            .field(
 612                "refresh_token",
 613                &self.refresh_token.as_ref().map(|_| "[redacted]"),
 614            )
 615            .field("expires_in", &self.expires_in)
 616            .field("token_type", &self.token_type)
 617            .finish()
 618    }
 619}
 620
 621impl TokenResponse {
 622    /// Convert into `OAuthTokens`, computing `expires_at` from `expires_in`.
 623    pub fn into_tokens(self) -> OAuthTokens {
 624        let expires_at = self
 625            .expires_in
 626            .map(|secs| SystemTime::now() + Duration::from_secs(secs));
 627        OAuthTokens {
 628            access_token: self.access_token,
 629            refresh_token: self.refresh_token,
 630            expires_at,
 631        }
 632    }
 633}
 634
 635/// Build the form-encoded body for an authorization code token exchange.
 636pub fn token_exchange_params(
 637    code: &str,
 638    client_id: &str,
 639    redirect_uri: &str,
 640    code_verifier: &str,
 641    resource: &str,
 642    client_secret: Option<&str>,
 643) -> Vec<(&'static str, String)> {
 644    let mut params = vec![
 645        ("grant_type", "authorization_code".to_string()),
 646        ("code", code.to_string()),
 647        ("redirect_uri", redirect_uri.to_string()),
 648        ("client_id", client_id.to_string()),
 649        ("code_verifier", code_verifier.to_string()),
 650        ("resource", resource.to_string()),
 651    ];
 652    if let Some(secret) = client_secret {
 653        params.push(("client_secret", secret.to_string()));
 654    }
 655    params
 656}
 657
 658/// Build the form-encoded body for a token refresh request.
 659pub fn token_refresh_params(
 660    refresh_token: &str,
 661    client_id: &str,
 662    resource: &str,
 663    client_secret: Option<&str>,
 664) -> Vec<(&'static str, String)> {
 665    let mut params = vec![
 666        ("grant_type", "refresh_token".to_string()),
 667        ("refresh_token", refresh_token.to_string()),
 668        ("client_id", client_id.to_string()),
 669        ("resource", resource.to_string()),
 670    ];
 671    if let Some(secret) = client_secret {
 672        params.push(("client_secret", secret.to_string()));
 673    }
 674    params
 675}
 676
 677// -- DCR request body (RFC 7591) ---------------------------------------------
 678
 679/// Build the JSON body for a Dynamic Client Registration request.
 680///
 681/// The `redirect_uri` should be the actual loopback URI with the ephemeral
 682/// port (e.g. `http://127.0.0.1:12345/callback`). Some auth servers do strict
 683/// redirect URI matching even for loopback addresses, so we register the
 684/// exact URI we intend to use.
 685pub fn dcr_registration_body(redirect_uri: &str) -> serde_json::Value {
 686    serde_json::json!({
 687        "client_name": "Zed",
 688        "redirect_uris": [redirect_uri],
 689        "grant_types": ["authorization_code"],
 690        "response_types": ["code"],
 691        "token_endpoint_auth_method": "none"
 692    })
 693}
 694
 695// -- Discovery (async, hits real endpoints) ----------------------------------
 696
 697/// Fetch Protected Resource Metadata from the MCP server.
 698///
 699/// Tries the `resource_metadata` URL from the `WWW-Authenticate` header first,
 700/// then falls back to well-known URIs constructed from `server_url`.
 701pub async fn fetch_protected_resource_metadata(
 702    http_client: &Arc<dyn HttpClient>,
 703    server_url: &Url,
 704    www_authenticate: &WwwAuthenticate,
 705) -> Result<ProtectedResourceMetadata> {
 706    let candidate_urls = match &www_authenticate.resource_metadata {
 707        Some(url) if url.origin() == server_url.origin() => vec![url.clone()],
 708        Some(url) => {
 709            log::warn!(
 710                "Ignoring cross-origin resource_metadata URL {} \
 711                 (server origin: {})",
 712                url,
 713                server_url.origin().unicode_serialization()
 714            );
 715            protected_resource_metadata_urls(server_url)
 716        }
 717        None => protected_resource_metadata_urls(server_url),
 718    };
 719
 720    for url in &candidate_urls {
 721        match fetch_json::<ProtectedResourceMetadataResponse>(http_client, url).await {
 722            Ok(response) => {
 723                if response.authorization_servers.is_empty() {
 724                    bail!(
 725                        "Protected Resource Metadata at {} has no authorization_servers",
 726                        url
 727                    );
 728                }
 729                return Ok(ProtectedResourceMetadata {
 730                    resource: response.resource.unwrap_or_else(|| server_url.clone()),
 731                    authorization_servers: response.authorization_servers,
 732                    scopes_supported: response.scopes_supported,
 733                });
 734            }
 735            Err(err) => {
 736                log::debug!(
 737                    "Failed to fetch Protected Resource Metadata from {}: {}",
 738                    url,
 739                    err
 740                );
 741            }
 742        }
 743    }
 744
 745    bail!(
 746        "Could not fetch Protected Resource Metadata for {}",
 747        server_url
 748    )
 749}
 750
 751/// Fetch Authorization Server Metadata, trying RFC 8414 and OIDC Discovery
 752/// endpoints in the priority order specified by the MCP spec.
 753pub async fn fetch_auth_server_metadata(
 754    http_client: &Arc<dyn HttpClient>,
 755    issuer: &Url,
 756) -> Result<AuthServerMetadata> {
 757    let candidate_urls = auth_server_metadata_urls(issuer);
 758
 759    for url in &candidate_urls {
 760        match fetch_json::<AuthServerMetadataResponse>(http_client, url).await {
 761            Ok(response) => {
 762                let reported_issuer = response.issuer.unwrap_or_else(|| issuer.clone());
 763                // if reported_issuer != *issuer {
 764                //     bail!(
 765                //         "Auth server metadata issuer mismatch: expected {}, got {}",
 766                //         issuer,
 767                //         reported_issuer
 768                //     );
 769                // }
 770
 771                return Ok(AuthServerMetadata {
 772                    issuer: reported_issuer,
 773                    authorization_endpoint: response
 774                        .authorization_endpoint
 775                        .ok_or_else(|| anyhow!("missing authorization_endpoint"))?,
 776                    token_endpoint: response
 777                        .token_endpoint
 778                        .ok_or_else(|| anyhow!("missing token_endpoint"))?,
 779                    registration_endpoint: response.registration_endpoint,
 780                    scopes_supported: response.scopes_supported,
 781                    code_challenge_methods_supported: response.code_challenge_methods_supported,
 782                    client_id_metadata_document_supported: response
 783                        .client_id_metadata_document_supported
 784                        .unwrap_or(false),
 785                });
 786            }
 787            Err(err) => {
 788                log::debug!("Failed to fetch Auth Server Metadata from {}: {}", url, err);
 789            }
 790        }
 791    }
 792
 793    bail!(
 794        "Could not fetch Authorization Server Metadata for {}",
 795        issuer
 796    )
 797}
 798
 799/// Run the full discovery flow: fetch resource metadata, then auth server
 800/// metadata, then select scopes. Client registration is resolved separately,
 801/// once the real redirect URI is known.
 802pub async fn discover(
 803    http_client: &Arc<dyn HttpClient>,
 804    server_url: &Url,
 805    www_authenticate: &WwwAuthenticate,
 806) -> Result<OAuthDiscovery> {
 807    let resource_metadata =
 808        fetch_protected_resource_metadata(http_client, server_url, www_authenticate).await?;
 809
 810    let auth_server_url = resource_metadata
 811        .authorization_servers
 812        .first()
 813        .ok_or_else(|| anyhow!("no authorization servers in resource metadata"))?;
 814
 815    let auth_server_metadata = fetch_auth_server_metadata(http_client, auth_server_url).await?;
 816
 817    // Verify PKCE S256 support (spec requirement).
 818    match &auth_server_metadata.code_challenge_methods_supported {
 819        Some(methods) if methods.iter().any(|m| m == "S256") => {}
 820        Some(_) => bail!("authorization server does not support S256 PKCE"),
 821        None => bail!("authorization server does not advertise code_challenge_methods_supported"),
 822    }
 823
 824    let scopes = select_scopes(www_authenticate, &resource_metadata);
 825
 826    Ok(OAuthDiscovery {
 827        resource_metadata,
 828        auth_server_metadata,
 829        scopes,
 830    })
 831}
 832
 833/// Resolve the OAuth client registration for an authorization flow.
 834///
 835/// CIMD uses the static client metadata document directly. For DCR, a fresh
 836/// registration is performed each time because the loopback redirect URI
 837/// includes an ephemeral port that changes every flow.
 838pub async fn resolve_client_registration(
 839    http_client: &Arc<dyn HttpClient>,
 840    discovery: &OAuthDiscovery,
 841    redirect_uri: &str,
 842) -> Result<OAuthClientRegistration> {
 843    match determine_registration_strategy(&discovery.auth_server_metadata) {
 844        ClientRegistrationStrategy::Cimd { client_id } => Ok(OAuthClientRegistration {
 845            client_id,
 846            client_secret: None,
 847        }),
 848        ClientRegistrationStrategy::Dcr {
 849            registration_endpoint,
 850        } => perform_dcr(http_client, &registration_endpoint, redirect_uri).await,
 851        ClientRegistrationStrategy::Unavailable => {
 852            bail!("authorization server supports neither CIMD nor DCR")
 853        }
 854    }
 855}
 856
 857// -- Dynamic Client Registration (RFC 7591) ----------------------------------
 858
 859/// Perform Dynamic Client Registration with the authorization server.
 860pub async fn perform_dcr(
 861    http_client: &Arc<dyn HttpClient>,
 862    registration_endpoint: &Url,
 863    redirect_uri: &str,
 864) -> Result<OAuthClientRegistration> {
 865    validate_oauth_url(registration_endpoint)?;
 866
 867    let body = dcr_registration_body(redirect_uri);
 868    let body_bytes = serde_json::to_vec(&body)?;
 869
 870    let request = Request::builder()
 871        .method(http_client::http::Method::POST)
 872        .uri(registration_endpoint.as_str())
 873        .header("Content-Type", "application/json")
 874        .header("Accept", "application/json")
 875        .body(AsyncBody::from(body_bytes))?;
 876
 877    let mut response = http_client.send(request).await?;
 878
 879    if !response.status().is_success() {
 880        let mut error_body = String::new();
 881        response.body_mut().read_to_string(&mut error_body).await?;
 882        bail!(
 883            "DCR failed with status {}: {}",
 884            response.status(),
 885            error_body
 886        );
 887    }
 888
 889    let mut response_body = String::new();
 890    response
 891        .body_mut()
 892        .read_to_string(&mut response_body)
 893        .await?;
 894
 895    let dcr_response: DcrResponse =
 896        serde_json::from_str(&response_body).context("failed to parse DCR response")?;
 897
 898    Ok(OAuthClientRegistration {
 899        client_id: dcr_response.client_id,
 900        client_secret: dcr_response.client_secret,
 901    })
 902}
 903
 904// -- Token exchange and refresh (async) --------------------------------------
 905
 906/// Exchange an authorization code for tokens at the token endpoint.
 907pub async fn exchange_code(
 908    http_client: &Arc<dyn HttpClient>,
 909    auth_server_metadata: &AuthServerMetadata,
 910    code: &str,
 911    client_id: &str,
 912    redirect_uri: &str,
 913    code_verifier: &str,
 914    resource: &str,
 915    client_secret: Option<&str>,
 916) -> Result<OAuthTokens> {
 917    let params = token_exchange_params(
 918        code,
 919        client_id,
 920        redirect_uri,
 921        code_verifier,
 922        resource,
 923        client_secret,
 924    );
 925    post_token_request(http_client, &auth_server_metadata.token_endpoint, &params).await
 926}
 927
 928/// Refresh tokens using a refresh token.
 929pub async fn refresh_tokens(
 930    http_client: &Arc<dyn HttpClient>,
 931    token_endpoint: &Url,
 932    refresh_token: &str,
 933    client_id: &str,
 934    resource: &str,
 935    client_secret: Option<&str>,
 936) -> Result<OAuthTokens> {
 937    let params = token_refresh_params(refresh_token, client_id, resource, client_secret);
 938    post_token_request(http_client, token_endpoint, &params).await
 939}
 940
 941/// POST form-encoded parameters to a token endpoint and parse the response.
 942async fn post_token_request(
 943    http_client: &Arc<dyn HttpClient>,
 944    token_endpoint: &Url,
 945    params: &[(&str, String)],
 946) -> Result<OAuthTokens> {
 947    validate_oauth_url(token_endpoint)?;
 948
 949    let body = url::form_urlencoded::Serializer::new(String::new())
 950        .extend_pairs(params.iter().map(|(k, v)| (*k, v.as_str())))
 951        .finish();
 952
 953    let request = Request::builder()
 954        .method(http_client::http::Method::POST)
 955        .uri(token_endpoint.as_str())
 956        .header("Content-Type", "application/x-www-form-urlencoded")
 957        .header("Accept", "application/json")
 958        .body(AsyncBody::from(body.into_bytes()))?;
 959
 960    let mut response = http_client.send(request).await?;
 961
 962    if !response.status().is_success() {
 963        let mut error_body = String::new();
 964        response.body_mut().read_to_string(&mut error_body).await?;
 965        bail!(
 966            "token request failed with status {}: {}",
 967            response.status(),
 968            error_body
 969        );
 970    }
 971
 972    let mut response_body = String::new();
 973    response
 974        .body_mut()
 975        .read_to_string(&mut response_body)
 976        .await?;
 977
 978    let token_response: TokenResponse =
 979        serde_json::from_str(&response_body).context("failed to parse token response")?;
 980
 981    Ok(token_response.into_tokens())
 982}
 983
 984// -- Loopback HTTP callback server -------------------------------------------
 985
 986/// An OAuth authorization callback received via the loopback HTTP server.
 987pub struct OAuthCallback {
 988    pub code: String,
 989    pub state: String,
 990}
 991
 992impl std::fmt::Debug for OAuthCallback {
 993    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 994        f.debug_struct("OAuthCallback")
 995            .field("code", &"[redacted]")
 996            .field("state", &"[redacted]")
 997            .finish()
 998    }
 999}
1000
1001impl OAuthCallback {
1002    /// Parse the query string from a callback URL like
1003    /// `http://127.0.0.1:<port>/callback?code=...&state=...`.
1004    pub fn parse_query(query: &str) -> Result<Self> {
1005        let mut code: Option<String> = None;
1006        let mut state: Option<String> = None;
1007        let mut error: Option<String> = None;
1008        let mut error_description: Option<String> = None;
1009
1010        for (key, value) in url::form_urlencoded::parse(query.as_bytes()) {
1011            match key.as_ref() {
1012                "code" => {
1013                    if !value.is_empty() {
1014                        code = Some(value.into_owned());
1015                    }
1016                }
1017                "state" => {
1018                    if !value.is_empty() {
1019                        state = Some(value.into_owned());
1020                    }
1021                }
1022                "error" => {
1023                    if !value.is_empty() {
1024                        error = Some(value.into_owned());
1025                    }
1026                }
1027                "error_description" => {
1028                    if !value.is_empty() {
1029                        error_description = Some(value.into_owned());
1030                    }
1031                }
1032                _ => {}
1033            }
1034        }
1035
1036        // Check for OAuth error response (RFC 6749 Section 4.1.2.1) before
1037        // checking for missing code/state.
1038        if let Some(error_code) = error {
1039            bail!(
1040                "OAuth authorization failed: {} ({})",
1041                error_code,
1042                error_description.as_deref().unwrap_or("no description")
1043            );
1044        }
1045
1046        let code = code.ok_or_else(|| anyhow!("missing 'code' parameter in OAuth callback"))?;
1047        let state = state.ok_or_else(|| anyhow!("missing 'state' parameter in OAuth callback"))?;
1048
1049        Ok(Self { code, state })
1050    }
1051}
1052
1053/// How long to wait for the browser to complete the OAuth flow before giving
1054/// up and releasing the loopback port.
1055const CALLBACK_TIMEOUT: Duration = Duration::from_secs(2 * 60);
1056
1057/// Start a loopback HTTP server to receive the OAuth authorization callback.
1058///
1059/// Binds to an ephemeral loopback port for each flow.
1060///
1061/// Returns `(redirect_uri, callback_future)`. The caller should use the
1062/// redirect URI in the authorization request, open the browser, then await
1063/// the future to receive the callback.
1064///
1065/// The server accepts exactly one request on `/callback`, validates that it
1066/// contains `code` and `state` query parameters, responds with a minimal
1067/// HTML page telling the user they can close the tab, and shuts down.
1068///
1069/// The callback server shuts down when the returned oneshot receiver is dropped
1070/// (e.g. because the authentication task was cancelled), or after a timeout
1071/// ([CALLBACK_TIMEOUT]).
1072pub async fn start_callback_server() -> Result<(
1073    String,
1074    futures::channel::oneshot::Receiver<Result<OAuthCallback>>,
1075)> {
1076    let server = tiny_http::Server::http("127.0.0.1:0")
1077        .map_err(|e| anyhow!(e).context("Failed to bind loopback listener for OAuth callback"))?;
1078    let port = server
1079        .server_addr()
1080        .to_ip()
1081        .context("server not bound to a TCP address")?
1082        .port();
1083
1084    let redirect_uri = format!("http://127.0.0.1:{}/callback", port);
1085
1086    let (tx, rx) = futures::channel::oneshot::channel();
1087
1088    // `tiny_http` is blocking, so we run it on a background thread.
1089    // The `recv_timeout` loop lets us check for cancellation (the receiver
1090    // being dropped) and enforce an overall timeout.
1091    std::thread::spawn(move || {
1092        let deadline = std::time::Instant::now() + CALLBACK_TIMEOUT;
1093
1094        loop {
1095            if tx.is_canceled() {
1096                return;
1097            }
1098            let remaining = deadline.saturating_duration_since(std::time::Instant::now());
1099            if remaining.is_zero() {
1100                return;
1101            }
1102
1103            let timeout = remaining.min(Duration::from_millis(500));
1104            let Some(request) = (match server.recv_timeout(timeout) {
1105                Ok(req) => req,
1106                Err(_) => {
1107                    let _ = tx.send(Err(anyhow!("OAuth callback server I/O error")));
1108                    return;
1109                }
1110            }) else {
1111                // Timeout with no request — loop back and check cancellation.
1112                continue;
1113            };
1114
1115            let result = handle_callback_request(&request);
1116
1117            let (status_code, body) = match &result {
1118                Ok(_) => (
1119                    200,
1120                    "<html><body><h1>Authorization successful</h1>\
1121                     <p>You can close this tab and return to Zed.</p></body></html>",
1122                ),
1123                Err(err) => {
1124                    log::error!("OAuth callback error: {}", err);
1125                    (
1126                        400,
1127                        "<html><body><h1>Authorization failed</h1>\
1128                         <p>Something went wrong. Please try again from Zed.</p></body></html>",
1129                    )
1130                }
1131            };
1132
1133            let response = tiny_http::Response::from_string(body)
1134                .with_status_code(status_code)
1135                .with_header(
1136                    tiny_http::Header::from_str("Content-Type: text/html")
1137                        .expect("failed to construct response header"),
1138                )
1139                .with_header(
1140                    tiny_http::Header::from_str("Keep-Alive: timeout=0,max=0")
1141                        .expect("failed to construct response header"),
1142                );
1143            request.respond(response).log_err();
1144
1145            let _ = tx.send(result);
1146            return;
1147        }
1148    });
1149
1150    Ok((redirect_uri, rx))
1151}
1152
1153/// Extract the `code` and `state` query parameters from an OAuth callback
1154/// request to `/callback`.
1155fn handle_callback_request(request: &tiny_http::Request) -> Result<OAuthCallback> {
1156    let url = Url::parse(&format!("http://localhost{}", request.url()))
1157        .context("malformed callback request URL")?;
1158
1159    if url.path() != "/callback" {
1160        bail!("unexpected path in OAuth callback: {}", url.path());
1161    }
1162
1163    let query = url
1164        .query()
1165        .ok_or_else(|| anyhow!("OAuth callback has no query string"))?;
1166    OAuthCallback::parse_query(query)
1167}
1168
1169// -- JSON fetch helper -------------------------------------------------------
1170
1171async fn fetch_json<T: serde::de::DeserializeOwned>(
1172    http_client: &Arc<dyn HttpClient>,
1173    url: &Url,
1174) -> Result<T> {
1175    validate_oauth_url(url)?;
1176
1177    let request = Request::builder()
1178        .method(http_client::http::Method::GET)
1179        .uri(url.as_str())
1180        .header("Accept", "application/json")
1181        .body(AsyncBody::default())?;
1182
1183    let mut response = http_client.send(request).await?;
1184
1185    if !response.status().is_success() {
1186        bail!("HTTP {} fetching {}", response.status(), url);
1187    }
1188
1189    let mut body = String::new();
1190    response.body_mut().read_to_string(&mut body).await?;
1191    serde_json::from_str(&body).with_context(|| format!("failed to parse JSON from {}", url))
1192}
1193
1194// -- Serde response types for discovery --------------------------------------
1195
1196#[derive(Debug, Deserialize)]
1197struct ProtectedResourceMetadataResponse {
1198    #[serde(default)]
1199    resource: Option<Url>,
1200    #[serde(default)]
1201    authorization_servers: Vec<Url>,
1202    #[serde(default)]
1203    scopes_supported: Option<Vec<String>>,
1204}
1205
1206#[derive(Debug, Deserialize)]
1207struct AuthServerMetadataResponse {
1208    #[serde(default)]
1209    issuer: Option<Url>,
1210    #[serde(default)]
1211    authorization_endpoint: Option<Url>,
1212    #[serde(default)]
1213    token_endpoint: Option<Url>,
1214    #[serde(default)]
1215    registration_endpoint: Option<Url>,
1216    #[serde(default)]
1217    scopes_supported: Option<Vec<String>>,
1218    #[serde(default)]
1219    code_challenge_methods_supported: Option<Vec<String>>,
1220    #[serde(default)]
1221    client_id_metadata_document_supported: Option<bool>,
1222}
1223
1224#[derive(Debug, Deserialize)]
1225struct DcrResponse {
1226    client_id: String,
1227    #[serde(default)]
1228    client_secret: Option<String>,
1229}
1230
1231/// Provides OAuth tokens to the HTTP transport layer.
1232///
1233/// The transport calls `access_token()` before each request. On a 401 response
1234/// it calls `try_refresh()` and retries once if the refresh succeeds.
1235#[async_trait]
1236pub trait OAuthTokenProvider: Send + Sync {
1237    /// Returns the current access token, if one is available.
1238    fn access_token(&self) -> Option<String>;
1239
1240    /// Attempts to refresh the access token. Returns `true` if a new token was
1241    /// obtained and the request should be retried.
1242    async fn try_refresh(&self) -> Result<bool>;
1243}
1244
1245/// Concrete `OAuthTokenProvider` backed by a full persisted OAuth session and
1246/// an HTTP client for token refresh. The same provider type is used both after
1247/// an interactive authentication flow and when restoring a saved session from
1248/// the keychain on startup.
1249pub struct McpOAuthTokenProvider {
1250    session: SyncMutex<OAuthSession>,
1251    http_client: Arc<dyn HttpClient>,
1252    token_refresh_tx: Option<mpsc::UnboundedSender<OAuthSession>>,
1253}
1254
1255impl McpOAuthTokenProvider {
1256    pub fn new(
1257        session: OAuthSession,
1258        http_client: Arc<dyn HttpClient>,
1259        token_refresh_tx: Option<mpsc::UnboundedSender<OAuthSession>>,
1260    ) -> Self {
1261        Self {
1262            session: SyncMutex::new(session),
1263            http_client,
1264            token_refresh_tx,
1265        }
1266    }
1267
1268    fn access_token_is_expired(tokens: &OAuthTokens) -> bool {
1269        tokens.expires_at.is_some_and(|expires_at| {
1270            SystemTime::now()
1271                .checked_add(Duration::from_secs(30))
1272                .is_some_and(|now_with_buffer| expires_at <= now_with_buffer)
1273        })
1274    }
1275}
1276
1277#[async_trait]
1278impl OAuthTokenProvider for McpOAuthTokenProvider {
1279    fn access_token(&self) -> Option<String> {
1280        let session = self.session.lock();
1281        if Self::access_token_is_expired(&session.tokens) {
1282            return None;
1283        }
1284        Some(session.tokens.access_token.clone())
1285    }
1286
1287    async fn try_refresh(&self) -> Result<bool> {
1288        let (refresh_token, token_endpoint, resource, client_id, client_secret) = {
1289            let session = self.session.lock();
1290            match session.tokens.refresh_token.clone() {
1291                Some(refresh_token) => (
1292                    refresh_token,
1293                    session.token_endpoint.clone(),
1294                    session.resource.clone(),
1295                    session.client_registration.client_id.clone(),
1296                    session.client_registration.client_secret.clone(),
1297                ),
1298                None => return Ok(false),
1299            }
1300        };
1301
1302        let resource_str = canonical_server_uri(&resource);
1303
1304        match refresh_tokens(
1305            &self.http_client,
1306            &token_endpoint,
1307            &refresh_token,
1308            &client_id,
1309            &resource_str,
1310            client_secret.as_deref(),
1311        )
1312        .await
1313        {
1314            Ok(mut new_tokens) => {
1315                if new_tokens.refresh_token.is_none() {
1316                    new_tokens.refresh_token = Some(refresh_token);
1317                }
1318
1319                {
1320                    let mut session = self.session.lock();
1321                    session.tokens = new_tokens;
1322
1323                    if let Some(ref tx) = self.token_refresh_tx {
1324                        tx.unbounded_send(session.clone()).ok();
1325                    }
1326                }
1327
1328                Ok(true)
1329            }
1330            Err(err) => {
1331                log::warn!("OAuth token refresh failed: {}", err);
1332                Ok(false)
1333            }
1334        }
1335    }
1336}
1337
1338#[cfg(test)]
1339mod tests {
1340    use super::*;
1341    use http_client::Response;
1342
1343    // -- require_https_or_loopback tests ------------------------------------
1344
1345    #[test]
1346    fn test_require_https_or_loopback_accepts_https() {
1347        let url = Url::parse("https://auth.example.com/token").unwrap();
1348        assert!(require_https_or_loopback(&url).is_ok());
1349    }
1350
1351    #[test]
1352    fn test_require_https_or_loopback_rejects_http_remote() {
1353        let url = Url::parse("http://auth.example.com/token").unwrap();
1354        assert!(require_https_or_loopback(&url).is_err());
1355    }
1356
1357    #[test]
1358    fn test_require_https_or_loopback_accepts_http_127_0_0_1() {
1359        let url = Url::parse("http://127.0.0.1:8080/callback").unwrap();
1360        assert!(require_https_or_loopback(&url).is_ok());
1361    }
1362
1363    #[test]
1364    fn test_require_https_or_loopback_accepts_http_ipv6_loopback() {
1365        let url = Url::parse("http://[::1]:8080/callback").unwrap();
1366        assert!(require_https_or_loopback(&url).is_ok());
1367    }
1368
1369    #[test]
1370    fn test_require_https_or_loopback_accepts_http_localhost() {
1371        let url = Url::parse("http://localhost:8080/callback").unwrap();
1372        assert!(require_https_or_loopback(&url).is_ok());
1373    }
1374
1375    #[test]
1376    fn test_require_https_or_loopback_accepts_http_localhost_case_insensitive() {
1377        let url = Url::parse("http://LOCALHOST:8080/callback").unwrap();
1378        assert!(require_https_or_loopback(&url).is_ok());
1379    }
1380
1381    #[test]
1382    fn test_require_https_or_loopback_rejects_http_non_loopback_ip() {
1383        let url = Url::parse("http://192.168.1.1:8080/token").unwrap();
1384        assert!(require_https_or_loopback(&url).is_err());
1385    }
1386
1387    #[test]
1388    fn test_require_https_or_loopback_rejects_ftp() {
1389        let url = Url::parse("ftp://auth.example.com/token").unwrap();
1390        assert!(require_https_or_loopback(&url).is_err());
1391    }
1392
1393    // -- validate_oauth_url (SSRF) tests ------------------------------------
1394
1395    #[test]
1396    fn test_validate_oauth_url_accepts_https_public() {
1397        let url = Url::parse("https://auth.example.com/token").unwrap();
1398        assert!(validate_oauth_url(&url).is_ok());
1399    }
1400
1401    #[test]
1402    fn test_validate_oauth_url_rejects_private_ipv4_10() {
1403        let url = Url::parse("https://10.0.0.1/token").unwrap();
1404        assert!(validate_oauth_url(&url).is_err());
1405    }
1406
1407    #[test]
1408    fn test_validate_oauth_url_rejects_private_ipv4_172() {
1409        let url = Url::parse("https://172.16.0.1/token").unwrap();
1410        assert!(validate_oauth_url(&url).is_err());
1411    }
1412
1413    #[test]
1414    fn test_validate_oauth_url_rejects_private_ipv4_192() {
1415        let url = Url::parse("https://192.168.1.1/token").unwrap();
1416        assert!(validate_oauth_url(&url).is_err());
1417    }
1418
1419    #[test]
1420    fn test_validate_oauth_url_rejects_link_local() {
1421        let url = Url::parse("https://169.254.169.254/latest/meta-data/").unwrap();
1422        assert!(validate_oauth_url(&url).is_err());
1423    }
1424
1425    #[test]
1426    fn test_validate_oauth_url_rejects_ipv6_ula() {
1427        let url = Url::parse("https://[fd12:3456:789a::1]/token").unwrap();
1428        assert!(validate_oauth_url(&url).is_err());
1429    }
1430
1431    #[test]
1432    fn test_validate_oauth_url_rejects_ipv6_unspecified() {
1433        let url = Url::parse("https://[::]/token").unwrap();
1434        assert!(validate_oauth_url(&url).is_err());
1435    }
1436
1437    #[test]
1438    fn test_validate_oauth_url_rejects_ipv4_mapped_ipv6_private() {
1439        let url = Url::parse("https://[::ffff:10.0.0.1]/token").unwrap();
1440        assert!(validate_oauth_url(&url).is_err());
1441    }
1442
1443    #[test]
1444    fn test_validate_oauth_url_rejects_ipv4_mapped_ipv6_link_local() {
1445        let url = Url::parse("https://[::ffff:169.254.169.254]/token").unwrap();
1446        assert!(validate_oauth_url(&url).is_err());
1447    }
1448
1449    #[test]
1450    fn test_validate_oauth_url_allows_http_loopback() {
1451        // Loopback is permitted (it's our callback server).
1452        let url = Url::parse("http://127.0.0.1:8080/callback").unwrap();
1453        assert!(validate_oauth_url(&url).is_ok());
1454    }
1455
1456    #[test]
1457    fn test_validate_oauth_url_allows_https_public_ip() {
1458        let url = Url::parse("https://93.184.216.34/token").unwrap();
1459        assert!(validate_oauth_url(&url).is_ok());
1460    }
1461
1462    // -- parse_www_authenticate tests ----------------------------------------
1463
1464    #[test]
1465    fn test_parse_www_authenticate_with_resource_metadata_and_scope() {
1466        let header = r#"Bearer resource_metadata="https://mcp.example.com/.well-known/oauth-protected-resource", scope="files:read user:profile""#;
1467        let result = parse_www_authenticate(header).unwrap();
1468
1469        assert_eq!(
1470            result.resource_metadata.as_ref().map(|u| u.as_str()),
1471            Some("https://mcp.example.com/.well-known/oauth-protected-resource")
1472        );
1473        assert_eq!(
1474            result.scope,
1475            Some(vec!["files:read".to_string(), "user:profile".to_string()])
1476        );
1477        assert_eq!(result.error, None);
1478    }
1479
1480    #[test]
1481    fn test_parse_www_authenticate_resource_metadata_only() {
1482        let header = r#"Bearer resource_metadata="https://mcp.example.com/.well-known/oauth-protected-resource""#;
1483        let result = parse_www_authenticate(header).unwrap();
1484
1485        assert_eq!(
1486            result.resource_metadata.as_ref().map(|u| u.as_str()),
1487            Some("https://mcp.example.com/.well-known/oauth-protected-resource")
1488        );
1489        assert_eq!(result.scope, None);
1490    }
1491
1492    #[test]
1493    fn test_parse_www_authenticate_bare_bearer() {
1494        let result = parse_www_authenticate("Bearer").unwrap();
1495        assert_eq!(result.resource_metadata, None);
1496        assert_eq!(result.scope, None);
1497    }
1498
1499    #[test]
1500    fn test_parse_www_authenticate_with_error() {
1501        let header = r#"Bearer error="insufficient_scope", scope="files:read files:write", resource_metadata="https://mcp.example.com/.well-known/oauth-protected-resource", error_description="Additional file write permission required""#;
1502        let result = parse_www_authenticate(header).unwrap();
1503
1504        assert_eq!(result.error, Some(BearerError::InsufficientScope));
1505        assert_eq!(
1506            result.error_description.as_deref(),
1507            Some("Additional file write permission required")
1508        );
1509        assert_eq!(
1510            result.scope,
1511            Some(vec!["files:read".to_string(), "files:write".to_string()])
1512        );
1513        assert!(result.resource_metadata.is_some());
1514    }
1515
1516    #[test]
1517    fn test_parse_www_authenticate_invalid_token_error() {
1518        let header =
1519            r#"Bearer error="invalid_token", error_description="The access token expired""#;
1520        let result = parse_www_authenticate(header).unwrap();
1521        assert_eq!(result.error, Some(BearerError::InvalidToken));
1522    }
1523
1524    #[test]
1525    fn test_parse_www_authenticate_invalid_request_error() {
1526        let header = r#"Bearer error="invalid_request""#;
1527        let result = parse_www_authenticate(header).unwrap();
1528        assert_eq!(result.error, Some(BearerError::InvalidRequest));
1529    }
1530
1531    #[test]
1532    fn test_parse_www_authenticate_unknown_error() {
1533        let header = r#"Bearer error="some_future_error""#;
1534        let result = parse_www_authenticate(header).unwrap();
1535        assert_eq!(result.error, Some(BearerError::Other));
1536    }
1537
1538    #[test]
1539    fn test_parse_www_authenticate_rejects_non_bearer() {
1540        let result = parse_www_authenticate("Basic realm=\"example\"");
1541        assert!(result.is_err());
1542    }
1543
1544    #[test]
1545    fn test_parse_www_authenticate_case_insensitive_scheme() {
1546        let header = r#"bearer resource_metadata="https://example.com/.well-known/oauth-protected-resource""#;
1547        let result = parse_www_authenticate(header).unwrap();
1548        assert!(result.resource_metadata.is_some());
1549    }
1550
1551    #[test]
1552    fn test_parse_www_authenticate_multiline_style() {
1553        // Some servers emit the header spread across multiple lines joined by
1554        // whitespace, as shown in the spec examples.
1555        let header = "Bearer resource_metadata=\"https://mcp.example.com/.well-known/oauth-protected-resource\",\n                         scope=\"files:read\"";
1556        let result = parse_www_authenticate(header).unwrap();
1557        assert!(result.resource_metadata.is_some());
1558        assert_eq!(result.scope, Some(vec!["files:read".to_string()]));
1559    }
1560
1561    #[test]
1562    fn test_protected_resource_metadata_urls_with_path() {
1563        let server_url = Url::parse("https://api.example.com/v1/mcp").unwrap();
1564        let urls = protected_resource_metadata_urls(&server_url);
1565
1566        assert_eq!(urls.len(), 2);
1567        assert_eq!(
1568            urls[0].as_str(),
1569            "https://api.example.com/.well-known/oauth-protected-resource/v1/mcp"
1570        );
1571        assert_eq!(
1572            urls[1].as_str(),
1573            "https://api.example.com/.well-known/oauth-protected-resource"
1574        );
1575    }
1576
1577    #[test]
1578    fn test_protected_resource_metadata_urls_without_path() {
1579        let server_url = Url::parse("https://mcp.example.com").unwrap();
1580        let urls = protected_resource_metadata_urls(&server_url);
1581
1582        assert_eq!(urls.len(), 1);
1583        assert_eq!(
1584            urls[0].as_str(),
1585            "https://mcp.example.com/.well-known/oauth-protected-resource"
1586        );
1587    }
1588
1589    #[test]
1590    fn test_auth_server_metadata_urls_with_path() {
1591        let issuer = Url::parse("https://auth.example.com/tenant1").unwrap();
1592        let urls = auth_server_metadata_urls(&issuer);
1593
1594        assert_eq!(urls.len(), 3);
1595        assert_eq!(
1596            urls[0].as_str(),
1597            "https://auth.example.com/.well-known/oauth-authorization-server/tenant1"
1598        );
1599        assert_eq!(
1600            urls[1].as_str(),
1601            "https://auth.example.com/.well-known/openid-configuration/tenant1"
1602        );
1603        assert_eq!(
1604            urls[2].as_str(),
1605            "https://auth.example.com/tenant1/.well-known/openid-configuration"
1606        );
1607    }
1608
1609    #[test]
1610    fn test_auth_server_metadata_urls_without_path() {
1611        let issuer = Url::parse("https://auth.example.com").unwrap();
1612        let urls = auth_server_metadata_urls(&issuer);
1613
1614        assert_eq!(urls.len(), 2);
1615        assert_eq!(
1616            urls[0].as_str(),
1617            "https://auth.example.com/.well-known/oauth-authorization-server"
1618        );
1619        assert_eq!(
1620            urls[1].as_str(),
1621            "https://auth.example.com/.well-known/openid-configuration"
1622        );
1623    }
1624
1625    // -- Canonical server URI tests ------------------------------------------
1626
1627    #[test]
1628    fn test_canonical_server_uri_simple() {
1629        let url = Url::parse("https://mcp.example.com").unwrap();
1630        assert_eq!(canonical_server_uri(&url), "https://mcp.example.com");
1631    }
1632
1633    #[test]
1634    fn test_canonical_server_uri_with_path() {
1635        let url = Url::parse("https://mcp.example.com/v1/mcp").unwrap();
1636        assert_eq!(canonical_server_uri(&url), "https://mcp.example.com/v1/mcp");
1637    }
1638
1639    #[test]
1640    fn test_canonical_server_uri_strips_trailing_slash() {
1641        let url = Url::parse("https://mcp.example.com/").unwrap();
1642        assert_eq!(canonical_server_uri(&url), "https://mcp.example.com");
1643    }
1644
1645    #[test]
1646    fn test_canonical_server_uri_preserves_port() {
1647        let url = Url::parse("https://mcp.example.com:8443").unwrap();
1648        assert_eq!(canonical_server_uri(&url), "https://mcp.example.com:8443");
1649    }
1650
1651    #[test]
1652    fn test_canonical_server_uri_lowercases() {
1653        let url = Url::parse("HTTPS://MCP.Example.COM/Server/MCP").unwrap();
1654        assert_eq!(
1655            canonical_server_uri(&url),
1656            "https://mcp.example.com/Server/MCP"
1657        );
1658    }
1659
1660    // -- Scope selection tests -----------------------------------------------
1661
1662    #[test]
1663    fn test_select_scopes_prefers_www_authenticate() {
1664        let www_auth = WwwAuthenticate {
1665            resource_metadata: None,
1666            scope: Some(vec!["files:read".into()]),
1667            error: None,
1668            error_description: None,
1669        };
1670        let resource_meta = ProtectedResourceMetadata {
1671            resource: Url::parse("https://example.com").unwrap(),
1672            authorization_servers: vec![],
1673            scopes_supported: Some(vec!["files:read".into(), "files:write".into()]),
1674        };
1675        assert_eq!(select_scopes(&www_auth, &resource_meta), vec!["files:read"]);
1676    }
1677
1678    #[test]
1679    fn test_select_scopes_falls_back_to_resource_metadata() {
1680        let www_auth = WwwAuthenticate {
1681            resource_metadata: None,
1682            scope: None,
1683            error: None,
1684            error_description: None,
1685        };
1686        let resource_meta = ProtectedResourceMetadata {
1687            resource: Url::parse("https://example.com").unwrap(),
1688            authorization_servers: vec![],
1689            scopes_supported: Some(vec!["admin".into()]),
1690        };
1691        assert_eq!(select_scopes(&www_auth, &resource_meta), vec!["admin"]);
1692    }
1693
1694    #[test]
1695    fn test_select_scopes_empty_when_nothing_available() {
1696        let www_auth = WwwAuthenticate {
1697            resource_metadata: None,
1698            scope: None,
1699            error: None,
1700            error_description: None,
1701        };
1702        let resource_meta = ProtectedResourceMetadata {
1703            resource: Url::parse("https://example.com").unwrap(),
1704            authorization_servers: vec![],
1705            scopes_supported: None,
1706        };
1707        assert!(select_scopes(&www_auth, &resource_meta).is_empty());
1708    }
1709
1710    // -- Client registration strategy tests ----------------------------------
1711
1712    #[test]
1713    fn test_registration_strategy_prefers_cimd() {
1714        let metadata = AuthServerMetadata {
1715            issuer: Url::parse("https://auth.example.com").unwrap(),
1716            authorization_endpoint: Url::parse("https://auth.example.com/authorize").unwrap(),
1717            token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
1718            registration_endpoint: Some(Url::parse("https://auth.example.com/register").unwrap()),
1719            scopes_supported: None,
1720            code_challenge_methods_supported: Some(vec!["S256".into()]),
1721            client_id_metadata_document_supported: true,
1722        };
1723        assert_eq!(
1724            determine_registration_strategy(&metadata),
1725            ClientRegistrationStrategy::Cimd {
1726                client_id: CIMD_URL.to_string(),
1727            }
1728        );
1729    }
1730
1731    #[test]
1732    fn test_registration_strategy_falls_back_to_dcr() {
1733        let reg_endpoint = Url::parse("https://auth.example.com/register").unwrap();
1734        let metadata = AuthServerMetadata {
1735            issuer: Url::parse("https://auth.example.com").unwrap(),
1736            authorization_endpoint: Url::parse("https://auth.example.com/authorize").unwrap(),
1737            token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
1738            registration_endpoint: Some(reg_endpoint.clone()),
1739            scopes_supported: None,
1740            code_challenge_methods_supported: Some(vec!["S256".into()]),
1741            client_id_metadata_document_supported: false,
1742        };
1743        assert_eq!(
1744            determine_registration_strategy(&metadata),
1745            ClientRegistrationStrategy::Dcr {
1746                registration_endpoint: reg_endpoint,
1747            }
1748        );
1749    }
1750
1751    #[test]
1752    fn test_registration_strategy_unavailable() {
1753        let metadata = AuthServerMetadata {
1754            issuer: Url::parse("https://auth.example.com").unwrap(),
1755            authorization_endpoint: Url::parse("https://auth.example.com/authorize").unwrap(),
1756            token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
1757            registration_endpoint: None,
1758            scopes_supported: None,
1759            code_challenge_methods_supported: Some(vec!["S256".into()]),
1760            client_id_metadata_document_supported: false,
1761        };
1762        assert_eq!(
1763            determine_registration_strategy(&metadata),
1764            ClientRegistrationStrategy::Unavailable,
1765        );
1766    }
1767
1768    // -- PKCE tests ----------------------------------------------------------
1769
1770    #[test]
1771    fn test_pkce_challenge_verifier_length() {
1772        let pkce = generate_pkce_challenge();
1773        // 32 random bytes → 43 base64url chars (no padding).
1774        assert_eq!(pkce.verifier.len(), 43);
1775    }
1776
1777    #[test]
1778    fn test_pkce_challenge_is_valid_base64url() {
1779        let pkce = generate_pkce_challenge();
1780        for c in pkce.verifier.chars().chain(pkce.challenge.chars()) {
1781            assert!(
1782                c.is_ascii_alphanumeric() || c == '-' || c == '_',
1783                "invalid base64url character: {}",
1784                c
1785            );
1786        }
1787    }
1788
1789    #[test]
1790    fn test_pkce_challenge_is_s256_of_verifier() {
1791        let pkce = generate_pkce_challenge();
1792        let engine = base64::engine::general_purpose::URL_SAFE_NO_PAD;
1793        let expected_digest = Sha256::digest(pkce.verifier.as_bytes());
1794        let expected_challenge = engine.encode(expected_digest);
1795        assert_eq!(pkce.challenge, expected_challenge);
1796    }
1797
1798    #[test]
1799    fn test_pkce_challenges_are_unique() {
1800        let a = generate_pkce_challenge();
1801        let b = generate_pkce_challenge();
1802        assert_ne!(a.verifier, b.verifier);
1803    }
1804
1805    // -- Authorization URL tests ---------------------------------------------
1806
1807    #[test]
1808    fn test_build_authorization_url() {
1809        let metadata = AuthServerMetadata {
1810            issuer: Url::parse("https://auth.example.com").unwrap(),
1811            authorization_endpoint: Url::parse("https://auth.example.com/authorize").unwrap(),
1812            token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
1813            registration_endpoint: None,
1814            scopes_supported: None,
1815            code_challenge_methods_supported: Some(vec!["S256".into()]),
1816            client_id_metadata_document_supported: true,
1817        };
1818        let pkce = PkceChallenge {
1819            verifier: "test_verifier".into(),
1820            challenge: "test_challenge".into(),
1821        };
1822        let url = build_authorization_url(
1823            &metadata,
1824            "https://zed.dev/oauth/client-metadata.json",
1825            "http://127.0.0.1:12345/callback",
1826            &["files:read".into(), "files:write".into()],
1827            "https://mcp.example.com",
1828            &pkce,
1829            "random_state_123",
1830        );
1831
1832        let pairs: std::collections::HashMap<_, _> = url.query_pairs().collect();
1833        assert_eq!(pairs.get("response_type").unwrap(), "code");
1834        assert_eq!(
1835            pairs.get("client_id").unwrap(),
1836            "https://zed.dev/oauth/client-metadata.json"
1837        );
1838        assert_eq!(
1839            pairs.get("redirect_uri").unwrap(),
1840            "http://127.0.0.1:12345/callback"
1841        );
1842        assert_eq!(pairs.get("scope").unwrap(), "files:read files:write");
1843        assert_eq!(pairs.get("resource").unwrap(), "https://mcp.example.com");
1844        assert_eq!(pairs.get("code_challenge").unwrap(), "test_challenge");
1845        assert_eq!(pairs.get("code_challenge_method").unwrap(), "S256");
1846        assert_eq!(pairs.get("state").unwrap(), "random_state_123");
1847    }
1848
1849    #[test]
1850    fn test_build_authorization_url_omits_empty_scope() {
1851        let metadata = AuthServerMetadata {
1852            issuer: Url::parse("https://auth.example.com").unwrap(),
1853            authorization_endpoint: Url::parse("https://auth.example.com/authorize").unwrap(),
1854            token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
1855            registration_endpoint: None,
1856            scopes_supported: None,
1857            code_challenge_methods_supported: Some(vec!["S256".into()]),
1858            client_id_metadata_document_supported: false,
1859        };
1860        let pkce = PkceChallenge {
1861            verifier: "v".into(),
1862            challenge: "c".into(),
1863        };
1864        let url = build_authorization_url(
1865            &metadata,
1866            "client_123",
1867            "http://127.0.0.1:9999/callback",
1868            &[],
1869            "https://mcp.example.com",
1870            &pkce,
1871            "state",
1872        );
1873
1874        let pairs: std::collections::HashMap<_, _> = url.query_pairs().collect();
1875        assert!(!pairs.contains_key("scope"));
1876    }
1877
1878    // -- Token exchange / refresh param tests --------------------------------
1879
1880    #[test]
1881    fn test_token_exchange_params() {
1882        let params = token_exchange_params(
1883            "auth_code_abc",
1884            "client_xyz",
1885            "http://127.0.0.1:5555/callback",
1886            "verifier_123",
1887            "https://mcp.example.com",
1888            None,
1889        );
1890        let map: std::collections::HashMap<&str, &str> =
1891            params.iter().map(|(k, v)| (*k, v.as_str())).collect();
1892
1893        assert_eq!(map["grant_type"], "authorization_code");
1894        assert_eq!(map["code"], "auth_code_abc");
1895        assert_eq!(map["redirect_uri"], "http://127.0.0.1:5555/callback");
1896        assert_eq!(map["client_id"], "client_xyz");
1897        assert_eq!(map["code_verifier"], "verifier_123");
1898        assert_eq!(map["resource"], "https://mcp.example.com");
1899    }
1900
1901    #[test]
1902    fn test_token_refresh_params() {
1903        let params = token_refresh_params(
1904            "refresh_token_abc",
1905            "client_xyz",
1906            "https://mcp.example.com",
1907            None,
1908        );
1909        let map: std::collections::HashMap<&str, &str> =
1910            params.iter().map(|(k, v)| (*k, v.as_str())).collect();
1911
1912        assert_eq!(map["grant_type"], "refresh_token");
1913        assert_eq!(map["refresh_token"], "refresh_token_abc");
1914        assert_eq!(map["client_id"], "client_xyz");
1915        assert_eq!(map["resource"], "https://mcp.example.com");
1916    }
1917
1918    // -- Token response tests ------------------------------------------------
1919
1920    #[test]
1921    fn test_token_response_into_tokens_with_expiry() {
1922        let response: TokenResponse = serde_json::from_str(
1923            r#"{"access_token": "at_123", "refresh_token": "rt_456", "expires_in": 3600, "token_type": "Bearer"}"#,
1924        )
1925        .unwrap();
1926
1927        let tokens = response.into_tokens();
1928        assert_eq!(tokens.access_token, "at_123");
1929        assert_eq!(tokens.refresh_token.as_deref(), Some("rt_456"));
1930        assert!(tokens.expires_at.is_some());
1931    }
1932
1933    #[test]
1934    fn test_token_response_into_tokens_minimal() {
1935        let response: TokenResponse =
1936            serde_json::from_str(r#"{"access_token": "at_789"}"#).unwrap();
1937
1938        let tokens = response.into_tokens();
1939        assert_eq!(tokens.access_token, "at_789");
1940        assert_eq!(tokens.refresh_token, None);
1941        assert_eq!(tokens.expires_at, None);
1942    }
1943
1944    // -- DCR body test -------------------------------------------------------
1945
1946    #[test]
1947    fn test_dcr_registration_body_shape() {
1948        let body = dcr_registration_body("http://127.0.0.1:12345/callback");
1949        assert_eq!(body["client_name"], "Zed");
1950        assert_eq!(body["redirect_uris"][0], "http://127.0.0.1:12345/callback");
1951        assert_eq!(body["grant_types"][0], "authorization_code");
1952        assert_eq!(body["response_types"][0], "code");
1953        assert_eq!(body["token_endpoint_auth_method"], "none");
1954    }
1955
1956    // -- Test helpers for async/HTTP tests -----------------------------------
1957
1958    fn make_fake_http_client(
1959        handler: impl Fn(
1960            http_client::Request<AsyncBody>,
1961        ) -> std::pin::Pin<
1962            Box<dyn std::future::Future<Output = anyhow::Result<Response<AsyncBody>>> + Send>,
1963        > + Send
1964        + Sync
1965        + 'static,
1966    ) -> Arc<dyn HttpClient> {
1967        http_client::FakeHttpClient::create(handler) as Arc<dyn HttpClient>
1968    }
1969
1970    fn json_response(status: u16, body: &str) -> anyhow::Result<Response<AsyncBody>> {
1971        Ok(Response::builder()
1972            .status(status)
1973            .header("Content-Type", "application/json")
1974            .body(AsyncBody::from(body.as_bytes().to_vec()))
1975            .unwrap())
1976    }
1977
1978    // -- Discovery integration tests -----------------------------------------
1979
1980    #[test]
1981    fn test_fetch_protected_resource_metadata() {
1982        smol::block_on(async {
1983            let client = make_fake_http_client(|req| {
1984                Box::pin(async move {
1985                    let uri = req.uri().to_string();
1986                    if uri.contains(".well-known/oauth-protected-resource") {
1987                        json_response(
1988                            200,
1989                            r#"{
1990                                "resource": "https://mcp.example.com",
1991                                "authorization_servers": ["https://auth.example.com"],
1992                                "scopes_supported": ["read", "write"]
1993                            }"#,
1994                        )
1995                    } else {
1996                        json_response(404, "{}")
1997                    }
1998                })
1999            });
2000
2001            let server_url = Url::parse("https://mcp.example.com").unwrap();
2002            let www_auth = WwwAuthenticate {
2003                resource_metadata: None,
2004                scope: None,
2005                error: None,
2006                error_description: None,
2007            };
2008
2009            let metadata = fetch_protected_resource_metadata(&client, &server_url, &www_auth)
2010                .await
2011                .unwrap();
2012
2013            assert_eq!(metadata.resource.as_str(), "https://mcp.example.com/");
2014            assert_eq!(metadata.authorization_servers.len(), 1);
2015            assert_eq!(
2016                metadata.authorization_servers[0].as_str(),
2017                "https://auth.example.com/"
2018            );
2019            assert_eq!(
2020                metadata.scopes_supported,
2021                Some(vec!["read".to_string(), "write".to_string()])
2022            );
2023        });
2024    }
2025
2026    #[test]
2027    fn test_fetch_protected_resource_metadata_prefers_www_authenticate_url() {
2028        smol::block_on(async {
2029            let client = make_fake_http_client(|req| {
2030                Box::pin(async move {
2031                    let uri = req.uri().to_string();
2032                    if uri == "https://mcp.example.com/custom-resource-metadata" {
2033                        json_response(
2034                            200,
2035                            r#"{
2036                                "resource": "https://mcp.example.com",
2037                                "authorization_servers": ["https://auth.example.com"]
2038                            }"#,
2039                        )
2040                    } else {
2041                        json_response(500, r#"{"error": "should not be called"}"#)
2042                    }
2043                })
2044            });
2045
2046            let server_url = Url::parse("https://mcp.example.com").unwrap();
2047            let www_auth = WwwAuthenticate {
2048                resource_metadata: Some(
2049                    Url::parse("https://mcp.example.com/custom-resource-metadata").unwrap(),
2050                ),
2051                scope: None,
2052                error: None,
2053                error_description: None,
2054            };
2055
2056            let metadata = fetch_protected_resource_metadata(&client, &server_url, &www_auth)
2057                .await
2058                .unwrap();
2059
2060            assert_eq!(metadata.authorization_servers.len(), 1);
2061        });
2062    }
2063
2064    #[test]
2065    fn test_fetch_protected_resource_metadata_rejects_cross_origin_url() {
2066        smol::block_on(async {
2067            let client = make_fake_http_client(|req| {
2068                Box::pin(async move {
2069                    let uri = req.uri().to_string();
2070                    // The cross-origin URL should NOT be fetched; only the
2071                    // well-known fallback at the server's own origin should be.
2072                    if uri.contains("attacker.example.com") {
2073                        panic!("should not fetch cross-origin resource_metadata URL");
2074                    } else if uri.contains(".well-known/oauth-protected-resource") {
2075                        json_response(
2076                            200,
2077                            r#"{
2078                                "resource": "https://mcp.example.com",
2079                                "authorization_servers": ["https://auth.example.com"]
2080                            }"#,
2081                        )
2082                    } else {
2083                        json_response(404, "{}")
2084                    }
2085                })
2086            });
2087
2088            let server_url = Url::parse("https://mcp.example.com").unwrap();
2089            let www_auth = WwwAuthenticate {
2090                resource_metadata: Some(
2091                    Url::parse("https://attacker.example.com/fake-metadata").unwrap(),
2092                ),
2093                scope: None,
2094                error: None,
2095                error_description: None,
2096            };
2097
2098            let metadata = fetch_protected_resource_metadata(&client, &server_url, &www_auth)
2099                .await
2100                .unwrap();
2101
2102            // Should have used the fallback well-known URL, not the attacker's.
2103            assert_eq!(metadata.resource.as_str(), "https://mcp.example.com/");
2104        });
2105    }
2106
2107    #[test]
2108    fn test_fetch_auth_server_metadata() {
2109        smol::block_on(async {
2110            let client = make_fake_http_client(|req| {
2111                Box::pin(async move {
2112                    let uri = req.uri().to_string();
2113                    if uri.contains(".well-known/oauth-authorization-server") {
2114                        json_response(
2115                            200,
2116                            r#"{
2117                                "issuer": "https://auth.example.com",
2118                                "authorization_endpoint": "https://auth.example.com/authorize",
2119                                "token_endpoint": "https://auth.example.com/token",
2120                                "registration_endpoint": "https://auth.example.com/register",
2121                                "code_challenge_methods_supported": ["S256"],
2122                                "client_id_metadata_document_supported": true
2123                            }"#,
2124                        )
2125                    } else {
2126                        json_response(404, "{}")
2127                    }
2128                })
2129            });
2130
2131            let issuer = Url::parse("https://auth.example.com").unwrap();
2132            let metadata = fetch_auth_server_metadata(&client, &issuer).await.unwrap();
2133
2134            assert_eq!(metadata.issuer.as_str(), "https://auth.example.com/");
2135            assert_eq!(
2136                metadata.authorization_endpoint.as_str(),
2137                "https://auth.example.com/authorize"
2138            );
2139            assert_eq!(
2140                metadata.token_endpoint.as_str(),
2141                "https://auth.example.com/token"
2142            );
2143            assert!(metadata.registration_endpoint.is_some());
2144            assert!(metadata.client_id_metadata_document_supported);
2145            assert_eq!(
2146                metadata.code_challenge_methods_supported,
2147                Some(vec!["S256".to_string()])
2148            );
2149        });
2150    }
2151
2152    #[test]
2153    fn test_fetch_auth_server_metadata_falls_back_to_oidc() {
2154        smol::block_on(async {
2155            let client = make_fake_http_client(|req| {
2156                Box::pin(async move {
2157                    let uri = req.uri().to_string();
2158                    if uri.contains("openid-configuration") {
2159                        json_response(
2160                            200,
2161                            r#"{
2162                                "issuer": "https://auth.example.com",
2163                                "authorization_endpoint": "https://auth.example.com/authorize",
2164                                "token_endpoint": "https://auth.example.com/token",
2165                                "code_challenge_methods_supported": ["S256"]
2166                            }"#,
2167                        )
2168                    } else {
2169                        json_response(404, "{}")
2170                    }
2171                })
2172            });
2173
2174            let issuer = Url::parse("https://auth.example.com").unwrap();
2175            let metadata = fetch_auth_server_metadata(&client, &issuer).await.unwrap();
2176
2177            assert_eq!(
2178                metadata.authorization_endpoint.as_str(),
2179                "https://auth.example.com/authorize"
2180            );
2181            assert!(!metadata.client_id_metadata_document_supported);
2182        });
2183    }
2184
2185    #[test]
2186    fn test_fetch_auth_server_metadata_rejects_issuer_mismatch() {
2187        smol::block_on(async {
2188            let client = make_fake_http_client(|req| {
2189                Box::pin(async move {
2190                    let uri = req.uri().to_string();
2191                    if uri.contains(".well-known/oauth-authorization-server") {
2192                        // Response claims to be a different issuer.
2193                        json_response(
2194                            200,
2195                            r#"{
2196                                "issuer": "https://evil.example.com",
2197                                "authorization_endpoint": "https://evil.example.com/authorize",
2198                                "token_endpoint": "https://evil.example.com/token",
2199                                "code_challenge_methods_supported": ["S256"]
2200                            }"#,
2201                        )
2202                    } else {
2203                        json_response(404, "{}")
2204                    }
2205                })
2206            });
2207
2208            let issuer = Url::parse("https://auth.example.com").unwrap();
2209            let result = fetch_auth_server_metadata(&client, &issuer).await;
2210
2211            assert!(result.is_err());
2212            let err_msg = result.unwrap_err().to_string();
2213            assert!(
2214                err_msg.contains("issuer mismatch"),
2215                "unexpected error: {}",
2216                err_msg
2217            );
2218        });
2219    }
2220
2221    // -- Full discover integration tests -------------------------------------
2222
2223    #[test]
2224    fn test_full_discover_with_cimd() {
2225        smol::block_on(async {
2226            let client = make_fake_http_client(|req| {
2227                Box::pin(async move {
2228                    let uri = req.uri().to_string();
2229                    if uri.contains("oauth-protected-resource") {
2230                        json_response(
2231                            200,
2232                            r#"{
2233                                "resource": "https://mcp.example.com",
2234                                "authorization_servers": ["https://auth.example.com"],
2235                                "scopes_supported": ["mcp:read"]
2236                            }"#,
2237                        )
2238                    } else if uri.contains("oauth-authorization-server") {
2239                        json_response(
2240                            200,
2241                            r#"{
2242                                "issuer": "https://auth.example.com",
2243                                "authorization_endpoint": "https://auth.example.com/authorize",
2244                                "token_endpoint": "https://auth.example.com/token",
2245                                "code_challenge_methods_supported": ["S256"],
2246                                "client_id_metadata_document_supported": true
2247                            }"#,
2248                        )
2249                    } else {
2250                        json_response(404, "{}")
2251                    }
2252                })
2253            });
2254
2255            let server_url = Url::parse("https://mcp.example.com").unwrap();
2256            let www_auth = WwwAuthenticate {
2257                resource_metadata: None,
2258                scope: None,
2259                error: None,
2260                error_description: None,
2261            };
2262
2263            let discovery = discover(&client, &server_url, &www_auth).await.unwrap();
2264            let registration =
2265                resolve_client_registration(&client, &discovery, "http://127.0.0.1:12345/callback")
2266                    .await
2267                    .unwrap();
2268
2269            assert_eq!(registration.client_id, CIMD_URL);
2270            assert_eq!(registration.client_secret, None);
2271            assert_eq!(discovery.scopes, vec!["mcp:read"]);
2272        });
2273    }
2274
2275    #[test]
2276    fn test_full_discover_with_dcr_fallback() {
2277        smol::block_on(async {
2278            let client = make_fake_http_client(|req| {
2279                Box::pin(async move {
2280                    let uri = req.uri().to_string();
2281                    if uri.contains("oauth-protected-resource") {
2282                        json_response(
2283                            200,
2284                            r#"{
2285                                "resource": "https://mcp.example.com",
2286                                "authorization_servers": ["https://auth.example.com"]
2287                            }"#,
2288                        )
2289                    } else if uri.contains("oauth-authorization-server") {
2290                        json_response(
2291                            200,
2292                            r#"{
2293                                "issuer": "https://auth.example.com",
2294                                "authorization_endpoint": "https://auth.example.com/authorize",
2295                                "token_endpoint": "https://auth.example.com/token",
2296                                "registration_endpoint": "https://auth.example.com/register",
2297                                "code_challenge_methods_supported": ["S256"],
2298                                "client_id_metadata_document_supported": false
2299                            }"#,
2300                        )
2301                    } else if uri.contains("/register") {
2302                        json_response(
2303                            201,
2304                            r#"{
2305                                "client_id": "dcr-minted-id-123",
2306                                "client_secret": "dcr-secret-456"
2307                            }"#,
2308                        )
2309                    } else {
2310                        json_response(404, "{}")
2311                    }
2312                })
2313            });
2314
2315            let server_url = Url::parse("https://mcp.example.com").unwrap();
2316            let www_auth = WwwAuthenticate {
2317                resource_metadata: None,
2318                scope: Some(vec!["files:read".into()]),
2319                error: None,
2320                error_description: None,
2321            };
2322
2323            let discovery = discover(&client, &server_url, &www_auth).await.unwrap();
2324            let registration =
2325                resolve_client_registration(&client, &discovery, "http://127.0.0.1:9999/callback")
2326                    .await
2327                    .unwrap();
2328
2329            assert_eq!(registration.client_id, "dcr-minted-id-123");
2330            assert_eq!(
2331                registration.client_secret.as_deref(),
2332                Some("dcr-secret-456")
2333            );
2334            assert_eq!(discovery.scopes, vec!["files:read"]);
2335        });
2336    }
2337
2338    #[test]
2339    fn test_discover_fails_without_pkce_support() {
2340        smol::block_on(async {
2341            let client = make_fake_http_client(|req| {
2342                Box::pin(async move {
2343                    let uri = req.uri().to_string();
2344                    if uri.contains("oauth-protected-resource") {
2345                        json_response(
2346                            200,
2347                            r#"{
2348                                "resource": "https://mcp.example.com",
2349                                "authorization_servers": ["https://auth.example.com"]
2350                            }"#,
2351                        )
2352                    } else if uri.contains("oauth-authorization-server") {
2353                        json_response(
2354                            200,
2355                            r#"{
2356                                "issuer": "https://auth.example.com",
2357                                "authorization_endpoint": "https://auth.example.com/authorize",
2358                                "token_endpoint": "https://auth.example.com/token"
2359                            }"#,
2360                        )
2361                    } else {
2362                        json_response(404, "{}")
2363                    }
2364                })
2365            });
2366
2367            let server_url = Url::parse("https://mcp.example.com").unwrap();
2368            let www_auth = WwwAuthenticate {
2369                resource_metadata: None,
2370                scope: None,
2371                error: None,
2372                error_description: None,
2373            };
2374
2375            let result = discover(&client, &server_url, &www_auth).await;
2376            assert!(result.is_err());
2377            let err_msg = result.unwrap_err().to_string();
2378            assert!(
2379                err_msg.contains("code_challenge_methods_supported"),
2380                "unexpected error: {}",
2381                err_msg
2382            );
2383        });
2384    }
2385
2386    // -- Token exchange integration tests ------------------------------------
2387
2388    #[test]
2389    fn test_exchange_code_success() {
2390        smol::block_on(async {
2391            let client = make_fake_http_client(|req| {
2392                Box::pin(async move {
2393                    let uri = req.uri().to_string();
2394                    if uri.contains("/token") {
2395                        json_response(
2396                            200,
2397                            r#"{
2398                                "access_token": "new_access_token",
2399                                "refresh_token": "new_refresh_token",
2400                                "expires_in": 3600,
2401                                "token_type": "Bearer"
2402                            }"#,
2403                        )
2404                    } else {
2405                        json_response(404, "{}")
2406                    }
2407                })
2408            });
2409
2410            let metadata = AuthServerMetadata {
2411                issuer: Url::parse("https://auth.example.com").unwrap(),
2412                authorization_endpoint: Url::parse("https://auth.example.com/authorize").unwrap(),
2413                token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
2414                registration_endpoint: None,
2415                scopes_supported: None,
2416                code_challenge_methods_supported: Some(vec!["S256".into()]),
2417                client_id_metadata_document_supported: true,
2418            };
2419
2420            let tokens = exchange_code(
2421                &client,
2422                &metadata,
2423                "auth_code_123",
2424                CIMD_URL,
2425                "http://127.0.0.1:9999/callback",
2426                "verifier_abc",
2427                "https://mcp.example.com",
2428                None,
2429            )
2430            .await
2431            .unwrap();
2432
2433            assert_eq!(tokens.access_token, "new_access_token");
2434            assert_eq!(tokens.refresh_token.as_deref(), Some("new_refresh_token"));
2435            assert!(tokens.expires_at.is_some());
2436        });
2437    }
2438
2439    #[test]
2440    fn test_refresh_tokens_success() {
2441        smol::block_on(async {
2442            let client = make_fake_http_client(|req| {
2443                Box::pin(async move {
2444                    let uri = req.uri().to_string();
2445                    if uri.contains("/token") {
2446                        json_response(
2447                            200,
2448                            r#"{
2449                                "access_token": "refreshed_token",
2450                                "expires_in": 1800,
2451                                "token_type": "Bearer"
2452                            }"#,
2453                        )
2454                    } else {
2455                        json_response(404, "{}")
2456                    }
2457                })
2458            });
2459
2460            let token_endpoint = Url::parse("https://auth.example.com/token").unwrap();
2461
2462            let tokens = refresh_tokens(
2463                &client,
2464                &token_endpoint,
2465                "old_refresh_token",
2466                CIMD_URL,
2467                "https://mcp.example.com",
2468                None,
2469            )
2470            .await
2471            .unwrap();
2472
2473            assert_eq!(tokens.access_token, "refreshed_token");
2474            assert_eq!(tokens.refresh_token, None);
2475            assert!(tokens.expires_at.is_some());
2476        });
2477    }
2478
2479    #[test]
2480    fn test_exchange_code_failure() {
2481        smol::block_on(async {
2482            let client = make_fake_http_client(|_req| {
2483                Box::pin(async move { json_response(400, r#"{"error": "invalid_grant"}"#) })
2484            });
2485
2486            let metadata = AuthServerMetadata {
2487                issuer: Url::parse("https://auth.example.com").unwrap(),
2488                authorization_endpoint: Url::parse("https://auth.example.com/authorize").unwrap(),
2489                token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
2490                registration_endpoint: None,
2491                scopes_supported: None,
2492                code_challenge_methods_supported: Some(vec!["S256".into()]),
2493                client_id_metadata_document_supported: true,
2494            };
2495
2496            let result = exchange_code(
2497                &client,
2498                &metadata,
2499                "bad_code",
2500                "client",
2501                "http://127.0.0.1:1/callback",
2502                "verifier",
2503                "https://mcp.example.com",
2504                None,
2505            )
2506            .await;
2507
2508            assert!(result.is_err());
2509            assert!(result.unwrap_err().to_string().contains("400"));
2510        });
2511    }
2512
2513    // -- DCR integration tests -----------------------------------------------
2514
2515    #[test]
2516    fn test_perform_dcr() {
2517        smol::block_on(async {
2518            let client = make_fake_http_client(|_req| {
2519                Box::pin(async move {
2520                    json_response(
2521                        201,
2522                        r#"{
2523                            "client_id": "dynamic-client-001",
2524                            "client_secret": "dynamic-secret-001"
2525                        }"#,
2526                    )
2527                })
2528            });
2529
2530            let endpoint = Url::parse("https://auth.example.com/register").unwrap();
2531            let registration = perform_dcr(&client, &endpoint, "http://127.0.0.1:9999/callback")
2532                .await
2533                .unwrap();
2534
2535            assert_eq!(registration.client_id, "dynamic-client-001");
2536            assert_eq!(
2537                registration.client_secret.as_deref(),
2538                Some("dynamic-secret-001")
2539            );
2540        });
2541    }
2542
2543    #[test]
2544    fn test_perform_dcr_failure() {
2545        smol::block_on(async {
2546            let client = make_fake_http_client(|_req| {
2547                Box::pin(
2548                    async move { json_response(403, r#"{"error": "registration_not_allowed"}"#) },
2549                )
2550            });
2551
2552            let endpoint = Url::parse("https://auth.example.com/register").unwrap();
2553            let result = perform_dcr(&client, &endpoint, "http://127.0.0.1:9999/callback").await;
2554
2555            assert!(result.is_err());
2556            assert!(result.unwrap_err().to_string().contains("403"));
2557        });
2558    }
2559
2560    // -- OAuthCallback parse tests -------------------------------------------
2561
2562    #[test]
2563    fn test_oauth_callback_parse_query() {
2564        let callback = OAuthCallback::parse_query("code=test_auth_code&state=test_state").unwrap();
2565        assert_eq!(callback.code, "test_auth_code");
2566        assert_eq!(callback.state, "test_state");
2567    }
2568
2569    #[test]
2570    fn test_oauth_callback_parse_query_reversed_order() {
2571        let callback = OAuthCallback::parse_query("state=test_state&code=test_auth_code").unwrap();
2572        assert_eq!(callback.code, "test_auth_code");
2573        assert_eq!(callback.state, "test_state");
2574    }
2575
2576    #[test]
2577    fn test_oauth_callback_parse_query_with_extra_params() {
2578        let callback =
2579            OAuthCallback::parse_query("code=test_auth_code&state=test_state&extra=ignored")
2580                .unwrap();
2581        assert_eq!(callback.code, "test_auth_code");
2582        assert_eq!(callback.state, "test_state");
2583    }
2584
2585    #[test]
2586    fn test_oauth_callback_parse_query_missing_code() {
2587        let result = OAuthCallback::parse_query("state=test_state");
2588        assert!(result.is_err());
2589        assert!(result.unwrap_err().to_string().contains("code"));
2590    }
2591
2592    #[test]
2593    fn test_oauth_callback_parse_query_missing_state() {
2594        let result = OAuthCallback::parse_query("code=test_auth_code");
2595        assert!(result.is_err());
2596        assert!(result.unwrap_err().to_string().contains("state"));
2597    }
2598
2599    #[test]
2600    fn test_oauth_callback_parse_query_empty_code() {
2601        let result = OAuthCallback::parse_query("code=&state=test_state");
2602        assert!(result.is_err());
2603    }
2604
2605    #[test]
2606    fn test_oauth_callback_parse_query_empty_state() {
2607        let result = OAuthCallback::parse_query("code=test_auth_code&state=");
2608        assert!(result.is_err());
2609    }
2610
2611    #[test]
2612    fn test_oauth_callback_parse_query_url_encoded_values() {
2613        let callback = OAuthCallback::parse_query("code=abc%20def&state=test%3Dstate").unwrap();
2614        assert_eq!(callback.code, "abc def");
2615        assert_eq!(callback.state, "test=state");
2616    }
2617
2618    #[test]
2619    fn test_oauth_callback_parse_query_error_response() {
2620        let result = OAuthCallback::parse_query(
2621            "error=access_denied&error_description=User%20denied%20access&state=abc",
2622        );
2623        assert!(result.is_err());
2624        let err_msg = result.unwrap_err().to_string();
2625        assert!(
2626            err_msg.contains("access_denied"),
2627            "unexpected error: {}",
2628            err_msg
2629        );
2630        assert!(
2631            err_msg.contains("User denied access"),
2632            "unexpected error: {}",
2633            err_msg
2634        );
2635    }
2636
2637    #[test]
2638    fn test_oauth_callback_parse_query_error_without_description() {
2639        let result = OAuthCallback::parse_query("error=server_error&state=abc");
2640        assert!(result.is_err());
2641        let err_msg = result.unwrap_err().to_string();
2642        assert!(
2643            err_msg.contains("server_error"),
2644            "unexpected error: {}",
2645            err_msg
2646        );
2647        assert!(
2648            err_msg.contains("no description"),
2649            "unexpected error: {}",
2650            err_msg
2651        );
2652    }
2653
2654    // -- McpOAuthTokenProvider tests -----------------------------------------
2655
2656    fn make_test_session(
2657        access_token: &str,
2658        refresh_token: Option<&str>,
2659        expires_at: Option<SystemTime>,
2660    ) -> OAuthSession {
2661        OAuthSession {
2662            token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
2663            resource: Url::parse("https://mcp.example.com").unwrap(),
2664            client_registration: OAuthClientRegistration {
2665                client_id: "test-client".into(),
2666                client_secret: None,
2667            },
2668            tokens: OAuthTokens {
2669                access_token: access_token.into(),
2670                refresh_token: refresh_token.map(String::from),
2671                expires_at,
2672            },
2673        }
2674    }
2675
2676    #[test]
2677    fn test_mcp_oauth_provider_returns_none_when_token_expired() {
2678        let expired = SystemTime::now() - Duration::from_secs(60);
2679        let session = make_test_session("stale-token", Some("rt"), Some(expired));
2680        let provider = McpOAuthTokenProvider::new(
2681            session,
2682            make_fake_http_client(|_| Box::pin(async { unreachable!() })),
2683            None,
2684        );
2685
2686        assert_eq!(provider.access_token(), None);
2687    }
2688
2689    #[test]
2690    fn test_mcp_oauth_provider_returns_token_when_not_expired() {
2691        let far_future = SystemTime::now() + Duration::from_secs(3600);
2692        let session = make_test_session("valid-token", Some("rt"), Some(far_future));
2693        let provider = McpOAuthTokenProvider::new(
2694            session,
2695            make_fake_http_client(|_| Box::pin(async { unreachable!() })),
2696            None,
2697        );
2698
2699        assert_eq!(provider.access_token().as_deref(), Some("valid-token"));
2700    }
2701
2702    #[test]
2703    fn test_mcp_oauth_provider_returns_token_when_no_expiry() {
2704        let session = make_test_session("no-expiry-token", Some("rt"), None);
2705        let provider = McpOAuthTokenProvider::new(
2706            session,
2707            make_fake_http_client(|_| Box::pin(async { unreachable!() })),
2708            None,
2709        );
2710
2711        assert_eq!(provider.access_token().as_deref(), Some("no-expiry-token"));
2712    }
2713
2714    #[test]
2715    fn test_mcp_oauth_provider_refresh_without_refresh_token_returns_false() {
2716        smol::block_on(async {
2717            let session = make_test_session("token", None, None);
2718            let provider = McpOAuthTokenProvider::new(
2719                session,
2720                make_fake_http_client(|_| {
2721                    Box::pin(async { unreachable!("no HTTP call expected") })
2722                }),
2723                None,
2724            );
2725
2726            let refreshed = provider.try_refresh().await.unwrap();
2727            assert!(!refreshed);
2728        });
2729    }
2730
2731    #[test]
2732    fn test_mcp_oauth_provider_refresh_updates_session_and_notifies_channel() {
2733        smol::block_on(async {
2734            let session = make_test_session("old-access", Some("my-refresh-token"), None);
2735            let (tx, mut rx) = futures::channel::mpsc::unbounded();
2736
2737            let http_client = make_fake_http_client(|_req| {
2738                Box::pin(async {
2739                    json_response(
2740                        200,
2741                        r#"{
2742                            "access_token": "new-access",
2743                            "refresh_token": "new-refresh",
2744                            "expires_in": 1800
2745                        }"#,
2746                    )
2747                })
2748            });
2749
2750            let provider = McpOAuthTokenProvider::new(session, http_client, Some(tx));
2751
2752            let refreshed = provider.try_refresh().await.unwrap();
2753            assert!(refreshed);
2754            assert_eq!(provider.access_token().as_deref(), Some("new-access"));
2755
2756            let notified_session = rx.try_recv().expect("channel should have a session");
2757            assert_eq!(notified_session.tokens.access_token, "new-access");
2758            assert_eq!(
2759                notified_session.tokens.refresh_token.as_deref(),
2760                Some("new-refresh")
2761            );
2762        });
2763    }
2764
2765    #[test]
2766    fn test_mcp_oauth_provider_refresh_preserves_old_refresh_token_when_server_omits_it() {
2767        smol::block_on(async {
2768            let session = make_test_session("old-access", Some("original-refresh"), None);
2769            let (tx, mut rx) = futures::channel::mpsc::unbounded();
2770
2771            let http_client = make_fake_http_client(|_req| {
2772                Box::pin(async {
2773                    json_response(
2774                        200,
2775                        r#"{
2776                            "access_token": "new-access",
2777                            "expires_in": 900
2778                        }"#,
2779                    )
2780                })
2781            });
2782
2783            let provider = McpOAuthTokenProvider::new(session, http_client, Some(tx));
2784
2785            let refreshed = provider.try_refresh().await.unwrap();
2786            assert!(refreshed);
2787
2788            let notified_session = rx.try_recv().expect("channel should have a session");
2789            assert_eq!(notified_session.tokens.access_token, "new-access");
2790            assert_eq!(
2791                notified_session.tokens.refresh_token.as_deref(),
2792                Some("original-refresh"),
2793            );
2794        });
2795    }
2796
2797    #[test]
2798    fn test_mcp_oauth_provider_refresh_returns_false_on_http_error() {
2799        smol::block_on(async {
2800            let session = make_test_session("old-access", Some("my-refresh"), None);
2801
2802            let http_client = make_fake_http_client(|_req| {
2803                Box::pin(async { json_response(401, r#"{"error": "invalid_grant"}"#) })
2804            });
2805
2806            let provider = McpOAuthTokenProvider::new(session, http_client, None);
2807
2808            let refreshed = provider.try_refresh().await.unwrap();
2809            assert!(!refreshed);
2810            // The old token should still be in place.
2811            assert_eq!(provider.access_token().as_deref(), Some("old-access"));
2812        });
2813    }
2814}