prediction.rs

  1use std::{borrow::Cow, ops::Range, sync::Arc};
  2
  3use cloud_llm_client::predict_edits_v3;
  4use language::{Anchor, BufferSnapshot, EditPreview, OffsetRangeExt, text_diff};
  5use uuid::Uuid;
  6
  7#[derive(Clone)]
  8pub struct EditPrediction {
  9    pub id: EditPredictionId,
 10    pub edits: Arc<[(Range<Anchor>, String)]>,
 11    pub snapshot: BufferSnapshot,
 12    pub edit_preview: EditPreview,
 13}
 14
 15impl EditPrediction {
 16    pub fn interpolate(
 17        &self,
 18        new_snapshot: &BufferSnapshot,
 19    ) -> Option<Vec<(Range<Anchor>, String)>> {
 20        interpolate_edits(&self.snapshot, new_snapshot, self.edits.clone())
 21    }
 22}
 23
 24#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
 25pub struct EditPredictionId(Uuid);
 26
 27impl From<Uuid> for EditPredictionId {
 28    fn from(value: Uuid) -> Self {
 29        EditPredictionId(value)
 30    }
 31}
 32
 33impl From<EditPredictionId> for gpui::ElementId {
 34    fn from(value: EditPredictionId) -> Self {
 35        gpui::ElementId::Uuid(value.0)
 36    }
 37}
 38
 39impl std::fmt::Display for EditPredictionId {
 40    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 41        write!(f, "{}", self.0)
 42    }
 43}
 44
 45pub fn interpolate_edits(
 46    old_snapshot: &BufferSnapshot,
 47    new_snapshot: &BufferSnapshot,
 48    current_edits: Arc<[(Range<Anchor>, String)]>,
 49) -> Option<Vec<(Range<Anchor>, String)>> {
 50    let mut edits = Vec::new();
 51
 52    let mut model_edits = current_edits.iter().peekable();
 53    for user_edit in new_snapshot.edits_since::<usize>(&old_snapshot.version) {
 54        while let Some((model_old_range, _)) = model_edits.peek() {
 55            let model_old_range = model_old_range.to_offset(old_snapshot);
 56            if model_old_range.end < user_edit.old.start {
 57                let (model_old_range, model_new_text) = model_edits.next().unwrap();
 58                edits.push((model_old_range.clone(), model_new_text.clone()));
 59            } else {
 60                break;
 61            }
 62        }
 63
 64        if let Some((model_old_range, model_new_text)) = model_edits.peek() {
 65            let model_old_offset_range = model_old_range.to_offset(old_snapshot);
 66            if user_edit.old == model_old_offset_range {
 67                let user_new_text = new_snapshot
 68                    .text_for_range(user_edit.new.clone())
 69                    .collect::<String>();
 70
 71                if let Some(model_suffix) = model_new_text.strip_prefix(&user_new_text) {
 72                    if !model_suffix.is_empty() {
 73                        let anchor = old_snapshot.anchor_after(user_edit.old.end);
 74                        edits.push((anchor..anchor, model_suffix.to_string()));
 75                    }
 76
 77                    model_edits.next();
 78                    continue;
 79                }
 80            }
 81        }
 82
 83        return None;
 84    }
 85
 86    edits.extend(model_edits.cloned());
 87
 88    if edits.is_empty() { None } else { Some(edits) }
 89}
 90
 91pub fn edits_from_response(
 92    edits: &[predict_edits_v3::Edit],
 93    snapshot: &BufferSnapshot,
 94) -> Arc<[(Range<Anchor>, String)]> {
 95    edits
 96        .iter()
 97        .flat_map(|edit| {
 98            // TODO multi-file edits
 99            let old_text = snapshot.text_for_range(edit.range.clone());
100
101            excerpt_edits_from_response(
102                old_text.collect::<Cow<str>>(),
103                &edit.content,
104                edit.range.start,
105                &snapshot,
106            )
107        })
108        .collect::<Vec<_>>()
109        .into()
110}
111
112fn excerpt_edits_from_response(
113    old_text: Cow<str>,
114    new_text: &str,
115    offset: usize,
116    snapshot: &BufferSnapshot,
117) -> impl Iterator<Item = (Range<Anchor>, String)> {
118    text_diff(&old_text, new_text)
119        .into_iter()
120        .map(move |(mut old_range, new_text)| {
121            old_range.start += offset;
122            old_range.end += offset;
123
124            let prefix_len = common_prefix(
125                snapshot.chars_for_range(old_range.clone()),
126                new_text.chars(),
127            );
128            old_range.start += prefix_len;
129
130            let suffix_len = common_prefix(
131                snapshot.reversed_chars_for_range(old_range.clone()),
132                new_text[prefix_len..].chars().rev(),
133            );
134            old_range.end = old_range.end.saturating_sub(suffix_len);
135
136            let new_text = new_text[prefix_len..new_text.len() - suffix_len].to_string();
137            let range = if old_range.is_empty() {
138                let anchor = snapshot.anchor_after(old_range.start);
139                anchor..anchor
140            } else {
141                snapshot.anchor_after(old_range.start)..snapshot.anchor_before(old_range.end)
142            };
143            (range, new_text)
144        })
145}
146
147fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b: T2) -> usize {
148    a.zip(b)
149        .take_while(|(a, b)| a == b)
150        .map(|(a, _)| a.len_utf8())
151        .sum()
152}
153
154#[cfg(test)]
155mod tests {
156    use std::path::PathBuf;
157
158    use super::*;
159    use cloud_llm_client::predict_edits_v3;
160    use gpui::{App, Entity, TestAppContext, prelude::*};
161    use indoc::indoc;
162    use language::{Buffer, ToOffset as _};
163
164    #[gpui::test]
165    async fn test_compute_edits(cx: &mut TestAppContext) {
166        let old = indoc! {r#"
167            fn main() {
168                let args =
169                println!("{}", args[1])
170            }
171        "#};
172
173        let new = indoc! {r#"
174            fn main() {
175                let args = std::env::args();
176                println!("{}", args[1]);
177            }
178        "#};
179
180        let buffer = cx.new(|cx| Buffer::local(old, cx));
181        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
182
183        // TODO cover more cases when multi-file is supported
184        let big_edits = vec![predict_edits_v3::Edit {
185            path: PathBuf::from("test.txt").into(),
186            range: 0..old.len(),
187            content: new.into(),
188        }];
189
190        let edits = edits_from_response(&big_edits, &snapshot);
191        assert_eq!(edits.len(), 2);
192        assert_eq!(
193            edits[0].0.to_point(&snapshot).start,
194            language::Point::new(1, 14)
195        );
196        assert_eq!(edits[0].1, " std::env::args();");
197        assert_eq!(
198            edits[1].0.to_point(&snapshot).start,
199            language::Point::new(2, 27)
200        );
201        assert_eq!(edits[1].1, ";");
202    }
203
204    #[gpui::test]
205    async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
206        let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
207        let edits: Arc<[(Range<Anchor>, String)]> = cx.update(|cx| {
208            to_prediction_edits(
209                [(2..5, "REM".to_string()), (9..11, "".to_string())],
210                &buffer,
211                cx,
212            )
213            .into()
214        });
215
216        let edit_preview = cx
217            .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
218            .await;
219
220        let prediction = EditPrediction {
221            id: EditPredictionId(Uuid::new_v4()),
222            edits,
223            snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
224            edit_preview,
225        };
226
227        cx.update(|cx| {
228            assert_eq!(
229                from_prediction_edits(
230                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
231                    &buffer,
232                    cx
233                ),
234                vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
235            );
236
237            buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
238            assert_eq!(
239                from_prediction_edits(
240                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
241                    &buffer,
242                    cx
243                ),
244                vec![(2..2, "REM".to_string()), (6..8, "".to_string())]
245            );
246
247            buffer.update(cx, |buffer, cx| buffer.undo(cx));
248            assert_eq!(
249                from_prediction_edits(
250                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
251                    &buffer,
252                    cx
253                ),
254                vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
255            );
256
257            buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
258            assert_eq!(
259                from_prediction_edits(
260                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
261                    &buffer,
262                    cx
263                ),
264                vec![(3..3, "EM".to_string()), (7..9, "".to_string())]
265            );
266
267            buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
268            assert_eq!(
269                from_prediction_edits(
270                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
271                    &buffer,
272                    cx
273                ),
274                vec![(4..4, "M".to_string()), (8..10, "".to_string())]
275            );
276
277            buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
278            assert_eq!(
279                from_prediction_edits(
280                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
281                    &buffer,
282                    cx
283                ),
284                vec![(9..11, "".to_string())]
285            );
286
287            buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
288            assert_eq!(
289                from_prediction_edits(
290                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
291                    &buffer,
292                    cx
293                ),
294                vec![(4..4, "M".to_string()), (8..10, "".to_string())]
295            );
296
297            buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
298            assert_eq!(
299                from_prediction_edits(
300                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
301                    &buffer,
302                    cx
303                ),
304                vec![(4..4, "M".to_string())]
305            );
306
307            buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
308            assert_eq!(prediction.interpolate(&buffer.read(cx).snapshot()), None);
309        })
310    }
311
312    fn to_prediction_edits(
313        iterator: impl IntoIterator<Item = (Range<usize>, String)>,
314        buffer: &Entity<Buffer>,
315        cx: &App,
316    ) -> Vec<(Range<Anchor>, String)> {
317        let buffer = buffer.read(cx);
318        iterator
319            .into_iter()
320            .map(|(range, text)| {
321                (
322                    buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
323                    text,
324                )
325            })
326            .collect()
327    }
328
329    fn from_prediction_edits(
330        editor_edits: &[(Range<Anchor>, String)],
331        buffer: &Entity<Buffer>,
332        cx: &App,
333    ) -> Vec<(Range<usize>, String)> {
334        let buffer = buffer.read(cx);
335        editor_edits
336            .iter()
337            .map(|(range, text)| {
338                (
339                    range.start.to_offset(buffer)..range.end.to_offset(buffer),
340                    text.clone(),
341                )
342            })
343            .collect()
344    }
345}