1use language::{Language, with_parser, with_query_cursor};
2use serde::{Deserialize, Serialize};
3use sha2::{Digest, Sha256};
4use std::{
5 cmp::{self, Reverse},
6 ops::Range,
7 path::Path,
8 sync::Arc,
9};
10use streaming_iterator::StreamingIterator;
11use tree_sitter::QueryCapture;
12use util::ResultExt as _;
13
14#[derive(Copy, Clone)]
15struct ChunkSizeRange {
16 min: usize,
17 max: usize,
18}
19
20const CHUNK_SIZE_RANGE: ChunkSizeRange = ChunkSizeRange {
21 min: 1024,
22 max: 8192,
23};
24
25#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct Chunk {
27 pub range: Range<usize>,
28 pub digest: [u8; 32],
29}
30
31pub fn chunk_text(text: &str, language: Option<&Arc<Language>>, path: &Path) -> Vec<Chunk> {
32 chunk_text_with_size_range(text, language, path, CHUNK_SIZE_RANGE)
33}
34
35fn chunk_text_with_size_range(
36 text: &str,
37 language: Option<&Arc<Language>>,
38 path: &Path,
39 size_config: ChunkSizeRange,
40) -> Vec<Chunk> {
41 let ranges = syntactic_ranges(text, language, path).unwrap_or_default();
42 chunk_text_with_syntactic_ranges(text, &ranges, size_config)
43}
44
45fn syntactic_ranges(
46 text: &str,
47 language: Option<&Arc<Language>>,
48 path: &Path,
49) -> Option<Vec<Range<usize>>> {
50 let language = language?;
51 let grammar = language.grammar()?;
52 let outline = grammar.outline_config.as_ref()?;
53 let tree = with_parser(|parser| {
54 parser.set_language(&grammar.ts_language).log_err()?;
55 parser.parse(text, None)
56 });
57
58 let Some(tree) = tree else {
59 log::error!("failed to parse file {path:?} for chunking");
60 return None;
61 };
62
63 struct RowInfo {
64 offset: usize,
65 is_comment: bool,
66 }
67
68 let scope = language.default_scope();
69 let line_comment_prefixes = scope.line_comment_prefixes();
70 let row_infos = text
71 .split('\n')
72 .map({
73 let mut offset = 0;
74 move |line| {
75 let line = line.trim_start();
76 let is_comment = line_comment_prefixes
77 .iter()
78 .any(|prefix| line.starts_with(prefix.as_ref()));
79 let result = RowInfo { offset, is_comment };
80 offset += line.len() + 1;
81 result
82 }
83 })
84 .collect::<Vec<_>>();
85
86 // Retrieve a list of ranges of outline items (types, functions, etc) in the document.
87 // Omit single-line outline items (e.g. struct fields, constant declarations), because
88 // we'll already be attempting to split on lines.
89 let mut ranges = with_query_cursor(|cursor| {
90 cursor
91 .matches(&outline.query, tree.root_node(), text.as_bytes())
92 .filter_map_deref(|mat| {
93 mat.captures
94 .iter()
95 .find_map(|QueryCapture { node, index }| {
96 if *index == outline.item_capture_ix {
97 let mut start_offset = node.start_byte();
98 let mut start_row = node.start_position().row;
99 let end_offset = node.end_byte();
100 let end_row = node.end_position().row;
101
102 // Expand the range to include any preceding comments.
103 while start_row > 0 && row_infos[start_row - 1].is_comment {
104 start_offset = row_infos[start_row - 1].offset;
105 start_row -= 1;
106 }
107
108 if end_row > start_row {
109 return Some(start_offset..end_offset);
110 }
111 }
112 None
113 })
114 })
115 .collect::<Vec<_>>()
116 });
117
118 ranges.sort_unstable_by_key(|range| (range.start, Reverse(range.end)));
119 Some(ranges)
120}
121
122fn chunk_text_with_syntactic_ranges(
123 text: &str,
124 mut syntactic_ranges: &[Range<usize>],
125 size_config: ChunkSizeRange,
126) -> Vec<Chunk> {
127 let mut chunks = Vec::new();
128 let mut range = 0..0;
129 let mut range_end_nesting_depth = 0;
130
131 // Try to split the text at line boundaries.
132 let mut line_ixs = text
133 .match_indices('\n')
134 .map(|(ix, _)| ix + 1)
135 .chain(if text.ends_with('\n') {
136 None
137 } else {
138 Some(text.len())
139 })
140 .peekable();
141
142 while let Some(&line_ix) = line_ixs.peek() {
143 // If the current position is beyond the maximum chunk size, then
144 // start a new chunk.
145 if line_ix - range.start > size_config.max {
146 if range.is_empty() {
147 range.end = cmp::min(range.start + size_config.max, line_ix);
148 while !text.is_char_boundary(range.end) {
149 range.end -= 1;
150 }
151 }
152
153 chunks.push(Chunk {
154 range: range.clone(),
155 digest: Sha256::digest(&text[range.clone()]).into(),
156 });
157 range_end_nesting_depth = 0;
158 range.start = range.end;
159 continue;
160 }
161
162 // Discard any syntactic ranges that end before the current position.
163 while let Some(first_item) = syntactic_ranges.first() {
164 if first_item.end < line_ix {
165 syntactic_ranges = &syntactic_ranges[1..];
166 continue;
167 } else {
168 break;
169 }
170 }
171
172 // Count how many syntactic ranges contain the current position.
173 let mut nesting_depth = 0;
174 for range in syntactic_ranges {
175 if range.start > line_ix {
176 break;
177 }
178 if range.start < line_ix && range.end > line_ix {
179 nesting_depth += 1;
180 }
181 }
182
183 // Extend the current range to this position, unless an earlier candidate
184 // end position was less nested syntactically.
185 if range.len() < size_config.min || nesting_depth <= range_end_nesting_depth {
186 range.end = line_ix;
187 range_end_nesting_depth = nesting_depth;
188 }
189
190 line_ixs.next();
191 }
192
193 if !range.is_empty() {
194 chunks.push(Chunk {
195 range: range.clone(),
196 digest: Sha256::digest(&text[range]).into(),
197 });
198 }
199
200 chunks
201}
202
203#[cfg(test)]
204mod tests {
205 use super::*;
206 use language::{Language, LanguageConfig, LanguageMatcher, tree_sitter_rust};
207 use unindent::Unindent as _;
208
209 #[test]
210 fn test_chunk_text_with_syntax() {
211 let language = rust_language();
212
213 let text = "
214 struct Person {
215 first_name: String,
216 last_name: String,
217 age: u32,
218 }
219
220 impl Person {
221 fn new(first_name: String, last_name: String, age: u32) -> Self {
222 Self { first_name, last_name, age }
223 }
224
225 /// Returns the first name
226 /// something something something
227 fn first_name(&self) -> &str {
228 &self.first_name
229 }
230
231 fn last_name(&self) -> &str {
232 &self.last_name
233 }
234
235 fn age(&self) -> u32 {
236 self.age
237 }
238 }
239 "
240 .unindent();
241
242 let chunks = chunk_text_with_size_range(
243 &text,
244 Some(&language),
245 Path::new("lib.rs"),
246 ChunkSizeRange {
247 min: text.find('}').unwrap(),
248 max: text.find("Self {").unwrap(),
249 },
250 );
251
252 // The entire impl cannot fit in a chunk, so it is split.
253 // Within the impl, two methods can fit in a chunk.
254 assert_chunks(
255 &text,
256 &chunks,
257 &[
258 "struct Person {", // ...
259 "impl Person {",
260 " /// Returns the first name",
261 " fn last_name",
262 ],
263 );
264
265 let text = "
266 struct T {}
267 struct U {}
268 struct V {}
269 struct W {
270 a: T,
271 b: U,
272 }
273 "
274 .unindent();
275
276 let chunks = chunk_text_with_size_range(
277 &text,
278 Some(&language),
279 Path::new("lib.rs"),
280 ChunkSizeRange {
281 min: text.find('{').unwrap(),
282 max: text.find('V').unwrap(),
283 },
284 );
285
286 // Two single-line structs can fit in a chunk.
287 // The last struct cannot fit in a chunk together
288 // with the previous single-line struct.
289 assert_chunks(
290 &text,
291 &chunks,
292 &[
293 "struct T", // ...
294 "struct V", // ...
295 "struct W", // ...
296 "}",
297 ],
298 );
299 }
300
301 #[test]
302 fn test_chunk_with_long_lines() {
303 let language = rust_language();
304
305 let text = "
306 struct S { a: u32 }
307 struct T { a: u64 }
308 struct U { a: u64, b: u64, c: u64, d: u64, e: u64, f: u64, g: u64, h: u64, i: u64, j: u64 }
309 struct W { a: u64, b: u64, c: u64, d: u64, e: u64, f: u64, g: u64, h: u64, i: u64, j: u64 }
310 "
311 .unindent();
312
313 let chunks = chunk_text_with_size_range(
314 &text,
315 Some(&language),
316 Path::new("lib.rs"),
317 ChunkSizeRange { min: 32, max: 64 },
318 );
319
320 // The line is too long to fit in one chunk
321 assert_chunks(
322 &text,
323 &chunks,
324 &[
325 "struct S {", // ...
326 "struct U",
327 "4, h: u64, i: u64", // ...
328 "struct W",
329 "4, h: u64, i: u64", // ...
330 ],
331 );
332 }
333
334 #[track_caller]
335 fn assert_chunks(text: &str, chunks: &[Chunk], expected_chunk_text_prefixes: &[&str]) {
336 check_chunk_invariants(text, chunks);
337
338 assert_eq!(
339 chunks.len(),
340 expected_chunk_text_prefixes.len(),
341 "unexpected number of chunks: {chunks:?}",
342 );
343
344 let mut prev_chunk_end = 0;
345 for (ix, chunk) in chunks.iter().enumerate() {
346 let expected_prefix = expected_chunk_text_prefixes[ix];
347 let chunk_text = &text[chunk.range.clone()];
348 if !chunk_text.starts_with(expected_prefix) {
349 let chunk_prefix_offset = text[prev_chunk_end..].find(expected_prefix);
350 if let Some(chunk_prefix_offset) = chunk_prefix_offset {
351 panic!(
352 "chunk {ix} starts at unexpected offset {}. expected {}",
353 chunk.range.start,
354 chunk_prefix_offset + prev_chunk_end
355 );
356 } else {
357 panic!("invalid expected chunk prefix {ix}: {expected_prefix:?}");
358 }
359 }
360 prev_chunk_end = chunk.range.end;
361 }
362 }
363
364 #[track_caller]
365 fn check_chunk_invariants(text: &str, chunks: &[Chunk]) {
366 for (ix, chunk) in chunks.iter().enumerate() {
367 if ix > 0 && chunk.range.start != chunks[ix - 1].range.end {
368 panic!("chunk ranges are not contiguous: {:?}", chunks);
369 }
370 }
371
372 if text.is_empty() {
373 assert!(chunks.is_empty())
374 } else if chunks.first().unwrap().range.start != 0
375 || chunks.last().unwrap().range.end != text.len()
376 {
377 panic!("chunks don't cover entire text {:?}", chunks);
378 }
379 }
380
381 #[test]
382 fn test_chunk_text() {
383 let text = "a\n".repeat(1000);
384 let chunks = chunk_text(&text, None, Path::new("lib.rs"));
385 assert_eq!(
386 chunks.len(),
387 ((2000_f64) / (CHUNK_SIZE_RANGE.max as f64)).ceil() as usize
388 );
389 }
390
391 fn rust_language() -> Arc<Language> {
392 Arc::new(
393 Language::new(
394 LanguageConfig {
395 name: "Rust".into(),
396 matcher: LanguageMatcher {
397 path_suffixes: vec!["rs".to_string()],
398 ..Default::default()
399 },
400 ..Default::default()
401 },
402 Some(tree_sitter_rust::LANGUAGE.into()),
403 )
404 .with_outline_query(
405 "
406 (function_item name: (_) @name) @item
407 (impl_item type: (_) @name) @item
408 (struct_item name: (_) @name) @item
409 (field_declaration name: (_) @name) @item
410 ",
411 )
412 .unwrap(),
413 )
414 }
415}