1use std::{
2 cmp,
3 sync::Arc,
4 time::{Duration, Instant},
5};
6
7use arrayvec::ArrayVec;
8use client::{Client, UserStore};
9use edit_prediction::{DataCollectionState, Direction, EditPredictionProvider};
10use gpui::{App, Entity, Task, prelude::*};
11use language::ToPoint as _;
12use project::Project;
13use util::ResultExt as _;
14
15use crate::{BufferEditPrediction, Zeta};
16
17pub struct ZetaEditPredictionProvider {
18 zeta: Entity<Zeta>,
19 next_pending_prediction_id: usize,
20 pending_predictions: ArrayVec<PendingPrediction, 2>,
21 last_request_timestamp: Instant,
22 project: Entity<Project>,
23}
24
25impl ZetaEditPredictionProvider {
26 pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
27
28 pub fn new(
29 project: Entity<Project>,
30 client: &Arc<Client>,
31 user_store: &Entity<UserStore>,
32 cx: &mut App,
33 ) -> Self {
34 let zeta = Zeta::global(client, user_store, cx);
35 zeta.update(cx, |zeta, cx| {
36 zeta.register_project(&project, cx);
37 });
38
39 Self {
40 zeta,
41 next_pending_prediction_id: 0,
42 pending_predictions: ArrayVec::new(),
43 last_request_timestamp: Instant::now(),
44 project: project,
45 }
46 }
47}
48
49struct PendingPrediction {
50 id: usize,
51 _task: Task<()>,
52}
53
54impl EditPredictionProvider for ZetaEditPredictionProvider {
55 fn name() -> &'static str {
56 "zed-predict2"
57 }
58
59 fn display_name() -> &'static str {
60 "Zed's Edit Predictions 2"
61 }
62
63 fn show_completions_in_menu() -> bool {
64 true
65 }
66
67 fn show_tab_accept_marker() -> bool {
68 true
69 }
70
71 fn data_collection_state(&self, _cx: &App) -> DataCollectionState {
72 // TODO [zeta2]
73 DataCollectionState::Unsupported
74 }
75
76 fn toggle_data_collection(&mut self, _cx: &mut App) {
77 // TODO [zeta2]
78 }
79
80 fn usage(&self, cx: &App) -> Option<client::EditPredictionUsage> {
81 self.zeta.read(cx).usage(cx)
82 }
83
84 fn is_enabled(
85 &self,
86 _buffer: &Entity<language::Buffer>,
87 _cursor_position: language::Anchor,
88 _cx: &App,
89 ) -> bool {
90 true
91 }
92
93 fn is_refreshing(&self) -> bool {
94 !self.pending_predictions.is_empty()
95 }
96
97 fn refresh(
98 &mut self,
99 buffer: Entity<language::Buffer>,
100 cursor_position: language::Anchor,
101 _debounce: bool,
102 cx: &mut Context<Self>,
103 ) {
104 let zeta = self.zeta.read(cx);
105
106 if zeta.user_store.read_with(cx, |user_store, _cx| {
107 user_store.account_too_young() || user_store.has_overdue_invoices()
108 }) {
109 return;
110 }
111
112 if let Some(current) = zeta.current_prediction_for_buffer(&buffer, &self.project, cx)
113 && let BufferEditPrediction::Local { prediction } = current
114 && prediction.interpolate(buffer.read(cx)).is_some()
115 {
116 return;
117 }
118
119 let pending_prediction_id = self.next_pending_prediction_id;
120 self.next_pending_prediction_id += 1;
121 let last_request_timestamp = self.last_request_timestamp;
122
123 let project = self.project.clone();
124 let task = cx.spawn(async move |this, cx| {
125 if let Some(timeout) = (last_request_timestamp + Self::THROTTLE_TIMEOUT)
126 .checked_duration_since(Instant::now())
127 {
128 cx.background_executor().timer(timeout).await;
129 }
130
131 let refresh_task = this.update(cx, |this, cx| {
132 this.last_request_timestamp = Instant::now();
133 this.zeta.update(cx, |zeta, cx| {
134 zeta.refresh_prediction(&project, &buffer, cursor_position, cx)
135 })
136 });
137
138 if let Some(refresh_task) = refresh_task.ok() {
139 refresh_task.await.log_err();
140 }
141
142 this.update(cx, |this, cx| {
143 if this.pending_predictions[0].id == pending_prediction_id {
144 this.pending_predictions.remove(0);
145 } else {
146 this.pending_predictions.clear();
147 }
148
149 cx.notify();
150 })
151 .ok();
152 });
153
154 // We always maintain at most two pending predictions. When we already
155 // have two, we replace the newest one.
156 if self.pending_predictions.len() <= 1 {
157 self.pending_predictions.push(PendingPrediction {
158 id: pending_prediction_id,
159 _task: task,
160 });
161 } else if self.pending_predictions.len() == 2 {
162 self.pending_predictions.pop();
163 self.pending_predictions.push(PendingPrediction {
164 id: pending_prediction_id,
165 _task: task,
166 });
167 }
168
169 cx.notify();
170 }
171
172 fn cycle(
173 &mut self,
174 _buffer: Entity<language::Buffer>,
175 _cursor_position: language::Anchor,
176 _direction: Direction,
177 _cx: &mut Context<Self>,
178 ) {
179 }
180
181 fn accept(&mut self, cx: &mut Context<Self>) {
182 self.zeta.update(cx, |zeta, _cx| {
183 zeta.accept_current_prediction(&self.project);
184 });
185 self.pending_predictions.clear();
186 }
187
188 fn discard(&mut self, cx: &mut Context<Self>) {
189 self.zeta.update(cx, |zeta, _cx| {
190 zeta.discard_current_prediction(&self.project);
191 });
192 self.pending_predictions.clear();
193 }
194
195 fn suggest(
196 &mut self,
197 buffer: &Entity<language::Buffer>,
198 cursor_position: language::Anchor,
199 cx: &mut Context<Self>,
200 ) -> Option<edit_prediction::EditPrediction> {
201 let prediction =
202 self.zeta
203 .read(cx)
204 .current_prediction_for_buffer(buffer, &self.project, cx)?;
205
206 let prediction = match prediction {
207 BufferEditPrediction::Local { prediction } => prediction,
208 BufferEditPrediction::Jump { prediction } => {
209 return Some(edit_prediction::EditPrediction::Jump {
210 id: Some(prediction.id.to_string().into()),
211 snapshot: prediction.snapshot.clone(),
212 target: prediction.edits.first().unwrap().0.start,
213 });
214 }
215 };
216
217 let buffer = buffer.read(cx);
218 let snapshot = buffer.snapshot();
219
220 let Some(edits) = prediction.interpolate(&snapshot) else {
221 self.zeta.update(cx, |zeta, _cx| {
222 zeta.discard_current_prediction(&self.project);
223 });
224 return None;
225 };
226
227 let cursor_row = cursor_position.to_point(&snapshot).row;
228 let (closest_edit_ix, (closest_edit_range, _)) =
229 edits.iter().enumerate().min_by_key(|(_, (range, _))| {
230 let distance_from_start = cursor_row.abs_diff(range.start.to_point(&snapshot).row);
231 let distance_from_end = cursor_row.abs_diff(range.end.to_point(&snapshot).row);
232 cmp::min(distance_from_start, distance_from_end)
233 })?;
234
235 let mut edit_start_ix = closest_edit_ix;
236 for (range, _) in edits[..edit_start_ix].iter().rev() {
237 let distance_from_closest_edit = closest_edit_range.start.to_point(&snapshot).row
238 - range.end.to_point(&snapshot).row;
239 if distance_from_closest_edit <= 1 {
240 edit_start_ix -= 1;
241 } else {
242 break;
243 }
244 }
245
246 let mut edit_end_ix = closest_edit_ix + 1;
247 for (range, _) in &edits[edit_end_ix..] {
248 let distance_from_closest_edit =
249 range.start.to_point(buffer).row - closest_edit_range.end.to_point(&snapshot).row;
250 if distance_from_closest_edit <= 1 {
251 edit_end_ix += 1;
252 } else {
253 break;
254 }
255 }
256
257 Some(edit_prediction::EditPrediction::Local {
258 id: Some(prediction.id.to_string().into()),
259 edits: edits[edit_start_ix..edit_end_ix].to_vec(),
260 edit_preview: Some(prediction.edit_preview.clone()),
261 })
262 }
263}