1use std::{cmp, sync::Arc};
2
3use client::{Client, UserStore};
4use cloud_llm_client::EditPredictionRejectReason;
5use edit_prediction_types::{
6 DataCollectionState, EditPredictionDelegate, EditPredictionDiscardReason,
7 EditPredictionIconSet, SuggestionDisplayType,
8};
9use gpui::{App, Entity, prelude::*};
10use language::{Buffer, ToPoint as _};
11use project::Project;
12
13use crate::{BufferEditPrediction, EditPredictionStore};
14
15pub struct ZedEditPredictionDelegate {
16 store: Entity<EditPredictionStore>,
17 project: Entity<Project>,
18 singleton_buffer: Option<Entity<Buffer>>,
19}
20
21impl ZedEditPredictionDelegate {
22 pub fn new(
23 project: Entity<Project>,
24 singleton_buffer: Option<Entity<Buffer>>,
25 client: &Arc<Client>,
26 user_store: &Entity<UserStore>,
27 cx: &mut Context<Self>,
28 ) -> Self {
29 let store = EditPredictionStore::global(client, user_store, cx);
30 store.update(cx, |store, cx| {
31 store.register_project(&project, cx);
32 });
33
34 cx.observe(&store, |_this, _ep_store, cx| {
35 cx.notify();
36 })
37 .detach();
38
39 Self {
40 project: project,
41 store: store,
42 singleton_buffer,
43 }
44 }
45}
46
47impl EditPredictionDelegate for ZedEditPredictionDelegate {
48 fn name() -> &'static str {
49 "zed-predict"
50 }
51
52 fn display_name() -> &'static str {
53 "Zed's Edit Predictions"
54 }
55
56 fn show_predictions_in_menu() -> bool {
57 true
58 }
59
60 fn show_tab_accept_marker() -> bool {
61 true
62 }
63
64 fn icons(&self, cx: &App) -> EditPredictionIconSet {
65 self.store.read(cx).icons(cx)
66 }
67
68 fn data_collection_state(&self, cx: &App) -> DataCollectionState {
69 if let Some(buffer) = &self.singleton_buffer
70 && let Some(file) = buffer.read(cx).file()
71 {
72 let is_project_open_source =
73 self.store
74 .read(cx)
75 .is_file_open_source(&self.project, file, cx);
76 if self.store.read(cx).data_collection_choice.is_enabled(cx) {
77 DataCollectionState::Enabled {
78 is_project_open_source,
79 }
80 } else {
81 DataCollectionState::Disabled {
82 is_project_open_source,
83 }
84 }
85 } else {
86 return DataCollectionState::Disabled {
87 is_project_open_source: false,
88 };
89 }
90 }
91
92 fn toggle_data_collection(&mut self, cx: &mut App) {
93 self.store.update(cx, |store, cx| {
94 store.toggle_data_collection_choice(cx);
95 });
96 }
97
98 fn usage(&self, cx: &App) -> Option<client::EditPredictionUsage> {
99 self.store.read(cx).usage(cx)
100 }
101
102 fn is_enabled(
103 &self,
104 _buffer: &Entity<language::Buffer>,
105 _cursor_position: language::Anchor,
106 _cx: &App,
107 ) -> bool {
108 true
109 }
110
111 fn is_refreshing(&self, cx: &App) -> bool {
112 self.store.read(cx).is_refreshing(&self.project)
113 }
114
115 fn refresh(
116 &mut self,
117 buffer: Entity<language::Buffer>,
118 cursor_position: language::Anchor,
119 _debounce: bool,
120 cx: &mut Context<Self>,
121 ) {
122 let store = self.store.read(cx);
123
124 if store.user_store.read_with(cx, |user_store, _cx| {
125 user_store.account_too_young() || user_store.has_overdue_invoices()
126 }) {
127 return;
128 }
129
130 self.store.update(cx, |store, cx| {
131 if let Some(current) =
132 store.prediction_at(&buffer, Some(cursor_position), &self.project, cx)
133 && let BufferEditPrediction::Local { prediction } = current
134 && prediction.interpolate(buffer.read(cx)).is_some()
135 {
136 return;
137 }
138
139 store.refresh_context(&self.project, &buffer, cursor_position, cx);
140 store.refresh_prediction_from_buffer(self.project.clone(), buffer, cursor_position, cx)
141 });
142 }
143
144 fn accept(&mut self, cx: &mut Context<Self>) {
145 self.store.update(cx, |store, cx| {
146 store.accept_current_prediction(&self.project, cx);
147 });
148 }
149
150 fn discard(&mut self, reason: EditPredictionDiscardReason, cx: &mut Context<Self>) {
151 let reject_reason = match reason {
152 EditPredictionDiscardReason::Rejected => EditPredictionRejectReason::Rejected,
153 EditPredictionDiscardReason::Ignored => EditPredictionRejectReason::Discarded,
154 };
155 self.store.update(cx, |store, cx| {
156 store.reject_current_prediction(reject_reason, &self.project, cx);
157 });
158 }
159
160 fn did_show(&mut self, display_type: SuggestionDisplayType, cx: &mut Context<Self>) {
161 self.store.update(cx, |store, cx| {
162 store.did_show_current_prediction(&self.project, display_type, cx);
163 });
164 }
165
166 fn suggest(
167 &mut self,
168 buffer: &Entity<language::Buffer>,
169 cursor_position: language::Anchor,
170 cx: &mut Context<Self>,
171 ) -> Option<edit_prediction_types::EditPrediction> {
172 self.store.update(cx, |store, cx| {
173 let prediction =
174 store.prediction_at(buffer, Some(cursor_position), &self.project, cx)?;
175
176 let prediction = match prediction {
177 BufferEditPrediction::Local { prediction } => prediction,
178 BufferEditPrediction::Jump { prediction } => {
179 return Some(edit_prediction_types::EditPrediction::Jump {
180 id: Some(prediction.id.to_string().into()),
181 snapshot: prediction.snapshot.clone(),
182 target: prediction.edits.first().unwrap().0.start,
183 });
184 }
185 };
186
187 let buffer = buffer.read(cx);
188 let snapshot = buffer.snapshot();
189
190 let Some(edits) = prediction.interpolate(&snapshot) else {
191 store.reject_current_prediction(
192 EditPredictionRejectReason::InterpolatedEmpty,
193 &self.project,
194 cx,
195 );
196 return None;
197 };
198
199 let cursor_row = cursor_position.to_point(&snapshot).row;
200 let (closest_edit_ix, (closest_edit_range, _)) =
201 edits.iter().enumerate().min_by_key(|(_, (range, _))| {
202 let distance_from_start =
203 cursor_row.abs_diff(range.start.to_point(&snapshot).row);
204 let distance_from_end = cursor_row.abs_diff(range.end.to_point(&snapshot).row);
205 cmp::min(distance_from_start, distance_from_end)
206 })?;
207
208 let mut edit_start_ix = closest_edit_ix;
209 for (range, _) in edits[..edit_start_ix].iter().rev() {
210 let distance_from_closest_edit = closest_edit_range.start.to_point(&snapshot).row
211 - range.end.to_point(&snapshot).row;
212 if distance_from_closest_edit <= 1 {
213 edit_start_ix -= 1;
214 } else {
215 break;
216 }
217 }
218
219 let mut edit_end_ix = closest_edit_ix + 1;
220 for (range, _) in &edits[edit_end_ix..] {
221 let distance_from_closest_edit = range.start.to_point(buffer).row
222 - closest_edit_range.end.to_point(&snapshot).row;
223 if distance_from_closest_edit <= 1 {
224 edit_end_ix += 1;
225 } else {
226 break;
227 }
228 }
229
230 Some(edit_prediction_types::EditPrediction::Local {
231 id: Some(prediction.id.to_string().into()),
232 edits: edits[edit_start_ix..edit_end_ix].to_vec(),
233 cursor_position: prediction.cursor_position,
234 edit_preview: Some(prediction.edit_preview.clone()),
235 })
236 })
237 }
238}