1use std::{cmp, sync::Arc};
2
3use client::{Client, UserStore};
4use cloud_llm_client::EditPredictionRejectReason;
5use edit_prediction_types::{DataCollectionState, EditPredictionDelegate};
6use gpui::{App, Entity, prelude::*};
7use language::{Buffer, ToPoint as _};
8use project::Project;
9
10use crate::{BufferEditPrediction, EditPredictionModel, EditPredictionStore};
11
12pub struct ZedEditPredictionDelegate {
13 store: Entity<EditPredictionStore>,
14 project: Entity<Project>,
15 singleton_buffer: Option<Entity<Buffer>>,
16}
17
18impl ZedEditPredictionDelegate {
19 pub fn new(
20 project: Entity<Project>,
21 singleton_buffer: Option<Entity<Buffer>>,
22 client: &Arc<Client>,
23 user_store: &Entity<UserStore>,
24 cx: &mut Context<Self>,
25 ) -> Self {
26 let store = EditPredictionStore::global(client, user_store, cx);
27 store.update(cx, |store, cx| {
28 store.register_project(&project, cx);
29 });
30
31 cx.observe(&store, |_this, _ep_store, cx| {
32 cx.notify();
33 })
34 .detach();
35
36 Self {
37 project: project,
38 store: store,
39 singleton_buffer,
40 }
41 }
42}
43
44impl EditPredictionDelegate for ZedEditPredictionDelegate {
45 fn name() -> &'static str {
46 "zed-predict"
47 }
48
49 fn display_name() -> &'static str {
50 "Zed's Edit Predictions"
51 }
52
53 fn show_predictions_in_menu() -> bool {
54 true
55 }
56
57 fn show_tab_accept_marker() -> bool {
58 true
59 }
60
61 fn data_collection_state(&self, cx: &App) -> DataCollectionState {
62 if let Some(buffer) = &self.singleton_buffer
63 && let Some(file) = buffer.read(cx).file()
64 {
65 let is_project_open_source =
66 self.store
67 .read(cx)
68 .is_file_open_source(&self.project, file, cx);
69 if self.store.read(cx).data_collection_choice.is_enabled() {
70 DataCollectionState::Enabled {
71 is_project_open_source,
72 }
73 } else {
74 DataCollectionState::Disabled {
75 is_project_open_source,
76 }
77 }
78 } else {
79 return DataCollectionState::Disabled {
80 is_project_open_source: false,
81 };
82 }
83 }
84
85 fn toggle_data_collection(&mut self, cx: &mut App) {
86 self.store.update(cx, |store, cx| {
87 store.toggle_data_collection_choice(cx);
88 });
89 }
90
91 fn usage(&self, cx: &App) -> Option<client::EditPredictionUsage> {
92 self.store.read(cx).usage(cx)
93 }
94
95 fn is_enabled(
96 &self,
97 _buffer: &Entity<language::Buffer>,
98 _cursor_position: language::Anchor,
99 cx: &App,
100 ) -> bool {
101 let store = self.store.read(cx);
102 if store.edit_prediction_model == EditPredictionModel::Sweep {
103 store.has_sweep_api_token(cx)
104 } else {
105 true
106 }
107 }
108
109 fn is_refreshing(&self, cx: &App) -> bool {
110 self.store.read(cx).is_refreshing(&self.project)
111 }
112
113 fn refresh(
114 &mut self,
115 buffer: Entity<language::Buffer>,
116 cursor_position: language::Anchor,
117 _debounce: bool,
118 cx: &mut Context<Self>,
119 ) {
120 let store = self.store.read(cx);
121
122 if store.user_store.read_with(cx, |user_store, _cx| {
123 user_store.account_too_young() || user_store.has_overdue_invoices()
124 }) {
125 return;
126 }
127
128 self.store.update(cx, |store, cx| {
129 if let Some(current) =
130 store.prediction_at(&buffer, Some(cursor_position), &self.project, cx)
131 && let BufferEditPrediction::Local { prediction } = current
132 && prediction.interpolate(buffer.read(cx)).is_some()
133 {
134 return;
135 }
136
137 store.refresh_context(&self.project, &buffer, cursor_position, cx);
138 store.refresh_prediction_from_buffer(self.project.clone(), buffer, cursor_position, cx)
139 });
140 }
141
142 fn accept(&mut self, cx: &mut Context<Self>) {
143 self.store.update(cx, |store, cx| {
144 store.accept_current_prediction(&self.project, cx);
145 });
146 }
147
148 fn discard(&mut self, cx: &mut Context<Self>) {
149 self.store.update(cx, |store, _cx| {
150 store.reject_current_prediction(EditPredictionRejectReason::Discarded, &self.project);
151 });
152 }
153
154 fn did_show(&mut self, cx: &mut Context<Self>) {
155 self.store.update(cx, |store, cx| {
156 store.did_show_current_prediction(&self.project, cx);
157 });
158 }
159
160 fn suggest(
161 &mut self,
162 buffer: &Entity<language::Buffer>,
163 cursor_position: language::Anchor,
164 cx: &mut Context<Self>,
165 ) -> Option<edit_prediction_types::EditPrediction> {
166 self.store.update(cx, |store, cx| {
167 let prediction =
168 store.prediction_at(buffer, Some(cursor_position), &self.project, cx)?;
169
170 let prediction = match prediction {
171 BufferEditPrediction::Local { prediction } => prediction,
172 BufferEditPrediction::Jump { prediction } => {
173 return Some(edit_prediction_types::EditPrediction::Jump {
174 id: Some(prediction.id.to_string().into()),
175 snapshot: prediction.snapshot.clone(),
176 target: prediction.edits.first().unwrap().0.start,
177 });
178 }
179 };
180
181 let buffer = buffer.read(cx);
182 let snapshot = buffer.snapshot();
183
184 let Some(edits) = prediction.interpolate(&snapshot) else {
185 store.reject_current_prediction(
186 EditPredictionRejectReason::InterpolatedEmpty,
187 &self.project,
188 );
189 return None;
190 };
191
192 let cursor_row = cursor_position.to_point(&snapshot).row;
193 let (closest_edit_ix, (closest_edit_range, _)) =
194 edits.iter().enumerate().min_by_key(|(_, (range, _))| {
195 let distance_from_start =
196 cursor_row.abs_diff(range.start.to_point(&snapshot).row);
197 let distance_from_end = cursor_row.abs_diff(range.end.to_point(&snapshot).row);
198 cmp::min(distance_from_start, distance_from_end)
199 })?;
200
201 let mut edit_start_ix = closest_edit_ix;
202 for (range, _) in edits[..edit_start_ix].iter().rev() {
203 let distance_from_closest_edit = closest_edit_range.start.to_point(&snapshot).row
204 - range.end.to_point(&snapshot).row;
205 if distance_from_closest_edit <= 1 {
206 edit_start_ix -= 1;
207 } else {
208 break;
209 }
210 }
211
212 let mut edit_end_ix = closest_edit_ix + 1;
213 for (range, _) in &edits[edit_end_ix..] {
214 let distance_from_closest_edit = range.start.to_point(buffer).row
215 - closest_edit_range.end.to_point(&snapshot).row;
216 if distance_from_closest_edit <= 1 {
217 edit_end_ix += 1;
218 } else {
219 break;
220 }
221 }
222
223 Some(edit_prediction_types::EditPrediction::Local {
224 id: Some(prediction.id.to_string().into()),
225 edits: edits[edit_start_ix..edit_end_ix].to_vec(),
226 edit_preview: Some(prediction.edit_preview.clone()),
227 })
228 })
229 }
230}