prediction.rs

  1use std::{
  2    ops::Range,
  3    sync::Arc,
  4    time::{Duration, Instant},
  5};
  6
  7use cloud_llm_client::EditPredictionRejectReason;
  8use edit_prediction_types::interpolate_edits;
  9use gpui::{AsyncApp, Entity, SharedString};
 10use language::{Anchor, Buffer, BufferSnapshot, EditPreview, TextBufferSnapshot};
 11use zeta_prompt::ZetaPromptInput;
 12
 13#[derive(Clone, Default, Debug, PartialEq, Eq, Hash)]
 14pub struct EditPredictionId(pub SharedString);
 15
 16impl From<EditPredictionId> for gpui::ElementId {
 17    fn from(value: EditPredictionId) -> Self {
 18        gpui::ElementId::Name(value.0)
 19    }
 20}
 21
 22impl std::fmt::Display for EditPredictionId {
 23    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 24        write!(f, "{}", self.0)
 25    }
 26}
 27
 28/// A prediction response that was returned from the provider, whether it was ultimately valid or not.
 29pub struct EditPredictionResult {
 30    pub id: EditPredictionId,
 31    pub prediction: Result<EditPrediction, EditPredictionRejectReason>,
 32}
 33
 34impl EditPredictionResult {
 35    pub async fn new(
 36        id: EditPredictionId,
 37        edited_buffer: &Entity<Buffer>,
 38        edited_buffer_snapshot: &BufferSnapshot,
 39        edits: Arc<[(Range<Anchor>, Arc<str>)]>,
 40        buffer_snapshotted_at: Instant,
 41        response_received_at: Instant,
 42        inputs: ZetaPromptInput,
 43        cx: &mut AsyncApp,
 44    ) -> Self {
 45        if edits.is_empty() {
 46            return Self {
 47                id,
 48                prediction: Err(EditPredictionRejectReason::Empty),
 49            };
 50        }
 51
 52        let Some((edits, snapshot, edit_preview_task)) = edited_buffer
 53            .read_with(cx, |buffer, cx| {
 54                let new_snapshot = buffer.snapshot();
 55                let edits: Arc<[_]> =
 56                    interpolate_edits(&edited_buffer_snapshot, &new_snapshot, &edits)?.into();
 57
 58                Some((edits.clone(), new_snapshot, buffer.preview_edits(edits, cx)))
 59            })
 60            .ok()
 61            .flatten()
 62        else {
 63            return Self {
 64                id,
 65                prediction: Err(EditPredictionRejectReason::InterpolatedEmpty),
 66            };
 67        };
 68
 69        let edit_preview = edit_preview_task.await;
 70
 71        Self {
 72            id: id.clone(),
 73            prediction: Ok(EditPrediction {
 74                id,
 75                edits,
 76                snapshot,
 77                edit_preview,
 78                inputs,
 79                buffer: edited_buffer.clone(),
 80                buffer_snapshotted_at,
 81                response_received_at,
 82            }),
 83        }
 84    }
 85}
 86
 87#[derive(Clone)]
 88pub struct EditPrediction {
 89    pub id: EditPredictionId,
 90    pub edits: Arc<[(Range<Anchor>, Arc<str>)]>,
 91    pub snapshot: BufferSnapshot,
 92    pub edit_preview: EditPreview,
 93    pub buffer: Entity<Buffer>,
 94    pub buffer_snapshotted_at: Instant,
 95    pub response_received_at: Instant,
 96    pub inputs: zeta_prompt::ZetaPromptInput,
 97}
 98
 99impl EditPrediction {
100    pub fn interpolate(
101        &self,
102        new_snapshot: &TextBufferSnapshot,
103    ) -> Option<Vec<(Range<Anchor>, Arc<str>)>> {
104        interpolate_edits(&self.snapshot, new_snapshot, &self.edits)
105    }
106
107    pub fn targets_buffer(&self, buffer: &Buffer) -> bool {
108        self.snapshot.remote_id() == buffer.remote_id()
109    }
110
111    pub fn latency(&self) -> Duration {
112        self.response_received_at - self.buffer_snapshotted_at
113    }
114}
115
116impl std::fmt::Debug for EditPrediction {
117    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
118        f.debug_struct("EditPrediction")
119            .field("id", &self.id)
120            .field("edits", &self.edits)
121            .finish()
122    }
123}
124
125#[cfg(test)]
126mod tests {
127    use std::path::Path;
128
129    use super::*;
130    use gpui::{App, Entity, TestAppContext, prelude::*};
131    use language::{Buffer, ToOffset as _};
132    use zeta_prompt::ZetaPromptInput;
133
134    #[gpui::test]
135    async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
136        let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
137        let edits: Arc<[(Range<Anchor>, Arc<str>)]> = cx.update(|cx| {
138            to_prediction_edits([(2..5, "REM".into()), (9..11, "".into())], &buffer, cx).into()
139        });
140
141        let edit_preview = cx
142            .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
143            .await;
144
145        let prediction = EditPrediction {
146            id: EditPredictionId("prediction-1".into()),
147            edits,
148            snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
149            buffer: buffer.clone(),
150            edit_preview,
151            inputs: ZetaPromptInput {
152                events: vec![],
153                related_files: vec![].into(),
154                cursor_path: Path::new("path.txt").into(),
155                cursor_offset_in_excerpt: 0,
156                cursor_excerpt: "".into(),
157                editable_range_in_excerpt: 0..0,
158            },
159            buffer_snapshotted_at: Instant::now(),
160            response_received_at: Instant::now(),
161        };
162
163        cx.update(|cx| {
164            assert_eq!(
165                from_prediction_edits(
166                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
167                    &buffer,
168                    cx
169                ),
170                vec![(2..5, "REM".into()), (9..11, "".into())]
171            );
172
173            buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
174            assert_eq!(
175                from_prediction_edits(
176                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
177                    &buffer,
178                    cx
179                ),
180                vec![(2..2, "REM".into()), (6..8, "".into())]
181            );
182
183            buffer.update(cx, |buffer, cx| buffer.undo(cx));
184            assert_eq!(
185                from_prediction_edits(
186                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
187                    &buffer,
188                    cx
189                ),
190                vec![(2..5, "REM".into()), (9..11, "".into())]
191            );
192
193            buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
194            assert_eq!(
195                from_prediction_edits(
196                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
197                    &buffer,
198                    cx
199                ),
200                vec![(3..3, "EM".into()), (7..9, "".into())]
201            );
202
203            buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
204            assert_eq!(
205                from_prediction_edits(
206                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
207                    &buffer,
208                    cx
209                ),
210                vec![(4..4, "M".into()), (8..10, "".into())]
211            );
212
213            buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
214            assert_eq!(
215                from_prediction_edits(
216                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
217                    &buffer,
218                    cx
219                ),
220                vec![(9..11, "".into())]
221            );
222
223            buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
224            assert_eq!(
225                from_prediction_edits(
226                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
227                    &buffer,
228                    cx
229                ),
230                vec![(4..4, "M".into()), (8..10, "".into())]
231            );
232
233            buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
234            assert_eq!(
235                from_prediction_edits(
236                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
237                    &buffer,
238                    cx
239                ),
240                vec![(4..4, "M".into())]
241            );
242
243            buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
244            assert_eq!(prediction.interpolate(&buffer.read(cx).snapshot()), None);
245        })
246    }
247
248    fn to_prediction_edits(
249        iterator: impl IntoIterator<Item = (Range<usize>, Arc<str>)>,
250        buffer: &Entity<Buffer>,
251        cx: &App,
252    ) -> Vec<(Range<Anchor>, Arc<str>)> {
253        let buffer = buffer.read(cx);
254        iterator
255            .into_iter()
256            .map(|(range, text)| {
257                (
258                    buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
259                    text,
260                )
261            })
262            .collect()
263    }
264
265    fn from_prediction_edits(
266        editor_edits: &[(Range<Anchor>, Arc<str>)],
267        buffer: &Entity<Buffer>,
268        cx: &App,
269    ) -> Vec<(Range<usize>, Arc<str>)> {
270        let buffer = buffer.read(cx);
271        editor_edits
272            .iter()
273            .map(|(range, text)| {
274                (
275                    range.start.to_offset(buffer)..range.end.to_offset(buffer),
276                    text.clone(),
277                )
278            })
279            .collect()
280    }
281}