chunking.rs

  1use language::{Language, with_parser, with_query_cursor};
  2use serde::{Deserialize, Serialize};
  3use sha2::{Digest, Sha256};
  4use std::{
  5    cmp::{self, Reverse},
  6    ops::Range,
  7    path::Path,
  8    sync::Arc,
  9};
 10use streaming_iterator::StreamingIterator;
 11use tree_sitter::QueryCapture;
 12use util::ResultExt as _;
 13
 14#[derive(Copy, Clone)]
 15struct ChunkSizeRange {
 16    min: usize,
 17    max: usize,
 18}
 19
 20const CHUNK_SIZE_RANGE: ChunkSizeRange = ChunkSizeRange {
 21    min: 1024,
 22    max: 8192,
 23};
 24
 25#[derive(Debug, Clone, Serialize, Deserialize)]
 26pub struct Chunk {
 27    pub range: Range<usize>,
 28    pub digest: [u8; 32],
 29}
 30
 31pub fn chunk_text(text: &str, language: Option<&Arc<Language>>, path: &Path) -> Vec<Chunk> {
 32    chunk_text_with_size_range(text, language, path, CHUNK_SIZE_RANGE)
 33}
 34
 35fn chunk_text_with_size_range(
 36    text: &str,
 37    language: Option<&Arc<Language>>,
 38    path: &Path,
 39    size_config: ChunkSizeRange,
 40) -> Vec<Chunk> {
 41    let ranges = syntactic_ranges(text, language, path).unwrap_or_default();
 42    chunk_text_with_syntactic_ranges(text, &ranges, size_config)
 43}
 44
 45fn syntactic_ranges(
 46    text: &str,
 47    language: Option<&Arc<Language>>,
 48    path: &Path,
 49) -> Option<Vec<Range<usize>>> {
 50    let language = language?;
 51    let grammar = language.grammar()?;
 52    let outline = grammar.outline_config.as_ref()?;
 53    let tree = with_parser(|parser| {
 54        parser.set_language(&grammar.ts_language).log_err()?;
 55        parser.parse(text, None)
 56    });
 57
 58    let Some(tree) = tree else {
 59        log::error!("failed to parse file {path:?} for chunking");
 60        return None;
 61    };
 62
 63    struct RowInfo {
 64        offset: usize,
 65        is_comment: bool,
 66    }
 67
 68    let scope = language.default_scope();
 69    let line_comment_prefixes = scope.line_comment_prefixes();
 70    let row_infos = text
 71        .split('\n')
 72        .map({
 73            let mut offset = 0;
 74            move |line| {
 75                let line = line.trim_start();
 76                let is_comment = line_comment_prefixes
 77                    .iter()
 78                    .any(|prefix| line.starts_with(prefix.as_ref()));
 79                let result = RowInfo { offset, is_comment };
 80                offset += line.len() + 1;
 81                result
 82            }
 83        })
 84        .collect::<Vec<_>>();
 85
 86    // Retrieve a list of ranges of outline items (types, functions, etc) in the document.
 87    // Omit single-line outline items (e.g. struct fields, constant declarations), because
 88    // we'll already be attempting to split on lines.
 89    let mut ranges = with_query_cursor(|cursor| {
 90        cursor
 91            .matches(&outline.query, tree.root_node(), text.as_bytes())
 92            .filter_map_deref(|mat| {
 93                mat.captures
 94                    .iter()
 95                    .find_map(|QueryCapture { node, index }| {
 96                        if *index == outline.item_capture_ix {
 97                            let mut start_offset = node.start_byte();
 98                            let mut start_row = node.start_position().row;
 99                            let end_offset = node.end_byte();
100                            let end_row = node.end_position().row;
101
102                            // Expand the range to include any preceding comments.
103                            while start_row > 0 && row_infos[start_row - 1].is_comment {
104                                start_offset = row_infos[start_row - 1].offset;
105                                start_row -= 1;
106                            }
107
108                            if end_row > start_row {
109                                return Some(start_offset..end_offset);
110                            }
111                        }
112                        None
113                    })
114            })
115            .collect::<Vec<_>>()
116    });
117
118    ranges.sort_unstable_by_key(|range| (range.start, Reverse(range.end)));
119    Some(ranges)
120}
121
122fn chunk_text_with_syntactic_ranges(
123    text: &str,
124    mut syntactic_ranges: &[Range<usize>],
125    size_config: ChunkSizeRange,
126) -> Vec<Chunk> {
127    let mut chunks = Vec::new();
128    let mut range = 0..0;
129    let mut range_end_nesting_depth = 0;
130
131    // Try to split the text at line boundaries.
132    let mut line_ixs = text
133        .match_indices('\n')
134        .map(|(ix, _)| ix + 1)
135        .chain(if text.ends_with('\n') {
136            None
137        } else {
138            Some(text.len())
139        })
140        .peekable();
141
142    while let Some(&line_ix) = line_ixs.peek() {
143        // If the current position is beyond the maximum chunk size, then
144        // start a new chunk.
145        if line_ix - range.start > size_config.max {
146            if range.is_empty() {
147                range.end = cmp::min(range.start + size_config.max, line_ix);
148                while !text.is_char_boundary(range.end) {
149                    range.end -= 1;
150                }
151            }
152
153            chunks.push(Chunk {
154                range: range.clone(),
155                digest: Sha256::digest(&text[range.clone()]).into(),
156            });
157            range_end_nesting_depth = 0;
158            range.start = range.end;
159            continue;
160        }
161
162        // Discard any syntactic ranges that end before the current position.
163        while let Some(first_item) = syntactic_ranges.first() {
164            if first_item.end < line_ix {
165                syntactic_ranges = &syntactic_ranges[1..];
166                continue;
167            } else {
168                break;
169            }
170        }
171
172        // Count how many syntactic ranges contain the current position.
173        let mut nesting_depth = 0;
174        for range in syntactic_ranges {
175            if range.start > line_ix {
176                break;
177            }
178            if range.start < line_ix && range.end > line_ix {
179                nesting_depth += 1;
180            }
181        }
182
183        // Extend the current range to this position, unless an earlier candidate
184        // end position was less nested syntactically.
185        if range.len() < size_config.min || nesting_depth <= range_end_nesting_depth {
186            range.end = line_ix;
187            range_end_nesting_depth = nesting_depth;
188        }
189
190        line_ixs.next();
191    }
192
193    if !range.is_empty() {
194        chunks.push(Chunk {
195            range: range.clone(),
196            digest: Sha256::digest(&text[range]).into(),
197        });
198    }
199
200    chunks
201}
202
203#[cfg(test)]
204mod tests {
205    use super::*;
206    use language::{Language, LanguageConfig, LanguageMatcher, tree_sitter_rust};
207    use unindent::Unindent as _;
208
209    #[test]
210    fn test_chunk_text_with_syntax() {
211        let language = rust_language();
212
213        let text = "
214            struct Person {
215                first_name: String,
216                last_name: String,
217                age: u32,
218            }
219
220            impl Person {
221                fn new(first_name: String, last_name: String, age: u32) -> Self {
222                    Self { first_name, last_name, age }
223                }
224
225                /// Returns the first name
226                /// something something something
227                fn first_name(&self) -> &str {
228                    &self.first_name
229                }
230
231                fn last_name(&self) -> &str {
232                    &self.last_name
233                }
234
235                fn age(&self) -> u32 {
236                    self.age
237                }
238            }
239        "
240        .unindent();
241
242        let chunks = chunk_text_with_size_range(
243            &text,
244            Some(&language),
245            Path::new("lib.rs"),
246            ChunkSizeRange {
247                min: text.find('}').unwrap(),
248                max: text.find("Self {").unwrap(),
249            },
250        );
251
252        // The entire impl cannot fit in a chunk, so it is split.
253        // Within the impl, two methods can fit in a chunk.
254        assert_chunks(
255            &text,
256            &chunks,
257            &[
258                "struct Person {", // ...
259                "impl Person {",
260                "    /// Returns the first name",
261                "    fn last_name",
262            ],
263        );
264
265        let text = "
266            struct T {}
267            struct U {}
268            struct V {}
269            struct W {
270                a: T,
271                b: U,
272            }
273        "
274        .unindent();
275
276        let chunks = chunk_text_with_size_range(
277            &text,
278            Some(&language),
279            Path::new("lib.rs"),
280            ChunkSizeRange {
281                min: text.find('{').unwrap(),
282                max: text.find('V').unwrap(),
283            },
284        );
285
286        // Two single-line structs can fit in a chunk.
287        // The last struct cannot fit in a chunk together
288        // with the previous single-line struct.
289        assert_chunks(
290            &text,
291            &chunks,
292            &[
293                "struct T", // ...
294                "struct V", // ...
295                "struct W", // ...
296                "}",
297            ],
298        );
299    }
300
301    #[test]
302    fn test_chunk_with_long_lines() {
303        let language = rust_language();
304
305        let text = "
306            struct S { a: u32 }
307            struct T { a: u64 }
308            struct U { a: u64, b: u64, c: u64, d: u64, e: u64, f: u64, g: u64, h: u64, i: u64, j: u64 }
309            struct W { a: u64, b: u64, c: u64, d: u64, e: u64, f: u64, g: u64, h: u64, i: u64, j: u64 }
310        "
311        .unindent();
312
313        let chunks = chunk_text_with_size_range(
314            &text,
315            Some(&language),
316            Path::new("lib.rs"),
317            ChunkSizeRange { min: 32, max: 64 },
318        );
319
320        // The line is too long to fit in one chunk
321        assert_chunks(
322            &text,
323            &chunks,
324            &[
325                "struct S {", // ...
326                "struct U",
327                "4, h: u64, i: u64", // ...
328                "struct W",
329                "4, h: u64, i: u64", // ...
330            ],
331        );
332    }
333
334    #[track_caller]
335    fn assert_chunks(text: &str, chunks: &[Chunk], expected_chunk_text_prefixes: &[&str]) {
336        check_chunk_invariants(text, chunks);
337
338        assert_eq!(
339            chunks.len(),
340            expected_chunk_text_prefixes.len(),
341            "unexpected number of chunks: {chunks:?}",
342        );
343
344        let mut prev_chunk_end = 0;
345        for (ix, chunk) in chunks.iter().enumerate() {
346            let expected_prefix = expected_chunk_text_prefixes[ix];
347            let chunk_text = &text[chunk.range.clone()];
348            if !chunk_text.starts_with(expected_prefix) {
349                let chunk_prefix_offset = text[prev_chunk_end..].find(expected_prefix);
350                if let Some(chunk_prefix_offset) = chunk_prefix_offset {
351                    panic!(
352                        "chunk {ix} starts at unexpected offset {}. expected {}",
353                        chunk.range.start,
354                        chunk_prefix_offset + prev_chunk_end
355                    );
356                } else {
357                    panic!("invalid expected chunk prefix {ix}: {expected_prefix:?}");
358                }
359            }
360            prev_chunk_end = chunk.range.end;
361        }
362    }
363
364    #[track_caller]
365    fn check_chunk_invariants(text: &str, chunks: &[Chunk]) {
366        for (ix, chunk) in chunks.iter().enumerate() {
367            if ix > 0 && chunk.range.start != chunks[ix - 1].range.end {
368                panic!("chunk ranges are not contiguous: {:?}", chunks);
369            }
370        }
371
372        if text.is_empty() {
373            assert!(chunks.is_empty())
374        } else if chunks.first().unwrap().range.start != 0
375            || chunks.last().unwrap().range.end != text.len()
376        {
377            panic!("chunks don't cover entire text {:?}", chunks);
378        }
379    }
380
381    #[test]
382    fn test_chunk_text() {
383        let text = "a\n".repeat(1000);
384        let chunks = chunk_text(&text, None, Path::new("lib.rs"));
385        assert_eq!(
386            chunks.len(),
387            ((2000_f64) / (CHUNK_SIZE_RANGE.max as f64)).ceil() as usize
388        );
389    }
390
391    fn rust_language() -> Arc<Language> {
392        Arc::new(
393            Language::new(
394                LanguageConfig {
395                    name: "Rust".into(),
396                    matcher: LanguageMatcher {
397                        path_suffixes: vec!["rs".to_string()],
398                        ..Default::default()
399                    },
400                    ..Default::default()
401                },
402                Some(tree_sitter_rust::LANGUAGE.into()),
403            )
404            .with_outline_query(
405                "
406            (function_item name: (_) @name) @item
407            (impl_item type: (_) @name) @item
408            (struct_item name: (_) @name) @item
409            (field_declaration name: (_) @name) @item
410        ",
411            )
412            .unwrap(),
413        )
414    }
415}