From 121cac83c0c5902220ba2ada52824bf3cbd25a12 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tom=20Houl=C3=A9?= Date: Wed, 8 Apr 2026 17:38:09 +0200 Subject: [PATCH] context_server: Mirror authorization server grant_types_supported In MCP OAuth, mirror the authorization server's grant_types_supported in the DCR registration body instead of hardcoding just authorization_code. Logfire's auth server requires both authorization_code and refresh_token in grant_types, and we already uses refresh tokens, so the only issue was not advertising the capability during registration. The DCR body now intersects our supported grant types with what the server advertises, or sends all of ours when the server metadata omits grant_types_supported. Without this change, the Pydantic Logfire MCP auth server refuses our client registration. --- crates/context_server/src/oauth.rs | 84 ++++++++++++++++++++++++++---- 1 file changed, 74 insertions(+), 10 deletions(-) diff --git a/crates/context_server/src/oauth.rs b/crates/context_server/src/oauth.rs index 1a314de2fca9b9987336decb15b208ffd7759dea..d1ce778bdcce02716f2fa78ed41ca2750837ed20 100644 --- a/crates/context_server/src/oauth.rs +++ b/crates/context_server/src/oauth.rs @@ -146,6 +146,7 @@ pub struct AuthServerMetadata { pub token_endpoint: Url, pub registration_endpoint: Option, pub scopes_supported: Option>, + pub grant_types_supported: Option>, pub code_challenge_methods_supported: Option>, pub client_id_metadata_document_supported: bool, } @@ -672,11 +673,30 @@ pub fn token_refresh_params( /// port (e.g. `http://127.0.0.1:12345/callback`). Some auth servers do strict /// redirect URI matching even for loopback addresses, so we register the /// exact URI we intend to use. -pub fn dcr_registration_body(redirect_uri: &str) -> serde_json::Value { +/// The grant types Zed can use. Intersected with the server's +/// `grant_types_supported` to build the DCR request. +const SUPPORTED_GRANT_TYPES: &[&str] = &["authorization_code", "refresh_token"]; + +pub fn dcr_registration_body( + redirect_uri: &str, + server_grant_types: Option<&[String]>, +) -> serde_json::Value { + // Use the intersection of what we support and what the server advertises. + // When the server doesn't advertise grant_types_supported, send all of + // ours — the server will reject what it doesn't like. + let grant_types: Vec<&str> = match server_grant_types { + Some(server) => SUPPORTED_GRANT_TYPES + .iter() + .copied() + .filter(|gt| server.iter().any(|s| s == *gt)) + .collect(), + None => SUPPORTED_GRANT_TYPES.to_vec(), + }; + serde_json::json!({ "client_name": "Zed", "redirect_uris": [redirect_uri], - "grant_types": ["authorization_code"], + "grant_types": grant_types, "response_types": ["code"], "token_endpoint_auth_method": "none" }) @@ -760,6 +780,7 @@ pub async fn fetch_auth_server_metadata( return Ok(AuthServerMetadata { issuer: reported_issuer, + grant_types_supported: response.grant_types_supported, authorization_endpoint: response .authorization_endpoint .ok_or_else(|| anyhow!("missing authorization_endpoint"))?, @@ -846,7 +867,18 @@ pub async fn resolve_client_registration( }), ClientRegistrationStrategy::Dcr { registration_endpoint, - } => perform_dcr(http_client, ®istration_endpoint, redirect_uri).await, + } => { + perform_dcr( + http_client, + ®istration_endpoint, + redirect_uri, + discovery + .auth_server_metadata + .grant_types_supported + .as_deref(), + ) + .await + } ClientRegistrationStrategy::Unavailable => { bail!("authorization server supports neither CIMD nor DCR") } @@ -860,10 +892,11 @@ pub async fn perform_dcr( http_client: &Arc, registration_endpoint: &Url, redirect_uri: &str, + server_grant_types: Option<&[String]>, ) -> Result { validate_oauth_url(registration_endpoint)?; - let body = dcr_registration_body(redirect_uri); + let body = dcr_registration_body(redirect_uri, server_grant_types); let body_bytes = serde_json::to_vec(&body)?; let request = Request::builder() @@ -1206,6 +1239,8 @@ struct AuthServerMetadataResponse { #[serde(default)] scopes_supported: Option>, #[serde(default)] + grant_types_supported: Option>, + #[serde(default)] code_challenge_methods_supported: Option>, #[serde(default)] client_id_metadata_document_supported: Option, @@ -1707,6 +1742,7 @@ mod tests { scopes_supported: None, code_challenge_methods_supported: Some(vec!["S256".into()]), client_id_metadata_document_supported: true, + grant_types_supported: None, }; assert_eq!( determine_registration_strategy(&metadata), @@ -1727,6 +1763,7 @@ mod tests { scopes_supported: None, code_challenge_methods_supported: Some(vec!["S256".into()]), client_id_metadata_document_supported: false, + grant_types_supported: None, }; assert_eq!( determine_registration_strategy(&metadata), @@ -1746,6 +1783,7 @@ mod tests { scopes_supported: None, code_challenge_methods_supported: Some(vec!["S256".into()]), client_id_metadata_document_supported: false, + grant_types_supported: None, }; assert_eq!( determine_registration_strategy(&metadata), @@ -1802,6 +1840,7 @@ mod tests { scopes_supported: None, code_challenge_methods_supported: Some(vec!["S256".into()]), client_id_metadata_document_supported: true, + grant_types_supported: None, }; let pkce = PkceChallenge { verifier: "test_verifier".into(), @@ -1844,6 +1883,7 @@ mod tests { scopes_supported: None, code_challenge_methods_supported: Some(vec!["S256".into()]), client_id_metadata_document_supported: false, + grant_types_supported: None, }; let pkce = PkceChallenge { verifier: "v".into(), @@ -1927,15 +1967,35 @@ mod tests { // -- DCR body test ------------------------------------------------------- #[test] - fn test_dcr_registration_body_shape() { - let body = dcr_registration_body("http://127.0.0.1:12345/callback"); + fn test_dcr_registration_body_without_server_metadata() { + // When server metadata is unavailable, include all supported grant types. + let body = dcr_registration_body("http://127.0.0.1:12345/callback", None); assert_eq!(body["client_name"], "Zed"); assert_eq!(body["redirect_uris"][0], "http://127.0.0.1:12345/callback"); assert_eq!(body["grant_types"][0], "authorization_code"); + assert_eq!(body["grant_types"][1], "refresh_token"); assert_eq!(body["response_types"][0], "code"); assert_eq!(body["token_endpoint_auth_method"], "none"); } + #[test] + fn test_dcr_registration_body_mirrors_server_grant_types() { + // When the server only supports authorization_code, omit refresh_token. + let server_types = vec!["authorization_code".to_string()]; + let body = dcr_registration_body("http://127.0.0.1:12345/callback", Some(&server_types)); + assert_eq!(body["grant_types"][0], "authorization_code"); + assert!(body["grant_types"].as_array().unwrap().len() == 1); + + // When the server supports both, include both. + let server_types = vec![ + "authorization_code".to_string(), + "refresh_token".to_string(), + ]; + let body = dcr_registration_body("http://127.0.0.1:12345/callback", Some(&server_types)); + assert_eq!(body["grant_types"][0], "authorization_code"); + assert_eq!(body["grant_types"][1], "refresh_token"); + } + // -- Test helpers for async/HTTP tests ----------------------------------- fn make_fake_http_client( @@ -2398,6 +2458,7 @@ mod tests { scopes_supported: None, code_challenge_methods_supported: Some(vec!["S256".into()]), client_id_metadata_document_supported: true, + grant_types_supported: None, }; let tokens = exchange_code( @@ -2472,6 +2533,7 @@ mod tests { scopes_supported: None, code_challenge_methods_supported: Some(vec!["S256".into()]), client_id_metadata_document_supported: true, + grant_types_supported: None, }; let result = exchange_code( @@ -2508,9 +2570,10 @@ mod tests { }); let endpoint = Url::parse("https://auth.example.com/register").unwrap(); - let registration = perform_dcr(&client, &endpoint, "http://127.0.0.1:9999/callback") - .await - .unwrap(); + let registration = + perform_dcr(&client, &endpoint, "http://127.0.0.1:9999/callback", None) + .await + .unwrap(); assert_eq!(registration.client_id, "dynamic-client-001"); assert_eq!( @@ -2530,7 +2593,8 @@ mod tests { }); let endpoint = Url::parse("https://auth.example.com/register").unwrap(); - let result = perform_dcr(&client, &endpoint, "http://127.0.0.1:9999/callback").await; + let result = + perform_dcr(&client, &endpoint, "http://127.0.0.1:9999/callback", None).await; assert!(result.is_err()); assert!(result.unwrap_err().to_string().contains("403"));