1use std::{borrow::Cow, ops::Range, sync::Arc};
2
3use cloud_llm_client::predict_edits_v3;
4use language::{Anchor, BufferSnapshot, EditPreview, OffsetRangeExt, text_diff};
5use uuid::Uuid;
6
7#[derive(Clone)]
8pub struct EditPrediction {
9 pub id: EditPredictionId,
10 pub edits: Arc<[(Range<Anchor>, String)]>,
11 pub snapshot: BufferSnapshot,
12 pub edit_preview: EditPreview,
13}
14
15impl EditPrediction {
16 pub fn interpolate(
17 &self,
18 new_snapshot: &BufferSnapshot,
19 ) -> Option<Vec<(Range<Anchor>, String)>> {
20 interpolate_edits(&self.snapshot, new_snapshot, self.edits.clone())
21 }
22}
23
24#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
25pub struct EditPredictionId(Uuid);
26
27impl From<Uuid> for EditPredictionId {
28 fn from(value: Uuid) -> Self {
29 EditPredictionId(value)
30 }
31}
32
33impl From<EditPredictionId> for gpui::ElementId {
34 fn from(value: EditPredictionId) -> Self {
35 gpui::ElementId::Uuid(value.0)
36 }
37}
38
39impl std::fmt::Display for EditPredictionId {
40 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
41 write!(f, "{}", self.0)
42 }
43}
44
45pub fn interpolate_edits(
46 old_snapshot: &BufferSnapshot,
47 new_snapshot: &BufferSnapshot,
48 current_edits: Arc<[(Range<Anchor>, String)]>,
49) -> Option<Vec<(Range<Anchor>, String)>> {
50 let mut edits = Vec::new();
51
52 let mut model_edits = current_edits.iter().peekable();
53 for user_edit in new_snapshot.edits_since::<usize>(&old_snapshot.version) {
54 while let Some((model_old_range, _)) = model_edits.peek() {
55 let model_old_range = model_old_range.to_offset(old_snapshot);
56 if model_old_range.end < user_edit.old.start {
57 let (model_old_range, model_new_text) = model_edits.next().unwrap();
58 edits.push((model_old_range.clone(), model_new_text.clone()));
59 } else {
60 break;
61 }
62 }
63
64 if let Some((model_old_range, model_new_text)) = model_edits.peek() {
65 let model_old_offset_range = model_old_range.to_offset(old_snapshot);
66 if user_edit.old == model_old_offset_range {
67 let user_new_text = new_snapshot
68 .text_for_range(user_edit.new.clone())
69 .collect::<String>();
70
71 if let Some(model_suffix) = model_new_text.strip_prefix(&user_new_text) {
72 if !model_suffix.is_empty() {
73 let anchor = old_snapshot.anchor_after(user_edit.old.end);
74 edits.push((anchor..anchor, model_suffix.to_string()));
75 }
76
77 model_edits.next();
78 continue;
79 }
80 }
81 }
82
83 return None;
84 }
85
86 edits.extend(model_edits.cloned());
87
88 if edits.is_empty() { None } else { Some(edits) }
89}
90
91pub fn edits_from_response(
92 edits: &[predict_edits_v3::Edit],
93 snapshot: &BufferSnapshot,
94) -> Arc<[(Range<Anchor>, String)]> {
95 edits
96 .iter()
97 .flat_map(|edit| {
98 // TODO multi-file edits
99 let old_text = snapshot.text_for_range(edit.range.clone());
100
101 excerpt_edits_from_response(
102 old_text.collect::<Cow<str>>(),
103 &edit.content,
104 edit.range.start,
105 &snapshot,
106 )
107 })
108 .collect::<Vec<_>>()
109 .into()
110}
111
112fn excerpt_edits_from_response(
113 old_text: Cow<str>,
114 new_text: &str,
115 offset: usize,
116 snapshot: &BufferSnapshot,
117) -> impl Iterator<Item = (Range<Anchor>, String)> {
118 text_diff(&old_text, new_text)
119 .into_iter()
120 .map(move |(mut old_range, new_text)| {
121 old_range.start += offset;
122 old_range.end += offset;
123
124 let prefix_len = common_prefix(
125 snapshot.chars_for_range(old_range.clone()),
126 new_text.chars(),
127 );
128 old_range.start += prefix_len;
129
130 let suffix_len = common_prefix(
131 snapshot.reversed_chars_for_range(old_range.clone()),
132 new_text[prefix_len..].chars().rev(),
133 );
134 old_range.end = old_range.end.saturating_sub(suffix_len);
135
136 let new_text = new_text[prefix_len..new_text.len() - suffix_len].to_string();
137 let range = if old_range.is_empty() {
138 let anchor = snapshot.anchor_after(old_range.start);
139 anchor..anchor
140 } else {
141 snapshot.anchor_after(old_range.start)..snapshot.anchor_before(old_range.end)
142 };
143 (range, new_text)
144 })
145}
146
147fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b: T2) -> usize {
148 a.zip(b)
149 .take_while(|(a, b)| a == b)
150 .map(|(a, _)| a.len_utf8())
151 .sum()
152}
153
154#[cfg(test)]
155mod tests {
156 use std::path::PathBuf;
157
158 use super::*;
159 use cloud_llm_client::predict_edits_v3;
160 use gpui::{App, Entity, TestAppContext, prelude::*};
161 use indoc::indoc;
162 use language::{Buffer, ToOffset as _};
163
164 #[gpui::test]
165 async fn test_compute_edits(cx: &mut TestAppContext) {
166 let old = indoc! {r#"
167 fn main() {
168 let args =
169 println!("{}", args[1])
170 }
171 "#};
172
173 let new = indoc! {r#"
174 fn main() {
175 let args = std::env::args();
176 println!("{}", args[1]);
177 }
178 "#};
179
180 let buffer = cx.new(|cx| Buffer::local(old, cx));
181 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
182
183 // TODO cover more cases when multi-file is supported
184 let big_edits = vec![predict_edits_v3::Edit {
185 path: PathBuf::from("test.txt").into(),
186 range: 0..old.len(),
187 content: new.into(),
188 }];
189
190 let edits = edits_from_response(&big_edits, &snapshot);
191 assert_eq!(edits.len(), 2);
192 assert_eq!(
193 edits[0].0.to_point(&snapshot).start,
194 language::Point::new(1, 14)
195 );
196 assert_eq!(edits[0].1, " std::env::args();");
197 assert_eq!(
198 edits[1].0.to_point(&snapshot).start,
199 language::Point::new(2, 27)
200 );
201 assert_eq!(edits[1].1, ";");
202 }
203
204 #[gpui::test]
205 async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
206 let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
207 let edits: Arc<[(Range<Anchor>, String)]> = cx.update(|cx| {
208 to_prediction_edits(
209 [(2..5, "REM".to_string()), (9..11, "".to_string())],
210 &buffer,
211 cx,
212 )
213 .into()
214 });
215
216 let edit_preview = cx
217 .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
218 .await;
219
220 let prediction = EditPrediction {
221 id: EditPredictionId(Uuid::new_v4()),
222 edits,
223 snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
224 edit_preview,
225 };
226
227 cx.update(|cx| {
228 assert_eq!(
229 from_prediction_edits(
230 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
231 &buffer,
232 cx
233 ),
234 vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
235 );
236
237 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
238 assert_eq!(
239 from_prediction_edits(
240 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
241 &buffer,
242 cx
243 ),
244 vec![(2..2, "REM".to_string()), (6..8, "".to_string())]
245 );
246
247 buffer.update(cx, |buffer, cx| buffer.undo(cx));
248 assert_eq!(
249 from_prediction_edits(
250 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
251 &buffer,
252 cx
253 ),
254 vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
255 );
256
257 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
258 assert_eq!(
259 from_prediction_edits(
260 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
261 &buffer,
262 cx
263 ),
264 vec![(3..3, "EM".to_string()), (7..9, "".to_string())]
265 );
266
267 buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
268 assert_eq!(
269 from_prediction_edits(
270 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
271 &buffer,
272 cx
273 ),
274 vec![(4..4, "M".to_string()), (8..10, "".to_string())]
275 );
276
277 buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
278 assert_eq!(
279 from_prediction_edits(
280 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
281 &buffer,
282 cx
283 ),
284 vec![(9..11, "".to_string())]
285 );
286
287 buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
288 assert_eq!(
289 from_prediction_edits(
290 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
291 &buffer,
292 cx
293 ),
294 vec![(4..4, "M".to_string()), (8..10, "".to_string())]
295 );
296
297 buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
298 assert_eq!(
299 from_prediction_edits(
300 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
301 &buffer,
302 cx
303 ),
304 vec![(4..4, "M".to_string())]
305 );
306
307 buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
308 assert_eq!(prediction.interpolate(&buffer.read(cx).snapshot()), None);
309 })
310 }
311
312 fn to_prediction_edits(
313 iterator: impl IntoIterator<Item = (Range<usize>, String)>,
314 buffer: &Entity<Buffer>,
315 cx: &App,
316 ) -> Vec<(Range<Anchor>, String)> {
317 let buffer = buffer.read(cx);
318 iterator
319 .into_iter()
320 .map(|(range, text)| {
321 (
322 buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
323 text,
324 )
325 })
326 .collect()
327 }
328
329 fn from_prediction_edits(
330 editor_edits: &[(Range<Anchor>, String)],
331 buffer: &Entity<Buffer>,
332 cx: &App,
333 ) -> Vec<(Range<usize>, String)> {
334 let buffer = buffer.read(cx);
335 editor_edits
336 .iter()
337 .map(|(range, text)| {
338 (
339 range.start.to_offset(buffer)..range.end.to_offset(buffer),
340 text.clone(),
341 )
342 })
343 .collect()
344 }
345}