oauth.rs

   1//! OAuth 2.0 authentication for MCP servers using the Authorization Code +
   2//! PKCE flow, per the MCP spec's OAuth profile.
   3//!
   4//! The flow is split into two phases:
   5//!
   6//! 1. **Discovery** ([`discover`]) fetches Protected Resource Metadata and
   7//!    Authorization Server Metadata. This can happen early (e.g. on a 401
   8//!    during server startup) because it doesn't need the redirect URI yet.
   9//!
  10//! 2. **Client registration** ([`resolve_client_registration`]) is separate
  11//!    because DCR requires the actual loopback redirect URI, which includes an
  12//!    ephemeral port that only exists once the callback server has started.
  13//!
  14//! After authentication, the full state is captured in [`OAuthSession`] which
  15//! is persisted to the keychain. On next startup, the stored session feeds
  16//! directly into [`McpOAuthTokenProvider`], giving a refresh-capable provider
  17//! without requiring another browser flow.
  18
  19use anyhow::{Context as _, Result, anyhow, bail};
  20use async_trait::async_trait;
  21use base64::Engine as _;
  22use futures::AsyncReadExt as _;
  23use futures::channel::mpsc;
  24use http_client::{AsyncBody, HttpClient, Request};
  25use parking_lot::Mutex as SyncMutex;
  26use rand::Rng as _;
  27use serde::{Deserialize, Serialize};
  28use sha2::{Digest, Sha256};
  29
  30use std::str::FromStr;
  31use std::sync::Arc;
  32use std::time::{Duration, SystemTime};
  33use url::Url;
  34use util::ResultExt as _;
  35
  36/// The CIMD URL where Zed's OAuth client metadata document is hosted.
  37pub const CIMD_URL: &str = "https://zed.dev/oauth/client-metadata.json";
  38
  39/// Validate that a URL is safe to use as an OAuth endpoint.
  40///
  41/// OAuth endpoints carry sensitive material (authorization codes, PKCE
  42/// verifiers, tokens) and must use TLS. Plain HTTP is only permitted for
  43/// loopback addresses, per RFC 8252 Section 8.3.
  44fn require_https_or_loopback(url: &Url) -> Result<()> {
  45    if url.scheme() == "https" {
  46        return Ok(());
  47    }
  48    if url.scheme() == "http" {
  49        if let Some(host) = url.host() {
  50            match host {
  51                url::Host::Ipv4(ip) if ip.is_loopback() => return Ok(()),
  52                url::Host::Ipv6(ip) if ip.is_loopback() => return Ok(()),
  53                url::Host::Domain(d) if d.eq_ignore_ascii_case("localhost") => return Ok(()),
  54                _ => {}
  55            }
  56        }
  57    }
  58    bail!(
  59        "OAuth endpoint must use HTTPS (got {}://{})",
  60        url.scheme(),
  61        url.host_str().unwrap_or("?")
  62    )
  63}
  64
  65/// Validate that a URL is safe to use as an OAuth endpoint, including SSRF
  66/// protections against private/reserved IP ranges.
  67///
  68/// This wraps [`require_https_or_loopback`] and adds IP-range checks to prevent
  69/// an attacker-controlled MCP server from directing Zed to fetch internal
  70/// network resources via metadata URLs.
  71///
  72/// **Known limitation:** Domain-name URLs that resolve to private IPs are *not*
  73/// blocked here — full mitigation requires resolver-level validation (e.g. a
  74/// custom `Resolve` implementation). This function only blocks IP-literal URLs.
  75fn validate_oauth_url(url: &Url) -> Result<()> {
  76    require_https_or_loopback(url)?;
  77
  78    if let Some(host) = url.host() {
  79        match host {
  80            url::Host::Ipv4(ip) => {
  81                // Loopback is already allowed by require_https_or_loopback.
  82                if ip.is_private() || ip.is_link_local() || ip.is_broadcast() || ip.is_unspecified()
  83                {
  84                    bail!(
  85                        "OAuth endpoint must not point to private/reserved IP: {}",
  86                        ip
  87                    );
  88                }
  89            }
  90            url::Host::Ipv6(ip) => {
  91                // Check for IPv4-mapped IPv6 addresses (::ffff:a.b.c.d) which
  92                // could bypass the IPv4 checks above.
  93                if let Some(mapped_v4) = ip.to_ipv4_mapped() {
  94                    if mapped_v4.is_private()
  95                        || mapped_v4.is_link_local()
  96                        || mapped_v4.is_broadcast()
  97                        || mapped_v4.is_unspecified()
  98                    {
  99                        bail!(
 100                            "OAuth endpoint must not point to private/reserved IP: ::ffff:{}",
 101                            mapped_v4
 102                        );
 103                    }
 104                }
 105
 106                if ip.is_unspecified() || ip.is_multicast() {
 107                    bail!(
 108                        "OAuth endpoint must not point to reserved IPv6 address: {}",
 109                        ip
 110                    );
 111                }
 112                // IPv6 Unique Local Addresses (fc00::/7). is_unique_local() is
 113                // nightly-only, so check the prefix manually.
 114                if (ip.segments()[0] & 0xfe00) == 0xfc00 {
 115                    bail!(
 116                        "OAuth endpoint must not point to IPv6 unique-local address: {}",
 117                        ip
 118                    );
 119                }
 120            }
 121            url::Host::Domain(_) => {
 122                // Domain-based SSRF prevention requires resolver-level checks.
 123                // See known limitation in the doc comment above.
 124            }
 125        }
 126    }
 127
 128    Ok(())
 129}
 130
 131/// Parsed from the MCP server's WWW-Authenticate header or well-known endpoint
 132/// per RFC 9728 (OAuth 2.0 Protected Resource Metadata).
 133#[derive(Debug, Clone, Serialize, Deserialize)]
 134pub struct ProtectedResourceMetadata {
 135    pub resource: Url,
 136    pub authorization_servers: Vec<Url>,
 137    pub scopes_supported: Option<Vec<String>>,
 138}
 139
 140/// Parsed from the authorization server's .well-known endpoint
 141/// per RFC 8414 (OAuth 2.0 Authorization Server Metadata).
 142#[derive(Debug, Clone, Serialize, Deserialize)]
 143pub struct AuthServerMetadata {
 144    pub issuer: Url,
 145    pub authorization_endpoint: Url,
 146    pub token_endpoint: Url,
 147    pub registration_endpoint: Option<Url>,
 148    pub scopes_supported: Option<Vec<String>>,
 149    pub code_challenge_methods_supported: Option<Vec<String>>,
 150    pub client_id_metadata_document_supported: bool,
 151}
 152
 153/// The result of client registration — either CIMD or DCR.
 154#[derive(Clone, Serialize, Deserialize)]
 155pub struct OAuthClientRegistration {
 156    pub client_id: String,
 157    /// Only present for DCR-minted registrations.
 158    pub client_secret: Option<String>,
 159}
 160
 161impl std::fmt::Debug for OAuthClientRegistration {
 162    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 163        f.debug_struct("OAuthClientRegistration")
 164            .field("client_id", &self.client_id)
 165            .field(
 166                "client_secret",
 167                &self.client_secret.as_ref().map(|_| "[redacted]"),
 168            )
 169            .finish()
 170    }
 171}
 172
 173/// Access and refresh tokens obtained from the token endpoint.
 174#[derive(Clone, Serialize, Deserialize)]
 175pub struct OAuthTokens {
 176    pub access_token: String,
 177    pub refresh_token: Option<String>,
 178    pub expires_at: Option<SystemTime>,
 179}
 180
 181impl std::fmt::Debug for OAuthTokens {
 182    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 183        f.debug_struct("OAuthTokens")
 184            .field("access_token", &"[redacted]")
 185            .field(
 186                "refresh_token",
 187                &self.refresh_token.as_ref().map(|_| "[redacted]"),
 188            )
 189            .field("expires_at", &self.expires_at)
 190            .finish()
 191    }
 192}
 193
 194/// Everything discovered before the browser flow starts. Client registration is
 195/// resolved separately, once the real redirect URI is known.
 196#[derive(Debug, Clone, Serialize, Deserialize)]
 197pub struct OAuthDiscovery {
 198    pub resource_metadata: ProtectedResourceMetadata,
 199    pub auth_server_metadata: AuthServerMetadata,
 200    pub scopes: Vec<String>,
 201}
 202
 203/// The persisted OAuth session for a context server.
 204///
 205/// Stored in the keychain so startup can restore a refresh-capable provider
 206/// without another browser flow. Deliberately excludes the full discovery
 207/// metadata to keep the serialized size well within keychain item limits.
 208#[derive(Clone, Serialize, Deserialize)]
 209pub struct OAuthSession {
 210    pub token_endpoint: Url,
 211    pub resource: Url,
 212    pub client_registration: OAuthClientRegistration,
 213    pub tokens: OAuthTokens,
 214}
 215
 216impl std::fmt::Debug for OAuthSession {
 217    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 218        f.debug_struct("OAuthSession")
 219            .field("token_endpoint", &self.token_endpoint)
 220            .field("resource", &self.resource)
 221            .field("client_registration", &self.client_registration)
 222            .field("tokens", &self.tokens)
 223            .finish()
 224    }
 225}
 226
 227/// Error codes defined by RFC 6750 Section 3.1 for Bearer token authentication.
 228#[derive(Debug, Clone, Copy, PartialEq, Eq)]
 229pub enum BearerError {
 230    /// The request is missing a required parameter, includes an unsupported
 231    /// parameter or parameter value, or is otherwise malformed.
 232    InvalidRequest,
 233    /// The access token provided is expired, revoked, malformed, or invalid.
 234    InvalidToken,
 235    /// The request requires higher privileges than provided by the access token.
 236    InsufficientScope,
 237    /// An unrecognized error code (extension or future spec addition).
 238    Other,
 239}
 240
 241impl BearerError {
 242    fn parse(value: &str) -> Self {
 243        match value {
 244            "invalid_request" => BearerError::InvalidRequest,
 245            "invalid_token" => BearerError::InvalidToken,
 246            "insufficient_scope" => BearerError::InsufficientScope,
 247            _ => BearerError::Other,
 248        }
 249    }
 250}
 251
 252/// Fields extracted from a `WWW-Authenticate: Bearer` header.
 253///
 254/// Per RFC 9728 Section 5.1, MCP servers include `resource_metadata` to point
 255/// at the Protected Resource Metadata document. The optional `scope` parameter
 256/// (RFC 6750 Section 3) indicates scopes required for the request.
 257#[derive(Debug, Clone, PartialEq, Eq)]
 258pub struct WwwAuthenticate {
 259    pub resource_metadata: Option<Url>,
 260    pub scope: Option<Vec<String>>,
 261    /// The parsed `error` parameter per RFC 6750 Section 3.1.
 262    pub error: Option<BearerError>,
 263    pub error_description: Option<String>,
 264}
 265
 266/// Parse a `WWW-Authenticate` header value.
 267///
 268/// Expects the `Bearer` scheme followed by comma-separated `key="value"` pairs.
 269/// Per RFC 6750 and RFC 9728, the relevant parameters are:
 270/// - `resource_metadata` — URL of the Protected Resource Metadata document
 271/// - `scope` — space-separated list of required scopes
 272/// - `error` — error code (e.g. "insufficient_scope")
 273/// - `error_description` — human-readable error description
 274pub fn parse_www_authenticate(header: &str) -> Result<WwwAuthenticate> {
 275    let header = header.trim();
 276
 277    let params_str = if header.len() >= 6 && header[..6].eq_ignore_ascii_case("bearer") {
 278        header[6..].trim()
 279    } else {
 280        bail!("WWW-Authenticate header does not use Bearer scheme");
 281    };
 282
 283    if params_str.is_empty() {
 284        return Ok(WwwAuthenticate {
 285            resource_metadata: None,
 286            scope: None,
 287            error: None,
 288            error_description: None,
 289        });
 290    }
 291
 292    let params = parse_auth_params(params_str);
 293
 294    let resource_metadata = params
 295        .get("resource_metadata")
 296        .map(|v| Url::parse(v))
 297        .transpose()
 298        .map_err(|e| anyhow!("invalid resource_metadata URL: {}", e))?;
 299
 300    let scope = params
 301        .get("scope")
 302        .map(|v| v.split_whitespace().map(String::from).collect());
 303
 304    let error = params.get("error").map(|v| BearerError::parse(v));
 305    let error_description = params.get("error_description").cloned();
 306
 307    Ok(WwwAuthenticate {
 308        resource_metadata,
 309        scope,
 310        error,
 311        error_description,
 312    })
 313}
 314
 315/// Parse comma-separated `key="value"` or `key=token` parameters from an
 316/// auth-param list (RFC 7235 Section 2.1).
 317fn parse_auth_params(input: &str) -> collections::HashMap<String, String> {
 318    let mut params = collections::HashMap::default();
 319    let mut remaining = input.trim();
 320
 321    while !remaining.is_empty() {
 322        // Skip leading whitespace and commas.
 323        remaining = remaining.trim_start_matches(|c: char| c == ',' || c.is_whitespace());
 324        if remaining.is_empty() {
 325            break;
 326        }
 327
 328        // Find the key (everything before '=').
 329        let eq_pos = match remaining.find('=') {
 330            Some(pos) => pos,
 331            None => break,
 332        };
 333
 334        let key = remaining[..eq_pos].trim().to_lowercase();
 335        remaining = &remaining[eq_pos + 1..];
 336        remaining = remaining.trim_start();
 337
 338        // Parse the value: either quoted or unquoted (token).
 339        let value;
 340        if remaining.starts_with('"') {
 341            // Quoted string: find the closing quote, handling escaped chars.
 342            remaining = &remaining[1..]; // skip opening quote
 343            let mut val = String::new();
 344            let mut chars = remaining.char_indices();
 345            loop {
 346                match chars.next() {
 347                    Some((_, '\\')) => {
 348                        // Escaped character — take the next char literally.
 349                        if let Some((_, c)) = chars.next() {
 350                            val.push(c);
 351                        }
 352                    }
 353                    Some((i, '"')) => {
 354                        remaining = &remaining[i + 1..];
 355                        break;
 356                    }
 357                    Some((_, c)) => val.push(c),
 358                    None => {
 359                        remaining = "";
 360                        break;
 361                    }
 362                }
 363            }
 364            value = val;
 365        } else {
 366            // Unquoted token: read until comma or whitespace.
 367            let end = remaining
 368                .find(|c: char| c == ',' || c.is_whitespace())
 369                .unwrap_or(remaining.len());
 370            value = remaining[..end].to_string();
 371            remaining = &remaining[end..];
 372        }
 373
 374        if !key.is_empty() {
 375            params.insert(key, value);
 376        }
 377    }
 378
 379    params
 380}
 381
 382/// Construct the well-known Protected Resource Metadata URIs for a given MCP
 383/// server URL, per RFC 9728 Section 3.
 384///
 385/// Returns URIs in priority order:
 386/// 1. Path-specific: `https://<host>/.well-known/oauth-protected-resource/<path>`
 387/// 2. Root: `https://<host>/.well-known/oauth-protected-resource`
 388pub fn protected_resource_metadata_urls(server_url: &Url) -> Vec<Url> {
 389    let mut urls = Vec::new();
 390    let base = format!("{}://{}", server_url.scheme(), server_url.authority());
 391
 392    let path = server_url.path().trim_start_matches('/');
 393    if !path.is_empty() {
 394        if let Ok(url) = Url::parse(&format!(
 395            "{}/.well-known/oauth-protected-resource/{}",
 396            base, path
 397        )) {
 398            urls.push(url);
 399        }
 400    }
 401
 402    if let Ok(url) = Url::parse(&format!("{}/.well-known/oauth-protected-resource", base)) {
 403        urls.push(url);
 404    }
 405
 406    urls
 407}
 408
 409/// Construct the well-known Authorization Server Metadata URIs for a given
 410/// issuer URL, per RFC 8414 Section 3.1 and Section 5 (OIDC compat).
 411///
 412/// Returns URIs in priority order, which differs depending on whether the
 413/// issuer URL has a path component.
 414pub fn auth_server_metadata_urls(issuer: &Url) -> Vec<Url> {
 415    let mut urls = Vec::new();
 416    let base = format!("{}://{}", issuer.scheme(), issuer.authority());
 417    let path = issuer.path().trim_matches('/');
 418
 419    if !path.is_empty() {
 420        // Issuer with path: try path-inserted variants first.
 421        if let Ok(url) = Url::parse(&format!(
 422            "{}/.well-known/oauth-authorization-server/{}",
 423            base, path
 424        )) {
 425            urls.push(url);
 426        }
 427        if let Ok(url) = Url::parse(&format!(
 428            "{}/.well-known/openid-configuration/{}",
 429            base, path
 430        )) {
 431            urls.push(url);
 432        }
 433        if let Ok(url) = Url::parse(&format!(
 434            "{}/{}/.well-known/openid-configuration",
 435            base, path
 436        )) {
 437            urls.push(url);
 438        }
 439    } else {
 440        // No path: standard well-known locations.
 441        if let Ok(url) = Url::parse(&format!("{}/.well-known/oauth-authorization-server", base)) {
 442            urls.push(url);
 443        }
 444        if let Ok(url) = Url::parse(&format!("{}/.well-known/openid-configuration", base)) {
 445            urls.push(url);
 446        }
 447    }
 448
 449    urls
 450}
 451
 452// -- Canonical server URI (RFC 8707) -----------------------------------------
 453
 454/// Derive the canonical resource URI for an MCP server URL, suitable for the
 455/// `resource` parameter in authorization and token requests per RFC 8707.
 456///
 457/// Lowercases the scheme and host, preserves the path (without trailing slash),
 458/// strips fragments and query strings.
 459pub fn canonical_server_uri(server_url: &Url) -> String {
 460    let mut uri = format!(
 461        "{}://{}",
 462        server_url.scheme().to_ascii_lowercase(),
 463        server_url.host_str().unwrap_or("").to_ascii_lowercase(),
 464    );
 465    if let Some(port) = server_url.port() {
 466        uri.push_str(&format!(":{}", port));
 467    }
 468    let path = server_url.path();
 469    if path != "/" {
 470        uri.push_str(path.trim_end_matches('/'));
 471    }
 472    uri
 473}
 474
 475// -- Scope selection ---------------------------------------------------------
 476
 477/// Select scopes following the MCP spec's Scope Selection Strategy:
 478/// 1. Use `scope` from the `WWW-Authenticate` challenge if present.
 479/// 2. Fall back to `scopes_supported` from Protected Resource Metadata.
 480/// 3. Return empty if neither is available.
 481pub fn select_scopes(
 482    www_authenticate: &WwwAuthenticate,
 483    resource_metadata: &ProtectedResourceMetadata,
 484) -> Vec<String> {
 485    if let Some(ref scopes) = www_authenticate.scope {
 486        if !scopes.is_empty() {
 487            return scopes.clone();
 488        }
 489    }
 490    resource_metadata
 491        .scopes_supported
 492        .clone()
 493        .unwrap_or_default()
 494}
 495
 496// -- Client registration strategy --------------------------------------------
 497
 498/// The registration approach to use, determined from auth server metadata.
 499#[derive(Debug, Clone, PartialEq, Eq)]
 500pub enum ClientRegistrationStrategy {
 501    /// The auth server supports CIMD. Use the CIMD URL as client_id directly.
 502    Cimd { client_id: String },
 503    /// The auth server has a registration endpoint. Caller must POST to it.
 504    Dcr { registration_endpoint: Url },
 505    /// No supported registration mechanism.
 506    Unavailable,
 507}
 508
 509/// Determine how to register with the authorization server, following the
 510/// spec's recommended priority: CIMD first, DCR fallback.
 511pub fn determine_registration_strategy(
 512    auth_server_metadata: &AuthServerMetadata,
 513) -> ClientRegistrationStrategy {
 514    if auth_server_metadata.client_id_metadata_document_supported {
 515        ClientRegistrationStrategy::Cimd {
 516            client_id: CIMD_URL.to_string(),
 517        }
 518    } else if let Some(ref endpoint) = auth_server_metadata.registration_endpoint {
 519        ClientRegistrationStrategy::Dcr {
 520            registration_endpoint: endpoint.clone(),
 521        }
 522    } else {
 523        ClientRegistrationStrategy::Unavailable
 524    }
 525}
 526
 527// -- PKCE (RFC 7636) ---------------------------------------------------------
 528
 529/// A PKCE code verifier and its S256 challenge.
 530#[derive(Clone)]
 531pub struct PkceChallenge {
 532    pub verifier: String,
 533    pub challenge: String,
 534}
 535
 536impl std::fmt::Debug for PkceChallenge {
 537    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 538        f.debug_struct("PkceChallenge")
 539            .field("verifier", &"[redacted]")
 540            .field("challenge", &self.challenge)
 541            .finish()
 542    }
 543}
 544
 545/// Generate a PKCE code verifier and S256 challenge per RFC 7636.
 546///
 547/// The verifier is 43 base64url characters derived from 32 random bytes.
 548/// The challenge is `BASE64URL(SHA256(verifier))`.
 549pub fn generate_pkce_challenge() -> PkceChallenge {
 550    let mut random_bytes = [0u8; 32];
 551    rand::rng().fill(&mut random_bytes);
 552    let engine = base64::engine::general_purpose::URL_SAFE_NO_PAD;
 553    let verifier = engine.encode(&random_bytes);
 554
 555    let digest = Sha256::digest(verifier.as_bytes());
 556    let challenge = engine.encode(digest);
 557
 558    PkceChallenge {
 559        verifier,
 560        challenge,
 561    }
 562}
 563
 564// -- Authorization URL construction ------------------------------------------
 565
 566/// Build the authorization URL for the OAuth Authorization Code + PKCE flow.
 567pub fn build_authorization_url(
 568    auth_server_metadata: &AuthServerMetadata,
 569    client_id: &str,
 570    redirect_uri: &str,
 571    scopes: &[String],
 572    resource: &str,
 573    pkce: &PkceChallenge,
 574    state: &str,
 575) -> Url {
 576    let mut url = auth_server_metadata.authorization_endpoint.clone();
 577    {
 578        let mut query = url.query_pairs_mut();
 579        query.append_pair("response_type", "code");
 580        query.append_pair("client_id", client_id);
 581        query.append_pair("redirect_uri", redirect_uri);
 582        if !scopes.is_empty() {
 583            query.append_pair("scope", &scopes.join(" "));
 584        }
 585        query.append_pair("resource", resource);
 586        query.append_pair("code_challenge", &pkce.challenge);
 587        query.append_pair("code_challenge_method", "S256");
 588        query.append_pair("state", state);
 589    }
 590    url
 591}
 592
 593// -- Token endpoint request bodies -------------------------------------------
 594
 595/// The JSON body returned by the token endpoint on success.
 596#[derive(Deserialize)]
 597pub struct TokenResponse {
 598    pub access_token: String,
 599    #[serde(default)]
 600    pub refresh_token: Option<String>,
 601    #[serde(default)]
 602    pub expires_in: Option<u64>,
 603    #[serde(default)]
 604    pub token_type: Option<String>,
 605}
 606
 607impl std::fmt::Debug for TokenResponse {
 608    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 609        f.debug_struct("TokenResponse")
 610            .field("access_token", &"[redacted]")
 611            .field(
 612                "refresh_token",
 613                &self.refresh_token.as_ref().map(|_| "[redacted]"),
 614            )
 615            .field("expires_in", &self.expires_in)
 616            .field("token_type", &self.token_type)
 617            .finish()
 618    }
 619}
 620
 621impl TokenResponse {
 622    /// Convert into `OAuthTokens`, computing `expires_at` from `expires_in`.
 623    pub fn into_tokens(self) -> OAuthTokens {
 624        let expires_at = self
 625            .expires_in
 626            .map(|secs| SystemTime::now() + Duration::from_secs(secs));
 627        OAuthTokens {
 628            access_token: self.access_token,
 629            refresh_token: self.refresh_token,
 630            expires_at,
 631        }
 632    }
 633}
 634
 635/// Build the form-encoded body for an authorization code token exchange.
 636pub fn token_exchange_params(
 637    code: &str,
 638    client_id: &str,
 639    redirect_uri: &str,
 640    code_verifier: &str,
 641    resource: &str,
 642) -> Vec<(&'static str, String)> {
 643    vec![
 644        ("grant_type", "authorization_code".to_string()),
 645        ("code", code.to_string()),
 646        ("redirect_uri", redirect_uri.to_string()),
 647        ("client_id", client_id.to_string()),
 648        ("code_verifier", code_verifier.to_string()),
 649        ("resource", resource.to_string()),
 650    ]
 651}
 652
 653/// Build the form-encoded body for a token refresh request.
 654pub fn token_refresh_params(
 655    refresh_token: &str,
 656    client_id: &str,
 657    resource: &str,
 658) -> Vec<(&'static str, String)> {
 659    vec![
 660        ("grant_type", "refresh_token".to_string()),
 661        ("refresh_token", refresh_token.to_string()),
 662        ("client_id", client_id.to_string()),
 663        ("resource", resource.to_string()),
 664    ]
 665}
 666
 667// -- DCR request body (RFC 7591) ---------------------------------------------
 668
 669/// Build the JSON body for a Dynamic Client Registration request.
 670///
 671/// The `redirect_uri` should be the actual loopback URI with the ephemeral
 672/// port (e.g. `http://127.0.0.1:12345/callback`). Some auth servers do strict
 673/// redirect URI matching even for loopback addresses, so we register the
 674/// exact URI we intend to use.
 675pub fn dcr_registration_body(redirect_uri: &str) -> serde_json::Value {
 676    serde_json::json!({
 677        "client_name": "Zed",
 678        "redirect_uris": [redirect_uri],
 679        "grant_types": ["authorization_code"],
 680        "response_types": ["code"],
 681        "token_endpoint_auth_method": "none"
 682    })
 683}
 684
 685// -- Discovery (async, hits real endpoints) ----------------------------------
 686
 687/// Fetch Protected Resource Metadata from the MCP server.
 688///
 689/// Tries the `resource_metadata` URL from the `WWW-Authenticate` header first,
 690/// then falls back to well-known URIs constructed from `server_url`.
 691pub async fn fetch_protected_resource_metadata(
 692    http_client: &Arc<dyn HttpClient>,
 693    server_url: &Url,
 694    www_authenticate: &WwwAuthenticate,
 695) -> Result<ProtectedResourceMetadata> {
 696    let candidate_urls = match &www_authenticate.resource_metadata {
 697        Some(url) if url.origin() == server_url.origin() => {
 698            // Try the header-provided URL first (per MCP spec: "use the resource
 699            // metadata URL from the parsed WWW-Authenticate headers when present"),
 700            // then fall back to RFC 9728 well-known URIs in case the header URL is
 701            // wrong (e.g. a buggy server that doubles the path component).
 702            let mut urls = vec![url.clone()];
 703            for fallback in protected_resource_metadata_urls(server_url) {
 704                if !urls.contains(&fallback) {
 705                    urls.push(fallback);
 706                }
 707            }
 708            urls
 709        }
 710        Some(url) => {
 711            log::warn!(
 712                "Ignoring cross-origin resource_metadata URL {} \
 713                 (server origin: {})",
 714                url,
 715                server_url.origin().unicode_serialization()
 716            );
 717            protected_resource_metadata_urls(server_url)
 718        }
 719        None => protected_resource_metadata_urls(server_url),
 720    };
 721
 722    for url in &candidate_urls {
 723        match fetch_json::<ProtectedResourceMetadataResponse>(http_client, url).await {
 724            Ok(response) => {
 725                if response.authorization_servers.is_empty() {
 726                    bail!(
 727                        "Protected Resource Metadata at {} has no authorization_servers",
 728                        url
 729                    );
 730                }
 731                return Ok(ProtectedResourceMetadata {
 732                    resource: response.resource.unwrap_or_else(|| server_url.clone()),
 733                    authorization_servers: response.authorization_servers,
 734                    scopes_supported: response.scopes_supported,
 735                });
 736            }
 737            Err(err) => {
 738                log::debug!(
 739                    "Failed to fetch Protected Resource Metadata from {}: {}",
 740                    url,
 741                    err
 742                );
 743            }
 744        }
 745    }
 746
 747    bail!(
 748        "Could not fetch Protected Resource Metadata for {}",
 749        server_url
 750    )
 751}
 752
 753/// Fetch Authorization Server Metadata, trying RFC 8414 and OIDC Discovery
 754/// endpoints in the priority order specified by the MCP spec.
 755pub async fn fetch_auth_server_metadata(
 756    http_client: &Arc<dyn HttpClient>,
 757    issuer: &Url,
 758) -> Result<AuthServerMetadata> {
 759    let candidate_urls = auth_server_metadata_urls(issuer);
 760
 761    for url in &candidate_urls {
 762        match fetch_json::<AuthServerMetadataResponse>(http_client, url).await {
 763            Ok(response) => {
 764                let reported_issuer = response.issuer.unwrap_or_else(|| issuer.clone());
 765                if reported_issuer != *issuer {
 766                    bail!(
 767                        "Auth server metadata issuer mismatch: expected {}, got {}",
 768                        issuer,
 769                        reported_issuer
 770                    );
 771                }
 772
 773                return Ok(AuthServerMetadata {
 774                    issuer: reported_issuer,
 775                    authorization_endpoint: response
 776                        .authorization_endpoint
 777                        .ok_or_else(|| anyhow!("missing authorization_endpoint"))?,
 778                    token_endpoint: response
 779                        .token_endpoint
 780                        .ok_or_else(|| anyhow!("missing token_endpoint"))?,
 781                    registration_endpoint: response.registration_endpoint,
 782                    scopes_supported: response.scopes_supported,
 783                    code_challenge_methods_supported: response.code_challenge_methods_supported,
 784                    client_id_metadata_document_supported: response
 785                        .client_id_metadata_document_supported
 786                        .unwrap_or(false),
 787                });
 788            }
 789            Err(err) => {
 790                log::debug!("Failed to fetch Auth Server Metadata from {}: {}", url, err);
 791            }
 792        }
 793    }
 794
 795    bail!(
 796        "Could not fetch Authorization Server Metadata for {}",
 797        issuer
 798    )
 799}
 800
 801/// Run the full discovery flow: fetch resource metadata, then auth server
 802/// metadata, then select scopes. Client registration is resolved separately,
 803/// once the real redirect URI is known.
 804pub async fn discover(
 805    http_client: &Arc<dyn HttpClient>,
 806    server_url: &Url,
 807    www_authenticate: &WwwAuthenticate,
 808) -> Result<OAuthDiscovery> {
 809    let resource_metadata =
 810        fetch_protected_resource_metadata(http_client, server_url, www_authenticate).await?;
 811
 812    let auth_server_url = resource_metadata
 813        .authorization_servers
 814        .first()
 815        .ok_or_else(|| anyhow!("no authorization servers in resource metadata"))?;
 816
 817    let auth_server_metadata = fetch_auth_server_metadata(http_client, auth_server_url).await?;
 818
 819    // Verify PKCE S256 support (spec requirement).
 820    match &auth_server_metadata.code_challenge_methods_supported {
 821        Some(methods) if methods.iter().any(|m| m == "S256") => {}
 822        Some(_) => bail!("authorization server does not support S256 PKCE"),
 823        None => bail!("authorization server does not advertise code_challenge_methods_supported"),
 824    }
 825
 826    // Verify there is at least one supported registration strategy before we
 827    // present the server as ready to authenticate.
 828    match determine_registration_strategy(&auth_server_metadata) {
 829        ClientRegistrationStrategy::Cimd { .. } | ClientRegistrationStrategy::Dcr { .. } => {}
 830        ClientRegistrationStrategy::Unavailable => {
 831            bail!("authorization server supports neither CIMD nor DCR")
 832        }
 833    }
 834
 835    let scopes = select_scopes(www_authenticate, &resource_metadata);
 836
 837    Ok(OAuthDiscovery {
 838        resource_metadata,
 839        auth_server_metadata,
 840        scopes,
 841    })
 842}
 843
 844/// Resolve the OAuth client registration for an authorization flow.
 845///
 846/// CIMD uses the static client metadata document directly. For DCR, a fresh
 847/// registration is performed each time because the loopback redirect URI
 848/// includes an ephemeral port that changes every flow.
 849pub async fn resolve_client_registration(
 850    http_client: &Arc<dyn HttpClient>,
 851    discovery: &OAuthDiscovery,
 852    redirect_uri: &str,
 853) -> Result<OAuthClientRegistration> {
 854    match determine_registration_strategy(&discovery.auth_server_metadata) {
 855        ClientRegistrationStrategy::Cimd { client_id } => Ok(OAuthClientRegistration {
 856            client_id,
 857            client_secret: None,
 858        }),
 859        ClientRegistrationStrategy::Dcr {
 860            registration_endpoint,
 861        } => perform_dcr(http_client, &registration_endpoint, redirect_uri).await,
 862        ClientRegistrationStrategy::Unavailable => {
 863            bail!("authorization server supports neither CIMD nor DCR")
 864        }
 865    }
 866}
 867
 868// -- Dynamic Client Registration (RFC 7591) ----------------------------------
 869
 870/// Perform Dynamic Client Registration with the authorization server.
 871pub async fn perform_dcr(
 872    http_client: &Arc<dyn HttpClient>,
 873    registration_endpoint: &Url,
 874    redirect_uri: &str,
 875) -> Result<OAuthClientRegistration> {
 876    validate_oauth_url(registration_endpoint)?;
 877
 878    let body = dcr_registration_body(redirect_uri);
 879    let body_bytes = serde_json::to_vec(&body)?;
 880
 881    let request = Request::builder()
 882        .method(http_client::http::Method::POST)
 883        .uri(registration_endpoint.as_str())
 884        .header("Content-Type", "application/json")
 885        .header("Accept", "application/json")
 886        .body(AsyncBody::from(body_bytes))?;
 887
 888    let mut response = http_client.send(request).await?;
 889
 890    if !response.status().is_success() {
 891        let mut error_body = String::new();
 892        response.body_mut().read_to_string(&mut error_body).await?;
 893        bail!(
 894            "DCR failed with status {}: {}",
 895            response.status(),
 896            error_body
 897        );
 898    }
 899
 900    let mut response_body = String::new();
 901    response
 902        .body_mut()
 903        .read_to_string(&mut response_body)
 904        .await?;
 905
 906    let dcr_response: DcrResponse =
 907        serde_json::from_str(&response_body).context("failed to parse DCR response")?;
 908
 909    Ok(OAuthClientRegistration {
 910        client_id: dcr_response.client_id,
 911        client_secret: dcr_response.client_secret,
 912    })
 913}
 914
 915// -- Token exchange and refresh (async) --------------------------------------
 916
 917/// Exchange an authorization code for tokens at the token endpoint.
 918pub async fn exchange_code(
 919    http_client: &Arc<dyn HttpClient>,
 920    auth_server_metadata: &AuthServerMetadata,
 921    code: &str,
 922    client_id: &str,
 923    redirect_uri: &str,
 924    code_verifier: &str,
 925    resource: &str,
 926) -> Result<OAuthTokens> {
 927    let params = token_exchange_params(code, client_id, redirect_uri, code_verifier, resource);
 928    post_token_request(http_client, &auth_server_metadata.token_endpoint, &params).await
 929}
 930
 931/// Refresh tokens using a refresh token.
 932pub async fn refresh_tokens(
 933    http_client: &Arc<dyn HttpClient>,
 934    token_endpoint: &Url,
 935    refresh_token: &str,
 936    client_id: &str,
 937    resource: &str,
 938) -> Result<OAuthTokens> {
 939    let params = token_refresh_params(refresh_token, client_id, resource);
 940    post_token_request(http_client, token_endpoint, &params).await
 941}
 942
 943/// POST form-encoded parameters to a token endpoint and parse the response.
 944async fn post_token_request(
 945    http_client: &Arc<dyn HttpClient>,
 946    token_endpoint: &Url,
 947    params: &[(&str, String)],
 948) -> Result<OAuthTokens> {
 949    validate_oauth_url(token_endpoint)?;
 950
 951    let body = url::form_urlencoded::Serializer::new(String::new())
 952        .extend_pairs(params.iter().map(|(k, v)| (*k, v.as_str())))
 953        .finish();
 954
 955    let request = Request::builder()
 956        .method(http_client::http::Method::POST)
 957        .uri(token_endpoint.as_str())
 958        .header("Content-Type", "application/x-www-form-urlencoded")
 959        .header("Accept", "application/json")
 960        .body(AsyncBody::from(body.into_bytes()))?;
 961
 962    let mut response = http_client.send(request).await?;
 963
 964    if !response.status().is_success() {
 965        let mut error_body = String::new();
 966        response.body_mut().read_to_string(&mut error_body).await?;
 967        bail!(
 968            "token request failed with status {}: {}",
 969            response.status(),
 970            error_body
 971        );
 972    }
 973
 974    let mut response_body = String::new();
 975    response
 976        .body_mut()
 977        .read_to_string(&mut response_body)
 978        .await?;
 979
 980    let token_response: TokenResponse =
 981        serde_json::from_str(&response_body).context("failed to parse token response")?;
 982
 983    Ok(token_response.into_tokens())
 984}
 985
 986// -- Loopback HTTP callback server -------------------------------------------
 987
 988/// An OAuth authorization callback received via the loopback HTTP server.
 989pub struct OAuthCallback {
 990    pub code: String,
 991    pub state: String,
 992}
 993
 994impl std::fmt::Debug for OAuthCallback {
 995    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 996        f.debug_struct("OAuthCallback")
 997            .field("code", &"[redacted]")
 998            .field("state", &"[redacted]")
 999            .finish()
1000    }
1001}
1002
1003impl OAuthCallback {
1004    /// Parse the query string from a callback URL like
1005    /// `http://127.0.0.1:<port>/callback?code=...&state=...`.
1006    pub fn parse_query(query: &str) -> Result<Self> {
1007        let mut code: Option<String> = None;
1008        let mut state: Option<String> = None;
1009        let mut error: Option<String> = None;
1010        let mut error_description: Option<String> = None;
1011
1012        for (key, value) in url::form_urlencoded::parse(query.as_bytes()) {
1013            match key.as_ref() {
1014                "code" => {
1015                    if !value.is_empty() {
1016                        code = Some(value.into_owned());
1017                    }
1018                }
1019                "state" => {
1020                    if !value.is_empty() {
1021                        state = Some(value.into_owned());
1022                    }
1023                }
1024                "error" => {
1025                    if !value.is_empty() {
1026                        error = Some(value.into_owned());
1027                    }
1028                }
1029                "error_description" => {
1030                    if !value.is_empty() {
1031                        error_description = Some(value.into_owned());
1032                    }
1033                }
1034                _ => {}
1035            }
1036        }
1037
1038        // Check for OAuth error response (RFC 6749 Section 4.1.2.1) before
1039        // checking for missing code/state.
1040        if let Some(error_code) = error {
1041            bail!(
1042                "OAuth authorization failed: {} ({})",
1043                error_code,
1044                error_description.as_deref().unwrap_or("no description")
1045            );
1046        }
1047
1048        let code = code.ok_or_else(|| anyhow!("missing 'code' parameter in OAuth callback"))?;
1049        let state = state.ok_or_else(|| anyhow!("missing 'state' parameter in OAuth callback"))?;
1050
1051        Ok(Self { code, state })
1052    }
1053}
1054
1055/// How long to wait for the browser to complete the OAuth flow before giving
1056/// up and releasing the loopback port.
1057const CALLBACK_TIMEOUT: Duration = Duration::from_secs(2 * 60);
1058
1059/// Start a loopback HTTP server to receive the OAuth authorization callback.
1060///
1061/// Binds to an ephemeral loopback port for each flow.
1062///
1063/// Returns `(redirect_uri, callback_future)`. The caller should use the
1064/// redirect URI in the authorization request, open the browser, then await
1065/// the future to receive the callback.
1066///
1067/// The server accepts exactly one request on `/callback`, validates that it
1068/// contains `code` and `state` query parameters, responds with a minimal
1069/// HTML page telling the user they can close the tab, and shuts down.
1070///
1071/// The callback server shuts down when the returned oneshot receiver is dropped
1072/// (e.g. because the authentication task was cancelled), or after a timeout
1073/// ([CALLBACK_TIMEOUT]).
1074pub async fn start_callback_server() -> Result<(
1075    String,
1076    futures::channel::oneshot::Receiver<Result<OAuthCallback>>,
1077)> {
1078    let server = tiny_http::Server::http("127.0.0.1:0")
1079        .map_err(|e| anyhow!(e).context("Failed to bind loopback listener for OAuth callback"))?;
1080    let port = server
1081        .server_addr()
1082        .to_ip()
1083        .context("server not bound to a TCP address")?
1084        .port();
1085
1086    let redirect_uri = format!("http://127.0.0.1:{}/callback", port);
1087
1088    let (tx, rx) = futures::channel::oneshot::channel();
1089
1090    // `tiny_http` is blocking, so we run it on a background thread.
1091    // The `recv_timeout` loop lets us check for cancellation (the receiver
1092    // being dropped) and enforce an overall timeout.
1093    std::thread::spawn(move || {
1094        let deadline = std::time::Instant::now() + CALLBACK_TIMEOUT;
1095
1096        loop {
1097            if tx.is_canceled() {
1098                return;
1099            }
1100            let remaining = deadline.saturating_duration_since(std::time::Instant::now());
1101            if remaining.is_zero() {
1102                return;
1103            }
1104
1105            let timeout = remaining.min(Duration::from_millis(500));
1106            let Some(request) = (match server.recv_timeout(timeout) {
1107                Ok(req) => req,
1108                Err(_) => {
1109                    let _ = tx.send(Err(anyhow!("OAuth callback server I/O error")));
1110                    return;
1111                }
1112            }) else {
1113                // Timeout with no request — loop back and check cancellation.
1114                continue;
1115            };
1116
1117            let result = handle_callback_request(&request);
1118
1119            let (status_code, body) = match &result {
1120                Ok(_) => (
1121                    200,
1122                    "<html><body><h1>Authorization successful</h1>\
1123                     <p>You can close this tab and return to Zed.</p></body></html>",
1124                ),
1125                Err(err) => {
1126                    log::error!("OAuth callback error: {}", err);
1127                    (
1128                        400,
1129                        "<html><body><h1>Authorization failed</h1>\
1130                         <p>Something went wrong. Please try again from Zed.</p></body></html>",
1131                    )
1132                }
1133            };
1134
1135            let response = tiny_http::Response::from_string(body)
1136                .with_status_code(status_code)
1137                .with_header(
1138                    tiny_http::Header::from_str("Content-Type: text/html")
1139                        .expect("failed to construct response header"),
1140                )
1141                .with_header(
1142                    tiny_http::Header::from_str("Keep-Alive: timeout=0,max=0")
1143                        .expect("failed to construct response header"),
1144                );
1145            request.respond(response).log_err();
1146
1147            let _ = tx.send(result);
1148            return;
1149        }
1150    });
1151
1152    Ok((redirect_uri, rx))
1153}
1154
1155/// Extract the `code` and `state` query parameters from an OAuth callback
1156/// request to `/callback`.
1157fn handle_callback_request(request: &tiny_http::Request) -> Result<OAuthCallback> {
1158    let url = Url::parse(&format!("http://localhost{}", request.url()))
1159        .context("malformed callback request URL")?;
1160
1161    if url.path() != "/callback" {
1162        bail!("unexpected path in OAuth callback: {}", url.path());
1163    }
1164
1165    let query = url
1166        .query()
1167        .ok_or_else(|| anyhow!("OAuth callback has no query string"))?;
1168    OAuthCallback::parse_query(query)
1169}
1170
1171// -- JSON fetch helper -------------------------------------------------------
1172
1173async fn fetch_json<T: serde::de::DeserializeOwned>(
1174    http_client: &Arc<dyn HttpClient>,
1175    url: &Url,
1176) -> Result<T> {
1177    validate_oauth_url(url)?;
1178
1179    let request = Request::builder()
1180        .method(http_client::http::Method::GET)
1181        .uri(url.as_str())
1182        .header("Accept", "application/json")
1183        .body(AsyncBody::default())?;
1184
1185    let mut response = http_client.send(request).await?;
1186
1187    if !response.status().is_success() {
1188        bail!("HTTP {} fetching {}", response.status(), url);
1189    }
1190
1191    let mut body = String::new();
1192    response.body_mut().read_to_string(&mut body).await?;
1193    serde_json::from_str(&body).with_context(|| format!("failed to parse JSON from {}", url))
1194}
1195
1196// -- Serde response types for discovery --------------------------------------
1197
1198#[derive(Debug, Deserialize)]
1199struct ProtectedResourceMetadataResponse {
1200    #[serde(default)]
1201    resource: Option<Url>,
1202    #[serde(default)]
1203    authorization_servers: Vec<Url>,
1204    #[serde(default)]
1205    scopes_supported: Option<Vec<String>>,
1206}
1207
1208#[derive(Debug, Deserialize)]
1209struct AuthServerMetadataResponse {
1210    #[serde(default)]
1211    issuer: Option<Url>,
1212    #[serde(default)]
1213    authorization_endpoint: Option<Url>,
1214    #[serde(default)]
1215    token_endpoint: Option<Url>,
1216    #[serde(default)]
1217    registration_endpoint: Option<Url>,
1218    #[serde(default)]
1219    scopes_supported: Option<Vec<String>>,
1220    #[serde(default)]
1221    code_challenge_methods_supported: Option<Vec<String>>,
1222    #[serde(default)]
1223    client_id_metadata_document_supported: Option<bool>,
1224}
1225
1226#[derive(Debug, Deserialize)]
1227struct DcrResponse {
1228    client_id: String,
1229    #[serde(default)]
1230    client_secret: Option<String>,
1231}
1232
1233/// Provides OAuth tokens to the HTTP transport layer.
1234///
1235/// The transport calls `access_token()` before each request. On a 401 response
1236/// it calls `try_refresh()` and retries once if the refresh succeeds.
1237#[async_trait]
1238pub trait OAuthTokenProvider: Send + Sync {
1239    /// Returns the current access token, if one is available.
1240    fn access_token(&self) -> Option<String>;
1241
1242    /// Attempts to refresh the access token. Returns `true` if a new token was
1243    /// obtained and the request should be retried.
1244    async fn try_refresh(&self) -> Result<bool>;
1245}
1246
1247/// Concrete `OAuthTokenProvider` backed by a full persisted OAuth session and
1248/// an HTTP client for token refresh. The same provider type is used both after
1249/// an interactive authentication flow and when restoring a saved session from
1250/// the keychain on startup.
1251pub struct McpOAuthTokenProvider {
1252    session: SyncMutex<OAuthSession>,
1253    http_client: Arc<dyn HttpClient>,
1254    token_refresh_tx: Option<mpsc::UnboundedSender<OAuthSession>>,
1255}
1256
1257impl McpOAuthTokenProvider {
1258    pub fn new(
1259        session: OAuthSession,
1260        http_client: Arc<dyn HttpClient>,
1261        token_refresh_tx: Option<mpsc::UnboundedSender<OAuthSession>>,
1262    ) -> Self {
1263        Self {
1264            session: SyncMutex::new(session),
1265            http_client,
1266            token_refresh_tx,
1267        }
1268    }
1269
1270    fn access_token_is_expired(tokens: &OAuthTokens) -> bool {
1271        tokens.expires_at.is_some_and(|expires_at| {
1272            SystemTime::now()
1273                .checked_add(Duration::from_secs(30))
1274                .is_some_and(|now_with_buffer| expires_at <= now_with_buffer)
1275        })
1276    }
1277}
1278
1279#[async_trait]
1280impl OAuthTokenProvider for McpOAuthTokenProvider {
1281    fn access_token(&self) -> Option<String> {
1282        let session = self.session.lock();
1283        if Self::access_token_is_expired(&session.tokens) {
1284            return None;
1285        }
1286        Some(session.tokens.access_token.clone())
1287    }
1288
1289    async fn try_refresh(&self) -> Result<bool> {
1290        let (refresh_token, token_endpoint, resource, client_id) = {
1291            let session = self.session.lock();
1292            match session.tokens.refresh_token.clone() {
1293                Some(refresh_token) => (
1294                    refresh_token,
1295                    session.token_endpoint.clone(),
1296                    session.resource.clone(),
1297                    session.client_registration.client_id.clone(),
1298                ),
1299                None => return Ok(false),
1300            }
1301        };
1302
1303        let resource_str = canonical_server_uri(&resource);
1304
1305        match refresh_tokens(
1306            &self.http_client,
1307            &token_endpoint,
1308            &refresh_token,
1309            &client_id,
1310            &resource_str,
1311        )
1312        .await
1313        {
1314            Ok(mut new_tokens) => {
1315                if new_tokens.refresh_token.is_none() {
1316                    new_tokens.refresh_token = Some(refresh_token);
1317                }
1318
1319                {
1320                    let mut session = self.session.lock();
1321                    session.tokens = new_tokens;
1322
1323                    if let Some(ref tx) = self.token_refresh_tx {
1324                        tx.unbounded_send(session.clone()).ok();
1325                    }
1326                }
1327
1328                Ok(true)
1329            }
1330            Err(err) => {
1331                log::warn!("OAuth token refresh failed: {}", err);
1332                Ok(false)
1333            }
1334        }
1335    }
1336}
1337
1338#[cfg(test)]
1339mod tests {
1340    use super::*;
1341    use http_client::Response;
1342
1343    // -- require_https_or_loopback tests ------------------------------------
1344
1345    #[test]
1346    fn test_require_https_or_loopback_accepts_https() {
1347        let url = Url::parse("https://auth.example.com/token").unwrap();
1348        assert!(require_https_or_loopback(&url).is_ok());
1349    }
1350
1351    #[test]
1352    fn test_require_https_or_loopback_rejects_http_remote() {
1353        let url = Url::parse("http://auth.example.com/token").unwrap();
1354        assert!(require_https_or_loopback(&url).is_err());
1355    }
1356
1357    #[test]
1358    fn test_require_https_or_loopback_accepts_http_127_0_0_1() {
1359        let url = Url::parse("http://127.0.0.1:8080/callback").unwrap();
1360        assert!(require_https_or_loopback(&url).is_ok());
1361    }
1362
1363    #[test]
1364    fn test_require_https_or_loopback_accepts_http_ipv6_loopback() {
1365        let url = Url::parse("http://[::1]:8080/callback").unwrap();
1366        assert!(require_https_or_loopback(&url).is_ok());
1367    }
1368
1369    #[test]
1370    fn test_require_https_or_loopback_accepts_http_localhost() {
1371        let url = Url::parse("http://localhost:8080/callback").unwrap();
1372        assert!(require_https_or_loopback(&url).is_ok());
1373    }
1374
1375    #[test]
1376    fn test_require_https_or_loopback_accepts_http_localhost_case_insensitive() {
1377        let url = Url::parse("http://LOCALHOST:8080/callback").unwrap();
1378        assert!(require_https_or_loopback(&url).is_ok());
1379    }
1380
1381    #[test]
1382    fn test_require_https_or_loopback_rejects_http_non_loopback_ip() {
1383        let url = Url::parse("http://192.168.1.1:8080/token").unwrap();
1384        assert!(require_https_or_loopback(&url).is_err());
1385    }
1386
1387    #[test]
1388    fn test_require_https_or_loopback_rejects_ftp() {
1389        let url = Url::parse("ftp://auth.example.com/token").unwrap();
1390        assert!(require_https_or_loopback(&url).is_err());
1391    }
1392
1393    // -- validate_oauth_url (SSRF) tests ------------------------------------
1394
1395    #[test]
1396    fn test_validate_oauth_url_accepts_https_public() {
1397        let url = Url::parse("https://auth.example.com/token").unwrap();
1398        assert!(validate_oauth_url(&url).is_ok());
1399    }
1400
1401    #[test]
1402    fn test_validate_oauth_url_rejects_private_ipv4_10() {
1403        let url = Url::parse("https://10.0.0.1/token").unwrap();
1404        assert!(validate_oauth_url(&url).is_err());
1405    }
1406
1407    #[test]
1408    fn test_validate_oauth_url_rejects_private_ipv4_172() {
1409        let url = Url::parse("https://172.16.0.1/token").unwrap();
1410        assert!(validate_oauth_url(&url).is_err());
1411    }
1412
1413    #[test]
1414    fn test_validate_oauth_url_rejects_private_ipv4_192() {
1415        let url = Url::parse("https://192.168.1.1/token").unwrap();
1416        assert!(validate_oauth_url(&url).is_err());
1417    }
1418
1419    #[test]
1420    fn test_validate_oauth_url_rejects_link_local() {
1421        let url = Url::parse("https://169.254.169.254/latest/meta-data/").unwrap();
1422        assert!(validate_oauth_url(&url).is_err());
1423    }
1424
1425    #[test]
1426    fn test_validate_oauth_url_rejects_ipv6_ula() {
1427        let url = Url::parse("https://[fd12:3456:789a::1]/token").unwrap();
1428        assert!(validate_oauth_url(&url).is_err());
1429    }
1430
1431    #[test]
1432    fn test_validate_oauth_url_rejects_ipv6_unspecified() {
1433        let url = Url::parse("https://[::]/token").unwrap();
1434        assert!(validate_oauth_url(&url).is_err());
1435    }
1436
1437    #[test]
1438    fn test_validate_oauth_url_rejects_ipv4_mapped_ipv6_private() {
1439        let url = Url::parse("https://[::ffff:10.0.0.1]/token").unwrap();
1440        assert!(validate_oauth_url(&url).is_err());
1441    }
1442
1443    #[test]
1444    fn test_validate_oauth_url_rejects_ipv4_mapped_ipv6_link_local() {
1445        let url = Url::parse("https://[::ffff:169.254.169.254]/token").unwrap();
1446        assert!(validate_oauth_url(&url).is_err());
1447    }
1448
1449    #[test]
1450    fn test_validate_oauth_url_allows_http_loopback() {
1451        // Loopback is permitted (it's our callback server).
1452        let url = Url::parse("http://127.0.0.1:8080/callback").unwrap();
1453        assert!(validate_oauth_url(&url).is_ok());
1454    }
1455
1456    #[test]
1457    fn test_validate_oauth_url_allows_https_public_ip() {
1458        let url = Url::parse("https://93.184.216.34/token").unwrap();
1459        assert!(validate_oauth_url(&url).is_ok());
1460    }
1461
1462    // -- parse_www_authenticate tests ----------------------------------------
1463
1464    #[test]
1465    fn test_parse_www_authenticate_with_resource_metadata_and_scope() {
1466        let header = r#"Bearer resource_metadata="https://mcp.example.com/.well-known/oauth-protected-resource", scope="files:read user:profile""#;
1467        let result = parse_www_authenticate(header).unwrap();
1468
1469        assert_eq!(
1470            result.resource_metadata.as_ref().map(|u| u.as_str()),
1471            Some("https://mcp.example.com/.well-known/oauth-protected-resource")
1472        );
1473        assert_eq!(
1474            result.scope,
1475            Some(vec!["files:read".to_string(), "user:profile".to_string()])
1476        );
1477        assert_eq!(result.error, None);
1478    }
1479
1480    #[test]
1481    fn test_parse_www_authenticate_resource_metadata_only() {
1482        let header = r#"Bearer resource_metadata="https://mcp.example.com/.well-known/oauth-protected-resource""#;
1483        let result = parse_www_authenticate(header).unwrap();
1484
1485        assert_eq!(
1486            result.resource_metadata.as_ref().map(|u| u.as_str()),
1487            Some("https://mcp.example.com/.well-known/oauth-protected-resource")
1488        );
1489        assert_eq!(result.scope, None);
1490    }
1491
1492    #[test]
1493    fn test_parse_www_authenticate_bare_bearer() {
1494        let result = parse_www_authenticate("Bearer").unwrap();
1495        assert_eq!(result.resource_metadata, None);
1496        assert_eq!(result.scope, None);
1497    }
1498
1499    #[test]
1500    fn test_parse_www_authenticate_with_error() {
1501        let header = r#"Bearer error="insufficient_scope", scope="files:read files:write", resource_metadata="https://mcp.example.com/.well-known/oauth-protected-resource", error_description="Additional file write permission required""#;
1502        let result = parse_www_authenticate(header).unwrap();
1503
1504        assert_eq!(result.error, Some(BearerError::InsufficientScope));
1505        assert_eq!(
1506            result.error_description.as_deref(),
1507            Some("Additional file write permission required")
1508        );
1509        assert_eq!(
1510            result.scope,
1511            Some(vec!["files:read".to_string(), "files:write".to_string()])
1512        );
1513        assert!(result.resource_metadata.is_some());
1514    }
1515
1516    #[test]
1517    fn test_parse_www_authenticate_invalid_token_error() {
1518        let header =
1519            r#"Bearer error="invalid_token", error_description="The access token expired""#;
1520        let result = parse_www_authenticate(header).unwrap();
1521        assert_eq!(result.error, Some(BearerError::InvalidToken));
1522    }
1523
1524    #[test]
1525    fn test_parse_www_authenticate_invalid_request_error() {
1526        let header = r#"Bearer error="invalid_request""#;
1527        let result = parse_www_authenticate(header).unwrap();
1528        assert_eq!(result.error, Some(BearerError::InvalidRequest));
1529    }
1530
1531    #[test]
1532    fn test_parse_www_authenticate_unknown_error() {
1533        let header = r#"Bearer error="some_future_error""#;
1534        let result = parse_www_authenticate(header).unwrap();
1535        assert_eq!(result.error, Some(BearerError::Other));
1536    }
1537
1538    #[test]
1539    fn test_parse_www_authenticate_rejects_non_bearer() {
1540        let result = parse_www_authenticate("Basic realm=\"example\"");
1541        assert!(result.is_err());
1542    }
1543
1544    #[test]
1545    fn test_parse_www_authenticate_case_insensitive_scheme() {
1546        let header = r#"bearer resource_metadata="https://example.com/.well-known/oauth-protected-resource""#;
1547        let result = parse_www_authenticate(header).unwrap();
1548        assert!(result.resource_metadata.is_some());
1549    }
1550
1551    #[test]
1552    fn test_parse_www_authenticate_multiline_style() {
1553        // Some servers emit the header spread across multiple lines joined by
1554        // whitespace, as shown in the spec examples.
1555        let header = "Bearer resource_metadata=\"https://mcp.example.com/.well-known/oauth-protected-resource\",\n                         scope=\"files:read\"";
1556        let result = parse_www_authenticate(header).unwrap();
1557        assert!(result.resource_metadata.is_some());
1558        assert_eq!(result.scope, Some(vec!["files:read".to_string()]));
1559    }
1560
1561    #[test]
1562    fn test_protected_resource_metadata_urls_with_path() {
1563        let server_url = Url::parse("https://api.example.com/v1/mcp").unwrap();
1564        let urls = protected_resource_metadata_urls(&server_url);
1565
1566        assert_eq!(urls.len(), 2);
1567        assert_eq!(
1568            urls[0].as_str(),
1569            "https://api.example.com/.well-known/oauth-protected-resource/v1/mcp"
1570        );
1571        assert_eq!(
1572            urls[1].as_str(),
1573            "https://api.example.com/.well-known/oauth-protected-resource"
1574        );
1575    }
1576
1577    #[test]
1578    fn test_protected_resource_metadata_urls_without_path() {
1579        let server_url = Url::parse("https://mcp.example.com").unwrap();
1580        let urls = protected_resource_metadata_urls(&server_url);
1581
1582        assert_eq!(urls.len(), 1);
1583        assert_eq!(
1584            urls[0].as_str(),
1585            "https://mcp.example.com/.well-known/oauth-protected-resource"
1586        );
1587    }
1588
1589    #[test]
1590    fn test_auth_server_metadata_urls_with_path() {
1591        let issuer = Url::parse("https://auth.example.com/tenant1").unwrap();
1592        let urls = auth_server_metadata_urls(&issuer);
1593
1594        assert_eq!(urls.len(), 3);
1595        assert_eq!(
1596            urls[0].as_str(),
1597            "https://auth.example.com/.well-known/oauth-authorization-server/tenant1"
1598        );
1599        assert_eq!(
1600            urls[1].as_str(),
1601            "https://auth.example.com/.well-known/openid-configuration/tenant1"
1602        );
1603        assert_eq!(
1604            urls[2].as_str(),
1605            "https://auth.example.com/tenant1/.well-known/openid-configuration"
1606        );
1607    }
1608
1609    #[test]
1610    fn test_auth_server_metadata_urls_without_path() {
1611        let issuer = Url::parse("https://auth.example.com").unwrap();
1612        let urls = auth_server_metadata_urls(&issuer);
1613
1614        assert_eq!(urls.len(), 2);
1615        assert_eq!(
1616            urls[0].as_str(),
1617            "https://auth.example.com/.well-known/oauth-authorization-server"
1618        );
1619        assert_eq!(
1620            urls[1].as_str(),
1621            "https://auth.example.com/.well-known/openid-configuration"
1622        );
1623    }
1624
1625    // -- Canonical server URI tests ------------------------------------------
1626
1627    #[test]
1628    fn test_canonical_server_uri_simple() {
1629        let url = Url::parse("https://mcp.example.com").unwrap();
1630        assert_eq!(canonical_server_uri(&url), "https://mcp.example.com");
1631    }
1632
1633    #[test]
1634    fn test_canonical_server_uri_with_path() {
1635        let url = Url::parse("https://mcp.example.com/v1/mcp").unwrap();
1636        assert_eq!(canonical_server_uri(&url), "https://mcp.example.com/v1/mcp");
1637    }
1638
1639    #[test]
1640    fn test_canonical_server_uri_strips_trailing_slash() {
1641        let url = Url::parse("https://mcp.example.com/").unwrap();
1642        assert_eq!(canonical_server_uri(&url), "https://mcp.example.com");
1643    }
1644
1645    #[test]
1646    fn test_canonical_server_uri_preserves_port() {
1647        let url = Url::parse("https://mcp.example.com:8443").unwrap();
1648        assert_eq!(canonical_server_uri(&url), "https://mcp.example.com:8443");
1649    }
1650
1651    #[test]
1652    fn test_canonical_server_uri_lowercases() {
1653        let url = Url::parse("HTTPS://MCP.Example.COM/Server/MCP").unwrap();
1654        assert_eq!(
1655            canonical_server_uri(&url),
1656            "https://mcp.example.com/Server/MCP"
1657        );
1658    }
1659
1660    // -- Scope selection tests -----------------------------------------------
1661
1662    #[test]
1663    fn test_select_scopes_prefers_www_authenticate() {
1664        let www_auth = WwwAuthenticate {
1665            resource_metadata: None,
1666            scope: Some(vec!["files:read".into()]),
1667            error: None,
1668            error_description: None,
1669        };
1670        let resource_meta = ProtectedResourceMetadata {
1671            resource: Url::parse("https://example.com").unwrap(),
1672            authorization_servers: vec![],
1673            scopes_supported: Some(vec!["files:read".into(), "files:write".into()]),
1674        };
1675        assert_eq!(select_scopes(&www_auth, &resource_meta), vec!["files:read"]);
1676    }
1677
1678    #[test]
1679    fn test_select_scopes_falls_back_to_resource_metadata() {
1680        let www_auth = WwwAuthenticate {
1681            resource_metadata: None,
1682            scope: None,
1683            error: None,
1684            error_description: None,
1685        };
1686        let resource_meta = ProtectedResourceMetadata {
1687            resource: Url::parse("https://example.com").unwrap(),
1688            authorization_servers: vec![],
1689            scopes_supported: Some(vec!["admin".into()]),
1690        };
1691        assert_eq!(select_scopes(&www_auth, &resource_meta), vec!["admin"]);
1692    }
1693
1694    #[test]
1695    fn test_select_scopes_empty_when_nothing_available() {
1696        let www_auth = WwwAuthenticate {
1697            resource_metadata: None,
1698            scope: None,
1699            error: None,
1700            error_description: None,
1701        };
1702        let resource_meta = ProtectedResourceMetadata {
1703            resource: Url::parse("https://example.com").unwrap(),
1704            authorization_servers: vec![],
1705            scopes_supported: None,
1706        };
1707        assert!(select_scopes(&www_auth, &resource_meta).is_empty());
1708    }
1709
1710    // -- Client registration strategy tests ----------------------------------
1711
1712    #[test]
1713    fn test_registration_strategy_prefers_cimd() {
1714        let metadata = AuthServerMetadata {
1715            issuer: Url::parse("https://auth.example.com").unwrap(),
1716            authorization_endpoint: Url::parse("https://auth.example.com/authorize").unwrap(),
1717            token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
1718            registration_endpoint: Some(Url::parse("https://auth.example.com/register").unwrap()),
1719            scopes_supported: None,
1720            code_challenge_methods_supported: Some(vec!["S256".into()]),
1721            client_id_metadata_document_supported: true,
1722        };
1723        assert_eq!(
1724            determine_registration_strategy(&metadata),
1725            ClientRegistrationStrategy::Cimd {
1726                client_id: CIMD_URL.to_string(),
1727            }
1728        );
1729    }
1730
1731    #[test]
1732    fn test_registration_strategy_falls_back_to_dcr() {
1733        let reg_endpoint = Url::parse("https://auth.example.com/register").unwrap();
1734        let metadata = AuthServerMetadata {
1735            issuer: Url::parse("https://auth.example.com").unwrap(),
1736            authorization_endpoint: Url::parse("https://auth.example.com/authorize").unwrap(),
1737            token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
1738            registration_endpoint: Some(reg_endpoint.clone()),
1739            scopes_supported: None,
1740            code_challenge_methods_supported: Some(vec!["S256".into()]),
1741            client_id_metadata_document_supported: false,
1742        };
1743        assert_eq!(
1744            determine_registration_strategy(&metadata),
1745            ClientRegistrationStrategy::Dcr {
1746                registration_endpoint: reg_endpoint,
1747            }
1748        );
1749    }
1750
1751    #[test]
1752    fn test_registration_strategy_unavailable() {
1753        let metadata = AuthServerMetadata {
1754            issuer: Url::parse("https://auth.example.com").unwrap(),
1755            authorization_endpoint: Url::parse("https://auth.example.com/authorize").unwrap(),
1756            token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
1757            registration_endpoint: None,
1758            scopes_supported: None,
1759            code_challenge_methods_supported: Some(vec!["S256".into()]),
1760            client_id_metadata_document_supported: false,
1761        };
1762        assert_eq!(
1763            determine_registration_strategy(&metadata),
1764            ClientRegistrationStrategy::Unavailable,
1765        );
1766    }
1767
1768    // -- PKCE tests ----------------------------------------------------------
1769
1770    #[test]
1771    fn test_pkce_challenge_verifier_length() {
1772        let pkce = generate_pkce_challenge();
1773        // 32 random bytes → 43 base64url chars (no padding).
1774        assert_eq!(pkce.verifier.len(), 43);
1775    }
1776
1777    #[test]
1778    fn test_pkce_challenge_is_valid_base64url() {
1779        let pkce = generate_pkce_challenge();
1780        for c in pkce.verifier.chars().chain(pkce.challenge.chars()) {
1781            assert!(
1782                c.is_ascii_alphanumeric() || c == '-' || c == '_',
1783                "invalid base64url character: {}",
1784                c
1785            );
1786        }
1787    }
1788
1789    #[test]
1790    fn test_pkce_challenge_is_s256_of_verifier() {
1791        let pkce = generate_pkce_challenge();
1792        let engine = base64::engine::general_purpose::URL_SAFE_NO_PAD;
1793        let expected_digest = Sha256::digest(pkce.verifier.as_bytes());
1794        let expected_challenge = engine.encode(expected_digest);
1795        assert_eq!(pkce.challenge, expected_challenge);
1796    }
1797
1798    #[test]
1799    fn test_pkce_challenges_are_unique() {
1800        let a = generate_pkce_challenge();
1801        let b = generate_pkce_challenge();
1802        assert_ne!(a.verifier, b.verifier);
1803    }
1804
1805    // -- Authorization URL tests ---------------------------------------------
1806
1807    #[test]
1808    fn test_build_authorization_url() {
1809        let metadata = AuthServerMetadata {
1810            issuer: Url::parse("https://auth.example.com").unwrap(),
1811            authorization_endpoint: Url::parse("https://auth.example.com/authorize").unwrap(),
1812            token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
1813            registration_endpoint: None,
1814            scopes_supported: None,
1815            code_challenge_methods_supported: Some(vec!["S256".into()]),
1816            client_id_metadata_document_supported: true,
1817        };
1818        let pkce = PkceChallenge {
1819            verifier: "test_verifier".into(),
1820            challenge: "test_challenge".into(),
1821        };
1822        let url = build_authorization_url(
1823            &metadata,
1824            "https://zed.dev/oauth/client-metadata.json",
1825            "http://127.0.0.1:12345/callback",
1826            &["files:read".into(), "files:write".into()],
1827            "https://mcp.example.com",
1828            &pkce,
1829            "random_state_123",
1830        );
1831
1832        let pairs: std::collections::HashMap<_, _> = url.query_pairs().collect();
1833        assert_eq!(pairs.get("response_type").unwrap(), "code");
1834        assert_eq!(
1835            pairs.get("client_id").unwrap(),
1836            "https://zed.dev/oauth/client-metadata.json"
1837        );
1838        assert_eq!(
1839            pairs.get("redirect_uri").unwrap(),
1840            "http://127.0.0.1:12345/callback"
1841        );
1842        assert_eq!(pairs.get("scope").unwrap(), "files:read files:write");
1843        assert_eq!(pairs.get("resource").unwrap(), "https://mcp.example.com");
1844        assert_eq!(pairs.get("code_challenge").unwrap(), "test_challenge");
1845        assert_eq!(pairs.get("code_challenge_method").unwrap(), "S256");
1846        assert_eq!(pairs.get("state").unwrap(), "random_state_123");
1847    }
1848
1849    #[test]
1850    fn test_build_authorization_url_omits_empty_scope() {
1851        let metadata = AuthServerMetadata {
1852            issuer: Url::parse("https://auth.example.com").unwrap(),
1853            authorization_endpoint: Url::parse("https://auth.example.com/authorize").unwrap(),
1854            token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
1855            registration_endpoint: None,
1856            scopes_supported: None,
1857            code_challenge_methods_supported: Some(vec!["S256".into()]),
1858            client_id_metadata_document_supported: false,
1859        };
1860        let pkce = PkceChallenge {
1861            verifier: "v".into(),
1862            challenge: "c".into(),
1863        };
1864        let url = build_authorization_url(
1865            &metadata,
1866            "client_123",
1867            "http://127.0.0.1:9999/callback",
1868            &[],
1869            "https://mcp.example.com",
1870            &pkce,
1871            "state",
1872        );
1873
1874        let pairs: std::collections::HashMap<_, _> = url.query_pairs().collect();
1875        assert!(!pairs.contains_key("scope"));
1876    }
1877
1878    // -- Token exchange / refresh param tests --------------------------------
1879
1880    #[test]
1881    fn test_token_exchange_params() {
1882        let params = token_exchange_params(
1883            "auth_code_abc",
1884            "client_xyz",
1885            "http://127.0.0.1:5555/callback",
1886            "verifier_123",
1887            "https://mcp.example.com",
1888        );
1889        let map: std::collections::HashMap<&str, &str> =
1890            params.iter().map(|(k, v)| (*k, v.as_str())).collect();
1891
1892        assert_eq!(map["grant_type"], "authorization_code");
1893        assert_eq!(map["code"], "auth_code_abc");
1894        assert_eq!(map["redirect_uri"], "http://127.0.0.1:5555/callback");
1895        assert_eq!(map["client_id"], "client_xyz");
1896        assert_eq!(map["code_verifier"], "verifier_123");
1897        assert_eq!(map["resource"], "https://mcp.example.com");
1898    }
1899
1900    #[test]
1901    fn test_token_refresh_params() {
1902        let params =
1903            token_refresh_params("refresh_token_abc", "client_xyz", "https://mcp.example.com");
1904        let map: std::collections::HashMap<&str, &str> =
1905            params.iter().map(|(k, v)| (*k, v.as_str())).collect();
1906
1907        assert_eq!(map["grant_type"], "refresh_token");
1908        assert_eq!(map["refresh_token"], "refresh_token_abc");
1909        assert_eq!(map["client_id"], "client_xyz");
1910        assert_eq!(map["resource"], "https://mcp.example.com");
1911    }
1912
1913    // -- Token response tests ------------------------------------------------
1914
1915    #[test]
1916    fn test_token_response_into_tokens_with_expiry() {
1917        let response: TokenResponse = serde_json::from_str(
1918            r#"{"access_token": "at_123", "refresh_token": "rt_456", "expires_in": 3600, "token_type": "Bearer"}"#,
1919        )
1920        .unwrap();
1921
1922        let tokens = response.into_tokens();
1923        assert_eq!(tokens.access_token, "at_123");
1924        assert_eq!(tokens.refresh_token.as_deref(), Some("rt_456"));
1925        assert!(tokens.expires_at.is_some());
1926    }
1927
1928    #[test]
1929    fn test_token_response_into_tokens_minimal() {
1930        let response: TokenResponse =
1931            serde_json::from_str(r#"{"access_token": "at_789"}"#).unwrap();
1932
1933        let tokens = response.into_tokens();
1934        assert_eq!(tokens.access_token, "at_789");
1935        assert_eq!(tokens.refresh_token, None);
1936        assert_eq!(tokens.expires_at, None);
1937    }
1938
1939    // -- DCR body test -------------------------------------------------------
1940
1941    #[test]
1942    fn test_dcr_registration_body_shape() {
1943        let body = dcr_registration_body("http://127.0.0.1:12345/callback");
1944        assert_eq!(body["client_name"], "Zed");
1945        assert_eq!(body["redirect_uris"][0], "http://127.0.0.1:12345/callback");
1946        assert_eq!(body["grant_types"][0], "authorization_code");
1947        assert_eq!(body["response_types"][0], "code");
1948        assert_eq!(body["token_endpoint_auth_method"], "none");
1949    }
1950
1951    // -- Test helpers for async/HTTP tests -----------------------------------
1952
1953    fn make_fake_http_client(
1954        handler: impl Fn(
1955            http_client::Request<AsyncBody>,
1956        ) -> std::pin::Pin<
1957            Box<dyn std::future::Future<Output = anyhow::Result<Response<AsyncBody>>> + Send>,
1958        > + Send
1959        + Sync
1960        + 'static,
1961    ) -> Arc<dyn HttpClient> {
1962        http_client::FakeHttpClient::create(handler) as Arc<dyn HttpClient>
1963    }
1964
1965    fn json_response(status: u16, body: &str) -> anyhow::Result<Response<AsyncBody>> {
1966        Ok(Response::builder()
1967            .status(status)
1968            .header("Content-Type", "application/json")
1969            .body(AsyncBody::from(body.as_bytes().to_vec()))
1970            .unwrap())
1971    }
1972
1973    // -- Discovery integration tests -----------------------------------------
1974
1975    #[test]
1976    fn test_fetch_protected_resource_metadata() {
1977        smol::block_on(async {
1978            let client = make_fake_http_client(|req| {
1979                Box::pin(async move {
1980                    let uri = req.uri().to_string();
1981                    if uri.contains(".well-known/oauth-protected-resource") {
1982                        json_response(
1983                            200,
1984                            r#"{
1985                                "resource": "https://mcp.example.com",
1986                                "authorization_servers": ["https://auth.example.com"],
1987                                "scopes_supported": ["read", "write"]
1988                            }"#,
1989                        )
1990                    } else {
1991                        json_response(404, "{}")
1992                    }
1993                })
1994            });
1995
1996            let server_url = Url::parse("https://mcp.example.com").unwrap();
1997            let www_auth = WwwAuthenticate {
1998                resource_metadata: None,
1999                scope: None,
2000                error: None,
2001                error_description: None,
2002            };
2003
2004            let metadata = fetch_protected_resource_metadata(&client, &server_url, &www_auth)
2005                .await
2006                .unwrap();
2007
2008            assert_eq!(metadata.resource.as_str(), "https://mcp.example.com/");
2009            assert_eq!(metadata.authorization_servers.len(), 1);
2010            assert_eq!(
2011                metadata.authorization_servers[0].as_str(),
2012                "https://auth.example.com/"
2013            );
2014            assert_eq!(
2015                metadata.scopes_supported,
2016                Some(vec!["read".to_string(), "write".to_string()])
2017            );
2018        });
2019    }
2020
2021    #[test]
2022    fn test_fetch_protected_resource_metadata_prefers_www_authenticate_url() {
2023        smol::block_on(async {
2024            let client = make_fake_http_client(|req| {
2025                Box::pin(async move {
2026                    let uri = req.uri().to_string();
2027                    if uri == "https://mcp.example.com/custom-resource-metadata" {
2028                        json_response(
2029                            200,
2030                            r#"{
2031                                "resource": "https://mcp.example.com",
2032                                "authorization_servers": ["https://auth.example.com"]
2033                            }"#,
2034                        )
2035                    } else {
2036                        json_response(500, r#"{"error": "should not be called"}"#)
2037                    }
2038                })
2039            });
2040
2041            let server_url = Url::parse("https://mcp.example.com").unwrap();
2042            let www_auth = WwwAuthenticate {
2043                resource_metadata: Some(
2044                    Url::parse("https://mcp.example.com/custom-resource-metadata").unwrap(),
2045                ),
2046                scope: None,
2047                error: None,
2048                error_description: None,
2049            };
2050
2051            let metadata = fetch_protected_resource_metadata(&client, &server_url, &www_auth)
2052                .await
2053                .unwrap();
2054
2055            assert_eq!(metadata.authorization_servers.len(), 1);
2056        });
2057    }
2058
2059    #[test]
2060    fn test_fetch_protected_resource_metadata_falls_back_when_header_url_fails() {
2061        // Reproduces the Pydantic Logfire case: the server's WWW-Authenticate
2062        // header contains a resource_metadata URL with a doubled path (e.g.
2063        // /mcp/mcp), which returns HTML instead of JSON. The client should
2064        // fall back to the RFC 9728 well-known URL, which works correctly.
2065        smol::block_on(async {
2066            let client = make_fake_http_client(|req| {
2067                Box::pin(async move {
2068                    let uri = req.uri().to_string();
2069                    if uri
2070                        == "https://mcp.example.com/.well-known/oauth-protected-resource/api/mcp/mcp"
2071                    {
2072                        // Buggy header URL returns HTML (like a SPA catch-all).
2073                        Ok(Response::builder()
2074                            .status(200)
2075                            .header("Content-Type", "text/html")
2076                            .body(AsyncBody::from(b"<!doctype html><html></html>".to_vec()))
2077                            .unwrap())
2078                    } else if uri
2079                        == "https://mcp.example.com/.well-known/oauth-protected-resource/api/mcp"
2080                    {
2081                        // Correct well-known URL returns valid metadata.
2082                        json_response(
2083                            200,
2084                            r#"{
2085                                "resource": "https://mcp.example.com/api/mcp",
2086                                "authorization_servers": ["https://auth.example.com"]
2087                            }"#,
2088                        )
2089                    } else {
2090                        json_response(404, "{}")
2091                    }
2092                })
2093            });
2094
2095            let server_url = Url::parse("https://mcp.example.com/api/mcp").unwrap();
2096            let www_auth = WwwAuthenticate {
2097                resource_metadata: Some(
2098                    // Buggy URL with doubled path component.
2099                    Url::parse(
2100                        "https://mcp.example.com/.well-known/oauth-protected-resource/api/mcp/mcp",
2101                    )
2102                    .unwrap(),
2103                ),
2104                scope: None,
2105                error: None,
2106                error_description: None,
2107            };
2108
2109            let metadata = fetch_protected_resource_metadata(&client, &server_url, &www_auth)
2110                .await
2111                .unwrap();
2112
2113            assert_eq!(
2114                metadata.resource.as_str(),
2115                "https://mcp.example.com/api/mcp"
2116            );
2117            assert_eq!(
2118                metadata.authorization_servers[0].as_str(),
2119                "https://auth.example.com/"
2120            );
2121        });
2122    }
2123
2124    #[test]
2125    fn test_fetch_protected_resource_metadata_rejects_cross_origin_url() {
2126        smol::block_on(async {
2127            let client = make_fake_http_client(|req| {
2128                Box::pin(async move {
2129                    let uri = req.uri().to_string();
2130                    // The cross-origin URL should NOT be fetched; only the
2131                    // well-known fallback at the server's own origin should be.
2132                    if uri.contains("attacker.example.com") {
2133                        panic!("should not fetch cross-origin resource_metadata URL");
2134                    } else if uri.contains(".well-known/oauth-protected-resource") {
2135                        json_response(
2136                            200,
2137                            r#"{
2138                                "resource": "https://mcp.example.com",
2139                                "authorization_servers": ["https://auth.example.com"]
2140                            }"#,
2141                        )
2142                    } else {
2143                        json_response(404, "{}")
2144                    }
2145                })
2146            });
2147
2148            let server_url = Url::parse("https://mcp.example.com").unwrap();
2149            let www_auth = WwwAuthenticate {
2150                resource_metadata: Some(
2151                    Url::parse("https://attacker.example.com/fake-metadata").unwrap(),
2152                ),
2153                scope: None,
2154                error: None,
2155                error_description: None,
2156            };
2157
2158            let metadata = fetch_protected_resource_metadata(&client, &server_url, &www_auth)
2159                .await
2160                .unwrap();
2161
2162            // Should have used the fallback well-known URL, not the attacker's.
2163            assert_eq!(metadata.resource.as_str(), "https://mcp.example.com/");
2164        });
2165    }
2166
2167    #[test]
2168    fn test_fetch_auth_server_metadata() {
2169        smol::block_on(async {
2170            let client = make_fake_http_client(|req| {
2171                Box::pin(async move {
2172                    let uri = req.uri().to_string();
2173                    if uri.contains(".well-known/oauth-authorization-server") {
2174                        json_response(
2175                            200,
2176                            r#"{
2177                                "issuer": "https://auth.example.com",
2178                                "authorization_endpoint": "https://auth.example.com/authorize",
2179                                "token_endpoint": "https://auth.example.com/token",
2180                                "registration_endpoint": "https://auth.example.com/register",
2181                                "code_challenge_methods_supported": ["S256"],
2182                                "client_id_metadata_document_supported": true
2183                            }"#,
2184                        )
2185                    } else {
2186                        json_response(404, "{}")
2187                    }
2188                })
2189            });
2190
2191            let issuer = Url::parse("https://auth.example.com").unwrap();
2192            let metadata = fetch_auth_server_metadata(&client, &issuer).await.unwrap();
2193
2194            assert_eq!(metadata.issuer.as_str(), "https://auth.example.com/");
2195            assert_eq!(
2196                metadata.authorization_endpoint.as_str(),
2197                "https://auth.example.com/authorize"
2198            );
2199            assert_eq!(
2200                metadata.token_endpoint.as_str(),
2201                "https://auth.example.com/token"
2202            );
2203            assert!(metadata.registration_endpoint.is_some());
2204            assert!(metadata.client_id_metadata_document_supported);
2205            assert_eq!(
2206                metadata.code_challenge_methods_supported,
2207                Some(vec!["S256".to_string()])
2208            );
2209        });
2210    }
2211
2212    #[test]
2213    fn test_fetch_auth_server_metadata_falls_back_to_oidc() {
2214        smol::block_on(async {
2215            let client = make_fake_http_client(|req| {
2216                Box::pin(async move {
2217                    let uri = req.uri().to_string();
2218                    if uri.contains("openid-configuration") {
2219                        json_response(
2220                            200,
2221                            r#"{
2222                                "issuer": "https://auth.example.com",
2223                                "authorization_endpoint": "https://auth.example.com/authorize",
2224                                "token_endpoint": "https://auth.example.com/token",
2225                                "code_challenge_methods_supported": ["S256"]
2226                            }"#,
2227                        )
2228                    } else {
2229                        json_response(404, "{}")
2230                    }
2231                })
2232            });
2233
2234            let issuer = Url::parse("https://auth.example.com").unwrap();
2235            let metadata = fetch_auth_server_metadata(&client, &issuer).await.unwrap();
2236
2237            assert_eq!(
2238                metadata.authorization_endpoint.as_str(),
2239                "https://auth.example.com/authorize"
2240            );
2241            assert!(!metadata.client_id_metadata_document_supported);
2242        });
2243    }
2244
2245    #[test]
2246    fn test_fetch_auth_server_metadata_rejects_issuer_mismatch() {
2247        smol::block_on(async {
2248            let client = make_fake_http_client(|req| {
2249                Box::pin(async move {
2250                    let uri = req.uri().to_string();
2251                    if uri.contains(".well-known/oauth-authorization-server") {
2252                        // Response claims to be a different issuer.
2253                        json_response(
2254                            200,
2255                            r#"{
2256                                "issuer": "https://evil.example.com",
2257                                "authorization_endpoint": "https://evil.example.com/authorize",
2258                                "token_endpoint": "https://evil.example.com/token",
2259                                "code_challenge_methods_supported": ["S256"]
2260                            }"#,
2261                        )
2262                    } else {
2263                        json_response(404, "{}")
2264                    }
2265                })
2266            });
2267
2268            let issuer = Url::parse("https://auth.example.com").unwrap();
2269            let result = fetch_auth_server_metadata(&client, &issuer).await;
2270
2271            assert!(result.is_err());
2272            let err_msg = result.unwrap_err().to_string();
2273            assert!(
2274                err_msg.contains("issuer mismatch"),
2275                "unexpected error: {}",
2276                err_msg
2277            );
2278        });
2279    }
2280
2281    // -- Full discover integration tests -------------------------------------
2282
2283    #[test]
2284    fn test_full_discover_with_cimd() {
2285        smol::block_on(async {
2286            let client = make_fake_http_client(|req| {
2287                Box::pin(async move {
2288                    let uri = req.uri().to_string();
2289                    if uri.contains("oauth-protected-resource") {
2290                        json_response(
2291                            200,
2292                            r#"{
2293                                "resource": "https://mcp.example.com",
2294                                "authorization_servers": ["https://auth.example.com"],
2295                                "scopes_supported": ["mcp:read"]
2296                            }"#,
2297                        )
2298                    } else if uri.contains("oauth-authorization-server") {
2299                        json_response(
2300                            200,
2301                            r#"{
2302                                "issuer": "https://auth.example.com",
2303                                "authorization_endpoint": "https://auth.example.com/authorize",
2304                                "token_endpoint": "https://auth.example.com/token",
2305                                "code_challenge_methods_supported": ["S256"],
2306                                "client_id_metadata_document_supported": true
2307                            }"#,
2308                        )
2309                    } else {
2310                        json_response(404, "{}")
2311                    }
2312                })
2313            });
2314
2315            let server_url = Url::parse("https://mcp.example.com").unwrap();
2316            let www_auth = WwwAuthenticate {
2317                resource_metadata: None,
2318                scope: None,
2319                error: None,
2320                error_description: None,
2321            };
2322
2323            let discovery = discover(&client, &server_url, &www_auth).await.unwrap();
2324            let registration =
2325                resolve_client_registration(&client, &discovery, "http://127.0.0.1:12345/callback")
2326                    .await
2327                    .unwrap();
2328
2329            assert_eq!(registration.client_id, CIMD_URL);
2330            assert_eq!(registration.client_secret, None);
2331            assert_eq!(discovery.scopes, vec!["mcp:read"]);
2332        });
2333    }
2334
2335    #[test]
2336    fn test_full_discover_with_dcr_fallback() {
2337        smol::block_on(async {
2338            let client = make_fake_http_client(|req| {
2339                Box::pin(async move {
2340                    let uri = req.uri().to_string();
2341                    if uri.contains("oauth-protected-resource") {
2342                        json_response(
2343                            200,
2344                            r#"{
2345                                "resource": "https://mcp.example.com",
2346                                "authorization_servers": ["https://auth.example.com"]
2347                            }"#,
2348                        )
2349                    } else if uri.contains("oauth-authorization-server") {
2350                        json_response(
2351                            200,
2352                            r#"{
2353                                "issuer": "https://auth.example.com",
2354                                "authorization_endpoint": "https://auth.example.com/authorize",
2355                                "token_endpoint": "https://auth.example.com/token",
2356                                "registration_endpoint": "https://auth.example.com/register",
2357                                "code_challenge_methods_supported": ["S256"],
2358                                "client_id_metadata_document_supported": false
2359                            }"#,
2360                        )
2361                    } else if uri.contains("/register") {
2362                        json_response(
2363                            201,
2364                            r#"{
2365                                "client_id": "dcr-minted-id-123",
2366                                "client_secret": "dcr-secret-456"
2367                            }"#,
2368                        )
2369                    } else {
2370                        json_response(404, "{}")
2371                    }
2372                })
2373            });
2374
2375            let server_url = Url::parse("https://mcp.example.com").unwrap();
2376            let www_auth = WwwAuthenticate {
2377                resource_metadata: None,
2378                scope: Some(vec!["files:read".into()]),
2379                error: None,
2380                error_description: None,
2381            };
2382
2383            let discovery = discover(&client, &server_url, &www_auth).await.unwrap();
2384            let registration =
2385                resolve_client_registration(&client, &discovery, "http://127.0.0.1:9999/callback")
2386                    .await
2387                    .unwrap();
2388
2389            assert_eq!(registration.client_id, "dcr-minted-id-123");
2390            assert_eq!(
2391                registration.client_secret.as_deref(),
2392                Some("dcr-secret-456")
2393            );
2394            assert_eq!(discovery.scopes, vec!["files:read"]);
2395        });
2396    }
2397
2398    #[test]
2399    fn test_discover_fails_without_pkce_support() {
2400        smol::block_on(async {
2401            let client = make_fake_http_client(|req| {
2402                Box::pin(async move {
2403                    let uri = req.uri().to_string();
2404                    if uri.contains("oauth-protected-resource") {
2405                        json_response(
2406                            200,
2407                            r#"{
2408                                "resource": "https://mcp.example.com",
2409                                "authorization_servers": ["https://auth.example.com"]
2410                            }"#,
2411                        )
2412                    } else if uri.contains("oauth-authorization-server") {
2413                        json_response(
2414                            200,
2415                            r#"{
2416                                "issuer": "https://auth.example.com",
2417                                "authorization_endpoint": "https://auth.example.com/authorize",
2418                                "token_endpoint": "https://auth.example.com/token"
2419                            }"#,
2420                        )
2421                    } else {
2422                        json_response(404, "{}")
2423                    }
2424                })
2425            });
2426
2427            let server_url = Url::parse("https://mcp.example.com").unwrap();
2428            let www_auth = WwwAuthenticate {
2429                resource_metadata: None,
2430                scope: None,
2431                error: None,
2432                error_description: None,
2433            };
2434
2435            let result = discover(&client, &server_url, &www_auth).await;
2436            assert!(result.is_err());
2437            let err_msg = result.unwrap_err().to_string();
2438            assert!(
2439                err_msg.contains("code_challenge_methods_supported"),
2440                "unexpected error: {}",
2441                err_msg
2442            );
2443        });
2444    }
2445
2446    // -- Token exchange integration tests ------------------------------------
2447
2448    #[test]
2449    fn test_exchange_code_success() {
2450        smol::block_on(async {
2451            let client = make_fake_http_client(|req| {
2452                Box::pin(async move {
2453                    let uri = req.uri().to_string();
2454                    if uri.contains("/token") {
2455                        json_response(
2456                            200,
2457                            r#"{
2458                                "access_token": "new_access_token",
2459                                "refresh_token": "new_refresh_token",
2460                                "expires_in": 3600,
2461                                "token_type": "Bearer"
2462                            }"#,
2463                        )
2464                    } else {
2465                        json_response(404, "{}")
2466                    }
2467                })
2468            });
2469
2470            let metadata = AuthServerMetadata {
2471                issuer: Url::parse("https://auth.example.com").unwrap(),
2472                authorization_endpoint: Url::parse("https://auth.example.com/authorize").unwrap(),
2473                token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
2474                registration_endpoint: None,
2475                scopes_supported: None,
2476                code_challenge_methods_supported: Some(vec!["S256".into()]),
2477                client_id_metadata_document_supported: true,
2478            };
2479
2480            let tokens = exchange_code(
2481                &client,
2482                &metadata,
2483                "auth_code_123",
2484                CIMD_URL,
2485                "http://127.0.0.1:9999/callback",
2486                "verifier_abc",
2487                "https://mcp.example.com",
2488            )
2489            .await
2490            .unwrap();
2491
2492            assert_eq!(tokens.access_token, "new_access_token");
2493            assert_eq!(tokens.refresh_token.as_deref(), Some("new_refresh_token"));
2494            assert!(tokens.expires_at.is_some());
2495        });
2496    }
2497
2498    #[test]
2499    fn test_refresh_tokens_success() {
2500        smol::block_on(async {
2501            let client = make_fake_http_client(|req| {
2502                Box::pin(async move {
2503                    let uri = req.uri().to_string();
2504                    if uri.contains("/token") {
2505                        json_response(
2506                            200,
2507                            r#"{
2508                                "access_token": "refreshed_token",
2509                                "expires_in": 1800,
2510                                "token_type": "Bearer"
2511                            }"#,
2512                        )
2513                    } else {
2514                        json_response(404, "{}")
2515                    }
2516                })
2517            });
2518
2519            let token_endpoint = Url::parse("https://auth.example.com/token").unwrap();
2520
2521            let tokens = refresh_tokens(
2522                &client,
2523                &token_endpoint,
2524                "old_refresh_token",
2525                CIMD_URL,
2526                "https://mcp.example.com",
2527            )
2528            .await
2529            .unwrap();
2530
2531            assert_eq!(tokens.access_token, "refreshed_token");
2532            assert_eq!(tokens.refresh_token, None);
2533            assert!(tokens.expires_at.is_some());
2534        });
2535    }
2536
2537    #[test]
2538    fn test_exchange_code_failure() {
2539        smol::block_on(async {
2540            let client = make_fake_http_client(|_req| {
2541                Box::pin(async move { json_response(400, r#"{"error": "invalid_grant"}"#) })
2542            });
2543
2544            let metadata = AuthServerMetadata {
2545                issuer: Url::parse("https://auth.example.com").unwrap(),
2546                authorization_endpoint: Url::parse("https://auth.example.com/authorize").unwrap(),
2547                token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
2548                registration_endpoint: None,
2549                scopes_supported: None,
2550                code_challenge_methods_supported: Some(vec!["S256".into()]),
2551                client_id_metadata_document_supported: true,
2552            };
2553
2554            let result = exchange_code(
2555                &client,
2556                &metadata,
2557                "bad_code",
2558                "client",
2559                "http://127.0.0.1:1/callback",
2560                "verifier",
2561                "https://mcp.example.com",
2562            )
2563            .await;
2564
2565            assert!(result.is_err());
2566            assert!(result.unwrap_err().to_string().contains("400"));
2567        });
2568    }
2569
2570    // -- DCR integration tests -----------------------------------------------
2571
2572    #[test]
2573    fn test_perform_dcr() {
2574        smol::block_on(async {
2575            let client = make_fake_http_client(|_req| {
2576                Box::pin(async move {
2577                    json_response(
2578                        201,
2579                        r#"{
2580                            "client_id": "dynamic-client-001",
2581                            "client_secret": "dynamic-secret-001"
2582                        }"#,
2583                    )
2584                })
2585            });
2586
2587            let endpoint = Url::parse("https://auth.example.com/register").unwrap();
2588            let registration = perform_dcr(&client, &endpoint, "http://127.0.0.1:9999/callback")
2589                .await
2590                .unwrap();
2591
2592            assert_eq!(registration.client_id, "dynamic-client-001");
2593            assert_eq!(
2594                registration.client_secret.as_deref(),
2595                Some("dynamic-secret-001")
2596            );
2597        });
2598    }
2599
2600    #[test]
2601    fn test_perform_dcr_failure() {
2602        smol::block_on(async {
2603            let client = make_fake_http_client(|_req| {
2604                Box::pin(
2605                    async move { json_response(403, r#"{"error": "registration_not_allowed"}"#) },
2606                )
2607            });
2608
2609            let endpoint = Url::parse("https://auth.example.com/register").unwrap();
2610            let result = perform_dcr(&client, &endpoint, "http://127.0.0.1:9999/callback").await;
2611
2612            assert!(result.is_err());
2613            assert!(result.unwrap_err().to_string().contains("403"));
2614        });
2615    }
2616
2617    // -- OAuthCallback parse tests -------------------------------------------
2618
2619    #[test]
2620    fn test_oauth_callback_parse_query() {
2621        let callback = OAuthCallback::parse_query("code=test_auth_code&state=test_state").unwrap();
2622        assert_eq!(callback.code, "test_auth_code");
2623        assert_eq!(callback.state, "test_state");
2624    }
2625
2626    #[test]
2627    fn test_oauth_callback_parse_query_reversed_order() {
2628        let callback = OAuthCallback::parse_query("state=test_state&code=test_auth_code").unwrap();
2629        assert_eq!(callback.code, "test_auth_code");
2630        assert_eq!(callback.state, "test_state");
2631    }
2632
2633    #[test]
2634    fn test_oauth_callback_parse_query_with_extra_params() {
2635        let callback =
2636            OAuthCallback::parse_query("code=test_auth_code&state=test_state&extra=ignored")
2637                .unwrap();
2638        assert_eq!(callback.code, "test_auth_code");
2639        assert_eq!(callback.state, "test_state");
2640    }
2641
2642    #[test]
2643    fn test_oauth_callback_parse_query_missing_code() {
2644        let result = OAuthCallback::parse_query("state=test_state");
2645        assert!(result.is_err());
2646        assert!(result.unwrap_err().to_string().contains("code"));
2647    }
2648
2649    #[test]
2650    fn test_oauth_callback_parse_query_missing_state() {
2651        let result = OAuthCallback::parse_query("code=test_auth_code");
2652        assert!(result.is_err());
2653        assert!(result.unwrap_err().to_string().contains("state"));
2654    }
2655
2656    #[test]
2657    fn test_oauth_callback_parse_query_empty_code() {
2658        let result = OAuthCallback::parse_query("code=&state=test_state");
2659        assert!(result.is_err());
2660    }
2661
2662    #[test]
2663    fn test_oauth_callback_parse_query_empty_state() {
2664        let result = OAuthCallback::parse_query("code=test_auth_code&state=");
2665        assert!(result.is_err());
2666    }
2667
2668    #[test]
2669    fn test_oauth_callback_parse_query_url_encoded_values() {
2670        let callback = OAuthCallback::parse_query("code=abc%20def&state=test%3Dstate").unwrap();
2671        assert_eq!(callback.code, "abc def");
2672        assert_eq!(callback.state, "test=state");
2673    }
2674
2675    #[test]
2676    fn test_oauth_callback_parse_query_error_response() {
2677        let result = OAuthCallback::parse_query(
2678            "error=access_denied&error_description=User%20denied%20access&state=abc",
2679        );
2680        assert!(result.is_err());
2681        let err_msg = result.unwrap_err().to_string();
2682        assert!(
2683            err_msg.contains("access_denied"),
2684            "unexpected error: {}",
2685            err_msg
2686        );
2687        assert!(
2688            err_msg.contains("User denied access"),
2689            "unexpected error: {}",
2690            err_msg
2691        );
2692    }
2693
2694    #[test]
2695    fn test_oauth_callback_parse_query_error_without_description() {
2696        let result = OAuthCallback::parse_query("error=server_error&state=abc");
2697        assert!(result.is_err());
2698        let err_msg = result.unwrap_err().to_string();
2699        assert!(
2700            err_msg.contains("server_error"),
2701            "unexpected error: {}",
2702            err_msg
2703        );
2704        assert!(
2705            err_msg.contains("no description"),
2706            "unexpected error: {}",
2707            err_msg
2708        );
2709    }
2710
2711    // -- McpOAuthTokenProvider tests -----------------------------------------
2712
2713    fn make_test_session(
2714        access_token: &str,
2715        refresh_token: Option<&str>,
2716        expires_at: Option<SystemTime>,
2717    ) -> OAuthSession {
2718        OAuthSession {
2719            token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
2720            resource: Url::parse("https://mcp.example.com").unwrap(),
2721            client_registration: OAuthClientRegistration {
2722                client_id: "test-client".into(),
2723                client_secret: None,
2724            },
2725            tokens: OAuthTokens {
2726                access_token: access_token.into(),
2727                refresh_token: refresh_token.map(String::from),
2728                expires_at,
2729            },
2730        }
2731    }
2732
2733    #[test]
2734    fn test_mcp_oauth_provider_returns_none_when_token_expired() {
2735        let expired = SystemTime::now() - Duration::from_secs(60);
2736        let session = make_test_session("stale-token", Some("rt"), Some(expired));
2737        let provider = McpOAuthTokenProvider::new(
2738            session,
2739            make_fake_http_client(|_| Box::pin(async { unreachable!() })),
2740            None,
2741        );
2742
2743        assert_eq!(provider.access_token(), None);
2744    }
2745
2746    #[test]
2747    fn test_mcp_oauth_provider_returns_token_when_not_expired() {
2748        let far_future = SystemTime::now() + Duration::from_secs(3600);
2749        let session = make_test_session("valid-token", Some("rt"), Some(far_future));
2750        let provider = McpOAuthTokenProvider::new(
2751            session,
2752            make_fake_http_client(|_| Box::pin(async { unreachable!() })),
2753            None,
2754        );
2755
2756        assert_eq!(provider.access_token().as_deref(), Some("valid-token"));
2757    }
2758
2759    #[test]
2760    fn test_mcp_oauth_provider_returns_token_when_no_expiry() {
2761        let session = make_test_session("no-expiry-token", Some("rt"), None);
2762        let provider = McpOAuthTokenProvider::new(
2763            session,
2764            make_fake_http_client(|_| Box::pin(async { unreachable!() })),
2765            None,
2766        );
2767
2768        assert_eq!(provider.access_token().as_deref(), Some("no-expiry-token"));
2769    }
2770
2771    #[test]
2772    fn test_mcp_oauth_provider_refresh_without_refresh_token_returns_false() {
2773        smol::block_on(async {
2774            let session = make_test_session("token", None, None);
2775            let provider = McpOAuthTokenProvider::new(
2776                session,
2777                make_fake_http_client(|_| {
2778                    Box::pin(async { unreachable!("no HTTP call expected") })
2779                }),
2780                None,
2781            );
2782
2783            let refreshed = provider.try_refresh().await.unwrap();
2784            assert!(!refreshed);
2785        });
2786    }
2787
2788    #[test]
2789    fn test_mcp_oauth_provider_refresh_updates_session_and_notifies_channel() {
2790        smol::block_on(async {
2791            let session = make_test_session("old-access", Some("my-refresh-token"), None);
2792            let (tx, mut rx) = futures::channel::mpsc::unbounded();
2793
2794            let http_client = make_fake_http_client(|_req| {
2795                Box::pin(async {
2796                    json_response(
2797                        200,
2798                        r#"{
2799                            "access_token": "new-access",
2800                            "refresh_token": "new-refresh",
2801                            "expires_in": 1800
2802                        }"#,
2803                    )
2804                })
2805            });
2806
2807            let provider = McpOAuthTokenProvider::new(session, http_client, Some(tx));
2808
2809            let refreshed = provider.try_refresh().await.unwrap();
2810            assert!(refreshed);
2811            assert_eq!(provider.access_token().as_deref(), Some("new-access"));
2812
2813            let notified_session = rx.try_recv().expect("channel should have a session");
2814            assert_eq!(notified_session.tokens.access_token, "new-access");
2815            assert_eq!(
2816                notified_session.tokens.refresh_token.as_deref(),
2817                Some("new-refresh")
2818            );
2819        });
2820    }
2821
2822    #[test]
2823    fn test_mcp_oauth_provider_refresh_preserves_old_refresh_token_when_server_omits_it() {
2824        smol::block_on(async {
2825            let session = make_test_session("old-access", Some("original-refresh"), None);
2826            let (tx, mut rx) = futures::channel::mpsc::unbounded();
2827
2828            let http_client = make_fake_http_client(|_req| {
2829                Box::pin(async {
2830                    json_response(
2831                        200,
2832                        r#"{
2833                            "access_token": "new-access",
2834                            "expires_in": 900
2835                        }"#,
2836                    )
2837                })
2838            });
2839
2840            let provider = McpOAuthTokenProvider::new(session, http_client, Some(tx));
2841
2842            let refreshed = provider.try_refresh().await.unwrap();
2843            assert!(refreshed);
2844
2845            let notified_session = rx.try_recv().expect("channel should have a session");
2846            assert_eq!(notified_session.tokens.access_token, "new-access");
2847            assert_eq!(
2848                notified_session.tokens.refresh_token.as_deref(),
2849                Some("original-refresh"),
2850            );
2851        });
2852    }
2853
2854    #[test]
2855    fn test_mcp_oauth_provider_refresh_returns_false_on_http_error() {
2856        smol::block_on(async {
2857            let session = make_test_session("old-access", Some("my-refresh"), None);
2858
2859            let http_client = make_fake_http_client(|_req| {
2860                Box::pin(async { json_response(401, r#"{"error": "invalid_grant"}"#) })
2861            });
2862
2863            let provider = McpOAuthTokenProvider::new(session, http_client, None);
2864
2865            let refreshed = provider.try_refresh().await.unwrap();
2866            assert!(!refreshed);
2867            // The old token should still be in place.
2868            assert_eq!(provider.access_token().as_deref(), Some("old-access"));
2869        });
2870    }
2871}