excerpt.rs

  1use language::{BufferSnapshot, LanguageId};
  2use std::ops::Range;
  3use text::{Point, ToOffset as _, ToPoint as _};
  4use tree_sitter::{Node, TreeCursor};
  5use util::RangeExt;
  6
  7use crate::{BufferDeclaration, declaration::DeclarationId, syntax_index::SyntaxIndexState};
  8
  9// TODO:
 10//
 11// - Test parent signatures
 12//
 13// - Decide whether to count signatures against the excerpt size. Could instead defer this to prompt
 14// planning.
 15//
 16// - Still return an excerpt even if the line around the cursor doesn't fit (e.g. for a markdown
 17// paragraph).
 18//
 19// - Truncation of long lines.
 20//
 21// - Filter outer syntax layers that don't support edit prediction.
 22
 23#[derive(Debug, Clone, PartialEq)]
 24pub struct EditPredictionExcerptOptions {
 25    /// Limit for the number of bytes in the window around the cursor.
 26    pub max_bytes: usize,
 27    /// Minimum number of bytes in the window around the cursor. When syntax tree selection results
 28    /// in an excerpt smaller than this, it will fall back on line-based selection.
 29    pub min_bytes: usize,
 30    /// Target ratio of bytes before the cursor divided by total bytes in the window.
 31    pub target_before_cursor_over_total_bytes: f32,
 32}
 33
 34// TODO: consider merging these
 35#[derive(Debug, Clone)]
 36pub struct EditPredictionExcerpt {
 37    pub range: Range<usize>,
 38    pub parent_declarations: Vec<(DeclarationId, Range<usize>)>,
 39    pub size: usize,
 40}
 41
 42#[derive(Debug, Clone)]
 43pub struct EditPredictionExcerptText {
 44    pub body: String,
 45    pub parent_signatures: Vec<String>,
 46    pub language_id: Option<LanguageId>,
 47}
 48
 49impl EditPredictionExcerpt {
 50    pub fn text(&self, buffer: &BufferSnapshot) -> EditPredictionExcerptText {
 51        let body = buffer
 52            .text_for_range(self.range.clone())
 53            .collect::<String>();
 54        let parent_signatures = self
 55            .parent_declarations
 56            .iter()
 57            .map(|(_, range)| buffer.text_for_range(range.clone()).collect::<String>())
 58            .collect();
 59        let language_id = buffer.language().map(|l| l.id());
 60        EditPredictionExcerptText {
 61            body,
 62            parent_signatures,
 63            language_id,
 64        }
 65    }
 66
 67    /// Selects an excerpt around a buffer position, attempting to choose logical boundaries based
 68    /// on TreeSitter structure and approximately targeting a goal ratio of bytesbefore vs after the
 69    /// cursor.
 70    ///
 71    /// When `index` is provided, the excerpt will include the signatures of parent outline items.
 72    ///
 73    /// First tries to use AST node boundaries to select the excerpt, and falls back on line-based
 74    /// expansion.
 75    ///
 76    /// Returns `None` if the line around the cursor doesn't fit.
 77    pub fn select_from_buffer(
 78        query_point: Point,
 79        buffer: &BufferSnapshot,
 80        options: &EditPredictionExcerptOptions,
 81        syntax_index: Option<&SyntaxIndexState>,
 82    ) -> Option<Self> {
 83        if buffer.len() <= options.max_bytes {
 84            log::debug!(
 85                "using entire file for excerpt since source length ({}) <= window max bytes ({})",
 86                buffer.len(),
 87                options.max_bytes
 88            );
 89            return Some(EditPredictionExcerpt::new(0..buffer.len(), Vec::new()));
 90        }
 91
 92        let query_offset = query_point.to_offset(buffer);
 93        let query_range = Point::new(query_point.row, 0).to_offset(buffer)
 94            ..Point::new(query_point.row + 1, 0).to_offset(buffer);
 95        if query_range.len() >= options.max_bytes {
 96            return None;
 97        }
 98
 99        let parent_declarations = if let Some(syntax_index) = syntax_index {
100            syntax_index
101                .buffer_declarations_containing_range(buffer.remote_id(), query_range.clone())
102                .collect()
103        } else {
104            Vec::new()
105        };
106
107        let excerpt_selector = ExcerptSelector {
108            query_offset,
109            query_range,
110            parent_declarations: &parent_declarations,
111            buffer,
112            options,
113        };
114
115        if let Some(excerpt) = excerpt_selector.select_tree_sitter_nodes() {
116            if excerpt.size >= options.min_bytes {
117                return Some(excerpt);
118            }
119            log::debug!(
120                "tree-sitter excerpt was {} bytes, smaller than min of {}, falling back on line-based selection",
121                excerpt.size,
122                options.min_bytes
123            );
124        } else {
125            log::debug!(
126                "couldn't find excerpt via tree-sitter, falling back on line-based selection"
127            );
128        }
129
130        excerpt_selector.select_lines()
131    }
132
133    fn new(range: Range<usize>, parent_declarations: Vec<(DeclarationId, Range<usize>)>) -> Self {
134        let size = range.len()
135            + parent_declarations
136                .iter()
137                .map(|(_, range)| range.len())
138                .sum::<usize>();
139        Self {
140            range,
141            parent_declarations,
142            size,
143        }
144    }
145
146    fn with_expanded_range(&self, new_range: Range<usize>) -> Self {
147        if !new_range.contains_inclusive(&self.range) {
148            // this is an issue because parent_signature_ranges may be incorrect
149            log::error!("bug: with_expanded_range called with disjoint range");
150        }
151        let mut parent_declarations = Vec::with_capacity(self.parent_declarations.len());
152        for (declaration_id, range) in &self.parent_declarations {
153            if !range.contains_inclusive(&new_range) {
154                break;
155            }
156            parent_declarations.push((*declaration_id, range.clone()));
157        }
158        Self::new(new_range, parent_declarations)
159    }
160
161    fn parent_signatures_size(&self) -> usize {
162        self.size - self.range.len()
163    }
164}
165
166struct ExcerptSelector<'a> {
167    query_offset: usize,
168    query_range: Range<usize>,
169    parent_declarations: &'a [(DeclarationId, &'a BufferDeclaration)],
170    buffer: &'a BufferSnapshot,
171    options: &'a EditPredictionExcerptOptions,
172}
173
174impl<'a> ExcerptSelector<'a> {
175    /// Finds the largest node that is smaller than the window size and contains `query_range`.
176    fn select_tree_sitter_nodes(&self) -> Option<EditPredictionExcerpt> {
177        let selected_layer_root = self.select_syntax_layer()?;
178        let mut cursor = selected_layer_root.walk();
179
180        loop {
181            let excerpt_range = node_line_start(cursor.node()).to_offset(&self.buffer)
182                ..node_line_end(cursor.node()).to_offset(&self.buffer);
183            if excerpt_range.contains_inclusive(&self.query_range) {
184                let excerpt = self.make_excerpt(excerpt_range);
185                if excerpt.size <= self.options.max_bytes {
186                    return Some(self.expand_to_siblings(&mut cursor, excerpt));
187                }
188            } else {
189                // TODO: Should still be able to handle this case via AST nodes. For example, this
190                // can happen if the cursor is between two methods in a large class file.
191                return None;
192            }
193
194            if cursor
195                .goto_first_child_for_byte(self.query_range.start)
196                .is_none()
197            {
198                return None;
199            }
200        }
201    }
202
203    /// Select the smallest syntax layer that exceeds max_len, or the largest if none exceed max_len.
204    fn select_syntax_layer(&self) -> Option<Node<'_>> {
205        let mut smallest_exceeding_max_len: Option<Node<'_>> = None;
206        let mut largest: Option<Node<'_>> = None;
207        for layer in self
208            .buffer
209            .syntax_layers_for_range(self.query_range.start..self.query_range.start, true)
210        {
211            let layer_range = layer.node().byte_range();
212            if !layer_range.contains_inclusive(&self.query_range) {
213                continue;
214            }
215
216            if layer_range.len() > self.options.max_bytes {
217                match &smallest_exceeding_max_len {
218                    None => smallest_exceeding_max_len = Some(layer.node()),
219                    Some(existing) => {
220                        if layer_range.len() < existing.byte_range().len() {
221                            smallest_exceeding_max_len = Some(layer.node());
222                        }
223                    }
224                }
225            } else {
226                match &largest {
227                    None => largest = Some(layer.node()),
228                    Some(existing) if layer_range.len() > existing.byte_range().len() => {
229                        largest = Some(layer.node())
230                    }
231                    _ => {}
232                }
233            }
234        }
235
236        smallest_exceeding_max_len.or(largest)
237    }
238
239    // motivation for this and `goto_previous_named_sibling` is to avoid including things like
240    // trailing unnamed "}" in body nodes
241    fn goto_next_named_sibling(cursor: &mut TreeCursor) -> bool {
242        while cursor.goto_next_sibling() {
243            if cursor.node().is_named() {
244                return true;
245            }
246        }
247        false
248    }
249
250    fn goto_previous_named_sibling(cursor: &mut TreeCursor) -> bool {
251        while cursor.goto_previous_sibling() {
252            if cursor.node().is_named() {
253                return true;
254            }
255        }
256        false
257    }
258
259    fn expand_to_siblings(
260        &self,
261        cursor: &mut TreeCursor,
262        mut excerpt: EditPredictionExcerpt,
263    ) -> EditPredictionExcerpt {
264        let mut forward_cursor = cursor.clone();
265        let backward_cursor = cursor;
266        let mut forward_done = !Self::goto_next_named_sibling(&mut forward_cursor);
267        let mut backward_done = !Self::goto_previous_named_sibling(backward_cursor);
268        loop {
269            if backward_done && forward_done {
270                break;
271            }
272
273            let mut forward = None;
274            while !forward_done {
275                let new_end = node_line_end(forward_cursor.node()).to_offset(&self.buffer);
276                if new_end > excerpt.range.end {
277                    let new_excerpt = excerpt.with_expanded_range(excerpt.range.start..new_end);
278                    if new_excerpt.size <= self.options.max_bytes {
279                        forward = Some(new_excerpt);
280                        break;
281                    } else {
282                        log::debug!("halting forward expansion, as it doesn't fit");
283                        forward_done = true;
284                        break;
285                    }
286                }
287                forward_done = !Self::goto_next_named_sibling(&mut forward_cursor);
288            }
289
290            let mut backward = None;
291            while !backward_done {
292                let new_start = node_line_start(backward_cursor.node()).to_offset(&self.buffer);
293                if new_start < excerpt.range.start {
294                    let new_excerpt = excerpt.with_expanded_range(new_start..excerpt.range.end);
295                    if new_excerpt.size <= self.options.max_bytes {
296                        backward = Some(new_excerpt);
297                        break;
298                    } else {
299                        log::debug!("halting backward expansion, as it doesn't fit");
300                        backward_done = true;
301                        break;
302                    }
303                }
304                backward_done = !Self::goto_previous_named_sibling(backward_cursor);
305            }
306
307            let go_forward = match (forward, backward) {
308                (Some(forward), Some(backward)) => {
309                    let go_forward = self.is_better_excerpt(&forward, &backward);
310                    if go_forward {
311                        excerpt = forward;
312                    } else {
313                        excerpt = backward;
314                    }
315                    go_forward
316                }
317                (Some(forward), None) => {
318                    log::debug!("expanding forward, since backward expansion has halted");
319                    excerpt = forward;
320                    true
321                }
322                (None, Some(backward)) => {
323                    log::debug!("expanding backward, since forward expansion has halted");
324                    excerpt = backward;
325                    false
326                }
327                (None, None) => break,
328            };
329
330            if go_forward {
331                forward_done = !Self::goto_next_named_sibling(&mut forward_cursor);
332            } else {
333                backward_done = !Self::goto_previous_named_sibling(backward_cursor);
334            }
335        }
336
337        excerpt
338    }
339
340    fn select_lines(&self) -> Option<EditPredictionExcerpt> {
341        // early return if line containing query_offset is already too large
342        let excerpt = self.make_excerpt(self.query_range.clone());
343        if excerpt.size > self.options.max_bytes {
344            log::debug!(
345                "excerpt for cursor line is {} bytes, which exceeds the window",
346                excerpt.size
347            );
348            return None;
349        }
350        let signatures_size = excerpt.parent_signatures_size();
351        let bytes_remaining = self.options.max_bytes.saturating_sub(signatures_size);
352
353        let before_bytes =
354            (self.options.target_before_cursor_over_total_bytes * bytes_remaining as f32) as usize;
355
356        let start_point = {
357            let offset = self.query_offset.saturating_sub(before_bytes);
358            let point = offset.to_point(self.buffer);
359            Point::new(point.row + 1, 0)
360        };
361        let start_offset = start_point.to_offset(&self.buffer);
362        let end_point = {
363            let offset = start_offset + bytes_remaining;
364            let point = offset.to_point(self.buffer);
365            Point::new(point.row, 0)
366        };
367        let end_offset = end_point.to_offset(&self.buffer);
368
369        // this could be expanded further since recalculated `signature_size` may be smaller, but
370        // skipping that for now for simplicity
371        //
372        // TODO: could also consider checking if lines immediately before / after fit.
373        let excerpt = self.make_excerpt(start_offset..end_offset);
374        if excerpt.size > self.options.max_bytes {
375            log::error!(
376                "bug: line-based excerpt selection has size {}, \
377                which is {} bytes larger than the max size",
378                excerpt.size,
379                excerpt.size - self.options.max_bytes
380            );
381        }
382        return Some(excerpt);
383    }
384
385    fn make_excerpt(&self, range: Range<usize>) -> EditPredictionExcerpt {
386        let parent_declarations = self
387            .parent_declarations
388            .iter()
389            .filter(|(_, declaration)| declaration.item_range.contains_inclusive(&range))
390            .map(|(id, declaration)| (*id, declaration.signature_range.clone()))
391            .collect();
392        EditPredictionExcerpt::new(range, parent_declarations)
393    }
394
395    /// Returns `true` if the `forward` excerpt is a better choice than the `backward` excerpt.
396    fn is_better_excerpt(
397        &self,
398        forward: &EditPredictionExcerpt,
399        backward: &EditPredictionExcerpt,
400    ) -> bool {
401        let forward_ratio = self.excerpt_range_ratio(forward);
402        let backward_ratio = self.excerpt_range_ratio(backward);
403        let forward_delta =
404            (forward_ratio - self.options.target_before_cursor_over_total_bytes).abs();
405        let backward_delta =
406            (backward_ratio - self.options.target_before_cursor_over_total_bytes).abs();
407        let forward_is_better = forward_delta <= backward_delta;
408        if forward_is_better {
409            log::debug!(
410                "expanding forward since {} is closer than {} to {}",
411                forward_ratio,
412                backward_ratio,
413                self.options.target_before_cursor_over_total_bytes
414            );
415        } else {
416            log::debug!(
417                "expanding backward since {} is closer than {} to {}",
418                backward_ratio,
419                forward_ratio,
420                self.options.target_before_cursor_over_total_bytes
421            );
422        }
423        forward_is_better
424    }
425
426    /// Returns the ratio of bytes before the cursor over bytes within the range.
427    fn excerpt_range_ratio(&self, excerpt: &EditPredictionExcerpt) -> f32 {
428        let Some(bytes_before_cursor) = self.query_offset.checked_sub(excerpt.range.start) else {
429            log::error!("bug: edit prediction cursor offset is not outside the excerpt");
430            return 0.0;
431        };
432        bytes_before_cursor as f32 / excerpt.range.len() as f32
433    }
434}
435
436fn node_line_start(node: Node) -> Point {
437    Point::new(node.start_position().row as u32, 0)
438}
439
440fn node_line_end(node: Node) -> Point {
441    Point::new(node.end_position().row as u32 + 1, 0)
442}
443
444#[cfg(test)]
445mod tests {
446    use super::*;
447    use gpui::{AppContext, TestAppContext};
448    use language::{Buffer, Language, LanguageConfig, LanguageMatcher, tree_sitter_rust};
449    use util::test::{generate_marked_text, marked_text_offsets_by};
450
451    fn create_buffer(text: &str, cx: &mut TestAppContext) -> BufferSnapshot {
452        let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(rust_lang().into(), cx));
453        buffer.read_with(cx, |buffer, _| buffer.snapshot())
454    }
455
456    fn rust_lang() -> Language {
457        Language::new(
458            LanguageConfig {
459                name: "Rust".into(),
460                matcher: LanguageMatcher {
461                    path_suffixes: vec!["rs".to_string()],
462                    ..Default::default()
463                },
464                ..Default::default()
465            },
466            Some(tree_sitter_rust::LANGUAGE.into()),
467        )
468        .with_outline_query(include_str!("../../languages/src/rust/outline.scm"))
469        .unwrap()
470    }
471
472    fn cursor_and_excerpt_range(text: &str) -> (String, usize, Range<usize>) {
473        let (text, offsets) = marked_text_offsets_by(text, vec!['ˇ', '«', '»']);
474        (text, offsets[&'ˇ'][0], offsets[&'«'][0]..offsets[&'»'][0])
475    }
476
477    fn check_example(options: EditPredictionExcerptOptions, text: &str, cx: &mut TestAppContext) {
478        let (text, cursor, expected_excerpt) = cursor_and_excerpt_range(text);
479
480        let buffer = create_buffer(&text, cx);
481        let cursor_point = cursor.to_point(&buffer);
482
483        let excerpt =
484            EditPredictionExcerpt::select_from_buffer(cursor_point, &buffer, &options, None)
485                .expect("Should select an excerpt");
486        pretty_assertions::assert_eq!(
487            generate_marked_text(&text, std::slice::from_ref(&excerpt.range), false),
488            generate_marked_text(&text, &[expected_excerpt], false)
489        );
490        assert!(excerpt.size <= options.max_bytes);
491        assert!(excerpt.range.contains(&cursor));
492    }
493
494    #[gpui::test]
495    fn test_ast_based_selection_current_node(cx: &mut TestAppContext) {
496        zlog::init_test();
497        let text = r#"
498fn main() {
499    let x = 1;
500«    let ˇy = 2;
501»    let z = 3;
502}"#;
503
504        let options = EditPredictionExcerptOptions {
505            max_bytes: 20,
506            min_bytes: 10,
507            target_before_cursor_over_total_bytes: 0.5,
508        };
509
510        check_example(options, text, cx);
511    }
512
513    #[gpui::test]
514    fn test_ast_based_selection_parent_node(cx: &mut TestAppContext) {
515        zlog::init_test();
516        let text = r#"
517fn foo() {}
518
519«fn main() {
520    let x = 1;
521    let ˇy = 2;
522    let z = 3;
523}
524»
525fn bar() {}"#;
526
527        let options = EditPredictionExcerptOptions {
528            max_bytes: 65,
529            min_bytes: 10,
530            target_before_cursor_over_total_bytes: 0.5,
531        };
532
533        check_example(options, text, cx);
534    }
535
536    #[gpui::test]
537    fn test_ast_based_selection_expands_to_siblings(cx: &mut TestAppContext) {
538        zlog::init_test();
539        let text = r#"
540fn main() {
541«    let x = 1;
542    let ˇy = 2;
543    let z = 3;
544»}"#;
545
546        let options = EditPredictionExcerptOptions {
547            max_bytes: 50,
548            min_bytes: 10,
549            target_before_cursor_over_total_bytes: 0.5,
550        };
551
552        check_example(options, text, cx);
553    }
554
555    #[gpui::test]
556    fn test_line_based_selection(cx: &mut TestAppContext) {
557        zlog::init_test();
558        let text = r#"
559fn main() {
560    let x = 1;
561«    if true {
562        let ˇy = 2;
563    }
564    let z = 3;
565»}"#;
566
567        let options = EditPredictionExcerptOptions {
568            max_bytes: 60,
569            min_bytes: 45,
570            target_before_cursor_over_total_bytes: 0.5,
571        };
572
573        check_example(options, text, cx);
574    }
575
576    #[gpui::test]
577    fn test_line_based_selection_with_before_cursor_ratio(cx: &mut TestAppContext) {
578        zlog::init_test();
579        let text = r#"
580    fn main() {
581«        let a = 1;
582        let b = 2;
583        let c = 3;
584        let ˇd = 4;
585        let e = 5;
586        let f = 6;
587»
588        let g = 7;
589    }"#;
590
591        let options = EditPredictionExcerptOptions {
592            max_bytes: 120,
593            min_bytes: 10,
594            target_before_cursor_over_total_bytes: 0.6,
595        };
596
597        check_example(options, text, cx);
598    }
599}