@@ -146,6 +146,7 @@ pub struct AuthServerMetadata {
pub token_endpoint: Url,
pub registration_endpoint: Option<Url>,
pub scopes_supported: Option<Vec<String>>,
+ pub grant_types_supported: Option<Vec<String>>,
pub code_challenge_methods_supported: Option<Vec<String>>,
pub client_id_metadata_document_supported: bool,
}
@@ -672,11 +673,30 @@ pub fn token_refresh_params(
/// port (e.g. `http://127.0.0.1:12345/callback`). Some auth servers do strict
/// redirect URI matching even for loopback addresses, so we register the
/// exact URI we intend to use.
-pub fn dcr_registration_body(redirect_uri: &str) -> serde_json::Value {
+/// The grant types Zed can use. Intersected with the server's
+/// `grant_types_supported` to build the DCR request.
+const SUPPORTED_GRANT_TYPES: &[&str] = &["authorization_code", "refresh_token"];
+
+pub fn dcr_registration_body(
+ redirect_uri: &str,
+ server_grant_types: Option<&[String]>,
+) -> serde_json::Value {
+ // Use the intersection of what we support and what the server advertises.
+ // When the server doesn't advertise grant_types_supported, send all of
+ // ours — the server will reject what it doesn't like.
+ let grant_types: Vec<&str> = match server_grant_types {
+ Some(server) => SUPPORTED_GRANT_TYPES
+ .iter()
+ .copied()
+ .filter(|gt| server.iter().any(|s| s == *gt))
+ .collect(),
+ None => SUPPORTED_GRANT_TYPES.to_vec(),
+ };
+
serde_json::json!({
"client_name": "Zed",
"redirect_uris": [redirect_uri],
- "grant_types": ["authorization_code"],
+ "grant_types": grant_types,
"response_types": ["code"],
"token_endpoint_auth_method": "none"
})
@@ -760,6 +780,7 @@ pub async fn fetch_auth_server_metadata(
return Ok(AuthServerMetadata {
issuer: reported_issuer,
+ grant_types_supported: response.grant_types_supported,
authorization_endpoint: response
.authorization_endpoint
.ok_or_else(|| anyhow!("missing authorization_endpoint"))?,
@@ -846,7 +867,18 @@ pub async fn resolve_client_registration(
}),
ClientRegistrationStrategy::Dcr {
registration_endpoint,
- } => perform_dcr(http_client, ®istration_endpoint, redirect_uri).await,
+ } => {
+ perform_dcr(
+ http_client,
+ ®istration_endpoint,
+ redirect_uri,
+ discovery
+ .auth_server_metadata
+ .grant_types_supported
+ .as_deref(),
+ )
+ .await
+ }
ClientRegistrationStrategy::Unavailable => {
bail!("authorization server supports neither CIMD nor DCR")
}
@@ -860,10 +892,11 @@ pub async fn perform_dcr(
http_client: &Arc<dyn HttpClient>,
registration_endpoint: &Url,
redirect_uri: &str,
+ server_grant_types: Option<&[String]>,
) -> Result<OAuthClientRegistration> {
validate_oauth_url(registration_endpoint)?;
- let body = dcr_registration_body(redirect_uri);
+ let body = dcr_registration_body(redirect_uri, server_grant_types);
let body_bytes = serde_json::to_vec(&body)?;
let request = Request::builder()
@@ -1206,6 +1239,8 @@ struct AuthServerMetadataResponse {
#[serde(default)]
scopes_supported: Option<Vec<String>>,
#[serde(default)]
+ grant_types_supported: Option<Vec<String>>,
+ #[serde(default)]
code_challenge_methods_supported: Option<Vec<String>>,
#[serde(default)]
client_id_metadata_document_supported: Option<bool>,
@@ -1707,6 +1742,7 @@ mod tests {
scopes_supported: None,
code_challenge_methods_supported: Some(vec!["S256".into()]),
client_id_metadata_document_supported: true,
+ grant_types_supported: None,
};
assert_eq!(
determine_registration_strategy(&metadata),
@@ -1727,6 +1763,7 @@ mod tests {
scopes_supported: None,
code_challenge_methods_supported: Some(vec!["S256".into()]),
client_id_metadata_document_supported: false,
+ grant_types_supported: None,
};
assert_eq!(
determine_registration_strategy(&metadata),
@@ -1746,6 +1783,7 @@ mod tests {
scopes_supported: None,
code_challenge_methods_supported: Some(vec!["S256".into()]),
client_id_metadata_document_supported: false,
+ grant_types_supported: None,
};
assert_eq!(
determine_registration_strategy(&metadata),
@@ -1802,6 +1840,7 @@ mod tests {
scopes_supported: None,
code_challenge_methods_supported: Some(vec!["S256".into()]),
client_id_metadata_document_supported: true,
+ grant_types_supported: None,
};
let pkce = PkceChallenge {
verifier: "test_verifier".into(),
@@ -1844,6 +1883,7 @@ mod tests {
scopes_supported: None,
code_challenge_methods_supported: Some(vec!["S256".into()]),
client_id_metadata_document_supported: false,
+ grant_types_supported: None,
};
let pkce = PkceChallenge {
verifier: "v".into(),
@@ -1927,15 +1967,35 @@ mod tests {
// -- DCR body test -------------------------------------------------------
#[test]
- fn test_dcr_registration_body_shape() {
- let body = dcr_registration_body("http://127.0.0.1:12345/callback");
+ fn test_dcr_registration_body_without_server_metadata() {
+ // When server metadata is unavailable, include all supported grant types.
+ let body = dcr_registration_body("http://127.0.0.1:12345/callback", None);
assert_eq!(body["client_name"], "Zed");
assert_eq!(body["redirect_uris"][0], "http://127.0.0.1:12345/callback");
assert_eq!(body["grant_types"][0], "authorization_code");
+ assert_eq!(body["grant_types"][1], "refresh_token");
assert_eq!(body["response_types"][0], "code");
assert_eq!(body["token_endpoint_auth_method"], "none");
}
+ #[test]
+ fn test_dcr_registration_body_mirrors_server_grant_types() {
+ // When the server only supports authorization_code, omit refresh_token.
+ let server_types = vec!["authorization_code".to_string()];
+ let body = dcr_registration_body("http://127.0.0.1:12345/callback", Some(&server_types));
+ assert_eq!(body["grant_types"][0], "authorization_code");
+ assert!(body["grant_types"].as_array().unwrap().len() == 1);
+
+ // When the server supports both, include both.
+ let server_types = vec![
+ "authorization_code".to_string(),
+ "refresh_token".to_string(),
+ ];
+ let body = dcr_registration_body("http://127.0.0.1:12345/callback", Some(&server_types));
+ assert_eq!(body["grant_types"][0], "authorization_code");
+ assert_eq!(body["grant_types"][1], "refresh_token");
+ }
+
// -- Test helpers for async/HTTP tests -----------------------------------
fn make_fake_http_client(
@@ -2398,6 +2458,7 @@ mod tests {
scopes_supported: None,
code_challenge_methods_supported: Some(vec!["S256".into()]),
client_id_metadata_document_supported: true,
+ grant_types_supported: None,
};
let tokens = exchange_code(
@@ -2472,6 +2533,7 @@ mod tests {
scopes_supported: None,
code_challenge_methods_supported: Some(vec!["S256".into()]),
client_id_metadata_document_supported: true,
+ grant_types_supported: None,
};
let result = exchange_code(
@@ -2508,9 +2570,10 @@ mod tests {
});
let endpoint = Url::parse("https://auth.example.com/register").unwrap();
- let registration = perform_dcr(&client, &endpoint, "http://127.0.0.1:9999/callback")
- .await
- .unwrap();
+ let registration =
+ perform_dcr(&client, &endpoint, "http://127.0.0.1:9999/callback", None)
+ .await
+ .unwrap();
assert_eq!(registration.client_id, "dynamic-client-001");
assert_eq!(
@@ -2530,7 +2593,8 @@ mod tests {
});
let endpoint = Url::parse("https://auth.example.com/register").unwrap();
- let result = perform_dcr(&client, &endpoint, "http://127.0.0.1:9999/callback").await;
+ let result =
+ perform_dcr(&client, &endpoint, "http://127.0.0.1:9999/callback", None).await;
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("403"));