edit_prediction.rs

   1use anyhow::Result;
   2use arrayvec::ArrayVec;
   3use client::{Client, EditPredictionUsage, UserStore};
   4use cloud_llm_client::predict_edits_v3::{self, PromptFormat};
   5use cloud_llm_client::{
   6    EXPIRED_LLM_TOKEN_HEADER_NAME, EditPredictionRejectReason, EditPredictionRejection,
   7    MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
   8    PredictEditsRequestTrigger, RejectEditPredictionsBodyRef, ZED_VERSION_HEADER_NAME,
   9};
  10use collections::{HashMap, HashSet};
  11use db::kvp::{Dismissable, KEY_VALUE_STORE};
  12use edit_prediction_context::EditPredictionExcerptOptions;
  13use edit_prediction_context::{RelatedExcerptStore, RelatedExcerptStoreEvent, RelatedFile};
  14use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
  15use futures::{
  16    AsyncReadExt as _, FutureExt as _, StreamExt as _,
  17    channel::mpsc::{self, UnboundedReceiver},
  18    select_biased,
  19};
  20use gpui::BackgroundExecutor;
  21use gpui::http_client::Url;
  22use gpui::{
  23    App, AsyncApp, Entity, EntityId, Global, SharedString, Subscription, Task, WeakEntity, actions,
  24    http_client::{self, AsyncBody, Method},
  25    prelude::*,
  26};
  27use language::language_settings::all_language_settings;
  28use language::{Anchor, Buffer, File, Point, TextBufferSnapshot, ToOffset, ToPoint};
  29use language::{BufferSnapshot, OffsetRangeExt};
  30use language_model::{LlmApiToken, RefreshLlmTokenListener};
  31use project::{Project, ProjectPath, WorktreeId};
  32use release_channel::AppVersion;
  33use semver::Version;
  34use serde::de::DeserializeOwned;
  35use settings::{EditPredictionProvider, SettingsStore, update_settings_file};
  36use std::collections::{VecDeque, hash_map};
  37use text::Edit;
  38use workspace::Workspace;
  39
  40use std::ops::Range;
  41use std::path::Path;
  42use std::rc::Rc;
  43use std::str::FromStr as _;
  44use std::sync::{Arc, LazyLock};
  45use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
  46use std::{env, mem};
  47use thiserror::Error;
  48use util::{RangeExt as _, ResultExt as _};
  49use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
  50
  51pub mod cursor_excerpt;
  52pub mod example_spec;
  53mod license_detection;
  54pub mod mercury;
  55mod onboarding_modal;
  56pub mod open_ai_response;
  57mod prediction;
  58pub mod sweep_ai;
  59
  60pub mod udiff;
  61
  62mod capture_example;
  63mod zed_edit_prediction_delegate;
  64pub mod zeta1;
  65pub mod zeta2;
  66
  67#[cfg(test)]
  68mod edit_prediction_tests;
  69
  70use crate::capture_example::should_sample_edit_prediction_example_capture;
  71use crate::license_detection::LicenseDetectionWatcher;
  72use crate::mercury::Mercury;
  73use crate::onboarding_modal::ZedPredictModal;
  74pub use crate::prediction::EditPrediction;
  75pub use crate::prediction::EditPredictionId;
  76use crate::prediction::EditPredictionResult;
  77pub use crate::sweep_ai::SweepAi;
  78pub use capture_example::capture_example;
  79pub use language_model::ApiKeyState;
  80pub use telemetry_events::EditPredictionRating;
  81pub use zed_edit_prediction_delegate::ZedEditPredictionDelegate;
  82
  83actions!(
  84    edit_prediction,
  85    [
  86        /// Resets the edit prediction onboarding state.
  87        ResetOnboarding,
  88        /// Clears the edit prediction history.
  89        ClearHistory,
  90    ]
  91);
  92
  93/// Maximum number of events to track.
  94const EVENT_COUNT_MAX: usize = 6;
  95const CHANGE_GROUPING_LINE_SPAN: u32 = 8;
  96const LAST_CHANGE_GROUPING_TIME: Duration = Duration::from_secs(1);
  97const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
  98const REJECT_REQUEST_DEBOUNCE: Duration = Duration::from_secs(15);
  99
 100pub struct SweepFeatureFlag;
 101
 102impl FeatureFlag for SweepFeatureFlag {
 103    const NAME: &str = "sweep-ai";
 104}
 105
 106pub struct MercuryFeatureFlag;
 107
 108impl FeatureFlag for MercuryFeatureFlag {
 109    const NAME: &str = "mercury";
 110}
 111
 112pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
 113    context: EditPredictionExcerptOptions {
 114        max_bytes: 512,
 115        min_bytes: 128,
 116        target_before_cursor_over_total_bytes: 0.5,
 117    },
 118    prompt_format: PromptFormat::DEFAULT,
 119};
 120
 121static USE_OLLAMA: LazyLock<bool> =
 122    LazyLock::new(|| env::var("ZED_ZETA2_OLLAMA").is_ok_and(|var| !var.is_empty()));
 123
 124static EDIT_PREDICTIONS_MODEL_ID: LazyLock<String> = LazyLock::new(|| {
 125    match env::var("ZED_ZETA2_MODEL").as_deref() {
 126        Ok("zeta2-exp") => "4w5n28vw", // Fine-tuned model @ Baseten
 127        Ok(model) => model,
 128        Err(_) if *USE_OLLAMA => "qwen3-coder:30b",
 129        Err(_) => "yqvev8r3", // Vanilla qwen3-coder @ Baseten
 130    }
 131    .to_string()
 132});
 133
 134pub struct Zeta2FeatureFlag;
 135
 136impl FeatureFlag for Zeta2FeatureFlag {
 137    const NAME: &'static str = "zeta2";
 138
 139    fn enabled_for_staff() -> bool {
 140        true
 141    }
 142}
 143
 144pub struct EditPredictionExampleCaptureFeatureFlag;
 145
 146impl FeatureFlag for EditPredictionExampleCaptureFeatureFlag {
 147    const NAME: &'static str = "edit-prediction-example-capture";
 148
 149    fn enabled_for_staff() -> bool {
 150        true
 151    }
 152}
 153
 154#[derive(Clone)]
 155struct EditPredictionStoreGlobal(Entity<EditPredictionStore>);
 156
 157impl Global for EditPredictionStoreGlobal {}
 158
 159pub struct EditPredictionStore {
 160    client: Arc<Client>,
 161    user_store: Entity<UserStore>,
 162    llm_token: LlmApiToken,
 163    _llm_token_subscription: Subscription,
 164    projects: HashMap<EntityId, ProjectState>,
 165    use_context: bool,
 166    options: ZetaOptions,
 167    update_required: bool,
 168    #[cfg(feature = "cli-support")]
 169    eval_cache: Option<Arc<dyn EvalCache>>,
 170    edit_prediction_model: EditPredictionModel,
 171    pub sweep_ai: SweepAi,
 172    pub mercury: Mercury,
 173    data_collection_choice: DataCollectionChoice,
 174    reject_predictions_tx: mpsc::UnboundedSender<EditPredictionRejection>,
 175    shown_predictions: VecDeque<EditPrediction>,
 176    rated_predictions: HashSet<EditPredictionId>,
 177    custom_predict_edits_url: Option<Arc<Url>>,
 178}
 179
 180#[derive(Copy, Clone, Default, PartialEq, Eq)]
 181pub enum EditPredictionModel {
 182    #[default]
 183    Zeta1,
 184    Zeta2,
 185    Sweep,
 186    Mercury,
 187}
 188
 189pub struct EditPredictionModelInput {
 190    project: Entity<Project>,
 191    buffer: Entity<Buffer>,
 192    snapshot: BufferSnapshot,
 193    position: Anchor,
 194    events: Vec<Arc<zeta_prompt::Event>>,
 195    related_files: Arc<[RelatedFile]>,
 196    recent_paths: VecDeque<ProjectPath>,
 197    trigger: PredictEditsRequestTrigger,
 198    diagnostic_search_range: Range<Point>,
 199    debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
 200    pub user_actions: Vec<UserActionRecord>,
 201}
 202
 203#[derive(Debug, Clone, PartialEq)]
 204pub struct ZetaOptions {
 205    pub context: EditPredictionExcerptOptions,
 206    pub prompt_format: predict_edits_v3::PromptFormat,
 207}
 208
 209#[derive(Debug)]
 210pub enum DebugEvent {
 211    ContextRetrievalStarted(ContextRetrievalStartedDebugEvent),
 212    ContextRetrievalFinished(ContextRetrievalFinishedDebugEvent),
 213    EditPredictionStarted(EditPredictionStartedDebugEvent),
 214    EditPredictionFinished(EditPredictionFinishedDebugEvent),
 215}
 216
 217#[derive(Debug)]
 218pub struct ContextRetrievalStartedDebugEvent {
 219    pub project_entity_id: EntityId,
 220    pub timestamp: Instant,
 221    pub search_prompt: String,
 222}
 223
 224#[derive(Debug)]
 225pub struct ContextRetrievalFinishedDebugEvent {
 226    pub project_entity_id: EntityId,
 227    pub timestamp: Instant,
 228    pub metadata: Vec<(&'static str, SharedString)>,
 229}
 230
 231#[derive(Debug)]
 232pub struct EditPredictionStartedDebugEvent {
 233    pub buffer: WeakEntity<Buffer>,
 234    pub position: Anchor,
 235    pub prompt: Option<String>,
 236}
 237
 238#[derive(Debug)]
 239pub struct EditPredictionFinishedDebugEvent {
 240    pub buffer: WeakEntity<Buffer>,
 241    pub position: Anchor,
 242    pub model_output: Option<String>,
 243}
 244
 245pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
 246
 247const USER_ACTION_HISTORY_SIZE: usize = 16;
 248
 249#[derive(Clone, Debug)]
 250pub struct UserActionRecord {
 251    pub action_type: UserActionType,
 252    pub buffer_id: EntityId,
 253    pub line_number: u32,
 254    pub offset: usize,
 255    pub timestamp_epoch_ms: u64,
 256}
 257
 258#[derive(Clone, Copy, Debug, PartialEq, Eq)]
 259pub enum UserActionType {
 260    InsertChar,
 261    InsertSelection,
 262    DeleteChar,
 263    DeleteSelection,
 264    CursorMovement,
 265}
 266
 267/// An event with associated metadata for reconstructing buffer state.
 268#[derive(Clone)]
 269pub struct StoredEvent {
 270    pub event: Arc<zeta_prompt::Event>,
 271    pub old_snapshot: TextBufferSnapshot,
 272}
 273
 274struct ProjectState {
 275    events: VecDeque<StoredEvent>,
 276    last_event: Option<LastEvent>,
 277    recent_paths: VecDeque<ProjectPath>,
 278    registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
 279    current_prediction: Option<CurrentEditPrediction>,
 280    next_pending_prediction_id: usize,
 281    pending_predictions: ArrayVec<PendingPrediction, 2>,
 282    debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
 283    last_prediction_refresh: Option<(EntityId, Instant)>,
 284    cancelled_predictions: HashSet<usize>,
 285    context: Entity<RelatedExcerptStore>,
 286    license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
 287    user_actions: VecDeque<UserActionRecord>,
 288    _subscription: gpui::Subscription,
 289}
 290
 291impl ProjectState {
 292    fn record_user_action(&mut self, action: UserActionRecord) {
 293        if self.user_actions.len() >= USER_ACTION_HISTORY_SIZE {
 294            self.user_actions.pop_front();
 295        }
 296        self.user_actions.push_back(action);
 297    }
 298
 299    pub fn events(&self, cx: &App) -> Vec<StoredEvent> {
 300        self.events
 301            .iter()
 302            .cloned()
 303            .chain(
 304                self.last_event
 305                    .as_ref()
 306                    .and_then(|event| event.finalize(&self.license_detection_watchers, cx)),
 307            )
 308            .collect()
 309    }
 310
 311    pub fn events_split_by_pause(&self, cx: &App) -> Vec<StoredEvent> {
 312        self.events
 313            .iter()
 314            .cloned()
 315            .chain(self.last_event.as_ref().iter().flat_map(|event| {
 316                let (one, two) = event.split_by_pause();
 317                let one = one.finalize(&self.license_detection_watchers, cx);
 318                let two = two.and_then(|two| two.finalize(&self.license_detection_watchers, cx));
 319                one.into_iter().chain(two)
 320            }))
 321            .collect()
 322    }
 323
 324    fn cancel_pending_prediction(
 325        &mut self,
 326        pending_prediction: PendingPrediction,
 327        cx: &mut Context<EditPredictionStore>,
 328    ) {
 329        self.cancelled_predictions.insert(pending_prediction.id);
 330
 331        cx.spawn(async move |this, cx| {
 332            let Some(prediction_id) = pending_prediction.task.await else {
 333                return;
 334            };
 335
 336            this.update(cx, |this, _cx| {
 337                this.reject_prediction(prediction_id, EditPredictionRejectReason::Canceled, false);
 338            })
 339            .ok();
 340        })
 341        .detach()
 342    }
 343
 344    fn active_buffer(
 345        &self,
 346        project: &Entity<Project>,
 347        cx: &App,
 348    ) -> Option<(Entity<Buffer>, Option<Anchor>)> {
 349        let project = project.read(cx);
 350        let active_path = project.path_for_entry(project.active_entry()?, cx)?;
 351        let active_buffer = project.buffer_store().read(cx).get_by_path(&active_path)?;
 352        let registered_buffer = self.registered_buffers.get(&active_buffer.entity_id())?;
 353        Some((active_buffer, registered_buffer.last_position))
 354    }
 355}
 356
 357#[derive(Debug, Clone)]
 358struct CurrentEditPrediction {
 359    pub requested_by: PredictionRequestedBy,
 360    pub prediction: EditPrediction,
 361    pub was_shown: bool,
 362    pub shown_with: Option<edit_prediction_types::SuggestionDisplayType>,
 363}
 364
 365impl CurrentEditPrediction {
 366    fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
 367        let Some(new_edits) = self
 368            .prediction
 369            .interpolate(&self.prediction.buffer.read(cx))
 370        else {
 371            return false;
 372        };
 373
 374        if self.prediction.buffer != old_prediction.prediction.buffer {
 375            return true;
 376        }
 377
 378        let Some(old_edits) = old_prediction
 379            .prediction
 380            .interpolate(&old_prediction.prediction.buffer.read(cx))
 381        else {
 382            return true;
 383        };
 384
 385        let requested_by_buffer_id = self.requested_by.buffer_id();
 386
 387        // This reduces the occurrence of UI thrash from replacing edits
 388        //
 389        // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
 390        if requested_by_buffer_id == Some(self.prediction.buffer.entity_id())
 391            && requested_by_buffer_id == Some(old_prediction.prediction.buffer.entity_id())
 392            && old_edits.len() == 1
 393            && new_edits.len() == 1
 394        {
 395            let (old_range, old_text) = &old_edits[0];
 396            let (new_range, new_text) = &new_edits[0];
 397            new_range == old_range && new_text.starts_with(old_text.as_ref())
 398        } else {
 399            true
 400        }
 401    }
 402}
 403
 404#[derive(Debug, Clone)]
 405enum PredictionRequestedBy {
 406    DiagnosticsUpdate,
 407    Buffer(EntityId),
 408}
 409
 410impl PredictionRequestedBy {
 411    pub fn buffer_id(&self) -> Option<EntityId> {
 412        match self {
 413            PredictionRequestedBy::DiagnosticsUpdate => None,
 414            PredictionRequestedBy::Buffer(buffer_id) => Some(*buffer_id),
 415        }
 416    }
 417}
 418
 419#[derive(Debug)]
 420struct PendingPrediction {
 421    id: usize,
 422    task: Task<Option<EditPredictionId>>,
 423}
 424
 425/// A prediction from the perspective of a buffer.
 426#[derive(Debug)]
 427enum BufferEditPrediction<'a> {
 428    Local { prediction: &'a EditPrediction },
 429    Jump { prediction: &'a EditPrediction },
 430}
 431
 432#[cfg(test)]
 433impl std::ops::Deref for BufferEditPrediction<'_> {
 434    type Target = EditPrediction;
 435
 436    fn deref(&self) -> &Self::Target {
 437        match self {
 438            BufferEditPrediction::Local { prediction } => prediction,
 439            BufferEditPrediction::Jump { prediction } => prediction,
 440        }
 441    }
 442}
 443
 444struct RegisteredBuffer {
 445    file: Option<Arc<dyn File>>,
 446    snapshot: TextBufferSnapshot,
 447    last_position: Option<Anchor>,
 448    _subscriptions: [gpui::Subscription; 2],
 449}
 450
 451#[derive(Clone)]
 452struct LastEvent {
 453    old_snapshot: TextBufferSnapshot,
 454    new_snapshot: TextBufferSnapshot,
 455    old_file: Option<Arc<dyn File>>,
 456    new_file: Option<Arc<dyn File>>,
 457    edit_range: Option<Range<Anchor>>,
 458    snapshot_after_last_editing_pause: Option<TextBufferSnapshot>,
 459    last_edit_time: Option<Instant>,
 460}
 461
 462impl LastEvent {
 463    pub fn finalize(
 464        &self,
 465        license_detection_watchers: &HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
 466        cx: &App,
 467    ) -> Option<StoredEvent> {
 468        let path = buffer_path_with_id_fallback(self.new_file.as_ref(), &self.new_snapshot, cx);
 469        let old_path = buffer_path_with_id_fallback(self.old_file.as_ref(), &self.old_snapshot, cx);
 470
 471        let in_open_source_repo =
 472            [self.new_file.as_ref(), self.old_file.as_ref()]
 473                .iter()
 474                .all(|file| {
 475                    file.is_some_and(|file| {
 476                        license_detection_watchers
 477                            .get(&file.worktree_id(cx))
 478                            .is_some_and(|watcher| watcher.is_project_open_source())
 479                    })
 480                });
 481
 482        let diff = compute_diff_between_snapshots(&self.old_snapshot, &self.new_snapshot)?;
 483
 484        if path == old_path && diff.is_empty() {
 485            None
 486        } else {
 487            Some(StoredEvent {
 488                event: Arc::new(zeta_prompt::Event::BufferChange {
 489                    old_path,
 490                    path,
 491                    diff,
 492                    in_open_source_repo,
 493                    // TODO: Actually detect if this edit was predicted or not
 494                    predicted: false,
 495                }),
 496                old_snapshot: self.old_snapshot.clone(),
 497            })
 498        }
 499    }
 500
 501    pub fn split_by_pause(&self) -> (LastEvent, Option<LastEvent>) {
 502        let Some(boundary_snapshot) = self.snapshot_after_last_editing_pause.as_ref() else {
 503            return (self.clone(), None);
 504        };
 505
 506        let before = LastEvent {
 507            old_snapshot: self.old_snapshot.clone(),
 508            new_snapshot: boundary_snapshot.clone(),
 509            old_file: self.old_file.clone(),
 510            new_file: self.new_file.clone(),
 511            edit_range: None,
 512            snapshot_after_last_editing_pause: None,
 513            last_edit_time: self.last_edit_time,
 514        };
 515
 516        let after = LastEvent {
 517            old_snapshot: boundary_snapshot.clone(),
 518            new_snapshot: self.new_snapshot.clone(),
 519            old_file: self.old_file.clone(),
 520            new_file: self.new_file.clone(),
 521            edit_range: None,
 522            snapshot_after_last_editing_pause: None,
 523            last_edit_time: self.last_edit_time,
 524        };
 525
 526        (before, Some(after))
 527    }
 528}
 529
 530pub(crate) fn compute_diff_between_snapshots(
 531    old_snapshot: &TextBufferSnapshot,
 532    new_snapshot: &TextBufferSnapshot,
 533) -> Option<String> {
 534    let edits: Vec<Edit<usize>> = new_snapshot
 535        .edits_since::<usize>(&old_snapshot.version)
 536        .collect();
 537
 538    let (first_edit, last_edit) = edits.first().zip(edits.last())?;
 539
 540    let old_start_point = old_snapshot.offset_to_point(first_edit.old.start);
 541    let old_end_point = old_snapshot.offset_to_point(last_edit.old.end);
 542    let new_start_point = new_snapshot.offset_to_point(first_edit.new.start);
 543    let new_end_point = new_snapshot.offset_to_point(last_edit.new.end);
 544
 545    const CONTEXT_LINES: u32 = 3;
 546
 547    let old_context_start_row = old_start_point.row.saturating_sub(CONTEXT_LINES);
 548    let new_context_start_row = new_start_point.row.saturating_sub(CONTEXT_LINES);
 549    let old_context_end_row =
 550        (old_end_point.row + 1 + CONTEXT_LINES).min(old_snapshot.max_point().row);
 551    let new_context_end_row =
 552        (new_end_point.row + 1 + CONTEXT_LINES).min(new_snapshot.max_point().row);
 553
 554    let old_start_line_offset = old_snapshot.point_to_offset(Point::new(old_context_start_row, 0));
 555    let new_start_line_offset = new_snapshot.point_to_offset(Point::new(new_context_start_row, 0));
 556    let old_end_line_offset = old_snapshot
 557        .point_to_offset(Point::new(old_context_end_row + 1, 0).min(old_snapshot.max_point()));
 558    let new_end_line_offset = new_snapshot
 559        .point_to_offset(Point::new(new_context_end_row + 1, 0).min(new_snapshot.max_point()));
 560    let old_edit_range = old_start_line_offset..old_end_line_offset;
 561    let new_edit_range = new_start_line_offset..new_end_line_offset;
 562
 563    let old_region_text: String = old_snapshot.text_for_range(old_edit_range).collect();
 564    let new_region_text: String = new_snapshot.text_for_range(new_edit_range).collect();
 565
 566    let diff = language::unified_diff_with_offsets(
 567        &old_region_text,
 568        &new_region_text,
 569        old_context_start_row,
 570        new_context_start_row,
 571    );
 572
 573    Some(diff)
 574}
 575
 576fn buffer_path_with_id_fallback(
 577    file: Option<&Arc<dyn File>>,
 578    snapshot: &TextBufferSnapshot,
 579    cx: &App,
 580) -> Arc<Path> {
 581    if let Some(file) = file {
 582        file.full_path(cx).into()
 583    } else {
 584        Path::new(&format!("untitled-{}", snapshot.remote_id())).into()
 585    }
 586}
 587
 588impl EditPredictionStore {
 589    pub fn try_global(cx: &App) -> Option<Entity<Self>> {
 590        cx.try_global::<EditPredictionStoreGlobal>()
 591            .map(|global| global.0.clone())
 592    }
 593
 594    pub fn global(
 595        client: &Arc<Client>,
 596        user_store: &Entity<UserStore>,
 597        cx: &mut App,
 598    ) -> Entity<Self> {
 599        cx.try_global::<EditPredictionStoreGlobal>()
 600            .map(|global| global.0.clone())
 601            .unwrap_or_else(|| {
 602                let ep_store = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
 603                cx.set_global(EditPredictionStoreGlobal(ep_store.clone()));
 604                ep_store
 605            })
 606    }
 607
 608    pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
 609        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
 610        let data_collection_choice = Self::load_data_collection_choice();
 611
 612        let llm_token = LlmApiToken::default();
 613
 614        let (reject_tx, reject_rx) = mpsc::unbounded();
 615        cx.background_spawn({
 616            let client = client.clone();
 617            let llm_token = llm_token.clone();
 618            let app_version = AppVersion::global(cx);
 619            let background_executor = cx.background_executor().clone();
 620            async move {
 621                Self::handle_rejected_predictions(
 622                    reject_rx,
 623                    client,
 624                    llm_token,
 625                    app_version,
 626                    background_executor,
 627                )
 628                .await
 629            }
 630        })
 631        .detach();
 632
 633        let mut this = Self {
 634            projects: HashMap::default(),
 635            client,
 636            user_store,
 637            options: DEFAULT_OPTIONS,
 638            use_context: false,
 639            llm_token,
 640            _llm_token_subscription: cx.subscribe(
 641                &refresh_llm_token_listener,
 642                |this, _listener, _event, cx| {
 643                    let client = this.client.clone();
 644                    let llm_token = this.llm_token.clone();
 645                    cx.spawn(async move |_this, _cx| {
 646                        llm_token.refresh(&client).await?;
 647                        anyhow::Ok(())
 648                    })
 649                    .detach_and_log_err(cx);
 650                },
 651            ),
 652            update_required: false,
 653            #[cfg(feature = "cli-support")]
 654            eval_cache: None,
 655            edit_prediction_model: EditPredictionModel::Zeta2,
 656            sweep_ai: SweepAi::new(cx),
 657            mercury: Mercury::new(cx),
 658            data_collection_choice,
 659            reject_predictions_tx: reject_tx,
 660            rated_predictions: Default::default(),
 661            shown_predictions: Default::default(),
 662            custom_predict_edits_url: match env::var("ZED_PREDICT_EDITS_URL") {
 663                Ok(custom_url) => Url::parse(&custom_url).log_err().map(Into::into),
 664                Err(_) => {
 665                    if *USE_OLLAMA {
 666                        Some(
 667                            Url::parse("http://localhost:11434/v1/chat/completions")
 668                                .unwrap()
 669                                .into(),
 670                        )
 671                    } else {
 672                        None
 673                    }
 674                }
 675            },
 676        };
 677
 678        this.configure_context_retrieval(cx);
 679        let weak_this = cx.weak_entity();
 680        cx.on_flags_ready(move |_, cx| {
 681            weak_this
 682                .update(cx, |this, cx| this.configure_context_retrieval(cx))
 683                .ok();
 684        })
 685        .detach();
 686        cx.observe_global::<SettingsStore>(|this, cx| {
 687            this.configure_context_retrieval(cx);
 688        })
 689        .detach();
 690
 691        this
 692    }
 693
 694    #[cfg(test)]
 695    pub fn set_custom_predict_edits_url(&mut self, url: Url) {
 696        self.custom_predict_edits_url = Some(url.into());
 697    }
 698
 699    pub fn set_edit_prediction_model(&mut self, model: EditPredictionModel) {
 700        self.edit_prediction_model = model;
 701    }
 702
 703    pub fn has_sweep_api_token(&self, cx: &App) -> bool {
 704        self.sweep_ai.api_token.read(cx).has_key()
 705    }
 706
 707    pub fn has_mercury_api_token(&self, cx: &App) -> bool {
 708        self.mercury.api_token.read(cx).has_key()
 709    }
 710
 711    #[cfg(feature = "cli-support")]
 712    pub fn with_eval_cache(&mut self, cache: Arc<dyn EvalCache>) {
 713        self.eval_cache = Some(cache);
 714    }
 715
 716    pub fn options(&self) -> &ZetaOptions {
 717        &self.options
 718    }
 719
 720    pub fn set_options(&mut self, options: ZetaOptions) {
 721        self.options = options;
 722    }
 723
 724    pub fn set_use_context(&mut self, use_context: bool) {
 725        self.use_context = use_context;
 726    }
 727
 728    pub fn clear_history(&mut self) {
 729        for project_state in self.projects.values_mut() {
 730            project_state.events.clear();
 731            project_state.last_event.take();
 732        }
 733    }
 734
 735    pub fn clear_history_for_project(&mut self, project: &Entity<Project>) {
 736        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
 737            project_state.events.clear();
 738            project_state.last_event.take();
 739        }
 740    }
 741
 742    pub fn edit_history_for_project(
 743        &self,
 744        project: &Entity<Project>,
 745        cx: &App,
 746    ) -> Vec<StoredEvent> {
 747        self.projects
 748            .get(&project.entity_id())
 749            .map(|project_state| project_state.events(cx))
 750            .unwrap_or_default()
 751    }
 752
 753    pub fn edit_history_for_project_with_pause_split_last_event(
 754        &self,
 755        project: &Entity<Project>,
 756        cx: &App,
 757    ) -> Vec<StoredEvent> {
 758        self.projects
 759            .get(&project.entity_id())
 760            .map(|project_state| project_state.events_split_by_pause(cx))
 761            .unwrap_or_default()
 762    }
 763
 764    pub fn context_for_project<'a>(
 765        &'a self,
 766        project: &Entity<Project>,
 767        cx: &'a App,
 768    ) -> Arc<[RelatedFile]> {
 769        self.projects
 770            .get(&project.entity_id())
 771            .map(|project| project.context.read(cx).related_files())
 772            .unwrap_or_else(|| vec![].into())
 773    }
 774
 775    pub fn context_for_project_with_buffers<'a>(
 776        &'a self,
 777        project: &Entity<Project>,
 778        cx: &'a App,
 779    ) -> Option<impl 'a + Iterator<Item = (RelatedFile, Entity<Buffer>)>> {
 780        self.projects
 781            .get(&project.entity_id())
 782            .map(|project| project.context.read(cx).related_files_with_buffers())
 783    }
 784
 785    pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
 786        if self.edit_prediction_model == EditPredictionModel::Zeta2 {
 787            self.user_store.read(cx).edit_prediction_usage()
 788        } else {
 789            None
 790        }
 791    }
 792
 793    pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
 794        self.get_or_init_project(project, cx);
 795    }
 796
 797    pub fn register_buffer(
 798        &mut self,
 799        buffer: &Entity<Buffer>,
 800        project: &Entity<Project>,
 801        cx: &mut Context<Self>,
 802    ) {
 803        let project_state = self.get_or_init_project(project, cx);
 804        Self::register_buffer_impl(project_state, buffer, project, cx);
 805    }
 806
 807    fn get_or_init_project(
 808        &mut self,
 809        project: &Entity<Project>,
 810        cx: &mut Context<Self>,
 811    ) -> &mut ProjectState {
 812        let entity_id = project.entity_id();
 813        self.projects
 814            .entry(entity_id)
 815            .or_insert_with(|| ProjectState {
 816                context: {
 817                    let related_excerpt_store = cx.new(|cx| RelatedExcerptStore::new(project, cx));
 818                    cx.subscribe(&related_excerpt_store, move |this, _, event, _| {
 819                        this.handle_excerpt_store_event(entity_id, event);
 820                    })
 821                    .detach();
 822                    related_excerpt_store
 823                },
 824                events: VecDeque::new(),
 825                last_event: None,
 826                recent_paths: VecDeque::new(),
 827                debug_tx: None,
 828                registered_buffers: HashMap::default(),
 829                current_prediction: None,
 830                cancelled_predictions: HashSet::default(),
 831                pending_predictions: ArrayVec::new(),
 832                next_pending_prediction_id: 0,
 833                last_prediction_refresh: None,
 834                license_detection_watchers: HashMap::default(),
 835                user_actions: VecDeque::with_capacity(USER_ACTION_HISTORY_SIZE),
 836                _subscription: cx.subscribe(&project, Self::handle_project_event),
 837            })
 838    }
 839
 840    pub fn remove_project(&mut self, project: &Entity<Project>) {
 841        self.projects.remove(&project.entity_id());
 842    }
 843
 844    fn handle_excerpt_store_event(
 845        &mut self,
 846        project_entity_id: EntityId,
 847        event: &RelatedExcerptStoreEvent,
 848    ) {
 849        if let Some(project_state) = self.projects.get(&project_entity_id) {
 850            if let Some(debug_tx) = project_state.debug_tx.clone() {
 851                match event {
 852                    RelatedExcerptStoreEvent::StartedRefresh => {
 853                        debug_tx
 854                            .unbounded_send(DebugEvent::ContextRetrievalStarted(
 855                                ContextRetrievalStartedDebugEvent {
 856                                    project_entity_id: project_entity_id,
 857                                    timestamp: Instant::now(),
 858                                    search_prompt: String::new(),
 859                                },
 860                            ))
 861                            .ok();
 862                    }
 863                    RelatedExcerptStoreEvent::FinishedRefresh {
 864                        cache_hit_count,
 865                        cache_miss_count,
 866                        mean_definition_latency,
 867                        max_definition_latency,
 868                    } => {
 869                        debug_tx
 870                            .unbounded_send(DebugEvent::ContextRetrievalFinished(
 871                                ContextRetrievalFinishedDebugEvent {
 872                                    project_entity_id: project_entity_id,
 873                                    timestamp: Instant::now(),
 874                                    metadata: vec![
 875                                        (
 876                                            "Cache Hits",
 877                                            format!(
 878                                                "{}/{}",
 879                                                cache_hit_count,
 880                                                cache_hit_count + cache_miss_count
 881                                            )
 882                                            .into(),
 883                                        ),
 884                                        (
 885                                            "Max LSP Time",
 886                                            format!("{} ms", max_definition_latency.as_millis())
 887                                                .into(),
 888                                        ),
 889                                        (
 890                                            "Mean LSP Time",
 891                                            format!("{} ms", mean_definition_latency.as_millis())
 892                                                .into(),
 893                                        ),
 894                                    ],
 895                                },
 896                            ))
 897                            .ok();
 898                    }
 899                }
 900            }
 901        }
 902    }
 903
 904    pub fn debug_info(
 905        &mut self,
 906        project: &Entity<Project>,
 907        cx: &mut Context<Self>,
 908    ) -> mpsc::UnboundedReceiver<DebugEvent> {
 909        let project_state = self.get_or_init_project(project, cx);
 910        let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
 911        project_state.debug_tx = Some(debug_watch_tx);
 912        debug_watch_rx
 913    }
 914
 915    fn handle_project_event(
 916        &mut self,
 917        project: Entity<Project>,
 918        event: &project::Event,
 919        cx: &mut Context<Self>,
 920    ) {
 921        // TODO [zeta2] init with recent paths
 922        match event {
 923            project::Event::ActiveEntryChanged(Some(active_entry_id)) => {
 924                let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
 925                    return;
 926                };
 927                let path = project.read(cx).path_for_entry(*active_entry_id, cx);
 928                if let Some(path) = path {
 929                    if let Some(ix) = project_state
 930                        .recent_paths
 931                        .iter()
 932                        .position(|probe| probe == &path)
 933                    {
 934                        project_state.recent_paths.remove(ix);
 935                    }
 936                    project_state.recent_paths.push_front(path);
 937                }
 938            }
 939            project::Event::DiagnosticsUpdated { .. } => {
 940                if cx.has_flag::<Zeta2FeatureFlag>() {
 941                    self.refresh_prediction_from_diagnostics(project, cx);
 942                }
 943            }
 944            _ => (),
 945        }
 946    }
 947
 948    fn register_buffer_impl<'a>(
 949        project_state: &'a mut ProjectState,
 950        buffer: &Entity<Buffer>,
 951        project: &Entity<Project>,
 952        cx: &mut Context<Self>,
 953    ) -> &'a mut RegisteredBuffer {
 954        let buffer_id = buffer.entity_id();
 955
 956        if let Some(file) = buffer.read(cx).file() {
 957            let worktree_id = file.worktree_id(cx);
 958            if let Some(worktree) = project.read(cx).worktree_for_id(worktree_id, cx) {
 959                project_state
 960                    .license_detection_watchers
 961                    .entry(worktree_id)
 962                    .or_insert_with(|| {
 963                        let project_entity_id = project.entity_id();
 964                        cx.observe_release(&worktree, move |this, _worktree, _cx| {
 965                            let Some(project_state) = this.projects.get_mut(&project_entity_id)
 966                            else {
 967                                return;
 968                            };
 969                            project_state
 970                                .license_detection_watchers
 971                                .remove(&worktree_id);
 972                        })
 973                        .detach();
 974                        Rc::new(LicenseDetectionWatcher::new(&worktree, cx))
 975                    });
 976            }
 977        }
 978
 979        match project_state.registered_buffers.entry(buffer_id) {
 980            hash_map::Entry::Occupied(entry) => entry.into_mut(),
 981            hash_map::Entry::Vacant(entry) => {
 982                let buf = buffer.read(cx);
 983                let snapshot = buf.text_snapshot();
 984                let file = buf.file().cloned();
 985                let project_entity_id = project.entity_id();
 986                entry.insert(RegisteredBuffer {
 987                    snapshot,
 988                    file,
 989                    last_position: None,
 990                    _subscriptions: [
 991                        cx.subscribe(buffer, {
 992                            let project = project.downgrade();
 993                            move |this, buffer, event, cx| {
 994                                if let language::BufferEvent::Edited = event
 995                                    && let Some(project) = project.upgrade()
 996                                {
 997                                    this.report_changes_for_buffer(&buffer, &project, cx);
 998                                }
 999                            }
1000                        }),
1001                        cx.observe_release(buffer, move |this, _buffer, _cx| {
1002                            let Some(project_state) = this.projects.get_mut(&project_entity_id)
1003                            else {
1004                                return;
1005                            };
1006                            project_state.registered_buffers.remove(&buffer_id);
1007                        }),
1008                    ],
1009                })
1010            }
1011        }
1012    }
1013
1014    fn report_changes_for_buffer(
1015        &mut self,
1016        buffer: &Entity<Buffer>,
1017        project: &Entity<Project>,
1018        cx: &mut Context<Self>,
1019    ) {
1020        let project_state = self.get_or_init_project(project, cx);
1021        let registered_buffer = Self::register_buffer_impl(project_state, buffer, project, cx);
1022
1023        let buf = buffer.read(cx);
1024        let new_file = buf.file().cloned();
1025        let new_snapshot = buf.text_snapshot();
1026        if new_snapshot.version == registered_buffer.snapshot.version {
1027            return;
1028        }
1029
1030        let old_file = mem::replace(&mut registered_buffer.file, new_file.clone());
1031        let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
1032        let mut num_edits = 0usize;
1033        let mut total_deleted = 0usize;
1034        let mut total_inserted = 0usize;
1035        let mut edit_range: Option<Range<Anchor>> = None;
1036        let mut last_offset: Option<usize> = None;
1037
1038        for (edit, anchor_range) in
1039            new_snapshot.anchored_edits_since::<usize>(&old_snapshot.version)
1040        {
1041            num_edits += 1;
1042            total_deleted += edit.old.len();
1043            total_inserted += edit.new.len();
1044            edit_range = Some(match edit_range {
1045                None => anchor_range,
1046                Some(acc) => acc.start..anchor_range.end,
1047            });
1048            last_offset = Some(edit.new.end);
1049        }
1050
1051        if num_edits > 0 {
1052            let action_type = match (total_deleted, total_inserted, num_edits) {
1053                (0, ins, n) if ins == n => UserActionType::InsertChar,
1054                (0, _, _) => UserActionType::InsertSelection,
1055                (del, 0, n) if del == n => UserActionType::DeleteChar,
1056                (_, 0, _) => UserActionType::DeleteSelection,
1057                (_, ins, n) if ins == n => UserActionType::InsertChar,
1058                (_, _, _) => UserActionType::InsertSelection,
1059            };
1060
1061            if let Some(offset) = last_offset {
1062                let point = new_snapshot.offset_to_point(offset);
1063                let timestamp_epoch_ms = SystemTime::now()
1064                    .duration_since(UNIX_EPOCH)
1065                    .map(|d| d.as_millis() as u64)
1066                    .unwrap_or(0);
1067                project_state.record_user_action(UserActionRecord {
1068                    action_type,
1069                    buffer_id: buffer.entity_id(),
1070                    line_number: point.row,
1071                    offset,
1072                    timestamp_epoch_ms,
1073                });
1074            }
1075        }
1076
1077        let events = &mut project_state.events;
1078
1079        let now = cx.background_executor().now();
1080        if let Some(last_event) = project_state.last_event.as_mut() {
1081            let is_next_snapshot_of_same_buffer = old_snapshot.remote_id()
1082                == last_event.new_snapshot.remote_id()
1083                && old_snapshot.version == last_event.new_snapshot.version;
1084
1085            let should_coalesce = is_next_snapshot_of_same_buffer
1086                && edit_range
1087                    .as_ref()
1088                    .zip(last_event.edit_range.as_ref())
1089                    .is_some_and(|(a, b)| {
1090                        let a = a.to_point(&new_snapshot);
1091                        let b = b.to_point(&new_snapshot);
1092                        if a.start > b.end {
1093                            a.start.row.abs_diff(b.end.row) <= CHANGE_GROUPING_LINE_SPAN
1094                        } else if b.start > a.end {
1095                            b.start.row.abs_diff(a.end.row) <= CHANGE_GROUPING_LINE_SPAN
1096                        } else {
1097                            true
1098                        }
1099                    });
1100
1101            if should_coalesce {
1102                let pause_elapsed = last_event
1103                    .last_edit_time
1104                    .map(|t| now.duration_since(t) >= LAST_CHANGE_GROUPING_TIME)
1105                    .unwrap_or(false);
1106                if pause_elapsed {
1107                    last_event.snapshot_after_last_editing_pause =
1108                        Some(last_event.new_snapshot.clone());
1109                }
1110
1111                last_event.edit_range = edit_range;
1112                last_event.new_snapshot = new_snapshot;
1113                last_event.last_edit_time = Some(now);
1114                return;
1115            }
1116        }
1117
1118        if let Some(event) = project_state.last_event.take() {
1119            if let Some(event) = event.finalize(&project_state.license_detection_watchers, cx) {
1120                if events.len() + 1 >= EVENT_COUNT_MAX {
1121                    events.pop_front();
1122                }
1123                events.push_back(event);
1124            }
1125        }
1126
1127        project_state.last_event = Some(LastEvent {
1128            old_file,
1129            new_file,
1130            old_snapshot,
1131            new_snapshot,
1132            edit_range,
1133            snapshot_after_last_editing_pause: None,
1134            last_edit_time: Some(now),
1135        });
1136    }
1137
1138    fn prediction_at(
1139        &mut self,
1140        buffer: &Entity<Buffer>,
1141        position: Option<language::Anchor>,
1142        project: &Entity<Project>,
1143        cx: &App,
1144    ) -> Option<BufferEditPrediction<'_>> {
1145        let project_state = self.projects.get_mut(&project.entity_id())?;
1146        if let Some(position) = position
1147            && let Some(buffer) = project_state
1148                .registered_buffers
1149                .get_mut(&buffer.entity_id())
1150        {
1151            buffer.last_position = Some(position);
1152        }
1153
1154        let CurrentEditPrediction {
1155            requested_by,
1156            prediction,
1157            ..
1158        } = project_state.current_prediction.as_ref()?;
1159
1160        if prediction.targets_buffer(buffer.read(cx)) {
1161            Some(BufferEditPrediction::Local { prediction })
1162        } else {
1163            let show_jump = match requested_by {
1164                PredictionRequestedBy::Buffer(requested_by_buffer_id) => {
1165                    requested_by_buffer_id == &buffer.entity_id()
1166                }
1167                PredictionRequestedBy::DiagnosticsUpdate => true,
1168            };
1169
1170            if show_jump {
1171                Some(BufferEditPrediction::Jump { prediction })
1172            } else {
1173                None
1174            }
1175        }
1176    }
1177
1178    fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
1179        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1180            return;
1181        };
1182
1183        let Some(current_prediction) = project_state.current_prediction.take() else {
1184            return;
1185        };
1186
1187        for pending_prediction in mem::take(&mut project_state.pending_predictions) {
1188            project_state.cancel_pending_prediction(pending_prediction, cx);
1189        }
1190
1191        match self.edit_prediction_model {
1192            EditPredictionModel::Sweep => {
1193                sweep_ai::edit_prediction_accepted(self, current_prediction, cx)
1194            }
1195            EditPredictionModel::Mercury => {}
1196            EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {
1197                zeta2::edit_prediction_accepted(self, current_prediction, cx)
1198            }
1199        }
1200    }
1201
1202    async fn handle_rejected_predictions(
1203        rx: UnboundedReceiver<EditPredictionRejection>,
1204        client: Arc<Client>,
1205        llm_token: LlmApiToken,
1206        app_version: Version,
1207        background_executor: BackgroundExecutor,
1208    ) {
1209        let mut rx = std::pin::pin!(rx.peekable());
1210        let mut batched = Vec::new();
1211
1212        while let Some(rejection) = rx.next().await {
1213            batched.push(rejection);
1214
1215            if batched.len() < MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST / 2 {
1216                select_biased! {
1217                    next = rx.as_mut().peek().fuse() => {
1218                        if next.is_some() {
1219                            continue;
1220                        }
1221                    }
1222                    () = background_executor.timer(REJECT_REQUEST_DEBOUNCE).fuse() => {},
1223                }
1224            }
1225
1226            let url = client
1227                .http_client()
1228                .build_zed_llm_url("/predict_edits/reject", &[])
1229                .unwrap();
1230
1231            let flush_count = batched
1232                .len()
1233                // in case items have accumulated after failure
1234                .min(MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST);
1235            let start = batched.len() - flush_count;
1236
1237            let body = RejectEditPredictionsBodyRef {
1238                rejections: &batched[start..],
1239            };
1240
1241            let result = Self::send_api_request::<()>(
1242                |builder| {
1243                    let req = builder
1244                        .uri(url.as_ref())
1245                        .body(serde_json::to_string(&body)?.into());
1246                    anyhow::Ok(req?)
1247                },
1248                client.clone(),
1249                llm_token.clone(),
1250                app_version.clone(),
1251                true,
1252            )
1253            .await;
1254
1255            if result.log_err().is_some() {
1256                batched.drain(start..);
1257            }
1258        }
1259    }
1260
1261    fn reject_current_prediction(
1262        &mut self,
1263        reason: EditPredictionRejectReason,
1264        project: &Entity<Project>,
1265    ) {
1266        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
1267            project_state.pending_predictions.clear();
1268            if let Some(prediction) = project_state.current_prediction.take() {
1269                self.reject_prediction(prediction.prediction.id, reason, prediction.was_shown);
1270            }
1271        };
1272    }
1273
1274    fn did_show_current_prediction(
1275        &mut self,
1276        project: &Entity<Project>,
1277        display_type: edit_prediction_types::SuggestionDisplayType,
1278        cx: &mut Context<Self>,
1279    ) {
1280        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1281            return;
1282        };
1283
1284        let Some(current_prediction) = project_state.current_prediction.as_mut() else {
1285            return;
1286        };
1287
1288        let is_jump = display_type == edit_prediction_types::SuggestionDisplayType::Jump;
1289        let previous_shown_with = current_prediction.shown_with;
1290
1291        if previous_shown_with.is_none() || !is_jump {
1292            current_prediction.shown_with = Some(display_type);
1293        }
1294
1295        let is_first_non_jump_show = !current_prediction.was_shown && !is_jump;
1296
1297        if is_first_non_jump_show {
1298            current_prediction.was_shown = true;
1299        }
1300
1301        let display_type_changed = previous_shown_with != Some(display_type);
1302
1303        if self.edit_prediction_model == EditPredictionModel::Sweep && display_type_changed {
1304            sweep_ai::edit_prediction_shown(
1305                &self.sweep_ai,
1306                self.client.clone(),
1307                &current_prediction.prediction,
1308                display_type,
1309                cx,
1310            );
1311        }
1312
1313        if is_first_non_jump_show {
1314            self.shown_predictions
1315                .push_front(current_prediction.prediction.clone());
1316            if self.shown_predictions.len() > 50 {
1317                let completion = self.shown_predictions.pop_back().unwrap();
1318                self.rated_predictions.remove(&completion.id);
1319            }
1320        }
1321    }
1322
1323    fn reject_prediction(
1324        &mut self,
1325        prediction_id: EditPredictionId,
1326        reason: EditPredictionRejectReason,
1327        was_shown: bool,
1328    ) {
1329        match self.edit_prediction_model {
1330            EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {
1331                if self.custom_predict_edits_url.is_some() {
1332                    return;
1333                }
1334            }
1335            EditPredictionModel::Sweep | EditPredictionModel::Mercury => return,
1336        }
1337
1338        self.reject_predictions_tx
1339            .unbounded_send(EditPredictionRejection {
1340                request_id: prediction_id.to_string(),
1341                reason,
1342                was_shown,
1343            })
1344            .log_err();
1345    }
1346
1347    fn is_refreshing(&self, project: &Entity<Project>) -> bool {
1348        self.projects
1349            .get(&project.entity_id())
1350            .is_some_and(|project_state| !project_state.pending_predictions.is_empty())
1351    }
1352
1353    pub fn refresh_prediction_from_buffer(
1354        &mut self,
1355        project: Entity<Project>,
1356        buffer: Entity<Buffer>,
1357        position: language::Anchor,
1358        cx: &mut Context<Self>,
1359    ) {
1360        self.queue_prediction_refresh(project.clone(), buffer.entity_id(), cx, move |this, cx| {
1361            let Some(request_task) = this
1362                .update(cx, |this, cx| {
1363                    this.request_prediction(
1364                        &project,
1365                        &buffer,
1366                        position,
1367                        PredictEditsRequestTrigger::Other,
1368                        cx,
1369                    )
1370                })
1371                .log_err()
1372            else {
1373                return Task::ready(anyhow::Ok(None));
1374            };
1375
1376            cx.spawn(async move |_cx| {
1377                request_task.await.map(|prediction_result| {
1378                    prediction_result.map(|prediction_result| {
1379                        (
1380                            prediction_result,
1381                            PredictionRequestedBy::Buffer(buffer.entity_id()),
1382                        )
1383                    })
1384                })
1385            })
1386        })
1387    }
1388
1389    pub fn refresh_prediction_from_diagnostics(
1390        &mut self,
1391        project: Entity<Project>,
1392        cx: &mut Context<Self>,
1393    ) {
1394        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1395            return;
1396        };
1397
1398        // Prefer predictions from buffer
1399        if project_state.current_prediction.is_some() {
1400            return;
1401        };
1402
1403        self.queue_prediction_refresh(project.clone(), project.entity_id(), cx, move |this, cx| {
1404            let Some((active_buffer, snapshot, cursor_point)) = this
1405                .read_with(cx, |this, cx| {
1406                    let project_state = this.projects.get(&project.entity_id())?;
1407                    let (buffer, position) = project_state.active_buffer(&project, cx)?;
1408                    let snapshot = buffer.read(cx).snapshot();
1409
1410                    if !Self::predictions_enabled_at(&snapshot, position, cx) {
1411                        return None;
1412                    }
1413
1414                    let cursor_point = position
1415                        .map(|pos| pos.to_point(&snapshot))
1416                        .unwrap_or_default();
1417
1418                    Some((buffer, snapshot, cursor_point))
1419                })
1420                .log_err()
1421                .flatten()
1422            else {
1423                return Task::ready(anyhow::Ok(None));
1424            };
1425
1426            cx.spawn(async move |cx| {
1427                let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1428                    active_buffer,
1429                    &snapshot,
1430                    Default::default(),
1431                    cursor_point,
1432                    &project,
1433                    cx,
1434                )
1435                .await?
1436                else {
1437                    return anyhow::Ok(None);
1438                };
1439
1440                let Some(prediction_result) = this
1441                    .update(cx, |this, cx| {
1442                        this.request_prediction(
1443                            &project,
1444                            &jump_buffer,
1445                            jump_position,
1446                            PredictEditsRequestTrigger::Diagnostics,
1447                            cx,
1448                        )
1449                    })?
1450                    .await?
1451                else {
1452                    return anyhow::Ok(None);
1453                };
1454
1455                this.update(cx, |this, cx| {
1456                    Some((
1457                        if this
1458                            .get_or_init_project(&project, cx)
1459                            .current_prediction
1460                            .is_none()
1461                        {
1462                            prediction_result
1463                        } else {
1464                            EditPredictionResult {
1465                                id: prediction_result.id,
1466                                prediction: Err(EditPredictionRejectReason::CurrentPreferred),
1467                            }
1468                        },
1469                        PredictionRequestedBy::DiagnosticsUpdate,
1470                    ))
1471                })
1472            })
1473        });
1474    }
1475
1476    fn predictions_enabled_at(
1477        snapshot: &BufferSnapshot,
1478        position: Option<language::Anchor>,
1479        cx: &App,
1480    ) -> bool {
1481        let file = snapshot.file();
1482        let all_settings = all_language_settings(file, cx);
1483        if !all_settings.show_edit_predictions(snapshot.language(), cx)
1484            || file.is_some_and(|file| !all_settings.edit_predictions_enabled_for_file(file, cx))
1485        {
1486            return false;
1487        }
1488
1489        if let Some(last_position) = position {
1490            let settings = snapshot.settings_at(last_position, cx);
1491
1492            if !settings.edit_predictions_disabled_in.is_empty()
1493                && let Some(scope) = snapshot.language_scope_at(last_position)
1494                && let Some(scope_name) = scope.override_name()
1495                && settings
1496                    .edit_predictions_disabled_in
1497                    .iter()
1498                    .any(|s| s == scope_name)
1499            {
1500                return false;
1501            }
1502        }
1503
1504        true
1505    }
1506
1507    #[cfg(not(test))]
1508    pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
1509    #[cfg(test)]
1510    pub const THROTTLE_TIMEOUT: Duration = Duration::ZERO;
1511
1512    fn queue_prediction_refresh(
1513        &mut self,
1514        project: Entity<Project>,
1515        throttle_entity: EntityId,
1516        cx: &mut Context<Self>,
1517        do_refresh: impl FnOnce(
1518            WeakEntity<Self>,
1519            &mut AsyncApp,
1520        )
1521            -> Task<Result<Option<(EditPredictionResult, PredictionRequestedBy)>>>
1522        + 'static,
1523    ) {
1524        let project_state = self.get_or_init_project(&project, cx);
1525        let pending_prediction_id = project_state.next_pending_prediction_id;
1526        project_state.next_pending_prediction_id += 1;
1527        let last_request = project_state.last_prediction_refresh;
1528
1529        let task = cx.spawn(async move |this, cx| {
1530            if let Some((last_entity, last_timestamp)) = last_request
1531                && throttle_entity == last_entity
1532                && let Some(timeout) =
1533                    (last_timestamp + Self::THROTTLE_TIMEOUT).checked_duration_since(Instant::now())
1534            {
1535                cx.background_executor().timer(timeout).await;
1536            }
1537
1538            // If this task was cancelled before the throttle timeout expired,
1539            // do not perform a request.
1540            let mut is_cancelled = true;
1541            this.update(cx, |this, cx| {
1542                let project_state = this.get_or_init_project(&project, cx);
1543                if !project_state
1544                    .cancelled_predictions
1545                    .remove(&pending_prediction_id)
1546                {
1547                    project_state.last_prediction_refresh = Some((throttle_entity, Instant::now()));
1548                    is_cancelled = false;
1549                }
1550            })
1551            .ok();
1552            if is_cancelled {
1553                return None;
1554            }
1555
1556            let new_prediction_result = do_refresh(this.clone(), cx).await.log_err().flatten();
1557            let new_prediction_id = new_prediction_result
1558                .as_ref()
1559                .map(|(prediction, _)| prediction.id.clone());
1560
1561            // When a prediction completes, remove it from the pending list, and cancel
1562            // any pending predictions that were enqueued before it.
1563            this.update(cx, |this, cx| {
1564                let project_state = this.get_or_init_project(&project, cx);
1565
1566                let is_cancelled = project_state
1567                    .cancelled_predictions
1568                    .remove(&pending_prediction_id);
1569
1570                let new_current_prediction = if !is_cancelled
1571                    && let Some((prediction_result, requested_by)) = new_prediction_result
1572                {
1573                    match prediction_result.prediction {
1574                        Ok(prediction) => {
1575                            let new_prediction = CurrentEditPrediction {
1576                                requested_by,
1577                                prediction,
1578                                was_shown: false,
1579                                shown_with: None,
1580                            };
1581
1582                            if let Some(current_prediction) =
1583                                project_state.current_prediction.as_ref()
1584                            {
1585                                if new_prediction.should_replace_prediction(&current_prediction, cx)
1586                                {
1587                                    this.reject_current_prediction(
1588                                        EditPredictionRejectReason::Replaced,
1589                                        &project,
1590                                    );
1591
1592                                    Some(new_prediction)
1593                                } else {
1594                                    this.reject_prediction(
1595                                        new_prediction.prediction.id,
1596                                        EditPredictionRejectReason::CurrentPreferred,
1597                                        false,
1598                                    );
1599                                    None
1600                                }
1601                            } else {
1602                                Some(new_prediction)
1603                            }
1604                        }
1605                        Err(reject_reason) => {
1606                            this.reject_prediction(prediction_result.id, reject_reason, false);
1607                            None
1608                        }
1609                    }
1610                } else {
1611                    None
1612                };
1613
1614                let project_state = this.get_or_init_project(&project, cx);
1615
1616                if let Some(new_prediction) = new_current_prediction {
1617                    project_state.current_prediction = Some(new_prediction);
1618                }
1619
1620                let mut pending_predictions = mem::take(&mut project_state.pending_predictions);
1621                for (ix, pending_prediction) in pending_predictions.iter().enumerate() {
1622                    if pending_prediction.id == pending_prediction_id {
1623                        pending_predictions.remove(ix);
1624                        for pending_prediction in pending_predictions.drain(0..ix) {
1625                            project_state.cancel_pending_prediction(pending_prediction, cx)
1626                        }
1627                        break;
1628                    }
1629                }
1630                this.get_or_init_project(&project, cx).pending_predictions = pending_predictions;
1631                cx.notify();
1632            })
1633            .ok();
1634
1635            new_prediction_id
1636        });
1637
1638        if project_state.pending_predictions.len() <= 1 {
1639            project_state.pending_predictions.push(PendingPrediction {
1640                id: pending_prediction_id,
1641                task,
1642            });
1643        } else if project_state.pending_predictions.len() == 2 {
1644            let pending_prediction = project_state.pending_predictions.pop().unwrap();
1645            project_state.pending_predictions.push(PendingPrediction {
1646                id: pending_prediction_id,
1647                task,
1648            });
1649            project_state.cancel_pending_prediction(pending_prediction, cx);
1650        }
1651    }
1652
1653    pub fn request_prediction(
1654        &mut self,
1655        project: &Entity<Project>,
1656        active_buffer: &Entity<Buffer>,
1657        position: language::Anchor,
1658        trigger: PredictEditsRequestTrigger,
1659        cx: &mut Context<Self>,
1660    ) -> Task<Result<Option<EditPredictionResult>>> {
1661        self.request_prediction_internal(
1662            project.clone(),
1663            active_buffer.clone(),
1664            position,
1665            trigger,
1666            cx.has_flag::<Zeta2FeatureFlag>(),
1667            cx,
1668        )
1669    }
1670
1671    fn request_prediction_internal(
1672        &mut self,
1673        project: Entity<Project>,
1674        active_buffer: Entity<Buffer>,
1675        position: language::Anchor,
1676        trigger: PredictEditsRequestTrigger,
1677        allow_jump: bool,
1678        cx: &mut Context<Self>,
1679    ) -> Task<Result<Option<EditPredictionResult>>> {
1680        const DIAGNOSTIC_LINES_RANGE: u32 = 20;
1681
1682        self.get_or_init_project(&project, cx);
1683        let project_state = self.projects.get(&project.entity_id()).unwrap();
1684        let stored_events = project_state.events(cx);
1685        let has_events = !stored_events.is_empty();
1686        let events: Vec<Arc<zeta_prompt::Event>> =
1687            stored_events.into_iter().map(|e| e.event).collect();
1688        let debug_tx = project_state.debug_tx.clone();
1689
1690        let snapshot = active_buffer.read(cx).snapshot();
1691        let cursor_point = position.to_point(&snapshot);
1692        let current_offset = position.to_offset(&snapshot);
1693
1694        let mut user_actions: Vec<UserActionRecord> =
1695            project_state.user_actions.iter().cloned().collect();
1696
1697        if let Some(last_action) = user_actions.last() {
1698            if last_action.buffer_id == active_buffer.entity_id()
1699                && current_offset != last_action.offset
1700            {
1701                let timestamp_epoch_ms = SystemTime::now()
1702                    .duration_since(UNIX_EPOCH)
1703                    .map(|d| d.as_millis() as u64)
1704                    .unwrap_or(0);
1705                user_actions.push(UserActionRecord {
1706                    action_type: UserActionType::CursorMovement,
1707                    buffer_id: active_buffer.entity_id(),
1708                    line_number: cursor_point.row,
1709                    offset: current_offset,
1710                    timestamp_epoch_ms,
1711                });
1712            }
1713        }
1714        let diagnostic_search_start = cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
1715        let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
1716        let diagnostic_search_range =
1717            Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
1718
1719        let related_files = if self.use_context {
1720            self.context_for_project(&project, cx)
1721        } else {
1722            Vec::new().into()
1723        };
1724
1725        let inputs = EditPredictionModelInput {
1726            project: project.clone(),
1727            buffer: active_buffer.clone(),
1728            snapshot: snapshot.clone(),
1729            position,
1730            events,
1731            related_files,
1732            recent_paths: project_state.recent_paths.clone(),
1733            trigger,
1734            diagnostic_search_range: diagnostic_search_range.clone(),
1735            debug_tx,
1736            user_actions,
1737        };
1738
1739        let can_collect_example = snapshot
1740            .file()
1741            .is_some_and(|file| self.can_collect_file(&project, file, cx))
1742            && self.can_collect_events(&inputs.events, cx);
1743
1744        if can_collect_example && should_sample_edit_prediction_example_capture(cx) {
1745            let events_for_capture =
1746                self.edit_history_for_project_with_pause_split_last_event(&project, cx);
1747            if let Some(example_task) = capture_example::capture_example(
1748                project.clone(),
1749                active_buffer.clone(),
1750                position,
1751                events_for_capture,
1752                false,
1753                cx,
1754            ) {
1755                cx.spawn(async move |_this, _cx| {
1756                    let example = example_task.await?;
1757                    telemetry::event!("Edit Prediction Example Captured", example = example);
1758                    anyhow::Ok(())
1759                })
1760                .detach_and_log_err(cx);
1761            }
1762        }
1763        let task = match self.edit_prediction_model {
1764            EditPredictionModel::Zeta1 => zeta1::request_prediction_with_zeta1(self, inputs, cx),
1765            EditPredictionModel::Zeta2 => zeta2::request_prediction_with_zeta2(self, inputs, cx),
1766            EditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep(inputs, cx),
1767            EditPredictionModel::Mercury => self.mercury.request_prediction(inputs, cx),
1768        };
1769
1770        cx.spawn(async move |this, cx| {
1771            let prediction = task.await?;
1772
1773            if prediction.is_none() && allow_jump {
1774                let cursor_point = position.to_point(&snapshot);
1775                if has_events
1776                    && let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1777                        active_buffer.clone(),
1778                        &snapshot,
1779                        diagnostic_search_range,
1780                        cursor_point,
1781                        &project,
1782                        cx,
1783                    )
1784                    .await?
1785                {
1786                    return this
1787                        .update(cx, |this, cx| {
1788                            this.request_prediction_internal(
1789                                project,
1790                                jump_buffer,
1791                                jump_position,
1792                                trigger,
1793                                false,
1794                                cx,
1795                            )
1796                        })?
1797                        .await;
1798                }
1799
1800                return anyhow::Ok(None);
1801            }
1802
1803            Ok(prediction)
1804        })
1805    }
1806
1807    async fn next_diagnostic_location(
1808        active_buffer: Entity<Buffer>,
1809        active_buffer_snapshot: &BufferSnapshot,
1810        active_buffer_diagnostic_search_range: Range<Point>,
1811        active_buffer_cursor_point: Point,
1812        project: &Entity<Project>,
1813        cx: &mut AsyncApp,
1814    ) -> Result<Option<(Entity<Buffer>, language::Anchor)>> {
1815        // find the closest diagnostic to the cursor that wasn't close enough to be included in the last request
1816        let mut jump_location = active_buffer_snapshot
1817            .diagnostic_groups(None)
1818            .into_iter()
1819            .filter_map(|(_, group)| {
1820                let range = &group.entries[group.primary_ix]
1821                    .range
1822                    .to_point(&active_buffer_snapshot);
1823                if range.overlaps(&active_buffer_diagnostic_search_range) {
1824                    None
1825                } else {
1826                    Some(range.start)
1827                }
1828            })
1829            .min_by_key(|probe| probe.row.abs_diff(active_buffer_cursor_point.row))
1830            .map(|position| {
1831                (
1832                    active_buffer.clone(),
1833                    active_buffer_snapshot.anchor_before(position),
1834                )
1835            });
1836
1837        if jump_location.is_none() {
1838            let active_buffer_path = active_buffer.read_with(cx, |buffer, cx| {
1839                let file = buffer.file()?;
1840
1841                Some(ProjectPath {
1842                    worktree_id: file.worktree_id(cx),
1843                    path: file.path().clone(),
1844                })
1845            });
1846
1847            let buffer_task = project.update(cx, |project, cx| {
1848                let (path, _, _) = project
1849                    .diagnostic_summaries(false, cx)
1850                    .filter(|(path, _, _)| Some(path) != active_buffer_path.as_ref())
1851                    .max_by_key(|(path, _, _)| {
1852                        // find the buffer with errors that shares most parent directories
1853                        path.path
1854                            .components()
1855                            .zip(
1856                                active_buffer_path
1857                                    .as_ref()
1858                                    .map(|p| p.path.components())
1859                                    .unwrap_or_default(),
1860                            )
1861                            .take_while(|(a, b)| a == b)
1862                            .count()
1863                    })?;
1864
1865                Some(project.open_buffer(path, cx))
1866            });
1867
1868            if let Some(buffer_task) = buffer_task {
1869                let closest_buffer = buffer_task.await?;
1870
1871                jump_location = closest_buffer
1872                    .read_with(cx, |buffer, _cx| {
1873                        buffer
1874                            .buffer_diagnostics(None)
1875                            .into_iter()
1876                            .min_by_key(|entry| entry.diagnostic.severity)
1877                            .map(|entry| entry.range.start)
1878                    })
1879                    .map(|position| (closest_buffer, position));
1880            }
1881        }
1882
1883        anyhow::Ok(jump_location)
1884    }
1885
1886    async fn send_raw_llm_request(
1887        request: open_ai::Request,
1888        client: Arc<Client>,
1889        llm_token: LlmApiToken,
1890        app_version: Version,
1891        #[cfg(feature = "cli-support")] eval_cache: Option<Arc<dyn EvalCache>>,
1892        #[cfg(feature = "cli-support")] eval_cache_kind: EvalCacheEntryKind,
1893    ) -> Result<(open_ai::Response, Option<EditPredictionUsage>)> {
1894        let url = client
1895            .http_client()
1896            .build_zed_llm_url("/predict_edits/raw", &[])?;
1897
1898        #[cfg(feature = "cli-support")]
1899        let cache_key = if let Some(cache) = eval_cache {
1900            use collections::FxHasher;
1901            use std::hash::{Hash, Hasher};
1902
1903            let mut hasher = FxHasher::default();
1904            url.hash(&mut hasher);
1905            let request_str = serde_json::to_string_pretty(&request)?;
1906            request_str.hash(&mut hasher);
1907            let hash = hasher.finish();
1908
1909            let key = (eval_cache_kind, hash);
1910            if let Some(response_str) = cache.read(key) {
1911                return Ok((serde_json::from_str(&response_str)?, None));
1912            }
1913
1914            Some((cache, request_str, key))
1915        } else {
1916            None
1917        };
1918
1919        let (response, usage) = Self::send_api_request(
1920            |builder| {
1921                let req = builder
1922                    .uri(url.as_ref())
1923                    .body(serde_json::to_string(&request)?.into());
1924                Ok(req?)
1925            },
1926            client,
1927            llm_token,
1928            app_version,
1929            true,
1930        )
1931        .await?;
1932
1933        #[cfg(feature = "cli-support")]
1934        if let Some((cache, request, key)) = cache_key {
1935            cache.write(key, &request, &serde_json::to_string_pretty(&response)?);
1936        }
1937
1938        Ok((response, usage))
1939    }
1940
1941    fn handle_api_response<T>(
1942        this: &WeakEntity<Self>,
1943        response: Result<(T, Option<EditPredictionUsage>)>,
1944        cx: &mut gpui::AsyncApp,
1945    ) -> Result<T> {
1946        match response {
1947            Ok((data, usage)) => {
1948                if let Some(usage) = usage {
1949                    this.update(cx, |this, cx| {
1950                        this.user_store.update(cx, |user_store, cx| {
1951                            user_store.update_edit_prediction_usage(usage, cx);
1952                        });
1953                    })
1954                    .ok();
1955                }
1956                Ok(data)
1957            }
1958            Err(err) => {
1959                if err.is::<ZedUpdateRequiredError>() {
1960                    cx.update(|cx| {
1961                        this.update(cx, |this, _cx| {
1962                            this.update_required = true;
1963                        })
1964                        .ok();
1965
1966                        let error_message: SharedString = err.to_string().into();
1967                        show_app_notification(
1968                            NotificationId::unique::<ZedUpdateRequiredError>(),
1969                            cx,
1970                            move |cx| {
1971                                cx.new(|cx| {
1972                                    ErrorMessagePrompt::new(error_message.clone(), cx)
1973                                        .with_link_button("Update Zed", "https://zed.dev/releases")
1974                                })
1975                            },
1976                        );
1977                    });
1978                }
1979                Err(err)
1980            }
1981        }
1982    }
1983
1984    async fn send_api_request<Res>(
1985        build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
1986        client: Arc<Client>,
1987        llm_token: LlmApiToken,
1988        app_version: Version,
1989        require_auth: bool,
1990    ) -> Result<(Res, Option<EditPredictionUsage>)>
1991    where
1992        Res: DeserializeOwned,
1993    {
1994        let http_client = client.http_client();
1995
1996        let mut token = if require_auth {
1997            Some(llm_token.acquire(&client).await?)
1998        } else {
1999            llm_token.acquire(&client).await.ok()
2000        };
2001        let mut did_retry = false;
2002
2003        loop {
2004            let request_builder = http_client::Request::builder().method(Method::POST);
2005
2006            let mut request_builder = request_builder
2007                .header("Content-Type", "application/json")
2008                .header(ZED_VERSION_HEADER_NAME, app_version.to_string());
2009
2010            // Only add Authorization header if we have a token
2011            if let Some(ref token_value) = token {
2012                request_builder =
2013                    request_builder.header("Authorization", format!("Bearer {}", token_value));
2014            }
2015
2016            let request = build(request_builder)?;
2017
2018            let mut response = http_client.send(request).await?;
2019
2020            if let Some(minimum_required_version) = response
2021                .headers()
2022                .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
2023                .and_then(|version| Version::from_str(version.to_str().ok()?).ok())
2024            {
2025                anyhow::ensure!(
2026                    app_version >= minimum_required_version,
2027                    ZedUpdateRequiredError {
2028                        minimum_version: minimum_required_version
2029                    }
2030                );
2031            }
2032
2033            if response.status().is_success() {
2034                let usage = EditPredictionUsage::from_headers(response.headers()).ok();
2035
2036                let mut body = Vec::new();
2037                response.body_mut().read_to_end(&mut body).await?;
2038                return Ok((serde_json::from_slice(&body)?, usage));
2039            } else if !did_retry
2040                && token.is_some()
2041                && response
2042                    .headers()
2043                    .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
2044                    .is_some()
2045            {
2046                did_retry = true;
2047                token = Some(llm_token.refresh(&client).await?);
2048            } else {
2049                let mut body = String::new();
2050                response.body_mut().read_to_string(&mut body).await?;
2051                anyhow::bail!(
2052                    "Request failed with status: {:?}\nBody: {}",
2053                    response.status(),
2054                    body
2055                );
2056            }
2057        }
2058    }
2059
2060    pub fn refresh_context(
2061        &mut self,
2062        project: &Entity<Project>,
2063        buffer: &Entity<language::Buffer>,
2064        cursor_position: language::Anchor,
2065        cx: &mut Context<Self>,
2066    ) {
2067        if self.use_context {
2068            self.get_or_init_project(project, cx)
2069                .context
2070                .update(cx, |store, cx| {
2071                    store.refresh(buffer.clone(), cursor_position, cx);
2072                });
2073        }
2074    }
2075
2076    #[cfg(feature = "cli-support")]
2077    pub fn set_context_for_buffer(
2078        &mut self,
2079        project: &Entity<Project>,
2080        related_files: Vec<RelatedFile>,
2081        cx: &mut Context<Self>,
2082    ) {
2083        self.get_or_init_project(project, cx)
2084            .context
2085            .update(cx, |store, _| {
2086                store.set_related_files(related_files);
2087            });
2088    }
2089
2090    fn is_file_open_source(
2091        &self,
2092        project: &Entity<Project>,
2093        file: &Arc<dyn File>,
2094        cx: &App,
2095    ) -> bool {
2096        if !file.is_local() || file.is_private() {
2097            return false;
2098        }
2099        let Some(project_state) = self.projects.get(&project.entity_id()) else {
2100            return false;
2101        };
2102        project_state
2103            .license_detection_watchers
2104            .get(&file.worktree_id(cx))
2105            .as_ref()
2106            .is_some_and(|watcher| watcher.is_project_open_source())
2107    }
2108
2109    fn can_collect_file(&self, project: &Entity<Project>, file: &Arc<dyn File>, cx: &App) -> bool {
2110        self.data_collection_choice.is_enabled(cx) && self.is_file_open_source(project, file, cx)
2111    }
2112
2113    fn can_collect_events(&self, events: &[Arc<zeta_prompt::Event>], cx: &App) -> bool {
2114        if !self.data_collection_choice.is_enabled(cx) {
2115            return false;
2116        }
2117        events.iter().all(|event| {
2118            matches!(
2119                event.as_ref(),
2120                zeta_prompt::Event::BufferChange {
2121                    in_open_source_repo: true,
2122                    ..
2123                }
2124            )
2125        })
2126    }
2127
2128    fn load_data_collection_choice() -> DataCollectionChoice {
2129        let choice = KEY_VALUE_STORE
2130            .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2131            .log_err()
2132            .flatten();
2133
2134        match choice.as_deref() {
2135            Some("true") => DataCollectionChoice::Enabled,
2136            Some("false") => DataCollectionChoice::Disabled,
2137            Some(_) => {
2138                log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
2139                DataCollectionChoice::NotAnswered
2140            }
2141            None => DataCollectionChoice::NotAnswered,
2142        }
2143    }
2144
2145    fn toggle_data_collection_choice(&mut self, cx: &mut Context<Self>) {
2146        self.data_collection_choice = self.data_collection_choice.toggle();
2147        let new_choice = self.data_collection_choice;
2148        let is_enabled = new_choice.is_enabled(cx);
2149        db::write_and_log(cx, move || {
2150            KEY_VALUE_STORE.write_kvp(
2151                ZED_PREDICT_DATA_COLLECTION_CHOICE.into(),
2152                is_enabled.to_string(),
2153            )
2154        });
2155    }
2156
2157    pub fn shown_predictions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
2158        self.shown_predictions.iter()
2159    }
2160
2161    pub fn shown_completions_len(&self) -> usize {
2162        self.shown_predictions.len()
2163    }
2164
2165    pub fn is_prediction_rated(&self, id: &EditPredictionId) -> bool {
2166        self.rated_predictions.contains(id)
2167    }
2168
2169    pub fn rate_prediction(
2170        &mut self,
2171        prediction: &EditPrediction,
2172        rating: EditPredictionRating,
2173        feedback: String,
2174        cx: &mut Context<Self>,
2175    ) {
2176        self.rated_predictions.insert(prediction.id.clone());
2177        telemetry::event!(
2178            "Edit Prediction Rated",
2179            rating,
2180            inputs = prediction.inputs,
2181            output = prediction
2182                .edit_preview
2183                .as_unified_diff(prediction.snapshot.file(), &prediction.edits),
2184            feedback
2185        );
2186        self.client.telemetry().flush_events().detach();
2187        cx.notify();
2188    }
2189
2190    fn configure_context_retrieval(&mut self, cx: &mut Context<'_, EditPredictionStore>) {
2191        self.use_context = cx.has_flag::<Zeta2FeatureFlag>()
2192            && all_language_settings(None, cx).edit_predictions.use_context;
2193    }
2194}
2195
2196#[derive(Error, Debug)]
2197#[error(
2198    "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
2199)]
2200pub struct ZedUpdateRequiredError {
2201    minimum_version: Version,
2202}
2203
2204#[cfg(feature = "cli-support")]
2205pub type EvalCacheKey = (EvalCacheEntryKind, u64);
2206
2207#[cfg(feature = "cli-support")]
2208#[derive(Debug, Clone, Copy, PartialEq)]
2209pub enum EvalCacheEntryKind {
2210    Context,
2211    Search,
2212    Prediction,
2213}
2214
2215#[cfg(feature = "cli-support")]
2216impl std::fmt::Display for EvalCacheEntryKind {
2217    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2218        match self {
2219            EvalCacheEntryKind::Search => write!(f, "search"),
2220            EvalCacheEntryKind::Context => write!(f, "context"),
2221            EvalCacheEntryKind::Prediction => write!(f, "prediction"),
2222        }
2223    }
2224}
2225
2226#[cfg(feature = "cli-support")]
2227pub trait EvalCache: Send + Sync {
2228    fn read(&self, key: EvalCacheKey) -> Option<String>;
2229    fn write(&self, key: EvalCacheKey, input: &str, value: &str);
2230}
2231
2232#[derive(Debug, Clone, Copy)]
2233pub enum DataCollectionChoice {
2234    NotAnswered,
2235    Enabled,
2236    Disabled,
2237}
2238
2239impl DataCollectionChoice {
2240    pub fn is_enabled(self, cx: &App) -> bool {
2241        if cx.is_staff() {
2242            return true;
2243        }
2244        match self {
2245            Self::Enabled => true,
2246            Self::NotAnswered | Self::Disabled => false,
2247        }
2248    }
2249
2250    #[must_use]
2251    pub fn toggle(&self) -> DataCollectionChoice {
2252        match self {
2253            Self::Enabled => Self::Disabled,
2254            Self::Disabled => Self::Enabled,
2255            Self::NotAnswered => Self::Enabled,
2256        }
2257    }
2258}
2259
2260impl From<bool> for DataCollectionChoice {
2261    fn from(value: bool) -> Self {
2262        match value {
2263            true => DataCollectionChoice::Enabled,
2264            false => DataCollectionChoice::Disabled,
2265        }
2266    }
2267}
2268
2269struct ZedPredictUpsell;
2270
2271impl Dismissable for ZedPredictUpsell {
2272    const KEY: &'static str = "dismissed-edit-predict-upsell";
2273
2274    fn dismissed() -> bool {
2275        // To make this backwards compatible with older versions of Zed, we
2276        // check if the user has seen the previous Edit Prediction Onboarding
2277        // before, by checking the data collection choice which was written to
2278        // the database once the user clicked on "Accept and Enable"
2279        if KEY_VALUE_STORE
2280            .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2281            .log_err()
2282            .is_some_and(|s| s.is_some())
2283        {
2284            return true;
2285        }
2286
2287        KEY_VALUE_STORE
2288            .read_kvp(Self::KEY)
2289            .log_err()
2290            .is_some_and(|s| s.is_some())
2291    }
2292}
2293
2294pub fn should_show_upsell_modal() -> bool {
2295    !ZedPredictUpsell::dismissed()
2296}
2297
2298pub fn init(cx: &mut App) {
2299    cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
2300        workspace.register_action(
2301            move |workspace, _: &zed_actions::OpenZedPredictOnboarding, window, cx| {
2302                ZedPredictModal::toggle(
2303                    workspace,
2304                    workspace.user_store().clone(),
2305                    workspace.client().clone(),
2306                    window,
2307                    cx,
2308                )
2309            },
2310        );
2311
2312        workspace.register_action(|workspace, _: &ResetOnboarding, _window, cx| {
2313            update_settings_file(workspace.app_state().fs.clone(), cx, move |settings, _| {
2314                settings
2315                    .project
2316                    .all_languages
2317                    .features
2318                    .get_or_insert_default()
2319                    .edit_prediction_provider = Some(EditPredictionProvider::None)
2320            });
2321        });
2322    })
2323    .detach();
2324}