1use cloud_llm_client::predict_edits_v3::{self, Signature};
2use edit_prediction::{DataCollectionState, Direction, EditPrediction, EditPredictionProvider};
3use edit_prediction_context::{
4 DeclarationId, EditPredictionContext, EditPredictionExcerptOptions, SyntaxIndex,
5 SyntaxIndexState,
6};
7use gpui::{App, Entity, EntityId, Task, prelude::*};
8use language::{Anchor, ToPoint};
9use language::{BufferSnapshot, Point};
10use std::collections::HashMap;
11use std::{ops::Range, sync::Arc};
12
13pub struct Zeta2EditPredictionProvider {
14 current: Option<CurrentEditPrediction>,
15 pending: Option<Task<()>>,
16}
17
18impl Zeta2EditPredictionProvider {
19 pub fn new() -> Self {
20 Self {
21 current: None,
22 pending: None,
23 }
24 }
25}
26
27#[derive(Clone)]
28struct CurrentEditPrediction {
29 buffer_id: EntityId,
30 prediction: EditPrediction,
31}
32
33impl EditPredictionProvider for Zeta2EditPredictionProvider {
34 fn name() -> &'static str {
35 // TODO [zeta2]
36 "zed-predict2"
37 }
38
39 fn display_name() -> &'static str {
40 "Zed's Edit Predictions 2"
41 }
42
43 fn show_completions_in_menu() -> bool {
44 true
45 }
46
47 fn show_tab_accept_marker() -> bool {
48 true
49 }
50
51 fn data_collection_state(&self, _cx: &App) -> DataCollectionState {
52 // TODO [zeta2]
53 DataCollectionState::Unsupported
54 }
55
56 fn toggle_data_collection(&mut self, _cx: &mut App) {
57 // TODO [zeta2]
58 }
59
60 fn usage(&self, _cx: &App) -> Option<client::EditPredictionUsage> {
61 // TODO [zeta2]
62 None
63 }
64
65 fn is_enabled(
66 &self,
67 _buffer: &Entity<language::Buffer>,
68 _cursor_position: language::Anchor,
69 _cx: &App,
70 ) -> bool {
71 true
72 }
73
74 fn is_refreshing(&self) -> bool {
75 self.pending.is_some()
76 }
77
78 fn refresh(
79 &mut self,
80 _project: Option<Entity<project::Project>>,
81 buffer: Entity<language::Buffer>,
82 cursor_position: language::Anchor,
83 _debounce: bool,
84 cx: &mut Context<Self>,
85 ) {
86 // TODO [zeta2] check account
87 // TODO [zeta2] actually request completion / interpolate
88
89 let snapshot = buffer.read(cx).snapshot();
90 let point = cursor_position.to_point(&snapshot);
91 let end_anchor = snapshot.anchor_before(language::Point::new(
92 point.row,
93 snapshot.line_len(point.row),
94 ));
95
96 let edits: Arc<[(Range<Anchor>, String)]> =
97 vec![(cursor_position..end_anchor, "👻".to_string())].into();
98 let edits_preview_task = buffer.read(cx).preview_edits(edits.clone(), cx);
99
100 // TODO [zeta2] throttle
101 // TODO [zeta2] keep 2 requests
102 self.pending = Some(cx.spawn(async move |this, cx| {
103 let edits_preview = edits_preview_task.await;
104
105 this.update(cx, |this, cx| {
106 this.current = Some(CurrentEditPrediction {
107 buffer_id: buffer.entity_id(),
108 prediction: EditPrediction {
109 // TODO! [zeta2] request id?
110 id: None,
111 edits: edits.to_vec(),
112 edit_preview: Some(edits_preview),
113 },
114 });
115 this.pending.take();
116 cx.notify();
117 })
118 .ok();
119 }));
120 cx.notify();
121 }
122
123 fn cycle(
124 &mut self,
125 _buffer: Entity<language::Buffer>,
126 _cursor_position: language::Anchor,
127 _direction: Direction,
128 _cx: &mut Context<Self>,
129 ) {
130 }
131
132 fn accept(&mut self, _cx: &mut Context<Self>) {
133 // TODO [zeta2] report accept
134 self.current.take();
135 self.pending.take();
136 }
137
138 fn discard(&mut self, _cx: &mut Context<Self>) {
139 self.current.take();
140 self.pending.take();
141 }
142
143 fn suggest(
144 &mut self,
145 buffer: &Entity<language::Buffer>,
146 _cursor_position: language::Anchor,
147 _cx: &mut Context<Self>,
148 ) -> Option<EditPrediction> {
149 let current_prediction = self.current.take()?;
150
151 if current_prediction.buffer_id != buffer.entity_id() {
152 return None;
153 }
154
155 // TODO [zeta2] interpolate
156
157 Some(current_prediction.prediction)
158 }
159}
160
161pub fn make_cloud_request_in_background(
162 cursor_point: Point,
163 buffer: BufferSnapshot,
164 events: Vec<predict_edits_v3::Event>,
165 can_collect_data: bool,
166 diagnostic_groups: Vec<predict_edits_v3::DiagnosticGroup>,
167 git_info: Option<cloud_llm_client::PredictEditsGitInfo>,
168 excerpt_options: EditPredictionExcerptOptions,
169 syntax_index: Entity<SyntaxIndex>,
170 cx: &mut App,
171) -> Task<Option<predict_edits_v3::PredictEditsRequest>> {
172 let index_state = syntax_index.read_with(cx, |index, _cx| index.state().clone());
173 cx.background_spawn(async move {
174 let index_state = index_state.lock().await;
175 EditPredictionContext::gather_context(cursor_point, &buffer, &excerpt_options, &index_state)
176 .map(|context| {
177 make_cloud_request(
178 context,
179 events,
180 can_collect_data,
181 diagnostic_groups,
182 git_info,
183 &index_state,
184 )
185 })
186 })
187}
188
189pub fn make_cloud_request(
190 context: EditPredictionContext,
191 events: Vec<predict_edits_v3::Event>,
192 can_collect_data: bool,
193 diagnostic_groups: Vec<predict_edits_v3::DiagnosticGroup>,
194 git_info: Option<cloud_llm_client::PredictEditsGitInfo>,
195 index_state: &SyntaxIndexState,
196) -> predict_edits_v3::PredictEditsRequest {
197 let mut signatures = Vec::new();
198 let mut declaration_to_signature_index = HashMap::default();
199 let mut referenced_declarations = Vec::new();
200 for snippet in context.snippets {
201 let parent_index = snippet.declaration.parent().and_then(|parent| {
202 add_signature(
203 parent,
204 &mut declaration_to_signature_index,
205 &mut signatures,
206 index_state,
207 )
208 });
209 let (text, text_is_truncated) = snippet.declaration.item_text();
210 referenced_declarations.push(predict_edits_v3::ReferencedDeclaration {
211 text: text.into(),
212 text_is_truncated,
213 signature_range: snippet.declaration.signature_range_in_item_text(),
214 parent_index,
215 score_components: snippet.score_components,
216 signature_score: snippet.scores.signature,
217 declaration_score: snippet.scores.declaration,
218 });
219 }
220
221 let excerpt_parent = context
222 .excerpt
223 .parent_declarations
224 .last()
225 .and_then(|(parent, _)| {
226 add_signature(
227 *parent,
228 &mut declaration_to_signature_index,
229 &mut signatures,
230 index_state,
231 )
232 });
233
234 predict_edits_v3::PredictEditsRequest {
235 excerpt: context.excerpt_text.body,
236 referenced_declarations,
237 signatures,
238 excerpt_parent,
239 // todo!
240 events,
241 can_collect_data,
242 diagnostic_groups,
243 git_info,
244 }
245}
246
247fn add_signature(
248 declaration_id: DeclarationId,
249 declaration_to_signature_index: &mut HashMap<DeclarationId, usize>,
250 signatures: &mut Vec<Signature>,
251 index: &SyntaxIndexState,
252) -> Option<usize> {
253 if let Some(signature_index) = declaration_to_signature_index.get(&declaration_id) {
254 return Some(*signature_index);
255 }
256 let Some(parent_declaration) = index.declaration(declaration_id) else {
257 log::error!("bug: missing parent declaration");
258 return None;
259 };
260 let parent_index = parent_declaration.parent().and_then(|parent| {
261 add_signature(parent, declaration_to_signature_index, signatures, index)
262 });
263 let (text, text_is_truncated) = parent_declaration.signature_text();
264 let signature_index = signatures.len();
265 signatures.push(Signature {
266 text: text.into(),
267 text_is_truncated,
268 parent_index,
269 });
270 declaration_to_signature_index.insert(declaration_id, signature_index);
271 Some(signature_index)
272}