1use std::{cmp, sync::Arc};
2
3use client::{Client, UserStore};
4use cloud_llm_client::EditPredictionRejectReason;
5use edit_prediction_types::{
6 DataCollectionState, EditPredictionDelegate, EditPredictionDiscardReason,
7 EditPredictionIconSet, SuggestionDisplayType,
8};
9use feature_flags::FeatureFlagAppExt;
10use gpui::{App, Entity, prelude::*};
11use language::{Buffer, ToPoint as _};
12use project::Project;
13
14use crate::{BufferEditPrediction, EditPredictionStore};
15
16pub struct ZedEditPredictionDelegate {
17 store: Entity<EditPredictionStore>,
18 project: Entity<Project>,
19 singleton_buffer: Option<Entity<Buffer>>,
20}
21
22impl ZedEditPredictionDelegate {
23 pub fn new(
24 project: Entity<Project>,
25 singleton_buffer: Option<Entity<Buffer>>,
26 client: &Arc<Client>,
27 user_store: &Entity<UserStore>,
28 cx: &mut Context<Self>,
29 ) -> Self {
30 let store = EditPredictionStore::global(client, user_store, cx);
31 store.update(cx, |store, cx| {
32 store.register_project(&project, cx);
33 });
34
35 cx.observe(&store, |_this, _ep_store, cx| {
36 cx.notify();
37 })
38 .detach();
39
40 Self {
41 project: project,
42 store: store,
43 singleton_buffer,
44 }
45 }
46}
47
48impl EditPredictionDelegate for ZedEditPredictionDelegate {
49 fn name() -> &'static str {
50 "zed-predict"
51 }
52
53 fn display_name() -> &'static str {
54 "Zed's Edit Predictions"
55 }
56
57 fn show_predictions_in_menu() -> bool {
58 true
59 }
60
61 fn show_tab_accept_marker() -> bool {
62 true
63 }
64
65 fn icons(&self, cx: &App) -> EditPredictionIconSet {
66 self.store.read(cx).icons(cx)
67 }
68
69 fn data_collection_state(&self, cx: &App) -> DataCollectionState {
70 if let Some(buffer) = &self.singleton_buffer
71 && let Some(file) = buffer.read(cx).file()
72 {
73 let is_project_open_source =
74 self.store
75 .read(cx)
76 .is_file_open_source(&self.project, file, cx);
77
78 if let Some(organization_configuration) = self
79 .store
80 .read(cx)
81 .user_store
82 .read(cx)
83 .current_organization_configuration()
84 {
85 if !organization_configuration
86 .edit_prediction
87 .is_feedback_enabled
88 {
89 return DataCollectionState::Disabled {
90 is_project_open_source,
91 };
92 }
93 }
94
95 if self.store.read(cx).data_collection_choice.is_enabled(cx) {
96 DataCollectionState::Enabled {
97 is_project_open_source,
98 }
99 } else {
100 DataCollectionState::Disabled {
101 is_project_open_source,
102 }
103 }
104 } else {
105 return DataCollectionState::Disabled {
106 is_project_open_source: false,
107 };
108 }
109 }
110
111 fn can_toggle_data_collection(&self, cx: &App) -> bool {
112 if cx.is_staff() {
113 return false;
114 }
115
116 if let Some(organization_configuration) = self
117 .store
118 .read(cx)
119 .user_store
120 .read(cx)
121 .current_organization_configuration()
122 {
123 if !organization_configuration
124 .edit_prediction
125 .is_feedback_enabled
126 {
127 return false;
128 }
129 }
130
131 true
132 }
133
134 fn toggle_data_collection(&mut self, cx: &mut App) {
135 self.store.update(cx, |store, cx| {
136 store.toggle_data_collection_choice(cx);
137 });
138 }
139
140 fn usage(&self, cx: &App) -> Option<client::EditPredictionUsage> {
141 self.store.read(cx).usage(cx)
142 }
143
144 fn is_enabled(
145 &self,
146 _buffer: &Entity<language::Buffer>,
147 _cursor_position: language::Anchor,
148 _cx: &App,
149 ) -> bool {
150 true
151 }
152
153 fn is_refreshing(&self, cx: &App) -> bool {
154 self.store.read(cx).is_refreshing(&self.project)
155 }
156
157 fn refresh(
158 &mut self,
159 buffer: Entity<language::Buffer>,
160 cursor_position: language::Anchor,
161 _debounce: bool,
162 cx: &mut Context<Self>,
163 ) {
164 let store = self.store.read(cx);
165
166 if store.user_store.read_with(cx, |user_store, _cx| {
167 user_store.account_too_young() || user_store.has_overdue_invoices()
168 }) {
169 return;
170 }
171
172 self.store.update(cx, |store, cx| {
173 if let Some(current) =
174 store.prediction_at(&buffer, Some(cursor_position), &self.project, cx)
175 && let BufferEditPrediction::Local { prediction } = current
176 && prediction.interpolate(buffer.read(cx)).is_some()
177 {
178 return;
179 }
180
181 store.refresh_context(&self.project, &buffer, cursor_position, cx);
182 store.refresh_prediction_from_buffer(self.project.clone(), buffer, cursor_position, cx)
183 });
184 }
185
186 fn accept(&mut self, cx: &mut Context<Self>) {
187 self.store.update(cx, |store, cx| {
188 store.accept_current_prediction(&self.project, cx);
189 });
190 }
191
192 fn discard(&mut self, reason: EditPredictionDiscardReason, cx: &mut Context<Self>) {
193 let reject_reason = match reason {
194 EditPredictionDiscardReason::Rejected => EditPredictionRejectReason::Rejected,
195 EditPredictionDiscardReason::Ignored => EditPredictionRejectReason::Discarded,
196 };
197 self.store.update(cx, |store, cx| {
198 store.reject_current_prediction(reject_reason, &self.project, cx);
199 });
200 }
201
202 fn did_show(&mut self, display_type: SuggestionDisplayType, cx: &mut Context<Self>) {
203 self.store.update(cx, |store, cx| {
204 store.did_show_current_prediction(&self.project, display_type, cx);
205 });
206 }
207
208 fn suggest(
209 &mut self,
210 buffer: &Entity<language::Buffer>,
211 cursor_position: language::Anchor,
212 cx: &mut Context<Self>,
213 ) -> Option<edit_prediction_types::EditPrediction> {
214 self.store.update(cx, |store, cx| {
215 let prediction =
216 store.prediction_at(buffer, Some(cursor_position), &self.project, cx)?;
217
218 let prediction = match prediction {
219 BufferEditPrediction::Local { prediction } => prediction,
220 BufferEditPrediction::Jump { prediction } => {
221 return Some(edit_prediction_types::EditPrediction::Jump {
222 id: Some(prediction.id.0.clone()),
223 snapshot: prediction.snapshot.clone(),
224 target: prediction.edits.first().unwrap().0.start,
225 });
226 }
227 };
228
229 let buffer = buffer.read(cx);
230 let snapshot = buffer.snapshot();
231
232 let Some(edits) = prediction.interpolate(&snapshot) else {
233 store.reject_current_prediction(
234 EditPredictionRejectReason::InterpolatedEmpty,
235 &self.project,
236 cx,
237 );
238 return None;
239 };
240
241 let cursor_row = cursor_position.to_point(&snapshot).row;
242 let (closest_edit_ix, (closest_edit_range, _)) =
243 edits.iter().enumerate().min_by_key(|(_, (range, _))| {
244 let distance_from_start =
245 cursor_row.abs_diff(range.start.to_point(&snapshot).row);
246 let distance_from_end = cursor_row.abs_diff(range.end.to_point(&snapshot).row);
247 cmp::min(distance_from_start, distance_from_end)
248 })?;
249
250 let mut edit_start_ix = closest_edit_ix;
251 for (range, _) in edits[..edit_start_ix].iter().rev() {
252 let distance_from_closest_edit = closest_edit_range.start.to_point(&snapshot).row
253 - range.end.to_point(&snapshot).row;
254 if distance_from_closest_edit <= 1 {
255 edit_start_ix -= 1;
256 } else {
257 break;
258 }
259 }
260
261 let mut edit_end_ix = closest_edit_ix + 1;
262 for (range, _) in &edits[edit_end_ix..] {
263 let distance_from_closest_edit = range.start.to_point(buffer).row
264 - closest_edit_range.end.to_point(&snapshot).row;
265 if distance_from_closest_edit <= 1 {
266 edit_end_ix += 1;
267 } else {
268 break;
269 }
270 }
271
272 Some(edit_prediction_types::EditPrediction::Local {
273 id: Some(prediction.id.0.clone()),
274 edits: edits[edit_start_ix..edit_end_ix].to_vec(),
275 cursor_position: prediction.cursor_position,
276 edit_preview: Some(prediction.edit_preview.clone()),
277 })
278 })
279 }
280}