1use std::{cmp, sync::Arc};
2
3use client::{Client, UserStore};
4use cloud_llm_client::EditPredictionRejectReason;
5use edit_prediction_types::{
6 DataCollectionState, EditPredictionDelegate, EditPredictionDiscardReason,
7 EditPredictionIconSet, SuggestionDisplayType,
8};
9use gpui::{App, Entity, prelude::*};
10use language::{Buffer, ToPoint as _};
11use project::Project;
12
13use crate::{BufferEditPrediction, EditPredictionModel, EditPredictionStore};
14
15pub struct ZedEditPredictionDelegate {
16 store: Entity<EditPredictionStore>,
17 project: Entity<Project>,
18 singleton_buffer: Option<Entity<Buffer>>,
19}
20
21impl ZedEditPredictionDelegate {
22 pub fn new(
23 project: Entity<Project>,
24 singleton_buffer: Option<Entity<Buffer>>,
25 client: &Arc<Client>,
26 user_store: &Entity<UserStore>,
27 cx: &mut Context<Self>,
28 ) -> Self {
29 let store = EditPredictionStore::global(client, user_store, cx);
30 store.update(cx, |store, cx| {
31 store.register_project(&project, cx);
32 });
33
34 cx.observe(&store, |_this, _ep_store, cx| {
35 cx.notify();
36 })
37 .detach();
38
39 Self {
40 project: project,
41 store: store,
42 singleton_buffer,
43 }
44 }
45}
46
47impl EditPredictionDelegate for ZedEditPredictionDelegate {
48 fn name() -> &'static str {
49 "zed-predict"
50 }
51
52 fn display_name() -> &'static str {
53 "Zed's Edit Predictions"
54 }
55
56 fn show_predictions_in_menu() -> bool {
57 true
58 }
59
60 fn show_tab_accept_marker() -> bool {
61 true
62 }
63
64 fn icons(&self, cx: &App) -> EditPredictionIconSet {
65 self.store.read(cx).icons(cx)
66 }
67
68 fn data_collection_state(&self, cx: &App) -> DataCollectionState {
69 if let Some(buffer) = &self.singleton_buffer
70 && let Some(file) = buffer.read(cx).file()
71 {
72 let is_project_open_source =
73 self.store
74 .read(cx)
75 .is_file_open_source(&self.project, file, cx);
76 if self.store.read(cx).data_collection_choice.is_enabled(cx) {
77 DataCollectionState::Enabled {
78 is_project_open_source,
79 }
80 } else {
81 DataCollectionState::Disabled {
82 is_project_open_source,
83 }
84 }
85 } else {
86 return DataCollectionState::Disabled {
87 is_project_open_source: false,
88 };
89 }
90 }
91
92 fn toggle_data_collection(&mut self, cx: &mut App) {
93 self.store.update(cx, |store, cx| {
94 store.toggle_data_collection_choice(cx);
95 });
96 }
97
98 fn usage(&self, cx: &App) -> Option<client::EditPredictionUsage> {
99 self.store.read(cx).usage(cx)
100 }
101
102 fn is_enabled(
103 &self,
104 _buffer: &Entity<language::Buffer>,
105 _cursor_position: language::Anchor,
106 cx: &App,
107 ) -> bool {
108 let store = self.store.read(cx);
109 if store.edit_prediction_model == EditPredictionModel::Sweep {
110 store.has_sweep_api_token(cx)
111 } else {
112 true
113 }
114 }
115
116 fn is_refreshing(&self, cx: &App) -> bool {
117 self.store.read(cx).is_refreshing(&self.project)
118 }
119
120 fn refresh(
121 &mut self,
122 buffer: Entity<language::Buffer>,
123 cursor_position: language::Anchor,
124 _debounce: bool,
125 cx: &mut Context<Self>,
126 ) {
127 let store = self.store.read(cx);
128
129 if store.user_store.read_with(cx, |user_store, _cx| {
130 user_store.account_too_young() || user_store.has_overdue_invoices()
131 }) {
132 return;
133 }
134
135 self.store.update(cx, |store, cx| {
136 if let Some(current) =
137 store.prediction_at(&buffer, Some(cursor_position), &self.project, cx)
138 && let BufferEditPrediction::Local { prediction } = current
139 && prediction.interpolate(buffer.read(cx)).is_some()
140 {
141 return;
142 }
143
144 store.refresh_context(&self.project, &buffer, cursor_position, cx);
145 store.refresh_prediction_from_buffer(self.project.clone(), buffer, cursor_position, cx)
146 });
147 }
148
149 fn accept(&mut self, cx: &mut Context<Self>) {
150 self.store.update(cx, |store, cx| {
151 store.accept_current_prediction(&self.project, cx);
152 });
153 }
154
155 fn discard(&mut self, reason: EditPredictionDiscardReason, cx: &mut Context<Self>) {
156 let reject_reason = match reason {
157 EditPredictionDiscardReason::Rejected => EditPredictionRejectReason::Rejected,
158 EditPredictionDiscardReason::Ignored => EditPredictionRejectReason::Discarded,
159 };
160 self.store.update(cx, |store, cx| {
161 store.reject_current_prediction(reject_reason, &self.project, cx);
162 });
163 }
164
165 fn did_show(&mut self, display_type: SuggestionDisplayType, cx: &mut Context<Self>) {
166 self.store.update(cx, |store, cx| {
167 store.did_show_current_prediction(&self.project, display_type, cx);
168 });
169 }
170
171 fn suggest(
172 &mut self,
173 buffer: &Entity<language::Buffer>,
174 cursor_position: language::Anchor,
175 cx: &mut Context<Self>,
176 ) -> Option<edit_prediction_types::EditPrediction> {
177 self.store.update(cx, |store, cx| {
178 let prediction =
179 store.prediction_at(buffer, Some(cursor_position), &self.project, cx)?;
180
181 let prediction = match prediction {
182 BufferEditPrediction::Local { prediction } => prediction,
183 BufferEditPrediction::Jump { prediction } => {
184 return Some(edit_prediction_types::EditPrediction::Jump {
185 id: Some(prediction.id.to_string().into()),
186 snapshot: prediction.snapshot.clone(),
187 target: prediction.edits.first().unwrap().0.start,
188 });
189 }
190 };
191
192 let buffer = buffer.read(cx);
193 let snapshot = buffer.snapshot();
194
195 let Some(edits) = prediction.interpolate(&snapshot) else {
196 store.reject_current_prediction(
197 EditPredictionRejectReason::InterpolatedEmpty,
198 &self.project,
199 cx,
200 );
201 return None;
202 };
203
204 let cursor_row = cursor_position.to_point(&snapshot).row;
205 let (closest_edit_ix, (closest_edit_range, _)) =
206 edits.iter().enumerate().min_by_key(|(_, (range, _))| {
207 let distance_from_start =
208 cursor_row.abs_diff(range.start.to_point(&snapshot).row);
209 let distance_from_end = cursor_row.abs_diff(range.end.to_point(&snapshot).row);
210 cmp::min(distance_from_start, distance_from_end)
211 })?;
212
213 let mut edit_start_ix = closest_edit_ix;
214 for (range, _) in edits[..edit_start_ix].iter().rev() {
215 let distance_from_closest_edit = closest_edit_range.start.to_point(&snapshot).row
216 - range.end.to_point(&snapshot).row;
217 if distance_from_closest_edit <= 1 {
218 edit_start_ix -= 1;
219 } else {
220 break;
221 }
222 }
223
224 let mut edit_end_ix = closest_edit_ix + 1;
225 for (range, _) in &edits[edit_end_ix..] {
226 let distance_from_closest_edit = range.start.to_point(buffer).row
227 - closest_edit_range.end.to_point(&snapshot).row;
228 if distance_from_closest_edit <= 1 {
229 edit_end_ix += 1;
230 } else {
231 break;
232 }
233 }
234
235 Some(edit_prediction_types::EditPrediction::Local {
236 id: Some(prediction.id.to_string().into()),
237 edits: edits[edit_start_ix..edit_end_ix].to_vec(),
238 cursor_position: prediction.cursor_position,
239 edit_preview: Some(prediction.edit_preview.clone()),
240 })
241 })
242 }
243}