1use language::{with_parser, with_query_cursor, Language};
2use serde::{Deserialize, Serialize};
3use sha2::{Digest, Sha256};
4use std::{
5 cmp::{self, Reverse},
6 ops::Range,
7 path::Path,
8 sync::Arc,
9};
10use tree_sitter::QueryCapture;
11use util::ResultExt as _;
12
13#[derive(Copy, Clone)]
14struct ChunkSizeRange {
15 min: usize,
16 max: usize,
17}
18
19const CHUNK_SIZE_RANGE: ChunkSizeRange = ChunkSizeRange {
20 min: 1024,
21 max: 8192,
22};
23
24#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct Chunk {
26 pub range: Range<usize>,
27 pub digest: [u8; 32],
28}
29
30pub fn chunk_text(text: &str, language: Option<&Arc<Language>>, path: &Path) -> Vec<Chunk> {
31 chunk_text_with_size_range(text, language, path, CHUNK_SIZE_RANGE)
32}
33
34fn chunk_text_with_size_range(
35 text: &str,
36 language: Option<&Arc<Language>>,
37 path: &Path,
38 size_config: ChunkSizeRange,
39) -> Vec<Chunk> {
40 let ranges = syntactic_ranges(text, language, path).unwrap_or_default();
41 chunk_text_with_syntactic_ranges(text, &ranges, size_config)
42}
43
44fn syntactic_ranges(
45 text: &str,
46 language: Option<&Arc<Language>>,
47 path: &Path,
48) -> Option<Vec<Range<usize>>> {
49 let language = language?;
50 let grammar = language.grammar()?;
51 let outline = grammar.outline_config.as_ref()?;
52 let tree = with_parser(|parser| {
53 parser.set_language(&grammar.ts_language).log_err()?;
54 parser.parse(text, None)
55 });
56
57 let Some(tree) = tree else {
58 log::error!("failed to parse file {path:?} for chunking");
59 return None;
60 };
61
62 struct RowInfo {
63 offset: usize,
64 is_comment: bool,
65 }
66
67 let scope = language.default_scope();
68 let line_comment_prefixes = scope.line_comment_prefixes();
69 let row_infos = text
70 .split('\n')
71 .map({
72 let mut offset = 0;
73 move |line| {
74 let line = line.trim_start();
75 let is_comment = line_comment_prefixes
76 .iter()
77 .any(|prefix| line.starts_with(prefix.as_ref()));
78 let result = RowInfo { offset, is_comment };
79 offset += line.len() + 1;
80 result
81 }
82 })
83 .collect::<Vec<_>>();
84
85 // Retrieve a list of ranges of outline items (types, functions, etc) in the document.
86 // Omit single-line outline items (e.g. struct fields, constant declarations), because
87 // we'll already be attempting to split on lines.
88 let mut ranges = with_query_cursor(|cursor| {
89 cursor
90 .matches(&outline.query, tree.root_node(), text.as_bytes())
91 .filter_map(|mat| {
92 mat.captures
93 .iter()
94 .find_map(|QueryCapture { node, index }| {
95 if *index == outline.item_capture_ix {
96 let mut start_offset = node.start_byte();
97 let mut start_row = node.start_position().row;
98 let end_offset = node.end_byte();
99 let end_row = node.end_position().row;
100
101 // Expand the range to include any preceding comments.
102 while start_row > 0 && row_infos[start_row - 1].is_comment {
103 start_offset = row_infos[start_row - 1].offset;
104 start_row -= 1;
105 }
106
107 if end_row > start_row {
108 return Some(start_offset..end_offset);
109 }
110 }
111 None
112 })
113 })
114 .collect::<Vec<_>>()
115 });
116
117 ranges.sort_unstable_by_key(|range| (range.start, Reverse(range.end)));
118 Some(ranges)
119}
120
121fn chunk_text_with_syntactic_ranges(
122 text: &str,
123 mut syntactic_ranges: &[Range<usize>],
124 size_config: ChunkSizeRange,
125) -> Vec<Chunk> {
126 let mut chunks = Vec::new();
127 let mut range = 0..0;
128 let mut range_end_nesting_depth = 0;
129
130 // Try to split the text at line boundaries.
131 let mut line_ixs = text
132 .match_indices('\n')
133 .map(|(ix, _)| ix + 1)
134 .chain(if text.ends_with('\n') {
135 None
136 } else {
137 Some(text.len())
138 })
139 .peekable();
140
141 while let Some(&line_ix) = line_ixs.peek() {
142 // If the current position is beyond the maximum chunk size, then
143 // start a new chunk.
144 if line_ix - range.start > size_config.max {
145 if range.is_empty() {
146 range.end = cmp::min(range.start + size_config.max, line_ix);
147 while !text.is_char_boundary(range.end) {
148 range.end -= 1;
149 }
150 }
151
152 chunks.push(Chunk {
153 range: range.clone(),
154 digest: Sha256::digest(&text[range.clone()]).into(),
155 });
156 range_end_nesting_depth = 0;
157 range.start = range.end;
158 continue;
159 }
160
161 // Discard any syntactic ranges that end before the current position.
162 while let Some(first_item) = syntactic_ranges.first() {
163 if first_item.end < line_ix {
164 syntactic_ranges = &syntactic_ranges[1..];
165 continue;
166 } else {
167 break;
168 }
169 }
170
171 // Count how many syntactic ranges contain the current position.
172 let mut nesting_depth = 0;
173 for range in syntactic_ranges {
174 if range.start > line_ix {
175 break;
176 }
177 if range.start < line_ix && range.end > line_ix {
178 nesting_depth += 1;
179 }
180 }
181
182 // Extend the current range to this position, unless an earlier candidate
183 // end position was less nested syntactically.
184 if range.len() < size_config.min || nesting_depth <= range_end_nesting_depth {
185 range.end = line_ix;
186 range_end_nesting_depth = nesting_depth;
187 }
188
189 line_ixs.next();
190 }
191
192 if !range.is_empty() {
193 chunks.push(Chunk {
194 range: range.clone(),
195 digest: Sha256::digest(&text[range]).into(),
196 });
197 }
198
199 chunks
200}
201
202#[cfg(test)]
203mod tests {
204 use super::*;
205 use language::{tree_sitter_rust, Language, LanguageConfig, LanguageMatcher};
206 use unindent::Unindent as _;
207
208 #[test]
209 fn test_chunk_text_with_syntax() {
210 let language = rust_language();
211
212 let text = "
213 struct Person {
214 first_name: String,
215 last_name: String,
216 age: u32,
217 }
218
219 impl Person {
220 fn new(first_name: String, last_name: String, age: u32) -> Self {
221 Self { first_name, last_name, age }
222 }
223
224 /// Returns the first name
225 /// something something something
226 fn first_name(&self) -> &str {
227 &self.first_name
228 }
229
230 fn last_name(&self) -> &str {
231 &self.last_name
232 }
233
234 fn age(&self) -> u32 {
235 self.age
236 }
237 }
238 "
239 .unindent();
240
241 let chunks = chunk_text_with_size_range(
242 &text,
243 Some(&language),
244 Path::new("lib.rs"),
245 ChunkSizeRange {
246 min: text.find('}').unwrap(),
247 max: text.find("Self {").unwrap(),
248 },
249 );
250
251 // The entire impl cannot fit in a chunk, so it is split.
252 // Within the impl, two methods can fit in a chunk.
253 assert_chunks(
254 &text,
255 &chunks,
256 &[
257 "struct Person {", // ...
258 "impl Person {",
259 " /// Returns the first name",
260 " fn last_name",
261 ],
262 );
263
264 let text = "
265 struct T {}
266 struct U {}
267 struct V {}
268 struct W {
269 a: T,
270 b: U,
271 }
272 "
273 .unindent();
274
275 let chunks = chunk_text_with_size_range(
276 &text,
277 Some(&language),
278 Path::new("lib.rs"),
279 ChunkSizeRange {
280 min: text.find('{').unwrap(),
281 max: text.find('V').unwrap(),
282 },
283 );
284
285 // Two single-line structs can fit in a chunk.
286 // The last struct cannot fit in a chunk together
287 // with the previous single-line struct.
288 assert_chunks(
289 &text,
290 &chunks,
291 &[
292 "struct T", // ...
293 "struct V", // ...
294 "struct W", // ...
295 "}",
296 ],
297 );
298 }
299
300 #[test]
301 fn test_chunk_with_long_lines() {
302 let language = rust_language();
303
304 let text = "
305 struct S { a: u32 }
306 struct T { a: u64 }
307 struct U { a: u64, b: u64, c: u64, d: u64, e: u64, f: u64, g: u64, h: u64, i: u64, j: u64 }
308 struct W { a: u64, b: u64, c: u64, d: u64, e: u64, f: u64, g: u64, h: u64, i: u64, j: u64 }
309 "
310 .unindent();
311
312 let chunks = chunk_text_with_size_range(
313 &text,
314 Some(&language),
315 Path::new("lib.rs"),
316 ChunkSizeRange { min: 32, max: 64 },
317 );
318
319 // The line is too long to fit in one chunk
320 assert_chunks(
321 &text,
322 &chunks,
323 &[
324 "struct S {", // ...
325 "struct U",
326 "4, h: u64, i: u64", // ...
327 "struct W",
328 "4, h: u64, i: u64", // ...
329 ],
330 );
331 }
332
333 #[track_caller]
334 fn assert_chunks(text: &str, chunks: &[Chunk], expected_chunk_text_prefixes: &[&str]) {
335 check_chunk_invariants(text, chunks);
336
337 assert_eq!(
338 chunks.len(),
339 expected_chunk_text_prefixes.len(),
340 "unexpected number of chunks: {chunks:?}",
341 );
342
343 let mut prev_chunk_end = 0;
344 for (ix, chunk) in chunks.iter().enumerate() {
345 let expected_prefix = expected_chunk_text_prefixes[ix];
346 let chunk_text = &text[chunk.range.clone()];
347 if !chunk_text.starts_with(expected_prefix) {
348 let chunk_prefix_offset = text[prev_chunk_end..].find(expected_prefix);
349 if let Some(chunk_prefix_offset) = chunk_prefix_offset {
350 panic!(
351 "chunk {ix} starts at unexpected offset {}. expected {}",
352 chunk.range.start,
353 chunk_prefix_offset + prev_chunk_end
354 );
355 } else {
356 panic!("invalid expected chunk prefix {ix}: {expected_prefix:?}");
357 }
358 }
359 prev_chunk_end = chunk.range.end;
360 }
361 }
362
363 #[track_caller]
364 fn check_chunk_invariants(text: &str, chunks: &[Chunk]) {
365 for (ix, chunk) in chunks.iter().enumerate() {
366 if ix > 0 && chunk.range.start != chunks[ix - 1].range.end {
367 panic!("chunk ranges are not contiguous: {:?}", chunks);
368 }
369 }
370
371 if text.is_empty() {
372 assert!(chunks.is_empty())
373 } else if chunks.first().unwrap().range.start != 0
374 || chunks.last().unwrap().range.end != text.len()
375 {
376 panic!("chunks don't cover entire text {:?}", chunks);
377 }
378 }
379
380 #[test]
381 fn test_chunk_text() {
382 let text = "a\n".repeat(1000);
383 let chunks = chunk_text(&text, None, Path::new("lib.rs"));
384 assert_eq!(
385 chunks.len(),
386 ((2000_f64) / (CHUNK_SIZE_RANGE.max as f64)).ceil() as usize
387 );
388 }
389
390 fn rust_language() -> Arc<Language> {
391 Arc::new(
392 Language::new(
393 LanguageConfig {
394 name: "Rust".into(),
395 matcher: LanguageMatcher {
396 path_suffixes: vec!["rs".to_string()],
397 ..Default::default()
398 },
399 ..Default::default()
400 },
401 Some(tree_sitter_rust::LANGUAGE.into()),
402 )
403 .with_outline_query(
404 "
405 (function_item name: (_) @name) @item
406 (impl_item type: (_) @name) @item
407 (struct_item name: (_) @name) @item
408 (field_declaration name: (_) @name) @item
409 ",
410 )
411 .unwrap(),
412 )
413 }
414}