1//! OAuth 2.0 authentication for MCP servers using the Authorization Code +
2//! PKCE flow, per the MCP spec's OAuth profile.
3//!
4//! The flow is split into two phases:
5//!
6//! 1. **Discovery** ([`discover`]) fetches Protected Resource Metadata and
7//! Authorization Server Metadata. This can happen early (e.g. on a 401
8//! during server startup) because it doesn't need the redirect URI yet.
9//!
10//! 2. **Client registration** ([`resolve_client_registration`]) is separate
11//! because DCR requires the actual loopback redirect URI, which includes an
12//! ephemeral port that only exists once the callback server has started.
13//!
14//! After authentication, the full state is captured in [`OAuthSession`] which
15//! is persisted to the keychain. On next startup, the stored session feeds
16//! directly into [`McpOAuthTokenProvider`], giving a refresh-capable provider
17//! without requiring another browser flow.
18
19use anyhow::{Context as _, Result, anyhow, bail};
20use async_trait::async_trait;
21use base64::Engine as _;
22use futures::AsyncReadExt as _;
23use futures::channel::mpsc;
24use http_client::{AsyncBody, HttpClient, Request};
25use parking_lot::Mutex as SyncMutex;
26use rand::Rng as _;
27use serde::{Deserialize, Serialize};
28use sha2::{Digest, Sha256};
29
30use std::str::FromStr;
31use std::sync::Arc;
32use std::time::{Duration, SystemTime};
33use url::Url;
34use util::ResultExt as _;
35
36/// The CIMD URL where Zed's OAuth client metadata document is hosted.
37pub const CIMD_URL: &str = "https://zed.dev/oauth/client-metadata.json";
38
39/// Validate that a URL is safe to use as an OAuth endpoint.
40///
41/// OAuth endpoints carry sensitive material (authorization codes, PKCE
42/// verifiers, tokens) and must use TLS. Plain HTTP is only permitted for
43/// loopback addresses, per RFC 8252 Section 8.3.
44fn require_https_or_loopback(url: &Url) -> Result<()> {
45 if url.scheme() == "https" {
46 return Ok(());
47 }
48 if url.scheme() == "http" {
49 if let Some(host) = url.host() {
50 match host {
51 url::Host::Ipv4(ip) if ip.is_loopback() => return Ok(()),
52 url::Host::Ipv6(ip) if ip.is_loopback() => return Ok(()),
53 url::Host::Domain(d) if d.eq_ignore_ascii_case("localhost") => return Ok(()),
54 _ => {}
55 }
56 }
57 }
58 bail!(
59 "OAuth endpoint must use HTTPS (got {}://{})",
60 url.scheme(),
61 url.host_str().unwrap_or("?")
62 )
63}
64
65/// Validate that a URL is safe to use as an OAuth endpoint, including SSRF
66/// protections against private/reserved IP ranges.
67///
68/// This wraps [`require_https_or_loopback`] and adds IP-range checks to prevent
69/// an attacker-controlled MCP server from directing Zed to fetch internal
70/// network resources via metadata URLs.
71///
72/// **Known limitation:** Domain-name URLs that resolve to private IPs are *not*
73/// blocked here — full mitigation requires resolver-level validation (e.g. a
74/// custom `Resolve` implementation). This function only blocks IP-literal URLs.
75fn validate_oauth_url(url: &Url) -> Result<()> {
76 require_https_or_loopback(url)?;
77
78 if let Some(host) = url.host() {
79 match host {
80 url::Host::Ipv4(ip) => {
81 // Loopback is already allowed by require_https_or_loopback.
82 if ip.is_private() || ip.is_link_local() || ip.is_broadcast() || ip.is_unspecified()
83 {
84 bail!(
85 "OAuth endpoint must not point to private/reserved IP: {}",
86 ip
87 );
88 }
89 }
90 url::Host::Ipv6(ip) => {
91 // Check for IPv4-mapped IPv6 addresses (::ffff:a.b.c.d) which
92 // could bypass the IPv4 checks above.
93 if let Some(mapped_v4) = ip.to_ipv4_mapped() {
94 if mapped_v4.is_private()
95 || mapped_v4.is_link_local()
96 || mapped_v4.is_broadcast()
97 || mapped_v4.is_unspecified()
98 {
99 bail!(
100 "OAuth endpoint must not point to private/reserved IP: ::ffff:{}",
101 mapped_v4
102 );
103 }
104 }
105
106 if ip.is_unspecified() || ip.is_multicast() {
107 bail!(
108 "OAuth endpoint must not point to reserved IPv6 address: {}",
109 ip
110 );
111 }
112 // IPv6 Unique Local Addresses (fc00::/7). is_unique_local() is
113 // nightly-only, so check the prefix manually.
114 if (ip.segments()[0] & 0xfe00) == 0xfc00 {
115 bail!(
116 "OAuth endpoint must not point to IPv6 unique-local address: {}",
117 ip
118 );
119 }
120 }
121 url::Host::Domain(_) => {
122 // Domain-based SSRF prevention requires resolver-level checks.
123 // See known limitation in the doc comment above.
124 }
125 }
126 }
127
128 Ok(())
129}
130
131/// Parsed from the MCP server's WWW-Authenticate header or well-known endpoint
132/// per RFC 9728 (OAuth 2.0 Protected Resource Metadata).
133#[derive(Debug, Clone, Serialize, Deserialize)]
134pub struct ProtectedResourceMetadata {
135 pub resource: Url,
136 pub authorization_servers: Vec<Url>,
137 pub scopes_supported: Option<Vec<String>>,
138}
139
140/// Parsed from the authorization server's .well-known endpoint
141/// per RFC 8414 (OAuth 2.0 Authorization Server Metadata).
142#[derive(Debug, Clone, Serialize, Deserialize)]
143pub struct AuthServerMetadata {
144 pub issuer: Url,
145 pub authorization_endpoint: Url,
146 pub token_endpoint: Url,
147 pub registration_endpoint: Option<Url>,
148 pub scopes_supported: Option<Vec<String>>,
149 pub code_challenge_methods_supported: Option<Vec<String>>,
150 pub client_id_metadata_document_supported: bool,
151}
152
153/// The result of client registration — either CIMD or DCR.
154#[derive(Clone, Serialize, Deserialize)]
155pub struct OAuthClientRegistration {
156 pub client_id: String,
157 /// Only present for DCR-minted registrations.
158 pub client_secret: Option<String>,
159}
160
161impl std::fmt::Debug for OAuthClientRegistration {
162 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
163 f.debug_struct("OAuthClientRegistration")
164 .field("client_id", &self.client_id)
165 .field(
166 "client_secret",
167 &self.client_secret.as_ref().map(|_| "[redacted]"),
168 )
169 .finish()
170 }
171}
172
173/// Access and refresh tokens obtained from the token endpoint.
174#[derive(Clone, Serialize, Deserialize)]
175pub struct OAuthTokens {
176 pub access_token: String,
177 pub refresh_token: Option<String>,
178 pub expires_at: Option<SystemTime>,
179}
180
181impl std::fmt::Debug for OAuthTokens {
182 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
183 f.debug_struct("OAuthTokens")
184 .field("access_token", &"[redacted]")
185 .field(
186 "refresh_token",
187 &self.refresh_token.as_ref().map(|_| "[redacted]"),
188 )
189 .field("expires_at", &self.expires_at)
190 .finish()
191 }
192}
193
194/// Everything discovered before the browser flow starts. Client registration is
195/// resolved separately, once the real redirect URI is known.
196#[derive(Debug, Clone, Serialize, Deserialize)]
197pub struct OAuthDiscovery {
198 pub resource_metadata: ProtectedResourceMetadata,
199 pub auth_server_metadata: AuthServerMetadata,
200 pub scopes: Vec<String>,
201}
202
203/// The persisted OAuth session for a context server.
204///
205/// Stored in the keychain so startup can restore a refresh-capable provider
206/// without another browser flow. Deliberately excludes the full discovery
207/// metadata to keep the serialized size well within keychain item limits.
208#[derive(Clone, Serialize, Deserialize)]
209pub struct OAuthSession {
210 pub token_endpoint: Url,
211 pub resource: Url,
212 pub client_registration: OAuthClientRegistration,
213 pub tokens: OAuthTokens,
214}
215
216impl std::fmt::Debug for OAuthSession {
217 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
218 f.debug_struct("OAuthSession")
219 .field("token_endpoint", &self.token_endpoint)
220 .field("resource", &self.resource)
221 .field("client_registration", &self.client_registration)
222 .field("tokens", &self.tokens)
223 .finish()
224 }
225}
226
227/// Error codes defined by RFC 6750 Section 3.1 for Bearer token authentication.
228#[derive(Debug, Clone, Copy, PartialEq, Eq)]
229pub enum BearerError {
230 /// The request is missing a required parameter, includes an unsupported
231 /// parameter or parameter value, or is otherwise malformed.
232 InvalidRequest,
233 /// The access token provided is expired, revoked, malformed, or invalid.
234 InvalidToken,
235 /// The request requires higher privileges than provided by the access token.
236 InsufficientScope,
237 /// An unrecognized error code (extension or future spec addition).
238 Other,
239}
240
241impl BearerError {
242 fn parse(value: &str) -> Self {
243 match value {
244 "invalid_request" => BearerError::InvalidRequest,
245 "invalid_token" => BearerError::InvalidToken,
246 "insufficient_scope" => BearerError::InsufficientScope,
247 _ => BearerError::Other,
248 }
249 }
250}
251
252/// Fields extracted from a `WWW-Authenticate: Bearer` header.
253///
254/// Per RFC 9728 Section 5.1, MCP servers include `resource_metadata` to point
255/// at the Protected Resource Metadata document. The optional `scope` parameter
256/// (RFC 6750 Section 3) indicates scopes required for the request.
257#[derive(Debug, Clone, PartialEq, Eq)]
258pub struct WwwAuthenticate {
259 pub resource_metadata: Option<Url>,
260 pub scope: Option<Vec<String>>,
261 /// The parsed `error` parameter per RFC 6750 Section 3.1.
262 pub error: Option<BearerError>,
263 pub error_description: Option<String>,
264}
265
266/// Parse a `WWW-Authenticate` header value.
267///
268/// Expects the `Bearer` scheme followed by comma-separated `key="value"` pairs.
269/// Per RFC 6750 and RFC 9728, the relevant parameters are:
270/// - `resource_metadata` — URL of the Protected Resource Metadata document
271/// - `scope` — space-separated list of required scopes
272/// - `error` — error code (e.g. "insufficient_scope")
273/// - `error_description` — human-readable error description
274pub fn parse_www_authenticate(header: &str) -> Result<WwwAuthenticate> {
275 let header = header.trim();
276
277 let params_str = if header.len() >= 6 && header[..6].eq_ignore_ascii_case("bearer") {
278 header[6..].trim()
279 } else {
280 bail!("WWW-Authenticate header does not use Bearer scheme");
281 };
282
283 if params_str.is_empty() {
284 return Ok(WwwAuthenticate {
285 resource_metadata: None,
286 scope: None,
287 error: None,
288 error_description: None,
289 });
290 }
291
292 let params = parse_auth_params(params_str);
293
294 let resource_metadata = params
295 .get("resource_metadata")
296 .map(|v| Url::parse(v))
297 .transpose()
298 .map_err(|e| anyhow!("invalid resource_metadata URL: {}", e))?;
299
300 let scope = params
301 .get("scope")
302 .map(|v| v.split_whitespace().map(String::from).collect());
303
304 let error = params.get("error").map(|v| BearerError::parse(v));
305 let error_description = params.get("error_description").cloned();
306
307 Ok(WwwAuthenticate {
308 resource_metadata,
309 scope,
310 error,
311 error_description,
312 })
313}
314
315/// Parse comma-separated `key="value"` or `key=token` parameters from an
316/// auth-param list (RFC 7235 Section 2.1).
317fn parse_auth_params(input: &str) -> collections::HashMap<String, String> {
318 let mut params = collections::HashMap::default();
319 let mut remaining = input.trim();
320
321 while !remaining.is_empty() {
322 // Skip leading whitespace and commas.
323 remaining = remaining.trim_start_matches(|c: char| c == ',' || c.is_whitespace());
324 if remaining.is_empty() {
325 break;
326 }
327
328 // Find the key (everything before '=').
329 let eq_pos = match remaining.find('=') {
330 Some(pos) => pos,
331 None => break,
332 };
333
334 let key = remaining[..eq_pos].trim().to_lowercase();
335 remaining = &remaining[eq_pos + 1..];
336 remaining = remaining.trim_start();
337
338 // Parse the value: either quoted or unquoted (token).
339 let value;
340 if remaining.starts_with('"') {
341 // Quoted string: find the closing quote, handling escaped chars.
342 remaining = &remaining[1..]; // skip opening quote
343 let mut val = String::new();
344 let mut chars = remaining.char_indices();
345 loop {
346 match chars.next() {
347 Some((_, '\\')) => {
348 // Escaped character — take the next char literally.
349 if let Some((_, c)) = chars.next() {
350 val.push(c);
351 }
352 }
353 Some((i, '"')) => {
354 remaining = &remaining[i + 1..];
355 break;
356 }
357 Some((_, c)) => val.push(c),
358 None => {
359 remaining = "";
360 break;
361 }
362 }
363 }
364 value = val;
365 } else {
366 // Unquoted token: read until comma or whitespace.
367 let end = remaining
368 .find(|c: char| c == ',' || c.is_whitespace())
369 .unwrap_or(remaining.len());
370 value = remaining[..end].to_string();
371 remaining = &remaining[end..];
372 }
373
374 if !key.is_empty() {
375 params.insert(key, value);
376 }
377 }
378
379 params
380}
381
382/// Construct the well-known Protected Resource Metadata URIs for a given MCP
383/// server URL, per RFC 9728 Section 3.
384///
385/// Returns URIs in priority order:
386/// 1. Path-specific: `https://<host>/.well-known/oauth-protected-resource/<path>`
387/// 2. Root: `https://<host>/.well-known/oauth-protected-resource`
388pub fn protected_resource_metadata_urls(server_url: &Url) -> Vec<Url> {
389 let mut urls = Vec::new();
390 let base = format!("{}://{}", server_url.scheme(), server_url.authority());
391
392 let path = server_url.path().trim_start_matches('/');
393 if !path.is_empty() {
394 if let Ok(url) = Url::parse(&format!(
395 "{}/.well-known/oauth-protected-resource/{}",
396 base, path
397 )) {
398 urls.push(url);
399 }
400 }
401
402 if let Ok(url) = Url::parse(&format!("{}/.well-known/oauth-protected-resource", base)) {
403 urls.push(url);
404 }
405
406 urls
407}
408
409/// Construct the well-known Authorization Server Metadata URIs for a given
410/// issuer URL, per RFC 8414 Section 3.1 and Section 5 (OIDC compat).
411///
412/// Returns URIs in priority order, which differs depending on whether the
413/// issuer URL has a path component.
414pub fn auth_server_metadata_urls(issuer: &Url) -> Vec<Url> {
415 let mut urls = Vec::new();
416 let base = format!("{}://{}", issuer.scheme(), issuer.authority());
417 let path = issuer.path().trim_matches('/');
418
419 if !path.is_empty() {
420 // Issuer with path: try path-inserted variants first.
421 if let Ok(url) = Url::parse(&format!(
422 "{}/.well-known/oauth-authorization-server/{}",
423 base, path
424 )) {
425 urls.push(url);
426 }
427 if let Ok(url) = Url::parse(&format!(
428 "{}/.well-known/openid-configuration/{}",
429 base, path
430 )) {
431 urls.push(url);
432 }
433 if let Ok(url) = Url::parse(&format!(
434 "{}/{}/.well-known/openid-configuration",
435 base, path
436 )) {
437 urls.push(url);
438 }
439 } else {
440 // No path: standard well-known locations.
441 if let Ok(url) = Url::parse(&format!("{}/.well-known/oauth-authorization-server", base)) {
442 urls.push(url);
443 }
444 if let Ok(url) = Url::parse(&format!("{}/.well-known/openid-configuration", base)) {
445 urls.push(url);
446 }
447 }
448
449 urls
450}
451
452// -- Canonical server URI (RFC 8707) -----------------------------------------
453
454/// Derive the canonical resource URI for an MCP server URL, suitable for the
455/// `resource` parameter in authorization and token requests per RFC 8707.
456///
457/// Lowercases the scheme and host, preserves the path (without trailing slash),
458/// strips fragments and query strings.
459pub fn canonical_server_uri(server_url: &Url) -> String {
460 let mut uri = format!(
461 "{}://{}",
462 server_url.scheme().to_ascii_lowercase(),
463 server_url.host_str().unwrap_or("").to_ascii_lowercase(),
464 );
465 if let Some(port) = server_url.port() {
466 uri.push_str(&format!(":{}", port));
467 }
468 let path = server_url.path();
469 if path != "/" {
470 uri.push_str(path.trim_end_matches('/'));
471 }
472 uri
473}
474
475// -- Scope selection ---------------------------------------------------------
476
477/// Select scopes following the MCP spec's Scope Selection Strategy:
478/// 1. Use `scope` from the `WWW-Authenticate` challenge if present.
479/// 2. Fall back to `scopes_supported` from Protected Resource Metadata.
480/// 3. Return empty if neither is available.
481pub fn select_scopes(
482 www_authenticate: &WwwAuthenticate,
483 resource_metadata: &ProtectedResourceMetadata,
484) -> Vec<String> {
485 if let Some(ref scopes) = www_authenticate.scope {
486 if !scopes.is_empty() {
487 return scopes.clone();
488 }
489 }
490 resource_metadata
491 .scopes_supported
492 .clone()
493 .unwrap_or_default()
494}
495
496// -- Client registration strategy --------------------------------------------
497
498/// The registration approach to use, determined from auth server metadata.
499#[derive(Debug, Clone, PartialEq, Eq)]
500pub enum ClientRegistrationStrategy {
501 /// The auth server supports CIMD. Use the CIMD URL as client_id directly.
502 Cimd { client_id: String },
503 /// The auth server has a registration endpoint. Caller must POST to it.
504 Dcr { registration_endpoint: Url },
505 /// No supported registration mechanism.
506 Unavailable,
507}
508
509/// Determine how to register with the authorization server, following the
510/// spec's recommended priority: CIMD first, DCR fallback.
511pub fn determine_registration_strategy(
512 auth_server_metadata: &AuthServerMetadata,
513) -> ClientRegistrationStrategy {
514 if auth_server_metadata.client_id_metadata_document_supported {
515 ClientRegistrationStrategy::Cimd {
516 client_id: CIMD_URL.to_string(),
517 }
518 } else if let Some(ref endpoint) = auth_server_metadata.registration_endpoint {
519 ClientRegistrationStrategy::Dcr {
520 registration_endpoint: endpoint.clone(),
521 }
522 } else {
523 ClientRegistrationStrategy::Unavailable
524 }
525}
526
527// -- PKCE (RFC 7636) ---------------------------------------------------------
528
529/// A PKCE code verifier and its S256 challenge.
530#[derive(Clone)]
531pub struct PkceChallenge {
532 pub verifier: String,
533 pub challenge: String,
534}
535
536impl std::fmt::Debug for PkceChallenge {
537 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
538 f.debug_struct("PkceChallenge")
539 .field("verifier", &"[redacted]")
540 .field("challenge", &self.challenge)
541 .finish()
542 }
543}
544
545/// Generate a PKCE code verifier and S256 challenge per RFC 7636.
546///
547/// The verifier is 43 base64url characters derived from 32 random bytes.
548/// The challenge is `BASE64URL(SHA256(verifier))`.
549pub fn generate_pkce_challenge() -> PkceChallenge {
550 let mut random_bytes = [0u8; 32];
551 rand::rng().fill(&mut random_bytes);
552 let engine = base64::engine::general_purpose::URL_SAFE_NO_PAD;
553 let verifier = engine.encode(&random_bytes);
554
555 let digest = Sha256::digest(verifier.as_bytes());
556 let challenge = engine.encode(digest);
557
558 PkceChallenge {
559 verifier,
560 challenge,
561 }
562}
563
564// -- Authorization URL construction ------------------------------------------
565
566/// Build the authorization URL for the OAuth Authorization Code + PKCE flow.
567pub fn build_authorization_url(
568 auth_server_metadata: &AuthServerMetadata,
569 client_id: &str,
570 redirect_uri: &str,
571 scopes: &[String],
572 resource: &str,
573 pkce: &PkceChallenge,
574 state: &str,
575) -> Url {
576 let mut url = auth_server_metadata.authorization_endpoint.clone();
577 {
578 let mut query = url.query_pairs_mut();
579 query.append_pair("response_type", "code");
580 query.append_pair("client_id", client_id);
581 query.append_pair("redirect_uri", redirect_uri);
582 if !scopes.is_empty() {
583 query.append_pair("scope", &scopes.join(" "));
584 }
585 query.append_pair("resource", resource);
586 query.append_pair("code_challenge", &pkce.challenge);
587 query.append_pair("code_challenge_method", "S256");
588 query.append_pair("state", state);
589 }
590 url
591}
592
593// -- Token endpoint request bodies -------------------------------------------
594
595/// The JSON body returned by the token endpoint on success.
596#[derive(Deserialize)]
597pub struct TokenResponse {
598 pub access_token: String,
599 #[serde(default)]
600 pub refresh_token: Option<String>,
601 #[serde(default)]
602 pub expires_in: Option<u64>,
603 #[serde(default)]
604 pub token_type: Option<String>,
605}
606
607impl std::fmt::Debug for TokenResponse {
608 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
609 f.debug_struct("TokenResponse")
610 .field("access_token", &"[redacted]")
611 .field(
612 "refresh_token",
613 &self.refresh_token.as_ref().map(|_| "[redacted]"),
614 )
615 .field("expires_in", &self.expires_in)
616 .field("token_type", &self.token_type)
617 .finish()
618 }
619}
620
621impl TokenResponse {
622 /// Convert into `OAuthTokens`, computing `expires_at` from `expires_in`.
623 pub fn into_tokens(self) -> OAuthTokens {
624 let expires_at = self
625 .expires_in
626 .map(|secs| SystemTime::now() + Duration::from_secs(secs));
627 OAuthTokens {
628 access_token: self.access_token,
629 refresh_token: self.refresh_token,
630 expires_at,
631 }
632 }
633}
634
635/// Build the form-encoded body for an authorization code token exchange.
636pub fn token_exchange_params(
637 code: &str,
638 client_id: &str,
639 redirect_uri: &str,
640 code_verifier: &str,
641 resource: &str,
642) -> Vec<(&'static str, String)> {
643 vec![
644 ("grant_type", "authorization_code".to_string()),
645 ("code", code.to_string()),
646 ("redirect_uri", redirect_uri.to_string()),
647 ("client_id", client_id.to_string()),
648 ("code_verifier", code_verifier.to_string()),
649 ("resource", resource.to_string()),
650 ]
651}
652
653/// Build the form-encoded body for a token refresh request.
654pub fn token_refresh_params(
655 refresh_token: &str,
656 client_id: &str,
657 resource: &str,
658) -> Vec<(&'static str, String)> {
659 vec![
660 ("grant_type", "refresh_token".to_string()),
661 ("refresh_token", refresh_token.to_string()),
662 ("client_id", client_id.to_string()),
663 ("resource", resource.to_string()),
664 ]
665}
666
667// -- DCR request body (RFC 7591) ---------------------------------------------
668
669/// Build the JSON body for a Dynamic Client Registration request.
670///
671/// The `redirect_uri` should be the actual loopback URI with the ephemeral
672/// port (e.g. `http://127.0.0.1:12345/callback`). Some auth servers do strict
673/// redirect URI matching even for loopback addresses, so we register the
674/// exact URI we intend to use.
675pub fn dcr_registration_body(redirect_uri: &str) -> serde_json::Value {
676 serde_json::json!({
677 "client_name": "Zed",
678 "redirect_uris": [redirect_uri],
679 "grant_types": ["authorization_code"],
680 "response_types": ["code"],
681 "token_endpoint_auth_method": "none"
682 })
683}
684
685// -- Discovery (async, hits real endpoints) ----------------------------------
686
687/// Fetch Protected Resource Metadata from the MCP server.
688///
689/// Tries the `resource_metadata` URL from the `WWW-Authenticate` header first,
690/// then falls back to well-known URIs constructed from `server_url`.
691pub async fn fetch_protected_resource_metadata(
692 http_client: &Arc<dyn HttpClient>,
693 server_url: &Url,
694 www_authenticate: &WwwAuthenticate,
695) -> Result<ProtectedResourceMetadata> {
696 let candidate_urls = match &www_authenticate.resource_metadata {
697 Some(url) if url.origin() == server_url.origin() => vec![url.clone()],
698 Some(url) => {
699 log::warn!(
700 "Ignoring cross-origin resource_metadata URL {} \
701 (server origin: {})",
702 url,
703 server_url.origin().unicode_serialization()
704 );
705 protected_resource_metadata_urls(server_url)
706 }
707 None => protected_resource_metadata_urls(server_url),
708 };
709
710 for url in &candidate_urls {
711 match fetch_json::<ProtectedResourceMetadataResponse>(http_client, url).await {
712 Ok(response) => {
713 if response.authorization_servers.is_empty() {
714 bail!(
715 "Protected Resource Metadata at {} has no authorization_servers",
716 url
717 );
718 }
719 return Ok(ProtectedResourceMetadata {
720 resource: response.resource.unwrap_or_else(|| server_url.clone()),
721 authorization_servers: response.authorization_servers,
722 scopes_supported: response.scopes_supported,
723 });
724 }
725 Err(err) => {
726 log::debug!(
727 "Failed to fetch Protected Resource Metadata from {}: {}",
728 url,
729 err
730 );
731 }
732 }
733 }
734
735 bail!(
736 "Could not fetch Protected Resource Metadata for {}",
737 server_url
738 )
739}
740
741/// Fetch Authorization Server Metadata, trying RFC 8414 and OIDC Discovery
742/// endpoints in the priority order specified by the MCP spec.
743pub async fn fetch_auth_server_metadata(
744 http_client: &Arc<dyn HttpClient>,
745 issuer: &Url,
746) -> Result<AuthServerMetadata> {
747 let candidate_urls = auth_server_metadata_urls(issuer);
748
749 for url in &candidate_urls {
750 match fetch_json::<AuthServerMetadataResponse>(http_client, url).await {
751 Ok(response) => {
752 let reported_issuer = response.issuer.unwrap_or_else(|| issuer.clone());
753 if reported_issuer != *issuer {
754 bail!(
755 "Auth server metadata issuer mismatch: expected {}, got {}",
756 issuer,
757 reported_issuer
758 );
759 }
760
761 return Ok(AuthServerMetadata {
762 issuer: reported_issuer,
763 authorization_endpoint: response
764 .authorization_endpoint
765 .ok_or_else(|| anyhow!("missing authorization_endpoint"))?,
766 token_endpoint: response
767 .token_endpoint
768 .ok_or_else(|| anyhow!("missing token_endpoint"))?,
769 registration_endpoint: response.registration_endpoint,
770 scopes_supported: response.scopes_supported,
771 code_challenge_methods_supported: response.code_challenge_methods_supported,
772 client_id_metadata_document_supported: response
773 .client_id_metadata_document_supported
774 .unwrap_or(false),
775 });
776 }
777 Err(err) => {
778 log::debug!("Failed to fetch Auth Server Metadata from {}: {}", url, err);
779 }
780 }
781 }
782
783 bail!(
784 "Could not fetch Authorization Server Metadata for {}",
785 issuer
786 )
787}
788
789/// Run the full discovery flow: fetch resource metadata, then auth server
790/// metadata, then select scopes. Client registration is resolved separately,
791/// once the real redirect URI is known.
792pub async fn discover(
793 http_client: &Arc<dyn HttpClient>,
794 server_url: &Url,
795 www_authenticate: &WwwAuthenticate,
796) -> Result<OAuthDiscovery> {
797 let resource_metadata =
798 fetch_protected_resource_metadata(http_client, server_url, www_authenticate).await?;
799
800 let auth_server_url = resource_metadata
801 .authorization_servers
802 .first()
803 .ok_or_else(|| anyhow!("no authorization servers in resource metadata"))?;
804
805 let auth_server_metadata = fetch_auth_server_metadata(http_client, auth_server_url).await?;
806
807 // Verify PKCE S256 support (spec requirement).
808 match &auth_server_metadata.code_challenge_methods_supported {
809 Some(methods) if methods.iter().any(|m| m == "S256") => {}
810 Some(_) => bail!("authorization server does not support S256 PKCE"),
811 None => bail!("authorization server does not advertise code_challenge_methods_supported"),
812 }
813
814 // Verify there is at least one supported registration strategy before we
815 // present the server as ready to authenticate.
816 match determine_registration_strategy(&auth_server_metadata) {
817 ClientRegistrationStrategy::Cimd { .. } | ClientRegistrationStrategy::Dcr { .. } => {}
818 ClientRegistrationStrategy::Unavailable => {
819 bail!("authorization server supports neither CIMD nor DCR")
820 }
821 }
822
823 let scopes = select_scopes(www_authenticate, &resource_metadata);
824
825 Ok(OAuthDiscovery {
826 resource_metadata,
827 auth_server_metadata,
828 scopes,
829 })
830}
831
832/// Resolve the OAuth client registration for an authorization flow.
833///
834/// CIMD uses the static client metadata document directly. For DCR, a fresh
835/// registration is performed each time because the loopback redirect URI
836/// includes an ephemeral port that changes every flow.
837pub async fn resolve_client_registration(
838 http_client: &Arc<dyn HttpClient>,
839 discovery: &OAuthDiscovery,
840 redirect_uri: &str,
841) -> Result<OAuthClientRegistration> {
842 match determine_registration_strategy(&discovery.auth_server_metadata) {
843 ClientRegistrationStrategy::Cimd { client_id } => Ok(OAuthClientRegistration {
844 client_id,
845 client_secret: None,
846 }),
847 ClientRegistrationStrategy::Dcr {
848 registration_endpoint,
849 } => perform_dcr(http_client, ®istration_endpoint, redirect_uri).await,
850 ClientRegistrationStrategy::Unavailable => {
851 bail!("authorization server supports neither CIMD nor DCR")
852 }
853 }
854}
855
856// -- Dynamic Client Registration (RFC 7591) ----------------------------------
857
858/// Perform Dynamic Client Registration with the authorization server.
859pub async fn perform_dcr(
860 http_client: &Arc<dyn HttpClient>,
861 registration_endpoint: &Url,
862 redirect_uri: &str,
863) -> Result<OAuthClientRegistration> {
864 validate_oauth_url(registration_endpoint)?;
865
866 let body = dcr_registration_body(redirect_uri);
867 let body_bytes = serde_json::to_vec(&body)?;
868
869 let request = Request::builder()
870 .method(http_client::http::Method::POST)
871 .uri(registration_endpoint.as_str())
872 .header("Content-Type", "application/json")
873 .header("Accept", "application/json")
874 .body(AsyncBody::from(body_bytes))?;
875
876 let mut response = http_client.send(request).await?;
877
878 if !response.status().is_success() {
879 let mut error_body = String::new();
880 response.body_mut().read_to_string(&mut error_body).await?;
881 bail!(
882 "DCR failed with status {}: {}",
883 response.status(),
884 error_body
885 );
886 }
887
888 let mut response_body = String::new();
889 response
890 .body_mut()
891 .read_to_string(&mut response_body)
892 .await?;
893
894 let dcr_response: DcrResponse =
895 serde_json::from_str(&response_body).context("failed to parse DCR response")?;
896
897 Ok(OAuthClientRegistration {
898 client_id: dcr_response.client_id,
899 client_secret: dcr_response.client_secret,
900 })
901}
902
903// -- Token exchange and refresh (async) --------------------------------------
904
905/// Exchange an authorization code for tokens at the token endpoint.
906pub async fn exchange_code(
907 http_client: &Arc<dyn HttpClient>,
908 auth_server_metadata: &AuthServerMetadata,
909 code: &str,
910 client_id: &str,
911 redirect_uri: &str,
912 code_verifier: &str,
913 resource: &str,
914) -> Result<OAuthTokens> {
915 let params = token_exchange_params(code, client_id, redirect_uri, code_verifier, resource);
916 post_token_request(http_client, &auth_server_metadata.token_endpoint, ¶ms).await
917}
918
919/// Refresh tokens using a refresh token.
920pub async fn refresh_tokens(
921 http_client: &Arc<dyn HttpClient>,
922 token_endpoint: &Url,
923 refresh_token: &str,
924 client_id: &str,
925 resource: &str,
926) -> Result<OAuthTokens> {
927 let params = token_refresh_params(refresh_token, client_id, resource);
928 post_token_request(http_client, token_endpoint, ¶ms).await
929}
930
931/// POST form-encoded parameters to a token endpoint and parse the response.
932async fn post_token_request(
933 http_client: &Arc<dyn HttpClient>,
934 token_endpoint: &Url,
935 params: &[(&str, String)],
936) -> Result<OAuthTokens> {
937 validate_oauth_url(token_endpoint)?;
938
939 let body = url::form_urlencoded::Serializer::new(String::new())
940 .extend_pairs(params.iter().map(|(k, v)| (*k, v.as_str())))
941 .finish();
942
943 let request = Request::builder()
944 .method(http_client::http::Method::POST)
945 .uri(token_endpoint.as_str())
946 .header("Content-Type", "application/x-www-form-urlencoded")
947 .header("Accept", "application/json")
948 .body(AsyncBody::from(body.into_bytes()))?;
949
950 let mut response = http_client.send(request).await?;
951
952 if !response.status().is_success() {
953 let mut error_body = String::new();
954 response.body_mut().read_to_string(&mut error_body).await?;
955 bail!(
956 "token request failed with status {}: {}",
957 response.status(),
958 error_body
959 );
960 }
961
962 let mut response_body = String::new();
963 response
964 .body_mut()
965 .read_to_string(&mut response_body)
966 .await?;
967
968 let token_response: TokenResponse =
969 serde_json::from_str(&response_body).context("failed to parse token response")?;
970
971 Ok(token_response.into_tokens())
972}
973
974// -- Loopback HTTP callback server -------------------------------------------
975
976/// An OAuth authorization callback received via the loopback HTTP server.
977pub struct OAuthCallback {
978 pub code: String,
979 pub state: String,
980}
981
982impl std::fmt::Debug for OAuthCallback {
983 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
984 f.debug_struct("OAuthCallback")
985 .field("code", &"[redacted]")
986 .field("state", &"[redacted]")
987 .finish()
988 }
989}
990
991impl OAuthCallback {
992 /// Parse the query string from a callback URL like
993 /// `http://127.0.0.1:<port>/callback?code=...&state=...`.
994 pub fn parse_query(query: &str) -> Result<Self> {
995 let mut code: Option<String> = None;
996 let mut state: Option<String> = None;
997 let mut error: Option<String> = None;
998 let mut error_description: Option<String> = None;
999
1000 for (key, value) in url::form_urlencoded::parse(query.as_bytes()) {
1001 match key.as_ref() {
1002 "code" => {
1003 if !value.is_empty() {
1004 code = Some(value.into_owned());
1005 }
1006 }
1007 "state" => {
1008 if !value.is_empty() {
1009 state = Some(value.into_owned());
1010 }
1011 }
1012 "error" => {
1013 if !value.is_empty() {
1014 error = Some(value.into_owned());
1015 }
1016 }
1017 "error_description" => {
1018 if !value.is_empty() {
1019 error_description = Some(value.into_owned());
1020 }
1021 }
1022 _ => {}
1023 }
1024 }
1025
1026 // Check for OAuth error response (RFC 6749 Section 4.1.2.1) before
1027 // checking for missing code/state.
1028 if let Some(error_code) = error {
1029 bail!(
1030 "OAuth authorization failed: {} ({})",
1031 error_code,
1032 error_description.as_deref().unwrap_or("no description")
1033 );
1034 }
1035
1036 let code = code.ok_or_else(|| anyhow!("missing 'code' parameter in OAuth callback"))?;
1037 let state = state.ok_or_else(|| anyhow!("missing 'state' parameter in OAuth callback"))?;
1038
1039 Ok(Self { code, state })
1040 }
1041}
1042
1043/// How long to wait for the browser to complete the OAuth flow before giving
1044/// up and releasing the loopback port.
1045const CALLBACK_TIMEOUT: Duration = Duration::from_secs(2 * 60);
1046
1047/// Start a loopback HTTP server to receive the OAuth authorization callback.
1048///
1049/// Binds to an ephemeral loopback port for each flow.
1050///
1051/// Returns `(redirect_uri, callback_future)`. The caller should use the
1052/// redirect URI in the authorization request, open the browser, then await
1053/// the future to receive the callback.
1054///
1055/// The server accepts exactly one request on `/callback`, validates that it
1056/// contains `code` and `state` query parameters, responds with a minimal
1057/// HTML page telling the user they can close the tab, and shuts down.
1058///
1059/// The callback server shuts down when the returned oneshot receiver is dropped
1060/// (e.g. because the authentication task was cancelled), or after a timeout
1061/// ([CALLBACK_TIMEOUT]).
1062pub async fn start_callback_server() -> Result<(
1063 String,
1064 futures::channel::oneshot::Receiver<Result<OAuthCallback>>,
1065)> {
1066 let server = tiny_http::Server::http("127.0.0.1:0")
1067 .map_err(|e| anyhow!(e).context("Failed to bind loopback listener for OAuth callback"))?;
1068 let port = server
1069 .server_addr()
1070 .to_ip()
1071 .context("server not bound to a TCP address")?
1072 .port();
1073
1074 let redirect_uri = format!("http://127.0.0.1:{}/callback", port);
1075
1076 let (tx, rx) = futures::channel::oneshot::channel();
1077
1078 // `tiny_http` is blocking, so we run it on a background thread.
1079 // The `recv_timeout` loop lets us check for cancellation (the receiver
1080 // being dropped) and enforce an overall timeout.
1081 std::thread::spawn(move || {
1082 let deadline = std::time::Instant::now() + CALLBACK_TIMEOUT;
1083
1084 loop {
1085 if tx.is_canceled() {
1086 return;
1087 }
1088 let remaining = deadline.saturating_duration_since(std::time::Instant::now());
1089 if remaining.is_zero() {
1090 return;
1091 }
1092
1093 let timeout = remaining.min(Duration::from_millis(500));
1094 let Some(request) = (match server.recv_timeout(timeout) {
1095 Ok(req) => req,
1096 Err(_) => {
1097 let _ = tx.send(Err(anyhow!("OAuth callback server I/O error")));
1098 return;
1099 }
1100 }) else {
1101 // Timeout with no request — loop back and check cancellation.
1102 continue;
1103 };
1104
1105 let result = handle_callback_request(&request);
1106
1107 let (status_code, body) = match &result {
1108 Ok(_) => (
1109 200,
1110 "<html><body><h1>Authorization successful</h1>\
1111 <p>You can close this tab and return to Zed.</p></body></html>",
1112 ),
1113 Err(err) => {
1114 log::error!("OAuth callback error: {}", err);
1115 (
1116 400,
1117 "<html><body><h1>Authorization failed</h1>\
1118 <p>Something went wrong. Please try again from Zed.</p></body></html>",
1119 )
1120 }
1121 };
1122
1123 let response = tiny_http::Response::from_string(body)
1124 .with_status_code(status_code)
1125 .with_header(
1126 tiny_http::Header::from_str("Content-Type: text/html")
1127 .expect("failed to construct response header"),
1128 )
1129 .with_header(
1130 tiny_http::Header::from_str("Keep-Alive: timeout=0,max=0")
1131 .expect("failed to construct response header"),
1132 );
1133 request.respond(response).log_err();
1134
1135 let _ = tx.send(result);
1136 return;
1137 }
1138 });
1139
1140 Ok((redirect_uri, rx))
1141}
1142
1143/// Extract the `code` and `state` query parameters from an OAuth callback
1144/// request to `/callback`.
1145fn handle_callback_request(request: &tiny_http::Request) -> Result<OAuthCallback> {
1146 let url = Url::parse(&format!("http://localhost{}", request.url()))
1147 .context("malformed callback request URL")?;
1148
1149 if url.path() != "/callback" {
1150 bail!("unexpected path in OAuth callback: {}", url.path());
1151 }
1152
1153 let query = url
1154 .query()
1155 .ok_or_else(|| anyhow!("OAuth callback has no query string"))?;
1156 OAuthCallback::parse_query(query)
1157}
1158
1159// -- JSON fetch helper -------------------------------------------------------
1160
1161async fn fetch_json<T: serde::de::DeserializeOwned>(
1162 http_client: &Arc<dyn HttpClient>,
1163 url: &Url,
1164) -> Result<T> {
1165 validate_oauth_url(url)?;
1166
1167 let request = Request::builder()
1168 .method(http_client::http::Method::GET)
1169 .uri(url.as_str())
1170 .header("Accept", "application/json")
1171 .body(AsyncBody::default())?;
1172
1173 let mut response = http_client.send(request).await?;
1174
1175 if !response.status().is_success() {
1176 bail!("HTTP {} fetching {}", response.status(), url);
1177 }
1178
1179 let mut body = String::new();
1180 response.body_mut().read_to_string(&mut body).await?;
1181 serde_json::from_str(&body).with_context(|| format!("failed to parse JSON from {}", url))
1182}
1183
1184// -- Serde response types for discovery --------------------------------------
1185
1186#[derive(Debug, Deserialize)]
1187struct ProtectedResourceMetadataResponse {
1188 #[serde(default)]
1189 resource: Option<Url>,
1190 #[serde(default)]
1191 authorization_servers: Vec<Url>,
1192 #[serde(default)]
1193 scopes_supported: Option<Vec<String>>,
1194}
1195
1196#[derive(Debug, Deserialize)]
1197struct AuthServerMetadataResponse {
1198 #[serde(default)]
1199 issuer: Option<Url>,
1200 #[serde(default)]
1201 authorization_endpoint: Option<Url>,
1202 #[serde(default)]
1203 token_endpoint: Option<Url>,
1204 #[serde(default)]
1205 registration_endpoint: Option<Url>,
1206 #[serde(default)]
1207 scopes_supported: Option<Vec<String>>,
1208 #[serde(default)]
1209 code_challenge_methods_supported: Option<Vec<String>>,
1210 #[serde(default)]
1211 client_id_metadata_document_supported: Option<bool>,
1212}
1213
1214#[derive(Debug, Deserialize)]
1215struct DcrResponse {
1216 client_id: String,
1217 #[serde(default)]
1218 client_secret: Option<String>,
1219}
1220
1221/// Provides OAuth tokens to the HTTP transport layer.
1222///
1223/// The transport calls `access_token()` before each request. On a 401 response
1224/// it calls `try_refresh()` and retries once if the refresh succeeds.
1225#[async_trait]
1226pub trait OAuthTokenProvider: Send + Sync {
1227 /// Returns the current access token, if one is available.
1228 fn access_token(&self) -> Option<String>;
1229
1230 /// Attempts to refresh the access token. Returns `true` if a new token was
1231 /// obtained and the request should be retried.
1232 async fn try_refresh(&self) -> Result<bool>;
1233}
1234
1235/// Concrete `OAuthTokenProvider` backed by a full persisted OAuth session and
1236/// an HTTP client for token refresh. The same provider type is used both after
1237/// an interactive authentication flow and when restoring a saved session from
1238/// the keychain on startup.
1239pub struct McpOAuthTokenProvider {
1240 session: SyncMutex<OAuthSession>,
1241 http_client: Arc<dyn HttpClient>,
1242 token_refresh_tx: Option<mpsc::UnboundedSender<OAuthSession>>,
1243}
1244
1245impl McpOAuthTokenProvider {
1246 pub fn new(
1247 session: OAuthSession,
1248 http_client: Arc<dyn HttpClient>,
1249 token_refresh_tx: Option<mpsc::UnboundedSender<OAuthSession>>,
1250 ) -> Self {
1251 Self {
1252 session: SyncMutex::new(session),
1253 http_client,
1254 token_refresh_tx,
1255 }
1256 }
1257
1258 fn access_token_is_expired(tokens: &OAuthTokens) -> bool {
1259 tokens.expires_at.is_some_and(|expires_at| {
1260 SystemTime::now()
1261 .checked_add(Duration::from_secs(30))
1262 .is_some_and(|now_with_buffer| expires_at <= now_with_buffer)
1263 })
1264 }
1265}
1266
1267#[async_trait]
1268impl OAuthTokenProvider for McpOAuthTokenProvider {
1269 fn access_token(&self) -> Option<String> {
1270 let session = self.session.lock();
1271 if Self::access_token_is_expired(&session.tokens) {
1272 return None;
1273 }
1274 Some(session.tokens.access_token.clone())
1275 }
1276
1277 async fn try_refresh(&self) -> Result<bool> {
1278 let (refresh_token, token_endpoint, resource, client_id) = {
1279 let session = self.session.lock();
1280 match session.tokens.refresh_token.clone() {
1281 Some(refresh_token) => (
1282 refresh_token,
1283 session.token_endpoint.clone(),
1284 session.resource.clone(),
1285 session.client_registration.client_id.clone(),
1286 ),
1287 None => return Ok(false),
1288 }
1289 };
1290
1291 let resource_str = canonical_server_uri(&resource);
1292
1293 match refresh_tokens(
1294 &self.http_client,
1295 &token_endpoint,
1296 &refresh_token,
1297 &client_id,
1298 &resource_str,
1299 )
1300 .await
1301 {
1302 Ok(mut new_tokens) => {
1303 if new_tokens.refresh_token.is_none() {
1304 new_tokens.refresh_token = Some(refresh_token);
1305 }
1306
1307 {
1308 let mut session = self.session.lock();
1309 session.tokens = new_tokens;
1310
1311 if let Some(ref tx) = self.token_refresh_tx {
1312 tx.unbounded_send(session.clone()).ok();
1313 }
1314 }
1315
1316 Ok(true)
1317 }
1318 Err(err) => {
1319 log::warn!("OAuth token refresh failed: {}", err);
1320 Ok(false)
1321 }
1322 }
1323 }
1324}
1325
1326#[cfg(test)]
1327mod tests {
1328 use super::*;
1329 use http_client::Response;
1330
1331 // -- require_https_or_loopback tests ------------------------------------
1332
1333 #[test]
1334 fn test_require_https_or_loopback_accepts_https() {
1335 let url = Url::parse("https://auth.example.com/token").unwrap();
1336 assert!(require_https_or_loopback(&url).is_ok());
1337 }
1338
1339 #[test]
1340 fn test_require_https_or_loopback_rejects_http_remote() {
1341 let url = Url::parse("http://auth.example.com/token").unwrap();
1342 assert!(require_https_or_loopback(&url).is_err());
1343 }
1344
1345 #[test]
1346 fn test_require_https_or_loopback_accepts_http_127_0_0_1() {
1347 let url = Url::parse("http://127.0.0.1:8080/callback").unwrap();
1348 assert!(require_https_or_loopback(&url).is_ok());
1349 }
1350
1351 #[test]
1352 fn test_require_https_or_loopback_accepts_http_ipv6_loopback() {
1353 let url = Url::parse("http://[::1]:8080/callback").unwrap();
1354 assert!(require_https_or_loopback(&url).is_ok());
1355 }
1356
1357 #[test]
1358 fn test_require_https_or_loopback_accepts_http_localhost() {
1359 let url = Url::parse("http://localhost:8080/callback").unwrap();
1360 assert!(require_https_or_loopback(&url).is_ok());
1361 }
1362
1363 #[test]
1364 fn test_require_https_or_loopback_accepts_http_localhost_case_insensitive() {
1365 let url = Url::parse("http://LOCALHOST:8080/callback").unwrap();
1366 assert!(require_https_or_loopback(&url).is_ok());
1367 }
1368
1369 #[test]
1370 fn test_require_https_or_loopback_rejects_http_non_loopback_ip() {
1371 let url = Url::parse("http://192.168.1.1:8080/token").unwrap();
1372 assert!(require_https_or_loopback(&url).is_err());
1373 }
1374
1375 #[test]
1376 fn test_require_https_or_loopback_rejects_ftp() {
1377 let url = Url::parse("ftp://auth.example.com/token").unwrap();
1378 assert!(require_https_or_loopback(&url).is_err());
1379 }
1380
1381 // -- validate_oauth_url (SSRF) tests ------------------------------------
1382
1383 #[test]
1384 fn test_validate_oauth_url_accepts_https_public() {
1385 let url = Url::parse("https://auth.example.com/token").unwrap();
1386 assert!(validate_oauth_url(&url).is_ok());
1387 }
1388
1389 #[test]
1390 fn test_validate_oauth_url_rejects_private_ipv4_10() {
1391 let url = Url::parse("https://10.0.0.1/token").unwrap();
1392 assert!(validate_oauth_url(&url).is_err());
1393 }
1394
1395 #[test]
1396 fn test_validate_oauth_url_rejects_private_ipv4_172() {
1397 let url = Url::parse("https://172.16.0.1/token").unwrap();
1398 assert!(validate_oauth_url(&url).is_err());
1399 }
1400
1401 #[test]
1402 fn test_validate_oauth_url_rejects_private_ipv4_192() {
1403 let url = Url::parse("https://192.168.1.1/token").unwrap();
1404 assert!(validate_oauth_url(&url).is_err());
1405 }
1406
1407 #[test]
1408 fn test_validate_oauth_url_rejects_link_local() {
1409 let url = Url::parse("https://169.254.169.254/latest/meta-data/").unwrap();
1410 assert!(validate_oauth_url(&url).is_err());
1411 }
1412
1413 #[test]
1414 fn test_validate_oauth_url_rejects_ipv6_ula() {
1415 let url = Url::parse("https://[fd12:3456:789a::1]/token").unwrap();
1416 assert!(validate_oauth_url(&url).is_err());
1417 }
1418
1419 #[test]
1420 fn test_validate_oauth_url_rejects_ipv6_unspecified() {
1421 let url = Url::parse("https://[::]/token").unwrap();
1422 assert!(validate_oauth_url(&url).is_err());
1423 }
1424
1425 #[test]
1426 fn test_validate_oauth_url_rejects_ipv4_mapped_ipv6_private() {
1427 let url = Url::parse("https://[::ffff:10.0.0.1]/token").unwrap();
1428 assert!(validate_oauth_url(&url).is_err());
1429 }
1430
1431 #[test]
1432 fn test_validate_oauth_url_rejects_ipv4_mapped_ipv6_link_local() {
1433 let url = Url::parse("https://[::ffff:169.254.169.254]/token").unwrap();
1434 assert!(validate_oauth_url(&url).is_err());
1435 }
1436
1437 #[test]
1438 fn test_validate_oauth_url_allows_http_loopback() {
1439 // Loopback is permitted (it's our callback server).
1440 let url = Url::parse("http://127.0.0.1:8080/callback").unwrap();
1441 assert!(validate_oauth_url(&url).is_ok());
1442 }
1443
1444 #[test]
1445 fn test_validate_oauth_url_allows_https_public_ip() {
1446 let url = Url::parse("https://93.184.216.34/token").unwrap();
1447 assert!(validate_oauth_url(&url).is_ok());
1448 }
1449
1450 // -- parse_www_authenticate tests ----------------------------------------
1451
1452 #[test]
1453 fn test_parse_www_authenticate_with_resource_metadata_and_scope() {
1454 let header = r#"Bearer resource_metadata="https://mcp.example.com/.well-known/oauth-protected-resource", scope="files:read user:profile""#;
1455 let result = parse_www_authenticate(header).unwrap();
1456
1457 assert_eq!(
1458 result.resource_metadata.as_ref().map(|u| u.as_str()),
1459 Some("https://mcp.example.com/.well-known/oauth-protected-resource")
1460 );
1461 assert_eq!(
1462 result.scope,
1463 Some(vec!["files:read".to_string(), "user:profile".to_string()])
1464 );
1465 assert_eq!(result.error, None);
1466 }
1467
1468 #[test]
1469 fn test_parse_www_authenticate_resource_metadata_only() {
1470 let header = r#"Bearer resource_metadata="https://mcp.example.com/.well-known/oauth-protected-resource""#;
1471 let result = parse_www_authenticate(header).unwrap();
1472
1473 assert_eq!(
1474 result.resource_metadata.as_ref().map(|u| u.as_str()),
1475 Some("https://mcp.example.com/.well-known/oauth-protected-resource")
1476 );
1477 assert_eq!(result.scope, None);
1478 }
1479
1480 #[test]
1481 fn test_parse_www_authenticate_bare_bearer() {
1482 let result = parse_www_authenticate("Bearer").unwrap();
1483 assert_eq!(result.resource_metadata, None);
1484 assert_eq!(result.scope, None);
1485 }
1486
1487 #[test]
1488 fn test_parse_www_authenticate_with_error() {
1489 let header = r#"Bearer error="insufficient_scope", scope="files:read files:write", resource_metadata="https://mcp.example.com/.well-known/oauth-protected-resource", error_description="Additional file write permission required""#;
1490 let result = parse_www_authenticate(header).unwrap();
1491
1492 assert_eq!(result.error, Some(BearerError::InsufficientScope));
1493 assert_eq!(
1494 result.error_description.as_deref(),
1495 Some("Additional file write permission required")
1496 );
1497 assert_eq!(
1498 result.scope,
1499 Some(vec!["files:read".to_string(), "files:write".to_string()])
1500 );
1501 assert!(result.resource_metadata.is_some());
1502 }
1503
1504 #[test]
1505 fn test_parse_www_authenticate_invalid_token_error() {
1506 let header =
1507 r#"Bearer error="invalid_token", error_description="The access token expired""#;
1508 let result = parse_www_authenticate(header).unwrap();
1509 assert_eq!(result.error, Some(BearerError::InvalidToken));
1510 }
1511
1512 #[test]
1513 fn test_parse_www_authenticate_invalid_request_error() {
1514 let header = r#"Bearer error="invalid_request""#;
1515 let result = parse_www_authenticate(header).unwrap();
1516 assert_eq!(result.error, Some(BearerError::InvalidRequest));
1517 }
1518
1519 #[test]
1520 fn test_parse_www_authenticate_unknown_error() {
1521 let header = r#"Bearer error="some_future_error""#;
1522 let result = parse_www_authenticate(header).unwrap();
1523 assert_eq!(result.error, Some(BearerError::Other));
1524 }
1525
1526 #[test]
1527 fn test_parse_www_authenticate_rejects_non_bearer() {
1528 let result = parse_www_authenticate("Basic realm=\"example\"");
1529 assert!(result.is_err());
1530 }
1531
1532 #[test]
1533 fn test_parse_www_authenticate_case_insensitive_scheme() {
1534 let header = r#"bearer resource_metadata="https://example.com/.well-known/oauth-protected-resource""#;
1535 let result = parse_www_authenticate(header).unwrap();
1536 assert!(result.resource_metadata.is_some());
1537 }
1538
1539 #[test]
1540 fn test_parse_www_authenticate_multiline_style() {
1541 // Some servers emit the header spread across multiple lines joined by
1542 // whitespace, as shown in the spec examples.
1543 let header = "Bearer resource_metadata=\"https://mcp.example.com/.well-known/oauth-protected-resource\",\n scope=\"files:read\"";
1544 let result = parse_www_authenticate(header).unwrap();
1545 assert!(result.resource_metadata.is_some());
1546 assert_eq!(result.scope, Some(vec!["files:read".to_string()]));
1547 }
1548
1549 #[test]
1550 fn test_protected_resource_metadata_urls_with_path() {
1551 let server_url = Url::parse("https://api.example.com/v1/mcp").unwrap();
1552 let urls = protected_resource_metadata_urls(&server_url);
1553
1554 assert_eq!(urls.len(), 2);
1555 assert_eq!(
1556 urls[0].as_str(),
1557 "https://api.example.com/.well-known/oauth-protected-resource/v1/mcp"
1558 );
1559 assert_eq!(
1560 urls[1].as_str(),
1561 "https://api.example.com/.well-known/oauth-protected-resource"
1562 );
1563 }
1564
1565 #[test]
1566 fn test_protected_resource_metadata_urls_without_path() {
1567 let server_url = Url::parse("https://mcp.example.com").unwrap();
1568 let urls = protected_resource_metadata_urls(&server_url);
1569
1570 assert_eq!(urls.len(), 1);
1571 assert_eq!(
1572 urls[0].as_str(),
1573 "https://mcp.example.com/.well-known/oauth-protected-resource"
1574 );
1575 }
1576
1577 #[test]
1578 fn test_auth_server_metadata_urls_with_path() {
1579 let issuer = Url::parse("https://auth.example.com/tenant1").unwrap();
1580 let urls = auth_server_metadata_urls(&issuer);
1581
1582 assert_eq!(urls.len(), 3);
1583 assert_eq!(
1584 urls[0].as_str(),
1585 "https://auth.example.com/.well-known/oauth-authorization-server/tenant1"
1586 );
1587 assert_eq!(
1588 urls[1].as_str(),
1589 "https://auth.example.com/.well-known/openid-configuration/tenant1"
1590 );
1591 assert_eq!(
1592 urls[2].as_str(),
1593 "https://auth.example.com/tenant1/.well-known/openid-configuration"
1594 );
1595 }
1596
1597 #[test]
1598 fn test_auth_server_metadata_urls_without_path() {
1599 let issuer = Url::parse("https://auth.example.com").unwrap();
1600 let urls = auth_server_metadata_urls(&issuer);
1601
1602 assert_eq!(urls.len(), 2);
1603 assert_eq!(
1604 urls[0].as_str(),
1605 "https://auth.example.com/.well-known/oauth-authorization-server"
1606 );
1607 assert_eq!(
1608 urls[1].as_str(),
1609 "https://auth.example.com/.well-known/openid-configuration"
1610 );
1611 }
1612
1613 // -- Canonical server URI tests ------------------------------------------
1614
1615 #[test]
1616 fn test_canonical_server_uri_simple() {
1617 let url = Url::parse("https://mcp.example.com").unwrap();
1618 assert_eq!(canonical_server_uri(&url), "https://mcp.example.com");
1619 }
1620
1621 #[test]
1622 fn test_canonical_server_uri_with_path() {
1623 let url = Url::parse("https://mcp.example.com/v1/mcp").unwrap();
1624 assert_eq!(canonical_server_uri(&url), "https://mcp.example.com/v1/mcp");
1625 }
1626
1627 #[test]
1628 fn test_canonical_server_uri_strips_trailing_slash() {
1629 let url = Url::parse("https://mcp.example.com/").unwrap();
1630 assert_eq!(canonical_server_uri(&url), "https://mcp.example.com");
1631 }
1632
1633 #[test]
1634 fn test_canonical_server_uri_preserves_port() {
1635 let url = Url::parse("https://mcp.example.com:8443").unwrap();
1636 assert_eq!(canonical_server_uri(&url), "https://mcp.example.com:8443");
1637 }
1638
1639 #[test]
1640 fn test_canonical_server_uri_lowercases() {
1641 let url = Url::parse("HTTPS://MCP.Example.COM/Server/MCP").unwrap();
1642 assert_eq!(
1643 canonical_server_uri(&url),
1644 "https://mcp.example.com/Server/MCP"
1645 );
1646 }
1647
1648 // -- Scope selection tests -----------------------------------------------
1649
1650 #[test]
1651 fn test_select_scopes_prefers_www_authenticate() {
1652 let www_auth = WwwAuthenticate {
1653 resource_metadata: None,
1654 scope: Some(vec!["files:read".into()]),
1655 error: None,
1656 error_description: None,
1657 };
1658 let resource_meta = ProtectedResourceMetadata {
1659 resource: Url::parse("https://example.com").unwrap(),
1660 authorization_servers: vec![],
1661 scopes_supported: Some(vec!["files:read".into(), "files:write".into()]),
1662 };
1663 assert_eq!(select_scopes(&www_auth, &resource_meta), vec!["files:read"]);
1664 }
1665
1666 #[test]
1667 fn test_select_scopes_falls_back_to_resource_metadata() {
1668 let www_auth = WwwAuthenticate {
1669 resource_metadata: None,
1670 scope: None,
1671 error: None,
1672 error_description: None,
1673 };
1674 let resource_meta = ProtectedResourceMetadata {
1675 resource: Url::parse("https://example.com").unwrap(),
1676 authorization_servers: vec![],
1677 scopes_supported: Some(vec!["admin".into()]),
1678 };
1679 assert_eq!(select_scopes(&www_auth, &resource_meta), vec!["admin"]);
1680 }
1681
1682 #[test]
1683 fn test_select_scopes_empty_when_nothing_available() {
1684 let www_auth = WwwAuthenticate {
1685 resource_metadata: None,
1686 scope: None,
1687 error: None,
1688 error_description: None,
1689 };
1690 let resource_meta = ProtectedResourceMetadata {
1691 resource: Url::parse("https://example.com").unwrap(),
1692 authorization_servers: vec![],
1693 scopes_supported: None,
1694 };
1695 assert!(select_scopes(&www_auth, &resource_meta).is_empty());
1696 }
1697
1698 // -- Client registration strategy tests ----------------------------------
1699
1700 #[test]
1701 fn test_registration_strategy_prefers_cimd() {
1702 let metadata = AuthServerMetadata {
1703 issuer: Url::parse("https://auth.example.com").unwrap(),
1704 authorization_endpoint: Url::parse("https://auth.example.com/authorize").unwrap(),
1705 token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
1706 registration_endpoint: Some(Url::parse("https://auth.example.com/register").unwrap()),
1707 scopes_supported: None,
1708 code_challenge_methods_supported: Some(vec!["S256".into()]),
1709 client_id_metadata_document_supported: true,
1710 };
1711 assert_eq!(
1712 determine_registration_strategy(&metadata),
1713 ClientRegistrationStrategy::Cimd {
1714 client_id: CIMD_URL.to_string(),
1715 }
1716 );
1717 }
1718
1719 #[test]
1720 fn test_registration_strategy_falls_back_to_dcr() {
1721 let reg_endpoint = Url::parse("https://auth.example.com/register").unwrap();
1722 let metadata = AuthServerMetadata {
1723 issuer: Url::parse("https://auth.example.com").unwrap(),
1724 authorization_endpoint: Url::parse("https://auth.example.com/authorize").unwrap(),
1725 token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
1726 registration_endpoint: Some(reg_endpoint.clone()),
1727 scopes_supported: None,
1728 code_challenge_methods_supported: Some(vec!["S256".into()]),
1729 client_id_metadata_document_supported: false,
1730 };
1731 assert_eq!(
1732 determine_registration_strategy(&metadata),
1733 ClientRegistrationStrategy::Dcr {
1734 registration_endpoint: reg_endpoint,
1735 }
1736 );
1737 }
1738
1739 #[test]
1740 fn test_registration_strategy_unavailable() {
1741 let metadata = AuthServerMetadata {
1742 issuer: Url::parse("https://auth.example.com").unwrap(),
1743 authorization_endpoint: Url::parse("https://auth.example.com/authorize").unwrap(),
1744 token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
1745 registration_endpoint: None,
1746 scopes_supported: None,
1747 code_challenge_methods_supported: Some(vec!["S256".into()]),
1748 client_id_metadata_document_supported: false,
1749 };
1750 assert_eq!(
1751 determine_registration_strategy(&metadata),
1752 ClientRegistrationStrategy::Unavailable,
1753 );
1754 }
1755
1756 // -- PKCE tests ----------------------------------------------------------
1757
1758 #[test]
1759 fn test_pkce_challenge_verifier_length() {
1760 let pkce = generate_pkce_challenge();
1761 // 32 random bytes → 43 base64url chars (no padding).
1762 assert_eq!(pkce.verifier.len(), 43);
1763 }
1764
1765 #[test]
1766 fn test_pkce_challenge_is_valid_base64url() {
1767 let pkce = generate_pkce_challenge();
1768 for c in pkce.verifier.chars().chain(pkce.challenge.chars()) {
1769 assert!(
1770 c.is_ascii_alphanumeric() || c == '-' || c == '_',
1771 "invalid base64url character: {}",
1772 c
1773 );
1774 }
1775 }
1776
1777 #[test]
1778 fn test_pkce_challenge_is_s256_of_verifier() {
1779 let pkce = generate_pkce_challenge();
1780 let engine = base64::engine::general_purpose::URL_SAFE_NO_PAD;
1781 let expected_digest = Sha256::digest(pkce.verifier.as_bytes());
1782 let expected_challenge = engine.encode(expected_digest);
1783 assert_eq!(pkce.challenge, expected_challenge);
1784 }
1785
1786 #[test]
1787 fn test_pkce_challenges_are_unique() {
1788 let a = generate_pkce_challenge();
1789 let b = generate_pkce_challenge();
1790 assert_ne!(a.verifier, b.verifier);
1791 }
1792
1793 // -- Authorization URL tests ---------------------------------------------
1794
1795 #[test]
1796 fn test_build_authorization_url() {
1797 let metadata = AuthServerMetadata {
1798 issuer: Url::parse("https://auth.example.com").unwrap(),
1799 authorization_endpoint: Url::parse("https://auth.example.com/authorize").unwrap(),
1800 token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
1801 registration_endpoint: None,
1802 scopes_supported: None,
1803 code_challenge_methods_supported: Some(vec!["S256".into()]),
1804 client_id_metadata_document_supported: true,
1805 };
1806 let pkce = PkceChallenge {
1807 verifier: "test_verifier".into(),
1808 challenge: "test_challenge".into(),
1809 };
1810 let url = build_authorization_url(
1811 &metadata,
1812 "https://zed.dev/oauth/client-metadata.json",
1813 "http://127.0.0.1:12345/callback",
1814 &["files:read".into(), "files:write".into()],
1815 "https://mcp.example.com",
1816 &pkce,
1817 "random_state_123",
1818 );
1819
1820 let pairs: std::collections::HashMap<_, _> = url.query_pairs().collect();
1821 assert_eq!(pairs.get("response_type").unwrap(), "code");
1822 assert_eq!(
1823 pairs.get("client_id").unwrap(),
1824 "https://zed.dev/oauth/client-metadata.json"
1825 );
1826 assert_eq!(
1827 pairs.get("redirect_uri").unwrap(),
1828 "http://127.0.0.1:12345/callback"
1829 );
1830 assert_eq!(pairs.get("scope").unwrap(), "files:read files:write");
1831 assert_eq!(pairs.get("resource").unwrap(), "https://mcp.example.com");
1832 assert_eq!(pairs.get("code_challenge").unwrap(), "test_challenge");
1833 assert_eq!(pairs.get("code_challenge_method").unwrap(), "S256");
1834 assert_eq!(pairs.get("state").unwrap(), "random_state_123");
1835 }
1836
1837 #[test]
1838 fn test_build_authorization_url_omits_empty_scope() {
1839 let metadata = AuthServerMetadata {
1840 issuer: Url::parse("https://auth.example.com").unwrap(),
1841 authorization_endpoint: Url::parse("https://auth.example.com/authorize").unwrap(),
1842 token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
1843 registration_endpoint: None,
1844 scopes_supported: None,
1845 code_challenge_methods_supported: Some(vec!["S256".into()]),
1846 client_id_metadata_document_supported: false,
1847 };
1848 let pkce = PkceChallenge {
1849 verifier: "v".into(),
1850 challenge: "c".into(),
1851 };
1852 let url = build_authorization_url(
1853 &metadata,
1854 "client_123",
1855 "http://127.0.0.1:9999/callback",
1856 &[],
1857 "https://mcp.example.com",
1858 &pkce,
1859 "state",
1860 );
1861
1862 let pairs: std::collections::HashMap<_, _> = url.query_pairs().collect();
1863 assert!(!pairs.contains_key("scope"));
1864 }
1865
1866 // -- Token exchange / refresh param tests --------------------------------
1867
1868 #[test]
1869 fn test_token_exchange_params() {
1870 let params = token_exchange_params(
1871 "auth_code_abc",
1872 "client_xyz",
1873 "http://127.0.0.1:5555/callback",
1874 "verifier_123",
1875 "https://mcp.example.com",
1876 );
1877 let map: std::collections::HashMap<&str, &str> =
1878 params.iter().map(|(k, v)| (*k, v.as_str())).collect();
1879
1880 assert_eq!(map["grant_type"], "authorization_code");
1881 assert_eq!(map["code"], "auth_code_abc");
1882 assert_eq!(map["redirect_uri"], "http://127.0.0.1:5555/callback");
1883 assert_eq!(map["client_id"], "client_xyz");
1884 assert_eq!(map["code_verifier"], "verifier_123");
1885 assert_eq!(map["resource"], "https://mcp.example.com");
1886 }
1887
1888 #[test]
1889 fn test_token_refresh_params() {
1890 let params =
1891 token_refresh_params("refresh_token_abc", "client_xyz", "https://mcp.example.com");
1892 let map: std::collections::HashMap<&str, &str> =
1893 params.iter().map(|(k, v)| (*k, v.as_str())).collect();
1894
1895 assert_eq!(map["grant_type"], "refresh_token");
1896 assert_eq!(map["refresh_token"], "refresh_token_abc");
1897 assert_eq!(map["client_id"], "client_xyz");
1898 assert_eq!(map["resource"], "https://mcp.example.com");
1899 }
1900
1901 // -- Token response tests ------------------------------------------------
1902
1903 #[test]
1904 fn test_token_response_into_tokens_with_expiry() {
1905 let response: TokenResponse = serde_json::from_str(
1906 r#"{"access_token": "at_123", "refresh_token": "rt_456", "expires_in": 3600, "token_type": "Bearer"}"#,
1907 )
1908 .unwrap();
1909
1910 let tokens = response.into_tokens();
1911 assert_eq!(tokens.access_token, "at_123");
1912 assert_eq!(tokens.refresh_token.as_deref(), Some("rt_456"));
1913 assert!(tokens.expires_at.is_some());
1914 }
1915
1916 #[test]
1917 fn test_token_response_into_tokens_minimal() {
1918 let response: TokenResponse =
1919 serde_json::from_str(r#"{"access_token": "at_789"}"#).unwrap();
1920
1921 let tokens = response.into_tokens();
1922 assert_eq!(tokens.access_token, "at_789");
1923 assert_eq!(tokens.refresh_token, None);
1924 assert_eq!(tokens.expires_at, None);
1925 }
1926
1927 // -- DCR body test -------------------------------------------------------
1928
1929 #[test]
1930 fn test_dcr_registration_body_shape() {
1931 let body = dcr_registration_body("http://127.0.0.1:12345/callback");
1932 assert_eq!(body["client_name"], "Zed");
1933 assert_eq!(body["redirect_uris"][0], "http://127.0.0.1:12345/callback");
1934 assert_eq!(body["grant_types"][0], "authorization_code");
1935 assert_eq!(body["response_types"][0], "code");
1936 assert_eq!(body["token_endpoint_auth_method"], "none");
1937 }
1938
1939 // -- Test helpers for async/HTTP tests -----------------------------------
1940
1941 fn make_fake_http_client(
1942 handler: impl Fn(
1943 http_client::Request<AsyncBody>,
1944 ) -> std::pin::Pin<
1945 Box<dyn std::future::Future<Output = anyhow::Result<Response<AsyncBody>>> + Send>,
1946 > + Send
1947 + Sync
1948 + 'static,
1949 ) -> Arc<dyn HttpClient> {
1950 http_client::FakeHttpClient::create(handler) as Arc<dyn HttpClient>
1951 }
1952
1953 fn json_response(status: u16, body: &str) -> anyhow::Result<Response<AsyncBody>> {
1954 Ok(Response::builder()
1955 .status(status)
1956 .header("Content-Type", "application/json")
1957 .body(AsyncBody::from(body.as_bytes().to_vec()))
1958 .unwrap())
1959 }
1960
1961 // -- Discovery integration tests -----------------------------------------
1962
1963 #[test]
1964 fn test_fetch_protected_resource_metadata() {
1965 smol::block_on(async {
1966 let client = make_fake_http_client(|req| {
1967 Box::pin(async move {
1968 let uri = req.uri().to_string();
1969 if uri.contains(".well-known/oauth-protected-resource") {
1970 json_response(
1971 200,
1972 r#"{
1973 "resource": "https://mcp.example.com",
1974 "authorization_servers": ["https://auth.example.com"],
1975 "scopes_supported": ["read", "write"]
1976 }"#,
1977 )
1978 } else {
1979 json_response(404, "{}")
1980 }
1981 })
1982 });
1983
1984 let server_url = Url::parse("https://mcp.example.com").unwrap();
1985 let www_auth = WwwAuthenticate {
1986 resource_metadata: None,
1987 scope: None,
1988 error: None,
1989 error_description: None,
1990 };
1991
1992 let metadata = fetch_protected_resource_metadata(&client, &server_url, &www_auth)
1993 .await
1994 .unwrap();
1995
1996 assert_eq!(metadata.resource.as_str(), "https://mcp.example.com/");
1997 assert_eq!(metadata.authorization_servers.len(), 1);
1998 assert_eq!(
1999 metadata.authorization_servers[0].as_str(),
2000 "https://auth.example.com/"
2001 );
2002 assert_eq!(
2003 metadata.scopes_supported,
2004 Some(vec!["read".to_string(), "write".to_string()])
2005 );
2006 });
2007 }
2008
2009 #[test]
2010 fn test_fetch_protected_resource_metadata_prefers_www_authenticate_url() {
2011 smol::block_on(async {
2012 let client = make_fake_http_client(|req| {
2013 Box::pin(async move {
2014 let uri = req.uri().to_string();
2015 if uri == "https://mcp.example.com/custom-resource-metadata" {
2016 json_response(
2017 200,
2018 r#"{
2019 "resource": "https://mcp.example.com",
2020 "authorization_servers": ["https://auth.example.com"]
2021 }"#,
2022 )
2023 } else {
2024 json_response(500, r#"{"error": "should not be called"}"#)
2025 }
2026 })
2027 });
2028
2029 let server_url = Url::parse("https://mcp.example.com").unwrap();
2030 let www_auth = WwwAuthenticate {
2031 resource_metadata: Some(
2032 Url::parse("https://mcp.example.com/custom-resource-metadata").unwrap(),
2033 ),
2034 scope: None,
2035 error: None,
2036 error_description: None,
2037 };
2038
2039 let metadata = fetch_protected_resource_metadata(&client, &server_url, &www_auth)
2040 .await
2041 .unwrap();
2042
2043 assert_eq!(metadata.authorization_servers.len(), 1);
2044 });
2045 }
2046
2047 #[test]
2048 fn test_fetch_protected_resource_metadata_rejects_cross_origin_url() {
2049 smol::block_on(async {
2050 let client = make_fake_http_client(|req| {
2051 Box::pin(async move {
2052 let uri = req.uri().to_string();
2053 // The cross-origin URL should NOT be fetched; only the
2054 // well-known fallback at the server's own origin should be.
2055 if uri.contains("attacker.example.com") {
2056 panic!("should not fetch cross-origin resource_metadata URL");
2057 } else if uri.contains(".well-known/oauth-protected-resource") {
2058 json_response(
2059 200,
2060 r#"{
2061 "resource": "https://mcp.example.com",
2062 "authorization_servers": ["https://auth.example.com"]
2063 }"#,
2064 )
2065 } else {
2066 json_response(404, "{}")
2067 }
2068 })
2069 });
2070
2071 let server_url = Url::parse("https://mcp.example.com").unwrap();
2072 let www_auth = WwwAuthenticate {
2073 resource_metadata: Some(
2074 Url::parse("https://attacker.example.com/fake-metadata").unwrap(),
2075 ),
2076 scope: None,
2077 error: None,
2078 error_description: None,
2079 };
2080
2081 let metadata = fetch_protected_resource_metadata(&client, &server_url, &www_auth)
2082 .await
2083 .unwrap();
2084
2085 // Should have used the fallback well-known URL, not the attacker's.
2086 assert_eq!(metadata.resource.as_str(), "https://mcp.example.com/");
2087 });
2088 }
2089
2090 #[test]
2091 fn test_fetch_auth_server_metadata() {
2092 smol::block_on(async {
2093 let client = make_fake_http_client(|req| {
2094 Box::pin(async move {
2095 let uri = req.uri().to_string();
2096 if uri.contains(".well-known/oauth-authorization-server") {
2097 json_response(
2098 200,
2099 r#"{
2100 "issuer": "https://auth.example.com",
2101 "authorization_endpoint": "https://auth.example.com/authorize",
2102 "token_endpoint": "https://auth.example.com/token",
2103 "registration_endpoint": "https://auth.example.com/register",
2104 "code_challenge_methods_supported": ["S256"],
2105 "client_id_metadata_document_supported": true
2106 }"#,
2107 )
2108 } else {
2109 json_response(404, "{}")
2110 }
2111 })
2112 });
2113
2114 let issuer = Url::parse("https://auth.example.com").unwrap();
2115 let metadata = fetch_auth_server_metadata(&client, &issuer).await.unwrap();
2116
2117 assert_eq!(metadata.issuer.as_str(), "https://auth.example.com/");
2118 assert_eq!(
2119 metadata.authorization_endpoint.as_str(),
2120 "https://auth.example.com/authorize"
2121 );
2122 assert_eq!(
2123 metadata.token_endpoint.as_str(),
2124 "https://auth.example.com/token"
2125 );
2126 assert!(metadata.registration_endpoint.is_some());
2127 assert!(metadata.client_id_metadata_document_supported);
2128 assert_eq!(
2129 metadata.code_challenge_methods_supported,
2130 Some(vec!["S256".to_string()])
2131 );
2132 });
2133 }
2134
2135 #[test]
2136 fn test_fetch_auth_server_metadata_falls_back_to_oidc() {
2137 smol::block_on(async {
2138 let client = make_fake_http_client(|req| {
2139 Box::pin(async move {
2140 let uri = req.uri().to_string();
2141 if uri.contains("openid-configuration") {
2142 json_response(
2143 200,
2144 r#"{
2145 "issuer": "https://auth.example.com",
2146 "authorization_endpoint": "https://auth.example.com/authorize",
2147 "token_endpoint": "https://auth.example.com/token",
2148 "code_challenge_methods_supported": ["S256"]
2149 }"#,
2150 )
2151 } else {
2152 json_response(404, "{}")
2153 }
2154 })
2155 });
2156
2157 let issuer = Url::parse("https://auth.example.com").unwrap();
2158 let metadata = fetch_auth_server_metadata(&client, &issuer).await.unwrap();
2159
2160 assert_eq!(
2161 metadata.authorization_endpoint.as_str(),
2162 "https://auth.example.com/authorize"
2163 );
2164 assert!(!metadata.client_id_metadata_document_supported);
2165 });
2166 }
2167
2168 #[test]
2169 fn test_fetch_auth_server_metadata_rejects_issuer_mismatch() {
2170 smol::block_on(async {
2171 let client = make_fake_http_client(|req| {
2172 Box::pin(async move {
2173 let uri = req.uri().to_string();
2174 if uri.contains(".well-known/oauth-authorization-server") {
2175 // Response claims to be a different issuer.
2176 json_response(
2177 200,
2178 r#"{
2179 "issuer": "https://evil.example.com",
2180 "authorization_endpoint": "https://evil.example.com/authorize",
2181 "token_endpoint": "https://evil.example.com/token",
2182 "code_challenge_methods_supported": ["S256"]
2183 }"#,
2184 )
2185 } else {
2186 json_response(404, "{}")
2187 }
2188 })
2189 });
2190
2191 let issuer = Url::parse("https://auth.example.com").unwrap();
2192 let result = fetch_auth_server_metadata(&client, &issuer).await;
2193
2194 assert!(result.is_err());
2195 let err_msg = result.unwrap_err().to_string();
2196 assert!(
2197 err_msg.contains("issuer mismatch"),
2198 "unexpected error: {}",
2199 err_msg
2200 );
2201 });
2202 }
2203
2204 // -- Full discover integration tests -------------------------------------
2205
2206 #[test]
2207 fn test_full_discover_with_cimd() {
2208 smol::block_on(async {
2209 let client = make_fake_http_client(|req| {
2210 Box::pin(async move {
2211 let uri = req.uri().to_string();
2212 if uri.contains("oauth-protected-resource") {
2213 json_response(
2214 200,
2215 r#"{
2216 "resource": "https://mcp.example.com",
2217 "authorization_servers": ["https://auth.example.com"],
2218 "scopes_supported": ["mcp:read"]
2219 }"#,
2220 )
2221 } else if uri.contains("oauth-authorization-server") {
2222 json_response(
2223 200,
2224 r#"{
2225 "issuer": "https://auth.example.com",
2226 "authorization_endpoint": "https://auth.example.com/authorize",
2227 "token_endpoint": "https://auth.example.com/token",
2228 "code_challenge_methods_supported": ["S256"],
2229 "client_id_metadata_document_supported": true
2230 }"#,
2231 )
2232 } else {
2233 json_response(404, "{}")
2234 }
2235 })
2236 });
2237
2238 let server_url = Url::parse("https://mcp.example.com").unwrap();
2239 let www_auth = WwwAuthenticate {
2240 resource_metadata: None,
2241 scope: None,
2242 error: None,
2243 error_description: None,
2244 };
2245
2246 let discovery = discover(&client, &server_url, &www_auth).await.unwrap();
2247 let registration =
2248 resolve_client_registration(&client, &discovery, "http://127.0.0.1:12345/callback")
2249 .await
2250 .unwrap();
2251
2252 assert_eq!(registration.client_id, CIMD_URL);
2253 assert_eq!(registration.client_secret, None);
2254 assert_eq!(discovery.scopes, vec!["mcp:read"]);
2255 });
2256 }
2257
2258 #[test]
2259 fn test_full_discover_with_dcr_fallback() {
2260 smol::block_on(async {
2261 let client = make_fake_http_client(|req| {
2262 Box::pin(async move {
2263 let uri = req.uri().to_string();
2264 if uri.contains("oauth-protected-resource") {
2265 json_response(
2266 200,
2267 r#"{
2268 "resource": "https://mcp.example.com",
2269 "authorization_servers": ["https://auth.example.com"]
2270 }"#,
2271 )
2272 } else if uri.contains("oauth-authorization-server") {
2273 json_response(
2274 200,
2275 r#"{
2276 "issuer": "https://auth.example.com",
2277 "authorization_endpoint": "https://auth.example.com/authorize",
2278 "token_endpoint": "https://auth.example.com/token",
2279 "registration_endpoint": "https://auth.example.com/register",
2280 "code_challenge_methods_supported": ["S256"],
2281 "client_id_metadata_document_supported": false
2282 }"#,
2283 )
2284 } else if uri.contains("/register") {
2285 json_response(
2286 201,
2287 r#"{
2288 "client_id": "dcr-minted-id-123",
2289 "client_secret": "dcr-secret-456"
2290 }"#,
2291 )
2292 } else {
2293 json_response(404, "{}")
2294 }
2295 })
2296 });
2297
2298 let server_url = Url::parse("https://mcp.example.com").unwrap();
2299 let www_auth = WwwAuthenticate {
2300 resource_metadata: None,
2301 scope: Some(vec!["files:read".into()]),
2302 error: None,
2303 error_description: None,
2304 };
2305
2306 let discovery = discover(&client, &server_url, &www_auth).await.unwrap();
2307 let registration =
2308 resolve_client_registration(&client, &discovery, "http://127.0.0.1:9999/callback")
2309 .await
2310 .unwrap();
2311
2312 assert_eq!(registration.client_id, "dcr-minted-id-123");
2313 assert_eq!(
2314 registration.client_secret.as_deref(),
2315 Some("dcr-secret-456")
2316 );
2317 assert_eq!(discovery.scopes, vec!["files:read"]);
2318 });
2319 }
2320
2321 #[test]
2322 fn test_discover_fails_without_pkce_support() {
2323 smol::block_on(async {
2324 let client = make_fake_http_client(|req| {
2325 Box::pin(async move {
2326 let uri = req.uri().to_string();
2327 if uri.contains("oauth-protected-resource") {
2328 json_response(
2329 200,
2330 r#"{
2331 "resource": "https://mcp.example.com",
2332 "authorization_servers": ["https://auth.example.com"]
2333 }"#,
2334 )
2335 } else if uri.contains("oauth-authorization-server") {
2336 json_response(
2337 200,
2338 r#"{
2339 "issuer": "https://auth.example.com",
2340 "authorization_endpoint": "https://auth.example.com/authorize",
2341 "token_endpoint": "https://auth.example.com/token"
2342 }"#,
2343 )
2344 } else {
2345 json_response(404, "{}")
2346 }
2347 })
2348 });
2349
2350 let server_url = Url::parse("https://mcp.example.com").unwrap();
2351 let www_auth = WwwAuthenticate {
2352 resource_metadata: None,
2353 scope: None,
2354 error: None,
2355 error_description: None,
2356 };
2357
2358 let result = discover(&client, &server_url, &www_auth).await;
2359 assert!(result.is_err());
2360 let err_msg = result.unwrap_err().to_string();
2361 assert!(
2362 err_msg.contains("code_challenge_methods_supported"),
2363 "unexpected error: {}",
2364 err_msg
2365 );
2366 });
2367 }
2368
2369 // -- Token exchange integration tests ------------------------------------
2370
2371 #[test]
2372 fn test_exchange_code_success() {
2373 smol::block_on(async {
2374 let client = make_fake_http_client(|req| {
2375 Box::pin(async move {
2376 let uri = req.uri().to_string();
2377 if uri.contains("/token") {
2378 json_response(
2379 200,
2380 r#"{
2381 "access_token": "new_access_token",
2382 "refresh_token": "new_refresh_token",
2383 "expires_in": 3600,
2384 "token_type": "Bearer"
2385 }"#,
2386 )
2387 } else {
2388 json_response(404, "{}")
2389 }
2390 })
2391 });
2392
2393 let metadata = AuthServerMetadata {
2394 issuer: Url::parse("https://auth.example.com").unwrap(),
2395 authorization_endpoint: Url::parse("https://auth.example.com/authorize").unwrap(),
2396 token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
2397 registration_endpoint: None,
2398 scopes_supported: None,
2399 code_challenge_methods_supported: Some(vec!["S256".into()]),
2400 client_id_metadata_document_supported: true,
2401 };
2402
2403 let tokens = exchange_code(
2404 &client,
2405 &metadata,
2406 "auth_code_123",
2407 CIMD_URL,
2408 "http://127.0.0.1:9999/callback",
2409 "verifier_abc",
2410 "https://mcp.example.com",
2411 )
2412 .await
2413 .unwrap();
2414
2415 assert_eq!(tokens.access_token, "new_access_token");
2416 assert_eq!(tokens.refresh_token.as_deref(), Some("new_refresh_token"));
2417 assert!(tokens.expires_at.is_some());
2418 });
2419 }
2420
2421 #[test]
2422 fn test_refresh_tokens_success() {
2423 smol::block_on(async {
2424 let client = make_fake_http_client(|req| {
2425 Box::pin(async move {
2426 let uri = req.uri().to_string();
2427 if uri.contains("/token") {
2428 json_response(
2429 200,
2430 r#"{
2431 "access_token": "refreshed_token",
2432 "expires_in": 1800,
2433 "token_type": "Bearer"
2434 }"#,
2435 )
2436 } else {
2437 json_response(404, "{}")
2438 }
2439 })
2440 });
2441
2442 let token_endpoint = Url::parse("https://auth.example.com/token").unwrap();
2443
2444 let tokens = refresh_tokens(
2445 &client,
2446 &token_endpoint,
2447 "old_refresh_token",
2448 CIMD_URL,
2449 "https://mcp.example.com",
2450 )
2451 .await
2452 .unwrap();
2453
2454 assert_eq!(tokens.access_token, "refreshed_token");
2455 assert_eq!(tokens.refresh_token, None);
2456 assert!(tokens.expires_at.is_some());
2457 });
2458 }
2459
2460 #[test]
2461 fn test_exchange_code_failure() {
2462 smol::block_on(async {
2463 let client = make_fake_http_client(|_req| {
2464 Box::pin(async move { json_response(400, r#"{"error": "invalid_grant"}"#) })
2465 });
2466
2467 let metadata = AuthServerMetadata {
2468 issuer: Url::parse("https://auth.example.com").unwrap(),
2469 authorization_endpoint: Url::parse("https://auth.example.com/authorize").unwrap(),
2470 token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
2471 registration_endpoint: None,
2472 scopes_supported: None,
2473 code_challenge_methods_supported: Some(vec!["S256".into()]),
2474 client_id_metadata_document_supported: true,
2475 };
2476
2477 let result = exchange_code(
2478 &client,
2479 &metadata,
2480 "bad_code",
2481 "client",
2482 "http://127.0.0.1:1/callback",
2483 "verifier",
2484 "https://mcp.example.com",
2485 )
2486 .await;
2487
2488 assert!(result.is_err());
2489 assert!(result.unwrap_err().to_string().contains("400"));
2490 });
2491 }
2492
2493 // -- DCR integration tests -----------------------------------------------
2494
2495 #[test]
2496 fn test_perform_dcr() {
2497 smol::block_on(async {
2498 let client = make_fake_http_client(|_req| {
2499 Box::pin(async move {
2500 json_response(
2501 201,
2502 r#"{
2503 "client_id": "dynamic-client-001",
2504 "client_secret": "dynamic-secret-001"
2505 }"#,
2506 )
2507 })
2508 });
2509
2510 let endpoint = Url::parse("https://auth.example.com/register").unwrap();
2511 let registration = perform_dcr(&client, &endpoint, "http://127.0.0.1:9999/callback")
2512 .await
2513 .unwrap();
2514
2515 assert_eq!(registration.client_id, "dynamic-client-001");
2516 assert_eq!(
2517 registration.client_secret.as_deref(),
2518 Some("dynamic-secret-001")
2519 );
2520 });
2521 }
2522
2523 #[test]
2524 fn test_perform_dcr_failure() {
2525 smol::block_on(async {
2526 let client = make_fake_http_client(|_req| {
2527 Box::pin(
2528 async move { json_response(403, r#"{"error": "registration_not_allowed"}"#) },
2529 )
2530 });
2531
2532 let endpoint = Url::parse("https://auth.example.com/register").unwrap();
2533 let result = perform_dcr(&client, &endpoint, "http://127.0.0.1:9999/callback").await;
2534
2535 assert!(result.is_err());
2536 assert!(result.unwrap_err().to_string().contains("403"));
2537 });
2538 }
2539
2540 // -- OAuthCallback parse tests -------------------------------------------
2541
2542 #[test]
2543 fn test_oauth_callback_parse_query() {
2544 let callback = OAuthCallback::parse_query("code=test_auth_code&state=test_state").unwrap();
2545 assert_eq!(callback.code, "test_auth_code");
2546 assert_eq!(callback.state, "test_state");
2547 }
2548
2549 #[test]
2550 fn test_oauth_callback_parse_query_reversed_order() {
2551 let callback = OAuthCallback::parse_query("state=test_state&code=test_auth_code").unwrap();
2552 assert_eq!(callback.code, "test_auth_code");
2553 assert_eq!(callback.state, "test_state");
2554 }
2555
2556 #[test]
2557 fn test_oauth_callback_parse_query_with_extra_params() {
2558 let callback =
2559 OAuthCallback::parse_query("code=test_auth_code&state=test_state&extra=ignored")
2560 .unwrap();
2561 assert_eq!(callback.code, "test_auth_code");
2562 assert_eq!(callback.state, "test_state");
2563 }
2564
2565 #[test]
2566 fn test_oauth_callback_parse_query_missing_code() {
2567 let result = OAuthCallback::parse_query("state=test_state");
2568 assert!(result.is_err());
2569 assert!(result.unwrap_err().to_string().contains("code"));
2570 }
2571
2572 #[test]
2573 fn test_oauth_callback_parse_query_missing_state() {
2574 let result = OAuthCallback::parse_query("code=test_auth_code");
2575 assert!(result.is_err());
2576 assert!(result.unwrap_err().to_string().contains("state"));
2577 }
2578
2579 #[test]
2580 fn test_oauth_callback_parse_query_empty_code() {
2581 let result = OAuthCallback::parse_query("code=&state=test_state");
2582 assert!(result.is_err());
2583 }
2584
2585 #[test]
2586 fn test_oauth_callback_parse_query_empty_state() {
2587 let result = OAuthCallback::parse_query("code=test_auth_code&state=");
2588 assert!(result.is_err());
2589 }
2590
2591 #[test]
2592 fn test_oauth_callback_parse_query_url_encoded_values() {
2593 let callback = OAuthCallback::parse_query("code=abc%20def&state=test%3Dstate").unwrap();
2594 assert_eq!(callback.code, "abc def");
2595 assert_eq!(callback.state, "test=state");
2596 }
2597
2598 #[test]
2599 fn test_oauth_callback_parse_query_error_response() {
2600 let result = OAuthCallback::parse_query(
2601 "error=access_denied&error_description=User%20denied%20access&state=abc",
2602 );
2603 assert!(result.is_err());
2604 let err_msg = result.unwrap_err().to_string();
2605 assert!(
2606 err_msg.contains("access_denied"),
2607 "unexpected error: {}",
2608 err_msg
2609 );
2610 assert!(
2611 err_msg.contains("User denied access"),
2612 "unexpected error: {}",
2613 err_msg
2614 );
2615 }
2616
2617 #[test]
2618 fn test_oauth_callback_parse_query_error_without_description() {
2619 let result = OAuthCallback::parse_query("error=server_error&state=abc");
2620 assert!(result.is_err());
2621 let err_msg = result.unwrap_err().to_string();
2622 assert!(
2623 err_msg.contains("server_error"),
2624 "unexpected error: {}",
2625 err_msg
2626 );
2627 assert!(
2628 err_msg.contains("no description"),
2629 "unexpected error: {}",
2630 err_msg
2631 );
2632 }
2633
2634 // -- McpOAuthTokenProvider tests -----------------------------------------
2635
2636 fn make_test_session(
2637 access_token: &str,
2638 refresh_token: Option<&str>,
2639 expires_at: Option<SystemTime>,
2640 ) -> OAuthSession {
2641 OAuthSession {
2642 token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
2643 resource: Url::parse("https://mcp.example.com").unwrap(),
2644 client_registration: OAuthClientRegistration {
2645 client_id: "test-client".into(),
2646 client_secret: None,
2647 },
2648 tokens: OAuthTokens {
2649 access_token: access_token.into(),
2650 refresh_token: refresh_token.map(String::from),
2651 expires_at,
2652 },
2653 }
2654 }
2655
2656 #[test]
2657 fn test_mcp_oauth_provider_returns_none_when_token_expired() {
2658 let expired = SystemTime::now() - Duration::from_secs(60);
2659 let session = make_test_session("stale-token", Some("rt"), Some(expired));
2660 let provider = McpOAuthTokenProvider::new(
2661 session,
2662 make_fake_http_client(|_| Box::pin(async { unreachable!() })),
2663 None,
2664 );
2665
2666 assert_eq!(provider.access_token(), None);
2667 }
2668
2669 #[test]
2670 fn test_mcp_oauth_provider_returns_token_when_not_expired() {
2671 let far_future = SystemTime::now() + Duration::from_secs(3600);
2672 let session = make_test_session("valid-token", Some("rt"), Some(far_future));
2673 let provider = McpOAuthTokenProvider::new(
2674 session,
2675 make_fake_http_client(|_| Box::pin(async { unreachable!() })),
2676 None,
2677 );
2678
2679 assert_eq!(provider.access_token().as_deref(), Some("valid-token"));
2680 }
2681
2682 #[test]
2683 fn test_mcp_oauth_provider_returns_token_when_no_expiry() {
2684 let session = make_test_session("no-expiry-token", Some("rt"), None);
2685 let provider = McpOAuthTokenProvider::new(
2686 session,
2687 make_fake_http_client(|_| Box::pin(async { unreachable!() })),
2688 None,
2689 );
2690
2691 assert_eq!(provider.access_token().as_deref(), Some("no-expiry-token"));
2692 }
2693
2694 #[test]
2695 fn test_mcp_oauth_provider_refresh_without_refresh_token_returns_false() {
2696 smol::block_on(async {
2697 let session = make_test_session("token", None, None);
2698 let provider = McpOAuthTokenProvider::new(
2699 session,
2700 make_fake_http_client(|_| {
2701 Box::pin(async { unreachable!("no HTTP call expected") })
2702 }),
2703 None,
2704 );
2705
2706 let refreshed = provider.try_refresh().await.unwrap();
2707 assert!(!refreshed);
2708 });
2709 }
2710
2711 #[test]
2712 fn test_mcp_oauth_provider_refresh_updates_session_and_notifies_channel() {
2713 smol::block_on(async {
2714 let session = make_test_session("old-access", Some("my-refresh-token"), None);
2715 let (tx, mut rx) = futures::channel::mpsc::unbounded();
2716
2717 let http_client = make_fake_http_client(|_req| {
2718 Box::pin(async {
2719 json_response(
2720 200,
2721 r#"{
2722 "access_token": "new-access",
2723 "refresh_token": "new-refresh",
2724 "expires_in": 1800
2725 }"#,
2726 )
2727 })
2728 });
2729
2730 let provider = McpOAuthTokenProvider::new(session, http_client, Some(tx));
2731
2732 let refreshed = provider.try_refresh().await.unwrap();
2733 assert!(refreshed);
2734 assert_eq!(provider.access_token().as_deref(), Some("new-access"));
2735
2736 let notified_session = rx.try_recv().expect("channel should have a session");
2737 assert_eq!(notified_session.tokens.access_token, "new-access");
2738 assert_eq!(
2739 notified_session.tokens.refresh_token.as_deref(),
2740 Some("new-refresh")
2741 );
2742 });
2743 }
2744
2745 #[test]
2746 fn test_mcp_oauth_provider_refresh_preserves_old_refresh_token_when_server_omits_it() {
2747 smol::block_on(async {
2748 let session = make_test_session("old-access", Some("original-refresh"), None);
2749 let (tx, mut rx) = futures::channel::mpsc::unbounded();
2750
2751 let http_client = make_fake_http_client(|_req| {
2752 Box::pin(async {
2753 json_response(
2754 200,
2755 r#"{
2756 "access_token": "new-access",
2757 "expires_in": 900
2758 }"#,
2759 )
2760 })
2761 });
2762
2763 let provider = McpOAuthTokenProvider::new(session, http_client, Some(tx));
2764
2765 let refreshed = provider.try_refresh().await.unwrap();
2766 assert!(refreshed);
2767
2768 let notified_session = rx.try_recv().expect("channel should have a session");
2769 assert_eq!(notified_session.tokens.access_token, "new-access");
2770 assert_eq!(
2771 notified_session.tokens.refresh_token.as_deref(),
2772 Some("original-refresh"),
2773 );
2774 });
2775 }
2776
2777 #[test]
2778 fn test_mcp_oauth_provider_refresh_returns_false_on_http_error() {
2779 smol::block_on(async {
2780 let session = make_test_session("old-access", Some("my-refresh"), None);
2781
2782 let http_client = make_fake_http_client(|_req| {
2783 Box::pin(async { json_response(401, r#"{"error": "invalid_grant"}"#) })
2784 });
2785
2786 let provider = McpOAuthTokenProvider::new(session, http_client, None);
2787
2788 let refreshed = provider.try_refresh().await.unwrap();
2789 assert!(!refreshed);
2790 // The old token should still be in place.
2791 assert_eq!(provider.access_token().as_deref(), Some("old-access"));
2792 });
2793 }
2794}