diff --git a/Cargo.lock b/Cargo.lock index f4c94f8078b1ab392ed1a50e15c71dab1921f0a3..dbcea05ea9bc52288defc8c299d82eb508337544 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -20841,6 +20841,7 @@ dependencies = [ "language_model", "log", "menu", + "parking_lot", "postage", "project", "rand 0.9.1", diff --git a/crates/zed/src/zed/edit_prediction_registry.rs b/crates/zed/src/zed/edit_prediction_registry.rs index 4f009ccb0b1197f11b034ac48b89dd37b6f41278..ae26427fc6547079b163235f5d1c3df26a489795 100644 --- a/crates/zed/src/zed/edit_prediction_registry.rs +++ b/crates/zed/src/zed/edit_prediction_registry.rs @@ -8,7 +8,7 @@ use settings::SettingsStore; use std::{cell::RefCell, rc::Rc, sync::Arc}; use supermaven::{Supermaven, SupermavenCompletionProvider}; use ui::Window; -use zeta::{ProviderDataCollection, ZetaEditPredictionProvider}; +use zeta::ZetaEditPredictionProvider; pub fn init(client: Arc, user_store: Entity, cx: &mut App) { let editors: Rc, AnyWindowHandle>>> = Rc::default(); @@ -214,11 +214,8 @@ fn assign_edit_prediction_provider( }); } - let data_collection = - ProviderDataCollection::new(zeta.clone(), singleton_buffer, cx); - let provider = - cx.new(|_| zeta::ZetaEditPredictionProvider::new(zeta, data_collection)); + cx.new(|_| zeta::ZetaEditPredictionProvider::new(zeta, singleton_buffer)); editor.set_edit_prediction_provider(Some(provider), window, cx); } diff --git a/crates/zeta/Cargo.toml b/crates/zeta/Cargo.toml index a9c2a7619f4db22e51c014672aa2100b30a2539a..09bcfa7f542ce9c01802c9cebc11dfc9a8da2542 100644 --- a/crates/zeta/Cargo.toml +++ b/crates/zeta/Cargo.toml @@ -72,6 +72,7 @@ gpui = { workspace = true, features = ["test-support"] } http_client = { workspace = true, features = ["test-support"] } indoc.workspace = true language = { workspace = true, features = ["test-support"] } +parking_lot.workspace = true reqwest_client = { workspace = true, features = ["test-support"] } rpc = { workspace = true, features = ["test-support"] } settings = { workspace = true, features = ["test-support"] } diff --git a/crates/zeta/src/input_excerpt.rs b/crates/zeta/src/input_excerpt.rs index dd1bbed1d72e8668e9ed55c9b66b911addfcdd43..06bff5b1bea0f099b2ccd98605ac5de5bb5e6360 100644 --- a/crates/zeta/src/input_excerpt.rs +++ b/crates/zeta/src/input_excerpt.rs @@ -1,6 +1,6 @@ use crate::{ CURSOR_MARKER, EDITABLE_REGION_END_MARKER, EDITABLE_REGION_START_MARKER, START_OF_FILE_MARKER, - tokens_for_bytes, + guess_token_count, }; use language::{BufferSnapshot, Point}; use std::{fmt::Write, ops::Range}; @@ -22,7 +22,7 @@ pub fn excerpt_for_cursor_position( let mut remaining_edit_tokens = editable_region_token_limit; while let Some(parent) = snapshot.syntax_ancestor(scope_range.clone()) { - let parent_tokens = tokens_for_bytes(parent.byte_range().len()); + let parent_tokens = guess_token_count(parent.byte_range().len()); let parent_point_range = Point::new( parent.start_position().row as u32, parent.start_position().column as u32, @@ -99,7 +99,7 @@ fn expand_range( if remaining_tokens > 0 && expanded_range.start.row > 0 { expanded_range.start.row -= 1; let line_tokens = - tokens_for_bytes(snapshot.line_len(expanded_range.start.row) as usize); + guess_token_count(snapshot.line_len(expanded_range.start.row) as usize); remaining_tokens = remaining_tokens.saturating_sub(line_tokens); expanded = true; } @@ -107,7 +107,7 @@ fn expand_range( if remaining_tokens > 0 && expanded_range.end.row < snapshot.max_point().row { expanded_range.end.row += 1; expanded_range.end.column = snapshot.line_len(expanded_range.end.row); - let line_tokens = tokens_for_bytes(expanded_range.end.column as usize); + let line_tokens = guess_token_count(expanded_range.end.column as usize); remaining_tokens = remaining_tokens.saturating_sub(line_tokens); expanded = true; } diff --git a/crates/zeta/src/license_detection.rs b/crates/zeta/src/license_detection.rs index 5f207a44e8bd2028e6a2b416e978f101cfe5bd57..e06e1577a66cc160efa00213b80c6ca407f7be85 100644 --- a/crates/zeta/src/license_detection.rs +++ b/crates/zeta/src/license_detection.rs @@ -358,7 +358,6 @@ impl LicenseDetectionWatcher { #[cfg(test)] mod tests { - use fs::FakeFs; use gpui::TestAppContext; use serde_json::json; diff --git a/crates/zeta/src/zeta.rs b/crates/zeta/src/zeta.rs index 3851d16755783209fd9da4f468a494779a7d9fe7..dfcf98f025c2e020d6545efca64d4ab12579e370 100644 --- a/crates/zeta/src/zeta.rs +++ b/crates/zeta/src/zeta.rs @@ -29,7 +29,7 @@ use gpui::{ use http_client::{AsyncBody, HttpClient, Method, Request, Response}; use input_excerpt::excerpt_for_cursor_position; use language::{ - Anchor, Buffer, BufferSnapshot, EditPreview, OffsetRangeExt, ToOffset, ToPoint, text_diff, + Anchor, Buffer, BufferSnapshot, EditPreview, File, OffsetRangeExt, ToOffset, ToPoint, text_diff, }; use language_model::{LlmApiToken, RefreshLlmTokenListener}; use project::{Project, ProjectPath}; @@ -65,7 +65,6 @@ const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_ch const MAX_CONTEXT_TOKENS: usize = 150; const MAX_REWRITE_TOKENS: usize = 350; const MAX_EVENT_TOKENS: usize = 500; -const MAX_DIAGNOSTIC_GROUPS: usize = 10; /// Maximum number of events to track. const MAX_EVENT_COUNT: usize = 16; @@ -216,7 +215,7 @@ pub struct Zeta { client: Arc, shown_completions: VecDeque, rated_completions: HashSet, - data_collection_choice: Entity, + data_collection_choice: DataCollectionChoice, llm_token: LlmApiToken, _llm_token_subscription: Subscription, /// Whether an update to a newer version of Zed is required to continue using Zeta. @@ -271,10 +270,7 @@ impl Zeta { fn new(client: Arc, user_store: Entity, cx: &mut Context) -> Self { let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx); - - let data_collection_choice = Self::load_data_collection_choices(); - let data_collection_choice = cx.new(|_| data_collection_choice); - + let data_collection_choice = Self::load_data_collection_choice(); Self { projects: HashMap::default(), client, @@ -408,7 +404,6 @@ impl Zeta { project: &Entity, buffer: &Entity, cursor: language::Anchor, - can_collect_data: bool, cx: &mut Context, perform_predict_edits: F, ) -> Task>> @@ -422,15 +417,25 @@ impl Zeta { let buffer_snapshotted_at = Instant::now(); let snapshot = self.report_changes_for_buffer(&buffer, project, cx); let zeta = cx.entity(); - let events = self.get_or_init_zeta_project(project, cx).events.clone(); let client = self.client.clone(); let llm_token = self.llm_token.clone(); let app_version = AppVersion::global(cx); - let git_info = if let (true, Some(file)) = (can_collect_data, snapshot.file()) { - git_info_for_file(project, &ProjectPath::from_file(file.as_ref(), cx), cx) + let zeta_project = self.get_or_init_zeta_project(project, cx); + let mut events = Vec::with_capacity(zeta_project.events.len()); + events.extend(zeta_project.events.iter().cloned()); + let events = Arc::new(events); + + let (git_info, can_collect_file) = if let Some(file) = snapshot.file() { + let can_collect_file = self.can_collect_file(file, cx); + let git_info = if can_collect_file { + git_info_for_file(project, &ProjectPath::from_file(file.as_ref(), cx), cx) + } else { + None + }; + (git_info, can_collect_file) } else { - None + (None, false) }; let full_path: Arc = snapshot @@ -440,25 +445,35 @@ impl Zeta { let full_path_str = full_path.to_string_lossy().to_string(); let cursor_point = cursor.to_point(&snapshot); let cursor_offset = cursor_point.to_offset(&snapshot); - let make_events_prompt = move || prompt_for_events(&events, MAX_EVENT_TOKENS); + let prompt_for_events = { + let events = events.clone(); + move || prompt_for_events_impl(&events, MAX_EVENT_TOKENS) + }; let gather_task = gather_context( - project, full_path_str, &snapshot, cursor_point, - make_events_prompt, - can_collect_data, - git_info, + prompt_for_events, cx, ); cx.spawn(async move |this, cx| { let GatherContextOutput { - body, + mut body, editable_range, + included_events_count, } = gather_task.await?; let done_gathering_context_at = Instant::now(); + let included_events = &events[events.len() - included_events_count..events.len()]; + body.can_collect_data = can_collect_file + && this + .read_with(cx, |this, cx| this.can_collect_events(included_events, cx)) + .unwrap_or(false); + if body.can_collect_data { + body.git_info = git_info; + } + log::debug!( "Events:\n{}\nExcerpt:\n{:?}", body.input_events, @@ -563,10 +578,8 @@ impl Zeta { response: PredictEditsResponse, cx: &mut Context, ) -> Task>> { - use std::future::ready; - - self.request_completion_impl(project, buffer, position, false, cx, |_params| { - ready(Ok((response, None))) + self.request_completion_impl(project, buffer, position, cx, |_params| { + std::future::ready(Ok((response, None))) }) } @@ -575,17 +588,9 @@ impl Zeta { project: &Entity, buffer: &Entity, position: language::Anchor, - can_collect_data: bool, cx: &mut Context, ) -> Task>> { - self.request_completion_impl( - project, - buffer, - position, - can_collect_data, - cx, - Self::perform_predict_edits, - ) + self.request_completion_impl(project, buffer, position, cx, Self::perform_predict_edits) } pub fn perform_predict_edits( @@ -954,7 +959,58 @@ impl Zeta { new_snapshot } - fn load_data_collection_choices() -> DataCollectionChoice { + fn can_collect_file(&self, file: &Arc, cx: &App) -> bool { + self.data_collection_choice.is_enabled() && self.is_file_open_source(file, cx) + } + + fn can_collect_events(&self, events: &[Event], cx: &App) -> bool { + if !self.data_collection_choice.is_enabled() { + return false; + } + let mut last_checked_file = None; + for event in events { + match event { + Event::BufferChange { + old_snapshot, + new_snapshot, + .. + } => { + if let Some(old_file) = old_snapshot.file() + && let Some(new_file) = new_snapshot.file() + { + if let Some(last_checked_file) = last_checked_file + && Arc::ptr_eq(last_checked_file, old_file) + && Arc::ptr_eq(last_checked_file, new_file) + { + continue; + } + if !self.can_collect_file(old_file, cx) { + return false; + } + if !Arc::ptr_eq(old_file, new_file) && !self.can_collect_file(new_file, cx) + { + return false; + } + last_checked_file = Some(new_file); + } else { + return false; + } + } + } + } + true + } + + fn is_file_open_source(&self, file: &Arc, cx: &App) -> bool { + if !file.is_local() || file.is_private() { + return false; + } + self.license_detection_watchers + .get(&file.worktree_id(cx)) + .is_some_and(|watcher| watcher.is_project_open_source()) + } + + fn load_data_collection_choice() -> DataCollectionChoice { let choice = KEY_VALUE_STORE .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE) .log_err() @@ -970,6 +1026,17 @@ impl Zeta { None => DataCollectionChoice::NotAnswered, } } + + fn toggle_data_collection_choice(&mut self, cx: &mut Context) { + self.data_collection_choice = self.data_collection_choice.toggle(); + let new_choice = self.data_collection_choice; + db::write_and_log(cx, move || { + KEY_VALUE_STORE.write_kvp( + ZED_PREDICT_DATA_COLLECTION_CHOICE.into(), + new_choice.is_enabled().to_string(), + ) + }); + } } pub struct PerformPredictEditsParams { @@ -1026,48 +1093,19 @@ fn git_info_for_file( pub struct GatherContextOutput { pub body: PredictEditsBody, pub editable_range: Range, + pub included_events_count: usize, } pub fn gather_context( - project: &Entity, full_path_str: String, snapshot: &BufferSnapshot, cursor_point: language::Point, - make_events_prompt: impl FnOnce() -> String + Send + 'static, - can_collect_data: bool, - git_info: Option, + prompt_for_events: impl FnOnce() -> (String, usize) + Send + 'static, cx: &App, ) -> Task> { - let local_lsp_store = project.read(cx).lsp_store().read(cx).as_local(); - let diagnostic_groups: Vec<(String, serde_json::Value)> = - if can_collect_data && let Some(local_lsp_store) = local_lsp_store { - snapshot - .diagnostic_groups(None) - .into_iter() - .filter_map(|(language_server_id, diagnostic_group)| { - let language_server = - local_lsp_store.running_language_server_for_id(language_server_id)?; - let diagnostic_group = diagnostic_group.resolve::(snapshot); - let language_server_name = language_server.name().to_string(); - let serialized = serde_json::to_value(diagnostic_group).unwrap(); - Some((language_server_name, serialized)) - }) - .collect::>() - } else { - Vec::new() - }; - cx.background_spawn({ let snapshot = snapshot.clone(); async move { - let diagnostic_groups = if diagnostic_groups.is_empty() - || diagnostic_groups.len() >= MAX_DIAGNOSTIC_GROUPS - { - None - } else { - Some(diagnostic_groups) - }; - let input_excerpt = excerpt_for_cursor_position( cursor_point, &full_path_str, @@ -1075,15 +1113,15 @@ pub fn gather_context( MAX_REWRITE_TOKENS, MAX_CONTEXT_TOKENS, ); - let input_events = make_events_prompt(); + let (input_events, included_events_count) = prompt_for_events(); let editable_range = input_excerpt.editable_range.to_offset(&snapshot); let body = PredictEditsBody { input_events, input_excerpt: input_excerpt.prompt, - can_collect_data, - diagnostic_groups, - git_info, + can_collect_data: false, + diagnostic_groups: None, + git_info: None, outline: None, speculated_output: None, }; @@ -1091,18 +1129,19 @@ pub fn gather_context( Ok(GatherContextOutput { body, editable_range, + included_events_count, }) } }) } -fn prompt_for_events(events: &VecDeque, mut remaining_tokens: usize) -> String { +fn prompt_for_events_impl(events: &[Event], mut remaining_tokens: usize) -> (String, usize) { let mut result = String::new(); - for event in events.iter().rev() { + for (ix, event) in events.iter().rev().enumerate() { let event_string = event.to_prompt(); - let event_tokens = tokens_for_bytes(event_string.len()); + let event_tokens = guess_token_count(event_string.len()); if event_tokens > remaining_tokens { - break; + return (result, ix); } if !result.is_empty() { @@ -1111,7 +1150,7 @@ fn prompt_for_events(events: &VecDeque, mut remaining_tokens: usize) -> S result.insert_str(0, &event_string); remaining_tokens -= event_tokens; } - result + return (result, events.len()); } struct RegisteredBuffer { @@ -1222,6 +1261,7 @@ impl DataCollectionChoice { } } + #[must_use] pub fn toggle(&self) -> DataCollectionChoice { match self { Self::Enabled => Self::Disabled, @@ -1240,79 +1280,6 @@ impl From for DataCollectionChoice { } } -pub struct ProviderDataCollection { - /// When set to None, data collection is not possible in the provider buffer - choice: Option>, - license_detection_watcher: Option>, -} - -impl ProviderDataCollection { - pub fn new(zeta: Entity, buffer: Option>, cx: &mut App) -> Self { - let choice_and_watcher = buffer.and_then(|buffer| { - let file = buffer.read(cx).file()?; - - if !file.is_local() || file.is_private() { - return None; - } - - let zeta = zeta.read(cx); - let choice = zeta.data_collection_choice.clone(); - - let license_detection_watcher = zeta - .license_detection_watchers - .get(&file.worktree_id(cx)) - .cloned()?; - - Some((choice, license_detection_watcher)) - }); - - if let Some((choice, watcher)) = choice_and_watcher { - ProviderDataCollection { - choice: Some(choice), - license_detection_watcher: Some(watcher), - } - } else { - ProviderDataCollection { - choice: None, - license_detection_watcher: None, - } - } - } - - pub fn can_collect_data(&self, cx: &App) -> bool { - self.is_data_collection_enabled(cx) && self.is_project_open_source() - } - - pub fn is_data_collection_enabled(&self, cx: &App) -> bool { - self.choice - .as_ref() - .is_some_and(|choice| choice.read(cx).is_enabled()) - } - - fn is_project_open_source(&self) -> bool { - self.license_detection_watcher - .as_ref() - .is_some_and(|watcher| watcher.is_project_open_source()) - } - - pub fn toggle(&mut self, cx: &mut App) { - if let Some(choice) = self.choice.as_mut() { - let new_choice = choice.update(cx, |choice, _cx| { - let new_choice = choice.toggle(); - *choice = new_choice; - new_choice - }); - - db::write_and_log(cx, move || { - KEY_VALUE_STORE.write_kvp( - ZED_PREDICT_DATA_COLLECTION_CHOICE.into(), - new_choice.is_enabled().to_string(), - ) - }); - } - } -} - async fn llm_token_retry( llm_token: &LlmApiToken, client: &Arc, @@ -1343,24 +1310,23 @@ async fn llm_token_retry( pub struct ZetaEditPredictionProvider { zeta: Entity, + singleton_buffer: Option>, pending_completions: ArrayVec, next_pending_completion_id: usize, current_completion: Option, - /// None if this is entirely disabled for this provider - provider_data_collection: ProviderDataCollection, last_request_timestamp: Instant, } impl ZetaEditPredictionProvider { pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300); - pub fn new(zeta: Entity, provider_data_collection: ProviderDataCollection) -> Self { + pub fn new(zeta: Entity, singleton_buffer: Option>) -> Self { Self { zeta, + singleton_buffer, pending_completions: ArrayVec::new(), next_pending_completion_id: 0, current_completion: None, - provider_data_collection, last_request_timestamp: Instant::now(), } } @@ -1384,21 +1350,29 @@ impl edit_prediction::EditPredictionProvider for ZetaEditPredictionProvider { } fn data_collection_state(&self, cx: &App) -> DataCollectionState { - let is_project_open_source = self.provider_data_collection.is_project_open_source(); - - if self.provider_data_collection.is_data_collection_enabled(cx) { - DataCollectionState::Enabled { - is_project_open_source, + if let Some(buffer) = &self.singleton_buffer + && let Some(file) = buffer.read(cx).file() + { + let is_project_open_source = self.zeta.read(cx).is_file_open_source(file, cx); + if self.zeta.read(cx).data_collection_choice.is_enabled() { + DataCollectionState::Enabled { + is_project_open_source, + } + } else { + DataCollectionState::Disabled { + is_project_open_source, + } } } else { - DataCollectionState::Disabled { - is_project_open_source, - } + return DataCollectionState::Disabled { + is_project_open_source: false, + }; } } fn toggle_data_collection(&mut self, cx: &mut App) { - self.provider_data_collection.toggle(cx); + self.zeta + .update(cx, |zeta, cx| zeta.toggle_data_collection_choice(cx)); } fn usage(&self, cx: &App) -> Option { @@ -1456,7 +1430,6 @@ impl edit_prediction::EditPredictionProvider for ZetaEditPredictionProvider { let pending_completion_id = self.next_pending_completion_id; self.next_pending_completion_id += 1; - let can_collect_data = self.provider_data_collection.can_collect_data(cx); let last_request_timestamp = self.last_request_timestamp; let task = cx.spawn(async move |this, cx| { @@ -1469,7 +1442,7 @@ impl edit_prediction::EditPredictionProvider for ZetaEditPredictionProvider { let completion_request = this.update(cx, |this, cx| { this.last_request_timestamp = Instant::now(); this.zeta.update(cx, |zeta, cx| { - zeta.request_completion(&project, &buffer, position, can_collect_data, cx) + zeta.request_completion(&project, &buffer, position, cx) }) }); @@ -1638,10 +1611,11 @@ impl edit_prediction::EditPredictionProvider for ZetaEditPredictionProvider { } } -fn tokens_for_bytes(bytes: usize) -> usize { - /// Typical number of string bytes per token for the purposes of limiting model input. This is - /// intentionally low to err on the side of underestimating limits. - const BYTES_PER_TOKEN_GUESS: usize = 3; +/// Typical number of string bytes per token for the purposes of limiting model input. This is +/// intentionally low to err on the side of underestimating limits. +const BYTES_PER_TOKEN_GUESS: usize = 3; + +fn guess_token_count(bytes: usize) -> usize { bytes / BYTES_PER_TOKEN_GUESS } @@ -1654,11 +1628,15 @@ mod tests { use http_client::FakeHttpClient; use indoc::indoc; use language::Point; + use parking_lot::Mutex; + use serde_json::json; use settings::SettingsStore; use util::path; use super::*; + const BSD_0_TXT: &str = include_str!("../license_examples/0bsd.txt"); + #[gpui::test] async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) { let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx)); @@ -1778,77 +1756,65 @@ mod tests { #[gpui::test] async fn test_clean_up_diff(cx: &mut TestAppContext) { - cx.update(|cx| { - let settings_store = SettingsStore::test(cx); - cx.set_global(settings_store); - client::init_settings(cx); - Project::init_settings(cx); - }); + init_test(cx); - let edits = edits_for_prediction( - indoc! {" - fn main() { - let word_1 = \"lorem\"; - let range = word.len()..word.len(); - } - "}, + assert_eq!( + apply_edit_prediction( + indoc! {" + fn main() { + let word_1 = \"lorem\"; + let range = word.len()..word.len(); + } + "}, + indoc! {" + <|editable_region_start|> + fn main() { + let word_1 = \"lorem\"; + let range = word_1.len()..word_1.len(); + } + + <|editable_region_end|> + "}, + cx, + ) + .await, indoc! {" - <|editable_region_start|> fn main() { let word_1 = \"lorem\"; let range = word_1.len()..word_1.len(); } - - <|editable_region_end|> "}, - cx, - ) - .await; - assert_eq!( - edits, - [ - (Point::new(2, 20)..Point::new(2, 20), "_1".to_string()), - (Point::new(2, 32)..Point::new(2, 32), "_1".to_string()), - ] ); - let edits = edits_for_prediction( - indoc! {" - fn main() { - let story = \"the quick\" - } - "}, + assert_eq!( + apply_edit_prediction( + indoc! {" + fn main() { + let story = \"the quick\" + } + "}, + indoc! {" + <|editable_region_start|> + fn main() { + let story = \"the quick brown fox jumps over the lazy dog\"; + } + + <|editable_region_end|> + "}, + cx, + ) + .await, indoc! {" - <|editable_region_start|> fn main() { let story = \"the quick brown fox jumps over the lazy dog\"; } - - <|editable_region_end|> "}, - cx, - ) - .await; - assert_eq!( - edits, - [ - ( - Point::new(1, 26)..Point::new(1, 26), - " brown fox jumps over the lazy dog".to_string() - ), - (Point::new(1, 27)..Point::new(1, 27), ";".to_string()), - ] ); } #[gpui::test] async fn test_edit_prediction_end_of_buffer(cx: &mut TestAppContext) { - cx.update(|cx| { - let settings_store = SettingsStore::test(cx); - cx.set_global(settings_store); - client::init_settings(cx); - Project::init_settings(cx); - }); + init_test(cx); let buffer_content = "lorem\n"; let completion_response = indoc! {" @@ -1860,98 +1826,404 @@ mod tests { <|editable_region_end|> ```"}; - let http_client = FakeHttpClient::create(move |req| async move { - match (req.method(), req.uri().path()) { - (&Method::POST, "/client/llm_tokens") => Ok(http_client::Response::builder() - .status(200) - .body( - serde_json::to_string(&CreateLlmTokenResponse { - token: LlmToken("the-llm-token".to_string()), - }) - .unwrap() - .into(), - ) - .unwrap()), - (&Method::POST, "/predict_edits/v2") => Ok(http_client::Response::builder() - .status(200) - .body( - serde_json::to_string(&PredictEditsResponse { - request_id: Uuid::parse_str("7e86480f-3536-4d2c-9334-8213e3445d45") - .unwrap(), - output_excerpt: completion_response.to_string(), - }) - .unwrap() - .into(), - ) - .unwrap()), - _ => Ok(http_client::Response::builder() - .status(404) - .body("Not Found".into()) - .unwrap()), - } + assert_eq!( + apply_edit_prediction(buffer_content, completion_response, cx).await, + "lorem\nipsum" + ); + } + + #[gpui::test] + async fn test_can_collect_data(cx: &mut TestAppContext) { + init_test(cx); + + let fs = project::FakeFs::new(cx.executor()); + fs.insert_tree(path!("/project"), json!({ "LICENSE": BSD_0_TXT })) + .await; + + let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; + let buffer = project + .update(cx, |project, cx| { + project.open_local_buffer(path!("/project/src/main.rs"), cx) + }) + .await + .unwrap(); + + let (zeta, captured_request, _) = make_test_zeta(&project, cx).await; + zeta.update(cx, |zeta, _cx| { + zeta.data_collection_choice = DataCollectionChoice::Enabled }); - let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx)); - cx.update(|cx| { - RefreshLlmTokenListener::register(client.clone(), cx); + run_edit_prediction(&buffer, &project, &zeta, cx).await; + assert_eq!( + captured_request.lock().clone().unwrap().can_collect_data, + true + ); + + zeta.update(cx, |zeta, _cx| { + zeta.data_collection_choice = DataCollectionChoice::Disabled }); - // Construct the fake server to authenticate. - let _server = FakeServer::for_client(42, &client, cx).await; + + run_edit_prediction(&buffer, &project, &zeta, cx).await; + assert_eq!( + captured_request.lock().clone().unwrap().can_collect_data, + false + ); + } + + #[gpui::test] + async fn test_no_data_collection_for_remote_file(cx: &mut TestAppContext) { + init_test(cx); + + let fs = project::FakeFs::new(cx.executor()); + let project = Project::test(fs.clone(), [], cx).await; + + let buffer = cx.new(|_cx| { + Buffer::remote( + language::BufferId::new(1).unwrap(), + 1, + language::Capability::ReadWrite, + "fn main() {\n println!(\"Hello\");\n}", + ) + }); + + let (zeta, captured_request, _) = make_test_zeta(&project, cx).await; + zeta.update(cx, |zeta, _cx| { + zeta.data_collection_choice = DataCollectionChoice::Enabled + }); + + run_edit_prediction(&buffer, &project, &zeta, cx).await; + assert_eq!( + captured_request.lock().clone().unwrap().can_collect_data, + false + ); + } + + #[gpui::test] + async fn test_no_data_collection_for_private_file(cx: &mut TestAppContext) { + init_test(cx); + let fs = project::FakeFs::new(cx.executor()); + fs.insert_tree( + path!("/project"), + json!({ + "LICENSE": BSD_0_TXT, + ".env": "SECRET_KEY=secret" + }), + ) + .await; + let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; - let buffer = cx.new(|cx| Buffer::local(buffer_content, cx)); - let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0))); + let buffer = project + .update(cx, |project, cx| { + project.open_local_buffer("/project/.env", cx) + }) + .await + .unwrap(); - let zeta = cx.new(|cx| Zeta::new(client, project.read(cx).user_store(), cx)); - let completion_task = zeta.update(cx, |zeta, cx| { - zeta.request_completion(&project, &buffer, cursor, false, cx) + let (zeta, captured_request, _) = make_test_zeta(&project, cx).await; + zeta.update(cx, |zeta, _cx| { + zeta.data_collection_choice = DataCollectionChoice::Enabled + }); + + run_edit_prediction(&buffer, &project, &zeta, cx).await; + assert_eq!( + captured_request.lock().clone().unwrap().can_collect_data, + false + ); + } + + #[gpui::test] + async fn test_no_data_collection_for_untitled_buffer(cx: &mut TestAppContext) { + init_test(cx); + + let fs = project::FakeFs::new(cx.executor()); + let project = Project::test(fs.clone(), [], cx).await; + let buffer = cx.new(|cx| Buffer::local("", cx)); + + let (zeta, captured_request, _) = make_test_zeta(&project, cx).await; + zeta.update(cx, |zeta, _cx| { + zeta.data_collection_choice = DataCollectionChoice::Enabled + }); + + run_edit_prediction(&buffer, &project, &zeta, cx).await; + assert_eq!( + captured_request.lock().clone().unwrap().can_collect_data, + false + ); + } + + #[gpui::test] + async fn test_no_data_collection_when_closed_source(cx: &mut TestAppContext) { + init_test(cx); + + let fs = project::FakeFs::new(cx.executor()); + fs.insert_tree(path!("/project"), json!({ "main.rs": "fn main() {}" })) + .await; + + let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; + let buffer = project + .update(cx, |project, cx| { + project.open_local_buffer("/project/main.rs", cx) + }) + .await + .unwrap(); + + let (zeta, captured_request, _) = make_test_zeta(&project, cx).await; + zeta.update(cx, |zeta, _cx| { + zeta.data_collection_choice = DataCollectionChoice::Enabled + }); + + run_edit_prediction(&buffer, &project, &zeta, cx).await; + assert_eq!( + captured_request.lock().clone().unwrap().can_collect_data, + false + ); + } + + #[gpui::test] + async fn test_data_collection_status_changes_on_move(cx: &mut TestAppContext) { + init_test(cx); + + let fs = project::FakeFs::new(cx.executor()); + fs.insert_tree( + path!("/open_source_worktree"), + json!({ "LICENSE": BSD_0_TXT, "main.rs": "" }), + ) + .await; + fs.insert_tree(path!("/closed_source_worktree"), json!({ "main.rs": "" })) + .await; + + let project = Project::test( + fs.clone(), + [ + path!("/open_source_worktree").as_ref(), + path!("/closed_source_worktree").as_ref(), + ], + cx, + ) + .await; + let buffer = project + .update(cx, |project, cx| { + project.open_local_buffer(path!("/open_source_worktree/main.rs"), cx) + }) + .await + .unwrap(); + + let (zeta, captured_request, _) = make_test_zeta(&project, cx).await; + zeta.update(cx, |zeta, _cx| { + zeta.data_collection_choice = DataCollectionChoice::Enabled }); - let completion = completion_task.await.unwrap().unwrap(); + run_edit_prediction(&buffer, &project, &zeta, cx).await; + assert_eq!( + captured_request.lock().clone().unwrap().can_collect_data, + true + ); + + let closed_source_file = project + .update(cx, |project, cx| { + let worktree2 = project + .worktree_for_root_name("closed_source_worktree", cx) + .unwrap(); + worktree2.update(cx, |worktree2, cx| { + worktree2.load_file(Path::new("main.rs"), cx) + }) + }) + .await + .unwrap() + .file; + buffer.update(cx, |buffer, cx| { - buffer.edit(completion.edits.iter().cloned(), None, cx) + buffer.file_updated(closed_source_file, cx); }); + + run_edit_prediction(&buffer, &project, &zeta, cx).await; assert_eq!( - buffer.read_with(cx, |buffer, _| buffer.text()), - "lorem\nipsum" + captured_request.lock().clone().unwrap().can_collect_data, + false ); } - async fn edits_for_prediction( + #[gpui::test] + async fn test_no_data_collection_for_events_in_uncollectable_buffers(cx: &mut TestAppContext) { + init_test(cx); + + let fs = project::FakeFs::new(cx.executor()); + fs.insert_tree( + path!("/worktree1"), + json!({ "LICENSE": BSD_0_TXT, "main.rs": "", "other.rs": "" }), + ) + .await; + fs.insert_tree(path!("/worktree2"), json!({ "private.rs": "" })) + .await; + + let project = Project::test( + fs.clone(), + [path!("/worktree1").as_ref(), path!("/worktree2").as_ref()], + cx, + ) + .await; + let buffer = project + .update(cx, |project, cx| { + project.open_local_buffer(path!("/worktree1/main.rs"), cx) + }) + .await + .unwrap(); + let private_buffer = project + .update(cx, |project, cx| { + project.open_local_buffer(path!("/worktree2/file.rs"), cx) + }) + .await + .unwrap(); + + let (zeta, captured_request, _) = make_test_zeta(&project, cx).await; + zeta.update(cx, |zeta, _cx| { + zeta.data_collection_choice = DataCollectionChoice::Enabled + }); + + run_edit_prediction(&buffer, &project, &zeta, cx).await; + assert_eq!( + captured_request.lock().clone().unwrap().can_collect_data, + true + ); + + // this has a side effect of registering the buffer to watch for edits + run_edit_prediction(&private_buffer, &project, &zeta, cx).await; + assert_eq!( + captured_request.lock().clone().unwrap().can_collect_data, + false + ); + + private_buffer.update(cx, |private_buffer, cx| { + private_buffer.edit([(0..0, "An edit for the history!")], None, cx); + }); + + run_edit_prediction(&buffer, &project, &zeta, cx).await; + assert_eq!( + captured_request.lock().clone().unwrap().can_collect_data, + false + ); + + // make an edit that uses too many bytes, causing private_buffer edit to not be able to be + // included + buffer.update(cx, |buffer, cx| { + buffer.edit( + [(0..0, " ".repeat(MAX_EVENT_TOKENS * BYTES_PER_TOKEN_GUESS))], + None, + cx, + ); + }); + + run_edit_prediction(&buffer, &project, &zeta, cx).await; + assert_eq!( + captured_request.lock().clone().unwrap().can_collect_data, + true + ); + } + + fn init_test(cx: &mut TestAppContext) { + cx.update(|cx| { + let settings_store = SettingsStore::test(cx); + cx.set_global(settings_store); + language::init(cx); + client::init_settings(cx); + Project::init_settings(cx); + }); + } + + async fn apply_edit_prediction( buffer_content: &str, completion_response: &str, cx: &mut TestAppContext, - ) -> Vec<(Range, String)> { - let completion_response = completion_response.to_string(); - let http_client = FakeHttpClient::create(move |req| { - let completion = completion_response.clone(); - async move { - match (req.method(), req.uri().path()) { - (&Method::POST, "/client/llm_tokens") => Ok(http_client::Response::builder() - .status(200) - .body( - serde_json::to_string(&CreateLlmTokenResponse { - token: LlmToken("the-llm-token".to_string()), - }) - .unwrap() - .into(), - ) - .unwrap()), - (&Method::POST, "/predict_edits/v2") => Ok(http_client::Response::builder() - .status(200) - .body( - serde_json::to_string(&PredictEditsResponse { - request_id: Uuid::new_v4(), - output_excerpt: completion, - }) - .unwrap() - .into(), - ) - .unwrap()), - _ => Ok(http_client::Response::builder() - .status(404) - .body("Not Found".into()) - .unwrap()), + ) -> String { + let fs = project::FakeFs::new(cx.executor()); + let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; + let buffer = cx.new(|cx| Buffer::local(buffer_content, cx)); + let (zeta, _, response) = make_test_zeta(&project, cx).await; + *response.lock() = completion_response.to_string(); + let edit_prediction = run_edit_prediction(&buffer, &project, &zeta, cx).await; + buffer.update(cx, |buffer, cx| { + buffer.edit(edit_prediction.edits.iter().cloned(), None, cx) + }); + buffer.read_with(cx, |buffer, _| buffer.text()) + } + + async fn run_edit_prediction( + buffer: &Entity, + project: &Entity, + zeta: &Entity, + cx: &mut TestAppContext, + ) -> EditPrediction { + let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0))); + zeta.update(cx, |zeta, cx| zeta.register_buffer(buffer, &project, cx)); + cx.background_executor.run_until_parked(); + let completion_task = zeta.update(cx, |zeta, cx| { + zeta.request_completion(&project, buffer, cursor, cx) + }); + completion_task.await.unwrap().unwrap() + } + + async fn make_test_zeta( + project: &Entity, + cx: &mut TestAppContext, + ) -> ( + Entity, + Arc>>, + Arc>, + ) { + let default_response = indoc! {" + ```main.rs + <|start_of_file|> + <|editable_region_start|> + hello world + <|editable_region_end|> + ```" + }; + let captured_request: Arc>> = Arc::new(Mutex::new(None)); + let completion_response: Arc> = + Arc::new(Mutex::new(default_response.to_string())); + let http_client = FakeHttpClient::create({ + let captured_request = captured_request.clone(); + let completion_response = completion_response.clone(); + move |req| { + let captured_request = captured_request.clone(); + let completion_response = completion_response.clone(); + async move { + match (req.method(), req.uri().path()) { + (&Method::POST, "/client/llm_tokens") => { + Ok(http_client::Response::builder() + .status(200) + .body( + serde_json::to_string(&CreateLlmTokenResponse { + token: LlmToken("the-llm-token".to_string()), + }) + .unwrap() + .into(), + ) + .unwrap()) + } + (&Method::POST, "/predict_edits/v2") => { + let mut request_body = String::new(); + req.into_body().read_to_string(&mut request_body).await?; + *captured_request.lock() = + Some(serde_json::from_str(&request_body).unwrap()); + Ok(http_client::Response::builder() + .status(200) + .body( + serde_json::to_string(&PredictEditsResponse { + request_id: Uuid::new_v4(), + output_excerpt: completion_response.lock().clone(), + }) + .unwrap() + .into(), + ) + .unwrap()) + } + _ => Ok(http_client::Response::builder() + .status(404) + .body("Not Found".into()) + .unwrap()), + } } } }); @@ -1960,25 +2232,23 @@ mod tests { cx.update(|cx| { RefreshLlmTokenListener::register(client.clone(), cx); }); - // Construct the fake server to authenticate. let _server = FakeServer::for_client(42, &client, cx).await; - let fs = project::FakeFs::new(cx.executor()); - let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; - let buffer = cx.new(|cx| Buffer::local(buffer_content, cx)); - let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot()); - let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0))); - let zeta = cx.new(|cx| Zeta::new(client, project.read(cx).user_store(), cx)); - let completion_task = zeta.update(cx, |zeta, cx| { - zeta.request_completion(&project, &buffer, cursor, false, cx) + let zeta = cx.new(|cx| { + let mut zeta = Zeta::new(client, project.read(cx).user_store(), cx); + + let worktrees = project.read(cx).worktrees(cx).collect::>(); + for worktree in worktrees { + let worktree_id = worktree.read(cx).id(); + zeta.license_detection_watchers + .entry(worktree_id) + .or_insert_with(|| Rc::new(LicenseDetectionWatcher::new(&worktree, cx))); + } + + zeta }); - let completion = completion_task.await.unwrap().unwrap(); - completion - .edits - .iter() - .map(|(old_range, new_text)| (old_range.to_point(&snapshot), new_text.clone())) - .collect::>() + (zeta, captured_request, completion_response) } fn to_completion_edits( diff --git a/crates/zeta_cli/src/main.rs b/crates/zeta_cli/src/main.rs index e66eeed80920a0c31c5c06e119e17d418fbc294c..e7cec26b19358056cee4c8e253c54c0b2c794b33 100644 --- a/crates/zeta_cli/src/main.rs +++ b/crates/zeta_cli/src/main.rs @@ -189,30 +189,17 @@ async fn get_context( Some(events) => events.read_to_string().await?, None => String::new(), }; - // Enable gathering extra data not currently needed for edit predictions - let can_collect_data = true; - let git_info = None; - let mut gather_context_output = cx - .update(|cx| { - gather_context( - &project, - full_path_str, - &snapshot, - clipped_cursor, - move || events, - can_collect_data, - git_info, - cx, - ) - })? - .await; - - // Disable data collection for these requests, as this is currently just used for evals - if let Ok(gather_context_output) = gather_context_output.as_mut() { - gather_context_output.body.can_collect_data = false - } - - gather_context_output + let prompt_for_events = move || (events, 0); + cx.update(|cx| { + gather_context( + full_path_str, + &snapshot, + clipped_cursor, + prompt_for_events, + cx, + ) + })? + .await } pub async fn open_buffer_with_language_server(