1use std::{
2 cmp,
3 sync::Arc,
4 time::{Duration, Instant},
5};
6
7use arrayvec::ArrayVec;
8use client::{Client, UserStore};
9use edit_prediction::{DataCollectionState, Direction, EditPredictionProvider};
10use gpui::{App, Entity, Task, prelude::*};
11use language::ToPoint as _;
12use project::Project;
13use util::ResultExt as _;
14
15use crate::{BufferEditPrediction, Zeta};
16
17pub struct ZetaEditPredictionProvider {
18 zeta: Entity<Zeta>,
19 next_pending_prediction_id: usize,
20 pending_predictions: ArrayVec<PendingPrediction, 2>,
21 last_request_timestamp: Instant,
22 project: Entity<Project>,
23}
24
25impl ZetaEditPredictionProvider {
26 pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
27
28 pub fn new(
29 project: Entity<Project>,
30 client: &Arc<Client>,
31 user_store: &Entity<UserStore>,
32 cx: &mut App,
33 ) -> Self {
34 let zeta = Zeta::global(client, user_store, cx);
35 zeta.update(cx, |zeta, cx| {
36 zeta.register_project(&project, cx);
37 });
38
39 Self {
40 zeta,
41 next_pending_prediction_id: 0,
42 pending_predictions: ArrayVec::new(),
43 last_request_timestamp: Instant::now(),
44 project: project,
45 }
46 }
47}
48
49struct PendingPrediction {
50 id: usize,
51 _task: Task<()>,
52}
53
54impl EditPredictionProvider for ZetaEditPredictionProvider {
55 fn name() -> &'static str {
56 "zed-predict2"
57 }
58
59 fn display_name() -> &'static str {
60 "Zed's Edit Predictions 2"
61 }
62
63 fn show_completions_in_menu() -> bool {
64 true
65 }
66
67 fn show_tab_accept_marker() -> bool {
68 true
69 }
70
71 fn data_collection_state(&self, _cx: &App) -> DataCollectionState {
72 // TODO [zeta2]
73 DataCollectionState::Unsupported
74 }
75
76 fn toggle_data_collection(&mut self, _cx: &mut App) {
77 // TODO [zeta2]
78 }
79
80 fn usage(&self, cx: &App) -> Option<client::EditPredictionUsage> {
81 self.zeta.read(cx).usage(cx)
82 }
83
84 fn is_enabled(
85 &self,
86 _buffer: &Entity<language::Buffer>,
87 _cursor_position: language::Anchor,
88 _cx: &App,
89 ) -> bool {
90 true
91 }
92
93 fn is_refreshing(&self) -> bool {
94 !self.pending_predictions.is_empty()
95 }
96
97 fn refresh(
98 &mut self,
99 buffer: Entity<language::Buffer>,
100 cursor_position: language::Anchor,
101 _debounce: bool,
102 cx: &mut Context<Self>,
103 ) {
104 let zeta = self.zeta.read(cx);
105
106 if zeta.user_store.read_with(cx, |user_store, _cx| {
107 user_store.account_too_young() || user_store.has_overdue_invoices()
108 }) {
109 return;
110 }
111
112 if let Some(current) = zeta.current_prediction_for_buffer(&buffer, &self.project, cx)
113 && let BufferEditPrediction::Local { prediction } = current
114 && prediction.interpolate(buffer.read(cx)).is_some()
115 {
116 return;
117 }
118
119 self.zeta.update(cx, |zeta, cx| {
120 zeta.refresh_context_if_needed(&self.project, &buffer, cursor_position, cx);
121 });
122
123 let pending_prediction_id = self.next_pending_prediction_id;
124 self.next_pending_prediction_id += 1;
125 let last_request_timestamp = self.last_request_timestamp;
126
127 let project = self.project.clone();
128 let task = cx.spawn(async move |this, cx| {
129 if let Some(timeout) = (last_request_timestamp + Self::THROTTLE_TIMEOUT)
130 .checked_duration_since(Instant::now())
131 {
132 cx.background_executor().timer(timeout).await;
133 }
134
135 let refresh_task = this.update(cx, |this, cx| {
136 this.last_request_timestamp = Instant::now();
137 this.zeta.update(cx, |zeta, cx| {
138 zeta.refresh_prediction(&project, &buffer, cursor_position, cx)
139 })
140 });
141
142 if let Some(refresh_task) = refresh_task.ok() {
143 refresh_task.await.log_err();
144 }
145
146 this.update(cx, |this, cx| {
147 if this.pending_predictions[0].id == pending_prediction_id {
148 this.pending_predictions.remove(0);
149 } else {
150 this.pending_predictions.clear();
151 }
152
153 cx.notify();
154 })
155 .ok();
156 });
157
158 // We always maintain at most two pending predictions. When we already
159 // have two, we replace the newest one.
160 if self.pending_predictions.len() <= 1 {
161 self.pending_predictions.push(PendingPrediction {
162 id: pending_prediction_id,
163 _task: task,
164 });
165 } else if self.pending_predictions.len() == 2 {
166 self.pending_predictions.pop();
167 self.pending_predictions.push(PendingPrediction {
168 id: pending_prediction_id,
169 _task: task,
170 });
171 }
172
173 cx.notify();
174 }
175
176 fn cycle(
177 &mut self,
178 _buffer: Entity<language::Buffer>,
179 _cursor_position: language::Anchor,
180 _direction: Direction,
181 _cx: &mut Context<Self>,
182 ) {
183 }
184
185 fn accept(&mut self, cx: &mut Context<Self>) {
186 self.zeta.update(cx, |zeta, cx| {
187 zeta.accept_current_prediction(&self.project, cx);
188 });
189 self.pending_predictions.clear();
190 }
191
192 fn discard(&mut self, cx: &mut Context<Self>) {
193 self.zeta.update(cx, |zeta, _cx| {
194 zeta.discard_current_prediction(&self.project);
195 });
196 self.pending_predictions.clear();
197 }
198
199 fn suggest(
200 &mut self,
201 buffer: &Entity<language::Buffer>,
202 cursor_position: language::Anchor,
203 cx: &mut Context<Self>,
204 ) -> Option<edit_prediction::EditPrediction> {
205 let prediction =
206 self.zeta
207 .read(cx)
208 .current_prediction_for_buffer(buffer, &self.project, cx)?;
209
210 let prediction = match prediction {
211 BufferEditPrediction::Local { prediction } => prediction,
212 BufferEditPrediction::Jump { prediction } => {
213 return Some(edit_prediction::EditPrediction::Jump {
214 id: Some(prediction.id.to_string().into()),
215 snapshot: prediction.snapshot.clone(),
216 target: prediction.edits.first().unwrap().0.start,
217 });
218 }
219 };
220
221 let buffer = buffer.read(cx);
222 let snapshot = buffer.snapshot();
223
224 let Some(edits) = prediction.interpolate(&snapshot) else {
225 self.zeta.update(cx, |zeta, _cx| {
226 zeta.discard_current_prediction(&self.project);
227 });
228 return None;
229 };
230
231 let cursor_row = cursor_position.to_point(&snapshot).row;
232 let (closest_edit_ix, (closest_edit_range, _)) =
233 edits.iter().enumerate().min_by_key(|(_, (range, _))| {
234 let distance_from_start = cursor_row.abs_diff(range.start.to_point(&snapshot).row);
235 let distance_from_end = cursor_row.abs_diff(range.end.to_point(&snapshot).row);
236 cmp::min(distance_from_start, distance_from_end)
237 })?;
238
239 let mut edit_start_ix = closest_edit_ix;
240 for (range, _) in edits[..edit_start_ix].iter().rev() {
241 let distance_from_closest_edit = closest_edit_range.start.to_point(&snapshot).row
242 - range.end.to_point(&snapshot).row;
243 if distance_from_closest_edit <= 1 {
244 edit_start_ix -= 1;
245 } else {
246 break;
247 }
248 }
249
250 let mut edit_end_ix = closest_edit_ix + 1;
251 for (range, _) in &edits[edit_end_ix..] {
252 let distance_from_closest_edit =
253 range.start.to_point(buffer).row - closest_edit_range.end.to_point(&snapshot).row;
254 if distance_from_closest_edit <= 1 {
255 edit_end_ix += 1;
256 } else {
257 break;
258 }
259 }
260
261 Some(edit_prediction::EditPrediction::Local {
262 id: Some(prediction.id.to_string().into()),
263 edits: edits[edit_start_ix..edit_end_ix].to_vec(),
264 edit_preview: Some(prediction.edit_preview.clone()),
265 })
266 }
267}