1use std::{
2 ops::Range,
3 path::Path,
4 sync::Arc,
5 time::{Duration, Instant},
6};
7
8use cloud_llm_client::EditPredictionRejectReason;
9use edit_prediction_types::interpolate_edits;
10use gpui::{AsyncApp, Entity, SharedString};
11use language::{Anchor, Buffer, BufferSnapshot, EditPreview, TextBufferSnapshot};
12use serde::Serialize;
13
14#[derive(Clone, Default, Debug, PartialEq, Eq, Hash)]
15pub struct EditPredictionId(pub SharedString);
16
17impl From<EditPredictionId> for gpui::ElementId {
18 fn from(value: EditPredictionId) -> Self {
19 gpui::ElementId::Name(value.0)
20 }
21}
22
23impl std::fmt::Display for EditPredictionId {
24 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
25 write!(f, "{}", self.0)
26 }
27}
28
29/// A prediction response that was returned from the provider, whether it was ultimately valid or not.
30pub struct EditPredictionResult {
31 pub id: EditPredictionId,
32 pub prediction: Result<EditPrediction, EditPredictionRejectReason>,
33}
34
35impl EditPredictionResult {
36 pub async fn new(
37 id: EditPredictionId,
38 edited_buffer: &Entity<Buffer>,
39 edited_buffer_snapshot: &BufferSnapshot,
40 edits: Arc<[(Range<Anchor>, Arc<str>)]>,
41 buffer_snapshotted_at: Instant,
42 response_received_at: Instant,
43 inputs: EditPredictionInputs,
44 cx: &mut AsyncApp,
45 ) -> Self {
46 if edits.is_empty() {
47 return Self {
48 id,
49 prediction: Err(EditPredictionRejectReason::Empty),
50 };
51 }
52
53 let Some((edits, snapshot, edit_preview_task)) = edited_buffer
54 .read_with(cx, |buffer, cx| {
55 let new_snapshot = buffer.snapshot();
56 let edits: Arc<[_]> =
57 interpolate_edits(&edited_buffer_snapshot, &new_snapshot, &edits)?.into();
58
59 Some((edits.clone(), new_snapshot, buffer.preview_edits(edits, cx)))
60 })
61 .ok()
62 .flatten()
63 else {
64 return Self {
65 id,
66 prediction: Err(EditPredictionRejectReason::InterpolatedEmpty),
67 };
68 };
69
70 let edit_preview = edit_preview_task.await;
71
72 Self {
73 id: id.clone(),
74 prediction: Ok(EditPrediction {
75 id,
76 edits,
77 snapshot,
78 edit_preview,
79 inputs,
80 buffer: edited_buffer.clone(),
81 buffer_snapshotted_at,
82 response_received_at,
83 }),
84 }
85 }
86}
87
88#[derive(Clone)]
89pub struct EditPrediction {
90 pub id: EditPredictionId,
91 pub edits: Arc<[(Range<Anchor>, Arc<str>)]>,
92 pub snapshot: BufferSnapshot,
93 pub edit_preview: EditPreview,
94 pub buffer: Entity<Buffer>,
95 pub buffer_snapshotted_at: Instant,
96 pub response_received_at: Instant,
97 pub inputs: EditPredictionInputs,
98}
99
100#[derive(Debug, Clone, Serialize)]
101pub struct EditPredictionInputs {
102 pub events: Vec<Arc<cloud_llm_client::predict_edits_v3::Event>>,
103 pub included_files: Vec<cloud_llm_client::predict_edits_v3::RelatedFile>,
104 pub cursor_point: cloud_llm_client::predict_edits_v3::Point,
105 pub cursor_path: Arc<Path>,
106}
107
108impl EditPrediction {
109 pub fn interpolate(
110 &self,
111 new_snapshot: &TextBufferSnapshot,
112 ) -> Option<Vec<(Range<Anchor>, Arc<str>)>> {
113 interpolate_edits(&self.snapshot, new_snapshot, &self.edits)
114 }
115
116 pub fn targets_buffer(&self, buffer: &Buffer) -> bool {
117 self.snapshot.remote_id() == buffer.remote_id()
118 }
119
120 pub fn latency(&self) -> Duration {
121 self.response_received_at - self.buffer_snapshotted_at
122 }
123}
124
125impl std::fmt::Debug for EditPrediction {
126 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
127 f.debug_struct("EditPrediction")
128 .field("id", &self.id)
129 .field("edits", &self.edits)
130 .finish()
131 }
132}
133
134#[cfg(test)]
135mod tests {
136 use super::*;
137 use gpui::{App, Entity, TestAppContext, prelude::*};
138 use language::{Buffer, ToOffset as _};
139
140 #[gpui::test]
141 async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
142 let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
143 let edits: Arc<[(Range<Anchor>, Arc<str>)]> = cx.update(|cx| {
144 to_prediction_edits([(2..5, "REM".into()), (9..11, "".into())], &buffer, cx).into()
145 });
146
147 let edit_preview = cx
148 .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
149 .await;
150
151 let prediction = EditPrediction {
152 id: EditPredictionId("prediction-1".into()),
153 edits,
154 snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
155 buffer: buffer.clone(),
156 edit_preview,
157 inputs: EditPredictionInputs {
158 events: vec![],
159 included_files: vec![],
160 cursor_point: cloud_llm_client::predict_edits_v3::Point {
161 line: cloud_llm_client::predict_edits_v3::Line(0),
162 column: 0,
163 },
164 cursor_path: Path::new("path.txt").into(),
165 },
166 buffer_snapshotted_at: Instant::now(),
167 response_received_at: Instant::now(),
168 };
169
170 cx.update(|cx| {
171 assert_eq!(
172 from_prediction_edits(
173 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
174 &buffer,
175 cx
176 ),
177 vec![(2..5, "REM".into()), (9..11, "".into())]
178 );
179
180 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
181 assert_eq!(
182 from_prediction_edits(
183 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
184 &buffer,
185 cx
186 ),
187 vec![(2..2, "REM".into()), (6..8, "".into())]
188 );
189
190 buffer.update(cx, |buffer, cx| buffer.undo(cx));
191 assert_eq!(
192 from_prediction_edits(
193 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
194 &buffer,
195 cx
196 ),
197 vec![(2..5, "REM".into()), (9..11, "".into())]
198 );
199
200 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
201 assert_eq!(
202 from_prediction_edits(
203 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
204 &buffer,
205 cx
206 ),
207 vec![(3..3, "EM".into()), (7..9, "".into())]
208 );
209
210 buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
211 assert_eq!(
212 from_prediction_edits(
213 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
214 &buffer,
215 cx
216 ),
217 vec![(4..4, "M".into()), (8..10, "".into())]
218 );
219
220 buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
221 assert_eq!(
222 from_prediction_edits(
223 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
224 &buffer,
225 cx
226 ),
227 vec![(9..11, "".into())]
228 );
229
230 buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
231 assert_eq!(
232 from_prediction_edits(
233 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
234 &buffer,
235 cx
236 ),
237 vec![(4..4, "M".into()), (8..10, "".into())]
238 );
239
240 buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
241 assert_eq!(
242 from_prediction_edits(
243 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
244 &buffer,
245 cx
246 ),
247 vec![(4..4, "M".into())]
248 );
249
250 buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
251 assert_eq!(prediction.interpolate(&buffer.read(cx).snapshot()), None);
252 })
253 }
254
255 fn to_prediction_edits(
256 iterator: impl IntoIterator<Item = (Range<usize>, Arc<str>)>,
257 buffer: &Entity<Buffer>,
258 cx: &App,
259 ) -> Vec<(Range<Anchor>, Arc<str>)> {
260 let buffer = buffer.read(cx);
261 iterator
262 .into_iter()
263 .map(|(range, text)| {
264 (
265 buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
266 text,
267 )
268 })
269 .collect()
270 }
271
272 fn from_prediction_edits(
273 editor_edits: &[(Range<Anchor>, Arc<str>)],
274 buffer: &Entity<Buffer>,
275 cx: &App,
276 ) -> Vec<(Range<usize>, Arc<str>)> {
277 let buffer = buffer.read(cx);
278 editor_edits
279 .iter()
280 .map(|(range, text)| {
281 (
282 range.start.to_offset(buffer)..range.end.to_offset(buffer),
283 text.clone(),
284 )
285 })
286 .collect()
287 }
288}