1//! OAuth 2.0 authentication for MCP servers using the Authorization Code +
2//! PKCE flow, per the MCP spec's OAuth profile.
3//!
4//! The flow is split into two phases:
5//!
6//! 1. **Discovery** ([`discover`]) fetches Protected Resource Metadata and
7//! Authorization Server Metadata. This can happen early (e.g. on a 401
8//! during server startup) because it doesn't need the redirect URI yet.
9//!
10//! 2. **Client registration** ([`resolve_client_registration`]) is separate
11//! because DCR requires the actual loopback redirect URI, which includes an
12//! ephemeral port that only exists once the callback server has started.
13//!
14//! After authentication, the full state is captured in [`OAuthSession`] which
15//! is persisted to the keychain. On next startup, the stored session feeds
16//! directly into [`McpOAuthTokenProvider`], giving a refresh-capable provider
17//! without requiring another browser flow.
18
19use anyhow::{Context as _, Result, anyhow, bail};
20use async_trait::async_trait;
21use base64::Engine as _;
22use futures::AsyncReadExt as _;
23use futures::channel::mpsc;
24use http_client::{AsyncBody, HttpClient, Request};
25use parking_lot::Mutex as SyncMutex;
26use rand::Rng as _;
27use serde::{Deserialize, Serialize};
28use sha2::{Digest, Sha256};
29
30use std::str::FromStr;
31use std::sync::Arc;
32use std::time::{Duration, SystemTime};
33use url::Url;
34use util::ResultExt as _;
35
36/// The CIMD URL where Zed's OAuth client metadata document is hosted.
37pub const CIMD_URL: &str = "https://zed.dev/oauth/client-metadata.json";
38
39/// Validate that a URL is safe to use as an OAuth endpoint.
40///
41/// OAuth endpoints carry sensitive material (authorization codes, PKCE
42/// verifiers, tokens) and must use TLS. Plain HTTP is only permitted for
43/// loopback addresses, per RFC 8252 Section 8.3.
44fn require_https_or_loopback(url: &Url) -> Result<()> {
45 if url.scheme() == "https" {
46 return Ok(());
47 }
48 if url.scheme() == "http" {
49 if let Some(host) = url.host() {
50 match host {
51 url::Host::Ipv4(ip) if ip.is_loopback() => return Ok(()),
52 url::Host::Ipv6(ip) if ip.is_loopback() => return Ok(()),
53 url::Host::Domain(d) if d.eq_ignore_ascii_case("localhost") => return Ok(()),
54 _ => {}
55 }
56 }
57 }
58 bail!(
59 "OAuth endpoint must use HTTPS (got {}://{})",
60 url.scheme(),
61 url.host_str().unwrap_or("?")
62 )
63}
64
65/// Validate that a URL is safe to use as an OAuth endpoint, including SSRF
66/// protections against private/reserved IP ranges.
67///
68/// This wraps [`require_https_or_loopback`] and adds IP-range checks to prevent
69/// an attacker-controlled MCP server from directing Zed to fetch internal
70/// network resources via metadata URLs.
71///
72/// **Known limitation:** Domain-name URLs that resolve to private IPs are *not*
73/// blocked here — full mitigation requires resolver-level validation (e.g. a
74/// custom `Resolve` implementation). This function only blocks IP-literal URLs.
75fn validate_oauth_url(url: &Url) -> Result<()> {
76 require_https_or_loopback(url)?;
77
78 if let Some(host) = url.host() {
79 match host {
80 url::Host::Ipv4(ip) => {
81 // Loopback is already allowed by require_https_or_loopback.
82 if ip.is_private() || ip.is_link_local() || ip.is_broadcast() || ip.is_unspecified()
83 {
84 bail!(
85 "OAuth endpoint must not point to private/reserved IP: {}",
86 ip
87 );
88 }
89 }
90 url::Host::Ipv6(ip) => {
91 // Check for IPv4-mapped IPv6 addresses (::ffff:a.b.c.d) which
92 // could bypass the IPv4 checks above.
93 if let Some(mapped_v4) = ip.to_ipv4_mapped() {
94 if mapped_v4.is_private()
95 || mapped_v4.is_link_local()
96 || mapped_v4.is_broadcast()
97 || mapped_v4.is_unspecified()
98 {
99 bail!(
100 "OAuth endpoint must not point to private/reserved IP: ::ffff:{}",
101 mapped_v4
102 );
103 }
104 }
105
106 if ip.is_unspecified() || ip.is_multicast() {
107 bail!(
108 "OAuth endpoint must not point to reserved IPv6 address: {}",
109 ip
110 );
111 }
112 // IPv6 Unique Local Addresses (fc00::/7). is_unique_local() is
113 // nightly-only, so check the prefix manually.
114 if (ip.segments()[0] & 0xfe00) == 0xfc00 {
115 bail!(
116 "OAuth endpoint must not point to IPv6 unique-local address: {}",
117 ip
118 );
119 }
120 }
121 url::Host::Domain(_) => {
122 // Domain-based SSRF prevention requires resolver-level checks.
123 // See known limitation in the doc comment above.
124 }
125 }
126 }
127
128 Ok(())
129}
130
131/// Parsed from the MCP server's WWW-Authenticate header or well-known endpoint
132/// per RFC 9728 (OAuth 2.0 Protected Resource Metadata).
133#[derive(Debug, Clone, Serialize, Deserialize)]
134pub struct ProtectedResourceMetadata {
135 pub resource: Url,
136 pub authorization_servers: Vec<Url>,
137 pub scopes_supported: Option<Vec<String>>,
138}
139
140/// Parsed from the authorization server's .well-known endpoint
141/// per RFC 8414 (OAuth 2.0 Authorization Server Metadata).
142#[derive(Debug, Clone, Serialize, Deserialize)]
143pub struct AuthServerMetadata {
144 pub issuer: Url,
145 pub authorization_endpoint: Url,
146 pub token_endpoint: Url,
147 pub registration_endpoint: Option<Url>,
148 pub scopes_supported: Option<Vec<String>>,
149 pub grant_types_supported: Option<Vec<String>>,
150 pub code_challenge_methods_supported: Option<Vec<String>>,
151 pub client_id_metadata_document_supported: bool,
152}
153
154/// The result of client registration — either CIMD or DCR.
155#[derive(Clone, Serialize, Deserialize)]
156pub struct OAuthClientRegistration {
157 pub client_id: String,
158 /// Only present for DCR-minted registrations.
159 pub client_secret: Option<String>,
160}
161
162impl std::fmt::Debug for OAuthClientRegistration {
163 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
164 f.debug_struct("OAuthClientRegistration")
165 .field("client_id", &self.client_id)
166 .field(
167 "client_secret",
168 &self.client_secret.as_ref().map(|_| "[redacted]"),
169 )
170 .finish()
171 }
172}
173
174/// Access and refresh tokens obtained from the token endpoint.
175#[derive(Clone, Serialize, Deserialize)]
176pub struct OAuthTokens {
177 pub access_token: String,
178 pub refresh_token: Option<String>,
179 pub expires_at: Option<SystemTime>,
180}
181
182impl std::fmt::Debug for OAuthTokens {
183 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
184 f.debug_struct("OAuthTokens")
185 .field("access_token", &"[redacted]")
186 .field(
187 "refresh_token",
188 &self.refresh_token.as_ref().map(|_| "[redacted]"),
189 )
190 .field("expires_at", &self.expires_at)
191 .finish()
192 }
193}
194
195/// Everything discovered before the browser flow starts. Client registration is
196/// resolved separately, once the real redirect URI is known.
197#[derive(Debug, Clone, Serialize, Deserialize)]
198pub struct OAuthDiscovery {
199 pub resource_metadata: ProtectedResourceMetadata,
200 pub auth_server_metadata: AuthServerMetadata,
201 pub scopes: Vec<String>,
202}
203
204/// The persisted OAuth session for a context server.
205///
206/// Stored in the keychain so startup can restore a refresh-capable provider
207/// without another browser flow. Deliberately excludes the full discovery
208/// metadata to keep the serialized size well within keychain item limits.
209#[derive(Clone, Serialize, Deserialize)]
210pub struct OAuthSession {
211 pub token_endpoint: Url,
212 pub resource: Url,
213 pub client_registration: OAuthClientRegistration,
214 pub tokens: OAuthTokens,
215}
216
217impl std::fmt::Debug for OAuthSession {
218 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
219 f.debug_struct("OAuthSession")
220 .field("token_endpoint", &self.token_endpoint)
221 .field("resource", &self.resource)
222 .field("client_registration", &self.client_registration)
223 .field("tokens", &self.tokens)
224 .finish()
225 }
226}
227
228/// Error codes defined by RFC 6750 Section 3.1 for Bearer token authentication.
229#[derive(Debug, Clone, Copy, PartialEq, Eq)]
230pub enum BearerError {
231 /// The request is missing a required parameter, includes an unsupported
232 /// parameter or parameter value, or is otherwise malformed.
233 InvalidRequest,
234 /// The access token provided is expired, revoked, malformed, or invalid.
235 InvalidToken,
236 /// The request requires higher privileges than provided by the access token.
237 InsufficientScope,
238 /// An unrecognized error code (extension or future spec addition).
239 Other,
240}
241
242impl BearerError {
243 fn parse(value: &str) -> Self {
244 match value {
245 "invalid_request" => BearerError::InvalidRequest,
246 "invalid_token" => BearerError::InvalidToken,
247 "insufficient_scope" => BearerError::InsufficientScope,
248 _ => BearerError::Other,
249 }
250 }
251}
252
253/// Fields extracted from a `WWW-Authenticate: Bearer` header.
254///
255/// Per RFC 9728 Section 5.1, MCP servers include `resource_metadata` to point
256/// at the Protected Resource Metadata document. The optional `scope` parameter
257/// (RFC 6750 Section 3) indicates scopes required for the request.
258#[derive(Debug, Clone, PartialEq, Eq)]
259pub struct WwwAuthenticate {
260 pub resource_metadata: Option<Url>,
261 pub scope: Option<Vec<String>>,
262 /// The parsed `error` parameter per RFC 6750 Section 3.1.
263 pub error: Option<BearerError>,
264 pub error_description: Option<String>,
265}
266
267/// Parse a `WWW-Authenticate` header value.
268///
269/// Expects the `Bearer` scheme followed by comma-separated `key="value"` pairs.
270/// Per RFC 6750 and RFC 9728, the relevant parameters are:
271/// - `resource_metadata` — URL of the Protected Resource Metadata document
272/// - `scope` — space-separated list of required scopes
273/// - `error` — error code (e.g. "insufficient_scope")
274/// - `error_description` — human-readable error description
275pub fn parse_www_authenticate(header: &str) -> Result<WwwAuthenticate> {
276 let header = header.trim();
277
278 let params_str = if header.len() >= 6 && header[..6].eq_ignore_ascii_case("bearer") {
279 header[6..].trim()
280 } else {
281 bail!("WWW-Authenticate header does not use Bearer scheme");
282 };
283
284 if params_str.is_empty() {
285 return Ok(WwwAuthenticate {
286 resource_metadata: None,
287 scope: None,
288 error: None,
289 error_description: None,
290 });
291 }
292
293 let params = parse_auth_params(params_str);
294
295 let resource_metadata = params
296 .get("resource_metadata")
297 .map(|v| Url::parse(v))
298 .transpose()
299 .map_err(|e| anyhow!("invalid resource_metadata URL: {}", e))?;
300
301 let scope = params
302 .get("scope")
303 .map(|v| v.split_whitespace().map(String::from).collect());
304
305 let error = params.get("error").map(|v| BearerError::parse(v));
306 let error_description = params.get("error_description").cloned();
307
308 Ok(WwwAuthenticate {
309 resource_metadata,
310 scope,
311 error,
312 error_description,
313 })
314}
315
316/// Parse comma-separated `key="value"` or `key=token` parameters from an
317/// auth-param list (RFC 7235 Section 2.1).
318fn parse_auth_params(input: &str) -> collections::HashMap<String, String> {
319 let mut params = collections::HashMap::default();
320 let mut remaining = input.trim();
321
322 while !remaining.is_empty() {
323 // Skip leading whitespace and commas.
324 remaining = remaining.trim_start_matches(|c: char| c == ',' || c.is_whitespace());
325 if remaining.is_empty() {
326 break;
327 }
328
329 // Find the key (everything before '=').
330 let eq_pos = match remaining.find('=') {
331 Some(pos) => pos,
332 None => break,
333 };
334
335 let key = remaining[..eq_pos].trim().to_lowercase();
336 remaining = &remaining[eq_pos + 1..];
337 remaining = remaining.trim_start();
338
339 // Parse the value: either quoted or unquoted (token).
340 let value;
341 if remaining.starts_with('"') {
342 // Quoted string: find the closing quote, handling escaped chars.
343 remaining = &remaining[1..]; // skip opening quote
344 let mut val = String::new();
345 let mut chars = remaining.char_indices();
346 loop {
347 match chars.next() {
348 Some((_, '\\')) => {
349 // Escaped character — take the next char literally.
350 if let Some((_, c)) = chars.next() {
351 val.push(c);
352 }
353 }
354 Some((i, '"')) => {
355 remaining = &remaining[i + 1..];
356 break;
357 }
358 Some((_, c)) => val.push(c),
359 None => {
360 remaining = "";
361 break;
362 }
363 }
364 }
365 value = val;
366 } else {
367 // Unquoted token: read until comma or whitespace.
368 let end = remaining
369 .find(|c: char| c == ',' || c.is_whitespace())
370 .unwrap_or(remaining.len());
371 value = remaining[..end].to_string();
372 remaining = &remaining[end..];
373 }
374
375 if !key.is_empty() {
376 params.insert(key, value);
377 }
378 }
379
380 params
381}
382
383/// Construct the well-known Protected Resource Metadata URIs for a given MCP
384/// server URL, per RFC 9728 Section 3.
385///
386/// Returns URIs in priority order:
387/// 1. Path-specific: `https://<host>/.well-known/oauth-protected-resource/<path>`
388/// 2. Root: `https://<host>/.well-known/oauth-protected-resource`
389pub fn protected_resource_metadata_urls(server_url: &Url) -> Vec<Url> {
390 let mut urls = Vec::new();
391 let base = format!("{}://{}", server_url.scheme(), server_url.authority());
392
393 let path = server_url.path().trim_start_matches('/');
394 if !path.is_empty() {
395 if let Ok(url) = Url::parse(&format!(
396 "{}/.well-known/oauth-protected-resource/{}",
397 base, path
398 )) {
399 urls.push(url);
400 }
401 }
402
403 if let Ok(url) = Url::parse(&format!("{}/.well-known/oauth-protected-resource", base)) {
404 urls.push(url);
405 }
406
407 urls
408}
409
410/// Construct the well-known Authorization Server Metadata URIs for a given
411/// issuer URL, per RFC 8414 Section 3.1 and Section 5 (OIDC compat).
412///
413/// Returns URIs in priority order, which differs depending on whether the
414/// issuer URL has a path component.
415pub fn auth_server_metadata_urls(issuer: &Url) -> Vec<Url> {
416 let mut urls = Vec::new();
417 let base = format!("{}://{}", issuer.scheme(), issuer.authority());
418 let path = issuer.path().trim_matches('/');
419
420 if !path.is_empty() {
421 // Issuer with path: try path-inserted variants first.
422 if let Ok(url) = Url::parse(&format!(
423 "{}/.well-known/oauth-authorization-server/{}",
424 base, path
425 )) {
426 urls.push(url);
427 }
428 if let Ok(url) = Url::parse(&format!(
429 "{}/.well-known/openid-configuration/{}",
430 base, path
431 )) {
432 urls.push(url);
433 }
434 if let Ok(url) = Url::parse(&format!(
435 "{}/{}/.well-known/openid-configuration",
436 base, path
437 )) {
438 urls.push(url);
439 }
440 } else {
441 // No path: standard well-known locations.
442 if let Ok(url) = Url::parse(&format!("{}/.well-known/oauth-authorization-server", base)) {
443 urls.push(url);
444 }
445 if let Ok(url) = Url::parse(&format!("{}/.well-known/openid-configuration", base)) {
446 urls.push(url);
447 }
448 }
449
450 urls
451}
452
453// -- Canonical server URI (RFC 8707) -----------------------------------------
454
455/// Derive the canonical resource URI for an MCP server URL, suitable for the
456/// `resource` parameter in authorization and token requests per RFC 8707.
457///
458/// Lowercases the scheme and host, preserves the path (without trailing slash),
459/// strips fragments and query strings.
460pub fn canonical_server_uri(server_url: &Url) -> String {
461 let mut uri = format!(
462 "{}://{}",
463 server_url.scheme().to_ascii_lowercase(),
464 server_url.host_str().unwrap_or("").to_ascii_lowercase(),
465 );
466 if let Some(port) = server_url.port() {
467 uri.push_str(&format!(":{}", port));
468 }
469 let path = server_url.path();
470 if path != "/" {
471 uri.push_str(path.trim_end_matches('/'));
472 }
473 uri
474}
475
476// -- Scope selection ---------------------------------------------------------
477
478/// Select scopes following the MCP spec's Scope Selection Strategy:
479/// 1. Use `scope` from the `WWW-Authenticate` challenge if present.
480/// 2. Fall back to `scopes_supported` from Protected Resource Metadata.
481/// 3. Return empty if neither is available.
482pub fn select_scopes(
483 www_authenticate: &WwwAuthenticate,
484 resource_metadata: &ProtectedResourceMetadata,
485) -> Vec<String> {
486 if let Some(ref scopes) = www_authenticate.scope {
487 if !scopes.is_empty() {
488 return scopes.clone();
489 }
490 }
491 resource_metadata
492 .scopes_supported
493 .clone()
494 .unwrap_or_default()
495}
496
497// -- Client registration strategy --------------------------------------------
498
499/// The registration approach to use, determined from auth server metadata.
500#[derive(Debug, Clone, PartialEq, Eq)]
501pub enum ClientRegistrationStrategy {
502 /// The auth server supports CIMD. Use the CIMD URL as client_id directly.
503 Cimd { client_id: String },
504 /// The auth server has a registration endpoint. Caller must POST to it.
505 Dcr { registration_endpoint: Url },
506 /// No supported registration mechanism.
507 Unavailable,
508}
509
510/// Determine how to register with the authorization server, following the
511/// spec's recommended priority: CIMD first, DCR fallback.
512pub fn determine_registration_strategy(
513 auth_server_metadata: &AuthServerMetadata,
514) -> ClientRegistrationStrategy {
515 if auth_server_metadata.client_id_metadata_document_supported {
516 ClientRegistrationStrategy::Cimd {
517 client_id: CIMD_URL.to_string(),
518 }
519 } else if let Some(ref endpoint) = auth_server_metadata.registration_endpoint {
520 ClientRegistrationStrategy::Dcr {
521 registration_endpoint: endpoint.clone(),
522 }
523 } else {
524 ClientRegistrationStrategy::Unavailable
525 }
526}
527
528// -- PKCE (RFC 7636) ---------------------------------------------------------
529
530/// A PKCE code verifier and its S256 challenge.
531#[derive(Clone)]
532pub struct PkceChallenge {
533 pub verifier: String,
534 pub challenge: String,
535}
536
537impl std::fmt::Debug for PkceChallenge {
538 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
539 f.debug_struct("PkceChallenge")
540 .field("verifier", &"[redacted]")
541 .field("challenge", &self.challenge)
542 .finish()
543 }
544}
545
546/// Generate a PKCE code verifier and S256 challenge per RFC 7636.
547///
548/// The verifier is 43 base64url characters derived from 32 random bytes.
549/// The challenge is `BASE64URL(SHA256(verifier))`.
550pub fn generate_pkce_challenge() -> PkceChallenge {
551 let mut random_bytes = [0u8; 32];
552 rand::rng().fill(&mut random_bytes);
553 let engine = base64::engine::general_purpose::URL_SAFE_NO_PAD;
554 let verifier = engine.encode(&random_bytes);
555
556 let digest = Sha256::digest(verifier.as_bytes());
557 let challenge = engine.encode(digest);
558
559 PkceChallenge {
560 verifier,
561 challenge,
562 }
563}
564
565// -- Authorization URL construction ------------------------------------------
566
567/// Build the authorization URL for the OAuth Authorization Code + PKCE flow.
568pub fn build_authorization_url(
569 auth_server_metadata: &AuthServerMetadata,
570 client_id: &str,
571 redirect_uri: &str,
572 scopes: &[String],
573 resource: &str,
574 pkce: &PkceChallenge,
575 state: &str,
576) -> Url {
577 let mut url = auth_server_metadata.authorization_endpoint.clone();
578 {
579 let mut query = url.query_pairs_mut();
580 query.append_pair("response_type", "code");
581 query.append_pair("client_id", client_id);
582 query.append_pair("redirect_uri", redirect_uri);
583 if !scopes.is_empty() {
584 query.append_pair("scope", &scopes.join(" "));
585 }
586 query.append_pair("resource", resource);
587 query.append_pair("code_challenge", &pkce.challenge);
588 query.append_pair("code_challenge_method", "S256");
589 query.append_pair("state", state);
590 }
591 url
592}
593
594// -- Token endpoint request bodies -------------------------------------------
595
596/// The JSON body returned by the token endpoint on success.
597#[derive(Deserialize)]
598pub struct TokenResponse {
599 pub access_token: String,
600 #[serde(default)]
601 pub refresh_token: Option<String>,
602 #[serde(default)]
603 pub expires_in: Option<u64>,
604 #[serde(default)]
605 pub token_type: Option<String>,
606}
607
608impl std::fmt::Debug for TokenResponse {
609 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
610 f.debug_struct("TokenResponse")
611 .field("access_token", &"[redacted]")
612 .field(
613 "refresh_token",
614 &self.refresh_token.as_ref().map(|_| "[redacted]"),
615 )
616 .field("expires_in", &self.expires_in)
617 .field("token_type", &self.token_type)
618 .finish()
619 }
620}
621
622impl TokenResponse {
623 /// Convert into `OAuthTokens`, computing `expires_at` from `expires_in`.
624 pub fn into_tokens(self) -> OAuthTokens {
625 let expires_at = self
626 .expires_in
627 .map(|secs| SystemTime::now() + Duration::from_secs(secs));
628 OAuthTokens {
629 access_token: self.access_token,
630 refresh_token: self.refresh_token,
631 expires_at,
632 }
633 }
634}
635
636/// Build the form-encoded body for an authorization code token exchange.
637pub fn token_exchange_params(
638 code: &str,
639 client_id: &str,
640 redirect_uri: &str,
641 code_verifier: &str,
642 resource: &str,
643) -> Vec<(&'static str, String)> {
644 vec![
645 ("grant_type", "authorization_code".to_string()),
646 ("code", code.to_string()),
647 ("redirect_uri", redirect_uri.to_string()),
648 ("client_id", client_id.to_string()),
649 ("code_verifier", code_verifier.to_string()),
650 ("resource", resource.to_string()),
651 ]
652}
653
654/// Build the form-encoded body for a token refresh request.
655pub fn token_refresh_params(
656 refresh_token: &str,
657 client_id: &str,
658 resource: &str,
659) -> Vec<(&'static str, String)> {
660 vec![
661 ("grant_type", "refresh_token".to_string()),
662 ("refresh_token", refresh_token.to_string()),
663 ("client_id", client_id.to_string()),
664 ("resource", resource.to_string()),
665 ]
666}
667
668// -- DCR request body (RFC 7591) ---------------------------------------------
669
670/// Build the JSON body for a Dynamic Client Registration request.
671///
672/// The `redirect_uri` should be the actual loopback URI with the ephemeral
673/// port (e.g. `http://127.0.0.1:12345/callback`). Some auth servers do strict
674/// redirect URI matching even for loopback addresses, so we register the
675/// exact URI we intend to use.
676/// The grant types Zed can use. Intersected with the server's
677/// `grant_types_supported` to build the DCR request.
678const SUPPORTED_GRANT_TYPES: &[&str] = &["authorization_code", "refresh_token"];
679
680pub fn dcr_registration_body(
681 redirect_uri: &str,
682 server_grant_types: Option<&[String]>,
683) -> serde_json::Value {
684 // Use the intersection of what we support and what the server advertises.
685 // When the server doesn't advertise grant_types_supported, send all of
686 // ours — the server will reject what it doesn't like.
687 let grant_types: Vec<&str> = match server_grant_types {
688 Some(server) => SUPPORTED_GRANT_TYPES
689 .iter()
690 .copied()
691 .filter(|gt| server.iter().any(|s| s == *gt))
692 .collect(),
693 None => SUPPORTED_GRANT_TYPES.to_vec(),
694 };
695
696 serde_json::json!({
697 "client_name": "Zed",
698 "redirect_uris": [redirect_uri],
699 "grant_types": grant_types,
700 "response_types": ["code"],
701 "token_endpoint_auth_method": "none"
702 })
703}
704
705// -- Discovery (async, hits real endpoints) ----------------------------------
706
707/// Fetch Protected Resource Metadata from the MCP server.
708///
709/// Tries the `resource_metadata` URL from the `WWW-Authenticate` header first,
710/// then falls back to well-known URIs constructed from `server_url`.
711pub async fn fetch_protected_resource_metadata(
712 http_client: &Arc<dyn HttpClient>,
713 server_url: &Url,
714 www_authenticate: &WwwAuthenticate,
715) -> Result<ProtectedResourceMetadata> {
716 let candidate_urls = match &www_authenticate.resource_metadata {
717 Some(url) if url.origin() == server_url.origin() => vec![url.clone()],
718 Some(url) => {
719 log::warn!(
720 "Ignoring cross-origin resource_metadata URL {} \
721 (server origin: {})",
722 url,
723 server_url.origin().unicode_serialization()
724 );
725 protected_resource_metadata_urls(server_url)
726 }
727 None => protected_resource_metadata_urls(server_url),
728 };
729
730 for url in &candidate_urls {
731 match fetch_json::<ProtectedResourceMetadataResponse>(http_client, url).await {
732 Ok(response) => {
733 if response.authorization_servers.is_empty() {
734 bail!(
735 "Protected Resource Metadata at {} has no authorization_servers",
736 url
737 );
738 }
739 return Ok(ProtectedResourceMetadata {
740 resource: response.resource.unwrap_or_else(|| server_url.clone()),
741 authorization_servers: response.authorization_servers,
742 scopes_supported: response.scopes_supported,
743 });
744 }
745 Err(err) => {
746 log::debug!(
747 "Failed to fetch Protected Resource Metadata from {}: {}",
748 url,
749 err
750 );
751 }
752 }
753 }
754
755 bail!(
756 "Could not fetch Protected Resource Metadata for {}",
757 server_url
758 )
759}
760
761/// Fetch Authorization Server Metadata, trying RFC 8414 and OIDC Discovery
762/// endpoints in the priority order specified by the MCP spec.
763pub async fn fetch_auth_server_metadata(
764 http_client: &Arc<dyn HttpClient>,
765 issuer: &Url,
766) -> Result<AuthServerMetadata> {
767 let candidate_urls = auth_server_metadata_urls(issuer);
768
769 for url in &candidate_urls {
770 match fetch_json::<AuthServerMetadataResponse>(http_client, url).await {
771 Ok(response) => {
772 let reported_issuer = response.issuer.unwrap_or_else(|| issuer.clone());
773 if reported_issuer != *issuer {
774 bail!(
775 "Auth server metadata issuer mismatch: expected {}, got {}",
776 issuer,
777 reported_issuer
778 );
779 }
780
781 return Ok(AuthServerMetadata {
782 issuer: reported_issuer,
783 grant_types_supported: response.grant_types_supported,
784 authorization_endpoint: response
785 .authorization_endpoint
786 .ok_or_else(|| anyhow!("missing authorization_endpoint"))?,
787 token_endpoint: response
788 .token_endpoint
789 .ok_or_else(|| anyhow!("missing token_endpoint"))?,
790 registration_endpoint: response.registration_endpoint,
791 scopes_supported: response.scopes_supported,
792 code_challenge_methods_supported: response.code_challenge_methods_supported,
793 client_id_metadata_document_supported: response
794 .client_id_metadata_document_supported
795 .unwrap_or(false),
796 });
797 }
798 Err(err) => {
799 log::debug!("Failed to fetch Auth Server Metadata from {}: {}", url, err);
800 }
801 }
802 }
803
804 bail!(
805 "Could not fetch Authorization Server Metadata for {}",
806 issuer
807 )
808}
809
810/// Run the full discovery flow: fetch resource metadata, then auth server
811/// metadata, then select scopes. Client registration is resolved separately,
812/// once the real redirect URI is known.
813pub async fn discover(
814 http_client: &Arc<dyn HttpClient>,
815 server_url: &Url,
816 www_authenticate: &WwwAuthenticate,
817) -> Result<OAuthDiscovery> {
818 let resource_metadata =
819 fetch_protected_resource_metadata(http_client, server_url, www_authenticate).await?;
820
821 let auth_server_url = resource_metadata
822 .authorization_servers
823 .first()
824 .ok_or_else(|| anyhow!("no authorization servers in resource metadata"))?;
825
826 let auth_server_metadata = fetch_auth_server_metadata(http_client, auth_server_url).await?;
827
828 // Verify PKCE S256 support (spec requirement).
829 match &auth_server_metadata.code_challenge_methods_supported {
830 Some(methods) if methods.iter().any(|m| m == "S256") => {}
831 Some(_) => bail!("authorization server does not support S256 PKCE"),
832 None => bail!("authorization server does not advertise code_challenge_methods_supported"),
833 }
834
835 // Verify there is at least one supported registration strategy before we
836 // present the server as ready to authenticate.
837 match determine_registration_strategy(&auth_server_metadata) {
838 ClientRegistrationStrategy::Cimd { .. } | ClientRegistrationStrategy::Dcr { .. } => {}
839 ClientRegistrationStrategy::Unavailable => {
840 bail!("authorization server supports neither CIMD nor DCR")
841 }
842 }
843
844 let scopes = select_scopes(www_authenticate, &resource_metadata);
845
846 Ok(OAuthDiscovery {
847 resource_metadata,
848 auth_server_metadata,
849 scopes,
850 })
851}
852
853/// Resolve the OAuth client registration for an authorization flow.
854///
855/// CIMD uses the static client metadata document directly. For DCR, a fresh
856/// registration is performed each time because the loopback redirect URI
857/// includes an ephemeral port that changes every flow.
858pub async fn resolve_client_registration(
859 http_client: &Arc<dyn HttpClient>,
860 discovery: &OAuthDiscovery,
861 redirect_uri: &str,
862) -> Result<OAuthClientRegistration> {
863 match determine_registration_strategy(&discovery.auth_server_metadata) {
864 ClientRegistrationStrategy::Cimd { client_id } => Ok(OAuthClientRegistration {
865 client_id,
866 client_secret: None,
867 }),
868 ClientRegistrationStrategy::Dcr {
869 registration_endpoint,
870 } => {
871 perform_dcr(
872 http_client,
873 ®istration_endpoint,
874 redirect_uri,
875 discovery
876 .auth_server_metadata
877 .grant_types_supported
878 .as_deref(),
879 )
880 .await
881 }
882 ClientRegistrationStrategy::Unavailable => {
883 bail!("authorization server supports neither CIMD nor DCR")
884 }
885 }
886}
887
888// -- Dynamic Client Registration (RFC 7591) ----------------------------------
889
890/// Perform Dynamic Client Registration with the authorization server.
891pub async fn perform_dcr(
892 http_client: &Arc<dyn HttpClient>,
893 registration_endpoint: &Url,
894 redirect_uri: &str,
895 server_grant_types: Option<&[String]>,
896) -> Result<OAuthClientRegistration> {
897 validate_oauth_url(registration_endpoint)?;
898
899 let body = dcr_registration_body(redirect_uri, server_grant_types);
900 let body_bytes = serde_json::to_vec(&body)?;
901
902 let request = Request::builder()
903 .method(http_client::http::Method::POST)
904 .uri(registration_endpoint.as_str())
905 .header("Content-Type", "application/json")
906 .header("Accept", "application/json")
907 .body(AsyncBody::from(body_bytes))?;
908
909 let mut response = http_client.send(request).await?;
910
911 if !response.status().is_success() {
912 let mut error_body = String::new();
913 response.body_mut().read_to_string(&mut error_body).await?;
914 bail!(
915 "DCR failed with status {}: {}",
916 response.status(),
917 error_body
918 );
919 }
920
921 let mut response_body = String::new();
922 response
923 .body_mut()
924 .read_to_string(&mut response_body)
925 .await?;
926
927 let dcr_response: DcrResponse =
928 serde_json::from_str(&response_body).context("failed to parse DCR response")?;
929
930 Ok(OAuthClientRegistration {
931 client_id: dcr_response.client_id,
932 client_secret: dcr_response.client_secret,
933 })
934}
935
936// -- Token exchange and refresh (async) --------------------------------------
937
938/// Exchange an authorization code for tokens at the token endpoint.
939pub async fn exchange_code(
940 http_client: &Arc<dyn HttpClient>,
941 auth_server_metadata: &AuthServerMetadata,
942 code: &str,
943 client_id: &str,
944 redirect_uri: &str,
945 code_verifier: &str,
946 resource: &str,
947) -> Result<OAuthTokens> {
948 let params = token_exchange_params(code, client_id, redirect_uri, code_verifier, resource);
949 post_token_request(http_client, &auth_server_metadata.token_endpoint, ¶ms).await
950}
951
952/// Refresh tokens using a refresh token.
953pub async fn refresh_tokens(
954 http_client: &Arc<dyn HttpClient>,
955 token_endpoint: &Url,
956 refresh_token: &str,
957 client_id: &str,
958 resource: &str,
959) -> Result<OAuthTokens> {
960 let params = token_refresh_params(refresh_token, client_id, resource);
961 post_token_request(http_client, token_endpoint, ¶ms).await
962}
963
964/// POST form-encoded parameters to a token endpoint and parse the response.
965async fn post_token_request(
966 http_client: &Arc<dyn HttpClient>,
967 token_endpoint: &Url,
968 params: &[(&str, String)],
969) -> Result<OAuthTokens> {
970 validate_oauth_url(token_endpoint)?;
971
972 let body = url::form_urlencoded::Serializer::new(String::new())
973 .extend_pairs(params.iter().map(|(k, v)| (*k, v.as_str())))
974 .finish();
975
976 let request = Request::builder()
977 .method(http_client::http::Method::POST)
978 .uri(token_endpoint.as_str())
979 .header("Content-Type", "application/x-www-form-urlencoded")
980 .header("Accept", "application/json")
981 .body(AsyncBody::from(body.into_bytes()))?;
982
983 let mut response = http_client.send(request).await?;
984
985 if !response.status().is_success() {
986 let mut error_body = String::new();
987 response.body_mut().read_to_string(&mut error_body).await?;
988 bail!(
989 "token request failed with status {}: {}",
990 response.status(),
991 error_body
992 );
993 }
994
995 let mut response_body = String::new();
996 response
997 .body_mut()
998 .read_to_string(&mut response_body)
999 .await?;
1000
1001 let token_response: TokenResponse =
1002 serde_json::from_str(&response_body).context("failed to parse token response")?;
1003
1004 Ok(token_response.into_tokens())
1005}
1006
1007// -- Loopback HTTP callback server -------------------------------------------
1008
1009/// An OAuth authorization callback received via the loopback HTTP server.
1010pub struct OAuthCallback {
1011 pub code: String,
1012 pub state: String,
1013}
1014
1015impl std::fmt::Debug for OAuthCallback {
1016 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1017 f.debug_struct("OAuthCallback")
1018 .field("code", &"[redacted]")
1019 .field("state", &"[redacted]")
1020 .finish()
1021 }
1022}
1023
1024impl OAuthCallback {
1025 /// Parse the query string from a callback URL like
1026 /// `http://127.0.0.1:<port>/callback?code=...&state=...`.
1027 pub fn parse_query(query: &str) -> Result<Self> {
1028 let mut code: Option<String> = None;
1029 let mut state: Option<String> = None;
1030 let mut error: Option<String> = None;
1031 let mut error_description: Option<String> = None;
1032
1033 for (key, value) in url::form_urlencoded::parse(query.as_bytes()) {
1034 match key.as_ref() {
1035 "code" => {
1036 if !value.is_empty() {
1037 code = Some(value.into_owned());
1038 }
1039 }
1040 "state" => {
1041 if !value.is_empty() {
1042 state = Some(value.into_owned());
1043 }
1044 }
1045 "error" => {
1046 if !value.is_empty() {
1047 error = Some(value.into_owned());
1048 }
1049 }
1050 "error_description" => {
1051 if !value.is_empty() {
1052 error_description = Some(value.into_owned());
1053 }
1054 }
1055 _ => {}
1056 }
1057 }
1058
1059 // Check for OAuth error response (RFC 6749 Section 4.1.2.1) before
1060 // checking for missing code/state.
1061 if let Some(error_code) = error {
1062 bail!(
1063 "OAuth authorization failed: {} ({})",
1064 error_code,
1065 error_description.as_deref().unwrap_or("no description")
1066 );
1067 }
1068
1069 let code = code.ok_or_else(|| anyhow!("missing 'code' parameter in OAuth callback"))?;
1070 let state = state.ok_or_else(|| anyhow!("missing 'state' parameter in OAuth callback"))?;
1071
1072 Ok(Self { code, state })
1073 }
1074}
1075
1076/// How long to wait for the browser to complete the OAuth flow before giving
1077/// up and releasing the loopback port.
1078const CALLBACK_TIMEOUT: Duration = Duration::from_secs(2 * 60);
1079
1080/// Start a loopback HTTP server to receive the OAuth authorization callback.
1081///
1082/// Binds to an ephemeral loopback port for each flow.
1083///
1084/// Returns `(redirect_uri, callback_future)`. The caller should use the
1085/// redirect URI in the authorization request, open the browser, then await
1086/// the future to receive the callback.
1087///
1088/// The server accepts exactly one request on `/callback`, validates that it
1089/// contains `code` and `state` query parameters, responds with a minimal
1090/// HTML page telling the user they can close the tab, and shuts down.
1091///
1092/// The callback server shuts down when the returned oneshot receiver is dropped
1093/// (e.g. because the authentication task was cancelled), or after a timeout
1094/// ([CALLBACK_TIMEOUT]).
1095pub async fn start_callback_server() -> Result<(
1096 String,
1097 futures::channel::oneshot::Receiver<Result<OAuthCallback>>,
1098)> {
1099 let server = tiny_http::Server::http("127.0.0.1:0")
1100 .map_err(|e| anyhow!(e).context("Failed to bind loopback listener for OAuth callback"))?;
1101 let port = server
1102 .server_addr()
1103 .to_ip()
1104 .context("server not bound to a TCP address")?
1105 .port();
1106
1107 let redirect_uri = format!("http://127.0.0.1:{}/callback", port);
1108
1109 let (tx, rx) = futures::channel::oneshot::channel();
1110
1111 // `tiny_http` is blocking, so we run it on a background thread.
1112 // The `recv_timeout` loop lets us check for cancellation (the receiver
1113 // being dropped) and enforce an overall timeout.
1114 std::thread::spawn(move || {
1115 let deadline = std::time::Instant::now() + CALLBACK_TIMEOUT;
1116
1117 loop {
1118 if tx.is_canceled() {
1119 return;
1120 }
1121 let remaining = deadline.saturating_duration_since(std::time::Instant::now());
1122 if remaining.is_zero() {
1123 return;
1124 }
1125
1126 let timeout = remaining.min(Duration::from_millis(500));
1127 let Some(request) = (match server.recv_timeout(timeout) {
1128 Ok(req) => req,
1129 Err(_) => {
1130 let _ = tx.send(Err(anyhow!("OAuth callback server I/O error")));
1131 return;
1132 }
1133 }) else {
1134 // Timeout with no request — loop back and check cancellation.
1135 continue;
1136 };
1137
1138 let result = handle_callback_request(&request);
1139
1140 let (status_code, body) = match &result {
1141 Ok(_) => (
1142 200,
1143 "<html><body><h1>Authorization successful</h1>\
1144 <p>You can close this tab and return to Zed.</p></body></html>",
1145 ),
1146 Err(err) => {
1147 log::error!("OAuth callback error: {}", err);
1148 (
1149 400,
1150 "<html><body><h1>Authorization failed</h1>\
1151 <p>Something went wrong. Please try again from Zed.</p></body></html>",
1152 )
1153 }
1154 };
1155
1156 let response = tiny_http::Response::from_string(body)
1157 .with_status_code(status_code)
1158 .with_header(
1159 tiny_http::Header::from_str("Content-Type: text/html")
1160 .expect("failed to construct response header"),
1161 )
1162 .with_header(
1163 tiny_http::Header::from_str("Keep-Alive: timeout=0,max=0")
1164 .expect("failed to construct response header"),
1165 );
1166 request.respond(response).log_err();
1167
1168 let _ = tx.send(result);
1169 return;
1170 }
1171 });
1172
1173 Ok((redirect_uri, rx))
1174}
1175
1176/// Extract the `code` and `state` query parameters from an OAuth callback
1177/// request to `/callback`.
1178fn handle_callback_request(request: &tiny_http::Request) -> Result<OAuthCallback> {
1179 let url = Url::parse(&format!("http://localhost{}", request.url()))
1180 .context("malformed callback request URL")?;
1181
1182 if url.path() != "/callback" {
1183 bail!("unexpected path in OAuth callback: {}", url.path());
1184 }
1185
1186 let query = url
1187 .query()
1188 .ok_or_else(|| anyhow!("OAuth callback has no query string"))?;
1189 OAuthCallback::parse_query(query)
1190}
1191
1192// -- JSON fetch helper -------------------------------------------------------
1193
1194async fn fetch_json<T: serde::de::DeserializeOwned>(
1195 http_client: &Arc<dyn HttpClient>,
1196 url: &Url,
1197) -> Result<T> {
1198 validate_oauth_url(url)?;
1199
1200 let request = Request::builder()
1201 .method(http_client::http::Method::GET)
1202 .uri(url.as_str())
1203 .header("Accept", "application/json")
1204 .body(AsyncBody::default())?;
1205
1206 let mut response = http_client.send(request).await?;
1207
1208 if !response.status().is_success() {
1209 bail!("HTTP {} fetching {}", response.status(), url);
1210 }
1211
1212 let mut body = String::new();
1213 response.body_mut().read_to_string(&mut body).await?;
1214 serde_json::from_str(&body).with_context(|| format!("failed to parse JSON from {}", url))
1215}
1216
1217// -- Serde response types for discovery --------------------------------------
1218
1219#[derive(Debug, Deserialize)]
1220struct ProtectedResourceMetadataResponse {
1221 #[serde(default)]
1222 resource: Option<Url>,
1223 #[serde(default)]
1224 authorization_servers: Vec<Url>,
1225 #[serde(default)]
1226 scopes_supported: Option<Vec<String>>,
1227}
1228
1229#[derive(Debug, Deserialize)]
1230struct AuthServerMetadataResponse {
1231 #[serde(default)]
1232 issuer: Option<Url>,
1233 #[serde(default)]
1234 authorization_endpoint: Option<Url>,
1235 #[serde(default)]
1236 token_endpoint: Option<Url>,
1237 #[serde(default)]
1238 registration_endpoint: Option<Url>,
1239 #[serde(default)]
1240 scopes_supported: Option<Vec<String>>,
1241 #[serde(default)]
1242 grant_types_supported: Option<Vec<String>>,
1243 #[serde(default)]
1244 code_challenge_methods_supported: Option<Vec<String>>,
1245 #[serde(default)]
1246 client_id_metadata_document_supported: Option<bool>,
1247}
1248
1249#[derive(Debug, Deserialize)]
1250struct DcrResponse {
1251 client_id: String,
1252 #[serde(default)]
1253 client_secret: Option<String>,
1254}
1255
1256/// Provides OAuth tokens to the HTTP transport layer.
1257///
1258/// The transport calls `access_token()` before each request. On a 401 response
1259/// it calls `try_refresh()` and retries once if the refresh succeeds.
1260#[async_trait]
1261pub trait OAuthTokenProvider: Send + Sync {
1262 /// Returns the current access token, if one is available.
1263 fn access_token(&self) -> Option<String>;
1264
1265 /// Attempts to refresh the access token. Returns `true` if a new token was
1266 /// obtained and the request should be retried.
1267 async fn try_refresh(&self) -> Result<bool>;
1268}
1269
1270/// Concrete `OAuthTokenProvider` backed by a full persisted OAuth session and
1271/// an HTTP client for token refresh. The same provider type is used both after
1272/// an interactive authentication flow and when restoring a saved session from
1273/// the keychain on startup.
1274pub struct McpOAuthTokenProvider {
1275 session: SyncMutex<OAuthSession>,
1276 http_client: Arc<dyn HttpClient>,
1277 token_refresh_tx: Option<mpsc::UnboundedSender<OAuthSession>>,
1278}
1279
1280impl McpOAuthTokenProvider {
1281 pub fn new(
1282 session: OAuthSession,
1283 http_client: Arc<dyn HttpClient>,
1284 token_refresh_tx: Option<mpsc::UnboundedSender<OAuthSession>>,
1285 ) -> Self {
1286 Self {
1287 session: SyncMutex::new(session),
1288 http_client,
1289 token_refresh_tx,
1290 }
1291 }
1292
1293 fn access_token_is_expired(tokens: &OAuthTokens) -> bool {
1294 tokens.expires_at.is_some_and(|expires_at| {
1295 SystemTime::now()
1296 .checked_add(Duration::from_secs(30))
1297 .is_some_and(|now_with_buffer| expires_at <= now_with_buffer)
1298 })
1299 }
1300}
1301
1302#[async_trait]
1303impl OAuthTokenProvider for McpOAuthTokenProvider {
1304 fn access_token(&self) -> Option<String> {
1305 let session = self.session.lock();
1306 if Self::access_token_is_expired(&session.tokens) {
1307 return None;
1308 }
1309 Some(session.tokens.access_token.clone())
1310 }
1311
1312 async fn try_refresh(&self) -> Result<bool> {
1313 let (refresh_token, token_endpoint, resource, client_id) = {
1314 let session = self.session.lock();
1315 match session.tokens.refresh_token.clone() {
1316 Some(refresh_token) => (
1317 refresh_token,
1318 session.token_endpoint.clone(),
1319 session.resource.clone(),
1320 session.client_registration.client_id.clone(),
1321 ),
1322 None => return Ok(false),
1323 }
1324 };
1325
1326 let resource_str = canonical_server_uri(&resource);
1327
1328 match refresh_tokens(
1329 &self.http_client,
1330 &token_endpoint,
1331 &refresh_token,
1332 &client_id,
1333 &resource_str,
1334 )
1335 .await
1336 {
1337 Ok(mut new_tokens) => {
1338 if new_tokens.refresh_token.is_none() {
1339 new_tokens.refresh_token = Some(refresh_token);
1340 }
1341
1342 {
1343 let mut session = self.session.lock();
1344 session.tokens = new_tokens;
1345
1346 if let Some(ref tx) = self.token_refresh_tx {
1347 tx.unbounded_send(session.clone()).ok();
1348 }
1349 }
1350
1351 Ok(true)
1352 }
1353 Err(err) => {
1354 log::warn!("OAuth token refresh failed: {}", err);
1355 Ok(false)
1356 }
1357 }
1358 }
1359}
1360
1361#[cfg(test)]
1362mod tests {
1363 use super::*;
1364 use http_client::Response;
1365
1366 // -- require_https_or_loopback tests ------------------------------------
1367
1368 #[test]
1369 fn test_require_https_or_loopback_accepts_https() {
1370 let url = Url::parse("https://auth.example.com/token").unwrap();
1371 assert!(require_https_or_loopback(&url).is_ok());
1372 }
1373
1374 #[test]
1375 fn test_require_https_or_loopback_rejects_http_remote() {
1376 let url = Url::parse("http://auth.example.com/token").unwrap();
1377 assert!(require_https_or_loopback(&url).is_err());
1378 }
1379
1380 #[test]
1381 fn test_require_https_or_loopback_accepts_http_127_0_0_1() {
1382 let url = Url::parse("http://127.0.0.1:8080/callback").unwrap();
1383 assert!(require_https_or_loopback(&url).is_ok());
1384 }
1385
1386 #[test]
1387 fn test_require_https_or_loopback_accepts_http_ipv6_loopback() {
1388 let url = Url::parse("http://[::1]:8080/callback").unwrap();
1389 assert!(require_https_or_loopback(&url).is_ok());
1390 }
1391
1392 #[test]
1393 fn test_require_https_or_loopback_accepts_http_localhost() {
1394 let url = Url::parse("http://localhost:8080/callback").unwrap();
1395 assert!(require_https_or_loopback(&url).is_ok());
1396 }
1397
1398 #[test]
1399 fn test_require_https_or_loopback_accepts_http_localhost_case_insensitive() {
1400 let url = Url::parse("http://LOCALHOST:8080/callback").unwrap();
1401 assert!(require_https_or_loopback(&url).is_ok());
1402 }
1403
1404 #[test]
1405 fn test_require_https_or_loopback_rejects_http_non_loopback_ip() {
1406 let url = Url::parse("http://192.168.1.1:8080/token").unwrap();
1407 assert!(require_https_or_loopback(&url).is_err());
1408 }
1409
1410 #[test]
1411 fn test_require_https_or_loopback_rejects_ftp() {
1412 let url = Url::parse("ftp://auth.example.com/token").unwrap();
1413 assert!(require_https_or_loopback(&url).is_err());
1414 }
1415
1416 // -- validate_oauth_url (SSRF) tests ------------------------------------
1417
1418 #[test]
1419 fn test_validate_oauth_url_accepts_https_public() {
1420 let url = Url::parse("https://auth.example.com/token").unwrap();
1421 assert!(validate_oauth_url(&url).is_ok());
1422 }
1423
1424 #[test]
1425 fn test_validate_oauth_url_rejects_private_ipv4_10() {
1426 let url = Url::parse("https://10.0.0.1/token").unwrap();
1427 assert!(validate_oauth_url(&url).is_err());
1428 }
1429
1430 #[test]
1431 fn test_validate_oauth_url_rejects_private_ipv4_172() {
1432 let url = Url::parse("https://172.16.0.1/token").unwrap();
1433 assert!(validate_oauth_url(&url).is_err());
1434 }
1435
1436 #[test]
1437 fn test_validate_oauth_url_rejects_private_ipv4_192() {
1438 let url = Url::parse("https://192.168.1.1/token").unwrap();
1439 assert!(validate_oauth_url(&url).is_err());
1440 }
1441
1442 #[test]
1443 fn test_validate_oauth_url_rejects_link_local() {
1444 let url = Url::parse("https://169.254.169.254/latest/meta-data/").unwrap();
1445 assert!(validate_oauth_url(&url).is_err());
1446 }
1447
1448 #[test]
1449 fn test_validate_oauth_url_rejects_ipv6_ula() {
1450 let url = Url::parse("https://[fd12:3456:789a::1]/token").unwrap();
1451 assert!(validate_oauth_url(&url).is_err());
1452 }
1453
1454 #[test]
1455 fn test_validate_oauth_url_rejects_ipv6_unspecified() {
1456 let url = Url::parse("https://[::]/token").unwrap();
1457 assert!(validate_oauth_url(&url).is_err());
1458 }
1459
1460 #[test]
1461 fn test_validate_oauth_url_rejects_ipv4_mapped_ipv6_private() {
1462 let url = Url::parse("https://[::ffff:10.0.0.1]/token").unwrap();
1463 assert!(validate_oauth_url(&url).is_err());
1464 }
1465
1466 #[test]
1467 fn test_validate_oauth_url_rejects_ipv4_mapped_ipv6_link_local() {
1468 let url = Url::parse("https://[::ffff:169.254.169.254]/token").unwrap();
1469 assert!(validate_oauth_url(&url).is_err());
1470 }
1471
1472 #[test]
1473 fn test_validate_oauth_url_allows_http_loopback() {
1474 // Loopback is permitted (it's our callback server).
1475 let url = Url::parse("http://127.0.0.1:8080/callback").unwrap();
1476 assert!(validate_oauth_url(&url).is_ok());
1477 }
1478
1479 #[test]
1480 fn test_validate_oauth_url_allows_https_public_ip() {
1481 let url = Url::parse("https://93.184.216.34/token").unwrap();
1482 assert!(validate_oauth_url(&url).is_ok());
1483 }
1484
1485 // -- parse_www_authenticate tests ----------------------------------------
1486
1487 #[test]
1488 fn test_parse_www_authenticate_with_resource_metadata_and_scope() {
1489 let header = r#"Bearer resource_metadata="https://mcp.example.com/.well-known/oauth-protected-resource", scope="files:read user:profile""#;
1490 let result = parse_www_authenticate(header).unwrap();
1491
1492 assert_eq!(
1493 result.resource_metadata.as_ref().map(|u| u.as_str()),
1494 Some("https://mcp.example.com/.well-known/oauth-protected-resource")
1495 );
1496 assert_eq!(
1497 result.scope,
1498 Some(vec!["files:read".to_string(), "user:profile".to_string()])
1499 );
1500 assert_eq!(result.error, None);
1501 }
1502
1503 #[test]
1504 fn test_parse_www_authenticate_resource_metadata_only() {
1505 let header = r#"Bearer resource_metadata="https://mcp.example.com/.well-known/oauth-protected-resource""#;
1506 let result = parse_www_authenticate(header).unwrap();
1507
1508 assert_eq!(
1509 result.resource_metadata.as_ref().map(|u| u.as_str()),
1510 Some("https://mcp.example.com/.well-known/oauth-protected-resource")
1511 );
1512 assert_eq!(result.scope, None);
1513 }
1514
1515 #[test]
1516 fn test_parse_www_authenticate_bare_bearer() {
1517 let result = parse_www_authenticate("Bearer").unwrap();
1518 assert_eq!(result.resource_metadata, None);
1519 assert_eq!(result.scope, None);
1520 }
1521
1522 #[test]
1523 fn test_parse_www_authenticate_with_error() {
1524 let header = r#"Bearer error="insufficient_scope", scope="files:read files:write", resource_metadata="https://mcp.example.com/.well-known/oauth-protected-resource", error_description="Additional file write permission required""#;
1525 let result = parse_www_authenticate(header).unwrap();
1526
1527 assert_eq!(result.error, Some(BearerError::InsufficientScope));
1528 assert_eq!(
1529 result.error_description.as_deref(),
1530 Some("Additional file write permission required")
1531 );
1532 assert_eq!(
1533 result.scope,
1534 Some(vec!["files:read".to_string(), "files:write".to_string()])
1535 );
1536 assert!(result.resource_metadata.is_some());
1537 }
1538
1539 #[test]
1540 fn test_parse_www_authenticate_invalid_token_error() {
1541 let header =
1542 r#"Bearer error="invalid_token", error_description="The access token expired""#;
1543 let result = parse_www_authenticate(header).unwrap();
1544 assert_eq!(result.error, Some(BearerError::InvalidToken));
1545 }
1546
1547 #[test]
1548 fn test_parse_www_authenticate_invalid_request_error() {
1549 let header = r#"Bearer error="invalid_request""#;
1550 let result = parse_www_authenticate(header).unwrap();
1551 assert_eq!(result.error, Some(BearerError::InvalidRequest));
1552 }
1553
1554 #[test]
1555 fn test_parse_www_authenticate_unknown_error() {
1556 let header = r#"Bearer error="some_future_error""#;
1557 let result = parse_www_authenticate(header).unwrap();
1558 assert_eq!(result.error, Some(BearerError::Other));
1559 }
1560
1561 #[test]
1562 fn test_parse_www_authenticate_rejects_non_bearer() {
1563 let result = parse_www_authenticate("Basic realm=\"example\"");
1564 assert!(result.is_err());
1565 }
1566
1567 #[test]
1568 fn test_parse_www_authenticate_case_insensitive_scheme() {
1569 let header = r#"bearer resource_metadata="https://example.com/.well-known/oauth-protected-resource""#;
1570 let result = parse_www_authenticate(header).unwrap();
1571 assert!(result.resource_metadata.is_some());
1572 }
1573
1574 #[test]
1575 fn test_parse_www_authenticate_multiline_style() {
1576 // Some servers emit the header spread across multiple lines joined by
1577 // whitespace, as shown in the spec examples.
1578 let header = "Bearer resource_metadata=\"https://mcp.example.com/.well-known/oauth-protected-resource\",\n scope=\"files:read\"";
1579 let result = parse_www_authenticate(header).unwrap();
1580 assert!(result.resource_metadata.is_some());
1581 assert_eq!(result.scope, Some(vec!["files:read".to_string()]));
1582 }
1583
1584 #[test]
1585 fn test_protected_resource_metadata_urls_with_path() {
1586 let server_url = Url::parse("https://api.example.com/v1/mcp").unwrap();
1587 let urls = protected_resource_metadata_urls(&server_url);
1588
1589 assert_eq!(urls.len(), 2);
1590 assert_eq!(
1591 urls[0].as_str(),
1592 "https://api.example.com/.well-known/oauth-protected-resource/v1/mcp"
1593 );
1594 assert_eq!(
1595 urls[1].as_str(),
1596 "https://api.example.com/.well-known/oauth-protected-resource"
1597 );
1598 }
1599
1600 #[test]
1601 fn test_protected_resource_metadata_urls_without_path() {
1602 let server_url = Url::parse("https://mcp.example.com").unwrap();
1603 let urls = protected_resource_metadata_urls(&server_url);
1604
1605 assert_eq!(urls.len(), 1);
1606 assert_eq!(
1607 urls[0].as_str(),
1608 "https://mcp.example.com/.well-known/oauth-protected-resource"
1609 );
1610 }
1611
1612 #[test]
1613 fn test_auth_server_metadata_urls_with_path() {
1614 let issuer = Url::parse("https://auth.example.com/tenant1").unwrap();
1615 let urls = auth_server_metadata_urls(&issuer);
1616
1617 assert_eq!(urls.len(), 3);
1618 assert_eq!(
1619 urls[0].as_str(),
1620 "https://auth.example.com/.well-known/oauth-authorization-server/tenant1"
1621 );
1622 assert_eq!(
1623 urls[1].as_str(),
1624 "https://auth.example.com/.well-known/openid-configuration/tenant1"
1625 );
1626 assert_eq!(
1627 urls[2].as_str(),
1628 "https://auth.example.com/tenant1/.well-known/openid-configuration"
1629 );
1630 }
1631
1632 #[test]
1633 fn test_auth_server_metadata_urls_without_path() {
1634 let issuer = Url::parse("https://auth.example.com").unwrap();
1635 let urls = auth_server_metadata_urls(&issuer);
1636
1637 assert_eq!(urls.len(), 2);
1638 assert_eq!(
1639 urls[0].as_str(),
1640 "https://auth.example.com/.well-known/oauth-authorization-server"
1641 );
1642 assert_eq!(
1643 urls[1].as_str(),
1644 "https://auth.example.com/.well-known/openid-configuration"
1645 );
1646 }
1647
1648 // -- Canonical server URI tests ------------------------------------------
1649
1650 #[test]
1651 fn test_canonical_server_uri_simple() {
1652 let url = Url::parse("https://mcp.example.com").unwrap();
1653 assert_eq!(canonical_server_uri(&url), "https://mcp.example.com");
1654 }
1655
1656 #[test]
1657 fn test_canonical_server_uri_with_path() {
1658 let url = Url::parse("https://mcp.example.com/v1/mcp").unwrap();
1659 assert_eq!(canonical_server_uri(&url), "https://mcp.example.com/v1/mcp");
1660 }
1661
1662 #[test]
1663 fn test_canonical_server_uri_strips_trailing_slash() {
1664 let url = Url::parse("https://mcp.example.com/").unwrap();
1665 assert_eq!(canonical_server_uri(&url), "https://mcp.example.com");
1666 }
1667
1668 #[test]
1669 fn test_canonical_server_uri_preserves_port() {
1670 let url = Url::parse("https://mcp.example.com:8443").unwrap();
1671 assert_eq!(canonical_server_uri(&url), "https://mcp.example.com:8443");
1672 }
1673
1674 #[test]
1675 fn test_canonical_server_uri_lowercases() {
1676 let url = Url::parse("HTTPS://MCP.Example.COM/Server/MCP").unwrap();
1677 assert_eq!(
1678 canonical_server_uri(&url),
1679 "https://mcp.example.com/Server/MCP"
1680 );
1681 }
1682
1683 // -- Scope selection tests -----------------------------------------------
1684
1685 #[test]
1686 fn test_select_scopes_prefers_www_authenticate() {
1687 let www_auth = WwwAuthenticate {
1688 resource_metadata: None,
1689 scope: Some(vec!["files:read".into()]),
1690 error: None,
1691 error_description: None,
1692 };
1693 let resource_meta = ProtectedResourceMetadata {
1694 resource: Url::parse("https://example.com").unwrap(),
1695 authorization_servers: vec![],
1696 scopes_supported: Some(vec!["files:read".into(), "files:write".into()]),
1697 };
1698 assert_eq!(select_scopes(&www_auth, &resource_meta), vec!["files:read"]);
1699 }
1700
1701 #[test]
1702 fn test_select_scopes_falls_back_to_resource_metadata() {
1703 let www_auth = WwwAuthenticate {
1704 resource_metadata: None,
1705 scope: None,
1706 error: None,
1707 error_description: None,
1708 };
1709 let resource_meta = ProtectedResourceMetadata {
1710 resource: Url::parse("https://example.com").unwrap(),
1711 authorization_servers: vec![],
1712 scopes_supported: Some(vec!["admin".into()]),
1713 };
1714 assert_eq!(select_scopes(&www_auth, &resource_meta), vec!["admin"]);
1715 }
1716
1717 #[test]
1718 fn test_select_scopes_empty_when_nothing_available() {
1719 let www_auth = WwwAuthenticate {
1720 resource_metadata: None,
1721 scope: None,
1722 error: None,
1723 error_description: None,
1724 };
1725 let resource_meta = ProtectedResourceMetadata {
1726 resource: Url::parse("https://example.com").unwrap(),
1727 authorization_servers: vec![],
1728 scopes_supported: None,
1729 };
1730 assert!(select_scopes(&www_auth, &resource_meta).is_empty());
1731 }
1732
1733 // -- Client registration strategy tests ----------------------------------
1734
1735 #[test]
1736 fn test_registration_strategy_prefers_cimd() {
1737 let metadata = AuthServerMetadata {
1738 issuer: Url::parse("https://auth.example.com").unwrap(),
1739 authorization_endpoint: Url::parse("https://auth.example.com/authorize").unwrap(),
1740 token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
1741 registration_endpoint: Some(Url::parse("https://auth.example.com/register").unwrap()),
1742 scopes_supported: None,
1743 code_challenge_methods_supported: Some(vec!["S256".into()]),
1744 client_id_metadata_document_supported: true,
1745 grant_types_supported: None,
1746 };
1747 assert_eq!(
1748 determine_registration_strategy(&metadata),
1749 ClientRegistrationStrategy::Cimd {
1750 client_id: CIMD_URL.to_string(),
1751 }
1752 );
1753 }
1754
1755 #[test]
1756 fn test_registration_strategy_falls_back_to_dcr() {
1757 let reg_endpoint = Url::parse("https://auth.example.com/register").unwrap();
1758 let metadata = AuthServerMetadata {
1759 issuer: Url::parse("https://auth.example.com").unwrap(),
1760 authorization_endpoint: Url::parse("https://auth.example.com/authorize").unwrap(),
1761 token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
1762 registration_endpoint: Some(reg_endpoint.clone()),
1763 scopes_supported: None,
1764 code_challenge_methods_supported: Some(vec!["S256".into()]),
1765 client_id_metadata_document_supported: false,
1766 grant_types_supported: None,
1767 };
1768 assert_eq!(
1769 determine_registration_strategy(&metadata),
1770 ClientRegistrationStrategy::Dcr {
1771 registration_endpoint: reg_endpoint,
1772 }
1773 );
1774 }
1775
1776 #[test]
1777 fn test_registration_strategy_unavailable() {
1778 let metadata = AuthServerMetadata {
1779 issuer: Url::parse("https://auth.example.com").unwrap(),
1780 authorization_endpoint: Url::parse("https://auth.example.com/authorize").unwrap(),
1781 token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
1782 registration_endpoint: None,
1783 scopes_supported: None,
1784 code_challenge_methods_supported: Some(vec!["S256".into()]),
1785 client_id_metadata_document_supported: false,
1786 grant_types_supported: None,
1787 };
1788 assert_eq!(
1789 determine_registration_strategy(&metadata),
1790 ClientRegistrationStrategy::Unavailable,
1791 );
1792 }
1793
1794 // -- PKCE tests ----------------------------------------------------------
1795
1796 #[test]
1797 fn test_pkce_challenge_verifier_length() {
1798 let pkce = generate_pkce_challenge();
1799 // 32 random bytes → 43 base64url chars (no padding).
1800 assert_eq!(pkce.verifier.len(), 43);
1801 }
1802
1803 #[test]
1804 fn test_pkce_challenge_is_valid_base64url() {
1805 let pkce = generate_pkce_challenge();
1806 for c in pkce.verifier.chars().chain(pkce.challenge.chars()) {
1807 assert!(
1808 c.is_ascii_alphanumeric() || c == '-' || c == '_',
1809 "invalid base64url character: {}",
1810 c
1811 );
1812 }
1813 }
1814
1815 #[test]
1816 fn test_pkce_challenge_is_s256_of_verifier() {
1817 let pkce = generate_pkce_challenge();
1818 let engine = base64::engine::general_purpose::URL_SAFE_NO_PAD;
1819 let expected_digest = Sha256::digest(pkce.verifier.as_bytes());
1820 let expected_challenge = engine.encode(expected_digest);
1821 assert_eq!(pkce.challenge, expected_challenge);
1822 }
1823
1824 #[test]
1825 fn test_pkce_challenges_are_unique() {
1826 let a = generate_pkce_challenge();
1827 let b = generate_pkce_challenge();
1828 assert_ne!(a.verifier, b.verifier);
1829 }
1830
1831 // -- Authorization URL tests ---------------------------------------------
1832
1833 #[test]
1834 fn test_build_authorization_url() {
1835 let metadata = AuthServerMetadata {
1836 issuer: Url::parse("https://auth.example.com").unwrap(),
1837 authorization_endpoint: Url::parse("https://auth.example.com/authorize").unwrap(),
1838 token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
1839 registration_endpoint: None,
1840 scopes_supported: None,
1841 code_challenge_methods_supported: Some(vec!["S256".into()]),
1842 client_id_metadata_document_supported: true,
1843 grant_types_supported: None,
1844 };
1845 let pkce = PkceChallenge {
1846 verifier: "test_verifier".into(),
1847 challenge: "test_challenge".into(),
1848 };
1849 let url = build_authorization_url(
1850 &metadata,
1851 "https://zed.dev/oauth/client-metadata.json",
1852 "http://127.0.0.1:12345/callback",
1853 &["files:read".into(), "files:write".into()],
1854 "https://mcp.example.com",
1855 &pkce,
1856 "random_state_123",
1857 );
1858
1859 let pairs: std::collections::HashMap<_, _> = url.query_pairs().collect();
1860 assert_eq!(pairs.get("response_type").unwrap(), "code");
1861 assert_eq!(
1862 pairs.get("client_id").unwrap(),
1863 "https://zed.dev/oauth/client-metadata.json"
1864 );
1865 assert_eq!(
1866 pairs.get("redirect_uri").unwrap(),
1867 "http://127.0.0.1:12345/callback"
1868 );
1869 assert_eq!(pairs.get("scope").unwrap(), "files:read files:write");
1870 assert_eq!(pairs.get("resource").unwrap(), "https://mcp.example.com");
1871 assert_eq!(pairs.get("code_challenge").unwrap(), "test_challenge");
1872 assert_eq!(pairs.get("code_challenge_method").unwrap(), "S256");
1873 assert_eq!(pairs.get("state").unwrap(), "random_state_123");
1874 }
1875
1876 #[test]
1877 fn test_build_authorization_url_omits_empty_scope() {
1878 let metadata = AuthServerMetadata {
1879 issuer: Url::parse("https://auth.example.com").unwrap(),
1880 authorization_endpoint: Url::parse("https://auth.example.com/authorize").unwrap(),
1881 token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
1882 registration_endpoint: None,
1883 scopes_supported: None,
1884 code_challenge_methods_supported: Some(vec!["S256".into()]),
1885 client_id_metadata_document_supported: false,
1886 grant_types_supported: None,
1887 };
1888 let pkce = PkceChallenge {
1889 verifier: "v".into(),
1890 challenge: "c".into(),
1891 };
1892 let url = build_authorization_url(
1893 &metadata,
1894 "client_123",
1895 "http://127.0.0.1:9999/callback",
1896 &[],
1897 "https://mcp.example.com",
1898 &pkce,
1899 "state",
1900 );
1901
1902 let pairs: std::collections::HashMap<_, _> = url.query_pairs().collect();
1903 assert!(!pairs.contains_key("scope"));
1904 }
1905
1906 // -- Token exchange / refresh param tests --------------------------------
1907
1908 #[test]
1909 fn test_token_exchange_params() {
1910 let params = token_exchange_params(
1911 "auth_code_abc",
1912 "client_xyz",
1913 "http://127.0.0.1:5555/callback",
1914 "verifier_123",
1915 "https://mcp.example.com",
1916 );
1917 let map: std::collections::HashMap<&str, &str> =
1918 params.iter().map(|(k, v)| (*k, v.as_str())).collect();
1919
1920 assert_eq!(map["grant_type"], "authorization_code");
1921 assert_eq!(map["code"], "auth_code_abc");
1922 assert_eq!(map["redirect_uri"], "http://127.0.0.1:5555/callback");
1923 assert_eq!(map["client_id"], "client_xyz");
1924 assert_eq!(map["code_verifier"], "verifier_123");
1925 assert_eq!(map["resource"], "https://mcp.example.com");
1926 }
1927
1928 #[test]
1929 fn test_token_refresh_params() {
1930 let params =
1931 token_refresh_params("refresh_token_abc", "client_xyz", "https://mcp.example.com");
1932 let map: std::collections::HashMap<&str, &str> =
1933 params.iter().map(|(k, v)| (*k, v.as_str())).collect();
1934
1935 assert_eq!(map["grant_type"], "refresh_token");
1936 assert_eq!(map["refresh_token"], "refresh_token_abc");
1937 assert_eq!(map["client_id"], "client_xyz");
1938 assert_eq!(map["resource"], "https://mcp.example.com");
1939 }
1940
1941 // -- Token response tests ------------------------------------------------
1942
1943 #[test]
1944 fn test_token_response_into_tokens_with_expiry() {
1945 let response: TokenResponse = serde_json::from_str(
1946 r#"{"access_token": "at_123", "refresh_token": "rt_456", "expires_in": 3600, "token_type": "Bearer"}"#,
1947 )
1948 .unwrap();
1949
1950 let tokens = response.into_tokens();
1951 assert_eq!(tokens.access_token, "at_123");
1952 assert_eq!(tokens.refresh_token.as_deref(), Some("rt_456"));
1953 assert!(tokens.expires_at.is_some());
1954 }
1955
1956 #[test]
1957 fn test_token_response_into_tokens_minimal() {
1958 let response: TokenResponse =
1959 serde_json::from_str(r#"{"access_token": "at_789"}"#).unwrap();
1960
1961 let tokens = response.into_tokens();
1962 assert_eq!(tokens.access_token, "at_789");
1963 assert_eq!(tokens.refresh_token, None);
1964 assert_eq!(tokens.expires_at, None);
1965 }
1966
1967 // -- DCR body test -------------------------------------------------------
1968
1969 #[test]
1970 fn test_dcr_registration_body_without_server_metadata() {
1971 // When server metadata is unavailable, include all supported grant types.
1972 let body = dcr_registration_body("http://127.0.0.1:12345/callback", None);
1973 assert_eq!(body["client_name"], "Zed");
1974 assert_eq!(body["redirect_uris"][0], "http://127.0.0.1:12345/callback");
1975 assert_eq!(body["grant_types"][0], "authorization_code");
1976 assert_eq!(body["grant_types"][1], "refresh_token");
1977 assert_eq!(body["response_types"][0], "code");
1978 assert_eq!(body["token_endpoint_auth_method"], "none");
1979 }
1980
1981 #[test]
1982 fn test_dcr_registration_body_mirrors_server_grant_types() {
1983 // When the server only supports authorization_code, omit refresh_token.
1984 let server_types = vec!["authorization_code".to_string()];
1985 let body = dcr_registration_body("http://127.0.0.1:12345/callback", Some(&server_types));
1986 assert_eq!(body["grant_types"][0], "authorization_code");
1987 assert!(body["grant_types"].as_array().unwrap().len() == 1);
1988
1989 // When the server supports both, include both.
1990 let server_types = vec![
1991 "authorization_code".to_string(),
1992 "refresh_token".to_string(),
1993 ];
1994 let body = dcr_registration_body("http://127.0.0.1:12345/callback", Some(&server_types));
1995 assert_eq!(body["grant_types"][0], "authorization_code");
1996 assert_eq!(body["grant_types"][1], "refresh_token");
1997 }
1998
1999 // -- Test helpers for async/HTTP tests -----------------------------------
2000
2001 fn make_fake_http_client(
2002 handler: impl Fn(
2003 http_client::Request<AsyncBody>,
2004 ) -> std::pin::Pin<
2005 Box<dyn std::future::Future<Output = anyhow::Result<Response<AsyncBody>>> + Send>,
2006 > + Send
2007 + Sync
2008 + 'static,
2009 ) -> Arc<dyn HttpClient> {
2010 http_client::FakeHttpClient::create(handler) as Arc<dyn HttpClient>
2011 }
2012
2013 fn json_response(status: u16, body: &str) -> anyhow::Result<Response<AsyncBody>> {
2014 Ok(Response::builder()
2015 .status(status)
2016 .header("Content-Type", "application/json")
2017 .body(AsyncBody::from(body.as_bytes().to_vec()))
2018 .unwrap())
2019 }
2020
2021 // -- Discovery integration tests -----------------------------------------
2022
2023 #[test]
2024 fn test_fetch_protected_resource_metadata() {
2025 smol::block_on(async {
2026 let client = make_fake_http_client(|req| {
2027 Box::pin(async move {
2028 let uri = req.uri().to_string();
2029 if uri.contains(".well-known/oauth-protected-resource") {
2030 json_response(
2031 200,
2032 r#"{
2033 "resource": "https://mcp.example.com",
2034 "authorization_servers": ["https://auth.example.com"],
2035 "scopes_supported": ["read", "write"]
2036 }"#,
2037 )
2038 } else {
2039 json_response(404, "{}")
2040 }
2041 })
2042 });
2043
2044 let server_url = Url::parse("https://mcp.example.com").unwrap();
2045 let www_auth = WwwAuthenticate {
2046 resource_metadata: None,
2047 scope: None,
2048 error: None,
2049 error_description: None,
2050 };
2051
2052 let metadata = fetch_protected_resource_metadata(&client, &server_url, &www_auth)
2053 .await
2054 .unwrap();
2055
2056 assert_eq!(metadata.resource.as_str(), "https://mcp.example.com/");
2057 assert_eq!(metadata.authorization_servers.len(), 1);
2058 assert_eq!(
2059 metadata.authorization_servers[0].as_str(),
2060 "https://auth.example.com/"
2061 );
2062 assert_eq!(
2063 metadata.scopes_supported,
2064 Some(vec!["read".to_string(), "write".to_string()])
2065 );
2066 });
2067 }
2068
2069 #[test]
2070 fn test_fetch_protected_resource_metadata_prefers_www_authenticate_url() {
2071 smol::block_on(async {
2072 let client = make_fake_http_client(|req| {
2073 Box::pin(async move {
2074 let uri = req.uri().to_string();
2075 if uri == "https://mcp.example.com/custom-resource-metadata" {
2076 json_response(
2077 200,
2078 r#"{
2079 "resource": "https://mcp.example.com",
2080 "authorization_servers": ["https://auth.example.com"]
2081 }"#,
2082 )
2083 } else {
2084 json_response(500, r#"{"error": "should not be called"}"#)
2085 }
2086 })
2087 });
2088
2089 let server_url = Url::parse("https://mcp.example.com").unwrap();
2090 let www_auth = WwwAuthenticate {
2091 resource_metadata: Some(
2092 Url::parse("https://mcp.example.com/custom-resource-metadata").unwrap(),
2093 ),
2094 scope: None,
2095 error: None,
2096 error_description: None,
2097 };
2098
2099 let metadata = fetch_protected_resource_metadata(&client, &server_url, &www_auth)
2100 .await
2101 .unwrap();
2102
2103 assert_eq!(metadata.authorization_servers.len(), 1);
2104 });
2105 }
2106
2107 #[test]
2108 fn test_fetch_protected_resource_metadata_rejects_cross_origin_url() {
2109 smol::block_on(async {
2110 let client = make_fake_http_client(|req| {
2111 Box::pin(async move {
2112 let uri = req.uri().to_string();
2113 // The cross-origin URL should NOT be fetched; only the
2114 // well-known fallback at the server's own origin should be.
2115 if uri.contains("attacker.example.com") {
2116 panic!("should not fetch cross-origin resource_metadata URL");
2117 } else if uri.contains(".well-known/oauth-protected-resource") {
2118 json_response(
2119 200,
2120 r#"{
2121 "resource": "https://mcp.example.com",
2122 "authorization_servers": ["https://auth.example.com"]
2123 }"#,
2124 )
2125 } else {
2126 json_response(404, "{}")
2127 }
2128 })
2129 });
2130
2131 let server_url = Url::parse("https://mcp.example.com").unwrap();
2132 let www_auth = WwwAuthenticate {
2133 resource_metadata: Some(
2134 Url::parse("https://attacker.example.com/fake-metadata").unwrap(),
2135 ),
2136 scope: None,
2137 error: None,
2138 error_description: None,
2139 };
2140
2141 let metadata = fetch_protected_resource_metadata(&client, &server_url, &www_auth)
2142 .await
2143 .unwrap();
2144
2145 // Should have used the fallback well-known URL, not the attacker's.
2146 assert_eq!(metadata.resource.as_str(), "https://mcp.example.com/");
2147 });
2148 }
2149
2150 #[test]
2151 fn test_fetch_auth_server_metadata() {
2152 smol::block_on(async {
2153 let client = make_fake_http_client(|req| {
2154 Box::pin(async move {
2155 let uri = req.uri().to_string();
2156 if uri.contains(".well-known/oauth-authorization-server") {
2157 json_response(
2158 200,
2159 r#"{
2160 "issuer": "https://auth.example.com",
2161 "authorization_endpoint": "https://auth.example.com/authorize",
2162 "token_endpoint": "https://auth.example.com/token",
2163 "registration_endpoint": "https://auth.example.com/register",
2164 "code_challenge_methods_supported": ["S256"],
2165 "client_id_metadata_document_supported": true
2166 }"#,
2167 )
2168 } else {
2169 json_response(404, "{}")
2170 }
2171 })
2172 });
2173
2174 let issuer = Url::parse("https://auth.example.com").unwrap();
2175 let metadata = fetch_auth_server_metadata(&client, &issuer).await.unwrap();
2176
2177 assert_eq!(metadata.issuer.as_str(), "https://auth.example.com/");
2178 assert_eq!(
2179 metadata.authorization_endpoint.as_str(),
2180 "https://auth.example.com/authorize"
2181 );
2182 assert_eq!(
2183 metadata.token_endpoint.as_str(),
2184 "https://auth.example.com/token"
2185 );
2186 assert!(metadata.registration_endpoint.is_some());
2187 assert!(metadata.client_id_metadata_document_supported);
2188 assert_eq!(
2189 metadata.code_challenge_methods_supported,
2190 Some(vec!["S256".to_string()])
2191 );
2192 });
2193 }
2194
2195 #[test]
2196 fn test_fetch_auth_server_metadata_falls_back_to_oidc() {
2197 smol::block_on(async {
2198 let client = make_fake_http_client(|req| {
2199 Box::pin(async move {
2200 let uri = req.uri().to_string();
2201 if uri.contains("openid-configuration") {
2202 json_response(
2203 200,
2204 r#"{
2205 "issuer": "https://auth.example.com",
2206 "authorization_endpoint": "https://auth.example.com/authorize",
2207 "token_endpoint": "https://auth.example.com/token",
2208 "code_challenge_methods_supported": ["S256"]
2209 }"#,
2210 )
2211 } else {
2212 json_response(404, "{}")
2213 }
2214 })
2215 });
2216
2217 let issuer = Url::parse("https://auth.example.com").unwrap();
2218 let metadata = fetch_auth_server_metadata(&client, &issuer).await.unwrap();
2219
2220 assert_eq!(
2221 metadata.authorization_endpoint.as_str(),
2222 "https://auth.example.com/authorize"
2223 );
2224 assert!(!metadata.client_id_metadata_document_supported);
2225 });
2226 }
2227
2228 #[test]
2229 fn test_fetch_auth_server_metadata_rejects_issuer_mismatch() {
2230 smol::block_on(async {
2231 let client = make_fake_http_client(|req| {
2232 Box::pin(async move {
2233 let uri = req.uri().to_string();
2234 if uri.contains(".well-known/oauth-authorization-server") {
2235 // Response claims to be a different issuer.
2236 json_response(
2237 200,
2238 r#"{
2239 "issuer": "https://evil.example.com",
2240 "authorization_endpoint": "https://evil.example.com/authorize",
2241 "token_endpoint": "https://evil.example.com/token",
2242 "code_challenge_methods_supported": ["S256"]
2243 }"#,
2244 )
2245 } else {
2246 json_response(404, "{}")
2247 }
2248 })
2249 });
2250
2251 let issuer = Url::parse("https://auth.example.com").unwrap();
2252 let result = fetch_auth_server_metadata(&client, &issuer).await;
2253
2254 assert!(result.is_err());
2255 let err_msg = result.unwrap_err().to_string();
2256 assert!(
2257 err_msg.contains("issuer mismatch"),
2258 "unexpected error: {}",
2259 err_msg
2260 );
2261 });
2262 }
2263
2264 // -- Full discover integration tests -------------------------------------
2265
2266 #[test]
2267 fn test_full_discover_with_cimd() {
2268 smol::block_on(async {
2269 let client = make_fake_http_client(|req| {
2270 Box::pin(async move {
2271 let uri = req.uri().to_string();
2272 if uri.contains("oauth-protected-resource") {
2273 json_response(
2274 200,
2275 r#"{
2276 "resource": "https://mcp.example.com",
2277 "authorization_servers": ["https://auth.example.com"],
2278 "scopes_supported": ["mcp:read"]
2279 }"#,
2280 )
2281 } else if uri.contains("oauth-authorization-server") {
2282 json_response(
2283 200,
2284 r#"{
2285 "issuer": "https://auth.example.com",
2286 "authorization_endpoint": "https://auth.example.com/authorize",
2287 "token_endpoint": "https://auth.example.com/token",
2288 "code_challenge_methods_supported": ["S256"],
2289 "client_id_metadata_document_supported": true
2290 }"#,
2291 )
2292 } else {
2293 json_response(404, "{}")
2294 }
2295 })
2296 });
2297
2298 let server_url = Url::parse("https://mcp.example.com").unwrap();
2299 let www_auth = WwwAuthenticate {
2300 resource_metadata: None,
2301 scope: None,
2302 error: None,
2303 error_description: None,
2304 };
2305
2306 let discovery = discover(&client, &server_url, &www_auth).await.unwrap();
2307 let registration =
2308 resolve_client_registration(&client, &discovery, "http://127.0.0.1:12345/callback")
2309 .await
2310 .unwrap();
2311
2312 assert_eq!(registration.client_id, CIMD_URL);
2313 assert_eq!(registration.client_secret, None);
2314 assert_eq!(discovery.scopes, vec!["mcp:read"]);
2315 });
2316 }
2317
2318 #[test]
2319 fn test_full_discover_with_dcr_fallback() {
2320 smol::block_on(async {
2321 let client = make_fake_http_client(|req| {
2322 Box::pin(async move {
2323 let uri = req.uri().to_string();
2324 if uri.contains("oauth-protected-resource") {
2325 json_response(
2326 200,
2327 r#"{
2328 "resource": "https://mcp.example.com",
2329 "authorization_servers": ["https://auth.example.com"]
2330 }"#,
2331 )
2332 } else if uri.contains("oauth-authorization-server") {
2333 json_response(
2334 200,
2335 r#"{
2336 "issuer": "https://auth.example.com",
2337 "authorization_endpoint": "https://auth.example.com/authorize",
2338 "token_endpoint": "https://auth.example.com/token",
2339 "registration_endpoint": "https://auth.example.com/register",
2340 "code_challenge_methods_supported": ["S256"],
2341 "client_id_metadata_document_supported": false
2342 }"#,
2343 )
2344 } else if uri.contains("/register") {
2345 json_response(
2346 201,
2347 r#"{
2348 "client_id": "dcr-minted-id-123",
2349 "client_secret": "dcr-secret-456"
2350 }"#,
2351 )
2352 } else {
2353 json_response(404, "{}")
2354 }
2355 })
2356 });
2357
2358 let server_url = Url::parse("https://mcp.example.com").unwrap();
2359 let www_auth = WwwAuthenticate {
2360 resource_metadata: None,
2361 scope: Some(vec!["files:read".into()]),
2362 error: None,
2363 error_description: None,
2364 };
2365
2366 let discovery = discover(&client, &server_url, &www_auth).await.unwrap();
2367 let registration =
2368 resolve_client_registration(&client, &discovery, "http://127.0.0.1:9999/callback")
2369 .await
2370 .unwrap();
2371
2372 assert_eq!(registration.client_id, "dcr-minted-id-123");
2373 assert_eq!(
2374 registration.client_secret.as_deref(),
2375 Some("dcr-secret-456")
2376 );
2377 assert_eq!(discovery.scopes, vec!["files:read"]);
2378 });
2379 }
2380
2381 #[test]
2382 fn test_discover_fails_without_pkce_support() {
2383 smol::block_on(async {
2384 let client = make_fake_http_client(|req| {
2385 Box::pin(async move {
2386 let uri = req.uri().to_string();
2387 if uri.contains("oauth-protected-resource") {
2388 json_response(
2389 200,
2390 r#"{
2391 "resource": "https://mcp.example.com",
2392 "authorization_servers": ["https://auth.example.com"]
2393 }"#,
2394 )
2395 } else if uri.contains("oauth-authorization-server") {
2396 json_response(
2397 200,
2398 r#"{
2399 "issuer": "https://auth.example.com",
2400 "authorization_endpoint": "https://auth.example.com/authorize",
2401 "token_endpoint": "https://auth.example.com/token"
2402 }"#,
2403 )
2404 } else {
2405 json_response(404, "{}")
2406 }
2407 })
2408 });
2409
2410 let server_url = Url::parse("https://mcp.example.com").unwrap();
2411 let www_auth = WwwAuthenticate {
2412 resource_metadata: None,
2413 scope: None,
2414 error: None,
2415 error_description: None,
2416 };
2417
2418 let result = discover(&client, &server_url, &www_auth).await;
2419 assert!(result.is_err());
2420 let err_msg = result.unwrap_err().to_string();
2421 assert!(
2422 err_msg.contains("code_challenge_methods_supported"),
2423 "unexpected error: {}",
2424 err_msg
2425 );
2426 });
2427 }
2428
2429 // -- Token exchange integration tests ------------------------------------
2430
2431 #[test]
2432 fn test_exchange_code_success() {
2433 smol::block_on(async {
2434 let client = make_fake_http_client(|req| {
2435 Box::pin(async move {
2436 let uri = req.uri().to_string();
2437 if uri.contains("/token") {
2438 json_response(
2439 200,
2440 r#"{
2441 "access_token": "new_access_token",
2442 "refresh_token": "new_refresh_token",
2443 "expires_in": 3600,
2444 "token_type": "Bearer"
2445 }"#,
2446 )
2447 } else {
2448 json_response(404, "{}")
2449 }
2450 })
2451 });
2452
2453 let metadata = AuthServerMetadata {
2454 issuer: Url::parse("https://auth.example.com").unwrap(),
2455 authorization_endpoint: Url::parse("https://auth.example.com/authorize").unwrap(),
2456 token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
2457 registration_endpoint: None,
2458 scopes_supported: None,
2459 code_challenge_methods_supported: Some(vec!["S256".into()]),
2460 client_id_metadata_document_supported: true,
2461 grant_types_supported: None,
2462 };
2463
2464 let tokens = exchange_code(
2465 &client,
2466 &metadata,
2467 "auth_code_123",
2468 CIMD_URL,
2469 "http://127.0.0.1:9999/callback",
2470 "verifier_abc",
2471 "https://mcp.example.com",
2472 )
2473 .await
2474 .unwrap();
2475
2476 assert_eq!(tokens.access_token, "new_access_token");
2477 assert_eq!(tokens.refresh_token.as_deref(), Some("new_refresh_token"));
2478 assert!(tokens.expires_at.is_some());
2479 });
2480 }
2481
2482 #[test]
2483 fn test_refresh_tokens_success() {
2484 smol::block_on(async {
2485 let client = make_fake_http_client(|req| {
2486 Box::pin(async move {
2487 let uri = req.uri().to_string();
2488 if uri.contains("/token") {
2489 json_response(
2490 200,
2491 r#"{
2492 "access_token": "refreshed_token",
2493 "expires_in": 1800,
2494 "token_type": "Bearer"
2495 }"#,
2496 )
2497 } else {
2498 json_response(404, "{}")
2499 }
2500 })
2501 });
2502
2503 let token_endpoint = Url::parse("https://auth.example.com/token").unwrap();
2504
2505 let tokens = refresh_tokens(
2506 &client,
2507 &token_endpoint,
2508 "old_refresh_token",
2509 CIMD_URL,
2510 "https://mcp.example.com",
2511 )
2512 .await
2513 .unwrap();
2514
2515 assert_eq!(tokens.access_token, "refreshed_token");
2516 assert_eq!(tokens.refresh_token, None);
2517 assert!(tokens.expires_at.is_some());
2518 });
2519 }
2520
2521 #[test]
2522 fn test_exchange_code_failure() {
2523 smol::block_on(async {
2524 let client = make_fake_http_client(|_req| {
2525 Box::pin(async move { json_response(400, r#"{"error": "invalid_grant"}"#) })
2526 });
2527
2528 let metadata = AuthServerMetadata {
2529 issuer: Url::parse("https://auth.example.com").unwrap(),
2530 authorization_endpoint: Url::parse("https://auth.example.com/authorize").unwrap(),
2531 token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
2532 registration_endpoint: None,
2533 scopes_supported: None,
2534 code_challenge_methods_supported: Some(vec!["S256".into()]),
2535 client_id_metadata_document_supported: true,
2536 grant_types_supported: None,
2537 };
2538
2539 let result = exchange_code(
2540 &client,
2541 &metadata,
2542 "bad_code",
2543 "client",
2544 "http://127.0.0.1:1/callback",
2545 "verifier",
2546 "https://mcp.example.com",
2547 )
2548 .await;
2549
2550 assert!(result.is_err());
2551 assert!(result.unwrap_err().to_string().contains("400"));
2552 });
2553 }
2554
2555 // -- DCR integration tests -----------------------------------------------
2556
2557 #[test]
2558 fn test_perform_dcr() {
2559 smol::block_on(async {
2560 let client = make_fake_http_client(|_req| {
2561 Box::pin(async move {
2562 json_response(
2563 201,
2564 r#"{
2565 "client_id": "dynamic-client-001",
2566 "client_secret": "dynamic-secret-001"
2567 }"#,
2568 )
2569 })
2570 });
2571
2572 let endpoint = Url::parse("https://auth.example.com/register").unwrap();
2573 let registration =
2574 perform_dcr(&client, &endpoint, "http://127.0.0.1:9999/callback", None)
2575 .await
2576 .unwrap();
2577
2578 assert_eq!(registration.client_id, "dynamic-client-001");
2579 assert_eq!(
2580 registration.client_secret.as_deref(),
2581 Some("dynamic-secret-001")
2582 );
2583 });
2584 }
2585
2586 #[test]
2587 fn test_perform_dcr_failure() {
2588 smol::block_on(async {
2589 let client = make_fake_http_client(|_req| {
2590 Box::pin(
2591 async move { json_response(403, r#"{"error": "registration_not_allowed"}"#) },
2592 )
2593 });
2594
2595 let endpoint = Url::parse("https://auth.example.com/register").unwrap();
2596 let result =
2597 perform_dcr(&client, &endpoint, "http://127.0.0.1:9999/callback", None).await;
2598
2599 assert!(result.is_err());
2600 assert!(result.unwrap_err().to_string().contains("403"));
2601 });
2602 }
2603
2604 // -- OAuthCallback parse tests -------------------------------------------
2605
2606 #[test]
2607 fn test_oauth_callback_parse_query() {
2608 let callback = OAuthCallback::parse_query("code=test_auth_code&state=test_state").unwrap();
2609 assert_eq!(callback.code, "test_auth_code");
2610 assert_eq!(callback.state, "test_state");
2611 }
2612
2613 #[test]
2614 fn test_oauth_callback_parse_query_reversed_order() {
2615 let callback = OAuthCallback::parse_query("state=test_state&code=test_auth_code").unwrap();
2616 assert_eq!(callback.code, "test_auth_code");
2617 assert_eq!(callback.state, "test_state");
2618 }
2619
2620 #[test]
2621 fn test_oauth_callback_parse_query_with_extra_params() {
2622 let callback =
2623 OAuthCallback::parse_query("code=test_auth_code&state=test_state&extra=ignored")
2624 .unwrap();
2625 assert_eq!(callback.code, "test_auth_code");
2626 assert_eq!(callback.state, "test_state");
2627 }
2628
2629 #[test]
2630 fn test_oauth_callback_parse_query_missing_code() {
2631 let result = OAuthCallback::parse_query("state=test_state");
2632 assert!(result.is_err());
2633 assert!(result.unwrap_err().to_string().contains("code"));
2634 }
2635
2636 #[test]
2637 fn test_oauth_callback_parse_query_missing_state() {
2638 let result = OAuthCallback::parse_query("code=test_auth_code");
2639 assert!(result.is_err());
2640 assert!(result.unwrap_err().to_string().contains("state"));
2641 }
2642
2643 #[test]
2644 fn test_oauth_callback_parse_query_empty_code() {
2645 let result = OAuthCallback::parse_query("code=&state=test_state");
2646 assert!(result.is_err());
2647 }
2648
2649 #[test]
2650 fn test_oauth_callback_parse_query_empty_state() {
2651 let result = OAuthCallback::parse_query("code=test_auth_code&state=");
2652 assert!(result.is_err());
2653 }
2654
2655 #[test]
2656 fn test_oauth_callback_parse_query_url_encoded_values() {
2657 let callback = OAuthCallback::parse_query("code=abc%20def&state=test%3Dstate").unwrap();
2658 assert_eq!(callback.code, "abc def");
2659 assert_eq!(callback.state, "test=state");
2660 }
2661
2662 #[test]
2663 fn test_oauth_callback_parse_query_error_response() {
2664 let result = OAuthCallback::parse_query(
2665 "error=access_denied&error_description=User%20denied%20access&state=abc",
2666 );
2667 assert!(result.is_err());
2668 let err_msg = result.unwrap_err().to_string();
2669 assert!(
2670 err_msg.contains("access_denied"),
2671 "unexpected error: {}",
2672 err_msg
2673 );
2674 assert!(
2675 err_msg.contains("User denied access"),
2676 "unexpected error: {}",
2677 err_msg
2678 );
2679 }
2680
2681 #[test]
2682 fn test_oauth_callback_parse_query_error_without_description() {
2683 let result = OAuthCallback::parse_query("error=server_error&state=abc");
2684 assert!(result.is_err());
2685 let err_msg = result.unwrap_err().to_string();
2686 assert!(
2687 err_msg.contains("server_error"),
2688 "unexpected error: {}",
2689 err_msg
2690 );
2691 assert!(
2692 err_msg.contains("no description"),
2693 "unexpected error: {}",
2694 err_msg
2695 );
2696 }
2697
2698 // -- McpOAuthTokenProvider tests -----------------------------------------
2699
2700 fn make_test_session(
2701 access_token: &str,
2702 refresh_token: Option<&str>,
2703 expires_at: Option<SystemTime>,
2704 ) -> OAuthSession {
2705 OAuthSession {
2706 token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
2707 resource: Url::parse("https://mcp.example.com").unwrap(),
2708 client_registration: OAuthClientRegistration {
2709 client_id: "test-client".into(),
2710 client_secret: None,
2711 },
2712 tokens: OAuthTokens {
2713 access_token: access_token.into(),
2714 refresh_token: refresh_token.map(String::from),
2715 expires_at,
2716 },
2717 }
2718 }
2719
2720 #[test]
2721 fn test_mcp_oauth_provider_returns_none_when_token_expired() {
2722 let expired = SystemTime::now() - Duration::from_secs(60);
2723 let session = make_test_session("stale-token", Some("rt"), Some(expired));
2724 let provider = McpOAuthTokenProvider::new(
2725 session,
2726 make_fake_http_client(|_| Box::pin(async { unreachable!() })),
2727 None,
2728 );
2729
2730 assert_eq!(provider.access_token(), None);
2731 }
2732
2733 #[test]
2734 fn test_mcp_oauth_provider_returns_token_when_not_expired() {
2735 let far_future = SystemTime::now() + Duration::from_secs(3600);
2736 let session = make_test_session("valid-token", Some("rt"), Some(far_future));
2737 let provider = McpOAuthTokenProvider::new(
2738 session,
2739 make_fake_http_client(|_| Box::pin(async { unreachable!() })),
2740 None,
2741 );
2742
2743 assert_eq!(provider.access_token().as_deref(), Some("valid-token"));
2744 }
2745
2746 #[test]
2747 fn test_mcp_oauth_provider_returns_token_when_no_expiry() {
2748 let session = make_test_session("no-expiry-token", Some("rt"), None);
2749 let provider = McpOAuthTokenProvider::new(
2750 session,
2751 make_fake_http_client(|_| Box::pin(async { unreachable!() })),
2752 None,
2753 );
2754
2755 assert_eq!(provider.access_token().as_deref(), Some("no-expiry-token"));
2756 }
2757
2758 #[test]
2759 fn test_mcp_oauth_provider_refresh_without_refresh_token_returns_false() {
2760 smol::block_on(async {
2761 let session = make_test_session("token", None, None);
2762 let provider = McpOAuthTokenProvider::new(
2763 session,
2764 make_fake_http_client(|_| {
2765 Box::pin(async { unreachable!("no HTTP call expected") })
2766 }),
2767 None,
2768 );
2769
2770 let refreshed = provider.try_refresh().await.unwrap();
2771 assert!(!refreshed);
2772 });
2773 }
2774
2775 #[test]
2776 fn test_mcp_oauth_provider_refresh_updates_session_and_notifies_channel() {
2777 smol::block_on(async {
2778 let session = make_test_session("old-access", Some("my-refresh-token"), None);
2779 let (tx, mut rx) = futures::channel::mpsc::unbounded();
2780
2781 let http_client = make_fake_http_client(|_req| {
2782 Box::pin(async {
2783 json_response(
2784 200,
2785 r#"{
2786 "access_token": "new-access",
2787 "refresh_token": "new-refresh",
2788 "expires_in": 1800
2789 }"#,
2790 )
2791 })
2792 });
2793
2794 let provider = McpOAuthTokenProvider::new(session, http_client, Some(tx));
2795
2796 let refreshed = provider.try_refresh().await.unwrap();
2797 assert!(refreshed);
2798 assert_eq!(provider.access_token().as_deref(), Some("new-access"));
2799
2800 let notified_session = rx.try_recv().expect("channel should have a session");
2801 assert_eq!(notified_session.tokens.access_token, "new-access");
2802 assert_eq!(
2803 notified_session.tokens.refresh_token.as_deref(),
2804 Some("new-refresh")
2805 );
2806 });
2807 }
2808
2809 #[test]
2810 fn test_mcp_oauth_provider_refresh_preserves_old_refresh_token_when_server_omits_it() {
2811 smol::block_on(async {
2812 let session = make_test_session("old-access", Some("original-refresh"), None);
2813 let (tx, mut rx) = futures::channel::mpsc::unbounded();
2814
2815 let http_client = make_fake_http_client(|_req| {
2816 Box::pin(async {
2817 json_response(
2818 200,
2819 r#"{
2820 "access_token": "new-access",
2821 "expires_in": 900
2822 }"#,
2823 )
2824 })
2825 });
2826
2827 let provider = McpOAuthTokenProvider::new(session, http_client, Some(tx));
2828
2829 let refreshed = provider.try_refresh().await.unwrap();
2830 assert!(refreshed);
2831
2832 let notified_session = rx.try_recv().expect("channel should have a session");
2833 assert_eq!(notified_session.tokens.access_token, "new-access");
2834 assert_eq!(
2835 notified_session.tokens.refresh_token.as_deref(),
2836 Some("original-refresh"),
2837 );
2838 });
2839 }
2840
2841 #[test]
2842 fn test_mcp_oauth_provider_refresh_returns_false_on_http_error() {
2843 smol::block_on(async {
2844 let session = make_test_session("old-access", Some("my-refresh"), None);
2845
2846 let http_client = make_fake_http_client(|_req| {
2847 Box::pin(async { json_response(401, r#"{"error": "invalid_grant"}"#) })
2848 });
2849
2850 let provider = McpOAuthTokenProvider::new(session, http_client, None);
2851
2852 let refreshed = provider.try_refresh().await.unwrap();
2853 assert!(!refreshed);
2854 // The old token should still be in place.
2855 assert_eq!(provider.access_token().as_deref(), Some("old-access"));
2856 });
2857 }
2858}