diff --git a/Cargo.lock b/Cargo.lock index e6cc573d3f206b17fab95cb00dd9599c20167e12..81fd42efb204b7ebb49e7b06ac6b981962c6f5bb 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -15030,8 +15030,11 @@ dependencies = [ "assets", "bm25", "client", + "component", + "copilot", "copilot_ui", "edit_prediction", + "edit_prediction_ui", "editor", "feature_flags", "fs", diff --git a/crates/copilot/src/copilot.rs b/crates/copilot/src/copilot.rs index 8177a509bc60ecb459515e02a0f0d9f75d09cdf9..357b6786a7043bd0c8f3c335ab033925f4b36f5a 100644 --- a/crates/copilot/src/copilot.rs +++ b/crates/copilot/src/copilot.rs @@ -271,27 +271,26 @@ impl GlobalCopilotAuth { fs: Arc, node_runtime: NodeRuntime, cx: &mut App, - ) { + ) -> GlobalCopilotAuth { let auth = GlobalCopilotAuth(cx.new(|cx| Copilot::new(None, server_id, fs, node_runtime, cx))); - cx.set_global(auth); + cx.set_global(auth.clone()); + auth } pub fn try_global(cx: &mut App) -> Option<&GlobalCopilotAuth> { cx.try_global() } - pub fn get_or_init(cx: &mut App) -> Option { + pub fn get_or_init(app_state: Arc, cx: &mut App) -> GlobalCopilotAuth { if let Some(copilot) = cx.try_global::() { - Some(copilot.clone()) + copilot.clone() } else { - let app_state = AppState::global(cx).upgrade()?; Self::set_global( app_state.languages.next_language_server_id(), app_state.fs.clone(), app_state.node_runtime.clone(), cx, - ); - cx.try_global::().cloned() + ) } } } diff --git a/crates/copilot_ui/src/sign_in.rs b/crates/copilot_ui/src/sign_in.rs index cbc252270e118885f117cfd9667802147bcc9f26..2629ade34852e3b3555335afb2378eb8e1f40dd5 100644 --- a/crates/copilot_ui/src/sign_in.rs +++ b/crates/copilot_ui/src/sign_in.rs @@ -10,7 +10,7 @@ use gpui::{ }; use ui::{ButtonLike, CommonAnimationExt, ConfiguredApiCard, Vector, VectorName, prelude::*}; use util::ResultExt as _; -use workspace::{Toast, Workspace, notifications::NotificationId}; +use workspace::{AppState, Toast, Workspace, notifications::NotificationId}; const COPILOT_SIGN_UP_URL: &str = "https://github.com/features/copilot"; const ERROR_LABEL: &str = @@ -457,7 +457,7 @@ impl Render for CopilotCodeVerification { pub struct ConfigurationView { copilot_status: Option, - is_authenticated: Box bool + 'static>, + is_authenticated: Box bool + 'static>, edit_prediction: bool, _subscription: Option, } @@ -469,11 +469,13 @@ pub enum ConfigurationMode { impl ConfigurationView { pub fn new( - is_authenticated: impl Fn(&App) -> bool + 'static, + is_authenticated: impl Fn(&mut App) -> bool + 'static, mode: ConfigurationMode, cx: &mut Context, ) -> Self { - let copilot = GlobalCopilotAuth::try_global(cx).cloned(); + let copilot = AppState::try_global(cx) + .and_then(|state| state.upgrade()) + .map(|state| GlobalCopilotAuth::get_or_init(state, cx)); Self { copilot_status: copilot.as_ref().map(|copilot| copilot.0.read(cx).status()), @@ -567,7 +569,8 @@ impl ConfigurationView { .icon_position(IconPosition::Start) .icon_size(IconSize::Small) .on_click(|_, window, cx| { - if let Some(copilot) = GlobalCopilotAuth::get_or_init(cx) { + if let Some(app_state) = AppState::global(cx).upgrade() { + let copilot = GlobalCopilotAuth::get_or_init(app_state, cx); initiate_sign_in(copilot.0, window, cx) } }) @@ -594,8 +597,9 @@ impl ConfigurationView { .icon_position(IconPosition::Start) .icon_size(IconSize::Small) .on_click(|_, window, cx| { - if let Some(copilot) = GlobalCopilotAuth::get_or_init(cx) { - reinstall_and_sign_in(copilot.0, window, cx) + if let Some(app_state) = AppState::global(cx).upgrade() { + let copilot = GlobalCopilotAuth::get_or_init(app_state, cx); + reinstall_and_sign_in(copilot.0, window, cx); } }) } diff --git a/crates/edit_prediction_ui/src/edit_prediction_button.rs b/crates/edit_prediction_ui/src/edit_prediction_button.rs index e4470337564328e5c71622b5ca3ebbe112c4ffcc..94e408cf67e11e30190a82e41d4a7e403ea94f42 100644 --- a/crates/edit_prediction_ui/src/edit_prediction_button.rs +++ b/crates/edit_prediction_ui/src/edit_prediction_button.rs @@ -536,65 +536,13 @@ impl EditPredictionButton { } } - fn get_available_providers(&self, cx: &mut App) -> Vec { - let mut providers = Vec::new(); - - providers.push(EditPredictionProvider::Zed); - - if cx.has_flag::() { - providers.push(EditPredictionProvider::Experimental( - EXPERIMENTAL_ZETA2_EDIT_PREDICTION_PROVIDER_NAME, - )); - } - - if let Some(_) = EditPredictionStore::try_global(cx) - .and_then(|store| store.read(cx).copilot_for_project(&self.project.upgrade()?)) - { - providers.push(EditPredictionProvider::Copilot); - } - - if let Some(supermaven) = Supermaven::global(cx) { - if let Supermaven::Spawned(agent) = supermaven.read(cx) { - if matches!(agent.account_status, AccountStatus::Ready) { - providers.push(EditPredictionProvider::Supermaven); - } - } - } - - if CodestralEditPredictionDelegate::has_api_key(cx) { - providers.push(EditPredictionProvider::Codestral); - } - - if cx.has_flag::() - && edit_prediction::sweep_ai::sweep_api_token(cx) - .read(cx) - .has_key() - { - providers.push(EditPredictionProvider::Experimental( - EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME, - )); - } - - if cx.has_flag::() - && edit_prediction::mercury::mercury_api_token(cx) - .read(cx) - .has_key() - { - providers.push(EditPredictionProvider::Experimental( - EXPERIMENTAL_MERCURY_EDIT_PREDICTION_PROVIDER_NAME, - )); - } - - providers - } - fn add_provider_switching_section( &self, mut menu: ContextMenu, current_provider: EditPredictionProvider, cx: &mut App, ) -> ContextMenu { - let available_providers = self.get_available_providers(cx); + let available_providers = get_available_providers(cx); let providers: Vec<_> = available_providers .into_iter() @@ -605,28 +553,12 @@ impl EditPredictionButton { menu = menu.separator().header("Providers"); for provider in providers { + let Some(name) = provider.display_name() else { + continue; + }; let is_current = provider == current_provider; let fs = self.fs.clone(); - let name = match provider { - EditPredictionProvider::Zed => "Zed AI", - EditPredictionProvider::Copilot => "GitHub Copilot", - EditPredictionProvider::Supermaven => "Supermaven", - EditPredictionProvider::Codestral => "Codestral", - EditPredictionProvider::Experimental( - EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME, - ) => "Sweep", - EditPredictionProvider::Experimental( - EXPERIMENTAL_MERCURY_EDIT_PREDICTION_PROVIDER_NAME, - ) => "Mercury", - EditPredictionProvider::Experimental( - EXPERIMENTAL_ZETA2_EDIT_PREDICTION_PROVIDER_NAME, - ) => "Zeta2", - EditPredictionProvider::None | EditPredictionProvider::Experimental(_) => { - continue; - } - }; - menu = menu.item( ContextMenuEntry::new(name) .toggleable(IconPosition::Start, is_current) @@ -1339,7 +1271,7 @@ async fn open_disabled_globs_setting_in_editor( anyhow::Ok(()) } -fn set_completion_provider(fs: Arc, cx: &mut App, provider: EditPredictionProvider) { +pub fn set_completion_provider(fs: Arc, cx: &mut App, provider: EditPredictionProvider) { update_settings_file(fs, cx, move |settings, _| { settings .project @@ -1350,6 +1282,61 @@ fn set_completion_provider(fs: Arc, cx: &mut App, provider: EditPredicti }); } +pub fn get_available_providers(cx: &mut App) -> Vec { + let mut providers = Vec::new(); + + providers.push(EditPredictionProvider::Zed); + + if cx.has_flag::() { + providers.push(EditPredictionProvider::Experimental( + EXPERIMENTAL_ZETA2_EDIT_PREDICTION_PROVIDER_NAME, + )); + } + + if let Some(app_state) = workspace::AppState::global(cx).upgrade() + && copilot::GlobalCopilotAuth::get_or_init(app_state, cx) + .0 + .read(cx) + .is_authenticated() + { + providers.push(EditPredictionProvider::Copilot); + }; + + if let Some(supermaven) = Supermaven::global(cx) { + if let Supermaven::Spawned(agent) = supermaven.read(cx) { + if matches!(agent.account_status, AccountStatus::Ready) { + providers.push(EditPredictionProvider::Supermaven); + } + } + } + + if CodestralEditPredictionDelegate::has_api_key(cx) { + providers.push(EditPredictionProvider::Codestral); + } + + if cx.has_flag::() + && edit_prediction::sweep_ai::sweep_api_token(cx) + .read(cx) + .has_key() + { + providers.push(EditPredictionProvider::Experimental( + EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME, + )); + } + + if cx.has_flag::() + && edit_prediction::mercury::mercury_api_token(cx) + .read(cx) + .has_key() + { + providers.push(EditPredictionProvider::Experimental( + EXPERIMENTAL_MERCURY_EDIT_PREDICTION_PROVIDER_NAME, + )); + } + + providers +} + fn toggle_show_edit_predictions_for_language( language: Arc, fs: Arc, diff --git a/crates/edit_prediction_ui/src/edit_prediction_ui.rs b/crates/edit_prediction_ui/src/edit_prediction_ui.rs index 2ca852a0140651b515734dd144c868bfebe04328..d0fd636a804d4fe86bba281f718645dbf17de8e9 100644 --- a/crates/edit_prediction_ui/src/edit_prediction_ui.rs +++ b/crates/edit_prediction_ui/src/edit_prediction_ui.rs @@ -16,7 +16,9 @@ use std::any::{Any as _, TypeId}; use ui::{App, prelude::*}; use workspace::{SplitDirection, Workspace}; -pub use edit_prediction_button::{EditPredictionButton, ToggleMenu}; +pub use edit_prediction_button::{ + EditPredictionButton, ToggleMenu, get_available_providers, set_completion_provider, +}; use crate::rate_prediction_modal::PredictEditsRatePredictionsFeatureFlag; diff --git a/crates/settings_content/src/language.rs b/crates/settings_content/src/language.rs index ef22e1886751a3497a353c7d0a30346ba5b872e3..5e0c553f3a9bd25b85c0deabe6463f556b6d5cdf 100644 --- a/crates/settings_content/src/language.rs +++ b/crates/settings_content/src/language.rs @@ -167,6 +167,25 @@ impl EditPredictionProvider { | EditPredictionProvider::Experimental(_) => false, } } + + pub fn display_name(&self) -> Option<&'static str> { + match self { + EditPredictionProvider::Zed => Some("Zed AI"), + EditPredictionProvider::Copilot => Some("GitHub Copilot"), + EditPredictionProvider::Supermaven => Some("Supermaven"), + EditPredictionProvider::Codestral => Some("Codestral"), + EditPredictionProvider::Experimental( + EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME, + ) => Some("Sweep"), + EditPredictionProvider::Experimental( + EXPERIMENTAL_MERCURY_EDIT_PREDICTION_PROVIDER_NAME, + ) => Some("Mercury"), + EditPredictionProvider::Experimental( + EXPERIMENTAL_ZETA2_EDIT_PREDICTION_PROVIDER_NAME, + ) => Some("Zeta2"), + EditPredictionProvider::None | EditPredictionProvider::Experimental(_) => None, + } + } } /// The contents of the edit prediction settings. diff --git a/crates/settings_ui/Cargo.toml b/crates/settings_ui/Cargo.toml index 6d65513d250876e901bc624ba1104767358a854b..10410d29262da2337721a98da591677e5530477f 100644 --- a/crates/settings_ui/Cargo.toml +++ b/crates/settings_ui/Cargo.toml @@ -18,8 +18,11 @@ test-support = [] [dependencies] anyhow.workspace = true bm25 = "2.3.2" +component.workspace = true +copilot.workspace = true copilot_ui.workspace = true edit_prediction.workspace = true +edit_prediction_ui.workspace = true editor.workspace = true feature_flags.workspace = true fs.workspace = true diff --git a/crates/settings_ui/src/pages/edit_prediction_provider_setup.rs b/crates/settings_ui/src/pages/edit_prediction_provider_setup.rs index b170c9a88f505b1873aec3acdbc8499e9dab836a..d5e6688deeff35cecf545b249913549964765692 100644 --- a/crates/settings_ui/src/pages/edit_prediction_provider_setup.rs +++ b/crates/settings_ui/src/pages/edit_prediction_provider_setup.rs @@ -1,13 +1,16 @@ use edit_prediction::{ - ApiKeyState, EditPredictionStore, MercuryFeatureFlag, SweepFeatureFlag, + ApiKeyState, MercuryFeatureFlag, SweepFeatureFlag, mercury::{MERCURY_CREDENTIALS_URL, mercury_api_token}, sweep_ai::{SWEEP_CREDENTIALS_URL, sweep_api_token}, }; +use edit_prediction_ui::{get_available_providers, set_completion_provider}; use feature_flags::FeatureFlagAppExt as _; use gpui::{Entity, ScrollHandle, prelude::*}; +use language::language_settings::AllLanguageSettings; use language_models::provider::mistral::{CODESTRAL_API_URL, codestral_api_key}; -use project::Project; -use ui::{ButtonLink, ConfiguredApiCard, prelude::*}; +use settings::Settings as _; +use ui::{ButtonLink, ConfiguredApiCard, ContextMenu, DropdownMenu, DropdownStyle, prelude::*}; +use workspace::AppState; use crate::{ SettingField, SettingItem, SettingsFieldMetadata, SettingsPageItem, SettingsWindow, USER, @@ -20,15 +23,9 @@ pub(crate) fn render_edit_prediction_setup_page( window: &mut Window, cx: &mut Context, ) -> AnyElement { - let project = settings_window.original_window.as_ref().and_then(|window| { - window - .read_with(cx, |workspace, _| workspace.project().clone()) - .ok() - }); let providers = [ - project.and_then(|project| { - render_github_copilot_provider(project, window, cx).map(IntoElement::into_any_element) - }), + Some(render_provider_dropdown(window, cx)), + render_github_copilot_provider(window, cx).map(IntoElement::into_any_element), cx.has_flag::().then(|| { render_api_key_provider( IconName::Inception, @@ -94,6 +91,65 @@ pub(crate) fn render_edit_prediction_setup_page( .into_any_element() } +fn render_provider_dropdown(window: &mut Window, cx: &mut App) -> AnyElement { + let current_provider = AllLanguageSettings::get_global(cx) + .edit_predictions + .provider; + let current_provider_name = current_provider.display_name().unwrap_or("No provider set"); + + let menu = ContextMenu::build(window, cx, move |mut menu, _, cx| { + let available_providers = get_available_providers(cx); + let fs = ::global(cx); + + for provider in available_providers { + let Some(name) = provider.display_name() else { + continue; + }; + let is_current = provider == current_provider; + + menu = menu.toggleable_entry(name, is_current, IconPosition::Start, None, { + let fs = fs.clone(); + move |_, cx| { + set_completion_provider(fs.clone(), cx, provider); + } + }); + } + menu + }); + + v_flex() + .id("provider-selector") + .min_w_0() + .gap_1p5() + .child( + SettingsSectionHeader::new("Active Provider") + .icon(IconName::Sparkle) + .no_padding(true), + ) + .child( + h_flex() + .pt_2p5() + .w_full() + .justify_between() + .child( + v_flex() + .w_full() + .max_w_1_2() + .child(Label::new("Provider")) + .child( + Label::new("Select which provider to use for edit predictions.") + .size(LabelSize::Small) + .color(Color::Muted), + ), + ) + .child( + DropdownMenu::new("provider-dropdown", current_provider_name, menu) + .style(DropdownStyle::Outlined), + ), + ) + .into_any_element() +} + fn render_api_key_provider( icon: IconName, title: &'static str, @@ -330,20 +386,16 @@ fn codestral_settings() -> Box<[SettingsPageItem]> { ]) } -fn render_github_copilot_provider( - project: Entity, - window: &mut Window, - cx: &mut App, -) -> Option { - let copilot = EditPredictionStore::try_global(cx)? - .read(cx) - .copilot_for_project(&project); +fn render_github_copilot_provider(window: &mut Window, cx: &mut App) -> Option { let configuration_view = window.use_state(cx, |_, cx| { copilot_ui::ConfigurationView::new( move |cx| { - copilot - .as_ref() - .is_some_and(|copilot| copilot.read(cx).is_authenticated()) + if let Some(app_state) = AppState::global(cx).upgrade() { + let copilot = copilot::GlobalCopilotAuth::get_or_init(app_state, cx); + copilot.0.read(cx).is_authenticated() + } else { + false + } }, copilot_ui::ConfigurationMode::EditPrediction, cx, @@ -354,6 +406,7 @@ fn render_github_copilot_provider( v_flex() .id("github-copilot") .min_w_0() + .pt_8() .gap_1p5() .child( SettingsSectionHeader::new("GitHub Copilot")