Support multi-byte characters in diff

Antonio Scandurra created

Change summary

crates/ai/src/diff.rs | 57 +++++++++++++++++++++++++++++---------------
1 file changed, 37 insertions(+), 20 deletions(-)

Detailed changes

crates/ai/src/diff.rs 🔗

@@ -1,6 +1,7 @@
 use std::{
     cmp,
     fmt::{self, Debug},
+    ops::Range,
 };
 
 use collections::BinaryHeap;
@@ -71,8 +72,8 @@ pub enum Hunk {
 }
 
 pub struct Diff {
-    old: String,
-    new: String,
+    old: Vec<char>,
+    new: Vec<char>,
     scores: Matrix,
     old_text_ix: usize,
     new_text_ix: usize,
@@ -84,6 +85,7 @@ impl Diff {
     const EQUALITY_SCORE: isize = 5;
 
     pub fn new(old: String) -> Self {
+        let old = old.chars().collect::<Vec<_>>();
         let mut scores = Matrix::new();
         scores.resize(old.len() + 1, 1);
         for i in 0..=old.len() {
@@ -91,7 +93,7 @@ impl Diff {
         }
         Self {
             old,
-            new: String::new(),
+            new: Vec::new(),
             scores,
             old_text_ix: 0,
             new_text_ix: 0,
@@ -99,7 +101,7 @@ impl Diff {
     }
 
     pub fn push_new(&mut self, text: &str) -> Vec<Hunk> {
-        self.new.push_str(text);
+        self.new.extend(text.chars());
         self.scores.resize(self.old.len() + 1, self.new.len() + 1);
 
         for j in self.new_text_ix + 1..=self.new.len() {
@@ -107,7 +109,7 @@ impl Diff {
             for i in 1..=self.old.len() {
                 let insertion_score = self.scores.get(i, j - 1) + Self::INSERTION_SCORE;
                 let deletion_score = self.scores.get(i - 1, j) + Self::DELETION_SCORE;
-                let equality_score = if self.old.as_bytes()[i - 1] == self.new.as_bytes()[j - 1] {
+                let equality_score = if self.old[i - 1] == self.new[j - 1] {
                     self.scores.get(i - 1, j - 1) + Self::EQUALITY_SCORE
                 } else {
                     isize::MIN
@@ -138,6 +140,7 @@ impl Diff {
     }
 
     fn backtrack(&self, old_text_ix: usize, new_text_ix: usize) -> Vec<Hunk> {
+        let mut pending_insert: Option<Range<usize>> = None;
         let mut hunks = Vec::new();
         let mut i = old_text_ix;
         let mut j = new_text_ix;
@@ -153,7 +156,7 @@ impl Diff {
                 None
             };
             let equality_score = if i > self.old_text_ix && j > self.new_text_ix {
-                if self.old.as_bytes()[i - 1] == self.new.as_bytes()[j - 1] {
+                if self.old[i - 1] == self.new[j - 1] {
                     Some((i - 1, j - 1))
                 } else {
                     None
@@ -169,30 +172,44 @@ impl Diff {
                 .unwrap();
 
             if prev_i == i && prev_j == j - 1 {
-                if let Some(Hunk::Insert { text }) = hunks.last_mut() {
-                    text.insert_str(0, &self.new[prev_j..j]);
+                if let Some(pending_insert) = pending_insert.as_mut() {
+                    pending_insert.start = prev_j;
                 } else {
-                    hunks.push(Hunk::Insert {
-                        text: self.new[prev_j..j].to_string(),
-                    })
-                }
-            } else if prev_i == i - 1 && prev_j == j {
-                if let Some(Hunk::Remove { len }) = hunks.last_mut() {
-                    *len += 1;
-                } else {
-                    hunks.push(Hunk::Remove { len: 1 })
+                    pending_insert = Some(prev_j..j);
                 }
             } else {
-                if let Some(Hunk::Keep { len }) = hunks.last_mut() {
-                    *len += 1;
+                if let Some(range) = pending_insert.take() {
+                    hunks.push(Hunk::Insert {
+                        text: self.new[range].iter().collect(),
+                    });
+                }
+
+                let char_len = self.old[i - 1].len_utf8();
+                if prev_i == i - 1 && prev_j == j {
+                    if let Some(Hunk::Remove { len }) = hunks.last_mut() {
+                        *len += char_len;
+                    } else {
+                        hunks.push(Hunk::Remove { len: char_len })
+                    }
                 } else {
-                    hunks.push(Hunk::Keep { len: 1 })
+                    if let Some(Hunk::Keep { len }) = hunks.last_mut() {
+                        *len += char_len;
+                    } else {
+                        hunks.push(Hunk::Keep { len: char_len })
+                    }
                 }
             }
 
             i = prev_i;
             j = prev_j;
         }
+
+        if let Some(range) = pending_insert.take() {
+            hunks.push(Hunk::Insert {
+                text: self.new[range].iter().collect(),
+            });
+        }
+
         hunks.reverse();
         hunks
     }