chunking.rs

  1use language::{with_parser, with_query_cursor, Grammar};
  2use serde::{Deserialize, Serialize};
  3use sha2::{Digest, Sha256};
  4use std::{
  5    cmp::{self, Reverse},
  6    ops::Range,
  7    sync::Arc,
  8};
  9use tree_sitter::QueryCapture;
 10use util::ResultExt as _;
 11
 12#[derive(Copy, Clone)]
 13struct ChunkSizeRange {
 14    min: usize,
 15    max: usize,
 16}
 17
 18const CHUNK_SIZE_RANGE: ChunkSizeRange = ChunkSizeRange {
 19    min: 1024,
 20    max: 8192,
 21};
 22
 23#[derive(Debug, Clone, Serialize, Deserialize)]
 24pub struct Chunk {
 25    pub range: Range<usize>,
 26    pub digest: [u8; 32],
 27}
 28
 29pub fn chunk_text(text: &str, grammar: Option<&Arc<Grammar>>) -> Vec<Chunk> {
 30    chunk_text_with_size_range(text, grammar, CHUNK_SIZE_RANGE)
 31}
 32
 33fn chunk_text_with_size_range(
 34    text: &str,
 35    grammar: Option<&Arc<Grammar>>,
 36    size_config: ChunkSizeRange,
 37) -> Vec<Chunk> {
 38    let mut syntactic_ranges = Vec::new();
 39
 40    if let Some(grammar) = grammar {
 41        if let Some(outline) = grammar.outline_config.as_ref() {
 42            let tree = with_parser(|parser| {
 43                parser.set_language(&grammar.ts_language).log_err()?;
 44                parser.parse(&text, None)
 45            });
 46
 47            if let Some(tree) = tree {
 48                with_query_cursor(|cursor| {
 49                    // Retrieve a list of ranges of outline items (types, functions, etc) in the document.
 50                    // Omit single-line outline items (e.g. struct fields, constant declarations), because
 51                    // we'll already be attempting to split on lines.
 52                    syntactic_ranges = cursor
 53                        .matches(&outline.query, tree.root_node(), text.as_bytes())
 54                        .filter_map(|mat| {
 55                            mat.captures
 56                                .iter()
 57                                .find_map(|QueryCapture { node, index }| {
 58                                    if *index == outline.item_capture_ix {
 59                                        if node.end_position().row > node.start_position().row {
 60                                            return Some(node.byte_range());
 61                                        }
 62                                    }
 63                                    None
 64                                })
 65                        })
 66                        .collect::<Vec<_>>();
 67                    syntactic_ranges
 68                        .sort_unstable_by_key(|range| (range.start, Reverse(range.end)));
 69                });
 70            }
 71        }
 72    }
 73
 74    chunk_text_with_syntactic_ranges(text, &syntactic_ranges, size_config)
 75}
 76
 77fn chunk_text_with_syntactic_ranges(
 78    text: &str,
 79    mut syntactic_ranges: &[Range<usize>],
 80    size_config: ChunkSizeRange,
 81) -> Vec<Chunk> {
 82    let mut chunks = Vec::new();
 83    let mut range = 0..0;
 84    let mut range_end_nesting_depth = 0;
 85
 86    // Try to split the text at line boundaries.
 87    let mut line_ixs = text
 88        .match_indices('\n')
 89        .map(|(ix, _)| ix + 1)
 90        .chain(if text.ends_with('\n') {
 91            None
 92        } else {
 93            Some(text.len())
 94        })
 95        .peekable();
 96
 97    while let Some(&line_ix) = line_ixs.peek() {
 98        // If the current position is beyond the maximum chunk size, then
 99        // start a new chunk.
100        if line_ix - range.start > size_config.max {
101            if range.is_empty() {
102                range.end = cmp::min(range.start + size_config.max, line_ix);
103                while !text.is_char_boundary(range.end) {
104                    range.end -= 1;
105                }
106            }
107
108            chunks.push(Chunk {
109                range: range.clone(),
110                digest: Sha256::digest(&text[range.clone()]).into(),
111            });
112            range_end_nesting_depth = 0;
113            range.start = range.end;
114            continue;
115        }
116
117        // Discard any syntactic ranges that end before the current position.
118        while let Some(first_item) = syntactic_ranges.first() {
119            if first_item.end < line_ix {
120                syntactic_ranges = &syntactic_ranges[1..];
121                continue;
122            } else {
123                break;
124            }
125        }
126
127        // Count how many syntactic ranges contain the current position.
128        let mut nesting_depth = 0;
129        for range in syntactic_ranges {
130            if range.start > line_ix {
131                break;
132            }
133            if range.start < line_ix && range.end > line_ix {
134                nesting_depth += 1;
135            }
136        }
137
138        // Extend the current range to this position, unless an earlier candidate
139        // end position was less nested syntactically.
140        if range.len() < size_config.min || nesting_depth <= range_end_nesting_depth {
141            range.end = line_ix;
142            range_end_nesting_depth = nesting_depth;
143        }
144
145        line_ixs.next();
146    }
147
148    if !range.is_empty() {
149        chunks.push(Chunk {
150            range: range.clone(),
151            digest: Sha256::digest(&text[range.clone()]).into(),
152        });
153    }
154
155    chunks
156}
157
158#[cfg(test)]
159mod tests {
160    use super::*;
161    use language::{tree_sitter_rust, Language, LanguageConfig, LanguageMatcher};
162    use unindent::Unindent as _;
163
164    #[test]
165    fn test_chunk_text_with_syntax() {
166        let language = rust_language();
167
168        let text = "
169            struct Person {
170                first_name: String,
171                last_name: String,
172                age: u32,
173            }
174
175            impl Person {
176                fn new(first_name: String, last_name: String, age: u32) -> Self {
177                    Self { first_name, last_name, age }
178                }
179
180                fn first_name(&self) -> &str {
181                    &self.first_name
182                }
183
184                fn last_name(&self) -> &str {
185                    &self.last_name
186                }
187
188                fn age(&self) -> usize {
189                    self.ages
190                }
191            }
192        "
193        .unindent();
194
195        let chunks = chunk_text_with_size_range(
196            &text,
197            language.grammar(),
198            ChunkSizeRange {
199                min: text.find('}').unwrap(),
200                max: text.find("Self {").unwrap(),
201            },
202        );
203
204        // The entire impl cannot fit in a chunk, so it is split.
205        // Within the impl, two methods can fit in a chunk.
206        assert_chunks(
207            &text,
208            &chunks,
209            &[
210                "struct Person {", // ...
211                "impl Person {",
212                "    fn first_name",
213                "    fn age",
214            ],
215        );
216
217        let text = "
218            struct T {}
219            struct U {}
220            struct V {}
221            struct W {
222                a: T,
223                b: U,
224            }
225        "
226        .unindent();
227
228        let chunks = chunk_text_with_size_range(
229            &text,
230            language.grammar(),
231            ChunkSizeRange {
232                min: text.find('{').unwrap(),
233                max: text.find('V').unwrap(),
234            },
235        );
236
237        // Two single-line structs can fit in a chunk.
238        // The last struct cannot fit in a chunk together
239        // with the previous single-line struct.
240        assert_chunks(
241            &text,
242            &chunks,
243            &[
244                "struct T", // ...
245                "struct V", // ...
246                "struct W", // ...
247                "}",
248            ],
249        );
250    }
251
252    #[test]
253    fn test_chunk_with_long_lines() {
254        let language = rust_language();
255
256        let text = "
257            struct S { a: u32 }
258            struct T { a: u64 }
259            struct U { a: u64, b: u64, c: u64, d: u64, e: u64, f: u64, g: u64, h: u64, i: u64, j: u64 }
260            struct W { a: u64, b: u64, c: u64, d: u64, e: u64, f: u64, g: u64, h: u64, i: u64, j: u64 }
261        "
262        .unindent();
263
264        let chunks = chunk_text_with_size_range(
265            &text,
266            language.grammar(),
267            ChunkSizeRange { min: 32, max: 64 },
268        );
269
270        // The line is too long to fit in one chunk
271        assert_chunks(
272            &text,
273            &chunks,
274            &[
275                "struct S {", // ...
276                "struct U",
277                "4, h: u64, i: u64", // ...
278                "struct W",
279                "4, h: u64, i: u64", // ...
280            ],
281        );
282    }
283
284    #[track_caller]
285    fn assert_chunks(text: &str, chunks: &[Chunk], expected_chunk_text_prefixes: &[&str]) {
286        check_chunk_invariants(text, chunks);
287
288        assert_eq!(
289            chunks.len(),
290            expected_chunk_text_prefixes.len(),
291            "unexpected number of chunks: {chunks:?}",
292        );
293
294        let mut prev_chunk_end = 0;
295        for (ix, chunk) in chunks.iter().enumerate() {
296            let expected_prefix = expected_chunk_text_prefixes[ix];
297            let chunk_text = &text[chunk.range.clone()];
298            if !chunk_text.starts_with(expected_prefix) {
299                let chunk_prefix_offset = text[prev_chunk_end..].find(expected_prefix);
300                if let Some(chunk_prefix_offset) = chunk_prefix_offset {
301                    panic!(
302                        "chunk {ix} starts at unexpected offset {}. expected {}",
303                        chunk.range.start,
304                        chunk_prefix_offset + prev_chunk_end
305                    );
306                } else {
307                    panic!("invalid expected chunk prefix {ix}: {expected_prefix:?}");
308                }
309            }
310            prev_chunk_end = chunk.range.end;
311        }
312    }
313
314    #[track_caller]
315    fn check_chunk_invariants(text: &str, chunks: &[Chunk]) {
316        for (ix, chunk) in chunks.iter().enumerate() {
317            if ix > 0 && chunk.range.start != chunks[ix - 1].range.end {
318                panic!("chunk ranges are not contiguous: {:?}", chunks);
319            }
320        }
321
322        if text.is_empty() {
323            assert!(chunks.is_empty())
324        } else if chunks.first().unwrap().range.start != 0
325            || chunks.last().unwrap().range.end != text.len()
326        {
327            panic!("chunks don't cover entire text {:?}", chunks);
328        }
329    }
330
331    #[test]
332    fn test_chunk_text() {
333        let text = "a\n".repeat(1000);
334        let chunks = chunk_text(&text, None);
335        assert_eq!(
336            chunks.len(),
337            ((2000_f64) / (CHUNK_SIZE_RANGE.max as f64)).ceil() as usize
338        );
339    }
340
341    fn rust_language() -> Language {
342        Language::new(
343            LanguageConfig {
344                name: "Rust".into(),
345                matcher: LanguageMatcher {
346                    path_suffixes: vec!["rs".to_string()],
347                    ..Default::default()
348                },
349                ..Default::default()
350            },
351            Some(tree_sitter_rust::language()),
352        )
353        .with_outline_query(
354            "
355            (function_item name: (_) @name) @item
356            (impl_item type: (_) @name) @item
357            (struct_item name: (_) @name) @item
358            (field_declaration name: (_) @name) @item
359        ",
360        )
361        .unwrap()
362    }
363}