1use std::{ops::Range, sync::Arc};
2
3use gpui::{AsyncApp, Entity, SharedString};
4use language::{Anchor, Buffer, BufferSnapshot, EditPreview, OffsetRangeExt, TextBufferSnapshot};
5
6#[derive(Clone, Default, Debug, PartialEq, Eq, Hash)]
7pub struct EditPredictionId(pub SharedString);
8
9impl From<EditPredictionId> for gpui::ElementId {
10 fn from(value: EditPredictionId) -> Self {
11 gpui::ElementId::Name(value.0)
12 }
13}
14
15impl std::fmt::Display for EditPredictionId {
16 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
17 write!(f, "{}", self.0)
18 }
19}
20
21#[derive(Clone)]
22pub struct EditPrediction {
23 pub id: EditPredictionId,
24 pub edits: Arc<[(Range<Anchor>, Arc<str>)]>,
25 pub snapshot: BufferSnapshot,
26 pub edit_preview: EditPreview,
27 // We keep a reference to the buffer so that we do not need to reload it from disk when applying the prediction.
28 pub buffer: Entity<Buffer>,
29}
30
31impl EditPrediction {
32 pub async fn new(
33 id: EditPredictionId,
34 edited_buffer: &Entity<Buffer>,
35 edited_buffer_snapshot: &BufferSnapshot,
36 edits: Vec<(Range<Anchor>, Arc<str>)>,
37 cx: &mut AsyncApp,
38 ) -> Option<Self> {
39 let (edits, snapshot, edit_preview_task) = edited_buffer
40 .read_with(cx, |buffer, cx| {
41 let new_snapshot = buffer.snapshot();
42 let edits: Arc<[_]> =
43 interpolate_edits(&edited_buffer_snapshot, &new_snapshot, edits.into())?.into();
44
45 Some((edits.clone(), new_snapshot, buffer.preview_edits(edits, cx)))
46 })
47 .ok()??;
48
49 let edit_preview = edit_preview_task.await;
50
51 Some(EditPrediction {
52 id,
53 edits,
54 snapshot,
55 edit_preview,
56 buffer: edited_buffer.clone(),
57 })
58 }
59
60 pub fn interpolate(
61 &self,
62 new_snapshot: &TextBufferSnapshot,
63 ) -> Option<Vec<(Range<Anchor>, Arc<str>)>> {
64 interpolate_edits(&self.snapshot, new_snapshot, self.edits.clone())
65 }
66
67 pub fn targets_buffer(&self, buffer: &Buffer) -> bool {
68 self.snapshot.remote_id() == buffer.remote_id()
69 }
70}
71
72impl std::fmt::Debug for EditPrediction {
73 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
74 f.debug_struct("EditPrediction")
75 .field("id", &self.id)
76 .field("edits", &self.edits)
77 .finish()
78 }
79}
80
81pub fn interpolate_edits(
82 old_snapshot: &TextBufferSnapshot,
83 new_snapshot: &TextBufferSnapshot,
84 current_edits: Arc<[(Range<Anchor>, Arc<str>)]>,
85) -> Option<Vec<(Range<Anchor>, Arc<str>)>> {
86 let mut edits = Vec::new();
87
88 let mut model_edits = current_edits.iter().peekable();
89 for user_edit in new_snapshot.edits_since::<usize>(&old_snapshot.version) {
90 while let Some((model_old_range, _)) = model_edits.peek() {
91 let model_old_range = model_old_range.to_offset(old_snapshot);
92 if model_old_range.end < user_edit.old.start {
93 let (model_old_range, model_new_text) = model_edits.next().unwrap();
94 edits.push((model_old_range.clone(), model_new_text.clone()));
95 } else {
96 break;
97 }
98 }
99
100 if let Some((model_old_range, model_new_text)) = model_edits.peek() {
101 let model_old_offset_range = model_old_range.to_offset(old_snapshot);
102 if user_edit.old == model_old_offset_range {
103 let user_new_text = new_snapshot
104 .text_for_range(user_edit.new.clone())
105 .collect::<String>();
106
107 if let Some(model_suffix) = model_new_text.strip_prefix(&user_new_text) {
108 if !model_suffix.is_empty() {
109 let anchor = old_snapshot.anchor_after(user_edit.old.end);
110 edits.push((anchor..anchor, model_suffix.into()));
111 }
112
113 model_edits.next();
114 continue;
115 }
116 }
117 }
118
119 return None;
120 }
121
122 edits.extend(model_edits.cloned());
123
124 if edits.is_empty() { None } else { Some(edits) }
125}
126
127#[cfg(test)]
128mod tests {
129 use super::*;
130 use gpui::{App, Entity, TestAppContext, prelude::*};
131 use language::{Buffer, ToOffset as _};
132
133 #[gpui::test]
134 async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
135 let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
136 let edits: Arc<[(Range<Anchor>, Arc<str>)]> = cx.update(|cx| {
137 to_prediction_edits([(2..5, "REM".into()), (9..11, "".into())], &buffer, cx).into()
138 });
139
140 let edit_preview = cx
141 .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
142 .await;
143
144 let prediction = EditPrediction {
145 id: EditPredictionId("prediction-1".into()),
146 edits,
147 snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
148 buffer: buffer.clone(),
149 edit_preview,
150 };
151
152 cx.update(|cx| {
153 assert_eq!(
154 from_prediction_edits(
155 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
156 &buffer,
157 cx
158 ),
159 vec![(2..5, "REM".into()), (9..11, "".into())]
160 );
161
162 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
163 assert_eq!(
164 from_prediction_edits(
165 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
166 &buffer,
167 cx
168 ),
169 vec![(2..2, "REM".into()), (6..8, "".into())]
170 );
171
172 buffer.update(cx, |buffer, cx| buffer.undo(cx));
173 assert_eq!(
174 from_prediction_edits(
175 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
176 &buffer,
177 cx
178 ),
179 vec![(2..5, "REM".into()), (9..11, "".into())]
180 );
181
182 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
183 assert_eq!(
184 from_prediction_edits(
185 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
186 &buffer,
187 cx
188 ),
189 vec![(3..3, "EM".into()), (7..9, "".into())]
190 );
191
192 buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
193 assert_eq!(
194 from_prediction_edits(
195 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
196 &buffer,
197 cx
198 ),
199 vec![(4..4, "M".into()), (8..10, "".into())]
200 );
201
202 buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
203 assert_eq!(
204 from_prediction_edits(
205 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
206 &buffer,
207 cx
208 ),
209 vec![(9..11, "".into())]
210 );
211
212 buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
213 assert_eq!(
214 from_prediction_edits(
215 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
216 &buffer,
217 cx
218 ),
219 vec![(4..4, "M".into()), (8..10, "".into())]
220 );
221
222 buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
223 assert_eq!(
224 from_prediction_edits(
225 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
226 &buffer,
227 cx
228 ),
229 vec![(4..4, "M".into())]
230 );
231
232 buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
233 assert_eq!(prediction.interpolate(&buffer.read(cx).snapshot()), None);
234 })
235 }
236
237 fn to_prediction_edits(
238 iterator: impl IntoIterator<Item = (Range<usize>, Arc<str>)>,
239 buffer: &Entity<Buffer>,
240 cx: &App,
241 ) -> Vec<(Range<Anchor>, Arc<str>)> {
242 let buffer = buffer.read(cx);
243 iterator
244 .into_iter()
245 .map(|(range, text)| {
246 (
247 buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
248 text,
249 )
250 })
251 .collect()
252 }
253
254 fn from_prediction_edits(
255 editor_edits: &[(Range<Anchor>, Arc<str>)],
256 buffer: &Entity<Buffer>,
257 cx: &App,
258 ) -> Vec<(Range<usize>, Arc<str>)> {
259 let buffer = buffer.read(cx);
260 editor_edits
261 .iter()
262 .map(|(range, text)| {
263 (
264 range.start.to_offset(buffer)..range.end.to_offset(buffer),
265 text.clone(),
266 )
267 })
268 .collect()
269 }
270}