prediction.rs

  1use std::{
  2    ops::Range,
  3    path::Path,
  4    sync::Arc,
  5    time::{Duration, Instant},
  6};
  7
  8use cloud_llm_client::EditPredictionRejectReason;
  9use gpui::{AsyncApp, Entity, SharedString};
 10use language::{Anchor, Buffer, BufferSnapshot, EditPreview, OffsetRangeExt, TextBufferSnapshot};
 11use serde::Serialize;
 12
 13#[derive(Clone, Default, Debug, PartialEq, Eq, Hash)]
 14pub struct EditPredictionId(pub SharedString);
 15
 16impl From<EditPredictionId> for gpui::ElementId {
 17    fn from(value: EditPredictionId) -> Self {
 18        gpui::ElementId::Name(value.0)
 19    }
 20}
 21
 22impl std::fmt::Display for EditPredictionId {
 23    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 24        write!(f, "{}", self.0)
 25    }
 26}
 27
 28/// A prediction response that was returned from the provider, whether it was ultimately valid or not.
 29pub struct EditPredictionResult {
 30    pub id: EditPredictionId,
 31    pub prediction: Result<EditPrediction, EditPredictionRejectReason>,
 32}
 33
 34impl EditPredictionResult {
 35    pub async fn new(
 36        id: EditPredictionId,
 37        edited_buffer: &Entity<Buffer>,
 38        edited_buffer_snapshot: &BufferSnapshot,
 39        edits: Arc<[(Range<Anchor>, Arc<str>)]>,
 40        buffer_snapshotted_at: Instant,
 41        response_received_at: Instant,
 42        inputs: EditPredictionInputs,
 43        cx: &mut AsyncApp,
 44    ) -> Self {
 45        if edits.is_empty() {
 46            return Self {
 47                id,
 48                prediction: Err(EditPredictionRejectReason::Empty),
 49            };
 50        }
 51
 52        let Some((edits, snapshot, edit_preview_task)) = edited_buffer
 53            .read_with(cx, |buffer, cx| {
 54                let new_snapshot = buffer.snapshot();
 55                let edits: Arc<[_]> =
 56                    interpolate_edits(&edited_buffer_snapshot, &new_snapshot, edits)?.into();
 57
 58                Some((edits.clone(), new_snapshot, buffer.preview_edits(edits, cx)))
 59            })
 60            .ok()
 61            .flatten()
 62        else {
 63            return Self {
 64                id,
 65                prediction: Err(EditPredictionRejectReason::InterpolatedEmpty),
 66            };
 67        };
 68
 69        let edit_preview = edit_preview_task.await;
 70
 71        Self {
 72            id: id.clone(),
 73            prediction: Ok(EditPrediction {
 74                id,
 75                edits,
 76                snapshot,
 77                edit_preview,
 78                inputs,
 79                buffer: edited_buffer.clone(),
 80                buffer_snapshotted_at,
 81                response_received_at,
 82            }),
 83        }
 84    }
 85}
 86
 87#[derive(Clone)]
 88pub struct EditPrediction {
 89    pub id: EditPredictionId,
 90    pub edits: Arc<[(Range<Anchor>, Arc<str>)]>,
 91    pub snapshot: BufferSnapshot,
 92    pub edit_preview: EditPreview,
 93    pub buffer: Entity<Buffer>,
 94    pub buffer_snapshotted_at: Instant,
 95    pub response_received_at: Instant,
 96    pub inputs: EditPredictionInputs,
 97}
 98
 99#[derive(Debug, Clone, Serialize)]
