oauth.rs

   1//! OAuth 2.0 authentication for MCP servers using the Authorization Code +
   2//! PKCE flow, per the MCP spec's OAuth profile.
   3//!
   4//! The flow is split into two phases:
   5//!
   6//! 1. **Discovery** ([`discover`]) fetches Protected Resource Metadata and
   7//!    Authorization Server Metadata. This can happen early (e.g. on a 401
   8//!    during server startup) because it doesn't need the redirect URI yet.
   9//!
  10//! 2. **Client registration** ([`resolve_client_registration`]) is separate
  11//!    because DCR requires the actual loopback redirect URI, which includes an
  12//!    ephemeral port that only exists once the callback server has started.
  13//!
  14//! After authentication, the full state is captured in [`OAuthSession`] which
  15//! is persisted to the keychain. On next startup, the stored session feeds
  16//! directly into [`McpOAuthTokenProvider`], giving a refresh-capable provider
  17//! without requiring another browser flow.
  18
  19use anyhow::{Context as _, Result, anyhow, bail};
  20use async_trait::async_trait;
  21use base64::Engine as _;
  22use futures::AsyncReadExt as _;
  23use futures::channel::mpsc;
  24use http_client::{AsyncBody, HttpClient, Request};
  25use parking_lot::Mutex as SyncMutex;
  26use rand::Rng as _;
  27use serde::{Deserialize, Serialize};
  28use sha2::{Digest, Sha256};
  29
  30use std::str::FromStr;
  31use std::sync::Arc;
  32use std::time::{Duration, SystemTime};
  33use url::Url;
  34use util::ResultExt as _;
  35
  36/// The CIMD URL where Zed's OAuth client metadata document is hosted.
  37pub const CIMD_URL: &str = "https://zed.dev/oauth/client-metadata.json";
  38
  39/// Validate that a URL is safe to use as an OAuth endpoint.
  40///
  41/// OAuth endpoints carry sensitive material (authorization codes, PKCE
  42/// verifiers, tokens) and must use TLS. Plain HTTP is only permitted for
  43/// loopback addresses, per RFC 8252 Section 8.3.
  44fn require_https_or_loopback(url: &Url) -> Result<()> {
  45    if url.scheme() == "https" {
  46        return Ok(());
  47    }
  48    if url.scheme() == "http" {
  49        if let Some(host) = url.host() {
  50            match host {
  51                url::Host::Ipv4(ip) if ip.is_loopback() => return Ok(()),
  52                url::Host::Ipv6(ip) if ip.is_loopback() => return Ok(()),
  53                url::Host::Domain(d) if d.eq_ignore_ascii_case("localhost") => return Ok(()),
  54                _ => {}
  55            }
  56        }
  57    }
  58    bail!(
  59        "OAuth endpoint must use HTTPS (got {}://{})",
  60        url.scheme(),
  61        url.host_str().unwrap_or("?")
  62    )
  63}
  64
  65/// Validate that a URL is safe to use as an OAuth endpoint, including SSRF
  66/// protections against private/reserved IP ranges.
  67///
  68/// This wraps [`require_https_or_loopback`] and adds IP-range checks to prevent
  69/// an attacker-controlled MCP server from directing Zed to fetch internal
  70/// network resources via metadata URLs.
  71///
  72/// **Known limitation:** Domain-name URLs that resolve to private IPs are *not*
  73/// blocked here — full mitigation requires resolver-level validation (e.g. a
  74/// custom `Resolve` implementation). This function only blocks IP-literal URLs.
  75fn validate_oauth_url(url: &Url) -> Result<()> {
  76    require_https_or_loopback(url)?;
  77
  78    if let Some(host) = url.host() {
  79        match host {
  80            url::Host::Ipv4(ip) => {
  81                // Loopback is already allowed by require_https_or_loopback.
  82                if ip.is_private() || ip.is_link_local() || ip.is_broadcast() || ip.is_unspecified()
  83                {
  84                    bail!(
  85                        "OAuth endpoint must not point to private/reserved IP: {}",
  86                        ip
  87                    );
  88                }
  89            }
  90            url::Host::Ipv6(ip) => {
  91                // Check for IPv4-mapped IPv6 addresses (::ffff:a.b.c.d) which
  92                // could bypass the IPv4 checks above.
  93                if let Some(mapped_v4) = ip.to_ipv4_mapped() {
  94                    if mapped_v4.is_private()
  95                        || mapped_v4.is_link_local()
  96                        || mapped_v4.is_broadcast()
  97                        || mapped_v4.is_unspecified()
  98                    {
  99                        bail!(
 100                            "OAuth endpoint must not point to private/reserved IP: ::ffff:{}",
 101                            mapped_v4
 102                        );
 103                    }
 104                }
 105
 106                if ip.is_unspecified() || ip.is_multicast() {
 107                    bail!(
 108                        "OAuth endpoint must not point to reserved IPv6 address: {}",
 109                        ip
 110                    );
 111                }
 112                // IPv6 Unique Local Addresses (fc00::/7). is_unique_local() is
 113                // nightly-only, so check the prefix manually.
 114                if (ip.segments()[0] & 0xfe00) == 0xfc00 {
 115                    bail!(
 116                        "OAuth endpoint must not point to IPv6 unique-local address: {}",
 117                        ip
 118                    );
 119                }
 120            }
 121            url::Host::Domain(_) => {
 122                // Domain-based SSRF prevention requires resolver-level checks.
 123                // See known limitation in the doc comment above.
 124            }
 125        }
 126    }
 127
 128    Ok(())
 129}
 130
 131/// Parsed from the MCP server's WWW-Authenticate header or well-known endpoint
 132/// per RFC 9728 (OAuth 2.0 Protected Resource Metadata).
 133#[derive(Debug, Clone, Serialize, Deserialize)]
 134pub struct ProtectedResourceMetadata {
 135    pub resource: Url,
 136    pub authorization_servers: Vec<Url>,
 137    pub scopes_supported: Option<Vec<String>>,
 138}
 139
 140/// Parsed from the authorization server's .well-known endpoint
 141/// per RFC 8414 (OAuth 2.0 Authorization Server Metadata).
 142#[derive(Debug, Clone, Serialize, Deserialize)]
 143pub struct AuthServerMetadata {
 144    pub issuer: Url,
 145    pub authorization_endpoint: Url,
 146    pub token_endpoint: Url,
 147    pub registration_endpoint: Option<Url>,
 148    pub scopes_supported: Option<Vec<String>>,
 149    pub grant_types_supported: Option<Vec<String>>,
 150    pub code_challenge_methods_supported: Option<Vec<String>>,
 151    pub client_id_metadata_document_supported: bool,
 152}
 153
 154/// The result of client registration — either CIMD or DCR.
 155#[derive(Clone, Serialize, Deserialize)]
 156pub struct OAuthClientRegistration {
 157    pub client_id: String,
 158    /// Only present for DCR-minted registrations.
 159    pub client_secret: Option<String>,
 160}
 161
 162impl std::fmt::Debug for OAuthClientRegistration {
 163    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 164        f.debug_struct("OAuthClientRegistration")
 165            .field("client_id", &self.client_id)
 166            .field(
 167                "client_secret",
 168                &self.client_secret.as_ref().map(|_| "[redacted]"),
 169            )
 170            .finish()
 171    }
 172}
 173
 174/// Access and refresh tokens obtained from the token endpoint.
 175#[derive(Clone, Serialize, Deserialize)]
 176pub struct OAuthTokens {
 177    pub access_token: String,
 178    pub refresh_token: Option<String>,
 179    pub expires_at: Option<SystemTime>,
 180}
 181
 182impl std::fmt::Debug for OAuthTokens {
 183    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 184        f.debug_struct("OAuthTokens")
 185            .field("access_token", &"[redacted]")
 186            .field(
 187                "refresh_token",
 188                &self.refresh_token.as_ref().map(|_| "[redacted]"),
 189            )
 190            .field("expires_at", &self.expires_at)
 191            .finish()
 192    }
 193}
 194
 195/// Everything discovered before the browser flow starts. Client registration is
 196/// resolved separately, once the real redirect URI is known.
 197#[derive(Debug, Clone, Serialize, Deserialize)]
 198pub struct OAuthDiscovery {
 199    pub resource_metadata: ProtectedResourceMetadata,
 200    pub auth_server_metadata: AuthServerMetadata,
 201    pub scopes: Vec<String>,
 202}
 203
 204/// The persisted OAuth session for a context server.
 205///
 206/// Stored in the keychain so startup can restore a refresh-capable provider
 207/// without another browser flow. Deliberately excludes the full discovery
 208/// metadata to keep the serialized size well within keychain item limits.
 209#[derive(Clone, Serialize, Deserialize)]
 210pub struct OAuthSession {
 211    pub token_endpoint: Url,
 212    pub resource: Url,
 213    pub client_registration: OAuthClientRegistration,
 214    pub tokens: OAuthTokens,
 215}
 216
 217impl std::fmt::Debug for OAuthSession {
 218    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 219        f.debug_struct("OAuthSession")
 220            .field("token_endpoint", &self.token_endpoint)
 221            .field("resource", &self.resource)
 222            .field("client_registration", &self.client_registration)
 223            .field("tokens", &self.tokens)
 224            .finish()
 225    }
 226}
 227
 228/// Error codes defined by RFC 6750 Section 3.1 for Bearer token authentication.
 229#[derive(Debug, Clone, Copy, PartialEq, Eq)]
 230pub enum BearerError {
 231    /// The request is missing a required parameter, includes an unsupported
 232    /// parameter or parameter value, or is otherwise malformed.
 233    InvalidRequest,
 234    /// The access token provided is expired, revoked, malformed, or invalid.
 235    InvalidToken,
 236    /// The request requires higher privileges than provided by the access token.
 237    InsufficientScope,
 238    /// An unrecognized error code (extension or future spec addition).
 239    Other,
 240}
 241
 242impl BearerError {
 243    fn parse(value: &str) -> Self {
 244        match value {
 245            "invalid_request" => BearerError::InvalidRequest,
 246            "invalid_token" => BearerError::InvalidToken,
 247            "insufficient_scope" => BearerError::InsufficientScope,
 248            _ => BearerError::Other,
 249        }
 250    }
 251}
 252
 253/// Fields extracted from a `WWW-Authenticate: Bearer` header.
 254///
 255/// Per RFC 9728 Section 5.1, MCP servers include `resource_metadata` to point
 256/// at the Protected Resource Metadata document. The optional `scope` parameter
 257/// (RFC 6750 Section 3) indicates scopes required for the request.
 258#[derive(Debug, Clone, PartialEq, Eq)]
 259pub struct WwwAuthenticate {
 260    pub resource_metadata: Option<Url>,
 261    pub scope: Option<Vec<String>>,
 262    /// The parsed `error` parameter per RFC 6750 Section 3.1.
 263    pub error: Option<BearerError>,
 264    pub error_description: Option<String>,
 265}
 266
 267/// Parse a `WWW-Authenticate` header value.
 268///
 269/// Expects the `Bearer` scheme followed by comma-separated `key="value"` pairs.
 270/// Per RFC 6750 and RFC 9728, the relevant parameters are:
 271/// - `resource_metadata` — URL of the Protected Resource Metadata document
 272/// - `scope` — space-separated list of required scopes
 273/// - `error` — error code (e.g. "insufficient_scope")
 274/// - `error_description` — human-readable error description
 275pub fn parse_www_authenticate(header: &str) -> Result<WwwAuthenticate> {
 276    let header = header.trim();
 277
 278    let params_str = if header.len() >= 6 && header[..6].eq_ignore_ascii_case("bearer") {
 279        header[6..].trim()
 280    } else {
 281        bail!("WWW-Authenticate header does not use Bearer scheme");
 282    };
 283
 284    if params_str.is_empty() {
 285        return Ok(WwwAuthenticate {
 286            resource_metadata: None,
 287            scope: None,
 288            error: None,
 289            error_description: None,
 290        });
 291    }
 292
 293    let params = parse_auth_params(params_str);
 294
 295    let resource_metadata = params
 296        .get("resource_metadata")
 297        .map(|v| Url::parse(v))
 298        .transpose()
 299        .map_err(|e| anyhow!("invalid resource_metadata URL: {}", e))?;
 300
 301    let scope = params
 302        .get("scope")
 303        .map(|v| v.split_whitespace().map(String::from).collect());
 304
 305    let error = params.get("error").map(|v| BearerError::parse(v));
 306    let error_description = params.get("error_description").cloned();
 307
 308    Ok(WwwAuthenticate {
 309        resource_metadata,
 310        scope,
 311        error,
 312        error_description,
 313    })
 314}
 315
 316/// Parse comma-separated `key="value"` or `key=token` parameters from an
 317/// auth-param list (RFC 7235 Section 2.1).
 318fn parse_auth_params(input: &str) -> collections::HashMap<String, String> {
 319    let mut params = collections::HashMap::default();
 320    let mut remaining = input.trim();
 321
 322    while !remaining.is_empty() {
 323        // Skip leading whitespace and commas.
 324        remaining = remaining.trim_start_matches(|c: char| c == ',' || c.is_whitespace());
 325        if remaining.is_empty() {
 326            break;
 327        }
 328
 329        // Find the key (everything before '=').
 330        let eq_pos = match remaining.find('=') {
 331            Some(pos) => pos,
 332            None => break,
 333        };
 334
 335        let key = remaining[..eq_pos].trim().to_lowercase();
 336        remaining = &remaining[eq_pos + 1..];
 337        remaining = remaining.trim_start();
 338
 339        // Parse the value: either quoted or unquoted (token).
 340        let value;
 341        if remaining.starts_with('"') {
 342            // Quoted string: find the closing quote, handling escaped chars.
 343            remaining = &remaining[1..]; // skip opening quote
 344            let mut val = String::new();
 345            let mut chars = remaining.char_indices();
 346            loop {
 347                match chars.next() {
 348                    Some((_, '\\')) => {
 349                        // Escaped character — take the next char literally.
 350                        if let Some((_, c)) = chars.next() {
 351                            val.push(c);
 352                        }
 353                    }
 354                    Some((i, '"')) => {
 355                        remaining = &remaining[i + 1..];
 356                        break;
 357                    }
 358                    Some((_, c)) => val.push(c),
 359                    None => {
 360                        remaining = "";
 361                        break;
 362                    }
 363                }
 364            }
 365            value = val;
 366        } else {
 367            // Unquoted token: read until comma or whitespace.
 368            let end = remaining
 369                .find(|c: char| c == ',' || c.is_whitespace())
 370                .unwrap_or(remaining.len());
 371            value = remaining[..end].to_string();
 372            remaining = &remaining[end..];
 373        }
 374
 375        if !key.is_empty() {
 376            params.insert(key, value);
 377        }
 378    }
 379
 380    params
 381}
 382
 383/// Construct the well-known Protected Resource Metadata URIs for a given MCP
 384/// server URL, per RFC 9728 Section 3.
 385///
 386/// Returns URIs in priority order:
 387/// 1. Path-specific: `https://<host>/.well-known/oauth-protected-resource/<path>`
 388/// 2. Root: `https://<host>/.well-known/oauth-protected-resource`
 389pub fn protected_resource_metadata_urls(server_url: &Url) -> Vec<Url> {
 390    let mut urls = Vec::new();
 391    let base = format!("{}://{}", server_url.scheme(), server_url.authority());
 392
 393    let path = server_url.path().trim_start_matches('/');
 394    if !path.is_empty() {
 395        if let Ok(url) = Url::parse(&format!(
 396            "{}/.well-known/oauth-protected-resource/{}",
 397            base, path
 398        )) {
 399            urls.push(url);
 400        }
 401    }
 402
 403    if let Ok(url) = Url::parse(&format!("{}/.well-known/oauth-protected-resource", base)) {
 404        urls.push(url);
 405    }
 406
 407    urls
 408}
 409
 410/// Construct the well-known Authorization Server Metadata URIs for a given
 411/// issuer URL, per RFC 8414 Section 3.1 and Section 5 (OIDC compat).
 412///
 413/// Returns URIs in priority order, which differs depending on whether the
 414/// issuer URL has a path component.
 415pub fn auth_server_metadata_urls(issuer: &Url) -> Vec<Url> {
 416    let mut urls = Vec::new();
 417    let base = format!("{}://{}", issuer.scheme(), issuer.authority());
 418    let path = issuer.path().trim_matches('/');
 419
 420    if !path.is_empty() {
 421        // Issuer with path: try path-inserted variants first.
 422        if let Ok(url) = Url::parse(&format!(
 423            "{}/.well-known/oauth-authorization-server/{}",
 424            base, path
 425        )) {
 426            urls.push(url);
 427        }
 428        if let Ok(url) = Url::parse(&format!(
 429            "{}/.well-known/openid-configuration/{}",
 430            base, path
 431        )) {
 432            urls.push(url);
 433        }
 434        if let Ok(url) = Url::parse(&format!(
 435            "{}/{}/.well-known/openid-configuration",
 436            base, path
 437        )) {
 438            urls.push(url);
 439        }
 440    } else {
 441        // No path: standard well-known locations.
 442        if let Ok(url) = Url::parse(&format!("{}/.well-known/oauth-authorization-server", base)) {
 443            urls.push(url);
 444        }
 445        if let Ok(url) = Url::parse(&format!("{}/.well-known/openid-configuration", base)) {
 446            urls.push(url);
 447        }
 448    }
 449
 450    urls
 451}
 452
 453// -- Canonical server URI (RFC 8707) -----------------------------------------
 454
 455/// Derive the canonical resource URI for an MCP server URL, suitable for the
 456/// `resource` parameter in authorization and token requests per RFC 8707.
 457///
 458/// Lowercases the scheme and host, preserves the path (without trailing slash),
 459/// strips fragments and query strings.
 460pub fn canonical_server_uri(server_url: &Url) -> String {
 461    let mut uri = format!(
 462        "{}://{}",
 463        server_url.scheme().to_ascii_lowercase(),
 464        server_url.host_str().unwrap_or("").to_ascii_lowercase(),
 465    );
 466    if let Some(port) = server_url.port() {
 467        uri.push_str(&format!(":{}", port));
 468    }
 469    let path = server_url.path();
 470    if path != "/" {
 471        uri.push_str(path.trim_end_matches('/'));
 472    }
 473    uri
 474}
 475
 476// -- Scope selection ---------------------------------------------------------
 477
 478/// Select scopes following the MCP spec's Scope Selection Strategy:
 479/// 1. Use `scope` from the `WWW-Authenticate` challenge if present.
 480/// 2. Fall back to `scopes_supported` from Protected Resource Metadata.
 481/// 3. Return empty if neither is available.
 482pub fn select_scopes(
 483    www_authenticate: &WwwAuthenticate,
 484    resource_metadata: &ProtectedResourceMetadata,
 485) -> Vec<String> {
 486    if let Some(ref scopes) = www_authenticate.scope {
 487        if !scopes.is_empty() {
 488            return scopes.clone();
 489        }
 490    }
 491    resource_metadata
 492        .scopes_supported
 493        .clone()
 494        .unwrap_or_default()
 495}
 496
 497// -- Client registration strategy --------------------------------------------
 498
 499/// The registration approach to use, determined from auth server metadata.
 500#[derive(Debug, Clone, PartialEq, Eq)]
 501pub enum ClientRegistrationStrategy {
 502    /// The auth server supports CIMD. Use the CIMD URL as client_id directly.
 503    Cimd { client_id: String },
 504    /// The auth server has a registration endpoint. Caller must POST to it.
 505    Dcr { registration_endpoint: Url },
 506    /// No supported registration mechanism.
 507    Unavailable,
 508}
 509
 510/// Determine how to register with the authorization server, following the
 511/// spec's recommended priority: CIMD first, DCR fallback.
 512pub fn determine_registration_strategy(
 513    auth_server_metadata: &AuthServerMetadata,
 514) -> ClientRegistrationStrategy {
 515    if auth_server_metadata.client_id_metadata_document_supported {
 516        ClientRegistrationStrategy::Cimd {
 517            client_id: CIMD_URL.to_string(),
 518        }
 519    } else if let Some(ref endpoint) = auth_server_metadata.registration_endpoint {
 520        ClientRegistrationStrategy::Dcr {
 521            registration_endpoint: endpoint.clone(),
 522        }
 523    } else {
 524        ClientRegistrationStrategy::Unavailable
 525    }
 526}
 527
 528// -- PKCE (RFC 7636) ---------------------------------------------------------
 529
 530/// A PKCE code verifier and its S256 challenge.
 531#[derive(Clone)]
 532pub struct PkceChallenge {
 533    pub verifier: String,
 534    pub challenge: String,
 535}
 536
 537impl std::fmt::Debug for PkceChallenge {
 538    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 539        f.debug_struct("PkceChallenge")
 540            .field("verifier", &"[redacted]")
 541            .field("challenge", &self.challenge)
 542            .finish()
 543    }
 544}
 545
 546/// Generate a PKCE code verifier and S256 challenge per RFC 7636.
 547///
 548/// The verifier is 43 base64url characters derived from 32 random bytes.
 549/// The challenge is `BASE64URL(SHA256(verifier))`.
 550pub fn generate_pkce_challenge() -> PkceChallenge {
 551    let mut random_bytes = [0u8; 32];
 552    rand::rng().fill(&mut random_bytes);
 553    let engine = base64::engine::general_purpose::URL_SAFE_NO_PAD;
 554    let verifier = engine.encode(&random_bytes);
 555
 556    let digest = Sha256::digest(verifier.as_bytes());
 557    let challenge = engine.encode(digest);
 558
 559    PkceChallenge {
 560        verifier,
 561        challenge,
 562    }
 563}
 564
 565// -- Authorization URL construction ------------------------------------------
 566
 567/// Build the authorization URL for the OAuth Authorization Code + PKCE flow.
 568pub fn build_authorization_url(
 569    auth_server_metadata: &AuthServerMetadata,
 570    client_id: &str,
 571    redirect_uri: &str,
 572    scopes: &[String],
 573    resource: &str,
 574    pkce: &PkceChallenge,
 575    state: &str,
 576) -> Url {
 577    let mut url = auth_server_metadata.authorization_endpoint.clone();
 578    {
 579        let mut query = url.query_pairs_mut();
 580        query.append_pair("response_type", "code");
 581        query.append_pair("client_id", client_id);
 582        query.append_pair("redirect_uri", redirect_uri);
 583        if !scopes.is_empty() {
 584            query.append_pair("scope", &scopes.join(" "));
 585        }
 586        query.append_pair("resource", resource);
 587        query.append_pair("code_challenge", &pkce.challenge);
 588        query.append_pair("code_challenge_method", "S256");
 589        query.append_pair("state", state);
 590    }
 591    url
 592}
 593
 594// -- Token endpoint request bodies -------------------------------------------
 595
 596/// The JSON body returned by the token endpoint on success.
 597#[derive(Deserialize)]
 598pub struct TokenResponse {
 599    pub access_token: String,
 600    #[serde(default)]
 601    pub refresh_token: Option<String>,
 602    #[serde(default)]
 603    pub expires_in: Option<u64>,
 604    #[serde(default)]
 605    pub token_type: Option<String>,
 606}
 607
 608impl std::fmt::Debug for TokenResponse {
 609    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 610        f.debug_struct("TokenResponse")
 611            .field("access_token", &"[redacted]")
 612            .field(
 613                "refresh_token",
 614                &self.refresh_token.as_ref().map(|_| "[redacted]"),
 615            )
 616            .field("expires_in", &self.expires_in)
 617            .field("token_type", &self.token_type)
 618            .finish()
 619    }
 620}
 621
 622impl TokenResponse {
 623    /// Convert into `OAuthTokens`, computing `expires_at` from `expires_in`.
 624    pub fn into_tokens(self) -> OAuthTokens {
 625        let expires_at = self
 626            .expires_in
 627            .map(|secs| SystemTime::now() + Duration::from_secs(secs));
 628        OAuthTokens {
 629            access_token: self.access_token,
 630            refresh_token: self.refresh_token,
 631            expires_at,
 632        }
 633    }
 634}
 635
 636/// Build the form-encoded body for an authorization code token exchange.
 637pub fn token_exchange_params(
 638    code: &str,
 639    client_id: &str,
 640    redirect_uri: &str,
 641    code_verifier: &str,
 642    resource: &str,
 643) -> Vec<(&'static str, String)> {
 644    vec![
 645        ("grant_type", "authorization_code".to_string()),
 646        ("code", code.to_string()),
 647        ("redirect_uri", redirect_uri.to_string()),
 648        ("client_id", client_id.to_string()),
 649        ("code_verifier", code_verifier.to_string()),
 650        ("resource", resource.to_string()),
 651    ]
 652}
 653
 654/// Build the form-encoded body for a token refresh request.
 655pub fn token_refresh_params(
 656    refresh_token: &str,
 657    client_id: &str,
 658    resource: &str,
 659) -> Vec<(&'static str, String)> {
 660    vec![
 661        ("grant_type", "refresh_token".to_string()),
 662        ("refresh_token", refresh_token.to_string()),
 663        ("client_id", client_id.to_string()),
 664        ("resource", resource.to_string()),
 665    ]
 666}
 667
 668// -- DCR request body (RFC 7591) ---------------------------------------------
 669
 670/// Build the JSON body for a Dynamic Client Registration request.
 671///
 672/// The `redirect_uri` should be the actual loopback URI with the ephemeral
 673/// port (e.g. `http://127.0.0.1:12345/callback`). Some auth servers do strict
 674/// redirect URI matching even for loopback addresses, so we register the
 675/// exact URI we intend to use.
 676/// The grant types Zed can use. Intersected with the server's
 677/// `grant_types_supported` to build the DCR request.
 678const SUPPORTED_GRANT_TYPES: &[&str] = &["authorization_code", "refresh_token"];
 679
 680pub fn dcr_registration_body(
 681    redirect_uri: &str,
 682    server_grant_types: Option<&[String]>,
 683) -> serde_json::Value {
 684    // Use the intersection of what we support and what the server advertises.
 685    // When the server doesn't advertise grant_types_supported, send all of
 686    // ours — the server will reject what it doesn't like.
 687    let grant_types: Vec<&str> = match server_grant_types {
 688        Some(server) => SUPPORTED_GRANT_TYPES
 689            .iter()
 690            .copied()
 691            .filter(|gt| server.iter().any(|s| s == *gt))
 692            .collect(),
 693        None => SUPPORTED_GRANT_TYPES.to_vec(),
 694    };
 695
 696    serde_json::json!({
 697        "client_name": "Zed",
 698        "redirect_uris": [redirect_uri],
 699        "grant_types": grant_types,
 700        "response_types": ["code"],
 701        "token_endpoint_auth_method": "none"
 702    })
 703}
 704
 705// -- Discovery (async, hits real endpoints) ----------------------------------
 706
 707/// Fetch Protected Resource Metadata from the MCP server.
 708///
 709/// Tries the `resource_metadata` URL from the `WWW-Authenticate` header first,
 710/// then falls back to well-known URIs constructed from `server_url`.
 711pub async fn fetch_protected_resource_metadata(
 712    http_client: &Arc<dyn HttpClient>,
 713    server_url: &Url,
 714    www_authenticate: &WwwAuthenticate,
 715) -> Result<ProtectedResourceMetadata> {
 716    let candidate_urls = match &www_authenticate.resource_metadata {
 717        Some(url) if url.origin() == server_url.origin() => vec![url.clone()],
 718        Some(url) => {
 719            log::warn!(
 720                "Ignoring cross-origin resource_metadata URL {} \
 721                 (server origin: {})",
 722                url,
 723                server_url.origin().unicode_serialization()
 724            );
 725            protected_resource_metadata_urls(server_url)
 726        }
 727        None => protected_resource_metadata_urls(server_url),
 728    };
 729
 730    for url in &candidate_urls {
 731        match fetch_json::<ProtectedResourceMetadataResponse>(http_client, url).await {
 732            Ok(response) => {
 733                if response.authorization_servers.is_empty() {
 734                    bail!(
 735                        "Protected Resource Metadata at {} has no authorization_servers",
 736                        url
 737                    );
 738                }
 739                return Ok(ProtectedResourceMetadata {
 740                    resource: response.resource.unwrap_or_else(|| server_url.clone()),
 741                    authorization_servers: response.authorization_servers,
 742                    scopes_supported: response.scopes_supported,
 743                });
 744            }
 745            Err(err) => {
 746                log::debug!(
 747                    "Failed to fetch Protected Resource Metadata from {}: {}",
 748                    url,
 749                    err
 750                );
 751            }
 752        }
 753    }
 754
 755    bail!(
 756        "Could not fetch Protected Resource Metadata for {}",
 757        server_url
 758    )
 759}
 760
 761/// Fetch Authorization Server Metadata, trying RFC 8414 and OIDC Discovery
 762/// endpoints in the priority order specified by the MCP spec.
 763pub async fn fetch_auth_server_metadata(
 764    http_client: &Arc<dyn HttpClient>,
 765    issuer: &Url,
 766) -> Result<AuthServerMetadata> {
 767    let candidate_urls = auth_server_metadata_urls(issuer);
 768
 769    for url in &candidate_urls {
 770        match fetch_json::<AuthServerMetadataResponse>(http_client, url).await {
 771            Ok(response) => {
 772                let reported_issuer = response.issuer.unwrap_or_else(|| issuer.clone());
 773                if reported_issuer != *issuer {
 774                    bail!(
 775                        "Auth server metadata issuer mismatch: expected {}, got {}",
 776                        issuer,
 777                        reported_issuer
 778                    );
 779                }
 780
 781                return Ok(AuthServerMetadata {
 782                    issuer: reported_issuer,
 783                    grant_types_supported: response.grant_types_supported,
 784                    authorization_endpoint: response
 785                        .authorization_endpoint
 786                        .ok_or_else(|| anyhow!("missing authorization_endpoint"))?,
 787                    token_endpoint: response
 788                        .token_endpoint
 789                        .ok_or_else(|| anyhow!("missing token_endpoint"))?,
 790                    registration_endpoint: response.registration_endpoint,
 791                    scopes_supported: response.scopes_supported,
 792                    code_challenge_methods_supported: response.code_challenge_methods_supported,
 793                    client_id_metadata_document_supported: response
 794                        .client_id_metadata_document_supported
 795                        .unwrap_or(false),
 796                });
 797            }
 798            Err(err) => {
 799                log::debug!("Failed to fetch Auth Server Metadata from {}: {}", url, err);
 800            }
 801        }
 802    }
 803
 804    bail!(
 805        "Could not fetch Authorization Server Metadata for {}",
 806        issuer
 807    )
 808}
 809
 810/// Run the full discovery flow: fetch resource metadata, then auth server
 811/// metadata, then select scopes. Client registration is resolved separately,
 812/// once the real redirect URI is known.
 813pub async fn discover(
 814    http_client: &Arc<dyn HttpClient>,
 815    server_url: &Url,
 816    www_authenticate: &WwwAuthenticate,
 817) -> Result<OAuthDiscovery> {
 818    let resource_metadata =
 819        fetch_protected_resource_metadata(http_client, server_url, www_authenticate).await?;
 820
 821    let auth_server_url = resource_metadata
 822        .authorization_servers
 823        .first()
 824        .ok_or_else(|| anyhow!("no authorization servers in resource metadata"))?;
 825
 826    let auth_server_metadata = fetch_auth_server_metadata(http_client, auth_server_url).await?;
 827
 828    // Verify PKCE S256 support (spec requirement).
 829    match &auth_server_metadata.code_challenge_methods_supported {
 830        Some(methods) if methods.iter().any(|m| m == "S256") => {}
 831        Some(_) => bail!("authorization server does not support S256 PKCE"),
 832        None => bail!("authorization server does not advertise code_challenge_methods_supported"),
 833    }
 834
 835    // Verify there is at least one supported registration strategy before we
 836    // present the server as ready to authenticate.
 837    match determine_registration_strategy(&auth_server_metadata) {
 838        ClientRegistrationStrategy::Cimd { .. } | ClientRegistrationStrategy::Dcr { .. } => {}
 839        ClientRegistrationStrategy::Unavailable => {
 840            bail!("authorization server supports neither CIMD nor DCR")
 841        }
 842    }
 843
 844    let scopes = select_scopes(www_authenticate, &resource_metadata);
 845
 846    Ok(OAuthDiscovery {
 847        resource_metadata,
 848        auth_server_metadata,
 849        scopes,
 850    })
 851}
 852
 853/// Resolve the OAuth client registration for an authorization flow.
 854///
 855/// CIMD uses the static client metadata document directly. For DCR, a fresh
 856/// registration is performed each time because the loopback redirect URI
 857/// includes an ephemeral port that changes every flow.
 858pub async fn resolve_client_registration(
 859    http_client: &Arc<dyn HttpClient>,
 860    discovery: &OAuthDiscovery,
 861    redirect_uri: &str,
 862) -> Result<OAuthClientRegistration> {
 863    match determine_registration_strategy(&discovery.auth_server_metadata) {
 864        ClientRegistrationStrategy::Cimd { client_id } => Ok(OAuthClientRegistration {
 865            client_id,
 866            client_secret: None,
 867        }),
 868        ClientRegistrationStrategy::Dcr {
 869            registration_endpoint,
 870        } => {
 871            perform_dcr(
 872                http_client,
 873                &registration_endpoint,
 874                redirect_uri,
 875                discovery
 876                    .auth_server_metadata
 877                    .grant_types_supported
 878                    .as_deref(),
 879            )
 880            .await
 881        }
 882        ClientRegistrationStrategy::Unavailable => {
 883            bail!("authorization server supports neither CIMD nor DCR")
 884        }
 885    }
 886}
 887
 888// -- Dynamic Client Registration (RFC 7591) ----------------------------------
 889
 890/// Perform Dynamic Client Registration with the authorization server.
 891pub async fn perform_dcr(
 892    http_client: &Arc<dyn HttpClient>,
 893    registration_endpoint: &Url,
 894    redirect_uri: &str,
 895    server_grant_types: Option<&[String]>,
 896) -> Result<OAuthClientRegistration> {
 897    validate_oauth_url(registration_endpoint)?;
 898
 899    let body = dcr_registration_body(redirect_uri, server_grant_types);
 900    let body_bytes = serde_json::to_vec(&body)?;
 901
 902    let request = Request::builder()
 903        .method(http_client::http::Method::POST)
 904        .uri(registration_endpoint.as_str())
 905        .header("Content-Type", "application/json")
 906        .header("Accept", "application/json")
 907        .body(AsyncBody::from(body_bytes))?;
 908
 909    let mut response = http_client.send(request).await?;
 910
 911    if !response.status().is_success() {
 912        let mut error_body = String::new();
 913        response.body_mut().read_to_string(&mut error_body).await?;
 914        bail!(
 915            "DCR failed with status {}: {}",
 916            response.status(),
 917            error_body
 918        );
 919    }
 920
 921    let mut response_body = String::new();
 922    response
 923        .body_mut()
 924        .read_to_string(&mut response_body)
 925        .await?;
 926
 927    let dcr_response: DcrResponse =
 928        serde_json::from_str(&response_body).context("failed to parse DCR response")?;
 929
 930    Ok(OAuthClientRegistration {
 931        client_id: dcr_response.client_id,
 932        client_secret: dcr_response.client_secret,
 933    })
 934}
 935
 936// -- Token exchange and refresh (async) --------------------------------------
 937
 938/// Exchange an authorization code for tokens at the token endpoint.
 939pub async fn exchange_code(
 940    http_client: &Arc<dyn HttpClient>,
 941    auth_server_metadata: &AuthServerMetadata,
 942    code: &str,
 943    client_id: &str,
 944    redirect_uri: &str,
 945    code_verifier: &str,
 946    resource: &str,
 947) -> Result<OAuthTokens> {
 948    let params = token_exchange_params(code, client_id, redirect_uri, code_verifier, resource);
 949    post_token_request(http_client, &auth_server_metadata.token_endpoint, &params).await
 950}
 951
 952/// Refresh tokens using a refresh token.
 953pub async fn refresh_tokens(
 954    http_client: &Arc<dyn HttpClient>,
 955    token_endpoint: &Url,
 956    refresh_token: &str,
 957    client_id: &str,
 958    resource: &str,
 959) -> Result<OAuthTokens> {
 960    let params = token_refresh_params(refresh_token, client_id, resource);
 961    post_token_request(http_client, token_endpoint, &params).await
 962}
 963
 964/// POST form-encoded parameters to a token endpoint and parse the response.
 965async fn post_token_request(
 966    http_client: &Arc<dyn HttpClient>,
 967    token_endpoint: &Url,
 968    params: &[(&str, String)],
 969) -> Result<OAuthTokens> {
 970    validate_oauth_url(token_endpoint)?;
 971
 972    let body = url::form_urlencoded::Serializer::new(String::new())
 973        .extend_pairs(params.iter().map(|(k, v)| (*k, v.as_str())))
 974        .finish();
 975
 976    let request = Request::builder()
 977        .method(http_client::http::Method::POST)
 978        .uri(token_endpoint.as_str())
 979        .header("Content-Type", "application/x-www-form-urlencoded")
 980        .header("Accept", "application/json")
 981        .body(AsyncBody::from(body.into_bytes()))?;
 982
 983    let mut response = http_client.send(request).await?;
 984
 985    if !response.status().is_success() {
 986        let mut error_body = String::new();
 987        response.body_mut().read_to_string(&mut error_body).await?;
 988        bail!(
 989            "token request failed with status {}: {}",
 990            response.status(),
 991            error_body
 992        );
 993    }
 994
 995    let mut response_body = String::new();
 996    response
 997        .body_mut()
 998        .read_to_string(&mut response_body)
 999        .await?;
1000
1001    let token_response: TokenResponse =
1002        serde_json::from_str(&response_body).context("failed to parse token response")?;
1003
1004    Ok(token_response.into_tokens())
1005}
1006
1007// -- Loopback HTTP callback server -------------------------------------------
1008
1009/// An OAuth authorization callback received via the loopback HTTP server.
1010pub struct OAuthCallback {
1011    pub code: String,
1012    pub state: String,
1013}
1014
1015impl std::fmt::Debug for OAuthCallback {
1016    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1017        f.debug_struct("OAuthCallback")
1018            .field("code", &"[redacted]")
1019            .field("state", &"[redacted]")
1020            .finish()
1021    }
1022}
1023
1024impl OAuthCallback {
1025    /// Parse the query string from a callback URL like
1026    /// `http://127.0.0.1:<port>/callback?code=...&state=...`.
1027    pub fn parse_query(query: &str) -> Result<Self> {
1028        let mut code: Option<String> = None;
1029        let mut state: Option<String> = None;
1030        let mut error: Option<String> = None;
1031        let mut error_description: Option<String> = None;
1032
1033        for (key, value) in url::form_urlencoded::parse(query.as_bytes()) {
1034            match key.as_ref() {
1035                "code" => {
1036                    if !value.is_empty() {
1037                        code = Some(value.into_owned());
1038                    }
1039                }
1040                "state" => {
1041                    if !value.is_empty() {
1042                        state = Some(value.into_owned());
1043                    }
1044                }
1045                "error" => {
1046                    if !value.is_empty() {
1047                        error = Some(value.into_owned());
1048                    }
1049                }
1050                "error_description" => {
1051                    if !value.is_empty() {
1052                        error_description = Some(value.into_owned());
1053                    }
1054                }
1055                _ => {}
1056            }
1057        }
1058
1059        // Check for OAuth error response (RFC 6749 Section 4.1.2.1) before
1060        // checking for missing code/state.
1061        if let Some(error_code) = error {
1062            bail!(
1063                "OAuth authorization failed: {} ({})",
1064                error_code,
1065                error_description.as_deref().unwrap_or("no description")
1066            );
1067        }
1068
1069        let code = code.ok_or_else(|| anyhow!("missing 'code' parameter in OAuth callback"))?;
1070        let state = state.ok_or_else(|| anyhow!("missing 'state' parameter in OAuth callback"))?;
1071
1072        Ok(Self { code, state })
1073    }
1074}
1075
1076/// How long to wait for the browser to complete the OAuth flow before giving
1077/// up and releasing the loopback port.
1078const CALLBACK_TIMEOUT: Duration = Duration::from_secs(2 * 60);
1079
1080/// Start a loopback HTTP server to receive the OAuth authorization callback.
1081///
1082/// Binds to an ephemeral loopback port for each flow.
1083///
1084/// Returns `(redirect_uri, callback_future)`. The caller should use the
1085/// redirect URI in the authorization request, open the browser, then await
1086/// the future to receive the callback.
1087///
1088/// The server accepts exactly one request on `/callback`, validates that it
1089/// contains `code` and `state` query parameters, responds with a minimal
1090/// HTML page telling the user they can close the tab, and shuts down.
1091///
1092/// The callback server shuts down when the returned oneshot receiver is dropped
1093/// (e.g. because the authentication task was cancelled), or after a timeout
1094/// ([CALLBACK_TIMEOUT]).
1095pub async fn start_callback_server() -> Result<(
1096    String,
1097    futures::channel::oneshot::Receiver<Result<OAuthCallback>>,
1098)> {
1099    let server = tiny_http::Server::http("127.0.0.1:0")
1100        .map_err(|e| anyhow!(e).context("Failed to bind loopback listener for OAuth callback"))?;
1101    let port = server
1102        .server_addr()
1103        .to_ip()
1104        .context("server not bound to a TCP address")?
1105        .port();
1106
1107    let redirect_uri = format!("http://127.0.0.1:{}/callback", port);
1108
1109    let (tx, rx) = futures::channel::oneshot::channel();
1110
1111    // `tiny_http` is blocking, so we run it on a background thread.
1112    // The `recv_timeout` loop lets us check for cancellation (the receiver
1113    // being dropped) and enforce an overall timeout.
1114    std::thread::spawn(move || {
1115        let deadline = std::time::Instant::now() + CALLBACK_TIMEOUT;
1116
1117        loop {
1118            if tx.is_canceled() {
1119                return;
1120            }
1121            let remaining = deadline.saturating_duration_since(std::time::Instant::now());
1122            if remaining.is_zero() {
1123                return;
1124            }
1125
1126            let timeout = remaining.min(Duration::from_millis(500));
1127            let Some(request) = (match server.recv_timeout(timeout) {
1128                Ok(req) => req,
1129                Err(_) => {
1130                    let _ = tx.send(Err(anyhow!("OAuth callback server I/O error")));
1131                    return;
1132                }
1133            }) else {
1134                // Timeout with no request — loop back and check cancellation.
1135                continue;
1136            };
1137
1138            let result = handle_callback_request(&request);
1139
1140            let (status_code, body) = match &result {
1141                Ok(_) => (
1142                    200,
1143                    "<html><body><h1>Authorization successful</h1>\
1144                     <p>You can close this tab and return to Zed.</p></body></html>",
1145                ),
1146                Err(err) => {
1147                    log::error!("OAuth callback error: {}", err);
1148                    (
1149                        400,
1150                        "<html><body><h1>Authorization failed</h1>\
1151                         <p>Something went wrong. Please try again from Zed.</p></body></html>",
1152                    )
1153                }
1154            };
1155
1156            let response = tiny_http::Response::from_string(body)
1157                .with_status_code(status_code)
1158                .with_header(
1159                    tiny_http::Header::from_str("Content-Type: text/html")
1160                        .expect("failed to construct response header"),
1161                )
1162                .with_header(
1163                    tiny_http::Header::from_str("Keep-Alive: timeout=0,max=0")
1164                        .expect("failed to construct response header"),
1165                );
1166            request.respond(response).log_err();
1167
1168            let _ = tx.send(result);
1169            return;
1170        }
1171    });
1172
1173    Ok((redirect_uri, rx))
1174}
1175
1176/// Extract the `code` and `state` query parameters from an OAuth callback
1177/// request to `/callback`.
1178fn handle_callback_request(request: &tiny_http::Request) -> Result<OAuthCallback> {
1179    let url = Url::parse(&format!("http://localhost{}", request.url()))
1180        .context("malformed callback request URL")?;
1181
1182    if url.path() != "/callback" {
1183        bail!("unexpected path in OAuth callback: {}", url.path());
1184    }
1185
1186    let query = url
1187        .query()
1188        .ok_or_else(|| anyhow!("OAuth callback has no query string"))?;
1189    OAuthCallback::parse_query(query)
1190}
1191
1192// -- JSON fetch helper -------------------------------------------------------
1193
1194async fn fetch_json<T: serde::de::DeserializeOwned>(
1195    http_client: &Arc<dyn HttpClient>,
1196    url: &Url,
1197) -> Result<T> {
1198    validate_oauth_url(url)?;
1199
1200    let request = Request::builder()
1201        .method(http_client::http::Method::GET)
1202        .uri(url.as_str())
1203        .header("Accept", "application/json")
1204        .body(AsyncBody::default())?;
1205
1206    let mut response = http_client.send(request).await?;
1207
1208    if !response.status().is_success() {
1209        bail!("HTTP {} fetching {}", response.status(), url);
1210    }
1211
1212    let mut body = String::new();
1213    response.body_mut().read_to_string(&mut body).await?;
1214    serde_json::from_str(&body).with_context(|| format!("failed to parse JSON from {}", url))
1215}
1216
1217// -- Serde response types for discovery --------------------------------------
1218
1219#[derive(Debug, Deserialize)]
1220struct ProtectedResourceMetadataResponse {
1221    #[serde(default)]
1222    resource: Option<Url>,
1223    #[serde(default)]
1224    authorization_servers: Vec<Url>,
1225    #[serde(default)]
1226    scopes_supported: Option<Vec<String>>,
1227}
1228
1229#[derive(Debug, Deserialize)]
1230struct AuthServerMetadataResponse {
1231    #[serde(default)]
1232    issuer: Option<Url>,
1233    #[serde(default)]
1234    authorization_endpoint: Option<Url>,
1235    #[serde(default)]
1236    token_endpoint: Option<Url>,
1237    #[serde(default)]
1238    registration_endpoint: Option<Url>,
1239    #[serde(default)]
1240    scopes_supported: Option<Vec<String>>,
1241    #[serde(default)]
1242    grant_types_supported: Option<Vec<String>>,
1243    #[serde(default)]
1244    code_challenge_methods_supported: Option<Vec<String>>,
1245    #[serde(default)]
1246    client_id_metadata_document_supported: Option<bool>,
1247}
1248
1249#[derive(Debug, Deserialize)]
1250struct DcrResponse {
1251    client_id: String,
1252    #[serde(default)]
1253    client_secret: Option<String>,
1254}
1255
1256/// Provides OAuth tokens to the HTTP transport layer.
1257///
1258/// The transport calls `access_token()` before each request. On a 401 response
1259/// it calls `try_refresh()` and retries once if the refresh succeeds.
1260#[async_trait]
1261pub trait OAuthTokenProvider: Send + Sync {
1262    /// Returns the current access token, if one is available.
1263    fn access_token(&self) -> Option<String>;
1264
1265    /// Attempts to refresh the access token. Returns `true` if a new token was
1266    /// obtained and the request should be retried.
1267    async fn try_refresh(&self) -> Result<bool>;
1268}
1269
1270/// Concrete `OAuthTokenProvider` backed by a full persisted OAuth session and
1271/// an HTTP client for token refresh. The same provider type is used both after
1272/// an interactive authentication flow and when restoring a saved session from
1273/// the keychain on startup.
1274pub struct McpOAuthTokenProvider {
1275    session: SyncMutex<OAuthSession>,
1276    http_client: Arc<dyn HttpClient>,
1277    token_refresh_tx: Option<mpsc::UnboundedSender<OAuthSession>>,
1278}
1279
1280impl McpOAuthTokenProvider {
1281    pub fn new(
1282        session: OAuthSession,
1283        http_client: Arc<dyn HttpClient>,
1284        token_refresh_tx: Option<mpsc::UnboundedSender<OAuthSession>>,
1285    ) -> Self {
1286        Self {
1287            session: SyncMutex::new(session),
1288            http_client,
1289            token_refresh_tx,
1290        }
1291    }
1292
1293    fn access_token_is_expired(tokens: &OAuthTokens) -> bool {
1294        tokens.expires_at.is_some_and(|expires_at| {
1295            SystemTime::now()
1296                .checked_add(Duration::from_secs(30))
1297                .is_some_and(|now_with_buffer| expires_at <= now_with_buffer)
1298        })
1299    }
1300}
1301
1302#[async_trait]
1303impl OAuthTokenProvider for McpOAuthTokenProvider {
1304    fn access_token(&self) -> Option<String> {
1305        let session = self.session.lock();
1306        if Self::access_token_is_expired(&session.tokens) {
1307            return None;
1308        }
1309        Some(session.tokens.access_token.clone())
1310    }
1311
1312    async fn try_refresh(&self) -> Result<bool> {
1313        let (refresh_token, token_endpoint, resource, client_id) = {
1314            let session = self.session.lock();
1315            match session.tokens.refresh_token.clone() {
1316                Some(refresh_token) => (
1317                    refresh_token,
1318                    session.token_endpoint.clone(),
1319                    session.resource.clone(),
1320                    session.client_registration.client_id.clone(),
1321                ),
1322                None => return Ok(false),
1323            }
1324        };
1325
1326        let resource_str = canonical_server_uri(&resource);
1327
1328        match refresh_tokens(
1329            &self.http_client,
1330            &token_endpoint,
1331            &refresh_token,
1332            &client_id,
1333            &resource_str,
1334        )
1335        .await
1336        {
1337            Ok(mut new_tokens) => {
1338                if new_tokens.refresh_token.is_none() {
1339                    new_tokens.refresh_token = Some(refresh_token);
1340                }
1341
1342                {
1343                    let mut session = self.session.lock();
1344                    session.tokens = new_tokens;
1345
1346                    if let Some(ref tx) = self.token_refresh_tx {
1347                        tx.unbounded_send(session.clone()).ok();
1348                    }
1349                }
1350
1351                Ok(true)
1352            }
1353            Err(err) => {
1354                log::warn!("OAuth token refresh failed: {}", err);
1355                Ok(false)
1356            }
1357        }
1358    }
1359}
1360
1361#[cfg(test)]
1362mod tests {
1363    use super::*;
1364    use http_client::Response;
1365
1366    // -- require_https_or_loopback tests ------------------------------------
1367
1368    #[test]
1369    fn test_require_https_or_loopback_accepts_https() {
1370        let url = Url::parse("https://auth.example.com/token").unwrap();
1371        assert!(require_https_or_loopback(&url).is_ok());
1372    }
1373
1374    #[test]
1375    fn test_require_https_or_loopback_rejects_http_remote() {
1376        let url = Url::parse("http://auth.example.com/token").unwrap();
1377        assert!(require_https_or_loopback(&url).is_err());
1378    }
1379
1380    #[test]
1381    fn test_require_https_or_loopback_accepts_http_127_0_0_1() {
1382        let url = Url::parse("http://127.0.0.1:8080/callback").unwrap();
1383        assert!(require_https_or_loopback(&url).is_ok());
1384    }
1385
1386    #[test]
1387    fn test_require_https_or_loopback_accepts_http_ipv6_loopback() {
1388        let url = Url::parse("http://[::1]:8080/callback").unwrap();
1389        assert!(require_https_or_loopback(&url).is_ok());
1390    }
1391
1392    #[test]
1393    fn test_require_https_or_loopback_accepts_http_localhost() {
1394        let url = Url::parse("http://localhost:8080/callback").unwrap();
1395        assert!(require_https_or_loopback(&url).is_ok());
1396    }
1397
1398    #[test]
1399    fn test_require_https_or_loopback_accepts_http_localhost_case_insensitive() {
1400        let url = Url::parse("http://LOCALHOST:8080/callback").unwrap();
1401        assert!(require_https_or_loopback(&url).is_ok());
1402    }
1403
1404    #[test]
1405    fn test_require_https_or_loopback_rejects_http_non_loopback_ip() {
1406        let url = Url::parse("http://192.168.1.1:8080/token").unwrap();
1407        assert!(require_https_or_loopback(&url).is_err());
1408    }
1409
1410    #[test]
1411    fn test_require_https_or_loopback_rejects_ftp() {
1412        let url = Url::parse("ftp://auth.example.com/token").unwrap();
1413        assert!(require_https_or_loopback(&url).is_err());
1414    }
1415
1416    // -- validate_oauth_url (SSRF) tests ------------------------------------
1417
1418    #[test]
1419    fn test_validate_oauth_url_accepts_https_public() {
1420        let url = Url::parse("https://auth.example.com/token").unwrap();
1421        assert!(validate_oauth_url(&url).is_ok());
1422    }
1423
1424    #[test]
1425    fn test_validate_oauth_url_rejects_private_ipv4_10() {
1426        let url = Url::parse("https://10.0.0.1/token").unwrap();
1427        assert!(validate_oauth_url(&url).is_err());
1428    }
1429
1430    #[test]
1431    fn test_validate_oauth_url_rejects_private_ipv4_172() {
1432        let url = Url::parse("https://172.16.0.1/token").unwrap();
1433        assert!(validate_oauth_url(&url).is_err());
1434    }
1435
1436    #[test]
1437    fn test_validate_oauth_url_rejects_private_ipv4_192() {
1438        let url = Url::parse("https://192.168.1.1/token").unwrap();
1439        assert!(validate_oauth_url(&url).is_err());
1440    }
1441
1442    #[test]
1443    fn test_validate_oauth_url_rejects_link_local() {
1444        let url = Url::parse("https://169.254.169.254/latest/meta-data/").unwrap();
1445        assert!(validate_oauth_url(&url).is_err());
1446    }
1447
1448    #[test]
1449    fn test_validate_oauth_url_rejects_ipv6_ula() {
1450        let url = Url::parse("https://[fd12:3456:789a::1]/token").unwrap();
1451        assert!(validate_oauth_url(&url).is_err());
1452    }
1453
1454    #[test]
1455    fn test_validate_oauth_url_rejects_ipv6_unspecified() {
1456        let url = Url::parse("https://[::]/token").unwrap();
1457        assert!(validate_oauth_url(&url).is_err());
1458    }
1459
1460    #[test]
1461    fn test_validate_oauth_url_rejects_ipv4_mapped_ipv6_private() {
1462        let url = Url::parse("https://[::ffff:10.0.0.1]/token").unwrap();
1463        assert!(validate_oauth_url(&url).is_err());
1464    }
1465
1466    #[test]
1467    fn test_validate_oauth_url_rejects_ipv4_mapped_ipv6_link_local() {
1468        let url = Url::parse("https://[::ffff:169.254.169.254]/token").unwrap();
1469        assert!(validate_oauth_url(&url).is_err());
1470    }
1471
1472    #[test]
1473    fn test_validate_oauth_url_allows_http_loopback() {
1474        // Loopback is permitted (it's our callback server).
1475        let url = Url::parse("http://127.0.0.1:8080/callback").unwrap();
1476        assert!(validate_oauth_url(&url).is_ok());
1477    }
1478
1479    #[test]
1480    fn test_validate_oauth_url_allows_https_public_ip() {
1481        let url = Url::parse("https://93.184.216.34/token").unwrap();
1482        assert!(validate_oauth_url(&url).is_ok());
1483    }
1484
1485    // -- parse_www_authenticate tests ----------------------------------------
1486
1487    #[test]
1488    fn test_parse_www_authenticate_with_resource_metadata_and_scope() {
1489        let header = r#"Bearer resource_metadata="https://mcp.example.com/.well-known/oauth-protected-resource", scope="files:read user:profile""#;
1490        let result = parse_www_authenticate(header).unwrap();
1491
1492        assert_eq!(
1493            result.resource_metadata.as_ref().map(|u| u.as_str()),
1494            Some("https://mcp.example.com/.well-known/oauth-protected-resource")
1495        );
1496        assert_eq!(
1497            result.scope,
1498            Some(vec!["files:read".to_string(), "user:profile".to_string()])
1499        );
1500        assert_eq!(result.error, None);
1501    }
1502
1503    #[test]
1504    fn test_parse_www_authenticate_resource_metadata_only() {
1505        let header = r#"Bearer resource_metadata="https://mcp.example.com/.well-known/oauth-protected-resource""#;
1506        let result = parse_www_authenticate(header).unwrap();
1507
1508        assert_eq!(
1509            result.resource_metadata.as_ref().map(|u| u.as_str()),
1510            Some("https://mcp.example.com/.well-known/oauth-protected-resource")
1511        );
1512        assert_eq!(result.scope, None);
1513    }
1514
1515    #[test]
1516    fn test_parse_www_authenticate_bare_bearer() {
1517        let result = parse_www_authenticate("Bearer").unwrap();
1518        assert_eq!(result.resource_metadata, None);
1519        assert_eq!(result.scope, None);
1520    }
1521
1522    #[test]
1523    fn test_parse_www_authenticate_with_error() {
1524        let header = r#"Bearer error="insufficient_scope", scope="files:read files:write", resource_metadata="https://mcp.example.com/.well-known/oauth-protected-resource", error_description="Additional file write permission required""#;
1525        let result = parse_www_authenticate(header).unwrap();
1526
1527        assert_eq!(result.error, Some(BearerError::InsufficientScope));
1528        assert_eq!(
1529            result.error_description.as_deref(),
1530            Some("Additional file write permission required")
1531        );
1532        assert_eq!(
1533            result.scope,
1534            Some(vec!["files:read".to_string(), "files:write".to_string()])
1535        );
1536        assert!(result.resource_metadata.is_some());
1537    }
1538
1539    #[test]
1540    fn test_parse_www_authenticate_invalid_token_error() {
1541        let header =
1542            r#"Bearer error="invalid_token", error_description="The access token expired""#;
1543        let result = parse_www_authenticate(header).unwrap();
1544        assert_eq!(result.error, Some(BearerError::InvalidToken));
1545    }
1546
1547    #[test]
1548    fn test_parse_www_authenticate_invalid_request_error() {
1549        let header = r#"Bearer error="invalid_request""#;
1550        let result = parse_www_authenticate(header).unwrap();
1551        assert_eq!(result.error, Some(BearerError::InvalidRequest));
1552    }
1553
1554    #[test]
1555    fn test_parse_www_authenticate_unknown_error() {
1556        let header = r#"Bearer error="some_future_error""#;
1557        let result = parse_www_authenticate(header).unwrap();
1558        assert_eq!(result.error, Some(BearerError::Other));
1559    }
1560
1561    #[test]
1562    fn test_parse_www_authenticate_rejects_non_bearer() {
1563        let result = parse_www_authenticate("Basic realm=\"example\"");
1564        assert!(result.is_err());
1565    }
1566
1567    #[test]
1568    fn test_parse_www_authenticate_case_insensitive_scheme() {
1569        let header = r#"bearer resource_metadata="https://example.com/.well-known/oauth-protected-resource""#;
1570        let result = parse_www_authenticate(header).unwrap();
1571        assert!(result.resource_metadata.is_some());
1572    }
1573
1574    #[test]
1575    fn test_parse_www_authenticate_multiline_style() {
1576        // Some servers emit the header spread across multiple lines joined by
1577        // whitespace, as shown in the spec examples.
1578        let header = "Bearer resource_metadata=\"https://mcp.example.com/.well-known/oauth-protected-resource\",\n                         scope=\"files:read\"";
1579        let result = parse_www_authenticate(header).unwrap();
1580        assert!(result.resource_metadata.is_some());
1581        assert_eq!(result.scope, Some(vec!["files:read".to_string()]));
1582    }
1583
1584    #[test]
1585    fn test_protected_resource_metadata_urls_with_path() {
1586        let server_url = Url::parse("https://api.example.com/v1/mcp").unwrap();
1587        let urls = protected_resource_metadata_urls(&server_url);
1588
1589        assert_eq!(urls.len(), 2);
1590        assert_eq!(
1591            urls[0].as_str(),
1592            "https://api.example.com/.well-known/oauth-protected-resource/v1/mcp"
1593        );
1594        assert_eq!(
1595            urls[1].as_str(),
1596            "https://api.example.com/.well-known/oauth-protected-resource"
1597        );
1598    }
1599
1600    #[test]
1601    fn test_protected_resource_metadata_urls_without_path() {
1602        let server_url = Url::parse("https://mcp.example.com").unwrap();
1603        let urls = protected_resource_metadata_urls(&server_url);
1604
1605        assert_eq!(urls.len(), 1);
1606        assert_eq!(
1607            urls[0].as_str(),
1608            "https://mcp.example.com/.well-known/oauth-protected-resource"
1609        );
1610    }
1611
1612    #[test]
1613    fn test_auth_server_metadata_urls_with_path() {
1614        let issuer = Url::parse("https://auth.example.com/tenant1").unwrap();
1615        let urls = auth_server_metadata_urls(&issuer);
1616
1617        assert_eq!(urls.len(), 3);
1618        assert_eq!(
1619            urls[0].as_str(),
1620            "https://auth.example.com/.well-known/oauth-authorization-server/tenant1"
1621        );
1622        assert_eq!(
1623            urls[1].as_str(),
1624            "https://auth.example.com/.well-known/openid-configuration/tenant1"
1625        );
1626        assert_eq!(
1627            urls[2].as_str(),
1628            "https://auth.example.com/tenant1/.well-known/openid-configuration"
1629        );
1630    }
1631
1632    #[test]
1633    fn test_auth_server_metadata_urls_without_path() {
1634        let issuer = Url::parse("https://auth.example.com").unwrap();
1635        let urls = auth_server_metadata_urls(&issuer);
1636
1637        assert_eq!(urls.len(), 2);
1638        assert_eq!(
1639            urls[0].as_str(),
1640            "https://auth.example.com/.well-known/oauth-authorization-server"
1641        );
1642        assert_eq!(
1643            urls[1].as_str(),
1644            "https://auth.example.com/.well-known/openid-configuration"
1645        );
1646    }
1647
1648    // -- Canonical server URI tests ------------------------------------------
1649
1650    #[test]
1651    fn test_canonical_server_uri_simple() {
1652        let url = Url::parse("https://mcp.example.com").unwrap();
1653        assert_eq!(canonical_server_uri(&url), "https://mcp.example.com");
1654    }
1655
1656    #[test]
1657    fn test_canonical_server_uri_with_path() {
1658        let url = Url::parse("https://mcp.example.com/v1/mcp").unwrap();
1659        assert_eq!(canonical_server_uri(&url), "https://mcp.example.com/v1/mcp");
1660    }
1661
1662    #[test]
1663    fn test_canonical_server_uri_strips_trailing_slash() {
1664        let url = Url::parse("https://mcp.example.com/").unwrap();
1665        assert_eq!(canonical_server_uri(&url), "https://mcp.example.com");
1666    }
1667
1668    #[test]
1669    fn test_canonical_server_uri_preserves_port() {
1670        let url = Url::parse("https://mcp.example.com:8443").unwrap();
1671        assert_eq!(canonical_server_uri(&url), "https://mcp.example.com:8443");
1672    }
1673
1674    #[test]
1675    fn test_canonical_server_uri_lowercases() {
1676        let url = Url::parse("HTTPS://MCP.Example.COM/Server/MCP").unwrap();
1677        assert_eq!(
1678            canonical_server_uri(&url),
1679            "https://mcp.example.com/Server/MCP"
1680        );
1681    }
1682
1683    // -- Scope selection tests -----------------------------------------------
1684
1685    #[test]
1686    fn test_select_scopes_prefers_www_authenticate() {
1687        let www_auth = WwwAuthenticate {
1688            resource_metadata: None,
1689            scope: Some(vec!["files:read".into()]),
1690            error: None,
1691            error_description: None,
1692        };
1693        let resource_meta = ProtectedResourceMetadata {
1694            resource: Url::parse("https://example.com").unwrap(),
1695            authorization_servers: vec![],
1696            scopes_supported: Some(vec!["files:read".into(), "files:write".into()]),
1697        };
1698        assert_eq!(select_scopes(&www_auth, &resource_meta), vec!["files:read"]);
1699    }
1700
1701    #[test]
1702    fn test_select_scopes_falls_back_to_resource_metadata() {
1703        let www_auth = WwwAuthenticate {
1704            resource_metadata: None,
1705            scope: None,
1706            error: None,
1707            error_description: None,
1708        };
1709        let resource_meta = ProtectedResourceMetadata {
1710            resource: Url::parse("https://example.com").unwrap(),
1711            authorization_servers: vec![],
1712            scopes_supported: Some(vec!["admin".into()]),
1713        };
1714        assert_eq!(select_scopes(&www_auth, &resource_meta), vec!["admin"]);
1715    }
1716
1717    #[test]
1718    fn test_select_scopes_empty_when_nothing_available() {
1719        let www_auth = WwwAuthenticate {
1720            resource_metadata: None,
1721            scope: None,
1722            error: None,
1723            error_description: None,
1724        };
1725        let resource_meta = ProtectedResourceMetadata {
1726            resource: Url::parse("https://example.com").unwrap(),
1727            authorization_servers: vec![],
1728            scopes_supported: None,
1729        };
1730        assert!(select_scopes(&www_auth, &resource_meta).is_empty());
1731    }
1732
1733    // -- Client registration strategy tests ----------------------------------
1734
1735    #[test]
1736    fn test_registration_strategy_prefers_cimd() {
1737        let metadata = AuthServerMetadata {
1738            issuer: Url::parse("https://auth.example.com").unwrap(),
1739            authorization_endpoint: Url::parse("https://auth.example.com/authorize").unwrap(),
1740            token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
1741            registration_endpoint: Some(Url::parse("https://auth.example.com/register").unwrap()),
1742            scopes_supported: None,
1743            code_challenge_methods_supported: Some(vec!["S256".into()]),
1744            client_id_metadata_document_supported: true,
1745            grant_types_supported: None,
1746        };
1747        assert_eq!(
1748            determine_registration_strategy(&metadata),
1749            ClientRegistrationStrategy::Cimd {
1750                client_id: CIMD_URL.to_string(),
1751            }
1752        );
1753    }
1754
1755    #[test]
1756    fn test_registration_strategy_falls_back_to_dcr() {
1757        let reg_endpoint = Url::parse("https://auth.example.com/register").unwrap();
1758        let metadata = AuthServerMetadata {
1759            issuer: Url::parse("https://auth.example.com").unwrap(),
1760            authorization_endpoint: Url::parse("https://auth.example.com/authorize").unwrap(),
1761            token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
1762            registration_endpoint: Some(reg_endpoint.clone()),
1763            scopes_supported: None,
1764            code_challenge_methods_supported: Some(vec!["S256".into()]),
1765            client_id_metadata_document_supported: false,
1766            grant_types_supported: None,
1767        };
1768        assert_eq!(
1769            determine_registration_strategy(&metadata),
1770            ClientRegistrationStrategy::Dcr {
1771                registration_endpoint: reg_endpoint,
1772            }
1773        );
1774    }
1775
1776    #[test]
1777    fn test_registration_strategy_unavailable() {
1778        let metadata = AuthServerMetadata {
1779            issuer: Url::parse("https://auth.example.com").unwrap(),
1780            authorization_endpoint: Url::parse("https://auth.example.com/authorize").unwrap(),
1781            token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
1782            registration_endpoint: None,
1783            scopes_supported: None,
1784            code_challenge_methods_supported: Some(vec!["S256".into()]),
1785            client_id_metadata_document_supported: false,
1786            grant_types_supported: None,
1787        };
1788        assert_eq!(
1789            determine_registration_strategy(&metadata),
1790            ClientRegistrationStrategy::Unavailable,
1791        );
1792    }
1793
1794    // -- PKCE tests ----------------------------------------------------------
1795
1796    #[test]
1797    fn test_pkce_challenge_verifier_length() {
1798        let pkce = generate_pkce_challenge();
1799        // 32 random bytes → 43 base64url chars (no padding).
1800        assert_eq!(pkce.verifier.len(), 43);
1801    }
1802
1803    #[test]
1804    fn test_pkce_challenge_is_valid_base64url() {
1805        let pkce = generate_pkce_challenge();
1806        for c in pkce.verifier.chars().chain(pkce.challenge.chars()) {
1807            assert!(
1808                c.is_ascii_alphanumeric() || c == '-' || c == '_',
1809                "invalid base64url character: {}",
1810                c
1811            );
1812        }
1813    }
1814
1815    #[test]
1816    fn test_pkce_challenge_is_s256_of_verifier() {
1817        let pkce = generate_pkce_challenge();
1818        let engine = base64::engine::general_purpose::URL_SAFE_NO_PAD;
1819        let expected_digest = Sha256::digest(pkce.verifier.as_bytes());
1820        let expected_challenge = engine.encode(expected_digest);
1821        assert_eq!(pkce.challenge, expected_challenge);
1822    }
1823
1824    #[test]
1825    fn test_pkce_challenges_are_unique() {
1826        let a = generate_pkce_challenge();
1827        let b = generate_pkce_challenge();
1828        assert_ne!(a.verifier, b.verifier);
1829    }
1830
1831    // -- Authorization URL tests ---------------------------------------------
1832
1833    #[test]
1834    fn test_build_authorization_url() {
1835        let metadata = AuthServerMetadata {
1836            issuer: Url::parse("https://auth.example.com").unwrap(),
1837            authorization_endpoint: Url::parse("https://auth.example.com/authorize").unwrap(),
1838            token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
1839            registration_endpoint: None,
1840            scopes_supported: None,
1841            code_challenge_methods_supported: Some(vec!["S256".into()]),
1842            client_id_metadata_document_supported: true,
1843            grant_types_supported: None,
1844        };
1845        let pkce = PkceChallenge {
1846            verifier: "test_verifier".into(),
1847            challenge: "test_challenge".into(),
1848        };
1849        let url = build_authorization_url(
1850            &metadata,
1851            "https://zed.dev/oauth/client-metadata.json",
1852            "http://127.0.0.1:12345/callback",
1853            &["files:read".into(), "files:write".into()],
1854            "https://mcp.example.com",
1855            &pkce,
1856            "random_state_123",
1857        );
1858
1859        let pairs: std::collections::HashMap<_, _> = url.query_pairs().collect();
1860        assert_eq!(pairs.get("response_type").unwrap(), "code");
1861        assert_eq!(
1862            pairs.get("client_id").unwrap(),
1863            "https://zed.dev/oauth/client-metadata.json"
1864        );
1865        assert_eq!(
1866            pairs.get("redirect_uri").unwrap(),
1867            "http://127.0.0.1:12345/callback"
1868        );
1869        assert_eq!(pairs.get("scope").unwrap(), "files:read files:write");
1870        assert_eq!(pairs.get("resource").unwrap(), "https://mcp.example.com");
1871        assert_eq!(pairs.get("code_challenge").unwrap(), "test_challenge");
1872        assert_eq!(pairs.get("code_challenge_method").unwrap(), "S256");
1873        assert_eq!(pairs.get("state").unwrap(), "random_state_123");
1874    }
1875
1876    #[test]
1877    fn test_build_authorization_url_omits_empty_scope() {
1878        let metadata = AuthServerMetadata {
1879            issuer: Url::parse("https://auth.example.com").unwrap(),
1880            authorization_endpoint: Url::parse("https://auth.example.com/authorize").unwrap(),
1881            token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
1882            registration_endpoint: None,
1883            scopes_supported: None,
1884            code_challenge_methods_supported: Some(vec!["S256".into()]),
1885            client_id_metadata_document_supported: false,
1886            grant_types_supported: None,
1887        };
1888        let pkce = PkceChallenge {
1889            verifier: "v".into(),
1890            challenge: "c".into(),
1891        };
1892        let url = build_authorization_url(
1893            &metadata,
1894            "client_123",
1895            "http://127.0.0.1:9999/callback",
1896            &[],
1897            "https://mcp.example.com",
1898            &pkce,
1899            "state",
1900        );
1901
1902        let pairs: std::collections::HashMap<_, _> = url.query_pairs().collect();
1903        assert!(!pairs.contains_key("scope"));
1904    }
1905
1906    // -- Token exchange / refresh param tests --------------------------------
1907
1908    #[test]
1909    fn test_token_exchange_params() {
1910        let params = token_exchange_params(
1911            "auth_code_abc",
1912            "client_xyz",
1913            "http://127.0.0.1:5555/callback",
1914            "verifier_123",
1915            "https://mcp.example.com",
1916        );
1917        let map: std::collections::HashMap<&str, &str> =
1918            params.iter().map(|(k, v)| (*k, v.as_str())).collect();
1919
1920        assert_eq!(map["grant_type"], "authorization_code");
1921        assert_eq!(map["code"], "auth_code_abc");
1922        assert_eq!(map["redirect_uri"], "http://127.0.0.1:5555/callback");
1923        assert_eq!(map["client_id"], "client_xyz");
1924        assert_eq!(map["code_verifier"], "verifier_123");
1925        assert_eq!(map["resource"], "https://mcp.example.com");
1926    }
1927
1928    #[test]
1929    fn test_token_refresh_params() {
1930        let params =
1931            token_refresh_params("refresh_token_abc", "client_xyz", "https://mcp.example.com");
1932        let map: std::collections::HashMap<&str, &str> =
1933            params.iter().map(|(k, v)| (*k, v.as_str())).collect();
1934
1935        assert_eq!(map["grant_type"], "refresh_token");
1936        assert_eq!(map["refresh_token"], "refresh_token_abc");
1937        assert_eq!(map["client_id"], "client_xyz");
1938        assert_eq!(map["resource"], "https://mcp.example.com");
1939    }
1940
1941    // -- Token response tests ------------------------------------------------
1942
1943    #[test]
1944    fn test_token_response_into_tokens_with_expiry() {
1945        let response: TokenResponse = serde_json::from_str(
1946            r#"{"access_token": "at_123", "refresh_token": "rt_456", "expires_in": 3600, "token_type": "Bearer"}"#,
1947        )
1948        .unwrap();
1949
1950        let tokens = response.into_tokens();
1951        assert_eq!(tokens.access_token, "at_123");
1952        assert_eq!(tokens.refresh_token.as_deref(), Some("rt_456"));
1953        assert!(tokens.expires_at.is_some());
1954    }
1955
1956    #[test]
1957    fn test_token_response_into_tokens_minimal() {
1958        let response: TokenResponse =
1959            serde_json::from_str(r#"{"access_token": "at_789"}"#).unwrap();
1960
1961        let tokens = response.into_tokens();
1962        assert_eq!(tokens.access_token, "at_789");
1963        assert_eq!(tokens.refresh_token, None);
1964        assert_eq!(tokens.expires_at, None);
1965    }
1966
1967    // -- DCR body test -------------------------------------------------------
1968
1969    #[test]
1970    fn test_dcr_registration_body_without_server_metadata() {
1971        // When server metadata is unavailable, include all supported grant types.
1972        let body = dcr_registration_body("http://127.0.0.1:12345/callback", None);
1973        assert_eq!(body["client_name"], "Zed");
1974        assert_eq!(body["redirect_uris"][0], "http://127.0.0.1:12345/callback");
1975        assert_eq!(body["grant_types"][0], "authorization_code");
1976        assert_eq!(body["grant_types"][1], "refresh_token");
1977        assert_eq!(body["response_types"][0], "code");
1978        assert_eq!(body["token_endpoint_auth_method"], "none");
1979    }
1980
1981    #[test]
1982    fn test_dcr_registration_body_mirrors_server_grant_types() {
1983        // When the server only supports authorization_code, omit refresh_token.
1984        let server_types = vec!["authorization_code".to_string()];
1985        let body = dcr_registration_body("http://127.0.0.1:12345/callback", Some(&server_types));
1986        assert_eq!(body["grant_types"][0], "authorization_code");
1987        assert!(body["grant_types"].as_array().unwrap().len() == 1);
1988
1989        // When the server supports both, include both.
1990        let server_types = vec![
1991            "authorization_code".to_string(),
1992            "refresh_token".to_string(),
1993        ];
1994        let body = dcr_registration_body("http://127.0.0.1:12345/callback", Some(&server_types));
1995        assert_eq!(body["grant_types"][0], "authorization_code");
1996        assert_eq!(body["grant_types"][1], "refresh_token");
1997    }
1998
1999    // -- Test helpers for async/HTTP tests -----------------------------------
2000
2001    fn make_fake_http_client(
2002        handler: impl Fn(
2003            http_client::Request<AsyncBody>,
2004        ) -> std::pin::Pin<
2005            Box<dyn std::future::Future<Output = anyhow::Result<Response<AsyncBody>>> + Send>,
2006        > + Send
2007        + Sync
2008        + 'static,
2009    ) -> Arc<dyn HttpClient> {
2010        http_client::FakeHttpClient::create(handler) as Arc<dyn HttpClient>
2011    }
2012
2013    fn json_response(status: u16, body: &str) -> anyhow::Result<Response<AsyncBody>> {
2014        Ok(Response::builder()
2015            .status(status)
2016            .header("Content-Type", "application/json")
2017            .body(AsyncBody::from(body.as_bytes().to_vec()))
2018            .unwrap())
2019    }
2020
2021    // -- Discovery integration tests -----------------------------------------
2022
2023    #[test]
2024    fn test_fetch_protected_resource_metadata() {
2025        smol::block_on(async {
2026            let client = make_fake_http_client(|req| {
2027                Box::pin(async move {
2028                    let uri = req.uri().to_string();
2029                    if uri.contains(".well-known/oauth-protected-resource") {
2030                        json_response(
2031                            200,
2032                            r#"{
2033                                "resource": "https://mcp.example.com",
2034                                "authorization_servers": ["https://auth.example.com"],
2035                                "scopes_supported": ["read", "write"]
2036                            }"#,
2037                        )
2038                    } else {
2039                        json_response(404, "{}")
2040                    }
2041                })
2042            });
2043
2044            let server_url = Url::parse("https://mcp.example.com").unwrap();
2045            let www_auth = WwwAuthenticate {
2046                resource_metadata: None,
2047                scope: None,
2048                error: None,
2049                error_description: None,
2050            };
2051
2052            let metadata = fetch_protected_resource_metadata(&client, &server_url, &www_auth)
2053                .await
2054                .unwrap();
2055
2056            assert_eq!(metadata.resource.as_str(), "https://mcp.example.com/");
2057            assert_eq!(metadata.authorization_servers.len(), 1);
2058            assert_eq!(
2059                metadata.authorization_servers[0].as_str(),
2060                "https://auth.example.com/"
2061            );
2062            assert_eq!(
2063                metadata.scopes_supported,
2064                Some(vec!["read".to_string(), "write".to_string()])
2065            );
2066        });
2067    }
2068
2069    #[test]
2070    fn test_fetch_protected_resource_metadata_prefers_www_authenticate_url() {
2071        smol::block_on(async {
2072            let client = make_fake_http_client(|req| {
2073                Box::pin(async move {
2074                    let uri = req.uri().to_string();
2075                    if uri == "https://mcp.example.com/custom-resource-metadata" {
2076                        json_response(
2077                            200,
2078                            r#"{
2079                                "resource": "https://mcp.example.com",
2080                                "authorization_servers": ["https://auth.example.com"]
2081                            }"#,
2082                        )
2083                    } else {
2084                        json_response(500, r#"{"error": "should not be called"}"#)
2085                    }
2086                })
2087            });
2088
2089            let server_url = Url::parse("https://mcp.example.com").unwrap();
2090            let www_auth = WwwAuthenticate {
2091                resource_metadata: Some(
2092                    Url::parse("https://mcp.example.com/custom-resource-metadata").unwrap(),
2093                ),
2094                scope: None,
2095                error: None,
2096                error_description: None,
2097            };
2098
2099            let metadata = fetch_protected_resource_metadata(&client, &server_url, &www_auth)
2100                .await
2101                .unwrap();
2102
2103            assert_eq!(metadata.authorization_servers.len(), 1);
2104        });
2105    }
2106
2107    #[test]
2108    fn test_fetch_protected_resource_metadata_rejects_cross_origin_url() {
2109        smol::block_on(async {
2110            let client = make_fake_http_client(|req| {
2111                Box::pin(async move {
2112                    let uri = req.uri().to_string();
2113                    // The cross-origin URL should NOT be fetched; only the
2114                    // well-known fallback at the server's own origin should be.
2115                    if uri.contains("attacker.example.com") {
2116                        panic!("should not fetch cross-origin resource_metadata URL");
2117                    } else if uri.contains(".well-known/oauth-protected-resource") {
2118                        json_response(
2119                            200,
2120                            r#"{
2121                                "resource": "https://mcp.example.com",
2122                                "authorization_servers": ["https://auth.example.com"]
2123                            }"#,
2124                        )
2125                    } else {
2126                        json_response(404, "{}")
2127                    }
2128                })
2129            });
2130
2131            let server_url = Url::parse("https://mcp.example.com").unwrap();
2132            let www_auth = WwwAuthenticate {
2133                resource_metadata: Some(
2134                    Url::parse("https://attacker.example.com/fake-metadata").unwrap(),
2135                ),
2136                scope: None,
2137                error: None,
2138                error_description: None,
2139            };
2140
2141            let metadata = fetch_protected_resource_metadata(&client, &server_url, &www_auth)
2142                .await
2143                .unwrap();
2144
2145            // Should have used the fallback well-known URL, not the attacker's.
2146            assert_eq!(metadata.resource.as_str(), "https://mcp.example.com/");
2147        });
2148    }
2149
2150    #[test]
2151    fn test_fetch_auth_server_metadata() {
2152        smol::block_on(async {
2153            let client = make_fake_http_client(|req| {
2154                Box::pin(async move {
2155                    let uri = req.uri().to_string();
2156                    if uri.contains(".well-known/oauth-authorization-server") {
2157                        json_response(
2158                            200,
2159                            r#"{
2160                                "issuer": "https://auth.example.com",
2161                                "authorization_endpoint": "https://auth.example.com/authorize",
2162                                "token_endpoint": "https://auth.example.com/token",
2163                                "registration_endpoint": "https://auth.example.com/register",
2164                                "code_challenge_methods_supported": ["S256"],
2165                                "client_id_metadata_document_supported": true
2166                            }"#,
2167                        )
2168                    } else {
2169                        json_response(404, "{}")
2170                    }
2171                })
2172            });
2173
2174            let issuer = Url::parse("https://auth.example.com").unwrap();
2175            let metadata = fetch_auth_server_metadata(&client, &issuer).await.unwrap();
2176
2177            assert_eq!(metadata.issuer.as_str(), "https://auth.example.com/");
2178            assert_eq!(
2179                metadata.authorization_endpoint.as_str(),
2180                "https://auth.example.com/authorize"
2181            );
2182            assert_eq!(
2183                metadata.token_endpoint.as_str(),
2184                "https://auth.example.com/token"
2185            );
2186            assert!(metadata.registration_endpoint.is_some());
2187            assert!(metadata.client_id_metadata_document_supported);
2188            assert_eq!(
2189                metadata.code_challenge_methods_supported,
2190                Some(vec!["S256".to_string()])
2191            );
2192        });
2193    }
2194
2195    #[test]
2196    fn test_fetch_auth_server_metadata_falls_back_to_oidc() {
2197        smol::block_on(async {
2198            let client = make_fake_http_client(|req| {
2199                Box::pin(async move {
2200                    let uri = req.uri().to_string();
2201                    if uri.contains("openid-configuration") {
2202                        json_response(
2203                            200,
2204                            r#"{
2205                                "issuer": "https://auth.example.com",
2206                                "authorization_endpoint": "https://auth.example.com/authorize",
2207                                "token_endpoint": "https://auth.example.com/token",
2208                                "code_challenge_methods_supported": ["S256"]
2209                            }"#,
2210                        )
2211                    } else {
2212                        json_response(404, "{}")
2213                    }
2214                })
2215            });
2216
2217            let issuer = Url::parse("https://auth.example.com").unwrap();
2218            let metadata = fetch_auth_server_metadata(&client, &issuer).await.unwrap();
2219
2220            assert_eq!(
2221                metadata.authorization_endpoint.as_str(),
2222                "https://auth.example.com/authorize"
2223            );
2224            assert!(!metadata.client_id_metadata_document_supported);
2225        });
2226    }
2227
2228    #[test]
2229    fn test_fetch_auth_server_metadata_rejects_issuer_mismatch() {
2230        smol::block_on(async {
2231            let client = make_fake_http_client(|req| {
2232                Box::pin(async move {
2233                    let uri = req.uri().to_string();
2234                    if uri.contains(".well-known/oauth-authorization-server") {
2235                        // Response claims to be a different issuer.
2236                        json_response(
2237                            200,
2238                            r#"{
2239                                "issuer": "https://evil.example.com",
2240                                "authorization_endpoint": "https://evil.example.com/authorize",
2241                                "token_endpoint": "https://evil.example.com/token",
2242                                "code_challenge_methods_supported": ["S256"]
2243                            }"#,
2244                        )
2245                    } else {
2246                        json_response(404, "{}")
2247                    }
2248                })
2249            });
2250
2251            let issuer = Url::parse("https://auth.example.com").unwrap();
2252            let result = fetch_auth_server_metadata(&client, &issuer).await;
2253
2254            assert!(result.is_err());
2255            let err_msg = result.unwrap_err().to_string();
2256            assert!(
2257                err_msg.contains("issuer mismatch"),
2258                "unexpected error: {}",
2259                err_msg
2260            );
2261        });
2262    }
2263
2264    // -- Full discover integration tests -------------------------------------
2265
2266    #[test]
2267    fn test_full_discover_with_cimd() {
2268        smol::block_on(async {
2269            let client = make_fake_http_client(|req| {
2270                Box::pin(async move {
2271                    let uri = req.uri().to_string();
2272                    if uri.contains("oauth-protected-resource") {
2273                        json_response(
2274                            200,
2275                            r#"{
2276                                "resource": "https://mcp.example.com",
2277                                "authorization_servers": ["https://auth.example.com"],
2278                                "scopes_supported": ["mcp:read"]
2279                            }"#,
2280                        )
2281                    } else if uri.contains("oauth-authorization-server") {
2282                        json_response(
2283                            200,
2284                            r#"{
2285                                "issuer": "https://auth.example.com",
2286                                "authorization_endpoint": "https://auth.example.com/authorize",
2287                                "token_endpoint": "https://auth.example.com/token",
2288                                "code_challenge_methods_supported": ["S256"],
2289                                "client_id_metadata_document_supported": true
2290                            }"#,
2291                        )
2292                    } else {
2293                        json_response(404, "{}")
2294                    }
2295                })
2296            });
2297
2298            let server_url = Url::parse("https://mcp.example.com").unwrap();
2299            let www_auth = WwwAuthenticate {
2300                resource_metadata: None,
2301                scope: None,
2302                error: None,
2303                error_description: None,
2304            };
2305
2306            let discovery = discover(&client, &server_url, &www_auth).await.unwrap();
2307            let registration =
2308                resolve_client_registration(&client, &discovery, "http://127.0.0.1:12345/callback")
2309                    .await
2310                    .unwrap();
2311
2312            assert_eq!(registration.client_id, CIMD_URL);
2313            assert_eq!(registration.client_secret, None);
2314            assert_eq!(discovery.scopes, vec!["mcp:read"]);
2315        });
2316    }
2317
2318    #[test]
2319    fn test_full_discover_with_dcr_fallback() {
2320        smol::block_on(async {
2321            let client = make_fake_http_client(|req| {
2322                Box::pin(async move {
2323                    let uri = req.uri().to_string();
2324                    if uri.contains("oauth-protected-resource") {
2325                        json_response(
2326                            200,
2327                            r#"{
2328                                "resource": "https://mcp.example.com",
2329                                "authorization_servers": ["https://auth.example.com"]
2330                            }"#,
2331                        )
2332                    } else if uri.contains("oauth-authorization-server") {
2333                        json_response(
2334                            200,
2335                            r#"{
2336                                "issuer": "https://auth.example.com",
2337                                "authorization_endpoint": "https://auth.example.com/authorize",
2338                                "token_endpoint": "https://auth.example.com/token",
2339                                "registration_endpoint": "https://auth.example.com/register",
2340                                "code_challenge_methods_supported": ["S256"],
2341                                "client_id_metadata_document_supported": false
2342                            }"#,
2343                        )
2344                    } else if uri.contains("/register") {
2345                        json_response(
2346                            201,
2347                            r#"{
2348                                "client_id": "dcr-minted-id-123",
2349                                "client_secret": "dcr-secret-456"
2350                            }"#,
2351                        )
2352                    } else {
2353                        json_response(404, "{}")
2354                    }
2355                })
2356            });
2357
2358            let server_url = Url::parse("https://mcp.example.com").unwrap();
2359            let www_auth = WwwAuthenticate {
2360                resource_metadata: None,
2361                scope: Some(vec!["files:read".into()]),
2362                error: None,
2363                error_description: None,
2364            };
2365
2366            let discovery = discover(&client, &server_url, &www_auth).await.unwrap();
2367            let registration =
2368                resolve_client_registration(&client, &discovery, "http://127.0.0.1:9999/callback")
2369                    .await
2370                    .unwrap();
2371
2372            assert_eq!(registration.client_id, "dcr-minted-id-123");
2373            assert_eq!(
2374                registration.client_secret.as_deref(),
2375                Some("dcr-secret-456")
2376            );
2377            assert_eq!(discovery.scopes, vec!["files:read"]);
2378        });
2379    }
2380
2381    #[test]
2382    fn test_discover_fails_without_pkce_support() {
2383        smol::block_on(async {
2384            let client = make_fake_http_client(|req| {
2385                Box::pin(async move {
2386                    let uri = req.uri().to_string();
2387                    if uri.contains("oauth-protected-resource") {
2388                        json_response(
2389                            200,
2390                            r#"{
2391                                "resource": "https://mcp.example.com",
2392                                "authorization_servers": ["https://auth.example.com"]
2393                            }"#,
2394                        )
2395                    } else if uri.contains("oauth-authorization-server") {
2396                        json_response(
2397                            200,
2398                            r#"{
2399                                "issuer": "https://auth.example.com",
2400                                "authorization_endpoint": "https://auth.example.com/authorize",
2401                                "token_endpoint": "https://auth.example.com/token"
2402                            }"#,
2403                        )
2404                    } else {
2405                        json_response(404, "{}")
2406                    }
2407                })
2408            });
2409
2410            let server_url = Url::parse("https://mcp.example.com").unwrap();
2411            let www_auth = WwwAuthenticate {
2412                resource_metadata: None,
2413                scope: None,
2414                error: None,
2415                error_description: None,
2416            };
2417
2418            let result = discover(&client, &server_url, &www_auth).await;
2419            assert!(result.is_err());
2420            let err_msg = result.unwrap_err().to_string();
2421            assert!(
2422                err_msg.contains("code_challenge_methods_supported"),
2423                "unexpected error: {}",
2424                err_msg
2425            );
2426        });
2427    }
2428
2429    // -- Token exchange integration tests ------------------------------------
2430
2431    #[test]
2432    fn test_exchange_code_success() {
2433        smol::block_on(async {
2434            let client = make_fake_http_client(|req| {
2435                Box::pin(async move {
2436                    let uri = req.uri().to_string();
2437                    if uri.contains("/token") {
2438                        json_response(
2439                            200,
2440                            r#"{
2441                                "access_token": "new_access_token",
2442                                "refresh_token": "new_refresh_token",
2443                                "expires_in": 3600,
2444                                "token_type": "Bearer"
2445                            }"#,
2446                        )
2447                    } else {
2448                        json_response(404, "{}")
2449                    }
2450                })
2451            });
2452
2453            let metadata = AuthServerMetadata {
2454                issuer: Url::parse("https://auth.example.com").unwrap(),
2455                authorization_endpoint: Url::parse("https://auth.example.com/authorize").unwrap(),
2456                token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
2457                registration_endpoint: None,
2458                scopes_supported: None,
2459                code_challenge_methods_supported: Some(vec!["S256".into()]),
2460                client_id_metadata_document_supported: true,
2461                grant_types_supported: None,
2462            };
2463
2464            let tokens = exchange_code(
2465                &client,
2466                &metadata,
2467                "auth_code_123",
2468                CIMD_URL,
2469                "http://127.0.0.1:9999/callback",
2470                "verifier_abc",
2471                "https://mcp.example.com",
2472            )
2473            .await
2474            .unwrap();
2475
2476            assert_eq!(tokens.access_token, "new_access_token");
2477            assert_eq!(tokens.refresh_token.as_deref(), Some("new_refresh_token"));
2478            assert!(tokens.expires_at.is_some());
2479        });
2480    }
2481
2482    #[test]
2483    fn test_refresh_tokens_success() {
2484        smol::block_on(async {
2485            let client = make_fake_http_client(|req| {
2486                Box::pin(async move {
2487                    let uri = req.uri().to_string();
2488                    if uri.contains("/token") {
2489                        json_response(
2490                            200,
2491                            r#"{
2492                                "access_token": "refreshed_token",
2493                                "expires_in": 1800,
2494                                "token_type": "Bearer"
2495                            }"#,
2496                        )
2497                    } else {
2498                        json_response(404, "{}")
2499                    }
2500                })
2501            });
2502
2503            let token_endpoint = Url::parse("https://auth.example.com/token").unwrap();
2504
2505            let tokens = refresh_tokens(
2506                &client,
2507                &token_endpoint,
2508                "old_refresh_token",
2509                CIMD_URL,
2510                "https://mcp.example.com",
2511            )
2512            .await
2513            .unwrap();
2514
2515            assert_eq!(tokens.access_token, "refreshed_token");
2516            assert_eq!(tokens.refresh_token, None);
2517            assert!(tokens.expires_at.is_some());
2518        });
2519    }
2520
2521    #[test]
2522    fn test_exchange_code_failure() {
2523        smol::block_on(async {
2524            let client = make_fake_http_client(|_req| {
2525                Box::pin(async move { json_response(400, r#"{"error": "invalid_grant"}"#) })
2526            });
2527
2528            let metadata = AuthServerMetadata {
2529                issuer: Url::parse("https://auth.example.com").unwrap(),
2530                authorization_endpoint: Url::parse("https://auth.example.com/authorize").unwrap(),
2531                token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
2532                registration_endpoint: None,
2533                scopes_supported: None,
2534                code_challenge_methods_supported: Some(vec!["S256".into()]),
2535                client_id_metadata_document_supported: true,
2536                grant_types_supported: None,
2537            };
2538
2539            let result = exchange_code(
2540                &client,
2541                &metadata,
2542                "bad_code",
2543                "client",
2544                "http://127.0.0.1:1/callback",
2545                "verifier",
2546                "https://mcp.example.com",
2547            )
2548            .await;
2549
2550            assert!(result.is_err());
2551            assert!(result.unwrap_err().to_string().contains("400"));
2552        });
2553    }
2554
2555    // -- DCR integration tests -----------------------------------------------
2556
2557    #[test]
2558    fn test_perform_dcr() {
2559        smol::block_on(async {
2560            let client = make_fake_http_client(|_req| {
2561                Box::pin(async move {
2562                    json_response(
2563                        201,
2564                        r#"{
2565                            "client_id": "dynamic-client-001",
2566                            "client_secret": "dynamic-secret-001"
2567                        }"#,
2568                    )
2569                })
2570            });
2571
2572            let endpoint = Url::parse("https://auth.example.com/register").unwrap();
2573            let registration =
2574                perform_dcr(&client, &endpoint, "http://127.0.0.1:9999/callback", None)
2575                    .await
2576                    .unwrap();
2577
2578            assert_eq!(registration.client_id, "dynamic-client-001");
2579            assert_eq!(
2580                registration.client_secret.as_deref(),
2581                Some("dynamic-secret-001")
2582            );
2583        });
2584    }
2585
2586    #[test]
2587    fn test_perform_dcr_failure() {
2588        smol::block_on(async {
2589            let client = make_fake_http_client(|_req| {
2590                Box::pin(
2591                    async move { json_response(403, r#"{"error": "registration_not_allowed"}"#) },
2592                )
2593            });
2594
2595            let endpoint = Url::parse("https://auth.example.com/register").unwrap();
2596            let result =
2597                perform_dcr(&client, &endpoint, "http://127.0.0.1:9999/callback", None).await;
2598
2599            assert!(result.is_err());
2600            assert!(result.unwrap_err().to_string().contains("403"));
2601        });
2602    }
2603
2604    // -- OAuthCallback parse tests -------------------------------------------
2605
2606    #[test]
2607    fn test_oauth_callback_parse_query() {
2608        let callback = OAuthCallback::parse_query("code=test_auth_code&state=test_state").unwrap();
2609        assert_eq!(callback.code, "test_auth_code");
2610        assert_eq!(callback.state, "test_state");
2611    }
2612
2613    #[test]
2614    fn test_oauth_callback_parse_query_reversed_order() {
2615        let callback = OAuthCallback::parse_query("state=test_state&code=test_auth_code").unwrap();
2616        assert_eq!(callback.code, "test_auth_code");
2617        assert_eq!(callback.state, "test_state");
2618    }
2619
2620    #[test]
2621    fn test_oauth_callback_parse_query_with_extra_params() {
2622        let callback =
2623            OAuthCallback::parse_query("code=test_auth_code&state=test_state&extra=ignored")
2624                .unwrap();
2625        assert_eq!(callback.code, "test_auth_code");
2626        assert_eq!(callback.state, "test_state");
2627    }
2628
2629    #[test]
2630    fn test_oauth_callback_parse_query_missing_code() {
2631        let result = OAuthCallback::parse_query("state=test_state");
2632        assert!(result.is_err());
2633        assert!(result.unwrap_err().to_string().contains("code"));
2634    }
2635
2636    #[test]
2637    fn test_oauth_callback_parse_query_missing_state() {
2638        let result = OAuthCallback::parse_query("code=test_auth_code");
2639        assert!(result.is_err());
2640        assert!(result.unwrap_err().to_string().contains("state"));
2641    }
2642
2643    #[test]
2644    fn test_oauth_callback_parse_query_empty_code() {
2645        let result = OAuthCallback::parse_query("code=&state=test_state");
2646        assert!(result.is_err());
2647    }
2648
2649    #[test]
2650    fn test_oauth_callback_parse_query_empty_state() {
2651        let result = OAuthCallback::parse_query("code=test_auth_code&state=");
2652        assert!(result.is_err());
2653    }
2654
2655    #[test]
2656    fn test_oauth_callback_parse_query_url_encoded_values() {
2657        let callback = OAuthCallback::parse_query("code=abc%20def&state=test%3Dstate").unwrap();
2658        assert_eq!(callback.code, "abc def");
2659        assert_eq!(callback.state, "test=state");
2660    }
2661
2662    #[test]
2663    fn test_oauth_callback_parse_query_error_response() {
2664        let result = OAuthCallback::parse_query(
2665            "error=access_denied&error_description=User%20denied%20access&state=abc",
2666        );
2667        assert!(result.is_err());
2668        let err_msg = result.unwrap_err().to_string();
2669        assert!(
2670            err_msg.contains("access_denied"),
2671            "unexpected error: {}",
2672            err_msg
2673        );
2674        assert!(
2675            err_msg.contains("User denied access"),
2676            "unexpected error: {}",
2677            err_msg
2678        );
2679    }
2680
2681    #[test]
2682    fn test_oauth_callback_parse_query_error_without_description() {
2683        let result = OAuthCallback::parse_query("error=server_error&state=abc");
2684        assert!(result.is_err());
2685        let err_msg = result.unwrap_err().to_string();
2686        assert!(
2687            err_msg.contains("server_error"),
2688            "unexpected error: {}",
2689            err_msg
2690        );
2691        assert!(
2692            err_msg.contains("no description"),
2693            "unexpected error: {}",
2694            err_msg
2695        );
2696    }
2697
2698    // -- McpOAuthTokenProvider tests -----------------------------------------
2699
2700    fn make_test_session(
2701        access_token: &str,
2702        refresh_token: Option<&str>,
2703        expires_at: Option<SystemTime>,
2704    ) -> OAuthSession {
2705        OAuthSession {
2706            token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
2707            resource: Url::parse("https://mcp.example.com").unwrap(),
2708            client_registration: OAuthClientRegistration {
2709                client_id: "test-client".into(),
2710                client_secret: None,
2711            },
2712            tokens: OAuthTokens {
2713                access_token: access_token.into(),
2714                refresh_token: refresh_token.map(String::from),
2715                expires_at,
2716            },
2717        }
2718    }
2719
2720    #[test]
2721    fn test_mcp_oauth_provider_returns_none_when_token_expired() {
2722        let expired = SystemTime::now() - Duration::from_secs(60);
2723        let session = make_test_session("stale-token", Some("rt"), Some(expired));
2724        let provider = McpOAuthTokenProvider::new(
2725            session,
2726            make_fake_http_client(|_| Box::pin(async { unreachable!() })),
2727            None,
2728        );
2729
2730        assert_eq!(provider.access_token(), None);
2731    }
2732
2733    #[test]
2734    fn test_mcp_oauth_provider_returns_token_when_not_expired() {
2735        let far_future = SystemTime::now() + Duration::from_secs(3600);
2736        let session = make_test_session("valid-token", Some("rt"), Some(far_future));
2737        let provider = McpOAuthTokenProvider::new(
2738            session,
2739            make_fake_http_client(|_| Box::pin(async { unreachable!() })),
2740            None,
2741        );
2742
2743        assert_eq!(provider.access_token().as_deref(), Some("valid-token"));
2744    }
2745
2746    #[test]
2747    fn test_mcp_oauth_provider_returns_token_when_no_expiry() {
2748        let session = make_test_session("no-expiry-token", Some("rt"), None);
2749        let provider = McpOAuthTokenProvider::new(
2750            session,
2751            make_fake_http_client(|_| Box::pin(async { unreachable!() })),
2752            None,
2753        );
2754
2755        assert_eq!(provider.access_token().as_deref(), Some("no-expiry-token"));
2756    }
2757
2758    #[test]
2759    fn test_mcp_oauth_provider_refresh_without_refresh_token_returns_false() {
2760        smol::block_on(async {
2761            let session = make_test_session("token", None, None);
2762            let provider = McpOAuthTokenProvider::new(
2763                session,
2764                make_fake_http_client(|_| {
2765                    Box::pin(async { unreachable!("no HTTP call expected") })
2766                }),
2767                None,
2768            );
2769
2770            let refreshed = provider.try_refresh().await.unwrap();
2771            assert!(!refreshed);
2772        });
2773    }
2774
2775    #[test]
2776    fn test_mcp_oauth_provider_refresh_updates_session_and_notifies_channel() {
2777        smol::block_on(async {
2778            let session = make_test_session("old-access", Some("my-refresh-token"), None);
2779            let (tx, mut rx) = futures::channel::mpsc::unbounded();
2780
2781            let http_client = make_fake_http_client(|_req| {
2782                Box::pin(async {
2783                    json_response(
2784                        200,
2785                        r#"{
2786                            "access_token": "new-access",
2787                            "refresh_token": "new-refresh",
2788                            "expires_in": 1800
2789                        }"#,
2790                    )
2791                })
2792            });
2793
2794            let provider = McpOAuthTokenProvider::new(session, http_client, Some(tx));
2795
2796            let refreshed = provider.try_refresh().await.unwrap();
2797            assert!(refreshed);
2798            assert_eq!(provider.access_token().as_deref(), Some("new-access"));
2799
2800            let notified_session = rx.try_recv().expect("channel should have a session");
2801            assert_eq!(notified_session.tokens.access_token, "new-access");
2802            assert_eq!(
2803                notified_session.tokens.refresh_token.as_deref(),
2804                Some("new-refresh")
2805            );
2806        });
2807    }
2808
2809    #[test]
2810    fn test_mcp_oauth_provider_refresh_preserves_old_refresh_token_when_server_omits_it() {
2811        smol::block_on(async {
2812            let session = make_test_session("old-access", Some("original-refresh"), None);
2813            let (tx, mut rx) = futures::channel::mpsc::unbounded();
2814
2815            let http_client = make_fake_http_client(|_req| {
2816                Box::pin(async {
2817                    json_response(
2818                        200,
2819                        r#"{
2820                            "access_token": "new-access",
2821                            "expires_in": 900
2822                        }"#,
2823                    )
2824                })
2825            });
2826
2827            let provider = McpOAuthTokenProvider::new(session, http_client, Some(tx));
2828
2829            let refreshed = provider.try_refresh().await.unwrap();
2830            assert!(refreshed);
2831
2832            let notified_session = rx.try_recv().expect("channel should have a session");
2833            assert_eq!(notified_session.tokens.access_token, "new-access");
2834            assert_eq!(
2835                notified_session.tokens.refresh_token.as_deref(),
2836                Some("original-refresh"),
2837            );
2838        });
2839    }
2840
2841    #[test]
2842    fn test_mcp_oauth_provider_refresh_returns_false_on_http_error() {
2843        smol::block_on(async {
2844            let session = make_test_session("old-access", Some("my-refresh"), None);
2845
2846            let http_client = make_fake_http_client(|_req| {
2847                Box::pin(async { json_response(401, r#"{"error": "invalid_grant"}"#) })
2848            });
2849
2850            let provider = McpOAuthTokenProvider::new(session, http_client, None);
2851
2852            let refreshed = provider.try_refresh().await.unwrap();
2853            assert!(!refreshed);
2854            // The old token should still be in place.
2855            assert_eq!(provider.access_token().as_deref(), Some("old-access"));
2856        });
2857    }
2858}