edit_prediction.rs

   1use anyhow::Result;
   2use arrayvec::ArrayVec;
   3use client::{Client, EditPredictionUsage, UserStore};
   4use cloud_llm_client::predict_edits_v3::{self, PromptFormat};
   5use cloud_llm_client::{
   6    AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, EditPredictionRejectReason,
   7    EditPredictionRejection, MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST,
   8    MINIMUM_REQUIRED_VERSION_HEADER_NAME, PredictEditsRequestTrigger, RejectEditPredictionsBodyRef,
   9    ZED_VERSION_HEADER_NAME,
  10};
  11use collections::{HashMap, HashSet};
  12use db::kvp::{Dismissable, KEY_VALUE_STORE};
  13use edit_prediction_context::EditPredictionExcerptOptions;
  14use edit_prediction_context::{RelatedExcerptStore, RelatedExcerptStoreEvent, RelatedFile};
  15use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
  16use futures::{
  17    AsyncReadExt as _, FutureExt as _, StreamExt as _,
  18    channel::mpsc::{self, UnboundedReceiver},
  19    select_biased,
  20};
  21use gpui::BackgroundExecutor;
  22use gpui::{
  23    App, AsyncApp, Entity, EntityId, Global, SharedString, Subscription, Task, WeakEntity, actions,
  24    http_client::{self, AsyncBody, Method},
  25    prelude::*,
  26};
  27use language::language_settings::all_language_settings;
  28use language::{Anchor, Buffer, File, Point, TextBufferSnapshot, ToPoint};
  29use language::{BufferSnapshot, OffsetRangeExt};
  30use language_model::{LlmApiToken, RefreshLlmTokenListener};
  31use project::{Project, ProjectPath, WorktreeId};
  32use release_channel::AppVersion;
  33use semver::Version;
  34use serde::de::DeserializeOwned;
  35use settings::{EditPredictionProvider, SettingsStore, update_settings_file};
  36use std::collections::{VecDeque, hash_map};
  37use workspace::Workspace;
  38
  39use std::ops::Range;
  40use std::path::Path;
  41use std::rc::Rc;
  42use std::str::FromStr as _;
  43use std::sync::{Arc, LazyLock};
  44use std::time::{Duration, Instant};
  45use std::{env, mem};
  46use thiserror::Error;
  47use util::{RangeExt as _, ResultExt as _};
  48use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
  49
  50pub mod cursor_excerpt;
  51pub mod example_spec;
  52mod license_detection;
  53pub mod mercury;
  54mod onboarding_modal;
  55pub mod open_ai_response;
  56mod prediction;
  57pub mod sweep_ai;
  58
  59#[cfg(any(test, feature = "test-support", feature = "cli-support"))]
  60pub mod udiff;
  61
  62mod zed_edit_prediction_delegate;
  63pub mod zeta1;
  64pub mod zeta2;
  65
  66#[cfg(test)]
  67mod edit_prediction_tests;
  68
  69use crate::license_detection::LicenseDetectionWatcher;
  70use crate::mercury::Mercury;
  71use crate::onboarding_modal::ZedPredictModal;
  72pub use crate::prediction::EditPrediction;
  73pub use crate::prediction::EditPredictionId;
  74use crate::prediction::EditPredictionResult;
  75pub use crate::sweep_ai::SweepAi;
  76pub use language_model::ApiKeyState;
  77pub use telemetry_events::EditPredictionRating;
  78pub use zed_edit_prediction_delegate::ZedEditPredictionDelegate;
  79
  80actions!(
  81    edit_prediction,
  82    [
  83        /// Resets the edit prediction onboarding state.
  84        ResetOnboarding,
  85        /// Clears the edit prediction history.
  86        ClearHistory,
  87    ]
  88);
  89
  90/// Maximum number of events to track.
  91const EVENT_COUNT_MAX: usize = 6;
  92const CHANGE_GROUPING_LINE_SPAN: u32 = 8;
  93const LAST_CHANGE_GROUPING_TIME: Duration = Duration::from_secs(1);
  94const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
  95const REJECT_REQUEST_DEBOUNCE: Duration = Duration::from_secs(15);
  96
  97pub struct SweepFeatureFlag;
  98
  99impl FeatureFlag for SweepFeatureFlag {
 100    const NAME: &str = "sweep-ai";
 101}
 102
 103pub struct MercuryFeatureFlag;
 104
 105impl FeatureFlag for MercuryFeatureFlag {
 106    const NAME: &str = "mercury";
 107}
 108
 109pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
 110    context: EditPredictionExcerptOptions {
 111        max_bytes: 512,
 112        min_bytes: 128,
 113        target_before_cursor_over_total_bytes: 0.5,
 114    },
 115    prompt_format: PromptFormat::DEFAULT,
 116};
 117
 118static USE_OLLAMA: LazyLock<bool> =
 119    LazyLock::new(|| env::var("ZED_ZETA2_OLLAMA").is_ok_and(|var| !var.is_empty()));
 120
 121static EDIT_PREDICTIONS_MODEL_ID: LazyLock<String> = LazyLock::new(|| {
 122    match env::var("ZED_ZETA2_MODEL").as_deref() {
 123        Ok("zeta2-exp") => "4w5n28vw", // Fine-tuned model @ Baseten
 124        Ok(model) => model,
 125        Err(_) if *USE_OLLAMA => "qwen3-coder:30b",
 126        Err(_) => "yqvev8r3", // Vanilla qwen3-coder @ Baseten
 127    }
 128    .to_string()
 129});
 130static PREDICT_EDITS_URL: LazyLock<Option<String>> = LazyLock::new(|| {
 131    env::var("ZED_PREDICT_EDITS_URL").ok().or_else(|| {
 132        if *USE_OLLAMA {
 133            Some("http://localhost:11434/v1/chat/completions".into())
 134        } else {
 135            None
 136        }
 137    })
 138});
 139
 140pub struct Zeta2FeatureFlag;
 141
 142impl FeatureFlag for Zeta2FeatureFlag {
 143    const NAME: &'static str = "zeta2";
 144
 145    fn enabled_for_staff() -> bool {
 146        true
 147    }
 148}
 149
 150#[derive(Clone)]
 151struct EditPredictionStoreGlobal(Entity<EditPredictionStore>);
 152
 153impl Global for EditPredictionStoreGlobal {}
 154
 155pub struct EditPredictionStore {
 156    client: Arc<Client>,
 157    user_store: Entity<UserStore>,
 158    llm_token: LlmApiToken,
 159    _llm_token_subscription: Subscription,
 160    projects: HashMap<EntityId, ProjectState>,
 161    use_context: bool,
 162    options: ZetaOptions,
 163    update_required: bool,
 164    #[cfg(feature = "cli-support")]
 165    eval_cache: Option<Arc<dyn EvalCache>>,
 166    edit_prediction_model: EditPredictionModel,
 167    pub sweep_ai: SweepAi,
 168    pub mercury: Mercury,
 169    data_collection_choice: DataCollectionChoice,
 170    reject_predictions_tx: mpsc::UnboundedSender<EditPredictionRejection>,
 171    shown_predictions: VecDeque<EditPrediction>,
 172    rated_predictions: HashSet<EditPredictionId>,
 173}
 174
 175#[derive(Copy, Clone, Default, PartialEq, Eq)]
 176pub enum EditPredictionModel {
 177    #[default]
 178    Zeta1,
 179    Zeta2,
 180    Sweep,
 181    Mercury,
 182}
 183
 184pub struct EditPredictionModelInput {
 185    project: Entity<Project>,
 186    buffer: Entity<Buffer>,
 187    snapshot: BufferSnapshot,
 188    position: Anchor,
 189    events: Vec<Arc<zeta_prompt::Event>>,
 190    related_files: Arc<[RelatedFile]>,
 191    recent_paths: VecDeque<ProjectPath>,
 192    trigger: PredictEditsRequestTrigger,
 193    diagnostic_search_range: Range<Point>,
 194    debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
 195}
 196
 197#[derive(Debug, Clone, PartialEq)]
 198pub struct ZetaOptions {
 199    pub context: EditPredictionExcerptOptions,
 200    pub prompt_format: predict_edits_v3::PromptFormat,
 201}
 202
 203#[derive(Debug)]
 204pub enum DebugEvent {
 205    ContextRetrievalStarted(ContextRetrievalStartedDebugEvent),
 206    ContextRetrievalFinished(ContextRetrievalFinishedDebugEvent),
 207    EditPredictionStarted(EditPredictionStartedDebugEvent),
 208    EditPredictionFinished(EditPredictionFinishedDebugEvent),
 209}
 210
 211#[derive(Debug)]
 212pub struct ContextRetrievalStartedDebugEvent {
 213    pub project_entity_id: EntityId,
 214    pub timestamp: Instant,
 215    pub search_prompt: String,
 216}
 217
 218#[derive(Debug)]
 219pub struct ContextRetrievalFinishedDebugEvent {
 220    pub project_entity_id: EntityId,
 221    pub timestamp: Instant,
 222    pub metadata: Vec<(&'static str, SharedString)>,
 223}
 224
 225#[derive(Debug)]
 226pub struct EditPredictionStartedDebugEvent {
 227    pub buffer: WeakEntity<Buffer>,
 228    pub position: Anchor,
 229    pub prompt: Option<String>,
 230}
 231
 232#[derive(Debug)]
 233pub struct EditPredictionFinishedDebugEvent {
 234    pub buffer: WeakEntity<Buffer>,
 235    pub position: Anchor,
 236    pub model_output: Option<String>,
 237}
 238
 239pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
 240
 241struct ProjectState {
 242    events: VecDeque<Arc<zeta_prompt::Event>>,
 243    last_event: Option<LastEvent>,
 244    recent_paths: VecDeque<ProjectPath>,
 245    registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
 246    current_prediction: Option<CurrentEditPrediction>,
 247    next_pending_prediction_id: usize,
 248    pending_predictions: ArrayVec<PendingPrediction, 2>,
 249    debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
 250    last_prediction_refresh: Option<(EntityId, Instant)>,
 251    cancelled_predictions: HashSet<usize>,
 252    context: Entity<RelatedExcerptStore>,
 253    license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
 254    _subscription: gpui::Subscription,
 255}
 256
 257impl ProjectState {
 258    pub fn events(&self, cx: &App) -> Vec<Arc<zeta_prompt::Event>> {
 259        self.events
 260            .iter()
 261            .cloned()
 262            .chain(
 263                self.last_event
 264                    .as_ref()
 265                    .and_then(|event| event.finalize(&self.license_detection_watchers, cx)),
 266            )
 267            .collect()
 268    }
 269
 270    pub fn events_split_by_pause(&self, cx: &App) -> Vec<Arc<zeta_prompt::Event>> {
 271        self.events
 272            .iter()
 273            .cloned()
 274            .chain(self.last_event.as_ref().iter().flat_map(|event| {
 275                let (one, two) = event.split_by_pause();
 276                let one = one.finalize(&self.license_detection_watchers, cx);
 277                let two = two.and_then(|two| two.finalize(&self.license_detection_watchers, cx));
 278                one.into_iter().chain(two)
 279            }))
 280            .collect()
 281    }
 282
 283    fn cancel_pending_prediction(
 284        &mut self,
 285        pending_prediction: PendingPrediction,
 286        cx: &mut Context<EditPredictionStore>,
 287    ) {
 288        self.cancelled_predictions.insert(pending_prediction.id);
 289
 290        cx.spawn(async move |this, cx| {
 291            let Some(prediction_id) = pending_prediction.task.await else {
 292                return;
 293            };
 294
 295            this.update(cx, |this, _cx| {
 296                this.reject_prediction(prediction_id, EditPredictionRejectReason::Canceled, false);
 297            })
 298            .ok();
 299        })
 300        .detach()
 301    }
 302
 303    fn active_buffer(
 304        &self,
 305        project: &Entity<Project>,
 306        cx: &App,
 307    ) -> Option<(Entity<Buffer>, Option<Anchor>)> {
 308        let project = project.read(cx);
 309        let active_path = project.path_for_entry(project.active_entry()?, cx)?;
 310        let active_buffer = project.buffer_store().read(cx).get_by_path(&active_path)?;
 311        let registered_buffer = self.registered_buffers.get(&active_buffer.entity_id())?;
 312        Some((active_buffer, registered_buffer.last_position))
 313    }
 314}
 315
 316#[derive(Debug, Clone)]
 317struct CurrentEditPrediction {
 318    pub requested_by: PredictionRequestedBy,
 319    pub prediction: EditPrediction,
 320    pub was_shown: bool,
 321}
 322
 323impl CurrentEditPrediction {
 324    fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
 325        let Some(new_edits) = self
 326            .prediction
 327            .interpolate(&self.prediction.buffer.read(cx))
 328        else {
 329            return false;
 330        };
 331
 332        if self.prediction.buffer != old_prediction.prediction.buffer {
 333            return true;
 334        }
 335
 336        let Some(old_edits) = old_prediction
 337            .prediction
 338            .interpolate(&old_prediction.prediction.buffer.read(cx))
 339        else {
 340            return true;
 341        };
 342
 343        let requested_by_buffer_id = self.requested_by.buffer_id();
 344
 345        // This reduces the occurrence of UI thrash from replacing edits
 346        //
 347        // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
 348        if requested_by_buffer_id == Some(self.prediction.buffer.entity_id())
 349            && requested_by_buffer_id == Some(old_prediction.prediction.buffer.entity_id())
 350            && old_edits.len() == 1
 351            && new_edits.len() == 1
 352        {
 353            let (old_range, old_text) = &old_edits[0];
 354            let (new_range, new_text) = &new_edits[0];
 355            new_range == old_range && new_text.starts_with(old_text.as_ref())
 356        } else {
 357            true
 358        }
 359    }
 360}
 361
 362#[derive(Debug, Clone)]
 363enum PredictionRequestedBy {
 364    DiagnosticsUpdate,
 365    Buffer(EntityId),
 366}
 367
 368impl PredictionRequestedBy {
 369    pub fn buffer_id(&self) -> Option<EntityId> {
 370        match self {
 371            PredictionRequestedBy::DiagnosticsUpdate => None,
 372            PredictionRequestedBy::Buffer(buffer_id) => Some(*buffer_id),
 373        }
 374    }
 375}
 376
 377#[derive(Debug)]
 378struct PendingPrediction {
 379    id: usize,
 380    task: Task<Option<EditPredictionId>>,
 381}
 382
 383/// A prediction from the perspective of a buffer.
 384#[derive(Debug)]
 385enum BufferEditPrediction<'a> {
 386    Local { prediction: &'a EditPrediction },
 387    Jump { prediction: &'a EditPrediction },
 388}
 389
 390#[cfg(test)]
 391impl std::ops::Deref for BufferEditPrediction<'_> {
 392    type Target = EditPrediction;
 393
 394    fn deref(&self) -> &Self::Target {
 395        match self {
 396            BufferEditPrediction::Local { prediction } => prediction,
 397            BufferEditPrediction::Jump { prediction } => prediction,
 398        }
 399    }
 400}
 401
 402struct RegisteredBuffer {
 403    file: Option<Arc<dyn File>>,
 404    snapshot: TextBufferSnapshot,
 405    last_position: Option<Anchor>,
 406    _subscriptions: [gpui::Subscription; 2],
 407}
 408
 409#[derive(Clone)]
 410struct LastEvent {
 411    old_snapshot: TextBufferSnapshot,
 412    new_snapshot: TextBufferSnapshot,
 413    old_file: Option<Arc<dyn File>>,
 414    new_file: Option<Arc<dyn File>>,
 415    end_edit_anchor: Option<Anchor>,
 416    snapshot_after_last_editing_pause: Option<TextBufferSnapshot>,
 417    last_edit_time: Option<Instant>,
 418}
 419
 420impl LastEvent {
 421    pub fn finalize(
 422        &self,
 423        license_detection_watchers: &HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
 424        cx: &App,
 425    ) -> Option<Arc<zeta_prompt::Event>> {
 426        let path = buffer_path_with_id_fallback(self.new_file.as_ref(), &self.new_snapshot, cx);
 427        let old_path = buffer_path_with_id_fallback(self.old_file.as_ref(), &self.old_snapshot, cx);
 428
 429        let in_open_source_repo =
 430            [self.new_file.as_ref(), self.old_file.as_ref()]
 431                .iter()
 432                .all(|file| {
 433                    file.is_some_and(|file| {
 434                        license_detection_watchers
 435                            .get(&file.worktree_id(cx))
 436                            .is_some_and(|watcher| watcher.is_project_open_source())
 437                    })
 438                });
 439
 440        let diff = language::unified_diff(&self.old_snapshot.text(), &self.new_snapshot.text());
 441
 442        if path == old_path && diff.is_empty() {
 443            None
 444        } else {
 445            Some(Arc::new(zeta_prompt::Event::BufferChange {
 446                old_path,
 447                path,
 448                diff,
 449                in_open_source_repo,
 450                // TODO: Actually detect if this edit was predicted or not
 451                predicted: false,
 452            }))
 453        }
 454    }
 455
 456    pub fn split_by_pause(&self) -> (LastEvent, Option<LastEvent>) {
 457        let Some(boundary_snapshot) = self.snapshot_after_last_editing_pause.as_ref() else {
 458            return (self.clone(), None);
 459        };
 460
 461        let before = LastEvent {
 462            old_snapshot: self.old_snapshot.clone(),
 463            new_snapshot: boundary_snapshot.clone(),
 464            old_file: self.old_file.clone(),
 465            new_file: self.new_file.clone(),
 466            end_edit_anchor: self.end_edit_anchor,
 467            snapshot_after_last_editing_pause: None,
 468            last_edit_time: self.last_edit_time,
 469        };
 470
 471        let after = LastEvent {
 472            old_snapshot: boundary_snapshot.clone(),
 473            new_snapshot: self.new_snapshot.clone(),
 474            old_file: self.old_file.clone(),
 475            new_file: self.new_file.clone(),
 476            end_edit_anchor: self.end_edit_anchor,
 477            snapshot_after_last_editing_pause: None,
 478            last_edit_time: self.last_edit_time,
 479        };
 480
 481        (before, Some(after))
 482    }
 483}
 484
 485fn buffer_path_with_id_fallback(
 486    file: Option<&Arc<dyn File>>,
 487    snapshot: &TextBufferSnapshot,
 488    cx: &App,
 489) -> Arc<Path> {
 490    if let Some(file) = file {
 491        file.full_path(cx).into()
 492    } else {
 493        Path::new(&format!("untitled-{}", snapshot.remote_id())).into()
 494    }
 495}
 496
 497impl EditPredictionStore {
 498    pub fn try_global(cx: &App) -> Option<Entity<Self>> {
 499        cx.try_global::<EditPredictionStoreGlobal>()
 500            .map(|global| global.0.clone())
 501    }
 502
 503    pub fn global(
 504        client: &Arc<Client>,
 505        user_store: &Entity<UserStore>,
 506        cx: &mut App,
 507    ) -> Entity<Self> {
 508        cx.try_global::<EditPredictionStoreGlobal>()
 509            .map(|global| global.0.clone())
 510            .unwrap_or_else(|| {
 511                let ep_store = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
 512                cx.set_global(EditPredictionStoreGlobal(ep_store.clone()));
 513                ep_store
 514            })
 515    }
 516
 517    pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
 518        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
 519        let data_collection_choice = Self::load_data_collection_choice();
 520
 521        let llm_token = LlmApiToken::default();
 522
 523        let (reject_tx, reject_rx) = mpsc::unbounded();
 524        cx.background_spawn({
 525            let client = client.clone();
 526            let llm_token = llm_token.clone();
 527            let app_version = AppVersion::global(cx);
 528            let background_executor = cx.background_executor().clone();
 529            async move {
 530                Self::handle_rejected_predictions(
 531                    reject_rx,
 532                    client,
 533                    llm_token,
 534                    app_version,
 535                    background_executor,
 536                )
 537                .await
 538            }
 539        })
 540        .detach();
 541
 542        let mut this = Self {
 543            projects: HashMap::default(),
 544            client,
 545            user_store,
 546            options: DEFAULT_OPTIONS,
 547            use_context: false,
 548            llm_token,
 549            _llm_token_subscription: cx.subscribe(
 550                &refresh_llm_token_listener,
 551                |this, _listener, _event, cx| {
 552                    let client = this.client.clone();
 553                    let llm_token = this.llm_token.clone();
 554                    cx.spawn(async move |_this, _cx| {
 555                        llm_token.refresh(&client).await?;
 556                        anyhow::Ok(())
 557                    })
 558                    .detach_and_log_err(cx);
 559                },
 560            ),
 561            update_required: false,
 562            #[cfg(feature = "cli-support")]
 563            eval_cache: None,
 564            edit_prediction_model: EditPredictionModel::Zeta2,
 565            sweep_ai: SweepAi::new(cx),
 566            mercury: Mercury::new(cx),
 567            data_collection_choice,
 568            reject_predictions_tx: reject_tx,
 569            rated_predictions: Default::default(),
 570            shown_predictions: Default::default(),
 571        };
 572
 573        this.configure_context_retrieval(cx);
 574        let weak_this = cx.weak_entity();
 575        cx.on_flags_ready(move |_, cx| {
 576            weak_this
 577                .update(cx, |this, cx| this.configure_context_retrieval(cx))
 578                .ok();
 579        })
 580        .detach();
 581        cx.observe_global::<SettingsStore>(|this, cx| {
 582            this.configure_context_retrieval(cx);
 583        })
 584        .detach();
 585
 586        this
 587    }
 588
 589    pub fn set_edit_prediction_model(&mut self, model: EditPredictionModel) {
 590        self.edit_prediction_model = model;
 591    }
 592
 593    pub fn has_sweep_api_token(&self, cx: &App) -> bool {
 594        self.sweep_ai.api_token.read(cx).has_key()
 595    }
 596
 597    pub fn has_mercury_api_token(&self, cx: &App) -> bool {
 598        self.mercury.api_token.read(cx).has_key()
 599    }
 600
 601    #[cfg(feature = "cli-support")]
 602    pub fn with_eval_cache(&mut self, cache: Arc<dyn EvalCache>) {
 603        self.eval_cache = Some(cache);
 604    }
 605
 606    pub fn options(&self) -> &ZetaOptions {
 607        &self.options
 608    }
 609
 610    pub fn set_options(&mut self, options: ZetaOptions) {
 611        self.options = options;
 612    }
 613
 614    pub fn set_use_context(&mut self, use_context: bool) {
 615        self.use_context = use_context;
 616    }
 617
 618    pub fn clear_history(&mut self) {
 619        for project_state in self.projects.values_mut() {
 620            project_state.events.clear();
 621        }
 622    }
 623
 624    pub fn clear_history_for_project(&mut self, project: &Entity<Project>) {
 625        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
 626            project_state.events.clear();
 627        }
 628    }
 629
 630    pub fn edit_history_for_project(
 631        &self,
 632        project: &Entity<Project>,
 633        cx: &App,
 634    ) -> Vec<Arc<zeta_prompt::Event>> {
 635        self.projects
 636            .get(&project.entity_id())
 637            .map(|project_state| project_state.events(cx))
 638            .unwrap_or_default()
 639    }
 640
 641    pub fn edit_history_for_project_with_pause_split_last_event(
 642        &self,
 643        project: &Entity<Project>,
 644        cx: &App,
 645    ) -> Vec<Arc<zeta_prompt::Event>> {
 646        self.projects
 647            .get(&project.entity_id())
 648            .map(|project_state| project_state.events_split_by_pause(cx))
 649            .unwrap_or_default()
 650    }
 651
 652    pub fn context_for_project<'a>(
 653        &'a self,
 654        project: &Entity<Project>,
 655        cx: &'a App,
 656    ) -> Arc<[RelatedFile]> {
 657        self.projects
 658            .get(&project.entity_id())
 659            .map(|project| project.context.read(cx).related_files())
 660            .unwrap_or_else(|| vec![].into())
 661    }
 662
 663    pub fn context_for_project_with_buffers<'a>(
 664        &'a self,
 665        project: &Entity<Project>,
 666        cx: &'a App,
 667    ) -> Option<impl 'a + Iterator<Item = (RelatedFile, Entity<Buffer>)>> {
 668        self.projects
 669            .get(&project.entity_id())
 670            .map(|project| project.context.read(cx).related_files_with_buffers())
 671    }
 672
 673    pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
 674        if self.edit_prediction_model == EditPredictionModel::Zeta2 {
 675            self.user_store.read(cx).edit_prediction_usage()
 676        } else {
 677            None
 678        }
 679    }
 680
 681    pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
 682        self.get_or_init_project(project, cx);
 683    }
 684
 685    pub fn register_buffer(
 686        &mut self,
 687        buffer: &Entity<Buffer>,
 688        project: &Entity<Project>,
 689        cx: &mut Context<Self>,
 690    ) {
 691        let project_state = self.get_or_init_project(project, cx);
 692        Self::register_buffer_impl(project_state, buffer, project, cx);
 693    }
 694
 695    fn get_or_init_project(
 696        &mut self,
 697        project: &Entity<Project>,
 698        cx: &mut Context<Self>,
 699    ) -> &mut ProjectState {
 700        let entity_id = project.entity_id();
 701        self.projects
 702            .entry(entity_id)
 703            .or_insert_with(|| ProjectState {
 704                context: {
 705                    let related_excerpt_store = cx.new(|cx| RelatedExcerptStore::new(project, cx));
 706                    cx.subscribe(&related_excerpt_store, move |this, _, event, _| {
 707                        this.handle_excerpt_store_event(entity_id, event);
 708                    })
 709                    .detach();
 710                    related_excerpt_store
 711                },
 712                events: VecDeque::new(),
 713                last_event: None,
 714                recent_paths: VecDeque::new(),
 715                debug_tx: None,
 716                registered_buffers: HashMap::default(),
 717                current_prediction: None,
 718                cancelled_predictions: HashSet::default(),
 719                pending_predictions: ArrayVec::new(),
 720                next_pending_prediction_id: 0,
 721                last_prediction_refresh: None,
 722                license_detection_watchers: HashMap::default(),
 723                _subscription: cx.subscribe(&project, Self::handle_project_event),
 724            })
 725    }
 726
 727    pub fn remove_project(&mut self, project: &Entity<Project>) {
 728        self.projects.remove(&project.entity_id());
 729    }
 730
 731    fn handle_excerpt_store_event(
 732        &mut self,
 733        project_entity_id: EntityId,
 734        event: &RelatedExcerptStoreEvent,
 735    ) {
 736        if let Some(project_state) = self.projects.get(&project_entity_id) {
 737            if let Some(debug_tx) = project_state.debug_tx.clone() {
 738                match event {
 739                    RelatedExcerptStoreEvent::StartedRefresh => {
 740                        debug_tx
 741                            .unbounded_send(DebugEvent::ContextRetrievalStarted(
 742                                ContextRetrievalStartedDebugEvent {
 743                                    project_entity_id: project_entity_id,
 744                                    timestamp: Instant::now(),
 745                                    search_prompt: String::new(),
 746                                },
 747                            ))
 748                            .ok();
 749                    }
 750                    RelatedExcerptStoreEvent::FinishedRefresh {
 751                        cache_hit_count,
 752                        cache_miss_count,
 753                        mean_definition_latency,
 754                        max_definition_latency,
 755                    } => {
 756                        debug_tx
 757                            .unbounded_send(DebugEvent::ContextRetrievalFinished(
 758                                ContextRetrievalFinishedDebugEvent {
 759                                    project_entity_id: project_entity_id,
 760                                    timestamp: Instant::now(),
 761                                    metadata: vec![
 762                                        (
 763                                            "Cache Hits",
 764                                            format!(
 765                                                "{}/{}",
 766                                                cache_hit_count,
 767                                                cache_hit_count + cache_miss_count
 768                                            )
 769                                            .into(),
 770                                        ),
 771                                        (
 772                                            "Max LSP Time",
 773                                            format!("{} ms", max_definition_latency.as_millis())
 774                                                .into(),
 775                                        ),
 776                                        (
 777                                            "Mean LSP Time",
 778                                            format!("{} ms", mean_definition_latency.as_millis())
 779                                                .into(),
 780                                        ),
 781                                    ],
 782                                },
 783                            ))
 784                            .ok();
 785                    }
 786                }
 787            }
 788        }
 789    }
 790
 791    pub fn debug_info(
 792        &mut self,
 793        project: &Entity<Project>,
 794        cx: &mut Context<Self>,
 795    ) -> mpsc::UnboundedReceiver<DebugEvent> {
 796        let project_state = self.get_or_init_project(project, cx);
 797        let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
 798        project_state.debug_tx = Some(debug_watch_tx);
 799        debug_watch_rx
 800    }
 801
 802    fn handle_project_event(
 803        &mut self,
 804        project: Entity<Project>,
 805        event: &project::Event,
 806        cx: &mut Context<Self>,
 807    ) {
 808        // TODO [zeta2] init with recent paths
 809        match event {
 810            project::Event::ActiveEntryChanged(Some(active_entry_id)) => {
 811                let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
 812                    return;
 813                };
 814                let path = project.read(cx).path_for_entry(*active_entry_id, cx);
 815                if let Some(path) = path {
 816                    if let Some(ix) = project_state
 817                        .recent_paths
 818                        .iter()
 819                        .position(|probe| probe == &path)
 820                    {
 821                        project_state.recent_paths.remove(ix);
 822                    }
 823                    project_state.recent_paths.push_front(path);
 824                }
 825            }
 826            project::Event::DiagnosticsUpdated { .. } => {
 827                if cx.has_flag::<Zeta2FeatureFlag>() {
 828                    self.refresh_prediction_from_diagnostics(project, cx);
 829                }
 830            }
 831            _ => (),
 832        }
 833    }
 834
 835    fn register_buffer_impl<'a>(
 836        project_state: &'a mut ProjectState,
 837        buffer: &Entity<Buffer>,
 838        project: &Entity<Project>,
 839        cx: &mut Context<Self>,
 840    ) -> &'a mut RegisteredBuffer {
 841        let buffer_id = buffer.entity_id();
 842
 843        if let Some(file) = buffer.read(cx).file() {
 844            let worktree_id = file.worktree_id(cx);
 845            if let Some(worktree) = project.read(cx).worktree_for_id(worktree_id, cx) {
 846                project_state
 847                    .license_detection_watchers
 848                    .entry(worktree_id)
 849                    .or_insert_with(|| {
 850                        let project_entity_id = project.entity_id();
 851                        cx.observe_release(&worktree, move |this, _worktree, _cx| {
 852                            let Some(project_state) = this.projects.get_mut(&project_entity_id)
 853                            else {
 854                                return;
 855                            };
 856                            project_state
 857                                .license_detection_watchers
 858                                .remove(&worktree_id);
 859                        })
 860                        .detach();
 861                        Rc::new(LicenseDetectionWatcher::new(&worktree, cx))
 862                    });
 863            }
 864        }
 865
 866        match project_state.registered_buffers.entry(buffer_id) {
 867            hash_map::Entry::Occupied(entry) => entry.into_mut(),
 868            hash_map::Entry::Vacant(entry) => {
 869                let buf = buffer.read(cx);
 870                let snapshot = buf.text_snapshot();
 871                let file = buf.file().cloned();
 872                let project_entity_id = project.entity_id();
 873                entry.insert(RegisteredBuffer {
 874                    snapshot,
 875                    file,
 876                    last_position: None,
 877                    _subscriptions: [
 878                        cx.subscribe(buffer, {
 879                            let project = project.downgrade();
 880                            move |this, buffer, event, cx| {
 881                                if let language::BufferEvent::Edited = event
 882                                    && let Some(project) = project.upgrade()
 883                                {
 884                                    this.report_changes_for_buffer(&buffer, &project, cx);
 885                                }
 886                            }
 887                        }),
 888                        cx.observe_release(buffer, move |this, _buffer, _cx| {
 889                            let Some(project_state) = this.projects.get_mut(&project_entity_id)
 890                            else {
 891                                return;
 892                            };
 893                            project_state.registered_buffers.remove(&buffer_id);
 894                        }),
 895                    ],
 896                })
 897            }
 898        }
 899    }
 900
 901    fn report_changes_for_buffer(
 902        &mut self,
 903        buffer: &Entity<Buffer>,
 904        project: &Entity<Project>,
 905        cx: &mut Context<Self>,
 906    ) {
 907        let project_state = self.get_or_init_project(project, cx);
 908        let registered_buffer = Self::register_buffer_impl(project_state, buffer, project, cx);
 909
 910        let buf = buffer.read(cx);
 911        let new_file = buf.file().cloned();
 912        let new_snapshot = buf.text_snapshot();
 913        if new_snapshot.version == registered_buffer.snapshot.version {
 914            return;
 915        }
 916
 917        let old_file = mem::replace(&mut registered_buffer.file, new_file.clone());
 918        let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
 919        let end_edit_anchor = new_snapshot
 920            .anchored_edits_since::<Point>(&old_snapshot.version)
 921            .last()
 922            .map(|(_, range)| range.end);
 923        let events = &mut project_state.events;
 924
 925        let now = cx.background_executor().now();
 926        if let Some(last_event) = project_state.last_event.as_mut() {
 927            let is_next_snapshot_of_same_buffer = old_snapshot.remote_id()
 928                == last_event.new_snapshot.remote_id()
 929                && old_snapshot.version == last_event.new_snapshot.version;
 930
 931            let should_coalesce = is_next_snapshot_of_same_buffer
 932                && end_edit_anchor
 933                    .as_ref()
 934                    .zip(last_event.end_edit_anchor.as_ref())
 935                    .is_some_and(|(a, b)| {
 936                        let a = a.to_point(&new_snapshot);
 937                        let b = b.to_point(&new_snapshot);
 938                        a.row.abs_diff(b.row) <= CHANGE_GROUPING_LINE_SPAN
 939                    });
 940
 941            if should_coalesce {
 942                let pause_elapsed = last_event
 943                    .last_edit_time
 944                    .map(|t| now.duration_since(t) >= LAST_CHANGE_GROUPING_TIME)
 945                    .unwrap_or(false);
 946                if pause_elapsed {
 947                    last_event.snapshot_after_last_editing_pause =
 948                        Some(last_event.new_snapshot.clone());
 949                }
 950
 951                last_event.end_edit_anchor = end_edit_anchor;
 952                last_event.new_snapshot = new_snapshot;
 953                last_event.last_edit_time = Some(now);
 954                return;
 955            }
 956        }
 957
 958        if events.len() + 1 >= EVENT_COUNT_MAX {
 959            events.pop_front();
 960        }
 961
 962        if let Some(event) = project_state.last_event.take() {
 963            events.extend(event.finalize(&project_state.license_detection_watchers, cx));
 964        }
 965
 966        project_state.last_event = Some(LastEvent {
 967            old_file,
 968            new_file,
 969            old_snapshot,
 970            new_snapshot,
 971            end_edit_anchor,
 972            snapshot_after_last_editing_pause: None,
 973            last_edit_time: Some(now),
 974        });
 975    }
 976
 977    fn prediction_at(
 978        &mut self,
 979        buffer: &Entity<Buffer>,
 980        position: Option<language::Anchor>,
 981        project: &Entity<Project>,
 982        cx: &App,
 983    ) -> Option<BufferEditPrediction<'_>> {
 984        let project_state = self.projects.get_mut(&project.entity_id())?;
 985        if let Some(position) = position
 986            && let Some(buffer) = project_state
 987                .registered_buffers
 988                .get_mut(&buffer.entity_id())
 989        {
 990            buffer.last_position = Some(position);
 991        }
 992
 993        let CurrentEditPrediction {
 994            requested_by,
 995            prediction,
 996            ..
 997        } = project_state.current_prediction.as_ref()?;
 998
 999        if prediction.targets_buffer(buffer.read(cx)) {
1000            Some(BufferEditPrediction::Local { prediction })
1001        } else {
1002            let show_jump = match requested_by {
1003                PredictionRequestedBy::Buffer(requested_by_buffer_id) => {
1004                    requested_by_buffer_id == &buffer.entity_id()
1005                }
1006                PredictionRequestedBy::DiagnosticsUpdate => true,
1007            };
1008
1009            if show_jump {
1010                Some(BufferEditPrediction::Jump { prediction })
1011            } else {
1012                None
1013            }
1014        }
1015    }
1016
1017    fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
1018        match self.edit_prediction_model {
1019            EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {}
1020            EditPredictionModel::Sweep | EditPredictionModel::Mercury => return,
1021        }
1022
1023        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1024            return;
1025        };
1026
1027        let Some(prediction) = project_state.current_prediction.take() else {
1028            return;
1029        };
1030        let request_id = prediction.prediction.id.to_string();
1031        for pending_prediction in mem::take(&mut project_state.pending_predictions) {
1032            project_state.cancel_pending_prediction(pending_prediction, cx);
1033        }
1034
1035        let client = self.client.clone();
1036        let llm_token = self.llm_token.clone();
1037        let app_version = AppVersion::global(cx);
1038        cx.spawn(async move |this, cx| {
1039            let url = if let Ok(predict_edits_url) = env::var("ZED_ACCEPT_PREDICTION_URL") {
1040                http_client::Url::parse(&predict_edits_url)?
1041            } else {
1042                client
1043                    .http_client()
1044                    .build_zed_llm_url("/predict_edits/accept", &[])?
1045            };
1046
1047            let response = cx
1048                .background_spawn(Self::send_api_request::<()>(
1049                    move |builder| {
1050                        let req = builder.uri(url.as_ref()).body(
1051                            serde_json::to_string(&AcceptEditPredictionBody {
1052                                request_id: request_id.clone(),
1053                            })?
1054                            .into(),
1055                        );
1056                        Ok(req?)
1057                    },
1058                    client,
1059                    llm_token,
1060                    app_version,
1061                ))
1062                .await;
1063
1064            Self::handle_api_response(&this, response, cx)?;
1065            anyhow::Ok(())
1066        })
1067        .detach_and_log_err(cx);
1068    }
1069
1070    async fn handle_rejected_predictions(
1071        rx: UnboundedReceiver<EditPredictionRejection>,
1072        client: Arc<Client>,
1073        llm_token: LlmApiToken,
1074        app_version: Version,
1075        background_executor: BackgroundExecutor,
1076    ) {
1077        let mut rx = std::pin::pin!(rx.peekable());
1078        let mut batched = Vec::new();
1079
1080        while let Some(rejection) = rx.next().await {
1081            batched.push(rejection);
1082
1083            if batched.len() < MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST / 2 {
1084                select_biased! {
1085                    next = rx.as_mut().peek().fuse() => {
1086                        if next.is_some() {
1087                            continue;
1088                        }
1089                    }
1090                    () = background_executor.timer(REJECT_REQUEST_DEBOUNCE).fuse() => {},
1091                }
1092            }
1093
1094            let url = client
1095                .http_client()
1096                .build_zed_llm_url("/predict_edits/reject", &[])
1097                .unwrap();
1098
1099            let flush_count = batched
1100                .len()
1101                // in case items have accumulated after failure
1102                .min(MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST);
1103            let start = batched.len() - flush_count;
1104
1105            let body = RejectEditPredictionsBodyRef {
1106                rejections: &batched[start..],
1107            };
1108
1109            let result = Self::send_api_request::<()>(
1110                |builder| {
1111                    let req = builder
1112                        .uri(url.as_ref())
1113                        .body(serde_json::to_string(&body)?.into());
1114                    anyhow::Ok(req?)
1115                },
1116                client.clone(),
1117                llm_token.clone(),
1118                app_version.clone(),
1119            )
1120            .await;
1121
1122            if result.log_err().is_some() {
1123                batched.drain(start..);
1124            }
1125        }
1126    }
1127
1128    fn reject_current_prediction(
1129        &mut self,
1130        reason: EditPredictionRejectReason,
1131        project: &Entity<Project>,
1132    ) {
1133        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
1134            project_state.pending_predictions.clear();
1135            if let Some(prediction) = project_state.current_prediction.take() {
1136                self.reject_prediction(prediction.prediction.id, reason, prediction.was_shown);
1137            }
1138        };
1139    }
1140
1141    fn did_show_current_prediction(&mut self, project: &Entity<Project>, _cx: &mut Context<Self>) {
1142        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
1143            if let Some(current_prediction) = project_state.current_prediction.as_mut() {
1144                if !current_prediction.was_shown {
1145                    current_prediction.was_shown = true;
1146                    self.shown_predictions
1147                        .push_front(current_prediction.prediction.clone());
1148                    if self.shown_predictions.len() > 50 {
1149                        let completion = self.shown_predictions.pop_back().unwrap();
1150                        self.rated_predictions.remove(&completion.id);
1151                    }
1152                }
1153            }
1154        }
1155    }
1156
1157    fn reject_prediction(
1158        &mut self,
1159        prediction_id: EditPredictionId,
1160        reason: EditPredictionRejectReason,
1161        was_shown: bool,
1162    ) {
1163        match self.edit_prediction_model {
1164            EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {}
1165            EditPredictionModel::Sweep | EditPredictionModel::Mercury => return,
1166        }
1167
1168        self.reject_predictions_tx
1169            .unbounded_send(EditPredictionRejection {
1170                request_id: prediction_id.to_string(),
1171                reason,
1172                was_shown,
1173            })
1174            .log_err();
1175    }
1176
1177    fn is_refreshing(&self, project: &Entity<Project>) -> bool {
1178        self.projects
1179            .get(&project.entity_id())
1180            .is_some_and(|project_state| !project_state.pending_predictions.is_empty())
1181    }
1182
1183    pub fn refresh_prediction_from_buffer(
1184        &mut self,
1185        project: Entity<Project>,
1186        buffer: Entity<Buffer>,
1187        position: language::Anchor,
1188        cx: &mut Context<Self>,
1189    ) {
1190        self.queue_prediction_refresh(project.clone(), buffer.entity_id(), cx, move |this, cx| {
1191            let Some(request_task) = this
1192                .update(cx, |this, cx| {
1193                    this.request_prediction(
1194                        &project,
1195                        &buffer,
1196                        position,
1197                        PredictEditsRequestTrigger::Other,
1198                        cx,
1199                    )
1200                })
1201                .log_err()
1202            else {
1203                return Task::ready(anyhow::Ok(None));
1204            };
1205
1206            cx.spawn(async move |_cx| {
1207                request_task.await.map(|prediction_result| {
1208                    prediction_result.map(|prediction_result| {
1209                        (
1210                            prediction_result,
1211                            PredictionRequestedBy::Buffer(buffer.entity_id()),
1212                        )
1213                    })
1214                })
1215            })
1216        })
1217    }
1218
1219    pub fn refresh_prediction_from_diagnostics(
1220        &mut self,
1221        project: Entity<Project>,
1222        cx: &mut Context<Self>,
1223    ) {
1224        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1225            return;
1226        };
1227
1228        // Prefer predictions from buffer
1229        if project_state.current_prediction.is_some() {
1230            return;
1231        };
1232
1233        self.queue_prediction_refresh(project.clone(), project.entity_id(), cx, move |this, cx| {
1234            let Some((active_buffer, snapshot, cursor_point)) = this
1235                .read_with(cx, |this, cx| {
1236                    let project_state = this.projects.get(&project.entity_id())?;
1237                    let (buffer, position) = project_state.active_buffer(&project, cx)?;
1238                    let snapshot = buffer.read(cx).snapshot();
1239
1240                    if !Self::predictions_enabled_at(&snapshot, position, cx) {
1241                        return None;
1242                    }
1243
1244                    let cursor_point = position
1245                        .map(|pos| pos.to_point(&snapshot))
1246                        .unwrap_or_default();
1247
1248                    Some((buffer, snapshot, cursor_point))
1249                })
1250                .log_err()
1251                .flatten()
1252            else {
1253                return Task::ready(anyhow::Ok(None));
1254            };
1255
1256            cx.spawn(async move |cx| {
1257                let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1258                    active_buffer,
1259                    &snapshot,
1260                    Default::default(),
1261                    cursor_point,
1262                    &project,
1263                    cx,
1264                )
1265                .await?
1266                else {
1267                    return anyhow::Ok(None);
1268                };
1269
1270                let Some(prediction_result) = this
1271                    .update(cx, |this, cx| {
1272                        this.request_prediction(
1273                            &project,
1274                            &jump_buffer,
1275                            jump_position,
1276                            PredictEditsRequestTrigger::Diagnostics,
1277                            cx,
1278                        )
1279                    })?
1280                    .await?
1281                else {
1282                    return anyhow::Ok(None);
1283                };
1284
1285                this.update(cx, |this, cx| {
1286                    Some((
1287                        if this
1288                            .get_or_init_project(&project, cx)
1289                            .current_prediction
1290                            .is_none()
1291                        {
1292                            prediction_result
1293                        } else {
1294                            EditPredictionResult {
1295                                id: prediction_result.id,
1296                                prediction: Err(EditPredictionRejectReason::CurrentPreferred),
1297                            }
1298                        },
1299                        PredictionRequestedBy::DiagnosticsUpdate,
1300                    ))
1301                })
1302            })
1303        });
1304    }
1305
1306    fn predictions_enabled_at(
1307        snapshot: &BufferSnapshot,
1308        position: Option<language::Anchor>,
1309        cx: &App,
1310    ) -> bool {
1311        let file = snapshot.file();
1312        let all_settings = all_language_settings(file, cx);
1313        if !all_settings.show_edit_predictions(snapshot.language(), cx)
1314            || file.is_some_and(|file| !all_settings.edit_predictions_enabled_for_file(file, cx))
1315        {
1316            return false;
1317        }
1318
1319        if let Some(last_position) = position {
1320            let settings = snapshot.settings_at(last_position, cx);
1321
1322            if !settings.edit_predictions_disabled_in.is_empty()
1323                && let Some(scope) = snapshot.language_scope_at(last_position)
1324                && let Some(scope_name) = scope.override_name()
1325                && settings
1326                    .edit_predictions_disabled_in
1327                    .iter()
1328                    .any(|s| s == scope_name)
1329            {
1330                return false;
1331            }
1332        }
1333
1334        true
1335    }
1336
1337    #[cfg(not(test))]
1338    pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
1339    #[cfg(test)]
1340    pub const THROTTLE_TIMEOUT: Duration = Duration::ZERO;
1341
1342    fn queue_prediction_refresh(
1343        &mut self,
1344        project: Entity<Project>,
1345        throttle_entity: EntityId,
1346        cx: &mut Context<Self>,
1347        do_refresh: impl FnOnce(
1348            WeakEntity<Self>,
1349            &mut AsyncApp,
1350        )
1351            -> Task<Result<Option<(EditPredictionResult, PredictionRequestedBy)>>>
1352        + 'static,
1353    ) {
1354        let project_state = self.get_or_init_project(&project, cx);
1355        let pending_prediction_id = project_state.next_pending_prediction_id;
1356        project_state.next_pending_prediction_id += 1;
1357        let last_request = project_state.last_prediction_refresh;
1358
1359        let task = cx.spawn(async move |this, cx| {
1360            if let Some((last_entity, last_timestamp)) = last_request
1361                && throttle_entity == last_entity
1362                && let Some(timeout) =
1363                    (last_timestamp + Self::THROTTLE_TIMEOUT).checked_duration_since(Instant::now())
1364            {
1365                cx.background_executor().timer(timeout).await;
1366            }
1367
1368            // If this task was cancelled before the throttle timeout expired,
1369            // do not perform a request.
1370            let mut is_cancelled = true;
1371            this.update(cx, |this, cx| {
1372                let project_state = this.get_or_init_project(&project, cx);
1373                if !project_state
1374                    .cancelled_predictions
1375                    .remove(&pending_prediction_id)
1376                {
1377                    project_state.last_prediction_refresh = Some((throttle_entity, Instant::now()));
1378                    is_cancelled = false;
1379                }
1380            })
1381            .ok();
1382            if is_cancelled {
1383                return None;
1384            }
1385
1386            let new_prediction_result = do_refresh(this.clone(), cx).await.log_err().flatten();
1387            let new_prediction_id = new_prediction_result
1388                .as_ref()
1389                .map(|(prediction, _)| prediction.id.clone());
1390
1391            // When a prediction completes, remove it from the pending list, and cancel
1392            // any pending predictions that were enqueued before it.
1393            this.update(cx, |this, cx| {
1394                let project_state = this.get_or_init_project(&project, cx);
1395
1396                let is_cancelled = project_state
1397                    .cancelled_predictions
1398                    .remove(&pending_prediction_id);
1399
1400                let new_current_prediction = if !is_cancelled
1401                    && let Some((prediction_result, requested_by)) = new_prediction_result
1402                {
1403                    match prediction_result.prediction {
1404                        Ok(prediction) => {
1405                            let new_prediction = CurrentEditPrediction {
1406                                requested_by,
1407                                prediction,
1408                                was_shown: false,
1409                            };
1410
1411                            if let Some(current_prediction) =
1412                                project_state.current_prediction.as_ref()
1413                            {
1414                                if new_prediction.should_replace_prediction(&current_prediction, cx)
1415                                {
1416                                    this.reject_current_prediction(
1417                                        EditPredictionRejectReason::Replaced,
1418                                        &project,
1419                                    );
1420
1421                                    Some(new_prediction)
1422                                } else {
1423                                    this.reject_prediction(
1424                                        new_prediction.prediction.id,
1425                                        EditPredictionRejectReason::CurrentPreferred,
1426                                        false,
1427                                    );
1428                                    None
1429                                }
1430                            } else {
1431                                Some(new_prediction)
1432                            }
1433                        }
1434                        Err(reject_reason) => {
1435                            this.reject_prediction(prediction_result.id, reject_reason, false);
1436                            None
1437                        }
1438                    }
1439                } else {
1440                    None
1441                };
1442
1443                let project_state = this.get_or_init_project(&project, cx);
1444
1445                if let Some(new_prediction) = new_current_prediction {
1446                    project_state.current_prediction = Some(new_prediction);
1447                }
1448
1449                let mut pending_predictions = mem::take(&mut project_state.pending_predictions);
1450                for (ix, pending_prediction) in pending_predictions.iter().enumerate() {
1451                    if pending_prediction.id == pending_prediction_id {
1452                        pending_predictions.remove(ix);
1453                        for pending_prediction in pending_predictions.drain(0..ix) {
1454                            project_state.cancel_pending_prediction(pending_prediction, cx)
1455                        }
1456                        break;
1457                    }
1458                }
1459                this.get_or_init_project(&project, cx).pending_predictions = pending_predictions;
1460                cx.notify();
1461            })
1462            .ok();
1463
1464            new_prediction_id
1465        });
1466
1467        if project_state.pending_predictions.len() <= 1 {
1468            project_state.pending_predictions.push(PendingPrediction {
1469                id: pending_prediction_id,
1470                task,
1471            });
1472        } else if project_state.pending_predictions.len() == 2 {
1473            let pending_prediction = project_state.pending_predictions.pop().unwrap();
1474            project_state.pending_predictions.push(PendingPrediction {
1475                id: pending_prediction_id,
1476                task,
1477            });
1478            project_state.cancel_pending_prediction(pending_prediction, cx);
1479        }
1480    }
1481
1482    pub fn request_prediction(
1483        &mut self,
1484        project: &Entity<Project>,
1485        active_buffer: &Entity<Buffer>,
1486        position: language::Anchor,
1487        trigger: PredictEditsRequestTrigger,
1488        cx: &mut Context<Self>,
1489    ) -> Task<Result<Option<EditPredictionResult>>> {
1490        self.request_prediction_internal(
1491            project.clone(),
1492            active_buffer.clone(),
1493            position,
1494            trigger,
1495            cx.has_flag::<Zeta2FeatureFlag>(),
1496            cx,
1497        )
1498    }
1499
1500    fn request_prediction_internal(
1501        &mut self,
1502        project: Entity<Project>,
1503        active_buffer: Entity<Buffer>,
1504        position: language::Anchor,
1505        trigger: PredictEditsRequestTrigger,
1506        allow_jump: bool,
1507        cx: &mut Context<Self>,
1508    ) -> Task<Result<Option<EditPredictionResult>>> {
1509        const DIAGNOSTIC_LINES_RANGE: u32 = 20;
1510
1511        self.get_or_init_project(&project, cx);
1512        let project_state = self.projects.get(&project.entity_id()).unwrap();
1513        let events = project_state.events(cx);
1514        let has_events = !events.is_empty();
1515        let debug_tx = project_state.debug_tx.clone();
1516
1517        let snapshot = active_buffer.read(cx).snapshot();
1518        let cursor_point = position.to_point(&snapshot);
1519        let diagnostic_search_start = cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
1520        let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
1521        let diagnostic_search_range =
1522            Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
1523
1524        let related_files = if self.use_context {
1525            self.context_for_project(&project, cx)
1526        } else {
1527            Vec::new().into()
1528        };
1529
1530        let inputs = EditPredictionModelInput {
1531            project: project.clone(),
1532            buffer: active_buffer.clone(),
1533            snapshot: snapshot.clone(),
1534            position,
1535            events,
1536            related_files,
1537            recent_paths: project_state.recent_paths.clone(),
1538            trigger,
1539            diagnostic_search_range: diagnostic_search_range.clone(),
1540            debug_tx,
1541        };
1542
1543        let task = match self.edit_prediction_model {
1544            EditPredictionModel::Zeta1 => zeta1::request_prediction_with_zeta1(self, inputs, cx),
1545            EditPredictionModel::Zeta2 => zeta2::request_prediction_with_zeta2(self, inputs, cx),
1546            EditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep(inputs, cx),
1547            EditPredictionModel::Mercury => self.mercury.request_prediction(inputs, cx),
1548        };
1549
1550        cx.spawn(async move |this, cx| {
1551            let prediction = task.await?;
1552
1553            if prediction.is_none() && allow_jump {
1554                let cursor_point = position.to_point(&snapshot);
1555                if has_events
1556                    && let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1557                        active_buffer.clone(),
1558                        &snapshot,
1559                        diagnostic_search_range,
1560                        cursor_point,
1561                        &project,
1562                        cx,
1563                    )
1564                    .await?
1565                {
1566                    return this
1567                        .update(cx, |this, cx| {
1568                            this.request_prediction_internal(
1569                                project,
1570                                jump_buffer,
1571                                jump_position,
1572                                trigger,
1573                                false,
1574                                cx,
1575                            )
1576                        })?
1577                        .await;
1578                }
1579
1580                return anyhow::Ok(None);
1581            }
1582
1583            Ok(prediction)
1584        })
1585    }
1586
1587    async fn next_diagnostic_location(
1588        active_buffer: Entity<Buffer>,
1589        active_buffer_snapshot: &BufferSnapshot,
1590        active_buffer_diagnostic_search_range: Range<Point>,
1591        active_buffer_cursor_point: Point,
1592        project: &Entity<Project>,
1593        cx: &mut AsyncApp,
1594    ) -> Result<Option<(Entity<Buffer>, language::Anchor)>> {
1595        // find the closest diagnostic to the cursor that wasn't close enough to be included in the last request
1596        let mut jump_location = active_buffer_snapshot
1597            .diagnostic_groups(None)
1598            .into_iter()
1599            .filter_map(|(_, group)| {
1600                let range = &group.entries[group.primary_ix]
1601                    .range
1602                    .to_point(&active_buffer_snapshot);
1603                if range.overlaps(&active_buffer_diagnostic_search_range) {
1604                    None
1605                } else {
1606                    Some(range.start)
1607                }
1608            })
1609            .min_by_key(|probe| probe.row.abs_diff(active_buffer_cursor_point.row))
1610            .map(|position| {
1611                (
1612                    active_buffer.clone(),
1613                    active_buffer_snapshot.anchor_before(position),
1614                )
1615            });
1616
1617        if jump_location.is_none() {
1618            let active_buffer_path = active_buffer.read_with(cx, |buffer, cx| {
1619                let file = buffer.file()?;
1620
1621                Some(ProjectPath {
1622                    worktree_id: file.worktree_id(cx),
1623                    path: file.path().clone(),
1624                })
1625            })?;
1626
1627            let buffer_task = project.update(cx, |project, cx| {
1628                let (path, _, _) = project
1629                    .diagnostic_summaries(false, cx)
1630                    .filter(|(path, _, _)| Some(path) != active_buffer_path.as_ref())
1631                    .max_by_key(|(path, _, _)| {
1632                        // find the buffer with errors that shares most parent directories
1633                        path.path
1634                            .components()
1635                            .zip(
1636                                active_buffer_path
1637                                    .as_ref()
1638                                    .map(|p| p.path.components())
1639                                    .unwrap_or_default(),
1640                            )
1641                            .take_while(|(a, b)| a == b)
1642                            .count()
1643                    })?;
1644
1645                Some(project.open_buffer(path, cx))
1646            })?;
1647
1648            if let Some(buffer_task) = buffer_task {
1649                let closest_buffer = buffer_task.await?;
1650
1651                jump_location = closest_buffer
1652                    .read_with(cx, |buffer, _cx| {
1653                        buffer
1654                            .buffer_diagnostics(None)
1655                            .into_iter()
1656                            .min_by_key(|entry| entry.diagnostic.severity)
1657                            .map(|entry| entry.range.start)
1658                    })?
1659                    .map(|position| (closest_buffer, position));
1660            }
1661        }
1662
1663        anyhow::Ok(jump_location)
1664    }
1665
1666    async fn send_raw_llm_request(
1667        request: open_ai::Request,
1668        client: Arc<Client>,
1669        llm_token: LlmApiToken,
1670        app_version: Version,
1671        #[cfg(feature = "cli-support")] eval_cache: Option<Arc<dyn EvalCache>>,
1672        #[cfg(feature = "cli-support")] eval_cache_kind: EvalCacheEntryKind,
1673    ) -> Result<(open_ai::Response, Option<EditPredictionUsage>)> {
1674        let url = if let Some(predict_edits_url) = PREDICT_EDITS_URL.as_ref() {
1675            http_client::Url::parse(&predict_edits_url)?
1676        } else {
1677            client
1678                .http_client()
1679                .build_zed_llm_url("/predict_edits/raw", &[])?
1680        };
1681
1682        #[cfg(feature = "cli-support")]
1683        let cache_key = if let Some(cache) = eval_cache {
1684            use collections::FxHasher;
1685            use std::hash::{Hash, Hasher};
1686
1687            let mut hasher = FxHasher::default();
1688            url.hash(&mut hasher);
1689            let request_str = serde_json::to_string_pretty(&request)?;
1690            request_str.hash(&mut hasher);
1691            let hash = hasher.finish();
1692
1693            let key = (eval_cache_kind, hash);
1694            if let Some(response_str) = cache.read(key) {
1695                return Ok((serde_json::from_str(&response_str)?, None));
1696            }
1697
1698            Some((cache, request_str, key))
1699        } else {
1700            None
1701        };
1702
1703        let (response, usage) = Self::send_api_request(
1704            |builder| {
1705                let req = builder
1706                    .uri(url.as_ref())
1707                    .body(serde_json::to_string(&request)?.into());
1708                Ok(req?)
1709            },
1710            client,
1711            llm_token,
1712            app_version,
1713        )
1714        .await?;
1715
1716        #[cfg(feature = "cli-support")]
1717        if let Some((cache, request, key)) = cache_key {
1718            cache.write(key, &request, &serde_json::to_string_pretty(&response)?);
1719        }
1720
1721        Ok((response, usage))
1722    }
1723
1724    fn handle_api_response<T>(
1725        this: &WeakEntity<Self>,
1726        response: Result<(T, Option<EditPredictionUsage>)>,
1727        cx: &mut gpui::AsyncApp,
1728    ) -> Result<T> {
1729        match response {
1730            Ok((data, usage)) => {
1731                if let Some(usage) = usage {
1732                    this.update(cx, |this, cx| {
1733                        this.user_store.update(cx, |user_store, cx| {
1734                            user_store.update_edit_prediction_usage(usage, cx);
1735                        });
1736                    })
1737                    .ok();
1738                }
1739                Ok(data)
1740            }
1741            Err(err) => {
1742                if err.is::<ZedUpdateRequiredError>() {
1743                    cx.update(|cx| {
1744                        this.update(cx, |this, _cx| {
1745                            this.update_required = true;
1746                        })
1747                        .ok();
1748
1749                        let error_message: SharedString = err.to_string().into();
1750                        show_app_notification(
1751                            NotificationId::unique::<ZedUpdateRequiredError>(),
1752                            cx,
1753                            move |cx| {
1754                                cx.new(|cx| {
1755                                    ErrorMessagePrompt::new(error_message.clone(), cx)
1756                                        .with_link_button("Update Zed", "https://zed.dev/releases")
1757                                })
1758                            },
1759                        );
1760                    })
1761                    .ok();
1762                }
1763                Err(err)
1764            }
1765        }
1766    }
1767
1768    async fn send_api_request<Res>(
1769        build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
1770        client: Arc<Client>,
1771        llm_token: LlmApiToken,
1772        app_version: Version,
1773    ) -> Result<(Res, Option<EditPredictionUsage>)>
1774    where
1775        Res: DeserializeOwned,
1776    {
1777        let http_client = client.http_client();
1778        let mut token = llm_token.acquire(&client).await?;
1779        let mut did_retry = false;
1780
1781        loop {
1782            let request_builder = http_client::Request::builder().method(Method::POST);
1783
1784            let request = build(
1785                request_builder
1786                    .header("Content-Type", "application/json")
1787                    .header("Authorization", format!("Bearer {}", token))
1788                    .header(ZED_VERSION_HEADER_NAME, app_version.to_string()),
1789            )?;
1790
1791            let mut response = http_client.send(request).await?;
1792
1793            if let Some(minimum_required_version) = response
1794                .headers()
1795                .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
1796                .and_then(|version| Version::from_str(version.to_str().ok()?).ok())
1797            {
1798                anyhow::ensure!(
1799                    app_version >= minimum_required_version,
1800                    ZedUpdateRequiredError {
1801                        minimum_version: minimum_required_version
1802                    }
1803                );
1804            }
1805
1806            if response.status().is_success() {
1807                let usage = EditPredictionUsage::from_headers(response.headers()).ok();
1808
1809                let mut body = Vec::new();
1810                response.body_mut().read_to_end(&mut body).await?;
1811                return Ok((serde_json::from_slice(&body)?, usage));
1812            } else if !did_retry
1813                && response
1814                    .headers()
1815                    .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
1816                    .is_some()
1817            {
1818                did_retry = true;
1819                token = llm_token.refresh(&client).await?;
1820            } else {
1821                let mut body = String::new();
1822                response.body_mut().read_to_string(&mut body).await?;
1823                anyhow::bail!(
1824                    "Request failed with status: {:?}\nBody: {}",
1825                    response.status(),
1826                    body
1827                );
1828            }
1829        }
1830    }
1831
1832    pub fn refresh_context(
1833        &mut self,
1834        project: &Entity<Project>,
1835        buffer: &Entity<language::Buffer>,
1836        cursor_position: language::Anchor,
1837        cx: &mut Context<Self>,
1838    ) {
1839        if self.use_context {
1840            self.get_or_init_project(project, cx)
1841                .context
1842                .update(cx, |store, cx| {
1843                    store.refresh(buffer.clone(), cursor_position, cx);
1844                });
1845        }
1846    }
1847
1848    #[cfg(feature = "cli-support")]
1849    pub fn set_context_for_buffer(
1850        &mut self,
1851        project: &Entity<Project>,
1852        related_files: Vec<RelatedFile>,
1853        cx: &mut Context<Self>,
1854    ) {
1855        self.get_or_init_project(project, cx)
1856            .context
1857            .update(cx, |store, _| {
1858                store.set_related_files(related_files);
1859            });
1860    }
1861
1862    fn is_file_open_source(
1863        &self,
1864        project: &Entity<Project>,
1865        file: &Arc<dyn File>,
1866        cx: &App,
1867    ) -> bool {
1868        if !file.is_local() || file.is_private() {
1869            return false;
1870        }
1871        let Some(project_state) = self.projects.get(&project.entity_id()) else {
1872            return false;
1873        };
1874        project_state
1875            .license_detection_watchers
1876            .get(&file.worktree_id(cx))
1877            .as_ref()
1878            .is_some_and(|watcher| watcher.is_project_open_source())
1879    }
1880
1881    fn can_collect_file(&self, project: &Entity<Project>, file: &Arc<dyn File>, cx: &App) -> bool {
1882        self.data_collection_choice.is_enabled() && self.is_file_open_source(project, file, cx)
1883    }
1884
1885    fn can_collect_events(&self, events: &[Arc<zeta_prompt::Event>]) -> bool {
1886        if !self.data_collection_choice.is_enabled() {
1887            return false;
1888        }
1889        events.iter().all(|event| {
1890            matches!(
1891                event.as_ref(),
1892                zeta_prompt::Event::BufferChange {
1893                    in_open_source_repo: true,
1894                    ..
1895                }
1896            )
1897        })
1898    }
1899
1900    fn load_data_collection_choice() -> DataCollectionChoice {
1901        let choice = KEY_VALUE_STORE
1902            .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
1903            .log_err()
1904            .flatten();
1905
1906        match choice.as_deref() {
1907            Some("true") => DataCollectionChoice::Enabled,
1908            Some("false") => DataCollectionChoice::Disabled,
1909            Some(_) => {
1910                log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
1911                DataCollectionChoice::NotAnswered
1912            }
1913            None => DataCollectionChoice::NotAnswered,
1914        }
1915    }
1916
1917    fn toggle_data_collection_choice(&mut self, cx: &mut Context<Self>) {
1918        self.data_collection_choice = self.data_collection_choice.toggle();
1919        let new_choice = self.data_collection_choice;
1920        db::write_and_log(cx, move || {
1921            KEY_VALUE_STORE.write_kvp(
1922                ZED_PREDICT_DATA_COLLECTION_CHOICE.into(),
1923                new_choice.is_enabled().to_string(),
1924            )
1925        });
1926    }
1927
1928    pub fn shown_predictions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
1929        self.shown_predictions.iter()
1930    }
1931
1932    pub fn shown_completions_len(&self) -> usize {
1933        self.shown_predictions.len()
1934    }
1935
1936    pub fn is_prediction_rated(&self, id: &EditPredictionId) -> bool {
1937        self.rated_predictions.contains(id)
1938    }
1939
1940    pub fn rate_prediction(
1941        &mut self,
1942        prediction: &EditPrediction,
1943        rating: EditPredictionRating,
1944        feedback: String,
1945        cx: &mut Context<Self>,
1946    ) {
1947        self.rated_predictions.insert(prediction.id.clone());
1948        telemetry::event!(
1949            "Edit Prediction Rated",
1950            rating,
1951            inputs = prediction.inputs,
1952            output = prediction.edit_preview.as_unified_diff(&prediction.edits),
1953            feedback
1954        );
1955        self.client.telemetry().flush_events().detach();
1956        cx.notify();
1957    }
1958
1959    fn configure_context_retrieval(&mut self, cx: &mut Context<'_, EditPredictionStore>) {
1960        self.use_context = cx.has_flag::<Zeta2FeatureFlag>()
1961            && all_language_settings(None, cx).edit_predictions.use_context;
1962    }
1963}
1964
1965#[derive(Error, Debug)]
1966#[error(
1967    "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
1968)]
1969pub struct ZedUpdateRequiredError {
1970    minimum_version: Version,
1971}
1972
1973#[cfg(feature = "cli-support")]
1974pub type EvalCacheKey = (EvalCacheEntryKind, u64);
1975
1976#[cfg(feature = "cli-support")]
1977#[derive(Debug, Clone, Copy, PartialEq)]
1978pub enum EvalCacheEntryKind {
1979    Context,
1980    Search,
1981    Prediction,
1982}
1983
1984#[cfg(feature = "cli-support")]
1985impl std::fmt::Display for EvalCacheEntryKind {
1986    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1987        match self {
1988            EvalCacheEntryKind::Search => write!(f, "search"),
1989            EvalCacheEntryKind::Context => write!(f, "context"),
1990            EvalCacheEntryKind::Prediction => write!(f, "prediction"),
1991        }
1992    }
1993}
1994
1995#[cfg(feature = "cli-support")]
1996pub trait EvalCache: Send + Sync {
1997    fn read(&self, key: EvalCacheKey) -> Option<String>;
1998    fn write(&self, key: EvalCacheKey, input: &str, value: &str);
1999}
2000
2001#[derive(Debug, Clone, Copy)]
2002pub enum DataCollectionChoice {
2003    NotAnswered,
2004    Enabled,
2005    Disabled,
2006}
2007
2008impl DataCollectionChoice {
2009    pub fn is_enabled(self) -> bool {
2010        match self {
2011            Self::Enabled => true,
2012            Self::NotAnswered | Self::Disabled => false,
2013        }
2014    }
2015
2016    pub fn is_answered(self) -> bool {
2017        match self {
2018            Self::Enabled | Self::Disabled => true,
2019            Self::NotAnswered => false,
2020        }
2021    }
2022
2023    #[must_use]
2024    pub fn toggle(&self) -> DataCollectionChoice {
2025        match self {
2026            Self::Enabled => Self::Disabled,
2027            Self::Disabled => Self::Enabled,
2028            Self::NotAnswered => Self::Enabled,
2029        }
2030    }
2031}
2032
2033impl From<bool> for DataCollectionChoice {
2034    fn from(value: bool) -> Self {
2035        match value {
2036            true => DataCollectionChoice::Enabled,
2037            false => DataCollectionChoice::Disabled,
2038        }
2039    }
2040}
2041
2042struct ZedPredictUpsell;
2043
2044impl Dismissable for ZedPredictUpsell {
2045    const KEY: &'static str = "dismissed-edit-predict-upsell";
2046
2047    fn dismissed() -> bool {
2048        // To make this backwards compatible with older versions of Zed, we
2049        // check if the user has seen the previous Edit Prediction Onboarding
2050        // before, by checking the data collection choice which was written to
2051        // the database once the user clicked on "Accept and Enable"
2052        if KEY_VALUE_STORE
2053            .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2054            .log_err()
2055            .is_some_and(|s| s.is_some())
2056        {
2057            return true;
2058        }
2059
2060        KEY_VALUE_STORE
2061            .read_kvp(Self::KEY)
2062            .log_err()
2063            .is_some_and(|s| s.is_some())
2064    }
2065}
2066
2067pub fn should_show_upsell_modal() -> bool {
2068    !ZedPredictUpsell::dismissed()
2069}
2070
2071pub fn init(cx: &mut App) {
2072    cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
2073        workspace.register_action(
2074            move |workspace, _: &zed_actions::OpenZedPredictOnboarding, window, cx| {
2075                ZedPredictModal::toggle(
2076                    workspace,
2077                    workspace.user_store().clone(),
2078                    workspace.client().clone(),
2079                    window,
2080                    cx,
2081                )
2082            },
2083        );
2084
2085        workspace.register_action(|workspace, _: &ResetOnboarding, _window, cx| {
2086            update_settings_file(workspace.app_state().fs.clone(), cx, move |settings, _| {
2087                settings
2088                    .project
2089                    .all_languages
2090                    .features
2091                    .get_or_insert_default()
2092                    .edit_prediction_provider = Some(EditPredictionProvider::None)
2093            });
2094        });
2095    })
2096    .detach();
2097}