Omit events from other worktrees from captured examples

Max Brunsfeld created

Change summary

crates/edit_prediction/src/capture_example.rs | 39 +++++++++++++++-----
1 file changed, 28 insertions(+), 11 deletions(-)

Detailed changes

crates/edit_prediction/src/capture_example.rs 🔗

@@ -7,7 +7,7 @@ use buffer_diff::BufferDiffSnapshot;
 use collections::HashMap;
 use gpui::{App, Entity, Task};
 use language::{Buffer, ToPoint as _};
-use project::Project;
+use project::{Project, WorktreeId};
 use std::{collections::hash_map, fmt::Write as _, path::Path, sync::Arc};
 use text::BufferSnapshot as TextBufferSnapshot;
 
@@ -35,14 +35,26 @@ pub fn capture_example(
         .or_else(|| repository_snapshot.remote_upstream_url.clone())?;
     let revision = repository_snapshot.head_commit.as_ref()?.sha.to_string();
 
-    let events = ep_store.update(cx, |store, cx| {
+    let mut events = ep_store.update(cx, |store, cx| {
         store.edit_history_for_project_with_pause_split_last_event(&project, cx)
     });
 
     let git_store = project.read(cx).git_store().clone();
 
     Some(cx.spawn(async move |mut cx| {
-        let snapshots_by_path = collect_snapshots(&project, &git_store, &events, &mut cx).await?;
+        let snapshots_by_path =
+            collect_snapshots(&project, &git_store, worktree_id, &events, &mut cx).await?;
+
+        events.retain(|stored_event| {
+            match stored_event.event.as_ref() {
+                zeta_prompt::Event::BufferChange { path, .. } => {
+                    if !snapshots_by_path.contains_key(path) {
+                        return false;
+                    }
+                }
+            }
+            true
+        });
 
         let line_comment_prefix = snapshot
             .language()
@@ -100,21 +112,26 @@ fn compute_cursor_excerpt(
 async fn collect_snapshots(
     project: &Entity<Project>,
     git_store: &Entity<project::git_store::GitStore>,
+    worktree_id: WorktreeId,
     events: &[StoredEvent],
     cx: &mut gpui::AsyncApp,
 ) -> Result<HashMap<Arc<Path>, (TextBufferSnapshot, BufferDiffSnapshot)>> {
     let mut snapshots_by_path = HashMap::default();
+    let root_name = project.read_with(cx, |project, cx| {
+        project
+            .worktree_for_id(worktree_id, cx)
+            .unwrap()
+            .read(cx)
+            .root_name()
+            .to_owned()
+    })?;
     for stored_event in events {
         let zeta_prompt::Event::BufferChange { path, .. } = stored_event.event.as_ref();
         if let Some((project_path, full_path)) = project.read_with(cx, |project, cx| {
-            let project_path = project.find_project_path(path, cx)?;
-            let full_path = project
-                .worktree_for_id(project_path.worktree_id, cx)?
-                .read(cx)
-                .root_name()
-                .join(&project_path.path)
-                .as_std_path()
-                .into();
+            let project_path = project
+                .find_project_path(path, cx)
+                .filter(|path| path.worktree_id == worktree_id)?;
+            let full_path = root_name.join(&project_path.path).as_std_path().into();
             Some((project_path, full_path))
         })? {
             if let hash_map::Entry::Vacant(entry) = snapshots_by_path.entry(full_path) {