Make the edit prediction status bar menu work correctly when using sweep (#43203)

Max Brunsfeld and Ben Kunkle created

Release Notes:

- N/A

---------

Co-authored-by: Ben Kunkle <ben@zed.dev>

Change summary

crates/edit_prediction_button/src/edit_prediction_button.rs | 89 ++++--
crates/zed/src/zed/edit_prediction_registry.rs              |  3 
crates/zeta2/src/zeta2.rs                                   | 23 +
3 files changed, 80 insertions(+), 35 deletions(-)

Detailed changes

crates/edit_prediction_button/src/edit_prediction_button.rs 🔗

@@ -83,9 +83,7 @@ impl Render for EditPredictionButton {
 
         let all_language_settings = all_language_settings(None, cx);
 
-        match &all_language_settings.edit_predictions.provider {
-            EditPredictionProvider::None => div().hidden(),
-
+        match all_language_settings.edit_predictions.provider {
             EditPredictionProvider::Copilot => {
                 let Some(copilot) = Copilot::global(cx) else {
                     return div().hidden();
@@ -302,23 +300,23 @@ impl Render for EditPredictionButton {
                         .with_handle(self.popover_menu_handle.clone()),
                 )
             }
-            EditPredictionProvider::Experimental(provider_name) => {
-                if *provider_name == EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME
-                    && cx.has_flag::<SweepFeatureFlag>()
-                {
-                    div().child(Icon::new(IconName::SweepAi))
-                } else {
-                    div()
-                }
-            }
-
-            EditPredictionProvider::Zed => {
+            provider @ (EditPredictionProvider::Experimental(
+                EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME,
+            )
+            | EditPredictionProvider::Zed) => {
                 let enabled = self.editor_enabled.unwrap_or(true);
 
-                let zeta_icon = if enabled {
-                    IconName::ZedPredict
-                } else {
-                    IconName::ZedPredictDisabled
+                let is_sweep = matches!(
+                    provider,
+                    EditPredictionProvider::Experimental(
+                        EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME
+                    )
+                );
+
+                let zeta_icon = match (is_sweep, enabled) {
+                    (true, _) => IconName::SweepAi,
+                    (false, true) => IconName::ZedPredict,
+                    (false, false) => IconName::ZedPredictDisabled,
                 };
 
                 if zeta::should_show_upsell_modal() {
@@ -402,8 +400,10 @@ impl Render for EditPredictionButton {
 
                 let mut popover_menu = PopoverMenu::new("zeta")
                     .menu(move |window, cx| {
-                        this.update(cx, |this, cx| this.build_zeta_context_menu(window, cx))
-                            .ok()
+                        this.update(cx, |this, cx| {
+                            this.build_zeta_context_menu(provider, window, cx)
+                        })
+                        .ok()
                     })
                     .anchor(Corner::BottomRight)
                     .with_handle(self.popover_menu_handle.clone());
@@ -429,6 +429,10 @@ impl Render for EditPredictionButton {
 
                 div().child(popover_menu.into_any_element())
             }
+
+            EditPredictionProvider::None | EditPredictionProvider::Experimental(_) => {
+                div().hidden()
+            }
         }
     }
 }
@@ -487,6 +491,12 @@ impl EditPredictionButton {
             providers.push(EditPredictionProvider::Codestral);
         }
 
+        if cx.has_flag::<SweepFeatureFlag>() {
+            providers.push(EditPredictionProvider::Experimental(
+                EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME,
+            ));
+        }
+
         providers
     }
 
@@ -498,6 +508,11 @@ impl EditPredictionButton {
     ) -> ContextMenu {
         let available_providers = self.get_available_providers(cx);
 
+        const ZED_AI_CALLOUT: &str =
+            "Zed's edit prediction is powered by Zeta, an open-source, dataset mode.";
+        const USE_SWEEP_API_TOKEN_CALLOUT: &str =
+            "Set the SWEEP_API_TOKEN environment variable to use Sweep";
+
         let other_providers: Vec<_> = available_providers
             .into_iter()
             .filter(|p| *p != current_provider && *p != EditPredictionProvider::None)
@@ -514,11 +529,8 @@ impl EditPredictionButton {
                         ContextMenuEntry::new("Zed AI")
                             .documentation_aside(
                                 DocumentationSide::Left,
-                                DocumentationEdge::Top,
-                                |_| {
-                                    Label::new("Zed's edit prediction is powered by Zeta, an open-source, dataset mode.")
-                                        .into_any_element()
-                                },
+                                DocumentationEdge::Bottom,
+                                |_| Label::new(ZED_AI_CALLOUT).into_any_element(),
                             )
                             .handler(move |_, cx| {
                                 set_completion_provider(fs.clone(), cx, provider);
@@ -539,7 +551,29 @@ impl EditPredictionButton {
                             set_completion_provider(fs.clone(), cx, provider);
                         })
                     }
-                    EditPredictionProvider::None | EditPredictionProvider::Experimental(_) => continue,
+                    EditPredictionProvider::Experimental(
+                        EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME,
+                    ) => {
+                        let has_api_token = zeta2::Zeta::try_global(cx)
+                            .map_or(false, |zeta| zeta.read(cx).has_sweep_api_token());
+
+                        let entry = ContextMenuEntry::new("Sweep")
+                            .when(!has_api_token, |this| {
+                                this.disabled(true).documentation_aside(
+                                    DocumentationSide::Left,
+                                    DocumentationEdge::Bottom,
+                                    |_| Label::new(USE_SWEEP_API_TOKEN_CALLOUT).into_any_element(),
+                                )
+                            })
+                            .handler(move |_, cx| {
+                                set_completion_provider(fs.clone(), cx, provider);
+                            });
+
+                        menu.item(entry)
+                    }
+                    EditPredictionProvider::None | EditPredictionProvider::Experimental(_) => {
+                        continue;
+                    }
                 };
             }
         }
@@ -909,6 +943,7 @@ impl EditPredictionButton {
 
     fn build_zeta_context_menu(
         &self,
+        provider: EditPredictionProvider,
         window: &mut Window,
         cx: &mut Context<Self>,
     ) -> Entity<ContextMenu> {
@@ -996,7 +1031,7 @@ impl EditPredictionButton {
             }
 
             let menu = self.build_language_settings_menu(menu, window, cx);
-            let menu = self.add_provider_switching_section(menu, EditPredictionProvider::Zed, cx);
+            let menu = self.add_provider_switching_section(menu, provider, cx);
 
             menu
         })

crates/zed/src/zed/edit_prediction_registry.rs 🔗

@@ -204,6 +204,8 @@ fn assign_edit_prediction_provider(
             editor.set_edit_prediction_provider(Some(provider), window, cx);
         }
         value @ (EditPredictionProvider::Experimental(_) | EditPredictionProvider::Zed) => {
+            let zeta2 = zeta2::Zeta::global(client, &user_store, cx);
+
             if let Some(project) = editor.project() {
                 let mut worktree = None;
                 if let Some(buffer) = &singleton_buffer
@@ -217,7 +219,6 @@ fn assign_edit_prediction_provider(
                     && name == EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME
                     && cx.has_flag::<SweepFeatureFlag>()
                 {
-                    let zeta2 = zeta2::Zeta::global(client, &user_store, cx);
                     let provider = cx.new(|cx| {
                         zeta2::ZetaEditPredictionProvider::new(
                             project.clone(),

crates/zeta2/src/zeta2.rs 🔗

@@ -402,20 +402,21 @@ impl Zeta {
             #[cfg(feature = "eval-support")]
             eval_cache: None,
             edit_prediction_model: ZetaEditPredictionModel::ZedCloud,
-            sweep_api_token: None,
+            sweep_api_token: std::env::var("SWEEP_AI_TOKEN")
+                .context("No SWEEP_AI_TOKEN environment variable set")
+                .log_err(),
             sweep_ai_debug_info: sweep_ai::debug_info(cx),
         }
     }
 
     pub fn set_edit_prediction_model(&mut self, model: ZetaEditPredictionModel) {
-        if model == ZetaEditPredictionModel::Sweep {
-            self.sweep_api_token = std::env::var("SWEEP_AI_TOKEN")
-                .context("No SWEEP_AI_TOKEN environment variable set")
-                .log_err();
-        }
         self.edit_prediction_model = model;
     }
 
+    pub fn has_sweep_api_token(&self) -> bool {
+        self.sweep_api_token.is_some()
+    }
+
     #[cfg(feature = "eval-support")]
     pub fn with_eval_cache(&mut self, cache: Arc<dyn EvalCache>) {
         self.eval_cache = Some(cache);
@@ -472,7 +473,11 @@ impl Zeta {
     }
 
     pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
-        self.user_store.read(cx).edit_prediction_usage()
+        if self.edit_prediction_model == ZetaEditPredictionModel::ZedCloud {
+            self.user_store.read(cx).edit_prediction_usage()
+        } else {
+            None
+        }
     }
 
     pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
@@ -659,6 +664,10 @@ impl Zeta {
     }
 
     fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
+        if self.edit_prediction_model != ZetaEditPredictionModel::ZedCloud {
+            return;
+        }
+
         let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
             return;
         };