1use std::{
2 ops::Range,
3 path::Path,
4 sync::Arc,
5 time::{Duration, Instant},
6};
7
8use cloud_llm_client::EditPredictionRejectReason;
9use gpui::{AsyncApp, Entity, SharedString};
10use language::{Anchor, Buffer, BufferSnapshot, EditPreview, OffsetRangeExt, TextBufferSnapshot};
11use serde::Serialize;
12
13#[derive(Clone, Default, Debug, PartialEq, Eq, Hash)]
14pub struct EditPredictionId(pub SharedString);
15
16impl From<EditPredictionId> for gpui::ElementId {
17 fn from(value: EditPredictionId) -> Self {
18 gpui::ElementId::Name(value.0)
19 }
20}
21
22impl std::fmt::Display for EditPredictionId {
23 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
24 write!(f, "{}", self.0)
25 }
26}
27
28/// A prediction response that was returned from the provider, whether it was ultimately valid or not.
29pub struct EditPredictionResult {
30 pub id: EditPredictionId,
31 pub prediction: Result<EditPrediction, EditPredictionRejectReason>,
32}
33
34impl EditPredictionResult {
35 pub async fn new(
36 id: EditPredictionId,
37 edited_buffer: &Entity<Buffer>,
38 edited_buffer_snapshot: &BufferSnapshot,
39 edits: Arc<[(Range<Anchor>, Arc<str>)]>,
40 buffer_snapshotted_at: Instant,
41 response_received_at: Instant,
42 inputs: EditPredictionInputs,
43 cx: &mut AsyncApp,
44 ) -> Self {
45 if edits.is_empty() {
46 return Self {
47 id,
48 prediction: Err(EditPredictionRejectReason::Empty),
49 };
50 }
51
52 let Some((edits, snapshot, edit_preview_task)) = edited_buffer
53 .read_with(cx, |buffer, cx| {
54 let new_snapshot = buffer.snapshot();
55 let edits: Arc<[_]> =
56 interpolate_edits(&edited_buffer_snapshot, &new_snapshot, edits)?.into();
57
58 Some((edits.clone(), new_snapshot, buffer.preview_edits(edits, cx)))
59 })
60 .ok()
61 .flatten()
62 else {
63 return Self {
64 id,
65 prediction: Err(EditPredictionRejectReason::InterpolatedEmpty),
66 };
67 };
68
69 let edit_preview = edit_preview_task.await;
70
71 Self {
72 id: id.clone(),
73 prediction: Ok(EditPrediction {
74 id,
75 edits,
76 snapshot,
77 edit_preview,
78 inputs,
79 buffer: edited_buffer.clone(),
80 buffer_snapshotted_at,
81 response_received_at,
82 }),
83 }
84 }
85}
86
87#[derive(Clone)]
88pub struct EditPrediction {
89 pub id: EditPredictionId,
90 pub edits: Arc<[(Range<Anchor>, Arc<str>)]>,
91 pub snapshot: BufferSnapshot,
92 pub edit_preview: EditPreview,
93 pub buffer: Entity<Buffer>,
94 pub buffer_snapshotted_at: Instant,
95 pub response_received_at: Instant,
96 pub inputs: EditPredictionInputs,
97}
98
99#[derive(Debug, Clone, Serialize)]
100pub struct EditPredictionInputs {
101 pub events: Vec<Arc<cloud_llm_client::predict_edits_v3::Event>>,
102 pub included_files: Vec<cloud_llm_client::predict_edits_v3::IncludedFile>,
103 pub cursor_point: cloud_llm_client::predict_edits_v3::Point,
104 pub cursor_path: Arc<Path>,
105}
106
107impl EditPrediction {
108 pub fn interpolate(
109 &self,
110 new_snapshot: &TextBufferSnapshot,
111 ) -> Option<Vec<(Range<Anchor>, Arc<str>)>> {
112 interpolate_edits(&self.snapshot, new_snapshot, self.edits.clone())
113 }
114
115 pub fn targets_buffer(&self, buffer: &Buffer) -> bool {
116 self.snapshot.remote_id() == buffer.remote_id()
117 }
118
119 pub fn latency(&self) -> Duration {
120 self.response_received_at - self.buffer_snapshotted_at
121 }
122}
123
124impl std::fmt::Debug for EditPrediction {
125 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
126 f.debug_struct("EditPrediction")
127 .field("id", &self.id)
128 .field("edits", &self.edits)
129 .finish()
130 }
131}
132
133pub fn interpolate_edits(
134 old_snapshot: &TextBufferSnapshot,
135 new_snapshot: &TextBufferSnapshot,
136 current_edits: Arc<[(Range<Anchor>, Arc<str>)]>,
137) -> Option<Vec<(Range<Anchor>, Arc<str>)>> {
138 let mut edits = Vec::new();
139
140 let mut model_edits = current_edits.iter().peekable();
141 for user_edit in new_snapshot.edits_since::<usize>(&old_snapshot.version) {
142 while let Some((model_old_range, _)) = model_edits.peek() {
143 let model_old_range = model_old_range.to_offset(old_snapshot);
144 if model_old_range.end < user_edit.old.start {
145 let (model_old_range, model_new_text) = model_edits.next().unwrap();
146 edits.push((model_old_range.clone(), model_new_text.clone()));
147 } else {
148 break;
149 }
150 }
151
152 if let Some((model_old_range, model_new_text)) = model_edits.peek() {
153 let model_old_offset_range = model_old_range.to_offset(old_snapshot);
154 if user_edit.old == model_old_offset_range {
155 let user_new_text = new_snapshot
156 .text_for_range(user_edit.new.clone())
157 .collect::<String>();
158
159 if let Some(model_suffix) = model_new_text.strip_prefix(&user_new_text) {
160 if !model_suffix.is_empty() {
161 let anchor = old_snapshot.anchor_after(user_edit.old.end);
162 edits.push((anchor..anchor, model_suffix.into()));
163 }
164
165 model_edits.next();
166 continue;
167 }
168 }
169 }
170
171 return None;
172 }
173
174 edits.extend(model_edits.cloned());
175
176 if edits.is_empty() { None } else { Some(edits) }
177}
178
179#[cfg(test)]
180mod tests {
181 use super::*;
182 use gpui::{App, Entity, TestAppContext, prelude::*};
183 use language::{Buffer, ToOffset as _};
184
185 #[gpui::test]
186 async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
187 let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
188 let edits: Arc<[(Range<Anchor>, Arc<str>)]> = cx.update(|cx| {
189 to_prediction_edits([(2..5, "REM".into()), (9..11, "".into())], &buffer, cx).into()
190 });
191
192 let edit_preview = cx
193 .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
194 .await;
195
196 let prediction = EditPrediction {
197 id: EditPredictionId("prediction-1".into()),
198 edits,
199 snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
200 buffer: buffer.clone(),
201 edit_preview,
202 inputs: EditPredictionInputs {
203 events: vec![],
204 included_files: vec![],
205 cursor_point: cloud_llm_client::predict_edits_v3::Point {
206 line: cloud_llm_client::predict_edits_v3::Line(0),
207 column: 0,
208 },
209 cursor_path: Path::new("path.txt").into(),
210 },
211 buffer_snapshotted_at: Instant::now(),
212 response_received_at: Instant::now(),
213 };
214
215 cx.update(|cx| {
216 assert_eq!(
217 from_prediction_edits(
218 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
219 &buffer,
220 cx
221 ),
222 vec![(2..5, "REM".into()), (9..11, "".into())]
223 );
224
225 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
226 assert_eq!(
227 from_prediction_edits(
228 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
229 &buffer,
230 cx
231 ),
232 vec![(2..2, "REM".into()), (6..8, "".into())]
233 );
234
235 buffer.update(cx, |buffer, cx| buffer.undo(cx));
236 assert_eq!(
237 from_prediction_edits(
238 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
239 &buffer,
240 cx
241 ),
242 vec![(2..5, "REM".into()), (9..11, "".into())]
243 );
244
245 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
246 assert_eq!(
247 from_prediction_edits(
248 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
249 &buffer,
250 cx
251 ),
252 vec![(3..3, "EM".into()), (7..9, "".into())]
253 );
254
255 buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
256 assert_eq!(
257 from_prediction_edits(
258 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
259 &buffer,
260 cx
261 ),
262 vec![(4..4, "M".into()), (8..10, "".into())]
263 );
264
265 buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
266 assert_eq!(
267 from_prediction_edits(
268 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
269 &buffer,
270 cx
271 ),
272 vec![(9..11, "".into())]
273 );
274
275 buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
276 assert_eq!(
277 from_prediction_edits(
278 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
279 &buffer,
280 cx
281 ),
282 vec![(4..4, "M".into()), (8..10, "".into())]
283 );
284
285 buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
286 assert_eq!(
287 from_prediction_edits(
288 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
289 &buffer,
290 cx
291 ),
292 vec![(4..4, "M".into())]
293 );
294
295 buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
296 assert_eq!(prediction.interpolate(&buffer.read(cx).snapshot()), None);
297 })
298 }
299
300 fn to_prediction_edits(
301 iterator: impl IntoIterator<Item = (Range<usize>, Arc<str>)>,
302 buffer: &Entity<Buffer>,
303 cx: &App,
304 ) -> Vec<(Range<Anchor>, Arc<str>)> {
305 let buffer = buffer.read(cx);
306 iterator
307 .into_iter()
308 .map(|(range, text)| {
309 (
310 buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
311 text,
312 )
313 })
314 .collect()
315 }
316
317 fn from_prediction_edits(
318 editor_edits: &[(Range<Anchor>, Arc<str>)],
319 buffer: &Entity<Buffer>,
320 cx: &App,
321 ) -> Vec<(Range<usize>, Arc<str>)> {
322 let buffer = buffer.read(cx);
323 editor_edits
324 .iter()
325 .map(|(range, text)| {
326 (
327 range.start.to_offset(buffer)..range.end.to_offset(buffer),
328 text.clone(),
329 )
330 })
331 .collect()
332 }
333}