prediction.rs

  1use std::{borrow::Cow, ops::Range, path::Path, sync::Arc};
  2
  3use anyhow::Context as _;
  4use cloud_llm_client::predict_edits_v3;
  5use gpui::{App, AsyncApp, Entity};
  6use language::{
  7    Anchor, Buffer, BufferSnapshot, EditPreview, OffsetRangeExt, TextBufferSnapshot, text_diff,
  8};
  9use project::Project;
 10use util::ResultExt;
 11use uuid::Uuid;
 12
 13#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
 14pub struct EditPredictionId(Uuid);
 15
 16impl From<EditPredictionId> for gpui::ElementId {
 17    fn from(value: EditPredictionId) -> Self {
 18        gpui::ElementId::Uuid(value.0)
 19    }
 20}
 21
 22impl std::fmt::Display for EditPredictionId {
 23    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 24        write!(f, "{}", self.0)
 25    }
 26}
 27
 28#[derive(Clone)]
 29pub struct EditPrediction {
 30    pub id: EditPredictionId,
 31    pub path: Arc<Path>,
 32    pub edits: Arc<[(Range<Anchor>, String)]>,
 33    pub snapshot: BufferSnapshot,
 34    pub edit_preview: EditPreview,
 35    // We keep a reference to the buffer so that we do not need to reload it from disk when applying the prediction.
 36    _buffer: Entity<Buffer>,
 37}
 38
 39impl EditPrediction {
 40    pub async fn from_response(
 41        response: predict_edits_v3::PredictEditsResponse,
 42        active_buffer_old_snapshot: &TextBufferSnapshot,
 43        active_buffer: &Entity<Buffer>,
 44        project: &Entity<Project>,
 45        cx: &mut AsyncApp,
 46    ) -> Option<Self> {
 47        // TODO only allow cloud to return one path
 48        let Some(path) = response.edits.first().map(|e| e.path.clone()) else {
 49            return None;
 50        };
 51
 52        let is_same_path = active_buffer
 53            .read_with(cx, |buffer, cx| buffer_path_eq(buffer, &path, cx))
 54            .ok()?;
 55
 56        let (buffer, edits, snapshot, edit_preview_task) = if is_same_path {
 57            active_buffer
 58                .read_with(cx, |buffer, cx| {
 59                    let new_snapshot = buffer.snapshot();
 60                    let edits = edits_from_response(&response.edits, &active_buffer_old_snapshot);
 61                    let edits: Arc<[_]> =
 62                        interpolate_edits(active_buffer_old_snapshot, &new_snapshot, edits)?.into();
 63
 64                    Some((
 65                        active_buffer.clone(),
 66                        edits.clone(),
 67                        new_snapshot,
 68                        buffer.preview_edits(edits, cx),
 69                    ))
 70                })
 71                .ok()??
 72        } else {
 73            let buffer_handle = project
 74                .update(cx, |project, cx| {
 75                    let project_path = project
 76                        .find_project_path(&path, cx)
 77                        .context("Failed to find project path for zeta edit")?;
 78                    anyhow::Ok(project.open_buffer(project_path, cx))
 79                })
 80                .ok()?
 81                .log_err()?
 82                .await
 83                .context("Failed to open buffer for zeta edit")
 84                .log_err()?;
 85
 86            buffer_handle
 87                .read_with(cx, |buffer, cx| {
 88                    let snapshot = buffer.snapshot();
 89                    let edits = edits_from_response(&response.edits, &snapshot);
 90                    if edits.is_empty() {
 91                        return None;
 92                    }
 93                    Some((
 94                        buffer_handle.clone(),
 95                        edits.clone(),
 96                        snapshot,
 97                        buffer.preview_edits(edits, cx),
 98                    ))
 99                })
100                .ok()??
101        };
102
103        let edit_preview = edit_preview_task.await;
104
105        Some(EditPrediction {
106            id: EditPredictionId(response.request_id),
107            path,
108            edits,
109            snapshot,
110            edit_preview,
111            _buffer: buffer,
112        })
113    }
114
115    pub fn interpolate(
116        &self,
117        new_snapshot: &TextBufferSnapshot,
118    ) -> Option<Vec<(Range<Anchor>, String)>> {
119        interpolate_edits(&self.snapshot, new_snapshot, self.edits.clone())
120    }
121
122    pub fn targets_buffer(&self, buffer: &Buffer, cx: &App) -> bool {
123        buffer_path_eq(buffer, &self.path, cx)
124    }
125}
126
127impl std::fmt::Debug for EditPrediction {
128    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
129        f.debug_struct("EditPrediction")
130            .field("id", &self.id)
131            .field("path", &self.path)
132            .field("edits", &self.edits)
133            .finish()
134    }
135}
136
137pub fn buffer_path_eq(buffer: &Buffer, path: &Path, cx: &App) -> bool {
138    buffer.file().map(|p| p.full_path(cx)).as_deref() == Some(path)
139}
140
141pub fn interpolate_edits(
142    old_snapshot: &TextBufferSnapshot,
143    new_snapshot: &TextBufferSnapshot,
144    current_edits: Arc<[(Range<Anchor>, String)]>,
145) -> Option<Vec<(Range<Anchor>, String)>> {
146    let mut edits = Vec::new();
147
148    let mut model_edits = current_edits.iter().peekable();
149    for user_edit in new_snapshot.edits_since::<usize>(&old_snapshot.version) {
150        while let Some((model_old_range, _)) = model_edits.peek() {
151            let model_old_range = model_old_range.to_offset(old_snapshot);
152            if model_old_range.end < user_edit.old.start {
153                let (model_old_range, model_new_text) = model_edits.next().unwrap();
154                edits.push((model_old_range.clone(), model_new_text.clone()));
155            } else {
156                break;
157            }
158        }
159
160        if let Some((model_old_range, model_new_text)) = model_edits.peek() {
161            let model_old_offset_range = model_old_range.to_offset(old_snapshot);
162            if user_edit.old == model_old_offset_range {
163                let user_new_text = new_snapshot
164                    .text_for_range(user_edit.new.clone())
165                    .collect::<String>();
166
167                if let Some(model_suffix) = model_new_text.strip_prefix(&user_new_text) {
168                    if !model_suffix.is_empty() {
169                        let anchor = old_snapshot.anchor_after(user_edit.old.end);
170                        edits.push((anchor..anchor, model_suffix.to_string()));
171                    }
172
173                    model_edits.next();
174                    continue;
175                }
176            }
177        }
178
179        return None;
180    }
181
182    edits.extend(model_edits.cloned());
183
184    if edits.is_empty() { None } else { Some(edits) }
185}
186
187fn edits_from_response(
188    edits: &[predict_edits_v3::Edit],
189    snapshot: &TextBufferSnapshot,
190) -> Arc<[(Range<Anchor>, String)]> {
191    edits
192        .iter()
193        .flat_map(|edit| {
194            let old_text = snapshot.text_for_range(edit.range.clone());
195
196            excerpt_edits_from_response(
197                old_text.collect::<Cow<str>>(),
198                &edit.content,
199                edit.range.start,
200                &snapshot,
201            )
202        })
203        .collect::<Vec<_>>()
204        .into()
205}
206
207fn excerpt_edits_from_response(
208    old_text: Cow<str>,
209    new_text: &str,
210    offset: usize,
211    snapshot: &TextBufferSnapshot,
212) -> impl Iterator<Item = (Range<Anchor>, String)> {
213    text_diff(&old_text, new_text)
214        .into_iter()
215        .map(move |(mut old_range, new_text)| {
216            old_range.start += offset;
217            old_range.end += offset;
218
219            let prefix_len = common_prefix(
220                snapshot.chars_for_range(old_range.clone()),
221                new_text.chars(),
222            );
223            old_range.start += prefix_len;
224
225            let suffix_len = common_prefix(
226                snapshot.reversed_chars_for_range(old_range.clone()),
227                new_text[prefix_len..].chars().rev(),
228            );
229            old_range.end = old_range.end.saturating_sub(suffix_len);
230
231            let new_text = new_text[prefix_len..new_text.len() - suffix_len].to_string();
232            let range = if old_range.is_empty() {
233                let anchor = snapshot.anchor_after(old_range.start);
234                anchor..anchor
235            } else {
236                snapshot.anchor_after(old_range.start)..snapshot.anchor_before(old_range.end)
237            };
238            (range, new_text)
239        })
240}
241
242fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b: T2) -> usize {
243    a.zip(b)
244        .take_while(|(a, b)| a == b)
245        .map(|(a, _)| a.len_utf8())
246        .sum()
247}
248
249#[cfg(test)]
250mod tests {
251    use std::path::PathBuf;
252
253    use super::*;
254    use cloud_llm_client::predict_edits_v3;
255    use gpui::{App, Entity, TestAppContext, prelude::*};
256    use indoc::indoc;
257    use language::{Buffer, ToOffset as _};
258
259    #[gpui::test]
260    async fn test_compute_edits(cx: &mut TestAppContext) {
261        let old = indoc! {r#"
262            fn main() {
263                let args =
264                println!("{}", args[1])
265            }
266        "#};
267
268        let new = indoc! {r#"
269            fn main() {
270                let args = std::env::args();
271                println!("{}", args[1]);
272            }
273        "#};
274
275        let buffer = cx.new(|cx| Buffer::local(old, cx));
276        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
277
278        // TODO cover more cases when multi-file is supported
279        let big_edits = vec![predict_edits_v3::Edit {
280            path: PathBuf::from("test.txt").into(),
281            range: 0..old.len(),
282            content: new.into(),
283        }];
284
285        let edits = edits_from_response(&big_edits, &snapshot);
286        assert_eq!(edits.len(), 2);
287        assert_eq!(
288            edits[0].0.to_point(&snapshot).start,
289            language::Point::new(1, 14)
290        );
291        assert_eq!(edits[0].1, " std::env::args();");
292        assert_eq!(
293            edits[1].0.to_point(&snapshot).start,
294            language::Point::new(2, 27)
295        );
296        assert_eq!(edits[1].1, ";");
297    }
298
299    #[gpui::test]
300    async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
301        let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
302        let edits: Arc<[(Range<Anchor>, String)]> = cx.update(|cx| {
303            to_prediction_edits(
304                [(2..5, "REM".to_string()), (9..11, "".to_string())],
305                &buffer,
306                cx,
307            )
308            .into()
309        });
310
311        let edit_preview = cx
312            .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
313            .await;
314
315        let prediction = EditPrediction {
316            id: EditPredictionId(Uuid::new_v4()),
317            edits,
318            snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
319            path: Path::new("test.txt").into(),
320            _buffer: buffer.clone(),
321            edit_preview,
322        };
323
324        cx.update(|cx| {
325            assert_eq!(
326                from_prediction_edits(
327                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
328                    &buffer,
329                    cx
330                ),
331                vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
332            );
333
334            buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
335            assert_eq!(
336                from_prediction_edits(
337                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
338                    &buffer,
339                    cx
340                ),
341                vec![(2..2, "REM".to_string()), (6..8, "".to_string())]
342            );
343
344            buffer.update(cx, |buffer, cx| buffer.undo(cx));
345            assert_eq!(
346                from_prediction_edits(
347                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
348                    &buffer,
349                    cx
350                ),
351                vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
352            );
353
354            buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
355            assert_eq!(
356                from_prediction_edits(
357                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
358                    &buffer,
359                    cx
360                ),
361                vec![(3..3, "EM".to_string()), (7..9, "".to_string())]
362            );
363
364            buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
365            assert_eq!(
366                from_prediction_edits(
367                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
368                    &buffer,
369                    cx
370                ),
371                vec![(4..4, "M".to_string()), (8..10, "".to_string())]
372            );
373
374            buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
375            assert_eq!(
376                from_prediction_edits(
377                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
378                    &buffer,
379                    cx
380                ),
381                vec![(9..11, "".to_string())]
382            );
383
384            buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
385            assert_eq!(
386                from_prediction_edits(
387                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
388                    &buffer,
389                    cx
390                ),
391                vec![(4..4, "M".to_string()), (8..10, "".to_string())]
392            );
393
394            buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
395            assert_eq!(
396                from_prediction_edits(
397                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
398                    &buffer,
399                    cx
400                ),
401                vec![(4..4, "M".to_string())]
402            );
403
404            buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
405            assert_eq!(prediction.interpolate(&buffer.read(cx).snapshot()), None);
406        })
407    }
408
409    fn to_prediction_edits(
410        iterator: impl IntoIterator<Item = (Range<usize>, String)>,
411        buffer: &Entity<Buffer>,
412        cx: &App,
413    ) -> Vec<(Range<Anchor>, String)> {
414        let buffer = buffer.read(cx);
415        iterator
416            .into_iter()
417            .map(|(range, text)| {
418                (
419                    buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
420                    text,
421                )
422            })
423            .collect()
424    }
425
426    fn from_prediction_edits(
427        editor_edits: &[(Range<Anchor>, String)],
428        buffer: &Entity<Buffer>,
429        cx: &App,
430    ) -> Vec<(Range<usize>, String)> {
431        let buffer = buffer.read(cx);
432        editor_edits
433            .iter()
434            .map(|(range, text)| {
435                (
436                    range.start.to_offset(buffer)..range.end.to_offset(buffer),
437                    text.clone(),
438                )
439            })
440            .collect()
441    }
442}