zed_edit_prediction_delegate.rs

  1use std::{cmp, sync::Arc};
  2
  3use client::{Client, UserStore};
  4use cloud_llm_client::EditPredictionRejectReason;
  5use edit_prediction_types::{
  6    DataCollectionState, EditPredictionDelegate, EditPredictionDiscardReason,
  7    EditPredictionIconSet, SuggestionDisplayType,
  8};
  9use gpui::{App, Entity, prelude::*};
 10use language::{Buffer, ToPoint as _};
 11use project::Project;
 12
 13use crate::{BufferEditPrediction, EditPredictionModel, EditPredictionStore};
 14
 15pub struct ZedEditPredictionDelegate {
 16    store: Entity<EditPredictionStore>,
 17    project: Entity<Project>,
 18    singleton_buffer: Option<Entity<Buffer>>,
 19}
 20
 21impl ZedEditPredictionDelegate {
 22    pub fn new(
 23        project: Entity<Project>,
 24        singleton_buffer: Option<Entity<Buffer>>,
 25        client: &Arc<Client>,
 26        user_store: &Entity<UserStore>,
 27        cx: &mut Context<Self>,
 28    ) -> Self {
 29        let store = EditPredictionStore::global(client, user_store, cx);
 30        store.update(cx, |store, cx| {
 31            store.register_project(&project, cx);
 32        });
 33
 34        cx.observe(&store, |_this, _ep_store, cx| {
 35            cx.notify();
 36        })
 37        .detach();
 38
 39        Self {
 40            project: project,
 41            store: store,
 42            singleton_buffer,
 43        }
 44    }
 45}
 46
 47impl EditPredictionDelegate for ZedEditPredictionDelegate {
 48    fn name() -> &'static str {
 49        "zed-predict"
 50    }
 51
 52    fn display_name() -> &'static str {
 53        "Zed's Edit Predictions"
 54    }
 55
 56    fn show_predictions_in_menu() -> bool {
 57        true
 58    }
 59
 60    fn show_tab_accept_marker() -> bool {
 61        true
 62    }
 63
 64    fn icons(&self, cx: &App) -> EditPredictionIconSet {
 65        self.store.read(cx).icons(cx)
 66    }
 67
 68    fn data_collection_state(&self, cx: &App) -> DataCollectionState {
 69        if let Some(buffer) = &self.singleton_buffer
 70            && let Some(file) = buffer.read(cx).file()
 71        {
 72            let is_project_open_source =
 73                self.store
 74                    .read(cx)
 75                    .is_file_open_source(&self.project, file, cx);
 76            if self.store.read(cx).data_collection_choice.is_enabled(cx) {
 77                DataCollectionState::Enabled {
 78                    is_project_open_source,
 79                }
 80            } else {
 81                DataCollectionState::Disabled {
 82                    is_project_open_source,
 83                }
 84            }
 85        } else {
 86            return DataCollectionState::Disabled {
 87                is_project_open_source: false,
 88            };
 89        }
 90    }
 91
 92    fn toggle_data_collection(&mut self, cx: &mut App) {
 93        self.store.update(cx, |store, cx| {
 94            store.toggle_data_collection_choice(cx);
 95        });
 96    }
 97
 98    fn usage(&self, cx: &App) -> Option<client::EditPredictionUsage> {
 99        self.store.read(cx).usage(cx)
100    }
101
102    fn is_enabled(
103        &self,
104        _buffer: &Entity<language::Buffer>,
105        _cursor_position: language::Anchor,
106        cx: &App,
107    ) -> bool {
108        let store = self.store.read(cx);
109        if store.edit_prediction_model == EditPredictionModel::Sweep {
110            store.has_sweep_api_token(cx)
111        } else {
112            true
113        }
114    }
115
116    fn is_refreshing(&self, cx: &App) -> bool {
117        self.store.read(cx).is_refreshing(&self.project)
118    }
119
120    fn refresh(
121        &mut self,
122        buffer: Entity<language::Buffer>,
123        cursor_position: language::Anchor,
124        _debounce: bool,
125        cx: &mut Context<Self>,
126    ) {
127        let store = self.store.read(cx);
128
129        if store.user_store.read_with(cx, |user_store, _cx| {
130            user_store.account_too_young() || user_store.has_overdue_invoices()
131        }) {
132            return;
133        }
134
135        self.store.update(cx, |store, cx| {
136            if let Some(current) =
137                store.prediction_at(&buffer, Some(cursor_position), &self.project, cx)
138                && let BufferEditPrediction::Local { prediction } = current
139                && prediction.interpolate(buffer.read(cx)).is_some()
140            {
141                return;
142            }
143
144            store.refresh_context(&self.project, &buffer, cursor_position, cx);
145            store.refresh_prediction_from_buffer(self.project.clone(), buffer, cursor_position, cx)
146        });
147    }
148
149    fn accept(&mut self, cx: &mut Context<Self>) {
150        self.store.update(cx, |store, cx| {
151            store.accept_current_prediction(&self.project, cx);
152        });
153    }
154
155    fn discard(&mut self, reason: EditPredictionDiscardReason, cx: &mut Context<Self>) {
156        let reject_reason = match reason {
157            EditPredictionDiscardReason::Rejected => EditPredictionRejectReason::Rejected,
158            EditPredictionDiscardReason::Ignored => EditPredictionRejectReason::Discarded,
159        };
160        self.store.update(cx, |store, cx| {
161            store.reject_current_prediction(reject_reason, &self.project, cx);
162        });
163    }
164
165    fn did_show(&mut self, display_type: SuggestionDisplayType, cx: &mut Context<Self>) {
166        self.store.update(cx, |store, cx| {
167            store.did_show_current_prediction(&self.project, display_type, cx);
168        });
169    }
170
171    fn suggest(
172        &mut self,
173        buffer: &Entity<language::Buffer>,
174        cursor_position: language::Anchor,
175        cx: &mut Context<Self>,
176    ) -> Option<edit_prediction_types::EditPrediction> {
177        self.store.update(cx, |store, cx| {
178            let prediction =
179                store.prediction_at(buffer, Some(cursor_position), &self.project, cx)?;
180
181            let prediction = match prediction {
182                BufferEditPrediction::Local { prediction } => prediction,
183                BufferEditPrediction::Jump { prediction } => {
184                    return Some(edit_prediction_types::EditPrediction::Jump {
185                        id: Some(prediction.id.to_string().into()),
186                        snapshot: prediction.snapshot.clone(),
187                        target: prediction.edits.first().unwrap().0.start,
188                    });
189                }
190            };
191
192            let buffer = buffer.read(cx);
193            let snapshot = buffer.snapshot();
194
195            let Some(edits) = prediction.interpolate(&snapshot) else {
196                store.reject_current_prediction(
197                    EditPredictionRejectReason::InterpolatedEmpty,
198                    &self.project,
199                    cx,
200                );
201                return None;
202            };
203
204            let cursor_row = cursor_position.to_point(&snapshot).row;
205            let (closest_edit_ix, (closest_edit_range, _)) =
206                edits.iter().enumerate().min_by_key(|(_, (range, _))| {
207                    let distance_from_start =
208                        cursor_row.abs_diff(range.start.to_point(&snapshot).row);
209                    let distance_from_end = cursor_row.abs_diff(range.end.to_point(&snapshot).row);
210                    cmp::min(distance_from_start, distance_from_end)
211                })?;
212
213            let mut edit_start_ix = closest_edit_ix;
214            for (range, _) in edits[..edit_start_ix].iter().rev() {
215                let distance_from_closest_edit = closest_edit_range.start.to_point(&snapshot).row
216                    - range.end.to_point(&snapshot).row;
217                if distance_from_closest_edit <= 1 {
218                    edit_start_ix -= 1;
219                } else {
220                    break;
221                }
222            }
223
224            let mut edit_end_ix = closest_edit_ix + 1;
225            for (range, _) in &edits[edit_end_ix..] {
226                let distance_from_closest_edit = range.start.to_point(buffer).row
227                    - closest_edit_range.end.to_point(&snapshot).row;
228                if distance_from_closest_edit <= 1 {
229                    edit_end_ix += 1;
230                } else {
231                    break;
232                }
233            }
234
235            Some(edit_prediction_types::EditPrediction::Local {
236                id: Some(prediction.id.to_string().into()),
237                edits: edits[edit_start_ix..edit_end_ix].to_vec(),
238                cursor_position: prediction.cursor_position,
239                edit_preview: Some(prediction.edit_preview.clone()),
240            })
241        })
242    }
243}