Expand tabs correctly in `TabMap`'s highlighted chunks iterator

Antonio Scandurra and Nathan Sobo created

Co-Authored-By: Nathan Sobo <nathan@zed.dev>

Change summary

zed/src/editor/display_map/tab_map.rs  | 22 +++++---
zed/src/editor/display_map/wrap_map.rs | 67 +++++++++++++++++++++------
2 files changed, 65 insertions(+), 24 deletions(-)

Detailed changes

zed/src/editor/display_map/tab_map.rs 🔗

@@ -145,17 +145,18 @@ impl Snapshot {
     }
 
     pub fn highlighted_chunks(&mut self, range: Range<OutputPoint>) -> HighlightedChunks {
-        let input_start = self
-            .input
-            .to_output_offset(self.to_input_point(range.start, Bias::Left).0);
+        let (input_start, expanded_char_column, to_next_stop) =
+            self.to_input_point(range.start, Bias::Left);
+        let input_start = self.input.to_output_offset(input_start);
         let input_end = self
             .input
             .to_output_offset(self.to_input_point(range.end, Bias::Left).0);
         HighlightedChunks {
-            input_chunks: self.input.highlighted_chunks(input_start..input_end),
-            column: 0,
+            fold_chunks: self.input.highlighted_chunks(input_start..input_end),
+            column: expanded_char_column,
             tab_size: self.tab_size,
-            chunk: "",
+            chunk: &SPACES[0..to_next_stop],
+            skip_leading_tab: to_next_stop > 0,
             style_id: Default::default(),
         }
     }
@@ -429,11 +430,12 @@ impl<'a> Iterator for Chunks<'a> {
 }
 
 pub struct HighlightedChunks<'a> {
-    input_chunks: InputHighlightedChunks<'a>,
+    fold_chunks: InputHighlightedChunks<'a>,
     chunk: &'a str,
     style_id: StyleId,
     column: usize,
     tab_size: usize,
+    skip_leading_tab: bool,
 }
 
 impl<'a> Iterator for HighlightedChunks<'a> {
@@ -441,9 +443,13 @@ impl<'a> Iterator for HighlightedChunks<'a> {
 
     fn next(&mut self) -> Option<Self::Item> {
         if self.chunk.is_empty() {
-            if let Some((chunk, style_id)) = self.input_chunks.next() {
+            if let Some((chunk, style_id)) = self.fold_chunks.next() {
                 self.chunk = chunk;
                 self.style_id = style_id;
+                if self.skip_leading_tab {
+                    self.chunk = &self.chunk[1..];
+                    self.skip_leading_tab = false;
+                }
             } else {
                 return None;
             }

zed/src/editor/display_map/wrap_map.rs 🔗

@@ -54,7 +54,7 @@ pub struct WrapPoint(super::Point);
 pub struct Chunks<'a> {
     input_chunks: tab_map::Chunks<'a>,
     input_chunk: &'a str,
-    input_position: TabPoint,
+    output_position: WrapPoint,
     transforms: Cursor<'a, Transform, WrapPoint, TabPoint>,
 }
 
@@ -451,7 +451,7 @@ impl Snapshot {
         Chunks {
             input_chunks,
             transforms,
-            input_position,
+            output_position: point,
             input_chunk: "",
         }
     }
@@ -547,6 +547,7 @@ impl<'a> Iterator for Chunks<'a> {
     fn next(&mut self) -> Option<Self::Item> {
         let transform = self.transforms.item()?;
         if let Some(display_text) = transform.display_text {
+            self.output_position.0 += transform.summary.output.lines;
             self.transforms.next(&());
             return Some(display_text);
         }
@@ -556,18 +557,18 @@ impl<'a> Iterator for Chunks<'a> {
         }
 
         let mut input_len = 0;
-        let transform_end = self.transforms.sum_end(&());
+        let transform_end = self.transforms.seek_end(&());
         for c in self.input_chunk.chars() {
             let char_len = c.len_utf8();
             input_len += char_len;
             if c == '\n' {
-                *self.input_position.row_mut() += 1;
-                *self.input_position.column_mut() = 0;
+                *self.output_position.row_mut() += 1;
+                *self.output_position.column_mut() = 0;
             } else {
-                *self.input_position.column_mut() += char_len as u32;
+                *self.output_position.column_mut() += char_len as u32;
             }
 
-            if self.input_position >= transform_end {
+            if self.output_position >= transform_end {
                 self.transforms.next(&());
                 break;
             }
@@ -837,18 +838,19 @@ mod tests {
                 let (folds_snapshot, edits) = cx.read(|cx| fold_map.read(cx));
                 let (tabs_snapshot, edits) = tab_map.sync(folds_snapshot, edits);
                 interpolated_snapshot.interpolate(tabs_snapshot.clone(), &edits);
-                interpolated_snapshot.check_invariants();
+                interpolated_snapshot.check_invariants(&mut rng);
 
                 let unwrapped_text = tabs_snapshot.text();
                 let expected_text = wrap_text(&unwrapped_text, wrap_width, &mut line_wrapper);
                 let mut snapshot = cx.read(|cx| wrap_map.sync(tabs_snapshot.clone(), edits, cx));
+                snapshot.check_invariants(&mut rng);
 
                 if wrap_map.is_rewrapping() {
                     notifications.recv().await;
                     snapshot = cx.read(|cx| wrap_map.sync(tabs_snapshot, Vec::new(), cx));
                 }
 
-                snapshot.check_invariants();
+                snapshot.check_invariants(&mut rng);
                 let actual_text = snapshot.text();
                 assert_eq!(
                     actual_text, expected_text,
@@ -884,18 +886,51 @@ mod tests {
             self.chunks_at(WrapPoint::zero()).collect()
         }
 
-        fn check_invariants(&self) {
+        fn check_invariants(&mut self, rng: &mut impl Rng) {
             assert_eq!(
                 TabPoint::from(self.transforms.summary().input.lines),
                 self.tab_snapshot.max_point()
             );
 
-            let mut transforms = self.transforms.cursor::<(), ()>().peekable();
-            while let Some(transform) = transforms.next() {
-                let next_transform = transforms.peek();
-                assert!(
-                    !transform.is_isomorphic()
-                        || next_transform.map_or(true, |t| !t.is_isomorphic())
+            {
+                let mut transforms = self.transforms.cursor::<(), ()>().peekable();
+                while let Some(transform) = transforms.next() {
+                    let next_transform = transforms.peek();
+                    assert!(
+                        !transform.is_isomorphic()
+                            || next_transform.map_or(true, |t| !t.is_isomorphic())
+                    );
+                }
+            }
+
+            for _ in 0..5 {
+                let mut end_row = rng.gen_range(0..=self.max_point().row());
+                let start_row = rng.gen_range(0..=end_row);
+                end_row += 1;
+
+                let mut expected_text = self
+                    .chunks_at(WrapPoint::new(start_row, 0))
+                    .collect::<String>();
+                if expected_text.ends_with("\n") {
+                    expected_text.push('\n');
+                }
+                let mut expected_text = expected_text
+                    .lines()
+                    .take((end_row - start_row) as usize)
+                    .collect::<Vec<_>>()
+                    .join("\n");
+                if end_row <= self.max_point().row() {
+                    expected_text.push('\n');
+                }
+                let actual_text = self
+                    .highlighted_chunks_for_rows(start_row..end_row)
+                    .map(|c| c.0)
+                    .collect::<String>();
+                assert_eq!(
+                    expected_text,
+                    actual_text,
+                    "chunks != highlighted_chunks for rows {:?}",
+                    start_row..end_row
                 );
             }
         }