zed_edit_prediction_delegate.rs

  1use std::{cmp, sync::Arc};
  2
  3use client::{Client, UserStore};
  4use cloud_llm_client::EditPredictionRejectReason;
  5use edit_prediction_types::{
  6    DataCollectionState, EditPredictionDelegate, EditPredictionDiscardReason,
  7    EditPredictionIconSet, SuggestionDisplayType,
  8};
  9use feature_flags::FeatureFlagAppExt;
 10use gpui::{App, Entity, prelude::*};
 11use language::{Buffer, ToPoint as _};
 12use project::Project;
 13
 14use crate::{BufferEditPrediction, EditPredictionStore};
 15
 16pub struct ZedEditPredictionDelegate {
 17    store: Entity<EditPredictionStore>,
 18    project: Entity<Project>,
 19    singleton_buffer: Option<Entity<Buffer>>,
 20}
 21
 22impl ZedEditPredictionDelegate {
 23    pub fn new(
 24        project: Entity<Project>,
 25        singleton_buffer: Option<Entity<Buffer>>,
 26        client: &Arc<Client>,
 27        user_store: &Entity<UserStore>,
 28        cx: &mut Context<Self>,
 29    ) -> Self {
 30        let store = EditPredictionStore::global(client, user_store, cx);
 31        store.update(cx, |store, cx| {
 32            store.register_project(&project, cx);
 33        });
 34
 35        cx.observe(&store, |_this, _ep_store, cx| {
 36            cx.notify();
 37        })
 38        .detach();
 39
 40        Self {
 41            project: project,
 42            store: store,
 43            singleton_buffer,
 44        }
 45    }
 46}
 47
 48impl EditPredictionDelegate for ZedEditPredictionDelegate {
 49    fn name() -> &'static str {
 50        "zed-predict"
 51    }
 52
 53    fn display_name() -> &'static str {
 54        "Zed's Edit Predictions"
 55    }
 56
 57    fn show_predictions_in_menu() -> bool {
 58        true
 59    }
 60
 61    fn show_tab_accept_marker() -> bool {
 62        true
 63    }
 64
 65    fn icons(&self, cx: &App) -> EditPredictionIconSet {
 66        self.store.read(cx).icons(cx)
 67    }
 68
 69    fn data_collection_state(&self, cx: &App) -> DataCollectionState {
 70        if let Some(buffer) = &self.singleton_buffer
 71            && let Some(file) = buffer.read(cx).file()
 72        {
 73            let is_project_open_source =
 74                self.store
 75                    .read(cx)
 76                    .is_file_open_source(&self.project, file, cx);
 77
 78            if let Some(organization_configuration) = self
 79                .store
 80                .read(cx)
 81                .user_store
 82                .read(cx)
 83                .current_organization_configuration()
 84            {
 85                if !organization_configuration
 86                    .edit_prediction
 87                    .is_feedback_enabled
 88                {
 89                    return DataCollectionState::Disabled {
 90                        is_project_open_source,
 91                    };
 92                }
 93            }
 94
 95            if self.store.read(cx).data_collection_choice.is_enabled(cx) {
 96                DataCollectionState::Enabled {
 97                    is_project_open_source,
 98                }
 99            } else {
100                DataCollectionState::Disabled {
101                    is_project_open_source,
102                }
103            }
104        } else {
105            return DataCollectionState::Disabled {
106                is_project_open_source: false,
107            };
108        }
109    }
110
111    fn can_toggle_data_collection(&self, cx: &App) -> bool {
112        if cx.is_staff() {
113            return false;
114        }
115
116        if let Some(organization_configuration) = self
117            .store
118            .read(cx)
119            .user_store
120            .read(cx)
121            .current_organization_configuration()
122        {
123            if !organization_configuration
124                .edit_prediction
125                .is_feedback_enabled
126            {
127                return false;
128            }
129        }
130
131        true
132    }
133
134    fn toggle_data_collection(&mut self, cx: &mut App) {
135        self.store.update(cx, |store, cx| {
136            store.toggle_data_collection_choice(cx);
137        });
138    }
139
140    fn usage(&self, cx: &App) -> Option<client::EditPredictionUsage> {
141        self.store.read(cx).usage(cx)
142    }
143
144    fn is_enabled(
145        &self,
146        _buffer: &Entity<language::Buffer>,
147        _cursor_position: language::Anchor,
148        _cx: &App,
149    ) -> bool {
150        true
151    }
152
153    fn is_refreshing(&self, cx: &App) -> bool {
154        self.store.read(cx).is_refreshing(&self.project)
155    }
156
157    fn refresh(
158        &mut self,
159        buffer: Entity<language::Buffer>,
160        cursor_position: language::Anchor,
161        _debounce: bool,
162        cx: &mut Context<Self>,
163    ) {
164        let store = self.store.read(cx);
165
166        if store.user_store.read_with(cx, |user_store, _cx| {
167            user_store.account_too_young() || user_store.has_overdue_invoices()
168        }) {
169            return;
170        }
171
172        self.store.update(cx, |store, cx| {
173            if let Some(current) =
174                store.prediction_at(&buffer, Some(cursor_position), &self.project, cx)
175                && let BufferEditPrediction::Local { prediction } = current
176                && prediction.interpolate(buffer.read(cx)).is_some()
177            {
178                return;
179            }
180
181            store.refresh_context(&self.project, &buffer, cursor_position, cx);
182            store.refresh_prediction_from_buffer(self.project.clone(), buffer, cursor_position, cx)
183        });
184    }
185
186    fn accept(&mut self, cx: &mut Context<Self>) {
187        self.store.update(cx, |store, cx| {
188            store.accept_current_prediction(&self.project, cx);
189        });
190    }
191
192    fn discard(&mut self, reason: EditPredictionDiscardReason, cx: &mut Context<Self>) {
193        let reject_reason = match reason {
194            EditPredictionDiscardReason::Rejected => EditPredictionRejectReason::Rejected,
195            EditPredictionDiscardReason::Ignored => EditPredictionRejectReason::Discarded,
196        };
197        self.store.update(cx, |store, cx| {
198            store.reject_current_prediction(reject_reason, &self.project, cx);
199        });
200    }
201
202    fn did_show(&mut self, display_type: SuggestionDisplayType, cx: &mut Context<Self>) {
203        self.store.update(cx, |store, cx| {
204            store.did_show_current_prediction(&self.project, display_type, cx);
205        });
206    }
207
208    fn suggest(
209        &mut self,
210        buffer: &Entity<language::Buffer>,
211        cursor_position: language::Anchor,
212        cx: &mut Context<Self>,
213    ) -> Option<edit_prediction_types::EditPrediction> {
214        self.store.update(cx, |store, cx| {
215            let prediction =
216                store.prediction_at(buffer, Some(cursor_position), &self.project, cx)?;
217
218            let prediction = match prediction {
219                BufferEditPrediction::Local { prediction } => prediction,
220                BufferEditPrediction::Jump { prediction } => {
221                    return Some(edit_prediction_types::EditPrediction::Jump {
222                        id: Some(prediction.id.0.clone()),
223                        snapshot: prediction.snapshot.clone(),
224                        target: prediction.edits.first().unwrap().0.start,
225                    });
226                }
227            };
228
229            let buffer = buffer.read(cx);
230            let snapshot = buffer.snapshot();
231
232            let Some(edits) = prediction.interpolate(&snapshot) else {
233                store.reject_current_prediction(
234                    EditPredictionRejectReason::InterpolatedEmpty,
235                    &self.project,
236                    cx,
237                );
238                return None;
239            };
240
241            let cursor_row = cursor_position.to_point(&snapshot).row;
242            let (closest_edit_ix, (closest_edit_range, _)) =
243                edits.iter().enumerate().min_by_key(|(_, (range, _))| {
244                    let distance_from_start =
245                        cursor_row.abs_diff(range.start.to_point(&snapshot).row);
246                    let distance_from_end = cursor_row.abs_diff(range.end.to_point(&snapshot).row);
247                    cmp::min(distance_from_start, distance_from_end)
248                })?;
249
250            let mut edit_start_ix = closest_edit_ix;
251            for (range, _) in edits[..edit_start_ix].iter().rev() {
252                let distance_from_closest_edit = closest_edit_range.start.to_point(&snapshot).row
253                    - range.end.to_point(&snapshot).row;
254                if distance_from_closest_edit <= 1 {
255                    edit_start_ix -= 1;
256                } else {
257                    break;
258                }
259            }
260
261            let mut edit_end_ix = closest_edit_ix + 1;
262            for (range, _) in &edits[edit_end_ix..] {
263                let distance_from_closest_edit = range.start.to_point(buffer).row
264                    - closest_edit_range.end.to_point(&snapshot).row;
265                if distance_from_closest_edit <= 1 {
266                    edit_end_ix += 1;
267                } else {
268                    break;
269                }
270            }
271
272            Some(edit_prediction_types::EditPrediction::Local {
273                id: Some(prediction.id.0.clone()),
274                edits: edits[edit_start_ix..edit_end_ix].to_vec(),
275                cursor_position: prediction.cursor_position,
276                edit_preview: Some(prediction.edit_preview.clone()),
277            })
278        })
279    }
280}