1use std::{cmp, sync::Arc, time::Duration};
2
3use client::{Client, UserStore};
4use cloud_llm_client::EditPredictionRejectReason;
5use edit_prediction::{DataCollectionState, Direction, EditPredictionProvider};
6use gpui::{App, Entity, prelude::*};
7use language::ToPoint as _;
8use project::Project;
9
10use crate::{BufferEditPrediction, Zeta, ZetaEditPredictionModel};
11
12pub struct ZetaEditPredictionProvider {
13 zeta: Entity<Zeta>,
14 project: Entity<Project>,
15}
16
17impl ZetaEditPredictionProvider {
18 pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
19
20 pub fn new(
21 project: Entity<Project>,
22 client: &Arc<Client>,
23 user_store: &Entity<UserStore>,
24 cx: &mut Context<Self>,
25 ) -> Self {
26 let zeta = Zeta::global(client, user_store, cx);
27 zeta.update(cx, |zeta, cx| {
28 zeta.register_project(&project, cx);
29 });
30
31 cx.observe(&zeta, |_this, _zeta, cx| {
32 cx.notify();
33 })
34 .detach();
35
36 Self {
37 project: project,
38 zeta,
39 }
40 }
41}
42
43impl EditPredictionProvider for ZetaEditPredictionProvider {
44 fn name() -> &'static str {
45 "zed-predict2"
46 }
47
48 fn display_name() -> &'static str {
49 "Zed's Edit Predictions 2"
50 }
51
52 fn show_completions_in_menu() -> bool {
53 true
54 }
55
56 fn show_tab_accept_marker() -> bool {
57 true
58 }
59
60 fn data_collection_state(&self, _cx: &App) -> DataCollectionState {
61 // TODO [zeta2]
62 DataCollectionState::Unsupported
63 }
64
65 fn toggle_data_collection(&mut self, _cx: &mut App) {
66 // TODO [zeta2]
67 }
68
69 fn usage(&self, cx: &App) -> Option<client::EditPredictionUsage> {
70 self.zeta.read(cx).usage(cx)
71 }
72
73 fn is_enabled(
74 &self,
75 _buffer: &Entity<language::Buffer>,
76 _cursor_position: language::Anchor,
77 cx: &App,
78 ) -> bool {
79 let zeta = self.zeta.read(cx);
80 if zeta.edit_prediction_model == ZetaEditPredictionModel::Sweep {
81 zeta.sweep_ai.api_token.is_some()
82 } else {
83 true
84 }
85 }
86
87 fn is_refreshing(&self, cx: &App) -> bool {
88 self.zeta.read(cx).is_refreshing(&self.project)
89 }
90
91 fn refresh(
92 &mut self,
93 buffer: Entity<language::Buffer>,
94 cursor_position: language::Anchor,
95 _debounce: bool,
96 cx: &mut Context<Self>,
97 ) {
98 let zeta = self.zeta.read(cx);
99
100 if zeta.user_store.read_with(cx, |user_store, _cx| {
101 user_store.account_too_young() || user_store.has_overdue_invoices()
102 }) {
103 return;
104 }
105
106 if let Some(current) = zeta.current_prediction_for_buffer(&buffer, &self.project, cx)
107 && let BufferEditPrediction::Local { prediction } = current
108 && prediction.interpolate(buffer.read(cx)).is_some()
109 {
110 return;
111 }
112
113 self.zeta.update(cx, |zeta, cx| {
114 zeta.refresh_context_if_needed(&self.project, &buffer, cursor_position, cx);
115 zeta.refresh_prediction_from_buffer(self.project.clone(), buffer, cursor_position, cx)
116 });
117 }
118
119 fn cycle(
120 &mut self,
121 _buffer: Entity<language::Buffer>,
122 _cursor_position: language::Anchor,
123 _direction: Direction,
124 _cx: &mut Context<Self>,
125 ) {
126 }
127
128 fn accept(&mut self, cx: &mut Context<Self>) {
129 self.zeta.update(cx, |zeta, cx| {
130 zeta.accept_current_prediction(&self.project, cx);
131 });
132 }
133
134 fn discard(&mut self, cx: &mut Context<Self>) {
135 self.zeta.update(cx, |zeta, cx| {
136 zeta.reject_current_prediction(
137 EditPredictionRejectReason::Discarded,
138 &self.project,
139 cx,
140 );
141 });
142 }
143
144 fn did_show(&mut self, cx: &mut Context<Self>) {
145 self.zeta.update(cx, |zeta, cx| {
146 zeta.did_show_current_prediction(&self.project, cx);
147 });
148 }
149
150 fn suggest(
151 &mut self,
152 buffer: &Entity<language::Buffer>,
153 cursor_position: language::Anchor,
154 cx: &mut Context<Self>,
155 ) -> Option<edit_prediction::EditPrediction> {
156 let prediction =
157 self.zeta
158 .read(cx)
159 .current_prediction_for_buffer(buffer, &self.project, cx)?;
160
161 let prediction = match prediction {
162 BufferEditPrediction::Local { prediction } => prediction,
163 BufferEditPrediction::Jump { prediction } => {
164 return Some(edit_prediction::EditPrediction::Jump {
165 id: Some(prediction.id.to_string().into()),
166 snapshot: prediction.snapshot.clone(),
167 target: prediction.edits.first().unwrap().0.start,
168 });
169 }
170 };
171
172 let buffer = buffer.read(cx);
173 let snapshot = buffer.snapshot();
174
175 let Some(edits) = prediction.interpolate(&snapshot) else {
176 self.zeta.update(cx, |zeta, cx| {
177 zeta.reject_current_prediction(
178 EditPredictionRejectReason::InterpolatedEmpty,
179 &self.project,
180 cx,
181 );
182 });
183 return None;
184 };
185
186 let cursor_row = cursor_position.to_point(&snapshot).row;
187 let (closest_edit_ix, (closest_edit_range, _)) =
188 edits.iter().enumerate().min_by_key(|(_, (range, _))| {
189 let distance_from_start = cursor_row.abs_diff(range.start.to_point(&snapshot).row);
190 let distance_from_end = cursor_row.abs_diff(range.end.to_point(&snapshot).row);
191 cmp::min(distance_from_start, distance_from_end)
192 })?;
193
194 let mut edit_start_ix = closest_edit_ix;
195 for (range, _) in edits[..edit_start_ix].iter().rev() {
196 let distance_from_closest_edit = closest_edit_range.start.to_point(&snapshot).row
197 - range.end.to_point(&snapshot).row;
198 if distance_from_closest_edit <= 1 {
199 edit_start_ix -= 1;
200 } else {
201 break;
202 }
203 }
204
205 let mut edit_end_ix = closest_edit_ix + 1;
206 for (range, _) in &edits[edit_end_ix..] {
207 let distance_from_closest_edit =
208 range.start.to_point(buffer).row - closest_edit_range.end.to_point(&snapshot).row;
209 if distance_from_closest_edit <= 1 {
210 edit_end_ix += 1;
211 } else {
212 break;
213 }
214 }
215
216 Some(edit_prediction::EditPrediction::Local {
217 id: Some(prediction.id.to_string().into()),
218 edits: edits[edit_start_ix..edit_end_ix].to_vec(),
219 edit_preview: Some(prediction.edit_preview.clone()),
220 })
221 }
222}