100pub struct EditPredictionInputs {
101    pub events: Vec<Arc<cloud_llm_client::predict_edits_v3::Event>>,
102    pub included_files: Vec<cloud_llm_client::predict_edits_v3::IncludedFile>,
103    pub cursor_point: cloud_llm_client::predict_edits_v3::Point,
104    pub cursor_path: Arc<Path>,
105}
106
107impl EditPrediction {
108    pub fn interpolate(
109        &self,
110        new_snapshot: &TextBufferSnapshot,
111    ) -> Option<Vec<(Range<Anchor>, Arc<str>)>> {
112        interpolate_edits(&self.snapshot, new_snapshot, self.edits.clone())
113    }
114
115    pub fn targets_buffer(&self, buffer: &Buffer) -> bool {
116        self.snapshot.remote_id() == buffer.remote_id()
117    }
118
119    pub fn latency(&self) -> Duration {
120        self.response_received_at - self.buffer_snapshotted_at
121    }
122}
123
124impl std::fmt::Debug for EditPrediction {
125    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
126        f.debug_struct("EditPrediction")
127            .field("id", &self.id)
128            .field("edits", &self.edits)
129            .finish()
130    }
131}
132
133pub fn interpolate_edits(
134    old_snapshot: &TextBufferSnapshot,
135    new_snapshot: &TextBufferSnapshot,
136    current_edits: Arc<[(Range<Anchor>, Arc<str>)]>,
137) -> Option<Vec<(Range<Anchor>, Arc<str>)>> {
138    let mut edits = Vec::new();
139
140    let mut model_edits = current_edits.iter().peekable();
141    for user_edit in new_snapshot.edits_since::<usize>(&old_snapshot.version) {
142        while let Some((model_old_range, _)) = model_edits.peek() {
143            let model_old_range = model_old_range.to_offset(old_snapshot);
144            if model_old_range.end < user_edit.old.start {
145                let (model_old_range, model_new_text) = model_edits.next().unwrap();
146                edits.push((model_old_range.clone(), model_new_text.clone()));
147            } else {
148                break;
149            }
150        }
151
152        if let Some((model_old_range, model_new_text)) = model_edits.peek() {
153            let model_old_offset_range = model_old_range.to_offset(old_snapshot);
154            if user_edit.old == model_old_offset_range {
155                let user_new_text = new_snapshot
156                    .text_for_range(user_edit.new.clone())
157                    .collect::<String>();
158
159                if let Some(model_suffix) = model_new_text.strip_prefix(&user_new_text) {
160                    if !model_suffix.is_empty() {
161                        let anchor = old_snapshot.anchor_after(user_edit.old.end);
162                        edits.push((anchor..anchor, model_suffix.into()));
163                    }
164
165                    model_edits.next();
166                    continue;
167                }
168            }
169        }
170
171        return None;
172    }
173
174    edits.extend(model_edits.cloned());
175
176    if edits.is_empty() { None } else { Some(edits) }
177}
178
179#[cfg(test)]
180mod tests {
181    use super::*;
182    use gpui::{App, Entity, TestAppContext, prelude::*};
183    use language::{Buffer, ToOffset as _};
184
185    #[gpui::test]
186    async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
187        let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
188        let edits: Arc<[(Range<Anchor>, Arc<str>)]> = cx.update(|cx| {
189            to_prediction_edits([(2..5, "REM".into()), (9..11, "".into())], &buffer, cx).into()
190        });
191
192        let edit_preview = cx
193            .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
194            .await;
195
196        let prediction = EditPrediction {
197            id: EditPredictionId("prediction-1".into()),
198            edits,
199            snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
200            buffer: buffer.clone(),
201            edit_preview,
202            inputs: EditPredictionInputs {
203                events: vec![],
204                included_files: vec![],
205                cursor_point: cloud_llm_client::predict_edits_v3::Point {
206                    line: cloud_llm_client::predict_edits_v3::Line(0),
207                    column: 0,
208                },
209                cursor_path: Path::new("path.txt").into(),
210            },
211            buffer_snapshotted_at: Instant::now(),
212            response_received_at: Instant::now(),
213        };
214
215        cx.update(|cx| {
216            assert_eq!(
217                from_prediction_edits(
218                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
219                    &buffer,
220                    cx
221                ),
222                vec![(2..5, "REM".into()), (9..11, "".into())]
223            );
224
225            buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
226            assert_eq!(
227                from_prediction_edits(
228                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
229                    &buffer,
230                    cx
231                ),
232                vec![(2..2, "REM".into()), (6..8, "".into())]
233            );
234
235            buffer.update(cx, |buffer, cx| buffer.undo(cx));
236            assert_eq!(
237                from_prediction_edits(
238                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
239                    &buffer,
240                    cx
241                ),
242                vec![(2..5, "REM".into()), (9..11, "".into())]
243            );
244
245            buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
246            assert_eq!(
247                from_prediction_edits(
248                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
249                    &buffer,
250                    cx
251                ),
252                vec![(3..3, "EM".into()), (7..9, "".into())]
253            );
254
255            buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
256            assert_eq!(
257                from_prediction_edits(
258                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
259                    &buffer,
260                    cx
261                ),
262                vec![(4..4, "M".into()), (8..10, "".into())]
263            );
264
265            buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
266            assert_eq!(
267                from_prediction_edits(
268                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
269                    &buffer,
270                    cx
271                ),
272                vec![(9..11, "".into())]
273            );
274
275            buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
276            assert_eq!(
277                from_prediction_edits(
278                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
279                    &buffer,
280                    cx
281                ),
282                vec![(4..4, "M".into()), (8..10, "".into())]
283            );
284
285            buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
286            assert_eq!(
287                from_prediction_edits(
288                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
289                    &buffer,
290                    cx
291                ),
292                vec![(4..4, "M".into())]
293            );
294
295            buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
296            assert_eq!(prediction.interpolate(&buffer.read(cx).snapshot()), None);
297        })
298    }
299
300    fn to_prediction_edits(
301        iterator: impl IntoIterator<Item = (Range<usize>, Arc<str>)>,
302        buffer: &Entity<Buffer>,
303        cx: &App,
304    ) -> Vec<(Range<Anchor>, Arc<str>)> {
305        let buffer = buffer.read(cx);
306        iterator
307            .into_iter()
308            .map(|(range, text)| {
309                (
310                    buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
311                    text,
312                )
313            })
314            .collect()
315    }
316
317    fn from_prediction_edits(
318        editor_edits: &[(Range<Anchor>, Arc<str>)],
319        buffer: &Entity<Buffer>,
320        cx: &App,
321    ) -> Vec<(Range<usize>, Arc<str>)> {
322        let buffer = buffer.read(cx);
323        editor_edits
324            .iter()
325            .map(|(range, text)| {
326                (
327                    range.start.to_offset(buffer)..range.end.to_offset(buffer),
328                    text.clone(),
329                )
330            })
331            .collect()
332    }
333}