1use std::{cmp, sync::Arc, time::Duration};
2
3use client::{Client, UserStore};
4use cloud_llm_client::EditPredictionRejectReason;
5use edit_prediction::{DataCollectionState, Direction, EditPredictionProvider};
6use gpui::{App, Entity, prelude::*};
7use language::ToPoint as _;
8use project::Project;
9
10use crate::{BufferEditPrediction, Zeta, ZetaEditPredictionModel};
11
12pub struct ZetaEditPredictionProvider {
13 zeta: Entity<Zeta>,
14 project: Entity<Project>,
15}
16
17impl ZetaEditPredictionProvider {
18 pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
19
20 pub fn new(
21 project: Entity<Project>,
22 client: &Arc<Client>,
23 user_store: &Entity<UserStore>,
24 cx: &mut Context<Self>,
25 ) -> Self {
26 let zeta = Zeta::global(client, user_store, cx);
27 zeta.update(cx, |zeta, cx| {
28 zeta.register_project(&project, cx);
29 });
30
31 cx.observe(&zeta, |_this, _zeta, cx| {
32 cx.notify();
33 })
34 .detach();
35
36 Self {
37 project: project,
38 zeta,
39 }
40 }
41}
42
43impl EditPredictionProvider for ZetaEditPredictionProvider {
44 fn name() -> &'static str {
45 "zed-predict2"
46 }
47
48 fn display_name() -> &'static str {
49 "Zed's Edit Predictions 2"
50 }
51
52 fn show_completions_in_menu() -> bool {
53 true
54 }
55
56 fn show_tab_accept_marker() -> bool {
57 true
58 }
59
60 fn data_collection_state(&self, _cx: &App) -> DataCollectionState {
61 // TODO [zeta2]
62 DataCollectionState::Unsupported
63 }
64
65 fn toggle_data_collection(&mut self, _cx: &mut App) {
66 // TODO [zeta2]
67 }
68
69 fn usage(&self, cx: &App) -> Option<client::EditPredictionUsage> {
70 self.zeta.read(cx).usage(cx)
71 }
72
73 fn is_enabled(
74 &self,
75 _buffer: &Entity<language::Buffer>,
76 _cursor_position: language::Anchor,
77 cx: &App,
78 ) -> bool {
79 let zeta = self.zeta.read(cx);
80 if zeta.edit_prediction_model == ZetaEditPredictionModel::Sweep {
81 zeta.has_sweep_api_token()
82 } else {
83 true
84 }
85 }
86
87 fn is_refreshing(&self, cx: &App) -> bool {
88 self.zeta.read(cx).is_refreshing(&self.project)
89 }
90
91 fn refresh(
92 &mut self,
93 buffer: Entity<language::Buffer>,
94 cursor_position: language::Anchor,
95 _debounce: bool,
96 cx: &mut Context<Self>,
97 ) {
98 let zeta = self.zeta.read(cx);
99
100 if zeta.user_store.read_with(cx, |user_store, _cx| {
101 user_store.account_too_young() || user_store.has_overdue_invoices()
102 }) {
103 return;
104 }
105
106 if let Some(current) = zeta.current_prediction_for_buffer(&buffer, &self.project, cx)
107 && let BufferEditPrediction::Local { prediction } = current
108 && prediction.interpolate(buffer.read(cx)).is_some()
109 {
110 return;
111 }
112
113 self.zeta.update(cx, |zeta, cx| {
114 zeta.refresh_context_if_needed(&self.project, &buffer, cursor_position, cx);
115 zeta.refresh_prediction_from_buffer(self.project.clone(), buffer, cursor_position, cx)
116 });
117 }
118
119 fn cycle(
120 &mut self,
121 _buffer: Entity<language::Buffer>,
122 _cursor_position: language::Anchor,
123 _direction: Direction,
124 _cx: &mut Context<Self>,
125 ) {
126 }
127
128 fn accept(&mut self, cx: &mut Context<Self>) {
129 self.zeta.update(cx, |zeta, cx| {
130 zeta.accept_current_prediction(&self.project, cx);
131 });
132 }
133
134 fn discard(&mut self, cx: &mut Context<Self>) {
135 self.zeta.update(cx, |zeta, _cx| {
136 zeta.reject_current_prediction(EditPredictionRejectReason::Discarded, &self.project);
137 });
138 }
139
140 fn did_show(&mut self, cx: &mut Context<Self>) {
141 self.zeta.update(cx, |zeta, cx| {
142 zeta.did_show_current_prediction(&self.project, cx);
143 });
144 }
145
146 fn suggest(
147 &mut self,
148 buffer: &Entity<language::Buffer>,
149 cursor_position: language::Anchor,
150 cx: &mut Context<Self>,
151 ) -> Option<edit_prediction::EditPrediction> {
152 let prediction =
153 self.zeta
154 .read(cx)
155 .current_prediction_for_buffer(buffer, &self.project, cx)?;
156
157 let prediction = match prediction {
158 BufferEditPrediction::Local { prediction } => prediction,
159 BufferEditPrediction::Jump { prediction } => {
160 return Some(edit_prediction::EditPrediction::Jump {
161 id: Some(prediction.id.to_string().into()),
162 snapshot: prediction.snapshot.clone(),
163 target: prediction.edits.first().unwrap().0.start,
164 });
165 }
166 };
167
168 let buffer = buffer.read(cx);
169 let snapshot = buffer.snapshot();
170
171 let Some(edits) = prediction.interpolate(&snapshot) else {
172 self.zeta.update(cx, |zeta, _cx| {
173 zeta.reject_current_prediction(
174 EditPredictionRejectReason::InterpolatedEmpty,
175 &self.project,
176 );
177 });
178 return None;
179 };
180
181 let cursor_row = cursor_position.to_point(&snapshot).row;
182 let (closest_edit_ix, (closest_edit_range, _)) =
183 edits.iter().enumerate().min_by_key(|(_, (range, _))| {
184 let distance_from_start = cursor_row.abs_diff(range.start.to_point(&snapshot).row);
185 let distance_from_end = cursor_row.abs_diff(range.end.to_point(&snapshot).row);
186 cmp::min(distance_from_start, distance_from_end)
187 })?;
188
189 let mut edit_start_ix = closest_edit_ix;
190 for (range, _) in edits[..edit_start_ix].iter().rev() {
191 let distance_from_closest_edit = closest_edit_range.start.to_point(&snapshot).row
192 - range.end.to_point(&snapshot).row;
193 if distance_from_closest_edit <= 1 {
194 edit_start_ix -= 1;
195 } else {
196 break;
197 }
198 }
199
200 let mut edit_end_ix = closest_edit_ix + 1;
201 for (range, _) in &edits[edit_end_ix..] {
202 let distance_from_closest_edit =
203 range.start.to_point(buffer).row - closest_edit_range.end.to_point(&snapshot).row;
204 if distance_from_closest_edit <= 1 {
205 edit_end_ix += 1;
206 } else {
207 break;
208 }
209 }
210
211 Some(edit_prediction::EditPrediction::Local {
212 id: Some(prediction.id.to_string().into()),
213 edits: edits[edit_start_ix..edit_end_ix].to_vec(),
214 edit_preview: Some(prediction.edit_preview.clone()),
215 })
216 }
217}