1use language::{BufferSnapshot, LanguageId};
2use std::ops::Range;
3use text::{Point, ToOffset, ToPoint as _};
4use tree_sitter::{Node, TreeCursor};
5use util::RangeExt;
6
7use crate::{BufferDeclaration, Line, declaration::DeclarationId, syntax_index::SyntaxIndexState};
8
9// TODO:
10//
11// - Test parent signatures
12//
13// - Decide whether to count signatures against the excerpt size. Could instead defer this to prompt
14// planning.
15//
16// - Still return an excerpt even if the line around the cursor doesn't fit (e.g. for a markdown
17// paragraph).
18//
19// - Truncation of long lines.
20//
21// - Filter outer syntax layers that don't support edit prediction.
22
23#[derive(Debug, Clone, PartialEq)]
24pub struct EditPredictionExcerptOptions {
25 /// Limit for the number of bytes in the window around the cursor.
26 pub max_bytes: usize,
27 /// Minimum number of bytes in the window around the cursor. When syntax tree selection results
28 /// in an excerpt smaller than this, it will fall back on line-based selection.
29 pub min_bytes: usize,
30 /// Target ratio of bytes before the cursor divided by total bytes in the window.
31 pub target_before_cursor_over_total_bytes: f32,
32}
33
34// TODO: consider merging these
35#[derive(Debug, Clone)]
36pub struct EditPredictionExcerpt {
37 pub range: Range<usize>,
38 pub line_range: Range<Line>,
39 pub parent_declarations: Vec<(DeclarationId, Range<usize>)>,
40 pub size: usize,
41}
42
43#[derive(Debug, Clone)]
44pub struct EditPredictionExcerptText {
45 pub body: String,
46 pub parent_signatures: Vec<String>,
47 pub language_id: Option<LanguageId>,
48}
49
50impl EditPredictionExcerpt {
51 pub fn text(&self, buffer: &BufferSnapshot) -> EditPredictionExcerptText {
52 let body = buffer
53 .text_for_range(self.range.clone())
54 .collect::<String>();
55 let parent_signatures = self
56 .parent_declarations
57 .iter()
58 .map(|(_, range)| buffer.text_for_range(range.clone()).collect::<String>())
59 .collect();
60 let language_id = buffer.language().map(|l| l.id());
61 EditPredictionExcerptText {
62 body,
63 parent_signatures,
64 language_id,
65 }
66 }
67
68 /// Selects an excerpt around a buffer position, attempting to choose logical boundaries based
69 /// on TreeSitter structure and approximately targeting a goal ratio of bytesbefore vs after the
70 /// cursor.
71 ///
72 /// When `index` is provided, the excerpt will include the signatures of parent outline items.
73 ///
74 /// First tries to use AST node boundaries to select the excerpt, and falls back on line-based
75 /// expansion.
76 ///
77 /// Returns `None` if the line around the cursor doesn't fit.
78 pub fn select_from_buffer(
79 query_point: Point,
80 buffer: &BufferSnapshot,
81 options: &EditPredictionExcerptOptions,
82 syntax_index: Option<&SyntaxIndexState>,
83 ) -> Option<Self> {
84 if buffer.len() <= options.max_bytes {
85 log::debug!(
86 "using entire file for excerpt since source length ({}) <= window max bytes ({})",
87 buffer.len(),
88 options.max_bytes
89 );
90 let offset_range = 0..buffer.len();
91 let line_range = Line(0)..Line(buffer.max_point().row);
92 return Some(EditPredictionExcerpt::new(
93 offset_range,
94 line_range,
95 Vec::new(),
96 ));
97 }
98
99 let query_offset = query_point.to_offset(buffer);
100 let query_line_range = query_point.row..query_point.row + 1;
101 let query_range = Point::new(query_line_range.start, 0).to_offset(buffer)
102 ..Point::new(query_line_range.end, 0).to_offset(buffer);
103 if query_range.len() >= options.max_bytes {
104 return None;
105 }
106
107 let parent_declarations = if let Some(syntax_index) = syntax_index {
108 syntax_index
109 .buffer_declarations_containing_range(buffer.remote_id(), query_range.clone())
110 .collect()
111 } else {
112 Vec::new()
113 };
114
115 let excerpt_selector = ExcerptSelector {
116 query_offset,
117 query_range,
118 query_line_range: Line(query_line_range.start)..Line(query_line_range.end),
119 parent_declarations: &parent_declarations,
120 buffer,
121 options,
122 };
123
124 if let Some(excerpt) = excerpt_selector.select_tree_sitter_nodes() {
125 if excerpt.size >= options.min_bytes {
126 return Some(excerpt);
127 }
128 log::debug!(
129 "tree-sitter excerpt was {} bytes, smaller than min of {}, falling back on line-based selection",
130 excerpt.size,
131 options.min_bytes
132 );
133 } else {
134 log::debug!(
135 "couldn't find excerpt via tree-sitter, falling back on line-based selection"
136 );
137 }
138
139 excerpt_selector.select_lines()
140 }
141
142 fn new(
143 range: Range<usize>,
144 line_range: Range<Line>,
145 parent_declarations: Vec<(DeclarationId, Range<usize>)>,
146 ) -> Self {
147 let size = range.len()
148 + parent_declarations
149 .iter()
150 .map(|(_, range)| range.len())
151 .sum::<usize>();
152 Self {
153 range,
154 parent_declarations,
155 size,
156 line_range,
157 }
158 }
159
160 fn with_expanded_range(&self, new_range: Range<usize>, new_line_range: Range<Line>) -> Self {
161 if !new_range.contains_inclusive(&self.range) {
162 // this is an issue because parent_signature_ranges may be incorrect
163 log::error!("bug: with_expanded_range called with disjoint range");
164 }
165 let mut parent_declarations = Vec::with_capacity(self.parent_declarations.len());
166 for (declaration_id, range) in &self.parent_declarations {
167 if !range.contains_inclusive(&new_range) {
168 break;
169 }
170 parent_declarations.push((*declaration_id, range.clone()));
171 }
172 Self::new(new_range, new_line_range, parent_declarations)
173 }
174
175 fn parent_signatures_size(&self) -> usize {
176 self.size - self.range.len()
177 }
178}
179
180struct ExcerptSelector<'a> {
181 query_offset: usize,
182 query_range: Range<usize>,
183 query_line_range: Range<Line>,
184 parent_declarations: &'a [(DeclarationId, &'a BufferDeclaration)],
185 buffer: &'a BufferSnapshot,
186 options: &'a EditPredictionExcerptOptions,
187}
188
189impl<'a> ExcerptSelector<'a> {
190 /// Finds the largest node that is smaller than the window size and contains `query_range`.
191 fn select_tree_sitter_nodes(&self) -> Option<EditPredictionExcerpt> {
192 let selected_layer_root = self.select_syntax_layer()?;
193 let mut cursor = selected_layer_root.walk();
194
195 loop {
196 let line_start = node_line_start(cursor.node());
197 let line_end = node_line_end(cursor.node());
198 let line_range = Line(line_start.row)..Line(line_end.row);
199 let excerpt_range =
200 line_start.to_offset(&self.buffer)..line_end.to_offset(&self.buffer);
201 if excerpt_range.contains_inclusive(&self.query_range) {
202 let excerpt = self.make_excerpt(excerpt_range, line_range);
203 if excerpt.size <= self.options.max_bytes {
204 return Some(self.expand_to_siblings(&mut cursor, excerpt));
205 }
206 } else {
207 // TODO: Should still be able to handle this case via AST nodes. For example, this
208 // can happen if the cursor is between two methods in a large class file.
209 return None;
210 }
211
212 if cursor
213 .goto_first_child_for_byte(self.query_range.start)
214 .is_none()
215 {
216 return None;
217 }
218 }
219 }
220
221 /// Select the smallest syntax layer that exceeds max_len, or the largest if none exceed max_len.
222 fn select_syntax_layer(&self) -> Option<Node<'_>> {
223 let mut smallest_exceeding_max_len: Option<Node<'_>> = None;
224 let mut largest: Option<Node<'_>> = None;
225 for layer in self
226 .buffer
227 .syntax_layers_for_range(self.query_range.start..self.query_range.start, true)
228 {
229 let layer_range = layer.node().byte_range();
230 if !layer_range.contains_inclusive(&self.query_range) {
231 continue;
232 }
233
234 if layer_range.len() > self.options.max_bytes {
235 match &smallest_exceeding_max_len {
236 None => smallest_exceeding_max_len = Some(layer.node()),
237 Some(existing) => {
238 if layer_range.len() < existing.byte_range().len() {
239 smallest_exceeding_max_len = Some(layer.node());
240 }
241 }
242 }
243 } else {
244 match &largest {
245 None => largest = Some(layer.node()),
246 Some(existing) if layer_range.len() > existing.byte_range().len() => {
247 largest = Some(layer.node())
248 }
249 _ => {}
250 }
251 }
252 }
253
254 smallest_exceeding_max_len.or(largest)
255 }
256
257 // motivation for this and `goto_previous_named_sibling` is to avoid including things like
258 // trailing unnamed "}" in body nodes
259 fn goto_next_named_sibling(cursor: &mut TreeCursor) -> bool {
260 while cursor.goto_next_sibling() {
261 if cursor.node().is_named() {
262 return true;
263 }
264 }
265 false
266 }
267
268 fn goto_previous_named_sibling(cursor: &mut TreeCursor) -> bool {
269 while cursor.goto_previous_sibling() {
270 if cursor.node().is_named() {
271 return true;
272 }
273 }
274 false
275 }
276
277 fn expand_to_siblings(
278 &self,
279 cursor: &mut TreeCursor,
280 mut excerpt: EditPredictionExcerpt,
281 ) -> EditPredictionExcerpt {
282 let mut forward_cursor = cursor.clone();
283 let backward_cursor = cursor;
284 let mut forward_done = !Self::goto_next_named_sibling(&mut forward_cursor);
285 let mut backward_done = !Self::goto_previous_named_sibling(backward_cursor);
286 loop {
287 if backward_done && forward_done {
288 break;
289 }
290
291 let mut forward = None;
292 while !forward_done {
293 let new_end_point = node_line_end(forward_cursor.node());
294 let new_end = new_end_point.to_offset(&self.buffer);
295 if new_end > excerpt.range.end {
296 let new_excerpt = excerpt.with_expanded_range(
297 excerpt.range.start..new_end,
298 excerpt.line_range.start..Line(new_end_point.row),
299 );
300 if new_excerpt.size <= self.options.max_bytes {
301 forward = Some(new_excerpt);
302 break;
303 } else {
304 log::debug!("halting forward expansion, as it doesn't fit");
305 forward_done = true;
306 break;
307 }
308 }
309 forward_done = !Self::goto_next_named_sibling(&mut forward_cursor);
310 }
311
312 let mut backward = None;
313 while !backward_done {
314 let new_start_point = node_line_start(backward_cursor.node());
315 let new_start = new_start_point.to_offset(&self.buffer);
316 if new_start < excerpt.range.start {
317 let new_excerpt = excerpt.with_expanded_range(
318 new_start..excerpt.range.end,
319 Line(new_start_point.row)..excerpt.line_range.end,
320 );
321 if new_excerpt.size <= self.options.max_bytes {
322 backward = Some(new_excerpt);
323 break;
324 } else {
325 log::debug!("halting backward expansion, as it doesn't fit");
326 backward_done = true;
327 break;
328 }
329 }
330 backward_done = !Self::goto_previous_named_sibling(backward_cursor);
331 }
332
333 let go_forward = match (forward, backward) {
334 (Some(forward), Some(backward)) => {
335 let go_forward = self.is_better_excerpt(&forward, &backward);
336 if go_forward {
337 excerpt = forward;
338 } else {
339 excerpt = backward;
340 }
341 go_forward
342 }
343 (Some(forward), None) => {
344 log::debug!("expanding forward, since backward expansion has halted");
345 excerpt = forward;
346 true
347 }
348 (None, Some(backward)) => {
349 log::debug!("expanding backward, since forward expansion has halted");
350 excerpt = backward;
351 false
352 }
353 (None, None) => break,
354 };
355
356 if go_forward {
357 forward_done = !Self::goto_next_named_sibling(&mut forward_cursor);
358 } else {
359 backward_done = !Self::goto_previous_named_sibling(backward_cursor);
360 }
361 }
362
363 excerpt
364 }
365
366 fn select_lines(&self) -> Option<EditPredictionExcerpt> {
367 // early return if line containing query_offset is already too large
368 let excerpt = self.make_excerpt(self.query_range.clone(), self.query_line_range.clone());
369 if excerpt.size > self.options.max_bytes {
370 log::debug!(
371 "excerpt for cursor line is {} bytes, which exceeds the window",
372 excerpt.size
373 );
374 return None;
375 }
376 let signatures_size = excerpt.parent_signatures_size();
377 let bytes_remaining = self.options.max_bytes.saturating_sub(signatures_size);
378
379 let before_bytes =
380 (self.options.target_before_cursor_over_total_bytes * bytes_remaining as f32) as usize;
381
382 let start_line = {
383 let offset = self.query_offset.saturating_sub(before_bytes);
384 let point = offset.to_point(self.buffer);
385 Line(point.row + 1)
386 };
387 let start_offset = Point::new(start_line.0, 0).to_offset(&self.buffer);
388 let end_line = {
389 let offset = start_offset + bytes_remaining;
390 let point = offset.to_point(self.buffer);
391 Line(point.row)
392 };
393 let end_offset = Point::new(end_line.0, 0).to_offset(&self.buffer);
394
395 // this could be expanded further since recalculated `signature_size` may be smaller, but
396 // skipping that for now for simplicity
397 //
398 // TODO: could also consider checking if lines immediately before / after fit.
399 let excerpt = self.make_excerpt(start_offset..end_offset, start_line..end_line);
400 if excerpt.size > self.options.max_bytes {
401 log::error!(
402 "bug: line-based excerpt selection has size {}, \
403 which is {} bytes larger than the max size",
404 excerpt.size,
405 excerpt.size - self.options.max_bytes
406 );
407 }
408 return Some(excerpt);
409 }
410
411 fn make_excerpt(&self, range: Range<usize>, line_range: Range<Line>) -> EditPredictionExcerpt {
412 let parent_declarations = self
413 .parent_declarations
414 .iter()
415 .filter(|(_, declaration)| declaration.item_range.contains_inclusive(&range))
416 .map(|(id, declaration)| (*id, declaration.signature_range.clone()))
417 .collect();
418 EditPredictionExcerpt::new(range, line_range, parent_declarations)
419 }
420
421 /// Returns `true` if the `forward` excerpt is a better choice than the `backward` excerpt.
422 fn is_better_excerpt(
423 &self,
424 forward: &EditPredictionExcerpt,
425 backward: &EditPredictionExcerpt,
426 ) -> bool {
427 let forward_ratio = self.excerpt_range_ratio(forward);
428 let backward_ratio = self.excerpt_range_ratio(backward);
429 let forward_delta =
430 (forward_ratio - self.options.target_before_cursor_over_total_bytes).abs();
431 let backward_delta =
432 (backward_ratio - self.options.target_before_cursor_over_total_bytes).abs();
433 let forward_is_better = forward_delta <= backward_delta;
434 if forward_is_better {
435 log::debug!(
436 "expanding forward since {} is closer than {} to {}",
437 forward_ratio,
438 backward_ratio,
439 self.options.target_before_cursor_over_total_bytes
440 );
441 } else {
442 log::debug!(
443 "expanding backward since {} is closer than {} to {}",
444 backward_ratio,
445 forward_ratio,
446 self.options.target_before_cursor_over_total_bytes
447 );
448 }
449 forward_is_better
450 }
451
452 /// Returns the ratio of bytes before the cursor over bytes within the range.
453 fn excerpt_range_ratio(&self, excerpt: &EditPredictionExcerpt) -> f32 {
454 let Some(bytes_before_cursor) = self.query_offset.checked_sub(excerpt.range.start) else {
455 log::error!("bug: edit prediction cursor offset is not outside the excerpt");
456 return 0.0;
457 };
458 bytes_before_cursor as f32 / excerpt.range.len() as f32
459 }
460}
461
462/// Returns the next line start if the offset is not already at the start of a line.
463pub(crate) fn next_line_start(offset: usize, buffer: &text::BufferSnapshot) -> Point {
464 if offset == 0 {
465 Point::new(0, 0)
466 } else {
467 let point = offset.to_point(buffer);
468 Point::new(point.row + 1, 0)
469 }
470}
471
472/// Returns the previous line start if the offset is not already at the start of a line.
473pub(crate) fn previous_line_start(offset: usize, buffer: &text::BufferSnapshot) -> Point {
474 if offset >= buffer.len() {
475 buffer.max_point()
476 } else {
477 let point = offset.to_point(buffer);
478 Point::new(point.row, 0)
479 }
480}
481
482fn node_line_start(node: Node) -> Point {
483 Point::new(node.start_position().row as u32, 0)
484}
485
486fn node_line_end(node: Node) -> Point {
487 Point::new(node.end_position().row as u32 + 1, 0)
488}
489
490#[cfg(test)]
491mod tests {
492 use super::*;
493 use gpui::{AppContext, TestAppContext};
494 use language::{Buffer, Language, LanguageConfig, LanguageMatcher, tree_sitter_rust};
495 use util::test::{generate_marked_text, marked_text_offsets_by};
496
497 fn create_buffer(text: &str, cx: &mut TestAppContext) -> BufferSnapshot {
498 let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(rust_lang().into(), cx));
499 buffer.read_with(cx, |buffer, _| buffer.snapshot())
500 }
501
502 fn rust_lang() -> Language {
503 Language::new(
504 LanguageConfig {
505 name: "Rust".into(),
506 matcher: LanguageMatcher {
507 path_suffixes: vec!["rs".to_string()],
508 ..Default::default()
509 },
510 ..Default::default()
511 },
512 Some(tree_sitter_rust::LANGUAGE.into()),
513 )
514 .with_outline_query(include_str!("../../languages/src/rust/outline.scm"))
515 .unwrap()
516 }
517
518 fn cursor_and_excerpt_range(text: &str) -> (String, usize, Range<usize>) {
519 let (text, offsets) = marked_text_offsets_by(text, vec!['ˇ', '«', '»']);
520 (text, offsets[&'ˇ'][0], offsets[&'«'][0]..offsets[&'»'][0])
521 }
522
523 fn check_example(options: EditPredictionExcerptOptions, text: &str, cx: &mut TestAppContext) {
524 let (text, cursor, expected_excerpt) = cursor_and_excerpt_range(text);
525
526 let buffer = create_buffer(&text, cx);
527 let cursor_point = cursor.to_point(&buffer);
528
529 let excerpt =
530 EditPredictionExcerpt::select_from_buffer(cursor_point, &buffer, &options, None)
531 .expect("Should select an excerpt");
532 pretty_assertions::assert_eq!(
533 generate_marked_text(&text, std::slice::from_ref(&excerpt.range), false),
534 generate_marked_text(&text, &[expected_excerpt], false)
535 );
536 assert!(excerpt.size <= options.max_bytes);
537 assert!(excerpt.range.contains(&cursor));
538 }
539
540 #[gpui::test]
541 fn test_ast_based_selection_current_node(cx: &mut TestAppContext) {
542 zlog::init_test();
543 let text = r#"
544fn main() {
545 let x = 1;
546« let ˇy = 2;
547» let z = 3;
548}"#;
549
550 let options = EditPredictionExcerptOptions {
551 max_bytes: 20,
552 min_bytes: 10,
553 target_before_cursor_over_total_bytes: 0.5,
554 };
555
556 check_example(options, text, cx);
557 }
558
559 #[gpui::test]
560 fn test_ast_based_selection_parent_node(cx: &mut TestAppContext) {
561 zlog::init_test();
562 let text = r#"
563fn foo() {}
564
565«fn main() {
566 let x = 1;
567 let ˇy = 2;
568 let z = 3;
569}
570»
571fn bar() {}"#;
572
573 let options = EditPredictionExcerptOptions {
574 max_bytes: 65,
575 min_bytes: 10,
576 target_before_cursor_over_total_bytes: 0.5,
577 };
578
579 check_example(options, text, cx);
580 }
581
582 #[gpui::test]
583 fn test_ast_based_selection_expands_to_siblings(cx: &mut TestAppContext) {
584 zlog::init_test();
585 let text = r#"
586fn main() {
587« let x = 1;
588 let ˇy = 2;
589 let z = 3;
590»}"#;
591
592 let options = EditPredictionExcerptOptions {
593 max_bytes: 50,
594 min_bytes: 10,
595 target_before_cursor_over_total_bytes: 0.5,
596 };
597
598 check_example(options, text, cx);
599 }
600
601 #[gpui::test]
602 fn test_line_based_selection(cx: &mut TestAppContext) {
603 zlog::init_test();
604 let text = r#"
605fn main() {
606 let x = 1;
607« if true {
608 let ˇy = 2;
609 }
610 let z = 3;
611»}"#;
612
613 let options = EditPredictionExcerptOptions {
614 max_bytes: 60,
615 min_bytes: 45,
616 target_before_cursor_over_total_bytes: 0.5,
617 };
618
619 check_example(options, text, cx);
620 }
621
622 #[gpui::test]
623 fn test_line_based_selection_with_before_cursor_ratio(cx: &mut TestAppContext) {
624 zlog::init_test();
625 let text = r#"
626 fn main() {
627« let a = 1;
628 let b = 2;
629 let c = 3;
630 let ˇd = 4;
631 let e = 5;
632 let f = 6;
633»
634 let g = 7;
635 }"#;
636
637 let options = EditPredictionExcerptOptions {
638 max_bytes: 120,
639 min_bytes: 10,
640 target_before_cursor_over_total_bytes: 0.6,
641 };
642
643 check_example(options, text, cx);
644 }
645}