prediction.rs

  1use std::{ops::Range, sync::Arc};
  2
  3use gpui::{AsyncApp, Entity, SharedString};
  4use language::{Anchor, Buffer, BufferSnapshot, EditPreview, OffsetRangeExt, TextBufferSnapshot};
  5
  6#[derive(Clone, Default, Debug, PartialEq, Eq, Hash)]
  7pub struct EditPredictionId(pub SharedString);
  8
  9impl From<EditPredictionId> for gpui::ElementId {
 10    fn from(value: EditPredictionId) -> Self {
 11        gpui::ElementId::Name(value.0)
 12    }
 13}
 14
 15impl std::fmt::Display for EditPredictionId {
 16    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 17        write!(f, "{}", self.0)
 18    }
 19}
 20
 21#[derive(Clone)]
 22pub struct EditPrediction {
 23    pub id: EditPredictionId,
 24    pub edits: Arc<[(Range<Anchor>, Arc<str>)]>,
 25    pub snapshot: BufferSnapshot,
 26    pub edit_preview: EditPreview,
 27    // We keep a reference to the buffer so that we do not need to reload it from disk when applying the prediction.
 28    pub buffer: Entity<Buffer>,
 29}
 30
 31impl EditPrediction {
 32    pub async fn new(
 33        id: EditPredictionId,
 34        edited_buffer: &Entity<Buffer>,
 35        edited_buffer_snapshot: &BufferSnapshot,
 36        edits: Vec<(Range<Anchor>, Arc<str>)>,
 37        cx: &mut AsyncApp,
 38    ) -> Option<Self> {
 39        let (edits, snapshot, edit_preview_task) = edited_buffer
 40            .read_with(cx, |buffer, cx| {
 41                let new_snapshot = buffer.snapshot();
 42                let edits: Arc<[_]> =
 43                    interpolate_edits(&edited_buffer_snapshot, &new_snapshot, edits.into())?.into();
 44
 45                Some((edits.clone(), new_snapshot, buffer.preview_edits(edits, cx)))
 46            })
 47            .ok()??;
 48
 49        let edit_preview = edit_preview_task.await;
 50
 51        Some(EditPrediction {
 52            id,
 53            edits,
 54            snapshot,
 55            edit_preview,
 56            buffer: edited_buffer.clone(),
 57        })
 58    }
 59
 60    pub fn interpolate(
 61        &self,
 62        new_snapshot: &TextBufferSnapshot,
 63    ) -> Option<Vec<(Range<Anchor>, Arc<str>)>> {
 64        interpolate_edits(&self.snapshot, new_snapshot, self.edits.clone())
 65    }
 66
 67    pub fn targets_buffer(&self, buffer: &Buffer) -> bool {
 68        self.snapshot.remote_id() == buffer.remote_id()
 69    }
 70}
 71
 72impl std::fmt::Debug for EditPrediction {
 73    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 74        f.debug_struct("EditPrediction")
 75            .field("id", &self.id)
 76            .field("edits", &self.edits)
 77            .finish()
 78    }
 79}
 80
 81pub fn interpolate_edits(
 82    old_snapshot: &TextBufferSnapshot,
 83    new_snapshot: &TextBufferSnapshot,
 84    current_edits: Arc<[(Range<Anchor>, Arc<str>)]>,
 85) -> Option<Vec<(Range<Anchor>, Arc<str>)>> {
 86    let mut edits = Vec::new();
 87
 88    let mut model_edits = current_edits.iter().peekable();
 89    for user_edit in new_snapshot.edits_since::<usize>(&old_snapshot.version) {
 90        while let Some((model_old_range, _)) = model_edits.peek() {
 91            let model_old_range = model_old_range.to_offset(old_snapshot);
 92            if model_old_range.end < user_edit.old.start {
 93                let (model_old_range, model_new_text) = model_edits.next().unwrap();
 94                edits.push((model_old_range.clone(), model_new_text.clone()));
 95            } else {
 96                break;
 97            }
 98        }
 99
100        if let Some((model_old_range, model_new_text)) = model_edits.peek() {
101            let model_old_offset_range = model_old_range.to_offset(old_snapshot);
102            if user_edit.old == model_old_offset_range {
103                let user_new_text = new_snapshot
104                    .text_for_range(user_edit.new.clone())
105                    .collect::<String>();
106
107                if let Some(model_suffix) = model_new_text.strip_prefix(&user_new_text) {
108                    if !model_suffix.is_empty() {
109                        let anchor = old_snapshot.anchor_after(user_edit.old.end);
110                        edits.push((anchor..anchor, model_suffix.into()));
111                    }
112
113                    model_edits.next();
114                    continue;
115                }
116            }
117        }
118
119        return None;
120    }
121
122    edits.extend(model_edits.cloned());
123
124    if edits.is_empty() { None } else { Some(edits) }
125}
126
127#[cfg(test)]
128mod tests {
129    use super::*;
130    use gpui::{App, Entity, TestAppContext, prelude::*};
131    use language::{Buffer, ToOffset as _};
132
133    #[gpui::test]
134    async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
135        let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
136        let edits: Arc<[(Range<Anchor>, Arc<str>)]> = cx.update(|cx| {
137            to_prediction_edits([(2..5, "REM".into()), (9..11, "".into())], &buffer, cx).into()
138        });
139
140        let edit_preview = cx
141            .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
142            .await;
143
144        let prediction = EditPrediction {
145            id: EditPredictionId("prediction-1".into()),
146            edits,
147            snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
148            buffer: buffer.clone(),
149            edit_preview,
150        };
151
152        cx.update(|cx| {
153            assert_eq!(
154                from_prediction_edits(
155                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
156                    &buffer,
157                    cx
158                ),
159                vec![(2..5, "REM".into()), (9..11, "".into())]
160            );
161
162            buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
163            assert_eq!(
164                from_prediction_edits(
165                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
166                    &buffer,
167                    cx
168                ),
169                vec![(2..2, "REM".into()), (6..8, "".into())]
170            );
171
172            buffer.update(cx, |buffer, cx| buffer.undo(cx));
173            assert_eq!(
174                from_prediction_edits(
175                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
176                    &buffer,
177                    cx
178                ),
179                vec![(2..5, "REM".into()), (9..11, "".into())]
180            );
181
182            buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
183            assert_eq!(
184                from_prediction_edits(
185                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
186                    &buffer,
187                    cx
188                ),
189                vec![(3..3, "EM".into()), (7..9, "".into())]
190            );
191
192            buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
193            assert_eq!(
194                from_prediction_edits(
195                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
196                    &buffer,
197                    cx
198                ),
199                vec![(4..4, "M".into()), (8..10, "".into())]
200            );
201
202            buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
203            assert_eq!(
204                from_prediction_edits(
205                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
206                    &buffer,
207                    cx
208                ),
209                vec![(9..11, "".into())]
210            );
211
212            buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
213            assert_eq!(
214                from_prediction_edits(
215                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
216                    &buffer,
217                    cx
218                ),
219                vec![(4..4, "M".into()), (8..10, "".into())]
220            );
221
222            buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
223            assert_eq!(
224                from_prediction_edits(
225                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
226                    &buffer,
227                    cx
228                ),
229                vec![(4..4, "M".into())]
230            );
231
232            buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
233            assert_eq!(prediction.interpolate(&buffer.read(cx).snapshot()), None);
234        })
235    }
236
237    fn to_prediction_edits(
238        iterator: impl IntoIterator<Item = (Range<usize>, Arc<str>)>,
239        buffer: &Entity<Buffer>,
240        cx: &App,
241    ) -> Vec<(Range<Anchor>, Arc<str>)> {
242        let buffer = buffer.read(cx);
243        iterator
244            .into_iter()
245            .map(|(range, text)| {
246                (
247                    buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
248                    text,
249                )
250            })
251            .collect()
252    }
253
254    fn from_prediction_edits(
255        editor_edits: &[(Range<Anchor>, Arc<str>)],
256        buffer: &Entity<Buffer>,
257        cx: &App,
258    ) -> Vec<(Range<usize>, Arc<str>)> {
259        let buffer = buffer.read(cx);
260        editor_edits
261            .iter()
262            .map(|(range, text)| {
263                (
264                    range.start.to_offset(buffer)..range.end.to_offset(buffer),
265                    text.clone(),
266                )
267            })
268            .collect()
269    }
270}