edit_prediction.rs

   1use anyhow::Result;
   2use arrayvec::ArrayVec;
   3use client::{Client, EditPredictionUsage, UserStore};
   4use cloud_llm_client::predict_edits_v3::{self, PromptFormat};
   5use cloud_llm_client::{
   6    EXPIRED_LLM_TOKEN_HEADER_NAME, EditPredictionRejectReason, EditPredictionRejection,
   7    MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
   8    PredictEditsRequestTrigger, RejectEditPredictionsBodyRef, ZED_VERSION_HEADER_NAME,
   9};
  10use collections::{HashMap, HashSet};
  11use db::kvp::{Dismissable, KEY_VALUE_STORE};
  12use edit_prediction_context::EditPredictionExcerptOptions;
  13use edit_prediction_context::{RelatedExcerptStore, RelatedExcerptStoreEvent, RelatedFile};
  14use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
  15use futures::{
  16    AsyncReadExt as _, FutureExt as _, StreamExt as _,
  17    channel::mpsc::{self, UnboundedReceiver},
  18    select_biased,
  19};
  20use gpui::BackgroundExecutor;
  21use gpui::http_client::Url;
  22use gpui::{
  23    App, AsyncApp, Entity, EntityId, Global, SharedString, Subscription, Task, WeakEntity, actions,
  24    http_client::{self, AsyncBody, Method},
  25    prelude::*,
  26};
  27use language::language_settings::all_language_settings;
  28use language::{Anchor, Buffer, File, Point, TextBufferSnapshot, ToPoint};
  29use language::{BufferSnapshot, OffsetRangeExt};
  30use language_model::{LlmApiToken, RefreshLlmTokenListener};
  31use project::{Project, ProjectPath, WorktreeId};
  32use release_channel::AppVersion;
  33use semver::Version;
  34use serde::de::DeserializeOwned;
  35use settings::{EditPredictionProvider, SettingsStore, update_settings_file};
  36use std::collections::{VecDeque, hash_map};
  37use text::Edit;
  38use workspace::Workspace;
  39
  40use std::ops::Range;
  41use std::path::Path;
  42use std::rc::Rc;
  43use std::str::FromStr as _;
  44use std::sync::{Arc, LazyLock};
  45use std::time::{Duration, Instant};
  46use std::{env, mem};
  47use thiserror::Error;
  48use util::{RangeExt as _, ResultExt as _};
  49use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
  50
  51pub mod cursor_excerpt;
  52pub mod example_spec;
  53mod license_detection;
  54pub mod mercury;
  55mod onboarding_modal;
  56pub mod open_ai_response;
  57mod prediction;
  58pub mod sweep_ai;
  59
  60pub mod udiff;
  61
  62mod capture_example;
  63mod zed_edit_prediction_delegate;
  64pub mod zeta1;
  65pub mod zeta2;
  66
  67#[cfg(test)]
  68mod edit_prediction_tests;
  69
  70use crate::capture_example::should_sample_edit_prediction_example_capture;
  71use crate::license_detection::LicenseDetectionWatcher;
  72use crate::mercury::Mercury;
  73use crate::onboarding_modal::ZedPredictModal;
  74pub use crate::prediction::EditPrediction;
  75pub use crate::prediction::EditPredictionId;
  76use crate::prediction::EditPredictionResult;
  77pub use crate::sweep_ai::SweepAi;
  78pub use capture_example::capture_example;
  79pub use language_model::ApiKeyState;
  80pub use telemetry_events::EditPredictionRating;
  81pub use zed_edit_prediction_delegate::ZedEditPredictionDelegate;
  82
  83actions!(
  84    edit_prediction,
  85    [
  86        /// Resets the edit prediction onboarding state.
  87        ResetOnboarding,
  88        /// Clears the edit prediction history.
  89        ClearHistory,
  90    ]
  91);
  92
  93/// Maximum number of events to track.
  94const EVENT_COUNT_MAX: usize = 6;
  95const CHANGE_GROUPING_LINE_SPAN: u32 = 8;
  96const LAST_CHANGE_GROUPING_TIME: Duration = Duration::from_secs(1);
  97const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
  98const REJECT_REQUEST_DEBOUNCE: Duration = Duration::from_secs(15);
  99
 100pub struct SweepFeatureFlag;
 101
 102impl FeatureFlag for SweepFeatureFlag {
 103    const NAME: &str = "sweep-ai";
 104}
 105
 106pub struct MercuryFeatureFlag;
 107
 108impl FeatureFlag for MercuryFeatureFlag {
 109    const NAME: &str = "mercury";
 110}
 111
 112pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
 113    context: EditPredictionExcerptOptions {
 114        max_bytes: 512,
 115        min_bytes: 128,
 116        target_before_cursor_over_total_bytes: 0.5,
 117    },
 118    prompt_format: PromptFormat::DEFAULT,
 119};
 120
 121static USE_OLLAMA: LazyLock<bool> =
 122    LazyLock::new(|| env::var("ZED_ZETA2_OLLAMA").is_ok_and(|var| !var.is_empty()));
 123
 124static EDIT_PREDICTIONS_MODEL_ID: LazyLock<String> = LazyLock::new(|| {
 125    match env::var("ZED_ZETA2_MODEL").as_deref() {
 126        Ok("zeta2-exp") => "4w5n28vw", // Fine-tuned model @ Baseten
 127        Ok(model) => model,
 128        Err(_) if *USE_OLLAMA => "qwen3-coder:30b",
 129        Err(_) => "yqvev8r3", // Vanilla qwen3-coder @ Baseten
 130    }
 131    .to_string()
 132});
 133
 134pub struct Zeta2FeatureFlag;
 135
 136impl FeatureFlag for Zeta2FeatureFlag {
 137    const NAME: &'static str = "zeta2";
 138
 139    fn enabled_for_staff() -> bool {
 140        true
 141    }
 142}
 143
 144pub struct EditPredictionExampleCaptureFeatureFlag;
 145
 146impl FeatureFlag for EditPredictionExampleCaptureFeatureFlag {
 147    const NAME: &'static str = "edit-prediction-example-capture";
 148
 149    fn enabled_for_staff() -> bool {
 150        true
 151    }
 152}
 153
 154#[derive(Clone)]
 155struct EditPredictionStoreGlobal(Entity<EditPredictionStore>);
 156
 157impl Global for EditPredictionStoreGlobal {}
 158
 159pub struct EditPredictionStore {
 160    client: Arc<Client>,
 161    user_store: Entity<UserStore>,
 162    llm_token: LlmApiToken,
 163    _llm_token_subscription: Subscription,
 164    projects: HashMap<EntityId, ProjectState>,
 165    use_context: bool,
 166    options: ZetaOptions,
 167    update_required: bool,
 168    #[cfg(feature = "cli-support")]
 169    eval_cache: Option<Arc<dyn EvalCache>>,
 170    edit_prediction_model: EditPredictionModel,
 171    pub sweep_ai: SweepAi,
 172    pub mercury: Mercury,
 173    data_collection_choice: DataCollectionChoice,
 174    reject_predictions_tx: mpsc::UnboundedSender<EditPredictionRejection>,
 175    shown_predictions: VecDeque<EditPrediction>,
 176    rated_predictions: HashSet<EditPredictionId>,
 177    custom_predict_edits_url: Option<Arc<Url>>,
 178}
 179
 180#[derive(Copy, Clone, Default, PartialEq, Eq)]
 181pub enum EditPredictionModel {
 182    #[default]
 183    Zeta1,
 184    Zeta2,
 185    Sweep,
 186    Mercury,
 187}
 188
 189pub struct EditPredictionModelInput {
 190    project: Entity<Project>,
 191    buffer: Entity<Buffer>,
 192    snapshot: BufferSnapshot,
 193    position: Anchor,
 194    events: Vec<Arc<zeta_prompt::Event>>,
 195    related_files: Arc<[RelatedFile]>,
 196    recent_paths: VecDeque<ProjectPath>,
 197    trigger: PredictEditsRequestTrigger,
 198    diagnostic_search_range: Range<Point>,
 199    debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
 200}
 201
 202#[derive(Debug, Clone, PartialEq)]
 203pub struct ZetaOptions {
 204    pub context: EditPredictionExcerptOptions,
 205    pub prompt_format: predict_edits_v3::PromptFormat,
 206}
 207
 208#[derive(Debug)]
 209pub enum DebugEvent {
 210    ContextRetrievalStarted(ContextRetrievalStartedDebugEvent),
 211    ContextRetrievalFinished(ContextRetrievalFinishedDebugEvent),
 212    EditPredictionStarted(EditPredictionStartedDebugEvent),
 213    EditPredictionFinished(EditPredictionFinishedDebugEvent),
 214}
 215
 216#[derive(Debug)]
 217pub struct ContextRetrievalStartedDebugEvent {
 218    pub project_entity_id: EntityId,
 219    pub timestamp: Instant,
 220    pub search_prompt: String,
 221}
 222
 223#[derive(Debug)]
 224pub struct ContextRetrievalFinishedDebugEvent {
 225    pub project_entity_id: EntityId,
 226    pub timestamp: Instant,
 227    pub metadata: Vec<(&'static str, SharedString)>,
 228}
 229
 230#[derive(Debug)]
 231pub struct EditPredictionStartedDebugEvent {
 232    pub buffer: WeakEntity<Buffer>,
 233    pub position: Anchor,
 234    pub prompt: Option<String>,
 235}
 236
 237#[derive(Debug)]
 238pub struct EditPredictionFinishedDebugEvent {
 239    pub buffer: WeakEntity<Buffer>,
 240    pub position: Anchor,
 241    pub model_output: Option<String>,
 242}
 243
 244pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
 245
 246/// An event with associated metadata for reconstructing buffer state.
 247#[derive(Clone)]
 248pub struct StoredEvent {
 249    pub event: Arc<zeta_prompt::Event>,
 250    pub old_snapshot: TextBufferSnapshot,
 251}
 252
 253struct ProjectState {
 254    events: VecDeque<StoredEvent>,
 255    last_event: Option<LastEvent>,
 256    recent_paths: VecDeque<ProjectPath>,
 257    registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
 258    current_prediction: Option<CurrentEditPrediction>,
 259    next_pending_prediction_id: usize,
 260    pending_predictions: ArrayVec<PendingPrediction, 2>,
 261    debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
 262    last_prediction_refresh: Option<(EntityId, Instant)>,
 263    cancelled_predictions: HashSet<usize>,
 264    context: Entity<RelatedExcerptStore>,
 265    license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
 266    _subscription: gpui::Subscription,
 267}
 268
 269impl ProjectState {
 270    pub fn events(&self, cx: &App) -> Vec<StoredEvent> {
 271        self.events
 272            .iter()
 273            .cloned()
 274            .chain(
 275                self.last_event
 276                    .as_ref()
 277                    .and_then(|event| event.finalize(&self.license_detection_watchers, cx)),
 278            )
 279            .collect()
 280    }
 281
 282    pub fn events_split_by_pause(&self, cx: &App) -> Vec<StoredEvent> {
 283        self.events
 284            .iter()
 285            .cloned()
 286            .chain(self.last_event.as_ref().iter().flat_map(|event| {
 287                let (one, two) = event.split_by_pause();
 288                let one = one.finalize(&self.license_detection_watchers, cx);
 289                let two = two.and_then(|two| two.finalize(&self.license_detection_watchers, cx));
 290                one.into_iter().chain(two)
 291            }))
 292            .collect()
 293    }
 294
 295    fn cancel_pending_prediction(
 296        &mut self,
 297        pending_prediction: PendingPrediction,
 298        cx: &mut Context<EditPredictionStore>,
 299    ) {
 300        self.cancelled_predictions.insert(pending_prediction.id);
 301
 302        cx.spawn(async move |this, cx| {
 303            let Some(prediction_id) = pending_prediction.task.await else {
 304                return;
 305            };
 306
 307            this.update(cx, |this, _cx| {
 308                this.reject_prediction(prediction_id, EditPredictionRejectReason::Canceled, false);
 309            })
 310            .ok();
 311        })
 312        .detach()
 313    }
 314
 315    fn active_buffer(
 316        &self,
 317        project: &Entity<Project>,
 318        cx: &App,
 319    ) -> Option<(Entity<Buffer>, Option<Anchor>)> {
 320        let project = project.read(cx);
 321        let active_path = project.path_for_entry(project.active_entry()?, cx)?;
 322        let active_buffer = project.buffer_store().read(cx).get_by_path(&active_path)?;
 323        let registered_buffer = self.registered_buffers.get(&active_buffer.entity_id())?;
 324        Some((active_buffer, registered_buffer.last_position))
 325    }
 326}
 327
 328#[derive(Debug, Clone)]
 329struct CurrentEditPrediction {
 330    pub requested_by: PredictionRequestedBy,
 331    pub prediction: EditPrediction,
 332    pub was_shown: bool,
 333    pub shown_with: Option<edit_prediction_types::SuggestionDisplayType>,
 334}
 335
 336impl CurrentEditPrediction {
 337    fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
 338        let Some(new_edits) = self
 339            .prediction
 340            .interpolate(&self.prediction.buffer.read(cx))
 341        else {
 342            return false;
 343        };
 344
 345        if self.prediction.buffer != old_prediction.prediction.buffer {
 346            return true;
 347        }
 348
 349        let Some(old_edits) = old_prediction
 350            .prediction
 351            .interpolate(&old_prediction.prediction.buffer.read(cx))
 352        else {
 353            return true;
 354        };
 355
 356        let requested_by_buffer_id = self.requested_by.buffer_id();
 357
 358        // This reduces the occurrence of UI thrash from replacing edits
 359        //
 360        // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
 361        if requested_by_buffer_id == Some(self.prediction.buffer.entity_id())
 362            && requested_by_buffer_id == Some(old_prediction.prediction.buffer.entity_id())
 363            && old_edits.len() == 1
 364            && new_edits.len() == 1
 365        {
 366            let (old_range, old_text) = &old_edits[0];
 367            let (new_range, new_text) = &new_edits[0];
 368            new_range == old_range && new_text.starts_with(old_text.as_ref())
 369        } else {
 370            true
 371        }
 372    }
 373}
 374
 375#[derive(Debug, Clone)]
 376enum PredictionRequestedBy {
 377    DiagnosticsUpdate,
 378    Buffer(EntityId),
 379}
 380
 381impl PredictionRequestedBy {
 382    pub fn buffer_id(&self) -> Option<EntityId> {
 383        match self {
 384            PredictionRequestedBy::DiagnosticsUpdate => None,
 385            PredictionRequestedBy::Buffer(buffer_id) => Some(*buffer_id),
 386        }
 387    }
 388}
 389
 390#[derive(Debug)]
 391struct PendingPrediction {
 392    id: usize,
 393    task: Task<Option<EditPredictionId>>,
 394}
 395
 396/// A prediction from the perspective of a buffer.
 397#[derive(Debug)]
 398enum BufferEditPrediction<'a> {
 399    Local { prediction: &'a EditPrediction },
 400    Jump { prediction: &'a EditPrediction },
 401}
 402
 403#[cfg(test)]
 404impl std::ops::Deref for BufferEditPrediction<'_> {
 405    type Target = EditPrediction;
 406
 407    fn deref(&self) -> &Self::Target {
 408        match self {
 409            BufferEditPrediction::Local { prediction } => prediction,
 410            BufferEditPrediction::Jump { prediction } => prediction,
 411        }
 412    }
 413}
 414
 415struct RegisteredBuffer {
 416    file: Option<Arc<dyn File>>,
 417    snapshot: TextBufferSnapshot,
 418    last_position: Option<Anchor>,
 419    _subscriptions: [gpui::Subscription; 2],
 420}
 421
 422#[derive(Clone)]
 423struct LastEvent {
 424    old_snapshot: TextBufferSnapshot,
 425    new_snapshot: TextBufferSnapshot,
 426    old_file: Option<Arc<dyn File>>,
 427    new_file: Option<Arc<dyn File>>,
 428    edit_range: Option<Range<Anchor>>,
 429    snapshot_after_last_editing_pause: Option<TextBufferSnapshot>,
 430    last_edit_time: Option<Instant>,
 431}
 432
 433impl LastEvent {
 434    pub fn finalize(
 435        &self,
 436        license_detection_watchers: &HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
 437        cx: &App,
 438    ) -> Option<StoredEvent> {
 439        let path = buffer_path_with_id_fallback(self.new_file.as_ref(), &self.new_snapshot, cx);
 440        let old_path = buffer_path_with_id_fallback(self.old_file.as_ref(), &self.old_snapshot, cx);
 441
 442        let in_open_source_repo =
 443            [self.new_file.as_ref(), self.old_file.as_ref()]
 444                .iter()
 445                .all(|file| {
 446                    file.is_some_and(|file| {
 447                        license_detection_watchers
 448                            .get(&file.worktree_id(cx))
 449                            .is_some_and(|watcher| watcher.is_project_open_source())
 450                    })
 451                });
 452
 453        let diff = compute_diff_between_snapshots(&self.old_snapshot, &self.new_snapshot)?;
 454
 455        if path == old_path && diff.is_empty() {
 456            None
 457        } else {
 458            Some(StoredEvent {
 459                event: Arc::new(zeta_prompt::Event::BufferChange {
 460                    old_path,
 461                    path,
 462                    diff,
 463                    in_open_source_repo,
 464                    // TODO: Actually detect if this edit was predicted or not
 465                    predicted: false,
 466                }),
 467                old_snapshot: self.old_snapshot.clone(),
 468            })
 469        }
 470    }
 471
 472    pub fn split_by_pause(&self) -> (LastEvent, Option<LastEvent>) {
 473        let Some(boundary_snapshot) = self.snapshot_after_last_editing_pause.as_ref() else {
 474            return (self.clone(), None);
 475        };
 476
 477        let before = LastEvent {
 478            old_snapshot: self.old_snapshot.clone(),
 479            new_snapshot: boundary_snapshot.clone(),
 480            old_file: self.old_file.clone(),
 481            new_file: self.new_file.clone(),
 482            edit_range: None,
 483            snapshot_after_last_editing_pause: None,
 484            last_edit_time: self.last_edit_time,
 485        };
 486
 487        let after = LastEvent {
 488            old_snapshot: boundary_snapshot.clone(),
 489            new_snapshot: self.new_snapshot.clone(),
 490            old_file: self.old_file.clone(),
 491            new_file: self.new_file.clone(),
 492            edit_range: None,
 493            snapshot_after_last_editing_pause: None,
 494            last_edit_time: self.last_edit_time,
 495        };
 496
 497        (before, Some(after))
 498    }
 499}
 500
 501pub(crate) fn compute_diff_between_snapshots(
 502    old_snapshot: &TextBufferSnapshot,
 503    new_snapshot: &TextBufferSnapshot,
 504) -> Option<String> {
 505    let edits: Vec<Edit<usize>> = new_snapshot
 506        .edits_since::<usize>(&old_snapshot.version)
 507        .collect();
 508
 509    let (first_edit, last_edit) = edits.first().zip(edits.last())?;
 510
 511    let old_start_point = old_snapshot.offset_to_point(first_edit.old.start);
 512    let old_end_point = old_snapshot.offset_to_point(last_edit.old.end);
 513    let new_start_point = new_snapshot.offset_to_point(first_edit.new.start);
 514    let new_end_point = new_snapshot.offset_to_point(last_edit.new.end);
 515
 516    const CONTEXT_LINES: u32 = 3;
 517
 518    let old_context_start_row = old_start_point.row.saturating_sub(CONTEXT_LINES);
 519    let new_context_start_row = new_start_point.row.saturating_sub(CONTEXT_LINES);
 520    let old_context_end_row =
 521        (old_end_point.row + 1 + CONTEXT_LINES).min(old_snapshot.max_point().row);
 522    let new_context_end_row =
 523        (new_end_point.row + 1 + CONTEXT_LINES).min(new_snapshot.max_point().row);
 524
 525    let old_start_line_offset = old_snapshot.point_to_offset(Point::new(old_context_start_row, 0));
 526    let new_start_line_offset = new_snapshot.point_to_offset(Point::new(new_context_start_row, 0));
 527    let old_end_line_offset = old_snapshot
 528        .point_to_offset(Point::new(old_context_end_row + 1, 0).min(old_snapshot.max_point()));
 529    let new_end_line_offset = new_snapshot
 530        .point_to_offset(Point::new(new_context_end_row + 1, 0).min(new_snapshot.max_point()));
 531    let old_edit_range = old_start_line_offset..old_end_line_offset;
 532    let new_edit_range = new_start_line_offset..new_end_line_offset;
 533
 534    let old_region_text: String = old_snapshot.text_for_range(old_edit_range).collect();
 535    let new_region_text: String = new_snapshot.text_for_range(new_edit_range).collect();
 536
 537    let diff = language::unified_diff_with_offsets(
 538        &old_region_text,
 539        &new_region_text,
 540        old_context_start_row,
 541        new_context_start_row,
 542    );
 543
 544    Some(diff)
 545}
 546
 547fn buffer_path_with_id_fallback(
 548    file: Option<&Arc<dyn File>>,
 549    snapshot: &TextBufferSnapshot,
 550    cx: &App,
 551) -> Arc<Path> {
 552    if let Some(file) = file {
 553        file.full_path(cx).into()
 554    } else {
 555        Path::new(&format!("untitled-{}", snapshot.remote_id())).into()
 556    }
 557}
 558
 559impl EditPredictionStore {
 560    pub fn try_global(cx: &App) -> Option<Entity<Self>> {
 561        cx.try_global::<EditPredictionStoreGlobal>()
 562            .map(|global| global.0.clone())
 563    }
 564
 565    pub fn global(
 566        client: &Arc<Client>,
 567        user_store: &Entity<UserStore>,
 568        cx: &mut App,
 569    ) -> Entity<Self> {
 570        cx.try_global::<EditPredictionStoreGlobal>()
 571            .map(|global| global.0.clone())
 572            .unwrap_or_else(|| {
 573                let ep_store = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
 574                cx.set_global(EditPredictionStoreGlobal(ep_store.clone()));
 575                ep_store
 576            })
 577    }
 578
 579    pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
 580        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
 581        let data_collection_choice = Self::load_data_collection_choice();
 582
 583        let llm_token = LlmApiToken::default();
 584
 585        let (reject_tx, reject_rx) = mpsc::unbounded();
 586        cx.background_spawn({
 587            let client = client.clone();
 588            let llm_token = llm_token.clone();
 589            let app_version = AppVersion::global(cx);
 590            let background_executor = cx.background_executor().clone();
 591            async move {
 592                Self::handle_rejected_predictions(
 593                    reject_rx,
 594                    client,
 595                    llm_token,
 596                    app_version,
 597                    background_executor,
 598                )
 599                .await
 600            }
 601        })
 602        .detach();
 603
 604        let mut this = Self {
 605            projects: HashMap::default(),
 606            client,
 607            user_store,
 608            options: DEFAULT_OPTIONS,
 609            use_context: false,
 610            llm_token,
 611            _llm_token_subscription: cx.subscribe(
 612                &refresh_llm_token_listener,
 613                |this, _listener, _event, cx| {
 614                    let client = this.client.clone();
 615                    let llm_token = this.llm_token.clone();
 616                    cx.spawn(async move |_this, _cx| {
 617                        llm_token.refresh(&client).await?;
 618                        anyhow::Ok(())
 619                    })
 620                    .detach_and_log_err(cx);
 621                },
 622            ),
 623            update_required: false,
 624            #[cfg(feature = "cli-support")]
 625            eval_cache: None,
 626            edit_prediction_model: EditPredictionModel::Zeta2,
 627            sweep_ai: SweepAi::new(cx),
 628            mercury: Mercury::new(cx),
 629            data_collection_choice,
 630            reject_predictions_tx: reject_tx,
 631            rated_predictions: Default::default(),
 632            shown_predictions: Default::default(),
 633            custom_predict_edits_url: match env::var("ZED_PREDICT_EDITS_URL") {
 634                Ok(custom_url) => Url::parse(&custom_url).log_err().map(Into::into),
 635                Err(_) => {
 636                    if *USE_OLLAMA {
 637                        Some(
 638                            Url::parse("http://localhost:11434/v1/chat/completions")
 639                                .unwrap()
 640                                .into(),
 641                        )
 642                    } else {
 643                        None
 644                    }
 645                }
 646            },
 647        };
 648
 649        this.configure_context_retrieval(cx);
 650        let weak_this = cx.weak_entity();
 651        cx.on_flags_ready(move |_, cx| {
 652            weak_this
 653                .update(cx, |this, cx| this.configure_context_retrieval(cx))
 654                .ok();
 655        })
 656        .detach();
 657        cx.observe_global::<SettingsStore>(|this, cx| {
 658            this.configure_context_retrieval(cx);
 659        })
 660        .detach();
 661
 662        this
 663    }
 664
 665    #[cfg(test)]
 666    pub fn set_custom_predict_edits_url(&mut self, url: Url) {
 667        self.custom_predict_edits_url = Some(url.into());
 668    }
 669
 670    pub fn set_edit_prediction_model(&mut self, model: EditPredictionModel) {
 671        self.edit_prediction_model = model;
 672    }
 673
 674    pub fn has_sweep_api_token(&self, cx: &App) -> bool {
 675        self.sweep_ai.api_token.read(cx).has_key()
 676    }
 677
 678    pub fn has_mercury_api_token(&self, cx: &App) -> bool {
 679        self.mercury.api_token.read(cx).has_key()
 680    }
 681
 682    #[cfg(feature = "cli-support")]
 683    pub fn with_eval_cache(&mut self, cache: Arc<dyn EvalCache>) {
 684        self.eval_cache = Some(cache);
 685    }
 686
 687    pub fn options(&self) -> &ZetaOptions {
 688        &self.options
 689    }
 690
 691    pub fn set_options(&mut self, options: ZetaOptions) {
 692        self.options = options;
 693    }
 694
 695    pub fn set_use_context(&mut self, use_context: bool) {
 696        self.use_context = use_context;
 697    }
 698
 699    pub fn clear_history(&mut self) {
 700        for project_state in self.projects.values_mut() {
 701            project_state.events.clear();
 702            project_state.last_event.take();
 703        }
 704    }
 705
 706    pub fn clear_history_for_project(&mut self, project: &Entity<Project>) {
 707        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
 708            project_state.events.clear();
 709            project_state.last_event.take();
 710        }
 711    }
 712
 713    pub fn edit_history_for_project(
 714        &self,
 715        project: &Entity<Project>,
 716        cx: &App,
 717    ) -> Vec<StoredEvent> {
 718        self.projects
 719            .get(&project.entity_id())
 720            .map(|project_state| project_state.events(cx))
 721            .unwrap_or_default()
 722    }
 723
 724    pub fn edit_history_for_project_with_pause_split_last_event(
 725        &self,
 726        project: &Entity<Project>,
 727        cx: &App,
 728    ) -> Vec<StoredEvent> {
 729        self.projects
 730            .get(&project.entity_id())
 731            .map(|project_state| project_state.events_split_by_pause(cx))
 732            .unwrap_or_default()
 733    }
 734
 735    pub fn context_for_project<'a>(
 736        &'a self,
 737        project: &Entity<Project>,
 738        cx: &'a App,
 739    ) -> Arc<[RelatedFile]> {
 740        self.projects
 741            .get(&project.entity_id())
 742            .map(|project| project.context.read(cx).related_files())
 743            .unwrap_or_else(|| vec![].into())
 744    }
 745
 746    pub fn context_for_project_with_buffers<'a>(
 747        &'a self,
 748        project: &Entity<Project>,
 749        cx: &'a App,
 750    ) -> Option<impl 'a + Iterator<Item = (RelatedFile, Entity<Buffer>)>> {
 751        self.projects
 752            .get(&project.entity_id())
 753            .map(|project| project.context.read(cx).related_files_with_buffers())
 754    }
 755
 756    pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
 757        if self.edit_prediction_model == EditPredictionModel::Zeta2 {
 758            self.user_store.read(cx).edit_prediction_usage()
 759        } else {
 760            None
 761        }
 762    }
 763
 764    pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
 765        self.get_or_init_project(project, cx);
 766    }
 767
 768    pub fn register_buffer(
 769        &mut self,
 770        buffer: &Entity<Buffer>,
 771        project: &Entity<Project>,
 772        cx: &mut Context<Self>,
 773    ) {
 774        let project_state = self.get_or_init_project(project, cx);
 775        Self::register_buffer_impl(project_state, buffer, project, cx);
 776    }
 777
 778    fn get_or_init_project(
 779        &mut self,
 780        project: &Entity<Project>,
 781        cx: &mut Context<Self>,
 782    ) -> &mut ProjectState {
 783        let entity_id = project.entity_id();
 784        self.projects
 785            .entry(entity_id)
 786            .or_insert_with(|| ProjectState {
 787                context: {
 788                    let related_excerpt_store = cx.new(|cx| RelatedExcerptStore::new(project, cx));
 789                    cx.subscribe(&related_excerpt_store, move |this, _, event, _| {
 790                        this.handle_excerpt_store_event(entity_id, event);
 791                    })
 792                    .detach();
 793                    related_excerpt_store
 794                },
 795                events: VecDeque::new(),
 796                last_event: None,
 797                recent_paths: VecDeque::new(),
 798                debug_tx: None,
 799                registered_buffers: HashMap::default(),
 800                current_prediction: None,
 801                cancelled_predictions: HashSet::default(),
 802                pending_predictions: ArrayVec::new(),
 803                next_pending_prediction_id: 0,
 804                last_prediction_refresh: None,
 805                license_detection_watchers: HashMap::default(),
 806                _subscription: cx.subscribe(&project, Self::handle_project_event),
 807            })
 808    }
 809
 810    pub fn remove_project(&mut self, project: &Entity<Project>) {
 811        self.projects.remove(&project.entity_id());
 812    }
 813
 814    fn handle_excerpt_store_event(
 815        &mut self,
 816        project_entity_id: EntityId,
 817        event: &RelatedExcerptStoreEvent,
 818    ) {
 819        if let Some(project_state) = self.projects.get(&project_entity_id) {
 820            if let Some(debug_tx) = project_state.debug_tx.clone() {
 821                match event {
 822                    RelatedExcerptStoreEvent::StartedRefresh => {
 823                        debug_tx
 824                            .unbounded_send(DebugEvent::ContextRetrievalStarted(
 825                                ContextRetrievalStartedDebugEvent {
 826                                    project_entity_id: project_entity_id,
 827                                    timestamp: Instant::now(),
 828                                    search_prompt: String::new(),
 829                                },
 830                            ))
 831                            .ok();
 832                    }
 833                    RelatedExcerptStoreEvent::FinishedRefresh {
 834                        cache_hit_count,
 835                        cache_miss_count,
 836                        mean_definition_latency,
 837                        max_definition_latency,
 838                    } => {
 839                        debug_tx
 840                            .unbounded_send(DebugEvent::ContextRetrievalFinished(
 841                                ContextRetrievalFinishedDebugEvent {
 842                                    project_entity_id: project_entity_id,
 843                                    timestamp: Instant::now(),
 844                                    metadata: vec![
 845                                        (
 846                                            "Cache Hits",
 847                                            format!(
 848                                                "{}/{}",
 849                                                cache_hit_count,
 850                                                cache_hit_count + cache_miss_count
 851                                            )
 852                                            .into(),
 853                                        ),
 854                                        (
 855                                            "Max LSP Time",
 856                                            format!("{} ms", max_definition_latency.as_millis())
 857                                                .into(),
 858                                        ),
 859                                        (
 860                                            "Mean LSP Time",
 861                                            format!("{} ms", mean_definition_latency.as_millis())
 862                                                .into(),
 863                                        ),
 864                                    ],
 865                                },
 866                            ))
 867                            .ok();
 868                    }
 869                }
 870            }
 871        }
 872    }
 873
 874    pub fn debug_info(
 875        &mut self,
 876        project: &Entity<Project>,
 877        cx: &mut Context<Self>,
 878    ) -> mpsc::UnboundedReceiver<DebugEvent> {
 879        let project_state = self.get_or_init_project(project, cx);
 880        let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
 881        project_state.debug_tx = Some(debug_watch_tx);
 882        debug_watch_rx
 883    }
 884
 885    fn handle_project_event(
 886        &mut self,
 887        project: Entity<Project>,
 888        event: &project::Event,
 889        cx: &mut Context<Self>,
 890    ) {
 891        // TODO [zeta2] init with recent paths
 892        match event {
 893            project::Event::ActiveEntryChanged(Some(active_entry_id)) => {
 894                let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
 895                    return;
 896                };
 897                let path = project.read(cx).path_for_entry(*active_entry_id, cx);
 898                if let Some(path) = path {
 899                    if let Some(ix) = project_state
 900                        .recent_paths
 901                        .iter()
 902                        .position(|probe| probe == &path)
 903                    {
 904                        project_state.recent_paths.remove(ix);
 905                    }
 906                    project_state.recent_paths.push_front(path);
 907                }
 908            }
 909            project::Event::DiagnosticsUpdated { .. } => {
 910                if cx.has_flag::<Zeta2FeatureFlag>() {
 911                    self.refresh_prediction_from_diagnostics(project, cx);
 912                }
 913            }
 914            _ => (),
 915        }
 916    }
 917
 918    fn register_buffer_impl<'a>(
 919        project_state: &'a mut ProjectState,
 920        buffer: &Entity<Buffer>,
 921        project: &Entity<Project>,
 922        cx: &mut Context<Self>,
 923    ) -> &'a mut RegisteredBuffer {
 924        let buffer_id = buffer.entity_id();
 925
 926        if let Some(file) = buffer.read(cx).file() {
 927            let worktree_id = file.worktree_id(cx);
 928            if let Some(worktree) = project.read(cx).worktree_for_id(worktree_id, cx) {
 929                project_state
 930                    .license_detection_watchers
 931                    .entry(worktree_id)
 932                    .or_insert_with(|| {
 933                        let project_entity_id = project.entity_id();
 934                        cx.observe_release(&worktree, move |this, _worktree, _cx| {
 935                            let Some(project_state) = this.projects.get_mut(&project_entity_id)
 936                            else {
 937                                return;
 938                            };
 939                            project_state
 940                                .license_detection_watchers
 941                                .remove(&worktree_id);
 942                        })
 943                        .detach();
 944                        Rc::new(LicenseDetectionWatcher::new(&worktree, cx))
 945                    });
 946            }
 947        }
 948
 949        match project_state.registered_buffers.entry(buffer_id) {
 950            hash_map::Entry::Occupied(entry) => entry.into_mut(),
 951            hash_map::Entry::Vacant(entry) => {
 952                let buf = buffer.read(cx);
 953                let snapshot = buf.text_snapshot();
 954                let file = buf.file().cloned();
 955                let project_entity_id = project.entity_id();
 956                entry.insert(RegisteredBuffer {
 957                    snapshot,
 958                    file,
 959                    last_position: None,
 960                    _subscriptions: [
 961                        cx.subscribe(buffer, {
 962                            let project = project.downgrade();
 963                            move |this, buffer, event, cx| {
 964                                if let language::BufferEvent::Edited = event
 965                                    && let Some(project) = project.upgrade()
 966                                {
 967                                    this.report_changes_for_buffer(&buffer, &project, cx);
 968                                }
 969                            }
 970                        }),
 971                        cx.observe_release(buffer, move |this, _buffer, _cx| {
 972                            let Some(project_state) = this.projects.get_mut(&project_entity_id)
 973                            else {
 974                                return;
 975                            };
 976                            project_state.registered_buffers.remove(&buffer_id);
 977                        }),
 978                    ],
 979                })
 980            }
 981        }
 982    }
 983
 984    fn report_changes_for_buffer(
 985        &mut self,
 986        buffer: &Entity<Buffer>,
 987        project: &Entity<Project>,
 988        cx: &mut Context<Self>,
 989    ) {
 990        let project_state = self.get_or_init_project(project, cx);
 991        let registered_buffer = Self::register_buffer_impl(project_state, buffer, project, cx);
 992
 993        let buf = buffer.read(cx);
 994        let new_file = buf.file().cloned();
 995        let new_snapshot = buf.text_snapshot();
 996        if new_snapshot.version == registered_buffer.snapshot.version {
 997            return;
 998        }
 999
1000        let old_file = mem::replace(&mut registered_buffer.file, new_file.clone());
1001        let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
1002        let edit_range = edited_range(&old_snapshot, &new_snapshot);
1003        let events = &mut project_state.events;
1004
1005        let now = cx.background_executor().now();
1006        if let Some(last_event) = project_state.last_event.as_mut() {
1007            let is_next_snapshot_of_same_buffer = old_snapshot.remote_id()
1008                == last_event.new_snapshot.remote_id()
1009                && old_snapshot.version == last_event.new_snapshot.version;
1010
1011            let should_coalesce = is_next_snapshot_of_same_buffer
1012                && edit_range
1013                    .as_ref()
1014                    .zip(last_event.edit_range.as_ref())
1015                    .is_some_and(|(a, b)| {
1016                        let a = a.to_point(&new_snapshot);
1017                        let b = b.to_point(&new_snapshot);
1018                        if a.start > b.end {
1019                            a.start.row.abs_diff(b.end.row) <= CHANGE_GROUPING_LINE_SPAN
1020                        } else if b.start > a.end {
1021                            b.start.row.abs_diff(a.end.row) <= CHANGE_GROUPING_LINE_SPAN
1022                        } else {
1023                            true
1024                        }
1025                    });
1026
1027            if should_coalesce {
1028                let pause_elapsed = last_event
1029                    .last_edit_time
1030                    .map(|t| now.duration_since(t) >= LAST_CHANGE_GROUPING_TIME)
1031                    .unwrap_or(false);
1032                if pause_elapsed {
1033                    last_event.snapshot_after_last_editing_pause =
1034                        Some(last_event.new_snapshot.clone());
1035                }
1036
1037                last_event.edit_range = edit_range;
1038                last_event.new_snapshot = new_snapshot;
1039                last_event.last_edit_time = Some(now);
1040                return;
1041            }
1042        }
1043
1044        if let Some(event) = project_state.last_event.take() {
1045            if let Some(event) = event.finalize(&project_state.license_detection_watchers, cx) {
1046                if events.len() + 1 >= EVENT_COUNT_MAX {
1047                    events.pop_front();
1048                }
1049                events.push_back(event);
1050            }
1051        }
1052
1053        project_state.last_event = Some(LastEvent {
1054            old_file,
1055            new_file,
1056            old_snapshot,
1057            new_snapshot,
1058            edit_range,
1059            snapshot_after_last_editing_pause: None,
1060            last_edit_time: Some(now),
1061        });
1062    }
1063
1064    fn prediction_at(
1065        &mut self,
1066        buffer: &Entity<Buffer>,
1067        position: Option<language::Anchor>,
1068        project: &Entity<Project>,
1069        cx: &App,
1070    ) -> Option<BufferEditPrediction<'_>> {
1071        let project_state = self.projects.get_mut(&project.entity_id())?;
1072        if let Some(position) = position
1073            && let Some(buffer) = project_state
1074                .registered_buffers
1075                .get_mut(&buffer.entity_id())
1076        {
1077            buffer.last_position = Some(position);
1078        }
1079
1080        let CurrentEditPrediction {
1081            requested_by,
1082            prediction,
1083            ..
1084        } = project_state.current_prediction.as_ref()?;
1085
1086        if prediction.targets_buffer(buffer.read(cx)) {
1087            Some(BufferEditPrediction::Local { prediction })
1088        } else {
1089            let show_jump = match requested_by {
1090                PredictionRequestedBy::Buffer(requested_by_buffer_id) => {
1091                    requested_by_buffer_id == &buffer.entity_id()
1092                }
1093                PredictionRequestedBy::DiagnosticsUpdate => true,
1094            };
1095
1096            if show_jump {
1097                Some(BufferEditPrediction::Jump { prediction })
1098            } else {
1099                None
1100            }
1101        }
1102    }
1103
1104    fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
1105        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1106            return;
1107        };
1108
1109        let Some(current_prediction) = project_state.current_prediction.take() else {
1110            return;
1111        };
1112
1113        for pending_prediction in mem::take(&mut project_state.pending_predictions) {
1114            project_state.cancel_pending_prediction(pending_prediction, cx);
1115        }
1116
1117        match self.edit_prediction_model {
1118            EditPredictionModel::Sweep => {
1119                sweep_ai::edit_prediction_accepted(self, current_prediction, cx)
1120            }
1121            EditPredictionModel::Mercury => {}
1122            EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {
1123                zeta2::edit_prediction_accepted(self, current_prediction, cx)
1124            }
1125        }
1126    }
1127
1128    async fn handle_rejected_predictions(
1129        rx: UnboundedReceiver<EditPredictionRejection>,
1130        client: Arc<Client>,
1131        llm_token: LlmApiToken,
1132        app_version: Version,
1133        background_executor: BackgroundExecutor,
1134    ) {
1135        let mut rx = std::pin::pin!(rx.peekable());
1136        let mut batched = Vec::new();
1137
1138        while let Some(rejection) = rx.next().await {
1139            batched.push(rejection);
1140
1141            if batched.len() < MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST / 2 {
1142                select_biased! {
1143                    next = rx.as_mut().peek().fuse() => {
1144                        if next.is_some() {
1145                            continue;
1146                        }
1147                    }
1148                    () = background_executor.timer(REJECT_REQUEST_DEBOUNCE).fuse() => {},
1149                }
1150            }
1151
1152            let url = client
1153                .http_client()
1154                .build_zed_llm_url("/predict_edits/reject", &[])
1155                .unwrap();
1156
1157            let flush_count = batched
1158                .len()
1159                // in case items have accumulated after failure
1160                .min(MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST);
1161            let start = batched.len() - flush_count;
1162
1163            let body = RejectEditPredictionsBodyRef {
1164                rejections: &batched[start..],
1165            };
1166
1167            let result = Self::send_api_request::<()>(
1168                |builder| {
1169                    let req = builder
1170                        .uri(url.as_ref())
1171                        .body(serde_json::to_string(&body)?.into());
1172                    anyhow::Ok(req?)
1173                },
1174                client.clone(),
1175                llm_token.clone(),
1176                app_version.clone(),
1177                true,
1178            )
1179            .await;
1180
1181            if result.log_err().is_some() {
1182                batched.drain(start..);
1183            }
1184        }
1185    }
1186
1187    fn reject_current_prediction(
1188        &mut self,
1189        reason: EditPredictionRejectReason,
1190        project: &Entity<Project>,
1191    ) {
1192        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
1193            project_state.pending_predictions.clear();
1194            if let Some(prediction) = project_state.current_prediction.take() {
1195                self.reject_prediction(prediction.prediction.id, reason, prediction.was_shown);
1196            }
1197        };
1198    }
1199
1200    fn did_show_current_prediction(
1201        &mut self,
1202        project: &Entity<Project>,
1203        display_type: edit_prediction_types::SuggestionDisplayType,
1204        cx: &mut Context<Self>,
1205    ) {
1206        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1207            return;
1208        };
1209
1210        let Some(current_prediction) = project_state.current_prediction.as_mut() else {
1211            return;
1212        };
1213
1214        let is_jump = display_type == edit_prediction_types::SuggestionDisplayType::Jump;
1215        let previous_shown_with = current_prediction.shown_with;
1216
1217        if previous_shown_with.is_none() || !is_jump {
1218            current_prediction.shown_with = Some(display_type);
1219        }
1220
1221        let is_first_non_jump_show = !current_prediction.was_shown && !is_jump;
1222
1223        if is_first_non_jump_show {
1224            current_prediction.was_shown = true;
1225        }
1226
1227        let display_type_changed = previous_shown_with != Some(display_type);
1228
1229        if self.edit_prediction_model == EditPredictionModel::Sweep && display_type_changed {
1230            sweep_ai::edit_prediction_shown(
1231                &self.sweep_ai,
1232                self.client.clone(),
1233                &current_prediction.prediction,
1234                display_type,
1235                cx,
1236            );
1237        }
1238
1239        if is_first_non_jump_show {
1240            self.shown_predictions
1241                .push_front(current_prediction.prediction.clone());
1242            if self.shown_predictions.len() > 50 {
1243                let completion = self.shown_predictions.pop_back().unwrap();
1244                self.rated_predictions.remove(&completion.id);
1245            }
1246        }
1247    }
1248
1249    fn reject_prediction(
1250        &mut self,
1251        prediction_id: EditPredictionId,
1252        reason: EditPredictionRejectReason,
1253        was_shown: bool,
1254    ) {
1255        match self.edit_prediction_model {
1256            EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {
1257                if self.custom_predict_edits_url.is_some() {
1258                    return;
1259                }
1260            }
1261            EditPredictionModel::Sweep | EditPredictionModel::Mercury => return,
1262        }
1263
1264        self.reject_predictions_tx
1265            .unbounded_send(EditPredictionRejection {
1266                request_id: prediction_id.to_string(),
1267                reason,
1268                was_shown,
1269            })
1270            .log_err();
1271    }
1272
1273    fn is_refreshing(&self, project: &Entity<Project>) -> bool {
1274        self.projects
1275            .get(&project.entity_id())
1276            .is_some_and(|project_state| !project_state.pending_predictions.is_empty())
1277    }
1278
1279    pub fn refresh_prediction_from_buffer(
1280        &mut self,
1281        project: Entity<Project>,
1282        buffer: Entity<Buffer>,
1283        position: language::Anchor,
1284        cx: &mut Context<Self>,
1285    ) {
1286        self.queue_prediction_refresh(project.clone(), buffer.entity_id(), cx, move |this, cx| {
1287            let Some(request_task) = this
1288                .update(cx, |this, cx| {
1289                    this.request_prediction(
1290                        &project,
1291                        &buffer,
1292                        position,
1293                        PredictEditsRequestTrigger::Other,
1294                        cx,
1295                    )
1296                })
1297                .log_err()
1298            else {
1299                return Task::ready(anyhow::Ok(None));
1300            };
1301
1302            cx.spawn(async move |_cx| {
1303                request_task.await.map(|prediction_result| {
1304                    prediction_result.map(|prediction_result| {
1305                        (
1306                            prediction_result,
1307                            PredictionRequestedBy::Buffer(buffer.entity_id()),
1308                        )
1309                    })
1310                })
1311            })
1312        })
1313    }
1314
1315    pub fn refresh_prediction_from_diagnostics(
1316        &mut self,
1317        project: Entity<Project>,
1318        cx: &mut Context<Self>,
1319    ) {
1320        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1321            return;
1322        };
1323
1324        // Prefer predictions from buffer
1325        if project_state.current_prediction.is_some() {
1326            return;
1327        };
1328
1329        self.queue_prediction_refresh(project.clone(), project.entity_id(), cx, move |this, cx| {
1330            let Some((active_buffer, snapshot, cursor_point)) = this
1331                .read_with(cx, |this, cx| {
1332                    let project_state = this.projects.get(&project.entity_id())?;
1333                    let (buffer, position) = project_state.active_buffer(&project, cx)?;
1334                    let snapshot = buffer.read(cx).snapshot();
1335
1336                    if !Self::predictions_enabled_at(&snapshot, position, cx) {
1337                        return None;
1338                    }
1339
1340                    let cursor_point = position
1341                        .map(|pos| pos.to_point(&snapshot))
1342                        .unwrap_or_default();
1343
1344                    Some((buffer, snapshot, cursor_point))
1345                })
1346                .log_err()
1347                .flatten()
1348            else {
1349                return Task::ready(anyhow::Ok(None));
1350            };
1351
1352            cx.spawn(async move |cx| {
1353                let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1354                    active_buffer,
1355                    &snapshot,
1356                    Default::default(),
1357                    cursor_point,
1358                    &project,
1359                    cx,
1360                )
1361                .await?
1362                else {
1363                    return anyhow::Ok(None);
1364                };
1365
1366                let Some(prediction_result) = this
1367                    .update(cx, |this, cx| {
1368                        this.request_prediction(
1369                            &project,
1370                            &jump_buffer,
1371                            jump_position,
1372                            PredictEditsRequestTrigger::Diagnostics,
1373                            cx,
1374                        )
1375                    })?
1376                    .await?
1377                else {
1378                    return anyhow::Ok(None);
1379                };
1380
1381                this.update(cx, |this, cx| {
1382                    Some((
1383                        if this
1384                            .get_or_init_project(&project, cx)
1385                            .current_prediction
1386                            .is_none()
1387                        {
1388                            prediction_result
1389                        } else {
1390                            EditPredictionResult {
1391                                id: prediction_result.id,
1392                                prediction: Err(EditPredictionRejectReason::CurrentPreferred),
1393                            }
1394                        },
1395                        PredictionRequestedBy::DiagnosticsUpdate,
1396                    ))
1397                })
1398            })
1399        });
1400    }
1401
1402    fn predictions_enabled_at(
1403        snapshot: &BufferSnapshot,
1404        position: Option<language::Anchor>,
1405        cx: &App,
1406    ) -> bool {
1407        let file = snapshot.file();
1408        let all_settings = all_language_settings(file, cx);
1409        if !all_settings.show_edit_predictions(snapshot.language(), cx)
1410            || file.is_some_and(|file| !all_settings.edit_predictions_enabled_for_file(file, cx))
1411        {
1412            return false;
1413        }
1414
1415        if let Some(last_position) = position {
1416            let settings = snapshot.settings_at(last_position, cx);
1417
1418            if !settings.edit_predictions_disabled_in.is_empty()
1419                && let Some(scope) = snapshot.language_scope_at(last_position)
1420                && let Some(scope_name) = scope.override_name()
1421                && settings
1422                    .edit_predictions_disabled_in
1423                    .iter()
1424                    .any(|s| s == scope_name)
1425            {
1426                return false;
1427            }
1428        }
1429
1430        true
1431    }
1432
1433    #[cfg(not(test))]
1434    pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
1435    #[cfg(test)]
1436    pub const THROTTLE_TIMEOUT: Duration = Duration::ZERO;
1437
1438    fn queue_prediction_refresh(
1439        &mut self,
1440        project: Entity<Project>,
1441        throttle_entity: EntityId,
1442        cx: &mut Context<Self>,
1443        do_refresh: impl FnOnce(
1444            WeakEntity<Self>,
1445            &mut AsyncApp,
1446        )
1447            -> Task<Result<Option<(EditPredictionResult, PredictionRequestedBy)>>>
1448        + 'static,
1449    ) {
1450        let project_state = self.get_or_init_project(&project, cx);
1451        let pending_prediction_id = project_state.next_pending_prediction_id;
1452        project_state.next_pending_prediction_id += 1;
1453        let last_request = project_state.last_prediction_refresh;
1454
1455        let task = cx.spawn(async move |this, cx| {
1456            if let Some((last_entity, last_timestamp)) = last_request
1457                && throttle_entity == last_entity
1458                && let Some(timeout) =
1459                    (last_timestamp + Self::THROTTLE_TIMEOUT).checked_duration_since(Instant::now())
1460            {
1461                cx.background_executor().timer(timeout).await;
1462            }
1463
1464            // If this task was cancelled before the throttle timeout expired,
1465            // do not perform a request.
1466            let mut is_cancelled = true;
1467            this.update(cx, |this, cx| {
1468                let project_state = this.get_or_init_project(&project, cx);
1469                if !project_state
1470                    .cancelled_predictions
1471                    .remove(&pending_prediction_id)
1472                {
1473                    project_state.last_prediction_refresh = Some((throttle_entity, Instant::now()));
1474                    is_cancelled = false;
1475                }
1476            })
1477            .ok();
1478            if is_cancelled {
1479                return None;
1480            }
1481
1482            let new_prediction_result = do_refresh(this.clone(), cx).await.log_err().flatten();
1483            let new_prediction_id = new_prediction_result
1484                .as_ref()
1485                .map(|(prediction, _)| prediction.id.clone());
1486
1487            // When a prediction completes, remove it from the pending list, and cancel
1488            // any pending predictions that were enqueued before it.
1489            this.update(cx, |this, cx| {
1490                let project_state = this.get_or_init_project(&project, cx);
1491
1492                let is_cancelled = project_state
1493                    .cancelled_predictions
1494                    .remove(&pending_prediction_id);
1495
1496                let new_current_prediction = if !is_cancelled
1497                    && let Some((prediction_result, requested_by)) = new_prediction_result
1498                {
1499                    match prediction_result.prediction {
1500                        Ok(prediction) => {
1501                            let new_prediction = CurrentEditPrediction {
1502                                requested_by,
1503                                prediction,
1504                                was_shown: false,
1505                                shown_with: None,
1506                            };
1507
1508                            if let Some(current_prediction) =
1509                                project_state.current_prediction.as_ref()
1510                            {
1511                                if new_prediction.should_replace_prediction(&current_prediction, cx)
1512                                {
1513                                    this.reject_current_prediction(
1514                                        EditPredictionRejectReason::Replaced,
1515                                        &project,
1516                                    );
1517
1518                                    Some(new_prediction)
1519                                } else {
1520                                    this.reject_prediction(
1521                                        new_prediction.prediction.id,
1522                                        EditPredictionRejectReason::CurrentPreferred,
1523                                        false,
1524                                    );
1525                                    None
1526                                }
1527                            } else {
1528                                Some(new_prediction)
1529                            }
1530                        }
1531                        Err(reject_reason) => {
1532                            this.reject_prediction(prediction_result.id, reject_reason, false);
1533                            None
1534                        }
1535                    }
1536                } else {
1537                    None
1538                };
1539
1540                let project_state = this.get_or_init_project(&project, cx);
1541
1542                if let Some(new_prediction) = new_current_prediction {
1543                    project_state.current_prediction = Some(new_prediction);
1544                }
1545
1546                let mut pending_predictions = mem::take(&mut project_state.pending_predictions);
1547                for (ix, pending_prediction) in pending_predictions.iter().enumerate() {
1548                    if pending_prediction.id == pending_prediction_id {
1549                        pending_predictions.remove(ix);
1550                        for pending_prediction in pending_predictions.drain(0..ix) {
1551                            project_state.cancel_pending_prediction(pending_prediction, cx)
1552                        }
1553                        break;
1554                    }
1555                }
1556                this.get_or_init_project(&project, cx).pending_predictions = pending_predictions;
1557                cx.notify();
1558            })
1559            .ok();
1560
1561            new_prediction_id
1562        });
1563
1564        if project_state.pending_predictions.len() <= 1 {
1565            project_state.pending_predictions.push(PendingPrediction {
1566                id: pending_prediction_id,
1567                task,
1568            });
1569        } else if project_state.pending_predictions.len() == 2 {
1570            let pending_prediction = project_state.pending_predictions.pop().unwrap();
1571            project_state.pending_predictions.push(PendingPrediction {
1572                id: pending_prediction_id,
1573                task,
1574            });
1575            project_state.cancel_pending_prediction(pending_prediction, cx);
1576        }
1577    }
1578
1579    pub fn request_prediction(
1580        &mut self,
1581        project: &Entity<Project>,
1582        active_buffer: &Entity<Buffer>,
1583        position: language::Anchor,
1584        trigger: PredictEditsRequestTrigger,
1585        cx: &mut Context<Self>,
1586    ) -> Task<Result<Option<EditPredictionResult>>> {
1587        self.request_prediction_internal(
1588            project.clone(),
1589            active_buffer.clone(),
1590            position,
1591            trigger,
1592            cx.has_flag::<Zeta2FeatureFlag>(),
1593            cx,
1594        )
1595    }
1596
1597    fn request_prediction_internal(
1598        &mut self,
1599        project: Entity<Project>,
1600        active_buffer: Entity<Buffer>,
1601        position: language::Anchor,
1602        trigger: PredictEditsRequestTrigger,
1603        allow_jump: bool,
1604        cx: &mut Context<Self>,
1605    ) -> Task<Result<Option<EditPredictionResult>>> {
1606        const DIAGNOSTIC_LINES_RANGE: u32 = 20;
1607
1608        self.get_or_init_project(&project, cx);
1609        let project_state = self.projects.get(&project.entity_id()).unwrap();
1610        let stored_events = project_state.events(cx);
1611        let has_events = !stored_events.is_empty();
1612        let events: Vec<Arc<zeta_prompt::Event>> =
1613            stored_events.into_iter().map(|e| e.event).collect();
1614        let debug_tx = project_state.debug_tx.clone();
1615
1616        let snapshot = active_buffer.read(cx).snapshot();
1617        let cursor_point = position.to_point(&snapshot);
1618        let diagnostic_search_start = cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
1619        let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
1620        let diagnostic_search_range =
1621            Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
1622
1623        let related_files = if self.use_context {
1624            self.context_for_project(&project, cx)
1625        } else {
1626            Vec::new().into()
1627        };
1628
1629        let inputs = EditPredictionModelInput {
1630            project: project.clone(),
1631            buffer: active_buffer.clone(),
1632            snapshot: snapshot.clone(),
1633            position,
1634            events,
1635            related_files,
1636            recent_paths: project_state.recent_paths.clone(),
1637            trigger,
1638            diagnostic_search_range: diagnostic_search_range.clone(),
1639            debug_tx,
1640        };
1641
1642        let can_collect_example = snapshot
1643            .file()
1644            .is_some_and(|file| self.can_collect_file(&project, file, cx))
1645            && self.can_collect_events(&inputs.events, cx);
1646
1647        if can_collect_example && should_sample_edit_prediction_example_capture(cx) {
1648            let events_for_capture =
1649                self.edit_history_for_project_with_pause_split_last_event(&project, cx);
1650            if let Some(example_task) = capture_example::capture_example(
1651                project.clone(),
1652                active_buffer.clone(),
1653                position,
1654                events_for_capture,
1655                cx,
1656            ) {
1657                cx.spawn(async move |_this, _cx| {
1658                    let example = example_task.await?;
1659                    telemetry::event!("Edit Prediction Example Captured", example = example);
1660                    anyhow::Ok(())
1661                })
1662                .detach_and_log_err(cx);
1663            }
1664        }
1665        let task = match self.edit_prediction_model {
1666            EditPredictionModel::Zeta1 => zeta1::request_prediction_with_zeta1(self, inputs, cx),
1667            EditPredictionModel::Zeta2 => zeta2::request_prediction_with_zeta2(self, inputs, cx),
1668            EditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep(inputs, cx),
1669            EditPredictionModel::Mercury => self.mercury.request_prediction(inputs, cx),
1670        };
1671
1672        cx.spawn(async move |this, cx| {
1673            let prediction = task.await?;
1674
1675            if prediction.is_none() && allow_jump {
1676                let cursor_point = position.to_point(&snapshot);
1677                if has_events
1678                    && let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1679                        active_buffer.clone(),
1680                        &snapshot,
1681                        diagnostic_search_range,
1682                        cursor_point,
1683                        &project,
1684                        cx,
1685                    )
1686                    .await?
1687                {
1688                    return this
1689                        .update(cx, |this, cx| {
1690                            this.request_prediction_internal(
1691                                project,
1692                                jump_buffer,
1693                                jump_position,
1694                                trigger,
1695                                false,
1696                                cx,
1697                            )
1698                        })?
1699                        .await;
1700                }
1701
1702                return anyhow::Ok(None);
1703            }
1704
1705            Ok(prediction)
1706        })
1707    }
1708
1709    async fn next_diagnostic_location(
1710        active_buffer: Entity<Buffer>,
1711        active_buffer_snapshot: &BufferSnapshot,
1712        active_buffer_diagnostic_search_range: Range<Point>,
1713        active_buffer_cursor_point: Point,
1714        project: &Entity<Project>,
1715        cx: &mut AsyncApp,
1716    ) -> Result<Option<(Entity<Buffer>, language::Anchor)>> {
1717        // find the closest diagnostic to the cursor that wasn't close enough to be included in the last request
1718        let mut jump_location = active_buffer_snapshot
1719            .diagnostic_groups(None)
1720            .into_iter()
1721            .filter_map(|(_, group)| {
1722                let range = &group.entries[group.primary_ix]
1723                    .range
1724                    .to_point(&active_buffer_snapshot);
1725                if range.overlaps(&active_buffer_diagnostic_search_range) {
1726                    None
1727                } else {
1728                    Some(range.start)
1729                }
1730            })
1731            .min_by_key(|probe| probe.row.abs_diff(active_buffer_cursor_point.row))
1732            .map(|position| {
1733                (
1734                    active_buffer.clone(),
1735                    active_buffer_snapshot.anchor_before(position),
1736                )
1737            });
1738
1739        if jump_location.is_none() {
1740            let active_buffer_path = active_buffer.read_with(cx, |buffer, cx| {
1741                let file = buffer.file()?;
1742
1743                Some(ProjectPath {
1744                    worktree_id: file.worktree_id(cx),
1745                    path: file.path().clone(),
1746                })
1747            })?;
1748
1749            let buffer_task = project.update(cx, |project, cx| {
1750                let (path, _, _) = project
1751                    .diagnostic_summaries(false, cx)
1752                    .filter(|(path, _, _)| Some(path) != active_buffer_path.as_ref())
1753                    .max_by_key(|(path, _, _)| {
1754                        // find the buffer with errors that shares most parent directories
1755                        path.path
1756                            .components()
1757                            .zip(
1758                                active_buffer_path
1759                                    .as_ref()
1760                                    .map(|p| p.path.components())
1761                                    .unwrap_or_default(),
1762                            )
1763                            .take_while(|(a, b)| a == b)
1764                            .count()
1765                    })?;
1766
1767                Some(project.open_buffer(path, cx))
1768            })?;
1769
1770            if let Some(buffer_task) = buffer_task {
1771                let closest_buffer = buffer_task.await?;
1772
1773                jump_location = closest_buffer
1774                    .read_with(cx, |buffer, _cx| {
1775                        buffer
1776                            .buffer_diagnostics(None)
1777                            .into_iter()
1778                            .min_by_key(|entry| entry.diagnostic.severity)
1779                            .map(|entry| entry.range.start)
1780                    })?
1781                    .map(|position| (closest_buffer, position));
1782            }
1783        }
1784
1785        anyhow::Ok(jump_location)
1786    }
1787
1788    async fn send_raw_llm_request(
1789        request: open_ai::Request,
1790        client: Arc<Client>,
1791        llm_token: LlmApiToken,
1792        app_version: Version,
1793        #[cfg(feature = "cli-support")] eval_cache: Option<Arc<dyn EvalCache>>,
1794        #[cfg(feature = "cli-support")] eval_cache_kind: EvalCacheEntryKind,
1795    ) -> Result<(open_ai::Response, Option<EditPredictionUsage>)> {
1796        let url = client
1797            .http_client()
1798            .build_zed_llm_url("/predict_edits/raw", &[])?;
1799
1800        #[cfg(feature = "cli-support")]
1801        let cache_key = if let Some(cache) = eval_cache {
1802            use collections::FxHasher;
1803            use std::hash::{Hash, Hasher};
1804
1805            let mut hasher = FxHasher::default();
1806            url.hash(&mut hasher);
1807            let request_str = serde_json::to_string_pretty(&request)?;
1808            request_str.hash(&mut hasher);
1809            let hash = hasher.finish();
1810
1811            let key = (eval_cache_kind, hash);
1812            if let Some(response_str) = cache.read(key) {
1813                return Ok((serde_json::from_str(&response_str)?, None));
1814            }
1815
1816            Some((cache, request_str, key))
1817        } else {
1818            None
1819        };
1820
1821        let (response, usage) = Self::send_api_request(
1822            |builder| {
1823                let req = builder
1824                    .uri(url.as_ref())
1825                    .body(serde_json::to_string(&request)?.into());
1826                Ok(req?)
1827            },
1828            client,
1829            llm_token,
1830            app_version,
1831            true,
1832        )
1833        .await?;
1834
1835        #[cfg(feature = "cli-support")]
1836        if let Some((cache, request, key)) = cache_key {
1837            cache.write(key, &request, &serde_json::to_string_pretty(&response)?);
1838        }
1839
1840        Ok((response, usage))
1841    }
1842
1843    fn handle_api_response<T>(
1844        this: &WeakEntity<Self>,
1845        response: Result<(T, Option<EditPredictionUsage>)>,
1846        cx: &mut gpui::AsyncApp,
1847    ) -> Result<T> {
1848        match response {
1849            Ok((data, usage)) => {
1850                if let Some(usage) = usage {
1851                    this.update(cx, |this, cx| {
1852                        this.user_store.update(cx, |user_store, cx| {
1853                            user_store.update_edit_prediction_usage(usage, cx);
1854                        });
1855                    })
1856                    .ok();
1857                }
1858                Ok(data)
1859            }
1860            Err(err) => {
1861                if err.is::<ZedUpdateRequiredError>() {
1862                    cx.update(|cx| {
1863                        this.update(cx, |this, _cx| {
1864                            this.update_required = true;
1865                        })
1866                        .ok();
1867
1868                        let error_message: SharedString = err.to_string().into();
1869                        show_app_notification(
1870                            NotificationId::unique::<ZedUpdateRequiredError>(),
1871                            cx,
1872                            move |cx| {
1873                                cx.new(|cx| {
1874                                    ErrorMessagePrompt::new(error_message.clone(), cx)
1875                                        .with_link_button("Update Zed", "https://zed.dev/releases")
1876                                })
1877                            },
1878                        );
1879                    })
1880                    .ok();
1881                }
1882                Err(err)
1883            }
1884        }
1885    }
1886
1887    async fn send_api_request<Res>(
1888        build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
1889        client: Arc<Client>,
1890        llm_token: LlmApiToken,
1891        app_version: Version,
1892        require_auth: bool,
1893    ) -> Result<(Res, Option<EditPredictionUsage>)>
1894    where
1895        Res: DeserializeOwned,
1896    {
1897        let http_client = client.http_client();
1898
1899        let mut token = if require_auth {
1900            Some(llm_token.acquire(&client).await?)
1901        } else {
1902            llm_token.acquire(&client).await.ok()
1903        };
1904        let mut did_retry = false;
1905
1906        loop {
1907            let request_builder = http_client::Request::builder().method(Method::POST);
1908
1909            let mut request_builder = request_builder
1910                .header("Content-Type", "application/json")
1911                .header(ZED_VERSION_HEADER_NAME, app_version.to_string());
1912
1913            // Only add Authorization header if we have a token
1914            if let Some(ref token_value) = token {
1915                request_builder =
1916                    request_builder.header("Authorization", format!("Bearer {}", token_value));
1917            }
1918
1919            let request = build(request_builder)?;
1920
1921            let mut response = http_client.send(request).await?;
1922
1923            if let Some(minimum_required_version) = response
1924                .headers()
1925                .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
1926                .and_then(|version| Version::from_str(version.to_str().ok()?).ok())
1927            {
1928                anyhow::ensure!(
1929                    app_version >= minimum_required_version,
1930                    ZedUpdateRequiredError {
1931                        minimum_version: minimum_required_version
1932                    }
1933                );
1934            }
1935
1936            if response.status().is_success() {
1937                let usage = EditPredictionUsage::from_headers(response.headers()).ok();
1938
1939                let mut body = Vec::new();
1940                response.body_mut().read_to_end(&mut body).await?;
1941                return Ok((serde_json::from_slice(&body)?, usage));
1942            } else if !did_retry
1943                && token.is_some()
1944                && response
1945                    .headers()
1946                    .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
1947                    .is_some()
1948            {
1949                did_retry = true;
1950                token = Some(llm_token.refresh(&client).await?);
1951            } else {
1952                let mut body = String::new();
1953                response.body_mut().read_to_string(&mut body).await?;
1954                anyhow::bail!(
1955                    "Request failed with status: {:?}\nBody: {}",
1956                    response.status(),
1957                    body
1958                );
1959            }
1960        }
1961    }
1962
1963    pub fn refresh_context(
1964        &mut self,
1965        project: &Entity<Project>,
1966        buffer: &Entity<language::Buffer>,
1967        cursor_position: language::Anchor,
1968        cx: &mut Context<Self>,
1969    ) {
1970        if self.use_context {
1971            self.get_or_init_project(project, cx)
1972                .context
1973                .update(cx, |store, cx| {
1974                    store.refresh(buffer.clone(), cursor_position, cx);
1975                });
1976        }
1977    }
1978
1979    #[cfg(feature = "cli-support")]
1980    pub fn set_context_for_buffer(
1981        &mut self,
1982        project: &Entity<Project>,
1983        related_files: Vec<RelatedFile>,
1984        cx: &mut Context<Self>,
1985    ) {
1986        self.get_or_init_project(project, cx)
1987            .context
1988            .update(cx, |store, _| {
1989                store.set_related_files(related_files);
1990            });
1991    }
1992
1993    fn is_file_open_source(
1994        &self,
1995        project: &Entity<Project>,
1996        file: &Arc<dyn File>,
1997        cx: &App,
1998    ) -> bool {
1999        if !file.is_local() || file.is_private() {
2000            return false;
2001        }
2002        let Some(project_state) = self.projects.get(&project.entity_id()) else {
2003            return false;
2004        };
2005        project_state
2006            .license_detection_watchers
2007            .get(&file.worktree_id(cx))
2008            .as_ref()
2009            .is_some_and(|watcher| watcher.is_project_open_source())
2010    }
2011
2012    fn can_collect_file(&self, project: &Entity<Project>, file: &Arc<dyn File>, cx: &App) -> bool {
2013        self.data_collection_choice.is_enabled(cx) && self.is_file_open_source(project, file, cx)
2014    }
2015
2016    fn can_collect_events(&self, events: &[Arc<zeta_prompt::Event>], cx: &App) -> bool {
2017        if !self.data_collection_choice.is_enabled(cx) {
2018            return false;
2019        }
2020        events.iter().all(|event| {
2021            matches!(
2022                event.as_ref(),
2023                zeta_prompt::Event::BufferChange {
2024                    in_open_source_repo: true,
2025                    ..
2026                }
2027            )
2028        })
2029    }
2030
2031    fn load_data_collection_choice() -> DataCollectionChoice {
2032        let choice = KEY_VALUE_STORE
2033            .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2034            .log_err()
2035            .flatten();
2036
2037        match choice.as_deref() {
2038            Some("true") => DataCollectionChoice::Enabled,
2039            Some("false") => DataCollectionChoice::Disabled,
2040            Some(_) => {
2041                log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
2042                DataCollectionChoice::NotAnswered
2043            }
2044            None => DataCollectionChoice::NotAnswered,
2045        }
2046    }
2047
2048    fn toggle_data_collection_choice(&mut self, cx: &mut Context<Self>) {
2049        self.data_collection_choice = self.data_collection_choice.toggle();
2050        let new_choice = self.data_collection_choice;
2051        let is_enabled = new_choice.is_enabled(cx);
2052        db::write_and_log(cx, move || {
2053            KEY_VALUE_STORE.write_kvp(
2054                ZED_PREDICT_DATA_COLLECTION_CHOICE.into(),
2055                is_enabled.to_string(),
2056            )
2057        });
2058    }
2059
2060    pub fn shown_predictions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
2061        self.shown_predictions.iter()
2062    }
2063
2064    pub fn shown_completions_len(&self) -> usize {
2065        self.shown_predictions.len()
2066    }
2067
2068    pub fn is_prediction_rated(&self, id: &EditPredictionId) -> bool {
2069        self.rated_predictions.contains(id)
2070    }
2071
2072    pub fn rate_prediction(
2073        &mut self,
2074        prediction: &EditPrediction,
2075        rating: EditPredictionRating,
2076        feedback: String,
2077        cx: &mut Context<Self>,
2078    ) {
2079        self.rated_predictions.insert(prediction.id.clone());
2080        telemetry::event!(
2081            "Edit Prediction Rated",
2082            rating,
2083            inputs = prediction.inputs,
2084            output = prediction
2085                .edit_preview
2086                .as_unified_diff(prediction.snapshot.file(), &prediction.edits),
2087            feedback
2088        );
2089        self.client.telemetry().flush_events().detach();
2090        cx.notify();
2091    }
2092
2093    fn configure_context_retrieval(&mut self, cx: &mut Context<'_, EditPredictionStore>) {
2094        self.use_context = cx.has_flag::<Zeta2FeatureFlag>()
2095            && all_language_settings(None, cx).edit_predictions.use_context;
2096    }
2097}
2098
2099fn edited_range(
2100    old_snapshot: &TextBufferSnapshot,
2101    new_snapshot: &TextBufferSnapshot,
2102) -> Option<Range<Anchor>> {
2103    new_snapshot
2104        .anchored_edits_since::<usize>(&old_snapshot.version)
2105        .fold(None, |acc, (_, range)| {
2106            Some(match acc {
2107                None => range,
2108                Some(acc) => acc.start..range.end,
2109            })
2110        })
2111}
2112
2113#[derive(Error, Debug)]
2114#[error(
2115    "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
2116)]
2117pub struct ZedUpdateRequiredError {
2118    minimum_version: Version,
2119}
2120
2121#[cfg(feature = "cli-support")]
2122pub type EvalCacheKey = (EvalCacheEntryKind, u64);
2123
2124#[cfg(feature = "cli-support")]
2125#[derive(Debug, Clone, Copy, PartialEq)]
2126pub enum EvalCacheEntryKind {
2127    Context,
2128    Search,
2129    Prediction,
2130}
2131
2132#[cfg(feature = "cli-support")]
2133impl std::fmt::Display for EvalCacheEntryKind {
2134    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2135        match self {
2136            EvalCacheEntryKind::Search => write!(f, "search"),
2137            EvalCacheEntryKind::Context => write!(f, "context"),
2138            EvalCacheEntryKind::Prediction => write!(f, "prediction"),
2139        }
2140    }
2141}
2142
2143#[cfg(feature = "cli-support")]
2144pub trait EvalCache: Send + Sync {
2145    fn read(&self, key: EvalCacheKey) -> Option<String>;
2146    fn write(&self, key: EvalCacheKey, input: &str, value: &str);
2147}
2148
2149#[derive(Debug, Clone, Copy)]
2150pub enum DataCollectionChoice {
2151    NotAnswered,
2152    Enabled,
2153    Disabled,
2154}
2155
2156impl DataCollectionChoice {
2157    pub fn is_enabled(self, cx: &App) -> bool {
2158        if cx.is_staff() {
2159            return true;
2160        }
2161        match self {
2162            Self::Enabled => true,
2163            Self::NotAnswered | Self::Disabled => false,
2164        }
2165    }
2166
2167    #[must_use]
2168    pub fn toggle(&self) -> DataCollectionChoice {
2169        match self {
2170            Self::Enabled => Self::Disabled,
2171            Self::Disabled => Self::Enabled,
2172            Self::NotAnswered => Self::Enabled,
2173        }
2174    }
2175}
2176
2177impl From<bool> for DataCollectionChoice {
2178    fn from(value: bool) -> Self {
2179        match value {
2180            true => DataCollectionChoice::Enabled,
2181            false => DataCollectionChoice::Disabled,
2182        }
2183    }
2184}
2185
2186struct ZedPredictUpsell;
2187
2188impl Dismissable for ZedPredictUpsell {
2189    const KEY: &'static str = "dismissed-edit-predict-upsell";
2190
2191    fn dismissed() -> bool {
2192        // To make this backwards compatible with older versions of Zed, we
2193        // check if the user has seen the previous Edit Prediction Onboarding
2194        // before, by checking the data collection choice which was written to
2195        // the database once the user clicked on "Accept and Enable"
2196        if KEY_VALUE_STORE
2197            .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2198            .log_err()
2199            .is_some_and(|s| s.is_some())
2200        {
2201            return true;
2202        }
2203
2204        KEY_VALUE_STORE
2205            .read_kvp(Self::KEY)
2206            .log_err()
2207            .is_some_and(|s| s.is_some())
2208    }
2209}
2210
2211pub fn should_show_upsell_modal() -> bool {
2212    !ZedPredictUpsell::dismissed()
2213}
2214
2215pub fn init(cx: &mut App) {
2216    cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
2217        workspace.register_action(
2218            move |workspace, _: &zed_actions::OpenZedPredictOnboarding, window, cx| {
2219                ZedPredictModal::toggle(
2220                    workspace,
2221                    workspace.user_store().clone(),
2222                    workspace.client().clone(),
2223                    window,
2224                    cx,
2225                )
2226            },
2227        );
2228
2229        workspace.register_action(|workspace, _: &ResetOnboarding, _window, cx| {
2230            update_settings_file(workspace.app_state().fs.clone(), cx, move |settings, _| {
2231                settings
2232                    .project
2233                    .all_languages
2234                    .features
2235                    .get_or_insert_default()
2236                    .edit_prediction_provider = Some(EditPredictionProvider::None)
2237            });
2238        });
2239    })
2240    .detach();
2241}