1use std::{cmp, sync::Arc, time::Duration};
2
3use client::{Client, UserStore};
4use edit_prediction::{DataCollectionState, Direction, EditPredictionProvider};
5use gpui::{App, Entity, prelude::*};
6use language::ToPoint as _;
7use project::Project;
8
9use crate::{BufferEditPrediction, Zeta, ZetaEditPredictionModel};
10
11pub struct ZetaEditPredictionProvider {
12 zeta: Entity<Zeta>,
13 project: Entity<Project>,
14}
15
16impl ZetaEditPredictionProvider {
17 pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
18
19 pub fn new(
20 project: Entity<Project>,
21 client: &Arc<Client>,
22 user_store: &Entity<UserStore>,
23 cx: &mut Context<Self>,
24 ) -> Self {
25 let zeta = Zeta::global(client, user_store, cx);
26 zeta.update(cx, |zeta, cx| {
27 zeta.register_project(&project, cx);
28 });
29
30 cx.observe(&zeta, |_this, _zeta, cx| {
31 cx.notify();
32 })
33 .detach();
34
35 Self {
36 project: project,
37 zeta,
38 }
39 }
40}
41
42impl EditPredictionProvider for ZetaEditPredictionProvider {
43 fn name() -> &'static str {
44 "zed-predict2"
45 }
46
47 fn display_name() -> &'static str {
48 "Zed's Edit Predictions 2"
49 }
50
51 fn show_completions_in_menu() -> bool {
52 true
53 }
54
55 fn show_tab_accept_marker() -> bool {
56 true
57 }
58
59 fn data_collection_state(&self, _cx: &App) -> DataCollectionState {
60 // TODO [zeta2]
61 DataCollectionState::Unsupported
62 }
63
64 fn toggle_data_collection(&mut self, _cx: &mut App) {
65 // TODO [zeta2]
66 }
67
68 fn usage(&self, cx: &App) -> Option<client::EditPredictionUsage> {
69 self.zeta.read(cx).usage(cx)
70 }
71
72 fn is_enabled(
73 &self,
74 _buffer: &Entity<language::Buffer>,
75 _cursor_position: language::Anchor,
76 cx: &App,
77 ) -> bool {
78 let zeta = self.zeta.read(cx);
79 if zeta.edit_prediction_model == ZetaEditPredictionModel::Sweep {
80 zeta.sweep_ai.api_token.is_some()
81 } else {
82 true
83 }
84 }
85
86 fn is_refreshing(&self, cx: &App) -> bool {
87 self.zeta.read(cx).is_refreshing(&self.project)
88 }
89
90 fn refresh(
91 &mut self,
92 buffer: Entity<language::Buffer>,
93 cursor_position: language::Anchor,
94 _debounce: bool,
95 cx: &mut Context<Self>,
96 ) {
97 let zeta = self.zeta.read(cx);
98
99 if zeta.user_store.read_with(cx, |user_store, _cx| {
100 user_store.account_too_young() || user_store.has_overdue_invoices()
101 }) {
102 return;
103 }
104
105 if let Some(current) = zeta.current_prediction_for_buffer(&buffer, &self.project, cx)
106 && let BufferEditPrediction::Local { prediction } = current
107 && prediction.interpolate(buffer.read(cx)).is_some()
108 {
109 return;
110 }
111
112 self.zeta.update(cx, |zeta, cx| {
113 zeta.refresh_context_if_needed(&self.project, &buffer, cursor_position, cx);
114 zeta.refresh_prediction_from_buffer(self.project.clone(), buffer, cursor_position, cx)
115 });
116 }
117
118 fn cycle(
119 &mut self,
120 _buffer: Entity<language::Buffer>,
121 _cursor_position: language::Anchor,
122 _direction: Direction,
123 _cx: &mut Context<Self>,
124 ) {
125 }
126
127 fn accept(&mut self, cx: &mut Context<Self>) {
128 self.zeta.update(cx, |zeta, cx| {
129 zeta.accept_current_prediction(&self.project, cx);
130 });
131 }
132
133 fn discard(&mut self, cx: &mut Context<Self>) {
134 self.zeta.update(cx, |zeta, cx| {
135 zeta.discard_current_prediction(&self.project, cx);
136 });
137 }
138
139 fn did_show(&mut self, cx: &mut Context<Self>) {
140 self.zeta.update(cx, |zeta, cx| {
141 zeta.did_show_current_prediction(&self.project, cx);
142 });
143 }
144
145 fn suggest(
146 &mut self,
147 buffer: &Entity<language::Buffer>,
148 cursor_position: language::Anchor,
149 cx: &mut Context<Self>,
150 ) -> Option<edit_prediction::EditPrediction> {
151 let prediction =
152 self.zeta
153 .read(cx)
154 .current_prediction_for_buffer(buffer, &self.project, cx)?;
155
156 let prediction = match prediction {
157 BufferEditPrediction::Local { prediction } => prediction,
158 BufferEditPrediction::Jump { prediction } => {
159 return Some(edit_prediction::EditPrediction::Jump {
160 id: Some(prediction.id.to_string().into()),
161 snapshot: prediction.snapshot.clone(),
162 target: prediction.edits.first().unwrap().0.start,
163 });
164 }
165 };
166
167 let buffer = buffer.read(cx);
168 let snapshot = buffer.snapshot();
169
170 let Some(edits) = prediction.interpolate(&snapshot) else {
171 self.zeta.update(cx, |zeta, cx| {
172 zeta.discard_current_prediction(&self.project, cx);
173 });
174 return None;
175 };
176
177 let cursor_row = cursor_position.to_point(&snapshot).row;
178 let (closest_edit_ix, (closest_edit_range, _)) =
179 edits.iter().enumerate().min_by_key(|(_, (range, _))| {
180 let distance_from_start = cursor_row.abs_diff(range.start.to_point(&snapshot).row);
181 let distance_from_end = cursor_row.abs_diff(range.end.to_point(&snapshot).row);
182 cmp::min(distance_from_start, distance_from_end)
183 })?;
184
185 let mut edit_start_ix = closest_edit_ix;
186 for (range, _) in edits[..edit_start_ix].iter().rev() {
187 let distance_from_closest_edit = closest_edit_range.start.to_point(&snapshot).row
188 - range.end.to_point(&snapshot).row;
189 if distance_from_closest_edit <= 1 {
190 edit_start_ix -= 1;
191 } else {
192 break;
193 }
194 }
195
196 let mut edit_end_ix = closest_edit_ix + 1;
197 for (range, _) in &edits[edit_end_ix..] {
198 let distance_from_closest_edit =
199 range.start.to_point(buffer).row - closest_edit_range.end.to_point(&snapshot).row;
200 if distance_from_closest_edit <= 1 {
201 edit_end_ix += 1;
202 } else {
203 break;
204 }
205 }
206
207 Some(edit_prediction::EditPrediction::Local {
208 id: Some(prediction.id.to_string().into()),
209 edits: edits[edit_start_ix..edit_end_ix].to_vec(),
210 edit_preview: Some(prediction.edit_preview.clone()),
211 })
212 }
213}