zed_edit_prediction_delegate.rs

  1use std::{cmp, sync::Arc};
  2
  3use client::{Client, UserStore};
  4use cloud_llm_client::EditPredictionRejectReason;
  5use edit_prediction_types::{DataCollectionState, Direction, EditPredictionDelegate};
  6use gpui::{App, Entity, prelude::*};
  7use language::{Buffer, ToPoint as _};
  8use project::Project;
  9
 10use crate::{BufferEditPrediction, EditPredictionModel, EditPredictionStore};
 11
 12pub struct ZedEditPredictionDelegate {
 13    store: Entity<EditPredictionStore>,
 14    project: Entity<Project>,
 15    singleton_buffer: Option<Entity<Buffer>>,
 16}
 17
 18impl ZedEditPredictionDelegate {
 19    pub fn new(
 20        project: Entity<Project>,
 21        singleton_buffer: Option<Entity<Buffer>>,
 22        client: &Arc<Client>,
 23        user_store: &Entity<UserStore>,
 24        cx: &mut Context<Self>,
 25    ) -> Self {
 26        let store = EditPredictionStore::global(client, user_store, cx);
 27        store.update(cx, |store, cx| {
 28            store.register_project(&project, cx);
 29        });
 30
 31        cx.observe(&store, |_this, _ep_store, cx| {
 32            cx.notify();
 33        })
 34        .detach();
 35
 36        Self {
 37            project: project,
 38            store: store,
 39            singleton_buffer,
 40        }
 41    }
 42}
 43
 44impl EditPredictionDelegate for ZedEditPredictionDelegate {
 45    fn name() -> &'static str {
 46        "zed-predict"
 47    }
 48
 49    fn display_name() -> &'static str {
 50        "Zed's Edit Predictions"
 51    }
 52
 53    fn show_predictions_in_menu() -> bool {
 54        true
 55    }
 56
 57    fn show_tab_accept_marker() -> bool {
 58        true
 59    }
 60
 61    fn data_collection_state(&self, cx: &App) -> DataCollectionState {
 62        if let Some(buffer) = &self.singleton_buffer
 63            && let Some(file) = buffer.read(cx).file()
 64        {
 65            let is_project_open_source =
 66                self.store
 67                    .read(cx)
 68                    .is_file_open_source(&self.project, file, cx);
 69            if self.store.read(cx).data_collection_choice.is_enabled() {
 70                DataCollectionState::Enabled {
 71                    is_project_open_source,
 72                }
 73            } else {
 74                DataCollectionState::Disabled {
 75                    is_project_open_source,
 76                }
 77            }
 78        } else {
 79            return DataCollectionState::Disabled {
 80                is_project_open_source: false,
 81            };
 82        }
 83    }
 84
 85    fn toggle_data_collection(&mut self, cx: &mut App) {
 86        self.store.update(cx, |store, cx| {
 87            store.toggle_data_collection_choice(cx);
 88        });
 89    }
 90
 91    fn usage(&self, cx: &App) -> Option<client::EditPredictionUsage> {
 92        self.store.read(cx).usage(cx)
 93    }
 94
 95    fn is_enabled(
 96        &self,
 97        _buffer: &Entity<language::Buffer>,
 98        _cursor_position: language::Anchor,
 99        cx: &App,
100    ) -> bool {
101        let store = self.store.read(cx);
102        if store.edit_prediction_model == EditPredictionModel::Sweep {
103            store.has_sweep_api_token(cx)
104        } else {
105            true
106        }
107    }
108
109    fn is_refreshing(&self, cx: &App) -> bool {
110        self.store.read(cx).is_refreshing(&self.project)
111    }
112
113    fn refresh(
114        &mut self,
115        buffer: Entity<language::Buffer>,
116        cursor_position: language::Anchor,
117        _debounce: bool,
118        cx: &mut Context<Self>,
119    ) {
120        let store = self.store.read(cx);
121
122        if store.user_store.read_with(cx, |user_store, _cx| {
123            user_store.account_too_young() || user_store.has_overdue_invoices()
124        }) {
125            return;
126        }
127
128        self.store.update(cx, |store, cx| {
129            if let Some(current) =
130                store.prediction_at(&buffer, Some(cursor_position), &self.project, cx)
131                && let BufferEditPrediction::Local { prediction } = current
132                && prediction.interpolate(buffer.read(cx)).is_some()
133            {
134                return;
135            }
136
137            store.refresh_context(&self.project, &buffer, cursor_position, cx);
138            store.refresh_prediction_from_buffer(self.project.clone(), buffer, cursor_position, cx)
139        });
140    }
141
142    fn cycle(
143        &mut self,
144        _buffer: Entity<language::Buffer>,
145        _cursor_position: language::Anchor,
146        _direction: Direction,
147        _cx: &mut Context<Self>,
148    ) {
149    }
150
151    fn accept(&mut self, cx: &mut Context<Self>) {
152        self.store.update(cx, |store, cx| {
153            store.accept_current_prediction(&self.project, cx);
154        });
155    }
156
157    fn discard(&mut self, cx: &mut Context<Self>) {
158        self.store.update(cx, |store, _cx| {
159            store.reject_current_prediction(EditPredictionRejectReason::Discarded, &self.project);
160        });
161    }
162
163    fn did_show(&mut self, cx: &mut Context<Self>) {
164        self.store.update(cx, |store, cx| {
165            store.did_show_current_prediction(&self.project, cx);
166        });
167    }
168
169    fn suggest(
170        &mut self,
171        buffer: &Entity<language::Buffer>,
172        cursor_position: language::Anchor,
173        cx: &mut Context<Self>,
174    ) -> Option<edit_prediction_types::EditPrediction> {
175        self.store.update(cx, |store, cx| {
176            let prediction =
177                store.prediction_at(buffer, Some(cursor_position), &self.project, cx)?;
178
179            let prediction = match prediction {
180                BufferEditPrediction::Local { prediction } => prediction,
181                BufferEditPrediction::Jump { prediction } => {
182                    return Some(edit_prediction_types::EditPrediction::Jump {
183                        id: Some(prediction.id.to_string().into()),
184                        snapshot: prediction.snapshot.clone(),
185                        target: prediction.edits.first().unwrap().0.start,
186                    });
187                }
188            };
189
190            let buffer = buffer.read(cx);
191            let snapshot = buffer.snapshot();
192
193            let Some(edits) = prediction.interpolate(&snapshot) else {
194                store.reject_current_prediction(
195                    EditPredictionRejectReason::InterpolatedEmpty,
196                    &self.project,
197                );
198                return None;
199            };
200
201            let cursor_row = cursor_position.to_point(&snapshot).row;
202            let (closest_edit_ix, (closest_edit_range, _)) =
203                edits.iter().enumerate().min_by_key(|(_, (range, _))| {
204                    let distance_from_start =
205                        cursor_row.abs_diff(range.start.to_point(&snapshot).row);
206                    let distance_from_end = cursor_row.abs_diff(range.end.to_point(&snapshot).row);
207                    cmp::min(distance_from_start, distance_from_end)
208                })?;
209
210            let mut edit_start_ix = closest_edit_ix;
211            for (range, _) in edits[..edit_start_ix].iter().rev() {
212                let distance_from_closest_edit = closest_edit_range.start.to_point(&snapshot).row
213                    - range.end.to_point(&snapshot).row;
214                if distance_from_closest_edit <= 1 {
215                    edit_start_ix -= 1;
216                } else {
217                    break;
218                }
219            }
220
221            let mut edit_end_ix = closest_edit_ix + 1;
222            for (range, _) in &edits[edit_end_ix..] {
223                let distance_from_closest_edit = range.start.to_point(buffer).row
224                    - closest_edit_range.end.to_point(&snapshot).row;
225                if distance_from_closest_edit <= 1 {
226                    edit_end_ix += 1;
227                } else {
228                    break;
229                }
230            }
231
232            Some(edit_prediction_types::EditPrediction::Local {
233                id: Some(prediction.id.to_string().into()),
234                edits: edits[edit_start_ix..edit_end_ix].to_vec(),
235                edit_preview: Some(prediction.edit_preview.clone()),
236            })
237        })
238    }
239}