1use cloud_llm_client::predict_edits_v3::{self, Excerpt};
2use edit_prediction_context::Line;
3use language::{BufferSnapshot, Point};
4use std::ops::Range;
5
6pub fn merge_excerpts(
7 buffer: &BufferSnapshot,
8 sorted_line_ranges: impl IntoIterator<Item = Range<Line>>,
9) -> Vec<Excerpt> {
10 let mut output = Vec::new();
11 let mut merged_ranges = Vec::<Range<Line>>::new();
12
13 for line_range in sorted_line_ranges {
14 if let Some(last_line_range) = merged_ranges.last_mut()
15 && line_range.start <= last_line_range.end
16 {
17 last_line_range.end = last_line_range.end.max(line_range.end);
18 continue;
19 }
20 merged_ranges.push(line_range);
21 }
22
23 let outline_items = buffer.outline_items_as_points_containing(0..buffer.len(), false, None);
24 let mut outline_items = outline_items.into_iter().peekable();
25
26 for range in merged_ranges {
27 let point_range = Point::new(range.start.0, 0)..Point::new(range.end.0, 0);
28
29 while let Some(outline_item) = outline_items.peek() {
30 if outline_item.range.start >= point_range.start {
31 break;
32 }
33 if outline_item.range.end > point_range.start {
34 let mut point_range = outline_item.source_range_for_text.clone();
35 point_range.start.column = 0;
36 point_range.end.column = buffer.line_len(point_range.end.row);
37
38 output.push(Excerpt {
39 start_line: Line(point_range.start.row),
40 text: buffer
41 .text_for_range(point_range.clone())
42 .collect::<String>()
43 .into(),
44 })
45 }
46 outline_items.next();
47 }
48
49 output.push(Excerpt {
50 start_line: Line(point_range.start.row),
51 text: buffer
52 .text_for_range(point_range.clone())
53 .collect::<String>()
54 .into(),
55 })
56 }
57
58 output
59}
60
61pub fn write_merged_excerpts(
62 buffer: &BufferSnapshot,
63 sorted_line_ranges: impl IntoIterator<Item = Range<Line>>,
64 sorted_insertions: &[(predict_edits_v3::Point, &str)],
65 output: &mut String,
66) {
67 cloud_zeta2_prompt::write_excerpts(
68 merge_excerpts(buffer, sorted_line_ranges).iter(),
69 sorted_insertions,
70 Line(buffer.max_point().row),
71 true,
72 output,
73 );
74}
75
76#[cfg(test)]
77mod tests {
78 use std::sync::Arc;
79
80 use super::*;
81 use gpui::{TestAppContext, prelude::*};
82 use indoc::indoc;
83 use language::{Buffer, Language, LanguageConfig, LanguageMatcher, OffsetRangeExt};
84 use pretty_assertions::assert_eq;
85 use util::test::marked_text_ranges;
86
87 #[gpui::test]
88 fn test_rust(cx: &mut TestAppContext) {
89 let table = [
90 (
91 indoc! {r#"
92 struct User {
93 first_name: String,
94 « last_name: String,
95 ageˇ: u32,
96 » email: String,
97 create_at: Instant,
98 }
99
100 impl User {
101 pub fn first_name(&self) -> String {
102 self.first_name.clone()
103 }
104
105 pub fn full_name(&self) -> String {
106 « format!("{} {}", self.first_name, self.last_name)
107 » }
108 }
109 "#},
110 indoc! {r#"
111 1|struct User {
112 …
113 3| last_name: String,
114 4| age<|cursor|>: u32,
115 …
116 9|impl User {
117 …
118 14| pub fn full_name(&self) -> String {
119 15| format!("{} {}", self.first_name, self.last_name)
120 …
121 "#},
122 ),
123 (
124 indoc! {r#"
125 struct User {
126 first_name: String,
127 « last_name: String,
128 age: u32,
129 }
130 »"#
131 },
132 indoc! {r#"
133 1|struct User {
134 …
135 3| last_name: String,
136 4| age: u32,
137 5|}
138 "#},
139 ),
140 ];
141
142 for (input, expected_output) in table {
143 let input_without_ranges = input.replace(['«', '»'], "");
144 let input_without_caret = input.replace('ˇ', "");
145 let cursor_offset = input_without_ranges.find('ˇ');
146 let (input, ranges) = marked_text_ranges(&input_without_caret, false);
147 let buffer =
148 cx.new(|cx| Buffer::local(input, cx).with_language(Arc::new(rust_lang()), cx));
149 buffer.read_with(cx, |buffer, _cx| {
150 let insertions = cursor_offset
151 .map(|offset| {
152 let point = buffer.offset_to_point(offset);
153 vec![(
154 predict_edits_v3::Point {
155 line: Line(point.row),
156 column: point.column,
157 },
158 "<|cursor|>",
159 )]
160 })
161 .unwrap_or_default();
162 let ranges: Vec<Range<Line>> = ranges
163 .into_iter()
164 .map(|range| {
165 let point_range = range.to_point(&buffer);
166 Line(point_range.start.row)..Line(point_range.end.row)
167 })
168 .collect();
169
170 let mut output = String::new();
171 write_merged_excerpts(&buffer.snapshot(), ranges, &insertions, &mut output);
172 assert_eq!(output, expected_output);
173 });
174 }
175 }
176
177 fn rust_lang() -> Language {
178 Language::new(
179 LanguageConfig {
180 name: "Rust".into(),
181 matcher: LanguageMatcher {
182 path_suffixes: vec!["rs".to_string()],
183 ..Default::default()
184 },
185 ..Default::default()
186 },
187 Some(language::tree_sitter_rust::LANGUAGE.into()),
188 )
189 .with_outline_query(include_str!("../../languages/src/rust/outline.scm"))
190 .unwrap()
191 }
192}