ep: Encode cursor position in the predicted patch (#49450)

Oleksiy Syvokon created

Release Notes:

- N/A

Change summary

crates/edit_prediction/src/example_spec.rs     | 97 ++++++++++---------
crates/edit_prediction_cli/src/parse_output.rs |  3 
2 files changed, 54 insertions(+), 46 deletions(-)

Detailed changes

crates/edit_prediction/src/example_spec.rs 🔗

@@ -12,6 +12,56 @@ pub const INLINE_CURSOR_MARKER: &str = "<|user_cursor|>";
 /// falling back to git-based loading.
 pub const MAX_CURSOR_FILE_SIZE: usize = 64 * 1024;
 
+/// Encodes a cursor position into a diff patch by adding a comment line with a caret
+/// pointing to the cursor column.
+///
+/// The cursor offset is relative to the start of the new text content (additions and context lines).
+/// Returns the patch with cursor marker comment lines inserted after the relevant addition line.
+pub fn encode_cursor_in_patch(patch: &str, cursor_offset: Option<usize>) -> String {
+    let Some(cursor_offset) = cursor_offset else {
+        return patch.to_string();
+    };
+
+    let mut result = String::new();
+    let mut line_start_offset = 0usize;
+
+    for line in patch.lines() {
+        if !result.is_empty() {
+            result.push('\n');
+        }
+        result.push_str(line);
+
+        match DiffLine::parse(line) {
+            DiffLine::Addition(content) => {
+                let line_end_offset = line_start_offset + content.len();
+
+                if cursor_offset >= line_start_offset && cursor_offset <= line_end_offset {
+                    let cursor_column = cursor_offset - line_start_offset;
+
+                    result.push('\n');
+                    result.push('#');
+                    for _ in 0..cursor_column {
+                        result.push(' ');
+                    }
+                    write!(result, "^{}", CURSOR_POSITION_MARKER).unwrap();
+                }
+
+                line_start_offset = line_end_offset + 1;
+            }
+            DiffLine::Context(content) => {
+                line_start_offset += content.len() + 1;
+            }
+            _ => {}
+        }
+    }
+
+    if patch.ends_with('\n') {
+        result.push('\n');
+    }
+
+    result
+}
+
 #[derive(Clone, Debug, PartialEq, Hash, Serialize, Deserialize)]
 pub struct ExampleSpec {
     #[serde(default)]
@@ -567,52 +617,7 @@ impl ExampleSpec {
     ) {
         self.expected_patches = patches
             .into_iter()
-            .map(|(patch, cursor_editable_region_offset)| {
-                let Some(cursor_offset) = cursor_editable_region_offset else {
-                    return patch;
-                };
-
-                let mut result = String::new();
-                let mut line_start_offset = 0usize;
-
-                for line in patch.lines() {
-                    if !result.is_empty() {
-                        result.push('\n');
-                    }
-                    result.push_str(line);
-
-                    match DiffLine::parse(line) {
-                        DiffLine::Addition(content) => {
-                            let line_end_offset = line_start_offset + content.len();
-
-                            if cursor_offset >= line_start_offset
-                                && cursor_offset <= line_end_offset
-                            {
-                                let cursor_column = cursor_offset - line_start_offset;
-
-                                result.push('\n');
-                                result.push('#');
-                                for _ in 0..cursor_column {
-                                    result.push(' ');
-                                }
-                                write!(result, "^{}", CURSOR_POSITION_MARKER).unwrap();
-                            }
-
-                            line_start_offset = line_end_offset + 1;
-                        }
-                        DiffLine::Context(content) => {
-                            line_start_offset += content.len() + 1;
-                        }
-                        _ => {}
-                    }
-                }
-
-                if patch.ends_with('\n') {
-                    result.push('\n');
-                }
-
-                result
-            })
+            .map(|(patch, cursor_offset)| encode_cursor_in_patch(&patch, cursor_offset))
             .collect();
     }
 }

crates/edit_prediction_cli/src/parse_output.rs 🔗

@@ -5,6 +5,7 @@ use crate::{
     repair,
 };
 use anyhow::{Context as _, Result};
+use edit_prediction::example_spec::encode_cursor_in_patch;
 use zeta_prompt::{CURSOR_MARKER, ZetaFormat};
 
 pub fn run_parse_output(example: &mut Example) -> Result<()> {
@@ -162,6 +163,8 @@ fn parse_zeta2_output(
         path = example.spec.cursor_path.to_string_lossy(),
     );
 
+    let formatted_diff = encode_cursor_in_patch(&formatted_diff, cursor_offset);
+
     let actual_cursor = cursor_offset.map(|editable_region_cursor_offset| {
         ActualCursor::from_editable_region(
             &example.spec.cursor_path,