prediction.rs

  1use std::{
  2    ops::Range,
  3    sync::Arc,
  4    time::{Duration, Instant},
  5};
  6
  7use cloud_llm_client::EditPredictionRejectReason;
  8use edit_prediction_types::{PredictedCursorPosition, interpolate_edits};
  9use gpui::{AsyncApp, Entity, SharedString};
 10use language::{Anchor, Buffer, BufferSnapshot, EditPreview, TextBufferSnapshot};
 11use zeta_prompt::ZetaPromptInput;
 12
 13#[derive(Clone, Default, Debug, PartialEq, Eq, Hash)]
 14pub struct EditPredictionId(pub SharedString);
 15
 16impl From<EditPredictionId> for gpui::ElementId {
 17    fn from(value: EditPredictionId) -> Self {
 18        gpui::ElementId::Name(value.0)
 19    }
 20}
 21
 22impl std::fmt::Display for EditPredictionId {
 23    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 24        write!(f, "{}", self.0)
 25    }
 26}
 27
 28/// A prediction response that was returned from the provider, whether it was ultimately valid or not.
 29pub struct EditPredictionResult {
 30    pub id: EditPredictionId,
 31    pub prediction: Result<EditPrediction, EditPredictionRejectReason>,
 32}
 33
 34impl EditPredictionResult {
 35    pub async fn new(
 36        id: EditPredictionId,
 37        edited_buffer: &Entity<Buffer>,
 38        edited_buffer_snapshot: &BufferSnapshot,
 39        edits: Arc<[(Range<Anchor>, Arc<str>)]>,
 40        cursor_position: Option<PredictedCursorPosition>,
 41        buffer_snapshotted_at: Instant,
 42        response_received_at: Instant,
 43        inputs: ZetaPromptInput,
 44        model_version: Option<String>,
 45        cx: &mut AsyncApp,
 46    ) -> Self {
 47        if edits.is_empty() {
 48            return Self {
 49                id,
 50                prediction: Err(EditPredictionRejectReason::Empty),
 51            };
 52        }
 53
 54        let Some((edits, snapshot, edit_preview_task)) =
 55            edited_buffer.read_with(cx, |buffer, cx| {
 56                let new_snapshot = buffer.snapshot();
 57                let edits: Arc<[_]> =
 58                    interpolate_edits(&edited_buffer_snapshot, &new_snapshot, &edits)?.into();
 59
 60                Some((edits.clone(), new_snapshot, buffer.preview_edits(edits, cx)))
 61            })
 62        else {
 63            return Self {
 64                id,
 65                prediction: Err(EditPredictionRejectReason::InterpolatedEmpty),
 66            };
 67        };
 68
 69        let edit_preview = edit_preview_task.await;
 70
 71        Self {
 72            id: id.clone(),
 73            prediction: Ok(EditPrediction {
 74                id,
 75                edits,
 76                cursor_position,
 77                snapshot,
 78                edit_preview,
 79                inputs,
 80                buffer: edited_buffer.clone(),
 81                buffer_snapshotted_at,
 82                response_received_at,
 83                model_version,
 84            }),
 85        }
 86    }
 87}
 88
 89#[derive(Clone)]
 90pub struct EditPrediction {
 91    pub id: EditPredictionId,
 92    pub edits: Arc<[(Range<Anchor>, Arc<str>)]>,
 93    pub cursor_position: Option<PredictedCursorPosition>,
 94    pub snapshot: BufferSnapshot,
 95    pub edit_preview: EditPreview,
 96    pub buffer: Entity<Buffer>,
 97    pub buffer_snapshotted_at: Instant,
 98    pub response_received_at: Instant,
 99    pub inputs: zeta_prompt::ZetaPromptInput,
100    pub model_version: Option<String>,
101}
102
103impl EditPrediction {
104    pub fn interpolate(
105        &self,
106        new_snapshot: &TextBufferSnapshot,
107    ) -> Option<Vec<(Range<Anchor>, Arc<str>)>> {
108        interpolate_edits(&self.snapshot, new_snapshot, &self.edits)
109    }
110
111    pub fn targets_buffer(&self, buffer: &Buffer) -> bool {
112        self.snapshot.remote_id() == buffer.remote_id()
113    }
114
115    pub fn latency(&self) -> Duration {
116        self.response_received_at - self.buffer_snapshotted_at
117    }
118}
119
120impl std::fmt::Debug for EditPrediction {
121    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
122        f.debug_struct("EditPrediction")
123            .field("id", &self.id)
124            .field("edits", &self.edits)
125            .finish()
126    }
127}
128
129#[cfg(test)]
130mod tests {
131    use std::path::Path;
132
133    use super::*;
134    use gpui::{App, Entity, TestAppContext, prelude::*};
135    use language::{Buffer, ToOffset as _};
136    use zeta_prompt::ZetaPromptInput;
137
138    #[gpui::test]
139    async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
140        let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
141        let edits: Arc<[(Range<Anchor>, Arc<str>)]> = cx.update(|cx| {
142            to_prediction_edits([(2..5, "REM".into()), (9..11, "".into())], &buffer, cx).into()
143        });
144
145        let edit_preview = cx
146            .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
147            .await;
148
149        let prediction = EditPrediction {
150            id: EditPredictionId("prediction-1".into()),
151            edits,
152            cursor_position: None,
153            snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
154            buffer: buffer.clone(),
155            edit_preview,
156            model_version: None,
157            inputs: ZetaPromptInput {
158                events: vec![],
159                related_files: vec![],
160                cursor_path: Path::new("path.txt").into(),
161                cursor_offset_in_excerpt: 0,
162                cursor_excerpt: "".into(),
163                editable_range_in_excerpt: 0..0,
164                excerpt_start_row: None,
165                excerpt_ranges: None,
166                preferred_model: None,
167                in_open_source_repo: false,
168                can_collect_data: false,
169            },
170            buffer_snapshotted_at: Instant::now(),
171            response_received_at: Instant::now(),
172        };
173
174        cx.update(|cx| {
175            assert_eq!(
176                from_prediction_edits(
177                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
178                    &buffer,
179                    cx
180                ),
181                vec![(2..5, "REM".into()), (9..11, "".into())]
182            );
183
184            buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
185            assert_eq!(
186                from_prediction_edits(
187                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
188                    &buffer,
189                    cx
190                ),
191                vec![(2..2, "REM".into()), (6..8, "".into())]
192            );
193
194            buffer.update(cx, |buffer, cx| buffer.undo(cx));
195            assert_eq!(
196                from_prediction_edits(
197                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
198                    &buffer,
199                    cx
200                ),
201                vec![(2..5, "REM".into()), (9..11, "".into())]
202            );
203
204            buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
205            assert_eq!(
206                from_prediction_edits(
207                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
208                    &buffer,
209                    cx
210                ),
211                vec![(3..3, "EM".into()), (7..9, "".into())]
212            );
213
214            buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
215            assert_eq!(
216                from_prediction_edits(
217                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
218                    &buffer,
219                    cx
220                ),
221                vec![(4..4, "M".into()), (8..10, "".into())]
222            );
223
224            buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
225            assert_eq!(
226                from_prediction_edits(
227                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
228                    &buffer,
229                    cx
230                ),
231                vec![(9..11, "".into())]
232            );
233
234            buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
235            assert_eq!(
236                from_prediction_edits(
237                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
238                    &buffer,
239                    cx
240                ),
241                vec![(4..4, "M".into()), (8..10, "".into())]
242            );
243
244            buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
245            assert_eq!(
246                from_prediction_edits(
247                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
248                    &buffer,
249                    cx
250                ),
251                vec![(4..4, "M".into())]
252            );
253
254            buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
255            assert_eq!(prediction.interpolate(&buffer.read(cx).snapshot()), None);
256        })
257    }
258
259    fn to_prediction_edits(
260        iterator: impl IntoIterator<Item = (Range<usize>, Arc<str>)>,
261        buffer: &Entity<Buffer>,
262        cx: &App,
263    ) -> Vec<(Range<Anchor>, Arc<str>)> {
264        let buffer = buffer.read(cx);
265        iterator
266            .into_iter()
267            .map(|(range, text)| {
268                (
269                    buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
270                    text,
271                )
272            })
273            .collect()
274    }
275
276    fn from_prediction_edits(
277        editor_edits: &[(Range<Anchor>, Arc<str>)],
278        buffer: &Entity<Buffer>,
279        cx: &App,
280    ) -> Vec<(Range<usize>, Arc<str>)> {
281        let buffer = buffer.read(cx);
282        editor_edits
283            .iter()
284            .map(|(range, text)| {
285                (
286                    range.start.to_offset(buffer)..range.end.to_offset(buffer),
287                    text.clone(),
288                )
289            })
290            .collect()
291    }
292}