oauth.rs

   1//! OAuth 2.0 authentication for MCP servers using the Authorization Code +
   2//! PKCE flow, per the MCP spec's OAuth profile.
   3//!
   4//! The flow is split into two phases:
   5//!
   6//! 1. **Discovery** ([`discover`]) fetches Protected Resource Metadata and
   7//!    Authorization Server Metadata. This can happen early (e.g. on a 401
   8//!    during server startup) because it doesn't need the redirect URI yet.
   9//!
  10//! 2. **Client registration** ([`resolve_client_registration`]) is separate
  11//!    because DCR requires the actual loopback redirect URI, which includes an
  12//!    ephemeral port that only exists once the callback server has started.
  13//!
  14//! After authentication, the full state is captured in [`OAuthSession`] which
  15//! is persisted to the keychain. On next startup, the stored session feeds
  16//! directly into [`McpOAuthTokenProvider`], giving a refresh-capable provider
  17//! without requiring another browser flow.
  18
  19use anyhow::{Context as _, Result, anyhow, bail};
  20use async_trait::async_trait;
  21use base64::Engine as _;
  22use futures::AsyncReadExt as _;
  23use futures::channel::mpsc;
  24use http_client::{AsyncBody, HttpClient, Request};
  25use parking_lot::Mutex as SyncMutex;
  26use rand::Rng as _;
  27use serde::{Deserialize, Serialize};
  28use sha2::{Digest, Sha256};
  29
  30use std::str::FromStr;
  31use std::sync::Arc;
  32use std::time::{Duration, SystemTime};
  33use url::Url;
  34use util::ResultExt as _;
  35
  36/// The CIMD URL where Zed's OAuth client metadata document is hosted.
  37pub const CIMD_URL: &str = "https://zed.dev/oauth/client-metadata.json";
  38
  39/// Validate that a URL is safe to use as an OAuth endpoint.
  40///
  41/// OAuth endpoints carry sensitive material (authorization codes, PKCE
  42/// verifiers, tokens) and must use TLS. Plain HTTP is only permitted for
  43/// loopback addresses, per RFC 8252 Section 8.3.
  44fn require_https_or_loopback(url: &Url) -> Result<()> {
  45    if url.scheme() == "https" {
  46        return Ok(());
  47    }
  48    if url.scheme() == "http" {
  49        if let Some(host) = url.host() {
  50            match host {
  51                url::Host::Ipv4(ip) if ip.is_loopback() => return Ok(()),
  52                url::Host::Ipv6(ip) if ip.is_loopback() => return Ok(()),
  53                url::Host::Domain(d) if d.eq_ignore_ascii_case("localhost") => return Ok(()),
  54                _ => {}
  55            }
  56        }
  57    }
  58    bail!(
  59        "OAuth endpoint must use HTTPS (got {}://{})",
  60        url.scheme(),
  61        url.host_str().unwrap_or("?")
  62    )
  63}
  64
  65/// Validate that a URL is safe to use as an OAuth endpoint, including SSRF
  66/// protections against private/reserved IP ranges.
  67///
  68/// This wraps [`require_https_or_loopback`] and adds IP-range checks to prevent
  69/// an attacker-controlled MCP server from directing Zed to fetch internal
  70/// network resources via metadata URLs.
  71///
  72/// **Known limitation:** Domain-name URLs that resolve to private IPs are *not*
  73/// blocked here — full mitigation requires resolver-level validation (e.g. a
  74/// custom `Resolve` implementation). This function only blocks IP-literal URLs.
  75fn validate_oauth_url(url: &Url) -> Result<()> {
  76    require_https_or_loopback(url)?;
  77
  78    if let Some(host) = url.host() {
  79        match host {
  80            url::Host::Ipv4(ip) => {
  81                // Loopback is already allowed by require_https_or_loopback.
  82                if ip.is_private() || ip.is_link_local() || ip.is_broadcast() || ip.is_unspecified()
  83                {
  84                    bail!(
  85                        "OAuth endpoint must not point to private/reserved IP: {}",
  86                        ip
  87                    );
  88                }
  89            }
  90            url::Host::Ipv6(ip) => {
  91                // Check for IPv4-mapped IPv6 addresses (::ffff:a.b.c.d) which
  92                // could bypass the IPv4 checks above.
  93                if let Some(mapped_v4) = ip.to_ipv4_mapped() {
  94                    if mapped_v4.is_private()
  95                        || mapped_v4.is_link_local()
  96                        || mapped_v4.is_broadcast()
  97                        || mapped_v4.is_unspecified()
  98                    {
  99                        bail!(
 100                            "OAuth endpoint must not point to private/reserved IP: ::ffff:{}",
 101                            mapped_v4
 102                        );
 103                    }
 104                }
 105
 106                if ip.is_unspecified() || ip.is_multicast() {
 107                    bail!(
 108                        "OAuth endpoint must not point to reserved IPv6 address: {}",
 109                        ip
 110                    );
 111                }
 112                // IPv6 Unique Local Addresses (fc00::/7). is_unique_local() is
 113                // nightly-only, so check the prefix manually.
 114                if (ip.segments()[0] & 0xfe00) == 0xfc00 {
 115                    bail!(
 116                        "OAuth endpoint must not point to IPv6 unique-local address: {}",
 117                        ip
 118                    );
 119                }
 120            }
 121            url::Host::Domain(_) => {
 122                // Domain-based SSRF prevention requires resolver-level checks.
 123                // See known limitation in the doc comment above.
 124            }
 125        }
 126    }
 127
 128    Ok(())
 129}
 130
 131/// Parsed from the MCP server's WWW-Authenticate header or well-known endpoint
 132/// per RFC 9728 (OAuth 2.0 Protected Resource Metadata).
 133#[derive(Debug, Clone, Serialize, Deserialize)]
 134pub struct ProtectedResourceMetadata {
 135    pub resource: Url,
 136    pub authorization_servers: Vec<Url>,
 137    pub scopes_supported: Option<Vec<String>>,
 138}
 139
 140/// Parsed from the authorization server's .well-known endpoint
 141/// per RFC 8414 (OAuth 2.0 Authorization Server Metadata).
 142#[derive(Debug, Clone, Serialize, Deserialize)]
 143pub struct AuthServerMetadata {
 144    pub issuer: Url,
 145    pub authorization_endpoint: Url,
 146    pub token_endpoint: Url,
 147    pub registration_endpoint: Option<Url>,
 148    pub scopes_supported: Option<Vec<String>>,
 149    pub code_challenge_methods_supported: Option<Vec<String>>,
 150    pub client_id_metadata_document_supported: bool,
 151}
 152
 153/// The result of client registration — either CIMD or DCR.
 154#[derive(Clone, Serialize, Deserialize)]
 155pub struct OAuthClientRegistration {
 156    pub client_id: String,
 157    /// Only present for DCR-minted registrations.
 158    pub client_secret: Option<String>,
 159}
 160
 161impl std::fmt::Debug for OAuthClientRegistration {
 162    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 163        f.debug_struct("OAuthClientRegistration")
 164            .field("client_id", &self.client_id)
 165            .field(
 166                "client_secret",
 167                &self.client_secret.as_ref().map(|_| "[redacted]"),
 168            )
 169            .finish()
 170    }
 171}
 172
 173/// Access and refresh tokens obtained from the token endpoint.
 174#[derive(Clone, Serialize, Deserialize)]
 175pub struct OAuthTokens {
 176    pub access_token: String,
 177    pub refresh_token: Option<String>,
 178    pub expires_at: Option<SystemTime>,
 179}
 180
 181impl std::fmt::Debug for OAuthTokens {
 182    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 183        f.debug_struct("OAuthTokens")
 184            .field("access_token", &"[redacted]")
 185            .field(
 186                "refresh_token",
 187                &self.refresh_token.as_ref().map(|_| "[redacted]"),
 188            )
 189            .field("expires_at", &self.expires_at)
 190            .finish()
 191    }
 192}
 193
 194/// Everything discovered before the browser flow starts. Client registration is
 195/// resolved separately, once the real redirect URI is known.
 196#[derive(Debug, Clone, Serialize, Deserialize)]
 197pub struct OAuthDiscovery {
 198    pub resource_metadata: ProtectedResourceMetadata,
 199    pub auth_server_metadata: AuthServerMetadata,
 200    pub scopes: Vec<String>,
 201}
 202
 203/// The persisted OAuth session for a context server.
 204///
 205/// Stored in the keychain so startup can restore a refresh-capable provider
 206/// without another browser flow. Deliberately excludes the full discovery
 207/// metadata to keep the serialized size well within keychain item limits.
 208#[derive(Clone, Serialize, Deserialize)]
 209pub struct OAuthSession {
 210    pub token_endpoint: Url,
 211    pub resource: Url,
 212    pub client_registration: OAuthClientRegistration,
 213    pub tokens: OAuthTokens,
 214}
 215
 216impl std::fmt::Debug for OAuthSession {
 217    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 218        f.debug_struct("OAuthSession")
 219            .field("token_endpoint", &self.token_endpoint)
 220            .field("resource", &self.resource)
 221            .field("client_registration", &self.client_registration)
 222            .field("tokens", &self.tokens)
 223            .finish()
 224    }
 225}
 226
 227/// Error codes defined by RFC 6750 Section 3.1 for Bearer token authentication.
 228#[derive(Debug, Clone, Copy, PartialEq, Eq)]
 229pub enum BearerError {
 230    /// The request is missing a required parameter, includes an unsupported
 231    /// parameter or parameter value, or is otherwise malformed.
 232    InvalidRequest,
 233    /// The access token provided is expired, revoked, malformed, or invalid.
 234    InvalidToken,
 235    /// The request requires higher privileges than provided by the access token.
 236    InsufficientScope,
 237    /// An unrecognized error code (extension or future spec addition).
 238    Other,
 239}
 240
 241impl BearerError {
 242    fn parse(value: &str) -> Self {
 243        match value {
 244            "invalid_request" => BearerError::InvalidRequest,
 245            "invalid_token" => BearerError::InvalidToken,
 246            "insufficient_scope" => BearerError::InsufficientScope,
 247            _ => BearerError::Other,
 248        }
 249    }
 250}
 251
 252/// Fields extracted from a `WWW-Authenticate: Bearer` header.
 253///
 254/// Per RFC 9728 Section 5.1, MCP servers include `resource_metadata` to point
 255/// at the Protected Resource Metadata document. The optional `scope` parameter
 256/// (RFC 6750 Section 3) indicates scopes required for the request.
 257#[derive(Debug, Clone, PartialEq, Eq)]
 258pub struct WwwAuthenticate {
 259    pub resource_metadata: Option<Url>,
 260    pub scope: Option<Vec<String>>,
 261    /// The parsed `error` parameter per RFC 6750 Section 3.1.
 262    pub error: Option<BearerError>,
 263    pub error_description: Option<String>,
 264}
 265
 266/// Parse a `WWW-Authenticate` header value.
 267///
 268/// Expects the `Bearer` scheme followed by comma-separated `key="value"` pairs.
 269/// Per RFC 6750 and RFC 9728, the relevant parameters are:
 270/// - `resource_metadata` — URL of the Protected Resource Metadata document
 271/// - `scope` — space-separated list of required scopes
 272/// - `error` — error code (e.g. "insufficient_scope")
 273/// - `error_description` — human-readable error description
 274pub fn parse_www_authenticate(header: &str) -> Result<WwwAuthenticate> {
 275    let header = header.trim();
 276
 277    let params_str = if header.len() >= 6 && header[..6].eq_ignore_ascii_case("bearer") {
 278        header[6..].trim()
 279    } else {
 280        bail!("WWW-Authenticate header does not use Bearer scheme");
 281    };
 282
 283    if params_str.is_empty() {
 284        return Ok(WwwAuthenticate {
 285            resource_metadata: None,
 286            scope: None,
 287            error: None,
 288            error_description: None,
 289        });
 290    }
 291
 292    let params = parse_auth_params(params_str);
 293
 294    let resource_metadata = params
 295        .get("resource_metadata")
 296        .map(|v| Url::parse(v))
 297        .transpose()
 298        .map_err(|e| anyhow!("invalid resource_metadata URL: {}", e))?;
 299
 300    let scope = params
 301        .get("scope")
 302        .map(|v| v.split_whitespace().map(String::from).collect());
 303
 304    let error = params.get("error").map(|v| BearerError::parse(v));
 305    let error_description = params.get("error_description").cloned();
 306
 307    Ok(WwwAuthenticate {
 308        resource_metadata,
 309        scope,
 310        error,
 311        error_description,
 312    })
 313}
 314
 315/// Parse comma-separated `key="value"` or `key=token` parameters from an
 316/// auth-param list (RFC 7235 Section 2.1).
 317fn parse_auth_params(input: &str) -> collections::HashMap<String, String> {
 318    let mut params = collections::HashMap::default();
 319    let mut remaining = input.trim();
 320
 321    while !remaining.is_empty() {
 322        // Skip leading whitespace and commas.
 323        remaining = remaining.trim_start_matches(|c: char| c == ',' || c.is_whitespace());
 324        if remaining.is_empty() {
 325            break;
 326        }
 327
 328        // Find the key (everything before '=').
 329        let eq_pos = match remaining.find('=') {
 330            Some(pos) => pos,
 331            None => break,
 332        };
 333
 334        let key = remaining[..eq_pos].trim().to_lowercase();
 335        remaining = &remaining[eq_pos + 1..];
 336        remaining = remaining.trim_start();
 337
 338        // Parse the value: either quoted or unquoted (token).
 339        let value;
 340        if remaining.starts_with('"') {
 341            // Quoted string: find the closing quote, handling escaped chars.
 342            remaining = &remaining[1..]; // skip opening quote
 343            let mut val = String::new();
 344            let mut chars = remaining.char_indices();
 345            loop {
 346                match chars.next() {
 347                    Some((_, '\\')) => {
 348                        // Escaped character — take the next char literally.
 349                        if let Some((_, c)) = chars.next() {
 350                            val.push(c);
 351                        }
 352                    }
 353                    Some((i, '"')) => {
 354                        remaining = &remaining[i + 1..];
 355                        break;
 356                    }
 357                    Some((_, c)) => val.push(c),
 358                    None => {
 359                        remaining = "";
 360                        break;
 361                    }
 362                }
 363            }
 364            value = val;
 365        } else {
 366            // Unquoted token: read until comma or whitespace.
 367            let end = remaining
 368                .find(|c: char| c == ',' || c.is_whitespace())
 369                .unwrap_or(remaining.len());
 370            value = remaining[..end].to_string();
 371            remaining = &remaining[end..];
 372        }
 373
 374        if !key.is_empty() {
 375            params.insert(key, value);
 376        }
 377    }
 378
 379    params
 380}
 381
 382/// Construct the well-known Protected Resource Metadata URIs for a given MCP
 383/// server URL, per RFC 9728 Section 3.
 384///
 385/// Returns URIs in priority order:
 386/// 1. Path-specific: `https://<host>/.well-known/oauth-protected-resource/<path>`
 387/// 2. Root: `https://<host>/.well-known/oauth-protected-resource`
 388pub fn protected_resource_metadata_urls(server_url: &Url) -> Vec<Url> {
 389    let mut urls = Vec::new();
 390    let base = format!("{}://{}", server_url.scheme(), server_url.authority());
 391
 392    let path = server_url.path().trim_start_matches('/');
 393    if !path.is_empty() {
 394        if let Ok(url) = Url::parse(&format!(
 395            "{}/.well-known/oauth-protected-resource/{}",
 396            base, path
 397        )) {
 398            urls.push(url);
 399        }
 400    }
 401
 402    if let Ok(url) = Url::parse(&format!("{}/.well-known/oauth-protected-resource", base)) {
 403        urls.push(url);
 404    }
 405
 406    urls
 407}
 408
 409/// Construct the well-known Authorization Server Metadata URIs for a given
 410/// issuer URL, per RFC 8414 Section 3.1 and Section 5 (OIDC compat).
 411///
 412/// Returns URIs in priority order, which differs depending on whether the
 413/// issuer URL has a path component.
 414pub fn auth_server_metadata_urls(issuer: &Url) -> Vec<Url> {
 415    let mut urls = Vec::new();
 416    let base = format!("{}://{}", issuer.scheme(), issuer.authority());
 417    let path = issuer.path().trim_matches('/');
 418
 419    if !path.is_empty() {
 420        // Issuer with path: try path-inserted variants first.
 421        if let Ok(url) = Url::parse(&format!(
 422            "{}/.well-known/oauth-authorization-server/{}",
 423            base, path
 424        )) {
 425            urls.push(url);
 426        }
 427        if let Ok(url) = Url::parse(&format!(
 428            "{}/.well-known/openid-configuration/{}",
 429            base, path
 430        )) {
 431            urls.push(url);
 432        }
 433        if let Ok(url) = Url::parse(&format!(
 434            "{}/{}/.well-known/openid-configuration",
 435            base, path
 436        )) {
 437            urls.push(url);
 438        }
 439    } else {
 440        // No path: standard well-known locations.
 441        if let Ok(url) = Url::parse(&format!("{}/.well-known/oauth-authorization-server", base)) {
 442            urls.push(url);
 443        }
 444        if let Ok(url) = Url::parse(&format!("{}/.well-known/openid-configuration", base)) {
 445            urls.push(url);
 446        }
 447    }
 448
 449    urls
 450}
 451
 452// -- Canonical server URI (RFC 8707) -----------------------------------------
 453
 454/// Derive the canonical resource URI for an MCP server URL, suitable for the
 455/// `resource` parameter in authorization and token requests per RFC 8707.
 456///
 457/// Lowercases the scheme and host, preserves the path (without trailing slash),
 458/// strips fragments and query strings.
 459pub fn canonical_server_uri(server_url: &Url) -> String {
 460    let mut uri = format!(
 461        "{}://{}",
 462        server_url.scheme().to_ascii_lowercase(),
 463        server_url.host_str().unwrap_or("").to_ascii_lowercase(),
 464    );
 465    if let Some(port) = server_url.port() {
 466        uri.push_str(&format!(":{}", port));
 467    }
 468    let path = server_url.path();
 469    if path != "/" {
 470        uri.push_str(path.trim_end_matches('/'));
 471    }
 472    uri
 473}
 474
 475// -- Scope selection ---------------------------------------------------------
 476
 477/// Select scopes following the MCP spec's Scope Selection Strategy:
 478/// 1. Use `scope` from the `WWW-Authenticate` challenge if present.
 479/// 2. Fall back to `scopes_supported` from Protected Resource Metadata.
 480/// 3. Return empty if neither is available.
 481pub fn select_scopes(
 482    www_authenticate: &WwwAuthenticate,
 483    resource_metadata: &ProtectedResourceMetadata,
 484) -> Vec<String> {
 485    if let Some(ref scopes) = www_authenticate.scope {
 486        if !scopes.is_empty() {
 487            return scopes.clone();
 488        }
 489    }
 490    resource_metadata
 491        .scopes_supported
 492        .clone()
 493        .unwrap_or_default()
 494}
 495
 496// -- Client registration strategy --------------------------------------------
 497
 498/// The registration approach to use, determined from auth server metadata.
 499#[derive(Debug, Clone, PartialEq, Eq)]
 500pub enum ClientRegistrationStrategy {
 501    /// The auth server supports CIMD. Use the CIMD URL as client_id directly.
 502    Cimd { client_id: String },
 503    /// The auth server has a registration endpoint. Caller must POST to it.
 504    Dcr { registration_endpoint: Url },
 505    /// No supported registration mechanism.
 506    Unavailable,
 507}
 508
 509/// Determine how to register with the authorization server, following the
 510/// spec's recommended priority: CIMD first, DCR fallback.
 511pub fn determine_registration_strategy(
 512    auth_server_metadata: &AuthServerMetadata,
 513) -> ClientRegistrationStrategy {
 514    if auth_server_metadata.client_id_metadata_document_supported {
 515        ClientRegistrationStrategy::Cimd {
 516            client_id: CIMD_URL.to_string(),
 517        }
 518    } else if let Some(ref endpoint) = auth_server_metadata.registration_endpoint {
 519        ClientRegistrationStrategy::Dcr {
 520            registration_endpoint: endpoint.clone(),
 521        }
 522    } else {
 523        ClientRegistrationStrategy::Unavailable
 524    }
 525}
 526
 527// -- PKCE (RFC 7636) ---------------------------------------------------------
 528
 529/// A PKCE code verifier and its S256 challenge.
 530#[derive(Clone)]
 531pub struct PkceChallenge {
 532    pub verifier: String,
 533    pub challenge: String,
 534}
 535
 536impl std::fmt::Debug for PkceChallenge {
 537    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 538        f.debug_struct("PkceChallenge")
 539            .field("verifier", &"[redacted]")
 540            .field("challenge", &self.challenge)
 541            .finish()
 542    }
 543}
 544
 545/// Generate a PKCE code verifier and S256 challenge per RFC 7636.
 546///
 547/// The verifier is 43 base64url characters derived from 32 random bytes.
 548/// The challenge is `BASE64URL(SHA256(verifier))`.
 549pub fn generate_pkce_challenge() -> PkceChallenge {
 550    let mut random_bytes = [0u8; 32];
 551    rand::rng().fill(&mut random_bytes);
 552    let engine = base64::engine::general_purpose::URL_SAFE_NO_PAD;
 553    let verifier = engine.encode(&random_bytes);
 554
 555    let digest = Sha256::digest(verifier.as_bytes());
 556    let challenge = engine.encode(digest);
 557
 558    PkceChallenge {
 559        verifier,
 560        challenge,
 561    }
 562}
 563
 564// -- Authorization URL construction ------------------------------------------
 565
 566/// Build the authorization URL for the OAuth Authorization Code + PKCE flow.
 567pub fn build_authorization_url(
 568    auth_server_metadata: &AuthServerMetadata,
 569    client_id: &str,
 570    redirect_uri: &str,
 571    scopes: &[String],
 572    resource: &str,
 573    pkce: &PkceChallenge,
 574    state: &str,
 575) -> Url {
 576    let mut url = auth_server_metadata.authorization_endpoint.clone();
 577    {
 578        let mut query = url.query_pairs_mut();
 579        query.append_pair("response_type", "code");
 580        query.append_pair("client_id", client_id);
 581        query.append_pair("redirect_uri", redirect_uri);
 582        if !scopes.is_empty() {
 583            query.append_pair("scope", &scopes.join(" "));
 584        }
 585        query.append_pair("resource", resource);
 586        query.append_pair("code_challenge", &pkce.challenge);
 587        query.append_pair("code_challenge_method", "S256");
 588        query.append_pair("state", state);
 589    }
 590    url
 591}
 592
 593// -- Token endpoint request bodies -------------------------------------------
 594
 595/// The JSON body returned by the token endpoint on success.
 596#[derive(Deserialize)]
 597pub struct TokenResponse {
 598    pub access_token: String,
 599    #[serde(default)]
 600    pub refresh_token: Option<String>,
 601    #[serde(default)]
 602    pub expires_in: Option<u64>,
 603    #[serde(default)]
 604    pub token_type: Option<String>,
 605}
 606
 607impl std::fmt::Debug for TokenResponse {
 608    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 609        f.debug_struct("TokenResponse")
 610            .field("access_token", &"[redacted]")
 611            .field(
 612                "refresh_token",
 613                &self.refresh_token.as_ref().map(|_| "[redacted]"),
 614            )
 615            .field("expires_in", &self.expires_in)
 616            .field("token_type", &self.token_type)
 617            .finish()
 618    }
 619}
 620
 621impl TokenResponse {
 622    /// Convert into `OAuthTokens`, computing `expires_at` from `expires_in`.
 623    pub fn into_tokens(self) -> OAuthTokens {
 624        let expires_at = self
 625            .expires_in
 626            .map(|secs| SystemTime::now() + Duration::from_secs(secs));
 627        OAuthTokens {
 628            access_token: self.access_token,
 629            refresh_token: self.refresh_token,
 630            expires_at,
 631        }
 632    }
 633}
 634
 635/// Build the form-encoded body for an authorization code token exchange.
 636pub fn token_exchange_params(
 637    code: &str,
 638    client_id: &str,
 639    redirect_uri: &str,
 640    code_verifier: &str,
 641    resource: &str,
 642) -> Vec<(&'static str, String)> {
 643    vec![
 644        ("grant_type", "authorization_code".to_string()),
 645        ("code", code.to_string()),
 646        ("redirect_uri", redirect_uri.to_string()),
 647        ("client_id", client_id.to_string()),
 648        ("code_verifier", code_verifier.to_string()),
 649        ("resource", resource.to_string()),
 650    ]
 651}
 652
 653/// Build the form-encoded body for a token refresh request.
 654pub fn token_refresh_params(
 655    refresh_token: &str,
 656    client_id: &str,
 657    resource: &str,
 658) -> Vec<(&'static str, String)> {
 659    vec![
 660        ("grant_type", "refresh_token".to_string()),
 661        ("refresh_token", refresh_token.to_string()),
 662        ("client_id", client_id.to_string()),
 663        ("resource", resource.to_string()),
 664    ]
 665}
 666
 667// -- DCR request body (RFC 7591) ---------------------------------------------
 668
 669/// Build the JSON body for a Dynamic Client Registration request.
 670///
 671/// The `redirect_uri` should be the actual loopback URI with the ephemeral
 672/// port (e.g. `http://127.0.0.1:12345/callback`). Some auth servers do strict
 673/// redirect URI matching even for loopback addresses, so we register the
 674/// exact URI we intend to use.
 675pub fn dcr_registration_body(redirect_uri: &str) -> serde_json::Value {
 676    serde_json::json!({
 677        "client_name": "Zed",
 678        "redirect_uris": [redirect_uri],
 679        "grant_types": ["authorization_code"],
 680        "response_types": ["code"],
 681        "token_endpoint_auth_method": "none"
 682    })
 683}
 684
 685// -- Discovery (async, hits real endpoints) ----------------------------------
 686
 687/// Fetch Protected Resource Metadata from the MCP server.
 688///
 689/// Tries the `resource_metadata` URL from the `WWW-Authenticate` header first,
 690/// then falls back to well-known URIs constructed from `server_url`.
 691pub async fn fetch_protected_resource_metadata(
 692    http_client: &Arc<dyn HttpClient>,
 693    server_url: &Url,
 694    www_authenticate: &WwwAuthenticate,
 695) -> Result<ProtectedResourceMetadata> {
 696    let candidate_urls = match &www_authenticate.resource_metadata {
 697        Some(url) if url.origin() == server_url.origin() => vec![url.clone()],
 698        Some(url) => {
 699            log::warn!(
 700                "Ignoring cross-origin resource_metadata URL {} \
 701                 (server origin: {})",
 702                url,
 703                server_url.origin().unicode_serialization()
 704            );
 705            protected_resource_metadata_urls(server_url)
 706        }
 707        None => protected_resource_metadata_urls(server_url),
 708    };
 709
 710    for url in &candidate_urls {
 711        match fetch_json::<ProtectedResourceMetadataResponse>(http_client, url).await {
 712            Ok(response) => {
 713                if response.authorization_servers.is_empty() {
 714                    bail!(
 715                        "Protected Resource Metadata at {} has no authorization_servers",
 716                        url
 717                    );
 718                }
 719                return Ok(ProtectedResourceMetadata {
 720                    resource: response.resource.unwrap_or_else(|| server_url.clone()),
 721                    authorization_servers: response.authorization_servers,
 722                    scopes_supported: response.scopes_supported,
 723                });
 724            }
 725            Err(err) => {
 726                log::debug!(
 727                    "Failed to fetch Protected Resource Metadata from {}: {}",
 728                    url,
 729                    err
 730                );
 731            }
 732        }
 733    }
 734
 735    bail!(
 736        "Could not fetch Protected Resource Metadata for {}",
 737        server_url
 738    )
 739}
 740
 741/// Fetch Authorization Server Metadata, trying RFC 8414 and OIDC Discovery
 742/// endpoints in the priority order specified by the MCP spec.
 743pub async fn fetch_auth_server_metadata(
 744    http_client: &Arc<dyn HttpClient>,
 745    issuer: &Url,
 746) -> Result<AuthServerMetadata> {
 747    let candidate_urls = auth_server_metadata_urls(issuer);
 748
 749    for url in &candidate_urls {
 750        match fetch_json::<AuthServerMetadataResponse>(http_client, url).await {
 751            Ok(response) => {
 752                let reported_issuer = response.issuer.unwrap_or_else(|| issuer.clone());
 753                if reported_issuer != *issuer {
 754                    bail!(
 755                        "Auth server metadata issuer mismatch: expected {}, got {}",
 756                        issuer,
 757                        reported_issuer
 758                    );
 759                }
 760
 761                return Ok(AuthServerMetadata {
 762                    issuer: reported_issuer,
 763                    authorization_endpoint: response
 764                        .authorization_endpoint
 765                        .ok_or_else(|| anyhow!("missing authorization_endpoint"))?,
 766                    token_endpoint: response
 767                        .token_endpoint
 768                        .ok_or_else(|| anyhow!("missing token_endpoint"))?,
 769                    registration_endpoint: response.registration_endpoint,
 770                    scopes_supported: response.scopes_supported,
 771                    code_challenge_methods_supported: response.code_challenge_methods_supported,
 772                    client_id_metadata_document_supported: response
 773                        .client_id_metadata_document_supported
 774                        .unwrap_or(false),
 775                });
 776            }
 777            Err(err) => {
 778                log::debug!("Failed to fetch Auth Server Metadata from {}: {}", url, err);
 779            }
 780        }
 781    }
 782
 783    bail!(
 784        "Could not fetch Authorization Server Metadata for {}",
 785        issuer
 786    )
 787}
 788
 789/// Run the full discovery flow: fetch resource metadata, then auth server
 790/// metadata, then select scopes. Client registration is resolved separately,
 791/// once the real redirect URI is known.
 792pub async fn discover(
 793    http_client: &Arc<dyn HttpClient>,
 794    server_url: &Url,
 795    www_authenticate: &WwwAuthenticate,
 796) -> Result<OAuthDiscovery> {
 797    let resource_metadata =
 798        fetch_protected_resource_metadata(http_client, server_url, www_authenticate).await?;
 799
 800    let auth_server_url = resource_metadata
 801        .authorization_servers
 802        .first()
 803        .ok_or_else(|| anyhow!("no authorization servers in resource metadata"))?;
 804
 805    let auth_server_metadata = fetch_auth_server_metadata(http_client, auth_server_url).await?;
 806
 807    // Verify PKCE S256 support (spec requirement).
 808    match &auth_server_metadata.code_challenge_methods_supported {
 809        Some(methods) if methods.iter().any(|m| m == "S256") => {}
 810        Some(_) => bail!("authorization server does not support S256 PKCE"),
 811        None => bail!("authorization server does not advertise code_challenge_methods_supported"),
 812    }
 813
 814    // Verify there is at least one supported registration strategy before we
 815    // present the server as ready to authenticate.
 816    match determine_registration_strategy(&auth_server_metadata) {
 817        ClientRegistrationStrategy::Cimd { .. } | ClientRegistrationStrategy::Dcr { .. } => {}
 818        ClientRegistrationStrategy::Unavailable => {
 819            bail!("authorization server supports neither CIMD nor DCR")
 820        }
 821    }
 822
 823    let scopes = select_scopes(www_authenticate, &resource_metadata);
 824
 825    Ok(OAuthDiscovery {
 826        resource_metadata,
 827        auth_server_metadata,
 828        scopes,
 829    })
 830}
 831
 832/// Resolve the OAuth client registration for an authorization flow.
 833///
 834/// CIMD uses the static client metadata document directly. For DCR, a fresh
 835/// registration is performed each time because the loopback redirect URI
 836/// includes an ephemeral port that changes every flow.
 837pub async fn resolve_client_registration(
 838    http_client: &Arc<dyn HttpClient>,
 839    discovery: &OAuthDiscovery,
 840    redirect_uri: &str,
 841) -> Result<OAuthClientRegistration> {
 842    match determine_registration_strategy(&discovery.auth_server_metadata) {
 843        ClientRegistrationStrategy::Cimd { client_id } => Ok(OAuthClientRegistration {
 844            client_id,
 845            client_secret: None,
 846        }),
 847        ClientRegistrationStrategy::Dcr {
 848            registration_endpoint,
 849        } => perform_dcr(http_client, &registration_endpoint, redirect_uri).await,
 850        ClientRegistrationStrategy::Unavailable => {
 851            bail!("authorization server supports neither CIMD nor DCR")
 852        }
 853    }
 854}
 855
 856// -- Dynamic Client Registration (RFC 7591) ----------------------------------
 857
 858/// Perform Dynamic Client Registration with the authorization server.
 859pub async fn perform_dcr(
 860    http_client: &Arc<dyn HttpClient>,
 861    registration_endpoint: &Url,
 862    redirect_uri: &str,
 863) -> Result<OAuthClientRegistration> {
 864    validate_oauth_url(registration_endpoint)?;
 865
 866    let body = dcr_registration_body(redirect_uri);
 867    let body_bytes = serde_json::to_vec(&body)?;
 868
 869    let request = Request::builder()
 870        .method(http_client::http::Method::POST)
 871        .uri(registration_endpoint.as_str())
 872        .header("Content-Type", "application/json")
 873        .header("Accept", "application/json")
 874        .body(AsyncBody::from(body_bytes))?;
 875
 876    let mut response = http_client.send(request).await?;
 877
 878    if !response.status().is_success() {
 879        let mut error_body = String::new();
 880        response.body_mut().read_to_string(&mut error_body).await?;
 881        bail!(
 882            "DCR failed with status {}: {}",
 883            response.status(),
 884            error_body
 885        );
 886    }
 887
 888    let mut response_body = String::new();
 889    response
 890        .body_mut()
 891        .read_to_string(&mut response_body)
 892        .await?;
 893
 894    let dcr_response: DcrResponse =
 895        serde_json::from_str(&response_body).context("failed to parse DCR response")?;
 896
 897    Ok(OAuthClientRegistration {
 898        client_id: dcr_response.client_id,
 899        client_secret: dcr_response.client_secret,
 900    })
 901}
 902
 903// -- Token exchange and refresh (async) --------------------------------------
 904
 905/// Exchange an authorization code for tokens at the token endpoint.
 906pub async fn exchange_code(
 907    http_client: &Arc<dyn HttpClient>,
 908    auth_server_metadata: &AuthServerMetadata,
 909    code: &str,
 910    client_id: &str,
 911    redirect_uri: &str,
 912    code_verifier: &str,
 913    resource: &str,
 914) -> Result<OAuthTokens> {
 915    let params = token_exchange_params(code, client_id, redirect_uri, code_verifier, resource);
 916    post_token_request(http_client, &auth_server_metadata.token_endpoint, &params).await
 917}
 918
 919/// Refresh tokens using a refresh token.
 920pub async fn refresh_tokens(
 921    http_client: &Arc<dyn HttpClient>,
 922    token_endpoint: &Url,
 923    refresh_token: &str,
 924    client_id: &str,
 925    resource: &str,
 926) -> Result<OAuthTokens> {
 927    let params = token_refresh_params(refresh_token, client_id, resource);
 928    post_token_request(http_client, token_endpoint, &params).await
 929}
 930
 931/// POST form-encoded parameters to a token endpoint and parse the response.
 932async fn post_token_request(
 933    http_client: &Arc<dyn HttpClient>,
 934    token_endpoint: &Url,
 935    params: &[(&str, String)],
 936) -> Result<OAuthTokens> {
 937    validate_oauth_url(token_endpoint)?;
 938
 939    let body = url::form_urlencoded::Serializer::new(String::new())
 940        .extend_pairs(params.iter().map(|(k, v)| (*k, v.as_str())))
 941        .finish();
 942
 943    let request = Request::builder()
 944        .method(http_client::http::Method::POST)
 945        .uri(token_endpoint.as_str())
 946        .header("Content-Type", "application/x-www-form-urlencoded")
 947        .header("Accept", "application/json")
 948        .body(AsyncBody::from(body.into_bytes()))?;
 949
 950    let mut response = http_client.send(request).await?;
 951
 952    if !response.status().is_success() {
 953        let mut error_body = String::new();
 954        response.body_mut().read_to_string(&mut error_body).await?;
 955        bail!(
 956            "token request failed with status {}: {}",
 957            response.status(),
 958            error_body
 959        );
 960    }
 961
 962    let mut response_body = String::new();
 963    response
 964        .body_mut()
 965        .read_to_string(&mut response_body)
 966        .await?;
 967
 968    let token_response: TokenResponse =
 969        serde_json::from_str(&response_body).context("failed to parse token response")?;
 970
 971    Ok(token_response.into_tokens())
 972}
 973
 974// -- Loopback HTTP callback server -------------------------------------------
 975
 976/// An OAuth authorization callback received via the loopback HTTP server.
 977pub struct OAuthCallback {
 978    pub code: String,
 979    pub state: String,
 980}
 981
 982impl std::fmt::Debug for OAuthCallback {
 983    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 984        f.debug_struct("OAuthCallback")
 985            .field("code", &"[redacted]")
 986            .field("state", &"[redacted]")
 987            .finish()
 988    }
 989}
 990
 991impl OAuthCallback {
 992    /// Parse the query string from a callback URL like
 993    /// `http://127.0.0.1:<port>/callback?code=...&state=...`.
 994    pub fn parse_query(query: &str) -> Result<Self> {
 995        let mut code: Option<String> = None;
 996        let mut state: Option<String> = None;
 997        let mut error: Option<String> = None;
 998        let mut error_description: Option<String> = None;
 999
1000        for (key, value) in url::form_urlencoded::parse(query.as_bytes()) {
1001            match key.as_ref() {
1002                "code" => {
1003                    if !value.is_empty() {
1004                        code = Some(value.into_owned());
1005                    }
1006                }
1007                "state" => {
1008                    if !value.is_empty() {
1009                        state = Some(value.into_owned());
1010                    }
1011                }
1012                "error" => {
1013                    if !value.is_empty() {
1014                        error = Some(value.into_owned());
1015                    }
1016                }
1017                "error_description" => {
1018                    if !value.is_empty() {
1019                        error_description = Some(value.into_owned());
1020                    }
1021                }
1022                _ => {}
1023            }
1024        }
1025
1026        // Check for OAuth error response (RFC 6749 Section 4.1.2.1) before
1027        // checking for missing code/state.
1028        if let Some(error_code) = error {
1029            bail!(
1030                "OAuth authorization failed: {} ({})",
1031                error_code,
1032                error_description.as_deref().unwrap_or("no description")
1033            );
1034        }
1035
1036        let code = code.ok_or_else(|| anyhow!("missing 'code' parameter in OAuth callback"))?;
1037        let state = state.ok_or_else(|| anyhow!("missing 'state' parameter in OAuth callback"))?;
1038
1039        Ok(Self { code, state })
1040    }
1041}
1042
1043/// How long to wait for the browser to complete the OAuth flow before giving
1044/// up and releasing the loopback port.
1045const CALLBACK_TIMEOUT: Duration = Duration::from_secs(2 * 60);
1046
1047/// Start a loopback HTTP server to receive the OAuth authorization callback.
1048///
1049/// Binds to an ephemeral loopback port for each flow.
1050///
1051/// Returns `(redirect_uri, callback_future)`. The caller should use the
1052/// redirect URI in the authorization request, open the browser, then await
1053/// the future to receive the callback.
1054///
1055/// The server accepts exactly one request on `/callback`, validates that it
1056/// contains `code` and `state` query parameters, responds with a minimal
1057/// HTML page telling the user they can close the tab, and shuts down.
1058///
1059/// The callback server shuts down when the returned oneshot receiver is dropped
1060/// (e.g. because the authentication task was cancelled), or after a timeout
1061/// ([CALLBACK_TIMEOUT]).
1062pub async fn start_callback_server() -> Result<(
1063    String,
1064    futures::channel::oneshot::Receiver<Result<OAuthCallback>>,
1065)> {
1066    let server = tiny_http::Server::http("127.0.0.1:0")
1067        .map_err(|e| anyhow!(e).context("Failed to bind loopback listener for OAuth callback"))?;
1068    let port = server
1069        .server_addr()
1070        .to_ip()
1071        .context("server not bound to a TCP address")?
1072        .port();
1073
1074    let redirect_uri = format!("http://127.0.0.1:{}/callback", port);
1075
1076    let (tx, rx) = futures::channel::oneshot::channel();
1077
1078    // `tiny_http` is blocking, so we run it on a background thread.
1079    // The `recv_timeout` loop lets us check for cancellation (the receiver
1080    // being dropped) and enforce an overall timeout.
1081    std::thread::spawn(move || {
1082        let deadline = std::time::Instant::now() + CALLBACK_TIMEOUT;
1083
1084        loop {
1085            if tx.is_canceled() {
1086                return;
1087            }
1088            let remaining = deadline.saturating_duration_since(std::time::Instant::now());
1089            if remaining.is_zero() {
1090                return;
1091            }
1092
1093            let timeout = remaining.min(Duration::from_millis(500));
1094            let Some(request) = (match server.recv_timeout(timeout) {
1095                Ok(req) => req,
1096                Err(_) => {
1097                    let _ = tx.send(Err(anyhow!("OAuth callback server I/O error")));
1098                    return;
1099                }
1100            }) else {
1101                // Timeout with no request — loop back and check cancellation.
1102                continue;
1103            };
1104
1105            let result = handle_callback_request(&request);
1106
1107            let (status_code, body) = match &result {
1108                Ok(_) => (
1109                    200,
1110                    http_client::oauth_callback_page(
1111                        "Authorization Successful",
1112                        "You can close this tab and return to Zed.",
1113                    ),
1114                ),
1115                Err(err) => {
1116                    log::error!("OAuth callback error: {}", err);
1117                    (
1118                        400,
1119                        http_client::oauth_callback_page(
1120                            "Authorization Failed",
1121                            "Something went wrong. Please try again from Zed.",
1122                        ),
1123                    )
1124                }
1125            };
1126
1127            let response = tiny_http::Response::from_string(body)
1128                .with_status_code(status_code)
1129                .with_header(
1130                    tiny_http::Header::from_str("Content-Type: text/html")
1131                        .expect("failed to construct response header"),
1132                )
1133                .with_header(
1134                    tiny_http::Header::from_str("Keep-Alive: timeout=0,max=0")
1135                        .expect("failed to construct response header"),
1136                );
1137            request.respond(response).log_err();
1138
1139            let _ = tx.send(result);
1140            return;
1141        }
1142    });
1143
1144    Ok((redirect_uri, rx))
1145}
1146
1147/// Extract the `code` and `state` query parameters from an OAuth callback
1148/// request to `/callback`.
1149fn handle_callback_request(request: &tiny_http::Request) -> Result<OAuthCallback> {
1150    let url = Url::parse(&format!("http://localhost{}", request.url()))
1151        .context("malformed callback request URL")?;
1152
1153    if url.path() != "/callback" {
1154        bail!("unexpected path in OAuth callback: {}", url.path());
1155    }
1156
1157    let query = url
1158        .query()
1159        .ok_or_else(|| anyhow!("OAuth callback has no query string"))?;
1160    OAuthCallback::parse_query(query)
1161}
1162
1163// -- JSON fetch helper -------------------------------------------------------
1164
1165async fn fetch_json<T: serde::de::DeserializeOwned>(
1166    http_client: &Arc<dyn HttpClient>,
1167    url: &Url,
1168) -> Result<T> {
1169    validate_oauth_url(url)?;
1170
1171    let request = Request::builder()
1172        .method(http_client::http::Method::GET)
1173        .uri(url.as_str())
1174        .header("Accept", "application/json")
1175        .body(AsyncBody::default())?;
1176
1177    let mut response = http_client.send(request).await?;
1178
1179    if !response.status().is_success() {
1180        bail!("HTTP {} fetching {}", response.status(), url);
1181    }
1182
1183    let mut body = String::new();
1184    response.body_mut().read_to_string(&mut body).await?;
1185    serde_json::from_str(&body).with_context(|| format!("failed to parse JSON from {}", url))
1186}
1187
1188// -- Serde response types for discovery --------------------------------------
1189
1190#[derive(Debug, Deserialize)]
1191struct ProtectedResourceMetadataResponse {
1192    #[serde(default)]
1193    resource: Option<Url>,
1194    #[serde(default)]
1195    authorization_servers: Vec<Url>,
1196    #[serde(default)]
1197    scopes_supported: Option<Vec<String>>,
1198}
1199
1200#[derive(Debug, Deserialize)]
1201struct AuthServerMetadataResponse {
1202    #[serde(default)]
1203    issuer: Option<Url>,
1204    #[serde(default)]
1205    authorization_endpoint: Option<Url>,
1206    #[serde(default)]
1207    token_endpoint: Option<Url>,
1208    #[serde(default)]
1209    registration_endpoint: Option<Url>,
1210    #[serde(default)]
1211    scopes_supported: Option<Vec<String>>,
1212    #[serde(default)]
1213    code_challenge_methods_supported: Option<Vec<String>>,
1214    #[serde(default)]
1215    client_id_metadata_document_supported: Option<bool>,
1216}
1217
1218#[derive(Debug, Deserialize)]
1219struct DcrResponse {
1220    client_id: String,
1221    #[serde(default)]
1222    client_secret: Option<String>,
1223}
1224
1225/// Provides OAuth tokens to the HTTP transport layer.
1226///
1227/// The transport calls `access_token()` before each request. On a 401 response
1228/// it calls `try_refresh()` and retries once if the refresh succeeds.
1229#[async_trait]
1230pub trait OAuthTokenProvider: Send + Sync {
1231    /// Returns the current access token, if one is available.
1232    fn access_token(&self) -> Option<String>;
1233
1234    /// Attempts to refresh the access token. Returns `true` if a new token was
1235    /// obtained and the request should be retried.
1236    async fn try_refresh(&self) -> Result<bool>;
1237}
1238
1239/// Concrete `OAuthTokenProvider` backed by a full persisted OAuth session and
1240/// an HTTP client for token refresh. The same provider type is used both after
1241/// an interactive authentication flow and when restoring a saved session from
1242/// the keychain on startup.
1243pub struct McpOAuthTokenProvider {
1244    session: SyncMutex<OAuthSession>,
1245    http_client: Arc<dyn HttpClient>,
1246    token_refresh_tx: Option<mpsc::UnboundedSender<OAuthSession>>,
1247}
1248
1249impl McpOAuthTokenProvider {
1250    pub fn new(
1251        session: OAuthSession,
1252        http_client: Arc<dyn HttpClient>,
1253        token_refresh_tx: Option<mpsc::UnboundedSender<OAuthSession>>,
1254    ) -> Self {
1255        Self {
1256            session: SyncMutex::new(session),
1257            http_client,
1258            token_refresh_tx,
1259        }
1260    }
1261
1262    fn access_token_is_expired(tokens: &OAuthTokens) -> bool {
1263        tokens.expires_at.is_some_and(|expires_at| {
1264            SystemTime::now()
1265                .checked_add(Duration::from_secs(30))
1266                .is_some_and(|now_with_buffer| expires_at <= now_with_buffer)
1267        })
1268    }
1269}
1270
1271#[async_trait]
1272impl OAuthTokenProvider for McpOAuthTokenProvider {
1273    fn access_token(&self) -> Option<String> {
1274        let session = self.session.lock();
1275        if Self::access_token_is_expired(&session.tokens) {
1276            return None;
1277        }
1278        Some(session.tokens.access_token.clone())
1279    }
1280
1281    async fn try_refresh(&self) -> Result<bool> {
1282        let (refresh_token, token_endpoint, resource, client_id) = {
1283            let session = self.session.lock();
1284            match session.tokens.refresh_token.clone() {
1285                Some(refresh_token) => (
1286                    refresh_token,
1287                    session.token_endpoint.clone(),
1288                    session.resource.clone(),
1289                    session.client_registration.client_id.clone(),
1290                ),
1291                None => return Ok(false),
1292            }
1293        };
1294
1295        let resource_str = canonical_server_uri(&resource);
1296
1297        match refresh_tokens(
1298            &self.http_client,
1299            &token_endpoint,
1300            &refresh_token,
1301            &client_id,
1302            &resource_str,
1303        )
1304        .await
1305        {
1306            Ok(mut new_tokens) => {
1307                if new_tokens.refresh_token.is_none() {
1308                    new_tokens.refresh_token = Some(refresh_token);
1309                }
1310
1311                {
1312                    let mut session = self.session.lock();
1313                    session.tokens = new_tokens;
1314
1315                    if let Some(ref tx) = self.token_refresh_tx {
1316                        tx.unbounded_send(session.clone()).ok();
1317                    }
1318                }
1319
1320                Ok(true)
1321            }
1322            Err(err) => {
1323                log::warn!("OAuth token refresh failed: {}", err);
1324                Ok(false)
1325            }
1326        }
1327    }
1328}
1329
1330#[cfg(test)]
1331mod tests {
1332    use super::*;
1333    use http_client::Response;
1334
1335    // -- require_https_or_loopback tests ------------------------------------
1336
1337    #[test]
1338    fn test_require_https_or_loopback_accepts_https() {
1339        let url = Url::parse("https://auth.example.com/token").unwrap();
1340        assert!(require_https_or_loopback(&url).is_ok());
1341    }
1342
1343    #[test]
1344    fn test_require_https_or_loopback_rejects_http_remote() {
1345        let url = Url::parse("http://auth.example.com/token").unwrap();
1346        assert!(require_https_or_loopback(&url).is_err());
1347    }
1348
1349    #[test]
1350    fn test_require_https_or_loopback_accepts_http_127_0_0_1() {
1351        let url = Url::parse("http://127.0.0.1:8080/callback").unwrap();
1352        assert!(require_https_or_loopback(&url).is_ok());
1353    }
1354
1355    #[test]
1356    fn test_require_https_or_loopback_accepts_http_ipv6_loopback() {
1357        let url = Url::parse("http://[::1]:8080/callback").unwrap();
1358        assert!(require_https_or_loopback(&url).is_ok());
1359    }
1360
1361    #[test]
1362    fn test_require_https_or_loopback_accepts_http_localhost() {
1363        let url = Url::parse("http://localhost:8080/callback").unwrap();
1364        assert!(require_https_or_loopback(&url).is_ok());
1365    }
1366
1367    #[test]
1368    fn test_require_https_or_loopback_accepts_http_localhost_case_insensitive() {
1369        let url = Url::parse("http://LOCALHOST:8080/callback").unwrap();
1370        assert!(require_https_or_loopback(&url).is_ok());
1371    }
1372
1373    #[test]
1374    fn test_require_https_or_loopback_rejects_http_non_loopback_ip() {
1375        let url = Url::parse("http://192.168.1.1:8080/token").unwrap();
1376        assert!(require_https_or_loopback(&url).is_err());
1377    }
1378
1379    #[test]
1380    fn test_require_https_or_loopback_rejects_ftp() {
1381        let url = Url::parse("ftp://auth.example.com/token").unwrap();
1382        assert!(require_https_or_loopback(&url).is_err());
1383    }
1384
1385    // -- validate_oauth_url (SSRF) tests ------------------------------------
1386
1387    #[test]
1388    fn test_validate_oauth_url_accepts_https_public() {
1389        let url = Url::parse("https://auth.example.com/token").unwrap();
1390        assert!(validate_oauth_url(&url).is_ok());
1391    }
1392
1393    #[test]
1394    fn test_validate_oauth_url_rejects_private_ipv4_10() {
1395        let url = Url::parse("https://10.0.0.1/token").unwrap();
1396        assert!(validate_oauth_url(&url).is_err());
1397    }
1398
1399    #[test]
1400    fn test_validate_oauth_url_rejects_private_ipv4_172() {
1401        let url = Url::parse("https://172.16.0.1/token").unwrap();
1402        assert!(validate_oauth_url(&url).is_err());
1403    }
1404
1405    #[test]
1406    fn test_validate_oauth_url_rejects_private_ipv4_192() {
1407        let url = Url::parse("https://192.168.1.1/token").unwrap();
1408        assert!(validate_oauth_url(&url).is_err());
1409    }
1410
1411    #[test]
1412    fn test_validate_oauth_url_rejects_link_local() {
1413        let url = Url::parse("https://169.254.169.254/latest/meta-data/").unwrap();
1414        assert!(validate_oauth_url(&url).is_err());
1415    }
1416
1417    #[test]
1418    fn test_validate_oauth_url_rejects_ipv6_ula() {
1419        let url = Url::parse("https://[fd12:3456:789a::1]/token").unwrap();
1420        assert!(validate_oauth_url(&url).is_err());
1421    }
1422
1423    #[test]
1424    fn test_validate_oauth_url_rejects_ipv6_unspecified() {
1425        let url = Url::parse("https://[::]/token").unwrap();
1426        assert!(validate_oauth_url(&url).is_err());
1427    }
1428
1429    #[test]
1430    fn test_validate_oauth_url_rejects_ipv4_mapped_ipv6_private() {
1431        let url = Url::parse("https://[::ffff:10.0.0.1]/token").unwrap();
1432        assert!(validate_oauth_url(&url).is_err());
1433    }
1434
1435    #[test]
1436    fn test_validate_oauth_url_rejects_ipv4_mapped_ipv6_link_local() {
1437        let url = Url::parse("https://[::ffff:169.254.169.254]/token").unwrap();
1438        assert!(validate_oauth_url(&url).is_err());
1439    }
1440
1441    #[test]
1442    fn test_validate_oauth_url_allows_http_loopback() {
1443        // Loopback is permitted (it's our callback server).
1444        let url = Url::parse("http://127.0.0.1:8080/callback").unwrap();
1445        assert!(validate_oauth_url(&url).is_ok());
1446    }
1447
1448    #[test]
1449    fn test_validate_oauth_url_allows_https_public_ip() {
1450        let url = Url::parse("https://93.184.216.34/token").unwrap();
1451        assert!(validate_oauth_url(&url).is_ok());
1452    }
1453
1454    // -- parse_www_authenticate tests ----------------------------------------
1455
1456    #[test]
1457    fn test_parse_www_authenticate_with_resource_metadata_and_scope() {
1458        let header = r#"Bearer resource_metadata="https://mcp.example.com/.well-known/oauth-protected-resource", scope="files:read user:profile""#;
1459        let result = parse_www_authenticate(header).unwrap();
1460
1461        assert_eq!(
1462            result.resource_metadata.as_ref().map(|u| u.as_str()),
1463            Some("https://mcp.example.com/.well-known/oauth-protected-resource")
1464        );
1465        assert_eq!(
1466            result.scope,
1467            Some(vec!["files:read".to_string(), "user:profile".to_string()])
1468        );
1469        assert_eq!(result.error, None);
1470    }
1471
1472    #[test]
1473    fn test_parse_www_authenticate_resource_metadata_only() {
1474        let header = r#"Bearer resource_metadata="https://mcp.example.com/.well-known/oauth-protected-resource""#;
1475        let result = parse_www_authenticate(header).unwrap();
1476
1477        assert_eq!(
1478            result.resource_metadata.as_ref().map(|u| u.as_str()),
1479            Some("https://mcp.example.com/.well-known/oauth-protected-resource")
1480        );
1481        assert_eq!(result.scope, None);
1482    }
1483
1484    #[test]
1485    fn test_parse_www_authenticate_bare_bearer() {
1486        let result = parse_www_authenticate("Bearer").unwrap();
1487        assert_eq!(result.resource_metadata, None);
1488        assert_eq!(result.scope, None);
1489    }
1490
1491    #[test]
1492    fn test_parse_www_authenticate_with_error() {
1493        let header = r#"Bearer error="insufficient_scope", scope="files:read files:write", resource_metadata="https://mcp.example.com/.well-known/oauth-protected-resource", error_description="Additional file write permission required""#;
1494        let result = parse_www_authenticate(header).unwrap();
1495
1496        assert_eq!(result.error, Some(BearerError::InsufficientScope));
1497        assert_eq!(
1498            result.error_description.as_deref(),
1499            Some("Additional file write permission required")
1500        );
1501        assert_eq!(
1502            result.scope,
1503            Some(vec!["files:read".to_string(), "files:write".to_string()])
1504        );
1505        assert!(result.resource_metadata.is_some());
1506    }
1507
1508    #[test]
1509    fn test_parse_www_authenticate_invalid_token_error() {
1510        let header =
1511            r#"Bearer error="invalid_token", error_description="The access token expired""#;
1512        let result = parse_www_authenticate(header).unwrap();
1513        assert_eq!(result.error, Some(BearerError::InvalidToken));
1514    }
1515
1516    #[test]
1517    fn test_parse_www_authenticate_invalid_request_error() {
1518        let header = r#"Bearer error="invalid_request""#;
1519        let result = parse_www_authenticate(header).unwrap();
1520        assert_eq!(result.error, Some(BearerError::InvalidRequest));
1521    }
1522
1523    #[test]
1524    fn test_parse_www_authenticate_unknown_error() {
1525        let header = r#"Bearer error="some_future_error""#;
1526        let result = parse_www_authenticate(header).unwrap();
1527        assert_eq!(result.error, Some(BearerError::Other));
1528    }
1529
1530    #[test]
1531    fn test_parse_www_authenticate_rejects_non_bearer() {
1532        let result = parse_www_authenticate("Basic realm=\"example\"");
1533        assert!(result.is_err());
1534    }
1535
1536    #[test]
1537    fn test_parse_www_authenticate_case_insensitive_scheme() {
1538        let header = r#"bearer resource_metadata="https://example.com/.well-known/oauth-protected-resource""#;
1539        let result = parse_www_authenticate(header).unwrap();
1540        assert!(result.resource_metadata.is_some());
1541    }
1542
1543    #[test]
1544    fn test_parse_www_authenticate_multiline_style() {
1545        // Some servers emit the header spread across multiple lines joined by
1546        // whitespace, as shown in the spec examples.
1547        let header = "Bearer resource_metadata=\"https://mcp.example.com/.well-known/oauth-protected-resource\",\n                         scope=\"files:read\"";
1548        let result = parse_www_authenticate(header).unwrap();
1549        assert!(result.resource_metadata.is_some());
1550        assert_eq!(result.scope, Some(vec!["files:read".to_string()]));
1551    }
1552
1553    #[test]
1554    fn test_protected_resource_metadata_urls_with_path() {
1555        let server_url = Url::parse("https://api.example.com/v1/mcp").unwrap();
1556        let urls = protected_resource_metadata_urls(&server_url);
1557
1558        assert_eq!(urls.len(), 2);
1559        assert_eq!(
1560            urls[0].as_str(),
1561            "https://api.example.com/.well-known/oauth-protected-resource/v1/mcp"
1562        );
1563        assert_eq!(
1564            urls[1].as_str(),
1565            "https://api.example.com/.well-known/oauth-protected-resource"
1566        );
1567    }
1568
1569    #[test]
1570    fn test_protected_resource_metadata_urls_without_path() {
1571        let server_url = Url::parse("https://mcp.example.com").unwrap();
1572        let urls = protected_resource_metadata_urls(&server_url);
1573
1574        assert_eq!(urls.len(), 1);
1575        assert_eq!(
1576            urls[0].as_str(),
1577            "https://mcp.example.com/.well-known/oauth-protected-resource"
1578        );
1579    }
1580
1581    #[test]
1582    fn test_auth_server_metadata_urls_with_path() {
1583        let issuer = Url::parse("https://auth.example.com/tenant1").unwrap();
1584        let urls = auth_server_metadata_urls(&issuer);
1585
1586        assert_eq!(urls.len(), 3);
1587        assert_eq!(
1588            urls[0].as_str(),
1589            "https://auth.example.com/.well-known/oauth-authorization-server/tenant1"
1590        );
1591        assert_eq!(
1592            urls[1].as_str(),
1593            "https://auth.example.com/.well-known/openid-configuration/tenant1"
1594        );
1595        assert_eq!(
1596            urls[2].as_str(),
1597            "https://auth.example.com/tenant1/.well-known/openid-configuration"
1598        );
1599    }
1600
1601    #[test]
1602    fn test_auth_server_metadata_urls_without_path() {
1603        let issuer = Url::parse("https://auth.example.com").unwrap();
1604        let urls = auth_server_metadata_urls(&issuer);
1605
1606        assert_eq!(urls.len(), 2);
1607        assert_eq!(
1608            urls[0].as_str(),
1609            "https://auth.example.com/.well-known/oauth-authorization-server"
1610        );
1611        assert_eq!(
1612            urls[1].as_str(),
1613            "https://auth.example.com/.well-known/openid-configuration"
1614        );
1615    }
1616
1617    // -- Canonical server URI tests ------------------------------------------
1618
1619    #[test]
1620    fn test_canonical_server_uri_simple() {
1621        let url = Url::parse("https://mcp.example.com").unwrap();
1622        assert_eq!(canonical_server_uri(&url), "https://mcp.example.com");
1623    }
1624
1625    #[test]
1626    fn test_canonical_server_uri_with_path() {
1627        let url = Url::parse("https://mcp.example.com/v1/mcp").unwrap();
1628        assert_eq!(canonical_server_uri(&url), "https://mcp.example.com/v1/mcp");
1629    }
1630
1631    #[test]
1632    fn test_canonical_server_uri_strips_trailing_slash() {
1633        let url = Url::parse("https://mcp.example.com/").unwrap();
1634        assert_eq!(canonical_server_uri(&url), "https://mcp.example.com");
1635    }
1636
1637    #[test]
1638    fn test_canonical_server_uri_preserves_port() {
1639        let url = Url::parse("https://mcp.example.com:8443").unwrap();
1640        assert_eq!(canonical_server_uri(&url), "https://mcp.example.com:8443");
1641    }
1642
1643    #[test]
1644    fn test_canonical_server_uri_lowercases() {
1645        let url = Url::parse("HTTPS://MCP.Example.COM/Server/MCP").unwrap();
1646        assert_eq!(
1647            canonical_server_uri(&url),
1648            "https://mcp.example.com/Server/MCP"
1649        );
1650    }
1651
1652    // -- Scope selection tests -----------------------------------------------
1653
1654    #[test]
1655    fn test_select_scopes_prefers_www_authenticate() {
1656        let www_auth = WwwAuthenticate {
1657            resource_metadata: None,
1658            scope: Some(vec!["files:read".into()]),
1659            error: None,
1660            error_description: None,
1661        };
1662        let resource_meta = ProtectedResourceMetadata {
1663            resource: Url::parse("https://example.com").unwrap(),
1664            authorization_servers: vec![],
1665            scopes_supported: Some(vec!["files:read".into(), "files:write".into()]),
1666        };
1667        assert_eq!(select_scopes(&www_auth, &resource_meta), vec!["files:read"]);
1668    }
1669
1670    #[test]
1671    fn test_select_scopes_falls_back_to_resource_metadata() {
1672        let www_auth = WwwAuthenticate {
1673            resource_metadata: None,
1674            scope: None,
1675            error: None,
1676            error_description: None,
1677        };
1678        let resource_meta = ProtectedResourceMetadata {
1679            resource: Url::parse("https://example.com").unwrap(),
1680            authorization_servers: vec![],
1681            scopes_supported: Some(vec!["admin".into()]),
1682        };
1683        assert_eq!(select_scopes(&www_auth, &resource_meta), vec!["admin"]);
1684    }
1685
1686    #[test]
1687    fn test_select_scopes_empty_when_nothing_available() {
1688        let www_auth = WwwAuthenticate {
1689            resource_metadata: None,
1690            scope: None,
1691            error: None,
1692            error_description: None,
1693        };
1694        let resource_meta = ProtectedResourceMetadata {
1695            resource: Url::parse("https://example.com").unwrap(),
1696            authorization_servers: vec![],
1697            scopes_supported: None,
1698        };
1699        assert!(select_scopes(&www_auth, &resource_meta).is_empty());
1700    }
1701
1702    // -- Client registration strategy tests ----------------------------------
1703
1704    #[test]
1705    fn test_registration_strategy_prefers_cimd() {
1706        let metadata = AuthServerMetadata {
1707            issuer: Url::parse("https://auth.example.com").unwrap(),
1708            authorization_endpoint: Url::parse("https://auth.example.com/authorize").unwrap(),
1709            token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
1710            registration_endpoint: Some(Url::parse("https://auth.example.com/register").unwrap()),
1711            scopes_supported: None,
1712            code_challenge_methods_supported: Some(vec!["S256".into()]),
1713            client_id_metadata_document_supported: true,
1714        };
1715        assert_eq!(
1716            determine_registration_strategy(&metadata),
1717            ClientRegistrationStrategy::Cimd {
1718                client_id: CIMD_URL.to_string(),
1719            }
1720        );
1721    }
1722
1723    #[test]
1724    fn test_registration_strategy_falls_back_to_dcr() {
1725        let reg_endpoint = Url::parse("https://auth.example.com/register").unwrap();
1726        let metadata = AuthServerMetadata {
1727            issuer: Url::parse("https://auth.example.com").unwrap(),
1728            authorization_endpoint: Url::parse("https://auth.example.com/authorize").unwrap(),
1729            token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
1730            registration_endpoint: Some(reg_endpoint.clone()),
1731            scopes_supported: None,
1732            code_challenge_methods_supported: Some(vec!["S256".into()]),
1733            client_id_metadata_document_supported: false,
1734        };
1735        assert_eq!(
1736            determine_registration_strategy(&metadata),
1737            ClientRegistrationStrategy::Dcr {
1738                registration_endpoint: reg_endpoint,
1739            }
1740        );
1741    }
1742
1743    #[test]
1744    fn test_registration_strategy_unavailable() {
1745        let metadata = AuthServerMetadata {
1746            issuer: Url::parse("https://auth.example.com").unwrap(),
1747            authorization_endpoint: Url::parse("https://auth.example.com/authorize").unwrap(),
1748            token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
1749            registration_endpoint: None,
1750            scopes_supported: None,
1751            code_challenge_methods_supported: Some(vec!["S256".into()]),
1752            client_id_metadata_document_supported: false,
1753        };
1754        assert_eq!(
1755            determine_registration_strategy(&metadata),
1756            ClientRegistrationStrategy::Unavailable,
1757        );
1758    }
1759
1760    // -- PKCE tests ----------------------------------------------------------
1761
1762    #[test]
1763    fn test_pkce_challenge_verifier_length() {
1764        let pkce = generate_pkce_challenge();
1765        // 32 random bytes → 43 base64url chars (no padding).
1766        assert_eq!(pkce.verifier.len(), 43);
1767    }
1768
1769    #[test]
1770    fn test_pkce_challenge_is_valid_base64url() {
1771        let pkce = generate_pkce_challenge();
1772        for c in pkce.verifier.chars().chain(pkce.challenge.chars()) {
1773            assert!(
1774                c.is_ascii_alphanumeric() || c == '-' || c == '_',
1775                "invalid base64url character: {}",
1776                c
1777            );
1778        }
1779    }
1780
1781    #[test]
1782    fn test_pkce_challenge_is_s256_of_verifier() {
1783        let pkce = generate_pkce_challenge();
1784        let engine = base64::engine::general_purpose::URL_SAFE_NO_PAD;
1785        let expected_digest = Sha256::digest(pkce.verifier.as_bytes());
1786        let expected_challenge = engine.encode(expected_digest);
1787        assert_eq!(pkce.challenge, expected_challenge);
1788    }
1789
1790    #[test]
1791    fn test_pkce_challenges_are_unique() {
1792        let a = generate_pkce_challenge();
1793        let b = generate_pkce_challenge();
1794        assert_ne!(a.verifier, b.verifier);
1795    }
1796
1797    // -- Authorization URL tests ---------------------------------------------
1798
1799    #[test]
1800    fn test_build_authorization_url() {
1801        let metadata = AuthServerMetadata {
1802            issuer: Url::parse("https://auth.example.com").unwrap(),
1803            authorization_endpoint: Url::parse("https://auth.example.com/authorize").unwrap(),
1804            token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
1805            registration_endpoint: None,
1806            scopes_supported: None,
1807            code_challenge_methods_supported: Some(vec!["S256".into()]),
1808            client_id_metadata_document_supported: true,
1809        };
1810        let pkce = PkceChallenge {
1811            verifier: "test_verifier".into(),
1812            challenge: "test_challenge".into(),
1813        };
1814        let url = build_authorization_url(
1815            &metadata,
1816            "https://zed.dev/oauth/client-metadata.json",
1817            "http://127.0.0.1:12345/callback",
1818            &["files:read".into(), "files:write".into()],
1819            "https://mcp.example.com",
1820            &pkce,
1821            "random_state_123",
1822        );
1823
1824        let pairs: std::collections::HashMap<_, _> = url.query_pairs().collect();
1825        assert_eq!(pairs.get("response_type").unwrap(), "code");
1826        assert_eq!(
1827            pairs.get("client_id").unwrap(),
1828            "https://zed.dev/oauth/client-metadata.json"
1829        );
1830        assert_eq!(
1831            pairs.get("redirect_uri").unwrap(),
1832            "http://127.0.0.1:12345/callback"
1833        );
1834        assert_eq!(pairs.get("scope").unwrap(), "files:read files:write");
1835        assert_eq!(pairs.get("resource").unwrap(), "https://mcp.example.com");
1836        assert_eq!(pairs.get("code_challenge").unwrap(), "test_challenge");
1837        assert_eq!(pairs.get("code_challenge_method").unwrap(), "S256");
1838        assert_eq!(pairs.get("state").unwrap(), "random_state_123");
1839    }
1840
1841    #[test]
1842    fn test_build_authorization_url_omits_empty_scope() {
1843        let metadata = AuthServerMetadata {
1844            issuer: Url::parse("https://auth.example.com").unwrap(),
1845            authorization_endpoint: Url::parse("https://auth.example.com/authorize").unwrap(),
1846            token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
1847            registration_endpoint: None,
1848            scopes_supported: None,
1849            code_challenge_methods_supported: Some(vec!["S256".into()]),
1850            client_id_metadata_document_supported: false,
1851        };
1852        let pkce = PkceChallenge {
1853            verifier: "v".into(),
1854            challenge: "c".into(),
1855        };
1856        let url = build_authorization_url(
1857            &metadata,
1858            "client_123",
1859            "http://127.0.0.1:9999/callback",
1860            &[],
1861            "https://mcp.example.com",
1862            &pkce,
1863            "state",
1864        );
1865
1866        let pairs: std::collections::HashMap<_, _> = url.query_pairs().collect();
1867        assert!(!pairs.contains_key("scope"));
1868    }
1869
1870    // -- Token exchange / refresh param tests --------------------------------
1871
1872    #[test]
1873    fn test_token_exchange_params() {
1874        let params = token_exchange_params(
1875            "auth_code_abc",
1876            "client_xyz",
1877            "http://127.0.0.1:5555/callback",
1878            "verifier_123",
1879            "https://mcp.example.com",
1880        );
1881        let map: std::collections::HashMap<&str, &str> =
1882            params.iter().map(|(k, v)| (*k, v.as_str())).collect();
1883
1884        assert_eq!(map["grant_type"], "authorization_code");
1885        assert_eq!(map["code"], "auth_code_abc");
1886        assert_eq!(map["redirect_uri"], "http://127.0.0.1:5555/callback");
1887        assert_eq!(map["client_id"], "client_xyz");
1888        assert_eq!(map["code_verifier"], "verifier_123");
1889        assert_eq!(map["resource"], "https://mcp.example.com");
1890    }
1891
1892    #[test]
1893    fn test_token_refresh_params() {
1894        let params =
1895            token_refresh_params("refresh_token_abc", "client_xyz", "https://mcp.example.com");
1896        let map: std::collections::HashMap<&str, &str> =
1897            params.iter().map(|(k, v)| (*k, v.as_str())).collect();
1898
1899        assert_eq!(map["grant_type"], "refresh_token");
1900        assert_eq!(map["refresh_token"], "refresh_token_abc");
1901        assert_eq!(map["client_id"], "client_xyz");
1902        assert_eq!(map["resource"], "https://mcp.example.com");
1903    }
1904
1905    // -- Token response tests ------------------------------------------------
1906
1907    #[test]
1908    fn test_token_response_into_tokens_with_expiry() {
1909        let response: TokenResponse = serde_json::from_str(
1910            r#"{"access_token": "at_123", "refresh_token": "rt_456", "expires_in": 3600, "token_type": "Bearer"}"#,
1911        )
1912        .unwrap();
1913
1914        let tokens = response.into_tokens();
1915        assert_eq!(tokens.access_token, "at_123");
1916        assert_eq!(tokens.refresh_token.as_deref(), Some("rt_456"));
1917        assert!(tokens.expires_at.is_some());
1918    }
1919
1920    #[test]
1921    fn test_token_response_into_tokens_minimal() {
1922        let response: TokenResponse =
1923            serde_json::from_str(r#"{"access_token": "at_789"}"#).unwrap();
1924
1925        let tokens = response.into_tokens();
1926        assert_eq!(tokens.access_token, "at_789");
1927        assert_eq!(tokens.refresh_token, None);
1928        assert_eq!(tokens.expires_at, None);
1929    }
1930
1931    // -- DCR body test -------------------------------------------------------
1932
1933    #[test]
1934    fn test_dcr_registration_body_shape() {
1935        let body = dcr_registration_body("http://127.0.0.1:12345/callback");
1936        assert_eq!(body["client_name"], "Zed");
1937        assert_eq!(body["redirect_uris"][0], "http://127.0.0.1:12345/callback");
1938        assert_eq!(body["grant_types"][0], "authorization_code");
1939        assert_eq!(body["response_types"][0], "code");
1940        assert_eq!(body["token_endpoint_auth_method"], "none");
1941    }
1942
1943    // -- Test helpers for async/HTTP tests -----------------------------------
1944
1945    fn make_fake_http_client(
1946        handler: impl Fn(
1947            http_client::Request<AsyncBody>,
1948        ) -> std::pin::Pin<
1949            Box<dyn std::future::Future<Output = anyhow::Result<Response<AsyncBody>>> + Send>,
1950        > + Send
1951        + Sync
1952        + 'static,
1953    ) -> Arc<dyn HttpClient> {
1954        http_client::FakeHttpClient::create(handler) as Arc<dyn HttpClient>
1955    }
1956
1957    fn json_response(status: u16, body: &str) -> anyhow::Result<Response<AsyncBody>> {
1958        Ok(Response::builder()
1959            .status(status)
1960            .header("Content-Type", "application/json")
1961            .body(AsyncBody::from(body.as_bytes().to_vec()))
1962            .unwrap())
1963    }
1964
1965    // -- Discovery integration tests -----------------------------------------
1966
1967    #[test]
1968    fn test_fetch_protected_resource_metadata() {
1969        smol::block_on(async {
1970            let client = make_fake_http_client(|req| {
1971                Box::pin(async move {
1972                    let uri = req.uri().to_string();
1973                    if uri.contains(".well-known/oauth-protected-resource") {
1974                        json_response(
1975                            200,
1976                            r#"{
1977                                "resource": "https://mcp.example.com",
1978                                "authorization_servers": ["https://auth.example.com"],
1979                                "scopes_supported": ["read", "write"]
1980                            }"#,
1981                        )
1982                    } else {
1983                        json_response(404, "{}")
1984                    }
1985                })
1986            });
1987
1988            let server_url = Url::parse("https://mcp.example.com").unwrap();
1989            let www_auth = WwwAuthenticate {
1990                resource_metadata: None,
1991                scope: None,
1992                error: None,
1993                error_description: None,
1994            };
1995
1996            let metadata = fetch_protected_resource_metadata(&client, &server_url, &www_auth)
1997                .await
1998                .unwrap();
1999
2000            assert_eq!(metadata.resource.as_str(), "https://mcp.example.com/");
2001            assert_eq!(metadata.authorization_servers.len(), 1);
2002            assert_eq!(
2003                metadata.authorization_servers[0].as_str(),
2004                "https://auth.example.com/"
2005            );
2006            assert_eq!(
2007                metadata.scopes_supported,
2008                Some(vec!["read".to_string(), "write".to_string()])
2009            );
2010        });
2011    }
2012
2013    #[test]
2014    fn test_fetch_protected_resource_metadata_prefers_www_authenticate_url() {
2015        smol::block_on(async {
2016            let client = make_fake_http_client(|req| {
2017                Box::pin(async move {
2018                    let uri = req.uri().to_string();
2019                    if uri == "https://mcp.example.com/custom-resource-metadata" {
2020                        json_response(
2021                            200,
2022                            r#"{
2023                                "resource": "https://mcp.example.com",
2024                                "authorization_servers": ["https://auth.example.com"]
2025                            }"#,
2026                        )
2027                    } else {
2028                        json_response(500, r#"{"error": "should not be called"}"#)
2029                    }
2030                })
2031            });
2032
2033            let server_url = Url::parse("https://mcp.example.com").unwrap();
2034            let www_auth = WwwAuthenticate {
2035                resource_metadata: Some(
2036                    Url::parse("https://mcp.example.com/custom-resource-metadata").unwrap(),
2037                ),
2038                scope: None,
2039                error: None,
2040                error_description: None,
2041            };
2042
2043            let metadata = fetch_protected_resource_metadata(&client, &server_url, &www_auth)
2044                .await
2045                .unwrap();
2046
2047            assert_eq!(metadata.authorization_servers.len(), 1);
2048        });
2049    }
2050
2051    #[test]
2052    fn test_fetch_protected_resource_metadata_rejects_cross_origin_url() {
2053        smol::block_on(async {
2054            let client = make_fake_http_client(|req| {
2055                Box::pin(async move {
2056                    let uri = req.uri().to_string();
2057                    // The cross-origin URL should NOT be fetched; only the
2058                    // well-known fallback at the server's own origin should be.
2059                    if uri.contains("attacker.example.com") {
2060                        panic!("should not fetch cross-origin resource_metadata URL");
2061                    } else if uri.contains(".well-known/oauth-protected-resource") {
2062                        json_response(
2063                            200,
2064                            r#"{
2065                                "resource": "https://mcp.example.com",
2066                                "authorization_servers": ["https://auth.example.com"]
2067                            }"#,
2068                        )
2069                    } else {
2070                        json_response(404, "{}")
2071                    }
2072                })
2073            });
2074
2075            let server_url = Url::parse("https://mcp.example.com").unwrap();
2076            let www_auth = WwwAuthenticate {
2077                resource_metadata: Some(
2078                    Url::parse("https://attacker.example.com/fake-metadata").unwrap(),
2079                ),
2080                scope: None,
2081                error: None,
2082                error_description: None,
2083            };
2084
2085            let metadata = fetch_protected_resource_metadata(&client, &server_url, &www_auth)
2086                .await
2087                .unwrap();
2088
2089            // Should have used the fallback well-known URL, not the attacker's.
2090            assert_eq!(metadata.resource.as_str(), "https://mcp.example.com/");
2091        });
2092    }
2093
2094    #[test]
2095    fn test_fetch_auth_server_metadata() {
2096        smol::block_on(async {
2097            let client = make_fake_http_client(|req| {
2098                Box::pin(async move {
2099                    let uri = req.uri().to_string();
2100                    if uri.contains(".well-known/oauth-authorization-server") {
2101                        json_response(
2102                            200,
2103                            r#"{
2104                                "issuer": "https://auth.example.com",
2105                                "authorization_endpoint": "https://auth.example.com/authorize",
2106                                "token_endpoint": "https://auth.example.com/token",
2107                                "registration_endpoint": "https://auth.example.com/register",
2108                                "code_challenge_methods_supported": ["S256"],
2109                                "client_id_metadata_document_supported": true
2110                            }"#,
2111                        )
2112                    } else {
2113                        json_response(404, "{}")
2114                    }
2115                })
2116            });
2117
2118            let issuer = Url::parse("https://auth.example.com").unwrap();
2119            let metadata = fetch_auth_server_metadata(&client, &issuer).await.unwrap();
2120
2121            assert_eq!(metadata.issuer.as_str(), "https://auth.example.com/");
2122            assert_eq!(
2123                metadata.authorization_endpoint.as_str(),
2124                "https://auth.example.com/authorize"
2125            );
2126            assert_eq!(
2127                metadata.token_endpoint.as_str(),
2128                "https://auth.example.com/token"
2129            );
2130            assert!(metadata.registration_endpoint.is_some());
2131            assert!(metadata.client_id_metadata_document_supported);
2132            assert_eq!(
2133                metadata.code_challenge_methods_supported,
2134                Some(vec!["S256".to_string()])
2135            );
2136        });
2137    }
2138
2139    #[test]
2140    fn test_fetch_auth_server_metadata_falls_back_to_oidc() {
2141        smol::block_on(async {
2142            let client = make_fake_http_client(|req| {
2143                Box::pin(async move {
2144                    let uri = req.uri().to_string();
2145                    if uri.contains("openid-configuration") {
2146                        json_response(
2147                            200,
2148                            r#"{
2149                                "issuer": "https://auth.example.com",
2150                                "authorization_endpoint": "https://auth.example.com/authorize",
2151                                "token_endpoint": "https://auth.example.com/token",
2152                                "code_challenge_methods_supported": ["S256"]
2153                            }"#,
2154                        )
2155                    } else {
2156                        json_response(404, "{}")
2157                    }
2158                })
2159            });
2160
2161            let issuer = Url::parse("https://auth.example.com").unwrap();
2162            let metadata = fetch_auth_server_metadata(&client, &issuer).await.unwrap();
2163
2164            assert_eq!(
2165                metadata.authorization_endpoint.as_str(),
2166                "https://auth.example.com/authorize"
2167            );
2168            assert!(!metadata.client_id_metadata_document_supported);
2169        });
2170    }
2171
2172    #[test]
2173    fn test_fetch_auth_server_metadata_rejects_issuer_mismatch() {
2174        smol::block_on(async {
2175            let client = make_fake_http_client(|req| {
2176                Box::pin(async move {
2177                    let uri = req.uri().to_string();
2178                    if uri.contains(".well-known/oauth-authorization-server") {
2179                        // Response claims to be a different issuer.
2180                        json_response(
2181                            200,
2182                            r#"{
2183                                "issuer": "https://evil.example.com",
2184                                "authorization_endpoint": "https://evil.example.com/authorize",
2185                                "token_endpoint": "https://evil.example.com/token",
2186                                "code_challenge_methods_supported": ["S256"]
2187                            }"#,
2188                        )
2189                    } else {
2190                        json_response(404, "{}")
2191                    }
2192                })
2193            });
2194
2195            let issuer = Url::parse("https://auth.example.com").unwrap();
2196            let result = fetch_auth_server_metadata(&client, &issuer).await;
2197
2198            assert!(result.is_err());
2199            let err_msg = result.unwrap_err().to_string();
2200            assert!(
2201                err_msg.contains("issuer mismatch"),
2202                "unexpected error: {}",
2203                err_msg
2204            );
2205        });
2206    }
2207
2208    // -- Full discover integration tests -------------------------------------
2209
2210    #[test]
2211    fn test_full_discover_with_cimd() {
2212        smol::block_on(async {
2213            let client = make_fake_http_client(|req| {
2214                Box::pin(async move {
2215                    let uri = req.uri().to_string();
2216                    if uri.contains("oauth-protected-resource") {
2217                        json_response(
2218                            200,
2219                            r#"{
2220                                "resource": "https://mcp.example.com",
2221                                "authorization_servers": ["https://auth.example.com"],
2222                                "scopes_supported": ["mcp:read"]
2223                            }"#,
2224                        )
2225                    } else if uri.contains("oauth-authorization-server") {
2226                        json_response(
2227                            200,
2228                            r#"{
2229                                "issuer": "https://auth.example.com",
2230                                "authorization_endpoint": "https://auth.example.com/authorize",
2231                                "token_endpoint": "https://auth.example.com/token",
2232                                "code_challenge_methods_supported": ["S256"],
2233                                "client_id_metadata_document_supported": true
2234                            }"#,
2235                        )
2236                    } else {
2237                        json_response(404, "{}")
2238                    }
2239                })
2240            });
2241
2242            let server_url = Url::parse("https://mcp.example.com").unwrap();
2243            let www_auth = WwwAuthenticate {
2244                resource_metadata: None,
2245                scope: None,
2246                error: None,
2247                error_description: None,
2248            };
2249
2250            let discovery = discover(&client, &server_url, &www_auth).await.unwrap();
2251            let registration =
2252                resolve_client_registration(&client, &discovery, "http://127.0.0.1:12345/callback")
2253                    .await
2254                    .unwrap();
2255
2256            assert_eq!(registration.client_id, CIMD_URL);
2257            assert_eq!(registration.client_secret, None);
2258            assert_eq!(discovery.scopes, vec!["mcp:read"]);
2259        });
2260    }
2261
2262    #[test]
2263    fn test_full_discover_with_dcr_fallback() {
2264        smol::block_on(async {
2265            let client = make_fake_http_client(|req| {
2266                Box::pin(async move {
2267                    let uri = req.uri().to_string();
2268                    if uri.contains("oauth-protected-resource") {
2269                        json_response(
2270                            200,
2271                            r#"{
2272                                "resource": "https://mcp.example.com",
2273                                "authorization_servers": ["https://auth.example.com"]
2274                            }"#,
2275                        )
2276                    } else if uri.contains("oauth-authorization-server") {
2277                        json_response(
2278                            200,
2279                            r#"{
2280                                "issuer": "https://auth.example.com",
2281                                "authorization_endpoint": "https://auth.example.com/authorize",
2282                                "token_endpoint": "https://auth.example.com/token",
2283                                "registration_endpoint": "https://auth.example.com/register",
2284                                "code_challenge_methods_supported": ["S256"],
2285                                "client_id_metadata_document_supported": false
2286                            }"#,
2287                        )
2288                    } else if uri.contains("/register") {
2289                        json_response(
2290                            201,
2291                            r#"{
2292                                "client_id": "dcr-minted-id-123",
2293                                "client_secret": "dcr-secret-456"
2294                            }"#,
2295                        )
2296                    } else {
2297                        json_response(404, "{}")
2298                    }
2299                })
2300            });
2301
2302            let server_url = Url::parse("https://mcp.example.com").unwrap();
2303            let www_auth = WwwAuthenticate {
2304                resource_metadata: None,
2305                scope: Some(vec!["files:read".into()]),
2306                error: None,
2307                error_description: None,
2308            };
2309
2310            let discovery = discover(&client, &server_url, &www_auth).await.unwrap();
2311            let registration =
2312                resolve_client_registration(&client, &discovery, "http://127.0.0.1:9999/callback")
2313                    .await
2314                    .unwrap();
2315
2316            assert_eq!(registration.client_id, "dcr-minted-id-123");
2317            assert_eq!(
2318                registration.client_secret.as_deref(),
2319                Some("dcr-secret-456")
2320            );
2321            assert_eq!(discovery.scopes, vec!["files:read"]);
2322        });
2323    }
2324
2325    #[test]
2326    fn test_discover_fails_without_pkce_support() {
2327        smol::block_on(async {
2328            let client = make_fake_http_client(|req| {
2329                Box::pin(async move {
2330                    let uri = req.uri().to_string();
2331                    if uri.contains("oauth-protected-resource") {
2332                        json_response(
2333                            200,
2334                            r#"{
2335                                "resource": "https://mcp.example.com",
2336                                "authorization_servers": ["https://auth.example.com"]
2337                            }"#,
2338                        )
2339                    } else if uri.contains("oauth-authorization-server") {
2340                        json_response(
2341                            200,
2342                            r#"{
2343                                "issuer": "https://auth.example.com",
2344                                "authorization_endpoint": "https://auth.example.com/authorize",
2345                                "token_endpoint": "https://auth.example.com/token"
2346                            }"#,
2347                        )
2348                    } else {
2349                        json_response(404, "{}")
2350                    }
2351                })
2352            });
2353
2354            let server_url = Url::parse("https://mcp.example.com").unwrap();
2355            let www_auth = WwwAuthenticate {
2356                resource_metadata: None,
2357                scope: None,
2358                error: None,
2359                error_description: None,
2360            };
2361
2362            let result = discover(&client, &server_url, &www_auth).await;
2363            assert!(result.is_err());
2364            let err_msg = result.unwrap_err().to_string();
2365            assert!(
2366                err_msg.contains("code_challenge_methods_supported"),
2367                "unexpected error: {}",
2368                err_msg
2369            );
2370        });
2371    }
2372
2373    // -- Token exchange integration tests ------------------------------------
2374
2375    #[test]
2376    fn test_exchange_code_success() {
2377        smol::block_on(async {
2378            let client = make_fake_http_client(|req| {
2379                Box::pin(async move {
2380                    let uri = req.uri().to_string();
2381                    if uri.contains("/token") {
2382                        json_response(
2383                            200,
2384                            r#"{
2385                                "access_token": "new_access_token",
2386                                "refresh_token": "new_refresh_token",
2387                                "expires_in": 3600,
2388                                "token_type": "Bearer"
2389                            }"#,
2390                        )
2391                    } else {
2392                        json_response(404, "{}")
2393                    }
2394                })
2395            });
2396
2397            let metadata = AuthServerMetadata {
2398                issuer: Url::parse("https://auth.example.com").unwrap(),
2399                authorization_endpoint: Url::parse("https://auth.example.com/authorize").unwrap(),
2400                token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
2401                registration_endpoint: None,
2402                scopes_supported: None,
2403                code_challenge_methods_supported: Some(vec!["S256".into()]),
2404                client_id_metadata_document_supported: true,
2405            };
2406
2407            let tokens = exchange_code(
2408                &client,
2409                &metadata,
2410                "auth_code_123",
2411                CIMD_URL,
2412                "http://127.0.0.1:9999/callback",
2413                "verifier_abc",
2414                "https://mcp.example.com",
2415            )
2416            .await
2417            .unwrap();
2418
2419            assert_eq!(tokens.access_token, "new_access_token");
2420            assert_eq!(tokens.refresh_token.as_deref(), Some("new_refresh_token"));
2421            assert!(tokens.expires_at.is_some());
2422        });
2423    }
2424
2425    #[test]
2426    fn test_refresh_tokens_success() {
2427        smol::block_on(async {
2428            let client = make_fake_http_client(|req| {
2429                Box::pin(async move {
2430                    let uri = req.uri().to_string();
2431                    if uri.contains("/token") {
2432                        json_response(
2433                            200,
2434                            r#"{
2435                                "access_token": "refreshed_token",
2436                                "expires_in": 1800,
2437                                "token_type": "Bearer"
2438                            }"#,
2439                        )
2440                    } else {
2441                        json_response(404, "{}")
2442                    }
2443                })
2444            });
2445
2446            let token_endpoint = Url::parse("https://auth.example.com/token").unwrap();
2447
2448            let tokens = refresh_tokens(
2449                &client,
2450                &token_endpoint,
2451                "old_refresh_token",
2452                CIMD_URL,
2453                "https://mcp.example.com",
2454            )
2455            .await
2456            .unwrap();
2457
2458            assert_eq!(tokens.access_token, "refreshed_token");
2459            assert_eq!(tokens.refresh_token, None);
2460            assert!(tokens.expires_at.is_some());
2461        });
2462    }
2463
2464    #[test]
2465    fn test_exchange_code_failure() {
2466        smol::block_on(async {
2467            let client = make_fake_http_client(|_req| {
2468                Box::pin(async move { json_response(400, r#"{"error": "invalid_grant"}"#) })
2469            });
2470
2471            let metadata = AuthServerMetadata {
2472                issuer: Url::parse("https://auth.example.com").unwrap(),
2473                authorization_endpoint: Url::parse("https://auth.example.com/authorize").unwrap(),
2474                token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
2475                registration_endpoint: None,
2476                scopes_supported: None,
2477                code_challenge_methods_supported: Some(vec!["S256".into()]),
2478                client_id_metadata_document_supported: true,
2479            };
2480
2481            let result = exchange_code(
2482                &client,
2483                &metadata,
2484                "bad_code",
2485                "client",
2486                "http://127.0.0.1:1/callback",
2487                "verifier",
2488                "https://mcp.example.com",
2489            )
2490            .await;
2491
2492            assert!(result.is_err());
2493            assert!(result.unwrap_err().to_string().contains("400"));
2494        });
2495    }
2496
2497    // -- DCR integration tests -----------------------------------------------
2498
2499    #[test]
2500    fn test_perform_dcr() {
2501        smol::block_on(async {
2502            let client = make_fake_http_client(|_req| {
2503                Box::pin(async move {
2504                    json_response(
2505                        201,
2506                        r#"{
2507                            "client_id": "dynamic-client-001",
2508                            "client_secret": "dynamic-secret-001"
2509                        }"#,
2510                    )
2511                })
2512            });
2513
2514            let endpoint = Url::parse("https://auth.example.com/register").unwrap();
2515            let registration = perform_dcr(&client, &endpoint, "http://127.0.0.1:9999/callback")
2516                .await
2517                .unwrap();
2518
2519            assert_eq!(registration.client_id, "dynamic-client-001");
2520            assert_eq!(
2521                registration.client_secret.as_deref(),
2522                Some("dynamic-secret-001")
2523            );
2524        });
2525    }
2526
2527    #[test]
2528    fn test_perform_dcr_failure() {
2529        smol::block_on(async {
2530            let client = make_fake_http_client(|_req| {
2531                Box::pin(
2532                    async move { json_response(403, r#"{"error": "registration_not_allowed"}"#) },
2533                )
2534            });
2535
2536            let endpoint = Url::parse("https://auth.example.com/register").unwrap();
2537            let result = perform_dcr(&client, &endpoint, "http://127.0.0.1:9999/callback").await;
2538
2539            assert!(result.is_err());
2540            assert!(result.unwrap_err().to_string().contains("403"));
2541        });
2542    }
2543
2544    // -- OAuthCallback parse tests -------------------------------------------
2545
2546    #[test]
2547    fn test_oauth_callback_parse_query() {
2548        let callback = OAuthCallback::parse_query("code=test_auth_code&state=test_state").unwrap();
2549        assert_eq!(callback.code, "test_auth_code");
2550        assert_eq!(callback.state, "test_state");
2551    }
2552
2553    #[test]
2554    fn test_oauth_callback_parse_query_reversed_order() {
2555        let callback = OAuthCallback::parse_query("state=test_state&code=test_auth_code").unwrap();
2556        assert_eq!(callback.code, "test_auth_code");
2557        assert_eq!(callback.state, "test_state");
2558    }
2559
2560    #[test]
2561    fn test_oauth_callback_parse_query_with_extra_params() {
2562        let callback =
2563            OAuthCallback::parse_query("code=test_auth_code&state=test_state&extra=ignored")
2564                .unwrap();
2565        assert_eq!(callback.code, "test_auth_code");
2566        assert_eq!(callback.state, "test_state");
2567    }
2568
2569    #[test]
2570    fn test_oauth_callback_parse_query_missing_code() {
2571        let result = OAuthCallback::parse_query("state=test_state");
2572        assert!(result.is_err());
2573        assert!(result.unwrap_err().to_string().contains("code"));
2574    }
2575
2576    #[test]
2577    fn test_oauth_callback_parse_query_missing_state() {
2578        let result = OAuthCallback::parse_query("code=test_auth_code");
2579        assert!(result.is_err());
2580        assert!(result.unwrap_err().to_string().contains("state"));
2581    }
2582
2583    #[test]
2584    fn test_oauth_callback_parse_query_empty_code() {
2585        let result = OAuthCallback::parse_query("code=&state=test_state");
2586        assert!(result.is_err());
2587    }
2588
2589    #[test]
2590    fn test_oauth_callback_parse_query_empty_state() {
2591        let result = OAuthCallback::parse_query("code=test_auth_code&state=");
2592        assert!(result.is_err());
2593    }
2594
2595    #[test]
2596    fn test_oauth_callback_parse_query_url_encoded_values() {
2597        let callback = OAuthCallback::parse_query("code=abc%20def&state=test%3Dstate").unwrap();
2598        assert_eq!(callback.code, "abc def");
2599        assert_eq!(callback.state, "test=state");
2600    }
2601
2602    #[test]
2603    fn test_oauth_callback_parse_query_error_response() {
2604        let result = OAuthCallback::parse_query(
2605            "error=access_denied&error_description=User%20denied%20access&state=abc",
2606        );
2607        assert!(result.is_err());
2608        let err_msg = result.unwrap_err().to_string();
2609        assert!(
2610            err_msg.contains("access_denied"),
2611            "unexpected error: {}",
2612            err_msg
2613        );
2614        assert!(
2615            err_msg.contains("User denied access"),
2616            "unexpected error: {}",
2617            err_msg
2618        );
2619    }
2620
2621    #[test]
2622    fn test_oauth_callback_parse_query_error_without_description() {
2623        let result = OAuthCallback::parse_query("error=server_error&state=abc");
2624        assert!(result.is_err());
2625        let err_msg = result.unwrap_err().to_string();
2626        assert!(
2627            err_msg.contains("server_error"),
2628            "unexpected error: {}",
2629            err_msg
2630        );
2631        assert!(
2632            err_msg.contains("no description"),
2633            "unexpected error: {}",
2634            err_msg
2635        );
2636    }
2637
2638    // -- McpOAuthTokenProvider tests -----------------------------------------
2639
2640    fn make_test_session(
2641        access_token: &str,
2642        refresh_token: Option<&str>,
2643        expires_at: Option<SystemTime>,
2644    ) -> OAuthSession {
2645        OAuthSession {
2646            token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
2647            resource: Url::parse("https://mcp.example.com").unwrap(),
2648            client_registration: OAuthClientRegistration {
2649                client_id: "test-client".into(),
2650                client_secret: None,
2651            },
2652            tokens: OAuthTokens {
2653                access_token: access_token.into(),
2654                refresh_token: refresh_token.map(String::from),
2655                expires_at,
2656            },
2657        }
2658    }
2659
2660    #[test]
2661    fn test_mcp_oauth_provider_returns_none_when_token_expired() {
2662        let expired = SystemTime::now() - Duration::from_secs(60);
2663        let session = make_test_session("stale-token", Some("rt"), Some(expired));
2664        let provider = McpOAuthTokenProvider::new(
2665            session,
2666            make_fake_http_client(|_| Box::pin(async { unreachable!() })),
2667            None,
2668        );
2669
2670        assert_eq!(provider.access_token(), None);
2671    }
2672
2673    #[test]
2674    fn test_mcp_oauth_provider_returns_token_when_not_expired() {
2675        let far_future = SystemTime::now() + Duration::from_secs(3600);
2676        let session = make_test_session("valid-token", Some("rt"), Some(far_future));
2677        let provider = McpOAuthTokenProvider::new(
2678            session,
2679            make_fake_http_client(|_| Box::pin(async { unreachable!() })),
2680            None,
2681        );
2682
2683        assert_eq!(provider.access_token().as_deref(), Some("valid-token"));
2684    }
2685
2686    #[test]
2687    fn test_mcp_oauth_provider_returns_token_when_no_expiry() {
2688        let session = make_test_session("no-expiry-token", Some("rt"), None);
2689        let provider = McpOAuthTokenProvider::new(
2690            session,
2691            make_fake_http_client(|_| Box::pin(async { unreachable!() })),
2692            None,
2693        );
2694
2695        assert_eq!(provider.access_token().as_deref(), Some("no-expiry-token"));
2696    }
2697
2698    #[test]
2699    fn test_mcp_oauth_provider_refresh_without_refresh_token_returns_false() {
2700        smol::block_on(async {
2701            let session = make_test_session("token", None, None);
2702            let provider = McpOAuthTokenProvider::new(
2703                session,
2704                make_fake_http_client(|_| {
2705                    Box::pin(async { unreachable!("no HTTP call expected") })
2706                }),
2707                None,
2708            );
2709
2710            let refreshed = provider.try_refresh().await.unwrap();
2711            assert!(!refreshed);
2712        });
2713    }
2714
2715    #[test]
2716    fn test_mcp_oauth_provider_refresh_updates_session_and_notifies_channel() {
2717        smol::block_on(async {
2718            let session = make_test_session("old-access", Some("my-refresh-token"), None);
2719            let (tx, mut rx) = futures::channel::mpsc::unbounded();
2720
2721            let http_client = make_fake_http_client(|_req| {
2722                Box::pin(async {
2723                    json_response(
2724                        200,
2725                        r#"{
2726                            "access_token": "new-access",
2727                            "refresh_token": "new-refresh",
2728                            "expires_in": 1800
2729                        }"#,
2730                    )
2731                })
2732            });
2733
2734            let provider = McpOAuthTokenProvider::new(session, http_client, Some(tx));
2735
2736            let refreshed = provider.try_refresh().await.unwrap();
2737            assert!(refreshed);
2738            assert_eq!(provider.access_token().as_deref(), Some("new-access"));
2739
2740            let notified_session = rx.try_recv().expect("channel should have a session");
2741            assert_eq!(notified_session.tokens.access_token, "new-access");
2742            assert_eq!(
2743                notified_session.tokens.refresh_token.as_deref(),
2744                Some("new-refresh")
2745            );
2746        });
2747    }
2748
2749    #[test]
2750    fn test_mcp_oauth_provider_refresh_preserves_old_refresh_token_when_server_omits_it() {
2751        smol::block_on(async {
2752            let session = make_test_session("old-access", Some("original-refresh"), None);
2753            let (tx, mut rx) = futures::channel::mpsc::unbounded();
2754
2755            let http_client = make_fake_http_client(|_req| {
2756                Box::pin(async {
2757                    json_response(
2758                        200,
2759                        r#"{
2760                            "access_token": "new-access",
2761                            "expires_in": 900
2762                        }"#,
2763                    )
2764                })
2765            });
2766
2767            let provider = McpOAuthTokenProvider::new(session, http_client, Some(tx));
2768
2769            let refreshed = provider.try_refresh().await.unwrap();
2770            assert!(refreshed);
2771
2772            let notified_session = rx.try_recv().expect("channel should have a session");
2773            assert_eq!(notified_session.tokens.access_token, "new-access");
2774            assert_eq!(
2775                notified_session.tokens.refresh_token.as_deref(),
2776                Some("original-refresh"),
2777            );
2778        });
2779    }
2780
2781    #[test]
2782    fn test_mcp_oauth_provider_refresh_returns_false_on_http_error() {
2783        smol::block_on(async {
2784            let session = make_test_session("old-access", Some("my-refresh"), None);
2785
2786            let http_client = make_fake_http_client(|_req| {
2787                Box::pin(async { json_response(401, r#"{"error": "invalid_grant"}"#) })
2788            });
2789
2790            let provider = McpOAuthTokenProvider::new(session, http_client, None);
2791
2792            let refreshed = provider.try_refresh().await.unwrap();
2793            assert!(!refreshed);
2794            // The old token should still be in place.
2795            assert_eq!(provider.access_token().as_deref(), Some("old-access"));
2796        });
2797    }
2798}