1use std::{cmp, sync::Arc};
2
3use client::{Client, UserStore};
4use cloud_llm_client::EditPredictionRejectReason;
5use edit_prediction_types::{
6 DataCollectionState, EditPredictionDelegate, EditPredictionDiscardReason,
7 EditPredictionIconSet, SuggestionDisplayType,
8};
9use gpui::{App, Entity, prelude::*};
10use language::{Buffer, ToPoint as _};
11use project::Project;
12use ui::prelude::*;
13
14use crate::{BufferEditPrediction, EditPredictionModel, EditPredictionStore};
15
16pub struct ZedEditPredictionDelegate {
17 store: Entity<EditPredictionStore>,
18 project: Entity<Project>,
19 singleton_buffer: Option<Entity<Buffer>>,
20}
21
22impl ZedEditPredictionDelegate {
23 pub fn new(
24 project: Entity<Project>,
25 singleton_buffer: Option<Entity<Buffer>>,
26 client: &Arc<Client>,
27 user_store: &Entity<UserStore>,
28 cx: &mut Context<Self>,
29 ) -> Self {
30 let store = EditPredictionStore::global(client, user_store, cx);
31 store.update(cx, |store, cx| {
32 store.register_project(&project, cx);
33 });
34
35 cx.observe(&store, |_this, _ep_store, cx| {
36 cx.notify();
37 })
38 .detach();
39
40 Self {
41 project: project,
42 store: store,
43 singleton_buffer,
44 }
45 }
46}
47
48impl EditPredictionDelegate for ZedEditPredictionDelegate {
49 fn name() -> &'static str {
50 "zed-predict"
51 }
52
53 fn display_name() -> &'static str {
54 "Zed's Edit Predictions"
55 }
56
57 fn show_predictions_in_menu() -> bool {
58 true
59 }
60
61 fn show_tab_accept_marker() -> bool {
62 true
63 }
64
65 fn icons(&self, cx: &App) -> EditPredictionIconSet {
66 match self.store.read(cx).edit_prediction_model {
67 EditPredictionModel::Sweep => EditPredictionIconSet::new(IconName::SweepAi)
68 .with_disabled(IconName::SweepAiDisabled)
69 .with_up(IconName::SweepAiUp)
70 .with_down(IconName::SweepAiDown)
71 .with_error(IconName::SweepAiError),
72 EditPredictionModel::Mercury => EditPredictionIconSet::new(IconName::Inception),
73 EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {
74 EditPredictionIconSet::new(IconName::ZedPredict)
75 .with_disabled(IconName::ZedPredictDisabled)
76 .with_up(IconName::ZedPredictUp)
77 .with_down(IconName::ZedPredictDown)
78 .with_error(IconName::ZedPredictError)
79 }
80 EditPredictionModel::Ollama => EditPredictionIconSet::new(IconName::AiOllama),
81 }
82 }
83
84 fn data_collection_state(&self, cx: &App) -> DataCollectionState {
85 if let Some(buffer) = &self.singleton_buffer
86 && let Some(file) = buffer.read(cx).file()
87 {
88 let is_project_open_source =
89 self.store
90 .read(cx)
91 .is_file_open_source(&self.project, file, cx);
92 if self.store.read(cx).data_collection_choice.is_enabled(cx) {
93 DataCollectionState::Enabled {
94 is_project_open_source,
95 }
96 } else {
97 DataCollectionState::Disabled {
98 is_project_open_source,
99 }
100 }
101 } else {
102 return DataCollectionState::Disabled {
103 is_project_open_source: false,
104 };
105 }
106 }
107
108 fn toggle_data_collection(&mut self, cx: &mut App) {
109 self.store.update(cx, |store, cx| {
110 store.toggle_data_collection_choice(cx);
111 });
112 }
113
114 fn usage(&self, cx: &App) -> Option<client::EditPredictionUsage> {
115 self.store.read(cx).usage(cx)
116 }
117
118 fn is_enabled(
119 &self,
120 _buffer: &Entity<language::Buffer>,
121 _cursor_position: language::Anchor,
122 cx: &App,
123 ) -> bool {
124 let store = self.store.read(cx);
125 if store.edit_prediction_model == EditPredictionModel::Sweep {
126 store.has_sweep_api_token(cx)
127 } else {
128 true
129 }
130 }
131
132 fn is_refreshing(&self, cx: &App) -> bool {
133 self.store.read(cx).is_refreshing(&self.project)
134 }
135
136 fn refresh(
137 &mut self,
138 buffer: Entity<language::Buffer>,
139 cursor_position: language::Anchor,
140 _debounce: bool,
141 cx: &mut Context<Self>,
142 ) {
143 let store = self.store.read(cx);
144
145 if store.user_store.read_with(cx, |user_store, _cx| {
146 user_store.account_too_young() || user_store.has_overdue_invoices()
147 }) {
148 return;
149 }
150
151 self.store.update(cx, |store, cx| {
152 if let Some(current) =
153 store.prediction_at(&buffer, Some(cursor_position), &self.project, cx)
154 && let BufferEditPrediction::Local { prediction } = current
155 && prediction.interpolate(buffer.read(cx)).is_some()
156 {
157 return;
158 }
159
160 store.refresh_context(&self.project, &buffer, cursor_position, cx);
161 store.refresh_prediction_from_buffer(self.project.clone(), buffer, cursor_position, cx)
162 });
163 }
164
165 fn accept(&mut self, cx: &mut Context<Self>) {
166 self.store.update(cx, |store, cx| {
167 store.accept_current_prediction(&self.project, cx);
168 });
169 }
170
171 fn discard(&mut self, reason: EditPredictionDiscardReason, cx: &mut Context<Self>) {
172 let reject_reason = match reason {
173 EditPredictionDiscardReason::Rejected => EditPredictionRejectReason::Rejected,
174 EditPredictionDiscardReason::Ignored => EditPredictionRejectReason::Discarded,
175 };
176 self.store.update(cx, |store, cx| {
177 store.reject_current_prediction(reject_reason, &self.project, cx);
178 });
179 }
180
181 fn did_show(&mut self, display_type: SuggestionDisplayType, cx: &mut Context<Self>) {
182 self.store.update(cx, |store, cx| {
183 store.did_show_current_prediction(&self.project, display_type, cx);
184 });
185 }
186
187 fn suggest(
188 &mut self,
189 buffer: &Entity<language::Buffer>,
190 cursor_position: language::Anchor,
191 cx: &mut Context<Self>,
192 ) -> Option<edit_prediction_types::EditPrediction> {
193 self.store.update(cx, |store, cx| {
194 let prediction =
195 store.prediction_at(buffer, Some(cursor_position), &self.project, cx)?;
196
197 let prediction = match prediction {
198 BufferEditPrediction::Local { prediction } => prediction,
199 BufferEditPrediction::Jump { prediction } => {
200 return Some(edit_prediction_types::EditPrediction::Jump {
201 id: Some(prediction.id.to_string().into()),
202 snapshot: prediction.snapshot.clone(),
203 target: prediction.edits.first().unwrap().0.start,
204 });
205 }
206 };
207
208 let buffer = buffer.read(cx);
209 let snapshot = buffer.snapshot();
210
211 let Some(edits) = prediction.interpolate(&snapshot) else {
212 store.reject_current_prediction(
213 EditPredictionRejectReason::InterpolatedEmpty,
214 &self.project,
215 cx,
216 );
217 return None;
218 };
219
220 let cursor_row = cursor_position.to_point(&snapshot).row;
221 let (closest_edit_ix, (closest_edit_range, _)) =
222 edits.iter().enumerate().min_by_key(|(_, (range, _))| {
223 let distance_from_start =
224 cursor_row.abs_diff(range.start.to_point(&snapshot).row);
225 let distance_from_end = cursor_row.abs_diff(range.end.to_point(&snapshot).row);
226 cmp::min(distance_from_start, distance_from_end)
227 })?;
228
229 let mut edit_start_ix = closest_edit_ix;
230 for (range, _) in edits[..edit_start_ix].iter().rev() {
231 let distance_from_closest_edit = closest_edit_range.start.to_point(&snapshot).row
232 - range.end.to_point(&snapshot).row;
233 if distance_from_closest_edit <= 1 {
234 edit_start_ix -= 1;
235 } else {
236 break;
237 }
238 }
239
240 let mut edit_end_ix = closest_edit_ix + 1;
241 for (range, _) in &edits[edit_end_ix..] {
242 let distance_from_closest_edit = range.start.to_point(buffer).row
243 - closest_edit_range.end.to_point(&snapshot).row;
244 if distance_from_closest_edit <= 1 {
245 edit_end_ix += 1;
246 } else {
247 break;
248 }
249 }
250
251 Some(edit_prediction_types::EditPrediction::Local {
252 id: Some(prediction.id.to_string().into()),
253 edits: edits[edit_start_ix..edit_end_ix].to_vec(),
254 cursor_position: prediction.cursor_position,
255 edit_preview: Some(prediction.edit_preview.clone()),
256 })
257 })
258 }
259}