WIP

Antonio Scandurra created

Change summary

Cargo.lock                |   1 
crates/ai/Cargo.toml      |   1 
crates/ai/src/refactor.rs | 347 +++++++++++++++++++++++++++-------------
3 files changed, 237 insertions(+), 112 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -106,6 +106,7 @@ dependencies = [
  "fs",
  "futures 0.3.28",
  "gpui",
+ "indoc",
  "isahc",
  "language",
  "menu",

crates/ai/Cargo.toml 🔗

@@ -24,6 +24,7 @@ workspace = { path = "../workspace" }
 anyhow.workspace = true
 chrono = { version = "0.4", features = ["serde"] }
 futures.workspace = true
+indoc.workspace = true
 isahc.workspace = true
 regex.workspace = true
 schemars.workspace = true

crates/ai/src/refactor.rs 🔗

@@ -1,14 +1,13 @@
 use crate::{stream_completion, OpenAIRequest, RequestMessage, Role};
-use collections::{BTreeMap, BTreeSet, HashMap, HashSet};
-use editor::{Anchor, Editor, MultiBuffer, MultiBufferSnapshot, ToOffset};
-use futures::{io::BufWriter, AsyncReadExt, AsyncWriteExt, StreamExt};
+use collections::HashMap;
+use editor::{Editor, ToOffset};
+use futures::StreamExt;
 use gpui::{
     actions, elements::*, AnyViewHandle, AppContext, Entity, Task, View, ViewContext, ViewHandle,
     WeakViewHandle,
 };
 use menu::Confirm;
-use serde::Deserialize;
-use similar::ChangeTag;
+use similar::{Change, ChangeTag, TextDiff};
 use std::{env, iter, ops::Range, sync::Arc};
 use util::TryFutureExt;
 use workspace::{Modal, Workspace};
@@ -33,12 +32,12 @@ impl RefactoringAssistant {
     }
 
     fn refactor(&mut self, editor: &ViewHandle<Editor>, prompt: &str, cx: &mut AppContext) {
-        let buffer = editor.read(cx).buffer().read(cx).snapshot(cx);
+        let snapshot = editor.read(cx).buffer().read(cx).snapshot(cx);
         let selection = editor.read(cx).selections.newest_anchor().clone();
-        let selected_text = buffer
+        let selected_text = snapshot
             .text_for_range(selection.start..selection.end)
             .collect::<String>();
-        let language_name = buffer
+        let language_name = snapshot
             .language_at(selection.start)
             .map(|language| language.name());
         let language_name = language_name.as_deref().unwrap_or("");
@@ -48,7 +47,7 @@ impl RefactoringAssistant {
                 RequestMessage {
                 role: Role::User,
                 content: format!(
-                    "Given the following {language_name} snippet:\n{selected_text}\n{prompt}. Avoid making remarks and reply only with the new code."
+                    "Given the following {language_name} snippet:\n{selected_text}\n{prompt}. Avoid making remarks and reply only with the new code. Preserve indentation."
                 ),
             }],
             stream: true,
@@ -60,86 +59,149 @@ impl RefactoringAssistant {
             editor.id(),
             cx.spawn(|mut cx| {
                 async move {
-                    let selection_start = selection.start.to_offset(&buffer);
-
-                    // Find unique words in the selected text to use as diff boundaries.
-                    let mut duplicate_words = HashSet::default();
-                    let mut unique_old_words = HashMap::default();
-                    for (range, word) in words(&selected_text) {
-                        if !duplicate_words.contains(word) {
-                            if unique_old_words.insert(word, range.end).is_some() {
-                                unique_old_words.remove(word);
-                                duplicate_words.insert(word);
-                            }
-                        }
-                    }
+                    let selection_start = selection.start.to_offset(&snapshot);
 
                     let mut new_text = String::new();
                     let mut messages = response.await?;
-                    let mut new_word_search_start_ix = 0;
-                    let mut last_old_word_end_ix = 0;
-
-                    'outer: loop {
-                        const MIN_DIFF_LEN: usize = 50;
-
-                        let start = new_word_search_start_ix;
-                        let mut words = words(&new_text[start..]);
-                        while let Some((range, new_word)) = words.next() {
-                            // We found a word in the new text that was unique in the old text. We can use
-                            // it as a diff boundary, and start applying edits.
-                            if let Some(old_word_end_ix) = unique_old_words.get(new_word).copied() {
-                                if old_word_end_ix.saturating_sub(last_old_word_end_ix)
-                                    > MIN_DIFF_LEN
-                                {
-                                    drop(words);
-
-                                    let remainder = new_text.split_off(start + range.end);
-                                    let edits = diff(
-                                        selection_start + last_old_word_end_ix,
-                                        &selected_text[last_old_word_end_ix..old_word_end_ix],
-                                        &new_text,
-                                        &buffer,
-                                    );
-                                    editor.update(&mut cx, |editor, cx| {
-                                        editor
-                                            .buffer()
-                                            .update(cx, |buffer, cx| buffer.edit(edits, None, cx))
-                                    })?;
-
-                                    new_text = remainder;
-                                    new_word_search_start_ix = 0;
-                                    last_old_word_end_ix = old_word_end_ix;
-                                    continue 'outer;
+
+                    let mut transaction = None;
+
+                    while let Some(message) = messages.next().await {
+                        smol::future::yield_now().await;
+                        let mut message = message?;
+                        if let Some(choice) = message.choices.pop() {
+                            if let Some(text) = choice.delta.content {
+                                new_text.push_str(&text);
+
+                                println!("-------------------------------------");
+
+                                println!(
+                                    "{}",
+                                    similar::TextDiff::from_words(&selected_text, &new_text)
+                                        .unified_diff()
+                                );
+
+                                let mut changes =
+                                    similar::TextDiff::from_words(&selected_text, &new_text)
+                                        .iter_all_changes()
+                                        .collect::<Vec<_>>();
+
+                                let mut ix = 0;
+                                while ix < changes.len() {
+                                    let deletion_start_ix = ix;
+                                    let mut deletion_end_ix = ix;
+                                    while changes
+                                        .get(ix)
+                                        .map_or(false, |change| change.tag() == ChangeTag::Delete)
+                                    {
+                                        ix += 1;
+                                        deletion_end_ix += 1;
+                                    }
+
+                                    let insertion_start_ix = ix;
+                                    let mut insertion_end_ix = ix;
+                                    while changes
+                                        .get(ix)
+                                        .map_or(false, |change| change.tag() == ChangeTag::Insert)
+                                    {
+                                        ix += 1;
+                                        insertion_end_ix += 1;
+                                    }
+
+                                    if deletion_end_ix > deletion_start_ix
+                                        && insertion_end_ix > insertion_start_ix
+                                    {
+                                        for _ in deletion_start_ix..deletion_end_ix {
+                                            let deletion = changes.remove(deletion_end_ix);
+                                            changes.insert(insertion_end_ix - 1, deletion);
+                                        }
+                                    }
+
+                                    ix += 1;
                                 }
-                            }
 
-                            new_word_search_start_ix = start + range.end;
-                        }
-                        drop(words);
-
-                        // Buffer incoming text, stopping if the stream was exhausted.
-                        if let Some(message) = messages.next().await {
-                            let mut message = message?;
-                            if let Some(choice) = message.choices.pop() {
-                                if let Some(text) = choice.delta.content {
-                                    new_text.push_str(&text);
+                                while changes
+                                    .last()
+                                    .map_or(false, |change| change.tag() != ChangeTag::Insert)
+                                {
+                                    changes.pop();
                                 }
+
+                                editor.update(&mut cx, |editor, cx| {
+                                    editor.buffer().update(cx, |buffer, cx| {
+                                        if let Some(transaction) = transaction.take() {
+                                            buffer.undo(cx); // TODO: Undo the transaction instead
+                                        }
+
+                                        buffer.start_transaction(cx);
+                                        let mut edit_start = selection_start;
+                                        dbg!(&changes);
+                                        for change in changes {
+                                            let value = change.value();
+                                            let edit_end = edit_start + value.len();
+                                            match change.tag() {
+                                                ChangeTag::Equal => {
+                                                    edit_start = edit_end;
+                                                }
+                                                ChangeTag::Delete => {
+                                                    let range = snapshot.anchor_after(edit_start)
+                                                        ..snapshot.anchor_before(edit_end);
+                                                    buffer.edit([(range, "")], None, cx);
+                                                    edit_start = edit_end;
+                                                }
+                                                ChangeTag::Insert => {
+                                                    let insertion_start =
+                                                        snapshot.anchor_after(edit_start);
+                                                    buffer.edit(
+                                                        [(insertion_start..insertion_start, value)],
+                                                        None,
+                                                        cx,
+                                                    );
+                                                }
+                                            }
+                                        }
+                                        transaction = buffer.end_transaction(cx);
+                                    })
+                                })?;
                             }
-                        } else {
-                            break;
                         }
                     }
 
-                    let edits = diff(
-                        selection_start + last_old_word_end_ix,
-                        &selected_text[last_old_word_end_ix..],
-                        &new_text,
-                        &buffer,
-                    );
                     editor.update(&mut cx, |editor, cx| {
-                        editor
-                            .buffer()
-                            .update(cx, |buffer, cx| buffer.edit(edits, None, cx))
+                        editor.buffer().update(cx, |buffer, cx| {
+                            if let Some(transaction) = transaction.take() {
+                                buffer.undo(cx); // TODO: Undo the transaction instead
+                            }
+
+                            buffer.start_transaction(cx);
+                            let mut edit_start = selection_start;
+                            for change in similar::TextDiff::from_words(&selected_text, &new_text)
+                                .iter_all_changes()
+                            {
+                                let value = change.value();
+                                let edit_end = edit_start + value.len();
+                                match change.tag() {
+                                    ChangeTag::Equal => {
+                                        edit_start = edit_end;
+                                    }
+                                    ChangeTag::Delete => {
+                                        let range = snapshot.anchor_after(edit_start)
+                                            ..snapshot.anchor_before(edit_end);
+                                        buffer.edit([(range, "")], None, cx);
+                                        edit_start = edit_end;
+                                    }
+                                    ChangeTag::Insert => {
+                                        let insertion_start = snapshot.anchor_after(edit_start);
+                                        buffer.edit(
+                                            [(insertion_start..insertion_start, value)],
+                                            None,
+                                            cx,
+                                        );
+                                    }
+                                }
+                            }
+                            buffer.end_transaction(cx);
+                        })
                     })?;
 
                     anyhow::Ok(())
@@ -197,11 +259,13 @@ impl RefactoringModal {
         {
             workspace.toggle_modal(cx, |_, cx| {
                 let prompt_editor = cx.add_view(|cx| {
-                    Editor::auto_height(
+                    let mut editor = Editor::auto_height(
                         4,
                         Some(Arc::new(|theme| theme.search.editor.input.clone())),
                         cx,
-                    )
+                    );
+                    editor.set_text("Replace with match statement.", cx);
+                    editor
                 });
                 cx.add_view(|_| RefactoringModal {
                     editor,
@@ -242,38 +306,97 @@ fn words(text: &str) -> impl Iterator<Item = (Range<usize>, &str)> {
     })
 }
 
-fn diff<'a>(
-    start_ix: usize,
-    old_text: &'a str,
-    new_text: &'a str,
-    old_buffer_snapshot: &MultiBufferSnapshot,
-) -> Vec<(Range<Anchor>, &'a str)> {
-    let mut edit_start = start_ix;
-    let mut edits = Vec::new();
-    let diff = similar::TextDiff::from_words(old_text, &new_text);
-    for change in diff.iter_all_changes() {
-        let value = change.value();
-        let edit_end = edit_start + value.len();
-        match change.tag() {
-            ChangeTag::Equal => {
-                edit_start = edit_end;
-            }
-            ChangeTag::Delete => {
-                edits.push((
-                    old_buffer_snapshot.anchor_after(edit_start)
-                        ..old_buffer_snapshot.anchor_before(edit_end),
-                    "",
-                ));
-                edit_start = edit_end;
-            }
-            ChangeTag::Insert => {
-                edits.push((
-                    old_buffer_snapshot.anchor_after(edit_start)
-                        ..old_buffer_snapshot.anchor_after(edit_start),
-                    value,
-                ));
-            }
+fn streaming_diff<'a>(old_text: &'a str, new_text: &'a str) -> Vec<Change<'a, str>> {
+    let changes = TextDiff::configure()
+        .algorithm(similar::Algorithm::Patience)
+        .diff_words(old_text, new_text);
+    let mut changes = changes.iter_all_changes().peekable();
+
+    let mut result = vec![];
+
+    loop {
+        let mut deletions = vec![];
+        let mut insertions = vec![];
+
+        while changes
+            .peek()
+            .map_or(false, |change| change.tag() == ChangeTag::Delete)
+        {
+            deletions.push(changes.next().unwrap());
         }
+
+        while changes
+            .peek()
+            .map_or(false, |change| change.tag() == ChangeTag::Insert)
+        {
+            insertions.push(changes.next().unwrap());
+        }
+
+        if !deletions.is_empty() && !insertions.is_empty() {
+            result.append(&mut insertions);
+            result.append(&mut deletions);
+        } else {
+            result.append(&mut deletions);
+            result.append(&mut insertions);
+        }
+
+        if let Some(change) = changes.next() {
+            result.push(change);
+        } else {
+            break;
+        }
+    }
+
+    // Remove all non-inserts at the end.
+    while result
+        .last()
+        .map_or(false, |change| change.tag() != ChangeTag::Insert)
+    {
+        result.pop();
+    }
+
+    result
+}
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+    use indoc::indoc;
+
+    #[test]
+    fn test_streaming_diff() {
+        let old_text = indoc! {"
+            match (self.format, src_format) {
+                (Format::A8, Format::A8)
+                | (Format::Rgb24, Format::Rgb24)
+                | (Format::Rgba32, Format::Rgba32) => {
+                    return self
+                        .blit_from_with::<BlitMemcpy>(dst_rect, src_bytes, src_stride, src_format);
+                }
+                (Format::A8, Format::Rgb24) => {
+                    return self
+                        .blit_from_with::<BlitRgb24ToA8>(dst_rect, src_bytes, src_stride, src_format);
+                }
+                (Format::Rgb24, Format::A8) => {
+                    return self
+                        .blit_from_with::<BlitA8ToRgb24>(dst_rect, src_bytes, src_stride, src_format);
+                }
+                (Format::Rgb24, Format::Rgba32) => {
+                    return self.blit_from_with::<BlitRgba32ToRgb24>(
+                        dst_rect, src_bytes, src_stride, src_format,
+                    );
+                }
+                (Format::Rgba32, Format::Rgb24)
+                | (Format::Rgba32, Format::A8)
+                | (Format::A8, Format::Rgba32) => {
+                    unimplemented!()
+                }
+                _ => {}
+            }
+        "};
+        let new_text = indoc! {"
+            if self.format == src_format
+        "};
+        dbg!(streaming_diff(old_text, new_text));
     }
-    edits
 }