oauth.rs

   1//! OAuth 2.0 authentication for MCP servers using the Authorization Code +
   2//! PKCE flow, per the MCP spec's OAuth profile.
   3//!
   4//! The flow is split into two phases:
   5//!
   6//! 1. **Discovery** ([`discover`]) fetches Protected Resource Metadata and
   7//!    Authorization Server Metadata. This can happen early (e.g. on a 401
   8//!    during server startup) because it doesn't need the redirect URI yet.
   9//!
  10//! 2. **Client registration** ([`resolve_client_registration`]) is separate
  11//!    because DCR requires the actual loopback redirect URI, which includes an
  12//!    ephemeral port that only exists once the callback server has started.
  13//!
  14//! After authentication, the full state is captured in [`OAuthSession`] which
  15//! is persisted to the keychain. On next startup, the stored session feeds
  16//! directly into [`McpOAuthTokenProvider`], giving a refresh-capable provider
  17//! without requiring another browser flow.
  18
  19use anyhow::{Context as _, Result, anyhow, bail};
  20use async_trait::async_trait;
  21use base64::Engine as _;
  22use futures::AsyncReadExt as _;
  23use futures::channel::mpsc;
  24use http_client::{AsyncBody, HttpClient, Request};
  25use parking_lot::Mutex as SyncMutex;
  26use rand::Rng as _;
  27use serde::{Deserialize, Serialize};
  28use sha2::{Digest, Sha256};
  29
  30use std::str::FromStr;
  31use std::sync::Arc;
  32use std::time::{Duration, SystemTime};
  33use url::Url;
  34use util::ResultExt as _;
  35
  36/// The CIMD URL where Zed's OAuth client metadata document is hosted.
  37pub const CIMD_URL: &str = "https://zed.dev/oauth/client-metadata.json";
  38
  39/// Validate that a URL is safe to use as an OAuth endpoint.
  40///
  41/// OAuth endpoints carry sensitive material (authorization codes, PKCE
  42/// verifiers, tokens) and must use TLS. Plain HTTP is only permitted for
  43/// loopback addresses, per RFC 8252 Section 8.3.
  44fn require_https_or_loopback(url: &Url) -> Result<()> {
  45    if url.scheme() == "https" {
  46        return Ok(());
  47    }
  48    if url.scheme() == "http" {
  49        if let Some(host) = url.host() {
  50            match host {
  51                url::Host::Ipv4(ip) if ip.is_loopback() => return Ok(()),
  52                url::Host::Ipv6(ip) if ip.is_loopback() => return Ok(()),
  53                url::Host::Domain(d) if d.eq_ignore_ascii_case("localhost") => return Ok(()),
  54                _ => {}
  55            }
  56        }
  57    }
  58    bail!(
  59        "OAuth endpoint must use HTTPS (got {}://{})",
  60        url.scheme(),
  61        url.host_str().unwrap_or("?")
  62    )
  63}
  64
  65/// Validate that a URL is safe to use as an OAuth endpoint, including SSRF
  66/// protections against private/reserved IP ranges.
  67///
  68/// This wraps [`require_https_or_loopback`] and adds IP-range checks to prevent
  69/// an attacker-controlled MCP server from directing Zed to fetch internal
  70/// network resources via metadata URLs.
  71///
  72/// **Known limitation:** Domain-name URLs that resolve to private IPs are *not*
  73/// blocked here — full mitigation requires resolver-level validation (e.g. a
  74/// custom `Resolve` implementation). This function only blocks IP-literal URLs.
  75fn validate_oauth_url(url: &Url) -> Result<()> {
  76    require_https_or_loopback(url)?;
  77
  78    if let Some(host) = url.host() {
  79        match host {
  80            url::Host::Ipv4(ip) => {
  81                // Loopback is already allowed by require_https_or_loopback.
  82                if ip.is_private() || ip.is_link_local() || ip.is_broadcast() || ip.is_unspecified()
  83                {
  84                    bail!(
  85                        "OAuth endpoint must not point to private/reserved IP: {}",
  86                        ip
  87                    );
  88                }
  89            }
  90            url::Host::Ipv6(ip) => {
  91                // Check for IPv4-mapped IPv6 addresses (::ffff:a.b.c.d) which
  92                // could bypass the IPv4 checks above.
  93                if let Some(mapped_v4) = ip.to_ipv4_mapped() {
  94                    if mapped_v4.is_private()
  95                        || mapped_v4.is_link_local()
  96                        || mapped_v4.is_broadcast()
  97                        || mapped_v4.is_unspecified()
  98                    {
  99                        bail!(
 100                            "OAuth endpoint must not point to private/reserved IP: ::ffff:{}",
 101                            mapped_v4
 102                        );
 103                    }
 104                }
 105
 106                if ip.is_unspecified() || ip.is_multicast() {
 107                    bail!(
 108                        "OAuth endpoint must not point to reserved IPv6 address: {}",
 109                        ip
 110                    );
 111                }
 112                // IPv6 Unique Local Addresses (fc00::/7). is_unique_local() is
 113                // nightly-only, so check the prefix manually.
 114                if (ip.segments()[0] & 0xfe00) == 0xfc00 {
 115                    bail!(
 116                        "OAuth endpoint must not point to IPv6 unique-local address: {}",
 117                        ip
 118                    );
 119                }
 120            }
 121            url::Host::Domain(_) => {
 122                // Domain-based SSRF prevention requires resolver-level checks.
 123                // See known limitation in the doc comment above.
 124            }
 125        }
 126    }
 127
 128    Ok(())
 129}
 130
 131/// Parsed from the MCP server's WWW-Authenticate header or well-known endpoint
 132/// per RFC 9728 (OAuth 2.0 Protected Resource Metadata).
 133#[derive(Debug, Clone, Serialize, Deserialize)]
 134pub struct ProtectedResourceMetadata {
 135    pub resource: Url,
 136    pub authorization_servers: Vec<Url>,
 137    pub scopes_supported: Option<Vec<String>>,
 138}
 139
 140/// Parsed from the authorization server's .well-known endpoint
 141/// per RFC 8414 (OAuth 2.0 Authorization Server Metadata).
 142#[derive(Debug, Clone, Serialize, Deserialize)]
 143pub struct AuthServerMetadata {
 144    pub issuer: Url,
 145    pub authorization_endpoint: Url,
 146    pub token_endpoint: Url,
 147    pub registration_endpoint: Option<Url>,
 148    pub scopes_supported: Option<Vec<String>>,
 149    pub code_challenge_methods_supported: Option<Vec<String>>,
 150    pub client_id_metadata_document_supported: bool,
 151}
 152
 153/// The result of client registration — either CIMD or DCR.
 154#[derive(Clone, Serialize, Deserialize)]
 155pub struct OAuthClientRegistration {
 156    pub client_id: String,
 157    /// Only present for DCR-minted registrations.
 158    pub client_secret: Option<String>,
 159}
 160
 161impl std::fmt::Debug for OAuthClientRegistration {
 162    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 163        f.debug_struct("OAuthClientRegistration")
 164            .field("client_id", &self.client_id)
 165            .field(
 166                "client_secret",
 167                &self.client_secret.as_ref().map(|_| "[redacted]"),
 168            )
 169            .finish()
 170    }
 171}
 172
 173/// Access and refresh tokens obtained from the token endpoint.
 174#[derive(Clone, Serialize, Deserialize)]
 175pub struct OAuthTokens {
 176    pub access_token: String,
 177    pub refresh_token: Option<String>,
 178    pub expires_at: Option<SystemTime>,
 179}
 180
 181impl std::fmt::Debug for OAuthTokens {
 182    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 183        f.debug_struct("OAuthTokens")
 184            .field("access_token", &"[redacted]")
 185            .field(
 186                "refresh_token",
 187                &self.refresh_token.as_ref().map(|_| "[redacted]"),
 188            )
 189            .field("expires_at", &self.expires_at)
 190            .finish()
 191    }
 192}
 193
 194/// Everything discovered before the browser flow starts. Client registration is
 195/// resolved separately, once the real redirect URI is known.
 196#[derive(Debug, Clone, Serialize, Deserialize)]
 197pub struct OAuthDiscovery {
 198    pub resource_metadata: ProtectedResourceMetadata,
 199    pub auth_server_metadata: AuthServerMetadata,
 200    pub scopes: Vec<String>,
 201}
 202
 203/// The persisted OAuth session for a context server.
 204///
 205/// Stored in the keychain so startup can restore a refresh-capable provider
 206/// without another browser flow. Deliberately excludes the full discovery
 207/// metadata to keep the serialized size well within keychain item limits.
 208#[derive(Clone, Serialize, Deserialize)]
 209pub struct OAuthSession {
 210    pub token_endpoint: Url,
 211    pub resource: Url,
 212    pub client_registration: OAuthClientRegistration,
 213    pub tokens: OAuthTokens,
 214}
 215
 216impl std::fmt::Debug for OAuthSession {
 217    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 218        f.debug_struct("OAuthSession")
 219            .field("token_endpoint", &self.token_endpoint)
 220            .field("resource", &self.resource)
 221            .field("client_registration", &self.client_registration)
 222            .field("tokens", &self.tokens)
 223            .finish()
 224    }
 225}
 226
 227/// Error codes defined by RFC 6750 Section 3.1 for Bearer token authentication.
 228#[derive(Debug, Clone, Copy, PartialEq, Eq)]
 229pub enum BearerError {
 230    /// The request is missing a required parameter, includes an unsupported
 231    /// parameter or parameter value, or is otherwise malformed.
 232    InvalidRequest,
 233    /// The access token provided is expired, revoked, malformed, or invalid.
 234    InvalidToken,
 235    /// The request requires higher privileges than provided by the access token.
 236    InsufficientScope,
 237    /// An unrecognized error code (extension or future spec addition).
 238    Other,
 239}
 240
 241impl BearerError {
 242    fn parse(value: &str) -> Self {
 243        match value {
 244            "invalid_request" => BearerError::InvalidRequest,
 245            "invalid_token" => BearerError::InvalidToken,
 246            "insufficient_scope" => BearerError::InsufficientScope,
 247            _ => BearerError::Other,
 248        }
 249    }
 250}
 251
 252/// Fields extracted from a `WWW-Authenticate: Bearer` header.
 253///
 254/// Per RFC 9728 Section 5.1, MCP servers include `resource_metadata` to point
 255/// at the Protected Resource Metadata document. The optional `scope` parameter
 256/// (RFC 6750 Section 3) indicates scopes required for the request.
 257#[derive(Debug, Clone, PartialEq, Eq)]
 258pub struct WwwAuthenticate {
 259    pub resource_metadata: Option<Url>,
 260    pub scope: Option<Vec<String>>,
 261    /// The parsed `error` parameter per RFC 6750 Section 3.1.
 262    pub error: Option<BearerError>,
 263    pub error_description: Option<String>,
 264}
 265
 266/// Parse a `WWW-Authenticate` header value.
 267///
 268/// Expects the `Bearer` scheme followed by comma-separated `key="value"` pairs.
 269/// Per RFC 6750 and RFC 9728, the relevant parameters are:
 270/// - `resource_metadata` — URL of the Protected Resource Metadata document
 271/// - `scope` — space-separated list of required scopes
 272/// - `error` — error code (e.g. "insufficient_scope")
 273/// - `error_description` — human-readable error description
 274pub fn parse_www_authenticate(header: &str) -> Result<WwwAuthenticate> {
 275    let header = header.trim();
 276
 277    let params_str = if header.len() >= 6 && header[..6].eq_ignore_ascii_case("bearer") {
 278        header[6..].trim()
 279    } else {
 280        bail!("WWW-Authenticate header does not use Bearer scheme");
 281    };
 282
 283    if params_str.is_empty() {
 284        return Ok(WwwAuthenticate {
 285            resource_metadata: None,
 286            scope: None,
 287            error: None,
 288            error_description: None,
 289        });
 290    }
 291
 292    let params = parse_auth_params(params_str);
 293
 294    let resource_metadata = params
 295        .get("resource_metadata")
 296        .map(|v| Url::parse(v))
 297        .transpose()
 298        .map_err(|e| anyhow!("invalid resource_metadata URL: {}", e))?;
 299
 300    let scope = params
 301        .get("scope")
 302        .map(|v| v.split_whitespace().map(String::from).collect());
 303
 304    let error = params.get("error").map(|v| BearerError::parse(v));
 305    let error_description = params.get("error_description").cloned();
 306
 307    Ok(WwwAuthenticate {
 308        resource_metadata,
 309        scope,
 310        error,
 311        error_description,
 312    })
 313}
 314
 315/// Parse comma-separated `key="value"` or `key=token` parameters from an
 316/// auth-param list (RFC 7235 Section 2.1).
 317fn parse_auth_params(input: &str) -> collections::HashMap<String, String> {
 318    let mut params = collections::HashMap::default();
 319    let mut remaining = input.trim();
 320
 321    while !remaining.is_empty() {
 322        // Skip leading whitespace and commas.
 323        remaining = remaining.trim_start_matches(|c: char| c == ',' || c.is_whitespace());
 324        if remaining.is_empty() {
 325            break;
 326        }
 327
 328        // Find the key (everything before '=').
 329        let eq_pos = match remaining.find('=') {
 330            Some(pos) => pos,
 331            None => break,
 332        };
 333
 334        let key = remaining[..eq_pos].trim().to_lowercase();
 335        remaining = &remaining[eq_pos + 1..];
 336        remaining = remaining.trim_start();
 337
 338        // Parse the value: either quoted or unquoted (token).
 339        let value;
 340        if remaining.starts_with('"') {
 341            // Quoted string: find the closing quote, handling escaped chars.
 342            remaining = &remaining[1..]; // skip opening quote
 343            let mut val = String::new();
 344            let mut chars = remaining.char_indices();
 345            loop {
 346                match chars.next() {
 347                    Some((_, '\\')) => {
 348                        // Escaped character — take the next char literally.
 349                        if let Some((_, c)) = chars.next() {
 350                            val.push(c);
 351                        }
 352                    }
 353                    Some((i, '"')) => {
 354                        remaining = &remaining[i + 1..];
 355                        break;
 356                    }
 357                    Some((_, c)) => val.push(c),
 358                    None => {
 359                        remaining = "";
 360                        break;
 361                    }
 362                }
 363            }
 364            value = val;
 365        } else {
 366            // Unquoted token: read until comma or whitespace.
 367            let end = remaining
 368                .find(|c: char| c == ',' || c.is_whitespace())
 369                .unwrap_or(remaining.len());
 370            value = remaining[..end].to_string();
 371            remaining = &remaining[end..];
 372        }
 373
 374        if !key.is_empty() {
 375            params.insert(key, value);
 376        }
 377    }
 378
 379    params
 380}
 381
 382/// Construct the well-known Protected Resource Metadata URIs for a given MCP
 383/// server URL, per RFC 9728 Section 3.
 384///
 385/// Returns URIs in priority order:
 386/// 1. Path-specific: `https://<host>/.well-known/oauth-protected-resource/<path>`
 387/// 2. Root: `https://<host>/.well-known/oauth-protected-resource`
 388pub fn protected_resource_metadata_urls(server_url: &Url) -> Vec<Url> {
 389    let mut urls = Vec::new();
 390    let base = format!("{}://{}", server_url.scheme(), server_url.authority());
 391
 392    let path = server_url.path().trim_start_matches('/');
 393    if !path.is_empty() {
 394        if let Ok(url) = Url::parse(&format!(
 395            "{}/.well-known/oauth-protected-resource/{}",
 396            base, path
 397        )) {
 398            urls.push(url);
 399        }
 400    }
 401
 402    if let Ok(url) = Url::parse(&format!("{}/.well-known/oauth-protected-resource", base)) {
 403        urls.push(url);
 404    }
 405
 406    urls
 407}
 408
 409/// Construct the well-known Authorization Server Metadata URIs for a given
 410/// issuer URL, per RFC 8414 Section 3.1 and Section 5 (OIDC compat).
 411///
 412/// Returns URIs in priority order, which differs depending on whether the
 413/// issuer URL has a path component.
 414pub fn auth_server_metadata_urls(issuer: &Url) -> Vec<Url> {
 415    let mut urls = Vec::new();
 416    let base = format!("{}://{}", issuer.scheme(), issuer.authority());
 417    let path = issuer.path().trim_matches('/');
 418
 419    if !path.is_empty() {
 420        // Issuer with path: try path-inserted variants first.
 421        if let Ok(url) = Url::parse(&format!(
 422            "{}/.well-known/oauth-authorization-server/{}",
 423            base, path
 424        )) {
 425            urls.push(url);
 426        }
 427        if let Ok(url) = Url::parse(&format!(
 428            "{}/.well-known/openid-configuration/{}",
 429            base, path
 430        )) {
 431            urls.push(url);
 432        }
 433        if let Ok(url) = Url::parse(&format!(
 434            "{}/{}/.well-known/openid-configuration",
 435            base, path
 436        )) {
 437            urls.push(url);
 438        }
 439    } else {
 440        // No path: standard well-known locations.
 441        if let Ok(url) = Url::parse(&format!("{}/.well-known/oauth-authorization-server", base)) {
 442            urls.push(url);
 443        }
 444        if let Ok(url) = Url::parse(&format!("{}/.well-known/openid-configuration", base)) {
 445            urls.push(url);
 446        }
 447    }
 448
 449    urls
 450}
 451
 452// -- Canonical server URI (RFC 8707) -----------------------------------------
 453
 454/// Derive the canonical resource URI for an MCP server URL, suitable for the
 455/// `resource` parameter in authorization and token requests per RFC 8707.
 456///
 457/// Lowercases the scheme and host, preserves the path (without trailing slash),
 458/// strips fragments and query strings.
 459pub fn canonical_server_uri(server_url: &Url) -> String {
 460    let mut uri = format!(
 461        "{}://{}",
 462        server_url.scheme().to_ascii_lowercase(),
 463        server_url.host_str().unwrap_or("").to_ascii_lowercase(),
 464    );
 465    if let Some(port) = server_url.port() {
 466        uri.push_str(&format!(":{}", port));
 467    }
 468    let path = server_url.path();
 469    if path != "/" {
 470        uri.push_str(path.trim_end_matches('/'));
 471    }
 472    uri
 473}
 474
 475// -- Scope selection ---------------------------------------------------------
 476
 477/// Select scopes following the MCP spec's Scope Selection Strategy:
 478/// 1. Use `scope` from the `WWW-Authenticate` challenge if present.
 479/// 2. Fall back to `scopes_supported` from Protected Resource Metadata.
 480/// 3. Return empty if neither is available.
 481pub fn select_scopes(
 482    www_authenticate: &WwwAuthenticate,
 483    resource_metadata: &ProtectedResourceMetadata,
 484) -> Vec<String> {
 485    if let Some(ref scopes) = www_authenticate.scope {
 486        if !scopes.is_empty() {
 487            return scopes.clone();
 488        }
 489    }
 490    resource_metadata
 491        .scopes_supported
 492        .clone()
 493        .unwrap_or_default()
 494}
 495
 496// -- Client registration strategy --------------------------------------------
 497
 498/// The registration approach to use, determined from auth server metadata.
 499#[derive(Debug, Clone, PartialEq, Eq)]
 500pub enum ClientRegistrationStrategy {
 501    /// The auth server supports CIMD. Use the CIMD URL as client_id directly.
 502    Cimd { client_id: String },
 503    /// The auth server has a registration endpoint. Caller must POST to it.
 504    Dcr { registration_endpoint: Url },
 505    /// No supported registration mechanism.
 506    Unavailable,
 507}
 508
 509/// Determine how to register with the authorization server, following the
 510/// spec's recommended priority: CIMD first, DCR fallback.
 511pub fn determine_registration_strategy(
 512    auth_server_metadata: &AuthServerMetadata,
 513) -> ClientRegistrationStrategy {
 514    if auth_server_metadata.client_id_metadata_document_supported {
 515        ClientRegistrationStrategy::Cimd {
 516            client_id: CIMD_URL.to_string(),
 517        }
 518    } else if let Some(ref endpoint) = auth_server_metadata.registration_endpoint {
 519        ClientRegistrationStrategy::Dcr {
 520            registration_endpoint: endpoint.clone(),
 521        }
 522    } else {
 523        ClientRegistrationStrategy::Unavailable
 524    }
 525}
 526
 527// -- PKCE (RFC 7636) ---------------------------------------------------------
 528
 529/// A PKCE code verifier and its S256 challenge.
 530#[derive(Clone)]
 531pub struct PkceChallenge {
 532    pub verifier: String,
 533    pub challenge: String,
 534}
 535
 536impl std::fmt::Debug for PkceChallenge {
 537    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 538        f.debug_struct("PkceChallenge")
 539            .field("verifier", &"[redacted]")
 540            .field("challenge", &self.challenge)
 541            .finish()
 542    }
 543}
 544
 545/// Generate a PKCE code verifier and S256 challenge per RFC 7636.
 546///
 547/// The verifier is 43 base64url characters derived from 32 random bytes.
 548/// The challenge is `BASE64URL(SHA256(verifier))`.
 549pub fn generate_pkce_challenge() -> PkceChallenge {
 550    let mut random_bytes = [0u8; 32];
 551    rand::rng().fill(&mut random_bytes);
 552    let engine = base64::engine::general_purpose::URL_SAFE_NO_PAD;
 553    let verifier = engine.encode(&random_bytes);
 554
 555    let digest = Sha256::digest(verifier.as_bytes());
 556    let challenge = engine.encode(digest);
 557
 558    PkceChallenge {
 559        verifier,
 560        challenge,
 561    }
 562}
 563
 564// -- Authorization URL construction ------------------------------------------
 565
 566/// Build the authorization URL for the OAuth Authorization Code + PKCE flow.
 567pub fn build_authorization_url(
 568    auth_server_metadata: &AuthServerMetadata,
 569    client_id: &str,
 570    redirect_uri: &str,
 571    scopes: &[String],
 572    resource: &str,
 573    pkce: &PkceChallenge,
 574    state: &str,
 575) -> Url {
 576    let mut url = auth_server_metadata.authorization_endpoint.clone();
 577    {
 578        let mut query = url.query_pairs_mut();
 579        query.append_pair("response_type", "code");
 580        query.append_pair("client_id", client_id);
 581        query.append_pair("redirect_uri", redirect_uri);
 582        if !scopes.is_empty() {
 583            query.append_pair("scope", &scopes.join(" "));
 584        }
 585        query.append_pair("resource", resource);
 586        query.append_pair("code_challenge", &pkce.challenge);
 587        query.append_pair("code_challenge_method", "S256");
 588        query.append_pair("state", state);
 589    }
 590    url
 591}
 592
 593// -- Token endpoint request bodies -------------------------------------------
 594
 595/// The JSON body returned by the token endpoint on success.
 596#[derive(Deserialize)]
 597pub struct TokenResponse {
 598    pub access_token: String,
 599    #[serde(default)]
 600    pub refresh_token: Option<String>,
 601    #[serde(default)]
 602    pub expires_in: Option<u64>,
 603    #[serde(default)]
 604    pub token_type: Option<String>,
 605}
 606
 607impl std::fmt::Debug for TokenResponse {
 608    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 609        f.debug_struct("TokenResponse")
 610            .field("access_token", &"[redacted]")
 611            .field(
 612                "refresh_token",
 613                &self.refresh_token.as_ref().map(|_| "[redacted]"),
 614            )
 615            .field("expires_in", &self.expires_in)
 616            .field("token_type", &self.token_type)
 617            .finish()
 618    }
 619}
 620
 621impl TokenResponse {
 622    /// Convert into `OAuthTokens`, computing `expires_at` from `expires_in`.
 623    pub fn into_tokens(self) -> OAuthTokens {
 624        let expires_at = self
 625            .expires_in
 626            .map(|secs| SystemTime::now() + Duration::from_secs(secs));
 627        OAuthTokens {
 628            access_token: self.access_token,
 629            refresh_token: self.refresh_token,
 630            expires_at,
 631        }
 632    }
 633}
 634
 635/// Build the form-encoded body for an authorization code token exchange.
 636pub fn token_exchange_params(
 637    code: &str,
 638    client_id: &str,
 639    redirect_uri: &str,
 640    code_verifier: &str,
 641    resource: &str,
 642) -> Vec<(&'static str, String)> {
 643    vec![
 644        ("grant_type", "authorization_code".to_string()),
 645        ("code", code.to_string()),
 646        ("redirect_uri", redirect_uri.to_string()),
 647        ("client_id", client_id.to_string()),
 648        ("code_verifier", code_verifier.to_string()),
 649        ("resource", resource.to_string()),
 650    ]
 651}
 652
 653/// Build the form-encoded body for a token refresh request.
 654pub fn token_refresh_params(
 655    refresh_token: &str,
 656    client_id: &str,
 657    resource: &str,
 658) -> Vec<(&'static str, String)> {
 659    vec![
 660        ("grant_type", "refresh_token".to_string()),
 661        ("refresh_token", refresh_token.to_string()),
 662        ("client_id", client_id.to_string()),
 663        ("resource", resource.to_string()),
 664    ]
 665}
 666
 667// -- DCR request body (RFC 7591) ---------------------------------------------
 668
 669/// Build the JSON body for a Dynamic Client Registration request.
 670///
 671/// The `redirect_uri` should be the actual loopback URI with the ephemeral
 672/// port (e.g. `http://127.0.0.1:12345/callback`). Some auth servers do strict
 673/// redirect URI matching even for loopback addresses, so we register the
 674/// exact URI we intend to use.
 675pub fn dcr_registration_body(redirect_uri: &str) -> serde_json::Value {
 676    serde_json::json!({
 677        "client_name": "Zed",
 678        "redirect_uris": [redirect_uri],
 679        "grant_types": ["authorization_code"],
 680        "response_types": ["code"],
 681        "token_endpoint_auth_method": "none"
 682    })
 683}
 684
 685// -- Discovery (async, hits real endpoints) ----------------------------------
 686
 687/// Fetch Protected Resource Metadata from the MCP server.
 688///
 689/// Tries the `resource_metadata` URL from the `WWW-Authenticate` header first,
 690/// then falls back to well-known URIs constructed from `server_url`.
 691pub async fn fetch_protected_resource_metadata(
 692    http_client: &Arc<dyn HttpClient>,
 693    server_url: &Url,
 694    www_authenticate: &WwwAuthenticate,
 695) -> Result<ProtectedResourceMetadata> {
 696    let candidate_urls = match &www_authenticate.resource_metadata {
 697        Some(url) if url.origin() == server_url.origin() => vec![url.clone()],
 698        Some(url) => {
 699            log::warn!(
 700                "Ignoring cross-origin resource_metadata URL {} \
 701                 (server origin: {})",
 702                url,
 703                server_url.origin().unicode_serialization()
 704            );
 705            protected_resource_metadata_urls(server_url)
 706        }
 707        None => protected_resource_metadata_urls(server_url),
 708    };
 709
 710    for url in &candidate_urls {
 711        match fetch_json::<ProtectedResourceMetadataResponse>(http_client, url).await {
 712            Ok(response) => {
 713                if response.authorization_servers.is_empty() {
 714                    bail!(
 715                        "Protected Resource Metadata at {} has no authorization_servers",
 716                        url
 717                    );
 718                }
 719                return Ok(ProtectedResourceMetadata {
 720                    resource: response.resource.unwrap_or_else(|| server_url.clone()),
 721                    authorization_servers: response.authorization_servers,
 722                    scopes_supported: response.scopes_supported,
 723                });
 724            }
 725            Err(err) => {
 726                log::debug!(
 727                    "Failed to fetch Protected Resource Metadata from {}: {}",
 728                    url,
 729                    err
 730                );
 731            }
 732        }
 733    }
 734
 735    bail!(
 736        "Could not fetch Protected Resource Metadata for {}",
 737        server_url
 738    )
 739}
 740
 741/// Fetch Authorization Server Metadata, trying RFC 8414 and OIDC Discovery
 742/// endpoints in the priority order specified by the MCP spec.
 743pub async fn fetch_auth_server_metadata(
 744    http_client: &Arc<dyn HttpClient>,
 745    issuer: &Url,
 746) -> Result<AuthServerMetadata> {
 747    let candidate_urls = auth_server_metadata_urls(issuer);
 748
 749    for url in &candidate_urls {
 750        match fetch_json::<AuthServerMetadataResponse>(http_client, url).await {
 751            Ok(response) => {
 752                let reported_issuer = response.issuer.unwrap_or_else(|| issuer.clone());
 753                if reported_issuer != *issuer {
 754                    bail!(
 755                        "Auth server metadata issuer mismatch: expected {}, got {}",
 756                        issuer,
 757                        reported_issuer
 758                    );
 759                }
 760
 761                return Ok(AuthServerMetadata {
 762                    issuer: reported_issuer,
 763                    authorization_endpoint: response
 764                        .authorization_endpoint
 765                        .ok_or_else(|| anyhow!("missing authorization_endpoint"))?,
 766                    token_endpoint: response
 767                        .token_endpoint
 768                        .ok_or_else(|| anyhow!("missing token_endpoint"))?,
 769                    registration_endpoint: response.registration_endpoint,
 770                    scopes_supported: response.scopes_supported,
 771                    code_challenge_methods_supported: response.code_challenge_methods_supported,
 772                    client_id_metadata_document_supported: response
 773                        .client_id_metadata_document_supported
 774                        .unwrap_or(false),
 775                });
 776            }
 777            Err(err) => {
 778                log::debug!("Failed to fetch Auth Server Metadata from {}: {}", url, err);
 779            }
 780        }
 781    }
 782
 783    bail!(
 784        "Could not fetch Authorization Server Metadata for {}",
 785        issuer
 786    )
 787}
 788
 789/// Run the full discovery flow: fetch resource metadata, then auth server
 790/// metadata, then select scopes. Client registration is resolved separately,
 791/// once the real redirect URI is known.
 792pub async fn discover(
 793    http_client: &Arc<dyn HttpClient>,
 794    server_url: &Url,
 795    www_authenticate: &WwwAuthenticate,
 796) -> Result<OAuthDiscovery> {
 797    let resource_metadata =
 798        fetch_protected_resource_metadata(http_client, server_url, www_authenticate).await?;
 799
 800    let auth_server_url = resource_metadata
 801        .authorization_servers
 802        .first()
 803        .ok_or_else(|| anyhow!("no authorization servers in resource metadata"))?;
 804
 805    let auth_server_metadata = fetch_auth_server_metadata(http_client, auth_server_url).await?;
 806
 807    // Verify PKCE S256 support (spec requirement).
 808    match &auth_server_metadata.code_challenge_methods_supported {
 809        Some(methods) if methods.iter().any(|m| m == "S256") => {}
 810        Some(_) => bail!("authorization server does not support S256 PKCE"),
 811        None => bail!("authorization server does not advertise code_challenge_methods_supported"),
 812    }
 813
 814    // Verify there is at least one supported registration strategy before we
 815    // present the server as ready to authenticate.
 816    match determine_registration_strategy(&auth_server_metadata) {
 817        ClientRegistrationStrategy::Cimd { .. } | ClientRegistrationStrategy::Dcr { .. } => {}
 818        ClientRegistrationStrategy::Unavailable => {
 819            bail!("authorization server supports neither CIMD nor DCR")
 820        }
 821    }
 822
 823    let scopes = select_scopes(www_authenticate, &resource_metadata);
 824
 825    Ok(OAuthDiscovery {
 826        resource_metadata,
 827        auth_server_metadata,
 828        scopes,
 829    })
 830}
 831
 832/// Resolve the OAuth client registration for an authorization flow.
 833///
 834/// CIMD uses the static client metadata document directly. For DCR, a fresh
 835/// registration is performed each time because the loopback redirect URI
 836/// includes an ephemeral port that changes every flow.
 837pub async fn resolve_client_registration(
 838    http_client: &Arc<dyn HttpClient>,
 839    discovery: &OAuthDiscovery,
 840    redirect_uri: &str,
 841) -> Result<OAuthClientRegistration> {
 842    match determine_registration_strategy(&discovery.auth_server_metadata) {
 843        ClientRegistrationStrategy::Cimd { client_id } => Ok(OAuthClientRegistration {
 844            client_id,
 845            client_secret: None,
 846        }),
 847        ClientRegistrationStrategy::Dcr {
 848            registration_endpoint,
 849        } => perform_dcr(http_client, &registration_endpoint, redirect_uri).await,
 850        ClientRegistrationStrategy::Unavailable => {
 851            bail!("authorization server supports neither CIMD nor DCR")
 852        }
 853    }
 854}
 855
 856// -- Dynamic Client Registration (RFC 7591) ----------------------------------
 857
 858/// Perform Dynamic Client Registration with the authorization server.
 859pub async fn perform_dcr(
 860    http_client: &Arc<dyn HttpClient>,
 861    registration_endpoint: &Url,
 862    redirect_uri: &str,
 863) -> Result<OAuthClientRegistration> {
 864    validate_oauth_url(registration_endpoint)?;
 865
 866    let body = dcr_registration_body(redirect_uri);
 867    let body_bytes = serde_json::to_vec(&body)?;
 868
 869    let request = Request::builder()
 870        .method(http_client::http::Method::POST)
 871        .uri(registration_endpoint.as_str())
 872        .header("Content-Type", "application/json")
 873        .header("Accept", "application/json")
 874        .body(AsyncBody::from(body_bytes))?;
 875
 876    let mut response = http_client.send(request).await?;
 877
 878    if !response.status().is_success() {
 879        let mut error_body = String::new();
 880        response.body_mut().read_to_string(&mut error_body).await?;
 881        bail!(
 882            "DCR failed with status {}: {}",
 883            response.status(),
 884            error_body
 885        );
 886    }
 887
 888    let mut response_body = String::new();
 889    response
 890        .body_mut()
 891        .read_to_string(&mut response_body)
 892        .await?;
 893
 894    let dcr_response: DcrResponse =
 895        serde_json::from_str(&response_body).context("failed to parse DCR response")?;
 896
 897    Ok(OAuthClientRegistration {
 898        client_id: dcr_response.client_id,
 899        client_secret: dcr_response.client_secret,
 900    })
 901}
 902
 903// -- Token exchange and refresh (async) --------------------------------------
 904
 905/// Exchange an authorization code for tokens at the token endpoint.
 906pub async fn exchange_code(
 907    http_client: &Arc<dyn HttpClient>,
 908    auth_server_metadata: &AuthServerMetadata,
 909    code: &str,
 910    client_id: &str,
 911    redirect_uri: &str,
 912    code_verifier: &str,
 913    resource: &str,
 914) -> Result<OAuthTokens> {
 915    let params = token_exchange_params(code, client_id, redirect_uri, code_verifier, resource);
 916    post_token_request(http_client, &auth_server_metadata.token_endpoint, &params).await
 917}
 918
 919/// Refresh tokens using a refresh token.
 920pub async fn refresh_tokens(
 921    http_client: &Arc<dyn HttpClient>,
 922    token_endpoint: &Url,
 923    refresh_token: &str,
 924    client_id: &str,
 925    resource: &str,
 926) -> Result<OAuthTokens> {
 927    let params = token_refresh_params(refresh_token, client_id, resource);
 928    post_token_request(http_client, token_endpoint, &params).await
 929}
 930
 931/// POST form-encoded parameters to a token endpoint and parse the response.
 932async fn post_token_request(
 933    http_client: &Arc<dyn HttpClient>,
 934    token_endpoint: &Url,
 935    params: &[(&str, String)],
 936) -> Result<OAuthTokens> {
 937    validate_oauth_url(token_endpoint)?;
 938
 939    let body = url::form_urlencoded::Serializer::new(String::new())
 940        .extend_pairs(params.iter().map(|(k, v)| (*k, v.as_str())))
 941        .finish();
 942
 943    let request = Request::builder()
 944        .method(http_client::http::Method::POST)
 945        .uri(token_endpoint.as_str())
 946        .header("Content-Type", "application/x-www-form-urlencoded")
 947        .header("Accept", "application/json")
 948        .body(AsyncBody::from(body.into_bytes()))?;
 949
 950    let mut response = http_client.send(request).await?;
 951
 952    if !response.status().is_success() {
 953        let mut error_body = String::new();
 954        response.body_mut().read_to_string(&mut error_body).await?;
 955        bail!(
 956            "token request failed with status {}: {}",
 957            response.status(),
 958            error_body
 959        );
 960    }
 961
 962    let mut response_body = String::new();
 963    response
 964        .body_mut()
 965        .read_to_string(&mut response_body)
 966        .await?;
 967
 968    let token_response: TokenResponse =
 969        serde_json::from_str(&response_body).context("failed to parse token response")?;
 970
 971    Ok(token_response.into_tokens())
 972}
 973
 974// -- Loopback HTTP callback server -------------------------------------------
 975
 976/// An OAuth authorization callback received via the loopback HTTP server.
 977pub struct OAuthCallback {
 978    pub code: String,
 979    pub state: String,
 980}
 981
 982impl std::fmt::Debug for OAuthCallback {
 983    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 984        f.debug_struct("OAuthCallback")
 985            .field("code", &"[redacted]")
 986            .field("state", &"[redacted]")
 987            .finish()
 988    }
 989}
 990
 991impl OAuthCallback {
 992    /// Parse the query string from a callback URL like
 993    /// `http://127.0.0.1:<port>/callback?code=...&state=...`.
 994    pub fn parse_query(query: &str) -> Result<Self> {
 995        let mut code: Option<String> = None;
 996        let mut state: Option<String> = None;
 997        let mut error: Option<String> = None;
 998        let mut error_description: Option<String> = None;
 999
1000        for (key, value) in url::form_urlencoded::parse(query.as_bytes()) {
1001            match key.as_ref() {
1002                "code" => {
1003                    if !value.is_empty() {
1004                        code = Some(value.into_owned());
1005                    }
1006                }
1007                "state" => {
1008                    if !value.is_empty() {
1009                        state = Some(value.into_owned());
1010                    }
1011                }
1012                "error" => {
1013                    if !value.is_empty() {
1014                        error = Some(value.into_owned());
1015                    }
1016                }
1017                "error_description" => {
1018                    if !value.is_empty() {
1019                        error_description = Some(value.into_owned());
1020                    }
1021                }
1022                _ => {}
1023            }
1024        }
1025
1026        // Check for OAuth error response (RFC 6749 Section 4.1.2.1) before
1027        // checking for missing code/state.
1028        if let Some(error_code) = error {
1029            bail!(
1030                "OAuth authorization failed: {} ({})",
1031                error_code,
1032                error_description.as_deref().unwrap_or("no description")
1033            );
1034        }
1035
1036        let code = code.ok_or_else(|| anyhow!("missing 'code' parameter in OAuth callback"))?;
1037        let state = state.ok_or_else(|| anyhow!("missing 'state' parameter in OAuth callback"))?;
1038
1039        Ok(Self { code, state })
1040    }
1041}
1042
1043/// How long to wait for the browser to complete the OAuth flow before giving
1044/// up and releasing the loopback port.
1045const CALLBACK_TIMEOUT: Duration = Duration::from_secs(2 * 60);
1046
1047/// Start a loopback HTTP server to receive the OAuth authorization callback.
1048///
1049/// Binds to an ephemeral loopback port for each flow.
1050///
1051/// Returns `(redirect_uri, callback_future)`. The caller should use the
1052/// redirect URI in the authorization request, open the browser, then await
1053/// the future to receive the callback.
1054///
1055/// The server accepts exactly one request on `/callback`, validates that it
1056/// contains `code` and `state` query parameters, responds with a minimal
1057/// HTML page telling the user they can close the tab, and shuts down.
1058///
1059/// The callback server shuts down when the returned oneshot receiver is dropped
1060/// (e.g. because the authentication task was cancelled), or after a timeout
1061/// ([CALLBACK_TIMEOUT]).
1062pub async fn start_callback_server() -> Result<(
1063    String,
1064    futures::channel::oneshot::Receiver<Result<OAuthCallback>>,
1065)> {
1066    let server = tiny_http::Server::http("127.0.0.1:0")
1067        .map_err(|e| anyhow!(e).context("Failed to bind loopback listener for OAuth callback"))?;
1068    let port = server
1069        .server_addr()
1070        .to_ip()
1071        .context("server not bound to a TCP address")?
1072        .port();
1073
1074    let redirect_uri = format!("http://127.0.0.1:{}/callback", port);
1075
1076    let (tx, rx) = futures::channel::oneshot::channel();
1077
1078    // `tiny_http` is blocking, so we run it on a background thread.
1079    // The `recv_timeout` loop lets us check for cancellation (the receiver
1080    // being dropped) and enforce an overall timeout.
1081    std::thread::spawn(move || {
1082        let deadline = std::time::Instant::now() + CALLBACK_TIMEOUT;
1083
1084        loop {
1085            if tx.is_canceled() {
1086                return;
1087            }
1088            let remaining = deadline.saturating_duration_since(std::time::Instant::now());
1089            if remaining.is_zero() {
1090                return;
1091            }
1092
1093            let timeout = remaining.min(Duration::from_millis(500));
1094            let Some(request) = (match server.recv_timeout(timeout) {
1095                Ok(req) => req,
1096                Err(_) => {
1097                    let _ = tx.send(Err(anyhow!("OAuth callback server I/O error")));
1098                    return;
1099                }
1100            }) else {
1101                // Timeout with no request — loop back and check cancellation.
1102                continue;
1103            };
1104
1105            let result = handle_callback_request(&request);
1106
1107            let (status_code, body) = match &result {
1108                Ok(_) => (
1109                    200,
1110                    "<html><body><h1>Authorization successful</h1>\
1111                     <p>You can close this tab and return to Zed.</p></body></html>",
1112                ),
1113                Err(err) => {
1114                    log::error!("OAuth callback error: {}", err);
1115                    (
1116                        400,
1117                        "<html><body><h1>Authorization failed</h1>\
1118                         <p>Something went wrong. Please try again from Zed.</p></body></html>",
1119                    )
1120                }
1121            };
1122
1123            let response = tiny_http::Response::from_string(body)
1124                .with_status_code(status_code)
1125                .with_header(
1126                    tiny_http::Header::from_str("Content-Type: text/html")
1127                        .expect("failed to construct response header"),
1128                )
1129                .with_header(
1130                    tiny_http::Header::from_str("Keep-Alive: timeout=0,max=0")
1131                        .expect("failed to construct response header"),
1132                );
1133            request.respond(response).log_err();
1134
1135            let _ = tx.send(result);
1136            return;
1137        }
1138    });
1139
1140    Ok((redirect_uri, rx))
1141}
1142
1143/// Extract the `code` and `state` query parameters from an OAuth callback
1144/// request to `/callback`.
1145fn handle_callback_request(request: &tiny_http::Request) -> Result<OAuthCallback> {
1146    let url = Url::parse(&format!("http://localhost{}", request.url()))
1147        .context("malformed callback request URL")?;
1148
1149    if url.path() != "/callback" {
1150        bail!("unexpected path in OAuth callback: {}", url.path());
1151    }
1152
1153    let query = url
1154        .query()
1155        .ok_or_else(|| anyhow!("OAuth callback has no query string"))?;
1156    OAuthCallback::parse_query(query)
1157}
1158
1159// -- JSON fetch helper -------------------------------------------------------
1160
1161async fn fetch_json<T: serde::de::DeserializeOwned>(
1162    http_client: &Arc<dyn HttpClient>,
1163    url: &Url,
1164) -> Result<T> {
1165    validate_oauth_url(url)?;
1166
1167    let request = Request::builder()
1168        .method(http_client::http::Method::GET)
1169        .uri(url.as_str())
1170        .header("Accept", "application/json")
1171        .body(AsyncBody::default())?;
1172
1173    let mut response = http_client.send(request).await?;
1174
1175    if !response.status().is_success() {
1176        bail!("HTTP {} fetching {}", response.status(), url);
1177    }
1178
1179    let mut body = String::new();
1180    response.body_mut().read_to_string(&mut body).await?;
1181    serde_json::from_str(&body).with_context(|| format!("failed to parse JSON from {}", url))
1182}
1183
1184// -- Serde response types for discovery --------------------------------------
1185
1186#[derive(Debug, Deserialize)]
1187struct ProtectedResourceMetadataResponse {
1188    #[serde(default)]
1189    resource: Option<Url>,
1190    #[serde(default)]
1191    authorization_servers: Vec<Url>,
1192    #[serde(default)]
1193    scopes_supported: Option<Vec<String>>,
1194}
1195
1196#[derive(Debug, Deserialize)]
1197struct AuthServerMetadataResponse {
1198    #[serde(default)]
1199    issuer: Option<Url>,
1200    #[serde(default)]
1201    authorization_endpoint: Option<Url>,
1202    #[serde(default)]
1203    token_endpoint: Option<Url>,
1204    #[serde(default)]
1205    registration_endpoint: Option<Url>,
1206    #[serde(default)]
1207    scopes_supported: Option<Vec<String>>,
1208    #[serde(default)]
1209    code_challenge_methods_supported: Option<Vec<String>>,
1210    #[serde(default)]
1211    client_id_metadata_document_supported: Option<bool>,
1212}
1213
1214#[derive(Debug, Deserialize)]
1215struct DcrResponse {
1216    client_id: String,
1217    #[serde(default)]
1218    client_secret: Option<String>,
1219}
1220
1221/// Provides OAuth tokens to the HTTP transport layer.
1222///
1223/// The transport calls `access_token()` before each request. On a 401 response
1224/// it calls `try_refresh()` and retries once if the refresh succeeds.
1225#[async_trait]
1226pub trait OAuthTokenProvider: Send + Sync {
1227    /// Returns the current access token, if one is available.
1228    fn access_token(&self) -> Option<String>;
1229
1230    /// Attempts to refresh the access token. Returns `true` if a new token was
1231    /// obtained and the request should be retried.
1232    async fn try_refresh(&self) -> Result<bool>;
1233}
1234
1235/// Concrete `OAuthTokenProvider` backed by a full persisted OAuth session and
1236/// an HTTP client for token refresh. The same provider type is used both after
1237/// an interactive authentication flow and when restoring a saved session from
1238/// the keychain on startup.
1239pub struct McpOAuthTokenProvider {
1240    session: SyncMutex<OAuthSession>,
1241    http_client: Arc<dyn HttpClient>,
1242    token_refresh_tx: Option<mpsc::UnboundedSender<OAuthSession>>,
1243}
1244
1245impl McpOAuthTokenProvider {
1246    pub fn new(
1247        session: OAuthSession,
1248        http_client: Arc<dyn HttpClient>,
1249        token_refresh_tx: Option<mpsc::UnboundedSender<OAuthSession>>,
1250    ) -> Self {
1251        Self {
1252            session: SyncMutex::new(session),
1253            http_client,
1254            token_refresh_tx,
1255        }
1256    }
1257
1258    fn access_token_is_expired(tokens: &OAuthTokens) -> bool {
1259        tokens.expires_at.is_some_and(|expires_at| {
1260            SystemTime::now()
1261                .checked_add(Duration::from_secs(30))
1262                .is_some_and(|now_with_buffer| expires_at <= now_with_buffer)
1263        })
1264    }
1265}
1266
1267#[async_trait]
1268impl OAuthTokenProvider for McpOAuthTokenProvider {
1269    fn access_token(&self) -> Option<String> {
1270        let session = self.session.lock();
1271        if Self::access_token_is_expired(&session.tokens) {
1272            return None;
1273        }
1274        Some(session.tokens.access_token.clone())
1275    }
1276
1277    async fn try_refresh(&self) -> Result<bool> {
1278        let (refresh_token, token_endpoint, resource, client_id) = {
1279            let session = self.session.lock();
1280            match session.tokens.refresh_token.clone() {
1281                Some(refresh_token) => (
1282                    refresh_token,
1283                    session.token_endpoint.clone(),
1284                    session.resource.clone(),
1285                    session.client_registration.client_id.clone(),
1286                ),
1287                None => return Ok(false),
1288            }
1289        };
1290
1291        let resource_str = canonical_server_uri(&resource);
1292
1293        match refresh_tokens(
1294            &self.http_client,
1295            &token_endpoint,
1296            &refresh_token,
1297            &client_id,
1298            &resource_str,
1299        )
1300        .await
1301        {
1302            Ok(mut new_tokens) => {
1303                if new_tokens.refresh_token.is_none() {
1304                    new_tokens.refresh_token = Some(refresh_token);
1305                }
1306
1307                {
1308                    let mut session = self.session.lock();
1309                    session.tokens = new_tokens;
1310
1311                    if let Some(ref tx) = self.token_refresh_tx {
1312                        tx.unbounded_send(session.clone()).ok();
1313                    }
1314                }
1315
1316                Ok(true)
1317            }
1318            Err(err) => {
1319                log::warn!("OAuth token refresh failed: {}", err);
1320                Ok(false)
1321            }
1322        }
1323    }
1324}
1325
1326#[cfg(test)]
1327mod tests {
1328    use super::*;
1329    use http_client::Response;
1330
1331    // -- require_https_or_loopback tests ------------------------------------
1332
1333    #[test]
1334    fn test_require_https_or_loopback_accepts_https() {
1335        let url = Url::parse("https://auth.example.com/token").unwrap();
1336        assert!(require_https_or_loopback(&url).is_ok());
1337    }
1338
1339    #[test]
1340    fn test_require_https_or_loopback_rejects_http_remote() {
1341        let url = Url::parse("http://auth.example.com/token").unwrap();
1342        assert!(require_https_or_loopback(&url).is_err());
1343    }
1344
1345    #[test]
1346    fn test_require_https_or_loopback_accepts_http_127_0_0_1() {
1347        let url = Url::parse("http://127.0.0.1:8080/callback").unwrap();
1348        assert!(require_https_or_loopback(&url).is_ok());
1349    }
1350
1351    #[test]
1352    fn test_require_https_or_loopback_accepts_http_ipv6_loopback() {
1353        let url = Url::parse("http://[::1]:8080/callback").unwrap();
1354        assert!(require_https_or_loopback(&url).is_ok());
1355    }
1356
1357    #[test]
1358    fn test_require_https_or_loopback_accepts_http_localhost() {
1359        let url = Url::parse("http://localhost:8080/callback").unwrap();
1360        assert!(require_https_or_loopback(&url).is_ok());
1361    }
1362
1363    #[test]
1364    fn test_require_https_or_loopback_accepts_http_localhost_case_insensitive() {
1365        let url = Url::parse("http://LOCALHOST:8080/callback").unwrap();
1366        assert!(require_https_or_loopback(&url).is_ok());
1367    }
1368
1369    #[test]
1370    fn test_require_https_or_loopback_rejects_http_non_loopback_ip() {
1371        let url = Url::parse("http://192.168.1.1:8080/token").unwrap();
1372        assert!(require_https_or_loopback(&url).is_err());
1373    }
1374
1375    #[test]
1376    fn test_require_https_or_loopback_rejects_ftp() {
1377        let url = Url::parse("ftp://auth.example.com/token").unwrap();
1378        assert!(require_https_or_loopback(&url).is_err());
1379    }
1380
1381    // -- validate_oauth_url (SSRF) tests ------------------------------------
1382
1383    #[test]
1384    fn test_validate_oauth_url_accepts_https_public() {
1385        let url = Url::parse("https://auth.example.com/token").unwrap();
1386        assert!(validate_oauth_url(&url).is_ok());
1387    }
1388
1389    #[test]
1390    fn test_validate_oauth_url_rejects_private_ipv4_10() {
1391        let url = Url::parse("https://10.0.0.1/token").unwrap();
1392        assert!(validate_oauth_url(&url).is_err());
1393    }
1394
1395    #[test]
1396    fn test_validate_oauth_url_rejects_private_ipv4_172() {
1397        let url = Url::parse("https://172.16.0.1/token").unwrap();
1398        assert!(validate_oauth_url(&url).is_err());
1399    }
1400
1401    #[test]
1402    fn test_validate_oauth_url_rejects_private_ipv4_192() {
1403        let url = Url::parse("https://192.168.1.1/token").unwrap();
1404        assert!(validate_oauth_url(&url).is_err());
1405    }
1406
1407    #[test]
1408    fn test_validate_oauth_url_rejects_link_local() {
1409        let url = Url::parse("https://169.254.169.254/latest/meta-data/").unwrap();
1410        assert!(validate_oauth_url(&url).is_err());
1411    }
1412
1413    #[test]
1414    fn test_validate_oauth_url_rejects_ipv6_ula() {
1415        let url = Url::parse("https://[fd12:3456:789a::1]/token").unwrap();
1416        assert!(validate_oauth_url(&url).is_err());
1417    }
1418
1419    #[test]
1420    fn test_validate_oauth_url_rejects_ipv6_unspecified() {
1421        let url = Url::parse("https://[::]/token").unwrap();
1422        assert!(validate_oauth_url(&url).is_err());
1423    }
1424
1425    #[test]
1426    fn test_validate_oauth_url_rejects_ipv4_mapped_ipv6_private() {
1427        let url = Url::parse("https://[::ffff:10.0.0.1]/token").unwrap();
1428        assert!(validate_oauth_url(&url).is_err());
1429    }
1430
1431    #[test]
1432    fn test_validate_oauth_url_rejects_ipv4_mapped_ipv6_link_local() {
1433        let url = Url::parse("https://[::ffff:169.254.169.254]/token").unwrap();
1434        assert!(validate_oauth_url(&url).is_err());
1435    }
1436
1437    #[test]
1438    fn test_validate_oauth_url_allows_http_loopback() {
1439        // Loopback is permitted (it's our callback server).
1440        let url = Url::parse("http://127.0.0.1:8080/callback").unwrap();
1441        assert!(validate_oauth_url(&url).is_ok());
1442    }
1443
1444    #[test]
1445    fn test_validate_oauth_url_allows_https_public_ip() {
1446        let url = Url::parse("https://93.184.216.34/token").unwrap();
1447        assert!(validate_oauth_url(&url).is_ok());
1448    }
1449
1450    // -- parse_www_authenticate tests ----------------------------------------
1451
1452    #[test]
1453    fn test_parse_www_authenticate_with_resource_metadata_and_scope() {
1454        let header = r#"Bearer resource_metadata="https://mcp.example.com/.well-known/oauth-protected-resource", scope="files:read user:profile""#;
1455        let result = parse_www_authenticate(header).unwrap();
1456
1457        assert_eq!(
1458            result.resource_metadata.as_ref().map(|u| u.as_str()),
1459            Some("https://mcp.example.com/.well-known/oauth-protected-resource")
1460        );
1461        assert_eq!(
1462            result.scope,
1463            Some(vec!["files:read".to_string(), "user:profile".to_string()])
1464        );
1465        assert_eq!(result.error, None);
1466    }
1467
1468    #[test]
1469    fn test_parse_www_authenticate_resource_metadata_only() {
1470        let header = r#"Bearer resource_metadata="https://mcp.example.com/.well-known/oauth-protected-resource""#;
1471        let result = parse_www_authenticate(header).unwrap();
1472
1473        assert_eq!(
1474            result.resource_metadata.as_ref().map(|u| u.as_str()),
1475            Some("https://mcp.example.com/.well-known/oauth-protected-resource")
1476        );
1477        assert_eq!(result.scope, None);
1478    }
1479
1480    #[test]
1481    fn test_parse_www_authenticate_bare_bearer() {
1482        let result = parse_www_authenticate("Bearer").unwrap();
1483        assert_eq!(result.resource_metadata, None);
1484        assert_eq!(result.scope, None);
1485    }
1486
1487    #[test]
1488    fn test_parse_www_authenticate_with_error() {
1489        let header = r#"Bearer error="insufficient_scope", scope="files:read files:write", resource_metadata="https://mcp.example.com/.well-known/oauth-protected-resource", error_description="Additional file write permission required""#;
1490        let result = parse_www_authenticate(header).unwrap();
1491
1492        assert_eq!(result.error, Some(BearerError::InsufficientScope));
1493        assert_eq!(
1494            result.error_description.as_deref(),
1495            Some("Additional file write permission required")
1496        );
1497        assert_eq!(
1498            result.scope,
1499            Some(vec!["files:read".to_string(), "files:write".to_string()])
1500        );
1501        assert!(result.resource_metadata.is_some());
1502    }
1503
1504    #[test]
1505    fn test_parse_www_authenticate_invalid_token_error() {
1506        let header =
1507            r#"Bearer error="invalid_token", error_description="The access token expired""#;
1508        let result = parse_www_authenticate(header).unwrap();
1509        assert_eq!(result.error, Some(BearerError::InvalidToken));
1510    }
1511
1512    #[test]
1513    fn test_parse_www_authenticate_invalid_request_error() {
1514        let header = r#"Bearer error="invalid_request""#;
1515        let result = parse_www_authenticate(header).unwrap();
1516        assert_eq!(result.error, Some(BearerError::InvalidRequest));
1517    }
1518
1519    #[test]
1520    fn test_parse_www_authenticate_unknown_error() {
1521        let header = r#"Bearer error="some_future_error""#;
1522        let result = parse_www_authenticate(header).unwrap();
1523        assert_eq!(result.error, Some(BearerError::Other));
1524    }
1525
1526    #[test]
1527    fn test_parse_www_authenticate_rejects_non_bearer() {
1528        let result = parse_www_authenticate("Basic realm=\"example\"");
1529        assert!(result.is_err());
1530    }
1531
1532    #[test]
1533    fn test_parse_www_authenticate_case_insensitive_scheme() {
1534        let header = r#"bearer resource_metadata="https://example.com/.well-known/oauth-protected-resource""#;
1535        let result = parse_www_authenticate(header).unwrap();
1536        assert!(result.resource_metadata.is_some());
1537    }
1538
1539    #[test]
1540    fn test_parse_www_authenticate_multiline_style() {
1541        // Some servers emit the header spread across multiple lines joined by
1542        // whitespace, as shown in the spec examples.
1543        let header = "Bearer resource_metadata=\"https://mcp.example.com/.well-known/oauth-protected-resource\",\n                         scope=\"files:read\"";
1544        let result = parse_www_authenticate(header).unwrap();
1545        assert!(result.resource_metadata.is_some());
1546        assert_eq!(result.scope, Some(vec!["files:read".to_string()]));
1547    }
1548
1549    #[test]
1550    fn test_protected_resource_metadata_urls_with_path() {
1551        let server_url = Url::parse("https://api.example.com/v1/mcp").unwrap();
1552        let urls = protected_resource_metadata_urls(&server_url);
1553
1554        assert_eq!(urls.len(), 2);
1555        assert_eq!(
1556            urls[0].as_str(),
1557            "https://api.example.com/.well-known/oauth-protected-resource/v1/mcp"
1558        );
1559        assert_eq!(
1560            urls[1].as_str(),
1561            "https://api.example.com/.well-known/oauth-protected-resource"
1562        );
1563    }
1564
1565    #[test]
1566    fn test_protected_resource_metadata_urls_without_path() {
1567        let server_url = Url::parse("https://mcp.example.com").unwrap();
1568        let urls = protected_resource_metadata_urls(&server_url);
1569
1570        assert_eq!(urls.len(), 1);
1571        assert_eq!(
1572            urls[0].as_str(),
1573            "https://mcp.example.com/.well-known/oauth-protected-resource"
1574        );
1575    }
1576
1577    #[test]
1578    fn test_auth_server_metadata_urls_with_path() {
1579        let issuer = Url::parse("https://auth.example.com/tenant1").unwrap();
1580        let urls = auth_server_metadata_urls(&issuer);
1581
1582        assert_eq!(urls.len(), 3);
1583        assert_eq!(
1584            urls[0].as_str(),
1585            "https://auth.example.com/.well-known/oauth-authorization-server/tenant1"
1586        );
1587        assert_eq!(
1588            urls[1].as_str(),
1589            "https://auth.example.com/.well-known/openid-configuration/tenant1"
1590        );
1591        assert_eq!(
1592            urls[2].as_str(),
1593            "https://auth.example.com/tenant1/.well-known/openid-configuration"
1594        );
1595    }
1596
1597    #[test]
1598    fn test_auth_server_metadata_urls_without_path() {
1599        let issuer = Url::parse("https://auth.example.com").unwrap();
1600        let urls = auth_server_metadata_urls(&issuer);
1601
1602        assert_eq!(urls.len(), 2);
1603        assert_eq!(
1604            urls[0].as_str(),
1605            "https://auth.example.com/.well-known/oauth-authorization-server"
1606        );
1607        assert_eq!(
1608            urls[1].as_str(),
1609            "https://auth.example.com/.well-known/openid-configuration"
1610        );
1611    }
1612
1613    // -- Canonical server URI tests ------------------------------------------
1614
1615    #[test]
1616    fn test_canonical_server_uri_simple() {
1617        let url = Url::parse("https://mcp.example.com").unwrap();
1618        assert_eq!(canonical_server_uri(&url), "https://mcp.example.com");
1619    }
1620
1621    #[test]
1622    fn test_canonical_server_uri_with_path() {
1623        let url = Url::parse("https://mcp.example.com/v1/mcp").unwrap();
1624        assert_eq!(canonical_server_uri(&url), "https://mcp.example.com/v1/mcp");
1625    }
1626
1627    #[test]
1628    fn test_canonical_server_uri_strips_trailing_slash() {
1629        let url = Url::parse("https://mcp.example.com/").unwrap();
1630        assert_eq!(canonical_server_uri(&url), "https://mcp.example.com");
1631    }
1632
1633    #[test]
1634    fn test_canonical_server_uri_preserves_port() {
1635        let url = Url::parse("https://mcp.example.com:8443").unwrap();
1636        assert_eq!(canonical_server_uri(&url), "https://mcp.example.com:8443");
1637    }
1638
1639    #[test]
1640    fn test_canonical_server_uri_lowercases() {
1641        let url = Url::parse("HTTPS://MCP.Example.COM/Server/MCP").unwrap();
1642        assert_eq!(
1643            canonical_server_uri(&url),
1644            "https://mcp.example.com/Server/MCP"
1645        );
1646    }
1647
1648    // -- Scope selection tests -----------------------------------------------
1649
1650    #[test]
1651    fn test_select_scopes_prefers_www_authenticate() {
1652        let www_auth = WwwAuthenticate {
1653            resource_metadata: None,
1654            scope: Some(vec!["files:read".into()]),
1655            error: None,
1656            error_description: None,
1657        };
1658        let resource_meta = ProtectedResourceMetadata {
1659            resource: Url::parse("https://example.com").unwrap(),
1660            authorization_servers: vec![],
1661            scopes_supported: Some(vec!["files:read".into(), "files:write".into()]),
1662        };
1663        assert_eq!(select_scopes(&www_auth, &resource_meta), vec!["files:read"]);
1664    }
1665
1666    #[test]
1667    fn test_select_scopes_falls_back_to_resource_metadata() {
1668        let www_auth = WwwAuthenticate {
1669            resource_metadata: None,
1670            scope: None,
1671            error: None,
1672            error_description: None,
1673        };
1674        let resource_meta = ProtectedResourceMetadata {
1675            resource: Url::parse("https://example.com").unwrap(),
1676            authorization_servers: vec![],
1677            scopes_supported: Some(vec!["admin".into()]),
1678        };
1679        assert_eq!(select_scopes(&www_auth, &resource_meta), vec!["admin"]);
1680    }
1681
1682    #[test]
1683    fn test_select_scopes_empty_when_nothing_available() {
1684        let www_auth = WwwAuthenticate {
1685            resource_metadata: None,
1686            scope: None,
1687            error: None,
1688            error_description: None,
1689        };
1690        let resource_meta = ProtectedResourceMetadata {
1691            resource: Url::parse("https://example.com").unwrap(),
1692            authorization_servers: vec![],
1693            scopes_supported: None,
1694        };
1695        assert!(select_scopes(&www_auth, &resource_meta).is_empty());
1696    }
1697
1698    // -- Client registration strategy tests ----------------------------------
1699
1700    #[test]
1701    fn test_registration_strategy_prefers_cimd() {
1702        let metadata = AuthServerMetadata {
1703            issuer: Url::parse("https://auth.example.com").unwrap(),
1704            authorization_endpoint: Url::parse("https://auth.example.com/authorize").unwrap(),
1705            token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
1706            registration_endpoint: Some(Url::parse("https://auth.example.com/register").unwrap()),
1707            scopes_supported: None,
1708            code_challenge_methods_supported: Some(vec!["S256".into()]),
1709            client_id_metadata_document_supported: true,
1710        };
1711        assert_eq!(
1712            determine_registration_strategy(&metadata),
1713            ClientRegistrationStrategy::Cimd {
1714                client_id: CIMD_URL.to_string(),
1715            }
1716        );
1717    }
1718
1719    #[test]
1720    fn test_registration_strategy_falls_back_to_dcr() {
1721        let reg_endpoint = Url::parse("https://auth.example.com/register").unwrap();
1722        let metadata = AuthServerMetadata {
1723            issuer: Url::parse("https://auth.example.com").unwrap(),
1724            authorization_endpoint: Url::parse("https://auth.example.com/authorize").unwrap(),
1725            token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
1726            registration_endpoint: Some(reg_endpoint.clone()),
1727            scopes_supported: None,
1728            code_challenge_methods_supported: Some(vec!["S256".into()]),
1729            client_id_metadata_document_supported: false,
1730        };
1731        assert_eq!(
1732            determine_registration_strategy(&metadata),
1733            ClientRegistrationStrategy::Dcr {
1734                registration_endpoint: reg_endpoint,
1735            }
1736        );
1737    }
1738
1739    #[test]
1740    fn test_registration_strategy_unavailable() {
1741        let metadata = AuthServerMetadata {
1742            issuer: Url::parse("https://auth.example.com").unwrap(),
1743            authorization_endpoint: Url::parse("https://auth.example.com/authorize").unwrap(),
1744            token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
1745            registration_endpoint: None,
1746            scopes_supported: None,
1747            code_challenge_methods_supported: Some(vec!["S256".into()]),
1748            client_id_metadata_document_supported: false,
1749        };
1750        assert_eq!(
1751            determine_registration_strategy(&metadata),
1752            ClientRegistrationStrategy::Unavailable,
1753        );
1754    }
1755
1756    // -- PKCE tests ----------------------------------------------------------
1757
1758    #[test]
1759    fn test_pkce_challenge_verifier_length() {
1760        let pkce = generate_pkce_challenge();
1761        // 32 random bytes → 43 base64url chars (no padding).
1762        assert_eq!(pkce.verifier.len(), 43);
1763    }
1764
1765    #[test]
1766    fn test_pkce_challenge_is_valid_base64url() {
1767        let pkce = generate_pkce_challenge();
1768        for c in pkce.verifier.chars().chain(pkce.challenge.chars()) {
1769            assert!(
1770                c.is_ascii_alphanumeric() || c == '-' || c == '_',
1771                "invalid base64url character: {}",
1772                c
1773            );
1774        }
1775    }
1776
1777    #[test]
1778    fn test_pkce_challenge_is_s256_of_verifier() {
1779        let pkce = generate_pkce_challenge();
1780        let engine = base64::engine::general_purpose::URL_SAFE_NO_PAD;
1781        let expected_digest = Sha256::digest(pkce.verifier.as_bytes());
1782        let expected_challenge = engine.encode(expected_digest);
1783        assert_eq!(pkce.challenge, expected_challenge);
1784    }
1785
1786    #[test]
1787    fn test_pkce_challenges_are_unique() {
1788        let a = generate_pkce_challenge();
1789        let b = generate_pkce_challenge();
1790        assert_ne!(a.verifier, b.verifier);
1791    }
1792
1793    // -- Authorization URL tests ---------------------------------------------
1794
1795    #[test]
1796    fn test_build_authorization_url() {
1797        let metadata = AuthServerMetadata {
1798            issuer: Url::parse("https://auth.example.com").unwrap(),
1799            authorization_endpoint: Url::parse("https://auth.example.com/authorize").unwrap(),
1800            token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
1801            registration_endpoint: None,
1802            scopes_supported: None,
1803            code_challenge_methods_supported: Some(vec!["S256".into()]),
1804            client_id_metadata_document_supported: true,
1805        };
1806        let pkce = PkceChallenge {
1807            verifier: "test_verifier".into(),
1808            challenge: "test_challenge".into(),
1809        };
1810        let url = build_authorization_url(
1811            &metadata,
1812            "https://zed.dev/oauth/client-metadata.json",
1813            "http://127.0.0.1:12345/callback",
1814            &["files:read".into(), "files:write".into()],
1815            "https://mcp.example.com",
1816            &pkce,
1817            "random_state_123",
1818        );
1819
1820        let pairs: std::collections::HashMap<_, _> = url.query_pairs().collect();
1821        assert_eq!(pairs.get("response_type").unwrap(), "code");
1822        assert_eq!(
1823            pairs.get("client_id").unwrap(),
1824            "https://zed.dev/oauth/client-metadata.json"
1825        );
1826        assert_eq!(
1827            pairs.get("redirect_uri").unwrap(),
1828            "http://127.0.0.1:12345/callback"
1829        );
1830        assert_eq!(pairs.get("scope").unwrap(), "files:read files:write");
1831        assert_eq!(pairs.get("resource").unwrap(), "https://mcp.example.com");
1832        assert_eq!(pairs.get("code_challenge").unwrap(), "test_challenge");
1833        assert_eq!(pairs.get("code_challenge_method").unwrap(), "S256");
1834        assert_eq!(pairs.get("state").unwrap(), "random_state_123");
1835    }
1836
1837    #[test]
1838    fn test_build_authorization_url_omits_empty_scope() {
1839        let metadata = AuthServerMetadata {
1840            issuer: Url::parse("https://auth.example.com").unwrap(),
1841            authorization_endpoint: Url::parse("https://auth.example.com/authorize").unwrap(),
1842            token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
1843            registration_endpoint: None,
1844            scopes_supported: None,
1845            code_challenge_methods_supported: Some(vec!["S256".into()]),
1846            client_id_metadata_document_supported: false,
1847        };
1848        let pkce = PkceChallenge {
1849            verifier: "v".into(),
1850            challenge: "c".into(),
1851        };
1852        let url = build_authorization_url(
1853            &metadata,
1854            "client_123",
1855            "http://127.0.0.1:9999/callback",
1856            &[],
1857            "https://mcp.example.com",
1858            &pkce,
1859            "state",
1860        );
1861
1862        let pairs: std::collections::HashMap<_, _> = url.query_pairs().collect();
1863        assert!(!pairs.contains_key("scope"));
1864    }
1865
1866    // -- Token exchange / refresh param tests --------------------------------
1867
1868    #[test]
1869    fn test_token_exchange_params() {
1870        let params = token_exchange_params(
1871            "auth_code_abc",
1872            "client_xyz",
1873            "http://127.0.0.1:5555/callback",
1874            "verifier_123",
1875            "https://mcp.example.com",
1876        );
1877        let map: std::collections::HashMap<&str, &str> =
1878            params.iter().map(|(k, v)| (*k, v.as_str())).collect();
1879
1880        assert_eq!(map["grant_type"], "authorization_code");
1881        assert_eq!(map["code"], "auth_code_abc");
1882        assert_eq!(map["redirect_uri"], "http://127.0.0.1:5555/callback");
1883        assert_eq!(map["client_id"], "client_xyz");
1884        assert_eq!(map["code_verifier"], "verifier_123");
1885        assert_eq!(map["resource"], "https://mcp.example.com");
1886    }
1887
1888    #[test]
1889    fn test_token_refresh_params() {
1890        let params =
1891            token_refresh_params("refresh_token_abc", "client_xyz", "https://mcp.example.com");
1892        let map: std::collections::HashMap<&str, &str> =
1893            params.iter().map(|(k, v)| (*k, v.as_str())).collect();
1894
1895        assert_eq!(map["grant_type"], "refresh_token");
1896        assert_eq!(map["refresh_token"], "refresh_token_abc");
1897        assert_eq!(map["client_id"], "client_xyz");
1898        assert_eq!(map["resource"], "https://mcp.example.com");
1899    }
1900
1901    // -- Token response tests ------------------------------------------------
1902
1903    #[test]
1904    fn test_token_response_into_tokens_with_expiry() {
1905        let response: TokenResponse = serde_json::from_str(
1906            r#"{"access_token": "at_123", "refresh_token": "rt_456", "expires_in": 3600, "token_type": "Bearer"}"#,
1907        )
1908        .unwrap();
1909
1910        let tokens = response.into_tokens();
1911        assert_eq!(tokens.access_token, "at_123");
1912        assert_eq!(tokens.refresh_token.as_deref(), Some("rt_456"));
1913        assert!(tokens.expires_at.is_some());
1914    }
1915
1916    #[test]
1917    fn test_token_response_into_tokens_minimal() {
1918        let response: TokenResponse =
1919            serde_json::from_str(r#"{"access_token": "at_789"}"#).unwrap();
1920
1921        let tokens = response.into_tokens();
1922        assert_eq!(tokens.access_token, "at_789");
1923        assert_eq!(tokens.refresh_token, None);
1924        assert_eq!(tokens.expires_at, None);
1925    }
1926
1927    // -- DCR body test -------------------------------------------------------
1928
1929    #[test]
1930    fn test_dcr_registration_body_shape() {
1931        let body = dcr_registration_body("http://127.0.0.1:12345/callback");
1932        assert_eq!(body["client_name"], "Zed");
1933        assert_eq!(body["redirect_uris"][0], "http://127.0.0.1:12345/callback");
1934        assert_eq!(body["grant_types"][0], "authorization_code");
1935        assert_eq!(body["response_types"][0], "code");
1936        assert_eq!(body["token_endpoint_auth_method"], "none");
1937    }
1938
1939    // -- Test helpers for async/HTTP tests -----------------------------------
1940
1941    fn make_fake_http_client(
1942        handler: impl Fn(
1943            http_client::Request<AsyncBody>,
1944        ) -> std::pin::Pin<
1945            Box<dyn std::future::Future<Output = anyhow::Result<Response<AsyncBody>>> + Send>,
1946        > + Send
1947        + Sync
1948        + 'static,
1949    ) -> Arc<dyn HttpClient> {
1950        http_client::FakeHttpClient::create(handler) as Arc<dyn HttpClient>
1951    }
1952
1953    fn json_response(status: u16, body: &str) -> anyhow::Result<Response<AsyncBody>> {
1954        Ok(Response::builder()
1955            .status(status)
1956            .header("Content-Type", "application/json")
1957            .body(AsyncBody::from(body.as_bytes().to_vec()))
1958            .unwrap())
1959    }
1960
1961    // -- Discovery integration tests -----------------------------------------
1962
1963    #[test]
1964    fn test_fetch_protected_resource_metadata() {
1965        smol::block_on(async {
1966            let client = make_fake_http_client(|req| {
1967                Box::pin(async move {
1968                    let uri = req.uri().to_string();
1969                    if uri.contains(".well-known/oauth-protected-resource") {
1970                        json_response(
1971                            200,
1972                            r#"{
1973                                "resource": "https://mcp.example.com",
1974                                "authorization_servers": ["https://auth.example.com"],
1975                                "scopes_supported": ["read", "write"]
1976                            }"#,
1977                        )
1978                    } else {
1979                        json_response(404, "{}")
1980                    }
1981                })
1982            });
1983
1984            let server_url = Url::parse("https://mcp.example.com").unwrap();
1985            let www_auth = WwwAuthenticate {
1986                resource_metadata: None,
1987                scope: None,
1988                error: None,
1989                error_description: None,
1990            };
1991
1992            let metadata = fetch_protected_resource_metadata(&client, &server_url, &www_auth)
1993                .await
1994                .unwrap();
1995
1996            assert_eq!(metadata.resource.as_str(), "https://mcp.example.com/");
1997            assert_eq!(metadata.authorization_servers.len(), 1);
1998            assert_eq!(
1999                metadata.authorization_servers[0].as_str(),
2000                "https://auth.example.com/"
2001            );
2002            assert_eq!(
2003                metadata.scopes_supported,
2004                Some(vec!["read".to_string(), "write".to_string()])
2005            );
2006        });
2007    }
2008
2009    #[test]
2010    fn test_fetch_protected_resource_metadata_prefers_www_authenticate_url() {
2011        smol::block_on(async {
2012            let client = make_fake_http_client(|req| {
2013                Box::pin(async move {
2014                    let uri = req.uri().to_string();
2015                    if uri == "https://mcp.example.com/custom-resource-metadata" {
2016                        json_response(
2017                            200,
2018                            r#"{
2019                                "resource": "https://mcp.example.com",
2020                                "authorization_servers": ["https://auth.example.com"]
2021                            }"#,
2022                        )
2023                    } else {
2024                        json_response(500, r#"{"error": "should not be called"}"#)
2025                    }
2026                })
2027            });
2028
2029            let server_url = Url::parse("https://mcp.example.com").unwrap();
2030            let www_auth = WwwAuthenticate {
2031                resource_metadata: Some(
2032                    Url::parse("https://mcp.example.com/custom-resource-metadata").unwrap(),
2033                ),
2034                scope: None,
2035                error: None,
2036                error_description: None,
2037            };
2038
2039            let metadata = fetch_protected_resource_metadata(&client, &server_url, &www_auth)
2040                .await
2041                .unwrap();
2042
2043            assert_eq!(metadata.authorization_servers.len(), 1);
2044        });
2045    }
2046
2047    #[test]
2048    fn test_fetch_protected_resource_metadata_rejects_cross_origin_url() {
2049        smol::block_on(async {
2050            let client = make_fake_http_client(|req| {
2051                Box::pin(async move {
2052                    let uri = req.uri().to_string();
2053                    // The cross-origin URL should NOT be fetched; only the
2054                    // well-known fallback at the server's own origin should be.
2055                    if uri.contains("attacker.example.com") {
2056                        panic!("should not fetch cross-origin resource_metadata URL");
2057                    } else if uri.contains(".well-known/oauth-protected-resource") {
2058                        json_response(
2059                            200,
2060                            r#"{
2061                                "resource": "https://mcp.example.com",
2062                                "authorization_servers": ["https://auth.example.com"]
2063                            }"#,
2064                        )
2065                    } else {
2066                        json_response(404, "{}")
2067                    }
2068                })
2069            });
2070
2071            let server_url = Url::parse("https://mcp.example.com").unwrap();
2072            let www_auth = WwwAuthenticate {
2073                resource_metadata: Some(
2074                    Url::parse("https://attacker.example.com/fake-metadata").unwrap(),
2075                ),
2076                scope: None,
2077                error: None,
2078                error_description: None,
2079            };
2080
2081            let metadata = fetch_protected_resource_metadata(&client, &server_url, &www_auth)
2082                .await
2083                .unwrap();
2084
2085            // Should have used the fallback well-known URL, not the attacker's.
2086            assert_eq!(metadata.resource.as_str(), "https://mcp.example.com/");
2087        });
2088    }
2089
2090    #[test]
2091    fn test_fetch_auth_server_metadata() {
2092        smol::block_on(async {
2093            let client = make_fake_http_client(|req| {
2094                Box::pin(async move {
2095                    let uri = req.uri().to_string();
2096                    if uri.contains(".well-known/oauth-authorization-server") {
2097                        json_response(
2098                            200,
2099                            r#"{
2100                                "issuer": "https://auth.example.com",
2101                                "authorization_endpoint": "https://auth.example.com/authorize",
2102                                "token_endpoint": "https://auth.example.com/token",
2103                                "registration_endpoint": "https://auth.example.com/register",
2104                                "code_challenge_methods_supported": ["S256"],
2105                                "client_id_metadata_document_supported": true
2106                            }"#,
2107                        )
2108                    } else {
2109                        json_response(404, "{}")
2110                    }
2111                })
2112            });
2113
2114            let issuer = Url::parse("https://auth.example.com").unwrap();
2115            let metadata = fetch_auth_server_metadata(&client, &issuer).await.unwrap();
2116
2117            assert_eq!(metadata.issuer.as_str(), "https://auth.example.com/");
2118            assert_eq!(
2119                metadata.authorization_endpoint.as_str(),
2120                "https://auth.example.com/authorize"
2121            );
2122            assert_eq!(
2123                metadata.token_endpoint.as_str(),
2124                "https://auth.example.com/token"
2125            );
2126            assert!(metadata.registration_endpoint.is_some());
2127            assert!(metadata.client_id_metadata_document_supported);
2128            assert_eq!(
2129                metadata.code_challenge_methods_supported,
2130                Some(vec!["S256".to_string()])
2131            );
2132        });
2133    }
2134
2135    #[test]
2136    fn test_fetch_auth_server_metadata_falls_back_to_oidc() {
2137        smol::block_on(async {
2138            let client = make_fake_http_client(|req| {
2139                Box::pin(async move {
2140                    let uri = req.uri().to_string();
2141                    if uri.contains("openid-configuration") {
2142                        json_response(
2143                            200,
2144                            r#"{
2145                                "issuer": "https://auth.example.com",
2146                                "authorization_endpoint": "https://auth.example.com/authorize",
2147                                "token_endpoint": "https://auth.example.com/token",
2148                                "code_challenge_methods_supported": ["S256"]
2149                            }"#,
2150                        )
2151                    } else {
2152                        json_response(404, "{}")
2153                    }
2154                })
2155            });
2156
2157            let issuer = Url::parse("https://auth.example.com").unwrap();
2158            let metadata = fetch_auth_server_metadata(&client, &issuer).await.unwrap();
2159
2160            assert_eq!(
2161                metadata.authorization_endpoint.as_str(),
2162                "https://auth.example.com/authorize"
2163            );
2164            assert!(!metadata.client_id_metadata_document_supported);
2165        });
2166    }
2167
2168    #[test]
2169    fn test_fetch_auth_server_metadata_rejects_issuer_mismatch() {
2170        smol::block_on(async {
2171            let client = make_fake_http_client(|req| {
2172                Box::pin(async move {
2173                    let uri = req.uri().to_string();
2174                    if uri.contains(".well-known/oauth-authorization-server") {
2175                        // Response claims to be a different issuer.
2176                        json_response(
2177                            200,
2178                            r#"{
2179                                "issuer": "https://evil.example.com",
2180                                "authorization_endpoint": "https://evil.example.com/authorize",
2181                                "token_endpoint": "https://evil.example.com/token",
2182                                "code_challenge_methods_supported": ["S256"]
2183                            }"#,
2184                        )
2185                    } else {
2186                        json_response(404, "{}")
2187                    }
2188                })
2189            });
2190
2191            let issuer = Url::parse("https://auth.example.com").unwrap();
2192            let result = fetch_auth_server_metadata(&client, &issuer).await;
2193
2194            assert!(result.is_err());
2195            let err_msg = result.unwrap_err().to_string();
2196            assert!(
2197                err_msg.contains("issuer mismatch"),
2198                "unexpected error: {}",
2199                err_msg
2200            );
2201        });
2202    }
2203
2204    // -- Full discover integration tests -------------------------------------
2205
2206    #[test]
2207    fn test_full_discover_with_cimd() {
2208        smol::block_on(async {
2209            let client = make_fake_http_client(|req| {
2210                Box::pin(async move {
2211                    let uri = req.uri().to_string();
2212                    if uri.contains("oauth-protected-resource") {
2213                        json_response(
2214                            200,
2215                            r#"{
2216                                "resource": "https://mcp.example.com",
2217                                "authorization_servers": ["https://auth.example.com"],
2218                                "scopes_supported": ["mcp:read"]
2219                            }"#,
2220                        )
2221                    } else if uri.contains("oauth-authorization-server") {
2222                        json_response(
2223                            200,
2224                            r#"{
2225                                "issuer": "https://auth.example.com",
2226                                "authorization_endpoint": "https://auth.example.com/authorize",
2227                                "token_endpoint": "https://auth.example.com/token",
2228                                "code_challenge_methods_supported": ["S256"],
2229                                "client_id_metadata_document_supported": true
2230                            }"#,
2231                        )
2232                    } else {
2233                        json_response(404, "{}")
2234                    }
2235                })
2236            });
2237
2238            let server_url = Url::parse("https://mcp.example.com").unwrap();
2239            let www_auth = WwwAuthenticate {
2240                resource_metadata: None,
2241                scope: None,
2242                error: None,
2243                error_description: None,
2244            };
2245
2246            let discovery = discover(&client, &server_url, &www_auth).await.unwrap();
2247            let registration =
2248                resolve_client_registration(&client, &discovery, "http://127.0.0.1:12345/callback")
2249                    .await
2250                    .unwrap();
2251
2252            assert_eq!(registration.client_id, CIMD_URL);
2253            assert_eq!(registration.client_secret, None);
2254            assert_eq!(discovery.scopes, vec!["mcp:read"]);
2255        });
2256    }
2257
2258    #[test]
2259    fn test_full_discover_with_dcr_fallback() {
2260        smol::block_on(async {
2261            let client = make_fake_http_client(|req| {
2262                Box::pin(async move {
2263                    let uri = req.uri().to_string();
2264                    if uri.contains("oauth-protected-resource") {
2265                        json_response(
2266                            200,
2267                            r#"{
2268                                "resource": "https://mcp.example.com",
2269                                "authorization_servers": ["https://auth.example.com"]
2270                            }"#,
2271                        )
2272                    } else if uri.contains("oauth-authorization-server") {
2273                        json_response(
2274                            200,
2275                            r#"{
2276                                "issuer": "https://auth.example.com",
2277                                "authorization_endpoint": "https://auth.example.com/authorize",
2278                                "token_endpoint": "https://auth.example.com/token",
2279                                "registration_endpoint": "https://auth.example.com/register",
2280                                "code_challenge_methods_supported": ["S256"],
2281                                "client_id_metadata_document_supported": false
2282                            }"#,
2283                        )
2284                    } else if uri.contains("/register") {
2285                        json_response(
2286                            201,
2287                            r#"{
2288                                "client_id": "dcr-minted-id-123",
2289                                "client_secret": "dcr-secret-456"
2290                            }"#,
2291                        )
2292                    } else {
2293                        json_response(404, "{}")
2294                    }
2295                })
2296            });
2297
2298            let server_url = Url::parse("https://mcp.example.com").unwrap();
2299            let www_auth = WwwAuthenticate {
2300                resource_metadata: None,
2301                scope: Some(vec!["files:read".into()]),
2302                error: None,
2303                error_description: None,
2304            };
2305
2306            let discovery = discover(&client, &server_url, &www_auth).await.unwrap();
2307            let registration =
2308                resolve_client_registration(&client, &discovery, "http://127.0.0.1:9999/callback")
2309                    .await
2310                    .unwrap();
2311
2312            assert_eq!(registration.client_id, "dcr-minted-id-123");
2313            assert_eq!(
2314                registration.client_secret.as_deref(),
2315                Some("dcr-secret-456")
2316            );
2317            assert_eq!(discovery.scopes, vec!["files:read"]);
2318        });
2319    }
2320
2321    #[test]
2322    fn test_discover_fails_without_pkce_support() {
2323        smol::block_on(async {
2324            let client = make_fake_http_client(|req| {
2325                Box::pin(async move {
2326                    let uri = req.uri().to_string();
2327                    if uri.contains("oauth-protected-resource") {
2328                        json_response(
2329                            200,
2330                            r#"{
2331                                "resource": "https://mcp.example.com",
2332                                "authorization_servers": ["https://auth.example.com"]
2333                            }"#,
2334                        )
2335                    } else if uri.contains("oauth-authorization-server") {
2336                        json_response(
2337                            200,
2338                            r#"{
2339                                "issuer": "https://auth.example.com",
2340                                "authorization_endpoint": "https://auth.example.com/authorize",
2341                                "token_endpoint": "https://auth.example.com/token"
2342                            }"#,
2343                        )
2344                    } else {
2345                        json_response(404, "{}")
2346                    }
2347                })
2348            });
2349
2350            let server_url = Url::parse("https://mcp.example.com").unwrap();
2351            let www_auth = WwwAuthenticate {
2352                resource_metadata: None,
2353                scope: None,
2354                error: None,
2355                error_description: None,
2356            };
2357
2358            let result = discover(&client, &server_url, &www_auth).await;
2359            assert!(result.is_err());
2360            let err_msg = result.unwrap_err().to_string();
2361            assert!(
2362                err_msg.contains("code_challenge_methods_supported"),
2363                "unexpected error: {}",
2364                err_msg
2365            );
2366        });
2367    }
2368
2369    // -- Token exchange integration tests ------------------------------------
2370
2371    #[test]
2372    fn test_exchange_code_success() {
2373        smol::block_on(async {
2374            let client = make_fake_http_client(|req| {
2375                Box::pin(async move {
2376                    let uri = req.uri().to_string();
2377                    if uri.contains("/token") {
2378                        json_response(
2379                            200,
2380                            r#"{
2381                                "access_token": "new_access_token",
2382                                "refresh_token": "new_refresh_token",
2383                                "expires_in": 3600,
2384                                "token_type": "Bearer"
2385                            }"#,
2386                        )
2387                    } else {
2388                        json_response(404, "{}")
2389                    }
2390                })
2391            });
2392
2393            let metadata = AuthServerMetadata {
2394                issuer: Url::parse("https://auth.example.com").unwrap(),
2395                authorization_endpoint: Url::parse("https://auth.example.com/authorize").unwrap(),
2396                token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
2397                registration_endpoint: None,
2398                scopes_supported: None,
2399                code_challenge_methods_supported: Some(vec!["S256".into()]),
2400                client_id_metadata_document_supported: true,
2401            };
2402
2403            let tokens = exchange_code(
2404                &client,
2405                &metadata,
2406                "auth_code_123",
2407                CIMD_URL,
2408                "http://127.0.0.1:9999/callback",
2409                "verifier_abc",
2410                "https://mcp.example.com",
2411            )
2412            .await
2413            .unwrap();
2414
2415            assert_eq!(tokens.access_token, "new_access_token");
2416            assert_eq!(tokens.refresh_token.as_deref(), Some("new_refresh_token"));
2417            assert!(tokens.expires_at.is_some());
2418        });
2419    }
2420
2421    #[test]
2422    fn test_refresh_tokens_success() {
2423        smol::block_on(async {
2424            let client = make_fake_http_client(|req| {
2425                Box::pin(async move {
2426                    let uri = req.uri().to_string();
2427                    if uri.contains("/token") {
2428                        json_response(
2429                            200,
2430                            r#"{
2431                                "access_token": "refreshed_token",
2432                                "expires_in": 1800,
2433                                "token_type": "Bearer"
2434                            }"#,
2435                        )
2436                    } else {
2437                        json_response(404, "{}")
2438                    }
2439                })
2440            });
2441
2442            let token_endpoint = Url::parse("https://auth.example.com/token").unwrap();
2443
2444            let tokens = refresh_tokens(
2445                &client,
2446                &token_endpoint,
2447                "old_refresh_token",
2448                CIMD_URL,
2449                "https://mcp.example.com",
2450            )
2451            .await
2452            .unwrap();
2453
2454            assert_eq!(tokens.access_token, "refreshed_token");
2455            assert_eq!(tokens.refresh_token, None);
2456            assert!(tokens.expires_at.is_some());
2457        });
2458    }
2459
2460    #[test]
2461    fn test_exchange_code_failure() {
2462        smol::block_on(async {
2463            let client = make_fake_http_client(|_req| {
2464                Box::pin(async move { json_response(400, r#"{"error": "invalid_grant"}"#) })
2465            });
2466
2467            let metadata = AuthServerMetadata {
2468                issuer: Url::parse("https://auth.example.com").unwrap(),
2469                authorization_endpoint: Url::parse("https://auth.example.com/authorize").unwrap(),
2470                token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
2471                registration_endpoint: None,
2472                scopes_supported: None,
2473                code_challenge_methods_supported: Some(vec!["S256".into()]),
2474                client_id_metadata_document_supported: true,
2475            };
2476
2477            let result = exchange_code(
2478                &client,
2479                &metadata,
2480                "bad_code",
2481                "client",
2482                "http://127.0.0.1:1/callback",
2483                "verifier",
2484                "https://mcp.example.com",
2485            )
2486            .await;
2487
2488            assert!(result.is_err());
2489            assert!(result.unwrap_err().to_string().contains("400"));
2490        });
2491    }
2492
2493    // -- DCR integration tests -----------------------------------------------
2494
2495    #[test]
2496    fn test_perform_dcr() {
2497        smol::block_on(async {
2498            let client = make_fake_http_client(|_req| {
2499                Box::pin(async move {
2500                    json_response(
2501                        201,
2502                        r#"{
2503                            "client_id": "dynamic-client-001",
2504                            "client_secret": "dynamic-secret-001"
2505                        }"#,
2506                    )
2507                })
2508            });
2509
2510            let endpoint = Url::parse("https://auth.example.com/register").unwrap();
2511            let registration = perform_dcr(&client, &endpoint, "http://127.0.0.1:9999/callback")
2512                .await
2513                .unwrap();
2514
2515            assert_eq!(registration.client_id, "dynamic-client-001");
2516            assert_eq!(
2517                registration.client_secret.as_deref(),
2518                Some("dynamic-secret-001")
2519            );
2520        });
2521    }
2522
2523    #[test]
2524    fn test_perform_dcr_failure() {
2525        smol::block_on(async {
2526            let client = make_fake_http_client(|_req| {
2527                Box::pin(
2528                    async move { json_response(403, r#"{"error": "registration_not_allowed"}"#) },
2529                )
2530            });
2531
2532            let endpoint = Url::parse("https://auth.example.com/register").unwrap();
2533            let result = perform_dcr(&client, &endpoint, "http://127.0.0.1:9999/callback").await;
2534
2535            assert!(result.is_err());
2536            assert!(result.unwrap_err().to_string().contains("403"));
2537        });
2538    }
2539
2540    // -- OAuthCallback parse tests -------------------------------------------
2541
2542    #[test]
2543    fn test_oauth_callback_parse_query() {
2544        let callback = OAuthCallback::parse_query("code=test_auth_code&state=test_state").unwrap();
2545        assert_eq!(callback.code, "test_auth_code");
2546        assert_eq!(callback.state, "test_state");
2547    }
2548
2549    #[test]
2550    fn test_oauth_callback_parse_query_reversed_order() {
2551        let callback = OAuthCallback::parse_query("state=test_state&code=test_auth_code").unwrap();
2552        assert_eq!(callback.code, "test_auth_code");
2553        assert_eq!(callback.state, "test_state");
2554    }
2555
2556    #[test]
2557    fn test_oauth_callback_parse_query_with_extra_params() {
2558        let callback =
2559            OAuthCallback::parse_query("code=test_auth_code&state=test_state&extra=ignored")
2560                .unwrap();
2561        assert_eq!(callback.code, "test_auth_code");
2562        assert_eq!(callback.state, "test_state");
2563    }
2564
2565    #[test]
2566    fn test_oauth_callback_parse_query_missing_code() {
2567        let result = OAuthCallback::parse_query("state=test_state");
2568        assert!(result.is_err());
2569        assert!(result.unwrap_err().to_string().contains("code"));
2570    }
2571
2572    #[test]
2573    fn test_oauth_callback_parse_query_missing_state() {
2574        let result = OAuthCallback::parse_query("code=test_auth_code");
2575        assert!(result.is_err());
2576        assert!(result.unwrap_err().to_string().contains("state"));
2577    }
2578
2579    #[test]
2580    fn test_oauth_callback_parse_query_empty_code() {
2581        let result = OAuthCallback::parse_query("code=&state=test_state");
2582        assert!(result.is_err());
2583    }
2584
2585    #[test]
2586    fn test_oauth_callback_parse_query_empty_state() {
2587        let result = OAuthCallback::parse_query("code=test_auth_code&state=");
2588        assert!(result.is_err());
2589    }
2590
2591    #[test]
2592    fn test_oauth_callback_parse_query_url_encoded_values() {
2593        let callback = OAuthCallback::parse_query("code=abc%20def&state=test%3Dstate").unwrap();
2594        assert_eq!(callback.code, "abc def");
2595        assert_eq!(callback.state, "test=state");
2596    }
2597
2598    #[test]
2599    fn test_oauth_callback_parse_query_error_response() {
2600        let result = OAuthCallback::parse_query(
2601            "error=access_denied&error_description=User%20denied%20access&state=abc",
2602        );
2603        assert!(result.is_err());
2604        let err_msg = result.unwrap_err().to_string();
2605        assert!(
2606            err_msg.contains("access_denied"),
2607            "unexpected error: {}",
2608            err_msg
2609        );
2610        assert!(
2611            err_msg.contains("User denied access"),
2612            "unexpected error: {}",
2613            err_msg
2614        );
2615    }
2616
2617    #[test]
2618    fn test_oauth_callback_parse_query_error_without_description() {
2619        let result = OAuthCallback::parse_query("error=server_error&state=abc");
2620        assert!(result.is_err());
2621        let err_msg = result.unwrap_err().to_string();
2622        assert!(
2623            err_msg.contains("server_error"),
2624            "unexpected error: {}",
2625            err_msg
2626        );
2627        assert!(
2628            err_msg.contains("no description"),
2629            "unexpected error: {}",
2630            err_msg
2631        );
2632    }
2633
2634    // -- McpOAuthTokenProvider tests -----------------------------------------
2635
2636    fn make_test_session(
2637        access_token: &str,
2638        refresh_token: Option<&str>,
2639        expires_at: Option<SystemTime>,
2640    ) -> OAuthSession {
2641        OAuthSession {
2642            token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
2643            resource: Url::parse("https://mcp.example.com").unwrap(),
2644            client_registration: OAuthClientRegistration {
2645                client_id: "test-client".into(),
2646                client_secret: None,
2647            },
2648            tokens: OAuthTokens {
2649                access_token: access_token.into(),
2650                refresh_token: refresh_token.map(String::from),
2651                expires_at,
2652            },
2653        }
2654    }
2655
2656    #[test]
2657    fn test_mcp_oauth_provider_returns_none_when_token_expired() {
2658        let expired = SystemTime::now() - Duration::from_secs(60);
2659        let session = make_test_session("stale-token", Some("rt"), Some(expired));
2660        let provider = McpOAuthTokenProvider::new(
2661            session,
2662            make_fake_http_client(|_| Box::pin(async { unreachable!() })),
2663            None,
2664        );
2665
2666        assert_eq!(provider.access_token(), None);
2667    }
2668
2669    #[test]
2670    fn test_mcp_oauth_provider_returns_token_when_not_expired() {
2671        let far_future = SystemTime::now() + Duration::from_secs(3600);
2672        let session = make_test_session("valid-token", Some("rt"), Some(far_future));
2673        let provider = McpOAuthTokenProvider::new(
2674            session,
2675            make_fake_http_client(|_| Box::pin(async { unreachable!() })),
2676            None,
2677        );
2678
2679        assert_eq!(provider.access_token().as_deref(), Some("valid-token"));
2680    }
2681
2682    #[test]
2683    fn test_mcp_oauth_provider_returns_token_when_no_expiry() {
2684        let session = make_test_session("no-expiry-token", Some("rt"), None);
2685        let provider = McpOAuthTokenProvider::new(
2686            session,
2687            make_fake_http_client(|_| Box::pin(async { unreachable!() })),
2688            None,
2689        );
2690
2691        assert_eq!(provider.access_token().as_deref(), Some("no-expiry-token"));
2692    }
2693
2694    #[test]
2695    fn test_mcp_oauth_provider_refresh_without_refresh_token_returns_false() {
2696        smol::block_on(async {
2697            let session = make_test_session("token", None, None);
2698            let provider = McpOAuthTokenProvider::new(
2699                session,
2700                make_fake_http_client(|_| {
2701                    Box::pin(async { unreachable!("no HTTP call expected") })
2702                }),
2703                None,
2704            );
2705
2706            let refreshed = provider.try_refresh().await.unwrap();
2707            assert!(!refreshed);
2708        });
2709    }
2710
2711    #[test]
2712    fn test_mcp_oauth_provider_refresh_updates_session_and_notifies_channel() {
2713        smol::block_on(async {
2714            let session = make_test_session("old-access", Some("my-refresh-token"), None);
2715            let (tx, mut rx) = futures::channel::mpsc::unbounded();
2716
2717            let http_client = make_fake_http_client(|_req| {
2718                Box::pin(async {
2719                    json_response(
2720                        200,
2721                        r#"{
2722                            "access_token": "new-access",
2723                            "refresh_token": "new-refresh",
2724                            "expires_in": 1800
2725                        }"#,
2726                    )
2727                })
2728            });
2729
2730            let provider = McpOAuthTokenProvider::new(session, http_client, Some(tx));
2731
2732            let refreshed = provider.try_refresh().await.unwrap();
2733            assert!(refreshed);
2734            assert_eq!(provider.access_token().as_deref(), Some("new-access"));
2735
2736            let notified_session = rx.try_recv().expect("channel should have a session");
2737            assert_eq!(notified_session.tokens.access_token, "new-access");
2738            assert_eq!(
2739                notified_session.tokens.refresh_token.as_deref(),
2740                Some("new-refresh")
2741            );
2742        });
2743    }
2744
2745    #[test]
2746    fn test_mcp_oauth_provider_refresh_preserves_old_refresh_token_when_server_omits_it() {
2747        smol::block_on(async {
2748            let session = make_test_session("old-access", Some("original-refresh"), None);
2749            let (tx, mut rx) = futures::channel::mpsc::unbounded();
2750
2751            let http_client = make_fake_http_client(|_req| {
2752                Box::pin(async {
2753                    json_response(
2754                        200,
2755                        r#"{
2756                            "access_token": "new-access",
2757                            "expires_in": 900
2758                        }"#,
2759                    )
2760                })
2761            });
2762
2763            let provider = McpOAuthTokenProvider::new(session, http_client, Some(tx));
2764
2765            let refreshed = provider.try_refresh().await.unwrap();
2766            assert!(refreshed);
2767
2768            let notified_session = rx.try_recv().expect("channel should have a session");
2769            assert_eq!(notified_session.tokens.access_token, "new-access");
2770            assert_eq!(
2771                notified_session.tokens.refresh_token.as_deref(),
2772                Some("original-refresh"),
2773            );
2774        });
2775    }
2776
2777    #[test]
2778    fn test_mcp_oauth_provider_refresh_returns_false_on_http_error() {
2779        smol::block_on(async {
2780            let session = make_test_session("old-access", Some("my-refresh"), None);
2781
2782            let http_client = make_fake_http_client(|_req| {
2783                Box::pin(async { json_response(401, r#"{"error": "invalid_grant"}"#) })
2784            });
2785
2786            let provider = McpOAuthTokenProvider::new(session, http_client, None);
2787
2788            let refreshed = provider.try_refresh().await.unwrap();
2789            assert!(!refreshed);
2790            // The old token should still be in place.
2791            assert_eq!(provider.access_token().as_deref(), Some("old-access"));
2792        });
2793    }
2794}