1use std::{borrow::Cow, ops::Range, path::Path, sync::Arc};
  2
  3use anyhow::Context as _;
  4use cloud_llm_client::predict_edits_v3;
  5use gpui::{App, AsyncApp, Entity};
  6use language::{
  7    Anchor, Buffer, BufferSnapshot, EditPreview, OffsetRangeExt, TextBufferSnapshot, text_diff,
  8};
  9use project::Project;
 10use util::ResultExt;
 11use uuid::Uuid;
 12
 13#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
 14pub struct EditPredictionId(Uuid);
 15
 16impl Into<Uuid> for EditPredictionId {
 17    fn into(self) -> Uuid {
 18        self.0
 19    }
 20}
 21
 22impl From<EditPredictionId> for gpui::ElementId {
 23    fn from(value: EditPredictionId) -> Self {
 24        gpui::ElementId::Uuid(value.0)
 25    }
 26}
 27
 28impl std::fmt::Display for EditPredictionId {
 29    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 30        write!(f, "{}", self.0)
 31    }
 32}
 33
 34#[derive(Clone)]
 35pub struct EditPrediction {
 36    pub id: EditPredictionId,
 37    pub path: Arc<Path>,
 38    pub edits: Arc<[(Range<Anchor>, String)]>,
 39    pub snapshot: BufferSnapshot,
 40    pub edit_preview: EditPreview,
 41    // We keep a reference to the buffer so that we do not need to reload it from disk when applying the prediction.
 42    pub buffer: Entity<Buffer>,
 43}
 44
 45impl EditPrediction {
 46    pub async fn from_response(
 47        response: predict_edits_v3::PredictEditsResponse,
 48        active_buffer_old_snapshot: &TextBufferSnapshot,
 49        active_buffer: &Entity<Buffer>,
 50        project: &Entity<Project>,
 51        cx: &mut AsyncApp,
 52    ) -> Option<Self> {
 53        // TODO only allow cloud to return one path
 54        let Some(path) = response.edits.first().map(|e| e.path.clone()) else {
 55            return None;
 56        };
 57
 58        let is_same_path = active_buffer
 59            .read_with(cx, |buffer, cx| buffer_path_eq(buffer, &path, cx))
 60            .ok()?;
 61
 62        let (buffer, edits, snapshot, edit_preview_task) = if is_same_path {
 63            active_buffer
 64                .read_with(cx, |buffer, cx| {
 65                    let new_snapshot = buffer.snapshot();
 66                    let edits = edits_from_response(&response.edits, &active_buffer_old_snapshot);
 67                    let edits: Arc<[_]> =
 68                        interpolate_edits(active_buffer_old_snapshot, &new_snapshot, edits)?.into();
 69
 70                    Some((
 71                        active_buffer.clone(),
 72                        edits.clone(),
 73                        new_snapshot,
 74                        buffer.preview_edits(edits, cx),
 75                    ))
 76                })
 77                .ok()??
 78        } else {
 79            let buffer_handle = project
 80                .update(cx, |project, cx| {
 81                    let project_path = project
 82                        .find_project_path(&path, cx)
 83                        .context("Failed to find project path for zeta edit")?;
 84                    anyhow::Ok(project.open_buffer(project_path, cx))
 85                })
 86                .ok()?
 87                .log_err()?
 88                .await
 89                .context("Failed to open buffer for zeta edit")
 90                .log_err()?;
 91
 92            buffer_handle
 93                .read_with(cx, |buffer, cx| {
 94                    let snapshot = buffer.snapshot();
 95                    let edits = edits_from_response(&response.edits, &snapshot);
 96                    if edits.is_empty() {
 97                        return None;
 98                    }
 99                    Some((
100                        buffer_handle.clone(),
101                        edits.clone(),
102                        snapshot,
103                        buffer.preview_edits(edits, cx),
104                    ))
105                })
106                .ok()??
107        };
108
109        let edit_preview = edit_preview_task.await;
110
111        Some(EditPrediction {
112            id: EditPredictionId(response.request_id),
113            path,
114            edits,
115            snapshot,
116            edit_preview,
117            buffer,
118        })
119    }
120
121    pub fn interpolate(
122        &self,
123        new_snapshot: &TextBufferSnapshot,
124    ) -> Option<Vec<(Range<Anchor>, String)>> {
125        interpolate_edits(&self.snapshot, new_snapshot, self.edits.clone())
126    }
127
128    pub fn targets_buffer(&self, buffer: &Buffer, cx: &App) -> bool {
129        buffer_path_eq(buffer, &self.path, cx)
130    }
131}
132
133impl std::fmt::Debug for EditPrediction {
134    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
135        f.debug_struct("EditPrediction")
136            .field("id", &self.id)
137            .field("path", &self.path)
138            .field("edits", &self.edits)
139            .finish()
140    }
141}
142
143pub fn buffer_path_eq(buffer: &Buffer, path: &Path, cx: &App) -> bool {
144    buffer.file().map(|p| p.full_path(cx)).as_deref() == Some(path)
145}
146
147pub fn interpolate_edits(
148    old_snapshot: &TextBufferSnapshot,
149    new_snapshot: &TextBufferSnapshot,
150    current_edits: Arc<[(Range<Anchor>, String)]>,
151) -> Option<Vec<(Range<Anchor>, String)>> {
152    let mut edits = Vec::new();
153
154    let mut model_edits = current_edits.iter().peekable();
155    for user_edit in new_snapshot.edits_since::<usize>(&old_snapshot.version) {
156        while let Some((model_old_range, _)) = model_edits.peek() {
157            let model_old_range = model_old_range.to_offset(old_snapshot);
158            if model_old_range.end < user_edit.old.start {
159                let (model_old_range, model_new_text) = model_edits.next().unwrap();
160                edits.push((model_old_range.clone(), model_new_text.clone()));
161            } else {
162                break;
163            }
164        }
165
166        if let Some((model_old_range, model_new_text)) = model_edits.peek() {
167            let model_old_offset_range = model_old_range.to_offset(old_snapshot);
168            if user_edit.old == model_old_offset_range {
169                let user_new_text = new_snapshot
170                    .text_for_range(user_edit.new.clone())
171                    .collect::<String>();
172
173                if let Some(model_suffix) = model_new_text.strip_prefix(&user_new_text) {
174                    if !model_suffix.is_empty() {
175                        let anchor = old_snapshot.anchor_after(user_edit.old.end);
176                        edits.push((anchor..anchor, model_suffix.to_string()));
177                    }
178
179                    model_edits.next();
180                    continue;
181                }
182            }
183        }
184
185        return None;
186    }
187
188    edits.extend(model_edits.cloned());
189
190    if edits.is_empty() { None } else { Some(edits) }
191}
192
193pub const fn line_range_to_point_range(range: Range<predict_edits_v3::Line>) -> Range<language::Point> {
194    language::Point::new(range.start.0, 0)..language::Point::new(range.end.0, 0)
195}
196
197fn edits_from_response(
198    edits: &[predict_edits_v3::Edit],
199    snapshot: &TextBufferSnapshot,
200) -> Arc<[(Range<Anchor>, String)]> {
201    edits
202        .iter()
203        .flat_map(|edit| {
204            let point_range = line_range_to_point_range(edit.range.clone());
205            let offset = point_range.to_offset(snapshot).start;
206            let old_text = snapshot.text_for_range(point_range);
207
208            excerpt_edits_from_response(
209                old_text.collect::<Cow<str>>(),
210                &edit.content,
211                offset,
212                &snapshot,
213            )
214        })
215        .collect::<Vec<_>>()
216        .into()
217}
218
219fn excerpt_edits_from_response(
220    old_text: Cow<str>,
221    new_text: &str,
222    offset: usize,
223    snapshot: &TextBufferSnapshot,
224) -> impl Iterator<Item = (Range<Anchor>, String)> {
225    text_diff(&old_text, new_text)
226        .into_iter()
227        .map(move |(mut old_range, new_text)| {
228            old_range.start += offset;
229            old_range.end += offset;
230
231            let prefix_len = common_prefix(
232                snapshot.chars_for_range(old_range.clone()),
233                new_text.chars(),
234            );
235            old_range.start += prefix_len;
236
237            let suffix_len = common_prefix(
238                snapshot.reversed_chars_for_range(old_range.clone()),
239                new_text[prefix_len..].chars().rev(),
240            );
241            old_range.end = old_range.end.saturating_sub(suffix_len);
242
243            let new_text = new_text[prefix_len..new_text.len() - suffix_len].to_string();
244            let range = if old_range.is_empty() {
245                let anchor = snapshot.anchor_after(old_range.start);
246                anchor..anchor
247            } else {
248                snapshot.anchor_after(old_range.start)..snapshot.anchor_before(old_range.end)
249            };
250            (range, new_text)
251        })
252}
253
254fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b: T2) -> usize {
255    a.zip(b)
256        .take_while(|(a, b)| a == b)
257        .map(|(a, _)| a.len_utf8())
258        .sum()
259}
260
261#[cfg(test)]
262mod tests {
263    use std::path::PathBuf;
264
265    use super::*;
266    use cloud_llm_client::predict_edits_v3;
267    use edit_prediction_context::Line;
268    use gpui::{App, Entity, TestAppContext, prelude::*};
269    use indoc::indoc;
270    use language::{Buffer, ToOffset as _};
271
272    #[gpui::test]
273    async fn test_compute_edits(cx: &mut TestAppContext) {
274        let old = indoc! {r#"
275            fn main() {
276                let args =
277                println!("{}", args[1])
278            }
279        "#};
280
281        let new = indoc! {r#"
282            fn main() {
283                let args = std::env::args();
284                println!("{}", args[1]);
285            }
286        "#};
287
288        let buffer = cx.new(|cx| Buffer::local(old, cx));
289        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
290
291        // TODO cover more cases when multi-file is supported
292        let big_edits = vec![predict_edits_v3::Edit {
293            path: PathBuf::from("test.txt").into(),
294            range: Line(0)..Line(old.lines().count() as u32),
295            content: new.into(),
296        }];
297
298        let edits = edits_from_response(&big_edits, &snapshot);
299        assert_eq!(edits.len(), 2);
300        assert_eq!(
301            edits[0].0.to_point(&snapshot).start,
302            language::Point::new(1, 14)
303        );
304        assert_eq!(edits[0].1, " std::env::args();");
305        assert_eq!(
306            edits[1].0.to_point(&snapshot).start,
307            language::Point::new(2, 27)
308        );
309        assert_eq!(edits[1].1, ";");
310    }
311
312    #[gpui::test]
313    async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
314        let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
315        let edits: Arc<[(Range<Anchor>, String)]> = cx.update(|cx| {
316            to_prediction_edits(
317                [(2..5, "REM".to_string()), (9..11, "".to_string())],
318                &buffer,
319                cx,
320            )
321            .into()
322        });
323
324        let edit_preview = cx
325            .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
326            .await;
327
328        let prediction = EditPrediction {
329            id: EditPredictionId(Uuid::new_v4()),
330            edits,
331            snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
332            path: Path::new("test.txt").into(),
333            buffer: buffer.clone(),
334            edit_preview,
335        };
336
337        cx.update(|cx| {
338            assert_eq!(
339                from_prediction_edits(
340                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
341                    &buffer,
342                    cx
343                ),
344                vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
345            );
346
347            buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
348            assert_eq!(
349                from_prediction_edits(
350                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
351                    &buffer,
352                    cx
353                ),
354                vec![(2..2, "REM".to_string()), (6..8, "".to_string())]
355            );
356
357            buffer.update(cx, |buffer, cx| buffer.undo(cx));
358            assert_eq!(
359                from_prediction_edits(
360                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
361                    &buffer,
362                    cx
363                ),
364                vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
365            );
366
367            buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
368            assert_eq!(
369                from_prediction_edits(
370                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
371                    &buffer,
372                    cx
373                ),
374                vec![(3..3, "EM".to_string()), (7..9, "".to_string())]
375            );
376
377            buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
378            assert_eq!(
379                from_prediction_edits(
380                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
381                    &buffer,
382                    cx
383                ),
384                vec![(4..4, "M".to_string()), (8..10, "".to_string())]
385            );
386
387            buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
388            assert_eq!(
389                from_prediction_edits(
390                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
391                    &buffer,
392                    cx
393                ),
394                vec![(9..11, "".to_string())]
395            );
396
397            buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
398            assert_eq!(
399                from_prediction_edits(
400                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
401                    &buffer,
402                    cx
403                ),
404                vec![(4..4, "M".to_string()), (8..10, "".to_string())]
405            );
406
407            buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
408            assert_eq!(
409                from_prediction_edits(
410                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
411                    &buffer,
412                    cx
413                ),
414                vec![(4..4, "M".to_string())]
415            );
416
417            buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
418            assert_eq!(prediction.interpolate(&buffer.read(cx).snapshot()), None);
419        })
420    }
421
422    fn to_prediction_edits(
423        iterator: impl IntoIterator<Item = (Range<usize>, String)>,
424        buffer: &Entity<Buffer>,
425        cx: &App,
426    ) -> Vec<(Range<Anchor>, String)> {
427        let buffer = buffer.read(cx);
428        iterator
429            .into_iter()
430            .map(|(range, text)| {
431                (
432                    buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
433                    text,
434                )
435            })
436            .collect()
437    }
438
439    fn from_prediction_edits(
440        editor_edits: &[(Range<Anchor>, String)],
441        buffer: &Entity<Buffer>,
442        cx: &App,
443    ) -> Vec<(Range<usize>, String)> {
444        let buffer = buffer.read(cx);
445        editor_edits
446            .iter()
447            .map(|(range, text)| {
448                (
449                    range.start.to_offset(buffer)..range.end.to_offset(buffer),
450                    text.clone(),
451                )
452            })
453            .collect()
454    }
455}