1//! OAuth 2.0 authentication for MCP servers using the Authorization Code +
2//! PKCE flow, per the MCP spec's OAuth profile.
3//!
4//! The flow is split into two phases:
5//!
6//! 1. **Discovery** ([`discover`]) fetches Protected Resource Metadata and
7//! Authorization Server Metadata. This can happen early (e.g. on a 401
8//! during server startup) because it doesn't need the redirect URI yet.
9//!
10//! 2. **Client registration** ([`resolve_client_registration`]) is separate
11//! because DCR requires the actual loopback redirect URI, which includes an
12//! ephemeral port that only exists once the callback server has started.
13//!
14//! After authentication, the full state is captured in [`OAuthSession`] which
15//! is persisted to the keychain. On next startup, the stored session feeds
16//! directly into [`McpOAuthTokenProvider`], giving a refresh-capable provider
17//! without requiring another browser flow.
18
19use anyhow::{Context as _, Result, anyhow, bail};
20use async_trait::async_trait;
21use base64::Engine as _;
22use futures::AsyncReadExt as _;
23use futures::channel::mpsc;
24use http_client::{AsyncBody, HttpClient, Request};
25use parking_lot::Mutex as SyncMutex;
26use rand::Rng as _;
27use serde::{Deserialize, Serialize};
28use sha2::{Digest, Sha256};
29
30use std::str::FromStr;
31use std::sync::Arc;
32use std::time::{Duration, SystemTime};
33use url::Url;
34use util::ResultExt as _;
35
36/// The CIMD URL where Zed's OAuth client metadata document is hosted.
37pub const CIMD_URL: &str = "https://zed.dev/oauth/client-metadata.json";
38
39/// Validate that a URL is safe to use as an OAuth endpoint.
40///
41/// OAuth endpoints carry sensitive material (authorization codes, PKCE
42/// verifiers, tokens) and must use TLS. Plain HTTP is only permitted for
43/// loopback addresses, per RFC 8252 Section 8.3.
44fn require_https_or_loopback(url: &Url) -> Result<()> {
45 if url.scheme() == "https" {
46 return Ok(());
47 }
48 if url.scheme() == "http" {
49 if let Some(host) = url.host() {
50 match host {
51 url::Host::Ipv4(ip) if ip.is_loopback() => return Ok(()),
52 url::Host::Ipv6(ip) if ip.is_loopback() => return Ok(()),
53 url::Host::Domain(d) if d.eq_ignore_ascii_case("localhost") => return Ok(()),
54 _ => {}
55 }
56 }
57 }
58 bail!(
59 "OAuth endpoint must use HTTPS (got {}://{})",
60 url.scheme(),
61 url.host_str().unwrap_or("?")
62 )
63}
64
65/// Validate that a URL is safe to use as an OAuth endpoint, including SSRF
66/// protections against private/reserved IP ranges.
67///
68/// This wraps [`require_https_or_loopback`] and adds IP-range checks to prevent
69/// an attacker-controlled MCP server from directing Zed to fetch internal
70/// network resources via metadata URLs.
71///
72/// **Known limitation:** Domain-name URLs that resolve to private IPs are *not*
73/// blocked here — full mitigation requires resolver-level validation (e.g. a
74/// custom `Resolve` implementation). This function only blocks IP-literal URLs.
75fn validate_oauth_url(url: &Url) -> Result<()> {
76 require_https_or_loopback(url)?;
77
78 if let Some(host) = url.host() {
79 match host {
80 url::Host::Ipv4(ip) => {
81 // Loopback is already allowed by require_https_or_loopback.
82 if ip.is_private() || ip.is_link_local() || ip.is_broadcast() || ip.is_unspecified()
83 {
84 bail!(
85 "OAuth endpoint must not point to private/reserved IP: {}",
86 ip
87 );
88 }
89 }
90 url::Host::Ipv6(ip) => {
91 // Check for IPv4-mapped IPv6 addresses (::ffff:a.b.c.d) which
92 // could bypass the IPv4 checks above.
93 if let Some(mapped_v4) = ip.to_ipv4_mapped() {
94 if mapped_v4.is_private()
95 || mapped_v4.is_link_local()
96 || mapped_v4.is_broadcast()
97 || mapped_v4.is_unspecified()
98 {
99 bail!(
100 "OAuth endpoint must not point to private/reserved IP: ::ffff:{}",
101 mapped_v4
102 );
103 }
104 }
105
106 if ip.is_unspecified() || ip.is_multicast() {
107 bail!(
108 "OAuth endpoint must not point to reserved IPv6 address: {}",
109 ip
110 );
111 }
112 // IPv6 Unique Local Addresses (fc00::/7). is_unique_local() is
113 // nightly-only, so check the prefix manually.
114 if (ip.segments()[0] & 0xfe00) == 0xfc00 {
115 bail!(
116 "OAuth endpoint must not point to IPv6 unique-local address: {}",
117 ip
118 );
119 }
120 }
121 url::Host::Domain(_) => {
122 // Domain-based SSRF prevention requires resolver-level checks.
123 // See known limitation in the doc comment above.
124 }
125 }
126 }
127
128 Ok(())
129}
130
131/// Parsed from the MCP server's WWW-Authenticate header or well-known endpoint
132/// per RFC 9728 (OAuth 2.0 Protected Resource Metadata).
133#[derive(Debug, Clone, Serialize, Deserialize)]
134pub struct ProtectedResourceMetadata {
135 pub resource: Url,
136 pub authorization_servers: Vec<Url>,
137 pub scopes_supported: Option<Vec<String>>,
138}
139
140/// Parsed from the authorization server's .well-known endpoint
141/// per RFC 8414 (OAuth 2.0 Authorization Server Metadata).
142#[derive(Debug, Clone, Serialize, Deserialize)]
143pub struct AuthServerMetadata {
144 pub issuer: Url,
145 pub authorization_endpoint: Url,
146 pub token_endpoint: Url,
147 pub registration_endpoint: Option<Url>,
148 pub scopes_supported: Option<Vec<String>>,
149 pub code_challenge_methods_supported: Option<Vec<String>>,
150 pub client_id_metadata_document_supported: bool,
151}
152
153/// The result of client registration — either CIMD or DCR.
154#[derive(Clone, Serialize, Deserialize)]
155pub struct OAuthClientRegistration {
156 pub client_id: String,
157 /// Only present for DCR-minted registrations.
158 pub client_secret: Option<String>,
159}
160
161impl std::fmt::Debug for OAuthClientRegistration {
162 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
163 f.debug_struct("OAuthClientRegistration")
164 .field("client_id", &self.client_id)
165 .field(
166 "client_secret",
167 &self.client_secret.as_ref().map(|_| "[redacted]"),
168 )
169 .finish()
170 }
171}
172
173/// Access and refresh tokens obtained from the token endpoint.
174#[derive(Clone, Serialize, Deserialize)]
175pub struct OAuthTokens {
176 pub access_token: String,
177 pub refresh_token: Option<String>,
178 pub expires_at: Option<SystemTime>,
179}
180
181impl std::fmt::Debug for OAuthTokens {
182 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
183 f.debug_struct("OAuthTokens")
184 .field("access_token", &"[redacted]")
185 .field(
186 "refresh_token",
187 &self.refresh_token.as_ref().map(|_| "[redacted]"),
188 )
189 .field("expires_at", &self.expires_at)
190 .finish()
191 }
192}
193
194/// Everything discovered before the browser flow starts. Client registration is
195/// resolved separately, once the real redirect URI is known.
196#[derive(Debug, Clone, Serialize, Deserialize)]
197pub struct OAuthDiscovery {
198 pub resource_metadata: ProtectedResourceMetadata,
199 pub auth_server_metadata: AuthServerMetadata,
200 pub scopes: Vec<String>,
201}
202
203/// The persisted OAuth session for a context server.
204///
205/// Stored in the keychain so startup can restore a refresh-capable provider
206/// without another browser flow. Deliberately excludes the full discovery
207/// metadata to keep the serialized size well within keychain item limits.
208#[derive(Clone, Serialize, Deserialize)]
209pub struct OAuthSession {
210 pub token_endpoint: Url,
211 pub resource: Url,
212 pub client_registration: OAuthClientRegistration,
213 pub tokens: OAuthTokens,
214}
215
216impl std::fmt::Debug for OAuthSession {
217 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
218 f.debug_struct("OAuthSession")
219 .field("token_endpoint", &self.token_endpoint)
220 .field("resource", &self.resource)
221 .field("client_registration", &self.client_registration)
222 .field("tokens", &self.tokens)
223 .finish()
224 }
225}
226
227/// Error codes defined by RFC 6750 Section 3.1 for Bearer token authentication.
228#[derive(Debug, Clone, Copy, PartialEq, Eq)]
229pub enum BearerError {
230 /// The request is missing a required parameter, includes an unsupported
231 /// parameter or parameter value, or is otherwise malformed.
232 InvalidRequest,
233 /// The access token provided is expired, revoked, malformed, or invalid.
234 InvalidToken,
235 /// The request requires higher privileges than provided by the access token.
236 InsufficientScope,
237 /// An unrecognized error code (extension or future spec addition).
238 Other,
239}
240
241impl BearerError {
242 fn parse(value: &str) -> Self {
243 match value {
244 "invalid_request" => BearerError::InvalidRequest,
245 "invalid_token" => BearerError::InvalidToken,
246 "insufficient_scope" => BearerError::InsufficientScope,
247 _ => BearerError::Other,
248 }
249 }
250}
251
252/// Fields extracted from a `WWW-Authenticate: Bearer` header.
253///
254/// Per RFC 9728 Section 5.1, MCP servers include `resource_metadata` to point
255/// at the Protected Resource Metadata document. The optional `scope` parameter
256/// (RFC 6750 Section 3) indicates scopes required for the request.
257#[derive(Debug, Clone, PartialEq, Eq)]
258pub struct WwwAuthenticate {
259 pub resource_metadata: Option<Url>,
260 pub scope: Option<Vec<String>>,
261 /// The parsed `error` parameter per RFC 6750 Section 3.1.
262 pub error: Option<BearerError>,
263 pub error_description: Option<String>,
264}
265
266/// Parse a `WWW-Authenticate` header value.
267///
268/// Expects the `Bearer` scheme followed by comma-separated `key="value"` pairs.
269/// Per RFC 6750 and RFC 9728, the relevant parameters are:
270/// - `resource_metadata` — URL of the Protected Resource Metadata document
271/// - `scope` — space-separated list of required scopes
272/// - `error` — error code (e.g. "insufficient_scope")
273/// - `error_description` — human-readable error description
274pub fn parse_www_authenticate(header: &str) -> Result<WwwAuthenticate> {
275 let header = header.trim();
276
277 let params_str = if header.len() >= 6 && header[..6].eq_ignore_ascii_case("bearer") {
278 header[6..].trim()
279 } else {
280 bail!("WWW-Authenticate header does not use Bearer scheme");
281 };
282
283 if params_str.is_empty() {
284 return Ok(WwwAuthenticate {
285 resource_metadata: None,
286 scope: None,
287 error: None,
288 error_description: None,
289 });
290 }
291
292 let params = parse_auth_params(params_str);
293
294 let resource_metadata = params
295 .get("resource_metadata")
296 .map(|v| Url::parse(v))
297 .transpose()
298 .map_err(|e| anyhow!("invalid resource_metadata URL: {}", e))?;
299
300 let scope = params
301 .get("scope")
302 .map(|v| v.split_whitespace().map(String::from).collect());
303
304 let error = params.get("error").map(|v| BearerError::parse(v));
305 let error_description = params.get("error_description").cloned();
306
307 Ok(WwwAuthenticate {
308 resource_metadata,
309 scope,
310 error,
311 error_description,
312 })
313}
314
315/// Parse comma-separated `key="value"` or `key=token` parameters from an
316/// auth-param list (RFC 7235 Section 2.1).
317fn parse_auth_params(input: &str) -> collections::HashMap<String, String> {
318 let mut params = collections::HashMap::default();
319 let mut remaining = input.trim();
320
321 while !remaining.is_empty() {
322 // Skip leading whitespace and commas.
323 remaining = remaining.trim_start_matches(|c: char| c == ',' || c.is_whitespace());
324 if remaining.is_empty() {
325 break;
326 }
327
328 // Find the key (everything before '=').
329 let eq_pos = match remaining.find('=') {
330 Some(pos) => pos,
331 None => break,
332 };
333
334 let key = remaining[..eq_pos].trim().to_lowercase();
335 remaining = &remaining[eq_pos + 1..];
336 remaining = remaining.trim_start();
337
338 // Parse the value: either quoted or unquoted (token).
339 let value;
340 if remaining.starts_with('"') {
341 // Quoted string: find the closing quote, handling escaped chars.
342 remaining = &remaining[1..]; // skip opening quote
343 let mut val = String::new();
344 let mut chars = remaining.char_indices();
345 loop {
346 match chars.next() {
347 Some((_, '\\')) => {
348 // Escaped character — take the next char literally.
349 if let Some((_, c)) = chars.next() {
350 val.push(c);
351 }
352 }
353 Some((i, '"')) => {
354 remaining = &remaining[i + 1..];
355 break;
356 }
357 Some((_, c)) => val.push(c),
358 None => {
359 remaining = "";
360 break;
361 }
362 }
363 }
364 value = val;
365 } else {
366 // Unquoted token: read until comma or whitespace.
367 let end = remaining
368 .find(|c: char| c == ',' || c.is_whitespace())
369 .unwrap_or(remaining.len());
370 value = remaining[..end].to_string();
371 remaining = &remaining[end..];
372 }
373
374 if !key.is_empty() {
375 params.insert(key, value);
376 }
377 }
378
379 params
380}
381
382/// Construct the well-known Protected Resource Metadata URIs for a given MCP
383/// server URL, per RFC 9728 Section 3.
384///
385/// Returns URIs in priority order:
386/// 1. Path-specific: `https://<host>/.well-known/oauth-protected-resource/<path>`
387/// 2. Root: `https://<host>/.well-known/oauth-protected-resource`
388pub fn protected_resource_metadata_urls(server_url: &Url) -> Vec<Url> {
389 let mut urls = Vec::new();
390 let base = format!("{}://{}", server_url.scheme(), server_url.authority());
391
392 let path = server_url.path().trim_start_matches('/');
393 if !path.is_empty() {
394 if let Ok(url) = Url::parse(&format!(
395 "{}/.well-known/oauth-protected-resource/{}",
396 base, path
397 )) {
398 urls.push(url);
399 }
400 }
401
402 if let Ok(url) = Url::parse(&format!("{}/.well-known/oauth-protected-resource", base)) {
403 urls.push(url);
404 }
405
406 urls
407}
408
409/// Construct the well-known Authorization Server Metadata URIs for a given
410/// issuer URL, per RFC 8414 Section 3.1 and Section 5 (OIDC compat).
411///
412/// Returns URIs in priority order, which differs depending on whether the
413/// issuer URL has a path component.
414pub fn auth_server_metadata_urls(issuer: &Url) -> Vec<Url> {
415 let mut urls = Vec::new();
416 let base = format!("{}://{}", issuer.scheme(), issuer.authority());
417 let path = issuer.path().trim_matches('/');
418
419 if !path.is_empty() {
420 // Issuer with path: try path-inserted variants first.
421 if let Ok(url) = Url::parse(&format!(
422 "{}/.well-known/oauth-authorization-server/{}",
423 base, path
424 )) {
425 urls.push(url);
426 }
427 if let Ok(url) = Url::parse(&format!(
428 "{}/.well-known/openid-configuration/{}",
429 base, path
430 )) {
431 urls.push(url);
432 }
433 if let Ok(url) = Url::parse(&format!(
434 "{}/{}/.well-known/openid-configuration",
435 base, path
436 )) {
437 urls.push(url);
438 }
439 } else {
440 // No path: standard well-known locations.
441 if let Ok(url) = Url::parse(&format!("{}/.well-known/oauth-authorization-server", base)) {
442 urls.push(url);
443 }
444 if let Ok(url) = Url::parse(&format!("{}/.well-known/openid-configuration", base)) {
445 urls.push(url);
446 }
447 }
448
449 urls
450}
451
452// -- Canonical server URI (RFC 8707) -----------------------------------------
453
454/// Derive the canonical resource URI for an MCP server URL, suitable for the
455/// `resource` parameter in authorization and token requests per RFC 8707.
456///
457/// Lowercases the scheme and host, preserves the path (without trailing slash),
458/// strips fragments and query strings.
459pub fn canonical_server_uri(server_url: &Url) -> String {
460 let mut uri = format!(
461 "{}://{}",
462 server_url.scheme().to_ascii_lowercase(),
463 server_url.host_str().unwrap_or("").to_ascii_lowercase(),
464 );
465 if let Some(port) = server_url.port() {
466 uri.push_str(&format!(":{}", port));
467 }
468 let path = server_url.path();
469 if path != "/" {
470 uri.push_str(path.trim_end_matches('/'));
471 }
472 uri
473}
474
475// -- Scope selection ---------------------------------------------------------
476
477/// Select scopes following the MCP spec's Scope Selection Strategy:
478/// 1. Use `scope` from the `WWW-Authenticate` challenge if present.
479/// 2. Fall back to `scopes_supported` from Protected Resource Metadata.
480/// 3. Return empty if neither is available.
481pub fn select_scopes(
482 www_authenticate: &WwwAuthenticate,
483 resource_metadata: &ProtectedResourceMetadata,
484) -> Vec<String> {
485 if let Some(ref scopes) = www_authenticate.scope {
486 if !scopes.is_empty() {
487 return scopes.clone();
488 }
489 }
490 resource_metadata
491 .scopes_supported
492 .clone()
493 .unwrap_or_default()
494}
495
496// -- Client registration strategy --------------------------------------------
497
498/// The registration approach to use, determined from auth server metadata.
499#[derive(Debug, Clone, PartialEq, Eq)]
500pub enum ClientRegistrationStrategy {
501 /// The auth server supports CIMD. Use the CIMD URL as client_id directly.
502 Cimd { client_id: String },
503 /// The auth server has a registration endpoint. Caller must POST to it.
504 Dcr { registration_endpoint: Url },
505 /// No supported registration mechanism.
506 Unavailable,
507}
508
509/// Determine how to register with the authorization server, following the
510/// spec's recommended priority: CIMD first, DCR fallback.
511pub fn determine_registration_strategy(
512 auth_server_metadata: &AuthServerMetadata,
513) -> ClientRegistrationStrategy {
514 if auth_server_metadata.client_id_metadata_document_supported {
515 ClientRegistrationStrategy::Cimd {
516 client_id: CIMD_URL.to_string(),
517 }
518 } else if let Some(ref endpoint) = auth_server_metadata.registration_endpoint {
519 ClientRegistrationStrategy::Dcr {
520 registration_endpoint: endpoint.clone(),
521 }
522 } else {
523 ClientRegistrationStrategy::Unavailable
524 }
525}
526
527// -- PKCE (RFC 7636) ---------------------------------------------------------
528
529/// A PKCE code verifier and its S256 challenge.
530#[derive(Clone)]
531pub struct PkceChallenge {
532 pub verifier: String,
533 pub challenge: String,
534}
535
536impl std::fmt::Debug for PkceChallenge {
537 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
538 f.debug_struct("PkceChallenge")
539 .field("verifier", &"[redacted]")
540 .field("challenge", &self.challenge)
541 .finish()
542 }
543}
544
545/// Generate a PKCE code verifier and S256 challenge per RFC 7636.
546///
547/// The verifier is 43 base64url characters derived from 32 random bytes.
548/// The challenge is `BASE64URL(SHA256(verifier))`.
549pub fn generate_pkce_challenge() -> PkceChallenge {
550 let mut random_bytes = [0u8; 32];
551 rand::rng().fill(&mut random_bytes);
552 let engine = base64::engine::general_purpose::URL_SAFE_NO_PAD;
553 let verifier = engine.encode(&random_bytes);
554
555 let digest = Sha256::digest(verifier.as_bytes());
556 let challenge = engine.encode(digest);
557
558 PkceChallenge {
559 verifier,
560 challenge,
561 }
562}
563
564// -- Authorization URL construction ------------------------------------------
565
566/// Build the authorization URL for the OAuth Authorization Code + PKCE flow.
567pub fn build_authorization_url(
568 auth_server_metadata: &AuthServerMetadata,
569 client_id: &str,
570 redirect_uri: &str,
571 scopes: &[String],
572 resource: &str,
573 pkce: &PkceChallenge,
574 state: &str,
575) -> Url {
576 let mut url = auth_server_metadata.authorization_endpoint.clone();
577 {
578 let mut query = url.query_pairs_mut();
579 query.append_pair("response_type", "code");
580 query.append_pair("client_id", client_id);
581 query.append_pair("redirect_uri", redirect_uri);
582 if !scopes.is_empty() {
583 query.append_pair("scope", &scopes.join(" "));
584 }
585 query.append_pair("resource", resource);
586 query.append_pair("code_challenge", &pkce.challenge);
587 query.append_pair("code_challenge_method", "S256");
588 query.append_pair("state", state);
589 }
590 url
591}
592
593// -- Token endpoint request bodies -------------------------------------------
594
595/// The JSON body returned by the token endpoint on success.
596#[derive(Deserialize)]
597pub struct TokenResponse {
598 pub access_token: String,
599 #[serde(default)]
600 pub refresh_token: Option<String>,
601 #[serde(default)]
602 pub expires_in: Option<u64>,
603 #[serde(default)]
604 pub token_type: Option<String>,
605}
606
607impl std::fmt::Debug for TokenResponse {
608 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
609 f.debug_struct("TokenResponse")
610 .field("access_token", &"[redacted]")
611 .field(
612 "refresh_token",
613 &self.refresh_token.as_ref().map(|_| "[redacted]"),
614 )
615 .field("expires_in", &self.expires_in)
616 .field("token_type", &self.token_type)
617 .finish()
618 }
619}
620
621impl TokenResponse {
622 /// Convert into `OAuthTokens`, computing `expires_at` from `expires_in`.
623 pub fn into_tokens(self) -> OAuthTokens {
624 let expires_at = self
625 .expires_in
626 .map(|secs| SystemTime::now() + Duration::from_secs(secs));
627 OAuthTokens {
628 access_token: self.access_token,
629 refresh_token: self.refresh_token,
630 expires_at,
631 }
632 }
633}
634
635/// Build the form-encoded body for an authorization code token exchange.
636pub fn token_exchange_params(
637 code: &str,
638 client_id: &str,
639 redirect_uri: &str,
640 code_verifier: &str,
641 resource: &str,
642) -> Vec<(&'static str, String)> {
643 vec![
644 ("grant_type", "authorization_code".to_string()),
645 ("code", code.to_string()),
646 ("redirect_uri", redirect_uri.to_string()),
647 ("client_id", client_id.to_string()),
648 ("code_verifier", code_verifier.to_string()),
649 ("resource", resource.to_string()),
650 ]
651}
652
653/// Build the form-encoded body for a token refresh request.
654pub fn token_refresh_params(
655 refresh_token: &str,
656 client_id: &str,
657 resource: &str,
658) -> Vec<(&'static str, String)> {
659 vec![
660 ("grant_type", "refresh_token".to_string()),
661 ("refresh_token", refresh_token.to_string()),
662 ("client_id", client_id.to_string()),
663 ("resource", resource.to_string()),
664 ]
665}
666
667// -- DCR request body (RFC 7591) ---------------------------------------------
668
669/// Build the JSON body for a Dynamic Client Registration request.
670///
671/// The `redirect_uri` should be the actual loopback URI with the ephemeral
672/// port (e.g. `http://127.0.0.1:12345/callback`). Some auth servers do strict
673/// redirect URI matching even for loopback addresses, so we register the
674/// exact URI we intend to use.
675pub fn dcr_registration_body(redirect_uri: &str) -> serde_json::Value {
676 serde_json::json!({
677 "client_name": "Zed",
678 "redirect_uris": [redirect_uri],
679 "grant_types": ["authorization_code"],
680 "response_types": ["code"],
681 "token_endpoint_auth_method": "none"
682 })
683}
684
685// -- Discovery (async, hits real endpoints) ----------------------------------
686
687/// Fetch Protected Resource Metadata from the MCP server.
688///
689/// Tries the `resource_metadata` URL from the `WWW-Authenticate` header first,
690/// then falls back to well-known URIs constructed from `server_url`.
691pub async fn fetch_protected_resource_metadata(
692 http_client: &Arc<dyn HttpClient>,
693 server_url: &Url,
694 www_authenticate: &WwwAuthenticate,
695) -> Result<ProtectedResourceMetadata> {
696 let candidate_urls = match &www_authenticate.resource_metadata {
697 Some(url) if url.origin() == server_url.origin() => {
698 // Try the header-provided URL first (per MCP spec: "use the resource
699 // metadata URL from the parsed WWW-Authenticate headers when present"),
700 // then fall back to RFC 9728 well-known URIs in case the header URL is
701 // wrong (e.g. a buggy server that doubles the path component).
702 let mut urls = vec![url.clone()];
703 for fallback in protected_resource_metadata_urls(server_url) {
704 if !urls.contains(&fallback) {
705 urls.push(fallback);
706 }
707 }
708 urls
709 }
710 Some(url) => {
711 log::warn!(
712 "Ignoring cross-origin resource_metadata URL {} \
713 (server origin: {})",
714 url,
715 server_url.origin().unicode_serialization()
716 );
717 protected_resource_metadata_urls(server_url)
718 }
719 None => protected_resource_metadata_urls(server_url),
720 };
721
722 for url in &candidate_urls {
723 match fetch_json::<ProtectedResourceMetadataResponse>(http_client, url).await {
724 Ok(response) => {
725 if response.authorization_servers.is_empty() {
726 bail!(
727 "Protected Resource Metadata at {} has no authorization_servers",
728 url
729 );
730 }
731 return Ok(ProtectedResourceMetadata {
732 resource: response.resource.unwrap_or_else(|| server_url.clone()),
733 authorization_servers: response.authorization_servers,
734 scopes_supported: response.scopes_supported,
735 });
736 }
737 Err(err) => {
738 log::debug!(
739 "Failed to fetch Protected Resource Metadata from {}: {}",
740 url,
741 err
742 );
743 }
744 }
745 }
746
747 bail!(
748 "Could not fetch Protected Resource Metadata for {}",
749 server_url
750 )
751}
752
753/// Fetch Authorization Server Metadata, trying RFC 8414 and OIDC Discovery
754/// endpoints in the priority order specified by the MCP spec.
755pub async fn fetch_auth_server_metadata(
756 http_client: &Arc<dyn HttpClient>,
757 issuer: &Url,
758) -> Result<AuthServerMetadata> {
759 let candidate_urls = auth_server_metadata_urls(issuer);
760
761 for url in &candidate_urls {
762 match fetch_json::<AuthServerMetadataResponse>(http_client, url).await {
763 Ok(response) => {
764 let reported_issuer = response.issuer.unwrap_or_else(|| issuer.clone());
765 if reported_issuer != *issuer {
766 bail!(
767 "Auth server metadata issuer mismatch: expected {}, got {}",
768 issuer,
769 reported_issuer
770 );
771 }
772
773 return Ok(AuthServerMetadata {
774 issuer: reported_issuer,
775 authorization_endpoint: response
776 .authorization_endpoint
777 .ok_or_else(|| anyhow!("missing authorization_endpoint"))?,
778 token_endpoint: response
779 .token_endpoint
780 .ok_or_else(|| anyhow!("missing token_endpoint"))?,
781 registration_endpoint: response.registration_endpoint,
782 scopes_supported: response.scopes_supported,
783 code_challenge_methods_supported: response.code_challenge_methods_supported,
784 client_id_metadata_document_supported: response
785 .client_id_metadata_document_supported
786 .unwrap_or(false),
787 });
788 }
789 Err(err) => {
790 log::debug!("Failed to fetch Auth Server Metadata from {}: {}", url, err);
791 }
792 }
793 }
794
795 bail!(
796 "Could not fetch Authorization Server Metadata for {}",
797 issuer
798 )
799}
800
801/// Run the full discovery flow: fetch resource metadata, then auth server
802/// metadata, then select scopes. Client registration is resolved separately,
803/// once the real redirect URI is known.
804pub async fn discover(
805 http_client: &Arc<dyn HttpClient>,
806 server_url: &Url,
807 www_authenticate: &WwwAuthenticate,
808) -> Result<OAuthDiscovery> {
809 let resource_metadata =
810 fetch_protected_resource_metadata(http_client, server_url, www_authenticate).await?;
811
812 let auth_server_url = resource_metadata
813 .authorization_servers
814 .first()
815 .ok_or_else(|| anyhow!("no authorization servers in resource metadata"))?;
816
817 let auth_server_metadata = fetch_auth_server_metadata(http_client, auth_server_url).await?;
818
819 // Verify PKCE S256 support (spec requirement).
820 match &auth_server_metadata.code_challenge_methods_supported {
821 Some(methods) if methods.iter().any(|m| m == "S256") => {}
822 Some(_) => bail!("authorization server does not support S256 PKCE"),
823 None => bail!("authorization server does not advertise code_challenge_methods_supported"),
824 }
825
826 // Verify there is at least one supported registration strategy before we
827 // present the server as ready to authenticate.
828 match determine_registration_strategy(&auth_server_metadata) {
829 ClientRegistrationStrategy::Cimd { .. } | ClientRegistrationStrategy::Dcr { .. } => {}
830 ClientRegistrationStrategy::Unavailable => {
831 bail!("authorization server supports neither CIMD nor DCR")
832 }
833 }
834
835 let scopes = select_scopes(www_authenticate, &resource_metadata);
836
837 Ok(OAuthDiscovery {
838 resource_metadata,
839 auth_server_metadata,
840 scopes,
841 })
842}
843
844/// Resolve the OAuth client registration for an authorization flow.
845///
846/// CIMD uses the static client metadata document directly. For DCR, a fresh
847/// registration is performed each time because the loopback redirect URI
848/// includes an ephemeral port that changes every flow.
849pub async fn resolve_client_registration(
850 http_client: &Arc<dyn HttpClient>,
851 discovery: &OAuthDiscovery,
852 redirect_uri: &str,
853) -> Result<OAuthClientRegistration> {
854 match determine_registration_strategy(&discovery.auth_server_metadata) {
855 ClientRegistrationStrategy::Cimd { client_id } => Ok(OAuthClientRegistration {
856 client_id,
857 client_secret: None,
858 }),
859 ClientRegistrationStrategy::Dcr {
860 registration_endpoint,
861 } => perform_dcr(http_client, ®istration_endpoint, redirect_uri).await,
862 ClientRegistrationStrategy::Unavailable => {
863 bail!("authorization server supports neither CIMD nor DCR")
864 }
865 }
866}
867
868// -- Dynamic Client Registration (RFC 7591) ----------------------------------
869
870/// Perform Dynamic Client Registration with the authorization server.
871pub async fn perform_dcr(
872 http_client: &Arc<dyn HttpClient>,
873 registration_endpoint: &Url,
874 redirect_uri: &str,
875) -> Result<OAuthClientRegistration> {
876 validate_oauth_url(registration_endpoint)?;
877
878 let body = dcr_registration_body(redirect_uri);
879 let body_bytes = serde_json::to_vec(&body)?;
880
881 let request = Request::builder()
882 .method(http_client::http::Method::POST)
883 .uri(registration_endpoint.as_str())
884 .header("Content-Type", "application/json")
885 .header("Accept", "application/json")
886 .body(AsyncBody::from(body_bytes))?;
887
888 let mut response = http_client.send(request).await?;
889
890 if !response.status().is_success() {
891 let mut error_body = String::new();
892 response.body_mut().read_to_string(&mut error_body).await?;
893 bail!(
894 "DCR failed with status {}: {}",
895 response.status(),
896 error_body
897 );
898 }
899
900 let mut response_body = String::new();
901 response
902 .body_mut()
903 .read_to_string(&mut response_body)
904 .await?;
905
906 let dcr_response: DcrResponse =
907 serde_json::from_str(&response_body).context("failed to parse DCR response")?;
908
909 Ok(OAuthClientRegistration {
910 client_id: dcr_response.client_id,
911 client_secret: dcr_response.client_secret,
912 })
913}
914
915// -- Token exchange and refresh (async) --------------------------------------
916
917/// Exchange an authorization code for tokens at the token endpoint.
918pub async fn exchange_code(
919 http_client: &Arc<dyn HttpClient>,
920 auth_server_metadata: &AuthServerMetadata,
921 code: &str,
922 client_id: &str,
923 redirect_uri: &str,
924 code_verifier: &str,
925 resource: &str,
926) -> Result<OAuthTokens> {
927 let params = token_exchange_params(code, client_id, redirect_uri, code_verifier, resource);
928 post_token_request(http_client, &auth_server_metadata.token_endpoint, ¶ms).await
929}
930
931/// Refresh tokens using a refresh token.
932pub async fn refresh_tokens(
933 http_client: &Arc<dyn HttpClient>,
934 token_endpoint: &Url,
935 refresh_token: &str,
936 client_id: &str,
937 resource: &str,
938) -> Result<OAuthTokens> {
939 let params = token_refresh_params(refresh_token, client_id, resource);
940 post_token_request(http_client, token_endpoint, ¶ms).await
941}
942
943/// POST form-encoded parameters to a token endpoint and parse the response.
944async fn post_token_request(
945 http_client: &Arc<dyn HttpClient>,
946 token_endpoint: &Url,
947 params: &[(&str, String)],
948) -> Result<OAuthTokens> {
949 validate_oauth_url(token_endpoint)?;
950
951 let body = url::form_urlencoded::Serializer::new(String::new())
952 .extend_pairs(params.iter().map(|(k, v)| (*k, v.as_str())))
953 .finish();
954
955 let request = Request::builder()
956 .method(http_client::http::Method::POST)
957 .uri(token_endpoint.as_str())
958 .header("Content-Type", "application/x-www-form-urlencoded")
959 .header("Accept", "application/json")
960 .body(AsyncBody::from(body.into_bytes()))?;
961
962 let mut response = http_client.send(request).await?;
963
964 if !response.status().is_success() {
965 let mut error_body = String::new();
966 response.body_mut().read_to_string(&mut error_body).await?;
967 bail!(
968 "token request failed with status {}: {}",
969 response.status(),
970 error_body
971 );
972 }
973
974 let mut response_body = String::new();
975 response
976 .body_mut()
977 .read_to_string(&mut response_body)
978 .await?;
979
980 let token_response: TokenResponse =
981 serde_json::from_str(&response_body).context("failed to parse token response")?;
982
983 Ok(token_response.into_tokens())
984}
985
986// -- Loopback HTTP callback server -------------------------------------------
987
988/// An OAuth authorization callback received via the loopback HTTP server.
989pub struct OAuthCallback {
990 pub code: String,
991 pub state: String,
992}
993
994impl std::fmt::Debug for OAuthCallback {
995 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
996 f.debug_struct("OAuthCallback")
997 .field("code", &"[redacted]")
998 .field("state", &"[redacted]")
999 .finish()
1000 }
1001}
1002
1003impl OAuthCallback {
1004 /// Parse the query string from a callback URL like
1005 /// `http://127.0.0.1:<port>/callback?code=...&state=...`.
1006 pub fn parse_query(query: &str) -> Result<Self> {
1007 let mut code: Option<String> = None;
1008 let mut state: Option<String> = None;
1009 let mut error: Option<String> = None;
1010 let mut error_description: Option<String> = None;
1011
1012 for (key, value) in url::form_urlencoded::parse(query.as_bytes()) {
1013 match key.as_ref() {
1014 "code" => {
1015 if !value.is_empty() {
1016 code = Some(value.into_owned());
1017 }
1018 }
1019 "state" => {
1020 if !value.is_empty() {
1021 state = Some(value.into_owned());
1022 }
1023 }
1024 "error" => {
1025 if !value.is_empty() {
1026 error = Some(value.into_owned());
1027 }
1028 }
1029 "error_description" => {
1030 if !value.is_empty() {
1031 error_description = Some(value.into_owned());
1032 }
1033 }
1034 _ => {}
1035 }
1036 }
1037
1038 // Check for OAuth error response (RFC 6749 Section 4.1.2.1) before
1039 // checking for missing code/state.
1040 if let Some(error_code) = error {
1041 bail!(
1042 "OAuth authorization failed: {} ({})",
1043 error_code,
1044 error_description.as_deref().unwrap_or("no description")
1045 );
1046 }
1047
1048 let code = code.ok_or_else(|| anyhow!("missing 'code' parameter in OAuth callback"))?;
1049 let state = state.ok_or_else(|| anyhow!("missing 'state' parameter in OAuth callback"))?;
1050
1051 Ok(Self { code, state })
1052 }
1053}
1054
1055/// How long to wait for the browser to complete the OAuth flow before giving
1056/// up and releasing the loopback port.
1057const CALLBACK_TIMEOUT: Duration = Duration::from_secs(2 * 60);
1058
1059/// Start a loopback HTTP server to receive the OAuth authorization callback.
1060///
1061/// Binds to an ephemeral loopback port for each flow.
1062///
1063/// Returns `(redirect_uri, callback_future)`. The caller should use the
1064/// redirect URI in the authorization request, open the browser, then await
1065/// the future to receive the callback.
1066///
1067/// The server accepts exactly one request on `/callback`, validates that it
1068/// contains `code` and `state` query parameters, responds with a minimal
1069/// HTML page telling the user they can close the tab, and shuts down.
1070///
1071/// The callback server shuts down when the returned oneshot receiver is dropped
1072/// (e.g. because the authentication task was cancelled), or after a timeout
1073/// ([CALLBACK_TIMEOUT]).
1074pub async fn start_callback_server() -> Result<(
1075 String,
1076 futures::channel::oneshot::Receiver<Result<OAuthCallback>>,
1077)> {
1078 let server = tiny_http::Server::http("127.0.0.1:0")
1079 .map_err(|e| anyhow!(e).context("Failed to bind loopback listener for OAuth callback"))?;
1080 let port = server
1081 .server_addr()
1082 .to_ip()
1083 .context("server not bound to a TCP address")?
1084 .port();
1085
1086 let redirect_uri = format!("http://127.0.0.1:{}/callback", port);
1087
1088 let (tx, rx) = futures::channel::oneshot::channel();
1089
1090 // `tiny_http` is blocking, so we run it on a background thread.
1091 // The `recv_timeout` loop lets us check for cancellation (the receiver
1092 // being dropped) and enforce an overall timeout.
1093 std::thread::spawn(move || {
1094 let deadline = std::time::Instant::now() + CALLBACK_TIMEOUT;
1095
1096 loop {
1097 if tx.is_canceled() {
1098 return;
1099 }
1100 let remaining = deadline.saturating_duration_since(std::time::Instant::now());
1101 if remaining.is_zero() {
1102 return;
1103 }
1104
1105 let timeout = remaining.min(Duration::from_millis(500));
1106 let Some(request) = (match server.recv_timeout(timeout) {
1107 Ok(req) => req,
1108 Err(_) => {
1109 let _ = tx.send(Err(anyhow!("OAuth callback server I/O error")));
1110 return;
1111 }
1112 }) else {
1113 // Timeout with no request — loop back and check cancellation.
1114 continue;
1115 };
1116
1117 let result = handle_callback_request(&request);
1118
1119 let (status_code, body) = match &result {
1120 Ok(_) => (
1121 200,
1122 "<html><body><h1>Authorization successful</h1>\
1123 <p>You can close this tab and return to Zed.</p></body></html>",
1124 ),
1125 Err(err) => {
1126 log::error!("OAuth callback error: {}", err);
1127 (
1128 400,
1129 "<html><body><h1>Authorization failed</h1>\
1130 <p>Something went wrong. Please try again from Zed.</p></body></html>",
1131 )
1132 }
1133 };
1134
1135 let response = tiny_http::Response::from_string(body)
1136 .with_status_code(status_code)
1137 .with_header(
1138 tiny_http::Header::from_str("Content-Type: text/html")
1139 .expect("failed to construct response header"),
1140 )
1141 .with_header(
1142 tiny_http::Header::from_str("Keep-Alive: timeout=0,max=0")
1143 .expect("failed to construct response header"),
1144 );
1145 request.respond(response).log_err();
1146
1147 let _ = tx.send(result);
1148 return;
1149 }
1150 });
1151
1152 Ok((redirect_uri, rx))
1153}
1154
1155/// Extract the `code` and `state` query parameters from an OAuth callback
1156/// request to `/callback`.
1157fn handle_callback_request(request: &tiny_http::Request) -> Result<OAuthCallback> {
1158 let url = Url::parse(&format!("http://localhost{}", request.url()))
1159 .context("malformed callback request URL")?;
1160
1161 if url.path() != "/callback" {
1162 bail!("unexpected path in OAuth callback: {}", url.path());
1163 }
1164
1165 let query = url
1166 .query()
1167 .ok_or_else(|| anyhow!("OAuth callback has no query string"))?;
1168 OAuthCallback::parse_query(query)
1169}
1170
1171// -- JSON fetch helper -------------------------------------------------------
1172
1173async fn fetch_json<T: serde::de::DeserializeOwned>(
1174 http_client: &Arc<dyn HttpClient>,
1175 url: &Url,
1176) -> Result<T> {
1177 validate_oauth_url(url)?;
1178
1179 let request = Request::builder()
1180 .method(http_client::http::Method::GET)
1181 .uri(url.as_str())
1182 .header("Accept", "application/json")
1183 .body(AsyncBody::default())?;
1184
1185 let mut response = http_client.send(request).await?;
1186
1187 if !response.status().is_success() {
1188 bail!("HTTP {} fetching {}", response.status(), url);
1189 }
1190
1191 let mut body = String::new();
1192 response.body_mut().read_to_string(&mut body).await?;
1193 serde_json::from_str(&body).with_context(|| format!("failed to parse JSON from {}", url))
1194}
1195
1196// -- Serde response types for discovery --------------------------------------
1197
1198#[derive(Debug, Deserialize)]
1199struct ProtectedResourceMetadataResponse {
1200 #[serde(default)]
1201 resource: Option<Url>,
1202 #[serde(default)]
1203 authorization_servers: Vec<Url>,
1204 #[serde(default)]
1205 scopes_supported: Option<Vec<String>>,
1206}
1207
1208#[derive(Debug, Deserialize)]
1209struct AuthServerMetadataResponse {
1210 #[serde(default)]
1211 issuer: Option<Url>,
1212 #[serde(default)]
1213 authorization_endpoint: Option<Url>,
1214 #[serde(default)]
1215 token_endpoint: Option<Url>,
1216 #[serde(default)]
1217 registration_endpoint: Option<Url>,
1218 #[serde(default)]
1219 scopes_supported: Option<Vec<String>>,
1220 #[serde(default)]
1221 code_challenge_methods_supported: Option<Vec<String>>,
1222 #[serde(default)]
1223 client_id_metadata_document_supported: Option<bool>,
1224}
1225
1226#[derive(Debug, Deserialize)]
1227struct DcrResponse {
1228 client_id: String,
1229 #[serde(default)]
1230 client_secret: Option<String>,
1231}
1232
1233/// Provides OAuth tokens to the HTTP transport layer.
1234///
1235/// The transport calls `access_token()` before each request. On a 401 response
1236/// it calls `try_refresh()` and retries once if the refresh succeeds.
1237#[async_trait]
1238pub trait OAuthTokenProvider: Send + Sync {
1239 /// Returns the current access token, if one is available.
1240 fn access_token(&self) -> Option<String>;
1241
1242 /// Attempts to refresh the access token. Returns `true` if a new token was
1243 /// obtained and the request should be retried.
1244 async fn try_refresh(&self) -> Result<bool>;
1245}
1246
1247/// Concrete `OAuthTokenProvider` backed by a full persisted OAuth session and
1248/// an HTTP client for token refresh. The same provider type is used both after
1249/// an interactive authentication flow and when restoring a saved session from
1250/// the keychain on startup.
1251pub struct McpOAuthTokenProvider {
1252 session: SyncMutex<OAuthSession>,
1253 http_client: Arc<dyn HttpClient>,
1254 token_refresh_tx: Option<mpsc::UnboundedSender<OAuthSession>>,
1255}
1256
1257impl McpOAuthTokenProvider {
1258 pub fn new(
1259 session: OAuthSession,
1260 http_client: Arc<dyn HttpClient>,
1261 token_refresh_tx: Option<mpsc::UnboundedSender<OAuthSession>>,
1262 ) -> Self {
1263 Self {
1264 session: SyncMutex::new(session),
1265 http_client,
1266 token_refresh_tx,
1267 }
1268 }
1269
1270 fn access_token_is_expired(tokens: &OAuthTokens) -> bool {
1271 tokens.expires_at.is_some_and(|expires_at| {
1272 SystemTime::now()
1273 .checked_add(Duration::from_secs(30))
1274 .is_some_and(|now_with_buffer| expires_at <= now_with_buffer)
1275 })
1276 }
1277}
1278
1279#[async_trait]
1280impl OAuthTokenProvider for McpOAuthTokenProvider {
1281 fn access_token(&self) -> Option<String> {
1282 let session = self.session.lock();
1283 if Self::access_token_is_expired(&session.tokens) {
1284 return None;
1285 }
1286 Some(session.tokens.access_token.clone())
1287 }
1288
1289 async fn try_refresh(&self) -> Result<bool> {
1290 let (refresh_token, token_endpoint, resource, client_id) = {
1291 let session = self.session.lock();
1292 match session.tokens.refresh_token.clone() {
1293 Some(refresh_token) => (
1294 refresh_token,
1295 session.token_endpoint.clone(),
1296 session.resource.clone(),
1297 session.client_registration.client_id.clone(),
1298 ),
1299 None => return Ok(false),
1300 }
1301 };
1302
1303 let resource_str = canonical_server_uri(&resource);
1304
1305 match refresh_tokens(
1306 &self.http_client,
1307 &token_endpoint,
1308 &refresh_token,
1309 &client_id,
1310 &resource_str,
1311 )
1312 .await
1313 {
1314 Ok(mut new_tokens) => {
1315 if new_tokens.refresh_token.is_none() {
1316 new_tokens.refresh_token = Some(refresh_token);
1317 }
1318
1319 {
1320 let mut session = self.session.lock();
1321 session.tokens = new_tokens;
1322
1323 if let Some(ref tx) = self.token_refresh_tx {
1324 tx.unbounded_send(session.clone()).ok();
1325 }
1326 }
1327
1328 Ok(true)
1329 }
1330 Err(err) => {
1331 log::warn!("OAuth token refresh failed: {}", err);
1332 Ok(false)
1333 }
1334 }
1335 }
1336}
1337
1338#[cfg(test)]
1339mod tests {
1340 use super::*;
1341 use http_client::Response;
1342
1343 // -- require_https_or_loopback tests ------------------------------------
1344
1345 #[test]
1346 fn test_require_https_or_loopback_accepts_https() {
1347 let url = Url::parse("https://auth.example.com/token").unwrap();
1348 assert!(require_https_or_loopback(&url).is_ok());
1349 }
1350
1351 #[test]
1352 fn test_require_https_or_loopback_rejects_http_remote() {
1353 let url = Url::parse("http://auth.example.com/token").unwrap();
1354 assert!(require_https_or_loopback(&url).is_err());
1355 }
1356
1357 #[test]
1358 fn test_require_https_or_loopback_accepts_http_127_0_0_1() {
1359 let url = Url::parse("http://127.0.0.1:8080/callback").unwrap();
1360 assert!(require_https_or_loopback(&url).is_ok());
1361 }
1362
1363 #[test]
1364 fn test_require_https_or_loopback_accepts_http_ipv6_loopback() {
1365 let url = Url::parse("http://[::1]:8080/callback").unwrap();
1366 assert!(require_https_or_loopback(&url).is_ok());
1367 }
1368
1369 #[test]
1370 fn test_require_https_or_loopback_accepts_http_localhost() {
1371 let url = Url::parse("http://localhost:8080/callback").unwrap();
1372 assert!(require_https_or_loopback(&url).is_ok());
1373 }
1374
1375 #[test]
1376 fn test_require_https_or_loopback_accepts_http_localhost_case_insensitive() {
1377 let url = Url::parse("http://LOCALHOST:8080/callback").unwrap();
1378 assert!(require_https_or_loopback(&url).is_ok());
1379 }
1380
1381 #[test]
1382 fn test_require_https_or_loopback_rejects_http_non_loopback_ip() {
1383 let url = Url::parse("http://192.168.1.1:8080/token").unwrap();
1384 assert!(require_https_or_loopback(&url).is_err());
1385 }
1386
1387 #[test]
1388 fn test_require_https_or_loopback_rejects_ftp() {
1389 let url = Url::parse("ftp://auth.example.com/token").unwrap();
1390 assert!(require_https_or_loopback(&url).is_err());
1391 }
1392
1393 // -- validate_oauth_url (SSRF) tests ------------------------------------
1394
1395 #[test]
1396 fn test_validate_oauth_url_accepts_https_public() {
1397 let url = Url::parse("https://auth.example.com/token").unwrap();
1398 assert!(validate_oauth_url(&url).is_ok());
1399 }
1400
1401 #[test]
1402 fn test_validate_oauth_url_rejects_private_ipv4_10() {
1403 let url = Url::parse("https://10.0.0.1/token").unwrap();
1404 assert!(validate_oauth_url(&url).is_err());
1405 }
1406
1407 #[test]
1408 fn test_validate_oauth_url_rejects_private_ipv4_172() {
1409 let url = Url::parse("https://172.16.0.1/token").unwrap();
1410 assert!(validate_oauth_url(&url).is_err());
1411 }
1412
1413 #[test]
1414 fn test_validate_oauth_url_rejects_private_ipv4_192() {
1415 let url = Url::parse("https://192.168.1.1/token").unwrap();
1416 assert!(validate_oauth_url(&url).is_err());
1417 }
1418
1419 #[test]
1420 fn test_validate_oauth_url_rejects_link_local() {
1421 let url = Url::parse("https://169.254.169.254/latest/meta-data/").unwrap();
1422 assert!(validate_oauth_url(&url).is_err());
1423 }
1424
1425 #[test]
1426 fn test_validate_oauth_url_rejects_ipv6_ula() {
1427 let url = Url::parse("https://[fd12:3456:789a::1]/token").unwrap();
1428 assert!(validate_oauth_url(&url).is_err());
1429 }
1430
1431 #[test]
1432 fn test_validate_oauth_url_rejects_ipv6_unspecified() {
1433 let url = Url::parse("https://[::]/token").unwrap();
1434 assert!(validate_oauth_url(&url).is_err());
1435 }
1436
1437 #[test]
1438 fn test_validate_oauth_url_rejects_ipv4_mapped_ipv6_private() {
1439 let url = Url::parse("https://[::ffff:10.0.0.1]/token").unwrap();
1440 assert!(validate_oauth_url(&url).is_err());
1441 }
1442
1443 #[test]
1444 fn test_validate_oauth_url_rejects_ipv4_mapped_ipv6_link_local() {
1445 let url = Url::parse("https://[::ffff:169.254.169.254]/token").unwrap();
1446 assert!(validate_oauth_url(&url).is_err());
1447 }
1448
1449 #[test]
1450 fn test_validate_oauth_url_allows_http_loopback() {
1451 // Loopback is permitted (it's our callback server).
1452 let url = Url::parse("http://127.0.0.1:8080/callback").unwrap();
1453 assert!(validate_oauth_url(&url).is_ok());
1454 }
1455
1456 #[test]
1457 fn test_validate_oauth_url_allows_https_public_ip() {
1458 let url = Url::parse("https://93.184.216.34/token").unwrap();
1459 assert!(validate_oauth_url(&url).is_ok());
1460 }
1461
1462 // -- parse_www_authenticate tests ----------------------------------------
1463
1464 #[test]
1465 fn test_parse_www_authenticate_with_resource_metadata_and_scope() {
1466 let header = r#"Bearer resource_metadata="https://mcp.example.com/.well-known/oauth-protected-resource", scope="files:read user:profile""#;
1467 let result = parse_www_authenticate(header).unwrap();
1468
1469 assert_eq!(
1470 result.resource_metadata.as_ref().map(|u| u.as_str()),
1471 Some("https://mcp.example.com/.well-known/oauth-protected-resource")
1472 );
1473 assert_eq!(
1474 result.scope,
1475 Some(vec!["files:read".to_string(), "user:profile".to_string()])
1476 );
1477 assert_eq!(result.error, None);
1478 }
1479
1480 #[test]
1481 fn test_parse_www_authenticate_resource_metadata_only() {
1482 let header = r#"Bearer resource_metadata="https://mcp.example.com/.well-known/oauth-protected-resource""#;
1483 let result = parse_www_authenticate(header).unwrap();
1484
1485 assert_eq!(
1486 result.resource_metadata.as_ref().map(|u| u.as_str()),
1487 Some("https://mcp.example.com/.well-known/oauth-protected-resource")
1488 );
1489 assert_eq!(result.scope, None);
1490 }
1491
1492 #[test]
1493 fn test_parse_www_authenticate_bare_bearer() {
1494 let result = parse_www_authenticate("Bearer").unwrap();
1495 assert_eq!(result.resource_metadata, None);
1496 assert_eq!(result.scope, None);
1497 }
1498
1499 #[test]
1500 fn test_parse_www_authenticate_with_error() {
1501 let header = r#"Bearer error="insufficient_scope", scope="files:read files:write", resource_metadata="https://mcp.example.com/.well-known/oauth-protected-resource", error_description="Additional file write permission required""#;
1502 let result = parse_www_authenticate(header).unwrap();
1503
1504 assert_eq!(result.error, Some(BearerError::InsufficientScope));
1505 assert_eq!(
1506 result.error_description.as_deref(),
1507 Some("Additional file write permission required")
1508 );
1509 assert_eq!(
1510 result.scope,
1511 Some(vec!["files:read".to_string(), "files:write".to_string()])
1512 );
1513 assert!(result.resource_metadata.is_some());
1514 }
1515
1516 #[test]
1517 fn test_parse_www_authenticate_invalid_token_error() {
1518 let header =
1519 r#"Bearer error="invalid_token", error_description="The access token expired""#;
1520 let result = parse_www_authenticate(header).unwrap();
1521 assert_eq!(result.error, Some(BearerError::InvalidToken));
1522 }
1523
1524 #[test]
1525 fn test_parse_www_authenticate_invalid_request_error() {
1526 let header = r#"Bearer error="invalid_request""#;
1527 let result = parse_www_authenticate(header).unwrap();
1528 assert_eq!(result.error, Some(BearerError::InvalidRequest));
1529 }
1530
1531 #[test]
1532 fn test_parse_www_authenticate_unknown_error() {
1533 let header = r#"Bearer error="some_future_error""#;
1534 let result = parse_www_authenticate(header).unwrap();
1535 assert_eq!(result.error, Some(BearerError::Other));
1536 }
1537
1538 #[test]
1539 fn test_parse_www_authenticate_rejects_non_bearer() {
1540 let result = parse_www_authenticate("Basic realm=\"example\"");
1541 assert!(result.is_err());
1542 }
1543
1544 #[test]
1545 fn test_parse_www_authenticate_case_insensitive_scheme() {
1546 let header = r#"bearer resource_metadata="https://example.com/.well-known/oauth-protected-resource""#;
1547 let result = parse_www_authenticate(header).unwrap();
1548 assert!(result.resource_metadata.is_some());
1549 }
1550
1551 #[test]
1552 fn test_parse_www_authenticate_multiline_style() {
1553 // Some servers emit the header spread across multiple lines joined by
1554 // whitespace, as shown in the spec examples.
1555 let header = "Bearer resource_metadata=\"https://mcp.example.com/.well-known/oauth-protected-resource\",\n scope=\"files:read\"";
1556 let result = parse_www_authenticate(header).unwrap();
1557 assert!(result.resource_metadata.is_some());
1558 assert_eq!(result.scope, Some(vec!["files:read".to_string()]));
1559 }
1560
1561 #[test]
1562 fn test_protected_resource_metadata_urls_with_path() {
1563 let server_url = Url::parse("https://api.example.com/v1/mcp").unwrap();
1564 let urls = protected_resource_metadata_urls(&server_url);
1565
1566 assert_eq!(urls.len(), 2);
1567 assert_eq!(
1568 urls[0].as_str(),
1569 "https://api.example.com/.well-known/oauth-protected-resource/v1/mcp"
1570 );
1571 assert_eq!(
1572 urls[1].as_str(),
1573 "https://api.example.com/.well-known/oauth-protected-resource"
1574 );
1575 }
1576
1577 #[test]
1578 fn test_protected_resource_metadata_urls_without_path() {
1579 let server_url = Url::parse("https://mcp.example.com").unwrap();
1580 let urls = protected_resource_metadata_urls(&server_url);
1581
1582 assert_eq!(urls.len(), 1);
1583 assert_eq!(
1584 urls[0].as_str(),
1585 "https://mcp.example.com/.well-known/oauth-protected-resource"
1586 );
1587 }
1588
1589 #[test]
1590 fn test_auth_server_metadata_urls_with_path() {
1591 let issuer = Url::parse("https://auth.example.com/tenant1").unwrap();
1592 let urls = auth_server_metadata_urls(&issuer);
1593
1594 assert_eq!(urls.len(), 3);
1595 assert_eq!(
1596 urls[0].as_str(),
1597 "https://auth.example.com/.well-known/oauth-authorization-server/tenant1"
1598 );
1599 assert_eq!(
1600 urls[1].as_str(),
1601 "https://auth.example.com/.well-known/openid-configuration/tenant1"
1602 );
1603 assert_eq!(
1604 urls[2].as_str(),
1605 "https://auth.example.com/tenant1/.well-known/openid-configuration"
1606 );
1607 }
1608
1609 #[test]
1610 fn test_auth_server_metadata_urls_without_path() {
1611 let issuer = Url::parse("https://auth.example.com").unwrap();
1612 let urls = auth_server_metadata_urls(&issuer);
1613
1614 assert_eq!(urls.len(), 2);
1615 assert_eq!(
1616 urls[0].as_str(),
1617 "https://auth.example.com/.well-known/oauth-authorization-server"
1618 );
1619 assert_eq!(
1620 urls[1].as_str(),
1621 "https://auth.example.com/.well-known/openid-configuration"
1622 );
1623 }
1624
1625 // -- Canonical server URI tests ------------------------------------------
1626
1627 #[test]
1628 fn test_canonical_server_uri_simple() {
1629 let url = Url::parse("https://mcp.example.com").unwrap();
1630 assert_eq!(canonical_server_uri(&url), "https://mcp.example.com");
1631 }
1632
1633 #[test]
1634 fn test_canonical_server_uri_with_path() {
1635 let url = Url::parse("https://mcp.example.com/v1/mcp").unwrap();
1636 assert_eq!(canonical_server_uri(&url), "https://mcp.example.com/v1/mcp");
1637 }
1638
1639 #[test]
1640 fn test_canonical_server_uri_strips_trailing_slash() {
1641 let url = Url::parse("https://mcp.example.com/").unwrap();
1642 assert_eq!(canonical_server_uri(&url), "https://mcp.example.com");
1643 }
1644
1645 #[test]
1646 fn test_canonical_server_uri_preserves_port() {
1647 let url = Url::parse("https://mcp.example.com:8443").unwrap();
1648 assert_eq!(canonical_server_uri(&url), "https://mcp.example.com:8443");
1649 }
1650
1651 #[test]
1652 fn test_canonical_server_uri_lowercases() {
1653 let url = Url::parse("HTTPS://MCP.Example.COM/Server/MCP").unwrap();
1654 assert_eq!(
1655 canonical_server_uri(&url),
1656 "https://mcp.example.com/Server/MCP"
1657 );
1658 }
1659
1660 // -- Scope selection tests -----------------------------------------------
1661
1662 #[test]
1663 fn test_select_scopes_prefers_www_authenticate() {
1664 let www_auth = WwwAuthenticate {
1665 resource_metadata: None,
1666 scope: Some(vec!["files:read".into()]),
1667 error: None,
1668 error_description: None,
1669 };
1670 let resource_meta = ProtectedResourceMetadata {
1671 resource: Url::parse("https://example.com").unwrap(),
1672 authorization_servers: vec![],
1673 scopes_supported: Some(vec!["files:read".into(), "files:write".into()]),
1674 };
1675 assert_eq!(select_scopes(&www_auth, &resource_meta), vec!["files:read"]);
1676 }
1677
1678 #[test]
1679 fn test_select_scopes_falls_back_to_resource_metadata() {
1680 let www_auth = WwwAuthenticate {
1681 resource_metadata: None,
1682 scope: None,
1683 error: None,
1684 error_description: None,
1685 };
1686 let resource_meta = ProtectedResourceMetadata {
1687 resource: Url::parse("https://example.com").unwrap(),
1688 authorization_servers: vec![],
1689 scopes_supported: Some(vec!["admin".into()]),
1690 };
1691 assert_eq!(select_scopes(&www_auth, &resource_meta), vec!["admin"]);
1692 }
1693
1694 #[test]
1695 fn test_select_scopes_empty_when_nothing_available() {
1696 let www_auth = WwwAuthenticate {
1697 resource_metadata: None,
1698 scope: None,
1699 error: None,
1700 error_description: None,
1701 };
1702 let resource_meta = ProtectedResourceMetadata {
1703 resource: Url::parse("https://example.com").unwrap(),
1704 authorization_servers: vec![],
1705 scopes_supported: None,
1706 };
1707 assert!(select_scopes(&www_auth, &resource_meta).is_empty());
1708 }
1709
1710 // -- Client registration strategy tests ----------------------------------
1711
1712 #[test]
1713 fn test_registration_strategy_prefers_cimd() {
1714 let metadata = AuthServerMetadata {
1715 issuer: Url::parse("https://auth.example.com").unwrap(),
1716 authorization_endpoint: Url::parse("https://auth.example.com/authorize").unwrap(),
1717 token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
1718 registration_endpoint: Some(Url::parse("https://auth.example.com/register").unwrap()),
1719 scopes_supported: None,
1720 code_challenge_methods_supported: Some(vec!["S256".into()]),
1721 client_id_metadata_document_supported: true,
1722 };
1723 assert_eq!(
1724 determine_registration_strategy(&metadata),
1725 ClientRegistrationStrategy::Cimd {
1726 client_id: CIMD_URL.to_string(),
1727 }
1728 );
1729 }
1730
1731 #[test]
1732 fn test_registration_strategy_falls_back_to_dcr() {
1733 let reg_endpoint = Url::parse("https://auth.example.com/register").unwrap();
1734 let metadata = AuthServerMetadata {
1735 issuer: Url::parse("https://auth.example.com").unwrap(),
1736 authorization_endpoint: Url::parse("https://auth.example.com/authorize").unwrap(),
1737 token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
1738 registration_endpoint: Some(reg_endpoint.clone()),
1739 scopes_supported: None,
1740 code_challenge_methods_supported: Some(vec!["S256".into()]),
1741 client_id_metadata_document_supported: false,
1742 };
1743 assert_eq!(
1744 determine_registration_strategy(&metadata),
1745 ClientRegistrationStrategy::Dcr {
1746 registration_endpoint: reg_endpoint,
1747 }
1748 );
1749 }
1750
1751 #[test]
1752 fn test_registration_strategy_unavailable() {
1753 let metadata = AuthServerMetadata {
1754 issuer: Url::parse("https://auth.example.com").unwrap(),
1755 authorization_endpoint: Url::parse("https://auth.example.com/authorize").unwrap(),
1756 token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
1757 registration_endpoint: None,
1758 scopes_supported: None,
1759 code_challenge_methods_supported: Some(vec!["S256".into()]),
1760 client_id_metadata_document_supported: false,
1761 };
1762 assert_eq!(
1763 determine_registration_strategy(&metadata),
1764 ClientRegistrationStrategy::Unavailable,
1765 );
1766 }
1767
1768 // -- PKCE tests ----------------------------------------------------------
1769
1770 #[test]
1771 fn test_pkce_challenge_verifier_length() {
1772 let pkce = generate_pkce_challenge();
1773 // 32 random bytes → 43 base64url chars (no padding).
1774 assert_eq!(pkce.verifier.len(), 43);
1775 }
1776
1777 #[test]
1778 fn test_pkce_challenge_is_valid_base64url() {
1779 let pkce = generate_pkce_challenge();
1780 for c in pkce.verifier.chars().chain(pkce.challenge.chars()) {
1781 assert!(
1782 c.is_ascii_alphanumeric() || c == '-' || c == '_',
1783 "invalid base64url character: {}",
1784 c
1785 );
1786 }
1787 }
1788
1789 #[test]
1790 fn test_pkce_challenge_is_s256_of_verifier() {
1791 let pkce = generate_pkce_challenge();
1792 let engine = base64::engine::general_purpose::URL_SAFE_NO_PAD;
1793 let expected_digest = Sha256::digest(pkce.verifier.as_bytes());
1794 let expected_challenge = engine.encode(expected_digest);
1795 assert_eq!(pkce.challenge, expected_challenge);
1796 }
1797
1798 #[test]
1799 fn test_pkce_challenges_are_unique() {
1800 let a = generate_pkce_challenge();
1801 let b = generate_pkce_challenge();
1802 assert_ne!(a.verifier, b.verifier);
1803 }
1804
1805 // -- Authorization URL tests ---------------------------------------------
1806
1807 #[test]
1808 fn test_build_authorization_url() {
1809 let metadata = AuthServerMetadata {
1810 issuer: Url::parse("https://auth.example.com").unwrap(),
1811 authorization_endpoint: Url::parse("https://auth.example.com/authorize").unwrap(),
1812 token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
1813 registration_endpoint: None,
1814 scopes_supported: None,
1815 code_challenge_methods_supported: Some(vec!["S256".into()]),
1816 client_id_metadata_document_supported: true,
1817 };
1818 let pkce = PkceChallenge {
1819 verifier: "test_verifier".into(),
1820 challenge: "test_challenge".into(),
1821 };
1822 let url = build_authorization_url(
1823 &metadata,
1824 "https://zed.dev/oauth/client-metadata.json",
1825 "http://127.0.0.1:12345/callback",
1826 &["files:read".into(), "files:write".into()],
1827 "https://mcp.example.com",
1828 &pkce,
1829 "random_state_123",
1830 );
1831
1832 let pairs: std::collections::HashMap<_, _> = url.query_pairs().collect();
1833 assert_eq!(pairs.get("response_type").unwrap(), "code");
1834 assert_eq!(
1835 pairs.get("client_id").unwrap(),
1836 "https://zed.dev/oauth/client-metadata.json"
1837 );
1838 assert_eq!(
1839 pairs.get("redirect_uri").unwrap(),
1840 "http://127.0.0.1:12345/callback"
1841 );
1842 assert_eq!(pairs.get("scope").unwrap(), "files:read files:write");
1843 assert_eq!(pairs.get("resource").unwrap(), "https://mcp.example.com");
1844 assert_eq!(pairs.get("code_challenge").unwrap(), "test_challenge");
1845 assert_eq!(pairs.get("code_challenge_method").unwrap(), "S256");
1846 assert_eq!(pairs.get("state").unwrap(), "random_state_123");
1847 }
1848
1849 #[test]
1850 fn test_build_authorization_url_omits_empty_scope() {
1851 let metadata = AuthServerMetadata {
1852 issuer: Url::parse("https://auth.example.com").unwrap(),
1853 authorization_endpoint: Url::parse("https://auth.example.com/authorize").unwrap(),
1854 token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
1855 registration_endpoint: None,
1856 scopes_supported: None,
1857 code_challenge_methods_supported: Some(vec!["S256".into()]),
1858 client_id_metadata_document_supported: false,
1859 };
1860 let pkce = PkceChallenge {
1861 verifier: "v".into(),
1862 challenge: "c".into(),
1863 };
1864 let url = build_authorization_url(
1865 &metadata,
1866 "client_123",
1867 "http://127.0.0.1:9999/callback",
1868 &[],
1869 "https://mcp.example.com",
1870 &pkce,
1871 "state",
1872 );
1873
1874 let pairs: std::collections::HashMap<_, _> = url.query_pairs().collect();
1875 assert!(!pairs.contains_key("scope"));
1876 }
1877
1878 // -- Token exchange / refresh param tests --------------------------------
1879
1880 #[test]
1881 fn test_token_exchange_params() {
1882 let params = token_exchange_params(
1883 "auth_code_abc",
1884 "client_xyz",
1885 "http://127.0.0.1:5555/callback",
1886 "verifier_123",
1887 "https://mcp.example.com",
1888 );
1889 let map: std::collections::HashMap<&str, &str> =
1890 params.iter().map(|(k, v)| (*k, v.as_str())).collect();
1891
1892 assert_eq!(map["grant_type"], "authorization_code");
1893 assert_eq!(map["code"], "auth_code_abc");
1894 assert_eq!(map["redirect_uri"], "http://127.0.0.1:5555/callback");
1895 assert_eq!(map["client_id"], "client_xyz");
1896 assert_eq!(map["code_verifier"], "verifier_123");
1897 assert_eq!(map["resource"], "https://mcp.example.com");
1898 }
1899
1900 #[test]
1901 fn test_token_refresh_params() {
1902 let params =
1903 token_refresh_params("refresh_token_abc", "client_xyz", "https://mcp.example.com");
1904 let map: std::collections::HashMap<&str, &str> =
1905 params.iter().map(|(k, v)| (*k, v.as_str())).collect();
1906
1907 assert_eq!(map["grant_type"], "refresh_token");
1908 assert_eq!(map["refresh_token"], "refresh_token_abc");
1909 assert_eq!(map["client_id"], "client_xyz");
1910 assert_eq!(map["resource"], "https://mcp.example.com");
1911 }
1912
1913 // -- Token response tests ------------------------------------------------
1914
1915 #[test]
1916 fn test_token_response_into_tokens_with_expiry() {
1917 let response: TokenResponse = serde_json::from_str(
1918 r#"{"access_token": "at_123", "refresh_token": "rt_456", "expires_in": 3600, "token_type": "Bearer"}"#,
1919 )
1920 .unwrap();
1921
1922 let tokens = response.into_tokens();
1923 assert_eq!(tokens.access_token, "at_123");
1924 assert_eq!(tokens.refresh_token.as_deref(), Some("rt_456"));
1925 assert!(tokens.expires_at.is_some());
1926 }
1927
1928 #[test]
1929 fn test_token_response_into_tokens_minimal() {
1930 let response: TokenResponse =
1931 serde_json::from_str(r#"{"access_token": "at_789"}"#).unwrap();
1932
1933 let tokens = response.into_tokens();
1934 assert_eq!(tokens.access_token, "at_789");
1935 assert_eq!(tokens.refresh_token, None);
1936 assert_eq!(tokens.expires_at, None);
1937 }
1938
1939 // -- DCR body test -------------------------------------------------------
1940
1941 #[test]
1942 fn test_dcr_registration_body_shape() {
1943 let body = dcr_registration_body("http://127.0.0.1:12345/callback");
1944 assert_eq!(body["client_name"], "Zed");
1945 assert_eq!(body["redirect_uris"][0], "http://127.0.0.1:12345/callback");
1946 assert_eq!(body["grant_types"][0], "authorization_code");
1947 assert_eq!(body["response_types"][0], "code");
1948 assert_eq!(body["token_endpoint_auth_method"], "none");
1949 }
1950
1951 // -- Test helpers for async/HTTP tests -----------------------------------
1952
1953 fn make_fake_http_client(
1954 handler: impl Fn(
1955 http_client::Request<AsyncBody>,
1956 ) -> std::pin::Pin<
1957 Box<dyn std::future::Future<Output = anyhow::Result<Response<AsyncBody>>> + Send>,
1958 > + Send
1959 + Sync
1960 + 'static,
1961 ) -> Arc<dyn HttpClient> {
1962 http_client::FakeHttpClient::create(handler) as Arc<dyn HttpClient>
1963 }
1964
1965 fn json_response(status: u16, body: &str) -> anyhow::Result<Response<AsyncBody>> {
1966 Ok(Response::builder()
1967 .status(status)
1968 .header("Content-Type", "application/json")
1969 .body(AsyncBody::from(body.as_bytes().to_vec()))
1970 .unwrap())
1971 }
1972
1973 // -- Discovery integration tests -----------------------------------------
1974
1975 #[test]
1976 fn test_fetch_protected_resource_metadata() {
1977 smol::block_on(async {
1978 let client = make_fake_http_client(|req| {
1979 Box::pin(async move {
1980 let uri = req.uri().to_string();
1981 if uri.contains(".well-known/oauth-protected-resource") {
1982 json_response(
1983 200,
1984 r#"{
1985 "resource": "https://mcp.example.com",
1986 "authorization_servers": ["https://auth.example.com"],
1987 "scopes_supported": ["read", "write"]
1988 }"#,
1989 )
1990 } else {
1991 json_response(404, "{}")
1992 }
1993 })
1994 });
1995
1996 let server_url = Url::parse("https://mcp.example.com").unwrap();
1997 let www_auth = WwwAuthenticate {
1998 resource_metadata: None,
1999 scope: None,
2000 error: None,
2001 error_description: None,
2002 };
2003
2004 let metadata = fetch_protected_resource_metadata(&client, &server_url, &www_auth)
2005 .await
2006 .unwrap();
2007
2008 assert_eq!(metadata.resource.as_str(), "https://mcp.example.com/");
2009 assert_eq!(metadata.authorization_servers.len(), 1);
2010 assert_eq!(
2011 metadata.authorization_servers[0].as_str(),
2012 "https://auth.example.com/"
2013 );
2014 assert_eq!(
2015 metadata.scopes_supported,
2016 Some(vec!["read".to_string(), "write".to_string()])
2017 );
2018 });
2019 }
2020
2021 #[test]
2022 fn test_fetch_protected_resource_metadata_prefers_www_authenticate_url() {
2023 smol::block_on(async {
2024 let client = make_fake_http_client(|req| {
2025 Box::pin(async move {
2026 let uri = req.uri().to_string();
2027 if uri == "https://mcp.example.com/custom-resource-metadata" {
2028 json_response(
2029 200,
2030 r#"{
2031 "resource": "https://mcp.example.com",
2032 "authorization_servers": ["https://auth.example.com"]
2033 }"#,
2034 )
2035 } else {
2036 json_response(500, r#"{"error": "should not be called"}"#)
2037 }
2038 })
2039 });
2040
2041 let server_url = Url::parse("https://mcp.example.com").unwrap();
2042 let www_auth = WwwAuthenticate {
2043 resource_metadata: Some(
2044 Url::parse("https://mcp.example.com/custom-resource-metadata").unwrap(),
2045 ),
2046 scope: None,
2047 error: None,
2048 error_description: None,
2049 };
2050
2051 let metadata = fetch_protected_resource_metadata(&client, &server_url, &www_auth)
2052 .await
2053 .unwrap();
2054
2055 assert_eq!(metadata.authorization_servers.len(), 1);
2056 });
2057 }
2058
2059 #[test]
2060 fn test_fetch_protected_resource_metadata_falls_back_when_header_url_fails() {
2061 // Reproduces the Pydantic Logfire case: the server's WWW-Authenticate
2062 // header contains a resource_metadata URL with a doubled path (e.g.
2063 // /mcp/mcp), which returns HTML instead of JSON. The client should
2064 // fall back to the RFC 9728 well-known URL, which works correctly.
2065 smol::block_on(async {
2066 let client = make_fake_http_client(|req| {
2067 Box::pin(async move {
2068 let uri = req.uri().to_string();
2069 if uri
2070 == "https://mcp.example.com/.well-known/oauth-protected-resource/api/mcp/mcp"
2071 {
2072 // Buggy header URL returns HTML (like a SPA catch-all).
2073 Ok(Response::builder()
2074 .status(200)
2075 .header("Content-Type", "text/html")
2076 .body(AsyncBody::from(b"<!doctype html><html></html>".to_vec()))
2077 .unwrap())
2078 } else if uri
2079 == "https://mcp.example.com/.well-known/oauth-protected-resource/api/mcp"
2080 {
2081 // Correct well-known URL returns valid metadata.
2082 json_response(
2083 200,
2084 r#"{
2085 "resource": "https://mcp.example.com/api/mcp",
2086 "authorization_servers": ["https://auth.example.com"]
2087 }"#,
2088 )
2089 } else {
2090 json_response(404, "{}")
2091 }
2092 })
2093 });
2094
2095 let server_url = Url::parse("https://mcp.example.com/api/mcp").unwrap();
2096 let www_auth = WwwAuthenticate {
2097 resource_metadata: Some(
2098 // Buggy URL with doubled path component.
2099 Url::parse(
2100 "https://mcp.example.com/.well-known/oauth-protected-resource/api/mcp/mcp",
2101 )
2102 .unwrap(),
2103 ),
2104 scope: None,
2105 error: None,
2106 error_description: None,
2107 };
2108
2109 let metadata = fetch_protected_resource_metadata(&client, &server_url, &www_auth)
2110 .await
2111 .unwrap();
2112
2113 assert_eq!(
2114 metadata.resource.as_str(),
2115 "https://mcp.example.com/api/mcp"
2116 );
2117 assert_eq!(
2118 metadata.authorization_servers[0].as_str(),
2119 "https://auth.example.com/"
2120 );
2121 });
2122 }
2123
2124 #[test]
2125 fn test_fetch_protected_resource_metadata_rejects_cross_origin_url() {
2126 smol::block_on(async {
2127 let client = make_fake_http_client(|req| {
2128 Box::pin(async move {
2129 let uri = req.uri().to_string();
2130 // The cross-origin URL should NOT be fetched; only the
2131 // well-known fallback at the server's own origin should be.
2132 if uri.contains("attacker.example.com") {
2133 panic!("should not fetch cross-origin resource_metadata URL");
2134 } else if uri.contains(".well-known/oauth-protected-resource") {
2135 json_response(
2136 200,
2137 r#"{
2138 "resource": "https://mcp.example.com",
2139 "authorization_servers": ["https://auth.example.com"]
2140 }"#,
2141 )
2142 } else {
2143 json_response(404, "{}")
2144 }
2145 })
2146 });
2147
2148 let server_url = Url::parse("https://mcp.example.com").unwrap();
2149 let www_auth = WwwAuthenticate {
2150 resource_metadata: Some(
2151 Url::parse("https://attacker.example.com/fake-metadata").unwrap(),
2152 ),
2153 scope: None,
2154 error: None,
2155 error_description: None,
2156 };
2157
2158 let metadata = fetch_protected_resource_metadata(&client, &server_url, &www_auth)
2159 .await
2160 .unwrap();
2161
2162 // Should have used the fallback well-known URL, not the attacker's.
2163 assert_eq!(metadata.resource.as_str(), "https://mcp.example.com/");
2164 });
2165 }
2166
2167 #[test]
2168 fn test_fetch_auth_server_metadata() {
2169 smol::block_on(async {
2170 let client = make_fake_http_client(|req| {
2171 Box::pin(async move {
2172 let uri = req.uri().to_string();
2173 if uri.contains(".well-known/oauth-authorization-server") {
2174 json_response(
2175 200,
2176 r#"{
2177 "issuer": "https://auth.example.com",
2178 "authorization_endpoint": "https://auth.example.com/authorize",
2179 "token_endpoint": "https://auth.example.com/token",
2180 "registration_endpoint": "https://auth.example.com/register",
2181 "code_challenge_methods_supported": ["S256"],
2182 "client_id_metadata_document_supported": true
2183 }"#,
2184 )
2185 } else {
2186 json_response(404, "{}")
2187 }
2188 })
2189 });
2190
2191 let issuer = Url::parse("https://auth.example.com").unwrap();
2192 let metadata = fetch_auth_server_metadata(&client, &issuer).await.unwrap();
2193
2194 assert_eq!(metadata.issuer.as_str(), "https://auth.example.com/");
2195 assert_eq!(
2196 metadata.authorization_endpoint.as_str(),
2197 "https://auth.example.com/authorize"
2198 );
2199 assert_eq!(
2200 metadata.token_endpoint.as_str(),
2201 "https://auth.example.com/token"
2202 );
2203 assert!(metadata.registration_endpoint.is_some());
2204 assert!(metadata.client_id_metadata_document_supported);
2205 assert_eq!(
2206 metadata.code_challenge_methods_supported,
2207 Some(vec!["S256".to_string()])
2208 );
2209 });
2210 }
2211
2212 #[test]
2213 fn test_fetch_auth_server_metadata_falls_back_to_oidc() {
2214 smol::block_on(async {
2215 let client = make_fake_http_client(|req| {
2216 Box::pin(async move {
2217 let uri = req.uri().to_string();
2218 if uri.contains("openid-configuration") {
2219 json_response(
2220 200,
2221 r#"{
2222 "issuer": "https://auth.example.com",
2223 "authorization_endpoint": "https://auth.example.com/authorize",
2224 "token_endpoint": "https://auth.example.com/token",
2225 "code_challenge_methods_supported": ["S256"]
2226 }"#,
2227 )
2228 } else {
2229 json_response(404, "{}")
2230 }
2231 })
2232 });
2233
2234 let issuer = Url::parse("https://auth.example.com").unwrap();
2235 let metadata = fetch_auth_server_metadata(&client, &issuer).await.unwrap();
2236
2237 assert_eq!(
2238 metadata.authorization_endpoint.as_str(),
2239 "https://auth.example.com/authorize"
2240 );
2241 assert!(!metadata.client_id_metadata_document_supported);
2242 });
2243 }
2244
2245 #[test]
2246 fn test_fetch_auth_server_metadata_rejects_issuer_mismatch() {
2247 smol::block_on(async {
2248 let client = make_fake_http_client(|req| {
2249 Box::pin(async move {
2250 let uri = req.uri().to_string();
2251 if uri.contains(".well-known/oauth-authorization-server") {
2252 // Response claims to be a different issuer.
2253 json_response(
2254 200,
2255 r#"{
2256 "issuer": "https://evil.example.com",
2257 "authorization_endpoint": "https://evil.example.com/authorize",
2258 "token_endpoint": "https://evil.example.com/token",
2259 "code_challenge_methods_supported": ["S256"]
2260 }"#,
2261 )
2262 } else {
2263 json_response(404, "{}")
2264 }
2265 })
2266 });
2267
2268 let issuer = Url::parse("https://auth.example.com").unwrap();
2269 let result = fetch_auth_server_metadata(&client, &issuer).await;
2270
2271 assert!(result.is_err());
2272 let err_msg = result.unwrap_err().to_string();
2273 assert!(
2274 err_msg.contains("issuer mismatch"),
2275 "unexpected error: {}",
2276 err_msg
2277 );
2278 });
2279 }
2280
2281 // -- Full discover integration tests -------------------------------------
2282
2283 #[test]
2284 fn test_full_discover_with_cimd() {
2285 smol::block_on(async {
2286 let client = make_fake_http_client(|req| {
2287 Box::pin(async move {
2288 let uri = req.uri().to_string();
2289 if uri.contains("oauth-protected-resource") {
2290 json_response(
2291 200,
2292 r#"{
2293 "resource": "https://mcp.example.com",
2294 "authorization_servers": ["https://auth.example.com"],
2295 "scopes_supported": ["mcp:read"]
2296 }"#,
2297 )
2298 } else if uri.contains("oauth-authorization-server") {
2299 json_response(
2300 200,
2301 r#"{
2302 "issuer": "https://auth.example.com",
2303 "authorization_endpoint": "https://auth.example.com/authorize",
2304 "token_endpoint": "https://auth.example.com/token",
2305 "code_challenge_methods_supported": ["S256"],
2306 "client_id_metadata_document_supported": true
2307 }"#,
2308 )
2309 } else {
2310 json_response(404, "{}")
2311 }
2312 })
2313 });
2314
2315 let server_url = Url::parse("https://mcp.example.com").unwrap();
2316 let www_auth = WwwAuthenticate {
2317 resource_metadata: None,
2318 scope: None,
2319 error: None,
2320 error_description: None,
2321 };
2322
2323 let discovery = discover(&client, &server_url, &www_auth).await.unwrap();
2324 let registration =
2325 resolve_client_registration(&client, &discovery, "http://127.0.0.1:12345/callback")
2326 .await
2327 .unwrap();
2328
2329 assert_eq!(registration.client_id, CIMD_URL);
2330 assert_eq!(registration.client_secret, None);
2331 assert_eq!(discovery.scopes, vec!["mcp:read"]);
2332 });
2333 }
2334
2335 #[test]
2336 fn test_full_discover_with_dcr_fallback() {
2337 smol::block_on(async {
2338 let client = make_fake_http_client(|req| {
2339 Box::pin(async move {
2340 let uri = req.uri().to_string();
2341 if uri.contains("oauth-protected-resource") {
2342 json_response(
2343 200,
2344 r#"{
2345 "resource": "https://mcp.example.com",
2346 "authorization_servers": ["https://auth.example.com"]
2347 }"#,
2348 )
2349 } else if uri.contains("oauth-authorization-server") {
2350 json_response(
2351 200,
2352 r#"{
2353 "issuer": "https://auth.example.com",
2354 "authorization_endpoint": "https://auth.example.com/authorize",
2355 "token_endpoint": "https://auth.example.com/token",
2356 "registration_endpoint": "https://auth.example.com/register",
2357 "code_challenge_methods_supported": ["S256"],
2358 "client_id_metadata_document_supported": false
2359 }"#,
2360 )
2361 } else if uri.contains("/register") {
2362 json_response(
2363 201,
2364 r#"{
2365 "client_id": "dcr-minted-id-123",
2366 "client_secret": "dcr-secret-456"
2367 }"#,
2368 )
2369 } else {
2370 json_response(404, "{}")
2371 }
2372 })
2373 });
2374
2375 let server_url = Url::parse("https://mcp.example.com").unwrap();
2376 let www_auth = WwwAuthenticate {
2377 resource_metadata: None,
2378 scope: Some(vec!["files:read".into()]),
2379 error: None,
2380 error_description: None,
2381 };
2382
2383 let discovery = discover(&client, &server_url, &www_auth).await.unwrap();
2384 let registration =
2385 resolve_client_registration(&client, &discovery, "http://127.0.0.1:9999/callback")
2386 .await
2387 .unwrap();
2388
2389 assert_eq!(registration.client_id, "dcr-minted-id-123");
2390 assert_eq!(
2391 registration.client_secret.as_deref(),
2392 Some("dcr-secret-456")
2393 );
2394 assert_eq!(discovery.scopes, vec!["files:read"]);
2395 });
2396 }
2397
2398 #[test]
2399 fn test_discover_fails_without_pkce_support() {
2400 smol::block_on(async {
2401 let client = make_fake_http_client(|req| {
2402 Box::pin(async move {
2403 let uri = req.uri().to_string();
2404 if uri.contains("oauth-protected-resource") {
2405 json_response(
2406 200,
2407 r#"{
2408 "resource": "https://mcp.example.com",
2409 "authorization_servers": ["https://auth.example.com"]
2410 }"#,
2411 )
2412 } else if uri.contains("oauth-authorization-server") {
2413 json_response(
2414 200,
2415 r#"{
2416 "issuer": "https://auth.example.com",
2417 "authorization_endpoint": "https://auth.example.com/authorize",
2418 "token_endpoint": "https://auth.example.com/token"
2419 }"#,
2420 )
2421 } else {
2422 json_response(404, "{}")
2423 }
2424 })
2425 });
2426
2427 let server_url = Url::parse("https://mcp.example.com").unwrap();
2428 let www_auth = WwwAuthenticate {
2429 resource_metadata: None,
2430 scope: None,
2431 error: None,
2432 error_description: None,
2433 };
2434
2435 let result = discover(&client, &server_url, &www_auth).await;
2436 assert!(result.is_err());
2437 let err_msg = result.unwrap_err().to_string();
2438 assert!(
2439 err_msg.contains("code_challenge_methods_supported"),
2440 "unexpected error: {}",
2441 err_msg
2442 );
2443 });
2444 }
2445
2446 // -- Token exchange integration tests ------------------------------------
2447
2448 #[test]
2449 fn test_exchange_code_success() {
2450 smol::block_on(async {
2451 let client = make_fake_http_client(|req| {
2452 Box::pin(async move {
2453 let uri = req.uri().to_string();
2454 if uri.contains("/token") {
2455 json_response(
2456 200,
2457 r#"{
2458 "access_token": "new_access_token",
2459 "refresh_token": "new_refresh_token",
2460 "expires_in": 3600,
2461 "token_type": "Bearer"
2462 }"#,
2463 )
2464 } else {
2465 json_response(404, "{}")
2466 }
2467 })
2468 });
2469
2470 let metadata = AuthServerMetadata {
2471 issuer: Url::parse("https://auth.example.com").unwrap(),
2472 authorization_endpoint: Url::parse("https://auth.example.com/authorize").unwrap(),
2473 token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
2474 registration_endpoint: None,
2475 scopes_supported: None,
2476 code_challenge_methods_supported: Some(vec!["S256".into()]),
2477 client_id_metadata_document_supported: true,
2478 };
2479
2480 let tokens = exchange_code(
2481 &client,
2482 &metadata,
2483 "auth_code_123",
2484 CIMD_URL,
2485 "http://127.0.0.1:9999/callback",
2486 "verifier_abc",
2487 "https://mcp.example.com",
2488 )
2489 .await
2490 .unwrap();
2491
2492 assert_eq!(tokens.access_token, "new_access_token");
2493 assert_eq!(tokens.refresh_token.as_deref(), Some("new_refresh_token"));
2494 assert!(tokens.expires_at.is_some());
2495 });
2496 }
2497
2498 #[test]
2499 fn test_refresh_tokens_success() {
2500 smol::block_on(async {
2501 let client = make_fake_http_client(|req| {
2502 Box::pin(async move {
2503 let uri = req.uri().to_string();
2504 if uri.contains("/token") {
2505 json_response(
2506 200,
2507 r#"{
2508 "access_token": "refreshed_token",
2509 "expires_in": 1800,
2510 "token_type": "Bearer"
2511 }"#,
2512 )
2513 } else {
2514 json_response(404, "{}")
2515 }
2516 })
2517 });
2518
2519 let token_endpoint = Url::parse("https://auth.example.com/token").unwrap();
2520
2521 let tokens = refresh_tokens(
2522 &client,
2523 &token_endpoint,
2524 "old_refresh_token",
2525 CIMD_URL,
2526 "https://mcp.example.com",
2527 )
2528 .await
2529 .unwrap();
2530
2531 assert_eq!(tokens.access_token, "refreshed_token");
2532 assert_eq!(tokens.refresh_token, None);
2533 assert!(tokens.expires_at.is_some());
2534 });
2535 }
2536
2537 #[test]
2538 fn test_exchange_code_failure() {
2539 smol::block_on(async {
2540 let client = make_fake_http_client(|_req| {
2541 Box::pin(async move { json_response(400, r#"{"error": "invalid_grant"}"#) })
2542 });
2543
2544 let metadata = AuthServerMetadata {
2545 issuer: Url::parse("https://auth.example.com").unwrap(),
2546 authorization_endpoint: Url::parse("https://auth.example.com/authorize").unwrap(),
2547 token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
2548 registration_endpoint: None,
2549 scopes_supported: None,
2550 code_challenge_methods_supported: Some(vec!["S256".into()]),
2551 client_id_metadata_document_supported: true,
2552 };
2553
2554 let result = exchange_code(
2555 &client,
2556 &metadata,
2557 "bad_code",
2558 "client",
2559 "http://127.0.0.1:1/callback",
2560 "verifier",
2561 "https://mcp.example.com",
2562 )
2563 .await;
2564
2565 assert!(result.is_err());
2566 assert!(result.unwrap_err().to_string().contains("400"));
2567 });
2568 }
2569
2570 // -- DCR integration tests -----------------------------------------------
2571
2572 #[test]
2573 fn test_perform_dcr() {
2574 smol::block_on(async {
2575 let client = make_fake_http_client(|_req| {
2576 Box::pin(async move {
2577 json_response(
2578 201,
2579 r#"{
2580 "client_id": "dynamic-client-001",
2581 "client_secret": "dynamic-secret-001"
2582 }"#,
2583 )
2584 })
2585 });
2586
2587 let endpoint = Url::parse("https://auth.example.com/register").unwrap();
2588 let registration = perform_dcr(&client, &endpoint, "http://127.0.0.1:9999/callback")
2589 .await
2590 .unwrap();
2591
2592 assert_eq!(registration.client_id, "dynamic-client-001");
2593 assert_eq!(
2594 registration.client_secret.as_deref(),
2595 Some("dynamic-secret-001")
2596 );
2597 });
2598 }
2599
2600 #[test]
2601 fn test_perform_dcr_failure() {
2602 smol::block_on(async {
2603 let client = make_fake_http_client(|_req| {
2604 Box::pin(
2605 async move { json_response(403, r#"{"error": "registration_not_allowed"}"#) },
2606 )
2607 });
2608
2609 let endpoint = Url::parse("https://auth.example.com/register").unwrap();
2610 let result = perform_dcr(&client, &endpoint, "http://127.0.0.1:9999/callback").await;
2611
2612 assert!(result.is_err());
2613 assert!(result.unwrap_err().to_string().contains("403"));
2614 });
2615 }
2616
2617 // -- OAuthCallback parse tests -------------------------------------------
2618
2619 #[test]
2620 fn test_oauth_callback_parse_query() {
2621 let callback = OAuthCallback::parse_query("code=test_auth_code&state=test_state").unwrap();
2622 assert_eq!(callback.code, "test_auth_code");
2623 assert_eq!(callback.state, "test_state");
2624 }
2625
2626 #[test]
2627 fn test_oauth_callback_parse_query_reversed_order() {
2628 let callback = OAuthCallback::parse_query("state=test_state&code=test_auth_code").unwrap();
2629 assert_eq!(callback.code, "test_auth_code");
2630 assert_eq!(callback.state, "test_state");
2631 }
2632
2633 #[test]
2634 fn test_oauth_callback_parse_query_with_extra_params() {
2635 let callback =
2636 OAuthCallback::parse_query("code=test_auth_code&state=test_state&extra=ignored")
2637 .unwrap();
2638 assert_eq!(callback.code, "test_auth_code");
2639 assert_eq!(callback.state, "test_state");
2640 }
2641
2642 #[test]
2643 fn test_oauth_callback_parse_query_missing_code() {
2644 let result = OAuthCallback::parse_query("state=test_state");
2645 assert!(result.is_err());
2646 assert!(result.unwrap_err().to_string().contains("code"));
2647 }
2648
2649 #[test]
2650 fn test_oauth_callback_parse_query_missing_state() {
2651 let result = OAuthCallback::parse_query("code=test_auth_code");
2652 assert!(result.is_err());
2653 assert!(result.unwrap_err().to_string().contains("state"));
2654 }
2655
2656 #[test]
2657 fn test_oauth_callback_parse_query_empty_code() {
2658 let result = OAuthCallback::parse_query("code=&state=test_state");
2659 assert!(result.is_err());
2660 }
2661
2662 #[test]
2663 fn test_oauth_callback_parse_query_empty_state() {
2664 let result = OAuthCallback::parse_query("code=test_auth_code&state=");
2665 assert!(result.is_err());
2666 }
2667
2668 #[test]
2669 fn test_oauth_callback_parse_query_url_encoded_values() {
2670 let callback = OAuthCallback::parse_query("code=abc%20def&state=test%3Dstate").unwrap();
2671 assert_eq!(callback.code, "abc def");
2672 assert_eq!(callback.state, "test=state");
2673 }
2674
2675 #[test]
2676 fn test_oauth_callback_parse_query_error_response() {
2677 let result = OAuthCallback::parse_query(
2678 "error=access_denied&error_description=User%20denied%20access&state=abc",
2679 );
2680 assert!(result.is_err());
2681 let err_msg = result.unwrap_err().to_string();
2682 assert!(
2683 err_msg.contains("access_denied"),
2684 "unexpected error: {}",
2685 err_msg
2686 );
2687 assert!(
2688 err_msg.contains("User denied access"),
2689 "unexpected error: {}",
2690 err_msg
2691 );
2692 }
2693
2694 #[test]
2695 fn test_oauth_callback_parse_query_error_without_description() {
2696 let result = OAuthCallback::parse_query("error=server_error&state=abc");
2697 assert!(result.is_err());
2698 let err_msg = result.unwrap_err().to_string();
2699 assert!(
2700 err_msg.contains("server_error"),
2701 "unexpected error: {}",
2702 err_msg
2703 );
2704 assert!(
2705 err_msg.contains("no description"),
2706 "unexpected error: {}",
2707 err_msg
2708 );
2709 }
2710
2711 // -- McpOAuthTokenProvider tests -----------------------------------------
2712
2713 fn make_test_session(
2714 access_token: &str,
2715 refresh_token: Option<&str>,
2716 expires_at: Option<SystemTime>,
2717 ) -> OAuthSession {
2718 OAuthSession {
2719 token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
2720 resource: Url::parse("https://mcp.example.com").unwrap(),
2721 client_registration: OAuthClientRegistration {
2722 client_id: "test-client".into(),
2723 client_secret: None,
2724 },
2725 tokens: OAuthTokens {
2726 access_token: access_token.into(),
2727 refresh_token: refresh_token.map(String::from),
2728 expires_at,
2729 },
2730 }
2731 }
2732
2733 #[test]
2734 fn test_mcp_oauth_provider_returns_none_when_token_expired() {
2735 let expired = SystemTime::now() - Duration::from_secs(60);
2736 let session = make_test_session("stale-token", Some("rt"), Some(expired));
2737 let provider = McpOAuthTokenProvider::new(
2738 session,
2739 make_fake_http_client(|_| Box::pin(async { unreachable!() })),
2740 None,
2741 );
2742
2743 assert_eq!(provider.access_token(), None);
2744 }
2745
2746 #[test]
2747 fn test_mcp_oauth_provider_returns_token_when_not_expired() {
2748 let far_future = SystemTime::now() + Duration::from_secs(3600);
2749 let session = make_test_session("valid-token", Some("rt"), Some(far_future));
2750 let provider = McpOAuthTokenProvider::new(
2751 session,
2752 make_fake_http_client(|_| Box::pin(async { unreachable!() })),
2753 None,
2754 );
2755
2756 assert_eq!(provider.access_token().as_deref(), Some("valid-token"));
2757 }
2758
2759 #[test]
2760 fn test_mcp_oauth_provider_returns_token_when_no_expiry() {
2761 let session = make_test_session("no-expiry-token", Some("rt"), None);
2762 let provider = McpOAuthTokenProvider::new(
2763 session,
2764 make_fake_http_client(|_| Box::pin(async { unreachable!() })),
2765 None,
2766 );
2767
2768 assert_eq!(provider.access_token().as_deref(), Some("no-expiry-token"));
2769 }
2770
2771 #[test]
2772 fn test_mcp_oauth_provider_refresh_without_refresh_token_returns_false() {
2773 smol::block_on(async {
2774 let session = make_test_session("token", None, None);
2775 let provider = McpOAuthTokenProvider::new(
2776 session,
2777 make_fake_http_client(|_| {
2778 Box::pin(async { unreachable!("no HTTP call expected") })
2779 }),
2780 None,
2781 );
2782
2783 let refreshed = provider.try_refresh().await.unwrap();
2784 assert!(!refreshed);
2785 });
2786 }
2787
2788 #[test]
2789 fn test_mcp_oauth_provider_refresh_updates_session_and_notifies_channel() {
2790 smol::block_on(async {
2791 let session = make_test_session("old-access", Some("my-refresh-token"), None);
2792 let (tx, mut rx) = futures::channel::mpsc::unbounded();
2793
2794 let http_client = make_fake_http_client(|_req| {
2795 Box::pin(async {
2796 json_response(
2797 200,
2798 r#"{
2799 "access_token": "new-access",
2800 "refresh_token": "new-refresh",
2801 "expires_in": 1800
2802 }"#,
2803 )
2804 })
2805 });
2806
2807 let provider = McpOAuthTokenProvider::new(session, http_client, Some(tx));
2808
2809 let refreshed = provider.try_refresh().await.unwrap();
2810 assert!(refreshed);
2811 assert_eq!(provider.access_token().as_deref(), Some("new-access"));
2812
2813 let notified_session = rx.try_recv().expect("channel should have a session");
2814 assert_eq!(notified_session.tokens.access_token, "new-access");
2815 assert_eq!(
2816 notified_session.tokens.refresh_token.as_deref(),
2817 Some("new-refresh")
2818 );
2819 });
2820 }
2821
2822 #[test]
2823 fn test_mcp_oauth_provider_refresh_preserves_old_refresh_token_when_server_omits_it() {
2824 smol::block_on(async {
2825 let session = make_test_session("old-access", Some("original-refresh"), None);
2826 let (tx, mut rx) = futures::channel::mpsc::unbounded();
2827
2828 let http_client = make_fake_http_client(|_req| {
2829 Box::pin(async {
2830 json_response(
2831 200,
2832 r#"{
2833 "access_token": "new-access",
2834 "expires_in": 900
2835 }"#,
2836 )
2837 })
2838 });
2839
2840 let provider = McpOAuthTokenProvider::new(session, http_client, Some(tx));
2841
2842 let refreshed = provider.try_refresh().await.unwrap();
2843 assert!(refreshed);
2844
2845 let notified_session = rx.try_recv().expect("channel should have a session");
2846 assert_eq!(notified_session.tokens.access_token, "new-access");
2847 assert_eq!(
2848 notified_session.tokens.refresh_token.as_deref(),
2849 Some("original-refresh"),
2850 );
2851 });
2852 }
2853
2854 #[test]
2855 fn test_mcp_oauth_provider_refresh_returns_false_on_http_error() {
2856 smol::block_on(async {
2857 let session = make_test_session("old-access", Some("my-refresh"), None);
2858
2859 let http_client = make_fake_http_client(|_req| {
2860 Box::pin(async { json_response(401, r#"{"error": "invalid_grant"}"#) })
2861 });
2862
2863 let provider = McpOAuthTokenProvider::new(session, http_client, None);
2864
2865 let refreshed = provider.try_refresh().await.unwrap();
2866 assert!(!refreshed);
2867 // The old token should still be in place.
2868 assert_eq!(provider.access_token().as_deref(), Some("old-access"));
2869 });
2870 }
2871}