1//! OAuth 2.0 authentication for MCP servers using the Authorization Code +
2//! PKCE flow, per the MCP spec's OAuth profile.
3//!
4//! The flow is split into two phases:
5//!
6//! 1. **Discovery** ([`discover`]) fetches Protected Resource Metadata and
7//! Authorization Server Metadata. This can happen early (e.g. on a 401
8//! during server startup) because it doesn't need the redirect URI yet.
9//!
10//! 2. **Client registration** ([`resolve_client_registration`]) is separate
11//! because DCR requires the actual loopback redirect URI, which includes an
12//! ephemeral port that only exists once the callback server has started.
13//!
14//! After authentication, the full state is captured in [`OAuthSession`] which
15//! is persisted to the keychain. On next startup, the stored session feeds
16//! directly into [`McpOAuthTokenProvider`], giving a refresh-capable provider
17//! without requiring another browser flow.
18
19use anyhow::{Context as _, Result, anyhow, bail};
20use async_trait::async_trait;
21use base64::Engine as _;
22use futures::AsyncReadExt as _;
23use futures::channel::mpsc;
24use http_client::{AsyncBody, HttpClient, Request};
25use parking_lot::Mutex as SyncMutex;
26use rand::Rng as _;
27use serde::{Deserialize, Serialize};
28use sha2::{Digest, Sha256};
29
30use std::str::FromStr;
31use std::sync::Arc;
32use std::time::{Duration, SystemTime};
33use url::Url;
34use util::ResultExt as _;
35
36/// The CIMD URL where Zed's OAuth client metadata document is hosted.
37pub const CIMD_URL: &str = "https://zed.dev/oauth/client-metadata.json";
38
39/// Validate that a URL is safe to use as an OAuth endpoint.
40///
41/// OAuth endpoints carry sensitive material (authorization codes, PKCE
42/// verifiers, tokens) and must use TLS. Plain HTTP is only permitted for
43/// loopback addresses, per RFC 8252 Section 8.3.
44fn require_https_or_loopback(url: &Url) -> Result<()> {
45 if url.scheme() == "https" {
46 return Ok(());
47 }
48 if url.scheme() == "http" {
49 if let Some(host) = url.host() {
50 match host {
51 url::Host::Ipv4(ip) if ip.is_loopback() => return Ok(()),
52 url::Host::Ipv6(ip) if ip.is_loopback() => return Ok(()),
53 url::Host::Domain(d) if d.eq_ignore_ascii_case("localhost") => return Ok(()),
54 _ => {}
55 }
56 }
57 }
58 bail!(
59 "OAuth endpoint must use HTTPS (got {}://{})",
60 url.scheme(),
61 url.host_str().unwrap_or("?")
62 )
63}
64
65/// Validate that a URL is safe to use as an OAuth endpoint, including SSRF
66/// protections against private/reserved IP ranges.
67///
68/// This wraps [`require_https_or_loopback`] and adds IP-range checks to prevent
69/// an attacker-controlled MCP server from directing Zed to fetch internal
70/// network resources via metadata URLs.
71///
72/// **Known limitation:** Domain-name URLs that resolve to private IPs are *not*
73/// blocked here — full mitigation requires resolver-level validation (e.g. a
74/// custom `Resolve` implementation). This function only blocks IP-literal URLs.
75fn validate_oauth_url(url: &Url) -> Result<()> {
76 require_https_or_loopback(url)?;
77
78 if let Some(host) = url.host() {
79 match host {
80 url::Host::Ipv4(ip) => {
81 // Loopback is already allowed by require_https_or_loopback.
82 if ip.is_private() || ip.is_link_local() || ip.is_broadcast() || ip.is_unspecified()
83 {
84 bail!(
85 "OAuth endpoint must not point to private/reserved IP: {}",
86 ip
87 );
88 }
89 }
90 url::Host::Ipv6(ip) => {
91 // Check for IPv4-mapped IPv6 addresses (::ffff:a.b.c.d) which
92 // could bypass the IPv4 checks above.
93 if let Some(mapped_v4) = ip.to_ipv4_mapped() {
94 if mapped_v4.is_private()
95 || mapped_v4.is_link_local()
96 || mapped_v4.is_broadcast()
97 || mapped_v4.is_unspecified()
98 {
99 bail!(
100 "OAuth endpoint must not point to private/reserved IP: ::ffff:{}",
101 mapped_v4
102 );
103 }
104 }
105
106 if ip.is_unspecified() || ip.is_multicast() {
107 bail!(
108 "OAuth endpoint must not point to reserved IPv6 address: {}",
109 ip
110 );
111 }
112 // IPv6 Unique Local Addresses (fc00::/7). is_unique_local() is
113 // nightly-only, so check the prefix manually.
114 if (ip.segments()[0] & 0xfe00) == 0xfc00 {
115 bail!(
116 "OAuth endpoint must not point to IPv6 unique-local address: {}",
117 ip
118 );
119 }
120 }
121 url::Host::Domain(_) => {
122 // Domain-based SSRF prevention requires resolver-level checks.
123 // See known limitation in the doc comment above.
124 }
125 }
126 }
127
128 Ok(())
129}
130
131/// Parsed from the MCP server's WWW-Authenticate header or well-known endpoint
132/// per RFC 9728 (OAuth 2.0 Protected Resource Metadata).
133#[derive(Debug, Clone, Serialize, Deserialize)]
134pub struct ProtectedResourceMetadata {
135 pub resource: Url,
136 pub authorization_servers: Vec<Url>,
137 pub scopes_supported: Option<Vec<String>>,
138}
139
140/// Parsed from the authorization server's .well-known endpoint
141/// per RFC 8414 (OAuth 2.0 Authorization Server Metadata).
142#[derive(Debug, Clone, Serialize, Deserialize)]
143pub struct AuthServerMetadata {
144 pub issuer: Url,
145 pub authorization_endpoint: Url,
146 pub token_endpoint: Url,
147 pub registration_endpoint: Option<Url>,
148 pub scopes_supported: Option<Vec<String>>,
149 pub code_challenge_methods_supported: Option<Vec<String>>,
150 pub client_id_metadata_document_supported: bool,
151}
152
153/// The result of client registration — either CIMD or DCR.
154#[derive(Clone, Serialize, Deserialize)]
155pub struct OAuthClientRegistration {
156 pub client_id: String,
157 /// Only present for DCR-minted registrations.
158 pub client_secret: Option<String>,
159}
160
161impl std::fmt::Debug for OAuthClientRegistration {
162 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
163 f.debug_struct("OAuthClientRegistration")
164 .field("client_id", &self.client_id)
165 .field(
166 "client_secret",
167 &self.client_secret.as_ref().map(|_| "[redacted]"),
168 )
169 .finish()
170 }
171}
172
173/// Access and refresh tokens obtained from the token endpoint.
174#[derive(Clone, Serialize, Deserialize)]
175pub struct OAuthTokens {
176 pub access_token: String,
177 pub refresh_token: Option<String>,
178 pub expires_at: Option<SystemTime>,
179}
180
181impl std::fmt::Debug for OAuthTokens {
182 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
183 f.debug_struct("OAuthTokens")
184 .field("access_token", &"[redacted]")
185 .field(
186 "refresh_token",
187 &self.refresh_token.as_ref().map(|_| "[redacted]"),
188 )
189 .field("expires_at", &self.expires_at)
190 .finish()
191 }
192}
193
194/// Everything discovered before the browser flow starts. Client registration is
195/// resolved separately, once the real redirect URI is known.
196#[derive(Debug, Clone, Serialize, Deserialize)]
197pub struct OAuthDiscovery {
198 pub resource_metadata: ProtectedResourceMetadata,
199 pub auth_server_metadata: AuthServerMetadata,
200 pub scopes: Vec<String>,
201}
202
203/// The persisted OAuth session for a context server.
204///
205/// Stored in the keychain so startup can restore a refresh-capable provider
206/// without another browser flow. Deliberately excludes the full discovery
207/// metadata to keep the serialized size well within keychain item limits.
208#[derive(Clone, Serialize, Deserialize)]
209pub struct OAuthSession {
210 pub token_endpoint: Url,
211 pub resource: Url,
212 pub client_registration: OAuthClientRegistration,
213 pub tokens: OAuthTokens,
214}
215
216impl std::fmt::Debug for OAuthSession {
217 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
218 f.debug_struct("OAuthSession")
219 .field("token_endpoint", &self.token_endpoint)
220 .field("resource", &self.resource)
221 .field("client_registration", &self.client_registration)
222 .field("tokens", &self.tokens)
223 .finish()
224 }
225}
226
227/// Error codes defined by RFC 6750 Section 3.1 for Bearer token authentication.
228#[derive(Debug, Clone, Copy, PartialEq, Eq)]
229pub enum BearerError {
230 /// The request is missing a required parameter, includes an unsupported
231 /// parameter or parameter value, or is otherwise malformed.
232 InvalidRequest,
233 /// The access token provided is expired, revoked, malformed, or invalid.
234 InvalidToken,
235 /// The request requires higher privileges than provided by the access token.
236 InsufficientScope,
237 /// An unrecognized error code (extension or future spec addition).
238 Other,
239}
240
241impl BearerError {
242 fn parse(value: &str) -> Self {
243 match value {
244 "invalid_request" => BearerError::InvalidRequest,
245 "invalid_token" => BearerError::InvalidToken,
246 "insufficient_scope" => BearerError::InsufficientScope,
247 _ => BearerError::Other,
248 }
249 }
250}
251
252/// Fields extracted from a `WWW-Authenticate: Bearer` header.
253///
254/// Per RFC 9728 Section 5.1, MCP servers include `resource_metadata` to point
255/// at the Protected Resource Metadata document. The optional `scope` parameter
256/// (RFC 6750 Section 3) indicates scopes required for the request.
257#[derive(Debug, Clone, PartialEq, Eq)]
258pub struct WwwAuthenticate {
259 pub resource_metadata: Option<Url>,
260 pub scope: Option<Vec<String>>,
261 /// The parsed `error` parameter per RFC 6750 Section 3.1.
262 pub error: Option<BearerError>,
263 pub error_description: Option<String>,
264}
265
266/// Parse a `WWW-Authenticate` header value.
267///
268/// Expects the `Bearer` scheme followed by comma-separated `key="value"` pairs.
269/// Per RFC 6750 and RFC 9728, the relevant parameters are:
270/// - `resource_metadata` — URL of the Protected Resource Metadata document
271/// - `scope` — space-separated list of required scopes
272/// - `error` — error code (e.g. "insufficient_scope")
273/// - `error_description` — human-readable error description
274pub fn parse_www_authenticate(header: &str) -> Result<WwwAuthenticate> {
275 let header = header.trim();
276
277 let params_str = if header.len() >= 6 && header[..6].eq_ignore_ascii_case("bearer") {
278 header[6..].trim()
279 } else {
280 bail!("WWW-Authenticate header does not use Bearer scheme");
281 };
282
283 if params_str.is_empty() {
284 return Ok(WwwAuthenticate {
285 resource_metadata: None,
286 scope: None,
287 error: None,
288 error_description: None,
289 });
290 }
291
292 let params = parse_auth_params(params_str);
293
294 let resource_metadata = params
295 .get("resource_metadata")
296 .map(|v| Url::parse(v))
297 .transpose()
298 .map_err(|e| anyhow!("invalid resource_metadata URL: {}", e))?;
299
300 let scope = params
301 .get("scope")
302 .map(|v| v.split_whitespace().map(String::from).collect());
303
304 let error = params.get("error").map(|v| BearerError::parse(v));
305 let error_description = params.get("error_description").cloned();
306
307 Ok(WwwAuthenticate {
308 resource_metadata,
309 scope,
310 error,
311 error_description,
312 })
313}
314
315/// Parse comma-separated `key="value"` or `key=token` parameters from an
316/// auth-param list (RFC 7235 Section 2.1).
317fn parse_auth_params(input: &str) -> collections::HashMap<String, String> {
318 let mut params = collections::HashMap::default();
319 let mut remaining = input.trim();
320
321 while !remaining.is_empty() {
322 // Skip leading whitespace and commas.
323 remaining = remaining.trim_start_matches(|c: char| c == ',' || c.is_whitespace());
324 if remaining.is_empty() {
325 break;
326 }
327
328 // Find the key (everything before '=').
329 let eq_pos = match remaining.find('=') {
330 Some(pos) => pos,
331 None => break,
332 };
333
334 let key = remaining[..eq_pos].trim().to_lowercase();
335 remaining = &remaining[eq_pos + 1..];
336 remaining = remaining.trim_start();
337
338 // Parse the value: either quoted or unquoted (token).
339 let value;
340 if remaining.starts_with('"') {
341 // Quoted string: find the closing quote, handling escaped chars.
342 remaining = &remaining[1..]; // skip opening quote
343 let mut val = String::new();
344 let mut chars = remaining.char_indices();
345 loop {
346 match chars.next() {
347 Some((_, '\\')) => {
348 // Escaped character — take the next char literally.
349 if let Some((_, c)) = chars.next() {
350 val.push(c);
351 }
352 }
353 Some((i, '"')) => {
354 remaining = &remaining[i + 1..];
355 break;
356 }
357 Some((_, c)) => val.push(c),
358 None => {
359 remaining = "";
360 break;
361 }
362 }
363 }
364 value = val;
365 } else {
366 // Unquoted token: read until comma or whitespace.
367 let end = remaining
368 .find(|c: char| c == ',' || c.is_whitespace())
369 .unwrap_or(remaining.len());
370 value = remaining[..end].to_string();
371 remaining = &remaining[end..];
372 }
373
374 if !key.is_empty() {
375 params.insert(key, value);
376 }
377 }
378
379 params
380}
381
382/// Construct the well-known Protected Resource Metadata URIs for a given MCP
383/// server URL, per RFC 9728 Section 3.
384///
385/// Returns URIs in priority order:
386/// 1. Path-specific: `https://<host>/.well-known/oauth-protected-resource/<path>`
387/// 2. Root: `https://<host>/.well-known/oauth-protected-resource`
388pub fn protected_resource_metadata_urls(server_url: &Url) -> Vec<Url> {
389 let mut urls = Vec::new();
390 let base = format!("{}://{}", server_url.scheme(), server_url.authority());
391
392 let path = server_url.path().trim_start_matches('/');
393 if !path.is_empty() {
394 if let Ok(url) = Url::parse(&format!(
395 "{}/.well-known/oauth-protected-resource/{}",
396 base, path
397 )) {
398 urls.push(url);
399 }
400 }
401
402 if let Ok(url) = Url::parse(&format!("{}/.well-known/oauth-protected-resource", base)) {
403 urls.push(url);
404 }
405
406 urls
407}
408
409/// Construct the well-known Authorization Server Metadata URIs for a given
410/// issuer URL, per RFC 8414 Section 3.1 and Section 5 (OIDC compat).
411///
412/// Returns URIs in priority order, which differs depending on whether the
413/// issuer URL has a path component.
414pub fn auth_server_metadata_urls(issuer: &Url) -> Vec<Url> {
415 let mut urls = Vec::new();
416 let base = format!("{}://{}", issuer.scheme(), issuer.authority());
417 let path = issuer.path().trim_matches('/');
418
419 if !path.is_empty() {
420 // Issuer with path: try path-inserted variants first.
421 if let Ok(url) = Url::parse(&format!(
422 "{}/.well-known/oauth-authorization-server/{}",
423 base, path
424 )) {
425 urls.push(url);
426 }
427 if let Ok(url) = Url::parse(&format!(
428 "{}/.well-known/openid-configuration/{}",
429 base, path
430 )) {
431 urls.push(url);
432 }
433 if let Ok(url) = Url::parse(&format!(
434 "{}/{}/.well-known/openid-configuration",
435 base, path
436 )) {
437 urls.push(url);
438 }
439 } else {
440 // No path: standard well-known locations.
441 if let Ok(url) = Url::parse(&format!("{}/.well-known/oauth-authorization-server", base)) {
442 urls.push(url);
443 }
444 if let Ok(url) = Url::parse(&format!("{}/.well-known/openid-configuration", base)) {
445 urls.push(url);
446 }
447 }
448
449 urls
450}
451
452// -- Canonical server URI (RFC 8707) -----------------------------------------
453
454/// Derive the canonical resource URI for an MCP server URL, suitable for the
455/// `resource` parameter in authorization and token requests per RFC 8707.
456///
457/// Lowercases the scheme and host, preserves the path (without trailing slash),
458/// strips fragments and query strings.
459pub fn canonical_server_uri(server_url: &Url) -> String {
460 let mut uri = format!(
461 "{}://{}",
462 server_url.scheme().to_ascii_lowercase(),
463 server_url.host_str().unwrap_or("").to_ascii_lowercase(),
464 );
465 if let Some(port) = server_url.port() {
466 uri.push_str(&format!(":{}", port));
467 }
468 let path = server_url.path();
469 if path != "/" {
470 uri.push_str(path.trim_end_matches('/'));
471 }
472 uri
473}
474
475// -- Scope selection ---------------------------------------------------------
476
477/// Select scopes following the MCP spec's Scope Selection Strategy:
478/// 1. Use `scope` from the `WWW-Authenticate` challenge if present.
479/// 2. Fall back to `scopes_supported` from Protected Resource Metadata.
480/// 3. Return empty if neither is available.
481pub fn select_scopes(
482 www_authenticate: &WwwAuthenticate,
483 resource_metadata: &ProtectedResourceMetadata,
484) -> Vec<String> {
485 if let Some(ref scopes) = www_authenticate.scope {
486 if !scopes.is_empty() {
487 return scopes.clone();
488 }
489 }
490 resource_metadata
491 .scopes_supported
492 .clone()
493 .unwrap_or_default()
494}
495
496// -- Client registration strategy --------------------------------------------
497
498/// The registration approach to use, determined from auth server metadata.
499#[derive(Debug, Clone, PartialEq, Eq)]
500pub enum ClientRegistrationStrategy {
501 /// The auth server supports CIMD. Use the CIMD URL as client_id directly.
502 Cimd { client_id: String },
503 /// The auth server has a registration endpoint. Caller must POST to it.
504 Dcr { registration_endpoint: Url },
505 /// No supported registration mechanism.
506 Unavailable,
507}
508
509/// Determine how to register with the authorization server, following the
510/// spec's recommended priority: CIMD first, DCR fallback.
511pub fn determine_registration_strategy(
512 auth_server_metadata: &AuthServerMetadata,
513) -> ClientRegistrationStrategy {
514 if auth_server_metadata.client_id_metadata_document_supported {
515 ClientRegistrationStrategy::Cimd {
516 client_id: CIMD_URL.to_string(),
517 }
518 } else if let Some(ref endpoint) = auth_server_metadata.registration_endpoint {
519 ClientRegistrationStrategy::Dcr {
520 registration_endpoint: endpoint.clone(),
521 }
522 } else {
523 ClientRegistrationStrategy::Unavailable
524 }
525}
526
527// -- PKCE (RFC 7636) ---------------------------------------------------------
528
529/// A PKCE code verifier and its S256 challenge.
530#[derive(Clone)]
531pub struct PkceChallenge {
532 pub verifier: String,
533 pub challenge: String,
534}
535
536impl std::fmt::Debug for PkceChallenge {
537 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
538 f.debug_struct("PkceChallenge")
539 .field("verifier", &"[redacted]")
540 .field("challenge", &self.challenge)
541 .finish()
542 }
543}
544
545/// Generate a PKCE code verifier and S256 challenge per RFC 7636.
546///
547/// The verifier is 43 base64url characters derived from 32 random bytes.
548/// The challenge is `BASE64URL(SHA256(verifier))`.
549pub fn generate_pkce_challenge() -> PkceChallenge {
550 let mut random_bytes = [0u8; 32];
551 rand::rng().fill(&mut random_bytes);
552 let engine = base64::engine::general_purpose::URL_SAFE_NO_PAD;
553 let verifier = engine.encode(&random_bytes);
554
555 let digest = Sha256::digest(verifier.as_bytes());
556 let challenge = engine.encode(digest);
557
558 PkceChallenge {
559 verifier,
560 challenge,
561 }
562}
563
564// -- Authorization URL construction ------------------------------------------
565
566/// Build the authorization URL for the OAuth Authorization Code + PKCE flow.
567pub fn build_authorization_url(
568 auth_server_metadata: &AuthServerMetadata,
569 client_id: &str,
570 redirect_uri: &str,
571 scopes: &[String],
572 resource: &str,
573 pkce: &PkceChallenge,
574 state: &str,
575) -> Url {
576 let mut url = auth_server_metadata.authorization_endpoint.clone();
577 {
578 let mut query = url.query_pairs_mut();
579 query.append_pair("response_type", "code");
580 query.append_pair("client_id", client_id);
581 query.append_pair("redirect_uri", redirect_uri);
582 if !scopes.is_empty() {
583 query.append_pair("scope", &scopes.join(" "));
584 }
585 query.append_pair("resource", resource);
586 query.append_pair("code_challenge", &pkce.challenge);
587 query.append_pair("code_challenge_method", "S256");
588 query.append_pair("state", state);
589 }
590 url
591}
592
593// -- Token endpoint request bodies -------------------------------------------
594
595/// The JSON body returned by the token endpoint on success.
596#[derive(Deserialize)]
597pub struct TokenResponse {
598 pub access_token: String,
599 #[serde(default)]
600 pub refresh_token: Option<String>,
601 #[serde(default)]
602 pub expires_in: Option<u64>,
603 #[serde(default)]
604 pub token_type: Option<String>,
605}
606
607impl std::fmt::Debug for TokenResponse {
608 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
609 f.debug_struct("TokenResponse")
610 .field("access_token", &"[redacted]")
611 .field(
612 "refresh_token",
613 &self.refresh_token.as_ref().map(|_| "[redacted]"),
614 )
615 .field("expires_in", &self.expires_in)
616 .field("token_type", &self.token_type)
617 .finish()
618 }
619}
620
621impl TokenResponse {
622 /// Convert into `OAuthTokens`, computing `expires_at` from `expires_in`.
623 pub fn into_tokens(self) -> OAuthTokens {
624 let expires_at = self
625 .expires_in
626 .map(|secs| SystemTime::now() + Duration::from_secs(secs));
627 OAuthTokens {
628 access_token: self.access_token,
629 refresh_token: self.refresh_token,
630 expires_at,
631 }
632 }
633}
634
635/// An OAuth token error response (RFC 6749 Section 5.2).
636#[derive(Debug, Deserialize, PartialEq)]
637pub struct OAuthTokenError {
638 pub error: String,
639 #[serde(default)]
640 pub error_description: Option<String>,
641}
642
643impl std::fmt::Display for OAuthTokenError {
644 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
645 write!(f, "OAuth token error: {}", self.error)?;
646 if let Some(description) = &self.error_description {
647 write!(f, " ({description})")?;
648 }
649 Ok(())
650 }
651}
652
653impl std::error::Error for OAuthTokenError {}
654
655/// Build the form-encoded body for an authorization code token exchange.
656pub fn token_exchange_params(
657 code: &str,
658 client_id: &str,
659 redirect_uri: &str,
660 code_verifier: &str,
661 resource: &str,
662 client_secret: Option<&str>,
663) -> Vec<(&'static str, String)> {
664 let mut params = vec![
665 ("grant_type", "authorization_code".to_string()),
666 ("code", code.to_string()),
667 ("redirect_uri", redirect_uri.to_string()),
668 ("client_id", client_id.to_string()),
669 ("code_verifier", code_verifier.to_string()),
670 ("resource", resource.to_string()),
671 ];
672 if let Some(secret) = client_secret {
673 params.push(("client_secret", secret.to_string()));
674 }
675 params
676}
677
678/// Build the form-encoded body for a token refresh request.
679pub fn token_refresh_params(
680 refresh_token: &str,
681 client_id: &str,
682 resource: &str,
683 client_secret: Option<&str>,
684) -> Vec<(&'static str, String)> {
685 let mut params = vec![
686 ("grant_type", "refresh_token".to_string()),
687 ("refresh_token", refresh_token.to_string()),
688 ("client_id", client_id.to_string()),
689 ("resource", resource.to_string()),
690 ];
691 if let Some(secret) = client_secret {
692 params.push(("client_secret", secret.to_string()));
693 }
694 params
695}
696
697// -- DCR request body (RFC 7591) ---------------------------------------------
698
699/// Build the JSON body for a Dynamic Client Registration request.
700///
701/// The `redirect_uri` should be the actual loopback URI with the ephemeral
702/// port (e.g. `http://127.0.0.1:12345/callback`). Some auth servers do strict
703/// redirect URI matching even for loopback addresses, so we register the
704/// exact URI we intend to use.
705pub fn dcr_registration_body(redirect_uri: &str) -> serde_json::Value {
706 serde_json::json!({
707 "client_name": "Zed",
708 "redirect_uris": [redirect_uri],
709 "grant_types": ["authorization_code"],
710 "response_types": ["code"],
711 "token_endpoint_auth_method": "none"
712 })
713}
714
715// -- Discovery (async, hits real endpoints) ----------------------------------
716
717/// Fetch Protected Resource Metadata from the MCP server.
718///
719/// Tries the `resource_metadata` URL from the `WWW-Authenticate` header first,
720/// then falls back to well-known URIs constructed from `server_url`.
721pub async fn fetch_protected_resource_metadata(
722 http_client: &Arc<dyn HttpClient>,
723 server_url: &Url,
724 www_authenticate: &WwwAuthenticate,
725) -> Result<ProtectedResourceMetadata> {
726 let candidate_urls = match &www_authenticate.resource_metadata {
727 Some(url) if url.origin() == server_url.origin() => vec![url.clone()],
728 Some(url) => {
729 log::warn!(
730 "Ignoring cross-origin resource_metadata URL {} \
731 (server origin: {})",
732 url,
733 server_url.origin().unicode_serialization()
734 );
735 protected_resource_metadata_urls(server_url)
736 }
737 None => protected_resource_metadata_urls(server_url),
738 };
739
740 for url in &candidate_urls {
741 match fetch_json::<ProtectedResourceMetadataResponse>(http_client, url).await {
742 Ok(response) => {
743 if response.authorization_servers.is_empty() {
744 bail!(
745 "Protected Resource Metadata at {} has no authorization_servers",
746 url
747 );
748 }
749 return Ok(ProtectedResourceMetadata {
750 resource: response.resource.unwrap_or_else(|| server_url.clone()),
751 authorization_servers: response.authorization_servers,
752 scopes_supported: response.scopes_supported,
753 });
754 }
755 Err(err) => {
756 log::debug!(
757 "Failed to fetch Protected Resource Metadata from {}: {}",
758 url,
759 err
760 );
761 }
762 }
763 }
764
765 bail!(
766 "Could not fetch Protected Resource Metadata for {}",
767 server_url
768 )
769}
770
771/// Fetch Authorization Server Metadata, trying RFC 8414 and OIDC Discovery
772/// endpoints in the priority order specified by the MCP spec.
773pub async fn fetch_auth_server_metadata(
774 http_client: &Arc<dyn HttpClient>,
775 issuer: &Url,
776) -> Result<AuthServerMetadata> {
777 let candidate_urls = auth_server_metadata_urls(issuer);
778
779 for url in &candidate_urls {
780 match fetch_json::<AuthServerMetadataResponse>(http_client, url).await {
781 Ok(response) => {
782 let reported_issuer = response.issuer.unwrap_or_else(|| issuer.clone());
783
784 if reported_issuer != *issuer {
785 bail!(
786 "Auth server metadata issuer mismatch: expected {}, got {}",
787 issuer,
788 reported_issuer
789 );
790 }
791
792 return Ok(AuthServerMetadata {
793 issuer: reported_issuer,
794 authorization_endpoint: response
795 .authorization_endpoint
796 .ok_or_else(|| anyhow!("missing authorization_endpoint"))?,
797 token_endpoint: response
798 .token_endpoint
799 .ok_or_else(|| anyhow!("missing token_endpoint"))?,
800 registration_endpoint: response.registration_endpoint,
801 scopes_supported: response.scopes_supported,
802 code_challenge_methods_supported: response.code_challenge_methods_supported,
803 client_id_metadata_document_supported: response
804 .client_id_metadata_document_supported
805 .unwrap_or(false),
806 });
807 }
808 Err(err) => {
809 log::debug!("Failed to fetch Auth Server Metadata from {}: {}", url, err);
810 }
811 }
812 }
813
814 bail!(
815 "Could not fetch Authorization Server Metadata for {}",
816 issuer
817 )
818}
819
820/// Run the full discovery flow: fetch resource metadata, then auth server
821/// metadata, then select scopes. Client registration is resolved separately,
822/// once the real redirect URI is known.
823pub async fn discover(
824 http_client: &Arc<dyn HttpClient>,
825 server_url: &Url,
826 www_authenticate: &WwwAuthenticate,
827) -> Result<OAuthDiscovery> {
828 let resource_metadata =
829 fetch_protected_resource_metadata(http_client, server_url, www_authenticate).await?;
830
831 let auth_server_url = resource_metadata
832 .authorization_servers
833 .first()
834 .ok_or_else(|| anyhow!("no authorization servers in resource metadata"))?;
835
836 let auth_server_metadata = fetch_auth_server_metadata(http_client, auth_server_url).await?;
837
838 // Verify PKCE S256 support (spec requirement).
839 match &auth_server_metadata.code_challenge_methods_supported {
840 Some(methods) if methods.iter().any(|m| m == "S256") => {}
841 Some(_) => bail!("authorization server does not support S256 PKCE"),
842 None => bail!("authorization server does not advertise code_challenge_methods_supported"),
843 }
844
845 let scopes = select_scopes(www_authenticate, &resource_metadata);
846
847 Ok(OAuthDiscovery {
848 resource_metadata,
849 auth_server_metadata,
850 scopes,
851 })
852}
853
854/// Resolve the OAuth client registration for an authorization flow.
855///
856/// CIMD uses the static client metadata document directly. For DCR, a fresh
857/// registration is performed each time because the loopback redirect URI
858/// includes an ephemeral port that changes every flow.
859pub async fn resolve_client_registration(
860 http_client: &Arc<dyn HttpClient>,
861 discovery: &OAuthDiscovery,
862 redirect_uri: &str,
863) -> Result<OAuthClientRegistration> {
864 match determine_registration_strategy(&discovery.auth_server_metadata) {
865 ClientRegistrationStrategy::Cimd { client_id } => Ok(OAuthClientRegistration {
866 client_id,
867 client_secret: None,
868 }),
869 ClientRegistrationStrategy::Dcr {
870 registration_endpoint,
871 } => perform_dcr(http_client, ®istration_endpoint, redirect_uri).await,
872 ClientRegistrationStrategy::Unavailable => {
873 bail!("authorization server supports neither CIMD nor DCR")
874 }
875 }
876}
877
878// -- Dynamic Client Registration (RFC 7591) ----------------------------------
879
880/// Perform Dynamic Client Registration with the authorization server.
881pub async fn perform_dcr(
882 http_client: &Arc<dyn HttpClient>,
883 registration_endpoint: &Url,
884 redirect_uri: &str,
885) -> Result<OAuthClientRegistration> {
886 validate_oauth_url(registration_endpoint)?;
887
888 let body = dcr_registration_body(redirect_uri);
889 let body_bytes = serde_json::to_vec(&body)?;
890
891 let request = Request::builder()
892 .method(http_client::http::Method::POST)
893 .uri(registration_endpoint.as_str())
894 .header("Content-Type", "application/json")
895 .header("Accept", "application/json")
896 .body(AsyncBody::from(body_bytes))?;
897
898 let mut response = http_client.send(request).await?;
899
900 if !response.status().is_success() {
901 let mut error_body = String::new();
902 response.body_mut().read_to_string(&mut error_body).await?;
903 bail!(
904 "DCR failed with status {}: {}",
905 response.status(),
906 error_body
907 );
908 }
909
910 let mut response_body = String::new();
911 response
912 .body_mut()
913 .read_to_string(&mut response_body)
914 .await?;
915
916 let dcr_response: DcrResponse =
917 serde_json::from_str(&response_body).context("failed to parse DCR response")?;
918
919 Ok(OAuthClientRegistration {
920 client_id: dcr_response.client_id,
921 client_secret: dcr_response.client_secret,
922 })
923}
924
925// -- Token exchange and refresh (async) --------------------------------------
926
927/// Exchange an authorization code for tokens at the token endpoint.
928pub async fn exchange_code(
929 http_client: &Arc<dyn HttpClient>,
930 auth_server_metadata: &AuthServerMetadata,
931 code: &str,
932 client_id: &str,
933 redirect_uri: &str,
934 code_verifier: &str,
935 resource: &str,
936 client_secret: Option<&str>,
937) -> Result<OAuthTokens> {
938 let params = token_exchange_params(
939 code,
940 client_id,
941 redirect_uri,
942 code_verifier,
943 resource,
944 client_secret,
945 );
946 post_token_request(http_client, &auth_server_metadata.token_endpoint, ¶ms).await
947}
948
949/// Refresh tokens using a refresh token.
950pub async fn refresh_tokens(
951 http_client: &Arc<dyn HttpClient>,
952 token_endpoint: &Url,
953 refresh_token: &str,
954 client_id: &str,
955 resource: &str,
956 client_secret: Option<&str>,
957) -> Result<OAuthTokens> {
958 let params = token_refresh_params(refresh_token, client_id, resource, client_secret);
959 post_token_request(http_client, token_endpoint, ¶ms).await
960}
961
962/// POST form-encoded parameters to a token endpoint and parse the response.
963async fn post_token_request(
964 http_client: &Arc<dyn HttpClient>,
965 token_endpoint: &Url,
966 params: &[(&str, String)],
967) -> Result<OAuthTokens> {
968 validate_oauth_url(token_endpoint)?;
969
970 let body = url::form_urlencoded::Serializer::new(String::new())
971 .extend_pairs(params.iter().map(|(k, v)| (*k, v.as_str())))
972 .finish();
973
974 let request = Request::builder()
975 .method(http_client::http::Method::POST)
976 .uri(token_endpoint.as_str())
977 .header("Content-Type", "application/x-www-form-urlencoded")
978 .header("Accept", "application/json")
979 .body(AsyncBody::from(body.into_bytes()))?;
980
981 let mut response = http_client.send(request).await?;
982
983 if !response.status().is_success() {
984 let mut error_body = String::new();
985 response.body_mut().read_to_string(&mut error_body).await?;
986 let status = response.status();
987 // Try to parse as an OAuth error response (RFC 6749 Section 5.2).
988 if let Ok(token_error) = serde_json::from_str::<OAuthTokenError>(&error_body) {
989 return Err(token_error.into());
990 }
991 bail!("token request failed with status {status}: {error_body}");
992 }
993
994 let mut response_body = String::new();
995 response
996 .body_mut()
997 .read_to_string(&mut response_body)
998 .await?;
999
1000 let token_response: TokenResponse =
1001 serde_json::from_str(&response_body).context("failed to parse token response")?;
1002
1003 Ok(token_response.into_tokens())
1004}
1005
1006// -- Loopback HTTP callback server -------------------------------------------
1007
1008/// An OAuth authorization callback received via the loopback HTTP server.
1009pub struct OAuthCallback {
1010 pub code: String,
1011 pub state: String,
1012}
1013
1014impl std::fmt::Debug for OAuthCallback {
1015 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1016 f.debug_struct("OAuthCallback")
1017 .field("code", &"[redacted]")
1018 .field("state", &"[redacted]")
1019 .finish()
1020 }
1021}
1022
1023impl OAuthCallback {
1024 /// Parse the query string from a callback URL like
1025 /// `http://127.0.0.1:<port>/callback?code=...&state=...`.
1026 pub fn parse_query(query: &str) -> Result<Self> {
1027 let mut code: Option<String> = None;
1028 let mut state: Option<String> = None;
1029 let mut error: Option<String> = None;
1030 let mut error_description: Option<String> = None;
1031
1032 for (key, value) in url::form_urlencoded::parse(query.as_bytes()) {
1033 match key.as_ref() {
1034 "code" => {
1035 if !value.is_empty() {
1036 code = Some(value.into_owned());
1037 }
1038 }
1039 "state" => {
1040 if !value.is_empty() {
1041 state = Some(value.into_owned());
1042 }
1043 }
1044 "error" => {
1045 if !value.is_empty() {
1046 error = Some(value.into_owned());
1047 }
1048 }
1049 "error_description" => {
1050 if !value.is_empty() {
1051 error_description = Some(value.into_owned());
1052 }
1053 }
1054 _ => {}
1055 }
1056 }
1057
1058 // Check for OAuth error response (RFC 6749 Section 4.1.2.1) before
1059 // checking for missing code/state.
1060 if let Some(error_code) = error {
1061 bail!(
1062 "OAuth authorization failed: {} ({})",
1063 error_code,
1064 error_description.as_deref().unwrap_or("no description")
1065 );
1066 }
1067
1068 let code = code.ok_or_else(|| anyhow!("missing 'code' parameter in OAuth callback"))?;
1069 let state = state.ok_or_else(|| anyhow!("missing 'state' parameter in OAuth callback"))?;
1070
1071 Ok(Self { code, state })
1072 }
1073}
1074
1075/// How long to wait for the browser to complete the OAuth flow before giving
1076/// up and releasing the loopback port.
1077const CALLBACK_TIMEOUT: Duration = Duration::from_secs(2 * 60);
1078
1079/// Start a loopback HTTP server to receive the OAuth authorization callback.
1080///
1081/// Binds to an ephemeral loopback port for each flow.
1082///
1083/// Returns `(redirect_uri, callback_future)`. The caller should use the
1084/// redirect URI in the authorization request, open the browser, then await
1085/// the future to receive the callback.
1086///
1087/// The server accepts exactly one request on `/callback`, validates that it
1088/// contains `code` and `state` query parameters, responds with a minimal
1089/// HTML page telling the user they can close the tab, and shuts down.
1090///
1091/// The callback server shuts down when the returned oneshot receiver is dropped
1092/// (e.g. because the authentication task was cancelled), or after a timeout
1093/// ([CALLBACK_TIMEOUT]).
1094pub async fn start_callback_server() -> Result<(
1095 String,
1096 futures::channel::oneshot::Receiver<Result<OAuthCallback>>,
1097)> {
1098 let server = tiny_http::Server::http("127.0.0.1:0")
1099 .map_err(|e| anyhow!(e).context("Failed to bind loopback listener for OAuth callback"))?;
1100 let port = server
1101 .server_addr()
1102 .to_ip()
1103 .context("server not bound to a TCP address")?
1104 .port();
1105
1106 let redirect_uri = format!("http://127.0.0.1:{}/callback", port);
1107
1108 let (tx, rx) = futures::channel::oneshot::channel();
1109
1110 // `tiny_http` is blocking, so we run it on a background thread.
1111 // The `recv_timeout` loop lets us check for cancellation (the receiver
1112 // being dropped) and enforce an overall timeout.
1113 std::thread::spawn(move || {
1114 let deadline = std::time::Instant::now() + CALLBACK_TIMEOUT;
1115
1116 loop {
1117 if tx.is_canceled() {
1118 return;
1119 }
1120 let remaining = deadline.saturating_duration_since(std::time::Instant::now());
1121 if remaining.is_zero() {
1122 return;
1123 }
1124
1125 let timeout = remaining.min(Duration::from_millis(500));
1126 let Some(request) = (match server.recv_timeout(timeout) {
1127 Ok(req) => req,
1128 Err(_) => {
1129 let _ = tx.send(Err(anyhow!("OAuth callback server I/O error")));
1130 return;
1131 }
1132 }) else {
1133 // Timeout with no request — loop back and check cancellation.
1134 continue;
1135 };
1136
1137 let result = handle_callback_request(&request);
1138
1139 let (status_code, body) = match &result {
1140 Ok(_) => (
1141 200,
1142 "<html><body><h1>Authorization successful</h1>\
1143 <p>You can close this tab and return to Zed.</p></body></html>",
1144 ),
1145 Err(err) => {
1146 log::error!("OAuth callback error: {}", err);
1147 (
1148 400,
1149 "<html><body><h1>Authorization failed</h1>\
1150 <p>Something went wrong. Please try again from Zed.</p></body></html>",
1151 )
1152 }
1153 };
1154
1155 let response = tiny_http::Response::from_string(body)
1156 .with_status_code(status_code)
1157 .with_header(
1158 tiny_http::Header::from_str("Content-Type: text/html")
1159 .expect("failed to construct response header"),
1160 )
1161 .with_header(
1162 tiny_http::Header::from_str("Keep-Alive: timeout=0,max=0")
1163 .expect("failed to construct response header"),
1164 );
1165 request.respond(response).log_err();
1166
1167 let _ = tx.send(result);
1168 return;
1169 }
1170 });
1171
1172 Ok((redirect_uri, rx))
1173}
1174
1175/// Extract the `code` and `state` query parameters from an OAuth callback
1176/// request to `/callback`.
1177fn handle_callback_request(request: &tiny_http::Request) -> Result<OAuthCallback> {
1178 let url = Url::parse(&format!("http://localhost{}", request.url()))
1179 .context("malformed callback request URL")?;
1180
1181 if url.path() != "/callback" {
1182 bail!("unexpected path in OAuth callback: {}", url.path());
1183 }
1184
1185 let query = url
1186 .query()
1187 .ok_or_else(|| anyhow!("OAuth callback has no query string"))?;
1188 OAuthCallback::parse_query(query)
1189}
1190
1191// -- JSON fetch helper -------------------------------------------------------
1192
1193async fn fetch_json<T: serde::de::DeserializeOwned>(
1194 http_client: &Arc<dyn HttpClient>,
1195 url: &Url,
1196) -> Result<T> {
1197 validate_oauth_url(url)?;
1198
1199 let request = Request::builder()
1200 .method(http_client::http::Method::GET)
1201 .uri(url.as_str())
1202 .header("Accept", "application/json")
1203 .body(AsyncBody::default())?;
1204
1205 let mut response = http_client.send(request).await?;
1206
1207 if !response.status().is_success() {
1208 bail!("HTTP {} fetching {}", response.status(), url);
1209 }
1210
1211 let mut body = String::new();
1212 response.body_mut().read_to_string(&mut body).await?;
1213 serde_json::from_str(&body).with_context(|| format!("failed to parse JSON from {}", url))
1214}
1215
1216// -- Serde response types for discovery --------------------------------------
1217
1218#[derive(Debug, Deserialize)]
1219struct ProtectedResourceMetadataResponse {
1220 #[serde(default)]
1221 resource: Option<Url>,
1222 #[serde(default)]
1223 authorization_servers: Vec<Url>,
1224 #[serde(default)]
1225 scopes_supported: Option<Vec<String>>,
1226}
1227
1228#[derive(Debug, Deserialize)]
1229struct AuthServerMetadataResponse {
1230 #[serde(default)]
1231 issuer: Option<Url>,
1232 #[serde(default)]
1233 authorization_endpoint: Option<Url>,
1234 #[serde(default)]
1235 token_endpoint: Option<Url>,
1236 #[serde(default)]
1237 registration_endpoint: Option<Url>,
1238 #[serde(default)]
1239 scopes_supported: Option<Vec<String>>,
1240 #[serde(default)]
1241 code_challenge_methods_supported: Option<Vec<String>>,
1242 #[serde(default)]
1243 client_id_metadata_document_supported: Option<bool>,
1244}
1245
1246#[derive(Debug, Deserialize)]
1247struct DcrResponse {
1248 client_id: String,
1249 #[serde(default)]
1250 client_secret: Option<String>,
1251}
1252
1253/// Provides OAuth tokens to the HTTP transport layer.
1254///
1255/// The transport calls `access_token()` before each request. On a 401 response
1256/// it calls `try_refresh()` and retries once if the refresh succeeds.
1257#[async_trait]
1258pub trait OAuthTokenProvider: Send + Sync {
1259 /// Returns the current access token, if one is available.
1260 fn access_token(&self) -> Option<String>;
1261
1262 /// Attempts to refresh the access token. Returns `true` if a new token was
1263 /// obtained and the request should be retried.
1264 async fn try_refresh(&self) -> Result<bool>;
1265}
1266
1267/// Concrete `OAuthTokenProvider` backed by a full persisted OAuth session and
1268/// an HTTP client for token refresh. The same provider type is used both after
1269/// an interactive authentication flow and when restoring a saved session from
1270/// the keychain on startup.
1271pub struct McpOAuthTokenProvider {
1272 session: SyncMutex<OAuthSession>,
1273 http_client: Arc<dyn HttpClient>,
1274 token_refresh_tx: Option<mpsc::UnboundedSender<OAuthSession>>,
1275}
1276
1277impl McpOAuthTokenProvider {
1278 pub fn new(
1279 session: OAuthSession,
1280 http_client: Arc<dyn HttpClient>,
1281 token_refresh_tx: Option<mpsc::UnboundedSender<OAuthSession>>,
1282 ) -> Self {
1283 Self {
1284 session: SyncMutex::new(session),
1285 http_client,
1286 token_refresh_tx,
1287 }
1288 }
1289
1290 fn access_token_is_expired(tokens: &OAuthTokens) -> bool {
1291 tokens.expires_at.is_some_and(|expires_at| {
1292 SystemTime::now()
1293 .checked_add(Duration::from_secs(30))
1294 .is_some_and(|now_with_buffer| expires_at <= now_with_buffer)
1295 })
1296 }
1297}
1298
1299#[async_trait]
1300impl OAuthTokenProvider for McpOAuthTokenProvider {
1301 fn access_token(&self) -> Option<String> {
1302 let session = self.session.lock();
1303 if Self::access_token_is_expired(&session.tokens) {
1304 return None;
1305 }
1306 Some(session.tokens.access_token.clone())
1307 }
1308
1309 async fn try_refresh(&self) -> Result<bool> {
1310 let (refresh_token, token_endpoint, resource, client_id, client_secret) = {
1311 let session = self.session.lock();
1312 match session.tokens.refresh_token.clone() {
1313 Some(refresh_token) => (
1314 refresh_token,
1315 session.token_endpoint.clone(),
1316 session.resource.clone(),
1317 session.client_registration.client_id.clone(),
1318 session.client_registration.client_secret.clone(),
1319 ),
1320 None => return Ok(false),
1321 }
1322 };
1323
1324 let resource_str = canonical_server_uri(&resource);
1325
1326 match refresh_tokens(
1327 &self.http_client,
1328 &token_endpoint,
1329 &refresh_token,
1330 &client_id,
1331 &resource_str,
1332 client_secret.as_deref(),
1333 )
1334 .await
1335 {
1336 Ok(mut new_tokens) => {
1337 if new_tokens.refresh_token.is_none() {
1338 new_tokens.refresh_token = Some(refresh_token);
1339 }
1340
1341 {
1342 let mut session = self.session.lock();
1343 session.tokens = new_tokens;
1344
1345 if let Some(ref tx) = self.token_refresh_tx {
1346 tx.unbounded_send(session.clone()).ok();
1347 }
1348 }
1349
1350 Ok(true)
1351 }
1352 Err(err) => {
1353 log::warn!("OAuth token refresh failed: {}", err);
1354 Ok(false)
1355 }
1356 }
1357 }
1358}
1359
1360#[cfg(test)]
1361mod tests {
1362 use super::*;
1363 use http_client::Response;
1364
1365 // -- require_https_or_loopback tests ------------------------------------
1366
1367 #[test]
1368 fn test_require_https_or_loopback_accepts_https() {
1369 let url = Url::parse("https://auth.example.com/token").unwrap();
1370 assert!(require_https_or_loopback(&url).is_ok());
1371 }
1372
1373 #[test]
1374 fn test_require_https_or_loopback_rejects_http_remote() {
1375 let url = Url::parse("http://auth.example.com/token").unwrap();
1376 assert!(require_https_or_loopback(&url).is_err());
1377 }
1378
1379 #[test]
1380 fn test_require_https_or_loopback_accepts_http_127_0_0_1() {
1381 let url = Url::parse("http://127.0.0.1:8080/callback").unwrap();
1382 assert!(require_https_or_loopback(&url).is_ok());
1383 }
1384
1385 #[test]
1386 fn test_require_https_or_loopback_accepts_http_ipv6_loopback() {
1387 let url = Url::parse("http://[::1]:8080/callback").unwrap();
1388 assert!(require_https_or_loopback(&url).is_ok());
1389 }
1390
1391 #[test]
1392 fn test_require_https_or_loopback_accepts_http_localhost() {
1393 let url = Url::parse("http://localhost:8080/callback").unwrap();
1394 assert!(require_https_or_loopback(&url).is_ok());
1395 }
1396
1397 #[test]
1398 fn test_require_https_or_loopback_accepts_http_localhost_case_insensitive() {
1399 let url = Url::parse("http://LOCALHOST:8080/callback").unwrap();
1400 assert!(require_https_or_loopback(&url).is_ok());
1401 }
1402
1403 #[test]
1404 fn test_require_https_or_loopback_rejects_http_non_loopback_ip() {
1405 let url = Url::parse("http://192.168.1.1:8080/token").unwrap();
1406 assert!(require_https_or_loopback(&url).is_err());
1407 }
1408
1409 #[test]
1410 fn test_require_https_or_loopback_rejects_ftp() {
1411 let url = Url::parse("ftp://auth.example.com/token").unwrap();
1412 assert!(require_https_or_loopback(&url).is_err());
1413 }
1414
1415 // -- validate_oauth_url (SSRF) tests ------------------------------------
1416
1417 #[test]
1418 fn test_validate_oauth_url_accepts_https_public() {
1419 let url = Url::parse("https://auth.example.com/token").unwrap();
1420 assert!(validate_oauth_url(&url).is_ok());
1421 }
1422
1423 #[test]
1424 fn test_validate_oauth_url_rejects_private_ipv4_10() {
1425 let url = Url::parse("https://10.0.0.1/token").unwrap();
1426 assert!(validate_oauth_url(&url).is_err());
1427 }
1428
1429 #[test]
1430 fn test_validate_oauth_url_rejects_private_ipv4_172() {
1431 let url = Url::parse("https://172.16.0.1/token").unwrap();
1432 assert!(validate_oauth_url(&url).is_err());
1433 }
1434
1435 #[test]
1436 fn test_validate_oauth_url_rejects_private_ipv4_192() {
1437 let url = Url::parse("https://192.168.1.1/token").unwrap();
1438 assert!(validate_oauth_url(&url).is_err());
1439 }
1440
1441 #[test]
1442 fn test_validate_oauth_url_rejects_link_local() {
1443 let url = Url::parse("https://169.254.169.254/latest/meta-data/").unwrap();
1444 assert!(validate_oauth_url(&url).is_err());
1445 }
1446
1447 #[test]
1448 fn test_validate_oauth_url_rejects_ipv6_ula() {
1449 let url = Url::parse("https://[fd12:3456:789a::1]/token").unwrap();
1450 assert!(validate_oauth_url(&url).is_err());
1451 }
1452
1453 #[test]
1454 fn test_validate_oauth_url_rejects_ipv6_unspecified() {
1455 let url = Url::parse("https://[::]/token").unwrap();
1456 assert!(validate_oauth_url(&url).is_err());
1457 }
1458
1459 #[test]
1460 fn test_validate_oauth_url_rejects_ipv4_mapped_ipv6_private() {
1461 let url = Url::parse("https://[::ffff:10.0.0.1]/token").unwrap();
1462 assert!(validate_oauth_url(&url).is_err());
1463 }
1464
1465 #[test]
1466 fn test_validate_oauth_url_rejects_ipv4_mapped_ipv6_link_local() {
1467 let url = Url::parse("https://[::ffff:169.254.169.254]/token").unwrap();
1468 assert!(validate_oauth_url(&url).is_err());
1469 }
1470
1471 #[test]
1472 fn test_validate_oauth_url_allows_http_loopback() {
1473 // Loopback is permitted (it's our callback server).
1474 let url = Url::parse("http://127.0.0.1:8080/callback").unwrap();
1475 assert!(validate_oauth_url(&url).is_ok());
1476 }
1477
1478 #[test]
1479 fn test_validate_oauth_url_allows_https_public_ip() {
1480 let url = Url::parse("https://93.184.216.34/token").unwrap();
1481 assert!(validate_oauth_url(&url).is_ok());
1482 }
1483
1484 // -- parse_www_authenticate tests ----------------------------------------
1485
1486 #[test]
1487 fn test_parse_www_authenticate_with_resource_metadata_and_scope() {
1488 let header = r#"Bearer resource_metadata="https://mcp.example.com/.well-known/oauth-protected-resource", scope="files:read user:profile""#;
1489 let result = parse_www_authenticate(header).unwrap();
1490
1491 assert_eq!(
1492 result.resource_metadata.as_ref().map(|u| u.as_str()),
1493 Some("https://mcp.example.com/.well-known/oauth-protected-resource")
1494 );
1495 assert_eq!(
1496 result.scope,
1497 Some(vec!["files:read".to_string(), "user:profile".to_string()])
1498 );
1499 assert_eq!(result.error, None);
1500 }
1501
1502 #[test]
1503 fn test_parse_www_authenticate_resource_metadata_only() {
1504 let header = r#"Bearer resource_metadata="https://mcp.example.com/.well-known/oauth-protected-resource""#;
1505 let result = parse_www_authenticate(header).unwrap();
1506
1507 assert_eq!(
1508 result.resource_metadata.as_ref().map(|u| u.as_str()),
1509 Some("https://mcp.example.com/.well-known/oauth-protected-resource")
1510 );
1511 assert_eq!(result.scope, None);
1512 }
1513
1514 #[test]
1515 fn test_parse_www_authenticate_bare_bearer() {
1516 let result = parse_www_authenticate("Bearer").unwrap();
1517 assert_eq!(result.resource_metadata, None);
1518 assert_eq!(result.scope, None);
1519 }
1520
1521 #[test]
1522 fn test_parse_www_authenticate_with_error() {
1523 let header = r#"Bearer error="insufficient_scope", scope="files:read files:write", resource_metadata="https://mcp.example.com/.well-known/oauth-protected-resource", error_description="Additional file write permission required""#;
1524 let result = parse_www_authenticate(header).unwrap();
1525
1526 assert_eq!(result.error, Some(BearerError::InsufficientScope));
1527 assert_eq!(
1528 result.error_description.as_deref(),
1529 Some("Additional file write permission required")
1530 );
1531 assert_eq!(
1532 result.scope,
1533 Some(vec!["files:read".to_string(), "files:write".to_string()])
1534 );
1535 assert!(result.resource_metadata.is_some());
1536 }
1537
1538 #[test]
1539 fn test_parse_www_authenticate_invalid_token_error() {
1540 let header =
1541 r#"Bearer error="invalid_token", error_description="The access token expired""#;
1542 let result = parse_www_authenticate(header).unwrap();
1543 assert_eq!(result.error, Some(BearerError::InvalidToken));
1544 }
1545
1546 #[test]
1547 fn test_parse_www_authenticate_invalid_request_error() {
1548 let header = r#"Bearer error="invalid_request""#;
1549 let result = parse_www_authenticate(header).unwrap();
1550 assert_eq!(result.error, Some(BearerError::InvalidRequest));
1551 }
1552
1553 #[test]
1554 fn test_parse_www_authenticate_unknown_error() {
1555 let header = r#"Bearer error="some_future_error""#;
1556 let result = parse_www_authenticate(header).unwrap();
1557 assert_eq!(result.error, Some(BearerError::Other));
1558 }
1559
1560 #[test]
1561 fn test_parse_www_authenticate_rejects_non_bearer() {
1562 let result = parse_www_authenticate("Basic realm=\"example\"");
1563 assert!(result.is_err());
1564 }
1565
1566 #[test]
1567 fn test_parse_www_authenticate_case_insensitive_scheme() {
1568 let header = r#"bearer resource_metadata="https://example.com/.well-known/oauth-protected-resource""#;
1569 let result = parse_www_authenticate(header).unwrap();
1570 assert!(result.resource_metadata.is_some());
1571 }
1572
1573 #[test]
1574 fn test_parse_www_authenticate_multiline_style() {
1575 // Some servers emit the header spread across multiple lines joined by
1576 // whitespace, as shown in the spec examples.
1577 let header = "Bearer resource_metadata=\"https://mcp.example.com/.well-known/oauth-protected-resource\",\n scope=\"files:read\"";
1578 let result = parse_www_authenticate(header).unwrap();
1579 assert!(result.resource_metadata.is_some());
1580 assert_eq!(result.scope, Some(vec!["files:read".to_string()]));
1581 }
1582
1583 #[test]
1584 fn test_protected_resource_metadata_urls_with_path() {
1585 let server_url = Url::parse("https://api.example.com/v1/mcp").unwrap();
1586 let urls = protected_resource_metadata_urls(&server_url);
1587
1588 assert_eq!(urls.len(), 2);
1589 assert_eq!(
1590 urls[0].as_str(),
1591 "https://api.example.com/.well-known/oauth-protected-resource/v1/mcp"
1592 );
1593 assert_eq!(
1594 urls[1].as_str(),
1595 "https://api.example.com/.well-known/oauth-protected-resource"
1596 );
1597 }
1598
1599 #[test]
1600 fn test_protected_resource_metadata_urls_without_path() {
1601 let server_url = Url::parse("https://mcp.example.com").unwrap();
1602 let urls = protected_resource_metadata_urls(&server_url);
1603
1604 assert_eq!(urls.len(), 1);
1605 assert_eq!(
1606 urls[0].as_str(),
1607 "https://mcp.example.com/.well-known/oauth-protected-resource"
1608 );
1609 }
1610
1611 #[test]
1612 fn test_auth_server_metadata_urls_with_path() {
1613 let issuer = Url::parse("https://auth.example.com/tenant1").unwrap();
1614 let urls = auth_server_metadata_urls(&issuer);
1615
1616 assert_eq!(urls.len(), 3);
1617 assert_eq!(
1618 urls[0].as_str(),
1619 "https://auth.example.com/.well-known/oauth-authorization-server/tenant1"
1620 );
1621 assert_eq!(
1622 urls[1].as_str(),
1623 "https://auth.example.com/.well-known/openid-configuration/tenant1"
1624 );
1625 assert_eq!(
1626 urls[2].as_str(),
1627 "https://auth.example.com/tenant1/.well-known/openid-configuration"
1628 );
1629 }
1630
1631 #[test]
1632 fn test_auth_server_metadata_urls_without_path() {
1633 let issuer = Url::parse("https://auth.example.com").unwrap();
1634 let urls = auth_server_metadata_urls(&issuer);
1635
1636 assert_eq!(urls.len(), 2);
1637 assert_eq!(
1638 urls[0].as_str(),
1639 "https://auth.example.com/.well-known/oauth-authorization-server"
1640 );
1641 assert_eq!(
1642 urls[1].as_str(),
1643 "https://auth.example.com/.well-known/openid-configuration"
1644 );
1645 }
1646
1647 // -- Canonical server URI tests ------------------------------------------
1648
1649 #[test]
1650 fn test_canonical_server_uri_simple() {
1651 let url = Url::parse("https://mcp.example.com").unwrap();
1652 assert_eq!(canonical_server_uri(&url), "https://mcp.example.com");
1653 }
1654
1655 #[test]
1656 fn test_canonical_server_uri_with_path() {
1657 let url = Url::parse("https://mcp.example.com/v1/mcp").unwrap();
1658 assert_eq!(canonical_server_uri(&url), "https://mcp.example.com/v1/mcp");
1659 }
1660
1661 #[test]
1662 fn test_canonical_server_uri_strips_trailing_slash() {
1663 let url = Url::parse("https://mcp.example.com/").unwrap();
1664 assert_eq!(canonical_server_uri(&url), "https://mcp.example.com");
1665 }
1666
1667 #[test]
1668 fn test_canonical_server_uri_preserves_port() {
1669 let url = Url::parse("https://mcp.example.com:8443").unwrap();
1670 assert_eq!(canonical_server_uri(&url), "https://mcp.example.com:8443");
1671 }
1672
1673 #[test]
1674 fn test_canonical_server_uri_lowercases() {
1675 let url = Url::parse("HTTPS://MCP.Example.COM/Server/MCP").unwrap();
1676 assert_eq!(
1677 canonical_server_uri(&url),
1678 "https://mcp.example.com/Server/MCP"
1679 );
1680 }
1681
1682 // -- Scope selection tests -----------------------------------------------
1683
1684 #[test]
1685 fn test_select_scopes_prefers_www_authenticate() {
1686 let www_auth = WwwAuthenticate {
1687 resource_metadata: None,
1688 scope: Some(vec!["files:read".into()]),
1689 error: None,
1690 error_description: None,
1691 };
1692 let resource_meta = ProtectedResourceMetadata {
1693 resource: Url::parse("https://example.com").unwrap(),
1694 authorization_servers: vec![],
1695 scopes_supported: Some(vec!["files:read".into(), "files:write".into()]),
1696 };
1697 assert_eq!(select_scopes(&www_auth, &resource_meta), vec!["files:read"]);
1698 }
1699
1700 #[test]
1701 fn test_select_scopes_falls_back_to_resource_metadata() {
1702 let www_auth = WwwAuthenticate {
1703 resource_metadata: None,
1704 scope: None,
1705 error: None,
1706 error_description: None,
1707 };
1708 let resource_meta = ProtectedResourceMetadata {
1709 resource: Url::parse("https://example.com").unwrap(),
1710 authorization_servers: vec![],
1711 scopes_supported: Some(vec!["admin".into()]),
1712 };
1713 assert_eq!(select_scopes(&www_auth, &resource_meta), vec!["admin"]);
1714 }
1715
1716 #[test]
1717 fn test_select_scopes_empty_when_nothing_available() {
1718 let www_auth = WwwAuthenticate {
1719 resource_metadata: None,
1720 scope: None,
1721 error: None,
1722 error_description: None,
1723 };
1724 let resource_meta = ProtectedResourceMetadata {
1725 resource: Url::parse("https://example.com").unwrap(),
1726 authorization_servers: vec![],
1727 scopes_supported: None,
1728 };
1729 assert!(select_scopes(&www_auth, &resource_meta).is_empty());
1730 }
1731
1732 // -- Client registration strategy tests ----------------------------------
1733
1734 #[test]
1735 fn test_registration_strategy_prefers_cimd() {
1736 let metadata = AuthServerMetadata {
1737 issuer: Url::parse("https://auth.example.com").unwrap(),
1738 authorization_endpoint: Url::parse("https://auth.example.com/authorize").unwrap(),
1739 token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
1740 registration_endpoint: Some(Url::parse("https://auth.example.com/register").unwrap()),
1741 scopes_supported: None,
1742 code_challenge_methods_supported: Some(vec!["S256".into()]),
1743 client_id_metadata_document_supported: true,
1744 };
1745 assert_eq!(
1746 determine_registration_strategy(&metadata),
1747 ClientRegistrationStrategy::Cimd {
1748 client_id: CIMD_URL.to_string(),
1749 }
1750 );
1751 }
1752
1753 #[test]
1754 fn test_registration_strategy_falls_back_to_dcr() {
1755 let reg_endpoint = Url::parse("https://auth.example.com/register").unwrap();
1756 let metadata = AuthServerMetadata {
1757 issuer: Url::parse("https://auth.example.com").unwrap(),
1758 authorization_endpoint: Url::parse("https://auth.example.com/authorize").unwrap(),
1759 token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
1760 registration_endpoint: Some(reg_endpoint.clone()),
1761 scopes_supported: None,
1762 code_challenge_methods_supported: Some(vec!["S256".into()]),
1763 client_id_metadata_document_supported: false,
1764 };
1765 assert_eq!(
1766 determine_registration_strategy(&metadata),
1767 ClientRegistrationStrategy::Dcr {
1768 registration_endpoint: reg_endpoint,
1769 }
1770 );
1771 }
1772
1773 #[test]
1774 fn test_registration_strategy_unavailable() {
1775 let metadata = AuthServerMetadata {
1776 issuer: Url::parse("https://auth.example.com").unwrap(),
1777 authorization_endpoint: Url::parse("https://auth.example.com/authorize").unwrap(),
1778 token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
1779 registration_endpoint: None,
1780 scopes_supported: None,
1781 code_challenge_methods_supported: Some(vec!["S256".into()]),
1782 client_id_metadata_document_supported: false,
1783 };
1784 assert_eq!(
1785 determine_registration_strategy(&metadata),
1786 ClientRegistrationStrategy::Unavailable,
1787 );
1788 }
1789
1790 // -- PKCE tests ----------------------------------------------------------
1791
1792 #[test]
1793 fn test_pkce_challenge_verifier_length() {
1794 let pkce = generate_pkce_challenge();
1795 // 32 random bytes → 43 base64url chars (no padding).
1796 assert_eq!(pkce.verifier.len(), 43);
1797 }
1798
1799 #[test]
1800 fn test_pkce_challenge_is_valid_base64url() {
1801 let pkce = generate_pkce_challenge();
1802 for c in pkce.verifier.chars().chain(pkce.challenge.chars()) {
1803 assert!(
1804 c.is_ascii_alphanumeric() || c == '-' || c == '_',
1805 "invalid base64url character: {}",
1806 c
1807 );
1808 }
1809 }
1810
1811 #[test]
1812 fn test_pkce_challenge_is_s256_of_verifier() {
1813 let pkce = generate_pkce_challenge();
1814 let engine = base64::engine::general_purpose::URL_SAFE_NO_PAD;
1815 let expected_digest = Sha256::digest(pkce.verifier.as_bytes());
1816 let expected_challenge = engine.encode(expected_digest);
1817 assert_eq!(pkce.challenge, expected_challenge);
1818 }
1819
1820 #[test]
1821 fn test_pkce_challenges_are_unique() {
1822 let a = generate_pkce_challenge();
1823 let b = generate_pkce_challenge();
1824 assert_ne!(a.verifier, b.verifier);
1825 }
1826
1827 // -- Authorization URL tests ---------------------------------------------
1828
1829 #[test]
1830 fn test_build_authorization_url() {
1831 let metadata = AuthServerMetadata {
1832 issuer: Url::parse("https://auth.example.com").unwrap(),
1833 authorization_endpoint: Url::parse("https://auth.example.com/authorize").unwrap(),
1834 token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
1835 registration_endpoint: None,
1836 scopes_supported: None,
1837 code_challenge_methods_supported: Some(vec!["S256".into()]),
1838 client_id_metadata_document_supported: true,
1839 };
1840 let pkce = PkceChallenge {
1841 verifier: "test_verifier".into(),
1842 challenge: "test_challenge".into(),
1843 };
1844 let url = build_authorization_url(
1845 &metadata,
1846 "https://zed.dev/oauth/client-metadata.json",
1847 "http://127.0.0.1:12345/callback",
1848 &["files:read".into(), "files:write".into()],
1849 "https://mcp.example.com",
1850 &pkce,
1851 "random_state_123",
1852 );
1853
1854 let pairs: std::collections::HashMap<_, _> = url.query_pairs().collect();
1855 assert_eq!(pairs.get("response_type").unwrap(), "code");
1856 assert_eq!(
1857 pairs.get("client_id").unwrap(),
1858 "https://zed.dev/oauth/client-metadata.json"
1859 );
1860 assert_eq!(
1861 pairs.get("redirect_uri").unwrap(),
1862 "http://127.0.0.1:12345/callback"
1863 );
1864 assert_eq!(pairs.get("scope").unwrap(), "files:read files:write");
1865 assert_eq!(pairs.get("resource").unwrap(), "https://mcp.example.com");
1866 assert_eq!(pairs.get("code_challenge").unwrap(), "test_challenge");
1867 assert_eq!(pairs.get("code_challenge_method").unwrap(), "S256");
1868 assert_eq!(pairs.get("state").unwrap(), "random_state_123");
1869 }
1870
1871 #[test]
1872 fn test_build_authorization_url_omits_empty_scope() {
1873 let metadata = AuthServerMetadata {
1874 issuer: Url::parse("https://auth.example.com").unwrap(),
1875 authorization_endpoint: Url::parse("https://auth.example.com/authorize").unwrap(),
1876 token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
1877 registration_endpoint: None,
1878 scopes_supported: None,
1879 code_challenge_methods_supported: Some(vec!["S256".into()]),
1880 client_id_metadata_document_supported: false,
1881 };
1882 let pkce = PkceChallenge {
1883 verifier: "v".into(),
1884 challenge: "c".into(),
1885 };
1886 let url = build_authorization_url(
1887 &metadata,
1888 "client_123",
1889 "http://127.0.0.1:9999/callback",
1890 &[],
1891 "https://mcp.example.com",
1892 &pkce,
1893 "state",
1894 );
1895
1896 let pairs: std::collections::HashMap<_, _> = url.query_pairs().collect();
1897 assert!(!pairs.contains_key("scope"));
1898 }
1899
1900 // -- Token exchange / refresh param tests --------------------------------
1901
1902 #[test]
1903 fn test_token_exchange_params() {
1904 let params = token_exchange_params(
1905 "auth_code_abc",
1906 "client_xyz",
1907 "http://127.0.0.1:5555/callback",
1908 "verifier_123",
1909 "https://mcp.example.com",
1910 None,
1911 );
1912 let map: std::collections::HashMap<&str, &str> =
1913 params.iter().map(|(k, v)| (*k, v.as_str())).collect();
1914
1915 assert_eq!(map["grant_type"], "authorization_code");
1916 assert_eq!(map["code"], "auth_code_abc");
1917 assert_eq!(map["redirect_uri"], "http://127.0.0.1:5555/callback");
1918 assert_eq!(map["client_id"], "client_xyz");
1919 assert_eq!(map["code_verifier"], "verifier_123");
1920 assert_eq!(map["resource"], "https://mcp.example.com");
1921 }
1922
1923 #[test]
1924 fn test_token_refresh_params() {
1925 let params = token_refresh_params(
1926 "refresh_token_abc",
1927 "client_xyz",
1928 "https://mcp.example.com",
1929 None,
1930 );
1931 let map: std::collections::HashMap<&str, &str> =
1932 params.iter().map(|(k, v)| (*k, v.as_str())).collect();
1933
1934 assert_eq!(map["grant_type"], "refresh_token");
1935 assert_eq!(map["refresh_token"], "refresh_token_abc");
1936 assert_eq!(map["client_id"], "client_xyz");
1937 assert_eq!(map["resource"], "https://mcp.example.com");
1938 }
1939
1940 // -- Token response tests ------------------------------------------------
1941
1942 #[test]
1943 fn test_token_response_into_tokens_with_expiry() {
1944 let response: TokenResponse = serde_json::from_str(
1945 r#"{"access_token": "at_123", "refresh_token": "rt_456", "expires_in": 3600, "token_type": "Bearer"}"#,
1946 )
1947 .unwrap();
1948
1949 let tokens = response.into_tokens();
1950 assert_eq!(tokens.access_token, "at_123");
1951 assert_eq!(tokens.refresh_token.as_deref(), Some("rt_456"));
1952 assert!(tokens.expires_at.is_some());
1953 }
1954
1955 #[test]
1956 fn test_token_response_into_tokens_minimal() {
1957 let response: TokenResponse =
1958 serde_json::from_str(r#"{"access_token": "at_789"}"#).unwrap();
1959
1960 let tokens = response.into_tokens();
1961 assert_eq!(tokens.access_token, "at_789");
1962 assert_eq!(tokens.refresh_token, None);
1963 assert_eq!(tokens.expires_at, None);
1964 }
1965
1966 // -- DCR body test -------------------------------------------------------
1967
1968 #[test]
1969 fn test_dcr_registration_body_shape() {
1970 let body = dcr_registration_body("http://127.0.0.1:12345/callback");
1971 assert_eq!(body["client_name"], "Zed");
1972 assert_eq!(body["redirect_uris"][0], "http://127.0.0.1:12345/callback");
1973 assert_eq!(body["grant_types"][0], "authorization_code");
1974 assert_eq!(body["response_types"][0], "code");
1975 assert_eq!(body["token_endpoint_auth_method"], "none");
1976 }
1977
1978 // -- Test helpers for async/HTTP tests -----------------------------------
1979
1980 fn make_fake_http_client(
1981 handler: impl Fn(
1982 http_client::Request<AsyncBody>,
1983 ) -> std::pin::Pin<
1984 Box<dyn std::future::Future<Output = anyhow::Result<Response<AsyncBody>>> + Send>,
1985 > + Send
1986 + Sync
1987 + 'static,
1988 ) -> Arc<dyn HttpClient> {
1989 http_client::FakeHttpClient::create(handler) as Arc<dyn HttpClient>
1990 }
1991
1992 fn json_response(status: u16, body: &str) -> anyhow::Result<Response<AsyncBody>> {
1993 Ok(Response::builder()
1994 .status(status)
1995 .header("Content-Type", "application/json")
1996 .body(AsyncBody::from(body.as_bytes().to_vec()))
1997 .unwrap())
1998 }
1999
2000 // -- Discovery integration tests -----------------------------------------
2001
2002 #[test]
2003 fn test_fetch_protected_resource_metadata() {
2004 smol::block_on(async {
2005 let client = make_fake_http_client(|req| {
2006 Box::pin(async move {
2007 let uri = req.uri().to_string();
2008 if uri.contains(".well-known/oauth-protected-resource") {
2009 json_response(
2010 200,
2011 r#"{
2012 "resource": "https://mcp.example.com",
2013 "authorization_servers": ["https://auth.example.com"],
2014 "scopes_supported": ["read", "write"]
2015 }"#,
2016 )
2017 } else {
2018 json_response(404, "{}")
2019 }
2020 })
2021 });
2022
2023 let server_url = Url::parse("https://mcp.example.com").unwrap();
2024 let www_auth = WwwAuthenticate {
2025 resource_metadata: None,
2026 scope: None,
2027 error: None,
2028 error_description: None,
2029 };
2030
2031 let metadata = fetch_protected_resource_metadata(&client, &server_url, &www_auth)
2032 .await
2033 .unwrap();
2034
2035 assert_eq!(metadata.resource.as_str(), "https://mcp.example.com/");
2036 assert_eq!(metadata.authorization_servers.len(), 1);
2037 assert_eq!(
2038 metadata.authorization_servers[0].as_str(),
2039 "https://auth.example.com/"
2040 );
2041 assert_eq!(
2042 metadata.scopes_supported,
2043 Some(vec!["read".to_string(), "write".to_string()])
2044 );
2045 });
2046 }
2047
2048 #[test]
2049 fn test_fetch_protected_resource_metadata_prefers_www_authenticate_url() {
2050 smol::block_on(async {
2051 let client = make_fake_http_client(|req| {
2052 Box::pin(async move {
2053 let uri = req.uri().to_string();
2054 if uri == "https://mcp.example.com/custom-resource-metadata" {
2055 json_response(
2056 200,
2057 r#"{
2058 "resource": "https://mcp.example.com",
2059 "authorization_servers": ["https://auth.example.com"]
2060 }"#,
2061 )
2062 } else {
2063 json_response(500, r#"{"error": "should not be called"}"#)
2064 }
2065 })
2066 });
2067
2068 let server_url = Url::parse("https://mcp.example.com").unwrap();
2069 let www_auth = WwwAuthenticate {
2070 resource_metadata: Some(
2071 Url::parse("https://mcp.example.com/custom-resource-metadata").unwrap(),
2072 ),
2073 scope: None,
2074 error: None,
2075 error_description: None,
2076 };
2077
2078 let metadata = fetch_protected_resource_metadata(&client, &server_url, &www_auth)
2079 .await
2080 .unwrap();
2081
2082 assert_eq!(metadata.authorization_servers.len(), 1);
2083 });
2084 }
2085
2086 #[test]
2087 fn test_fetch_protected_resource_metadata_rejects_cross_origin_url() {
2088 smol::block_on(async {
2089 let client = make_fake_http_client(|req| {
2090 Box::pin(async move {
2091 let uri = req.uri().to_string();
2092 // The cross-origin URL should NOT be fetched; only the
2093 // well-known fallback at the server's own origin should be.
2094 if uri.contains("attacker.example.com") {
2095 panic!("should not fetch cross-origin resource_metadata URL");
2096 } else if uri.contains(".well-known/oauth-protected-resource") {
2097 json_response(
2098 200,
2099 r#"{
2100 "resource": "https://mcp.example.com",
2101 "authorization_servers": ["https://auth.example.com"]
2102 }"#,
2103 )
2104 } else {
2105 json_response(404, "{}")
2106 }
2107 })
2108 });
2109
2110 let server_url = Url::parse("https://mcp.example.com").unwrap();
2111 let www_auth = WwwAuthenticate {
2112 resource_metadata: Some(
2113 Url::parse("https://attacker.example.com/fake-metadata").unwrap(),
2114 ),
2115 scope: None,
2116 error: None,
2117 error_description: None,
2118 };
2119
2120 let metadata = fetch_protected_resource_metadata(&client, &server_url, &www_auth)
2121 .await
2122 .unwrap();
2123
2124 // Should have used the fallback well-known URL, not the attacker's.
2125 assert_eq!(metadata.resource.as_str(), "https://mcp.example.com/");
2126 });
2127 }
2128
2129 #[test]
2130 fn test_fetch_auth_server_metadata() {
2131 smol::block_on(async {
2132 let client = make_fake_http_client(|req| {
2133 Box::pin(async move {
2134 let uri = req.uri().to_string();
2135 if uri.contains(".well-known/oauth-authorization-server") {
2136 json_response(
2137 200,
2138 r#"{
2139 "issuer": "https://auth.example.com",
2140 "authorization_endpoint": "https://auth.example.com/authorize",
2141 "token_endpoint": "https://auth.example.com/token",
2142 "registration_endpoint": "https://auth.example.com/register",
2143 "code_challenge_methods_supported": ["S256"],
2144 "client_id_metadata_document_supported": true
2145 }"#,
2146 )
2147 } else {
2148 json_response(404, "{}")
2149 }
2150 })
2151 });
2152
2153 let issuer = Url::parse("https://auth.example.com").unwrap();
2154 let metadata = fetch_auth_server_metadata(&client, &issuer).await.unwrap();
2155
2156 assert_eq!(metadata.issuer.as_str(), "https://auth.example.com/");
2157 assert_eq!(
2158 metadata.authorization_endpoint.as_str(),
2159 "https://auth.example.com/authorize"
2160 );
2161 assert_eq!(
2162 metadata.token_endpoint.as_str(),
2163 "https://auth.example.com/token"
2164 );
2165 assert!(metadata.registration_endpoint.is_some());
2166 assert!(metadata.client_id_metadata_document_supported);
2167 assert_eq!(
2168 metadata.code_challenge_methods_supported,
2169 Some(vec!["S256".to_string()])
2170 );
2171 });
2172 }
2173
2174 #[test]
2175 fn test_fetch_auth_server_metadata_falls_back_to_oidc() {
2176 smol::block_on(async {
2177 let client = make_fake_http_client(|req| {
2178 Box::pin(async move {
2179 let uri = req.uri().to_string();
2180 if uri.contains("openid-configuration") {
2181 json_response(
2182 200,
2183 r#"{
2184 "issuer": "https://auth.example.com",
2185 "authorization_endpoint": "https://auth.example.com/authorize",
2186 "token_endpoint": "https://auth.example.com/token",
2187 "code_challenge_methods_supported": ["S256"]
2188 }"#,
2189 )
2190 } else {
2191 json_response(404, "{}")
2192 }
2193 })
2194 });
2195
2196 let issuer = Url::parse("https://auth.example.com").unwrap();
2197 let metadata = fetch_auth_server_metadata(&client, &issuer).await.unwrap();
2198
2199 assert_eq!(
2200 metadata.authorization_endpoint.as_str(),
2201 "https://auth.example.com/authorize"
2202 );
2203 assert!(!metadata.client_id_metadata_document_supported);
2204 });
2205 }
2206
2207 #[test]
2208 fn test_fetch_auth_server_metadata_rejects_issuer_mismatch() {
2209 smol::block_on(async {
2210 let client = make_fake_http_client(|req| {
2211 Box::pin(async move {
2212 let uri = req.uri().to_string();
2213 if uri.contains(".well-known/oauth-authorization-server") {
2214 // Response claims to be a different issuer.
2215 json_response(
2216 200,
2217 r#"{
2218 "issuer": "https://evil.example.com",
2219 "authorization_endpoint": "https://evil.example.com/authorize",
2220 "token_endpoint": "https://evil.example.com/token",
2221 "code_challenge_methods_supported": ["S256"]
2222 }"#,
2223 )
2224 } else {
2225 json_response(404, "{}")
2226 }
2227 })
2228 });
2229
2230 let issuer = Url::parse("https://auth.example.com").unwrap();
2231 let result = fetch_auth_server_metadata(&client, &issuer).await;
2232
2233 assert!(result.is_err());
2234 let err_msg = result.unwrap_err().to_string();
2235 assert!(
2236 err_msg.contains("issuer mismatch"),
2237 "unexpected error: {}",
2238 err_msg
2239 );
2240 });
2241 }
2242
2243 // -- Full discover integration tests -------------------------------------
2244
2245 #[test]
2246 fn test_full_discover_with_cimd() {
2247 smol::block_on(async {
2248 let client = make_fake_http_client(|req| {
2249 Box::pin(async move {
2250 let uri = req.uri().to_string();
2251 if uri.contains("oauth-protected-resource") {
2252 json_response(
2253 200,
2254 r#"{
2255 "resource": "https://mcp.example.com",
2256 "authorization_servers": ["https://auth.example.com"],
2257 "scopes_supported": ["mcp:read"]
2258 }"#,
2259 )
2260 } else if uri.contains("oauth-authorization-server") {
2261 json_response(
2262 200,
2263 r#"{
2264 "issuer": "https://auth.example.com",
2265 "authorization_endpoint": "https://auth.example.com/authorize",
2266 "token_endpoint": "https://auth.example.com/token",
2267 "code_challenge_methods_supported": ["S256"],
2268 "client_id_metadata_document_supported": true
2269 }"#,
2270 )
2271 } else {
2272 json_response(404, "{}")
2273 }
2274 })
2275 });
2276
2277 let server_url = Url::parse("https://mcp.example.com").unwrap();
2278 let www_auth = WwwAuthenticate {
2279 resource_metadata: None,
2280 scope: None,
2281 error: None,
2282 error_description: None,
2283 };
2284
2285 let discovery = discover(&client, &server_url, &www_auth).await.unwrap();
2286 let registration =
2287 resolve_client_registration(&client, &discovery, "http://127.0.0.1:12345/callback")
2288 .await
2289 .unwrap();
2290
2291 assert_eq!(registration.client_id, CIMD_URL);
2292 assert_eq!(registration.client_secret, None);
2293 assert_eq!(discovery.scopes, vec!["mcp:read"]);
2294 });
2295 }
2296
2297 #[test]
2298 fn test_full_discover_with_dcr_fallback() {
2299 smol::block_on(async {
2300 let client = make_fake_http_client(|req| {
2301 Box::pin(async move {
2302 let uri = req.uri().to_string();
2303 if uri.contains("oauth-protected-resource") {
2304 json_response(
2305 200,
2306 r#"{
2307 "resource": "https://mcp.example.com",
2308 "authorization_servers": ["https://auth.example.com"]
2309 }"#,
2310 )
2311 } else if uri.contains("oauth-authorization-server") {
2312 json_response(
2313 200,
2314 r#"{
2315 "issuer": "https://auth.example.com",
2316 "authorization_endpoint": "https://auth.example.com/authorize",
2317 "token_endpoint": "https://auth.example.com/token",
2318 "registration_endpoint": "https://auth.example.com/register",
2319 "code_challenge_methods_supported": ["S256"],
2320 "client_id_metadata_document_supported": false
2321 }"#,
2322 )
2323 } else if uri.contains("/register") {
2324 json_response(
2325 201,
2326 r#"{
2327 "client_id": "dcr-minted-id-123",
2328 "client_secret": "dcr-secret-456"
2329 }"#,
2330 )
2331 } else {
2332 json_response(404, "{}")
2333 }
2334 })
2335 });
2336
2337 let server_url = Url::parse("https://mcp.example.com").unwrap();
2338 let www_auth = WwwAuthenticate {
2339 resource_metadata: None,
2340 scope: Some(vec!["files:read".into()]),
2341 error: None,
2342 error_description: None,
2343 };
2344
2345 let discovery = discover(&client, &server_url, &www_auth).await.unwrap();
2346 let registration =
2347 resolve_client_registration(&client, &discovery, "http://127.0.0.1:9999/callback")
2348 .await
2349 .unwrap();
2350
2351 assert_eq!(registration.client_id, "dcr-minted-id-123");
2352 assert_eq!(
2353 registration.client_secret.as_deref(),
2354 Some("dcr-secret-456")
2355 );
2356 assert_eq!(discovery.scopes, vec!["files:read"]);
2357 });
2358 }
2359
2360 #[test]
2361 fn test_discover_fails_without_pkce_support() {
2362 smol::block_on(async {
2363 let client = make_fake_http_client(|req| {
2364 Box::pin(async move {
2365 let uri = req.uri().to_string();
2366 if uri.contains("oauth-protected-resource") {
2367 json_response(
2368 200,
2369 r#"{
2370 "resource": "https://mcp.example.com",
2371 "authorization_servers": ["https://auth.example.com"]
2372 }"#,
2373 )
2374 } else if uri.contains("oauth-authorization-server") {
2375 json_response(
2376 200,
2377 r#"{
2378 "issuer": "https://auth.example.com",
2379 "authorization_endpoint": "https://auth.example.com/authorize",
2380 "token_endpoint": "https://auth.example.com/token"
2381 }"#,
2382 )
2383 } else {
2384 json_response(404, "{}")
2385 }
2386 })
2387 });
2388
2389 let server_url = Url::parse("https://mcp.example.com").unwrap();
2390 let www_auth = WwwAuthenticate {
2391 resource_metadata: None,
2392 scope: None,
2393 error: None,
2394 error_description: None,
2395 };
2396
2397 let result = discover(&client, &server_url, &www_auth).await;
2398 assert!(result.is_err());
2399 let err_msg = result.unwrap_err().to_string();
2400 assert!(
2401 err_msg.contains("code_challenge_methods_supported"),
2402 "unexpected error: {}",
2403 err_msg
2404 );
2405 });
2406 }
2407
2408 // -- Token exchange integration tests ------------------------------------
2409
2410 #[test]
2411 fn test_exchange_code_success() {
2412 smol::block_on(async {
2413 let client = make_fake_http_client(|req| {
2414 Box::pin(async move {
2415 let uri = req.uri().to_string();
2416 if uri.contains("/token") {
2417 json_response(
2418 200,
2419 r#"{
2420 "access_token": "new_access_token",
2421 "refresh_token": "new_refresh_token",
2422 "expires_in": 3600,
2423 "token_type": "Bearer"
2424 }"#,
2425 )
2426 } else {
2427 json_response(404, "{}")
2428 }
2429 })
2430 });
2431
2432 let metadata = AuthServerMetadata {
2433 issuer: Url::parse("https://auth.example.com").unwrap(),
2434 authorization_endpoint: Url::parse("https://auth.example.com/authorize").unwrap(),
2435 token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
2436 registration_endpoint: None,
2437 scopes_supported: None,
2438 code_challenge_methods_supported: Some(vec!["S256".into()]),
2439 client_id_metadata_document_supported: true,
2440 };
2441
2442 let tokens = exchange_code(
2443 &client,
2444 &metadata,
2445 "auth_code_123",
2446 CIMD_URL,
2447 "http://127.0.0.1:9999/callback",
2448 "verifier_abc",
2449 "https://mcp.example.com",
2450 None,
2451 )
2452 .await
2453 .unwrap();
2454
2455 assert_eq!(tokens.access_token, "new_access_token");
2456 assert_eq!(tokens.refresh_token.as_deref(), Some("new_refresh_token"));
2457 assert!(tokens.expires_at.is_some());
2458 });
2459 }
2460
2461 #[test]
2462 fn test_refresh_tokens_success() {
2463 smol::block_on(async {
2464 let client = make_fake_http_client(|req| {
2465 Box::pin(async move {
2466 let uri = req.uri().to_string();
2467 if uri.contains("/token") {
2468 json_response(
2469 200,
2470 r#"{
2471 "access_token": "refreshed_token",
2472 "expires_in": 1800,
2473 "token_type": "Bearer"
2474 }"#,
2475 )
2476 } else {
2477 json_response(404, "{}")
2478 }
2479 })
2480 });
2481
2482 let token_endpoint = Url::parse("https://auth.example.com/token").unwrap();
2483
2484 let tokens = refresh_tokens(
2485 &client,
2486 &token_endpoint,
2487 "old_refresh_token",
2488 CIMD_URL,
2489 "https://mcp.example.com",
2490 None,
2491 )
2492 .await
2493 .unwrap();
2494
2495 assert_eq!(tokens.access_token, "refreshed_token");
2496 assert_eq!(tokens.refresh_token, None);
2497 assert!(tokens.expires_at.is_some());
2498 });
2499 }
2500
2501 #[test]
2502 fn test_exchange_code_failure() {
2503 smol::block_on(async {
2504 let client = make_fake_http_client(|_req| {
2505 Box::pin(async move { json_response(400, r#"{"error": "invalid_grant"}"#) })
2506 });
2507
2508 let metadata = AuthServerMetadata {
2509 issuer: Url::parse("https://auth.example.com").unwrap(),
2510 authorization_endpoint: Url::parse("https://auth.example.com/authorize").unwrap(),
2511 token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
2512 registration_endpoint: None,
2513 scopes_supported: None,
2514 code_challenge_methods_supported: Some(vec!["S256".into()]),
2515 client_id_metadata_document_supported: true,
2516 };
2517
2518 let result = exchange_code(
2519 &client,
2520 &metadata,
2521 "bad_code",
2522 "client",
2523 "http://127.0.0.1:1/callback",
2524 "verifier",
2525 "https://mcp.example.com",
2526 None,
2527 )
2528 .await;
2529
2530 let err = result.unwrap_err();
2531 let token_error = err
2532 .downcast_ref::<OAuthTokenError>()
2533 .expect("expected OAuthTokenError");
2534 assert_eq!(
2535 *token_error,
2536 OAuthTokenError {
2537 error: "invalid_grant".into(),
2538 error_description: None,
2539 }
2540 );
2541 });
2542 }
2543
2544 // -- DCR integration tests -----------------------------------------------
2545
2546 #[test]
2547 fn test_perform_dcr() {
2548 smol::block_on(async {
2549 let client = make_fake_http_client(|_req| {
2550 Box::pin(async move {
2551 json_response(
2552 201,
2553 r#"{
2554 "client_id": "dynamic-client-001",
2555 "client_secret": "dynamic-secret-001"
2556 }"#,
2557 )
2558 })
2559 });
2560
2561 let endpoint = Url::parse("https://auth.example.com/register").unwrap();
2562 let registration = perform_dcr(&client, &endpoint, "http://127.0.0.1:9999/callback")
2563 .await
2564 .unwrap();
2565
2566 assert_eq!(registration.client_id, "dynamic-client-001");
2567 assert_eq!(
2568 registration.client_secret.as_deref(),
2569 Some("dynamic-secret-001")
2570 );
2571 });
2572 }
2573
2574 #[test]
2575 fn test_perform_dcr_failure() {
2576 smol::block_on(async {
2577 let client = make_fake_http_client(|_req| {
2578 Box::pin(
2579 async move { json_response(403, r#"{"error": "registration_not_allowed"}"#) },
2580 )
2581 });
2582
2583 let endpoint = Url::parse("https://auth.example.com/register").unwrap();
2584 let result = perform_dcr(&client, &endpoint, "http://127.0.0.1:9999/callback").await;
2585
2586 assert!(result.is_err());
2587 assert!(result.unwrap_err().to_string().contains("403"));
2588 });
2589 }
2590
2591 // -- OAuthCallback parse tests -------------------------------------------
2592
2593 #[test]
2594 fn test_oauth_callback_parse_query() {
2595 let callback = OAuthCallback::parse_query("code=test_auth_code&state=test_state").unwrap();
2596 assert_eq!(callback.code, "test_auth_code");
2597 assert_eq!(callback.state, "test_state");
2598 }
2599
2600 #[test]
2601 fn test_oauth_callback_parse_query_reversed_order() {
2602 let callback = OAuthCallback::parse_query("state=test_state&code=test_auth_code").unwrap();
2603 assert_eq!(callback.code, "test_auth_code");
2604 assert_eq!(callback.state, "test_state");
2605 }
2606
2607 #[test]
2608 fn test_oauth_callback_parse_query_with_extra_params() {
2609 let callback =
2610 OAuthCallback::parse_query("code=test_auth_code&state=test_state&extra=ignored")
2611 .unwrap();
2612 assert_eq!(callback.code, "test_auth_code");
2613 assert_eq!(callback.state, "test_state");
2614 }
2615
2616 #[test]
2617 fn test_oauth_callback_parse_query_missing_code() {
2618 let result = OAuthCallback::parse_query("state=test_state");
2619 assert!(result.is_err());
2620 assert!(result.unwrap_err().to_string().contains("code"));
2621 }
2622
2623 #[test]
2624 fn test_oauth_callback_parse_query_missing_state() {
2625 let result = OAuthCallback::parse_query("code=test_auth_code");
2626 assert!(result.is_err());
2627 assert!(result.unwrap_err().to_string().contains("state"));
2628 }
2629
2630 #[test]
2631 fn test_oauth_callback_parse_query_empty_code() {
2632 let result = OAuthCallback::parse_query("code=&state=test_state");
2633 assert!(result.is_err());
2634 }
2635
2636 #[test]
2637 fn test_oauth_callback_parse_query_empty_state() {
2638 let result = OAuthCallback::parse_query("code=test_auth_code&state=");
2639 assert!(result.is_err());
2640 }
2641
2642 #[test]
2643 fn test_oauth_callback_parse_query_url_encoded_values() {
2644 let callback = OAuthCallback::parse_query("code=abc%20def&state=test%3Dstate").unwrap();
2645 assert_eq!(callback.code, "abc def");
2646 assert_eq!(callback.state, "test=state");
2647 }
2648
2649 #[test]
2650 fn test_oauth_callback_parse_query_error_response() {
2651 let result = OAuthCallback::parse_query(
2652 "error=access_denied&error_description=User%20denied%20access&state=abc",
2653 );
2654 assert!(result.is_err());
2655 let err_msg = result.unwrap_err().to_string();
2656 assert!(
2657 err_msg.contains("access_denied"),
2658 "unexpected error: {}",
2659 err_msg
2660 );
2661 assert!(
2662 err_msg.contains("User denied access"),
2663 "unexpected error: {}",
2664 err_msg
2665 );
2666 }
2667
2668 #[test]
2669 fn test_oauth_callback_parse_query_error_without_description() {
2670 let result = OAuthCallback::parse_query("error=server_error&state=abc");
2671 assert!(result.is_err());
2672 let err_msg = result.unwrap_err().to_string();
2673 assert!(
2674 err_msg.contains("server_error"),
2675 "unexpected error: {}",
2676 err_msg
2677 );
2678 assert!(
2679 err_msg.contains("no description"),
2680 "unexpected error: {}",
2681 err_msg
2682 );
2683 }
2684
2685 // -- McpOAuthTokenProvider tests -----------------------------------------
2686
2687 fn make_test_session(
2688 access_token: &str,
2689 refresh_token: Option<&str>,
2690 expires_at: Option<SystemTime>,
2691 ) -> OAuthSession {
2692 OAuthSession {
2693 token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
2694 resource: Url::parse("https://mcp.example.com").unwrap(),
2695 client_registration: OAuthClientRegistration {
2696 client_id: "test-client".into(),
2697 client_secret: None,
2698 },
2699 tokens: OAuthTokens {
2700 access_token: access_token.into(),
2701 refresh_token: refresh_token.map(String::from),
2702 expires_at,
2703 },
2704 }
2705 }
2706
2707 #[test]
2708 fn test_mcp_oauth_provider_returns_none_when_token_expired() {
2709 let expired = SystemTime::now() - Duration::from_secs(60);
2710 let session = make_test_session("stale-token", Some("rt"), Some(expired));
2711 let provider = McpOAuthTokenProvider::new(
2712 session,
2713 make_fake_http_client(|_| Box::pin(async { unreachable!() })),
2714 None,
2715 );
2716
2717 assert_eq!(provider.access_token(), None);
2718 }
2719
2720 #[test]
2721 fn test_mcp_oauth_provider_returns_token_when_not_expired() {
2722 let far_future = SystemTime::now() + Duration::from_secs(3600);
2723 let session = make_test_session("valid-token", Some("rt"), Some(far_future));
2724 let provider = McpOAuthTokenProvider::new(
2725 session,
2726 make_fake_http_client(|_| Box::pin(async { unreachable!() })),
2727 None,
2728 );
2729
2730 assert_eq!(provider.access_token().as_deref(), Some("valid-token"));
2731 }
2732
2733 #[test]
2734 fn test_mcp_oauth_provider_returns_token_when_no_expiry() {
2735 let session = make_test_session("no-expiry-token", Some("rt"), None);
2736 let provider = McpOAuthTokenProvider::new(
2737 session,
2738 make_fake_http_client(|_| Box::pin(async { unreachable!() })),
2739 None,
2740 );
2741
2742 assert_eq!(provider.access_token().as_deref(), Some("no-expiry-token"));
2743 }
2744
2745 #[test]
2746 fn test_mcp_oauth_provider_refresh_without_refresh_token_returns_false() {
2747 smol::block_on(async {
2748 let session = make_test_session("token", None, None);
2749 let provider = McpOAuthTokenProvider::new(
2750 session,
2751 make_fake_http_client(|_| {
2752 Box::pin(async { unreachable!("no HTTP call expected") })
2753 }),
2754 None,
2755 );
2756
2757 let refreshed = provider.try_refresh().await.unwrap();
2758 assert!(!refreshed);
2759 });
2760 }
2761
2762 #[test]
2763 fn test_mcp_oauth_provider_refresh_updates_session_and_notifies_channel() {
2764 smol::block_on(async {
2765 let session = make_test_session("old-access", Some("my-refresh-token"), None);
2766 let (tx, mut rx) = futures::channel::mpsc::unbounded();
2767
2768 let http_client = make_fake_http_client(|_req| {
2769 Box::pin(async {
2770 json_response(
2771 200,
2772 r#"{
2773 "access_token": "new-access",
2774 "refresh_token": "new-refresh",
2775 "expires_in": 1800
2776 }"#,
2777 )
2778 })
2779 });
2780
2781 let provider = McpOAuthTokenProvider::new(session, http_client, Some(tx));
2782
2783 let refreshed = provider.try_refresh().await.unwrap();
2784 assert!(refreshed);
2785 assert_eq!(provider.access_token().as_deref(), Some("new-access"));
2786
2787 let notified_session = rx.try_recv().expect("channel should have a session");
2788 assert_eq!(notified_session.tokens.access_token, "new-access");
2789 assert_eq!(
2790 notified_session.tokens.refresh_token.as_deref(),
2791 Some("new-refresh")
2792 );
2793 });
2794 }
2795
2796 #[test]
2797 fn test_mcp_oauth_provider_refresh_preserves_old_refresh_token_when_server_omits_it() {
2798 smol::block_on(async {
2799 let session = make_test_session("old-access", Some("original-refresh"), None);
2800 let (tx, mut rx) = futures::channel::mpsc::unbounded();
2801
2802 let http_client = make_fake_http_client(|_req| {
2803 Box::pin(async {
2804 json_response(
2805 200,
2806 r#"{
2807 "access_token": "new-access",
2808 "expires_in": 900
2809 }"#,
2810 )
2811 })
2812 });
2813
2814 let provider = McpOAuthTokenProvider::new(session, http_client, Some(tx));
2815
2816 let refreshed = provider.try_refresh().await.unwrap();
2817 assert!(refreshed);
2818
2819 let notified_session = rx.try_recv().expect("channel should have a session");
2820 assert_eq!(notified_session.tokens.access_token, "new-access");
2821 assert_eq!(
2822 notified_session.tokens.refresh_token.as_deref(),
2823 Some("original-refresh"),
2824 );
2825 });
2826 }
2827
2828 #[test]
2829 fn test_mcp_oauth_provider_refresh_returns_false_on_http_error() {
2830 smol::block_on(async {
2831 let session = make_test_session("old-access", Some("my-refresh"), None);
2832
2833 let http_client = make_fake_http_client(|_req| {
2834 Box::pin(async { json_response(401, r#"{"error": "invalid_grant"}"#) })
2835 });
2836
2837 let provider = McpOAuthTokenProvider::new(session, http_client, None);
2838
2839 let refreshed = provider.try_refresh().await.unwrap();
2840 assert!(!refreshed);
2841 // The old token should still be in place.
2842 assert_eq!(provider.access_token().as_deref(), Some("old-access"));
2843 });
2844 }
2845}