context_server: Mirror authorization server grant_types_supported

Tom Houlé created

In MCP OAuth, mirror the authorization server's grant_types_supported in the
DCR registration body instead of hardcoding just authorization_code.
Logfire's auth server requires both authorization_code and refresh_token
in grant_types, and we already uses refresh tokens, so the only issue
was not advertising the capability during registration. The DCR body now
intersects our supported grant types with what the server advertises, or
sends all of ours when the server metadata omits grant_types_supported.

Without this change, the Pydantic Logfire MCP auth server refuses our client registration.

Change summary

crates/context_server/src/oauth.rs | 84 ++++++++++++++++++++++++++++---
1 file changed, 74 insertions(+), 10 deletions(-)

Detailed changes

crates/context_server/src/oauth.rs 🔗

@@ -146,6 +146,7 @@ pub struct AuthServerMetadata {
     pub token_endpoint: Url,
     pub registration_endpoint: Option<Url>,
     pub scopes_supported: Option<Vec<String>>,
+    pub grant_types_supported: Option<Vec<String>>,
     pub code_challenge_methods_supported: Option<Vec<String>>,
     pub client_id_metadata_document_supported: bool,
 }
@@ -672,11 +673,30 @@ pub fn token_refresh_params(
 /// port (e.g. `http://127.0.0.1:12345/callback`). Some auth servers do strict
 /// redirect URI matching even for loopback addresses, so we register the
 /// exact URI we intend to use.
-pub fn dcr_registration_body(redirect_uri: &str) -> serde_json::Value {
+/// The grant types Zed can use. Intersected with the server's
+/// `grant_types_supported` to build the DCR request.
+const SUPPORTED_GRANT_TYPES: &[&str] = &["authorization_code", "refresh_token"];
+
+pub fn dcr_registration_body(
+    redirect_uri: &str,
+    server_grant_types: Option<&[String]>,
+) -> serde_json::Value {
+    // Use the intersection of what we support and what the server advertises.
+    // When the server doesn't advertise grant_types_supported, send all of
+    // ours — the server will reject what it doesn't like.
+    let grant_types: Vec<&str> = match server_grant_types {
+        Some(server) => SUPPORTED_GRANT_TYPES
+            .iter()
+            .copied()
+            .filter(|gt| server.iter().any(|s| s == *gt))
+            .collect(),
+        None => SUPPORTED_GRANT_TYPES.to_vec(),
+    };
+
     serde_json::json!({
         "client_name": "Zed",
         "redirect_uris": [redirect_uri],
-        "grant_types": ["authorization_code"],
+        "grant_types": grant_types,
         "response_types": ["code"],
         "token_endpoint_auth_method": "none"
     })
@@ -760,6 +780,7 @@ pub async fn fetch_auth_server_metadata(
 
                 return Ok(AuthServerMetadata {
                     issuer: reported_issuer,
+                    grant_types_supported: response.grant_types_supported,
                     authorization_endpoint: response
                         .authorization_endpoint
                         .ok_or_else(|| anyhow!("missing authorization_endpoint"))?,
@@ -846,7 +867,18 @@ pub async fn resolve_client_registration(
         }),
         ClientRegistrationStrategy::Dcr {
             registration_endpoint,
-        } => perform_dcr(http_client, &registration_endpoint, redirect_uri).await,
+        } => {
+            perform_dcr(
+                http_client,
+                &registration_endpoint,
+                redirect_uri,
+                discovery
+                    .auth_server_metadata
+                    .grant_types_supported
+                    .as_deref(),
+            )
+            .await
+        }
         ClientRegistrationStrategy::Unavailable => {
             bail!("authorization server supports neither CIMD nor DCR")
         }
@@ -860,10 +892,11 @@ pub async fn perform_dcr(
     http_client: &Arc<dyn HttpClient>,
     registration_endpoint: &Url,
     redirect_uri: &str,
+    server_grant_types: Option<&[String]>,
 ) -> Result<OAuthClientRegistration> {
     validate_oauth_url(registration_endpoint)?;
 
-    let body = dcr_registration_body(redirect_uri);
+    let body = dcr_registration_body(redirect_uri, server_grant_types);
     let body_bytes = serde_json::to_vec(&body)?;
 
     let request = Request::builder()
@@ -1206,6 +1239,8 @@ struct AuthServerMetadataResponse {
     #[serde(default)]
     scopes_supported: Option<Vec<String>>,
     #[serde(default)]
+    grant_types_supported: Option<Vec<String>>,
+    #[serde(default)]
     code_challenge_methods_supported: Option<Vec<String>>,
     #[serde(default)]
     client_id_metadata_document_supported: Option<bool>,
@@ -1707,6 +1742,7 @@ mod tests {
             scopes_supported: None,
             code_challenge_methods_supported: Some(vec!["S256".into()]),
             client_id_metadata_document_supported: true,
+            grant_types_supported: None,
         };
         assert_eq!(
             determine_registration_strategy(&metadata),
@@ -1727,6 +1763,7 @@ mod tests {
             scopes_supported: None,
             code_challenge_methods_supported: Some(vec!["S256".into()]),
             client_id_metadata_document_supported: false,
+            grant_types_supported: None,
         };
         assert_eq!(
             determine_registration_strategy(&metadata),
@@ -1746,6 +1783,7 @@ mod tests {
             scopes_supported: None,
             code_challenge_methods_supported: Some(vec!["S256".into()]),
             client_id_metadata_document_supported: false,
+            grant_types_supported: None,
         };
         assert_eq!(
             determine_registration_strategy(&metadata),
@@ -1802,6 +1840,7 @@ mod tests {
             scopes_supported: None,
             code_challenge_methods_supported: Some(vec!["S256".into()]),
             client_id_metadata_document_supported: true,
+            grant_types_supported: None,
         };
         let pkce = PkceChallenge {
             verifier: "test_verifier".into(),
@@ -1844,6 +1883,7 @@ mod tests {
             scopes_supported: None,
             code_challenge_methods_supported: Some(vec!["S256".into()]),
             client_id_metadata_document_supported: false,
+            grant_types_supported: None,
         };
         let pkce = PkceChallenge {
             verifier: "v".into(),
@@ -1927,15 +1967,35 @@ mod tests {
     // -- DCR body test -------------------------------------------------------
 
     #[test]
-    fn test_dcr_registration_body_shape() {
-        let body = dcr_registration_body("http://127.0.0.1:12345/callback");
+    fn test_dcr_registration_body_without_server_metadata() {
+        // When server metadata is unavailable, include all supported grant types.
+        let body = dcr_registration_body("http://127.0.0.1:12345/callback", None);
         assert_eq!(body["client_name"], "Zed");
         assert_eq!(body["redirect_uris"][0], "http://127.0.0.1:12345/callback");
         assert_eq!(body["grant_types"][0], "authorization_code");
+        assert_eq!(body["grant_types"][1], "refresh_token");
         assert_eq!(body["response_types"][0], "code");
         assert_eq!(body["token_endpoint_auth_method"], "none");
     }
 
+    #[test]
+    fn test_dcr_registration_body_mirrors_server_grant_types() {
+        // When the server only supports authorization_code, omit refresh_token.
+        let server_types = vec!["authorization_code".to_string()];
+        let body = dcr_registration_body("http://127.0.0.1:12345/callback", Some(&server_types));
+        assert_eq!(body["grant_types"][0], "authorization_code");
+        assert!(body["grant_types"].as_array().unwrap().len() == 1);
+
+        // When the server supports both, include both.
+        let server_types = vec![
+            "authorization_code".to_string(),
+            "refresh_token".to_string(),
+        ];
+        let body = dcr_registration_body("http://127.0.0.1:12345/callback", Some(&server_types));
+        assert_eq!(body["grant_types"][0], "authorization_code");
+        assert_eq!(body["grant_types"][1], "refresh_token");
+    }
+
     // -- Test helpers for async/HTTP tests -----------------------------------
 
     fn make_fake_http_client(
@@ -2398,6 +2458,7 @@ mod tests {
                 scopes_supported: None,
                 code_challenge_methods_supported: Some(vec!["S256".into()]),
                 client_id_metadata_document_supported: true,
+                grant_types_supported: None,
             };
 
             let tokens = exchange_code(
@@ -2472,6 +2533,7 @@ mod tests {
                 scopes_supported: None,
                 code_challenge_methods_supported: Some(vec!["S256".into()]),
                 client_id_metadata_document_supported: true,
+                grant_types_supported: None,
             };
 
             let result = exchange_code(
@@ -2508,9 +2570,10 @@ mod tests {
             });
 
             let endpoint = Url::parse("https://auth.example.com/register").unwrap();
-            let registration = perform_dcr(&client, &endpoint, "http://127.0.0.1:9999/callback")
-                .await
-                .unwrap();
+            let registration =
+                perform_dcr(&client, &endpoint, "http://127.0.0.1:9999/callback", None)
+                    .await
+                    .unwrap();
 
             assert_eq!(registration.client_id, "dynamic-client-001");
             assert_eq!(
@@ -2530,7 +2593,8 @@ mod tests {
             });
 
             let endpoint = Url::parse("https://auth.example.com/register").unwrap();
-            let result = perform_dcr(&client, &endpoint, "http://127.0.0.1:9999/callback").await;
+            let result =
+                perform_dcr(&client, &endpoint, "http://127.0.0.1:9999/callback", None).await;
 
             assert!(result.is_err());
             assert!(result.unwrap_err().to_string().contains("403"));