From 55c91131772e82e6c9eafd41f2a00b5b866bb20f Mon Sep 17 00:00:00 2001 From: Richard Feldman Date: Mon, 15 Dec 2025 14:14:37 -0500 Subject: [PATCH] Support OAuth extensions in settings panel --- Cargo.lock | 1 + .../src/wasm_host/llm_provider.rs | 116 ++++----- crates/settings_ui/Cargo.toml | 1 + crates/settings_ui/src/page_data.rs | 4 +- .../pages/edit_prediction_provider_setup.rs | 237 +++++++++++++++--- 5 files changed, 263 insertions(+), 96 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 7c74c46e5696493f6499d5758a25bec7d4bdb9a2..798ef5d5c8cdcbe2dfc6c81e62b289a99726e086 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -14831,6 +14831,7 @@ dependencies = [ "gpui", "heck 0.5.0", "language", + "language_model", "language_models", "log", "menu", diff --git a/crates/extension_host/src/wasm_host/llm_provider.rs b/crates/extension_host/src/wasm_host/llm_provider.rs index e02e25d30da443bee497c460bc98e3d1c406e577..d4f38a66977274914d6f1ecc06e0ad8ba1b68763 100644 --- a/crates/extension_host/src/wasm_host/llm_provider.rs +++ b/crates/extension_host/src/wasm_host/llm_provider.rs @@ -34,7 +34,7 @@ use markdown::{Markdown, MarkdownElement, MarkdownStyle}; use settings::Settings; use std::sync::Arc; use theme::ThemeSettings; -use ui::{Label, LabelSize, prelude::*}; +use ui::{ConfiguredApiCard, Label, LabelSize, prelude::*}; use util::ResultExt as _; use workspace::Workspace; use workspace::oauth_device_flow_modal::{ @@ -658,17 +658,25 @@ impl ExtensionProviderConfigurationView { let icon_path = self.icon_path.clone(); let this_handle = cx.weak_entity(); - // Get workspace to show modal - let Some(workspace) = window.root::().flatten() else { + // Get workspace window handle to show modal - try current window first, then find any workspace window + log::info!("OAuth: Looking for workspace window"); + let workspace_window = window.window_handle().downcast::().or_else(|| { + log::info!("OAuth: Current window is not a workspace, searching other windows"); + cx.windows() + .into_iter() + .find_map(|window_handle| window_handle.downcast::()) + }); + + let Some(workspace_window) = workspace_window else { + log::error!("OAuth: Could not find any workspace window"); self.oauth_in_progress = false; self.oauth_error = Some("Could not access workspace to show sign-in modal".to_string()); cx.notify(); return; }; - - let workspace = workspace.downgrade(); + log::info!("OAuth: Found workspace window"); let state = state.downgrade(); - cx.spawn_in(window, async move |_this, cx| { + cx.spawn(async move |_this, cx| { // Step 1: Start device flow - get prompt info from extension let start_result = extension .call({ @@ -683,12 +691,22 @@ impl ExtensionProviderConfigurationView { }) .await; + log::info!( + "OAuth: Device flow start result: {:?}", + start_result.is_ok() + ); let prompt_info: LlmDeviceFlowPromptInfo = match start_result { - Ok(Ok(Ok(info))) => info, + Ok(Ok(Ok(info))) => { + log::info!( + "OAuth: Got device flow prompt info, user_code: {}", + info.user_code + ); + info + } Ok(Ok(Err(e))) => { - log::error!("Device flow start failed: {}", e); + log::error!("OAuth: Device flow start failed: {}", e); this_handle - .update_in(cx, |this, _window, cx| { + .update(cx, |this, cx| { this.oauth_in_progress = false; this.oauth_error = Some(e); cx.notify(); @@ -697,9 +715,9 @@ impl ExtensionProviderConfigurationView { return; } Ok(Err(e)) | Err(e) => { - log::error!("Device flow start error: {}", e); + log::error!("OAuth: Device flow start error: {}", e); this_handle - .update_in(cx, |this, _window, cx| { + .update(cx, |this, cx| { this.oauth_in_progress = false; this.oauth_error = Some(e.to_string()); cx.notify(); @@ -721,20 +739,31 @@ impl ExtensionProviderConfigurationView { icon_path, }; - let flow_state: Option> = workspace - .update_in(cx, |workspace, window, cx| { + log::info!("OAuth: Attempting to show modal in workspace window"); + let flow_state: Option> = workspace_window + .update(cx, |workspace, window, cx| { + log::info!("OAuth: Inside workspace.update, creating modal"); + window.activate_window(); let flow_state = cx.new(|_cx| OAuthDeviceFlowState::new(modal_config)); let flow_state_clone = flow_state.clone(); workspace.toggle_modal(window, cx, |_window, cx| { + log::info!("OAuth: Inside toggle_modal callback"); OAuthDeviceFlowModal::new(flow_state_clone, cx) }); flow_state }) .ok(); + log::info!( + "OAuth: workspace_window.update result: {:?}", + flow_state.is_some() + ); let Some(flow_state) = flow_state else { + log::error!( + "OAuth: Failed to show sign-in modal - workspace_window.update returned None" + ); this_handle - .update_in(cx, |this, _window, cx| { + .update(cx, |this, cx| { this.oauth_in_progress = false; this.oauth_error = Some("Failed to show sign-in modal".to_string()); cx.notify(); @@ -742,6 +771,7 @@ impl ExtensionProviderConfigurationView { .log_err(); return; }; + log::info!("OAuth: Modal shown successfully, starting poll"); // Step 3: Poll for authentication completion let poll_result = extension @@ -778,7 +808,7 @@ impl ExtensionProviderConfigurationView { }; state - .update_in(cx, |state, _window, cx| { + .update(cx, |state, cx| { state.is_authenticated = true; state.available_models = new_models; cx.notify(); @@ -787,7 +817,7 @@ impl ExtensionProviderConfigurationView { // Update flow state to show success flow_state - .update_in(cx, |state, _window, cx| { + .update(cx, |state, cx| { state.set_status(OAuthDeviceFlowStatus::Authorized, cx); }) .log_err(); @@ -795,12 +825,12 @@ impl ExtensionProviderConfigurationView { Ok(Ok(Err(e))) => { log::error!("Device flow poll failed: {}", e); flow_state - .update_in(cx, |state, _window, cx| { + .update(cx, |state, cx| { state.set_status(OAuthDeviceFlowStatus::Failed(e.clone()), cx); }) .log_err(); this_handle - .update_in(cx, |this, _window, cx| { + .update(cx, |this, cx| { this.oauth_error = Some(e); cx.notify(); }) @@ -810,7 +840,7 @@ impl ExtensionProviderConfigurationView { log::error!("Device flow poll error: {}", e); let error_string = e.to_string(); flow_state - .update_in(cx, |state, _window, cx| { + .update(cx, |state, cx| { state.set_status( OAuthDeviceFlowStatus::Failed(error_string.clone()), cx, @@ -818,7 +848,7 @@ impl ExtensionProviderConfigurationView { }) .log_err(); this_handle - .update_in(cx, |this, _window, cx| { + .update(cx, |this, cx| { this.oauth_error = Some(error_string); cx.notify(); }) @@ -827,7 +857,7 @@ impl ExtensionProviderConfigurationView { }; this_handle - .update_in(cx, |this, _window, cx| { + .update(cx, |this, cx| { this.oauth_in_progress = false; cx.notify(); }) @@ -958,46 +988,18 @@ impl gpui::Render for ExtensionProviderConfigurationView { // If authenticated, show success state with sign out option if is_authenticated && env_var_name_used.is_none() { - let reset_label = if has_oauth && !has_api_key { - "Sign Out" + let (status_label, button_label) = if has_oauth && !has_api_key { + ("Signed in", "Sign Out") } else { - "Reset Key" - }; - - let status_label = if has_oauth && !has_api_key { - "Signed in" - } else { - "API key configured" + ("API key configured", "Reset Key") }; content = content.child( - h_flex() - .mt_0p5() - .p_1() - .justify_between() - .rounded_md() - .border_1() - .border_color(cx.theme().colors().border) - .bg(cx.theme().colors().background) - .child( - h_flex() - .flex_1() - .min_w_0() - .gap_1() - .child(ui::Icon::new(ui::IconName::Check).color(Color::Success)) - .child(Label::new(status_label).truncate()), - ) - .child( - ui::Button::new("reset-key", reset_label) - .label_size(LabelSize::Small) - .icon(ui::IconName::Undo) - .icon_size(ui::IconSize::Small) - .icon_color(Color::Muted) - .icon_position(ui::IconPosition::Start) - .on_click(cx.listener(|this, _, window, cx| { - this.reset_api_key(window, cx); - })), - ), + ConfiguredApiCard::new(status_label) + .button_label(button_label) + .on_click(cx.listener(|this, _, window, cx| { + this.reset_api_key(window, cx); + })), ); return content.into_any_element(); diff --git a/crates/settings_ui/Cargo.toml b/crates/settings_ui/Cargo.toml index 99889f391be5aa50a51d0cef473c759a27637fef..ee8e2b4552cdd2e32ccb675b1d08c86063a2437f 100644 --- a/crates/settings_ui/Cargo.toml +++ b/crates/settings_ui/Cargo.toml @@ -21,6 +21,7 @@ bm25 = "2.3.2" copilot.workspace = true edit_prediction.workspace = true extension_host.workspace = true +language_model.workspace = true language_models.workspace = true editor.workspace = true feature_flags.workspace = true diff --git a/crates/settings_ui/src/page_data.rs b/crates/settings_ui/src/page_data.rs index 79fc1cc11158399265a184a289fd8d7a71ce8d69..4b9b4c2fd0ec1b546c3683961c1da1336942ad69 100644 --- a/crates/settings_ui/src/page_data.rs +++ b/crates/settings_ui/src/page_data.rs @@ -7479,8 +7479,8 @@ fn edit_prediction_language_settings_section() -> Vec { files: USER, render: Arc::new(|_, window, cx| { let settings_window = cx.entity(); - let page = window.use_state(cx, |_, _| { - crate::pages::EditPredictionSetupPage::new(settings_window) + let page = window.use_state(cx, |window, cx| { + crate::pages::EditPredictionSetupPage::new(settings_window, window, cx) }); page.into_any_element() }), diff --git a/crates/settings_ui/src/pages/edit_prediction_provider_setup.rs b/crates/settings_ui/src/pages/edit_prediction_provider_setup.rs index 031cc3d1142cf2c14224188a8d58f0c296942fdd..ca37b990afaac3d03c811f6620f226db288dc12d 100644 --- a/crates/settings_ui/src/pages/edit_prediction_provider_setup.rs +++ b/crates/settings_ui/src/pages/edit_prediction_provider_setup.rs @@ -5,9 +5,13 @@ use edit_prediction::{ }; use extension_host::ExtensionStore; use feature_flags::FeatureFlagAppExt as _; -use gpui::{Entity, ScrollHandle, prelude::*}; +use gpui::{AnyView, Entity, ScrollHandle, Subscription, prelude::*}; +use language_model::{ + ConfigurationViewTargetAgent, LanguageModelProviderId, LanguageModelRegistry, +}; use language_models::provider::mistral::{CODESTRAL_API_URL, codestral_api_key}; -use ui::{ButtonLink, ConfiguredApiCard, WithScrollbar, prelude::*}; +use std::collections::HashMap; +use ui::{ButtonLink, ConfiguredApiCard, Icon, WithScrollbar, prelude::*}; use crate::{ SettingField, SettingItem, SettingsFieldMetadata, SettingsPageItem, SettingsWindow, USER, @@ -17,14 +21,85 @@ use crate::{ pub struct EditPredictionSetupPage { settings_window: Entity, scroll_handle: ScrollHandle, + extension_oauth_views: HashMap, + _registry_subscription: Subscription, +} + +struct ExtensionOAuthProviderView { + provider_name: SharedString, + provider_icon: IconName, + provider_icon_path: Option, + configuration_view: AnyView, } impl EditPredictionSetupPage { - pub fn new(settings_window: Entity) -> Self { - Self { + pub fn new( + settings_window: Entity, + window: &mut Window, + cx: &mut Context, + ) -> Self { + let registry_subscription = cx.subscribe_in( + &LanguageModelRegistry::global(cx), + window, + |this, _, event: &language_model::Event, window, cx| match event { + language_model::Event::AddedProvider(provider_id) => { + this.maybe_add_extension_oauth_view(provider_id, window, cx); + } + language_model::Event::RemovedProvider(provider_id) => { + this.extension_oauth_views.remove(provider_id); + } + _ => {} + }, + ); + + let mut this = Self { settings_window, scroll_handle: ScrollHandle::new(), + extension_oauth_views: HashMap::default(), + _registry_subscription: registry_subscription, + }; + this.build_extension_oauth_views(window, cx); + this + } + + fn build_extension_oauth_views(&mut self, window: &mut Window, cx: &mut Context) { + let oauth_provider_ids = get_extension_oauth_provider_ids(cx); + for provider_id in oauth_provider_ids { + self.maybe_add_extension_oauth_view(&provider_id, window, cx); + } + } + + fn maybe_add_extension_oauth_view( + &mut self, + provider_id: &LanguageModelProviderId, + window: &mut Window, + cx: &mut Context, + ) { + // Check if this provider has OAuth configured in the extension manifest + if !is_extension_oauth_provider(provider_id, cx) { + return; } + + let registry = LanguageModelRegistry::global(cx).read(cx); + let Some(provider) = registry.provider(provider_id) else { + return; + }; + + let provider_name = provider.name().0.clone(); + let provider_icon = provider.icon(); + let provider_icon_path = provider.icon_path(); + let configuration_view = + provider.configuration_view(ConfigurationViewTargetAgent::ZedAgent, window, cx); + + self.extension_oauth_views.insert( + provider_id.clone(), + ExtensionOAuthProviderView { + provider_name, + provider_icon, + provider_icon_path, + configuration_view, + }, + ); } } @@ -37,10 +112,43 @@ impl Render for EditPredictionSetupPage { .installed_extensions() .contains_key("copilot-chat"); - let providers = [ - (!copilot_extension_installed) - .then(|| render_github_copilot_provider(window, cx).into_any_element()), - cx.has_flag::().then(|| { + let mut providers: Vec = Vec::new(); + + // Built-in Copilot (hidden if copilot-chat extension is installed) + if !copilot_extension_installed { + providers.push(render_github_copilot_provider(window, cx).into_any_element()); + } + + // Extension providers with OAuth support + for (provider_id, view) in &self.extension_oauth_views { + let icon_element: AnyElement = if let Some(icon_path) = &view.provider_icon_path { + Icon::from_external_svg(icon_path.clone()) + .size(ui::IconSize::Medium) + .into_any_element() + } else { + Icon::new(view.provider_icon) + .size(ui::IconSize::Medium) + .into_any_element() + }; + + providers.push( + v_flex() + .id(SharedString::from(provider_id.0.to_string())) + .min_w_0() + .gap_1p5() + .child( + h_flex().gap_2().items_center().child(icon_element).child( + Headline::new(view.provider_name.clone()).size(HeadlineSize::Small), + ), + ) + .child(view.configuration_view.clone()) + .into_any_element(), + ); + } + + // Mercury (feature flagged) + if cx.has_flag::() { + providers.push( render_api_key_provider( IconName::Inception, "Mercury", @@ -51,9 +159,13 @@ impl Render for EditPredictionSetupPage { window, cx, ) - .into_any_element() - }), - cx.has_flag::().then(|| { + .into_any_element(), + ); + } + + // Sweep (feature flagged) + if cx.has_flag::() { + providers.push( render_api_key_provider( IconName::SweepAi, "Sweep", @@ -64,32 +176,34 @@ impl Render for EditPredictionSetupPage { window, cx, ) - .into_any_element() - }), - Some( - render_api_key_provider( - IconName::AiMistral, - "Codestral", - "https://console.mistral.ai/codestral".into(), - codestral_api_key(cx), - |cx| language_models::MistralLanguageModelProvider::api_url(cx), - Some(settings_window.update(cx, |settings_window, cx| { - let codestral_settings = codestral_settings(); - settings_window - .render_sub_page_items_section( - codestral_settings.iter().enumerate(), - None, - window, - cx, - ) - .into_any_element() - })), - window, - cx, - ) .into_any_element(), - ), - ]; + ); + } + + // Codestral + providers.push( + render_api_key_provider( + IconName::AiMistral, + "Codestral", + "https://console.mistral.ai/codestral".into(), + codestral_api_key(cx), + |cx| language_models::MistralLanguageModelProvider::api_url(cx), + Some(settings_window.update(cx, |settings_window, cx| { + let codestral_settings = codestral_settings(); + settings_window + .render_sub_page_items_section( + codestral_settings.iter().enumerate(), + None, + window, + cx, + ) + .into_any_element() + })), + window, + cx, + ) + .into_any_element(), + ); div() .size_full() @@ -103,11 +217,60 @@ impl Render for EditPredictionSetupPage { .pb_16() .overflow_y_scroll() .track_scroll(&self.scroll_handle) - .children(providers.into_iter().flatten()), + .children(providers), ) } } +/// Get extension provider IDs that have OAuth configured. +fn get_extension_oauth_provider_ids(cx: &App) -> Vec { + let extension_store = ExtensionStore::global(cx).read(cx); + + extension_store + .installed_extensions() + .iter() + .flat_map(|(extension_id, entry)| { + entry.manifest.language_model_providers.iter().filter_map( + move |(provider_id, provider_entry)| { + // Check if this provider has OAuth configured + let has_oauth = provider_entry + .auth + .as_ref() + .is_some_and(|auth| auth.oauth.is_some()); + + if has_oauth { + Some(LanguageModelProviderId( + format!("{}:{}", extension_id, provider_id).into(), + )) + } else { + None + } + }, + ) + }) + .collect() +} + +/// Check if a provider ID corresponds to an extension with OAuth configured. +fn is_extension_oauth_provider(provider_id: &LanguageModelProviderId, cx: &App) -> bool { + // Extension provider IDs are in the format "extension_id:provider_id" + let Some((extension_id, local_provider_id)) = provider_id.0.split_once(':') else { + return false; + }; + + let extension_store = ExtensionStore::global(cx).read(cx); + let Some(entry) = extension_store.installed_extensions().get(extension_id) else { + return false; + }; + + entry + .manifest + .language_model_providers + .get(local_provider_id) + .and_then(|p| p.auth.as_ref()) + .is_some_and(|auth| auth.oauth.is_some()) +} + fn render_api_key_provider( icon: IconName, title: &'static str,