diff --git a/crates/context_server/src/oauth.rs b/crates/context_server/src/oauth.rs index 1a314de2fca9b9987336decb15b208ffd7759dea..d1ce778bdcce02716f2fa78ed41ca2750837ed20 100644 --- a/crates/context_server/src/oauth.rs +++ b/crates/context_server/src/oauth.rs @@ -146,6 +146,7 @@ pub struct AuthServerMetadata { pub token_endpoint: Url, pub registration_endpoint: Option, pub scopes_supported: Option>, + pub grant_types_supported: Option>, pub code_challenge_methods_supported: Option>, pub client_id_metadata_document_supported: bool, } @@ -672,11 +673,30 @@ pub fn token_refresh_params( /// port (e.g. `http://127.0.0.1:12345/callback`). Some auth servers do strict /// redirect URI matching even for loopback addresses, so we register the /// exact URI we intend to use. -pub fn dcr_registration_body(redirect_uri: &str) -> serde_json::Value { +/// The grant types Zed can use. Intersected with the server's +/// `grant_types_supported` to build the DCR request. +const SUPPORTED_GRANT_TYPES: &[&str] = &["authorization_code", "refresh_token"]; + +pub fn dcr_registration_body( + redirect_uri: &str, + server_grant_types: Option<&[String]>, +) -> serde_json::Value { + // Use the intersection of what we support and what the server advertises. + // When the server doesn't advertise grant_types_supported, send all of + // ours — the server will reject what it doesn't like. + let grant_types: Vec<&str> = match server_grant_types { + Some(server) => SUPPORTED_GRANT_TYPES + .iter() + .copied() + .filter(|gt| server.iter().any(|s| s == *gt)) + .collect(), + None => SUPPORTED_GRANT_TYPES.to_vec(), + }; + serde_json::json!({ "client_name": "Zed", "redirect_uris": [redirect_uri], - "grant_types": ["authorization_code"], + "grant_types": grant_types, "response_types": ["code"], "token_endpoint_auth_method": "none" }) @@ -760,6 +780,7 @@ pub async fn fetch_auth_server_metadata( return Ok(AuthServerMetadata { issuer: reported_issuer, + grant_types_supported: response.grant_types_supported, authorization_endpoint: response .authorization_endpoint .ok_or_else(|| anyhow!("missing authorization_endpoint"))?, @@ -846,7 +867,18 @@ pub async fn resolve_client_registration( }), ClientRegistrationStrategy::Dcr { registration_endpoint, - } => perform_dcr(http_client, ®istration_endpoint, redirect_uri).await, + } => { + perform_dcr( + http_client, + ®istration_endpoint, + redirect_uri, + discovery + .auth_server_metadata + .grant_types_supported + .as_deref(), + ) + .await + } ClientRegistrationStrategy::Unavailable => { bail!("authorization server supports neither CIMD nor DCR") } @@ -860,10 +892,11 @@ pub async fn perform_dcr( http_client: &Arc, registration_endpoint: &Url, redirect_uri: &str, + server_grant_types: Option<&[String]>, ) -> Result { validate_oauth_url(registration_endpoint)?; - let body = dcr_registration_body(redirect_uri); + let body = dcr_registration_body(redirect_uri, server_grant_types); let body_bytes = serde_json::to_vec(&body)?; let request = Request::builder() @@ -1206,6 +1239,8 @@ struct AuthServerMetadataResponse { #[serde(default)] scopes_supported: Option>, #[serde(default)] + grant_types_supported: Option>, + #[serde(default)] code_challenge_methods_supported: Option>, #[serde(default)] client_id_metadata_document_supported: Option, @@ -1707,6 +1742,7 @@ mod tests { scopes_supported: None, code_challenge_methods_supported: Some(vec!["S256".into()]), client_id_metadata_document_supported: true, + grant_types_supported: None, }; assert_eq!( determine_registration_strategy(&metadata), @@ -1727,6 +1763,7 @@ mod tests { scopes_supported: None, code_challenge_methods_supported: Some(vec!["S256".into()]), client_id_metadata_document_supported: false, + grant_types_supported: None, }; assert_eq!( determine_registration_strategy(&metadata), @@ -1746,6 +1783,7 @@ mod tests { scopes_supported: None, code_challenge_methods_supported: Some(vec!["S256".into()]), client_id_metadata_document_supported: false, + grant_types_supported: None, }; assert_eq!( determine_registration_strategy(&metadata), @@ -1802,6 +1840,7 @@ mod tests { scopes_supported: None, code_challenge_methods_supported: Some(vec!["S256".into()]), client_id_metadata_document_supported: true, + grant_types_supported: None, }; let pkce = PkceChallenge { verifier: "test_verifier".into(), @@ -1844,6 +1883,7 @@ mod tests { scopes_supported: None, code_challenge_methods_supported: Some(vec!["S256".into()]), client_id_metadata_document_supported: false, + grant_types_supported: None, }; let pkce = PkceChallenge { verifier: "v".into(), @@ -1927,15 +1967,35 @@ mod tests { // -- DCR body test ------------------------------------------------------- #[test] - fn test_dcr_registration_body_shape() { - let body = dcr_registration_body("http://127.0.0.1:12345/callback"); + fn test_dcr_registration_body_without_server_metadata() { + // When server metadata is unavailable, include all supported grant types. + let body = dcr_registration_body("http://127.0.0.1:12345/callback", None); assert_eq!(body["client_name"], "Zed"); assert_eq!(body["redirect_uris"][0], "http://127.0.0.1:12345/callback"); assert_eq!(body["grant_types"][0], "authorization_code"); + assert_eq!(body["grant_types"][1], "refresh_token"); assert_eq!(body["response_types"][0], "code"); assert_eq!(body["token_endpoint_auth_method"], "none"); } + #[test] + fn test_dcr_registration_body_mirrors_server_grant_types() { + // When the server only supports authorization_code, omit refresh_token. + let server_types = vec!["authorization_code".to_string()]; + let body = dcr_registration_body("http://127.0.0.1:12345/callback", Some(&server_types)); + assert_eq!(body["grant_types"][0], "authorization_code"); + assert!(body["grant_types"].as_array().unwrap().len() == 1); + + // When the server supports both, include both. + let server_types = vec![ + "authorization_code".to_string(), + "refresh_token".to_string(), + ]; + let body = dcr_registration_body("http://127.0.0.1:12345/callback", Some(&server_types)); + assert_eq!(body["grant_types"][0], "authorization_code"); + assert_eq!(body["grant_types"][1], "refresh_token"); + } + // -- Test helpers for async/HTTP tests ----------------------------------- fn make_fake_http_client( @@ -2398,6 +2458,7 @@ mod tests { scopes_supported: None, code_challenge_methods_supported: Some(vec!["S256".into()]), client_id_metadata_document_supported: true, + grant_types_supported: None, }; let tokens = exchange_code( @@ -2472,6 +2533,7 @@ mod tests { scopes_supported: None, code_challenge_methods_supported: Some(vec!["S256".into()]), client_id_metadata_document_supported: true, + grant_types_supported: None, }; let result = exchange_code( @@ -2508,9 +2570,10 @@ mod tests { }); let endpoint = Url::parse("https://auth.example.com/register").unwrap(); - let registration = perform_dcr(&client, &endpoint, "http://127.0.0.1:9999/callback") - .await - .unwrap(); + let registration = + perform_dcr(&client, &endpoint, "http://127.0.0.1:9999/callback", None) + .await + .unwrap(); assert_eq!(registration.client_id, "dynamic-client-001"); assert_eq!( @@ -2530,7 +2593,8 @@ mod tests { }); let endpoint = Url::parse("https://auth.example.com/register").unwrap(); - let result = perform_dcr(&client, &endpoint, "http://127.0.0.1:9999/callback").await; + let result = + perform_dcr(&client, &endpoint, "http://127.0.0.1:9999/callback", None).await; assert!(result.is_err()); assert!(result.unwrap_err().to_string().contains("403"));