1use std::{ops::Range, sync::Arc};
2
3use gpui::{AsyncApp, Entity};
4use language::{Anchor, Buffer, BufferSnapshot, EditPreview, OffsetRangeExt, TextBufferSnapshot};
5use uuid::Uuid;
6
7#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
8pub struct EditPredictionId(pub Uuid);
9
10impl Into<Uuid> for EditPredictionId {
11 fn into(self) -> Uuid {
12 self.0
13 }
14}
15
16impl From<EditPredictionId> for gpui::ElementId {
17 fn from(value: EditPredictionId) -> Self {
18 gpui::ElementId::Uuid(value.0)
19 }
20}
21
22impl std::fmt::Display for EditPredictionId {
23 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
24 write!(f, "{}", self.0)
25 }
26}
27
28#[derive(Clone)]
29pub struct EditPrediction {
30 pub id: EditPredictionId,
31 pub edits: Arc<[(Range<Anchor>, Arc<str>)]>,
32 pub snapshot: BufferSnapshot,
33 pub edit_preview: EditPreview,
34 // We keep a reference to the buffer so that we do not need to reload it from disk when applying the prediction.
35 pub buffer: Entity<Buffer>,
36}
37
38impl EditPrediction {
39 pub async fn new(
40 id: EditPredictionId,
41 edited_buffer: &Entity<Buffer>,
42 edited_buffer_snapshot: &BufferSnapshot,
43 edits: Vec<(Range<Anchor>, Arc<str>)>,
44 cx: &mut AsyncApp,
45 ) -> Option<Self> {
46 let (edits, snapshot, edit_preview_task) = edited_buffer
47 .read_with(cx, |buffer, cx| {
48 let new_snapshot = buffer.snapshot();
49 let edits: Arc<[_]> =
50 interpolate_edits(&edited_buffer_snapshot, &new_snapshot, edits.into())?.into();
51
52 Some((edits.clone(), new_snapshot, buffer.preview_edits(edits, cx)))
53 })
54 .ok()??;
55
56 let edit_preview = edit_preview_task.await;
57
58 Some(EditPrediction {
59 id,
60 edits,
61 snapshot,
62 edit_preview,
63 buffer: edited_buffer.clone(),
64 })
65 }
66
67 pub fn interpolate(
68 &self,
69 new_snapshot: &TextBufferSnapshot,
70 ) -> Option<Vec<(Range<Anchor>, Arc<str>)>> {
71 interpolate_edits(&self.snapshot, new_snapshot, self.edits.clone())
72 }
73
74 pub fn targets_buffer(&self, buffer: &Buffer) -> bool {
75 self.snapshot.remote_id() == buffer.remote_id()
76 }
77}
78
79impl std::fmt::Debug for EditPrediction {
80 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
81 f.debug_struct("EditPrediction")
82 .field("id", &self.id)
83 .field("edits", &self.edits)
84 .finish()
85 }
86}
87
88pub fn interpolate_edits(
89 old_snapshot: &TextBufferSnapshot,
90 new_snapshot: &TextBufferSnapshot,
91 current_edits: Arc<[(Range<Anchor>, Arc<str>)]>,
92) -> Option<Vec<(Range<Anchor>, Arc<str>)>> {
93 let mut edits = Vec::new();
94
95 let mut model_edits = current_edits.iter().peekable();
96 for user_edit in new_snapshot.edits_since::<usize>(&old_snapshot.version) {
97 while let Some((model_old_range, _)) = model_edits.peek() {
98 let model_old_range = model_old_range.to_offset(old_snapshot);
99 if model_old_range.end < user_edit.old.start {
100 let (model_old_range, model_new_text) = model_edits.next().unwrap();
101 edits.push((model_old_range.clone(), model_new_text.clone()));
102 } else {
103 break;
104 }
105 }
106
107 if let Some((model_old_range, model_new_text)) = model_edits.peek() {
108 let model_old_offset_range = model_old_range.to_offset(old_snapshot);
109 if user_edit.old == model_old_offset_range {
110 let user_new_text = new_snapshot
111 .text_for_range(user_edit.new.clone())
112 .collect::<String>();
113
114 if let Some(model_suffix) = model_new_text.strip_prefix(&user_new_text) {
115 if !model_suffix.is_empty() {
116 let anchor = old_snapshot.anchor_after(user_edit.old.end);
117 edits.push((anchor..anchor, model_suffix.into()));
118 }
119
120 model_edits.next();
121 continue;
122 }
123 }
124 }
125
126 return None;
127 }
128
129 edits.extend(model_edits.cloned());
130
131 if edits.is_empty() { None } else { Some(edits) }
132}
133
134#[cfg(test)]
135mod tests {
136 use super::*;
137 use gpui::{App, Entity, TestAppContext, prelude::*};
138 use language::{Buffer, ToOffset as _};
139
140 #[gpui::test]
141 async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
142 let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
143 let edits: Arc<[(Range<Anchor>, Arc<str>)]> = cx.update(|cx| {
144 to_prediction_edits([(2..5, "REM".into()), (9..11, "".into())], &buffer, cx).into()
145 });
146
147 let edit_preview = cx
148 .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
149 .await;
150
151 let prediction = EditPrediction {
152 id: EditPredictionId(Uuid::new_v4()),
153 edits,
154 snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
155 buffer: buffer.clone(),
156 edit_preview,
157 };
158
159 cx.update(|cx| {
160 assert_eq!(
161 from_prediction_edits(
162 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
163 &buffer,
164 cx
165 ),
166 vec![(2..5, "REM".into()), (9..11, "".into())]
167 );
168
169 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
170 assert_eq!(
171 from_prediction_edits(
172 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
173 &buffer,
174 cx
175 ),
176 vec![(2..2, "REM".into()), (6..8, "".into())]
177 );
178
179 buffer.update(cx, |buffer, cx| buffer.undo(cx));
180 assert_eq!(
181 from_prediction_edits(
182 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
183 &buffer,
184 cx
185 ),
186 vec![(2..5, "REM".into()), (9..11, "".into())]
187 );
188
189 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
190 assert_eq!(
191 from_prediction_edits(
192 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
193 &buffer,
194 cx
195 ),
196 vec![(3..3, "EM".into()), (7..9, "".into())]
197 );
198
199 buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
200 assert_eq!(
201 from_prediction_edits(
202 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
203 &buffer,
204 cx
205 ),
206 vec![(4..4, "M".into()), (8..10, "".into())]
207 );
208
209 buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
210 assert_eq!(
211 from_prediction_edits(
212 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
213 &buffer,
214 cx
215 ),
216 vec![(9..11, "".into())]
217 );
218
219 buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
220 assert_eq!(
221 from_prediction_edits(
222 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
223 &buffer,
224 cx
225 ),
226 vec![(4..4, "M".into()), (8..10, "".into())]
227 );
228
229 buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
230 assert_eq!(
231 from_prediction_edits(
232 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
233 &buffer,
234 cx
235 ),
236 vec![(4..4, "M".into())]
237 );
238
239 buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
240 assert_eq!(prediction.interpolate(&buffer.read(cx).snapshot()), None);
241 })
242 }
243
244 fn to_prediction_edits(
245 iterator: impl IntoIterator<Item = (Range<usize>, Arc<str>)>,
246 buffer: &Entity<Buffer>,
247 cx: &App,
248 ) -> Vec<(Range<Anchor>, Arc<str>)> {
249 let buffer = buffer.read(cx);
250 iterator
251 .into_iter()
252 .map(|(range, text)| {
253 (
254 buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
255 text,
256 )
257 })
258 .collect()
259 }
260
261 fn from_prediction_edits(
262 editor_edits: &[(Range<Anchor>, Arc<str>)],
263 buffer: &Entity<Buffer>,
264 cx: &App,
265 ) -> Vec<(Range<usize>, Arc<str>)> {
266 let buffer = buffer.read(cx);
267 editor_edits
268 .iter()
269 .map(|(range, text)| {
270 (
271 range.start.to_offset(buffer)..range.end.to_offset(buffer),
272 text.clone(),
273 )
274 })
275 .collect()
276 }
277}