Fix issues with predicted cursor positions (#48205)

Max Brunsfeld and Zed Zippy created

Release Notes:

- N/A

---------

Co-authored-by: Zed Zippy <234243425+zed-zippy[bot]@users.noreply.github.com>

Change summary

crates/edit_prediction/src/zed_edit_prediction_delegate.rs |  2 
crates/edit_prediction/src/zeta1.rs                        | 63 +++++
crates/edit_prediction/src/zeta2.rs                        | 79 -------
crates/edit_prediction_types/src/edit_prediction_types.rs  |  2 
crates/edit_prediction_ui/src/rate_prediction_modal.rs     | 36 +++
crates/editor/src/edit_prediction_tests.rs                 |  5 
crates/language/src/buffer.rs                              |  6 
7 files changed, 108 insertions(+), 85 deletions(-)

Detailed changes

crates/edit_prediction/src/zed_edit_prediction_delegate.rs 🔗

@@ -244,7 +244,7 @@ impl EditPredictionDelegate for ZedEditPredictionDelegate {
             Some(edit_prediction_types::EditPrediction::Local {
                 id: Some(prediction.id.to_string().into()),
                 edits: edits[edit_start_ix..edit_end_ix].to_vec(),
-                cursor_position: None,
+                cursor_position: prediction.cursor_position,
                 edit_preview: Some(prediction.edit_preview.clone()),
             })
         })

crates/edit_prediction/src/zeta1.rs 🔗

@@ -10,12 +10,14 @@ use anyhow::{Context as _, Result};
 use cloud_llm_client::{
     PredictEditsBody, PredictEditsGitInfo, PredictEditsRequestTrigger, PredictEditsResponse,
 };
+use edit_prediction_types::PredictedCursorPosition;
 use gpui::{App, AppContext as _, AsyncApp, Context, Entity, SharedString, Task};
 use language::{
     Anchor, Buffer, BufferSnapshot, OffsetRangeExt as _, Point, ToOffset, ToPoint as _, text_diff,
 };
 use project::{Project, ProjectPath};
 use release_channel::AppVersion;
+use text::Bias;
 use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
 use zeta_prompt::{Event, ZetaPromptInput};
 
