1//! OAuth 2.0 authentication for MCP servers using the Authorization Code +
2//! PKCE flow, per the MCP spec's OAuth profile.
3//!
4//! The flow is split into two phases:
5//!
6//! 1. **Discovery** ([`discover`]) fetches Protected Resource Metadata and
7//! Authorization Server Metadata. This can happen early (e.g. on a 401
8//! during server startup) because it doesn't need the redirect URI yet.
9//!
10//! 2. **Client registration** ([`resolve_client_registration`]) is separate
11//! because DCR requires the actual loopback redirect URI, which includes an
12//! ephemeral port that only exists once the callback server has started.
13//!
14//! After authentication, the full state is captured in [`OAuthSession`] which
15//! is persisted to the keychain. On next startup, the stored session feeds
16//! directly into [`McpOAuthTokenProvider`], giving a refresh-capable provider
17//! without requiring another browser flow.
18
19use anyhow::{Context as _, Result, anyhow, bail};
20use async_trait::async_trait;
21use base64::Engine as _;
22use futures::AsyncReadExt as _;
23use futures::channel::mpsc;
24use http_client::{AsyncBody, HttpClient, Request};
25use parking_lot::Mutex as SyncMutex;
26use rand::Rng as _;
27use serde::{Deserialize, Serialize};
28use sha2::{Digest, Sha256};
29
30use std::sync::Arc;
31use std::time::{Duration, SystemTime};
32use url::Url;
33
34/// The CIMD URL where Zed's OAuth client metadata document is hosted.
35pub const CIMD_URL: &str = "https://zed.dev/oauth/client-metadata.json";
36
37/// Validate that a URL is safe to use as an OAuth endpoint.
38///
39/// OAuth endpoints carry sensitive material (authorization codes, PKCE
40/// verifiers, tokens) and must use TLS. Plain HTTP is only permitted for
41/// loopback addresses, per RFC 8252 Section 8.3.
42fn require_https_or_loopback(url: &Url) -> Result<()> {
43 if url.scheme() == "https" {
44 return Ok(());
45 }
46 if url.scheme() == "http" {
47 if let Some(host) = url.host() {
48 match host {
49 url::Host::Ipv4(ip) if ip.is_loopback() => return Ok(()),
50 url::Host::Ipv6(ip) if ip.is_loopback() => return Ok(()),
51 url::Host::Domain(d) if d.eq_ignore_ascii_case("localhost") => return Ok(()),
52 _ => {}
53 }
54 }
55 }
56 bail!(
57 "OAuth endpoint must use HTTPS (got {}://{})",
58 url.scheme(),
59 url.host_str().unwrap_or("?")
60 )
61}
62
63/// Validate that a URL is safe to use as an OAuth endpoint, including SSRF
64/// protections against private/reserved IP ranges.
65///
66/// This wraps [`require_https_or_loopback`] and adds IP-range checks to prevent
67/// an attacker-controlled MCP server from directing Zed to fetch internal
68/// network resources via metadata URLs.
69///
70/// **Known limitation:** Domain-name URLs that resolve to private IPs are *not*
71/// blocked here — full mitigation requires resolver-level validation (e.g. a
72/// custom `Resolve` implementation). This function only blocks IP-literal URLs.
73fn validate_oauth_url(url: &Url) -> Result<()> {
74 require_https_or_loopback(url)?;
75
76 if let Some(host) = url.host() {
77 match host {
78 url::Host::Ipv4(ip) => {
79 // Loopback is already allowed by require_https_or_loopback.
80 if ip.is_private() || ip.is_link_local() || ip.is_broadcast() || ip.is_unspecified()
81 {
82 bail!(
83 "OAuth endpoint must not point to private/reserved IP: {}",
84 ip
85 );
86 }
87 }
88 url::Host::Ipv6(ip) => {
89 // Check for IPv4-mapped IPv6 addresses (::ffff:a.b.c.d) which
90 // could bypass the IPv4 checks above.
91 if let Some(mapped_v4) = ip.to_ipv4_mapped() {
92 if mapped_v4.is_private()
93 || mapped_v4.is_link_local()
94 || mapped_v4.is_broadcast()
95 || mapped_v4.is_unspecified()
96 {
97 bail!(
98 "OAuth endpoint must not point to private/reserved IP: ::ffff:{}",
99 mapped_v4
100 );
101 }
102 }
103
104 if ip.is_unspecified() || ip.is_multicast() {
105 bail!(
106 "OAuth endpoint must not point to reserved IPv6 address: {}",
107 ip
108 );
109 }
110 // IPv6 Unique Local Addresses (fc00::/7). is_unique_local() is
111 // nightly-only, so check the prefix manually.
112 if (ip.segments()[0] & 0xfe00) == 0xfc00 {
113 bail!(
114 "OAuth endpoint must not point to IPv6 unique-local address: {}",
115 ip
116 );
117 }
118 }
119 url::Host::Domain(_) => {
120 // Domain-based SSRF prevention requires resolver-level checks.
121 // See known limitation in the doc comment above.
122 }
123 }
124 }
125
126 Ok(())
127}
128
129/// Parsed from the MCP server's WWW-Authenticate header or well-known endpoint
130/// per RFC 9728 (OAuth 2.0 Protected Resource Metadata).
131#[derive(Debug, Clone, Serialize, Deserialize)]
132pub struct ProtectedResourceMetadata {
133 pub resource: Url,
134 pub authorization_servers: Vec<Url>,
135 pub scopes_supported: Option<Vec<String>>,
136}
137
138/// Parsed from the authorization server's .well-known endpoint
139/// per RFC 8414 (OAuth 2.0 Authorization Server Metadata).
140#[derive(Debug, Clone, Serialize, Deserialize)]
141pub struct AuthServerMetadata {
142 pub issuer: Url,
143 pub authorization_endpoint: Url,
144 pub token_endpoint: Url,
145 pub registration_endpoint: Option<Url>,
146 pub scopes_supported: Option<Vec<String>>,
147 pub code_challenge_methods_supported: Option<Vec<String>>,
148 pub client_id_metadata_document_supported: bool,
149}
150
151/// The result of client registration — either CIMD or DCR.
152#[derive(Clone, Serialize, Deserialize)]
153pub struct OAuthClientRegistration {
154 pub client_id: String,
155 /// Only present for DCR-minted registrations.
156 pub client_secret: Option<String>,
157}
158
159impl std::fmt::Debug for OAuthClientRegistration {
160 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
161 f.debug_struct("OAuthClientRegistration")
162 .field("client_id", &self.client_id)
163 .field(
164 "client_secret",
165 &self.client_secret.as_ref().map(|_| "[redacted]"),
166 )
167 .finish()
168 }
169}
170
171/// Access and refresh tokens obtained from the token endpoint.
172#[derive(Clone, Serialize, Deserialize)]
173pub struct OAuthTokens {
174 pub access_token: String,
175 pub refresh_token: Option<String>,
176 pub expires_at: Option<SystemTime>,
177}
178
179impl std::fmt::Debug for OAuthTokens {
180 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
181 f.debug_struct("OAuthTokens")
182 .field("access_token", &"[redacted]")
183 .field(
184 "refresh_token",
185 &self.refresh_token.as_ref().map(|_| "[redacted]"),
186 )
187 .field("expires_at", &self.expires_at)
188 .finish()
189 }
190}
191
192/// Everything discovered before the browser flow starts. Client registration is
193/// resolved separately, once the real redirect URI is known.
194#[derive(Debug, Clone, Serialize, Deserialize)]
195pub struct OAuthDiscovery {
196 pub resource_metadata: ProtectedResourceMetadata,
197 pub auth_server_metadata: AuthServerMetadata,
198 pub scopes: Vec<String>,
199}
200
201/// The persisted OAuth session for a context server.
202///
203/// Stored in the keychain so startup can restore a refresh-capable provider
204/// without another browser flow. Deliberately excludes the full discovery
205/// metadata to keep the serialized size well within keychain item limits.
206#[derive(Clone, Serialize, Deserialize)]
207pub struct OAuthSession {
208 pub token_endpoint: Url,
209 pub resource: Url,
210 pub client_registration: OAuthClientRegistration,
211 pub tokens: OAuthTokens,
212}
213
214impl std::fmt::Debug for OAuthSession {
215 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
216 f.debug_struct("OAuthSession")
217 .field("token_endpoint", &self.token_endpoint)
218 .field("resource", &self.resource)
219 .field("client_registration", &self.client_registration)
220 .field("tokens", &self.tokens)
221 .finish()
222 }
223}
224
225/// Error codes defined by RFC 6750 Section 3.1 for Bearer token authentication.
226#[derive(Debug, Clone, Copy, PartialEq, Eq)]
227pub enum BearerError {
228 /// The request is missing a required parameter, includes an unsupported
229 /// parameter or parameter value, or is otherwise malformed.
230 InvalidRequest,
231 /// The access token provided is expired, revoked, malformed, or invalid.
232 InvalidToken,
233 /// The request requires higher privileges than provided by the access token.
234 InsufficientScope,
235 /// An unrecognized error code (extension or future spec addition).
236 Other,
237}
238
239impl BearerError {
240 fn parse(value: &str) -> Self {
241 match value {
242 "invalid_request" => BearerError::InvalidRequest,
243 "invalid_token" => BearerError::InvalidToken,
244 "insufficient_scope" => BearerError::InsufficientScope,
245 _ => BearerError::Other,
246 }
247 }
248}
249
250/// Fields extracted from a `WWW-Authenticate: Bearer` header.
251///
252/// Per RFC 9728 Section 5.1, MCP servers include `resource_metadata` to point
253/// at the Protected Resource Metadata document. The optional `scope` parameter
254/// (RFC 6750 Section 3) indicates scopes required for the request.
255#[derive(Debug, Clone, PartialEq, Eq)]
256pub struct WwwAuthenticate {
257 pub resource_metadata: Option<Url>,
258 pub scope: Option<Vec<String>>,
259 /// The parsed `error` parameter per RFC 6750 Section 3.1.
260 pub error: Option<BearerError>,
261 pub error_description: Option<String>,
262}
263
264/// Parse a `WWW-Authenticate` header value.
265///
266/// Expects the `Bearer` scheme followed by comma-separated `key="value"` pairs.
267/// Per RFC 6750 and RFC 9728, the relevant parameters are:
268/// - `resource_metadata` — URL of the Protected Resource Metadata document
269/// - `scope` — space-separated list of required scopes
270/// - `error` — error code (e.g. "insufficient_scope")
271/// - `error_description` — human-readable error description
272pub fn parse_www_authenticate(header: &str) -> Result<WwwAuthenticate> {
273 let header = header.trim();
274
275 let params_str = if header.len() >= 6 && header[..6].eq_ignore_ascii_case("bearer") {
276 header[6..].trim()
277 } else {
278 bail!("WWW-Authenticate header does not use Bearer scheme");
279 };
280
281 if params_str.is_empty() {
282 return Ok(WwwAuthenticate {
283 resource_metadata: None,
284 scope: None,
285 error: None,
286 error_description: None,
287 });
288 }
289
290 let params = parse_auth_params(params_str);
291
292 let resource_metadata = params
293 .get("resource_metadata")
294 .map(|v| Url::parse(v))
295 .transpose()
296 .map_err(|e| anyhow!("invalid resource_metadata URL: {}", e))?;
297
298 let scope = params
299 .get("scope")
300 .map(|v| v.split_whitespace().map(String::from).collect());
301
302 let error = params.get("error").map(|v| BearerError::parse(v));
303 let error_description = params.get("error_description").cloned();
304
305 Ok(WwwAuthenticate {
306 resource_metadata,
307 scope,
308 error,
309 error_description,
310 })
311}
312
313/// Parse comma-separated `key="value"` or `key=token` parameters from an
314/// auth-param list (RFC 7235 Section 2.1).
315fn parse_auth_params(input: &str) -> collections::HashMap<String, String> {
316 let mut params = collections::HashMap::default();
317 let mut remaining = input.trim();
318
319 while !remaining.is_empty() {
320 // Skip leading whitespace and commas.
321 remaining = remaining.trim_start_matches(|c: char| c == ',' || c.is_whitespace());
322 if remaining.is_empty() {
323 break;
324 }
325
326 // Find the key (everything before '=').
327 let eq_pos = match remaining.find('=') {
328 Some(pos) => pos,
329 None => break,
330 };
331
332 let key = remaining[..eq_pos].trim().to_lowercase();
333 remaining = &remaining[eq_pos + 1..];
334 remaining = remaining.trim_start();
335
336 // Parse the value: either quoted or unquoted (token).
337 let value;
338 if remaining.starts_with('"') {
339 // Quoted string: find the closing quote, handling escaped chars.
340 remaining = &remaining[1..]; // skip opening quote
341 let mut val = String::new();
342 let mut chars = remaining.char_indices();
343 loop {
344 match chars.next() {
345 Some((_, '\\')) => {
346 // Escaped character — take the next char literally.
347 if let Some((_, c)) = chars.next() {
348 val.push(c);
349 }
350 }
351 Some((i, '"')) => {
352 remaining = &remaining[i + 1..];
353 break;
354 }
355 Some((_, c)) => val.push(c),
356 None => {
357 remaining = "";
358 break;
359 }
360 }
361 }
362 value = val;
363 } else {
364 // Unquoted token: read until comma or whitespace.
365 let end = remaining
366 .find(|c: char| c == ',' || c.is_whitespace())
367 .unwrap_or(remaining.len());
368 value = remaining[..end].to_string();
369 remaining = &remaining[end..];
370 }
371
372 if !key.is_empty() {
373 params.insert(key, value);
374 }
375 }
376
377 params
378}
379
380/// Construct the well-known Protected Resource Metadata URIs for a given MCP
381/// server URL, per RFC 9728 Section 3.
382///
383/// Returns URIs in priority order:
384/// 1. Path-specific: `https://<host>/.well-known/oauth-protected-resource/<path>`
385/// 2. Root: `https://<host>/.well-known/oauth-protected-resource`
386pub fn protected_resource_metadata_urls(server_url: &Url) -> Vec<Url> {
387 let mut urls = Vec::new();
388 let base = format!("{}://{}", server_url.scheme(), server_url.authority());
389
390 let path = server_url.path().trim_start_matches('/');
391 if !path.is_empty() {
392 if let Ok(url) = Url::parse(&format!(
393 "{}/.well-known/oauth-protected-resource/{}",
394 base, path
395 )) {
396 urls.push(url);
397 }
398 }
399
400 if let Ok(url) = Url::parse(&format!("{}/.well-known/oauth-protected-resource", base)) {
401 urls.push(url);
402 }
403
404 urls
405}
406
407/// Construct the well-known Authorization Server Metadata URIs for a given
408/// issuer URL, per RFC 8414 Section 3.1 and Section 5 (OIDC compat).
409///
410/// Returns URIs in priority order, which differs depending on whether the
411/// issuer URL has a path component.
412pub fn auth_server_metadata_urls(issuer: &Url) -> Vec<Url> {
413 let mut urls = Vec::new();
414 let base = format!("{}://{}", issuer.scheme(), issuer.authority());
415 let path = issuer.path().trim_matches('/');
416
417 if !path.is_empty() {
418 // Issuer with path: try path-inserted variants first.
419 if let Ok(url) = Url::parse(&format!(
420 "{}/.well-known/oauth-authorization-server/{}",
421 base, path
422 )) {
423 urls.push(url);
424 }
425 if let Ok(url) = Url::parse(&format!(
426 "{}/.well-known/openid-configuration/{}",
427 base, path
428 )) {
429 urls.push(url);
430 }
431 if let Ok(url) = Url::parse(&format!(
432 "{}/{}/.well-known/openid-configuration",
433 base, path
434 )) {
435 urls.push(url);
436 }
437 } else {
438 // No path: standard well-known locations.
439 if let Ok(url) = Url::parse(&format!("{}/.well-known/oauth-authorization-server", base)) {
440 urls.push(url);
441 }
442 if let Ok(url) = Url::parse(&format!("{}/.well-known/openid-configuration", base)) {
443 urls.push(url);
444 }
445 }
446
447 urls
448}
449
450// -- Canonical server URI (RFC 8707) -----------------------------------------
451
452/// Derive the canonical resource URI for an MCP server URL, suitable for the
453/// `resource` parameter in authorization and token requests per RFC 8707.
454///
455/// Lowercases the scheme and host, preserves the path (without trailing slash),
456/// strips fragments and query strings.
457pub fn canonical_server_uri(server_url: &Url) -> String {
458 let mut uri = format!(
459 "{}://{}",
460 server_url.scheme().to_ascii_lowercase(),
461 server_url.host_str().unwrap_or("").to_ascii_lowercase(),
462 );
463 if let Some(port) = server_url.port() {
464 uri.push_str(&format!(":{}", port));
465 }
466 let path = server_url.path();
467 if path != "/" {
468 uri.push_str(path.trim_end_matches('/'));
469 }
470 uri
471}
472
473// -- Scope selection ---------------------------------------------------------
474
475/// Select scopes following the MCP spec's Scope Selection Strategy:
476/// 1. Use `scope` from the `WWW-Authenticate` challenge if present.
477/// 2. Fall back to `scopes_supported` from Protected Resource Metadata.
478/// 3. Return empty if neither is available.
479pub fn select_scopes(
480 www_authenticate: &WwwAuthenticate,
481 resource_metadata: &ProtectedResourceMetadata,
482) -> Vec<String> {
483 if let Some(ref scopes) = www_authenticate.scope {
484 if !scopes.is_empty() {
485 return scopes.clone();
486 }
487 }
488 resource_metadata
489 .scopes_supported
490 .clone()
491 .unwrap_or_default()
492}
493
494// -- Client registration strategy --------------------------------------------
495
496/// The registration approach to use, determined from auth server metadata.
497#[derive(Debug, Clone, PartialEq, Eq)]
498pub enum ClientRegistrationStrategy {
499 /// The auth server supports CIMD. Use the CIMD URL as client_id directly.
500 Cimd { client_id: String },
501 /// The auth server has a registration endpoint. Caller must POST to it.
502 Dcr { registration_endpoint: Url },
503 /// No supported registration mechanism.
504 Unavailable,
505}
506
507/// Determine how to register with the authorization server, following the
508/// spec's recommended priority: CIMD first, DCR fallback.
509pub fn determine_registration_strategy(
510 auth_server_metadata: &AuthServerMetadata,
511) -> ClientRegistrationStrategy {
512 if auth_server_metadata.client_id_metadata_document_supported {
513 ClientRegistrationStrategy::Cimd {
514 client_id: CIMD_URL.to_string(),
515 }
516 } else if let Some(ref endpoint) = auth_server_metadata.registration_endpoint {
517 ClientRegistrationStrategy::Dcr {
518 registration_endpoint: endpoint.clone(),
519 }
520 } else {
521 ClientRegistrationStrategy::Unavailable
522 }
523}
524
525// -- PKCE (RFC 7636) ---------------------------------------------------------
526
527/// A PKCE code verifier and its S256 challenge.
528#[derive(Clone)]
529pub struct PkceChallenge {
530 pub verifier: String,
531 pub challenge: String,
532}
533
534impl std::fmt::Debug for PkceChallenge {
535 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
536 f.debug_struct("PkceChallenge")
537 .field("verifier", &"[redacted]")
538 .field("challenge", &self.challenge)
539 .finish()
540 }
541}
542
543/// Generate a PKCE code verifier and S256 challenge per RFC 7636.
544///
545/// The verifier is 43 base64url characters derived from 32 random bytes.
546/// The challenge is `BASE64URL(SHA256(verifier))`.
547pub fn generate_pkce_challenge() -> PkceChallenge {
548 let mut random_bytes = [0u8; 32];
549 rand::rng().fill(&mut random_bytes);
550 let engine = base64::engine::general_purpose::URL_SAFE_NO_PAD;
551 let verifier = engine.encode(&random_bytes);
552
553 let digest = Sha256::digest(verifier.as_bytes());
554 let challenge = engine.encode(digest);
555
556 PkceChallenge {
557 verifier,
558 challenge,
559 }
560}
561
562// -- Authorization URL construction ------------------------------------------
563
564/// Build the authorization URL for the OAuth Authorization Code + PKCE flow.
565pub fn build_authorization_url(
566 auth_server_metadata: &AuthServerMetadata,
567 client_id: &str,
568 redirect_uri: &str,
569 scopes: &[String],
570 resource: &str,
571 pkce: &PkceChallenge,
572 state: &str,
573) -> Url {
574 let mut url = auth_server_metadata.authorization_endpoint.clone();
575 {
576 let mut query = url.query_pairs_mut();
577 query.append_pair("response_type", "code");
578 query.append_pair("client_id", client_id);
579 query.append_pair("redirect_uri", redirect_uri);
580 if !scopes.is_empty() {
581 query.append_pair("scope", &scopes.join(" "));
582 }
583 query.append_pair("resource", resource);
584 query.append_pair("code_challenge", &pkce.challenge);
585 query.append_pair("code_challenge_method", "S256");
586 query.append_pair("state", state);
587 }
588 url
589}
590
591// -- Token endpoint request bodies -------------------------------------------
592
593/// The JSON body returned by the token endpoint on success.
594#[derive(Deserialize)]
595pub struct TokenResponse {
596 pub access_token: String,
597 #[serde(default)]
598 pub refresh_token: Option<String>,
599 #[serde(default)]
600 pub expires_in: Option<u64>,
601 #[serde(default)]
602 pub token_type: Option<String>,
603}
604
605impl std::fmt::Debug for TokenResponse {
606 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
607 f.debug_struct("TokenResponse")
608 .field("access_token", &"[redacted]")
609 .field(
610 "refresh_token",
611 &self.refresh_token.as_ref().map(|_| "[redacted]"),
612 )
613 .field("expires_in", &self.expires_in)
614 .field("token_type", &self.token_type)
615 .finish()
616 }
617}
618
619impl TokenResponse {
620 /// Convert into `OAuthTokens`, computing `expires_at` from `expires_in`.
621 pub fn into_tokens(self) -> OAuthTokens {
622 let expires_at = self
623 .expires_in
624 .map(|secs| SystemTime::now() + Duration::from_secs(secs));
625 OAuthTokens {
626 access_token: self.access_token,
627 refresh_token: self.refresh_token,
628 expires_at,
629 }
630 }
631}
632
633/// Build the form-encoded body for an authorization code token exchange.
634pub fn token_exchange_params(
635 code: &str,
636 client_id: &str,
637 redirect_uri: &str,
638 code_verifier: &str,
639 resource: &str,
640) -> Vec<(&'static str, String)> {
641 vec![
642 ("grant_type", "authorization_code".to_string()),
643 ("code", code.to_string()),
644 ("redirect_uri", redirect_uri.to_string()),
645 ("client_id", client_id.to_string()),
646 ("code_verifier", code_verifier.to_string()),
647 ("resource", resource.to_string()),
648 ]
649}
650
651/// Build the form-encoded body for a token refresh request.
652pub fn token_refresh_params(
653 refresh_token: &str,
654 client_id: &str,
655 resource: &str,
656) -> Vec<(&'static str, String)> {
657 vec![
658 ("grant_type", "refresh_token".to_string()),
659 ("refresh_token", refresh_token.to_string()),
660 ("client_id", client_id.to_string()),
661 ("resource", resource.to_string()),
662 ]
663}
664
665// -- DCR request body (RFC 7591) ---------------------------------------------
666
667/// Build the JSON body for a Dynamic Client Registration request.
668///
669/// The `redirect_uri` should be the actual loopback URI with the ephemeral
670/// port (e.g. `http://127.0.0.1:12345/callback`). Some auth servers do strict
671/// redirect URI matching even for loopback addresses, so we register the
672/// exact URI we intend to use.
673pub fn dcr_registration_body(redirect_uri: &str) -> serde_json::Value {
674 serde_json::json!({
675 "client_name": "Zed",
676 "redirect_uris": [redirect_uri],
677 "grant_types": ["authorization_code"],
678 "response_types": ["code"],
679 "token_endpoint_auth_method": "none"
680 })
681}
682
683// -- Discovery (async, hits real endpoints) ----------------------------------
684
685/// Fetch Protected Resource Metadata from the MCP server.
686///
687/// Tries the `resource_metadata` URL from the `WWW-Authenticate` header first,
688/// then falls back to well-known URIs constructed from `server_url`.
689pub async fn fetch_protected_resource_metadata(
690 http_client: &Arc<dyn HttpClient>,
691 server_url: &Url,
692 www_authenticate: &WwwAuthenticate,
693) -> Result<ProtectedResourceMetadata> {
694 let candidate_urls = match &www_authenticate.resource_metadata {
695 Some(url) if url.origin() == server_url.origin() => vec![url.clone()],
696 Some(url) => {
697 log::warn!(
698 "Ignoring cross-origin resource_metadata URL {} \
699 (server origin: {})",
700 url,
701 server_url.origin().unicode_serialization()
702 );
703 protected_resource_metadata_urls(server_url)
704 }
705 None => protected_resource_metadata_urls(server_url),
706 };
707
708 for url in &candidate_urls {
709 match fetch_json::<ProtectedResourceMetadataResponse>(http_client, url).await {
710 Ok(response) => {
711 if response.authorization_servers.is_empty() {
712 bail!(
713 "Protected Resource Metadata at {} has no authorization_servers",
714 url
715 );
716 }
717 return Ok(ProtectedResourceMetadata {
718 resource: response.resource.unwrap_or_else(|| server_url.clone()),
719 authorization_servers: response.authorization_servers,
720 scopes_supported: response.scopes_supported,
721 });
722 }
723 Err(err) => {
724 log::debug!(
725 "Failed to fetch Protected Resource Metadata from {}: {}",
726 url,
727 err
728 );
729 }
730 }
731 }
732
733 bail!(
734 "Could not fetch Protected Resource Metadata for {}",
735 server_url
736 )
737}
738
739/// Fetch Authorization Server Metadata, trying RFC 8414 and OIDC Discovery
740/// endpoints in the priority order specified by the MCP spec.
741pub async fn fetch_auth_server_metadata(
742 http_client: &Arc<dyn HttpClient>,
743 issuer: &Url,
744) -> Result<AuthServerMetadata> {
745 let candidate_urls = auth_server_metadata_urls(issuer);
746
747 for url in &candidate_urls {
748 match fetch_json::<AuthServerMetadataResponse>(http_client, url).await {
749 Ok(response) => {
750 let reported_issuer = response.issuer.unwrap_or_else(|| issuer.clone());
751 if reported_issuer != *issuer {
752 bail!(
753 "Auth server metadata issuer mismatch: expected {}, got {}",
754 issuer,
755 reported_issuer
756 );
757 }
758
759 return Ok(AuthServerMetadata {
760 issuer: reported_issuer,
761 authorization_endpoint: response
762 .authorization_endpoint
763 .ok_or_else(|| anyhow!("missing authorization_endpoint"))?,
764 token_endpoint: response
765 .token_endpoint
766 .ok_or_else(|| anyhow!("missing token_endpoint"))?,
767 registration_endpoint: response.registration_endpoint,
768 scopes_supported: response.scopes_supported,
769 code_challenge_methods_supported: response.code_challenge_methods_supported,
770 client_id_metadata_document_supported: response
771 .client_id_metadata_document_supported
772 .unwrap_or(false),
773 });
774 }
775 Err(err) => {
776 log::debug!("Failed to fetch Auth Server Metadata from {}: {}", url, err);
777 }
778 }
779 }
780
781 bail!(
782 "Could not fetch Authorization Server Metadata for {}",
783 issuer
784 )
785}
786
787/// Run the full discovery flow: fetch resource metadata, then auth server
788/// metadata, then select scopes. Client registration is resolved separately,
789/// once the real redirect URI is known.
790pub async fn discover(
791 http_client: &Arc<dyn HttpClient>,
792 server_url: &Url,
793 www_authenticate: &WwwAuthenticate,
794) -> Result<OAuthDiscovery> {
795 let resource_metadata =
796 fetch_protected_resource_metadata(http_client, server_url, www_authenticate).await?;
797
798 let auth_server_url = resource_metadata
799 .authorization_servers
800 .first()
801 .ok_or_else(|| anyhow!("no authorization servers in resource metadata"))?;
802
803 let auth_server_metadata = fetch_auth_server_metadata(http_client, auth_server_url).await?;
804
805 // Verify PKCE S256 support (spec requirement).
806 match &auth_server_metadata.code_challenge_methods_supported {
807 Some(methods) if methods.iter().any(|m| m == "S256") => {}
808 Some(_) => bail!("authorization server does not support S256 PKCE"),
809 None => bail!("authorization server does not advertise code_challenge_methods_supported"),
810 }
811
812 // Verify there is at least one supported registration strategy before we
813 // present the server as ready to authenticate.
814 match determine_registration_strategy(&auth_server_metadata) {
815 ClientRegistrationStrategy::Cimd { .. } | ClientRegistrationStrategy::Dcr { .. } => {}
816 ClientRegistrationStrategy::Unavailable => {
817 bail!("authorization server supports neither CIMD nor DCR")
818 }
819 }
820
821 let scopes = select_scopes(www_authenticate, &resource_metadata);
822
823 Ok(OAuthDiscovery {
824 resource_metadata,
825 auth_server_metadata,
826 scopes,
827 })
828}
829
830/// Resolve the OAuth client registration for an authorization flow.
831///
832/// CIMD uses the static client metadata document directly. For DCR, a fresh
833/// registration is performed each time because the loopback redirect URI
834/// includes an ephemeral port that changes every flow.
835pub async fn resolve_client_registration(
836 http_client: &Arc<dyn HttpClient>,
837 discovery: &OAuthDiscovery,
838 redirect_uri: &str,
839) -> Result<OAuthClientRegistration> {
840 match determine_registration_strategy(&discovery.auth_server_metadata) {
841 ClientRegistrationStrategy::Cimd { client_id } => Ok(OAuthClientRegistration {
842 client_id,
843 client_secret: None,
844 }),
845 ClientRegistrationStrategy::Dcr {
846 registration_endpoint,
847 } => perform_dcr(http_client, ®istration_endpoint, redirect_uri).await,
848 ClientRegistrationStrategy::Unavailable => {
849 bail!("authorization server supports neither CIMD nor DCR")
850 }
851 }
852}
853
854// -- Dynamic Client Registration (RFC 7591) ----------------------------------
855
856/// Perform Dynamic Client Registration with the authorization server.
857pub async fn perform_dcr(
858 http_client: &Arc<dyn HttpClient>,
859 registration_endpoint: &Url,
860 redirect_uri: &str,
861) -> Result<OAuthClientRegistration> {
862 validate_oauth_url(registration_endpoint)?;
863
864 let body = dcr_registration_body(redirect_uri);
865 let body_bytes = serde_json::to_vec(&body)?;
866
867 let request = Request::builder()
868 .method(http_client::http::Method::POST)
869 .uri(registration_endpoint.as_str())
870 .header("Content-Type", "application/json")
871 .header("Accept", "application/json")
872 .body(AsyncBody::from(body_bytes))?;
873
874 let mut response = http_client.send(request).await?;
875
876 if !response.status().is_success() {
877 let mut error_body = String::new();
878 response.body_mut().read_to_string(&mut error_body).await?;
879 bail!(
880 "DCR failed with status {}: {}",
881 response.status(),
882 error_body
883 );
884 }
885
886 let mut response_body = String::new();
887 response
888 .body_mut()
889 .read_to_string(&mut response_body)
890 .await?;
891
892 let dcr_response: DcrResponse =
893 serde_json::from_str(&response_body).context("failed to parse DCR response")?;
894
895 Ok(OAuthClientRegistration {
896 client_id: dcr_response.client_id,
897 client_secret: dcr_response.client_secret,
898 })
899}
900
901// -- Token exchange and refresh (async) --------------------------------------
902
903/// Exchange an authorization code for tokens at the token endpoint.
904pub async fn exchange_code(
905 http_client: &Arc<dyn HttpClient>,
906 auth_server_metadata: &AuthServerMetadata,
907 code: &str,
908 client_id: &str,
909 redirect_uri: &str,
910 code_verifier: &str,
911 resource: &str,
912) -> Result<OAuthTokens> {
913 let params = token_exchange_params(code, client_id, redirect_uri, code_verifier, resource);
914 post_token_request(http_client, &auth_server_metadata.token_endpoint, ¶ms).await
915}
916
917/// Refresh tokens using a refresh token.
918pub async fn refresh_tokens(
919 http_client: &Arc<dyn HttpClient>,
920 token_endpoint: &Url,
921 refresh_token: &str,
922 client_id: &str,
923 resource: &str,
924) -> Result<OAuthTokens> {
925 let params = token_refresh_params(refresh_token, client_id, resource);
926 post_token_request(http_client, token_endpoint, ¶ms).await
927}
928
929/// POST form-encoded parameters to a token endpoint and parse the response.
930async fn post_token_request(
931 http_client: &Arc<dyn HttpClient>,
932 token_endpoint: &Url,
933 params: &[(&str, String)],
934) -> Result<OAuthTokens> {
935 validate_oauth_url(token_endpoint)?;
936
937 let body = url::form_urlencoded::Serializer::new(String::new())
938 .extend_pairs(params.iter().map(|(k, v)| (*k, v.as_str())))
939 .finish();
940
941 let request = Request::builder()
942 .method(http_client::http::Method::POST)
943 .uri(token_endpoint.as_str())
944 .header("Content-Type", "application/x-www-form-urlencoded")
945 .header("Accept", "application/json")
946 .body(AsyncBody::from(body.into_bytes()))?;
947
948 let mut response = http_client.send(request).await?;
949
950 if !response.status().is_success() {
951 let mut error_body = String::new();
952 response.body_mut().read_to_string(&mut error_body).await?;
953 bail!(
954 "token request failed with status {}: {}",
955 response.status(),
956 error_body
957 );
958 }
959
960 let mut response_body = String::new();
961 response
962 .body_mut()
963 .read_to_string(&mut response_body)
964 .await?;
965
966 let token_response: TokenResponse =
967 serde_json::from_str(&response_body).context("failed to parse token response")?;
968
969 Ok(token_response.into_tokens())
970}
971
972// -- Loopback HTTP callback server -------------------------------------------
973
974/// An OAuth authorization callback received via the loopback HTTP server.
975pub struct OAuthCallback {
976 pub code: String,
977 pub state: String,
978}
979
980impl std::fmt::Debug for OAuthCallback {
981 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
982 f.debug_struct("OAuthCallback")
983 .field("code", &"[redacted]")
984 .field("state", &"[redacted]")
985 .finish()
986 }
987}
988
989impl OAuthCallback {
990 /// Parse the query string from a callback URL like
991 /// `http://127.0.0.1:<port>/callback?code=...&state=...`.
992 pub fn parse_query(query: &str) -> Result<Self> {
993 let params = http_client::OAuthCallbackParams::parse_query(query)?;
994 Ok(Self {
995 code: params.code,
996 state: params.state,
997 })
998 }
999}
1000
1001/// Start a loopback HTTP server to receive the OAuth authorization callback.
1002///
1003/// Binds to an ephemeral loopback port for each flow.
1004///
1005/// Returns `(redirect_uri, callback_future)`. The caller should use the
1006/// redirect URI in the authorization request, open the browser, then await
1007/// the future to receive the callback.
1008///
1009/// The server accepts exactly one request on `/callback`, validates that it
1010/// contains `code` and `state` query parameters, responds with a minimal
1011/// HTML page telling the user they can close the tab, and shuts down.
1012///
1013/// The callback server shuts down when the returned oneshot receiver is dropped
1014/// (e.g. because the authentication task was cancelled), or after a timeout.
1015pub async fn start_callback_server() -> Result<(
1016 String,
1017 futures::channel::oneshot::Receiver<Result<OAuthCallback>>,
1018)> {
1019 let (redirect_uri, rx) = http_client::start_oauth_callback_server()?;
1020 let (tx, mapped_rx) = futures::channel::oneshot::channel();
1021 smol::spawn(async move {
1022 let result = match rx.await {
1023 Ok(Ok(params)) => Ok(OAuthCallback {
1024 code: params.code,
1025 state: params.state,
1026 }),
1027 Ok(Err(e)) => Err(e),
1028 Err(_) => Err(anyhow!("OAuth callback channel was cancelled")),
1029 };
1030 let _ = tx.send(result);
1031 })
1032 .detach();
1033 Ok((redirect_uri, mapped_rx))
1034}
1035
1036// -- JSON fetch helper -------------------------------------------------------
1037
1038async fn fetch_json<T: serde::de::DeserializeOwned>(
1039 http_client: &Arc<dyn HttpClient>,
1040 url: &Url,
1041) -> Result<T> {
1042 validate_oauth_url(url)?;
1043
1044 let request = Request::builder()
1045 .method(http_client::http::Method::GET)
1046 .uri(url.as_str())
1047 .header("Accept", "application/json")
1048 .body(AsyncBody::default())?;
1049
1050 let mut response = http_client.send(request).await?;
1051
1052 if !response.status().is_success() {
1053 bail!("HTTP {} fetching {}", response.status(), url);
1054 }
1055
1056 let mut body = String::new();
1057 response.body_mut().read_to_string(&mut body).await?;
1058 serde_json::from_str(&body).with_context(|| format!("failed to parse JSON from {}", url))
1059}
1060
1061// -- Serde response types for discovery --------------------------------------
1062
1063#[derive(Debug, Deserialize)]
1064struct ProtectedResourceMetadataResponse {
1065 #[serde(default)]
1066 resource: Option<Url>,
1067 #[serde(default)]
1068 authorization_servers: Vec<Url>,
1069 #[serde(default)]
1070 scopes_supported: Option<Vec<String>>,
1071}
1072
1073#[derive(Debug, Deserialize)]
1074struct AuthServerMetadataResponse {
1075 #[serde(default)]
1076 issuer: Option<Url>,
1077 #[serde(default)]
1078 authorization_endpoint: Option<Url>,
1079 #[serde(default)]
1080 token_endpoint: Option<Url>,
1081 #[serde(default)]
1082 registration_endpoint: Option<Url>,
1083 #[serde(default)]
1084 scopes_supported: Option<Vec<String>>,
1085 #[serde(default)]
1086 code_challenge_methods_supported: Option<Vec<String>>,
1087 #[serde(default)]
1088 client_id_metadata_document_supported: Option<bool>,
1089}
1090
1091#[derive(Debug, Deserialize)]
1092struct DcrResponse {
1093 client_id: String,
1094 #[serde(default)]
1095 client_secret: Option<String>,
1096}
1097
1098/// Provides OAuth tokens to the HTTP transport layer.
1099///
1100/// The transport calls `access_token()` before each request. On a 401 response
1101/// it calls `try_refresh()` and retries once if the refresh succeeds.
1102#[async_trait]
1103pub trait OAuthTokenProvider: Send + Sync {
1104 /// Returns the current access token, if one is available.
1105 fn access_token(&self) -> Option<String>;
1106
1107 /// Attempts to refresh the access token. Returns `true` if a new token was
1108 /// obtained and the request should be retried.
1109 async fn try_refresh(&self) -> Result<bool>;
1110}
1111
1112/// Concrete `OAuthTokenProvider` backed by a full persisted OAuth session and
1113/// an HTTP client for token refresh. The same provider type is used both after
1114/// an interactive authentication flow and when restoring a saved session from
1115/// the keychain on startup.
1116pub struct McpOAuthTokenProvider {
1117 session: SyncMutex<OAuthSession>,
1118 http_client: Arc<dyn HttpClient>,
1119 token_refresh_tx: Option<mpsc::UnboundedSender<OAuthSession>>,
1120}
1121
1122impl McpOAuthTokenProvider {
1123 pub fn new(
1124 session: OAuthSession,
1125 http_client: Arc<dyn HttpClient>,
1126 token_refresh_tx: Option<mpsc::UnboundedSender<OAuthSession>>,
1127 ) -> Self {
1128 Self {
1129 session: SyncMutex::new(session),
1130 http_client,
1131 token_refresh_tx,
1132 }
1133 }
1134
1135 fn access_token_is_expired(tokens: &OAuthTokens) -> bool {
1136 tokens.expires_at.is_some_and(|expires_at| {
1137 SystemTime::now()
1138 .checked_add(Duration::from_secs(30))
1139 .is_some_and(|now_with_buffer| expires_at <= now_with_buffer)
1140 })
1141 }
1142}
1143
1144#[async_trait]
1145impl OAuthTokenProvider for McpOAuthTokenProvider {
1146 fn access_token(&self) -> Option<String> {
1147 let session = self.session.lock();
1148 if Self::access_token_is_expired(&session.tokens) {
1149 return None;
1150 }
1151 Some(session.tokens.access_token.clone())
1152 }
1153
1154 async fn try_refresh(&self) -> Result<bool> {
1155 let (refresh_token, token_endpoint, resource, client_id) = {
1156 let session = self.session.lock();
1157 match session.tokens.refresh_token.clone() {
1158 Some(refresh_token) => (
1159 refresh_token,
1160 session.token_endpoint.clone(),
1161 session.resource.clone(),
1162 session.client_registration.client_id.clone(),
1163 ),
1164 None => return Ok(false),
1165 }
1166 };
1167
1168 let resource_str = canonical_server_uri(&resource);
1169
1170 match refresh_tokens(
1171 &self.http_client,
1172 &token_endpoint,
1173 &refresh_token,
1174 &client_id,
1175 &resource_str,
1176 )
1177 .await
1178 {
1179 Ok(mut new_tokens) => {
1180 if new_tokens.refresh_token.is_none() {
1181 new_tokens.refresh_token = Some(refresh_token);
1182 }
1183
1184 {
1185 let mut session = self.session.lock();
1186 session.tokens = new_tokens;
1187
1188 if let Some(ref tx) = self.token_refresh_tx {
1189 tx.unbounded_send(session.clone()).ok();
1190 }
1191 }
1192
1193 Ok(true)
1194 }
1195 Err(err) => {
1196 log::warn!("OAuth token refresh failed: {}", err);
1197 Ok(false)
1198 }
1199 }
1200 }
1201}
1202
1203#[cfg(test)]
1204mod tests {
1205 use super::*;
1206 use http_client::Response;
1207
1208 // -- require_https_or_loopback tests ------------------------------------
1209
1210 #[test]
1211 fn test_require_https_or_loopback_accepts_https() {
1212 let url = Url::parse("https://auth.example.com/token").unwrap();
1213 assert!(require_https_or_loopback(&url).is_ok());
1214 }
1215
1216 #[test]
1217 fn test_require_https_or_loopback_rejects_http_remote() {
1218 let url = Url::parse("http://auth.example.com/token").unwrap();
1219 assert!(require_https_or_loopback(&url).is_err());
1220 }
1221
1222 #[test]
1223 fn test_require_https_or_loopback_accepts_http_127_0_0_1() {
1224 let url = Url::parse("http://127.0.0.1:8080/callback").unwrap();
1225 assert!(require_https_or_loopback(&url).is_ok());
1226 }
1227
1228 #[test]
1229 fn test_require_https_or_loopback_accepts_http_ipv6_loopback() {
1230 let url = Url::parse("http://[::1]:8080/callback").unwrap();
1231 assert!(require_https_or_loopback(&url).is_ok());
1232 }
1233
1234 #[test]
1235 fn test_require_https_or_loopback_accepts_http_localhost() {
1236 let url = Url::parse("http://localhost:8080/callback").unwrap();
1237 assert!(require_https_or_loopback(&url).is_ok());
1238 }
1239
1240 #[test]
1241 fn test_require_https_or_loopback_accepts_http_localhost_case_insensitive() {
1242 let url = Url::parse("http://LOCALHOST:8080/callback").unwrap();
1243 assert!(require_https_or_loopback(&url).is_ok());
1244 }
1245
1246 #[test]
1247 fn test_require_https_or_loopback_rejects_http_non_loopback_ip() {
1248 let url = Url::parse("http://192.168.1.1:8080/token").unwrap();
1249 assert!(require_https_or_loopback(&url).is_err());
1250 }
1251
1252 #[test]
1253 fn test_require_https_or_loopback_rejects_ftp() {
1254 let url = Url::parse("ftp://auth.example.com/token").unwrap();
1255 assert!(require_https_or_loopback(&url).is_err());
1256 }
1257
1258 // -- validate_oauth_url (SSRF) tests ------------------------------------
1259
1260 #[test]
1261 fn test_validate_oauth_url_accepts_https_public() {
1262 let url = Url::parse("https://auth.example.com/token").unwrap();
1263 assert!(validate_oauth_url(&url).is_ok());
1264 }
1265
1266 #[test]
1267 fn test_validate_oauth_url_rejects_private_ipv4_10() {
1268 let url = Url::parse("https://10.0.0.1/token").unwrap();
1269 assert!(validate_oauth_url(&url).is_err());
1270 }
1271
1272 #[test]
1273 fn test_validate_oauth_url_rejects_private_ipv4_172() {
1274 let url = Url::parse("https://172.16.0.1/token").unwrap();
1275 assert!(validate_oauth_url(&url).is_err());
1276 }
1277
1278 #[test]
1279 fn test_validate_oauth_url_rejects_private_ipv4_192() {
1280 let url = Url::parse("https://192.168.1.1/token").unwrap();
1281 assert!(validate_oauth_url(&url).is_err());
1282 }
1283
1284 #[test]
1285 fn test_validate_oauth_url_rejects_link_local() {
1286 let url = Url::parse("https://169.254.169.254/latest/meta-data/").unwrap();
1287 assert!(validate_oauth_url(&url).is_err());
1288 }
1289
1290 #[test]
1291 fn test_validate_oauth_url_rejects_ipv6_ula() {
1292 let url = Url::parse("https://[fd12:3456:789a::1]/token").unwrap();
1293 assert!(validate_oauth_url(&url).is_err());
1294 }
1295
1296 #[test]
1297 fn test_validate_oauth_url_rejects_ipv6_unspecified() {
1298 let url = Url::parse("https://[::]/token").unwrap();
1299 assert!(validate_oauth_url(&url).is_err());
1300 }
1301
1302 #[test]
1303 fn test_validate_oauth_url_rejects_ipv4_mapped_ipv6_private() {
1304 let url = Url::parse("https://[::ffff:10.0.0.1]/token").unwrap();
1305 assert!(validate_oauth_url(&url).is_err());
1306 }
1307
1308 #[test]
1309 fn test_validate_oauth_url_rejects_ipv4_mapped_ipv6_link_local() {
1310 let url = Url::parse("https://[::ffff:169.254.169.254]/token").unwrap();
1311 assert!(validate_oauth_url(&url).is_err());
1312 }
1313
1314 #[test]
1315 fn test_validate_oauth_url_allows_http_loopback() {
1316 // Loopback is permitted (it's our callback server).
1317 let url = Url::parse("http://127.0.0.1:8080/callback").unwrap();
1318 assert!(validate_oauth_url(&url).is_ok());
1319 }
1320
1321 #[test]
1322 fn test_validate_oauth_url_allows_https_public_ip() {
1323 let url = Url::parse("https://93.184.216.34/token").unwrap();
1324 assert!(validate_oauth_url(&url).is_ok());
1325 }
1326
1327 // -- parse_www_authenticate tests ----------------------------------------
1328
1329 #[test]
1330 fn test_parse_www_authenticate_with_resource_metadata_and_scope() {
1331 let header = r#"Bearer resource_metadata="https://mcp.example.com/.well-known/oauth-protected-resource", scope="files:read user:profile""#;
1332 let result = parse_www_authenticate(header).unwrap();
1333
1334 assert_eq!(
1335 result.resource_metadata.as_ref().map(|u| u.as_str()),
1336 Some("https://mcp.example.com/.well-known/oauth-protected-resource")
1337 );
1338 assert_eq!(
1339 result.scope,
1340 Some(vec!["files:read".to_string(), "user:profile".to_string()])
1341 );
1342 assert_eq!(result.error, None);
1343 }
1344
1345 #[test]
1346 fn test_parse_www_authenticate_resource_metadata_only() {
1347 let header = r#"Bearer resource_metadata="https://mcp.example.com/.well-known/oauth-protected-resource""#;
1348 let result = parse_www_authenticate(header).unwrap();
1349
1350 assert_eq!(
1351 result.resource_metadata.as_ref().map(|u| u.as_str()),
1352 Some("https://mcp.example.com/.well-known/oauth-protected-resource")
1353 );
1354 assert_eq!(result.scope, None);
1355 }
1356
1357 #[test]
1358 fn test_parse_www_authenticate_bare_bearer() {
1359 let result = parse_www_authenticate("Bearer").unwrap();
1360 assert_eq!(result.resource_metadata, None);
1361 assert_eq!(result.scope, None);
1362 }
1363
1364 #[test]
1365 fn test_parse_www_authenticate_with_error() {
1366 let header = r#"Bearer error="insufficient_scope", scope="files:read files:write", resource_metadata="https://mcp.example.com/.well-known/oauth-protected-resource", error_description="Additional file write permission required""#;
1367 let result = parse_www_authenticate(header).unwrap();
1368
1369 assert_eq!(result.error, Some(BearerError::InsufficientScope));
1370 assert_eq!(
1371 result.error_description.as_deref(),
1372 Some("Additional file write permission required")
1373 );
1374 assert_eq!(
1375 result.scope,
1376 Some(vec!["files:read".to_string(), "files:write".to_string()])
1377 );
1378 assert!(result.resource_metadata.is_some());
1379 }
1380
1381 #[test]
1382 fn test_parse_www_authenticate_invalid_token_error() {
1383 let header =
1384 r#"Bearer error="invalid_token", error_description="The access token expired""#;
1385 let result = parse_www_authenticate(header).unwrap();
1386 assert_eq!(result.error, Some(BearerError::InvalidToken));
1387 }
1388
1389 #[test]
1390 fn test_parse_www_authenticate_invalid_request_error() {
1391 let header = r#"Bearer error="invalid_request""#;
1392 let result = parse_www_authenticate(header).unwrap();
1393 assert_eq!(result.error, Some(BearerError::InvalidRequest));
1394 }
1395
1396 #[test]
1397 fn test_parse_www_authenticate_unknown_error() {
1398 let header = r#"Bearer error="some_future_error""#;
1399 let result = parse_www_authenticate(header).unwrap();
1400 assert_eq!(result.error, Some(BearerError::Other));
1401 }
1402
1403 #[test]
1404 fn test_parse_www_authenticate_rejects_non_bearer() {
1405 let result = parse_www_authenticate("Basic realm=\"example\"");
1406 assert!(result.is_err());
1407 }
1408
1409 #[test]
1410 fn test_parse_www_authenticate_case_insensitive_scheme() {
1411 let header = r#"bearer resource_metadata="https://example.com/.well-known/oauth-protected-resource""#;
1412 let result = parse_www_authenticate(header).unwrap();
1413 assert!(result.resource_metadata.is_some());
1414 }
1415
1416 #[test]
1417 fn test_parse_www_authenticate_multiline_style() {
1418 // Some servers emit the header spread across multiple lines joined by
1419 // whitespace, as shown in the spec examples.
1420 let header = "Bearer resource_metadata=\"https://mcp.example.com/.well-known/oauth-protected-resource\",\n scope=\"files:read\"";
1421 let result = parse_www_authenticate(header).unwrap();
1422 assert!(result.resource_metadata.is_some());
1423 assert_eq!(result.scope, Some(vec!["files:read".to_string()]));
1424 }
1425
1426 #[test]
1427 fn test_protected_resource_metadata_urls_with_path() {
1428 let server_url = Url::parse("https://api.example.com/v1/mcp").unwrap();
1429 let urls = protected_resource_metadata_urls(&server_url);
1430
1431 assert_eq!(urls.len(), 2);
1432 assert_eq!(
1433 urls[0].as_str(),
1434 "https://api.example.com/.well-known/oauth-protected-resource/v1/mcp"
1435 );
1436 assert_eq!(
1437 urls[1].as_str(),
1438 "https://api.example.com/.well-known/oauth-protected-resource"
1439 );
1440 }
1441
1442 #[test]
1443 fn test_protected_resource_metadata_urls_without_path() {
1444 let server_url = Url::parse("https://mcp.example.com").unwrap();
1445 let urls = protected_resource_metadata_urls(&server_url);
1446
1447 assert_eq!(urls.len(), 1);
1448 assert_eq!(
1449 urls[0].as_str(),
1450 "https://mcp.example.com/.well-known/oauth-protected-resource"
1451 );
1452 }
1453
1454 #[test]
1455 fn test_auth_server_metadata_urls_with_path() {
1456 let issuer = Url::parse("https://auth.example.com/tenant1").unwrap();
1457 let urls = auth_server_metadata_urls(&issuer);
1458
1459 assert_eq!(urls.len(), 3);
1460 assert_eq!(
1461 urls[0].as_str(),
1462 "https://auth.example.com/.well-known/oauth-authorization-server/tenant1"
1463 );
1464 assert_eq!(
1465 urls[1].as_str(),
1466 "https://auth.example.com/.well-known/openid-configuration/tenant1"
1467 );
1468 assert_eq!(
1469 urls[2].as_str(),
1470 "https://auth.example.com/tenant1/.well-known/openid-configuration"
1471 );
1472 }
1473
1474 #[test]
1475 fn test_auth_server_metadata_urls_without_path() {
1476 let issuer = Url::parse("https://auth.example.com").unwrap();
1477 let urls = auth_server_metadata_urls(&issuer);
1478
1479 assert_eq!(urls.len(), 2);
1480 assert_eq!(
1481 urls[0].as_str(),
1482 "https://auth.example.com/.well-known/oauth-authorization-server"
1483 );
1484 assert_eq!(
1485 urls[1].as_str(),
1486 "https://auth.example.com/.well-known/openid-configuration"
1487 );
1488 }
1489
1490 // -- Canonical server URI tests ------------------------------------------
1491
1492 #[test]
1493 fn test_canonical_server_uri_simple() {
1494 let url = Url::parse("https://mcp.example.com").unwrap();
1495 assert_eq!(canonical_server_uri(&url), "https://mcp.example.com");
1496 }
1497
1498 #[test]
1499 fn test_canonical_server_uri_with_path() {
1500 let url = Url::parse("https://mcp.example.com/v1/mcp").unwrap();
1501 assert_eq!(canonical_server_uri(&url), "https://mcp.example.com/v1/mcp");
1502 }
1503
1504 #[test]
1505 fn test_canonical_server_uri_strips_trailing_slash() {
1506 let url = Url::parse("https://mcp.example.com/").unwrap();
1507 assert_eq!(canonical_server_uri(&url), "https://mcp.example.com");
1508 }
1509
1510 #[test]
1511 fn test_canonical_server_uri_preserves_port() {
1512 let url = Url::parse("https://mcp.example.com:8443").unwrap();
1513 assert_eq!(canonical_server_uri(&url), "https://mcp.example.com:8443");
1514 }
1515
1516 #[test]
1517 fn test_canonical_server_uri_lowercases() {
1518 let url = Url::parse("HTTPS://MCP.Example.COM/Server/MCP").unwrap();
1519 assert_eq!(
1520 canonical_server_uri(&url),
1521 "https://mcp.example.com/Server/MCP"
1522 );
1523 }
1524
1525 // -- Scope selection tests -----------------------------------------------
1526
1527 #[test]
1528 fn test_select_scopes_prefers_www_authenticate() {
1529 let www_auth = WwwAuthenticate {
1530 resource_metadata: None,
1531 scope: Some(vec!["files:read".into()]),
1532 error: None,
1533 error_description: None,
1534 };
1535 let resource_meta = ProtectedResourceMetadata {
1536 resource: Url::parse("https://example.com").unwrap(),
1537 authorization_servers: vec![],
1538 scopes_supported: Some(vec!["files:read".into(), "files:write".into()]),
1539 };
1540 assert_eq!(select_scopes(&www_auth, &resource_meta), vec!["files:read"]);
1541 }
1542
1543 #[test]
1544 fn test_select_scopes_falls_back_to_resource_metadata() {
1545 let www_auth = WwwAuthenticate {
1546 resource_metadata: None,
1547 scope: None,
1548 error: None,
1549 error_description: None,
1550 };
1551 let resource_meta = ProtectedResourceMetadata {
1552 resource: Url::parse("https://example.com").unwrap(),
1553 authorization_servers: vec![],
1554 scopes_supported: Some(vec!["admin".into()]),
1555 };
1556 assert_eq!(select_scopes(&www_auth, &resource_meta), vec!["admin"]);
1557 }
1558
1559 #[test]
1560 fn test_select_scopes_empty_when_nothing_available() {
1561 let www_auth = WwwAuthenticate {
1562 resource_metadata: None,
1563 scope: None,
1564 error: None,
1565 error_description: None,
1566 };
1567 let resource_meta = ProtectedResourceMetadata {
1568 resource: Url::parse("https://example.com").unwrap(),
1569 authorization_servers: vec![],
1570 scopes_supported: None,
1571 };
1572 assert!(select_scopes(&www_auth, &resource_meta).is_empty());
1573 }
1574
1575 // -- Client registration strategy tests ----------------------------------
1576
1577 #[test]
1578 fn test_registration_strategy_prefers_cimd() {
1579 let metadata = AuthServerMetadata {
1580 issuer: Url::parse("https://auth.example.com").unwrap(),
1581 authorization_endpoint: Url::parse("https://auth.example.com/authorize").unwrap(),
1582 token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
1583 registration_endpoint: Some(Url::parse("https://auth.example.com/register").unwrap()),
1584 scopes_supported: None,
1585 code_challenge_methods_supported: Some(vec!["S256".into()]),
1586 client_id_metadata_document_supported: true,
1587 };
1588 assert_eq!(
1589 determine_registration_strategy(&metadata),
1590 ClientRegistrationStrategy::Cimd {
1591 client_id: CIMD_URL.to_string(),
1592 }
1593 );
1594 }
1595
1596 #[test]
1597 fn test_registration_strategy_falls_back_to_dcr() {
1598 let reg_endpoint = Url::parse("https://auth.example.com/register").unwrap();
1599 let metadata = AuthServerMetadata {
1600 issuer: Url::parse("https://auth.example.com").unwrap(),
1601 authorization_endpoint: Url::parse("https://auth.example.com/authorize").unwrap(),
1602 token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
1603 registration_endpoint: Some(reg_endpoint.clone()),
1604 scopes_supported: None,
1605 code_challenge_methods_supported: Some(vec!["S256".into()]),
1606 client_id_metadata_document_supported: false,
1607 };
1608 assert_eq!(
1609 determine_registration_strategy(&metadata),
1610 ClientRegistrationStrategy::Dcr {
1611 registration_endpoint: reg_endpoint,
1612 }
1613 );
1614 }
1615
1616 #[test]
1617 fn test_registration_strategy_unavailable() {
1618 let metadata = AuthServerMetadata {
1619 issuer: Url::parse("https://auth.example.com").unwrap(),
1620 authorization_endpoint: Url::parse("https://auth.example.com/authorize").unwrap(),
1621 token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
1622 registration_endpoint: None,
1623 scopes_supported: None,
1624 code_challenge_methods_supported: Some(vec!["S256".into()]),
1625 client_id_metadata_document_supported: false,
1626 };
1627 assert_eq!(
1628 determine_registration_strategy(&metadata),
1629 ClientRegistrationStrategy::Unavailable,
1630 );
1631 }
1632
1633 // -- PKCE tests ----------------------------------------------------------
1634
1635 #[test]
1636 fn test_pkce_challenge_verifier_length() {
1637 let pkce = generate_pkce_challenge();
1638 // 32 random bytes → 43 base64url chars (no padding).
1639 assert_eq!(pkce.verifier.len(), 43);
1640 }
1641
1642 #[test]
1643 fn test_pkce_challenge_is_valid_base64url() {
1644 let pkce = generate_pkce_challenge();
1645 for c in pkce.verifier.chars().chain(pkce.challenge.chars()) {
1646 assert!(
1647 c.is_ascii_alphanumeric() || c == '-' || c == '_',
1648 "invalid base64url character: {}",
1649 c
1650 );
1651 }
1652 }
1653
1654 #[test]
1655 fn test_pkce_challenge_is_s256_of_verifier() {
1656 let pkce = generate_pkce_challenge();
1657 let engine = base64::engine::general_purpose::URL_SAFE_NO_PAD;
1658 let expected_digest = Sha256::digest(pkce.verifier.as_bytes());
1659 let expected_challenge = engine.encode(expected_digest);
1660 assert_eq!(pkce.challenge, expected_challenge);
1661 }
1662
1663 #[test]
1664 fn test_pkce_challenges_are_unique() {
1665 let a = generate_pkce_challenge();
1666 let b = generate_pkce_challenge();
1667 assert_ne!(a.verifier, b.verifier);
1668 }
1669
1670 // -- Authorization URL tests ---------------------------------------------
1671
1672 #[test]
1673 fn test_build_authorization_url() {
1674 let metadata = AuthServerMetadata {
1675 issuer: Url::parse("https://auth.example.com").unwrap(),
1676 authorization_endpoint: Url::parse("https://auth.example.com/authorize").unwrap(),
1677 token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
1678 registration_endpoint: None,
1679 scopes_supported: None,
1680 code_challenge_methods_supported: Some(vec!["S256".into()]),
1681 client_id_metadata_document_supported: true,
1682 };
1683 let pkce = PkceChallenge {
1684 verifier: "test_verifier".into(),
1685 challenge: "test_challenge".into(),
1686 };
1687 let url = build_authorization_url(
1688 &metadata,
1689 "https://zed.dev/oauth/client-metadata.json",
1690 "http://127.0.0.1:12345/callback",
1691 &["files:read".into(), "files:write".into()],
1692 "https://mcp.example.com",
1693 &pkce,
1694 "random_state_123",
1695 );
1696
1697 let pairs: std::collections::HashMap<_, _> = url.query_pairs().collect();
1698 assert_eq!(pairs.get("response_type").unwrap(), "code");
1699 assert_eq!(
1700 pairs.get("client_id").unwrap(),
1701 "https://zed.dev/oauth/client-metadata.json"
1702 );
1703 assert_eq!(
1704 pairs.get("redirect_uri").unwrap(),
1705 "http://127.0.0.1:12345/callback"
1706 );
1707 assert_eq!(pairs.get("scope").unwrap(), "files:read files:write");
1708 assert_eq!(pairs.get("resource").unwrap(), "https://mcp.example.com");
1709 assert_eq!(pairs.get("code_challenge").unwrap(), "test_challenge");
1710 assert_eq!(pairs.get("code_challenge_method").unwrap(), "S256");
1711 assert_eq!(pairs.get("state").unwrap(), "random_state_123");
1712 }
1713
1714 #[test]
1715 fn test_build_authorization_url_omits_empty_scope() {
1716 let metadata = AuthServerMetadata {
1717 issuer: Url::parse("https://auth.example.com").unwrap(),
1718 authorization_endpoint: Url::parse("https://auth.example.com/authorize").unwrap(),
1719 token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
1720 registration_endpoint: None,
1721 scopes_supported: None,
1722 code_challenge_methods_supported: Some(vec!["S256".into()]),
1723 client_id_metadata_document_supported: false,
1724 };
1725 let pkce = PkceChallenge {
1726 verifier: "v".into(),
1727 challenge: "c".into(),
1728 };
1729 let url = build_authorization_url(
1730 &metadata,
1731 "client_123",
1732 "http://127.0.0.1:9999/callback",
1733 &[],
1734 "https://mcp.example.com",
1735 &pkce,
1736 "state",
1737 );
1738
1739 let pairs: std::collections::HashMap<_, _> = url.query_pairs().collect();
1740 assert!(!pairs.contains_key("scope"));
1741 }
1742
1743 // -- Token exchange / refresh param tests --------------------------------
1744
1745 #[test]
1746 fn test_token_exchange_params() {
1747 let params = token_exchange_params(
1748 "auth_code_abc",
1749 "client_xyz",
1750 "http://127.0.0.1:5555/callback",
1751 "verifier_123",
1752 "https://mcp.example.com",
1753 );
1754 let map: std::collections::HashMap<&str, &str> =
1755 params.iter().map(|(k, v)| (*k, v.as_str())).collect();
1756
1757 assert_eq!(map["grant_type"], "authorization_code");
1758 assert_eq!(map["code"], "auth_code_abc");
1759 assert_eq!(map["redirect_uri"], "http://127.0.0.1:5555/callback");
1760 assert_eq!(map["client_id"], "client_xyz");
1761 assert_eq!(map["code_verifier"], "verifier_123");
1762 assert_eq!(map["resource"], "https://mcp.example.com");
1763 }
1764
1765 #[test]
1766 fn test_token_refresh_params() {
1767 let params =
1768 token_refresh_params("refresh_token_abc", "client_xyz", "https://mcp.example.com");
1769 let map: std::collections::HashMap<&str, &str> =
1770 params.iter().map(|(k, v)| (*k, v.as_str())).collect();
1771
1772 assert_eq!(map["grant_type"], "refresh_token");
1773 assert_eq!(map["refresh_token"], "refresh_token_abc");
1774 assert_eq!(map["client_id"], "client_xyz");
1775 assert_eq!(map["resource"], "https://mcp.example.com");
1776 }
1777
1778 // -- Token response tests ------------------------------------------------
1779
1780 #[test]
1781 fn test_token_response_into_tokens_with_expiry() {
1782 let response: TokenResponse = serde_json::from_str(
1783 r#"{"access_token": "at_123", "refresh_token": "rt_456", "expires_in": 3600, "token_type": "Bearer"}"#,
1784 )
1785 .unwrap();
1786
1787 let tokens = response.into_tokens();
1788 assert_eq!(tokens.access_token, "at_123");
1789 assert_eq!(tokens.refresh_token.as_deref(), Some("rt_456"));
1790 assert!(tokens.expires_at.is_some());
1791 }
1792
1793 #[test]
1794 fn test_token_response_into_tokens_minimal() {
1795 let response: TokenResponse =
1796 serde_json::from_str(r#"{"access_token": "at_789"}"#).unwrap();
1797
1798 let tokens = response.into_tokens();
1799 assert_eq!(tokens.access_token, "at_789");
1800 assert_eq!(tokens.refresh_token, None);
1801 assert_eq!(tokens.expires_at, None);
1802 }
1803
1804 // -- DCR body test -------------------------------------------------------
1805
1806 #[test]
1807 fn test_dcr_registration_body_shape() {
1808 let body = dcr_registration_body("http://127.0.0.1:12345/callback");
1809 assert_eq!(body["client_name"], "Zed");
1810 assert_eq!(body["redirect_uris"][0], "http://127.0.0.1:12345/callback");
1811 assert_eq!(body["grant_types"][0], "authorization_code");
1812 assert_eq!(body["response_types"][0], "code");
1813 assert_eq!(body["token_endpoint_auth_method"], "none");
1814 }
1815
1816 // -- Test helpers for async/HTTP tests -----------------------------------
1817
1818 fn make_fake_http_client(
1819 handler: impl Fn(
1820 http_client::Request<AsyncBody>,
1821 ) -> std::pin::Pin<
1822 Box<dyn std::future::Future<Output = anyhow::Result<Response<AsyncBody>>> + Send>,
1823 > + Send
1824 + Sync
1825 + 'static,
1826 ) -> Arc<dyn HttpClient> {
1827 http_client::FakeHttpClient::create(handler) as Arc<dyn HttpClient>
1828 }
1829
1830 fn json_response(status: u16, body: &str) -> anyhow::Result<Response<AsyncBody>> {
1831 Ok(Response::builder()
1832 .status(status)
1833 .header("Content-Type", "application/json")
1834 .body(AsyncBody::from(body.as_bytes().to_vec()))
1835 .unwrap())
1836 }
1837
1838 // -- Discovery integration tests -----------------------------------------
1839
1840 #[test]
1841 fn test_fetch_protected_resource_metadata() {
1842 smol::block_on(async {
1843 let client = make_fake_http_client(|req| {
1844 Box::pin(async move {
1845 let uri = req.uri().to_string();
1846 if uri.contains(".well-known/oauth-protected-resource") {
1847 json_response(
1848 200,
1849 r#"{
1850 "resource": "https://mcp.example.com",
1851 "authorization_servers": ["https://auth.example.com"],
1852 "scopes_supported": ["read", "write"]
1853 }"#,
1854 )
1855 } else {
1856 json_response(404, "{}")
1857 }
1858 })
1859 });
1860
1861 let server_url = Url::parse("https://mcp.example.com").unwrap();
1862 let www_auth = WwwAuthenticate {
1863 resource_metadata: None,
1864 scope: None,
1865 error: None,
1866 error_description: None,
1867 };
1868
1869 let metadata = fetch_protected_resource_metadata(&client, &server_url, &www_auth)
1870 .await
1871 .unwrap();
1872
1873 assert_eq!(metadata.resource.as_str(), "https://mcp.example.com/");
1874 assert_eq!(metadata.authorization_servers.len(), 1);
1875 assert_eq!(
1876 metadata.authorization_servers[0].as_str(),
1877 "https://auth.example.com/"
1878 );
1879 assert_eq!(
1880 metadata.scopes_supported,
1881 Some(vec!["read".to_string(), "write".to_string()])
1882 );
1883 });
1884 }
1885
1886 #[test]
1887 fn test_fetch_protected_resource_metadata_prefers_www_authenticate_url() {
1888 smol::block_on(async {
1889 let client = make_fake_http_client(|req| {
1890 Box::pin(async move {
1891 let uri = req.uri().to_string();
1892 if uri == "https://mcp.example.com/custom-resource-metadata" {
1893 json_response(
1894 200,
1895 r#"{
1896 "resource": "https://mcp.example.com",
1897 "authorization_servers": ["https://auth.example.com"]
1898 }"#,
1899 )
1900 } else {
1901 json_response(500, r#"{"error": "should not be called"}"#)
1902 }
1903 })
1904 });
1905
1906 let server_url = Url::parse("https://mcp.example.com").unwrap();
1907 let www_auth = WwwAuthenticate {
1908 resource_metadata: Some(
1909 Url::parse("https://mcp.example.com/custom-resource-metadata").unwrap(),
1910 ),
1911 scope: None,
1912 error: None,
1913 error_description: None,
1914 };
1915
1916 let metadata = fetch_protected_resource_metadata(&client, &server_url, &www_auth)
1917 .await
1918 .unwrap();
1919
1920 assert_eq!(metadata.authorization_servers.len(), 1);
1921 });
1922 }
1923
1924 #[test]
1925 fn test_fetch_protected_resource_metadata_rejects_cross_origin_url() {
1926 smol::block_on(async {
1927 let client = make_fake_http_client(|req| {
1928 Box::pin(async move {
1929 let uri = req.uri().to_string();
1930 // The cross-origin URL should NOT be fetched; only the
1931 // well-known fallback at the server's own origin should be.
1932 if uri.contains("attacker.example.com") {
1933 panic!("should not fetch cross-origin resource_metadata URL");
1934 } else if uri.contains(".well-known/oauth-protected-resource") {
1935 json_response(
1936 200,
1937 r#"{
1938 "resource": "https://mcp.example.com",
1939 "authorization_servers": ["https://auth.example.com"]
1940 }"#,
1941 )
1942 } else {
1943 json_response(404, "{}")
1944 }
1945 })
1946 });
1947
1948 let server_url = Url::parse("https://mcp.example.com").unwrap();
1949 let www_auth = WwwAuthenticate {
1950 resource_metadata: Some(
1951 Url::parse("https://attacker.example.com/fake-metadata").unwrap(),
1952 ),
1953 scope: None,
1954 error: None,
1955 error_description: None,
1956 };
1957
1958 let metadata = fetch_protected_resource_metadata(&client, &server_url, &www_auth)
1959 .await
1960 .unwrap();
1961
1962 // Should have used the fallback well-known URL, not the attacker's.
1963 assert_eq!(metadata.resource.as_str(), "https://mcp.example.com/");
1964 });
1965 }
1966
1967 #[test]
1968 fn test_fetch_auth_server_metadata() {
1969 smol::block_on(async {
1970 let client = make_fake_http_client(|req| {
1971 Box::pin(async move {
1972 let uri = req.uri().to_string();
1973 if uri.contains(".well-known/oauth-authorization-server") {
1974 json_response(
1975 200,
1976 r#"{
1977 "issuer": "https://auth.example.com",
1978 "authorization_endpoint": "https://auth.example.com/authorize",
1979 "token_endpoint": "https://auth.example.com/token",
1980 "registration_endpoint": "https://auth.example.com/register",
1981 "code_challenge_methods_supported": ["S256"],
1982 "client_id_metadata_document_supported": true
1983 }"#,
1984 )
1985 } else {
1986 json_response(404, "{}")
1987 }
1988 })
1989 });
1990
1991 let issuer = Url::parse("https://auth.example.com").unwrap();
1992 let metadata = fetch_auth_server_metadata(&client, &issuer).await.unwrap();
1993
1994 assert_eq!(metadata.issuer.as_str(), "https://auth.example.com/");
1995 assert_eq!(
1996 metadata.authorization_endpoint.as_str(),
1997 "https://auth.example.com/authorize"
1998 );
1999 assert_eq!(
2000 metadata.token_endpoint.as_str(),
2001 "https://auth.example.com/token"
2002 );
2003 assert!(metadata.registration_endpoint.is_some());
2004 assert!(metadata.client_id_metadata_document_supported);
2005 assert_eq!(
2006 metadata.code_challenge_methods_supported,
2007 Some(vec!["S256".to_string()])
2008 );
2009 });
2010 }
2011
2012 #[test]
2013 fn test_fetch_auth_server_metadata_falls_back_to_oidc() {
2014 smol::block_on(async {
2015 let client = make_fake_http_client(|req| {
2016 Box::pin(async move {
2017 let uri = req.uri().to_string();
2018 if uri.contains("openid-configuration") {
2019 json_response(
2020 200,
2021 r#"{
2022 "issuer": "https://auth.example.com",
2023 "authorization_endpoint": "https://auth.example.com/authorize",
2024 "token_endpoint": "https://auth.example.com/token",
2025 "code_challenge_methods_supported": ["S256"]
2026 }"#,
2027 )
2028 } else {
2029 json_response(404, "{}")
2030 }
2031 })
2032 });
2033
2034 let issuer = Url::parse("https://auth.example.com").unwrap();
2035 let metadata = fetch_auth_server_metadata(&client, &issuer).await.unwrap();
2036
2037 assert_eq!(
2038 metadata.authorization_endpoint.as_str(),
2039 "https://auth.example.com/authorize"
2040 );
2041 assert!(!metadata.client_id_metadata_document_supported);
2042 });
2043 }
2044
2045 #[test]
2046 fn test_fetch_auth_server_metadata_rejects_issuer_mismatch() {
2047 smol::block_on(async {
2048 let client = make_fake_http_client(|req| {
2049 Box::pin(async move {
2050 let uri = req.uri().to_string();
2051 if uri.contains(".well-known/oauth-authorization-server") {
2052 // Response claims to be a different issuer.
2053 json_response(
2054 200,
2055 r#"{
2056 "issuer": "https://evil.example.com",
2057 "authorization_endpoint": "https://evil.example.com/authorize",
2058 "token_endpoint": "https://evil.example.com/token",
2059 "code_challenge_methods_supported": ["S256"]
2060 }"#,
2061 )
2062 } else {
2063 json_response(404, "{}")
2064 }
2065 })
2066 });
2067
2068 let issuer = Url::parse("https://auth.example.com").unwrap();
2069 let result = fetch_auth_server_metadata(&client, &issuer).await;
2070
2071 assert!(result.is_err());
2072 let err_msg = result.unwrap_err().to_string();
2073 assert!(
2074 err_msg.contains("issuer mismatch"),
2075 "unexpected error: {}",
2076 err_msg
2077 );
2078 });
2079 }
2080
2081 // -- Full discover integration tests -------------------------------------
2082
2083 #[test]
2084 fn test_full_discover_with_cimd() {
2085 smol::block_on(async {
2086 let client = make_fake_http_client(|req| {
2087 Box::pin(async move {
2088 let uri = req.uri().to_string();
2089 if uri.contains("oauth-protected-resource") {
2090 json_response(
2091 200,
2092 r#"{
2093 "resource": "https://mcp.example.com",
2094 "authorization_servers": ["https://auth.example.com"],
2095 "scopes_supported": ["mcp:read"]
2096 }"#,
2097 )
2098 } else if uri.contains("oauth-authorization-server") {
2099 json_response(
2100 200,
2101 r#"{
2102 "issuer": "https://auth.example.com",
2103 "authorization_endpoint": "https://auth.example.com/authorize",
2104 "token_endpoint": "https://auth.example.com/token",
2105 "code_challenge_methods_supported": ["S256"],
2106 "client_id_metadata_document_supported": true
2107 }"#,
2108 )
2109 } else {
2110 json_response(404, "{}")
2111 }
2112 })
2113 });
2114
2115 let server_url = Url::parse("https://mcp.example.com").unwrap();
2116 let www_auth = WwwAuthenticate {
2117 resource_metadata: None,
2118 scope: None,
2119 error: None,
2120 error_description: None,
2121 };
2122
2123 let discovery = discover(&client, &server_url, &www_auth).await.unwrap();
2124 let registration =
2125 resolve_client_registration(&client, &discovery, "http://127.0.0.1:12345/callback")
2126 .await
2127 .unwrap();
2128
2129 assert_eq!(registration.client_id, CIMD_URL);
2130 assert_eq!(registration.client_secret, None);
2131 assert_eq!(discovery.scopes, vec!["mcp:read"]);
2132 });
2133 }
2134
2135 #[test]
2136 fn test_full_discover_with_dcr_fallback() {
2137 smol::block_on(async {
2138 let client = make_fake_http_client(|req| {
2139 Box::pin(async move {
2140 let uri = req.uri().to_string();
2141 if uri.contains("oauth-protected-resource") {
2142 json_response(
2143 200,
2144 r#"{
2145 "resource": "https://mcp.example.com",
2146 "authorization_servers": ["https://auth.example.com"]
2147 }"#,
2148 )
2149 } else if uri.contains("oauth-authorization-server") {
2150 json_response(
2151 200,
2152 r#"{
2153 "issuer": "https://auth.example.com",
2154 "authorization_endpoint": "https://auth.example.com/authorize",
2155 "token_endpoint": "https://auth.example.com/token",
2156 "registration_endpoint": "https://auth.example.com/register",
2157 "code_challenge_methods_supported": ["S256"],
2158 "client_id_metadata_document_supported": false
2159 }"#,
2160 )
2161 } else if uri.contains("/register") {
2162 json_response(
2163 201,
2164 r#"{
2165 "client_id": "dcr-minted-id-123",
2166 "client_secret": "dcr-secret-456"
2167 }"#,
2168 )
2169 } else {
2170 json_response(404, "{}")
2171 }
2172 })
2173 });
2174
2175 let server_url = Url::parse("https://mcp.example.com").unwrap();
2176 let www_auth = WwwAuthenticate {
2177 resource_metadata: None,
2178 scope: Some(vec!["files:read".into()]),
2179 error: None,
2180 error_description: None,
2181 };
2182
2183 let discovery = discover(&client, &server_url, &www_auth).await.unwrap();
2184 let registration =
2185 resolve_client_registration(&client, &discovery, "http://127.0.0.1:9999/callback")
2186 .await
2187 .unwrap();
2188
2189 assert_eq!(registration.client_id, "dcr-minted-id-123");
2190 assert_eq!(
2191 registration.client_secret.as_deref(),
2192 Some("dcr-secret-456")
2193 );
2194 assert_eq!(discovery.scopes, vec!["files:read"]);
2195 });
2196 }
2197
2198 #[test]
2199 fn test_discover_fails_without_pkce_support() {
2200 smol::block_on(async {
2201 let client = make_fake_http_client(|req| {
2202 Box::pin(async move {
2203 let uri = req.uri().to_string();
2204 if uri.contains("oauth-protected-resource") {
2205 json_response(
2206 200,
2207 r#"{
2208 "resource": "https://mcp.example.com",
2209 "authorization_servers": ["https://auth.example.com"]
2210 }"#,
2211 )
2212 } else if uri.contains("oauth-authorization-server") {
2213 json_response(
2214 200,
2215 r#"{
2216 "issuer": "https://auth.example.com",
2217 "authorization_endpoint": "https://auth.example.com/authorize",
2218 "token_endpoint": "https://auth.example.com/token"
2219 }"#,
2220 )
2221 } else {
2222 json_response(404, "{}")
2223 }
2224 })
2225 });
2226
2227 let server_url = Url::parse("https://mcp.example.com").unwrap();
2228 let www_auth = WwwAuthenticate {
2229 resource_metadata: None,
2230 scope: None,
2231 error: None,
2232 error_description: None,
2233 };
2234
2235 let result = discover(&client, &server_url, &www_auth).await;
2236 assert!(result.is_err());
2237 let err_msg = result.unwrap_err().to_string();
2238 assert!(
2239 err_msg.contains("code_challenge_methods_supported"),
2240 "unexpected error: {}",
2241 err_msg
2242 );
2243 });
2244 }
2245
2246 // -- Token exchange integration tests ------------------------------------
2247
2248 #[test]
2249 fn test_exchange_code_success() {
2250 smol::block_on(async {
2251 let client = make_fake_http_client(|req| {
2252 Box::pin(async move {
2253 let uri = req.uri().to_string();
2254 if uri.contains("/token") {
2255 json_response(
2256 200,
2257 r#"{
2258 "access_token": "new_access_token",
2259 "refresh_token": "new_refresh_token",
2260 "expires_in": 3600,
2261 "token_type": "Bearer"
2262 }"#,
2263 )
2264 } else {
2265 json_response(404, "{}")
2266 }
2267 })
2268 });
2269
2270 let metadata = AuthServerMetadata {
2271 issuer: Url::parse("https://auth.example.com").unwrap(),
2272 authorization_endpoint: Url::parse("https://auth.example.com/authorize").unwrap(),
2273 token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
2274 registration_endpoint: None,
2275 scopes_supported: None,
2276 code_challenge_methods_supported: Some(vec!["S256".into()]),
2277 client_id_metadata_document_supported: true,
2278 };
2279
2280 let tokens = exchange_code(
2281 &client,
2282 &metadata,
2283 "auth_code_123",
2284 CIMD_URL,
2285 "http://127.0.0.1:9999/callback",
2286 "verifier_abc",
2287 "https://mcp.example.com",
2288 )
2289 .await
2290 .unwrap();
2291
2292 assert_eq!(tokens.access_token, "new_access_token");
2293 assert_eq!(tokens.refresh_token.as_deref(), Some("new_refresh_token"));
2294 assert!(tokens.expires_at.is_some());
2295 });
2296 }
2297
2298 #[test]
2299 fn test_refresh_tokens_success() {
2300 smol::block_on(async {
2301 let client = make_fake_http_client(|req| {
2302 Box::pin(async move {
2303 let uri = req.uri().to_string();
2304 if uri.contains("/token") {
2305 json_response(
2306 200,
2307 r#"{
2308 "access_token": "refreshed_token",
2309 "expires_in": 1800,
2310 "token_type": "Bearer"
2311 }"#,
2312 )
2313 } else {
2314 json_response(404, "{}")
2315 }
2316 })
2317 });
2318
2319 let token_endpoint = Url::parse("https://auth.example.com/token").unwrap();
2320
2321 let tokens = refresh_tokens(
2322 &client,
2323 &token_endpoint,
2324 "old_refresh_token",
2325 CIMD_URL,
2326 "https://mcp.example.com",
2327 )
2328 .await
2329 .unwrap();
2330
2331 assert_eq!(tokens.access_token, "refreshed_token");
2332 assert_eq!(tokens.refresh_token, None);
2333 assert!(tokens.expires_at.is_some());
2334 });
2335 }
2336
2337 #[test]
2338 fn test_exchange_code_failure() {
2339 smol::block_on(async {
2340 let client = make_fake_http_client(|_req| {
2341 Box::pin(async move { json_response(400, r#"{"error": "invalid_grant"}"#) })
2342 });
2343
2344 let metadata = AuthServerMetadata {
2345 issuer: Url::parse("https://auth.example.com").unwrap(),
2346 authorization_endpoint: Url::parse("https://auth.example.com/authorize").unwrap(),
2347 token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
2348 registration_endpoint: None,
2349 scopes_supported: None,
2350 code_challenge_methods_supported: Some(vec!["S256".into()]),
2351 client_id_metadata_document_supported: true,
2352 };
2353
2354 let result = exchange_code(
2355 &client,
2356 &metadata,
2357 "bad_code",
2358 "client",
2359 "http://127.0.0.1:1/callback",
2360 "verifier",
2361 "https://mcp.example.com",
2362 )
2363 .await;
2364
2365 assert!(result.is_err());
2366 assert!(result.unwrap_err().to_string().contains("400"));
2367 });
2368 }
2369
2370 // -- DCR integration tests -----------------------------------------------
2371
2372 #[test]
2373 fn test_perform_dcr() {
2374 smol::block_on(async {
2375 let client = make_fake_http_client(|_req| {
2376 Box::pin(async move {
2377 json_response(
2378 201,
2379 r#"{
2380 "client_id": "dynamic-client-001",
2381 "client_secret": "dynamic-secret-001"
2382 }"#,
2383 )
2384 })
2385 });
2386
2387 let endpoint = Url::parse("https://auth.example.com/register").unwrap();
2388 let registration = perform_dcr(&client, &endpoint, "http://127.0.0.1:9999/callback")
2389 .await
2390 .unwrap();
2391
2392 assert_eq!(registration.client_id, "dynamic-client-001");
2393 assert_eq!(
2394 registration.client_secret.as_deref(),
2395 Some("dynamic-secret-001")
2396 );
2397 });
2398 }
2399
2400 #[test]
2401 fn test_perform_dcr_failure() {
2402 smol::block_on(async {
2403 let client = make_fake_http_client(|_req| {
2404 Box::pin(
2405 async move { json_response(403, r#"{"error": "registration_not_allowed"}"#) },
2406 )
2407 });
2408
2409 let endpoint = Url::parse("https://auth.example.com/register").unwrap();
2410 let result = perform_dcr(&client, &endpoint, "http://127.0.0.1:9999/callback").await;
2411
2412 assert!(result.is_err());
2413 assert!(result.unwrap_err().to_string().contains("403"));
2414 });
2415 }
2416
2417 // -- OAuthCallback parse tests -------------------------------------------
2418
2419 #[test]
2420 fn test_oauth_callback_parse_query() {
2421 let callback = OAuthCallback::parse_query("code=test_auth_code&state=test_state").unwrap();
2422 assert_eq!(callback.code, "test_auth_code");
2423 assert_eq!(callback.state, "test_state");
2424 }
2425
2426 #[test]
2427 fn test_oauth_callback_parse_query_reversed_order() {
2428 let callback = OAuthCallback::parse_query("state=test_state&code=test_auth_code").unwrap();
2429 assert_eq!(callback.code, "test_auth_code");
2430 assert_eq!(callback.state, "test_state");
2431 }
2432
2433 #[test]
2434 fn test_oauth_callback_parse_query_with_extra_params() {
2435 let callback =
2436 OAuthCallback::parse_query("code=test_auth_code&state=test_state&extra=ignored")
2437 .unwrap();
2438 assert_eq!(callback.code, "test_auth_code");
2439 assert_eq!(callback.state, "test_state");
2440 }
2441
2442 #[test]
2443 fn test_oauth_callback_parse_query_missing_code() {
2444 let result = OAuthCallback::parse_query("state=test_state");
2445 assert!(result.is_err());
2446 assert!(result.unwrap_err().to_string().contains("code"));
2447 }
2448
2449 #[test]
2450 fn test_oauth_callback_parse_query_missing_state() {
2451 let result = OAuthCallback::parse_query("code=test_auth_code");
2452 assert!(result.is_err());
2453 assert!(result.unwrap_err().to_string().contains("state"));
2454 }
2455
2456 #[test]
2457 fn test_oauth_callback_parse_query_empty_code() {
2458 let result = OAuthCallback::parse_query("code=&state=test_state");
2459 assert!(result.is_err());
2460 }
2461
2462 #[test]
2463 fn test_oauth_callback_parse_query_empty_state() {
2464 let result = OAuthCallback::parse_query("code=test_auth_code&state=");
2465 assert!(result.is_err());
2466 }
2467
2468 #[test]
2469 fn test_oauth_callback_parse_query_url_encoded_values() {
2470 let callback = OAuthCallback::parse_query("code=abc%20def&state=test%3Dstate").unwrap();
2471 assert_eq!(callback.code, "abc def");
2472 assert_eq!(callback.state, "test=state");
2473 }
2474
2475 #[test]
2476 fn test_oauth_callback_parse_query_error_response() {
2477 let result = OAuthCallback::parse_query(
2478 "error=access_denied&error_description=User%20denied%20access&state=abc",
2479 );
2480 assert!(result.is_err());
2481 let err_msg = result.unwrap_err().to_string();
2482 assert!(
2483 err_msg.contains("access_denied"),
2484 "unexpected error: {}",
2485 err_msg
2486 );
2487 assert!(
2488 err_msg.contains("User denied access"),
2489 "unexpected error: {}",
2490 err_msg
2491 );
2492 }
2493
2494 #[test]
2495 fn test_oauth_callback_parse_query_error_without_description() {
2496 let result = OAuthCallback::parse_query("error=server_error&state=abc");
2497 assert!(result.is_err());
2498 let err_msg = result.unwrap_err().to_string();
2499 assert!(
2500 err_msg.contains("server_error"),
2501 "unexpected error: {}",
2502 err_msg
2503 );
2504 assert!(
2505 err_msg.contains("no description"),
2506 "unexpected error: {}",
2507 err_msg
2508 );
2509 }
2510
2511 // -- McpOAuthTokenProvider tests -----------------------------------------
2512
2513 fn make_test_session(
2514 access_token: &str,
2515 refresh_token: Option<&str>,
2516 expires_at: Option<SystemTime>,
2517 ) -> OAuthSession {
2518 OAuthSession {
2519 token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
2520 resource: Url::parse("https://mcp.example.com").unwrap(),
2521 client_registration: OAuthClientRegistration {
2522 client_id: "test-client".into(),
2523 client_secret: None,
2524 },
2525 tokens: OAuthTokens {
2526 access_token: access_token.into(),
2527 refresh_token: refresh_token.map(String::from),
2528 expires_at,
2529 },
2530 }
2531 }
2532
2533 #[test]
2534 fn test_mcp_oauth_provider_returns_none_when_token_expired() {
2535 let expired = SystemTime::now() - Duration::from_secs(60);
2536 let session = make_test_session("stale-token", Some("rt"), Some(expired));
2537 let provider = McpOAuthTokenProvider::new(
2538 session,
2539 make_fake_http_client(|_| Box::pin(async { unreachable!() })),
2540 None,
2541 );
2542
2543 assert_eq!(provider.access_token(), None);
2544 }
2545
2546 #[test]
2547 fn test_mcp_oauth_provider_returns_token_when_not_expired() {
2548 let far_future = SystemTime::now() + Duration::from_secs(3600);
2549 let session = make_test_session("valid-token", Some("rt"), Some(far_future));
2550 let provider = McpOAuthTokenProvider::new(
2551 session,
2552 make_fake_http_client(|_| Box::pin(async { unreachable!() })),
2553 None,
2554 );
2555
2556 assert_eq!(provider.access_token().as_deref(), Some("valid-token"));
2557 }
2558
2559 #[test]
2560 fn test_mcp_oauth_provider_returns_token_when_no_expiry() {
2561 let session = make_test_session("no-expiry-token", Some("rt"), None);
2562 let provider = McpOAuthTokenProvider::new(
2563 session,
2564 make_fake_http_client(|_| Box::pin(async { unreachable!() })),
2565 None,
2566 );
2567
2568 assert_eq!(provider.access_token().as_deref(), Some("no-expiry-token"));
2569 }
2570
2571 #[test]
2572 fn test_mcp_oauth_provider_refresh_without_refresh_token_returns_false() {
2573 smol::block_on(async {
2574 let session = make_test_session("token", None, None);
2575 let provider = McpOAuthTokenProvider::new(
2576 session,
2577 make_fake_http_client(|_| {
2578 Box::pin(async { unreachable!("no HTTP call expected") })
2579 }),
2580 None,
2581 );
2582
2583 let refreshed = provider.try_refresh().await.unwrap();
2584 assert!(!refreshed);
2585 });
2586 }
2587
2588 #[test]
2589 fn test_mcp_oauth_provider_refresh_updates_session_and_notifies_channel() {
2590 smol::block_on(async {
2591 let session = make_test_session("old-access", Some("my-refresh-token"), None);
2592 let (tx, mut rx) = futures::channel::mpsc::unbounded();
2593
2594 let http_client = make_fake_http_client(|_req| {
2595 Box::pin(async {
2596 json_response(
2597 200,
2598 r#"{
2599 "access_token": "new-access",
2600 "refresh_token": "new-refresh",
2601 "expires_in": 1800
2602 }"#,
2603 )
2604 })
2605 });
2606
2607 let provider = McpOAuthTokenProvider::new(session, http_client, Some(tx));
2608
2609 let refreshed = provider.try_refresh().await.unwrap();
2610 assert!(refreshed);
2611 assert_eq!(provider.access_token().as_deref(), Some("new-access"));
2612
2613 let notified_session = rx.try_recv().expect("channel should have a session");
2614 assert_eq!(notified_session.tokens.access_token, "new-access");
2615 assert_eq!(
2616 notified_session.tokens.refresh_token.as_deref(),
2617 Some("new-refresh")
2618 );
2619 });
2620 }
2621
2622 #[test]
2623 fn test_mcp_oauth_provider_refresh_preserves_old_refresh_token_when_server_omits_it() {
2624 smol::block_on(async {
2625 let session = make_test_session("old-access", Some("original-refresh"), None);
2626 let (tx, mut rx) = futures::channel::mpsc::unbounded();
2627
2628 let http_client = make_fake_http_client(|_req| {
2629 Box::pin(async {
2630 json_response(
2631 200,
2632 r#"{
2633 "access_token": "new-access",
2634 "expires_in": 900
2635 }"#,
2636 )
2637 })
2638 });
2639
2640 let provider = McpOAuthTokenProvider::new(session, http_client, Some(tx));
2641
2642 let refreshed = provider.try_refresh().await.unwrap();
2643 assert!(refreshed);
2644
2645 let notified_session = rx.try_recv().expect("channel should have a session");
2646 assert_eq!(notified_session.tokens.access_token, "new-access");
2647 assert_eq!(
2648 notified_session.tokens.refresh_token.as_deref(),
2649 Some("original-refresh"),
2650 );
2651 });
2652 }
2653
2654 #[test]
2655 fn test_mcp_oauth_provider_refresh_returns_false_on_http_error() {
2656 smol::block_on(async {
2657 let session = make_test_session("old-access", Some("my-refresh"), None);
2658
2659 let http_client = make_fake_http_client(|_req| {
2660 Box::pin(async { json_response(401, r#"{"error": "invalid_grant"}"#) })
2661 });
2662
2663 let provider = McpOAuthTokenProvider::new(session, http_client, None);
2664
2665 let refreshed = provider.try_refresh().await.unwrap();
2666 assert!(!refreshed);
2667 // The old token should still be in place.
2668 assert_eq!(provider.access_token().as_deref(), Some("old-access"));
2669 });
2670 }
2671}