chunking.rs

  1use language::{with_parser, Grammar, Tree};
  2use serde::{Deserialize, Serialize};
  3use sha2::{Digest, Sha256};
  4use std::{cmp, ops::Range, sync::Arc};
  5
  6const CHUNK_THRESHOLD: usize = 1500;
  7
  8#[derive(Debug, Clone, Serialize, Deserialize)]
  9pub struct Chunk {
 10    pub range: Range<usize>,
 11    pub digest: [u8; 32],
 12}
 13
 14pub fn chunk_text(text: &str, grammar: Option<&Arc<Grammar>>) -> Vec<Chunk> {
 15    if let Some(grammar) = grammar {
 16        let tree = with_parser(|parser| {
 17            parser
 18                .set_language(&grammar.ts_language)
 19                .expect("incompatible grammar");
 20            parser.parse(&text, None).expect("invalid language")
 21        });
 22
 23        chunk_parse_tree(tree, &text, CHUNK_THRESHOLD)
 24    } else {
 25        chunk_lines(&text)
 26    }
 27}
 28
 29fn chunk_parse_tree(tree: Tree, text: &str, chunk_threshold: usize) -> Vec<Chunk> {
 30    let mut chunk_ranges = Vec::new();
 31    let mut cursor = tree.walk();
 32
 33    let mut range = 0..0;
 34    loop {
 35        let node = cursor.node();
 36
 37        // If adding the node to the current chunk exceeds the threshold
 38        if node.end_byte() - range.start > chunk_threshold {
 39            // Try to descend into its first child. If we can't, flush the current
 40            // range and try again.
 41            if cursor.goto_first_child() {
 42                continue;
 43            } else if !range.is_empty() {
 44                chunk_ranges.push(range.clone());
 45                range.start = range.end;
 46                continue;
 47            }
 48
 49            // If we get here, the node itself has no children but is larger than the threshold.
 50            // Break its text into arbitrary chunks.
 51            split_text(text, range.clone(), node.end_byte(), &mut chunk_ranges);
 52        }
 53        range.end = node.end_byte();
 54
 55        // If we get here, we consumed the node. Advance to the next child, ascending if there isn't one.
 56        while !cursor.goto_next_sibling() {
 57            if !cursor.goto_parent() {
 58                if !range.is_empty() {
 59                    chunk_ranges.push(range);
 60                }
 61
 62                return chunk_ranges
 63                    .into_iter()
 64                    .map(|range| {
 65                        let digest = Sha256::digest(&text[range.clone()]).into();
 66                        Chunk { range, digest }
 67                    })
 68                    .collect();
 69            }
 70        }
 71    }
 72}
 73
 74fn chunk_lines(text: &str) -> Vec<Chunk> {
 75    let mut chunk_ranges = Vec::new();
 76    let mut range = 0..0;
 77
 78    let mut newlines = text.match_indices('\n').peekable();
 79    while let Some((newline_ix, _)) = newlines.peek() {
 80        let newline_ix = newline_ix + 1;
 81        if newline_ix - range.start <= CHUNK_THRESHOLD {
 82            range.end = newline_ix;
 83            newlines.next();
 84        } else {
 85            if range.is_empty() {
 86                split_text(text, range, newline_ix, &mut chunk_ranges);
 87                range = newline_ix..newline_ix;
 88            } else {
 89                chunk_ranges.push(range.clone());
 90                range.start = range.end;
 91            }
 92        }
 93    }
 94
 95    if !range.is_empty() {
 96        chunk_ranges.push(range);
 97    }
 98
 99    chunk_ranges
100        .into_iter()
101        .map(|range| Chunk {
102            digest: Sha256::digest(&text[range.clone()]).into(),
103            range,
104        })
105        .collect()
106}
107
108fn split_text(
109    text: &str,
110    mut range: Range<usize>,
111    max_end: usize,
112    chunk_ranges: &mut Vec<Range<usize>>,
113) {
114    while range.start < max_end {
115        range.end = cmp::min(range.start + CHUNK_THRESHOLD, max_end);
116        while !text.is_char_boundary(range.end) {
117            range.end -= 1;
118        }
119        chunk_ranges.push(range.clone());
120        range.start = range.end;
121    }
122}
123
124#[cfg(test)]
125mod tests {
126    use super::*;
127    use language::{tree_sitter_rust, Language, LanguageConfig, LanguageMatcher};
128
129    // This example comes from crates/gpui/examples/window_positioning.rs which
130    // has the property of being CHUNK_THRESHOLD < TEXT.len() < 2*CHUNK_THRESHOLD
131    static TEXT: &str = r#"
132    use gpui::*;
133
134    struct WindowContent {
135        text: SharedString,
136    }
137
138    impl Render for WindowContent {
139        fn render(&mut self, _cx: &mut ViewContext<Self>) -> impl IntoElement {
140            div()
141                .flex()
142                .bg(rgb(0x1e2025))
143                .size_full()
144                .justify_center()
145                .items_center()
146                .text_xl()
147                .text_color(rgb(0xffffff))
148                .child(self.text.clone())
149        }
150    }
151
152    fn main() {
153        App::new().run(|cx: &mut AppContext| {
154            // Create several new windows, positioned in the top right corner of each screen
155
156            for screen in cx.displays() {
157                let options = {
158                    let popup_margin_width = DevicePixels::from(16);
159                    let popup_margin_height = DevicePixels::from(-0) - DevicePixels::from(48);
160
161                    let window_size = Size {
162                        width: px(400.),
163                        height: px(72.),
164                    };
165
166                    let screen_bounds = screen.bounds();
167                    let size: Size<DevicePixels> = window_size.into();
168
169                    let bounds = gpui::Bounds::<DevicePixels> {
170                        origin: screen_bounds.upper_right()
171                            - point(size.width + popup_margin_width, popup_margin_height),
172                        size: window_size.into(),
173                    };
174
175                    WindowOptions {
176                        // Set the bounds of the window in screen coordinates
177                        bounds: Some(bounds),
178                        // Specify the display_id to ensure the window is created on the correct screen
179                        display_id: Some(screen.id()),
180
181                        titlebar: None,
182                        window_background: WindowBackgroundAppearance::default(),
183                        focus: false,
184                        show: true,
185                        kind: WindowKind::PopUp,
186                        is_movable: false,
187                        fullscreen: false,
188                        app_id: None,
189                    }
190                };
191
192                cx.open_window(options, |cx| {
193                    cx.new_view(|_| WindowContent {
194                        text: format!("{:?}", screen.id()).into(),
195                    })
196                });
197            }
198        });
199    }"#;
200
201    fn setup_rust_language() -> Language {
202        Language::new(
203            LanguageConfig {
204                name: "Rust".into(),
205                matcher: LanguageMatcher {
206                    path_suffixes: vec!["rs".to_string()],
207                    ..Default::default()
208                },
209                ..Default::default()
210            },
211            Some(tree_sitter_rust::language()),
212        )
213    }
214
215    #[test]
216    fn test_chunk_text() {
217        let text = "a\n".repeat(1000);
218        let chunks = chunk_text(&text, None);
219        assert_eq!(
220            chunks.len(),
221            ((2000_f64) / (CHUNK_THRESHOLD as f64)).ceil() as usize
222        );
223    }
224
225    #[test]
226    fn test_chunk_text_grammar() {
227        // Let's set up a big text with some known segments
228        // We'll then chunk it and verify that the chunks are correct
229
230        let language = setup_rust_language();
231
232        let chunks = chunk_text(TEXT, language.grammar());
233        assert_eq!(chunks.len(), 2);
234
235        assert_eq!(chunks[0].range.start, 0);
236        assert_eq!(chunks[0].range.end, 1498);
237        // The break between chunks is right before the "Specify the display_id" comment
238
239        assert_eq!(chunks[1].range.start, 1498);
240        assert_eq!(chunks[1].range.end, 2434);
241    }
242
243    #[test]
244    fn test_chunk_parse_tree() {
245        let language = setup_rust_language();
246        let grammar = language.grammar().unwrap();
247
248        let tree = with_parser(|parser| {
249            parser
250                .set_language(&grammar.ts_language)
251                .expect("incompatible grammar");
252            parser.parse(TEXT, None).expect("invalid language")
253        });
254
255        let chunks = chunk_parse_tree(tree, TEXT, 250);
256        assert_eq!(chunks.len(), 11);
257    }
258
259    #[test]
260    fn test_chunk_unparsable() {
261        // Even if a chunk is unparsable, we should still be able to chunk it
262        let language = setup_rust_language();
263        let grammar = language.grammar().unwrap();
264
265        let text = r#"fn main() {"#;
266        let tree = with_parser(|parser| {
267            parser
268                .set_language(&grammar.ts_language)
269                .expect("incompatible grammar");
270            parser.parse(text, None).expect("invalid language")
271        });
272
273        let chunks = chunk_parse_tree(tree, text, 250);
274        assert_eq!(chunks.len(), 1);
275
276        assert_eq!(chunks[0].range.start, 0);
277        assert_eq!(chunks[0].range.end, 11);
278    }
279
280    #[test]
281    fn test_empty_text() {
282        let language = setup_rust_language();
283        let grammar = language.grammar().unwrap();
284
285        let tree = with_parser(|parser| {
286            parser
287                .set_language(&grammar.ts_language)
288                .expect("incompatible grammar");
289            parser.parse("", None).expect("invalid language")
290        });
291
292        let chunks = chunk_parse_tree(tree, "", CHUNK_THRESHOLD);
293        assert!(chunks.is_empty(), "Chunks should be empty for empty text");
294    }
295
296    #[test]
297    fn test_single_large_node() {
298        let large_text = "static ".to_owned() + "a".repeat(CHUNK_THRESHOLD - 1).as_str() + " = 2";
299
300        let language = setup_rust_language();
301        let grammar = language.grammar().unwrap();
302
303        let tree = with_parser(|parser| {
304            parser
305                .set_language(&grammar.ts_language)
306                .expect("incompatible grammar");
307            parser.parse(&large_text, None).expect("invalid language")
308        });
309
310        let chunks = chunk_parse_tree(tree, &large_text, CHUNK_THRESHOLD);
311
312        assert_eq!(
313            chunks.len(),
314            3,
315            "Large chunks are broken up according to grammar as best as possible"
316        );
317
318        // Expect chunks to be static, aaaaaa..., and = 2
319        assert_eq!(chunks[0].range.start, 0);
320        assert_eq!(chunks[0].range.end, "static".len());
321
322        assert_eq!(chunks[1].range.start, "static".len());
323        assert_eq!(chunks[1].range.end, "static".len() + CHUNK_THRESHOLD);
324
325        assert_eq!(chunks[2].range.start, "static".len() + CHUNK_THRESHOLD);
326        assert_eq!(chunks[2].range.end, large_text.len());
327    }
328
329    #[test]
330    fn test_multiple_small_nodes() {
331        let small_text = "a b c d e f g h i j k l m n o p q r s t u v w x y z";
332        let language = setup_rust_language();
333        let grammar = language.grammar().unwrap();
334
335        let tree = with_parser(|parser| {
336            parser
337                .set_language(&grammar.ts_language)
338                .expect("incompatible grammar");
339            parser.parse(small_text, None).expect("invalid language")
340        });
341
342        let chunks = chunk_parse_tree(tree, small_text, 5);
343        assert!(
344            chunks.len() > 1,
345            "Should have multiple chunks for multiple small nodes"
346        );
347    }
348
349    #[test]
350    fn test_node_with_children() {
351        let nested_text = "fn main() { let a = 1; let b = 2; }";
352        let language = setup_rust_language();
353        let grammar = language.grammar().unwrap();
354
355        let tree = with_parser(|parser| {
356            parser
357                .set_language(&grammar.ts_language)
358                .expect("incompatible grammar");
359            parser.parse(nested_text, None).expect("invalid language")
360        });
361
362        let chunks = chunk_parse_tree(tree, nested_text, 10);
363        assert!(
364            chunks.len() > 1,
365            "Should have multiple chunks for a node with children"
366        );
367    }
368
369    #[test]
370    fn test_text_with_unparsable_sections() {
371        // This test uses purposefully hit-or-miss sizing of 11 characters per likely chunk
372        let mixed_text = "fn main() { let a = 1; let b = 2; } unparsable bits here";
373        let language = setup_rust_language();
374        let grammar = language.grammar().unwrap();
375
376        let tree = with_parser(|parser| {
377            parser
378                .set_language(&grammar.ts_language)
379                .expect("incompatible grammar");
380            parser.parse(mixed_text, None).expect("invalid language")
381        });
382
383        let chunks = chunk_parse_tree(tree, mixed_text, 11);
384        assert!(
385            chunks.len() > 1,
386            "Should handle both parsable and unparsable sections correctly"
387        );
388
389        let expected_chunks = [
390            "fn main() {",
391            " let a = 1;",
392            " let b = 2;",
393            " }",
394            " unparsable",
395            " bits here",
396        ];
397
398        for (i, chunk) in chunks.iter().enumerate() {
399            assert_eq!(
400                &mixed_text[chunk.range.clone()],
401                expected_chunks[i],
402                "Chunk {} should match",
403                i
404            );
405        }
406    }
407}