prediction.rs

  1use std::{
  2    ops::Range,
  3    sync::Arc,
  4    time::{Duration, Instant},
  5};
  6
  7use cloud_llm_client::EditPredictionRejectReason;
  8use edit_prediction_types::interpolate_edits;
  9use gpui::{AsyncApp, Entity, SharedString};
 10use language::{Anchor, Buffer, BufferSnapshot, EditPreview, TextBufferSnapshot};
 11use zeta_prompt::ZetaPromptInput;
 12
 13#[derive(Clone, Default, Debug, PartialEq, Eq, Hash)]
 14pub struct EditPredictionId(pub SharedString);
 15
 16impl From<EditPredictionId> for gpui::ElementId {
 17    fn from(value: EditPredictionId) -> Self {
 18        gpui::ElementId::Name(value.0)
 19    }
 20}
 21
 22impl std::fmt::Display for EditPredictionId {
 23    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 24        write!(f, "{}", self.0)
 25    }
 26}
 27
 28/// A prediction response that was returned from the provider, whether it was ultimately valid or not.
 29pub struct EditPredictionResult {
 30    pub id: EditPredictionId,
 31    pub prediction: Result<EditPrediction, EditPredictionRejectReason>,
 32}
 33
 34impl EditPredictionResult {
 35    pub async fn new(
 36        id: EditPredictionId,
 37        edited_buffer: &Entity<Buffer>,
 38        edited_buffer_snapshot: &BufferSnapshot,
 39        edits: Arc<[(Range<Anchor>, Arc<str>)]>,
 40        buffer_snapshotted_at: Instant,
 41        response_received_at: Instant,
 42        inputs: ZetaPromptInput,
 43        cx: &mut AsyncApp,
 44    ) -> Self {
 45        if edits.is_empty() {
 46            return Self {
 47                id,
 48                prediction: Err(EditPredictionRejectReason::Empty),
 49            };
 50        }
 51
 52        let Some((edits, snapshot, edit_preview_task)) =
 53            edited_buffer.read_with(cx, |buffer, cx| {
 54                let new_snapshot = buffer.snapshot();
 55                let edits: Arc<[_]> =
 56                    interpolate_edits(&edited_buffer_snapshot, &new_snapshot, &edits)?.into();
 57
 58                Some((edits.clone(), new_snapshot, buffer.preview_edits(edits, cx)))
 59            })
 60        else {
 61            return Self {
 62                id,
 63                prediction: Err(EditPredictionRejectReason::InterpolatedEmpty),
 64            };
 65        };
 66
 67        let edit_preview = edit_preview_task.await;
 68
 69        Self {
 70            id: id.clone(),
 71            prediction: Ok(EditPrediction {
 72                id,
 73                edits,
 74                snapshot,
 75                edit_preview,
 76                inputs,
 77                buffer: edited_buffer.clone(),
 78                buffer_snapshotted_at,
 79                response_received_at,
 80            }),
 81        }
 82    }
 83}
 84
 85#[derive(Clone)]
 86pub struct EditPrediction {
 87    pub id: EditPredictionId,
 88    pub edits: Arc<[(Range<Anchor>, Arc<str>)]>,
 89    pub snapshot: BufferSnapshot,
 90    pub edit_preview: EditPreview,
 91    pub buffer: Entity<Buffer>,
 92    pub buffer_snapshotted_at: Instant,
 93    pub response_received_at: Instant,
 94    pub inputs: zeta_prompt::ZetaPromptInput,
 95}
 96
 97impl EditPrediction {
 98    pub fn interpolate(
 99        &self,
100        new_snapshot: &TextBufferSnapshot,
101    ) -> Option<Vec<(Range<Anchor>, Arc<str>)>> {
102        interpolate_edits(&self.snapshot, new_snapshot, &self.edits)
103    }
104
105    pub fn targets_buffer(&self, buffer: &Buffer) -> bool {
106        self.snapshot.remote_id() == buffer.remote_id()
107    }
108
109    pub fn latency(&self) -> Duration {
110        self.response_received_at - self.buffer_snapshotted_at
111    }
112}
113
114impl std::fmt::Debug for EditPrediction {
115    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
116        f.debug_struct("EditPrediction")
117            .field("id", &self.id)
118            .field("edits", &self.edits)
119            .finish()
120    }
121}
122
123#[cfg(test)]
124mod tests {
125    use std::path::Path;
126
127    use super::*;
128    use gpui::{App, Entity, TestAppContext, prelude::*};
129    use language::{Buffer, ToOffset as _};
130    use zeta_prompt::ZetaPromptInput;
131
132    #[gpui::test]
133    async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
134        let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
135        let edits: Arc<[(Range<Anchor>, Arc<str>)]> = cx.update(|cx| {
136            to_prediction_edits([(2..5, "REM".into()), (9..11, "".into())], &buffer, cx).into()
137        });
138
139        let edit_preview = cx
140            .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
141            .await;
142
143        let prediction = EditPrediction {
144            id: EditPredictionId("prediction-1".into()),
145            edits,
146            snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
147            buffer: buffer.clone(),
148            edit_preview,
149            inputs: ZetaPromptInput {
150                events: vec![],
151                related_files: vec![].into(),
152                cursor_path: Path::new("path.txt").into(),
153                cursor_offset_in_excerpt: 0,
154                cursor_excerpt: "".into(),
155                editable_range_in_excerpt: 0..0,
156            },
157            buffer_snapshotted_at: Instant::now(),
158            response_received_at: Instant::now(),
159        };
160
161        cx.update(|cx| {
162            assert_eq!(
163                from_prediction_edits(
164                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
165                    &buffer,
166                    cx
167                ),
168                vec![(2..5, "REM".into()), (9..11, "".into())]
169            );
170
171            buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
172            assert_eq!(
173                from_prediction_edits(
174                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
175                    &buffer,
176                    cx
177                ),
178                vec![(2..2, "REM".into()), (6..8, "".into())]
179            );
180
181            buffer.update(cx, |buffer, cx| buffer.undo(cx));
182            assert_eq!(
183                from_prediction_edits(
184                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
185                    &buffer,
186                    cx
187                ),
188                vec![(2..5, "REM".into()), (9..11, "".into())]
189            );
190
191            buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
192            assert_eq!(
193                from_prediction_edits(
194                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
195                    &buffer,
196                    cx
197                ),
198                vec![(3..3, "EM".into()), (7..9, "".into())]
199            );
200
201            buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
202            assert_eq!(
203                from_prediction_edits(
204                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
205                    &buffer,
206                    cx
207                ),
208                vec![(4..4, "M".into()), (8..10, "".into())]
209            );
210
211            buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
212            assert_eq!(
213                from_prediction_edits(
214                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
215                    &buffer,
216                    cx
217                ),
218                vec![(9..11, "".into())]
219            );
220
221            buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
222            assert_eq!(
223                from_prediction_edits(
224                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
225                    &buffer,
226                    cx
227                ),
228                vec![(4..4, "M".into()), (8..10, "".into())]
229            );
230
231            buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
232            assert_eq!(
233                from_prediction_edits(
234                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
235                    &buffer,
236                    cx
237                ),
238                vec![(4..4, "M".into())]
239            );
240
241            buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
242            assert_eq!(prediction.interpolate(&buffer.read(cx).snapshot()), None);
243        })
244    }
245
246    fn to_prediction_edits(
247        iterator: impl IntoIterator<Item = (Range<usize>, Arc<str>)>,
248        buffer: &Entity<Buffer>,
249        cx: &App,
250    ) -> Vec<(Range<Anchor>, Arc<str>)> {
251        let buffer = buffer.read(cx);
252        iterator
253            .into_iter()
254            .map(|(range, text)| {
255                (
256                    buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
257                    text,
258                )
259            })
260            .collect()
261    }
262
263    fn from_prediction_edits(
264        editor_edits: &[(Range<Anchor>, Arc<str>)],
265        buffer: &Entity<Buffer>,
266        cx: &App,
267    ) -> Vec<(Range<usize>, Arc<str>)> {
268        let buffer = buffer.read(cx);
269        editor_edits
270            .iter()
271            .map(|(range, text)| {
272                (
273                    range.start.to_offset(buffer)..range.end.to_offset(buffer),
274                    text.clone(),
275                )
276            })
277            .collect()
278    }
279}