1use std::{cmp, sync::Arc};
2
3use client::{Client, UserStore};
4use cloud_llm_client::EditPredictionRejectReason;
5use edit_prediction_types::{
6 DataCollectionState, EditPredictionDelegate, EditPredictionDiscardReason,
7 EditPredictionIconSet, SuggestionDisplayType,
8};
9use feature_flags::FeatureFlagAppExt;
10use fs::Fs;
11use gpui::{App, Entity, prelude::*};
12use language::{Buffer, ToPoint as _};
13use project::Project;
14use settings::{EditPredictionDataCollectionChoice, update_settings_file};
15
16use crate::{BufferEditPrediction, EditPredictionStore};
17
18pub struct ZedEditPredictionDelegate {
19 store: Entity<EditPredictionStore>,
20 project: Entity<Project>,
21 singleton_buffer: Option<Entity<Buffer>>,
22}
23
24impl ZedEditPredictionDelegate {
25 pub fn new(
26 project: Entity<Project>,
27 singleton_buffer: Option<Entity<Buffer>>,
28 client: &Arc<Client>,
29 user_store: &Entity<UserStore>,
30 cx: &mut Context<Self>,
31 ) -> Self {
32 let store = EditPredictionStore::global(client, user_store, cx);
33 store.update(cx, |store, cx| {
34 store.register_project(&project, cx);
35 });
36
37 cx.observe(&store, |_this, _ep_store, cx| {
38 cx.notify();
39 })
40 .detach();
41
42 Self {
43 project: project,
44 store: store,
45 singleton_buffer,
46 }
47 }
48}
49
50impl EditPredictionDelegate for ZedEditPredictionDelegate {
51 fn name() -> &'static str {
52 "zed-predict"
53 }
54
55 fn display_name() -> &'static str {
56 "Zed's Edit Predictions"
57 }
58
59 fn show_predictions_in_menu() -> bool {
60 true
61 }
62
63 fn show_tab_accept_marker() -> bool {
64 true
65 }
66
67 fn icons(&self, cx: &App) -> EditPredictionIconSet {
68 self.store.read(cx).icons(cx)
69 }
70
71 fn data_collection_state(&self, cx: &App) -> DataCollectionState {
72 if let Some(buffer) = &self.singleton_buffer
73 && let Some(file) = buffer.read(cx).file()
74 {
75 let is_project_open_source =
76 self.store
77 .read(cx)
78 .is_file_open_source(&self.project, file, cx);
79
80 if self.store.read(cx).is_data_collection_enabled(cx) {
81 DataCollectionState::Enabled {
82 is_project_open_source,
83 }
84 } else {
85 DataCollectionState::Disabled {
86 is_project_open_source,
87 }
88 }
89 } else {
90 DataCollectionState::Disabled {
91 is_project_open_source: false,
92 }
93 }
94 }
95
96 fn can_toggle_data_collection(&self, cx: &App) -> bool {
97 if cx.is_staff() {
98 return false;
99 }
100
101 self.store
102 .read(cx)
103 .is_data_collection_allowed_by_organization(cx)
104 }
105
106 fn toggle_data_collection(&mut self, cx: &mut App) {
107 let fs = <dyn Fs>::global(cx);
108 let is_currently_enabled = self.store.read(cx).is_data_collection_enabled(cx);
109 update_settings_file(fs, cx, move |settings, _| {
110 let edit_predictions = settings
111 .project
112 .all_languages
113 .edit_predictions
114 .get_or_insert_default();
115
116 edit_predictions.allow_data_collection = Some(if is_currently_enabled {
117 EditPredictionDataCollectionChoice::No
118 } else {
119 EditPredictionDataCollectionChoice::Yes
120 });
121 });
122 }
123
124 fn usage(&self, cx: &App) -> Option<client::EditPredictionUsage> {
125 self.store.read(cx).usage(cx)
126 }
127
128 fn is_enabled(
129 &self,
130 _buffer: &Entity<language::Buffer>,
131 _cursor_position: language::Anchor,
132 _cx: &App,
133 ) -> bool {
134 true
135 }
136
137 fn is_refreshing(&self, cx: &App) -> bool {
138 self.store.read(cx).is_refreshing(&self.project)
139 }
140
141 fn refresh(
142 &mut self,
143 buffer: Entity<language::Buffer>,
144 cursor_position: language::Anchor,
145 _debounce: bool,
146 cx: &mut Context<Self>,
147 ) {
148 let store = self.store.read(cx);
149
150 if store.user_store.read_with(cx, |user_store, _cx| {
151 user_store.account_too_young() || user_store.has_overdue_invoices()
152 }) {
153 return;
154 }
155
156 self.store.update(cx, |store, cx| {
157 if let Some(current) =
158 store.prediction_at(&buffer, Some(cursor_position), &self.project, cx)
159 && let BufferEditPrediction::Local { prediction } = current
160 && prediction.interpolate(buffer.read(cx)).is_some()
161 {
162 return;
163 }
164
165 store.refresh_context(&self.project, &buffer, cursor_position, cx);
166 store.refresh_prediction_from_buffer(self.project.clone(), buffer, cursor_position, cx)
167 });
168 }
169
170 fn accept(&mut self, cx: &mut Context<Self>) {
171 self.store.update(cx, |store, cx| {
172 store.accept_current_prediction(&self.project, cx);
173 });
174 }
175
176 fn discard(&mut self, reason: EditPredictionDiscardReason, cx: &mut Context<Self>) {
177 let reject_reason = match reason {
178 EditPredictionDiscardReason::Rejected => EditPredictionRejectReason::Rejected,
179 EditPredictionDiscardReason::Ignored => EditPredictionRejectReason::Discarded,
180 };
181 self.store.update(cx, |store, cx| {
182 store.reject_current_prediction(reject_reason, &self.project, cx);
183 });
184 }
185
186 fn did_show(&mut self, display_type: SuggestionDisplayType, cx: &mut Context<Self>) {
187 self.store.update(cx, |store, cx| {
188 store.did_show_current_prediction(&self.project, display_type, cx);
189 });
190 }
191
192 fn suggest(
193 &mut self,
194 buffer: &Entity<language::Buffer>,
195 cursor_position: language::Anchor,
196 cx: &mut Context<Self>,
197 ) -> Option<edit_prediction_types::EditPrediction> {
198 self.store.update(cx, |store, cx| {
199 let prediction =
200 store.prediction_at(buffer, Some(cursor_position), &self.project, cx)?;
201
202 let prediction = match prediction {
203 BufferEditPrediction::Local { prediction } => prediction,
204 BufferEditPrediction::Jump { prediction } => {
205 return Some(edit_prediction_types::EditPrediction::Jump {
206 id: Some(prediction.id.0.clone()),
207 snapshot: prediction.snapshot.clone(),
208 target: prediction.edits.first().unwrap().0.start,
209 });
210 }
211 };
212
213 let buffer = buffer.read(cx);
214 let snapshot = buffer.snapshot();
215
216 let Some(edits) = prediction.interpolate(&snapshot) else {
217 store.reject_current_prediction(
218 EditPredictionRejectReason::InterpolatedEmpty,
219 &self.project,
220 cx,
221 );
222 return None;
223 };
224
225 let cursor_row = cursor_position.to_point(&snapshot).row;
226 let (closest_edit_ix, (closest_edit_range, _)) =
227 edits.iter().enumerate().min_by_key(|(_, (range, _))| {
228 let distance_from_start =
229 cursor_row.abs_diff(range.start.to_point(&snapshot).row);
230 let distance_from_end = cursor_row.abs_diff(range.end.to_point(&snapshot).row);
231 cmp::min(distance_from_start, distance_from_end)
232 })?;
233
234 let mut edit_start_ix = closest_edit_ix;
235 for (range, _) in edits[..edit_start_ix].iter().rev() {
236 let distance_from_closest_edit = closest_edit_range.start.to_point(&snapshot).row
237 - range.end.to_point(&snapshot).row;
238 if distance_from_closest_edit <= 1 {
239 edit_start_ix -= 1;
240 } else {
241 break;
242 }
243 }
244
245 let mut edit_end_ix = closest_edit_ix + 1;
246 for (range, _) in &edits[edit_end_ix..] {
247 let distance_from_closest_edit = range.start.to_point(buffer).row
248 - closest_edit_range.end.to_point(&snapshot).row;
249 if distance_from_closest_edit <= 1 {
250 edit_end_ix += 1;
251 } else {
252 break;
253 }
254 }
255
256 Some(edit_prediction_types::EditPrediction::Local {
257 id: Some(prediction.id.0.clone()),
258 edits: edits[edit_start_ix..edit_end_ix].to_vec(),
259 cursor_position: prediction.cursor_position,
260 edit_preview: Some(prediction.edit_preview.clone()),
261 })
262 })
263 }
264}