1use language::{BufferSnapshot, LanguageId};
  2use std::ops::Range;
  3use text::{Point, ToOffset as _, ToPoint as _};
  4use tree_sitter::{Node, TreeCursor};
  5use util::RangeExt;
  6
  7use crate::{BufferDeclaration, Line, declaration::DeclarationId, syntax_index::SyntaxIndexState};
  8
  9// TODO:
 10//
 11// - Test parent signatures
 12//
 13// - Decide whether to count signatures against the excerpt size. Could instead defer this to prompt
 14// planning.
 15//
 16// - Still return an excerpt even if the line around the cursor doesn't fit (e.g. for a markdown
 17// paragraph).
 18//
 19// - Truncation of long lines.
 20//
 21// - Filter outer syntax layers that don't support edit prediction.
 22
 23#[derive(Debug, Clone, PartialEq)]
 24pub struct EditPredictionExcerptOptions {
 25    /// Limit for the number of bytes in the window around the cursor.
 26    pub max_bytes: usize,
 27    /// Minimum number of bytes in the window around the cursor. When syntax tree selection results
 28    /// in an excerpt smaller than this, it will fall back on line-based selection.
 29    pub min_bytes: usize,
 30    /// Target ratio of bytes before the cursor divided by total bytes in the window.
 31    pub target_before_cursor_over_total_bytes: f32,
 32}
 33
 34// TODO: consider merging these
 35#[derive(Debug, Clone)]
 36pub struct EditPredictionExcerpt {
 37    pub range: Range<usize>,
 38    pub line_range: Range<Line>,
 39    pub parent_declarations: Vec<(DeclarationId, Range<usize>)>,
 40    pub size: usize,
 41}
 42
 43#[derive(Debug, Clone)]
 44pub struct EditPredictionExcerptText {
 45    pub body: String,
 46    pub parent_signatures: Vec<String>,
 47    pub language_id: Option<LanguageId>,
 48}
 49
 50impl EditPredictionExcerpt {
 51    pub fn text(&self, buffer: &BufferSnapshot) -> EditPredictionExcerptText {
 52        let body = buffer
 53            .text_for_range(self.range.clone())
 54            .collect::<String>();
 55        let parent_signatures = self
 56            .parent_declarations
 57            .iter()
 58            .map(|(_, range)| buffer.text_for_range(range.clone()).collect::<String>())
 59            .collect();
 60        let language_id = buffer.language().map(|l| l.id());
 61        EditPredictionExcerptText {
 62            body,
 63            parent_signatures,
 64            language_id,
 65        }
 66    }
 67
 68    /// Selects an excerpt around a buffer position, attempting to choose logical boundaries based
 69    /// on TreeSitter structure and approximately targeting a goal ratio of bytesbefore vs after the
 70    /// cursor.
 71    ///
 72    /// When `index` is provided, the excerpt will include the signatures of parent outline items.
 73    ///
 74    /// First tries to use AST node boundaries to select the excerpt, and falls back on line-based
 75    /// expansion.
 76    ///
 77    /// Returns `None` if the line around the cursor doesn't fit.
 78    pub fn select_from_buffer(
 79        query_point: Point,
 80        buffer: &BufferSnapshot,
 81        options: &EditPredictionExcerptOptions,
 82        syntax_index: Option<&SyntaxIndexState>,
 83    ) -> Option<Self> {
 84        if buffer.len() <= options.max_bytes {
 85            log::debug!(
 86                "using entire file for excerpt since source length ({}) <= window max bytes ({})",
 87                buffer.len(),
 88                options.max_bytes
 89            );
 90            let offset_range = 0..buffer.len();
 91            let line_range = Line(0)..Line(buffer.max_point().row);
 92            return Some(EditPredictionExcerpt::new(
 93                offset_range,
 94                line_range,
 95                Vec::new(),
 96            ));
 97        }
 98
 99        let query_offset = query_point.to_offset(buffer);
100        let query_line_range = query_point.row..query_point.row + 1;
101        let query_range = Point::new(query_line_range.start, 0).to_offset(buffer)
102            ..Point::new(query_line_range.end, 0).to_offset(buffer);
103        if query_range.len() >= options.max_bytes {
104            return None;
105        }
106
107        let parent_declarations = if let Some(syntax_index) = syntax_index {
108            syntax_index
109                .buffer_declarations_containing_range(buffer.remote_id(), query_range.clone())
110                .collect()
111        } else {
112            Vec::new()
113        };
114
115        let excerpt_selector = ExcerptSelector {
116            query_offset,
117            query_range,
118            query_line_range: Line(query_line_range.start)..Line(query_line_range.end),
119            parent_declarations: &parent_declarations,
120            buffer,
121            options,
122        };
123
124        if let Some(excerpt) = excerpt_selector.select_tree_sitter_nodes() {
125            if excerpt.size >= options.min_bytes {
126                return Some(excerpt);
127            }
128            log::debug!(
129                "tree-sitter excerpt was {} bytes, smaller than min of {}, falling back on line-based selection",
130                excerpt.size,
131                options.min_bytes
132            );
133        } else {
134            log::debug!(
135                "couldn't find excerpt via tree-sitter, falling back on line-based selection"
136            );
137        }
138
139        excerpt_selector.select_lines()
140    }
141
142    fn new(
143        range: Range<usize>,
144        line_range: Range<Line>,
145        parent_declarations: Vec<(DeclarationId, Range<usize>)>,
146    ) -> Self {
147        let size = range.len()
148            + parent_declarations
149                .iter()
150                .map(|(_, range)| range.len())
151                .sum::<usize>();
152        Self {
153            range,
154            parent_declarations,
155            size,
156            line_range,
157        }
158    }
159
160    fn with_expanded_range(&self, new_range: Range<usize>, new_line_range: Range<Line>) -> Self {
161        if !new_range.contains_inclusive(&self.range) {
162            // this is an issue because parent_signature_ranges may be incorrect
163            log::error!("bug: with_expanded_range called with disjoint range");
164        }
165        let mut parent_declarations = Vec::with_capacity(self.parent_declarations.len());
166        for (declaration_id, range) in &self.parent_declarations {
167            if !range.contains_inclusive(&new_range) {
168                break;
169            }
170            parent_declarations.push((*declaration_id, range.clone()));
171        }
172        Self::new(new_range, new_line_range, parent_declarations)
173    }
174
175    fn parent_signatures_size(&self) -> usize {
176        self.size - self.range.len()
177    }
178}
179
180struct ExcerptSelector<'a> {
181    query_offset: usize,
182    query_range: Range<usize>,
183    query_line_range: Range<Line>,
184    parent_declarations: &'a [(DeclarationId, &'a BufferDeclaration)],
185    buffer: &'a BufferSnapshot,
186    options: &'a EditPredictionExcerptOptions,
187}
188
189impl<'a> ExcerptSelector<'a> {
190    /// Finds the largest node that is smaller than the window size and contains `query_range`.
191    fn select_tree_sitter_nodes(&self) -> Option<EditPredictionExcerpt> {
192        let selected_layer_root = self.select_syntax_layer()?;
193        let mut cursor = selected_layer_root.walk();
194
195        loop {
196            let line_start = node_line_start(cursor.node());
197            let line_end = node_line_end(cursor.node());
198            let line_range = Line(line_start.row)..Line(line_end.row);
199            let excerpt_range =
200                line_start.to_offset(&self.buffer)..line_end.to_offset(&self.buffer);
201            if excerpt_range.contains_inclusive(&self.query_range) {
202                let excerpt = self.make_excerpt(excerpt_range, line_range);
203                if excerpt.size <= self.options.max_bytes {
204                    return Some(self.expand_to_siblings(&mut cursor, excerpt));
205                }
206            } else {
207                // TODO: Should still be able to handle this case via AST nodes. For example, this
208                // can happen if the cursor is between two methods in a large class file.
209                return None;
210            }
211
212            if cursor
213                .goto_first_child_for_byte(self.query_range.start)
214                .is_none()
215            {
216                return None;
217            }
218        }
219    }
220
221    /// Select the smallest syntax layer that exceeds max_len, or the largest if none exceed max_len.
222    fn select_syntax_layer(&self) -> Option<Node<'_>> {
223        let mut smallest_exceeding_max_len: Option<Node<'_>> = None;
224        let mut largest: Option<Node<'_>> = None;
225        for layer in self
226            .buffer
227            .syntax_layers_for_range(self.query_range.start..self.query_range.start, true)
228        {
229            let layer_range = layer.node().byte_range();
230            if !layer_range.contains_inclusive(&self.query_range) {
231                continue;
232            }
233
234            if layer_range.len() > self.options.max_bytes {
235                match &smallest_exceeding_max_len {
236                    None => smallest_exceeding_max_len = Some(layer.node()),
237                    Some(existing) => {
238                        if layer_range.len() < existing.byte_range().len() {
239                            smallest_exceeding_max_len = Some(layer.node());
240                        }
241                    }
242                }
243            } else {
244                match &largest {
245                    None => largest = Some(layer.node()),
246                    Some(existing) if layer_range.len() > existing.byte_range().len() => {
247                        largest = Some(layer.node())
248                    }
249                    _ => {}
250                }
251            }
252        }
253
254        smallest_exceeding_max_len.or(largest)
255    }
256
257    // motivation for this and `goto_previous_named_sibling` is to avoid including things like
258    // trailing unnamed "}" in body nodes
259    fn goto_next_named_sibling(cursor: &mut TreeCursor) -> bool {
260        while cursor.goto_next_sibling() {
261            if cursor.node().is_named() {
262                return true;
263            }
264        }
265        false
266    }
267
268    fn goto_previous_named_sibling(cursor: &mut TreeCursor) -> bool {
269        while cursor.goto_previous_sibling() {
270            if cursor.node().is_named() {
271                return true;
272            }
273        }
274        false
275    }
276
277    fn expand_to_siblings(
278        &self,
279        cursor: &mut TreeCursor,
280        mut excerpt: EditPredictionExcerpt,
281    ) -> EditPredictionExcerpt {
282        let mut forward_cursor = cursor.clone();
283        let backward_cursor = cursor;
284        let mut forward_done = !Self::goto_next_named_sibling(&mut forward_cursor);
285        let mut backward_done = !Self::goto_previous_named_sibling(backward_cursor);
286        loop {
287            if backward_done && forward_done {
288                break;
289            }
290
291            let mut forward = None;
292            while !forward_done {
293                let new_end_point = node_line_end(forward_cursor.node());
294                let new_end = new_end_point.to_offset(&self.buffer);
295                if new_end > excerpt.range.end {
296                    let new_excerpt = excerpt.with_expanded_range(
297                        excerpt.range.start..new_end,
298                        excerpt.line_range.start..Line(new_end_point.row),
299                    );
300                    if new_excerpt.size <= self.options.max_bytes {
301                        forward = Some(new_excerpt);
302                        break;
303                    } else {
304                        log::debug!("halting forward expansion, as it doesn't fit");
305                        forward_done = true;
306                        break;
307                    }
308                }
309                forward_done = !Self::goto_next_named_sibling(&mut forward_cursor);
310            }
311
312            let mut backward = None;
313            while !backward_done {
314                let new_start_point = node_line_start(backward_cursor.node());
315                let new_start = new_start_point.to_offset(&self.buffer);
316                if new_start < excerpt.range.start {
317                    let new_excerpt = excerpt.with_expanded_range(
318                        new_start..excerpt.range.end,
319                        Line(new_start_point.row)..excerpt.line_range.end,
320                    );
321                    if new_excerpt.size <= self.options.max_bytes {
322                        backward = Some(new_excerpt);
323                        break;
324                    } else {
325                        log::debug!("halting backward expansion, as it doesn't fit");
326                        backward_done = true;
327                        break;
328                    }
329                }
330                backward_done = !Self::goto_previous_named_sibling(backward_cursor);
331            }
332
333            let go_forward = match (forward, backward) {
334                (Some(forward), Some(backward)) => {
335                    let go_forward = self.is_better_excerpt(&forward, &backward);
336                    if go_forward {
337                        excerpt = forward;
338                    } else {
339                        excerpt = backward;
340                    }
341                    go_forward
342                }
343                (Some(forward), None) => {
344                    log::debug!("expanding forward, since backward expansion has halted");
345                    excerpt = forward;
346                    true
347                }
348                (None, Some(backward)) => {
349                    log::debug!("expanding backward, since forward expansion has halted");
350                    excerpt = backward;
351                    false
352                }
353                (None, None) => break,
354            };
355
356            if go_forward {
357                forward_done = !Self::goto_next_named_sibling(&mut forward_cursor);
358            } else {
359                backward_done = !Self::goto_previous_named_sibling(backward_cursor);
360            }
361        }
362
363        excerpt
364    }
365
366    fn select_lines(&self) -> Option<EditPredictionExcerpt> {
367        // early return if line containing query_offset is already too large
368        let excerpt = self.make_excerpt(self.query_range.clone(), self.query_line_range.clone());
369        if excerpt.size > self.options.max_bytes {
370            log::debug!(
371                "excerpt for cursor line is {} bytes, which exceeds the window",
372                excerpt.size
373            );
374            return None;
375        }
376        let signatures_size = excerpt.parent_signatures_size();
377        let bytes_remaining = self.options.max_bytes.saturating_sub(signatures_size);
378
379        let before_bytes =
380            (self.options.target_before_cursor_over_total_bytes * bytes_remaining as f32) as usize;
381
382        let start_line = {
383            let offset = self.query_offset.saturating_sub(before_bytes);
384            let point = offset.to_point(self.buffer);
385            Line(point.row + 1)
386        };
387        let start_offset = Point::new(start_line.0, 0).to_offset(&self.buffer);
388        let end_line = {
389            let offset = start_offset + bytes_remaining;
390            let point = offset.to_point(self.buffer);
391            Line(point.row)
392        };
393        let end_offset = Point::new(end_line.0, 0).to_offset(&self.buffer);
394
395        // this could be expanded further since recalculated `signature_size` may be smaller, but
396        // skipping that for now for simplicity
397        //
398        // TODO: could also consider checking if lines immediately before / after fit.
399        let excerpt = self.make_excerpt(start_offset..end_offset, start_line..end_line);
400        if excerpt.size > self.options.max_bytes {
401            log::error!(
402                "bug: line-based excerpt selection has size {}, \
403                which is {} bytes larger than the max size",
404                excerpt.size,
405                excerpt.size - self.options.max_bytes
406            );
407        }
408        return Some(excerpt);
409    }
410
411    fn make_excerpt(&self, range: Range<usize>, line_range: Range<Line>) -> EditPredictionExcerpt {
412        let parent_declarations = self
413            .parent_declarations
414            .iter()
415            .filter(|(_, declaration)| declaration.item_range.contains_inclusive(&range))
416            .map(|(id, declaration)| (*id, declaration.signature_range.clone()))
417            .collect();
418        EditPredictionExcerpt::new(range, line_range, parent_declarations)
419    }
420
421    /// Returns `true` if the `forward` excerpt is a better choice than the `backward` excerpt.
422    fn is_better_excerpt(
423        &self,
424        forward: &EditPredictionExcerpt,
425        backward: &EditPredictionExcerpt,
426    ) -> bool {
427        let forward_ratio = self.excerpt_range_ratio(forward);
428        let backward_ratio = self.excerpt_range_ratio(backward);
429        let forward_delta =
430            (forward_ratio - self.options.target_before_cursor_over_total_bytes).abs();
431        let backward_delta =
432            (backward_ratio - self.options.target_before_cursor_over_total_bytes).abs();
433        let forward_is_better = forward_delta <= backward_delta;
434        if forward_is_better {
435            log::debug!(
436                "expanding forward since {} is closer than {} to {}",
437                forward_ratio,
438                backward_ratio,
439                self.options.target_before_cursor_over_total_bytes
440            );
441        } else {
442            log::debug!(
443                "expanding backward since {} is closer than {} to {}",
444                backward_ratio,
445                forward_ratio,
446                self.options.target_before_cursor_over_total_bytes
447            );
448        }
449        forward_is_better
450    }
451
452    /// Returns the ratio of bytes before the cursor over bytes within the range.
453    fn excerpt_range_ratio(&self, excerpt: &EditPredictionExcerpt) -> f32 {
454        let Some(bytes_before_cursor) = self.query_offset.checked_sub(excerpt.range.start) else {
455            log::error!("bug: edit prediction cursor offset is not outside the excerpt");
456            return 0.0;
457        };
458        bytes_before_cursor as f32 / excerpt.range.len() as f32
459    }
460}
461
462fn node_line_start(node: Node) -> Point {
463    Point::new(node.start_position().row as u32, 0)
464}
465
466fn node_line_end(node: Node) -> Point {
467    Point::new(node.end_position().row as u32 + 1, 0)
468}
469
470#[cfg(test)]
471mod tests {
472    use super::*;
473    use gpui::{AppContext, TestAppContext};
474    use language::{Buffer, Language, LanguageConfig, LanguageMatcher, tree_sitter_rust};
475    use util::test::{generate_marked_text, marked_text_offsets_by};
476
477    fn create_buffer(text: &str, cx: &mut TestAppContext) -> BufferSnapshot {
478        let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(rust_lang().into(), cx));
479        buffer.read_with(cx, |buffer, _| buffer.snapshot())
480    }
481
482    fn rust_lang() -> Language {
483        Language::new(
484            LanguageConfig {
485                name: "Rust".into(),
486                matcher: LanguageMatcher {
487                    path_suffixes: vec!["rs".to_string()],
488                    ..Default::default()
489                },
490                ..Default::default()
491            },
492            Some(tree_sitter_rust::LANGUAGE.into()),
493        )
494        .with_outline_query(include_str!("../../languages/src/rust/outline.scm"))
495        .unwrap()
496    }
497
498    fn cursor_and_excerpt_range(text: &str) -> (String, usize, Range<usize>) {
499        let (text, offsets) = marked_text_offsets_by(text, vec!['ˇ', '«', '»']);
500        (text, offsets[&'ˇ'][0], offsets[&'«'][0]..offsets[&'»'][0])
501    }
502
503    fn check_example(options: EditPredictionExcerptOptions, text: &str, cx: &mut TestAppContext) {
504        let (text, cursor, expected_excerpt) = cursor_and_excerpt_range(text);
505
506        let buffer = create_buffer(&text, cx);
507        let cursor_point = cursor.to_point(&buffer);
508
509        let excerpt =
510            EditPredictionExcerpt::select_from_buffer(cursor_point, &buffer, &options, None)
511                .expect("Should select an excerpt");
512        pretty_assertions::assert_eq!(
513            generate_marked_text(&text, std::slice::from_ref(&excerpt.range), false),
514            generate_marked_text(&text, &[expected_excerpt], false)
515        );
516        assert!(excerpt.size <= options.max_bytes);
517        assert!(excerpt.range.contains(&cursor));
518    }
519
520    #[gpui::test]
521    fn test_ast_based_selection_current_node(cx: &mut TestAppContext) {
522        zlog::init_test();
523        let text = r#"
524fn main() {
525    let x = 1;
526«    let ˇy = 2;
527»    let z = 3;
528}"#;
529
530        let options = EditPredictionExcerptOptions {
531            max_bytes: 20,
532            min_bytes: 10,
533            target_before_cursor_over_total_bytes: 0.5,
534        };
535
536        check_example(options, text, cx);
537    }
538
539    #[gpui::test]
540    fn test_ast_based_selection_parent_node(cx: &mut TestAppContext) {
541        zlog::init_test();
542        let text = r#"
543fn foo() {}
544
545«fn main() {
546    let x = 1;
547    let ˇy = 2;
548    let z = 3;
549}
550»
551fn bar() {}"#;
552
553        let options = EditPredictionExcerptOptions {
554            max_bytes: 65,
555            min_bytes: 10,
556            target_before_cursor_over_total_bytes: 0.5,
557        };
558
559        check_example(options, text, cx);
560    }
561
562    #[gpui::test]
563    fn test_ast_based_selection_expands_to_siblings(cx: &mut TestAppContext) {
564        zlog::init_test();
565        let text = r#"
566fn main() {
567«    let x = 1;
568    let ˇy = 2;
569    let z = 3;
570»}"#;
571
572        let options = EditPredictionExcerptOptions {
573            max_bytes: 50,
574            min_bytes: 10,
575            target_before_cursor_over_total_bytes: 0.5,
576        };
577
578        check_example(options, text, cx);
579    }
580
581    #[gpui::test]
582    fn test_line_based_selection(cx: &mut TestAppContext) {
583        zlog::init_test();
584        let text = r#"
585fn main() {
586    let x = 1;
587«    if true {
588        let ˇy = 2;
589    }
590    let z = 3;
591»}"#;
592
593        let options = EditPredictionExcerptOptions {
594            max_bytes: 60,
595            min_bytes: 45,
596            target_before_cursor_over_total_bytes: 0.5,
597        };
598
599        check_example(options, text, cx);
600    }
601
602    #[gpui::test]
603    fn test_line_based_selection_with_before_cursor_ratio(cx: &mut TestAppContext) {
604        zlog::init_test();
605        let text = r#"
606    fn main() {
607«        let a = 1;
608        let b = 2;
609        let c = 3;
610        let ˇd = 4;
611        let e = 5;
612        let f = 6;
613»
614        let g = 7;
615    }"#;
616
617        let options = EditPredictionExcerptOptions {
618            max_bytes: 120,
619            min_bytes: 10,
620            target_before_cursor_over_total_bytes: 0.6,
621        };
622
623        check_example(options, text, cx);
624    }
625}