oauth.rs

   1//! OAuth 2.0 authentication for MCP servers using the Authorization Code +
   2//! PKCE flow, per the MCP spec's OAuth profile.
   3//!
   4//! The flow is split into two phases:
   5//!
   6//! 1. **Discovery** ([`discover`]) fetches Protected Resource Metadata and
   7//!    Authorization Server Metadata. This can happen early (e.g. on a 401
   8//!    during server startup) because it doesn't need the redirect URI yet.
   9//!
  10//! 2. **Client registration** ([`resolve_client_registration`]) is separate
  11//!    because DCR requires the actual loopback redirect URI, which includes an
  12//!    ephemeral port that only exists once the callback server has started.
  13//!
  14//! After authentication, the full state is captured in [`OAuthSession`] which
  15//! is persisted to the keychain. On next startup, the stored session feeds
  16//! directly into [`McpOAuthTokenProvider`], giving a refresh-capable provider
  17//! without requiring another browser flow.
  18
  19use anyhow::{Context as _, Result, anyhow, bail};
  20use async_trait::async_trait;
  21use base64::Engine as _;
  22use futures::AsyncReadExt as _;
  23use futures::channel::mpsc;
  24use http_client::{AsyncBody, HttpClient, Request};
  25use parking_lot::Mutex as SyncMutex;
  26use rand::Rng as _;
  27use serde::{Deserialize, Serialize};
  28use sha2::{Digest, Sha256};
  29
  30use std::str::FromStr;
  31use std::sync::Arc;
  32use std::time::{Duration, SystemTime};
  33use url::Url;
  34use util::ResultExt as _;
  35
  36/// The CIMD URL where Zed's OAuth client metadata document is hosted.
  37pub const CIMD_URL: &str = "https://zed.dev/oauth/client-metadata.json";
  38
  39/// Validate that a URL is safe to use as an OAuth endpoint.
  40///
  41/// OAuth endpoints carry sensitive material (authorization codes, PKCE
  42/// verifiers, tokens) and must use TLS. Plain HTTP is only permitted for
  43/// loopback addresses, per RFC 8252 Section 8.3.
  44fn require_https_or_loopback(url: &Url) -> Result<()> {
  45    if url.scheme() == "https" {
  46        return Ok(());
  47    }
  48    if url.scheme() == "http" {
  49        if let Some(host) = url.host() {
  50            match host {
  51                url::Host::Ipv4(ip) if ip.is_loopback() => return Ok(()),
  52                url::Host::Ipv6(ip) if ip.is_loopback() => return Ok(()),
  53                url::Host::Domain(d) if d.eq_ignore_ascii_case("localhost") => return Ok(()),
  54                _ => {}
  55            }
  56        }
  57    }
  58    bail!(
  59        "OAuth endpoint must use HTTPS (got {}://{})",
  60        url.scheme(),
  61        url.host_str().unwrap_or("?")
  62    )
  63}
  64
  65/// Validate that a URL is safe to use as an OAuth endpoint, including SSRF
  66/// protections against private/reserved IP ranges.
  67///
  68/// This wraps [`require_https_or_loopback`] and adds IP-range checks to prevent
  69/// an attacker-controlled MCP server from directing Zed to fetch internal
  70/// network resources via metadata URLs.
  71///
  72/// **Known limitation:** Domain-name URLs that resolve to private IPs are *not*
  73/// blocked here — full mitigation requires resolver-level validation (e.g. a
  74/// custom `Resolve` implementation). This function only blocks IP-literal URLs.
  75fn validate_oauth_url(url: &Url) -> Result<()> {
  76    require_https_or_loopback(url)?;
  77
  78    if let Some(host) = url.host() {
  79        match host {
  80            url::Host::Ipv4(ip) => {
  81                // Loopback is already allowed by require_https_or_loopback.
  82                if ip.is_private() || ip.is_link_local() || ip.is_broadcast() || ip.is_unspecified()
  83                {
  84                    bail!(
  85                        "OAuth endpoint must not point to private/reserved IP: {}",
  86                        ip
  87                    );
  88                }
  89            }
  90            url::Host::Ipv6(ip) => {
  91                // Check for IPv4-mapped IPv6 addresses (::ffff:a.b.c.d) which
  92                // could bypass the IPv4 checks above.
  93                if let Some(mapped_v4) = ip.to_ipv4_mapped() {
  94                    if mapped_v4.is_private()
  95                        || mapped_v4.is_link_local()
  96                        || mapped_v4.is_broadcast()
  97                        || mapped_v4.is_unspecified()
  98                    {
  99                        bail!(
 100                            "OAuth endpoint must not point to private/reserved IP: ::ffff:{}",
 101                            mapped_v4
 102                        );
 103                    }
 104                }
 105
 106                if ip.is_unspecified() || ip.is_multicast() {
 107                    bail!(
 108                        "OAuth endpoint must not point to reserved IPv6 address: {}",
 109                        ip
 110                    );
 111                }
 112                // IPv6 Unique Local Addresses (fc00::/7). is_unique_local() is
 113                // nightly-only, so check the prefix manually.
 114                if (ip.segments()[0] & 0xfe00) == 0xfc00 {
 115                    bail!(
 116                        "OAuth endpoint must not point to IPv6 unique-local address: {}",
 117                        ip
 118                    );
 119                }
 120            }
 121            url::Host::Domain(_) => {
 122                // Domain-based SSRF prevention requires resolver-level checks.
 123                // See known limitation in the doc comment above.
 124            }
 125        }
 126    }
 127
 128    Ok(())
 129}
 130
 131/// Parsed from the MCP server's WWW-Authenticate header or well-known endpoint
 132/// per RFC 9728 (OAuth 2.0 Protected Resource Metadata).
 133#[derive(Debug, Clone, Serialize, Deserialize)]
 134pub struct ProtectedResourceMetadata {
 135    pub resource: Url,
 136    pub authorization_servers: Vec<Url>,
 137    pub scopes_supported: Option<Vec<String>>,
 138}
 139
 140/// Parsed from the authorization server's .well-known endpoint
 141/// per RFC 8414 (OAuth 2.0 Authorization Server Metadata).
 142#[derive(Debug, Clone, Serialize, Deserialize)]
 143pub struct AuthServerMetadata {
 144    pub issuer: Url,
 145    pub authorization_endpoint: Url,
 146    pub token_endpoint: Url,
 147    pub registration_endpoint: Option<Url>,
 148    pub scopes_supported: Option<Vec<String>>,
 149    pub code_challenge_methods_supported: Option<Vec<String>>,
 150    pub client_id_metadata_document_supported: bool,
 151}
 152
 153/// The result of client registration — either CIMD or DCR.
 154#[derive(Clone, Serialize, Deserialize)]
 155pub struct OAuthClientRegistration {
 156    pub client_id: String,
 157    /// Only present for DCR-minted registrations.
 158    pub client_secret: Option<String>,
 159}
 160
 161impl std::fmt::Debug for OAuthClientRegistration {
 162    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 163        f.debug_struct("OAuthClientRegistration")
 164            .field("client_id", &self.client_id)
 165            .field(
 166                "client_secret",
 167                &self.client_secret.as_ref().map(|_| "[redacted]"),
 168            )
 169            .finish()
 170    }
 171}
 172
 173/// Access and refresh tokens obtained from the token endpoint.
 174#[derive(Clone, Serialize, Deserialize)]
 175pub struct OAuthTokens {
 176    pub access_token: String,
 177    pub refresh_token: Option<String>,
 178    pub expires_at: Option<SystemTime>,
 179}
 180
 181impl std::fmt::Debug for OAuthTokens {
 182    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 183        f.debug_struct("OAuthTokens")
 184            .field("access_token", &"[redacted]")
 185            .field(
 186                "refresh_token",
 187                &self.refresh_token.as_ref().map(|_| "[redacted]"),
 188            )
 189            .field("expires_at", &self.expires_at)
 190            .finish()
 191    }
 192}
 193
 194/// Everything discovered before the browser flow starts. Client registration is
 195/// resolved separately, once the real redirect URI is known.
 196#[derive(Debug, Clone, Serialize, Deserialize)]
 197pub struct OAuthDiscovery {
 198    pub resource_metadata: ProtectedResourceMetadata,
 199    pub auth_server_metadata: AuthServerMetadata,
 200    pub scopes: Vec<String>,
 201}
 202
 203/// The persisted OAuth session for a context server.
 204///
 205/// Stored in the keychain so startup can restore a refresh-capable provider
 206/// without another browser flow. Deliberately excludes the full discovery
 207/// metadata to keep the serialized size well within keychain item limits.
 208#[derive(Clone, Serialize, Deserialize)]
 209pub struct OAuthSession {
 210    pub token_endpoint: Url,
 211    pub resource: Url,
 212    pub client_registration: OAuthClientRegistration,
 213    pub tokens: OAuthTokens,
 214}
 215
 216impl std::fmt::Debug for OAuthSession {
 217    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 218        f.debug_struct("OAuthSession")
 219            .field("token_endpoint", &self.token_endpoint)
 220            .field("resource", &self.resource)
 221            .field("client_registration", &self.client_registration)
 222            .field("tokens", &self.tokens)
 223            .finish()
 224    }
 225}
 226
 227/// Error codes defined by RFC 6750 Section 3.1 for Bearer token authentication.
 228#[derive(Debug, Clone, Copy, PartialEq, Eq)]
 229pub enum BearerError {
 230    /// The request is missing a required parameter, includes an unsupported
 231    /// parameter or parameter value, or is otherwise malformed.
 232    InvalidRequest,
 233    /// The access token provided is expired, revoked, malformed, or invalid.
 234    InvalidToken,
 235    /// The request requires higher privileges than provided by the access token.
 236    InsufficientScope,
 237    /// An unrecognized error code (extension or future spec addition).
 238    Other,
 239}
 240
 241impl BearerError {
 242    fn parse(value: &str) -> Self {
 243        match value {
 244            "invalid_request" => BearerError::InvalidRequest,
 245            "invalid_token" => BearerError::InvalidToken,
 246            "insufficient_scope" => BearerError::InsufficientScope,
 247            _ => BearerError::Other,
 248        }
 249    }
 250}
 251
 252/// Fields extracted from a `WWW-Authenticate: Bearer` header.
 253///
 254/// Per RFC 9728 Section 5.1, MCP servers include `resource_metadata` to point
 255/// at the Protected Resource Metadata document. The optional `scope` parameter
 256/// (RFC 6750 Section 3) indicates scopes required for the request.
 257#[derive(Debug, Clone, PartialEq, Eq)]
 258pub struct WwwAuthenticate {
 259    pub resource_metadata: Option<Url>,
 260    pub scope: Option<Vec<String>>,
 261    /// The parsed `error` parameter per RFC 6750 Section 3.1.
 262    pub error: Option<BearerError>,
 263    pub error_description: Option<String>,
 264}
 265
 266/// Parse a `WWW-Authenticate` header value.
 267///
 268/// Expects the `Bearer` scheme followed by comma-separated `key="value"` pairs.
 269/// Per RFC 6750 and RFC 9728, the relevant parameters are:
 270/// - `resource_metadata` — URL of the Protected Resource Metadata document
 271/// - `scope` — space-separated list of required scopes
 272/// - `error` — error code (e.g. "insufficient_scope")
 273/// - `error_description` — human-readable error description
 274pub fn parse_www_authenticate(header: &str) -> Result<WwwAuthenticate> {
 275    let header = header.trim();
 276
 277    let params_str = if header.len() >= 6 && header[..6].eq_ignore_ascii_case("bearer") {
 278        header[6..].trim()
 279    } else {
 280        bail!("WWW-Authenticate header does not use Bearer scheme");
 281    };
 282
 283    if params_str.is_empty() {
 284        return Ok(WwwAuthenticate {
 285            resource_metadata: None,
 286            scope: None,
 287            error: None,
 288            error_description: None,
 289        });
 290    }
 291
 292    let params = parse_auth_params(params_str);
 293
 294    let resource_metadata = params
 295        .get("resource_metadata")
 296        .map(|v| Url::parse(v))
 297        .transpose()
 298        .map_err(|e| anyhow!("invalid resource_metadata URL: {}", e))?;
 299
 300    let scope = params
 301        .get("scope")
 302        .map(|v| v.split_whitespace().map(String::from).collect());
 303
 304    let error = params.get("error").map(|v| BearerError::parse(v));
 305    let error_description = params.get("error_description").cloned();
 306
 307    Ok(WwwAuthenticate {
 308        resource_metadata,
 309        scope,
 310        error,
 311        error_description,
 312    })
 313}
 314
 315/// Parse comma-separated `key="value"` or `key=token` parameters from an
 316/// auth-param list (RFC 7235 Section 2.1).
 317fn parse_auth_params(input: &str) -> collections::HashMap<String, String> {
 318    let mut params = collections::HashMap::default();
 319    let mut remaining = input.trim();
 320
 321    while !remaining.is_empty() {
 322        // Skip leading whitespace and commas.
 323        remaining = remaining.trim_start_matches(|c: char| c == ',' || c.is_whitespace());
 324        if remaining.is_empty() {
 325            break;
 326        }
 327
 328        // Find the key (everything before '=').
 329        let eq_pos = match remaining.find('=') {
 330            Some(pos) => pos,
 331            None => break,
 332        };
 333
 334        let key = remaining[..eq_pos].trim().to_lowercase();
 335        remaining = &remaining[eq_pos + 1..];
 336        remaining = remaining.trim_start();
 337
 338        // Parse the value: either quoted or unquoted (token).
 339        let value;
 340        if remaining.starts_with('"') {
 341            // Quoted string: find the closing quote, handling escaped chars.
 342            remaining = &remaining[1..]; // skip opening quote
 343            let mut val = String::new();
 344            let mut chars = remaining.char_indices();
 345            loop {
 346                match chars.next() {
 347                    Some((_, '\\')) => {
 348                        // Escaped character — take the next char literally.
 349                        if let Some((_, c)) = chars.next() {
 350                            val.push(c);
 351                        }
 352                    }
 353                    Some((i, '"')) => {
 354                        remaining = &remaining[i + 1..];
 355                        break;
 356                    }
 357                    Some((_, c)) => val.push(c),
 358                    None => {
 359                        remaining = "";
 360                        break;
 361                    }
 362                }
 363            }
 364            value = val;
 365        } else {
 366            // Unquoted token: read until comma or whitespace.
 367            let end = remaining
 368                .find(|c: char| c == ',' || c.is_whitespace())
 369                .unwrap_or(remaining.len());
 370            value = remaining[..end].to_string();
 371            remaining = &remaining[end..];
 372        }
 373
 374        if !key.is_empty() {
 375            params.insert(key, value);
 376        }
 377    }
 378
 379    params
 380}
 381
 382/// Construct the well-known Protected Resource Metadata URIs for a given MCP
 383/// server URL, per RFC 9728 Section 3.
 384///
 385/// Returns URIs in priority order:
 386/// 1. Path-specific: `https://<host>/.well-known/oauth-protected-resource/<path>`
 387/// 2. Root: `https://<host>/.well-known/oauth-protected-resource`
 388pub fn protected_resource_metadata_urls(server_url: &Url) -> Vec<Url> {
 389    let mut urls = Vec::new();
 390    let base = format!("{}://{}", server_url.scheme(), server_url.authority());
 391
 392    let path = server_url.path().trim_start_matches('/');
 393    if !path.is_empty() {
 394        if let Ok(url) = Url::parse(&format!(
 395            "{}/.well-known/oauth-protected-resource/{}",
 396            base, path
 397        )) {
 398            urls.push(url);
 399        }
 400    }
 401
 402    if let Ok(url) = Url::parse(&format!("{}/.well-known/oauth-protected-resource", base)) {
 403        urls.push(url);
 404    }
 405
 406    urls
 407}
 408
 409/// Construct the well-known Authorization Server Metadata URIs for a given
 410/// issuer URL, per RFC 8414 Section 3.1 and Section 5 (OIDC compat).
 411///
 412/// Returns URIs in priority order, which differs depending on whether the
 413/// issuer URL has a path component.
 414pub fn auth_server_metadata_urls(issuer: &Url) -> Vec<Url> {
 415    let mut urls = Vec::new();
 416    let base = format!("{}://{}", issuer.scheme(), issuer.authority());
 417    let path = issuer.path().trim_matches('/');
 418
 419    if !path.is_empty() {
 420        // Issuer with path: try path-inserted variants first.
 421        if let Ok(url) = Url::parse(&format!(
 422            "{}/.well-known/oauth-authorization-server/{}",
 423            base, path
 424        )) {
 425            urls.push(url);
 426        }
 427        if let Ok(url) = Url::parse(&format!(
 428            "{}/.well-known/openid-configuration/{}",
 429            base, path
 430        )) {
 431            urls.push(url);
 432        }
 433        if let Ok(url) = Url::parse(&format!(
 434            "{}/{}/.well-known/openid-configuration",
 435            base, path
 436        )) {
 437            urls.push(url);
 438        }
 439    } else {
 440        // No path: standard well-known locations.
 441        if let Ok(url) = Url::parse(&format!("{}/.well-known/oauth-authorization-server", base)) {
 442            urls.push(url);
 443        }
 444        if let Ok(url) = Url::parse(&format!("{}/.well-known/openid-configuration", base)) {
 445            urls.push(url);
 446        }
 447    }
 448
 449    urls
 450}
 451
 452// -- Canonical server URI (RFC 8707) -----------------------------------------
 453
 454/// Derive the canonical resource URI for an MCP server URL, suitable for the
 455/// `resource` parameter in authorization and token requests per RFC 8707.
 456///
 457/// Lowercases the scheme and host, preserves the path (without trailing slash),
 458/// strips fragments and query strings.
 459pub fn canonical_server_uri(server_url: &Url) -> String {
 460    let mut uri = format!(
 461        "{}://{}",
 462        server_url.scheme().to_ascii_lowercase(),
 463        server_url.host_str().unwrap_or("").to_ascii_lowercase(),
 464    );
 465    if let Some(port) = server_url.port() {
 466        uri.push_str(&format!(":{}", port));
 467    }
 468    let path = server_url.path();
 469    if path != "/" {
 470        uri.push_str(path.trim_end_matches('/'));
 471    }
 472    uri
 473}
 474
 475// -- Scope selection ---------------------------------------------------------
 476
 477/// Select scopes following the MCP spec's Scope Selection Strategy:
 478/// 1. Use `scope` from the `WWW-Authenticate` challenge if present.
 479/// 2. Fall back to `scopes_supported` from Protected Resource Metadata.
 480/// 3. Return empty if neither is available.
 481pub fn select_scopes(
 482    www_authenticate: &WwwAuthenticate,
 483    resource_metadata: &ProtectedResourceMetadata,
 484) -> Vec<String> {
 485    if let Some(ref scopes) = www_authenticate.scope {
 486        if !scopes.is_empty() {
 487            return scopes.clone();
 488        }
 489    }
 490    resource_metadata
 491        .scopes_supported
 492        .clone()
 493        .unwrap_or_default()
 494}
 495
 496// -- Client registration strategy --------------------------------------------
 497
 498/// The registration approach to use, determined from auth server metadata.
 499#[derive(Debug, Clone, PartialEq, Eq)]
 500pub enum ClientRegistrationStrategy {
 501    /// The auth server supports CIMD. Use the CIMD URL as client_id directly.
 502    Cimd { client_id: String },
 503    /// The auth server has a registration endpoint. Caller must POST to it.
 504    Dcr { registration_endpoint: Url },
 505    /// No supported registration mechanism.
 506    Unavailable,
 507}
 508
 509/// Determine how to register with the authorization server, following the
 510/// spec's recommended priority: CIMD first, DCR fallback.
 511pub fn determine_registration_strategy(
 512    auth_server_metadata: &AuthServerMetadata,
 513) -> ClientRegistrationStrategy {
 514    if auth_server_metadata.client_id_metadata_document_supported {
 515        ClientRegistrationStrategy::Cimd {
 516            client_id: CIMD_URL.to_string(),
 517        }
 518    } else if let Some(ref endpoint) = auth_server_metadata.registration_endpoint {
 519        ClientRegistrationStrategy::Dcr {
 520            registration_endpoint: endpoint.clone(),
 521        }
 522    } else {
 523        ClientRegistrationStrategy::Unavailable
 524    }
 525}
 526
 527// -- PKCE (RFC 7636) ---------------------------------------------------------
 528
 529/// A PKCE code verifier and its S256 challenge.
 530#[derive(Clone)]
 531pub struct PkceChallenge {
 532    pub verifier: String,
 533    pub challenge: String,
 534}
 535
 536impl std::fmt::Debug for PkceChallenge {
 537    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 538        f.debug_struct("PkceChallenge")
 539            .field("verifier", &"[redacted]")
 540            .field("challenge", &self.challenge)
 541            .finish()
 542    }
 543}
 544
 545/// Generate a PKCE code verifier and S256 challenge per RFC 7636.
 546///
 547/// The verifier is 43 base64url characters derived from 32 random bytes.
 548/// The challenge is `BASE64URL(SHA256(verifier))`.
 549pub fn generate_pkce_challenge() -> PkceChallenge {
 550    let mut random_bytes = [0u8; 32];
 551    rand::rng().fill(&mut random_bytes);
 552    let engine = base64::engine::general_purpose::URL_SAFE_NO_PAD;
 553    let verifier = engine.encode(&random_bytes);
 554
 555    let digest = Sha256::digest(verifier.as_bytes());
 556    let challenge = engine.encode(digest);
 557
 558    PkceChallenge {
 559        verifier,
 560        challenge,
 561    }
 562}
 563
 564// -- Authorization URL construction ------------------------------------------
 565
 566/// Build the authorization URL for the OAuth Authorization Code + PKCE flow.
 567pub fn build_authorization_url(
 568    auth_server_metadata: &AuthServerMetadata,
 569    client_id: &str,
 570    redirect_uri: &str,
 571    scopes: &[String],
 572    resource: &str,
 573    pkce: &PkceChallenge,
 574    state: &str,
 575) -> Url {
 576    let mut url = auth_server_metadata.authorization_endpoint.clone();
 577    {
 578        let mut query = url.query_pairs_mut();
 579        query.append_pair("response_type", "code");
 580        query.append_pair("client_id", client_id);
 581        query.append_pair("redirect_uri", redirect_uri);
 582        if !scopes.is_empty() {
 583            query.append_pair("scope", &scopes.join(" "));
 584        }
 585        query.append_pair("resource", resource);
 586        query.append_pair("code_challenge", &pkce.challenge);
 587        query.append_pair("code_challenge_method", "S256");
 588        query.append_pair("state", state);
 589    }
 590    url
 591}
 592
 593// -- Token endpoint request bodies -------------------------------------------
 594
 595/// The JSON body returned by the token endpoint on success.
 596#[derive(Deserialize)]
 597pub struct TokenResponse {
 598    pub access_token: String,
 599    #[serde(default)]
 600    pub refresh_token: Option<String>,
 601    #[serde(default)]
 602    pub expires_in: Option<u64>,
 603    #[serde(default)]
 604    pub token_type: Option<String>,
 605}
 606
 607impl std::fmt::Debug for TokenResponse {
 608    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 609        f.debug_struct("TokenResponse")
 610            .field("access_token", &"[redacted]")
 611            .field(
 612                "refresh_token",
 613                &self.refresh_token.as_ref().map(|_| "[redacted]"),
 614            )
 615            .field("expires_in", &self.expires_in)
 616            .field("token_type", &self.token_type)
 617            .finish()
 618    }
 619}
 620
 621impl TokenResponse {
 622    /// Convert into `OAuthTokens`, computing `expires_at` from `expires_in`.
 623    pub fn into_tokens(self) -> OAuthTokens {
 624        let expires_at = self
 625            .expires_in
 626            .map(|secs| SystemTime::now() + Duration::from_secs(secs));
 627        OAuthTokens {
 628            access_token: self.access_token,
 629            refresh_token: self.refresh_token,
 630            expires_at,
 631        }
 632    }
 633}
 634
 635/// An OAuth token error response (RFC 6749 Section 5.2).
 636#[derive(Debug, Deserialize, PartialEq)]
 637pub struct OAuthTokenError {
 638    pub error: String,
 639    #[serde(default)]
 640    pub error_description: Option<String>,
 641}
 642
 643impl std::fmt::Display for OAuthTokenError {
 644    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 645        write!(f, "OAuth token error: {}", self.error)?;
 646        if let Some(description) = &self.error_description {
 647            write!(f, " ({description})")?;
 648        }
 649        Ok(())
 650    }
 651}
 652
 653impl std::error::Error for OAuthTokenError {}
 654
 655/// Build the form-encoded body for an authorization code token exchange.
 656pub fn token_exchange_params(
 657    code: &str,
 658    client_id: &str,
 659    redirect_uri: &str,
 660    code_verifier: &str,
 661    resource: &str,
 662    client_secret: Option<&str>,
 663) -> Vec<(&'static str, String)> {
 664    let mut params = vec![
 665        ("grant_type", "authorization_code".to_string()),
 666        ("code", code.to_string()),
 667        ("redirect_uri", redirect_uri.to_string()),
 668        ("client_id", client_id.to_string()),
 669        ("code_verifier", code_verifier.to_string()),
 670        ("resource", resource.to_string()),
 671    ];
 672    if let Some(secret) = client_secret {
 673        params.push(("client_secret", secret.to_string()));
 674    }
 675    params
 676}
 677
 678/// Build the form-encoded body for a token refresh request.
 679pub fn token_refresh_params(
 680    refresh_token: &str,
 681    client_id: &str,
 682    resource: &str,
 683    client_secret: Option<&str>,
 684) -> Vec<(&'static str, String)> {
 685    let mut params = vec![
 686        ("grant_type", "refresh_token".to_string()),
 687        ("refresh_token", refresh_token.to_string()),
 688        ("client_id", client_id.to_string()),
 689        ("resource", resource.to_string()),
 690    ];
 691    if let Some(secret) = client_secret {
 692        params.push(("client_secret", secret.to_string()));
 693    }
 694    params
 695}
 696
 697// -- DCR request body (RFC 7591) ---------------------------------------------
 698
 699/// Build the JSON body for a Dynamic Client Registration request.
 700///
 701/// The `redirect_uri` should be the actual loopback URI with the ephemeral
 702/// port (e.g. `http://127.0.0.1:12345/callback`). Some auth servers do strict
 703/// redirect URI matching even for loopback addresses, so we register the
 704/// exact URI we intend to use.
 705pub fn dcr_registration_body(redirect_uri: &str) -> serde_json::Value {
 706    serde_json::json!({
 707        "client_name": "Zed",
 708        "redirect_uris": [redirect_uri],
 709        "grant_types": ["authorization_code"],
 710        "response_types": ["code"],
 711        "token_endpoint_auth_method": "none"
 712    })
 713}
 714
 715// -- Discovery (async, hits real endpoints) ----------------------------------
 716
 717/// Fetch Protected Resource Metadata from the MCP server.
 718///
 719/// Tries the `resource_metadata` URL from the `WWW-Authenticate` header first,
 720/// then falls back to well-known URIs constructed from `server_url`.
 721pub async fn fetch_protected_resource_metadata(
 722    http_client: &Arc<dyn HttpClient>,
 723    server_url: &Url,
 724    www_authenticate: &WwwAuthenticate,
 725) -> Result<ProtectedResourceMetadata> {
 726    let candidate_urls = match &www_authenticate.resource_metadata {
 727        Some(url) if url.origin() == server_url.origin() => vec![url.clone()],
 728        Some(url) => {
 729            log::warn!(
 730                "Ignoring cross-origin resource_metadata URL {} \
 731                 (server origin: {})",
 732                url,
 733                server_url.origin().unicode_serialization()
 734            );
 735            protected_resource_metadata_urls(server_url)
 736        }
 737        None => protected_resource_metadata_urls(server_url),
 738    };
 739
 740    for url in &candidate_urls {
 741        match fetch_json::<ProtectedResourceMetadataResponse>(http_client, url).await {
 742            Ok(response) => {
 743                if response.authorization_servers.is_empty() {
 744                    bail!(
 745                        "Protected Resource Metadata at {} has no authorization_servers",
 746                        url
 747                    );
 748                }
 749                return Ok(ProtectedResourceMetadata {
 750                    resource: response.resource.unwrap_or_else(|| server_url.clone()),
 751                    authorization_servers: response.authorization_servers,
 752                    scopes_supported: response.scopes_supported,
 753                });
 754            }
 755            Err(err) => {
 756                log::debug!(
 757                    "Failed to fetch Protected Resource Metadata from {}: {}",
 758                    url,
 759                    err
 760                );
 761            }
 762        }
 763    }
 764
 765    bail!(
 766        "Could not fetch Protected Resource Metadata for {}",
 767        server_url
 768    )
 769}
 770
 771/// Fetch Authorization Server Metadata, trying RFC 8414 and OIDC Discovery
 772/// endpoints in the priority order specified by the MCP spec.
 773pub async fn fetch_auth_server_metadata(
 774    http_client: &Arc<dyn HttpClient>,
 775    issuer: &Url,
 776) -> Result<AuthServerMetadata> {
 777    let candidate_urls = auth_server_metadata_urls(issuer);
 778
 779    for url in &candidate_urls {
 780        match fetch_json::<AuthServerMetadataResponse>(http_client, url).await {
 781            Ok(response) => {
 782                let reported_issuer = response.issuer.unwrap_or_else(|| issuer.clone());
 783
 784                if reported_issuer != *issuer {
 785                    bail!(
 786                        "Auth server metadata issuer mismatch: expected {}, got {}",
 787                        issuer,
 788                        reported_issuer
 789                    );
 790                }
 791
 792                return Ok(AuthServerMetadata {
 793                    issuer: reported_issuer,
 794                    authorization_endpoint: response
 795                        .authorization_endpoint
 796                        .ok_or_else(|| anyhow!("missing authorization_endpoint"))?,
 797                    token_endpoint: response
 798                        .token_endpoint
 799                        .ok_or_else(|| anyhow!("missing token_endpoint"))?,
 800                    registration_endpoint: response.registration_endpoint,
 801                    scopes_supported: response.scopes_supported,
 802                    code_challenge_methods_supported: response.code_challenge_methods_supported,
 803                    client_id_metadata_document_supported: response
 804                        .client_id_metadata_document_supported
 805                        .unwrap_or(false),
 806                });
 807            }
 808            Err(err) => {
 809                log::debug!("Failed to fetch Auth Server Metadata from {}: {}", url, err);
 810            }
 811        }
 812    }
 813
 814    bail!(
 815        "Could not fetch Authorization Server Metadata for {}",
 816        issuer
 817    )
 818}
 819
 820/// Run the full discovery flow: fetch resource metadata, then auth server
 821/// metadata, then select scopes. Client registration is resolved separately,
 822/// once the real redirect URI is known.
 823pub async fn discover(
 824    http_client: &Arc<dyn HttpClient>,
 825    server_url: &Url,
 826    www_authenticate: &WwwAuthenticate,
 827) -> Result<OAuthDiscovery> {
 828    let resource_metadata =
 829        fetch_protected_resource_metadata(http_client, server_url, www_authenticate).await?;
 830
 831    let auth_server_url = resource_metadata
 832        .authorization_servers
 833        .first()
 834        .ok_or_else(|| anyhow!("no authorization servers in resource metadata"))?;
 835
 836    let auth_server_metadata = fetch_auth_server_metadata(http_client, auth_server_url).await?;
 837
 838    // Verify PKCE S256 support (spec requirement).
 839    match &auth_server_metadata.code_challenge_methods_supported {
 840        Some(methods) if methods.iter().any(|m| m == "S256") => {}
 841        Some(_) => bail!("authorization server does not support S256 PKCE"),
 842        None => bail!("authorization server does not advertise code_challenge_methods_supported"),
 843    }
 844
 845    let scopes = select_scopes(www_authenticate, &resource_metadata);
 846
 847    Ok(OAuthDiscovery {
 848        resource_metadata,
 849        auth_server_metadata,
 850        scopes,
 851    })
 852}
 853
 854/// Resolve the OAuth client registration for an authorization flow.
 855///
 856/// CIMD uses the static client metadata document directly. For DCR, a fresh
 857/// registration is performed each time because the loopback redirect URI
 858/// includes an ephemeral port that changes every flow.
 859pub async fn resolve_client_registration(
 860    http_client: &Arc<dyn HttpClient>,
 861    discovery: &OAuthDiscovery,
 862    redirect_uri: &str,
 863) -> Result<OAuthClientRegistration> {
 864    match determine_registration_strategy(&discovery.auth_server_metadata) {
 865        ClientRegistrationStrategy::Cimd { client_id } => Ok(OAuthClientRegistration {
 866            client_id,
 867            client_secret: None,
 868        }),
 869        ClientRegistrationStrategy::Dcr {
 870            registration_endpoint,
 871        } => perform_dcr(http_client, &registration_endpoint, redirect_uri).await,
 872        ClientRegistrationStrategy::Unavailable => {
 873            bail!("authorization server supports neither CIMD nor DCR")
 874        }
 875    }
 876}
 877
 878// -- Dynamic Client Registration (RFC 7591) ----------------------------------
 879
 880/// Perform Dynamic Client Registration with the authorization server.
 881pub async fn perform_dcr(
 882    http_client: &Arc<dyn HttpClient>,
 883    registration_endpoint: &Url,
 884    redirect_uri: &str,
 885) -> Result<OAuthClientRegistration> {
 886    validate_oauth_url(registration_endpoint)?;
 887
 888    let body = dcr_registration_body(redirect_uri);
 889    let body_bytes = serde_json::to_vec(&body)?;
 890
 891    let request = Request::builder()
 892        .method(http_client::http::Method::POST)
 893        .uri(registration_endpoint.as_str())
 894        .header("Content-Type", "application/json")
 895        .header("Accept", "application/json")
 896        .body(AsyncBody::from(body_bytes))?;
 897
 898    let mut response = http_client.send(request).await?;
 899
 900    if !response.status().is_success() {
 901        let mut error_body = String::new();
 902        response.body_mut().read_to_string(&mut error_body).await?;
 903        bail!(
 904            "DCR failed with status {}: {}",
 905            response.status(),
 906            error_body
 907        );
 908    }
 909
 910    let mut response_body = String::new();
 911    response
 912        .body_mut()
 913        .read_to_string(&mut response_body)
 914        .await?;
 915
 916    let dcr_response: DcrResponse =
 917        serde_json::from_str(&response_body).context("failed to parse DCR response")?;
 918
 919    Ok(OAuthClientRegistration {
 920        client_id: dcr_response.client_id,
 921        client_secret: dcr_response.client_secret,
 922    })
 923}
 924
 925// -- Token exchange and refresh (async) --------------------------------------
 926
 927/// Exchange an authorization code for tokens at the token endpoint.
 928pub async fn exchange_code(
 929    http_client: &Arc<dyn HttpClient>,
 930    auth_server_metadata: &AuthServerMetadata,
 931    code: &str,
 932    client_id: &str,
 933    redirect_uri: &str,
 934    code_verifier: &str,
 935    resource: &str,
 936    client_secret: Option<&str>,
 937) -> Result<OAuthTokens> {
 938    let params = token_exchange_params(
 939        code,
 940        client_id,
 941        redirect_uri,
 942        code_verifier,
 943        resource,
 944        client_secret,
 945    );
 946    post_token_request(http_client, &auth_server_metadata.token_endpoint, &params).await
 947}
 948
 949/// Refresh tokens using a refresh token.
 950pub async fn refresh_tokens(
 951    http_client: &Arc<dyn HttpClient>,
 952    token_endpoint: &Url,
 953    refresh_token: &str,
 954    client_id: &str,
 955    resource: &str,
 956    client_secret: Option<&str>,
 957) -> Result<OAuthTokens> {
 958    let params = token_refresh_params(refresh_token, client_id, resource, client_secret);
 959    post_token_request(http_client, token_endpoint, &params).await
 960}
 961
 962/// POST form-encoded parameters to a token endpoint and parse the response.
 963async fn post_token_request(
 964    http_client: &Arc<dyn HttpClient>,
 965    token_endpoint: &Url,
 966    params: &[(&str, String)],
 967) -> Result<OAuthTokens> {
 968    validate_oauth_url(token_endpoint)?;
 969
 970    let body = url::form_urlencoded::Serializer::new(String::new())
 971        .extend_pairs(params.iter().map(|(k, v)| (*k, v.as_str())))
 972        .finish();
 973
 974    let request = Request::builder()
 975        .method(http_client::http::Method::POST)
 976        .uri(token_endpoint.as_str())
 977        .header("Content-Type", "application/x-www-form-urlencoded")
 978        .header("Accept", "application/json")
 979        .body(AsyncBody::from(body.into_bytes()))?;
 980
 981    let mut response = http_client.send(request).await?;
 982
 983    if !response.status().is_success() {
 984        let mut error_body = String::new();
 985        response.body_mut().read_to_string(&mut error_body).await?;
 986        let status = response.status();
 987        // Try to parse as an OAuth error response (RFC 6749 Section 5.2).
 988        if let Ok(token_error) = serde_json::from_str::<OAuthTokenError>(&error_body) {
 989            return Err(token_error.into());
 990        }
 991        bail!("token request failed with status {status}: {error_body}");
 992    }
 993
 994    let mut response_body = String::new();
 995    response
 996        .body_mut()
 997        .read_to_string(&mut response_body)
 998        .await?;
 999
1000    let token_response: TokenResponse =
1001        serde_json::from_str(&response_body).context("failed to parse token response")?;
1002
1003    Ok(token_response.into_tokens())
1004}
1005
1006// -- Loopback HTTP callback server -------------------------------------------
1007
1008/// An OAuth authorization callback received via the loopback HTTP server.
1009pub struct OAuthCallback {
1010    pub code: String,
1011    pub state: String,
1012}
1013
1014impl std::fmt::Debug for OAuthCallback {
1015    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1016        f.debug_struct("OAuthCallback")
1017            .field("code", &"[redacted]")
1018            .field("state", &"[redacted]")
1019            .finish()
1020    }
1021}
1022
1023impl OAuthCallback {
1024    /// Parse the query string from a callback URL like
1025    /// `http://127.0.0.1:<port>/callback?code=...&state=...`.
1026    pub fn parse_query(query: &str) -> Result<Self> {
1027        let mut code: Option<String> = None;
1028        let mut state: Option<String> = None;
1029        let mut error: Option<String> = None;
1030        let mut error_description: Option<String> = None;
1031
1032        for (key, value) in url::form_urlencoded::parse(query.as_bytes()) {
1033            match key.as_ref() {
1034                "code" => {
1035                    if !value.is_empty() {
1036                        code = Some(value.into_owned());
1037                    }
1038                }
1039                "state" => {
1040                    if !value.is_empty() {
1041                        state = Some(value.into_owned());
1042                    }
1043                }
1044                "error" => {
1045                    if !value.is_empty() {
1046                        error = Some(value.into_owned());
1047                    }
1048                }
1049                "error_description" => {
1050                    if !value.is_empty() {
1051                        error_description = Some(value.into_owned());
1052                    }
1053                }
1054                _ => {}
1055            }
1056        }
1057
1058        // Check for OAuth error response (RFC 6749 Section 4.1.2.1) before
1059        // checking for missing code/state.
1060        if let Some(error_code) = error {
1061            bail!(
1062                "OAuth authorization failed: {} ({})",
1063                error_code,
1064                error_description.as_deref().unwrap_or("no description")
1065            );
1066        }
1067
1068        let code = code.ok_or_else(|| anyhow!("missing 'code' parameter in OAuth callback"))?;
1069        let state = state.ok_or_else(|| anyhow!("missing 'state' parameter in OAuth callback"))?;
1070
1071        Ok(Self { code, state })
1072    }
1073}
1074
1075/// How long to wait for the browser to complete the OAuth flow before giving
1076/// up and releasing the loopback port.
1077const CALLBACK_TIMEOUT: Duration = Duration::from_secs(2 * 60);
1078
1079/// Start a loopback HTTP server to receive the OAuth authorization callback.
1080///
1081/// Binds to an ephemeral loopback port for each flow.
1082///
1083/// Returns `(redirect_uri, callback_future)`. The caller should use the
1084/// redirect URI in the authorization request, open the browser, then await
1085/// the future to receive the callback.
1086///
1087/// The server accepts exactly one request on `/callback`, validates that it
1088/// contains `code` and `state` query parameters, responds with a minimal
1089/// HTML page telling the user they can close the tab, and shuts down.
1090///
1091/// The callback server shuts down when the returned oneshot receiver is dropped
1092/// (e.g. because the authentication task was cancelled), or after a timeout
1093/// ([CALLBACK_TIMEOUT]).
1094pub async fn start_callback_server() -> Result<(
1095    String,
1096    futures::channel::oneshot::Receiver<Result<OAuthCallback>>,
1097)> {
1098    let server = tiny_http::Server::http("127.0.0.1:0")
1099        .map_err(|e| anyhow!(e).context("Failed to bind loopback listener for OAuth callback"))?;
1100    let port = server
1101        .server_addr()
1102        .to_ip()
1103        .context("server not bound to a TCP address")?
1104        .port();
1105
1106    let redirect_uri = format!("http://127.0.0.1:{}/callback", port);
1107
1108    let (tx, rx) = futures::channel::oneshot::channel();
1109
1110    // `tiny_http` is blocking, so we run it on a background thread.
1111    // The `recv_timeout` loop lets us check for cancellation (the receiver
1112    // being dropped) and enforce an overall timeout.
1113    std::thread::spawn(move || {
1114        let deadline = std::time::Instant::now() + CALLBACK_TIMEOUT;
1115
1116        loop {
1117            if tx.is_canceled() {
1118                return;
1119            }
1120            let remaining = deadline.saturating_duration_since(std::time::Instant::now());
1121            if remaining.is_zero() {
1122                return;
1123            }
1124
1125            let timeout = remaining.min(Duration::from_millis(500));
1126            let Some(request) = (match server.recv_timeout(timeout) {
1127                Ok(req) => req,
1128                Err(_) => {
1129                    let _ = tx.send(Err(anyhow!("OAuth callback server I/O error")));
1130                    return;
1131                }
1132            }) else {
1133                // Timeout with no request — loop back and check cancellation.
1134                continue;
1135            };
1136
1137            let result = handle_callback_request(&request);
1138
1139            let (status_code, body) = match &result {
1140                Ok(_) => (
1141                    200,
1142                    "<html><body><h1>Authorization successful</h1>\
1143                     <p>You can close this tab and return to Zed.</p></body></html>",
1144                ),
1145                Err(err) => {
1146                    log::error!("OAuth callback error: {}", err);
1147                    (
1148                        400,
1149                        "<html><body><h1>Authorization failed</h1>\
1150                         <p>Something went wrong. Please try again from Zed.</p></body></html>",
1151                    )
1152                }
1153            };
1154
1155            let response = tiny_http::Response::from_string(body)
1156                .with_status_code(status_code)
1157                .with_header(
1158                    tiny_http::Header::from_str("Content-Type: text/html")
1159                        .expect("failed to construct response header"),
1160                )
1161                .with_header(
1162                    tiny_http::Header::from_str("Keep-Alive: timeout=0,max=0")
1163                        .expect("failed to construct response header"),
1164                );
1165            request.respond(response).log_err();
1166
1167            let _ = tx.send(result);
1168            return;
1169        }
1170    });
1171
1172    Ok((redirect_uri, rx))
1173}
1174
1175/// Extract the `code` and `state` query parameters from an OAuth callback
1176/// request to `/callback`.
1177fn handle_callback_request(request: &tiny_http::Request) -> Result<OAuthCallback> {
1178    let url = Url::parse(&format!("http://localhost{}", request.url()))
1179        .context("malformed callback request URL")?;
1180
1181    if url.path() != "/callback" {
1182        bail!("unexpected path in OAuth callback: {}", url.path());
1183    }
1184
1185    let query = url
1186        .query()
1187        .ok_or_else(|| anyhow!("OAuth callback has no query string"))?;
1188    OAuthCallback::parse_query(query)
1189}
1190
1191// -- JSON fetch helper -------------------------------------------------------
1192
1193async fn fetch_json<T: serde::de::DeserializeOwned>(
1194    http_client: &Arc<dyn HttpClient>,
1195    url: &Url,
1196) -> Result<T> {
1197    validate_oauth_url(url)?;
1198
1199    let request = Request::builder()
1200        .method(http_client::http::Method::GET)
1201        .uri(url.as_str())
1202        .header("Accept", "application/json")
1203        .body(AsyncBody::default())?;
1204
1205    let mut response = http_client.send(request).await?;
1206
1207    if !response.status().is_success() {
1208        bail!("HTTP {} fetching {}", response.status(), url);
1209    }
1210
1211    let mut body = String::new();
1212    response.body_mut().read_to_string(&mut body).await?;
1213    serde_json::from_str(&body).with_context(|| format!("failed to parse JSON from {}", url))
1214}
1215
1216// -- Serde response types for discovery --------------------------------------
1217
1218#[derive(Debug, Deserialize)]
1219struct ProtectedResourceMetadataResponse {
1220    #[serde(default)]
1221    resource: Option<Url>,
1222    #[serde(default)]
1223    authorization_servers: Vec<Url>,
1224    #[serde(default)]
1225    scopes_supported: Option<Vec<String>>,
1226}
1227
1228#[derive(Debug, Deserialize)]
1229struct AuthServerMetadataResponse {
1230    #[serde(default)]
1231    issuer: Option<Url>,
1232    #[serde(default)]
1233    authorization_endpoint: Option<Url>,
1234    #[serde(default)]
1235    token_endpoint: Option<Url>,
1236    #[serde(default)]
1237    registration_endpoint: Option<Url>,
1238    #[serde(default)]
1239    scopes_supported: Option<Vec<String>>,
1240    #[serde(default)]
1241    code_challenge_methods_supported: Option<Vec<String>>,
1242    #[serde(default)]
1243    client_id_metadata_document_supported: Option<bool>,
1244}
1245
1246#[derive(Debug, Deserialize)]
1247struct DcrResponse {
1248    client_id: String,
1249    #[serde(default)]
1250    client_secret: Option<String>,
1251}
1252
1253/// Provides OAuth tokens to the HTTP transport layer.
1254///
1255/// The transport calls `access_token()` before each request. On a 401 response
1256/// it calls `try_refresh()` and retries once if the refresh succeeds.
1257#[async_trait]
1258pub trait OAuthTokenProvider: Send + Sync {
1259    /// Returns the current access token, if one is available.
1260    fn access_token(&self) -> Option<String>;
1261
1262    /// Attempts to refresh the access token. Returns `true` if a new token was
1263    /// obtained and the request should be retried.
1264    async fn try_refresh(&self) -> Result<bool>;
1265}
1266
1267/// Concrete `OAuthTokenProvider` backed by a full persisted OAuth session and
1268/// an HTTP client for token refresh. The same provider type is used both after
1269/// an interactive authentication flow and when restoring a saved session from
1270/// the keychain on startup.
1271pub struct McpOAuthTokenProvider {
1272    session: SyncMutex<OAuthSession>,
1273    http_client: Arc<dyn HttpClient>,
1274    token_refresh_tx: Option<mpsc::UnboundedSender<OAuthSession>>,
1275}
1276
1277impl McpOAuthTokenProvider {
1278    pub fn new(
1279        session: OAuthSession,
1280        http_client: Arc<dyn HttpClient>,
1281        token_refresh_tx: Option<mpsc::UnboundedSender<OAuthSession>>,
1282    ) -> Self {
1283        Self {
1284            session: SyncMutex::new(session),
1285            http_client,
1286            token_refresh_tx,
1287        }
1288    }
1289
1290    fn access_token_is_expired(tokens: &OAuthTokens) -> bool {
1291        tokens.expires_at.is_some_and(|expires_at| {
1292            SystemTime::now()
1293                .checked_add(Duration::from_secs(30))
1294                .is_some_and(|now_with_buffer| expires_at <= now_with_buffer)
1295        })
1296    }
1297}
1298
1299#[async_trait]
1300impl OAuthTokenProvider for McpOAuthTokenProvider {
1301    fn access_token(&self) -> Option<String> {
1302        let session = self.session.lock();
1303        if Self::access_token_is_expired(&session.tokens) {
1304            return None;
1305        }
1306        Some(session.tokens.access_token.clone())
1307    }
1308
1309    async fn try_refresh(&self) -> Result<bool> {
1310        let (refresh_token, token_endpoint, resource, client_id, client_secret) = {
1311            let session = self.session.lock();
1312            match session.tokens.refresh_token.clone() {
1313                Some(refresh_token) => (
1314                    refresh_token,
1315                    session.token_endpoint.clone(),
1316                    session.resource.clone(),
1317                    session.client_registration.client_id.clone(),
1318                    session.client_registration.client_secret.clone(),
1319                ),
1320                None => return Ok(false),
1321            }
1322        };
1323
1324        let resource_str = canonical_server_uri(&resource);
1325
1326        match refresh_tokens(
1327            &self.http_client,
1328            &token_endpoint,
1329            &refresh_token,
1330            &client_id,
1331            &resource_str,
1332            client_secret.as_deref(),
1333        )
1334        .await
1335        {
1336            Ok(mut new_tokens) => {
1337                if new_tokens.refresh_token.is_none() {
1338                    new_tokens.refresh_token = Some(refresh_token);
1339                }
1340
1341                {
1342                    let mut session = self.session.lock();
1343                    session.tokens = new_tokens;
1344
1345                    if let Some(ref tx) = self.token_refresh_tx {
1346                        tx.unbounded_send(session.clone()).ok();
1347                    }
1348                }
1349
1350                Ok(true)
1351            }
1352            Err(err) => {
1353                log::warn!("OAuth token refresh failed: {}", err);
1354                Ok(false)
1355            }
1356        }
1357    }
1358}
1359
1360#[cfg(test)]
1361mod tests {
1362    use super::*;
1363    use http_client::Response;
1364
1365    // -- require_https_or_loopback tests ------------------------------------
1366
1367    #[test]
1368    fn test_require_https_or_loopback_accepts_https() {
1369        let url = Url::parse("https://auth.example.com/token").unwrap();
1370        assert!(require_https_or_loopback(&url).is_ok());
1371    }
1372
1373    #[test]
1374    fn test_require_https_or_loopback_rejects_http_remote() {
1375        let url = Url::parse("http://auth.example.com/token").unwrap();
1376        assert!(require_https_or_loopback(&url).is_err());
1377    }
1378
1379    #[test]
1380    fn test_require_https_or_loopback_accepts_http_127_0_0_1() {
1381        let url = Url::parse("http://127.0.0.1:8080/callback").unwrap();
1382        assert!(require_https_or_loopback(&url).is_ok());
1383    }
1384
1385    #[test]
1386    fn test_require_https_or_loopback_accepts_http_ipv6_loopback() {
1387        let url = Url::parse("http://[::1]:8080/callback").unwrap();
1388        assert!(require_https_or_loopback(&url).is_ok());
1389    }
1390
1391    #[test]
1392    fn test_require_https_or_loopback_accepts_http_localhost() {
1393        let url = Url::parse("http://localhost:8080/callback").unwrap();
1394        assert!(require_https_or_loopback(&url).is_ok());
1395    }
1396
1397    #[test]
1398    fn test_require_https_or_loopback_accepts_http_localhost_case_insensitive() {
1399        let url = Url::parse("http://LOCALHOST:8080/callback").unwrap();
1400        assert!(require_https_or_loopback(&url).is_ok());
1401    }
1402
1403    #[test]
1404    fn test_require_https_or_loopback_rejects_http_non_loopback_ip() {
1405        let url = Url::parse("http://192.168.1.1:8080/token").unwrap();
1406        assert!(require_https_or_loopback(&url).is_err());
1407    }
1408
1409    #[test]
1410    fn test_require_https_or_loopback_rejects_ftp() {
1411        let url = Url::parse("ftp://auth.example.com/token").unwrap();
1412        assert!(require_https_or_loopback(&url).is_err());
1413    }
1414
1415    // -- validate_oauth_url (SSRF) tests ------------------------------------
1416
1417    #[test]
1418    fn test_validate_oauth_url_accepts_https_public() {
1419        let url = Url::parse("https://auth.example.com/token").unwrap();
1420        assert!(validate_oauth_url(&url).is_ok());
1421    }
1422
1423    #[test]
1424    fn test_validate_oauth_url_rejects_private_ipv4_10() {
1425        let url = Url::parse("https://10.0.0.1/token").unwrap();
1426        assert!(validate_oauth_url(&url).is_err());
1427    }
1428
1429    #[test]
1430    fn test_validate_oauth_url_rejects_private_ipv4_172() {
1431        let url = Url::parse("https://172.16.0.1/token").unwrap();
1432        assert!(validate_oauth_url(&url).is_err());
1433    }
1434
1435    #[test]
1436    fn test_validate_oauth_url_rejects_private_ipv4_192() {
1437        let url = Url::parse("https://192.168.1.1/token").unwrap();
1438        assert!(validate_oauth_url(&url).is_err());
1439    }
1440
1441    #[test]
1442    fn test_validate_oauth_url_rejects_link_local() {
1443        let url = Url::parse("https://169.254.169.254/latest/meta-data/").unwrap();
1444        assert!(validate_oauth_url(&url).is_err());
1445    }
1446
1447    #[test]
1448    fn test_validate_oauth_url_rejects_ipv6_ula() {
1449        let url = Url::parse("https://[fd12:3456:789a::1]/token").unwrap();
1450        assert!(validate_oauth_url(&url).is_err());
1451    }
1452
1453    #[test]
1454    fn test_validate_oauth_url_rejects_ipv6_unspecified() {
1455        let url = Url::parse("https://[::]/token").unwrap();
1456        assert!(validate_oauth_url(&url).is_err());
1457    }
1458
1459    #[test]
1460    fn test_validate_oauth_url_rejects_ipv4_mapped_ipv6_private() {
1461        let url = Url::parse("https://[::ffff:10.0.0.1]/token").unwrap();
1462        assert!(validate_oauth_url(&url).is_err());
1463    }
1464
1465    #[test]
1466    fn test_validate_oauth_url_rejects_ipv4_mapped_ipv6_link_local() {
1467        let url = Url::parse("https://[::ffff:169.254.169.254]/token").unwrap();
1468        assert!(validate_oauth_url(&url).is_err());
1469    }
1470
1471    #[test]
1472    fn test_validate_oauth_url_allows_http_loopback() {
1473        // Loopback is permitted (it's our callback server).
1474        let url = Url::parse("http://127.0.0.1:8080/callback").unwrap();
1475        assert!(validate_oauth_url(&url).is_ok());
1476    }
1477
1478    #[test]
1479    fn test_validate_oauth_url_allows_https_public_ip() {
1480        let url = Url::parse("https://93.184.216.34/token").unwrap();
1481        assert!(validate_oauth_url(&url).is_ok());
1482    }
1483
1484    // -- parse_www_authenticate tests ----------------------------------------
1485
1486    #[test]
1487    fn test_parse_www_authenticate_with_resource_metadata_and_scope() {
1488        let header = r#"Bearer resource_metadata="https://mcp.example.com/.well-known/oauth-protected-resource", scope="files:read user:profile""#;
1489        let result = parse_www_authenticate(header).unwrap();
1490
1491        assert_eq!(
1492            result.resource_metadata.as_ref().map(|u| u.as_str()),
1493            Some("https://mcp.example.com/.well-known/oauth-protected-resource")
1494        );
1495        assert_eq!(
1496            result.scope,
1497            Some(vec!["files:read".to_string(), "user:profile".to_string()])
1498        );
1499        assert_eq!(result.error, None);
1500    }
1501
1502    #[test]
1503    fn test_parse_www_authenticate_resource_metadata_only() {
1504        let header = r#"Bearer resource_metadata="https://mcp.example.com/.well-known/oauth-protected-resource""#;
1505        let result = parse_www_authenticate(header).unwrap();
1506
1507        assert_eq!(
1508            result.resource_metadata.as_ref().map(|u| u.as_str()),
1509            Some("https://mcp.example.com/.well-known/oauth-protected-resource")
1510        );
1511        assert_eq!(result.scope, None);
1512    }
1513
1514    #[test]
1515    fn test_parse_www_authenticate_bare_bearer() {
1516        let result = parse_www_authenticate("Bearer").unwrap();
1517        assert_eq!(result.resource_metadata, None);
1518        assert_eq!(result.scope, None);
1519    }
1520
1521    #[test]
1522    fn test_parse_www_authenticate_with_error() {
1523        let header = r#"Bearer error="insufficient_scope", scope="files:read files:write", resource_metadata="https://mcp.example.com/.well-known/oauth-protected-resource", error_description="Additional file write permission required""#;
1524        let result = parse_www_authenticate(header).unwrap();
1525
1526        assert_eq!(result.error, Some(BearerError::InsufficientScope));
1527        assert_eq!(
1528            result.error_description.as_deref(),
1529            Some("Additional file write permission required")
1530        );
1531        assert_eq!(
1532            result.scope,
1533            Some(vec!["files:read".to_string(), "files:write".to_string()])
1534        );
1535        assert!(result.resource_metadata.is_some());
1536    }
1537
1538    #[test]
1539    fn test_parse_www_authenticate_invalid_token_error() {
1540        let header =
1541            r#"Bearer error="invalid_token", error_description="The access token expired""#;
1542        let result = parse_www_authenticate(header).unwrap();
1543        assert_eq!(result.error, Some(BearerError::InvalidToken));
1544    }
1545
1546    #[test]
1547    fn test_parse_www_authenticate_invalid_request_error() {
1548        let header = r#"Bearer error="invalid_request""#;
1549        let result = parse_www_authenticate(header).unwrap();
1550        assert_eq!(result.error, Some(BearerError::InvalidRequest));
1551    }
1552
1553    #[test]
1554    fn test_parse_www_authenticate_unknown_error() {
1555        let header = r#"Bearer error="some_future_error""#;
1556        let result = parse_www_authenticate(header).unwrap();
1557        assert_eq!(result.error, Some(BearerError::Other));
1558    }
1559
1560    #[test]
1561    fn test_parse_www_authenticate_rejects_non_bearer() {
1562        let result = parse_www_authenticate("Basic realm=\"example\"");
1563        assert!(result.is_err());
1564    }
1565
1566    #[test]
1567    fn test_parse_www_authenticate_case_insensitive_scheme() {
1568        let header = r#"bearer resource_metadata="https://example.com/.well-known/oauth-protected-resource""#;
1569        let result = parse_www_authenticate(header).unwrap();
1570        assert!(result.resource_metadata.is_some());
1571    }
1572
1573    #[test]
1574    fn test_parse_www_authenticate_multiline_style() {
1575        // Some servers emit the header spread across multiple lines joined by
1576        // whitespace, as shown in the spec examples.
1577        let header = "Bearer resource_metadata=\"https://mcp.example.com/.well-known/oauth-protected-resource\",\n                         scope=\"files:read\"";
1578        let result = parse_www_authenticate(header).unwrap();
1579        assert!(result.resource_metadata.is_some());
1580        assert_eq!(result.scope, Some(vec!["files:read".to_string()]));
1581    }
1582
1583    #[test]
1584    fn test_protected_resource_metadata_urls_with_path() {
1585        let server_url = Url::parse("https://api.example.com/v1/mcp").unwrap();
1586        let urls = protected_resource_metadata_urls(&server_url);
1587
1588        assert_eq!(urls.len(), 2);
1589        assert_eq!(
1590            urls[0].as_str(),
1591            "https://api.example.com/.well-known/oauth-protected-resource/v1/mcp"
1592        );
1593        assert_eq!(
1594            urls[1].as_str(),
1595            "https://api.example.com/.well-known/oauth-protected-resource"
1596        );
1597    }
1598
1599    #[test]
1600    fn test_protected_resource_metadata_urls_without_path() {
1601        let server_url = Url::parse("https://mcp.example.com").unwrap();
1602        let urls = protected_resource_metadata_urls(&server_url);
1603
1604        assert_eq!(urls.len(), 1);
1605        assert_eq!(
1606            urls[0].as_str(),
1607            "https://mcp.example.com/.well-known/oauth-protected-resource"
1608        );
1609    }
1610
1611    #[test]
1612    fn test_auth_server_metadata_urls_with_path() {
1613        let issuer = Url::parse("https://auth.example.com/tenant1").unwrap();
1614        let urls = auth_server_metadata_urls(&issuer);
1615
1616        assert_eq!(urls.len(), 3);
1617        assert_eq!(
1618            urls[0].as_str(),
1619            "https://auth.example.com/.well-known/oauth-authorization-server/tenant1"
1620        );
1621        assert_eq!(
1622            urls[1].as_str(),
1623            "https://auth.example.com/.well-known/openid-configuration/tenant1"
1624        );
1625        assert_eq!(
1626            urls[2].as_str(),
1627            "https://auth.example.com/tenant1/.well-known/openid-configuration"
1628        );
1629    }
1630
1631    #[test]
1632    fn test_auth_server_metadata_urls_without_path() {
1633        let issuer = Url::parse("https://auth.example.com").unwrap();
1634        let urls = auth_server_metadata_urls(&issuer);
1635
1636        assert_eq!(urls.len(), 2);
1637        assert_eq!(
1638            urls[0].as_str(),
1639            "https://auth.example.com/.well-known/oauth-authorization-server"
1640        );
1641        assert_eq!(
1642            urls[1].as_str(),
1643            "https://auth.example.com/.well-known/openid-configuration"
1644        );
1645    }
1646
1647    // -- Canonical server URI tests ------------------------------------------
1648
1649    #[test]
1650    fn test_canonical_server_uri_simple() {
1651        let url = Url::parse("https://mcp.example.com").unwrap();
1652        assert_eq!(canonical_server_uri(&url), "https://mcp.example.com");
1653    }
1654
1655    #[test]
1656    fn test_canonical_server_uri_with_path() {
1657        let url = Url::parse("https://mcp.example.com/v1/mcp").unwrap();
1658        assert_eq!(canonical_server_uri(&url), "https://mcp.example.com/v1/mcp");
1659    }
1660
1661    #[test]
1662    fn test_canonical_server_uri_strips_trailing_slash() {
1663        let url = Url::parse("https://mcp.example.com/").unwrap();
1664        assert_eq!(canonical_server_uri(&url), "https://mcp.example.com");
1665    }
1666
1667    #[test]
1668    fn test_canonical_server_uri_preserves_port() {
1669        let url = Url::parse("https://mcp.example.com:8443").unwrap();
1670        assert_eq!(canonical_server_uri(&url), "https://mcp.example.com:8443");
1671    }
1672
1673    #[test]
1674    fn test_canonical_server_uri_lowercases() {
1675        let url = Url::parse("HTTPS://MCP.Example.COM/Server/MCP").unwrap();
1676        assert_eq!(
1677            canonical_server_uri(&url),
1678            "https://mcp.example.com/Server/MCP"
1679        );
1680    }
1681
1682    // -- Scope selection tests -----------------------------------------------
1683
1684    #[test]
1685    fn test_select_scopes_prefers_www_authenticate() {
1686        let www_auth = WwwAuthenticate {
1687            resource_metadata: None,
1688            scope: Some(vec!["files:read".into()]),
1689            error: None,
1690            error_description: None,
1691        };
1692        let resource_meta = ProtectedResourceMetadata {
1693            resource: Url::parse("https://example.com").unwrap(),
1694            authorization_servers: vec![],
1695            scopes_supported: Some(vec!["files:read".into(), "files:write".into()]),
1696        };
1697        assert_eq!(select_scopes(&www_auth, &resource_meta), vec!["files:read"]);
1698    }
1699
1700    #[test]
1701    fn test_select_scopes_falls_back_to_resource_metadata() {
1702        let www_auth = WwwAuthenticate {
1703            resource_metadata: None,
1704            scope: None,
1705            error: None,
1706            error_description: None,
1707        };
1708        let resource_meta = ProtectedResourceMetadata {
1709            resource: Url::parse("https://example.com").unwrap(),
1710            authorization_servers: vec![],
1711            scopes_supported: Some(vec!["admin".into()]),
1712        };
1713        assert_eq!(select_scopes(&www_auth, &resource_meta), vec!["admin"]);
1714    }
1715
1716    #[test]
1717    fn test_select_scopes_empty_when_nothing_available() {
1718        let www_auth = WwwAuthenticate {
1719            resource_metadata: None,
1720            scope: None,
1721            error: None,
1722            error_description: None,
1723        };
1724        let resource_meta = ProtectedResourceMetadata {
1725            resource: Url::parse("https://example.com").unwrap(),
1726            authorization_servers: vec![],
1727            scopes_supported: None,
1728        };
1729        assert!(select_scopes(&www_auth, &resource_meta).is_empty());
1730    }
1731
1732    // -- Client registration strategy tests ----------------------------------
1733
1734    #[test]
1735    fn test_registration_strategy_prefers_cimd() {
1736        let metadata = AuthServerMetadata {
1737            issuer: Url::parse("https://auth.example.com").unwrap(),
1738            authorization_endpoint: Url::parse("https://auth.example.com/authorize").unwrap(),
1739            token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
1740            registration_endpoint: Some(Url::parse("https://auth.example.com/register").unwrap()),
1741            scopes_supported: None,
1742            code_challenge_methods_supported: Some(vec!["S256".into()]),
1743            client_id_metadata_document_supported: true,
1744        };
1745        assert_eq!(
1746            determine_registration_strategy(&metadata),
1747            ClientRegistrationStrategy::Cimd {
1748                client_id: CIMD_URL.to_string(),
1749            }
1750        );
1751    }
1752
1753    #[test]
1754    fn test_registration_strategy_falls_back_to_dcr() {
1755        let reg_endpoint = Url::parse("https://auth.example.com/register").unwrap();
1756        let metadata = AuthServerMetadata {
1757            issuer: Url::parse("https://auth.example.com").unwrap(),
1758            authorization_endpoint: Url::parse("https://auth.example.com/authorize").unwrap(),
1759            token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
1760            registration_endpoint: Some(reg_endpoint.clone()),
1761            scopes_supported: None,
1762            code_challenge_methods_supported: Some(vec!["S256".into()]),
1763            client_id_metadata_document_supported: false,
1764        };
1765        assert_eq!(
1766            determine_registration_strategy(&metadata),
1767            ClientRegistrationStrategy::Dcr {
1768                registration_endpoint: reg_endpoint,
1769            }
1770        );
1771    }
1772
1773    #[test]
1774    fn test_registration_strategy_unavailable() {
1775        let metadata = AuthServerMetadata {
1776            issuer: Url::parse("https://auth.example.com").unwrap(),
1777            authorization_endpoint: Url::parse("https://auth.example.com/authorize").unwrap(),
1778            token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
1779            registration_endpoint: None,
1780            scopes_supported: None,
1781            code_challenge_methods_supported: Some(vec!["S256".into()]),
1782            client_id_metadata_document_supported: false,
1783        };
1784        assert_eq!(
1785            determine_registration_strategy(&metadata),
1786            ClientRegistrationStrategy::Unavailable,
1787        );
1788    }
1789
1790    // -- PKCE tests ----------------------------------------------------------
1791
1792    #[test]
1793    fn test_pkce_challenge_verifier_length() {
1794        let pkce = generate_pkce_challenge();
1795        // 32 random bytes → 43 base64url chars (no padding).
1796        assert_eq!(pkce.verifier.len(), 43);
1797    }
1798
1799    #[test]
1800    fn test_pkce_challenge_is_valid_base64url() {
1801        let pkce = generate_pkce_challenge();
1802        for c in pkce.verifier.chars().chain(pkce.challenge.chars()) {
1803            assert!(
1804                c.is_ascii_alphanumeric() || c == '-' || c == '_',
1805                "invalid base64url character: {}",
1806                c
1807            );
1808        }
1809    }
1810
1811    #[test]
1812    fn test_pkce_challenge_is_s256_of_verifier() {
1813        let pkce = generate_pkce_challenge();
1814        let engine = base64::engine::general_purpose::URL_SAFE_NO_PAD;
1815        let expected_digest = Sha256::digest(pkce.verifier.as_bytes());
1816        let expected_challenge = engine.encode(expected_digest);
1817        assert_eq!(pkce.challenge, expected_challenge);
1818    }
1819
1820    #[test]
1821    fn test_pkce_challenges_are_unique() {
1822        let a = generate_pkce_challenge();
1823        let b = generate_pkce_challenge();
1824        assert_ne!(a.verifier, b.verifier);
1825    }
1826
1827    // -- Authorization URL tests ---------------------------------------------
1828
1829    #[test]
1830    fn test_build_authorization_url() {
1831        let metadata = AuthServerMetadata {
1832            issuer: Url::parse("https://auth.example.com").unwrap(),
1833            authorization_endpoint: Url::parse("https://auth.example.com/authorize").unwrap(),
1834            token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
1835            registration_endpoint: None,
1836            scopes_supported: None,
1837            code_challenge_methods_supported: Some(vec!["S256".into()]),
1838            client_id_metadata_document_supported: true,
1839        };
1840        let pkce = PkceChallenge {
1841            verifier: "test_verifier".into(),
1842            challenge: "test_challenge".into(),
1843        };
1844        let url = build_authorization_url(
1845            &metadata,
1846            "https://zed.dev/oauth/client-metadata.json",
1847            "http://127.0.0.1:12345/callback",
1848            &["files:read".into(), "files:write".into()],
1849            "https://mcp.example.com",
1850            &pkce,
1851            "random_state_123",
1852        );
1853
1854        let pairs: std::collections::HashMap<_, _> = url.query_pairs().collect();
1855        assert_eq!(pairs.get("response_type").unwrap(), "code");
1856        assert_eq!(
1857            pairs.get("client_id").unwrap(),
1858            "https://zed.dev/oauth/client-metadata.json"
1859        );
1860        assert_eq!(
1861            pairs.get("redirect_uri").unwrap(),
1862            "http://127.0.0.1:12345/callback"
1863        );
1864        assert_eq!(pairs.get("scope").unwrap(), "files:read files:write");
1865        assert_eq!(pairs.get("resource").unwrap(), "https://mcp.example.com");
1866        assert_eq!(pairs.get("code_challenge").unwrap(), "test_challenge");
1867        assert_eq!(pairs.get("code_challenge_method").unwrap(), "S256");
1868        assert_eq!(pairs.get("state").unwrap(), "random_state_123");
1869    }
1870
1871    #[test]
1872    fn test_build_authorization_url_omits_empty_scope() {
1873        let metadata = AuthServerMetadata {
1874            issuer: Url::parse("https://auth.example.com").unwrap(),
1875            authorization_endpoint: Url::parse("https://auth.example.com/authorize").unwrap(),
1876            token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
1877            registration_endpoint: None,
1878            scopes_supported: None,
1879            code_challenge_methods_supported: Some(vec!["S256".into()]),
1880            client_id_metadata_document_supported: false,
1881        };
1882        let pkce = PkceChallenge {
1883            verifier: "v".into(),
1884            challenge: "c".into(),
1885        };
1886        let url = build_authorization_url(
1887            &metadata,
1888            "client_123",
1889            "http://127.0.0.1:9999/callback",
1890            &[],
1891            "https://mcp.example.com",
1892            &pkce,
1893            "state",
1894        );
1895
1896        let pairs: std::collections::HashMap<_, _> = url.query_pairs().collect();
1897        assert!(!pairs.contains_key("scope"));
1898    }
1899
1900    // -- Token exchange / refresh param tests --------------------------------
1901
1902    #[test]
1903    fn test_token_exchange_params() {
1904        let params = token_exchange_params(
1905            "auth_code_abc",
1906            "client_xyz",
1907            "http://127.0.0.1:5555/callback",
1908            "verifier_123",
1909            "https://mcp.example.com",
1910            None,
1911        );
1912        let map: std::collections::HashMap<&str, &str> =
1913            params.iter().map(|(k, v)| (*k, v.as_str())).collect();
1914
1915        assert_eq!(map["grant_type"], "authorization_code");
1916        assert_eq!(map["code"], "auth_code_abc");
1917        assert_eq!(map["redirect_uri"], "http://127.0.0.1:5555/callback");
1918        assert_eq!(map["client_id"], "client_xyz");
1919        assert_eq!(map["code_verifier"], "verifier_123");
1920        assert_eq!(map["resource"], "https://mcp.example.com");
1921    }
1922
1923    #[test]
1924    fn test_token_refresh_params() {
1925        let params = token_refresh_params(
1926            "refresh_token_abc",
1927            "client_xyz",
1928            "https://mcp.example.com",
1929            None,
1930        );
1931        let map: std::collections::HashMap<&str, &str> =
1932            params.iter().map(|(k, v)| (*k, v.as_str())).collect();
1933
1934        assert_eq!(map["grant_type"], "refresh_token");
1935        assert_eq!(map["refresh_token"], "refresh_token_abc");
1936        assert_eq!(map["client_id"], "client_xyz");
1937        assert_eq!(map["resource"], "https://mcp.example.com");
1938    }
1939
1940    // -- Token response tests ------------------------------------------------
1941
1942    #[test]
1943    fn test_token_response_into_tokens_with_expiry() {
1944        let response: TokenResponse = serde_json::from_str(
1945            r#"{"access_token": "at_123", "refresh_token": "rt_456", "expires_in": 3600, "token_type": "Bearer"}"#,
1946        )
1947        .unwrap();
1948
1949        let tokens = response.into_tokens();
1950        assert_eq!(tokens.access_token, "at_123");
1951        assert_eq!(tokens.refresh_token.as_deref(), Some("rt_456"));
1952        assert!(tokens.expires_at.is_some());
1953    }
1954
1955    #[test]
1956    fn test_token_response_into_tokens_minimal() {
1957        let response: TokenResponse =
1958            serde_json::from_str(r#"{"access_token": "at_789"}"#).unwrap();
1959
1960        let tokens = response.into_tokens();
1961        assert_eq!(tokens.access_token, "at_789");
1962        assert_eq!(tokens.refresh_token, None);
1963        assert_eq!(tokens.expires_at, None);
1964    }
1965
1966    // -- DCR body test -------------------------------------------------------
1967
1968    #[test]
1969    fn test_dcr_registration_body_shape() {
1970        let body = dcr_registration_body("http://127.0.0.1:12345/callback");
1971        assert_eq!(body["client_name"], "Zed");
1972        assert_eq!(body["redirect_uris"][0], "http://127.0.0.1:12345/callback");
1973        assert_eq!(body["grant_types"][0], "authorization_code");
1974        assert_eq!(body["response_types"][0], "code");
1975        assert_eq!(body["token_endpoint_auth_method"], "none");
1976    }
1977
1978    // -- Test helpers for async/HTTP tests -----------------------------------
1979
1980    fn make_fake_http_client(
1981        handler: impl Fn(
1982            http_client::Request<AsyncBody>,
1983        ) -> std::pin::Pin<
1984            Box<dyn std::future::Future<Output = anyhow::Result<Response<AsyncBody>>> + Send>,
1985        > + Send
1986        + Sync
1987        + 'static,
1988    ) -> Arc<dyn HttpClient> {
1989        http_client::FakeHttpClient::create(handler) as Arc<dyn HttpClient>
1990    }
1991
1992    fn json_response(status: u16, body: &str) -> anyhow::Result<Response<AsyncBody>> {
1993        Ok(Response::builder()
1994            .status(status)
1995            .header("Content-Type", "application/json")
1996            .body(AsyncBody::from(body.as_bytes().to_vec()))
1997            .unwrap())
1998    }
1999
2000    // -- Discovery integration tests -----------------------------------------
2001
2002    #[test]
2003    fn test_fetch_protected_resource_metadata() {
2004        smol::block_on(async {
2005            let client = make_fake_http_client(|req| {
2006                Box::pin(async move {
2007                    let uri = req.uri().to_string();
2008                    if uri.contains(".well-known/oauth-protected-resource") {
2009                        json_response(
2010                            200,
2011                            r#"{
2012                                "resource": "https://mcp.example.com",
2013                                "authorization_servers": ["https://auth.example.com"],
2014                                "scopes_supported": ["read", "write"]
2015                            }"#,
2016                        )
2017                    } else {
2018                        json_response(404, "{}")
2019                    }
2020                })
2021            });
2022
2023            let server_url = Url::parse("https://mcp.example.com").unwrap();
2024            let www_auth = WwwAuthenticate {
2025                resource_metadata: None,
2026                scope: None,
2027                error: None,
2028                error_description: None,
2029            };
2030
2031            let metadata = fetch_protected_resource_metadata(&client, &server_url, &www_auth)
2032                .await
2033                .unwrap();
2034
2035            assert_eq!(metadata.resource.as_str(), "https://mcp.example.com/");
2036            assert_eq!(metadata.authorization_servers.len(), 1);
2037            assert_eq!(
2038                metadata.authorization_servers[0].as_str(),
2039                "https://auth.example.com/"
2040            );
2041            assert_eq!(
2042                metadata.scopes_supported,
2043                Some(vec!["read".to_string(), "write".to_string()])
2044            );
2045        });
2046    }
2047
2048    #[test]
2049    fn test_fetch_protected_resource_metadata_prefers_www_authenticate_url() {
2050        smol::block_on(async {
2051            let client = make_fake_http_client(|req| {
2052                Box::pin(async move {
2053                    let uri = req.uri().to_string();
2054                    if uri == "https://mcp.example.com/custom-resource-metadata" {
2055                        json_response(
2056                            200,
2057                            r#"{
2058                                "resource": "https://mcp.example.com",
2059                                "authorization_servers": ["https://auth.example.com"]
2060                            }"#,
2061                        )
2062                    } else {
2063                        json_response(500, r#"{"error": "should not be called"}"#)
2064                    }
2065                })
2066            });
2067
2068            let server_url = Url::parse("https://mcp.example.com").unwrap();
2069            let www_auth = WwwAuthenticate {
2070                resource_metadata: Some(
2071                    Url::parse("https://mcp.example.com/custom-resource-metadata").unwrap(),
2072                ),
2073                scope: None,
2074                error: None,
2075                error_description: None,
2076            };
2077
2078            let metadata = fetch_protected_resource_metadata(&client, &server_url, &www_auth)
2079                .await
2080                .unwrap();
2081
2082            assert_eq!(metadata.authorization_servers.len(), 1);
2083        });
2084    }
2085
2086    #[test]
2087    fn test_fetch_protected_resource_metadata_rejects_cross_origin_url() {
2088        smol::block_on(async {
2089            let client = make_fake_http_client(|req| {
2090                Box::pin(async move {
2091                    let uri = req.uri().to_string();
2092                    // The cross-origin URL should NOT be fetched; only the
2093                    // well-known fallback at the server's own origin should be.
2094                    if uri.contains("attacker.example.com") {
2095                        panic!("should not fetch cross-origin resource_metadata URL");
2096                    } else if uri.contains(".well-known/oauth-protected-resource") {
2097                        json_response(
2098                            200,
2099                            r#"{
2100                                "resource": "https://mcp.example.com",
2101                                "authorization_servers": ["https://auth.example.com"]
2102                            }"#,
2103                        )
2104                    } else {
2105                        json_response(404, "{}")
2106                    }
2107                })
2108            });
2109
2110            let server_url = Url::parse("https://mcp.example.com").unwrap();
2111            let www_auth = WwwAuthenticate {
2112                resource_metadata: Some(
2113                    Url::parse("https://attacker.example.com/fake-metadata").unwrap(),
2114                ),
2115                scope: None,
2116                error: None,
2117                error_description: None,
2118            };
2119
2120            let metadata = fetch_protected_resource_metadata(&client, &server_url, &www_auth)
2121                .await
2122                .unwrap();
2123
2124            // Should have used the fallback well-known URL, not the attacker's.
2125            assert_eq!(metadata.resource.as_str(), "https://mcp.example.com/");
2126        });
2127    }
2128
2129    #[test]
2130    fn test_fetch_auth_server_metadata() {
2131        smol::block_on(async {
2132            let client = make_fake_http_client(|req| {
2133                Box::pin(async move {
2134                    let uri = req.uri().to_string();
2135                    if uri.contains(".well-known/oauth-authorization-server") {
2136                        json_response(
2137                            200,
2138                            r#"{
2139                                "issuer": "https://auth.example.com",
2140                                "authorization_endpoint": "https://auth.example.com/authorize",
2141                                "token_endpoint": "https://auth.example.com/token",
2142                                "registration_endpoint": "https://auth.example.com/register",
2143                                "code_challenge_methods_supported": ["S256"],
2144                                "client_id_metadata_document_supported": true
2145                            }"#,
2146                        )
2147                    } else {
2148                        json_response(404, "{}")
2149                    }
2150                })
2151            });
2152
2153            let issuer = Url::parse("https://auth.example.com").unwrap();
2154            let metadata = fetch_auth_server_metadata(&client, &issuer).await.unwrap();
2155
2156            assert_eq!(metadata.issuer.as_str(), "https://auth.example.com/");
2157            assert_eq!(
2158                metadata.authorization_endpoint.as_str(),
2159                "https://auth.example.com/authorize"
2160            );
2161            assert_eq!(
2162                metadata.token_endpoint.as_str(),
2163                "https://auth.example.com/token"
2164            );
2165            assert!(metadata.registration_endpoint.is_some());
2166            assert!(metadata.client_id_metadata_document_supported);
2167            assert_eq!(
2168                metadata.code_challenge_methods_supported,
2169                Some(vec!["S256".to_string()])
2170            );
2171        });
2172    }
2173
2174    #[test]
2175    fn test_fetch_auth_server_metadata_falls_back_to_oidc() {
2176        smol::block_on(async {
2177            let client = make_fake_http_client(|req| {
2178                Box::pin(async move {
2179                    let uri = req.uri().to_string();
2180                    if uri.contains("openid-configuration") {
2181                        json_response(
2182                            200,
2183                            r#"{
2184                                "issuer": "https://auth.example.com",
2185                                "authorization_endpoint": "https://auth.example.com/authorize",
2186                                "token_endpoint": "https://auth.example.com/token",
2187                                "code_challenge_methods_supported": ["S256"]
2188                            }"#,
2189                        )
2190                    } else {
2191                        json_response(404, "{}")
2192                    }
2193                })
2194            });
2195
2196            let issuer = Url::parse("https://auth.example.com").unwrap();
2197            let metadata = fetch_auth_server_metadata(&client, &issuer).await.unwrap();
2198
2199            assert_eq!(
2200                metadata.authorization_endpoint.as_str(),
2201                "https://auth.example.com/authorize"
2202            );
2203            assert!(!metadata.client_id_metadata_document_supported);
2204        });
2205    }
2206
2207    #[test]
2208    fn test_fetch_auth_server_metadata_rejects_issuer_mismatch() {
2209        smol::block_on(async {
2210            let client = make_fake_http_client(|req| {
2211                Box::pin(async move {
2212                    let uri = req.uri().to_string();
2213                    if uri.contains(".well-known/oauth-authorization-server") {
2214                        // Response claims to be a different issuer.
2215                        json_response(
2216                            200,
2217                            r#"{
2218                                "issuer": "https://evil.example.com",
2219                                "authorization_endpoint": "https://evil.example.com/authorize",
2220                                "token_endpoint": "https://evil.example.com/token",
2221                                "code_challenge_methods_supported": ["S256"]
2222                            }"#,
2223                        )
2224                    } else {
2225                        json_response(404, "{}")
2226                    }
2227                })
2228            });
2229
2230            let issuer = Url::parse("https://auth.example.com").unwrap();
2231            let result = fetch_auth_server_metadata(&client, &issuer).await;
2232
2233            assert!(result.is_err());
2234            let err_msg = result.unwrap_err().to_string();
2235            assert!(
2236                err_msg.contains("issuer mismatch"),
2237                "unexpected error: {}",
2238                err_msg
2239            );
2240        });
2241    }
2242
2243    // -- Full discover integration tests -------------------------------------
2244
2245    #[test]
2246    fn test_full_discover_with_cimd() {
2247        smol::block_on(async {
2248            let client = make_fake_http_client(|req| {
2249                Box::pin(async move {
2250                    let uri = req.uri().to_string();
2251                    if uri.contains("oauth-protected-resource") {
2252                        json_response(
2253                            200,
2254                            r#"{
2255                                "resource": "https://mcp.example.com",
2256                                "authorization_servers": ["https://auth.example.com"],
2257                                "scopes_supported": ["mcp:read"]
2258                            }"#,
2259                        )
2260                    } else if uri.contains("oauth-authorization-server") {
2261                        json_response(
2262                            200,
2263                            r#"{
2264                                "issuer": "https://auth.example.com",
2265                                "authorization_endpoint": "https://auth.example.com/authorize",
2266                                "token_endpoint": "https://auth.example.com/token",
2267                                "code_challenge_methods_supported": ["S256"],
2268                                "client_id_metadata_document_supported": true
2269                            }"#,
2270                        )
2271                    } else {
2272                        json_response(404, "{}")
2273                    }
2274                })
2275            });
2276
2277            let server_url = Url::parse("https://mcp.example.com").unwrap();
2278            let www_auth = WwwAuthenticate {
2279                resource_metadata: None,
2280                scope: None,
2281                error: None,
2282                error_description: None,
2283            };
2284
2285            let discovery = discover(&client, &server_url, &www_auth).await.unwrap();
2286            let registration =
2287                resolve_client_registration(&client, &discovery, "http://127.0.0.1:12345/callback")
2288                    .await
2289                    .unwrap();
2290
2291            assert_eq!(registration.client_id, CIMD_URL);
2292            assert_eq!(registration.client_secret, None);
2293            assert_eq!(discovery.scopes, vec!["mcp:read"]);
2294        });
2295    }
2296
2297    #[test]
2298    fn test_full_discover_with_dcr_fallback() {
2299        smol::block_on(async {
2300            let client = make_fake_http_client(|req| {
2301                Box::pin(async move {
2302                    let uri = req.uri().to_string();
2303                    if uri.contains("oauth-protected-resource") {
2304                        json_response(
2305                            200,
2306                            r#"{
2307                                "resource": "https://mcp.example.com",
2308                                "authorization_servers": ["https://auth.example.com"]
2309                            }"#,
2310                        )
2311                    } else if uri.contains("oauth-authorization-server") {
2312                        json_response(
2313                            200,
2314                            r#"{
2315                                "issuer": "https://auth.example.com",
2316                                "authorization_endpoint": "https://auth.example.com/authorize",
2317                                "token_endpoint": "https://auth.example.com/token",
2318                                "registration_endpoint": "https://auth.example.com/register",
2319                                "code_challenge_methods_supported": ["S256"],
2320                                "client_id_metadata_document_supported": false
2321                            }"#,
2322                        )
2323                    } else if uri.contains("/register") {
2324                        json_response(
2325                            201,
2326                            r#"{
2327                                "client_id": "dcr-minted-id-123",
2328                                "client_secret": "dcr-secret-456"
2329                            }"#,
2330                        )
2331                    } else {
2332                        json_response(404, "{}")
2333                    }
2334                })
2335            });
2336
2337            let server_url = Url::parse("https://mcp.example.com").unwrap();
2338            let www_auth = WwwAuthenticate {
2339                resource_metadata: None,
2340                scope: Some(vec!["files:read".into()]),
2341                error: None,
2342                error_description: None,
2343            };
2344
2345            let discovery = discover(&client, &server_url, &www_auth).await.unwrap();
2346            let registration =
2347                resolve_client_registration(&client, &discovery, "http://127.0.0.1:9999/callback")
2348                    .await
2349                    .unwrap();
2350
2351            assert_eq!(registration.client_id, "dcr-minted-id-123");
2352            assert_eq!(
2353                registration.client_secret.as_deref(),
2354                Some("dcr-secret-456")
2355            );
2356            assert_eq!(discovery.scopes, vec!["files:read"]);
2357        });
2358    }
2359
2360    #[test]
2361    fn test_discover_fails_without_pkce_support() {
2362        smol::block_on(async {
2363            let client = make_fake_http_client(|req| {
2364                Box::pin(async move {
2365                    let uri = req.uri().to_string();
2366                    if uri.contains("oauth-protected-resource") {
2367                        json_response(
2368                            200,
2369                            r#"{
2370                                "resource": "https://mcp.example.com",
2371                                "authorization_servers": ["https://auth.example.com"]
2372                            }"#,
2373                        )
2374                    } else if uri.contains("oauth-authorization-server") {
2375                        json_response(
2376                            200,
2377                            r#"{
2378                                "issuer": "https://auth.example.com",
2379                                "authorization_endpoint": "https://auth.example.com/authorize",
2380                                "token_endpoint": "https://auth.example.com/token"
2381                            }"#,
2382                        )
2383                    } else {
2384                        json_response(404, "{}")
2385                    }
2386                })
2387            });
2388
2389            let server_url = Url::parse("https://mcp.example.com").unwrap();
2390            let www_auth = WwwAuthenticate {
2391                resource_metadata: None,
2392                scope: None,
2393                error: None,
2394                error_description: None,
2395            };
2396
2397            let result = discover(&client, &server_url, &www_auth).await;
2398            assert!(result.is_err());
2399            let err_msg = result.unwrap_err().to_string();
2400            assert!(
2401                err_msg.contains("code_challenge_methods_supported"),
2402                "unexpected error: {}",
2403                err_msg
2404            );
2405        });
2406    }
2407
2408    // -- Token exchange integration tests ------------------------------------
2409
2410    #[test]
2411    fn test_exchange_code_success() {
2412        smol::block_on(async {
2413            let client = make_fake_http_client(|req| {
2414                Box::pin(async move {
2415                    let uri = req.uri().to_string();
2416                    if uri.contains("/token") {
2417                        json_response(
2418                            200,
2419                            r#"{
2420                                "access_token": "new_access_token",
2421                                "refresh_token": "new_refresh_token",
2422                                "expires_in": 3600,
2423                                "token_type": "Bearer"
2424                            }"#,
2425                        )
2426                    } else {
2427                        json_response(404, "{}")
2428                    }
2429                })
2430            });
2431
2432            let metadata = AuthServerMetadata {
2433                issuer: Url::parse("https://auth.example.com").unwrap(),
2434                authorization_endpoint: Url::parse("https://auth.example.com/authorize").unwrap(),
2435                token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
2436                registration_endpoint: None,
2437                scopes_supported: None,
2438                code_challenge_methods_supported: Some(vec!["S256".into()]),
2439                client_id_metadata_document_supported: true,
2440            };
2441
2442            let tokens = exchange_code(
2443                &client,
2444                &metadata,
2445                "auth_code_123",
2446                CIMD_URL,
2447                "http://127.0.0.1:9999/callback",
2448                "verifier_abc",
2449                "https://mcp.example.com",
2450                None,
2451            )
2452            .await
2453            .unwrap();
2454
2455            assert_eq!(tokens.access_token, "new_access_token");
2456            assert_eq!(tokens.refresh_token.as_deref(), Some("new_refresh_token"));
2457            assert!(tokens.expires_at.is_some());
2458        });
2459    }
2460
2461    #[test]
2462    fn test_refresh_tokens_success() {
2463        smol::block_on(async {
2464            let client = make_fake_http_client(|req| {
2465                Box::pin(async move {
2466                    let uri = req.uri().to_string();
2467                    if uri.contains("/token") {
2468                        json_response(
2469                            200,
2470                            r#"{
2471                                "access_token": "refreshed_token",
2472                                "expires_in": 1800,
2473                                "token_type": "Bearer"
2474                            }"#,
2475                        )
2476                    } else {
2477                        json_response(404, "{}")
2478                    }
2479                })
2480            });
2481
2482            let token_endpoint = Url::parse("https://auth.example.com/token").unwrap();
2483
2484            let tokens = refresh_tokens(
2485                &client,
2486                &token_endpoint,
2487                "old_refresh_token",
2488                CIMD_URL,
2489                "https://mcp.example.com",
2490                None,
2491            )
2492            .await
2493            .unwrap();
2494
2495            assert_eq!(tokens.access_token, "refreshed_token");
2496            assert_eq!(tokens.refresh_token, None);
2497            assert!(tokens.expires_at.is_some());
2498        });
2499    }
2500
2501    #[test]
2502    fn test_exchange_code_failure() {
2503        smol::block_on(async {
2504            let client = make_fake_http_client(|_req| {
2505                Box::pin(async move { json_response(400, r#"{"error": "invalid_grant"}"#) })
2506            });
2507
2508            let metadata = AuthServerMetadata {
2509                issuer: Url::parse("https://auth.example.com").unwrap(),
2510                authorization_endpoint: Url::parse("https://auth.example.com/authorize").unwrap(),
2511                token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
2512                registration_endpoint: None,
2513                scopes_supported: None,
2514                code_challenge_methods_supported: Some(vec!["S256".into()]),
2515                client_id_metadata_document_supported: true,
2516            };
2517
2518            let result = exchange_code(
2519                &client,
2520                &metadata,
2521                "bad_code",
2522                "client",
2523                "http://127.0.0.1:1/callback",
2524                "verifier",
2525                "https://mcp.example.com",
2526                None,
2527            )
2528            .await;
2529
2530            let err = result.unwrap_err();
2531            let token_error = err
2532                .downcast_ref::<OAuthTokenError>()
2533                .expect("expected OAuthTokenError");
2534            assert_eq!(
2535                *token_error,
2536                OAuthTokenError {
2537                    error: "invalid_grant".into(),
2538                    error_description: None,
2539                }
2540            );
2541        });
2542    }
2543
2544    // -- DCR integration tests -----------------------------------------------
2545
2546    #[test]
2547    fn test_perform_dcr() {
2548        smol::block_on(async {
2549            let client = make_fake_http_client(|_req| {
2550                Box::pin(async move {
2551                    json_response(
2552                        201,
2553                        r#"{
2554                            "client_id": "dynamic-client-001",
2555                            "client_secret": "dynamic-secret-001"
2556                        }"#,
2557                    )
2558                })
2559            });
2560
2561            let endpoint = Url::parse("https://auth.example.com/register").unwrap();
2562            let registration = perform_dcr(&client, &endpoint, "http://127.0.0.1:9999/callback")
2563                .await
2564                .unwrap();
2565
2566            assert_eq!(registration.client_id, "dynamic-client-001");
2567            assert_eq!(
2568                registration.client_secret.as_deref(),
2569                Some("dynamic-secret-001")
2570            );
2571        });
2572    }
2573
2574    #[test]
2575    fn test_perform_dcr_failure() {
2576        smol::block_on(async {
2577            let client = make_fake_http_client(|_req| {
2578                Box::pin(
2579                    async move { json_response(403, r#"{"error": "registration_not_allowed"}"#) },
2580                )
2581            });
2582
2583            let endpoint = Url::parse("https://auth.example.com/register").unwrap();
2584            let result = perform_dcr(&client, &endpoint, "http://127.0.0.1:9999/callback").await;
2585
2586            assert!(result.is_err());
2587            assert!(result.unwrap_err().to_string().contains("403"));
2588        });
2589    }
2590
2591    // -- OAuthCallback parse tests -------------------------------------------
2592
2593    #[test]
2594    fn test_oauth_callback_parse_query() {
2595        let callback = OAuthCallback::parse_query("code=test_auth_code&state=test_state").unwrap();
2596        assert_eq!(callback.code, "test_auth_code");
2597        assert_eq!(callback.state, "test_state");
2598    }
2599
2600    #[test]
2601    fn test_oauth_callback_parse_query_reversed_order() {
2602        let callback = OAuthCallback::parse_query("state=test_state&code=test_auth_code").unwrap();
2603        assert_eq!(callback.code, "test_auth_code");
2604        assert_eq!(callback.state, "test_state");
2605    }
2606
2607    #[test]
2608    fn test_oauth_callback_parse_query_with_extra_params() {
2609        let callback =
2610            OAuthCallback::parse_query("code=test_auth_code&state=test_state&extra=ignored")
2611                .unwrap();
2612        assert_eq!(callback.code, "test_auth_code");
2613        assert_eq!(callback.state, "test_state");
2614    }
2615
2616    #[test]
2617    fn test_oauth_callback_parse_query_missing_code() {
2618        let result = OAuthCallback::parse_query("state=test_state");
2619        assert!(result.is_err());
2620        assert!(result.unwrap_err().to_string().contains("code"));
2621    }
2622
2623    #[test]
2624    fn test_oauth_callback_parse_query_missing_state() {
2625        let result = OAuthCallback::parse_query("code=test_auth_code");
2626        assert!(result.is_err());
2627        assert!(result.unwrap_err().to_string().contains("state"));
2628    }
2629
2630    #[test]
2631    fn test_oauth_callback_parse_query_empty_code() {
2632        let result = OAuthCallback::parse_query("code=&state=test_state");
2633        assert!(result.is_err());
2634    }
2635
2636    #[test]
2637    fn test_oauth_callback_parse_query_empty_state() {
2638        let result = OAuthCallback::parse_query("code=test_auth_code&state=");
2639        assert!(result.is_err());
2640    }
2641
2642    #[test]
2643    fn test_oauth_callback_parse_query_url_encoded_values() {
2644        let callback = OAuthCallback::parse_query("code=abc%20def&state=test%3Dstate").unwrap();
2645        assert_eq!(callback.code, "abc def");
2646        assert_eq!(callback.state, "test=state");
2647    }
2648
2649    #[test]
2650    fn test_oauth_callback_parse_query_error_response() {
2651        let result = OAuthCallback::parse_query(
2652            "error=access_denied&error_description=User%20denied%20access&state=abc",
2653        );
2654        assert!(result.is_err());
2655        let err_msg = result.unwrap_err().to_string();
2656        assert!(
2657            err_msg.contains("access_denied"),
2658            "unexpected error: {}",
2659            err_msg
2660        );
2661        assert!(
2662            err_msg.contains("User denied access"),
2663            "unexpected error: {}",
2664            err_msg
2665        );
2666    }
2667
2668    #[test]
2669    fn test_oauth_callback_parse_query_error_without_description() {
2670        let result = OAuthCallback::parse_query("error=server_error&state=abc");
2671        assert!(result.is_err());
2672        let err_msg = result.unwrap_err().to_string();
2673        assert!(
2674            err_msg.contains("server_error"),
2675            "unexpected error: {}",
2676            err_msg
2677        );
2678        assert!(
2679            err_msg.contains("no description"),
2680            "unexpected error: {}",
2681            err_msg
2682        );
2683    }
2684
2685    // -- McpOAuthTokenProvider tests -----------------------------------------
2686
2687    fn make_test_session(
2688        access_token: &str,
2689        refresh_token: Option<&str>,
2690        expires_at: Option<SystemTime>,
2691    ) -> OAuthSession {
2692        OAuthSession {
2693            token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
2694            resource: Url::parse("https://mcp.example.com").unwrap(),
2695            client_registration: OAuthClientRegistration {
2696                client_id: "test-client".into(),
2697                client_secret: None,
2698            },
2699            tokens: OAuthTokens {
2700                access_token: access_token.into(),
2701                refresh_token: refresh_token.map(String::from),
2702                expires_at,
2703            },
2704        }
2705    }
2706
2707    #[test]
2708    fn test_mcp_oauth_provider_returns_none_when_token_expired() {
2709        let expired = SystemTime::now() - Duration::from_secs(60);
2710        let session = make_test_session("stale-token", Some("rt"), Some(expired));
2711        let provider = McpOAuthTokenProvider::new(
2712            session,
2713            make_fake_http_client(|_| Box::pin(async { unreachable!() })),
2714            None,
2715        );
2716
2717        assert_eq!(provider.access_token(), None);
2718    }
2719
2720    #[test]
2721    fn test_mcp_oauth_provider_returns_token_when_not_expired() {
2722        let far_future = SystemTime::now() + Duration::from_secs(3600);
2723        let session = make_test_session("valid-token", Some("rt"), Some(far_future));
2724        let provider = McpOAuthTokenProvider::new(
2725            session,
2726            make_fake_http_client(|_| Box::pin(async { unreachable!() })),
2727            None,
2728        );
2729
2730        assert_eq!(provider.access_token().as_deref(), Some("valid-token"));
2731    }
2732
2733    #[test]
2734    fn test_mcp_oauth_provider_returns_token_when_no_expiry() {
2735        let session = make_test_session("no-expiry-token", Some("rt"), None);
2736        let provider = McpOAuthTokenProvider::new(
2737            session,
2738            make_fake_http_client(|_| Box::pin(async { unreachable!() })),
2739            None,
2740        );
2741
2742        assert_eq!(provider.access_token().as_deref(), Some("no-expiry-token"));
2743    }
2744
2745    #[test]
2746    fn test_mcp_oauth_provider_refresh_without_refresh_token_returns_false() {
2747        smol::block_on(async {
2748            let session = make_test_session("token", None, None);
2749            let provider = McpOAuthTokenProvider::new(
2750                session,
2751                make_fake_http_client(|_| {
2752                    Box::pin(async { unreachable!("no HTTP call expected") })
2753                }),
2754                None,
2755            );
2756
2757            let refreshed = provider.try_refresh().await.unwrap();
2758            assert!(!refreshed);
2759        });
2760    }
2761
2762    #[test]
2763    fn test_mcp_oauth_provider_refresh_updates_session_and_notifies_channel() {
2764        smol::block_on(async {
2765            let session = make_test_session("old-access", Some("my-refresh-token"), None);
2766            let (tx, mut rx) = futures::channel::mpsc::unbounded();
2767
2768            let http_client = make_fake_http_client(|_req| {
2769                Box::pin(async {
2770                    json_response(
2771                        200,
2772                        r#"{
2773                            "access_token": "new-access",
2774                            "refresh_token": "new-refresh",
2775                            "expires_in": 1800
2776                        }"#,
2777                    )
2778                })
2779            });
2780
2781            let provider = McpOAuthTokenProvider::new(session, http_client, Some(tx));
2782
2783            let refreshed = provider.try_refresh().await.unwrap();
2784            assert!(refreshed);
2785            assert_eq!(provider.access_token().as_deref(), Some("new-access"));
2786
2787            let notified_session = rx.try_recv().expect("channel should have a session");
2788            assert_eq!(notified_session.tokens.access_token, "new-access");
2789            assert_eq!(
2790                notified_session.tokens.refresh_token.as_deref(),
2791                Some("new-refresh")
2792            );
2793        });
2794    }
2795
2796    #[test]
2797    fn test_mcp_oauth_provider_refresh_preserves_old_refresh_token_when_server_omits_it() {
2798        smol::block_on(async {
2799            let session = make_test_session("old-access", Some("original-refresh"), None);
2800            let (tx, mut rx) = futures::channel::mpsc::unbounded();
2801
2802            let http_client = make_fake_http_client(|_req| {
2803                Box::pin(async {
2804                    json_response(
2805                        200,
2806                        r#"{
2807                            "access_token": "new-access",
2808                            "expires_in": 900
2809                        }"#,
2810                    )
2811                })
2812            });
2813
2814            let provider = McpOAuthTokenProvider::new(session, http_client, Some(tx));
2815
2816            let refreshed = provider.try_refresh().await.unwrap();
2817            assert!(refreshed);
2818
2819            let notified_session = rx.try_recv().expect("channel should have a session");
2820            assert_eq!(notified_session.tokens.access_token, "new-access");
2821            assert_eq!(
2822                notified_session.tokens.refresh_token.as_deref(),
2823                Some("original-refresh"),
2824            );
2825        });
2826    }
2827
2828    #[test]
2829    fn test_mcp_oauth_provider_refresh_returns_false_on_http_error() {
2830        smol::block_on(async {
2831            let session = make_test_session("old-access", Some("my-refresh"), None);
2832
2833            let http_client = make_fake_http_client(|_req| {
2834                Box::pin(async { json_response(401, r#"{"error": "invalid_grant"}"#) })
2835            });
2836
2837            let provider = McpOAuthTokenProvider::new(session, http_client, None);
2838
2839            let refreshed = provider.try_refresh().await.unwrap();
2840            assert!(!refreshed);
2841            // The old token should still be in place.
2842            assert_eq!(provider.access_token().as_deref(), Some("old-access"));
2843        });
2844    }
2845}