zed_edit_prediction_delegate.rs

  1use std::{cmp, sync::Arc};
  2
  3use client::{Client, UserStore};
  4use cloud_llm_client::EditPredictionRejectReason;
  5use edit_prediction_types::{
  6    DataCollectionState, EditPredictionDelegate, EditPredictionDiscardReason,
  7    EditPredictionIconSet, SuggestionDisplayType,
  8};
  9use feature_flags::FeatureFlagAppExt;
 10use fs::Fs;
 11use gpui::{App, Entity, prelude::*};
 12use language::{Buffer, ToPoint as _};
 13use project::Project;
 14use settings::{EditPredictionDataCollectionChoice, update_settings_file};
 15
 16use crate::{BufferEditPrediction, EditPredictionStore};
 17
 18pub struct ZedEditPredictionDelegate {
 19    store: Entity<EditPredictionStore>,
 20    project: Entity<Project>,
 21    singleton_buffer: Option<Entity<Buffer>>,
 22}
 23
 24impl ZedEditPredictionDelegate {
 25    pub fn new(
 26        project: Entity<Project>,
 27        singleton_buffer: Option<Entity<Buffer>>,
 28        client: &Arc<Client>,
 29        user_store: &Entity<UserStore>,
 30        cx: &mut Context<Self>,
 31    ) -> Self {
 32        let store = EditPredictionStore::global(client, user_store, cx);
 33        store.update(cx, |store, cx| {
 34            store.register_project(&project, cx);
 35        });
 36
 37        cx.observe(&store, |_this, _ep_store, cx| {
 38            cx.notify();
 39        })
 40        .detach();
 41
 42        Self {
 43            project: project,
 44            store: store,
 45            singleton_buffer,
 46        }
 47    }
 48}
 49
 50impl EditPredictionDelegate for ZedEditPredictionDelegate {
 51    fn name() -> &'static str {
 52        "zed-predict"
 53    }
 54
 55    fn display_name() -> &'static str {
 56        "Zed's Edit Predictions"
 57    }
 58
 59    fn show_predictions_in_menu() -> bool {
 60        true
 61    }
 62
 63    fn show_tab_accept_marker() -> bool {
 64        true
 65    }
 66
 67    fn icons(&self, cx: &App) -> EditPredictionIconSet {
 68        self.store.read(cx).icons(cx)
 69    }
 70
 71    fn data_collection_state(&self, cx: &App) -> DataCollectionState {
 72        if let Some(buffer) = &self.singleton_buffer
 73            && let Some(file) = buffer.read(cx).file()
 74        {
 75            let is_project_open_source =
 76                self.store
 77                    .read(cx)
 78                    .is_file_open_source(&self.project, file, cx);
 79
 80            if self.store.read(cx).is_data_collection_enabled(cx) {
 81                DataCollectionState::Enabled {
 82                    is_project_open_source,
 83                }
 84            } else {
 85                DataCollectionState::Disabled {
 86                    is_project_open_source,
 87                }
 88            }
 89        } else {
 90            DataCollectionState::Disabled {
 91                is_project_open_source: false,
 92            }
 93        }
 94    }
 95
 96    fn can_toggle_data_collection(&self, cx: &App) -> bool {
 97        if cx.is_staff() {
 98            return false;
 99        }
100
101        self.store
102            .read(cx)
103            .is_data_collection_allowed_by_organization(cx)
104    }
105
106    fn toggle_data_collection(&mut self, cx: &mut App) {
107        let fs = <dyn Fs>::global(cx);
108        let is_currently_enabled = self.store.read(cx).is_data_collection_enabled(cx);
109        update_settings_file(fs, cx, move |settings, _| {
110            let edit_predictions = settings
111                .project
112                .all_languages
113                .edit_predictions
114                .get_or_insert_default();
115
116            edit_predictions.allow_data_collection = Some(if is_currently_enabled {
117                EditPredictionDataCollectionChoice::No
118            } else {
119                EditPredictionDataCollectionChoice::Yes
120            });
121        });
122    }
123
124    fn usage(&self, cx: &App) -> Option<client::EditPredictionUsage> {
125        self.store.read(cx).usage(cx)
126    }
127
128    fn is_enabled(
129        &self,
130        _buffer: &Entity<language::Buffer>,
131        _cursor_position: language::Anchor,
132        _cx: &App,
133    ) -> bool {
134        true
135    }
136
137    fn is_refreshing(&self, cx: &App) -> bool {
138        self.store.read(cx).is_refreshing(&self.project)
139    }
140
141    fn refresh(
142        &mut self,
143        buffer: Entity<language::Buffer>,
144        cursor_position: language::Anchor,
145        _debounce: bool,
146        cx: &mut Context<Self>,
147    ) {
148        let store = self.store.read(cx);
149
150        if store.user_store.read_with(cx, |user_store, _cx| {
151            user_store.account_too_young() || user_store.has_overdue_invoices()
152        }) {
153            return;
154        }
155
156        self.store.update(cx, |store, cx| {
157            if let Some(current) =
158                store.prediction_at(&buffer, Some(cursor_position), &self.project, cx)
159                && let BufferEditPrediction::Local { prediction } = current
160                && prediction.interpolate(buffer.read(cx)).is_some()
161            {
162                return;
163            }
164
165            store.refresh_context(&self.project, &buffer, cursor_position, cx);
166            store.refresh_prediction_from_buffer(self.project.clone(), buffer, cursor_position, cx)
167        });
168    }
169
170    fn accept(&mut self, cx: &mut Context<Self>) {
171        self.store.update(cx, |store, cx| {
172            store.accept_current_prediction(&self.project, cx);
173        });
174    }
175
176    fn discard(&mut self, reason: EditPredictionDiscardReason, cx: &mut Context<Self>) {
177        let reject_reason = match reason {
178            EditPredictionDiscardReason::Rejected => EditPredictionRejectReason::Rejected,
179            EditPredictionDiscardReason::Ignored => EditPredictionRejectReason::Discarded,
180        };
181        self.store.update(cx, |store, cx| {
182            store.reject_current_prediction(reject_reason, &self.project, cx);
183        });
184    }
185
186    fn did_show(&mut self, display_type: SuggestionDisplayType, cx: &mut Context<Self>) {
187        self.store.update(cx, |store, cx| {
188            store.did_show_current_prediction(&self.project, display_type, cx);
189        });
190    }
191
192    fn suggest(
193        &mut self,
194        buffer: &Entity<language::Buffer>,
195        cursor_position: language::Anchor,
196        cx: &mut Context<Self>,
197    ) -> Option<edit_prediction_types::EditPrediction> {
198        self.store.update(cx, |store, cx| {
199            let prediction =
200                store.prediction_at(buffer, Some(cursor_position), &self.project, cx)?;
201
202            let prediction = match prediction {
203                BufferEditPrediction::Local { prediction } => prediction,
204                BufferEditPrediction::Jump { prediction } => {
205                    return Some(edit_prediction_types::EditPrediction::Jump {
206                        id: Some(prediction.id.0.clone()),
207                        snapshot: prediction.snapshot.clone(),
208                        target: prediction.edits.first().unwrap().0.start,
209                    });
210                }
211            };
212
213            let buffer = buffer.read(cx);
214            let snapshot = buffer.snapshot();
215
216            let Some(edits) = prediction.interpolate(&snapshot) else {
217                store.reject_current_prediction(
218                    EditPredictionRejectReason::InterpolatedEmpty,
219                    &self.project,
220                    cx,
221                );
222                return None;
223            };
224
225            let cursor_row = cursor_position.to_point(&snapshot).row;
226            let (closest_edit_ix, (closest_edit_range, _)) =
227                edits.iter().enumerate().min_by_key(|(_, (range, _))| {
228                    let distance_from_start =
229                        cursor_row.abs_diff(range.start.to_point(&snapshot).row);
230                    let distance_from_end = cursor_row.abs_diff(range.end.to_point(&snapshot).row);
231                    cmp::min(distance_from_start, distance_from_end)
232                })?;
233
234            let mut edit_start_ix = closest_edit_ix;
235            for (range, _) in edits[..edit_start_ix].iter().rev() {
236                let distance_from_closest_edit = closest_edit_range.start.to_point(&snapshot).row
237                    - range.end.to_point(&snapshot).row;
238                if distance_from_closest_edit <= 1 {
239                    edit_start_ix -= 1;
240                } else {
241                    break;
242                }
243            }
244
245            let mut edit_end_ix = closest_edit_ix + 1;
246            for (range, _) in &edits[edit_end_ix..] {
247                let distance_from_closest_edit = range.start.to_point(buffer).row
248                    - closest_edit_range.end.to_point(&snapshot).row;
249                if distance_from_closest_edit <= 1 {
250                    edit_end_ix += 1;
251                } else {
252                    break;
253                }
254            }
255
256            Some(edit_prediction_types::EditPrediction::Local {
257                id: Some(prediction.id.0.clone()),
258                edits: edits[edit_start_ix..edit_end_ix].to_vec(),
259                cursor_position: prediction.cursor_position,
260                edit_preview: Some(prediction.edit_preview.clone()),
261            })
262        })
263    }
264}