edit_prediction.rs

   1use anyhow::Result;
   2use arrayvec::ArrayVec;
   3use client::{Client, EditPredictionUsage, UserStore};
   4use cloud_llm_client::predict_edits_v3::{self, PromptFormat};
   5use cloud_llm_client::{
   6    AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, EditPredictionRejectReason,
   7    EditPredictionRejection, MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST,
   8    MINIMUM_REQUIRED_VERSION_HEADER_NAME, PredictEditsRequestTrigger, RejectEditPredictionsBodyRef,
   9    ZED_VERSION_HEADER_NAME,
  10};
  11use collections::{HashMap, HashSet};
  12use db::kvp::{Dismissable, KEY_VALUE_STORE};
  13use edit_prediction_context::EditPredictionExcerptOptions;
  14use edit_prediction_context::{RelatedExcerptStore, RelatedExcerptStoreEvent, RelatedFile};
  15use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
  16use futures::{
  17    AsyncReadExt as _, FutureExt as _, StreamExt as _,
  18    channel::mpsc::{self, UnboundedReceiver},
  19    select_biased,
  20};
  21use gpui::BackgroundExecutor;
  22use gpui::http_client::Url;
  23use gpui::{
  24    App, AsyncApp, Entity, EntityId, Global, SharedString, Subscription, Task, WeakEntity, actions,
  25    http_client::{self, AsyncBody, Method},
  26    prelude::*,
  27};
  28use language::language_settings::all_language_settings;
  29use language::{Anchor, Buffer, File, Point, TextBufferSnapshot, ToPoint};
  30use language::{BufferSnapshot, OffsetRangeExt};
  31use language_model::{LlmApiToken, RefreshLlmTokenListener};
  32use project::{Project, ProjectPath, WorktreeId};
  33use release_channel::AppVersion;
  34use semver::Version;
  35use serde::de::DeserializeOwned;
  36use settings::{EditPredictionProvider, SettingsStore, update_settings_file};
  37use std::collections::{VecDeque, hash_map};
  38use workspace::Workspace;
  39
  40use std::ops::Range;
  41use std::path::Path;
  42use std::rc::Rc;
  43use std::str::FromStr as _;
  44use std::sync::{Arc, LazyLock};
  45use std::time::{Duration, Instant};
  46use std::{env, mem};
  47use thiserror::Error;
  48use util::{RangeExt as _, ResultExt as _};
  49use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
  50
  51pub mod cursor_excerpt;
  52pub mod example_spec;
  53mod license_detection;
  54pub mod mercury;
  55mod onboarding_modal;
  56pub mod open_ai_response;
  57mod prediction;
  58pub mod sweep_ai;
  59
  60#[cfg(any(test, feature = "test-support", feature = "cli-support"))]
  61pub mod udiff;
  62
  63mod zed_edit_prediction_delegate;
  64pub mod zeta1;
  65pub mod zeta2;
  66
  67#[cfg(test)]
  68mod edit_prediction_tests;
  69
  70use crate::license_detection::LicenseDetectionWatcher;
  71use crate::mercury::Mercury;
  72use crate::onboarding_modal::ZedPredictModal;
  73pub use crate::prediction::EditPrediction;
  74pub use crate::prediction::EditPredictionId;
  75use crate::prediction::EditPredictionResult;
  76pub use crate::sweep_ai::SweepAi;
  77pub use language_model::ApiKeyState;
  78pub use telemetry_events::EditPredictionRating;
  79pub use zed_edit_prediction_delegate::ZedEditPredictionDelegate;
  80
  81actions!(
  82    edit_prediction,
  83    [
  84        /// Resets the edit prediction onboarding state.
  85        ResetOnboarding,
  86        /// Clears the edit prediction history.
  87        ClearHistory,
  88    ]
  89);
  90
  91/// Maximum number of events to track.
  92const EVENT_COUNT_MAX: usize = 6;
  93const CHANGE_GROUPING_LINE_SPAN: u32 = 8;
  94const LAST_CHANGE_GROUPING_TIME: Duration = Duration::from_secs(1);
  95const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
  96const REJECT_REQUEST_DEBOUNCE: Duration = Duration::from_secs(15);
  97
  98pub struct SweepFeatureFlag;
  99
 100impl FeatureFlag for SweepFeatureFlag {
 101    const NAME: &str = "sweep-ai";
 102}
 103
 104pub struct MercuryFeatureFlag;
 105
 106impl FeatureFlag for MercuryFeatureFlag {
 107    const NAME: &str = "mercury";
 108}
 109
 110pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
 111    context: EditPredictionExcerptOptions {
 112        max_bytes: 512,
 113        min_bytes: 128,
 114        target_before_cursor_over_total_bytes: 0.5,
 115    },
 116    prompt_format: PromptFormat::DEFAULT,
 117};
 118
 119static USE_OLLAMA: LazyLock<bool> =
 120    LazyLock::new(|| env::var("ZED_ZETA2_OLLAMA").is_ok_and(|var| !var.is_empty()));
 121
 122static EDIT_PREDICTIONS_MODEL_ID: LazyLock<String> = LazyLock::new(|| {
 123    match env::var("ZED_ZETA2_MODEL").as_deref() {
 124        Ok("zeta2-exp") => "4w5n28vw", // Fine-tuned model @ Baseten
 125        Ok(model) => model,
 126        Err(_) if *USE_OLLAMA => "qwen3-coder:30b",
 127        Err(_) => "yqvev8r3", // Vanilla qwen3-coder @ Baseten
 128    }
 129    .to_string()
 130});
 131
 132pub struct Zeta2FeatureFlag;
 133
 134impl FeatureFlag for Zeta2FeatureFlag {
 135    const NAME: &'static str = "zeta2";
 136
 137    fn enabled_for_staff() -> bool {
 138        true
 139    }
 140}
 141
 142#[derive(Clone)]
 143struct EditPredictionStoreGlobal(Entity<EditPredictionStore>);
 144
 145impl Global for EditPredictionStoreGlobal {}
 146
 147pub struct EditPredictionStore {
 148    client: Arc<Client>,
 149    user_store: Entity<UserStore>,
 150    llm_token: LlmApiToken,
 151    _llm_token_subscription: Subscription,
 152    projects: HashMap<EntityId, ProjectState>,
 153    use_context: bool,
 154    options: ZetaOptions,
 155    update_required: bool,
 156    #[cfg(feature = "cli-support")]
 157    eval_cache: Option<Arc<dyn EvalCache>>,
 158    edit_prediction_model: EditPredictionModel,
 159    pub sweep_ai: SweepAi,
 160    pub mercury: Mercury,
 161    data_collection_choice: DataCollectionChoice,
 162    reject_predictions_tx: mpsc::UnboundedSender<EditPredictionRejection>,
 163    shown_predictions: VecDeque<EditPrediction>,
 164    rated_predictions: HashSet<EditPredictionId>,
 165    custom_predict_edits_url: Option<Arc<Url>>,
 166}
 167
 168#[derive(Copy, Clone, Default, PartialEq, Eq)]
 169pub enum EditPredictionModel {
 170    #[default]
 171    Zeta1,
 172    Zeta2,
 173    Sweep,
 174    Mercury,
 175}
 176
 177pub struct EditPredictionModelInput {
 178    project: Entity<Project>,
 179    buffer: Entity<Buffer>,
 180    snapshot: BufferSnapshot,
 181    position: Anchor,
 182    events: Vec<Arc<zeta_prompt::Event>>,
 183    related_files: Arc<[RelatedFile]>,
 184    recent_paths: VecDeque<ProjectPath>,
 185    trigger: PredictEditsRequestTrigger,
 186    diagnostic_search_range: Range<Point>,
 187    debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
 188}
 189
 190#[derive(Debug, Clone, PartialEq)]
 191pub struct ZetaOptions {
 192    pub context: EditPredictionExcerptOptions,
 193    pub prompt_format: predict_edits_v3::PromptFormat,
 194}
 195
 196#[derive(Debug)]
 197pub enum DebugEvent {
 198    ContextRetrievalStarted(ContextRetrievalStartedDebugEvent),
 199    ContextRetrievalFinished(ContextRetrievalFinishedDebugEvent),
 200    EditPredictionStarted(EditPredictionStartedDebugEvent),
 201    EditPredictionFinished(EditPredictionFinishedDebugEvent),
 202}
 203
 204#[derive(Debug)]
 205pub struct ContextRetrievalStartedDebugEvent {
 206    pub project_entity_id: EntityId,
 207    pub timestamp: Instant,
 208    pub search_prompt: String,
 209}
 210
 211#[derive(Debug)]
 212pub struct ContextRetrievalFinishedDebugEvent {
 213    pub project_entity_id: EntityId,
 214    pub timestamp: Instant,
 215    pub metadata: Vec<(&'static str, SharedString)>,
 216}
 217
 218#[derive(Debug)]
 219pub struct EditPredictionStartedDebugEvent {
 220    pub buffer: WeakEntity<Buffer>,
 221    pub position: Anchor,
 222    pub prompt: Option<String>,
 223}
 224
 225#[derive(Debug)]
 226pub struct EditPredictionFinishedDebugEvent {
 227    pub buffer: WeakEntity<Buffer>,
 228    pub position: Anchor,
 229    pub model_output: Option<String>,
 230}
 231
 232pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
 233
 234struct ProjectState {
 235    events: VecDeque<Arc<zeta_prompt::Event>>,
 236    last_event: Option<LastEvent>,
 237    recent_paths: VecDeque<ProjectPath>,
 238    registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
 239    current_prediction: Option<CurrentEditPrediction>,
 240    next_pending_prediction_id: usize,
 241    pending_predictions: ArrayVec<PendingPrediction, 2>,
 242    debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
 243    last_prediction_refresh: Option<(EntityId, Instant)>,
 244    cancelled_predictions: HashSet<usize>,
 245    context: Entity<RelatedExcerptStore>,
 246    license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
 247    _subscription: gpui::Subscription,
 248}
 249
 250impl ProjectState {
 251    pub fn events(&self, cx: &App) -> Vec<Arc<zeta_prompt::Event>> {
 252        self.events
 253            .iter()
 254            .cloned()
 255            .chain(
 256                self.last_event
 257                    .as_ref()
 258                    .and_then(|event| event.finalize(&self.license_detection_watchers, cx)),
 259            )
 260            .collect()
 261    }
 262
 263    pub fn events_split_by_pause(&self, cx: &App) -> Vec<Arc<zeta_prompt::Event>> {
 264        self.events
 265            .iter()
 266            .cloned()
 267            .chain(self.last_event.as_ref().iter().flat_map(|event| {
 268                let (one, two) = event.split_by_pause();
 269                let one = one.finalize(&self.license_detection_watchers, cx);
 270                let two = two.and_then(|two| two.finalize(&self.license_detection_watchers, cx));
 271                one.into_iter().chain(two)
 272            }))
 273            .collect()
 274    }
 275
 276    fn cancel_pending_prediction(
 277        &mut self,
 278        pending_prediction: PendingPrediction,
 279        cx: &mut Context<EditPredictionStore>,
 280    ) {
 281        self.cancelled_predictions.insert(pending_prediction.id);
 282
 283        cx.spawn(async move |this, cx| {
 284            let Some(prediction_id) = pending_prediction.task.await else {
 285                return;
 286            };
 287
 288            this.update(cx, |this, _cx| {
 289                this.reject_prediction(prediction_id, EditPredictionRejectReason::Canceled, false);
 290            })
 291            .ok();
 292        })
 293        .detach()
 294    }
 295
 296    fn active_buffer(
 297        &self,
 298        project: &Entity<Project>,
 299        cx: &App,
 300    ) -> Option<(Entity<Buffer>, Option<Anchor>)> {
 301        let project = project.read(cx);
 302        let active_path = project.path_for_entry(project.active_entry()?, cx)?;
 303        let active_buffer = project.buffer_store().read(cx).get_by_path(&active_path)?;
 304        let registered_buffer = self.registered_buffers.get(&active_buffer.entity_id())?;
 305        Some((active_buffer, registered_buffer.last_position))
 306    }
 307}
 308
 309#[derive(Debug, Clone)]
 310struct CurrentEditPrediction {
 311    pub requested_by: PredictionRequestedBy,
 312    pub prediction: EditPrediction,
 313    pub was_shown: bool,
 314}
 315
 316impl CurrentEditPrediction {
 317    fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
 318        let Some(new_edits) = self
 319            .prediction
 320            .interpolate(&self.prediction.buffer.read(cx))
 321        else {
 322            return false;
 323        };
 324
 325        if self.prediction.buffer != old_prediction.prediction.buffer {
 326            return true;
 327        }
 328
 329        let Some(old_edits) = old_prediction
 330            .prediction
 331            .interpolate(&old_prediction.prediction.buffer.read(cx))
 332        else {
 333            return true;
 334        };
 335
 336        let requested_by_buffer_id = self.requested_by.buffer_id();
 337
 338        // This reduces the occurrence of UI thrash from replacing edits
 339        //
 340        // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
 341        if requested_by_buffer_id == Some(self.prediction.buffer.entity_id())
 342            && requested_by_buffer_id == Some(old_prediction.prediction.buffer.entity_id())
 343            && old_edits.len() == 1
 344            && new_edits.len() == 1
 345        {
 346            let (old_range, old_text) = &old_edits[0];
 347            let (new_range, new_text) = &new_edits[0];
 348            new_range == old_range && new_text.starts_with(old_text.as_ref())
 349        } else {
 350            true
 351        }
 352    }
 353}
 354
 355#[derive(Debug, Clone)]
 356enum PredictionRequestedBy {
 357    DiagnosticsUpdate,
 358    Buffer(EntityId),
 359}
 360
 361impl PredictionRequestedBy {
 362    pub fn buffer_id(&self) -> Option<EntityId> {
 363        match self {
 364            PredictionRequestedBy::DiagnosticsUpdate => None,
 365            PredictionRequestedBy::Buffer(buffer_id) => Some(*buffer_id),
 366        }
 367    }
 368}
 369
 370#[derive(Debug)]
 371struct PendingPrediction {
 372    id: usize,
 373    task: Task<Option<EditPredictionId>>,
 374}
 375
 376/// A prediction from the perspective of a buffer.
 377#[derive(Debug)]
 378enum BufferEditPrediction<'a> {
 379    Local { prediction: &'a EditPrediction },
 380    Jump { prediction: &'a EditPrediction },
 381}
 382
 383#[cfg(test)]
 384impl std::ops::Deref for BufferEditPrediction<'_> {
 385    type Target = EditPrediction;
 386
 387    fn deref(&self) -> &Self::Target {
 388        match self {
 389            BufferEditPrediction::Local { prediction } => prediction,
 390            BufferEditPrediction::Jump { prediction } => prediction,
 391        }
 392    }
 393}
 394
 395struct RegisteredBuffer {
 396    file: Option<Arc<dyn File>>,
 397    snapshot: TextBufferSnapshot,
 398    last_position: Option<Anchor>,
 399    _subscriptions: [gpui::Subscription; 2],
 400}
 401
 402#[derive(Clone)]
 403struct LastEvent {
 404    old_snapshot: TextBufferSnapshot,
 405    new_snapshot: TextBufferSnapshot,
 406    old_file: Option<Arc<dyn File>>,
 407    new_file: Option<Arc<dyn File>>,
 408    end_edit_anchor: Option<Anchor>,
 409    snapshot_after_last_editing_pause: Option<TextBufferSnapshot>,
 410    last_edit_time: Option<Instant>,
 411}
 412
 413impl LastEvent {
 414    pub fn finalize(
 415        &self,
 416        license_detection_watchers: &HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
 417        cx: &App,
 418    ) -> Option<Arc<zeta_prompt::Event>> {
 419        let path = buffer_path_with_id_fallback(self.new_file.as_ref(), &self.new_snapshot, cx);
 420        let old_path = buffer_path_with_id_fallback(self.old_file.as_ref(), &self.old_snapshot, cx);
 421
 422        let in_open_source_repo =
 423            [self.new_file.as_ref(), self.old_file.as_ref()]
 424                .iter()
 425                .all(|file| {
 426                    file.is_some_and(|file| {
 427                        license_detection_watchers
 428                            .get(&file.worktree_id(cx))
 429                            .is_some_and(|watcher| watcher.is_project_open_source())
 430                    })
 431                });
 432
 433        let diff = language::unified_diff(&self.old_snapshot.text(), &self.new_snapshot.text());
 434
 435        if path == old_path && diff.is_empty() {
 436            None
 437        } else {
 438            Some(Arc::new(zeta_prompt::Event::BufferChange {
 439                old_path,
 440                path,
 441                diff,
 442                in_open_source_repo,
 443                // TODO: Actually detect if this edit was predicted or not
 444                predicted: false,
 445            }))
 446        }
 447    }
 448
 449    pub fn split_by_pause(&self) -> (LastEvent, Option<LastEvent>) {
 450        let Some(boundary_snapshot) = self.snapshot_after_last_editing_pause.as_ref() else {
 451            return (self.clone(), None);
 452        };
 453
 454        let before = LastEvent {
 455            old_snapshot: self.old_snapshot.clone(),
 456            new_snapshot: boundary_snapshot.clone(),
 457            old_file: self.old_file.clone(),
 458            new_file: self.new_file.clone(),
 459            end_edit_anchor: self.end_edit_anchor,
 460            snapshot_after_last_editing_pause: None,
 461            last_edit_time: self.last_edit_time,
 462        };
 463
 464        let after = LastEvent {
 465            old_snapshot: boundary_snapshot.clone(),
 466            new_snapshot: self.new_snapshot.clone(),
 467            old_file: self.old_file.clone(),
 468            new_file: self.new_file.clone(),
 469            end_edit_anchor: self.end_edit_anchor,
 470            snapshot_after_last_editing_pause: None,
 471            last_edit_time: self.last_edit_time,
 472        };
 473
 474        (before, Some(after))
 475    }
 476}
 477
 478fn buffer_path_with_id_fallback(
 479    file: Option<&Arc<dyn File>>,
 480    snapshot: &TextBufferSnapshot,
 481    cx: &App,
 482) -> Arc<Path> {
 483    if let Some(file) = file {
 484        file.full_path(cx).into()
 485    } else {
 486        Path::new(&format!("untitled-{}", snapshot.remote_id())).into()
 487    }
 488}
 489
 490impl EditPredictionStore {
 491    pub fn try_global(cx: &App) -> Option<Entity<Self>> {
 492        cx.try_global::<EditPredictionStoreGlobal>()
 493            .map(|global| global.0.clone())
 494    }
 495
 496    pub fn global(
 497        client: &Arc<Client>,
 498        user_store: &Entity<UserStore>,
 499        cx: &mut App,
 500    ) -> Entity<Self> {
 501        cx.try_global::<EditPredictionStoreGlobal>()
 502            .map(|global| global.0.clone())
 503            .unwrap_or_else(|| {
 504                let ep_store = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
 505                cx.set_global(EditPredictionStoreGlobal(ep_store.clone()));
 506                ep_store
 507            })
 508    }
 509
 510    pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
 511        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
 512        let data_collection_choice = Self::load_data_collection_choice();
 513
 514        let llm_token = LlmApiToken::default();
 515
 516        let (reject_tx, reject_rx) = mpsc::unbounded();
 517        cx.background_spawn({
 518            let client = client.clone();
 519            let llm_token = llm_token.clone();
 520            let app_version = AppVersion::global(cx);
 521            let background_executor = cx.background_executor().clone();
 522            async move {
 523                Self::handle_rejected_predictions(
 524                    reject_rx,
 525                    client,
 526                    llm_token,
 527                    app_version,
 528                    background_executor,
 529                )
 530                .await
 531            }
 532        })
 533        .detach();
 534
 535        let mut this = Self {
 536            projects: HashMap::default(),
 537            client,
 538            user_store,
 539            options: DEFAULT_OPTIONS,
 540            use_context: false,
 541            llm_token,
 542            _llm_token_subscription: cx.subscribe(
 543                &refresh_llm_token_listener,
 544                |this, _listener, _event, cx| {
 545                    let client = this.client.clone();
 546                    let llm_token = this.llm_token.clone();
 547                    cx.spawn(async move |_this, _cx| {
 548                        llm_token.refresh(&client).await?;
 549                        anyhow::Ok(())
 550                    })
 551                    .detach_and_log_err(cx);
 552                },
 553            ),
 554            update_required: false,
 555            #[cfg(feature = "cli-support")]
 556            eval_cache: None,
 557            edit_prediction_model: EditPredictionModel::Zeta2,
 558            sweep_ai: SweepAi::new(cx),
 559            mercury: Mercury::new(cx),
 560            data_collection_choice,
 561            reject_predictions_tx: reject_tx,
 562            rated_predictions: Default::default(),
 563            shown_predictions: Default::default(),
 564            custom_predict_edits_url: match env::var("ZED_PREDICT_EDITS_URL") {
 565                Ok(custom_url) => Url::parse(&custom_url).log_err().map(Into::into),
 566                Err(_) => {
 567                    if *USE_OLLAMA {
 568                        Some(
 569                            Url::parse("http://localhost:11434/v1/chat/completions")
 570                                .unwrap()
 571                                .into(),
 572                        )
 573                    } else {
 574                        None
 575                    }
 576                }
 577            },
 578        };
 579
 580        this.configure_context_retrieval(cx);
 581        let weak_this = cx.weak_entity();
 582        cx.on_flags_ready(move |_, cx| {
 583            weak_this
 584                .update(cx, |this, cx| this.configure_context_retrieval(cx))
 585                .ok();
 586        })
 587        .detach();
 588        cx.observe_global::<SettingsStore>(|this, cx| {
 589            this.configure_context_retrieval(cx);
 590        })
 591        .detach();
 592
 593        this
 594    }
 595
 596    #[cfg(test)]
 597    pub fn set_custom_predict_edits_url(&mut self, url: Url) {
 598        self.custom_predict_edits_url = Some(url.into());
 599    }
 600
 601    pub fn set_edit_prediction_model(&mut self, model: EditPredictionModel) {
 602        self.edit_prediction_model = model;
 603    }
 604
 605    pub fn has_sweep_api_token(&self, cx: &App) -> bool {
 606        self.sweep_ai.api_token.read(cx).has_key()
 607    }
 608
 609    pub fn has_mercury_api_token(&self, cx: &App) -> bool {
 610        self.mercury.api_token.read(cx).has_key()
 611    }
 612
 613    #[cfg(feature = "cli-support")]
 614    pub fn with_eval_cache(&mut self, cache: Arc<dyn EvalCache>) {
 615        self.eval_cache = Some(cache);
 616    }
 617
 618    pub fn options(&self) -> &ZetaOptions {
 619        &self.options
 620    }
 621
 622    pub fn set_options(&mut self, options: ZetaOptions) {
 623        self.options = options;
 624    }
 625
 626    pub fn set_use_context(&mut self, use_context: bool) {
 627        self.use_context = use_context;
 628    }
 629
 630    pub fn clear_history(&mut self) {
 631        for project_state in self.projects.values_mut() {
 632            project_state.events.clear();
 633        }
 634    }
 635
 636    pub fn clear_history_for_project(&mut self, project: &Entity<Project>) {
 637        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
 638            project_state.events.clear();
 639        }
 640    }
 641
 642    pub fn edit_history_for_project(
 643        &self,
 644        project: &Entity<Project>,
 645        cx: &App,
 646    ) -> Vec<Arc<zeta_prompt::Event>> {
 647        self.projects
 648            .get(&project.entity_id())
 649            .map(|project_state| project_state.events(cx))
 650            .unwrap_or_default()
 651    }
 652
 653    pub fn edit_history_for_project_with_pause_split_last_event(
 654        &self,
 655        project: &Entity<Project>,
 656        cx: &App,
 657    ) -> Vec<Arc<zeta_prompt::Event>> {
 658        self.projects
 659            .get(&project.entity_id())
 660            .map(|project_state| project_state.events_split_by_pause(cx))
 661            .unwrap_or_default()
 662    }
 663
 664    pub fn context_for_project<'a>(
 665        &'a self,
 666        project: &Entity<Project>,
 667        cx: &'a App,
 668    ) -> Arc<[RelatedFile]> {
 669        self.projects
 670            .get(&project.entity_id())
 671            .map(|project| project.context.read(cx).related_files())
 672            .unwrap_or_else(|| vec![].into())
 673    }
 674
 675    pub fn context_for_project_with_buffers<'a>(
 676        &'a self,
 677        project: &Entity<Project>,
 678        cx: &'a App,
 679    ) -> Option<impl 'a + Iterator<Item = (RelatedFile, Entity<Buffer>)>> {
 680        self.projects
 681            .get(&project.entity_id())
 682            .map(|project| project.context.read(cx).related_files_with_buffers())
 683    }
 684
 685    pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
 686        if self.edit_prediction_model == EditPredictionModel::Zeta2 {
 687            self.user_store.read(cx).edit_prediction_usage()
 688        } else {
 689            None
 690        }
 691    }
 692
 693    pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
 694        self.get_or_init_project(project, cx);
 695    }
 696
 697    pub fn register_buffer(
 698        &mut self,
 699        buffer: &Entity<Buffer>,
 700        project: &Entity<Project>,
 701        cx: &mut Context<Self>,
 702    ) {
 703        let project_state = self.get_or_init_project(project, cx);
 704        Self::register_buffer_impl(project_state, buffer, project, cx);
 705    }
 706
 707    fn get_or_init_project(
 708        &mut self,
 709        project: &Entity<Project>,
 710        cx: &mut Context<Self>,
 711    ) -> &mut ProjectState {
 712        let entity_id = project.entity_id();
 713        self.projects
 714            .entry(entity_id)
 715            .or_insert_with(|| ProjectState {
 716                context: {
 717                    let related_excerpt_store = cx.new(|cx| RelatedExcerptStore::new(project, cx));
 718                    cx.subscribe(&related_excerpt_store, move |this, _, event, _| {
 719                        this.handle_excerpt_store_event(entity_id, event);
 720                    })
 721                    .detach();
 722                    related_excerpt_store
 723                },
 724                events: VecDeque::new(),
 725                last_event: None,
 726                recent_paths: VecDeque::new(),
 727                debug_tx: None,
 728                registered_buffers: HashMap::default(),
 729                current_prediction: None,
 730                cancelled_predictions: HashSet::default(),
 731                pending_predictions: ArrayVec::new(),
 732                next_pending_prediction_id: 0,
 733                last_prediction_refresh: None,
 734                license_detection_watchers: HashMap::default(),
 735                _subscription: cx.subscribe(&project, Self::handle_project_event),
 736            })
 737    }
 738
 739    pub fn remove_project(&mut self, project: &Entity<Project>) {
 740        self.projects.remove(&project.entity_id());
 741    }
 742
 743    fn handle_excerpt_store_event(
 744        &mut self,
 745        project_entity_id: EntityId,
 746        event: &RelatedExcerptStoreEvent,
 747    ) {
 748        if let Some(project_state) = self.projects.get(&project_entity_id) {
 749            if let Some(debug_tx) = project_state.debug_tx.clone() {
 750                match event {
 751                    RelatedExcerptStoreEvent::StartedRefresh => {
 752                        debug_tx
 753                            .unbounded_send(DebugEvent::ContextRetrievalStarted(
 754                                ContextRetrievalStartedDebugEvent {
 755                                    project_entity_id: project_entity_id,
 756                                    timestamp: Instant::now(),
 757                                    search_prompt: String::new(),
 758                                },
 759                            ))
 760                            .ok();
 761                    }
 762                    RelatedExcerptStoreEvent::FinishedRefresh {
 763                        cache_hit_count,
 764                        cache_miss_count,
 765                        mean_definition_latency,
 766                        max_definition_latency,
 767                    } => {
 768                        debug_tx
 769                            .unbounded_send(DebugEvent::ContextRetrievalFinished(
 770                                ContextRetrievalFinishedDebugEvent {
 771                                    project_entity_id: project_entity_id,
 772                                    timestamp: Instant::now(),
 773                                    metadata: vec![
 774                                        (
 775                                            "Cache Hits",
 776                                            format!(
 777                                                "{}/{}",
 778                                                cache_hit_count,
 779                                                cache_hit_count + cache_miss_count
 780                                            )
 781                                            .into(),
 782                                        ),
 783                                        (
 784                                            "Max LSP Time",
 785                                            format!("{} ms", max_definition_latency.as_millis())
 786                                                .into(),
 787                                        ),
 788                                        (
 789                                            "Mean LSP Time",
 790                                            format!("{} ms", mean_definition_latency.as_millis())
 791                                                .into(),
 792                                        ),
 793                                    ],
 794                                },
 795                            ))
 796                            .ok();
 797                    }
 798                }
 799            }
 800        }
 801    }
 802
 803    pub fn debug_info(
 804        &mut self,
 805        project: &Entity<Project>,
 806        cx: &mut Context<Self>,
 807    ) -> mpsc::UnboundedReceiver<DebugEvent> {
 808        let project_state = self.get_or_init_project(project, cx);
 809        let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
 810        project_state.debug_tx = Some(debug_watch_tx);
 811        debug_watch_rx
 812    }
 813
 814    fn handle_project_event(
 815        &mut self,
 816        project: Entity<Project>,
 817        event: &project::Event,
 818        cx: &mut Context<Self>,
 819    ) {
 820        // TODO [zeta2] init with recent paths
 821        match event {
 822            project::Event::ActiveEntryChanged(Some(active_entry_id)) => {
 823                let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
 824                    return;
 825                };
 826                let path = project.read(cx).path_for_entry(*active_entry_id, cx);
 827                if let Some(path) = path {
 828                    if let Some(ix) = project_state
 829                        .recent_paths
 830                        .iter()
 831                        .position(|probe| probe == &path)
 832                    {
 833                        project_state.recent_paths.remove(ix);
 834                    }
 835                    project_state.recent_paths.push_front(path);
 836                }
 837            }
 838            project::Event::DiagnosticsUpdated { .. } => {
 839                if cx.has_flag::<Zeta2FeatureFlag>() {
 840                    self.refresh_prediction_from_diagnostics(project, cx);
 841                }
 842            }
 843            _ => (),
 844        }
 845    }
 846
 847    fn register_buffer_impl<'a>(
 848        project_state: &'a mut ProjectState,
 849        buffer: &Entity<Buffer>,
 850        project: &Entity<Project>,
 851        cx: &mut Context<Self>,
 852    ) -> &'a mut RegisteredBuffer {
 853        let buffer_id = buffer.entity_id();
 854
 855        if let Some(file) = buffer.read(cx).file() {
 856            let worktree_id = file.worktree_id(cx);
 857            if let Some(worktree) = project.read(cx).worktree_for_id(worktree_id, cx) {
 858                project_state
 859                    .license_detection_watchers
 860                    .entry(worktree_id)
 861                    .or_insert_with(|| {
 862                        let project_entity_id = project.entity_id();
 863                        cx.observe_release(&worktree, move |this, _worktree, _cx| {
 864                            let Some(project_state) = this.projects.get_mut(&project_entity_id)
 865                            else {
 866                                return;
 867                            };
 868                            project_state
 869                                .license_detection_watchers
 870                                .remove(&worktree_id);
 871                        })
 872                        .detach();
 873                        Rc::new(LicenseDetectionWatcher::new(&worktree, cx))
 874                    });
 875            }
 876        }
 877
 878        match project_state.registered_buffers.entry(buffer_id) {
 879            hash_map::Entry::Occupied(entry) => entry.into_mut(),
 880            hash_map::Entry::Vacant(entry) => {
 881                let buf = buffer.read(cx);
 882                let snapshot = buf.text_snapshot();
 883                let file = buf.file().cloned();
 884                let project_entity_id = project.entity_id();
 885                entry.insert(RegisteredBuffer {
 886                    snapshot,
 887                    file,
 888                    last_position: None,
 889                    _subscriptions: [
 890                        cx.subscribe(buffer, {
 891                            let project = project.downgrade();
 892                            move |this, buffer, event, cx| {
 893                                if let language::BufferEvent::Edited = event
 894                                    && let Some(project) = project.upgrade()
 895                                {
 896                                    this.report_changes_for_buffer(&buffer, &project, cx);
 897                                }
 898                            }
 899                        }),
 900                        cx.observe_release(buffer, move |this, _buffer, _cx| {
 901                            let Some(project_state) = this.projects.get_mut(&project_entity_id)
 902                            else {
 903                                return;
 904                            };
 905                            project_state.registered_buffers.remove(&buffer_id);
 906                        }),
 907                    ],
 908                })
 909            }
 910        }
 911    }
 912
 913    fn report_changes_for_buffer(
 914        &mut self,
 915        buffer: &Entity<Buffer>,
 916        project: &Entity<Project>,
 917        cx: &mut Context<Self>,
 918    ) {
 919        let project_state = self.get_or_init_project(project, cx);
 920        let registered_buffer = Self::register_buffer_impl(project_state, buffer, project, cx);
 921
 922        let buf = buffer.read(cx);
 923        let new_file = buf.file().cloned();
 924        let new_snapshot = buf.text_snapshot();
 925        if new_snapshot.version == registered_buffer.snapshot.version {
 926            return;
 927        }
 928
 929        let old_file = mem::replace(&mut registered_buffer.file, new_file.clone());
 930        let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
 931        let end_edit_anchor = new_snapshot
 932            .anchored_edits_since::<Point>(&old_snapshot.version)
 933            .last()
 934            .map(|(_, range)| range.end);
 935        let events = &mut project_state.events;
 936
 937        let now = cx.background_executor().now();
 938        if let Some(last_event) = project_state.last_event.as_mut() {
 939            let is_next_snapshot_of_same_buffer = old_snapshot.remote_id()
 940                == last_event.new_snapshot.remote_id()
 941                && old_snapshot.version == last_event.new_snapshot.version;
 942
 943            let should_coalesce = is_next_snapshot_of_same_buffer
 944                && end_edit_anchor
 945                    .as_ref()
 946                    .zip(last_event.end_edit_anchor.as_ref())
 947                    .is_some_and(|(a, b)| {
 948                        let a = a.to_point(&new_snapshot);
 949                        let b = b.to_point(&new_snapshot);
 950                        a.row.abs_diff(b.row) <= CHANGE_GROUPING_LINE_SPAN
 951                    });
 952
 953            if should_coalesce {
 954                let pause_elapsed = last_event
 955                    .last_edit_time
 956                    .map(|t| now.duration_since(t) >= LAST_CHANGE_GROUPING_TIME)
 957                    .unwrap_or(false);
 958                if pause_elapsed {
 959                    last_event.snapshot_after_last_editing_pause =
 960                        Some(last_event.new_snapshot.clone());
 961                }
 962
 963                last_event.end_edit_anchor = end_edit_anchor;
 964                last_event.new_snapshot = new_snapshot;
 965                last_event.last_edit_time = Some(now);
 966                return;
 967            }
 968        }
 969
 970        if events.len() + 1 >= EVENT_COUNT_MAX {
 971            events.pop_front();
 972        }
 973
 974        if let Some(event) = project_state.last_event.take() {
 975            events.extend(event.finalize(&project_state.license_detection_watchers, cx));
 976        }
 977
 978        project_state.last_event = Some(LastEvent {
 979            old_file,
 980            new_file,
 981            old_snapshot,
 982            new_snapshot,
 983            end_edit_anchor,
 984            snapshot_after_last_editing_pause: None,
 985            last_edit_time: Some(now),
 986        });
 987    }
 988
 989    fn prediction_at(
 990        &mut self,
 991        buffer: &Entity<Buffer>,
 992        position: Option<language::Anchor>,
 993        project: &Entity<Project>,
 994        cx: &App,
 995    ) -> Option<BufferEditPrediction<'_>> {
 996        let project_state = self.projects.get_mut(&project.entity_id())?;
 997        if let Some(position) = position
 998            && let Some(buffer) = project_state
 999                .registered_buffers
1000                .get_mut(&buffer.entity_id())
1001        {
1002            buffer.last_position = Some(position);
1003        }
1004
1005        let CurrentEditPrediction {
1006            requested_by,
1007            prediction,
1008            ..
1009        } = project_state.current_prediction.as_ref()?;
1010
1011        if prediction.targets_buffer(buffer.read(cx)) {
1012            Some(BufferEditPrediction::Local { prediction })
1013        } else {
1014            let show_jump = match requested_by {
1015                PredictionRequestedBy::Buffer(requested_by_buffer_id) => {
1016                    requested_by_buffer_id == &buffer.entity_id()
1017                }
1018                PredictionRequestedBy::DiagnosticsUpdate => true,
1019            };
1020
1021            if show_jump {
1022                Some(BufferEditPrediction::Jump { prediction })
1023            } else {
1024                None
1025            }
1026        }
1027    }
1028
1029    fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
1030        let custom_accept_url = env::var("ZED_ACCEPT_PREDICTION_URL").ok();
1031        match self.edit_prediction_model {
1032            EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {
1033                if self.custom_predict_edits_url.is_some() && custom_accept_url.is_none() {
1034                    return;
1035                }
1036            }
1037            EditPredictionModel::Sweep | EditPredictionModel::Mercury => return,
1038        }
1039
1040        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1041            return;
1042        };
1043
1044        let Some(prediction) = project_state.current_prediction.take() else {
1045            return;
1046        };
1047        let request_id = prediction.prediction.id.to_string();
1048        for pending_prediction in mem::take(&mut project_state.pending_predictions) {
1049            project_state.cancel_pending_prediction(pending_prediction, cx);
1050        }
1051
1052        let client = self.client.clone();
1053        let llm_token = self.llm_token.clone();
1054        let app_version = AppVersion::global(cx);
1055        cx.spawn(async move |this, cx| {
1056            let (url, require_auth) = if let Some(accept_edits_url) = custom_accept_url {
1057                (http_client::Url::parse(&accept_edits_url)?, false)
1058            } else {
1059                (
1060                    client
1061                        .http_client()
1062                        .build_zed_llm_url("/predict_edits/accept", &[])?,
1063                    true,
1064                )
1065            };
1066
1067            let response = cx
1068                .background_spawn(Self::send_api_request::<()>(
1069                    move |builder| {
1070                        let req = builder.uri(url.as_ref()).body(
1071                            serde_json::to_string(&AcceptEditPredictionBody {
1072                                request_id: request_id.clone(),
1073                            })?
1074                            .into(),
1075                        );
1076                        Ok(req?)
1077                    },
1078                    client,
1079                    llm_token,
1080                    app_version,
1081                    require_auth,
1082                ))
1083                .await;
1084
1085            Self::handle_api_response(&this, response, cx)?;
1086            anyhow::Ok(())
1087        })
1088        .detach_and_log_err(cx);
1089    }
1090
1091    async fn handle_rejected_predictions(
1092        rx: UnboundedReceiver<EditPredictionRejection>,
1093        client: Arc<Client>,
1094        llm_token: LlmApiToken,
1095        app_version: Version,
1096        background_executor: BackgroundExecutor,
1097    ) {
1098        let mut rx = std::pin::pin!(rx.peekable());
1099        let mut batched = Vec::new();
1100
1101        while let Some(rejection) = rx.next().await {
1102            batched.push(rejection);
1103
1104            if batched.len() < MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST / 2 {
1105                select_biased! {
1106                    next = rx.as_mut().peek().fuse() => {
1107                        if next.is_some() {
1108                            continue;
1109                        }
1110                    }
1111                    () = background_executor.timer(REJECT_REQUEST_DEBOUNCE).fuse() => {},
1112                }
1113            }
1114
1115            let url = client
1116                .http_client()
1117                .build_zed_llm_url("/predict_edits/reject", &[])
1118                .unwrap();
1119
1120            let flush_count = batched
1121                .len()
1122                // in case items have accumulated after failure
1123                .min(MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST);
1124            let start = batched.len() - flush_count;
1125
1126            let body = RejectEditPredictionsBodyRef {
1127                rejections: &batched[start..],
1128            };
1129
1130            let result = Self::send_api_request::<()>(
1131                |builder| {
1132                    let req = builder
1133                        .uri(url.as_ref())
1134                        .body(serde_json::to_string(&body)?.into());
1135                    anyhow::Ok(req?)
1136                },
1137                client.clone(),
1138                llm_token.clone(),
1139                app_version.clone(),
1140                true,
1141            )
1142            .await;
1143
1144            if result.log_err().is_some() {
1145                batched.drain(start..);
1146            }
1147        }
1148    }
1149
1150    fn reject_current_prediction(
1151        &mut self,
1152        reason: EditPredictionRejectReason,
1153        project: &Entity<Project>,
1154    ) {
1155        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
1156            project_state.pending_predictions.clear();
1157            if let Some(prediction) = project_state.current_prediction.take() {
1158                self.reject_prediction(prediction.prediction.id, reason, prediction.was_shown);
1159            }
1160        };
1161    }
1162
1163    fn did_show_current_prediction(&mut self, project: &Entity<Project>, _cx: &mut Context<Self>) {
1164        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
1165            if let Some(current_prediction) = project_state.current_prediction.as_mut() {
1166                if !current_prediction.was_shown {
1167                    current_prediction.was_shown = true;
1168                    self.shown_predictions
1169                        .push_front(current_prediction.prediction.clone());
1170                    if self.shown_predictions.len() > 50 {
1171                        let completion = self.shown_predictions.pop_back().unwrap();
1172                        self.rated_predictions.remove(&completion.id);
1173                    }
1174                }
1175            }
1176        }
1177    }
1178
1179    fn reject_prediction(
1180        &mut self,
1181        prediction_id: EditPredictionId,
1182        reason: EditPredictionRejectReason,
1183        was_shown: bool,
1184    ) {
1185        match self.edit_prediction_model {
1186            EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {
1187                if self.custom_predict_edits_url.is_some() {
1188                    return;
1189                }
1190            }
1191            EditPredictionModel::Sweep | EditPredictionModel::Mercury => return,
1192        }
1193
1194        self.reject_predictions_tx
1195            .unbounded_send(EditPredictionRejection {
1196                request_id: prediction_id.to_string(),
1197                reason,
1198                was_shown,
1199            })
1200            .log_err();
1201    }
1202
1203    fn is_refreshing(&self, project: &Entity<Project>) -> bool {
1204        self.projects
1205            .get(&project.entity_id())
1206            .is_some_and(|project_state| !project_state.pending_predictions.is_empty())
1207    }
1208
1209    pub fn refresh_prediction_from_buffer(
1210        &mut self,
1211        project: Entity<Project>,
1212        buffer: Entity<Buffer>,
1213        position: language::Anchor,
1214        cx: &mut Context<Self>,
1215    ) {
1216        self.queue_prediction_refresh(project.clone(), buffer.entity_id(), cx, move |this, cx| {
1217            let Some(request_task) = this
1218                .update(cx, |this, cx| {
1219                    this.request_prediction(
1220                        &project,
1221                        &buffer,
1222                        position,
1223                        PredictEditsRequestTrigger::Other,
1224                        cx,
1225                    )
1226                })
1227                .log_err()
1228            else {
1229                return Task::ready(anyhow::Ok(None));
1230            };
1231
1232            cx.spawn(async move |_cx| {
1233                request_task.await.map(|prediction_result| {
1234                    prediction_result.map(|prediction_result| {
1235                        (
1236                            prediction_result,
1237                            PredictionRequestedBy::Buffer(buffer.entity_id()),
1238                        )
1239                    })
1240                })
1241            })
1242        })
1243    }
1244
1245    pub fn refresh_prediction_from_diagnostics(
1246        &mut self,
1247        project: Entity<Project>,
1248        cx: &mut Context<Self>,
1249    ) {
1250        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1251            return;
1252        };
1253
1254        // Prefer predictions from buffer
1255        if project_state.current_prediction.is_some() {
1256            return;
1257        };
1258
1259        self.queue_prediction_refresh(project.clone(), project.entity_id(), cx, move |this, cx| {
1260            let Some((active_buffer, snapshot, cursor_point)) = this
1261                .read_with(cx, |this, cx| {
1262                    let project_state = this.projects.get(&project.entity_id())?;
1263                    let (buffer, position) = project_state.active_buffer(&project, cx)?;
1264                    let snapshot = buffer.read(cx).snapshot();
1265
1266                    if !Self::predictions_enabled_at(&snapshot, position, cx) {
1267                        return None;
1268                    }
1269
1270                    let cursor_point = position
1271                        .map(|pos| pos.to_point(&snapshot))
1272                        .unwrap_or_default();
1273
1274                    Some((buffer, snapshot, cursor_point))
1275                })
1276                .log_err()
1277                .flatten()
1278            else {
1279                return Task::ready(anyhow::Ok(None));
1280            };
1281
1282            cx.spawn(async move |cx| {
1283                let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1284                    active_buffer,
1285                    &snapshot,
1286                    Default::default(),
1287                    cursor_point,
1288                    &project,
1289                    cx,
1290                )
1291                .await?
1292                else {
1293                    return anyhow::Ok(None);
1294                };
1295
1296                let Some(prediction_result) = this
1297                    .update(cx, |this, cx| {
1298                        this.request_prediction(
1299                            &project,
1300                            &jump_buffer,
1301                            jump_position,
1302                            PredictEditsRequestTrigger::Diagnostics,
1303                            cx,
1304                        )
1305                    })?
1306                    .await?
1307                else {
1308                    return anyhow::Ok(None);
1309                };
1310
1311                this.update(cx, |this, cx| {
1312                    Some((
1313                        if this
1314                            .get_or_init_project(&project, cx)
1315                            .current_prediction
1316                            .is_none()
1317                        {
1318                            prediction_result
1319                        } else {
1320                            EditPredictionResult {
1321                                id: prediction_result.id,
1322                                prediction: Err(EditPredictionRejectReason::CurrentPreferred),
1323                            }
1324                        },
1325                        PredictionRequestedBy::DiagnosticsUpdate,
1326                    ))
1327                })
1328            })
1329        });
1330    }
1331
1332    fn predictions_enabled_at(
1333        snapshot: &BufferSnapshot,
1334        position: Option<language::Anchor>,
1335        cx: &App,
1336    ) -> bool {
1337        let file = snapshot.file();
1338        let all_settings = all_language_settings(file, cx);
1339        if !all_settings.show_edit_predictions(snapshot.language(), cx)
1340            || file.is_some_and(|file| !all_settings.edit_predictions_enabled_for_file(file, cx))
1341        {
1342            return false;
1343        }
1344
1345        if let Some(last_position) = position {
1346            let settings = snapshot.settings_at(last_position, cx);
1347
1348            if !settings.edit_predictions_disabled_in.is_empty()
1349                && let Some(scope) = snapshot.language_scope_at(last_position)
1350                && let Some(scope_name) = scope.override_name()
1351                && settings
1352                    .edit_predictions_disabled_in
1353                    .iter()
1354                    .any(|s| s == scope_name)
1355            {
1356                return false;
1357            }
1358        }
1359
1360        true
1361    }
1362
1363    #[cfg(not(test))]
1364    pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
1365    #[cfg(test)]
1366    pub const THROTTLE_TIMEOUT: Duration = Duration::ZERO;
1367
1368    fn queue_prediction_refresh(
1369        &mut self,
1370        project: Entity<Project>,
1371        throttle_entity: EntityId,
1372        cx: &mut Context<Self>,
1373        do_refresh: impl FnOnce(
1374            WeakEntity<Self>,
1375            &mut AsyncApp,
1376        )
1377            -> Task<Result<Option<(EditPredictionResult, PredictionRequestedBy)>>>
1378        + 'static,
1379    ) {
1380        let project_state = self.get_or_init_project(&project, cx);
1381        let pending_prediction_id = project_state.next_pending_prediction_id;
1382        project_state.next_pending_prediction_id += 1;
1383        let last_request = project_state.last_prediction_refresh;
1384
1385        let task = cx.spawn(async move |this, cx| {
1386            if let Some((last_entity, last_timestamp)) = last_request
1387                && throttle_entity == last_entity
1388                && let Some(timeout) =
1389                    (last_timestamp + Self::THROTTLE_TIMEOUT).checked_duration_since(Instant::now())
1390            {
1391                cx.background_executor().timer(timeout).await;
1392            }
1393
1394            // If this task was cancelled before the throttle timeout expired,
1395            // do not perform a request.
1396            let mut is_cancelled = true;
1397            this.update(cx, |this, cx| {
1398                let project_state = this.get_or_init_project(&project, cx);
1399                if !project_state
1400                    .cancelled_predictions
1401                    .remove(&pending_prediction_id)
1402                {
1403                    project_state.last_prediction_refresh = Some((throttle_entity, Instant::now()));
1404                    is_cancelled = false;
1405                }
1406            })
1407            .ok();
1408            if is_cancelled {
1409                return None;
1410            }
1411
1412            let new_prediction_result = do_refresh(this.clone(), cx).await.log_err().flatten();
1413            let new_prediction_id = new_prediction_result
1414                .as_ref()
1415                .map(|(prediction, _)| prediction.id.clone());
1416
1417            // When a prediction completes, remove it from the pending list, and cancel
1418            // any pending predictions that were enqueued before it.
1419            this.update(cx, |this, cx| {
1420                let project_state = this.get_or_init_project(&project, cx);
1421
1422                let is_cancelled = project_state
1423                    .cancelled_predictions
1424                    .remove(&pending_prediction_id);
1425
1426                let new_current_prediction = if !is_cancelled
1427                    && let Some((prediction_result, requested_by)) = new_prediction_result
1428                {
1429                    match prediction_result.prediction {
1430                        Ok(prediction) => {
1431                            let new_prediction = CurrentEditPrediction {
1432                                requested_by,
1433                                prediction,
1434                                was_shown: false,
1435                            };
1436
1437                            if let Some(current_prediction) =
1438                                project_state.current_prediction.as_ref()
1439                            {
1440                                if new_prediction.should_replace_prediction(&current_prediction, cx)
1441                                {
1442                                    this.reject_current_prediction(
1443                                        EditPredictionRejectReason::Replaced,
1444                                        &project,
1445                                    );
1446
1447                                    Some(new_prediction)
1448                                } else {
1449                                    this.reject_prediction(
1450                                        new_prediction.prediction.id,
1451                                        EditPredictionRejectReason::CurrentPreferred,
1452                                        false,
1453                                    );
1454                                    None
1455                                }
1456                            } else {
1457                                Some(new_prediction)
1458                            }
1459                        }
1460                        Err(reject_reason) => {
1461                            this.reject_prediction(prediction_result.id, reject_reason, false);
1462                            None
1463                        }
1464                    }
1465                } else {
1466                    None
1467                };
1468
1469                let project_state = this.get_or_init_project(&project, cx);
1470
1471                if let Some(new_prediction) = new_current_prediction {
1472                    project_state.current_prediction = Some(new_prediction);
1473                }
1474
1475                let mut pending_predictions = mem::take(&mut project_state.pending_predictions);
1476                for (ix, pending_prediction) in pending_predictions.iter().enumerate() {
1477                    if pending_prediction.id == pending_prediction_id {
1478                        pending_predictions.remove(ix);
1479                        for pending_prediction in pending_predictions.drain(0..ix) {
1480                            project_state.cancel_pending_prediction(pending_prediction, cx)
1481                        }
1482                        break;
1483                    }
1484                }
1485                this.get_or_init_project(&project, cx).pending_predictions = pending_predictions;
1486                cx.notify();
1487            })
1488            .ok();
1489
1490            new_prediction_id
1491        });
1492
1493        if project_state.pending_predictions.len() <= 1 {
1494            project_state.pending_predictions.push(PendingPrediction {
1495                id: pending_prediction_id,
1496                task,
1497            });
1498        } else if project_state.pending_predictions.len() == 2 {
1499            let pending_prediction = project_state.pending_predictions.pop().unwrap();
1500            project_state.pending_predictions.push(PendingPrediction {
1501                id: pending_prediction_id,
1502                task,
1503            });
1504            project_state.cancel_pending_prediction(pending_prediction, cx);
1505        }
1506    }
1507
1508    pub fn request_prediction(
1509        &mut self,
1510        project: &Entity<Project>,
1511        active_buffer: &Entity<Buffer>,
1512        position: language::Anchor,
1513        trigger: PredictEditsRequestTrigger,
1514        cx: &mut Context<Self>,
1515    ) -> Task<Result<Option<EditPredictionResult>>> {
1516        self.request_prediction_internal(
1517            project.clone(),
1518            active_buffer.clone(),
1519            position,
1520            trigger,
1521            cx.has_flag::<Zeta2FeatureFlag>(),
1522            cx,
1523        )
1524    }
1525
1526    fn request_prediction_internal(
1527        &mut self,
1528        project: Entity<Project>,
1529        active_buffer: Entity<Buffer>,
1530        position: language::Anchor,
1531        trigger: PredictEditsRequestTrigger,
1532        allow_jump: bool,
1533        cx: &mut Context<Self>,
1534    ) -> Task<Result<Option<EditPredictionResult>>> {
1535        const DIAGNOSTIC_LINES_RANGE: u32 = 20;
1536
1537        self.get_or_init_project(&project, cx);
1538        let project_state = self.projects.get(&project.entity_id()).unwrap();
1539        let events = project_state.events(cx);
1540        let has_events = !events.is_empty();
1541        let debug_tx = project_state.debug_tx.clone();
1542
1543        let snapshot = active_buffer.read(cx).snapshot();
1544        let cursor_point = position.to_point(&snapshot);
1545        let diagnostic_search_start = cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
1546        let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
1547        let diagnostic_search_range =
1548            Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
1549
1550        let related_files = if self.use_context {
1551            self.context_for_project(&project, cx)
1552        } else {
1553            Vec::new().into()
1554        };
1555
1556        let inputs = EditPredictionModelInput {
1557            project: project.clone(),
1558            buffer: active_buffer.clone(),
1559            snapshot: snapshot.clone(),
1560            position,
1561            events,
1562            related_files,
1563            recent_paths: project_state.recent_paths.clone(),
1564            trigger,
1565            diagnostic_search_range: diagnostic_search_range.clone(),
1566            debug_tx,
1567        };
1568
1569        let task = match self.edit_prediction_model {
1570            EditPredictionModel::Zeta1 => zeta1::request_prediction_with_zeta1(self, inputs, cx),
1571            EditPredictionModel::Zeta2 => zeta2::request_prediction_with_zeta2(self, inputs, cx),
1572            EditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep(inputs, cx),
1573            EditPredictionModel::Mercury => self.mercury.request_prediction(inputs, cx),
1574        };
1575
1576        cx.spawn(async move |this, cx| {
1577            let prediction = task.await?;
1578
1579            if prediction.is_none() && allow_jump {
1580                let cursor_point = position.to_point(&snapshot);
1581                if has_events
1582                    && let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1583                        active_buffer.clone(),
1584                        &snapshot,
1585                        diagnostic_search_range,
1586                        cursor_point,
1587                        &project,
1588                        cx,
1589                    )
1590                    .await?
1591                {
1592                    return this
1593                        .update(cx, |this, cx| {
1594                            this.request_prediction_internal(
1595                                project,
1596                                jump_buffer,
1597                                jump_position,
1598                                trigger,
1599                                false,
1600                                cx,
1601                            )
1602                        })?
1603                        .await;
1604                }
1605
1606                return anyhow::Ok(None);
1607            }
1608
1609            Ok(prediction)
1610        })
1611    }
1612
1613    async fn next_diagnostic_location(
1614        active_buffer: Entity<Buffer>,
1615        active_buffer_snapshot: &BufferSnapshot,
1616        active_buffer_diagnostic_search_range: Range<Point>,
1617        active_buffer_cursor_point: Point,
1618        project: &Entity<Project>,
1619        cx: &mut AsyncApp,
1620    ) -> Result<Option<(Entity<Buffer>, language::Anchor)>> {
1621        // find the closest diagnostic to the cursor that wasn't close enough to be included in the last request
1622        let mut jump_location = active_buffer_snapshot
1623            .diagnostic_groups(None)
1624            .into_iter()
1625            .filter_map(|(_, group)| {
1626                let range = &group.entries[group.primary_ix]
1627                    .range
1628                    .to_point(&active_buffer_snapshot);
1629                if range.overlaps(&active_buffer_diagnostic_search_range) {
1630                    None
1631                } else {
1632                    Some(range.start)
1633                }
1634            })
1635            .min_by_key(|probe| probe.row.abs_diff(active_buffer_cursor_point.row))
1636            .map(|position| {
1637                (
1638                    active_buffer.clone(),
1639                    active_buffer_snapshot.anchor_before(position),
1640                )
1641            });
1642
1643        if jump_location.is_none() {
1644            let active_buffer_path = active_buffer.read_with(cx, |buffer, cx| {
1645                let file = buffer.file()?;
1646
1647                Some(ProjectPath {
1648                    worktree_id: file.worktree_id(cx),
1649                    path: file.path().clone(),
1650                })
1651            })?;
1652
1653            let buffer_task = project.update(cx, |project, cx| {
1654                let (path, _, _) = project
1655                    .diagnostic_summaries(false, cx)
1656                    .filter(|(path, _, _)| Some(path) != active_buffer_path.as_ref())
1657                    .max_by_key(|(path, _, _)| {
1658                        // find the buffer with errors that shares most parent directories
1659                        path.path
1660                            .components()
1661                            .zip(
1662                                active_buffer_path
1663                                    .as_ref()
1664                                    .map(|p| p.path.components())
1665                                    .unwrap_or_default(),
1666                            )
1667                            .take_while(|(a, b)| a == b)
1668                            .count()
1669                    })?;
1670
1671                Some(project.open_buffer(path, cx))
1672            })?;
1673
1674            if let Some(buffer_task) = buffer_task {
1675                let closest_buffer = buffer_task.await?;
1676
1677                jump_location = closest_buffer
1678                    .read_with(cx, |buffer, _cx| {
1679                        buffer
1680                            .buffer_diagnostics(None)
1681                            .into_iter()
1682                            .min_by_key(|entry| entry.diagnostic.severity)
1683                            .map(|entry| entry.range.start)
1684                    })?
1685                    .map(|position| (closest_buffer, position));
1686            }
1687        }
1688
1689        anyhow::Ok(jump_location)
1690    }
1691
1692    async fn send_raw_llm_request(
1693        request: open_ai::Request,
1694        client: Arc<Client>,
1695        llm_token: LlmApiToken,
1696        app_version: Version,
1697        #[cfg(feature = "cli-support")] eval_cache: Option<Arc<dyn EvalCache>>,
1698        #[cfg(feature = "cli-support")] eval_cache_kind: EvalCacheEntryKind,
1699    ) -> Result<(open_ai::Response, Option<EditPredictionUsage>)> {
1700        let url = client
1701            .http_client()
1702            .build_zed_llm_url("/predict_edits/raw", &[])?;
1703
1704        #[cfg(feature = "cli-support")]
1705        let cache_key = if let Some(cache) = eval_cache {
1706            use collections::FxHasher;
1707            use std::hash::{Hash, Hasher};
1708
1709            let mut hasher = FxHasher::default();
1710            url.hash(&mut hasher);
1711            let request_str = serde_json::to_string_pretty(&request)?;
1712            request_str.hash(&mut hasher);
1713            let hash = hasher.finish();
1714
1715            let key = (eval_cache_kind, hash);
1716            if let Some(response_str) = cache.read(key) {
1717                return Ok((serde_json::from_str(&response_str)?, None));
1718            }
1719
1720            Some((cache, request_str, key))
1721        } else {
1722            None
1723        };
1724
1725        let (response, usage) = Self::send_api_request(
1726            |builder| {
1727                let req = builder
1728                    .uri(url.as_ref())
1729                    .body(serde_json::to_string(&request)?.into());
1730                Ok(req?)
1731            },
1732            client,
1733            llm_token,
1734            app_version,
1735            true,
1736        )
1737        .await?;
1738
1739        #[cfg(feature = "cli-support")]
1740        if let Some((cache, request, key)) = cache_key {
1741            cache.write(key, &request, &serde_json::to_string_pretty(&response)?);
1742        }
1743
1744        Ok((response, usage))
1745    }
1746
1747    fn handle_api_response<T>(
1748        this: &WeakEntity<Self>,
1749        response: Result<(T, Option<EditPredictionUsage>)>,
1750        cx: &mut gpui::AsyncApp,
1751    ) -> Result<T> {
1752        match response {
1753            Ok((data, usage)) => {
1754                if let Some(usage) = usage {
1755                    this.update(cx, |this, cx| {
1756                        this.user_store.update(cx, |user_store, cx| {
1757                            user_store.update_edit_prediction_usage(usage, cx);
1758                        });
1759                    })
1760                    .ok();
1761                }
1762                Ok(data)
1763            }
1764            Err(err) => {
1765                if err.is::<ZedUpdateRequiredError>() {
1766                    cx.update(|cx| {
1767                        this.update(cx, |this, _cx| {
1768                            this.update_required = true;
1769                        })
1770                        .ok();
1771
1772                        let error_message: SharedString = err.to_string().into();
1773                        show_app_notification(
1774                            NotificationId::unique::<ZedUpdateRequiredError>(),
1775                            cx,
1776                            move |cx| {
1777                                cx.new(|cx| {
1778                                    ErrorMessagePrompt::new(error_message.clone(), cx)
1779                                        .with_link_button("Update Zed", "https://zed.dev/releases")
1780                                })
1781                            },
1782                        );
1783                    })
1784                    .ok();
1785                }
1786                Err(err)
1787            }
1788        }
1789    }
1790
1791    async fn send_api_request<Res>(
1792        build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
1793        client: Arc<Client>,
1794        llm_token: LlmApiToken,
1795        app_version: Version,
1796        require_auth: bool,
1797    ) -> Result<(Res, Option<EditPredictionUsage>)>
1798    where
1799        Res: DeserializeOwned,
1800    {
1801        let http_client = client.http_client();
1802
1803        let mut token = if require_auth {
1804            Some(llm_token.acquire(&client).await?)
1805        } else {
1806            llm_token.acquire(&client).await.ok()
1807        };
1808        let mut did_retry = false;
1809
1810        loop {
1811            let request_builder = http_client::Request::builder().method(Method::POST);
1812
1813            let mut request_builder = request_builder
1814                .header("Content-Type", "application/json")
1815                .header(ZED_VERSION_HEADER_NAME, app_version.to_string());
1816
1817            // Only add Authorization header if we have a token
1818            if let Some(ref token_value) = token {
1819                request_builder =
1820                    request_builder.header("Authorization", format!("Bearer {}", token_value));
1821            }
1822
1823            let request = build(request_builder)?;
1824
1825            let mut response = http_client.send(request).await?;
1826
1827            if let Some(minimum_required_version) = response
1828                .headers()
1829                .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
1830                .and_then(|version| Version::from_str(version.to_str().ok()?).ok())
1831            {
1832                anyhow::ensure!(
1833                    app_version >= minimum_required_version,
1834                    ZedUpdateRequiredError {
1835                        minimum_version: minimum_required_version
1836                    }
1837                );
1838            }
1839
1840            if response.status().is_success() {
1841                let usage = EditPredictionUsage::from_headers(response.headers()).ok();
1842
1843                let mut body = Vec::new();
1844                response.body_mut().read_to_end(&mut body).await?;
1845                return Ok((serde_json::from_slice(&body)?, usage));
1846            } else if !did_retry
1847                && token.is_some()
1848                && response
1849                    .headers()
1850                    .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
1851                    .is_some()
1852            {
1853                did_retry = true;
1854                token = Some(llm_token.refresh(&client).await?);
1855            } else {
1856                let mut body = String::new();
1857                response.body_mut().read_to_string(&mut body).await?;
1858                anyhow::bail!(
1859                    "Request failed with status: {:?}\nBody: {}",
1860                    response.status(),
1861                    body
1862                );
1863            }
1864        }
1865    }
1866
1867    pub fn refresh_context(
1868        &mut self,
1869        project: &Entity<Project>,
1870        buffer: &Entity<language::Buffer>,
1871        cursor_position: language::Anchor,
1872        cx: &mut Context<Self>,
1873    ) {
1874        if self.use_context {
1875            self.get_or_init_project(project, cx)
1876                .context
1877                .update(cx, |store, cx| {
1878                    store.refresh(buffer.clone(), cursor_position, cx);
1879                });
1880        }
1881    }
1882
1883    #[cfg(feature = "cli-support")]
1884    pub fn set_context_for_buffer(
1885        &mut self,
1886        project: &Entity<Project>,
1887        related_files: Vec<RelatedFile>,
1888        cx: &mut Context<Self>,
1889    ) {
1890        self.get_or_init_project(project, cx)
1891            .context
1892            .update(cx, |store, _| {
1893                store.set_related_files(related_files);
1894            });
1895    }
1896
1897    fn is_file_open_source(
1898        &self,
1899        project: &Entity<Project>,
1900        file: &Arc<dyn File>,
1901        cx: &App,
1902    ) -> bool {
1903        if !file.is_local() || file.is_private() {
1904            return false;
1905        }
1906        let Some(project_state) = self.projects.get(&project.entity_id()) else {
1907            return false;
1908        };
1909        project_state
1910            .license_detection_watchers
1911            .get(&file.worktree_id(cx))
1912            .as_ref()
1913            .is_some_and(|watcher| watcher.is_project_open_source())
1914    }
1915
1916    fn can_collect_file(&self, project: &Entity<Project>, file: &Arc<dyn File>, cx: &App) -> bool {
1917        self.data_collection_choice.is_enabled() && self.is_file_open_source(project, file, cx)
1918    }
1919
1920    fn can_collect_events(&self, events: &[Arc<zeta_prompt::Event>]) -> bool {
1921        if !self.data_collection_choice.is_enabled() {
1922            return false;
1923        }
1924        events.iter().all(|event| {
1925            matches!(
1926                event.as_ref(),
1927                zeta_prompt::Event::BufferChange {
1928                    in_open_source_repo: true,
1929                    ..
1930                }
1931            )
1932        })
1933    }
1934
1935    fn load_data_collection_choice() -> DataCollectionChoice {
1936        let choice = KEY_VALUE_STORE
1937            .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
1938            .log_err()
1939            .flatten();
1940
1941        match choice.as_deref() {
1942            Some("true") => DataCollectionChoice::Enabled,
1943            Some("false") => DataCollectionChoice::Disabled,
1944            Some(_) => {
1945                log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
1946                DataCollectionChoice::NotAnswered
1947            }
1948            None => DataCollectionChoice::NotAnswered,
1949        }
1950    }
1951
1952    fn toggle_data_collection_choice(&mut self, cx: &mut Context<Self>) {
1953        self.data_collection_choice = self.data_collection_choice.toggle();
1954        let new_choice = self.data_collection_choice;
1955        db::write_and_log(cx, move || {
1956            KEY_VALUE_STORE.write_kvp(
1957                ZED_PREDICT_DATA_COLLECTION_CHOICE.into(),
1958                new_choice.is_enabled().to_string(),
1959            )
1960        });
1961    }
1962
1963    pub fn shown_predictions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
1964        self.shown_predictions.iter()
1965    }
1966
1967    pub fn shown_completions_len(&self) -> usize {
1968        self.shown_predictions.len()
1969    }
1970
1971    pub fn is_prediction_rated(&self, id: &EditPredictionId) -> bool {
1972        self.rated_predictions.contains(id)
1973    }
1974
1975    pub fn rate_prediction(
1976        &mut self,
1977        prediction: &EditPrediction,
1978        rating: EditPredictionRating,
1979        feedback: String,
1980        cx: &mut Context<Self>,
1981    ) {
1982        self.rated_predictions.insert(prediction.id.clone());
1983        telemetry::event!(
1984            "Edit Prediction Rated",
1985            rating,
1986            inputs = prediction.inputs,
1987            output = prediction.edit_preview.as_unified_diff(&prediction.edits),
1988            feedback
1989        );
1990        self.client.telemetry().flush_events().detach();
1991        cx.notify();
1992    }
1993
1994    fn configure_context_retrieval(&mut self, cx: &mut Context<'_, EditPredictionStore>) {
1995        self.use_context = cx.has_flag::<Zeta2FeatureFlag>()
1996            && all_language_settings(None, cx).edit_predictions.use_context;
1997    }
1998}
1999
2000#[derive(Error, Debug)]
2001#[error(
2002    "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
2003)]
2004pub struct ZedUpdateRequiredError {
2005    minimum_version: Version,
2006}
2007
2008#[cfg(feature = "cli-support")]
2009pub type EvalCacheKey = (EvalCacheEntryKind, u64);
2010
2011#[cfg(feature = "cli-support")]
2012#[derive(Debug, Clone, Copy, PartialEq)]
2013pub enum EvalCacheEntryKind {
2014    Context,
2015    Search,
2016    Prediction,
2017}
2018
2019#[cfg(feature = "cli-support")]
2020impl std::fmt::Display for EvalCacheEntryKind {
2021    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2022        match self {
2023            EvalCacheEntryKind::Search => write!(f, "search"),
2024            EvalCacheEntryKind::Context => write!(f, "context"),
2025            EvalCacheEntryKind::Prediction => write!(f, "prediction"),
2026        }
2027    }
2028}
2029
2030#[cfg(feature = "cli-support")]
2031pub trait EvalCache: Send + Sync {
2032    fn read(&self, key: EvalCacheKey) -> Option<String>;
2033    fn write(&self, key: EvalCacheKey, input: &str, value: &str);
2034}
2035
2036#[derive(Debug, Clone, Copy)]
2037pub enum DataCollectionChoice {
2038    NotAnswered,
2039    Enabled,
2040    Disabled,
2041}
2042
2043impl DataCollectionChoice {
2044    pub fn is_enabled(self) -> bool {
2045        match self {
2046            Self::Enabled => true,
2047            Self::NotAnswered | Self::Disabled => false,
2048        }
2049    }
2050
2051    pub fn is_answered(self) -> bool {
2052        match self {
2053            Self::Enabled | Self::Disabled => true,
2054            Self::NotAnswered => false,
2055        }
2056    }
2057
2058    #[must_use]
2059    pub fn toggle(&self) -> DataCollectionChoice {
2060        match self {
2061            Self::Enabled => Self::Disabled,
2062            Self::Disabled => Self::Enabled,
2063            Self::NotAnswered => Self::Enabled,
2064        }
2065    }
2066}
2067
2068impl From<bool> for DataCollectionChoice {
2069    fn from(value: bool) -> Self {
2070        match value {
2071            true => DataCollectionChoice::Enabled,
2072            false => DataCollectionChoice::Disabled,
2073        }
2074    }
2075}
2076
2077struct ZedPredictUpsell;
2078
2079impl Dismissable for ZedPredictUpsell {
2080    const KEY: &'static str = "dismissed-edit-predict-upsell";
2081
2082    fn dismissed() -> bool {
2083        // To make this backwards compatible with older versions of Zed, we
2084        // check if the user has seen the previous Edit Prediction Onboarding
2085        // before, by checking the data collection choice which was written to
2086        // the database once the user clicked on "Accept and Enable"
2087        if KEY_VALUE_STORE
2088            .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2089            .log_err()
2090            .is_some_and(|s| s.is_some())
2091        {
2092            return true;
2093        }
2094
2095        KEY_VALUE_STORE
2096            .read_kvp(Self::KEY)
2097            .log_err()
2098            .is_some_and(|s| s.is_some())
2099    }
2100}
2101
2102pub fn should_show_upsell_modal() -> bool {
2103    !ZedPredictUpsell::dismissed()
2104}
2105
2106pub fn init(cx: &mut App) {
2107    cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
2108        workspace.register_action(
2109            move |workspace, _: &zed_actions::OpenZedPredictOnboarding, window, cx| {
2110                ZedPredictModal::toggle(
2111                    workspace,
2112                    workspace.user_store().clone(),
2113                    workspace.client().clone(),
2114                    window,
2115                    cx,
2116                )
2117            },
2118        );
2119
2120        workspace.register_action(|workspace, _: &ResetOnboarding, _window, cx| {
2121            update_settings_file(workspace.app_state().fs.clone(), cx, move |settings, _| {
2122                settings
2123                    .project
2124                    .all_languages
2125                    .features
2126                    .get_or_insert_default()
2127                    .edit_prediction_provider = Some(EditPredictionProvider::None)
2128            });
2129        });
2130    })
2131    .detach();
2132}