From ff6bd7d82efe3b2539cfcfb3dea3b030313c2e94 Mon Sep 17 00:00:00 2001 From: Ben Kunkle Date: Mon, 1 Dec 2025 11:36:49 -0800 Subject: [PATCH] sweep: Add UI for setting Sweep API token in system keychain (#43502) Closes #ISSUE Release Notes: - N/A *or* Added/Fixed/Improved ... --- Cargo.lock | 3 + crates/edit_prediction_button/Cargo.toml | 2 + .../src/edit_prediction_button.rs | 131 ++++++++++++------ .../src/sweep_api_token_modal.rs | 84 +++++++++++ crates/zeta/Cargo.toml | 9 +- crates/zeta/src/provider.rs | 2 +- crates/zeta/src/sweep_ai.rs | 59 +++++++- crates/zeta/src/zeta.rs | 15 +- 8 files changed, 246 insertions(+), 59 deletions(-) create mode 100644 crates/edit_prediction_button/src/sweep_api_token_modal.rs diff --git a/Cargo.lock b/Cargo.lock index 1d891e7a066d0eb6ec3c79e65f291e0bdf93961f..99c8bb19e8c45dd60f36b4234b275ff80ee43f16 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5299,6 +5299,7 @@ dependencies = [ "indoc", "language", "lsp", + "menu", "paths", "project", "regex", @@ -5308,6 +5309,7 @@ dependencies = [ "telemetry", "theme", "ui", + "ui_input", "util", "workspace", "zed_actions", @@ -21678,6 +21680,7 @@ dependencies = [ "collections", "command_palette_hooks", "copilot", + "credentials_provider", "ctor", "db", "edit_prediction", diff --git a/crates/edit_prediction_button/Cargo.toml b/crates/edit_prediction_button/Cargo.toml index b7ec07e1e2b24d1d1b851913195afdbf58376da5..d336cf66926d37ab7c0ebb1d5aa5a2172342350c 100644 --- a/crates/edit_prediction_button/Cargo.toml +++ b/crates/edit_prediction_button/Cargo.toml @@ -32,6 +32,8 @@ settings.workspace = true supermaven.workspace = true telemetry.workspace = true ui.workspace = true +ui_input.workspace = true +menu.workspace = true util.workspace = true workspace.workspace = true zed_actions.workspace = true diff --git a/crates/edit_prediction_button/src/edit_prediction_button.rs b/crates/edit_prediction_button/src/edit_prediction_button.rs index e10cc13c3a9b97fda9c6b79a55878d0fcc960934..8ce8441859b7cc747a2b566dedd913e58259969d 100644 --- a/crates/edit_prediction_button/src/edit_prediction_button.rs +++ b/crates/edit_prediction_button/src/edit_prediction_button.rs @@ -1,3 +1,7 @@ +mod sweep_api_token_modal; + +pub use sweep_api_token_modal::SweepApiKeyModal; + use anyhow::Result; use client::{Client, UserStore, zed_urls}; use cloud_llm_client::UsageLimit; @@ -40,8 +44,7 @@ use workspace::{ notifications::NotificationId, }; use zed_actions::OpenBrowser; -use zeta::RateCompletions; -use zeta::{SweepFeatureFlag, Zeta2FeatureFlag}; +use zeta::{RateCompletions, SweepFeatureFlag, Zeta2FeatureFlag}; actions!( edit_prediction, @@ -313,6 +316,10 @@ impl Render for EditPredictionButton { ) ); + let sweep_missing_token = is_sweep + && !zeta::Zeta::try_global(cx) + .map_or(false, |zeta| zeta.read(cx).has_sweep_api_token()); + let zeta_icon = match (is_sweep, enabled) { (true, _) => IconName::SweepAi, (false, true) => IconName::ZedPredict, @@ -360,19 +367,24 @@ impl Render for EditPredictionButton { let show_editor_predictions = self.editor_show_predictions; let user = self.user_store.read(cx).current_user(); + let indicator_color = if sweep_missing_token { + Some(Color::Error) + } else if enabled && (!show_editor_predictions || over_limit) { + Some(if over_limit { + Color::Error + } else { + Color::Muted + }) + } else { + None + }; + let icon_button = IconButton::new("zed-predict-pending-button", zeta_icon) .shape(IconButtonShape::Square) - .when( - enabled && (!show_editor_predictions || over_limit), - |this| { - this.indicator(Indicator::dot().when_else( - over_limit, - |dot| dot.color(Color::Error), - |dot| dot.color(Color::Muted), - )) + .when_some(indicator_color, |this, color| { + this.indicator(Indicator::dot().color(color)) .indicator_border_color(Some(cx.theme().colors().status_bar_background)) - }, - ) + }) .when(!self.popover_menu_handle.is_deployed(), |element| { let user = user.clone(); element.tooltip(move |_window, cx| { @@ -537,23 +549,23 @@ impl EditPredictionButton { const ZED_AI_CALLOUT: &str = "Zed's edit prediction is powered by Zeta, an open-source, dataset mode."; - const USE_SWEEP_API_TOKEN_CALLOUT: &str = - "Set the SWEEP_API_TOKEN environment variable to use Sweep"; - let other_providers: Vec<_> = available_providers + let providers: Vec<_> = available_providers .into_iter() - .filter(|p| *p != current_provider && *p != EditPredictionProvider::None) + .filter(|p| *p != EditPredictionProvider::None) .collect(); - if !other_providers.is_empty() { - menu = menu.separator().header("Switch Providers"); + if !providers.is_empty() { + menu = menu.separator().header("Providers"); - for provider in other_providers { + for provider in providers { + let is_current = provider == current_provider; let fs = self.fs.clone(); menu = match provider { EditPredictionProvider::Zed => menu.item( ContextMenuEntry::new("Zed AI") + .toggleable(IconPosition::Start, is_current) .documentation_aside( DocumentationSide::Left, DocumentationEdge::Bottom, @@ -563,46 +575,77 @@ impl EditPredictionButton { set_completion_provider(fs.clone(), cx, provider); }), ), - EditPredictionProvider::Copilot => { - menu.entry("GitHub Copilot", None, move |_, cx| { - set_completion_provider(fs.clone(), cx, provider); - }) - } - EditPredictionProvider::Supermaven => { - menu.entry("Supermaven", None, move |_, cx| { - set_completion_provider(fs.clone(), cx, provider); - }) - } - EditPredictionProvider::Codestral => { - menu.entry("Codestral", None, move |_, cx| { - set_completion_provider(fs.clone(), cx, provider); - }) - } + EditPredictionProvider::Copilot => menu.item( + ContextMenuEntry::new("GitHub Copilot") + .toggleable(IconPosition::Start, is_current) + .handler(move |_, cx| { + set_completion_provider(fs.clone(), cx, provider); + }), + ), + EditPredictionProvider::Supermaven => menu.item( + ContextMenuEntry::new("Supermaven") + .toggleable(IconPosition::Start, is_current) + .handler(move |_, cx| { + set_completion_provider(fs.clone(), cx, provider); + }), + ), + EditPredictionProvider::Codestral => menu.item( + ContextMenuEntry::new("Codestral") + .toggleable(IconPosition::Start, is_current) + .handler(move |_, cx| { + set_completion_provider(fs.clone(), cx, provider); + }), + ), EditPredictionProvider::Experimental( EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME, ) => { let has_api_token = zeta::Zeta::try_global(cx) .map_or(false, |zeta| zeta.read(cx).has_sweep_api_token()); - let entry = ContextMenuEntry::new("Sweep") - .when(!has_api_token, |this| { - this.disabled(true).documentation_aside( + let should_open_modal = !has_api_token || is_current; + + let entry = if has_api_token { + ContextMenuEntry::new("Sweep") + .toggleable(IconPosition::Start, is_current) + } else { + ContextMenuEntry::new("Sweep") + .icon(IconName::XCircle) + .icon_color(Color::Error) + .documentation_aside( DocumentationSide::Left, DocumentationEdge::Bottom, - |_| Label::new(USE_SWEEP_API_TOKEN_CALLOUT).into_any_element(), + |_| { + Label::new("Click to configure your Sweep API token") + .into_any_element() + }, ) - }) - .handler(move |_, cx| { + }; + + let entry = entry.handler(move |window, cx| { + if should_open_modal { + if let Some(workspace) = window.root::().flatten() { + workspace.update(cx, |workspace, cx| { + workspace.toggle_modal(window, cx, |window, cx| { + SweepApiKeyModal::new(window, cx) + }); + }); + }; + } else { set_completion_provider(fs.clone(), cx, provider); - }); + } + }); menu.item(entry) } EditPredictionProvider::Experimental( EXPERIMENTAL_ZETA2_EDIT_PREDICTION_PROVIDER_NAME, - ) => menu.entry("Zeta2", None, move |_, cx| { - set_completion_provider(fs.clone(), cx, provider); - }), + ) => menu.item( + ContextMenuEntry::new("Zeta2") + .toggleable(IconPosition::Start, is_current) + .handler(move |_, cx| { + set_completion_provider(fs.clone(), cx, provider); + }), + ), EditPredictionProvider::None | EditPredictionProvider::Experimental(_) => { continue; } diff --git a/crates/edit_prediction_button/src/sweep_api_token_modal.rs b/crates/edit_prediction_button/src/sweep_api_token_modal.rs new file mode 100644 index 0000000000000000000000000000000000000000..ab2102f25a2a7291644ca67ab3c89fd47da7ac0a --- /dev/null +++ b/crates/edit_prediction_button/src/sweep_api_token_modal.rs @@ -0,0 +1,84 @@ +use gpui::{ + DismissEvent, Entity, EventEmitter, FocusHandle, Focusable, IntoElement, ParentElement, Render, +}; +use ui::{Button, ButtonStyle, Clickable, Headline, HeadlineSize, prelude::*}; +use ui_input::InputField; +use workspace::ModalView; +use zeta::Zeta; + +pub struct SweepApiKeyModal { + api_key_input: Entity, + focus_handle: FocusHandle, +} + +impl SweepApiKeyModal { + pub fn new(window: &mut Window, cx: &mut Context) -> Self { + let api_key_input = cx.new(|cx| InputField::new(window, cx, "Enter your Sweep API token")); + + Self { + api_key_input, + focus_handle: cx.focus_handle(), + } + } + + fn cancel(&mut self, _: &menu::Cancel, _window: &mut Window, cx: &mut Context) { + cx.emit(DismissEvent); + } + + fn confirm(&mut self, _: &menu::Confirm, _window: &mut Window, cx: &mut Context) { + let api_key = self.api_key_input.read(cx).text(cx); + let api_key = (!api_key.trim().is_empty()).then_some(api_key); + + if let Some(zeta) = Zeta::try_global(cx) { + zeta.update(cx, |zeta, cx| { + zeta.sweep_ai + .set_api_token(api_key, cx) + .detach_and_log_err(cx); + }); + } + + cx.emit(DismissEvent); + } +} + +impl EventEmitter for SweepApiKeyModal {} + +impl ModalView for SweepApiKeyModal {} + +impl Focusable for SweepApiKeyModal { + fn focus_handle(&self, _cx: &App) -> FocusHandle { + self.focus_handle.clone() + } +} + +impl Render for SweepApiKeyModal { + fn render(&mut self, _window: &mut Window, cx: &mut Context) -> impl IntoElement { + v_flex() + .key_context("SweepApiKeyModal") + .on_action(cx.listener(Self::cancel)) + .on_action(cx.listener(Self::confirm)) + .elevation_2(cx) + .w(px(400.)) + .p_4() + .gap_3() + .child(Headline::new("Sweep API Token").size(HeadlineSize::Small)) + .child(self.api_key_input.clone()) + .child( + h_flex() + .justify_end() + .gap_2() + .child(Button::new("cancel", "Cancel").on_click(cx.listener( + |_, _, _window, cx| { + cx.emit(DismissEvent); + }, + ))) + .child( + Button::new("save", "Save") + .style(ButtonStyle::Filled) + .on_click(cx.listener(|this, _, window, cx| { + this.confirm(&menu::Confirm, window, cx); + })), + ), + ) + } +} diff --git a/crates/zeta/Cargo.toml b/crates/zeta/Cargo.toml index 61eeab16229d82dc01d800f37bf729aa11469afd..7429fcb8e8d5e4b485f69ea87c37d7d670c3b199 100644 --- a/crates/zeta/Cargo.toml +++ b/crates/zeta/Cargo.toml @@ -23,9 +23,10 @@ buffer_diff.workspace = true client.workspace = true cloud_llm_client.workspace = true cloud_zeta2_prompt.workspace = true -copilot.workspace = true collections.workspace = true command_palette_hooks.workspace = true +copilot.workspace = true +credentials_provider.workspace = true db.workspace = true edit_prediction.workspace = true edit_prediction_context.workspace = true @@ -43,12 +44,12 @@ lsp.workspace = true markdown.workspace = true menu.workspace = true open_ai.workspace = true -pretty_assertions.workspace = true postage.workspace = true +pretty_assertions.workspace = true project.workspace = true rand.workspace = true -release_channel.workspace = true regex.workspace = true +release_channel.workspace = true semver.workspace = true serde.workspace = true serde_json.workspace = true @@ -60,8 +61,8 @@ telemetry.workspace = true telemetry_events.workspace = true theme.workspace = true thiserror.workspace = true -util.workspace = true ui.workspace = true +util.workspace = true uuid.workspace = true workspace.workspace = true worktree.workspace = true diff --git a/crates/zeta/src/provider.rs b/crates/zeta/src/provider.rs index b91df0963386543fbd1e8645e5893a35fe202cc5..5a2117397b7dd94d1fd61c4fb9880ebe447dbc1f 100644 --- a/crates/zeta/src/provider.rs +++ b/crates/zeta/src/provider.rs @@ -78,7 +78,7 @@ impl EditPredictionProvider for ZetaEditPredictionProvider { ) -> bool { let zeta = self.zeta.read(cx); if zeta.edit_prediction_model == ZetaEditPredictionModel::Sweep { - zeta.sweep_ai.api_token.is_some() + zeta.has_sweep_api_token() } else { true } diff --git a/crates/zeta/src/sweep_ai.rs b/crates/zeta/src/sweep_ai.rs index 427051a6b82d5dd3700b4970d039bfba8d6563c3..8fd5398f3facc807d99951c48c749e9247fb5670 100644 --- a/crates/zeta/src/sweep_ai.rs +++ b/crates/zeta/src/sweep_ai.rs @@ -1,6 +1,7 @@ -use anyhow::Result; +use anyhow::{Context as _, Result}; use cloud_llm_client::predict_edits_v3::Event; -use futures::AsyncReadExt as _; +use credentials_provider::CredentialsProvider; +use futures::{AsyncReadExt as _, FutureExt, future::Shared}; use gpui::{ App, AppContext as _, Entity, Task, http_client::{self, AsyncBody, Method}, @@ -23,18 +24,23 @@ use crate::{EditPredictionId, EditPredictionInputs, prediction::EditPredictionRe const SWEEP_API_URL: &str = "https://autocomplete.sweep.dev/backend/next_edit_autocomplete"; pub struct SweepAi { - pub api_token: Option, + pub api_token: Shared>>, pub debug_info: Arc, } impl SweepAi { pub fn new(cx: &App) -> Self { SweepAi { - api_token: std::env::var("SWEEP_AI_TOKEN").ok(), + api_token: load_api_token(cx).shared(), debug_info: debug_info(cx), } } + pub fn set_api_token(&mut self, api_token: Option, cx: &mut App) -> Task> { + self.api_token = Task::ready(api_token.clone()).shared(); + store_api_token_in_keychain(api_token, cx) + } + pub fn request_prediction_with_sweep( &self, project: &Entity, @@ -47,7 +53,7 @@ impl SweepAi { cx: &mut App, ) -> Task>> { let debug_info = self.debug_info.clone(); - let Some(api_token) = self.api_token.clone() else { + let Some(api_token) = self.api_token.clone().now_or_never().flatten() else { return Task::ready(Ok(None)); }; let full_path: Arc = snapshot @@ -260,6 +266,49 @@ impl SweepAi { } } +pub const SWEEP_CREDENTIALS_URL: &str = "https://autocomplete.sweep.dev"; +pub const SWEEP_CREDENTIALS_USERNAME: &str = "sweep-api-token"; + +pub fn load_api_token(cx: &App) -> Task> { + if let Some(api_token) = std::env::var("SWEEP_AI_TOKEN") + .ok() + .filter(|value| !value.is_empty()) + { + return Task::ready(Some(api_token)); + } + let credentials_provider = ::global(cx); + cx.spawn(async move |cx| { + let (_, credentials) = credentials_provider + .read_credentials(SWEEP_CREDENTIALS_URL, &cx) + .await + .ok()??; + String::from_utf8(credentials).ok() + }) +} + +fn store_api_token_in_keychain(api_token: Option, cx: &App) -> Task> { + let credentials_provider = ::global(cx); + + cx.spawn(async move |cx| { + if let Some(api_token) = api_token { + credentials_provider + .write_credentials( + SWEEP_CREDENTIALS_URL, + SWEEP_CREDENTIALS_USERNAME, + api_token.as_bytes(), + cx, + ) + .await + .context("Failed to save Sweep API token to system keychain") + } else { + credentials_provider + .delete_credentials(SWEEP_CREDENTIALS_URL, cx) + .await + .context("Failed to delete Sweep API token from system keychain") + } + }) +} + #[derive(Debug, Clone, Serialize)] struct AutocompleteRequest { pub debug_info: Arc, diff --git a/crates/zeta/src/zeta.rs b/crates/zeta/src/zeta.rs index 8fda34133343e465b1b56835b116770b856cfe36..1f0edf99e59c7efca69f6f0ac3b3c9169c33b373 100644 --- a/crates/zeta/src/zeta.rs +++ b/crates/zeta/src/zeta.rs @@ -20,7 +20,7 @@ use edit_prediction_context::{ }; use feature_flags::{FeatureFlag, FeatureFlagAppExt as _, PredictEditsRateCompletionsFeatureFlag}; use futures::channel::{mpsc, oneshot}; -use futures::{AsyncReadExt as _, StreamExt as _}; +use futures::{AsyncReadExt as _, FutureExt as _, StreamExt as _}; use gpui::{ App, AsyncApp, Entity, EntityId, Global, SharedString, Subscription, Task, WeakEntity, actions, http_client::{self, AsyncBody, Method}, @@ -61,7 +61,7 @@ mod prediction; mod provider; mod rate_prediction_modal; pub mod retrieval_search; -mod sweep_ai; +pub mod sweep_ai; pub mod udiff; mod xml_edits; pub mod zeta1; @@ -80,7 +80,7 @@ use crate::rate_prediction_modal::{ NextEdit, PreviousEdit, RatePredictionsModal, ThumbsDownActivePrediction, ThumbsUpActivePrediction, }; -use crate::sweep_ai::SweepAi; +pub use crate::sweep_ai::SweepAi; use crate::zeta1::request_prediction_with_zeta1; pub use provider::ZetaEditPredictionProvider; @@ -193,7 +193,7 @@ pub struct Zeta { #[cfg(feature = "eval-support")] eval_cache: Option>, edit_prediction_model: ZetaEditPredictionModel, - sweep_ai: SweepAi, + pub sweep_ai: SweepAi, data_collection_choice: DataCollectionChoice, rejected_predictions: Vec, reject_predictions_tx: mpsc::UnboundedSender<()>, @@ -553,7 +553,12 @@ impl Zeta { } pub fn has_sweep_api_token(&self) -> bool { - self.sweep_ai.api_token.is_some() + self.sweep_ai + .api_token + .clone() + .now_or_never() + .flatten() + .is_some() } #[cfg(feature = "eval-support")]