prediction.rs

  1use std::{
  2    ops::Range,
  3    sync::Arc,
  4    time::{Duration, Instant},
  5};
  6
  7use cloud_llm_client::EditPredictionRejectReason;
  8use edit_prediction_types::{PredictedCursorPosition, interpolate_edits};
  9use gpui::{AsyncApp, Entity, SharedString};
 10use language::{Anchor, Buffer, BufferSnapshot, EditPreview, TextBufferSnapshot};
 11use zeta_prompt::ZetaPromptInput;
 12
 13#[derive(Clone, Default, Debug, PartialEq, Eq, Hash)]
 14pub struct EditPredictionId(pub SharedString);
 15
 16impl From<EditPredictionId> for gpui::ElementId {
 17    fn from(value: EditPredictionId) -> Self {
 18        gpui::ElementId::Name(value.0)
 19    }
 20}
 21
 22impl std::fmt::Display for EditPredictionId {
 23    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 24        write!(f, "{}", self.0)
 25    }
 26}
 27
 28/// A prediction response that was returned from the provider, whether it was ultimately valid or not.
 29pub struct EditPredictionResult {
 30    pub id: EditPredictionId,
 31    pub prediction: Result<EditPrediction, EditPredictionRejectReason>,
 32}
 33
 34impl EditPredictionResult {
 35    pub async fn new(
 36        id: EditPredictionId,
 37        edited_buffer: &Entity<Buffer>,
 38        edited_buffer_snapshot: &BufferSnapshot,
 39        edits: Arc<[(Range<Anchor>, Arc<str>)]>,
 40        cursor_position: Option<PredictedCursorPosition>,
 41        buffer_snapshotted_at: Instant,
 42        response_received_at: Instant,
 43        inputs: ZetaPromptInput,
 44        cx: &mut AsyncApp,
 45    ) -> Self {
 46        if edits.is_empty() {
 47            return Self {
 48                id,
 49                prediction: Err(EditPredictionRejectReason::Empty),
 50            };
 51        }
 52
 53        let Some((edits, snapshot, edit_preview_task)) =
 54            edited_buffer.read_with(cx, |buffer, cx| {
 55                let new_snapshot = buffer.snapshot();
 56                let edits: Arc<[_]> =
 57                    interpolate_edits(&edited_buffer_snapshot, &new_snapshot, &edits)?.into();
 58
 59                Some((edits.clone(), new_snapshot, buffer.preview_edits(edits, cx)))
 60            })
 61        else {
 62            return Self {
 63                id,
 64                prediction: Err(EditPredictionRejectReason::InterpolatedEmpty),
 65            };
 66        };
 67
 68        let edit_preview = edit_preview_task.await;
 69
 70        Self {
 71            id: id.clone(),
 72            prediction: Ok(EditPrediction {
 73                id,
 74                edits,
 75                cursor_position,
 76                snapshot,
 77                edit_preview,
 78                inputs,
 79                buffer: edited_buffer.clone(),
 80                buffer_snapshotted_at,
 81                response_received_at,
 82            }),
 83        }
 84    }
 85}
 86
 87#[derive(Clone)]
 88pub struct EditPrediction {
 89    pub id: EditPredictionId,
 90    pub edits: Arc<[(Range<Anchor>, Arc<str>)]>,
 91    pub cursor_position: Option<PredictedCursorPosition>,
 92    pub snapshot: BufferSnapshot,
 93    pub edit_preview: EditPreview,
 94    pub buffer: Entity<Buffer>,
 95    pub buffer_snapshotted_at: Instant,
 96    pub response_received_at: Instant,
 97    pub inputs: zeta_prompt::ZetaPromptInput,
 98}
 99
