Detailed changes
@@ -4021,7 +4021,7 @@ dependencies = [
"util",
"uuid",
"workspace",
- "zed_predict_tos",
+ "zed_predict_onboarding",
]
[[package]]
@@ -6415,7 +6415,7 @@ dependencies = [
"ui",
"workspace",
"zed_actions",
- "zed_predict_tos",
+ "zed_predict_onboarding",
"zeta",
]
@@ -13541,6 +13541,7 @@ dependencies = [
"windows 0.58.0",
"workspace",
"zed_actions",
+ "zed_predict_onboarding",
]
[[package]]
@@ -16557,7 +16558,7 @@ dependencies = [
"winresource",
"workspace",
"zed_actions",
- "zed_predict_tos",
+ "zed_predict_onboarding",
"zeta",
]
@@ -16672,13 +16673,21 @@ dependencies = [
]
[[package]]
-name = "zed_predict_tos"
+name = "zed_predict_onboarding"
version = "0.1.0"
dependencies = [
+ "chrono",
"client",
+ "db",
+ "feature_flags",
+ "fs",
"gpui",
+ "language",
"menu",
+ "settings",
+ "theme",
"ui",
+ "util",
"workspace",
]
@@ -16872,6 +16881,7 @@ dependencies = [
"collections",
"command_palette_hooks",
"ctor",
+ "db",
"editor",
"env_logger 0.11.6",
"feature_flags",
@@ -16886,6 +16896,7 @@ dependencies = [
"menu",
"reqwest_client",
"rpc",
+ "serde",
"serde_json",
"settings",
"similar",
@@ -152,7 +152,7 @@ members = [
"crates/worktree",
"crates/zed",
"crates/zed_actions",
- "crates/zed_predict_tos",
+ "crates/zed_predict_onboarding",
"crates/zeta",
#
@@ -201,7 +201,6 @@ edition = "2021"
activity_indicator = { path = "crates/activity_indicator" }
ai = { path = "crates/ai" }
-zed_predict_tos = { path = "crates/zed_predict_tos" }
anthropic = { path = "crates/anthropic" }
assets = { path = "crates/assets" }
assistant = { path = "crates/assistant" }
@@ -350,6 +349,7 @@ workspace = { path = "crates/workspace" }
worktree = { path = "crates/worktree" }
zed = { path = "crates/zed" }
zed_actions = { path = "crates/zed_actions" }
+zed_predict_onboarding = { path = "crates/zed_predict_onboarding" }
zeta = { path = "crates/zeta" }
#
@@ -1,5 +1,5 @@
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
-<path d="M7 8.9V11C5.34478 11 4.65522 11 3 11V10.4L7 5.6V5H3V7.1" stroke="black" stroke-width="1.5"/>
<path d="M12 5L14 8L12 11" stroke="black" stroke-width="1.5"/>
<path d="M10 6.5L11 8L10 9.5" stroke="black" stroke-width="1.5"/>
+<path d="M7.5 8.9V11C5.43097 11 4.56903 11 2.5 11V10.4L7.5 5.6V5H2.5V7.1" stroke="black" stroke-width="1.5"/>
</svg>
@@ -0,0 +1,19 @@
+<svg width="420" height="128" xmlns="http://www.w3.org/2000/svg">
+ <defs>
+ <pattern id="tilePattern" width="22" height="22" patternUnits="userSpaceOnUse">
+ <svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
+ <path d="M12 5L14 8L12 11" stroke="black" stroke-width="1.5"/>
+ <path d="M10 6.5L11 8L10 9.5" stroke="black" stroke-width="1.5"/>
+ <path d="M7.5 8.9V11C5.43097 11 4.56903 11 2.5 11V10.4L7.5 5.6V5H2.5V7.1" stroke="black" stroke-width="1.5"/>
+ </svg>
+ </pattern>
+ <linearGradient id="fade" y2="1" x2="0">
+ <stop offset="0" stop-color="white" stop-opacity=".24"/>
+ <stop offset="1" stop-color="white" stop-opacity="0"/>
+ </linearGradient>
+ <mask id="fadeMask" maskContentUnits="objectBoundingBox">
+ <rect width="1" height="1" fill="url(#fade)"/>
+ </mask>
+ </defs>
+ <rect width="100%" height="100%" fill="url(#tilePattern)" mask="url(#fadeMask)"/>
+</svg>
@@ -823,5 +823,12 @@
"shift-end": "terminal::ScrollToBottom",
"ctrl-shift-space": "terminal::ToggleViMode"
}
+ },
+ {
+ "context": "ZedPredictModal",
+ "use_key_equivalents": true,
+ "bindings": {
+ "escape": "menu::Cancel"
+ }
}
]
@@ -883,7 +883,7 @@
}
},
{
- "context": "ZedPredictTos",
+ "context": "ZedPredictModal",
"use_key_equivalents": true,
"bindings": {
"escape": "menu::Cancel"
@@ -121,9 +121,7 @@ pub enum Event {
},
ShowContacts,
ParticipantIndicesChanged,
- TermsStatusUpdated {
- accepted: bool,
- },
+ PrivateUserInfoUpdated,
}
#[derive(Clone, Copy)]
@@ -227,9 +225,7 @@ impl UserStore {
};
this.set_current_user_accepted_tos_at(accepted_tos_at);
- cx.emit(Event::TermsStatusUpdated {
- accepted: accepted_tos_at.is_some(),
- });
+ cx.emit(Event::PrivateUserInfoUpdated);
})
} else {
anyhow::Ok(())
@@ -244,6 +240,8 @@ impl UserStore {
Status::SignedOut => {
current_user_tx.send(None).await.ok();
this.update(&mut cx, |this, cx| {
+ this.accepted_tos_at = None;
+ cx.emit(Event::PrivateUserInfoUpdated);
cx.notify();
this.clear_contacts()
})?
@@ -714,7 +712,7 @@ impl UserStore {
this.update(&mut cx, |this, cx| {
this.set_current_user_accepted_tos_at(Some(response.accepted_tos_at));
- cx.emit(Event::TermsStatusUpdated { accepted: true });
+ cx.emit(Event::PrivateUserInfoUpdated);
})
} else {
Err(anyhow!("client not found"))
@@ -447,7 +447,7 @@ async fn predict_edits(
));
}
- let sample_input_output = claims.is_staff && rand::random::<f32>() < 0.1;
+ let should_sample = claims.is_staff || params.can_collect_data;
let api_url = state
.config
@@ -541,7 +541,7 @@ async fn predict_edits(
let output = choice.text.clone();
async move {
- let properties = if sample_input_output {
+ let properties = if should_sample {
json!({
"model": model.to_string(),
"headers": response.headers,
@@ -88,7 +88,7 @@ url.workspace = true
util.workspace = true
uuid.workspace = true
workspace.workspace = true
-zed_predict_tos.workspace = true
+zed_predict_onboarding.workspace = true
[dev-dependencies]
ctor.workspace = true
@@ -652,7 +652,7 @@ impl CompletionsMenu {
)
.on_click(cx.listener(move |editor, _event, window, cx| {
cx.stop_propagation();
- editor.toggle_zed_predict_tos(window, cx);
+ editor.toggle_zed_predict_onboarding(window, cx);
})),
),
@@ -69,7 +69,7 @@ pub use element::{
};
use futures::{future, FutureExt};
use fuzzy::StringMatchCandidate;
-use zed_predict_tos::ZedPredictTos;
+use zed_predict_onboarding::ZedPredictModal;
use code_context_menus::{
AvailableCodeAction, CodeActionContents, CodeActionsItem, CodeActionsMenu, CodeContextMenu,
@@ -3948,12 +3948,21 @@ impl Editor {
self.do_completion(action.item_ix, CompletionIntent::Compose, window, cx)
}
- fn toggle_zed_predict_tos(&mut self, window: &mut Window, cx: &mut Context<Self>) {
+ fn toggle_zed_predict_onboarding(&mut self, window: &mut Window, cx: &mut Context<Self>) {
let (Some(workspace), Some(project)) = (self.workspace(), self.project.as_ref()) else {
return;
};
- ZedPredictTos::toggle(workspace, project.read(cx).user_store().clone(), window, cx);
+ let project = project.read(cx);
+
+ ZedPredictModal::toggle(
+ workspace,
+ project.user_store().clone(),
+ project.client().clone(),
+ project.fs().clone(),
+ window,
+ cx,
+ );
}
fn do_completion(
@@ -3985,7 +3994,7 @@ impl Editor {
)) => {
drop(entries);
drop(context_menu);
- self.toggle_zed_predict_tos(window, cx);
+ self.toggle_zed_predict_onboarding(window, cx);
return Some(Task::ready(Ok(())));
}
_ => {}
@@ -87,8 +87,8 @@ define_connection!(
// mtime_seconds: Option<i64>,
// mtime_nanos: Option<i32>,
// )
- pub static ref DB: EditorDb<WorkspaceDb> =
- &[sql! (
+ pub static ref DB: EditorDb<WorkspaceDb> = &[
+ sql! (
CREATE TABLE editors(
item_id INTEGER NOT NULL,
workspace_id INTEGER NOT NULL,
@@ -134,7 +134,7 @@ define_connection!(
ALTER TABLE editors ADD COLUMN mtime_seconds INTEGER DEFAULT NULL;
ALTER TABLE editors ADD COLUMN mtime_nanos INTEGER DEFAULT NULL;
),
- ];
+ ];
);
impl EditorDb {
@@ -18,6 +18,31 @@ pub struct InlineCompletion {
pub edit_preview: Option<language::EditPreview>,
}
+pub enum DataCollectionState {
+ /// The provider doesn't support data collection.
+ Unsupported,
+ /// When there's a file not saved yet. In this case, we can't tell to which project it belongs.
+ Unknown,
+ /// Data collection is enabled
+ Enabled,
+ /// Data collection is disabled or unanswered.
+ Disabled,
+}
+
+impl DataCollectionState {
+ pub fn is_supported(&self) -> bool {
+ !matches!(self, DataCollectionState::Unsupported)
+ }
+
+ pub fn is_unknown(&self) -> bool {
+ matches!(self, DataCollectionState::Unknown)
+ }
+
+ pub fn is_enabled(&self) -> bool {
+ matches!(self, DataCollectionState::Enabled)
+ }
+}
+
pub trait InlineCompletionProvider: 'static + Sized {
fn name() -> &'static str;
fn display_name() -> &'static str;
@@ -26,6 +51,10 @@ pub trait InlineCompletionProvider: 'static + Sized {
fn show_tab_accept_marker() -> bool {
false
}
+ fn data_collection_state(&self, _cx: &App) -> DataCollectionState {
+ DataCollectionState::Unsupported
+ }
+ fn toggle_data_collection(&mut self, _cx: &mut App) {}
fn is_enabled(
&self,
buffer: &Entity<Buffer>,
@@ -72,6 +101,8 @@ pub trait InlineCompletionProviderHandle {
fn show_completions_in_menu(&self) -> bool;
fn show_completions_in_normal_mode(&self) -> bool;
fn show_tab_accept_marker(&self) -> bool;
+ fn data_collection_state(&self, cx: &App) -> DataCollectionState;
+ fn toggle_data_collection(&self, cx: &mut App);
fn needs_terms_acceptance(&self, cx: &App) -> bool;
fn is_refreshing(&self, cx: &App) -> bool;
fn refresh(
@@ -122,6 +153,14 @@ where
T::show_tab_accept_marker()
}
+ fn data_collection_state(&self, cx: &App) -> DataCollectionState {
+ self.read(cx).data_collection_state(cx)
+ }
+
+ fn toggle_data_collection(&self, cx: &mut App) {
+ self.update(cx, |this, cx| this.toggle_data_collection(cx))
+ }
+
fn is_enabled(
&self,
buffer: &Entity<Buffer>,
@@ -29,7 +29,7 @@ workspace.workspace = true
zed_actions.workspace = true
zeta.workspace = true
client.workspace = true
-zed_predict_tos.workspace = true
+zed_predict_onboarding.workspace = true
[dev-dependencies]
copilot = { workspace = true, features = ["test-support"] }
@@ -1,5 +1,5 @@
use anyhow::Result;
-use client::UserStore;
+use client::{Client, UserStore};
use copilot::{Copilot, Status};
use editor::{actions::ShowInlineCompletion, scroll::Autoscroll, Editor};
use feature_flags::{
@@ -20,18 +20,16 @@ use language::{
use settings::{update_settings_file, Settings, SettingsStore};
use std::{path::Path, sync::Arc, time::Duration};
use supermaven::{AccountStatus, Supermaven};
-use ui::{prelude::*, ButtonLike, Color, Icon, IconWithIndicator, Indicator, PopoverMenuHandle};
+use ui::{
+ prelude::*, ButtonLike, Clickable, ContextMenu, ContextMenuEntry, IconButton,
+ IconWithIndicator, Indicator, PopoverMenu, PopoverMenuHandle, Tooltip,
+};
use workspace::{
- create_and_open_local_file,
- item::ItemHandle,
- notifications::NotificationId,
- ui::{
- ButtonCommon, Clickable, ContextMenu, IconButton, IconName, IconSize, PopoverMenu, Tooltip,
- },
- StatusItemView, Toast, Workspace,
+ create_and_open_local_file, item::ItemHandle, notifications::NotificationId, StatusItemView,
+ Toast, Workspace,
};
use zed_actions::OpenBrowser;
-use zed_predict_tos::ZedPredictTos;
+use zed_predict_onboarding::ZedPredictModal;
use zeta::RateCompletionModal;
actions!(zeta, [RateCompletions]);
@@ -48,6 +46,7 @@ pub struct InlineCompletionButton {
language: Option<Arc<Language>>,
file: Option<Arc<dyn File>>,
inline_completion_provider: Option<Arc<dyn inline_completion::InlineCompletionProviderHandle>>,
+ client: Arc<Client>,
fs: Arc<dyn Fs>,
workspace: WeakEntity<Workspace>,
user_store: Entity<UserStore>,
@@ -231,14 +230,16 @@ impl Render for InlineCompletionButton {
return div();
}
- if !self
- .user_store
- .read(cx)
- .current_user_has_accepted_terms()
- .unwrap_or(false)
- {
+ let current_user_terms_accepted =
+ self.user_store.read(cx).current_user_has_accepted_terms();
+
+ if !current_user_terms_accepted.unwrap_or(false) {
let workspace = self.workspace.clone();
let user_store = self.user_store.clone();
+ let client = self.client.clone();
+ let fs = self.fs.clone();
+
+ let signed_in = current_user_terms_accepted.is_some();
return div().child(
ButtonLike::new("zeta-pending-tos-icon")
@@ -252,20 +253,29 @@ impl Render for InlineCompletionButton {
))
.into_any_element(),
)
- .tooltip(|window, cx| {
+ .tooltip(move |window, cx| {
Tooltip::with_meta(
"Edit Predictions",
None,
- "Read Terms of Service",
+ if signed_in {
+ "Read Terms of Service"
+ } else {
+ "Sign in to use"
+ },
window,
cx,
)
})
.on_click(cx.listener(move |_, _, window, cx| {
- let user_store = user_store.clone();
-
if let Some(workspace) = workspace.upgrade() {
- ZedPredictTos::toggle(workspace, user_store, window, cx);
+ ZedPredictModal::toggle(
+ workspace,
+ user_store.clone(),
+ client.clone(),
+ fs.clone(),
+ window,
+ cx,
+ );
}
})),
);
@@ -318,6 +328,7 @@ impl InlineCompletionButton {
workspace: WeakEntity<Workspace>,
fs: Arc<dyn Fs>,
user_store: Entity<UserStore>,
+ client: Arc<Client>,
popover_menu_handle: PopoverMenuHandle<ContextMenu>,
cx: &mut Context<Self>,
) -> Self {
@@ -337,6 +348,7 @@ impl InlineCompletionButton {
inline_completion_provider: None,
popover_menu_handle,
workspace,
+ client,
fs,
user_store,
}
@@ -430,6 +442,22 @@ impl InlineCompletionButton {
move |_, cx| toggle_inline_completions_globally(fs.clone(), cx),
);
+ if let Some(provider) = &self.inline_completion_provider {
+ let data_collection = provider.data_collection_state(cx);
+
+ if data_collection.is_supported() {
+ let provider = provider.clone();
+ menu = menu.separator().item(
+ ContextMenuEntry::new("Data Collection")
+ .toggleable(IconPosition::Start, data_collection.is_enabled())
+ .disabled(data_collection.is_unknown())
+ .handler(move |_, cx| {
+ provider.toggle_data_collection(cx);
+ }),
+ );
+ }
+ }
+
if let Some(editor_focus_handle) = self.editor_focus_handle.clone() {
menu = menu
.separator()
@@ -39,6 +39,9 @@ pub struct PredictEditsParams {
pub outline: Option<String>,
pub input_events: String,
pub input_excerpt: String,
+ /// Whether the user provided consent for sampling this interaction.
+ #[serde(default)]
+ pub can_collect_data: bool,
}
#[derive(Debug, Serialize, Deserialize)]
@@ -48,6 +48,7 @@ telemetry.workspace = true
workspace.workspace = true
zed_actions.workspace = true
git_ui.workspace = true
+zed_predict_onboarding.workspace = true
[target.'cfg(windows)'.dependencies]
windows.workspace = true
@@ -37,6 +37,7 @@ use ui::{
use util::ResultExt;
use workspace::{notifications::NotifyResultExt, Workspace};
use zed_actions::{OpenBrowser, OpenRecent, OpenRemote};
+use zed_predict_onboarding::ZedPredictBanner;
#[cfg(feature = "stories")]
pub use stories::*;
@@ -113,6 +114,7 @@ pub struct TitleBar {
application_menu: Option<Entity<ApplicationMenu>>,
_subscriptions: Vec<Subscription>,
git_ui_enabled: Arc<AtomicBool>,
+ zed_predict_banner: Entity<ZedPredictBanner>,
}
impl Render for TitleBar {
@@ -196,6 +198,7 @@ impl Render for TitleBar {
.on_mouse_down(MouseButton::Left, |_, _, cx| cx.stop_propagation()),
)
.child(self.render_collaborator_list(window, cx))
+ .child(self.zed_predict_banner.clone())
.child(
h_flex()
.gap_1()
@@ -271,6 +274,7 @@ impl TitleBar {
let project = workspace.project().clone();
let user_store = workspace.app_state().user_store.clone();
let client = workspace.app_state().client.clone();
+ let fs = workspace.app_state().fs.clone();
let active_call = ActiveCall::global(cx);
let platform_style = PlatformStyle::platform();
@@ -306,6 +310,16 @@ impl TitleBar {
}
}));
+ let zed_predict_banner = cx.new(|cx| {
+ ZedPredictBanner::new(
+ workspace.weak_handle(),
+ user_store.clone(),
+ client.clone(),
+ fs.clone(),
+ cx,
+ )
+ });
+
Self {
platform_style,
content: div().id(id.into()),
@@ -319,6 +333,7 @@ impl TitleBar {
client,
_subscriptions: subscriptions,
git_ui_enabled: is_git_ui_enabled,
+ zed_predict_banner,
}
}
@@ -64,6 +64,11 @@ impl ContextMenuEntry {
}
}
+ pub fn toggleable(mut self, toggle_position: IconPosition, toggled: bool) -> Self {
+ self.toggle = Some((toggle_position, toggled));
+ self
+ }
+
pub fn icon(mut self, icon: IconName) -> Self {
self.icon = Some(icon);
self
@@ -379,6 +379,12 @@ pub mod simple_message_notification {
click_message: Option<SharedString>,
secondary_click_message: Option<SharedString>,
secondary_on_click: Option<Arc<dyn Fn(&mut Window, &mut Context<Self>)>>,
+ tertiary_click_message: Option<SharedString>,
+ tertiary_on_click: Option<Arc<dyn Fn(&mut Window, &mut Context<Self>)>>,
+ more_info_message: Option<SharedString>,
+ more_info_url: Option<Arc<str>>,
+ show_close_button: bool,
+ title: Option<SharedString>,
}
impl EventEmitter<DismissEvent> for MessageNotification {}
@@ -402,6 +408,12 @@ pub mod simple_message_notification {
click_message: None,
secondary_on_click: None,
secondary_click_message: None,
+ tertiary_on_click: None,
+ tertiary_click_message: None,
+ more_info_message: None,
+ more_info_url: None,
+ show_close_button: true,
+ title: None,
}
}
@@ -437,31 +449,85 @@ pub mod simple_message_notification {
self
}
+ pub fn with_tertiary_click_message<S>(mut self, message: S) -> Self
+ where
+ S: Into<SharedString>,
+ {
+ self.tertiary_click_message = Some(message.into());
+ self
+ }
+
+ pub fn on_tertiary_click<F>(mut self, on_click: F) -> Self
+ where
+ F: 'static + Fn(&mut Window, &mut Context<Self>),
+ {
+ self.tertiary_on_click = Some(Arc::new(on_click));
+ self
+ }
+
+ pub fn more_info_message<S>(mut self, message: S) -> Self
+ where
+ S: Into<SharedString>,
+ {
+ self.more_info_message = Some(message.into());
+ self
+ }
+
+ pub fn more_info_url<S>(mut self, url: S) -> Self
+ where
+ S: Into<Arc<str>>,
+ {
+ self.more_info_url = Some(url.into());
+ self
+ }
+
pub fn dismiss(&mut self, cx: &mut Context<Self>) {
cx.emit(DismissEvent);
}
+
+ pub fn show_close_button(mut self, show: bool) -> Self {
+ self.show_close_button = show;
+ self
+ }
+
+ pub fn with_title<S>(mut self, title: S) -> Self
+ where
+ S: Into<SharedString>,
+ {
+ self.title = Some(title.into());
+ self
+ }
}
impl Render for MessageNotification {
fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
v_flex()
.p_3()
- .gap_2()
+ .gap_3()
.elevation_3(cx)
.child(
h_flex()
.gap_4()
.justify_between()
.items_start()
- .child(div().max_w_96().child((self.build_content)(window, cx)))
.child(
- IconButton::new("close", IconName::Close)
- .on_click(cx.listener(|this, _, _, cx| this.dismiss(cx))),
- ),
+ v_flex()
+ .gap_0p5()
+ .when_some(self.title.clone(), |element, title| {
+ element.child(Label::new(title))
+ })
+ .child(div().max_w_96().child((self.build_content)(window, cx))),
+ )
+ .when(self.show_close_button, |this| {
+ this.child(
+ IconButton::new("close", IconName::Close)
+ .on_click(cx.listener(|this, _, _, cx| this.dismiss(cx))),
+ )
+ }),
)
.child(
h_flex()
- .gap_2()
+ .gap_1()
.children(self.click_message.iter().map(|message| {
Button::new(message.clone(), message.clone())
.label_size(LabelSize::Small)
@@ -489,7 +555,40 @@ pub mod simple_message_notification {
};
this.dismiss(cx)
}))
- })),
+ }))
+ .child(
+ h_flex()
+ .w_full()
+ .gap_1()
+ .justify_end()
+ .children(self.tertiary_click_message.iter().map(|message| {
+ Button::new(message.clone(), message.clone())
+ .label_size(LabelSize::Small)
+ .on_click(cx.listener(|this, _, window, cx| {
+ if let Some(on_click) = this.tertiary_on_click.as_ref()
+ {
+ (on_click)(window, cx)
+ };
+ this.dismiss(cx)
+ }))
+ }))
+ .children(
+ self.more_info_message
+ .iter()
+ .zip(self.more_info_url.iter())
+ .map(|(message, url)| {
+ let url = url.clone();
+ Button::new(message.clone(), message.clone())
+ .label_size(LabelSize::Small)
+ .icon(IconName::ArrowUpRight)
+ .icon_size(IconSize::Indicator)
+ .icon_color(Color::Muted)
+ .on_click(cx.listener(move |_, _, _, cx| {
+ cx.open_url(&url);
+ }))
+ }),
+ ),
+ ),
)
}
}
@@ -58,7 +58,9 @@ use persistence::{
SerializedWindowBounds, DB,
};
use postage::stream::Stream;
-use project::{DirectoryLister, Project, ProjectEntryId, ProjectPath, ResolvedPath, Worktree};
+use project::{
+ DirectoryLister, Project, ProjectEntryId, ProjectPath, ResolvedPath, Worktree, WorktreeId,
+};
use remote::{ssh_session::ConnectionIdentifier, SshClientDelegate, SshConnectionOptions};
use schemars::JsonSchema;
use serde::Deserialize;
@@ -2200,6 +2202,18 @@ impl Workspace {
}
}
+ pub fn absolute_path_of_worktree(
+ &self,
+ worktree_id: WorktreeId,
+ cx: &mut Context<Self>,
+ ) -> Option<PathBuf> {
+ self.project
+ .read(cx)
+ .worktree_for_id(worktree_id, cx)
+ // TODO: use `abs_path` or `root_dir`
+ .map(|wt| wt.read(cx).abs_path().as_ref().to_path_buf())
+ }
+
fn add_folder_to_project(
&mut self,
_: &AddFolderToProject,
@@ -2751,6 +2751,8 @@ impl Snapshot {
self.entry_for_path("")
}
+ /// TODO: what's the difference between `root_dir` and `abs_path`?
+ /// is there any? if so, document it.
pub fn root_dir(&self) -> Option<Arc<Path>> {
self.root_entry()
.filter(|entry| entry.is_dir())
@@ -16,7 +16,7 @@ path = "src/main.rs"
[dependencies]
activity_indicator.workspace = true
-zed_predict_tos.workspace = true
+zed_predict_onboarding.workspace = true
anyhow.workspace = true
assets.workspace = true
assistant.workspace = true
@@ -439,6 +439,7 @@ fn main() {
inline_completion_registry::init(
app_state.client.clone(),
app_state.user_store.clone(),
+ app_state.fs.clone(),
cx,
);
let prompt_builder = PromptBuilder::load(app_state.fs.clone(), stdout_is_a_pty(), cx);
@@ -176,6 +176,7 @@ pub fn initialize_workspace(
workspace.weak_handle(),
app_state.fs.clone(),
app_state.user_store.clone(),
+ app_state.client.clone(),
popover_menu_handle.clone(),
cx,
)
@@ -5,13 +5,17 @@ use collections::HashMap;
use copilot::{Copilot, CopilotCompletionProvider};
use editor::{Editor, EditorMode};
use feature_flags::{FeatureFlagAppExt, PredictEditsFeatureFlag};
-use gpui::{AnyWindowHandle, App, AppContext as _, Context, Entity, WeakEntity, Window};
+use fs::Fs;
+use gpui::{AnyWindowHandle, App, AppContext, Context, Entity, WeakEntity};
use language::language_settings::{all_language_settings, InlineCompletionProvider};
use settings::SettingsStore;
use supermaven::{Supermaven, SupermavenCompletionProvider};
-use zed_predict_tos::ZedPredictTos;
+use ui::Window;
+use workspace::Workspace;
+use zed_predict_onboarding::ZedPredictModal;
+use zeta::ProviderDataCollection;
-pub fn init(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut App) {
+pub fn init(client: Arc<Client>, user_store: Entity<UserStore>, fs: Arc<dyn Fs>, cx: &mut App) {
let editors: Rc<RefCell<HashMap<WeakEntity<Editor>, AnyWindowHandle>>> = Rc::default();
cx.observe_new({
let editors = editors.clone();
@@ -37,6 +41,7 @@ pub fn init(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut App) {
}
})
.detach();
+
editors
.borrow_mut()
.insert(editor_handle, window.window_handle());
@@ -91,6 +96,7 @@ pub fn init(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut App) {
let editors = editors.clone();
let client = client.clone();
let user_store = user_store.clone();
+ let fs = fs.clone();
move |cx| {
let new_provider = all_language_settings(None, cx).inline_completions.provider;
if new_provider != provider {
@@ -123,9 +129,11 @@ pub fn init(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut App) {
window
.update(cx, |_, window, cx| {
- ZedPredictTos::toggle(
+ ZedPredictModal::toggle(
workspace,
user_store.clone(),
+ client.clone(),
+ fs.clone(),
window,
cx,
);
@@ -214,17 +222,19 @@ fn register_backward_compatible_actions(editor: &mut Editor, cx: &mut Context<Ed
fn assign_inline_completion_provider(
editor: &mut Editor,
- provider: language::language_settings::InlineCompletionProvider,
+ provider: InlineCompletionProvider,
client: &Arc<Client>,
user_store: Entity<UserStore>,
window: &mut Window,
cx: &mut Context<Editor>,
) {
+ let singleton_buffer = editor.buffer().read(cx).as_singleton();
+
match provider {
- language::language_settings::InlineCompletionProvider::None => {}
- language::language_settings::InlineCompletionProvider::Copilot => {
+ InlineCompletionProvider::None => {}
+ InlineCompletionProvider::Copilot => {
if let Some(copilot) = Copilot::global(cx) {
- if let Some(buffer) = editor.buffer().read(cx).as_singleton() {
+ if let Some(buffer) = singleton_buffer {
if buffer.read(cx).file().is_some() {
copilot.update(cx, |copilot, cx| {
copilot.register_buffer(&buffer, cx);
@@ -235,26 +245,35 @@ fn assign_inline_completion_provider(
editor.set_inline_completion_provider(Some(provider), window, cx);
}
}
- language::language_settings::InlineCompletionProvider::Supermaven => {
+ InlineCompletionProvider::Supermaven => {
if let Some(supermaven) = Supermaven::global(cx) {
let provider = cx.new(|_| SupermavenCompletionProvider::new(supermaven));
editor.set_inline_completion_provider(Some(provider), window, cx);
}
}
-
- language::language_settings::InlineCompletionProvider::Zed => {
+ InlineCompletionProvider::Zed => {
if cx.has_flag::<PredictEditsFeatureFlag>()
|| (cfg!(debug_assertions) && client.status().borrow().is_connected())
{
let zeta = zeta::Zeta::register(client.clone(), user_store, cx);
- if let Some(buffer) = editor.buffer().read(cx).as_singleton() {
+ if let Some(buffer) = &singleton_buffer {
if buffer.read(cx).file().is_some() {
zeta.update(cx, |zeta, cx| {
zeta.register_buffer(&buffer, cx);
});
}
}
- let provider = cx.new(|_| zeta::ZetaInlineCompletionProvider::new(zeta));
+
+ let data_collection = ProviderDataCollection::new(
+ zeta.clone(),
+ window.root::<Workspace>().flatten(),
+ singleton_buffer,
+ cx,
+ );
+
+ let provider =
+ cx.new(|_| zeta::ZetaInlineCompletionProvider::new(zeta, data_collection));
+
editor.set_inline_completion_provider(Some(provider), window, cx);
}
}
@@ -1,5 +1,5 @@
[package]
-name = "zed_predict_tos"
+name = "zed_predict_onboarding"
version = "0.1.0"
edition = "2021"
publish = false
@@ -9,15 +9,23 @@ license = "GPL-3.0-or-later"
workspace = true
[lib]
-path = "src/zed_predict_tos.rs"
+path = "src/lib.rs"
doctest = false
[features]
test-support = []
[dependencies]
+chrono.workspace = true
client.workspace = true
+db.workspace = true
+feature_flags.workspace = true
+fs.workspace = true
gpui.workspace = true
+language.workspace = true
+menu.workspace = true
+settings.workspace = true
+theme.workspace = true
ui.workspace = true
+util.workspace = true
workspace.workspace = true
-menu.workspace = true
@@ -0,0 +1,168 @@
+use std::sync::Arc;
+
+use crate::ZedPredictModal;
+use chrono::Utc;
+use client::{Client, UserStore};
+use feature_flags::{FeatureFlagAppExt as _, PredictEditsFeatureFlag};
+use fs::Fs;
+use gpui::{Entity, Subscription, WeakEntity};
+use language::language_settings::{all_language_settings, InlineCompletionProvider};
+use settings::SettingsStore;
+use ui::{prelude::*, ButtonLike, Tooltip};
+use util::ResultExt;
+use workspace::Workspace;
+
+/// Prompts user to try AI inline prediction feature
+pub struct ZedPredictBanner {
+ workspace: WeakEntity<Workspace>,
+ user_store: Entity<UserStore>,
+ client: Arc<Client>,
+ fs: Arc<dyn Fs>,
+ dismissed: bool,
+ _subscription: Subscription,
+}
+
+impl ZedPredictBanner {
+ pub fn new(
+ workspace: WeakEntity<Workspace>,
+ user_store: Entity<UserStore>,
+ client: Arc<Client>,
+ fs: Arc<dyn Fs>,
+ cx: &mut Context<Self>,
+ ) -> Self {
+ Self {
+ workspace,
+ user_store,
+ client,
+ fs,
+ dismissed: get_dismissed(),
+ _subscription: cx.observe_global::<SettingsStore>(Self::handle_settings_changed),
+ }
+ }
+
+ fn should_show(&self, cx: &mut App) -> bool {
+ if !cx.has_flag::<PredictEditsFeatureFlag>() || self.dismissed {
+ return false;
+ }
+
+ let provider = all_language_settings(None, cx).inline_completions.provider;
+
+ match provider {
+ InlineCompletionProvider::None
+ | InlineCompletionProvider::Copilot
+ | InlineCompletionProvider::Supermaven => true,
+ InlineCompletionProvider::Zed => false,
+ }
+ }
+
+ fn handle_settings_changed(&mut self, cx: &mut Context<Self>) {
+ if self.dismissed {
+ return;
+ }
+
+ let provider = all_language_settings(None, cx).inline_completions.provider;
+
+ match provider {
+ InlineCompletionProvider::None
+ | InlineCompletionProvider::Copilot
+ | InlineCompletionProvider::Supermaven => {}
+ InlineCompletionProvider::Zed => {
+ self.dismiss(cx);
+ }
+ }
+ }
+
+ fn dismiss(&mut self, cx: &mut Context<Self>) {
+ persist_dismissed(cx);
+ self.dismissed = true;
+ cx.notify();
+ }
+}
+
+const DISMISSED_AT_KEY: &str = "zed_predict_banner_dismissed_at";
+
+pub(crate) fn get_dismissed() -> bool {
+ db::kvp::KEY_VALUE_STORE
+ .read_kvp(DISMISSED_AT_KEY)
+ .log_err()
+ .map_or(false, |dismissed| dismissed.is_some())
+}
+
+pub(crate) fn persist_dismissed(cx: &mut App) {
+ cx.spawn(|_| {
+ let time = Utc::now().to_rfc3339();
+ db::kvp::KEY_VALUE_STORE.write_kvp(DISMISSED_AT_KEY.into(), time)
+ })
+ .detach_and_log_err(cx);
+}
+
+impl Render for ZedPredictBanner {
+ fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
+ if !self.should_show(cx) {
+ return div();
+ }
+
+ let border_color = cx.theme().colors().editor_foreground.opacity(0.3);
+ let banner = h_flex()
+ .rounded_md()
+ .border_1()
+ .border_color(border_color)
+ .child(
+ ButtonLike::new("try-zed-predict")
+ .child(
+ h_flex()
+ .h_full()
+ .items_center()
+ .gap_1p5()
+ .child(Icon::new(IconName::ZedPredict).size(IconSize::Small))
+ .child(
+ h_flex()
+ .gap_0p5()
+ .child(
+ Label::new("Introducing:")
+ .size(LabelSize::Small)
+ .color(Color::Muted),
+ )
+ .child(Label::new("Edit Prediction").size(LabelSize::Small)),
+ ),
+ )
+ .on_click({
+ let workspace = self.workspace.clone();
+ let user_store = self.user_store.clone();
+ let client = self.client.clone();
+ let fs = self.fs.clone();
+ move |_, window, cx| {
+ let Some(workspace) = workspace.upgrade() else {
+ return;
+ };
+ ZedPredictModal::toggle(
+ workspace,
+ user_store.clone(),
+ client.clone(),
+ fs.clone(),
+ window,
+ cx,
+ );
+ }
+ }),
+ )
+ .child(
+ div().border_l_1().border_color(border_color).child(
+ IconButton::new("close", IconName::Close)
+ .icon_size(IconSize::Indicator)
+ .on_click(cx.listener(|this, _, _window, cx| this.dismiss(cx)))
+ .tooltip(|window, cx| {
+ Tooltip::with_meta(
+ "Close Announcement Banner",
+ None,
+ "It won't show again for this feature",
+ window,
+ cx,
+ )
+ }),
+ ),
+ );
+
+ div().pr_1().child(banner)
+ }
+}
@@ -0,0 +1,5 @@
+mod banner;
+mod modal;
+
+pub use banner::ZedPredictBanner;
+pub use modal::ZedPredictModal;
@@ -0,0 +1,313 @@
+use std::{sync::Arc, time::Duration};
+
+use client::{Client, UserStore};
+use feature_flags::FeatureFlagAppExt as _;
+use fs::Fs;
+use gpui::{
+ ease_in_out, svg, Animation, AnimationExt as _, ClickEvent, DismissEvent, Entity, EventEmitter,
+ FocusHandle, Focusable, MouseDownEvent, Render,
+};
+use language::language_settings::{AllLanguageSettings, InlineCompletionProvider};
+use settings::{update_settings_file, Settings};
+use ui::{prelude::*, CheckboxWithLabel, TintColor};
+use workspace::{notifications::NotifyTaskExt, ModalView, Workspace};
+
+/// Introduces user to AI inline prediction feature and terms of service
+pub struct ZedPredictModal {
+ user_store: Entity<UserStore>,
+ client: Arc<Client>,
+ fs: Arc<dyn Fs>,
+ focus_handle: FocusHandle,
+ sign_in_status: SignInStatus,
+ terms_of_service: bool,
+}
+
+#[derive(PartialEq, Eq)]
+enum SignInStatus {
+ /// Signed out or signed in but not from this modal
+ Idle,
+ /// Authentication triggered from this modal
+ Waiting,
+ /// Signed in after authentication from this modal
+ SignedIn,
+}
+
+impl ZedPredictModal {
+ fn new(
+ user_store: Entity<UserStore>,
+ client: Arc<Client>,
+ fs: Arc<dyn Fs>,
+ cx: &mut Context<Self>,
+ ) -> Self {
+ ZedPredictModal {
+ user_store,
+ client,
+ fs,
+ focus_handle: cx.focus_handle(),
+ sign_in_status: SignInStatus::Idle,
+ terms_of_service: false,
+ }
+ }
+
+ pub fn toggle(
+ workspace: Entity<Workspace>,
+ user_store: Entity<UserStore>,
+ client: Arc<Client>,
+ fs: Arc<dyn Fs>,
+ window: &mut Window,
+ cx: &mut App,
+ ) {
+ workspace.update(cx, |this, cx| {
+ this.toggle_modal(window, cx, |_window, cx| {
+ ZedPredictModal::new(user_store, client, fs, cx)
+ });
+ });
+ }
+
+ fn view_terms(&mut self, _: &ClickEvent, _: &mut Window, cx: &mut Context<Self>) {
+ cx.open_url("https://zed.dev/terms-of-service");
+ cx.notify();
+ }
+
+ fn view_blog(&mut self, _: &ClickEvent, _: &mut Window, cx: &mut Context<Self>) {
+ cx.open_url("https://zed.dev/blog/"); // TODO Add the link when live
+ cx.notify();
+ }
+
+ fn accept_and_enable(&mut self, _: &ClickEvent, window: &mut Window, cx: &mut Context<Self>) {
+ let task = self
+ .user_store
+ .update(cx, |this, cx| this.accept_terms_of_service(cx));
+
+ cx.spawn(|this, mut cx| async move {
+ task.await?;
+
+ this.update(&mut cx, |this, cx| {
+ update_settings_file::<AllLanguageSettings>(this.fs.clone(), cx, move |file, _| {
+ file.features
+ .get_or_insert(Default::default())
+ .inline_completion_provider = Some(InlineCompletionProvider::Zed);
+ });
+
+ cx.emit(DismissEvent);
+ })
+ })
+ .detach_and_notify_err(window, cx);
+ }
+
+ fn sign_in(&mut self, _: &ClickEvent, window: &mut Window, cx: &mut Context<Self>) {
+ let client = self.client.clone();
+ self.sign_in_status = SignInStatus::Waiting;
+
+ cx.spawn(move |this, mut cx| async move {
+ let result = client.authenticate_and_connect(true, &cx).await;
+
+ let status = match result {
+ Ok(_) => SignInStatus::SignedIn,
+ Err(_) => SignInStatus::Idle,
+ };
+
+ this.update(&mut cx, |this, cx| {
+ this.sign_in_status = status;
+ cx.notify()
+ })?;
+
+ result
+ })
+ .detach_and_notify_err(window, cx);
+ }
+
+ fn cancel(&mut self, _: &menu::Cancel, _: &mut Window, cx: &mut Context<Self>) {
+ cx.emit(DismissEvent);
+ }
+}
+
+impl EventEmitter<DismissEvent> for ZedPredictModal {}
+
+impl Focusable for ZedPredictModal {
+ fn focus_handle(&self, _cx: &App) -> FocusHandle {
+ self.focus_handle.clone()
+ }
+}
+
+impl ModalView for ZedPredictModal {}
+
+impl Render for ZedPredictModal {
+ fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
+ let base = v_flex()
+ .w(px(420.))
+ .p_4()
+ .relative()
+ .gap_2()
+ .overflow_hidden()
+ .elevation_3(cx)
+ .id("zed predict tos")
+ .track_focus(&self.focus_handle(cx))
+ .on_action(cx.listener(Self::cancel))
+ .key_context("ZedPredictModal")
+ .on_action(cx.listener(|_, _: &menu::Cancel, _window, cx| {
+ cx.emit(DismissEvent);
+ }))
+ .on_any_mouse_down(cx.listener(|this, _: &MouseDownEvent, window, _cx| {
+ this.focus_handle.focus(window);
+ }))
+ .child(
+ div()
+ .p_1p5()
+ .absolute()
+ .top_0()
+ .left_0()
+ .right_0()
+ .h(px(200.))
+ .child(
+ svg()
+ .path("icons/zed_predict_bg.svg")
+ .text_color(cx.theme().colors().icon_disabled)
+ .w(px(416.))
+ .h(px(128.))
+ .overflow_hidden(),
+ ),
+ )
+ .child(
+ h_flex()
+ .w_full()
+ .mb_2()
+ .justify_between()
+ .child(
+ v_flex()
+ .gap_1()
+ .child(
+ Label::new("Introducing Zed AI's")
+ .size(LabelSize::Small)
+ .color(Color::Muted),
+ )
+ .child(Headline::new("Edit Prediction").size(HeadlineSize::Large)),
+ )
+ .child({
+ let tab = |n: usize| {
+ let text_color = cx.theme().colors().text;
+ let border_color = cx.theme().colors().text_accent.opacity(0.4);
+
+ h_flex().child(
+ h_flex()
+ .px_4()
+ .py_0p5()
+ .bg(cx.theme().colors().editor_background)
+ .border_1()
+ .border_color(border_color)
+ .rounded_md()
+ .font(theme::ThemeSettings::get_global(cx).buffer_font.clone())
+ .text_size(TextSize::XSmall.rems(cx))
+ .text_color(text_color)
+ .child("tab")
+ .with_animation(
+ ElementId::Integer(n),
+ Animation::new(Duration::from_secs(2)).repeat(),
+ move |tab, delta| {
+ let delta = (delta - 0.15 * n as f32) / 0.7;
+ let delta = 1.0 - (0.5 - delta).abs() * 2.;
+ let delta = ease_in_out(delta.clamp(0., 1.));
+ let delta = 0.1 + 0.9 * delta;
+
+ tab.border_color(border_color.opacity(delta))
+ .text_color(text_color.opacity(delta))
+ },
+ ),
+ )
+ };
+
+ v_flex()
+ .gap_2()
+ .items_center()
+ .pr_4()
+ .child(tab(0).ml_neg_20())
+ .child(tab(1))
+ .child(tab(2).ml_20())
+ }),
+ )
+ .child(h_flex().absolute().top_2().right_2().child(
+ IconButton::new("cancel", IconName::X).on_click(cx.listener(
+ |_, _: &ClickEvent, _window, cx| {
+ cx.emit(DismissEvent);
+ },
+ )),
+ ));
+
+ let blog_post_button = if cx.is_staff() {
+ Some(
+ Button::new("view-blog", "Read the Blog Post")
+ .full_width()
+ .icon(IconName::ArrowUpRight)
+ .icon_size(IconSize::Indicator)
+ .icon_color(Color::Muted)
+ .on_click(cx.listener(Self::view_blog)),
+ )
+ } else {
+ // TODO: put back when blog post is published
+ None
+ };
+
+ if self.user_store.read(cx).current_user().is_some() {
+ let copy = match self.sign_in_status {
+ SignInStatus::Idle => "Get accurate and helpful edit predictions at every keystroke. To set Zed as your inline completions provider, ensure you:",
+ SignInStatus::SignedIn => "Almost there! Ensure you:",
+ SignInStatus::Waiting => unreachable!(),
+ };
+
+ base.child(Label::new(copy).color(Color::Muted))
+ .child(
+ h_flex()
+ .gap_0p5()
+ .child(CheckboxWithLabel::new(
+ "tos-checkbox",
+ Label::new("Have read and accepted the").color(Color::Muted),
+ self.terms_of_service.into(),
+ cx.listener(move |this, state, _window, cx| {
+ this.terms_of_service = *state == ToggleState::Selected;
+ cx.notify()
+ }),
+ ))
+ .child(
+ Button::new("view-tos", "Terms of Service")
+ .icon(IconName::ArrowUpRight)
+ .icon_size(IconSize::Indicator)
+ .icon_color(Color::Muted)
+ .on_click(cx.listener(Self::view_terms)),
+ ),
+ )
+ .child(
+ v_flex()
+ .mt_2()
+ .gap_2()
+ .w_full()
+ .child(
+ Button::new("accept-tos", "Enable Edit Predictions")
+ .disabled(!self.terms_of_service)
+ .style(ButtonStyle::Tinted(TintColor::Accent))
+ .full_width()
+ .on_click(cx.listener(Self::accept_and_enable)),
+ )
+ .children(blog_post_button),
+ )
+ } else {
+ base.child(
+ Label::new("To set Zed as your inline completions provider, please sign in.")
+ .color(Color::Muted),
+ )
+ .child(
+ v_flex()
+ .mt_2()
+ .gap_2()
+ .w_full()
+ .child(
+ Button::new("accept-tos", "Sign in with GitHub")
+ .disabled(self.sign_in_status == SignInStatus::Waiting)
+ .style(ButtonStyle::Tinted(TintColor::Accent))
+ .full_width()
+ .on_click(cx.listener(Self::sign_in)),
+ )
+ .children(blog_post_button),
+ )
+ }
+ }
+}
@@ -1,155 +0,0 @@
-//! AI service Terms of Service acceptance modal.
-
-use client::UserStore;
-use gpui::{
- App, ClickEvent, DismissEvent, Entity, EventEmitter, FocusHandle, Focusable, MouseDownEvent,
- Render,
-};
-use ui::{prelude::*, TintColor};
-use workspace::{ModalView, Workspace};
-
-/// Terms of acceptance for AI inline prediction.
-pub struct ZedPredictTos {
- focus_handle: FocusHandle,
- user_store: Entity<UserStore>,
- workspace: Entity<Workspace>,
- viewed: bool,
-}
-
-impl ZedPredictTos {
- fn new(
- workspace: Entity<Workspace>,
- user_store: Entity<UserStore>,
- cx: &mut Context<Self>,
- ) -> Self {
- ZedPredictTos {
- viewed: false,
- focus_handle: cx.focus_handle(),
- user_store,
- workspace,
- }
- }
- pub fn toggle(
- workspace: Entity<Workspace>,
- user_store: Entity<UserStore>,
- window: &mut Window,
- cx: &mut App,
- ) {
- workspace.update(cx, |this, cx| {
- let workspace = cx.entity().clone();
- this.toggle_modal(window, cx, |_window, cx| {
- ZedPredictTos::new(workspace, user_store, cx)
- });
- });
- }
-
- fn view_terms(&mut self, _: &ClickEvent, _window: &mut Window, cx: &mut Context<Self>) {
- self.viewed = true;
- cx.open_url("https://zed.dev/terms-of-service");
- cx.notify();
- }
-
- fn accept_terms(&mut self, _: &ClickEvent, _window: &mut Window, cx: &mut Context<Self>) {
- let task = self
- .user_store
- .update(cx, |this, cx| this.accept_terms_of_service(cx));
-
- let workspace = self.workspace.clone();
-
- cx.spawn(|this, mut cx| async move {
- match task.await {
- Ok(_) => this.update(&mut cx, |_, cx| {
- cx.emit(DismissEvent);
- }),
- Err(err) => workspace.update(&mut cx, |this, cx| {
- this.show_error(&err, cx);
- }),
- }
- })
- .detach_and_log_err(cx);
- }
-
- fn cancel(&mut self, _: &menu::Cancel, _window: &mut Window, cx: &mut Context<Self>) {
- cx.emit(DismissEvent);
- }
-}
-
-impl EventEmitter<DismissEvent> for ZedPredictTos {}
-
-impl Focusable for ZedPredictTos {
- fn focus_handle(&self, _cx: &App) -> FocusHandle {
- self.focus_handle.clone()
- }
-}
-
-impl ModalView for ZedPredictTos {}
-
-impl Render for ZedPredictTos {
- fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
- v_flex()
- .id("zed predict tos")
- .track_focus(&self.focus_handle(cx))
- .on_action(cx.listener(Self::cancel))
- .key_context("ZedPredictTos")
- .elevation_3(cx)
- .w_96()
- .items_center()
- .p_4()
- .gap_2()
- .on_action(cx.listener(|_, _: &menu::Cancel, _window, cx| {
- cx.emit(DismissEvent);
- }))
- .on_any_mouse_down(cx.listener(|this, _: &MouseDownEvent, window, _cx| {
- this.focus_handle.focus(window);
- }))
- .child(
- h_flex()
- .w_full()
- .justify_between()
- .child(
- v_flex()
- .gap_0p5()
- .child(
- Label::new("Zed AI")
- .size(LabelSize::Small)
- .color(Color::Muted),
- )
- .child(Headline::new("Edit Prediction")),
- )
- .child(Icon::new(IconName::ZedPredict).size(IconSize::XLarge)),
- )
- .child(
- Label::new(
- "To use Zed AI's Edit Prediction feature, please read and accept our Terms of Service.",
- )
- .color(Color::Muted),
- )
- .child(
- v_flex()
- .mt_2()
- .gap_0p5()
- .w_full()
- .child(if self.viewed {
- Button::new("accept-tos", "I've Read and Accept the Terms of Service")
- .style(ButtonStyle::Tinted(TintColor::Accent))
- .full_width()
- .on_click(cx.listener(Self::accept_terms))
- } else {
- Button::new("view-tos", "Read Terms of Service")
- .style(ButtonStyle::Tinted(TintColor::Accent))
- .icon(IconName::ArrowUpRight)
- .icon_size(IconSize::XSmall)
- .icon_position(IconPosition::End)
- .full_width()
- .on_click(cx.listener(Self::view_terms))
- })
- .child(
- Button::new("cancel", "Cancel")
- .full_width()
- .on_click(cx.listener(|_, _: &ClickEvent, _window, cx| {
- cx.emit(DismissEvent);
- })),
- ),
- )
- }
-}
@@ -22,6 +22,7 @@ arrayvec.workspace = true
client.workspace = true
collections.workspace = true
command_palette_hooks.workspace = true
+db.workspace = true
editor.workspace = true
feature_flags.workspace = true
futures.workspace = true
@@ -34,6 +35,7 @@ language_models.workspace = true
log.workspace = true
menu.workspace = true
rpc.workspace = true
+serde.workspace = true
serde_json.workspace = true
settings.workspace = true
similar.workspace = true
@@ -0,0 +1,54 @@
+use anyhow::Result;
+use collections::HashMap;
+use std::path::{Path, PathBuf};
+use workspace::WorkspaceDb;
+
+use db::sqlez_macros::sql;
+use db::{define_connection, query};
+
+define_connection!(
+ pub static ref DB: ZetaDb<WorkspaceDb> = &[
+ sql! (
+ CREATE TABLE zeta_preferences(
+ worktree_path BLOB NOT NULL PRIMARY KEY,
+ accepted_data_collection INTEGER
+ ) STRICT;
+ ),
+ ];
+);
+
+impl ZetaDb {
+ pub fn get_all_zeta_preferences(&self) -> Result<HashMap<PathBuf, bool>> {
+ Ok(self.get_all_zeta_preferences_query()?.into_iter().collect())
+ }
+
+ query! {
+ fn get_all_zeta_preferences_query() -> Result<Vec<(PathBuf, bool)>> {
+ SELECT worktree_path, accepted_data_collection FROM zeta_preferences
+ }
+ }
+
+ query! {
+ pub fn get_accepted_data_collection(worktree_path: &Path) -> Result<Option<bool>> {
+ SELECT accepted_data_collection FROM zeta_preferences
+ WHERE worktree_path = ?
+ }
+ }
+
+ query! {
+ pub async fn save_accepted_data_collection(worktree_path: PathBuf, accepted_data_collection: bool) -> Result<()> {
+ INSERT INTO zeta_preferences
+ (worktree_path, accepted_data_collection)
+ VALUES
+ (?1, ?2)
+ ON CONFLICT (worktree_path) DO UPDATE SET
+ accepted_data_collection = ?2
+ }
+ }
+
+ query! {
+ pub async fn clear_all_zeta_preferences() -> Result<()> {
+ DELETE FROM zeta_preferences
+ }
+ }
+}
@@ -1,7 +1,10 @@
mod completion_diff_element;
+mod persistence;
mod rate_completion_modal;
pub(crate) use completion_diff_element::*;
+use db::kvp::KEY_VALUE_STORE;
+use inline_completion::DataCollectionState;
pub use rate_completion_modal::*;
use anyhow::{anyhow, Context as _, Result};
@@ -12,6 +15,7 @@ use feature_flags::FeatureFlagAppExt as _;
use futures::AsyncReadExt;
use gpui::{
actions, App, AppContext as _, AsyncApp, Context, Entity, EntityId, Global, Subscription, Task,
+ WeakEntity,
};
use http_client::{HttpClient, Method};
use language::{
@@ -20,26 +24,33 @@ use language::{
};
use language_models::LlmApiToken;
use rpc::{PredictEditsParams, PredictEditsResponse, EXPIRED_LLM_TOKEN_HEADER_NAME};
+use serde::{Deserialize, Serialize};
use std::{
borrow::Cow,
- cmp,
+ cmp, env,
fmt::Write,
future::Future,
mem,
ops::Range,
- path::Path,
+ path::{Path, PathBuf},
sync::Arc,
time::{Duration, Instant},
};
use telemetry_events::InlineCompletionRating;
use util::ResultExt;
use uuid::Uuid;
+use workspace::{
+ notifications::{simple_message_notification::MessageNotification, NotificationId},
+ Workspace,
+};
const CURSOR_MARKER: &'static str = "<|user_cursor_is_here|>";
const START_OF_FILE_MARKER: &'static str = "<|start_of_file|>";
const EDITABLE_REGION_START_MARKER: &'static str = "<|editable_region_start|>";
const EDITABLE_REGION_END_MARKER: &'static str = "<|editable_region_end|>";
const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1);
+const ZED_PREDICT_DATA_COLLECTION_NEVER_ASK_AGAIN_KEY: &'static str =
+ "zed_predict_data_collection_never_ask_again";
// TODO(mgsloan): more systematic way to choose or tune these fairly arbitrary constants?
@@ -187,6 +198,7 @@ pub struct Zeta {
registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
shown_completions: VecDeque<InlineCompletion>,
rated_completions: HashSet<InlineCompletionId>,
+ data_collection_preferences: DataCollectionPreferences,
llm_token: LlmApiToken,
_llm_token_subscription: Subscription,
tos_accepted: bool, // Terms of service accepted
@@ -216,13 +228,13 @@ impl Zeta {
fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
let refresh_llm_token_listener = language_models::RefreshLlmTokenListener::global(cx);
-
Self {
client,
events: VecDeque::new(),
shown_completions: VecDeque::new(),
rated_completions: HashSet::default(),
registered_buffers: HashMap::default(),
+ data_collection_preferences: Self::load_data_collection_preferences(cx),
llm_token: LlmApiToken::default(),
_llm_token_subscription: cx.subscribe(
&refresh_llm_token_listener,
@@ -240,11 +252,16 @@ impl Zeta {
.read(cx)
.current_user_has_accepted_terms()
.unwrap_or(false),
- _user_store_subscription: cx.subscribe(&user_store, |this, _, event, _| match event {
- client::user::Event::TermsStatusUpdated { accepted } => {
- this.tos_accepted = *accepted;
+ _user_store_subscription: cx.subscribe(&user_store, |this, user_store, event, cx| {
+ match event {
+ client::user::Event::PrivateUserInfoUpdated => {
+ this.tos_accepted = user_store
+ .read(cx)
+ .current_user_has_accepted_terms()
+ .unwrap_or(false);
+ }
+ _ => {}
}
- _ => {}
}),
}
}
@@ -308,11 +325,8 @@ impl Zeta {
event: &language::BufferEvent,
cx: &mut Context<Self>,
) {
- match event {
- language::BufferEvent::Edited => {
- self.report_changes_for_buffer(&buffer, cx);
- }
- _ => {}
+ if let language::BufferEvent::Edited = event {
+ self.report_changes_for_buffer(&buffer, cx);
}
}
@@ -320,6 +334,7 @@ impl Zeta {
&mut self,
buffer: &Entity<Buffer>,
cursor: language::Anchor,
+ can_collect_data: bool,
cx: &mut Context<Self>,
perform_predict_edits: F,
) -> Task<Result<Option<InlineCompletion>>>
@@ -370,6 +385,7 @@ impl Zeta {
input_events: input_events.clone(),
input_excerpt: input_excerpt.clone(),
outline: Some(input_outline.clone()),
+ can_collect_data,
};
let response = perform_predict_edits(client, llm_token, is_staff, body).await?;
@@ -540,16 +556,25 @@ and then another
) -> Task<Result<Option<InlineCompletion>>> {
use std::future::ready;
- self.request_completion_impl(buffer, position, cx, |_, _, _, _| ready(Ok(response)))
+ self.request_completion_impl(buffer, position, false, cx, |_, _, _, _| {
+ ready(Ok(response))
+ })
}
pub fn request_completion(
&mut self,
buffer: &Entity<Buffer>,
position: language::Anchor,
+ can_collect_data: bool,
cx: &mut Context<Self>,
) -> Task<Result<Option<InlineCompletion>>> {
- self.request_completion_impl(buffer, position, cx, Self::perform_predict_edits)
+ self.request_completion_impl(
+ buffer,
+ position,
+ can_collect_data,
+ cx,
+ Self::perform_predict_edits,
+ )
}
fn perform_predict_edits(
@@ -862,6 +887,80 @@ and then another
new_snapshot
}
+
+ pub fn data_collection_choice_at(&self, path: &Path) -> DataCollectionChoice {
+ match self.data_collection_preferences.per_worktree.get(path) {
+ Some(true) => DataCollectionChoice::Enabled,
+ Some(false) => DataCollectionChoice::Disabled,
+ None => DataCollectionChoice::NotAnswered,
+ }
+ }
+
+ fn update_data_collection_choice_for_worktree(
+ &mut self,
+ absolute_path_of_project_worktree: PathBuf,
+ can_collect_data: bool,
+ cx: &mut Context<Self>,
+ ) {
+ self.data_collection_preferences
+ .per_worktree
+ .insert(absolute_path_of_project_worktree.clone(), can_collect_data);
+
+ db::write_and_log(cx, move || {
+ persistence::DB
+ .save_accepted_data_collection(absolute_path_of_project_worktree, can_collect_data)
+ });
+ }
+
+ fn set_never_ask_again_for_data_collection(&mut self, cx: &mut Context<Self>) {
+ self.data_collection_preferences.never_ask_again = true;
+
+ // persist choice
+ db::write_and_log(cx, move || {
+ KEY_VALUE_STORE.write_kvp(
+ ZED_PREDICT_DATA_COLLECTION_NEVER_ASK_AGAIN_KEY.into(),
+ "true".to_string(),
+ )
+ });
+ }
+
+ fn load_data_collection_preferences(cx: &mut Context<Self>) -> DataCollectionPreferences {
+ if env::var("ZED_PREDICT_CLEAR_DATA_COLLECTION_PREFERENCES").is_ok() {
+ db::write_and_log(cx, move || async move {
+ KEY_VALUE_STORE
+ .delete_kvp(ZED_PREDICT_DATA_COLLECTION_NEVER_ASK_AGAIN_KEY.into())
+ .await
+ .log_err();
+
+ persistence::DB.clear_all_zeta_preferences().await
+ });
+ return DataCollectionPreferences::default();
+ }
+
+ let never_ask_again = KEY_VALUE_STORE
+ .read_kvp(ZED_PREDICT_DATA_COLLECTION_NEVER_ASK_AGAIN_KEY)
+ .log_err()
+ .flatten()
+ .map(|value| value == "true")
+ .unwrap_or(false);
+
+ let preferences_per_project = persistence::DB
+ .get_all_zeta_preferences()
+ .log_err()
+ .unwrap_or_else(HashMap::default);
+
+ DataCollectionPreferences {
+ never_ask_again,
+ per_worktree: preferences_per_project,
+ }
+ }
+}
+
+#[derive(Default, Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
+struct DataCollectionPreferences {
+ /// Set when a user clicks on "Never Ask Again", can never be unset.
+ never_ask_again: bool,
+ per_worktree: HashMap<PathBuf, bool>,
}
fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b: T2) -> usize {
@@ -1276,22 +1375,120 @@ struct PendingCompletion {
_task: Task<()>,
}
+#[derive(Clone, Copy)]
+pub enum DataCollectionChoice {
+ NotAnswered,
+ Enabled,
+ Disabled,
+}
+
+impl DataCollectionChoice {
+ pub fn is_enabled(&self) -> bool {
+ match self {
+ Self::Enabled => true,
+ Self::NotAnswered | Self::Disabled => false,
+ }
+ }
+
+ pub fn is_answered(&self) -> bool {
+ match self {
+ Self::Enabled | Self::Disabled => true,
+ Self::NotAnswered => false,
+ }
+ }
+
+ pub fn toggle(&self) -> DataCollectionChoice {
+ match self {
+ Self::Enabled => Self::Disabled,
+ Self::Disabled => Self::Enabled,
+ Self::NotAnswered => Self::Enabled,
+ }
+ }
+}
+
pub struct ZetaInlineCompletionProvider {
zeta: Entity<Zeta>,
pending_completions: ArrayVec<PendingCompletion, 2>,
next_pending_completion_id: usize,
current_completion: Option<CurrentInlineCompletion>,
+ data_collection: Option<ProviderDataCollection>,
+}
+
+pub struct ProviderDataCollection {
+ workspace: WeakEntity<Workspace>,
+ worktree_root_path: PathBuf,
+ choice: DataCollectionChoice,
+}
+
+impl ProviderDataCollection {
+ pub fn new(
+ zeta: Entity<Zeta>,
+ workspace: Option<Entity<Workspace>>,
+ buffer: Option<Entity<Buffer>>,
+ cx: &mut App,
+ ) -> Option<ProviderDataCollection> {
+ let workspace = workspace?;
+
+ let worktree_root_path = buffer?.update(cx, |buffer, cx| {
+ let file = buffer.file()?;
+
+ if !file.is_local() || file.is_private() {
+ return None;
+ }
+
+ workspace.update(cx, |workspace, cx| {
+ Some(
+ workspace
+ .absolute_path_of_worktree(file.worktree_id(cx), cx)?
+ .to_path_buf(),
+ )
+ })
+ })?;
+
+ let choice = zeta.read(cx).data_collection_choice_at(&worktree_root_path);
+
+ Some(ProviderDataCollection {
+ workspace: workspace.downgrade(),
+ worktree_root_path,
+ choice,
+ })
+ }
+
+ fn set_choice(&mut self, choice: DataCollectionChoice, zeta: &Entity<Zeta>, cx: &mut App) {
+ self.choice = choice;
+
+ let worktree_root_path = self.worktree_root_path.clone();
+
+ zeta.update(cx, |zeta, cx| {
+ zeta.update_data_collection_choice_for_worktree(
+ worktree_root_path,
+ choice.is_enabled(),
+ cx,
+ )
+ });
+ }
+
+ fn toggle_choice(&mut self, zeta: &Entity<Zeta>, cx: &mut App) {
+ self.set_choice(self.choice.toggle(), zeta, cx);
+ }
}
impl ZetaInlineCompletionProvider {
pub const DEBOUNCE_TIMEOUT: Duration = Duration::from_millis(8);
- pub fn new(zeta: Entity<Zeta>) -> Self {
+ pub fn new(zeta: Entity<Zeta>, data_collection: Option<ProviderDataCollection>) -> Self {
Self {
zeta,
pending_completions: ArrayVec::new(),
next_pending_completion_id: 0,
current_completion: None,
+ data_collection,
+ }
+ }
+
+ fn set_data_collection_choice(&mut self, choice: DataCollectionChoice, cx: &mut App) {
+ if let Some(data_collection) = self.data_collection.as_mut() {
+ data_collection.set_choice(choice, &self.zeta, cx);
}
}
}
@@ -1302,7 +1499,7 @@ impl inline_completion::InlineCompletionProvider for ZetaInlineCompletionProvide
}
fn display_name() -> &'static str {
- "Zed Predict"
+ "Zed's Edit Predictions"
}
fn show_completions_in_menu() -> bool {
@@ -1317,6 +1514,24 @@ impl inline_completion::InlineCompletionProvider for ZetaInlineCompletionProvide
true
}
+ fn data_collection_state(&self, _cx: &App) -> DataCollectionState {
+ let Some(data_collection) = self.data_collection.as_ref() else {
+ return DataCollectionState::Unknown;
+ };
+
+ if data_collection.choice.is_enabled() {
+ DataCollectionState::Enabled
+ } else {
+ DataCollectionState::Disabled
+ }
+ }
+
+ fn toggle_data_collection(&mut self, cx: &mut App) {
+ if let Some(data_collection) = self.data_collection.as_mut() {
+ data_collection.toggle_choice(&self.zeta, cx);
+ }
+ }
+
fn is_enabled(
&self,
buffer: &Entity<Buffer>,
@@ -1362,6 +1577,10 @@ impl inline_completion::InlineCompletionProvider for ZetaInlineCompletionProvide
let pending_completion_id = self.next_pending_completion_id;
self.next_pending_completion_id += 1;
+ let can_collect_data = self
+ .data_collection
+ .as_ref()
+ .map_or(false, |data_collection| data_collection.choice.is_enabled());
let task = cx.spawn(|this, mut cx| async move {
if debounce {
@@ -1370,7 +1589,7 @@ impl inline_completion::InlineCompletionProvider for ZetaInlineCompletionProvide
let completion_request = this.update(&mut cx, |this, cx| {
this.zeta.update(cx, |zeta, cx| {
- zeta.request_completion(&buffer, position, cx)
+ zeta.request_completion(&buffer, position, can_collect_data, cx)
})
});
@@ -1447,8 +1666,80 @@ impl inline_completion::InlineCompletionProvider for ZetaInlineCompletionProvide
// Right now we don't support cycling.
}
- fn accept(&mut self, _cx: &mut Context<Self>) {
+ fn accept(&mut self, cx: &mut Context<Self>) {
self.pending_completions.clear();
+
+ let Some(data_collection) = self.data_collection.as_mut() else {
+ return;
+ };
+
+ if data_collection.choice.is_answered()
+ || self
+ .zeta
+ .read(cx)
+ .data_collection_preferences
+ .never_ask_again
+ {
+ return;
+ }
+
+ struct ZetaDataCollectionNotification;
+ let notification_id = NotificationId::unique::<ZetaDataCollectionNotification>();
+
+ const DATA_COLLECTION_INFO_URL: &str = "https://zed.dev/terms-of-service"; // TODO: Replace for a link that's dedicated to Edit Predictions data collection
+
+ let this = cx.entity();
+ data_collection
+ .workspace
+ .update(cx, |workspace, cx| {
+ workspace.show_notification(notification_id, cx, |cx| {
+ let zeta = self.zeta.clone();
+
+ cx.new(move |_cx| {
+ let message =
+ "To allow Zed to suggest better edits, turn on data collection. You \
+ can turn off at any time via the status bar menu.";
+ MessageNotification::new(message)
+ .with_title("Per-Project Data Collection Program")
+ .show_close_button(false)
+ .with_click_message("Turn On")
+ .on_click({
+ let this = this.clone();
+ move |_window, cx| {
+ this.update(cx, |this, cx| {
+ this.set_data_collection_choice(
+ DataCollectionChoice::Enabled,
+ cx,
+ )
+ });
+ }
+ })
+ .with_secondary_click_message("Turn Off")
+ .on_secondary_click({
+ move |_window, cx| {
+ this.update(cx, |this, cx| {
+ this.set_data_collection_choice(
+ DataCollectionChoice::Disabled,
+ cx,
+ )
+ });
+ }
+ })
+ .with_tertiary_click_message("Never Ask Again")
+ .on_tertiary_click({
+ let zeta = zeta.clone();
+ move |_window, cx| {
+ zeta.update(cx, |zeta, cx| {
+ zeta.set_never_ask_again_for_data_collection(cx);
+ });
+ }
+ })
+ .more_info_message("Learn More")
+ .more_info_url(DATA_COLLECTION_INFO_URL)
+ })
+ });
+ })
+ .log_err();
}
fn discard(&mut self, _cx: &mut Context<Self>) {
@@ -1688,8 +1979,9 @@ mod tests {
let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
- let completion_task =
- zeta.update(cx, |zeta, cx| zeta.request_completion(&buffer, cursor, cx));
+ let completion_task = zeta.update(cx, |zeta, cx| {
+ zeta.request_completion(&buffer, cursor, false, cx)
+ });
let token_request = server.receive::<proto::GetLlmToken>().await.unwrap();
server.respond(