edit_prediction.rs

   1use anyhow::Result;
   2use arrayvec::ArrayVec;
   3use client::{Client, EditPredictionUsage, UserStore};
   4use cloud_llm_client::predict_edits_v3::{self, PromptFormat};
   5use cloud_llm_client::{
   6    AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, EditPredictionRejectReason,
   7    EditPredictionRejection, MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST,
   8    MINIMUM_REQUIRED_VERSION_HEADER_NAME, PredictEditsRequestTrigger, RejectEditPredictionsBodyRef,
   9    ZED_VERSION_HEADER_NAME,
  10};
  11use collections::{HashMap, HashSet};
  12use db::kvp::{Dismissable, KEY_VALUE_STORE};
  13use edit_prediction_context::EditPredictionExcerptOptions;
  14use edit_prediction_context::{RelatedExcerptStore, RelatedExcerptStoreEvent, RelatedFile};
  15use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
  16use futures::{
  17    AsyncReadExt as _, FutureExt as _, StreamExt as _,
  18    channel::mpsc::{self, UnboundedReceiver},
  19    select_biased,
  20};
  21use gpui::BackgroundExecutor;
  22use gpui::http_client::Url;
  23use gpui::{
  24    App, AsyncApp, Entity, EntityId, Global, SharedString, Subscription, Task, WeakEntity, actions,
  25    http_client::{self, AsyncBody, Method},
  26    prelude::*,
  27};
  28use language::language_settings::all_language_settings;
  29use language::{Anchor, Buffer, File, Point, TextBufferSnapshot, ToPoint};
  30use language::{BufferSnapshot, OffsetRangeExt};
  31use language_model::{LlmApiToken, RefreshLlmTokenListener};
  32use project::{Project, ProjectPath, WorktreeId};
  33use release_channel::AppVersion;
  34use semver::Version;
  35use serde::de::DeserializeOwned;
  36use settings::{EditPredictionProvider, SettingsStore, update_settings_file};
  37use std::collections::{VecDeque, hash_map};
  38use text::Edit;
  39use workspace::Workspace;
  40
  41use std::ops::Range;
  42use std::path::Path;
  43use std::rc::Rc;
  44use std::str::FromStr as _;
  45use std::sync::{Arc, LazyLock};
  46use std::time::{Duration, Instant};
  47use std::{env, mem};
  48use thiserror::Error;
  49use util::{RangeExt as _, ResultExt as _};
  50use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
  51
  52pub mod cursor_excerpt;
  53pub mod example_spec;
  54mod license_detection;
  55pub mod mercury;
  56mod onboarding_modal;
  57pub mod open_ai_response;
  58mod prediction;
  59pub mod sweep_ai;
  60
  61pub mod udiff;
  62
  63mod capture_example;
  64mod zed_edit_prediction_delegate;
  65pub mod zeta1;
  66pub mod zeta2;
  67
  68#[cfg(test)]
  69mod edit_prediction_tests;
  70
  71use crate::license_detection::LicenseDetectionWatcher;
  72use crate::mercury::Mercury;
  73use crate::onboarding_modal::ZedPredictModal;
  74pub use crate::prediction::EditPrediction;
  75pub use crate::prediction::EditPredictionId;
  76use crate::prediction::EditPredictionResult;
  77pub use crate::sweep_ai::SweepAi;
  78pub use capture_example::capture_example;
  79pub use language_model::ApiKeyState;
  80pub use telemetry_events::EditPredictionRating;
  81pub use zed_edit_prediction_delegate::ZedEditPredictionDelegate;
  82
  83actions!(
  84    edit_prediction,
  85    [
  86        /// Resets the edit prediction onboarding state.
  87        ResetOnboarding,
  88        /// Clears the edit prediction history.
  89        ClearHistory,
  90    ]
  91);
  92
  93/// Maximum number of events to track.
  94const EVENT_COUNT_MAX: usize = 6;
  95const CHANGE_GROUPING_LINE_SPAN: u32 = 8;
  96const LAST_CHANGE_GROUPING_TIME: Duration = Duration::from_secs(1);
  97const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
  98const REJECT_REQUEST_DEBOUNCE: Duration = Duration::from_secs(15);
  99
 100pub struct SweepFeatureFlag;
 101
 102impl FeatureFlag for SweepFeatureFlag {
 103    const NAME: &str = "sweep-ai";
 104}
 105
 106pub struct MercuryFeatureFlag;
 107
 108impl FeatureFlag for MercuryFeatureFlag {
 109    const NAME: &str = "mercury";
 110}
 111
 112pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
 113    context: EditPredictionExcerptOptions {
 114        max_bytes: 512,
 115        min_bytes: 128,
 116        target_before_cursor_over_total_bytes: 0.5,
 117    },
 118    prompt_format: PromptFormat::DEFAULT,
 119};
 120
 121static USE_OLLAMA: LazyLock<bool> =
 122    LazyLock::new(|| env::var("ZED_ZETA2_OLLAMA").is_ok_and(|var| !var.is_empty()));
 123
 124static EDIT_PREDICTIONS_MODEL_ID: LazyLock<String> = LazyLock::new(|| {
 125    match env::var("ZED_ZETA2_MODEL").as_deref() {
 126        Ok("zeta2-exp") => "4w5n28vw", // Fine-tuned model @ Baseten
 127        Ok(model) => model,
 128        Err(_) if *USE_OLLAMA => "qwen3-coder:30b",
 129        Err(_) => "yqvev8r3", // Vanilla qwen3-coder @ Baseten
 130    }
 131    .to_string()
 132});
 133
 134pub struct Zeta2FeatureFlag;
 135
 136impl FeatureFlag for Zeta2FeatureFlag {
 137    const NAME: &'static str = "zeta2";
 138
 139    fn enabled_for_staff() -> bool {
 140        true
 141    }
 142}
 143
 144#[derive(Clone)]
 145struct EditPredictionStoreGlobal(Entity<EditPredictionStore>);
 146
 147impl Global for EditPredictionStoreGlobal {}
 148
 149pub struct EditPredictionStore {
 150    client: Arc<Client>,
 151    user_store: Entity<UserStore>,
 152    llm_token: LlmApiToken,
 153    _llm_token_subscription: Subscription,
 154    projects: HashMap<EntityId, ProjectState>,
 155    use_context: bool,
 156    options: ZetaOptions,
 157    update_required: bool,
 158    #[cfg(feature = "cli-support")]
 159    eval_cache: Option<Arc<dyn EvalCache>>,
 160    edit_prediction_model: EditPredictionModel,
 161    pub sweep_ai: SweepAi,
 162    pub mercury: Mercury,
 163    data_collection_choice: DataCollectionChoice,
 164    reject_predictions_tx: mpsc::UnboundedSender<EditPredictionRejection>,
 165    shown_predictions: VecDeque<EditPrediction>,
 166    rated_predictions: HashSet<EditPredictionId>,
 167    custom_predict_edits_url: Option<Arc<Url>>,
 168}
 169
 170#[derive(Copy, Clone, Default, PartialEq, Eq)]
 171pub enum EditPredictionModel {
 172    #[default]
 173    Zeta1,
 174    Zeta2,
 175    Sweep,
 176    Mercury,
 177}
 178
 179pub struct EditPredictionModelInput {
 180    project: Entity<Project>,
 181    buffer: Entity<Buffer>,
 182    snapshot: BufferSnapshot,
 183    position: Anchor,
 184    events: Vec<Arc<zeta_prompt::Event>>,
 185    related_files: Arc<[RelatedFile]>,
 186    recent_paths: VecDeque<ProjectPath>,
 187    trigger: PredictEditsRequestTrigger,
 188    diagnostic_search_range: Range<Point>,
 189    debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
 190}
 191
 192#[derive(Debug, Clone, PartialEq)]
 193pub struct ZetaOptions {
 194    pub context: EditPredictionExcerptOptions,
 195    pub prompt_format: predict_edits_v3::PromptFormat,
 196}
 197
 198#[derive(Debug)]
 199pub enum DebugEvent {
 200    ContextRetrievalStarted(ContextRetrievalStartedDebugEvent),
 201    ContextRetrievalFinished(ContextRetrievalFinishedDebugEvent),
 202    EditPredictionStarted(EditPredictionStartedDebugEvent),
 203    EditPredictionFinished(EditPredictionFinishedDebugEvent),
 204}
 205
 206#[derive(Debug)]
 207pub struct ContextRetrievalStartedDebugEvent {
 208    pub project_entity_id: EntityId,
 209    pub timestamp: Instant,
 210    pub search_prompt: String,
 211}
 212
 213#[derive(Debug)]
 214pub struct ContextRetrievalFinishedDebugEvent {
 215    pub project_entity_id: EntityId,
 216    pub timestamp: Instant,
 217    pub metadata: Vec<(&'static str, SharedString)>,
 218}
 219
 220#[derive(Debug)]
 221pub struct EditPredictionStartedDebugEvent {
 222    pub buffer: WeakEntity<Buffer>,
 223    pub position: Anchor,
 224    pub prompt: Option<String>,
 225}
 226
 227#[derive(Debug)]
 228pub struct EditPredictionFinishedDebugEvent {
 229    pub buffer: WeakEntity<Buffer>,
 230    pub position: Anchor,
 231    pub model_output: Option<String>,
 232}
 233
 234pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
 235
 236/// An event with associated metadata for reconstructing buffer state.
 237#[derive(Clone)]
 238pub struct StoredEvent {
 239    pub event: Arc<zeta_prompt::Event>,
 240    pub old_snapshot: TextBufferSnapshot,
 241}
 242
 243struct ProjectState {
 244    events: VecDeque<StoredEvent>,
 245    last_event: Option<LastEvent>,
 246    recent_paths: VecDeque<ProjectPath>,
 247    registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
 248    current_prediction: Option<CurrentEditPrediction>,
 249    next_pending_prediction_id: usize,
 250    pending_predictions: ArrayVec<PendingPrediction, 2>,
 251    debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
 252    last_prediction_refresh: Option<(EntityId, Instant)>,
 253    cancelled_predictions: HashSet<usize>,
 254    context: Entity<RelatedExcerptStore>,
 255    license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
 256    _subscription: gpui::Subscription,
 257}
 258
 259impl ProjectState {
 260    pub fn events(&self, cx: &App) -> Vec<StoredEvent> {
 261        self.events
 262            .iter()
 263            .cloned()
 264            .chain(
 265                self.last_event
 266                    .as_ref()
 267                    .and_then(|event| event.finalize(&self.license_detection_watchers, cx)),
 268            )
 269            .collect()
 270    }
 271
 272    pub fn events_split_by_pause(&self, cx: &App) -> Vec<StoredEvent> {
 273        self.events
 274            .iter()
 275            .cloned()
 276            .chain(self.last_event.as_ref().iter().flat_map(|event| {
 277                let (one, two) = event.split_by_pause();
 278                let one = one.finalize(&self.license_detection_watchers, cx);
 279                let two = two.and_then(|two| two.finalize(&self.license_detection_watchers, cx));
 280                one.into_iter().chain(two)
 281            }))
 282            .collect()
 283    }
 284
 285    fn cancel_pending_prediction(
 286        &mut self,
 287        pending_prediction: PendingPrediction,
 288        cx: &mut Context<EditPredictionStore>,
 289    ) {
 290        self.cancelled_predictions.insert(pending_prediction.id);
 291
 292        cx.spawn(async move |this, cx| {
 293            let Some(prediction_id) = pending_prediction.task.await else {
 294                return;
 295            };
 296
 297            this.update(cx, |this, _cx| {
 298                this.reject_prediction(prediction_id, EditPredictionRejectReason::Canceled, false);
 299            })
 300            .ok();
 301        })
 302        .detach()
 303    }
 304
 305    fn active_buffer(
 306        &self,
 307        project: &Entity<Project>,
 308        cx: &App,
 309    ) -> Option<(Entity<Buffer>, Option<Anchor>)> {
 310        let project = project.read(cx);
 311        let active_path = project.path_for_entry(project.active_entry()?, cx)?;
 312        let active_buffer = project.buffer_store().read(cx).get_by_path(&active_path)?;
 313        let registered_buffer = self.registered_buffers.get(&active_buffer.entity_id())?;
 314        Some((active_buffer, registered_buffer.last_position))
 315    }
 316}
 317
 318#[derive(Debug, Clone)]
 319struct CurrentEditPrediction {
 320    pub requested_by: PredictionRequestedBy,
 321    pub prediction: EditPrediction,
 322    pub was_shown: bool,
 323}
 324
 325impl CurrentEditPrediction {
 326    fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
 327        let Some(new_edits) = self
 328            .prediction
 329            .interpolate(&self.prediction.buffer.read(cx))
 330        else {
 331            return false;
 332        };
 333
 334        if self.prediction.buffer != old_prediction.prediction.buffer {
 335            return true;
 336        }
 337
 338        let Some(old_edits) = old_prediction
 339            .prediction
 340            .interpolate(&old_prediction.prediction.buffer.read(cx))
 341        else {
 342            return true;
 343        };
 344
 345        let requested_by_buffer_id = self.requested_by.buffer_id();
 346
 347        // This reduces the occurrence of UI thrash from replacing edits
 348        //
 349        // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
 350        if requested_by_buffer_id == Some(self.prediction.buffer.entity_id())
 351            && requested_by_buffer_id == Some(old_prediction.prediction.buffer.entity_id())
 352            && old_edits.len() == 1
 353            && new_edits.len() == 1
 354        {
 355            let (old_range, old_text) = &old_edits[0];
 356            let (new_range, new_text) = &new_edits[0];
 357            new_range == old_range && new_text.starts_with(old_text.as_ref())
 358        } else {
 359            true
 360        }
 361    }
 362}
 363
 364#[derive(Debug, Clone)]
 365enum PredictionRequestedBy {
 366    DiagnosticsUpdate,
 367    Buffer(EntityId),
 368}
 369
 370impl PredictionRequestedBy {
 371    pub fn buffer_id(&self) -> Option<EntityId> {
 372        match self {
 373            PredictionRequestedBy::DiagnosticsUpdate => None,
 374            PredictionRequestedBy::Buffer(buffer_id) => Some(*buffer_id),
 375        }
 376    }
 377}
 378
 379#[derive(Debug)]
 380struct PendingPrediction {
 381    id: usize,
 382    task: Task<Option<EditPredictionId>>,
 383}
 384
 385/// A prediction from the perspective of a buffer.
 386#[derive(Debug)]
 387enum BufferEditPrediction<'a> {
 388    Local { prediction: &'a EditPrediction },
 389    Jump { prediction: &'a EditPrediction },
 390}
 391
 392#[cfg(test)]
 393impl std::ops::Deref for BufferEditPrediction<'_> {
 394    type Target = EditPrediction;
 395
 396    fn deref(&self) -> &Self::Target {
 397        match self {
 398            BufferEditPrediction::Local { prediction } => prediction,
 399            BufferEditPrediction::Jump { prediction } => prediction,
 400        }
 401    }
 402}
 403
 404struct RegisteredBuffer {
 405    file: Option<Arc<dyn File>>,
 406    snapshot: TextBufferSnapshot,
 407    last_position: Option<Anchor>,
 408    _subscriptions: [gpui::Subscription; 2],
 409}
 410
 411#[derive(Clone)]
 412struct LastEvent {
 413    old_snapshot: TextBufferSnapshot,
 414    new_snapshot: TextBufferSnapshot,
 415    old_file: Option<Arc<dyn File>>,
 416    new_file: Option<Arc<dyn File>>,
 417    end_edit_anchor: Option<Anchor>,
 418    snapshot_after_last_editing_pause: Option<TextBufferSnapshot>,
 419    last_edit_time: Option<Instant>,
 420}
 421
 422impl LastEvent {
 423    pub fn finalize(
 424        &self,
 425        license_detection_watchers: &HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
 426        cx: &App,
 427    ) -> Option<StoredEvent> {
 428        let path = buffer_path_with_id_fallback(self.new_file.as_ref(), &self.new_snapshot, cx);
 429        let old_path = buffer_path_with_id_fallback(self.old_file.as_ref(), &self.old_snapshot, cx);
 430
 431        let in_open_source_repo =
 432            [self.new_file.as_ref(), self.old_file.as_ref()]
 433                .iter()
 434                .all(|file| {
 435                    file.is_some_and(|file| {
 436                        license_detection_watchers
 437                            .get(&file.worktree_id(cx))
 438                            .is_some_and(|watcher| watcher.is_project_open_source())
 439                    })
 440                });
 441
 442        let diff = compute_diff_between_snapshots(&self.old_snapshot, &self.new_snapshot)?;
 443
 444        if path == old_path && diff.is_empty() {
 445            None
 446        } else {
 447            Some(StoredEvent {
 448                event: Arc::new(zeta_prompt::Event::BufferChange {
 449                    old_path,
 450                    path,
 451                    diff,
 452                    in_open_source_repo,
 453                    // TODO: Actually detect if this edit was predicted or not
 454                    predicted: false,
 455                }),
 456                old_snapshot: self.old_snapshot.clone(),
 457            })
 458        }
 459    }
 460
 461    pub fn split_by_pause(&self) -> (LastEvent, Option<LastEvent>) {
 462        let Some(boundary_snapshot) = self.snapshot_after_last_editing_pause.as_ref() else {
 463            return (self.clone(), None);
 464        };
 465
 466        let before = LastEvent {
 467            old_snapshot: self.old_snapshot.clone(),
 468            new_snapshot: boundary_snapshot.clone(),
 469            old_file: self.old_file.clone(),
 470            new_file: self.new_file.clone(),
 471            end_edit_anchor: self.end_edit_anchor,
 472            snapshot_after_last_editing_pause: None,
 473            last_edit_time: self.last_edit_time,
 474        };
 475
 476        let after = LastEvent {
 477            old_snapshot: boundary_snapshot.clone(),
 478            new_snapshot: self.new_snapshot.clone(),
 479            old_file: self.old_file.clone(),
 480            new_file: self.new_file.clone(),
 481            end_edit_anchor: self.end_edit_anchor,
 482            snapshot_after_last_editing_pause: None,
 483            last_edit_time: self.last_edit_time,
 484        };
 485
 486        (before, Some(after))
 487    }
 488}
 489
 490pub(crate) fn compute_diff_between_snapshots(
 491    old_snapshot: &TextBufferSnapshot,
 492    new_snapshot: &TextBufferSnapshot,
 493) -> Option<String> {
 494    let edits: Vec<Edit<usize>> = new_snapshot
 495        .edits_since::<usize>(&old_snapshot.version)
 496        .collect();
 497
 498    let (first_edit, last_edit) = edits.first().zip(edits.last())?;
 499
 500    let old_start_point = old_snapshot.offset_to_point(first_edit.old.start);
 501    let old_end_point = old_snapshot.offset_to_point(last_edit.old.end);
 502    let new_start_point = new_snapshot.offset_to_point(first_edit.new.start);
 503    let new_end_point = new_snapshot.offset_to_point(last_edit.new.end);
 504
 505    const CONTEXT_LINES: u32 = 3;
 506
 507    let old_context_start_row = old_start_point.row.saturating_sub(CONTEXT_LINES);
 508    let new_context_start_row = new_start_point.row.saturating_sub(CONTEXT_LINES);
 509    let old_context_end_row =
 510        (old_end_point.row + 1 + CONTEXT_LINES).min(old_snapshot.max_point().row);
 511    let new_context_end_row =
 512        (new_end_point.row + 1 + CONTEXT_LINES).min(new_snapshot.max_point().row);
 513
 514    let old_start_line_offset = old_snapshot.point_to_offset(Point::new(old_context_start_row, 0));
 515    let new_start_line_offset = new_snapshot.point_to_offset(Point::new(new_context_start_row, 0));
 516    let old_end_line_offset = old_snapshot
 517        .point_to_offset(Point::new(old_context_end_row + 1, 0).min(old_snapshot.max_point()));
 518    let new_end_line_offset = new_snapshot
 519        .point_to_offset(Point::new(new_context_end_row + 1, 0).min(new_snapshot.max_point()));
 520    let old_edit_range = old_start_line_offset..old_end_line_offset;
 521    let new_edit_range = new_start_line_offset..new_end_line_offset;
 522
 523    let old_region_text: String = old_snapshot.text_for_range(old_edit_range).collect();
 524    let new_region_text: String = new_snapshot.text_for_range(new_edit_range).collect();
 525
 526    let diff = language::unified_diff_with_offsets(
 527        &old_region_text,
 528        &new_region_text,
 529        old_context_start_row,
 530        new_context_start_row,
 531    );
 532
 533    Some(diff)
 534}
 535
 536fn buffer_path_with_id_fallback(
 537    file: Option<&Arc<dyn File>>,
 538    snapshot: &TextBufferSnapshot,
 539    cx: &App,
 540) -> Arc<Path> {
 541    if let Some(file) = file {
 542        file.full_path(cx).into()
 543    } else {
 544        Path::new(&format!("untitled-{}", snapshot.remote_id())).into()
 545    }
 546}
 547
 548impl EditPredictionStore {
 549    pub fn try_global(cx: &App) -> Option<Entity<Self>> {
 550        cx.try_global::<EditPredictionStoreGlobal>()
 551            .map(|global| global.0.clone())
 552    }
 553
 554    pub fn global(
 555        client: &Arc<Client>,
 556        user_store: &Entity<UserStore>,
 557        cx: &mut App,
 558    ) -> Entity<Self> {
 559        cx.try_global::<EditPredictionStoreGlobal>()
 560            .map(|global| global.0.clone())
 561            .unwrap_or_else(|| {
 562                let ep_store = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
 563                cx.set_global(EditPredictionStoreGlobal(ep_store.clone()));
 564                ep_store
 565            })
 566    }
 567
 568    pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
 569        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
 570        let data_collection_choice = Self::load_data_collection_choice();
 571
 572        let llm_token = LlmApiToken::default();
 573
 574        let (reject_tx, reject_rx) = mpsc::unbounded();
 575        cx.background_spawn({
 576            let client = client.clone();
 577            let llm_token = llm_token.clone();
 578            let app_version = AppVersion::global(cx);
 579            let background_executor = cx.background_executor().clone();
 580            async move {
 581                Self::handle_rejected_predictions(
 582                    reject_rx,
 583                    client,
 584                    llm_token,
 585                    app_version,
 586                    background_executor,
 587                )
 588                .await
 589            }
 590        })
 591        .detach();
 592
 593        let mut this = Self {
 594            projects: HashMap::default(),
 595            client,
 596            user_store,
 597            options: DEFAULT_OPTIONS,
 598            use_context: false,
 599            llm_token,
 600            _llm_token_subscription: cx.subscribe(
 601                &refresh_llm_token_listener,
 602                |this, _listener, _event, cx| {
 603                    let client = this.client.clone();
 604                    let llm_token = this.llm_token.clone();
 605                    cx.spawn(async move |_this, _cx| {
 606                        llm_token.refresh(&client).await?;
 607                        anyhow::Ok(())
 608                    })
 609                    .detach_and_log_err(cx);
 610                },
 611            ),
 612            update_required: false,
 613            #[cfg(feature = "cli-support")]
 614            eval_cache: None,
 615            edit_prediction_model: EditPredictionModel::Zeta2,
 616            sweep_ai: SweepAi::new(cx),
 617            mercury: Mercury::new(cx),
 618            data_collection_choice,
 619            reject_predictions_tx: reject_tx,
 620            rated_predictions: Default::default(),
 621            shown_predictions: Default::default(),
 622            custom_predict_edits_url: match env::var("ZED_PREDICT_EDITS_URL") {
 623                Ok(custom_url) => Url::parse(&custom_url).log_err().map(Into::into),
 624                Err(_) => {
 625                    if *USE_OLLAMA {
 626                        Some(
 627                            Url::parse("http://localhost:11434/v1/chat/completions")
 628                                .unwrap()
 629                                .into(),
 630                        )
 631                    } else {
 632                        None
 633                    }
 634                }
 635            },
 636        };
 637
 638        this.configure_context_retrieval(cx);
 639        let weak_this = cx.weak_entity();
 640        cx.on_flags_ready(move |_, cx| {
 641            weak_this
 642                .update(cx, |this, cx| this.configure_context_retrieval(cx))
 643                .ok();
 644        })
 645        .detach();
 646        cx.observe_global::<SettingsStore>(|this, cx| {
 647            this.configure_context_retrieval(cx);
 648        })
 649        .detach();
 650
 651        this
 652    }
 653
 654    #[cfg(test)]
 655    pub fn set_custom_predict_edits_url(&mut self, url: Url) {
 656        self.custom_predict_edits_url = Some(url.into());
 657    }
 658
 659    pub fn set_edit_prediction_model(&mut self, model: EditPredictionModel) {
 660        self.edit_prediction_model = model;
 661    }
 662
 663    pub fn has_sweep_api_token(&self, cx: &App) -> bool {
 664        self.sweep_ai.api_token.read(cx).has_key()
 665    }
 666
 667    pub fn has_mercury_api_token(&self, cx: &App) -> bool {
 668        self.mercury.api_token.read(cx).has_key()
 669    }
 670
 671    #[cfg(feature = "cli-support")]
 672    pub fn with_eval_cache(&mut self, cache: Arc<dyn EvalCache>) {
 673        self.eval_cache = Some(cache);
 674    }
 675
 676    pub fn options(&self) -> &ZetaOptions {
 677        &self.options
 678    }
 679
 680    pub fn set_options(&mut self, options: ZetaOptions) {
 681        self.options = options;
 682    }
 683
 684    pub fn set_use_context(&mut self, use_context: bool) {
 685        self.use_context = use_context;
 686    }
 687
 688    pub fn clear_history(&mut self) {
 689        for project_state in self.projects.values_mut() {
 690            project_state.events.clear();
 691            project_state.last_event.take();
 692        }
 693    }
 694
 695    pub fn clear_history_for_project(&mut self, project: &Entity<Project>) {
 696        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
 697            project_state.events.clear();
 698            project_state.last_event.take();
 699        }
 700    }
 701
 702    pub fn edit_history_for_project(
 703        &self,
 704        project: &Entity<Project>,
 705        cx: &App,
 706    ) -> Vec<StoredEvent> {
 707        self.projects
 708            .get(&project.entity_id())
 709            .map(|project_state| project_state.events(cx))
 710            .unwrap_or_default()
 711    }
 712
 713    pub fn edit_history_for_project_with_pause_split_last_event(
 714        &self,
 715        project: &Entity<Project>,
 716        cx: &App,
 717    ) -> Vec<StoredEvent> {
 718        self.projects
 719            .get(&project.entity_id())
 720            .map(|project_state| project_state.events_split_by_pause(cx))
 721            .unwrap_or_default()
 722    }
 723
 724    pub fn context_for_project<'a>(
 725        &'a self,
 726        project: &Entity<Project>,
 727        cx: &'a App,
 728    ) -> Arc<[RelatedFile]> {
 729        self.projects
 730            .get(&project.entity_id())
 731            .map(|project| project.context.read(cx).related_files())
 732            .unwrap_or_else(|| vec![].into())
 733    }
 734
 735    pub fn context_for_project_with_buffers<'a>(
 736        &'a self,
 737        project: &Entity<Project>,
 738        cx: &'a App,
 739    ) -> Option<impl 'a + Iterator<Item = (RelatedFile, Entity<Buffer>)>> {
 740        self.projects
 741            .get(&project.entity_id())
 742            .map(|project| project.context.read(cx).related_files_with_buffers())
 743    }
 744
 745    pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
 746        if self.edit_prediction_model == EditPredictionModel::Zeta2 {
 747            self.user_store.read(cx).edit_prediction_usage()
 748        } else {
 749            None
 750        }
 751    }
 752
 753    pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
 754        self.get_or_init_project(project, cx);
 755    }
 756
 757    pub fn register_buffer(
 758        &mut self,
 759        buffer: &Entity<Buffer>,
 760        project: &Entity<Project>,
 761        cx: &mut Context<Self>,
 762    ) {
 763        let project_state = self.get_or_init_project(project, cx);
 764        Self::register_buffer_impl(project_state, buffer, project, cx);
 765    }
 766
 767    fn get_or_init_project(
 768        &mut self,
 769        project: &Entity<Project>,
 770        cx: &mut Context<Self>,
 771    ) -> &mut ProjectState {
 772        let entity_id = project.entity_id();
 773        self.projects
 774            .entry(entity_id)
 775            .or_insert_with(|| ProjectState {
 776                context: {
 777                    let related_excerpt_store = cx.new(|cx| RelatedExcerptStore::new(project, cx));
 778                    cx.subscribe(&related_excerpt_store, move |this, _, event, _| {
 779                        this.handle_excerpt_store_event(entity_id, event);
 780                    })
 781                    .detach();
 782                    related_excerpt_store
 783                },
 784                events: VecDeque::new(),
 785                last_event: None,
 786                recent_paths: VecDeque::new(),
 787                debug_tx: None,
 788                registered_buffers: HashMap::default(),
 789                current_prediction: None,
 790                cancelled_predictions: HashSet::default(),
 791                pending_predictions: ArrayVec::new(),
 792                next_pending_prediction_id: 0,
 793                last_prediction_refresh: None,
 794                license_detection_watchers: HashMap::default(),
 795                _subscription: cx.subscribe(&project, Self::handle_project_event),
 796            })
 797    }
 798
 799    pub fn remove_project(&mut self, project: &Entity<Project>) {
 800        self.projects.remove(&project.entity_id());
 801    }
 802
 803    fn handle_excerpt_store_event(
 804        &mut self,
 805        project_entity_id: EntityId,
 806        event: &RelatedExcerptStoreEvent,
 807    ) {
 808        if let Some(project_state) = self.projects.get(&project_entity_id) {
 809            if let Some(debug_tx) = project_state.debug_tx.clone() {
 810                match event {
 811                    RelatedExcerptStoreEvent::StartedRefresh => {
 812                        debug_tx
 813                            .unbounded_send(DebugEvent::ContextRetrievalStarted(
 814                                ContextRetrievalStartedDebugEvent {
 815                                    project_entity_id: project_entity_id,
 816                                    timestamp: Instant::now(),
 817                                    search_prompt: String::new(),
 818                                },
 819                            ))
 820                            .ok();
 821                    }
 822                    RelatedExcerptStoreEvent::FinishedRefresh {
 823                        cache_hit_count,
 824                        cache_miss_count,
 825                        mean_definition_latency,
 826                        max_definition_latency,
 827                    } => {
 828                        debug_tx
 829                            .unbounded_send(DebugEvent::ContextRetrievalFinished(
 830                                ContextRetrievalFinishedDebugEvent {
 831                                    project_entity_id: project_entity_id,
 832                                    timestamp: Instant::now(),
 833                                    metadata: vec![
 834                                        (
 835                                            "Cache Hits",
 836                                            format!(
 837                                                "{}/{}",
 838                                                cache_hit_count,
 839                                                cache_hit_count + cache_miss_count
 840                                            )
 841                                            .into(),
 842                                        ),
 843                                        (
 844                                            "Max LSP Time",
 845                                            format!("{} ms", max_definition_latency.as_millis())
 846                                                .into(),
 847                                        ),
 848                                        (
 849                                            "Mean LSP Time",
 850                                            format!("{} ms", mean_definition_latency.as_millis())
 851                                                .into(),
 852                                        ),
 853                                    ],
 854                                },
 855                            ))
 856                            .ok();
 857                    }
 858                }
 859            }
 860        }
 861    }
 862
 863    pub fn debug_info(
 864        &mut self,
 865        project: &Entity<Project>,
 866        cx: &mut Context<Self>,
 867    ) -> mpsc::UnboundedReceiver<DebugEvent> {
 868        let project_state = self.get_or_init_project(project, cx);
 869        let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
 870        project_state.debug_tx = Some(debug_watch_tx);
 871        debug_watch_rx
 872    }
 873
 874    fn handle_project_event(
 875        &mut self,
 876        project: Entity<Project>,
 877        event: &project::Event,
 878        cx: &mut Context<Self>,
 879    ) {
 880        // TODO [zeta2] init with recent paths
 881        match event {
 882            project::Event::ActiveEntryChanged(Some(active_entry_id)) => {
 883                let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
 884                    return;
 885                };
 886                let path = project.read(cx).path_for_entry(*active_entry_id, cx);
 887                if let Some(path) = path {
 888                    if let Some(ix) = project_state
 889                        .recent_paths
 890                        .iter()
 891                        .position(|probe| probe == &path)
 892                    {
 893                        project_state.recent_paths.remove(ix);
 894                    }
 895                    project_state.recent_paths.push_front(path);
 896                }
 897            }
 898            project::Event::DiagnosticsUpdated { .. } => {
 899                if cx.has_flag::<Zeta2FeatureFlag>() {
 900                    self.refresh_prediction_from_diagnostics(project, cx);
 901                }
 902            }
 903            _ => (),
 904        }
 905    }
 906
 907    fn register_buffer_impl<'a>(
 908        project_state: &'a mut ProjectState,
 909        buffer: &Entity<Buffer>,
 910        project: &Entity<Project>,
 911        cx: &mut Context<Self>,
 912    ) -> &'a mut RegisteredBuffer {
 913        let buffer_id = buffer.entity_id();
 914
 915        if let Some(file) = buffer.read(cx).file() {
 916            let worktree_id = file.worktree_id(cx);
 917            if let Some(worktree) = project.read(cx).worktree_for_id(worktree_id, cx) {
 918                project_state
 919                    .license_detection_watchers
 920                    .entry(worktree_id)
 921                    .or_insert_with(|| {
 922                        let project_entity_id = project.entity_id();
 923                        cx.observe_release(&worktree, move |this, _worktree, _cx| {
 924                            let Some(project_state) = this.projects.get_mut(&project_entity_id)
 925                            else {
 926                                return;
 927                            };
 928                            project_state
 929                                .license_detection_watchers
 930                                .remove(&worktree_id);
 931                        })
 932                        .detach();
 933                        Rc::new(LicenseDetectionWatcher::new(&worktree, cx))
 934                    });
 935            }
 936        }
 937
 938        match project_state.registered_buffers.entry(buffer_id) {
 939            hash_map::Entry::Occupied(entry) => entry.into_mut(),
 940            hash_map::Entry::Vacant(entry) => {
 941                let buf = buffer.read(cx);
 942                let snapshot = buf.text_snapshot();
 943                let file = buf.file().cloned();
 944                let project_entity_id = project.entity_id();
 945                entry.insert(RegisteredBuffer {
 946                    snapshot,
 947                    file,
 948                    last_position: None,
 949                    _subscriptions: [
 950                        cx.subscribe(buffer, {
 951                            let project = project.downgrade();
 952                            move |this, buffer, event, cx| {
 953                                if let language::BufferEvent::Edited = event
 954                                    && let Some(project) = project.upgrade()
 955                                {
 956                                    this.report_changes_for_buffer(&buffer, &project, cx);
 957                                }
 958                            }
 959                        }),
 960                        cx.observe_release(buffer, move |this, _buffer, _cx| {
 961                            let Some(project_state) = this.projects.get_mut(&project_entity_id)
 962                            else {
 963                                return;
 964                            };
 965                            project_state.registered_buffers.remove(&buffer_id);
 966                        }),
 967                    ],
 968                })
 969            }
 970        }
 971    }
 972
 973    fn report_changes_for_buffer(
 974        &mut self,
 975        buffer: &Entity<Buffer>,
 976        project: &Entity<Project>,
 977        cx: &mut Context<Self>,
 978    ) {
 979        let project_state = self.get_or_init_project(project, cx);
 980        let registered_buffer = Self::register_buffer_impl(project_state, buffer, project, cx);
 981
 982        let buf = buffer.read(cx);
 983        let new_file = buf.file().cloned();
 984        let new_snapshot = buf.text_snapshot();
 985        if new_snapshot.version == registered_buffer.snapshot.version {
 986            return;
 987        }
 988
 989        let old_file = mem::replace(&mut registered_buffer.file, new_file.clone());
 990        let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
 991        let end_edit_anchor = new_snapshot
 992            .anchored_edits_since::<Point>(&old_snapshot.version)
 993            .last()
 994            .map(|(_, range)| range.end);
 995        let events = &mut project_state.events;
 996
 997        let now = cx.background_executor().now();
 998        if let Some(last_event) = project_state.last_event.as_mut() {
 999            let is_next_snapshot_of_same_buffer = old_snapshot.remote_id()
1000                == last_event.new_snapshot.remote_id()
1001                && old_snapshot.version == last_event.new_snapshot.version;
1002
1003            let should_coalesce = is_next_snapshot_of_same_buffer
1004                && end_edit_anchor
1005                    .as_ref()
1006                    .zip(last_event.end_edit_anchor.as_ref())
1007                    .is_some_and(|(a, b)| {
1008                        let a = a.to_point(&new_snapshot);
1009                        let b = b.to_point(&new_snapshot);
1010                        a.row.abs_diff(b.row) <= CHANGE_GROUPING_LINE_SPAN
1011                    });
1012
1013            if should_coalesce {
1014                let pause_elapsed = last_event
1015                    .last_edit_time
1016                    .map(|t| now.duration_since(t) >= LAST_CHANGE_GROUPING_TIME)
1017                    .unwrap_or(false);
1018                if pause_elapsed {
1019                    last_event.snapshot_after_last_editing_pause =
1020                        Some(last_event.new_snapshot.clone());
1021                }
1022
1023                last_event.end_edit_anchor = end_edit_anchor;
1024                last_event.new_snapshot = new_snapshot;
1025                last_event.last_edit_time = Some(now);
1026                return;
1027            }
1028        }
1029
1030        if events.len() + 1 >= EVENT_COUNT_MAX {
1031            events.pop_front();
1032        }
1033
1034        if let Some(event) = project_state.last_event.take() {
1035            events.extend(event.finalize(&project_state.license_detection_watchers, cx));
1036        }
1037
1038        project_state.last_event = Some(LastEvent {
1039            old_file,
1040            new_file,
1041            old_snapshot,
1042            new_snapshot,
1043            end_edit_anchor,
1044            snapshot_after_last_editing_pause: None,
1045            last_edit_time: Some(now),
1046        });
1047    }
1048
1049    fn prediction_at(
1050        &mut self,
1051        buffer: &Entity<Buffer>,
1052        position: Option<language::Anchor>,
1053        project: &Entity<Project>,
1054        cx: &App,
1055    ) -> Option<BufferEditPrediction<'_>> {
1056        let project_state = self.projects.get_mut(&project.entity_id())?;
1057        if let Some(position) = position
1058            && let Some(buffer) = project_state
1059                .registered_buffers
1060                .get_mut(&buffer.entity_id())
1061        {
1062            buffer.last_position = Some(position);
1063        }
1064
1065        let CurrentEditPrediction {
1066            requested_by,
1067            prediction,
1068            ..
1069        } = project_state.current_prediction.as_ref()?;
1070
1071        if prediction.targets_buffer(buffer.read(cx)) {
1072            Some(BufferEditPrediction::Local { prediction })
1073        } else {
1074            let show_jump = match requested_by {
1075                PredictionRequestedBy::Buffer(requested_by_buffer_id) => {
1076                    requested_by_buffer_id == &buffer.entity_id()
1077                }
1078                PredictionRequestedBy::DiagnosticsUpdate => true,
1079            };
1080
1081            if show_jump {
1082                Some(BufferEditPrediction::Jump { prediction })
1083            } else {
1084                None
1085            }
1086        }
1087    }
1088
1089    fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
1090        let custom_accept_url = env::var("ZED_ACCEPT_PREDICTION_URL").ok();
1091        match self.edit_prediction_model {
1092            EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {
1093                if self.custom_predict_edits_url.is_some() && custom_accept_url.is_none() {
1094                    return;
1095                }
1096            }
1097            EditPredictionModel::Sweep | EditPredictionModel::Mercury => return,
1098        }
1099
1100        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1101            return;
1102        };
1103
1104        let Some(prediction) = project_state.current_prediction.take() else {
1105            return;
1106        };
1107        let request_id = prediction.prediction.id.to_string();
1108        for pending_prediction in mem::take(&mut project_state.pending_predictions) {
1109            project_state.cancel_pending_prediction(pending_prediction, cx);
1110        }
1111
1112        let client = self.client.clone();
1113        let llm_token = self.llm_token.clone();
1114        let app_version = AppVersion::global(cx);
1115        cx.spawn(async move |this, cx| {
1116            let (url, require_auth) = if let Some(accept_edits_url) = custom_accept_url {
1117                (http_client::Url::parse(&accept_edits_url)?, false)
1118            } else {
1119                (
1120                    client
1121                        .http_client()
1122                        .build_zed_llm_url("/predict_edits/accept", &[])?,
1123                    true,
1124                )
1125            };
1126
1127            let response = cx
1128                .background_spawn(Self::send_api_request::<()>(
1129                    move |builder| {
1130                        let req = builder.uri(url.as_ref()).body(
1131                            serde_json::to_string(&AcceptEditPredictionBody {
1132                                request_id: request_id.clone(),
1133                            })?
1134                            .into(),
1135                        );
1136                        Ok(req?)
1137                    },
1138                    client,
1139                    llm_token,
1140                    app_version,
1141                    require_auth,
1142                ))
1143                .await;
1144
1145            Self::handle_api_response(&this, response, cx)?;
1146            anyhow::Ok(())
1147        })
1148        .detach_and_log_err(cx);
1149    }
1150
1151    async fn handle_rejected_predictions(
1152        rx: UnboundedReceiver<EditPredictionRejection>,
1153        client: Arc<Client>,
1154        llm_token: LlmApiToken,
1155        app_version: Version,
1156        background_executor: BackgroundExecutor,
1157    ) {
1158        let mut rx = std::pin::pin!(rx.peekable());
1159        let mut batched = Vec::new();
1160
1161        while let Some(rejection) = rx.next().await {
1162            batched.push(rejection);
1163
1164            if batched.len() < MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST / 2 {
1165                select_biased! {
1166                    next = rx.as_mut().peek().fuse() => {
1167                        if next.is_some() {
1168                            continue;
1169                        }
1170                    }
1171                    () = background_executor.timer(REJECT_REQUEST_DEBOUNCE).fuse() => {},
1172                }
1173            }
1174
1175            let url = client
1176                .http_client()
1177                .build_zed_llm_url("/predict_edits/reject", &[])
1178                .unwrap();
1179
1180            let flush_count = batched
1181                .len()
1182                // in case items have accumulated after failure
1183                .min(MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST);
1184            let start = batched.len() - flush_count;
1185
1186            let body = RejectEditPredictionsBodyRef {
1187                rejections: &batched[start..],
1188            };
1189
1190            let result = Self::send_api_request::<()>(
1191                |builder| {
1192                    let req = builder
1193                        .uri(url.as_ref())
1194                        .body(serde_json::to_string(&body)?.into());
1195                    anyhow::Ok(req?)
1196                },
1197                client.clone(),
1198                llm_token.clone(),
1199                app_version.clone(),
1200                true,
1201            )
1202            .await;
1203
1204            if result.log_err().is_some() {
1205                batched.drain(start..);
1206            }
1207        }
1208    }
1209
1210    fn reject_current_prediction(
1211        &mut self,
1212        reason: EditPredictionRejectReason,
1213        project: &Entity<Project>,
1214    ) {
1215        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
1216            project_state.pending_predictions.clear();
1217            if let Some(prediction) = project_state.current_prediction.take() {
1218                self.reject_prediction(prediction.prediction.id, reason, prediction.was_shown);
1219            }
1220        };
1221    }
1222
1223    fn did_show_current_prediction(&mut self, project: &Entity<Project>, _cx: &mut Context<Self>) {
1224        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
1225            if let Some(current_prediction) = project_state.current_prediction.as_mut() {
1226                if !current_prediction.was_shown {
1227                    current_prediction.was_shown = true;
1228                    self.shown_predictions
1229                        .push_front(current_prediction.prediction.clone());
1230                    if self.shown_predictions.len() > 50 {
1231                        let completion = self.shown_predictions.pop_back().unwrap();
1232                        self.rated_predictions.remove(&completion.id);
1233                    }
1234                }
1235            }
1236        }
1237    }
1238
1239    fn reject_prediction(
1240        &mut self,
1241        prediction_id: EditPredictionId,
1242        reason: EditPredictionRejectReason,
1243        was_shown: bool,
1244    ) {
1245        match self.edit_prediction_model {
1246            EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {
1247                if self.custom_predict_edits_url.is_some() {
1248                    return;
1249                }
1250            }
1251            EditPredictionModel::Sweep | EditPredictionModel::Mercury => return,
1252        }
1253
1254        self.reject_predictions_tx
1255            .unbounded_send(EditPredictionRejection {
1256                request_id: prediction_id.to_string(),
1257                reason,
1258                was_shown,
1259            })
1260            .log_err();
1261    }
1262
1263    fn is_refreshing(&self, project: &Entity<Project>) -> bool {
1264        self.projects
1265            .get(&project.entity_id())
1266            .is_some_and(|project_state| !project_state.pending_predictions.is_empty())
1267    }
1268
1269    pub fn refresh_prediction_from_buffer(
1270        &mut self,
1271        project: Entity<Project>,
1272        buffer: Entity<Buffer>,
1273        position: language::Anchor,
1274        cx: &mut Context<Self>,
1275    ) {
1276        self.queue_prediction_refresh(project.clone(), buffer.entity_id(), cx, move |this, cx| {
1277            let Some(request_task) = this
1278                .update(cx, |this, cx| {
1279                    this.request_prediction(
1280                        &project,
1281                        &buffer,
1282                        position,
1283                        PredictEditsRequestTrigger::Other,
1284                        cx,
1285                    )
1286                })
1287                .log_err()
1288            else {
1289                return Task::ready(anyhow::Ok(None));
1290            };
1291
1292            cx.spawn(async move |_cx| {
1293                request_task.await.map(|prediction_result| {
1294                    prediction_result.map(|prediction_result| {
1295                        (
1296                            prediction_result,
1297                            PredictionRequestedBy::Buffer(buffer.entity_id()),
1298                        )
1299                    })
1300                })
1301            })
1302        })
1303    }
1304
1305    pub fn refresh_prediction_from_diagnostics(
1306        &mut self,
1307        project: Entity<Project>,
1308        cx: &mut Context<Self>,
1309    ) {
1310        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1311            return;
1312        };
1313
1314        // Prefer predictions from buffer
1315        if project_state.current_prediction.is_some() {
1316            return;
1317        };
1318
1319        self.queue_prediction_refresh(project.clone(), project.entity_id(), cx, move |this, cx| {
1320            let Some((active_buffer, snapshot, cursor_point)) = this
1321                .read_with(cx, |this, cx| {
1322                    let project_state = this.projects.get(&project.entity_id())?;
1323                    let (buffer, position) = project_state.active_buffer(&project, cx)?;
1324                    let snapshot = buffer.read(cx).snapshot();
1325
1326                    if !Self::predictions_enabled_at(&snapshot, position, cx) {
1327                        return None;
1328                    }
1329
1330                    let cursor_point = position
1331                        .map(|pos| pos.to_point(&snapshot))
1332                        .unwrap_or_default();
1333
1334                    Some((buffer, snapshot, cursor_point))
1335                })
1336                .log_err()
1337                .flatten()
1338            else {
1339                return Task::ready(anyhow::Ok(None));
1340            };
1341
1342            cx.spawn(async move |cx| {
1343                let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1344                    active_buffer,
1345                    &snapshot,
1346                    Default::default(),
1347                    cursor_point,
1348                    &project,
1349                    cx,
1350                )
1351                .await?
1352                else {
1353                    return anyhow::Ok(None);
1354                };
1355
1356                let Some(prediction_result) = this
1357                    .update(cx, |this, cx| {
1358                        this.request_prediction(
1359                            &project,
1360                            &jump_buffer,
1361                            jump_position,
1362                            PredictEditsRequestTrigger::Diagnostics,
1363                            cx,
1364                        )
1365                    })?
1366                    .await?
1367                else {
1368                    return anyhow::Ok(None);
1369                };
1370
1371                this.update(cx, |this, cx| {
1372                    Some((
1373                        if this
1374                            .get_or_init_project(&project, cx)
1375                            .current_prediction
1376                            .is_none()
1377                        {
1378                            prediction_result
1379                        } else {
1380                            EditPredictionResult {
1381                                id: prediction_result.id,
1382                                prediction: Err(EditPredictionRejectReason::CurrentPreferred),
1383                            }
1384                        },
1385                        PredictionRequestedBy::DiagnosticsUpdate,
1386                    ))
1387                })
1388            })
1389        });
1390    }
1391
1392    fn predictions_enabled_at(
1393        snapshot: &BufferSnapshot,
1394        position: Option<language::Anchor>,
1395        cx: &App,
1396    ) -> bool {
1397        let file = snapshot.file();
1398        let all_settings = all_language_settings(file, cx);
1399        if !all_settings.show_edit_predictions(snapshot.language(), cx)
1400            || file.is_some_and(|file| !all_settings.edit_predictions_enabled_for_file(file, cx))
1401        {
1402            return false;
1403        }
1404
1405        if let Some(last_position) = position {
1406            let settings = snapshot.settings_at(last_position, cx);
1407
1408            if !settings.edit_predictions_disabled_in.is_empty()
1409                && let Some(scope) = snapshot.language_scope_at(last_position)
1410                && let Some(scope_name) = scope.override_name()
1411                && settings
1412                    .edit_predictions_disabled_in
1413                    .iter()
1414                    .any(|s| s == scope_name)
1415            {
1416                return false;
1417            }
1418        }
1419
1420        true
1421    }
1422
1423    #[cfg(not(test))]
1424    pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
1425    #[cfg(test)]
1426    pub const THROTTLE_TIMEOUT: Duration = Duration::ZERO;
1427
1428    fn queue_prediction_refresh(
1429        &mut self,
1430        project: Entity<Project>,
1431        throttle_entity: EntityId,
1432        cx: &mut Context<Self>,
1433        do_refresh: impl FnOnce(
1434            WeakEntity<Self>,
1435            &mut AsyncApp,
1436        )
1437            -> Task<Result<Option<(EditPredictionResult, PredictionRequestedBy)>>>
1438        + 'static,
1439    ) {
1440        let project_state = self.get_or_init_project(&project, cx);
1441        let pending_prediction_id = project_state.next_pending_prediction_id;
1442        project_state.next_pending_prediction_id += 1;
1443        let last_request = project_state.last_prediction_refresh;
1444
1445        let task = cx.spawn(async move |this, cx| {
1446            if let Some((last_entity, last_timestamp)) = last_request
1447                && throttle_entity == last_entity
1448                && let Some(timeout) =
1449                    (last_timestamp + Self::THROTTLE_TIMEOUT).checked_duration_since(Instant::now())
1450            {
1451                cx.background_executor().timer(timeout).await;
1452            }
1453
1454            // If this task was cancelled before the throttle timeout expired,
1455            // do not perform a request.
1456            let mut is_cancelled = true;
1457            this.update(cx, |this, cx| {
1458                let project_state = this.get_or_init_project(&project, cx);
1459                if !project_state
1460                    .cancelled_predictions
1461                    .remove(&pending_prediction_id)
1462                {
1463                    project_state.last_prediction_refresh = Some((throttle_entity, Instant::now()));
1464                    is_cancelled = false;
1465                }
1466            })
1467            .ok();
1468            if is_cancelled {
1469                return None;
1470            }
1471
1472            let new_prediction_result = do_refresh(this.clone(), cx).await.log_err().flatten();
1473            let new_prediction_id = new_prediction_result
1474                .as_ref()
1475                .map(|(prediction, _)| prediction.id.clone());
1476
1477            // When a prediction completes, remove it from the pending list, and cancel
1478            // any pending predictions that were enqueued before it.
1479            this.update(cx, |this, cx| {
1480                let project_state = this.get_or_init_project(&project, cx);
1481
1482                let is_cancelled = project_state
1483                    .cancelled_predictions
1484                    .remove(&pending_prediction_id);
1485
1486                let new_current_prediction = if !is_cancelled
1487                    && let Some((prediction_result, requested_by)) = new_prediction_result
1488                {
1489                    match prediction_result.prediction {
1490                        Ok(prediction) => {
1491                            let new_prediction = CurrentEditPrediction {
1492                                requested_by,
1493                                prediction,
1494                                was_shown: false,
1495                            };
1496
1497                            if let Some(current_prediction) =
1498                                project_state.current_prediction.as_ref()
1499                            {
1500                                if new_prediction.should_replace_prediction(&current_prediction, cx)
1501                                {
1502                                    this.reject_current_prediction(
1503                                        EditPredictionRejectReason::Replaced,
1504                                        &project,
1505                                    );
1506
1507                                    Some(new_prediction)
1508                                } else {
1509                                    this.reject_prediction(
1510                                        new_prediction.prediction.id,
1511                                        EditPredictionRejectReason::CurrentPreferred,
1512                                        false,
1513                                    );
1514                                    None
1515                                }
1516                            } else {
1517                                Some(new_prediction)
1518                            }
1519                        }
1520                        Err(reject_reason) => {
1521                            this.reject_prediction(prediction_result.id, reject_reason, false);
1522                            None
1523                        }
1524                    }
1525                } else {
1526                    None
1527                };
1528
1529                let project_state = this.get_or_init_project(&project, cx);
1530
1531                if let Some(new_prediction) = new_current_prediction {
1532                    project_state.current_prediction = Some(new_prediction);
1533                }
1534
1535                let mut pending_predictions = mem::take(&mut project_state.pending_predictions);
1536                for (ix, pending_prediction) in pending_predictions.iter().enumerate() {
1537                    if pending_prediction.id == pending_prediction_id {
1538                        pending_predictions.remove(ix);
1539                        for pending_prediction in pending_predictions.drain(0..ix) {
1540                            project_state.cancel_pending_prediction(pending_prediction, cx)
1541                        }
1542                        break;
1543                    }
1544                }
1545                this.get_or_init_project(&project, cx).pending_predictions = pending_predictions;
1546                cx.notify();
1547            })
1548            .ok();
1549
1550            new_prediction_id
1551        });
1552
1553        if project_state.pending_predictions.len() <= 1 {
1554            project_state.pending_predictions.push(PendingPrediction {
1555                id: pending_prediction_id,
1556                task,
1557            });
1558        } else if project_state.pending_predictions.len() == 2 {
1559            let pending_prediction = project_state.pending_predictions.pop().unwrap();
1560            project_state.pending_predictions.push(PendingPrediction {
1561                id: pending_prediction_id,
1562                task,
1563            });
1564            project_state.cancel_pending_prediction(pending_prediction, cx);
1565        }
1566    }
1567
1568    pub fn request_prediction(
1569        &mut self,
1570        project: &Entity<Project>,
1571        active_buffer: &Entity<Buffer>,
1572        position: language::Anchor,
1573        trigger: PredictEditsRequestTrigger,
1574        cx: &mut Context<Self>,
1575    ) -> Task<Result<Option<EditPredictionResult>>> {
1576        self.request_prediction_internal(
1577            project.clone(),
1578            active_buffer.clone(),
1579            position,
1580            trigger,
1581            cx.has_flag::<Zeta2FeatureFlag>(),
1582            cx,
1583        )
1584    }
1585
1586    fn request_prediction_internal(
1587        &mut self,
1588        project: Entity<Project>,
1589        active_buffer: Entity<Buffer>,
1590        position: language::Anchor,
1591        trigger: PredictEditsRequestTrigger,
1592        allow_jump: bool,
1593        cx: &mut Context<Self>,
1594    ) -> Task<Result<Option<EditPredictionResult>>> {
1595        const DIAGNOSTIC_LINES_RANGE: u32 = 20;
1596
1597        self.get_or_init_project(&project, cx);
1598        let project_state = self.projects.get(&project.entity_id()).unwrap();
1599        let stored_events = project_state.events(cx);
1600        let has_events = !stored_events.is_empty();
1601        let events: Vec<Arc<zeta_prompt::Event>> =
1602            stored_events.into_iter().map(|e| e.event).collect();
1603        let debug_tx = project_state.debug_tx.clone();
1604
1605        let snapshot = active_buffer.read(cx).snapshot();
1606        let cursor_point = position.to_point(&snapshot);
1607        let diagnostic_search_start = cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
1608        let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
1609        let diagnostic_search_range =
1610            Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
1611
1612        let related_files = if self.use_context {
1613            self.context_for_project(&project, cx)
1614        } else {
1615            Vec::new().into()
1616        };
1617
1618        let inputs = EditPredictionModelInput {
1619            project: project.clone(),
1620            buffer: active_buffer.clone(),
1621            snapshot: snapshot.clone(),
1622            position,
1623            events,
1624            related_files,
1625            recent_paths: project_state.recent_paths.clone(),
1626            trigger,
1627            diagnostic_search_range: diagnostic_search_range.clone(),
1628            debug_tx,
1629        };
1630
1631        let task = match self.edit_prediction_model {
1632            EditPredictionModel::Zeta1 => zeta1::request_prediction_with_zeta1(self, inputs, cx),
1633            EditPredictionModel::Zeta2 => zeta2::request_prediction_with_zeta2(self, inputs, cx),
1634            EditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep(inputs, cx),
1635            EditPredictionModel::Mercury => self.mercury.request_prediction(inputs, cx),
1636        };
1637
1638        cx.spawn(async move |this, cx| {
1639            let prediction = task.await?;
1640
1641            if prediction.is_none() && allow_jump {
1642                let cursor_point = position.to_point(&snapshot);
1643                if has_events
1644                    && let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1645                        active_buffer.clone(),
1646                        &snapshot,
1647                        diagnostic_search_range,
1648                        cursor_point,
1649                        &project,
1650                        cx,
1651                    )
1652                    .await?
1653                {
1654                    return this
1655                        .update(cx, |this, cx| {
1656                            this.request_prediction_internal(
1657                                project,
1658                                jump_buffer,
1659                                jump_position,
1660                                trigger,
1661                                false,
1662                                cx,
1663                            )
1664                        })?
1665                        .await;
1666                }
1667
1668                return anyhow::Ok(None);
1669            }
1670
1671            Ok(prediction)
1672        })
1673    }
1674
1675    async fn next_diagnostic_location(
1676        active_buffer: Entity<Buffer>,
1677        active_buffer_snapshot: &BufferSnapshot,
1678        active_buffer_diagnostic_search_range: Range<Point>,
1679        active_buffer_cursor_point: Point,
1680        project: &Entity<Project>,
1681        cx: &mut AsyncApp,
1682    ) -> Result<Option<(Entity<Buffer>, language::Anchor)>> {
1683        // find the closest diagnostic to the cursor that wasn't close enough to be included in the last request
1684        let mut jump_location = active_buffer_snapshot
1685            .diagnostic_groups(None)
1686            .into_iter()
1687            .filter_map(|(_, group)| {
1688                let range = &group.entries[group.primary_ix]
1689                    .range
1690                    .to_point(&active_buffer_snapshot);
1691                if range.overlaps(&active_buffer_diagnostic_search_range) {
1692                    None
1693                } else {
1694                    Some(range.start)
1695                }
1696            })
1697            .min_by_key(|probe| probe.row.abs_diff(active_buffer_cursor_point.row))
1698            .map(|position| {
1699                (
1700                    active_buffer.clone(),
1701                    active_buffer_snapshot.anchor_before(position),
1702                )
1703            });
1704
1705        if jump_location.is_none() {
1706            let active_buffer_path = active_buffer.read_with(cx, |buffer, cx| {
1707                let file = buffer.file()?;
1708
1709                Some(ProjectPath {
1710                    worktree_id: file.worktree_id(cx),
1711                    path: file.path().clone(),
1712                })
1713            })?;
1714
1715            let buffer_task = project.update(cx, |project, cx| {
1716                let (path, _, _) = project
1717                    .diagnostic_summaries(false, cx)
1718                    .filter(|(path, _, _)| Some(path) != active_buffer_path.as_ref())
1719                    .max_by_key(|(path, _, _)| {
1720                        // find the buffer with errors that shares most parent directories
1721                        path.path
1722                            .components()
1723                            .zip(
1724                                active_buffer_path
1725                                    .as_ref()
1726                                    .map(|p| p.path.components())
1727                                    .unwrap_or_default(),
1728                            )
1729                            .take_while(|(a, b)| a == b)
1730                            .count()
1731                    })?;
1732
1733                Some(project.open_buffer(path, cx))
1734            })?;
1735
1736            if let Some(buffer_task) = buffer_task {
1737                let closest_buffer = buffer_task.await?;
1738
1739                jump_location = closest_buffer
1740                    .read_with(cx, |buffer, _cx| {
1741                        buffer
1742                            .buffer_diagnostics(None)
1743                            .into_iter()
1744                            .min_by_key(|entry| entry.diagnostic.severity)
1745                            .map(|entry| entry.range.start)
1746                    })?
1747                    .map(|position| (closest_buffer, position));
1748            }
1749        }
1750
1751        anyhow::Ok(jump_location)
1752    }
1753
1754    async fn send_raw_llm_request(
1755        request: open_ai::Request,
1756        client: Arc<Client>,
1757        llm_token: LlmApiToken,
1758        app_version: Version,
1759        #[cfg(feature = "cli-support")] eval_cache: Option<Arc<dyn EvalCache>>,
1760        #[cfg(feature = "cli-support")] eval_cache_kind: EvalCacheEntryKind,
1761    ) -> Result<(open_ai::Response, Option<EditPredictionUsage>)> {
1762        let url = client
1763            .http_client()
1764            .build_zed_llm_url("/predict_edits/raw", &[])?;
1765
1766        #[cfg(feature = "cli-support")]
1767        let cache_key = if let Some(cache) = eval_cache {
1768            use collections::FxHasher;
1769            use std::hash::{Hash, Hasher};
1770
1771            let mut hasher = FxHasher::default();
1772            url.hash(&mut hasher);
1773            let request_str = serde_json::to_string_pretty(&request)?;
1774            request_str.hash(&mut hasher);
1775            let hash = hasher.finish();
1776
1777            let key = (eval_cache_kind, hash);
1778            if let Some(response_str) = cache.read(key) {
1779                return Ok((serde_json::from_str(&response_str)?, None));
1780            }
1781
1782            Some((cache, request_str, key))
1783        } else {
1784            None
1785        };
1786
1787        let (response, usage) = Self::send_api_request(
1788            |builder| {
1789                let req = builder
1790                    .uri(url.as_ref())
1791                    .body(serde_json::to_string(&request)?.into());
1792                Ok(req?)
1793            },
1794            client,
1795            llm_token,
1796            app_version,
1797            true,
1798        )
1799        .await?;
1800
1801        #[cfg(feature = "cli-support")]
1802        if let Some((cache, request, key)) = cache_key {
1803            cache.write(key, &request, &serde_json::to_string_pretty(&response)?);
1804        }
1805
1806        Ok((response, usage))
1807    }
1808
1809    fn handle_api_response<T>(
1810        this: &WeakEntity<Self>,
1811        response: Result<(T, Option<EditPredictionUsage>)>,
1812        cx: &mut gpui::AsyncApp,
1813    ) -> Result<T> {
1814        match response {
1815            Ok((data, usage)) => {
1816                if let Some(usage) = usage {
1817                    this.update(cx, |this, cx| {
1818                        this.user_store.update(cx, |user_store, cx| {
1819                            user_store.update_edit_prediction_usage(usage, cx);
1820                        });
1821                    })
1822                    .ok();
1823                }
1824                Ok(data)
1825            }
1826            Err(err) => {
1827                if err.is::<ZedUpdateRequiredError>() {
1828                    cx.update(|cx| {
1829                        this.update(cx, |this, _cx| {
1830                            this.update_required = true;
1831                        })
1832                        .ok();
1833
1834                        let error_message: SharedString = err.to_string().into();
1835                        show_app_notification(
1836                            NotificationId::unique::<ZedUpdateRequiredError>(),
1837                            cx,
1838                            move |cx| {
1839                                cx.new(|cx| {
1840                                    ErrorMessagePrompt::new(error_message.clone(), cx)
1841                                        .with_link_button("Update Zed", "https://zed.dev/releases")
1842                                })
1843                            },
1844                        );
1845                    })
1846                    .ok();
1847                }
1848                Err(err)
1849            }
1850        }
1851    }
1852
1853    async fn send_api_request<Res>(
1854        build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
1855        client: Arc<Client>,
1856        llm_token: LlmApiToken,
1857        app_version: Version,
1858        require_auth: bool,
1859    ) -> Result<(Res, Option<EditPredictionUsage>)>
1860    where
1861        Res: DeserializeOwned,
1862    {
1863        let http_client = client.http_client();
1864
1865        let mut token = if require_auth {
1866            Some(llm_token.acquire(&client).await?)
1867        } else {
1868            llm_token.acquire(&client).await.ok()
1869        };
1870        let mut did_retry = false;
1871
1872        loop {
1873            let request_builder = http_client::Request::builder().method(Method::POST);
1874
1875            let mut request_builder = request_builder
1876                .header("Content-Type", "application/json")
1877                .header(ZED_VERSION_HEADER_NAME, app_version.to_string());
1878
1879            // Only add Authorization header if we have a token
1880            if let Some(ref token_value) = token {
1881                request_builder =
1882                    request_builder.header("Authorization", format!("Bearer {}", token_value));
1883            }
1884
1885            let request = build(request_builder)?;
1886
1887            let mut response = http_client.send(request).await?;
1888
1889            if let Some(minimum_required_version) = response
1890                .headers()
1891                .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
1892                .and_then(|version| Version::from_str(version.to_str().ok()?).ok())
1893            {
1894                anyhow::ensure!(
1895                    app_version >= minimum_required_version,
1896                    ZedUpdateRequiredError {
1897                        minimum_version: minimum_required_version
1898                    }
1899                );
1900            }
1901
1902            if response.status().is_success() {
1903                let usage = EditPredictionUsage::from_headers(response.headers()).ok();
1904
1905                let mut body = Vec::new();
1906                response.body_mut().read_to_end(&mut body).await?;
1907                return Ok((serde_json::from_slice(&body)?, usage));
1908            } else if !did_retry
1909                && token.is_some()
1910                && response
1911                    .headers()
1912                    .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
1913                    .is_some()
1914            {
1915                did_retry = true;
1916                token = Some(llm_token.refresh(&client).await?);
1917            } else {
1918                let mut body = String::new();
1919                response.body_mut().read_to_string(&mut body).await?;
1920                anyhow::bail!(
1921                    "Request failed with status: {:?}\nBody: {}",
1922                    response.status(),
1923                    body
1924                );
1925            }
1926        }
1927    }
1928
1929    pub fn refresh_context(
1930        &mut self,
1931        project: &Entity<Project>,
1932        buffer: &Entity<language::Buffer>,
1933        cursor_position: language::Anchor,
1934        cx: &mut Context<Self>,
1935    ) {
1936        if self.use_context {
1937            self.get_or_init_project(project, cx)
1938                .context
1939                .update(cx, |store, cx| {
1940                    store.refresh(buffer.clone(), cursor_position, cx);
1941                });
1942        }
1943    }
1944
1945    #[cfg(feature = "cli-support")]
1946    pub fn set_context_for_buffer(
1947        &mut self,
1948        project: &Entity<Project>,
1949        related_files: Vec<RelatedFile>,
1950        cx: &mut Context<Self>,
1951    ) {
1952        self.get_or_init_project(project, cx)
1953            .context
1954            .update(cx, |store, _| {
1955                store.set_related_files(related_files);
1956            });
1957    }
1958
1959    fn is_file_open_source(
1960        &self,
1961        project: &Entity<Project>,
1962        file: &Arc<dyn File>,
1963        cx: &App,
1964    ) -> bool {
1965        if !file.is_local() || file.is_private() {
1966            return false;
1967        }
1968        let Some(project_state) = self.projects.get(&project.entity_id()) else {
1969            return false;
1970        };
1971        project_state
1972            .license_detection_watchers
1973            .get(&file.worktree_id(cx))
1974            .as_ref()
1975            .is_some_and(|watcher| watcher.is_project_open_source())
1976    }
1977
1978    fn can_collect_file(&self, project: &Entity<Project>, file: &Arc<dyn File>, cx: &App) -> bool {
1979        self.data_collection_choice.is_enabled() && self.is_file_open_source(project, file, cx)
1980    }
1981
1982    fn can_collect_events(&self, events: &[Arc<zeta_prompt::Event>]) -> bool {
1983        if !self.data_collection_choice.is_enabled() {
1984            return false;
1985        }
1986        events.iter().all(|event| {
1987            matches!(
1988                event.as_ref(),
1989                zeta_prompt::Event::BufferChange {
1990                    in_open_source_repo: true,
1991                    ..
1992                }
1993            )
1994        })
1995    }
1996
1997    fn load_data_collection_choice() -> DataCollectionChoice {
1998        let choice = KEY_VALUE_STORE
1999            .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2000            .log_err()
2001            .flatten();
2002
2003        match choice.as_deref() {
2004            Some("true") => DataCollectionChoice::Enabled,
2005            Some("false") => DataCollectionChoice::Disabled,
2006            Some(_) => {
2007                log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
2008                DataCollectionChoice::NotAnswered
2009            }
2010            None => DataCollectionChoice::NotAnswered,
2011        }
2012    }
2013
2014    fn toggle_data_collection_choice(&mut self, cx: &mut Context<Self>) {
2015        self.data_collection_choice = self.data_collection_choice.toggle();
2016        let new_choice = self.data_collection_choice;
2017        db::write_and_log(cx, move || {
2018            KEY_VALUE_STORE.write_kvp(
2019                ZED_PREDICT_DATA_COLLECTION_CHOICE.into(),
2020                new_choice.is_enabled().to_string(),
2021            )
2022        });
2023    }
2024
2025    pub fn shown_predictions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
2026        self.shown_predictions.iter()
2027    }
2028
2029    pub fn shown_completions_len(&self) -> usize {
2030        self.shown_predictions.len()
2031    }
2032
2033    pub fn is_prediction_rated(&self, id: &EditPredictionId) -> bool {
2034        self.rated_predictions.contains(id)
2035    }
2036
2037    pub fn rate_prediction(
2038        &mut self,
2039        prediction: &EditPrediction,
2040        rating: EditPredictionRating,
2041        feedback: String,
2042        cx: &mut Context<Self>,
2043    ) {
2044        self.rated_predictions.insert(prediction.id.clone());
2045        telemetry::event!(
2046            "Edit Prediction Rated",
2047            rating,
2048            inputs = prediction.inputs,
2049            output = prediction.edit_preview.as_unified_diff(&prediction.edits),
2050            feedback
2051        );
2052        self.client.telemetry().flush_events().detach();
2053        cx.notify();
2054    }
2055
2056    fn configure_context_retrieval(&mut self, cx: &mut Context<'_, EditPredictionStore>) {
2057        self.use_context = cx.has_flag::<Zeta2FeatureFlag>()
2058            && all_language_settings(None, cx).edit_predictions.use_context;
2059    }
2060}
2061
2062#[derive(Error, Debug)]
2063#[error(
2064    "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
2065)]
2066pub struct ZedUpdateRequiredError {
2067    minimum_version: Version,
2068}
2069
2070#[cfg(feature = "cli-support")]
2071pub type EvalCacheKey = (EvalCacheEntryKind, u64);
2072
2073#[cfg(feature = "cli-support")]
2074#[derive(Debug, Clone, Copy, PartialEq)]
2075pub enum EvalCacheEntryKind {
2076    Context,
2077    Search,
2078    Prediction,
2079}
2080
2081#[cfg(feature = "cli-support")]
2082impl std::fmt::Display for EvalCacheEntryKind {
2083    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2084        match self {
2085            EvalCacheEntryKind::Search => write!(f, "search"),
2086            EvalCacheEntryKind::Context => write!(f, "context"),
2087            EvalCacheEntryKind::Prediction => write!(f, "prediction"),
2088        }
2089    }
2090}
2091
2092#[cfg(feature = "cli-support")]
2093pub trait EvalCache: Send + Sync {
2094    fn read(&self, key: EvalCacheKey) -> Option<String>;
2095    fn write(&self, key: EvalCacheKey, input: &str, value: &str);
2096}
2097
2098#[derive(Debug, Clone, Copy)]
2099pub enum DataCollectionChoice {
2100    NotAnswered,
2101    Enabled,
2102    Disabled,
2103}
2104
2105impl DataCollectionChoice {
2106    pub fn is_enabled(self) -> bool {
2107        match self {
2108            Self::Enabled => true,
2109            Self::NotAnswered | Self::Disabled => false,
2110        }
2111    }
2112
2113    pub fn is_answered(self) -> bool {
2114        match self {
2115            Self::Enabled | Self::Disabled => true,
2116            Self::NotAnswered => false,
2117        }
2118    }
2119
2120    #[must_use]
2121    pub fn toggle(&self) -> DataCollectionChoice {
2122        match self {
2123            Self::Enabled => Self::Disabled,
2124            Self::Disabled => Self::Enabled,
2125            Self::NotAnswered => Self::Enabled,
2126        }
2127    }
2128}
2129
2130impl From<bool> for DataCollectionChoice {
2131    fn from(value: bool) -> Self {
2132        match value {
2133            true => DataCollectionChoice::Enabled,
2134            false => DataCollectionChoice::Disabled,
2135        }
2136    }
2137}
2138
2139struct ZedPredictUpsell;
2140
2141impl Dismissable for ZedPredictUpsell {
2142    const KEY: &'static str = "dismissed-edit-predict-upsell";
2143
2144    fn dismissed() -> bool {
2145        // To make this backwards compatible with older versions of Zed, we
2146        // check if the user has seen the previous Edit Prediction Onboarding
2147        // before, by checking the data collection choice which was written to
2148        // the database once the user clicked on "Accept and Enable"
2149        if KEY_VALUE_STORE
2150            .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2151            .log_err()
2152            .is_some_and(|s| s.is_some())
2153        {
2154            return true;
2155        }
2156
2157        KEY_VALUE_STORE
2158            .read_kvp(Self::KEY)
2159            .log_err()
2160            .is_some_and(|s| s.is_some())
2161    }
2162}
2163
2164pub fn should_show_upsell_modal() -> bool {
2165    !ZedPredictUpsell::dismissed()
2166}
2167
2168pub fn init(cx: &mut App) {
2169    cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
2170        workspace.register_action(
2171            move |workspace, _: &zed_actions::OpenZedPredictOnboarding, window, cx| {
2172                ZedPredictModal::toggle(
2173                    workspace,
2174                    workspace.user_store().clone(),
2175                    workspace.client().clone(),
2176                    window,
2177                    cx,
2178                )
2179            },
2180        );
2181
2182        workspace.register_action(|workspace, _: &ResetOnboarding, _window, cx| {
2183            update_settings_file(workspace.app_state().fs.clone(), cx, move |settings, _| {
2184                settings
2185                    .project
2186                    .all_languages
2187                    .features
2188                    .get_or_insert_default()
2189                    .edit_prediction_provider = Some(EditPredictionProvider::None)
2190            });
2191        });
2192    })
2193    .detach();
2194}