1use std::{
2 ops::Range,
3 sync::Arc,
4 time::{Duration, Instant},
5};
6
7use cloud_llm_client::EditPredictionRejectReason;
8use edit_prediction_types::{PredictedCursorPosition, interpolate_edits};
9use gpui::{AsyncApp, Entity, SharedString};
10use language::{Anchor, Buffer, BufferSnapshot, EditPreview, TextBufferSnapshot};
11use zeta_prompt::ZetaPromptInput;
12
13#[derive(Clone, Default, Debug, PartialEq, Eq, Hash)]
14pub struct EditPredictionId(pub SharedString);
15
16impl From<EditPredictionId> for gpui::ElementId {
17 fn from(value: EditPredictionId) -> Self {
18 gpui::ElementId::Name(value.0)
19 }
20}
21
22impl std::fmt::Display for EditPredictionId {
23 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
24 write!(f, "{}", self.0)
25 }
26}
27
28/// A prediction response that was returned from the provider, whether it was ultimately valid or not.
29pub struct EditPredictionResult {
30 pub id: EditPredictionId,
31 pub prediction: Result<EditPrediction, EditPredictionRejectReason>,
32}
33
34impl EditPredictionResult {
35 pub async fn new(
36 id: EditPredictionId,
37 edited_buffer: &Entity<Buffer>,
38 edited_buffer_snapshot: &BufferSnapshot,
39 edits: Arc<[(Range<Anchor>, Arc<str>)]>,
40 cursor_position: Option<PredictedCursorPosition>,
41 buffer_snapshotted_at: Instant,
42 response_received_at: Instant,
43 inputs: ZetaPromptInput,
44 model_version: Option<String>,
45 cx: &mut AsyncApp,
46 ) -> Self {
47 if edits.is_empty() {
48 return Self {
49 id,
50 prediction: Err(EditPredictionRejectReason::Empty),
51 };
52 }
53
54 let Some((edits, snapshot, edit_preview_task)) =
55 edited_buffer.read_with(cx, |buffer, cx| {
56 let new_snapshot = buffer.snapshot();
57 let edits: Arc<[_]> =
58 interpolate_edits(&edited_buffer_snapshot, &new_snapshot, &edits)?.into();
59
60 Some((edits.clone(), new_snapshot, buffer.preview_edits(edits, cx)))
61 })
62 else {
63 return Self {
64 id,
65 prediction: Err(EditPredictionRejectReason::InterpolatedEmpty),
66 };
67 };
68
69 let edit_preview = edit_preview_task.await;
70
71 Self {
72 id: id.clone(),
73 prediction: Ok(EditPrediction {
74 id,
75 edits,
76 cursor_position,
77 snapshot,
78 edit_preview,
79 inputs,
80 buffer: edited_buffer.clone(),
81 buffer_snapshotted_at,
82 response_received_at,
83 model_version,
84 }),
85 }
86 }
87}
88
89#[derive(Clone)]
90pub struct EditPrediction {
91 pub id: EditPredictionId,
92 pub edits: Arc<[(Range<Anchor>, Arc<str>)]>,
93 pub cursor_position: Option<PredictedCursorPosition>,
94 pub snapshot: BufferSnapshot,
95 pub edit_preview: EditPreview,
96 pub buffer: Entity<Buffer>,
97 pub buffer_snapshotted_at: Instant,
98 pub response_received_at: Instant,
99 pub inputs: zeta_prompt::ZetaPromptInput,
100 pub model_version: Option<String>,
101}
102
103impl EditPrediction {
104 pub fn interpolate(
105 &self,
106 new_snapshot: &TextBufferSnapshot,
107 ) -> Option<Vec<(Range<Anchor>, Arc<str>)>> {
108 interpolate_edits(&self.snapshot, new_snapshot, &self.edits)
109 }
110
111 pub fn targets_buffer(&self, buffer: &Buffer) -> bool {
112 self.snapshot.remote_id() == buffer.remote_id()
113 }
114
115 pub fn latency(&self) -> Duration {
116 self.response_received_at - self.buffer_snapshotted_at
117 }
118}
119
120impl std::fmt::Debug for EditPrediction {
121 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
122 f.debug_struct("EditPrediction")
123 .field("id", &self.id)
124 .field("edits", &self.edits)
125 .finish()
126 }
127}
128
129#[cfg(test)]
130mod tests {
131 use std::path::Path;
132
133 use super::*;
134 use gpui::{App, Entity, TestAppContext, prelude::*};
135 use language::{Buffer, ToOffset as _};
136 use zeta_prompt::ZetaPromptInput;
137
138 #[gpui::test]
139 async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
140 let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
141 let edits: Arc<[(Range<Anchor>, Arc<str>)]> = cx.update(|cx| {
142 to_prediction_edits([(2..5, "REM".into()), (9..11, "".into())], &buffer, cx).into()
143 });
144
145 let edit_preview = cx
146 .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
147 .await;
148
149 let prediction = EditPrediction {
150 id: EditPredictionId("prediction-1".into()),
151 edits,
152 cursor_position: None,
153 snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
154 buffer: buffer.clone(),
155 edit_preview,
156 model_version: None,
157 inputs: ZetaPromptInput {
158 events: vec![],
159 related_files: vec![],
160 cursor_path: Path::new("path.txt").into(),
161 cursor_offset_in_excerpt: 0,
162 cursor_excerpt: "".into(),
163 editable_range_in_excerpt: 0..0,
164 excerpt_start_row: None,
165 excerpt_ranges: None,
166 preferred_model: None,
167 in_open_source_repo: false,
168 can_collect_data: false,
169 },
170 buffer_snapshotted_at: Instant::now(),
171 response_received_at: Instant::now(),
172 };
173
174 cx.update(|cx| {
175 assert_eq!(
176 from_prediction_edits(
177 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
178 &buffer,
179 cx
180 ),
181 vec![(2..5, "REM".into()), (9..11, "".into())]
182 );
183
184 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
185 assert_eq!(
186 from_prediction_edits(
187 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
188 &buffer,
189 cx
190 ),
191 vec![(2..2, "REM".into()), (6..8, "".into())]
192 );
193
194 buffer.update(cx, |buffer, cx| buffer.undo(cx));
195 assert_eq!(
196 from_prediction_edits(
197 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
198 &buffer,
199 cx
200 ),
201 vec![(2..5, "REM".into()), (9..11, "".into())]
202 );
203
204 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
205 assert_eq!(
206 from_prediction_edits(
207 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
208 &buffer,
209 cx
210 ),
211 vec![(3..3, "EM".into()), (7..9, "".into())]
212 );
213
214 buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
215 assert_eq!(
216 from_prediction_edits(
217 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
218 &buffer,
219 cx
220 ),
221 vec![(4..4, "M".into()), (8..10, "".into())]
222 );
223
224 buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
225 assert_eq!(
226 from_prediction_edits(
227 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
228 &buffer,
229 cx
230 ),
231 vec![(9..11, "".into())]
232 );
233
234 buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
235 assert_eq!(
236 from_prediction_edits(
237 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
238 &buffer,
239 cx
240 ),
241 vec![(4..4, "M".into()), (8..10, "".into())]
242 );
243
244 buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
245 assert_eq!(
246 from_prediction_edits(
247 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
248 &buffer,
249 cx
250 ),
251 vec![(4..4, "M".into())]
252 );
253
254 buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
255 assert_eq!(prediction.interpolate(&buffer.read(cx).snapshot()), None);
256 })
257 }
258
259 fn to_prediction_edits(
260 iterator: impl IntoIterator<Item = (Range<usize>, Arc<str>)>,
261 buffer: &Entity<Buffer>,
262 cx: &App,
263 ) -> Vec<(Range<Anchor>, Arc<str>)> {
264 let buffer = buffer.read(cx);
265 iterator
266 .into_iter()
267 .map(|(range, text)| {
268 (
269 buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
270 text,
271 )
272 })
273 .collect()
274 }
275
276 fn from_prediction_edits(
277 editor_edits: &[(Range<Anchor>, Arc<str>)],
278 buffer: &Entity<Buffer>,
279 cx: &App,
280 ) -> Vec<(Range<usize>, Arc<str>)> {
281 let buffer = buffer.read(cx);
282 editor_edits
283 .iter()
284 .map(|(range, text)| {
285 (
286 range.start.to_offset(buffer)..range.end.to_offset(buffer),
287 text.clone(),
288 )
289 })
290 .collect()
291 }
292}