prediction.rs

  1use std::{ops::Range, sync::Arc};
  2
  3use gpui::{AsyncApp, Entity};
  4use language::{Anchor, Buffer, BufferSnapshot, EditPreview, OffsetRangeExt, TextBufferSnapshot};
  5use uuid::Uuid;
  6
  7#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
  8pub struct EditPredictionId(pub Uuid);
  9
 10impl Into<Uuid> for EditPredictionId {
 11    fn into(self) -> Uuid {
 12        self.0
 13    }
 14}
 15
 16impl From<EditPredictionId> for gpui::ElementId {
 17    fn from(value: EditPredictionId) -> Self {
 18        gpui::ElementId::Uuid(value.0)
 19    }
 20}
 21
 22impl std::fmt::Display for EditPredictionId {
 23    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 24        write!(f, "{}", self.0)
 25    }
 26}
 27
 28#[derive(Clone)]
 29pub struct EditPrediction {
 30    pub id: EditPredictionId,
 31    pub edits: Arc<[(Range<Anchor>, Arc<str>)]>,
 32    pub snapshot: BufferSnapshot,
 33    pub edit_preview: EditPreview,
 34    // We keep a reference to the buffer so that we do not need to reload it from disk when applying the prediction.
 35    pub buffer: Entity<Buffer>,
 36}
 37
 38impl EditPrediction {
 39    pub async fn new(
 40        id: EditPredictionId,
 41        edited_buffer: &Entity<Buffer>,
 42        edited_buffer_snapshot: &BufferSnapshot,
 43        edits: Vec<(Range<Anchor>, Arc<str>)>,
 44        cx: &mut AsyncApp,
 45    ) -> Option<Self> {
 46        let (edits, snapshot, edit_preview_task) = edited_buffer
 47            .read_with(cx, |buffer, cx| {
 48                let new_snapshot = buffer.snapshot();
 49                let edits: Arc<[_]> =
 50                    interpolate_edits(&edited_buffer_snapshot, &new_snapshot, edits.into())?.into();
 51
 52                Some((edits.clone(), new_snapshot, buffer.preview_edits(edits, cx)))
 53            })
 54            .ok()??;
 55
 56        let edit_preview = edit_preview_task.await;
 57
 58        Some(EditPrediction {
 59            id,
 60            edits,
 61            snapshot,
 62            edit_preview,
 63            buffer: edited_buffer.clone(),
 64        })
 65    }
 66
 67    pub fn interpolate(
 68        &self,
 69        new_snapshot: &TextBufferSnapshot,
 70    ) -> Option<Vec<(Range<Anchor>, Arc<str>)>> {
 71        interpolate_edits(&self.snapshot, new_snapshot, self.edits.clone())
 72    }
 73
 74    pub fn targets_buffer(&self, buffer: &Buffer) -> bool {
 75        self.snapshot.remote_id() == buffer.remote_id()
 76    }
 77}
 78
 79impl std::fmt::Debug for EditPrediction {
 80    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 81        f.debug_struct("EditPrediction")
 82            .field("id", &self.id)
 83            .field("edits", &self.edits)
 84            .finish()
 85    }
 86}
 87
 88pub fn interpolate_edits(
 89    old_snapshot: &TextBufferSnapshot,
 90    new_snapshot: &TextBufferSnapshot,
 91    current_edits: Arc<[(Range<Anchor>, Arc<str>)]>,
 92) -> Option<Vec<(Range<Anchor>, Arc<str>)>> {
 93    let mut edits = Vec::new();
 94
 95    let mut model_edits = current_edits.iter().peekable();
 96    for user_edit in new_snapshot.edits_since::<usize>(&old_snapshot.version) {
 97        while let Some((model_old_range, _)) = model_edits.peek() {
 98            let model_old_range = model_old_range.to_offset(old_snapshot);
 99            if model_old_range.end < user_edit.old.start {
100                let (model_old_range, model_new_text) = model_edits.next().unwrap();
101                edits.push((model_old_range.clone(), model_new_text.clone()));
102            } else {
103                break;
104            }
105        }
106
107        if let Some((model_old_range, model_new_text)) = model_edits.peek() {
108            let model_old_offset_range = model_old_range.to_offset(old_snapshot);
109            if user_edit.old == model_old_offset_range {
110                let user_new_text = new_snapshot
111                    .text_for_range(user_edit.new.clone())
112                    .collect::<String>();
113
114                if let Some(model_suffix) = model_new_text.strip_prefix(&user_new_text) {
115                    if !model_suffix.is_empty() {
116                        let anchor = old_snapshot.anchor_after(user_edit.old.end);
117                        edits.push((anchor..anchor, model_suffix.into()));
118                    }
119
120                    model_edits.next();
121                    continue;
122                }
123            }
124        }
125
126        return None;
127    }
128
129    edits.extend(model_edits.cloned());
130
131    if edits.is_empty() { None } else { Some(edits) }
132}
133
134#[cfg(test)]
135mod tests {
136    use super::*;
137    use gpui::{App, Entity, TestAppContext, prelude::*};
138    use language::{Buffer, ToOffset as _};
139
140    #[gpui::test]
141    async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
142        let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
143        let edits: Arc<[(Range<Anchor>, Arc<str>)]> = cx.update(|cx| {
144            to_prediction_edits([(2..5, "REM".into()), (9..11, "".into())], &buffer, cx).into()
145        });
146
147        let edit_preview = cx
148            .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
149            .await;
150
151        let prediction = EditPrediction {
152            id: EditPredictionId(Uuid::new_v4()),
153            edits,
154            snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
155            buffer: buffer.clone(),
156            edit_preview,
157        };
158
159        cx.update(|cx| {
160            assert_eq!(
161                from_prediction_edits(
162                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
163                    &buffer,
164                    cx
165                ),
166                vec![(2..5, "REM".into()), (9..11, "".into())]
167            );
168
169            buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
170            assert_eq!(
171                from_prediction_edits(
172                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
173                    &buffer,
174                    cx
175                ),
176                vec![(2..2, "REM".into()), (6..8, "".into())]
177            );
178
179            buffer.update(cx, |buffer, cx| buffer.undo(cx));
180            assert_eq!(
181                from_prediction_edits(
182                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
183                    &buffer,
184                    cx
185                ),
186                vec![(2..5, "REM".into()), (9..11, "".into())]
187            );
188
189            buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
190            assert_eq!(
191                from_prediction_edits(
192                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
193                    &buffer,
194                    cx
195                ),
196                vec![(3..3, "EM".into()), (7..9, "".into())]
197            );
198
199            buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
200            assert_eq!(
201                from_prediction_edits(
202                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
203                    &buffer,
204                    cx
205                ),
206                vec![(4..4, "M".into()), (8..10, "".into())]
207            );
208
209            buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
210            assert_eq!(
211                from_prediction_edits(
212                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
213                    &buffer,
214                    cx
215                ),
216                vec![(9..11, "".into())]
217            );
218
219            buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
220            assert_eq!(
221                from_prediction_edits(
222                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
223                    &buffer,
224                    cx
225                ),
226                vec![(4..4, "M".into()), (8..10, "".into())]
227            );
228
229            buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
230            assert_eq!(
231                from_prediction_edits(
232                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
233                    &buffer,
234                    cx
235                ),
236                vec![(4..4, "M".into())]
237            );
238
239            buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
240            assert_eq!(prediction.interpolate(&buffer.read(cx).snapshot()), None);
241        })
242    }
243
244    fn to_prediction_edits(
245        iterator: impl IntoIterator<Item = (Range<usize>, Arc<str>)>,
246        buffer: &Entity<Buffer>,
247        cx: &App,
248    ) -> Vec<(Range<Anchor>, Arc<str>)> {
249        let buffer = buffer.read(cx);
250        iterator
251            .into_iter()
252            .map(|(range, text)| {
253                (
254                    buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
255                    text,
256                )
257            })
258            .collect()
259    }
260
261    fn from_prediction_edits(
262        editor_edits: &[(Range<Anchor>, Arc<str>)],
263        buffer: &Entity<Buffer>,
264        cx: &App,
265    ) -> Vec<(Range<usize>, Arc<str>)> {
266        let buffer = buffer.read(cx);
267        editor_edits
268            .iter()
269            .map(|(range, text)| {
270                (
271                    range.start.to_offset(buffer)..range.end.to_offset(buffer),
272                    text.clone(),
273                )
274            })
275            .collect()
276    }
277}