1use std::{
2 cmp,
3 sync::Arc,
4 time::{Duration, Instant},
5};
6
7use anyhow::Context as _;
8use arrayvec::ArrayVec;
9use client::{Client, UserStore};
10use edit_prediction::{DataCollectionState, Direction, EditPredictionProvider};
11use gpui::{App, Entity, EntityId, Task, prelude::*};
12use language::{BufferSnapshot, ToPoint as _};
13use project::Project;
14use util::ResultExt as _;
15
16use crate::{Zeta, prediction::EditPrediction};
17
18pub struct ZetaEditPredictionProvider {
19 zeta: Entity<Zeta>,
20 current_prediction: Option<CurrentEditPrediction>,
21 next_pending_prediction_id: usize,
22 pending_predictions: ArrayVec<PendingPrediction, 2>,
23 last_request_timestamp: Instant,
24}
25
26impl ZetaEditPredictionProvider {
27 pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
28
29 pub fn new(
30 project: Option<&Entity<Project>>,
31 client: &Arc<Client>,
32 user_store: &Entity<UserStore>,
33 cx: &mut App,
34 ) -> Self {
35 let zeta = Zeta::global(client, user_store, cx);
36 if let Some(project) = project {
37 zeta.update(cx, |zeta, cx| {
38 zeta.register_project(project, cx);
39 });
40 }
41
42 Self {
43 zeta,
44 current_prediction: None,
45 next_pending_prediction_id: 0,
46 pending_predictions: ArrayVec::new(),
47 last_request_timestamp: Instant::now(),
48 }
49 }
50}
51
52#[derive(Clone)]
53struct CurrentEditPrediction {
54 buffer_id: EntityId,
55 prediction: EditPrediction,
56}
57
58impl CurrentEditPrediction {
59 fn should_replace_prediction(&self, old_prediction: &Self, snapshot: &BufferSnapshot) -> bool {
60 if self.buffer_id != old_prediction.buffer_id {
61 return true;
62 }
63
64 let Some(old_edits) = old_prediction.prediction.interpolate(snapshot) else {
65 return true;
66 };
67 let Some(new_edits) = self.prediction.interpolate(snapshot) else {
68 return false;
69 };
70
71 if old_edits.len() == 1 && new_edits.len() == 1 {
72 let (old_range, old_text) = &old_edits[0];
73 let (new_range, new_text) = &new_edits[0];
74 new_range == old_range && new_text.starts_with(old_text)
75 } else {
76 true
77 }
78 }
79}
80
81struct PendingPrediction {
82 id: usize,
83 _task: Task<()>,
84}
85
86impl EditPredictionProvider for ZetaEditPredictionProvider {
87 fn name() -> &'static str {
88 "zed-predict2"
89 }
90
91 fn display_name() -> &'static str {
92 "Zed's Edit Predictions 2"
93 }
94
95 fn show_completions_in_menu() -> bool {
96 true
97 }
98
99 fn show_tab_accept_marker() -> bool {
100 true
101 }
102
103 fn data_collection_state(&self, _cx: &App) -> DataCollectionState {
104 // TODO [zeta2]
105 DataCollectionState::Unsupported
106 }
107
108 fn toggle_data_collection(&mut self, _cx: &mut App) {
109 // TODO [zeta2]
110 }
111
112 fn usage(&self, cx: &App) -> Option<client::EditPredictionUsage> {
113 self.zeta.read(cx).usage(cx)
114 }
115
116 fn is_enabled(
117 &self,
118 _buffer: &Entity<language::Buffer>,
119 _cursor_position: language::Anchor,
120 _cx: &App,
121 ) -> bool {
122 true
123 }
124
125 fn is_refreshing(&self) -> bool {
126 !self.pending_predictions.is_empty()
127 }
128
129 fn refresh(
130 &mut self,
131 project: Option<Entity<project::Project>>,
132 buffer: Entity<language::Buffer>,
133 cursor_position: language::Anchor,
134 _debounce: bool,
135 cx: &mut Context<Self>,
136 ) {
137 let Some(project) = project else {
138 return;
139 };
140
141 if self
142 .zeta
143 .read(cx)
144 .user_store
145 .read_with(cx, |user_store, _cx| {
146 user_store.account_too_young() || user_store.has_overdue_invoices()
147 })
148 {
149 return;
150 }
151
152 if let Some(current_prediction) = self.current_prediction.as_ref() {
153 let snapshot = buffer.read(cx).snapshot();
154 if current_prediction
155 .prediction
156 .interpolate(&snapshot)
157 .is_some()
158 {
159 return;
160 }
161 }
162
163 let pending_prediction_id = self.next_pending_prediction_id;
164 self.next_pending_prediction_id += 1;
165 let last_request_timestamp = self.last_request_timestamp;
166
167 let task = cx.spawn(async move |this, cx| {
168 if let Some(timeout) = (last_request_timestamp + Self::THROTTLE_TIMEOUT)
169 .checked_duration_since(Instant::now())
170 {
171 cx.background_executor().timer(timeout).await;
172 }
173
174 let prediction_request = this.update(cx, |this, cx| {
175 this.last_request_timestamp = Instant::now();
176 this.zeta.update(cx, |zeta, cx| {
177 zeta.request_prediction(&project, &buffer, cursor_position, cx)
178 })
179 });
180
181 let prediction = match prediction_request {
182 Ok(prediction_request) => {
183 let prediction_request = prediction_request.await;
184 prediction_request.map(|c| {
185 c.map(|prediction| CurrentEditPrediction {
186 buffer_id: buffer.entity_id(),
187 prediction,
188 })
189 })
190 }
191 Err(error) => Err(error),
192 };
193
194 this.update(cx, |this, cx| {
195 if this.pending_predictions[0].id == pending_prediction_id {
196 this.pending_predictions.remove(0);
197 } else {
198 this.pending_predictions.clear();
199 }
200
201 let Some(new_prediction) = prediction
202 .context("edit prediction failed")
203 .log_err()
204 .flatten()
205 else {
206 cx.notify();
207 return;
208 };
209
210 if let Some(old_prediction) = this.current_prediction.as_ref() {
211 let snapshot = buffer.read(cx).snapshot();
212 if new_prediction.should_replace_prediction(old_prediction, &snapshot) {
213 this.current_prediction = Some(new_prediction);
214 }
215 } else {
216 this.current_prediction = Some(new_prediction);
217 }
218
219 cx.notify();
220 })
221 .ok();
222 });
223
224 // We always maintain at most two pending predictions. When we already
225 // have two, we replace the newest one.
226 if self.pending_predictions.len() <= 1 {
227 self.pending_predictions.push(PendingPrediction {
228 id: pending_prediction_id,
229 _task: task,
230 });
231 } else if self.pending_predictions.len() == 2 {
232 self.pending_predictions.pop();
233 self.pending_predictions.push(PendingPrediction {
234 id: pending_prediction_id,
235 _task: task,
236 });
237 }
238
239 cx.notify();
240 }
241
242 fn cycle(
243 &mut self,
244 _buffer: Entity<language::Buffer>,
245 _cursor_position: language::Anchor,
246 _direction: Direction,
247 _cx: &mut Context<Self>,
248 ) {
249 }
250
251 fn accept(&mut self, _cx: &mut Context<Self>) {
252 // TODO [zeta2] report accept
253 self.current_prediction.take();
254 self.pending_predictions.clear();
255 }
256
257 fn discard(&mut self, _cx: &mut Context<Self>) {
258 self.pending_predictions.clear();
259 self.current_prediction.take();
260 }
261
262 fn suggest(
263 &mut self,
264 buffer: &Entity<language::Buffer>,
265 cursor_position: language::Anchor,
266 cx: &mut Context<Self>,
267 ) -> Option<edit_prediction::EditPrediction> {
268 let CurrentEditPrediction {
269 buffer_id,
270 prediction,
271 ..
272 } = self.current_prediction.as_mut()?;
273
274 // Invalidate previous prediction if it was generated for a different buffer.
275 if *buffer_id != buffer.entity_id() {
276 self.current_prediction.take();
277 return None;
278 }
279
280 let buffer = buffer.read(cx);
281 let Some(edits) = prediction.interpolate(&buffer.snapshot()) else {
282 self.current_prediction.take();
283 return None;
284 };
285
286 let cursor_row = cursor_position.to_point(buffer).row;
287 let (closest_edit_ix, (closest_edit_range, _)) =
288 edits.iter().enumerate().min_by_key(|(_, (range, _))| {
289 let distance_from_start = cursor_row.abs_diff(range.start.to_point(buffer).row);
290 let distance_from_end = cursor_row.abs_diff(range.end.to_point(buffer).row);
291 cmp::min(distance_from_start, distance_from_end)
292 })?;
293
294 let mut edit_start_ix = closest_edit_ix;
295 for (range, _) in edits[..edit_start_ix].iter().rev() {
296 let distance_from_closest_edit =
297 closest_edit_range.start.to_point(buffer).row - range.end.to_point(buffer).row;
298 if distance_from_closest_edit <= 1 {
299 edit_start_ix -= 1;
300 } else {
301 break;
302 }
303 }
304
305 let mut edit_end_ix = closest_edit_ix + 1;
306 for (range, _) in &edits[edit_end_ix..] {
307 let distance_from_closest_edit =
308 range.start.to_point(buffer).row - closest_edit_range.end.to_point(buffer).row;
309 if distance_from_closest_edit <= 1 {
310 edit_end_ix += 1;
311 } else {
312 break;
313 }
314 }
315
316 Some(edit_prediction::EditPrediction {
317 id: Some(prediction.id.to_string().into()),
318 edits: edits[edit_start_ix..edit_end_ix].to_vec(),
319 edit_preview: Some(prediction.edit_preview.clone()),
320 })
321 }
322}