1use std::{borrow::Cow, ops::Range, path::Path, sync::Arc};
2
3use anyhow::Context as _;
4use cloud_llm_client::predict_edits_v3;
5use gpui::{App, AsyncApp, Entity};
6use language::{
7 Anchor, Buffer, BufferSnapshot, EditPreview, OffsetRangeExt, TextBufferSnapshot, text_diff,
8};
9use project::Project;
10use util::ResultExt;
11use uuid::Uuid;
12
13#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
14pub struct EditPredictionId(Uuid);
15
16impl From<EditPredictionId> for gpui::ElementId {
17 fn from(value: EditPredictionId) -> Self {
18 gpui::ElementId::Uuid(value.0)
19 }
20}
21
22impl std::fmt::Display for EditPredictionId {
23 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
24 write!(f, "{}", self.0)
25 }
26}
27
28#[derive(Clone)]
29pub struct EditPrediction {
30 pub id: EditPredictionId,
31 pub path: Arc<Path>,
32 pub edits: Arc<[(Range<Anchor>, String)]>,
33 pub snapshot: BufferSnapshot,
34 pub edit_preview: EditPreview,
35 // We keep a reference to the buffer so that we do not need to reload it from disk when applying the prediction.
36 _buffer: Entity<Buffer>,
37}
38
39impl EditPrediction {
40 pub async fn from_response(
41 response: predict_edits_v3::PredictEditsResponse,
42 active_buffer_old_snapshot: &TextBufferSnapshot,
43 active_buffer: &Entity<Buffer>,
44 project: &Entity<Project>,
45 cx: &mut AsyncApp,
46 ) -> Option<Self> {
47 // TODO only allow cloud to return one path
48 let Some(path) = response.edits.first().map(|e| e.path.clone()) else {
49 return None;
50 };
51
52 let is_same_path = active_buffer
53 .read_with(cx, |buffer, cx| buffer_path_eq(buffer, &path, cx))
54 .ok()?;
55
56 let (buffer, edits, snapshot, edit_preview_task) = if is_same_path {
57 active_buffer
58 .read_with(cx, |buffer, cx| {
59 let new_snapshot = buffer.snapshot();
60 let edits = edits_from_response(&response.edits, &active_buffer_old_snapshot);
61 let edits: Arc<[_]> =
62 interpolate_edits(active_buffer_old_snapshot, &new_snapshot, edits)?.into();
63
64 Some((
65 active_buffer.clone(),
66 edits.clone(),
67 new_snapshot,
68 buffer.preview_edits(edits, cx),
69 ))
70 })
71 .ok()??
72 } else {
73 let buffer_handle = project
74 .update(cx, |project, cx| {
75 let project_path = project
76 .find_project_path(&path, cx)
77 .context("Failed to find project path for zeta edit")?;
78 anyhow::Ok(project.open_buffer(project_path, cx))
79 })
80 .ok()?
81 .log_err()?
82 .await
83 .context("Failed to open buffer for zeta edit")
84 .log_err()?;
85
86 buffer_handle
87 .read_with(cx, |buffer, cx| {
88 let snapshot = buffer.snapshot();
89 let edits = edits_from_response(&response.edits, &snapshot);
90 if edits.is_empty() {
91 return None;
92 }
93 Some((
94 buffer_handle.clone(),
95 edits.clone(),
96 snapshot,
97 buffer.preview_edits(edits, cx),
98 ))
99 })
100 .ok()??
101 };
102
103 let edit_preview = edit_preview_task.await;
104
105 Some(EditPrediction {
106 id: EditPredictionId(response.request_id),
107 path,
108 edits,
109 snapshot,
110 edit_preview,
111 _buffer: buffer,
112 })
113 }
114
115 pub fn interpolate(
116 &self,
117 new_snapshot: &TextBufferSnapshot,
118 ) -> Option<Vec<(Range<Anchor>, String)>> {
119 interpolate_edits(&self.snapshot, new_snapshot, self.edits.clone())
120 }
121
122 pub fn targets_buffer(&self, buffer: &Buffer, cx: &App) -> bool {
123 buffer_path_eq(buffer, &self.path, cx)
124 }
125}
126
127impl std::fmt::Debug for EditPrediction {
128 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
129 f.debug_struct("EditPrediction")
130 .field("id", &self.id)
131 .field("path", &self.path)
132 .field("edits", &self.edits)
133 .finish()
134 }
135}
136
137pub fn buffer_path_eq(buffer: &Buffer, path: &Path, cx: &App) -> bool {
138 buffer.file().map(|p| p.full_path(cx)).as_deref() == Some(path)
139}
140
141pub fn interpolate_edits(
142 old_snapshot: &TextBufferSnapshot,
143 new_snapshot: &TextBufferSnapshot,
144 current_edits: Arc<[(Range<Anchor>, String)]>,
145) -> Option<Vec<(Range<Anchor>, String)>> {
146 let mut edits = Vec::new();
147
148 let mut model_edits = current_edits.iter().peekable();
149 for user_edit in new_snapshot.edits_since::<usize>(&old_snapshot.version) {
150 while let Some((model_old_range, _)) = model_edits.peek() {
151 let model_old_range = model_old_range.to_offset(old_snapshot);
152 if model_old_range.end < user_edit.old.start {
153 let (model_old_range, model_new_text) = model_edits.next().unwrap();
154 edits.push((model_old_range.clone(), model_new_text.clone()));
155 } else {
156 break;
157 }
158 }
159
160 if let Some((model_old_range, model_new_text)) = model_edits.peek() {
161 let model_old_offset_range = model_old_range.to_offset(old_snapshot);
162 if user_edit.old == model_old_offset_range {
163 let user_new_text = new_snapshot
164 .text_for_range(user_edit.new.clone())
165 .collect::<String>();
166
167 if let Some(model_suffix) = model_new_text.strip_prefix(&user_new_text) {
168 if !model_suffix.is_empty() {
169 let anchor = old_snapshot.anchor_after(user_edit.old.end);
170 edits.push((anchor..anchor, model_suffix.to_string()));
171 }
172
173 model_edits.next();
174 continue;
175 }
176 }
177 }
178
179 return None;
180 }
181
182 edits.extend(model_edits.cloned());
183
184 if edits.is_empty() { None } else { Some(edits) }
185}
186
187fn edits_from_response(
188 edits: &[predict_edits_v3::Edit],
189 snapshot: &TextBufferSnapshot,
190) -> Arc<[(Range<Anchor>, String)]> {
191 edits
192 .iter()
193 .flat_map(|edit| {
194 let old_text = snapshot.text_for_range(edit.range.clone());
195
196 excerpt_edits_from_response(
197 old_text.collect::<Cow<str>>(),
198 &edit.content,
199 edit.range.start,
200 &snapshot,
201 )
202 })
203 .collect::<Vec<_>>()
204 .into()
205}
206
207fn excerpt_edits_from_response(
208 old_text: Cow<str>,
209 new_text: &str,
210 offset: usize,
211 snapshot: &TextBufferSnapshot,
212) -> impl Iterator<Item = (Range<Anchor>, String)> {
213 text_diff(&old_text, new_text)
214 .into_iter()
215 .map(move |(mut old_range, new_text)| {
216 old_range.start += offset;
217 old_range.end += offset;
218
219 let prefix_len = common_prefix(
220 snapshot.chars_for_range(old_range.clone()),
221 new_text.chars(),
222 );
223 old_range.start += prefix_len;
224
225 let suffix_len = common_prefix(
226 snapshot.reversed_chars_for_range(old_range.clone()),
227 new_text[prefix_len..].chars().rev(),
228 );
229 old_range.end = old_range.end.saturating_sub(suffix_len);
230
231 let new_text = new_text[prefix_len..new_text.len() - suffix_len].to_string();
232 let range = if old_range.is_empty() {
233 let anchor = snapshot.anchor_after(old_range.start);
234 anchor..anchor
235 } else {
236 snapshot.anchor_after(old_range.start)..snapshot.anchor_before(old_range.end)
237 };
238 (range, new_text)
239 })
240}
241
242fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b: T2) -> usize {
243 a.zip(b)
244 .take_while(|(a, b)| a == b)
245 .map(|(a, _)| a.len_utf8())
246 .sum()
247}
248
249#[cfg(test)]
250mod tests {
251 use std::path::PathBuf;
252
253 use super::*;
254 use cloud_llm_client::predict_edits_v3;
255 use gpui::{App, Entity, TestAppContext, prelude::*};
256 use indoc::indoc;
257 use language::{Buffer, ToOffset as _};
258
259 #[gpui::test]
260 async fn test_compute_edits(cx: &mut TestAppContext) {
261 let old = indoc! {r#"
262 fn main() {
263 let args =
264 println!("{}", args[1])
265 }
266 "#};
267
268 let new = indoc! {r#"
269 fn main() {
270 let args = std::env::args();
271 println!("{}", args[1]);
272 }
273 "#};
274
275 let buffer = cx.new(|cx| Buffer::local(old, cx));
276 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
277
278 // TODO cover more cases when multi-file is supported
279 let big_edits = vec![predict_edits_v3::Edit {
280 path: PathBuf::from("test.txt").into(),
281 range: 0..old.len(),
282 content: new.into(),
283 }];
284
285 let edits = edits_from_response(&big_edits, &snapshot);
286 assert_eq!(edits.len(), 2);
287 assert_eq!(
288 edits[0].0.to_point(&snapshot).start,
289 language::Point::new(1, 14)
290 );
291 assert_eq!(edits[0].1, " std::env::args();");
292 assert_eq!(
293 edits[1].0.to_point(&snapshot).start,
294 language::Point::new(2, 27)
295 );
296 assert_eq!(edits[1].1, ";");
297 }
298
299 #[gpui::test]
300 async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
301 let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
302 let edits: Arc<[(Range<Anchor>, String)]> = cx.update(|cx| {
303 to_prediction_edits(
304 [(2..5, "REM".to_string()), (9..11, "".to_string())],
305 &buffer,
306 cx,
307 )
308 .into()
309 });
310
311 let edit_preview = cx
312 .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
313 .await;
314
315 let prediction = EditPrediction {
316 id: EditPredictionId(Uuid::new_v4()),
317 edits,
318 snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
319 path: Path::new("test.txt").into(),
320 _buffer: buffer.clone(),
321 edit_preview,
322 };
323
324 cx.update(|cx| {
325 assert_eq!(
326 from_prediction_edits(
327 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
328 &buffer,
329 cx
330 ),
331 vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
332 );
333
334 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
335 assert_eq!(
336 from_prediction_edits(
337 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
338 &buffer,
339 cx
340 ),
341 vec![(2..2, "REM".to_string()), (6..8, "".to_string())]
342 );
343
344 buffer.update(cx, |buffer, cx| buffer.undo(cx));
345 assert_eq!(
346 from_prediction_edits(
347 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
348 &buffer,
349 cx
350 ),
351 vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
352 );
353
354 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
355 assert_eq!(
356 from_prediction_edits(
357 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
358 &buffer,
359 cx
360 ),
361 vec![(3..3, "EM".to_string()), (7..9, "".to_string())]
362 );
363
364 buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
365 assert_eq!(
366 from_prediction_edits(
367 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
368 &buffer,
369 cx
370 ),
371 vec![(4..4, "M".to_string()), (8..10, "".to_string())]
372 );
373
374 buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
375 assert_eq!(
376 from_prediction_edits(
377 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
378 &buffer,
379 cx
380 ),
381 vec![(9..11, "".to_string())]
382 );
383
384 buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
385 assert_eq!(
386 from_prediction_edits(
387 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
388 &buffer,
389 cx
390 ),
391 vec![(4..4, "M".to_string()), (8..10, "".to_string())]
392 );
393
394 buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
395 assert_eq!(
396 from_prediction_edits(
397 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
398 &buffer,
399 cx
400 ),
401 vec![(4..4, "M".to_string())]
402 );
403
404 buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
405 assert_eq!(prediction.interpolate(&buffer.read(cx).snapshot()), None);
406 })
407 }
408
409 fn to_prediction_edits(
410 iterator: impl IntoIterator<Item = (Range<usize>, String)>,
411 buffer: &Entity<Buffer>,
412 cx: &App,
413 ) -> Vec<(Range<Anchor>, String)> {
414 let buffer = buffer.read(cx);
415 iterator
416 .into_iter()
417 .map(|(range, text)| {
418 (
419 buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
420 text,
421 )
422 })
423 .collect()
424 }
425
426 fn from_prediction_edits(
427 editor_edits: &[(Range<Anchor>, String)],
428 buffer: &Entity<Buffer>,
429 cx: &App,
430 ) -> Vec<(Range<usize>, String)> {
431 let buffer = buffer.read(cx);
432 editor_edits
433 .iter()
434 .map(|(range, text)| {
435 (
436 range.start.to_offset(buffer)..range.end.to_offset(buffer),
437 text.clone(),
438 )
439 })
440 .collect()
441 }
442}