prediction.rs

  1use std::{ops::Range, sync::Arc};
  2
  3use cloud_llm_client::EditPredictionRejectReason;
  4use edit_prediction_types::{PredictedCursorPosition, interpolate_edits};
  5use gpui::{AsyncApp, Entity, SharedString};
  6use language::{Anchor, Buffer, BufferSnapshot, EditPreview, TextBufferSnapshot};
  7use zeta_prompt::ZetaPromptInput;
  8
  9#[derive(Clone, Default, Debug, PartialEq, Eq, Hash)]
 10pub struct EditPredictionId(pub SharedString);
 11
 12impl From<EditPredictionId> for gpui::ElementId {
 13    fn from(value: EditPredictionId) -> Self {
 14        gpui::ElementId::Name(value.0)
 15    }
 16}
 17
 18impl std::fmt::Display for EditPredictionId {
 19    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 20        write!(f, "{}", self.0)
 21    }
 22}
 23
 24/// A prediction response that was returned from the provider, whether it was ultimately valid or not.
 25pub struct EditPredictionResult {
 26    pub id: EditPredictionId,
 27    pub prediction: Result<EditPrediction, EditPredictionRejectReason>,
 28    pub e2e_latency: std::time::Duration,
 29}
 30
 31impl EditPredictionResult {
 32    pub async fn new(
 33        id: EditPredictionId,
 34        edited_buffer: &Entity<Buffer>,
 35        edited_buffer_snapshot: &BufferSnapshot,
 36        edits: Arc<[(Range<Anchor>, Arc<str>)]>,
 37        cursor_position: Option<PredictedCursorPosition>,
 38        inputs: ZetaPromptInput,
 39        model_version: Option<String>,
 40        e2e_latency: std::time::Duration,
 41        cx: &mut AsyncApp,
 42    ) -> Self {
 43        if edits.is_empty() {
 44            return Self {
 45                id,
 46                e2e_latency,
 47                prediction: Err(EditPredictionRejectReason::Empty),
 48            };
 49        }
 50
 51        let Some((edits, snapshot, edit_preview_task)) =
 52            edited_buffer.read_with(cx, |buffer, cx| {
 53                let new_snapshot = buffer.snapshot();
 54                let edits: Arc<[_]> =
 55                    interpolate_edits(&edited_buffer_snapshot, &new_snapshot, &edits)?.into();
 56
 57                Some((edits.clone(), new_snapshot, buffer.preview_edits(edits, cx)))
 58            })
 59        else {
 60            return Self {
 61                id,
 62                e2e_latency,
 63                prediction: Err(EditPredictionRejectReason::InterpolatedEmpty),
 64            };
 65        };
 66
 67        let edit_preview = edit_preview_task.await;
 68
 69        Self {
 70            id: id.clone(),
 71            e2e_latency,
 72            prediction: Ok(EditPrediction {
 73                id,
 74                edits,
 75                cursor_position,
 76                snapshot,
 77                edit_preview,
 78                inputs,
 79                buffer: edited_buffer.clone(),
 80                model_version,
 81            }),
 82        }
 83    }
 84}
 85
 86#[derive(Clone)]
 87pub struct EditPrediction {
 88    pub id: EditPredictionId,
 89    pub edits: Arc<[(Range<Anchor>, Arc<str>)]>,
 90    pub cursor_position: Option<PredictedCursorPosition>,
 91    pub snapshot: BufferSnapshot,
 92    pub edit_preview: EditPreview,
 93    pub buffer: Entity<Buffer>,
 94    pub inputs: zeta_prompt::ZetaPromptInput,
 95    pub model_version: Option<String>,
 96}
 97
 98impl EditPrediction {
 99    pub fn interpolate(
100        &self,
101        new_snapshot: &TextBufferSnapshot,
102    ) -> Option<Vec<(Range<Anchor>, Arc<str>)>> {
103        interpolate_edits(&self.snapshot, new_snapshot, &self.edits)
104    }
105
106    pub fn targets_buffer(&self, buffer: &Buffer) -> bool {
107        self.snapshot.remote_id() == buffer.remote_id()
108    }
109}
110
111impl std::fmt::Debug for EditPrediction {
112    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
113        f.debug_struct("EditPrediction")
114            .field("id", &self.id)
115            .field("edits", &self.edits)
116            .finish()
117    }
118}
119
120#[cfg(test)]
121mod tests {
122    use std::path::Path;
123
124    use super::*;
125    use gpui::{App, Entity, TestAppContext, prelude::*};
126    use language::{Buffer, ToOffset as _};
127    use zeta_prompt::ZetaPromptInput;
128
129    #[gpui::test]
130    async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
131        let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
132        let edits: Arc<[(Range<Anchor>, Arc<str>)]> = cx.update(|cx| {
133            to_prediction_edits([(2..5, "REM".into()), (9..11, "".into())], &buffer, cx).into()
134        });
135
136        let edit_preview = cx
137            .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
138            .await;
139
140        let prediction = EditPrediction {
141            id: EditPredictionId("prediction-1".into()),
142            edits,
143            cursor_position: None,
144            snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
145            buffer: buffer.clone(),
146            edit_preview,
147            model_version: None,
148            inputs: ZetaPromptInput {
149                events: vec![],
150                related_files: Some(vec![]),
151                active_buffer_diagnostics: vec![],
152                cursor_path: Path::new("path.txt").into(),
153                cursor_offset_in_excerpt: 0,
154                cursor_excerpt: "".into(),
155                excerpt_start_row: None,
156                excerpt_ranges: Default::default(),
157                syntax_ranges: None,
158                experiment: None,
159                in_open_source_repo: false,
160                can_collect_data: false,
161                repo_url: None,
162            },
163        };
164
165        cx.update(|cx| {
166            assert_eq!(
167                from_prediction_edits(
168                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
169                    &buffer,
170                    cx
171                ),
172                vec![(2..5, "REM".into()), (9..11, "".into())]
173            );
174
175            buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
176            assert_eq!(
177                from_prediction_edits(
178                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
179                    &buffer,
180                    cx
181                ),
182                vec![(2..2, "REM".into()), (6..8, "".into())]
183            );
184
185            buffer.update(cx, |buffer, cx| buffer.undo(cx));
186            assert_eq!(
187                from_prediction_edits(
188                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
189                    &buffer,
190                    cx
191                ),
192                vec![(2..5, "REM".into()), (9..11, "".into())]
193            );
194
195            buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
196            assert_eq!(
197                from_prediction_edits(
198                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
199                    &buffer,
200                    cx
201                ),
202                vec![(3..3, "EM".into()), (7..9, "".into())]
203            );
204
205            buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
206            assert_eq!(
207                from_prediction_edits(
208                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
209                    &buffer,
210                    cx
211                ),
212                vec![(4..4, "M".into()), (8..10, "".into())]
213            );
214
215            buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
216            assert_eq!(
217                from_prediction_edits(
218                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
219                    &buffer,
220                    cx
221                ),
222                vec![(9..11, "".into())]
223            );
224
225            buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
226            assert_eq!(
227                from_prediction_edits(
228                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
229                    &buffer,
230                    cx
231                ),
232                vec![(4..4, "M".into()), (8..10, "".into())]
233            );
234
235            buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
236            assert_eq!(
237                from_prediction_edits(
238                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
239                    &buffer,
240                    cx
241                ),
242                vec![(4..4, "M".into())]
243            );
244
245            buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
246            assert_eq!(prediction.interpolate(&buffer.read(cx).snapshot()), None);
247        })
248    }
249
250    fn to_prediction_edits(
251        iterator: impl IntoIterator<Item = (Range<usize>, Arc<str>)>,
252        buffer: &Entity<Buffer>,
253        cx: &App,
254    ) -> Vec<(Range<Anchor>, Arc<str>)> {
255        let buffer = buffer.read(cx);
256        iterator
257            .into_iter()
258            .map(|(range, text)| {
259                (
260                    buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
261                    text,
262                )
263            })
264            .collect()
265    }
266
267    fn from_prediction_edits(
268        editor_edits: &[(Range<Anchor>, Arc<str>)],
269        buffer: &Entity<Buffer>,
270        cx: &App,
271    ) -> Vec<(Range<usize>, Arc<str>)> {
272        let buffer = buffer.read(cx);
273        editor_edits
274            .iter()
275            .map(|(range, text)| {
276                (
277                    range.start.to_offset(buffer)..range.end.to_offset(buffer),
278                    text.clone(),
279                )
280            })
281            .collect()
282    }
283}