1use std::{borrow::Cow, ops::Range, path::Path, sync::Arc};
2
3use anyhow::Context as _;
4use cloud_llm_client::predict_edits_v3;
5use gpui::{App, AsyncApp, Entity};
6use language::{
7 Anchor, Buffer, BufferSnapshot, EditPreview, OffsetRangeExt, TextBufferSnapshot, text_diff,
8};
9use project::Project;
10use util::ResultExt;
11use uuid::Uuid;
12
13#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
14pub struct EditPredictionId(Uuid);
15
16impl Into<Uuid> for EditPredictionId {
17 fn into(self) -> Uuid {
18 self.0
19 }
20}
21
22impl From<EditPredictionId> for gpui::ElementId {
23 fn from(value: EditPredictionId) -> Self {
24 gpui::ElementId::Uuid(value.0)
25 }
26}
27
28impl std::fmt::Display for EditPredictionId {
29 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
30 write!(f, "{}", self.0)
31 }
32}
33
34#[derive(Clone)]
35pub struct EditPrediction {
36 pub id: EditPredictionId,
37 pub path: Arc<Path>,
38 pub edits: Arc<[(Range<Anchor>, String)]>,
39 pub snapshot: BufferSnapshot,
40 pub edit_preview: EditPreview,
41 // We keep a reference to the buffer so that we do not need to reload it from disk when applying the prediction.
42 pub buffer: Entity<Buffer>,
43}
44
45impl EditPrediction {
46 pub async fn from_response(
47 response: predict_edits_v3::PredictEditsResponse,
48 active_buffer_old_snapshot: &TextBufferSnapshot,
49 active_buffer: &Entity<Buffer>,
50 project: &Entity<Project>,
51 cx: &mut AsyncApp,
52 ) -> Option<Self> {
53 // TODO only allow cloud to return one path
54 let Some(path) = response.edits.first().map(|e| e.path.clone()) else {
55 return None;
56 };
57
58 let is_same_path = active_buffer
59 .read_with(cx, |buffer, cx| buffer_path_eq(buffer, &path, cx))
60 .ok()?;
61
62 let (buffer, edits, snapshot, edit_preview_task) = if is_same_path {
63 active_buffer
64 .read_with(cx, |buffer, cx| {
65 let new_snapshot = buffer.snapshot();
66 let edits = edits_from_response(&response.edits, &active_buffer_old_snapshot);
67 let edits: Arc<[_]> =
68 interpolate_edits(active_buffer_old_snapshot, &new_snapshot, edits)?.into();
69
70 Some((
71 active_buffer.clone(),
72 edits.clone(),
73 new_snapshot,
74 buffer.preview_edits(edits, cx),
75 ))
76 })
77 .ok()??
78 } else {
79 let buffer_handle = project
80 .update(cx, |project, cx| {
81 let project_path = project
82 .find_project_path(&path, cx)
83 .context("Failed to find project path for zeta edit")?;
84 anyhow::Ok(project.open_buffer(project_path, cx))
85 })
86 .ok()?
87 .log_err()?
88 .await
89 .context("Failed to open buffer for zeta edit")
90 .log_err()?;
91
92 buffer_handle
93 .read_with(cx, |buffer, cx| {
94 let snapshot = buffer.snapshot();
95 let edits = edits_from_response(&response.edits, &snapshot);
96 if edits.is_empty() {
97 return None;
98 }
99 Some((
100 buffer_handle.clone(),
101 edits.clone(),
102 snapshot,
103 buffer.preview_edits(edits, cx),
104 ))
105 })
106 .ok()??
107 };
108
109 let edit_preview = edit_preview_task.await;
110
111 Some(EditPrediction {
112 id: EditPredictionId(response.request_id),
113 path,
114 edits,
115 snapshot,
116 edit_preview,
117 buffer,
118 })
119 }
120
121 pub fn interpolate(
122 &self,
123 new_snapshot: &TextBufferSnapshot,
124 ) -> Option<Vec<(Range<Anchor>, String)>> {
125 interpolate_edits(&self.snapshot, new_snapshot, self.edits.clone())
126 }
127
128 pub fn targets_buffer(&self, buffer: &Buffer, cx: &App) -> bool {
129 buffer_path_eq(buffer, &self.path, cx)
130 }
131}
132
133impl std::fmt::Debug for EditPrediction {
134 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
135 f.debug_struct("EditPrediction")
136 .field("id", &self.id)
137 .field("path", &self.path)
138 .field("edits", &self.edits)
139 .finish()
140 }
141}
142
143pub fn buffer_path_eq(buffer: &Buffer, path: &Path, cx: &App) -> bool {
144 buffer.file().map(|p| p.full_path(cx)).as_deref() == Some(path)
145}
146
147pub fn interpolate_edits(
148 old_snapshot: &TextBufferSnapshot,
149 new_snapshot: &TextBufferSnapshot,
150 current_edits: Arc<[(Range<Anchor>, String)]>,
151) -> Option<Vec<(Range<Anchor>, String)>> {
152 let mut edits = Vec::new();
153
154 let mut model_edits = current_edits.iter().peekable();
155 for user_edit in new_snapshot.edits_since::<usize>(&old_snapshot.version) {
156 while let Some((model_old_range, _)) = model_edits.peek() {
157 let model_old_range = model_old_range.to_offset(old_snapshot);
158 if model_old_range.end < user_edit.old.start {
159 let (model_old_range, model_new_text) = model_edits.next().unwrap();
160 edits.push((model_old_range.clone(), model_new_text.clone()));
161 } else {
162 break;
163 }
164 }
165
166 if let Some((model_old_range, model_new_text)) = model_edits.peek() {
167 let model_old_offset_range = model_old_range.to_offset(old_snapshot);
168 if user_edit.old == model_old_offset_range {
169 let user_new_text = new_snapshot
170 .text_for_range(user_edit.new.clone())
171 .collect::<String>();
172
173 if let Some(model_suffix) = model_new_text.strip_prefix(&user_new_text) {
174 if !model_suffix.is_empty() {
175 let anchor = old_snapshot.anchor_after(user_edit.old.end);
176 edits.push((anchor..anchor, model_suffix.to_string()));
177 }
178
179 model_edits.next();
180 continue;
181 }
182 }
183 }
184
185 return None;
186 }
187
188 edits.extend(model_edits.cloned());
189
190 if edits.is_empty() { None } else { Some(edits) }
191}
192
193pub fn line_range_to_point_range(range: Range<predict_edits_v3::Line>) -> Range<language::Point> {
194 language::Point::new(range.start.0, 0)..language::Point::new(range.end.0, 0)
195}
196
197fn edits_from_response(
198 edits: &[predict_edits_v3::Edit],
199 snapshot: &TextBufferSnapshot,
200) -> Arc<[(Range<Anchor>, String)]> {
201 edits
202 .iter()
203 .flat_map(|edit| {
204 let point_range = line_range_to_point_range(edit.range.clone());
205 let offset = point_range.to_offset(snapshot).start;
206 let old_text = snapshot.text_for_range(point_range);
207
208 excerpt_edits_from_response(
209 old_text.collect::<Cow<str>>(),
210 &edit.content,
211 offset,
212 &snapshot,
213 )
214 })
215 .collect::<Vec<_>>()
216 .into()
217}
218
219fn excerpt_edits_from_response(
220 old_text: Cow<str>,
221 new_text: &str,
222 offset: usize,
223 snapshot: &TextBufferSnapshot,
224) -> impl Iterator<Item = (Range<Anchor>, String)> {
225 text_diff(&old_text, new_text)
226 .into_iter()
227 .map(move |(mut old_range, new_text)| {
228 old_range.start += offset;
229 old_range.end += offset;
230
231 let prefix_len = common_prefix(
232 snapshot.chars_for_range(old_range.clone()),
233 new_text.chars(),
234 );
235 old_range.start += prefix_len;
236
237 let suffix_len = common_prefix(
238 snapshot.reversed_chars_for_range(old_range.clone()),
239 new_text[prefix_len..].chars().rev(),
240 );
241 old_range.end = old_range.end.saturating_sub(suffix_len);
242
243 let new_text = new_text[prefix_len..new_text.len() - suffix_len].to_string();
244 let range = if old_range.is_empty() {
245 let anchor = snapshot.anchor_after(old_range.start);
246 anchor..anchor
247 } else {
248 snapshot.anchor_after(old_range.start)..snapshot.anchor_before(old_range.end)
249 };
250 (range, new_text)
251 })
252}
253
254fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b: T2) -> usize {
255 a.zip(b)
256 .take_while(|(a, b)| a == b)
257 .map(|(a, _)| a.len_utf8())
258 .sum()
259}
260
261#[cfg(test)]
262mod tests {
263 use std::path::PathBuf;
264
265 use super::*;
266 use cloud_llm_client::predict_edits_v3;
267 use edit_prediction_context::Line;
268 use gpui::{App, Entity, TestAppContext, prelude::*};
269 use indoc::indoc;
270 use language::{Buffer, ToOffset as _};
271
272 #[gpui::test]
273 async fn test_compute_edits(cx: &mut TestAppContext) {
274 let old = indoc! {r#"
275 fn main() {
276 let args =
277 println!("{}", args[1])
278 }
279 "#};
280
281 let new = indoc! {r#"
282 fn main() {
283 let args = std::env::args();
284 println!("{}", args[1]);
285 }
286 "#};
287
288 let buffer = cx.new(|cx| Buffer::local(old, cx));
289 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
290
291 // TODO cover more cases when multi-file is supported
292 let big_edits = vec![predict_edits_v3::Edit {
293 path: PathBuf::from("test.txt").into(),
294 range: Line(0)..Line(old.lines().count() as u32),
295 content: new.into(),
296 }];
297
298 let edits = edits_from_response(&big_edits, &snapshot);
299 assert_eq!(edits.len(), 2);
300 assert_eq!(
301 edits[0].0.to_point(&snapshot).start,
302 language::Point::new(1, 14)
303 );
304 assert_eq!(edits[0].1, " std::env::args();");
305 assert_eq!(
306 edits[1].0.to_point(&snapshot).start,
307 language::Point::new(2, 27)
308 );
309 assert_eq!(edits[1].1, ";");
310 }
311
312 #[gpui::test]
313 async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
314 let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
315 let edits: Arc<[(Range<Anchor>, String)]> = cx.update(|cx| {
316 to_prediction_edits(
317 [(2..5, "REM".to_string()), (9..11, "".to_string())],
318 &buffer,
319 cx,
320 )
321 .into()
322 });
323
324 let edit_preview = cx
325 .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
326 .await;
327
328 let prediction = EditPrediction {
329 id: EditPredictionId(Uuid::new_v4()),
330 edits,
331 snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
332 path: Path::new("test.txt").into(),
333 buffer: buffer.clone(),
334 edit_preview,
335 };
336
337 cx.update(|cx| {
338 assert_eq!(
339 from_prediction_edits(
340 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
341 &buffer,
342 cx
343 ),
344 vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
345 );
346
347 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
348 assert_eq!(
349 from_prediction_edits(
350 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
351 &buffer,
352 cx
353 ),
354 vec![(2..2, "REM".to_string()), (6..8, "".to_string())]
355 );
356
357 buffer.update(cx, |buffer, cx| buffer.undo(cx));
358 assert_eq!(
359 from_prediction_edits(
360 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
361 &buffer,
362 cx
363 ),
364 vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
365 );
366
367 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
368 assert_eq!(
369 from_prediction_edits(
370 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
371 &buffer,
372 cx
373 ),
374 vec![(3..3, "EM".to_string()), (7..9, "".to_string())]
375 );
376
377 buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
378 assert_eq!(
379 from_prediction_edits(
380 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
381 &buffer,
382 cx
383 ),
384 vec![(4..4, "M".to_string()), (8..10, "".to_string())]
385 );
386
387 buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
388 assert_eq!(
389 from_prediction_edits(
390 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
391 &buffer,
392 cx
393 ),
394 vec![(9..11, "".to_string())]
395 );
396
397 buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
398 assert_eq!(
399 from_prediction_edits(
400 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
401 &buffer,
402 cx
403 ),
404 vec![(4..4, "M".to_string()), (8..10, "".to_string())]
405 );
406
407 buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
408 assert_eq!(
409 from_prediction_edits(
410 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
411 &buffer,
412 cx
413 ),
414 vec![(4..4, "M".to_string())]
415 );
416
417 buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
418 assert_eq!(prediction.interpolate(&buffer.read(cx).snapshot()), None);
419 })
420 }
421
422 fn to_prediction_edits(
423 iterator: impl IntoIterator<Item = (Range<usize>, String)>,
424 buffer: &Entity<Buffer>,
425 cx: &App,
426 ) -> Vec<(Range<Anchor>, String)> {
427 let buffer = buffer.read(cx);
428 iterator
429 .into_iter()
430 .map(|(range, text)| {
431 (
432 buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
433 text,
434 )
435 })
436 .collect()
437 }
438
439 fn from_prediction_edits(
440 editor_edits: &[(Range<Anchor>, String)],
441 buffer: &Entity<Buffer>,
442 cx: &App,
443 ) -> Vec<(Range<usize>, String)> {
444 let buffer = buffer.read(cx);
445 editor_edits
446 .iter()
447 .map(|(range, text)| {
448 (
449 range.start.to_offset(buffer)..range.end.to_offset(buffer),
450 text.clone(),
451 )
452 })
453 .collect()
454 }
455}