chunking.rs

  1use language::{with_parser, Grammar, Tree};
  2use serde::{Deserialize, Serialize};
  3use sha2::{Digest, Sha256};
  4use std::{cmp, ops::Range, sync::Arc};
  5
  6const CHUNK_THRESHOLD: usize = 1500;
  7
  8#[derive(Debug, Clone, Serialize, Deserialize)]
  9pub struct Chunk {
 10    pub range: Range<usize>,
 11    pub digest: [u8; 32],
 12}
 13
 14pub fn chunk_text(text: &str, grammar: Option<&Arc<Grammar>>) -> Vec<Chunk> {
 15    if let Some(grammar) = grammar {
 16        let tree = with_parser(|parser| {
 17            parser
 18                .set_language(&grammar.ts_language)
 19                .expect("incompatible grammar");
 20            parser.parse(&text, None).expect("invalid language")
 21        });
 22
 23        chunk_parse_tree(tree, &text, CHUNK_THRESHOLD)
 24    } else {
 25        chunk_lines(&text)
 26    }
 27}
 28
 29fn chunk_parse_tree(tree: Tree, text: &str, chunk_threshold: usize) -> Vec<Chunk> {
 30    let mut chunk_ranges = Vec::new();
 31    let mut cursor = tree.walk();
 32
 33    let mut range = 0..0;
 34    loop {
 35        let node = cursor.node();
 36
 37        // If adding the node to the current chunk exceeds the threshold
 38        if node.end_byte() - range.start > chunk_threshold {
 39            // Try to descend into its first child. If we can't, flush the current
 40            // range and try again.
 41            if cursor.goto_first_child() {
 42                continue;
 43            } else if !range.is_empty() {
 44                chunk_ranges.push(range.clone());
 45                range.start = range.end;
 46                continue;
 47            }
 48
 49            // If we get here, the node itself has no children but is larger than the threshold.
 50            // Break its text into arbitrary chunks.
 51            split_text(text, range.clone(), node.end_byte(), &mut chunk_ranges);
 52        }
 53        range.end = node.end_byte();
 54
 55        // If we get here, we consumed the node. Advance to the next child, ascending if there isn't one.
 56        while !cursor.goto_next_sibling() {
 57            if !cursor.goto_parent() {
 58                if !range.is_empty() {
 59                    chunk_ranges.push(range);
 60                }
 61
 62                return chunk_ranges
 63                    .into_iter()
 64                    .map(|range| {
 65                        let digest = Sha256::digest(&text[range.clone()]).into();
 66                        Chunk { range, digest }
 67                    })
 68                    .collect();
 69            }
 70        }
 71    }
 72}
 73
 74fn chunk_lines(text: &str) -> Vec<Chunk> {
 75    let mut chunk_ranges = Vec::new();
 76    let mut range = 0..0;
 77
 78    let mut newlines = text.match_indices('\n').peekable();
 79    while let Some((newline_ix, _)) = newlines.peek() {
 80        let newline_ix = newline_ix + 1;
 81        if newline_ix - range.start <= CHUNK_THRESHOLD {
 82            range.end = newline_ix;
 83            newlines.next();
 84        } else {
 85            if range.is_empty() {
 86                split_text(text, range, newline_ix, &mut chunk_ranges);
 87                range = newline_ix..newline_ix;
 88            } else {
 89                chunk_ranges.push(range.clone());
 90                range.start = range.end;
 91            }
 92        }
 93    }
 94
 95    if !range.is_empty() {
 96        chunk_ranges.push(range);
 97    }
 98
 99    chunk_ranges
100        .into_iter()
101        .map(|range| {
102            let mut hasher = Sha256::new();
103            hasher.update(&text[range.clone()]);
104            let mut digest = [0u8; 32];
105            digest.copy_from_slice(hasher.finalize().as_slice());
106            Chunk { range, digest }
107        })
108        .collect()
109}
110
111fn split_text(
112    text: &str,
113    mut range: Range<usize>,
114    max_end: usize,
115    chunk_ranges: &mut Vec<Range<usize>>,
116) {
117    while range.start < max_end {
118        range.end = cmp::min(range.start + CHUNK_THRESHOLD, max_end);
119        while !text.is_char_boundary(range.end) {
120            range.end -= 1;
121        }
122        chunk_ranges.push(range.clone());
123        range.start = range.end;
124    }
125}
126
127#[cfg(test)]
128mod tests {
129    use super::*;
130    use language::{tree_sitter_rust, Language, LanguageConfig, LanguageMatcher};
131
132    // This example comes from crates/gpui/examples/window_positioning.rs which
133    // has the property of being CHUNK_THRESHOLD < TEXT.len() < 2*CHUNK_THRESHOLD
134    static TEXT: &str = r#"
135    use gpui::*;
136
137    struct WindowContent {
138        text: SharedString,
139    }
140
141    impl Render for WindowContent {
142        fn render(&mut self, _cx: &mut ViewContext<Self>) -> impl IntoElement {
143            div()
144                .flex()
145                .bg(rgb(0x1e2025))
146                .size_full()
147                .justify_center()
148                .items_center()
149                .text_xl()
150                .text_color(rgb(0xffffff))
151                .child(self.text.clone())
152        }
153    }
154
155    fn main() {
156        App::new().run(|cx: &mut AppContext| {
157            // Create several new windows, positioned in the top right corner of each screen
158
159            for screen in cx.displays() {
160                let options = {
161                    let popup_margin_width = DevicePixels::from(16);
162                    let popup_margin_height = DevicePixels::from(-0) - DevicePixels::from(48);
163
164                    let window_size = Size {
165                        width: px(400.),
166                        height: px(72.),
167                    };
168
169                    let screen_bounds = screen.bounds();
170                    let size: Size<DevicePixels> = window_size.into();
171
172                    let bounds = gpui::Bounds::<DevicePixels> {
173                        origin: screen_bounds.upper_right()
174                            - point(size.width + popup_margin_width, popup_margin_height),
175                        size: window_size.into(),
176                    };
177
178                    WindowOptions {
179                        // Set the bounds of the window in screen coordinates
180                        bounds: Some(bounds),
181                        // Specify the display_id to ensure the window is created on the correct screen
182                        display_id: Some(screen.id()),
183
184                        titlebar: None,
185                        window_background: WindowBackgroundAppearance::default(),
186                        focus: false,
187                        show: true,
188                        kind: WindowKind::PopUp,
189                        is_movable: false,
190                        fullscreen: false,
191                    }
192                };
193
194                cx.open_window(options, |cx| {
195                    cx.new_view(|_| WindowContent {
196                        text: format!("{:?}", screen.id()).into(),
197                    })
198                });
199            }
200        });
201    }"#;
202
203    fn setup_rust_language() -> Language {
204        Language::new(
205            LanguageConfig {
206                name: "Rust".into(),
207                matcher: LanguageMatcher {
208                    path_suffixes: vec!["rs".to_string()],
209                    ..Default::default()
210                },
211                ..Default::default()
212            },
213            Some(tree_sitter_rust::language()),
214        )
215    }
216
217    #[test]
218    fn test_chunk_text() {
219        let text = "a\n".repeat(1000);
220        let chunks = chunk_text(&text, None);
221        assert_eq!(
222            chunks.len(),
223            ((2000_f64) / (CHUNK_THRESHOLD as f64)).ceil() as usize
224        );
225    }
226
227    #[test]
228    fn test_chunk_text_grammar() {
229        // Let's set up a big text with some known segments
230        // We'll then chunk it and verify that the chunks are correct
231
232        let language = setup_rust_language();
233
234        let chunks = chunk_text(TEXT, language.grammar());
235        assert_eq!(chunks.len(), 2);
236
237        assert_eq!(chunks[0].range.start, 0);
238        assert_eq!(chunks[0].range.end, 1498);
239        // The break between chunks is right before the "Specify the display_id" comment
240
241        assert_eq!(chunks[1].range.start, 1498);
242        assert_eq!(chunks[1].range.end, 2396);
243    }
244
245    #[test]
246    fn test_chunk_parse_tree() {
247        let language = setup_rust_language();
248        let grammar = language.grammar().unwrap();
249
250        let tree = with_parser(|parser| {
251            parser
252                .set_language(&grammar.ts_language)
253                .expect("incompatible grammar");
254            parser.parse(TEXT, None).expect("invalid language")
255        });
256
257        let chunks = chunk_parse_tree(tree, TEXT, 250);
258        assert_eq!(chunks.len(), 11);
259    }
260
261    #[test]
262    fn test_chunk_unparsable() {
263        // Even if a chunk is unparsable, we should still be able to chunk it
264        let language = setup_rust_language();
265        let grammar = language.grammar().unwrap();
266
267        let text = r#"fn main() {"#;
268        let tree = with_parser(|parser| {
269            parser
270                .set_language(&grammar.ts_language)
271                .expect("incompatible grammar");
272            parser.parse(text, None).expect("invalid language")
273        });
274
275        let chunks = chunk_parse_tree(tree, text, 250);
276        assert_eq!(chunks.len(), 1);
277
278        assert_eq!(chunks[0].range.start, 0);
279        assert_eq!(chunks[0].range.end, 11);
280    }
281
282    #[test]
283    fn test_empty_text() {
284        let language = setup_rust_language();
285        let grammar = language.grammar().unwrap();
286
287        let tree = with_parser(|parser| {
288            parser
289                .set_language(&grammar.ts_language)
290                .expect("incompatible grammar");
291            parser.parse("", None).expect("invalid language")
292        });
293
294        let chunks = chunk_parse_tree(tree, "", CHUNK_THRESHOLD);
295        assert!(chunks.is_empty(), "Chunks should be empty for empty text");
296    }
297
298    #[test]
299    fn test_single_large_node() {
300        let large_text = "static ".to_owned() + "a".repeat(CHUNK_THRESHOLD - 1).as_str() + " = 2";
301
302        let language = setup_rust_language();
303        let grammar = language.grammar().unwrap();
304
305        let tree = with_parser(|parser| {
306            parser
307                .set_language(&grammar.ts_language)
308                .expect("incompatible grammar");
309            parser.parse(&large_text, None).expect("invalid language")
310        });
311
312        let chunks = chunk_parse_tree(tree, &large_text, CHUNK_THRESHOLD);
313
314        assert_eq!(
315            chunks.len(),
316            3,
317            "Large chunks are broken up according to grammar as best as possible"
318        );
319
320        // Expect chunks to be static, aaaaaa..., and = 2
321        assert_eq!(chunks[0].range.start, 0);
322        assert_eq!(chunks[0].range.end, "static".len());
323
324        assert_eq!(chunks[1].range.start, "static".len());
325        assert_eq!(chunks[1].range.end, "static".len() + CHUNK_THRESHOLD);
326
327        assert_eq!(chunks[2].range.start, "static".len() + CHUNK_THRESHOLD);
328        assert_eq!(chunks[2].range.end, large_text.len());
329    }
330
331    #[test]
332    fn test_multiple_small_nodes() {
333        let small_text = "a b c d e f g h i j k l m n o p q r s t u v w x y z";
334        let language = setup_rust_language();
335        let grammar = language.grammar().unwrap();
336
337        let tree = with_parser(|parser| {
338            parser
339                .set_language(&grammar.ts_language)
340                .expect("incompatible grammar");
341            parser.parse(small_text, None).expect("invalid language")
342        });
343
344        let chunks = chunk_parse_tree(tree, small_text, 5);
345        assert!(
346            chunks.len() > 1,
347            "Should have multiple chunks for multiple small nodes"
348        );
349    }
350
351    #[test]
352    fn test_node_with_children() {
353        let nested_text = "fn main() { let a = 1; let b = 2; }";
354        let language = setup_rust_language();
355        let grammar = language.grammar().unwrap();
356
357        let tree = with_parser(|parser| {
358            parser
359                .set_language(&grammar.ts_language)
360                .expect("incompatible grammar");
361            parser.parse(nested_text, None).expect("invalid language")
362        });
363
364        let chunks = chunk_parse_tree(tree, nested_text, 10);
365        assert!(
366            chunks.len() > 1,
367            "Should have multiple chunks for a node with children"
368        );
369    }
370
371    #[test]
372    fn test_text_with_unparsable_sections() {
373        // This test uses purposefully hit-or-miss sizing of 11 characters per likely chunk
374        let mixed_text = "fn main() { let a = 1; let b = 2; } unparsable bits here";
375        let language = setup_rust_language();
376        let grammar = language.grammar().unwrap();
377
378        let tree = with_parser(|parser| {
379            parser
380                .set_language(&grammar.ts_language)
381                .expect("incompatible grammar");
382            parser.parse(mixed_text, None).expect("invalid language")
383        });
384
385        let chunks = chunk_parse_tree(tree, mixed_text, 11);
386        assert!(
387            chunks.len() > 1,
388            "Should handle both parsable and unparsable sections correctly"
389        );
390
391        let expected_chunks = [
392            "fn main() {",
393            " let a = 1;",
394            " let b = 2;",
395            " }",
396            " unparsable",
397            " bits here",
398        ];
399
400        for (i, chunk) in chunks.iter().enumerate() {
401            assert_eq!(
402                &mixed_text[chunk.range.clone()],
403                expected_chunks[i],
404                "Chunk {} should match",
405                i
406            );
407        }
408    }
409}