1use language::{BufferSnapshot, LanguageId};
2use std::ops::Range;
3use text::{Point, ToOffset as _, ToPoint as _};
4use tree_sitter::{Node, TreeCursor};
5use util::RangeExt;
6
7use crate::{BufferDeclaration, Line, declaration::DeclarationId, syntax_index::SyntaxIndexState};
8
9// TODO:
10//
11// - Test parent signatures
12//
13// - Decide whether to count signatures against the excerpt size. Could instead defer this to prompt
14// planning.
15//
16// - Still return an excerpt even if the line around the cursor doesn't fit (e.g. for a markdown
17// paragraph).
18//
19// - Truncation of long lines.
20//
21// - Filter outer syntax layers that don't support edit prediction.
22
23#[derive(Debug, Clone, PartialEq)]
24pub struct EditPredictionExcerptOptions {
25 /// Limit for the number of bytes in the window around the cursor.
26 pub max_bytes: usize,
27 /// Minimum number of bytes in the window around the cursor. When syntax tree selection results
28 /// in an excerpt smaller than this, it will fall back on line-based selection.
29 pub min_bytes: usize,
30 /// Target ratio of bytes before the cursor divided by total bytes in the window.
31 pub target_before_cursor_over_total_bytes: f32,
32}
33
34// TODO: consider merging these
35#[derive(Debug, Clone)]
36pub struct EditPredictionExcerpt {
37 pub range: Range<usize>,
38 pub line_range: Range<Line>,
39 pub parent_declarations: Vec<(DeclarationId, Range<usize>)>,
40 pub size: usize,
41}
42
43#[derive(Debug, Clone)]
44pub struct EditPredictionExcerptText {
45 pub body: String,
46 pub parent_signatures: Vec<String>,
47 pub language_id: Option<LanguageId>,
48}
49
50impl EditPredictionExcerpt {
51 pub fn text(&self, buffer: &BufferSnapshot) -> EditPredictionExcerptText {
52 let body = buffer
53 .text_for_range(self.range.clone())
54 .collect::<String>();
55 let parent_signatures = self
56 .parent_declarations
57 .iter()
58 .map(|(_, range)| buffer.text_for_range(range.clone()).collect::<String>())
59 .collect();
60 let language_id = buffer.language().map(|l| l.id());
61 EditPredictionExcerptText {
62 body,
63 parent_signatures,
64 language_id,
65 }
66 }
67
68 /// Selects an excerpt around a buffer position, attempting to choose logical boundaries based
69 /// on TreeSitter structure and approximately targeting a goal ratio of bytesbefore vs after the
70 /// cursor.
71 ///
72 /// When `index` is provided, the excerpt will include the signatures of parent outline items.
73 ///
74 /// First tries to use AST node boundaries to select the excerpt, and falls back on line-based
75 /// expansion.
76 ///
77 /// Returns `None` if the line around the cursor doesn't fit.
78 pub fn select_from_buffer(
79 query_point: Point,
80 buffer: &BufferSnapshot,
81 options: &EditPredictionExcerptOptions,
82 syntax_index: Option<&SyntaxIndexState>,
83 ) -> Option<Self> {
84 if buffer.len() <= options.max_bytes {
85 log::debug!(
86 "using entire file for excerpt since source length ({}) <= window max bytes ({})",
87 buffer.len(),
88 options.max_bytes
89 );
90 let offset_range = 0..buffer.len();
91 let line_range = Line(0)..Line(buffer.max_point().row);
92 return Some(EditPredictionExcerpt::new(
93 offset_range,
94 line_range,
95 Vec::new(),
96 ));
97 }
98
99 let query_offset = query_point.to_offset(buffer);
100 let query_line_range = query_point.row..query_point.row + 1;
101 let query_range = Point::new(query_line_range.start, 0).to_offset(buffer)
102 ..Point::new(query_line_range.end, 0).to_offset(buffer);
103 if query_range.len() >= options.max_bytes {
104 return None;
105 }
106
107 let parent_declarations = if let Some(syntax_index) = syntax_index {
108 syntax_index
109 .buffer_declarations_containing_range(buffer.remote_id(), query_range.clone())
110 .collect()
111 } else {
112 Vec::new()
113 };
114
115 let excerpt_selector = ExcerptSelector {
116 query_offset,
117 query_range,
118 query_line_range: Line(query_line_range.start)..Line(query_line_range.end),
119 parent_declarations: &parent_declarations,
120 buffer,
121 options,
122 };
123
124 if let Some(excerpt) = excerpt_selector.select_tree_sitter_nodes() {
125 if excerpt.size >= options.min_bytes {
126 return Some(excerpt);
127 }
128 log::debug!(
129 "tree-sitter excerpt was {} bytes, smaller than min of {}, falling back on line-based selection",
130 excerpt.size,
131 options.min_bytes
132 );
133 } else {
134 log::debug!(
135 "couldn't find excerpt via tree-sitter, falling back on line-based selection"
136 );
137 }
138
139 excerpt_selector.select_lines()
140 }
141
142 fn new(
143 range: Range<usize>,
144 line_range: Range<Line>,
145 parent_declarations: Vec<(DeclarationId, Range<usize>)>,
146 ) -> Self {
147 let size = range.len()
148 + parent_declarations
149 .iter()
150 .map(|(_, range)| range.len())
151 .sum::<usize>();
152 Self {
153 range,
154 parent_declarations,
155 size,
156 line_range,
157 }
158 }
159
160 fn with_expanded_range(&self, new_range: Range<usize>, new_line_range: Range<Line>) -> Self {
161 if !new_range.contains_inclusive(&self.range) {
162 // this is an issue because parent_signature_ranges may be incorrect
163 log::error!("bug: with_expanded_range called with disjoint range");
164 }
165 let mut parent_declarations = Vec::with_capacity(self.parent_declarations.len());
166 for (declaration_id, range) in &self.parent_declarations {
167 if !range.contains_inclusive(&new_range) {
168 break;
169 }
170 parent_declarations.push((*declaration_id, range.clone()));
171 }
172 Self::new(new_range, new_line_range, parent_declarations)
173 }
174
175 fn parent_signatures_size(&self) -> usize {
176 self.size - self.range.len()
177 }
178}
179
180struct ExcerptSelector<'a> {
181 query_offset: usize,
182 query_range: Range<usize>,
183 query_line_range: Range<Line>,
184 parent_declarations: &'a [(DeclarationId, &'a BufferDeclaration)],
185 buffer: &'a BufferSnapshot,
186 options: &'a EditPredictionExcerptOptions,
187}
188
189impl<'a> ExcerptSelector<'a> {
190 /// Finds the largest node that is smaller than the window size and contains `query_range`.
191 fn select_tree_sitter_nodes(&self) -> Option<EditPredictionExcerpt> {
192 let selected_layer_root = self.select_syntax_layer()?;
193 let mut cursor = selected_layer_root.walk();
194
195 loop {
196 let line_start = node_line_start(cursor.node());
197 let line_end = node_line_end(cursor.node());
198 let line_range = Line(line_start.row)..Line(line_end.row);
199 let excerpt_range =
200 line_start.to_offset(&self.buffer)..line_end.to_offset(&self.buffer);
201 if excerpt_range.contains_inclusive(&self.query_range) {
202 let excerpt = self.make_excerpt(excerpt_range, line_range);
203 if excerpt.size <= self.options.max_bytes {
204 return Some(self.expand_to_siblings(&mut cursor, excerpt));
205 }
206 } else {
207 // TODO: Should still be able to handle this case via AST nodes. For example, this
208 // can happen if the cursor is between two methods in a large class file.
209 return None;
210 }
211
212 if cursor
213 .goto_first_child_for_byte(self.query_range.start)
214 .is_none()
215 {
216 return None;
217 }
218 }
219 }
220
221 /// Select the smallest syntax layer that exceeds max_len, or the largest if none exceed max_len.
222 fn select_syntax_layer(&self) -> Option<Node<'_>> {
223 let mut smallest_exceeding_max_len: Option<Node<'_>> = None;
224 let mut largest: Option<Node<'_>> = None;
225 for layer in self
226 .buffer
227 .syntax_layers_for_range(self.query_range.start..self.query_range.start, true)
228 {
229 let layer_range = layer.node().byte_range();
230 if !layer_range.contains_inclusive(&self.query_range) {
231 continue;
232 }
233
234 if layer_range.len() > self.options.max_bytes {
235 match &smallest_exceeding_max_len {
236 None => smallest_exceeding_max_len = Some(layer.node()),
237 Some(existing) => {
238 if layer_range.len() < existing.byte_range().len() {
239 smallest_exceeding_max_len = Some(layer.node());
240 }
241 }
242 }
243 } else {
244 match &largest {
245 None => largest = Some(layer.node()),
246 Some(existing) if layer_range.len() > existing.byte_range().len() => {
247 largest = Some(layer.node())
248 }
249 _ => {}
250 }
251 }
252 }
253
254 smallest_exceeding_max_len.or(largest)
255 }
256
257 // motivation for this and `goto_previous_named_sibling` is to avoid including things like
258 // trailing unnamed "}" in body nodes
259 fn goto_next_named_sibling(cursor: &mut TreeCursor) -> bool {
260 while cursor.goto_next_sibling() {
261 if cursor.node().is_named() {
262 return true;
263 }
264 }
265 false
266 }
267
268 fn goto_previous_named_sibling(cursor: &mut TreeCursor) -> bool {
269 while cursor.goto_previous_sibling() {
270 if cursor.node().is_named() {
271 return true;
272 }
273 }
274 false
275 }
276
277 fn expand_to_siblings(
278 &self,
279 cursor: &mut TreeCursor,
280 mut excerpt: EditPredictionExcerpt,
281 ) -> EditPredictionExcerpt {
282 let mut forward_cursor = cursor.clone();
283 let backward_cursor = cursor;
284 let mut forward_done = !Self::goto_next_named_sibling(&mut forward_cursor);
285 let mut backward_done = !Self::goto_previous_named_sibling(backward_cursor);
286 loop {
287 if backward_done && forward_done {
288 break;
289 }
290
291 let mut forward = None;
292 while !forward_done {
293 let new_end_point = node_line_end(forward_cursor.node());
294 let new_end = new_end_point.to_offset(&self.buffer);
295 if new_end > excerpt.range.end {
296 let new_excerpt = excerpt.with_expanded_range(
297 excerpt.range.start..new_end,
298 excerpt.line_range.start..Line(new_end_point.row),
299 );
300 if new_excerpt.size <= self.options.max_bytes {
301 forward = Some(new_excerpt);
302 break;
303 } else {
304 log::debug!("halting forward expansion, as it doesn't fit");
305 forward_done = true;
306 break;
307 }
308 }
309 forward_done = !Self::goto_next_named_sibling(&mut forward_cursor);
310 }
311
312 let mut backward = None;
313 while !backward_done {
314 let new_start_point = node_line_start(backward_cursor.node());
315 let new_start = new_start_point.to_offset(&self.buffer);
316 if new_start < excerpt.range.start {
317 let new_excerpt = excerpt.with_expanded_range(
318 new_start..excerpt.range.end,
319 Line(new_start_point.row)..excerpt.line_range.end,
320 );
321 if new_excerpt.size <= self.options.max_bytes {
322 backward = Some(new_excerpt);
323 break;
324 } else {
325 log::debug!("halting backward expansion, as it doesn't fit");
326 backward_done = true;
327 break;
328 }
329 }
330 backward_done = !Self::goto_previous_named_sibling(backward_cursor);
331 }
332
333 let go_forward = match (forward, backward) {
334 (Some(forward), Some(backward)) => {
335 let go_forward = self.is_better_excerpt(&forward, &backward);
336 if go_forward {
337 excerpt = forward;
338 } else {
339 excerpt = backward;
340 }
341 go_forward
342 }
343 (Some(forward), None) => {
344 log::debug!("expanding forward, since backward expansion has halted");
345 excerpt = forward;
346 true
347 }
348 (None, Some(backward)) => {
349 log::debug!("expanding backward, since forward expansion has halted");
350 excerpt = backward;
351 false
352 }
353 (None, None) => break,
354 };
355
356 if go_forward {
357 forward_done = !Self::goto_next_named_sibling(&mut forward_cursor);
358 } else {
359 backward_done = !Self::goto_previous_named_sibling(backward_cursor);
360 }
361 }
362
363 excerpt
364 }
365
366 fn select_lines(&self) -> Option<EditPredictionExcerpt> {
367 // early return if line containing query_offset is already too large
368 let excerpt = self.make_excerpt(self.query_range.clone(), self.query_line_range.clone());
369 if excerpt.size > self.options.max_bytes {
370 log::debug!(
371 "excerpt for cursor line is {} bytes, which exceeds the window",
372 excerpt.size
373 );
374 return None;
375 }
376 let signatures_size = excerpt.parent_signatures_size();
377 let bytes_remaining = self.options.max_bytes.saturating_sub(signatures_size);
378
379 let before_bytes =
380 (self.options.target_before_cursor_over_total_bytes * bytes_remaining as f32) as usize;
381
382 let start_line = {
383 let offset = self.query_offset.saturating_sub(before_bytes);
384 let point = offset.to_point(self.buffer);
385 Line(point.row + 1)
386 };
387 let start_offset = Point::new(start_line.0, 0).to_offset(&self.buffer);
388 let end_line = {
389 let offset = start_offset + bytes_remaining;
390 let point = offset.to_point(self.buffer);
391 Line(point.row)
392 };
393 let end_offset = Point::new(end_line.0, 0).to_offset(&self.buffer);
394
395 // this could be expanded further since recalculated `signature_size` may be smaller, but
396 // skipping that for now for simplicity
397 //
398 // TODO: could also consider checking if lines immediately before / after fit.
399 let excerpt = self.make_excerpt(start_offset..end_offset, start_line..end_line);
400 if excerpt.size > self.options.max_bytes {
401 log::error!(
402 "bug: line-based excerpt selection has size {}, \
403 which is {} bytes larger than the max size",
404 excerpt.size,
405 excerpt.size - self.options.max_bytes
406 );
407 }
408 return Some(excerpt);
409 }
410
411 fn make_excerpt(&self, range: Range<usize>, line_range: Range<Line>) -> EditPredictionExcerpt {
412 let parent_declarations = self
413 .parent_declarations
414 .iter()
415 .filter(|(_, declaration)| declaration.item_range.contains_inclusive(&range))
416 .map(|(id, declaration)| (*id, declaration.signature_range.clone()))
417 .collect();
418 EditPredictionExcerpt::new(range, line_range, parent_declarations)
419 }
420
421 /// Returns `true` if the `forward` excerpt is a better choice than the `backward` excerpt.
422 fn is_better_excerpt(
423 &self,
424 forward: &EditPredictionExcerpt,
425 backward: &EditPredictionExcerpt,
426 ) -> bool {
427 let forward_ratio = self.excerpt_range_ratio(forward);
428 let backward_ratio = self.excerpt_range_ratio(backward);
429 let forward_delta =
430 (forward_ratio - self.options.target_before_cursor_over_total_bytes).abs();
431 let backward_delta =
432 (backward_ratio - self.options.target_before_cursor_over_total_bytes).abs();
433 let forward_is_better = forward_delta <= backward_delta;
434 if forward_is_better {
435 log::debug!(
436 "expanding forward since {} is closer than {} to {}",
437 forward_ratio,
438 backward_ratio,
439 self.options.target_before_cursor_over_total_bytes
440 );
441 } else {
442 log::debug!(
443 "expanding backward since {} is closer than {} to {}",
444 backward_ratio,
445 forward_ratio,
446 self.options.target_before_cursor_over_total_bytes
447 );
448 }
449 forward_is_better
450 }
451
452 /// Returns the ratio of bytes before the cursor over bytes within the range.
453 fn excerpt_range_ratio(&self, excerpt: &EditPredictionExcerpt) -> f32 {
454 let Some(bytes_before_cursor) = self.query_offset.checked_sub(excerpt.range.start) else {
455 log::error!("bug: edit prediction cursor offset is not outside the excerpt");
456 return 0.0;
457 };
458 bytes_before_cursor as f32 / excerpt.range.len() as f32
459 }
460}
461
462fn node_line_start(node: Node) -> Point {
463 Point::new(node.start_position().row as u32, 0)
464}
465
466fn node_line_end(node: Node) -> Point {
467 Point::new(node.end_position().row as u32 + 1, 0)
468}
469
470#[cfg(test)]
471mod tests {
472 use super::*;
473 use gpui::{AppContext, TestAppContext};
474 use language::{Buffer, Language, LanguageConfig, LanguageMatcher, tree_sitter_rust};
475 use util::test::{generate_marked_text, marked_text_offsets_by};
476
477 fn create_buffer(text: &str, cx: &mut TestAppContext) -> BufferSnapshot {
478 let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(rust_lang().into(), cx));
479 buffer.read_with(cx, |buffer, _| buffer.snapshot())
480 }
481
482 fn rust_lang() -> Language {
483 Language::new(
484 LanguageConfig {
485 name: "Rust".into(),
486 matcher: LanguageMatcher {
487 path_suffixes: vec!["rs".to_string()],
488 ..Default::default()
489 },
490 ..Default::default()
491 },
492 Some(tree_sitter_rust::LANGUAGE.into()),
493 )
494 .with_outline_query(include_str!("../../languages/src/rust/outline.scm"))
495 .unwrap()
496 }
497
498 fn cursor_and_excerpt_range(text: &str) -> (String, usize, Range<usize>) {
499 let (text, offsets) = marked_text_offsets_by(text, vec!['ˇ', '«', '»']);
500 (text, offsets[&'ˇ'][0], offsets[&'«'][0]..offsets[&'»'][0])
501 }
502
503 fn check_example(options: EditPredictionExcerptOptions, text: &str, cx: &mut TestAppContext) {
504 let (text, cursor, expected_excerpt) = cursor_and_excerpt_range(text);
505
506 let buffer = create_buffer(&text, cx);
507 let cursor_point = cursor.to_point(&buffer);
508
509 let excerpt =
510 EditPredictionExcerpt::select_from_buffer(cursor_point, &buffer, &options, None)
511 .expect("Should select an excerpt");
512 pretty_assertions::assert_eq!(
513 generate_marked_text(&text, std::slice::from_ref(&excerpt.range), false),
514 generate_marked_text(&text, &[expected_excerpt], false)
515 );
516 assert!(excerpt.size <= options.max_bytes);
517 assert!(excerpt.range.contains(&cursor));
518 }
519
520 #[gpui::test]
521 fn test_ast_based_selection_current_node(cx: &mut TestAppContext) {
522 zlog::init_test();
523 let text = r#"
524fn main() {
525 let x = 1;
526« let ˇy = 2;
527» let z = 3;
528}"#;
529
530 let options = EditPredictionExcerptOptions {
531 max_bytes: 20,
532 min_bytes: 10,
533 target_before_cursor_over_total_bytes: 0.5,
534 };
535
536 check_example(options, text, cx);
537 }
538
539 #[gpui::test]
540 fn test_ast_based_selection_parent_node(cx: &mut TestAppContext) {
541 zlog::init_test();
542 let text = r#"
543fn foo() {}
544
545«fn main() {
546 let x = 1;
547 let ˇy = 2;
548 let z = 3;
549}
550»
551fn bar() {}"#;
552
553 let options = EditPredictionExcerptOptions {
554 max_bytes: 65,
555 min_bytes: 10,
556 target_before_cursor_over_total_bytes: 0.5,
557 };
558
559 check_example(options, text, cx);
560 }
561
562 #[gpui::test]
563 fn test_ast_based_selection_expands_to_siblings(cx: &mut TestAppContext) {
564 zlog::init_test();
565 let text = r#"
566fn main() {
567« let x = 1;
568 let ˇy = 2;
569 let z = 3;
570»}"#;
571
572 let options = EditPredictionExcerptOptions {
573 max_bytes: 50,
574 min_bytes: 10,
575 target_before_cursor_over_total_bytes: 0.5,
576 };
577
578 check_example(options, text, cx);
579 }
580
581 #[gpui::test]
582 fn test_line_based_selection(cx: &mut TestAppContext) {
583 zlog::init_test();
584 let text = r#"
585fn main() {
586 let x = 1;
587« if true {
588 let ˇy = 2;
589 }
590 let z = 3;
591»}"#;
592
593 let options = EditPredictionExcerptOptions {
594 max_bytes: 60,
595 min_bytes: 45,
596 target_before_cursor_over_total_bytes: 0.5,
597 };
598
599 check_example(options, text, cx);
600 }
601
602 #[gpui::test]
603 fn test_line_based_selection_with_before_cursor_ratio(cx: &mut TestAppContext) {
604 zlog::init_test();
605 let text = r#"
606 fn main() {
607« let a = 1;
608 let b = 2;
609 let c = 3;
610 let ˇd = 4;
611 let e = 5;
612 let f = 6;
613»
614 let g = 7;
615 }"#;
616
617 let options = EditPredictionExcerptOptions {
618 max_bytes: 120,
619 min_bytes: 10,
620 target_before_cursor_over_total_bytes: 0.6,
621 };
622
623 check_example(options, text, cx);
624 }
625}