diff --git a/crates/zed/src/zed/edit_prediction_registry.rs b/crates/zed/src/zed/edit_prediction_registry.rs index 7b8b98018e6d6c608574ab81e912e8a98e363046..4f009ccb0b1197f11b034ac48b89dd37b6f41278 100644 --- a/crates/zed/src/zed/edit_prediction_registry.rs +++ b/crates/zed/src/zed/edit_prediction_registry.rs @@ -207,9 +207,10 @@ fn assign_edit_prediction_provider( if let Some(buffer) = &singleton_buffer && buffer.read(cx).file().is_some() + && let Some(project) = editor.project() { zeta.update(cx, |zeta, cx| { - zeta.register_buffer(buffer, cx); + zeta.register_buffer(buffer, project, cx); }); } diff --git a/crates/zeta/src/zeta.rs b/crates/zeta/src/zeta.rs index e0cfd23dd26cd7ea49181b5aabc16f00f4fd826a..3851d16755783209fd9da4f468a494779a7d9fe7 100644 --- a/crates/zeta/src/zeta.rs +++ b/crates/zeta/src/zeta.rs @@ -35,12 +35,13 @@ use language_model::{LlmApiToken, RefreshLlmTokenListener}; use project::{Project, ProjectPath}; use release_channel::AppVersion; use settings::WorktreeId; +use std::collections::hash_map; +use std::mem; use std::str::FromStr; use std::{ cmp, fmt::Write, future::Future, - mem, ops::Range, path::Path, rc::Rc, @@ -211,9 +212,8 @@ impl std::fmt::Debug for EditPrediction { } pub struct Zeta { + projects: HashMap, client: Arc, - events: VecDeque, - registered_buffers: HashMap, shown_completions: VecDeque, rated_completions: HashSet, data_collection_choice: Entity, @@ -225,6 +225,11 @@ pub struct Zeta { license_detection_watchers: HashMap>, } +struct ZetaProject { + events: VecDeque, + registered_buffers: HashMap, +} + impl Zeta { pub fn global(cx: &mut App) -> Option> { cx.try_global::().map(|global| global.0.clone()) @@ -255,7 +260,9 @@ impl Zeta { } pub fn clear_history(&mut self) { - self.events.clear(); + for zeta_project in self.projects.values_mut() { + zeta_project.events.clear(); + } } pub fn usage(&self, cx: &App) -> Option { @@ -269,11 +276,10 @@ impl Zeta { let data_collection_choice = cx.new(|_| data_collection_choice); Self { + projects: HashMap::default(), client, - events: VecDeque::new(), shown_completions: VecDeque::new(), rated_completions: HashSet::default(), - registered_buffers: HashMap::default(), data_collection_choice, llm_token: LlmApiToken::default(), _llm_token_subscription: cx.subscribe( @@ -294,12 +300,35 @@ impl Zeta { } } - fn push_event(&mut self, event: Event) { + fn get_or_init_zeta_project( + &mut self, + project: &Entity, + cx: &mut Context, + ) -> &mut ZetaProject { + let project_id = project.entity_id(); + match self.projects.entry(project_id) { + hash_map::Entry::Occupied(entry) => entry.into_mut(), + hash_map::Entry::Vacant(entry) => { + cx.observe_release(project, move |this, _, _cx| { + this.projects.remove(&project_id); + }) + .detach(); + entry.insert(ZetaProject { + events: VecDeque::with_capacity(MAX_EVENT_COUNT), + registered_buffers: HashMap::default(), + }) + } + } + } + + fn push_event(zeta_project: &mut ZetaProject, event: Event) { + let events = &mut zeta_project.events; + if let Some(Event::BufferChange { new_snapshot: last_new_snapshot, timestamp: last_timestamp, .. - }) = self.events.back_mut() + }) = events.back_mut() { // Coalesce edits for the same buffer when they happen one after the other. let Event::BufferChange { @@ -318,50 +347,65 @@ impl Zeta { } } - self.events.push_back(event); - if self.events.len() >= MAX_EVENT_COUNT { + if events.len() >= MAX_EVENT_COUNT { // These are halved instead of popping to improve prompt caching. - self.events.drain(..MAX_EVENT_COUNT / 2); + events.drain(..MAX_EVENT_COUNT / 2); } - } - - pub fn register_buffer(&mut self, buffer: &Entity, cx: &mut Context) { - let buffer_id = buffer.entity_id(); - let weak_buffer = buffer.downgrade(); - - if let std::collections::hash_map::Entry::Vacant(entry) = - self.registered_buffers.entry(buffer_id) - { - let snapshot = buffer.read(cx).snapshot(); - entry.insert(RegisteredBuffer { - snapshot, - _subscriptions: [ - cx.subscribe(buffer, move |this, buffer, event, cx| { - this.handle_buffer_event(buffer, event, cx); - }), - cx.observe_release(buffer, move |this, _buffer, _cx| { - this.registered_buffers.remove(&weak_buffer.entity_id()); - }), - ], - }); - }; + events.push_back(event); } - fn handle_buffer_event( + pub fn register_buffer( &mut self, - buffer: Entity, - event: &language::BufferEvent, + buffer: &Entity, + project: &Entity, cx: &mut Context, ) { - if let language::BufferEvent::Edited = event { - self.report_changes_for_buffer(&buffer, cx); + let zeta_project = self.get_or_init_zeta_project(project, cx); + Self::register_buffer_impl(zeta_project, buffer, project, cx); + } + + fn register_buffer_impl<'a>( + zeta_project: &'a mut ZetaProject, + buffer: &Entity, + project: &Entity, + cx: &mut Context, + ) -> &'a mut RegisteredBuffer { + let buffer_id = buffer.entity_id(); + match zeta_project.registered_buffers.entry(buffer_id) { + hash_map::Entry::Occupied(entry) => entry.into_mut(), + hash_map::Entry::Vacant(entry) => { + let snapshot = buffer.read(cx).snapshot(); + let project_entity_id = project.entity_id(); + entry.insert(RegisteredBuffer { + snapshot, + _subscriptions: [ + cx.subscribe(buffer, { + let project = project.downgrade(); + move |this, buffer, event, cx| { + if let language::BufferEvent::Edited = event + && let Some(project) = project.upgrade() + { + this.report_changes_for_buffer(&buffer, &project, cx); + } + } + }), + cx.observe_release(buffer, move |this, _buffer, _cx| { + let Some(zeta_project) = this.projects.get_mut(&project_entity_id) + else { + return; + }; + zeta_project.registered_buffers.remove(&buffer_id); + }), + ], + }) + } } } fn request_completion_impl( &mut self, - project: Option<&Entity>, + project: &Entity, buffer: &Entity, cursor: language::Anchor, can_collect_data: bool, @@ -376,16 +420,14 @@ impl Zeta { { let buffer = buffer.clone(); let buffer_snapshotted_at = Instant::now(); - let snapshot = self.report_changes_for_buffer(&buffer, cx); + let snapshot = self.report_changes_for_buffer(&buffer, project, cx); let zeta = cx.entity(); - let events = self.events.clone(); + let events = self.get_or_init_zeta_project(project, cx).events.clone(); let client = self.client.clone(); let llm_token = self.llm_token.clone(); let app_version = AppVersion::global(cx); - let git_info = if let (true, Some(project), Some(file)) = - (can_collect_data, project, snapshot.file()) - { + let git_info = if let (true, Some(file)) = (can_collect_data, snapshot.file()) { git_info_for_file(project, &ProjectPath::from_file(file.as_ref(), cx), cx) } else { None @@ -512,163 +554,10 @@ impl Zeta { }) } - // Generates several example completions of various states to fill the Zeta completion modal - #[cfg(any(test, feature = "test-support"))] - pub fn fill_with_fake_completions(&mut self, cx: &mut Context) -> Task<()> { - use language::Point; - - let test_buffer_text = indoc::indoc! {r#"a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line - And maybe a short line - - Then a few lines - - and then another - "#}; - - let project = None; - let buffer = cx.new(|cx| Buffer::local(test_buffer_text, cx)); - let position = buffer.read(cx).anchor_before(Point::new(1, 0)); - - let completion_tasks = vec![ - self.fake_completion( - project, - &buffer, - position, - PredictEditsResponse { - request_id: Uuid::parse_str("e7861db5-0cea-4761-b1c5-ad083ac53a80").unwrap(), - output_excerpt: format!("{EDITABLE_REGION_START_MARKER} -a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line -[here's an edit] -And maybe a short line -Then a few lines -and then another -{EDITABLE_REGION_END_MARKER} - ", ), - }, - cx, - ), - self.fake_completion( - project, - &buffer, - position, - PredictEditsResponse { - request_id: Uuid::parse_str("077c556a-2c49-44e2-bbc6-dafc09032a5e").unwrap(), - output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER} -a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line -And maybe a short line -[and another edit] -Then a few lines -and then another -{EDITABLE_REGION_END_MARKER} - "#), - }, - cx, - ), - self.fake_completion( - project, - &buffer, - position, - PredictEditsResponse { - request_id: Uuid::parse_str("df8c7b23-3d1d-4f99-a306-1f6264a41277").unwrap(), - output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER} -a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line -And maybe a short line - -Then a few lines - -and then another -{EDITABLE_REGION_END_MARKER} - "#), - }, - cx, - ), - self.fake_completion( - project, - &buffer, - position, - PredictEditsResponse { - request_id: Uuid::parse_str("c743958d-e4d8-44a8-aa5b-eb1e305c5f5c").unwrap(), - output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER} -a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line -And maybe a short line - -Then a few lines - -and then another -{EDITABLE_REGION_END_MARKER} - "#), - }, - cx, - ), - self.fake_completion( - project, - &buffer, - position, - PredictEditsResponse { - request_id: Uuid::parse_str("ff5cd7ab-ad06-4808-986e-d3391e7b8355").unwrap(), - output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER} -a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line -And maybe a short line -Then a few lines -[a third completion] -and then another -{EDITABLE_REGION_END_MARKER} - "#), - }, - cx, - ), - self.fake_completion( - project, - &buffer, - position, - PredictEditsResponse { - request_id: Uuid::parse_str("83cafa55-cdba-4b27-8474-1865ea06be94").unwrap(), - output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER} -a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line -And maybe a short line -and then another -[fourth completion example] -{EDITABLE_REGION_END_MARKER} - "#), - }, - cx, - ), - self.fake_completion( - project, - &buffer, - position, - PredictEditsResponse { - request_id: Uuid::parse_str("d5bd3afd-8723-47c7-bd77-15a3a926867b").unwrap(), - output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER} -a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line -And maybe a short line -Then a few lines -and then another -[fifth and final completion] -{EDITABLE_REGION_END_MARKER} - "#), - }, - cx, - ), - ]; - - cx.spawn(async move |zeta, cx| { - for task in completion_tasks { - task.await.unwrap(); - } - - zeta.update(cx, |zeta, _cx| { - zeta.shown_completions.get_mut(2).unwrap().edits = Arc::new([]); - zeta.shown_completions.get_mut(3).unwrap().edits = Arc::new([]); - }) - .ok(); - }) - } - #[cfg(any(test, feature = "test-support"))] pub fn fake_completion( &mut self, - project: Option<&Entity>, + project: &Entity, buffer: &Entity, position: language::Anchor, response: PredictEditsResponse, @@ -683,7 +572,7 @@ and then another pub fn request_completion( &mut self, - project: Option<&Entity>, + project: &Entity, buffer: &Entity, position: language::Anchor, can_collect_data: bool, @@ -1043,23 +932,23 @@ and then another fn report_changes_for_buffer( &mut self, buffer: &Entity, + project: &Entity, cx: &mut Context, ) -> BufferSnapshot { - self.register_buffer(buffer, cx); + let zeta_project = self.get_or_init_zeta_project(project, cx); + let registered_buffer = Self::register_buffer_impl(zeta_project, buffer, project, cx); - let registered_buffer = self - .registered_buffers - .get_mut(&buffer.entity_id()) - .unwrap(); let new_snapshot = buffer.read(cx).snapshot(); - if new_snapshot.version != registered_buffer.snapshot.version { let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone()); - self.push_event(Event::BufferChange { - old_snapshot, - new_snapshot: new_snapshot.clone(), - timestamp: Instant::now(), - }); + Self::push_event( + zeta_project, + Event::BufferChange { + old_snapshot, + new_snapshot: new_snapshot.clone(), + timestamp: Instant::now(), + }, + ); } new_snapshot @@ -1140,7 +1029,7 @@ pub struct GatherContextOutput { } pub fn gather_context( - project: Option<&Entity>, + project: &Entity, full_path_str: String, snapshot: &BufferSnapshot, cursor_point: language::Point, @@ -1149,8 +1038,7 @@ pub fn gather_context( git_info: Option, cx: &App, ) -> Task> { - let local_lsp_store = - project.and_then(|project| project.read(cx).lsp_store().read(cx).as_local()); + let local_lsp_store = project.read(cx).lsp_store().read(cx).as_local(); let diagnostic_groups: Vec<(String, serde_json::Value)> = if can_collect_data && let Some(local_lsp_store) = local_lsp_store { snapshot @@ -1540,6 +1428,9 @@ impl edit_prediction::EditPredictionProvider for ZetaEditPredictionProvider { if self.zeta.read(cx).update_required { return; } + let Some(project) = project else { + return; + }; if self .zeta @@ -1578,13 +1469,7 @@ impl edit_prediction::EditPredictionProvider for ZetaEditPredictionProvider { let completion_request = this.update(cx, |this, cx| { this.last_request_timestamp = Instant::now(); this.zeta.update(cx, |zeta, cx| { - zeta.request_completion( - project.as_ref(), - &buffer, - position, - can_collect_data, - cx, - ) + zeta.request_completion(&project, &buffer, position, can_collect_data, cx) }) }); @@ -1762,7 +1647,6 @@ fn tokens_for_bytes(bytes: usize) -> usize { #[cfg(test)] mod tests { - use client::UserStore; use client::test::FakeServer; use clock::FakeSystemClock; use cloud_api_types::{CreateLlmTokenResponse, LlmToken}; @@ -1771,6 +1655,7 @@ mod tests { use indoc::indoc; use language::Point; use settings::SettingsStore; + use util::path; use super::*; @@ -1897,6 +1782,7 @@ mod tests { let settings_store = SettingsStore::test(cx); cx.set_global(settings_store); client::init_settings(cx); + Project::init_settings(cx); }); let edits = edits_for_prediction( @@ -1961,6 +1847,7 @@ mod tests { let settings_store = SettingsStore::test(cx); cx.set_global(settings_store); client::init_settings(cx); + Project::init_settings(cx); }); let buffer_content = "lorem\n"; @@ -2010,13 +1897,14 @@ mod tests { }); // Construct the fake server to authenticate. let _server = FakeServer::for_client(42, &client, cx).await; - let user_store = cx.new(|cx| UserStore::new(client.clone(), cx)); - let zeta = cx.new(|cx| Zeta::new(client, user_store.clone(), cx)); - + let fs = project::FakeFs::new(cx.executor()); + let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; let buffer = cx.new(|cx| Buffer::local(buffer_content, cx)); let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0))); + + let zeta = cx.new(|cx| Zeta::new(client, project.read(cx).user_store(), cx)); let completion_task = zeta.update(cx, |zeta, cx| { - zeta.request_completion(None, &buffer, cursor, false, cx) + zeta.request_completion(&project, &buffer, cursor, false, cx) }); let completion = completion_task.await.unwrap().unwrap(); @@ -2074,14 +1962,15 @@ mod tests { }); // Construct the fake server to authenticate. let _server = FakeServer::for_client(42, &client, cx).await; - let user_store = cx.new(|cx| UserStore::new(client.clone(), cx)); - let zeta = cx.new(|cx| Zeta::new(client, user_store.clone(), cx)); - + let fs = project::FakeFs::new(cx.executor()); + let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; let buffer = cx.new(|cx| Buffer::local(buffer_content, cx)); let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot()); let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0))); + + let zeta = cx.new(|cx| Zeta::new(client, project.read(cx).user_store(), cx)); let completion_task = zeta.update(cx, |zeta, cx| { - zeta.request_completion(None, &buffer, cursor, false, cx) + zeta.request_completion(&project, &buffer, cursor, false, cx) }); let completion = completion_task.await.unwrap().unwrap(); diff --git a/crates/zeta_cli/src/main.rs b/crates/zeta_cli/src/main.rs index 5b2d4cf615be67d9493d617ae7de38fdc8fa4b2f..e66eeed80920a0c31c5c06e119e17d418fbc294c 100644 --- a/crates/zeta_cli/src/main.rs +++ b/crates/zeta_cli/src/main.rs @@ -10,7 +10,7 @@ use language::Bias; use language::Buffer; use language::Point; use language_model::LlmApiToken; -use project::{Project, ProjectPath}; +use project::{Project, ProjectPath, Worktree}; use release_channel::AppVersion; use reqwest_client::ReqwestClient; use std::path::{Path, PathBuf}; @@ -129,15 +129,33 @@ async fn get_context( return Err(anyhow!("Absolute paths are not supported in --cursor")); } - let (project, _lsp_open_handle, buffer) = if use_language_server { - let (project, lsp_open_handle, buffer) = - open_buffer_with_language_server(&worktree_path, &cursor.path, app_state, cx).await?; - (Some(project), Some(lsp_open_handle), buffer) + let project = cx.update(|cx| { + Project::local( + app_state.client.clone(), + app_state.node_runtime.clone(), + app_state.user_store.clone(), + app_state.languages.clone(), + app_state.fs.clone(), + None, + cx, + ) + })?; + + let worktree = project + .update(cx, |project, cx| { + project.create_worktree(&worktree_path, true, cx) + })? + .await?; + + let (_lsp_open_handle, buffer) = if use_language_server { + let (lsp_open_handle, buffer) = + open_buffer_with_language_server(&project, &worktree, &cursor.path, cx).await?; + (Some(lsp_open_handle), buffer) } else { let abs_path = worktree_path.join(&cursor.path); let content = smol::fs::read_to_string(&abs_path).await?; let buffer = cx.new(|cx| Buffer::local(content, cx))?; - (None, None, buffer) + (None, buffer) }; let worktree_name = worktree_path @@ -177,7 +195,7 @@ async fn get_context( let mut gather_context_output = cx .update(|cx| { gather_context( - project.as_ref(), + &project, full_path_str, &snapshot, clipped_cursor, @@ -198,29 +216,11 @@ async fn get_context( } pub async fn open_buffer_with_language_server( - worktree_path: &Path, + project: &Entity, + worktree: &Entity, path: &Path, - app_state: &Arc, cx: &mut AsyncApp, -) -> Result<(Entity, Entity>, Entity)> { - let project = cx.update(|cx| { - Project::local( - app_state.client.clone(), - app_state.node_runtime.clone(), - app_state.user_store.clone(), - app_state.languages.clone(), - app_state.fs.clone(), - None, - cx, - ) - })?; - - let worktree = project - .update(cx, |project, cx| { - project.create_worktree(worktree_path, true, cx) - })? - .await?; - +) -> Result<(Entity>, Entity)> { let project_path = worktree.read_with(cx, |worktree, _cx| ProjectPath { worktree_id: worktree.id(), path: path.to_path_buf().into(), @@ -237,7 +237,7 @@ pub async fn open_buffer_with_language_server( let log_prefix = path.to_string_lossy().to_string(); wait_for_lang_server(&project, &buffer, log_prefix, cx).await?; - Ok((project, lsp_open_handle, buffer)) + Ok((lsp_open_handle, buffer)) } // TODO: Dedupe with similar function in crates/eval/src/instance.rs