zed_edit_prediction_delegate.rs

  1use std::{cmp, sync::Arc};
  2
  3use client::{Client, UserStore};
  4use cloud_llm_client::EditPredictionRejectReason;
  5use edit_prediction_types::{
  6    DataCollectionState, EditPredictionDelegate, EditPredictionIconSet, SuggestionDisplayType,
  7};
  8use gpui::{App, Entity, prelude::*};
  9use language::{Buffer, ToPoint as _};
 10use project::Project;
 11use ui::prelude::*;
 12
 13use crate::{BufferEditPrediction, EditPredictionModel, EditPredictionStore};
 14
 15pub struct ZedEditPredictionDelegate {
 16    store: Entity<EditPredictionStore>,
 17    project: Entity<Project>,
 18    singleton_buffer: Option<Entity<Buffer>>,
 19}
 20
 21impl ZedEditPredictionDelegate {
 22    pub fn new(
 23        project: Entity<Project>,
 24        singleton_buffer: Option<Entity<Buffer>>,
 25        client: &Arc<Client>,
 26        user_store: &Entity<UserStore>,
 27        cx: &mut Context<Self>,
 28    ) -> Self {
 29        let store = EditPredictionStore::global(client, user_store, cx);
 30        store.update(cx, |store, cx| {
 31            store.register_project(&project, cx);
 32        });
 33
 34        cx.observe(&store, |_this, _ep_store, cx| {
 35            cx.notify();
 36        })
 37        .detach();
 38
 39        Self {
 40            project: project,
 41            store: store,
 42            singleton_buffer,
 43        }
 44    }
 45}
 46
 47impl EditPredictionDelegate for ZedEditPredictionDelegate {
 48    fn name() -> &'static str {
 49        "zed-predict"
 50    }
 51
 52    fn display_name() -> &'static str {
 53        "Zed's Edit Predictions"
 54    }
 55
 56    fn show_predictions_in_menu() -> bool {
 57        true
 58    }
 59
 60    fn show_tab_accept_marker() -> bool {
 61        true
 62    }
 63
 64    fn icons(&self, cx: &App) -> EditPredictionIconSet {
 65        match self.store.read(cx).edit_prediction_model {
 66            EditPredictionModel::Ollama => EditPredictionIconSet::new(IconName::AiOllama),
 67            EditPredictionModel::Sweep => EditPredictionIconSet::new(IconName::SweepAi)
 68                .with_disabled(IconName::SweepAiDisabled)
 69                .with_up(IconName::SweepAiUp)
 70                .with_down(IconName::SweepAiDown)
 71                .with_error(IconName::SweepAiError),
 72            EditPredictionModel::Mercury => EditPredictionIconSet::new(IconName::Inception),
 73            EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 { .. } => {
 74                EditPredictionIconSet::new(IconName::ZedPredict)
 75                    .with_disabled(IconName::ZedPredictDisabled)
 76                    .with_up(IconName::ZedPredictUp)
 77                    .with_down(IconName::ZedPredictDown)
 78                    .with_error(IconName::ZedPredictError)
 79            }
 80        }
 81    }
 82
 83    fn data_collection_state(&self, cx: &App) -> DataCollectionState {
 84        if let Some(buffer) = &self.singleton_buffer
 85            && let Some(file) = buffer.read(cx).file()
 86        {
 87            let is_project_open_source =
 88                self.store
 89                    .read(cx)
 90                    .is_file_open_source(&self.project, file, cx);
 91            if self.store.read(cx).data_collection_choice.is_enabled(cx) {
 92                DataCollectionState::Enabled {
 93                    is_project_open_source,
 94                }
 95            } else {
 96                DataCollectionState::Disabled {
 97                    is_project_open_source,
 98                }
 99            }
100        } else {
101            return DataCollectionState::Disabled {
102                is_project_open_source: false,
103            };
104        }
105    }
106
107    fn toggle_data_collection(&mut self, cx: &mut App) {
108        self.store.update(cx, |store, cx| {
109            store.toggle_data_collection_choice(cx);
110        });
111    }
112
113    fn usage(&self, cx: &App) -> Option<client::EditPredictionUsage> {
114        self.store.read(cx).usage(cx)
115    }
116
117    fn is_enabled(
118        &self,
119        _buffer: &Entity<language::Buffer>,
120        _cursor_position: language::Anchor,
121        cx: &App,
122    ) -> bool {
123        let store = self.store.read(cx);
124        if store.edit_prediction_model == EditPredictionModel::Sweep {
125            store.has_sweep_api_token(cx)
126        } else {
127            true
128        }
129    }
130
131    fn is_refreshing(&self, cx: &App) -> bool {
132        self.store.read(cx).is_refreshing(&self.project)
133    }
134
135    fn refresh(
136        &mut self,
137        buffer: Entity<language::Buffer>,
138        cursor_position: language::Anchor,
139        _debounce: bool,
140        cx: &mut Context<Self>,
141    ) {
142        let store = self.store.read(cx);
143
144        if store.user_store.read_with(cx, |user_store, _cx| {
145            user_store.account_too_young() || user_store.has_overdue_invoices()
146        }) {
147            return;
148        }
149
150        self.store.update(cx, |store, cx| {
151            if let Some(current) =
152                store.prediction_at(&buffer, Some(cursor_position), &self.project, cx)
153                && let BufferEditPrediction::Local { prediction } = current
154                && prediction.interpolate(buffer.read(cx)).is_some()
155            {
156                return;
157            }
158
159            store.refresh_context(&self.project, &buffer, cursor_position, cx);
160            store.refresh_prediction_from_buffer(self.project.clone(), buffer, cursor_position, cx)
161        });
162    }
163
164    fn accept(&mut self, cx: &mut Context<Self>) {
165        self.store.update(cx, |store, cx| {
166            store.accept_current_prediction(&self.project, cx);
167        });
168    }
169
170    fn discard(&mut self, cx: &mut Context<Self>) {
171        self.store.update(cx, |store, _cx| {
172            store.reject_current_prediction(EditPredictionRejectReason::Discarded, &self.project);
173        });
174    }
175
176    fn did_show(&mut self, display_type: SuggestionDisplayType, cx: &mut Context<Self>) {
177        self.store.update(cx, |store, cx| {
178            store.did_show_current_prediction(&self.project, display_type, cx);
179        });
180    }
181
182    fn suggest(
183        &mut self,
184        buffer: &Entity<language::Buffer>,
185        cursor_position: language::Anchor,
186        cx: &mut Context<Self>,
187    ) -> Option<edit_prediction_types::EditPrediction> {
188        self.store.update(cx, |store, cx| {
189            let prediction =
190                store.prediction_at(buffer, Some(cursor_position), &self.project, cx)?;
191
192            let prediction = match prediction {
193                BufferEditPrediction::Local { prediction } => prediction,
194                BufferEditPrediction::Jump { prediction } => {
195                    return Some(edit_prediction_types::EditPrediction::Jump {
196                        id: Some(prediction.id.to_string().into()),
197                        snapshot: prediction.snapshot.clone(),
198                        target: prediction.edits.first().unwrap().0.start,
199                    });
200                }
201            };
202
203            let buffer = buffer.read(cx);
204            let snapshot = buffer.snapshot();
205
206            let Some(edits) = prediction.interpolate(&snapshot) else {
207                store.reject_current_prediction(
208                    EditPredictionRejectReason::InterpolatedEmpty,
209                    &self.project,
210                );
211                return None;
212            };
213
214            let cursor_row = cursor_position.to_point(&snapshot).row;
215            let (closest_edit_ix, (closest_edit_range, _)) =
216                edits.iter().enumerate().min_by_key(|(_, (range, _))| {
217                    let distance_from_start =
218                        cursor_row.abs_diff(range.start.to_point(&snapshot).row);
219                    let distance_from_end = cursor_row.abs_diff(range.end.to_point(&snapshot).row);
220                    cmp::min(distance_from_start, distance_from_end)
221                })?;
222
223            let mut edit_start_ix = closest_edit_ix;
224            for (range, _) in edits[..edit_start_ix].iter().rev() {
225                let distance_from_closest_edit = closest_edit_range.start.to_point(&snapshot).row
226                    - range.end.to_point(&snapshot).row;
227                if distance_from_closest_edit <= 1 {
228                    edit_start_ix -= 1;
229                } else {
230                    break;
231                }
232            }
233
234            let mut edit_end_ix = closest_edit_ix + 1;
235            for (range, _) in &edits[edit_end_ix..] {
236                let distance_from_closest_edit = range.start.to_point(buffer).row
237                    - closest_edit_range.end.to_point(&snapshot).row;
238                if distance_from_closest_edit <= 1 {
239                    edit_end_ix += 1;
240                } else {
241                    break;
242                }
243            }
244
245            Some(edit_prediction_types::EditPrediction::Local {
246                id: Some(prediction.id.to_string().into()),
247                edits: edits[edit_start_ix..edit_end_ix].to_vec(),
248                edit_preview: Some(prediction.edit_preview.clone()),
249            })
250        })
251    }
252}