From 0e33a3afe0f37b1a7a37fec6b068abe9dd984009 Mon Sep 17 00:00:00 2001 From: Michael Sloan Date: Sun, 7 Sep 2025 11:16:49 -0600 Subject: [PATCH] zeta: Check whether data collection is allowed for recent edit history (#37680) Also: * Adds tests for can_collect_data. * Temporarily removes collection of diagnostics. Release Notes: - Edit Prediction: Fixed a bug where requests were marked eligible for data collection despite the recent edit history in the request involving files that may not be open source. The requests affected by this bug will not be used in training data. --- Cargo.lock | 1 + .../zed/src/zed/edit_prediction_registry.rs | 7 +- crates/zeta/Cargo.toml | 1 + crates/zeta/src/input_excerpt.rs | 8 +- crates/zeta/src/license_detection.rs | 1 - crates/zeta/src/zeta.rs | 882 ++++++++++++------ crates/zeta_cli/src/main.rs | 35 +- 7 files changed, 595 insertions(+), 340 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index f4c94f8078b1ab392ed1a50e15c71dab1921f0a3..dbcea05ea9bc52288defc8c299d82eb508337544 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -20841,6 +20841,7 @@ dependencies = [ "language_model", "log", "menu", + "parking_lot", "postage", "project", "rand 0.9.1", diff --git a/crates/zed/src/zed/edit_prediction_registry.rs b/crates/zed/src/zed/edit_prediction_registry.rs index 4f009ccb0b1197f11b034ac48b89dd37b6f41278..ae26427fc6547079b163235f5d1c3df26a489795 100644 --- a/crates/zed/src/zed/edit_prediction_registry.rs +++ b/crates/zed/src/zed/edit_prediction_registry.rs @@ -8,7 +8,7 @@ use settings::SettingsStore; use std::{cell::RefCell, rc::Rc, sync::Arc}; use supermaven::{Supermaven, SupermavenCompletionProvider}; use ui::Window; -use zeta::{ProviderDataCollection, ZetaEditPredictionProvider}; +use zeta::ZetaEditPredictionProvider; pub fn init(client: Arc, user_store: Entity, cx: &mut App) { let editors: Rc, AnyWindowHandle>>> = Rc::default(); @@ -214,11 +214,8 @@ fn assign_edit_prediction_provider( }); } - let data_collection = - ProviderDataCollection::new(zeta.clone(), singleton_buffer, cx); - let provider = - cx.new(|_| zeta::ZetaEditPredictionProvider::new(zeta, data_collection)); + cx.new(|_| zeta::ZetaEditPredictionProvider::new(zeta, singleton_buffer)); editor.set_edit_prediction_provider(Some(provider), window, cx); } diff --git a/crates/zeta/Cargo.toml b/crates/zeta/Cargo.toml index a9c2a7619f4db22e51c014672aa2100b30a2539a..09bcfa7f542ce9c01802c9cebc11dfc9a8da2542 100644 --- a/crates/zeta/Cargo.toml +++ b/crates/zeta/Cargo.toml @@ -72,6 +72,7 @@ gpui = { workspace = true, features = ["test-support"] } http_client = { workspace = true, features = ["test-support"] } indoc.workspace = true language = { workspace = true, features = ["test-support"] } +parking_lot.workspace = true reqwest_client = { workspace = true, features = ["test-support"] } rpc = { workspace = true, features = ["test-support"] } settings = { workspace = true, features = ["test-support"] } diff --git a/crates/zeta/src/input_excerpt.rs b/crates/zeta/src/input_excerpt.rs index dd1bbed1d72e8668e9ed55c9b66b911addfcdd43..06bff5b1bea0f099b2ccd98605ac5de5bb5e6360 100644 --- a/crates/zeta/src/input_excerpt.rs +++ b/crates/zeta/src/input_excerpt.rs @@ -1,6 +1,6 @@ use crate::{ CURSOR_MARKER, EDITABLE_REGION_END_MARKER, EDITABLE_REGION_START_MARKER, START_OF_FILE_MARKER, - tokens_for_bytes, + guess_token_count, }; use language::{BufferSnapshot, Point}; use std::{fmt::Write, ops::Range}; @@ -22,7 +22,7 @@ pub fn excerpt_for_cursor_position( let mut remaining_edit_tokens = editable_region_token_limit; while let Some(parent) = snapshot.syntax_ancestor(scope_range.clone()) { - let parent_tokens = tokens_for_bytes(parent.byte_range().len()); + let parent_tokens = guess_token_count(parent.byte_range().len()); let parent_point_range = Point::new( parent.start_position().row as u32, parent.start_position().column as u32, @@ -99,7 +99,7 @@ fn expand_range( if remaining_tokens > 0 && expanded_range.start.row > 0 { expanded_range.start.row -= 1; let line_tokens = - tokens_for_bytes(snapshot.line_len(expanded_range.start.row) as usize); + guess_token_count(snapshot.line_len(expanded_range.start.row) as usize); remaining_tokens = remaining_tokens.saturating_sub(line_tokens); expanded = true; } @@ -107,7 +107,7 @@ fn expand_range( if remaining_tokens > 0 && expanded_range.end.row < snapshot.max_point().row { expanded_range.end.row += 1; expanded_range.end.column = snapshot.line_len(expanded_range.end.row); - let line_tokens = tokens_for_bytes(expanded_range.end.column as usize); + let line_tokens = guess_token_count(expanded_range.end.column as usize); remaining_tokens = remaining_tokens.saturating_sub(line_tokens); expanded = true; } diff --git a/crates/zeta/src/license_detection.rs b/crates/zeta/src/license_detection.rs index 5f207a44e8bd2028e6a2b416e978f101cfe5bd57..e06e1577a66cc160efa00213b80c6ca407f7be85 100644 --- a/crates/zeta/src/license_detection.rs +++ b/crates/zeta/src/license_detection.rs @@ -358,7 +358,6 @@ impl LicenseDetectionWatcher { #[cfg(test)] mod tests { - use fs::FakeFs; use gpui::TestAppContext; use serde_json::json; diff --git a/crates/zeta/src/zeta.rs b/crates/zeta/src/zeta.rs index 3851d16755783209fd9da4f468a494779a7d9fe7..dfcf98f025c2e020d6545efca64d4ab12579e370 100644 --- a/crates/zeta/src/zeta.rs +++ b/crates/zeta/src/zeta.rs @@ -29,7 +29,7 @@ use gpui::{ use http_client::{AsyncBody, HttpClient, Method, Request, Response}; use input_excerpt::excerpt_for_cursor_position; use language::{ - Anchor, Buffer, BufferSnapshot, EditPreview, OffsetRangeExt, ToOffset, ToPoint, text_diff, + Anchor, Buffer, BufferSnapshot, EditPreview, File, OffsetRangeExt, ToOffset, ToPoint, text_diff, }; use language_model::{LlmApiToken, RefreshLlmTokenListener}; use project::{Project, ProjectPath}; @@ -65,7 +65,6 @@ const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_ch const MAX_CONTEXT_TOKENS: usize = 150; const MAX_REWRITE_TOKENS: usize = 350; const MAX_EVENT_TOKENS: usize = 500; -const MAX_DIAGNOSTIC_GROUPS: usize = 10; /// Maximum number of events to track. const MAX_EVENT_COUNT: usize = 16; @@ -216,7 +215,7 @@ pub struct Zeta { client: Arc, shown_completions: VecDeque, rated_completions: HashSet, - data_collection_choice: Entity, + data_collection_choice: DataCollectionChoice, llm_token: LlmApiToken, _llm_token_subscription: Subscription, /// Whether an update to a newer version of Zed is required to continue using Zeta. @@ -271,10 +270,7 @@ impl Zeta { fn new(client: Arc, user_store: Entity, cx: &mut Context) -> Self { let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx); - - let data_collection_choice = Self::load_data_collection_choices(); - let data_collection_choice = cx.new(|_| data_collection_choice); - + let data_collection_choice = Self::load_data_collection_choice(); Self { projects: HashMap::default(), client, @@ -408,7 +404,6 @@ impl Zeta { project: &Entity, buffer: &Entity, cursor: language::Anchor, - can_collect_data: bool, cx: &mut Context, perform_predict_edits: F, ) -> Task>> @@ -422,15 +417,25 @@ impl Zeta { let buffer_snapshotted_at = Instant::now(); let snapshot = self.report_changes_for_buffer(&buffer, project, cx); let zeta = cx.entity(); - let events = self.get_or_init_zeta_project(project, cx).events.clone(); let client = self.client.clone(); let llm_token = self.llm_token.clone(); let app_version = AppVersion::global(cx); - let git_info = if let (true, Some(file)) = (can_collect_data, snapshot.file()) { - git_info_for_file(project, &ProjectPath::from_file(file.as_ref(), cx), cx) + let zeta_project = self.get_or_init_zeta_project(project, cx); + let mut events = Vec::with_capacity(zeta_project.events.len()); + events.extend(zeta_project.events.iter().cloned()); + let events = Arc::new(events); + + let (git_info, can_collect_file) = if let Some(file) = snapshot.file() { + let can_collect_file = self.can_collect_file(file, cx); + let git_info = if can_collect_file { + git_info_for_file(project, &ProjectPath::from_file(file.as_ref(), cx), cx) + } else { + None + }; + (git_info, can_collect_file) } else { - None + (None, false) }; let full_path: Arc = snapshot @@ -440,25 +445,35 @@ impl Zeta { let full_path_str = full_path.to_string_lossy().to_string(); let cursor_point = cursor.to_point(&snapshot); let cursor_offset = cursor_point.to_offset(&snapshot); - let make_events_prompt = move || prompt_for_events(&events, MAX_EVENT_TOKENS); + let prompt_for_events = { + let events = events.clone(); + move || prompt_for_events_impl(&events, MAX_EVENT_TOKENS) + }; let gather_task = gather_context( - project, full_path_str, &snapshot, cursor_point, - make_events_prompt, - can_collect_data, - git_info, + prompt_for_events, cx, ); cx.spawn(async move |this, cx| { let GatherContextOutput { - body, + mut body, editable_range, + included_events_count, } = gather_task.await?; let done_gathering_context_at = Instant::now(); + let included_events = &events[events.len() - included_events_count..events.len()]; + body.can_collect_data = can_collect_file + && this + .read_with(cx, |this, cx| this.can_collect_events(included_events, cx)) + .unwrap_or(false); + if body.can_collect_data { + body.git_info = git_info; + } + log::debug!( "Events:\n{}\nExcerpt:\n{:?}", body.input_events, @@ -563,10 +578,8 @@ impl Zeta { response: PredictEditsResponse, cx: &mut Context, ) -> Task>> { - use std::future::ready; - - self.request_completion_impl(project, buffer, position, false, cx, |_params| { - ready(Ok((response, None))) + self.request_completion_impl(project, buffer, position, cx, |_params| { + std::future::ready(Ok((response, None))) }) } @@ -575,17 +588,9 @@ impl Zeta { project: &Entity, buffer: &Entity, position: language::Anchor, - can_collect_data: bool, cx: &mut Context, ) -> Task>> { - self.request_completion_impl( - project, - buffer, - position, - can_collect_data, - cx, - Self::perform_predict_edits, - ) + self.request_completion_impl(project, buffer, position, cx, Self::perform_predict_edits) } pub fn perform_predict_edits( @@ -954,7 +959,58 @@ impl Zeta { new_snapshot } - fn load_data_collection_choices() -> DataCollectionChoice { + fn can_collect_file(&self, file: &Arc, cx: &App) -> bool { + self.data_collection_choice.is_enabled() && self.is_file_open_source(file, cx) + } + + fn can_collect_events(&self, events: &[Event], cx: &App) -> bool { + if !self.data_collection_choice.is_enabled() { + return false; + } + let mut last_checked_file = None; + for event in events { + match event { + Event::BufferChange { + old_snapshot, + new_snapshot, + .. + } => { + if let Some(old_file) = old_snapshot.file() + && let Some(new_file) = new_snapshot.file() + { + if let Some(last_checked_file) = last_checked_file + && Arc::ptr_eq(last_checked_file, old_file) + && Arc::ptr_eq(last_checked_file, new_file) + { + continue; + } + if !self.can_collect_file(old_file, cx) { + return false; + } + if !Arc::ptr_eq(old_file, new_file) && !self.can_collect_file(new_file, cx) + { + return false; + } + last_checked_file = Some(new_file); + } else { + return false; + } + } + } + } + true + } + + fn is_file_open_source(&self, file: &Arc, cx: &App) -> bool { + if !file.is_local() || file.is_private() { + return false; + } + self.license_detection_watchers + .get(&file.worktree_id(cx)) + .is_some_and(|watcher| watcher.is_project_open_source()) + } + + fn load_data_collection_choice() -> DataCollectionChoice { let choice = KEY_VALUE_STORE .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE) .log_err() @@ -970,6 +1026,17 @@ impl Zeta { None => DataCollectionChoice::NotAnswered, } } + + fn toggle_data_collection_choice(&mut self, cx: &mut Context) { + self.data_collection_choice = self.data_collection_choice.toggle(); + let new_choice = self.data_collection_choice; + db::write_and_log(cx, move || { + KEY_VALUE_STORE.write_kvp( + ZED_PREDICT_DATA_COLLECTION_CHOICE.into(), + new_choice.is_enabled().to_string(), + ) + }); + } } pub struct PerformPredictEditsParams { @@ -1026,48 +1093,19 @@ fn git_info_for_file( pub struct GatherContextOutput { pub body: PredictEditsBody, pub editable_range: Range, + pub included_events_count: usize, } pub fn gather_context( - project: &Entity, full_path_str: String, snapshot: &BufferSnapshot, cursor_point: language::Point, - make_events_prompt: impl FnOnce() -> String + Send + 'static, - can_collect_data: bool, - git_info: Option, + prompt_for_events: impl FnOnce() -> (String, usize) + Send + 'static, cx: &App, ) -> Task> { - let local_lsp_store = project.read(cx).lsp_store().read(cx).as_local(); - let diagnostic_groups: Vec<(String, serde_json::Value)> = - if can_collect_data && let Some(local_lsp_store) = local_lsp_store { - snapshot - .diagnostic_groups(None) - .into_iter() - .filter_map(|(language_server_id, diagnostic_group)| { - let language_server = - local_lsp_store.running_language_server_for_id(language_server_id)?; - let diagnostic_group = diagnostic_group.resolve::(snapshot); - let language_server_name = language_server.name().to_string(); - let serialized = serde_json::to_value(diagnostic_group).unwrap(); - Some((language_server_name, serialized)) - }) - .collect::>() - } else { - Vec::new() - }; - cx.background_spawn({ let snapshot = snapshot.clone(); async move { - let diagnostic_groups = if diagnostic_groups.is_empty() - || diagnostic_groups.len() >= MAX_DIAGNOSTIC_GROUPS - { - None - } else { - Some(diagnostic_groups) - }; - let input_excerpt = excerpt_for_cursor_position( cursor_point, &full_path_str, @@ -1075,15 +1113,15 @@ pub fn gather_context( MAX_REWRITE_TOKENS, MAX_CONTEXT_TOKENS, ); - let input_events = make_events_prompt(); + let (input_events, included_events_count) = prompt_for_events(); let editable_range = input_excerpt.editable_range.to_offset(&snapshot); let body = PredictEditsBody { input_events, input_excerpt: input_excerpt.prompt, - can_collect_data, - diagnostic_groups, - git_info, + can_collect_data: false, + diagnostic_groups: None, + git_info: None, outline: None, speculated_output: None, }; @@ -1091,18 +1129,19 @@ pub fn gather_context( Ok(GatherContextOutput { body, editable_range, + included_events_count, }) } }) } -fn prompt_for_events(events: &VecDeque, mut remaining_tokens: usize) -> String { +fn prompt_for_events_impl(events: &[Event], mut remaining_tokens: usize) -> (String, usize) { let mut result = String::new(); - for event in events.iter().rev() { + for (ix, event) in events.iter().rev().enumerate() { let event_string = event.to_prompt(); - let event_tokens = tokens_for_bytes(event_string.len()); + let event_tokens = guess_token_count(event_string.len()); if event_tokens > remaining_tokens { - break; + return (result, ix); } if !result.is_empty() { @@ -1111,7 +1150,7 @@ fn prompt_for_events(events: &VecDeque, mut remaining_tokens: usize) -> S result.insert_str(0, &event_string); remaining_tokens -= event_tokens; } - result + return (result, events.len()); } struct RegisteredBuffer { @@ -1222,6 +1261,7 @@ impl DataCollectionChoice { } } + #[must_use] pub fn toggle(&self) -> DataCollectionChoice { match self { Self::Enabled => Self::Disabled, @@ -1240,79 +1280,6 @@ impl From for DataCollectionChoice { } } -pub struct ProviderDataCollection { - /// When set to None, data collection is not possible in the provider buffer - choice: Option>, - license_detection_watcher: Option>, -} - -impl ProviderDataCollection { - pub fn new(zeta: Entity, buffer: Option>, cx: &mut App) -> Self { - let choice_and_watcher = buffer.and_then(|buffer| { - let file = buffer.read(cx).file()?; - - if !file.is_local() || file.is_private() { - return None; - } - - let zeta = zeta.read(cx); - let choice = zeta.data_collection_choice.clone(); - - let license_detection_watcher = zeta - .license_detection_watchers - .get(&file.worktree_id(cx)) - .cloned()?; - - Some((choice, license_detection_watcher)) - }); - - if let Some((choice, watcher)) = choice_and_watcher { - ProviderDataCollection { - choice: Some(choice), - license_detection_watcher: Some(watcher), - } - } else { - ProviderDataCollection { - choice: None, - license_detection_watcher: None, - } - } - } - - pub fn can_collect_data(&self, cx: &App) -> bool { - self.is_data_collection_enabled(cx) && self.is_project_open_source() - } - - pub fn is_data_collection_enabled(&self, cx: &App) -> bool { - self.choice - .as_ref() - .is_some_and(|choice| choice.read(cx).is_enabled()) - } - - fn is_project_open_source(&self) -> bool { - self.license_detection_watcher - .as_ref() - .is_some_and(|watcher| watcher.is_project_open_source()) - } - - pub fn toggle(&mut self, cx: &mut App) { - if let Some(choice) = self.choice.as_mut() { - let new_choice = choice.update(cx, |choice, _cx| { - let new_choice = choice.toggle(); - *choice = new_choice; - new_choice - }); - - db::write_and_log(cx, move || { - KEY_VALUE_STORE.write_kvp( - ZED_PREDICT_DATA_COLLECTION_CHOICE.into(), - new_choice.is_enabled().to_string(), - ) - }); - } - } -} - async fn llm_token_retry( llm_token: &LlmApiToken, client: &Arc, @@ -1343,24 +1310,23 @@ async fn llm_token_retry( pub struct ZetaEditPredictionProvider { zeta: Entity, + singleton_buffer: Option>, pending_completions: ArrayVec, next_pending_completion_id: usize, current_completion: Option, - /// None if this is entirely disabled for this provider - provider_data_collection: ProviderDataCollection, last_request_timestamp: Instant, } impl ZetaEditPredictionProvider { pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300); - pub fn new(zeta: Entity, provider_data_collection: ProviderDataCollection) -> Self { + pub fn new(zeta: Entity, singleton_buffer: Option>) -> Self { Self { zeta, + singleton_buffer, pending_completions: ArrayVec::new(), next_pending_completion_id: 0, current_completion: None, - provider_data_collection, last_request_timestamp: Instant::now(), } } @@ -1384,21 +1350,29 @@ impl edit_prediction::EditPredictionProvider for ZetaEditPredictionProvider { } fn data_collection_state(&self, cx: &App) -> DataCollectionState { - let is_project_open_source = self.provider_data_collection.is_project_open_source(); - - if self.provider_data_collection.is_data_collection_enabled(cx) { - DataCollectionState::Enabled { - is_project_open_source, + if let Some(buffer) = &self.singleton_buffer + && let Some(file) = buffer.read(cx).file() + { + let is_project_open_source = self.zeta.read(cx).is_file_open_source(file, cx); + if self.zeta.read(cx).data_collection_choice.is_enabled() { + DataCollectionState::Enabled { + is_project_open_source, + } + } else { + DataCollectionState::Disabled { + is_project_open_source, + } } } else { - DataCollectionState::Disabled { - is_project_open_source, - } + return DataCollectionState::Disabled { + is_project_open_source: false, + }; } } fn toggle_data_collection(&mut self, cx: &mut App) { - self.provider_data_collection.toggle(cx); + self.zeta + .update(cx, |zeta, cx| zeta.toggle_data_collection_choice(cx)); } fn usage(&self, cx: &App) -> Option { @@ -1456,7 +1430,6 @@ impl edit_prediction::EditPredictionProvider for ZetaEditPredictionProvider { let pending_completion_id = self.next_pending_completion_id; self.next_pending_completion_id += 1; - let can_collect_data = self.provider_data_collection.can_collect_data(cx); let last_request_timestamp = self.last_request_timestamp; let task = cx.spawn(async move |this, cx| { @@ -1469,7 +1442,7 @@ impl edit_prediction::EditPredictionProvider for ZetaEditPredictionProvider { let completion_request = this.update(cx, |this, cx| { this.last_request_timestamp = Instant::now(); this.zeta.update(cx, |zeta, cx| { - zeta.request_completion(&project, &buffer, position, can_collect_data, cx) + zeta.request_completion(&project, &buffer, position, cx) }) }); @@ -1638,10 +1611,11 @@ impl edit_prediction::EditPredictionProvider for ZetaEditPredictionProvider { } } -fn tokens_for_bytes(bytes: usize) -> usize { - /// Typical number of string bytes per token for the purposes of limiting model input. This is - /// intentionally low to err on the side of underestimating limits. - const BYTES_PER_TOKEN_GUESS: usize = 3; +/// Typical number of string bytes per token for the purposes of limiting model input. This is +/// intentionally low to err on the side of underestimating limits. +const BYTES_PER_TOKEN_GUESS: usize = 3; + +fn guess_token_count(bytes: usize) -> usize { bytes / BYTES_PER_TOKEN_GUESS } @@ -1654,11 +1628,15 @@ mod tests { use http_client::FakeHttpClient; use indoc::indoc; use language::Point; + use parking_lot::Mutex; + use serde_json::json; use settings::SettingsStore; use util::path; use super::*; + const BSD_0_TXT: &str = include_str!("../license_examples/0bsd.txt"); + #[gpui::test] async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) { let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx)); @@ -1778,77 +1756,65 @@ mod tests { #[gpui::test] async fn test_clean_up_diff(cx: &mut TestAppContext) { - cx.update(|cx| { - let settings_store = SettingsStore::test(cx); - cx.set_global(settings_store); - client::init_settings(cx); - Project::init_settings(cx); - }); + init_test(cx); - let edits = edits_for_prediction( - indoc! {" - fn main() { - let word_1 = \"lorem\"; - let range = word.len()..word.len(); - } - "}, + assert_eq!( + apply_edit_prediction( + indoc! {" + fn main() { + let word_1 = \"lorem\"; + let range = word.len()..word.len(); + } + "}, + indoc! {" + <|editable_region_start|> + fn main() { + let word_1 = \"lorem\"; + let range = word_1.len()..word_1.len(); + } + + <|editable_region_end|> + "}, + cx, + ) + .await, indoc! {" - <|editable_region_start|> fn main() { let word_1 = \"lorem\"; let range = word_1.len()..word_1.len(); } - - <|editable_region_end|> "}, - cx, - ) - .await; - assert_eq!( - edits, - [ - (Point::new(2, 20)..Point::new(2, 20), "_1".to_string()), - (Point::new(2, 32)..Point::new(2, 32), "_1".to_string()), - ] ); - let edits = edits_for_prediction( - indoc! {" - fn main() { - let story = \"the quick\" - } - "}, + assert_eq!( + apply_edit_prediction( + indoc! {" + fn main() { + let story = \"the quick\" + } + "}, + indoc! {" + <|editable_region_start|> + fn main() { + let story = \"the quick brown fox jumps over the lazy dog\"; + } + + <|editable_region_end|> + "}, + cx, + ) + .await, indoc! {" - <|editable_region_start|> fn main() { let story = \"the quick brown fox jumps over the lazy dog\"; } - - <|editable_region_end|> "}, - cx, - ) - .await; - assert_eq!( - edits, - [ - ( - Point::new(1, 26)..Point::new(1, 26), - " brown fox jumps over the lazy dog".to_string() - ), - (Point::new(1, 27)..Point::new(1, 27), ";".to_string()), - ] ); } #[gpui::test] async fn test_edit_prediction_end_of_buffer(cx: &mut TestAppContext) { - cx.update(|cx| { - let settings_store = SettingsStore::test(cx); - cx.set_global(settings_store); - client::init_settings(cx); - Project::init_settings(cx); - }); + init_test(cx); let buffer_content = "lorem\n"; let completion_response = indoc! {" @@ -1860,98 +1826,404 @@ mod tests { <|editable_region_end|> ```"}; - let http_client = FakeHttpClient::create(move |req| async move { - match (req.method(), req.uri().path()) { - (&Method::POST, "/client/llm_tokens") => Ok(http_client::Response::builder() - .status(200) - .body( - serde_json::to_string(&CreateLlmTokenResponse { - token: LlmToken("the-llm-token".to_string()), - }) - .unwrap() - .into(), - ) - .unwrap()), - (&Method::POST, "/predict_edits/v2") => Ok(http_client::Response::builder() - .status(200) - .body( - serde_json::to_string(&PredictEditsResponse { - request_id: Uuid::parse_str("7e86480f-3536-4d2c-9334-8213e3445d45") - .unwrap(), - output_excerpt: completion_response.to_string(), - }) - .unwrap() - .into(), - ) - .unwrap()), - _ => Ok(http_client::Response::builder() - .status(404) - .body("Not Found".into()) - .unwrap()), - } + assert_eq!( + apply_edit_prediction(buffer_content, completion_response, cx).await, + "lorem\nipsum" + ); + } + + #[gpui::test] + async fn test_can_collect_data(cx: &mut TestAppContext) { + init_test(cx); + + let fs = project::FakeFs::new(cx.executor()); + fs.insert_tree(path!("/project"), json!({ "LICENSE": BSD_0_TXT })) + .await; + + let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; + let buffer = project + .update(cx, |project, cx| { + project.open_local_buffer(path!("/project/src/main.rs"), cx) + }) + .await + .unwrap(); + + let (zeta, captured_request, _) = make_test_zeta(&project, cx).await; + zeta.update(cx, |zeta, _cx| { + zeta.data_collection_choice = DataCollectionChoice::Enabled }); - let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx)); - cx.update(|cx| { - RefreshLlmTokenListener::register(client.clone(), cx); + run_edit_prediction(&buffer, &project, &zeta, cx).await; + assert_eq!( + captured_request.lock().clone().unwrap().can_collect_data, + true + ); + + zeta.update(cx, |zeta, _cx| { + zeta.data_collection_choice = DataCollectionChoice::Disabled }); - // Construct the fake server to authenticate. - let _server = FakeServer::for_client(42, &client, cx).await; + + run_edit_prediction(&buffer, &project, &zeta, cx).await; + assert_eq!( + captured_request.lock().clone().unwrap().can_collect_data, + false + ); + } + + #[gpui::test] + async fn test_no_data_collection_for_remote_file(cx: &mut TestAppContext) { + init_test(cx); + + let fs = project::FakeFs::new(cx.executor()); + let project = Project::test(fs.clone(), [], cx).await; + + let buffer = cx.new(|_cx| { + Buffer::remote( + language::BufferId::new(1).unwrap(), + 1, + language::Capability::ReadWrite, + "fn main() {\n println!(\"Hello\");\n}", + ) + }); + + let (zeta, captured_request, _) = make_test_zeta(&project, cx).await; + zeta.update(cx, |zeta, _cx| { + zeta.data_collection_choice = DataCollectionChoice::Enabled + }); + + run_edit_prediction(&buffer, &project, &zeta, cx).await; + assert_eq!( + captured_request.lock().clone().unwrap().can_collect_data, + false + ); + } + + #[gpui::test] + async fn test_no_data_collection_for_private_file(cx: &mut TestAppContext) { + init_test(cx); + let fs = project::FakeFs::new(cx.executor()); + fs.insert_tree( + path!("/project"), + json!({ + "LICENSE": BSD_0_TXT, + ".env": "SECRET_KEY=secret" + }), + ) + .await; + let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; - let buffer = cx.new(|cx| Buffer::local(buffer_content, cx)); - let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0))); + let buffer = project + .update(cx, |project, cx| { + project.open_local_buffer("/project/.env", cx) + }) + .await + .unwrap(); - let zeta = cx.new(|cx| Zeta::new(client, project.read(cx).user_store(), cx)); - let completion_task = zeta.update(cx, |zeta, cx| { - zeta.request_completion(&project, &buffer, cursor, false, cx) + let (zeta, captured_request, _) = make_test_zeta(&project, cx).await; + zeta.update(cx, |zeta, _cx| { + zeta.data_collection_choice = DataCollectionChoice::Enabled + }); + + run_edit_prediction(&buffer, &project, &zeta, cx).await; + assert_eq!( + captured_request.lock().clone().unwrap().can_collect_data, + false + ); + } + + #[gpui::test] + async fn test_no_data_collection_for_untitled_buffer(cx: &mut TestAppContext) { + init_test(cx); + + let fs = project::FakeFs::new(cx.executor()); + let project = Project::test(fs.clone(), [], cx).await; + let buffer = cx.new(|cx| Buffer::local("", cx)); + + let (zeta, captured_request, _) = make_test_zeta(&project, cx).await; + zeta.update(cx, |zeta, _cx| { + zeta.data_collection_choice = DataCollectionChoice::Enabled + }); + + run_edit_prediction(&buffer, &project, &zeta, cx).await; + assert_eq!( + captured_request.lock().clone().unwrap().can_collect_data, + false + ); + } + + #[gpui::test] + async fn test_no_data_collection_when_closed_source(cx: &mut TestAppContext) { + init_test(cx); + + let fs = project::FakeFs::new(cx.executor()); + fs.insert_tree(path!("/project"), json!({ "main.rs": "fn main() {}" })) + .await; + + let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; + let buffer = project + .update(cx, |project, cx| { + project.open_local_buffer("/project/main.rs", cx) + }) + .await + .unwrap(); + + let (zeta, captured_request, _) = make_test_zeta(&project, cx).await; + zeta.update(cx, |zeta, _cx| { + zeta.data_collection_choice = DataCollectionChoice::Enabled + }); + + run_edit_prediction(&buffer, &project, &zeta, cx).await; + assert_eq!( + captured_request.lock().clone().unwrap().can_collect_data, + false + ); + } + + #[gpui::test] + async fn test_data_collection_status_changes_on_move(cx: &mut TestAppContext) { + init_test(cx); + + let fs = project::FakeFs::new(cx.executor()); + fs.insert_tree( + path!("/open_source_worktree"), + json!({ "LICENSE": BSD_0_TXT, "main.rs": "" }), + ) + .await; + fs.insert_tree(path!("/closed_source_worktree"), json!({ "main.rs": "" })) + .await; + + let project = Project::test( + fs.clone(), + [ + path!("/open_source_worktree").as_ref(), + path!("/closed_source_worktree").as_ref(), + ], + cx, + ) + .await; + let buffer = project + .update(cx, |project, cx| { + project.open_local_buffer(path!("/open_source_worktree/main.rs"), cx) + }) + .await + .unwrap(); + + let (zeta, captured_request, _) = make_test_zeta(&project, cx).await; + zeta.update(cx, |zeta, _cx| { + zeta.data_collection_choice = DataCollectionChoice::Enabled }); - let completion = completion_task.await.unwrap().unwrap(); + run_edit_prediction(&buffer, &project, &zeta, cx).await; + assert_eq!( + captured_request.lock().clone().unwrap().can_collect_data, + true + ); + + let closed_source_file = project + .update(cx, |project, cx| { + let worktree2 = project + .worktree_for_root_name("closed_source_worktree", cx) + .unwrap(); + worktree2.update(cx, |worktree2, cx| { + worktree2.load_file(Path::new("main.rs"), cx) + }) + }) + .await + .unwrap() + .file; + buffer.update(cx, |buffer, cx| { - buffer.edit(completion.edits.iter().cloned(), None, cx) + buffer.file_updated(closed_source_file, cx); }); + + run_edit_prediction(&buffer, &project, &zeta, cx).await; assert_eq!( - buffer.read_with(cx, |buffer, _| buffer.text()), - "lorem\nipsum" + captured_request.lock().clone().unwrap().can_collect_data, + false ); } - async fn edits_for_prediction( + #[gpui::test] + async fn test_no_data_collection_for_events_in_uncollectable_buffers(cx: &mut TestAppContext) { + init_test(cx); + + let fs = project::FakeFs::new(cx.executor()); + fs.insert_tree( + path!("/worktree1"), + json!({ "LICENSE": BSD_0_TXT, "main.rs": "", "other.rs": "" }), + ) + .await; + fs.insert_tree(path!("/worktree2"), json!({ "private.rs": "" })) + .await; + + let project = Project::test( + fs.clone(), + [path!("/worktree1").as_ref(), path!("/worktree2").as_ref()], + cx, + ) + .await; + let buffer = project + .update(cx, |project, cx| { + project.open_local_buffer(path!("/worktree1/main.rs"), cx) + }) + .await + .unwrap(); + let private_buffer = project + .update(cx, |project, cx| { + project.open_local_buffer(path!("/worktree2/file.rs"), cx) + }) + .await + .unwrap(); + + let (zeta, captured_request, _) = make_test_zeta(&project, cx).await; + zeta.update(cx, |zeta, _cx| { + zeta.data_collection_choice = DataCollectionChoice::Enabled + }); + + run_edit_prediction(&buffer, &project, &zeta, cx).await; + assert_eq!( + captured_request.lock().clone().unwrap().can_collect_data, + true + ); + + // this has a side effect of registering the buffer to watch for edits + run_edit_prediction(&private_buffer, &project, &zeta, cx).await; + assert_eq!( + captured_request.lock().clone().unwrap().can_collect_data, + false + ); + + private_buffer.update(cx, |private_buffer, cx| { + private_buffer.edit([(0..0, "An edit for the history!")], None, cx); + }); + + run_edit_prediction(&buffer, &project, &zeta, cx).await; + assert_eq!( + captured_request.lock().clone().unwrap().can_collect_data, + false + ); + + // make an edit that uses too many bytes, causing private_buffer edit to not be able to be + // included + buffer.update(cx, |buffer, cx| { + buffer.edit( + [(0..0, " ".repeat(MAX_EVENT_TOKENS * BYTES_PER_TOKEN_GUESS))], + None, + cx, + ); + }); + + run_edit_prediction(&buffer, &project, &zeta, cx).await; + assert_eq!( + captured_request.lock().clone().unwrap().can_collect_data, + true + ); + } + + fn init_test(cx: &mut TestAppContext) { + cx.update(|cx| { + let settings_store = SettingsStore::test(cx); + cx.set_global(settings_store); + language::init(cx); + client::init_settings(cx); + Project::init_settings(cx); + }); + } + + async fn apply_edit_prediction( buffer_content: &str, completion_response: &str, cx: &mut TestAppContext, - ) -> Vec<(Range, String)> { - let completion_response = completion_response.to_string(); - let http_client = FakeHttpClient::create(move |req| { - let completion = completion_response.clone(); - async move { - match (req.method(), req.uri().path()) { - (&Method::POST, "/client/llm_tokens") => Ok(http_client::Response::builder() - .status(200) - .body( - serde_json::to_string(&CreateLlmTokenResponse { - token: LlmToken("the-llm-token".to_string()), - }) - .unwrap() - .into(), - ) - .unwrap()), - (&Method::POST, "/predict_edits/v2") => Ok(http_client::Response::builder() - .status(200) - .body( - serde_json::to_string(&PredictEditsResponse { - request_id: Uuid::new_v4(), - output_excerpt: completion, - }) - .unwrap() - .into(), - ) - .unwrap()), - _ => Ok(http_client::Response::builder() - .status(404) - .body("Not Found".into()) - .unwrap()), + ) -> String { + let fs = project::FakeFs::new(cx.executor()); + let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; + let buffer = cx.new(|cx| Buffer::local(buffer_content, cx)); + let (zeta, _, response) = make_test_zeta(&project, cx).await; + *response.lock() = completion_response.to_string(); + let edit_prediction = run_edit_prediction(&buffer, &project, &zeta, cx).await; + buffer.update(cx, |buffer, cx| { + buffer.edit(edit_prediction.edits.iter().cloned(), None, cx) + }); + buffer.read_with(cx, |buffer, _| buffer.text()) + } + + async fn run_edit_prediction( + buffer: &Entity, + project: &Entity, + zeta: &Entity, + cx: &mut TestAppContext, + ) -> EditPrediction { + let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0))); + zeta.update(cx, |zeta, cx| zeta.register_buffer(buffer, &project, cx)); + cx.background_executor.run_until_parked(); + let completion_task = zeta.update(cx, |zeta, cx| { + zeta.request_completion(&project, buffer, cursor, cx) + }); + completion_task.await.unwrap().unwrap() + } + + async fn make_test_zeta( + project: &Entity, + cx: &mut TestAppContext, + ) -> ( + Entity, + Arc>>, + Arc>, + ) { + let default_response = indoc! {" + ```main.rs + <|start_of_file|> + <|editable_region_start|> + hello world + <|editable_region_end|> + ```" + }; + let captured_request: Arc>> = Arc::new(Mutex::new(None)); + let completion_response: Arc> = + Arc::new(Mutex::new(default_response.to_string())); + let http_client = FakeHttpClient::create({ + let captured_request = captured_request.clone(); + let completion_response = completion_response.clone(); + move |req| { + let captured_request = captured_request.clone(); + let completion_response = completion_response.clone(); + async move { + match (req.method(), req.uri().path()) { + (&Method::POST, "/client/llm_tokens") => { + Ok(http_client::Response::builder() + .status(200) + .body( + serde_json::to_string(&CreateLlmTokenResponse { + token: LlmToken("the-llm-token".to_string()), + }) + .unwrap() + .into(), + ) + .unwrap()) + } + (&Method::POST, "/predict_edits/v2") => { + let mut request_body = String::new(); + req.into_body().read_to_string(&mut request_body).await?; + *captured_request.lock() = + Some(serde_json::from_str(&request_body).unwrap()); + Ok(http_client::Response::builder() + .status(200) + .body( + serde_json::to_string(&PredictEditsResponse { + request_id: Uuid::new_v4(), + output_excerpt: completion_response.lock().clone(), + }) + .unwrap() + .into(), + ) + .unwrap()) + } + _ => Ok(http_client::Response::builder() + .status(404) + .body("Not Found".into()) + .unwrap()), + } } } }); @@ -1960,25 +2232,23 @@ mod tests { cx.update(|cx| { RefreshLlmTokenListener::register(client.clone(), cx); }); - // Construct the fake server to authenticate. let _server = FakeServer::for_client(42, &client, cx).await; - let fs = project::FakeFs::new(cx.executor()); - let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; - let buffer = cx.new(|cx| Buffer::local(buffer_content, cx)); - let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot()); - let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0))); - let zeta = cx.new(|cx| Zeta::new(client, project.read(cx).user_store(), cx)); - let completion_task = zeta.update(cx, |zeta, cx| { - zeta.request_completion(&project, &buffer, cursor, false, cx) + let zeta = cx.new(|cx| { + let mut zeta = Zeta::new(client, project.read(cx).user_store(), cx); + + let worktrees = project.read(cx).worktrees(cx).collect::>(); + for worktree in worktrees { + let worktree_id = worktree.read(cx).id(); + zeta.license_detection_watchers + .entry(worktree_id) + .or_insert_with(|| Rc::new(LicenseDetectionWatcher::new(&worktree, cx))); + } + + zeta }); - let completion = completion_task.await.unwrap().unwrap(); - completion - .edits - .iter() - .map(|(old_range, new_text)| (old_range.to_point(&snapshot), new_text.clone())) - .collect::>() + (zeta, captured_request, completion_response) } fn to_completion_edits( diff --git a/crates/zeta_cli/src/main.rs b/crates/zeta_cli/src/main.rs index e66eeed80920a0c31c5c06e119e17d418fbc294c..e7cec26b19358056cee4c8e253c54c0b2c794b33 100644 --- a/crates/zeta_cli/src/main.rs +++ b/crates/zeta_cli/src/main.rs @@ -189,30 +189,17 @@ async fn get_context( Some(events) => events.read_to_string().await?, None => String::new(), }; - // Enable gathering extra data not currently needed for edit predictions - let can_collect_data = true; - let git_info = None; - let mut gather_context_output = cx - .update(|cx| { - gather_context( - &project, - full_path_str, - &snapshot, - clipped_cursor, - move || events, - can_collect_data, - git_info, - cx, - ) - })? - .await; - - // Disable data collection for these requests, as this is currently just used for evals - if let Ok(gather_context_output) = gather_context_output.as_mut() { - gather_context_output.body.can_collect_data = false - } - - gather_context_output + let prompt_for_events = move || (events, 0); + cx.update(|cx| { + gather_context( + full_path_str, + &snapshot, + clipped_cursor, + prompt_for_events, + cx, + ) + })? + .await } pub async fn open_buffer_with_language_server(