1use std::{cmp, sync::Arc};
2
3use client::{Client, UserStore};
4use cloud_llm_client::EditPredictionRejectReason;
5use edit_prediction_types::{DataCollectionState, Direction, EditPredictionDelegate};
6use gpui::{App, Entity, prelude::*};
7use language::{Buffer, ToPoint as _};
8use project::Project;
9
10use crate::{BufferEditPrediction, EditPredictionModel, EditPredictionStore};
11
12pub struct ZedEditPredictionDelegate {
13 store: Entity<EditPredictionStore>,
14 project: Entity<Project>,
15 singleton_buffer: Option<Entity<Buffer>>,
16}
17
18impl ZedEditPredictionDelegate {
19 pub fn new(
20 project: Entity<Project>,
21 singleton_buffer: Option<Entity<Buffer>>,
22 client: &Arc<Client>,
23 user_store: &Entity<UserStore>,
24 cx: &mut Context<Self>,
25 ) -> Self {
26 let store = EditPredictionStore::global(client, user_store, cx);
27 store.update(cx, |store, cx| {
28 store.register_project(&project, cx);
29 });
30
31 cx.observe(&store, |_this, _ep_store, cx| {
32 cx.notify();
33 })
34 .detach();
35
36 Self {
37 project: project,
38 store: store,
39 singleton_buffer,
40 }
41 }
42}
43
44impl EditPredictionDelegate for ZedEditPredictionDelegate {
45 fn name() -> &'static str {
46 "zed-predict"
47 }
48
49 fn display_name() -> &'static str {
50 "Zed's Edit Predictions"
51 }
52
53 fn show_predictions_in_menu() -> bool {
54 true
55 }
56
57 fn show_tab_accept_marker() -> bool {
58 true
59 }
60
61 fn data_collection_state(&self, cx: &App) -> DataCollectionState {
62 if let Some(buffer) = &self.singleton_buffer
63 && let Some(file) = buffer.read(cx).file()
64 {
65 let is_project_open_source =
66 self.store
67 .read(cx)
68 .is_file_open_source(&self.project, file, cx);
69 if self.store.read(cx).data_collection_choice.is_enabled() {
70 DataCollectionState::Enabled {
71 is_project_open_source,
72 }
73 } else {
74 DataCollectionState::Disabled {
75 is_project_open_source,
76 }
77 }
78 } else {
79 return DataCollectionState::Disabled {
80 is_project_open_source: false,
81 };
82 }
83 }
84
85 fn toggle_data_collection(&mut self, cx: &mut App) {
86 self.store.update(cx, |store, cx| {
87 store.toggle_data_collection_choice(cx);
88 });
89 }
90
91 fn usage(&self, cx: &App) -> Option<client::EditPredictionUsage> {
92 self.store.read(cx).usage(cx)
93 }
94
95 fn is_enabled(
96 &self,
97 _buffer: &Entity<language::Buffer>,
98 _cursor_position: language::Anchor,
99 cx: &App,
100 ) -> bool {
101 let store = self.store.read(cx);
102 if store.edit_prediction_model == EditPredictionModel::Sweep {
103 store.has_sweep_api_token()
104 } else {
105 true
106 }
107 }
108
109 fn is_refreshing(&self, cx: &App) -> bool {
110 self.store.read(cx).is_refreshing(&self.project)
111 }
112
113 fn refresh(
114 &mut self,
115 buffer: Entity<language::Buffer>,
116 cursor_position: language::Anchor,
117 _debounce: bool,
118 cx: &mut Context<Self>,
119 ) {
120 let store = self.store.read(cx);
121
122 if store.user_store.read_with(cx, |user_store, _cx| {
123 user_store.account_too_young() || user_store.has_overdue_invoices()
124 }) {
125 return;
126 }
127
128 if let Some(current) = store.current_prediction_for_buffer(&buffer, &self.project, cx)
129 && let BufferEditPrediction::Local { prediction } = current
130 && prediction.interpolate(buffer.read(cx)).is_some()
131 {
132 return;
133 }
134
135 self.store.update(cx, |store, cx| {
136 store.refresh_context(&self.project, &buffer, cursor_position, cx);
137 store.refresh_prediction_from_buffer(self.project.clone(), buffer, cursor_position, cx)
138 });
139 }
140
141 fn cycle(
142 &mut self,
143 _buffer: Entity<language::Buffer>,
144 _cursor_position: language::Anchor,
145 _direction: Direction,
146 _cx: &mut Context<Self>,
147 ) {
148 }
149
150 fn accept(&mut self, cx: &mut Context<Self>) {
151 self.store.update(cx, |store, cx| {
152 store.accept_current_prediction(&self.project, cx);
153 });
154 }
155
156 fn discard(&mut self, cx: &mut Context<Self>) {
157 self.store.update(cx, |store, _cx| {
158 store.reject_current_prediction(EditPredictionRejectReason::Discarded, &self.project);
159 });
160 }
161
162 fn did_show(&mut self, cx: &mut Context<Self>) {
163 self.store.update(cx, |store, cx| {
164 store.did_show_current_prediction(&self.project, cx);
165 });
166 }
167
168 fn suggest(
169 &mut self,
170 buffer: &Entity<language::Buffer>,
171 cursor_position: language::Anchor,
172 cx: &mut Context<Self>,
173 ) -> Option<edit_prediction_types::EditPrediction> {
174 let prediction =
175 self.store
176 .read(cx)
177 .current_prediction_for_buffer(buffer, &self.project, cx)?;
178
179 let prediction = match prediction {
180 BufferEditPrediction::Local { prediction } => prediction,
181 BufferEditPrediction::Jump { prediction } => {
182 return Some(edit_prediction_types::EditPrediction::Jump {
183 id: Some(prediction.id.to_string().into()),
184 snapshot: prediction.snapshot.clone(),
185 target: prediction.edits.first().unwrap().0.start,
186 });
187 }
188 };
189
190 let buffer = buffer.read(cx);
191 let snapshot = buffer.snapshot();
192
193 let Some(edits) = prediction.interpolate(&snapshot) else {
194 self.store.update(cx, |store, _cx| {
195 store.reject_current_prediction(
196 EditPredictionRejectReason::InterpolatedEmpty,
197 &self.project,
198 );
199 });
200 return None;
201 };
202
203 let cursor_row = cursor_position.to_point(&snapshot).row;
204 let (closest_edit_ix, (closest_edit_range, _)) =
205 edits.iter().enumerate().min_by_key(|(_, (range, _))| {
206 let distance_from_start = cursor_row.abs_diff(range.start.to_point(&snapshot).row);
207 let distance_from_end = cursor_row.abs_diff(range.end.to_point(&snapshot).row);
208 cmp::min(distance_from_start, distance_from_end)
209 })?;
210
211 let mut edit_start_ix = closest_edit_ix;
212 for (range, _) in edits[..edit_start_ix].iter().rev() {
213 let distance_from_closest_edit = closest_edit_range.start.to_point(&snapshot).row
214 - range.end.to_point(&snapshot).row;
215 if distance_from_closest_edit <= 1 {
216 edit_start_ix -= 1;
217 } else {
218 break;
219 }
220 }
221
222 let mut edit_end_ix = closest_edit_ix + 1;
223 for (range, _) in &edits[edit_end_ix..] {
224 let distance_from_closest_edit =
225 range.start.to_point(buffer).row - closest_edit_range.end.to_point(&snapshot).row;
226 if distance_from_closest_edit <= 1 {
227 edit_end_ix += 1;
228 } else {
229 break;
230 }
231 }
232
233 Some(edit_prediction_types::EditPrediction::Local {
234 id: Some(prediction.id.to_string().into()),
235 edits: edits[edit_start_ix..edit_end_ix].to_vec(),
236 edit_preview: Some(prediction.edit_preview.clone()),
237 })
238 }
239}