prediction.rs

  1use std::{
  2    ops::Range,
  3    path::Path,
  4    sync::Arc,
  5    time::{Duration, Instant},
  6};
  7
  8use gpui::{AsyncApp, Entity, SharedString};
  9use language::{Anchor, Buffer, BufferSnapshot, EditPreview, OffsetRangeExt, TextBufferSnapshot};
 10use serde::Serialize;
 11
 12#[derive(Clone, Default, Debug, PartialEq, Eq, Hash)]
 13pub struct EditPredictionId(pub SharedString);
 14
 15impl From<EditPredictionId> for gpui::ElementId {
 16    fn from(value: EditPredictionId) -> Self {
 17        gpui::ElementId::Name(value.0)
 18    }
 19}
 20
 21impl std::fmt::Display for EditPredictionId {
 22    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 23        write!(f, "{}", self.0)
 24    }
 25}
 26
 27#[derive(Clone)]
 28pub struct EditPrediction {
 29    pub id: EditPredictionId,
 30    pub edits: Arc<[(Range<Anchor>, Arc<str>)]>,
 31    pub snapshot: BufferSnapshot,
 32    pub edit_preview: EditPreview,
 33    // We keep a reference to the buffer so that we do not need to reload it from disk when applying the prediction.
 34    pub buffer: Entity<Buffer>,
 35    pub buffer_snapshotted_at: Instant,
 36    pub response_received_at: Instant,
 37    pub inputs: EditPredictionInputs,
 38}
 39
 40#[derive(Debug, Clone, Serialize)]
 41pub struct EditPredictionInputs {
 42    pub events: Vec<Arc<cloud_llm_client::predict_edits_v3::Event>>,
 43    pub included_files: Vec<cloud_llm_client::predict_edits_v3::IncludedFile>,
 44    pub cursor_point: cloud_llm_client::predict_edits_v3::Point,
 45    pub cursor_path: Arc<Path>,
 46}
 47
 48impl EditPrediction {
 49    pub async fn new(
 50        id: EditPredictionId,
 51        edited_buffer: &Entity<Buffer>,
 52        edited_buffer_snapshot: &BufferSnapshot,
 53        edits: Arc<[(Range<Anchor>, Arc<str>)]>,
 54        buffer_snapshotted_at: Instant,
 55        response_received_at: Instant,
 56        inputs: EditPredictionInputs,
 57        cx: &mut AsyncApp,
 58    ) -> Option<Self> {
 59        let (edits, snapshot, edit_preview_task) = edited_buffer
 60            .read_with(cx, |buffer, cx| {
 61                let new_snapshot = buffer.snapshot();
 62                let edits: Arc<[_]> =
 63                    interpolate_edits(&edited_buffer_snapshot, &new_snapshot, edits)?.into();
 64
 65                Some((edits.clone(), new_snapshot, buffer.preview_edits(edits, cx)))
 66            })
 67            .ok()??;
 68
 69        let edit_preview = edit_preview_task.await;
 70
 71        Some(EditPrediction {
 72            id,
 73            edits,
 74            snapshot,
 75            edit_preview,
 76            inputs,
 77            buffer: edited_buffer.clone(),
 78            buffer_snapshotted_at,
 79            response_received_at,
 80        })
 81    }
 82
 83    pub fn interpolate(
 84        &self,
 85        new_snapshot: &TextBufferSnapshot,
 86    ) -> Option<Vec<(Range<Anchor>, Arc<str>)>> {
 87        interpolate_edits(&self.snapshot, new_snapshot, self.edits.clone())
 88    }
 89
 90    pub fn targets_buffer(&self, buffer: &Buffer) -> bool {
 91        self.snapshot.remote_id() == buffer.remote_id()
 92    }
 93
 94    pub fn latency(&self) -> Duration {
 95        self.response_received_at - self.buffer_snapshotted_at
 96    }
 97}
 98
 99impl std::fmt::Debug for EditPrediction {
100    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
101        f.debug_struct("EditPrediction")
102            .field("id", &self.id)
103            .field("edits", &self.edits)
104            .finish()
105    }
106}
107
108pub fn interpolate_edits(
109    old_snapshot: &TextBufferSnapshot,
110    new_snapshot: &TextBufferSnapshot,
111    current_edits: Arc<[(Range<Anchor>, Arc<str>)]>,
112) -> Option<Vec<(Range<Anchor>, Arc<str>)>> {
113    let mut edits = Vec::new();
114
115    let mut model_edits = current_edits.iter().peekable();
116    for user_edit in new_snapshot.edits_since::<usize>(&old_snapshot.version) {
117        while let Some((model_old_range, _)) = model_edits.peek() {
118            let model_old_range = model_old_range.to_offset(old_snapshot);
119            if model_old_range.end < user_edit.old.start {
120                let (model_old_range, model_new_text) = model_edits.next().unwrap();
121                edits.push((model_old_range.clone(), model_new_text.clone()));
122            } else {
123                break;
124            }
125        }
126
127        if let Some((model_old_range, model_new_text)) = model_edits.peek() {
128            let model_old_offset_range = model_old_range.to_offset(old_snapshot);
129            if user_edit.old == model_old_offset_range {
130                let user_new_text = new_snapshot
131                    .text_for_range(user_edit.new.clone())
132                    .collect::<String>();
133
134                if let Some(model_suffix) = model_new_text.strip_prefix(&user_new_text) {
135                    if !model_suffix.is_empty() {
136                        let anchor = old_snapshot.anchor_after(user_edit.old.end);
137                        edits.push((anchor..anchor, model_suffix.into()));
138                    }
139
140                    model_edits.next();
141                    continue;
142                }
143            }
144        }
145
146        return None;
147    }
148
149    edits.extend(model_edits.cloned());
150
151    if edits.is_empty() { None } else { Some(edits) }
152}
153
154#[cfg(test)]
155mod tests {
156    use super::*;
157    use gpui::{App, Entity, TestAppContext, prelude::*};
158    use language::{Buffer, ToOffset as _};
159
160    #[gpui::test]
161    async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
162        let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
163        let edits: Arc<[(Range<Anchor>, Arc<str>)]> = cx.update(|cx| {
164            to_prediction_edits([(2..5, "REM".into()), (9..11, "".into())], &buffer, cx).into()
165        });
166
167        let edit_preview = cx
168            .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
169            .await;
170
171        let prediction = EditPrediction {
172            id: EditPredictionId("prediction-1".into()),
173            edits,
174            snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
175            buffer: buffer.clone(),
176            edit_preview,
177            inputs: EditPredictionInputs {
178                events: vec![],
179                included_files: vec![],
180                cursor_point: cloud_llm_client::predict_edits_v3::Point {
181                    line: cloud_llm_client::predict_edits_v3::Line(0),
182                    column: 0,
183                },
184                cursor_path: Path::new("path.txt").into(),
185            },
186            buffer_snapshotted_at: Instant::now(),
187            response_received_at: Instant::now(),
188        };
189
190        cx.update(|cx| {
191            assert_eq!(
192                from_prediction_edits(
193                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
194                    &buffer,
195                    cx
196                ),
197                vec![(2..5, "REM".into()), (9..11, "".into())]
198            );
199
200            buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
201            assert_eq!(
202                from_prediction_edits(
203                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
204                    &buffer,
205                    cx
206                ),
207                vec![(2..2, "REM".into()), (6..8, "".into())]
208            );
209
210            buffer.update(cx, |buffer, cx| buffer.undo(cx));
211            assert_eq!(
212                from_prediction_edits(
213                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
214                    &buffer,
215                    cx
216                ),
217                vec![(2..5, "REM".into()), (9..11, "".into())]
218            );
219
220            buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
221            assert_eq!(
222                from_prediction_edits(
223                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
224                    &buffer,
225                    cx
226                ),
227                vec![(3..3, "EM".into()), (7..9, "".into())]
228            );
229
230            buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
231            assert_eq!(
232                from_prediction_edits(
233                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
234                    &buffer,
235                    cx
236                ),
237                vec![(4..4, "M".into()), (8..10, "".into())]
238            );
239
240            buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
241            assert_eq!(
242                from_prediction_edits(
243                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
244                    &buffer,
245                    cx
246                ),
247                vec![(9..11, "".into())]
248            );
249
250            buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
251            assert_eq!(
252                from_prediction_edits(
253                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
254                    &buffer,
255                    cx
256                ),
257                vec![(4..4, "M".into()), (8..10, "".into())]
258            );
259
260            buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
261            assert_eq!(
262                from_prediction_edits(
263                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
264                    &buffer,
265                    cx
266                ),
267                vec![(4..4, "M".into())]
268            );
269
270            buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
271            assert_eq!(prediction.interpolate(&buffer.read(cx).snapshot()), None);
272        })
273    }
274
275    fn to_prediction_edits(
276        iterator: impl IntoIterator<Item = (Range<usize>, Arc<str>)>,
277        buffer: &Entity<Buffer>,
278        cx: &App,
279    ) -> Vec<(Range<Anchor>, Arc<str>)> {
280        let buffer = buffer.read(cx);
281        iterator
282            .into_iter()
283            .map(|(range, text)| {
284                (
285                    buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
286                    text,
287                )
288            })
289            .collect()
290    }
291
292    fn from_prediction_edits(
293        editor_edits: &[(Range<Anchor>, Arc<str>)],
294        buffer: &Entity<Buffer>,
295        cx: &App,
296    ) -> Vec<(Range<usize>, Arc<str>)> {
297        let buffer = buffer.read(cx);
298        editor_edits
299            .iter()
300            .map(|(range, text)| {
301                (
302                    range.start.to_offset(buffer)..range.end.to_offset(buffer),
303                    text.clone(),
304                )
305            })
306            .collect()
307    }
308}