1use std::{cmp, sync::Arc};
2
3use client::{Client, UserStore};
4use cloud_llm_client::EditPredictionRejectReason;
5use edit_prediction_types::{
6 DataCollectionState, EditPredictionDelegate, EditPredictionIconSet, SuggestionDisplayType,
7};
8use gpui::{App, Entity, prelude::*};
9use language::{Buffer, ToPoint as _};
10use project::Project;
11use ui::prelude::*;
12
13use crate::{BufferEditPrediction, EditPredictionModel, EditPredictionStore};
14
15pub struct ZedEditPredictionDelegate {
16 store: Entity<EditPredictionStore>,
17 project: Entity<Project>,
18 singleton_buffer: Option<Entity<Buffer>>,
19}
20
21impl ZedEditPredictionDelegate {
22 pub fn new(
23 project: Entity<Project>,
24 singleton_buffer: Option<Entity<Buffer>>,
25 client: &Arc<Client>,
26 user_store: &Entity<UserStore>,
27 cx: &mut Context<Self>,
28 ) -> Self {
29 let store = EditPredictionStore::global(client, user_store, cx);
30 store.update(cx, |store, cx| {
31 store.register_project(&project, cx);
32 });
33
34 cx.observe(&store, |_this, _ep_store, cx| {
35 cx.notify();
36 })
37 .detach();
38
39 Self {
40 project: project,
41 store: store,
42 singleton_buffer,
43 }
44 }
45}
46
47impl EditPredictionDelegate for ZedEditPredictionDelegate {
48 fn name() -> &'static str {
49 "zed-predict"
50 }
51
52 fn display_name() -> &'static str {
53 "Zed's Edit Predictions"
54 }
55
56 fn show_predictions_in_menu() -> bool {
57 true
58 }
59
60 fn show_tab_accept_marker() -> bool {
61 true
62 }
63
64 fn icons(&self, cx: &App) -> EditPredictionIconSet {
65 match self.store.read(cx).edit_prediction_model {
66 EditPredictionModel::Ollama => EditPredictionIconSet::new(IconName::AiOllama),
67 EditPredictionModel::Sweep => EditPredictionIconSet::new(IconName::SweepAi)
68 .with_disabled(IconName::SweepAiDisabled)
69 .with_up(IconName::SweepAiUp)
70 .with_down(IconName::SweepAiDown)
71 .with_error(IconName::SweepAiError),
72 EditPredictionModel::Mercury => EditPredictionIconSet::new(IconName::Inception),
73 EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 { .. } => {
74 EditPredictionIconSet::new(IconName::ZedPredict)
75 .with_disabled(IconName::ZedPredictDisabled)
76 .with_up(IconName::ZedPredictUp)
77 .with_down(IconName::ZedPredictDown)
78 .with_error(IconName::ZedPredictError)
79 }
80 }
81 }
82
83 fn data_collection_state(&self, cx: &App) -> DataCollectionState {
84 if let Some(buffer) = &self.singleton_buffer
85 && let Some(file) = buffer.read(cx).file()
86 {
87 let is_project_open_source =
88 self.store
89 .read(cx)
90 .is_file_open_source(&self.project, file, cx);
91 if self.store.read(cx).data_collection_choice.is_enabled(cx) {
92 DataCollectionState::Enabled {
93 is_project_open_source,
94 }
95 } else {
96 DataCollectionState::Disabled {
97 is_project_open_source,
98 }
99 }
100 } else {
101 return DataCollectionState::Disabled {
102 is_project_open_source: false,
103 };
104 }
105 }
106
107 fn toggle_data_collection(&mut self, cx: &mut App) {
108 self.store.update(cx, |store, cx| {
109 store.toggle_data_collection_choice(cx);
110 });
111 }
112
113 fn usage(&self, cx: &App) -> Option<client::EditPredictionUsage> {
114 self.store.read(cx).usage(cx)
115 }
116
117 fn is_enabled(
118 &self,
119 _buffer: &Entity<language::Buffer>,
120 _cursor_position: language::Anchor,
121 cx: &App,
122 ) -> bool {
123 let store = self.store.read(cx);
124 if store.edit_prediction_model == EditPredictionModel::Sweep {
125 store.has_sweep_api_token(cx)
126 } else {
127 true
128 }
129 }
130
131 fn is_refreshing(&self, cx: &App) -> bool {
132 self.store.read(cx).is_refreshing(&self.project)
133 }
134
135 fn refresh(
136 &mut self,
137 buffer: Entity<language::Buffer>,
138 cursor_position: language::Anchor,
139 _debounce: bool,
140 cx: &mut Context<Self>,
141 ) {
142 let store = self.store.read(cx);
143
144 if store.user_store.read_with(cx, |user_store, _cx| {
145 user_store.account_too_young() || user_store.has_overdue_invoices()
146 }) {
147 return;
148 }
149
150 self.store.update(cx, |store, cx| {
151 if let Some(current) =
152 store.prediction_at(&buffer, Some(cursor_position), &self.project, cx)
153 && let BufferEditPrediction::Local { prediction } = current
154 && prediction.interpolate(buffer.read(cx)).is_some()
155 {
156 return;
157 }
158
159 store.refresh_context(&self.project, &buffer, cursor_position, cx);
160 store.refresh_prediction_from_buffer(self.project.clone(), buffer, cursor_position, cx)
161 });
162 }
163
164 fn accept(&mut self, cx: &mut Context<Self>) {
165 self.store.update(cx, |store, cx| {
166 store.accept_current_prediction(&self.project, cx);
167 });
168 }
169
170 fn discard(&mut self, cx: &mut Context<Self>) {
171 self.store.update(cx, |store, _cx| {
172 store.reject_current_prediction(EditPredictionRejectReason::Discarded, &self.project);
173 });
174 }
175
176 fn did_show(&mut self, display_type: SuggestionDisplayType, cx: &mut Context<Self>) {
177 self.store.update(cx, |store, cx| {
178 store.did_show_current_prediction(&self.project, display_type, cx);
179 });
180 }
181
182 fn suggest(
183 &mut self,
184 buffer: &Entity<language::Buffer>,
185 cursor_position: language::Anchor,
186 cx: &mut Context<Self>,
187 ) -> Option<edit_prediction_types::EditPrediction> {
188 self.store.update(cx, |store, cx| {
189 let prediction =
190 store.prediction_at(buffer, Some(cursor_position), &self.project, cx)?;
191
192 let prediction = match prediction {
193 BufferEditPrediction::Local { prediction } => prediction,
194 BufferEditPrediction::Jump { prediction } => {
195 return Some(edit_prediction_types::EditPrediction::Jump {
196 id: Some(prediction.id.to_string().into()),
197 snapshot: prediction.snapshot.clone(),
198 target: prediction.edits.first().unwrap().0.start,
199 });
200 }
201 };
202
203 let buffer = buffer.read(cx);
204 let snapshot = buffer.snapshot();
205
206 let Some(edits) = prediction.interpolate(&snapshot) else {
207 store.reject_current_prediction(
208 EditPredictionRejectReason::InterpolatedEmpty,
209 &self.project,
210 );
211 return None;
212 };
213
214 let cursor_row = cursor_position.to_point(&snapshot).row;
215 let (closest_edit_ix, (closest_edit_range, _)) =
216 edits.iter().enumerate().min_by_key(|(_, (range, _))| {
217 let distance_from_start =
218 cursor_row.abs_diff(range.start.to_point(&snapshot).row);
219 let distance_from_end = cursor_row.abs_diff(range.end.to_point(&snapshot).row);
220 cmp::min(distance_from_start, distance_from_end)
221 })?;
222
223 let mut edit_start_ix = closest_edit_ix;
224 for (range, _) in edits[..edit_start_ix].iter().rev() {
225 let distance_from_closest_edit = closest_edit_range.start.to_point(&snapshot).row
226 - range.end.to_point(&snapshot).row;
227 if distance_from_closest_edit <= 1 {
228 edit_start_ix -= 1;
229 } else {
230 break;
231 }
232 }
233
234 let mut edit_end_ix = closest_edit_ix + 1;
235 for (range, _) in &edits[edit_end_ix..] {
236 let distance_from_closest_edit = range.start.to_point(buffer).row
237 - closest_edit_range.end.to_point(&snapshot).row;
238 if distance_from_closest_edit <= 1 {
239 edit_end_ix += 1;
240 } else {
241 break;
242 }
243 }
244
245 Some(edit_prediction_types::EditPrediction::Local {
246 id: Some(prediction.id.to_string().into()),
247 edits: edits[edit_start_ix..edit_end_ix].to_vec(),
248 edit_preview: Some(prediction.edit_preview.clone()),
249 })
250 })
251 }
252}