excerpt.rs

  1use language::BufferSnapshot;
  2use std::ops::Range;
  3use text::{Point, ToOffset as _, ToPoint as _};
  4use tree_sitter::{Node, TreeCursor};
  5use util::RangeExt;
  6
  7use crate::{BufferDeclaration, declaration::DeclarationId, syntax_index::SyntaxIndexState};
  8
  9// TODO:
 10//
 11// - Test parent signatures
 12//
 13// - Decide whether to count signatures against the excerpt size. Could instead defer this to prompt
 14// planning.
 15//
 16// - Still return an excerpt even if the line around the cursor doesn't fit (e.g. for a markdown
 17// paragraph).
 18//
 19// - Truncation of long lines.
 20//
 21// - Filter outer syntax layers that don't support edit prediction.
 22
 23#[derive(Debug, Clone)]
 24pub struct EditPredictionExcerptOptions {
 25    /// Limit for the number of bytes in the window around the cursor.
 26    pub max_bytes: usize,
 27    /// Minimum number of bytes in the window around the cursor. When syntax tree selection results
 28    /// in an excerpt smaller than this, it will fall back on line-based selection.
 29    pub min_bytes: usize,
 30    /// Target ratio of bytes before the cursor divided by total bytes in the window.
 31    pub target_before_cursor_over_total_bytes: f32,
 32}
 33
 34#[derive(Debug, Clone)]
 35pub struct EditPredictionExcerpt {
 36    pub range: Range<usize>,
 37    pub parent_declarations: Vec<(DeclarationId, Range<usize>)>,
 38    pub size: usize,
 39}
 40
 41#[derive(Debug, Clone)]
 42pub struct EditPredictionExcerptText {
 43    pub body: String,
 44    pub parent_signatures: Vec<String>,
 45}
 46
 47impl EditPredictionExcerpt {
 48    pub fn text(&self, buffer: &BufferSnapshot) -> EditPredictionExcerptText {
 49        let body = buffer
 50            .text_for_range(self.range.clone())
 51            .collect::<String>();
 52        let parent_signatures = self
 53            .parent_declarations
 54            .iter()
 55            .map(|(_, range)| buffer.text_for_range(range.clone()).collect::<String>())
 56            .collect();
 57        EditPredictionExcerptText {
 58            body,
 59            parent_signatures,
 60        }
 61    }
 62
 63    /// Selects an excerpt around a buffer position, attempting to choose logical boundaries based
 64    /// on TreeSitter structure and approximately targeting a goal ratio of bytesbefore vs after the
 65    /// cursor.
 66    ///
 67    /// When `index` is provided, the excerpt will include the signatures of parent outline items.
 68    ///
 69    /// First tries to use AST node boundaries to select the excerpt, and falls back on line-based
 70    /// expansion.
 71    ///
 72    /// Returns `None` if the line around the cursor doesn't fit.
 73    pub fn select_from_buffer(
 74        query_point: Point,
 75        buffer: &BufferSnapshot,
 76        options: &EditPredictionExcerptOptions,
 77        syntax_index: Option<&SyntaxIndexState>,
 78    ) -> Option<Self> {
 79        if buffer.len() <= options.max_bytes {
 80            log::debug!(
 81                "using entire file for excerpt since source length ({}) <= window max bytes ({})",
 82                buffer.len(),
 83                options.max_bytes
 84            );
 85            return Some(EditPredictionExcerpt::new(0..buffer.len(), Vec::new()));
 86        }
 87
 88        let query_offset = query_point.to_offset(buffer);
 89        let query_range = Point::new(query_point.row, 0).to_offset(buffer)
 90            ..Point::new(query_point.row + 1, 0).to_offset(buffer);
 91        if query_range.len() >= options.max_bytes {
 92            return None;
 93        }
 94
 95        let parent_declarations = if let Some(syntax_index) = syntax_index {
 96            syntax_index
 97                .buffer_declarations_containing_range(buffer.remote_id(), query_range.clone())
 98                .collect()
 99        } else {
100            Vec::new()
101        };
102
103        let excerpt_selector = ExcerptSelector {
104            query_offset,
105            query_range,
106            parent_declarations: &parent_declarations,
107            buffer,
108            options,
109        };
110
111        if let Some(excerpt) = excerpt_selector.select_tree_sitter_nodes() {
112            if excerpt.size >= options.min_bytes {
113                return Some(excerpt);
114            }
115            log::debug!(
116                "tree-sitter excerpt was {} bytes, smaller than min of {}, falling back on line-based selection",
117                excerpt.size,
118                options.min_bytes
119            );
120        } else {
121            log::debug!(
122                "couldn't find excerpt via tree-sitter, falling back on line-based selection"
123            );
124        }
125
126        excerpt_selector.select_lines()
127    }
128
129    fn new(range: Range<usize>, parent_declarations: Vec<(DeclarationId, Range<usize>)>) -> Self {
130        let size = range.len()
131            + parent_declarations
132                .iter()
133                .map(|(_, range)| range.len())
134                .sum::<usize>();
135        Self {
136            range,
137            parent_declarations,
138            size,
139        }
140    }
141
142    fn with_expanded_range(&self, new_range: Range<usize>) -> Self {
143        if !new_range.contains_inclusive(&self.range) {
144            // this is an issue because parent_signature_ranges may be incorrect
145            log::error!("bug: with_expanded_range called with disjoint range");
146        }
147        let mut parent_declarations = Vec::with_capacity(self.parent_declarations.len());
148        for (declaration_id, range) in &self.parent_declarations {
149            if !range.contains_inclusive(&new_range) {
150                break;
151            }
152            parent_declarations.push((*declaration_id, range.clone()));
153        }
154        Self::new(new_range, parent_declarations)
155    }
156
157    fn parent_signatures_size(&self) -> usize {
158        self.size - self.range.len()
159    }
160}
161
162struct ExcerptSelector<'a> {
163    query_offset: usize,
164    query_range: Range<usize>,
165    parent_declarations: &'a [(DeclarationId, &'a BufferDeclaration)],
166    buffer: &'a BufferSnapshot,
167    options: &'a EditPredictionExcerptOptions,
168}
169
170impl<'a> ExcerptSelector<'a> {
171    /// Finds the largest node that is smaller than the window size and contains `query_range`.
172    fn select_tree_sitter_nodes(&self) -> Option<EditPredictionExcerpt> {
173        let selected_layer_root = self.select_syntax_layer()?;
174        let mut cursor = selected_layer_root.walk();
175
176        loop {
177            let excerpt_range = node_line_start(cursor.node()).to_offset(&self.buffer)
178                ..node_line_end(cursor.node()).to_offset(&self.buffer);
179            if excerpt_range.contains_inclusive(&self.query_range) {
180                let excerpt = self.make_excerpt(excerpt_range);
181                if excerpt.size <= self.options.max_bytes {
182                    return Some(self.expand_to_siblings(&mut cursor, excerpt));
183                }
184            } else {
185                // TODO: Should still be able to handle this case via AST nodes. For example, this
186                // can happen if the cursor is between two methods in a large class file.
187                return None;
188            }
189
190            if cursor
191                .goto_first_child_for_byte(self.query_range.start)
192                .is_none()
193            {
194                return None;
195            }
196        }
197    }
198
199    /// Select the smallest syntax layer that exceeds max_len, or the largest if none exceed max_len.
200    fn select_syntax_layer(&self) -> Option<Node<'_>> {
201        let mut smallest_exceeding_max_len: Option<Node<'_>> = None;
202        let mut largest: Option<Node<'_>> = None;
203        for layer in self
204            .buffer
205            .syntax_layers_for_range(self.query_range.start..self.query_range.start, true)
206        {
207            let layer_range = layer.node().byte_range();
208            if !layer_range.contains_inclusive(&self.query_range) {
209                continue;
210            }
211
212            if layer_range.len() > self.options.max_bytes {
213                match &smallest_exceeding_max_len {
214                    None => smallest_exceeding_max_len = Some(layer.node()),
215                    Some(existing) => {
216                        if layer_range.len() < existing.byte_range().len() {
217                            smallest_exceeding_max_len = Some(layer.node());
218                        }
219                    }
220                }
221            } else {
222                match &largest {
223                    None => largest = Some(layer.node()),
224                    Some(existing) if layer_range.len() > existing.byte_range().len() => {
225                        largest = Some(layer.node())
226                    }
227                    _ => {}
228                }
229            }
230        }
231
232        smallest_exceeding_max_len.or(largest)
233    }
234
235    // motivation for this and `goto_previous_named_sibling` is to avoid including things like
236    // trailing unnamed "}" in body nodes
237    fn goto_next_named_sibling(cursor: &mut TreeCursor) -> bool {
238        while cursor.goto_next_sibling() {
239            if cursor.node().is_named() {
240                return true;
241            }
242        }
243        false
244    }
245
246    fn goto_previous_named_sibling(cursor: &mut TreeCursor) -> bool {
247        while cursor.goto_previous_sibling() {
248            if cursor.node().is_named() {
249                return true;
250            }
251        }
252        false
253    }
254
255    fn expand_to_siblings(
256        &self,
257        cursor: &mut TreeCursor,
258        mut excerpt: EditPredictionExcerpt,
259    ) -> EditPredictionExcerpt {
260        let mut forward_cursor = cursor.clone();
261        let backward_cursor = cursor;
262        let mut forward_done = !Self::goto_next_named_sibling(&mut forward_cursor);
263        let mut backward_done = !Self::goto_previous_named_sibling(backward_cursor);
264        loop {
265            if backward_done && forward_done {
266                break;
267            }
268
269            let mut forward = None;
270            while !forward_done {
271                let new_end = node_line_end(forward_cursor.node()).to_offset(&self.buffer);
272                if new_end > excerpt.range.end {
273                    let new_excerpt = excerpt.with_expanded_range(excerpt.range.start..new_end);
274                    if new_excerpt.size <= self.options.max_bytes {
275                        forward = Some(new_excerpt);
276                        break;
277                    } else {
278                        log::debug!("halting forward expansion, as it doesn't fit");
279                        forward_done = true;
280                        break;
281                    }
282                }
283                forward_done = !Self::goto_next_named_sibling(&mut forward_cursor);
284            }
285
286            let mut backward = None;
287            while !backward_done {
288                let new_start = node_line_start(backward_cursor.node()).to_offset(&self.buffer);
289                if new_start < excerpt.range.start {
290                    let new_excerpt = excerpt.with_expanded_range(new_start..excerpt.range.end);
291                    if new_excerpt.size <= self.options.max_bytes {
292                        backward = Some(new_excerpt);
293                        break;
294                    } else {
295                        log::debug!("halting backward expansion, as it doesn't fit");
296                        backward_done = true;
297                        break;
298                    }
299                }
300                backward_done = !Self::goto_previous_named_sibling(backward_cursor);
301            }
302
303            let go_forward = match (forward, backward) {
304                (Some(forward), Some(backward)) => {
305                    let go_forward = self.is_better_excerpt(&forward, &backward);
306                    if go_forward {
307                        excerpt = forward;
308                    } else {
309                        excerpt = backward;
310                    }
311                    go_forward
312                }
313                (Some(forward), None) => {
314                    log::debug!("expanding forward, since backward expansion has halted");
315                    excerpt = forward;
316                    true
317                }
318                (None, Some(backward)) => {
319                    log::debug!("expanding backward, since forward expansion has halted");
320                    excerpt = backward;
321                    false
322                }
323                (None, None) => break,
324            };
325
326            if go_forward {
327                forward_done = !Self::goto_next_named_sibling(&mut forward_cursor);
328            } else {
329                backward_done = !Self::goto_previous_named_sibling(backward_cursor);
330            }
331        }
332
333        excerpt
334    }
335
336    fn select_lines(&self) -> Option<EditPredictionExcerpt> {
337        // early return if line containing query_offset is already too large
338        let excerpt = self.make_excerpt(self.query_range.clone());
339        if excerpt.size > self.options.max_bytes {
340            log::debug!(
341                "excerpt for cursor line is {} bytes, which exceeds the window",
342                excerpt.size
343            );
344            return None;
345        }
346        let signatures_size = excerpt.parent_signatures_size();
347        let bytes_remaining = self.options.max_bytes.saturating_sub(signatures_size);
348
349        let before_bytes =
350            (self.options.target_before_cursor_over_total_bytes * bytes_remaining as f32) as usize;
351
352        let start_point = {
353            let offset = self.query_offset.saturating_sub(before_bytes);
354            let point = offset.to_point(self.buffer);
355            Point::new(point.row + 1, 0)
356        };
357        let start_offset = start_point.to_offset(&self.buffer);
358        let end_point = {
359            let offset = start_offset + bytes_remaining;
360            let point = offset.to_point(self.buffer);
361            Point::new(point.row, 0)
362        };
363        let end_offset = end_point.to_offset(&self.buffer);
364
365        // this could be expanded further since recalculated `signature_size` may be smaller, but
366        // skipping that for now for simplicity
367        //
368        // TODO: could also consider checking if lines immediately before / after fit.
369        let excerpt = self.make_excerpt(start_offset..end_offset);
370        if excerpt.size > self.options.max_bytes {
371            log::error!(
372                "bug: line-based excerpt selection has size {}, \
373                which is {} bytes larger than the max size",
374                excerpt.size,
375                excerpt.size - self.options.max_bytes
376            );
377        }
378        return Some(excerpt);
379    }
380
381    fn make_excerpt(&self, range: Range<usize>) -> EditPredictionExcerpt {
382        let parent_declarations = self
383            .parent_declarations
384            .iter()
385            .filter(|(_, declaration)| declaration.item_range.contains_inclusive(&range))
386            .map(|(id, declaration)| (*id, declaration.signature_range.clone()))
387            .collect();
388        EditPredictionExcerpt::new(range, parent_declarations)
389    }
390
391    /// Returns `true` if the `forward` excerpt is a better choice than the `backward` excerpt.
392    fn is_better_excerpt(
393        &self,
394        forward: &EditPredictionExcerpt,
395        backward: &EditPredictionExcerpt,
396    ) -> bool {
397        let forward_ratio = self.excerpt_range_ratio(forward);
398        let backward_ratio = self.excerpt_range_ratio(backward);
399        let forward_delta =
400            (forward_ratio - self.options.target_before_cursor_over_total_bytes).abs();
401        let backward_delta =
402            (backward_ratio - self.options.target_before_cursor_over_total_bytes).abs();
403        let forward_is_better = forward_delta <= backward_delta;
404        if forward_is_better {
405            log::debug!(
406                "expanding forward since {} is closer than {} to {}",
407                forward_ratio,
408                backward_ratio,
409                self.options.target_before_cursor_over_total_bytes
410            );
411        } else {
412            log::debug!(
413                "expanding backward since {} is closer than {} to {}",
414                backward_ratio,
415                forward_ratio,
416                self.options.target_before_cursor_over_total_bytes
417            );
418        }
419        forward_is_better
420    }
421
422    /// Returns the ratio of bytes before the cursor over bytes within the range.
423    fn excerpt_range_ratio(&self, excerpt: &EditPredictionExcerpt) -> f32 {
424        let Some(bytes_before_cursor) = self.query_offset.checked_sub(excerpt.range.start) else {
425            log::error!("bug: edit prediction cursor offset is not outside the excerpt");
426            return 0.0;
427        };
428        bytes_before_cursor as f32 / excerpt.range.len() as f32
429    }
430}
431
432fn node_line_start(node: Node) -> Point {
433    Point::new(node.start_position().row as u32, 0)
434}
435
436fn node_line_end(node: Node) -> Point {
437    Point::new(node.end_position().row as u32 + 1, 0)
438}
439
440#[cfg(test)]
441mod tests {
442    use super::*;
443    use gpui::{AppContext, TestAppContext};
444    use language::{Buffer, Language, LanguageConfig, LanguageMatcher, tree_sitter_rust};
445    use util::test::{generate_marked_text, marked_text_offsets_by};
446
447    fn create_buffer(text: &str, cx: &mut TestAppContext) -> BufferSnapshot {
448        let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(rust_lang().into(), cx));
449        buffer.read_with(cx, |buffer, _| buffer.snapshot())
450    }
451
452    fn rust_lang() -> Language {
453        Language::new(
454            LanguageConfig {
455                name: "Rust".into(),
456                matcher: LanguageMatcher {
457                    path_suffixes: vec!["rs".to_string()],
458                    ..Default::default()
459                },
460                ..Default::default()
461            },
462            Some(tree_sitter_rust::LANGUAGE.into()),
463        )
464        .with_outline_query(include_str!("../../languages/src/rust/outline.scm"))
465        .unwrap()
466    }
467
468    fn cursor_and_excerpt_range(text: &str) -> (String, usize, Range<usize>) {
469        let (text, offsets) = marked_text_offsets_by(text, vec!['ˇ', '«', '»']);
470        (text, offsets[&'ˇ'][0], offsets[&'«'][0]..offsets[&'»'][0])
471    }
472
473    fn check_example(options: EditPredictionExcerptOptions, text: &str, cx: &mut TestAppContext) {
474        let (text, cursor, expected_excerpt) = cursor_and_excerpt_range(text);
475
476        let buffer = create_buffer(&text, cx);
477        let cursor_point = cursor.to_point(&buffer);
478
479        let excerpt =
480            EditPredictionExcerpt::select_from_buffer(cursor_point, &buffer, &options, None)
481                .expect("Should select an excerpt");
482        pretty_assertions::assert_eq!(
483            generate_marked_text(&text, std::slice::from_ref(&excerpt.range), false),
484            generate_marked_text(&text, &[expected_excerpt], false)
485        );
486        assert!(excerpt.size <= options.max_bytes);
487        assert!(excerpt.range.contains(&cursor));
488    }
489
490    #[gpui::test]
491    fn test_ast_based_selection_current_node(cx: &mut TestAppContext) {
492        zlog::init_test();
493        let text = r#"
494fn main() {
495    let x = 1;
496«    let ˇy = 2;
497»    let z = 3;
498}"#;
499
500        let options = EditPredictionExcerptOptions {
501            max_bytes: 20,
502            min_bytes: 10,
503            target_before_cursor_over_total_bytes: 0.5,
504        };
505
506        check_example(options, text, cx);
507    }
508
509    #[gpui::test]
510    fn test_ast_based_selection_parent_node(cx: &mut TestAppContext) {
511        zlog::init_test();
512        let text = r#"
513fn foo() {}
514
515«fn main() {
516    let x = 1;
517    let ˇy = 2;
518    let z = 3;
519}
520»
521fn bar() {}"#;
522
523        let options = EditPredictionExcerptOptions {
524            max_bytes: 65,
525            min_bytes: 10,
526            target_before_cursor_over_total_bytes: 0.5,
527        };
528
529        check_example(options, text, cx);
530    }
531
532    #[gpui::test]
533    fn test_ast_based_selection_expands_to_siblings(cx: &mut TestAppContext) {
534        zlog::init_test();
535        let text = r#"
536fn main() {
537«    let x = 1;
538    let ˇy = 2;
539    let z = 3;
540»}"#;
541
542        let options = EditPredictionExcerptOptions {
543            max_bytes: 50,
544            min_bytes: 10,
545            target_before_cursor_over_total_bytes: 0.5,
546        };
547
548        check_example(options, text, cx);
549    }
550
551    #[gpui::test]
552    fn test_line_based_selection(cx: &mut TestAppContext) {
553        zlog::init_test();
554        let text = r#"
555fn main() {
556    let x = 1;
557«    if true {
558        let ˇy = 2;
559    }
560    let z = 3;
561»}"#;
562
563        let options = EditPredictionExcerptOptions {
564            max_bytes: 60,
565            min_bytes: 45,
566            target_before_cursor_over_total_bytes: 0.5,
567        };
568
569        check_example(options, text, cx);
570    }
571
572    #[gpui::test]
573    fn test_line_based_selection_with_before_cursor_ratio(cx: &mut TestAppContext) {
574        zlog::init_test();
575        let text = r#"
576    fn main() {
577«        let a = 1;
578        let b = 2;
579        let c = 3;
580        let ˇd = 4;
581        let e = 5;
582        let f = 6;
583»
584        let g = 7;
585    }"#;
586
587        let options = EditPredictionExcerptOptions {
588            max_bytes: 120,
589            min_bytes: 10,
590            target_before_cursor_over_total_bytes: 0.6,
591        };
592
593        check_example(options, text, cx);
594    }
595}