Detailed changes
@@ -5299,6 +5299,7 @@ dependencies = [
"indoc",
"language",
"lsp",
+ "menu",
"paths",
"project",
"regex",
@@ -5308,6 +5309,7 @@ dependencies = [
"telemetry",
"theme",
"ui",
+ "ui_input",
"util",
"workspace",
"zed_actions",
@@ -21678,6 +21680,7 @@ dependencies = [
"collections",
"command_palette_hooks",
"copilot",
+ "credentials_provider",
"ctor",
"db",
"edit_prediction",
@@ -32,6 +32,8 @@ settings.workspace = true
supermaven.workspace = true
telemetry.workspace = true
ui.workspace = true
+ui_input.workspace = true
+menu.workspace = true
util.workspace = true
workspace.workspace = true
zed_actions.workspace = true
@@ -1,3 +1,7 @@
+mod sweep_api_token_modal;
+
+pub use sweep_api_token_modal::SweepApiKeyModal;
+
use anyhow::Result;
use client::{Client, UserStore, zed_urls};
use cloud_llm_client::UsageLimit;
@@ -40,8 +44,7 @@ use workspace::{
notifications::NotificationId,
};
use zed_actions::OpenBrowser;
-use zeta::RateCompletions;
-use zeta::{SweepFeatureFlag, Zeta2FeatureFlag};
+use zeta::{RateCompletions, SweepFeatureFlag, Zeta2FeatureFlag};
actions!(
edit_prediction,
@@ -313,6 +316,10 @@ impl Render for EditPredictionButton {
)
);
+ let sweep_missing_token = is_sweep
+ && !zeta::Zeta::try_global(cx)
+ .map_or(false, |zeta| zeta.read(cx).has_sweep_api_token());
+
let zeta_icon = match (is_sweep, enabled) {
(true, _) => IconName::SweepAi,
(false, true) => IconName::ZedPredict,
@@ -360,19 +367,24 @@ impl Render for EditPredictionButton {
let show_editor_predictions = self.editor_show_predictions;
let user = self.user_store.read(cx).current_user();
+ let indicator_color = if sweep_missing_token {
+ Some(Color::Error)
+ } else if enabled && (!show_editor_predictions || over_limit) {
+ Some(if over_limit {
+ Color::Error
+ } else {
+ Color::Muted
+ })
+ } else {
+ None
+ };
+
let icon_button = IconButton::new("zed-predict-pending-button", zeta_icon)
.shape(IconButtonShape::Square)
- .when(
- enabled && (!show_editor_predictions || over_limit),
- |this| {
- this.indicator(Indicator::dot().when_else(
- over_limit,
- |dot| dot.color(Color::Error),
- |dot| dot.color(Color::Muted),
- ))
+ .when_some(indicator_color, |this, color| {
+ this.indicator(Indicator::dot().color(color))
.indicator_border_color(Some(cx.theme().colors().status_bar_background))
- },
- )
+ })
.when(!self.popover_menu_handle.is_deployed(), |element| {
let user = user.clone();
element.tooltip(move |_window, cx| {
@@ -537,23 +549,23 @@ impl EditPredictionButton {
const ZED_AI_CALLOUT: &str =
"Zed's edit prediction is powered by Zeta, an open-source, dataset mode.";
- const USE_SWEEP_API_TOKEN_CALLOUT: &str =
- "Set the SWEEP_API_TOKEN environment variable to use Sweep";
- let other_providers: Vec<_> = available_providers
+ let providers: Vec<_> = available_providers
.into_iter()
- .filter(|p| *p != current_provider && *p != EditPredictionProvider::None)
+ .filter(|p| *p != EditPredictionProvider::None)
.collect();
- if !other_providers.is_empty() {
- menu = menu.separator().header("Switch Providers");
+ if !providers.is_empty() {
+ menu = menu.separator().header("Providers");
- for provider in other_providers {
+ for provider in providers {
+ let is_current = provider == current_provider;
let fs = self.fs.clone();
menu = match provider {
EditPredictionProvider::Zed => menu.item(
ContextMenuEntry::new("Zed AI")
+ .toggleable(IconPosition::Start, is_current)
.documentation_aside(
DocumentationSide::Left,
DocumentationEdge::Bottom,
@@ -563,46 +575,77 @@ impl EditPredictionButton {
set_completion_provider(fs.clone(), cx, provider);
}),
),
- EditPredictionProvider::Copilot => {
- menu.entry("GitHub Copilot", None, move |_, cx| {
- set_completion_provider(fs.clone(), cx, provider);
- })
- }
- EditPredictionProvider::Supermaven => {
- menu.entry("Supermaven", None, move |_, cx| {
- set_completion_provider(fs.clone(), cx, provider);
- })
- }
- EditPredictionProvider::Codestral => {
- menu.entry("Codestral", None, move |_, cx| {
- set_completion_provider(fs.clone(), cx, provider);
- })
- }
+ EditPredictionProvider::Copilot => menu.item(
+ ContextMenuEntry::new("GitHub Copilot")
+ .toggleable(IconPosition::Start, is_current)
+ .handler(move |_, cx| {
+ set_completion_provider(fs.clone(), cx, provider);
+ }),
+ ),
+ EditPredictionProvider::Supermaven => menu.item(
+ ContextMenuEntry::new("Supermaven")
+ .toggleable(IconPosition::Start, is_current)
+ .handler(move |_, cx| {
+ set_completion_provider(fs.clone(), cx, provider);
+ }),
+ ),
+ EditPredictionProvider::Codestral => menu.item(
+ ContextMenuEntry::new("Codestral")
+ .toggleable(IconPosition::Start, is_current)
+ .handler(move |_, cx| {
+ set_completion_provider(fs.clone(), cx, provider);
+ }),
+ ),
EditPredictionProvider::Experimental(
EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME,
) => {
let has_api_token = zeta::Zeta::try_global(cx)
.map_or(false, |zeta| zeta.read(cx).has_sweep_api_token());
- let entry = ContextMenuEntry::new("Sweep")
- .when(!has_api_token, |this| {
- this.disabled(true).documentation_aside(
+ let should_open_modal = !has_api_token || is_current;
+
+ let entry = if has_api_token {
+ ContextMenuEntry::new("Sweep")
+ .toggleable(IconPosition::Start, is_current)
+ } else {
+ ContextMenuEntry::new("Sweep")
+ .icon(IconName::XCircle)
+ .icon_color(Color::Error)
+ .documentation_aside(
DocumentationSide::Left,
DocumentationEdge::Bottom,
- |_| Label::new(USE_SWEEP_API_TOKEN_CALLOUT).into_any_element(),
+ |_| {
+ Label::new("Click to configure your Sweep API token")
+ .into_any_element()
+ },
)
- })
- .handler(move |_, cx| {
+ };
+
+ let entry = entry.handler(move |window, cx| {
+ if should_open_modal {
+ if let Some(workspace) = window.root::<Workspace>().flatten() {
+ workspace.update(cx, |workspace, cx| {
+ workspace.toggle_modal(window, cx, |window, cx| {
+ SweepApiKeyModal::new(window, cx)
+ });
+ });
+ };
+ } else {
set_completion_provider(fs.clone(), cx, provider);
- });
+ }
+ });
menu.item(entry)
}
EditPredictionProvider::Experimental(
EXPERIMENTAL_ZETA2_EDIT_PREDICTION_PROVIDER_NAME,
- ) => menu.entry("Zeta2", None, move |_, cx| {
- set_completion_provider(fs.clone(), cx, provider);
- }),
+ ) => menu.item(
+ ContextMenuEntry::new("Zeta2")
+ .toggleable(IconPosition::Start, is_current)
+ .handler(move |_, cx| {
+ set_completion_provider(fs.clone(), cx, provider);
+ }),
+ ),
EditPredictionProvider::None | EditPredictionProvider::Experimental(_) => {
continue;
}
@@ -0,0 +1,84 @@
+use gpui::{
+ DismissEvent, Entity, EventEmitter, FocusHandle, Focusable, IntoElement, ParentElement, Render,
+};
+use ui::{Button, ButtonStyle, Clickable, Headline, HeadlineSize, prelude::*};
+use ui_input::InputField;
+use workspace::ModalView;
+use zeta::Zeta;
+
+pub struct SweepApiKeyModal {
+ api_key_input: Entity<InputField>,
+ focus_handle: FocusHandle,
+}
+
+impl SweepApiKeyModal {
+ pub fn new(window: &mut Window, cx: &mut Context<Self>) -> Self {
+ let api_key_input = cx.new(|cx| InputField::new(window, cx, "Enter your Sweep API token"));
+
+ Self {
+ api_key_input,
+ focus_handle: cx.focus_handle(),
+ }
+ }
+
+ fn cancel(&mut self, _: &menu::Cancel, _window: &mut Window, cx: &mut Context<Self>) {
+ cx.emit(DismissEvent);
+ }
+
+ fn confirm(&mut self, _: &menu::Confirm, _window: &mut Window, cx: &mut Context<Self>) {
+ let api_key = self.api_key_input.read(cx).text(cx);
+ let api_key = (!api_key.trim().is_empty()).then_some(api_key);
+
+ if let Some(zeta) = Zeta::try_global(cx) {
+ zeta.update(cx, |zeta, cx| {
+ zeta.sweep_ai
+ .set_api_token(api_key, cx)
+ .detach_and_log_err(cx);
+ });
+ }
+
+ cx.emit(DismissEvent);
+ }
+}
+
+impl EventEmitter<DismissEvent> for SweepApiKeyModal {}
+
+impl ModalView for SweepApiKeyModal {}
+
+impl Focusable for SweepApiKeyModal {
+ fn focus_handle(&self, _cx: &App) -> FocusHandle {
+ self.focus_handle.clone()
+ }
+}
+
+impl Render for SweepApiKeyModal {
+ fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
+ v_flex()
+ .key_context("SweepApiKeyModal")
+ .on_action(cx.listener(Self::cancel))
+ .on_action(cx.listener(Self::confirm))
+ .elevation_2(cx)
+ .w(px(400.))
+ .p_4()
+ .gap_3()
+ .child(Headline::new("Sweep API Token").size(HeadlineSize::Small))
+ .child(self.api_key_input.clone())
+ .child(
+ h_flex()
+ .justify_end()
+ .gap_2()
+ .child(Button::new("cancel", "Cancel").on_click(cx.listener(
+ |_, _, _window, cx| {
+ cx.emit(DismissEvent);
+ },
+ )))
+ .child(
+ Button::new("save", "Save")
+ .style(ButtonStyle::Filled)
+ .on_click(cx.listener(|this, _, window, cx| {
+ this.confirm(&menu::Confirm, window, cx);
+ })),
+ ),
+ )
+ }
+}
@@ -23,9 +23,10 @@ buffer_diff.workspace = true
client.workspace = true
cloud_llm_client.workspace = true
cloud_zeta2_prompt.workspace = true
-copilot.workspace = true
collections.workspace = true
command_palette_hooks.workspace = true
+copilot.workspace = true
+credentials_provider.workspace = true
db.workspace = true
edit_prediction.workspace = true
edit_prediction_context.workspace = true
@@ -43,12 +44,12 @@ lsp.workspace = true
markdown.workspace = true
menu.workspace = true
open_ai.workspace = true
-pretty_assertions.workspace = true
postage.workspace = true
+pretty_assertions.workspace = true
project.workspace = true
rand.workspace = true
-release_channel.workspace = true
regex.workspace = true
+release_channel.workspace = true
semver.workspace = true
serde.workspace = true
serde_json.workspace = true
@@ -60,8 +61,8 @@ telemetry.workspace = true
telemetry_events.workspace = true
theme.workspace = true
thiserror.workspace = true
-util.workspace = true
ui.workspace = true
+util.workspace = true
uuid.workspace = true
workspace.workspace = true
worktree.workspace = true
@@ -78,7 +78,7 @@ impl EditPredictionProvider for ZetaEditPredictionProvider {
) -> bool {
let zeta = self.zeta.read(cx);
if zeta.edit_prediction_model == ZetaEditPredictionModel::Sweep {
- zeta.sweep_ai.api_token.is_some()
+ zeta.has_sweep_api_token()
} else {
true
}
@@ -1,6 +1,7 @@
-use anyhow::Result;
+use anyhow::{Context as _, Result};
use cloud_llm_client::predict_edits_v3::Event;
-use futures::AsyncReadExt as _;
+use credentials_provider::CredentialsProvider;
+use futures::{AsyncReadExt as _, FutureExt, future::Shared};
use gpui::{
App, AppContext as _, Entity, Task,
http_client::{self, AsyncBody, Method},
@@ -23,18 +24,23 @@ use crate::{EditPredictionId, EditPredictionInputs, prediction::EditPredictionRe
const SWEEP_API_URL: &str = "https://autocomplete.sweep.dev/backend/next_edit_autocomplete";
pub struct SweepAi {
- pub api_token: Option<String>,
+ pub api_token: Shared<Task<Option<String>>>,
pub debug_info: Arc<str>,
}
impl SweepAi {
pub fn new(cx: &App) -> Self {
SweepAi {
- api_token: std::env::var("SWEEP_AI_TOKEN").ok(),
+ api_token: load_api_token(cx).shared(),
debug_info: debug_info(cx),
}
}
+ pub fn set_api_token(&mut self, api_token: Option<String>, cx: &mut App) -> Task<Result<()>> {
+ self.api_token = Task::ready(api_token.clone()).shared();
+ store_api_token_in_keychain(api_token, cx)
+ }
+
pub fn request_prediction_with_sweep(
&self,
project: &Entity<Project>,
@@ -47,7 +53,7 @@ impl SweepAi {
cx: &mut App,
) -> Task<Result<Option<EditPredictionResult>>> {
let debug_info = self.debug_info.clone();
- let Some(api_token) = self.api_token.clone() else {
+ let Some(api_token) = self.api_token.clone().now_or_never().flatten() else {
return Task::ready(Ok(None));
};
let full_path: Arc<Path> = snapshot
@@ -260,6 +266,49 @@ impl SweepAi {
}
}
+pub const SWEEP_CREDENTIALS_URL: &str = "https://autocomplete.sweep.dev";
+pub const SWEEP_CREDENTIALS_USERNAME: &str = "sweep-api-token";
+
+pub fn load_api_token(cx: &App) -> Task<Option<String>> {
+ if let Some(api_token) = std::env::var("SWEEP_AI_TOKEN")
+ .ok()
+ .filter(|value| !value.is_empty())
+ {
+ return Task::ready(Some(api_token));
+ }
+ let credentials_provider = <dyn CredentialsProvider>::global(cx);
+ cx.spawn(async move |cx| {
+ let (_, credentials) = credentials_provider
+ .read_credentials(SWEEP_CREDENTIALS_URL, &cx)
+ .await
+ .ok()??;
+ String::from_utf8(credentials).ok()
+ })
+}
+
+fn store_api_token_in_keychain(api_token: Option<String>, cx: &App) -> Task<Result<()>> {
+ let credentials_provider = <dyn CredentialsProvider>::global(cx);
+
+ cx.spawn(async move |cx| {
+ if let Some(api_token) = api_token {
+ credentials_provider
+ .write_credentials(
+ SWEEP_CREDENTIALS_URL,
+ SWEEP_CREDENTIALS_USERNAME,
+ api_token.as_bytes(),
+ cx,
+ )
+ .await
+ .context("Failed to save Sweep API token to system keychain")
+ } else {
+ credentials_provider
+ .delete_credentials(SWEEP_CREDENTIALS_URL, cx)
+ .await
+ .context("Failed to delete Sweep API token from system keychain")
+ }
+ })
+}
+
#[derive(Debug, Clone, Serialize)]
struct AutocompleteRequest {
pub debug_info: Arc<str>,
@@ -20,7 +20,7 @@ use edit_prediction_context::{
};
use feature_flags::{FeatureFlag, FeatureFlagAppExt as _, PredictEditsRateCompletionsFeatureFlag};
use futures::channel::{mpsc, oneshot};
-use futures::{AsyncReadExt as _, StreamExt as _};
+use futures::{AsyncReadExt as _, FutureExt as _, StreamExt as _};
use gpui::{
App, AsyncApp, Entity, EntityId, Global, SharedString, Subscription, Task, WeakEntity, actions,
http_client::{self, AsyncBody, Method},
@@ -61,7 +61,7 @@ mod prediction;
mod provider;
mod rate_prediction_modal;
pub mod retrieval_search;
-mod sweep_ai;
+pub mod sweep_ai;
pub mod udiff;
mod xml_edits;
pub mod zeta1;
@@ -80,7 +80,7 @@ use crate::rate_prediction_modal::{
NextEdit, PreviousEdit, RatePredictionsModal, ThumbsDownActivePrediction,
ThumbsUpActivePrediction,
};
-use crate::sweep_ai::SweepAi;
+pub use crate::sweep_ai::SweepAi;
use crate::zeta1::request_prediction_with_zeta1;
pub use provider::ZetaEditPredictionProvider;
@@ -193,7 +193,7 @@ pub struct Zeta {
#[cfg(feature = "eval-support")]
eval_cache: Option<Arc<dyn EvalCache>>,
edit_prediction_model: ZetaEditPredictionModel,
- sweep_ai: SweepAi,
+ pub sweep_ai: SweepAi,
data_collection_choice: DataCollectionChoice,
rejected_predictions: Vec<EditPredictionRejection>,
reject_predictions_tx: mpsc::UnboundedSender<()>,
@@ -553,7 +553,12 @@ impl Zeta {
}
pub fn has_sweep_api_token(&self) -> bool {
- self.sweep_ai.api_token.is_some()
+ self.sweep_ai
+ .api_token
+ .clone()
+ .now_or_never()
+ .flatten()
+ .is_some()
}
#[cfg(feature = "eval-support")]