zed_edit_prediction_delegate.rs

  1use std::{cmp, sync::Arc};
  2
  3use client::{Client, UserStore};
  4use cloud_llm_client::EditPredictionRejectReason;
  5use edit_prediction_types::{
  6    DataCollectionState, EditPredictionDelegate, EditPredictionDiscardReason,
  7    EditPredictionIconSet, SuggestionDisplayType,
  8};
  9use gpui::{App, Entity, prelude::*};
 10use language::{Buffer, ToPoint as _};
 11use project::Project;
 12use ui::prelude::*;
 13
 14use crate::{BufferEditPrediction, EditPredictionModel, EditPredictionStore};
 15
 16pub struct ZedEditPredictionDelegate {
 17    store: Entity<EditPredictionStore>,
 18    project: Entity<Project>,
 19    singleton_buffer: Option<Entity<Buffer>>,
 20}
 21
 22impl ZedEditPredictionDelegate {
 23    pub fn new(
 24        project: Entity<Project>,
 25        singleton_buffer: Option<Entity<Buffer>>,
 26        client: &Arc<Client>,
 27        user_store: &Entity<UserStore>,
 28        cx: &mut Context<Self>,
 29    ) -> Self {
 30        let store = EditPredictionStore::global(client, user_store, cx);
 31        store.update(cx, |store, cx| {
 32            store.register_project(&project, cx);
 33        });
 34
 35        cx.observe(&store, |_this, _ep_store, cx| {
 36            cx.notify();
 37        })
 38        .detach();
 39
 40        Self {
 41            project: project,
 42            store: store,
 43            singleton_buffer,
 44        }
 45    }
 46}
 47
 48impl EditPredictionDelegate for ZedEditPredictionDelegate {
 49    fn name() -> &'static str {
 50        "zed-predict"
 51    }
 52
 53    fn display_name() -> &'static str {
 54        "Zed's Edit Predictions"
 55    }
 56
 57    fn show_predictions_in_menu() -> bool {
 58        true
 59    }
 60
 61    fn show_tab_accept_marker() -> bool {
 62        true
 63    }
 64
 65    fn icons(&self, cx: &App) -> EditPredictionIconSet {
 66        match self.store.read(cx).edit_prediction_model {
 67            EditPredictionModel::Sweep => EditPredictionIconSet::new(IconName::SweepAi)
 68                .with_disabled(IconName::SweepAiDisabled)
 69                .with_up(IconName::SweepAiUp)
 70                .with_down(IconName::SweepAiDown)
 71                .with_error(IconName::SweepAiError),
 72            EditPredictionModel::Mercury => EditPredictionIconSet::new(IconName::Inception),
 73            EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {
 74                EditPredictionIconSet::new(IconName::ZedPredict)
 75                    .with_disabled(IconName::ZedPredictDisabled)
 76                    .with_up(IconName::ZedPredictUp)
 77                    .with_down(IconName::ZedPredictDown)
 78                    .with_error(IconName::ZedPredictError)
 79            }
 80            EditPredictionModel::Ollama => EditPredictionIconSet::new(IconName::AiOllama),
 81        }
 82    }
 83
 84    fn data_collection_state(&self, cx: &App) -> DataCollectionState {
 85        if let Some(buffer) = &self.singleton_buffer
 86            && let Some(file) = buffer.read(cx).file()
 87        {
 88            let is_project_open_source =
 89                self.store
 90                    .read(cx)
 91                    .is_file_open_source(&self.project, file, cx);
 92            if self.store.read(cx).data_collection_choice.is_enabled(cx) {
 93                DataCollectionState::Enabled {
 94                    is_project_open_source,
 95                }
 96            } else {
 97                DataCollectionState::Disabled {
 98                    is_project_open_source,
 99                }
100            }
101        } else {
102            return DataCollectionState::Disabled {
103                is_project_open_source: false,
104            };
105        }
106    }
107
108    fn toggle_data_collection(&mut self, cx: &mut App) {
109        self.store.update(cx, |store, cx| {
110            store.toggle_data_collection_choice(cx);
111        });
112    }
113
114    fn usage(&self, cx: &App) -> Option<client::EditPredictionUsage> {
115        self.store.read(cx).usage(cx)
116    }
117
118    fn is_enabled(
119        &self,
120        _buffer: &Entity<language::Buffer>,
121        _cursor_position: language::Anchor,
122        cx: &App,
123    ) -> bool {
124        let store = self.store.read(cx);
125        if store.edit_prediction_model == EditPredictionModel::Sweep {
126            store.has_sweep_api_token(cx)
127        } else {
128            true
129        }
130    }
131
132    fn is_refreshing(&self, cx: &App) -> bool {
133        self.store.read(cx).is_refreshing(&self.project)
134    }
135
136    fn refresh(
137        &mut self,
138        buffer: Entity<language::Buffer>,
139        cursor_position: language::Anchor,
140        _debounce: bool,
141        cx: &mut Context<Self>,
142    ) {
143        let store = self.store.read(cx);
144
145        if store.user_store.read_with(cx, |user_store, _cx| {
146            user_store.account_too_young() || user_store.has_overdue_invoices()
147        }) {
148            return;
149        }
150
151        self.store.update(cx, |store, cx| {
152            if let Some(current) =
153                store.prediction_at(&buffer, Some(cursor_position), &self.project, cx)
154                && let BufferEditPrediction::Local { prediction } = current
155                && prediction.interpolate(buffer.read(cx)).is_some()
156            {
157                return;
158            }
159
160            store.refresh_context(&self.project, &buffer, cursor_position, cx);
161            store.refresh_prediction_from_buffer(self.project.clone(), buffer, cursor_position, cx)
162        });
163    }
164
165    fn accept(&mut self, cx: &mut Context<Self>) {
166        self.store.update(cx, |store, cx| {
167            store.accept_current_prediction(&self.project, cx);
168        });
169    }
170
171    fn discard(&mut self, reason: EditPredictionDiscardReason, cx: &mut Context<Self>) {
172        let reject_reason = match reason {
173            EditPredictionDiscardReason::Rejected => EditPredictionRejectReason::Rejected,
174            EditPredictionDiscardReason::Ignored => EditPredictionRejectReason::Discarded,
175        };
176        self.store.update(cx, |store, cx| {
177            store.reject_current_prediction(reject_reason, &self.project, cx);
178        });
179    }
180
181    fn did_show(&mut self, display_type: SuggestionDisplayType, cx: &mut Context<Self>) {
182        self.store.update(cx, |store, cx| {
183            store.did_show_current_prediction(&self.project, display_type, cx);
184        });
185    }
186
187    fn suggest(
188        &mut self,
189        buffer: &Entity<language::Buffer>,
190        cursor_position: language::Anchor,
191        cx: &mut Context<Self>,
192    ) -> Option<edit_prediction_types::EditPrediction> {
193        self.store.update(cx, |store, cx| {
194            let prediction =
195                store.prediction_at(buffer, Some(cursor_position), &self.project, cx)?;
196
197            let prediction = match prediction {
198                BufferEditPrediction::Local { prediction } => prediction,
199                BufferEditPrediction::Jump { prediction } => {
200                    return Some(edit_prediction_types::EditPrediction::Jump {
201                        id: Some(prediction.id.to_string().into()),
202                        snapshot: prediction.snapshot.clone(),
203                        target: prediction.edits.first().unwrap().0.start,
204                    });
205                }
206            };
207
208            let buffer = buffer.read(cx);
209            let snapshot = buffer.snapshot();
210
211            let Some(edits) = prediction.interpolate(&snapshot) else {
212                store.reject_current_prediction(
213                    EditPredictionRejectReason::InterpolatedEmpty,
214                    &self.project,
215                    cx,
216                );
217                return None;
218            };
219
220            let cursor_row = cursor_position.to_point(&snapshot).row;
221            let (closest_edit_ix, (closest_edit_range, _)) =
222                edits.iter().enumerate().min_by_key(|(_, (range, _))| {
223                    let distance_from_start =
224                        cursor_row.abs_diff(range.start.to_point(&snapshot).row);
225                    let distance_from_end = cursor_row.abs_diff(range.end.to_point(&snapshot).row);
226                    cmp::min(distance_from_start, distance_from_end)
227                })?;
228
229            let mut edit_start_ix = closest_edit_ix;
230            for (range, _) in edits[..edit_start_ix].iter().rev() {
231                let distance_from_closest_edit = closest_edit_range.start.to_point(&snapshot).row
232                    - range.end.to_point(&snapshot).row;
233                if distance_from_closest_edit <= 1 {
234                    edit_start_ix -= 1;
235                } else {
236                    break;
237                }
238            }
239
240            let mut edit_end_ix = closest_edit_ix + 1;
241            for (range, _) in &edits[edit_end_ix..] {
242                let distance_from_closest_edit = range.start.to_point(buffer).row
243                    - closest_edit_range.end.to_point(&snapshot).row;
244                if distance_from_closest_edit <= 1 {
245                    edit_end_ix += 1;
246                } else {
247                    break;
248                }
249            }
250
251            Some(edit_prediction_types::EditPrediction::Local {
252                id: Some(prediction.id.to_string().into()),
253                edits: edits[edit_start_ix..edit_end_ix].to_vec(),
254                cursor_position: prediction.cursor_position,
255                edit_preview: Some(prediction.edit_preview.clone()),
256            })
257        })
258    }
259}