Edit prediction changes (#46169)

Oleksiy Syvokon and Agus Zubiaga created

1. Handle diffs with no trailing new lines 
2. ep: Don't assume workdir name in edit history paths
3. Fix `imitate_human_edits()` for pure insertions

Release Notes:

- N/A

---------

Co-authored-by: Agus Zubiaga <agus@zed.dev>

Change summary

crates/edit_prediction/src/udiff.rs            | 154 ++++++++++++++++---
crates/edit_prediction_cli/src/load_project.rs |  23 +-
crates/edit_prediction_cli/src/main.rs         |   2 
crates/edit_prediction_cli/src/split_commit.rs | 100 +++++++++++-
4 files changed, 230 insertions(+), 49 deletions(-)

Detailed changes

crates/edit_prediction/src/udiff.rs 🔗

@@ -62,36 +62,26 @@ pub async fn apply_diff(
                 hunk,
                 is_new_file,
             } => {
+                let worktree_id = worktree.read_with(cx, |wt, _| wt.id())?;
+                let rel_path = RelPath::new(Path::new(path.as_ref()), PathStyle::Posix)?;
+                let project_path = project::ProjectPath {
+                    worktree_id,
+                    path: rel_path.into_arc(),
+                };
+
                 let buffer = match current_file {
                     None => {
-                        let buffer = if is_new_file {
+                        if is_new_file {
                             // New file - create it first, then open the buffer
-                            let worktree_id = worktree.read_with(cx, |wt, _| wt.id())?;
-                            let rel_path =
-                                RelPath::new(Path::new(path.as_ref()), PathStyle::Posix)?;
-                            let project_path = project::ProjectPath {
-                                worktree_id,
-                                path: rel_path.into_arc(),
-                            };
                             project
                                 .update(cx, |project, cx| {
                                     project.create_entry(project_path.clone(), false, cx)
                                 })?
                                 .await?;
-                            project
-                                .update(cx, |project, cx| project.open_buffer(project_path, cx))?
-                                .await?
-                        } else {
-                            // Existing file - find and open it
-                            let project_path = project
-                                .update(cx, |project, cx| {
-                                    project.find_project_path(path.as_ref(), cx)
-                                })?
-                                .context("no such path")?;
-                            project
-                                .update(cx, |project, cx| project.open_buffer(project_path, cx))?
-                                .await?
-                        };
+                        }
+                        let buffer = project
+                            .update(cx, |project, cx| project.open_buffer(project_path, cx))?
+                            .await?;
                         included_files.insert(path.to_string(), buffer.clone());
                         current_file = Some(buffer);
                         current_file.as_ref().unwrap()
@@ -191,7 +181,8 @@ pub fn strip_diff_metadata(diff: &str) -> String {
             | DiffLine::HunkHeader(_)
             | DiffLine::Context(_)
             | DiffLine::Deletion(_)
-            | DiffLine::Addition(_) => {
+            | DiffLine::Addition(_)
+            | DiffLine::NoNewlineAtEOF => {
                 result.push_str(line);
                 result.push('\n');
             }
@@ -426,6 +417,23 @@ impl<'a> DiffParser<'a> {
                             }
                         }
                     }
+                    DiffLine::NoNewlineAtEOF => {
+                        if let Some(last_edit) = self.hunk.edits.last_mut() {
+                            if last_edit.text.ends_with('\n') {
+                                // Previous line was an addition (has trailing newline in text)
+                                last_edit.text.pop();
+                            } else if !last_edit.range.is_empty()
+                                && last_edit.range.end == self.hunk.context.len()
+                            {
+                                // Previous line was a deletion (non-empty range at end of context)
+                                self.hunk.context.pop();
+                                last_edit.range.end -= 1;
+                            }
+                        } else {
+                            // Previous line was context (no edits)
+                            self.hunk.context.pop();
+                        }
+                    }
                     DiffLine::Garbage(_) => {}
                 }
 
@@ -460,7 +468,13 @@ fn resolve_hunk_edits_in_buffer(
                 offset = Some(range.start + ix);
             }
         }
-        offset.ok_or_else(|| anyhow!("Failed to match context:\n{}", hunk.context))
+        offset.ok_or_else(|| {
+            anyhow!(
+                "Failed to match context:\n\n```\n{}```\n\nBuffer contents:\n\n```\n{}```",
+                hunk.context,
+                buffer.text()
+            )
+        })
     }?;
     let iter = hunk.edits.into_iter().flat_map(move |edit| {
         let old_text = buffer
@@ -488,6 +502,7 @@ pub enum DiffLine<'a> {
     Context(&'a str),
     Deletion(&'a str),
     Addition(&'a str),
+    NoNewlineAtEOF,
     Garbage(&'a str),
 }
 
@@ -505,6 +520,9 @@ impl<'a> DiffLine<'a> {
     }
 
     fn try_parse(line: &'a str) -> Option<Self> {
+        if line.starts_with("\\ No newline") {
+            return Some(Self::NoNewlineAtEOF);
+        }
         if let Some(header) = line.strip_prefix("---").and_then(eat_required_whitespace) {
             let path = parse_header_path("a/", header);
             Some(Self::OldPath { path })
@@ -561,6 +579,7 @@ impl<'a> Display for DiffLine<'a> {
             DiffLine::Context(content) => write!(f, " {content}"),
             DiffLine::Deletion(content) => write!(f, "-{content}"),
             DiffLine::Addition(content) => write!(f, "+{content}"),
+            DiffLine::NoNewlineAtEOF => write!(f, "\\ No newline at end of file"),
             DiffLine::Garbage(line) => write!(f, "{line}"),
         }
     }
@@ -831,6 +850,93 @@ mod tests {
         )
     }
 
+    #[test]
+    fn test_no_newline_at_eof() {
+        let diff = indoc! {"
+            --- a/file.py
+            +++ b/file.py
+            @@ -55,7 +55,3 @@ class CustomDataset(Dataset):
+                         torch.set_rng_state(state)
+                         mask = self.transform(mask)
+
+            -        if self.mode == 'Training':
+            -            return (img, mask, name)
+            -        else:
+            -            return (img, mask, name)
+            \\ No newline at end of file
+        "};
+
+        let mut events = Vec::new();
+        let mut parser = DiffParser::new(diff);
+        while let Some(event) = parser.next().unwrap() {
+            events.push(event);
+        }
+
+        assert_eq!(
+            events,
+            &[
+                DiffEvent::Hunk {
+                    path: "file.py".into(),
+                    hunk: Hunk {
+                        context: concat!(
+                            "            torch.set_rng_state(state)\n",
+                            "            mask = self.transform(mask)\n",
+                            "\n",
+                            "        if self.mode == 'Training':\n",
+                            "            return (img, mask, name)\n",
+                            "        else:\n",
+                            "            return (img, mask, name)",
+                        )
+                        .into(),
+                        edits: vec![Edit {
+                            range: 80..203,
+                            text: "".into()
+                        }],
+                    },
+                    is_new_file: false,
+                },
+                DiffEvent::FileEnd { renamed_to: None }
+            ],
+        );
+    }
+
+    #[test]
+    fn test_no_newline_at_eof_addition() {
+        let diff = indoc! {"
+            --- a/file.txt
+            +++ b/file.txt
+            @@ -1,2 +1,3 @@
+             context
+            -deleted
+            +added line
+            \\ No newline at end of file
+        "};
+
+        let mut events = Vec::new();
+        let mut parser = DiffParser::new(diff);
+        while let Some(event) = parser.next().unwrap() {
+            events.push(event);
+        }
+
+        assert_eq!(
+            events,
+            &[
+                DiffEvent::Hunk {
+                    path: "file.txt".into(),
+                    hunk: Hunk {
+                        context: "context\ndeleted\n".into(),
+                        edits: vec![Edit {
+                            range: 8..16,
+                            text: "added line".into()
+                        }],
+                    },
+                    is_new_file: false,
+                },
+                DiffEvent::FileEnd { renamed_to: None }
+            ],
+        );
+    }
+
     #[gpui::test]
     async fn test_apply_diff_successful(cx: &mut TestAppContext) {
         let fs = init_test(cx);

crates/edit_prediction_cli/src/load_project.rs 🔗

@@ -11,9 +11,10 @@ use edit_prediction::udiff::OpenedBuffers;
 use futures::AsyncWriteExt as _;
 use gpui::{AsyncApp, Entity};
 use language::{Anchor, Buffer, LanguageNotFound, ToOffset, ToPoint};
-use project::Project;
 use project::buffer_store::BufferStoreEvent;
+use project::{Project, ProjectPath};
 use std::{fs, path::PathBuf, sync::Arc};
+use util::{paths::PathStyle, rel_path::RelPath};
 
 pub async fn run_load_project(
     example: &mut Example,
@@ -76,16 +77,16 @@ async fn cursor_position(
         return Err(error);
     }
 
-    let cursor_path = project
-        .read_with(cx, |project, cx| {
-            project.find_project_path(&example.spec.cursor_path, cx)
-        })?
-        .with_context(|| {
-            format!(
-                "failed to find cursor path {}",
-                example.spec.cursor_path.display()
-            )
-        })?;
+    let worktree = project
+        .read_with(cx, |project, cx| project.visible_worktrees(cx).next())?
+        .context("project has no worktree")?;
+
+    let worktree_id = worktree.read_with(cx, |wt, _| wt.id())?;
+    let cursor_path = ProjectPath {
+        worktree_id,
+        path: RelPath::new(example.spec.cursor_path.as_ref(), PathStyle::Posix)?.into_arc(),
+    };
+
     let cursor_buffer = project
         .update(cx, |project, cx| project.open_buffer(cursor_path, cx))?
         .await?;

crates/edit_prediction_cli/src/main.rs 🔗

@@ -429,7 +429,7 @@ fn main() {
                                     .join(format!("{}_err.txt", example.spec.name));
                                 app_state
                                     .fs
-                                    .write(&err_path, e.to_string().as_bytes())
+                                    .write(&err_path, format!("{e:?}").as_bytes())
                                     .await
                                     .unwrap();
 

crates/edit_prediction_cli/src/split_commit.rs 🔗

@@ -799,20 +799,15 @@ pub fn imitate_human_edits(
             hunk.new_count += 1;
         }
     } else {
-        // For pure insertions, we need to add or modify a hunk
-        // Check if the source hunk exists AND has enough lines for the target's line index
-        let can_insert_in_existing_hunk = new_src_patch
-            .hunks
-            .get(tgt_edit_loc.hunk_index)
-            .map_or(false, |hunk| {
-                tgt_edit_loc.line_index_within_hunk <= hunk.lines.len()
-            });
-
-        if can_insert_in_existing_hunk {
-            if let Some(hunk) = new_src_patch.hunks.get_mut(tgt_edit_loc.hunk_index) {
-                // Insert the partial line at the same position as target
+        // For pure insertions, insert after the last edit in source patch
+        // This imitates human typing - the intermediate content is what the user is currently typing
+        let last_src_edit = locate_edited_line(&new_src_patch, -1);
+
+        if let Some(src_loc) = last_src_edit {
+            // Insert after the last edit in source
+            if let Some(hunk) = new_src_patch.hunks.get_mut(src_loc.hunk_index) {
                 hunk.lines.insert(
-                    tgt_edit_loc.line_index_within_hunk,
+                    src_loc.line_index_within_hunk + 1,
                     PatchLine::Addition(new_src.clone()),
                 );
                 hunk.new_count += 1;
@@ -1672,6 +1667,85 @@ index 123..456 789
         );
     }
 
+    #[test]
+    fn test_imitate_human_edits_inserts_after_last_source_edit() {
+        // Regression test: intermediate content should appear after the last edit
+        // in the source patch, not at the position of the first target edit.
+        // This ensures the diff output correctly imitates human typing order.
+        //
+        // The bug was: when source has edits and target has a pure insertion,
+        // the intermediate content was inserted at tgt_edit_loc.line_index_within_hunk
+        // (position of first target edit) instead of after the last source edit.
+        //
+        // Source patch has edits at lines 1-4, target has a new edit at line 10
+        // (different location to avoid the "same line" early return)
+        let source = r#"--- a/test.py
++++ b/test.py
+@@ -1,4 +1,5 @@
++import foo
+ import bar
+-import old
+ import baz
++import qux
+"#;
+        // Target has a pure insertion at a different line (line 10, not overlapping with source)
+        let target = r#"--- a/test.py
++++ b/test.py
+@@ -10,3 +10,4 @@
+ def main():
++    print("hello world")
+     pass
+"#;
+
+        // Use a seed that produces a partial result
+        let (new_src, _new_tgt, cursor) = imitate_human_edits(source, target, 42);
+
+        // The function should produce a modified patch
+        assert!(cursor.is_some(), "Should produce intermediate state");
+
+        let src_patch = Patch::parse_unified_diff(&new_src);
+        let all_additions: Vec<_> = src_patch
+            .hunks
+            .iter()
+            .flat_map(|h| h.lines.iter())
+            .filter_map(|l| match l {
+                PatchLine::Addition(s) => Some(s.as_str()),
+                _ => None,
+            })
+            .collect();
+
+        // The intermediate content (partial 'print("hello world")') should be
+        // the LAST addition, appearing after "+import qux" (the last source edit)
+        let last_addition = all_additions.last().expect("Should have additions");
+        assert!(
+            last_addition.trim_start().starts_with("pr"),
+            "Intermediate content should be the last addition (partial 'print'), but last was: {:?}",
+            last_addition
+        );
+
+        // Verify the original source edits are still in order before the intermediate
+        let foo_pos = all_additions.iter().position(|s| *s == "import foo");
+        let qux_pos = all_additions.iter().position(|s| *s == "import qux");
+        let intermediate_pos = all_additions
+            .iter()
+            .position(|s| s.trim_start().starts_with("pr"));
+
+        assert!(foo_pos.is_some(), "Should have 'import foo'");
+        assert!(qux_pos.is_some(), "Should have 'import qux'");
+        assert!(
+            intermediate_pos.is_some(),
+            "Should have intermediate content"
+        );
+
+        assert!(
+            foo_pos < qux_pos && qux_pos < intermediate_pos,
+            "Order should be: foo < qux < intermediate. Got foo={:?}, qux={:?}, intermediate={:?}",
+            foo_pos,
+            qux_pos,
+            intermediate_pos
+        );
+    }
+
     #[test]
     fn test_cursor_excerpt_with_multibyte_utf8() {
         // Test that cursor excerpt handles multi-byte UTF-8 characters correctly