prediction.rs

  1use std::{
  2    ops::Range,
  3    path::Path,
  4    sync::Arc,
  5    time::{Duration, Instant},
  6};
  7
  8use cloud_llm_client::EditPredictionRejectReason;
  9use edit_prediction_types::interpolate_edits;
 10use gpui::{AsyncApp, Entity, SharedString};
 11use language::{Anchor, Buffer, BufferSnapshot, EditPreview, TextBufferSnapshot};
 12use serde::Serialize;
 13
 14#[derive(Clone, Default, Debug, PartialEq, Eq, Hash)]
 15pub struct EditPredictionId(pub SharedString);
 16
 17impl From<EditPredictionId> for gpui::ElementId {
 18    fn from(value: EditPredictionId) -> Self {
 19        gpui::ElementId::Name(value.0)
 20    }
 21}
 22
 23impl std::fmt::Display for EditPredictionId {
 24    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 25        write!(f, "{}", self.0)
 26    }
 27}
 28
 29/// A prediction response that was returned from the provider, whether it was ultimately valid or not.
 30pub struct EditPredictionResult {
 31    pub id: EditPredictionId,
 32    pub prediction: Result<EditPrediction, EditPredictionRejectReason>,
 33}
 34
 35impl EditPredictionResult {
 36    pub async fn new(
 37        id: EditPredictionId,
 38        edited_buffer: &Entity<Buffer>,
 39        edited_buffer_snapshot: &BufferSnapshot,
 40        edits: Arc<[(Range<Anchor>, Arc<str>)]>,
 41        buffer_snapshotted_at: Instant,
 42        response_received_at: Instant,
 43        inputs: EditPredictionInputs,
 44        cx: &mut AsyncApp,
 45    ) -> Self {
 46        if edits.is_empty() {
 47            return Self {
 48                id,
 49                prediction: Err(EditPredictionRejectReason::Empty),
 50            };
 51        }
 52
 53        let Some((edits, snapshot, edit_preview_task)) = edited_buffer
 54            .read_with(cx, |buffer, cx| {
 55                let new_snapshot = buffer.snapshot();
 56                let edits: Arc<[_]> =
 57                    interpolate_edits(&edited_buffer_snapshot, &new_snapshot, &edits)?.into();
 58
 59                Some((edits.clone(), new_snapshot, buffer.preview_edits(edits, cx)))
 60            })
 61            .ok()
 62            .flatten()
 63        else {
 64            return Self {
 65                id,
 66                prediction: Err(EditPredictionRejectReason::InterpolatedEmpty),
 67            };
 68        };
 69
 70        let edit_preview = edit_preview_task.await;
 71
 72        Self {
 73            id: id.clone(),
 74            prediction: Ok(EditPrediction {
 75                id,
 76                edits,
 77                snapshot,
 78                edit_preview,
 79                inputs,
 80                buffer: edited_buffer.clone(),
 81                buffer_snapshotted_at,
 82                response_received_at,
 83            }),
 84        }
 85    }
 86}
 87
 88#[derive(Clone)]
 89pub struct EditPrediction {
 90    pub id: EditPredictionId,
 91    pub edits: Arc<[(Range<Anchor>, Arc<str>)]>,
 92    pub snapshot: BufferSnapshot,
 93    pub edit_preview: EditPreview,
 94    pub buffer: Entity<Buffer>,
 95    pub buffer_snapshotted_at: Instant,
 96    pub response_received_at: Instant,
 97    pub inputs: EditPredictionInputs,
 98}
 99
