1use std::{cmp, sync::Arc};
2
3use client::{Client, UserStore};
4use cloud_llm_client::EditPredictionRejectReason;
5use edit_prediction_types::{DataCollectionState, Direction, EditPredictionDelegate};
6use gpui::{App, Entity, prelude::*};
7use language::{Buffer, ToPoint as _};
8use project::Project;
9
10use crate::{BufferEditPrediction, EditPredictionModel, EditPredictionStore};
11
12pub struct ZedEditPredictionDelegate {
13 store: Entity<EditPredictionStore>,
14 project: Entity<Project>,
15 singleton_buffer: Option<Entity<Buffer>>,
16}
17
18impl ZedEditPredictionDelegate {
19 pub fn new(
20 project: Entity<Project>,
21 singleton_buffer: Option<Entity<Buffer>>,
22 client: &Arc<Client>,
23 user_store: &Entity<UserStore>,
24 cx: &mut Context<Self>,
25 ) -> Self {
26 let store = EditPredictionStore::global(client, user_store, cx);
27 store.update(cx, |store, cx| {
28 store.register_project(&project, cx);
29 });
30
31 cx.observe(&store, |_this, _ep_store, cx| {
32 cx.notify();
33 })
34 .detach();
35
36 Self {
37 project: project,
38 store: store,
39 singleton_buffer,
40 }
41 }
42}
43
44impl EditPredictionDelegate for ZedEditPredictionDelegate {
45 fn name() -> &'static str {
46 "zed-predict"
47 }
48
49 fn display_name() -> &'static str {
50 "Zed's Edit Predictions"
51 }
52
53 fn show_predictions_in_menu() -> bool {
54 true
55 }
56
57 fn show_tab_accept_marker() -> bool {
58 true
59 }
60
61 fn data_collection_state(&self, cx: &App) -> DataCollectionState {
62 if let Some(buffer) = &self.singleton_buffer
63 && let Some(file) = buffer.read(cx).file()
64 {
65 let is_project_open_source =
66 self.store
67 .read(cx)
68 .is_file_open_source(&self.project, file, cx);
69 if self.store.read(cx).data_collection_choice.is_enabled() {
70 DataCollectionState::Enabled {
71 is_project_open_source,
72 }
73 } else {
74 DataCollectionState::Disabled {
75 is_project_open_source,
76 }
77 }
78 } else {
79 return DataCollectionState::Disabled {
80 is_project_open_source: false,
81 };
82 }
83 }
84
85 fn toggle_data_collection(&mut self, cx: &mut App) {
86 self.store.update(cx, |store, cx| {
87 store.toggle_data_collection_choice(cx);
88 });
89 }
90
91 fn usage(&self, cx: &App) -> Option<client::EditPredictionUsage> {
92 self.store.read(cx).usage(cx)
93 }
94
95 fn is_enabled(
96 &self,
97 _buffer: &Entity<language::Buffer>,
98 _cursor_position: language::Anchor,
99 cx: &App,
100 ) -> bool {
101 let store = self.store.read(cx);
102 if store.edit_prediction_model == EditPredictionModel::Sweep {
103 store.has_sweep_api_token(cx)
104 } else {
105 true
106 }
107 }
108
109 fn is_refreshing(&self, cx: &App) -> bool {
110 self.store.read(cx).is_refreshing(&self.project)
111 }
112
113 fn refresh(
114 &mut self,
115 buffer: Entity<language::Buffer>,
116 cursor_position: language::Anchor,
117 _debounce: bool,
118 cx: &mut Context<Self>,
119 ) {
120 let store = self.store.read(cx);
121
122 if store.user_store.read_with(cx, |user_store, _cx| {
123 user_store.account_too_young() || user_store.has_overdue_invoices()
124 }) {
125 return;
126 }
127
128 self.store.update(cx, |store, cx| {
129 if let Some(current) =
130 store.prediction_at(&buffer, Some(cursor_position), &self.project, cx)
131 && let BufferEditPrediction::Local { prediction } = current
132 && prediction.interpolate(buffer.read(cx)).is_some()
133 {
134 return;
135 }
136
137 store.refresh_context(&self.project, &buffer, cursor_position, cx);
138 store.refresh_prediction_from_buffer(self.project.clone(), buffer, cursor_position, cx)
139 });
140 }
141
142 fn cycle(
143 &mut self,
144 _buffer: Entity<language::Buffer>,
145 _cursor_position: language::Anchor,
146 _direction: Direction,
147 _cx: &mut Context<Self>,
148 ) {
149 }
150
151 fn accept(&mut self, cx: &mut Context<Self>) {
152 self.store.update(cx, |store, cx| {
153 store.accept_current_prediction(&self.project, cx);
154 });
155 }
156
157 fn discard(&mut self, cx: &mut Context<Self>) {
158 self.store.update(cx, |store, _cx| {
159 store.reject_current_prediction(EditPredictionRejectReason::Discarded, &self.project);
160 });
161 }
162
163 fn did_show(&mut self, cx: &mut Context<Self>) {
164 self.store.update(cx, |store, cx| {
165 store.did_show_current_prediction(&self.project, cx);
166 });
167 }
168
169 fn suggest(
170 &mut self,
171 buffer: &Entity<language::Buffer>,
172 cursor_position: language::Anchor,
173 cx: &mut Context<Self>,
174 ) -> Option<edit_prediction_types::EditPrediction> {
175 self.store.update(cx, |store, cx| {
176 let prediction =
177 store.prediction_at(buffer, Some(cursor_position), &self.project, cx)?;
178
179 let prediction = match prediction {
180 BufferEditPrediction::Local { prediction } => prediction,
181 BufferEditPrediction::Jump { prediction } => {
182 return Some(edit_prediction_types::EditPrediction::Jump {
183 id: Some(prediction.id.to_string().into()),
184 snapshot: prediction.snapshot.clone(),
185 target: prediction.edits.first().unwrap().0.start,
186 });
187 }
188 };
189
190 let buffer = buffer.read(cx);
191 let snapshot = buffer.snapshot();
192
193 let Some(edits) = prediction.interpolate(&snapshot) else {
194 store.reject_current_prediction(
195 EditPredictionRejectReason::InterpolatedEmpty,
196 &self.project,
197 );
198 return None;
199 };
200
201 let cursor_row = cursor_position.to_point(&snapshot).row;
202 let (closest_edit_ix, (closest_edit_range, _)) =
203 edits.iter().enumerate().min_by_key(|(_, (range, _))| {
204 let distance_from_start =
205 cursor_row.abs_diff(range.start.to_point(&snapshot).row);
206 let distance_from_end = cursor_row.abs_diff(range.end.to_point(&snapshot).row);
207 cmp::min(distance_from_start, distance_from_end)
208 })?;
209
210 let mut edit_start_ix = closest_edit_ix;
211 for (range, _) in edits[..edit_start_ix].iter().rev() {
212 let distance_from_closest_edit = closest_edit_range.start.to_point(&snapshot).row
213 - range.end.to_point(&snapshot).row;
214 if distance_from_closest_edit <= 1 {
215 edit_start_ix -= 1;
216 } else {
217 break;
218 }
219 }
220
221 let mut edit_end_ix = closest_edit_ix + 1;
222 for (range, _) in &edits[edit_end_ix..] {
223 let distance_from_closest_edit = range.start.to_point(buffer).row
224 - closest_edit_range.end.to_point(&snapshot).row;
225 if distance_from_closest_edit <= 1 {
226 edit_end_ix += 1;
227 } else {
228 break;
229 }
230 }
231
232 Some(edit_prediction_types::EditPrediction::Local {
233 id: Some(prediction.id.to_string().into()),
234 edits: edits[edit_start_ix..edit_end_ix].to_vec(),
235 edit_preview: Some(prediction.edit_preview.clone()),
236 })
237 })
238 }
239}