sweep: Add UI for setting Sweep API token in system keychain (#43502)

Ben Kunkle created

Closes #ISSUE

Release Notes:

- N/A *or* Added/Fixed/Improved ...

Change summary

Cargo.lock                                                  |   3 
crates/edit_prediction_button/Cargo.toml                    |   2 
crates/edit_prediction_button/src/edit_prediction_button.rs | 131 ++++--
crates/edit_prediction_button/src/sweep_api_token_modal.rs  |  84 ++++
crates/zeta/Cargo.toml                                      |   9 
crates/zeta/src/provider.rs                                 |   2 
crates/zeta/src/sweep_ai.rs                                 |  59 ++
crates/zeta/src/zeta.rs                                     |  15 
8 files changed, 246 insertions(+), 59 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -5299,6 +5299,7 @@ dependencies = [
  "indoc",
  "language",
  "lsp",
+ "menu",
  "paths",
  "project",
  "regex",
@@ -5308,6 +5309,7 @@ dependencies = [
  "telemetry",
  "theme",
  "ui",
+ "ui_input",
  "util",
  "workspace",
  "zed_actions",
@@ -21678,6 +21680,7 @@ dependencies = [
  "collections",
  "command_palette_hooks",
  "copilot",
+ "credentials_provider",
  "ctor",
  "db",
  "edit_prediction",

crates/edit_prediction_button/Cargo.toml 🔗

@@ -32,6 +32,8 @@ settings.workspace = true
 supermaven.workspace = true
 telemetry.workspace = true
 ui.workspace = true
+ui_input.workspace = true
+menu.workspace = true
 util.workspace = true
 workspace.workspace = true
 zed_actions.workspace = true

crates/edit_prediction_button/src/edit_prediction_button.rs 🔗

@@ -1,3 +1,7 @@
+mod sweep_api_token_modal;
+
+pub use sweep_api_token_modal::SweepApiKeyModal;
+
 use anyhow::Result;
 use client::{Client, UserStore, zed_urls};
 use cloud_llm_client::UsageLimit;
@@ -40,8 +44,7 @@ use workspace::{
     notifications::NotificationId,
 };
 use zed_actions::OpenBrowser;
-use zeta::RateCompletions;
-use zeta::{SweepFeatureFlag, Zeta2FeatureFlag};
+use zeta::{RateCompletions, SweepFeatureFlag, Zeta2FeatureFlag};
 
 actions!(
     edit_prediction,
@@ -313,6 +316,10 @@ impl Render for EditPredictionButton {
                     )
                 );
 
+                let sweep_missing_token = is_sweep
+                    && !zeta::Zeta::try_global(cx)
+                        .map_or(false, |zeta| zeta.read(cx).has_sweep_api_token());
+
                 let zeta_icon = match (is_sweep, enabled) {
                     (true, _) => IconName::SweepAi,
                     (false, true) => IconName::ZedPredict,
@@ -360,19 +367,24 @@ impl Render for EditPredictionButton {
                 let show_editor_predictions = self.editor_show_predictions;
                 let user = self.user_store.read(cx).current_user();
 
+                let indicator_color = if sweep_missing_token {
+                    Some(Color::Error)
+                } else if enabled && (!show_editor_predictions || over_limit) {
+                    Some(if over_limit {
+                        Color::Error
+                    } else {
+                        Color::Muted
+                    })
+                } else {
+                    None
+                };
+
                 let icon_button = IconButton::new("zed-predict-pending-button", zeta_icon)
                     .shape(IconButtonShape::Square)
-                    .when(
-                        enabled && (!show_editor_predictions || over_limit),
-                        |this| {
-                            this.indicator(Indicator::dot().when_else(
-                                over_limit,
-                                |dot| dot.color(Color::Error),
-                                |dot| dot.color(Color::Muted),
-                            ))
+                    .when_some(indicator_color, |this, color| {
+                        this.indicator(Indicator::dot().color(color))
                             .indicator_border_color(Some(cx.theme().colors().status_bar_background))
-                        },
-                    )
+                    })
                     .when(!self.popover_menu_handle.is_deployed(), |element| {
                         let user = user.clone();
                         element.tooltip(move |_window, cx| {
@@ -537,23 +549,23 @@ impl EditPredictionButton {
 
         const ZED_AI_CALLOUT: &str =
             "Zed's edit prediction is powered by Zeta, an open-source, dataset mode.";
-        const USE_SWEEP_API_TOKEN_CALLOUT: &str =
-            "Set the SWEEP_API_TOKEN environment variable to use Sweep";
 
-        let other_providers: Vec<_> = available_providers
+        let providers: Vec<_> = available_providers
             .into_iter()
-            .filter(|p| *p != current_provider && *p != EditPredictionProvider::None)
+            .filter(|p| *p != EditPredictionProvider::None)
             .collect();
 
-        if !other_providers.is_empty() {
-            menu = menu.separator().header("Switch Providers");
+        if !providers.is_empty() {
+            menu = menu.separator().header("Providers");
 
-            for provider in other_providers {
+            for provider in providers {
+                let is_current = provider == current_provider;
                 let fs = self.fs.clone();
 
                 menu = match provider {
                     EditPredictionProvider::Zed => menu.item(
                         ContextMenuEntry::new("Zed AI")
+                            .toggleable(IconPosition::Start, is_current)
                             .documentation_aside(
                                 DocumentationSide::Left,
                                 DocumentationEdge::Bottom,
@@ -563,46 +575,77 @@ impl EditPredictionButton {
                                 set_completion_provider(fs.clone(), cx, provider);
                             }),
                     ),
-                    EditPredictionProvider::Copilot => {
-                        menu.entry("GitHub Copilot", None, move |_, cx| {
-                            set_completion_provider(fs.clone(), cx, provider);
-                        })
-                    }
-                    EditPredictionProvider::Supermaven => {
-                        menu.entry("Supermaven", None, move |_, cx| {
-                            set_completion_provider(fs.clone(), cx, provider);
-                        })
-                    }
-                    EditPredictionProvider::Codestral => {
-                        menu.entry("Codestral", None, move |_, cx| {
-                            set_completion_provider(fs.clone(), cx, provider);
-                        })
-                    }
+                    EditPredictionProvider::Copilot => menu.item(
+                        ContextMenuEntry::new("GitHub Copilot")
+                            .toggleable(IconPosition::Start, is_current)
+                            .handler(move |_, cx| {
+                                set_completion_provider(fs.clone(), cx, provider);
+                            }),
+                    ),
+                    EditPredictionProvider::Supermaven => menu.item(
+                        ContextMenuEntry::new("Supermaven")
+                            .toggleable(IconPosition::Start, is_current)
+                            .handler(move |_, cx| {
+                                set_completion_provider(fs.clone(), cx, provider);
+                            }),
+                    ),
+                    EditPredictionProvider::Codestral => menu.item(
+                        ContextMenuEntry::new("Codestral")
+                            .toggleable(IconPosition::Start, is_current)
+                            .handler(move |_, cx| {
+                                set_completion_provider(fs.clone(), cx, provider);
+                            }),
+                    ),
                     EditPredictionProvider::Experimental(
                         EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME,
                     ) => {
                         let has_api_token = zeta::Zeta::try_global(cx)
                             .map_or(false, |zeta| zeta.read(cx).has_sweep_api_token());
 
-                        let entry = ContextMenuEntry::new("Sweep")
-                            .when(!has_api_token, |this| {
-                                this.disabled(true).documentation_aside(
+                        let should_open_modal = !has_api_token || is_current;
+
+                        let entry = if has_api_token {
+                            ContextMenuEntry::new("Sweep")
+                                .toggleable(IconPosition::Start, is_current)
+                        } else {
+                            ContextMenuEntry::new("Sweep")
+                                .icon(IconName::XCircle)
+                                .icon_color(Color::Error)
+                                .documentation_aside(
                                     DocumentationSide::Left,
                                     DocumentationEdge::Bottom,
-                                    |_| Label::new(USE_SWEEP_API_TOKEN_CALLOUT).into_any_element(),
+                                    |_| {
+                                        Label::new("Click to configure your Sweep API token")
+                                            .into_any_element()
+                                    },
                                 )
-                            })
-                            .handler(move |_, cx| {
+                        };
+
+                        let entry = entry.handler(move |window, cx| {
+                            if should_open_modal {
+                                if let Some(workspace) = window.root::<Workspace>().flatten() {
+                                    workspace.update(cx, |workspace, cx| {
+                                        workspace.toggle_modal(window, cx, |window, cx| {
+                                            SweepApiKeyModal::new(window, cx)
+                                        });
+                                    });
+                                };
+                            } else {
                                 set_completion_provider(fs.clone(), cx, provider);
-                            });
+                            }
+                        });
 
                         menu.item(entry)
                     }
                     EditPredictionProvider::Experimental(
                         EXPERIMENTAL_ZETA2_EDIT_PREDICTION_PROVIDER_NAME,
-                    ) => menu.entry("Zeta2", None, move |_, cx| {
-                        set_completion_provider(fs.clone(), cx, provider);
-                    }),
+                    ) => menu.item(
+                        ContextMenuEntry::new("Zeta2")
+                            .toggleable(IconPosition::Start, is_current)
+                            .handler(move |_, cx| {
+                                set_completion_provider(fs.clone(), cx, provider);
+                            }),
+                    ),
                     EditPredictionProvider::None | EditPredictionProvider::Experimental(_) => {
                         continue;
                     }

crates/edit_prediction_button/src/sweep_api_token_modal.rs 🔗

@@ -0,0 +1,84 @@
+use gpui::{
+    DismissEvent, Entity, EventEmitter, FocusHandle, Focusable, IntoElement, ParentElement, Render,
+};
+use ui::{Button, ButtonStyle, Clickable, Headline, HeadlineSize, prelude::*};
+use ui_input::InputField;
+use workspace::ModalView;
+use zeta::Zeta;
+
+pub struct SweepApiKeyModal {
+    api_key_input: Entity<InputField>,
+    focus_handle: FocusHandle,
+}
+
+impl SweepApiKeyModal {
+    pub fn new(window: &mut Window, cx: &mut Context<Self>) -> Self {
+        let api_key_input = cx.new(|cx| InputField::new(window, cx, "Enter your Sweep API token"));
+
+        Self {
+            api_key_input,
+            focus_handle: cx.focus_handle(),
+        }
+    }
+
+    fn cancel(&mut self, _: &menu::Cancel, _window: &mut Window, cx: &mut Context<Self>) {
+        cx.emit(DismissEvent);
+    }
+
+    fn confirm(&mut self, _: &menu::Confirm, _window: &mut Window, cx: &mut Context<Self>) {
+        let api_key = self.api_key_input.read(cx).text(cx);
+        let api_key = (!api_key.trim().is_empty()).then_some(api_key);
+
+        if let Some(zeta) = Zeta::try_global(cx) {
+            zeta.update(cx, |zeta, cx| {
+                zeta.sweep_ai
+                    .set_api_token(api_key, cx)
+                    .detach_and_log_err(cx);
+            });
+        }
+
+        cx.emit(DismissEvent);
+    }
+}
+
+impl EventEmitter<DismissEvent> for SweepApiKeyModal {}
+
+impl ModalView for SweepApiKeyModal {}
+
+impl Focusable for SweepApiKeyModal {
+    fn focus_handle(&self, _cx: &App) -> FocusHandle {
+        self.focus_handle.clone()
+    }
+}
+
+impl Render for SweepApiKeyModal {
+    fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
+        v_flex()
+            .key_context("SweepApiKeyModal")
+            .on_action(cx.listener(Self::cancel))
+            .on_action(cx.listener(Self::confirm))
+            .elevation_2(cx)
+            .w(px(400.))
+            .p_4()
+            .gap_3()
+            .child(Headline::new("Sweep API Token").size(HeadlineSize::Small))
+            .child(self.api_key_input.clone())
+            .child(
+                h_flex()
+                    .justify_end()
+                    .gap_2()
+                    .child(Button::new("cancel", "Cancel").on_click(cx.listener(
+                        |_, _, _window, cx| {
+                            cx.emit(DismissEvent);
+                        },
+                    )))
+                    .child(
+                        Button::new("save", "Save")
+                            .style(ButtonStyle::Filled)
+                            .on_click(cx.listener(|this, _, window, cx| {
+                                this.confirm(&menu::Confirm, window, cx);
+                            })),
+                    ),
+            )
+    }
+}

crates/zeta/Cargo.toml 🔗

@@ -23,9 +23,10 @@ buffer_diff.workspace = true
 client.workspace = true
 cloud_llm_client.workspace = true
 cloud_zeta2_prompt.workspace = true
-copilot.workspace = true
 collections.workspace = true
 command_palette_hooks.workspace = true
+copilot.workspace = true
+credentials_provider.workspace = true
 db.workspace = true
 edit_prediction.workspace = true
 edit_prediction_context.workspace = true
@@ -43,12 +44,12 @@ lsp.workspace = true
 markdown.workspace = true
 menu.workspace = true
 open_ai.workspace = true
-pretty_assertions.workspace = true
 postage.workspace = true
+pretty_assertions.workspace = true
 project.workspace = true
 rand.workspace = true
-release_channel.workspace = true
 regex.workspace = true
+release_channel.workspace = true
 semver.workspace = true
 serde.workspace = true
 serde_json.workspace = true
@@ -60,8 +61,8 @@ telemetry.workspace = true
 telemetry_events.workspace = true
 theme.workspace = true
 thiserror.workspace = true
-util.workspace = true
 ui.workspace = true
+util.workspace = true
 uuid.workspace = true
 workspace.workspace = true
 worktree.workspace = true

crates/zeta/src/provider.rs 🔗

@@ -78,7 +78,7 @@ impl EditPredictionProvider for ZetaEditPredictionProvider {
     ) -> bool {
         let zeta = self.zeta.read(cx);
         if zeta.edit_prediction_model == ZetaEditPredictionModel::Sweep {
-            zeta.sweep_ai.api_token.is_some()
+            zeta.has_sweep_api_token()
         } else {
             true
         }

crates/zeta/src/sweep_ai.rs 🔗

@@ -1,6 +1,7 @@
-use anyhow::Result;
+use anyhow::{Context as _, Result};
 use cloud_llm_client::predict_edits_v3::Event;
-use futures::AsyncReadExt as _;
+use credentials_provider::CredentialsProvider;
+use futures::{AsyncReadExt as _, FutureExt, future::Shared};
 use gpui::{
     App, AppContext as _, Entity, Task,
     http_client::{self, AsyncBody, Method},
@@ -23,18 +24,23 @@ use crate::{EditPredictionId, EditPredictionInputs, prediction::EditPredictionRe
 const SWEEP_API_URL: &str = "https://autocomplete.sweep.dev/backend/next_edit_autocomplete";
 
 pub struct SweepAi {
-    pub api_token: Option<String>,
+    pub api_token: Shared<Task<Option<String>>>,
     pub debug_info: Arc<str>,
 }
 
 impl SweepAi {
     pub fn new(cx: &App) -> Self {
         SweepAi {
-            api_token: std::env::var("SWEEP_AI_TOKEN").ok(),
+            api_token: load_api_token(cx).shared(),
             debug_info: debug_info(cx),
         }
     }
 
+    pub fn set_api_token(&mut self, api_token: Option<String>, cx: &mut App) -> Task<Result<()>> {
+        self.api_token = Task::ready(api_token.clone()).shared();
+        store_api_token_in_keychain(api_token, cx)
+    }
+
     pub fn request_prediction_with_sweep(
         &self,
         project: &Entity<Project>,
@@ -47,7 +53,7 @@ impl SweepAi {
         cx: &mut App,
     ) -> Task<Result<Option<EditPredictionResult>>> {
         let debug_info = self.debug_info.clone();
-        let Some(api_token) = self.api_token.clone() else {
+        let Some(api_token) = self.api_token.clone().now_or_never().flatten() else {
             return Task::ready(Ok(None));
         };
         let full_path: Arc<Path> = snapshot
@@ -260,6 +266,49 @@ impl SweepAi {
     }
 }
 
+pub const SWEEP_CREDENTIALS_URL: &str = "https://autocomplete.sweep.dev";
+pub const SWEEP_CREDENTIALS_USERNAME: &str = "sweep-api-token";
+
+pub fn load_api_token(cx: &App) -> Task<Option<String>> {
+    if let Some(api_token) = std::env::var("SWEEP_AI_TOKEN")
+        .ok()
+        .filter(|value| !value.is_empty())
+    {
+        return Task::ready(Some(api_token));
+    }
+    let credentials_provider = <dyn CredentialsProvider>::global(cx);
+    cx.spawn(async move |cx| {
+        let (_, credentials) = credentials_provider
+            .read_credentials(SWEEP_CREDENTIALS_URL, &cx)
+            .await
+            .ok()??;
+        String::from_utf8(credentials).ok()
+    })
+}
+
+fn store_api_token_in_keychain(api_token: Option<String>, cx: &App) -> Task<Result<()>> {
+    let credentials_provider = <dyn CredentialsProvider>::global(cx);
+
+    cx.spawn(async move |cx| {
+        if let Some(api_token) = api_token {
+            credentials_provider
+                .write_credentials(
+                    SWEEP_CREDENTIALS_URL,
+                    SWEEP_CREDENTIALS_USERNAME,
+                    api_token.as_bytes(),
+                    cx,
+                )
+                .await
+                .context("Failed to save Sweep API token to system keychain")
+        } else {
+            credentials_provider
+                .delete_credentials(SWEEP_CREDENTIALS_URL, cx)
+                .await
+                .context("Failed to delete Sweep API token from system keychain")
+        }
+    })
+}
+
 #[derive(Debug, Clone, Serialize)]
 struct AutocompleteRequest {
     pub debug_info: Arc<str>,

crates/zeta/src/zeta.rs 🔗

@@ -20,7 +20,7 @@ use edit_prediction_context::{
 };
 use feature_flags::{FeatureFlag, FeatureFlagAppExt as _, PredictEditsRateCompletionsFeatureFlag};
 use futures::channel::{mpsc, oneshot};
-use futures::{AsyncReadExt as _, StreamExt as _};
+use futures::{AsyncReadExt as _, FutureExt as _, StreamExt as _};
 use gpui::{
     App, AsyncApp, Entity, EntityId, Global, SharedString, Subscription, Task, WeakEntity, actions,
     http_client::{self, AsyncBody, Method},
@@ -61,7 +61,7 @@ mod prediction;
 mod provider;
 mod rate_prediction_modal;
 pub mod retrieval_search;
-mod sweep_ai;
+pub mod sweep_ai;
 pub mod udiff;
 mod xml_edits;
 pub mod zeta1;
@@ -80,7 +80,7 @@ use crate::rate_prediction_modal::{
     NextEdit, PreviousEdit, RatePredictionsModal, ThumbsDownActivePrediction,
     ThumbsUpActivePrediction,
 };
-use crate::sweep_ai::SweepAi;
+pub use crate::sweep_ai::SweepAi;
 use crate::zeta1::request_prediction_with_zeta1;
 pub use provider::ZetaEditPredictionProvider;
 
@@ -193,7 +193,7 @@ pub struct Zeta {
     #[cfg(feature = "eval-support")]
     eval_cache: Option<Arc<dyn EvalCache>>,
     edit_prediction_model: ZetaEditPredictionModel,
-    sweep_ai: SweepAi,
+    pub sweep_ai: SweepAi,
     data_collection_choice: DataCollectionChoice,
     rejected_predictions: Vec<EditPredictionRejection>,
     reject_predictions_tx: mpsc::UnboundedSender<()>,
@@ -553,7 +553,12 @@ impl Zeta {
     }
 
     pub fn has_sweep_api_token(&self) -> bool {
-        self.sweep_ai.api_token.is_some()
+        self.sweep_ai
+            .api_token
+            .clone()
+            .now_or_never()
+            .flatten()
+            .is_some()
     }
 
     #[cfg(feature = "eval-support")]