chunking.rs

  1use language::{with_parser, Grammar, Tree};
  2use serde::{Deserialize, Serialize};
  3use sha2::{Digest, Sha256};
  4use std::{cmp, ops::Range, sync::Arc};
  5
  6const CHUNK_THRESHOLD: usize = 1500;
  7
  8#[derive(Debug, Clone, Serialize, Deserialize)]
  9pub struct Chunk {
 10    pub range: Range<usize>,
 11    pub digest: [u8; 32],
 12}
 13
 14pub fn chunk_text(text: &str, grammar: Option<&Arc<Grammar>>) -> Vec<Chunk> {
 15    if let Some(grammar) = grammar {
 16        let tree = with_parser(|parser| {
 17            parser
 18                .set_language(&grammar.ts_language)
 19                .expect("incompatible grammar");
 20            parser.parse(&text, None).expect("invalid language")
 21        });
 22
 23        chunk_parse_tree(tree, &text, CHUNK_THRESHOLD)
 24    } else {
 25        chunk_lines(&text)
 26    }
 27}
 28
 29fn chunk_parse_tree(tree: Tree, text: &str, chunk_threshold: usize) -> Vec<Chunk> {
 30    let mut chunk_ranges = Vec::new();
 31    let mut cursor = tree.walk();
 32
 33    let mut range = 0..0;
 34    loop {
 35        let node = cursor.node();
 36
 37        // If adding the node to the current chunk exceeds the threshold
 38        if node.end_byte() - range.start > chunk_threshold {
 39            // Try to descend into its first child. If we can't, flush the current
 40            // range and try again.
 41            if cursor.goto_first_child() {
 42                continue;
 43            } else if !range.is_empty() {
 44                chunk_ranges.push(range.clone());
 45                range.start = range.end;
 46                continue;
 47            }
 48
 49            // If we get here, the node itself has no children but is larger than the threshold.
 50            // Break its text into arbitrary chunks.
 51            split_text(text, range.clone(), node.end_byte(), &mut chunk_ranges);
 52        }
 53        range.end = node.end_byte();
 54
 55        // If we get here, we consumed the node. Advance to the next child, ascending if there isn't one.
 56        while !cursor.goto_next_sibling() {
 57            if !cursor.goto_parent() {
 58                if !range.is_empty() {
 59                    chunk_ranges.push(range);
 60                }
 61
 62                return chunk_ranges
 63                    .into_iter()
 64                    .map(|range| {
 65                        let digest = Sha256::digest(&text[range.clone()]).into();
 66                        Chunk { range, digest }
 67                    })
 68                    .collect();
 69            }
 70        }
 71    }
 72}
 73
 74fn chunk_lines(text: &str) -> Vec<Chunk> {
 75    let mut chunk_ranges = Vec::new();
 76    let mut range = 0..0;
 77
 78    let mut newlines = text.match_indices('\n').peekable();
 79    while let Some((newline_ix, _)) = newlines.peek() {
 80        let newline_ix = newline_ix + 1;
 81        if newline_ix - range.start <= CHUNK_THRESHOLD {
 82            range.end = newline_ix;
 83            newlines.next();
 84        } else {
 85            if range.is_empty() {
 86                split_text(text, range, newline_ix, &mut chunk_ranges);
 87                range = newline_ix..newline_ix;
 88            } else {
 89                chunk_ranges.push(range.clone());
 90                range.start = range.end;
 91            }
 92        }
 93    }
 94
 95    if !range.is_empty() {
 96        chunk_ranges.push(range);
 97    }
 98
 99    chunk_ranges
100        .into_iter()
101        .map(|range| {
102            let mut hasher = Sha256::new();
103            hasher.update(&text[range.clone()]);
104            let mut digest = [0u8; 32];
105            digest.copy_from_slice(hasher.finalize().as_slice());
106            Chunk { range, digest }
107        })
108        .collect()
109}
110
111fn split_text(
112    text: &str,
113    mut range: Range<usize>,
114    max_end: usize,
115    chunk_ranges: &mut Vec<Range<usize>>,
116) {
117    while range.start < max_end {
118        range.end = cmp::min(range.start + CHUNK_THRESHOLD, max_end);
119        while !text.is_char_boundary(range.end) {
120            range.end -= 1;
121        }
122        chunk_ranges.push(range.clone());
123        range.start = range.end;
124    }
125}
126
127#[cfg(test)]
128mod tests {
129    use super::*;
130    use language::{tree_sitter_rust, Language, LanguageConfig, LanguageMatcher};
131
132    // This example comes from crates/gpui/examples/window_positioning.rs which
133    // has the property of being CHUNK_THRESHOLD < TEXT.len() < 2*CHUNK_THRESHOLD
134    static TEXT: &str = r#"
135    use gpui::*;
136
137    struct WindowContent {
138        text: SharedString,
139    }
140
141    impl Render for WindowContent {
142        fn render(&mut self, _cx: &mut ViewContext<Self>) -> impl IntoElement {
143            div()
144                .flex()
145                .bg(rgb(0x1e2025))
146                .size_full()
147                .justify_center()
148                .items_center()
149                .text_xl()
150                .text_color(rgb(0xffffff))
151                .child(self.text.clone())
152        }
153    }
154
155    fn main() {
156        App::new().run(|cx: &mut AppContext| {
157            // Create several new windows, positioned in the top right corner of each screen
158
159            for screen in cx.displays() {
160                let options = {
161                    let popup_margin_width = DevicePixels::from(16);
162                    let popup_margin_height = DevicePixels::from(-0) - DevicePixels::from(48);
163
164                    let window_size = Size {
165                        width: px(400.),
166                        height: px(72.),
167                    };
168
169                    let screen_bounds = screen.bounds();
170                    let size: Size<DevicePixels> = window_size.into();
171
172                    let bounds = gpui::Bounds::<DevicePixels> {
173                        origin: screen_bounds.upper_right()
174                            - point(size.width + popup_margin_width, popup_margin_height),
175                        size: window_size.into(),
176                    };
177
178                    WindowOptions {
179                        // Set the bounds of the window in screen coordinates
180                        bounds: Some(bounds),
181                        // Specify the display_id to ensure the window is created on the correct screen
182                        display_id: Some(screen.id()),
183
184                        titlebar: None,
185                        window_background: WindowBackgroundAppearance::default(),
186                        focus: false,
187                        show: true,
188                        kind: WindowKind::PopUp,
189                        is_movable: false,
190                        fullscreen: false,
191                        app_id: None,
192                    }
193                };
194
195                cx.open_window(options, |cx| {
196                    cx.new_view(|_| WindowContent {
197                        text: format!("{:?}", screen.id()).into(),
198                    })
199                });
200            }
201        });
202    }"#;
203
204    fn setup_rust_language() -> Language {
205        Language::new(
206            LanguageConfig {
207                name: "Rust".into(),
208                matcher: LanguageMatcher {
209                    path_suffixes: vec!["rs".to_string()],
210                    ..Default::default()
211                },
212                ..Default::default()
213            },
214            Some(tree_sitter_rust::language()),
215        )
216    }
217
218    #[test]
219    fn test_chunk_text() {
220        let text = "a\n".repeat(1000);
221        let chunks = chunk_text(&text, None);
222        assert_eq!(
223            chunks.len(),
224            ((2000_f64) / (CHUNK_THRESHOLD as f64)).ceil() as usize
225        );
226    }
227
228    #[test]
229    fn test_chunk_text_grammar() {
230        // Let's set up a big text with some known segments
231        // We'll then chunk it and verify that the chunks are correct
232
233        let language = setup_rust_language();
234
235        let chunks = chunk_text(TEXT, language.grammar());
236        assert_eq!(chunks.len(), 2);
237
238        assert_eq!(chunks[0].range.start, 0);
239        assert_eq!(chunks[0].range.end, 1498);
240        // The break between chunks is right before the "Specify the display_id" comment
241
242        assert_eq!(chunks[1].range.start, 1498);
243        assert_eq!(chunks[1].range.end, 2434);
244    }
245
246    #[test]
247    fn test_chunk_parse_tree() {
248        let language = setup_rust_language();
249        let grammar = language.grammar().unwrap();
250
251        let tree = with_parser(|parser| {
252            parser
253                .set_language(&grammar.ts_language)
254                .expect("incompatible grammar");
255            parser.parse(TEXT, None).expect("invalid language")
256        });
257
258        let chunks = chunk_parse_tree(tree, TEXT, 250);
259        assert_eq!(chunks.len(), 11);
260    }
261
262    #[test]
263    fn test_chunk_unparsable() {
264        // Even if a chunk is unparsable, we should still be able to chunk it
265        let language = setup_rust_language();
266        let grammar = language.grammar().unwrap();
267
268        let text = r#"fn main() {"#;
269        let tree = with_parser(|parser| {
270            parser
271                .set_language(&grammar.ts_language)
272                .expect("incompatible grammar");
273            parser.parse(text, None).expect("invalid language")
274        });
275
276        let chunks = chunk_parse_tree(tree, text, 250);
277        assert_eq!(chunks.len(), 1);
278
279        assert_eq!(chunks[0].range.start, 0);
280        assert_eq!(chunks[0].range.end, 11);
281    }
282
283    #[test]
284    fn test_empty_text() {
285        let language = setup_rust_language();
286        let grammar = language.grammar().unwrap();
287
288        let tree = with_parser(|parser| {
289            parser
290                .set_language(&grammar.ts_language)
291                .expect("incompatible grammar");
292            parser.parse("", None).expect("invalid language")
293        });
294
295        let chunks = chunk_parse_tree(tree, "", CHUNK_THRESHOLD);
296        assert!(chunks.is_empty(), "Chunks should be empty for empty text");
297    }
298
299    #[test]
300    fn test_single_large_node() {
301        let large_text = "static ".to_owned() + "a".repeat(CHUNK_THRESHOLD - 1).as_str() + " = 2";
302
303        let language = setup_rust_language();
304        let grammar = language.grammar().unwrap();
305
306        let tree = with_parser(|parser| {
307            parser
308                .set_language(&grammar.ts_language)
309                .expect("incompatible grammar");
310            parser.parse(&large_text, None).expect("invalid language")
311        });
312
313        let chunks = chunk_parse_tree(tree, &large_text, CHUNK_THRESHOLD);
314
315        assert_eq!(
316            chunks.len(),
317            3,
318            "Large chunks are broken up according to grammar as best as possible"
319        );
320
321        // Expect chunks to be static, aaaaaa..., and = 2
322        assert_eq!(chunks[0].range.start, 0);
323        assert_eq!(chunks[0].range.end, "static".len());
324
325        assert_eq!(chunks[1].range.start, "static".len());
326        assert_eq!(chunks[1].range.end, "static".len() + CHUNK_THRESHOLD);
327
328        assert_eq!(chunks[2].range.start, "static".len() + CHUNK_THRESHOLD);
329        assert_eq!(chunks[2].range.end, large_text.len());
330    }
331
332    #[test]
333    fn test_multiple_small_nodes() {
334        let small_text = "a b c d e f g h i j k l m n o p q r s t u v w x y z";
335        let language = setup_rust_language();
336        let grammar = language.grammar().unwrap();
337
338        let tree = with_parser(|parser| {
339            parser
340                .set_language(&grammar.ts_language)
341                .expect("incompatible grammar");
342            parser.parse(small_text, None).expect("invalid language")
343        });
344
345        let chunks = chunk_parse_tree(tree, small_text, 5);
346        assert!(
347            chunks.len() > 1,
348            "Should have multiple chunks for multiple small nodes"
349        );
350    }
351
352    #[test]
353    fn test_node_with_children() {
354        let nested_text = "fn main() { let a = 1; let b = 2; }";
355        let language = setup_rust_language();
356        let grammar = language.grammar().unwrap();
357
358        let tree = with_parser(|parser| {
359            parser
360                .set_language(&grammar.ts_language)
361                .expect("incompatible grammar");
362            parser.parse(nested_text, None).expect("invalid language")
363        });
364
365        let chunks = chunk_parse_tree(tree, nested_text, 10);
366        assert!(
367            chunks.len() > 1,
368            "Should have multiple chunks for a node with children"
369        );
370    }
371
372    #[test]
373    fn test_text_with_unparsable_sections() {
374        // This test uses purposefully hit-or-miss sizing of 11 characters per likely chunk
375        let mixed_text = "fn main() { let a = 1; let b = 2; } unparsable bits here";
376        let language = setup_rust_language();
377        let grammar = language.grammar().unwrap();
378
379        let tree = with_parser(|parser| {
380            parser
381                .set_language(&grammar.ts_language)
382                .expect("incompatible grammar");
383            parser.parse(mixed_text, None).expect("invalid language")
384        });
385
386        let chunks = chunk_parse_tree(tree, mixed_text, 11);
387        assert!(
388            chunks.len() > 1,
389            "Should handle both parsable and unparsable sections correctly"
390        );
391
392        let expected_chunks = [
393            "fn main() {",
394            " let a = 1;",
395            " let b = 2;",
396            " }",
397            " unparsable",
398            " bits here",
399        ];
400
401        for (i, chunk) in chunks.iter().enumerate() {
402            assert_eq!(
403                &mixed_text[chunk.range.clone()],
404                expected_chunks[i],
405                "Chunk {} should match",
406                i
407            );
408        }
409    }
410}