provider.rs

  1use std::{
  2    cmp,
  3    sync::Arc,
  4    time::{Duration, Instant},
  5};
  6
  7use anyhow::Context as _;
  8use arrayvec::ArrayVec;
  9use client::{Client, UserStore};
 10use edit_prediction::{DataCollectionState, Direction, EditPredictionProvider};
 11use gpui::{App, Entity, EntityId, Task, prelude::*};
 12use language::{BufferSnapshot, ToPoint as _};
 13use project::Project;
 14use util::ResultExt as _;
 15
 16use crate::{Zeta, prediction::EditPrediction};
 17
 18pub struct ZetaEditPredictionProvider {
 19    zeta: Entity<Zeta>,
 20    current_prediction: Option<CurrentEditPrediction>,
 21    next_pending_prediction_id: usize,
 22    pending_predictions: ArrayVec<PendingPrediction, 2>,
 23    last_request_timestamp: Instant,
 24}
 25
 26impl ZetaEditPredictionProvider {
 27    pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
 28
 29    pub fn new(
 30        project: Option<&Entity<Project>>,
 31        client: &Arc<Client>,
 32        user_store: &Entity<UserStore>,
 33        cx: &mut App,
 34    ) -> Self {
 35        let zeta = Zeta::global(client, user_store, cx);
 36        if let Some(project) = project {
 37            zeta.update(cx, |zeta, cx| {
 38                zeta.register_project(project, cx);
 39            });
 40        }
 41
 42        Self {
 43            zeta,
 44            current_prediction: None,
 45            next_pending_prediction_id: 0,
 46            pending_predictions: ArrayVec::new(),
 47            last_request_timestamp: Instant::now(),
 48        }
 49    }
 50}
 51
 52#[derive(Clone)]
 53struct CurrentEditPrediction {
 54    buffer_id: EntityId,
 55    prediction: EditPrediction,
 56}
 57
 58impl CurrentEditPrediction {
 59    fn should_replace_prediction(&self, old_prediction: &Self, snapshot: &BufferSnapshot) -> bool {
 60        if self.buffer_id != old_prediction.buffer_id {
 61            return true;
 62        }
 63
 64        let Some(old_edits) = old_prediction.prediction.interpolate(snapshot) else {
 65            return true;
 66        };
 67        let Some(new_edits) = self.prediction.interpolate(snapshot) else {
 68            return false;
 69        };
 70
 71        if old_edits.len() == 1 && new_edits.len() == 1 {
 72            let (old_range, old_text) = &old_edits[0];
 73            let (new_range, new_text) = &new_edits[0];
 74            new_range == old_range && new_text.starts_with(old_text)
 75        } else {
 76            true
 77        }
 78    }
 79}
 80
 81struct PendingPrediction {
 82    id: usize,
 83    _task: Task<()>,
 84}
 85
 86impl EditPredictionProvider for ZetaEditPredictionProvider {
 87    fn name() -> &'static str {
 88        "zed-predict2"
 89    }
 90
 91    fn display_name() -> &'static str {
 92        "Zed's Edit Predictions 2"
 93    }
 94
 95    fn show_completions_in_menu() -> bool {
 96        true
 97    }
 98
 99    fn show_tab_accept_marker() -> bool {
100        true
101    }
102
103    fn data_collection_state(&self, _cx: &App) -> DataCollectionState {
104        // TODO [zeta2]
105        DataCollectionState::Unsupported
106    }
107
108    fn toggle_data_collection(&mut self, _cx: &mut App) {
109        // TODO [zeta2]
110    }
111
112    fn usage(&self, cx: &App) -> Option<client::EditPredictionUsage> {
113        self.zeta.read(cx).usage(cx)
114    }
115
116    fn is_enabled(
117        &self,
118        _buffer: &Entity<language::Buffer>,
119        _cursor_position: language::Anchor,
120        _cx: &App,
121    ) -> bool {
122        true
123    }
124
125    fn is_refreshing(&self) -> bool {
126        !self.pending_predictions.is_empty()
127    }
128
129    fn refresh(
130        &mut self,
131        project: Option<Entity<project::Project>>,
132        buffer: Entity<language::Buffer>,
133        cursor_position: language::Anchor,
134        _debounce: bool,
135        cx: &mut Context<Self>,
136    ) {
137        let Some(project) = project else {
138            return;
139        };
140
141        if self
142            .zeta
143            .read(cx)
144            .user_store
145            .read_with(cx, |user_store, _cx| {
146                user_store.account_too_young() || user_store.has_overdue_invoices()
147            })
148        {
149            return;
150        }
151
152        if let Some(current_prediction) = self.current_prediction.as_ref() {
153            let snapshot = buffer.read(cx).snapshot();
154            if current_prediction
155                .prediction
156                .interpolate(&snapshot)
157                .is_some()
158            {
159                return;
160            }
161        }
162
163        let pending_prediction_id = self.next_pending_prediction_id;
164        self.next_pending_prediction_id += 1;
165        let last_request_timestamp = self.last_request_timestamp;
166
167        let task = cx.spawn(async move |this, cx| {
168            if let Some(timeout) = (last_request_timestamp + Self::THROTTLE_TIMEOUT)
169                .checked_duration_since(Instant::now())
170            {
171                cx.background_executor().timer(timeout).await;
172            }
173
174            let prediction_request = this.update(cx, |this, cx| {
175                this.last_request_timestamp = Instant::now();
176                this.zeta.update(cx, |zeta, cx| {
177                    zeta.request_prediction(&project, &buffer, cursor_position, cx)
178                })
179            });
180
181            let prediction = match prediction_request {
182                Ok(prediction_request) => {
183                    let prediction_request = prediction_request.await;
184                    prediction_request.map(|c| {
185                        c.map(|prediction| CurrentEditPrediction {
186                            buffer_id: buffer.entity_id(),
187                            prediction,
188                        })
189                    })
190                }
191                Err(error) => Err(error),
192            };
193
194            this.update(cx, |this, cx| {
195                if this.pending_predictions[0].id == pending_prediction_id {
196                    this.pending_predictions.remove(0);
197                } else {
198                    this.pending_predictions.clear();
199                }
200
201                let Some(new_prediction) = prediction
202                    .context("edit prediction failed")
203                    .log_err()
204                    .flatten()
205                else {
206                    cx.notify();
207                    return;
208                };
209
210                if let Some(old_prediction) = this.current_prediction.as_ref() {
211                    let snapshot = buffer.read(cx).snapshot();
212                    if new_prediction.should_replace_prediction(old_prediction, &snapshot) {
213                        this.current_prediction = Some(new_prediction);
214                    }
215                } else {
216                    this.current_prediction = Some(new_prediction);
217                }
218
219                cx.notify();
220            })
221            .ok();
222        });
223
224        // We always maintain at most two pending predictions. When we already
225        // have two, we replace the newest one.
226        if self.pending_predictions.len() <= 1 {
227            self.pending_predictions.push(PendingPrediction {
228                id: pending_prediction_id,
229                _task: task,
230            });
231        } else if self.pending_predictions.len() == 2 {
232            self.pending_predictions.pop();
233            self.pending_predictions.push(PendingPrediction {
234                id: pending_prediction_id,
235                _task: task,
236            });
237        }
238
239        cx.notify();
240    }
241
242    fn cycle(
243        &mut self,
244        _buffer: Entity<language::Buffer>,
245        _cursor_position: language::Anchor,
246        _direction: Direction,
247        _cx: &mut Context<Self>,
248    ) {
249    }
250
251    fn accept(&mut self, _cx: &mut Context<Self>) {
252        // TODO [zeta2] report accept
253        self.current_prediction.take();
254        self.pending_predictions.clear();
255    }
256
257    fn discard(&mut self, _cx: &mut Context<Self>) {
258        self.pending_predictions.clear();
259        self.current_prediction.take();
260    }
261
262    fn suggest(
263        &mut self,
264        buffer: &Entity<language::Buffer>,
265        cursor_position: language::Anchor,
266        cx: &mut Context<Self>,
267    ) -> Option<edit_prediction::EditPrediction> {
268        let CurrentEditPrediction {
269            buffer_id,
270            prediction,
271            ..
272        } = self.current_prediction.as_mut()?;
273
274        // Invalidate previous prediction if it was generated for a different buffer.
275        if *buffer_id != buffer.entity_id() {
276            self.current_prediction.take();
277            return None;
278        }
279
280        let buffer = buffer.read(cx);
281        let Some(edits) = prediction.interpolate(&buffer.snapshot()) else {
282            self.current_prediction.take();
283            return None;
284        };
285
286        let cursor_row = cursor_position.to_point(buffer).row;
287        let (closest_edit_ix, (closest_edit_range, _)) =
288            edits.iter().enumerate().min_by_key(|(_, (range, _))| {
289                let distance_from_start = cursor_row.abs_diff(range.start.to_point(buffer).row);
290                let distance_from_end = cursor_row.abs_diff(range.end.to_point(buffer).row);
291                cmp::min(distance_from_start, distance_from_end)
292            })?;
293
294        let mut edit_start_ix = closest_edit_ix;
295        for (range, _) in edits[..edit_start_ix].iter().rev() {
296            let distance_from_closest_edit =
297                closest_edit_range.start.to_point(buffer).row - range.end.to_point(buffer).row;
298            if distance_from_closest_edit <= 1 {
299                edit_start_ix -= 1;
300            } else {
301                break;
302            }
303        }
304
305        let mut edit_end_ix = closest_edit_ix + 1;
306        for (range, _) in &edits[edit_end_ix..] {
307            let distance_from_closest_edit =
308                range.start.to_point(buffer).row - closest_edit_range.end.to_point(buffer).row;
309            if distance_from_closest_edit <= 1 {
310                edit_end_ix += 1;
311            } else {
312                break;
313            }
314        }
315
316        Some(edit_prediction::EditPrediction {
317            id: Some(prediction.id.to_string().into()),
318            edits: edits[edit_start_ix..edit_end_ix].to_vec(),
319            edit_preview: Some(prediction.edit_preview.clone()),
320        })
321    }
322}