oauth.rs

   1//! OAuth 2.0 authentication for MCP servers using the Authorization Code +
   2//! PKCE flow, per the MCP spec's OAuth profile.
   3//!
   4//! The flow is split into two phases:
   5//!
   6//! 1. **Discovery** ([`discover`]) fetches Protected Resource Metadata and
   7//!    Authorization Server Metadata. This can happen early (e.g. on a 401
   8//!    during server startup) because it doesn't need the redirect URI yet.
   9//!
  10//! 2. **Client registration** ([`resolve_client_registration`]) is separate
  11//!    because DCR requires the actual loopback redirect URI, which includes an
  12//!    ephemeral port that only exists once the callback server has started.
  13//!
  14//! After authentication, the full state is captured in [`OAuthSession`] which
  15//! is persisted to the keychain. On next startup, the stored session feeds
  16//! directly into [`McpOAuthTokenProvider`], giving a refresh-capable provider
  17//! without requiring another browser flow.
  18
  19use anyhow::{Context as _, Result, anyhow, bail};
  20use async_trait::async_trait;
  21use base64::Engine as _;
  22use futures::AsyncReadExt as _;
  23use futures::channel::mpsc;
  24use http_client::{AsyncBody, HttpClient, Request};
  25use parking_lot::Mutex as SyncMutex;
  26use rand::Rng as _;
  27use serde::{Deserialize, Serialize};
  28use sha2::{Digest, Sha256};
  29
  30use std::sync::Arc;
  31use std::time::{Duration, SystemTime};
  32use url::Url;
  33
  34/// The CIMD URL where Zed's OAuth client metadata document is hosted.
  35pub const CIMD_URL: &str = "https://zed.dev/oauth/client-metadata.json";
  36
  37/// Validate that a URL is safe to use as an OAuth endpoint.
  38///
  39/// OAuth endpoints carry sensitive material (authorization codes, PKCE
  40/// verifiers, tokens) and must use TLS. Plain HTTP is only permitted for
  41/// loopback addresses, per RFC 8252 Section 8.3.
  42fn require_https_or_loopback(url: &Url) -> Result<()> {
  43    if url.scheme() == "https" {
  44        return Ok(());
  45    }
  46    if url.scheme() == "http" {
  47        if let Some(host) = url.host() {
  48            match host {
  49                url::Host::Ipv4(ip) if ip.is_loopback() => return Ok(()),
  50                url::Host::Ipv6(ip) if ip.is_loopback() => return Ok(()),
  51                url::Host::Domain(d) if d.eq_ignore_ascii_case("localhost") => return Ok(()),
  52                _ => {}
  53            }
  54        }
  55    }
  56    bail!(
  57        "OAuth endpoint must use HTTPS (got {}://{})",
  58        url.scheme(),
  59        url.host_str().unwrap_or("?")
  60    )
  61}
  62
  63/// Validate that a URL is safe to use as an OAuth endpoint, including SSRF
  64/// protections against private/reserved IP ranges.
  65///
  66/// This wraps [`require_https_or_loopback`] and adds IP-range checks to prevent
  67/// an attacker-controlled MCP server from directing Zed to fetch internal
  68/// network resources via metadata URLs.
  69///
  70/// **Known limitation:** Domain-name URLs that resolve to private IPs are *not*
  71/// blocked here — full mitigation requires resolver-level validation (e.g. a
  72/// custom `Resolve` implementation). This function only blocks IP-literal URLs.
  73fn validate_oauth_url(url: &Url) -> Result<()> {
  74    require_https_or_loopback(url)?;
  75
  76    if let Some(host) = url.host() {
  77        match host {
  78            url::Host::Ipv4(ip) => {
  79                // Loopback is already allowed by require_https_or_loopback.
  80                if ip.is_private() || ip.is_link_local() || ip.is_broadcast() || ip.is_unspecified()
  81                {
  82                    bail!(
  83                        "OAuth endpoint must not point to private/reserved IP: {}",
  84                        ip
  85                    );
  86                }
  87            }
  88            url::Host::Ipv6(ip) => {
  89                // Check for IPv4-mapped IPv6 addresses (::ffff:a.b.c.d) which
  90                // could bypass the IPv4 checks above.
  91                if let Some(mapped_v4) = ip.to_ipv4_mapped() {
  92                    if mapped_v4.is_private()
  93                        || mapped_v4.is_link_local()
  94                        || mapped_v4.is_broadcast()
  95                        || mapped_v4.is_unspecified()
  96                    {
  97                        bail!(
  98                            "OAuth endpoint must not point to private/reserved IP: ::ffff:{}",
  99                            mapped_v4
 100                        );
 101                    }
 102                }
 103
 104                if ip.is_unspecified() || ip.is_multicast() {
 105                    bail!(
 106                        "OAuth endpoint must not point to reserved IPv6 address: {}",
 107                        ip
 108                    );
 109                }
 110                // IPv6 Unique Local Addresses (fc00::/7). is_unique_local() is
 111                // nightly-only, so check the prefix manually.
 112                if (ip.segments()[0] & 0xfe00) == 0xfc00 {
 113                    bail!(
 114                        "OAuth endpoint must not point to IPv6 unique-local address: {}",
 115                        ip
 116                    );
 117                }
 118            }
 119            url::Host::Domain(_) => {
 120                // Domain-based SSRF prevention requires resolver-level checks.
 121                // See known limitation in the doc comment above.
 122            }
 123        }
 124    }
 125
 126    Ok(())
 127}
 128
 129/// Parsed from the MCP server's WWW-Authenticate header or well-known endpoint
 130/// per RFC 9728 (OAuth 2.0 Protected Resource Metadata).
 131#[derive(Debug, Clone, Serialize, Deserialize)]
 132pub struct ProtectedResourceMetadata {
 133    pub resource: Url,
 134    pub authorization_servers: Vec<Url>,
 135    pub scopes_supported: Option<Vec<String>>,
 136}
 137
 138/// Parsed from the authorization server's .well-known endpoint
 139/// per RFC 8414 (OAuth 2.0 Authorization Server Metadata).
 140#[derive(Debug, Clone, Serialize, Deserialize)]
 141pub struct AuthServerMetadata {
 142    pub issuer: Url,
 143    pub authorization_endpoint: Url,
 144    pub token_endpoint: Url,
 145    pub registration_endpoint: Option<Url>,
 146    pub scopes_supported: Option<Vec<String>>,
 147    pub code_challenge_methods_supported: Option<Vec<String>>,
 148    pub client_id_metadata_document_supported: bool,
 149}
 150
 151/// The result of client registration — either CIMD or DCR.
 152#[derive(Clone, Serialize, Deserialize)]
 153pub struct OAuthClientRegistration {
 154    pub client_id: String,
 155    /// Only present for DCR-minted registrations.
 156    pub client_secret: Option<String>,
 157}
 158
 159impl std::fmt::Debug for OAuthClientRegistration {
 160    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 161        f.debug_struct("OAuthClientRegistration")
 162            .field("client_id", &self.client_id)
 163            .field(
 164                "client_secret",
 165                &self.client_secret.as_ref().map(|_| "[redacted]"),
 166            )
 167            .finish()
 168    }
 169}
 170
 171/// Access and refresh tokens obtained from the token endpoint.
 172#[derive(Clone, Serialize, Deserialize)]
 173pub struct OAuthTokens {
 174    pub access_token: String,
 175    pub refresh_token: Option<String>,
 176    pub expires_at: Option<SystemTime>,
 177}
 178
 179impl std::fmt::Debug for OAuthTokens {
 180    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 181        f.debug_struct("OAuthTokens")
 182            .field("access_token", &"[redacted]")
 183            .field(
 184                "refresh_token",
 185                &self.refresh_token.as_ref().map(|_| "[redacted]"),
 186            )
 187            .field("expires_at", &self.expires_at)
 188            .finish()
 189    }
 190}
 191
 192/// Everything discovered before the browser flow starts. Client registration is
 193/// resolved separately, once the real redirect URI is known.
 194#[derive(Debug, Clone, Serialize, Deserialize)]
 195pub struct OAuthDiscovery {
 196    pub resource_metadata: ProtectedResourceMetadata,
 197    pub auth_server_metadata: AuthServerMetadata,
 198    pub scopes: Vec<String>,
 199}
 200
 201/// The persisted OAuth session for a context server.
 202///
 203/// Stored in the keychain so startup can restore a refresh-capable provider
 204/// without another browser flow. Deliberately excludes the full discovery
 205/// metadata to keep the serialized size well within keychain item limits.
 206#[derive(Clone, Serialize, Deserialize)]
 207pub struct OAuthSession {
 208    pub token_endpoint: Url,
 209    pub resource: Url,
 210    pub client_registration: OAuthClientRegistration,
 211    pub tokens: OAuthTokens,
 212}
 213
 214impl std::fmt::Debug for OAuthSession {
 215    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 216        f.debug_struct("OAuthSession")
 217            .field("token_endpoint", &self.token_endpoint)
 218            .field("resource", &self.resource)
 219            .field("client_registration", &self.client_registration)
 220            .field("tokens", &self.tokens)
 221            .finish()
 222    }
 223}
 224
 225/// Error codes defined by RFC 6750 Section 3.1 for Bearer token authentication.
 226#[derive(Debug, Clone, Copy, PartialEq, Eq)]
 227pub enum BearerError {
 228    /// The request is missing a required parameter, includes an unsupported
 229    /// parameter or parameter value, or is otherwise malformed.
 230    InvalidRequest,
 231    /// The access token provided is expired, revoked, malformed, or invalid.
 232    InvalidToken,
 233    /// The request requires higher privileges than provided by the access token.
 234    InsufficientScope,
 235    /// An unrecognized error code (extension or future spec addition).
 236    Other,
 237}
 238
 239impl BearerError {
 240    fn parse(value: &str) -> Self {
 241        match value {
 242            "invalid_request" => BearerError::InvalidRequest,
 243            "invalid_token" => BearerError::InvalidToken,
 244            "insufficient_scope" => BearerError::InsufficientScope,
 245            _ => BearerError::Other,
 246        }
 247    }
 248}
 249
 250/// Fields extracted from a `WWW-Authenticate: Bearer` header.
 251///
 252/// Per RFC 9728 Section 5.1, MCP servers include `resource_metadata` to point
 253/// at the Protected Resource Metadata document. The optional `scope` parameter
 254/// (RFC 6750 Section 3) indicates scopes required for the request.
 255#[derive(Debug, Clone, PartialEq, Eq)]
 256pub struct WwwAuthenticate {
 257    pub resource_metadata: Option<Url>,
 258    pub scope: Option<Vec<String>>,
 259    /// The parsed `error` parameter per RFC 6750 Section 3.1.
 260    pub error: Option<BearerError>,
 261    pub error_description: Option<String>,
 262}
 263
 264/// Parse a `WWW-Authenticate` header value.
 265///
 266/// Expects the `Bearer` scheme followed by comma-separated `key="value"` pairs.
 267/// Per RFC 6750 and RFC 9728, the relevant parameters are:
 268/// - `resource_metadata` — URL of the Protected Resource Metadata document
 269/// - `scope` — space-separated list of required scopes
 270/// - `error` — error code (e.g. "insufficient_scope")
 271/// - `error_description` — human-readable error description
 272pub fn parse_www_authenticate(header: &str) -> Result<WwwAuthenticate> {
 273    let header = header.trim();
 274
 275    let params_str = if header.len() >= 6 && header[..6].eq_ignore_ascii_case("bearer") {
 276        header[6..].trim()
 277    } else {
 278        bail!("WWW-Authenticate header does not use Bearer scheme");
 279    };
 280
 281    if params_str.is_empty() {
 282        return Ok(WwwAuthenticate {
 283            resource_metadata: None,
 284            scope: None,
 285            error: None,
 286            error_description: None,
 287        });
 288    }
 289
 290    let params = parse_auth_params(params_str);
 291
 292    let resource_metadata = params
 293        .get("resource_metadata")
 294        .map(|v| Url::parse(v))
 295        .transpose()
 296        .map_err(|e| anyhow!("invalid resource_metadata URL: {}", e))?;
 297
 298    let scope = params
 299        .get("scope")
 300        .map(|v| v.split_whitespace().map(String::from).collect());
 301
 302    let error = params.get("error").map(|v| BearerError::parse(v));
 303    let error_description = params.get("error_description").cloned();
 304
 305    Ok(WwwAuthenticate {
 306        resource_metadata,
 307        scope,
 308        error,
 309        error_description,
 310    })
 311}
 312
 313/// Parse comma-separated `key="value"` or `key=token` parameters from an
 314/// auth-param list (RFC 7235 Section 2.1).
 315fn parse_auth_params(input: &str) -> collections::HashMap<String, String> {
 316    let mut params = collections::HashMap::default();
 317    let mut remaining = input.trim();
 318
 319    while !remaining.is_empty() {
 320        // Skip leading whitespace and commas.
 321        remaining = remaining.trim_start_matches(|c: char| c == ',' || c.is_whitespace());
 322        if remaining.is_empty() {
 323            break;
 324        }
 325
 326        // Find the key (everything before '=').
 327        let eq_pos = match remaining.find('=') {
 328            Some(pos) => pos,
 329            None => break,
 330        };
 331
 332        let key = remaining[..eq_pos].trim().to_lowercase();
 333        remaining = &remaining[eq_pos + 1..];
 334        remaining = remaining.trim_start();
 335
 336        // Parse the value: either quoted or unquoted (token).
 337        let value;
 338        if remaining.starts_with('"') {
 339            // Quoted string: find the closing quote, handling escaped chars.
 340            remaining = &remaining[1..]; // skip opening quote
 341            let mut val = String::new();
 342            let mut chars = remaining.char_indices();
 343            loop {
 344                match chars.next() {
 345                    Some((_, '\\')) => {
 346                        // Escaped character — take the next char literally.
 347                        if let Some((_, c)) = chars.next() {
 348                            val.push(c);
 349                        }
 350                    }
 351                    Some((i, '"')) => {
 352                        remaining = &remaining[i + 1..];
 353                        break;
 354                    }
 355                    Some((_, c)) => val.push(c),
 356                    None => {
 357                        remaining = "";
 358                        break;
 359                    }
 360                }
 361            }
 362            value = val;
 363        } else {
 364            // Unquoted token: read until comma or whitespace.
 365            let end = remaining
 366                .find(|c: char| c == ',' || c.is_whitespace())
 367                .unwrap_or(remaining.len());
 368            value = remaining[..end].to_string();
 369            remaining = &remaining[end..];
 370        }
 371
 372        if !key.is_empty() {
 373            params.insert(key, value);
 374        }
 375    }
 376
 377    params
 378}
 379
 380/// Construct the well-known Protected Resource Metadata URIs for a given MCP
 381/// server URL, per RFC 9728 Section 3.
 382///
 383/// Returns URIs in priority order:
 384/// 1. Path-specific: `https://<host>/.well-known/oauth-protected-resource/<path>`
 385/// 2. Root: `https://<host>/.well-known/oauth-protected-resource`
 386pub fn protected_resource_metadata_urls(server_url: &Url) -> Vec<Url> {
 387    let mut urls = Vec::new();
 388    let base = format!("{}://{}", server_url.scheme(), server_url.authority());
 389
 390    let path = server_url.path().trim_start_matches('/');
 391    if !path.is_empty() {
 392        if let Ok(url) = Url::parse(&format!(
 393            "{}/.well-known/oauth-protected-resource/{}",
 394            base, path
 395        )) {
 396            urls.push(url);
 397        }
 398    }
 399
 400    if let Ok(url) = Url::parse(&format!("{}/.well-known/oauth-protected-resource", base)) {
 401        urls.push(url);
 402    }
 403
 404    urls
 405}
 406
 407/// Construct the well-known Authorization Server Metadata URIs for a given
 408/// issuer URL, per RFC 8414 Section 3.1 and Section 5 (OIDC compat).
 409///
 410/// Returns URIs in priority order, which differs depending on whether the
 411/// issuer URL has a path component.
 412pub fn auth_server_metadata_urls(issuer: &Url) -> Vec<Url> {
 413    let mut urls = Vec::new();
 414    let base = format!("{}://{}", issuer.scheme(), issuer.authority());
 415    let path = issuer.path().trim_matches('/');
 416
 417    if !path.is_empty() {
 418        // Issuer with path: try path-inserted variants first.
 419        if let Ok(url) = Url::parse(&format!(
 420            "{}/.well-known/oauth-authorization-server/{}",
 421            base, path
 422        )) {
 423            urls.push(url);
 424        }
 425        if let Ok(url) = Url::parse(&format!(
 426            "{}/.well-known/openid-configuration/{}",
 427            base, path
 428        )) {
 429            urls.push(url);
 430        }
 431        if let Ok(url) = Url::parse(&format!(
 432            "{}/{}/.well-known/openid-configuration",
 433            base, path
 434        )) {
 435            urls.push(url);
 436        }
 437    } else {
 438        // No path: standard well-known locations.
 439        if let Ok(url) = Url::parse(&format!("{}/.well-known/oauth-authorization-server", base)) {
 440            urls.push(url);
 441        }
 442        if let Ok(url) = Url::parse(&format!("{}/.well-known/openid-configuration", base)) {
 443            urls.push(url);
 444        }
 445    }
 446
 447    urls
 448}
 449
 450// -- Canonical server URI (RFC 8707) -----------------------------------------
 451
 452/// Derive the canonical resource URI for an MCP server URL, suitable for the
 453/// `resource` parameter in authorization and token requests per RFC 8707.
 454///
 455/// Lowercases the scheme and host, preserves the path (without trailing slash),
 456/// strips fragments and query strings.
 457pub fn canonical_server_uri(server_url: &Url) -> String {
 458    let mut uri = format!(
 459        "{}://{}",
 460        server_url.scheme().to_ascii_lowercase(),
 461        server_url.host_str().unwrap_or("").to_ascii_lowercase(),
 462    );
 463    if let Some(port) = server_url.port() {
 464        uri.push_str(&format!(":{}", port));
 465    }
 466    let path = server_url.path();
 467    if path != "/" {
 468        uri.push_str(path.trim_end_matches('/'));
 469    }
 470    uri
 471}
 472
 473// -- Scope selection ---------------------------------------------------------
 474
 475/// Select scopes following the MCP spec's Scope Selection Strategy:
 476/// 1. Use `scope` from the `WWW-Authenticate` challenge if present.
 477/// 2. Fall back to `scopes_supported` from Protected Resource Metadata.
 478/// 3. Return empty if neither is available.
 479pub fn select_scopes(
 480    www_authenticate: &WwwAuthenticate,
 481    resource_metadata: &ProtectedResourceMetadata,
 482) -> Vec<String> {
 483    if let Some(ref scopes) = www_authenticate.scope {
 484        if !scopes.is_empty() {
 485            return scopes.clone();
 486        }
 487    }
 488    resource_metadata
 489        .scopes_supported
 490        .clone()
 491        .unwrap_or_default()
 492}
 493
 494// -- Client registration strategy --------------------------------------------
 495
 496/// The registration approach to use, determined from auth server metadata.
 497#[derive(Debug, Clone, PartialEq, Eq)]
 498pub enum ClientRegistrationStrategy {
 499    /// The auth server supports CIMD. Use the CIMD URL as client_id directly.
 500    Cimd { client_id: String },
 501    /// The auth server has a registration endpoint. Caller must POST to it.
 502    Dcr { registration_endpoint: Url },
 503    /// No supported registration mechanism.
 504    Unavailable,
 505}
 506
 507/// Determine how to register with the authorization server, following the
 508/// spec's recommended priority: CIMD first, DCR fallback.
 509pub fn determine_registration_strategy(
 510    auth_server_metadata: &AuthServerMetadata,
 511) -> ClientRegistrationStrategy {
 512    if auth_server_metadata.client_id_metadata_document_supported {
 513        ClientRegistrationStrategy::Cimd {
 514            client_id: CIMD_URL.to_string(),
 515        }
 516    } else if let Some(ref endpoint) = auth_server_metadata.registration_endpoint {
 517        ClientRegistrationStrategy::Dcr {
 518            registration_endpoint: endpoint.clone(),
 519        }
 520    } else {
 521        ClientRegistrationStrategy::Unavailable
 522    }
 523}
 524
 525// -- PKCE (RFC 7636) ---------------------------------------------------------
 526
 527/// A PKCE code verifier and its S256 challenge.
 528#[derive(Clone)]
 529pub struct PkceChallenge {
 530    pub verifier: String,
 531    pub challenge: String,
 532}
 533
 534impl std::fmt::Debug for PkceChallenge {
 535    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 536        f.debug_struct("PkceChallenge")
 537            .field("verifier", &"[redacted]")
 538            .field("challenge", &self.challenge)
 539            .finish()
 540    }
 541}
 542
 543/// Generate a PKCE code verifier and S256 challenge per RFC 7636.
 544///
 545/// The verifier is 43 base64url characters derived from 32 random bytes.
 546/// The challenge is `BASE64URL(SHA256(verifier))`.
 547pub fn generate_pkce_challenge() -> PkceChallenge {
 548    let mut random_bytes = [0u8; 32];
 549    rand::rng().fill(&mut random_bytes);
 550    let engine = base64::engine::general_purpose::URL_SAFE_NO_PAD;
 551    let verifier = engine.encode(&random_bytes);
 552
 553    let digest = Sha256::digest(verifier.as_bytes());
 554    let challenge = engine.encode(digest);
 555
 556    PkceChallenge {
 557        verifier,
 558        challenge,
 559    }
 560}
 561
 562// -- Authorization URL construction ------------------------------------------
 563
 564/// Build the authorization URL for the OAuth Authorization Code + PKCE flow.
 565pub fn build_authorization_url(
 566    auth_server_metadata: &AuthServerMetadata,
 567    client_id: &str,
 568    redirect_uri: &str,
 569    scopes: &[String],
 570    resource: &str,
 571    pkce: &PkceChallenge,
 572    state: &str,
 573) -> Url {
 574    let mut url = auth_server_metadata.authorization_endpoint.clone();
 575    {
 576        let mut query = url.query_pairs_mut();
 577        query.append_pair("response_type", "code");
 578        query.append_pair("client_id", client_id);
 579        query.append_pair("redirect_uri", redirect_uri);
 580        if !scopes.is_empty() {
 581            query.append_pair("scope", &scopes.join(" "));
 582        }
 583        query.append_pair("resource", resource);
 584        query.append_pair("code_challenge", &pkce.challenge);
 585        query.append_pair("code_challenge_method", "S256");
 586        query.append_pair("state", state);
 587    }
 588    url
 589}
 590
 591// -- Token endpoint request bodies -------------------------------------------
 592
 593/// The JSON body returned by the token endpoint on success.
 594#[derive(Deserialize)]
 595pub struct TokenResponse {
 596    pub access_token: String,
 597    #[serde(default)]
 598    pub refresh_token: Option<String>,
 599    #[serde(default)]
 600    pub expires_in: Option<u64>,
 601    #[serde(default)]
 602    pub token_type: Option<String>,
 603}
 604
 605impl std::fmt::Debug for TokenResponse {
 606    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 607        f.debug_struct("TokenResponse")
 608            .field("access_token", &"[redacted]")
 609            .field(
 610                "refresh_token",
 611                &self.refresh_token.as_ref().map(|_| "[redacted]"),
 612            )
 613            .field("expires_in", &self.expires_in)
 614            .field("token_type", &self.token_type)
 615            .finish()
 616    }
 617}
 618
 619impl TokenResponse {
 620    /// Convert into `OAuthTokens`, computing `expires_at` from `expires_in`.
 621    pub fn into_tokens(self) -> OAuthTokens {
 622        let expires_at = self
 623            .expires_in
 624            .map(|secs| SystemTime::now() + Duration::from_secs(secs));
 625        OAuthTokens {
 626            access_token: self.access_token,
 627            refresh_token: self.refresh_token,
 628            expires_at,
 629        }
 630    }
 631}
 632
 633/// Build the form-encoded body for an authorization code token exchange.
 634pub fn token_exchange_params(
 635    code: &str,
 636    client_id: &str,
 637    redirect_uri: &str,
 638    code_verifier: &str,
 639    resource: &str,
 640) -> Vec<(&'static str, String)> {
 641    vec![
 642        ("grant_type", "authorization_code".to_string()),
 643        ("code", code.to_string()),
 644        ("redirect_uri", redirect_uri.to_string()),
 645        ("client_id", client_id.to_string()),
 646        ("code_verifier", code_verifier.to_string()),
 647        ("resource", resource.to_string()),
 648    ]
 649}
 650
 651/// Build the form-encoded body for a token refresh request.
 652pub fn token_refresh_params(
 653    refresh_token: &str,
 654    client_id: &str,
 655    resource: &str,
 656) -> Vec<(&'static str, String)> {
 657    vec![
 658        ("grant_type", "refresh_token".to_string()),
 659        ("refresh_token", refresh_token.to_string()),
 660        ("client_id", client_id.to_string()),
 661        ("resource", resource.to_string()),
 662    ]
 663}
 664
 665// -- DCR request body (RFC 7591) ---------------------------------------------
 666
 667/// Build the JSON body for a Dynamic Client Registration request.
 668///
 669/// The `redirect_uri` should be the actual loopback URI with the ephemeral
 670/// port (e.g. `http://127.0.0.1:12345/callback`). Some auth servers do strict
 671/// redirect URI matching even for loopback addresses, so we register the
 672/// exact URI we intend to use.
 673pub fn dcr_registration_body(redirect_uri: &str) -> serde_json::Value {
 674    serde_json::json!({
 675        "client_name": "Zed",
 676        "redirect_uris": [redirect_uri],
 677        "grant_types": ["authorization_code"],
 678        "response_types": ["code"],
 679        "token_endpoint_auth_method": "none"
 680    })
 681}
 682
 683// -- Discovery (async, hits real endpoints) ----------------------------------
 684
 685/// Fetch Protected Resource Metadata from the MCP server.
 686///
 687/// Tries the `resource_metadata` URL from the `WWW-Authenticate` header first,
 688/// then falls back to well-known URIs constructed from `server_url`.
 689pub async fn fetch_protected_resource_metadata(
 690    http_client: &Arc<dyn HttpClient>,
 691    server_url: &Url,
 692    www_authenticate: &WwwAuthenticate,
 693) -> Result<ProtectedResourceMetadata> {
 694    let candidate_urls = match &www_authenticate.resource_metadata {
 695        Some(url) if url.origin() == server_url.origin() => vec![url.clone()],
 696        Some(url) => {
 697            log::warn!(
 698                "Ignoring cross-origin resource_metadata URL {} \
 699                 (server origin: {})",
 700                url,
 701                server_url.origin().unicode_serialization()
 702            );
 703            protected_resource_metadata_urls(server_url)
 704        }
 705        None => protected_resource_metadata_urls(server_url),
 706    };
 707
 708    for url in &candidate_urls {
 709        match fetch_json::<ProtectedResourceMetadataResponse>(http_client, url).await {
 710            Ok(response) => {
 711                if response.authorization_servers.is_empty() {
 712                    bail!(
 713                        "Protected Resource Metadata at {} has no authorization_servers",
 714                        url
 715                    );
 716                }
 717                return Ok(ProtectedResourceMetadata {
 718                    resource: response.resource.unwrap_or_else(|| server_url.clone()),
 719                    authorization_servers: response.authorization_servers,
 720                    scopes_supported: response.scopes_supported,
 721                });
 722            }
 723            Err(err) => {
 724                log::debug!(
 725                    "Failed to fetch Protected Resource Metadata from {}: {}",
 726                    url,
 727                    err
 728                );
 729            }
 730        }
 731    }
 732
 733    bail!(
 734        "Could not fetch Protected Resource Metadata for {}",
 735        server_url
 736    )
 737}
 738
 739/// Fetch Authorization Server Metadata, trying RFC 8414 and OIDC Discovery
 740/// endpoints in the priority order specified by the MCP spec.
 741pub async fn fetch_auth_server_metadata(
 742    http_client: &Arc<dyn HttpClient>,
 743    issuer: &Url,
 744) -> Result<AuthServerMetadata> {
 745    let candidate_urls = auth_server_metadata_urls(issuer);
 746
 747    for url in &candidate_urls {
 748        match fetch_json::<AuthServerMetadataResponse>(http_client, url).await {
 749            Ok(response) => {
 750                let reported_issuer = response.issuer.unwrap_or_else(|| issuer.clone());
 751                if reported_issuer != *issuer {
 752                    bail!(
 753                        "Auth server metadata issuer mismatch: expected {}, got {}",
 754                        issuer,
 755                        reported_issuer
 756                    );
 757                }
 758
 759                return Ok(AuthServerMetadata {
 760                    issuer: reported_issuer,
 761                    authorization_endpoint: response
 762                        .authorization_endpoint
 763                        .ok_or_else(|| anyhow!("missing authorization_endpoint"))?,
 764                    token_endpoint: response
 765                        .token_endpoint
 766                        .ok_or_else(|| anyhow!("missing token_endpoint"))?,
 767                    registration_endpoint: response.registration_endpoint,
 768                    scopes_supported: response.scopes_supported,
 769                    code_challenge_methods_supported: response.code_challenge_methods_supported,
 770                    client_id_metadata_document_supported: response
 771                        .client_id_metadata_document_supported
 772                        .unwrap_or(false),
 773                });
 774            }
 775            Err(err) => {
 776                log::debug!("Failed to fetch Auth Server Metadata from {}: {}", url, err);
 777            }
 778        }
 779    }
 780
 781    bail!(
 782        "Could not fetch Authorization Server Metadata for {}",
 783        issuer
 784    )
 785}
 786
 787/// Run the full discovery flow: fetch resource metadata, then auth server
 788/// metadata, then select scopes. Client registration is resolved separately,
 789/// once the real redirect URI is known.
 790pub async fn discover(
 791    http_client: &Arc<dyn HttpClient>,
 792    server_url: &Url,
 793    www_authenticate: &WwwAuthenticate,
 794) -> Result<OAuthDiscovery> {
 795    let resource_metadata =
 796        fetch_protected_resource_metadata(http_client, server_url, www_authenticate).await?;
 797
 798    let auth_server_url = resource_metadata
 799        .authorization_servers
 800        .first()
 801        .ok_or_else(|| anyhow!("no authorization servers in resource metadata"))?;
 802
 803    let auth_server_metadata = fetch_auth_server_metadata(http_client, auth_server_url).await?;
 804
 805    // Verify PKCE S256 support (spec requirement).
 806    match &auth_server_metadata.code_challenge_methods_supported {
 807        Some(methods) if methods.iter().any(|m| m == "S256") => {}
 808        Some(_) => bail!("authorization server does not support S256 PKCE"),
 809        None => bail!("authorization server does not advertise code_challenge_methods_supported"),
 810    }
 811
 812    // Verify there is at least one supported registration strategy before we
 813    // present the server as ready to authenticate.
 814    match determine_registration_strategy(&auth_server_metadata) {
 815        ClientRegistrationStrategy::Cimd { .. } | ClientRegistrationStrategy::Dcr { .. } => {}
 816        ClientRegistrationStrategy::Unavailable => {
 817            bail!("authorization server supports neither CIMD nor DCR")
 818        }
 819    }
 820
 821    let scopes = select_scopes(www_authenticate, &resource_metadata);
 822
 823    Ok(OAuthDiscovery {
 824        resource_metadata,
 825        auth_server_metadata,
 826        scopes,
 827    })
 828}
 829
 830/// Resolve the OAuth client registration for an authorization flow.
 831///
 832/// CIMD uses the static client metadata document directly. For DCR, a fresh
 833/// registration is performed each time because the loopback redirect URI
 834/// includes an ephemeral port that changes every flow.
 835pub async fn resolve_client_registration(
 836    http_client: &Arc<dyn HttpClient>,
 837    discovery: &OAuthDiscovery,
 838    redirect_uri: &str,
 839) -> Result<OAuthClientRegistration> {
 840    match determine_registration_strategy(&discovery.auth_server_metadata) {
 841        ClientRegistrationStrategy::Cimd { client_id } => Ok(OAuthClientRegistration {
 842            client_id,
 843            client_secret: None,
 844        }),
 845        ClientRegistrationStrategy::Dcr {
 846            registration_endpoint,
 847        } => perform_dcr(http_client, &registration_endpoint, redirect_uri).await,
 848        ClientRegistrationStrategy::Unavailable => {
 849            bail!("authorization server supports neither CIMD nor DCR")
 850        }
 851    }
 852}
 853
 854// -- Dynamic Client Registration (RFC 7591) ----------------------------------
 855
 856/// Perform Dynamic Client Registration with the authorization server.
 857pub async fn perform_dcr(
 858    http_client: &Arc<dyn HttpClient>,
 859    registration_endpoint: &Url,
 860    redirect_uri: &str,
 861) -> Result<OAuthClientRegistration> {
 862    validate_oauth_url(registration_endpoint)?;
 863
 864    let body = dcr_registration_body(redirect_uri);
 865    let body_bytes = serde_json::to_vec(&body)?;
 866
 867    let request = Request::builder()
 868        .method(http_client::http::Method::POST)
 869        .uri(registration_endpoint.as_str())
 870        .header("Content-Type", "application/json")
 871        .header("Accept", "application/json")
 872        .body(AsyncBody::from(body_bytes))?;
 873
 874    let mut response = http_client.send(request).await?;
 875
 876    if !response.status().is_success() {
 877        let mut error_body = String::new();
 878        response.body_mut().read_to_string(&mut error_body).await?;
 879        bail!(
 880            "DCR failed with status {}: {}",
 881            response.status(),
 882            error_body
 883        );
 884    }
 885
 886    let mut response_body = String::new();
 887    response
 888        .body_mut()
 889        .read_to_string(&mut response_body)
 890        .await?;
 891
 892    let dcr_response: DcrResponse =
 893        serde_json::from_str(&response_body).context("failed to parse DCR response")?;
 894
 895    Ok(OAuthClientRegistration {
 896        client_id: dcr_response.client_id,
 897        client_secret: dcr_response.client_secret,
 898    })
 899}
 900
 901// -- Token exchange and refresh (async) --------------------------------------
 902
 903/// Exchange an authorization code for tokens at the token endpoint.
 904pub async fn exchange_code(
 905    http_client: &Arc<dyn HttpClient>,
 906    auth_server_metadata: &AuthServerMetadata,
 907    code: &str,
 908    client_id: &str,
 909    redirect_uri: &str,
 910    code_verifier: &str,
 911    resource: &str,
 912) -> Result<OAuthTokens> {
 913    let params = token_exchange_params(code, client_id, redirect_uri, code_verifier, resource);
 914    post_token_request(http_client, &auth_server_metadata.token_endpoint, &params).await
 915}
 916
 917/// Refresh tokens using a refresh token.
 918pub async fn refresh_tokens(
 919    http_client: &Arc<dyn HttpClient>,
 920    token_endpoint: &Url,
 921    refresh_token: &str,
 922    client_id: &str,
 923    resource: &str,
 924) -> Result<OAuthTokens> {
 925    let params = token_refresh_params(refresh_token, client_id, resource);
 926    post_token_request(http_client, token_endpoint, &params).await
 927}
 928
 929/// POST form-encoded parameters to a token endpoint and parse the response.
 930async fn post_token_request(
 931    http_client: &Arc<dyn HttpClient>,
 932    token_endpoint: &Url,
 933    params: &[(&str, String)],
 934) -> Result<OAuthTokens> {
 935    validate_oauth_url(token_endpoint)?;
 936
 937    let body = url::form_urlencoded::Serializer::new(String::new())
 938        .extend_pairs(params.iter().map(|(k, v)| (*k, v.as_str())))
 939        .finish();
 940
 941    let request = Request::builder()
 942        .method(http_client::http::Method::POST)
 943        .uri(token_endpoint.as_str())
 944        .header("Content-Type", "application/x-www-form-urlencoded")
 945        .header("Accept", "application/json")
 946        .body(AsyncBody::from(body.into_bytes()))?;
 947
 948    let mut response = http_client.send(request).await?;
 949
 950    if !response.status().is_success() {
 951        let mut error_body = String::new();
 952        response.body_mut().read_to_string(&mut error_body).await?;
 953        bail!(
 954            "token request failed with status {}: {}",
 955            response.status(),
 956            error_body
 957        );
 958    }
 959
 960    let mut response_body = String::new();
 961    response
 962        .body_mut()
 963        .read_to_string(&mut response_body)
 964        .await?;
 965
 966    let token_response: TokenResponse =
 967        serde_json::from_str(&response_body).context("failed to parse token response")?;
 968
 969    Ok(token_response.into_tokens())
 970}
 971
 972// -- Loopback HTTP callback server -------------------------------------------
 973
 974/// An OAuth authorization callback received via the loopback HTTP server.
 975pub struct OAuthCallback {
 976    pub code: String,
 977    pub state: String,
 978}
 979
 980impl std::fmt::Debug for OAuthCallback {
 981    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 982        f.debug_struct("OAuthCallback")
 983            .field("code", &"[redacted]")
 984            .field("state", &"[redacted]")
 985            .finish()
 986    }
 987}
 988
 989impl OAuthCallback {
 990    /// Parse the query string from a callback URL like
 991    /// `http://127.0.0.1:<port>/callback?code=...&state=...`.
 992    pub fn parse_query(query: &str) -> Result<Self> {
 993        let params = http_client::OAuthCallbackParams::parse_query(query)?;
 994        Ok(Self {
 995            code: params.code,
 996            state: params.state,
 997        })
 998    }
 999}
