1use language::BufferSnapshot;
2use std::ops::Range;
3use text::{OffsetRangeExt as _, Point, ToOffset as _, ToPoint as _};
4use tree_sitter::{Node, TreeCursor};
5use util::RangeExt;
6
7// TODO:
8//
9// - Test parent signatures
10//
11// - Decide whether to count signatures against the excerpt size. Could instead defer this to prompt
12// planning.
13//
14// - Still return an excerpt even if the line around the cursor doesn't fit (e.g. for a markdown
15// paragraph).
16//
17// - Truncation of long lines.
18//
19// - Filter outer syntax layers that don't support edit prediction.
20
21#[derive(Debug, Clone)]
22pub struct EditPredictionExcerptOptions {
23 /// Limit for the number of bytes in the window around the cursor.
24 pub max_bytes: usize,
25 /// Minimum number of bytes in the window around the cursor. When syntax tree selection results
26 /// in an excerpt smaller than this, it will fall back on line-based selection.
27 pub min_bytes: usize,
28 /// Target ratio of bytes before the cursor divided by total bytes in the window.
29 pub target_before_cursor_over_total_bytes: f32,
30 /// Whether to include parent signatures
31 pub include_parent_signatures: bool,
32}
33
34#[derive(Clone)]
35pub struct EditPredictionExcerpt {
36 pub range: Range<usize>,
37 pub parent_signature_ranges: Vec<Range<usize>>,
38 pub size: usize,
39}
40
41impl EditPredictionExcerpt {
42 /// Selects an excerpt around a buffer position, attempting to choose logical boundaries based
43 /// on TreeSitter structure and approximately targeting a goal ratio of bytesbefore vs after the
44 /// cursor. When `include_parent_signatures` is true, the excerpt also includes the signatures
45 /// of parent outline items.
46 ///
47 /// First tries to use AST node boundaries to select the excerpt, and falls back on line-based
48 /// expansion.
49 ///
50 /// Returns `None` if the line around the cursor doesn't fit.
51 pub fn select_from_buffer(
52 query_point: Point,
53 buffer: &BufferSnapshot,
54 options: &EditPredictionExcerptOptions,
55 ) -> Option<Self> {
56 if buffer.len() <= options.max_bytes {
57 log::debug!(
58 "using entire file for excerpt since source length ({}) <= window max bytes ({})",
59 buffer.len(),
60 options.max_bytes
61 );
62 return Some(EditPredictionExcerpt::new(0..buffer.len(), Vec::new()));
63 }
64
65 let query_offset = query_point.to_offset(buffer);
66 let query_range = Point::new(query_point.row, 0).to_offset(buffer)
67 ..Point::new(query_point.row + 1, 0).to_offset(buffer);
68 if query_range.len() >= options.max_bytes {
69 return None;
70 }
71
72 // TODO: Don't compute text / annotation_range / skip converting to and from anchors.
73 let outline_items = if options.include_parent_signatures {
74 buffer
75 .outline_items_containing(query_range.clone(), false, None)
76 .into_iter()
77 .flat_map(|item| {
78 Some(ExcerptOutlineItem {
79 item_range: item.range.to_offset(&buffer),
80 signature_range: item.signature_range?.to_offset(&buffer),
81 })
82 })
83 .collect()
84 } else {
85 Vec::new()
86 };
87
88 let excerpt_selector = ExcerptSelector {
89 query_offset,
90 query_range,
91 outline_items: &outline_items,
92 buffer,
93 options,
94 };
95
96 if let Some(excerpt_ranges) = excerpt_selector.select_tree_sitter_nodes() {
97 if excerpt_ranges.size >= options.min_bytes {
98 return Some(excerpt_ranges);
99 }
100 log::debug!(
101 "tree-sitter excerpt was {} bytes, smaller than min of {}, falling back on line-based selection",
102 excerpt_ranges.size,
103 options.min_bytes
104 );
105 } else {
106 log::debug!(
107 "couldn't find excerpt via tree-sitter, falling back on line-based selection"
108 );
109 }
110
111 excerpt_selector.select_lines()
112 }
113
114 fn new(range: Range<usize>, parent_signature_ranges: Vec<Range<usize>>) -> Self {
115 let size = range.len()
116 + parent_signature_ranges
117 .iter()
118 .map(|r| r.len())
119 .sum::<usize>();
120 Self {
121 range,
122 parent_signature_ranges,
123 size,
124 }
125 }
126
127 fn with_expanded_range(&self, new_range: Range<usize>) -> Self {
128 if !new_range.contains_inclusive(&self.range) {
129 // this is an issue because parent_signature_ranges may be incorrect
130 log::error!("bug: with_expanded_range called with disjoint range");
131 }
132 let mut parent_signature_ranges = Vec::with_capacity(self.parent_signature_ranges.len());
133 let mut size = new_range.len();
134 for range in &self.parent_signature_ranges {
135 if range.contains_inclusive(&new_range) {
136 break;
137 }
138 parent_signature_ranges.push(range.clone());
139 size += range.len();
140 }
141 Self {
142 range: new_range,
143 parent_signature_ranges,
144 size,
145 }
146 }
147
148 fn parent_signatures_size(&self) -> usize {
149 self.size - self.range.len()
150 }
151}
152
153struct ExcerptSelector<'a> {
154 query_offset: usize,
155 query_range: Range<usize>,
156 outline_items: &'a [ExcerptOutlineItem],
157 buffer: &'a BufferSnapshot,
158 options: &'a EditPredictionExcerptOptions,
159}
160
161struct ExcerptOutlineItem {
162 item_range: Range<usize>,
163 signature_range: Range<usize>,
164}
165
166impl<'a> ExcerptSelector<'a> {
167 /// Finds the largest node that is smaller than the window size and contains `query_range`.
168 fn select_tree_sitter_nodes(&self) -> Option<EditPredictionExcerpt> {
169 let selected_layer_root = self.select_syntax_layer()?;
170 let mut cursor = selected_layer_root.walk();
171
172 loop {
173 let excerpt_range = node_line_start(cursor.node()).to_offset(&self.buffer)
174 ..node_line_end(cursor.node()).to_offset(&self.buffer);
175 if excerpt_range.contains_inclusive(&self.query_range) {
176 let excerpt = self.make_excerpt(excerpt_range);
177 if excerpt.size <= self.options.max_bytes {
178 return Some(self.expand_to_siblings(&mut cursor, excerpt));
179 }
180 } else {
181 // TODO: Should still be able to handle this case via AST nodes. For example, this
182 // can happen if the cursor is between two methods in a large class file.
183 return None;
184 }
185
186 if cursor
187 .goto_first_child_for_byte(self.query_range.start)
188 .is_none()
189 {
190 return None;
191 }
192 }
193 }
194
195 /// Select the smallest syntax layer that exceeds max_len, or the largest if none exceed max_len.
196 fn select_syntax_layer(&self) -> Option<Node<'_>> {
197 let mut smallest_exceeding_max_len: Option<Node<'_>> = None;
198 let mut largest: Option<Node<'_>> = None;
199 for layer in self
200 .buffer
201 .syntax_layers_for_range(self.query_range.start..self.query_range.start, true)
202 {
203 let layer_range = layer.node().byte_range();
204 if !layer_range.contains_inclusive(&self.query_range) {
205 continue;
206 }
207
208 if layer_range.len() > self.options.max_bytes {
209 match &smallest_exceeding_max_len {
210 None => smallest_exceeding_max_len = Some(layer.node()),
211 Some(existing) => {
212 if layer_range.len() < existing.byte_range().len() {
213 smallest_exceeding_max_len = Some(layer.node());
214 }
215 }
216 }
217 } else {
218 match &largest {
219 None => largest = Some(layer.node()),
220 Some(existing) if layer_range.len() > existing.byte_range().len() => {
221 largest = Some(layer.node())
222 }
223 _ => {}
224 }
225 }
226 }
227
228 smallest_exceeding_max_len.or(largest)
229 }
230
231 // motivation for this and `goto_previous_named_sibling` is to avoid including things like
232 // trailing unnamed "}" in body nodes
233 fn goto_next_named_sibling(cursor: &mut TreeCursor) -> bool {
234 while cursor.goto_next_sibling() {
235 if cursor.node().is_named() {
236 return true;
237 }
238 }
239 false
240 }
241
242 fn goto_previous_named_sibling(cursor: &mut TreeCursor) -> bool {
243 while cursor.goto_previous_sibling() {
244 if cursor.node().is_named() {
245 return true;
246 }
247 }
248 false
249 }
250
251 fn expand_to_siblings(
252 &self,
253 cursor: &mut TreeCursor,
254 mut excerpt: EditPredictionExcerpt,
255 ) -> EditPredictionExcerpt {
256 let mut forward_cursor = cursor.clone();
257 let backward_cursor = cursor;
258 let mut forward_done = !Self::goto_next_named_sibling(&mut forward_cursor);
259 let mut backward_done = !Self::goto_previous_named_sibling(backward_cursor);
260 loop {
261 if backward_done && forward_done {
262 break;
263 }
264
265 let mut forward = None;
266 while !forward_done {
267 let new_end = node_line_end(forward_cursor.node()).to_offset(&self.buffer);
268 if new_end > excerpt.range.end {
269 let new_excerpt = excerpt.with_expanded_range(excerpt.range.start..new_end);
270 if new_excerpt.size <= self.options.max_bytes {
271 forward = Some(new_excerpt);
272 break;
273 } else {
274 log::debug!("halting forward expansion, as it doesn't fit");
275 forward_done = true;
276 break;
277 }
278 }
279 forward_done = !Self::goto_next_named_sibling(&mut forward_cursor);
280 }
281
282 let mut backward = None;
283 while !backward_done {
284 let new_start = node_line_start(backward_cursor.node()).to_offset(&self.buffer);
285 if new_start < excerpt.range.start {
286 let new_excerpt = excerpt.with_expanded_range(new_start..excerpt.range.end);
287 if new_excerpt.size <= self.options.max_bytes {
288 backward = Some(new_excerpt);
289 break;
290 } else {
291 log::debug!("halting backward expansion, as it doesn't fit");
292 backward_done = true;
293 break;
294 }
295 }
296 backward_done = !Self::goto_previous_named_sibling(backward_cursor);
297 }
298
299 let go_forward = match (forward, backward) {
300 (Some(forward), Some(backward)) => {
301 let go_forward = self.is_better_excerpt(&forward, &backward);
302 if go_forward {
303 excerpt = forward;
304 } else {
305 excerpt = backward;
306 }
307 go_forward
308 }
309 (Some(forward), None) => {
310 log::debug!("expanding forward, since backward expansion has halted");
311 excerpt = forward;
312 true
313 }
314 (None, Some(backward)) => {
315 log::debug!("expanding backward, since forward expansion has halted");
316 excerpt = backward;
317 false
318 }
319 (None, None) => break,
320 };
321
322 if go_forward {
323 forward_done = !Self::goto_next_named_sibling(&mut forward_cursor);
324 } else {
325 backward_done = !Self::goto_previous_named_sibling(backward_cursor);
326 }
327 }
328
329 excerpt
330 }
331
332 fn select_lines(&self) -> Option<EditPredictionExcerpt> {
333 // early return if line containing query_offset is already too large
334 let excerpt = self.make_excerpt(self.query_range.clone());
335 if excerpt.size > self.options.max_bytes {
336 log::debug!(
337 "excerpt for cursor line is {} bytes, which exceeds the window",
338 excerpt.size
339 );
340 return None;
341 }
342 let signatures_size = excerpt.parent_signatures_size();
343 let bytes_remaining = self.options.max_bytes.saturating_sub(signatures_size);
344
345 let before_bytes =
346 (self.options.target_before_cursor_over_total_bytes * bytes_remaining as f32) as usize;
347
348 let start_point = {
349 let offset = self.query_offset.saturating_sub(before_bytes);
350 let point = offset.to_point(self.buffer);
351 Point::new(point.row + 1, 0)
352 };
353 let start_offset = start_point.to_offset(&self.buffer);
354 let end_point = {
355 let offset = start_offset + bytes_remaining;
356 let point = offset.to_point(self.buffer);
357 Point::new(point.row, 0)
358 };
359 let end_offset = end_point.to_offset(&self.buffer);
360
361 // this could be expanded further since recalculated `signature_size` may be smaller, but
362 // skipping that for now for simplicity
363 //
364 // TODO: could also consider checking if lines immediately before / after fit.
365 let excerpt = self.make_excerpt(start_offset..end_offset);
366 if excerpt.size > self.options.max_bytes {
367 log::error!(
368 "bug: line-based excerpt selection has size {}, \
369 which is {} bytes larger than the max size",
370 excerpt.size,
371 excerpt.size - self.options.max_bytes
372 );
373 }
374 return Some(excerpt);
375 }
376
377 fn make_excerpt(&self, range: Range<usize>) -> EditPredictionExcerpt {
378 let parent_signature_ranges = self
379 .outline_items
380 .iter()
381 .filter(|item| item.item_range.contains_inclusive(&range))
382 .map(|item| item.signature_range.clone())
383 .collect();
384 EditPredictionExcerpt::new(range, parent_signature_ranges)
385 }
386
387 /// Returns `true` if the `forward` excerpt is a better choice than the `backward` excerpt.
388 fn is_better_excerpt(
389 &self,
390 forward: &EditPredictionExcerpt,
391 backward: &EditPredictionExcerpt,
392 ) -> bool {
393 let forward_ratio = self.excerpt_range_ratio(forward);
394 let backward_ratio = self.excerpt_range_ratio(backward);
395 let forward_delta =
396 (forward_ratio - self.options.target_before_cursor_over_total_bytes).abs();
397 let backward_delta =
398 (backward_ratio - self.options.target_before_cursor_over_total_bytes).abs();
399 let forward_is_better = forward_delta <= backward_delta;
400 if forward_is_better {
401 log::debug!(
402 "expanding forward since {} is closer than {} to {}",
403 forward_ratio,
404 backward_ratio,
405 self.options.target_before_cursor_over_total_bytes
406 );
407 } else {
408 log::debug!(
409 "expanding backward since {} is closer than {} to {}",
410 backward_ratio,
411 forward_ratio,
412 self.options.target_before_cursor_over_total_bytes
413 );
414 }
415 forward_is_better
416 }
417
418 /// Returns the ratio of bytes before the cursor over bytes within the range.
419 fn excerpt_range_ratio(&self, excerpt: &EditPredictionExcerpt) -> f32 {
420 let Some(bytes_before_cursor) = self.query_offset.checked_sub(excerpt.range.start) else {
421 log::error!("bug: edit prediction cursor offset is not outside the excerpt");
422 return 0.0;
423 };
424 bytes_before_cursor as f32 / excerpt.range.len() as f32
425 }
426}
427
428fn node_line_start(node: Node) -> Point {
429 Point::new(node.start_position().row as u32, 0)
430}
431
432fn node_line_end(node: Node) -> Point {
433 Point::new(node.end_position().row as u32 + 1, 0)
434}
435
436#[cfg(test)]
437mod tests {
438 use super::*;
439 use gpui::{AppContext, TestAppContext};
440 use language::{Buffer, Language, LanguageConfig, LanguageMatcher, tree_sitter_rust};
441 use util::test::{generate_marked_text, marked_text_offsets_by};
442
443 fn create_buffer(text: &str, cx: &mut TestAppContext) -> BufferSnapshot {
444 let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(rust_lang().into(), cx));
445 buffer.read_with(cx, |buffer, _| buffer.snapshot())
446 }
447
448 fn rust_lang() -> Language {
449 Language::new(
450 LanguageConfig {
451 name: "Rust".into(),
452 matcher: LanguageMatcher {
453 path_suffixes: vec!["rs".to_string()],
454 ..Default::default()
455 },
456 ..Default::default()
457 },
458 Some(tree_sitter_rust::LANGUAGE.into()),
459 )
460 .with_outline_query(include_str!("../../languages/src/rust/outline.scm"))
461 .unwrap()
462 }
463
464 fn cursor_and_excerpt_range(text: &str) -> (String, usize, Range<usize>) {
465 let (text, offsets) = marked_text_offsets_by(text, vec!['ˇ', '«', '»']);
466 (text, offsets[&'ˇ'][0], offsets[&'«'][0]..offsets[&'»'][0])
467 }
468
469 fn check_example(options: EditPredictionExcerptOptions, text: &str, cx: &mut TestAppContext) {
470 let (text, cursor, expected_excerpt) = cursor_and_excerpt_range(text);
471
472 let buffer = create_buffer(&text, cx);
473 let cursor_point = cursor.to_point(&buffer);
474
475 let excerpt = EditPredictionExcerpt::select_from_buffer(cursor_point, &buffer, &options)
476 .expect("Should select an excerpt");
477 pretty_assertions::assert_eq!(
478 generate_marked_text(&text, std::slice::from_ref(&excerpt.range), false),
479 generate_marked_text(&text, &[expected_excerpt], false)
480 );
481 assert!(excerpt.size <= options.max_bytes);
482 assert!(excerpt.range.contains(&cursor));
483 }
484
485 #[gpui::test]
486 fn test_ast_based_selection_current_node(cx: &mut TestAppContext) {
487 zlog::init_test();
488 let text = r#"
489fn main() {
490 let x = 1;
491« let ˇy = 2;
492» let z = 3;
493}"#;
494
495 let options = EditPredictionExcerptOptions {
496 max_bytes: 20,
497 min_bytes: 10,
498 target_before_cursor_over_total_bytes: 0.5,
499 include_parent_signatures: false,
500 };
501
502 check_example(options, text, cx);
503 }
504
505 #[gpui::test]
506 fn test_ast_based_selection_parent_node(cx: &mut TestAppContext) {
507 zlog::init_test();
508 let text = r#"
509fn foo() {}
510
511«fn main() {
512 let x = 1;
513 let ˇy = 2;
514 let z = 3;
515}
516»
517fn bar() {}"#;
518
519 let options = EditPredictionExcerptOptions {
520 max_bytes: 65,
521 min_bytes: 10,
522 target_before_cursor_over_total_bytes: 0.5,
523 include_parent_signatures: false,
524 };
525
526 check_example(options, text, cx);
527 }
528
529 #[gpui::test]
530 fn test_ast_based_selection_expands_to_siblings(cx: &mut TestAppContext) {
531 zlog::init_test();
532 let text = r#"
533fn main() {
534« let x = 1;
535 let ˇy = 2;
536 let z = 3;
537»}"#;
538
539 let options = EditPredictionExcerptOptions {
540 max_bytes: 50,
541 min_bytes: 10,
542 target_before_cursor_over_total_bytes: 0.5,
543 include_parent_signatures: false,
544 };
545
546 check_example(options, text, cx);
547 }
548
549 #[gpui::test]
550 fn test_line_based_selection(cx: &mut TestAppContext) {
551 zlog::init_test();
552 let text = r#"
553fn main() {
554 let x = 1;
555« if true {
556 let ˇy = 2;
557 }
558 let z = 3;
559»}"#;
560
561 let options = EditPredictionExcerptOptions {
562 max_bytes: 60,
563 min_bytes: 45,
564 target_before_cursor_over_total_bytes: 0.5,
565 include_parent_signatures: false,
566 };
567
568 check_example(options, text, cx);
569 }
570
571 #[gpui::test]
572 fn test_line_based_selection_with_before_cursor_ratio(cx: &mut TestAppContext) {
573 zlog::init_test();
574 let text = r#"
575 fn main() {
576« let a = 1;
577 let b = 2;
578 let c = 3;
579 let ˇd = 4;
580 let e = 5;
581 let f = 6;
582»
583 let g = 7;
584 }"#;
585
586 let options = EditPredictionExcerptOptions {
587 max_bytes: 120,
588 min_bytes: 10,
589 target_before_cursor_over_total_bytes: 0.6,
590 include_parent_signatures: false,
591 };
592
593 check_example(options, text, cx);
594 }
595}