zeta2: Compute smaller edits (#38786)

Agus Zubiaga created

The new cloud endpoint returns structured edits, but they may include
more of the input excerpt than what we want to display in the preview,
so we compute a smaller diff on the client side against the snapshot.

Release Notes:

- N/A

Change summary

Cargo.lock                |   1 
crates/zeta2/Cargo.toml   |   1 
crates/zeta2/src/zeta2.rs | 139 +++++++++++++++++++++++++++++++++-------
3 files changed, 114 insertions(+), 27 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -20587,6 +20587,7 @@ dependencies = [
  "edit_prediction_context",
  "futures 0.3.31",
  "gpui",
+ "indoc",
  "language",
  "language_model",
  "log",

crates/zeta2/Cargo.toml 🔗

@@ -22,6 +22,7 @@ edit_prediction.workspace = true
 edit_prediction_context.workspace = true
 futures.workspace = true
 gpui.workspace = true
+indoc.workspace = true
 language.workspace = true
 language_model.workspace = true
 log.workspace = true

crates/zeta2/src/zeta2.rs 🔗

@@ -21,11 +21,13 @@ use gpui::{
 };
 use language::{
     Anchor, Buffer, DiagnosticSet, LanguageServerId, OffsetRangeExt as _, ToOffset as _, ToPoint,
+    text_diff,
 };
 use language::{BufferSnapshot, EditPreview};
 use language_model::{LlmApiToken, RefreshLlmTokenListener};
 use project::Project;
 use release_channel::AppVersion;
+use std::borrow::Cow;
 use std::cmp;
 use std::collections::{HashMap, VecDeque, hash_map};
 use std::path::PathBuf;
@@ -438,7 +440,10 @@ impl Zeta {
                         .ok();
                 }
 
-                anyhow::Ok(Some(response?))
+                let (response, usage) = response?;
+                let edits = Self::compute_edits(&response.edits, &snapshot);
+
+                anyhow::Ok(Some((response.request_id, edits, usage)))
             }
         });
 
@@ -446,9 +451,7 @@ impl Zeta {
 
         cx.spawn(async move |this, cx| {
             match request_task.await {
-                Ok(Some((response, usage))) => {
-                    log::debug!("predicted edits: {:?}", &response.edits);
-
+                Ok(Some((id, edits, usage))) => {
                     if let Some(usage) = usage {
                         this.update(cx, |this, cx| {
                             this.user_store.update(cx, |user_store, cx| {
@@ -459,28 +462,6 @@ impl Zeta {
                     }
 
                     // TODO telemetry: duration, etc
-
-                    // TODO produce smaller edits by diffing against snapshot first
-                    //
-                    // Cloud returns entire snippets/excerpts ranges as they were included
-                    // in the request, but we should display smaller edits to the user.
-                    //
-                    // We can do this by computing a diff of each one against the snapshot.
-                    // Similar to zeta::Zeta::compute_edits, but per edit.
-                    let edits = response
-                        .edits
-                        .into_iter()
-                        .map(|edit| {
-                            // TODO edits to different files
-                            (
-                                snapshot.anchor_before(edit.range.start)
-                                    ..snapshot.anchor_before(edit.range.end),
-                                edit.content,
-                            )
-                        })
-                        .collect::<Vec<_>>()
-                        .into();
-
                     let Some((edits, snapshot, edit_preview_task)) =
                         buffer.read_with(cx, |buffer, cx| {
                             let new_snapshot = buffer.snapshot();
@@ -493,7 +474,7 @@ impl Zeta {
                     };
 
                     Ok(Some(EditPrediction {
-                        id: EditPredictionId(response.request_id),
+                        id: EditPredictionId(id),
                         edits,
                         snapshot,
                         edit_preview: edit_preview_task.await,
@@ -604,6 +585,62 @@ impl Zeta {
         }
     }
 
+    fn compute_edits(
+        edits: &[predict_edits_v3::Edit],
+        snapshot: &BufferSnapshot,
+    ) -> Arc<[(Range<Anchor>, String)]> {
+        edits
+            .iter()
+            .flat_map(|edit| {
+                // TODO multi-file edits
+                let old_text = snapshot.text_for_range(edit.range.clone());
+
+                Self::compute_excerpt_edits(
+                    old_text.collect::<Cow<str>>(),
+                    &edit.content,
+                    edit.range.start,
+                    &snapshot,
+                )
+            })
+            .collect::<Vec<_>>()
+            .into()
+    }
+
+    fn compute_excerpt_edits(
+        old_text: Cow<str>,
+        new_text: &str,
+        offset: usize,
+        snapshot: &BufferSnapshot,
+    ) -> impl Iterator<Item = (Range<Anchor>, String)> {
+        text_diff(&old_text, new_text)
+            .into_iter()
+            .map(move |(mut old_range, new_text)| {
+                old_range.start += offset;
+                old_range.end += offset;
+
+                let prefix_len = common_prefix(
+                    snapshot.chars_for_range(old_range.clone()),
+                    new_text.chars(),
+                );
+                old_range.start += prefix_len;
+
+                let suffix_len = common_prefix(
+                    snapshot.reversed_chars_for_range(old_range.clone()),
+                    new_text[prefix_len..].chars().rev(),
+                );
+                old_range.end = old_range.end.saturating_sub(suffix_len);
+
+                let new_text = new_text[prefix_len..new_text.len() - suffix_len].to_string();
+                let range = if old_range.is_empty() {
+                    let anchor = snapshot.anchor_after(old_range.start);
+                    anchor..anchor
+                } else {
+                    snapshot.anchor_after(old_range.start)..snapshot.anchor_before(old_range.end)
+                };
+                (range, new_text)
+            })
+    }
+
     fn gather_nearby_diagnostics(
         cursor_offset: usize,
         diagnostic_sets: &[(LanguageServerId, DiagnosticSet)],
@@ -713,6 +750,13 @@ impl Zeta {
     }
 }
 
+fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b: T2) -> usize {
+    a.zip(b)
+        .take_while(|(a, b)| a == b)
+        .map(|(a, _)| a.len_utf8())
+        .sum()
+}
+
 #[derive(Error, Debug)]
 #[error(
     "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
@@ -1222,6 +1266,47 @@ fn interpolate(
 mod tests {
     use super::*;
     use gpui::TestAppContext;
+    use indoc::indoc;
+
+    #[gpui::test]
+    async fn test_compute_edits(cx: &mut TestAppContext) {
+        let old = indoc! {r#"
+            fn main() {
+                let args =
+                println!("{}", args[1])
+            }
+        "#};
+
+        let new = indoc! {r#"
+            fn main() {
+                let args = std::env::args();
+                println!("{}", args[1]);
+            }
+        "#};
+
+        let buffer = cx.new(|cx| Buffer::local(old, cx));
+        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
+
+        // TODO cover more cases when multi-file is supported
+        let big_edits = vec![predict_edits_v3::Edit {
+            path: PathBuf::from("test.txt"),
+            range: 0..old.len(),
+            content: new.into(),
+        }];
+
+        let edits = Zeta::compute_edits(&big_edits, &snapshot);
+        assert_eq!(edits.len(), 2);
+        assert_eq!(
+            edits[0].0.to_point(&snapshot).start,
+            language::Point::new(1, 14)
+        );
+        assert_eq!(edits[0].1, " std::env::args();");
+        assert_eq!(
+            edits[1].0.to_point(&snapshot).start,
+            language::Point::new(2, 27)
+        );
+        assert_eq!(edits[1].1, ";");
+    }
 
     #[gpui::test]
     async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {