zeta2.rs

  1use anyhow::{Context as _, Result, anyhow};
  2use arrayvec::ArrayVec;
  3use client::{Client, EditPredictionUsage, UserStore};
  4use cloud_llm_client::predict_edits_v3::{self, Signature};
  5use cloud_llm_client::{
  6    EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME, ZED_VERSION_HEADER_NAME,
  7};
  8use edit_prediction::{DataCollectionState, Direction, EditPrediction, EditPredictionProvider};
  9use edit_prediction_context::{
 10    DeclarationId, EditPredictionContext, EditPredictionExcerptOptions, SyntaxIndex,
 11    SyntaxIndexState,
 12};
 13use futures::AsyncReadExt as _;
 14use gpui::http_client::Method;
 15use gpui::{
 16    App, Entity, EntityId, Global, SemanticVersion, SharedString, Subscription, Task, http_client,
 17    prelude::*,
 18};
 19use language::BufferSnapshot;
 20use language::{Anchor, Buffer, OffsetRangeExt as _, ToPoint};
 21use language_model::{LlmApiToken, RefreshLlmTokenListener};
 22use project::Project;
 23use release_channel::AppVersion;
 24use std::collections::HashMap;
 25use std::path::PathBuf;
 26use std::str::FromStr as _;
 27use std::time::{Duration, Instant};
 28use std::{ops::Range, sync::Arc};
 29use thiserror::Error;
 30use util::ResultExt as _;
 31use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
 32
 33#[derive(Clone)]
 34struct ZetaGlobal(Entity<Zeta>);
 35
 36impl Global for ZetaGlobal {}
 37
 38pub struct Zeta {
 39    client: Arc<Client>,
 40    user_store: Entity<UserStore>,
 41    llm_token: LlmApiToken,
 42    _llm_token_subscription: Subscription,
 43    projects: HashMap<EntityId, RegisteredProject>,
 44    excerpt_options: EditPredictionExcerptOptions,
 45    update_required: bool,
 46}
 47
 48struct RegisteredProject {
 49    syntax_index: Entity<SyntaxIndex>,
 50}
 51
 52impl Zeta {
 53    pub fn global(
 54        client: &Arc<Client>,
 55        user_store: &Entity<UserStore>,
 56        cx: &mut App,
 57    ) -> Entity<Self> {
 58        cx.try_global::<ZetaGlobal>()
 59            .map(|global| global.0.clone())
 60            .unwrap_or_else(|| {
 61                let zeta = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
 62                cx.set_global(ZetaGlobal(zeta.clone()));
 63                zeta
 64            })
 65    }
 66
 67    fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
 68        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
 69
 70        Self {
 71            projects: HashMap::new(),
 72            client,
 73            user_store,
 74            excerpt_options: EditPredictionExcerptOptions {
 75                max_bytes: 512,
 76                min_bytes: 128,
 77                target_before_cursor_over_total_bytes: 0.5,
 78            },
 79            llm_token: LlmApiToken::default(),
 80            _llm_token_subscription: cx.subscribe(
 81                &refresh_llm_token_listener,
 82                |this, _listener, _event, cx| {
 83                    let client = this.client.clone();
 84                    let llm_token = this.llm_token.clone();
 85                    cx.spawn(async move |_this, _cx| {
 86                        llm_token.refresh(&client).await?;
 87                        anyhow::Ok(())
 88                    })
 89                    .detach_and_log_err(cx);
 90                },
 91            ),
 92            update_required: false,
 93        }
 94    }
 95
 96    pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
 97        self.user_store.read(cx).edit_prediction_usage()
 98    }
 99
