prediction.rs

  1use std::{ops::Range, sync::Arc};
  2
  3use cloud_llm_client::EditPredictionRejectReason;
  4use edit_prediction_types::{PredictedCursorPosition, interpolate_edits};
  5use gpui::{AsyncApp, Entity, SharedString};
  6use language::{Anchor, Buffer, BufferSnapshot, EditPreview, TextBufferSnapshot};
  7use zeta_prompt::ZetaPromptInput;
  8
  9#[derive(Clone, Default, Debug, PartialEq, Eq, Hash)]
 10pub struct EditPredictionId(pub SharedString);
 11
 12impl From<EditPredictionId> for gpui::ElementId {
 13    fn from(value: EditPredictionId) -> Self {
 14        gpui::ElementId::Name(value.0)
 15    }
 16}
 17
 18impl std::fmt::Display for EditPredictionId {
 19    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 20        write!(f, "{}", self.0)
 21    }
 22}
 23
 24/// A prediction response that was returned from the provider, whether it was ultimately valid or not.
 25pub struct EditPredictionResult {
 26    pub id: EditPredictionId,
 27    pub prediction: Result<EditPrediction, EditPredictionRejectReason>,
 28    pub model_version: Option<String>,
 29    pub e2e_latency: std::time::Duration,
 30}
 31
 32impl EditPredictionResult {
 33    pub async fn new(
 34        id: EditPredictionId,
 35        edited_buffer: &Entity<Buffer>,
 36        edited_buffer_snapshot: &BufferSnapshot,
 37        edits: Arc<[(Range<Anchor>, Arc<str>)]>,
 38        cursor_position: Option<PredictedCursorPosition>,
 39        inputs: ZetaPromptInput,
 40        model_version: Option<String>,
 41        e2e_latency: std::time::Duration,
 42        cx: &mut AsyncApp,
 43    ) -> Self {
 44        if edits.is_empty() {
 45            return Self {
 46                id,
 47                prediction: Err(EditPredictionRejectReason::Empty),
 48                model_version,
 49                e2e_latency,
 50            };
 51        }
 52
 53        let Some((edits, snapshot, edit_preview_task)) =
 54            edited_buffer.read_with(cx, |buffer, cx| {
 55                let new_snapshot = buffer.snapshot();
 56                let edits: Arc<[_]> =
 57                    interpolate_edits(&edited_buffer_snapshot, &new_snapshot, &edits)?.into();
 58
 59                Some((edits.clone(), new_snapshot, buffer.preview_edits(edits, cx)))
 60            })
 61        else {
 62            return Self {
 63                id,
 64                prediction: Err(EditPredictionRejectReason::InterpolatedEmpty),
 65                model_version,
 66                e2e_latency,
 67            };
 68        };
 69
 70        let edit_preview = edit_preview_task.await;
 71
 72        Self {
 73            id: id.clone(),
 74            prediction: Ok(EditPrediction {
 75                id,
 76                edits,
 77                cursor_position,
 78                snapshot,
 79                edit_preview,
 80                inputs,
 81                buffer: edited_buffer.clone(),
 82                model_version: model_version.clone(),
 83            }),
 84            model_version,
 85            e2e_latency,
 86        }
 87    }
 88}
 89
 90#[derive(Clone)]
 91pub struct EditPrediction {
 92    pub id: EditPredictionId,
 93    pub edits: Arc<[(Range<Anchor>, Arc<str>)]>,
 94    pub cursor_position: Option<PredictedCursorPosition>,
 95    pub snapshot: BufferSnapshot,
 96    pub edit_preview: EditPreview,
 97    pub buffer: Entity<Buffer>,
 98    pub inputs: zeta_prompt::ZetaPromptInput,
 99    pub model_version: Option<String>,
100}
101
102impl EditPrediction {
103    pub fn interpolate(
104        &self,
105        new_snapshot: &TextBufferSnapshot,
106    ) -> Option<Vec<(Range<Anchor>, Arc<str>)>> {
107        interpolate_edits(&self.snapshot, new_snapshot, &self.edits)
108    }
109
110    pub fn targets_buffer(&self, buffer: &Buffer) -> bool {
111        self.snapshot.remote_id() == buffer.remote_id()
112    }
113}
114
115impl std::fmt::Debug for EditPrediction {
116    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
117        f.debug_struct("EditPrediction")
118            .field("id", &self.id)
119            .field("edits", &self.edits)
120            .finish()
121    }
122}
123
124#[cfg(test)]
125mod tests {
126    use std::path::Path;
127
128    use super::*;
129    use gpui::{App, Entity, TestAppContext, prelude::*};
130    use language::{Buffer, ToOffset as _};
131    use zeta_prompt::ZetaPromptInput;
132
133    #[gpui::test]
134    async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
135        let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
136        let edits: Arc<[(Range<Anchor>, Arc<str>)]> = cx.update(|cx| {
137            to_prediction_edits([(2..5, "REM".into()), (9..11, "".into())], &buffer, cx).into()
138        });
139
140        let edit_preview = cx
141            .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
142            .await;
143
144        let prediction = EditPrediction {
145            id: EditPredictionId("prediction-1".into()),
146            edits,
147            cursor_position: None,
148            snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
149            buffer: buffer.clone(),
150            edit_preview,
151            model_version: None,
152            inputs: ZetaPromptInput {
153                events: vec![],
154                related_files: Some(vec![]),
155                active_buffer_diagnostics: vec![],
156                cursor_path: Path::new("path.txt").into(),
157                cursor_offset_in_excerpt: 0,
158                cursor_excerpt: "".into(),
159                excerpt_start_row: None,
160                excerpt_ranges: Default::default(),
161                syntax_ranges: None,
162                in_open_source_repo: false,
163                can_collect_data: false,
164                repo_url: None,
165            },
166        };
167
168        cx.update(|cx| {
169            assert_eq!(
170                from_prediction_edits(
171                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
172                    &buffer,
173                    cx
174                ),
175                vec![(2..5, "REM".into()), (9..11, "".into())]
176            );
177
178            buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
179            assert_eq!(
180                from_prediction_edits(
181                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
182                    &buffer,
183                    cx
184                ),
185                vec![(2..2, "REM".into()), (6..8, "".into())]
186            );
187
188            buffer.update(cx, |buffer, cx| buffer.undo(cx));
189            assert_eq!(
190                from_prediction_edits(
191                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
192                    &buffer,
193                    cx
194                ),
195                vec![(2..5, "REM".into()), (9..11, "".into())]
196            );
197
198            buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
199            assert_eq!(
200                from_prediction_edits(
201                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
202                    &buffer,
203                    cx
204                ),
205                vec![(3..3, "EM".into()), (7..9, "".into())]
206            );
207
208            buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
209            assert_eq!(
210                from_prediction_edits(
211                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
212                    &buffer,
213                    cx
214                ),
215                vec![(4..4, "M".into()), (8..10, "".into())]
216            );
217
218            buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
219            assert_eq!(
220                from_prediction_edits(
221                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
222                    &buffer,
223                    cx
224                ),
225                vec![(9..11, "".into())]
226            );
227
228            buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
229            assert_eq!(
230                from_prediction_edits(
231                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
232                    &buffer,
233                    cx
234                ),
235                vec![(4..4, "M".into()), (8..10, "".into())]
236            );
237
238            buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
239            assert_eq!(
240                from_prediction_edits(
241                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
242                    &buffer,
243                    cx
244                ),
245                vec![(4..4, "M".into())]
246            );
247
248            buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
249            assert_eq!(prediction.interpolate(&buffer.read(cx).snapshot()), None);
250        })
251    }
252
253    fn to_prediction_edits(
254        iterator: impl IntoIterator<Item = (Range<usize>, Arc<str>)>,
255        buffer: &Entity<Buffer>,
256        cx: &App,
257    ) -> Vec<(Range<Anchor>, Arc<str>)> {
258        let buffer = buffer.read(cx);
259        iterator
260            .into_iter()
261            .map(|(range, text)| {
262                (
263                    buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
264                    text,
265                )
266            })
267            .collect()
268    }
269
270    fn from_prediction_edits(
271        editor_edits: &[(Range<Anchor>, Arc<str>)],
272        buffer: &Entity<Buffer>,
273        cx: &App,
274    ) -> Vec<(Range<usize>, Arc<str>)> {
275        let buffer = buffer.read(cx);
276        editor_edits
277            .iter()
278            .map(|(range, text)| {
279                (
280                    range.start.to_offset(buffer)..range.end.to_offset(buffer),
281                    text.clone(),
282                )
283            })
284            .collect()
285    }
286}