ep cli: Handle opening buffers from files created by the edit history (#46254)

Agus Zubiaga created

Since we don't persist new files to disk, they don't have entries, so we
have to look them up in memory first.

Release Notes:

- N/A

Change summary

crates/edit_prediction/src/udiff.rs            | 50 ++++++++++++-------
crates/edit_prediction_cli/src/load_project.rs | 50 +++++++++++--------
2 files changed, 61 insertions(+), 39 deletions(-)

Detailed changes

crates/edit_prediction/src/udiff.rs 🔗

@@ -8,7 +8,7 @@ use std::{
 };
 
 use anyhow::{Context as _, Result, anyhow};
-use collections::HashMap;
+use collections::{HashMap, hash_map::Entry};
 use gpui::{AsyncApp, Entity};
 use language::{Anchor, Buffer, OffsetRangeExt as _, TextBufferSnapshot, text_diff};
 use postage::stream::Stream as _;
@@ -17,7 +17,13 @@ use util::{paths::PathStyle, rel_path::RelPath};
 use worktree::Worktree;
 
 #[derive(Clone, Debug)]
-pub struct OpenedBuffers(#[allow(unused)] HashMap<String, Entity<Buffer>>);
+pub struct OpenedBuffers(HashMap<String, Entity<Buffer>>);
+
+impl OpenedBuffers {
+    pub fn get(&self, path: &str) -> Option<&Entity<Buffer>> {
+        self.0.get(path)
+    }
+}
 
 #[must_use]
 pub async fn apply_diff(
@@ -42,12 +48,12 @@ pub async fn apply_diff(
         .collect();
     refresh_worktree_entries(&worktree, paths.iter().map(|p| p.as_path()), cx).await?;
 
-    let mut included_files = HashMap::default();
+    let mut included_files: HashMap<String, Entity<Buffer>> = HashMap::default();
 
     let ranges = [Anchor::MIN..Anchor::MAX];
     let mut diff = DiffParser::new(diff_str);
     let mut current_file = None;
-    let mut edits = vec![];
+    let mut edits: Vec<(std::ops::Range<Anchor>, Arc<str>)> = vec![];
 
     while let Some(event) = diff.next()? {
         match event {
@@ -58,21 +64,29 @@ pub async fn apply_diff(
             } => {
                 let buffer = match current_file {
                     None => {
-                        let buffer = if is_new_file {
-                            project
-                                .update(cx, |project, cx| project.create_buffer(true, cx))?
-                                .await?
-                        } else {
-                            let project_path = project
-                                .update(cx, |project, cx| {
-                                    project.find_project_path(path.as_ref(), cx)
-                                })?
-                                .with_context(|| format!("no such path: {}", path))?;
-                            project
-                                .update(cx, |project, cx| project.open_buffer(project_path, cx))?
-                                .await?
+                        let buffer = match included_files.entry(path.to_string()) {
+                            Entry::Occupied(entry) => entry.get().clone(),
+                            Entry::Vacant(entry) => {
+                                let buffer = if is_new_file {
+                                    project
+                                        .update(cx, |project, cx| project.create_buffer(true, cx))?
+                                        .await?
+                                } else {
+                                    let project_path = project
+                                        .update(cx, |project, cx| {
+                                            project.find_project_path(path.as_ref(), cx)
+                                        })?
+                                        .with_context(|| format!("no such path: {}", path))?;
+                                    project
+                                        .update(cx, |project, cx| {
+                                            project.open_buffer(project_path, cx)
+                                        })?
+                                        .await?
+                                };
+                                entry.insert(buffer.clone());
+                                buffer
+                            }
                         };
-                        included_files.insert(path.to_string(), buffer.clone());
                         current_file = Some(buffer);
                         current_file.as_ref().unwrap()
                     }

crates/edit_prediction_cli/src/load_project.rs 🔗

@@ -27,9 +27,10 @@ pub async fn run_load_project(
 
     let project = setup_project(example, &app_state, &progress, &mut cx).await?;
 
-    let _open_buffers = apply_edit_history(example, &project, &mut cx).await?;
+    let open_buffers = apply_edit_history(example, &project, &mut cx).await?;
 
-    let (buffer, cursor_position) = cursor_position(example, &project, &mut cx).await?;
+    let (buffer, cursor_position) =
+        cursor_position(example, &project, &open_buffers, &mut cx).await?;
     let (example_buffer, language_name) = buffer.read_with(&cx, |buffer, _cx| {
         let cursor_point = cursor_position.to_point(&buffer);
         let language_name = buffer
@@ -54,7 +55,7 @@ pub async fn run_load_project(
         buffer,
         project,
         cursor_position,
-        _open_buffers,
+        _open_buffers: open_buffers,
     });
     Ok(())
 }
@@ -62,6 +63,7 @@ pub async fn run_load_project(
 async fn cursor_position(
     example: &Example,
     project: &Entity<Project>,
+    open_buffers: &OpenedBuffers,
     cx: &mut AsyncApp,
 ) -> Result<(Entity<Buffer>, Anchor)> {
     let language_registry = project.read_with(cx, |project, _| project.languages().clone())?;
@@ -75,25 +77,31 @@ async fn cursor_position(
         return Err(error);
     }
 
-    // Since the worktree scanner is disabled, manually refresh entries for the cursor path.
-    if let Some(worktree) = project.read_with(cx, |project, cx| project.worktrees(cx).next())? {
-        refresh_worktree_entries(&worktree, [&*example.spec.cursor_path], cx).await?;
-    }
-
-    let cursor_path = project
-        .read_with(cx, |project, cx| {
-            project.find_project_path(&example.spec.cursor_path, cx)
-        })?
-        .with_context(|| {
-            format!(
-                "failed to find cursor path {}",
-                example.spec.cursor_path.display()
-            )
-        })?;
+    let cursor_path_str = example.spec.cursor_path.to_string_lossy();
+    // We try open_buffers first because the file might be new and not saved to disk
+    let cursor_buffer = if let Some(buffer) = open_buffers.get(&cursor_path_str) {
+        buffer.clone()
+    } else {
+        // Since the worktree scanner is disabled, manually refresh entries for the cursor path.
+        if let Some(worktree) = project.read_with(cx, |project, cx| project.worktrees(cx).next())? {
+            refresh_worktree_entries(&worktree, [&*example.spec.cursor_path], cx).await?;
+        }
 
-    let cursor_buffer = project
-        .update(cx, |project, cx| project.open_buffer(cursor_path, cx))?
-        .await?;
+        let cursor_path = project
+            .read_with(cx, |project, cx| {
+                project.find_project_path(&example.spec.cursor_path, cx)
+            })?
+            .with_context(|| {
+                format!(
+                    "failed to find cursor path {}",
+                    example.spec.cursor_path.display()
+                )
+            })?;
+
+        project
+            .update(cx, |project, cx| project.open_buffer(cursor_path, cx))?
+            .await?
+    };
 
     let (cursor_excerpt, cursor_offset_within_excerpt) = example.spec.cursor_excerpt()?;