chunking.rs

  1use language::{with_parser, with_query_cursor, Language};
  2use serde::{Deserialize, Serialize};
  3use sha2::{Digest, Sha256};
  4use std::{
  5    cmp::{self, Reverse},
  6    ops::Range,
  7    path::Path,
  8    sync::Arc,
  9};
 10use tree_sitter::QueryCapture;
 11use util::ResultExt as _;
 12
 13#[derive(Copy, Clone)]
 14struct ChunkSizeRange {
 15    min: usize,
 16    max: usize,
 17}
 18
 19const CHUNK_SIZE_RANGE: ChunkSizeRange = ChunkSizeRange {
 20    min: 1024,
 21    max: 8192,
 22};
 23
 24#[derive(Debug, Clone, Serialize, Deserialize)]
 25pub struct Chunk {
 26    pub range: Range<usize>,
 27    pub digest: [u8; 32],
 28}
 29
 30pub fn chunk_text(text: &str, language: Option<&Arc<Language>>, path: &Path) -> Vec<Chunk> {
 31    chunk_text_with_size_range(text, language, path, CHUNK_SIZE_RANGE)
 32}
 33
 34fn chunk_text_with_size_range(
 35    text: &str,
 36    language: Option<&Arc<Language>>,
 37    path: &Path,
 38    size_config: ChunkSizeRange,
 39) -> Vec<Chunk> {
 40    let ranges = syntactic_ranges(text, language, path).unwrap_or_default();
 41    chunk_text_with_syntactic_ranges(text, &ranges, size_config)
 42}
 43
 44fn syntactic_ranges(
 45    text: &str,
 46    language: Option<&Arc<Language>>,
 47    path: &Path,
 48) -> Option<Vec<Range<usize>>> {
 49    let language = language?;
 50    let grammar = language.grammar()?;
 51    let outline = grammar.outline_config.as_ref()?;
 52    let tree = with_parser(|parser| {
 53        parser.set_language(&grammar.ts_language).log_err()?;
 54        parser.parse(text, None)
 55    });
 56
 57    let Some(tree) = tree else {
 58        log::error!("failed to parse file {path:?} for chunking");
 59        return None;
 60    };
 61
 62    struct RowInfo {
 63        offset: usize,
 64        is_comment: bool,
 65    }
 66
 67    let scope = language.default_scope();
 68    let line_comment_prefixes = scope.line_comment_prefixes();
 69    let row_infos = text
 70        .split('\n')
 71        .map({
 72            let mut offset = 0;
 73            move |line| {
 74                let line = line.trim_start();
 75                let is_comment = line_comment_prefixes
 76                    .iter()
 77                    .any(|prefix| line.starts_with(prefix.as_ref()));
 78                let result = RowInfo { offset, is_comment };
 79                offset += line.len() + 1;
 80                result
 81            }
 82        })
 83        .collect::<Vec<_>>();
 84
 85    // Retrieve a list of ranges of outline items (types, functions, etc) in the document.
 86    // Omit single-line outline items (e.g. struct fields, constant declarations), because
 87    // we'll already be attempting to split on lines.
 88    let mut ranges = with_query_cursor(|cursor| {
 89        cursor
 90            .matches(&outline.query, tree.root_node(), text.as_bytes())
 91            .filter_map(|mat| {
 92                mat.captures
 93                    .iter()
 94                    .find_map(|QueryCapture { node, index }| {
 95                        if *index == outline.item_capture_ix {
 96                            let mut start_offset = node.start_byte();
 97                            let mut start_row = node.start_position().row;
 98                            let end_offset = node.end_byte();
 99                            let end_row = node.end_position().row;
100
101                            // Expand the range to include any preceding comments.
102                            while start_row > 0 && row_infos[start_row - 1].is_comment {
103                                start_offset = row_infos[start_row - 1].offset;
104                                start_row -= 1;
105                            }
106
107                            if end_row > start_row {
108                                return Some(start_offset..end_offset);
109                            }
110                        }
111                        None
112                    })
113            })
114            .collect::<Vec<_>>()
115    });
116
117    ranges.sort_unstable_by_key(|range| (range.start, Reverse(range.end)));
118    Some(ranges)
119}
120
121fn chunk_text_with_syntactic_ranges(
122    text: &str,
123    mut syntactic_ranges: &[Range<usize>],
124    size_config: ChunkSizeRange,
125) -> Vec<Chunk> {
126    let mut chunks = Vec::new();
127    let mut range = 0..0;
128    let mut range_end_nesting_depth = 0;
129
130    // Try to split the text at line boundaries.
131    let mut line_ixs = text
132        .match_indices('\n')
133        .map(|(ix, _)| ix + 1)
134        .chain(if text.ends_with('\n') {
135            None
136        } else {
137            Some(text.len())
138        })
139        .peekable();
140
141    while let Some(&line_ix) = line_ixs.peek() {
142        // If the current position is beyond the maximum chunk size, then
143        // start a new chunk.
144        if line_ix - range.start > size_config.max {
145            if range.is_empty() {
146                range.end = cmp::min(range.start + size_config.max, line_ix);
147                while !text.is_char_boundary(range.end) {
148                    range.end -= 1;
149                }
150            }
151
152            chunks.push(Chunk {
153                range: range.clone(),
154                digest: Sha256::digest(&text[range.clone()]).into(),
155            });
156            range_end_nesting_depth = 0;
157            range.start = range.end;
158            continue;
159        }
160
161        // Discard any syntactic ranges that end before the current position.
162        while let Some(first_item) = syntactic_ranges.first() {
163            if first_item.end < line_ix {
164                syntactic_ranges = &syntactic_ranges[1..];
165                continue;
166            } else {
167                break;
168            }
169        }
170
171        // Count how many syntactic ranges contain the current position.
172        let mut nesting_depth = 0;
173        for range in syntactic_ranges {
174            if range.start > line_ix {
175                break;
176            }
177            if range.start < line_ix && range.end > line_ix {
178                nesting_depth += 1;
179            }
180        }
181
182        // Extend the current range to this position, unless an earlier candidate
183        // end position was less nested syntactically.
184        if range.len() < size_config.min || nesting_depth <= range_end_nesting_depth {
185            range.end = line_ix;
186            range_end_nesting_depth = nesting_depth;
187        }
188
189        line_ixs.next();
190    }
191
192    if !range.is_empty() {
193        chunks.push(Chunk {
194            range: range.clone(),
195            digest: Sha256::digest(&text[range]).into(),
196        });
197    }
198
199    chunks
200}
201
202#[cfg(test)]
203mod tests {
204    use super::*;
205    use language::{tree_sitter_rust, Language, LanguageConfig, LanguageMatcher};
206    use unindent::Unindent as _;
207
208    #[test]
209    fn test_chunk_text_with_syntax() {
210        let language = rust_language();
211
212        let text = "
213            struct Person {
214                first_name: String,
215                last_name: String,
216                age: u32,
217            }
218
219            impl Person {
220                fn new(first_name: String, last_name: String, age: u32) -> Self {
221                    Self { first_name, last_name, age }
222                }
223
224                /// Returns the first name
225                /// something something something
226                fn first_name(&self) -> &str {
227                    &self.first_name
228                }
229
230                fn last_name(&self) -> &str {
231                    &self.last_name
232                }
233
234                fn age(&self) -> u32 {
235                    self.age
236                }
237            }
238        "
239        .unindent();
240
241        let chunks = chunk_text_with_size_range(
242            &text,
243            Some(&language),
244            Path::new("lib.rs"),
245            ChunkSizeRange {
246                min: text.find('}').unwrap(),
247                max: text.find("Self {").unwrap(),
248            },
249        );
250
251        // The entire impl cannot fit in a chunk, so it is split.
252        // Within the impl, two methods can fit in a chunk.
253        assert_chunks(
254            &text,
255            &chunks,
256            &[
257                "struct Person {", // ...
258                "impl Person {",
259                "    /// Returns the first name",
260                "    fn last_name",
261            ],
262        );
263
264        let text = "
265            struct T {}
266            struct U {}
267            struct V {}
268            struct W {
269                a: T,
270                b: U,
271            }
272        "
273        .unindent();
274
275        let chunks = chunk_text_with_size_range(
276            &text,
277            Some(&language),
278            Path::new("lib.rs"),
279            ChunkSizeRange {
280                min: text.find('{').unwrap(),
281                max: text.find('V').unwrap(),
282            },
283        );
284
285        // Two single-line structs can fit in a chunk.
286        // The last struct cannot fit in a chunk together
287        // with the previous single-line struct.
288        assert_chunks(
289            &text,
290            &chunks,
291            &[
292                "struct T", // ...
293                "struct V", // ...
294                "struct W", // ...
295                "}",
296            ],
297        );
298    }
299
300    #[test]
301    fn test_chunk_with_long_lines() {
302        let language = rust_language();
303
304        let text = "
305            struct S { a: u32 }
306            struct T { a: u64 }
307            struct U { a: u64, b: u64, c: u64, d: u64, e: u64, f: u64, g: u64, h: u64, i: u64, j: u64 }
308            struct W { a: u64, b: u64, c: u64, d: u64, e: u64, f: u64, g: u64, h: u64, i: u64, j: u64 }
309        "
310        .unindent();
311
312        let chunks = chunk_text_with_size_range(
313            &text,
314            Some(&language),
315            Path::new("lib.rs"),
316            ChunkSizeRange { min: 32, max: 64 },
317        );
318
319        // The line is too long to fit in one chunk
320        assert_chunks(
321            &text,
322            &chunks,
323            &[
324                "struct S {", // ...
325                "struct U",
326                "4, h: u64, i: u64", // ...
327                "struct W",
328                "4, h: u64, i: u64", // ...
329            ],
330        );
331    }
332
333    #[track_caller]
334    fn assert_chunks(text: &str, chunks: &[Chunk], expected_chunk_text_prefixes: &[&str]) {
335        check_chunk_invariants(text, chunks);
336
337        assert_eq!(
338            chunks.len(),
339            expected_chunk_text_prefixes.len(),
340            "unexpected number of chunks: {chunks:?}",
341        );
342
343        let mut prev_chunk_end = 0;
344        for (ix, chunk) in chunks.iter().enumerate() {
345            let expected_prefix = expected_chunk_text_prefixes[ix];
346            let chunk_text = &text[chunk.range.clone()];
347            if !chunk_text.starts_with(expected_prefix) {
348                let chunk_prefix_offset = text[prev_chunk_end..].find(expected_prefix);
349                if let Some(chunk_prefix_offset) = chunk_prefix_offset {
350                    panic!(
351                        "chunk {ix} starts at unexpected offset {}. expected {}",
352                        chunk.range.start,
353                        chunk_prefix_offset + prev_chunk_end
354                    );
355                } else {
356                    panic!("invalid expected chunk prefix {ix}: {expected_prefix:?}");
357                }
358            }
359            prev_chunk_end = chunk.range.end;
360        }
361    }
362
363    #[track_caller]
364    fn check_chunk_invariants(text: &str, chunks: &[Chunk]) {
365        for (ix, chunk) in chunks.iter().enumerate() {
366            if ix > 0 && chunk.range.start != chunks[ix - 1].range.end {
367                panic!("chunk ranges are not contiguous: {:?}", chunks);
368            }
369        }
370
371        if text.is_empty() {
372            assert!(chunks.is_empty())
373        } else if chunks.first().unwrap().range.start != 0
374            || chunks.last().unwrap().range.end != text.len()
375        {
376            panic!("chunks don't cover entire text {:?}", chunks);
377        }
378    }
379
380    #[test]
381    fn test_chunk_text() {
382        let text = "a\n".repeat(1000);
383        let chunks = chunk_text(&text, None, Path::new("lib.rs"));
384        assert_eq!(
385            chunks.len(),
386            ((2000_f64) / (CHUNK_SIZE_RANGE.max as f64)).ceil() as usize
387        );
388    }
389
390    fn rust_language() -> Arc<Language> {
391        Arc::new(
392            Language::new(
393                LanguageConfig {
394                    name: "Rust".into(),
395                    matcher: LanguageMatcher {
396                        path_suffixes: vec!["rs".to_string()],
397                        ..Default::default()
398                    },
399                    ..Default::default()
400                },
401                Some(tree_sitter_rust::LANGUAGE.into()),
402            )
403            .with_outline_query(
404                "
405            (function_item name: (_) @name) @item
406            (impl_item type: (_) @name) @item
407            (struct_item name: (_) @name) @item
408            (field_declaration name: (_) @name) @item
409        ",
410            )
411            .unwrap(),
412        )
413    }
414}