1use std::{ops::Range, sync::Arc};
2
3use cloud_llm_client::EditPredictionRejectReason;
4use edit_prediction_types::{PredictedCursorPosition, interpolate_edits};
5use gpui::{AsyncApp, Entity, SharedString};
6use language::{Anchor, Buffer, BufferSnapshot, EditPreview, TextBufferSnapshot};
7use zeta_prompt::ZetaPromptInput;
8
9#[derive(Clone, Default, Debug, PartialEq, Eq, Hash)]
10pub struct EditPredictionId(pub SharedString);
11
12impl From<EditPredictionId> for gpui::ElementId {
13 fn from(value: EditPredictionId) -> Self {
14 gpui::ElementId::Name(value.0)
15 }
16}
17
18impl std::fmt::Display for EditPredictionId {
19 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
20 write!(f, "{}", self.0)
21 }
22}
23
24/// A prediction response that was returned from the provider, whether it was ultimately valid or not.
25pub struct EditPredictionResult {
26 pub id: EditPredictionId,
27 pub prediction: Result<EditPrediction, EditPredictionRejectReason>,
28 pub e2e_latency: std::time::Duration,
29}
30
31impl EditPredictionResult {
32 pub async fn new(
33 id: EditPredictionId,
34 edited_buffer: &Entity<Buffer>,
35 edited_buffer_snapshot: &BufferSnapshot,
36 edits: Arc<[(Range<Anchor>, Arc<str>)]>,
37 cursor_position: Option<PredictedCursorPosition>,
38 inputs: ZetaPromptInput,
39 model_version: Option<String>,
40 e2e_latency: std::time::Duration,
41 cx: &mut AsyncApp,
42 ) -> Self {
43 if edits.is_empty() {
44 return Self {
45 id,
46 e2e_latency,
47 prediction: Err(EditPredictionRejectReason::Empty),
48 };
49 }
50
51 let Some((edits, snapshot, edit_preview_task)) =
52 edited_buffer.read_with(cx, |buffer, cx| {
53 let new_snapshot = buffer.snapshot();
54 let edits: Arc<[_]> =
55 interpolate_edits(&edited_buffer_snapshot, &new_snapshot, &edits)?.into();
56
57 Some((edits.clone(), new_snapshot, buffer.preview_edits(edits, cx)))
58 })
59 else {
60 return Self {
61 id,
62 e2e_latency,
63 prediction: Err(EditPredictionRejectReason::InterpolatedEmpty),
64 };
65 };
66
67 let edit_preview = edit_preview_task.await;
68
69 Self {
70 id: id.clone(),
71 e2e_latency,
72 prediction: Ok(EditPrediction {
73 id,
74 edits,
75 cursor_position,
76 snapshot,
77 edit_preview,
78 inputs,
79 buffer: edited_buffer.clone(),
80 model_version,
81 }),
82 }
83 }
84}
85
86#[derive(Clone)]
87pub struct EditPrediction {
88 pub id: EditPredictionId,
89 pub edits: Arc<[(Range<Anchor>, Arc<str>)]>,
90 pub cursor_position: Option<PredictedCursorPosition>,
91 pub snapshot: BufferSnapshot,
92 pub edit_preview: EditPreview,
93 pub buffer: Entity<Buffer>,
94 pub inputs: zeta_prompt::ZetaPromptInput,
95 pub model_version: Option<String>,
96}
97
98impl EditPrediction {
99 pub fn interpolate(
100 &self,
101 new_snapshot: &TextBufferSnapshot,
102 ) -> Option<Vec<(Range<Anchor>, Arc<str>)>> {
103 interpolate_edits(&self.snapshot, new_snapshot, &self.edits)
104 }
105
106 pub fn targets_buffer(&self, buffer: &Buffer) -> bool {
107 self.snapshot.remote_id() == buffer.remote_id()
108 }
109}
110
111impl std::fmt::Debug for EditPrediction {
112 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
113 f.debug_struct("EditPrediction")
114 .field("id", &self.id)
115 .field("edits", &self.edits)
116 .finish()
117 }
118}
119
120#[cfg(test)]
121mod tests {
122 use std::path::Path;
123
124 use super::*;
125 use gpui::{App, Entity, TestAppContext, prelude::*};
126 use language::{Buffer, ToOffset as _};
127 use zeta_prompt::ZetaPromptInput;
128
129 #[gpui::test]
130 async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
131 let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
132 let edits: Arc<[(Range<Anchor>, Arc<str>)]> = cx.update(|cx| {
133 to_prediction_edits([(2..5, "REM".into()), (9..11, "".into())], &buffer, cx).into()
134 });
135
136 let edit_preview = cx
137 .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
138 .await;
139
140 let prediction = EditPrediction {
141 id: EditPredictionId("prediction-1".into()),
142 edits,
143 cursor_position: None,
144 snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
145 buffer: buffer.clone(),
146 edit_preview,
147 model_version: None,
148 inputs: ZetaPromptInput {
149 events: vec![],
150 related_files: Some(vec![]),
151 active_buffer_diagnostics: vec![],
152 cursor_path: Path::new("path.txt").into(),
153 cursor_offset_in_excerpt: 0,
154 cursor_excerpt: "".into(),
155 excerpt_start_row: None,
156 excerpt_ranges: Default::default(),
157 syntax_ranges: None,
158 experiment: None,
159 in_open_source_repo: false,
160 can_collect_data: false,
161 repo_url: None,
162 },
163 };
164
165 cx.update(|cx| {
166 assert_eq!(
167 from_prediction_edits(
168 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
169 &buffer,
170 cx
171 ),
172 vec![(2..5, "REM".into()), (9..11, "".into())]
173 );
174
175 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
176 assert_eq!(
177 from_prediction_edits(
178 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
179 &buffer,
180 cx
181 ),
182 vec![(2..2, "REM".into()), (6..8, "".into())]
183 );
184
185 buffer.update(cx, |buffer, cx| buffer.undo(cx));
186 assert_eq!(
187 from_prediction_edits(
188 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
189 &buffer,
190 cx
191 ),
192 vec![(2..5, "REM".into()), (9..11, "".into())]
193 );
194
195 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
196 assert_eq!(
197 from_prediction_edits(
198 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
199 &buffer,
200 cx
201 ),
202 vec![(3..3, "EM".into()), (7..9, "".into())]
203 );
204
205 buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
206 assert_eq!(
207 from_prediction_edits(
208 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
209 &buffer,
210 cx
211 ),
212 vec![(4..4, "M".into()), (8..10, "".into())]
213 );
214
215 buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
216 assert_eq!(
217 from_prediction_edits(
218 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
219 &buffer,
220 cx
221 ),
222 vec![(9..11, "".into())]
223 );
224
225 buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
226 assert_eq!(
227 from_prediction_edits(
228 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
229 &buffer,
230 cx
231 ),
232 vec![(4..4, "M".into()), (8..10, "".into())]
233 );
234
235 buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
236 assert_eq!(
237 from_prediction_edits(
238 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
239 &buffer,
240 cx
241 ),
242 vec![(4..4, "M".into())]
243 );
244
245 buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
246 assert_eq!(prediction.interpolate(&buffer.read(cx).snapshot()), None);
247 })
248 }
249
250 fn to_prediction_edits(
251 iterator: impl IntoIterator<Item = (Range<usize>, Arc<str>)>,
252 buffer: &Entity<Buffer>,
253 cx: &App,
254 ) -> Vec<(Range<Anchor>, Arc<str>)> {
255 let buffer = buffer.read(cx);
256 iterator
257 .into_iter()
258 .map(|(range, text)| {
259 (
260 buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
261 text,
262 )
263 })
264 .collect()
265 }
266
267 fn from_prediction_edits(
268 editor_edits: &[(Range<Anchor>, Arc<str>)],
269 buffer: &Entity<Buffer>,
270 cx: &App,
271 ) -> Vec<(Range<usize>, Arc<str>)> {
272 let buffer = buffer.read(cx);
273 editor_edits
274 .iter()
275 .map(|(range, text)| {
276 (
277 range.start.to_offset(buffer)..range.end.to_offset(buffer),
278 text.clone(),
279 )
280 })
281 .collect()
282 }
283}