diff --git a/crates/assistant/src/streaming_diff.rs b/crates/assistant/src/streaming_diff.rs index 8383e081d53d427b562ad788c7670e52f15a7ef5..5c20dccadb95a5a5667d57161216ad896dfdef69 100644 --- a/crates/assistant/src/streaming_diff.rs +++ b/crates/assistant/src/streaming_diff.rs @@ -28,13 +28,36 @@ impl Matrix { self.cols = cols; } + fn swap_columns(&mut self, col1: usize, col2: usize) { + if col1 == col2 { + return; + } + + if col1 >= self.cols { + panic!("column out of bounds"); + } + + if col2 >= self.cols { + panic!("column out of bounds"); + } + + unsafe { + let ptr = self.cells.as_mut_ptr(); + std::ptr::swap_nonoverlapping( + ptr.add(col1 * self.rows), + ptr.add(col2 * self.rows), + self.rows, + ); + } + } + fn get(&self, row: usize, col: usize) -> f64 { if row >= self.rows { panic!("row out of bounds") } if col >= self.cols { - panic!("col out of bounds") + panic!("column out of bounds") } self.cells[col * self.rows + row] } @@ -45,7 +68,7 @@ impl Matrix { } if col >= self.cols { - panic!("col out of bounds") + panic!("column out of bounds") } self.cells[col * self.rows + row] = value; @@ -106,26 +129,32 @@ impl StreamingDiff { pub fn push_new(&mut self, text: &str) -> Vec { self.new.extend(text.chars()); - self.scores.resize(self.old.len() + 1, self.new.len() + 1); + self.scores.swap_columns(0, self.scores.cols - 1); + self.scores + .resize(self.old.len() + 1, self.new.len() - self.new_text_ix + 1); + self.equal_runs.retain(|(_i, j), _| *j == self.new_text_ix); for j in self.new_text_ix + 1..=self.new.len() { - self.scores.set(0, j, j as f64 * Self::INSERTION_SCORE); + let relative_j = j - self.new_text_ix; + + self.scores + .set(0, relative_j, j as f64 * Self::INSERTION_SCORE); for i in 1..=self.old.len() { - let insertion_score = self.scores.get(i, j - 1) + Self::INSERTION_SCORE; - let deletion_score = self.scores.get(i - 1, j) + Self::DELETION_SCORE; + let insertion_score = self.scores.get(i, relative_j - 1) + Self::INSERTION_SCORE; + let deletion_score = self.scores.get(i - 1, relative_j) + Self::DELETION_SCORE; let equality_score = if self.old[i - 1] == self.new[j - 1] { let mut equal_run = self.equal_runs.get(&(i - 1, j - 1)).copied().unwrap_or(0); equal_run += 1; self.equal_runs.insert((i, j), equal_run); let exponent = cmp::min(equal_run as i32 / 4, Self::MAX_EQUALITY_EXPONENT); - self.scores.get(i - 1, j - 1) + Self::EQUALITY_BASE.powi(exponent) + self.scores.get(i - 1, relative_j - 1) + Self::EQUALITY_BASE.powi(exponent) } else { f64::NEG_INFINITY }; let score = insertion_score.max(deletion_score).max(equality_score); - self.scores.set(i, j, score); + self.scores.set(i, relative_j, score); } } @@ -133,7 +162,7 @@ impl StreamingDiff { let mut next_old_text_ix = self.old_text_ix; let next_new_text_ix = self.new.len(); for i in self.old_text_ix..=self.old.len() { - let score = self.scores.get(i, next_new_text_ix); + let score = self.scores.get(i, next_new_text_ix - self.new_text_ix); if score > max_score { max_score = score; next_old_text_ix = i; @@ -174,7 +203,9 @@ impl StreamingDiff { let (prev_i, prev_j) = [insertion_score, deletion_score, equality_score] .iter() - .max_by_key(|cell| cell.map(|(i, j)| OrderedFloat(self.scores.get(i, j)))) + .max_by_key(|cell| { + cell.map(|(i, j)| OrderedFloat(self.scores.get(i, j - self.new_text_ix))) + }) .unwrap() .unwrap();