provider.rs

  1use std::{
  2    cmp,
  3    sync::Arc,
  4    time::{Duration, Instant},
  5};
  6
  7use arrayvec::ArrayVec;
  8use client::{Client, UserStore};
  9use edit_prediction::{DataCollectionState, Direction, EditPredictionProvider};
 10use gpui::{App, Entity, Task, prelude::*};
 11use language::ToPoint as _;
 12use project::Project;
 13use util::ResultExt as _;
 14
 15use crate::{BufferEditPrediction, Zeta, ZetaEditPredictionModel};
 16
 17pub struct ZetaEditPredictionProvider {
 18    zeta: Entity<Zeta>,
 19    next_pending_prediction_id: usize,
 20    pending_predictions: ArrayVec<PendingPrediction, 2>,
 21    last_request_timestamp: Instant,
 22    project: Entity<Project>,
 23}
 24
 25impl ZetaEditPredictionProvider {
 26    pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
 27
 28    pub fn new(
 29        project: Entity<Project>,
 30        client: &Arc<Client>,
 31        user_store: &Entity<UserStore>,
 32        cx: &mut App,
 33    ) -> Self {
 34        let zeta = Zeta::global(client, user_store, cx);
 35        zeta.update(cx, |zeta, cx| {
 36            zeta.register_project(&project, cx);
 37        });
 38
 39        Self {
 40            zeta,
 41            next_pending_prediction_id: 0,
 42            pending_predictions: ArrayVec::new(),
 43            last_request_timestamp: Instant::now(),
 44            project: project,
 45        }
 46    }
 47}
 48
 49struct PendingPrediction {
 50    id: usize,
 51    _task: Task<()>,
 52}
 53
 54impl EditPredictionProvider for ZetaEditPredictionProvider {
 55    fn name() -> &'static str {
 56        "zed-predict2"
 57    }
 58
 59    fn display_name() -> &'static str {
 60        "Zed's Edit Predictions 2"
 61    }
 62
 63    fn show_completions_in_menu() -> bool {
 64        true
 65    }
 66
 67    fn show_tab_accept_marker() -> bool {
 68        true
 69    }
 70
 71    fn data_collection_state(&self, _cx: &App) -> DataCollectionState {
 72        // TODO [zeta2]
 73        DataCollectionState::Unsupported
 74    }
 75
 76    fn toggle_data_collection(&mut self, _cx: &mut App) {
 77        // TODO [zeta2]
 78    }
 79
 80    fn usage(&self, cx: &App) -> Option<client::EditPredictionUsage> {
 81        self.zeta.read(cx).usage(cx)
 82    }
 83
 84    fn is_enabled(
 85        &self,
 86        _buffer: &Entity<language::Buffer>,
 87        _cursor_position: language::Anchor,
 88        cx: &App,
 89    ) -> bool {
 90        let zeta = self.zeta.read(cx);
 91        if zeta.edit_prediction_model == ZetaEditPredictionModel::Sweep {
 92            zeta.sweep_api_token.is_some()
 93        } else {
 94            true
 95        }
 96    }
 97
 98    fn is_refreshing(&self) -> bool {
 99        !self.pending_predictions.is_empty()
100    }
101
102    fn refresh(
103        &mut self,
104        buffer: Entity<language::Buffer>,
105        cursor_position: language::Anchor,
106        _debounce: bool,
107        cx: &mut Context<Self>,
108    ) {
109        let zeta = self.zeta.read(cx);
110
111        if zeta.user_store.read_with(cx, |user_store, _cx| {
112            user_store.account_too_young() || user_store.has_overdue_invoices()
113        }) {
114            return;
115        }
116
117        if let Some(current) = zeta.current_prediction_for_buffer(&buffer, &self.project, cx)
118            && let BufferEditPrediction::Local { prediction } = current
119            && prediction.interpolate(buffer.read(cx)).is_some()
120        {
121            return;
122        }
123
124        self.zeta.update(cx, |zeta, cx| {
125            zeta.refresh_context_if_needed(&self.project, &buffer, cursor_position, cx);
126        });
127
128        let pending_prediction_id = self.next_pending_prediction_id;
129        self.next_pending_prediction_id += 1;
130        let last_request_timestamp = self.last_request_timestamp;
131
132        let project = self.project.clone();
133        let task = cx.spawn(async move |this, cx| {
134            if let Some(timeout) = (last_request_timestamp + Self::THROTTLE_TIMEOUT)
135                .checked_duration_since(Instant::now())
136            {
137                cx.background_executor().timer(timeout).await;
138            }
139
140            let refresh_task = this.update(cx, |this, cx| {
141                this.last_request_timestamp = Instant::now();
142                this.zeta.update(cx, |zeta, cx| {
143                    zeta.refresh_prediction(&project, &buffer, cursor_position, cx)
144                })
145            });
146
147            if let Some(refresh_task) = refresh_task.ok() {
148                refresh_task.await.log_err();
149            }
150
151            this.update(cx, |this, cx| {
152                if this.pending_predictions[0].id == pending_prediction_id {
153                    this.pending_predictions.remove(0);
154                } else {
155                    this.pending_predictions.clear();
156                }
157
158                cx.notify();
159            })
160            .ok();
161        });
162
163        // We always maintain at most two pending predictions. When we already
164        // have two, we replace the newest one.
165        if self.pending_predictions.len() <= 1 {
166            self.pending_predictions.push(PendingPrediction {
167                id: pending_prediction_id,
168                _task: task,
169            });
170        } else if self.pending_predictions.len() == 2 {
171            self.pending_predictions.pop();
172            self.pending_predictions.push(PendingPrediction {
173                id: pending_prediction_id,
174                _task: task,
175            });
176        }
177
178        cx.notify();
179    }
180
181    fn cycle(
182        &mut self,
183        _buffer: Entity<language::Buffer>,
184        _cursor_position: language::Anchor,
185        _direction: Direction,
186        _cx: &mut Context<Self>,
187    ) {
188    }
189
190    fn accept(&mut self, cx: &mut Context<Self>) {
191        self.zeta.update(cx, |zeta, cx| {
192            zeta.accept_current_prediction(&self.project, cx);
193        });
194        self.pending_predictions.clear();
195    }
196
197    fn discard(&mut self, cx: &mut Context<Self>) {
198        self.zeta.update(cx, |zeta, _cx| {
199            zeta.discard_current_prediction(&self.project);
200        });
201        self.pending_predictions.clear();
202    }
203
204    fn suggest(
205        &mut self,
206        buffer: &Entity<language::Buffer>,
207        cursor_position: language::Anchor,
208        cx: &mut Context<Self>,
209    ) -> Option<edit_prediction::EditPrediction> {
210        let prediction =
211            self.zeta
212                .read(cx)
213                .current_prediction_for_buffer(buffer, &self.project, cx)?;
214
215        let prediction = match prediction {
216            BufferEditPrediction::Local { prediction } => prediction,
217            BufferEditPrediction::Jump { prediction } => {
218                return Some(edit_prediction::EditPrediction::Jump {
219                    id: Some(prediction.id.to_string().into()),
220                    snapshot: prediction.snapshot.clone(),
221                    target: prediction.edits.first().unwrap().0.start,
222                });
223            }
224        };
225
226        let buffer = buffer.read(cx);
227        let snapshot = buffer.snapshot();
228
229        let Some(edits) = prediction.interpolate(&snapshot) else {
230            self.zeta.update(cx, |zeta, _cx| {
231                zeta.discard_current_prediction(&self.project);
232            });
233            return None;
234        };
235
236        let cursor_row = cursor_position.to_point(&snapshot).row;
237        let (closest_edit_ix, (closest_edit_range, _)) =
238            edits.iter().enumerate().min_by_key(|(_, (range, _))| {
239                let distance_from_start = cursor_row.abs_diff(range.start.to_point(&snapshot).row);
240                let distance_from_end = cursor_row.abs_diff(range.end.to_point(&snapshot).row);
241                cmp::min(distance_from_start, distance_from_end)
242            })?;
243
244        let mut edit_start_ix = closest_edit_ix;
245        for (range, _) in edits[..edit_start_ix].iter().rev() {
246            let distance_from_closest_edit = closest_edit_range.start.to_point(&snapshot).row
247                - range.end.to_point(&snapshot).row;
248            if distance_from_closest_edit <= 1 {
249                edit_start_ix -= 1;
250            } else {
251                break;
252            }
253        }
254
255        let mut edit_end_ix = closest_edit_ix + 1;
256        for (range, _) in &edits[edit_end_ix..] {
257            let distance_from_closest_edit =
258                range.start.to_point(buffer).row - closest_edit_range.end.to_point(&snapshot).row;
259            if distance_from_closest_edit <= 1 {
260                edit_end_ix += 1;
261            } else {
262                break;
263            }
264        }
265
266        Some(edit_prediction::EditPrediction::Local {
267            id: Some(prediction.id.to_string().into()),
268            edits: edits[edit_start_ix..edit_end_ix].to_vec(),
269            edit_preview: Some(prediction.edit_preview.clone()),
270        })
271    }
272}