excerpt.rs

  1use language::BufferSnapshot;
  2use std::ops::Range;
  3use text::{OffsetRangeExt as _, Point, ToOffset as _, ToPoint as _};
  4use tree_sitter::{Node, TreeCursor};
  5use util::RangeExt;
  6
  7// TODO:
  8//
  9// - Test parent signatures
 10//
 11// - Decide whether to count signatures against the excerpt size. Could instead defer this to prompt
 12// planning.
 13//
 14// - Still return an excerpt even if the line around the cursor doesn't fit (e.g. for a markdown
 15// paragraph).
 16//
 17// - Truncation of long lines.
 18//
 19// - Filter outer syntax layers that don't support edit prediction.
 20
 21#[derive(Debug, Clone)]
 22pub struct EditPredictionExcerptOptions {
 23    /// Limit for the number of bytes in the window around the cursor.
 24    pub max_bytes: usize,
 25    /// Minimum number of bytes in the window around the cursor. When syntax tree selection results
 26    /// in an excerpt smaller than this, it will fall back on line-based selection.
 27    pub min_bytes: usize,
 28    /// Target ratio of bytes before the cursor divided by total bytes in the window.
 29    pub target_before_cursor_over_total_bytes: f32,
 30    /// Whether to include parent signatures
 31    pub include_parent_signatures: bool,
 32}
 33
 34#[derive(Clone)]
 35pub struct EditPredictionExcerpt {
 36    pub range: Range<usize>,
 37    pub parent_signature_ranges: Vec<Range<usize>>,
 38    pub size: usize,
 39}
 40
 41impl EditPredictionExcerpt {
 42    /// Selects an excerpt around a buffer position, attempting to choose logical boundaries based
 43    /// on TreeSitter structure and approximately targeting a goal ratio of bytesbefore vs after the
 44    /// cursor. When `include_parent_signatures` is true, the excerpt also includes the signatures
 45    /// of parent outline items.
 46    ///
 47    /// First tries to use AST node boundaries to select the excerpt, and falls back on line-based
 48    /// expansion.
 49    ///
 50    /// Returns `None` if the line around the cursor doesn't fit.
 51    pub fn select_from_buffer(
 52        query_point: Point,
 53        buffer: &BufferSnapshot,
 54        options: &EditPredictionExcerptOptions,
 55    ) -> Option<Self> {
 56        if buffer.len() <= options.max_bytes {
 57            log::debug!(
 58                "using entire file for excerpt since source length ({}) <= window max bytes ({})",
 59                buffer.len(),
 60                options.max_bytes
 61            );
 62            return Some(EditPredictionExcerpt::new(0..buffer.len(), Vec::new()));
 63        }
 64
 65        let query_offset = query_point.to_offset(buffer);
 66        let query_range = Point::new(query_point.row, 0).to_offset(buffer)
 67            ..Point::new(query_point.row + 1, 0).to_offset(buffer);
 68        if query_range.len() >= options.max_bytes {
 69            return None;
 70        }
 71
 72        // TODO: Don't compute text / annotation_range / skip converting to and from anchors.
 73        let outline_items = if options.include_parent_signatures {
 74            buffer
 75                .outline_items_containing(query_range.clone(), false, None)
 76                .into_iter()
 77                .flat_map(|item| {
 78                    Some(ExcerptOutlineItem {
 79                        item_range: item.range.to_offset(&buffer),
 80                        signature_range: item.signature_range?.to_offset(&buffer),
 81                    })
 82                })
 83                .collect()
 84        } else {
 85            Vec::new()
 86        };
 87
 88        let excerpt_selector = ExcerptSelector {
 89            query_offset,
 90            query_range,
 91            outline_items: &outline_items,
 92            buffer,
 93            options,
 94        };
 95
 96        if let Some(excerpt_ranges) = excerpt_selector.select_tree_sitter_nodes() {
 97            if excerpt_ranges.size >= options.min_bytes {
 98                return Some(excerpt_ranges);
 99            }
100            log::debug!(
101                "tree-sitter excerpt was {} bytes, smaller than min of {}, falling back on line-based selection",
102                excerpt_ranges.size,
103                options.min_bytes
104            );
105        } else {
106            log::debug!(
107                "couldn't find excerpt via tree-sitter, falling back on line-based selection"
108            );
109        }
110
111        excerpt_selector.select_lines()
112    }
113
114    fn new(range: Range<usize>, parent_signature_ranges: Vec<Range<usize>>) -> Self {
115        let size = range.len()
116            + parent_signature_ranges
117                .iter()
118                .map(|r| r.len())
119                .sum::<usize>();
120        Self {
121            range,
122            parent_signature_ranges,
123            size,
124        }
125    }
126
127    fn with_expanded_range(&self, new_range: Range<usize>) -> Self {
128        if !new_range.contains_inclusive(&self.range) {
129            // this is an issue because parent_signature_ranges may be incorrect
130            log::error!("bug: with_expanded_range called with disjoint range");
131        }
132        let mut parent_signature_ranges = Vec::with_capacity(self.parent_signature_ranges.len());
133        let mut size = new_range.len();
134        for range in &self.parent_signature_ranges {
135            if range.contains_inclusive(&new_range) {
136                break;
137            }
138            parent_signature_ranges.push(range.clone());
139            size += range.len();
140        }
141        Self {
142            range: new_range,
143            parent_signature_ranges,
144            size,
145        }
146    }
147
148    fn parent_signatures_size(&self) -> usize {
149        self.size - self.range.len()
150    }
151}
152
153struct ExcerptSelector<'a> {
154    query_offset: usize,
155    query_range: Range<usize>,
156    outline_items: &'a [ExcerptOutlineItem],
157    buffer: &'a BufferSnapshot,
158    options: &'a EditPredictionExcerptOptions,
159}
160
161struct ExcerptOutlineItem {
162    item_range: Range<usize>,
163    signature_range: Range<usize>,
164}
165
166impl<'a> ExcerptSelector<'a> {
167    /// Finds the largest node that is smaller than the window size and contains `query_range`.
168    fn select_tree_sitter_nodes(&self) -> Option<EditPredictionExcerpt> {
169        let selected_layer_root = self.select_syntax_layer()?;
170        let mut cursor = selected_layer_root.walk();
171
172        loop {
173            let excerpt_range = node_line_start(cursor.node()).to_offset(&self.buffer)
174                ..node_line_end(cursor.node()).to_offset(&self.buffer);
175            if excerpt_range.contains_inclusive(&self.query_range) {
176                let excerpt = self.make_excerpt(excerpt_range);
177                if excerpt.size <= self.options.max_bytes {
178                    return Some(self.expand_to_siblings(&mut cursor, excerpt));
179                }
180            } else {
181                // TODO: Should still be able to handle this case via AST nodes. For example, this
182                // can happen if the cursor is between two methods in a large class file.
183                return None;
184            }
185
186            if cursor
187                .goto_first_child_for_byte(self.query_range.start)
188                .is_none()
189            {
190                return None;
191            }
192        }
193    }
194
195    /// Select the smallest syntax layer that exceeds max_len, or the largest if none exceed max_len.
196    fn select_syntax_layer(&self) -> Option<Node<'_>> {
197        let mut smallest_exceeding_max_len: Option<Node<'_>> = None;
198        let mut largest: Option<Node<'_>> = None;
199        for layer in self
200            .buffer
201            .syntax_layers_for_range(self.query_range.start..self.query_range.start, true)
202        {
203            let layer_range = layer.node().byte_range();
204            if !layer_range.contains_inclusive(&self.query_range) {
205                continue;
206            }
207
208            if layer_range.len() > self.options.max_bytes {
209                match &smallest_exceeding_max_len {
210                    None => smallest_exceeding_max_len = Some(layer.node()),
211                    Some(existing) => {
212                        if layer_range.len() < existing.byte_range().len() {
213                            smallest_exceeding_max_len = Some(layer.node());
214                        }
215                    }
216                }
217            } else {
218                match &largest {
219                    None => largest = Some(layer.node()),
220                    Some(existing) if layer_range.len() > existing.byte_range().len() => {
221                        largest = Some(layer.node())
222                    }
223                    _ => {}
224                }
225            }
226        }
227
228        smallest_exceeding_max_len.or(largest)
229    }
230
231    // motivation for this and `goto_previous_named_sibling` is to avoid including things like
232    // trailing unnamed "}" in body nodes
233    fn goto_next_named_sibling(cursor: &mut TreeCursor) -> bool {
234        while cursor.goto_next_sibling() {
235            if cursor.node().is_named() {
236                return true;
237            }
238        }
239        false
240    }
241
242    fn goto_previous_named_sibling(cursor: &mut TreeCursor) -> bool {
243        while cursor.goto_previous_sibling() {
244            if cursor.node().is_named() {
245                return true;
246            }
247        }
248        false
249    }
250
251    fn expand_to_siblings(
252        &self,
253        cursor: &mut TreeCursor,
254        mut excerpt: EditPredictionExcerpt,
255    ) -> EditPredictionExcerpt {
256        let mut forward_cursor = cursor.clone();
257        let backward_cursor = cursor;
258        let mut forward_done = !Self::goto_next_named_sibling(&mut forward_cursor);
259        let mut backward_done = !Self::goto_previous_named_sibling(backward_cursor);
260        loop {
261            if backward_done && forward_done {
262                break;
263            }
264
265            let mut forward = None;
266            while !forward_done {
267                let new_end = node_line_end(forward_cursor.node()).to_offset(&self.buffer);
268                if new_end > excerpt.range.end {
269                    let new_excerpt = excerpt.with_expanded_range(excerpt.range.start..new_end);
270                    if new_excerpt.size <= self.options.max_bytes {
271                        forward = Some(new_excerpt);
272                        break;
273                    } else {
274                        log::debug!("halting forward expansion, as it doesn't fit");
275                        forward_done = true;
276                        break;
277                    }
278                }
279                forward_done = !Self::goto_next_named_sibling(&mut forward_cursor);
280            }
281
282            let mut backward = None;
283            while !backward_done {
284                let new_start = node_line_start(backward_cursor.node()).to_offset(&self.buffer);
285                if new_start < excerpt.range.start {
286                    let new_excerpt = excerpt.with_expanded_range(new_start..excerpt.range.end);
287                    if new_excerpt.size <= self.options.max_bytes {
288                        backward = Some(new_excerpt);
289                        break;
290                    } else {
291                        log::debug!("halting backward expansion, as it doesn't fit");
292                        backward_done = true;
293                        break;
294                    }
295                }
296                backward_done = !Self::goto_previous_named_sibling(backward_cursor);
297            }
298
299            let go_forward = match (forward, backward) {
300                (Some(forward), Some(backward)) => {
301                    let go_forward = self.is_better_excerpt(&forward, &backward);
302                    if go_forward {
303                        excerpt = forward;
304                    } else {
305                        excerpt = backward;
306                    }
307                    go_forward
308                }
309                (Some(forward), None) => {
310                    log::debug!("expanding forward, since backward expansion has halted");
311                    excerpt = forward;
312                    true
313                }
314                (None, Some(backward)) => {
315                    log::debug!("expanding backward, since forward expansion has halted");
316                    excerpt = backward;
317                    false
318                }
319                (None, None) => break,
320            };
321
322            if go_forward {
323                forward_done = !Self::goto_next_named_sibling(&mut forward_cursor);
324            } else {
325                backward_done = !Self::goto_previous_named_sibling(backward_cursor);
326            }
327        }
328
329        excerpt
330    }
331
332    fn select_lines(&self) -> Option<EditPredictionExcerpt> {
333        // early return if line containing query_offset is already too large
334        let excerpt = self.make_excerpt(self.query_range.clone());
335        if excerpt.size > self.options.max_bytes {
336            log::debug!(
337                "excerpt for cursor line is {} bytes, which exceeds the window",
338                excerpt.size
339            );
340            return None;
341        }
342        let signatures_size = excerpt.parent_signatures_size();
343        let bytes_remaining = self.options.max_bytes.saturating_sub(signatures_size);
344
345        let before_bytes =
346            (self.options.target_before_cursor_over_total_bytes * bytes_remaining as f32) as usize;
347
348        let start_point = {
349            let offset = self.query_offset.saturating_sub(before_bytes);
350            let point = offset.to_point(self.buffer);
351            Point::new(point.row + 1, 0)
352        };
353        let start_offset = start_point.to_offset(&self.buffer);
354        let end_point = {
355            let offset = start_offset + bytes_remaining;
356            let point = offset.to_point(self.buffer);
357            Point::new(point.row, 0)
358        };
359        let end_offset = end_point.to_offset(&self.buffer);
360
361        // this could be expanded further since recalculated `signature_size` may be smaller, but
362        // skipping that for now for simplicity
363        //
364        // TODO: could also consider checking if lines immediately before / after fit.
365        let excerpt = self.make_excerpt(start_offset..end_offset);
366        if excerpt.size > self.options.max_bytes {
367            log::error!(
368                "bug: line-based excerpt selection has size {}, \
369                which is {} bytes larger than the max size",
370                excerpt.size,
371                excerpt.size - self.options.max_bytes
372            );
373        }
374        return Some(excerpt);
375    }
376
377    fn make_excerpt(&self, range: Range<usize>) -> EditPredictionExcerpt {
378        let parent_signature_ranges = self
379            .outline_items
380            .iter()
381            .filter(|item| item.item_range.contains_inclusive(&range))
382            .map(|item| item.signature_range.clone())
383            .collect();
384        EditPredictionExcerpt::new(range, parent_signature_ranges)
385    }
386
387    /// Returns `true` if the `forward` excerpt is a better choice than the `backward` excerpt.
388    fn is_better_excerpt(
389        &self,
390        forward: &EditPredictionExcerpt,
391        backward: &EditPredictionExcerpt,
392    ) -> bool {
393        let forward_ratio = self.excerpt_range_ratio(forward);
394        let backward_ratio = self.excerpt_range_ratio(backward);
395        let forward_delta =
396            (forward_ratio - self.options.target_before_cursor_over_total_bytes).abs();
397        let backward_delta =
398            (backward_ratio - self.options.target_before_cursor_over_total_bytes).abs();
399        let forward_is_better = forward_delta <= backward_delta;
400        if forward_is_better {
401            log::debug!(
402                "expanding forward since {} is closer than {} to {}",
403                forward_ratio,
404                backward_ratio,
405                self.options.target_before_cursor_over_total_bytes
406            );
407        } else {
408            log::debug!(
409                "expanding backward since {} is closer than {} to {}",
410                backward_ratio,
411                forward_ratio,
412                self.options.target_before_cursor_over_total_bytes
413            );
414        }
415        forward_is_better
416    }
417
418    /// Returns the ratio of bytes before the cursor over bytes within the range.
419    fn excerpt_range_ratio(&self, excerpt: &EditPredictionExcerpt) -> f32 {
420        let Some(bytes_before_cursor) = self.query_offset.checked_sub(excerpt.range.start) else {
421            log::error!("bug: edit prediction cursor offset is not outside the excerpt");
422            return 0.0;
423        };
424        bytes_before_cursor as f32 / excerpt.range.len() as f32
425    }
426}
427
428fn node_line_start(node: Node) -> Point {
429    Point::new(node.start_position().row as u32, 0)
430}
431
432fn node_line_end(node: Node) -> Point {
433    Point::new(node.end_position().row as u32 + 1, 0)
434}
435
436#[cfg(test)]
437mod tests {
438    use super::*;
439    use gpui::{AppContext, TestAppContext};
440    use language::{Buffer, Language, LanguageConfig, LanguageMatcher, tree_sitter_rust};
441    use util::test::{generate_marked_text, marked_text_offsets_by};
442
443    fn create_buffer(text: &str, cx: &mut TestAppContext) -> BufferSnapshot {
444        let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(rust_lang().into(), cx));
445        buffer.read_with(cx, |buffer, _| buffer.snapshot())
446    }
447
448    fn rust_lang() -> Language {
449        Language::new(
450            LanguageConfig {
451                name: "Rust".into(),
452                matcher: LanguageMatcher {
453                    path_suffixes: vec!["rs".to_string()],
454                    ..Default::default()
455                },
456                ..Default::default()
457            },
458            Some(tree_sitter_rust::LANGUAGE.into()),
459        )
460        .with_outline_query(include_str!("../../languages/src/rust/outline.scm"))
461        .unwrap()
462    }
463
464    fn cursor_and_excerpt_range(text: &str) -> (String, usize, Range<usize>) {
465        let (text, offsets) = marked_text_offsets_by(text, vec!['ˇ', '«', '»']);
466        (text, offsets[&'ˇ'][0], offsets[&'«'][0]..offsets[&'»'][0])
467    }
468
469    fn check_example(options: EditPredictionExcerptOptions, text: &str, cx: &mut TestAppContext) {
470        let (text, cursor, expected_excerpt) = cursor_and_excerpt_range(text);
471
472        let buffer = create_buffer(&text, cx);
473        let cursor_point = cursor.to_point(&buffer);
474
475        let excerpt = EditPredictionExcerpt::select_from_buffer(cursor_point, &buffer, &options)
476            .expect("Should select an excerpt");
477        pretty_assertions::assert_eq!(
478            generate_marked_text(&text, std::slice::from_ref(&excerpt.range), false),
479            generate_marked_text(&text, &[expected_excerpt], false)
480        );
481        assert!(excerpt.size <= options.max_bytes);
482        assert!(excerpt.range.contains(&cursor));
483    }
484
485    #[gpui::test]
486    fn test_ast_based_selection_current_node(cx: &mut TestAppContext) {
487        zlog::init_test();
488        let text = r#"
489fn main() {
490    let x = 1;
491«    let ˇy = 2;
492»    let z = 3;
493}"#;
494
495        let options = EditPredictionExcerptOptions {
496            max_bytes: 20,
497            min_bytes: 10,
498            target_before_cursor_over_total_bytes: 0.5,
499            include_parent_signatures: false,
500        };
501
502        check_example(options, text, cx);
503    }
504
505    #[gpui::test]
506    fn test_ast_based_selection_parent_node(cx: &mut TestAppContext) {
507        zlog::init_test();
508        let text = r#"
509fn foo() {}
510
511«fn main() {
512    let x = 1;
513    let ˇy = 2;
514    let z = 3;
515}
516»
517fn bar() {}"#;
518
519        let options = EditPredictionExcerptOptions {
520            max_bytes: 65,
521            min_bytes: 10,
522            target_before_cursor_over_total_bytes: 0.5,
523            include_parent_signatures: false,
524        };
525
526        check_example(options, text, cx);
527    }
528
529    #[gpui::test]
530    fn test_ast_based_selection_expands_to_siblings(cx: &mut TestAppContext) {
531        zlog::init_test();
532        let text = r#"
533fn main() {
534«    let x = 1;
535    let ˇy = 2;
536    let z = 3;
537»}"#;
538
539        let options = EditPredictionExcerptOptions {
540            max_bytes: 50,
541            min_bytes: 10,
542            target_before_cursor_over_total_bytes: 0.5,
543            include_parent_signatures: false,
544        };
545
546        check_example(options, text, cx);
547    }
548
549    #[gpui::test]
550    fn test_line_based_selection(cx: &mut TestAppContext) {
551        zlog::init_test();
552        let text = r#"
553fn main() {
554    let x = 1;
555«    if true {
556        let ˇy = 2;
557    }
558    let z = 3;
559»}"#;
560
561        let options = EditPredictionExcerptOptions {
562            max_bytes: 60,
563            min_bytes: 45,
564            target_before_cursor_over_total_bytes: 0.5,
565            include_parent_signatures: false,
566        };
567
568        check_example(options, text, cx);
569    }
570
571    #[gpui::test]
572    fn test_line_based_selection_with_before_cursor_ratio(cx: &mut TestAppContext) {
573        zlog::init_test();
574        let text = r#"
575    fn main() {
576«        let a = 1;
577        let b = 2;
578        let c = 3;
579        let ˇd = 4;
580        let e = 5;
581        let f = 6;
582»
583        let g = 7;
584    }"#;
585
586        let options = EditPredictionExcerptOptions {
587            max_bytes: 120,
588            min_bytes: 10,
589            target_before_cursor_over_total_bytes: 0.6,
590            include_parent_signatures: false,
591        };
592
593        check_example(options, text, cx);
594    }
595}