From dc2879759caac0c0ddab35adc7c36acb826a89c8 Mon Sep 17 00:00:00 2001 From: Michael Sloan Date: Tue, 2 Sep 2025 00:26:42 -0600 Subject: [PATCH] Progress towards scoping zeta edit history to projects --- .../zed/src/zed/edit_prediction_registry.rs | 4 +- crates/zeta/src/zeta.rs | 189 ++++++++++++------ 2 files changed, 130 insertions(+), 63 deletions(-) diff --git a/crates/zed/src/zed/edit_prediction_registry.rs b/crates/zed/src/zed/edit_prediction_registry.rs index 7b8b98018e6d6c608574ab81e912e8a98e363046..d9f1a0ed40628ede3ffcc8a6188d4b2d0b2b4ba9 100644 --- a/crates/zed/src/zed/edit_prediction_registry.rs +++ b/crates/zed/src/zed/edit_prediction_registry.rs @@ -207,9 +207,11 @@ fn assign_edit_prediction_provider( if let Some(buffer) = &singleton_buffer && buffer.read(cx).file().is_some() + // todo! + && let Some(project) = editor.project() { zeta.update(cx, |zeta, cx| { - zeta.register_buffer(buffer, cx); + zeta.register_buffer(buffer, project, cx); }); } diff --git a/crates/zeta/src/zeta.rs b/crates/zeta/src/zeta.rs index 11c51e317a172973af27c78d56c2e1005f9ef56a..bf4b9706b51a9b2bb538807e8666360a6f4b98da 100644 --- a/crates/zeta/src/zeta.rs +++ b/crates/zeta/src/zeta.rs @@ -39,12 +39,13 @@ use multi_buffer::MultiBufferPoint; use project::{Project, ProjectPath}; use release_channel::AppVersion; use settings::WorktreeId; +use std::collections::hash_map; +use std::mem; use std::str::FromStr; use std::{ cmp, fmt::Write, future::Future, - mem, ops::Range, path::Path, rc::Rc, @@ -55,6 +56,7 @@ use telemetry_events::EditPredictionRating; use thiserror::Error; use util::{ResultExt, maybe}; use uuid::Uuid; +use workspace::Workspace; use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification}; use worktree::Worktree; @@ -235,8 +237,6 @@ impl std::fmt::Debug for EditPrediction { pub struct Zeta { client: Arc, - events: VecDeque, - registered_buffers: HashMap, shown_completions: VecDeque, rated_completions: HashSet, data_collection_choice: Entity, @@ -246,6 +246,12 @@ pub struct Zeta { update_required: bool, user_store: Entity, license_detection_watchers: HashMap>, + projects: HashMap, +} + +struct ZetaProject { + events: VecDeque, + registered_buffers: HashMap, recent_editors: VecDeque, last_activity_state: Option, _activity_poll_task: Option>>, @@ -296,7 +302,9 @@ impl Zeta { } pub fn clear_history(&mut self) { - self.events.clear(); + for zeta_project in self.projects.values_mut() { + zeta_project.events.clear(); + } } pub fn usage(&self, cx: &App) -> Option { @@ -309,6 +317,7 @@ impl Zeta { let data_collection_choice = Self::load_data_collection_choices(); let data_collection_choice = cx.new(|_| data_collection_choice); + /* todo! let mut activity_poll_task = None; if let Some(workspace) = &workspace { @@ -344,13 +353,12 @@ impl Zeta { } })); } + */ Self { client, - events: VecDeque::with_capacity(MAX_EVENT_COUNT), shown_completions: VecDeque::with_capacity(MAX_SHOWN_COMPLETION_COUNT), rated_completions: HashSet::default(), - registered_buffers: HashMap::default(), data_collection_choice, llm_token: LlmApiToken::default(), _llm_token_subscription: cx.subscribe( @@ -368,18 +376,42 @@ impl Zeta { update_required: false, license_detection_watchers: HashMap::default(), user_store, - recent_editors: VecDeque::new(), - last_activity_state: None, - _activity_poll_task: activity_poll_task, + projects: HashMap::default(), } } - fn push_event(&mut self, event: Event) { + fn get_mut_or_init_zeta_project( + &mut self, + project: &Entity, + cx: &mut Context, + ) -> &mut ZetaProject { + let project_id = project.entity_id(); + match self.projects.entry(project_id) { + hash_map::Entry::Occupied(entry) => entry.into_mut(), + hash_map::Entry::Vacant(entry) => { + cx.observe_release(project, move |this, _, _cx| { + this.projects.remove(&project_id); + }); + entry.insert(ZetaProject { + events: VecDeque::with_capacity(MAX_EVENT_COUNT), + registered_buffers: HashMap::default(), + recent_editors: VecDeque::new(), + last_activity_state: None, + // todo! + _activity_poll_task: None, + }) + } + } + } + + fn push_event(zeta_project: &mut ZetaProject, event: Event) { + let events = &mut zeta_project.events; + if let Some(Event::BufferChange { new_snapshot: last_new_snapshot, timestamp: last_timestamp, .. - }) = self.events.back_mut() + }) = events.back_mut() { // Coalesce edits for the same buffer when they happen one after the other. let Event::BufferChange { @@ -398,51 +430,65 @@ impl Zeta { } } - if self.events.len() >= MAX_EVENT_COUNT { + if events.len() >= MAX_EVENT_COUNT { // These are halved instead of popping to improve prompt caching. - self.events.drain(..MAX_EVENT_COUNT / 2); + events.drain(..MAX_EVENT_COUNT / 2); } - self.events.push_back(event); + events.push_back(event); } - pub fn register_buffer(&mut self, buffer: &Entity, cx: &mut Context) { - let buffer_id = buffer.entity_id(); - let weak_buffer = buffer.downgrade(); - - if let std::collections::hash_map::Entry::Vacant(entry) = - self.registered_buffers.entry(buffer_id) - { - let snapshot = buffer.read(cx).snapshot(); - - entry.insert(RegisteredBuffer { - snapshot, - _subscriptions: [ - cx.subscribe(buffer, move |this, buffer, event, cx| { - this.handle_buffer_event(buffer, event, cx); - }), - cx.observe_release(buffer, move |this, _buffer, _cx| { - this.registered_buffers.remove(&weak_buffer.entity_id()); - }), - ], - }); - }; - } - - fn handle_buffer_event( + pub fn register_buffer( &mut self, - buffer: Entity, - event: &language::BufferEvent, + buffer: &Entity, + project: &Entity, cx: &mut Context, ) { - if let language::BufferEvent::Edited = event { - self.report_changes_for_buffer(&buffer, cx); + let zeta_project = self.get_mut_or_init_zeta_project(project, cx); + Self::register_buffer_impl(zeta_project, buffer, project, cx); + } + + fn register_buffer_impl<'a>( + zeta_project: &'a mut ZetaProject, + buffer: &Entity, + project: &Entity, + cx: &mut Context, + ) -> &'a mut RegisteredBuffer { + let buffer_id = buffer.entity_id(); + match zeta_project.registered_buffers.entry(buffer_id) { + hash_map::Entry::Occupied(entry) => entry.into_mut(), + hash_map::Entry::Vacant(entry) => { + let snapshot = buffer.read(cx).snapshot(); + let project_entity_id = project.entity_id(); + entry.insert(RegisteredBuffer { + snapshot, + _subscriptions: [ + cx.subscribe(buffer, { + let project = project.downgrade(); + move |this, buffer, event, cx| { + if let language::BufferEvent::Edited = event + && let Some(project) = project.upgrade() + { + this.report_changes_for_buffer(&buffer, &project, cx); + } + } + }), + cx.observe_release(buffer, move |this, _buffer, _cx| { + let Some(zeta_project) = this.projects.get_mut(&project_entity_id) + else { + return; + }; + zeta_project.registered_buffers.remove(&buffer_id); + }), + ], + }) + } } } fn request_completion_impl( &mut self, - project: Option>, + project: &Entity, buffer: &Entity, cursor: language::Anchor, can_collect_data: CanCollectData, @@ -457,9 +503,12 @@ impl Zeta { { let buffer = buffer.clone(); let buffer_snapshotted_at = Instant::now(); - let snapshot = self.report_changes_for_buffer(&buffer, cx); + let snapshot = self.report_changes_for_buffer(&buffer, project, cx); let zeta = cx.entity(); - let events = self.events.clone(); + let events = self + .get_mut_or_init_zeta_project(project, cx) + .events + .clone(); let client = self.client.clone(); let llm_token = self.llm_token.clone(); let app_version = AppVersion::global(cx); @@ -487,6 +536,8 @@ impl Zeta { } = gather_task.await?; let done_gathering_context_at = Instant::now(); + let additional_context_task: Option> = None; + /* todo! let additional_context_task = if matches!(can_collect_data, CanCollectData(true)) && let Some(file) = snapshot.file() && let Ok(project_path) = cx.update(|cx| ProjectPath::from_file(file.as_ref(), cx)) @@ -503,7 +554,7 @@ impl Zeta { snapshot, &buffer_snapshotted_at, project_path, - project.as_ref(), + &project, cx, ) }) { @@ -515,6 +566,7 @@ impl Zeta { } else { None }; + */ log::debug!( "Events:\n{}\nExcerpt:\n{:?}", @@ -606,6 +658,7 @@ impl Zeta { ); } + /* todo! if let Some(additional_context_task) = additional_context_task { cx.background_spawn(async move { if let Some(additional_context) = additional_context_task.await { @@ -618,6 +671,7 @@ impl Zeta { }) .detach(); } + */ edit_prediction }) @@ -626,6 +680,7 @@ impl Zeta { // Generates several example completions of various states to fill the Zeta completion modal #[cfg(any(test, feature = "test-support"))] pub fn fill_with_fake_completions(&mut self, cx: &mut Context) -> Task<()> { + /* use language::Point; let test_buffer_text = indoc::indoc! {r#"a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line @@ -773,12 +828,14 @@ and then another }) .ok(); }) + */ + todo!() } #[cfg(any(test, feature = "test-support"))] pub fn fake_completion( &mut self, - project: Option>, + project: &Entity, buffer: &Entity, position: language::Anchor, response: PredictEditsResponse, @@ -798,7 +855,7 @@ and then another pub fn request_completion( &mut self, - project: Option>, + project: &Entity, buffer: &Entity, position: language::Anchor, can_collect_data: CanCollectData, @@ -1155,23 +1212,23 @@ and then another fn report_changes_for_buffer( &mut self, buffer: &Entity, + project: &Entity, cx: &mut Context, ) -> BufferSnapshot { - self.register_buffer(buffer, cx); + let zeta_project = self.get_mut_or_init_zeta_project(project, cx); + let registered_buffer = Self::register_buffer_impl(zeta_project, buffer, project, cx); - let registered_buffer = self - .registered_buffers - .get_mut(&buffer.entity_id()) - .unwrap(); let new_snapshot = buffer.read(cx).snapshot(); - if new_snapshot.version != registered_buffer.snapshot.version { let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone()); - self.push_event(Event::BufferChange { - old_snapshot, - new_snapshot: new_snapshot.clone(), - timestamp: Instant::now(), - }); + Self::push_event( + zeta_project, + Event::BufferChange { + old_snapshot, + new_snapshot: new_snapshot.clone(), + timestamp: Instant::now(), + }, + ); } new_snapshot @@ -1194,6 +1251,7 @@ and then another } } + /* fn gather_additional_context( &mut self, cursor_point: language::Point, @@ -1201,10 +1259,10 @@ and then another snapshot: BufferSnapshot, buffer_snapshotted_at: &Instant, project_path: ProjectPath, - project: Option<&Entity>, + project: &WeakEntity, cx: &mut Context, ) -> Option> { - let project = project?.read(cx); + let project = project.upgrade()?.read(cx); let entry = project.entry_for_path(&project_path, cx)?; if !worktree_entry_is_eligible_for_collection(&entry) { return None; @@ -1511,6 +1569,7 @@ and then another } results } + */ } fn to_cloud_llm_client_point(point: language::Point) -> cloud_llm_client::Point { @@ -1926,6 +1985,10 @@ impl edit_prediction::EditPredictionProvider for ZetaEditPredictionProvider { if self.zeta.read(cx).update_required { return; } + // todo! Don't require a project + let Some(project) = project else { + return; + }; if self .zeta @@ -1964,7 +2027,7 @@ impl edit_prediction::EditPredictionProvider for ZetaEditPredictionProvider { let completion_request = this.update(cx, |this, cx| { this.last_request_timestamp = Instant::now(); this.zeta.update(cx, |zeta, cx| { - zeta.request_completion(project, &buffer, position, can_collect_data, cx) + zeta.request_completion(&project, &buffer, position, can_collect_data, cx) }) }); @@ -2140,6 +2203,7 @@ fn tokens_for_bytes(bytes: usize) -> usize { bytes / BYTES_PER_TOKEN_GUESS } +/* todo! #[cfg(test)] mod tests { use client::UserStore; @@ -2510,3 +2574,4 @@ mod tests { zlog::init_test(); } } +*/