1use language::BufferSnapshot;
2use std::ops::Range;
3use text::{Point, ToOffset as _, ToPoint as _};
4use tree_sitter::{Node, TreeCursor};
5use util::RangeExt;
6
7use crate::{BufferDeclaration, declaration::DeclarationId, syntax_index::SyntaxIndexState};
8
9// TODO:
10//
11// - Test parent signatures
12//
13// - Decide whether to count signatures against the excerpt size. Could instead defer this to prompt
14// planning.
15//
16// - Still return an excerpt even if the line around the cursor doesn't fit (e.g. for a markdown
17// paragraph).
18//
19// - Truncation of long lines.
20//
21// - Filter outer syntax layers that don't support edit prediction.
22
23#[derive(Debug, Clone)]
24pub struct EditPredictionExcerptOptions {
25 /// Limit for the number of bytes in the window around the cursor.
26 pub max_bytes: usize,
27 /// Minimum number of bytes in the window around the cursor. When syntax tree selection results
28 /// in an excerpt smaller than this, it will fall back on line-based selection.
29 pub min_bytes: usize,
30 /// Target ratio of bytes before the cursor divided by total bytes in the window.
31 pub target_before_cursor_over_total_bytes: f32,
32}
33
34#[derive(Debug, Clone)]
35pub struct EditPredictionExcerpt {
36 pub range: Range<usize>,
37 pub parent_declarations: Vec<(DeclarationId, Range<usize>)>,
38 pub size: usize,
39}
40
41#[derive(Debug, Clone)]
42pub struct EditPredictionExcerptText {
43 pub body: String,
44 pub parent_signatures: Vec<String>,
45}
46
47impl EditPredictionExcerpt {
48 pub fn text(&self, buffer: &BufferSnapshot) -> EditPredictionExcerptText {
49 let body = buffer
50 .text_for_range(self.range.clone())
51 .collect::<String>();
52 let parent_signatures = self
53 .parent_declarations
54 .iter()
55 .map(|(_, range)| buffer.text_for_range(range.clone()).collect::<String>())
56 .collect();
57 EditPredictionExcerptText {
58 body,
59 parent_signatures,
60 }
61 }
62
63 /// Selects an excerpt around a buffer position, attempting to choose logical boundaries based
64 /// on TreeSitter structure and approximately targeting a goal ratio of bytesbefore vs after the
65 /// cursor.
66 ///
67 /// When `index` is provided, the excerpt will include the signatures of parent outline items.
68 ///
69 /// First tries to use AST node boundaries to select the excerpt, and falls back on line-based
70 /// expansion.
71 ///
72 /// Returns `None` if the line around the cursor doesn't fit.
73 pub fn select_from_buffer(
74 query_point: Point,
75 buffer: &BufferSnapshot,
76 options: &EditPredictionExcerptOptions,
77 syntax_index: Option<&SyntaxIndexState>,
78 ) -> Option<Self> {
79 if buffer.len() <= options.max_bytes {
80 log::debug!(
81 "using entire file for excerpt since source length ({}) <= window max bytes ({})",
82 buffer.len(),
83 options.max_bytes
84 );
85 return Some(EditPredictionExcerpt::new(0..buffer.len(), Vec::new()));
86 }
87
88 let query_offset = query_point.to_offset(buffer);
89 let query_range = Point::new(query_point.row, 0).to_offset(buffer)
90 ..Point::new(query_point.row + 1, 0).to_offset(buffer);
91 if query_range.len() >= options.max_bytes {
92 return None;
93 }
94
95 let parent_declarations = if let Some(syntax_index) = syntax_index {
96 syntax_index
97 .buffer_declarations_containing_range(buffer.remote_id(), query_range.clone())
98 .collect()
99 } else {
100 Vec::new()
101 };
102
103 let excerpt_selector = ExcerptSelector {
104 query_offset,
105 query_range,
106 parent_declarations: &parent_declarations,
107 buffer,
108 options,
109 };
110
111 if let Some(excerpt) = excerpt_selector.select_tree_sitter_nodes() {
112 if excerpt.size >= options.min_bytes {
113 return Some(excerpt);
114 }
115 log::debug!(
116 "tree-sitter excerpt was {} bytes, smaller than min of {}, falling back on line-based selection",
117 excerpt.size,
118 options.min_bytes
119 );
120 } else {
121 log::debug!(
122 "couldn't find excerpt via tree-sitter, falling back on line-based selection"
123 );
124 }
125
126 excerpt_selector.select_lines()
127 }
128
129 fn new(range: Range<usize>, parent_declarations: Vec<(DeclarationId, Range<usize>)>) -> Self {
130 let size = range.len()
131 + parent_declarations
132 .iter()
133 .map(|(_, range)| range.len())
134 .sum::<usize>();
135 Self {
136 range,
137 parent_declarations,
138 size,
139 }
140 }
141
142 fn with_expanded_range(&self, new_range: Range<usize>) -> Self {
143 if !new_range.contains_inclusive(&self.range) {
144 // this is an issue because parent_signature_ranges may be incorrect
145 log::error!("bug: with_expanded_range called with disjoint range");
146 }
147 let mut parent_declarations = Vec::with_capacity(self.parent_declarations.len());
148 for (declaration_id, range) in &self.parent_declarations {
149 if !range.contains_inclusive(&new_range) {
150 break;
151 }
152 parent_declarations.push((*declaration_id, range.clone()));
153 }
154 Self::new(new_range, parent_declarations)
155 }
156
157 fn parent_signatures_size(&self) -> usize {
158 self.size - self.range.len()
159 }
160}
161
162struct ExcerptSelector<'a> {
163 query_offset: usize,
164 query_range: Range<usize>,
165 parent_declarations: &'a [(DeclarationId, &'a BufferDeclaration)],
166 buffer: &'a BufferSnapshot,
167 options: &'a EditPredictionExcerptOptions,
168}
169
170impl<'a> ExcerptSelector<'a> {
171 /// Finds the largest node that is smaller than the window size and contains `query_range`.
172 fn select_tree_sitter_nodes(&self) -> Option<EditPredictionExcerpt> {
173 let selected_layer_root = self.select_syntax_layer()?;
174 let mut cursor = selected_layer_root.walk();
175
176 loop {
177 let excerpt_range = node_line_start(cursor.node()).to_offset(&self.buffer)
178 ..node_line_end(cursor.node()).to_offset(&self.buffer);
179 if excerpt_range.contains_inclusive(&self.query_range) {
180 let excerpt = self.make_excerpt(excerpt_range);
181 if excerpt.size <= self.options.max_bytes {
182 return Some(self.expand_to_siblings(&mut cursor, excerpt));
183 }
184 } else {
185 // TODO: Should still be able to handle this case via AST nodes. For example, this
186 // can happen if the cursor is between two methods in a large class file.
187 return None;
188 }
189
190 if cursor
191 .goto_first_child_for_byte(self.query_range.start)
192 .is_none()
193 {
194 return None;
195 }
196 }
197 }
198
199 /// Select the smallest syntax layer that exceeds max_len, or the largest if none exceed max_len.
200 fn select_syntax_layer(&self) -> Option<Node<'_>> {
201 let mut smallest_exceeding_max_len: Option<Node<'_>> = None;
202 let mut largest: Option<Node<'_>> = None;
203 for layer in self
204 .buffer
205 .syntax_layers_for_range(self.query_range.start..self.query_range.start, true)
206 {
207 let layer_range = layer.node().byte_range();
208 if !layer_range.contains_inclusive(&self.query_range) {
209 continue;
210 }
211
212 if layer_range.len() > self.options.max_bytes {
213 match &smallest_exceeding_max_len {
214 None => smallest_exceeding_max_len = Some(layer.node()),
215 Some(existing) => {
216 if layer_range.len() < existing.byte_range().len() {
217 smallest_exceeding_max_len = Some(layer.node());
218 }
219 }
220 }
221 } else {
222 match &largest {
223 None => largest = Some(layer.node()),
224 Some(existing) if layer_range.len() > existing.byte_range().len() => {
225 largest = Some(layer.node())
226 }
227 _ => {}
228 }
229 }
230 }
231
232 smallest_exceeding_max_len.or(largest)
233 }
234
235 // motivation for this and `goto_previous_named_sibling` is to avoid including things like
236 // trailing unnamed "}" in body nodes
237 fn goto_next_named_sibling(cursor: &mut TreeCursor) -> bool {
238 while cursor.goto_next_sibling() {
239 if cursor.node().is_named() {
240 return true;
241 }
242 }
243 false
244 }
245
246 fn goto_previous_named_sibling(cursor: &mut TreeCursor) -> bool {
247 while cursor.goto_previous_sibling() {
248 if cursor.node().is_named() {
249 return true;
250 }
251 }
252 false
253 }
254
255 fn expand_to_siblings(
256 &self,
257 cursor: &mut TreeCursor,
258 mut excerpt: EditPredictionExcerpt,
259 ) -> EditPredictionExcerpt {
260 let mut forward_cursor = cursor.clone();
261 let backward_cursor = cursor;
262 let mut forward_done = !Self::goto_next_named_sibling(&mut forward_cursor);
263 let mut backward_done = !Self::goto_previous_named_sibling(backward_cursor);
264 loop {
265 if backward_done && forward_done {
266 break;
267 }
268
269 let mut forward = None;
270 while !forward_done {
271 let new_end = node_line_end(forward_cursor.node()).to_offset(&self.buffer);
272 if new_end > excerpt.range.end {
273 let new_excerpt = excerpt.with_expanded_range(excerpt.range.start..new_end);
274 if new_excerpt.size <= self.options.max_bytes {
275 forward = Some(new_excerpt);
276 break;
277 } else {
278 log::debug!("halting forward expansion, as it doesn't fit");
279 forward_done = true;
280 break;
281 }
282 }
283 forward_done = !Self::goto_next_named_sibling(&mut forward_cursor);
284 }
285
286 let mut backward = None;
287 while !backward_done {
288 let new_start = node_line_start(backward_cursor.node()).to_offset(&self.buffer);
289 if new_start < excerpt.range.start {
290 let new_excerpt = excerpt.with_expanded_range(new_start..excerpt.range.end);
291 if new_excerpt.size <= self.options.max_bytes {
292 backward = Some(new_excerpt);
293 break;
294 } else {
295 log::debug!("halting backward expansion, as it doesn't fit");
296 backward_done = true;
297 break;
298 }
299 }
300 backward_done = !Self::goto_previous_named_sibling(backward_cursor);
301 }
302
303 let go_forward = match (forward, backward) {
304 (Some(forward), Some(backward)) => {
305 let go_forward = self.is_better_excerpt(&forward, &backward);
306 if go_forward {
307 excerpt = forward;
308 } else {
309 excerpt = backward;
310 }
311 go_forward
312 }
313 (Some(forward), None) => {
314 log::debug!("expanding forward, since backward expansion has halted");
315 excerpt = forward;
316 true
317 }
318 (None, Some(backward)) => {
319 log::debug!("expanding backward, since forward expansion has halted");
320 excerpt = backward;
321 false
322 }
323 (None, None) => break,
324 };
325
326 if go_forward {
327 forward_done = !Self::goto_next_named_sibling(&mut forward_cursor);
328 } else {
329 backward_done = !Self::goto_previous_named_sibling(backward_cursor);
330 }
331 }
332
333 excerpt
334 }
335
336 fn select_lines(&self) -> Option<EditPredictionExcerpt> {
337 // early return if line containing query_offset is already too large
338 let excerpt = self.make_excerpt(self.query_range.clone());
339 if excerpt.size > self.options.max_bytes {
340 log::debug!(
341 "excerpt for cursor line is {} bytes, which exceeds the window",
342 excerpt.size
343 );
344 return None;
345 }
346 let signatures_size = excerpt.parent_signatures_size();
347 let bytes_remaining = self.options.max_bytes.saturating_sub(signatures_size);
348
349 let before_bytes =
350 (self.options.target_before_cursor_over_total_bytes * bytes_remaining as f32) as usize;
351
352 let start_point = {
353 let offset = self.query_offset.saturating_sub(before_bytes);
354 let point = offset.to_point(self.buffer);
355 Point::new(point.row + 1, 0)
356 };
357 let start_offset = start_point.to_offset(&self.buffer);
358 let end_point = {
359 let offset = start_offset + bytes_remaining;
360 let point = offset.to_point(self.buffer);
361 Point::new(point.row, 0)
362 };
363 let end_offset = end_point.to_offset(&self.buffer);
364
365 // this could be expanded further since recalculated `signature_size` may be smaller, but
366 // skipping that for now for simplicity
367 //
368 // TODO: could also consider checking if lines immediately before / after fit.
369 let excerpt = self.make_excerpt(start_offset..end_offset);
370 if excerpt.size > self.options.max_bytes {
371 log::error!(
372 "bug: line-based excerpt selection has size {}, \
373 which is {} bytes larger than the max size",
374 excerpt.size,
375 excerpt.size - self.options.max_bytes
376 );
377 }
378 return Some(excerpt);
379 }
380
381 fn make_excerpt(&self, range: Range<usize>) -> EditPredictionExcerpt {
382 let parent_declarations = self
383 .parent_declarations
384 .iter()
385 .filter(|(_, declaration)| declaration.item_range.contains_inclusive(&range))
386 .map(|(id, declaration)| (*id, declaration.signature_range.clone()))
387 .collect();
388 EditPredictionExcerpt::new(range, parent_declarations)
389 }
390
391 /// Returns `true` if the `forward` excerpt is a better choice than the `backward` excerpt.
392 fn is_better_excerpt(
393 &self,
394 forward: &EditPredictionExcerpt,
395 backward: &EditPredictionExcerpt,
396 ) -> bool {
397 let forward_ratio = self.excerpt_range_ratio(forward);
398 let backward_ratio = self.excerpt_range_ratio(backward);
399 let forward_delta =
400 (forward_ratio - self.options.target_before_cursor_over_total_bytes).abs();
401 let backward_delta =
402 (backward_ratio - self.options.target_before_cursor_over_total_bytes).abs();
403 let forward_is_better = forward_delta <= backward_delta;
404 if forward_is_better {
405 log::debug!(
406 "expanding forward since {} is closer than {} to {}",
407 forward_ratio,
408 backward_ratio,
409 self.options.target_before_cursor_over_total_bytes
410 );
411 } else {
412 log::debug!(
413 "expanding backward since {} is closer than {} to {}",
414 backward_ratio,
415 forward_ratio,
416 self.options.target_before_cursor_over_total_bytes
417 );
418 }
419 forward_is_better
420 }
421
422 /// Returns the ratio of bytes before the cursor over bytes within the range.
423 fn excerpt_range_ratio(&self, excerpt: &EditPredictionExcerpt) -> f32 {
424 let Some(bytes_before_cursor) = self.query_offset.checked_sub(excerpt.range.start) else {
425 log::error!("bug: edit prediction cursor offset is not outside the excerpt");
426 return 0.0;
427 };
428 bytes_before_cursor as f32 / excerpt.range.len() as f32
429 }
430}
431
432fn node_line_start(node: Node) -> Point {
433 Point::new(node.start_position().row as u32, 0)
434}
435
436fn node_line_end(node: Node) -> Point {
437 Point::new(node.end_position().row as u32 + 1, 0)
438}
439
440#[cfg(test)]
441mod tests {
442 use super::*;
443 use gpui::{AppContext, TestAppContext};
444 use language::{Buffer, Language, LanguageConfig, LanguageMatcher, tree_sitter_rust};
445 use util::test::{generate_marked_text, marked_text_offsets_by};
446
447 fn create_buffer(text: &str, cx: &mut TestAppContext) -> BufferSnapshot {
448 let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(rust_lang().into(), cx));
449 buffer.read_with(cx, |buffer, _| buffer.snapshot())
450 }
451
452 fn rust_lang() -> Language {
453 Language::new(
454 LanguageConfig {
455 name: "Rust".into(),
456 matcher: LanguageMatcher {
457 path_suffixes: vec!["rs".to_string()],
458 ..Default::default()
459 },
460 ..Default::default()
461 },
462 Some(tree_sitter_rust::LANGUAGE.into()),
463 )
464 .with_outline_query(include_str!("../../languages/src/rust/outline.scm"))
465 .unwrap()
466 }
467
468 fn cursor_and_excerpt_range(text: &str) -> (String, usize, Range<usize>) {
469 let (text, offsets) = marked_text_offsets_by(text, vec!['ˇ', '«', '»']);
470 (text, offsets[&'ˇ'][0], offsets[&'«'][0]..offsets[&'»'][0])
471 }
472
473 fn check_example(options: EditPredictionExcerptOptions, text: &str, cx: &mut TestAppContext) {
474 let (text, cursor, expected_excerpt) = cursor_and_excerpt_range(text);
475
476 let buffer = create_buffer(&text, cx);
477 let cursor_point = cursor.to_point(&buffer);
478
479 let excerpt =
480 EditPredictionExcerpt::select_from_buffer(cursor_point, &buffer, &options, None)
481 .expect("Should select an excerpt");
482 pretty_assertions::assert_eq!(
483 generate_marked_text(&text, std::slice::from_ref(&excerpt.range), false),
484 generate_marked_text(&text, &[expected_excerpt], false)
485 );
486 assert!(excerpt.size <= options.max_bytes);
487 assert!(excerpt.range.contains(&cursor));
488 }
489
490 #[gpui::test]
491 fn test_ast_based_selection_current_node(cx: &mut TestAppContext) {
492 zlog::init_test();
493 let text = r#"
494fn main() {
495 let x = 1;
496« let ˇy = 2;
497» let z = 3;
498}"#;
499
500 let options = EditPredictionExcerptOptions {
501 max_bytes: 20,
502 min_bytes: 10,
503 target_before_cursor_over_total_bytes: 0.5,
504 };
505
506 check_example(options, text, cx);
507 }
508
509 #[gpui::test]
510 fn test_ast_based_selection_parent_node(cx: &mut TestAppContext) {
511 zlog::init_test();
512 let text = r#"
513fn foo() {}
514
515«fn main() {
516 let x = 1;
517 let ˇy = 2;
518 let z = 3;
519}
520»
521fn bar() {}"#;
522
523 let options = EditPredictionExcerptOptions {
524 max_bytes: 65,
525 min_bytes: 10,
526 target_before_cursor_over_total_bytes: 0.5,
527 };
528
529 check_example(options, text, cx);
530 }
531
532 #[gpui::test]
533 fn test_ast_based_selection_expands_to_siblings(cx: &mut TestAppContext) {
534 zlog::init_test();
535 let text = r#"
536fn main() {
537« let x = 1;
538 let ˇy = 2;
539 let z = 3;
540»}"#;
541
542 let options = EditPredictionExcerptOptions {
543 max_bytes: 50,
544 min_bytes: 10,
545 target_before_cursor_over_total_bytes: 0.5,
546 };
547
548 check_example(options, text, cx);
549 }
550
551 #[gpui::test]
552 fn test_line_based_selection(cx: &mut TestAppContext) {
553 zlog::init_test();
554 let text = r#"
555fn main() {
556 let x = 1;
557« if true {
558 let ˇy = 2;
559 }
560 let z = 3;
561»}"#;
562
563 let options = EditPredictionExcerptOptions {
564 max_bytes: 60,
565 min_bytes: 45,
566 target_before_cursor_over_total_bytes: 0.5,
567 };
568
569 check_example(options, text, cx);
570 }
571
572 #[gpui::test]
573 fn test_line_based_selection_with_before_cursor_ratio(cx: &mut TestAppContext) {
574 zlog::init_test();
575 let text = r#"
576 fn main() {
577« let a = 1;
578 let b = 2;
579 let c = 3;
580 let ˇd = 4;
581 let e = 5;
582 let f = 6;
583»
584 let g = 7;
585 }"#;
586
587 let options = EditPredictionExcerptOptions {
588 max_bytes: 120,
589 min_bytes: 10,
590 target_before_cursor_over_total_bytes: 0.6,
591 };
592
593 check_example(options, text, cx);
594 }
595}