provider.rs

  1use std::{cmp, sync::Arc, time::Duration};
  2
  3use client::{Client, UserStore};
  4use cloud_llm_client::EditPredictionRejectReason;
  5use edit_prediction::{DataCollectionState, Direction, EditPredictionProvider};
  6use gpui::{App, Entity, prelude::*};
  7use language::ToPoint as _;
  8use project::Project;
  9
 10use crate::{BufferEditPrediction, Zeta, ZetaEditPredictionModel};
 11
 12pub struct ZetaEditPredictionProvider {
 13    zeta: Entity<Zeta>,
 14    project: Entity<Project>,
 15}
 16
 17impl ZetaEditPredictionProvider {
 18    pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
 19
 20    pub fn new(
 21        project: Entity<Project>,
 22        client: &Arc<Client>,
 23        user_store: &Entity<UserStore>,
 24        cx: &mut Context<Self>,
 25    ) -> Self {
 26        let zeta = Zeta::global(client, user_store, cx);
 27        zeta.update(cx, |zeta, cx| {
 28            zeta.register_project(&project, cx);
 29        });
 30
 31        cx.observe(&zeta, |_this, _zeta, cx| {
 32            cx.notify();
 33        })
 34        .detach();
 35
 36        Self {
 37            project: project,
 38            zeta,
 39        }
 40    }
 41}
 42
 43impl EditPredictionProvider for ZetaEditPredictionProvider {
 44    fn name() -> &'static str {
 45        "zed-predict2"
 46    }
 47
 48    fn display_name() -> &'static str {
 49        "Zed's Edit Predictions 2"
 50    }
 51
 52    fn show_completions_in_menu() -> bool {
 53        true
 54    }
 55
 56    fn show_tab_accept_marker() -> bool {
 57        true
 58    }
 59
 60    fn data_collection_state(&self, _cx: &App) -> DataCollectionState {
 61        // TODO [zeta2]
 62        DataCollectionState::Unsupported
 63    }
 64
 65    fn toggle_data_collection(&mut self, _cx: &mut App) {
 66        // TODO [zeta2]
 67    }
 68
 69    fn usage(&self, cx: &App) -> Option<client::EditPredictionUsage> {
 70        self.zeta.read(cx).usage(cx)
 71    }
 72
 73    fn is_enabled(
 74        &self,
 75        _buffer: &Entity<language::Buffer>,
 76        _cursor_position: language::Anchor,
 77        cx: &App,
 78    ) -> bool {
 79        let zeta = self.zeta.read(cx);
 80        if zeta.edit_prediction_model == ZetaEditPredictionModel::Sweep {
 81            zeta.sweep_ai.api_token.is_some()
 82        } else {
 83            true
 84        }
 85    }
 86
 87    fn is_refreshing(&self, cx: &App) -> bool {
 88        self.zeta.read(cx).is_refreshing(&self.project)
 89    }
 90
 91    fn refresh(
 92        &mut self,
 93        buffer: Entity<language::Buffer>,
 94        cursor_position: language::Anchor,
 95        _debounce: bool,
 96        cx: &mut Context<Self>,
 97    ) {
 98        let zeta = self.zeta.read(cx);
 99
100        if zeta.user_store.read_with(cx, |user_store, _cx| {
101            user_store.account_too_young() || user_store.has_overdue_invoices()
102        }) {
103            return;
104        }
105
106        if let Some(current) = zeta.current_prediction_for_buffer(&buffer, &self.project, cx)
107            && let BufferEditPrediction::Local { prediction } = current
108            && prediction.interpolate(buffer.read(cx)).is_some()
109        {
110            return;
111        }
112
113        self.zeta.update(cx, |zeta, cx| {
114            zeta.refresh_context_if_needed(&self.project, &buffer, cursor_position, cx);
115            zeta.refresh_prediction_from_buffer(self.project.clone(), buffer, cursor_position, cx)
116        });
117    }
118
119    fn cycle(
120        &mut self,
121        _buffer: Entity<language::Buffer>,
122        _cursor_position: language::Anchor,
123        _direction: Direction,
124        _cx: &mut Context<Self>,
125    ) {
126    }
127
128    fn accept(&mut self, cx: &mut Context<Self>) {
129        self.zeta.update(cx, |zeta, cx| {
130            zeta.accept_current_prediction(&self.project, cx);
131        });
132    }
133
134    fn discard(&mut self, cx: &mut Context<Self>) {
135        self.zeta.update(cx, |zeta, cx| {
136            zeta.reject_current_prediction(
137                EditPredictionRejectReason::Discarded,
138                &self.project,
139                cx,
140            );
141        });
142    }
143
144    fn did_show(&mut self, cx: &mut Context<Self>) {
145        self.zeta.update(cx, |zeta, cx| {
146            zeta.did_show_current_prediction(&self.project, cx);
147        });
148    }
149
150    fn suggest(
151        &mut self,
152        buffer: &Entity<language::Buffer>,
153        cursor_position: language::Anchor,
154        cx: &mut Context<Self>,
155    ) -> Option<edit_prediction::EditPrediction> {
156        let prediction =
157            self.zeta
158                .read(cx)
159                .current_prediction_for_buffer(buffer, &self.project, cx)?;
160
161        let prediction = match prediction {
162            BufferEditPrediction::Local { prediction } => prediction,
163            BufferEditPrediction::Jump { prediction } => {
164                return Some(edit_prediction::EditPrediction::Jump {
165                    id: Some(prediction.id.to_string().into()),
166                    snapshot: prediction.snapshot.clone(),
167                    target: prediction.edits.first().unwrap().0.start,
168                });
169            }
170        };
171
172        let buffer = buffer.read(cx);
173        let snapshot = buffer.snapshot();
174
175        let Some(edits) = prediction.interpolate(&snapshot) else {
176            self.zeta.update(cx, |zeta, cx| {
177                zeta.reject_current_prediction(
178                    EditPredictionRejectReason::InterpolatedEmpty,
179                    &self.project,
180                    cx,
181                );
182            });
183            return None;
184        };
185
186        let cursor_row = cursor_position.to_point(&snapshot).row;
187        let (closest_edit_ix, (closest_edit_range, _)) =
188            edits.iter().enumerate().min_by_key(|(_, (range, _))| {
189                let distance_from_start = cursor_row.abs_diff(range.start.to_point(&snapshot).row);
190                let distance_from_end = cursor_row.abs_diff(range.end.to_point(&snapshot).row);
191                cmp::min(distance_from_start, distance_from_end)
192            })?;
193
194        let mut edit_start_ix = closest_edit_ix;
195        for (range, _) in edits[..edit_start_ix].iter().rev() {
196            let distance_from_closest_edit = closest_edit_range.start.to_point(&snapshot).row
197                - range.end.to_point(&snapshot).row;
198            if distance_from_closest_edit <= 1 {
199                edit_start_ix -= 1;
200            } else {
201                break;
202            }
203        }
204
205        let mut edit_end_ix = closest_edit_ix + 1;
206        for (range, _) in &edits[edit_end_ix..] {
207            let distance_from_closest_edit =
208                range.start.to_point(buffer).row - closest_edit_range.end.to_point(&snapshot).row;
209            if distance_from_closest_edit <= 1 {
210                edit_end_ix += 1;
211            } else {
212                break;
213            }
214        }
215
216        Some(edit_prediction::EditPrediction::Local {
217            id: Some(prediction.id.to_string().into()),
218            edits: edits[edit_start_ix..edit_end_ix].to_vec(),
219            edit_preview: Some(prediction.edit_preview.clone()),
220        })
221    }
222}