1use std::{borrow::Cow, ops::Range, path::Path, sync::Arc};
2
3use anyhow::Context as _;
4use cloud_llm_client::predict_edits_v3;
5use gpui::{App, AsyncApp, Entity};
6use language::{
7 Anchor, Buffer, BufferSnapshot, EditPreview, OffsetRangeExt, TextBufferSnapshot, text_diff,
8};
9use project::Project;
10use util::ResultExt;
11use uuid::Uuid;
12
13#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
14pub struct EditPredictionId(Uuid);
15
16impl From<EditPredictionId> for gpui::ElementId {
17 fn from(value: EditPredictionId) -> Self {
18 gpui::ElementId::Uuid(value.0)
19 }
20}
21
22impl std::fmt::Display for EditPredictionId {
23 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
24 write!(f, "{}", self.0)
25 }
26}
27
28#[derive(Clone)]
29pub struct EditPrediction {
30 pub id: EditPredictionId,
31 pub path: Arc<Path>,
32 pub edits: Arc<[(Range<Anchor>, String)]>,
33 pub snapshot: BufferSnapshot,
34 pub edit_preview: EditPreview,
35 // We keep a reference to the buffer so that we do not need to reload it from disk when applying the prediction.
36 pub buffer: Entity<Buffer>,
37}
38
39impl EditPrediction {
40 pub async fn from_response(
41 response: predict_edits_v3::PredictEditsResponse,
42 active_buffer_old_snapshot: &TextBufferSnapshot,
43 active_buffer: &Entity<Buffer>,
44 project: &Entity<Project>,
45 cx: &mut AsyncApp,
46 ) -> Option<Self> {
47 // TODO only allow cloud to return one path
48 let Some(path) = response.edits.first().map(|e| e.path.clone()) else {
49 return None;
50 };
51
52 let is_same_path = active_buffer
53 .read_with(cx, |buffer, cx| buffer_path_eq(buffer, &path, cx))
54 .ok()?;
55
56 let (buffer, edits, snapshot, edit_preview_task) = if is_same_path {
57 active_buffer
58 .read_with(cx, |buffer, cx| {
59 let new_snapshot = buffer.snapshot();
60 let edits = edits_from_response(&response.edits, &active_buffer_old_snapshot);
61 let edits: Arc<[_]> =
62 interpolate_edits(active_buffer_old_snapshot, &new_snapshot, edits)?.into();
63
64 Some((
65 active_buffer.clone(),
66 edits.clone(),
67 new_snapshot,
68 buffer.preview_edits(edits, cx),
69 ))
70 })
71 .ok()??
72 } else {
73 let buffer_handle = project
74 .update(cx, |project, cx| {
75 let project_path = project
76 .find_project_path(&path, cx)
77 .context("Failed to find project path for zeta edit")?;
78 anyhow::Ok(project.open_buffer(project_path, cx))
79 })
80 .ok()?
81 .log_err()?
82 .await
83 .context("Failed to open buffer for zeta edit")
84 .log_err()?;
85
86 buffer_handle
87 .read_with(cx, |buffer, cx| {
88 let snapshot = buffer.snapshot();
89 let edits = edits_from_response(&response.edits, &snapshot);
90 if edits.is_empty() {
91 return None;
92 }
93 Some((
94 buffer_handle.clone(),
95 edits.clone(),
96 snapshot,
97 buffer.preview_edits(edits, cx),
98 ))
99 })
100 .ok()??
101 };
102
103 let edit_preview = edit_preview_task.await;
104
105 Some(EditPrediction {
106 id: EditPredictionId(response.request_id),
107 path,
108 edits,
109 snapshot,
110 edit_preview,
111 buffer,
112 })
113 }
114
115 pub fn interpolate(
116 &self,
117 new_snapshot: &TextBufferSnapshot,
118 ) -> Option<Vec<(Range<Anchor>, String)>> {
119 interpolate_edits(&self.snapshot, new_snapshot, self.edits.clone())
120 }
121
122 pub fn targets_buffer(&self, buffer: &Buffer, cx: &App) -> bool {
123 buffer_path_eq(buffer, &self.path, cx)
124 }
125}
126
127impl std::fmt::Debug for EditPrediction {
128 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
129 f.debug_struct("EditPrediction")
130 .field("id", &self.id)
131 .field("path", &self.path)
132 .field("edits", &self.edits)
133 .finish()
134 }
135}
136
137pub fn buffer_path_eq(buffer: &Buffer, path: &Path, cx: &App) -> bool {
138 buffer.file().map(|p| p.full_path(cx)).as_deref() == Some(path)
139}
140
141pub fn interpolate_edits(
142 old_snapshot: &TextBufferSnapshot,
143 new_snapshot: &TextBufferSnapshot,
144 current_edits: Arc<[(Range<Anchor>, String)]>,
145) -> Option<Vec<(Range<Anchor>, String)>> {
146 let mut edits = Vec::new();
147
148 let mut model_edits = current_edits.iter().peekable();
149 for user_edit in new_snapshot.edits_since::<usize>(&old_snapshot.version) {
150 while let Some((model_old_range, _)) = model_edits.peek() {
151 let model_old_range = model_old_range.to_offset(old_snapshot);
152 if model_old_range.end < user_edit.old.start {
153 let (model_old_range, model_new_text) = model_edits.next().unwrap();
154 edits.push((model_old_range.clone(), model_new_text.clone()));
155 } else {
156 break;
157 }
158 }
159
160 if let Some((model_old_range, model_new_text)) = model_edits.peek() {
161 let model_old_offset_range = model_old_range.to_offset(old_snapshot);
162 if user_edit.old == model_old_offset_range {
163 let user_new_text = new_snapshot
164 .text_for_range(user_edit.new.clone())
165 .collect::<String>();
166
167 if let Some(model_suffix) = model_new_text.strip_prefix(&user_new_text) {
168 if !model_suffix.is_empty() {
169 let anchor = old_snapshot.anchor_after(user_edit.old.end);
170 edits.push((anchor..anchor, model_suffix.to_string()));
171 }
172
173 model_edits.next();
174 continue;
175 }
176 }
177 }
178
179 return None;
180 }
181
182 edits.extend(model_edits.cloned());
183
184 if edits.is_empty() { None } else { Some(edits) }
185}
186
187pub fn line_range_to_point_range(range: Range<predict_edits_v3::Line>) -> Range<language::Point> {
188 language::Point::new(range.start.0, 0)..language::Point::new(range.end.0, 0)
189}
190
191fn edits_from_response(
192 edits: &[predict_edits_v3::Edit],
193 snapshot: &TextBufferSnapshot,
194) -> Arc<[(Range<Anchor>, String)]> {
195 edits
196 .iter()
197 .flat_map(|edit| {
198 let point_range = line_range_to_point_range(edit.range.clone());
199 let offset = point_range.to_offset(snapshot).start;
200 let old_text = snapshot.text_for_range(point_range);
201
202 excerpt_edits_from_response(
203 old_text.collect::<Cow<str>>(),
204 &edit.content,
205 offset,
206 &snapshot,
207 )
208 })
209 .collect::<Vec<_>>()
210 .into()
211}
212
213fn excerpt_edits_from_response(
214 old_text: Cow<str>,
215 new_text: &str,
216 offset: usize,
217 snapshot: &TextBufferSnapshot,
218) -> impl Iterator<Item = (Range<Anchor>, String)> {
219 text_diff(&old_text, new_text)
220 .into_iter()
221 .map(move |(mut old_range, new_text)| {
222 old_range.start += offset;
223 old_range.end += offset;
224
225 let prefix_len = common_prefix(
226 snapshot.chars_for_range(old_range.clone()),
227 new_text.chars(),
228 );
229 old_range.start += prefix_len;
230
231 let suffix_len = common_prefix(
232 snapshot.reversed_chars_for_range(old_range.clone()),
233 new_text[prefix_len..].chars().rev(),
234 );
235 old_range.end = old_range.end.saturating_sub(suffix_len);
236
237 let new_text = new_text[prefix_len..new_text.len() - suffix_len].to_string();
238 let range = if old_range.is_empty() {
239 let anchor = snapshot.anchor_after(old_range.start);
240 anchor..anchor
241 } else {
242 snapshot.anchor_after(old_range.start)..snapshot.anchor_before(old_range.end)
243 };
244 (range, new_text)
245 })
246}
247
248fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b: T2) -> usize {
249 a.zip(b)
250 .take_while(|(a, b)| a == b)
251 .map(|(a, _)| a.len_utf8())
252 .sum()
253}
254
255#[cfg(test)]
256mod tests {
257 use std::path::PathBuf;
258
259 use super::*;
260 use cloud_llm_client::predict_edits_v3;
261 use edit_prediction_context::Line;
262 use gpui::{App, Entity, TestAppContext, prelude::*};
263 use indoc::indoc;
264 use language::{Buffer, ToOffset as _};
265
266 #[gpui::test]
267 async fn test_compute_edits(cx: &mut TestAppContext) {
268 let old = indoc! {r#"
269 fn main() {
270 let args =
271 println!("{}", args[1])
272 }
273 "#};
274
275 let new = indoc! {r#"
276 fn main() {
277 let args = std::env::args();
278 println!("{}", args[1]);
279 }
280 "#};
281
282 let buffer = cx.new(|cx| Buffer::local(old, cx));
283 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
284
285 // TODO cover more cases when multi-file is supported
286 let big_edits = vec![predict_edits_v3::Edit {
287 path: PathBuf::from("test.txt").into(),
288 range: Line(0)..Line(old.lines().count() as u32),
289 content: new.into(),
290 }];
291
292 let edits = edits_from_response(&big_edits, &snapshot);
293 assert_eq!(edits.len(), 2);
294 assert_eq!(
295 edits[0].0.to_point(&snapshot).start,
296 language::Point::new(1, 14)
297 );
298 assert_eq!(edits[0].1, " std::env::args();");
299 assert_eq!(
300 edits[1].0.to_point(&snapshot).start,
301 language::Point::new(2, 27)
302 );
303 assert_eq!(edits[1].1, ";");
304 }
305
306 #[gpui::test]
307 async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
308 let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
309 let edits: Arc<[(Range<Anchor>, String)]> = cx.update(|cx| {
310 to_prediction_edits(
311 [(2..5, "REM".to_string()), (9..11, "".to_string())],
312 &buffer,
313 cx,
314 )
315 .into()
316 });
317
318 let edit_preview = cx
319 .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
320 .await;
321
322 let prediction = EditPrediction {
323 id: EditPredictionId(Uuid::new_v4()),
324 edits,
325 snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
326 path: Path::new("test.txt").into(),
327 buffer: buffer.clone(),
328 edit_preview,
329 };
330
331 cx.update(|cx| {
332 assert_eq!(
333 from_prediction_edits(
334 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
335 &buffer,
336 cx
337 ),
338 vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
339 );
340
341 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
342 assert_eq!(
343 from_prediction_edits(
344 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
345 &buffer,
346 cx
347 ),
348 vec![(2..2, "REM".to_string()), (6..8, "".to_string())]
349 );
350
351 buffer.update(cx, |buffer, cx| buffer.undo(cx));
352 assert_eq!(
353 from_prediction_edits(
354 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
355 &buffer,
356 cx
357 ),
358 vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
359 );
360
361 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
362 assert_eq!(
363 from_prediction_edits(
364 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
365 &buffer,
366 cx
367 ),
368 vec![(3..3, "EM".to_string()), (7..9, "".to_string())]
369 );
370
371 buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
372 assert_eq!(
373 from_prediction_edits(
374 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
375 &buffer,
376 cx
377 ),
378 vec![(4..4, "M".to_string()), (8..10, "".to_string())]
379 );
380
381 buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
382 assert_eq!(
383 from_prediction_edits(
384 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
385 &buffer,
386 cx
387 ),
388 vec![(9..11, "".to_string())]
389 );
390
391 buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
392 assert_eq!(
393 from_prediction_edits(
394 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
395 &buffer,
396 cx
397 ),
398 vec![(4..4, "M".to_string()), (8..10, "".to_string())]
399 );
400
401 buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
402 assert_eq!(
403 from_prediction_edits(
404 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
405 &buffer,
406 cx
407 ),
408 vec![(4..4, "M".to_string())]
409 );
410
411 buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
412 assert_eq!(prediction.interpolate(&buffer.read(cx).snapshot()), None);
413 })
414 }
415
416 fn to_prediction_edits(
417 iterator: impl IntoIterator<Item = (Range<usize>, String)>,
418 buffer: &Entity<Buffer>,
419 cx: &App,
420 ) -> Vec<(Range<Anchor>, String)> {
421 let buffer = buffer.read(cx);
422 iterator
423 .into_iter()
424 .map(|(range, text)| {
425 (
426 buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
427 text,
428 )
429 })
430 .collect()
431 }
432
433 fn from_prediction_edits(
434 editor_edits: &[(Range<Anchor>, String)],
435 buffer: &Entity<Buffer>,
436 cx: &App,
437 ) -> Vec<(Range<usize>, String)> {
438 let buffer = buffer.read(cx);
439 editor_edits
440 .iter()
441 .map(|(range, text)| {
442 (
443 range.start.to_offset(buffer)..range.end.to_offset(buffer),
444 text.clone(),
445 )
446 })
447 .collect()
448 }
449}