merge_excerpts.rs

  1use cloud_llm_client::predict_edits_v3::Excerpt;
  2use edit_prediction_context::Line;
  3use language::{BufferSnapshot, Point};
  4use std::ops::Range;
  5
  6pub fn merge_excerpts(
  7    buffer: &BufferSnapshot,
  8    sorted_line_ranges: impl IntoIterator<Item = Range<Line>>,
  9) -> Vec<Excerpt> {
 10    let mut output = Vec::new();
 11    let mut merged_ranges = Vec::<Range<Line>>::new();
 12
 13    for line_range in sorted_line_ranges {
 14        if let Some(last_line_range) = merged_ranges.last_mut()
 15            && line_range.start <= last_line_range.end
 16        {
 17            last_line_range.end = last_line_range.end.max(line_range.end);
 18            continue;
 19        }
 20        merged_ranges.push(line_range);
 21    }
 22
 23    let outline_items = buffer.outline_items_as_points_containing(0..buffer.len(), false, None);
 24    let mut outline_items = outline_items.into_iter().peekable();
 25
 26    for range in merged_ranges {
 27        let point_range = Point::new(range.start.0, 0)..Point::new(range.end.0, 0);
 28
 29        while let Some(outline_item) = outline_items.peek() {
 30            if outline_item.range.start >= point_range.start {
 31                break;
 32            }
 33            if outline_item.range.end > point_range.start {
 34                let mut point_range = outline_item.source_range_for_text.clone();
 35                point_range.start.column = 0;
 36                point_range.end.column = buffer.line_len(point_range.end.row);
 37
 38                output.push(Excerpt {
 39                    start_line: Line(point_range.start.row),
 40                    text: buffer
 41                        .text_for_range(point_range.clone())
 42                        .collect::<String>()
 43                        .into(),
 44                })
 45            }
 46            outline_items.next();
 47        }
 48
 49        output.push(Excerpt {
 50            start_line: Line(point_range.start.row),
 51            text: buffer
 52                .text_for_range(point_range.clone())
 53                .collect::<String>()
 54                .into(),
 55        })
 56    }
 57
 58    output
 59}
 60
 61#[cfg(test)]
 62mod tests {
 63    use std::sync::Arc;
 64
 65    use super::*;
 66    use cloud_llm_client::predict_edits_v3;
 67    use gpui::{TestAppContext, prelude::*};
 68    use indoc::indoc;
 69    use language::{Buffer, Language, LanguageConfig, LanguageMatcher, OffsetRangeExt};
 70    use pretty_assertions::assert_eq;
 71    use util::test::marked_text_ranges;
 72
 73    #[gpui::test]
 74    fn test_rust(cx: &mut TestAppContext) {
 75        let table = [
 76            (
 77                indoc! {r#"
 78                    struct User {
 79                        first_name: String,
 80                    «    last_name: String,
 81                        ageˇ: u32,
 82                    »    email: String,
 83                        create_at: Instant,
 84                    }
 85
 86                    impl User {
 87                        pub fn first_name(&self) -> String {
 88                            self.first_name.clone()
 89                        }
 90
 91                        pub fn full_name(&self) -> String {
 92                    «        format!("{} {}", self.first_name, self.last_name)
 93                    »    }
 94                    }
 95                "#},
 96                indoc! {r#"
 97                    1|struct User {
 98 99                    3|    last_name: String,
100                    4|    age<|cursor|>: u32,
101102                    9|impl User {
103104                    14|    pub fn full_name(&self) -> String {
105                    15|        format!("{} {}", self.first_name, self.last_name)
106107                "#},
108            ),
109            (
110                indoc! {r#"
111                    struct User {
112                        first_name: String,
113                    «    last_name: String,
114                        age: u32,
115                    }
116                    »"#
117                },
118                indoc! {r#"
119                    1|struct User {
120121                    3|    last_name: String,
122                    4|    age: u32,
123                    5|}
124                "#},
125            ),
126        ];
127
128        for (input, expected_output) in table {
129            let input_without_ranges = input.replace(['«', '»'], "");
130            let input_without_caret = input.replace('ˇ', "");
131            let cursor_offset = input_without_ranges.find('ˇ');
132            let (input, ranges) = marked_text_ranges(&input_without_caret, false);
133            let buffer =
134                cx.new(|cx| Buffer::local(input, cx).with_language(Arc::new(rust_lang()), cx));
135            buffer.read_with(cx, |buffer, _cx| {
136                let insertions = cursor_offset
137                    .map(|offset| {
138                        let point = buffer.offset_to_point(offset);
139                        vec![(
140                            predict_edits_v3::Point {
141                                line: Line(point.row),
142                                column: point.column,
143                            },
144                            "<|cursor|>",
145                        )]
146                    })
147                    .unwrap_or_default();
148                let ranges: Vec<Range<Line>> = ranges
149                    .into_iter()
150                    .map(|range| {
151                        let point_range = range.to_point(&buffer);
152                        Line(point_range.start.row)..Line(point_range.end.row)
153                    })
154                    .collect();
155
156                let mut output = String::new();
157                cloud_zeta2_prompt::write_excerpts(
158                    merge_excerpts(&buffer.snapshot(), ranges).iter(),
159                    &insertions,
160                    Line(buffer.max_point().row),
161                    true,
162                    &mut output,
163                );
164                assert_eq!(output, expected_output);
165            });
166        }
167    }
168
169    fn rust_lang() -> Language {
170        Language::new(
171            LanguageConfig {
172                name: "Rust".into(),
173                matcher: LanguageMatcher {
174                    path_suffixes: vec!["rs".to_string()],
175                    ..Default::default()
176                },
177                ..Default::default()
178            },
179            Some(language::tree_sitter_rust::LANGUAGE.into()),
180        )
181        .with_outline_query(include_str!("../../languages/src/rust/outline.scm"))
182        .unwrap()
183    }
184}