1use std::{ops::Range, sync::Arc};
2
3use cloud_llm_client::EditPredictionRejectReason;
4use edit_prediction_types::{PredictedCursorPosition, interpolate_edits};
5use gpui::{AsyncApp, Entity, SharedString};
6use language::{Anchor, Buffer, BufferSnapshot, EditPreview, TextBufferSnapshot};
7use zeta_prompt::ZetaPromptInput;
8
9#[derive(Clone, Default, Debug, PartialEq, Eq, Hash)]
10pub struct EditPredictionId(pub SharedString);
11
12impl From<EditPredictionId> for gpui::ElementId {
13 fn from(value: EditPredictionId) -> Self {
14 gpui::ElementId::Name(value.0)
15 }
16}
17
18impl std::fmt::Display for EditPredictionId {
19 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
20 write!(f, "{}", self.0)
21 }
22}
23
24/// A prediction response that was returned from the provider, whether it was ultimately valid or not.
25pub struct EditPredictionResult {
26 pub id: EditPredictionId,
27 pub prediction: Result<EditPrediction, EditPredictionRejectReason>,
28 pub model_version: Option<String>,
29 pub e2e_latency: std::time::Duration,
30}
31
32impl EditPredictionResult {
33 pub async fn new(
34 id: EditPredictionId,
35 edited_buffer: &Entity<Buffer>,
36 edited_buffer_snapshot: &BufferSnapshot,
37 edits: Arc<[(Range<Anchor>, Arc<str>)]>,
38 cursor_position: Option<PredictedCursorPosition>,
39 inputs: ZetaPromptInput,
40 model_version: Option<String>,
41 e2e_latency: std::time::Duration,
42 cx: &mut AsyncApp,
43 ) -> Self {
44 if edits.is_empty() {
45 return Self {
46 id,
47 prediction: Err(EditPredictionRejectReason::Empty),
48 model_version,
49 e2e_latency,
50 };
51 }
52
53 let Some((edits, snapshot, edit_preview_task)) =
54 edited_buffer.read_with(cx, |buffer, cx| {
55 let new_snapshot = buffer.snapshot();
56 let edits: Arc<[_]> =
57 interpolate_edits(&edited_buffer_snapshot, &new_snapshot, &edits)?.into();
58
59 Some((edits.clone(), new_snapshot, buffer.preview_edits(edits, cx)))
60 })
61 else {
62 return Self {
63 id,
64 prediction: Err(EditPredictionRejectReason::InterpolatedEmpty),
65 model_version,
66 e2e_latency,
67 };
68 };
69
70 let edit_preview = edit_preview_task.await;
71
72 Self {
73 id: id.clone(),
74 prediction: Ok(EditPrediction {
75 id,
76 edits,
77 cursor_position,
78 snapshot,
79 edit_preview,
80 inputs,
81 buffer: edited_buffer.clone(),
82 model_version: model_version.clone(),
83 }),
84 model_version,
85 e2e_latency,
86 }
87 }
88}
89
90#[derive(Clone)]
91pub struct EditPrediction {
92 pub id: EditPredictionId,
93 pub edits: Arc<[(Range<Anchor>, Arc<str>)]>,
94 pub cursor_position: Option<PredictedCursorPosition>,
95 pub snapshot: BufferSnapshot,
96 pub edit_preview: EditPreview,
97 pub buffer: Entity<Buffer>,
98 pub inputs: zeta_prompt::ZetaPromptInput,
99 pub model_version: Option<String>,
100}
101
102impl EditPrediction {
103 pub fn interpolate(
104 &self,
105 new_snapshot: &TextBufferSnapshot,
106 ) -> Option<Vec<(Range<Anchor>, Arc<str>)>> {
107 interpolate_edits(&self.snapshot, new_snapshot, &self.edits)
108 }
109
110 pub fn targets_buffer(&self, buffer: &Buffer) -> bool {
111 self.snapshot.remote_id() == buffer.remote_id()
112 }
113}
114
115impl std::fmt::Debug for EditPrediction {
116 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
117 f.debug_struct("EditPrediction")
118 .field("id", &self.id)
119 .field("edits", &self.edits)
120 .finish()
121 }
122}
123
124#[cfg(test)]
125mod tests {
126 use std::path::Path;
127
128 use super::*;
129 use gpui::{App, Entity, TestAppContext, prelude::*};
130 use language::{Buffer, ToOffset as _};
131 use zeta_prompt::ZetaPromptInput;
132
133 #[gpui::test]
134 async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
135 let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
136 let edits: Arc<[(Range<Anchor>, Arc<str>)]> = cx.update(|cx| {
137 to_prediction_edits([(2..5, "REM".into()), (9..11, "".into())], &buffer, cx).into()
138 });
139
140 let edit_preview = cx
141 .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
142 .await;
143
144 let prediction = EditPrediction {
145 id: EditPredictionId("prediction-1".into()),
146 edits,
147 cursor_position: None,
148 snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
149 buffer: buffer.clone(),
150 edit_preview,
151 model_version: None,
152 inputs: ZetaPromptInput {
153 events: vec![],
154 related_files: Some(vec![]),
155 active_buffer_diagnostics: vec![],
156 cursor_path: Path::new("path.txt").into(),
157 cursor_offset_in_excerpt: 0,
158 cursor_excerpt: "".into(),
159 excerpt_start_row: None,
160 excerpt_ranges: Default::default(),
161 syntax_ranges: None,
162 in_open_source_repo: false,
163 can_collect_data: false,
164 repo_url: None,
165 },
166 };
167
168 cx.update(|cx| {
169 assert_eq!(
170 from_prediction_edits(
171 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
172 &buffer,
173 cx
174 ),
175 vec![(2..5, "REM".into()), (9..11, "".into())]
176 );
177
178 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
179 assert_eq!(
180 from_prediction_edits(
181 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
182 &buffer,
183 cx
184 ),
185 vec![(2..2, "REM".into()), (6..8, "".into())]
186 );
187
188 buffer.update(cx, |buffer, cx| buffer.undo(cx));
189 assert_eq!(
190 from_prediction_edits(
191 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
192 &buffer,
193 cx
194 ),
195 vec![(2..5, "REM".into()), (9..11, "".into())]
196 );
197
198 buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
199 assert_eq!(
200 from_prediction_edits(
201 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
202 &buffer,
203 cx
204 ),
205 vec![(3..3, "EM".into()), (7..9, "".into())]
206 );
207
208 buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
209 assert_eq!(
210 from_prediction_edits(
211 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
212 &buffer,
213 cx
214 ),
215 vec![(4..4, "M".into()), (8..10, "".into())]
216 );
217
218 buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
219 assert_eq!(
220 from_prediction_edits(
221 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
222 &buffer,
223 cx
224 ),
225 vec![(9..11, "".into())]
226 );
227
228 buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
229 assert_eq!(
230 from_prediction_edits(
231 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
232 &buffer,
233 cx
234 ),
235 vec![(4..4, "M".into()), (8..10, "".into())]
236 );
237
238 buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
239 assert_eq!(
240 from_prediction_edits(
241 &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
242 &buffer,
243 cx
244 ),
245 vec![(4..4, "M".into())]
246 );
247
248 buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
249 assert_eq!(prediction.interpolate(&buffer.read(cx).snapshot()), None);
250 })
251 }
252
253 fn to_prediction_edits(
254 iterator: impl IntoIterator<Item = (Range<usize>, Arc<str>)>,
255 buffer: &Entity<Buffer>,
256 cx: &App,
257 ) -> Vec<(Range<Anchor>, Arc<str>)> {
258 let buffer = buffer.read(cx);
259 iterator
260 .into_iter()
261 .map(|(range, text)| {
262 (
263 buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
264 text,
265 )
266 })
267 .collect()
268 }
269
270 fn from_prediction_edits(
271 editor_edits: &[(Range<Anchor>, Arc<str>)],
272 buffer: &Entity<Buffer>,
273 cx: &App,
274 ) -> Vec<(Range<usize>, Arc<str>)> {
275 let buffer = buffer.read(cx);
276 editor_edits
277 .iter()
278 .map(|(range, text)| {
279 (
280 range.start.to_offset(buffer)..range.end.to_offset(buffer),
281 text.clone(),
282 )
283 })
284 .collect()
285 }
286}