1use std::{
2 ops::Range,
3 sync::Arc,
4 time::{Duration, Instant},
5};
6
7use cloud_llm_client::EditPredictionRejectReason;
8use edit_prediction_types::{PredictedCursorPosition, interpolate_edits};
9use gpui::{AsyncApp, Entity, SharedString};
10use language::{Anchor, Buffer, BufferSnapshot, EditPreview, TextBufferSnapshot};
11use zeta_prompt::ZetaPromptInput;
12
13#[derive(Clone, Default, Debug, PartialEq, Eq, Hash)]
14pub struct EditPredictionId(pub SharedString);
15
16impl From<EditPredictionId> for gpui::ElementId {
17 fn from(value: EditPredictionId) -> Self {
18 gpui::ElementId::Name(value.0)
19 }
20}
21
22impl std::fmt::Display for EditPredictionId {
23 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
24 write!(f, "{}", self.0)
25 }
26}
27
28/// A prediction response that was returned from the provider, whether it was ultimately valid or not.
29pub struct EditPredictionResult {
30 pub id: EditPredictionId,
31 pub prediction: Result<EditPrediction, EditPredictionRejectReason>,
32}
33
34impl EditPredictionResult {
35 pub async fn new(
36 id: EditPredictionId,
37 edited_buffer: &Entity<Buffer>,
38 edited_buffer_snapshot: &BufferSnapshot,
39 edits: Arc<[(Range<Anchor>, Arc<str>)]>,
40 cursor_position: Option<PredictedCursorPosition>,
41 buffer_snapshotted_at: Instant,
42 response_received_at: Instant,
43 inputs: ZetaPromptInput,
44 cx: &mut AsyncApp,
45 ) -> Self {
46 if edits.is_empty() {
47 return Self {
48 id,
49 prediction: Err(EditPredictionRejectReason::Empty),
50 };
51 }
52
53 let Some((edits, snapshot, edit_preview_task)) =
54 edited_buffer.read_with(cx, |buffer, cx| {
55 let new_snapshot = buffer.snapshot();
56 let edits: Arc<[_]> =
57 interpolate_edits(&edited_buffer_snapshot, &new_snapshot, &edits)?.into();
58
59 Some((edits.clone(), new_snapshot, buffer.preview_edits(edits, cx)))
60 })
61 else {
62 return Self {
63 id,
64 prediction: Err(EditPredictionRejectReason::InterpolatedEmpty),
65 };
66 };
67
68 let edit_preview = edit_preview_task.await;
69
70 Self {
71 id: id.clone(),
72 prediction: Ok(EditPrediction {
73 id,
74 edits,
75 cursor_position,
76 snapshot,
77 edit_preview,
78 inputs,
79 buffer: edited_buffer.clone(),
80 buffer_snapshotted_at,
81 response_received_at,
82 }),
83 }
84 }
85}
86
87#[derive(Clone)]
88pub struct EditPrediction {
89 pub id: EditPredictionId,
90 pub edits: Arc<[(Range<Anchor>, Arc<str>)]>,
91 pub cursor_position: Option<PredictedCursorPosition>,
92 pub snapshot: BufferSnapshot,
93 pub edit_preview: EditPreview,
94 pub buffer: Entity<Buffer>,
95 pub buffer_snapshotted_at: Instant,
96 pub response_received_at: Instant,
97 pub inputs: zeta_prompt::ZetaPromptInput,
98}
99
100impl EditPrediction {
101 pub fn interpolate(
102 &self,
103 new_snapshot: &TextBufferSnapshot,
104 ) -> Option<Vec<(Range<Anchor>, Arc<str>)>> {
105 interpolate_edits(&self.snapshot, new_snapshot, &self.edits)
106 }
107
108 pub fn targets_buffer(&self, buffer: &Buffer) -> bool {
109 self.snapshot.remote_id() == buffer.remote_id()
110 }
111
112 pub fn latency(&self) -> Duration {
113 self.response_received_at - self.buffer_snapshotted_at
114 }
115}
116
117impl std::fmt::Debug for EditPrediction {
118 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
119 f.debug_struct("EditPrediction")
120 .field("id", &self.id)
121 .field("edits", &self.edits)
122 .finish()
123 }
124}
125
126#[cfg(test)]
127mod tests {
128 use std::path::Path;
129
130 use super::*;
131 use gpui::{App, Entity, TestAppContext, prelude::*};
132 use language::{Buffer, ToOffset as _};
133 use zeta_prompt::ZetaPromptInput;
134
135 #[gpui::test]
136 async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
137 let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
138 let edits: Arc<[(Range<Anchor>, Arc<str>)]> = cx.update(|cx| {
139 to_prediction_edits([(2..5, "REM".into()), (9..11, "".into())], &buffer, cx).into()
140 });
141
142 let edit_preview = cx
143 .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
144 .await;
145
146 let prediction = EditPrediction {
147 id: EditPredictionId("prediction-1".into()),
148 edits,
149 cursor_position: None,
150 snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
151 buffer: buffer.clone(),
152 edit_preview,
153 inputs: ZetaPromptInput {
154 events: vec![],
155 related_files: vec![],
156 cursor_path: Path::new("path.txt").into(),
157 cursor_offset_in_excerpt: 0,
158 cursor_excerpt: "".into(),
159 editable_range_in_excerpt: 0..0,
160 excerpt_start_row: None,
161 excerpt_ranges: None,
162 preferred_model: None,
163 in_open_source_repo: false,
164 can_collect_data: false,
165 },
166 buffer_snapshotted_at: Instant::now(),
167 response_received_at: Instant::now(),
168 };
169
170 cx.update(|cx| {
171 assert_eq!(
172 from_prediction_edits(
173 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
174 &buffer,
175 cx
176 ),
177 vec![(2..5, "REM".into()), (9..11, "".into())]
178 );
179
180 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
181 assert_eq!(
182 from_prediction_edits(
183 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
184 &buffer,
185 cx
186 ),
187 vec![(2..2, "REM".into()), (6..8, "".into())]
188 );
189
190 buffer.update(cx, |buffer, cx| buffer.undo(cx));
191 assert_eq!(
192 from_prediction_edits(
193 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
194 &buffer,
195 cx
196 ),
197 vec![(2..5, "REM".into()), (9..11, "".into())]
198 );
199
200 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
201 assert_eq!(
202 from_prediction_edits(
203 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
204 &buffer,
205 cx
206 ),
207 vec![(3..3, "EM".into()), (7..9, "".into())]
208 );
209
210 buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
211 assert_eq!(
212 from_prediction_edits(
213 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
214 &buffer,
215 cx
216 ),
217 vec![(4..4, "M".into()), (8..10, "".into())]
218 );
219
220 buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
221 assert_eq!(
222 from_prediction_edits(
223 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
224 &buffer,
225 cx
226 ),
227 vec![(9..11, "".into())]
228 );
229
230 buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
231 assert_eq!(
232 from_prediction_edits(
233 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
234 &buffer,
235 cx
236 ),
237 vec![(4..4, "M".into()), (8..10, "".into())]
238 );
239
240 buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
241 assert_eq!(
242 from_prediction_edits(
243 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
244 &buffer,
245 cx
246 ),
247 vec![(4..4, "M".into())]
248 );
249
250 buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
251 assert_eq!(prediction.interpolate(&buffer.read(cx).snapshot()), None);
252 })
253 }
254
255 fn to_prediction_edits(
256 iterator: impl IntoIterator<Item = (Range<usize>, Arc<str>)>,
257 buffer: &Entity<Buffer>,
258 cx: &App,
259 ) -> Vec<(Range<Anchor>, Arc<str>)> {
260 let buffer = buffer.read(cx);
261 iterator
262 .into_iter()
263 .map(|(range, text)| {
264 (
265 buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
266 text,
267 )
268 })
269 .collect()
270 }
271
272 fn from_prediction_edits(
273 editor_edits: &[(Range<Anchor>, Arc<str>)],
274 buffer: &Entity<Buffer>,
275 cx: &App,
276 ) -> Vec<(Range<usize>, Arc<str>)> {
277 let buffer = buffer.read(cx);
278 editor_edits
279 .iter()
280 .map(|(range, text)| {
281 (
282 range.start.to_offset(buffer)..range.end.to_offset(buffer),
283 text.clone(),
284 )
285 })
286 .collect()
287 }
288}