Progress towards scoping zeta edit history to projects

Michael Sloan created

Change summary

crates/zed/src/zed/edit_prediction_registry.rs |   4 
crates/zeta/src/zeta.rs                        | 189 +++++++++++++------
2 files changed, 130 insertions(+), 63 deletions(-)

Detailed changes

crates/zed/src/zed/edit_prediction_registry.rs 🔗

@@ -207,9 +207,11 @@ fn assign_edit_prediction_provider(
 
                 if let Some(buffer) = &singleton_buffer
                     && buffer.read(cx).file().is_some()
+                    // todo!
+                    && let Some(project) = editor.project()
                 {
                     zeta.update(cx, |zeta, cx| {
-                        zeta.register_buffer(buffer, cx);
+                        zeta.register_buffer(buffer, project, cx);
                     });
                 }
 

crates/zeta/src/zeta.rs 🔗

@@ -39,12 +39,13 @@ use multi_buffer::MultiBufferPoint;
 use project::{Project, ProjectPath};
 use release_channel::AppVersion;
 use settings::WorktreeId;
+use std::collections::hash_map;
+use std::mem;
 use std::str::FromStr;
 use std::{
     cmp,
     fmt::Write,
     future::Future,
-    mem,
     ops::Range,
     path::Path,
     rc::Rc,
@@ -55,6 +56,7 @@ use telemetry_events::EditPredictionRating;
 use thiserror::Error;
 use util::{ResultExt, maybe};
 use uuid::Uuid;
+use workspace::Workspace;
 use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
 use worktree::Worktree;
 
@@ -235,8 +237,6 @@ impl std::fmt::Debug for EditPrediction {
 
 pub struct Zeta {
     client: Arc<Client>,
-    events: VecDeque<Event>,
-    registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
     shown_completions: VecDeque<EditPrediction>,
     rated_completions: HashSet<EditPredictionId>,
     data_collection_choice: Entity<DataCollectionChoice>,
@@ -246,6 +246,12 @@ pub struct Zeta {
     update_required: bool,
     user_store: Entity<UserStore>,
     license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
+    projects: HashMap<EntityId, ZetaProject>,
+}
+
+struct ZetaProject {
+    events: VecDeque<Event>,
+    registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
     recent_editors: VecDeque<RecentEditor>,
     last_activity_state: Option<ActivityState>,
     _activity_poll_task: Option<Task<Result<()>>>,
@@ -296,7 +302,9 @@ impl Zeta {
     }
 
     pub fn clear_history(&mut self) {
-        self.events.clear();
+        for zeta_project in self.projects.values_mut() {
+            zeta_project.events.clear();
+        }
     }
 
     pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
@@ -309,6 +317,7 @@ impl Zeta {
         let data_collection_choice = Self::load_data_collection_choices();
         let data_collection_choice = cx.new(|_| data_collection_choice);
 
+        /* todo!
         let mut activity_poll_task = None;
 
         if let Some(workspace) = &workspace {
@@ -344,13 +353,12 @@ impl Zeta {
                 }
             }));
         }
+        */
 
         Self {
             client,
-            events: VecDeque::with_capacity(MAX_EVENT_COUNT),
             shown_completions: VecDeque::with_capacity(MAX_SHOWN_COMPLETION_COUNT),
             rated_completions: HashSet::default(),
-            registered_buffers: HashMap::default(),
             data_collection_choice,
             llm_token: LlmApiToken::default(),
             _llm_token_subscription: cx.subscribe(
@@ -368,18 +376,42 @@ impl Zeta {
             update_required: false,
             license_detection_watchers: HashMap::default(),
             user_store,
-            recent_editors: VecDeque::new(),
-            last_activity_state: None,
-            _activity_poll_task: activity_poll_task,
+            projects: HashMap::default(),
         }
     }
 
-    fn push_event(&mut self, event: Event) {
+    fn get_mut_or_init_zeta_project(
+        &mut self,
+        project: &Entity<Project>,
+        cx: &mut Context<Self>,
+    ) -> &mut ZetaProject {
+        let project_id = project.entity_id();
+        match self.projects.entry(project_id) {
+            hash_map::Entry::Occupied(entry) => entry.into_mut(),
+            hash_map::Entry::Vacant(entry) => {
+                cx.observe_release(project, move |this, _, _cx| {
+                    this.projects.remove(&project_id);
+                });
+                entry.insert(ZetaProject {
+                    events: VecDeque::with_capacity(MAX_EVENT_COUNT),
+                    registered_buffers: HashMap::default(),
+                    recent_editors: VecDeque::new(),
+                    last_activity_state: None,
+                    // todo!
+                    _activity_poll_task: None,
+                })
+            }
+        }
+    }
+
+    fn push_event(zeta_project: &mut ZetaProject, event: Event) {
+        let events = &mut zeta_project.events;
+
         if let Some(Event::BufferChange {
             new_snapshot: last_new_snapshot,
             timestamp: last_timestamp,
             ..
-        }) = self.events.back_mut()
+        }) = events.back_mut()
         {
             // Coalesce edits for the same buffer when they happen one after the other.
             let Event::BufferChange {
@@ -398,51 +430,65 @@ impl Zeta {
             }
         }
 
-        if self.events.len() >= MAX_EVENT_COUNT {
+        if events.len() >= MAX_EVENT_COUNT {
             // These are halved instead of popping to improve prompt caching.
-            self.events.drain(..MAX_EVENT_COUNT / 2);
+            events.drain(..MAX_EVENT_COUNT / 2);
         }
 
-        self.events.push_back(event);
+        events.push_back(event);
     }
 
-    pub fn register_buffer(&mut self, buffer: &Entity<Buffer>, cx: &mut Context<Self>) {
-        let buffer_id = buffer.entity_id();
-        let weak_buffer = buffer.downgrade();
-
-        if let std::collections::hash_map::Entry::Vacant(entry) =
-            self.registered_buffers.entry(buffer_id)
-        {
-            let snapshot = buffer.read(cx).snapshot();
-
-            entry.insert(RegisteredBuffer {
-                snapshot,
-                _subscriptions: [
-                    cx.subscribe(buffer, move |this, buffer, event, cx| {
-                        this.handle_buffer_event(buffer, event, cx);
-                    }),
-                    cx.observe_release(buffer, move |this, _buffer, _cx| {
-                        this.registered_buffers.remove(&weak_buffer.entity_id());
-                    }),
-                ],
-            });
-        };
-    }
-
-    fn handle_buffer_event(
+    pub fn register_buffer(
         &mut self,
-        buffer: Entity<Buffer>,
-        event: &language::BufferEvent,
+        buffer: &Entity<Buffer>,
+        project: &Entity<Project>,
         cx: &mut Context<Self>,
     ) {
-        if let language::BufferEvent::Edited = event {
-            self.report_changes_for_buffer(&buffer, cx);
+        let zeta_project = self.get_mut_or_init_zeta_project(project, cx);
+        Self::register_buffer_impl(zeta_project, buffer, project, cx);
+    }
+
+    fn register_buffer_impl<'a>(
+        zeta_project: &'a mut ZetaProject,
+        buffer: &Entity<Buffer>,
+        project: &Entity<Project>,
+        cx: &mut Context<Self>,
+    ) -> &'a mut RegisteredBuffer {
+        let buffer_id = buffer.entity_id();
+        match zeta_project.registered_buffers.entry(buffer_id) {
+            hash_map::Entry::Occupied(entry) => entry.into_mut(),
+            hash_map::Entry::Vacant(entry) => {
+                let snapshot = buffer.read(cx).snapshot();
+                let project_entity_id = project.entity_id();
+                entry.insert(RegisteredBuffer {
+                    snapshot,
+                    _subscriptions: [
+                        cx.subscribe(buffer, {
+                            let project = project.downgrade();
+                            move |this, buffer, event, cx| {
+                                if let language::BufferEvent::Edited = event
+                                    && let Some(project) = project.upgrade()
+                                {
+                                    this.report_changes_for_buffer(&buffer, &project, cx);
+                                }
+                            }
+                        }),
+                        cx.observe_release(buffer, move |this, _buffer, _cx| {
+                            let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
+                            else {
+                                return;
+                            };
+                            zeta_project.registered_buffers.remove(&buffer_id);
+                        }),
+                    ],
+                })
+            }
         }
     }
 
     fn request_completion_impl<F, R>(
         &mut self,
-        project: Option<Entity<Project>>,
+        project: &Entity<Project>,
         buffer: &Entity<Buffer>,
         cursor: language::Anchor,
         can_collect_data: CanCollectData,
@@ -457,9 +503,12 @@ impl Zeta {
     {
         let buffer = buffer.clone();
         let buffer_snapshotted_at = Instant::now();
-        let snapshot = self.report_changes_for_buffer(&buffer, cx);
+        let snapshot = self.report_changes_for_buffer(&buffer, project, cx);
         let zeta = cx.entity();
-        let events = self.events.clone();
+        let events = self
+            .get_mut_or_init_zeta_project(project, cx)
+            .events
+            .clone();
         let client = self.client.clone();
         let llm_token = self.llm_token.clone();
         let app_version = AppVersion::global(cx);
@@ -487,6 +536,8 @@ impl Zeta {
             } = gather_task.await?;
             let done_gathering_context_at = Instant::now();
 
+            let additional_context_task: Option<Task<PredictEditsAdditionalContext>> = None;
+            /* todo!
             let additional_context_task = if matches!(can_collect_data, CanCollectData(true))
                 && let Some(file) = snapshot.file()
                 && let Ok(project_path) = cx.update(|cx| ProjectPath::from_file(file.as_ref(), cx))
@@ -503,7 +554,7 @@ impl Zeta {
                             snapshot,
                             &buffer_snapshotted_at,
                             project_path,
-                            project.as_ref(),
+                            &project,
                             cx,
                         )
                     }) {
@@ -515,6 +566,7 @@ impl Zeta {
             } else {
                 None
             };
+            */
 
             log::debug!(
                 "Events:\n{}\nExcerpt:\n{:?}",
@@ -606,6 +658,7 @@ impl Zeta {
                 );
             }
 
+            /* todo!
             if let Some(additional_context_task) = additional_context_task {
                 cx.background_spawn(async move {
                     if let Some(additional_context) = additional_context_task.await {
@@ -618,6 +671,7 @@ impl Zeta {
                 })
                 .detach();
             }
+            */
 
             edit_prediction
         })
@@ -626,6 +680,7 @@ impl Zeta {
     // Generates several example completions of various states to fill the Zeta completion modal
     #[cfg(any(test, feature = "test-support"))]
     pub fn fill_with_fake_completions(&mut self, cx: &mut Context<Self>) -> Task<()> {
+        /*
         use language::Point;
 
         let test_buffer_text = indoc::indoc! {r#"a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
@@ -773,12 +828,14 @@ and then another
             })
             .ok();
         })
+        */
+        todo!()
     }
 
     #[cfg(any(test, feature = "test-support"))]
     pub fn fake_completion(
         &mut self,
-        project: Option<Entity<Project>>,
+        project: &Entity<Project>,
         buffer: &Entity<Buffer>,
         position: language::Anchor,
         response: PredictEditsResponse,
@@ -798,7 +855,7 @@ and then another
 
     pub fn request_completion(
         &mut self,
-        project: Option<Entity<Project>>,
+        project: &Entity<Project>,
         buffer: &Entity<Buffer>,
         position: language::Anchor,
         can_collect_data: CanCollectData,
@@ -1155,23 +1212,23 @@ and then another
     fn report_changes_for_buffer(
         &mut self,
         buffer: &Entity<Buffer>,
+        project: &Entity<Project>,
         cx: &mut Context<Self>,
     ) -> BufferSnapshot {
-        self.register_buffer(buffer, cx);
+        let zeta_project = self.get_mut_or_init_zeta_project(project, cx);
+        let registered_buffer = Self::register_buffer_impl(zeta_project, buffer, project, cx);
 
-        let registered_buffer = self
-            .registered_buffers
-            .get_mut(&buffer.entity_id())
-            .unwrap();
         let new_snapshot = buffer.read(cx).snapshot();
-
         if new_snapshot.version != registered_buffer.snapshot.version {
             let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
-            self.push_event(Event::BufferChange {
-                old_snapshot,
-                new_snapshot: new_snapshot.clone(),
-                timestamp: Instant::now(),
-            });
+            Self::push_event(
+                zeta_project,
+                Event::BufferChange {
+                    old_snapshot,
+                    new_snapshot: new_snapshot.clone(),
+                    timestamp: Instant::now(),
+                },
+            );
         }
 
         new_snapshot
@@ -1194,6 +1251,7 @@ and then another
         }
     }
 
+    /*
     fn gather_additional_context(
         &mut self,
         cursor_point: language::Point,
@@ -1201,10 +1259,10 @@ and then another
         snapshot: BufferSnapshot,
         buffer_snapshotted_at: &Instant,
         project_path: ProjectPath,
-        project: Option<&Entity<Project>>,
+        project: &WeakEntity<Project>,
         cx: &mut Context<Self>,
     ) -> Option<Task<PredictEditsAdditionalContext>> {
-        let project = project?.read(cx);
+        let project = project.upgrade()?.read(cx);
         let entry = project.entry_for_path(&project_path, cx)?;
         if !worktree_entry_is_eligible_for_collection(&entry) {
             return None;
@@ -1511,6 +1569,7 @@ and then another
         }
         results
     }
+    */
 }
 
 fn to_cloud_llm_client_point(point: language::Point) -> cloud_llm_client::Point {
@@ -1926,6 +1985,10 @@ impl edit_prediction::EditPredictionProvider for ZetaEditPredictionProvider {
         if self.zeta.read(cx).update_required {
             return;
         }
+        // todo! Don't require a project
+        let Some(project) = project else {
+            return;
+        };
 
         if self
             .zeta
@@ -1964,7 +2027,7 @@ impl edit_prediction::EditPredictionProvider for ZetaEditPredictionProvider {
             let completion_request = this.update(cx, |this, cx| {
                 this.last_request_timestamp = Instant::now();
                 this.zeta.update(cx, |zeta, cx| {
-                    zeta.request_completion(project, &buffer, position, can_collect_data, cx)
+                    zeta.request_completion(&project, &buffer, position, can_collect_data, cx)
                 })
             });
 
@@ -2140,6 +2203,7 @@ fn tokens_for_bytes(bytes: usize) -> usize {
     bytes / BYTES_PER_TOKEN_GUESS
 }
 
+/* todo!
 #[cfg(test)]
 mod tests {
     use client::UserStore;
@@ -2510,3 +2574,4 @@ mod tests {
         zlog::init_test();
     }
 }
+*/