1use std::{
2 ops::Range,
3 path::Path,
4 sync::Arc,
5 time::{Duration, Instant},
6};
7
8use gpui::{AsyncApp, Entity, SharedString};
9use language::{Anchor, Buffer, BufferSnapshot, EditPreview, OffsetRangeExt, TextBufferSnapshot};
10use serde::Serialize;
11
12#[derive(Clone, Default, Debug, PartialEq, Eq, Hash)]
13pub struct EditPredictionId(pub SharedString);
14
15impl From<EditPredictionId> for gpui::ElementId {
16 fn from(value: EditPredictionId) -> Self {
17 gpui::ElementId::Name(value.0)
18 }
19}
20
21impl std::fmt::Display for EditPredictionId {
22 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
23 write!(f, "{}", self.0)
24 }
25}
26
27#[derive(Clone)]
28pub struct EditPrediction {
29 pub id: EditPredictionId,
30 pub edits: Arc<[(Range<Anchor>, Arc<str>)]>,
31 pub snapshot: BufferSnapshot,
32 pub edit_preview: EditPreview,
33 // We keep a reference to the buffer so that we do not need to reload it from disk when applying the prediction.
34 pub buffer: Entity<Buffer>,
35 pub buffer_snapshotted_at: Instant,
36 pub response_received_at: Instant,
37 pub inputs: EditPredictionInputs,
38}
39
40#[derive(Debug, Clone, Serialize)]
41pub struct EditPredictionInputs {
42 pub events: Vec<Arc<cloud_llm_client::predict_edits_v3::Event>>,
43 pub included_files: Vec<cloud_llm_client::predict_edits_v3::IncludedFile>,
44 pub cursor_point: cloud_llm_client::predict_edits_v3::Point,
45 pub cursor_path: Arc<Path>,
46}
47
48impl EditPrediction {
49 pub async fn new(
50 id: EditPredictionId,
51 edited_buffer: &Entity<Buffer>,
52 edited_buffer_snapshot: &BufferSnapshot,
53 edits: Arc<[(Range<Anchor>, Arc<str>)]>,
54 buffer_snapshotted_at: Instant,
55 response_received_at: Instant,
56 inputs: EditPredictionInputs,
57 cx: &mut AsyncApp,
58 ) -> Option<Self> {
59 let (edits, snapshot, edit_preview_task) = edited_buffer
60 .read_with(cx, |buffer, cx| {
61 let new_snapshot = buffer.snapshot();
62 let edits: Arc<[_]> =
63 interpolate_edits(&edited_buffer_snapshot, &new_snapshot, edits)?.into();
64
65 Some((edits.clone(), new_snapshot, buffer.preview_edits(edits, cx)))
66 })
67 .ok()??;
68
69 let edit_preview = edit_preview_task.await;
70
71 Some(EditPrediction {
72 id,
73 edits,
74 snapshot,
75 edit_preview,
76 inputs,
77 buffer: edited_buffer.clone(),
78 buffer_snapshotted_at,
79 response_received_at,
80 })
81 }
82
83 pub fn interpolate(
84 &self,
85 new_snapshot: &TextBufferSnapshot,
86 ) -> Option<Vec<(Range<Anchor>, Arc<str>)>> {
87 interpolate_edits(&self.snapshot, new_snapshot, self.edits.clone())
88 }
89
90 pub fn targets_buffer(&self, buffer: &Buffer) -> bool {
91 self.snapshot.remote_id() == buffer.remote_id()
92 }
93
94 pub fn latency(&self) -> Duration {
95 self.response_received_at - self.buffer_snapshotted_at
96 }
97}
98
99impl std::fmt::Debug for EditPrediction {
100 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
101 f.debug_struct("EditPrediction")
102 .field("id", &self.id)
103 .field("edits", &self.edits)
104 .finish()
105 }
106}
107
108pub fn interpolate_edits(
109 old_snapshot: &TextBufferSnapshot,
110 new_snapshot: &TextBufferSnapshot,
111 current_edits: Arc<[(Range<Anchor>, Arc<str>)]>,
112) -> Option<Vec<(Range<Anchor>, Arc<str>)>> {
113 let mut edits = Vec::new();
114
115 let mut model_edits = current_edits.iter().peekable();
116 for user_edit in new_snapshot.edits_since::<usize>(&old_snapshot.version) {
117 while let Some((model_old_range, _)) = model_edits.peek() {
118 let model_old_range = model_old_range.to_offset(old_snapshot);
119 if model_old_range.end < user_edit.old.start {
120 let (model_old_range, model_new_text) = model_edits.next().unwrap();
121 edits.push((model_old_range.clone(), model_new_text.clone()));
122 } else {
123 break;
124 }
125 }
126
127 if let Some((model_old_range, model_new_text)) = model_edits.peek() {
128 let model_old_offset_range = model_old_range.to_offset(old_snapshot);
129 if user_edit.old == model_old_offset_range {
130 let user_new_text = new_snapshot
131 .text_for_range(user_edit.new.clone())
132 .collect::<String>();
133
134 if let Some(model_suffix) = model_new_text.strip_prefix(&user_new_text) {
135 if !model_suffix.is_empty() {
136 let anchor = old_snapshot.anchor_after(user_edit.old.end);
137 edits.push((anchor..anchor, model_suffix.into()));
138 }
139
140 model_edits.next();
141 continue;
142 }
143 }
144 }
145
146 return None;
147 }
148
149 edits.extend(model_edits.cloned());
150
151 if edits.is_empty() { None } else { Some(edits) }
152}
153
154#[cfg(test)]
155mod tests {
156 use super::*;
157 use gpui::{App, Entity, TestAppContext, prelude::*};
158 use language::{Buffer, ToOffset as _};
159
160 #[gpui::test]
161 async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
162 let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
163 let edits: Arc<[(Range<Anchor>, Arc<str>)]> = cx.update(|cx| {
164 to_prediction_edits([(2..5, "REM".into()), (9..11, "".into())], &buffer, cx).into()
165 });
166
167 let edit_preview = cx
168 .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
169 .await;
170
171 let prediction = EditPrediction {
172 id: EditPredictionId("prediction-1".into()),
173 edits,
174 snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
175 buffer: buffer.clone(),
176 edit_preview,
177 inputs: EditPredictionInputs {
178 events: vec![],
179 included_files: vec![],
180 cursor_point: cloud_llm_client::predict_edits_v3::Point {
181 line: cloud_llm_client::predict_edits_v3::Line(0),
182 column: 0,
183 },
184 cursor_path: Path::new("path.txt").into(),
185 },
186 buffer_snapshotted_at: Instant::now(),
187 response_received_at: Instant::now(),
188 };
189
190 cx.update(|cx| {
191 assert_eq!(
192 from_prediction_edits(
193 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
194 &buffer,
195 cx
196 ),
197 vec![(2..5, "REM".into()), (9..11, "".into())]
198 );
199
200 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
201 assert_eq!(
202 from_prediction_edits(
203 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
204 &buffer,
205 cx
206 ),
207 vec![(2..2, "REM".into()), (6..8, "".into())]
208 );
209
210 buffer.update(cx, |buffer, cx| buffer.undo(cx));
211 assert_eq!(
212 from_prediction_edits(
213 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
214 &buffer,
215 cx
216 ),
217 vec![(2..5, "REM".into()), (9..11, "".into())]
218 );
219
220 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
221 assert_eq!(
222 from_prediction_edits(
223 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
224 &buffer,
225 cx
226 ),
227 vec![(3..3, "EM".into()), (7..9, "".into())]
228 );
229
230 buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
231 assert_eq!(
232 from_prediction_edits(
233 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
234 &buffer,
235 cx
236 ),
237 vec![(4..4, "M".into()), (8..10, "".into())]
238 );
239
240 buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
241 assert_eq!(
242 from_prediction_edits(
243 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
244 &buffer,
245 cx
246 ),
247 vec![(9..11, "".into())]
248 );
249
250 buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
251 assert_eq!(
252 from_prediction_edits(
253 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
254 &buffer,
255 cx
256 ),
257 vec![(4..4, "M".into()), (8..10, "".into())]
258 );
259
260 buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
261 assert_eq!(
262 from_prediction_edits(
263 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
264 &buffer,
265 cx
266 ),
267 vec![(4..4, "M".into())]
268 );
269
270 buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
271 assert_eq!(prediction.interpolate(&buffer.read(cx).snapshot()), None);
272 })
273 }
274
275 fn to_prediction_edits(
276 iterator: impl IntoIterator<Item = (Range<usize>, Arc<str>)>,
277 buffer: &Entity<Buffer>,
278 cx: &App,
279 ) -> Vec<(Range<Anchor>, Arc<str>)> {
280 let buffer = buffer.read(cx);
281 iterator
282 .into_iter()
283 .map(|(range, text)| {
284 (
285 buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
286 text,
287 )
288 })
289 .collect()
290 }
291
292 fn from_prediction_edits(
293 editor_edits: &[(Range<Anchor>, Arc<str>)],
294 buffer: &Entity<Buffer>,
295 cx: &App,
296 ) -> Vec<(Range<usize>, Arc<str>)> {
297 let buffer = buffer.read(cx);
298 editor_edits
299 .iter()
300 .map(|(range, text)| {
301 (
302 range.start.to_offset(buffer)..range.end.to_offset(buffer),
303 text.clone(),
304 )
305 })
306 .collect()
307 }
308}