1//! OAuth 2.0 authentication for MCP servers using the Authorization Code +
2//! PKCE flow, per the MCP spec's OAuth profile.
3//!
4//! The flow is split into two phases:
5//!
6//! 1. **Discovery** ([`discover`]) fetches Protected Resource Metadata and
7//! Authorization Server Metadata. This can happen early (e.g. on a 401
8//! during server startup) because it doesn't need the redirect URI yet.
9//!
10//! 2. **Client registration** ([`resolve_client_registration`]) is separate
11//! because DCR requires the actual loopback redirect URI, which includes an
12//! ephemeral port that only exists once the callback server has started.
13//!
14//! After authentication, the full state is captured in [`OAuthSession`] which
15//! is persisted to the keychain. On next startup, the stored session feeds
16//! directly into [`McpOAuthTokenProvider`], giving a refresh-capable provider
17//! without requiring another browser flow.
18
19use anyhow::{Context as _, Result, anyhow, bail};
20use async_trait::async_trait;
21use base64::Engine as _;
22use futures::AsyncReadExt as _;
23use futures::channel::mpsc;
24use http_client::{AsyncBody, HttpClient, Request};
25use parking_lot::Mutex as SyncMutex;
26use rand::Rng as _;
27use serde::{Deserialize, Serialize};
28use sha2::{Digest, Sha256};
29
30use std::str::FromStr;
31use std::sync::Arc;
32use std::time::{Duration, SystemTime};
33use url::Url;
34use util::ResultExt as _;
35
36/// The CIMD URL where Zed's OAuth client metadata document is hosted.
37pub const CIMD_URL: &str = "https://zed.dev/oauth/client-metadata.json";
38
39/// Validate that a URL is safe to use as an OAuth endpoint.
40///
41/// OAuth endpoints carry sensitive material (authorization codes, PKCE
42/// verifiers, tokens) and must use TLS. Plain HTTP is only permitted for
43/// loopback addresses, per RFC 8252 Section 8.3.
44fn require_https_or_loopback(url: &Url) -> Result<()> {
45 if url.scheme() == "https" {
46 return Ok(());
47 }
48 if url.scheme() == "http" {
49 if let Some(host) = url.host() {
50 match host {
51 url::Host::Ipv4(ip) if ip.is_loopback() => return Ok(()),
52 url::Host::Ipv6(ip) if ip.is_loopback() => return Ok(()),
53 url::Host::Domain(d) if d.eq_ignore_ascii_case("localhost") => return Ok(()),
54 _ => {}
55 }
56 }
57 }
58 bail!(
59 "OAuth endpoint must use HTTPS (got {}://{})",
60 url.scheme(),
61 url.host_str().unwrap_or("?")
62 )
63}
64
65/// Validate that a URL is safe to use as an OAuth endpoint, including SSRF
66/// protections against private/reserved IP ranges.
67///
68/// This wraps [`require_https_or_loopback`] and adds IP-range checks to prevent
69/// an attacker-controlled MCP server from directing Zed to fetch internal
70/// network resources via metadata URLs.
71///
72/// **Known limitation:** Domain-name URLs that resolve to private IPs are *not*
73/// blocked here — full mitigation requires resolver-level validation (e.g. a
74/// custom `Resolve` implementation). This function only blocks IP-literal URLs.
75fn validate_oauth_url(url: &Url) -> Result<()> {
76 require_https_or_loopback(url)?;
77
78 if let Some(host) = url.host() {
79 match host {
80 url::Host::Ipv4(ip) => {
81 // Loopback is already allowed by require_https_or_loopback.
82 if ip.is_private() || ip.is_link_local() || ip.is_broadcast() || ip.is_unspecified()
83 {
84 bail!(
85 "OAuth endpoint must not point to private/reserved IP: {}",
86 ip
87 );
88 }
89 }
90 url::Host::Ipv6(ip) => {
91 // Check for IPv4-mapped IPv6 addresses (::ffff:a.b.c.d) which
92 // could bypass the IPv4 checks above.
93 if let Some(mapped_v4) = ip.to_ipv4_mapped() {
94 if mapped_v4.is_private()
95 || mapped_v4.is_link_local()
96 || mapped_v4.is_broadcast()
97 || mapped_v4.is_unspecified()
98 {
99 bail!(
100 "OAuth endpoint must not point to private/reserved IP: ::ffff:{}",
101 mapped_v4
102 );
103 }
104 }
105
106 if ip.is_unspecified() || ip.is_multicast() {
107 bail!(
108 "OAuth endpoint must not point to reserved IPv6 address: {}",
109 ip
110 );
111 }
112 // IPv6 Unique Local Addresses (fc00::/7). is_unique_local() is
113 // nightly-only, so check the prefix manually.
114 if (ip.segments()[0] & 0xfe00) == 0xfc00 {
115 bail!(
116 "OAuth endpoint must not point to IPv6 unique-local address: {}",
117 ip
118 );
119 }
120 }
121 url::Host::Domain(_) => {
122 // Domain-based SSRF prevention requires resolver-level checks.
123 // See known limitation in the doc comment above.
124 }
125 }
126 }
127
128 Ok(())
129}
130
131/// Parsed from the MCP server's WWW-Authenticate header or well-known endpoint
132/// per RFC 9728 (OAuth 2.0 Protected Resource Metadata).
133#[derive(Debug, Clone, Serialize, Deserialize)]
134pub struct ProtectedResourceMetadata {
135 pub resource: Url,
136 pub authorization_servers: Vec<Url>,
137 pub scopes_supported: Option<Vec<String>>,
138}
139
140/// Parsed from the authorization server's .well-known endpoint
141/// per RFC 8414 (OAuth 2.0 Authorization Server Metadata).
142#[derive(Debug, Clone, Serialize, Deserialize)]
143pub struct AuthServerMetadata {
144 pub issuer: Url,
145 pub authorization_endpoint: Url,
146 pub token_endpoint: Url,
147 pub registration_endpoint: Option<Url>,
148 pub scopes_supported: Option<Vec<String>>,
149 pub code_challenge_methods_supported: Option<Vec<String>>,
150 pub client_id_metadata_document_supported: bool,
151}
152
153/// The result of client registration — either CIMD or DCR.
154#[derive(Clone, Serialize, Deserialize)]
155pub struct OAuthClientRegistration {
156 pub client_id: String,
157 /// Only present for DCR-minted registrations.
158 pub client_secret: Option<String>,
159}
160
161impl std::fmt::Debug for OAuthClientRegistration {
162 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
163 f.debug_struct("OAuthClientRegistration")
164 .field("client_id", &self.client_id)
165 .field(
166 "client_secret",
167 &self.client_secret.as_ref().map(|_| "[redacted]"),
168 )
169 .finish()
170 }
171}
172
173/// Access and refresh tokens obtained from the token endpoint.
174#[derive(Clone, Serialize, Deserialize)]
175pub struct OAuthTokens {
176 pub access_token: String,
177 pub refresh_token: Option<String>,
178 pub expires_at: Option<SystemTime>,
179}
180
181impl std::fmt::Debug for OAuthTokens {
182 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
183 f.debug_struct("OAuthTokens")
184 .field("access_token", &"[redacted]")
185 .field(
186 "refresh_token",
187 &self.refresh_token.as_ref().map(|_| "[redacted]"),
188 )
189 .field("expires_at", &self.expires_at)
190 .finish()
191 }
192}
193
194/// Everything discovered before the browser flow starts. Client registration is
195/// resolved separately, once the real redirect URI is known.
196#[derive(Debug, Clone, Serialize, Deserialize)]
197pub struct OAuthDiscovery {
198 pub resource_metadata: ProtectedResourceMetadata,
199 pub auth_server_metadata: AuthServerMetadata,
200 pub scopes: Vec<String>,
201}
202
203/// The persisted OAuth session for a context server.
204///
205/// Stored in the keychain so startup can restore a refresh-capable provider
206/// without another browser flow. Deliberately excludes the full discovery
207/// metadata to keep the serialized size well within keychain item limits.
208#[derive(Clone, Serialize, Deserialize)]
209pub struct OAuthSession {
210 pub token_endpoint: Url,
211 pub resource: Url,
212 pub client_registration: OAuthClientRegistration,
213 pub tokens: OAuthTokens,
214}
215
216impl std::fmt::Debug for OAuthSession {
217 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
218 f.debug_struct("OAuthSession")
219 .field("token_endpoint", &self.token_endpoint)
220 .field("resource", &self.resource)
221 .field("client_registration", &self.client_registration)
222 .field("tokens", &self.tokens)
223 .finish()
224 }
225}
226
227/// Error codes defined by RFC 6750 Section 3.1 for Bearer token authentication.
228#[derive(Debug, Clone, Copy, PartialEq, Eq)]
229pub enum BearerError {
230 /// The request is missing a required parameter, includes an unsupported
231 /// parameter or parameter value, or is otherwise malformed.
232 InvalidRequest,
233 /// The access token provided is expired, revoked, malformed, or invalid.
234 InvalidToken,
235 /// The request requires higher privileges than provided by the access token.
236 InsufficientScope,
237 /// An unrecognized error code (extension or future spec addition).
238 Other,
239}
240
241impl BearerError {
242 fn parse(value: &str) -> Self {
243 match value {
244 "invalid_request" => BearerError::InvalidRequest,
245 "invalid_token" => BearerError::InvalidToken,
246 "insufficient_scope" => BearerError::InsufficientScope,
247 _ => BearerError::Other,
248 }
249 }
250}
251
252/// Fields extracted from a `WWW-Authenticate: Bearer` header.
253///
254/// Per RFC 9728 Section 5.1, MCP servers include `resource_metadata` to point
255/// at the Protected Resource Metadata document. The optional `scope` parameter
256/// (RFC 6750 Section 3) indicates scopes required for the request.
257#[derive(Debug, Clone, PartialEq, Eq)]
258pub struct WwwAuthenticate {
259 pub resource_metadata: Option<Url>,
260 pub scope: Option<Vec<String>>,
261 /// The parsed `error` parameter per RFC 6750 Section 3.1.
262 pub error: Option<BearerError>,
263 pub error_description: Option<String>,
264}
265
266/// Parse a `WWW-Authenticate` header value.
267///
268/// Expects the `Bearer` scheme followed by comma-separated `key="value"` pairs.
269/// Per RFC 6750 and RFC 9728, the relevant parameters are:
270/// - `resource_metadata` — URL of the Protected Resource Metadata document
271/// - `scope` — space-separated list of required scopes
272/// - `error` — error code (e.g. "insufficient_scope")
273/// - `error_description` — human-readable error description
274pub fn parse_www_authenticate(header: &str) -> Result<WwwAuthenticate> {
275 let header = header.trim();
276
277 let params_str = if header.len() >= 6 && header[..6].eq_ignore_ascii_case("bearer") {
278 header[6..].trim()
279 } else {
280 bail!("WWW-Authenticate header does not use Bearer scheme");
281 };
282
283 if params_str.is_empty() {
284 return Ok(WwwAuthenticate {
285 resource_metadata: None,
286 scope: None,
287 error: None,
288 error_description: None,
289 });
290 }
291
292 let params = parse_auth_params(params_str);
293
294 let resource_metadata = params
295 .get("resource_metadata")
296 .map(|v| Url::parse(v))
297 .transpose()
298 .map_err(|e| anyhow!("invalid resource_metadata URL: {}", e))?;
299
300 let scope = params
301 .get("scope")
302 .map(|v| v.split_whitespace().map(String::from).collect());
303
304 let error = params.get("error").map(|v| BearerError::parse(v));
305 let error_description = params.get("error_description").cloned();
306
307 Ok(WwwAuthenticate {
308 resource_metadata,
309 scope,
310 error,
311 error_description,
312 })
313}
314
315/// Parse comma-separated `key="value"` or `key=token` parameters from an
316/// auth-param list (RFC 7235 Section 2.1).
317fn parse_auth_params(input: &str) -> collections::HashMap<String, String> {
318 let mut params = collections::HashMap::default();
319 let mut remaining = input.trim();
320
321 while !remaining.is_empty() {
322 // Skip leading whitespace and commas.
323 remaining = remaining.trim_start_matches(|c: char| c == ',' || c.is_whitespace());
324 if remaining.is_empty() {
325 break;
326 }
327
328 // Find the key (everything before '=').
329 let eq_pos = match remaining.find('=') {
330 Some(pos) => pos,
331 None => break,
332 };
333
334 let key = remaining[..eq_pos].trim().to_lowercase();
335 remaining = &remaining[eq_pos + 1..];
336 remaining = remaining.trim_start();
337
338 // Parse the value: either quoted or unquoted (token).
339 let value;
340 if remaining.starts_with('"') {
341 // Quoted string: find the closing quote, handling escaped chars.
342 remaining = &remaining[1..]; // skip opening quote
343 let mut val = String::new();
344 let mut chars = remaining.char_indices();
345 loop {
346 match chars.next() {
347 Some((_, '\\')) => {
348 // Escaped character — take the next char literally.
349 if let Some((_, c)) = chars.next() {
350 val.push(c);
351 }
352 }
353 Some((i, '"')) => {
354 remaining = &remaining[i + 1..];
355 break;
356 }
357 Some((_, c)) => val.push(c),
358 None => {
359 remaining = "";
360 break;
361 }
362 }
363 }
364 value = val;
365 } else {
366 // Unquoted token: read until comma or whitespace.
367 let end = remaining
368 .find(|c: char| c == ',' || c.is_whitespace())
369 .unwrap_or(remaining.len());
370 value = remaining[..end].to_string();
371 remaining = &remaining[end..];
372 }
373
374 if !key.is_empty() {
375 params.insert(key, value);
376 }
377 }
378
379 params
380}
381
382/// Construct the well-known Protected Resource Metadata URIs for a given MCP
383/// server URL, per RFC 9728 Section 3.
384///
385/// Returns URIs in priority order:
386/// 1. Path-specific: `https://<host>/.well-known/oauth-protected-resource/<path>`
387/// 2. Root: `https://<host>/.well-known/oauth-protected-resource`
388pub fn protected_resource_metadata_urls(server_url: &Url) -> Vec<Url> {
389 let mut urls = Vec::new();
390 let base = format!("{}://{}", server_url.scheme(), server_url.authority());
391
392 let path = server_url.path().trim_start_matches('/');
393 if !path.is_empty() {
394 if let Ok(url) = Url::parse(&format!(
395 "{}/.well-known/oauth-protected-resource/{}",
396 base, path
397 )) {
398 urls.push(url);
399 }
400 }
401
402 if let Ok(url) = Url::parse(&format!("{}/.well-known/oauth-protected-resource", base)) {
403 urls.push(url);
404 }
405
406 urls
407}
408
409/// Construct the well-known Authorization Server Metadata URIs for a given
410/// issuer URL, per RFC 8414 Section 3.1 and Section 5 (OIDC compat).
411///
412/// Returns URIs in priority order, which differs depending on whether the
413/// issuer URL has a path component.
414pub fn auth_server_metadata_urls(issuer: &Url) -> Vec<Url> {
415 let mut urls = Vec::new();
416 let base = format!("{}://{}", issuer.scheme(), issuer.authority());
417 let path = issuer.path().trim_matches('/');
418
419 if !path.is_empty() {
420 // Issuer with path: try path-inserted variants first.
421 if let Ok(url) = Url::parse(&format!(
422 "{}/.well-known/oauth-authorization-server/{}",
423 base, path
424 )) {
425 urls.push(url);
426 }
427 if let Ok(url) = Url::parse(&format!(
428 "{}/.well-known/openid-configuration/{}",
429 base, path
430 )) {
431 urls.push(url);
432 }
433 if let Ok(url) = Url::parse(&format!(
434 "{}/{}/.well-known/openid-configuration",
435 base, path
436 )) {
437 urls.push(url);
438 }
439 } else {
440 // No path: standard well-known locations.
441 if let Ok(url) = Url::parse(&format!("{}/.well-known/oauth-authorization-server", base)) {
442 urls.push(url);
443 }
444 if let Ok(url) = Url::parse(&format!("{}/.well-known/openid-configuration", base)) {
445 urls.push(url);
446 }
447 }
448
449 urls
450}
451
452// -- Canonical server URI (RFC 8707) -----------------------------------------
453
454/// Derive the canonical resource URI for an MCP server URL, suitable for the
455/// `resource` parameter in authorization and token requests per RFC 8707.
456///
457/// Lowercases the scheme and host, preserves the path (without trailing slash),
458/// strips fragments and query strings.
459pub fn canonical_server_uri(server_url: &Url) -> String {
460 let mut uri = format!(
461 "{}://{}",
462 server_url.scheme().to_ascii_lowercase(),
463 server_url.host_str().unwrap_or("").to_ascii_lowercase(),
464 );
465 if let Some(port) = server_url.port() {
466 uri.push_str(&format!(":{}", port));
467 }
468 let path = server_url.path();
469 if path != "/" {
470 uri.push_str(path.trim_end_matches('/'));
471 }
472 uri
473}
474
475// -- Scope selection ---------------------------------------------------------
476
477/// Select scopes following the MCP spec's Scope Selection Strategy:
478/// 1. Use `scope` from the `WWW-Authenticate` challenge if present.
479/// 2. Fall back to `scopes_supported` from Protected Resource Metadata.
480/// 3. Return empty if neither is available.
481pub fn select_scopes(
482 www_authenticate: &WwwAuthenticate,
483 resource_metadata: &ProtectedResourceMetadata,
484) -> Vec<String> {
485 if let Some(ref scopes) = www_authenticate.scope {
486 if !scopes.is_empty() {
487 return scopes.clone();
488 }
489 }
490 resource_metadata
491 .scopes_supported
492 .clone()
493 .unwrap_or_default()
494}
495
496// -- Client registration strategy --------------------------------------------
497
498/// The registration approach to use, determined from auth server metadata.
499#[derive(Debug, Clone, PartialEq, Eq)]
500pub enum ClientRegistrationStrategy {
501 /// The auth server supports CIMD. Use the CIMD URL as client_id directly.
502 Cimd { client_id: String },
503 /// The auth server has a registration endpoint. Caller must POST to it.
504 Dcr { registration_endpoint: Url },
505 /// No supported registration mechanism.
506 Unavailable,
507}
508
509/// Determine how to register with the authorization server, following the
510/// spec's recommended priority: CIMD first, DCR fallback.
511pub fn determine_registration_strategy(
512 auth_server_metadata: &AuthServerMetadata,
513) -> ClientRegistrationStrategy {
514 if auth_server_metadata.client_id_metadata_document_supported {
515 ClientRegistrationStrategy::Cimd {
516 client_id: CIMD_URL.to_string(),
517 }
518 } else if let Some(ref endpoint) = auth_server_metadata.registration_endpoint {
519 ClientRegistrationStrategy::Dcr {
520 registration_endpoint: endpoint.clone(),
521 }
522 } else {
523 ClientRegistrationStrategy::Unavailable
524 }
525}
526
527// -- PKCE (RFC 7636) ---------------------------------------------------------
528
529/// A PKCE code verifier and its S256 challenge.
530#[derive(Clone)]
531pub struct PkceChallenge {
532 pub verifier: String,
533 pub challenge: String,
534}
535
536impl std::fmt::Debug for PkceChallenge {
537 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
538 f.debug_struct("PkceChallenge")
539 .field("verifier", &"[redacted]")
540 .field("challenge", &self.challenge)
541 .finish()
542 }
543}
544
545/// Generate a PKCE code verifier and S256 challenge per RFC 7636.
546///
547/// The verifier is 43 base64url characters derived from 32 random bytes.
548/// The challenge is `BASE64URL(SHA256(verifier))`.
549pub fn generate_pkce_challenge() -> PkceChallenge {
550 let mut random_bytes = [0u8; 32];
551 rand::rng().fill(&mut random_bytes);
552 let engine = base64::engine::general_purpose::URL_SAFE_NO_PAD;
553 let verifier = engine.encode(&random_bytes);
554
555 let digest = Sha256::digest(verifier.as_bytes());
556 let challenge = engine.encode(digest);
557
558 PkceChallenge {
559 verifier,
560 challenge,
561 }
562}
563
564// -- Authorization URL construction ------------------------------------------
565
566/// Build the authorization URL for the OAuth Authorization Code + PKCE flow.
567pub fn build_authorization_url(
568 auth_server_metadata: &AuthServerMetadata,
569 client_id: &str,
570 redirect_uri: &str,
571 scopes: &[String],
572 resource: &str,
573 pkce: &PkceChallenge,
574 state: &str,
575) -> Url {
576 let mut url = auth_server_metadata.authorization_endpoint.clone();
577 {
578 let mut query = url.query_pairs_mut();
579 query.append_pair("response_type", "code");
580 query.append_pair("client_id", client_id);
581 query.append_pair("redirect_uri", redirect_uri);
582 if !scopes.is_empty() {
583 query.append_pair("scope", &scopes.join(" "));
584 }
585 query.append_pair("resource", resource);
586 query.append_pair("code_challenge", &pkce.challenge);
587 query.append_pair("code_challenge_method", "S256");
588 query.append_pair("state", state);
589 }
590 url
591}
592
593// -- Token endpoint request bodies -------------------------------------------
594
595/// The JSON body returned by the token endpoint on success.
596#[derive(Deserialize)]
597pub struct TokenResponse {
598 pub access_token: String,
599 #[serde(default)]
600 pub refresh_token: Option<String>,
601 #[serde(default)]
602 pub expires_in: Option<u64>,
603 #[serde(default)]
604 pub token_type: Option<String>,
605}
606
607impl std::fmt::Debug for TokenResponse {
608 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
609 f.debug_struct("TokenResponse")
610 .field("access_token", &"[redacted]")
611 .field(
612 "refresh_token",
613 &self.refresh_token.as_ref().map(|_| "[redacted]"),
614 )
615 .field("expires_in", &self.expires_in)
616 .field("token_type", &self.token_type)
617 .finish()
618 }
619}
620
621impl TokenResponse {
622 /// Convert into `OAuthTokens`, computing `expires_at` from `expires_in`.
623 pub fn into_tokens(self) -> OAuthTokens {
624 let expires_at = self
625 .expires_in
626 .map(|secs| SystemTime::now() + Duration::from_secs(secs));
627 OAuthTokens {
628 access_token: self.access_token,
629 refresh_token: self.refresh_token,
630 expires_at,
631 }
632 }
633}
634
635/// Build the form-encoded body for an authorization code token exchange.
636pub fn token_exchange_params(
637 code: &str,
638 client_id: &str,
639 redirect_uri: &str,
640 code_verifier: &str,
641 resource: &str,
642) -> Vec<(&'static str, String)> {
643 vec![
644 ("grant_type", "authorization_code".to_string()),
645 ("code", code.to_string()),
646 ("redirect_uri", redirect_uri.to_string()),
647 ("client_id", client_id.to_string()),
648 ("code_verifier", code_verifier.to_string()),
649 ("resource", resource.to_string()),
650 ]
651}
652
653/// Build the form-encoded body for a token refresh request.
654pub fn token_refresh_params(
655 refresh_token: &str,
656 client_id: &str,
657 resource: &str,
658) -> Vec<(&'static str, String)> {
659 vec![
660 ("grant_type", "refresh_token".to_string()),
661 ("refresh_token", refresh_token.to_string()),
662 ("client_id", client_id.to_string()),
663 ("resource", resource.to_string()),
664 ]
665}
666
667// -- DCR request body (RFC 7591) ---------------------------------------------
668
669/// Build the JSON body for a Dynamic Client Registration request.
670///
671/// The `redirect_uri` should be the actual loopback URI with the ephemeral
672/// port (e.g. `http://127.0.0.1:12345/callback`). Some auth servers do strict
673/// redirect URI matching even for loopback addresses, so we register the
674/// exact URI we intend to use.
675pub fn dcr_registration_body(redirect_uri: &str) -> serde_json::Value {
676 serde_json::json!({
677 "client_name": "Zed",
678 "redirect_uris": [redirect_uri],
679 "grant_types": ["authorization_code"],
680 "response_types": ["code"],
681 "token_endpoint_auth_method": "none"
682 })
683}
684
685// -- Discovery (async, hits real endpoints) ----------------------------------
686
687/// Fetch Protected Resource Metadata from the MCP server.
688///
689/// Tries the `resource_metadata` URL from the `WWW-Authenticate` header first,
690/// then falls back to well-known URIs constructed from `server_url`.
691pub async fn fetch_protected_resource_metadata(
692 http_client: &Arc<dyn HttpClient>,
693 server_url: &Url,
694 www_authenticate: &WwwAuthenticate,
695) -> Result<ProtectedResourceMetadata> {
696 let candidate_urls = match &www_authenticate.resource_metadata {
697 Some(url) if url.origin() == server_url.origin() => vec![url.clone()],
698 Some(url) => {
699 log::warn!(
700 "Ignoring cross-origin resource_metadata URL {} \
701 (server origin: {})",
702 url,
703 server_url.origin().unicode_serialization()
704 );
705 protected_resource_metadata_urls(server_url)
706 }
707 None => protected_resource_metadata_urls(server_url),
708 };
709
710 for url in &candidate_urls {
711 match fetch_json::<ProtectedResourceMetadataResponse>(http_client, url).await {
712 Ok(response) => {
713 if response.authorization_servers.is_empty() {
714 bail!(
715 "Protected Resource Metadata at {} has no authorization_servers",
716 url
717 );
718 }
719 return Ok(ProtectedResourceMetadata {
720 resource: response.resource.unwrap_or_else(|| server_url.clone()),
721 authorization_servers: response.authorization_servers,
722 scopes_supported: response.scopes_supported,
723 });
724 }
725 Err(err) => {
726 log::debug!(
727 "Failed to fetch Protected Resource Metadata from {}: {}",
728 url,
729 err
730 );
731 }
732 }
733 }
734
735 bail!(
736 "Could not fetch Protected Resource Metadata for {}",
737 server_url
738 )
739}
740
741/// Fetch Authorization Server Metadata, trying RFC 8414 and OIDC Discovery
742/// endpoints in the priority order specified by the MCP spec.
743pub async fn fetch_auth_server_metadata(
744 http_client: &Arc<dyn HttpClient>,
745 issuer: &Url,
746) -> Result<AuthServerMetadata> {
747 let candidate_urls = auth_server_metadata_urls(issuer);
748
749 for url in &candidate_urls {
750 match fetch_json::<AuthServerMetadataResponse>(http_client, url).await {
751 Ok(response) => {
752 let reported_issuer = response.issuer.unwrap_or_else(|| issuer.clone());
753 if reported_issuer != *issuer {
754 bail!(
755 "Auth server metadata issuer mismatch: expected {}, got {}",
756 issuer,
757 reported_issuer
758 );
759 }
760
761 return Ok(AuthServerMetadata {
762 issuer: reported_issuer,
763 authorization_endpoint: response
764 .authorization_endpoint
765 .ok_or_else(|| anyhow!("missing authorization_endpoint"))?,
766 token_endpoint: response
767 .token_endpoint
768 .ok_or_else(|| anyhow!("missing token_endpoint"))?,
769 registration_endpoint: response.registration_endpoint,
770 scopes_supported: response.scopes_supported,
771 code_challenge_methods_supported: response.code_challenge_methods_supported,
772 client_id_metadata_document_supported: response
773 .client_id_metadata_document_supported
774 .unwrap_or(false),
775 });
776 }
777 Err(err) => {
778 log::debug!("Failed to fetch Auth Server Metadata from {}: {}", url, err);
779 }
780 }
781 }
782
783 bail!(
784 "Could not fetch Authorization Server Metadata for {}",
785 issuer
786 )
787}
788
789/// Run the full discovery flow: fetch resource metadata, then auth server
790/// metadata, then select scopes. Client registration is resolved separately,
791/// once the real redirect URI is known.
792pub async fn discover(
793 http_client: &Arc<dyn HttpClient>,
794 server_url: &Url,
795 www_authenticate: &WwwAuthenticate,
796) -> Result<OAuthDiscovery> {
797 let resource_metadata =
798 fetch_protected_resource_metadata(http_client, server_url, www_authenticate).await?;
799
800 let auth_server_url = resource_metadata
801 .authorization_servers
802 .first()
803 .ok_or_else(|| anyhow!("no authorization servers in resource metadata"))?;
804
805 let auth_server_metadata = fetch_auth_server_metadata(http_client, auth_server_url).await?;
806
807 // Verify PKCE S256 support (spec requirement).
808 match &auth_server_metadata.code_challenge_methods_supported {
809 Some(methods) if methods.iter().any(|m| m == "S256") => {}
810 Some(_) => bail!("authorization server does not support S256 PKCE"),
811 None => bail!("authorization server does not advertise code_challenge_methods_supported"),
812 }
813
814 // Verify there is at least one supported registration strategy before we
815 // present the server as ready to authenticate.
816 match determine_registration_strategy(&auth_server_metadata) {
817 ClientRegistrationStrategy::Cimd { .. } | ClientRegistrationStrategy::Dcr { .. } => {}
818 ClientRegistrationStrategy::Unavailable => {
819 bail!("authorization server supports neither CIMD nor DCR")
820 }
821 }
822
823 let scopes = select_scopes(www_authenticate, &resource_metadata);
824
825 Ok(OAuthDiscovery {
826 resource_metadata,
827 auth_server_metadata,
828 scopes,
829 })
830}
831
832/// Resolve the OAuth client registration for an authorization flow.
833///
834/// CIMD uses the static client metadata document directly. For DCR, a fresh
835/// registration is performed each time because the loopback redirect URI
836/// includes an ephemeral port that changes every flow.
837pub async fn resolve_client_registration(
838 http_client: &Arc<dyn HttpClient>,
839 discovery: &OAuthDiscovery,
840 redirect_uri: &str,
841) -> Result<OAuthClientRegistration> {
842 match determine_registration_strategy(&discovery.auth_server_metadata) {
843 ClientRegistrationStrategy::Cimd { client_id } => Ok(OAuthClientRegistration {
844 client_id,
845 client_secret: None,
846 }),
847 ClientRegistrationStrategy::Dcr {
848 registration_endpoint,
849 } => perform_dcr(http_client, ®istration_endpoint, redirect_uri).await,
850 ClientRegistrationStrategy::Unavailable => {
851 bail!("authorization server supports neither CIMD nor DCR")
852 }
853 }
854}
855
856// -- Dynamic Client Registration (RFC 7591) ----------------------------------
857
858/// Perform Dynamic Client Registration with the authorization server.
859pub async fn perform_dcr(
860 http_client: &Arc<dyn HttpClient>,
861 registration_endpoint: &Url,
862 redirect_uri: &str,
863) -> Result<OAuthClientRegistration> {
864 validate_oauth_url(registration_endpoint)?;
865
866 let body = dcr_registration_body(redirect_uri);
867 let body_bytes = serde_json::to_vec(&body)?;
868
869 let request = Request::builder()
870 .method(http_client::http::Method::POST)
871 .uri(registration_endpoint.as_str())
872 .header("Content-Type", "application/json")
873 .header("Accept", "application/json")
874 .body(AsyncBody::from(body_bytes))?;
875
876 let mut response = http_client.send(request).await?;
877
878 if !response.status().is_success() {
879 let mut error_body = String::new();
880 response.body_mut().read_to_string(&mut error_body).await?;
881 bail!(
882 "DCR failed with status {}: {}",
883 response.status(),
884 error_body
885 );
886 }
887
888 let mut response_body = String::new();
889 response
890 .body_mut()
891 .read_to_string(&mut response_body)
892 .await?;
893
894 let dcr_response: DcrResponse =
895 serde_json::from_str(&response_body).context("failed to parse DCR response")?;
896
897 Ok(OAuthClientRegistration {
898 client_id: dcr_response.client_id,
899 client_secret: dcr_response.client_secret,
900 })
901}
902
903// -- Token exchange and refresh (async) --------------------------------------
904
905/// Exchange an authorization code for tokens at the token endpoint.
906pub async fn exchange_code(
907 http_client: &Arc<dyn HttpClient>,
908 auth_server_metadata: &AuthServerMetadata,
909 code: &str,
910 client_id: &str,
911 redirect_uri: &str,
912 code_verifier: &str,
913 resource: &str,
914) -> Result<OAuthTokens> {
915 let params = token_exchange_params(code, client_id, redirect_uri, code_verifier, resource);
916 post_token_request(http_client, &auth_server_metadata.token_endpoint, ¶ms).await
917}
918
919/// Refresh tokens using a refresh token.
920pub async fn refresh_tokens(
921 http_client: &Arc<dyn HttpClient>,
922 token_endpoint: &Url,
923 refresh_token: &str,
924 client_id: &str,
925 resource: &str,
926) -> Result<OAuthTokens> {
927 let params = token_refresh_params(refresh_token, client_id, resource);
928 post_token_request(http_client, token_endpoint, ¶ms).await
929}
930
931/// POST form-encoded parameters to a token endpoint and parse the response.
932async fn post_token_request(
933 http_client: &Arc<dyn HttpClient>,
934 token_endpoint: &Url,
935 params: &[(&str, String)],
936) -> Result<OAuthTokens> {
937 validate_oauth_url(token_endpoint)?;
938
939 let body = url::form_urlencoded::Serializer::new(String::new())
940 .extend_pairs(params.iter().map(|(k, v)| (*k, v.as_str())))
941 .finish();
942
943 let request = Request::builder()
944 .method(http_client::http::Method::POST)
945 .uri(token_endpoint.as_str())
946 .header("Content-Type", "application/x-www-form-urlencoded")
947 .header("Accept", "application/json")
948 .body(AsyncBody::from(body.into_bytes()))?;
949
950 let mut response = http_client.send(request).await?;
951
952 if !response.status().is_success() {
953 let mut error_body = String::new();
954 response.body_mut().read_to_string(&mut error_body).await?;
955 bail!(
956 "token request failed with status {}: {}",
957 response.status(),
958 error_body
959 );
960 }
961
962 let mut response_body = String::new();
963 response
964 .body_mut()
965 .read_to_string(&mut response_body)
966 .await?;
967
968 let token_response: TokenResponse =
969 serde_json::from_str(&response_body).context("failed to parse token response")?;
970
971 Ok(token_response.into_tokens())
972}
973
974// -- Loopback HTTP callback server -------------------------------------------
975
976/// An OAuth authorization callback received via the loopback HTTP server.
977pub struct OAuthCallback {
978 pub code: String,
979 pub state: String,
980}
981
982impl std::fmt::Debug for OAuthCallback {
983 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
984 f.debug_struct("OAuthCallback")
985 .field("code", &"[redacted]")
986 .field("state", &"[redacted]")
987 .finish()
988 }
989}
990
991impl OAuthCallback {
992 /// Parse the query string from a callback URL like
993 /// `http://127.0.0.1:<port>/callback?code=...&state=...`.
994 pub fn parse_query(query: &str) -> Result<Self> {
995 let mut code: Option<String> = None;
996 let mut state: Option<String> = None;
997 let mut error: Option<String> = None;
998 let mut error_description: Option<String> = None;
999
1000 for (key, value) in url::form_urlencoded::parse(query.as_bytes()) {
1001 match key.as_ref() {
1002 "code" => {
1003 if !value.is_empty() {
1004 code = Some(value.into_owned());
1005 }
1006 }
1007 "state" => {
1008 if !value.is_empty() {
1009 state = Some(value.into_owned());
1010 }
1011 }
1012 "error" => {
1013 if !value.is_empty() {
1014 error = Some(value.into_owned());
1015 }
1016 }
1017 "error_description" => {
1018 if !value.is_empty() {
1019 error_description = Some(value.into_owned());
1020 }
1021 }
1022 _ => {}
1023 }
1024 }
1025
1026 // Check for OAuth error response (RFC 6749 Section 4.1.2.1) before
1027 // checking for missing code/state.
1028 if let Some(error_code) = error {
1029 bail!(
1030 "OAuth authorization failed: {} ({})",
1031 error_code,
1032 error_description.as_deref().unwrap_or("no description")
1033 );
1034 }
1035
1036 let code = code.ok_or_else(|| anyhow!("missing 'code' parameter in OAuth callback"))?;
1037 let state = state.ok_or_else(|| anyhow!("missing 'state' parameter in OAuth callback"))?;
1038
1039 Ok(Self { code, state })
1040 }
1041}
1042
1043/// How long to wait for the browser to complete the OAuth flow before giving
1044/// up and releasing the loopback port.
1045const CALLBACK_TIMEOUT: Duration = Duration::from_secs(2 * 60);
1046
1047/// Start a loopback HTTP server to receive the OAuth authorization callback.
1048///
1049/// Binds to an ephemeral loopback port for each flow.
1050///
1051/// Returns `(redirect_uri, callback_future)`. The caller should use the
1052/// redirect URI in the authorization request, open the browser, then await
1053/// the future to receive the callback.
1054///
1055/// The server accepts exactly one request on `/callback`, validates that it
1056/// contains `code` and `state` query parameters, responds with a minimal
1057/// HTML page telling the user they can close the tab, and shuts down.
1058///
1059/// The callback server shuts down when the returned oneshot receiver is dropped
1060/// (e.g. because the authentication task was cancelled), or after a timeout
1061/// ([CALLBACK_TIMEOUT]).
1062pub async fn start_callback_server() -> Result<(
1063 String,
1064 futures::channel::oneshot::Receiver<Result<OAuthCallback>>,
1065)> {
1066 let server = tiny_http::Server::http("127.0.0.1:0")
1067 .map_err(|e| anyhow!(e).context("Failed to bind loopback listener for OAuth callback"))?;
1068 let port = server
1069 .server_addr()
1070 .to_ip()
1071 .context("server not bound to a TCP address")?
1072 .port();
1073
1074 let redirect_uri = format!("http://127.0.0.1:{}/callback", port);
1075
1076 let (tx, rx) = futures::channel::oneshot::channel();
1077
1078 // `tiny_http` is blocking, so we run it on a background thread.
1079 // The `recv_timeout` loop lets us check for cancellation (the receiver
1080 // being dropped) and enforce an overall timeout.
1081 std::thread::spawn(move || {
1082 let deadline = std::time::Instant::now() + CALLBACK_TIMEOUT;
1083
1084 loop {
1085 if tx.is_canceled() {
1086 return;
1087 }
1088 let remaining = deadline.saturating_duration_since(std::time::Instant::now());
1089 if remaining.is_zero() {
1090 return;
1091 }
1092
1093 let timeout = remaining.min(Duration::from_millis(500));
1094 let Some(request) = (match server.recv_timeout(timeout) {
1095 Ok(req) => req,
1096 Err(_) => {
1097 let _ = tx.send(Err(anyhow!("OAuth callback server I/O error")));
1098 return;
1099 }
1100 }) else {
1101 // Timeout with no request — loop back and check cancellation.
1102 continue;
1103 };
1104
1105 let result = handle_callback_request(&request);
1106
1107 let (status_code, body) = match &result {
1108 Ok(_) => (
1109 200,
1110 http_client::oauth_callback_page(
1111 "Authorization Successful",
1112 "You can close this tab and return to Zed.",
1113 ),
1114 ),
1115 Err(err) => {
1116 log::error!("OAuth callback error: {}", err);
1117 (
1118 400,
1119 http_client::oauth_callback_page(
1120 "Authorization Failed",
1121 "Something went wrong. Please try again from Zed.",
1122 ),
1123 )
1124 }
1125 };
1126
1127 let response = tiny_http::Response::from_string(body)
1128 .with_status_code(status_code)
1129 .with_header(
1130 tiny_http::Header::from_str("Content-Type: text/html")
1131 .expect("failed to construct response header"),
1132 )
1133 .with_header(
1134 tiny_http::Header::from_str("Keep-Alive: timeout=0,max=0")
1135 .expect("failed to construct response header"),
1136 );
1137 request.respond(response).log_err();
1138
1139 let _ = tx.send(result);
1140 return;
1141 }
1142 });
1143
1144 Ok((redirect_uri, rx))
1145}
1146
1147/// Extract the `code` and `state` query parameters from an OAuth callback
1148/// request to `/callback`.
1149fn handle_callback_request(request: &tiny_http::Request) -> Result<OAuthCallback> {
1150 let url = Url::parse(&format!("http://localhost{}", request.url()))
1151 .context("malformed callback request URL")?;
1152
1153 if url.path() != "/callback" {
1154 bail!("unexpected path in OAuth callback: {}", url.path());
1155 }
1156
1157 let query = url
1158 .query()
1159 .ok_or_else(|| anyhow!("OAuth callback has no query string"))?;
1160 OAuthCallback::parse_query(query)
1161}
1162
1163// -- JSON fetch helper -------------------------------------------------------
1164
1165async fn fetch_json<T: serde::de::DeserializeOwned>(
1166 http_client: &Arc<dyn HttpClient>,
1167 url: &Url,
1168) -> Result<T> {
1169 validate_oauth_url(url)?;
1170
1171 let request = Request::builder()
1172 .method(http_client::http::Method::GET)
1173 .uri(url.as_str())
1174 .header("Accept", "application/json")
1175 .body(AsyncBody::default())?;
1176
1177 let mut response = http_client.send(request).await?;
1178
1179 if !response.status().is_success() {
1180 bail!("HTTP {} fetching {}", response.status(), url);
1181 }
1182
1183 let mut body = String::new();
1184 response.body_mut().read_to_string(&mut body).await?;
1185 serde_json::from_str(&body).with_context(|| format!("failed to parse JSON from {}", url))
1186}
1187
1188// -- Serde response types for discovery --------------------------------------
1189
1190#[derive(Debug, Deserialize)]
1191struct ProtectedResourceMetadataResponse {
1192 #[serde(default)]
1193 resource: Option<Url>,
1194 #[serde(default)]
1195 authorization_servers: Vec<Url>,
1196 #[serde(default)]
1197 scopes_supported: Option<Vec<String>>,
1198}
1199
1200#[derive(Debug, Deserialize)]
1201struct AuthServerMetadataResponse {
1202 #[serde(default)]
1203 issuer: Option<Url>,
1204 #[serde(default)]
1205 authorization_endpoint: Option<Url>,
1206 #[serde(default)]
1207 token_endpoint: Option<Url>,
1208 #[serde(default)]
1209 registration_endpoint: Option<Url>,
1210 #[serde(default)]
1211 scopes_supported: Option<Vec<String>>,
1212 #[serde(default)]
1213 code_challenge_methods_supported: Option<Vec<String>>,
1214 #[serde(default)]
1215 client_id_metadata_document_supported: Option<bool>,
1216}
1217
1218#[derive(Debug, Deserialize)]
1219struct DcrResponse {
1220 client_id: String,
1221 #[serde(default)]
1222 client_secret: Option<String>,
1223}
1224
1225/// Provides OAuth tokens to the HTTP transport layer.
1226///
1227/// The transport calls `access_token()` before each request. On a 401 response
1228/// it calls `try_refresh()` and retries once if the refresh succeeds.
1229#[async_trait]
1230pub trait OAuthTokenProvider: Send + Sync {
1231 /// Returns the current access token, if one is available.
1232 fn access_token(&self) -> Option<String>;
1233
1234 /// Attempts to refresh the access token. Returns `true` if a new token was
1235 /// obtained and the request should be retried.
1236 async fn try_refresh(&self) -> Result<bool>;
1237}
1238
1239/// Concrete `OAuthTokenProvider` backed by a full persisted OAuth session and
1240/// an HTTP client for token refresh. The same provider type is used both after
1241/// an interactive authentication flow and when restoring a saved session from
1242/// the keychain on startup.
1243pub struct McpOAuthTokenProvider {
1244 session: SyncMutex<OAuthSession>,
1245 http_client: Arc<dyn HttpClient>,
1246 token_refresh_tx: Option<mpsc::UnboundedSender<OAuthSession>>,
1247}
1248
1249impl McpOAuthTokenProvider {
1250 pub fn new(
1251 session: OAuthSession,
1252 http_client: Arc<dyn HttpClient>,
1253 token_refresh_tx: Option<mpsc::UnboundedSender<OAuthSession>>,
1254 ) -> Self {
1255 Self {
1256 session: SyncMutex::new(session),
1257 http_client,
1258 token_refresh_tx,
1259 }
1260 }
1261
1262 fn access_token_is_expired(tokens: &OAuthTokens) -> bool {
1263 tokens.expires_at.is_some_and(|expires_at| {
1264 SystemTime::now()
1265 .checked_add(Duration::from_secs(30))
1266 .is_some_and(|now_with_buffer| expires_at <= now_with_buffer)
1267 })
1268 }
1269}
1270
1271#[async_trait]
1272impl OAuthTokenProvider for McpOAuthTokenProvider {
1273 fn access_token(&self) -> Option<String> {
1274 let session = self.session.lock();
1275 if Self::access_token_is_expired(&session.tokens) {
1276 return None;
1277 }
1278 Some(session.tokens.access_token.clone())
1279 }
1280
1281 async fn try_refresh(&self) -> Result<bool> {
1282 let (refresh_token, token_endpoint, resource, client_id) = {
1283 let session = self.session.lock();
1284 match session.tokens.refresh_token.clone() {
1285 Some(refresh_token) => (
1286 refresh_token,
1287 session.token_endpoint.clone(),
1288 session.resource.clone(),
1289 session.client_registration.client_id.clone(),
1290 ),
1291 None => return Ok(false),
1292 }
1293 };
1294
1295 let resource_str = canonical_server_uri(&resource);
1296
1297 match refresh_tokens(
1298 &self.http_client,
1299 &token_endpoint,
1300 &refresh_token,
1301 &client_id,
1302 &resource_str,
1303 )
1304 .await
1305 {
1306 Ok(mut new_tokens) => {
1307 if new_tokens.refresh_token.is_none() {
1308 new_tokens.refresh_token = Some(refresh_token);
1309 }
1310
1311 {
1312 let mut session = self.session.lock();
1313 session.tokens = new_tokens;
1314
1315 if let Some(ref tx) = self.token_refresh_tx {
1316 tx.unbounded_send(session.clone()).ok();
1317 }
1318 }
1319
1320 Ok(true)
1321 }
1322 Err(err) => {
1323 log::warn!("OAuth token refresh failed: {}", err);
1324 Ok(false)
1325 }
1326 }
1327 }
1328}
1329
1330#[cfg(test)]
1331mod tests {
1332 use super::*;
1333 use http_client::Response;
1334
1335 // -- require_https_or_loopback tests ------------------------------------
1336
1337 #[test]
1338 fn test_require_https_or_loopback_accepts_https() {
1339 let url = Url::parse("https://auth.example.com/token").unwrap();
1340 assert!(require_https_or_loopback(&url).is_ok());
1341 }
1342
1343 #[test]
1344 fn test_require_https_or_loopback_rejects_http_remote() {
1345 let url = Url::parse("http://auth.example.com/token").unwrap();
1346 assert!(require_https_or_loopback(&url).is_err());
1347 }
1348
1349 #[test]
1350 fn test_require_https_or_loopback_accepts_http_127_0_0_1() {
1351 let url = Url::parse("http://127.0.0.1:8080/callback").unwrap();
1352 assert!(require_https_or_loopback(&url).is_ok());
1353 }
1354
1355 #[test]
1356 fn test_require_https_or_loopback_accepts_http_ipv6_loopback() {
1357 let url = Url::parse("http://[::1]:8080/callback").unwrap();
1358 assert!(require_https_or_loopback(&url).is_ok());
1359 }
1360
1361 #[test]
1362 fn test_require_https_or_loopback_accepts_http_localhost() {
1363 let url = Url::parse("http://localhost:8080/callback").unwrap();
1364 assert!(require_https_or_loopback(&url).is_ok());
1365 }
1366
1367 #[test]
1368 fn test_require_https_or_loopback_accepts_http_localhost_case_insensitive() {
1369 let url = Url::parse("http://LOCALHOST:8080/callback").unwrap();
1370 assert!(require_https_or_loopback(&url).is_ok());
1371 }
1372
1373 #[test]
1374 fn test_require_https_or_loopback_rejects_http_non_loopback_ip() {
1375 let url = Url::parse("http://192.168.1.1:8080/token").unwrap();
1376 assert!(require_https_or_loopback(&url).is_err());
1377 }
1378
1379 #[test]
1380 fn test_require_https_or_loopback_rejects_ftp() {
1381 let url = Url::parse("ftp://auth.example.com/token").unwrap();
1382 assert!(require_https_or_loopback(&url).is_err());
1383 }
1384
1385 // -- validate_oauth_url (SSRF) tests ------------------------------------
1386
1387 #[test]
1388 fn test_validate_oauth_url_accepts_https_public() {
1389 let url = Url::parse("https://auth.example.com/token").unwrap();
1390 assert!(validate_oauth_url(&url).is_ok());
1391 }
1392
1393 #[test]
1394 fn test_validate_oauth_url_rejects_private_ipv4_10() {
1395 let url = Url::parse("https://10.0.0.1/token").unwrap();
1396 assert!(validate_oauth_url(&url).is_err());
1397 }
1398
1399 #[test]
1400 fn test_validate_oauth_url_rejects_private_ipv4_172() {
1401 let url = Url::parse("https://172.16.0.1/token").unwrap();
1402 assert!(validate_oauth_url(&url).is_err());
1403 }
1404
1405 #[test]
1406 fn test_validate_oauth_url_rejects_private_ipv4_192() {
1407 let url = Url::parse("https://192.168.1.1/token").unwrap();
1408 assert!(validate_oauth_url(&url).is_err());
1409 }
1410
1411 #[test]
1412 fn test_validate_oauth_url_rejects_link_local() {
1413 let url = Url::parse("https://169.254.169.254/latest/meta-data/").unwrap();
1414 assert!(validate_oauth_url(&url).is_err());
1415 }
1416
1417 #[test]
1418 fn test_validate_oauth_url_rejects_ipv6_ula() {
1419 let url = Url::parse("https://[fd12:3456:789a::1]/token").unwrap();
1420 assert!(validate_oauth_url(&url).is_err());
1421 }
1422
1423 #[test]
1424 fn test_validate_oauth_url_rejects_ipv6_unspecified() {
1425 let url = Url::parse("https://[::]/token").unwrap();
1426 assert!(validate_oauth_url(&url).is_err());
1427 }
1428
1429 #[test]
1430 fn test_validate_oauth_url_rejects_ipv4_mapped_ipv6_private() {
1431 let url = Url::parse("https://[::ffff:10.0.0.1]/token").unwrap();
1432 assert!(validate_oauth_url(&url).is_err());
1433 }
1434
1435 #[test]
1436 fn test_validate_oauth_url_rejects_ipv4_mapped_ipv6_link_local() {
1437 let url = Url::parse("https://[::ffff:169.254.169.254]/token").unwrap();
1438 assert!(validate_oauth_url(&url).is_err());
1439 }
1440
1441 #[test]
1442 fn test_validate_oauth_url_allows_http_loopback() {
1443 // Loopback is permitted (it's our callback server).
1444 let url = Url::parse("http://127.0.0.1:8080/callback").unwrap();
1445 assert!(validate_oauth_url(&url).is_ok());
1446 }
1447
1448 #[test]
1449 fn test_validate_oauth_url_allows_https_public_ip() {
1450 let url = Url::parse("https://93.184.216.34/token").unwrap();
1451 assert!(validate_oauth_url(&url).is_ok());
1452 }
1453
1454 // -- parse_www_authenticate tests ----------------------------------------
1455
1456 #[test]
1457 fn test_parse_www_authenticate_with_resource_metadata_and_scope() {
1458 let header = r#"Bearer resource_metadata="https://mcp.example.com/.well-known/oauth-protected-resource", scope="files:read user:profile""#;
1459 let result = parse_www_authenticate(header).unwrap();
1460
1461 assert_eq!(
1462 result.resource_metadata.as_ref().map(|u| u.as_str()),
1463 Some("https://mcp.example.com/.well-known/oauth-protected-resource")
1464 );
1465 assert_eq!(
1466 result.scope,
1467 Some(vec!["files:read".to_string(), "user:profile".to_string()])
1468 );
1469 assert_eq!(result.error, None);
1470 }
1471
1472 #[test]
1473 fn test_parse_www_authenticate_resource_metadata_only() {
1474 let header = r#"Bearer resource_metadata="https://mcp.example.com/.well-known/oauth-protected-resource""#;
1475 let result = parse_www_authenticate(header).unwrap();
1476
1477 assert_eq!(
1478 result.resource_metadata.as_ref().map(|u| u.as_str()),
1479 Some("https://mcp.example.com/.well-known/oauth-protected-resource")
1480 );
1481 assert_eq!(result.scope, None);
1482 }
1483
1484 #[test]
1485 fn test_parse_www_authenticate_bare_bearer() {
1486 let result = parse_www_authenticate("Bearer").unwrap();
1487 assert_eq!(result.resource_metadata, None);
1488 assert_eq!(result.scope, None);
1489 }
1490
1491 #[test]
1492 fn test_parse_www_authenticate_with_error() {
1493 let header = r#"Bearer error="insufficient_scope", scope="files:read files:write", resource_metadata="https://mcp.example.com/.well-known/oauth-protected-resource", error_description="Additional file write permission required""#;
1494 let result = parse_www_authenticate(header).unwrap();
1495
1496 assert_eq!(result.error, Some(BearerError::InsufficientScope));
1497 assert_eq!(
1498 result.error_description.as_deref(),
1499 Some("Additional file write permission required")
1500 );
1501 assert_eq!(
1502 result.scope,
1503 Some(vec!["files:read".to_string(), "files:write".to_string()])
1504 );
1505 assert!(result.resource_metadata.is_some());
1506 }
1507
1508 #[test]
1509 fn test_parse_www_authenticate_invalid_token_error() {
1510 let header =
1511 r#"Bearer error="invalid_token", error_description="The access token expired""#;
1512 let result = parse_www_authenticate(header).unwrap();
1513 assert_eq!(result.error, Some(BearerError::InvalidToken));
1514 }
1515
1516 #[test]
1517 fn test_parse_www_authenticate_invalid_request_error() {
1518 let header = r#"Bearer error="invalid_request""#;
1519 let result = parse_www_authenticate(header).unwrap();
1520 assert_eq!(result.error, Some(BearerError::InvalidRequest));
1521 }
1522
1523 #[test]
1524 fn test_parse_www_authenticate_unknown_error() {
1525 let header = r#"Bearer error="some_future_error""#;
1526 let result = parse_www_authenticate(header).unwrap();
1527 assert_eq!(result.error, Some(BearerError::Other));
1528 }
1529
1530 #[test]
1531 fn test_parse_www_authenticate_rejects_non_bearer() {
1532 let result = parse_www_authenticate("Basic realm=\"example\"");
1533 assert!(result.is_err());
1534 }
1535
1536 #[test]
1537 fn test_parse_www_authenticate_case_insensitive_scheme() {
1538 let header = r#"bearer resource_metadata="https://example.com/.well-known/oauth-protected-resource""#;
1539 let result = parse_www_authenticate(header).unwrap();
1540 assert!(result.resource_metadata.is_some());
1541 }
1542
1543 #[test]
1544 fn test_parse_www_authenticate_multiline_style() {
1545 // Some servers emit the header spread across multiple lines joined by
1546 // whitespace, as shown in the spec examples.
1547 let header = "Bearer resource_metadata=\"https://mcp.example.com/.well-known/oauth-protected-resource\",\n scope=\"files:read\"";
1548 let result = parse_www_authenticate(header).unwrap();
1549 assert!(result.resource_metadata.is_some());
1550 assert_eq!(result.scope, Some(vec!["files:read".to_string()]));
1551 }
1552
1553 #[test]
1554 fn test_protected_resource_metadata_urls_with_path() {
1555 let server_url = Url::parse("https://api.example.com/v1/mcp").unwrap();
1556 let urls = protected_resource_metadata_urls(&server_url);
1557
1558 assert_eq!(urls.len(), 2);
1559 assert_eq!(
1560 urls[0].as_str(),
1561 "https://api.example.com/.well-known/oauth-protected-resource/v1/mcp"
1562 );
1563 assert_eq!(
1564 urls[1].as_str(),
1565 "https://api.example.com/.well-known/oauth-protected-resource"
1566 );
1567 }
1568
1569 #[test]
1570 fn test_protected_resource_metadata_urls_without_path() {
1571 let server_url = Url::parse("https://mcp.example.com").unwrap();
1572 let urls = protected_resource_metadata_urls(&server_url);
1573
1574 assert_eq!(urls.len(), 1);
1575 assert_eq!(
1576 urls[0].as_str(),
1577 "https://mcp.example.com/.well-known/oauth-protected-resource"
1578 );
1579 }
1580
1581 #[test]
1582 fn test_auth_server_metadata_urls_with_path() {
1583 let issuer = Url::parse("https://auth.example.com/tenant1").unwrap();
1584 let urls = auth_server_metadata_urls(&issuer);
1585
1586 assert_eq!(urls.len(), 3);
1587 assert_eq!(
1588 urls[0].as_str(),
1589 "https://auth.example.com/.well-known/oauth-authorization-server/tenant1"
1590 );
1591 assert_eq!(
1592 urls[1].as_str(),
1593 "https://auth.example.com/.well-known/openid-configuration/tenant1"
1594 );
1595 assert_eq!(
1596 urls[2].as_str(),
1597 "https://auth.example.com/tenant1/.well-known/openid-configuration"
1598 );
1599 }
1600
1601 #[test]
1602 fn test_auth_server_metadata_urls_without_path() {
1603 let issuer = Url::parse("https://auth.example.com").unwrap();
1604 let urls = auth_server_metadata_urls(&issuer);
1605
1606 assert_eq!(urls.len(), 2);
1607 assert_eq!(
1608 urls[0].as_str(),
1609 "https://auth.example.com/.well-known/oauth-authorization-server"
1610 );
1611 assert_eq!(
1612 urls[1].as_str(),
1613 "https://auth.example.com/.well-known/openid-configuration"
1614 );
1615 }
1616
1617 // -- Canonical server URI tests ------------------------------------------
1618
1619 #[test]
1620 fn test_canonical_server_uri_simple() {
1621 let url = Url::parse("https://mcp.example.com").unwrap();
1622 assert_eq!(canonical_server_uri(&url), "https://mcp.example.com");
1623 }
1624
1625 #[test]
1626 fn test_canonical_server_uri_with_path() {
1627 let url = Url::parse("https://mcp.example.com/v1/mcp").unwrap();
1628 assert_eq!(canonical_server_uri(&url), "https://mcp.example.com/v1/mcp");
1629 }
1630
1631 #[test]
1632 fn test_canonical_server_uri_strips_trailing_slash() {
1633 let url = Url::parse("https://mcp.example.com/").unwrap();
1634 assert_eq!(canonical_server_uri(&url), "https://mcp.example.com");
1635 }
1636
1637 #[test]
1638 fn test_canonical_server_uri_preserves_port() {
1639 let url = Url::parse("https://mcp.example.com:8443").unwrap();
1640 assert_eq!(canonical_server_uri(&url), "https://mcp.example.com:8443");
1641 }
1642
1643 #[test]
1644 fn test_canonical_server_uri_lowercases() {
1645 let url = Url::parse("HTTPS://MCP.Example.COM/Server/MCP").unwrap();
1646 assert_eq!(
1647 canonical_server_uri(&url),
1648 "https://mcp.example.com/Server/MCP"
1649 );
1650 }
1651
1652 // -- Scope selection tests -----------------------------------------------
1653
1654 #[test]
1655 fn test_select_scopes_prefers_www_authenticate() {
1656 let www_auth = WwwAuthenticate {
1657 resource_metadata: None,
1658 scope: Some(vec!["files:read".into()]),
1659 error: None,
1660 error_description: None,
1661 };
1662 let resource_meta = ProtectedResourceMetadata {
1663 resource: Url::parse("https://example.com").unwrap(),
1664 authorization_servers: vec![],
1665 scopes_supported: Some(vec!["files:read".into(), "files:write".into()]),
1666 };
1667 assert_eq!(select_scopes(&www_auth, &resource_meta), vec!["files:read"]);
1668 }
1669
1670 #[test]
1671 fn test_select_scopes_falls_back_to_resource_metadata() {
1672 let www_auth = WwwAuthenticate {
1673 resource_metadata: None,
1674 scope: None,
1675 error: None,
1676 error_description: None,
1677 };
1678 let resource_meta = ProtectedResourceMetadata {
1679 resource: Url::parse("https://example.com").unwrap(),
1680 authorization_servers: vec![],
1681 scopes_supported: Some(vec!["admin".into()]),
1682 };
1683 assert_eq!(select_scopes(&www_auth, &resource_meta), vec!["admin"]);
1684 }
1685
1686 #[test]
1687 fn test_select_scopes_empty_when_nothing_available() {
1688 let www_auth = WwwAuthenticate {
1689 resource_metadata: None,
1690 scope: None,
1691 error: None,
1692 error_description: None,
1693 };
1694 let resource_meta = ProtectedResourceMetadata {
1695 resource: Url::parse("https://example.com").unwrap(),
1696 authorization_servers: vec![],
1697 scopes_supported: None,
1698 };
1699 assert!(select_scopes(&www_auth, &resource_meta).is_empty());
1700 }
1701
1702 // -- Client registration strategy tests ----------------------------------
1703
1704 #[test]
1705 fn test_registration_strategy_prefers_cimd() {
1706 let metadata = AuthServerMetadata {
1707 issuer: Url::parse("https://auth.example.com").unwrap(),
1708 authorization_endpoint: Url::parse("https://auth.example.com/authorize").unwrap(),
1709 token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
1710 registration_endpoint: Some(Url::parse("https://auth.example.com/register").unwrap()),
1711 scopes_supported: None,
1712 code_challenge_methods_supported: Some(vec!["S256".into()]),
1713 client_id_metadata_document_supported: true,
1714 };
1715 assert_eq!(
1716 determine_registration_strategy(&metadata),
1717 ClientRegistrationStrategy::Cimd {
1718 client_id: CIMD_URL.to_string(),
1719 }
1720 );
1721 }
1722
1723 #[test]
1724 fn test_registration_strategy_falls_back_to_dcr() {
1725 let reg_endpoint = Url::parse("https://auth.example.com/register").unwrap();
1726 let metadata = AuthServerMetadata {
1727 issuer: Url::parse("https://auth.example.com").unwrap(),
1728 authorization_endpoint: Url::parse("https://auth.example.com/authorize").unwrap(),
1729 token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
1730 registration_endpoint: Some(reg_endpoint.clone()),
1731 scopes_supported: None,
1732 code_challenge_methods_supported: Some(vec!["S256".into()]),
1733 client_id_metadata_document_supported: false,
1734 };
1735 assert_eq!(
1736 determine_registration_strategy(&metadata),
1737 ClientRegistrationStrategy::Dcr {
1738 registration_endpoint: reg_endpoint,
1739 }
1740 );
1741 }
1742
1743 #[test]
1744 fn test_registration_strategy_unavailable() {
1745 let metadata = AuthServerMetadata {
1746 issuer: Url::parse("https://auth.example.com").unwrap(),
1747 authorization_endpoint: Url::parse("https://auth.example.com/authorize").unwrap(),
1748 token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
1749 registration_endpoint: None,
1750 scopes_supported: None,
1751 code_challenge_methods_supported: Some(vec!["S256".into()]),
1752 client_id_metadata_document_supported: false,
1753 };
1754 assert_eq!(
1755 determine_registration_strategy(&metadata),
1756 ClientRegistrationStrategy::Unavailable,
1757 );
1758 }
1759
1760 // -- PKCE tests ----------------------------------------------------------
1761
1762 #[test]
1763 fn test_pkce_challenge_verifier_length() {
1764 let pkce = generate_pkce_challenge();
1765 // 32 random bytes → 43 base64url chars (no padding).
1766 assert_eq!(pkce.verifier.len(), 43);
1767 }
1768
1769 #[test]
1770 fn test_pkce_challenge_is_valid_base64url() {
1771 let pkce = generate_pkce_challenge();
1772 for c in pkce.verifier.chars().chain(pkce.challenge.chars()) {
1773 assert!(
1774 c.is_ascii_alphanumeric() || c == '-' || c == '_',
1775 "invalid base64url character: {}",
1776 c
1777 );
1778 }
1779 }
1780
1781 #[test]
1782 fn test_pkce_challenge_is_s256_of_verifier() {
1783 let pkce = generate_pkce_challenge();
1784 let engine = base64::engine::general_purpose::URL_SAFE_NO_PAD;
1785 let expected_digest = Sha256::digest(pkce.verifier.as_bytes());
1786 let expected_challenge = engine.encode(expected_digest);
1787 assert_eq!(pkce.challenge, expected_challenge);
1788 }
1789
1790 #[test]
1791 fn test_pkce_challenges_are_unique() {
1792 let a = generate_pkce_challenge();
1793 let b = generate_pkce_challenge();
1794 assert_ne!(a.verifier, b.verifier);
1795 }
1796
1797 // -- Authorization URL tests ---------------------------------------------
1798
1799 #[test]
1800 fn test_build_authorization_url() {
1801 let metadata = AuthServerMetadata {
1802 issuer: Url::parse("https://auth.example.com").unwrap(),
1803 authorization_endpoint: Url::parse("https://auth.example.com/authorize").unwrap(),
1804 token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
1805 registration_endpoint: None,
1806 scopes_supported: None,
1807 code_challenge_methods_supported: Some(vec!["S256".into()]),
1808 client_id_metadata_document_supported: true,
1809 };
1810 let pkce = PkceChallenge {
1811 verifier: "test_verifier".into(),
1812 challenge: "test_challenge".into(),
1813 };
1814 let url = build_authorization_url(
1815 &metadata,
1816 "https://zed.dev/oauth/client-metadata.json",
1817 "http://127.0.0.1:12345/callback",
1818 &["files:read".into(), "files:write".into()],
1819 "https://mcp.example.com",
1820 &pkce,
1821 "random_state_123",
1822 );
1823
1824 let pairs: std::collections::HashMap<_, _> = url.query_pairs().collect();
1825 assert_eq!(pairs.get("response_type").unwrap(), "code");
1826 assert_eq!(
1827 pairs.get("client_id").unwrap(),
1828 "https://zed.dev/oauth/client-metadata.json"
1829 );
1830 assert_eq!(
1831 pairs.get("redirect_uri").unwrap(),
1832 "http://127.0.0.1:12345/callback"
1833 );
1834 assert_eq!(pairs.get("scope").unwrap(), "files:read files:write");
1835 assert_eq!(pairs.get("resource").unwrap(), "https://mcp.example.com");
1836 assert_eq!(pairs.get("code_challenge").unwrap(), "test_challenge");
1837 assert_eq!(pairs.get("code_challenge_method").unwrap(), "S256");
1838 assert_eq!(pairs.get("state").unwrap(), "random_state_123");
1839 }
1840
1841 #[test]
1842 fn test_build_authorization_url_omits_empty_scope() {
1843 let metadata = AuthServerMetadata {
1844 issuer: Url::parse("https://auth.example.com").unwrap(),
1845 authorization_endpoint: Url::parse("https://auth.example.com/authorize").unwrap(),
1846 token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
1847 registration_endpoint: None,
1848 scopes_supported: None,
1849 code_challenge_methods_supported: Some(vec!["S256".into()]),
1850 client_id_metadata_document_supported: false,
1851 };
1852 let pkce = PkceChallenge {
1853 verifier: "v".into(),
1854 challenge: "c".into(),
1855 };
1856 let url = build_authorization_url(
1857 &metadata,
1858 "client_123",
1859 "http://127.0.0.1:9999/callback",
1860 &[],
1861 "https://mcp.example.com",
1862 &pkce,
1863 "state",
1864 );
1865
1866 let pairs: std::collections::HashMap<_, _> = url.query_pairs().collect();
1867 assert!(!pairs.contains_key("scope"));
1868 }
1869
1870 // -- Token exchange / refresh param tests --------------------------------
1871
1872 #[test]
1873 fn test_token_exchange_params() {
1874 let params = token_exchange_params(
1875 "auth_code_abc",
1876 "client_xyz",
1877 "http://127.0.0.1:5555/callback",
1878 "verifier_123",
1879 "https://mcp.example.com",
1880 );
1881 let map: std::collections::HashMap<&str, &str> =
1882 params.iter().map(|(k, v)| (*k, v.as_str())).collect();
1883
1884 assert_eq!(map["grant_type"], "authorization_code");
1885 assert_eq!(map["code"], "auth_code_abc");
1886 assert_eq!(map["redirect_uri"], "http://127.0.0.1:5555/callback");
1887 assert_eq!(map["client_id"], "client_xyz");
1888 assert_eq!(map["code_verifier"], "verifier_123");
1889 assert_eq!(map["resource"], "https://mcp.example.com");
1890 }
1891
1892 #[test]
1893 fn test_token_refresh_params() {
1894 let params =
1895 token_refresh_params("refresh_token_abc", "client_xyz", "https://mcp.example.com");
1896 let map: std::collections::HashMap<&str, &str> =
1897 params.iter().map(|(k, v)| (*k, v.as_str())).collect();
1898
1899 assert_eq!(map["grant_type"], "refresh_token");
1900 assert_eq!(map["refresh_token"], "refresh_token_abc");
1901 assert_eq!(map["client_id"], "client_xyz");
1902 assert_eq!(map["resource"], "https://mcp.example.com");
1903 }
1904
1905 // -- Token response tests ------------------------------------------------
1906
1907 #[test]
1908 fn test_token_response_into_tokens_with_expiry() {
1909 let response: TokenResponse = serde_json::from_str(
1910 r#"{"access_token": "at_123", "refresh_token": "rt_456", "expires_in": 3600, "token_type": "Bearer"}"#,
1911 )
1912 .unwrap();
1913
1914 let tokens = response.into_tokens();
1915 assert_eq!(tokens.access_token, "at_123");
1916 assert_eq!(tokens.refresh_token.as_deref(), Some("rt_456"));
1917 assert!(tokens.expires_at.is_some());
1918 }
1919
1920 #[test]
1921 fn test_token_response_into_tokens_minimal() {
1922 let response: TokenResponse =
1923 serde_json::from_str(r#"{"access_token": "at_789"}"#).unwrap();
1924
1925 let tokens = response.into_tokens();
1926 assert_eq!(tokens.access_token, "at_789");
1927 assert_eq!(tokens.refresh_token, None);
1928 assert_eq!(tokens.expires_at, None);
1929 }
1930
1931 // -- DCR body test -------------------------------------------------------
1932
1933 #[test]
1934 fn test_dcr_registration_body_shape() {
1935 let body = dcr_registration_body("http://127.0.0.1:12345/callback");
1936 assert_eq!(body["client_name"], "Zed");
1937 assert_eq!(body["redirect_uris"][0], "http://127.0.0.1:12345/callback");
1938 assert_eq!(body["grant_types"][0], "authorization_code");
1939 assert_eq!(body["response_types"][0], "code");
1940 assert_eq!(body["token_endpoint_auth_method"], "none");
1941 }
1942
1943 // -- Test helpers for async/HTTP tests -----------------------------------
1944
1945 fn make_fake_http_client(
1946 handler: impl Fn(
1947 http_client::Request<AsyncBody>,
1948 ) -> std::pin::Pin<
1949 Box<dyn std::future::Future<Output = anyhow::Result<Response<AsyncBody>>> + Send>,
1950 > + Send
1951 + Sync
1952 + 'static,
1953 ) -> Arc<dyn HttpClient> {
1954 http_client::FakeHttpClient::create(handler) as Arc<dyn HttpClient>
1955 }
1956
1957 fn json_response(status: u16, body: &str) -> anyhow::Result<Response<AsyncBody>> {
1958 Ok(Response::builder()
1959 .status(status)
1960 .header("Content-Type", "application/json")
1961 .body(AsyncBody::from(body.as_bytes().to_vec()))
1962 .unwrap())
1963 }
1964
1965 // -- Discovery integration tests -----------------------------------------
1966
1967 #[test]
1968 fn test_fetch_protected_resource_metadata() {
1969 smol::block_on(async {
1970 let client = make_fake_http_client(|req| {
1971 Box::pin(async move {
1972 let uri = req.uri().to_string();
1973 if uri.contains(".well-known/oauth-protected-resource") {
1974 json_response(
1975 200,
1976 r#"{
1977 "resource": "https://mcp.example.com",
1978 "authorization_servers": ["https://auth.example.com"],
1979 "scopes_supported": ["read", "write"]
1980 }"#,
1981 )
1982 } else {
1983 json_response(404, "{}")
1984 }
1985 })
1986 });
1987
1988 let server_url = Url::parse("https://mcp.example.com").unwrap();
1989 let www_auth = WwwAuthenticate {
1990 resource_metadata: None,
1991 scope: None,
1992 error: None,
1993 error_description: None,
1994 };
1995
1996 let metadata = fetch_protected_resource_metadata(&client, &server_url, &www_auth)
1997 .await
1998 .unwrap();
1999
2000 assert_eq!(metadata.resource.as_str(), "https://mcp.example.com/");
2001 assert_eq!(metadata.authorization_servers.len(), 1);
2002 assert_eq!(
2003 metadata.authorization_servers[0].as_str(),
2004 "https://auth.example.com/"
2005 );
2006 assert_eq!(
2007 metadata.scopes_supported,
2008 Some(vec!["read".to_string(), "write".to_string()])
2009 );
2010 });
2011 }
2012
2013 #[test]
2014 fn test_fetch_protected_resource_metadata_prefers_www_authenticate_url() {
2015 smol::block_on(async {
2016 let client = make_fake_http_client(|req| {
2017 Box::pin(async move {
2018 let uri = req.uri().to_string();
2019 if uri == "https://mcp.example.com/custom-resource-metadata" {
2020 json_response(
2021 200,
2022 r#"{
2023 "resource": "https://mcp.example.com",
2024 "authorization_servers": ["https://auth.example.com"]
2025 }"#,
2026 )
2027 } else {
2028 json_response(500, r#"{"error": "should not be called"}"#)
2029 }
2030 })
2031 });
2032
2033 let server_url = Url::parse("https://mcp.example.com").unwrap();
2034 let www_auth = WwwAuthenticate {
2035 resource_metadata: Some(
2036 Url::parse("https://mcp.example.com/custom-resource-metadata").unwrap(),
2037 ),
2038 scope: None,
2039 error: None,
2040 error_description: None,
2041 };
2042
2043 let metadata = fetch_protected_resource_metadata(&client, &server_url, &www_auth)
2044 .await
2045 .unwrap();
2046
2047 assert_eq!(metadata.authorization_servers.len(), 1);
2048 });
2049 }
2050
2051 #[test]
2052 fn test_fetch_protected_resource_metadata_rejects_cross_origin_url() {
2053 smol::block_on(async {
2054 let client = make_fake_http_client(|req| {
2055 Box::pin(async move {
2056 let uri = req.uri().to_string();
2057 // The cross-origin URL should NOT be fetched; only the
2058 // well-known fallback at the server's own origin should be.
2059 if uri.contains("attacker.example.com") {
2060 panic!("should not fetch cross-origin resource_metadata URL");
2061 } else if uri.contains(".well-known/oauth-protected-resource") {
2062 json_response(
2063 200,
2064 r#"{
2065 "resource": "https://mcp.example.com",
2066 "authorization_servers": ["https://auth.example.com"]
2067 }"#,
2068 )
2069 } else {
2070 json_response(404, "{}")
2071 }
2072 })
2073 });
2074
2075 let server_url = Url::parse("https://mcp.example.com").unwrap();
2076 let www_auth = WwwAuthenticate {
2077 resource_metadata: Some(
2078 Url::parse("https://attacker.example.com/fake-metadata").unwrap(),
2079 ),
2080 scope: None,
2081 error: None,
2082 error_description: None,
2083 };
2084
2085 let metadata = fetch_protected_resource_metadata(&client, &server_url, &www_auth)
2086 .await
2087 .unwrap();
2088
2089 // Should have used the fallback well-known URL, not the attacker's.
2090 assert_eq!(metadata.resource.as_str(), "https://mcp.example.com/");
2091 });
2092 }
2093
2094 #[test]
2095 fn test_fetch_auth_server_metadata() {
2096 smol::block_on(async {
2097 let client = make_fake_http_client(|req| {
2098 Box::pin(async move {
2099 let uri = req.uri().to_string();
2100 if uri.contains(".well-known/oauth-authorization-server") {
2101 json_response(
2102 200,
2103 r#"{
2104 "issuer": "https://auth.example.com",
2105 "authorization_endpoint": "https://auth.example.com/authorize",
2106 "token_endpoint": "https://auth.example.com/token",
2107 "registration_endpoint": "https://auth.example.com/register",
2108 "code_challenge_methods_supported": ["S256"],
2109 "client_id_metadata_document_supported": true
2110 }"#,
2111 )
2112 } else {
2113 json_response(404, "{}")
2114 }
2115 })
2116 });
2117
2118 let issuer = Url::parse("https://auth.example.com").unwrap();
2119 let metadata = fetch_auth_server_metadata(&client, &issuer).await.unwrap();
2120
2121 assert_eq!(metadata.issuer.as_str(), "https://auth.example.com/");
2122 assert_eq!(
2123 metadata.authorization_endpoint.as_str(),
2124 "https://auth.example.com/authorize"
2125 );
2126 assert_eq!(
2127 metadata.token_endpoint.as_str(),
2128 "https://auth.example.com/token"
2129 );
2130 assert!(metadata.registration_endpoint.is_some());
2131 assert!(metadata.client_id_metadata_document_supported);
2132 assert_eq!(
2133 metadata.code_challenge_methods_supported,
2134 Some(vec!["S256".to_string()])
2135 );
2136 });
2137 }
2138
2139 #[test]
2140 fn test_fetch_auth_server_metadata_falls_back_to_oidc() {
2141 smol::block_on(async {
2142 let client = make_fake_http_client(|req| {
2143 Box::pin(async move {
2144 let uri = req.uri().to_string();
2145 if uri.contains("openid-configuration") {
2146 json_response(
2147 200,
2148 r#"{
2149 "issuer": "https://auth.example.com",
2150 "authorization_endpoint": "https://auth.example.com/authorize",
2151 "token_endpoint": "https://auth.example.com/token",
2152 "code_challenge_methods_supported": ["S256"]
2153 }"#,
2154 )
2155 } else {
2156 json_response(404, "{}")
2157 }
2158 })
2159 });
2160
2161 let issuer = Url::parse("https://auth.example.com").unwrap();
2162 let metadata = fetch_auth_server_metadata(&client, &issuer).await.unwrap();
2163
2164 assert_eq!(
2165 metadata.authorization_endpoint.as_str(),
2166 "https://auth.example.com/authorize"
2167 );
2168 assert!(!metadata.client_id_metadata_document_supported);
2169 });
2170 }
2171
2172 #[test]
2173 fn test_fetch_auth_server_metadata_rejects_issuer_mismatch() {
2174 smol::block_on(async {
2175 let client = make_fake_http_client(|req| {
2176 Box::pin(async move {
2177 let uri = req.uri().to_string();
2178 if uri.contains(".well-known/oauth-authorization-server") {
2179 // Response claims to be a different issuer.
2180 json_response(
2181 200,
2182 r#"{
2183 "issuer": "https://evil.example.com",
2184 "authorization_endpoint": "https://evil.example.com/authorize",
2185 "token_endpoint": "https://evil.example.com/token",
2186 "code_challenge_methods_supported": ["S256"]
2187 }"#,
2188 )
2189 } else {
2190 json_response(404, "{}")
2191 }
2192 })
2193 });
2194
2195 let issuer = Url::parse("https://auth.example.com").unwrap();
2196 let result = fetch_auth_server_metadata(&client, &issuer).await;
2197
2198 assert!(result.is_err());
2199 let err_msg = result.unwrap_err().to_string();
2200 assert!(
2201 err_msg.contains("issuer mismatch"),
2202 "unexpected error: {}",
2203 err_msg
2204 );
2205 });
2206 }
2207
2208 // -- Full discover integration tests -------------------------------------
2209
2210 #[test]
2211 fn test_full_discover_with_cimd() {
2212 smol::block_on(async {
2213 let client = make_fake_http_client(|req| {
2214 Box::pin(async move {
2215 let uri = req.uri().to_string();
2216 if uri.contains("oauth-protected-resource") {
2217 json_response(
2218 200,
2219 r#"{
2220 "resource": "https://mcp.example.com",
2221 "authorization_servers": ["https://auth.example.com"],
2222 "scopes_supported": ["mcp:read"]
2223 }"#,
2224 )
2225 } else if uri.contains("oauth-authorization-server") {
2226 json_response(
2227 200,
2228 r#"{
2229 "issuer": "https://auth.example.com",
2230 "authorization_endpoint": "https://auth.example.com/authorize",
2231 "token_endpoint": "https://auth.example.com/token",
2232 "code_challenge_methods_supported": ["S256"],
2233 "client_id_metadata_document_supported": true
2234 }"#,
2235 )
2236 } else {
2237 json_response(404, "{}")
2238 }
2239 })
2240 });
2241
2242 let server_url = Url::parse("https://mcp.example.com").unwrap();
2243 let www_auth = WwwAuthenticate {
2244 resource_metadata: None,
2245 scope: None,
2246 error: None,
2247 error_description: None,
2248 };
2249
2250 let discovery = discover(&client, &server_url, &www_auth).await.unwrap();
2251 let registration =
2252 resolve_client_registration(&client, &discovery, "http://127.0.0.1:12345/callback")
2253 .await
2254 .unwrap();
2255
2256 assert_eq!(registration.client_id, CIMD_URL);
2257 assert_eq!(registration.client_secret, None);
2258 assert_eq!(discovery.scopes, vec!["mcp:read"]);
2259 });
2260 }
2261
2262 #[test]
2263 fn test_full_discover_with_dcr_fallback() {
2264 smol::block_on(async {
2265 let client = make_fake_http_client(|req| {
2266 Box::pin(async move {
2267 let uri = req.uri().to_string();
2268 if uri.contains("oauth-protected-resource") {
2269 json_response(
2270 200,
2271 r#"{
2272 "resource": "https://mcp.example.com",
2273 "authorization_servers": ["https://auth.example.com"]
2274 }"#,
2275 )
2276 } else if uri.contains("oauth-authorization-server") {
2277 json_response(
2278 200,
2279 r#"{
2280 "issuer": "https://auth.example.com",
2281 "authorization_endpoint": "https://auth.example.com/authorize",
2282 "token_endpoint": "https://auth.example.com/token",
2283 "registration_endpoint": "https://auth.example.com/register",
2284 "code_challenge_methods_supported": ["S256"],
2285 "client_id_metadata_document_supported": false
2286 }"#,
2287 )
2288 } else if uri.contains("/register") {
2289 json_response(
2290 201,
2291 r#"{
2292 "client_id": "dcr-minted-id-123",
2293 "client_secret": "dcr-secret-456"
2294 }"#,
2295 )
2296 } else {
2297 json_response(404, "{}")
2298 }
2299 })
2300 });
2301
2302 let server_url = Url::parse("https://mcp.example.com").unwrap();
2303 let www_auth = WwwAuthenticate {
2304 resource_metadata: None,
2305 scope: Some(vec!["files:read".into()]),
2306 error: None,
2307 error_description: None,
2308 };
2309
2310 let discovery = discover(&client, &server_url, &www_auth).await.unwrap();
2311 let registration =
2312 resolve_client_registration(&client, &discovery, "http://127.0.0.1:9999/callback")
2313 .await
2314 .unwrap();
2315
2316 assert_eq!(registration.client_id, "dcr-minted-id-123");
2317 assert_eq!(
2318 registration.client_secret.as_deref(),
2319 Some("dcr-secret-456")
2320 );
2321 assert_eq!(discovery.scopes, vec!["files:read"]);
2322 });
2323 }
2324
2325 #[test]
2326 fn test_discover_fails_without_pkce_support() {
2327 smol::block_on(async {
2328 let client = make_fake_http_client(|req| {
2329 Box::pin(async move {
2330 let uri = req.uri().to_string();
2331 if uri.contains("oauth-protected-resource") {
2332 json_response(
2333 200,
2334 r#"{
2335 "resource": "https://mcp.example.com",
2336 "authorization_servers": ["https://auth.example.com"]
2337 }"#,
2338 )
2339 } else if uri.contains("oauth-authorization-server") {
2340 json_response(
2341 200,
2342 r#"{
2343 "issuer": "https://auth.example.com",
2344 "authorization_endpoint": "https://auth.example.com/authorize",
2345 "token_endpoint": "https://auth.example.com/token"
2346 }"#,
2347 )
2348 } else {
2349 json_response(404, "{}")
2350 }
2351 })
2352 });
2353
2354 let server_url = Url::parse("https://mcp.example.com").unwrap();
2355 let www_auth = WwwAuthenticate {
2356 resource_metadata: None,
2357 scope: None,
2358 error: None,
2359 error_description: None,
2360 };
2361
2362 let result = discover(&client, &server_url, &www_auth).await;
2363 assert!(result.is_err());
2364 let err_msg = result.unwrap_err().to_string();
2365 assert!(
2366 err_msg.contains("code_challenge_methods_supported"),
2367 "unexpected error: {}",
2368 err_msg
2369 );
2370 });
2371 }
2372
2373 // -- Token exchange integration tests ------------------------------------
2374
2375 #[test]
2376 fn test_exchange_code_success() {
2377 smol::block_on(async {
2378 let client = make_fake_http_client(|req| {
2379 Box::pin(async move {
2380 let uri = req.uri().to_string();
2381 if uri.contains("/token") {
2382 json_response(
2383 200,
2384 r#"{
2385 "access_token": "new_access_token",
2386 "refresh_token": "new_refresh_token",
2387 "expires_in": 3600,
2388 "token_type": "Bearer"
2389 }"#,
2390 )
2391 } else {
2392 json_response(404, "{}")
2393 }
2394 })
2395 });
2396
2397 let metadata = AuthServerMetadata {
2398 issuer: Url::parse("https://auth.example.com").unwrap(),
2399 authorization_endpoint: Url::parse("https://auth.example.com/authorize").unwrap(),
2400 token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
2401 registration_endpoint: None,
2402 scopes_supported: None,
2403 code_challenge_methods_supported: Some(vec!["S256".into()]),
2404 client_id_metadata_document_supported: true,
2405 };
2406
2407 let tokens = exchange_code(
2408 &client,
2409 &metadata,
2410 "auth_code_123",
2411 CIMD_URL,
2412 "http://127.0.0.1:9999/callback",
2413 "verifier_abc",
2414 "https://mcp.example.com",
2415 )
2416 .await
2417 .unwrap();
2418
2419 assert_eq!(tokens.access_token, "new_access_token");
2420 assert_eq!(tokens.refresh_token.as_deref(), Some("new_refresh_token"));
2421 assert!(tokens.expires_at.is_some());
2422 });
2423 }
2424
2425 #[test]
2426 fn test_refresh_tokens_success() {
2427 smol::block_on(async {
2428 let client = make_fake_http_client(|req| {
2429 Box::pin(async move {
2430 let uri = req.uri().to_string();
2431 if uri.contains("/token") {
2432 json_response(
2433 200,
2434 r#"{
2435 "access_token": "refreshed_token",
2436 "expires_in": 1800,
2437 "token_type": "Bearer"
2438 }"#,
2439 )
2440 } else {
2441 json_response(404, "{}")
2442 }
2443 })
2444 });
2445
2446 let token_endpoint = Url::parse("https://auth.example.com/token").unwrap();
2447
2448 let tokens = refresh_tokens(
2449 &client,
2450 &token_endpoint,
2451 "old_refresh_token",
2452 CIMD_URL,
2453 "https://mcp.example.com",
2454 )
2455 .await
2456 .unwrap();
2457
2458 assert_eq!(tokens.access_token, "refreshed_token");
2459 assert_eq!(tokens.refresh_token, None);
2460 assert!(tokens.expires_at.is_some());
2461 });
2462 }
2463
2464 #[test]
2465 fn test_exchange_code_failure() {
2466 smol::block_on(async {
2467 let client = make_fake_http_client(|_req| {
2468 Box::pin(async move { json_response(400, r#"{"error": "invalid_grant"}"#) })
2469 });
2470
2471 let metadata = AuthServerMetadata {
2472 issuer: Url::parse("https://auth.example.com").unwrap(),
2473 authorization_endpoint: Url::parse("https://auth.example.com/authorize").unwrap(),
2474 token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
2475 registration_endpoint: None,
2476 scopes_supported: None,
2477 code_challenge_methods_supported: Some(vec!["S256".into()]),
2478 client_id_metadata_document_supported: true,
2479 };
2480
2481 let result = exchange_code(
2482 &client,
2483 &metadata,
2484 "bad_code",
2485 "client",
2486 "http://127.0.0.1:1/callback",
2487 "verifier",
2488 "https://mcp.example.com",
2489 )
2490 .await;
2491
2492 assert!(result.is_err());
2493 assert!(result.unwrap_err().to_string().contains("400"));
2494 });
2495 }
2496
2497 // -- DCR integration tests -----------------------------------------------
2498
2499 #[test]
2500 fn test_perform_dcr() {
2501 smol::block_on(async {
2502 let client = make_fake_http_client(|_req| {
2503 Box::pin(async move {
2504 json_response(
2505 201,
2506 r#"{
2507 "client_id": "dynamic-client-001",
2508 "client_secret": "dynamic-secret-001"
2509 }"#,
2510 )
2511 })
2512 });
2513
2514 let endpoint = Url::parse("https://auth.example.com/register").unwrap();
2515 let registration = perform_dcr(&client, &endpoint, "http://127.0.0.1:9999/callback")
2516 .await
2517 .unwrap();
2518
2519 assert_eq!(registration.client_id, "dynamic-client-001");
2520 assert_eq!(
2521 registration.client_secret.as_deref(),
2522 Some("dynamic-secret-001")
2523 );
2524 });
2525 }
2526
2527 #[test]
2528 fn test_perform_dcr_failure() {
2529 smol::block_on(async {
2530 let client = make_fake_http_client(|_req| {
2531 Box::pin(
2532 async move { json_response(403, r#"{"error": "registration_not_allowed"}"#) },
2533 )
2534 });
2535
2536 let endpoint = Url::parse("https://auth.example.com/register").unwrap();
2537 let result = perform_dcr(&client, &endpoint, "http://127.0.0.1:9999/callback").await;
2538
2539 assert!(result.is_err());
2540 assert!(result.unwrap_err().to_string().contains("403"));
2541 });
2542 }
2543
2544 // -- OAuthCallback parse tests -------------------------------------------
2545
2546 #[test]
2547 fn test_oauth_callback_parse_query() {
2548 let callback = OAuthCallback::parse_query("code=test_auth_code&state=test_state").unwrap();
2549 assert_eq!(callback.code, "test_auth_code");
2550 assert_eq!(callback.state, "test_state");
2551 }
2552
2553 #[test]
2554 fn test_oauth_callback_parse_query_reversed_order() {
2555 let callback = OAuthCallback::parse_query("state=test_state&code=test_auth_code").unwrap();
2556 assert_eq!(callback.code, "test_auth_code");
2557 assert_eq!(callback.state, "test_state");
2558 }
2559
2560 #[test]
2561 fn test_oauth_callback_parse_query_with_extra_params() {
2562 let callback =
2563 OAuthCallback::parse_query("code=test_auth_code&state=test_state&extra=ignored")
2564 .unwrap();
2565 assert_eq!(callback.code, "test_auth_code");
2566 assert_eq!(callback.state, "test_state");
2567 }
2568
2569 #[test]
2570 fn test_oauth_callback_parse_query_missing_code() {
2571 let result = OAuthCallback::parse_query("state=test_state");
2572 assert!(result.is_err());
2573 assert!(result.unwrap_err().to_string().contains("code"));
2574 }
2575
2576 #[test]
2577 fn test_oauth_callback_parse_query_missing_state() {
2578 let result = OAuthCallback::parse_query("code=test_auth_code");
2579 assert!(result.is_err());
2580 assert!(result.unwrap_err().to_string().contains("state"));
2581 }
2582
2583 #[test]
2584 fn test_oauth_callback_parse_query_empty_code() {
2585 let result = OAuthCallback::parse_query("code=&state=test_state");
2586 assert!(result.is_err());
2587 }
2588
2589 #[test]
2590 fn test_oauth_callback_parse_query_empty_state() {
2591 let result = OAuthCallback::parse_query("code=test_auth_code&state=");
2592 assert!(result.is_err());
2593 }
2594
2595 #[test]
2596 fn test_oauth_callback_parse_query_url_encoded_values() {
2597 let callback = OAuthCallback::parse_query("code=abc%20def&state=test%3Dstate").unwrap();
2598 assert_eq!(callback.code, "abc def");
2599 assert_eq!(callback.state, "test=state");
2600 }
2601
2602 #[test]
2603 fn test_oauth_callback_parse_query_error_response() {
2604 let result = OAuthCallback::parse_query(
2605 "error=access_denied&error_description=User%20denied%20access&state=abc",
2606 );
2607 assert!(result.is_err());
2608 let err_msg = result.unwrap_err().to_string();
2609 assert!(
2610 err_msg.contains("access_denied"),
2611 "unexpected error: {}",
2612 err_msg
2613 );
2614 assert!(
2615 err_msg.contains("User denied access"),
2616 "unexpected error: {}",
2617 err_msg
2618 );
2619 }
2620
2621 #[test]
2622 fn test_oauth_callback_parse_query_error_without_description() {
2623 let result = OAuthCallback::parse_query("error=server_error&state=abc");
2624 assert!(result.is_err());
2625 let err_msg = result.unwrap_err().to_string();
2626 assert!(
2627 err_msg.contains("server_error"),
2628 "unexpected error: {}",
2629 err_msg
2630 );
2631 assert!(
2632 err_msg.contains("no description"),
2633 "unexpected error: {}",
2634 err_msg
2635 );
2636 }
2637
2638 // -- McpOAuthTokenProvider tests -----------------------------------------
2639
2640 fn make_test_session(
2641 access_token: &str,
2642 refresh_token: Option<&str>,
2643 expires_at: Option<SystemTime>,
2644 ) -> OAuthSession {
2645 OAuthSession {
2646 token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
2647 resource: Url::parse("https://mcp.example.com").unwrap(),
2648 client_registration: OAuthClientRegistration {
2649 client_id: "test-client".into(),
2650 client_secret: None,
2651 },
2652 tokens: OAuthTokens {
2653 access_token: access_token.into(),
2654 refresh_token: refresh_token.map(String::from),
2655 expires_at,
2656 },
2657 }
2658 }
2659
2660 #[test]
2661 fn test_mcp_oauth_provider_returns_none_when_token_expired() {
2662 let expired = SystemTime::now() - Duration::from_secs(60);
2663 let session = make_test_session("stale-token", Some("rt"), Some(expired));
2664 let provider = McpOAuthTokenProvider::new(
2665 session,
2666 make_fake_http_client(|_| Box::pin(async { unreachable!() })),
2667 None,
2668 );
2669
2670 assert_eq!(provider.access_token(), None);
2671 }
2672
2673 #[test]
2674 fn test_mcp_oauth_provider_returns_token_when_not_expired() {
2675 let far_future = SystemTime::now() + Duration::from_secs(3600);
2676 let session = make_test_session("valid-token", Some("rt"), Some(far_future));
2677 let provider = McpOAuthTokenProvider::new(
2678 session,
2679 make_fake_http_client(|_| Box::pin(async { unreachable!() })),
2680 None,
2681 );
2682
2683 assert_eq!(provider.access_token().as_deref(), Some("valid-token"));
2684 }
2685
2686 #[test]
2687 fn test_mcp_oauth_provider_returns_token_when_no_expiry() {
2688 let session = make_test_session("no-expiry-token", Some("rt"), None);
2689 let provider = McpOAuthTokenProvider::new(
2690 session,
2691 make_fake_http_client(|_| Box::pin(async { unreachable!() })),
2692 None,
2693 );
2694
2695 assert_eq!(provider.access_token().as_deref(), Some("no-expiry-token"));
2696 }
2697
2698 #[test]
2699 fn test_mcp_oauth_provider_refresh_without_refresh_token_returns_false() {
2700 smol::block_on(async {
2701 let session = make_test_session("token", None, None);
2702 let provider = McpOAuthTokenProvider::new(
2703 session,
2704 make_fake_http_client(|_| {
2705 Box::pin(async { unreachable!("no HTTP call expected") })
2706 }),
2707 None,
2708 );
2709
2710 let refreshed = provider.try_refresh().await.unwrap();
2711 assert!(!refreshed);
2712 });
2713 }
2714
2715 #[test]
2716 fn test_mcp_oauth_provider_refresh_updates_session_and_notifies_channel() {
2717 smol::block_on(async {
2718 let session = make_test_session("old-access", Some("my-refresh-token"), None);
2719 let (tx, mut rx) = futures::channel::mpsc::unbounded();
2720
2721 let http_client = make_fake_http_client(|_req| {
2722 Box::pin(async {
2723 json_response(
2724 200,
2725 r#"{
2726 "access_token": "new-access",
2727 "refresh_token": "new-refresh",
2728 "expires_in": 1800
2729 }"#,
2730 )
2731 })
2732 });
2733
2734 let provider = McpOAuthTokenProvider::new(session, http_client, Some(tx));
2735
2736 let refreshed = provider.try_refresh().await.unwrap();
2737 assert!(refreshed);
2738 assert_eq!(provider.access_token().as_deref(), Some("new-access"));
2739
2740 let notified_session = rx.try_recv().expect("channel should have a session");
2741 assert_eq!(notified_session.tokens.access_token, "new-access");
2742 assert_eq!(
2743 notified_session.tokens.refresh_token.as_deref(),
2744 Some("new-refresh")
2745 );
2746 });
2747 }
2748
2749 #[test]
2750 fn test_mcp_oauth_provider_refresh_preserves_old_refresh_token_when_server_omits_it() {
2751 smol::block_on(async {
2752 let session = make_test_session("old-access", Some("original-refresh"), None);
2753 let (tx, mut rx) = futures::channel::mpsc::unbounded();
2754
2755 let http_client = make_fake_http_client(|_req| {
2756 Box::pin(async {
2757 json_response(
2758 200,
2759 r#"{
2760 "access_token": "new-access",
2761 "expires_in": 900
2762 }"#,
2763 )
2764 })
2765 });
2766
2767 let provider = McpOAuthTokenProvider::new(session, http_client, Some(tx));
2768
2769 let refreshed = provider.try_refresh().await.unwrap();
2770 assert!(refreshed);
2771
2772 let notified_session = rx.try_recv().expect("channel should have a session");
2773 assert_eq!(notified_session.tokens.access_token, "new-access");
2774 assert_eq!(
2775 notified_session.tokens.refresh_token.as_deref(),
2776 Some("original-refresh"),
2777 );
2778 });
2779 }
2780
2781 #[test]
2782 fn test_mcp_oauth_provider_refresh_returns_false_on_http_error() {
2783 smol::block_on(async {
2784 let session = make_test_session("old-access", Some("my-refresh"), None);
2785
2786 let http_client = make_fake_http_client(|_req| {
2787 Box::pin(async { json_response(401, r#"{"error": "invalid_grant"}"#) })
2788 });
2789
2790 let provider = McpOAuthTokenProvider::new(session, http_client, None);
2791
2792 let refreshed = provider.try_refresh().await.unwrap();
2793 assert!(!refreshed);
2794 // The old token should still be in place.
2795 assert_eq!(provider.access_token().as_deref(), Some("old-access"));
2796 });
2797 }
2798}