1use std::{
2 cmp,
3 sync::Arc,
4 time::{Duration, Instant},
5};
6
7use arrayvec::ArrayVec;
8use client::{Client, UserStore};
9use edit_prediction::{DataCollectionState, Direction, EditPredictionProvider};
10use gpui::{App, Entity, Task, prelude::*};
11use language::ToPoint as _;
12use project::Project;
13use util::ResultExt as _;
14
15use crate::{BufferEditPrediction, Zeta, ZetaEditPredictionModel};
16
17pub struct ZetaEditPredictionProvider {
18 zeta: Entity<Zeta>,
19 next_pending_prediction_id: usize,
20 pending_predictions: ArrayVec<PendingPrediction, 2>,
21 last_request_timestamp: Instant,
22 project: Entity<Project>,
23}
24
25impl ZetaEditPredictionProvider {
26 pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
27
28 pub fn new(
29 project: Entity<Project>,
30 client: &Arc<Client>,
31 user_store: &Entity<UserStore>,
32 cx: &mut App,
33 ) -> Self {
34 let zeta = Zeta::global(client, user_store, cx);
35 zeta.update(cx, |zeta, cx| {
36 zeta.register_project(&project, cx);
37 });
38
39 Self {
40 zeta,
41 next_pending_prediction_id: 0,
42 pending_predictions: ArrayVec::new(),
43 last_request_timestamp: Instant::now(),
44 project: project,
45 }
46 }
47}
48
49struct PendingPrediction {
50 id: usize,
51 _task: Task<()>,
52}
53
54impl EditPredictionProvider for ZetaEditPredictionProvider {
55 fn name() -> &'static str {
56 "zed-predict2"
57 }
58
59 fn display_name() -> &'static str {
60 "Zed's Edit Predictions 2"
61 }
62
63 fn show_completions_in_menu() -> bool {
64 true
65 }
66
67 fn show_tab_accept_marker() -> bool {
68 true
69 }
70
71 fn data_collection_state(&self, _cx: &App) -> DataCollectionState {
72 // TODO [zeta2]
73 DataCollectionState::Unsupported
74 }
75
76 fn toggle_data_collection(&mut self, _cx: &mut App) {
77 // TODO [zeta2]
78 }
79
80 fn usage(&self, cx: &App) -> Option<client::EditPredictionUsage> {
81 self.zeta.read(cx).usage(cx)
82 }
83
84 fn is_enabled(
85 &self,
86 _buffer: &Entity<language::Buffer>,
87 _cursor_position: language::Anchor,
88 cx: &App,
89 ) -> bool {
90 let zeta = self.zeta.read(cx);
91 if zeta.edit_prediction_model == ZetaEditPredictionModel::Sweep {
92 zeta.sweep_api_token.is_some()
93 } else {
94 true
95 }
96 }
97
98 fn is_refreshing(&self) -> bool {
99 !self.pending_predictions.is_empty()
100 }
101
102 fn refresh(
103 &mut self,
104 buffer: Entity<language::Buffer>,
105 cursor_position: language::Anchor,
106 _debounce: bool,
107 cx: &mut Context<Self>,
108 ) {
109 let zeta = self.zeta.read(cx);
110
111 if zeta.user_store.read_with(cx, |user_store, _cx| {
112 user_store.account_too_young() || user_store.has_overdue_invoices()
113 }) {
114 return;
115 }
116
117 if let Some(current) = zeta.current_prediction_for_buffer(&buffer, &self.project, cx)
118 && let BufferEditPrediction::Local { prediction } = current
119 && prediction.interpolate(buffer.read(cx)).is_some()
120 {
121 return;
122 }
123
124 self.zeta.update(cx, |zeta, cx| {
125 zeta.refresh_context_if_needed(&self.project, &buffer, cursor_position, cx);
126 });
127
128 let pending_prediction_id = self.next_pending_prediction_id;
129 self.next_pending_prediction_id += 1;
130 let last_request_timestamp = self.last_request_timestamp;
131
132 let project = self.project.clone();
133 let task = cx.spawn(async move |this, cx| {
134 if let Some(timeout) = (last_request_timestamp + Self::THROTTLE_TIMEOUT)
135 .checked_duration_since(Instant::now())
136 {
137 cx.background_executor().timer(timeout).await;
138 }
139
140 let refresh_task = this.update(cx, |this, cx| {
141 this.last_request_timestamp = Instant::now();
142 this.zeta.update(cx, |zeta, cx| {
143 zeta.refresh_prediction(&project, &buffer, cursor_position, cx)
144 })
145 });
146
147 if let Some(refresh_task) = refresh_task.ok() {
148 refresh_task.await.log_err();
149 }
150
151 this.update(cx, |this, cx| {
152 if this.pending_predictions[0].id == pending_prediction_id {
153 this.pending_predictions.remove(0);
154 } else {
155 this.pending_predictions.clear();
156 }
157
158 cx.notify();
159 })
160 .ok();
161 });
162
163 // We always maintain at most two pending predictions. When we already
164 // have two, we replace the newest one.
165 if self.pending_predictions.len() <= 1 {
166 self.pending_predictions.push(PendingPrediction {
167 id: pending_prediction_id,
168 _task: task,
169 });
170 } else if self.pending_predictions.len() == 2 {
171 self.pending_predictions.pop();
172 self.pending_predictions.push(PendingPrediction {
173 id: pending_prediction_id,
174 _task: task,
175 });
176 }
177
178 cx.notify();
179 }
180
181 fn cycle(
182 &mut self,
183 _buffer: Entity<language::Buffer>,
184 _cursor_position: language::Anchor,
185 _direction: Direction,
186 _cx: &mut Context<Self>,
187 ) {
188 }
189
190 fn accept(&mut self, cx: &mut Context<Self>) {
191 self.zeta.update(cx, |zeta, cx| {
192 zeta.accept_current_prediction(&self.project, cx);
193 });
194 self.pending_predictions.clear();
195 }
196
197 fn discard(&mut self, cx: &mut Context<Self>) {
198 self.zeta.update(cx, |zeta, _cx| {
199 zeta.discard_current_prediction(&self.project);
200 });
201 self.pending_predictions.clear();
202 }
203
204 fn suggest(
205 &mut self,
206 buffer: &Entity<language::Buffer>,
207 cursor_position: language::Anchor,
208 cx: &mut Context<Self>,
209 ) -> Option<edit_prediction::EditPrediction> {
210 let prediction =
211 self.zeta
212 .read(cx)
213 .current_prediction_for_buffer(buffer, &self.project, cx)?;
214
215 let prediction = match prediction {
216 BufferEditPrediction::Local { prediction } => prediction,
217 BufferEditPrediction::Jump { prediction } => {
218 return Some(edit_prediction::EditPrediction::Jump {
219 id: Some(prediction.id.to_string().into()),
220 snapshot: prediction.snapshot.clone(),
221 target: prediction.edits.first().unwrap().0.start,
222 });
223 }
224 };
225
226 let buffer = buffer.read(cx);
227 let snapshot = buffer.snapshot();
228
229 let Some(edits) = prediction.interpolate(&snapshot) else {
230 self.zeta.update(cx, |zeta, _cx| {
231 zeta.discard_current_prediction(&self.project);
232 });
233 return None;
234 };
235
236 let cursor_row = cursor_position.to_point(&snapshot).row;
237 let (closest_edit_ix, (closest_edit_range, _)) =
238 edits.iter().enumerate().min_by_key(|(_, (range, _))| {
239 let distance_from_start = cursor_row.abs_diff(range.start.to_point(&snapshot).row);
240 let distance_from_end = cursor_row.abs_diff(range.end.to_point(&snapshot).row);
241 cmp::min(distance_from_start, distance_from_end)
242 })?;
243
244 let mut edit_start_ix = closest_edit_ix;
245 for (range, _) in edits[..edit_start_ix].iter().rev() {
246 let distance_from_closest_edit = closest_edit_range.start.to_point(&snapshot).row
247 - range.end.to_point(&snapshot).row;
248 if distance_from_closest_edit <= 1 {
249 edit_start_ix -= 1;
250 } else {
251 break;
252 }
253 }
254
255 let mut edit_end_ix = closest_edit_ix + 1;
256 for (range, _) in &edits[edit_end_ix..] {
257 let distance_from_closest_edit =
258 range.start.to_point(buffer).row - closest_edit_range.end.to_point(&snapshot).row;
259 if distance_from_closest_edit <= 1 {
260 edit_end_ix += 1;
261 } else {
262 break;
263 }
264 }
265
266 Some(edit_prediction::EditPrediction::Local {
267 id: Some(prediction.id.to_string().into()),
268 edits: edits[edit_start_ix..edit_end_ix].to_vec(),
269 edit_preview: Some(prediction.edit_preview.clone()),
270 })
271 }
272}