provider.rs

  1use std::{cmp, sync::Arc, time::Duration};
  2
  3use client::{Client, UserStore};
  4use edit_prediction::{DataCollectionState, Direction, EditPredictionProvider};
  5use gpui::{App, Entity, prelude::*};
  6use language::ToPoint as _;
  7use project::Project;
  8
  9use crate::{BufferEditPrediction, Zeta, ZetaEditPredictionModel};
 10
 11pub struct ZetaEditPredictionProvider {
 12    zeta: Entity<Zeta>,
 13    project: Entity<Project>,
 14}
 15
 16impl ZetaEditPredictionProvider {
 17    pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
 18
 19    pub fn new(
 20        project: Entity<Project>,
 21        client: &Arc<Client>,
 22        user_store: &Entity<UserStore>,
 23        cx: &mut Context<Self>,
 24    ) -> Self {
 25        let zeta = Zeta::global(client, user_store, cx);
 26        zeta.update(cx, |zeta, cx| {
 27            zeta.register_project(&project, cx);
 28        });
 29
 30        cx.observe(&zeta, |_this, _zeta, cx| {
 31            cx.notify();
 32        })
 33        .detach();
 34
 35        Self {
 36            project: project,
 37            zeta,
 38        }
 39    }
 40}
 41
 42impl EditPredictionProvider for ZetaEditPredictionProvider {
 43    fn name() -> &'static str {
 44        "zed-predict2"
 45    }
 46
 47    fn display_name() -> &'static str {
 48        "Zed's Edit Predictions 2"
 49    }
 50
 51    fn show_completions_in_menu() -> bool {
 52        true
 53    }
 54
 55    fn show_tab_accept_marker() -> bool {
 56        true
 57    }
 58
 59    fn data_collection_state(&self, _cx: &App) -> DataCollectionState {
 60        // TODO [zeta2]
 61        DataCollectionState::Unsupported
 62    }
 63
 64    fn toggle_data_collection(&mut self, _cx: &mut App) {
 65        // TODO [zeta2]
 66    }
 67
 68    fn usage(&self, cx: &App) -> Option<client::EditPredictionUsage> {
 69        self.zeta.read(cx).usage(cx)
 70    }
 71
 72    fn is_enabled(
 73        &self,
 74        _buffer: &Entity<language::Buffer>,
 75        _cursor_position: language::Anchor,
 76        cx: &App,
 77    ) -> bool {
 78        let zeta = self.zeta.read(cx);
 79        if zeta.edit_prediction_model == ZetaEditPredictionModel::Sweep {
 80            zeta.sweep_ai.api_token.is_some()
 81        } else {
 82            true
 83        }
 84    }
 85
 86    fn is_refreshing(&self, cx: &App) -> bool {
 87        self.zeta.read(cx).is_refreshing(&self.project)
 88    }
 89
 90    fn refresh(
 91        &mut self,
 92        buffer: Entity<language::Buffer>,
 93        cursor_position: language::Anchor,
 94        _debounce: bool,
 95        cx: &mut Context<Self>,
 96    ) {
 97        let zeta = self.zeta.read(cx);
 98
 99        if zeta.user_store.read_with(cx, |user_store, _cx| {
100            user_store.account_too_young() || user_store.has_overdue_invoices()
101        }) {
102            return;
103        }
104
105        if let Some(current) = zeta.current_prediction_for_buffer(&buffer, &self.project, cx)
106            && let BufferEditPrediction::Local { prediction } = current
107            && prediction.interpolate(buffer.read(cx)).is_some()
108        {
109            return;
110        }
111
112        self.zeta.update(cx, |zeta, cx| {
113            zeta.refresh_context_if_needed(&self.project, &buffer, cursor_position, cx);
114            zeta.refresh_prediction_from_buffer(self.project.clone(), buffer, cursor_position, cx)
115        });
116    }
117
118    fn cycle(
119        &mut self,
120        _buffer: Entity<language::Buffer>,
121        _cursor_position: language::Anchor,
122        _direction: Direction,
123        _cx: &mut Context<Self>,
124    ) {
125    }
126
127    fn accept(&mut self, cx: &mut Context<Self>) {
128        self.zeta.update(cx, |zeta, cx| {
129            zeta.accept_current_prediction(&self.project, cx);
130        });
131    }
132
133    fn discard(&mut self, cx: &mut Context<Self>) {
134        self.zeta.update(cx, |zeta, cx| {
135            zeta.discard_current_prediction(&self.project, cx);
136        });
137    }
138
139    fn did_show(&mut self, cx: &mut Context<Self>) {
140        self.zeta.update(cx, |zeta, cx| {
141            zeta.did_show_current_prediction(&self.project, cx);
142        });
143    }
144
145    fn suggest(
146        &mut self,
147        buffer: &Entity<language::Buffer>,
148        cursor_position: language::Anchor,
149        cx: &mut Context<Self>,
150    ) -> Option<edit_prediction::EditPrediction> {
151        let prediction =
152            self.zeta
153                .read(cx)
154                .current_prediction_for_buffer(buffer, &self.project, cx)?;
155
156        let prediction = match prediction {
157            BufferEditPrediction::Local { prediction } => prediction,
158            BufferEditPrediction::Jump { prediction } => {
159                return Some(edit_prediction::EditPrediction::Jump {
160                    id: Some(prediction.id.to_string().into()),
161                    snapshot: prediction.snapshot.clone(),
162                    target: prediction.edits.first().unwrap().0.start,
163                });
164            }
165        };
166
167        let buffer = buffer.read(cx);
168        let snapshot = buffer.snapshot();
169
170        let Some(edits) = prediction.interpolate(&snapshot) else {
171            self.zeta.update(cx, |zeta, cx| {
172                zeta.discard_current_prediction(&self.project, cx);
173            });
174            return None;
175        };
176
177        let cursor_row = cursor_position.to_point(&snapshot).row;
178        let (closest_edit_ix, (closest_edit_range, _)) =
179            edits.iter().enumerate().min_by_key(|(_, (range, _))| {
180                let distance_from_start = cursor_row.abs_diff(range.start.to_point(&snapshot).row);
181                let distance_from_end = cursor_row.abs_diff(range.end.to_point(&snapshot).row);
182                cmp::min(distance_from_start, distance_from_end)
183            })?;
184
185        let mut edit_start_ix = closest_edit_ix;
186        for (range, _) in edits[..edit_start_ix].iter().rev() {
187            let distance_from_closest_edit = closest_edit_range.start.to_point(&snapshot).row
188                - range.end.to_point(&snapshot).row;
189            if distance_from_closest_edit <= 1 {
190                edit_start_ix -= 1;
191            } else {
192                break;
193            }
194        }
195
196        let mut edit_end_ix = closest_edit_ix + 1;
197        for (range, _) in &edits[edit_end_ix..] {
198            let distance_from_closest_edit =
199                range.start.to_point(buffer).row - closest_edit_range.end.to_point(&snapshot).row;
200            if distance_from_closest_edit <= 1 {
201                edit_end_ix += 1;
202            } else {
203                break;
204            }
205        }
206
207        Some(edit_prediction::EditPrediction::Local {
208            id: Some(prediction.id.to_string().into()),
209            edits: edits[edit_start_ix..edit_end_ix].to_vec(),
210            edit_preview: Some(prediction.edit_preview.clone()),
211        })
212    }
213}