>,
+)> {
+ let server = tiny_http::Server::http("127.0.0.1:0")
+ .map_err(|e| anyhow!(e).context("Failed to bind loopback listener for OAuth callback"))?;
+ let port = server
+ .server_addr()
+ .to_ip()
+ .context("server not bound to a TCP address")?
+ .port();
+
+ let redirect_uri = format!("http://127.0.0.1:{}/callback", port);
+
+ let (tx, rx) = futures::channel::oneshot::channel();
+
+ // `tiny_http` is blocking, so we run it on a background thread.
+ // The `recv_timeout` loop lets us check for cancellation (the receiver
+ // being dropped) and enforce an overall timeout.
+ std::thread::spawn(move || {
+ let deadline = std::time::Instant::now() + CALLBACK_TIMEOUT;
+
+ loop {
+ if tx.is_canceled() {
+ return;
+ }
+ let remaining = deadline.saturating_duration_since(std::time::Instant::now());
+ if remaining.is_zero() {
+ return;
+ }
+
+ let timeout = remaining.min(Duration::from_millis(500));
+ let Some(request) = (match server.recv_timeout(timeout) {
+ Ok(req) => req,
+ Err(_) => {
+ let _ = tx.send(Err(anyhow!("OAuth callback server I/O error")));
+ return;
+ }
+ }) else {
+ // Timeout with no request — loop back and check cancellation.
+ continue;
+ };
+
+ let result = handle_callback_request(&request);
+
+ let (status_code, body) = match &result {
+ Ok(_) => (
+ 200,
+ "Authorization successful
\
+ You can close this tab and return to Zed.
",
+ ),
+ Err(err) => {
+ log::error!("OAuth callback error: {}", err);
+ (
+ 400,
+ "Authorization failed
\
+ Something went wrong. Please try again from Zed.
",
+ )
+ }
+ };
+
+ let response = tiny_http::Response::from_string(body)
+ .with_status_code(status_code)
+ .with_header(
+ tiny_http::Header::from_str("Content-Type: text/html")
+ .expect("failed to construct response header"),
+ )
+ .with_header(
+ tiny_http::Header::from_str("Keep-Alive: timeout=0,max=0")
+ .expect("failed to construct response header"),
+ );
+ request.respond(response).log_err();
+
+ let _ = tx.send(result);
+ return;
+ }
+ });
+
+ Ok((redirect_uri, rx))
+}
+
+/// Extract the `code` and `state` query parameters from an OAuth callback
+/// request to `/callback`.
+fn handle_callback_request(request: &tiny_http::Request) -> Result {
+ let url = Url::parse(&format!("http://localhost{}", request.url()))
+ .context("malformed callback request URL")?;
+
+ if url.path() != "/callback" {
+ bail!("unexpected path in OAuth callback: {}", url.path());
+ }
+
+ let query = url
+ .query()
+ .ok_or_else(|| anyhow!("OAuth callback has no query string"))?;
+ OAuthCallback::parse_query(query)
+}
+
+// -- JSON fetch helper -------------------------------------------------------
+
+async fn fetch_json(
+ http_client: &Arc,
+ url: &Url,
+) -> Result {
+ validate_oauth_url(url)?;
+
+ let request = Request::builder()
+ .method(http_client::http::Method::GET)
+ .uri(url.as_str())
+ .header("Accept", "application/json")
+ .body(AsyncBody::default())?;
+
+ let mut response = http_client.send(request).await?;
+
+ if !response.status().is_success() {
+ bail!("HTTP {} fetching {}", response.status(), url);
+ }
+
+ let mut body = String::new();
+ response.body_mut().read_to_string(&mut body).await?;
+ serde_json::from_str(&body).with_context(|| format!("failed to parse JSON from {}", url))
+}
+
+// -- Serde response types for discovery --------------------------------------
+
+#[derive(Debug, Deserialize)]
+struct ProtectedResourceMetadataResponse {
+ #[serde(default)]
+ resource: Option,
+ #[serde(default)]
+ authorization_servers: Vec,
+ #[serde(default)]
+ scopes_supported: Option>,
+}
+
+#[derive(Debug, Deserialize)]
+struct AuthServerMetadataResponse {
+ #[serde(default)]
+ issuer: Option,
+ #[serde(default)]
+ authorization_endpoint: Option,
+ #[serde(default)]
+ token_endpoint: Option,
+ #[serde(default)]
+ registration_endpoint: Option,
+ #[serde(default)]
+ scopes_supported: Option>,
+ #[serde(default)]
+ code_challenge_methods_supported: Option>,
+ #[serde(default)]
+ client_id_metadata_document_supported: Option,
+}
+
+#[derive(Debug, Deserialize)]
+struct DcrResponse {
+ client_id: String,
+ #[serde(default)]
+ client_secret: Option,
+}
+
+/// Provides OAuth tokens to the HTTP transport layer.
+///
+/// The transport calls `access_token()` before each request. On a 401 response
+/// it calls `try_refresh()` and retries once if the refresh succeeds.
+#[async_trait]
+pub trait OAuthTokenProvider: Send + Sync {
+ /// Returns the current access token, if one is available.
+ fn access_token(&self) -> Option;
+
+ /// Attempts to refresh the access token. Returns `true` if a new token was
+ /// obtained and the request should be retried.
+ async fn try_refresh(&self) -> Result;
+}
+
+/// Concrete `OAuthTokenProvider` backed by a full persisted OAuth session and
+/// an HTTP client for token refresh. The same provider type is used both after
+/// an interactive authentication flow and when restoring a saved session from
+/// the keychain on startup.
+pub struct McpOAuthTokenProvider {
+ session: SyncMutex,
+ http_client: Arc,
+ token_refresh_tx: Option>,
+}
+
+impl McpOAuthTokenProvider {
+ pub fn new(
+ session: OAuthSession,
+ http_client: Arc,
+ token_refresh_tx: Option>,
+ ) -> Self {
+ Self {
+ session: SyncMutex::new(session),
+ http_client,
+ token_refresh_tx,
+ }
+ }
+
+ fn access_token_is_expired(tokens: &OAuthTokens) -> bool {
+ tokens.expires_at.is_some_and(|expires_at| {
+ SystemTime::now()
+ .checked_add(Duration::from_secs(30))
+ .is_some_and(|now_with_buffer| expires_at <= now_with_buffer)
+ })
+ }
+}
+
+#[async_trait]
+impl OAuthTokenProvider for McpOAuthTokenProvider {
+ fn access_token(&self) -> Option {
+ let session = self.session.lock();
+ if Self::access_token_is_expired(&session.tokens) {
+ return None;
+ }
+ Some(session.tokens.access_token.clone())
+ }
+
+ async fn try_refresh(&self) -> Result {
+ let (refresh_token, token_endpoint, resource, client_id) = {
+ let session = self.session.lock();
+ match session.tokens.refresh_token.clone() {
+ Some(refresh_token) => (
+ refresh_token,
+ session.token_endpoint.clone(),
+ session.resource.clone(),
+ session.client_registration.client_id.clone(),
+ ),
+ None => return Ok(false),
+ }
+ };
+
+ let resource_str = canonical_server_uri(&resource);
+
+ match refresh_tokens(
+ &self.http_client,
+ &token_endpoint,
+ &refresh_token,
+ &client_id,
+ &resource_str,
+ )
+ .await
+ {
+ Ok(mut new_tokens) => {
+ if new_tokens.refresh_token.is_none() {
+ new_tokens.refresh_token = Some(refresh_token);
+ }
+
+ {
+ let mut session = self.session.lock();
+ session.tokens = new_tokens;
+
+ if let Some(ref tx) = self.token_refresh_tx {
+ tx.unbounded_send(session.clone()).ok();
+ }
+ }
+
+ Ok(true)
+ }
+ Err(err) => {
+ log::warn!("OAuth token refresh failed: {}", err);
+ Ok(false)
+ }
+ }
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use http_client::Response;
+
+ // -- require_https_or_loopback tests ------------------------------------
+
+ #[test]
+ fn test_require_https_or_loopback_accepts_https() {
+ let url = Url::parse("https://auth.example.com/token").unwrap();
+ assert!(require_https_or_loopback(&url).is_ok());
+ }
+
+ #[test]
+ fn test_require_https_or_loopback_rejects_http_remote() {
+ let url = Url::parse("http://auth.example.com/token").unwrap();
+ assert!(require_https_or_loopback(&url).is_err());
+ }
+
+ #[test]
+ fn test_require_https_or_loopback_accepts_http_127_0_0_1() {
+ let url = Url::parse("http://127.0.0.1:8080/callback").unwrap();
+ assert!(require_https_or_loopback(&url).is_ok());
+ }
+
+ #[test]
+ fn test_require_https_or_loopback_accepts_http_ipv6_loopback() {
+ let url = Url::parse("http://[::1]:8080/callback").unwrap();
+ assert!(require_https_or_loopback(&url).is_ok());
+ }
+
+ #[test]
+ fn test_require_https_or_loopback_accepts_http_localhost() {
+ let url = Url::parse("http://localhost:8080/callback").unwrap();
+ assert!(require_https_or_loopback(&url).is_ok());
+ }
+
+ #[test]
+ fn test_require_https_or_loopback_accepts_http_localhost_case_insensitive() {
+ let url = Url::parse("http://LOCALHOST:8080/callback").unwrap();
+ assert!(require_https_or_loopback(&url).is_ok());
+ }
+
+ #[test]
+ fn test_require_https_or_loopback_rejects_http_non_loopback_ip() {
+ let url = Url::parse("http://192.168.1.1:8080/token").unwrap();
+ assert!(require_https_or_loopback(&url).is_err());
+ }
+
+ #[test]
+ fn test_require_https_or_loopback_rejects_ftp() {
+ let url = Url::parse("ftp://auth.example.com/token").unwrap();
+ assert!(require_https_or_loopback(&url).is_err());
+ }
+
+ // -- validate_oauth_url (SSRF) tests ------------------------------------
+
+ #[test]
+ fn test_validate_oauth_url_accepts_https_public() {
+ let url = Url::parse("https://auth.example.com/token").unwrap();
+ assert!(validate_oauth_url(&url).is_ok());
+ }
+
+ #[test]
+ fn test_validate_oauth_url_rejects_private_ipv4_10() {
+ let url = Url::parse("https://10.0.0.1/token").unwrap();
+ assert!(validate_oauth_url(&url).is_err());
+ }
+
+ #[test]
+ fn test_validate_oauth_url_rejects_private_ipv4_172() {
+ let url = Url::parse("https://172.16.0.1/token").unwrap();
+ assert!(validate_oauth_url(&url).is_err());
+ }
+
+ #[test]
+ fn test_validate_oauth_url_rejects_private_ipv4_192() {
+ let url = Url::parse("https://192.168.1.1/token").unwrap();
+ assert!(validate_oauth_url(&url).is_err());
+ }
+
+ #[test]
+ fn test_validate_oauth_url_rejects_link_local() {
+ let url = Url::parse("https://169.254.169.254/latest/meta-data/").unwrap();
+ assert!(validate_oauth_url(&url).is_err());
+ }
+
+ #[test]
+ fn test_validate_oauth_url_rejects_ipv6_ula() {
+ let url = Url::parse("https://[fd12:3456:789a::1]/token").unwrap();
+ assert!(validate_oauth_url(&url).is_err());
+ }
+
+ #[test]
+ fn test_validate_oauth_url_rejects_ipv6_unspecified() {
+ let url = Url::parse("https://[::]/token").unwrap();
+ assert!(validate_oauth_url(&url).is_err());
+ }
+
+ #[test]
+ fn test_validate_oauth_url_rejects_ipv4_mapped_ipv6_private() {
+ let url = Url::parse("https://[::ffff:10.0.0.1]/token").unwrap();
+ assert!(validate_oauth_url(&url).is_err());
+ }
+
+ #[test]
+ fn test_validate_oauth_url_rejects_ipv4_mapped_ipv6_link_local() {
+ let url = Url::parse("https://[::ffff:169.254.169.254]/token").unwrap();
+ assert!(validate_oauth_url(&url).is_err());
+ }
+
+ #[test]
+ fn test_validate_oauth_url_allows_http_loopback() {
+ // Loopback is permitted (it's our callback server).
+ let url = Url::parse("http://127.0.0.1:8080/callback").unwrap();
+ assert!(validate_oauth_url(&url).is_ok());
+ }
+
+ #[test]
+ fn test_validate_oauth_url_allows_https_public_ip() {
+ let url = Url::parse("https://93.184.216.34/token").unwrap();
+ assert!(validate_oauth_url(&url).is_ok());
+ }
+
+ // -- parse_www_authenticate tests ----------------------------------------
+
+ #[test]
+ fn test_parse_www_authenticate_with_resource_metadata_and_scope() {
+ let header = r#"Bearer resource_metadata="https://mcp.example.com/.well-known/oauth-protected-resource", scope="files:read user:profile""#;
+ let result = parse_www_authenticate(header).unwrap();
+
+ assert_eq!(
+ result.resource_metadata.as_ref().map(|u| u.as_str()),
+ Some("https://mcp.example.com/.well-known/oauth-protected-resource")
+ );
+ assert_eq!(
+ result.scope,
+ Some(vec!["files:read".to_string(), "user:profile".to_string()])
+ );
+ assert_eq!(result.error, None);
+ }
+
+ #[test]
+ fn test_parse_www_authenticate_resource_metadata_only() {
+ let header = r#"Bearer resource_metadata="https://mcp.example.com/.well-known/oauth-protected-resource""#;
+ let result = parse_www_authenticate(header).unwrap();
+
+ assert_eq!(
+ result.resource_metadata.as_ref().map(|u| u.as_str()),
+ Some("https://mcp.example.com/.well-known/oauth-protected-resource")
+ );
+ assert_eq!(result.scope, None);
+ }
+
+ #[test]
+ fn test_parse_www_authenticate_bare_bearer() {
+ let result = parse_www_authenticate("Bearer").unwrap();
+ assert_eq!(result.resource_metadata, None);
+ assert_eq!(result.scope, None);
+ }
+
+ #[test]
+ fn test_parse_www_authenticate_with_error() {
+ let header = r#"Bearer error="insufficient_scope", scope="files:read files:write", resource_metadata="https://mcp.example.com/.well-known/oauth-protected-resource", error_description="Additional file write permission required""#;
+ let result = parse_www_authenticate(header).unwrap();
+
+ assert_eq!(result.error, Some(BearerError::InsufficientScope));
+ assert_eq!(
+ result.error_description.as_deref(),
+ Some("Additional file write permission required")
+ );
+ assert_eq!(
+ result.scope,
+ Some(vec!["files:read".to_string(), "files:write".to_string()])
+ );
+ assert!(result.resource_metadata.is_some());
+ }
+
+ #[test]
+ fn test_parse_www_authenticate_invalid_token_error() {
+ let header =
+ r#"Bearer error="invalid_token", error_description="The access token expired""#;
+ let result = parse_www_authenticate(header).unwrap();
+ assert_eq!(result.error, Some(BearerError::InvalidToken));
+ }
+
+ #[test]
+ fn test_parse_www_authenticate_invalid_request_error() {
+ let header = r#"Bearer error="invalid_request""#;
+ let result = parse_www_authenticate(header).unwrap();
+ assert_eq!(result.error, Some(BearerError::InvalidRequest));
+ }
+
+ #[test]
+ fn test_parse_www_authenticate_unknown_error() {
+ let header = r#"Bearer error="some_future_error""#;
+ let result = parse_www_authenticate(header).unwrap();
+ assert_eq!(result.error, Some(BearerError::Other));
+ }
+
+ #[test]
+ fn test_parse_www_authenticate_rejects_non_bearer() {
+ let result = parse_www_authenticate("Basic realm=\"example\"");
+ assert!(result.is_err());
+ }
+
+ #[test]
+ fn test_parse_www_authenticate_case_insensitive_scheme() {
+ let header = r#"bearer resource_metadata="https://example.com/.well-known/oauth-protected-resource""#;
+ let result = parse_www_authenticate(header).unwrap();
+ assert!(result.resource_metadata.is_some());
+ }
+
+ #[test]
+ fn test_parse_www_authenticate_multiline_style() {
+ // Some servers emit the header spread across multiple lines joined by
+ // whitespace, as shown in the spec examples.
+ let header = "Bearer resource_metadata=\"https://mcp.example.com/.well-known/oauth-protected-resource\",\n scope=\"files:read\"";
+ let result = parse_www_authenticate(header).unwrap();
+ assert!(result.resource_metadata.is_some());
+ assert_eq!(result.scope, Some(vec!["files:read".to_string()]));
+ }
+
+ #[test]
+ fn test_protected_resource_metadata_urls_with_path() {
+ let server_url = Url::parse("https://api.example.com/v1/mcp").unwrap();
+ let urls = protected_resource_metadata_urls(&server_url);
+
+ assert_eq!(urls.len(), 2);
+ assert_eq!(
+ urls[0].as_str(),
+ "https://api.example.com/.well-known/oauth-protected-resource/v1/mcp"
+ );
+ assert_eq!(
+ urls[1].as_str(),
+ "https://api.example.com/.well-known/oauth-protected-resource"
+ );
+ }
+
+ #[test]
+ fn test_protected_resource_metadata_urls_without_path() {
+ let server_url = Url::parse("https://mcp.example.com").unwrap();
+ let urls = protected_resource_metadata_urls(&server_url);
+
+ assert_eq!(urls.len(), 1);
+ assert_eq!(
+ urls[0].as_str(),
+ "https://mcp.example.com/.well-known/oauth-protected-resource"
+ );
+ }
+
+ #[test]
+ fn test_auth_server_metadata_urls_with_path() {
+ let issuer = Url::parse("https://auth.example.com/tenant1").unwrap();
+ let urls = auth_server_metadata_urls(&issuer);
+
+ assert_eq!(urls.len(), 3);
+ assert_eq!(
+ urls[0].as_str(),
+ "https://auth.example.com/.well-known/oauth-authorization-server/tenant1"
+ );
+ assert_eq!(
+ urls[1].as_str(),
+ "https://auth.example.com/.well-known/openid-configuration/tenant1"
+ );
+ assert_eq!(
+ urls[2].as_str(),
+ "https://auth.example.com/tenant1/.well-known/openid-configuration"
+ );
+ }
+
+ #[test]
+ fn test_auth_server_metadata_urls_without_path() {
+ let issuer = Url::parse("https://auth.example.com").unwrap();
+ let urls = auth_server_metadata_urls(&issuer);
+
+ assert_eq!(urls.len(), 2);
+ assert_eq!(
+ urls[0].as_str(),
+ "https://auth.example.com/.well-known/oauth-authorization-server"
+ );
+ assert_eq!(
+ urls[1].as_str(),
+ "https://auth.example.com/.well-known/openid-configuration"
+ );
+ }
+
+ // -- Canonical server URI tests ------------------------------------------
+
+ #[test]
+ fn test_canonical_server_uri_simple() {
+ let url = Url::parse("https://mcp.example.com").unwrap();
+ assert_eq!(canonical_server_uri(&url), "https://mcp.example.com");
+ }
+
+ #[test]
+ fn test_canonical_server_uri_with_path() {
+ let url = Url::parse("https://mcp.example.com/v1/mcp").unwrap();
+ assert_eq!(canonical_server_uri(&url), "https://mcp.example.com/v1/mcp");
+ }
+
+ #[test]
+ fn test_canonical_server_uri_strips_trailing_slash() {
+ let url = Url::parse("https://mcp.example.com/").unwrap();
+ assert_eq!(canonical_server_uri(&url), "https://mcp.example.com");
+ }
+
+ #[test]
+ fn test_canonical_server_uri_preserves_port() {
+ let url = Url::parse("https://mcp.example.com:8443").unwrap();
+ assert_eq!(canonical_server_uri(&url), "https://mcp.example.com:8443");
+ }
+
+ #[test]
+ fn test_canonical_server_uri_lowercases() {
+ let url = Url::parse("HTTPS://MCP.Example.COM/Server/MCP").unwrap();
+ assert_eq!(
+ canonical_server_uri(&url),
+ "https://mcp.example.com/Server/MCP"
+ );
+ }
+
+ // -- Scope selection tests -----------------------------------------------
+
+ #[test]
+ fn test_select_scopes_prefers_www_authenticate() {
+ let www_auth = WwwAuthenticate {
+ resource_metadata: None,
+ scope: Some(vec!["files:read".into()]),
+ error: None,
+ error_description: None,
+ };
+ let resource_meta = ProtectedResourceMetadata {
+ resource: Url::parse("https://example.com").unwrap(),
+ authorization_servers: vec![],
+ scopes_supported: Some(vec!["files:read".into(), "files:write".into()]),
+ };
+ assert_eq!(select_scopes(&www_auth, &resource_meta), vec!["files:read"]);
+ }
+
+ #[test]
+ fn test_select_scopes_falls_back_to_resource_metadata() {
+ let www_auth = WwwAuthenticate {
+ resource_metadata: None,
+ scope: None,
+ error: None,
+ error_description: None,
+ };
+ let resource_meta = ProtectedResourceMetadata {
+ resource: Url::parse("https://example.com").unwrap(),
+ authorization_servers: vec![],
+ scopes_supported: Some(vec!["admin".into()]),
+ };
+ assert_eq!(select_scopes(&www_auth, &resource_meta), vec!["admin"]);
+ }
+
+ #[test]
+ fn test_select_scopes_empty_when_nothing_available() {
+ let www_auth = WwwAuthenticate {
+ resource_metadata: None,
+ scope: None,
+ error: None,
+ error_description: None,
+ };
+ let resource_meta = ProtectedResourceMetadata {
+ resource: Url::parse("https://example.com").unwrap(),
+ authorization_servers: vec![],
+ scopes_supported: None,
+ };
+ assert!(select_scopes(&www_auth, &resource_meta).is_empty());
+ }
+
+ // -- Client registration strategy tests ----------------------------------
+
+ #[test]
+ fn test_registration_strategy_prefers_cimd() {
+ let metadata = AuthServerMetadata {
+ issuer: Url::parse("https://auth.example.com").unwrap(),
+ authorization_endpoint: Url::parse("https://auth.example.com/authorize").unwrap(),
+ token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
+ registration_endpoint: Some(Url::parse("https://auth.example.com/register").unwrap()),
+ scopes_supported: None,
+ code_challenge_methods_supported: Some(vec!["S256".into()]),
+ client_id_metadata_document_supported: true,
+ };
+ assert_eq!(
+ determine_registration_strategy(&metadata),
+ ClientRegistrationStrategy::Cimd {
+ client_id: CIMD_URL.to_string(),
+ }
+ );
+ }
+
+ #[test]
+ fn test_registration_strategy_falls_back_to_dcr() {
+ let reg_endpoint = Url::parse("https://auth.example.com/register").unwrap();
+ let metadata = AuthServerMetadata {
+ issuer: Url::parse("https://auth.example.com").unwrap(),
+ authorization_endpoint: Url::parse("https://auth.example.com/authorize").unwrap(),
+ token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
+ registration_endpoint: Some(reg_endpoint.clone()),
+ scopes_supported: None,
+ code_challenge_methods_supported: Some(vec!["S256".into()]),
+ client_id_metadata_document_supported: false,
+ };
+ assert_eq!(
+ determine_registration_strategy(&metadata),
+ ClientRegistrationStrategy::Dcr {
+ registration_endpoint: reg_endpoint,
+ }
+ );
+ }
+
+ #[test]
+ fn test_registration_strategy_unavailable() {
+ let metadata = AuthServerMetadata {
+ issuer: Url::parse("https://auth.example.com").unwrap(),
+ authorization_endpoint: Url::parse("https://auth.example.com/authorize").unwrap(),
+ token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
+ registration_endpoint: None,
+ scopes_supported: None,
+ code_challenge_methods_supported: Some(vec!["S256".into()]),
+ client_id_metadata_document_supported: false,
+ };
+ assert_eq!(
+ determine_registration_strategy(&metadata),
+ ClientRegistrationStrategy::Unavailable,
+ );
+ }
+
+ // -- PKCE tests ----------------------------------------------------------
+
+ #[test]
+ fn test_pkce_challenge_verifier_length() {
+ let pkce = generate_pkce_challenge();
+ // 32 random bytes → 43 base64url chars (no padding).
+ assert_eq!(pkce.verifier.len(), 43);
+ }
+
+ #[test]
+ fn test_pkce_challenge_is_valid_base64url() {
+ let pkce = generate_pkce_challenge();
+ for c in pkce.verifier.chars().chain(pkce.challenge.chars()) {
+ assert!(
+ c.is_ascii_alphanumeric() || c == '-' || c == '_',
+ "invalid base64url character: {}",
+ c
+ );
+ }
+ }
+
+ #[test]
+ fn test_pkce_challenge_is_s256_of_verifier() {
+ let pkce = generate_pkce_challenge();
+ let engine = base64::engine::general_purpose::URL_SAFE_NO_PAD;
+ let expected_digest = Sha256::digest(pkce.verifier.as_bytes());
+ let expected_challenge = engine.encode(expected_digest);
+ assert_eq!(pkce.challenge, expected_challenge);
+ }
+
+ #[test]
+ fn test_pkce_challenges_are_unique() {
+ let a = generate_pkce_challenge();
+ let b = generate_pkce_challenge();
+ assert_ne!(a.verifier, b.verifier);
+ }
+
+ // -- Authorization URL tests ---------------------------------------------
+
+ #[test]
+ fn test_build_authorization_url() {
+ let metadata = AuthServerMetadata {
+ issuer: Url::parse("https://auth.example.com").unwrap(),
+ authorization_endpoint: Url::parse("https://auth.example.com/authorize").unwrap(),
+ token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
+ registration_endpoint: None,
+ scopes_supported: None,
+ code_challenge_methods_supported: Some(vec!["S256".into()]),
+ client_id_metadata_document_supported: true,
+ };
+ let pkce = PkceChallenge {
+ verifier: "test_verifier".into(),
+ challenge: "test_challenge".into(),
+ };
+ let url = build_authorization_url(
+ &metadata,
+ "https://zed.dev/oauth/client-metadata.json",
+ "http://127.0.0.1:12345/callback",
+ &["files:read".into(), "files:write".into()],
+ "https://mcp.example.com",
+ &pkce,
+ "random_state_123",
+ );
+
+ let pairs: std::collections::HashMap<_, _> = url.query_pairs().collect();
+ assert_eq!(pairs.get("response_type").unwrap(), "code");
+ assert_eq!(
+ pairs.get("client_id").unwrap(),
+ "https://zed.dev/oauth/client-metadata.json"
+ );
+ assert_eq!(
+ pairs.get("redirect_uri").unwrap(),
+ "http://127.0.0.1:12345/callback"
+ );
+ assert_eq!(pairs.get("scope").unwrap(), "files:read files:write");
+ assert_eq!(pairs.get("resource").unwrap(), "https://mcp.example.com");
+ assert_eq!(pairs.get("code_challenge").unwrap(), "test_challenge");
+ assert_eq!(pairs.get("code_challenge_method").unwrap(), "S256");
+ assert_eq!(pairs.get("state").unwrap(), "random_state_123");
+ }
+
+ #[test]
+ fn test_build_authorization_url_omits_empty_scope() {
+ let metadata = AuthServerMetadata {
+ issuer: Url::parse("https://auth.example.com").unwrap(),
+ authorization_endpoint: Url::parse("https://auth.example.com/authorize").unwrap(),
+ token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
+ registration_endpoint: None,
+ scopes_supported: None,
+ code_challenge_methods_supported: Some(vec!["S256".into()]),
+ client_id_metadata_document_supported: false,
+ };
+ let pkce = PkceChallenge {
+ verifier: "v".into(),
+ challenge: "c".into(),
+ };
+ let url = build_authorization_url(
+ &metadata,
+ "client_123",
+ "http://127.0.0.1:9999/callback",
+ &[],
+ "https://mcp.example.com",
+ &pkce,
+ "state",
+ );
+
+ let pairs: std::collections::HashMap<_, _> = url.query_pairs().collect();
+ assert!(!pairs.contains_key("scope"));
+ }
+
+ // -- Token exchange / refresh param tests --------------------------------
+
+ #[test]
+ fn test_token_exchange_params() {
+ let params = token_exchange_params(
+ "auth_code_abc",
+ "client_xyz",
+ "http://127.0.0.1:5555/callback",
+ "verifier_123",
+ "https://mcp.example.com",
+ );
+ let map: std::collections::HashMap<&str, &str> =
+ params.iter().map(|(k, v)| (*k, v.as_str())).collect();
+
+ assert_eq!(map["grant_type"], "authorization_code");
+ assert_eq!(map["code"], "auth_code_abc");
+ assert_eq!(map["redirect_uri"], "http://127.0.0.1:5555/callback");
+ assert_eq!(map["client_id"], "client_xyz");
+ assert_eq!(map["code_verifier"], "verifier_123");
+ assert_eq!(map["resource"], "https://mcp.example.com");
+ }
+
+ #[test]
+ fn test_token_refresh_params() {
+ let params =
+ token_refresh_params("refresh_token_abc", "client_xyz", "https://mcp.example.com");
+ let map: std::collections::HashMap<&str, &str> =
+ params.iter().map(|(k, v)| (*k, v.as_str())).collect();
+
+ assert_eq!(map["grant_type"], "refresh_token");
+ assert_eq!(map["refresh_token"], "refresh_token_abc");
+ assert_eq!(map["client_id"], "client_xyz");
+ assert_eq!(map["resource"], "https://mcp.example.com");
+ }
+
+ // -- Token response tests ------------------------------------------------
+
+ #[test]
+ fn test_token_response_into_tokens_with_expiry() {
+ let response: TokenResponse = serde_json::from_str(
+ r#"{"access_token": "at_123", "refresh_token": "rt_456", "expires_in": 3600, "token_type": "Bearer"}"#,
+ )
+ .unwrap();
+
+ let tokens = response.into_tokens();
+ assert_eq!(tokens.access_token, "at_123");
+ assert_eq!(tokens.refresh_token.as_deref(), Some("rt_456"));
+ assert!(tokens.expires_at.is_some());
+ }
+
+ #[test]
+ fn test_token_response_into_tokens_minimal() {
+ let response: TokenResponse =
+ serde_json::from_str(r#"{"access_token": "at_789"}"#).unwrap();
+
+ let tokens = response.into_tokens();
+ assert_eq!(tokens.access_token, "at_789");
+ assert_eq!(tokens.refresh_token, None);
+ assert_eq!(tokens.expires_at, None);
+ }
+
+ // -- DCR body test -------------------------------------------------------
+
+ #[test]
+ fn test_dcr_registration_body_shape() {
+ let body = dcr_registration_body("http://127.0.0.1:12345/callback");
+ assert_eq!(body["client_name"], "Zed");
+ assert_eq!(body["redirect_uris"][0], "http://127.0.0.1:12345/callback");
+ assert_eq!(body["grant_types"][0], "authorization_code");
+ assert_eq!(body["response_types"][0], "code");
+ assert_eq!(body["token_endpoint_auth_method"], "none");
+ }
+
+ // -- Test helpers for async/HTTP tests -----------------------------------
+
+ fn make_fake_http_client(
+ handler: impl Fn(
+ http_client::Request,
+ ) -> std::pin::Pin<
+ Box>> + Send>,
+ > + Send
+ + Sync
+ + 'static,
+ ) -> Arc {
+ http_client::FakeHttpClient::create(handler) as Arc
+ }
+
+ fn json_response(status: u16, body: &str) -> anyhow::Result> {
+ Ok(Response::builder()
+ .status(status)
+ .header("Content-Type", "application/json")
+ .body(AsyncBody::from(body.as_bytes().to_vec()))
+ .unwrap())
+ }
+
+ // -- Discovery integration tests -----------------------------------------
+
+ #[test]
+ fn test_fetch_protected_resource_metadata() {
+ smol::block_on(async {
+ let client = make_fake_http_client(|req| {
+ Box::pin(async move {
+ let uri = req.uri().to_string();
+ if uri.contains(".well-known/oauth-protected-resource") {
+ json_response(
+ 200,
+ r#"{
+ "resource": "https://mcp.example.com",
+ "authorization_servers": ["https://auth.example.com"],
+ "scopes_supported": ["read", "write"]
+ }"#,
+ )
+ } else {
+ json_response(404, "{}")
+ }
+ })
+ });
+
+ let server_url = Url::parse("https://mcp.example.com").unwrap();
+ let www_auth = WwwAuthenticate {
+ resource_metadata: None,
+ scope: None,
+ error: None,
+ error_description: None,
+ };
+
+ let metadata = fetch_protected_resource_metadata(&client, &server_url, &www_auth)
+ .await
+ .unwrap();
+
+ assert_eq!(metadata.resource.as_str(), "https://mcp.example.com/");
+ assert_eq!(metadata.authorization_servers.len(), 1);
+ assert_eq!(
+ metadata.authorization_servers[0].as_str(),
+ "https://auth.example.com/"
+ );
+ assert_eq!(
+ metadata.scopes_supported,
+ Some(vec!["read".to_string(), "write".to_string()])
+ );
+ });
+ }
+
+ #[test]
+ fn test_fetch_protected_resource_metadata_prefers_www_authenticate_url() {
+ smol::block_on(async {
+ let client = make_fake_http_client(|req| {
+ Box::pin(async move {
+ let uri = req.uri().to_string();
+ if uri == "https://mcp.example.com/custom-resource-metadata" {
+ json_response(
+ 200,
+ r#"{
+ "resource": "https://mcp.example.com",
+ "authorization_servers": ["https://auth.example.com"]
+ }"#,
+ )
+ } else {
+ json_response(500, r#"{"error": "should not be called"}"#)
+ }
+ })
+ });
+
+ let server_url = Url::parse("https://mcp.example.com").unwrap();
+ let www_auth = WwwAuthenticate {
+ resource_metadata: Some(
+ Url::parse("https://mcp.example.com/custom-resource-metadata").unwrap(),
+ ),
+ scope: None,
+ error: None,
+ error_description: None,
+ };
+
+ let metadata = fetch_protected_resource_metadata(&client, &server_url, &www_auth)
+ .await
+ .unwrap();
+
+ assert_eq!(metadata.authorization_servers.len(), 1);
+ });
+ }
+
+ #[test]
+ fn test_fetch_protected_resource_metadata_rejects_cross_origin_url() {
+ smol::block_on(async {
+ let client = make_fake_http_client(|req| {
+ Box::pin(async move {
+ let uri = req.uri().to_string();
+ // The cross-origin URL should NOT be fetched; only the
+ // well-known fallback at the server's own origin should be.
+ if uri.contains("attacker.example.com") {
+ panic!("should not fetch cross-origin resource_metadata URL");
+ } else if uri.contains(".well-known/oauth-protected-resource") {
+ json_response(
+ 200,
+ r#"{
+ "resource": "https://mcp.example.com",
+ "authorization_servers": ["https://auth.example.com"]
+ }"#,
+ )
+ } else {
+ json_response(404, "{}")
+ }
+ })
+ });
+
+ let server_url = Url::parse("https://mcp.example.com").unwrap();
+ let www_auth = WwwAuthenticate {
+ resource_metadata: Some(
+ Url::parse("https://attacker.example.com/fake-metadata").unwrap(),
+ ),
+ scope: None,
+ error: None,
+ error_description: None,
+ };
+
+ let metadata = fetch_protected_resource_metadata(&client, &server_url, &www_auth)
+ .await
+ .unwrap();
+
+ // Should have used the fallback well-known URL, not the attacker's.
+ assert_eq!(metadata.resource.as_str(), "https://mcp.example.com/");
+ });
+ }
+
+ #[test]
+ fn test_fetch_auth_server_metadata() {
+ smol::block_on(async {
+ let client = make_fake_http_client(|req| {
+ Box::pin(async move {
+ let uri = req.uri().to_string();
+ if uri.contains(".well-known/oauth-authorization-server") {
+ json_response(
+ 200,
+ r#"{
+ "issuer": "https://auth.example.com",
+ "authorization_endpoint": "https://auth.example.com/authorize",
+ "token_endpoint": "https://auth.example.com/token",
+ "registration_endpoint": "https://auth.example.com/register",
+ "code_challenge_methods_supported": ["S256"],
+ "client_id_metadata_document_supported": true
+ }"#,
+ )
+ } else {
+ json_response(404, "{}")
+ }
+ })
+ });
+
+ let issuer = Url::parse("https://auth.example.com").unwrap();
+ let metadata = fetch_auth_server_metadata(&client, &issuer).await.unwrap();
+
+ assert_eq!(metadata.issuer.as_str(), "https://auth.example.com/");
+ assert_eq!(
+ metadata.authorization_endpoint.as_str(),
+ "https://auth.example.com/authorize"
+ );
+ assert_eq!(
+ metadata.token_endpoint.as_str(),
+ "https://auth.example.com/token"
+ );
+ assert!(metadata.registration_endpoint.is_some());
+ assert!(metadata.client_id_metadata_document_supported);
+ assert_eq!(
+ metadata.code_challenge_methods_supported,
+ Some(vec!["S256".to_string()])
+ );
+ });
+ }
+
+ #[test]
+ fn test_fetch_auth_server_metadata_falls_back_to_oidc() {
+ smol::block_on(async {
+ let client = make_fake_http_client(|req| {
+ Box::pin(async move {
+ let uri = req.uri().to_string();
+ if uri.contains("openid-configuration") {
+ json_response(
+ 200,
+ r#"{
+ "issuer": "https://auth.example.com",
+ "authorization_endpoint": "https://auth.example.com/authorize",
+ "token_endpoint": "https://auth.example.com/token",
+ "code_challenge_methods_supported": ["S256"]
+ }"#,
+ )
+ } else {
+ json_response(404, "{}")
+ }
+ })
+ });
+
+ let issuer = Url::parse("https://auth.example.com").unwrap();
+ let metadata = fetch_auth_server_metadata(&client, &issuer).await.unwrap();
+
+ assert_eq!(
+ metadata.authorization_endpoint.as_str(),
+ "https://auth.example.com/authorize"
+ );
+ assert!(!metadata.client_id_metadata_document_supported);
+ });
+ }
+
+ #[test]
+ fn test_fetch_auth_server_metadata_rejects_issuer_mismatch() {
+ smol::block_on(async {
+ let client = make_fake_http_client(|req| {
+ Box::pin(async move {
+ let uri = req.uri().to_string();
+ if uri.contains(".well-known/oauth-authorization-server") {
+ // Response claims to be a different issuer.
+ json_response(
+ 200,
+ r#"{
+ "issuer": "https://evil.example.com",
+ "authorization_endpoint": "https://evil.example.com/authorize",
+ "token_endpoint": "https://evil.example.com/token",
+ "code_challenge_methods_supported": ["S256"]
+ }"#,
+ )
+ } else {
+ json_response(404, "{}")
+ }
+ })
+ });
+
+ let issuer = Url::parse("https://auth.example.com").unwrap();
+ let result = fetch_auth_server_metadata(&client, &issuer).await;
+
+ assert!(result.is_err());
+ let err_msg = result.unwrap_err().to_string();
+ assert!(
+ err_msg.contains("issuer mismatch"),
+ "unexpected error: {}",
+ err_msg
+ );
+ });
+ }
+
+ // -- Full discover integration tests -------------------------------------
+
+ #[test]
+ fn test_full_discover_with_cimd() {
+ smol::block_on(async {
+ let client = make_fake_http_client(|req| {
+ Box::pin(async move {
+ let uri = req.uri().to_string();
+ if uri.contains("oauth-protected-resource") {
+ json_response(
+ 200,
+ r#"{
+ "resource": "https://mcp.example.com",
+ "authorization_servers": ["https://auth.example.com"],
+ "scopes_supported": ["mcp:read"]
+ }"#,
+ )
+ } else if uri.contains("oauth-authorization-server") {
+ json_response(
+ 200,
+ r#"{
+ "issuer": "https://auth.example.com",
+ "authorization_endpoint": "https://auth.example.com/authorize",
+ "token_endpoint": "https://auth.example.com/token",
+ "code_challenge_methods_supported": ["S256"],
+ "client_id_metadata_document_supported": true
+ }"#,
+ )
+ } else {
+ json_response(404, "{}")
+ }
+ })
+ });
+
+ let server_url = Url::parse("https://mcp.example.com").unwrap();
+ let www_auth = WwwAuthenticate {
+ resource_metadata: None,
+ scope: None,
+ error: None,
+ error_description: None,
+ };
+
+ let discovery = discover(&client, &server_url, &www_auth).await.unwrap();
+ let registration =
+ resolve_client_registration(&client, &discovery, "http://127.0.0.1:12345/callback")
+ .await
+ .unwrap();
+
+ assert_eq!(registration.client_id, CIMD_URL);
+ assert_eq!(registration.client_secret, None);
+ assert_eq!(discovery.scopes, vec!["mcp:read"]);
+ });
+ }
+
+ #[test]
+ fn test_full_discover_with_dcr_fallback() {
+ smol::block_on(async {
+ let client = make_fake_http_client(|req| {
+ Box::pin(async move {
+ let uri = req.uri().to_string();
+ if uri.contains("oauth-protected-resource") {
+ json_response(
+ 200,
+ r#"{
+ "resource": "https://mcp.example.com",
+ "authorization_servers": ["https://auth.example.com"]
+ }"#,
+ )
+ } else if uri.contains("oauth-authorization-server") {
+ json_response(
+ 200,
+ r#"{
+ "issuer": "https://auth.example.com",
+ "authorization_endpoint": "https://auth.example.com/authorize",
+ "token_endpoint": "https://auth.example.com/token",
+ "registration_endpoint": "https://auth.example.com/register",
+ "code_challenge_methods_supported": ["S256"],
+ "client_id_metadata_document_supported": false
+ }"#,
+ )
+ } else if uri.contains("/register") {
+ json_response(
+ 201,
+ r#"{
+ "client_id": "dcr-minted-id-123",
+ "client_secret": "dcr-secret-456"
+ }"#,
+ )
+ } else {
+ json_response(404, "{}")
+ }
+ })
+ });
+
+ let server_url = Url::parse("https://mcp.example.com").unwrap();
+ let www_auth = WwwAuthenticate {
+ resource_metadata: None,
+ scope: Some(vec!["files:read".into()]),
+ error: None,
+ error_description: None,
+ };
+
+ let discovery = discover(&client, &server_url, &www_auth).await.unwrap();
+ let registration =
+ resolve_client_registration(&client, &discovery, "http://127.0.0.1:9999/callback")
+ .await
+ .unwrap();
+
+ assert_eq!(registration.client_id, "dcr-minted-id-123");
+ assert_eq!(
+ registration.client_secret.as_deref(),
+ Some("dcr-secret-456")
+ );
+ assert_eq!(discovery.scopes, vec!["files:read"]);
+ });
+ }
+
+ #[test]
+ fn test_discover_fails_without_pkce_support() {
+ smol::block_on(async {
+ let client = make_fake_http_client(|req| {
+ Box::pin(async move {
+ let uri = req.uri().to_string();
+ if uri.contains("oauth-protected-resource") {
+ json_response(
+ 200,
+ r#"{
+ "resource": "https://mcp.example.com",
+ "authorization_servers": ["https://auth.example.com"]
+ }"#,
+ )
+ } else if uri.contains("oauth-authorization-server") {
+ json_response(
+ 200,
+ r#"{
+ "issuer": "https://auth.example.com",
+ "authorization_endpoint": "https://auth.example.com/authorize",
+ "token_endpoint": "https://auth.example.com/token"
+ }"#,
+ )
+ } else {
+ json_response(404, "{}")
+ }
+ })
+ });
+
+ let server_url = Url::parse("https://mcp.example.com").unwrap();
+ let www_auth = WwwAuthenticate {
+ resource_metadata: None,
+ scope: None,
+ error: None,
+ error_description: None,
+ };
+
+ let result = discover(&client, &server_url, &www_auth).await;
+ assert!(result.is_err());
+ let err_msg = result.unwrap_err().to_string();
+ assert!(
+ err_msg.contains("code_challenge_methods_supported"),
+ "unexpected error: {}",
+ err_msg
+ );
+ });
+ }
+
+ // -- Token exchange integration tests ------------------------------------
+
+ #[test]
+ fn test_exchange_code_success() {
+ smol::block_on(async {
+ let client = make_fake_http_client(|req| {
+ Box::pin(async move {
+ let uri = req.uri().to_string();
+ if uri.contains("/token") {
+ json_response(
+ 200,
+ r#"{
+ "access_token": "new_access_token",
+ "refresh_token": "new_refresh_token",
+ "expires_in": 3600,
+ "token_type": "Bearer"
+ }"#,
+ )
+ } else {
+ json_response(404, "{}")
+ }
+ })
+ });
+
+ let metadata = AuthServerMetadata {
+ issuer: Url::parse("https://auth.example.com").unwrap(),
+ authorization_endpoint: Url::parse("https://auth.example.com/authorize").unwrap(),
+ token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
+ registration_endpoint: None,
+ scopes_supported: None,
+ code_challenge_methods_supported: Some(vec!["S256".into()]),
+ client_id_metadata_document_supported: true,
+ };
+
+ let tokens = exchange_code(
+ &client,
+ &metadata,
+ "auth_code_123",
+ CIMD_URL,
+ "http://127.0.0.1:9999/callback",
+ "verifier_abc",
+ "https://mcp.example.com",
+ )
+ .await
+ .unwrap();
+
+ assert_eq!(tokens.access_token, "new_access_token");
+ assert_eq!(tokens.refresh_token.as_deref(), Some("new_refresh_token"));
+ assert!(tokens.expires_at.is_some());
+ });
+ }
+
+ #[test]
+ fn test_refresh_tokens_success() {
+ smol::block_on(async {
+ let client = make_fake_http_client(|req| {
+ Box::pin(async move {
+ let uri = req.uri().to_string();
+ if uri.contains("/token") {
+ json_response(
+ 200,
+ r#"{
+ "access_token": "refreshed_token",
+ "expires_in": 1800,
+ "token_type": "Bearer"
+ }"#,
+ )
+ } else {
+ json_response(404, "{}")
+ }
+ })
+ });
+
+ let token_endpoint = Url::parse("https://auth.example.com/token").unwrap();
+
+ let tokens = refresh_tokens(
+ &client,
+ &token_endpoint,
+ "old_refresh_token",
+ CIMD_URL,
+ "https://mcp.example.com",
+ )
+ .await
+ .unwrap();
+
+ assert_eq!(tokens.access_token, "refreshed_token");
+ assert_eq!(tokens.refresh_token, None);
+ assert!(tokens.expires_at.is_some());
+ });
+ }
+
+ #[test]
+ fn test_exchange_code_failure() {
+ smol::block_on(async {
+ let client = make_fake_http_client(|_req| {
+ Box::pin(async move { json_response(400, r#"{"error": "invalid_grant"}"#) })
+ });
+
+ let metadata = AuthServerMetadata {
+ issuer: Url::parse("https://auth.example.com").unwrap(),
+ authorization_endpoint: Url::parse("https://auth.example.com/authorize").unwrap(),
+ token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
+ registration_endpoint: None,
+ scopes_supported: None,
+ code_challenge_methods_supported: Some(vec!["S256".into()]),
+ client_id_metadata_document_supported: true,
+ };
+
+ let result = exchange_code(
+ &client,
+ &metadata,
+ "bad_code",
+ "client",
+ "http://127.0.0.1:1/callback",
+ "verifier",
+ "https://mcp.example.com",
+ )
+ .await;
+
+ assert!(result.is_err());
+ assert!(result.unwrap_err().to_string().contains("400"));
+ });
+ }
+
+ // -- DCR integration tests -----------------------------------------------
+
+ #[test]
+ fn test_perform_dcr() {
+ smol::block_on(async {
+ let client = make_fake_http_client(|_req| {
+ Box::pin(async move {
+ json_response(
+ 201,
+ r#"{
+ "client_id": "dynamic-client-001",
+ "client_secret": "dynamic-secret-001"
+ }"#,
+ )
+ })
+ });
+
+ let endpoint = Url::parse("https://auth.example.com/register").unwrap();
+ let registration = perform_dcr(&client, &endpoint, "http://127.0.0.1:9999/callback")
+ .await
+ .unwrap();
+
+ assert_eq!(registration.client_id, "dynamic-client-001");
+ assert_eq!(
+ registration.client_secret.as_deref(),
+ Some("dynamic-secret-001")
+ );
+ });
+ }
+
+ #[test]
+ fn test_perform_dcr_failure() {
+ smol::block_on(async {
+ let client = make_fake_http_client(|_req| {
+ Box::pin(
+ async move { json_response(403, r#"{"error": "registration_not_allowed"}"#) },
+ )
+ });
+
+ let endpoint = Url::parse("https://auth.example.com/register").unwrap();
+ let result = perform_dcr(&client, &endpoint, "http://127.0.0.1:9999/callback").await;
+
+ assert!(result.is_err());
+ assert!(result.unwrap_err().to_string().contains("403"));
+ });
+ }
+
+ // -- OAuthCallback parse tests -------------------------------------------
+
+ #[test]
+ fn test_oauth_callback_parse_query() {
+ let callback = OAuthCallback::parse_query("code=test_auth_code&state=test_state").unwrap();
+ assert_eq!(callback.code, "test_auth_code");
+ assert_eq!(callback.state, "test_state");
+ }
+
+ #[test]
+ fn test_oauth_callback_parse_query_reversed_order() {
+ let callback = OAuthCallback::parse_query("state=test_state&code=test_auth_code").unwrap();
+ assert_eq!(callback.code, "test_auth_code");
+ assert_eq!(callback.state, "test_state");
+ }
+
+ #[test]
+ fn test_oauth_callback_parse_query_with_extra_params() {
+ let callback =
+ OAuthCallback::parse_query("code=test_auth_code&state=test_state&extra=ignored")
+ .unwrap();
+ assert_eq!(callback.code, "test_auth_code");
+ assert_eq!(callback.state, "test_state");
+ }
+
+ #[test]
+ fn test_oauth_callback_parse_query_missing_code() {
+ let result = OAuthCallback::parse_query("state=test_state");
+ assert!(result.is_err());
+ assert!(result.unwrap_err().to_string().contains("code"));
+ }
+
+ #[test]
+ fn test_oauth_callback_parse_query_missing_state() {
+ let result = OAuthCallback::parse_query("code=test_auth_code");
+ assert!(result.is_err());
+ assert!(result.unwrap_err().to_string().contains("state"));
+ }
+
+ #[test]
+ fn test_oauth_callback_parse_query_empty_code() {
+ let result = OAuthCallback::parse_query("code=&state=test_state");
+ assert!(result.is_err());
+ }
+
+ #[test]
+ fn test_oauth_callback_parse_query_empty_state() {
+ let result = OAuthCallback::parse_query("code=test_auth_code&state=");
+ assert!(result.is_err());
+ }
+
+ #[test]
+ fn test_oauth_callback_parse_query_url_encoded_values() {
+ let callback = OAuthCallback::parse_query("code=abc%20def&state=test%3Dstate").unwrap();
+ assert_eq!(callback.code, "abc def");
+ assert_eq!(callback.state, "test=state");
+ }
+
+ #[test]
+ fn test_oauth_callback_parse_query_error_response() {
+ let result = OAuthCallback::parse_query(
+ "error=access_denied&error_description=User%20denied%20access&state=abc",
+ );
+ assert!(result.is_err());
+ let err_msg = result.unwrap_err().to_string();
+ assert!(
+ err_msg.contains("access_denied"),
+ "unexpected error: {}",
+ err_msg
+ );
+ assert!(
+ err_msg.contains("User denied access"),
+ "unexpected error: {}",
+ err_msg
+ );
+ }
+
+ #[test]
+ fn test_oauth_callback_parse_query_error_without_description() {
+ let result = OAuthCallback::parse_query("error=server_error&state=abc");
+ assert!(result.is_err());
+ let err_msg = result.unwrap_err().to_string();
+ assert!(
+ err_msg.contains("server_error"),
+ "unexpected error: {}",
+ err_msg
+ );
+ assert!(
+ err_msg.contains("no description"),
+ "unexpected error: {}",
+ err_msg
+ );
+ }
+
+ // -- McpOAuthTokenProvider tests -----------------------------------------
+
+ fn make_test_session(
+ access_token: &str,
+ refresh_token: Option<&str>,
+ expires_at: Option,
+ ) -> OAuthSession {
+ OAuthSession {
+ token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
+ resource: Url::parse("https://mcp.example.com").unwrap(),
+ client_registration: OAuthClientRegistration {
+ client_id: "test-client".into(),
+ client_secret: None,
+ },
+ tokens: OAuthTokens {
+ access_token: access_token.into(),
+ refresh_token: refresh_token.map(String::from),
+ expires_at,
+ },
+ }
+ }
+
+ #[test]
+ fn test_mcp_oauth_provider_returns_none_when_token_expired() {
+ let expired = SystemTime::now() - Duration::from_secs(60);
+ let session = make_test_session("stale-token", Some("rt"), Some(expired));
+ let provider = McpOAuthTokenProvider::new(
+ session,
+ make_fake_http_client(|_| Box::pin(async { unreachable!() })),
+ None,
+ );
+
+ assert_eq!(provider.access_token(), None);
+ }
+
+ #[test]
+ fn test_mcp_oauth_provider_returns_token_when_not_expired() {
+ let far_future = SystemTime::now() + Duration::from_secs(3600);
+ let session = make_test_session("valid-token", Some("rt"), Some(far_future));
+ let provider = McpOAuthTokenProvider::new(
+ session,
+ make_fake_http_client(|_| Box::pin(async { unreachable!() })),
+ None,
+ );
+
+ assert_eq!(provider.access_token().as_deref(), Some("valid-token"));
+ }
+
+ #[test]
+ fn test_mcp_oauth_provider_returns_token_when_no_expiry() {
+ let session = make_test_session("no-expiry-token", Some("rt"), None);
+ let provider = McpOAuthTokenProvider::new(
+ session,
+ make_fake_http_client(|_| Box::pin(async { unreachable!() })),
+ None,
+ );
+
+ assert_eq!(provider.access_token().as_deref(), Some("no-expiry-token"));
+ }
+
+ #[test]
+ fn test_mcp_oauth_provider_refresh_without_refresh_token_returns_false() {
+ smol::block_on(async {
+ let session = make_test_session("token", None, None);
+ let provider = McpOAuthTokenProvider::new(
+ session,
+ make_fake_http_client(|_| {
+ Box::pin(async { unreachable!("no HTTP call expected") })
+ }),
+ None,
+ );
+
+ let refreshed = provider.try_refresh().await.unwrap();
+ assert!(!refreshed);
+ });
+ }
+
+ #[test]
+ fn test_mcp_oauth_provider_refresh_updates_session_and_notifies_channel() {
+ smol::block_on(async {
+ let session = make_test_session("old-access", Some("my-refresh-token"), None);
+ let (tx, mut rx) = futures::channel::mpsc::unbounded();
+
+ let http_client = make_fake_http_client(|_req| {
+ Box::pin(async {
+ json_response(
+ 200,
+ r#"{
+ "access_token": "new-access",
+ "refresh_token": "new-refresh",
+ "expires_in": 1800
+ }"#,
+ )
+ })
+ });
+
+ let provider = McpOAuthTokenProvider::new(session, http_client, Some(tx));
+
+ let refreshed = provider.try_refresh().await.unwrap();
+ assert!(refreshed);
+ assert_eq!(provider.access_token().as_deref(), Some("new-access"));
+
+ let notified_session = rx
+ .try_next()
+ .unwrap()
+ .expect("channel should have a session");
+ assert_eq!(notified_session.tokens.access_token, "new-access");
+ assert_eq!(
+ notified_session.tokens.refresh_token.as_deref(),
+ Some("new-refresh")
+ );
+ });
+ }
+
+ #[test]
+ fn test_mcp_oauth_provider_refresh_preserves_old_refresh_token_when_server_omits_it() {
+ smol::block_on(async {
+ let session = make_test_session("old-access", Some("original-refresh"), None);
+ let (tx, mut rx) = futures::channel::mpsc::unbounded();
+
+ let http_client = make_fake_http_client(|_req| {
+ Box::pin(async {
+ json_response(
+ 200,
+ r#"{
+ "access_token": "new-access",
+ "expires_in": 900
+ }"#,
+ )
+ })
+ });
+
+ let provider = McpOAuthTokenProvider::new(session, http_client, Some(tx));
+
+ let refreshed = provider.try_refresh().await.unwrap();
+ assert!(refreshed);
+
+ let notified_session = rx
+ .try_next()
+ .unwrap()
+ .expect("channel should have a session");
+ assert_eq!(notified_session.tokens.access_token, "new-access");
+ assert_eq!(
+ notified_session.tokens.refresh_token.as_deref(),
+ Some("original-refresh"),
+ );
+ });
+ }
+
+ #[test]
+ fn test_mcp_oauth_provider_refresh_returns_false_on_http_error() {
+ smol::block_on(async {
+ let session = make_test_session("old-access", Some("my-refresh"), None);
+
+ let http_client = make_fake_http_client(|_req| {
+ Box::pin(async { json_response(401, r#"{"error": "invalid_grant"}"#) })
+ });
+
+ let provider = McpOAuthTokenProvider::new(session, http_client, None);
+
+ let refreshed = provider.try_refresh().await.unwrap();
+ assert!(!refreshed);
+ // The old token should still be in place.
+ assert_eq!(provider.access_token().as_deref(), Some("old-access"));
+ });
+ }
+}
diff --git a/crates/context_server/src/transport/http.rs b/crates/context_server/src/transport/http.rs
index 70248f0278fcf80024d75d7f78cae5c29f26cc43..3e002983b5e49026d668c8baabfe8f856e4c5fe7 100644
--- a/crates/context_server/src/transport/http.rs
+++ b/crates/context_server/src/transport/http.rs
@@ -8,8 +8,30 @@ use parking_lot::Mutex as SyncMutex;
use smol::channel;
use std::{pin::Pin, sync::Arc};
+use crate::oauth::{self, OAuthTokenProvider, WwwAuthenticate};
use crate::transport::Transport;
+/// Typed errors returned by the HTTP transport that callers can downcast from
+/// `anyhow::Error` to handle specific failure modes.
+#[derive(Debug)]
+pub enum TransportError {
+ /// The server returned 401 and token refresh either wasn't possible or
+ /// failed. The caller should initiate the OAuth authorization flow.
+ AuthRequired { www_authenticate: WwwAuthenticate },
+}
+
+impl std::fmt::Display for TransportError {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ match self {
+ TransportError::AuthRequired { .. } => {
+ write!(f, "OAuth authorization required")
+ }
+ }
+ }
+}
+
+impl std::error::Error for TransportError {}
+
// Constants from MCP spec
const HEADER_SESSION_ID: &str = "Mcp-Session-Id";
const EVENT_STREAM_MIME_TYPE: &str = "text/event-stream";
@@ -25,8 +47,11 @@ pub struct HttpTransport {
response_rx: channel::Receiver,
error_tx: channel::Sender,
error_rx: channel::Receiver,
- // Authentication headers to include in requests
+ /// Static headers to include in every request (e.g. from server config).
headers: HashMap,
+ /// When set, the transport attaches `Authorization: Bearer` headers and
+ /// handles 401 responses with token refresh + retry.
+ token_provider: Option>,
}
impl HttpTransport {
@@ -35,6 +60,16 @@ impl HttpTransport {
endpoint: String,
headers: HashMap,
executor: BackgroundExecutor,
+ ) -> Self {
+ Self::new_with_token_provider(http_client, endpoint, headers, executor, None)
+ }
+
+ pub fn new_with_token_provider(
+ http_client: Arc,
+ endpoint: String,
+ headers: HashMap,
+ executor: BackgroundExecutor,
+ token_provider: Option>,
) -> Self {
let (response_tx, response_rx) = channel::unbounded();
let (error_tx, error_rx) = channel::unbounded();
@@ -49,14 +84,14 @@ impl HttpTransport {
error_tx,
error_rx,
headers,
+ token_provider,
}
}
- /// Send a message and handle the response based on content type
- async fn send_message(&self, message: String) -> Result<()> {
- let is_notification =
- !message.contains("\"id\":") || message.contains("notifications/initialized");
-
+ /// Build a POST request for the given message body, attaching all standard
+ /// headers (content-type, accept, session ID, static headers, and bearer
+ /// token if available).
+ fn build_request(&self, message: &[u8]) -> Result> {
let mut request_builder = Request::builder()
.method(Method::POST)
.uri(&self.endpoint)
@@ -70,15 +105,71 @@ impl HttpTransport {
request_builder = request_builder.header(key.as_str(), value.as_str());
}
- // Add session ID if we have one (except for initialize)
+ // Attach bearer token when a token provider is present.
+ if let Some(token) = self.token_provider.as_ref().and_then(|p| p.access_token()) {
+ request_builder = request_builder.header("Authorization", format!("Bearer {}", token));
+ }
+
+ // Add session ID if we have one (except for initialize).
if let Some(ref session_id) = *self.session_id.lock() {
request_builder = request_builder.header(HEADER_SESSION_ID, session_id.as_str());
}
- let request = request_builder.body(AsyncBody::from(message.into_bytes()))?;
+ Ok(request_builder.body(AsyncBody::from(message.to_vec()))?)
+ }
+
+ /// Send a message and handle the response based on content type.
+ async fn send_message(&self, message: String) -> Result<()> {
+ let is_notification =
+ !message.contains("\"id\":") || message.contains("notifications/initialized");
+
+ // If we currently have no access token, try refreshing before sending
+ // the request so restored but expired sessions do not need an initial
+ // 401 round-trip before they can recover.
+ if let Some(ref provider) = self.token_provider {
+ if provider.access_token().is_none() {
+ provider.try_refresh().await.unwrap_or(false);
+ }
+ }
+
+ let request = self.build_request(message.as_bytes())?;
let mut response = self.http_client.send(request).await?;
- // Handle different response types based on status and content-type
+ // On 401, try refreshing the token and retry once.
+ if response.status().as_u16() == 401 {
+ let www_auth_header = response
+ .headers()
+ .get("www-authenticate")
+ .and_then(|v| v.to_str().ok())
+ .unwrap_or("Bearer");
+
+ let www_authenticate =
+ oauth::parse_www_authenticate(www_auth_header).unwrap_or(WwwAuthenticate {
+ resource_metadata: None,
+ scope: None,
+ error: None,
+ error_description: None,
+ });
+
+ if let Some(ref provider) = self.token_provider {
+ if provider.try_refresh().await.unwrap_or(false) {
+ // Retry with the refreshed token.
+ let retry_request = self.build_request(message.as_bytes())?;
+ response = self.http_client.send(retry_request).await?;
+
+ // If still 401 after refresh, give up.
+ if response.status().as_u16() == 401 {
+ return Err(TransportError::AuthRequired { www_authenticate }.into());
+ }
+ } else {
+ return Err(TransportError::AuthRequired { www_authenticate }.into());
+ }
+ } else {
+ return Err(TransportError::AuthRequired { www_authenticate }.into());
+ }
+ }
+
+ // Handle different response types based on status and content-type.
match response.status() {
status if status.is_success() => {
// Check content type
@@ -233,6 +324,7 @@ impl Drop for HttpTransport {
let endpoint = self.endpoint.clone();
let session_id = self.session_id.lock().clone();
let headers = self.headers.clone();
+ let access_token = self.token_provider.as_ref().and_then(|p| p.access_token());
if let Some(session_id) = session_id {
self.executor
@@ -242,11 +334,17 @@ impl Drop for HttpTransport {
.uri(&endpoint)
.header(HEADER_SESSION_ID, &session_id);
- // Add authentication headers if present
+ // Add static authentication headers.
for (key, value) in headers {
request_builder = request_builder.header(key.as_str(), value.as_str());
}
+ // Attach bearer token if available.
+ if let Some(token) = access_token {
+ request_builder =
+ request_builder.header("Authorization", format!("Bearer {}", token));
+ }
+
let request = request_builder.body(AsyncBody::empty());
if let Ok(request) = request {
@@ -257,3 +355,402 @@ impl Drop for HttpTransport {
}
}
}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use async_trait::async_trait;
+ use gpui::TestAppContext;
+ use parking_lot::Mutex as SyncMutex;
+ use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
+
+ /// A mock token provider that returns a configurable token and tracks
+ /// refresh attempts.
+ struct FakeTokenProvider {
+ token: SyncMutex