100    pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut App) {
101        self.projects
102            .entry(project.entity_id())
103            .or_insert_with(|| RegisteredProject {
104                syntax_index: cx.new(|cx| SyntaxIndex::new(project, cx)),
105            });
106    }
107
108    pub fn request_prediction(
109        &mut self,
110        project: &Entity<Project>,
111        buffer: &Entity<Buffer>,
112        position: language::Anchor,
113        cx: &mut Context<Self>,
114    ) -> Task<Result<Option<EditPrediction>>> {
115        let project_state = self.projects.get(&project.entity_id());
116
117        let index_state = project_state.map(|state| {
118            state
119                .syntax_index
120                .read_with(cx, |index, _cx| index.state().clone())
121        });
122        let excerpt_options = self.excerpt_options.clone();
123        let snapshot = buffer.read(cx).snapshot();
124        let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
125            return Task::ready(Err(anyhow!("No file path for excerpt")));
126        };
127        let client = self.client.clone();
128        let llm_token = self.llm_token.clone();
129        let app_version = AppVersion::global(cx);
130        let worktree_snapshots = project
131            .read(cx)
132            .worktrees(cx)
133            .map(|worktree| worktree.read(cx).snapshot())
134            .collect::<Vec<_>>();
135
136        let request_task = cx.background_spawn({
137            let snapshot = snapshot.clone();
138            async move {
139                let index_state = if let Some(index_state) = index_state {
140                    Some(index_state.lock_owned().await)
141                } else {
142                    None
143                };
144
145                let cursor_point = position.to_point(&snapshot);
146
147                // TODO: make this only true if debug view is open
148                let debug_info = true;
149
150                let Some(request) = EditPredictionContext::gather_context(
151                    cursor_point,
152                    &snapshot,
153                    &excerpt_options,
154                    index_state.as_deref(),
155                )
156                .map(|context| {
157                    make_cloud_request(
158                        excerpt_path.clone(),
159                        context,
160                        // TODO pass everything
161                        Vec::new(),
162                        false,
163                        Vec::new(),
164                        None,
165                        debug_info,
166                        &worktree_snapshots,
167                        index_state.as_deref(),
168                    )
169                }) else {
170                    return Ok(None);
171                };
172
173                anyhow::Ok(Some(
174                    Self::perform_request(client, llm_token, app_version, request).await?,
175                ))
176            }
177        });
178
179        let buffer = buffer.clone();
180
181        cx.spawn(async move |this, cx| {
182            match request_task.await {
183                Ok(Some((response, usage))) => {
184                    log::debug!("predicted edits: {:?}", &response.edits);
185
186                    if let Some(usage) = usage {
187                        this.update(cx, |this, cx| {
188                            this.user_store.update(cx, |user_store, cx| {
189                                user_store.update_edit_prediction_usage(usage, cx);
190                            });
191                        })
192                        .ok();
193                    }
194
195                    // TODO telemetry: duration, etc
196
197                    let edits = response
198                        .edits
199                        .into_iter()
200                        .map(|edit| {
201                            // TODO edits to different files
202                            (
203                                snapshot.anchor_before(edit.range.start)
204                                    ..snapshot.anchor_before(edit.range.end),
205                                edit.content,
206                            )
207                        })
208                        .collect::<Vec<_>>()
209                        .into();
210
211                    let Some((edits, edit_preview_task)) = buffer.read_with(cx, |buffer, cx| {
212                        let new_snapshot = buffer.snapshot();
213                        let edits: Arc<[_]> = interpolate(&snapshot, &new_snapshot, edits)?.into();
214                        Some((edits.clone().to_vec(), buffer.preview_edits(edits, cx)))
215                    })?
216                    else {
217                        return Ok(None);
218                    };
219
220                    Ok(Some(EditPrediction {
221                        // todo!
222                        id: None,
223                        edits,
224                        edit_preview: Some(edit_preview_task.await),
225                    }))
226                }
227                Ok(None) => Ok(None),
228                Err(err) => {
229                    if err.is::<ZedUpdateRequiredError>() {
230                        cx.update(|cx| {
231                            this.update(cx, |this, _cx| {
232                                this.update_required = true;
233                            })
234                            .ok();
235
236                            let error_message: SharedString = err.to_string().into();
237                            show_app_notification(
238                                NotificationId::unique::<ZedUpdateRequiredError>(),
239                                cx,
240                                move |cx| {
241                                    cx.new(|cx| {
242                                        ErrorMessagePrompt::new(error_message.clone(), cx)
243                                            .with_link_button(
244                                                "Update Zed",
245                                                "https://zed.dev/releases",
246                                            )
247                                    })
248                                },
249                            );
250                        })
251                        .ok();
252                    }
253
254                    Err(err)
255                }
256            }
257        })
258    }
259
260    async fn perform_request(
261        client: Arc<Client>,
262        llm_token: LlmApiToken,
263        app_version: SemanticVersion,
264        request: predict_edits_v3::PredictEditsRequest,
265    ) -> Result<(
266        predict_edits_v3::PredictEditsResponse,
267        Option<EditPredictionUsage>,
268    )> {
269        let http_client = client.http_client();
270        let mut token = llm_token.acquire(&client).await?;
271        let mut did_retry = false;
272
273        loop {
274            let request_builder = http_client::Request::builder().method(Method::POST);
275            let request_builder =
276                if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
277                    request_builder.uri(predict_edits_url)
278                } else {
279                    request_builder.uri(
280                        http_client
281                            .build_zed_llm_url("/predict_edits/v3", &[])?
282                            .as_ref(),
283                    )
284                };
285            let request = request_builder
286                .header("Content-Type", "application/json")
287                .header("Authorization", format!("Bearer {}", token))
288                .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
289                .body(serde_json::to_string(&request)?.into())?;
290
291            let mut response = http_client.send(request).await?;
292
293            if let Some(minimum_required_version) = response
294                .headers()
295                .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
296                .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
297            {
298                anyhow::ensure!(
299                    app_version >= minimum_required_version,
300                    ZedUpdateRequiredError {
301                        minimum_version: minimum_required_version
302                    }
303                );
304            }
305
306            if response.status().is_success() {
307                let usage = EditPredictionUsage::from_headers(response.headers()).ok();
308
309                let mut body = Vec::new();
310                response.body_mut().read_to_end(&mut body).await?;
311                return Ok((serde_json::from_slice(&body)?, usage));
312            } else if !did_retry
313                && response
314                    .headers()
315                    .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
316                    .is_some()
317            {
318                did_retry = true;
319                token = llm_token.refresh(&client).await?;
320            } else {
321                let mut body = String::new();
322                response.body_mut().read_to_string(&mut body).await?;
323                anyhow::bail!(
324                    "error predicting edits.\nStatus: {:?}\nBody: {}",
325                    response.status(),
326                    body
327                );
328            }
329        }
330    }
331}
332
333#[derive(Error, Debug)]
334#[error(
335    "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
336)]
337pub struct ZedUpdateRequiredError {
338    minimum_version: SemanticVersion,
339}
340
341pub struct ZetaEditPredictionProvider {
342    zeta: Entity<Zeta>,
343    current_prediction: Option<CurrentEditPrediction>,
344    next_pending_prediction_id: usize,
345    pending_predictions: ArrayVec<PendingPrediction, 2>,
346    last_request_timestamp: Instant,
347}
348
349impl ZetaEditPredictionProvider {
350    pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
351
352    pub fn new(
353        project: Option<&Entity<Project>>,
354        client: &Arc<Client>,
355        user_store: &Entity<UserStore>,
356        cx: &mut App,
357    ) -> Self {
358        let zeta = Zeta::global(client, user_store, cx);
359        if let Some(project) = project {
360            zeta.update(cx, |zeta, cx| {
361                zeta.register_project(project, cx);
362            });
363        }
364
365        Self {
366            zeta,
367            current_prediction: None,
368            next_pending_prediction_id: 0,
369            pending_predictions: ArrayVec::new(),
370            last_request_timestamp: Instant::now(),
371        }
372    }
373}
374
375#[derive(Clone)]
376struct CurrentEditPrediction {
377    buffer_id: EntityId,
378    prediction: EditPrediction,
379}
380
381impl CurrentEditPrediction {
382    fn should_replace_prediction(
383        &self,
384        _old_completion: &Self,
385        _snapshot: &BufferSnapshot,
386    ) -> bool {
387        true
388        // TODO
389        // if self.buffer_id != old_completion.buffer_id {
390        //     return true;
391        // }
392
393        // let Some(old_edits) = old_completion.completion.interpolate(snapshot) else {
394        //     return true;
395        // };
396        // let Some(new_edits) = self.completion.interpolate(snapshot) else {
397        //     return false;
398        // };
399
400        // if old_edits.len() == 1 && new_edits.len() == 1 {
401        //     let (old_range, old_text) = &old_edits[0];
402        //     let (new_range, new_text) = &new_edits[0];
403        //     new_range == old_range && new_text.starts_with(old_text)
404        // } else {
405        //     true
406        // }
407    }
408}
409
410struct PendingPrediction {
411    id: usize,
412    _task: Task<()>,
413}
414
415impl EditPredictionProvider for ZetaEditPredictionProvider {
416    fn name() -> &'static str {
417        // TODO [zeta2]
418        "zed-predict2"
419    }
420
421    fn display_name() -> &'static str {
422        "Zed's Edit Predictions 2"
423    }
424
425    fn show_completions_in_menu() -> bool {
426        true
427    }
428
429    fn show_tab_accept_marker() -> bool {
430        true
431    }
432
433    fn data_collection_state(&self, _cx: &App) -> DataCollectionState {
434        // TODO [zeta2]
435        DataCollectionState::Unsupported
436    }
437
438    fn toggle_data_collection(&mut self, _cx: &mut App) {
439        // TODO [zeta2]
440    }
441
442    fn usage(&self, cx: &App) -> Option<client::EditPredictionUsage> {
443        self.zeta.read(cx).usage(cx)
444    }
445
446    fn is_enabled(
447        &self,
448        _buffer: &Entity<language::Buffer>,
449        _cursor_position: language::Anchor,
450        _cx: &App,
451    ) -> bool {
452        true
453    }
454
455    fn is_refreshing(&self) -> bool {
456        !self.pending_predictions.is_empty()
457    }
458
459    fn refresh(
460        &mut self,
461        project: Option<Entity<project::Project>>,
462        buffer: Entity<language::Buffer>,
463        cursor_position: language::Anchor,
464        _debounce: bool,
465        cx: &mut Context<Self>,
466    ) {
467        let Some(project) = project else {
468            return;
469        };
470
471        // TODO [zeta2] check account
472        // if self
473        //     .zeta
474        //     .read(cx)
475        //     .user_store
476        //     .read_with(cx, |user_store, _cx| {
477        //         user_store.account_too_young() || user_store.has_overdue_invoices()
478        //     })
479        // {
480        //     return;
481        // }
482
483        // TODO [zeta2] try to interpolate current request
484
485        let pending_prediction_id = self.next_pending_prediction_id;
486        self.next_pending_prediction_id += 1;
487        let last_request_timestamp = self.last_request_timestamp;
488
489        let task = cx.spawn(async move |this, cx| {
490            if let Some(timeout) = (last_request_timestamp + Self::THROTTLE_TIMEOUT)
491                .checked_duration_since(Instant::now())
492            {
493                cx.background_executor().timer(timeout).await;
494            }
495
496            let prediction_request = this.update(cx, |this, cx| {
497                this.last_request_timestamp = Instant::now();
498                this.zeta.update(cx, |zeta, cx| {
499                    zeta.request_prediction(&project, &buffer, cursor_position, cx)
500                })
501            });
502
503            let prediction = match prediction_request {
504                Ok(prediction_request) => {
505                    let prediction_request = prediction_request.await;
506                    prediction_request.map(|c| {
507                        c.map(|prediction| CurrentEditPrediction {
508                            buffer_id: buffer.entity_id(),
509                            prediction,
510                        })
511                    })
512                }
513                Err(error) => Err(error),
514            };
515
516            this.update(cx, |this, cx| {
517                if this.pending_predictions[0].id == pending_prediction_id {
518                    this.pending_predictions.remove(0);
519                } else {
520                    this.pending_predictions.clear();
521                }
522
523                let Some(new_prediction) = prediction
524                    .context("edit prediction failed")
525                    .log_err()
526                    .flatten()
527                else {
528                    cx.notify();
529                    return;
530                };
531
532                if let Some(old_prediction) = this.current_prediction.as_ref() {
533                    let snapshot = buffer.read(cx).snapshot();
534                    if new_prediction.should_replace_prediction(old_prediction, &snapshot) {
535                        this.current_prediction = Some(new_prediction);
536                    }
537                } else {
538                    this.current_prediction = Some(new_prediction);
539                }
540
541                cx.notify();
542            })
543            .ok();
544        });
545
546        // We always maintain at most two pending predictions. When we already
547        // have two, we replace the newest one.
548        if self.pending_predictions.len() <= 1 {
549            self.pending_predictions.push(PendingPrediction {
550                id: pending_prediction_id,
551                _task: task,
552            });
553        } else if self.pending_predictions.len() == 2 {
554            self.pending_predictions.pop();
555            self.pending_predictions.push(PendingPrediction {
556                id: pending_prediction_id,
557                _task: task,
558            });
559        }
560
561        cx.notify();
562    }
563
564    fn cycle(
565        &mut self,
566        _buffer: Entity<language::Buffer>,
567        _cursor_position: language::Anchor,
568        _direction: Direction,
569        _cx: &mut Context<Self>,
570    ) {
571    }
572
573    fn accept(&mut self, _cx: &mut Context<Self>) {
574        // TODO [zeta2] report accept
575        self.current_prediction.take();
576        self.pending_predictions.clear();
577    }
578
579    fn discard(&mut self, _cx: &mut Context<Self>) {
580        self.pending_predictions.clear();
581        self.current_prediction.take();
582    }
583
584    fn suggest(
585        &mut self,
586        buffer: &Entity<language::Buffer>,
587        _cursor_position: language::Anchor,
588        _cx: &mut Context<Self>,
589    ) -> Option<EditPrediction> {
590        let current_prediction = self.current_prediction.take()?;
591
592        if current_prediction.buffer_id != buffer.entity_id() {
593            return None;
594        }
595
596        // TODO [zeta2] interpolate
597
598        Some(current_prediction.prediction)
599    }
600}
601
602fn make_cloud_request(
603    excerpt_path: PathBuf,
604    context: EditPredictionContext,
605    events: Vec<predict_edits_v3::Event>,
606    can_collect_data: bool,
607    diagnostic_groups: Vec<predict_edits_v3::DiagnosticGroup>,
608    git_info: Option<cloud_llm_client::PredictEditsGitInfo>,
609    debug_info: bool,
610    worktrees: &Vec<worktree::Snapshot>,
611    index_state: Option<&SyntaxIndexState>,
612) -> predict_edits_v3::PredictEditsRequest {
613    let mut signatures = Vec::new();
614    let mut declaration_to_signature_index = HashMap::default();
615    let mut referenced_declarations = Vec::new();
616
617    for snippet in context.snippets {
618        let project_entry_id = snippet.declaration.project_entry_id();
619        // TODO: Use full paths (worktree rooted) - need to move full_path method to the snapshot.
620        // Note that currently full_path is currently being used for excerpt_path.
621        let Some(path) = worktrees.iter().find_map(|worktree| {
622            let abs_path = worktree.abs_path();
623            worktree
624                .entry_for_id(project_entry_id)
625                .map(|e| abs_path.join(&e.path))
626        }) else {
627            continue;
628        };
629
630        let parent_index = index_state.and_then(|index_state| {
631            snippet.declaration.parent().and_then(|parent| {
632                add_signature(
633                    parent,
634                    &mut declaration_to_signature_index,
635                    &mut signatures,
636                    index_state,
637                )
638            })
639        });
640
641        let (text, text_is_truncated) = snippet.declaration.item_text();
642        referenced_declarations.push(predict_edits_v3::ReferencedDeclaration {
643            path,
644            text: text.into(),
645            range: snippet.declaration.item_range(),
646            text_is_truncated,
647            signature_range: snippet.declaration.signature_range_in_item_text(),
648            parent_index,
649            score_components: snippet.score_components,
650            signature_score: snippet.scores.signature,
651            declaration_score: snippet.scores.declaration,
652        });
653    }
654
655    let excerpt_parent = index_state.and_then(|index_state| {
656        context
657            .excerpt
658            .parent_declarations
659            .last()
660            .and_then(|(parent, _)| {
661                add_signature(
662                    *parent,
663                    &mut declaration_to_signature_index,
664                    &mut signatures,
665                    index_state,
666                )
667            })
668    });
669
670    predict_edits_v3::PredictEditsRequest {
671        excerpt_path,
672        excerpt: context.excerpt_text.body,
673        excerpt_range: context.excerpt.range,
674        cursor_offset: context.cursor_offset_in_excerpt,
675        referenced_declarations,
676        signatures,
677        excerpt_parent,
678        // todo!
679        events,
680        can_collect_data,
681        diagnostic_groups,
682        git_info,
683        debug_info,
684    }
685}
686
687fn add_signature(
688    declaration_id: DeclarationId,
689    declaration_to_signature_index: &mut HashMap<DeclarationId, usize>,
690    signatures: &mut Vec<Signature>,
691    index: &SyntaxIndexState,
692) -> Option<usize> {
693    if let Some(signature_index) = declaration_to_signature_index.get(&declaration_id) {
694        return Some(*signature_index);
695    }
696    let Some(parent_declaration) = index.declaration(declaration_id) else {
697        log::error!("bug: missing parent declaration");
698        return None;
699    };
700    let parent_index = parent_declaration.parent().and_then(|parent| {
701        add_signature(parent, declaration_to_signature_index, signatures, index)
702    });
703    let (text, text_is_truncated) = parent_declaration.signature_text();
704    let signature_index = signatures.len();
705    signatures.push(Signature {
706        text: text.into(),
707        text_is_truncated,
708        parent_index,
709    });
710    declaration_to_signature_index.insert(declaration_id, signature_index);
711    Some(signature_index)
712}
713
714fn interpolate(
715    old_snapshot: &BufferSnapshot,
716    new_snapshot: &BufferSnapshot,
717    current_edits: Arc<[(Range<Anchor>, String)]>,
718) -> Option<Vec<(Range<Anchor>, String)>> {
719    let mut edits = Vec::new();
720
721    let mut model_edits = current_edits.iter().peekable();
722    for user_edit in new_snapshot.edits_since::<usize>(&old_snapshot.version) {
723        while let Some((model_old_range, _)) = model_edits.peek() {
724            let model_old_range = model_old_range.to_offset(old_snapshot);
725            if model_old_range.end < user_edit.old.start {
726                let (model_old_range, model_new_text) = model_edits.next().unwrap();
727                edits.push((model_old_range.clone(), model_new_text.clone()));
728            } else {
729                break;
730            }
731        }
732
733        if let Some((model_old_range, model_new_text)) = model_edits.peek() {
734            let model_old_offset_range = model_old_range.to_offset(old_snapshot);
735            if user_edit.old == model_old_offset_range {
736                let user_new_text = new_snapshot
737                    .text_for_range(user_edit.new.clone())
738                    .collect::<String>();
739
740                if let Some(model_suffix) = model_new_text.strip_prefix(&user_new_text) {
741                    if !model_suffix.is_empty() {
742                        let anchor = old_snapshot.anchor_after(user_edit.old.end);
743                        edits.push((anchor..anchor, model_suffix.to_string()));
744                    }
745
746                    model_edits.next();
747                    continue;
748                }
749            }
750        }
751
752        return None;
753    }
754
755    edits.extend(model_edits.cloned());
756
757    if edits.is_empty() { None } else { Some(edits) }
758}