1000
1001/// Start a loopback HTTP server to receive the OAuth authorization callback.
1002///
1003/// Binds to an ephemeral loopback port for each flow.
1004///
1005/// Returns `(redirect_uri, callback_future)`. The caller should use the
1006/// redirect URI in the authorization request, open the browser, then await
1007/// the future to receive the callback.
1008///
1009/// The server accepts exactly one request on `/callback`, validates that it
1010/// contains `code` and `state` query parameters, responds with a minimal
1011/// HTML page telling the user they can close the tab, and shuts down.
1012///
1013/// The callback server shuts down when the returned oneshot receiver is dropped
1014/// (e.g. because the authentication task was cancelled), or after a timeout.
1015pub async fn start_callback_server() -> Result<(
1016    String,
1017    futures::channel::oneshot::Receiver<Result<OAuthCallback>>,
1018)> {
1019    let (redirect_uri, rx) = http_client::start_oauth_callback_server()?;
1020    let (tx, mapped_rx) = futures::channel::oneshot::channel();
1021    smol::spawn(async move {
1022        let result = match rx.await {
1023            Ok(Ok(params)) => Ok(OAuthCallback {
1024                code: params.code,
1025                state: params.state,
1026            }),
1027            Ok(Err(e)) => Err(e),
1028            Err(_) => Err(anyhow!("OAuth callback channel was cancelled")),
1029        };
1030        let _ = tx.send(result);
1031    })
1032    .detach();
1033    Ok((redirect_uri, mapped_rx))
1034}
1035
1036// -- JSON fetch helper -------------------------------------------------------
1037
1038async fn fetch_json<T: serde::de::DeserializeOwned>(
1039    http_client: &Arc<dyn HttpClient>,
1040    url: &Url,
1041) -> Result<T> {
1042    validate_oauth_url(url)?;
1043
1044    let request = Request::builder()
1045        .method(http_client::http::Method::GET)
1046        .uri(url.as_str())
1047        .header("Accept", "application/json")
1048        .body(AsyncBody::default())?;
1049
1050    let mut response = http_client.send(request).await?;
1051
1052    if !response.status().is_success() {
1053        bail!("HTTP {} fetching {}", response.status(), url);
1054    }
1055
1056    let mut body = String::new();
1057    response.body_mut().read_to_string(&mut body).await?;
1058    serde_json::from_str(&body).with_context(|| format!("failed to parse JSON from {}", url))
1059}
1060
1061// -- Serde response types for discovery --------------------------------------
1062
1063#[derive(Debug, Deserialize)]
1064struct ProtectedResourceMetadataResponse {
1065    #[serde(default)]
1066    resource: Option<Url>,
1067    #[serde(default)]
1068    authorization_servers: Vec<Url>,
1069    #[serde(default)]
1070    scopes_supported: Option<Vec<String>>,
1071}
1072
1073#[derive(Debug, Deserialize)]
1074struct AuthServerMetadataResponse {
1075    #[serde(default)]
1076    issuer: Option<Url>,
1077    #[serde(default)]
1078    authorization_endpoint: Option<Url>,
1079    #[serde(default)]
1080    token_endpoint: Option<Url>,
1081    #[serde(default)]
1082    registration_endpoint: Option<Url>,
1083    #[serde(default)]
1084    scopes_supported: Option<Vec<String>>,
1085    #[serde(default)]
1086    code_challenge_methods_supported: Option<Vec<String>>,
1087    #[serde(default)]
1088    client_id_metadata_document_supported: Option<bool>,
1089}
1090
1091#[derive(Debug, Deserialize)]
1092struct DcrResponse {
1093    client_id: String,
1094    #[serde(default)]
1095    client_secret: Option<String>,
1096}
1097
1098/// Provides OAuth tokens to the HTTP transport layer.
1099///
1100/// The transport calls `access_token()` before each request. On a 401 response
1101/// it calls `try_refresh()` and retries once if the refresh succeeds.
1102#[async_trait]
1103pub trait OAuthTokenProvider: Send + Sync {
1104    /// Returns the current access token, if one is available.
1105    fn access_token(&self) -> Option<String>;
1106
1107    /// Attempts to refresh the access token. Returns `true` if a new token was
1108    /// obtained and the request should be retried.
1109    async fn try_refresh(&self) -> Result<bool>;
1110}
1111
1112/// Concrete `OAuthTokenProvider` backed by a full persisted OAuth session and
1113/// an HTTP client for token refresh. The same provider type is used both after
1114/// an interactive authentication flow and when restoring a saved session from
1115/// the keychain on startup.
1116pub struct McpOAuthTokenProvider {
1117    session: SyncMutex<OAuthSession>,
1118    http_client: Arc<dyn HttpClient>,
1119    token_refresh_tx: Option<mpsc::UnboundedSender<OAuthSession>>,
1120}
1121
1122impl McpOAuthTokenProvider {
1123    pub fn new(
1124        session: OAuthSession,
1125        http_client: Arc<dyn HttpClient>,
1126        token_refresh_tx: Option<mpsc::UnboundedSender<OAuthSession>>,
1127    ) -> Self {
1128        Self {
1129            session: SyncMutex::new(session),
1130            http_client,
1131            token_refresh_tx,
1132        }
1133    }
1134
1135    fn access_token_is_expired(tokens: &OAuthTokens) -> bool {
1136        tokens.expires_at.is_some_and(|expires_at| {
1137            SystemTime::now()
1138                .checked_add(Duration::from_secs(30))
1139                .is_some_and(|now_with_buffer| expires_at <= now_with_buffer)
1140        })
1141    }
1142}
1143
1144#[async_trait]
1145impl OAuthTokenProvider for McpOAuthTokenProvider {
1146    fn access_token(&self) -> Option<String> {
1147        let session = self.session.lock();
1148        if Self::access_token_is_expired(&session.tokens) {
1149            return None;
1150        }
1151        Some(session.tokens.access_token.clone())
1152    }
1153
1154    async fn try_refresh(&self) -> Result<bool> {
1155        let (refresh_token, token_endpoint, resource, client_id) = {
1156            let session = self.session.lock();
1157            match session.tokens.refresh_token.clone() {
1158                Some(refresh_token) => (
1159                    refresh_token,
1160                    session.token_endpoint.clone(),
1161                    session.resource.clone(),
1162                    session.client_registration.client_id.clone(),
1163                ),
1164                None => return Ok(false),
1165            }
1166        };
1167
1168        let resource_str = canonical_server_uri(&resource);
1169
1170        match refresh_tokens(
1171            &self.http_client,
1172            &token_endpoint,
1173            &refresh_token,
1174            &client_id,
1175            &resource_str,
1176        )
1177        .await
1178        {
1179            Ok(mut new_tokens) => {
1180                if new_tokens.refresh_token.is_none() {
1181                    new_tokens.refresh_token = Some(refresh_token);
1182                }
1183
1184                {
1185                    let mut session = self.session.lock();
1186                    session.tokens = new_tokens;
1187
1188                    if let Some(ref tx) = self.token_refresh_tx {
1189                        tx.unbounded_send(session.clone()).ok();
1190                    }
1191                }
1192
1193                Ok(true)
1194            }
1195            Err(err) => {
1196                log::warn!("OAuth token refresh failed: {}", err);
1197                Ok(false)
1198            }
1199        }
1200    }
1201}
1202
1203#[cfg(test)]
1204mod tests {
1205    use super::*;
1206    use http_client::Response;
1207
1208    // -- require_https_or_loopback tests ------------------------------------
1209
1210    #[test]
1211    fn test_require_https_or_loopback_accepts_https() {
1212        let url = Url::parse("https://auth.example.com/token").unwrap();
1213        assert!(require_https_or_loopback(&url).is_ok());
1214    }
1215
1216    #[test]
1217    fn test_require_https_or_loopback_rejects_http_remote() {
1218        let url = Url::parse("http://auth.example.com/token").unwrap();
1219        assert!(require_https_or_loopback(&url).is_err());
1220    }
1221
1222    #[test]
1223    fn test_require_https_or_loopback_accepts_http_127_0_0_1() {
1224        let url = Url::parse("http://127.0.0.1:8080/callback").unwrap();
1225        assert!(require_https_or_loopback(&url).is_ok());
1226    }
1227
1228    #[test]
1229    fn test_require_https_or_loopback_accepts_http_ipv6_loopback() {
1230        let url = Url::parse("http://[::1]:8080/callback").unwrap();
1231        assert!(require_https_or_loopback(&url).is_ok());
1232    }
1233
1234    #[test]
1235    fn test_require_https_or_loopback_accepts_http_localhost() {
1236        let url = Url::parse("http://localhost:8080/callback").unwrap();
1237        assert!(require_https_or_loopback(&url).is_ok());
1238    }
1239
1240    #[test]
1241    fn test_require_https_or_loopback_accepts_http_localhost_case_insensitive() {
1242        let url = Url::parse("http://LOCALHOST:8080/callback").unwrap();
1243        assert!(require_https_or_loopback(&url).is_ok());
1244    }
1245
1246    #[test]
1247    fn test_require_https_or_loopback_rejects_http_non_loopback_ip() {
1248        let url = Url::parse("http://192.168.1.1:8080/token").unwrap();
1249        assert!(require_https_or_loopback(&url).is_err());
1250    }
1251
1252    #[test]
1253    fn test_require_https_or_loopback_rejects_ftp() {
1254        let url = Url::parse("ftp://auth.example.com/token").unwrap();
1255        assert!(require_https_or_loopback(&url).is_err());
1256    }
1257
1258    // -- validate_oauth_url (SSRF) tests ------------------------------------
1259
1260    #[test]
1261    fn test_validate_oauth_url_accepts_https_public() {
1262        let url = Url::parse("https://auth.example.com/token").unwrap();
1263        assert!(validate_oauth_url(&url).is_ok());
1264    }
1265
1266    #[test]
1267    fn test_validate_oauth_url_rejects_private_ipv4_10() {
1268        let url = Url::parse("https://10.0.0.1/token").unwrap();
1269        assert!(validate_oauth_url(&url).is_err());
1270    }
1271
1272    #[test]
1273    fn test_validate_oauth_url_rejects_private_ipv4_172() {
1274        let url = Url::parse("https://172.16.0.1/token").unwrap();
1275        assert!(validate_oauth_url(&url).is_err());
1276    }
1277
1278    #[test]
1279    fn test_validate_oauth_url_rejects_private_ipv4_192() {
1280        let url = Url::parse("https://192.168.1.1/token").unwrap();
1281        assert!(validate_oauth_url(&url).is_err());
1282    }
1283
1284    #[test]
1285    fn test_validate_oauth_url_rejects_link_local() {
1286        let url = Url::parse("https://169.254.169.254/latest/meta-data/").unwrap();
1287        assert!(validate_oauth_url(&url).is_err());
1288    }
1289
1290    #[test]
1291    fn test_validate_oauth_url_rejects_ipv6_ula() {
1292        let url = Url::parse("https://[fd12:3456:789a::1]/token").unwrap();
1293        assert!(validate_oauth_url(&url).is_err());
1294    }
1295
1296    #[test]
1297    fn test_validate_oauth_url_rejects_ipv6_unspecified() {
1298        let url = Url::parse("https://[::]/token").unwrap();
1299        assert!(validate_oauth_url(&url).is_err());
1300    }
1301
1302    #[test]
1303    fn test_validate_oauth_url_rejects_ipv4_mapped_ipv6_private() {
1304        let url = Url::parse("https://[::ffff:10.0.0.1]/token").unwrap();
1305        assert!(validate_oauth_url(&url).is_err());
1306    }
1307
1308    #[test]
1309    fn test_validate_oauth_url_rejects_ipv4_mapped_ipv6_link_local() {
1310        let url = Url::parse("https://[::ffff:169.254.169.254]/token").unwrap();
1311        assert!(validate_oauth_url(&url).is_err());
1312    }
1313
1314    #[test]
1315    fn test_validate_oauth_url_allows_http_loopback() {
1316        // Loopback is permitted (it's our callback server).
1317        let url = Url::parse("http://127.0.0.1:8080/callback").unwrap();
1318        assert!(validate_oauth_url(&url).is_ok());
1319    }
1320
1321    #[test]
1322    fn test_validate_oauth_url_allows_https_public_ip() {
1323        let url = Url::parse("https://93.184.216.34/token").unwrap();
1324        assert!(validate_oauth_url(&url).is_ok());
1325    }
1326
1327    // -- parse_www_authenticate tests ----------------------------------------
1328
1329    #[test]
1330    fn test_parse_www_authenticate_with_resource_metadata_and_scope() {
1331        let header = r#"Bearer resource_metadata="https://mcp.example.com/.well-known/oauth-protected-resource", scope="files:read user:profile""#;
1332        let result = parse_www_authenticate(header).unwrap();
1333
1334        assert_eq!(
1335            result.resource_metadata.as_ref().map(|u| u.as_str()),
1336            Some("https://mcp.example.com/.well-known/oauth-protected-resource")
1337        );
1338        assert_eq!(
1339            result.scope,
1340            Some(vec!["files:read".to_string(), "user:profile".to_string()])
1341        );
1342        assert_eq!(result.error, None);
1343    }
1344
1345    #[test]
1346    fn test_parse_www_authenticate_resource_metadata_only() {
1347        let header = r#"Bearer resource_metadata="https://mcp.example.com/.well-known/oauth-protected-resource""#;
1348        let result = parse_www_authenticate(header).unwrap();
1349
1350        assert_eq!(
1351            result.resource_metadata.as_ref().map(|u| u.as_str()),
1352            Some("https://mcp.example.com/.well-known/oauth-protected-resource")
1353        );
1354        assert_eq!(result.scope, None);
1355    }
1356
1357    #[test]
1358    fn test_parse_www_authenticate_bare_bearer() {
1359        let result = parse_www_authenticate("Bearer").unwrap();
1360        assert_eq!(result.resource_metadata, None);
1361        assert_eq!(result.scope, None);
1362    }
1363
1364    #[test]
1365    fn test_parse_www_authenticate_with_error() {
1366        let header = r#"Bearer error="insufficient_scope", scope="files:read files:write", resource_metadata="https://mcp.example.com/.well-known/oauth-protected-resource", error_description="Additional file write permission required""#;
1367        let result = parse_www_authenticate(header).unwrap();
1368
1369        assert_eq!(result.error, Some(BearerError::InsufficientScope));
1370        assert_eq!(
1371            result.error_description.as_deref(),
1372            Some("Additional file write permission required")
1373        );
1374        assert_eq!(
1375            result.scope,
1376            Some(vec!["files:read".to_string(), "files:write".to_string()])
1377        );
1378        assert!(result.resource_metadata.is_some());
1379    }
1380
1381    #[test]
1382    fn test_parse_www_authenticate_invalid_token_error() {
1383        let header =
1384            r#"Bearer error="invalid_token", error_description="The access token expired""#;
1385        let result = parse_www_authenticate(header).unwrap();
1386        assert_eq!(result.error, Some(BearerError::InvalidToken));
1387    }
1388
1389    #[test]
1390    fn test_parse_www_authenticate_invalid_request_error() {
1391        let header = r#"Bearer error="invalid_request""#;
1392        let result = parse_www_authenticate(header).unwrap();
1393        assert_eq!(result.error, Some(BearerError::InvalidRequest));
1394    }
1395
1396    #[test]
1397    fn test_parse_www_authenticate_unknown_error() {
1398        let header = r#"Bearer error="some_future_error""#;
1399        let result = parse_www_authenticate(header).unwrap();
1400        assert_eq!(result.error, Some(BearerError::Other));
1401    }
1402
1403    #[test]
1404    fn test_parse_www_authenticate_rejects_non_bearer() {
1405        let result = parse_www_authenticate("Basic realm=\"example\"");
1406        assert!(result.is_err());
1407    }
1408
1409    #[test]
1410    fn test_parse_www_authenticate_case_insensitive_scheme() {
1411        let header = r#"bearer resource_metadata="https://example.com/.well-known/oauth-protected-resource""#;
1412        let result = parse_www_authenticate(header).unwrap();
1413        assert!(result.resource_metadata.is_some());
1414    }
1415
1416    #[test]
1417    fn test_parse_www_authenticate_multiline_style() {
1418        // Some servers emit the header spread across multiple lines joined by
1419        // whitespace, as shown in the spec examples.
1420        let header = "Bearer resource_metadata=\"https://mcp.example.com/.well-known/oauth-protected-resource\",\n                         scope=\"files:read\"";
1421        let result = parse_www_authenticate(header).unwrap();
1422        assert!(result.resource_metadata.is_some());
1423        assert_eq!(result.scope, Some(vec!["files:read".to_string()]));
1424    }
1425
1426    #[test]
1427    fn test_protected_resource_metadata_urls_with_path() {
1428        let server_url = Url::parse("https://api.example.com/v1/mcp").unwrap();
1429        let urls = protected_resource_metadata_urls(&server_url);
1430
1431        assert_eq!(urls.len(), 2);
1432        assert_eq!(
1433            urls[0].as_str(),
1434            "https://api.example.com/.well-known/oauth-protected-resource/v1/mcp"
1435        );
1436        assert_eq!(
1437            urls[1].as_str(),
1438            "https://api.example.com/.well-known/oauth-protected-resource"
1439        );
1440    }
1441
1442    #[test]
1443    fn test_protected_resource_metadata_urls_without_path() {
1444        let server_url = Url::parse("https://mcp.example.com").unwrap();
1445        let urls = protected_resource_metadata_urls(&server_url);
1446
1447        assert_eq!(urls.len(), 1);
1448        assert_eq!(
1449            urls[0].as_str(),
1450            "https://mcp.example.com/.well-known/oauth-protected-resource"
1451        );
1452    }
1453
1454    #[test]
1455    fn test_auth_server_metadata_urls_with_path() {
1456        let issuer = Url::parse("https://auth.example.com/tenant1").unwrap();
1457        let urls = auth_server_metadata_urls(&issuer);
1458
1459        assert_eq!(urls.len(), 3);
1460        assert_eq!(
1461            urls[0].as_str(),
1462            "https://auth.example.com/.well-known/oauth-authorization-server/tenant1"
1463        );
1464        assert_eq!(
1465            urls[1].as_str(),
1466            "https://auth.example.com/.well-known/openid-configuration/tenant1"
1467        );
1468        assert_eq!(
1469            urls[2].as_str(),
1470            "https://auth.example.com/tenant1/.well-known/openid-configuration"
1471        );
1472    }
1473
1474    #[test]
1475    fn test_auth_server_metadata_urls_without_path() {
1476        let issuer = Url::parse("https://auth.example.com").unwrap();
1477        let urls = auth_server_metadata_urls(&issuer);
1478
1479        assert_eq!(urls.len(), 2);
1480        assert_eq!(
1481            urls[0].as_str(),
1482            "https://auth.example.com/.well-known/oauth-authorization-server"
1483        );
1484        assert_eq!(
1485            urls[1].as_str(),
1486            "https://auth.example.com/.well-known/openid-configuration"
1487        );
1488    }
1489
1490    // -- Canonical server URI tests ------------------------------------------
1491
1492    #[test]
1493    fn test_canonical_server_uri_simple() {
1494        let url = Url::parse("https://mcp.example.com").unwrap();
1495        assert_eq!(canonical_server_uri(&url), "https://mcp.example.com");
1496    }
1497
1498    #[test]
1499    fn test_canonical_server_uri_with_path() {
1500        let url = Url::parse("https://mcp.example.com/v1/mcp").unwrap();
1501        assert_eq!(canonical_server_uri(&url), "https://mcp.example.com/v1/mcp");
1502    }
1503
1504    #[test]
1505    fn test_canonical_server_uri_strips_trailing_slash() {
1506        let url = Url::parse("https://mcp.example.com/").unwrap();
1507        assert_eq!(canonical_server_uri(&url), "https://mcp.example.com");
1508    }
1509
1510    #[test]
1511    fn test_canonical_server_uri_preserves_port() {
1512        let url = Url::parse("https://mcp.example.com:8443").unwrap();
1513        assert_eq!(canonical_server_uri(&url), "https://mcp.example.com:8443");
1514    }
1515
1516    #[test]
1517    fn test_canonical_server_uri_lowercases() {
1518        let url = Url::parse("HTTPS://MCP.Example.COM/Server/MCP").unwrap();
1519        assert_eq!(
1520            canonical_server_uri(&url),
1521            "https://mcp.example.com/Server/MCP"
1522        );
1523    }
1524
1525    // -- Scope selection tests -----------------------------------------------
1526
1527    #[test]
1528    fn test_select_scopes_prefers_www_authenticate() {
1529        let www_auth = WwwAuthenticate {
1530            resource_metadata: None,
1531            scope: Some(vec!["files:read".into()]),
1532            error: None,
1533            error_description: None,
1534        };
1535        let resource_meta = ProtectedResourceMetadata {
1536            resource: Url::parse("https://example.com").unwrap(),
1537            authorization_servers: vec![],
1538            scopes_supported: Some(vec!["files:read".into(), "files:write".into()]),
1539        };
1540        assert_eq!(select_scopes(&www_auth, &resource_meta), vec!["files:read"]);
1541    }
1542
1543    #[test]
1544    fn test_select_scopes_falls_back_to_resource_metadata() {
1545        let www_auth = WwwAuthenticate {
1546            resource_metadata: None,
1547            scope: None,
1548            error: None,
1549            error_description: None,
1550        };
1551        let resource_meta = ProtectedResourceMetadata {
1552            resource: Url::parse("https://example.com").unwrap(),
1553            authorization_servers: vec![],
1554            scopes_supported: Some(vec!["admin".into()]),
1555        };
1556        assert_eq!(select_scopes(&www_auth, &resource_meta), vec!["admin"]);
1557    }
1558
1559    #[test]
1560    fn test_select_scopes_empty_when_nothing_available() {
1561        let www_auth = WwwAuthenticate {
1562            resource_metadata: None,
1563            scope: None,
1564            error: None,
1565            error_description: None,
1566        };
1567        let resource_meta = ProtectedResourceMetadata {
1568            resource: Url::parse("https://example.com").unwrap(),
1569            authorization_servers: vec![],
1570            scopes_supported: None,
1571        };
1572        assert!(select_scopes(&www_auth, &resource_meta).is_empty());
1573    }
1574
1575    // -- Client registration strategy tests ----------------------------------
1576
1577    #[test]
1578    fn test_registration_strategy_prefers_cimd() {
1579        let metadata = AuthServerMetadata {
1580            issuer: Url::parse("https://auth.example.com").unwrap(),
1581            authorization_endpoint: Url::parse("https://auth.example.com/authorize").unwrap(),
1582            token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
1583            registration_endpoint: Some(Url::parse("https://auth.example.com/register").unwrap()),
1584            scopes_supported: None,
1585            code_challenge_methods_supported: Some(vec!["S256".into()]),
1586            client_id_metadata_document_supported: true,
1587        };
1588        assert_eq!(
1589            determine_registration_strategy(&metadata),
1590            ClientRegistrationStrategy::Cimd {
1591                client_id: CIMD_URL.to_string(),
1592            }
1593        );
1594    }
1595
1596    #[test]
1597    fn test_registration_strategy_falls_back_to_dcr() {
1598        let reg_endpoint = Url::parse("https://auth.example.com/register").unwrap();
1599        let metadata = AuthServerMetadata {
1600            issuer: Url::parse("https://auth.example.com").unwrap(),
1601            authorization_endpoint: Url::parse("https://auth.example.com/authorize").unwrap(),
1602            token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
1603            registration_endpoint: Some(reg_endpoint.clone()),
1604            scopes_supported: None,
1605            code_challenge_methods_supported: Some(vec!["S256".into()]),
1606            client_id_metadata_document_supported: false,
1607        };
1608        assert_eq!(
1609            determine_registration_strategy(&metadata),
1610            ClientRegistrationStrategy::Dcr {
1611                registration_endpoint: reg_endpoint,
1612            }
1613        );
1614    }
1615
1616    #[test]
1617    fn test_registration_strategy_unavailable() {
1618        let metadata = AuthServerMetadata {
1619            issuer: Url::parse("https://auth.example.com").unwrap(),
1620            authorization_endpoint: Url::parse("https://auth.example.com/authorize").unwrap(),
1621            token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
1622            registration_endpoint: None,
1623            scopes_supported: None,
1624            code_challenge_methods_supported: Some(vec!["S256".into()]),
1625            client_id_metadata_document_supported: false,
1626        };
1627        assert_eq!(
1628            determine_registration_strategy(&metadata),
1629            ClientRegistrationStrategy::Unavailable,
1630        );
1631    }
1632
1633    // -- PKCE tests ----------------------------------------------------------
1634
1635    #[test]
1636    fn test_pkce_challenge_verifier_length() {
1637        let pkce = generate_pkce_challenge();
1638        // 32 random bytes → 43 base64url chars (no padding).
1639        assert_eq!(pkce.verifier.len(), 43);
1640    }
1641
1642    #[test]
1643    fn test_pkce_challenge_is_valid_base64url() {
1644        let pkce = generate_pkce_challenge();
1645        for c in pkce.verifier.chars().chain(pkce.challenge.chars()) {
1646            assert!(
1647                c.is_ascii_alphanumeric() || c == '-' || c == '_',
1648                "invalid base64url character: {}",
1649                c
1650            );
1651        }
1652    }
1653
1654    #[test]
1655    fn test_pkce_challenge_is_s256_of_verifier() {
1656        let pkce = generate_pkce_challenge();
1657        let engine = base64::engine::general_purpose::URL_SAFE_NO_PAD;
1658        let expected_digest = Sha256::digest(pkce.verifier.as_bytes());
1659        let expected_challenge = engine.encode(expected_digest);
1660        assert_eq!(pkce.challenge, expected_challenge);
1661    }
1662
1663    #[test]
1664    fn test_pkce_challenges_are_unique() {
1665        let a = generate_pkce_challenge();
1666        let b = generate_pkce_challenge();
1667        assert_ne!(a.verifier, b.verifier);
1668    }
1669
1670    // -- Authorization URL tests ---------------------------------------------
1671
1672    #[test]
1673    fn test_build_authorization_url() {
1674        let metadata = AuthServerMetadata {
1675            issuer: Url::parse("https://auth.example.com").unwrap(),
1676            authorization_endpoint: Url::parse("https://auth.example.com/authorize").unwrap(),
1677            token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
1678            registration_endpoint: None,
1679            scopes_supported: None,
1680            code_challenge_methods_supported: Some(vec!["S256".into()]),
1681            client_id_metadata_document_supported: true,
1682        };
1683        let pkce = PkceChallenge {
1684            verifier: "test_verifier".into(),
1685            challenge: "test_challenge".into(),
1686        };
1687        let url = build_authorization_url(
1688            &metadata,
1689            "https://zed.dev/oauth/client-metadata.json",
1690            "http://127.0.0.1:12345/callback",
1691            &["files:read".into(), "files:write".into()],
1692            "https://mcp.example.com",
1693            &pkce,
1694            "random_state_123",
1695        );
1696
1697        let pairs: std::collections::HashMap<_, _> = url.query_pairs().collect();
1698        assert_eq!(pairs.get("response_type").unwrap(), "code");
1699        assert_eq!(
1700            pairs.get("client_id").unwrap(),
1701            "https://zed.dev/oauth/client-metadata.json"
1702        );
1703        assert_eq!(
1704            pairs.get("redirect_uri").unwrap(),
1705            "http://127.0.0.1:12345/callback"
1706        );
1707        assert_eq!(pairs.get("scope").unwrap(), "files:read files:write");
1708        assert_eq!(pairs.get("resource").unwrap(), "https://mcp.example.com");
1709        assert_eq!(pairs.get("code_challenge").unwrap(), "test_challenge");
1710        assert_eq!(pairs.get("code_challenge_method").unwrap(), "S256");
1711        assert_eq!(pairs.get("state").unwrap(), "random_state_123");
1712    }
1713
1714    #[test]
1715    fn test_build_authorization_url_omits_empty_scope() {
1716        let metadata = AuthServerMetadata {
1717            issuer: Url::parse("https://auth.example.com").unwrap(),
1718            authorization_endpoint: Url::parse("https://auth.example.com/authorize").unwrap(),
1719            token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
1720            registration_endpoint: None,
1721            scopes_supported: None,
1722            code_challenge_methods_supported: Some(vec!["S256".into()]),
1723            client_id_metadata_document_supported: false,
1724        };
1725        let pkce = PkceChallenge {
1726            verifier: "v".into(),
1727            challenge: "c".into(),
1728        };
1729        let url = build_authorization_url(
1730            &metadata,
1731            "client_123",
1732            "http://127.0.0.1:9999/callback",
1733            &[],
1734            "https://mcp.example.com",
1735            &pkce,
1736            "state",
1737        );
1738
1739        let pairs: std::collections::HashMap<_, _> = url.query_pairs().collect();
1740        assert!(!pairs.contains_key("scope"));
1741    }
1742
1743    // -- Token exchange / refresh param tests --------------------------------
1744
1745    #[test]
1746    fn test_token_exchange_params() {
1747        let params = token_exchange_params(
1748            "auth_code_abc",
1749            "client_xyz",
1750            "http://127.0.0.1:5555/callback",
1751            "verifier_123",
1752            "https://mcp.example.com",
1753        );
1754        let map: std::collections::HashMap<&str, &str> =
1755            params.iter().map(|(k, v)| (*k, v.as_str())).collect();
1756
1757        assert_eq!(map["grant_type"], "authorization_code");
1758        assert_eq!(map["code"], "auth_code_abc");
1759        assert_eq!(map["redirect_uri"], "http://127.0.0.1:5555/callback");
1760        assert_eq!(map["client_id"], "client_xyz");
1761        assert_eq!(map["code_verifier"], "verifier_123");
1762        assert_eq!(map["resource"], "https://mcp.example.com");
1763    }
1764
1765    #[test]
1766    fn test_token_refresh_params() {
1767        let params =
1768            token_refresh_params("refresh_token_abc", "client_xyz", "https://mcp.example.com");
1769        let map: std::collections::HashMap<&str, &str> =
1770            params.iter().map(|(k, v)| (*k, v.as_str())).collect();
1771
1772        assert_eq!(map["grant_type"], "refresh_token");
1773        assert_eq!(map["refresh_token"], "refresh_token_abc");
1774        assert_eq!(map["client_id"], "client_xyz");
1775        assert_eq!(map["resource"], "https://mcp.example.com");
1776    }
1777
1778    // -- Token response tests ------------------------------------------------
1779
1780    #[test]
1781    fn test_token_response_into_tokens_with_expiry() {
1782        let response: TokenResponse = serde_json::from_str(
1783            r#"{"access_token": "at_123", "refresh_token": "rt_456", "expires_in": 3600, "token_type": "Bearer"}"#,
1784        )
1785        .unwrap();
1786
1787        let tokens = response.into_tokens();
1788        assert_eq!(tokens.access_token, "at_123");
1789        assert_eq!(tokens.refresh_token.as_deref(), Some("rt_456"));
1790        assert!(tokens.expires_at.is_some());
1791    }
1792
1793    #[test]
1794    fn test_token_response_into_tokens_minimal() {
1795        let response: TokenResponse =
1796            serde_json::from_str(r#"{"access_token": "at_789"}"#).unwrap();
1797
1798        let tokens = response.into_tokens();
1799        assert_eq!(tokens.access_token, "at_789");
1800        assert_eq!(tokens.refresh_token, None);
1801        assert_eq!(tokens.expires_at, None);
1802    }
1803
1804    // -- DCR body test -------------------------------------------------------
1805
1806    #[test]
1807    fn test_dcr_registration_body_shape() {
1808        let body = dcr_registration_body("http://127.0.0.1:12345/callback");
1809        assert_eq!(body["client_name"], "Zed");
1810        assert_eq!(body["redirect_uris"][0], "http://127.0.0.1:12345/callback");
1811        assert_eq!(body["grant_types"][0], "authorization_code");
1812        assert_eq!(body["response_types"][0], "code");
1813        assert_eq!(body["token_endpoint_auth_method"], "none");
1814    }
1815
1816    // -- Test helpers for async/HTTP tests -----------------------------------
1817
1818    fn make_fake_http_client(
1819        handler: impl Fn(
1820            http_client::Request<AsyncBody>,
1821        ) -> std::pin::Pin<
1822            Box<dyn std::future::Future<Output = anyhow::Result<Response<AsyncBody>>> + Send>,
1823        > + Send
1824        + Sync
1825        + 'static,
1826    ) -> Arc<dyn HttpClient> {
1827        http_client::FakeHttpClient::create(handler) as Arc<dyn HttpClient>
1828    }
1829
1830    fn json_response(status: u16, body: &str) -> anyhow::Result<Response<AsyncBody>> {
1831        Ok(Response::builder()
1832            .status(status)
1833            .header("Content-Type", "application/json")
1834            .body(AsyncBody::from(body.as_bytes().to_vec()))
1835            .unwrap())
1836    }
1837
1838    // -- Discovery integration tests -----------------------------------------
1839
1840    #[test]
1841    fn test_fetch_protected_resource_metadata() {
1842        smol::block_on(async {
1843            let client = make_fake_http_client(|req| {
1844                Box::pin(async move {
1845                    let uri = req.uri().to_string();
1846                    if uri.contains(".well-known/oauth-protected-resource") {
1847                        json_response(
1848                            200,
1849                            r#"{
1850                                "resource": "https://mcp.example.com",
1851                                "authorization_servers": ["https://auth.example.com"],
1852                                "scopes_supported": ["read", "write"]
1853                            }"#,
1854                        )
1855                    } else {
1856                        json_response(404, "{}")
1857                    }
1858                })
1859            });
1860
1861            let server_url = Url::parse("https://mcp.example.com").unwrap();
1862            let www_auth = WwwAuthenticate {
1863                resource_metadata: None,
1864                scope: None,
1865                error: None,
1866                error_description: None,
1867            };
1868
1869            let metadata = fetch_protected_resource_metadata(&client, &server_url, &www_auth)
1870                .await
1871                .unwrap();
1872
1873            assert_eq!(metadata.resource.as_str(), "https://mcp.example.com/");
1874            assert_eq!(metadata.authorization_servers.len(), 1);
1875            assert_eq!(
1876                metadata.authorization_servers[0].as_str(),
1877                "https://auth.example.com/"
1878            );
1879            assert_eq!(
1880                metadata.scopes_supported,
1881                Some(vec!["read".to_string(), "write".to_string()])
1882            );
1883        });
1884    }
1885
1886    #[test]
1887    fn test_fetch_protected_resource_metadata_prefers_www_authenticate_url() {
1888        smol::block_on(async {
1889            let client = make_fake_http_client(|req| {
1890                Box::pin(async move {
1891                    let uri = req.uri().to_string();
1892                    if uri == "https://mcp.example.com/custom-resource-metadata" {
1893                        json_response(
1894                            200,
1895                            r#"{
1896                                "resource": "https://mcp.example.com",
1897                                "authorization_servers": ["https://auth.example.com"]
1898                            }"#,
1899                        )
1900                    } else {
1901                        json_response(500, r#"{"error": "should not be called"}"#)
1902                    }
1903                })
1904            });
1905
1906            let server_url = Url::parse("https://mcp.example.com").unwrap();
1907            let www_auth = WwwAuthenticate {
1908                resource_metadata: Some(
1909                    Url::parse("https://mcp.example.com/custom-resource-metadata").unwrap(),
1910                ),
1911                scope: None,
1912                error: None,
1913                error_description: None,
1914            };
1915
1916            let metadata = fetch_protected_resource_metadata(&client, &server_url, &www_auth)
1917                .await
1918                .unwrap();
1919
1920            assert_eq!(metadata.authorization_servers.len(), 1);
1921        });
1922    }
1923
1924    #[test]
1925    fn test_fetch_protected_resource_metadata_rejects_cross_origin_url() {
1926        smol::block_on(async {
1927            let client = make_fake_http_client(|req| {
1928                Box::pin(async move {
1929                    let uri = req.uri().to_string();
1930                    // The cross-origin URL should NOT be fetched; only the
1931                    // well-known fallback at the server's own origin should be.
1932                    if uri.contains("attacker.example.com") {
1933                        panic!("should not fetch cross-origin resource_metadata URL");
1934                    } else if uri.contains(".well-known/oauth-protected-resource") {
1935                        json_response(
1936                            200,
1937                            r#"{
1938                                "resource": "https://mcp.example.com",
1939                                "authorization_servers": ["https://auth.example.com"]
1940                            }"#,
1941                        )
1942                    } else {
1943                        json_response(404, "{}")
1944                    }
1945                })
1946            });
1947
1948            let server_url = Url::parse("https://mcp.example.com").unwrap();
1949            let www_auth = WwwAuthenticate {
1950                resource_metadata: Some(
1951                    Url::parse("https://attacker.example.com/fake-metadata").unwrap(),
1952                ),
1953                scope: None,
1954                error: None,
1955                error_description: None,
1956            };
1957
1958            let metadata = fetch_protected_resource_metadata(&client, &server_url, &www_auth)
1959                .await
1960                .unwrap();
1961
1962            // Should have used the fallback well-known URL, not the attacker's.
1963            assert_eq!(metadata.resource.as_str(), "https://mcp.example.com/");
1964        });
1965    }
1966
1967    #[test]
1968    fn test_fetch_auth_server_metadata() {
1969        smol::block_on(async {
1970            let client = make_fake_http_client(|req| {
1971                Box::pin(async move {
1972                    let uri = req.uri().to_string();
1973                    if uri.contains(".well-known/oauth-authorization-server") {
1974                        json_response(
1975                            200,
1976                            r#"{
1977                                "issuer": "https://auth.example.com",
1978                                "authorization_endpoint": "https://auth.example.com/authorize",
1979                                "token_endpoint": "https://auth.example.com/token",
1980                                "registration_endpoint": "https://auth.example.com/register",
1981                                "code_challenge_methods_supported": ["S256"],
1982                                "client_id_metadata_document_supported": true
1983                            }"#,
1984                        )
1985                    } else {
1986                        json_response(404, "{}")
1987                    }
1988                })
1989            });
1990
1991            let issuer = Url::parse("https://auth.example.com").unwrap();
1992            let metadata = fetch_auth_server_metadata(&client, &issuer).await.unwrap();
1993
1994            assert_eq!(metadata.issuer.as_str(), "https://auth.example.com/");
1995            assert_eq!(
1996                metadata.authorization_endpoint.as_str(),
1997                "https://auth.example.com/authorize"
1998            );
1999            assert_eq!(
2000                metadata.token_endpoint.as_str(),
2001                "https://auth.example.com/token"
2002            );
2003            assert!(metadata.registration_endpoint.is_some());
2004            assert!(metadata.client_id_metadata_document_supported);
2005            assert_eq!(
2006                metadata.code_challenge_methods_supported,
2007                Some(vec!["S256".to_string()])
2008            );
2009        });
2010    }
2011
2012    #[test]
2013    fn test_fetch_auth_server_metadata_falls_back_to_oidc() {
2014        smol::block_on(async {
2015            let client = make_fake_http_client(|req| {
2016                Box::pin(async move {
2017                    let uri = req.uri().to_string();
2018                    if uri.contains("openid-configuration") {
2019                        json_response(
2020                            200,
2021                            r#"{
2022                                "issuer": "https://auth.example.com",
2023                                "authorization_endpoint": "https://auth.example.com/authorize",
2024                                "token_endpoint": "https://auth.example.com/token",
2025                                "code_challenge_methods_supported": ["S256"]
2026                            }"#,
2027                        )
2028                    } else {
2029                        json_response(404, "{}")
2030                    }
2031                })
2032            });
2033
2034            let issuer = Url::parse("https://auth.example.com").unwrap();
2035            let metadata = fetch_auth_server_metadata(&client, &issuer).await.unwrap();
2036
2037            assert_eq!(
2038                metadata.authorization_endpoint.as_str(),
2039                "https://auth.example.com/authorize"
2040            );
2041            assert!(!metadata.client_id_metadata_document_supported);
2042        });
2043    }
2044
2045    #[test]
2046    fn test_fetch_auth_server_metadata_rejects_issuer_mismatch() {
2047        smol::block_on(async {
2048            let client = make_fake_http_client(|req| {
2049                Box::pin(async move {
2050                    let uri = req.uri().to_string();
2051                    if uri.contains(".well-known/oauth-authorization-server") {
2052                        // Response claims to be a different issuer.
2053                        json_response(
2054                            200,
2055                            r#"{
2056                                "issuer": "https://evil.example.com",
2057                                "authorization_endpoint": "https://evil.example.com/authorize",
2058                                "token_endpoint": "https://evil.example.com/token",
2059                                "code_challenge_methods_supported": ["S256"]
2060                            }"#,
2061                        )
2062                    } else {
2063                        json_response(404, "{}")
2064                    }
2065                })
2066            });
2067
2068            let issuer = Url::parse("https://auth.example.com").unwrap();
2069            let result = fetch_auth_server_metadata(&client, &issuer).await;
2070
2071            assert!(result.is_err());
2072            let err_msg = result.unwrap_err().to_string();
2073            assert!(
2074                err_msg.contains("issuer mismatch"),
2075                "unexpected error: {}",
2076                err_msg
2077            );
2078        });
2079    }
2080
2081    // -- Full discover integration tests -------------------------------------
2082
2083    #[test]
2084    fn test_full_discover_with_cimd() {
2085        smol::block_on(async {
2086            let client = make_fake_http_client(|req| {
2087                Box::pin(async move {
2088                    let uri = req.uri().to_string();
2089                    if uri.contains("oauth-protected-resource") {
2090                        json_response(
2091                            200,
2092                            r#"{
2093                                "resource": "https://mcp.example.com",
2094                                "authorization_servers": ["https://auth.example.com"],
2095                                "scopes_supported": ["mcp:read"]
2096                            }"#,
2097                        )
2098                    } else if uri.contains("oauth-authorization-server") {
2099                        json_response(
2100                            200,
2101                            r#"{
2102                                "issuer": "https://auth.example.com",
2103                                "authorization_endpoint": "https://auth.example.com/authorize",
2104                                "token_endpoint": "https://auth.example.com/token",
2105                                "code_challenge_methods_supported": ["S256"],
2106                                "client_id_metadata_document_supported": true
2107                            }"#,
2108                        )
2109                    } else {
2110                        json_response(404, "{}")
2111                    }
2112                })
2113            });
2114
2115            let server_url = Url::parse("https://mcp.example.com").unwrap();
2116            let www_auth = WwwAuthenticate {
2117                resource_metadata: None,
2118                scope: None,
2119                error: None,
2120                error_description: None,
2121            };
2122
2123            let discovery = discover(&client, &server_url, &www_auth).await.unwrap();
2124            let registration =
2125                resolve_client_registration(&client, &discovery, "http://127.0.0.1:12345/callback")
2126                    .await
2127                    .unwrap();
2128
2129            assert_eq!(registration.client_id, CIMD_URL);
2130            assert_eq!(registration.client_secret, None);
2131            assert_eq!(discovery.scopes, vec!["mcp:read"]);
2132        });
2133    }
2134
2135    #[test]
2136    fn test_full_discover_with_dcr_fallback() {
2137        smol::block_on(async {
2138            let client = make_fake_http_client(|req| {
2139                Box::pin(async move {
2140                    let uri = req.uri().to_string();
2141                    if uri.contains("oauth-protected-resource") {
2142                        json_response(
2143                            200,
2144                            r#"{
2145                                "resource": "https://mcp.example.com",
2146                                "authorization_servers": ["https://auth.example.com"]
2147                            }"#,
2148                        )
2149                    } else if uri.contains("oauth-authorization-server") {
2150                        json_response(
2151                            200,
2152                            r#"{
2153                                "issuer": "https://auth.example.com",
2154                                "authorization_endpoint": "https://auth.example.com/authorize",
2155                                "token_endpoint": "https://auth.example.com/token",
2156                                "registration_endpoint": "https://auth.example.com/register",
2157                                "code_challenge_methods_supported": ["S256"],
2158                                "client_id_metadata_document_supported": false
2159                            }"#,
2160                        )
2161                    } else if uri.contains("/register") {
2162                        json_response(
2163                            201,
2164                            r#"{
2165                                "client_id": "dcr-minted-id-123",
2166                                "client_secret": "dcr-secret-456"
2167                            }"#,
2168                        )
2169                    } else {
2170                        json_response(404, "{}")
2171                    }
2172                })
2173            });
2174
2175            let server_url = Url::parse("https://mcp.example.com").unwrap();
2176            let www_auth = WwwAuthenticate {
2177                resource_metadata: None,
2178                scope: Some(vec!["files:read".into()]),
2179                error: None,
2180                error_description: None,
2181            };
2182
2183            let discovery = discover(&client, &server_url, &www_auth).await.unwrap();
2184            let registration =
2185                resolve_client_registration(&client, &discovery, "http://127.0.0.1:9999/callback")
2186                    .await
2187                    .unwrap();
2188
2189            assert_eq!(registration.client_id, "dcr-minted-id-123");
2190            assert_eq!(
2191                registration.client_secret.as_deref(),
2192                Some("dcr-secret-456")
2193            );
2194            assert_eq!(discovery.scopes, vec!["files:read"]);
2195        });
2196    }
2197
2198    #[test]
2199    fn test_discover_fails_without_pkce_support() {
2200        smol::block_on(async {
2201            let client = make_fake_http_client(|req| {
2202                Box::pin(async move {
2203                    let uri = req.uri().to_string();
2204                    if uri.contains("oauth-protected-resource") {
2205                        json_response(
2206                            200,
2207                            r#"{
2208                                "resource": "https://mcp.example.com",
2209                                "authorization_servers": ["https://auth.example.com"]
2210                            }"#,
2211                        )
2212                    } else if uri.contains("oauth-authorization-server") {
2213                        json_response(
2214                            200,
2215                            r#"{
2216                                "issuer": "https://auth.example.com",
2217                                "authorization_endpoint": "https://auth.example.com/authorize",
2218                                "token_endpoint": "https://auth.example.com/token"
2219                            }"#,
2220                        )
2221                    } else {
2222                        json_response(404, "{}")
2223                    }
2224                })
2225            });
2226
2227            let server_url = Url::parse("https://mcp.example.com").unwrap();
2228            let www_auth = WwwAuthenticate {
2229                resource_metadata: None,
2230                scope: None,
2231                error: None,
2232                error_description: None,
2233            };
2234
2235            let result = discover(&client, &server_url, &www_auth).await;
2236            assert!(result.is_err());
2237            let err_msg = result.unwrap_err().to_string();
2238            assert!(
2239                err_msg.contains("code_challenge_methods_supported"),
2240                "unexpected error: {}",
2241                err_msg
2242            );
2243        });
2244    }
2245
2246    // -- Token exchange integration tests ------------------------------------
2247
2248    #[test]
2249    fn test_exchange_code_success() {
2250        smol::block_on(async {
2251            let client = make_fake_http_client(|req| {
2252                Box::pin(async move {
2253                    let uri = req.uri().to_string();
2254                    if uri.contains("/token") {
2255                        json_response(
2256                            200,
2257                            r#"{
2258                                "access_token": "new_access_token",
2259                                "refresh_token": "new_refresh_token",
2260                                "expires_in": 3600,
2261                                "token_type": "Bearer"
2262                            }"#,
2263                        )
2264                    } else {
2265                        json_response(404, "{}")
2266                    }
2267                })
2268            });
2269
2270            let metadata = AuthServerMetadata {
2271                issuer: Url::parse("https://auth.example.com").unwrap(),
2272                authorization_endpoint: Url::parse("https://auth.example.com/authorize").unwrap(),
2273                token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
2274                registration_endpoint: None,
2275                scopes_supported: None,
2276                code_challenge_methods_supported: Some(vec!["S256".into()]),
2277                client_id_metadata_document_supported: true,
2278            };
2279
2280            let tokens = exchange_code(
2281                &client,
2282                &metadata,
2283                "auth_code_123",
2284                CIMD_URL,
2285                "http://127.0.0.1:9999/callback",
2286                "verifier_abc",
2287                "https://mcp.example.com",
2288            )
2289            .await
2290            .unwrap();
2291
2292            assert_eq!(tokens.access_token, "new_access_token");
2293            assert_eq!(tokens.refresh_token.as_deref(), Some("new_refresh_token"));
2294            assert!(tokens.expires_at.is_some());
2295        });
2296    }
2297
2298    #[test]
2299    fn test_refresh_tokens_success() {
2300        smol::block_on(async {
2301            let client = make_fake_http_client(|req| {
2302                Box::pin(async move {
2303                    let uri = req.uri().to_string();
2304                    if uri.contains("/token") {
2305                        json_response(
2306                            200,
2307                            r#"{
2308                                "access_token": "refreshed_token",
2309                                "expires_in": 1800,
2310                                "token_type": "Bearer"
2311                            }"#,
2312                        )
2313                    } else {
2314                        json_response(404, "{}")
2315                    }
2316                })
2317            });
2318
2319            let token_endpoint = Url::parse("https://auth.example.com/token").unwrap();
2320
2321            let tokens = refresh_tokens(
2322                &client,
2323                &token_endpoint,
2324                "old_refresh_token",
2325                CIMD_URL,
2326                "https://mcp.example.com",
2327            )
2328            .await
2329            .unwrap();
2330
2331            assert_eq!(tokens.access_token, "refreshed_token");
2332            assert_eq!(tokens.refresh_token, None);
2333            assert!(tokens.expires_at.is_some());
2334        });
2335    }
2336
2337    #[test]
2338    fn test_exchange_code_failure() {
2339        smol::block_on(async {
2340            let client = make_fake_http_client(|_req| {
2341                Box::pin(async move { json_response(400, r#"{"error": "invalid_grant"}"#) })
2342            });
2343
2344            let metadata = AuthServerMetadata {
2345                issuer: Url::parse("https://auth.example.com").unwrap(),
2346                authorization_endpoint: Url::parse("https://auth.example.com/authorize").unwrap(),
2347                token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
2348                registration_endpoint: None,
2349                scopes_supported: None,
2350                code_challenge_methods_supported: Some(vec!["S256".into()]),
2351                client_id_metadata_document_supported: true,
2352            };
2353
2354            let result = exchange_code(
2355                &client,
2356                &metadata,
2357                "bad_code",
2358                "client",
2359                "http://127.0.0.1:1/callback",
2360                "verifier",
2361                "https://mcp.example.com",
2362            )
2363            .await;
2364
2365            assert!(result.is_err());
2366            assert!(result.unwrap_err().to_string().contains("400"));
2367        });
2368    }
2369
2370    // -- DCR integration tests -----------------------------------------------
2371
2372    #[test]
2373    fn test_perform_dcr() {
2374        smol::block_on(async {
2375            let client = make_fake_http_client(|_req| {
2376                Box::pin(async move {
2377                    json_response(
2378                        201,
2379                        r#"{
2380                            "client_id": "dynamic-client-001",
2381                            "client_secret": "dynamic-secret-001"
2382                        }"#,
2383                    )
2384                })
2385            });
2386
2387            let endpoint = Url::parse("https://auth.example.com/register").unwrap();
2388            let registration = perform_dcr(&client, &endpoint, "http://127.0.0.1:9999/callback")
2389                .await
2390                .unwrap();
2391
2392            assert_eq!(registration.client_id, "dynamic-client-001");
2393            assert_eq!(
2394                registration.client_secret.as_deref(),
2395                Some("dynamic-secret-001")
2396            );
2397        });
2398    }
2399
2400    #[test]
2401    fn test_perform_dcr_failure() {
2402        smol::block_on(async {
2403            let client = make_fake_http_client(|_req| {
2404                Box::pin(
2405                    async move { json_response(403, r#"{"error": "registration_not_allowed"}"#) },
2406                )
2407            });
2408
2409            let endpoint = Url::parse("https://auth.example.com/register").unwrap();
2410            let result = perform_dcr(&client, &endpoint, "http://127.0.0.1:9999/callback").await;
2411
2412            assert!(result.is_err());
2413            assert!(result.unwrap_err().to_string().contains("403"));
2414        });
2415    }
2416
2417    // -- OAuthCallback parse tests -------------------------------------------
2418
2419    #[test]
2420    fn test_oauth_callback_parse_query() {
2421        let callback = OAuthCallback::parse_query("code=test_auth_code&state=test_state").unwrap();
2422        assert_eq!(callback.code, "test_auth_code");
2423        assert_eq!(callback.state, "test_state");
2424    }
2425
2426    #[test]
2427    fn test_oauth_callback_parse_query_reversed_order() {
2428        let callback = OAuthCallback::parse_query("state=test_state&code=test_auth_code").unwrap();
2429        assert_eq!(callback.code, "test_auth_code");
2430        assert_eq!(callback.state, "test_state");
2431    }
2432
2433    #[test]
2434    fn test_oauth_callback_parse_query_with_extra_params() {
2435        let callback =
2436            OAuthCallback::parse_query("code=test_auth_code&state=test_state&extra=ignored")
2437                .unwrap();
2438        assert_eq!(callback.code, "test_auth_code");
2439        assert_eq!(callback.state, "test_state");
2440    }
2441
2442    #[test]
2443    fn test_oauth_callback_parse_query_missing_code() {
2444        let result = OAuthCallback::parse_query("state=test_state");
2445        assert!(result.is_err());
2446        assert!(result.unwrap_err().to_string().contains("code"));
2447    }
2448
2449    #[test]
2450    fn test_oauth_callback_parse_query_missing_state() {
2451        let result = OAuthCallback::parse_query("code=test_auth_code");
2452        assert!(result.is_err());
2453        assert!(result.unwrap_err().to_string().contains("state"));
2454    }
2455
2456    #[test]
2457    fn test_oauth_callback_parse_query_empty_code() {
2458        let result = OAuthCallback::parse_query("code=&state=test_state");
2459        assert!(result.is_err());
2460    }
2461
2462    #[test]
2463    fn test_oauth_callback_parse_query_empty_state() {
2464        let result = OAuthCallback::parse_query("code=test_auth_code&state=");
2465        assert!(result.is_err());
2466    }
2467
2468    #[test]
2469    fn test_oauth_callback_parse_query_url_encoded_values() {
2470        let callback = OAuthCallback::parse_query("code=abc%20def&state=test%3Dstate").unwrap();
2471        assert_eq!(callback.code, "abc def");
2472        assert_eq!(callback.state, "test=state");
2473    }
2474
2475    #[test]
2476    fn test_oauth_callback_parse_query_error_response() {
2477        let result = OAuthCallback::parse_query(
2478            "error=access_denied&error_description=User%20denied%20access&state=abc",
2479        );
2480        assert!(result.is_err());
2481        let err_msg = result.unwrap_err().to_string();
2482        assert!(
2483            err_msg.contains("access_denied"),
2484            "unexpected error: {}",
2485            err_msg
2486        );
2487        assert!(
2488            err_msg.contains("User denied access"),
2489            "unexpected error: {}",
2490            err_msg
2491        );
2492    }
2493
2494    #[test]
2495    fn test_oauth_callback_parse_query_error_without_description() {
2496        let result = OAuthCallback::parse_query("error=server_error&state=abc");
2497        assert!(result.is_err());
2498        let err_msg = result.unwrap_err().to_string();
2499        assert!(
2500            err_msg.contains("server_error"),
2501            "unexpected error: {}",
2502            err_msg
2503        );
2504        assert!(
2505            err_msg.contains("no description"),
2506            "unexpected error: {}",
2507            err_msg
2508        );
2509    }
2510
2511    // -- McpOAuthTokenProvider tests -----------------------------------------
2512
2513    fn make_test_session(
2514        access_token: &str,
2515        refresh_token: Option<&str>,
2516        expires_at: Option<SystemTime>,
2517    ) -> OAuthSession {
2518        OAuthSession {
2519            token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
2520            resource: Url::parse("https://mcp.example.com").unwrap(),
2521            client_registration: OAuthClientRegistration {
2522                client_id: "test-client".into(),
2523                client_secret: None,
2524            },
2525            tokens: OAuthTokens {
2526                access_token: access_token.into(),
2527                refresh_token: refresh_token.map(String::from),
2528                expires_at,
2529            },
2530        }
2531    }
2532
2533    #[test]
2534    fn test_mcp_oauth_provider_returns_none_when_token_expired() {
2535        let expired = SystemTime::now() - Duration::from_secs(60);
2536        let session = make_test_session("stale-token", Some("rt"), Some(expired));
2537        let provider = McpOAuthTokenProvider::new(
2538            session,
2539            make_fake_http_client(|_| Box::pin(async { unreachable!() })),
2540            None,
2541        );
2542
2543        assert_eq!(provider.access_token(), None);
2544    }
2545
2546    #[test]
2547    fn test_mcp_oauth_provider_returns_token_when_not_expired() {
2548        let far_future = SystemTime::now() + Duration::from_secs(3600);
2549        let session = make_test_session("valid-token", Some("rt"), Some(far_future));
2550        let provider = McpOAuthTokenProvider::new(
2551            session,
2552            make_fake_http_client(|_| Box::pin(async { unreachable!() })),
2553            None,
2554        );
2555
2556        assert_eq!(provider.access_token().as_deref(), Some("valid-token"));
2557    }
2558
2559    #[test]
2560    fn test_mcp_oauth_provider_returns_token_when_no_expiry() {
2561        let session = make_test_session("no-expiry-token", Some("rt"), None);
2562        let provider = McpOAuthTokenProvider::new(
2563            session,
2564            make_fake_http_client(|_| Box::pin(async { unreachable!() })),
2565            None,
2566        );
2567
2568        assert_eq!(provider.access_token().as_deref(), Some("no-expiry-token"));
2569    }
2570
2571    #[test]
2572    fn test_mcp_oauth_provider_refresh_without_refresh_token_returns_false() {
2573        smol::block_on(async {
2574            let session = make_test_session("token", None, None);
2575            let provider = McpOAuthTokenProvider::new(
2576                session,
2577                make_fake_http_client(|_| {
2578                    Box::pin(async { unreachable!("no HTTP call expected") })
2579                }),
2580                None,
2581            );
2582
2583            let refreshed = provider.try_refresh().await.unwrap();
2584            assert!(!refreshed);
2585        });
2586    }
2587
2588    #[test]
2589    fn test_mcp_oauth_provider_refresh_updates_session_and_notifies_channel() {
2590        smol::block_on(async {
2591            let session = make_test_session("old-access", Some("my-refresh-token"), None);
2592            let (tx, mut rx) = futures::channel::mpsc::unbounded();
2593
2594            let http_client = make_fake_http_client(|_req| {
2595                Box::pin(async {
2596                    json_response(
2597                        200,
2598                        r#"{
2599                            "access_token": "new-access",
2600                            "refresh_token": "new-refresh",
2601                            "expires_in": 1800
2602                        }"#,
2603                    )
2604                })
2605            });
2606
2607            let provider = McpOAuthTokenProvider::new(session, http_client, Some(tx));
2608
2609            let refreshed = provider.try_refresh().await.unwrap();
2610            assert!(refreshed);
2611            assert_eq!(provider.access_token().as_deref(), Some("new-access"));
2612
2613            let notified_session = rx.try_recv().expect("channel should have a session");
2614            assert_eq!(notified_session.tokens.access_token, "new-access");
2615            assert_eq!(
2616                notified_session.tokens.refresh_token.as_deref(),
2617                Some("new-refresh")
2618            );
2619        });
2620    }
2621
2622    #[test]
2623    fn test_mcp_oauth_provider_refresh_preserves_old_refresh_token_when_server_omits_it() {
2624        smol::block_on(async {
2625            let session = make_test_session("old-access", Some("original-refresh"), None);
2626            let (tx, mut rx) = futures::channel::mpsc::unbounded();
2627
2628            let http_client = make_fake_http_client(|_req| {
2629                Box::pin(async {
2630                    json_response(
2631                        200,
2632                        r#"{
2633                            "access_token": "new-access",
2634                            "expires_in": 900
2635                        }"#,
2636                    )
2637                })
2638            });
2639
2640            let provider = McpOAuthTokenProvider::new(session, http_client, Some(tx));
2641
2642            let refreshed = provider.try_refresh().await.unwrap();
2643            assert!(refreshed);
2644
2645            let notified_session = rx.try_recv().expect("channel should have a session");
2646            assert_eq!(notified_session.tokens.access_token, "new-access");
2647            assert_eq!(
2648                notified_session.tokens.refresh_token.as_deref(),
2649                Some("original-refresh"),
2650            );
2651        });
2652    }
2653
2654    #[test]
2655    fn test_mcp_oauth_provider_refresh_returns_false_on_http_error() {
2656        smol::block_on(async {
2657            let session = make_test_session("old-access", Some("my-refresh"), None);
2658
2659            let http_client = make_fake_http_client(|_req| {
2660                Box::pin(async { json_response(401, r#"{"error": "invalid_grant"}"#) })
2661            });
2662
2663            let provider = McpOAuthTokenProvider::new(session, http_client, None);
2664
2665            let refreshed = provider.try_refresh().await.unwrap();
2666            assert!(!refreshed);
2667            // The old token should still be in place.
2668            assert_eq!(provider.access_token().as_deref(), Some("old-access"));
2669        });
2670    }
2671}