From 5453553cfa4eb8b1a21a066402ed3ba82067a240 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Tue, 22 Aug 2023 08:16:22 +0200 Subject: [PATCH] WIP --- Cargo.lock | 1 + crates/ai/Cargo.toml | 1 + crates/ai/src/refactor.rs | 347 ++++++++++++++++++++++++++------------ 3 files changed, 237 insertions(+), 112 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index f802d907390439aa22f32034998f995aacacfcdd..af16a88596217e1aaf0c04ce3f169f1b5a9051e9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -106,6 +106,7 @@ dependencies = [ "fs", "futures 0.3.28", "gpui", + "indoc", "isahc", "language", "menu", diff --git a/crates/ai/Cargo.toml b/crates/ai/Cargo.toml index bae20f7537c7442e29eeb7013e223b8d7ce422b2..5ef371e3425ce9e85da94d956a05b92b4a308100 100644 --- a/crates/ai/Cargo.toml +++ b/crates/ai/Cargo.toml @@ -24,6 +24,7 @@ workspace = { path = "../workspace" } anyhow.workspace = true chrono = { version = "0.4", features = ["serde"] } futures.workspace = true +indoc.workspace = true isahc.workspace = true regex.workspace = true schemars.workspace = true diff --git a/crates/ai/src/refactor.rs b/crates/ai/src/refactor.rs index 1a1d02cf1f2c55e78ee500993aa8f2621b49ffe0..1923ef7845f7d43e37253c019c67f600145dfdba 100644 --- a/crates/ai/src/refactor.rs +++ b/crates/ai/src/refactor.rs @@ -1,14 +1,13 @@ use crate::{stream_completion, OpenAIRequest, RequestMessage, Role}; -use collections::{BTreeMap, BTreeSet, HashMap, HashSet}; -use editor::{Anchor, Editor, MultiBuffer, MultiBufferSnapshot, ToOffset}; -use futures::{io::BufWriter, AsyncReadExt, AsyncWriteExt, StreamExt}; +use collections::HashMap; +use editor::{Editor, ToOffset}; +use futures::StreamExt; use gpui::{ actions, elements::*, AnyViewHandle, AppContext, Entity, Task, View, ViewContext, ViewHandle, WeakViewHandle, }; use menu::Confirm; -use serde::Deserialize; -use similar::ChangeTag; +use similar::{Change, ChangeTag, TextDiff}; use std::{env, iter, ops::Range, sync::Arc}; use util::TryFutureExt; use workspace::{Modal, Workspace}; @@ -33,12 +32,12 @@ impl RefactoringAssistant { } fn refactor(&mut self, editor: &ViewHandle, prompt: &str, cx: &mut AppContext) { - let buffer = editor.read(cx).buffer().read(cx).snapshot(cx); + let snapshot = editor.read(cx).buffer().read(cx).snapshot(cx); let selection = editor.read(cx).selections.newest_anchor().clone(); - let selected_text = buffer + let selected_text = snapshot .text_for_range(selection.start..selection.end) .collect::(); - let language_name = buffer + let language_name = snapshot .language_at(selection.start) .map(|language| language.name()); let language_name = language_name.as_deref().unwrap_or(""); @@ -48,7 +47,7 @@ impl RefactoringAssistant { RequestMessage { role: Role::User, content: format!( - "Given the following {language_name} snippet:\n{selected_text}\n{prompt}. Avoid making remarks and reply only with the new code." + "Given the following {language_name} snippet:\n{selected_text}\n{prompt}. Avoid making remarks and reply only with the new code. Preserve indentation." ), }], stream: true, @@ -60,86 +59,149 @@ impl RefactoringAssistant { editor.id(), cx.spawn(|mut cx| { async move { - let selection_start = selection.start.to_offset(&buffer); - - // Find unique words in the selected text to use as diff boundaries. - let mut duplicate_words = HashSet::default(); - let mut unique_old_words = HashMap::default(); - for (range, word) in words(&selected_text) { - if !duplicate_words.contains(word) { - if unique_old_words.insert(word, range.end).is_some() { - unique_old_words.remove(word); - duplicate_words.insert(word); - } - } - } + let selection_start = selection.start.to_offset(&snapshot); let mut new_text = String::new(); let mut messages = response.await?; - let mut new_word_search_start_ix = 0; - let mut last_old_word_end_ix = 0; - - 'outer: loop { - const MIN_DIFF_LEN: usize = 50; - - let start = new_word_search_start_ix; - let mut words = words(&new_text[start..]); - while let Some((range, new_word)) = words.next() { - // We found a word in the new text that was unique in the old text. We can use - // it as a diff boundary, and start applying edits. - if let Some(old_word_end_ix) = unique_old_words.get(new_word).copied() { - if old_word_end_ix.saturating_sub(last_old_word_end_ix) - > MIN_DIFF_LEN - { - drop(words); - - let remainder = new_text.split_off(start + range.end); - let edits = diff( - selection_start + last_old_word_end_ix, - &selected_text[last_old_word_end_ix..old_word_end_ix], - &new_text, - &buffer, - ); - editor.update(&mut cx, |editor, cx| { - editor - .buffer() - .update(cx, |buffer, cx| buffer.edit(edits, None, cx)) - })?; - - new_text = remainder; - new_word_search_start_ix = 0; - last_old_word_end_ix = old_word_end_ix; - continue 'outer; + + let mut transaction = None; + + while let Some(message) = messages.next().await { + smol::future::yield_now().await; + let mut message = message?; + if let Some(choice) = message.choices.pop() { + if let Some(text) = choice.delta.content { + new_text.push_str(&text); + + println!("-------------------------------------"); + + println!( + "{}", + similar::TextDiff::from_words(&selected_text, &new_text) + .unified_diff() + ); + + let mut changes = + similar::TextDiff::from_words(&selected_text, &new_text) + .iter_all_changes() + .collect::>(); + + let mut ix = 0; + while ix < changes.len() { + let deletion_start_ix = ix; + let mut deletion_end_ix = ix; + while changes + .get(ix) + .map_or(false, |change| change.tag() == ChangeTag::Delete) + { + ix += 1; + deletion_end_ix += 1; + } + + let insertion_start_ix = ix; + let mut insertion_end_ix = ix; + while changes + .get(ix) + .map_or(false, |change| change.tag() == ChangeTag::Insert) + { + ix += 1; + insertion_end_ix += 1; + } + + if deletion_end_ix > deletion_start_ix + && insertion_end_ix > insertion_start_ix + { + for _ in deletion_start_ix..deletion_end_ix { + let deletion = changes.remove(deletion_end_ix); + changes.insert(insertion_end_ix - 1, deletion); + } + } + + ix += 1; } - } - new_word_search_start_ix = start + range.end; - } - drop(words); - - // Buffer incoming text, stopping if the stream was exhausted. - if let Some(message) = messages.next().await { - let mut message = message?; - if let Some(choice) = message.choices.pop() { - if let Some(text) = choice.delta.content { - new_text.push_str(&text); + while changes + .last() + .map_or(false, |change| change.tag() != ChangeTag::Insert) + { + changes.pop(); } + + editor.update(&mut cx, |editor, cx| { + editor.buffer().update(cx, |buffer, cx| { + if let Some(transaction) = transaction.take() { + buffer.undo(cx); // TODO: Undo the transaction instead + } + + buffer.start_transaction(cx); + let mut edit_start = selection_start; + dbg!(&changes); + for change in changes { + let value = change.value(); + let edit_end = edit_start + value.len(); + match change.tag() { + ChangeTag::Equal => { + edit_start = edit_end; + } + ChangeTag::Delete => { + let range = snapshot.anchor_after(edit_start) + ..snapshot.anchor_before(edit_end); + buffer.edit([(range, "")], None, cx); + edit_start = edit_end; + } + ChangeTag::Insert => { + let insertion_start = + snapshot.anchor_after(edit_start); + buffer.edit( + [(insertion_start..insertion_start, value)], + None, + cx, + ); + } + } + } + transaction = buffer.end_transaction(cx); + }) + })?; } - } else { - break; } } - let edits = diff( - selection_start + last_old_word_end_ix, - &selected_text[last_old_word_end_ix..], - &new_text, - &buffer, - ); editor.update(&mut cx, |editor, cx| { - editor - .buffer() - .update(cx, |buffer, cx| buffer.edit(edits, None, cx)) + editor.buffer().update(cx, |buffer, cx| { + if let Some(transaction) = transaction.take() { + buffer.undo(cx); // TODO: Undo the transaction instead + } + + buffer.start_transaction(cx); + let mut edit_start = selection_start; + for change in similar::TextDiff::from_words(&selected_text, &new_text) + .iter_all_changes() + { + let value = change.value(); + let edit_end = edit_start + value.len(); + match change.tag() { + ChangeTag::Equal => { + edit_start = edit_end; + } + ChangeTag::Delete => { + let range = snapshot.anchor_after(edit_start) + ..snapshot.anchor_before(edit_end); + buffer.edit([(range, "")], None, cx); + edit_start = edit_end; + } + ChangeTag::Insert => { + let insertion_start = snapshot.anchor_after(edit_start); + buffer.edit( + [(insertion_start..insertion_start, value)], + None, + cx, + ); + } + } + } + buffer.end_transaction(cx); + }) })?; anyhow::Ok(()) @@ -197,11 +259,13 @@ impl RefactoringModal { { workspace.toggle_modal(cx, |_, cx| { let prompt_editor = cx.add_view(|cx| { - Editor::auto_height( + let mut editor = Editor::auto_height( 4, Some(Arc::new(|theme| theme.search.editor.input.clone())), cx, - ) + ); + editor.set_text("Replace with match statement.", cx); + editor }); cx.add_view(|_| RefactoringModal { editor, @@ -242,38 +306,97 @@ fn words(text: &str) -> impl Iterator, &str)> { }) } -fn diff<'a>( - start_ix: usize, - old_text: &'a str, - new_text: &'a str, - old_buffer_snapshot: &MultiBufferSnapshot, -) -> Vec<(Range, &'a str)> { - let mut edit_start = start_ix; - let mut edits = Vec::new(); - let diff = similar::TextDiff::from_words(old_text, &new_text); - for change in diff.iter_all_changes() { - let value = change.value(); - let edit_end = edit_start + value.len(); - match change.tag() { - ChangeTag::Equal => { - edit_start = edit_end; - } - ChangeTag::Delete => { - edits.push(( - old_buffer_snapshot.anchor_after(edit_start) - ..old_buffer_snapshot.anchor_before(edit_end), - "", - )); - edit_start = edit_end; - } - ChangeTag::Insert => { - edits.push(( - old_buffer_snapshot.anchor_after(edit_start) - ..old_buffer_snapshot.anchor_after(edit_start), - value, - )); - } +fn streaming_diff<'a>(old_text: &'a str, new_text: &'a str) -> Vec> { + let changes = TextDiff::configure() + .algorithm(similar::Algorithm::Patience) + .diff_words(old_text, new_text); + let mut changes = changes.iter_all_changes().peekable(); + + let mut result = vec![]; + + loop { + let mut deletions = vec![]; + let mut insertions = vec![]; + + while changes + .peek() + .map_or(false, |change| change.tag() == ChangeTag::Delete) + { + deletions.push(changes.next().unwrap()); } + + while changes + .peek() + .map_or(false, |change| change.tag() == ChangeTag::Insert) + { + insertions.push(changes.next().unwrap()); + } + + if !deletions.is_empty() && !insertions.is_empty() { + result.append(&mut insertions); + result.append(&mut deletions); + } else { + result.append(&mut deletions); + result.append(&mut insertions); + } + + if let Some(change) = changes.next() { + result.push(change); + } else { + break; + } + } + + // Remove all non-inserts at the end. + while result + .last() + .map_or(false, |change| change.tag() != ChangeTag::Insert) + { + result.pop(); + } + + result +} + +#[cfg(test)] +mod tests { + use super::*; + use indoc::indoc; + + #[test] + fn test_streaming_diff() { + let old_text = indoc! {" + match (self.format, src_format) { + (Format::A8, Format::A8) + | (Format::Rgb24, Format::Rgb24) + | (Format::Rgba32, Format::Rgba32) => { + return self + .blit_from_with::(dst_rect, src_bytes, src_stride, src_format); + } + (Format::A8, Format::Rgb24) => { + return self + .blit_from_with::(dst_rect, src_bytes, src_stride, src_format); + } + (Format::Rgb24, Format::A8) => { + return self + .blit_from_with::(dst_rect, src_bytes, src_stride, src_format); + } + (Format::Rgb24, Format::Rgba32) => { + return self.blit_from_with::( + dst_rect, src_bytes, src_stride, src_format, + ); + } + (Format::Rgba32, Format::Rgb24) + | (Format::Rgba32, Format::A8) + | (Format::A8, Format::Rgba32) => { + unimplemented!() + } + _ => {} + } + "}; + let new_text = indoc! {" + if self.format == src_format + "}; + dbg!(streaming_diff(old_text, new_text)); } - edits }