diff --git a/crates/ai/src/diff.rs b/crates/ai/src/diff.rs index 5e73c94ff89545f5a2e8ba18ee12066bb96e13a6..1b5b4cbd20f5da7b401ddde701cb22406d7dcfe5 100644 --- a/crates/ai/src/diff.rs +++ b/crates/ai/src/diff.rs @@ -65,7 +65,7 @@ impl Debug for Matrix { #[derive(Debug)] pub enum Hunk { - Insert { len: usize }, + Insert { text: String }, Remove { len: usize }, Keep { len: usize }, } @@ -75,37 +75,42 @@ pub struct Diff { new: String, scores: Matrix, old_text_ix: usize, + new_text_ix: usize, } impl Diff { + const INSERTION_SCORE: isize = -1; + const DELETION_SCORE: isize = -4; + const EQUALITY_SCORE: isize = 5; + pub fn new(old: String) -> Self { let mut scores = Matrix::new(); scores.resize(old.len() + 1, 1); for i in 0..=old.len() { - scores.set(i, 0, -(i as isize)); + scores.set(i, 0, i as isize * Self::DELETION_SCORE); } Self { old, new: String::new(), scores, old_text_ix: 0, + new_text_ix: 0, } } pub fn push_new(&mut self, text: &str) -> Vec { - let new_text_ix = self.new.len(); self.new.push_str(text); self.scores.resize(self.old.len() + 1, self.new.len() + 1); - for j in new_text_ix + 1..=self.new.len() { - self.scores.set(0, j, -(j as isize)); + for j in self.new_text_ix + 1..=self.new.len() { + self.scores.set(0, j, j as isize * Self::INSERTION_SCORE); for i in 1..=self.old.len() { - let insertion_score = self.scores.get(i, j - 1) - 1; - let deletion_score = self.scores.get(i - 1, j) - 10; + let insertion_score = self.scores.get(i, j - 1) + Self::INSERTION_SCORE; + let deletion_score = self.scores.get(i - 1, j) + Self::DELETION_SCORE; let equality_score = if self.old.as_bytes()[i - 1] == self.new.as_bytes()[j - 1] { - self.scores.get(i - 1, j - 1) + 5 + self.scores.get(i - 1, j - 1) + Self::EQUALITY_SCORE } else { - self.scores.get(i - 1, j - 1) - 20 + isize::MIN }; let score = insertion_score.max(deletion_score).max(equality_score); self.scores.set(i, j, score); @@ -114,19 +119,30 @@ impl Diff { let mut max_score = isize::MIN; let mut best_row = self.old_text_ix; + let mut best_col = self.new_text_ix; for i in self.old_text_ix..=self.old.len() { - let score = self.scores.get(i, self.new.len()); - if score > max_score { - max_score = score; - best_row = i; + for j in self.new_text_ix..=self.new.len() { + let score = self.scores.get(i, j); + if score > max_score { + max_score = score; + best_row = i; + best_col = j; + } } } + let hunks = self.backtrack(best_row, best_col); + self.old_text_ix = best_row; + self.new_text_ix = best_col; + hunks + } + + fn backtrack(&self, old_text_ix: usize, new_text_ix: usize) -> Vec { let mut hunks = Vec::new(); - let mut i = best_row; - let mut j = self.new.len(); - while (i, j) != (self.old_text_ix, new_text_ix) { - let insertion_score = if j > new_text_ix { + let mut i = old_text_ix; + let mut j = new_text_ix; + while (i, j) != (self.old_text_ix, self.new_text_ix) { + let insertion_score = if j > self.new_text_ix { Some((i, j - 1)) } else { None @@ -136,8 +152,12 @@ impl Diff { } else { None }; - let equality_score = if i > self.old_text_ix && j > new_text_ix { - Some((i - 1, j - 1)) + let equality_score = if i > self.old_text_ix && j > self.new_text_ix { + if self.old.as_bytes()[i - 1] == self.new.as_bytes()[j - 1] { + Some((i - 1, j - 1)) + } else { + None + } } else { None }; @@ -149,10 +169,12 @@ impl Diff { .unwrap(); if prev_i == i && prev_j == j - 1 { - if let Some(Hunk::Insert { len }) = hunks.last_mut() { - *len += 1; + if let Some(Hunk::Insert { text }) = hunks.last_mut() { + text.insert_str(0, &self.new[prev_j..j]); } else { - hunks.push(Hunk::Insert { len: 1 }) + hunks.push(Hunk::Insert { + text: self.new[prev_j..j].to_string(), + }) } } else if prev_i == i - 1 && prev_j == j { if let Some(Hunk::Remove { len }) = hunks.last_mut() { @@ -171,19 +193,12 @@ impl Diff { i = prev_i; j = prev_j; } - self.old_text_ix = best_row; hunks.reverse(); hunks } - pub fn finish(self) -> Option { - if self.old_text_ix < self.old.len() { - Some(Hunk::Remove { - len: self.old.len() - self.old_text_ix, - }) - } else { - None - } + pub fn finish(self) -> Vec { + self.backtrack(self.old.len(), self.new.len()) } } @@ -194,9 +209,9 @@ mod tests { #[test] fn test_diff() { let mut diff = Diff::new("hello world".to_string()); - diff.push_new("hello"); - diff.push_new(" ciaone"); - diff.push_new(" world"); - diff.finish(); + dbg!(diff.push_new("hello")); + dbg!(diff.push_new(" ciaone")); + // dbg!(diff.push_new(" world")); + dbg!(diff.finish()); } } diff --git a/crates/ai/src/refactor.rs b/crates/ai/src/refactor.rs index dcec04deefd76a3553ee4fc884c371f31cf15708..87f7495fcf5b812b0e925d3703647cd167daf19c 100644 --- a/crates/ai/src/refactor.rs +++ b/crates/ai/src/refactor.rs @@ -47,7 +47,7 @@ impl RefactoringAssistant { RequestMessage { role: Role::User, content: format!( - "Given the following {language_name} snippet:\n{selected_text}\n{prompt}. Avoid making remarks and reply only with the new code. Preserve indentation." + "Given the following {language_name} snippet:\n{selected_text}\n{prompt}. Never make remarks and reply only with the new code. Never change the leading whitespace on each line." ), }], stream: true, @@ -81,9 +81,7 @@ impl RefactoringAssistant { hunks_tx.send((hunks, new_text)).await?; } - if let Some(hunk) = diff.finish() { - hunks_tx.send((vec![hunk], String::new())).await?; - } + hunks_tx.send((diff.finish(), String::new())).await?; anyhow::Ok(()) }); @@ -92,14 +90,11 @@ impl RefactoringAssistant { editor.update(&mut cx, |editor, cx| { editor.buffer().update(cx, |buffer, cx| { buffer.start_transaction(cx); - let mut new_text_ix = 0; for hunk in hunks { match hunk { - crate::diff::Hunk::Insert { len } => { - let text = &new_text[new_text_ix..new_text_ix + len]; + crate::diff::Hunk::Insert { text } => { let edit_start = snapshot.anchor_after(edit_start); buffer.edit([(edit_start..edit_start, text)], None, cx); - new_text_ix += len; } crate::diff::Hunk::Remove { len } => { let edit_end = edit_start + len; @@ -110,7 +105,6 @@ impl RefactoringAssistant { } crate::diff::Hunk::Keep { len } => { edit_start += len; - new_text_ix += len; } } }