diff --git a/crates/settings_ui/src/page_data.rs b/crates/settings_ui/src/page_data.rs index 3841811c7f6dee93e354253469e976a5a4c1789e..ca2e23252a4483b365c7c42cfd086105d757a097 100644 --- a/crates/settings_ui/src/page_data.rs +++ b/crates/settings_ui/src/page_data.rs @@ -7659,8 +7659,8 @@ fn edit_prediction_language_settings_section() -> Vec { files: USER, render: Arc::new(|_, window, cx| { let settings_window = cx.entity(); - let page = window.use_state(cx, |window, cx| { - crate::pages::EditPredictionSetupPage::new(settings_window, window, cx) + let page = window.use_state(cx, |_, _| { + crate::pages::EditPredictionSetupPage::new(settings_window) }); page.into_any_element() }), diff --git a/crates/settings_ui/src/pages/edit_prediction_provider_setup.rs b/crates/settings_ui/src/pages/edit_prediction_provider_setup.rs index 5c674e501d507b4203887ebd6d03bc311607e8cb..fb8f967613fa195080f62c5ab2ce76a43f3d1e22 100644 --- a/crates/settings_ui/src/pages/edit_prediction_provider_setup.rs +++ b/crates/settings_ui/src/pages/edit_prediction_provider_setup.rs @@ -3,15 +3,10 @@ use edit_prediction::{ mercury::{MERCURY_CREDENTIALS_URL, mercury_api_token}, sweep_ai::{SWEEP_CREDENTIALS_URL, sweep_api_token}, }; -use extension_host::ExtensionStore; use feature_flags::FeatureFlagAppExt as _; -use gpui::{AnyView, Entity, ScrollHandle, Subscription, prelude::*}; -use language_model::{ - ConfigurationViewTargetAgent, LanguageModelProviderId, LanguageModelRegistry, -}; +use gpui::{Entity, ScrollHandle, prelude::*}; use language_models::provider::mistral::{CODESTRAL_API_URL, codestral_api_key}; -use std::collections::HashMap; -use ui::{ButtonLink, ConfiguredApiCard, Icon, WithScrollbar, prelude::*}; +use ui::{ButtonLink, ConfiguredApiCard, WithScrollbar, prelude::*}; use crate::{ SettingField, SettingItem, SettingsFieldMetadata, SettingsPageItem, SettingsWindow, USER, @@ -21,133 +16,24 @@ use crate::{ pub struct EditPredictionSetupPage { settings_window: Entity, scroll_handle: ScrollHandle, - extension_oauth_views: HashMap, - _registry_subscription: Subscription, -} - -struct ExtensionOAuthProviderView { - provider_name: SharedString, - provider_icon: IconName, - provider_icon_path: Option, - configuration_view: AnyView, } impl EditPredictionSetupPage { - pub fn new( - settings_window: Entity, - window: &mut Window, - cx: &mut Context, - ) -> Self { - let registry_subscription = cx.subscribe_in( - &LanguageModelRegistry::global(cx), - window, - |this, _, event: &language_model::Event, window, cx| match event { - language_model::Event::AddedProvider(provider_id) => { - this.maybe_add_extension_oauth_view(provider_id, window, cx); - } - language_model::Event::RemovedProvider(provider_id) => { - this.extension_oauth_views.remove(provider_id); - } - _ => {} - }, - ); - - let mut this = Self { + pub fn new(settings_window: Entity) -> Self { + Self { settings_window, scroll_handle: ScrollHandle::new(), - extension_oauth_views: HashMap::default(), - _registry_subscription: registry_subscription, - }; - this.build_extension_oauth_views(window, cx); - this - } - - fn build_extension_oauth_views(&mut self, window: &mut Window, cx: &mut Context) { - let oauth_provider_ids = get_extension_oauth_provider_ids(cx); - for provider_id in oauth_provider_ids { - self.maybe_add_extension_oauth_view(&provider_id, window, cx); } } - - fn maybe_add_extension_oauth_view( - &mut self, - provider_id: &LanguageModelProviderId, - window: &mut Window, - cx: &mut Context, - ) { - // Check if this provider has OAuth configured in the extension manifest - if !is_extension_oauth_provider(provider_id, cx) { - return; - } - - let registry = LanguageModelRegistry::global(cx).read(cx); - let Some(provider) = registry.provider(provider_id) else { - return; - }; - - let provider_name = provider.name().0; - let provider_icon = provider.icon(); - let provider_icon_path = provider.icon_path(); - let configuration_view = - provider.configuration_view(ConfigurationViewTargetAgent::EditPrediction, window, cx); - - self.extension_oauth_views.insert( - provider_id.clone(), - ExtensionOAuthProviderView { - provider_name, - provider_icon, - provider_icon_path, - configuration_view, - }, - ); - } } impl Render for EditPredictionSetupPage { fn render(&mut self, window: &mut Window, cx: &mut Context) -> impl IntoElement { let settings_window = self.settings_window.clone(); - let copilot_extension_installed = ExtensionStore::global(cx) - .read(cx) - .installed_extensions() - .contains_key("copilot-chat"); - - let mut providers: Vec = Vec::new(); - - // Built-in Copilot (hidden if copilot-chat extension is installed) - if !copilot_extension_installed { - providers.push(render_github_copilot_provider(window, cx).into_any_element()); - } - - // Extension providers with OAuth support - for (provider_id, view) in &self.extension_oauth_views { - let icon_element: AnyElement = if let Some(icon_path) = &view.provider_icon_path { - Icon::from_external_svg(icon_path.clone()) - .size(ui::IconSize::Medium) - .into_any_element() - } else { - Icon::new(view.provider_icon) - .size(ui::IconSize::Medium) - .into_any_element() - }; - - providers.push( - v_flex() - .id(SharedString::from(provider_id.0.to_string())) - .min_w_0() - .gap_1p5() - .child( - h_flex().gap_2().items_center().child(icon_element).child( - Headline::new(view.provider_name.clone()).size(HeadlineSize::Small), - ), - ) - .child(view.configuration_view.clone()) - .into_any_element(), - ); - } - - if cx.has_flag::() { - providers.push( + let providers = [ + Some(render_github_copilot_provider(window, cx).into_any_element()), + cx.has_flag::().then(|| { render_api_key_provider( IconName::Inception, "Mercury", @@ -158,12 +44,9 @@ impl Render for EditPredictionSetupPage { window, cx, ) - .into_any_element(), - ); - } - - if cx.has_flag::() { - providers.push( + .into_any_element() + }), + cx.has_flag::().then(|| { render_api_key_provider( IconName::SweepAi, "Sweep", @@ -174,33 +57,32 @@ impl Render for EditPredictionSetupPage { window, cx, ) + .into_any_element() + }), + Some( + render_api_key_provider( + IconName::AiMistral, + "Codestral", + "https://console.mistral.ai/codestral".into(), + codestral_api_key(cx), + |cx| language_models::MistralLanguageModelProvider::api_url(cx), + Some(settings_window.update(cx, |settings_window, cx| { + let codestral_settings = codestral_settings(); + settings_window + .render_sub_page_items_section( + codestral_settings.iter().enumerate(), + None, + window, + cx, + ) + .into_any_element() + })), + window, + cx, + ) .into_any_element(), - ); - } - - providers.push( - render_api_key_provider( - IconName::AiMistral, - "Codestral", - "https://console.mistral.ai/codestral".into(), - codestral_api_key(cx), - |cx| language_models::MistralLanguageModelProvider::api_url(cx), - Some(settings_window.update(cx, |settings_window, cx| { - let codestral_settings = codestral_settings(); - settings_window - .render_sub_page_items_section( - codestral_settings.iter().enumerate(), - None, - window, - cx, - ) - .into_any_element() - })), - window, - cx, - ) - .into_any_element(), - ); + ), + ]; div() .size_full() @@ -214,60 +96,11 @@ impl Render for EditPredictionSetupPage { .pb_16() .overflow_y_scroll() .track_scroll(&self.scroll_handle) - .children(providers), + .children(providers.into_iter().flatten()), ) } } -/// Get extension provider IDs that have OAuth configured. -fn get_extension_oauth_provider_ids(cx: &App) -> Vec { - let extension_store = ExtensionStore::global(cx).read(cx); - - extension_store - .installed_extensions() - .iter() - .flat_map(|(extension_id, entry)| { - entry.manifest.language_model_providers.iter().filter_map( - move |(provider_id, provider_entry)| { - // Check if this provider has OAuth configured - let has_oauth = provider_entry - .auth - .as_ref() - .is_some_and(|auth| auth.oauth.is_some()); - - if has_oauth { - Some(LanguageModelProviderId( - format!("{}:{}", extension_id, provider_id).into(), - )) - } else { - None - } - }, - ) - }) - .collect() -} - -/// Check if a provider ID corresponds to an extension with OAuth configured. -fn is_extension_oauth_provider(provider_id: &LanguageModelProviderId, cx: &App) -> bool { - // Extension provider IDs are in the format "extension_id:provider_id" - let Some((extension_id, local_provider_id)) = provider_id.0.split_once(':') else { - return false; - }; - - let extension_store = ExtensionStore::global(cx).read(cx); - let Some(entry) = extension_store.installed_extensions().get(extension_id) else { - return false; - }; - - entry - .manifest - .language_model_providers - .get(local_provider_id) - .and_then(|p| p.auth.as_ref()) - .is_some_and(|auth| auth.oauth.is_some()) -} - fn render_api_key_provider( icon: IconName, title: &'static str,