Integrate the new diff algorithm into the modal assistant

Antonio Scandurra created

Change summary

crates/ai/src/diff.rs     |  72 +++++++---
crates/ai/src/refactor.rs | 262 ++++++++--------------------------------
2 files changed, 100 insertions(+), 234 deletions(-)

Detailed changes

crates/ai/src/diff.rs 🔗

@@ -64,41 +64,40 @@ impl Debug for Matrix {
 }
 
 #[derive(Debug)]
-enum Hunk {
-    Insert(char),
-    Remove(char),
-    Keep(char),
+pub enum Hunk {
+    Insert { len: usize },
+    Remove { len: usize },
+    Keep { len: usize },
 }
 
-struct Diff {
+pub struct Diff {
     old: String,
     new: String,
     scores: Matrix,
-    last_diff_row: usize,
+    old_text_ix: usize,
 }
 
 impl Diff {
-    fn new(old: String) -> Self {
+    pub fn new(old: String) -> Self {
         let mut scores = Matrix::new();
         scores.resize(old.len() + 1, 1);
         for i in 0..=old.len() {
             scores.set(i, 0, -(i as isize));
         }
-        dbg!(&scores);
         Self {
             old,
             new: String::new(),
             scores,
-            last_diff_row: 0,
+            old_text_ix: 0,
         }
     }
 
-    fn push_new(&mut self, text: &str) -> Vec<Hunk> {
-        let last_diff_column = self.new.len();
+    pub fn push_new(&mut self, text: &str) -> Vec<Hunk> {
+        let new_text_ix = self.new.len();
         self.new.push_str(text);
         self.scores.resize(self.old.len() + 1, self.new.len() + 1);
 
-        for j in last_diff_column + 1..=self.new.len() {
+        for j in new_text_ix + 1..=self.new.len() {
             self.scores.set(0, j, -(j as isize));
             for i in 1..=self.old.len() {
                 let insertion_score = self.scores.get(i, j - 1) - 1;
@@ -114,8 +113,8 @@ impl Diff {
         }
 
         let mut max_score = isize::MIN;
-        let mut best_row = self.last_diff_row;
-        for i in self.last_diff_row..=self.old.len() {
+        let mut best_row = self.old_text_ix;
+        for i in self.old_text_ix..=self.old.len() {
             let score = self.scores.get(i, self.new.len());
             if score > max_score {
                 max_score = score;
@@ -126,18 +125,18 @@ impl Diff {
         let mut hunks = Vec::new();
         let mut i = best_row;
         let mut j = self.new.len();
-        while (i, j) != (self.last_diff_row, last_diff_column) {
-            let insertion_score = if j > last_diff_column {
+        while (i, j) != (self.old_text_ix, new_text_ix) {
+            let insertion_score = if j > new_text_ix {
                 Some((i, j - 1))
             } else {
                 None
             };
-            let deletion_score = if i > self.last_diff_row {
+            let deletion_score = if i > self.old_text_ix {
                 Some((i - 1, j))
             } else {
                 None
             };
-            let equality_score = if i > self.last_diff_row && j > last_diff_column {
+            let equality_score = if i > self.old_text_ix && j > new_text_ix {
                 Some((i - 1, j - 1))
             } else {
                 None
@@ -150,20 +149,42 @@ impl Diff {
                 .unwrap();
 
             if prev_i == i && prev_j == j - 1 {
-                hunks.push(Hunk::Insert(self.new.chars().skip(j - 1).next().unwrap()));
+                if let Some(Hunk::Insert { len }) = hunks.last_mut() {
+                    *len += 1;
+                } else {
+                    hunks.push(Hunk::Insert { len: 1 })
+                }
             } else if prev_i == i - 1 && prev_j == j {
-                hunks.push(Hunk::Remove(self.old.chars().skip(i - 1).next().unwrap()));
+                if let Some(Hunk::Remove { len }) = hunks.last_mut() {
+                    *len += 1;
+                } else {
+                    hunks.push(Hunk::Remove { len: 1 })
+                }
             } else {
-                hunks.push(Hunk::Keep(self.old.chars().skip(i - 1).next().unwrap()));
+                if let Some(Hunk::Keep { len }) = hunks.last_mut() {
+                    *len += 1;
+                } else {
+                    hunks.push(Hunk::Keep { len: 1 })
+                }
             }
 
             i = prev_i;
             j = prev_j;
         }
-        self.last_diff_row = best_row;
+        self.old_text_ix = best_row;
         hunks.reverse();
         hunks
     }
+
+    pub fn finish(self) -> Option<Hunk> {
+        if self.old_text_ix < self.old.len() {
+            Some(Hunk::Remove {
+                len: self.old.len() - self.old_text_ix,
+            })
+        } else {
+            None
+        }
+    }
 }
 
 #[cfg(test)]
@@ -173,8 +194,9 @@ mod tests {
     #[test]
     fn test_diff() {
         let mut diff = Diff::new("hello world".to_string());
-        dbg!(diff.push_new("hello"));
-        dbg!(diff.push_new(" ciaone"));
-        dbg!(diff.push_new(" world"));
+        diff.push_new("hello");
+        diff.push_new(" ciaone");
+        diff.push_new(" world");
+        diff.finish();
     }
 }

crates/ai/src/refactor.rs 🔗

@@ -1,7 +1,7 @@
 use crate::{stream_completion, OpenAIRequest, RequestMessage, Role};
 use collections::HashMap;
 use editor::{Editor, ToOffset};
-use futures::StreamExt;
+use futures::{channel::mpsc, SinkExt, StreamExt};
 use gpui::{
     actions, elements::*, AnyViewHandle, AppContext, Entity, Task, View, ViewContext, ViewHandle,
     WeakViewHandle,
@@ -59,151 +59,67 @@ impl RefactoringAssistant {
             editor.id(),
             cx.spawn(|mut cx| {
                 async move {
-                    let selection_start = selection.start.to_offset(&snapshot);
-
-                    let mut new_text = String::new();
-                    let mut messages = response.await?;
-
-                    let mut transaction = None;
-
-                    while let Some(message) = messages.next().await {
-                        smol::future::yield_now().await;
-                        let mut message = message?;
-                        if let Some(choice) = message.choices.pop() {
-                            if let Some(text) = choice.delta.content {
-                                new_text.push_str(&text);
-
-                                println!("-------------------------------------");
-
-                                println!(
-                                    "{}",
-                                    similar::TextDiff::from_words(&selected_text, &new_text)
-                                        .unified_diff()
-                                );
-
-                                let mut changes =
-                                    similar::TextDiff::from_words(&selected_text, &new_text)
-                                        .iter_all_changes()
-                                        .collect::<Vec<_>>();
-
-                                let mut ix = 0;
-                                while ix < changes.len() {
-                                    let deletion_start_ix = ix;
-                                    let mut deletion_end_ix = ix;
-                                    while changes
-                                        .get(ix)
-                                        .map_or(false, |change| change.tag() == ChangeTag::Delete)
-                                    {
-                                        ix += 1;
-                                        deletion_end_ix += 1;
+                    let mut edit_start = selection.start.to_offset(&snapshot);
+
+                    let (mut hunks_tx, mut hunks_rx) = mpsc::channel(1);
+                    let diff = cx.background().spawn(async move {
+                        let mut messages = response.await?.ready_chunks(4);
+                        let mut diff = crate::diff::Diff::new(selected_text);
+
+                        while let Some(messages) = messages.next().await {
+                            let mut new_text = String::new();
+                            for message in messages {
+                                let mut message = message?;
+                                if let Some(choice) = message.choices.pop() {
+                                    if let Some(text) = choice.delta.content {
+                                        new_text.push_str(&text);
                                     }
-
-                                    let insertion_start_ix = ix;
-                                    let mut insertion_end_ix = ix;
-                                    while changes
-                                        .get(ix)
-                                        .map_or(false, |change| change.tag() == ChangeTag::Insert)
-                                    {
-                                        ix += 1;
-                                        insertion_end_ix += 1;
-                                    }
-
-                                    if deletion_end_ix > deletion_start_ix
-                                        && insertion_end_ix > insertion_start_ix
-                                    {
-                                        for _ in deletion_start_ix..deletion_end_ix {
-                                            let deletion = changes.remove(deletion_end_ix);
-                                            changes.insert(insertion_end_ix - 1, deletion);
-                                        }
-                                    }
-
-                                    ix += 1;
                                 }
-
-                                while changes
-                                    .last()
-                                    .map_or(false, |change| change.tag() != ChangeTag::Insert)
-                                {
-                                    changes.pop();
-                                }
-
-                                editor.update(&mut cx, |editor, cx| {
-                                    editor.buffer().update(cx, |buffer, cx| {
-                                        if let Some(transaction) = transaction.take() {
-                                            buffer.undo(cx); // TODO: Undo the transaction instead
-                                        }
-
-                                        buffer.start_transaction(cx);
-                                        let mut edit_start = selection_start;
-                                        dbg!(&changes);
-                                        for change in changes {
-                                            let value = change.value();
-                                            let edit_end = edit_start + value.len();
-                                            match change.tag() {
-                                                ChangeTag::Equal => {
-                                                    edit_start = edit_end;
-                                                }
-                                                ChangeTag::Delete => {
-                                                    let range = snapshot.anchor_after(edit_start)
-                                                        ..snapshot.anchor_before(edit_end);
-                                                    buffer.edit([(range, "")], None, cx);
-                                                    edit_start = edit_end;
-                                                }
-                                                ChangeTag::Insert => {
-                                                    let insertion_start =
-                                                        snapshot.anchor_after(edit_start);
-                                                    buffer.edit(
-                                                        [(insertion_start..insertion_start, value)],
-                                                        None,
-                                                        cx,
-                                                    );
-                                                }
-                                            }
-                                        }
-                                        transaction = buffer.end_transaction(cx);
-                                    })
-                                })?;
                             }
+
+                            let hunks = diff.push_new(&new_text);
+                            hunks_tx.send((hunks, new_text)).await?;
                         }
-                    }
 
-                    editor.update(&mut cx, |editor, cx| {
-                        editor.buffer().update(cx, |buffer, cx| {
-                            if let Some(transaction) = transaction.take() {
-                                buffer.undo(cx); // TODO: Undo the transaction instead
-                            }
+                        if let Some(hunk) = diff.finish() {
+                            hunks_tx.send((vec![hunk], String::new())).await?;
+                        }
 
-                            buffer.start_transaction(cx);
-                            let mut edit_start = selection_start;
-                            for change in similar::TextDiff::from_words(&selected_text, &new_text)
-                                .iter_all_changes()
-                            {
-                                let value = change.value();
-                                let edit_end = edit_start + value.len();
-                                match change.tag() {
-                                    ChangeTag::Equal => {
-                                        edit_start = edit_end;
-                                    }
-                                    ChangeTag::Delete => {
-                                        let range = snapshot.anchor_after(edit_start)
-                                            ..snapshot.anchor_before(edit_end);
-                                        buffer.edit([(range, "")], None, cx);
-                                        edit_start = edit_end;
-                                    }
-                                    ChangeTag::Insert => {
-                                        let insertion_start = snapshot.anchor_after(edit_start);
-                                        buffer.edit(
-                                            [(insertion_start..insertion_start, value)],
-                                            None,
-                                            cx,
-                                        );
+                        anyhow::Ok(())
+                    });
+
+                    while let Some((hunks, new_text)) = hunks_rx.next().await {
+                        editor.update(&mut cx, |editor, cx| {
+                            editor.buffer().update(cx, |buffer, cx| {
+                                buffer.start_transaction(cx);
+                                let mut new_text_ix = 0;
+                                for hunk in hunks {
+                                    match hunk {
+                                        crate::diff::Hunk::Insert { len } => {
+                                            let text = &new_text[new_text_ix..new_text_ix + len];
+                                            let edit_start = snapshot.anchor_after(edit_start);
+                                            buffer.edit([(edit_start..edit_start, text)], None, cx);
+                                            new_text_ix += len;
+                                        }
+                                        crate::diff::Hunk::Remove { len } => {
+                                            let edit_end = edit_start + len;
+                                            let edit_range = snapshot.anchor_after(edit_start)
+                                                ..snapshot.anchor_before(edit_end);
+                                            buffer.edit([(edit_range, "")], None, cx);
+                                            edit_start = edit_end;
+                                        }
+                                        crate::diff::Hunk::Keep { len } => {
+                                            edit_start += len;
+                                            new_text_ix += len;
+                                        }
                                     }
                                 }
-                            }
-                            buffer.end_transaction(cx);
-                        })
-                    })?;
+                                buffer.end_transaction(cx);
+                            })
+                        })?;
+                    }
 
+                    diff.await?;
                     anyhow::Ok(())
                 }
                 .log_err()
@@ -285,75 +201,3 @@ impl RefactoringModal {
         }
     }
 }
-fn words(text: &str) -> impl Iterator<Item = (Range<usize>, &str)> {
-    let mut word_start_ix = None;
-    let mut chars = text.char_indices();
-    iter::from_fn(move || {
-        while let Some((ix, ch)) = chars.next() {
-            if let Some(start_ix) = word_start_ix {
-                if !ch.is_alphanumeric() {
-                    let word = &text[start_ix..ix];
-                    word_start_ix.take();
-                    return Some((start_ix..ix, word));
-                }
-            } else {
-                if ch.is_alphanumeric() {
-                    word_start_ix = Some(ix);
-                }
-            }
-        }
-        None
-    })
-}
-
-fn streaming_diff<'a>(old_text: &'a str, new_text: &'a str) -> Vec<Change<'a, str>> {
-    let changes = TextDiff::configure()
-        .algorithm(similar::Algorithm::Patience)
-        .diff_words(old_text, new_text);
-    let mut changes = changes.iter_all_changes().peekable();
-
-    let mut result = vec![];
-
-    loop {
-        let mut deletions = vec![];
-        let mut insertions = vec![];
-
-        while changes
-            .peek()
-            .map_or(false, |change| change.tag() == ChangeTag::Delete)
-        {
-            deletions.push(changes.next().unwrap());
-        }
-
-        while changes
-            .peek()
-            .map_or(false, |change| change.tag() == ChangeTag::Insert)
-        {
-            insertions.push(changes.next().unwrap());
-        }
-
-        if !deletions.is_empty() && !insertions.is_empty() {
-            result.append(&mut insertions);
-            result.append(&mut deletions);
-        } else {
-            result.append(&mut deletions);
-            result.append(&mut insertions);
-        }
-
-        if let Some(change) = changes.next() {
-            result.push(change);
-        } else {
-            break;
-        }
-    }
-
-    // Remove all non-inserts at the end.
-    while result
-        .last()
-        .map_or(false, |change| change.tag() != ChangeTag::Insert)
-    {
-        result.pop();
-    }
-
-    result
-}