merge_excerpts.rs

  1use cloud_llm_client::predict_edits_v3::{self, Excerpt};
  2use edit_prediction_context::Line;
  3use language::{BufferSnapshot, Point};
  4use std::ops::Range;
  5
  6pub fn merge_excerpts(
  7    buffer: &BufferSnapshot,
  8    sorted_line_ranges: impl IntoIterator<Item = Range<Line>>,
  9) -> Vec<Excerpt> {
 10    let mut output = Vec::new();
 11    let mut merged_ranges = Vec::<Range<Line>>::new();
 12
 13    for line_range in sorted_line_ranges {
 14        if let Some(last_line_range) = merged_ranges.last_mut()
 15            && line_range.start <= last_line_range.end
 16        {
 17            last_line_range.end = last_line_range.end.max(line_range.end);
 18            continue;
 19        }
 20        merged_ranges.push(line_range);
 21    }
 22
 23    let outline_items = buffer.outline_items_as_points_containing(0..buffer.len(), false, None);
 24    let mut outline_items = outline_items.into_iter().peekable();
 25
 26    for range in merged_ranges {
 27        let point_range = Point::new(range.start.0, 0)..Point::new(range.end.0, 0);
 28
 29        while let Some(outline_item) = outline_items.peek() {
 30            if outline_item.range.start >= point_range.start {
 31                break;
 32            }
 33            if outline_item.range.end > point_range.start {
 34                let mut point_range = outline_item.source_range_for_text.clone();
 35                point_range.start.column = 0;
 36                point_range.end.column = buffer.line_len(point_range.end.row);
 37
 38                output.push(Excerpt {
 39                    start_line: Line(point_range.start.row),
 40                    text: buffer
 41                        .text_for_range(point_range.clone())
 42                        .collect::<String>()
 43                        .into(),
 44                })
 45            }
 46            outline_items.next();
 47        }
 48
 49        output.push(Excerpt {
 50            start_line: Line(point_range.start.row),
 51            text: buffer
 52                .text_for_range(point_range.clone())
 53                .collect::<String>()
 54                .into(),
 55        })
 56    }
 57
 58    output
 59}
 60
 61pub fn write_merged_excerpts(
 62    buffer: &BufferSnapshot,
 63    sorted_line_ranges: impl IntoIterator<Item = Range<Line>>,
 64    sorted_insertions: &[(predict_edits_v3::Point, &str)],
 65    output: &mut String,
 66) {
 67    cloud_zeta2_prompt::write_excerpts(
 68        merge_excerpts(buffer, sorted_line_ranges).iter(),
 69        sorted_insertions,
 70        Line(buffer.max_point().row),
 71        true,
 72        output,
 73    );
 74}
 75
 76#[cfg(test)]
 77mod tests {
 78    use std::sync::Arc;
 79
 80    use super::*;
 81    use gpui::{TestAppContext, prelude::*};
 82    use indoc::indoc;
 83    use language::{Buffer, Language, LanguageConfig, LanguageMatcher, OffsetRangeExt};
 84    use pretty_assertions::assert_eq;
 85    use util::test::marked_text_ranges;
 86
 87    #[gpui::test]
 88    fn test_rust(cx: &mut TestAppContext) {
 89        let table = [
 90            (
 91                indoc! {r#"
 92                    struct User {
 93                        first_name: String,
 94                    «    last_name: String,
 95                        ageˇ: u32,
 96                    »    email: String,
 97                        create_at: Instant,
 98                    }
 99
100                    impl User {
101                        pub fn first_name(&self) -> String {
102                            self.first_name.clone()
103                        }
104
105                        pub fn full_name(&self) -> String {
106                    «        format!("{} {}", self.first_name, self.last_name)
107                    »    }
108                    }
109                "#},
110                indoc! {r#"
111                    1|struct User {
112113                    3|    last_name: String,
114                    4|    age<|cursor|>: u32,
115116                    9|impl User {
117118                    14|    pub fn full_name(&self) -> String {
119                    15|        format!("{} {}", self.first_name, self.last_name)
120121                "#},
122            ),
123            (
124                indoc! {r#"
125                    struct User {
126                        first_name: String,
127                    «    last_name: String,
128                        age: u32,
129                    }
130                    »"#
131                },
132                indoc! {r#"
133                    1|struct User {
134135                    3|    last_name: String,
136                    4|    age: u32,
137                    5|}
138                "#},
139            ),
140        ];
141
142        for (input, expected_output) in table {
143            let input_without_ranges = input.replace(['«', '»'], "");
144            let input_without_caret = input.replace('ˇ', "");
145            let cursor_offset = input_without_ranges.find('ˇ');
146            let (input, ranges) = marked_text_ranges(&input_without_caret, false);
147            let buffer =
148                cx.new(|cx| Buffer::local(input, cx).with_language(Arc::new(rust_lang()), cx));
149            buffer.read_with(cx, |buffer, _cx| {
150                let insertions = cursor_offset
151                    .map(|offset| {
152                        let point = buffer.offset_to_point(offset);
153                        vec![(
154                            predict_edits_v3::Point {
155                                line: Line(point.row),
156                                column: point.column,
157                            },
158                            "<|cursor|>",
159                        )]
160                    })
161                    .unwrap_or_default();
162                let ranges: Vec<Range<Line>> = ranges
163                    .into_iter()
164                    .map(|range| {
165                        let point_range = range.to_point(&buffer);
166                        Line(point_range.start.row)..Line(point_range.end.row)
167                    })
168                    .collect();
169
170                let mut output = String::new();
171                write_merged_excerpts(&buffer.snapshot(), ranges, &insertions, &mut output);
172                assert_eq!(output, expected_output);
173            });
174        }
175    }
176
177    fn rust_lang() -> Language {
178        Language::new(
179            LanguageConfig {
180                name: "Rust".into(),
181                matcher: LanguageMatcher {
182                    path_suffixes: vec!["rs".to_string()],
183                    ..Default::default()
184                },
185                ..Default::default()
186            },
187            Some(language::tree_sitter_rust::LANGUAGE.into()),
188        )
189        .with_outline_query(include_str!("../../languages/src/rust/outline.scm"))
190        .unwrap()
191    }
192}