edit_prediction.rs

   1use anyhow::Result;
   2use arrayvec::ArrayVec;
   3use client::{Client, EditPredictionUsage, UserStore};
   4use cloud_llm_client::predict_edits_v3::{self, PromptFormat};
   5use cloud_llm_client::{
   6    AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, EditPredictionRejectReason,
   7    EditPredictionRejection, MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST,
   8    MINIMUM_REQUIRED_VERSION_HEADER_NAME, PredictEditsRequestTrigger, RejectEditPredictionsBodyRef,
   9    ZED_VERSION_HEADER_NAME,
  10};
  11use collections::{HashMap, HashSet};
  12use db::kvp::{Dismissable, KEY_VALUE_STORE};
  13use edit_prediction_context::EditPredictionExcerptOptions;
  14use edit_prediction_context::{RelatedExcerptStore, RelatedExcerptStoreEvent, RelatedFile};
  15use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
  16use futures::{
  17    AsyncReadExt as _, FutureExt as _, StreamExt as _,
  18    channel::mpsc::{self, UnboundedReceiver},
  19    select_biased,
  20};
  21use gpui::BackgroundExecutor;
  22use gpui::http_client::Url;
  23use gpui::{
  24    App, AsyncApp, Entity, EntityId, Global, SharedString, Subscription, Task, WeakEntity, actions,
  25    http_client::{self, AsyncBody, Method},
  26    prelude::*,
  27};
  28use language::language_settings::all_language_settings;
  29use language::{Anchor, Buffer, File, Point, TextBufferSnapshot, ToPoint};
  30use language::{BufferSnapshot, OffsetRangeExt};
  31use language_model::{LlmApiToken, RefreshLlmTokenListener};
  32use project::{Project, ProjectPath, WorktreeId};
  33use release_channel::AppVersion;
  34use semver::Version;
  35use serde::de::DeserializeOwned;
  36use settings::{EditPredictionProvider, SettingsStore, update_settings_file};
  37use std::collections::{VecDeque, hash_map};
  38use text::Edit;
  39use workspace::Workspace;
  40
  41use std::ops::Range;
  42use std::path::Path;
  43use std::rc::Rc;
  44use std::str::FromStr as _;
  45use std::sync::{Arc, LazyLock};
  46use std::time::{Duration, Instant};
  47use std::{env, mem};
  48use thiserror::Error;
  49use util::{RangeExt as _, ResultExt as _};
  50use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
  51
  52pub mod cursor_excerpt;
  53pub mod example_spec;
  54mod license_detection;
  55pub mod mercury;
  56mod onboarding_modal;
  57pub mod open_ai_response;
  58mod prediction;
  59pub mod sweep_ai;
  60
  61pub mod udiff;
  62
  63mod capture_example;
  64mod zed_edit_prediction_delegate;
  65pub mod zeta1;
  66pub mod zeta2;
  67
  68#[cfg(test)]
  69mod edit_prediction_tests;
  70
  71use crate::license_detection::LicenseDetectionWatcher;
  72use crate::mercury::Mercury;
  73use crate::onboarding_modal::ZedPredictModal;
  74pub use crate::prediction::EditPrediction;
  75pub use crate::prediction::EditPredictionId;
  76use crate::prediction::EditPredictionResult;
  77pub use crate::sweep_ai::SweepAi;
  78pub use capture_example::capture_example;
  79pub use language_model::ApiKeyState;
  80pub use telemetry_events::EditPredictionRating;
  81pub use zed_edit_prediction_delegate::ZedEditPredictionDelegate;
  82
  83actions!(
  84    edit_prediction,
  85    [
  86        /// Resets the edit prediction onboarding state.
  87        ResetOnboarding,
  88        /// Clears the edit prediction history.
  89        ClearHistory,
  90    ]
  91);
  92
  93/// Maximum number of events to track.
  94const EVENT_COUNT_MAX: usize = 6;
  95const CHANGE_GROUPING_LINE_SPAN: u32 = 8;
  96const LAST_CHANGE_GROUPING_TIME: Duration = Duration::from_secs(1);
  97const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
  98const REJECT_REQUEST_DEBOUNCE: Duration = Duration::from_secs(15);
  99
 100pub struct SweepFeatureFlag;
 101
 102impl FeatureFlag for SweepFeatureFlag {
 103    const NAME: &str = "sweep-ai";
 104}
 105
 106pub struct MercuryFeatureFlag;
 107
 108impl FeatureFlag for MercuryFeatureFlag {
 109    const NAME: &str = "mercury";
 110}
 111
 112pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
 113    context: EditPredictionExcerptOptions {
 114        max_bytes: 512,
 115        min_bytes: 128,
 116        target_before_cursor_over_total_bytes: 0.5,
 117    },
 118    prompt_format: PromptFormat::DEFAULT,
 119};
 120
 121static USE_OLLAMA: LazyLock<bool> =
 122    LazyLock::new(|| env::var("ZED_ZETA2_OLLAMA").is_ok_and(|var| !var.is_empty()));
 123
 124static EDIT_PREDICTIONS_MODEL_ID: LazyLock<String> = LazyLock::new(|| {
 125    match env::var("ZED_ZETA2_MODEL").as_deref() {
 126        Ok("zeta2-exp") => "4w5n28vw", // Fine-tuned model @ Baseten
 127        Ok(model) => model,
 128        Err(_) if *USE_OLLAMA => "qwen3-coder:30b",
 129        Err(_) => "yqvev8r3", // Vanilla qwen3-coder @ Baseten
 130    }
 131    .to_string()
 132});
 133
 134pub struct Zeta2FeatureFlag;
 135
 136impl FeatureFlag for Zeta2FeatureFlag {
 137    const NAME: &'static str = "zeta2";
 138
 139    fn enabled_for_staff() -> bool {
 140        true
 141    }
 142}
 143
 144#[derive(Clone)]
 145struct EditPredictionStoreGlobal(Entity<EditPredictionStore>);
 146
 147impl Global for EditPredictionStoreGlobal {}
 148
 149pub struct EditPredictionStore {
 150    client: Arc<Client>,
 151    user_store: Entity<UserStore>,
 152    llm_token: LlmApiToken,
 153    _llm_token_subscription: Subscription,
 154    projects: HashMap<EntityId, ProjectState>,
 155    use_context: bool,
 156    options: ZetaOptions,
 157    update_required: bool,
 158    #[cfg(feature = "cli-support")]
 159    eval_cache: Option<Arc<dyn EvalCache>>,
 160    edit_prediction_model: EditPredictionModel,
 161    pub sweep_ai: SweepAi,
 162    pub mercury: Mercury,
 163    data_collection_choice: DataCollectionChoice,
 164    reject_predictions_tx: mpsc::UnboundedSender<EditPredictionRejection>,
 165    shown_predictions: VecDeque<EditPrediction>,
 166    rated_predictions: HashSet<EditPredictionId>,
 167    custom_predict_edits_url: Option<Arc<Url>>,
 168}
 169
 170#[derive(Copy, Clone, Default, PartialEq, Eq)]
 171pub enum EditPredictionModel {
 172    #[default]
 173    Zeta1,
 174    Zeta2,
 175    Sweep,
 176    Mercury,
 177}
 178
 179pub struct EditPredictionModelInput {
 180    project: Entity<Project>,
 181    buffer: Entity<Buffer>,
 182    snapshot: BufferSnapshot,
 183    position: Anchor,
 184    events: Vec<Arc<zeta_prompt::Event>>,
 185    related_files: Arc<[RelatedFile]>,
 186    recent_paths: VecDeque<ProjectPath>,
 187    trigger: PredictEditsRequestTrigger,
 188    diagnostic_search_range: Range<Point>,
 189    debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
 190}
 191
 192#[derive(Debug, Clone, PartialEq)]
 193pub struct ZetaOptions {
 194    pub context: EditPredictionExcerptOptions,
 195    pub prompt_format: predict_edits_v3::PromptFormat,
 196}
 197
 198#[derive(Debug)]
 199pub enum DebugEvent {
 200    ContextRetrievalStarted(ContextRetrievalStartedDebugEvent),
 201    ContextRetrievalFinished(ContextRetrievalFinishedDebugEvent),
 202    EditPredictionStarted(EditPredictionStartedDebugEvent),
 203    EditPredictionFinished(EditPredictionFinishedDebugEvent),
 204}
 205
 206#[derive(Debug)]
 207pub struct ContextRetrievalStartedDebugEvent {
 208    pub project_entity_id: EntityId,
 209    pub timestamp: Instant,
 210    pub search_prompt: String,
 211}
 212
 213#[derive(Debug)]
 214pub struct ContextRetrievalFinishedDebugEvent {
 215    pub project_entity_id: EntityId,
 216    pub timestamp: Instant,
 217    pub metadata: Vec<(&'static str, SharedString)>,
 218}
 219
 220#[derive(Debug)]
 221pub struct EditPredictionStartedDebugEvent {
 222    pub buffer: WeakEntity<Buffer>,
 223    pub position: Anchor,
 224    pub prompt: Option<String>,
 225}
 226
 227#[derive(Debug)]
 228pub struct EditPredictionFinishedDebugEvent {
 229    pub buffer: WeakEntity<Buffer>,
 230    pub position: Anchor,
 231    pub model_output: Option<String>,
 232}
 233
 234pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
 235
 236/// An event with associated metadata for reconstructing buffer state.
 237#[derive(Clone)]
 238pub struct StoredEvent {
 239    pub event: Arc<zeta_prompt::Event>,
 240    pub old_snapshot: TextBufferSnapshot,
 241}
 242
 243struct ProjectState {
 244    events: VecDeque<StoredEvent>,
 245    last_event: Option<LastEvent>,
 246    recent_paths: VecDeque<ProjectPath>,
 247    registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
 248    current_prediction: Option<CurrentEditPrediction>,
 249    next_pending_prediction_id: usize,
 250    pending_predictions: ArrayVec<PendingPrediction, 2>,
 251    debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
 252    last_prediction_refresh: Option<(EntityId, Instant)>,
 253    cancelled_predictions: HashSet<usize>,
 254    context: Entity<RelatedExcerptStore>,
 255    license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
 256    _subscription: gpui::Subscription,
 257}
 258
 259impl ProjectState {
 260    pub fn events(&self, cx: &App) -> Vec<StoredEvent> {
 261        self.events
 262            .iter()
 263            .cloned()
 264            .chain(
 265                self.last_event
 266                    .as_ref()
 267                    .and_then(|event| event.finalize(&self.license_detection_watchers, cx)),
 268            )
 269            .collect()
 270    }
 271
 272    pub fn events_split_by_pause(&self, cx: &App) -> Vec<StoredEvent> {
 273        self.events
 274            .iter()
 275            .cloned()
 276            .chain(self.last_event.as_ref().iter().flat_map(|event| {
 277                let (one, two) = event.split_by_pause();
 278                let one = one.finalize(&self.license_detection_watchers, cx);
 279                let two = two.and_then(|two| two.finalize(&self.license_detection_watchers, cx));
 280                one.into_iter().chain(two)
 281            }))
 282            .collect()
 283    }
 284
 285    fn cancel_pending_prediction(
 286        &mut self,
 287        pending_prediction: PendingPrediction,
 288        cx: &mut Context<EditPredictionStore>,
 289    ) {
 290        self.cancelled_predictions.insert(pending_prediction.id);
 291
 292        cx.spawn(async move |this, cx| {
 293            let Some(prediction_id) = pending_prediction.task.await else {
 294                return;
 295            };
 296
 297            this.update(cx, |this, _cx| {
 298                this.reject_prediction(prediction_id, EditPredictionRejectReason::Canceled, false);
 299            })
 300            .ok();
 301        })
 302        .detach()
 303    }
 304
 305    fn active_buffer(
 306        &self,
 307        project: &Entity<Project>,
 308        cx: &App,
 309    ) -> Option<(Entity<Buffer>, Option<Anchor>)> {
 310        let project = project.read(cx);
 311        let active_path = project.path_for_entry(project.active_entry()?, cx)?;
 312        let active_buffer = project.buffer_store().read(cx).get_by_path(&active_path)?;
 313        let registered_buffer = self.registered_buffers.get(&active_buffer.entity_id())?;
 314        Some((active_buffer, registered_buffer.last_position))
 315    }
 316}
 317
 318#[derive(Debug, Clone)]
 319struct CurrentEditPrediction {
 320    pub requested_by: PredictionRequestedBy,
 321    pub prediction: EditPrediction,
 322    pub was_shown: bool,
 323}
 324
 325impl CurrentEditPrediction {
 326    fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
 327        let Some(new_edits) = self
 328            .prediction
 329            .interpolate(&self.prediction.buffer.read(cx))
 330        else {
 331            return false;
 332        };
 333
 334        if self.prediction.buffer != old_prediction.prediction.buffer {
 335            return true;
 336        }
 337
 338        let Some(old_edits) = old_prediction
 339            .prediction
 340            .interpolate(&old_prediction.prediction.buffer.read(cx))
 341        else {
 342            return true;
 343        };
 344
 345        let requested_by_buffer_id = self.requested_by.buffer_id();
 346
 347        // This reduces the occurrence of UI thrash from replacing edits
 348        //
 349        // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
 350        if requested_by_buffer_id == Some(self.prediction.buffer.entity_id())
 351            && requested_by_buffer_id == Some(old_prediction.prediction.buffer.entity_id())
 352            && old_edits.len() == 1
 353            && new_edits.len() == 1
 354        {
 355            let (old_range, old_text) = &old_edits[0];
 356            let (new_range, new_text) = &new_edits[0];
 357            new_range == old_range && new_text.starts_with(old_text.as_ref())
 358        } else {
 359            true
 360        }
 361    }
 362}
 363
 364#[derive(Debug, Clone)]
 365enum PredictionRequestedBy {
 366    DiagnosticsUpdate,
 367    Buffer(EntityId),
 368}
 369
 370impl PredictionRequestedBy {
 371    pub fn buffer_id(&self) -> Option<EntityId> {
 372        match self {
 373            PredictionRequestedBy::DiagnosticsUpdate => None,
 374            PredictionRequestedBy::Buffer(buffer_id) => Some(*buffer_id),
 375        }
 376    }
 377}
 378
 379#[derive(Debug)]
 380struct PendingPrediction {
 381    id: usize,
 382    task: Task<Option<EditPredictionId>>,
 383}
 384
 385/// A prediction from the perspective of a buffer.
 386#[derive(Debug)]
 387enum BufferEditPrediction<'a> {
 388    Local { prediction: &'a EditPrediction },
 389    Jump { prediction: &'a EditPrediction },
 390}
 391
 392#[cfg(test)]
 393impl std::ops::Deref for BufferEditPrediction<'_> {
 394    type Target = EditPrediction;
 395
 396    fn deref(&self) -> &Self::Target {
 397        match self {
 398            BufferEditPrediction::Local { prediction } => prediction,
 399            BufferEditPrediction::Jump { prediction } => prediction,
 400        }
 401    }
 402}
 403
 404struct RegisteredBuffer {
 405    file: Option<Arc<dyn File>>,
 406    snapshot: TextBufferSnapshot,
 407    last_position: Option<Anchor>,
 408    _subscriptions: [gpui::Subscription; 2],
 409}
 410
 411#[derive(Clone)]
 412struct LastEvent {
 413    old_snapshot: TextBufferSnapshot,
 414    new_snapshot: TextBufferSnapshot,
 415    old_file: Option<Arc<dyn File>>,
 416    new_file: Option<Arc<dyn File>>,
 417    end_edit_anchor: Option<Anchor>,
 418    snapshot_after_last_editing_pause: Option<TextBufferSnapshot>,
 419    last_edit_time: Option<Instant>,
 420}
 421
 422impl LastEvent {
 423    pub fn finalize(
 424        &self,
 425        license_detection_watchers: &HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
 426        cx: &App,
 427    ) -> Option<StoredEvent> {
 428        let path = buffer_path_with_id_fallback(self.new_file.as_ref(), &self.new_snapshot, cx);
 429        let old_path = buffer_path_with_id_fallback(self.old_file.as_ref(), &self.old_snapshot, cx);
 430
 431        let in_open_source_repo =
 432            [self.new_file.as_ref(), self.old_file.as_ref()]
 433                .iter()
 434                .all(|file| {
 435                    file.is_some_and(|file| {
 436                        license_detection_watchers
 437                            .get(&file.worktree_id(cx))
 438                            .is_some_and(|watcher| watcher.is_project_open_source())
 439                    })
 440                });
 441
 442        let diff = compute_diff_between_snapshots(&self.old_snapshot, &self.new_snapshot)?;
 443
 444        if path == old_path && diff.is_empty() {
 445            None
 446        } else {
 447            Some(StoredEvent {
 448                event: Arc::new(zeta_prompt::Event::BufferChange {
 449                    old_path,
 450                    path,
 451                    diff,
 452                    in_open_source_repo,
 453                    // TODO: Actually detect if this edit was predicted or not
 454                    predicted: false,
 455                }),
 456                old_snapshot: self.old_snapshot.clone(),
 457            })
 458        }
 459    }
 460
 461    pub fn split_by_pause(&self) -> (LastEvent, Option<LastEvent>) {
 462        let Some(boundary_snapshot) = self.snapshot_after_last_editing_pause.as_ref() else {
 463            return (self.clone(), None);
 464        };
 465
 466        let before = LastEvent {
 467            old_snapshot: self.old_snapshot.clone(),
 468            new_snapshot: boundary_snapshot.clone(),
 469            old_file: self.old_file.clone(),
 470            new_file: self.new_file.clone(),
 471            end_edit_anchor: self.end_edit_anchor,
 472            snapshot_after_last_editing_pause: None,
 473            last_edit_time: self.last_edit_time,
 474        };
 475
 476        let after = LastEvent {
 477            old_snapshot: boundary_snapshot.clone(),
 478            new_snapshot: self.new_snapshot.clone(),
 479            old_file: self.old_file.clone(),
 480            new_file: self.new_file.clone(),
 481            end_edit_anchor: self.end_edit_anchor,
 482            snapshot_after_last_editing_pause: None,
 483            last_edit_time: self.last_edit_time,
 484        };
 485
 486        (before, Some(after))
 487    }
 488}
 489
 490pub(crate) fn compute_diff_between_snapshots(
 491    old_snapshot: &TextBufferSnapshot,
 492    new_snapshot: &TextBufferSnapshot,
 493) -> Option<String> {
 494    let edits: Vec<Edit<usize>> = new_snapshot
 495        .edits_since::<usize>(&old_snapshot.version)
 496        .collect();
 497
 498    let (first_edit, last_edit) = edits.first().zip(edits.last())?;
 499
 500    let old_start_point = old_snapshot.offset_to_point(first_edit.old.start);
 501    let old_end_point = old_snapshot.offset_to_point(last_edit.old.end);
 502    let new_start_point = new_snapshot.offset_to_point(first_edit.new.start);
 503    let new_end_point = new_snapshot.offset_to_point(last_edit.new.end);
 504
 505    const CONTEXT_LINES: u32 = 3;
 506
 507    let old_context_start_row = old_start_point.row.saturating_sub(CONTEXT_LINES);
 508    let new_context_start_row = new_start_point.row.saturating_sub(CONTEXT_LINES);
 509    let old_context_end_row =
 510        (old_end_point.row + 1 + CONTEXT_LINES).min(old_snapshot.max_point().row);
 511    let new_context_end_row =
 512        (new_end_point.row + 1 + CONTEXT_LINES).min(new_snapshot.max_point().row);
 513
 514    let old_start_line_offset = old_snapshot.point_to_offset(Point::new(old_context_start_row, 0));
 515    let new_start_line_offset = new_snapshot.point_to_offset(Point::new(new_context_start_row, 0));
 516    let old_end_line_offset = old_snapshot
 517        .point_to_offset(Point::new(old_context_end_row + 1, 0).min(old_snapshot.max_point()));
 518    let new_end_line_offset = new_snapshot
 519        .point_to_offset(Point::new(new_context_end_row + 1, 0).min(new_snapshot.max_point()));
 520    let old_edit_range = old_start_line_offset..old_end_line_offset;
 521    let new_edit_range = new_start_line_offset..new_end_line_offset;
 522
 523    let old_region_text: String = old_snapshot.text_for_range(old_edit_range).collect();
 524    let new_region_text: String = new_snapshot.text_for_range(new_edit_range).collect();
 525
 526    let diff = language::unified_diff_with_offsets(
 527        &old_region_text,
 528        &new_region_text,
 529        old_context_start_row,
 530        new_context_start_row,
 531    );
 532
 533    Some(diff)
 534}
 535
 536fn buffer_path_with_id_fallback(
 537    file: Option<&Arc<dyn File>>,
 538    snapshot: &TextBufferSnapshot,
 539    cx: &App,
 540) -> Arc<Path> {
 541    if let Some(file) = file {
 542        file.full_path(cx).into()
 543    } else {
 544        Path::new(&format!("untitled-{}", snapshot.remote_id())).into()
 545    }
 546}
 547
 548impl EditPredictionStore {
 549    pub fn try_global(cx: &App) -> Option<Entity<Self>> {
 550        cx.try_global::<EditPredictionStoreGlobal>()
 551            .map(|global| global.0.clone())
 552    }
 553
 554    pub fn global(
 555        client: &Arc<Client>,
 556        user_store: &Entity<UserStore>,
 557        cx: &mut App,
 558    ) -> Entity<Self> {
 559        cx.try_global::<EditPredictionStoreGlobal>()
 560            .map(|global| global.0.clone())
 561            .unwrap_or_else(|| {
 562                let ep_store = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
 563                cx.set_global(EditPredictionStoreGlobal(ep_store.clone()));
 564                ep_store
 565            })
 566    }
 567
 568    pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
 569        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
 570        let data_collection_choice = Self::load_data_collection_choice();
 571
 572        let llm_token = LlmApiToken::default();
 573
 574        let (reject_tx, reject_rx) = mpsc::unbounded();
 575        cx.background_spawn({
 576            let client = client.clone();
 577            let llm_token = llm_token.clone();
 578            let app_version = AppVersion::global(cx);
 579            let background_executor = cx.background_executor().clone();
 580            async move {
 581                Self::handle_rejected_predictions(
 582                    reject_rx,
 583                    client,
 584                    llm_token,
 585                    app_version,
 586                    background_executor,
 587                )
 588                .await
 589            }
 590        })
 591        .detach();
 592
 593        let mut this = Self {
 594            projects: HashMap::default(),
 595            client,
 596            user_store,
 597            options: DEFAULT_OPTIONS,
 598            use_context: false,
 599            llm_token,
 600            _llm_token_subscription: cx.subscribe(
 601                &refresh_llm_token_listener,
 602                |this, _listener, _event, cx| {
 603                    let client = this.client.clone();
 604                    let llm_token = this.llm_token.clone();
 605                    cx.spawn(async move |_this, _cx| {
 606                        llm_token.refresh(&client).await?;
 607                        anyhow::Ok(())
 608                    })
 609                    .detach_and_log_err(cx);
 610                },
 611            ),
 612            update_required: false,
 613            #[cfg(feature = "cli-support")]
 614            eval_cache: None,
 615            edit_prediction_model: EditPredictionModel::Zeta2,
 616            sweep_ai: SweepAi::new(cx),
 617            mercury: Mercury::new(cx),
 618            data_collection_choice,
 619            reject_predictions_tx: reject_tx,
 620            rated_predictions: Default::default(),
 621            shown_predictions: Default::default(),
 622            custom_predict_edits_url: match env::var("ZED_PREDICT_EDITS_URL") {
 623                Ok(custom_url) => Url::parse(&custom_url).log_err().map(Into::into),
 624                Err(_) => {
 625                    if *USE_OLLAMA {
 626                        Some(
 627                            Url::parse("http://localhost:11434/v1/chat/completions")
 628                                .unwrap()
 629                                .into(),
 630                        )
 631                    } else {
 632                        None
 633                    }
 634                }
 635            },
 636        };
 637
 638        this.configure_context_retrieval(cx);
 639        let weak_this = cx.weak_entity();
 640        cx.on_flags_ready(move |_, cx| {
 641            weak_this
 642                .update(cx, |this, cx| this.configure_context_retrieval(cx))
 643                .ok();
 644        })
 645        .detach();
 646        cx.observe_global::<SettingsStore>(|this, cx| {
 647            this.configure_context_retrieval(cx);
 648        })
 649        .detach();
 650
 651        this
 652    }
 653
 654    #[cfg(test)]
 655    pub fn set_custom_predict_edits_url(&mut self, url: Url) {
 656        self.custom_predict_edits_url = Some(url.into());
 657    }
 658
 659    pub fn set_edit_prediction_model(&mut self, model: EditPredictionModel) {
 660        self.edit_prediction_model = model;
 661    }
 662
 663    pub fn has_sweep_api_token(&self, cx: &App) -> bool {
 664        self.sweep_ai.api_token.read(cx).has_key()
 665    }
 666
 667    pub fn has_mercury_api_token(&self, cx: &App) -> bool {
 668        self.mercury.api_token.read(cx).has_key()
 669    }
 670
 671    #[cfg(feature = "cli-support")]
 672    pub fn with_eval_cache(&mut self, cache: Arc<dyn EvalCache>) {
 673        self.eval_cache = Some(cache);
 674    }
 675
 676    pub fn options(&self) -> &ZetaOptions {
 677        &self.options
 678    }
 679
 680    pub fn set_options(&mut self, options: ZetaOptions) {
 681        self.options = options;
 682    }
 683
 684    pub fn set_use_context(&mut self, use_context: bool) {
 685        self.use_context = use_context;
 686    }
 687
 688    pub fn clear_history(&mut self) {
 689        for project_state in self.projects.values_mut() {
 690            project_state.events.clear();
 691        }
 692    }
 693
 694    pub fn clear_history_for_project(&mut self, project: &Entity<Project>) {
 695        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
 696            project_state.events.clear();
 697        }
 698    }
 699
 700    pub fn edit_history_for_project(
 701        &self,
 702        project: &Entity<Project>,
 703        cx: &App,
 704    ) -> Vec<StoredEvent> {
 705        self.projects
 706            .get(&project.entity_id())
 707            .map(|project_state| project_state.events(cx))
 708            .unwrap_or_default()
 709    }
 710
 711    pub fn edit_history_for_project_with_pause_split_last_event(
 712        &self,
 713        project: &Entity<Project>,
 714        cx: &App,
 715    ) -> Vec<StoredEvent> {
 716        self.projects
 717            .get(&project.entity_id())
 718            .map(|project_state| project_state.events_split_by_pause(cx))
 719            .unwrap_or_default()
 720    }
 721
 722    pub fn context_for_project<'a>(
 723        &'a self,
 724        project: &Entity<Project>,
 725        cx: &'a App,
 726    ) -> Arc<[RelatedFile]> {
 727        self.projects
 728            .get(&project.entity_id())
 729            .map(|project| project.context.read(cx).related_files())
 730            .unwrap_or_else(|| vec![].into())
 731    }
 732
 733    pub fn context_for_project_with_buffers<'a>(
 734        &'a self,
 735        project: &Entity<Project>,
 736        cx: &'a App,
 737    ) -> Option<impl 'a + Iterator<Item = (RelatedFile, Entity<Buffer>)>> {
 738        self.projects
 739            .get(&project.entity_id())
 740            .map(|project| project.context.read(cx).related_files_with_buffers())
 741    }
 742
 743    pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
 744        if self.edit_prediction_model == EditPredictionModel::Zeta2 {
 745            self.user_store.read(cx).edit_prediction_usage()
 746        } else {
 747            None
 748        }
 749    }
 750
 751    pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
 752        self.get_or_init_project(project, cx);
 753    }
 754
 755    pub fn register_buffer(
 756        &mut self,
 757        buffer: &Entity<Buffer>,
 758        project: &Entity<Project>,
 759        cx: &mut Context<Self>,
 760    ) {
 761        let project_state = self.get_or_init_project(project, cx);
 762        Self::register_buffer_impl(project_state, buffer, project, cx);
 763    }
 764
 765    fn get_or_init_project(
 766        &mut self,
 767        project: &Entity<Project>,
 768        cx: &mut Context<Self>,
 769    ) -> &mut ProjectState {
 770        let entity_id = project.entity_id();
 771        self.projects
 772            .entry(entity_id)
 773            .or_insert_with(|| ProjectState {
 774                context: {
 775                    let related_excerpt_store = cx.new(|cx| RelatedExcerptStore::new(project, cx));
 776                    cx.subscribe(&related_excerpt_store, move |this, _, event, _| {
 777                        this.handle_excerpt_store_event(entity_id, event);
 778                    })
 779                    .detach();
 780                    related_excerpt_store
 781                },
 782                events: VecDeque::new(),
 783                last_event: None,
 784                recent_paths: VecDeque::new(),
 785                debug_tx: None,
 786                registered_buffers: HashMap::default(),
 787                current_prediction: None,
 788                cancelled_predictions: HashSet::default(),
 789                pending_predictions: ArrayVec::new(),
 790                next_pending_prediction_id: 0,
 791                last_prediction_refresh: None,
 792                license_detection_watchers: HashMap::default(),
 793                _subscription: cx.subscribe(&project, Self::handle_project_event),
 794            })
 795    }
 796
 797    pub fn remove_project(&mut self, project: &Entity<Project>) {
 798        self.projects.remove(&project.entity_id());
 799    }
 800
 801    fn handle_excerpt_store_event(
 802        &mut self,
 803        project_entity_id: EntityId,
 804        event: &RelatedExcerptStoreEvent,
 805    ) {
 806        if let Some(project_state) = self.projects.get(&project_entity_id) {
 807            if let Some(debug_tx) = project_state.debug_tx.clone() {
 808                match event {
 809                    RelatedExcerptStoreEvent::StartedRefresh => {
 810                        debug_tx
 811                            .unbounded_send(DebugEvent::ContextRetrievalStarted(
 812                                ContextRetrievalStartedDebugEvent {
 813                                    project_entity_id: project_entity_id,
 814                                    timestamp: Instant::now(),
 815                                    search_prompt: String::new(),
 816                                },
 817                            ))
 818                            .ok();
 819                    }
 820                    RelatedExcerptStoreEvent::FinishedRefresh {
 821                        cache_hit_count,
 822                        cache_miss_count,
 823                        mean_definition_latency,
 824                        max_definition_latency,
 825                    } => {
 826                        debug_tx
 827                            .unbounded_send(DebugEvent::ContextRetrievalFinished(
 828                                ContextRetrievalFinishedDebugEvent {
 829                                    project_entity_id: project_entity_id,
 830                                    timestamp: Instant::now(),
 831                                    metadata: vec![
 832                                        (
 833                                            "Cache Hits",
 834                                            format!(
 835                                                "{}/{}",
 836                                                cache_hit_count,
 837                                                cache_hit_count + cache_miss_count
 838                                            )
 839                                            .into(),
 840                                        ),
 841                                        (
 842                                            "Max LSP Time",
 843                                            format!("{} ms", max_definition_latency.as_millis())
 844                                                .into(),
 845                                        ),
 846                                        (
 847                                            "Mean LSP Time",
 848                                            format!("{} ms", mean_definition_latency.as_millis())
 849                                                .into(),
 850                                        ),
 851                                    ],
 852                                },
 853                            ))
 854                            .ok();
 855                    }
 856                }
 857            }
 858        }
 859    }
 860
 861    pub fn debug_info(
 862        &mut self,
 863        project: &Entity<Project>,
 864        cx: &mut Context<Self>,
 865    ) -> mpsc::UnboundedReceiver<DebugEvent> {
 866        let project_state = self.get_or_init_project(project, cx);
 867        let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
 868        project_state.debug_tx = Some(debug_watch_tx);
 869        debug_watch_rx
 870    }
 871
 872    fn handle_project_event(
 873        &mut self,
 874        project: Entity<Project>,
 875        event: &project::Event,
 876        cx: &mut Context<Self>,
 877    ) {
 878        // TODO [zeta2] init with recent paths
 879        match event {
 880            project::Event::ActiveEntryChanged(Some(active_entry_id)) => {
 881                let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
 882                    return;
 883                };
 884                let path = project.read(cx).path_for_entry(*active_entry_id, cx);
 885                if let Some(path) = path {
 886                    if let Some(ix) = project_state
 887                        .recent_paths
 888                        .iter()
 889                        .position(|probe| probe == &path)
 890                    {
 891                        project_state.recent_paths.remove(ix);
 892                    }
 893                    project_state.recent_paths.push_front(path);
 894                }
 895            }
 896            project::Event::DiagnosticsUpdated { .. } => {
 897                if cx.has_flag::<Zeta2FeatureFlag>() {
 898                    self.refresh_prediction_from_diagnostics(project, cx);
 899                }
 900            }
 901            _ => (),
 902        }
 903    }
 904
 905    fn register_buffer_impl<'a>(
 906        project_state: &'a mut ProjectState,
 907        buffer: &Entity<Buffer>,
 908        project: &Entity<Project>,
 909        cx: &mut Context<Self>,
 910    ) -> &'a mut RegisteredBuffer {
 911        let buffer_id = buffer.entity_id();
 912
 913        if let Some(file) = buffer.read(cx).file() {
 914            let worktree_id = file.worktree_id(cx);
 915            if let Some(worktree) = project.read(cx).worktree_for_id(worktree_id, cx) {
 916                project_state
 917                    .license_detection_watchers
 918                    .entry(worktree_id)
 919                    .or_insert_with(|| {
 920                        let project_entity_id = project.entity_id();
 921                        cx.observe_release(&worktree, move |this, _worktree, _cx| {
 922                            let Some(project_state) = this.projects.get_mut(&project_entity_id)
 923                            else {
 924                                return;
 925                            };
 926                            project_state
 927                                .license_detection_watchers
 928                                .remove(&worktree_id);
 929                        })
 930                        .detach();
 931                        Rc::new(LicenseDetectionWatcher::new(&worktree, cx))
 932                    });
 933            }
 934        }
 935
 936        match project_state.registered_buffers.entry(buffer_id) {
 937            hash_map::Entry::Occupied(entry) => entry.into_mut(),
 938            hash_map::Entry::Vacant(entry) => {
 939                let buf = buffer.read(cx);
 940                let snapshot = buf.text_snapshot();
 941                let file = buf.file().cloned();
 942                let project_entity_id = project.entity_id();
 943                entry.insert(RegisteredBuffer {
 944                    snapshot,
 945                    file,
 946                    last_position: None,
 947                    _subscriptions: [
 948                        cx.subscribe(buffer, {
 949                            let project = project.downgrade();
 950                            move |this, buffer, event, cx| {
 951                                if let language::BufferEvent::Edited = event
 952                                    && let Some(project) = project.upgrade()
 953                                {
 954                                    this.report_changes_for_buffer(&buffer, &project, cx);
 955                                }
 956                            }
 957                        }),
 958                        cx.observe_release(buffer, move |this, _buffer, _cx| {
 959                            let Some(project_state) = this.projects.get_mut(&project_entity_id)
 960                            else {
 961                                return;
 962                            };
 963                            project_state.registered_buffers.remove(&buffer_id);
 964                        }),
 965                    ],
 966                })
 967            }
 968        }
 969    }
 970
 971    fn report_changes_for_buffer(
 972        &mut self,
 973        buffer: &Entity<Buffer>,
 974        project: &Entity<Project>,
 975        cx: &mut Context<Self>,
 976    ) {
 977        let project_state = self.get_or_init_project(project, cx);
 978        let registered_buffer = Self::register_buffer_impl(project_state, buffer, project, cx);
 979
 980        let buf = buffer.read(cx);
 981        let new_file = buf.file().cloned();
 982        let new_snapshot = buf.text_snapshot();
 983        if new_snapshot.version == registered_buffer.snapshot.version {
 984            return;
 985        }
 986
 987        let old_file = mem::replace(&mut registered_buffer.file, new_file.clone());
 988        let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
 989        let end_edit_anchor = new_snapshot
 990            .anchored_edits_since::<Point>(&old_snapshot.version)
 991            .last()
 992            .map(|(_, range)| range.end);
 993        let events = &mut project_state.events;
 994
 995        let now = cx.background_executor().now();
 996        if let Some(last_event) = project_state.last_event.as_mut() {
 997            let is_next_snapshot_of_same_buffer = old_snapshot.remote_id()
 998                == last_event.new_snapshot.remote_id()
 999                && old_snapshot.version == last_event.new_snapshot.version;
1000
1001            let should_coalesce = is_next_snapshot_of_same_buffer
1002                && end_edit_anchor
1003                    .as_ref()
1004                    .zip(last_event.end_edit_anchor.as_ref())
1005                    .is_some_and(|(a, b)| {
1006                        let a = a.to_point(&new_snapshot);
1007                        let b = b.to_point(&new_snapshot);
1008                        a.row.abs_diff(b.row) <= CHANGE_GROUPING_LINE_SPAN
1009                    });
1010
1011            if should_coalesce {
1012                let pause_elapsed = last_event
1013                    .last_edit_time
1014                    .map(|t| now.duration_since(t) >= LAST_CHANGE_GROUPING_TIME)
1015                    .unwrap_or(false);
1016                if pause_elapsed {
1017                    last_event.snapshot_after_last_editing_pause =
1018                        Some(last_event.new_snapshot.clone());
1019                }
1020
1021                last_event.end_edit_anchor = end_edit_anchor;
1022                last_event.new_snapshot = new_snapshot;
1023                last_event.last_edit_time = Some(now);
1024                return;
1025            }
1026        }
1027
1028        if events.len() + 1 >= EVENT_COUNT_MAX {
1029            events.pop_front();
1030        }
1031
1032        if let Some(event) = project_state.last_event.take() {
1033            events.extend(event.finalize(&project_state.license_detection_watchers, cx));
1034        }
1035
1036        project_state.last_event = Some(LastEvent {
1037            old_file,
1038            new_file,
1039            old_snapshot,
1040            new_snapshot,
1041            end_edit_anchor,
1042            snapshot_after_last_editing_pause: None,
1043            last_edit_time: Some(now),
1044        });
1045    }
1046
1047    fn prediction_at(
1048        &mut self,
1049        buffer: &Entity<Buffer>,
1050        position: Option<language::Anchor>,
1051        project: &Entity<Project>,
1052        cx: &App,
1053    ) -> Option<BufferEditPrediction<'_>> {
1054        let project_state = self.projects.get_mut(&project.entity_id())?;
1055        if let Some(position) = position
1056            && let Some(buffer) = project_state
1057                .registered_buffers
1058                .get_mut(&buffer.entity_id())
1059        {
1060            buffer.last_position = Some(position);
1061        }
1062
1063        let CurrentEditPrediction {
1064            requested_by,
1065            prediction,
1066            ..
1067        } = project_state.current_prediction.as_ref()?;
1068
1069        if prediction.targets_buffer(buffer.read(cx)) {
1070            Some(BufferEditPrediction::Local { prediction })
1071        } else {
1072            let show_jump = match requested_by {
1073                PredictionRequestedBy::Buffer(requested_by_buffer_id) => {
1074                    requested_by_buffer_id == &buffer.entity_id()
1075                }
1076                PredictionRequestedBy::DiagnosticsUpdate => true,
1077            };
1078
1079            if show_jump {
1080                Some(BufferEditPrediction::Jump { prediction })
1081            } else {
1082                None
1083            }
1084        }
1085    }
1086
1087    fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
1088        let custom_accept_url = env::var("ZED_ACCEPT_PREDICTION_URL").ok();
1089        match self.edit_prediction_model {
1090            EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {
1091                if self.custom_predict_edits_url.is_some() && custom_accept_url.is_none() {
1092                    return;
1093                }
1094            }
1095            EditPredictionModel::Sweep | EditPredictionModel::Mercury => return,
1096        }
1097
1098        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1099            return;
1100        };
1101
1102        let Some(prediction) = project_state.current_prediction.take() else {
1103            return;
1104        };
1105        let request_id = prediction.prediction.id.to_string();
1106        for pending_prediction in mem::take(&mut project_state.pending_predictions) {
1107            project_state.cancel_pending_prediction(pending_prediction, cx);
1108        }
1109
1110        let client = self.client.clone();
1111        let llm_token = self.llm_token.clone();
1112        let app_version = AppVersion::global(cx);
1113        cx.spawn(async move |this, cx| {
1114            let (url, require_auth) = if let Some(accept_edits_url) = custom_accept_url {
1115                (http_client::Url::parse(&accept_edits_url)?, false)
1116            } else {
1117                (
1118                    client
1119                        .http_client()
1120                        .build_zed_llm_url("/predict_edits/accept", &[])?,
1121                    true,
1122                )
1123            };
1124
1125            let response = cx
1126                .background_spawn(Self::send_api_request::<()>(
1127                    move |builder| {
1128                        let req = builder.uri(url.as_ref()).body(
1129                            serde_json::to_string(&AcceptEditPredictionBody {
1130                                request_id: request_id.clone(),
1131                            })?
1132                            .into(),
1133                        );
1134                        Ok(req?)
1135                    },
1136                    client,
1137                    llm_token,
1138                    app_version,
1139                    require_auth,
1140                ))
1141                .await;
1142
1143            Self::handle_api_response(&this, response, cx)?;
1144            anyhow::Ok(())
1145        })
1146        .detach_and_log_err(cx);
1147    }
1148
1149    async fn handle_rejected_predictions(
1150        rx: UnboundedReceiver<EditPredictionRejection>,
1151        client: Arc<Client>,
1152        llm_token: LlmApiToken,
1153        app_version: Version,
1154        background_executor: BackgroundExecutor,
1155    ) {
1156        let mut rx = std::pin::pin!(rx.peekable());
1157        let mut batched = Vec::new();
1158
1159        while let Some(rejection) = rx.next().await {
1160            batched.push(rejection);
1161
1162            if batched.len() < MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST / 2 {
1163                select_biased! {
1164                    next = rx.as_mut().peek().fuse() => {
1165                        if next.is_some() {
1166                            continue;
1167                        }
1168                    }
1169                    () = background_executor.timer(REJECT_REQUEST_DEBOUNCE).fuse() => {},
1170                }
1171            }
1172
1173            let url = client
1174                .http_client()
1175                .build_zed_llm_url("/predict_edits/reject", &[])
1176                .unwrap();
1177
1178            let flush_count = batched
1179                .len()
1180                // in case items have accumulated after failure
1181                .min(MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST);
1182            let start = batched.len() - flush_count;
1183
1184            let body = RejectEditPredictionsBodyRef {
1185                rejections: &batched[start..],
1186            };
1187
1188            let result = Self::send_api_request::<()>(
1189                |builder| {
1190                    let req = builder
1191                        .uri(url.as_ref())
1192                        .body(serde_json::to_string(&body)?.into());
1193                    anyhow::Ok(req?)
1194                },
1195                client.clone(),
1196                llm_token.clone(),
1197                app_version.clone(),
1198                true,
1199            )
1200            .await;
1201
1202            if result.log_err().is_some() {
1203                batched.drain(start..);
1204            }
1205        }
1206    }
1207
1208    fn reject_current_prediction(
1209        &mut self,
1210        reason: EditPredictionRejectReason,
1211        project: &Entity<Project>,
1212    ) {
1213        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
1214            project_state.pending_predictions.clear();
1215            if let Some(prediction) = project_state.current_prediction.take() {
1216                self.reject_prediction(prediction.prediction.id, reason, prediction.was_shown);
1217            }
1218        };
1219    }
1220
1221    fn did_show_current_prediction(&mut self, project: &Entity<Project>, _cx: &mut Context<Self>) {
1222        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
1223            if let Some(current_prediction) = project_state.current_prediction.as_mut() {
1224                if !current_prediction.was_shown {
1225                    current_prediction.was_shown = true;
1226                    self.shown_predictions
1227                        .push_front(current_prediction.prediction.clone());
1228                    if self.shown_predictions.len() > 50 {
1229                        let completion = self.shown_predictions.pop_back().unwrap();
1230                        self.rated_predictions.remove(&completion.id);
1231                    }
1232                }
1233            }
1234        }
1235    }
1236
1237    fn reject_prediction(
1238        &mut self,
1239        prediction_id: EditPredictionId,
1240        reason: EditPredictionRejectReason,
1241        was_shown: bool,
1242    ) {
1243        match self.edit_prediction_model {
1244            EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {
1245                if self.custom_predict_edits_url.is_some() {
1246                    return;
1247                }
1248            }
1249            EditPredictionModel::Sweep | EditPredictionModel::Mercury => return,
1250        }
1251
1252        self.reject_predictions_tx
1253            .unbounded_send(EditPredictionRejection {
1254                request_id: prediction_id.to_string(),
1255                reason,
1256                was_shown,
1257            })
1258            .log_err();
1259    }
1260
1261    fn is_refreshing(&self, project: &Entity<Project>) -> bool {
1262        self.projects
1263            .get(&project.entity_id())
1264            .is_some_and(|project_state| !project_state.pending_predictions.is_empty())
1265    }
1266
1267    pub fn refresh_prediction_from_buffer(
1268        &mut self,
1269        project: Entity<Project>,
1270        buffer: Entity<Buffer>,
1271        position: language::Anchor,
1272        cx: &mut Context<Self>,
1273    ) {
1274        self.queue_prediction_refresh(project.clone(), buffer.entity_id(), cx, move |this, cx| {
1275            let Some(request_task) = this
1276                .update(cx, |this, cx| {
1277                    this.request_prediction(
1278                        &project,
1279                        &buffer,
1280                        position,
1281                        PredictEditsRequestTrigger::Other,
1282                        cx,
1283                    )
1284                })
1285                .log_err()
1286            else {
1287                return Task::ready(anyhow::Ok(None));
1288            };
1289
1290            cx.spawn(async move |_cx| {
1291                request_task.await.map(|prediction_result| {
1292                    prediction_result.map(|prediction_result| {
1293                        (
1294                            prediction_result,
1295                            PredictionRequestedBy::Buffer(buffer.entity_id()),
1296                        )
1297                    })
1298                })
1299            })
1300        })
1301    }
1302
1303    pub fn refresh_prediction_from_diagnostics(
1304        &mut self,
1305        project: Entity<Project>,
1306        cx: &mut Context<Self>,
1307    ) {
1308        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1309            return;
1310        };
1311
1312        // Prefer predictions from buffer
1313        if project_state.current_prediction.is_some() {
1314            return;
1315        };
1316
1317        self.queue_prediction_refresh(project.clone(), project.entity_id(), cx, move |this, cx| {
1318            let Some((active_buffer, snapshot, cursor_point)) = this
1319                .read_with(cx, |this, cx| {
1320                    let project_state = this.projects.get(&project.entity_id())?;
1321                    let (buffer, position) = project_state.active_buffer(&project, cx)?;
1322                    let snapshot = buffer.read(cx).snapshot();
1323
1324                    if !Self::predictions_enabled_at(&snapshot, position, cx) {
1325                        return None;
1326                    }
1327
1328                    let cursor_point = position
1329                        .map(|pos| pos.to_point(&snapshot))
1330                        .unwrap_or_default();
1331
1332                    Some((buffer, snapshot, cursor_point))
1333                })
1334                .log_err()
1335                .flatten()
1336            else {
1337                return Task::ready(anyhow::Ok(None));
1338            };
1339
1340            cx.spawn(async move |cx| {
1341                let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1342                    active_buffer,
1343                    &snapshot,
1344                    Default::default(),
1345                    cursor_point,
1346                    &project,
1347                    cx,
1348                )
1349                .await?
1350                else {
1351                    return anyhow::Ok(None);
1352                };
1353
1354                let Some(prediction_result) = this
1355                    .update(cx, |this, cx| {
1356                        this.request_prediction(
1357                            &project,
1358                            &jump_buffer,
1359                            jump_position,
1360                            PredictEditsRequestTrigger::Diagnostics,
1361                            cx,
1362                        )
1363                    })?
1364                    .await?
1365                else {
1366                    return anyhow::Ok(None);
1367                };
1368
1369                this.update(cx, |this, cx| {
1370                    Some((
1371                        if this
1372                            .get_or_init_project(&project, cx)
1373                            .current_prediction
1374                            .is_none()
1375                        {
1376                            prediction_result
1377                        } else {
1378                            EditPredictionResult {
1379                                id: prediction_result.id,
1380                                prediction: Err(EditPredictionRejectReason::CurrentPreferred),
1381                            }
1382                        },
1383                        PredictionRequestedBy::DiagnosticsUpdate,
1384                    ))
1385                })
1386            })
1387        });
1388    }
1389
1390    fn predictions_enabled_at(
1391        snapshot: &BufferSnapshot,
1392        position: Option<language::Anchor>,
1393        cx: &App,
1394    ) -> bool {
1395        let file = snapshot.file();
1396        let all_settings = all_language_settings(file, cx);
1397        if !all_settings.show_edit_predictions(snapshot.language(), cx)
1398            || file.is_some_and(|file| !all_settings.edit_predictions_enabled_for_file(file, cx))
1399        {
1400            return false;
1401        }
1402
1403        if let Some(last_position) = position {
1404            let settings = snapshot.settings_at(last_position, cx);
1405
1406            if !settings.edit_predictions_disabled_in.is_empty()
1407                && let Some(scope) = snapshot.language_scope_at(last_position)
1408                && let Some(scope_name) = scope.override_name()
1409                && settings
1410                    .edit_predictions_disabled_in
1411                    .iter()
1412                    .any(|s| s == scope_name)
1413            {
1414                return false;
1415            }
1416        }
1417
1418        true
1419    }
1420
1421    #[cfg(not(test))]
1422    pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
1423    #[cfg(test)]
1424    pub const THROTTLE_TIMEOUT: Duration = Duration::ZERO;
1425
1426    fn queue_prediction_refresh(
1427        &mut self,
1428        project: Entity<Project>,
1429        throttle_entity: EntityId,
1430        cx: &mut Context<Self>,
1431        do_refresh: impl FnOnce(
1432            WeakEntity<Self>,
1433            &mut AsyncApp,
1434        )
1435            -> Task<Result<Option<(EditPredictionResult, PredictionRequestedBy)>>>
1436        + 'static,
1437    ) {
1438        let project_state = self.get_or_init_project(&project, cx);
1439        let pending_prediction_id = project_state.next_pending_prediction_id;
1440        project_state.next_pending_prediction_id += 1;
1441        let last_request = project_state.last_prediction_refresh;
1442
1443        let task = cx.spawn(async move |this, cx| {
1444            if let Some((last_entity, last_timestamp)) = last_request
1445                && throttle_entity == last_entity
1446                && let Some(timeout) =
1447                    (last_timestamp + Self::THROTTLE_TIMEOUT).checked_duration_since(Instant::now())
1448            {
1449                cx.background_executor().timer(timeout).await;
1450            }
1451
1452            // If this task was cancelled before the throttle timeout expired,
1453            // do not perform a request.
1454            let mut is_cancelled = true;
1455            this.update(cx, |this, cx| {
1456                let project_state = this.get_or_init_project(&project, cx);
1457                if !project_state
1458                    .cancelled_predictions
1459                    .remove(&pending_prediction_id)
1460                {
1461                    project_state.last_prediction_refresh = Some((throttle_entity, Instant::now()));
1462                    is_cancelled = false;
1463                }
1464            })
1465            .ok();
1466            if is_cancelled {
1467                return None;
1468            }
1469
1470            let new_prediction_result = do_refresh(this.clone(), cx).await.log_err().flatten();
1471            let new_prediction_id = new_prediction_result
1472                .as_ref()
1473                .map(|(prediction, _)| prediction.id.clone());
1474
1475            // When a prediction completes, remove it from the pending list, and cancel
1476            // any pending predictions that were enqueued before it.
1477            this.update(cx, |this, cx| {
1478                let project_state = this.get_or_init_project(&project, cx);
1479
1480                let is_cancelled = project_state
1481                    .cancelled_predictions
1482                    .remove(&pending_prediction_id);
1483
1484                let new_current_prediction = if !is_cancelled
1485                    && let Some((prediction_result, requested_by)) = new_prediction_result
1486                {
1487                    match prediction_result.prediction {
1488                        Ok(prediction) => {
1489                            let new_prediction = CurrentEditPrediction {
1490                                requested_by,
1491                                prediction,
1492                                was_shown: false,
1493                            };
1494
1495                            if let Some(current_prediction) =
1496                                project_state.current_prediction.as_ref()
1497                            {
1498                                if new_prediction.should_replace_prediction(&current_prediction, cx)
1499                                {
1500                                    this.reject_current_prediction(
1501                                        EditPredictionRejectReason::Replaced,
1502                                        &project,
1503                                    );
1504
1505                                    Some(new_prediction)
1506                                } else {
1507                                    this.reject_prediction(
1508                                        new_prediction.prediction.id,
1509                                        EditPredictionRejectReason::CurrentPreferred,
1510                                        false,
1511                                    );
1512                                    None
1513                                }
1514                            } else {
1515                                Some(new_prediction)
1516                            }
1517                        }
1518                        Err(reject_reason) => {
1519                            this.reject_prediction(prediction_result.id, reject_reason, false);
1520                            None
1521                        }
1522                    }
1523                } else {
1524                    None
1525                };
1526
1527                let project_state = this.get_or_init_project(&project, cx);
1528
1529                if let Some(new_prediction) = new_current_prediction {
1530                    project_state.current_prediction = Some(new_prediction);
1531                }
1532
1533                let mut pending_predictions = mem::take(&mut project_state.pending_predictions);
1534                for (ix, pending_prediction) in pending_predictions.iter().enumerate() {
1535                    if pending_prediction.id == pending_prediction_id {
1536                        pending_predictions.remove(ix);
1537                        for pending_prediction in pending_predictions.drain(0..ix) {
1538                            project_state.cancel_pending_prediction(pending_prediction, cx)
1539                        }
1540                        break;
1541                    }
1542                }
1543                this.get_or_init_project(&project, cx).pending_predictions = pending_predictions;
1544                cx.notify();
1545            })
1546            .ok();
1547
1548            new_prediction_id
1549        });
1550
1551        if project_state.pending_predictions.len() <= 1 {
1552            project_state.pending_predictions.push(PendingPrediction {
1553                id: pending_prediction_id,
1554                task,
1555            });
1556        } else if project_state.pending_predictions.len() == 2 {
1557            let pending_prediction = project_state.pending_predictions.pop().unwrap();
1558            project_state.pending_predictions.push(PendingPrediction {
1559                id: pending_prediction_id,
1560                task,
1561            });
1562            project_state.cancel_pending_prediction(pending_prediction, cx);
1563        }
1564    }
1565
1566    pub fn request_prediction(
1567        &mut self,
1568        project: &Entity<Project>,
1569        active_buffer: &Entity<Buffer>,
1570        position: language::Anchor,
1571        trigger: PredictEditsRequestTrigger,
1572        cx: &mut Context<Self>,
1573    ) -> Task<Result<Option<EditPredictionResult>>> {
1574        self.request_prediction_internal(
1575            project.clone(),
1576            active_buffer.clone(),
1577            position,
1578            trigger,
1579            cx.has_flag::<Zeta2FeatureFlag>(),
1580            cx,
1581        )
1582    }
1583
1584    fn request_prediction_internal(
1585        &mut self,
1586        project: Entity<Project>,
1587        active_buffer: Entity<Buffer>,
1588        position: language::Anchor,
1589        trigger: PredictEditsRequestTrigger,
1590        allow_jump: bool,
1591        cx: &mut Context<Self>,
1592    ) -> Task<Result<Option<EditPredictionResult>>> {
1593        const DIAGNOSTIC_LINES_RANGE: u32 = 20;
1594
1595        self.get_or_init_project(&project, cx);
1596        let project_state = self.projects.get(&project.entity_id()).unwrap();
1597        let stored_events = project_state.events(cx);
1598        let has_events = !stored_events.is_empty();
1599        let events: Vec<Arc<zeta_prompt::Event>> =
1600            stored_events.into_iter().map(|e| e.event).collect();
1601        let debug_tx = project_state.debug_tx.clone();
1602
1603        let snapshot = active_buffer.read(cx).snapshot();
1604        let cursor_point = position.to_point(&snapshot);
1605        let diagnostic_search_start = cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
1606        let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
1607        let diagnostic_search_range =
1608            Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
1609
1610        let related_files = if self.use_context {
1611            self.context_for_project(&project, cx)
1612        } else {
1613            Vec::new().into()
1614        };
1615
1616        let inputs = EditPredictionModelInput {
1617            project: project.clone(),
1618            buffer: active_buffer.clone(),
1619            snapshot: snapshot.clone(),
1620            position,
1621            events,
1622            related_files,
1623            recent_paths: project_state.recent_paths.clone(),
1624            trigger,
1625            diagnostic_search_range: diagnostic_search_range.clone(),
1626            debug_tx,
1627        };
1628
1629        let task = match self.edit_prediction_model {
1630            EditPredictionModel::Zeta1 => zeta1::request_prediction_with_zeta1(self, inputs, cx),
1631            EditPredictionModel::Zeta2 => zeta2::request_prediction_with_zeta2(self, inputs, cx),
1632            EditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep(inputs, cx),
1633            EditPredictionModel::Mercury => self.mercury.request_prediction(inputs, cx),
1634        };
1635
1636        cx.spawn(async move |this, cx| {
1637            let prediction = task.await?;
1638
1639            if prediction.is_none() && allow_jump {
1640                let cursor_point = position.to_point(&snapshot);
1641                if has_events
1642                    && let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1643                        active_buffer.clone(),
1644                        &snapshot,
1645                        diagnostic_search_range,
1646                        cursor_point,
1647                        &project,
1648                        cx,
1649                    )
1650                    .await?
1651                {
1652                    return this
1653                        .update(cx, |this, cx| {
1654                            this.request_prediction_internal(
1655                                project,
1656                                jump_buffer,
1657                                jump_position,
1658                                trigger,
1659                                false,
1660                                cx,
1661                            )
1662                        })?
1663                        .await;
1664                }
1665
1666                return anyhow::Ok(None);
1667            }
1668
1669            Ok(prediction)
1670        })
1671    }
1672
1673    async fn next_diagnostic_location(
1674        active_buffer: Entity<Buffer>,
1675        active_buffer_snapshot: &BufferSnapshot,
1676        active_buffer_diagnostic_search_range: Range<Point>,
1677        active_buffer_cursor_point: Point,
1678        project: &Entity<Project>,
1679        cx: &mut AsyncApp,
1680    ) -> Result<Option<(Entity<Buffer>, language::Anchor)>> {
1681        // find the closest diagnostic to the cursor that wasn't close enough to be included in the last request
1682        let mut jump_location = active_buffer_snapshot
1683            .diagnostic_groups(None)
1684            .into_iter()
1685            .filter_map(|(_, group)| {
1686                let range = &group.entries[group.primary_ix]
1687                    .range
1688                    .to_point(&active_buffer_snapshot);
1689                if range.overlaps(&active_buffer_diagnostic_search_range) {
1690                    None
1691                } else {
1692                    Some(range.start)
1693                }
1694            })
1695            .min_by_key(|probe| probe.row.abs_diff(active_buffer_cursor_point.row))
1696            .map(|position| {
1697                (
1698                    active_buffer.clone(),
1699                    active_buffer_snapshot.anchor_before(position),
1700                )
1701            });
1702
1703        if jump_location.is_none() {
1704            let active_buffer_path = active_buffer.read_with(cx, |buffer, cx| {
1705                let file = buffer.file()?;
1706
1707                Some(ProjectPath {
1708                    worktree_id: file.worktree_id(cx),
1709                    path: file.path().clone(),
1710                })
1711            })?;
1712
1713            let buffer_task = project.update(cx, |project, cx| {
1714                let (path, _, _) = project
1715                    .diagnostic_summaries(false, cx)
1716                    .filter(|(path, _, _)| Some(path) != active_buffer_path.as_ref())
1717                    .max_by_key(|(path, _, _)| {
1718                        // find the buffer with errors that shares most parent directories
1719                        path.path
1720                            .components()
1721                            .zip(
1722                                active_buffer_path
1723                                    .as_ref()
1724                                    .map(|p| p.path.components())
1725                                    .unwrap_or_default(),
1726                            )
1727                            .take_while(|(a, b)| a == b)
1728                            .count()
1729                    })?;
1730
1731                Some(project.open_buffer(path, cx))
1732            })?;
1733
1734            if let Some(buffer_task) = buffer_task {
1735                let closest_buffer = buffer_task.await?;
1736
1737                jump_location = closest_buffer
1738                    .read_with(cx, |buffer, _cx| {
1739                        buffer
1740                            .buffer_diagnostics(None)
1741                            .into_iter()
1742                            .min_by_key(|entry| entry.diagnostic.severity)
1743                            .map(|entry| entry.range.start)
1744                    })?
1745                    .map(|position| (closest_buffer, position));
1746            }
1747        }
1748
1749        anyhow::Ok(jump_location)
1750    }
1751
1752    async fn send_raw_llm_request(
1753        request: open_ai::Request,
1754        client: Arc<Client>,
1755        llm_token: LlmApiToken,
1756        app_version: Version,
1757        #[cfg(feature = "cli-support")] eval_cache: Option<Arc<dyn EvalCache>>,
1758        #[cfg(feature = "cli-support")] eval_cache_kind: EvalCacheEntryKind,
1759    ) -> Result<(open_ai::Response, Option<EditPredictionUsage>)> {
1760        let url = client
1761            .http_client()
1762            .build_zed_llm_url("/predict_edits/raw", &[])?;
1763
1764        #[cfg(feature = "cli-support")]
1765        let cache_key = if let Some(cache) = eval_cache {
1766            use collections::FxHasher;
1767            use std::hash::{Hash, Hasher};
1768
1769            let mut hasher = FxHasher::default();
1770            url.hash(&mut hasher);
1771            let request_str = serde_json::to_string_pretty(&request)?;
1772            request_str.hash(&mut hasher);
1773            let hash = hasher.finish();
1774
1775            let key = (eval_cache_kind, hash);
1776            if let Some(response_str) = cache.read(key) {
1777                return Ok((serde_json::from_str(&response_str)?, None));
1778            }
1779
1780            Some((cache, request_str, key))
1781        } else {
1782            None
1783        };
1784
1785        let (response, usage) = Self::send_api_request(
1786            |builder| {
1787                let req = builder
1788                    .uri(url.as_ref())
1789                    .body(serde_json::to_string(&request)?.into());
1790                Ok(req?)
1791            },
1792            client,
1793            llm_token,
1794            app_version,
1795            true,
1796        )
1797        .await?;
1798
1799        #[cfg(feature = "cli-support")]
1800        if let Some((cache, request, key)) = cache_key {
1801            cache.write(key, &request, &serde_json::to_string_pretty(&response)?);
1802        }
1803
1804        Ok((response, usage))
1805    }
1806
1807    fn handle_api_response<T>(
1808        this: &WeakEntity<Self>,
1809        response: Result<(T, Option<EditPredictionUsage>)>,
1810        cx: &mut gpui::AsyncApp,
1811    ) -> Result<T> {
1812        match response {
1813            Ok((data, usage)) => {
1814                if let Some(usage) = usage {
1815                    this.update(cx, |this, cx| {
1816                        this.user_store.update(cx, |user_store, cx| {
1817                            user_store.update_edit_prediction_usage(usage, cx);
1818                        });
1819                    })
1820                    .ok();
1821                }
1822                Ok(data)
1823            }
1824            Err(err) => {
1825                if err.is::<ZedUpdateRequiredError>() {
1826                    cx.update(|cx| {
1827                        this.update(cx, |this, _cx| {
1828                            this.update_required = true;
1829                        })
1830                        .ok();
1831
1832                        let error_message: SharedString = err.to_string().into();
1833                        show_app_notification(
1834                            NotificationId::unique::<ZedUpdateRequiredError>(),
1835                            cx,
1836                            move |cx| {
1837                                cx.new(|cx| {
1838                                    ErrorMessagePrompt::new(error_message.clone(), cx)
1839                                        .with_link_button("Update Zed", "https://zed.dev/releases")
1840                                })
1841                            },
1842                        );
1843                    })
1844                    .ok();
1845                }
1846                Err(err)
1847            }
1848        }
1849    }
1850
1851    async fn send_api_request<Res>(
1852        build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
1853        client: Arc<Client>,
1854        llm_token: LlmApiToken,
1855        app_version: Version,
1856        require_auth: bool,
1857    ) -> Result<(Res, Option<EditPredictionUsage>)>
1858    where
1859        Res: DeserializeOwned,
1860    {
1861        let http_client = client.http_client();
1862
1863        let mut token = if require_auth {
1864            Some(llm_token.acquire(&client).await?)
1865        } else {
1866            llm_token.acquire(&client).await.ok()
1867        };
1868        let mut did_retry = false;
1869
1870        loop {
1871            let request_builder = http_client::Request::builder().method(Method::POST);
1872
1873            let mut request_builder = request_builder
1874                .header("Content-Type", "application/json")
1875                .header(ZED_VERSION_HEADER_NAME, app_version.to_string());
1876
1877            // Only add Authorization header if we have a token
1878            if let Some(ref token_value) = token {
1879                request_builder =
1880                    request_builder.header("Authorization", format!("Bearer {}", token_value));
1881            }
1882
1883            let request = build(request_builder)?;
1884
1885            let mut response = http_client.send(request).await?;
1886
1887            if let Some(minimum_required_version) = response
1888                .headers()
1889                .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
1890                .and_then(|version| Version::from_str(version.to_str().ok()?).ok())
1891            {
1892                anyhow::ensure!(
1893                    app_version >= minimum_required_version,
1894                    ZedUpdateRequiredError {
1895                        minimum_version: minimum_required_version
1896                    }
1897                );
1898            }
1899
1900            if response.status().is_success() {
1901                let usage = EditPredictionUsage::from_headers(response.headers()).ok();
1902
1903                let mut body = Vec::new();
1904                response.body_mut().read_to_end(&mut body).await?;
1905                return Ok((serde_json::from_slice(&body)?, usage));
1906            } else if !did_retry
1907                && token.is_some()
1908                && response
1909                    .headers()
1910                    .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
1911                    .is_some()
1912            {
1913                did_retry = true;
1914                token = Some(llm_token.refresh(&client).await?);
1915            } else {
1916                let mut body = String::new();
1917                response.body_mut().read_to_string(&mut body).await?;
1918                anyhow::bail!(
1919                    "Request failed with status: {:?}\nBody: {}",
1920                    response.status(),
1921                    body
1922                );
1923            }
1924        }
1925    }
1926
1927    pub fn refresh_context(
1928        &mut self,
1929        project: &Entity<Project>,
1930        buffer: &Entity<language::Buffer>,
1931        cursor_position: language::Anchor,
1932        cx: &mut Context<Self>,
1933    ) {
1934        if self.use_context {
1935            self.get_or_init_project(project, cx)
1936                .context
1937                .update(cx, |store, cx| {
1938                    store.refresh(buffer.clone(), cursor_position, cx);
1939                });
1940        }
1941    }
1942
1943    #[cfg(feature = "cli-support")]
1944    pub fn set_context_for_buffer(
1945        &mut self,
1946        project: &Entity<Project>,
1947        related_files: Vec<RelatedFile>,
1948        cx: &mut Context<Self>,
1949    ) {
1950        self.get_or_init_project(project, cx)
1951            .context
1952            .update(cx, |store, _| {
1953                store.set_related_files(related_files);
1954            });
1955    }
1956
1957    fn is_file_open_source(
1958        &self,
1959        project: &Entity<Project>,
1960        file: &Arc<dyn File>,
1961        cx: &App,
1962    ) -> bool {
1963        if !file.is_local() || file.is_private() {
1964            return false;
1965        }
1966        let Some(project_state) = self.projects.get(&project.entity_id()) else {
1967            return false;
1968        };
1969        project_state
1970            .license_detection_watchers
1971            .get(&file.worktree_id(cx))
1972            .as_ref()
1973            .is_some_and(|watcher| watcher.is_project_open_source())
1974    }
1975
1976    fn can_collect_file(&self, project: &Entity<Project>, file: &Arc<dyn File>, cx: &App) -> bool {
1977        self.data_collection_choice.is_enabled() && self.is_file_open_source(project, file, cx)
1978    }
1979
1980    fn can_collect_events(&self, events: &[Arc<zeta_prompt::Event>]) -> bool {
1981        if !self.data_collection_choice.is_enabled() {
1982            return false;
1983        }
1984        events.iter().all(|event| {
1985            matches!(
1986                event.as_ref(),
1987                zeta_prompt::Event::BufferChange {
1988                    in_open_source_repo: true,
1989                    ..
1990                }
1991            )
1992        })
1993    }
1994
1995    fn load_data_collection_choice() -> DataCollectionChoice {
1996        let choice = KEY_VALUE_STORE
1997            .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
1998            .log_err()
1999            .flatten();
2000
2001        match choice.as_deref() {
2002            Some("true") => DataCollectionChoice::Enabled,
2003            Some("false") => DataCollectionChoice::Disabled,
2004            Some(_) => {
2005                log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
2006                DataCollectionChoice::NotAnswered
2007            }
2008            None => DataCollectionChoice::NotAnswered,
2009        }
2010    }
2011
2012    fn toggle_data_collection_choice(&mut self, cx: &mut Context<Self>) {
2013        self.data_collection_choice = self.data_collection_choice.toggle();
2014        let new_choice = self.data_collection_choice;
2015        db::write_and_log(cx, move || {
2016            KEY_VALUE_STORE.write_kvp(
2017                ZED_PREDICT_DATA_COLLECTION_CHOICE.into(),
2018                new_choice.is_enabled().to_string(),
2019            )
2020        });
2021    }
2022
2023    pub fn shown_predictions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
2024        self.shown_predictions.iter()
2025    }
2026
2027    pub fn shown_completions_len(&self) -> usize {
2028        self.shown_predictions.len()
2029    }
2030
2031    pub fn is_prediction_rated(&self, id: &EditPredictionId) -> bool {
2032        self.rated_predictions.contains(id)
2033    }
2034
2035    pub fn rate_prediction(
2036        &mut self,
2037        prediction: &EditPrediction,
2038        rating: EditPredictionRating,
2039        feedback: String,
2040        cx: &mut Context<Self>,
2041    ) {
2042        self.rated_predictions.insert(prediction.id.clone());
2043        telemetry::event!(
2044            "Edit Prediction Rated",
2045            rating,
2046            inputs = prediction.inputs,
2047            output = prediction.edit_preview.as_unified_diff(&prediction.edits),
2048            feedback
2049        );
2050        self.client.telemetry().flush_events().detach();
2051        cx.notify();
2052    }
2053
2054    fn configure_context_retrieval(&mut self, cx: &mut Context<'_, EditPredictionStore>) {
2055        self.use_context = cx.has_flag::<Zeta2FeatureFlag>()
2056            && all_language_settings(None, cx).edit_predictions.use_context;
2057    }
2058}
2059
2060#[derive(Error, Debug)]
2061#[error(
2062    "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
2063)]
2064pub struct ZedUpdateRequiredError {
2065    minimum_version: Version,
2066}
2067
2068#[cfg(feature = "cli-support")]
2069pub type EvalCacheKey = (EvalCacheEntryKind, u64);
2070
2071#[cfg(feature = "cli-support")]
2072#[derive(Debug, Clone, Copy, PartialEq)]
2073pub enum EvalCacheEntryKind {
2074    Context,
2075    Search,
2076    Prediction,
2077}
2078
2079#[cfg(feature = "cli-support")]
2080impl std::fmt::Display for EvalCacheEntryKind {
2081    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2082        match self {
2083            EvalCacheEntryKind::Search => write!(f, "search"),
2084            EvalCacheEntryKind::Context => write!(f, "context"),
2085            EvalCacheEntryKind::Prediction => write!(f, "prediction"),
2086        }
2087    }
2088}
2089
2090#[cfg(feature = "cli-support")]
2091pub trait EvalCache: Send + Sync {
2092    fn read(&self, key: EvalCacheKey) -> Option<String>;
2093    fn write(&self, key: EvalCacheKey, input: &str, value: &str);
2094}
2095
2096#[derive(Debug, Clone, Copy)]
2097pub enum DataCollectionChoice {
2098    NotAnswered,
2099    Enabled,
2100    Disabled,
2101}
2102
2103impl DataCollectionChoice {
2104    pub fn is_enabled(self) -> bool {
2105        match self {
2106            Self::Enabled => true,
2107            Self::NotAnswered | Self::Disabled => false,
2108        }
2109    }
2110
2111    pub fn is_answered(self) -> bool {
2112        match self {
2113            Self::Enabled | Self::Disabled => true,
2114            Self::NotAnswered => false,
2115        }
2116    }
2117
2118    #[must_use]
2119    pub fn toggle(&self) -> DataCollectionChoice {
2120        match self {
2121            Self::Enabled => Self::Disabled,
2122            Self::Disabled => Self::Enabled,
2123            Self::NotAnswered => Self::Enabled,
2124        }
2125    }
2126}
2127
2128impl From<bool> for DataCollectionChoice {
2129    fn from(value: bool) -> Self {
2130        match value {
2131            true => DataCollectionChoice::Enabled,
2132            false => DataCollectionChoice::Disabled,
2133        }
2134    }
2135}
2136
2137struct ZedPredictUpsell;
2138
2139impl Dismissable for ZedPredictUpsell {
2140    const KEY: &'static str = "dismissed-edit-predict-upsell";
2141
2142    fn dismissed() -> bool {
2143        // To make this backwards compatible with older versions of Zed, we
2144        // check if the user has seen the previous Edit Prediction Onboarding
2145        // before, by checking the data collection choice which was written to
2146        // the database once the user clicked on "Accept and Enable"
2147        if KEY_VALUE_STORE
2148            .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2149            .log_err()
2150            .is_some_and(|s| s.is_some())
2151        {
2152            return true;
2153        }
2154
2155        KEY_VALUE_STORE
2156            .read_kvp(Self::KEY)
2157            .log_err()
2158            .is_some_and(|s| s.is_some())
2159    }
2160}
2161
2162pub fn should_show_upsell_modal() -> bool {
2163    !ZedPredictUpsell::dismissed()
2164}
2165
2166pub fn init(cx: &mut App) {
2167    cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
2168        workspace.register_action(
2169            move |workspace, _: &zed_actions::OpenZedPredictOnboarding, window, cx| {
2170                ZedPredictModal::toggle(
2171                    workspace,
2172                    workspace.user_store().clone(),
2173                    workspace.client().clone(),
2174                    window,
2175                    cx,
2176                )
2177            },
2178        );
2179
2180        workspace.register_action(|workspace, _: &ResetOnboarding, _window, cx| {
2181            update_settings_file(workspace.app_state().fs.clone(), cx, move |settings, _| {
2182                settings
2183                    .project
2184                    .all_languages
2185                    .features
2186                    .get_or_insert_default()
2187                    .edit_prediction_provider = Some(EditPredictionProvider::None)
2188            });
2189        });
2190    })
2191    .detach();
2192}