1use std::{
2 ops::Range,
3 sync::Arc,
4 time::{Duration, Instant},
5};
6
7use cloud_llm_client::EditPredictionRejectReason;
8use edit_prediction_types::interpolate_edits;
9use gpui::{AsyncApp, Entity, SharedString};
10use language::{Anchor, Buffer, BufferSnapshot, EditPreview, TextBufferSnapshot};
11use zeta_prompt::ZetaPromptInput;
12
13#[derive(Clone, Default, Debug, PartialEq, Eq, Hash)]
14pub struct EditPredictionId(pub SharedString);
15
16impl From<EditPredictionId> for gpui::ElementId {
17 fn from(value: EditPredictionId) -> Self {
18 gpui::ElementId::Name(value.0)
19 }
20}
21
22impl std::fmt::Display for EditPredictionId {
23 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
24 write!(f, "{}", self.0)
25 }
26}
27
28/// A prediction response that was returned from the provider, whether it was ultimately valid or not.
29pub struct EditPredictionResult {
30 pub id: EditPredictionId,
31 pub prediction: Result<EditPrediction, EditPredictionRejectReason>,
32}
33
34impl EditPredictionResult {
35 pub async fn new(
36 id: EditPredictionId,
37 edited_buffer: &Entity<Buffer>,
38 edited_buffer_snapshot: &BufferSnapshot,
39 edits: Arc<[(Range<Anchor>, Arc<str>)]>,
40 buffer_snapshotted_at: Instant,
41 response_received_at: Instant,
42 inputs: ZetaPromptInput,
43 cx: &mut AsyncApp,
44 ) -> Self {
45 if edits.is_empty() {
46 return Self {
47 id,
48 prediction: Err(EditPredictionRejectReason::Empty),
49 };
50 }
51
52 let Some((edits, snapshot, edit_preview_task)) = edited_buffer
53 .read_with(cx, |buffer, cx| {
54 let new_snapshot = buffer.snapshot();
55 let edits: Arc<[_]> =
56 interpolate_edits(&edited_buffer_snapshot, &new_snapshot, &edits)?.into();
57
58 Some((edits.clone(), new_snapshot, buffer.preview_edits(edits, cx)))
59 })
60 .ok()
61 .flatten()
62 else {
63 return Self {
64 id,
65 prediction: Err(EditPredictionRejectReason::InterpolatedEmpty),
66 };
67 };
68
69 let edit_preview = edit_preview_task.await;
70
71 Self {
72 id: id.clone(),
73 prediction: Ok(EditPrediction {
74 id,
75 edits,
76 snapshot,
77 edit_preview,
78 inputs,
79 buffer: edited_buffer.clone(),
80 buffer_snapshotted_at,
81 response_received_at,
82 }),
83 }
84 }
85}
86
87#[derive(Clone)]
88pub struct EditPrediction {
89 pub id: EditPredictionId,
90 pub edits: Arc<[(Range<Anchor>, Arc<str>)]>,
91 pub snapshot: BufferSnapshot,
92 pub edit_preview: EditPreview,
93 pub buffer: Entity<Buffer>,
94 pub buffer_snapshotted_at: Instant,
95 pub response_received_at: Instant,
96 pub inputs: zeta_prompt::ZetaPromptInput,
97}
98
99impl EditPrediction {
100 pub fn interpolate(
101 &self,
102 new_snapshot: &TextBufferSnapshot,
103 ) -> Option<Vec<(Range<Anchor>, Arc<str>)>> {
104 interpolate_edits(&self.snapshot, new_snapshot, &self.edits)
105 }
106
107 pub fn targets_buffer(&self, buffer: &Buffer) -> bool {
108 self.snapshot.remote_id() == buffer.remote_id()
109 }
110
111 pub fn latency(&self) -> Duration {
112 self.response_received_at - self.buffer_snapshotted_at
113 }
114}
115
116impl std::fmt::Debug for EditPrediction {
117 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
118 f.debug_struct("EditPrediction")
119 .field("id", &self.id)
120 .field("edits", &self.edits)
121 .finish()
122 }
123}
124
125#[cfg(test)]
126mod tests {
127 use std::path::Path;
128
129 use super::*;
130 use gpui::{App, Entity, TestAppContext, prelude::*};
131 use language::{Buffer, ToOffset as _};
132 use zeta_prompt::ZetaPromptInput;
133
134 #[gpui::test]
135 async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
136 let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
137 let edits: Arc<[(Range<Anchor>, Arc<str>)]> = cx.update(|cx| {
138 to_prediction_edits([(2..5, "REM".into()), (9..11, "".into())], &buffer, cx).into()
139 });
140
141 let edit_preview = cx
142 .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
143 .await;
144
145 let prediction = EditPrediction {
146 id: EditPredictionId("prediction-1".into()),
147 edits,
148 snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
149 buffer: buffer.clone(),
150 edit_preview,
151 inputs: ZetaPromptInput {
152 events: vec![],
153 related_files: vec![].into(),
154 cursor_path: Path::new("path.txt").into(),
155 cursor_offset_in_excerpt: 0,
156 cursor_excerpt: "".into(),
157 editable_range_in_excerpt: 0..0,
158 },
159 buffer_snapshotted_at: Instant::now(),
160 response_received_at: Instant::now(),
161 };
162
163 cx.update(|cx| {
164 assert_eq!(
165 from_prediction_edits(
166 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
167 &buffer,
168 cx
169 ),
170 vec![(2..5, "REM".into()), (9..11, "".into())]
171 );
172
173 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
174 assert_eq!(
175 from_prediction_edits(
176 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
177 &buffer,
178 cx
179 ),
180 vec![(2..2, "REM".into()), (6..8, "".into())]
181 );
182
183 buffer.update(cx, |buffer, cx| buffer.undo(cx));
184 assert_eq!(
185 from_prediction_edits(
186 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
187 &buffer,
188 cx
189 ),
190 vec![(2..5, "REM".into()), (9..11, "".into())]
191 );
192
193 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
194 assert_eq!(
195 from_prediction_edits(
196 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
197 &buffer,
198 cx
199 ),
200 vec![(3..3, "EM".into()), (7..9, "".into())]
201 );
202
203 buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
204 assert_eq!(
205 from_prediction_edits(
206 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
207 &buffer,
208 cx
209 ),
210 vec![(4..4, "M".into()), (8..10, "".into())]
211 );
212
213 buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
214 assert_eq!(
215 from_prediction_edits(
216 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
217 &buffer,
218 cx
219 ),
220 vec![(9..11, "".into())]
221 );
222
223 buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
224 assert_eq!(
225 from_prediction_edits(
226 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
227 &buffer,
228 cx
229 ),
230 vec![(4..4, "M".into()), (8..10, "".into())]
231 );
232
233 buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
234 assert_eq!(
235 from_prediction_edits(
236 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
237 &buffer,
238 cx
239 ),
240 vec![(4..4, "M".into())]
241 );
242
243 buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
244 assert_eq!(prediction.interpolate(&buffer.read(cx).snapshot()), None);
245 })
246 }
247
248 fn to_prediction_edits(
249 iterator: impl IntoIterator<Item = (Range<usize>, Arc<str>)>,
250 buffer: &Entity<Buffer>,
251 cx: &App,
252 ) -> Vec<(Range<Anchor>, Arc<str>)> {
253 let buffer = buffer.read(cx);
254 iterator
255 .into_iter()
256 .map(|(range, text)| {
257 (
258 buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
259 text,
260 )
261 })
262 .collect()
263 }
264
265 fn from_prediction_edits(
266 editor_edits: &[(Range<Anchor>, Arc<str>)],
267 buffer: &Entity<Buffer>,
268 cx: &App,
269 ) -> Vec<(Range<usize>, Arc<str>)> {
270 let buffer = buffer.read(cx);
271 editor_edits
272 .iter()
273 .map(|(range, text)| {
274 (
275 range.start.to_offset(buffer)..range.end.to_offset(buffer),
276 text.clone(),
277 )
278 })
279 .collect()
280 }
281}