100impl EditPrediction {
101    pub fn interpolate(
102        &self,
103        new_snapshot: &TextBufferSnapshot,
104    ) -> Option<Vec<(Range<Anchor>, Arc<str>)>> {
105        interpolate_edits(&self.snapshot, new_snapshot, &self.edits)
106    }
107
108    pub fn targets_buffer(&self, buffer: &Buffer) -> bool {
109        self.snapshot.remote_id() == buffer.remote_id()
110    }
111
112    pub fn latency(&self) -> Duration {
113        self.response_received_at - self.buffer_snapshotted_at
114    }
115}
116
117impl std::fmt::Debug for EditPrediction {
118    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
119        f.debug_struct("EditPrediction")
120            .field("id", &self.id)
121            .field("edits", &self.edits)
122            .finish()
123    }
124}
125
126#[cfg(test)]
127mod tests {
128    use std::path::Path;
129
130    use super::*;
131    use gpui::{App, Entity, TestAppContext, prelude::*};
132    use language::{Buffer, ToOffset as _};
133    use zeta_prompt::ZetaPromptInput;
134
135    #[gpui::test]
136    async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
137        let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
138        let edits: Arc<[(Range<Anchor>, Arc<str>)]> = cx.update(|cx| {
139            to_prediction_edits([(2..5, "REM".into()), (9..11, "".into())], &buffer, cx).into()
140        });
141
142        let edit_preview = cx
143            .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
144            .await;
145
146        let prediction = EditPrediction {
147            id: EditPredictionId("prediction-1".into()),
148            edits,
149            cursor_position: None,
150            snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
151            buffer: buffer.clone(),
152            edit_preview,
153            inputs: ZetaPromptInput {
154                events: vec![],
155                related_files: vec![],
156                cursor_path: Path::new("path.txt").into(),
157                cursor_offset_in_excerpt: 0,
158                cursor_excerpt: "".into(),
159                editable_range_in_excerpt: 0..0,
160                excerpt_start_row: None,
161                excerpt_ranges: None,
162                preferred_model: None,
163                in_open_source_repo: false,
164            },
165            buffer_snapshotted_at: Instant::now(),
166            response_received_at: Instant::now(),
167        };
168
169        cx.update(|cx| {
170            assert_eq!(
171                from_prediction_edits(
172                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
173                    &buffer,
174                    cx
175                ),
176                vec![(2..5, "REM".into()), (9..11, "".into())]
177            );
178
179            buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
180            assert_eq!(
181                from_prediction_edits(
182                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
183                    &buffer,
184                    cx
185                ),
186                vec![(2..2, "REM".into()), (6..8, "".into())]
187            );
188
189            buffer.update(cx, |buffer, cx| buffer.undo(cx));
190            assert_eq!(
191                from_prediction_edits(
192                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
193                    &buffer,
194                    cx
195                ),
196                vec![(2..5, "REM".into()), (9..11, "".into())]
197            );
198
199            buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
200            assert_eq!(
201                from_prediction_edits(
202                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
203                    &buffer,
204                    cx
205                ),
206                vec![(3..3, "EM".into()), (7..9, "".into())]
207            );
208
209            buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
210            assert_eq!(
211                from_prediction_edits(
212                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
213                    &buffer,
214                    cx
215                ),
216                vec![(4..4, "M".into()), (8..10, "".into())]
217            );
218
219            buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
220            assert_eq!(
221                from_prediction_edits(
222                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
223                    &buffer,
224                    cx
225                ),
226                vec![(9..11, "".into())]
227            );
228
229            buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
230            assert_eq!(
231                from_prediction_edits(
232                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
233                    &buffer,
234                    cx
235                ),
236                vec![(4..4, "M".into()), (8..10, "".into())]
237            );
238
239            buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
240            assert_eq!(
241                from_prediction_edits(
242                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
243                    &buffer,
244                    cx
245                ),
246                vec![(4..4, "M".into())]
247            );
248
249            buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
250            assert_eq!(prediction.interpolate(&buffer.read(cx).snapshot()), None);
251        })
252    }
253
254    fn to_prediction_edits(
255        iterator: impl IntoIterator<Item = (Range<usize>, Arc<str>)>,
256        buffer: &Entity<Buffer>,
257        cx: &App,
258    ) -> Vec<(Range<Anchor>, Arc<str>)> {
259        let buffer = buffer.read(cx);
260        iterator
261            .into_iter()
262            .map(|(range, text)| {
263                (
264                    buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
265                    text,
266                )
267            })
268            .collect()
269    }
270
271    fn from_prediction_edits(
272        editor_edits: &[(Range<Anchor>, Arc<str>)],
273        buffer: &Entity<Buffer>,
274        cx: &App,
275    ) -> Vec<(Range<usize>, Arc<str>)> {
276        let buffer = buffer.read(cx);
277        editor_edits
278            .iter()
279            .map(|(range, text)| {
280                (
281                    range.start.to_offset(buffer)..range.end.to_offset(buffer),
282                    text.clone(),
283                )
284            })
285            .collect()
286    }
287}