100#[derive(Debug, Clone, Serialize)]
101pub struct EditPredictionInputs {
102    pub events: Vec<Arc<cloud_llm_client::predict_edits_v3::Event>>,
103    pub included_files: Vec<cloud_llm_client::predict_edits_v3::RelatedFile>,
104    pub cursor_point: cloud_llm_client::predict_edits_v3::Point,
105    pub cursor_path: Arc<Path>,
106}
107
108impl EditPrediction {
109    pub fn interpolate(
110        &self,
111        new_snapshot: &TextBufferSnapshot,
112    ) -> Option<Vec<(Range<Anchor>, Arc<str>)>> {
113        interpolate_edits(&self.snapshot, new_snapshot, &self.edits)
114    }
115
116    pub fn targets_buffer(&self, buffer: &Buffer) -> bool {
117        self.snapshot.remote_id() == buffer.remote_id()
118    }
119
120    pub fn latency(&self) -> Duration {
121        self.response_received_at - self.buffer_snapshotted_at
122    }
123}
124
125impl std::fmt::Debug for EditPrediction {
126    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
127        f.debug_struct("EditPrediction")
128            .field("id", &self.id)
129            .field("edits", &self.edits)
130            .finish()
131    }
132}
133
134#[cfg(test)]
135mod tests {
136    use super::*;
137    use gpui::{App, Entity, TestAppContext, prelude::*};
138    use language::{Buffer, ToOffset as _};
139
140    #[gpui::test]
141    async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
142        let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
143        let edits: Arc<[(Range<Anchor>, Arc<str>)]> = cx.update(|cx| {
144            to_prediction_edits([(2..5, "REM".into()), (9..11, "".into())], &buffer, cx).into()
145        });
146
147        let edit_preview = cx
148            .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
149            .await;
150
151        let prediction = EditPrediction {
152            id: EditPredictionId("prediction-1".into()),
153            edits,
154            snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
155            buffer: buffer.clone(),
156            edit_preview,
157            inputs: EditPredictionInputs {
158                events: vec![],
159                included_files: vec![],
160                cursor_point: cloud_llm_client::predict_edits_v3::Point {
161                    line: cloud_llm_client::predict_edits_v3::Line(0),
162                    column: 0,
163                },
164                cursor_path: Path::new("path.txt").into(),
165            },
166            buffer_snapshotted_at: Instant::now(),
167            response_received_at: Instant::now(),
168        };
169
170        cx.update(|cx| {
171            assert_eq!(
172                from_prediction_edits(
173                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
174                    &buffer,
175                    cx
176                ),
177                vec![(2..5, "REM".into()), (9..11, "".into())]
178            );
179
180            buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
181            assert_eq!(
182                from_prediction_edits(
183                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
184                    &buffer,
185                    cx
186                ),
187                vec![(2..2, "REM".into()), (6..8, "".into())]
188            );
189
190            buffer.update(cx, |buffer, cx| buffer.undo(cx));
191            assert_eq!(
192                from_prediction_edits(
193                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
194                    &buffer,
195                    cx
196                ),
197                vec![(2..5, "REM".into()), (9..11, "".into())]
198            );
199
200            buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
201            assert_eq!(
202                from_prediction_edits(
203                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
204                    &buffer,
205                    cx
206                ),
207                vec![(3..3, "EM".into()), (7..9, "".into())]
208            );
209
210            buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
211            assert_eq!(
212                from_prediction_edits(
213                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
214                    &buffer,
215                    cx
216                ),
217                vec![(4..4, "M".into()), (8..10, "".into())]
218            );
219
220            buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
221            assert_eq!(
222                from_prediction_edits(
223                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
224                    &buffer,
225                    cx
226                ),
227                vec![(9..11, "".into())]
228            );
229
230            buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
231            assert_eq!(
232                from_prediction_edits(
233                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
234                    &buffer,
235                    cx
236                ),
237                vec![(4..4, "M".into()), (8..10, "".into())]
238            );
239
240            buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
241            assert_eq!(
242                from_prediction_edits(
243                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
244                    &buffer,
245                    cx
246                ),
247                vec![(4..4, "M".into())]
248            );
249
250            buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
251            assert_eq!(prediction.interpolate(&buffer.read(cx).snapshot()), None);
252        })
253    }
254
255    fn to_prediction_edits(
256        iterator: impl IntoIterator<Item = (Range<usize>, Arc<str>)>,
257        buffer: &Entity<Buffer>,
258        cx: &App,
259    ) -> Vec<(Range<Anchor>, Arc<str>)> {
260        let buffer = buffer.read(cx);
261        iterator
262            .into_iter()
263            .map(|(range, text)| {
264                (
265                    buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
266                    text,
267                )
268            })
269            .collect()
270    }
271
272    fn from_prediction_edits(
273        editor_edits: &[(Range<Anchor>, Arc<str>)],
274        buffer: &Entity<Buffer>,
275        cx: &App,
276    ) -> Vec<(Range<usize>, Arc<str>)> {
277        let buffer = buffer.read(cx);
278        editor_edits
279            .iter()
280            .map(|(range, text)| {
281                (
282                    range.start.to_offset(buffer)..range.end.to_offset(buffer),
283                    text.clone(),
284                )
285            })
286            .collect()
287    }
288}