Cargo.lock 🔗
@@ -106,6 +106,7 @@ dependencies = [
"fs",
"futures 0.3.28",
"gpui",
+ "indoc",
"isahc",
"language",
"menu",
Antonio Scandurra created
Cargo.lock | 1
crates/ai/Cargo.toml | 1
crates/ai/src/refactor.rs | 347 +++++++++++++++++++++++++++-------------
3 files changed, 237 insertions(+), 112 deletions(-)
@@ -106,6 +106,7 @@ dependencies = [
"fs",
"futures 0.3.28",
"gpui",
+ "indoc",
"isahc",
"language",
"menu",
@@ -24,6 +24,7 @@ workspace = { path = "../workspace" }
anyhow.workspace = true
chrono = { version = "0.4", features = ["serde"] }
futures.workspace = true
+indoc.workspace = true
isahc.workspace = true
regex.workspace = true
schemars.workspace = true
@@ -1,14 +1,13 @@
use crate::{stream_completion, OpenAIRequest, RequestMessage, Role};
-use collections::{BTreeMap, BTreeSet, HashMap, HashSet};
-use editor::{Anchor, Editor, MultiBuffer, MultiBufferSnapshot, ToOffset};
-use futures::{io::BufWriter, AsyncReadExt, AsyncWriteExt, StreamExt};
+use collections::HashMap;
+use editor::{Editor, ToOffset};
+use futures::StreamExt;
use gpui::{
actions, elements::*, AnyViewHandle, AppContext, Entity, Task, View, ViewContext, ViewHandle,
WeakViewHandle,
};
use menu::Confirm;
-use serde::Deserialize;
-use similar::ChangeTag;
+use similar::{Change, ChangeTag, TextDiff};
use std::{env, iter, ops::Range, sync::Arc};
use util::TryFutureExt;
use workspace::{Modal, Workspace};
@@ -33,12 +32,12 @@ impl RefactoringAssistant {
}
fn refactor(&mut self, editor: &ViewHandle<Editor>, prompt: &str, cx: &mut AppContext) {
- let buffer = editor.read(cx).buffer().read(cx).snapshot(cx);
+ let snapshot = editor.read(cx).buffer().read(cx).snapshot(cx);
let selection = editor.read(cx).selections.newest_anchor().clone();
- let selected_text = buffer
+ let selected_text = snapshot
.text_for_range(selection.start..selection.end)
.collect::<String>();
- let language_name = buffer
+ let language_name = snapshot
.language_at(selection.start)
.map(|language| language.name());
let language_name = language_name.as_deref().unwrap_or("");
@@ -48,7 +47,7 @@ impl RefactoringAssistant {
RequestMessage {
role: Role::User,
content: format!(
- "Given the following {language_name} snippet:\n{selected_text}\n{prompt}. Avoid making remarks and reply only with the new code."
+ "Given the following {language_name} snippet:\n{selected_text}\n{prompt}. Avoid making remarks and reply only with the new code. Preserve indentation."
),
}],
stream: true,
@@ -60,86 +59,149 @@ impl RefactoringAssistant {
editor.id(),
cx.spawn(|mut cx| {
async move {
- let selection_start = selection.start.to_offset(&buffer);
-
- // Find unique words in the selected text to use as diff boundaries.
- let mut duplicate_words = HashSet::default();
- let mut unique_old_words = HashMap::default();
- for (range, word) in words(&selected_text) {
- if !duplicate_words.contains(word) {
- if unique_old_words.insert(word, range.end).is_some() {
- unique_old_words.remove(word);
- duplicate_words.insert(word);
- }
- }
- }
+ let selection_start = selection.start.to_offset(&snapshot);
let mut new_text = String::new();
let mut messages = response.await?;
- let mut new_word_search_start_ix = 0;
- let mut last_old_word_end_ix = 0;
-
- 'outer: loop {
- const MIN_DIFF_LEN: usize = 50;
-
- let start = new_word_search_start_ix;
- let mut words = words(&new_text[start..]);
- while let Some((range, new_word)) = words.next() {
- // We found a word in the new text that was unique in the old text. We can use
- // it as a diff boundary, and start applying edits.
- if let Some(old_word_end_ix) = unique_old_words.get(new_word).copied() {
- if old_word_end_ix.saturating_sub(last_old_word_end_ix)
- > MIN_DIFF_LEN
- {
- drop(words);
-
- let remainder = new_text.split_off(start + range.end);
- let edits = diff(
- selection_start + last_old_word_end_ix,
- &selected_text[last_old_word_end_ix..old_word_end_ix],
- &new_text,
- &buffer,
- );
- editor.update(&mut cx, |editor, cx| {
- editor
- .buffer()
- .update(cx, |buffer, cx| buffer.edit(edits, None, cx))
- })?;
-
- new_text = remainder;
- new_word_search_start_ix = 0;
- last_old_word_end_ix = old_word_end_ix;
- continue 'outer;
+
+ let mut transaction = None;
+
+ while let Some(message) = messages.next().await {
+ smol::future::yield_now().await;
+ let mut message = message?;
+ if let Some(choice) = message.choices.pop() {
+ if let Some(text) = choice.delta.content {
+ new_text.push_str(&text);
+
+ println!("-------------------------------------");
+
+ println!(
+ "{}",
+ similar::TextDiff::from_words(&selected_text, &new_text)
+ .unified_diff()
+ );
+
+ let mut changes =
+ similar::TextDiff::from_words(&selected_text, &new_text)
+ .iter_all_changes()
+ .collect::<Vec<_>>();
+
+ let mut ix = 0;
+ while ix < changes.len() {
+ let deletion_start_ix = ix;
+ let mut deletion_end_ix = ix;
+ while changes
+ .get(ix)
+ .map_or(false, |change| change.tag() == ChangeTag::Delete)
+ {
+ ix += 1;
+ deletion_end_ix += 1;
+ }
+
+ let insertion_start_ix = ix;
+ let mut insertion_end_ix = ix;
+ while changes
+ .get(ix)
+ .map_or(false, |change| change.tag() == ChangeTag::Insert)
+ {
+ ix += 1;
+ insertion_end_ix += 1;
+ }
+
+ if deletion_end_ix > deletion_start_ix
+ && insertion_end_ix > insertion_start_ix
+ {
+ for _ in deletion_start_ix..deletion_end_ix {
+ let deletion = changes.remove(deletion_end_ix);
+ changes.insert(insertion_end_ix - 1, deletion);
+ }
+ }
+
+ ix += 1;
}
- }
- new_word_search_start_ix = start + range.end;
- }
- drop(words);
-
- // Buffer incoming text, stopping if the stream was exhausted.
- if let Some(message) = messages.next().await {
- let mut message = message?;
- if let Some(choice) = message.choices.pop() {
- if let Some(text) = choice.delta.content {
- new_text.push_str(&text);
+ while changes
+ .last()
+ .map_or(false, |change| change.tag() != ChangeTag::Insert)
+ {
+ changes.pop();
}
+
+ editor.update(&mut cx, |editor, cx| {
+ editor.buffer().update(cx, |buffer, cx| {
+ if let Some(transaction) = transaction.take() {
+ buffer.undo(cx); // TODO: Undo the transaction instead
+ }
+
+ buffer.start_transaction(cx);
+ let mut edit_start = selection_start;
+ dbg!(&changes);
+ for change in changes {
+ let value = change.value();
+ let edit_end = edit_start + value.len();
+ match change.tag() {
+ ChangeTag::Equal => {
+ edit_start = edit_end;
+ }
+ ChangeTag::Delete => {
+ let range = snapshot.anchor_after(edit_start)
+ ..snapshot.anchor_before(edit_end);
+ buffer.edit([(range, "")], None, cx);
+ edit_start = edit_end;
+ }
+ ChangeTag::Insert => {
+ let insertion_start =
+ snapshot.anchor_after(edit_start);
+ buffer.edit(
+ [(insertion_start..insertion_start, value)],
+ None,
+ cx,
+ );
+ }
+ }
+ }
+ transaction = buffer.end_transaction(cx);
+ })
+ })?;
}
- } else {
- break;
}
}
- let edits = diff(
- selection_start + last_old_word_end_ix,
- &selected_text[last_old_word_end_ix..],
- &new_text,
- &buffer,
- );
editor.update(&mut cx, |editor, cx| {
- editor
- .buffer()
- .update(cx, |buffer, cx| buffer.edit(edits, None, cx))
+ editor.buffer().update(cx, |buffer, cx| {
+ if let Some(transaction) = transaction.take() {
+ buffer.undo(cx); // TODO: Undo the transaction instead
+ }
+
+ buffer.start_transaction(cx);
+ let mut edit_start = selection_start;
+ for change in similar::TextDiff::from_words(&selected_text, &new_text)
+ .iter_all_changes()
+ {
+ let value = change.value();
+ let edit_end = edit_start + value.len();
+ match change.tag() {
+ ChangeTag::Equal => {
+ edit_start = edit_end;
+ }
+ ChangeTag::Delete => {
+ let range = snapshot.anchor_after(edit_start)
+ ..snapshot.anchor_before(edit_end);
+ buffer.edit([(range, "")], None, cx);
+ edit_start = edit_end;
+ }
+ ChangeTag::Insert => {
+ let insertion_start = snapshot.anchor_after(edit_start);
+ buffer.edit(
+ [(insertion_start..insertion_start, value)],
+ None,
+ cx,
+ );
+ }
+ }
+ }
+ buffer.end_transaction(cx);
+ })
})?;
anyhow::Ok(())
@@ -197,11 +259,13 @@ impl RefactoringModal {
{
workspace.toggle_modal(cx, |_, cx| {
let prompt_editor = cx.add_view(|cx| {
- Editor::auto_height(
+ let mut editor = Editor::auto_height(
4,
Some(Arc::new(|theme| theme.search.editor.input.clone())),
cx,
- )
+ );
+ editor.set_text("Replace with match statement.", cx);
+ editor
});
cx.add_view(|_| RefactoringModal {
editor,
@@ -242,38 +306,97 @@ fn words(text: &str) -> impl Iterator<Item = (Range<usize>, &str)> {
})
}
-fn diff<'a>(
- start_ix: usize,
- old_text: &'a str,
- new_text: &'a str,
- old_buffer_snapshot: &MultiBufferSnapshot,
-) -> Vec<(Range<Anchor>, &'a str)> {
- let mut edit_start = start_ix;
- let mut edits = Vec::new();
- let diff = similar::TextDiff::from_words(old_text, &new_text);
- for change in diff.iter_all_changes() {
- let value = change.value();
- let edit_end = edit_start + value.len();
- match change.tag() {
- ChangeTag::Equal => {
- edit_start = edit_end;
- }
- ChangeTag::Delete => {
- edits.push((
- old_buffer_snapshot.anchor_after(edit_start)
- ..old_buffer_snapshot.anchor_before(edit_end),
- "",
- ));
- edit_start = edit_end;
- }
- ChangeTag::Insert => {
- edits.push((
- old_buffer_snapshot.anchor_after(edit_start)
- ..old_buffer_snapshot.anchor_after(edit_start),
- value,
- ));
- }
+fn streaming_diff<'a>(old_text: &'a str, new_text: &'a str) -> Vec<Change<'a, str>> {
+ let changes = TextDiff::configure()
+ .algorithm(similar::Algorithm::Patience)
+ .diff_words(old_text, new_text);
+ let mut changes = changes.iter_all_changes().peekable();
+
+ let mut result = vec![];
+
+ loop {
+ let mut deletions = vec![];
+ let mut insertions = vec![];
+
+ while changes
+ .peek()
+ .map_or(false, |change| change.tag() == ChangeTag::Delete)
+ {
+ deletions.push(changes.next().unwrap());
}
+
+ while changes
+ .peek()
+ .map_or(false, |change| change.tag() == ChangeTag::Insert)
+ {
+ insertions.push(changes.next().unwrap());
+ }
+
+ if !deletions.is_empty() && !insertions.is_empty() {
+ result.append(&mut insertions);
+ result.append(&mut deletions);
+ } else {
+ result.append(&mut deletions);
+ result.append(&mut insertions);
+ }
+
+ if let Some(change) = changes.next() {
+ result.push(change);
+ } else {
+ break;
+ }
+ }
+
+ // Remove all non-inserts at the end.
+ while result
+ .last()
+ .map_or(false, |change| change.tag() != ChangeTag::Insert)
+ {
+ result.pop();
+ }
+
+ result
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use indoc::indoc;
+
+ #[test]
+ fn test_streaming_diff() {
+ let old_text = indoc! {"
+ match (self.format, src_format) {
+ (Format::A8, Format::A8)
+ | (Format::Rgb24, Format::Rgb24)
+ | (Format::Rgba32, Format::Rgba32) => {
+ return self
+ .blit_from_with::<BlitMemcpy>(dst_rect, src_bytes, src_stride, src_format);
+ }
+ (Format::A8, Format::Rgb24) => {
+ return self
+ .blit_from_with::<BlitRgb24ToA8>(dst_rect, src_bytes, src_stride, src_format);
+ }
+ (Format::Rgb24, Format::A8) => {
+ return self
+ .blit_from_with::<BlitA8ToRgb24>(dst_rect, src_bytes, src_stride, src_format);
+ }
+ (Format::Rgb24, Format::Rgba32) => {
+ return self.blit_from_with::<BlitRgba32ToRgb24>(
+ dst_rect, src_bytes, src_stride, src_format,
+ );
+ }
+ (Format::Rgba32, Format::Rgb24)
+ | (Format::Rgba32, Format::A8)
+ | (Format::A8, Format::Rgba32) => {
+ unimplemented!()
+ }
+ _ => {}
+ }
+ "};
+ let new_text = indoc! {"
+ if self.format == src_format
+ "};
+ dbg!(streaming_diff(old_text, new_text));
}
- edits
}