1use language::{BufferSnapshot, LanguageId};
2use std::ops::Range;
3use text::{Point, ToOffset as _, ToPoint as _};
4use tree_sitter::{Node, TreeCursor};
5use util::RangeExt;
6
7use crate::{BufferDeclaration, declaration::DeclarationId, syntax_index::SyntaxIndexState};
8
9// TODO:
10//
11// - Test parent signatures
12//
13// - Decide whether to count signatures against the excerpt size. Could instead defer this to prompt
14// planning.
15//
16// - Still return an excerpt even if the line around the cursor doesn't fit (e.g. for a markdown
17// paragraph).
18//
19// - Truncation of long lines.
20//
21// - Filter outer syntax layers that don't support edit prediction.
22
23#[derive(Debug, Clone, PartialEq)]
24pub struct EditPredictionExcerptOptions {
25 /// Limit for the number of bytes in the window around the cursor.
26 pub max_bytes: usize,
27 /// Minimum number of bytes in the window around the cursor. When syntax tree selection results
28 /// in an excerpt smaller than this, it will fall back on line-based selection.
29 pub min_bytes: usize,
30 /// Target ratio of bytes before the cursor divided by total bytes in the window.
31 pub target_before_cursor_over_total_bytes: f32,
32}
33
34// TODO: consider merging these
35#[derive(Debug, Clone)]
36pub struct EditPredictionExcerpt {
37 pub range: Range<usize>,
38 pub parent_declarations: Vec<(DeclarationId, Range<usize>)>,
39 pub size: usize,
40}
41
42#[derive(Debug, Clone)]
43pub struct EditPredictionExcerptText {
44 pub body: String,
45 pub parent_signatures: Vec<String>,
46 pub language_id: Option<LanguageId>,
47}
48
49impl EditPredictionExcerpt {
50 pub fn text(&self, buffer: &BufferSnapshot) -> EditPredictionExcerptText {
51 let body = buffer
52 .text_for_range(self.range.clone())
53 .collect::<String>();
54 let parent_signatures = self
55 .parent_declarations
56 .iter()
57 .map(|(_, range)| buffer.text_for_range(range.clone()).collect::<String>())
58 .collect();
59 let language_id = buffer.language().map(|l| l.id());
60 EditPredictionExcerptText {
61 body,
62 parent_signatures,
63 language_id,
64 }
65 }
66
67 /// Selects an excerpt around a buffer position, attempting to choose logical boundaries based
68 /// on TreeSitter structure and approximately targeting a goal ratio of bytesbefore vs after the
69 /// cursor.
70 ///
71 /// When `index` is provided, the excerpt will include the signatures of parent outline items.
72 ///
73 /// First tries to use AST node boundaries to select the excerpt, and falls back on line-based
74 /// expansion.
75 ///
76 /// Returns `None` if the line around the cursor doesn't fit.
77 pub fn select_from_buffer(
78 query_point: Point,
79 buffer: &BufferSnapshot,
80 options: &EditPredictionExcerptOptions,
81 syntax_index: Option<&SyntaxIndexState>,
82 ) -> Option<Self> {
83 if buffer.len() <= options.max_bytes {
84 log::debug!(
85 "using entire file for excerpt since source length ({}) <= window max bytes ({})",
86 buffer.len(),
87 options.max_bytes
88 );
89 return Some(EditPredictionExcerpt::new(0..buffer.len(), Vec::new()));
90 }
91
92 let query_offset = query_point.to_offset(buffer);
93 let query_range = Point::new(query_point.row, 0).to_offset(buffer)
94 ..Point::new(query_point.row + 1, 0).to_offset(buffer);
95 if query_range.len() >= options.max_bytes {
96 return None;
97 }
98
99 let parent_declarations = if let Some(syntax_index) = syntax_index {
100 syntax_index
101 .buffer_declarations_containing_range(buffer.remote_id(), query_range.clone())
102 .collect()
103 } else {
104 Vec::new()
105 };
106
107 let excerpt_selector = ExcerptSelector {
108 query_offset,
109 query_range,
110 parent_declarations: &parent_declarations,
111 buffer,
112 options,
113 };
114
115 if let Some(excerpt) = excerpt_selector.select_tree_sitter_nodes() {
116 if excerpt.size >= options.min_bytes {
117 return Some(excerpt);
118 }
119 log::debug!(
120 "tree-sitter excerpt was {} bytes, smaller than min of {}, falling back on line-based selection",
121 excerpt.size,
122 options.min_bytes
123 );
124 } else {
125 log::debug!(
126 "couldn't find excerpt via tree-sitter, falling back on line-based selection"
127 );
128 }
129
130 excerpt_selector.select_lines()
131 }
132
133 fn new(range: Range<usize>, parent_declarations: Vec<(DeclarationId, Range<usize>)>) -> Self {
134 let size = range.len()
135 + parent_declarations
136 .iter()
137 .map(|(_, range)| range.len())
138 .sum::<usize>();
139 Self {
140 range,
141 parent_declarations,
142 size,
143 }
144 }
145
146 fn with_expanded_range(&self, new_range: Range<usize>) -> Self {
147 if !new_range.contains_inclusive(&self.range) {
148 // this is an issue because parent_signature_ranges may be incorrect
149 log::error!("bug: with_expanded_range called with disjoint range");
150 }
151 let mut parent_declarations = Vec::with_capacity(self.parent_declarations.len());
152 for (declaration_id, range) in &self.parent_declarations {
153 if !range.contains_inclusive(&new_range) {
154 break;
155 }
156 parent_declarations.push((*declaration_id, range.clone()));
157 }
158 Self::new(new_range, parent_declarations)
159 }
160
161 fn parent_signatures_size(&self) -> usize {
162 self.size - self.range.len()
163 }
164}
165
166struct ExcerptSelector<'a> {
167 query_offset: usize,
168 query_range: Range<usize>,
169 parent_declarations: &'a [(DeclarationId, &'a BufferDeclaration)],
170 buffer: &'a BufferSnapshot,
171 options: &'a EditPredictionExcerptOptions,
172}
173
174impl<'a> ExcerptSelector<'a> {
175 /// Finds the largest node that is smaller than the window size and contains `query_range`.
176 fn select_tree_sitter_nodes(&self) -> Option<EditPredictionExcerpt> {
177 let selected_layer_root = self.select_syntax_layer()?;
178 let mut cursor = selected_layer_root.walk();
179
180 loop {
181 let excerpt_range = node_line_start(cursor.node()).to_offset(&self.buffer)
182 ..node_line_end(cursor.node()).to_offset(&self.buffer);
183 if excerpt_range.contains_inclusive(&self.query_range) {
184 let excerpt = self.make_excerpt(excerpt_range);
185 if excerpt.size <= self.options.max_bytes {
186 return Some(self.expand_to_siblings(&mut cursor, excerpt));
187 }
188 } else {
189 // TODO: Should still be able to handle this case via AST nodes. For example, this
190 // can happen if the cursor is between two methods in a large class file.
191 return None;
192 }
193
194 if cursor
195 .goto_first_child_for_byte(self.query_range.start)
196 .is_none()
197 {
198 return None;
199 }
200 }
201 }
202
203 /// Select the smallest syntax layer that exceeds max_len, or the largest if none exceed max_len.
204 fn select_syntax_layer(&self) -> Option<Node<'_>> {
205 let mut smallest_exceeding_max_len: Option<Node<'_>> = None;
206 let mut largest: Option<Node<'_>> = None;
207 for layer in self
208 .buffer
209 .syntax_layers_for_range(self.query_range.start..self.query_range.start, true)
210 {
211 let layer_range = layer.node().byte_range();
212 if !layer_range.contains_inclusive(&self.query_range) {
213 continue;
214 }
215
216 if layer_range.len() > self.options.max_bytes {
217 match &smallest_exceeding_max_len {
218 None => smallest_exceeding_max_len = Some(layer.node()),
219 Some(existing) => {
220 if layer_range.len() < existing.byte_range().len() {
221 smallest_exceeding_max_len = Some(layer.node());
222 }
223 }
224 }
225 } else {
226 match &largest {
227 None => largest = Some(layer.node()),
228 Some(existing) if layer_range.len() > existing.byte_range().len() => {
229 largest = Some(layer.node())
230 }
231 _ => {}
232 }
233 }
234 }
235
236 smallest_exceeding_max_len.or(largest)
237 }
238
239 // motivation for this and `goto_previous_named_sibling` is to avoid including things like
240 // trailing unnamed "}" in body nodes
241 fn goto_next_named_sibling(cursor: &mut TreeCursor) -> bool {
242 while cursor.goto_next_sibling() {
243 if cursor.node().is_named() {
244 return true;
245 }
246 }
247 false
248 }
249
250 fn goto_previous_named_sibling(cursor: &mut TreeCursor) -> bool {
251 while cursor.goto_previous_sibling() {
252 if cursor.node().is_named() {
253 return true;
254 }
255 }
256 false
257 }
258
259 fn expand_to_siblings(
260 &self,
261 cursor: &mut TreeCursor,
262 mut excerpt: EditPredictionExcerpt,
263 ) -> EditPredictionExcerpt {
264 let mut forward_cursor = cursor.clone();
265 let backward_cursor = cursor;
266 let mut forward_done = !Self::goto_next_named_sibling(&mut forward_cursor);
267 let mut backward_done = !Self::goto_previous_named_sibling(backward_cursor);
268 loop {
269 if backward_done && forward_done {
270 break;
271 }
272
273 let mut forward = None;
274 while !forward_done {
275 let new_end = node_line_end(forward_cursor.node()).to_offset(&self.buffer);
276 if new_end > excerpt.range.end {
277 let new_excerpt = excerpt.with_expanded_range(excerpt.range.start..new_end);
278 if new_excerpt.size <= self.options.max_bytes {
279 forward = Some(new_excerpt);
280 break;
281 } else {
282 log::debug!("halting forward expansion, as it doesn't fit");
283 forward_done = true;
284 break;
285 }
286 }
287 forward_done = !Self::goto_next_named_sibling(&mut forward_cursor);
288 }
289
290 let mut backward = None;
291 while !backward_done {
292 let new_start = node_line_start(backward_cursor.node()).to_offset(&self.buffer);
293 if new_start < excerpt.range.start {
294 let new_excerpt = excerpt.with_expanded_range(new_start..excerpt.range.end);
295 if new_excerpt.size <= self.options.max_bytes {
296 backward = Some(new_excerpt);
297 break;
298 } else {
299 log::debug!("halting backward expansion, as it doesn't fit");
300 backward_done = true;
301 break;
302 }
303 }
304 backward_done = !Self::goto_previous_named_sibling(backward_cursor);
305 }
306
307 let go_forward = match (forward, backward) {
308 (Some(forward), Some(backward)) => {
309 let go_forward = self.is_better_excerpt(&forward, &backward);
310 if go_forward {
311 excerpt = forward;
312 } else {
313 excerpt = backward;
314 }
315 go_forward
316 }
317 (Some(forward), None) => {
318 log::debug!("expanding forward, since backward expansion has halted");
319 excerpt = forward;
320 true
321 }
322 (None, Some(backward)) => {
323 log::debug!("expanding backward, since forward expansion has halted");
324 excerpt = backward;
325 false
326 }
327 (None, None) => break,
328 };
329
330 if go_forward {
331 forward_done = !Self::goto_next_named_sibling(&mut forward_cursor);
332 } else {
333 backward_done = !Self::goto_previous_named_sibling(backward_cursor);
334 }
335 }
336
337 excerpt
338 }
339
340 fn select_lines(&self) -> Option<EditPredictionExcerpt> {
341 // early return if line containing query_offset is already too large
342 let excerpt = self.make_excerpt(self.query_range.clone());
343 if excerpt.size > self.options.max_bytes {
344 log::debug!(
345 "excerpt for cursor line is {} bytes, which exceeds the window",
346 excerpt.size
347 );
348 return None;
349 }
350 let signatures_size = excerpt.parent_signatures_size();
351 let bytes_remaining = self.options.max_bytes.saturating_sub(signatures_size);
352
353 let before_bytes =
354 (self.options.target_before_cursor_over_total_bytes * bytes_remaining as f32) as usize;
355
356 let start_point = {
357 let offset = self.query_offset.saturating_sub(before_bytes);
358 let point = offset.to_point(self.buffer);
359 Point::new(point.row + 1, 0)
360 };
361 let start_offset = start_point.to_offset(&self.buffer);
362 let end_point = {
363 let offset = start_offset + bytes_remaining;
364 let point = offset.to_point(self.buffer);
365 Point::new(point.row, 0)
366 };
367 let end_offset = end_point.to_offset(&self.buffer);
368
369 // this could be expanded further since recalculated `signature_size` may be smaller, but
370 // skipping that for now for simplicity
371 //
372 // TODO: could also consider checking if lines immediately before / after fit.
373 let excerpt = self.make_excerpt(start_offset..end_offset);
374 if excerpt.size > self.options.max_bytes {
375 log::error!(
376 "bug: line-based excerpt selection has size {}, \
377 which is {} bytes larger than the max size",
378 excerpt.size,
379 excerpt.size - self.options.max_bytes
380 );
381 }
382 return Some(excerpt);
383 }
384
385 fn make_excerpt(&self, range: Range<usize>) -> EditPredictionExcerpt {
386 let parent_declarations = self
387 .parent_declarations
388 .iter()
389 .filter(|(_, declaration)| declaration.item_range.contains_inclusive(&range))
390 .map(|(id, declaration)| (*id, declaration.signature_range.clone()))
391 .collect();
392 EditPredictionExcerpt::new(range, parent_declarations)
393 }
394
395 /// Returns `true` if the `forward` excerpt is a better choice than the `backward` excerpt.
396 fn is_better_excerpt(
397 &self,
398 forward: &EditPredictionExcerpt,
399 backward: &EditPredictionExcerpt,
400 ) -> bool {
401 let forward_ratio = self.excerpt_range_ratio(forward);
402 let backward_ratio = self.excerpt_range_ratio(backward);
403 let forward_delta =
404 (forward_ratio - self.options.target_before_cursor_over_total_bytes).abs();
405 let backward_delta =
406 (backward_ratio - self.options.target_before_cursor_over_total_bytes).abs();
407 let forward_is_better = forward_delta <= backward_delta;
408 if forward_is_better {
409 log::debug!(
410 "expanding forward since {} is closer than {} to {}",
411 forward_ratio,
412 backward_ratio,
413 self.options.target_before_cursor_over_total_bytes
414 );
415 } else {
416 log::debug!(
417 "expanding backward since {} is closer than {} to {}",
418 backward_ratio,
419 forward_ratio,
420 self.options.target_before_cursor_over_total_bytes
421 );
422 }
423 forward_is_better
424 }
425
426 /// Returns the ratio of bytes before the cursor over bytes within the range.
427 fn excerpt_range_ratio(&self, excerpt: &EditPredictionExcerpt) -> f32 {
428 let Some(bytes_before_cursor) = self.query_offset.checked_sub(excerpt.range.start) else {
429 log::error!("bug: edit prediction cursor offset is not outside the excerpt");
430 return 0.0;
431 };
432 bytes_before_cursor as f32 / excerpt.range.len() as f32
433 }
434}
435
436fn node_line_start(node: Node) -> Point {
437 Point::new(node.start_position().row as u32, 0)
438}
439
440fn node_line_end(node: Node) -> Point {
441 Point::new(node.end_position().row as u32 + 1, 0)
442}
443
444#[cfg(test)]
445mod tests {
446 use super::*;
447 use gpui::{AppContext, TestAppContext};
448 use language::{Buffer, Language, LanguageConfig, LanguageMatcher, tree_sitter_rust};
449 use util::test::{generate_marked_text, marked_text_offsets_by};
450
451 fn create_buffer(text: &str, cx: &mut TestAppContext) -> BufferSnapshot {
452 let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(rust_lang().into(), cx));
453 buffer.read_with(cx, |buffer, _| buffer.snapshot())
454 }
455
456 fn rust_lang() -> Language {
457 Language::new(
458 LanguageConfig {
459 name: "Rust".into(),
460 matcher: LanguageMatcher {
461 path_suffixes: vec!["rs".to_string()],
462 ..Default::default()
463 },
464 ..Default::default()
465 },
466 Some(tree_sitter_rust::LANGUAGE.into()),
467 )
468 .with_outline_query(include_str!("../../languages/src/rust/outline.scm"))
469 .unwrap()
470 }
471
472 fn cursor_and_excerpt_range(text: &str) -> (String, usize, Range<usize>) {
473 let (text, offsets) = marked_text_offsets_by(text, vec!['ˇ', '«', '»']);
474 (text, offsets[&'ˇ'][0], offsets[&'«'][0]..offsets[&'»'][0])
475 }
476
477 fn check_example(options: EditPredictionExcerptOptions, text: &str, cx: &mut TestAppContext) {
478 let (text, cursor, expected_excerpt) = cursor_and_excerpt_range(text);
479
480 let buffer = create_buffer(&text, cx);
481 let cursor_point = cursor.to_point(&buffer);
482
483 let excerpt =
484 EditPredictionExcerpt::select_from_buffer(cursor_point, &buffer, &options, None)
485 .expect("Should select an excerpt");
486 pretty_assertions::assert_eq!(
487 generate_marked_text(&text, std::slice::from_ref(&excerpt.range), false),
488 generate_marked_text(&text, &[expected_excerpt], false)
489 );
490 assert!(excerpt.size <= options.max_bytes);
491 assert!(excerpt.range.contains(&cursor));
492 }
493
494 #[gpui::test]
495 fn test_ast_based_selection_current_node(cx: &mut TestAppContext) {
496 zlog::init_test();
497 let text = r#"
498fn main() {
499 let x = 1;
500« let ˇy = 2;
501» let z = 3;
502}"#;
503
504 let options = EditPredictionExcerptOptions {
505 max_bytes: 20,
506 min_bytes: 10,
507 target_before_cursor_over_total_bytes: 0.5,
508 };
509
510 check_example(options, text, cx);
511 }
512
513 #[gpui::test]
514 fn test_ast_based_selection_parent_node(cx: &mut TestAppContext) {
515 zlog::init_test();
516 let text = r#"
517fn foo() {}
518
519«fn main() {
520 let x = 1;
521 let ˇy = 2;
522 let z = 3;
523}
524»
525fn bar() {}"#;
526
527 let options = EditPredictionExcerptOptions {
528 max_bytes: 65,
529 min_bytes: 10,
530 target_before_cursor_over_total_bytes: 0.5,
531 };
532
533 check_example(options, text, cx);
534 }
535
536 #[gpui::test]
537 fn test_ast_based_selection_expands_to_siblings(cx: &mut TestAppContext) {
538 zlog::init_test();
539 let text = r#"
540fn main() {
541« let x = 1;
542 let ˇy = 2;
543 let z = 3;
544»}"#;
545
546 let options = EditPredictionExcerptOptions {
547 max_bytes: 50,
548 min_bytes: 10,
549 target_before_cursor_over_total_bytes: 0.5,
550 };
551
552 check_example(options, text, cx);
553 }
554
555 #[gpui::test]
556 fn test_line_based_selection(cx: &mut TestAppContext) {
557 zlog::init_test();
558 let text = r#"
559fn main() {
560 let x = 1;
561« if true {
562 let ˇy = 2;
563 }
564 let z = 3;
565»}"#;
566
567 let options = EditPredictionExcerptOptions {
568 max_bytes: 60,
569 min_bytes: 45,
570 target_before_cursor_over_total_bytes: 0.5,
571 };
572
573 check_example(options, text, cx);
574 }
575
576 #[gpui::test]
577 fn test_line_based_selection_with_before_cursor_ratio(cx: &mut TestAppContext) {
578 zlog::init_test();
579 let text = r#"
580 fn main() {
581« let a = 1;
582 let b = 2;
583 let c = 3;
584 let ˇd = 4;
585 let e = 5;
586 let f = 6;
587»
588 let g = 7;
589 }"#;
590
591 let options = EditPredictionExcerptOptions {
592 max_bytes: 120,
593 min_bytes: 10,
594 target_before_cursor_over_total_bytes: 0.6,
595 };
596
597 check_example(options, text, cx);
598 }
599}