1use std::{
2 ops::Range,
3 sync::Arc,
4 time::{Duration, Instant},
5};
6
7use cloud_llm_client::EditPredictionRejectReason;
8use edit_prediction_types::interpolate_edits;
9use gpui::{AsyncApp, Entity, SharedString};
10use language::{Anchor, Buffer, BufferSnapshot, EditPreview, TextBufferSnapshot};
11use zeta_prompt::ZetaPromptInput;
12
13#[derive(Clone, Default, Debug, PartialEq, Eq, Hash)]
14pub struct EditPredictionId(pub SharedString);
15
16impl From<EditPredictionId> for gpui::ElementId {
17 fn from(value: EditPredictionId) -> Self {
18 gpui::ElementId::Name(value.0)
19 }
20}
21
22impl std::fmt::Display for EditPredictionId {
23 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
24 write!(f, "{}", self.0)
25 }
26}
27
28/// A prediction response that was returned from the provider, whether it was ultimately valid or not.
29pub struct EditPredictionResult {
30 pub id: EditPredictionId,
31 pub prediction: Result<EditPrediction, EditPredictionRejectReason>,
32}
33
34impl EditPredictionResult {
35 pub async fn new(
36 id: EditPredictionId,
37 edited_buffer: &Entity<Buffer>,
38 edited_buffer_snapshot: &BufferSnapshot,
39 edits: Arc<[(Range<Anchor>, Arc<str>)]>,
40 buffer_snapshotted_at: Instant,
41 response_received_at: Instant,
42 inputs: ZetaPromptInput,
43 cx: &mut AsyncApp,
44 ) -> Self {
45 if edits.is_empty() {
46 return Self {
47 id,
48 prediction: Err(EditPredictionRejectReason::Empty),
49 };
50 }
51
52 let Some((edits, snapshot, edit_preview_task)) =
53 edited_buffer.read_with(cx, |buffer, cx| {
54 let new_snapshot = buffer.snapshot();
55 let edits: Arc<[_]> =
56 interpolate_edits(&edited_buffer_snapshot, &new_snapshot, &edits)?.into();
57
58 Some((edits.clone(), new_snapshot, buffer.preview_edits(edits, cx)))
59 })
60 else {
61 return Self {
62 id,
63 prediction: Err(EditPredictionRejectReason::InterpolatedEmpty),
64 };
65 };
66
67 let edit_preview = edit_preview_task.await;
68
69 Self {
70 id: id.clone(),
71 prediction: Ok(EditPrediction {
72 id,
73 edits,
74 snapshot,
75 edit_preview,
76 inputs,
77 buffer: edited_buffer.clone(),
78 buffer_snapshotted_at,
79 response_received_at,
80 }),
81 }
82 }
83}
84
85#[derive(Clone)]
86pub struct EditPrediction {
87 pub id: EditPredictionId,
88 pub edits: Arc<[(Range<Anchor>, Arc<str>)]>,
89 pub snapshot: BufferSnapshot,
90 pub edit_preview: EditPreview,
91 pub buffer: Entity<Buffer>,
92 pub buffer_snapshotted_at: Instant,
93 pub response_received_at: Instant,
94 pub inputs: zeta_prompt::ZetaPromptInput,
95}
96
97impl EditPrediction {
98 pub fn interpolate(
99 &self,
100 new_snapshot: &TextBufferSnapshot,
101 ) -> Option<Vec<(Range<Anchor>, Arc<str>)>> {
102 interpolate_edits(&self.snapshot, new_snapshot, &self.edits)
103 }
104
105 pub fn targets_buffer(&self, buffer: &Buffer) -> bool {
106 self.snapshot.remote_id() == buffer.remote_id()
107 }
108
109 pub fn latency(&self) -> Duration {
110 self.response_received_at - self.buffer_snapshotted_at
111 }
112}
113
114impl std::fmt::Debug for EditPrediction {
115 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
116 f.debug_struct("EditPrediction")
117 .field("id", &self.id)
118 .field("edits", &self.edits)
119 .finish()
120 }
121}
122
123#[cfg(test)]
124mod tests {
125 use std::path::Path;
126
127 use super::*;
128 use gpui::{App, Entity, TestAppContext, prelude::*};
129 use language::{Buffer, ToOffset as _};
130 use zeta_prompt::ZetaPromptInput;
131
132 #[gpui::test]
133 async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
134 let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
135 let edits: Arc<[(Range<Anchor>, Arc<str>)]> = cx.update(|cx| {
136 to_prediction_edits([(2..5, "REM".into()), (9..11, "".into())], &buffer, cx).into()
137 });
138
139 let edit_preview = cx
140 .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
141 .await;
142
143 let prediction = EditPrediction {
144 id: EditPredictionId("prediction-1".into()),
145 edits,
146 snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
147 buffer: buffer.clone(),
148 edit_preview,
149 inputs: ZetaPromptInput {
150 events: vec![],
151 related_files: vec![].into(),
152 cursor_path: Path::new("path.txt").into(),
153 cursor_offset_in_excerpt: 0,
154 cursor_excerpt: "".into(),
155 editable_range_in_excerpt: 0..0,
156 },
157 buffer_snapshotted_at: Instant::now(),
158 response_received_at: Instant::now(),
159 };
160
161 cx.update(|cx| {
162 assert_eq!(
163 from_prediction_edits(
164 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
165 &buffer,
166 cx
167 ),
168 vec![(2..5, "REM".into()), (9..11, "".into())]
169 );
170
171 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
172 assert_eq!(
173 from_prediction_edits(
174 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
175 &buffer,
176 cx
177 ),
178 vec![(2..2, "REM".into()), (6..8, "".into())]
179 );
180
181 buffer.update(cx, |buffer, cx| buffer.undo(cx));
182 assert_eq!(
183 from_prediction_edits(
184 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
185 &buffer,
186 cx
187 ),
188 vec![(2..5, "REM".into()), (9..11, "".into())]
189 );
190
191 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
192 assert_eq!(
193 from_prediction_edits(
194 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
195 &buffer,
196 cx
197 ),
198 vec![(3..3, "EM".into()), (7..9, "".into())]
199 );
200
201 buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
202 assert_eq!(
203 from_prediction_edits(
204 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
205 &buffer,
206 cx
207 ),
208 vec![(4..4, "M".into()), (8..10, "".into())]
209 );
210
211 buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
212 assert_eq!(
213 from_prediction_edits(
214 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
215 &buffer,
216 cx
217 ),
218 vec![(9..11, "".into())]
219 );
220
221 buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
222 assert_eq!(
223 from_prediction_edits(
224 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
225 &buffer,
226 cx
227 ),
228 vec![(4..4, "M".into()), (8..10, "".into())]
229 );
230
231 buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
232 assert_eq!(
233 from_prediction_edits(
234 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
235 &buffer,
236 cx
237 ),
238 vec![(4..4, "M".into())]
239 );
240
241 buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
242 assert_eq!(prediction.interpolate(&buffer.read(cx).snapshot()), None);
243 })
244 }
245
246 fn to_prediction_edits(
247 iterator: impl IntoIterator<Item = (Range<usize>, Arc<str>)>,
248 buffer: &Entity<Buffer>,
249 cx: &App,
250 ) -> Vec<(Range<Anchor>, Arc<str>)> {
251 let buffer = buffer.read(cx);
252 iterator
253 .into_iter()
254 .map(|(range, text)| {
255 (
256 buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
257 text,
258 )
259 })
260 .collect()
261 }
262
263 fn from_prediction_edits(
264 editor_edits: &[(Range<Anchor>, Arc<str>)],
265 buffer: &Entity<Buffer>,
266 cx: &App,
267 ) -> Vec<(Range<usize>, Arc<str>)> {
268 let buffer = buffer.read(cx);
269 editor_edits
270 .iter()
271 .map(|(range, text)| {
272 (
273 range.start.to_offset(buffer)..range.end.to_offset(buffer),
274 text.clone(),
275 )
276 })
277 .collect()
278 }
279}