@@ -347,9 +349,52 @@ pub fn compute_edits(
     offset: usize,
     snapshot: &BufferSnapshot,
 ) -> Vec<(Range<Anchor>, Arc<str>)> {
-    text_diff(&old_text, new_text)
-        .into_iter()
-        .map(|(mut old_range, new_text)| {
+    compute_edits_and_cursor_position(old_text, new_text, offset, None, snapshot).0
+}
+
+pub fn compute_edits_and_cursor_position(
+    old_text: String,
+    new_text: &str,
+    offset: usize,
+    cursor_offset_in_new_text: Option<usize>,
+    snapshot: &BufferSnapshot,
+) -> (
+    Vec<(Range<Anchor>, Arc<str>)>,
+    Option<PredictedCursorPosition>,
+) {
+    let diffs = text_diff(&old_text, new_text);
+
+    // Delta represents the cumulative change in byte count from all preceding edits.
+    // new_offset = old_offset + delta, so old_offset = new_offset - delta
+    let mut delta: isize = 0;
+    let mut cursor_position: Option<PredictedCursorPosition> = None;
+
+    let edits = diffs
+        .iter()
+        .map(|(raw_old_range, new_text)| {
+            // Compute cursor position if it falls within or before this edit.
+            if let (Some(cursor_offset), None) = (cursor_offset_in_new_text, cursor_position) {
+                let edit_start_in_new = (raw_old_range.start as isize + delta) as usize;
+                let edit_end_in_new = edit_start_in_new + new_text.len();
+
+                if cursor_offset < edit_start_in_new {
+                    let cursor_in_old = (cursor_offset as isize - delta) as usize;
+                    cursor_position = Some(PredictedCursorPosition::at_anchor(
+                        snapshot.anchor_after(offset + cursor_in_old),
+                    ));
+                } else if cursor_offset < edit_end_in_new {
+                    let offset_within_insertion = cursor_offset - edit_start_in_new;
+                    cursor_position = Some(PredictedCursorPosition::new(
+                        snapshot.anchor_before(offset + raw_old_range.start),
+                        offset_within_insertion,
+                    ));
+                }
+
+                delta += new_text.len() as isize - raw_old_range.len() as isize;
+            }
+
+            // Compute the edit with prefix/suffix trimming.
+            let mut old_range = raw_old_range.clone();
             let old_slice = &old_text[old_range.clone()];
 
             let prefix_len = common_prefix(old_slice.chars(), new_text.chars());
@@ -372,7 +417,17 @@ pub fn compute_edits(
             };
             (range, new_text)
         })
-        .collect()
+        .collect();
+
+    if let (Some(cursor_offset), None) = (cursor_offset_in_new_text, cursor_position) {
+        let cursor_in_old = (cursor_offset as isize - delta) as usize;
+        let buffer_offset = snapshot.clip_offset(offset + cursor_in_old, Bias::Right);
+        cursor_position = Some(PredictedCursorPosition::at_anchor(
+            snapshot.anchor_after(buffer_offset),
+        ));
+    }
+
+    (edits, cursor_position)
 }
 
 fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b: T2) -> usize {

crates/edit_prediction/src/zeta2.rs 🔗

@@ -1,5 +1,5 @@
 use crate::prediction::EditPredictionResult;
-use crate::zeta1::compute_edits;
+use crate::zeta1::compute_edits_and_cursor_position;
 use crate::{
     CurrentEditPrediction, DebugEvent, EDIT_PREDICTIONS_MODEL_ID, EditPredictionFinishedDebugEvent,
     EditPredictionId, EditPredictionModelInput, EditPredictionStartedDebugEvent,
@@ -8,9 +8,8 @@ use crate::{
 use anyhow::{Result, anyhow};
 use cloud_llm_client::predict_edits_v3::RawCompletionRequest;
 use cloud_llm_client::{AcceptEditPredictionBody, EditPredictionRejectReason};
-use edit_prediction_types::PredictedCursorPosition;
 use gpui::{App, Task, prelude::*};
-use language::{OffsetRangeExt as _, ToOffset as _, ToPoint, text_diff};
+use language::{OffsetRangeExt as _, ToOffset as _, ToPoint};
 use release_channel::AppVersion;
 
 use std::env;
@@ -173,23 +172,14 @@ pub fn request_prediction_with_zeta2(
                 old_text.push('\n');
             }
 
-            let edits = compute_edits(
-                old_text.clone(),
+            let (edits, cursor_position) = compute_edits_and_cursor_position(
+                old_text,
                 &output_text,
                 editable_offset_range.start,
+                cursor_offset_in_output,
                 &snapshot,
             );
 
-            let cursor_position = cursor_offset_in_output.map(|cursor_offset| {
-                compute_predicted_cursor_position(
-                    &old_text,
-                    &output_text,
-                    cursor_offset,
-                    editable_offset_range.start,
-                    &snapshot,
-                )
-            });
-
             anyhow::Ok((
                 Some((
                     request_id,
@@ -246,65 +236,6 @@ pub fn request_prediction_with_zeta2(
     })
 }
 
-/// Computes a `PredictedCursorPosition` from a cursor offset in the output text.
-///
-/// The cursor offset is relative to `new_text`. We need to determine if the cursor
-/// falls inside an edit's inserted text or in unchanged text:
-/// - If inside an edit: anchor = start of edit range, offset = position within insertion
-/// - If in unchanged text: anchor = corresponding position in old buffer, offset = 0
-fn compute_predicted_cursor_position(
-    old_text: &str,
-    new_text: &str,
-    cursor_offset_in_new: usize,
-    editable_region_start: usize,
-    snapshot: &language::BufferSnapshot,
-) -> PredictedCursorPosition {
-    let diffs = text_diff(old_text, new_text);
-
-    // Track position in both old and new text as we walk through diffs
-    let mut old_pos = 0usize;
-    let mut new_pos = 0usize;
-
-    for (old_range, new_text_chunk) in &diffs {
-        // Text before this diff is unchanged
-        let unchanged_len = old_range.start - old_pos;
-        let unchanged_end_in_new = new_pos + unchanged_len;
-
-        if cursor_offset_in_new < unchanged_end_in_new {
-            // Cursor is in unchanged text before this diff
-            let offset_in_unchanged = cursor_offset_in_new - new_pos;
-            let buffer_offset = editable_region_start + old_pos + offset_in_unchanged;
-            return PredictedCursorPosition::at_anchor(snapshot.anchor_after(buffer_offset));
-        }
-
-        // Move past the unchanged portion in new_text coordinates
-        new_pos = unchanged_end_in_new;
-
-        // Check if cursor is within this edit's new text
-        let edit_new_text_end = new_pos + new_text_chunk.len();
-        if cursor_offset_in_new < edit_new_text_end {
-            // Cursor is inside this edit's inserted text.
-            // Use anchor_before (left bias) so the anchor stays at the insertion point
-            // rather than moving past the inserted text.
-            let offset_within_insertion = cursor_offset_in_new - new_pos;
-            let buffer_offset = editable_region_start + old_range.start;
-            return PredictedCursorPosition::new(
-                snapshot.anchor_before(buffer_offset),
-                offset_within_insertion,
-            );
-        }
-
-        // Move past this edit
-        old_pos = old_range.end;
-        new_pos = edit_new_text_end;
-    }
-
-    // Cursor is in unchanged text after all diffs
-    let offset_in_unchanged = cursor_offset_in_new - new_pos;
-    let buffer_offset = (editable_region_start + old_pos + offset_in_unchanged).min(snapshot.len());
-    PredictedCursorPosition::at_anchor(snapshot.anchor_after(buffer_offset))
-}
-
 pub fn zeta2_prompt_input(
     snapshot: &language::BufferSnapshot,
     related_files: Vec<zeta_prompt::RelatedFile>,

crates/edit_prediction_types/src/edit_prediction_types.rs 🔗

@@ -52,7 +52,7 @@ impl EditPredictionIconSet {
 /// exist in the original buffer, we store an anchor (which points to a position
 /// in the original buffer, typically the start of an edit) plus an offset into
 /// the inserted text.
-#[derive(Clone, Debug)]
+#[derive(Copy, Clone, Debug)]
 pub struct PredictedCursorPosition {
     /// An anchor in the original buffer. If the cursor is inside an edit,
     /// this points to the start of that edit's range.

crates/edit_prediction_ui/src/rate_prediction_modal.rs 🔗

@@ -1,6 +1,6 @@
 use buffer_diff::BufferDiff;
 use edit_prediction::{EditPrediction, EditPredictionRating, EditPredictionStore};
-use editor::{Editor, ExcerptRange, MultiBuffer};
+use editor::{Editor, ExcerptRange, Inlay, MultiBuffer};
 use feature_flags::FeatureFlag;
 use gpui::{
     App, BorderStyle, DismissEvent, EdgesRefinement, Entity, EventEmitter, FocusHandle, Focusable,
@@ -8,7 +8,9 @@ use gpui::{
 };
 use language::{Buffer, CodeLabel, LanguageRegistry, Point, ToOffset, language_settings};
 use markdown::{Markdown, MarkdownStyle};
-use project::{Completion, CompletionDisplayOptions, CompletionResponse, CompletionSource};
+use project::{
+    Completion, CompletionDisplayOptions, CompletionResponse, CompletionSource, InlayId,
+};
 use settings::Settings as _;
 use std::rc::Rc;
 use std::{fmt::Write, sync::Arc, time::Duration};
@@ -352,9 +354,9 @@ impl RatePredictionsModal {
                 });
 
                 editor.disable_header_for_buffer(new_buffer_id, cx);
-                editor.buffer().update(cx, |multibuffer, cx| {
+                let excerpt_id = editor.buffer().update(cx, |multibuffer, cx| {
                     multibuffer.clear(cx);
-                    multibuffer.push_excerpts(
+                    let excerpt_ids = multibuffer.push_excerpts(
                         new_buffer,
                         vec![ExcerptRange {
                             context: start..end,
@@ -363,7 +365,33 @@ impl RatePredictionsModal {
                         cx,
                     );
                     multibuffer.add_diff(diff, cx);
+                    excerpt_ids.into_iter().next()
                 });
+
+                if let Some((excerpt_id, cursor_position)) =
+                    excerpt_id.zip(prediction.cursor_position.as_ref())
+                {
+                    let multibuffer_snapshot = editor.buffer().read(cx).snapshot(cx);
+                    if let Some(buffer_snapshot) =
+                        multibuffer_snapshot.buffer_for_excerpt(excerpt_id)
+                    {
+                        let cursor_offset = prediction
+                            .edit_preview
+                            .anchor_to_offset_in_result(cursor_position.anchor)
+                            + cursor_position.offset;
+                        let cursor_anchor = buffer_snapshot.anchor_after(cursor_offset);
+
+                        if let Some(anchor) =
+                            multibuffer_snapshot.anchor_in_excerpt(excerpt_id, cursor_anchor)
+                        {
+                            editor.splice_inlays(
+                                &[InlayId::EditPrediction(0)],
+                                vec![Inlay::edit_prediction(0, anchor, "▏")],
+                                cx,
+                            );
+                        }
+                    }
+                }
             });
 
             let mut formatted_inputs = String::new();

crates/editor/src/edit_prediction_tests.rs 🔗

@@ -37,10 +37,13 @@ async fn test_edit_prediction_insert(cx: &mut gpui::TestAppContext) {
 
 #[gpui::test]
 async fn test_edit_prediction_cursor_position_inside_insertion(cx: &mut gpui::TestAppContext) {
-    init_test(cx, |_| {});
+    init_test(cx, |_| {
+        eprintln!("");
+    });
 
     let mut cx = EditorTestContext::new(cx).await;
     let provider = cx.new(|_| FakeEditPredictionDelegate::default());
+
     assign_editor_completion_provider(provider.clone(), &mut cx);
     // Buffer: "fn foo() {}" - we'll insert text and position cursor inside the insertion
     cx.set_state("fn foo() ˇ{}");

crates/language/src/buffer.rs 🔗

@@ -918,6 +918,12 @@ impl EditPreview {
         })
     }
 
+    pub fn anchor_to_offset_in_result(&self, anchor: Anchor) -> usize {
+        anchor
+            .bias_right(&self.old_snapshot)
+            .to_offset(&self.applied_edits_snapshot)
+    }
+
     pub fn compute_visible_range<T>(&self, edits: &[(Range<Anchor>, T)]) -> Option<Range<Point>> {
         let (first, _) = edits.first()?;
         let (last, _) = edits.last()?;