From 21f778c6db7db6fe1221034fe012b0030f0fa2ed Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Wed, 6 Nov 2024 15:07:16 +0100 Subject: [PATCH] Reduce memory footprint for inline transformations (#20296) Closes https://github.com/zed-industries/zed/issues/18062 This pull request prevents the `scores` matrix for the streaming diff from growing quadratically. Previously, we would store rows and columns respectively for all characters in the old and new text. However, every time we receive a chunk, we will always advance the position in the matrix to the very latest character in the new text. This means we can avoid storing scores for the new characters that were already reported. Randomized tests still pass and I also made sure that the diffs we produce are identical. Release Notes: - Improved memory footprint for inline transformations ([#18062](https://github.com/zed-industries/zed/issues/18062)) --- crates/assistant/src/streaming_diff.rs | 51 +++++++++++++++++++++----- 1 file changed, 41 insertions(+), 10 deletions(-) diff --git a/crates/assistant/src/streaming_diff.rs b/crates/assistant/src/streaming_diff.rs index 8383e081d53d427b562ad788c7670e52f15a7ef5..5c20dccadb95a5a5667d57161216ad896dfdef69 100644 --- a/crates/assistant/src/streaming_diff.rs +++ b/crates/assistant/src/streaming_diff.rs @@ -28,13 +28,36 @@ impl Matrix { self.cols = cols; } + fn swap_columns(&mut self, col1: usize, col2: usize) { + if col1 == col2 { + return; + } + + if col1 >= self.cols { + panic!("column out of bounds"); + } + + if col2 >= self.cols { + panic!("column out of bounds"); + } + + unsafe { + let ptr = self.cells.as_mut_ptr(); + std::ptr::swap_nonoverlapping( + ptr.add(col1 * self.rows), + ptr.add(col2 * self.rows), + self.rows, + ); + } + } + fn get(&self, row: usize, col: usize) -> f64 { if row >= self.rows { panic!("row out of bounds") } if col >= self.cols { - panic!("col out of bounds") + panic!("column out of bounds") } self.cells[col * self.rows + row] } @@ -45,7 +68,7 @@ impl Matrix { } if col >= self.cols { - panic!("col out of bounds") + panic!("column out of bounds") } self.cells[col * self.rows + row] = value; @@ -106,26 +129,32 @@ impl StreamingDiff { pub fn push_new(&mut self, text: &str) -> Vec { self.new.extend(text.chars()); - self.scores.resize(self.old.len() + 1, self.new.len() + 1); + self.scores.swap_columns(0, self.scores.cols - 1); + self.scores + .resize(self.old.len() + 1, self.new.len() - self.new_text_ix + 1); + self.equal_runs.retain(|(_i, j), _| *j == self.new_text_ix); for j in self.new_text_ix + 1..=self.new.len() { - self.scores.set(0, j, j as f64 * Self::INSERTION_SCORE); + let relative_j = j - self.new_text_ix; + + self.scores + .set(0, relative_j, j as f64 * Self::INSERTION_SCORE); for i in 1..=self.old.len() { - let insertion_score = self.scores.get(i, j - 1) + Self::INSERTION_SCORE; - let deletion_score = self.scores.get(i - 1, j) + Self::DELETION_SCORE; + let insertion_score = self.scores.get(i, relative_j - 1) + Self::INSERTION_SCORE; + let deletion_score = self.scores.get(i - 1, relative_j) + Self::DELETION_SCORE; let equality_score = if self.old[i - 1] == self.new[j - 1] { let mut equal_run = self.equal_runs.get(&(i - 1, j - 1)).copied().unwrap_or(0); equal_run += 1; self.equal_runs.insert((i, j), equal_run); let exponent = cmp::min(equal_run as i32 / 4, Self::MAX_EQUALITY_EXPONENT); - self.scores.get(i - 1, j - 1) + Self::EQUALITY_BASE.powi(exponent) + self.scores.get(i - 1, relative_j - 1) + Self::EQUALITY_BASE.powi(exponent) } else { f64::NEG_INFINITY }; let score = insertion_score.max(deletion_score).max(equality_score); - self.scores.set(i, j, score); + self.scores.set(i, relative_j, score); } } @@ -133,7 +162,7 @@ impl StreamingDiff { let mut next_old_text_ix = self.old_text_ix; let next_new_text_ix = self.new.len(); for i in self.old_text_ix..=self.old.len() { - let score = self.scores.get(i, next_new_text_ix); + let score = self.scores.get(i, next_new_text_ix - self.new_text_ix); if score > max_score { max_score = score; next_old_text_ix = i; @@ -174,7 +203,9 @@ impl StreamingDiff { let (prev_i, prev_j) = [insertion_score, deletion_score, equality_score] .iter() - .max_by_key(|cell| cell.map(|(i, j)| OrderedFloat(self.scores.get(i, j)))) + .max_by_key(|cell| { + cell.map(|(i, j)| OrderedFloat(self.scores.get(i, j - self.new_text_ix))) + }) .unwrap() .unwrap();