diff --git a/crates/edit_prediction_button/src/edit_prediction_button.rs b/crates/edit_prediction_button/src/edit_prediction_button.rs index ba00e95c488dc8e8704274638087c8334f96e1a3..051ca6e85ccb985ba6b325cda725f83029aa3193 100644 --- a/crates/edit_prediction_button/src/edit_prediction_button.rs +++ b/crates/edit_prediction_button/src/edit_prediction_button.rs @@ -83,9 +83,7 @@ impl Render for EditPredictionButton { let all_language_settings = all_language_settings(None, cx); - match &all_language_settings.edit_predictions.provider { - EditPredictionProvider::None => div().hidden(), - + match all_language_settings.edit_predictions.provider { EditPredictionProvider::Copilot => { let Some(copilot) = Copilot::global(cx) else { return div().hidden(); @@ -302,23 +300,23 @@ impl Render for EditPredictionButton { .with_handle(self.popover_menu_handle.clone()), ) } - EditPredictionProvider::Experimental(provider_name) => { - if *provider_name == EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME - && cx.has_flag::() - { - div().child(Icon::new(IconName::SweepAi)) - } else { - div() - } - } - - EditPredictionProvider::Zed => { + provider @ (EditPredictionProvider::Experimental( + EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME, + ) + | EditPredictionProvider::Zed) => { let enabled = self.editor_enabled.unwrap_or(true); - let zeta_icon = if enabled { - IconName::ZedPredict - } else { - IconName::ZedPredictDisabled + let is_sweep = matches!( + provider, + EditPredictionProvider::Experimental( + EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME + ) + ); + + let zeta_icon = match (is_sweep, enabled) { + (true, _) => IconName::SweepAi, + (false, true) => IconName::ZedPredict, + (false, false) => IconName::ZedPredictDisabled, }; if zeta::should_show_upsell_modal() { @@ -402,8 +400,10 @@ impl Render for EditPredictionButton { let mut popover_menu = PopoverMenu::new("zeta") .menu(move |window, cx| { - this.update(cx, |this, cx| this.build_zeta_context_menu(window, cx)) - .ok() + this.update(cx, |this, cx| { + this.build_zeta_context_menu(provider, window, cx) + }) + .ok() }) .anchor(Corner::BottomRight) .with_handle(self.popover_menu_handle.clone()); @@ -429,6 +429,10 @@ impl Render for EditPredictionButton { div().child(popover_menu.into_any_element()) } + + EditPredictionProvider::None | EditPredictionProvider::Experimental(_) => { + div().hidden() + } } } } @@ -487,6 +491,12 @@ impl EditPredictionButton { providers.push(EditPredictionProvider::Codestral); } + if cx.has_flag::() { + providers.push(EditPredictionProvider::Experimental( + EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME, + )); + } + providers } @@ -498,6 +508,11 @@ impl EditPredictionButton { ) -> ContextMenu { let available_providers = self.get_available_providers(cx); + const ZED_AI_CALLOUT: &str = + "Zed's edit prediction is powered by Zeta, an open-source, dataset mode."; + const USE_SWEEP_API_TOKEN_CALLOUT: &str = + "Set the SWEEP_API_TOKEN environment variable to use Sweep"; + let other_providers: Vec<_> = available_providers .into_iter() .filter(|p| *p != current_provider && *p != EditPredictionProvider::None) @@ -514,11 +529,8 @@ impl EditPredictionButton { ContextMenuEntry::new("Zed AI") .documentation_aside( DocumentationSide::Left, - DocumentationEdge::Top, - |_| { - Label::new("Zed's edit prediction is powered by Zeta, an open-source, dataset mode.") - .into_any_element() - }, + DocumentationEdge::Bottom, + |_| Label::new(ZED_AI_CALLOUT).into_any_element(), ) .handler(move |_, cx| { set_completion_provider(fs.clone(), cx, provider); @@ -539,7 +551,29 @@ impl EditPredictionButton { set_completion_provider(fs.clone(), cx, provider); }) } - EditPredictionProvider::None | EditPredictionProvider::Experimental(_) => continue, + EditPredictionProvider::Experimental( + EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME, + ) => { + let has_api_token = zeta2::Zeta::try_global(cx) + .map_or(false, |zeta| zeta.read(cx).has_sweep_api_token()); + + let entry = ContextMenuEntry::new("Sweep") + .when(!has_api_token, |this| { + this.disabled(true).documentation_aside( + DocumentationSide::Left, + DocumentationEdge::Bottom, + |_| Label::new(USE_SWEEP_API_TOKEN_CALLOUT).into_any_element(), + ) + }) + .handler(move |_, cx| { + set_completion_provider(fs.clone(), cx, provider); + }); + + menu.item(entry) + } + EditPredictionProvider::None | EditPredictionProvider::Experimental(_) => { + continue; + } }; } } @@ -909,6 +943,7 @@ impl EditPredictionButton { fn build_zeta_context_menu( &self, + provider: EditPredictionProvider, window: &mut Window, cx: &mut Context, ) -> Entity { @@ -996,7 +1031,7 @@ impl EditPredictionButton { } let menu = self.build_language_settings_menu(menu, window, cx); - let menu = self.add_provider_switching_section(menu, EditPredictionProvider::Zed, cx); + let menu = self.add_provider_switching_section(menu, provider, cx); menu }) diff --git a/crates/zed/src/zed/edit_prediction_registry.rs b/crates/zed/src/zed/edit_prediction_registry.rs index 250a2b5a0e585d5acad7658a25f89bce12f766d2..577e81c6a9b36bc29a4b1d1f0cda63170c75d5a2 100644 --- a/crates/zed/src/zed/edit_prediction_registry.rs +++ b/crates/zed/src/zed/edit_prediction_registry.rs @@ -204,6 +204,8 @@ fn assign_edit_prediction_provider( editor.set_edit_prediction_provider(Some(provider), window, cx); } value @ (EditPredictionProvider::Experimental(_) | EditPredictionProvider::Zed) => { + let zeta2 = zeta2::Zeta::global(client, &user_store, cx); + if let Some(project) = editor.project() { let mut worktree = None; if let Some(buffer) = &singleton_buffer @@ -217,7 +219,6 @@ fn assign_edit_prediction_provider( && name == EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME && cx.has_flag::() { - let zeta2 = zeta2::Zeta::global(client, &user_store, cx); let provider = cx.new(|cx| { zeta2::ZetaEditPredictionProvider::new( project.clone(), diff --git a/crates/zeta2/src/zeta2.rs b/crates/zeta2/src/zeta2.rs index 6eacc5190f403594ad20f7365512b011d2226719..0d0f4f3d39e9c997282695828ba16e7eccd7d8e2 100644 --- a/crates/zeta2/src/zeta2.rs +++ b/crates/zeta2/src/zeta2.rs @@ -402,20 +402,21 @@ impl Zeta { #[cfg(feature = "eval-support")] eval_cache: None, edit_prediction_model: ZetaEditPredictionModel::ZedCloud, - sweep_api_token: None, + sweep_api_token: std::env::var("SWEEP_AI_TOKEN") + .context("No SWEEP_AI_TOKEN environment variable set") + .log_err(), sweep_ai_debug_info: sweep_ai::debug_info(cx), } } pub fn set_edit_prediction_model(&mut self, model: ZetaEditPredictionModel) { - if model == ZetaEditPredictionModel::Sweep { - self.sweep_api_token = std::env::var("SWEEP_AI_TOKEN") - .context("No SWEEP_AI_TOKEN environment variable set") - .log_err(); - } self.edit_prediction_model = model; } + pub fn has_sweep_api_token(&self) -> bool { + self.sweep_api_token.is_some() + } + #[cfg(feature = "eval-support")] pub fn with_eval_cache(&mut self, cache: Arc) { self.eval_cache = Some(cache); @@ -472,7 +473,11 @@ impl Zeta { } pub fn usage(&self, cx: &App) -> Option { - self.user_store.read(cx).edit_prediction_usage() + if self.edit_prediction_model == ZetaEditPredictionModel::ZedCloud { + self.user_store.read(cx).edit_prediction_usage() + } else { + None + } } pub fn register_project(&mut self, project: &Entity, cx: &mut Context) { @@ -659,6 +664,10 @@ impl Zeta { } fn accept_current_prediction(&mut self, project: &Entity, cx: &mut Context) { + if self.edit_prediction_model != ZetaEditPredictionModel::ZedCloud { + return; + } + let Some(project_state) = self.projects.get_mut(&project.entity_id()) else { return; };