1//! OAuth 2.0 authentication for MCP servers using the Authorization Code +
2//! PKCE flow, per the MCP spec's OAuth profile.
3//!
4//! The flow is split into two phases:
5//!
6//! 1. **Discovery** ([`discover`]) fetches Protected Resource Metadata and
7//! Authorization Server Metadata. This can happen early (e.g. on a 401
8//! during server startup) because it doesn't need the redirect URI yet.
9//!
10//! 2. **Client registration** ([`resolve_client_registration`]) is separate
11//! because DCR requires the actual loopback redirect URI, which includes an
12//! ephemeral port that only exists once the callback server has started.
13//!
14//! After authentication, the full state is captured in [`OAuthSession`] which
15//! is persisted to the keychain. On next startup, the stored session feeds
16//! directly into [`McpOAuthTokenProvider`], giving a refresh-capable provider
17//! without requiring another browser flow.
18
19use anyhow::{Context as _, Result, anyhow, bail};
20use async_trait::async_trait;
21use base64::Engine as _;
22use futures::AsyncReadExt as _;
23use futures::channel::mpsc;
24use http_client::{AsyncBody, HttpClient, Request};
25use parking_lot::Mutex as SyncMutex;
26use rand::Rng as _;
27use serde::{Deserialize, Serialize};
28use sha2::{Digest, Sha256};
29
30use std::str::FromStr;
31use std::sync::Arc;
32use std::time::{Duration, SystemTime};
33use url::Url;
34use util::ResultExt as _;
35
36/// The CIMD URL where Zed's OAuth client metadata document is hosted.
37pub const CIMD_URL: &str = "https://zed.dev/oauth/client-metadata.json";
38
39/// Validate that a URL is safe to use as an OAuth endpoint.
40///
41/// OAuth endpoints carry sensitive material (authorization codes, PKCE
42/// verifiers, tokens) and must use TLS. Plain HTTP is only permitted for
43/// loopback addresses, per RFC 8252 Section 8.3.
44fn require_https_or_loopback(url: &Url) -> Result<()> {
45 if url.scheme() == "https" {
46 return Ok(());
47 }
48 if url.scheme() == "http" {
49 if let Some(host) = url.host() {
50 match host {
51 url::Host::Ipv4(ip) if ip.is_loopback() => return Ok(()),
52 url::Host::Ipv6(ip) if ip.is_loopback() => return Ok(()),
53 url::Host::Domain(d) if d.eq_ignore_ascii_case("localhost") => return Ok(()),
54 _ => {}
55 }
56 }
57 }
58 bail!(
59 "OAuth endpoint must use HTTPS (got {}://{})",
60 url.scheme(),
61 url.host_str().unwrap_or("?")
62 )
63}
64
65/// Validate that a URL is safe to use as an OAuth endpoint, including SSRF
66/// protections against private/reserved IP ranges.
67///
68/// This wraps [`require_https_or_loopback`] and adds IP-range checks to prevent
69/// an attacker-controlled MCP server from directing Zed to fetch internal
70/// network resources via metadata URLs.
71///
72/// **Known limitation:** Domain-name URLs that resolve to private IPs are *not*
73/// blocked here — full mitigation requires resolver-level validation (e.g. a
74/// custom `Resolve` implementation). This function only blocks IP-literal URLs.
75fn validate_oauth_url(url: &Url) -> Result<()> {
76 require_https_or_loopback(url)?;
77
78 if let Some(host) = url.host() {
79 match host {
80 url::Host::Ipv4(ip) => {
81 // Loopback is already allowed by require_https_or_loopback.
82 if ip.is_private() || ip.is_link_local() || ip.is_broadcast() || ip.is_unspecified()
83 {
84 bail!(
85 "OAuth endpoint must not point to private/reserved IP: {}",
86 ip
87 );
88 }
89 }
90 url::Host::Ipv6(ip) => {
91 // Check for IPv4-mapped IPv6 addresses (::ffff:a.b.c.d) which
92 // could bypass the IPv4 checks above.
93 if let Some(mapped_v4) = ip.to_ipv4_mapped() {
94 if mapped_v4.is_private()
95 || mapped_v4.is_link_local()
96 || mapped_v4.is_broadcast()
97 || mapped_v4.is_unspecified()
98 {
99 bail!(
100 "OAuth endpoint must not point to private/reserved IP: ::ffff:{}",
101 mapped_v4
102 );
103 }
104 }
105
106 if ip.is_unspecified() || ip.is_multicast() {
107 bail!(
108 "OAuth endpoint must not point to reserved IPv6 address: {}",
109 ip
110 );
111 }
112 // IPv6 Unique Local Addresses (fc00::/7). is_unique_local() is
113 // nightly-only, so check the prefix manually.
114 if (ip.segments()[0] & 0xfe00) == 0xfc00 {
115 bail!(
116 "OAuth endpoint must not point to IPv6 unique-local address: {}",
117 ip
118 );
119 }
120 }
121 url::Host::Domain(_) => {
122 // Domain-based SSRF prevention requires resolver-level checks.
123 // See known limitation in the doc comment above.
124 }
125 }
126 }
127
128 Ok(())
129}
130
131/// Parsed from the MCP server's WWW-Authenticate header or well-known endpoint
132/// per RFC 9728 (OAuth 2.0 Protected Resource Metadata).
133#[derive(Debug, Clone, Serialize, Deserialize)]
134pub struct ProtectedResourceMetadata {
135 pub resource: Url,
136 pub authorization_servers: Vec<Url>,
137 pub scopes_supported: Option<Vec<String>>,
138}
139
140/// Parsed from the authorization server's .well-known endpoint
141/// per RFC 8414 (OAuth 2.0 Authorization Server Metadata).
142#[derive(Debug, Clone, Serialize, Deserialize)]
143pub struct AuthServerMetadata {
144 pub issuer: Url,
145 pub authorization_endpoint: Url,
146 pub token_endpoint: Url,
147 pub registration_endpoint: Option<Url>,
148 pub scopes_supported: Option<Vec<String>>,
149 pub code_challenge_methods_supported: Option<Vec<String>>,
150 pub client_id_metadata_document_supported: bool,
151}
152
153/// The result of client registration — either CIMD or DCR.
154#[derive(Clone, Serialize, Deserialize)]
155pub struct OAuthClientRegistration {
156 pub client_id: String,
157 /// Only present for DCR-minted registrations.
158 pub client_secret: Option<String>,
159}
160
161impl std::fmt::Debug for OAuthClientRegistration {
162 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
163 f.debug_struct("OAuthClientRegistration")
164 .field("client_id", &self.client_id)
165 .field(
166 "client_secret",
167 &self.client_secret.as_ref().map(|_| "[redacted]"),
168 )
169 .finish()
170 }
171}
172
173/// Access and refresh tokens obtained from the token endpoint.
174#[derive(Clone, Serialize, Deserialize)]
175pub struct OAuthTokens {
176 pub access_token: String,
177 pub refresh_token: Option<String>,
178 pub expires_at: Option<SystemTime>,
179}
180
181impl std::fmt::Debug for OAuthTokens {
182 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
183 f.debug_struct("OAuthTokens")
184 .field("access_token", &"[redacted]")
185 .field(
186 "refresh_token",
187 &self.refresh_token.as_ref().map(|_| "[redacted]"),
188 )
189 .field("expires_at", &self.expires_at)
190 .finish()
191 }
192}
193
194/// Everything discovered before the browser flow starts. Client registration is
195/// resolved separately, once the real redirect URI is known.
196#[derive(Debug, Clone, Serialize, Deserialize)]
197pub struct OAuthDiscovery {
198 pub resource_metadata: ProtectedResourceMetadata,
199 pub auth_server_metadata: AuthServerMetadata,
200 pub scopes: Vec<String>,
201}
202
203/// The persisted OAuth session for a context server.
204///
205/// Stored in the keychain so startup can restore a refresh-capable provider
206/// without another browser flow. Deliberately excludes the full discovery
207/// metadata to keep the serialized size well within keychain item limits.
208#[derive(Clone, Serialize, Deserialize)]
209pub struct OAuthSession {
210 pub token_endpoint: Url,
211 pub resource: Url,
212 pub client_registration: OAuthClientRegistration,
213 pub tokens: OAuthTokens,
214}
215
216impl std::fmt::Debug for OAuthSession {
217 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
218 f.debug_struct("OAuthSession")
219 .field("token_endpoint", &self.token_endpoint)
220 .field("resource", &self.resource)
221 .field("client_registration", &self.client_registration)
222 .field("tokens", &self.tokens)
223 .finish()
224 }
225}
226
227/// Error codes defined by RFC 6750 Section 3.1 for Bearer token authentication.
228#[derive(Debug, Clone, Copy, PartialEq, Eq)]
229pub enum BearerError {
230 /// The request is missing a required parameter, includes an unsupported
231 /// parameter or parameter value, or is otherwise malformed.
232 InvalidRequest,
233 /// The access token provided is expired, revoked, malformed, or invalid.
234 InvalidToken,
235 /// The request requires higher privileges than provided by the access token.
236 InsufficientScope,
237 /// An unrecognized error code (extension or future spec addition).
238 Other,
239}
240
241impl BearerError {
242 fn parse(value: &str) -> Self {
243 match value {
244 "invalid_request" => BearerError::InvalidRequest,
245 "invalid_token" => BearerError::InvalidToken,
246 "insufficient_scope" => BearerError::InsufficientScope,
247 _ => BearerError::Other,
248 }
249 }
250}
251
252/// Fields extracted from a `WWW-Authenticate: Bearer` header.
253///
254/// Per RFC 9728 Section 5.1, MCP servers include `resource_metadata` to point
255/// at the Protected Resource Metadata document. The optional `scope` parameter
256/// (RFC 6750 Section 3) indicates scopes required for the request.
257#[derive(Debug, Clone, PartialEq, Eq)]
258pub struct WwwAuthenticate {
259 pub resource_metadata: Option<Url>,
260 pub scope: Option<Vec<String>>,
261 /// The parsed `error` parameter per RFC 6750 Section 3.1.
262 pub error: Option<BearerError>,
263 pub error_description: Option<String>,
264}
265
266/// Parse a `WWW-Authenticate` header value.
267///
268/// Expects the `Bearer` scheme followed by comma-separated `key="value"` pairs.
269/// Per RFC 6750 and RFC 9728, the relevant parameters are:
270/// - `resource_metadata` — URL of the Protected Resource Metadata document
271/// - `scope` — space-separated list of required scopes
272/// - `error` — error code (e.g. "insufficient_scope")
273/// - `error_description` — human-readable error description
274pub fn parse_www_authenticate(header: &str) -> Result<WwwAuthenticate> {
275 let header = header.trim();
276
277 let params_str = if header.len() >= 6 && header[..6].eq_ignore_ascii_case("bearer") {
278 header[6..].trim()
279 } else {
280 bail!("WWW-Authenticate header does not use Bearer scheme");
281 };
282
283 if params_str.is_empty() {
284 return Ok(WwwAuthenticate {
285 resource_metadata: None,
286 scope: None,
287 error: None,
288 error_description: None,
289 });
290 }
291
292 let params = parse_auth_params(params_str);
293
294 let resource_metadata = params
295 .get("resource_metadata")
296 .map(|v| Url::parse(v))
297 .transpose()
298 .map_err(|e| anyhow!("invalid resource_metadata URL: {}", e))?;
299
300 let scope = params
301 .get("scope")
302 .map(|v| v.split_whitespace().map(String::from).collect());
303
304 let error = params.get("error").map(|v| BearerError::parse(v));
305 let error_description = params.get("error_description").cloned();
306
307 Ok(WwwAuthenticate {
308 resource_metadata,
309 scope,
310 error,
311 error_description,
312 })
313}
314
315/// Parse comma-separated `key="value"` or `key=token` parameters from an
316/// auth-param list (RFC 7235 Section 2.1).
317fn parse_auth_params(input: &str) -> collections::HashMap<String, String> {
318 let mut params = collections::HashMap::default();
319 let mut remaining = input.trim();
320
321 while !remaining.is_empty() {
322 // Skip leading whitespace and commas.
323 remaining = remaining.trim_start_matches(|c: char| c == ',' || c.is_whitespace());
324 if remaining.is_empty() {
325 break;
326 }
327
328 // Find the key (everything before '=').
329 let eq_pos = match remaining.find('=') {
330 Some(pos) => pos,
331 None => break,
332 };
333
334 let key = remaining[..eq_pos].trim().to_lowercase();
335 remaining = &remaining[eq_pos + 1..];
336 remaining = remaining.trim_start();
337
338 // Parse the value: either quoted or unquoted (token).
339 let value;
340 if remaining.starts_with('"') {
341 // Quoted string: find the closing quote, handling escaped chars.
342 remaining = &remaining[1..]; // skip opening quote
343 let mut val = String::new();
344 let mut chars = remaining.char_indices();
345 loop {
346 match chars.next() {
347 Some((_, '\\')) => {
348 // Escaped character — take the next char literally.
349 if let Some((_, c)) = chars.next() {
350 val.push(c);
351 }
352 }
353 Some((i, '"')) => {
354 remaining = &remaining[i + 1..];
355 break;
356 }
357 Some((_, c)) => val.push(c),
358 None => {
359 remaining = "";
360 break;
361 }
362 }
363 }
364 value = val;
365 } else {
366 // Unquoted token: read until comma or whitespace.
367 let end = remaining
368 .find(|c: char| c == ',' || c.is_whitespace())
369 .unwrap_or(remaining.len());
370 value = remaining[..end].to_string();
371 remaining = &remaining[end..];
372 }
373
374 if !key.is_empty() {
375 params.insert(key, value);
376 }
377 }
378
379 params
380}
381
382/// Construct the well-known Protected Resource Metadata URIs for a given MCP
383/// server URL, per RFC 9728 Section 3.
384///
385/// Returns URIs in priority order:
386/// 1. Path-specific: `https://<host>/.well-known/oauth-protected-resource/<path>`
387/// 2. Root: `https://<host>/.well-known/oauth-protected-resource`
388pub fn protected_resource_metadata_urls(server_url: &Url) -> Vec<Url> {
389 let mut urls = Vec::new();
390 let base = format!("{}://{}", server_url.scheme(), server_url.authority());
391
392 let path = server_url.path().trim_start_matches('/');
393 if !path.is_empty() {
394 if let Ok(url) = Url::parse(&format!(
395 "{}/.well-known/oauth-protected-resource/{}",
396 base, path
397 )) {
398 urls.push(url);
399 }
400 }
401
402 if let Ok(url) = Url::parse(&format!("{}/.well-known/oauth-protected-resource", base)) {
403 urls.push(url);
404 }
405
406 urls
407}
408
409/// Construct the well-known Authorization Server Metadata URIs for a given
410/// issuer URL, per RFC 8414 Section 3.1 and Section 5 (OIDC compat).
411///
412/// Returns URIs in priority order, which differs depending on whether the
413/// issuer URL has a path component.
414pub fn auth_server_metadata_urls(issuer: &Url) -> Vec<Url> {
415 let mut urls = Vec::new();
416 let base = format!("{}://{}", issuer.scheme(), issuer.authority());
417 let path = issuer.path().trim_matches('/');
418
419 if !path.is_empty() {
420 // Issuer with path: try path-inserted variants first.
421 if let Ok(url) = Url::parse(&format!(
422 "{}/.well-known/oauth-authorization-server/{}",
423 base, path
424 )) {
425 urls.push(url);
426 }
427 if let Ok(url) = Url::parse(&format!(
428 "{}/.well-known/openid-configuration/{}",
429 base, path
430 )) {
431 urls.push(url);
432 }
433 if let Ok(url) = Url::parse(&format!(
434 "{}/{}/.well-known/openid-configuration",
435 base, path
436 )) {
437 urls.push(url);
438 }
439 } else {
440 // No path: standard well-known locations.
441 if let Ok(url) = Url::parse(&format!("{}/.well-known/oauth-authorization-server", base)) {
442 urls.push(url);
443 }
444 if let Ok(url) = Url::parse(&format!("{}/.well-known/openid-configuration", base)) {
445 urls.push(url);
446 }
447 }
448
449 urls
450}
451
452// -- Canonical server URI (RFC 8707) -----------------------------------------
453
454/// Derive the canonical resource URI for an MCP server URL, suitable for the
455/// `resource` parameter in authorization and token requests per RFC 8707.
456///
457/// Lowercases the scheme and host, preserves the path (without trailing slash),
458/// strips fragments and query strings.
459pub fn canonical_server_uri(server_url: &Url) -> String {
460 let mut uri = format!(
461 "{}://{}",
462 server_url.scheme().to_ascii_lowercase(),
463 server_url.host_str().unwrap_or("").to_ascii_lowercase(),
464 );
465 if let Some(port) = server_url.port() {
466 uri.push_str(&format!(":{}", port));
467 }
468 let path = server_url.path();
469 if path != "/" {
470 uri.push_str(path.trim_end_matches('/'));
471 }
472 uri
473}
474
475// -- Scope selection ---------------------------------------------------------
476
477/// Select scopes following the MCP spec's Scope Selection Strategy:
478/// 1. Use `scope` from the `WWW-Authenticate` challenge if present.
479/// 2. Fall back to `scopes_supported` from Protected Resource Metadata.
480/// 3. Return empty if neither is available.
481pub fn select_scopes(
482 www_authenticate: &WwwAuthenticate,
483 resource_metadata: &ProtectedResourceMetadata,
484) -> Vec<String> {
485 if let Some(ref scopes) = www_authenticate.scope {
486 if !scopes.is_empty() {
487 return scopes.clone();
488 }
489 }
490 resource_metadata
491 .scopes_supported
492 .clone()
493 .unwrap_or_default()
494}
495
496// -- Client registration strategy --------------------------------------------
497
498/// The registration approach to use, determined from auth server metadata.
499#[derive(Debug, Clone, PartialEq, Eq)]
500pub enum ClientRegistrationStrategy {
501 /// The auth server supports CIMD. Use the CIMD URL as client_id directly.
502 Cimd { client_id: String },
503 /// The auth server has a registration endpoint. Caller must POST to it.
504 Dcr { registration_endpoint: Url },
505 /// No supported registration mechanism.
506 Unavailable,
507}
508
509/// Determine how to register with the authorization server, following the
510/// spec's recommended priority: CIMD first, DCR fallback.
511pub fn determine_registration_strategy(
512 auth_server_metadata: &AuthServerMetadata,
513) -> ClientRegistrationStrategy {
514 if auth_server_metadata.client_id_metadata_document_supported {
515 ClientRegistrationStrategy::Cimd {
516 client_id: CIMD_URL.to_string(),
517 }
518 } else if let Some(ref endpoint) = auth_server_metadata.registration_endpoint {
519 ClientRegistrationStrategy::Dcr {
520 registration_endpoint: endpoint.clone(),
521 }
522 } else {
523 ClientRegistrationStrategy::Unavailable
524 }
525}
526
527// -- PKCE (RFC 7636) ---------------------------------------------------------
528
529/// A PKCE code verifier and its S256 challenge.
530#[derive(Clone)]
531pub struct PkceChallenge {
532 pub verifier: String,
533 pub challenge: String,
534}
535
536impl std::fmt::Debug for PkceChallenge {
537 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
538 f.debug_struct("PkceChallenge")
539 .field("verifier", &"[redacted]")
540 .field("challenge", &self.challenge)
541 .finish()
542 }
543}
544
545/// Generate a PKCE code verifier and S256 challenge per RFC 7636.
546///
547/// The verifier is 43 base64url characters derived from 32 random bytes.
548/// The challenge is `BASE64URL(SHA256(verifier))`.
549pub fn generate_pkce_challenge() -> PkceChallenge {
550 let mut random_bytes = [0u8; 32];
551 rand::rng().fill(&mut random_bytes);
552 let engine = base64::engine::general_purpose::URL_SAFE_NO_PAD;
553 let verifier = engine.encode(&random_bytes);
554
555 let digest = Sha256::digest(verifier.as_bytes());
556 let challenge = engine.encode(digest);
557
558 PkceChallenge {
559 verifier,
560 challenge,
561 }
562}
563
564// -- Authorization URL construction ------------------------------------------
565
566/// Build the authorization URL for the OAuth Authorization Code + PKCE flow.
567pub fn build_authorization_url(
568 auth_server_metadata: &AuthServerMetadata,
569 client_id: &str,
570 redirect_uri: &str,
571 scopes: &[String],
572 resource: &str,
573 pkce: &PkceChallenge,
574 state: &str,
575) -> Url {
576 let mut url = auth_server_metadata.authorization_endpoint.clone();
577 {
578 let mut query = url.query_pairs_mut();
579 query.append_pair("response_type", "code");
580 query.append_pair("client_id", client_id);
581 query.append_pair("redirect_uri", redirect_uri);
582 if !scopes.is_empty() {
583 query.append_pair("scope", &scopes.join(" "));
584 }
585 query.append_pair("resource", resource);
586 query.append_pair("code_challenge", &pkce.challenge);
587 query.append_pair("code_challenge_method", "S256");
588 query.append_pair("state", state);
589 }
590 url
591}
592
593// -- Token endpoint request bodies -------------------------------------------
594
595/// The JSON body returned by the token endpoint on success.
596#[derive(Deserialize)]
597pub struct TokenResponse {
598 pub access_token: String,
599 #[serde(default)]
600 pub refresh_token: Option<String>,
601 #[serde(default)]
602 pub expires_in: Option<u64>,
603 #[serde(default)]
604 pub token_type: Option<String>,
605}
606
607impl std::fmt::Debug for TokenResponse {
608 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
609 f.debug_struct("TokenResponse")
610 .field("access_token", &"[redacted]")
611 .field(
612 "refresh_token",
613 &self.refresh_token.as_ref().map(|_| "[redacted]"),
614 )
615 .field("expires_in", &self.expires_in)
616 .field("token_type", &self.token_type)
617 .finish()
618 }
619}
620
621impl TokenResponse {
622 /// Convert into `OAuthTokens`, computing `expires_at` from `expires_in`.
623 pub fn into_tokens(self) -> OAuthTokens {
624 let expires_at = self
625 .expires_in
626 .map(|secs| SystemTime::now() + Duration::from_secs(secs));
627 OAuthTokens {
628 access_token: self.access_token,
629 refresh_token: self.refresh_token,
630 expires_at,
631 }
632 }
633}
634
635/// Build the form-encoded body for an authorization code token exchange.
636pub fn token_exchange_params(
637 code: &str,
638 client_id: &str,
639 redirect_uri: &str,
640 code_verifier: &str,
641 resource: &str,
642 client_secret: Option<&str>,
643) -> Vec<(&'static str, String)> {
644 let mut params = vec![
645 ("grant_type", "authorization_code".to_string()),
646 ("code", code.to_string()),
647 ("redirect_uri", redirect_uri.to_string()),
648 ("client_id", client_id.to_string()),
649 ("code_verifier", code_verifier.to_string()),
650 ("resource", resource.to_string()),
651 ];
652 if let Some(secret) = client_secret {
653 params.push(("client_secret", secret.to_string()));
654 }
655 params
656}
657
658/// Build the form-encoded body for a token refresh request.
659pub fn token_refresh_params(
660 refresh_token: &str,
661 client_id: &str,
662 resource: &str,
663 client_secret: Option<&str>,
664) -> Vec<(&'static str, String)> {
665 let mut params = vec![
666 ("grant_type", "refresh_token".to_string()),
667 ("refresh_token", refresh_token.to_string()),
668 ("client_id", client_id.to_string()),
669 ("resource", resource.to_string()),
670 ];
671 if let Some(secret) = client_secret {
672 params.push(("client_secret", secret.to_string()));
673 }
674 params
675}
676
677// -- DCR request body (RFC 7591) ---------------------------------------------
678
679/// Build the JSON body for a Dynamic Client Registration request.
680///
681/// The `redirect_uri` should be the actual loopback URI with the ephemeral
682/// port (e.g. `http://127.0.0.1:12345/callback`). Some auth servers do strict
683/// redirect URI matching even for loopback addresses, so we register the
684/// exact URI we intend to use.
685pub fn dcr_registration_body(redirect_uri: &str) -> serde_json::Value {
686 serde_json::json!({
687 "client_name": "Zed",
688 "redirect_uris": [redirect_uri],
689 "grant_types": ["authorization_code"],
690 "response_types": ["code"],
691 "token_endpoint_auth_method": "none"
692 })
693}
694
695// -- Discovery (async, hits real endpoints) ----------------------------------
696
697/// Fetch Protected Resource Metadata from the MCP server.
698///
699/// Tries the `resource_metadata` URL from the `WWW-Authenticate` header first,
700/// then falls back to well-known URIs constructed from `server_url`.
701pub async fn fetch_protected_resource_metadata(
702 http_client: &Arc<dyn HttpClient>,
703 server_url: &Url,
704 www_authenticate: &WwwAuthenticate,
705) -> Result<ProtectedResourceMetadata> {
706 let candidate_urls = match &www_authenticate.resource_metadata {
707 Some(url) if url.origin() == server_url.origin() => vec![url.clone()],
708 Some(url) => {
709 log::warn!(
710 "Ignoring cross-origin resource_metadata URL {} \
711 (server origin: {})",
712 url,
713 server_url.origin().unicode_serialization()
714 );
715 protected_resource_metadata_urls(server_url)
716 }
717 None => protected_resource_metadata_urls(server_url),
718 };
719
720 for url in &candidate_urls {
721 match fetch_json::<ProtectedResourceMetadataResponse>(http_client, url).await {
722 Ok(response) => {
723 if response.authorization_servers.is_empty() {
724 bail!(
725 "Protected Resource Metadata at {} has no authorization_servers",
726 url
727 );
728 }
729 return Ok(ProtectedResourceMetadata {
730 resource: response.resource.unwrap_or_else(|| server_url.clone()),
731 authorization_servers: response.authorization_servers,
732 scopes_supported: response.scopes_supported,
733 });
734 }
735 Err(err) => {
736 log::debug!(
737 "Failed to fetch Protected Resource Metadata from {}: {}",
738 url,
739 err
740 );
741 }
742 }
743 }
744
745 bail!(
746 "Could not fetch Protected Resource Metadata for {}",
747 server_url
748 )
749}
750
751/// Fetch Authorization Server Metadata, trying RFC 8414 and OIDC Discovery
752/// endpoints in the priority order specified by the MCP spec.
753pub async fn fetch_auth_server_metadata(
754 http_client: &Arc<dyn HttpClient>,
755 issuer: &Url,
756) -> Result<AuthServerMetadata> {
757 let candidate_urls = auth_server_metadata_urls(issuer);
758
759 for url in &candidate_urls {
760 match fetch_json::<AuthServerMetadataResponse>(http_client, url).await {
761 Ok(response) => {
762 let reported_issuer = response.issuer.unwrap_or_else(|| issuer.clone());
763 // if reported_issuer != *issuer {
764 // bail!(
765 // "Auth server metadata issuer mismatch: expected {}, got {}",
766 // issuer,
767 // reported_issuer
768 // );
769 // }
770
771 return Ok(AuthServerMetadata {
772 issuer: reported_issuer,
773 authorization_endpoint: response
774 .authorization_endpoint
775 .ok_or_else(|| anyhow!("missing authorization_endpoint"))?,
776 token_endpoint: response
777 .token_endpoint
778 .ok_or_else(|| anyhow!("missing token_endpoint"))?,
779 registration_endpoint: response.registration_endpoint,
780 scopes_supported: response.scopes_supported,
781 code_challenge_methods_supported: response.code_challenge_methods_supported,
782 client_id_metadata_document_supported: response
783 .client_id_metadata_document_supported
784 .unwrap_or(false),
785 });
786 }
787 Err(err) => {
788 log::debug!("Failed to fetch Auth Server Metadata from {}: {}", url, err);
789 }
790 }
791 }
792
793 bail!(
794 "Could not fetch Authorization Server Metadata for {}",
795 issuer
796 )
797}
798
799/// Run the full discovery flow: fetch resource metadata, then auth server
800/// metadata, then select scopes. Client registration is resolved separately,
801/// once the real redirect URI is known.
802pub async fn discover(
803 http_client: &Arc<dyn HttpClient>,
804 server_url: &Url,
805 www_authenticate: &WwwAuthenticate,
806) -> Result<OAuthDiscovery> {
807 let resource_metadata =
808 fetch_protected_resource_metadata(http_client, server_url, www_authenticate).await?;
809
810 let auth_server_url = resource_metadata
811 .authorization_servers
812 .first()
813 .ok_or_else(|| anyhow!("no authorization servers in resource metadata"))?;
814
815 let auth_server_metadata = fetch_auth_server_metadata(http_client, auth_server_url).await?;
816
817 // Verify PKCE S256 support (spec requirement).
818 match &auth_server_metadata.code_challenge_methods_supported {
819 Some(methods) if methods.iter().any(|m| m == "S256") => {}
820 Some(_) => bail!("authorization server does not support S256 PKCE"),
821 None => bail!("authorization server does not advertise code_challenge_methods_supported"),
822 }
823
824 let scopes = select_scopes(www_authenticate, &resource_metadata);
825
826 Ok(OAuthDiscovery {
827 resource_metadata,
828 auth_server_metadata,
829 scopes,
830 })
831}
832
833/// Resolve the OAuth client registration for an authorization flow.
834///
835/// CIMD uses the static client metadata document directly. For DCR, a fresh
836/// registration is performed each time because the loopback redirect URI
837/// includes an ephemeral port that changes every flow.
838pub async fn resolve_client_registration(
839 http_client: &Arc<dyn HttpClient>,
840 discovery: &OAuthDiscovery,
841 redirect_uri: &str,
842) -> Result<OAuthClientRegistration> {
843 match determine_registration_strategy(&discovery.auth_server_metadata) {
844 ClientRegistrationStrategy::Cimd { client_id } => Ok(OAuthClientRegistration {
845 client_id,
846 client_secret: None,
847 }),
848 ClientRegistrationStrategy::Dcr {
849 registration_endpoint,
850 } => perform_dcr(http_client, ®istration_endpoint, redirect_uri).await,
851 ClientRegistrationStrategy::Unavailable => {
852 bail!("authorization server supports neither CIMD nor DCR")
853 }
854 }
855}
856
857// -- Dynamic Client Registration (RFC 7591) ----------------------------------
858
859/// Perform Dynamic Client Registration with the authorization server.
860pub async fn perform_dcr(
861 http_client: &Arc<dyn HttpClient>,
862 registration_endpoint: &Url,
863 redirect_uri: &str,
864) -> Result<OAuthClientRegistration> {
865 validate_oauth_url(registration_endpoint)?;
866
867 let body = dcr_registration_body(redirect_uri);
868 let body_bytes = serde_json::to_vec(&body)?;
869
870 let request = Request::builder()
871 .method(http_client::http::Method::POST)
872 .uri(registration_endpoint.as_str())
873 .header("Content-Type", "application/json")
874 .header("Accept", "application/json")
875 .body(AsyncBody::from(body_bytes))?;
876
877 let mut response = http_client.send(request).await?;
878
879 if !response.status().is_success() {
880 let mut error_body = String::new();
881 response.body_mut().read_to_string(&mut error_body).await?;
882 bail!(
883 "DCR failed with status {}: {}",
884 response.status(),
885 error_body
886 );
887 }
888
889 let mut response_body = String::new();
890 response
891 .body_mut()
892 .read_to_string(&mut response_body)
893 .await?;
894
895 let dcr_response: DcrResponse =
896 serde_json::from_str(&response_body).context("failed to parse DCR response")?;
897
898 Ok(OAuthClientRegistration {
899 client_id: dcr_response.client_id,
900 client_secret: dcr_response.client_secret,
901 })
902}
903
904// -- Token exchange and refresh (async) --------------------------------------
905
906/// Exchange an authorization code for tokens at the token endpoint.
907pub async fn exchange_code(
908 http_client: &Arc<dyn HttpClient>,
909 auth_server_metadata: &AuthServerMetadata,
910 code: &str,
911 client_id: &str,
912 redirect_uri: &str,
913 code_verifier: &str,
914 resource: &str,
915 client_secret: Option<&str>,
916) -> Result<OAuthTokens> {
917 let params = token_exchange_params(
918 code,
919 client_id,
920 redirect_uri,
921 code_verifier,
922 resource,
923 client_secret,
924 );
925 post_token_request(http_client, &auth_server_metadata.token_endpoint, ¶ms).await
926}
927
928/// Refresh tokens using a refresh token.
929pub async fn refresh_tokens(
930 http_client: &Arc<dyn HttpClient>,
931 token_endpoint: &Url,
932 refresh_token: &str,
933 client_id: &str,
934 resource: &str,
935 client_secret: Option<&str>,
936) -> Result<OAuthTokens> {
937 let params = token_refresh_params(refresh_token, client_id, resource, client_secret);
938 post_token_request(http_client, token_endpoint, ¶ms).await
939}
940
941/// POST form-encoded parameters to a token endpoint and parse the response.
942async fn post_token_request(
943 http_client: &Arc<dyn HttpClient>,
944 token_endpoint: &Url,
945 params: &[(&str, String)],
946) -> Result<OAuthTokens> {
947 validate_oauth_url(token_endpoint)?;
948
949 let body = url::form_urlencoded::Serializer::new(String::new())
950 .extend_pairs(params.iter().map(|(k, v)| (*k, v.as_str())))
951 .finish();
952
953 let request = Request::builder()
954 .method(http_client::http::Method::POST)
955 .uri(token_endpoint.as_str())
956 .header("Content-Type", "application/x-www-form-urlencoded")
957 .header("Accept", "application/json")
958 .body(AsyncBody::from(body.into_bytes()))?;
959
960 let mut response = http_client.send(request).await?;
961
962 if !response.status().is_success() {
963 let mut error_body = String::new();
964 response.body_mut().read_to_string(&mut error_body).await?;
965 bail!(
966 "token request failed with status {}: {}",
967 response.status(),
968 error_body
969 );
970 }
971
972 let mut response_body = String::new();
973 response
974 .body_mut()
975 .read_to_string(&mut response_body)
976 .await?;
977
978 let token_response: TokenResponse =
979 serde_json::from_str(&response_body).context("failed to parse token response")?;
980
981 Ok(token_response.into_tokens())
982}
983
984// -- Loopback HTTP callback server -------------------------------------------
985
986/// An OAuth authorization callback received via the loopback HTTP server.
987pub struct OAuthCallback {
988 pub code: String,
989 pub state: String,
990}
991
992impl std::fmt::Debug for OAuthCallback {
993 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
994 f.debug_struct("OAuthCallback")
995 .field("code", &"[redacted]")
996 .field("state", &"[redacted]")
997 .finish()
998 }
999}
1000
1001impl OAuthCallback {
1002 /// Parse the query string from a callback URL like
1003 /// `http://127.0.0.1:<port>/callback?code=...&state=...`.
1004 pub fn parse_query(query: &str) -> Result<Self> {
1005 let mut code: Option<String> = None;
1006 let mut state: Option<String> = None;
1007 let mut error: Option<String> = None;
1008 let mut error_description: Option<String> = None;
1009
1010 for (key, value) in url::form_urlencoded::parse(query.as_bytes()) {
1011 match key.as_ref() {
1012 "code" => {
1013 if !value.is_empty() {
1014 code = Some(value.into_owned());
1015 }
1016 }
1017 "state" => {
1018 if !value.is_empty() {
1019 state = Some(value.into_owned());
1020 }
1021 }
1022 "error" => {
1023 if !value.is_empty() {
1024 error = Some(value.into_owned());
1025 }
1026 }
1027 "error_description" => {
1028 if !value.is_empty() {
1029 error_description = Some(value.into_owned());
1030 }
1031 }
1032 _ => {}
1033 }
1034 }
1035
1036 // Check for OAuth error response (RFC 6749 Section 4.1.2.1) before
1037 // checking for missing code/state.
1038 if let Some(error_code) = error {
1039 bail!(
1040 "OAuth authorization failed: {} ({})",
1041 error_code,
1042 error_description.as_deref().unwrap_or("no description")
1043 );
1044 }
1045
1046 let code = code.ok_or_else(|| anyhow!("missing 'code' parameter in OAuth callback"))?;
1047 let state = state.ok_or_else(|| anyhow!("missing 'state' parameter in OAuth callback"))?;
1048
1049 Ok(Self { code, state })
1050 }
1051}
1052
1053/// How long to wait for the browser to complete the OAuth flow before giving
1054/// up and releasing the loopback port.
1055const CALLBACK_TIMEOUT: Duration = Duration::from_secs(2 * 60);
1056
1057/// Start a loopback HTTP server to receive the OAuth authorization callback.
1058///
1059/// Binds to an ephemeral loopback port for each flow.
1060///
1061/// Returns `(redirect_uri, callback_future)`. The caller should use the
1062/// redirect URI in the authorization request, open the browser, then await
1063/// the future to receive the callback.
1064///
1065/// The server accepts exactly one request on `/callback`, validates that it
1066/// contains `code` and `state` query parameters, responds with a minimal
1067/// HTML page telling the user they can close the tab, and shuts down.
1068///
1069/// The callback server shuts down when the returned oneshot receiver is dropped
1070/// (e.g. because the authentication task was cancelled), or after a timeout
1071/// ([CALLBACK_TIMEOUT]).
1072pub async fn start_callback_server() -> Result<(
1073 String,
1074 futures::channel::oneshot::Receiver<Result<OAuthCallback>>,
1075)> {
1076 let server = tiny_http::Server::http("127.0.0.1:0")
1077 .map_err(|e| anyhow!(e).context("Failed to bind loopback listener for OAuth callback"))?;
1078 let port = server
1079 .server_addr()
1080 .to_ip()
1081 .context("server not bound to a TCP address")?
1082 .port();
1083
1084 let redirect_uri = format!("http://127.0.0.1:{}/callback", port);
1085
1086 let (tx, rx) = futures::channel::oneshot::channel();
1087
1088 // `tiny_http` is blocking, so we run it on a background thread.
1089 // The `recv_timeout` loop lets us check for cancellation (the receiver
1090 // being dropped) and enforce an overall timeout.
1091 std::thread::spawn(move || {
1092 let deadline = std::time::Instant::now() + CALLBACK_TIMEOUT;
1093
1094 loop {
1095 if tx.is_canceled() {
1096 return;
1097 }
1098 let remaining = deadline.saturating_duration_since(std::time::Instant::now());
1099 if remaining.is_zero() {
1100 return;
1101 }
1102
1103 let timeout = remaining.min(Duration::from_millis(500));
1104 let Some(request) = (match server.recv_timeout(timeout) {
1105 Ok(req) => req,
1106 Err(_) => {
1107 let _ = tx.send(Err(anyhow!("OAuth callback server I/O error")));
1108 return;
1109 }
1110 }) else {
1111 // Timeout with no request — loop back and check cancellation.
1112 continue;
1113 };
1114
1115 let result = handle_callback_request(&request);
1116
1117 let (status_code, body) = match &result {
1118 Ok(_) => (
1119 200,
1120 "<html><body><h1>Authorization successful</h1>\
1121 <p>You can close this tab and return to Zed.</p></body></html>",
1122 ),
1123 Err(err) => {
1124 log::error!("OAuth callback error: {}", err);
1125 (
1126 400,
1127 "<html><body><h1>Authorization failed</h1>\
1128 <p>Something went wrong. Please try again from Zed.</p></body></html>",
1129 )
1130 }
1131 };
1132
1133 let response = tiny_http::Response::from_string(body)
1134 .with_status_code(status_code)
1135 .with_header(
1136 tiny_http::Header::from_str("Content-Type: text/html")
1137 .expect("failed to construct response header"),
1138 )
1139 .with_header(
1140 tiny_http::Header::from_str("Keep-Alive: timeout=0,max=0")
1141 .expect("failed to construct response header"),
1142 );
1143 request.respond(response).log_err();
1144
1145 let _ = tx.send(result);
1146 return;
1147 }
1148 });
1149
1150 Ok((redirect_uri, rx))
1151}
1152
1153/// Extract the `code` and `state` query parameters from an OAuth callback
1154/// request to `/callback`.
1155fn handle_callback_request(request: &tiny_http::Request) -> Result<OAuthCallback> {
1156 let url = Url::parse(&format!("http://localhost{}", request.url()))
1157 .context("malformed callback request URL")?;
1158
1159 if url.path() != "/callback" {
1160 bail!("unexpected path in OAuth callback: {}", url.path());
1161 }
1162
1163 let query = url
1164 .query()
1165 .ok_or_else(|| anyhow!("OAuth callback has no query string"))?;
1166 OAuthCallback::parse_query(query)
1167}
1168
1169// -- JSON fetch helper -------------------------------------------------------
1170
1171async fn fetch_json<T: serde::de::DeserializeOwned>(
1172 http_client: &Arc<dyn HttpClient>,
1173 url: &Url,
1174) -> Result<T> {
1175 validate_oauth_url(url)?;
1176
1177 let request = Request::builder()
1178 .method(http_client::http::Method::GET)
1179 .uri(url.as_str())
1180 .header("Accept", "application/json")
1181 .body(AsyncBody::default())?;
1182
1183 let mut response = http_client.send(request).await?;
1184
1185 if !response.status().is_success() {
1186 bail!("HTTP {} fetching {}", response.status(), url);
1187 }
1188
1189 let mut body = String::new();
1190 response.body_mut().read_to_string(&mut body).await?;
1191 serde_json::from_str(&body).with_context(|| format!("failed to parse JSON from {}", url))
1192}
1193
1194// -- Serde response types for discovery --------------------------------------
1195
1196#[derive(Debug, Deserialize)]
1197struct ProtectedResourceMetadataResponse {
1198 #[serde(default)]
1199 resource: Option<Url>,
1200 #[serde(default)]
1201 authorization_servers: Vec<Url>,
1202 #[serde(default)]
1203 scopes_supported: Option<Vec<String>>,
1204}
1205
1206#[derive(Debug, Deserialize)]
1207struct AuthServerMetadataResponse {
1208 #[serde(default)]
1209 issuer: Option<Url>,
1210 #[serde(default)]
1211 authorization_endpoint: Option<Url>,
1212 #[serde(default)]
1213 token_endpoint: Option<Url>,
1214 #[serde(default)]
1215 registration_endpoint: Option<Url>,
1216 #[serde(default)]
1217 scopes_supported: Option<Vec<String>>,
1218 #[serde(default)]
1219 code_challenge_methods_supported: Option<Vec<String>>,
1220 #[serde(default)]
1221 client_id_metadata_document_supported: Option<bool>,
1222}
1223
1224#[derive(Debug, Deserialize)]
1225struct DcrResponse {
1226 client_id: String,
1227 #[serde(default)]
1228 client_secret: Option<String>,
1229}
1230
1231/// Provides OAuth tokens to the HTTP transport layer.
1232///
1233/// The transport calls `access_token()` before each request. On a 401 response
1234/// it calls `try_refresh()` and retries once if the refresh succeeds.
1235#[async_trait]
1236pub trait OAuthTokenProvider: Send + Sync {
1237 /// Returns the current access token, if one is available.
1238 fn access_token(&self) -> Option<String>;
1239
1240 /// Attempts to refresh the access token. Returns `true` if a new token was
1241 /// obtained and the request should be retried.
1242 async fn try_refresh(&self) -> Result<bool>;
1243}
1244
1245/// Concrete `OAuthTokenProvider` backed by a full persisted OAuth session and
1246/// an HTTP client for token refresh. The same provider type is used both after
1247/// an interactive authentication flow and when restoring a saved session from
1248/// the keychain on startup.
1249pub struct McpOAuthTokenProvider {
1250 session: SyncMutex<OAuthSession>,
1251 http_client: Arc<dyn HttpClient>,
1252 token_refresh_tx: Option<mpsc::UnboundedSender<OAuthSession>>,
1253}
1254
1255impl McpOAuthTokenProvider {
1256 pub fn new(
1257 session: OAuthSession,
1258 http_client: Arc<dyn HttpClient>,
1259 token_refresh_tx: Option<mpsc::UnboundedSender<OAuthSession>>,
1260 ) -> Self {
1261 Self {
1262 session: SyncMutex::new(session),
1263 http_client,
1264 token_refresh_tx,
1265 }
1266 }
1267
1268 fn access_token_is_expired(tokens: &OAuthTokens) -> bool {
1269 tokens.expires_at.is_some_and(|expires_at| {
1270 SystemTime::now()
1271 .checked_add(Duration::from_secs(30))
1272 .is_some_and(|now_with_buffer| expires_at <= now_with_buffer)
1273 })
1274 }
1275}
1276
1277#[async_trait]
1278impl OAuthTokenProvider for McpOAuthTokenProvider {
1279 fn access_token(&self) -> Option<String> {
1280 let session = self.session.lock();
1281 if Self::access_token_is_expired(&session.tokens) {
1282 return None;
1283 }
1284 Some(session.tokens.access_token.clone())
1285 }
1286
1287 async fn try_refresh(&self) -> Result<bool> {
1288 let (refresh_token, token_endpoint, resource, client_id, client_secret) = {
1289 let session = self.session.lock();
1290 match session.tokens.refresh_token.clone() {
1291 Some(refresh_token) => (
1292 refresh_token,
1293 session.token_endpoint.clone(),
1294 session.resource.clone(),
1295 session.client_registration.client_id.clone(),
1296 session.client_registration.client_secret.clone(),
1297 ),
1298 None => return Ok(false),
1299 }
1300 };
1301
1302 let resource_str = canonical_server_uri(&resource);
1303
1304 match refresh_tokens(
1305 &self.http_client,
1306 &token_endpoint,
1307 &refresh_token,
1308 &client_id,
1309 &resource_str,
1310 client_secret.as_deref(),
1311 )
1312 .await
1313 {
1314 Ok(mut new_tokens) => {
1315 if new_tokens.refresh_token.is_none() {
1316 new_tokens.refresh_token = Some(refresh_token);
1317 }
1318
1319 {
1320 let mut session = self.session.lock();
1321 session.tokens = new_tokens;
1322
1323 if let Some(ref tx) = self.token_refresh_tx {
1324 tx.unbounded_send(session.clone()).ok();
1325 }
1326 }
1327
1328 Ok(true)
1329 }
1330 Err(err) => {
1331 log::warn!("OAuth token refresh failed: {}", err);
1332 Ok(false)
1333 }
1334 }
1335 }
1336}
1337
1338#[cfg(test)]
1339mod tests {
1340 use super::*;
1341 use http_client::Response;
1342
1343 // -- require_https_or_loopback tests ------------------------------------
1344
1345 #[test]
1346 fn test_require_https_or_loopback_accepts_https() {
1347 let url = Url::parse("https://auth.example.com/token").unwrap();
1348 assert!(require_https_or_loopback(&url).is_ok());
1349 }
1350
1351 #[test]
1352 fn test_require_https_or_loopback_rejects_http_remote() {
1353 let url = Url::parse("http://auth.example.com/token").unwrap();
1354 assert!(require_https_or_loopback(&url).is_err());
1355 }
1356
1357 #[test]
1358 fn test_require_https_or_loopback_accepts_http_127_0_0_1() {
1359 let url = Url::parse("http://127.0.0.1:8080/callback").unwrap();
1360 assert!(require_https_or_loopback(&url).is_ok());
1361 }
1362
1363 #[test]
1364 fn test_require_https_or_loopback_accepts_http_ipv6_loopback() {
1365 let url = Url::parse("http://[::1]:8080/callback").unwrap();
1366 assert!(require_https_or_loopback(&url).is_ok());
1367 }
1368
1369 #[test]
1370 fn test_require_https_or_loopback_accepts_http_localhost() {
1371 let url = Url::parse("http://localhost:8080/callback").unwrap();
1372 assert!(require_https_or_loopback(&url).is_ok());
1373 }
1374
1375 #[test]
1376 fn test_require_https_or_loopback_accepts_http_localhost_case_insensitive() {
1377 let url = Url::parse("http://LOCALHOST:8080/callback").unwrap();
1378 assert!(require_https_or_loopback(&url).is_ok());
1379 }
1380
1381 #[test]
1382 fn test_require_https_or_loopback_rejects_http_non_loopback_ip() {
1383 let url = Url::parse("http://192.168.1.1:8080/token").unwrap();
1384 assert!(require_https_or_loopback(&url).is_err());
1385 }
1386
1387 #[test]
1388 fn test_require_https_or_loopback_rejects_ftp() {
1389 let url = Url::parse("ftp://auth.example.com/token").unwrap();
1390 assert!(require_https_or_loopback(&url).is_err());
1391 }
1392
1393 // -- validate_oauth_url (SSRF) tests ------------------------------------
1394
1395 #[test]
1396 fn test_validate_oauth_url_accepts_https_public() {
1397 let url = Url::parse("https://auth.example.com/token").unwrap();
1398 assert!(validate_oauth_url(&url).is_ok());
1399 }
1400
1401 #[test]
1402 fn test_validate_oauth_url_rejects_private_ipv4_10() {
1403 let url = Url::parse("https://10.0.0.1/token").unwrap();
1404 assert!(validate_oauth_url(&url).is_err());
1405 }
1406
1407 #[test]
1408 fn test_validate_oauth_url_rejects_private_ipv4_172() {
1409 let url = Url::parse("https://172.16.0.1/token").unwrap();
1410 assert!(validate_oauth_url(&url).is_err());
1411 }
1412
1413 #[test]
1414 fn test_validate_oauth_url_rejects_private_ipv4_192() {
1415 let url = Url::parse("https://192.168.1.1/token").unwrap();
1416 assert!(validate_oauth_url(&url).is_err());
1417 }
1418
1419 #[test]
1420 fn test_validate_oauth_url_rejects_link_local() {
1421 let url = Url::parse("https://169.254.169.254/latest/meta-data/").unwrap();
1422 assert!(validate_oauth_url(&url).is_err());
1423 }
1424
1425 #[test]
1426 fn test_validate_oauth_url_rejects_ipv6_ula() {
1427 let url = Url::parse("https://[fd12:3456:789a::1]/token").unwrap();
1428 assert!(validate_oauth_url(&url).is_err());
1429 }
1430
1431 #[test]
1432 fn test_validate_oauth_url_rejects_ipv6_unspecified() {
1433 let url = Url::parse("https://[::]/token").unwrap();
1434 assert!(validate_oauth_url(&url).is_err());
1435 }
1436
1437 #[test]
1438 fn test_validate_oauth_url_rejects_ipv4_mapped_ipv6_private() {
1439 let url = Url::parse("https://[::ffff:10.0.0.1]/token").unwrap();
1440 assert!(validate_oauth_url(&url).is_err());
1441 }
1442
1443 #[test]
1444 fn test_validate_oauth_url_rejects_ipv4_mapped_ipv6_link_local() {
1445 let url = Url::parse("https://[::ffff:169.254.169.254]/token").unwrap();
1446 assert!(validate_oauth_url(&url).is_err());
1447 }
1448
1449 #[test]
1450 fn test_validate_oauth_url_allows_http_loopback() {
1451 // Loopback is permitted (it's our callback server).
1452 let url = Url::parse("http://127.0.0.1:8080/callback").unwrap();
1453 assert!(validate_oauth_url(&url).is_ok());
1454 }
1455
1456 #[test]
1457 fn test_validate_oauth_url_allows_https_public_ip() {
1458 let url = Url::parse("https://93.184.216.34/token").unwrap();
1459 assert!(validate_oauth_url(&url).is_ok());
1460 }
1461
1462 // -- parse_www_authenticate tests ----------------------------------------
1463
1464 #[test]
1465 fn test_parse_www_authenticate_with_resource_metadata_and_scope() {
1466 let header = r#"Bearer resource_metadata="https://mcp.example.com/.well-known/oauth-protected-resource", scope="files:read user:profile""#;
1467 let result = parse_www_authenticate(header).unwrap();
1468
1469 assert_eq!(
1470 result.resource_metadata.as_ref().map(|u| u.as_str()),
1471 Some("https://mcp.example.com/.well-known/oauth-protected-resource")
1472 );
1473 assert_eq!(
1474 result.scope,
1475 Some(vec!["files:read".to_string(), "user:profile".to_string()])
1476 );
1477 assert_eq!(result.error, None);
1478 }
1479
1480 #[test]
1481 fn test_parse_www_authenticate_resource_metadata_only() {
1482 let header = r#"Bearer resource_metadata="https://mcp.example.com/.well-known/oauth-protected-resource""#;
1483 let result = parse_www_authenticate(header).unwrap();
1484
1485 assert_eq!(
1486 result.resource_metadata.as_ref().map(|u| u.as_str()),
1487 Some("https://mcp.example.com/.well-known/oauth-protected-resource")
1488 );
1489 assert_eq!(result.scope, None);
1490 }
1491
1492 #[test]
1493 fn test_parse_www_authenticate_bare_bearer() {
1494 let result = parse_www_authenticate("Bearer").unwrap();
1495 assert_eq!(result.resource_metadata, None);
1496 assert_eq!(result.scope, None);
1497 }
1498
1499 #[test]
1500 fn test_parse_www_authenticate_with_error() {
1501 let header = r#"Bearer error="insufficient_scope", scope="files:read files:write", resource_metadata="https://mcp.example.com/.well-known/oauth-protected-resource", error_description="Additional file write permission required""#;
1502 let result = parse_www_authenticate(header).unwrap();
1503
1504 assert_eq!(result.error, Some(BearerError::InsufficientScope));
1505 assert_eq!(
1506 result.error_description.as_deref(),
1507 Some("Additional file write permission required")
1508 );
1509 assert_eq!(
1510 result.scope,
1511 Some(vec!["files:read".to_string(), "files:write".to_string()])
1512 );
1513 assert!(result.resource_metadata.is_some());
1514 }
1515
1516 #[test]
1517 fn test_parse_www_authenticate_invalid_token_error() {
1518 let header =
1519 r#"Bearer error="invalid_token", error_description="The access token expired""#;
1520 let result = parse_www_authenticate(header).unwrap();
1521 assert_eq!(result.error, Some(BearerError::InvalidToken));
1522 }
1523
1524 #[test]
1525 fn test_parse_www_authenticate_invalid_request_error() {
1526 let header = r#"Bearer error="invalid_request""#;
1527 let result = parse_www_authenticate(header).unwrap();
1528 assert_eq!(result.error, Some(BearerError::InvalidRequest));
1529 }
1530
1531 #[test]
1532 fn test_parse_www_authenticate_unknown_error() {
1533 let header = r#"Bearer error="some_future_error""#;
1534 let result = parse_www_authenticate(header).unwrap();
1535 assert_eq!(result.error, Some(BearerError::Other));
1536 }
1537
1538 #[test]
1539 fn test_parse_www_authenticate_rejects_non_bearer() {
1540 let result = parse_www_authenticate("Basic realm=\"example\"");
1541 assert!(result.is_err());
1542 }
1543
1544 #[test]
1545 fn test_parse_www_authenticate_case_insensitive_scheme() {
1546 let header = r#"bearer resource_metadata="https://example.com/.well-known/oauth-protected-resource""#;
1547 let result = parse_www_authenticate(header).unwrap();
1548 assert!(result.resource_metadata.is_some());
1549 }
1550
1551 #[test]
1552 fn test_parse_www_authenticate_multiline_style() {
1553 // Some servers emit the header spread across multiple lines joined by
1554 // whitespace, as shown in the spec examples.
1555 let header = "Bearer resource_metadata=\"https://mcp.example.com/.well-known/oauth-protected-resource\",\n scope=\"files:read\"";
1556 let result = parse_www_authenticate(header).unwrap();
1557 assert!(result.resource_metadata.is_some());
1558 assert_eq!(result.scope, Some(vec!["files:read".to_string()]));
1559 }
1560
1561 #[test]
1562 fn test_protected_resource_metadata_urls_with_path() {
1563 let server_url = Url::parse("https://api.example.com/v1/mcp").unwrap();
1564 let urls = protected_resource_metadata_urls(&server_url);
1565
1566 assert_eq!(urls.len(), 2);
1567 assert_eq!(
1568 urls[0].as_str(),
1569 "https://api.example.com/.well-known/oauth-protected-resource/v1/mcp"
1570 );
1571 assert_eq!(
1572 urls[1].as_str(),
1573 "https://api.example.com/.well-known/oauth-protected-resource"
1574 );
1575 }
1576
1577 #[test]
1578 fn test_protected_resource_metadata_urls_without_path() {
1579 let server_url = Url::parse("https://mcp.example.com").unwrap();
1580 let urls = protected_resource_metadata_urls(&server_url);
1581
1582 assert_eq!(urls.len(), 1);
1583 assert_eq!(
1584 urls[0].as_str(),
1585 "https://mcp.example.com/.well-known/oauth-protected-resource"
1586 );
1587 }
1588
1589 #[test]
1590 fn test_auth_server_metadata_urls_with_path() {
1591 let issuer = Url::parse("https://auth.example.com/tenant1").unwrap();
1592 let urls = auth_server_metadata_urls(&issuer);
1593
1594 assert_eq!(urls.len(), 3);
1595 assert_eq!(
1596 urls[0].as_str(),
1597 "https://auth.example.com/.well-known/oauth-authorization-server/tenant1"
1598 );
1599 assert_eq!(
1600 urls[1].as_str(),
1601 "https://auth.example.com/.well-known/openid-configuration/tenant1"
1602 );
1603 assert_eq!(
1604 urls[2].as_str(),
1605 "https://auth.example.com/tenant1/.well-known/openid-configuration"
1606 );
1607 }
1608
1609 #[test]
1610 fn test_auth_server_metadata_urls_without_path() {
1611 let issuer = Url::parse("https://auth.example.com").unwrap();
1612 let urls = auth_server_metadata_urls(&issuer);
1613
1614 assert_eq!(urls.len(), 2);
1615 assert_eq!(
1616 urls[0].as_str(),
1617 "https://auth.example.com/.well-known/oauth-authorization-server"
1618 );
1619 assert_eq!(
1620 urls[1].as_str(),
1621 "https://auth.example.com/.well-known/openid-configuration"
1622 );
1623 }
1624
1625 // -- Canonical server URI tests ------------------------------------------
1626
1627 #[test]
1628 fn test_canonical_server_uri_simple() {
1629 let url = Url::parse("https://mcp.example.com").unwrap();
1630 assert_eq!(canonical_server_uri(&url), "https://mcp.example.com");
1631 }
1632
1633 #[test]
1634 fn test_canonical_server_uri_with_path() {
1635 let url = Url::parse("https://mcp.example.com/v1/mcp").unwrap();
1636 assert_eq!(canonical_server_uri(&url), "https://mcp.example.com/v1/mcp");
1637 }
1638
1639 #[test]
1640 fn test_canonical_server_uri_strips_trailing_slash() {
1641 let url = Url::parse("https://mcp.example.com/").unwrap();
1642 assert_eq!(canonical_server_uri(&url), "https://mcp.example.com");
1643 }
1644
1645 #[test]
1646 fn test_canonical_server_uri_preserves_port() {
1647 let url = Url::parse("https://mcp.example.com:8443").unwrap();
1648 assert_eq!(canonical_server_uri(&url), "https://mcp.example.com:8443");
1649 }
1650
1651 #[test]
1652 fn test_canonical_server_uri_lowercases() {
1653 let url = Url::parse("HTTPS://MCP.Example.COM/Server/MCP").unwrap();
1654 assert_eq!(
1655 canonical_server_uri(&url),
1656 "https://mcp.example.com/Server/MCP"
1657 );
1658 }
1659
1660 // -- Scope selection tests -----------------------------------------------
1661
1662 #[test]
1663 fn test_select_scopes_prefers_www_authenticate() {
1664 let www_auth = WwwAuthenticate {
1665 resource_metadata: None,
1666 scope: Some(vec!["files:read".into()]),
1667 error: None,
1668 error_description: None,
1669 };
1670 let resource_meta = ProtectedResourceMetadata {
1671 resource: Url::parse("https://example.com").unwrap(),
1672 authorization_servers: vec![],
1673 scopes_supported: Some(vec!["files:read".into(), "files:write".into()]),
1674 };
1675 assert_eq!(select_scopes(&www_auth, &resource_meta), vec!["files:read"]);
1676 }
1677
1678 #[test]
1679 fn test_select_scopes_falls_back_to_resource_metadata() {
1680 let www_auth = WwwAuthenticate {
1681 resource_metadata: None,
1682 scope: None,
1683 error: None,
1684 error_description: None,
1685 };
1686 let resource_meta = ProtectedResourceMetadata {
1687 resource: Url::parse("https://example.com").unwrap(),
1688 authorization_servers: vec![],
1689 scopes_supported: Some(vec!["admin".into()]),
1690 };
1691 assert_eq!(select_scopes(&www_auth, &resource_meta), vec!["admin"]);
1692 }
1693
1694 #[test]
1695 fn test_select_scopes_empty_when_nothing_available() {
1696 let www_auth = WwwAuthenticate {
1697 resource_metadata: None,
1698 scope: None,
1699 error: None,
1700 error_description: None,
1701 };
1702 let resource_meta = ProtectedResourceMetadata {
1703 resource: Url::parse("https://example.com").unwrap(),
1704 authorization_servers: vec![],
1705 scopes_supported: None,
1706 };
1707 assert!(select_scopes(&www_auth, &resource_meta).is_empty());
1708 }
1709
1710 // -- Client registration strategy tests ----------------------------------
1711
1712 #[test]
1713 fn test_registration_strategy_prefers_cimd() {
1714 let metadata = AuthServerMetadata {
1715 issuer: Url::parse("https://auth.example.com").unwrap(),
1716 authorization_endpoint: Url::parse("https://auth.example.com/authorize").unwrap(),
1717 token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
1718 registration_endpoint: Some(Url::parse("https://auth.example.com/register").unwrap()),
1719 scopes_supported: None,
1720 code_challenge_methods_supported: Some(vec!["S256".into()]),
1721 client_id_metadata_document_supported: true,
1722 };
1723 assert_eq!(
1724 determine_registration_strategy(&metadata),
1725 ClientRegistrationStrategy::Cimd {
1726 client_id: CIMD_URL.to_string(),
1727 }
1728 );
1729 }
1730
1731 #[test]
1732 fn test_registration_strategy_falls_back_to_dcr() {
1733 let reg_endpoint = Url::parse("https://auth.example.com/register").unwrap();
1734 let metadata = AuthServerMetadata {
1735 issuer: Url::parse("https://auth.example.com").unwrap(),
1736 authorization_endpoint: Url::parse("https://auth.example.com/authorize").unwrap(),
1737 token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
1738 registration_endpoint: Some(reg_endpoint.clone()),
1739 scopes_supported: None,
1740 code_challenge_methods_supported: Some(vec!["S256".into()]),
1741 client_id_metadata_document_supported: false,
1742 };
1743 assert_eq!(
1744 determine_registration_strategy(&metadata),
1745 ClientRegistrationStrategy::Dcr {
1746 registration_endpoint: reg_endpoint,
1747 }
1748 );
1749 }
1750
1751 #[test]
1752 fn test_registration_strategy_unavailable() {
1753 let metadata = AuthServerMetadata {
1754 issuer: Url::parse("https://auth.example.com").unwrap(),
1755 authorization_endpoint: Url::parse("https://auth.example.com/authorize").unwrap(),
1756 token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
1757 registration_endpoint: None,
1758 scopes_supported: None,
1759 code_challenge_methods_supported: Some(vec!["S256".into()]),
1760 client_id_metadata_document_supported: false,
1761 };
1762 assert_eq!(
1763 determine_registration_strategy(&metadata),
1764 ClientRegistrationStrategy::Unavailable,
1765 );
1766 }
1767
1768 // -- PKCE tests ----------------------------------------------------------
1769
1770 #[test]
1771 fn test_pkce_challenge_verifier_length() {
1772 let pkce = generate_pkce_challenge();
1773 // 32 random bytes → 43 base64url chars (no padding).
1774 assert_eq!(pkce.verifier.len(), 43);
1775 }
1776
1777 #[test]
1778 fn test_pkce_challenge_is_valid_base64url() {
1779 let pkce = generate_pkce_challenge();
1780 for c in pkce.verifier.chars().chain(pkce.challenge.chars()) {
1781 assert!(
1782 c.is_ascii_alphanumeric() || c == '-' || c == '_',
1783 "invalid base64url character: {}",
1784 c
1785 );
1786 }
1787 }
1788
1789 #[test]
1790 fn test_pkce_challenge_is_s256_of_verifier() {
1791 let pkce = generate_pkce_challenge();
1792 let engine = base64::engine::general_purpose::URL_SAFE_NO_PAD;
1793 let expected_digest = Sha256::digest(pkce.verifier.as_bytes());
1794 let expected_challenge = engine.encode(expected_digest);
1795 assert_eq!(pkce.challenge, expected_challenge);
1796 }
1797
1798 #[test]
1799 fn test_pkce_challenges_are_unique() {
1800 let a = generate_pkce_challenge();
1801 let b = generate_pkce_challenge();
1802 assert_ne!(a.verifier, b.verifier);
1803 }
1804
1805 // -- Authorization URL tests ---------------------------------------------
1806
1807 #[test]
1808 fn test_build_authorization_url() {
1809 let metadata = AuthServerMetadata {
1810 issuer: Url::parse("https://auth.example.com").unwrap(),
1811 authorization_endpoint: Url::parse("https://auth.example.com/authorize").unwrap(),
1812 token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
1813 registration_endpoint: None,
1814 scopes_supported: None,
1815 code_challenge_methods_supported: Some(vec!["S256".into()]),
1816 client_id_metadata_document_supported: true,
1817 };
1818 let pkce = PkceChallenge {
1819 verifier: "test_verifier".into(),
1820 challenge: "test_challenge".into(),
1821 };
1822 let url = build_authorization_url(
1823 &metadata,
1824 "https://zed.dev/oauth/client-metadata.json",
1825 "http://127.0.0.1:12345/callback",
1826 &["files:read".into(), "files:write".into()],
1827 "https://mcp.example.com",
1828 &pkce,
1829 "random_state_123",
1830 );
1831
1832 let pairs: std::collections::HashMap<_, _> = url.query_pairs().collect();
1833 assert_eq!(pairs.get("response_type").unwrap(), "code");
1834 assert_eq!(
1835 pairs.get("client_id").unwrap(),
1836 "https://zed.dev/oauth/client-metadata.json"
1837 );
1838 assert_eq!(
1839 pairs.get("redirect_uri").unwrap(),
1840 "http://127.0.0.1:12345/callback"
1841 );
1842 assert_eq!(pairs.get("scope").unwrap(), "files:read files:write");
1843 assert_eq!(pairs.get("resource").unwrap(), "https://mcp.example.com");
1844 assert_eq!(pairs.get("code_challenge").unwrap(), "test_challenge");
1845 assert_eq!(pairs.get("code_challenge_method").unwrap(), "S256");
1846 assert_eq!(pairs.get("state").unwrap(), "random_state_123");
1847 }
1848
1849 #[test]
1850 fn test_build_authorization_url_omits_empty_scope() {
1851 let metadata = AuthServerMetadata {
1852 issuer: Url::parse("https://auth.example.com").unwrap(),
1853 authorization_endpoint: Url::parse("https://auth.example.com/authorize").unwrap(),
1854 token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
1855 registration_endpoint: None,
1856 scopes_supported: None,
1857 code_challenge_methods_supported: Some(vec!["S256".into()]),
1858 client_id_metadata_document_supported: false,
1859 };
1860 let pkce = PkceChallenge {
1861 verifier: "v".into(),
1862 challenge: "c".into(),
1863 };
1864 let url = build_authorization_url(
1865 &metadata,
1866 "client_123",
1867 "http://127.0.0.1:9999/callback",
1868 &[],
1869 "https://mcp.example.com",
1870 &pkce,
1871 "state",
1872 );
1873
1874 let pairs: std::collections::HashMap<_, _> = url.query_pairs().collect();
1875 assert!(!pairs.contains_key("scope"));
1876 }
1877
1878 // -- Token exchange / refresh param tests --------------------------------
1879
1880 #[test]
1881 fn test_token_exchange_params() {
1882 let params = token_exchange_params(
1883 "auth_code_abc",
1884 "client_xyz",
1885 "http://127.0.0.1:5555/callback",
1886 "verifier_123",
1887 "https://mcp.example.com",
1888 None,
1889 );
1890 let map: std::collections::HashMap<&str, &str> =
1891 params.iter().map(|(k, v)| (*k, v.as_str())).collect();
1892
1893 assert_eq!(map["grant_type"], "authorization_code");
1894 assert_eq!(map["code"], "auth_code_abc");
1895 assert_eq!(map["redirect_uri"], "http://127.0.0.1:5555/callback");
1896 assert_eq!(map["client_id"], "client_xyz");
1897 assert_eq!(map["code_verifier"], "verifier_123");
1898 assert_eq!(map["resource"], "https://mcp.example.com");
1899 }
1900
1901 #[test]
1902 fn test_token_refresh_params() {
1903 let params = token_refresh_params(
1904 "refresh_token_abc",
1905 "client_xyz",
1906 "https://mcp.example.com",
1907 None,
1908 );
1909 let map: std::collections::HashMap<&str, &str> =
1910 params.iter().map(|(k, v)| (*k, v.as_str())).collect();
1911
1912 assert_eq!(map["grant_type"], "refresh_token");
1913 assert_eq!(map["refresh_token"], "refresh_token_abc");
1914 assert_eq!(map["client_id"], "client_xyz");
1915 assert_eq!(map["resource"], "https://mcp.example.com");
1916 }
1917
1918 // -- Token response tests ------------------------------------------------
1919
1920 #[test]
1921 fn test_token_response_into_tokens_with_expiry() {
1922 let response: TokenResponse = serde_json::from_str(
1923 r#"{"access_token": "at_123", "refresh_token": "rt_456", "expires_in": 3600, "token_type": "Bearer"}"#,
1924 )
1925 .unwrap();
1926
1927 let tokens = response.into_tokens();
1928 assert_eq!(tokens.access_token, "at_123");
1929 assert_eq!(tokens.refresh_token.as_deref(), Some("rt_456"));
1930 assert!(tokens.expires_at.is_some());
1931 }
1932
1933 #[test]
1934 fn test_token_response_into_tokens_minimal() {
1935 let response: TokenResponse =
1936 serde_json::from_str(r#"{"access_token": "at_789"}"#).unwrap();
1937
1938 let tokens = response.into_tokens();
1939 assert_eq!(tokens.access_token, "at_789");
1940 assert_eq!(tokens.refresh_token, None);
1941 assert_eq!(tokens.expires_at, None);
1942 }
1943
1944 // -- DCR body test -------------------------------------------------------
1945
1946 #[test]
1947 fn test_dcr_registration_body_shape() {
1948 let body = dcr_registration_body("http://127.0.0.1:12345/callback");
1949 assert_eq!(body["client_name"], "Zed");
1950 assert_eq!(body["redirect_uris"][0], "http://127.0.0.1:12345/callback");
1951 assert_eq!(body["grant_types"][0], "authorization_code");
1952 assert_eq!(body["response_types"][0], "code");
1953 assert_eq!(body["token_endpoint_auth_method"], "none");
1954 }
1955
1956 // -- Test helpers for async/HTTP tests -----------------------------------
1957
1958 fn make_fake_http_client(
1959 handler: impl Fn(
1960 http_client::Request<AsyncBody>,
1961 ) -> std::pin::Pin<
1962 Box<dyn std::future::Future<Output = anyhow::Result<Response<AsyncBody>>> + Send>,
1963 > + Send
1964 + Sync
1965 + 'static,
1966 ) -> Arc<dyn HttpClient> {
1967 http_client::FakeHttpClient::create(handler) as Arc<dyn HttpClient>
1968 }
1969
1970 fn json_response(status: u16, body: &str) -> anyhow::Result<Response<AsyncBody>> {
1971 Ok(Response::builder()
1972 .status(status)
1973 .header("Content-Type", "application/json")
1974 .body(AsyncBody::from(body.as_bytes().to_vec()))
1975 .unwrap())
1976 }
1977
1978 // -- Discovery integration tests -----------------------------------------
1979
1980 #[test]
1981 fn test_fetch_protected_resource_metadata() {
1982 smol::block_on(async {
1983 let client = make_fake_http_client(|req| {
1984 Box::pin(async move {
1985 let uri = req.uri().to_string();
1986 if uri.contains(".well-known/oauth-protected-resource") {
1987 json_response(
1988 200,
1989 r#"{
1990 "resource": "https://mcp.example.com",
1991 "authorization_servers": ["https://auth.example.com"],
1992 "scopes_supported": ["read", "write"]
1993 }"#,
1994 )
1995 } else {
1996 json_response(404, "{}")
1997 }
1998 })
1999 });
2000
2001 let server_url = Url::parse("https://mcp.example.com").unwrap();
2002 let www_auth = WwwAuthenticate {
2003 resource_metadata: None,
2004 scope: None,
2005 error: None,
2006 error_description: None,
2007 };
2008
2009 let metadata = fetch_protected_resource_metadata(&client, &server_url, &www_auth)
2010 .await
2011 .unwrap();
2012
2013 assert_eq!(metadata.resource.as_str(), "https://mcp.example.com/");
2014 assert_eq!(metadata.authorization_servers.len(), 1);
2015 assert_eq!(
2016 metadata.authorization_servers[0].as_str(),
2017 "https://auth.example.com/"
2018 );
2019 assert_eq!(
2020 metadata.scopes_supported,
2021 Some(vec!["read".to_string(), "write".to_string()])
2022 );
2023 });
2024 }
2025
2026 #[test]
2027 fn test_fetch_protected_resource_metadata_prefers_www_authenticate_url() {
2028 smol::block_on(async {
2029 let client = make_fake_http_client(|req| {
2030 Box::pin(async move {
2031 let uri = req.uri().to_string();
2032 if uri == "https://mcp.example.com/custom-resource-metadata" {
2033 json_response(
2034 200,
2035 r#"{
2036 "resource": "https://mcp.example.com",
2037 "authorization_servers": ["https://auth.example.com"]
2038 }"#,
2039 )
2040 } else {
2041 json_response(500, r#"{"error": "should not be called"}"#)
2042 }
2043 })
2044 });
2045
2046 let server_url = Url::parse("https://mcp.example.com").unwrap();
2047 let www_auth = WwwAuthenticate {
2048 resource_metadata: Some(
2049 Url::parse("https://mcp.example.com/custom-resource-metadata").unwrap(),
2050 ),
2051 scope: None,
2052 error: None,
2053 error_description: None,
2054 };
2055
2056 let metadata = fetch_protected_resource_metadata(&client, &server_url, &www_auth)
2057 .await
2058 .unwrap();
2059
2060 assert_eq!(metadata.authorization_servers.len(), 1);
2061 });
2062 }
2063
2064 #[test]
2065 fn test_fetch_protected_resource_metadata_rejects_cross_origin_url() {
2066 smol::block_on(async {
2067 let client = make_fake_http_client(|req| {
2068 Box::pin(async move {
2069 let uri = req.uri().to_string();
2070 // The cross-origin URL should NOT be fetched; only the
2071 // well-known fallback at the server's own origin should be.
2072 if uri.contains("attacker.example.com") {
2073 panic!("should not fetch cross-origin resource_metadata URL");
2074 } else if uri.contains(".well-known/oauth-protected-resource") {
2075 json_response(
2076 200,
2077 r#"{
2078 "resource": "https://mcp.example.com",
2079 "authorization_servers": ["https://auth.example.com"]
2080 }"#,
2081 )
2082 } else {
2083 json_response(404, "{}")
2084 }
2085 })
2086 });
2087
2088 let server_url = Url::parse("https://mcp.example.com").unwrap();
2089 let www_auth = WwwAuthenticate {
2090 resource_metadata: Some(
2091 Url::parse("https://attacker.example.com/fake-metadata").unwrap(),
2092 ),
2093 scope: None,
2094 error: None,
2095 error_description: None,
2096 };
2097
2098 let metadata = fetch_protected_resource_metadata(&client, &server_url, &www_auth)
2099 .await
2100 .unwrap();
2101
2102 // Should have used the fallback well-known URL, not the attacker's.
2103 assert_eq!(metadata.resource.as_str(), "https://mcp.example.com/");
2104 });
2105 }
2106
2107 #[test]
2108 fn test_fetch_auth_server_metadata() {
2109 smol::block_on(async {
2110 let client = make_fake_http_client(|req| {
2111 Box::pin(async move {
2112 let uri = req.uri().to_string();
2113 if uri.contains(".well-known/oauth-authorization-server") {
2114 json_response(
2115 200,
2116 r#"{
2117 "issuer": "https://auth.example.com",
2118 "authorization_endpoint": "https://auth.example.com/authorize",
2119 "token_endpoint": "https://auth.example.com/token",
2120 "registration_endpoint": "https://auth.example.com/register",
2121 "code_challenge_methods_supported": ["S256"],
2122 "client_id_metadata_document_supported": true
2123 }"#,
2124 )
2125 } else {
2126 json_response(404, "{}")
2127 }
2128 })
2129 });
2130
2131 let issuer = Url::parse("https://auth.example.com").unwrap();
2132 let metadata = fetch_auth_server_metadata(&client, &issuer).await.unwrap();
2133
2134 assert_eq!(metadata.issuer.as_str(), "https://auth.example.com/");
2135 assert_eq!(
2136 metadata.authorization_endpoint.as_str(),
2137 "https://auth.example.com/authorize"
2138 );
2139 assert_eq!(
2140 metadata.token_endpoint.as_str(),
2141 "https://auth.example.com/token"
2142 );
2143 assert!(metadata.registration_endpoint.is_some());
2144 assert!(metadata.client_id_metadata_document_supported);
2145 assert_eq!(
2146 metadata.code_challenge_methods_supported,
2147 Some(vec!["S256".to_string()])
2148 );
2149 });
2150 }
2151
2152 #[test]
2153 fn test_fetch_auth_server_metadata_falls_back_to_oidc() {
2154 smol::block_on(async {
2155 let client = make_fake_http_client(|req| {
2156 Box::pin(async move {
2157 let uri = req.uri().to_string();
2158 if uri.contains("openid-configuration") {
2159 json_response(
2160 200,
2161 r#"{
2162 "issuer": "https://auth.example.com",
2163 "authorization_endpoint": "https://auth.example.com/authorize",
2164 "token_endpoint": "https://auth.example.com/token",
2165 "code_challenge_methods_supported": ["S256"]
2166 }"#,
2167 )
2168 } else {
2169 json_response(404, "{}")
2170 }
2171 })
2172 });
2173
2174 let issuer = Url::parse("https://auth.example.com").unwrap();
2175 let metadata = fetch_auth_server_metadata(&client, &issuer).await.unwrap();
2176
2177 assert_eq!(
2178 metadata.authorization_endpoint.as_str(),
2179 "https://auth.example.com/authorize"
2180 );
2181 assert!(!metadata.client_id_metadata_document_supported);
2182 });
2183 }
2184
2185 #[test]
2186 fn test_fetch_auth_server_metadata_rejects_issuer_mismatch() {
2187 smol::block_on(async {
2188 let client = make_fake_http_client(|req| {
2189 Box::pin(async move {
2190 let uri = req.uri().to_string();
2191 if uri.contains(".well-known/oauth-authorization-server") {
2192 // Response claims to be a different issuer.
2193 json_response(
2194 200,
2195 r#"{
2196 "issuer": "https://evil.example.com",
2197 "authorization_endpoint": "https://evil.example.com/authorize",
2198 "token_endpoint": "https://evil.example.com/token",
2199 "code_challenge_methods_supported": ["S256"]
2200 }"#,
2201 )
2202 } else {
2203 json_response(404, "{}")
2204 }
2205 })
2206 });
2207
2208 let issuer = Url::parse("https://auth.example.com").unwrap();
2209 let result = fetch_auth_server_metadata(&client, &issuer).await;
2210
2211 assert!(result.is_err());
2212 let err_msg = result.unwrap_err().to_string();
2213 assert!(
2214 err_msg.contains("issuer mismatch"),
2215 "unexpected error: {}",
2216 err_msg
2217 );
2218 });
2219 }
2220
2221 // -- Full discover integration tests -------------------------------------
2222
2223 #[test]
2224 fn test_full_discover_with_cimd() {
2225 smol::block_on(async {
2226 let client = make_fake_http_client(|req| {
2227 Box::pin(async move {
2228 let uri = req.uri().to_string();
2229 if uri.contains("oauth-protected-resource") {
2230 json_response(
2231 200,
2232 r#"{
2233 "resource": "https://mcp.example.com",
2234 "authorization_servers": ["https://auth.example.com"],
2235 "scopes_supported": ["mcp:read"]
2236 }"#,
2237 )
2238 } else if uri.contains("oauth-authorization-server") {
2239 json_response(
2240 200,
2241 r#"{
2242 "issuer": "https://auth.example.com",
2243 "authorization_endpoint": "https://auth.example.com/authorize",
2244 "token_endpoint": "https://auth.example.com/token",
2245 "code_challenge_methods_supported": ["S256"],
2246 "client_id_metadata_document_supported": true
2247 }"#,
2248 )
2249 } else {
2250 json_response(404, "{}")
2251 }
2252 })
2253 });
2254
2255 let server_url = Url::parse("https://mcp.example.com").unwrap();
2256 let www_auth = WwwAuthenticate {
2257 resource_metadata: None,
2258 scope: None,
2259 error: None,
2260 error_description: None,
2261 };
2262
2263 let discovery = discover(&client, &server_url, &www_auth).await.unwrap();
2264 let registration =
2265 resolve_client_registration(&client, &discovery, "http://127.0.0.1:12345/callback")
2266 .await
2267 .unwrap();
2268
2269 assert_eq!(registration.client_id, CIMD_URL);
2270 assert_eq!(registration.client_secret, None);
2271 assert_eq!(discovery.scopes, vec!["mcp:read"]);
2272 });
2273 }
2274
2275 #[test]
2276 fn test_full_discover_with_dcr_fallback() {
2277 smol::block_on(async {
2278 let client = make_fake_http_client(|req| {
2279 Box::pin(async move {
2280 let uri = req.uri().to_string();
2281 if uri.contains("oauth-protected-resource") {
2282 json_response(
2283 200,
2284 r#"{
2285 "resource": "https://mcp.example.com",
2286 "authorization_servers": ["https://auth.example.com"]
2287 }"#,
2288 )
2289 } else if uri.contains("oauth-authorization-server") {
2290 json_response(
2291 200,
2292 r#"{
2293 "issuer": "https://auth.example.com",
2294 "authorization_endpoint": "https://auth.example.com/authorize",
2295 "token_endpoint": "https://auth.example.com/token",
2296 "registration_endpoint": "https://auth.example.com/register",
2297 "code_challenge_methods_supported": ["S256"],
2298 "client_id_metadata_document_supported": false
2299 }"#,
2300 )
2301 } else if uri.contains("/register") {
2302 json_response(
2303 201,
2304 r#"{
2305 "client_id": "dcr-minted-id-123",
2306 "client_secret": "dcr-secret-456"
2307 }"#,
2308 )
2309 } else {
2310 json_response(404, "{}")
2311 }
2312 })
2313 });
2314
2315 let server_url = Url::parse("https://mcp.example.com").unwrap();
2316 let www_auth = WwwAuthenticate {
2317 resource_metadata: None,
2318 scope: Some(vec!["files:read".into()]),
2319 error: None,
2320 error_description: None,
2321 };
2322
2323 let discovery = discover(&client, &server_url, &www_auth).await.unwrap();
2324 let registration =
2325 resolve_client_registration(&client, &discovery, "http://127.0.0.1:9999/callback")
2326 .await
2327 .unwrap();
2328
2329 assert_eq!(registration.client_id, "dcr-minted-id-123");
2330 assert_eq!(
2331 registration.client_secret.as_deref(),
2332 Some("dcr-secret-456")
2333 );
2334 assert_eq!(discovery.scopes, vec!["files:read"]);
2335 });
2336 }
2337
2338 #[test]
2339 fn test_discover_fails_without_pkce_support() {
2340 smol::block_on(async {
2341 let client = make_fake_http_client(|req| {
2342 Box::pin(async move {
2343 let uri = req.uri().to_string();
2344 if uri.contains("oauth-protected-resource") {
2345 json_response(
2346 200,
2347 r#"{
2348 "resource": "https://mcp.example.com",
2349 "authorization_servers": ["https://auth.example.com"]
2350 }"#,
2351 )
2352 } else if uri.contains("oauth-authorization-server") {
2353 json_response(
2354 200,
2355 r#"{
2356 "issuer": "https://auth.example.com",
2357 "authorization_endpoint": "https://auth.example.com/authorize",
2358 "token_endpoint": "https://auth.example.com/token"
2359 }"#,
2360 )
2361 } else {
2362 json_response(404, "{}")
2363 }
2364 })
2365 });
2366
2367 let server_url = Url::parse("https://mcp.example.com").unwrap();
2368 let www_auth = WwwAuthenticate {
2369 resource_metadata: None,
2370 scope: None,
2371 error: None,
2372 error_description: None,
2373 };
2374
2375 let result = discover(&client, &server_url, &www_auth).await;
2376 assert!(result.is_err());
2377 let err_msg = result.unwrap_err().to_string();
2378 assert!(
2379 err_msg.contains("code_challenge_methods_supported"),
2380 "unexpected error: {}",
2381 err_msg
2382 );
2383 });
2384 }
2385
2386 // -- Token exchange integration tests ------------------------------------
2387
2388 #[test]
2389 fn test_exchange_code_success() {
2390 smol::block_on(async {
2391 let client = make_fake_http_client(|req| {
2392 Box::pin(async move {
2393 let uri = req.uri().to_string();
2394 if uri.contains("/token") {
2395 json_response(
2396 200,
2397 r#"{
2398 "access_token": "new_access_token",
2399 "refresh_token": "new_refresh_token",
2400 "expires_in": 3600,
2401 "token_type": "Bearer"
2402 }"#,
2403 )
2404 } else {
2405 json_response(404, "{}")
2406 }
2407 })
2408 });
2409
2410 let metadata = AuthServerMetadata {
2411 issuer: Url::parse("https://auth.example.com").unwrap(),
2412 authorization_endpoint: Url::parse("https://auth.example.com/authorize").unwrap(),
2413 token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
2414 registration_endpoint: None,
2415 scopes_supported: None,
2416 code_challenge_methods_supported: Some(vec!["S256".into()]),
2417 client_id_metadata_document_supported: true,
2418 };
2419
2420 let tokens = exchange_code(
2421 &client,
2422 &metadata,
2423 "auth_code_123",
2424 CIMD_URL,
2425 "http://127.0.0.1:9999/callback",
2426 "verifier_abc",
2427 "https://mcp.example.com",
2428 None,
2429 )
2430 .await
2431 .unwrap();
2432
2433 assert_eq!(tokens.access_token, "new_access_token");
2434 assert_eq!(tokens.refresh_token.as_deref(), Some("new_refresh_token"));
2435 assert!(tokens.expires_at.is_some());
2436 });
2437 }
2438
2439 #[test]
2440 fn test_refresh_tokens_success() {
2441 smol::block_on(async {
2442 let client = make_fake_http_client(|req| {
2443 Box::pin(async move {
2444 let uri = req.uri().to_string();
2445 if uri.contains("/token") {
2446 json_response(
2447 200,
2448 r#"{
2449 "access_token": "refreshed_token",
2450 "expires_in": 1800,
2451 "token_type": "Bearer"
2452 }"#,
2453 )
2454 } else {
2455 json_response(404, "{}")
2456 }
2457 })
2458 });
2459
2460 let token_endpoint = Url::parse("https://auth.example.com/token").unwrap();
2461
2462 let tokens = refresh_tokens(
2463 &client,
2464 &token_endpoint,
2465 "old_refresh_token",
2466 CIMD_URL,
2467 "https://mcp.example.com",
2468 None,
2469 )
2470 .await
2471 .unwrap();
2472
2473 assert_eq!(tokens.access_token, "refreshed_token");
2474 assert_eq!(tokens.refresh_token, None);
2475 assert!(tokens.expires_at.is_some());
2476 });
2477 }
2478
2479 #[test]
2480 fn test_exchange_code_failure() {
2481 smol::block_on(async {
2482 let client = make_fake_http_client(|_req| {
2483 Box::pin(async move { json_response(400, r#"{"error": "invalid_grant"}"#) })
2484 });
2485
2486 let metadata = AuthServerMetadata {
2487 issuer: Url::parse("https://auth.example.com").unwrap(),
2488 authorization_endpoint: Url::parse("https://auth.example.com/authorize").unwrap(),
2489 token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
2490 registration_endpoint: None,
2491 scopes_supported: None,
2492 code_challenge_methods_supported: Some(vec!["S256".into()]),
2493 client_id_metadata_document_supported: true,
2494 };
2495
2496 let result = exchange_code(
2497 &client,
2498 &metadata,
2499 "bad_code",
2500 "client",
2501 "http://127.0.0.1:1/callback",
2502 "verifier",
2503 "https://mcp.example.com",
2504 None,
2505 )
2506 .await;
2507
2508 assert!(result.is_err());
2509 assert!(result.unwrap_err().to_string().contains("400"));
2510 });
2511 }
2512
2513 // -- DCR integration tests -----------------------------------------------
2514
2515 #[test]
2516 fn test_perform_dcr() {
2517 smol::block_on(async {
2518 let client = make_fake_http_client(|_req| {
2519 Box::pin(async move {
2520 json_response(
2521 201,
2522 r#"{
2523 "client_id": "dynamic-client-001",
2524 "client_secret": "dynamic-secret-001"
2525 }"#,
2526 )
2527 })
2528 });
2529
2530 let endpoint = Url::parse("https://auth.example.com/register").unwrap();
2531 let registration = perform_dcr(&client, &endpoint, "http://127.0.0.1:9999/callback")
2532 .await
2533 .unwrap();
2534
2535 assert_eq!(registration.client_id, "dynamic-client-001");
2536 assert_eq!(
2537 registration.client_secret.as_deref(),
2538 Some("dynamic-secret-001")
2539 );
2540 });
2541 }
2542
2543 #[test]
2544 fn test_perform_dcr_failure() {
2545 smol::block_on(async {
2546 let client = make_fake_http_client(|_req| {
2547 Box::pin(
2548 async move { json_response(403, r#"{"error": "registration_not_allowed"}"#) },
2549 )
2550 });
2551
2552 let endpoint = Url::parse("https://auth.example.com/register").unwrap();
2553 let result = perform_dcr(&client, &endpoint, "http://127.0.0.1:9999/callback").await;
2554
2555 assert!(result.is_err());
2556 assert!(result.unwrap_err().to_string().contains("403"));
2557 });
2558 }
2559
2560 // -- OAuthCallback parse tests -------------------------------------------
2561
2562 #[test]
2563 fn test_oauth_callback_parse_query() {
2564 let callback = OAuthCallback::parse_query("code=test_auth_code&state=test_state").unwrap();
2565 assert_eq!(callback.code, "test_auth_code");
2566 assert_eq!(callback.state, "test_state");
2567 }
2568
2569 #[test]
2570 fn test_oauth_callback_parse_query_reversed_order() {
2571 let callback = OAuthCallback::parse_query("state=test_state&code=test_auth_code").unwrap();
2572 assert_eq!(callback.code, "test_auth_code");
2573 assert_eq!(callback.state, "test_state");
2574 }
2575
2576 #[test]
2577 fn test_oauth_callback_parse_query_with_extra_params() {
2578 let callback =
2579 OAuthCallback::parse_query("code=test_auth_code&state=test_state&extra=ignored")
2580 .unwrap();
2581 assert_eq!(callback.code, "test_auth_code");
2582 assert_eq!(callback.state, "test_state");
2583 }
2584
2585 #[test]
2586 fn test_oauth_callback_parse_query_missing_code() {
2587 let result = OAuthCallback::parse_query("state=test_state");
2588 assert!(result.is_err());
2589 assert!(result.unwrap_err().to_string().contains("code"));
2590 }
2591
2592 #[test]
2593 fn test_oauth_callback_parse_query_missing_state() {
2594 let result = OAuthCallback::parse_query("code=test_auth_code");
2595 assert!(result.is_err());
2596 assert!(result.unwrap_err().to_string().contains("state"));
2597 }
2598
2599 #[test]
2600 fn test_oauth_callback_parse_query_empty_code() {
2601 let result = OAuthCallback::parse_query("code=&state=test_state");
2602 assert!(result.is_err());
2603 }
2604
2605 #[test]
2606 fn test_oauth_callback_parse_query_empty_state() {
2607 let result = OAuthCallback::parse_query("code=test_auth_code&state=");
2608 assert!(result.is_err());
2609 }
2610
2611 #[test]
2612 fn test_oauth_callback_parse_query_url_encoded_values() {
2613 let callback = OAuthCallback::parse_query("code=abc%20def&state=test%3Dstate").unwrap();
2614 assert_eq!(callback.code, "abc def");
2615 assert_eq!(callback.state, "test=state");
2616 }
2617
2618 #[test]
2619 fn test_oauth_callback_parse_query_error_response() {
2620 let result = OAuthCallback::parse_query(
2621 "error=access_denied&error_description=User%20denied%20access&state=abc",
2622 );
2623 assert!(result.is_err());
2624 let err_msg = result.unwrap_err().to_string();
2625 assert!(
2626 err_msg.contains("access_denied"),
2627 "unexpected error: {}",
2628 err_msg
2629 );
2630 assert!(
2631 err_msg.contains("User denied access"),
2632 "unexpected error: {}",
2633 err_msg
2634 );
2635 }
2636
2637 #[test]
2638 fn test_oauth_callback_parse_query_error_without_description() {
2639 let result = OAuthCallback::parse_query("error=server_error&state=abc");
2640 assert!(result.is_err());
2641 let err_msg = result.unwrap_err().to_string();
2642 assert!(
2643 err_msg.contains("server_error"),
2644 "unexpected error: {}",
2645 err_msg
2646 );
2647 assert!(
2648 err_msg.contains("no description"),
2649 "unexpected error: {}",
2650 err_msg
2651 );
2652 }
2653
2654 // -- McpOAuthTokenProvider tests -----------------------------------------
2655
2656 fn make_test_session(
2657 access_token: &str,
2658 refresh_token: Option<&str>,
2659 expires_at: Option<SystemTime>,
2660 ) -> OAuthSession {
2661 OAuthSession {
2662 token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
2663 resource: Url::parse("https://mcp.example.com").unwrap(),
2664 client_registration: OAuthClientRegistration {
2665 client_id: "test-client".into(),
2666 client_secret: None,
2667 },
2668 tokens: OAuthTokens {
2669 access_token: access_token.into(),
2670 refresh_token: refresh_token.map(String::from),
2671 expires_at,
2672 },
2673 }
2674 }
2675
2676 #[test]
2677 fn test_mcp_oauth_provider_returns_none_when_token_expired() {
2678 let expired = SystemTime::now() - Duration::from_secs(60);
2679 let session = make_test_session("stale-token", Some("rt"), Some(expired));
2680 let provider = McpOAuthTokenProvider::new(
2681 session,
2682 make_fake_http_client(|_| Box::pin(async { unreachable!() })),
2683 None,
2684 );
2685
2686 assert_eq!(provider.access_token(), None);
2687 }
2688
2689 #[test]
2690 fn test_mcp_oauth_provider_returns_token_when_not_expired() {
2691 let far_future = SystemTime::now() + Duration::from_secs(3600);
2692 let session = make_test_session("valid-token", Some("rt"), Some(far_future));
2693 let provider = McpOAuthTokenProvider::new(
2694 session,
2695 make_fake_http_client(|_| Box::pin(async { unreachable!() })),
2696 None,
2697 );
2698
2699 assert_eq!(provider.access_token().as_deref(), Some("valid-token"));
2700 }
2701
2702 #[test]
2703 fn test_mcp_oauth_provider_returns_token_when_no_expiry() {
2704 let session = make_test_session("no-expiry-token", Some("rt"), None);
2705 let provider = McpOAuthTokenProvider::new(
2706 session,
2707 make_fake_http_client(|_| Box::pin(async { unreachable!() })),
2708 None,
2709 );
2710
2711 assert_eq!(provider.access_token().as_deref(), Some("no-expiry-token"));
2712 }
2713
2714 #[test]
2715 fn test_mcp_oauth_provider_refresh_without_refresh_token_returns_false() {
2716 smol::block_on(async {
2717 let session = make_test_session("token", None, None);
2718 let provider = McpOAuthTokenProvider::new(
2719 session,
2720 make_fake_http_client(|_| {
2721 Box::pin(async { unreachable!("no HTTP call expected") })
2722 }),
2723 None,
2724 );
2725
2726 let refreshed = provider.try_refresh().await.unwrap();
2727 assert!(!refreshed);
2728 });
2729 }
2730
2731 #[test]
2732 fn test_mcp_oauth_provider_refresh_updates_session_and_notifies_channel() {
2733 smol::block_on(async {
2734 let session = make_test_session("old-access", Some("my-refresh-token"), None);
2735 let (tx, mut rx) = futures::channel::mpsc::unbounded();
2736
2737 let http_client = make_fake_http_client(|_req| {
2738 Box::pin(async {
2739 json_response(
2740 200,
2741 r#"{
2742 "access_token": "new-access",
2743 "refresh_token": "new-refresh",
2744 "expires_in": 1800
2745 }"#,
2746 )
2747 })
2748 });
2749
2750 let provider = McpOAuthTokenProvider::new(session, http_client, Some(tx));
2751
2752 let refreshed = provider.try_refresh().await.unwrap();
2753 assert!(refreshed);
2754 assert_eq!(provider.access_token().as_deref(), Some("new-access"));
2755
2756 let notified_session = rx.try_recv().expect("channel should have a session");
2757 assert_eq!(notified_session.tokens.access_token, "new-access");
2758 assert_eq!(
2759 notified_session.tokens.refresh_token.as_deref(),
2760 Some("new-refresh")
2761 );
2762 });
2763 }
2764
2765 #[test]
2766 fn test_mcp_oauth_provider_refresh_preserves_old_refresh_token_when_server_omits_it() {
2767 smol::block_on(async {
2768 let session = make_test_session("old-access", Some("original-refresh"), None);
2769 let (tx, mut rx) = futures::channel::mpsc::unbounded();
2770
2771 let http_client = make_fake_http_client(|_req| {
2772 Box::pin(async {
2773 json_response(
2774 200,
2775 r#"{
2776 "access_token": "new-access",
2777 "expires_in": 900
2778 }"#,
2779 )
2780 })
2781 });
2782
2783 let provider = McpOAuthTokenProvider::new(session, http_client, Some(tx));
2784
2785 let refreshed = provider.try_refresh().await.unwrap();
2786 assert!(refreshed);
2787
2788 let notified_session = rx.try_recv().expect("channel should have a session");
2789 assert_eq!(notified_session.tokens.access_token, "new-access");
2790 assert_eq!(
2791 notified_session.tokens.refresh_token.as_deref(),
2792 Some("original-refresh"),
2793 );
2794 });
2795 }
2796
2797 #[test]
2798 fn test_mcp_oauth_provider_refresh_returns_false_on_http_error() {
2799 smol::block_on(async {
2800 let session = make_test_session("old-access", Some("my-refresh"), None);
2801
2802 let http_client = make_fake_http_client(|_req| {
2803 Box::pin(async { json_response(401, r#"{"error": "invalid_grant"}"#) })
2804 });
2805
2806 let provider = McpOAuthTokenProvider::new(session, http_client, None);
2807
2808 let refreshed = provider.try_refresh().await.unwrap();
2809 assert!(!refreshed);
2810 // The old token should still be in place.
2811 assert_eq!(provider.access_token().as_deref(), Some("old-access"));
2812 });
2813 }
2814}