streaming_diff.rs

  1use collections::HashMap;
  2use ordered_float::OrderedFloat;
  3use std::{
  4    cmp,
  5    fmt::{self, Debug},
  6    ops::Range,
  7};
  8
  9struct Matrix {
 10    cells: Vec<f64>,
 11    rows: usize,
 12    cols: usize,
 13}
 14
 15impl Matrix {
 16    fn new() -> Self {
 17        Self {
 18            cells: Vec::new(),
 19            rows: 0,
 20            cols: 0,
 21        }
 22    }
 23
 24    fn resize(&mut self, rows: usize, cols: usize) {
 25        self.cells.resize(rows * cols, 0.);
 26        self.rows = rows;
 27        self.cols = cols;
 28    }
 29
 30    fn get(&self, row: usize, col: usize) -> f64 {
 31        if row >= self.rows {
 32            panic!("row out of bounds")
 33        }
 34
 35        if col >= self.cols {
 36            panic!("col out of bounds")
 37        }
 38        self.cells[col * self.rows + row]
 39    }
 40
 41    fn set(&mut self, row: usize, col: usize, value: f64) {
 42        if row >= self.rows {
 43            panic!("row out of bounds")
 44        }
 45
 46        if col >= self.cols {
 47            panic!("col out of bounds")
 48        }
 49
 50        self.cells[col * self.rows + row] = value;
 51    }
 52}
 53
 54impl Debug for Matrix {
 55    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
 56        writeln!(f)?;
 57        for i in 0..self.rows {
 58            for j in 0..self.cols {
 59                write!(f, "{:5}", self.get(i, j))?;
 60            }
 61            writeln!(f)?;
 62        }
 63        Ok(())
 64    }
 65}
 66
 67#[derive(Debug)]
 68pub enum Hunk {
 69    Insert { text: String },
 70    Remove { len: usize },
 71    Keep { len: usize },
 72}
 73
 74pub struct StreamingDiff {
 75    old: Vec<char>,
 76    new: Vec<char>,
 77    scores: Matrix,
 78    old_text_ix: usize,
 79    new_text_ix: usize,
 80    equal_runs: HashMap<(usize, usize), u32>,
 81}
 82
 83impl StreamingDiff {
 84    const INSERTION_SCORE: f64 = -1.;
 85    const DELETION_SCORE: f64 = -5.;
 86    const EQUALITY_BASE: f64 = 1.4;
 87    const MAX_EQUALITY_EXPONENT: i32 = 64;
 88
 89    pub fn new(old: String) -> Self {
 90        let old = old.chars().collect::<Vec<_>>();
 91        let mut scores = Matrix::new();
 92        scores.resize(old.len() + 1, 1);
 93        for i in 0..=old.len() {
 94            scores.set(i, 0, i as f64 * Self::DELETION_SCORE);
 95        }
 96        Self {
 97            old,
 98            new: Vec::new(),
 99            scores,
100            old_text_ix: 0,
101            new_text_ix: 0,
102            equal_runs: Default::default(),
103        }
104    }
105
106    pub fn push_new(&mut self, text: &str) -> Vec<Hunk> {
107        self.new.extend(text.chars());
108        self.scores.resize(self.old.len() + 1, self.new.len() + 1);
109
110        for j in self.new_text_ix + 1..=self.new.len() {
111            self.scores.set(0, j, j as f64 * Self::INSERTION_SCORE);
112            for i in 1..=self.old.len() {
113                let insertion_score = self.scores.get(i, j - 1) + Self::INSERTION_SCORE;
114                let deletion_score = self.scores.get(i - 1, j) + Self::DELETION_SCORE;
115                let equality_score = if self.old[i - 1] == self.new[j - 1] {
116                    let mut equal_run = self.equal_runs.get(&(i - 1, j - 1)).copied().unwrap_or(0);
117                    equal_run += 1;
118                    self.equal_runs.insert((i, j), equal_run);
119
120                    if self.old[i - 1] == ' ' {
121                        self.scores.get(i - 1, j - 1)
122                    } else {
123                        let exponent = cmp::min(equal_run as i32, Self::MAX_EQUALITY_EXPONENT);
124                        self.scores.get(i - 1, j - 1) + Self::EQUALITY_BASE.powi(exponent)
125                    }
126                } else {
127                    f64::NEG_INFINITY
128                };
129
130                let score = insertion_score.max(deletion_score).max(equality_score);
131                self.scores.set(i, j, score);
132            }
133        }
134
135        let mut max_score = f64::NEG_INFINITY;
136        let mut next_old_text_ix = self.old_text_ix;
137        let next_new_text_ix = self.new.len();
138        for i in self.old_text_ix..=self.old.len() {
139            let score = self.scores.get(i, next_new_text_ix);
140            if score > max_score {
141                max_score = score;
142                next_old_text_ix = i;
143            }
144        }
145
146        let hunks = self.backtrack(next_old_text_ix, next_new_text_ix);
147        self.old_text_ix = next_old_text_ix;
148        self.new_text_ix = next_new_text_ix;
149        hunks
150    }
151
152    fn backtrack(&self, old_text_ix: usize, new_text_ix: usize) -> Vec<Hunk> {
153        let mut pending_insert: Option<Range<usize>> = None;
154        let mut hunks = Vec::new();
155        let mut i = old_text_ix;
156        let mut j = new_text_ix;
157        while (i, j) != (self.old_text_ix, self.new_text_ix) {
158            let insertion_score = if j > self.new_text_ix {
159                Some((i, j - 1))
160            } else {
161                None
162            };
163            let deletion_score = if i > self.old_text_ix {
164                Some((i - 1, j))
165            } else {
166                None
167            };
168            let equality_score = if i > self.old_text_ix && j > self.new_text_ix {
169                if self.old[i - 1] == self.new[j - 1] {
170                    Some((i - 1, j - 1))
171                } else {
172                    None
173                }
174            } else {
175                None
176            };
177
178            let (prev_i, prev_j) = [insertion_score, deletion_score, equality_score]
179                .iter()
180                .max_by_key(|cell| cell.map(|(i, j)| OrderedFloat(self.scores.get(i, j))))
181                .unwrap()
182                .unwrap();
183
184            if prev_i == i && prev_j == j - 1 {
185                if let Some(pending_insert) = pending_insert.as_mut() {
186                    pending_insert.start = prev_j;
187                } else {
188                    pending_insert = Some(prev_j..j);
189                }
190            } else {
191                if let Some(range) = pending_insert.take() {
192                    hunks.push(Hunk::Insert {
193                        text: self.new[range].iter().collect(),
194                    });
195                }
196
197                let char_len = self.old[i - 1].len_utf8();
198                if prev_i == i - 1 && prev_j == j {
199                    if let Some(Hunk::Remove { len }) = hunks.last_mut() {
200                        *len += char_len;
201                    } else {
202                        hunks.push(Hunk::Remove { len: char_len })
203                    }
204                } else {
205                    if let Some(Hunk::Keep { len }) = hunks.last_mut() {
206                        *len += char_len;
207                    } else {
208                        hunks.push(Hunk::Keep { len: char_len })
209                    }
210                }
211            }
212
213            i = prev_i;
214            j = prev_j;
215        }
216
217        if let Some(range) = pending_insert.take() {
218            hunks.push(Hunk::Insert {
219                text: self.new[range].iter().collect(),
220            });
221        }
222
223        hunks.reverse();
224        hunks
225    }
226
227    pub fn finish(self) -> Vec<Hunk> {
228        self.backtrack(self.old.len(), self.new.len())
229    }
230}
231
232#[cfg(test)]
233mod tests {
234    use std::env;
235
236    use super::*;
237    use rand::prelude::*;
238
239    #[gpui::test(iterations = 100)]
240    fn test_random_diffs(mut rng: StdRng) {
241        let old_text_len = env::var("OLD_TEXT_LEN")
242            .map(|i| i.parse().expect("invalid `OLD_TEXT_LEN` variable"))
243            .unwrap_or(10);
244        let new_text_len = env::var("NEW_TEXT_LEN")
245            .map(|i| i.parse().expect("invalid `NEW_TEXT_LEN` variable"))
246            .unwrap_or(10);
247
248        let old = util::RandomCharIter::new(&mut rng)
249            .take(old_text_len)
250            .collect::<String>();
251        log::info!("old text: {:?}", old);
252
253        let mut diff = StreamingDiff::new(old.clone());
254        let mut hunks = Vec::new();
255        let mut new_len = 0;
256        let mut new = String::new();
257        while new_len < new_text_len {
258            let new_chunk_len = rng.gen_range(1..=new_text_len - new_len);
259            let new_chunk = util::RandomCharIter::new(&mut rng)
260                .take(new_len)
261                .collect::<String>();
262            log::info!("new chunk: {:?}", new_chunk);
263            new_len += new_chunk_len;
264            new.push_str(&new_chunk);
265            let new_hunks = diff.push_new(&new_chunk);
266            log::info!("hunks: {:?}", new_hunks);
267            hunks.extend(new_hunks);
268        }
269        let final_hunks = diff.finish();
270        log::info!("final hunks: {:?}", final_hunks);
271        hunks.extend(final_hunks);
272
273        log::info!("new text: {:?}", new);
274        let mut old_ix = 0;
275        let mut new_ix = 0;
276        let mut patched = String::new();
277        for hunk in hunks {
278            match hunk {
279                Hunk::Keep { len } => {
280                    assert_eq!(&old[old_ix..old_ix + len], &new[new_ix..new_ix + len]);
281                    patched.push_str(&old[old_ix..old_ix + len]);
282                    old_ix += len;
283                    new_ix += len;
284                }
285                Hunk::Remove { len } => {
286                    old_ix += len;
287                }
288                Hunk::Insert { text } => {
289                    assert_eq!(text, &new[new_ix..new_ix + text.len()]);
290                    patched.push_str(&text);
291                    new_ix += text.len();
292                }
293            }
294        }
295        assert_eq!(patched, new);
296    }
297}