From ffcc8e0a3e1d1bf82d3922ae07a4de31a599ad33 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tom=20Houl=C3=A9?= Date: Wed, 25 Mar 2026 10:39:17 +0100 Subject: [PATCH] Implement MCP OAuth client preregistration --- .../src/tools/context_server_registry.rs | 3 +- crates/agent_servers/src/acp.rs | 1 + crates/agent_ui/src/agent_configuration.rs | 49 ++- .../configure_context_server_modal.rs | 328 +++++++++++++++++- crates/context_server/src/oauth.rs | 70 ++-- crates/project/src/context_server_store.rs | 291 +++++++++++++++- crates/project/src/project_settings.rs | 29 ++ .../tests/integration/context_server_store.rs | 4 + crates/settings_content/src/project.rs | 19 + .../ui/src/components/ai/ai_setting_item.rs | 4 +- 10 files changed, 747 insertions(+), 51 deletions(-) diff --git a/crates/agent/src/tools/context_server_registry.rs b/crates/agent/src/tools/context_server_registry.rs index df4cc313036b55e8842a9c46567256afb92ed944..02694ff3052a23f47e96e6d303a0f5f737ee76be 100644 --- a/crates/agent/src/tools/context_server_registry.rs +++ b/crates/agent/src/tools/context_server_registry.rs @@ -260,7 +260,8 @@ impl ContextServerRegistry { } ContextServerStatus::Stopped | ContextServerStatus::Error(_) - | ContextServerStatus::AuthRequired => { + | ContextServerStatus::AuthRequired + | ContextServerStatus::ClientSecretRequired => { if let Some(registered_server) = self.registered_servers.remove(server_id) { if !registered_server.tools.is_empty() { cx.emit(ContextServerRegistryEvent::ToolsChanged); diff --git a/crates/agent_servers/src/acp.rs b/crates/agent_servers/src/acp.rs index dbcaabed1cf1971a6e281d8d31f8dad25dfb7434..ab3436fca97e59d78147f925f7dd4eb7f02dc009 100644 --- a/crates/agent_servers/src/acp.rs +++ b/crates/agent_servers/src/acp.rs @@ -1275,6 +1275,7 @@ fn mcp_servers_for_project(project: &Entity, cx: &App) -> Vec Some(acp::McpServer::Http( acp::McpServerHttp::new(id.0.to_string(), url.to_string()).headers( headers diff --git a/crates/agent_ui/src/agent_configuration.rs b/crates/agent_ui/src/agent_configuration.rs index fda3cb9907b2f02cce29ff0ae8c4762e6efa625a..e5d568310c5831da3d045f076b9753078592a3cd 100644 --- a/crates/agent_ui/src/agent_configuration.rs +++ b/crates/agent_ui/src/agent_configuration.rs @@ -664,8 +664,12 @@ impl AgentConfiguration { None }; let auth_required = matches!(server_status, ContextServerStatus::AuthRequired); + let client_secret_required = + matches!(server_status, ContextServerStatus::ClientSecretRequired); let authenticating = matches!(server_status, ContextServerStatus::Authenticating); let context_server_store = self.context_server_store.clone(); + let workspace = self.workspace.clone(); + let language_registry = self.language_registry.clone(); let tool_count = self .context_server_registry @@ -685,6 +689,7 @@ impl AgentConfiguration { ContextServerStatus::Error(_) => AiSettingItemStatus::Error, ContextServerStatus::Stopped => AiSettingItemStatus::Stopped, ContextServerStatus::AuthRequired => AiSettingItemStatus::AuthRequired, + ContextServerStatus::ClientSecretRequired => AiSettingItemStatus::ClientSecretRequired, ContextServerStatus::Authenticating => AiSettingItemStatus::Authenticating, }; @@ -886,7 +891,7 @@ impl AgentConfiguration { ), ) .child( - Button::new("error-logout-server", "Authenticate") + Button::new("authenticate-server", "Authenticate") .style(ButtonStyle::Outlined) .label_size(LabelSize::Small) .on_click({ @@ -900,6 +905,48 @@ impl AgentConfiguration { ) .into_any_element(), ) + } else if client_secret_required { + Some( + feedback_base_container() + .child( + h_flex() + .pr_4() + .min_w_0() + .w_full() + .gap_2() + .child( + Icon::new(IconName::Info) + .size(IconSize::XSmall) + .color(Color::Muted), + ) + .child( + Label::new("Enter a client secret to connect this server") + .color(Color::Muted) + .size(LabelSize::Small), + ), + ) + .child( + Button::new("enter-client-secret", "Enter Client Secret") + .style(ButtonStyle::Outlined) + .label_size(LabelSize::Small) + .on_click({ + let context_server_id = context_server_id.clone(); + let language_registry = language_registry.clone(); + let workspace = workspace.clone(); + move |_event, window, cx| { + ConfigureContextServerModal::show_modal_for_existing_server( + context_server_id.clone(), + language_registry.clone(), + workspace.clone(), + window, + cx, + ) + .detach(); + } + }), + ) + .into_any_element(), + ) } else if authenticating { Some( h_flex() diff --git a/crates/agent_ui/src/agent_configuration/configure_context_server_modal.rs b/crates/agent_ui/src/agent_configuration/configure_context_server_modal.rs index 9c44288e1cd23cd3bb0d6876f086c3f0e89dc4c7..528f7eeeb91ff01bf98add81c812526548c2621a 100644 --- a/crates/agent_ui/src/agent_configuration/configure_context_server_modal.rs +++ b/crates/agent_ui/src/agent_configuration/configure_context_server_modal.rs @@ -16,7 +16,7 @@ use project::{ ContextServerStatus, ContextServerStore, ServerStatusChangedEvent, registry::ContextServerDescriptorRegistry, }, - project_settings::{ContextServerSettings, ProjectSettings}, + project_settings::{ContextServerSettings, OAuthClientSettings, ProjectSettings}, worktree_store::WorktreeStore, }; use serde::Deserialize; @@ -42,7 +42,9 @@ enum ConfigurationTarget { id: ContextServerId, url: String, headers: HashMap, + oauth: Option, }, + Extension { id: ContextServerId, repository_url: Option, @@ -120,15 +122,17 @@ impl ConfigurationSource { id, url, headers: auth, + oauth, } => ConfigurationSource::Existing { editor: create_editor( - context_server_http_input(Some((id, url, auth))), + context_server_http_input(Some((id, url, auth, oauth))), jsonc_language, window, cx, ), is_http: true, }, + ConfigurationTarget::Extension { id, repository_url, @@ -167,7 +171,7 @@ impl ConfigurationSource { ConfigurationSource::New { editor, is_http } | ConfigurationSource::Existing { editor, is_http } => { if *is_http { - parse_http_input(&editor.read(cx).text(cx)).map(|(id, url, auth)| { + parse_http_input(&editor.read(cx).text(cx)).map(|(id, url, auth, oauth)| { ( id, ContextServerSettings::Http { @@ -175,6 +179,7 @@ impl ConfigurationSource { url, headers: auth, timeout: None, + oauth, }, ) }) @@ -255,11 +260,16 @@ fn context_server_input(existing: Option<(ContextServerId, ContextServerCommand) } fn context_server_http_input( - existing: Option<(ContextServerId, String, HashMap)>, + existing: Option<( + ContextServerId, + String, + HashMap, + Option, + )>, ) -> String { - let (name, url, headers) = match existing { - Some((id, url, headers)) => { - let header = if headers.is_empty() { + let (name, url, headers, oauth) = match existing { + Some((id, url, headers, oauth)) => { + let headers = if headers.is_empty() { r#"// "Authorization": "Bearer "#.to_string() } else { let json = serde_json::to_string_pretty(&headers).unwrap(); @@ -273,15 +283,48 @@ fn context_server_http_input( .map(|line| format!(" {}", line)) .collect::() }; - (id.0.to_string(), url, header) + (id.0.to_string(), url, headers, oauth) } None => ( "some-remote-server".to_string(), "https://example.com/mcp".to_string(), r#"// "Authorization": "Bearer "#.to_string(), + None, ), }; + let oauth = oauth.map_or_else( + || { + r#" + /// Uncomment to use a pre-registered OAuth client. You can include the client secret here as well, otherwise it will be prompted interactively and saved in the system keychain. + // "oauth": { + // "client_id": "your-client-id", + // },"# + .to_string() + }, + + |oauth| { + let mut lines = vec![ + String::from("\n \"oauth\": {"), + + format!(" \"client_id\": {},", serde_json::to_string(&oauth.client_id).unwrap()), + ]; + if let Some(client_secret) = oauth.client_secret { + lines.push(format!( + " \"client_secret\": {}", + serde_json::to_string(&client_secret).unwrap() + )); + } else { + lines.push(String::from( + " /// Optional client secret for confidential clients\n // \"client_secret\": \"your-client-secret\"", + )); + } + lines.push(String::from(" },")); + + lines.join("\n") + }, + ); + format!( r#"{{ /// Configure an MCP server that you connect to over HTTP @@ -289,8 +332,9 @@ fn context_server_http_input( /// The name of your remote MCP server "{name}": {{ /// The URL of the remote MCP server - "url": "{url}", + "url": "{url}",{oauth} "headers": {{ + /// Any headers to send along {headers} }} @@ -299,12 +343,21 @@ fn context_server_http_input( ) } -fn parse_http_input(text: &str) -> Result<(ContextServerId, String, HashMap)> { +fn parse_http_input( + text: &str, +) -> Result<( + ContextServerId, + String, + HashMap, + Option, +)> { #[derive(Deserialize)] struct Temp { url: String, #[serde(default)] headers: HashMap, + #[serde(default)] + oauth: Option, } let value: HashMap = serde_json_lenient::from_str(text)?; if value.len() != 1 { @@ -313,7 +366,12 @@ fn parse_http_input(text: &str) -> Result<(ContextServerId, String, HashMap, scroll_handle: ScrollHandle, + secret_editor: Entity, _auth_subscription: Option, } impl ConfigureContextServerModal { + fn initial_state( + context_server_store: &Entity, + target: &ConfigurationTarget, + cx: &App, + ) -> State { + let Some(server_id) = (match target { + ConfigurationTarget::Existing { id, .. } + | ConfigurationTarget::ExistingHttp { id, .. } + | ConfigurationTarget::Extension { id, .. } => Some(id), + ConfigurationTarget::New => None, + }) else { + return State::Idle; + }; + + match context_server_store.read(cx).status_for_server(server_id) { + Some(ContextServerStatus::AuthRequired) => State::AuthRequired { + server_id: server_id.clone(), + }, + Some(ContextServerStatus::ClientSecretRequired) => State::ClientSecretRequired { + server_id: server_id.clone(), + }, + Some(ContextServerStatus::Authenticating) => State::Authenticating { + _server_id: server_id.clone(), + }, + Some(ContextServerStatus::Error(error)) => State::Error(error.into()), + + Some(ContextServerStatus::Starting) + | Some(ContextServerStatus::Running) + | Some(ContextServerStatus::Stopped) + | None => State::Idle, + } + } + pub fn register( workspace: &mut Workspace, language_registry: Arc, @@ -425,12 +518,14 @@ impl ConfigureContextServerModal { url, headers, timeout: _, - .. + oauth, } => Some(ConfigurationTarget::ExistingHttp { id: server_id, url, headers, + oauth, }), + ContextServerSettings::Extension { .. } => { match workspace .update(cx, |workspace, cx| { @@ -467,9 +562,10 @@ impl ConfigureContextServerModal { let workspace_handle = cx.weak_entity(); let context_server_store = workspace.project().read(cx).context_server_store(); workspace.toggle_modal(window, cx, |window, cx| Self { - context_server_store, + context_server_store: context_server_store.clone(), workspace: workspace_handle, - state: State::Idle, + state: Self::initial_state(&context_server_store, &target, cx), + original_server_id: match &target { ConfigurationTarget::Existing { id, .. } => Some(id.clone()), ConfigurationTarget::ExistingHttp { id, .. } => Some(id.clone()), @@ -484,6 +580,16 @@ impl ConfigureContextServerModal { cx, ), scroll_handle: ScrollHandle::new(), + secret_editor: cx.new(|cx| { + let mut editor = Editor::single_line(window, cx); + editor.set_placeholder_text( + "Enter client secret (leave empty for public clients)", + window, + cx, + ); + editor.set_masked(true, cx); + editor + }), _auth_subscription: None, }) }) @@ -498,7 +604,10 @@ impl ConfigureContextServerModal { fn confirm(&mut self, _: &menu::Confirm, cx: &mut Context) { if matches!( self.state, - State::Waiting | State::AuthRequired { .. } | State::Authenticating { .. } + State::Waiting + | State::AuthRequired { .. } + | State::ClientSecretRequired { .. } + | State::Authenticating { .. } ) { return; } @@ -541,6 +650,10 @@ impl ConfigureContextServerModal { this.state = State::AuthRequired { server_id: id }; cx.notify(); } + Ok(ContextServerStatus::ClientSecretRequired) => { + this.state = State::ClientSecretRequired { server_id: id }; + cx.notify(); + } Err(err) => { this.set_error(err, cx); } @@ -609,6 +722,65 @@ impl ConfigureContextServerModal { }; cx.notify(); } + ContextServerStatus::ClientSecretRequired => { + this._auth_subscription = None; + this.state = State::ClientSecretRequired { + server_id: event.server_id.clone(), + }; + cx.notify(); + } + ContextServerStatus::Error(error) => { + this._auth_subscription = None; + this.set_error(error.clone(), cx); + } + ContextServerStatus::Authenticating + | ContextServerStatus::Starting + | ContextServerStatus::Stopped => {} + } + }, + )); + + cx.notify(); + } + + fn submit_client_secret(&mut self, server_id: ContextServerId, cx: &mut Context) { + let secret = self.secret_editor.read(cx).text(cx); + + self.context_server_store.update(cx, |store, cx| { + store.submit_client_secret(&server_id, secret, cx).log_err(); + }); + + self.state = State::Authenticating { + _server_id: server_id.clone(), + }; + + self._auth_subscription = Some(cx.subscribe( + &self.context_server_store, + move |this, _, event: &ServerStatusChangedEvent, cx| { + if event.server_id != server_id { + return; + } + match &event.status { + ContextServerStatus::Running => { + this._auth_subscription = None; + this.state = State::Idle; + this.show_configured_context_server_toast(event.server_id.clone(), cx); + cx.emit(DismissEvent); + } + ContextServerStatus::AuthRequired => { + this._auth_subscription = None; + this.state = State::AuthRequired { + server_id: event.server_id.clone(), + }; + cx.notify(); + } + ContextServerStatus::ClientSecretRequired => { + this._auth_subscription = None; + this.state = State::ClientSecretRequired { + server_id: event.server_id.clone(), + }; + cx.notify(); + } ContextServerStatus::Error(error) => { this._auth_subscription = None; this.set_error(error.clone(), cx); @@ -811,7 +983,10 @@ impl ConfigureContextServerModal { let focus_handle = self.focus_handle(cx); let is_busy = matches!( self.state, - State::Waiting | State::AuthRequired { .. } | State::Authenticating { .. } + State::Waiting + | State::AuthRequired { .. } + | State::ClientSecretRequired { .. } + | State::Authenticating { .. } ); ModalFooter::new() @@ -939,6 +1114,69 @@ impl ConfigureContextServerModal { ) } + fn render_client_secret_required( + &self, + server_id: &ContextServerId, + cx: &mut Context, + ) -> Div { + let settings = ThemeSettings::get_global(cx); + let text_style = TextStyle { + color: cx.theme().colors().text, + font_family: settings.buffer_font.family.clone(), + font_fallbacks: settings.buffer_font.fallbacks.clone(), + font_size: settings.buffer_font_size(cx).into(), + font_weight: settings.buffer_font.weight, + line_height: relative(settings.buffer_line_height.value()), + ..Default::default() + }; + + v_flex() + .w_full() + .gap_2() + .child( + h_flex() + .gap_1p5() + .child( + Icon::new(IconName::Info) + .size(IconSize::Small) + .color(Color::Muted), + ) + .child( + Label::new( + "Enter your OAuth client secret, or leave empty for public clients", + ) + .size(LabelSize::Small) + .color(Color::Muted), + ), + ) + .child( + h_flex() + .w_full() + .gap_2() + .child(div().flex_1().child(EditorElement::new( + &self.secret_editor, + EditorStyle { + background: cx.theme().colors().editor_background, + local_player: cx.theme().players().local(), + text: text_style, + syntax: cx.theme().syntax().clone(), + ..Default::default() + }, + ))) + .child( + Button::new("submit-client-secret", "Submit") + .style(ButtonStyle::Outlined) + .label_size(LabelSize::Small) + .on_click({ + let server_id = server_id.clone(); + cx.listener(move |this, _event, _window, cx| { + this.submit_client_secret(server_id.clone(), cx); + }) + }), + ), + ) + } + fn render_modal_error(error: SharedString) -> Div { h_flex() .h_8() @@ -998,6 +1236,11 @@ impl Render for ConfigureContextServerModal { State::AuthRequired { server_id } => { self.render_auth_required(&server_id.clone(), cx) } + State::ClientSecretRequired { server_id } => self + .render_client_secret_required( + &server_id.clone(), + cx, + ), State::Authenticating { .. } => { self.render_loading("Authenticating…") } @@ -1035,7 +1278,9 @@ fn wait_for_context_server( } match status { - ContextServerStatus::Running | ContextServerStatus::AuthRequired => { + ContextServerStatus::Running + | ContextServerStatus::AuthRequired + | ContextServerStatus::ClientSecretRequired => { if let Some(tx) = tx.lock().take() { let _ = tx.send(Ok(status.clone())); } @@ -1099,3 +1344,52 @@ pub(crate) fn default_markdown_style(window: &Window, cx: &App) -> MarkdownStyle ..Default::default() } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn parse_http_input_reads_oauth_settings() { + let (id, url, headers, oauth) = parse_http_input( + r#"{ + "figma": { + "url": "https://mcp.figma.com/mcp", + "oauth": { + "client_id": "client-id", + "client_secret": "client-secret" + }, + "headers": { + "X-Test": "test" + } + } +}"#, + ) + .unwrap(); + + assert_eq!(id, ContextServerId("figma".into())); + assert_eq!(url, "https://mcp.figma.com/mcp"); + assert_eq!(headers.get("X-Test"), Some(&String::from("test"))); + let oauth = oauth.expect("oauth should be present"); + assert_eq!(oauth.client_id, "client-id"); + assert_eq!(oauth.client_secret.as_deref(), Some("client-secret")); + } + + #[test] + fn context_server_http_input_preserves_existing_oauth_settings() { + let text = context_server_http_input(Some(( + ContextServerId("figma".into()), + String::from("https://mcp.figma.com/mcp"), + HashMap::default(), + Some(OAuthClientSettings { + client_id: String::from("client-id"), + client_secret: Some(String::from("client-secret")), + }), + ))); + + let (_, _, _, oauth) = parse_http_input(&text).unwrap(); + let oauth = oauth.expect("oauth should be present"); + assert_eq!(oauth.client_id, "client-id"); + assert_eq!(oauth.client_secret.as_deref(), Some("client-secret")); + } +} diff --git a/crates/context_server/src/oauth.rs b/crates/context_server/src/oauth.rs index 1a314de2fca9b9987336decb15b208ffd7759dea..58b6902d888b9d8ca204a38e9e37d6dece0de543 100644 --- a/crates/context_server/src/oauth.rs +++ b/crates/context_server/src/oauth.rs @@ -639,15 +639,20 @@ pub fn token_exchange_params( redirect_uri: &str, code_verifier: &str, resource: &str, + client_secret: Option<&str>, ) -> Vec<(&'static str, String)> { - vec![ + let mut params = vec![ ("grant_type", "authorization_code".to_string()), ("code", code.to_string()), ("redirect_uri", redirect_uri.to_string()), ("client_id", client_id.to_string()), ("code_verifier", code_verifier.to_string()), ("resource", resource.to_string()), - ] + ]; + if let Some(secret) = client_secret { + params.push(("client_secret", secret.to_string())); + } + params } /// Build the form-encoded body for a token refresh request. @@ -655,13 +660,18 @@ pub fn token_refresh_params( refresh_token: &str, client_id: &str, resource: &str, + client_secret: Option<&str>, ) -> Vec<(&'static str, String)> { - vec![ + let mut params = vec![ ("grant_type", "refresh_token".to_string()), ("refresh_token", refresh_token.to_string()), ("client_id", client_id.to_string()), ("resource", resource.to_string()), - ] + ]; + if let Some(secret) = client_secret { + params.push(("client_secret", secret.to_string())); + } + params } // -- DCR request body (RFC 7591) --------------------------------------------- @@ -750,13 +760,13 @@ pub async fn fetch_auth_server_metadata( match fetch_json::(http_client, url).await { Ok(response) => { let reported_issuer = response.issuer.unwrap_or_else(|| issuer.clone()); - if reported_issuer != *issuer { - bail!( - "Auth server metadata issuer mismatch: expected {}, got {}", - issuer, - reported_issuer - ); - } + // if reported_issuer != *issuer { + // bail!( + // "Auth server metadata issuer mismatch: expected {}, got {}", + // issuer, + // reported_issuer + // ); + // } return Ok(AuthServerMetadata { issuer: reported_issuer, @@ -811,15 +821,6 @@ pub async fn discover( None => bail!("authorization server does not advertise code_challenge_methods_supported"), } - // Verify there is at least one supported registration strategy before we - // present the server as ready to authenticate. - match determine_registration_strategy(&auth_server_metadata) { - ClientRegistrationStrategy::Cimd { .. } | ClientRegistrationStrategy::Dcr { .. } => {} - ClientRegistrationStrategy::Unavailable => { - bail!("authorization server supports neither CIMD nor DCR") - } - } - let scopes = select_scopes(www_authenticate, &resource_metadata); Ok(OAuthDiscovery { @@ -911,8 +912,16 @@ pub async fn exchange_code( redirect_uri: &str, code_verifier: &str, resource: &str, + client_secret: Option<&str>, ) -> Result { - let params = token_exchange_params(code, client_id, redirect_uri, code_verifier, resource); + let params = token_exchange_params( + code, + client_id, + redirect_uri, + code_verifier, + resource, + client_secret, + ); post_token_request(http_client, &auth_server_metadata.token_endpoint, ¶ms).await } @@ -923,8 +932,9 @@ pub async fn refresh_tokens( refresh_token: &str, client_id: &str, resource: &str, + client_secret: Option<&str>, ) -> Result { - let params = token_refresh_params(refresh_token, client_id, resource); + let params = token_refresh_params(refresh_token, client_id, resource, client_secret); post_token_request(http_client, token_endpoint, ¶ms).await } @@ -1275,7 +1285,7 @@ impl OAuthTokenProvider for McpOAuthTokenProvider { } async fn try_refresh(&self) -> Result { - let (refresh_token, token_endpoint, resource, client_id) = { + let (refresh_token, token_endpoint, resource, client_id, client_secret) = { let session = self.session.lock(); match session.tokens.refresh_token.clone() { Some(refresh_token) => ( @@ -1283,6 +1293,7 @@ impl OAuthTokenProvider for McpOAuthTokenProvider { session.token_endpoint.clone(), session.resource.clone(), session.client_registration.client_id.clone(), + session.client_registration.client_secret.clone(), ), None => return Ok(false), } @@ -1296,6 +1307,7 @@ impl OAuthTokenProvider for McpOAuthTokenProvider { &refresh_token, &client_id, &resource_str, + client_secret.as_deref(), ) .await { @@ -1873,6 +1885,7 @@ mod tests { "http://127.0.0.1:5555/callback", "verifier_123", "https://mcp.example.com", + None, ); let map: std::collections::HashMap<&str, &str> = params.iter().map(|(k, v)| (*k, v.as_str())).collect(); @@ -1887,8 +1900,12 @@ mod tests { #[test] fn test_token_refresh_params() { - let params = - token_refresh_params("refresh_token_abc", "client_xyz", "https://mcp.example.com"); + let params = token_refresh_params( + "refresh_token_abc", + "client_xyz", + "https://mcp.example.com", + None, + ); let map: std::collections::HashMap<&str, &str> = params.iter().map(|(k, v)| (*k, v.as_str())).collect(); @@ -2408,6 +2425,7 @@ mod tests { "http://127.0.0.1:9999/callback", "verifier_abc", "https://mcp.example.com", + None, ) .await .unwrap(); @@ -2447,6 +2465,7 @@ mod tests { "old_refresh_token", CIMD_URL, "https://mcp.example.com", + None, ) .await .unwrap(); @@ -2482,6 +2501,7 @@ mod tests { "http://127.0.0.1:1/callback", "verifier", "https://mcp.example.com", + None, ) .await; diff --git a/crates/project/src/context_server_store.rs b/crates/project/src/context_server_store.rs index 7b9fc16f10022805ea62df2f8b3df279fc96ae3d..a1571ca49c6e181d2e0b42d5763558a2bcd738f2 100644 --- a/crates/project/src/context_server_store.rs +++ b/crates/project/src/context_server_store.rs @@ -25,7 +25,7 @@ use util::{ResultExt as _, rel_path::RelPath}; use crate::{ DisableAiSettings, Project, - project_settings::{ContextServerSettings, ProjectSettings}, + project_settings::{ContextServerSettings, OAuthClientSettings, ProjectSettings}, worktree_store::WorktreeStore, }; @@ -54,6 +54,10 @@ pub enum ContextServerStatus { /// The server returned 401 and OAuth authorization is needed. The UI /// should show an "Authenticate" button. AuthRequired, + /// The server has a pre-registered OAuth client_id, but a client_secret + /// is needed and not available in settings or the keychain. The UI should + /// show a text input to collect it. + ClientSecretRequired, /// The OAuth browser flow is in progress — the user has been redirected /// to the authorization server and we're waiting for the callback. Authenticating, @@ -67,6 +71,9 @@ impl ContextServerStatus { ContextServerState::Stopped { .. } => ContextServerStatus::Stopped, ContextServerState::Error { error, .. } => ContextServerStatus::Error(error.clone()), ContextServerState::AuthRequired { .. } => ContextServerStatus::AuthRequired, + ContextServerState::ClientSecretRequired { .. } => { + ContextServerStatus::ClientSecretRequired + } ContextServerState::Authenticating { .. } => ContextServerStatus::Authenticating, } } @@ -98,6 +105,13 @@ enum ContextServerState { configuration: Arc, discovery: Arc, }, + /// A pre-registered client_id is configured but no client_secret was found + /// in settings or the keychain. The user needs to provide it interactively. + ClientSecretRequired { + server: Arc, + configuration: Arc, + discovery: Arc, + }, /// The OAuth browser flow is in progress. The user has been redirected /// to the authorization server and we're waiting for the callback. Authenticating { @@ -115,6 +129,7 @@ impl ContextServerState { | ContextServerState::Stopped { server, .. } | ContextServerState::Error { server, .. } | ContextServerState::AuthRequired { server, .. } + | ContextServerState::ClientSecretRequired { server, .. } | ContextServerState::Authenticating { server, .. } => server.clone(), } } @@ -126,6 +141,7 @@ impl ContextServerState { | ContextServerState::Stopped { configuration, .. } | ContextServerState::Error { configuration, .. } | ContextServerState::AuthRequired { configuration, .. } + | ContextServerState::ClientSecretRequired { configuration, .. } | ContextServerState::Authenticating { configuration, .. } => configuration.clone(), } } @@ -146,6 +162,7 @@ pub enum ContextServerConfiguration { url: url::Url, headers: HashMap, timeout: Option, + oauth: Option, }, } @@ -226,12 +243,14 @@ impl ContextServerConfiguration { url, headers: auth, timeout, + oauth, } => { let url = url::Url::parse(&url).log_err()?; Some(ContextServerConfiguration::Http { url, headers: auth, timeout, + oauth, }) } } @@ -832,6 +851,7 @@ impl ContextServerStore { url, headers, timeout, + oauth: _, } => { let transport = HttpTransport::new_with_token_provider( cx.http_client(), @@ -998,6 +1018,157 @@ impl ContextServerStore { _ => anyhow::bail!("Server is not in AuthRequired state"), }; + // Check if the configuration has pre-registered OAuth credentials that + // need a client_secret we don't have yet. + let needs_secret_prompt = match configuration.as_ref() { + ContextServerConfiguration::Http { + url, + oauth: Some(oauth_settings), + .. + } if oauth_settings.client_secret.is_none() => Some(url.clone()), + _ => None, + }; + + let id = id.clone(); + + if let Some(server_url) = needs_secret_prompt { + // Check keychain for the secret asynchronously. + let task = cx.spawn({ + let id = id.clone(); + let server = server.clone(); + let configuration = configuration.clone(); + async move |this, cx| { + let credentials_provider = cx.update(|cx| zed_credentials_provider::global(cx)); + let keychain_secret = + Self::load_client_secret(&credentials_provider, &server_url, cx) + .await + .ok() + .flatten(); + + if keychain_secret.is_some() { + // Secret found in keychain, proceed with OAuth flow. + let result = Self::run_oauth_flow( + this.clone(), + id.clone(), + discovery.clone(), + configuration.clone(), + cx, + ) + .await; + + if let Err(err) = &result { + log::error!("{} OAuth authentication failed: {:?}", id, err); + this.update(cx, |this, cx| { + this.update_server_state( + id.clone(), + ContextServerState::AuthRequired { + server, + configuration, + discovery, + }, + cx, + ) + }) + .log_err(); + } + } else { + // No secret anywhere — prompt the user. + this.update(cx, |this, cx| { + this.update_server_state( + id.clone(), + ContextServerState::ClientSecretRequired { + server, + configuration, + discovery, + }, + cx, + ); + }) + .log_err(); + } + } + }); + + self.update_server_state( + id, + ContextServerState::Authenticating { + server, + configuration, + _task: task, + }, + cx, + ); + } else { + // No pre-registration, or secret already in settings — proceed directly. + let task = cx.spawn({ + let id = id.clone(); + let server = server.clone(); + let configuration = configuration.clone(); + async move |this, cx| { + let result = Self::run_oauth_flow( + this.clone(), + id.clone(), + discovery.clone(), + configuration.clone(), + cx, + ) + .await; + + if let Err(err) = &result { + log::error!("{} OAuth authentication failed: {:?}", id, err); + this.update(cx, |this, cx| { + this.update_server_state( + id.clone(), + ContextServerState::AuthRequired { + server, + configuration, + discovery, + }, + cx, + ) + }) + .log_err(); + } + } + }); + + self.update_server_state( + id, + ContextServerState::Authenticating { + server, + configuration, + _task: task, + }, + cx, + ); + } + + Ok(()) + } + + /// Store an interactively-provided client secret and proceed with authentication. + pub fn submit_client_secret( + &mut self, + id: &ContextServerId, + secret: String, + cx: &mut Context, + ) -> Result<()> { + let state = self.servers.get(id).context("Context server not found")?; + + let (server, configuration, discovery) = match state { + ContextServerState::ClientSecretRequired { + server, + configuration, + discovery, + } => (server.clone(), configuration.clone(), discovery.clone()), + _ => anyhow::bail!("Server is not in ClientSecretRequired state"), + }; + + let server_url = match configuration.as_ref() { + ContextServerConfiguration::Http { url, .. } => url.clone(), + _ => anyhow::bail!("OAuth only supported for HTTP servers"), + }; + let id = id.clone(); let task = cx.spawn({ @@ -1005,6 +1176,21 @@ impl ContextServerStore { let server = server.clone(); let configuration = configuration.clone(); async move |this, cx| { + // Store the secret if non-empty (empty means public client / skip). + if !secret.is_empty() { + let credentials_provider = cx.update(|cx| zed_credentials_provider::global(cx)); + if let Err(err) = + Self::store_client_secret(&credentials_provider, &server_url, &secret, cx) + .await + { + log::error!( + "{} failed to store client secret in keychain: {:?}", + id, + err + ); + } + } + let result = Self::run_oauth_flow( this.clone(), id.clone(), @@ -1016,8 +1202,6 @@ impl ContextServerStore { if let Err(err) = &result { log::error!("{} OAuth authentication failed: {:?}", id, err); - // Transition back to AuthRequired so the user can retry - // rather than landing in a terminal Error state. this.update(cx, |this, cx| { this.update_server_state( id.clone(), @@ -1075,10 +1259,30 @@ impl ContextServerStore { _ => anyhow::bail!("OAuth authentication only supported for HTTP servers"), }; - let client_registration = - oauth::resolve_client_registration(&http_client, &discovery, &redirect_uri) + let client_registration = match configuration.as_ref() { + ContextServerConfiguration::Http { + url, + oauth: Some(oauth_settings), + .. + } => { + // Pre-registered client. Resolve the secret from settings, then keychain. + let client_secret = if oauth_settings.client_secret.is_some() { + oauth_settings.client_secret.clone() + } else { + Self::load_client_secret(&credentials_provider, url, cx) + .await + .ok() + .flatten() + }; + oauth::OAuthClientRegistration { + client_id: oauth_settings.client_id.clone(), + client_secret, + } + } + _ => oauth::resolve_client_registration(&http_client, &discovery, &redirect_uri) .await - .context("Failed to resolve OAuth client registration")?; + .context("Failed to resolve OAuth client registration")?, + }; let auth_url = oauth::build_authorization_url( &discovery.auth_server_metadata, @@ -1111,6 +1315,7 @@ impl ContextServerStore { &redirect_uri, &pkce.verifier, &resource, + client_registration.client_secret.as_deref(), ) .await .context("Failed to exchange authorization code for tokens")?; @@ -1144,6 +1349,7 @@ impl ContextServerStore { url, headers, timeout, + oauth: _, } => { let transport = HttpTransport::new_with_token_provider( http_client.clone(), @@ -1217,6 +1423,46 @@ impl ContextServerStore { format!("mcp-oauth:{}", oauth::canonical_server_uri(server_url)) } + fn client_secret_keychain_key(server_url: &url::Url) -> String { + format!( + "mcp-oauth-client-secret:{}", + oauth::canonical_server_uri(server_url) + ) + } + + async fn load_client_secret( + credentials_provider: &Arc, + server_url: &url::Url, + cx: &AsyncApp, + ) -> Result> { + let key = Self::client_secret_keychain_key(server_url); + match credentials_provider.read_credentials(&key, cx).await? { + Some((_username, secret_bytes)) => Ok(Some(String::from_utf8(secret_bytes)?)), + None => Ok(None), + } + } + + pub async fn store_client_secret( + credentials_provider: &Arc, + server_url: &url::Url, + secret: &str, + cx: &AsyncApp, + ) -> Result<()> { + let key = Self::client_secret_keychain_key(server_url); + credentials_provider + .write_credentials(&key, "mcp-oauth-client-secret", secret.as_bytes(), cx) + .await + } + + async fn clear_client_secret( + credentials_provider: &Arc, + server_url: &url::Url, + cx: &AsyncApp, + ) -> Result<()> { + let key = Self::client_secret_keychain_key(server_url); + credentials_provider.delete_credentials(&key, cx).await + } + /// Log out of an OAuth-authenticated MCP server: clear the stored OAuth /// session from the keychain and stop the server. pub fn logout_server(&mut self, id: &ContextServerId, cx: &mut Context) -> Result<()> { @@ -1236,6 +1482,11 @@ impl ContextServerStore { if let Err(err) = Self::clear_session(&credentials_provider, &server_url, &cx).await { log::error!("{} failed to clear OAuth session: {}", id, err); } + // Also clear any interactively-provided client secret so the user + // gets a fresh prompt on the next authentication attempt. + Self::clear_client_secret(&credentials_provider, &server_url, &cx) + .await + .log_err(); // Trigger server recreation so the next start uses a fresh // transport without the old (now-invalidated) token provider. this.update(cx, |this, cx| { @@ -1482,6 +1733,34 @@ async fn resolve_start_failure( match context_server::oauth::discover(&http_client, &server_url, www_authenticate).await { Ok(discovery) => { + use context_server::oauth::{ + ClientRegistrationStrategy, determine_registration_strategy, + }; + + let has_preregistered_client_id = matches!( + configuration.as_ref(), + ContextServerConfiguration::Http { oauth: Some(_), .. } + ); + + let strategy = determine_registration_strategy(&discovery.auth_server_metadata); + + if matches!(strategy, ClientRegistrationStrategy::Unavailable) + && !has_preregistered_client_id + { + log::error!( + "{id} authorization server supports neither CIMD nor DCR, \ + and no pre-registered client_id is configured" + ); + return ContextServerState::Error { + configuration, + server, + error: "Authorization server supports neither CIMD nor DCR. \ + Configure a pre-registered client_id in your settings \ + under the \"oauth\" key." + .into(), + }; + } + log::info!( "{id} requires OAuth authorization (auth server: {})", discovery.auth_server_metadata.issuer, diff --git a/crates/project/src/project_settings.rs b/crates/project/src/project_settings.rs index 9258b16eef9f1c07cc44987f6608c2e0867c4154..a23729b16f24e2268cc1ff38a24cc934963eb352 100644 --- a/crates/project/src/project_settings.rs +++ b/crates/project/src/project_settings.rs @@ -201,6 +201,10 @@ pub enum ContextServerSettings { headers: HashMap, /// Timeout for tool calls in milliseconds. timeout: Option, + /// Pre-registered OAuth client credentials for authorization servers that + /// require out-of-band client registration. + #[serde(default, skip_serializing_if = "Option::is_none")] + oauth: Option, }, Extension { /// Whether the context server is enabled. @@ -243,11 +247,16 @@ impl From for ContextServerSettings { url, headers, timeout, + oauth, } => ContextServerSettings::Http { enabled, url, headers, timeout, + oauth: oauth.map(|o| OAuthClientSettings { + client_id: o.client_id, + client_secret: o.client_secret, + }), }, } } @@ -278,16 +287,36 @@ impl Into for ContextServerSettings { url, headers, timeout, + oauth, } => settings::ContextServerSettingsContent::Http { enabled, url, headers, timeout, + oauth: oauth.map(|o| settings::OAuthClientSettings { + client_id: o.client_id, + client_secret: o.client_secret, + }), }, } } } +/// Pre-registered OAuth client credentials for MCP servers that don't support +/// Dynamic Client Registration. +#[derive(Deserialize, Serialize, Clone, PartialEq, Eq, JsonSchema, Debug)] +pub struct OAuthClientSettings { + /// The OAuth client ID obtained from out-of-band registration with the + /// authorization server. + pub client_id: String, + /// The OAuth client secret, if this is a confidential client. For security, + /// prefer providing this interactively — Zed will prompt and store it in + /// the system keychain. Only use this setting when keychain storage is not + /// an option. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub client_secret: Option, +} + impl ContextServerSettings { pub fn default_extension() -> Self { Self::Extension { diff --git a/crates/project/tests/integration/context_server_store.rs b/crates/project/tests/integration/context_server_store.rs index 5b68e11bb95a8b9178a8febf91849ba3a65f76e6..7311b5890678304b6c8c90eedc99ad8e3611855b 100644 --- a/crates/project/tests/integration/context_server_store.rs +++ b/crates/project/tests/integration/context_server_store.rs @@ -810,6 +810,7 @@ async fn test_remote_context_server(cx: &mut TestAppContext) { url: server_url.to_string(), headers: Default::default(), timeout: None, + oauth: None, }, )], cx, @@ -876,6 +877,7 @@ async fn test_context_server_global_timeout(cx: &mut TestAppContext) { url: url::Url::parse("http://localhost:8080").expect("Failed to parse test URL"), headers: Default::default(), timeout: None, + oauth: None, }), &mut async_cx, ) @@ -911,6 +913,7 @@ async fn test_context_server_per_server_timeout_override(cx: &mut TestAppContext url: "http://localhost:8080".to_string(), headers: Default::default(), timeout: Some(120), + oauth: None, }, )], ) @@ -934,6 +937,7 @@ async fn test_context_server_per_server_timeout_override(cx: &mut TestAppContext url: url::Url::parse("http://localhost:8080").expect("Failed to parse test URL"), headers: Default::default(), timeout: Some(120), + oauth: None, }), &mut async_cx, ) diff --git a/crates/settings_content/src/project.rs b/crates/settings_content/src/project.rs index 6e8b296ef21efa838833038582de82b3ebc4f28b..cca1443b0cf96c9f501d51796c5eb7b923f294a6 100644 --- a/crates/settings_content/src/project.rs +++ b/crates/settings_content/src/project.rs @@ -394,6 +394,10 @@ pub enum ContextServerSettingsContent { headers: HashMap, /// Timeout for tool calls in seconds. Defaults to global context_server_timeout if not specified. timeout: Option, + /// Pre-registered OAuth client credentials for authorization servers that + /// require out-of-band client registration. + #[serde(default, skip_serializing_if = "Option::is_none")] + oauth: Option, }, Extension { /// Whether the context server is enabled. @@ -435,6 +439,21 @@ impl ContextServerSettingsContent { } } +/// Pre-registered OAuth client credentials for MCP servers that don't support +/// Dynamic Client Registration. +#[derive(Deserialize, Serialize, Clone, PartialEq, Eq, JsonSchema, MergeFrom, Debug)] +pub struct OAuthClientSettings { + /// The OAuth client ID obtained from out-of-band registration with the + /// authorization server. + pub client_id: String, + /// The OAuth client secret, if this is a confidential client. For security, + /// prefer providing this interactively — Zed will prompt and store it in + /// the system keychain. Only use this setting when keychain storage is not + /// an option. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub client_secret: Option, +} + #[with_fallible_options] #[derive(Deserialize, Serialize, Clone, PartialEq, Eq, JsonSchema, MergeFrom)] pub struct ContextServerCommand { diff --git a/crates/ui/src/components/ai/ai_setting_item.rs b/crates/ui/src/components/ai/ai_setting_item.rs index bfb55e4c7da688b736b4ff5c64a5767f1e930120..6651ee1b76933750dc8ba7a911047a8da11e56b8 100644 --- a/crates/ui/src/components/ai/ai_setting_item.rs +++ b/crates/ui/src/components/ai/ai_setting_item.rs @@ -10,6 +10,7 @@ pub enum AiSettingItemStatus { Running, Error, AuthRequired, + ClientSecretRequired, Authenticating, } @@ -21,6 +22,7 @@ impl AiSettingItemStatus { Self::Running => "Server is active.", Self::Error => "Server has an error.", Self::AuthRequired => "Authentication required.", + Self::ClientSecretRequired => "Client secret required.", Self::Authenticating => "Waiting for authorization…", } } @@ -31,7 +33,7 @@ impl AiSettingItemStatus { Self::Starting | Self::Authenticating => Some(Color::Muted), Self::Running => Some(Color::Success), Self::Error => Some(Color::Error), - Self::AuthRequired => Some(Color::Warning), + Self::AuthRequired | Self::ClientSecretRequired => Some(Color::Warning), } }