1use std::{cmp, sync::Arc, time::Duration};
2
3use client::{Client, UserStore};
4use edit_prediction::{DataCollectionState, Direction, EditPredictionProvider};
5use gpui::{App, Entity, prelude::*};
6use language::ToPoint as _;
7use project::Project;
8
9use crate::{BufferEditPrediction, Zeta, ZetaEditPredictionModel};
10
11pub struct ZetaEditPredictionProvider {
12 zeta: Entity<Zeta>,
13 project: Entity<Project>,
14}
15
16impl ZetaEditPredictionProvider {
17 pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
18
19 pub fn new(
20 project: Entity<Project>,
21 client: &Arc<Client>,
22 user_store: &Entity<UserStore>,
23 cx: &mut Context<Self>,
24 ) -> Self {
25 let zeta = Zeta::global(client, user_store, cx);
26 zeta.update(cx, |zeta, cx| {
27 zeta.register_project(&project, cx);
28 });
29
30 cx.observe(&zeta, |_this, _zeta, cx| {
31 cx.notify();
32 })
33 .detach();
34
35 Self {
36 project: project,
37 zeta,
38 }
39 }
40}
41
42impl EditPredictionProvider for ZetaEditPredictionProvider {
43 fn name() -> &'static str {
44 "zed-predict2"
45 }
46
47 fn display_name() -> &'static str {
48 "Zed's Edit Predictions 2"
49 }
50
51 fn show_completions_in_menu() -> bool {
52 true
53 }
54
55 fn show_tab_accept_marker() -> bool {
56 true
57 }
58
59 fn data_collection_state(&self, _cx: &App) -> DataCollectionState {
60 // TODO [zeta2]
61 DataCollectionState::Unsupported
62 }
63
64 fn toggle_data_collection(&mut self, _cx: &mut App) {
65 // TODO [zeta2]
66 }
67
68 fn usage(&self, cx: &App) -> Option<client::EditPredictionUsage> {
69 self.zeta.read(cx).usage(cx)
70 }
71
72 fn is_enabled(
73 &self,
74 _buffer: &Entity<language::Buffer>,
75 _cursor_position: language::Anchor,
76 cx: &App,
77 ) -> bool {
78 let zeta = self.zeta.read(cx);
79 if zeta.edit_prediction_model == ZetaEditPredictionModel::Sweep {
80 zeta.sweep_api_token.is_some()
81 } else {
82 true
83 }
84 }
85
86 fn is_refreshing(&self, cx: &App) -> bool {
87 self.zeta.read(cx).is_refreshing(&self.project)
88 }
89
90 fn refresh(
91 &mut self,
92 buffer: Entity<language::Buffer>,
93 cursor_position: language::Anchor,
94 _debounce: bool,
95 cx: &mut Context<Self>,
96 ) {
97 let zeta = self.zeta.read(cx);
98
99 if zeta.user_store.read_with(cx, |user_store, _cx| {
100 user_store.account_too_young() || user_store.has_overdue_invoices()
101 }) {
102 return;
103 }
104
105 if let Some(current) = zeta.current_prediction_for_buffer(&buffer, &self.project, cx)
106 && let BufferEditPrediction::Local { prediction } = current
107 && prediction.interpolate(buffer.read(cx)).is_some()
108 {
109 return;
110 }
111
112 self.zeta.update(cx, |zeta, cx| {
113 zeta.refresh_context_if_needed(&self.project, &buffer, cursor_position, cx);
114 zeta.refresh_prediction_from_buffer(self.project.clone(), buffer, cursor_position, cx)
115 });
116 }
117
118 fn cycle(
119 &mut self,
120 _buffer: Entity<language::Buffer>,
121 _cursor_position: language::Anchor,
122 _direction: Direction,
123 _cx: &mut Context<Self>,
124 ) {
125 }
126
127 fn accept(&mut self, cx: &mut Context<Self>) {
128 self.zeta.update(cx, |zeta, cx| {
129 zeta.accept_current_prediction(&self.project, cx);
130 });
131 }
132
133 fn discard(&mut self, cx: &mut Context<Self>) {
134 self.zeta.update(cx, |zeta, _cx| {
135 zeta.discard_current_prediction(&self.project);
136 });
137 }
138
139 fn suggest(
140 &mut self,
141 buffer: &Entity<language::Buffer>,
142 cursor_position: language::Anchor,
143 cx: &mut Context<Self>,
144 ) -> Option<edit_prediction::EditPrediction> {
145 let prediction =
146 self.zeta
147 .read(cx)
148 .current_prediction_for_buffer(buffer, &self.project, cx)?;
149
150 let prediction = match prediction {
151 BufferEditPrediction::Local { prediction } => prediction,
152 BufferEditPrediction::Jump { prediction } => {
153 return Some(edit_prediction::EditPrediction::Jump {
154 id: Some(prediction.id.to_string().into()),
155 snapshot: prediction.snapshot.clone(),
156 target: prediction.edits.first().unwrap().0.start,
157 });
158 }
159 };
160
161 let buffer = buffer.read(cx);
162 let snapshot = buffer.snapshot();
163
164 let Some(edits) = prediction.interpolate(&snapshot) else {
165 self.zeta.update(cx, |zeta, _cx| {
166 zeta.discard_current_prediction(&self.project);
167 });
168 return None;
169 };
170
171 let cursor_row = cursor_position.to_point(&snapshot).row;
172 let (closest_edit_ix, (closest_edit_range, _)) =
173 edits.iter().enumerate().min_by_key(|(_, (range, _))| {
174 let distance_from_start = cursor_row.abs_diff(range.start.to_point(&snapshot).row);
175 let distance_from_end = cursor_row.abs_diff(range.end.to_point(&snapshot).row);
176 cmp::min(distance_from_start, distance_from_end)
177 })?;
178
179 let mut edit_start_ix = closest_edit_ix;
180 for (range, _) in edits[..edit_start_ix].iter().rev() {
181 let distance_from_closest_edit = closest_edit_range.start.to_point(&snapshot).row
182 - range.end.to_point(&snapshot).row;
183 if distance_from_closest_edit <= 1 {
184 edit_start_ix -= 1;
185 } else {
186 break;
187 }
188 }
189
190 let mut edit_end_ix = closest_edit_ix + 1;
191 for (range, _) in &edits[edit_end_ix..] {
192 let distance_from_closest_edit =
193 range.start.to_point(buffer).row - closest_edit_range.end.to_point(&snapshot).row;
194 if distance_from_closest_edit <= 1 {
195 edit_end_ix += 1;
196 } else {
197 break;
198 }
199 }
200
201 Some(edit_prediction::EditPrediction::Local {
202 id: Some(prediction.id.to_string().into()),
203 edits: edits[edit_start_ix..edit_end_ix].to_vec(),
204 edit_preview: Some(prediction.edit_preview.clone()),
205 })
206 }
207}