zeta2.rs

  1use cloud_llm_client::predict_edits_v3::{self, Signature};
  2use edit_prediction::{DataCollectionState, Direction, EditPrediction, EditPredictionProvider};
  3use edit_prediction_context::{
  4    DeclarationId, EditPredictionContext, EditPredictionExcerptOptions, SyntaxIndex,
  5    SyntaxIndexState,
  6};
  7use gpui::{App, Entity, EntityId, Task, prelude::*};
  8use language::{Anchor, ToPoint};
  9use language::{BufferSnapshot, Point};
 10use std::collections::HashMap;
 11use std::{ops::Range, sync::Arc};
 12
 13pub struct Zeta2EditPredictionProvider {
 14    current: Option<CurrentEditPrediction>,
 15    pending: Option<Task<()>>,
 16}
 17
 18impl Zeta2EditPredictionProvider {
 19    pub fn new() -> Self {
 20        Self {
 21            current: None,
 22            pending: None,
 23        }
 24    }
 25}
 26
 27#[derive(Clone)]
 28struct CurrentEditPrediction {
 29    buffer_id: EntityId,
 30    prediction: EditPrediction,
 31}
 32
 33impl EditPredictionProvider for Zeta2EditPredictionProvider {
 34    fn name() -> &'static str {
 35        // TODO [zeta2]
 36        "zed-predict2"
 37    }
 38
 39    fn display_name() -> &'static str {
 40        "Zed's Edit Predictions 2"
 41    }
 42
 43    fn show_completions_in_menu() -> bool {
 44        true
 45    }
 46
 47    fn show_tab_accept_marker() -> bool {
 48        true
 49    }
 50
 51    fn data_collection_state(&self, _cx: &App) -> DataCollectionState {
 52        // TODO [zeta2]
 53        DataCollectionState::Unsupported
 54    }
 55
 56    fn toggle_data_collection(&mut self, _cx: &mut App) {
 57        // TODO [zeta2]
 58    }
 59
 60    fn usage(&self, _cx: &App) -> Option<client::EditPredictionUsage> {
 61        // TODO [zeta2]
 62        None
 63    }
 64
 65    fn is_enabled(
 66        &self,
 67        _buffer: &Entity<language::Buffer>,
 68        _cursor_position: language::Anchor,
 69        _cx: &App,
 70    ) -> bool {
 71        true
 72    }
 73
 74    fn is_refreshing(&self) -> bool {
 75        self.pending.is_some()
 76    }
 77
 78    fn refresh(
 79        &mut self,
 80        _project: Option<Entity<project::Project>>,
 81        buffer: Entity<language::Buffer>,
 82        cursor_position: language::Anchor,
 83        _debounce: bool,
 84        cx: &mut Context<Self>,
 85    ) {
 86        // TODO [zeta2] check account
 87        // TODO [zeta2] actually request completion / interpolate
 88
 89        let snapshot = buffer.read(cx).snapshot();
 90        let point = cursor_position.to_point(&snapshot);
 91        let end_anchor = snapshot.anchor_before(language::Point::new(
 92            point.row,
 93            snapshot.line_len(point.row),
 94        ));
 95
 96        let edits: Arc<[(Range<Anchor>, String)]> =
 97            vec![(cursor_position..end_anchor, "👻".to_string())].into();
 98        let edits_preview_task = buffer.read(cx).preview_edits(edits.clone(), cx);
 99
100        // TODO [zeta2] throttle
101        // TODO [zeta2] keep 2 requests
102        self.pending = Some(cx.spawn(async move |this, cx| {
103            let edits_preview = edits_preview_task.await;
104
105            this.update(cx, |this, cx| {
106                this.current = Some(CurrentEditPrediction {
107                    buffer_id: buffer.entity_id(),
108                    prediction: EditPrediction {
109                        // TODO! [zeta2] request id?
110                        id: None,
111                        edits: edits.to_vec(),
112                        edit_preview: Some(edits_preview),
113                    },
114                });
115                this.pending.take();
116                cx.notify();
117            })
118            .ok();
119        }));
120        cx.notify();
121    }
122
123    fn cycle(
124        &mut self,
125        _buffer: Entity<language::Buffer>,
126        _cursor_position: language::Anchor,
127        _direction: Direction,
128        _cx: &mut Context<Self>,
129    ) {
130    }
131
132    fn accept(&mut self, _cx: &mut Context<Self>) {
133        // TODO [zeta2] report accept
134        self.current.take();
135        self.pending.take();
136    }
137
138    fn discard(&mut self, _cx: &mut Context<Self>) {
139        self.current.take();
140        self.pending.take();
141    }
142
143    fn suggest(
144        &mut self,
145        buffer: &Entity<language::Buffer>,
146        _cursor_position: language::Anchor,
147        _cx: &mut Context<Self>,
148    ) -> Option<EditPrediction> {
149        let current_prediction = self.current.take()?;
150
151        if current_prediction.buffer_id != buffer.entity_id() {
152            return None;
153        }
154
155        // TODO [zeta2] interpolate
156
157        Some(current_prediction.prediction)
158    }
159}
160
161pub fn make_cloud_request_in_background(
162    cursor_point: Point,
163    buffer: BufferSnapshot,
164    events: Vec<predict_edits_v3::Event>,
165    can_collect_data: bool,
166    diagnostic_groups: Vec<predict_edits_v3::DiagnosticGroup>,
167    git_info: Option<cloud_llm_client::PredictEditsGitInfo>,
168    excerpt_options: EditPredictionExcerptOptions,
169    syntax_index: Entity<SyntaxIndex>,
170    cx: &mut App,
171) -> Task<Option<predict_edits_v3::PredictEditsRequest>> {
172    let index_state = syntax_index.read_with(cx, |index, _cx| index.state().clone());
173    cx.background_spawn(async move {
174        let index_state = index_state.lock().await;
175        EditPredictionContext::gather_context(cursor_point, &buffer, &excerpt_options, &index_state)
176            .map(|context| {
177                make_cloud_request(
178                    context,
179                    events,
180                    can_collect_data,
181                    diagnostic_groups,
182                    git_info,
183                    &index_state,
184                )
185            })
186    })
187}
188
189pub fn make_cloud_request(
190    context: EditPredictionContext,
191    events: Vec<predict_edits_v3::Event>,
192    can_collect_data: bool,
193    diagnostic_groups: Vec<predict_edits_v3::DiagnosticGroup>,
194    git_info: Option<cloud_llm_client::PredictEditsGitInfo>,
195    index_state: &SyntaxIndexState,
196) -> predict_edits_v3::PredictEditsRequest {
197    let mut signatures = Vec::new();
198    let mut declaration_to_signature_index = HashMap::default();
199    let mut referenced_declarations = Vec::new();
200    for snippet in context.snippets {
201        let parent_index = snippet.declaration.parent().and_then(|parent| {
202            add_signature(
203                parent,
204                &mut declaration_to_signature_index,
205                &mut signatures,
206                index_state,
207            )
208        });
209        let (text, text_is_truncated) = snippet.declaration.item_text();
210        referenced_declarations.push(predict_edits_v3::ReferencedDeclaration {
211            text: text.into(),
212            text_is_truncated,
213            signature_range: snippet.declaration.signature_range_in_item_text(),
214            parent_index,
215            score_components: snippet.score_components,
216            signature_score: snippet.scores.signature,
217            declaration_score: snippet.scores.declaration,
218        });
219    }
220
221    let excerpt_parent = context
222        .excerpt
223        .parent_declarations
224        .last()
225        .and_then(|(parent, _)| {
226            add_signature(
227                *parent,
228                &mut declaration_to_signature_index,
229                &mut signatures,
230                index_state,
231            )
232        });
233
234    predict_edits_v3::PredictEditsRequest {
235        excerpt: context.excerpt_text.body,
236        referenced_declarations,
237        signatures,
238        excerpt_parent,
239        // todo!
240        events,
241        can_collect_data,
242        diagnostic_groups,
243        git_info,
244    }
245}
246
247fn add_signature(
248    declaration_id: DeclarationId,
249    declaration_to_signature_index: &mut HashMap<DeclarationId, usize>,
250    signatures: &mut Vec<Signature>,
251    index: &SyntaxIndexState,
252) -> Option<usize> {
253    if let Some(signature_index) = declaration_to_signature_index.get(&declaration_id) {
254        return Some(*signature_index);
255    }
256    let Some(parent_declaration) = index.declaration(declaration_id) else {
257        log::error!("bug: missing parent declaration");
258        return None;
259    };
260    let parent_index = parent_declaration.parent().and_then(|parent| {
261        add_signature(parent, declaration_to_signature_index, signatures, index)
262    });
263    let (text, text_is_truncated) = parent_declaration.signature_text();
264    let signature_index = signatures.len();
265    signatures.push(Signature {
266        text: text.into(),
267        text_is_truncated,
268        parent_index,
269    });
270    declaration_to_signature_index.insert(declaration_id, signature_index);
271    Some(signature_index)
272}