edit_prediction.rs

   1use anyhow::Result;
   2use arrayvec::ArrayVec;
   3use client::{Client, EditPredictionUsage, UserStore};
   4use cloud_llm_client::predict_edits_v3::{
   5    PredictEditsV3Request, PredictEditsV3Response, RawCompletionRequest, RawCompletionResponse,
   6};
   7use cloud_llm_client::{
   8    EditPredictionRejectReason, EditPredictionRejection,
   9    MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
  10    PredictEditsRequestTrigger, RejectEditPredictionsBodyRef, ZED_VERSION_HEADER_NAME,
  11};
  12use collections::{HashMap, HashSet};
  13use copilot::{Copilot, Reinstall, SignIn, SignOut};
  14use db::kvp::{Dismissable, KEY_VALUE_STORE};
  15use edit_prediction_context::{RelatedExcerptStore, RelatedExcerptStoreEvent, RelatedFile};
  16use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
  17use futures::{
  18    AsyncReadExt as _, FutureExt as _, StreamExt as _,
  19    channel::mpsc::{self, UnboundedReceiver},
  20    select_biased,
  21};
  22use gpui::BackgroundExecutor;
  23use gpui::http_client::Url;
  24use gpui::{
  25    App, AsyncApp, Entity, EntityId, Global, SharedString, Subscription, Task, WeakEntity, actions,
  26    http_client::{self, AsyncBody, Method},
  27    prelude::*,
  28};
  29use language::language_settings::all_language_settings;
  30use language::{Anchor, Buffer, File, Point, TextBufferSnapshot, ToOffset, ToPoint};
  31use language::{BufferSnapshot, OffsetRangeExt};
  32use language_model::{LlmApiToken, NeedsLlmTokenRefresh, RefreshLlmTokenListener};
  33use project::{Project, ProjectPath, WorktreeId};
  34use release_channel::AppVersion;
  35use semver::Version;
  36use serde::de::DeserializeOwned;
  37use settings::{EditPredictionProvider, update_settings_file};
  38use std::collections::{VecDeque, hash_map};
  39use text::Edit;
  40use workspace::Workspace;
  41use zeta_prompt::ZetaPromptInput;
  42use zeta_prompt::ZetaVersion;
  43
  44use std::ops::Range;
  45use std::path::Path;
  46use std::rc::Rc;
  47use std::str::FromStr as _;
  48use std::sync::{Arc, LazyLock};
  49use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
  50use std::{env, mem};
  51use thiserror::Error;
  52use util::{RangeExt as _, ResultExt as _};
  53use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
  54
  55pub mod cursor_excerpt;
  56pub mod example_spec;
  57mod license_detection;
  58pub mod mercury;
  59mod onboarding_modal;
  60pub mod open_ai_response;
  61mod prediction;
  62pub mod sweep_ai;
  63
  64pub mod udiff;
  65
  66mod capture_example;
  67mod zed_edit_prediction_delegate;
  68pub mod zeta1;
  69pub mod zeta2;
  70
  71#[cfg(test)]
  72mod edit_prediction_tests;
  73
  74use crate::capture_example::{
  75    should_sample_edit_prediction_example_capture, should_send_testing_zeta2_request,
  76};
  77use crate::license_detection::LicenseDetectionWatcher;
  78use crate::mercury::Mercury;
  79use crate::onboarding_modal::ZedPredictModal;
  80pub use crate::prediction::EditPrediction;
  81pub use crate::prediction::EditPredictionId;
  82use crate::prediction::EditPredictionResult;
  83pub use crate::sweep_ai::SweepAi;
  84pub use capture_example::capture_example;
  85pub use language_model::ApiKeyState;
  86pub use telemetry_events::EditPredictionRating;
  87pub use zed_edit_prediction_delegate::ZedEditPredictionDelegate;
  88
  89actions!(
  90    edit_prediction,
  91    [
  92        /// Resets the edit prediction onboarding state.
  93        ResetOnboarding,
  94        /// Clears the edit prediction history.
  95        ClearHistory,
  96    ]
  97);
  98
  99/// Maximum number of events to track.
 100const EVENT_COUNT_MAX: usize = 6;
 101const CHANGE_GROUPING_LINE_SPAN: u32 = 8;
 102const LAST_CHANGE_GROUPING_TIME: Duration = Duration::from_secs(1);
 103const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
 104const REJECT_REQUEST_DEBOUNCE: Duration = Duration::from_secs(15);
 105
 106pub struct SweepFeatureFlag;
 107
 108impl FeatureFlag for SweepFeatureFlag {
 109    const NAME: &str = "sweep-ai";
 110}
 111
 112pub struct MercuryFeatureFlag;
 113
 114impl FeatureFlag for MercuryFeatureFlag {
 115    const NAME: &str = "mercury";
 116}
 117
 118static EDIT_PREDICTIONS_MODEL_ID: LazyLock<Option<String>> =
 119    LazyLock::new(|| env::var("ZED_ZETA_MODEL").ok());
 120
 121pub struct Zeta2FeatureFlag;
 122
 123impl FeatureFlag for Zeta2FeatureFlag {
 124    const NAME: &'static str = "zeta2";
 125
 126    fn enabled_for_staff() -> bool {
 127        true
 128    }
 129}
 130
 131pub struct EditPredictionExampleCaptureFeatureFlag;
 132
 133impl FeatureFlag for EditPredictionExampleCaptureFeatureFlag {
 134    const NAME: &'static str = "edit-prediction-example-capture";
 135
 136    fn enabled_for_staff() -> bool {
 137        true
 138    }
 139}
 140
 141#[derive(Clone)]
 142struct EditPredictionStoreGlobal(Entity<EditPredictionStore>);
 143
 144impl Global for EditPredictionStoreGlobal {}
 145
 146pub struct EditPredictionStore {
 147    client: Arc<Client>,
 148    user_store: Entity<UserStore>,
 149    llm_token: LlmApiToken,
 150    _llm_token_subscription: Subscription,
 151    projects: HashMap<EntityId, ProjectState>,
 152    update_required: bool,
 153    edit_prediction_model: EditPredictionModel,
 154    pub sweep_ai: SweepAi,
 155    pub mercury: Mercury,
 156    data_collection_choice: DataCollectionChoice,
 157    reject_predictions_tx: mpsc::UnboundedSender<EditPredictionRejection>,
 158    shown_predictions: VecDeque<EditPrediction>,
 159    rated_predictions: HashSet<EditPredictionId>,
 160    custom_predict_edits_url: Option<Arc<Url>>,
 161}
 162
 163#[derive(Copy, Clone, Default, PartialEq, Eq)]
 164pub enum EditPredictionModel {
 165    #[default]
 166    Zeta1,
 167    Zeta2 {
 168        version: ZetaVersion,
 169    },
 170    Sweep,
 171    Mercury,
 172}
 173
 174#[derive(Clone)]
 175pub struct EditPredictionModelInput {
 176    project: Entity<Project>,
 177    buffer: Entity<Buffer>,
 178    snapshot: BufferSnapshot,
 179    position: Anchor,
 180    events: Vec<Arc<zeta_prompt::Event>>,
 181    related_files: Vec<RelatedFile>,
 182    recent_paths: VecDeque<ProjectPath>,
 183    trigger: PredictEditsRequestTrigger,
 184    diagnostic_search_range: Range<Point>,
 185    debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
 186    pub user_actions: Vec<UserActionRecord>,
 187}
 188
 189#[derive(Debug)]
 190pub enum DebugEvent {
 191    ContextRetrievalStarted(ContextRetrievalStartedDebugEvent),
 192    ContextRetrievalFinished(ContextRetrievalFinishedDebugEvent),
 193    EditPredictionStarted(EditPredictionStartedDebugEvent),
 194    EditPredictionFinished(EditPredictionFinishedDebugEvent),
 195}
 196
 197#[derive(Debug)]
 198pub struct ContextRetrievalStartedDebugEvent {
 199    pub project_entity_id: EntityId,
 200    pub timestamp: Instant,
 201    pub search_prompt: String,
 202}
 203
 204#[derive(Debug)]
 205pub struct ContextRetrievalFinishedDebugEvent {
 206    pub project_entity_id: EntityId,
 207    pub timestamp: Instant,
 208    pub metadata: Vec<(&'static str, SharedString)>,
 209}
 210
 211#[derive(Debug)]
 212pub struct EditPredictionStartedDebugEvent {
 213    pub buffer: WeakEntity<Buffer>,
 214    pub position: Anchor,
 215    pub prompt: Option<String>,
 216}
 217
 218#[derive(Debug)]
 219pub struct EditPredictionFinishedDebugEvent {
 220    pub buffer: WeakEntity<Buffer>,
 221    pub position: Anchor,
 222    pub model_output: Option<String>,
 223}
 224
 225const USER_ACTION_HISTORY_SIZE: usize = 16;
 226
 227#[derive(Clone, Debug)]
 228pub struct UserActionRecord {
 229    pub action_type: UserActionType,
 230    pub buffer_id: EntityId,
 231    pub line_number: u32,
 232    pub offset: usize,
 233    pub timestamp_epoch_ms: u64,
 234}
 235
 236#[derive(Clone, Copy, Debug, PartialEq, Eq)]
 237pub enum UserActionType {
 238    InsertChar,
 239    InsertSelection,
 240    DeleteChar,
 241    DeleteSelection,
 242    CursorMovement,
 243}
 244
 245/// An event with associated metadata for reconstructing buffer state.
 246#[derive(Clone)]
 247pub struct StoredEvent {
 248    pub event: Arc<zeta_prompt::Event>,
 249    pub old_snapshot: TextBufferSnapshot,
 250}
 251
 252struct ProjectState {
 253    events: VecDeque<StoredEvent>,
 254    last_event: Option<LastEvent>,
 255    recent_paths: VecDeque<ProjectPath>,
 256    registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
 257    current_prediction: Option<CurrentEditPrediction>,
 258    next_pending_prediction_id: usize,
 259    pending_predictions: ArrayVec<PendingPrediction, 2>,
 260    debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
 261    last_prediction_refresh: Option<(EntityId, Instant)>,
 262    cancelled_predictions: HashSet<usize>,
 263    context: Entity<RelatedExcerptStore>,
 264    license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
 265    user_actions: VecDeque<UserActionRecord>,
 266    _subscription: gpui::Subscription,
 267    copilot: Option<Entity<Copilot>>,
 268}
 269
 270impl ProjectState {
 271    fn record_user_action(&mut self, action: UserActionRecord) {
 272        if self.user_actions.len() >= USER_ACTION_HISTORY_SIZE {
 273            self.user_actions.pop_front();
 274        }
 275        self.user_actions.push_back(action);
 276    }
 277
 278    pub fn events(&self, cx: &App) -> Vec<StoredEvent> {
 279        self.events
 280            .iter()
 281            .cloned()
 282            .chain(
 283                self.last_event
 284                    .as_ref()
 285                    .and_then(|event| event.finalize(&self.license_detection_watchers, cx)),
 286            )
 287            .collect()
 288    }
 289
 290    pub fn events_split_by_pause(&self, cx: &App) -> Vec<StoredEvent> {
 291        self.events
 292            .iter()
 293            .cloned()
 294            .chain(self.last_event.as_ref().iter().flat_map(|event| {
 295                let (one, two) = event.split_by_pause();
 296                let one = one.finalize(&self.license_detection_watchers, cx);
 297                let two = two.and_then(|two| two.finalize(&self.license_detection_watchers, cx));
 298                one.into_iter().chain(two)
 299            }))
 300            .collect()
 301    }
 302
 303    fn cancel_pending_prediction(
 304        &mut self,
 305        pending_prediction: PendingPrediction,
 306        cx: &mut Context<EditPredictionStore>,
 307    ) {
 308        self.cancelled_predictions.insert(pending_prediction.id);
 309
 310        cx.spawn(async move |this, cx| {
 311            let Some(prediction_id) = pending_prediction.task.await else {
 312                return;
 313            };
 314
 315            this.update(cx, |this, _cx| {
 316                this.reject_prediction(prediction_id, EditPredictionRejectReason::Canceled, false);
 317            })
 318            .ok();
 319        })
 320        .detach()
 321    }
 322
 323    fn active_buffer(
 324        &self,
 325        project: &Entity<Project>,
 326        cx: &App,
 327    ) -> Option<(Entity<Buffer>, Option<Anchor>)> {
 328        let project = project.read(cx);
 329        let active_path = project.path_for_entry(project.active_entry()?, cx)?;
 330        let active_buffer = project.buffer_store().read(cx).get_by_path(&active_path)?;
 331        let registered_buffer = self.registered_buffers.get(&active_buffer.entity_id())?;
 332        Some((active_buffer, registered_buffer.last_position))
 333    }
 334}
 335
 336#[derive(Debug, Clone)]
 337struct CurrentEditPrediction {
 338    pub requested_by: PredictionRequestedBy,
 339    pub prediction: EditPrediction,
 340    pub was_shown: bool,
 341    pub shown_with: Option<edit_prediction_types::SuggestionDisplayType>,
 342}
 343
 344impl CurrentEditPrediction {
 345    fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
 346        let Some(new_edits) = self
 347            .prediction
 348            .interpolate(&self.prediction.buffer.read(cx))
 349        else {
 350            return false;
 351        };
 352
 353        if self.prediction.buffer != old_prediction.prediction.buffer {
 354            return true;
 355        }
 356
 357        let Some(old_edits) = old_prediction
 358            .prediction
 359            .interpolate(&old_prediction.prediction.buffer.read(cx))
 360        else {
 361            return true;
 362        };
 363
 364        let requested_by_buffer_id = self.requested_by.buffer_id();
 365
 366        // This reduces the occurrence of UI thrash from replacing edits
 367        //
 368        // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
 369        if requested_by_buffer_id == Some(self.prediction.buffer.entity_id())
 370            && requested_by_buffer_id == Some(old_prediction.prediction.buffer.entity_id())
 371            && old_edits.len() == 1
 372            && new_edits.len() == 1
 373        {
 374            let (old_range, old_text) = &old_edits[0];
 375            let (new_range, new_text) = &new_edits[0];
 376            new_range == old_range && new_text.starts_with(old_text.as_ref())
 377        } else {
 378            true
 379        }
 380    }
 381}
 382
 383#[derive(Debug, Clone)]
 384enum PredictionRequestedBy {
 385    DiagnosticsUpdate,
 386    Buffer(EntityId),
 387}
 388
 389impl PredictionRequestedBy {
 390    pub fn buffer_id(&self) -> Option<EntityId> {
 391        match self {
 392            PredictionRequestedBy::DiagnosticsUpdate => None,
 393            PredictionRequestedBy::Buffer(buffer_id) => Some(*buffer_id),
 394        }
 395    }
 396}
 397
 398#[derive(Debug)]
 399struct PendingPrediction {
 400    id: usize,
 401    task: Task<Option<EditPredictionId>>,
 402}
 403
 404/// A prediction from the perspective of a buffer.
 405#[derive(Debug)]
 406enum BufferEditPrediction<'a> {
 407    Local { prediction: &'a EditPrediction },
 408    Jump { prediction: &'a EditPrediction },
 409}
 410
 411#[cfg(test)]
 412impl std::ops::Deref for BufferEditPrediction<'_> {
 413    type Target = EditPrediction;
 414
 415    fn deref(&self) -> &Self::Target {
 416        match self {
 417            BufferEditPrediction::Local { prediction } => prediction,
 418            BufferEditPrediction::Jump { prediction } => prediction,
 419        }
 420    }
 421}
 422
 423struct RegisteredBuffer {
 424    file: Option<Arc<dyn File>>,
 425    snapshot: TextBufferSnapshot,
 426    last_position: Option<Anchor>,
 427    _subscriptions: [gpui::Subscription; 2],
 428}
 429
 430#[derive(Clone)]
 431struct LastEvent {
 432    old_snapshot: TextBufferSnapshot,
 433    new_snapshot: TextBufferSnapshot,
 434    old_file: Option<Arc<dyn File>>,
 435    new_file: Option<Arc<dyn File>>,
 436    edit_range: Option<Range<Anchor>>,
 437    snapshot_after_last_editing_pause: Option<TextBufferSnapshot>,
 438    last_edit_time: Option<Instant>,
 439}
 440
 441impl LastEvent {
 442    pub fn finalize(
 443        &self,
 444        license_detection_watchers: &HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
 445        cx: &App,
 446    ) -> Option<StoredEvent> {
 447        let path = buffer_path_with_id_fallback(self.new_file.as_ref(), &self.new_snapshot, cx);
 448        let old_path = buffer_path_with_id_fallback(self.old_file.as_ref(), &self.old_snapshot, cx);
 449
 450        let in_open_source_repo =
 451            [self.new_file.as_ref(), self.old_file.as_ref()]
 452                .iter()
 453                .all(|file| {
 454                    file.is_some_and(|file| {
 455                        license_detection_watchers
 456                            .get(&file.worktree_id(cx))
 457                            .is_some_and(|watcher| watcher.is_project_open_source())
 458                    })
 459                });
 460
 461        let diff = compute_diff_between_snapshots(&self.old_snapshot, &self.new_snapshot)?;
 462
 463        if path == old_path && diff.is_empty() {
 464            None
 465        } else {
 466            Some(StoredEvent {
 467                event: Arc::new(zeta_prompt::Event::BufferChange {
 468                    old_path,
 469                    path,
 470                    diff,
 471                    in_open_source_repo,
 472                    // TODO: Actually detect if this edit was predicted or not
 473                    predicted: false,
 474                }),
 475                old_snapshot: self.old_snapshot.clone(),
 476            })
 477        }
 478    }
 479
 480    pub fn split_by_pause(&self) -> (LastEvent, Option<LastEvent>) {
 481        let Some(boundary_snapshot) = self.snapshot_after_last_editing_pause.as_ref() else {
 482            return (self.clone(), None);
 483        };
 484
 485        let before = LastEvent {
 486            old_snapshot: self.old_snapshot.clone(),
 487            new_snapshot: boundary_snapshot.clone(),
 488            old_file: self.old_file.clone(),
 489            new_file: self.new_file.clone(),
 490            edit_range: None,
 491            snapshot_after_last_editing_pause: None,
 492            last_edit_time: self.last_edit_time,
 493        };
 494
 495        let after = LastEvent {
 496            old_snapshot: boundary_snapshot.clone(),
 497            new_snapshot: self.new_snapshot.clone(),
 498            old_file: self.old_file.clone(),
 499            new_file: self.new_file.clone(),
 500            edit_range: None,
 501            snapshot_after_last_editing_pause: None,
 502            last_edit_time: self.last_edit_time,
 503        };
 504
 505        (before, Some(after))
 506    }
 507}
 508
 509pub(crate) fn compute_diff_between_snapshots(
 510    old_snapshot: &TextBufferSnapshot,
 511    new_snapshot: &TextBufferSnapshot,
 512) -> Option<String> {
 513    let edits: Vec<Edit<usize>> = new_snapshot
 514        .edits_since::<usize>(&old_snapshot.version)
 515        .collect();
 516
 517    let (first_edit, last_edit) = edits.first().zip(edits.last())?;
 518
 519    let old_start_point = old_snapshot.offset_to_point(first_edit.old.start);
 520    let old_end_point = old_snapshot.offset_to_point(last_edit.old.end);
 521    let new_start_point = new_snapshot.offset_to_point(first_edit.new.start);
 522    let new_end_point = new_snapshot.offset_to_point(last_edit.new.end);
 523
 524    const CONTEXT_LINES: u32 = 3;
 525
 526    let old_context_start_row = old_start_point.row.saturating_sub(CONTEXT_LINES);
 527    let new_context_start_row = new_start_point.row.saturating_sub(CONTEXT_LINES);
 528    let old_context_end_row =
 529        (old_end_point.row + 1 + CONTEXT_LINES).min(old_snapshot.max_point().row);
 530    let new_context_end_row =
 531        (new_end_point.row + 1 + CONTEXT_LINES).min(new_snapshot.max_point().row);
 532
 533    let old_start_line_offset = old_snapshot.point_to_offset(Point::new(old_context_start_row, 0));
 534    let new_start_line_offset = new_snapshot.point_to_offset(Point::new(new_context_start_row, 0));
 535    let old_end_line_offset = old_snapshot
 536        .point_to_offset(Point::new(old_context_end_row + 1, 0).min(old_snapshot.max_point()));
 537    let new_end_line_offset = new_snapshot
 538        .point_to_offset(Point::new(new_context_end_row + 1, 0).min(new_snapshot.max_point()));
 539    let old_edit_range = old_start_line_offset..old_end_line_offset;
 540    let new_edit_range = new_start_line_offset..new_end_line_offset;
 541
 542    let old_region_text: String = old_snapshot.text_for_range(old_edit_range).collect();
 543    let new_region_text: String = new_snapshot.text_for_range(new_edit_range).collect();
 544
 545    let diff = language::unified_diff_with_offsets(
 546        &old_region_text,
 547        &new_region_text,
 548        old_context_start_row,
 549        new_context_start_row,
 550    );
 551
 552    Some(diff)
 553}
 554
 555fn buffer_path_with_id_fallback(
 556    file: Option<&Arc<dyn File>>,
 557    snapshot: &TextBufferSnapshot,
 558    cx: &App,
 559) -> Arc<Path> {
 560    if let Some(file) = file {
 561        file.full_path(cx).into()
 562    } else {
 563        Path::new(&format!("untitled-{}", snapshot.remote_id())).into()
 564    }
 565}
 566
 567impl EditPredictionStore {
 568    pub fn try_global(cx: &App) -> Option<Entity<Self>> {
 569        cx.try_global::<EditPredictionStoreGlobal>()
 570            .map(|global| global.0.clone())
 571    }
 572
 573    pub fn global(
 574        client: &Arc<Client>,
 575        user_store: &Entity<UserStore>,
 576        cx: &mut App,
 577    ) -> Entity<Self> {
 578        cx.try_global::<EditPredictionStoreGlobal>()
 579            .map(|global| global.0.clone())
 580            .unwrap_or_else(|| {
 581                let ep_store = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
 582                cx.set_global(EditPredictionStoreGlobal(ep_store.clone()));
 583                ep_store
 584            })
 585    }
 586
 587    pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
 588        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
 589        let data_collection_choice = Self::load_data_collection_choice();
 590
 591        let llm_token = LlmApiToken::default();
 592
 593        let (reject_tx, reject_rx) = mpsc::unbounded();
 594        cx.background_spawn({
 595            let client = client.clone();
 596            let llm_token = llm_token.clone();
 597            let app_version = AppVersion::global(cx);
 598            let background_executor = cx.background_executor().clone();
 599            async move {
 600                Self::handle_rejected_predictions(
 601                    reject_rx,
 602                    client,
 603                    llm_token,
 604                    app_version,
 605                    background_executor,
 606                )
 607                .await
 608            }
 609        })
 610        .detach();
 611
 612        let this = Self {
 613            projects: HashMap::default(),
 614            client,
 615            user_store,
 616            llm_token,
 617            _llm_token_subscription: cx.subscribe(
 618                &refresh_llm_token_listener,
 619                |this, _listener, _event, cx| {
 620                    let client = this.client.clone();
 621                    let llm_token = this.llm_token.clone();
 622                    cx.spawn(async move |_this, _cx| {
 623                        llm_token.refresh(&client).await?;
 624                        anyhow::Ok(())
 625                    })
 626                    .detach_and_log_err(cx);
 627                },
 628            ),
 629            update_required: false,
 630            edit_prediction_model: EditPredictionModel::Zeta2 {
 631                version: Default::default(),
 632            },
 633            sweep_ai: SweepAi::new(cx),
 634            mercury: Mercury::new(cx),
 635
 636            data_collection_choice,
 637            reject_predictions_tx: reject_tx,
 638            rated_predictions: Default::default(),
 639            shown_predictions: Default::default(),
 640            custom_predict_edits_url: match env::var("ZED_PREDICT_EDITS_URL") {
 641                Ok(custom_url) => Url::parse(&custom_url).log_err().map(Into::into),
 642                Err(_) => None,
 643            },
 644        };
 645
 646        this
 647    }
 648
 649    #[cfg(test)]
 650    pub fn set_custom_predict_edits_url(&mut self, url: Url) {
 651        self.custom_predict_edits_url = Some(url.into());
 652    }
 653
 654    pub fn set_edit_prediction_model(&mut self, model: EditPredictionModel) {
 655        self.edit_prediction_model = model;
 656    }
 657
 658    pub fn has_sweep_api_token(&self, cx: &App) -> bool {
 659        self.sweep_ai.api_token.read(cx).has_key()
 660    }
 661
 662    pub fn has_mercury_api_token(&self, cx: &App) -> bool {
 663        self.mercury.api_token.read(cx).has_key()
 664    }
 665
 666    pub fn clear_history(&mut self) {
 667        for project_state in self.projects.values_mut() {
 668            project_state.events.clear();
 669            project_state.last_event.take();
 670        }
 671    }
 672
 673    pub fn clear_history_for_project(&mut self, project: &Entity<Project>) {
 674        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
 675            project_state.events.clear();
 676            project_state.last_event.take();
 677        }
 678    }
 679
 680    pub fn edit_history_for_project(
 681        &self,
 682        project: &Entity<Project>,
 683        cx: &App,
 684    ) -> Vec<StoredEvent> {
 685        self.projects
 686            .get(&project.entity_id())
 687            .map(|project_state| project_state.events(cx))
 688            .unwrap_or_default()
 689    }
 690
 691    pub fn edit_history_for_project_with_pause_split_last_event(
 692        &self,
 693        project: &Entity<Project>,
 694        cx: &App,
 695    ) -> Vec<StoredEvent> {
 696        self.projects
 697            .get(&project.entity_id())
 698            .map(|project_state| project_state.events_split_by_pause(cx))
 699            .unwrap_or_default()
 700    }
 701
 702    pub fn context_for_project<'a>(
 703        &'a self,
 704        project: &Entity<Project>,
 705        cx: &'a mut App,
 706    ) -> Vec<RelatedFile> {
 707        self.projects
 708            .get(&project.entity_id())
 709            .map(|project| {
 710                project
 711                    .context
 712                    .update(cx, |context, cx| context.related_files(cx))
 713            })
 714            .unwrap_or_default()
 715    }
 716
 717    pub fn copilot_for_project(&self, project: &Entity<Project>) -> Option<Entity<Copilot>> {
 718        self.projects
 719            .get(&project.entity_id())
 720            .and_then(|project| project.copilot.clone())
 721    }
 722
 723    pub fn start_copilot_for_project(
 724        &mut self,
 725        project: &Entity<Project>,
 726        cx: &mut Context<Self>,
 727    ) -> Option<Entity<Copilot>> {
 728        let state = self.get_or_init_project(project, cx);
 729
 730        if state.copilot.is_some() {
 731            return state.copilot.clone();
 732        }
 733        let _project = project.clone();
 734        let project = project.read(cx);
 735
 736        let node = project.node_runtime().cloned();
 737        if let Some(node) = node {
 738            let next_id = project.languages().next_language_server_id();
 739            let fs = project.fs().clone();
 740
 741            let copilot = cx.new(|cx| Copilot::new(Some(_project), next_id, fs, node, cx));
 742            state.copilot = Some(copilot.clone());
 743            Some(copilot)
 744        } else {
 745            None
 746        }
 747    }
 748
 749    pub fn context_for_project_with_buffers<'a>(
 750        &'a self,
 751        project: &Entity<Project>,
 752        cx: &'a mut App,
 753    ) -> Vec<(RelatedFile, Entity<Buffer>)> {
 754        self.projects
 755            .get(&project.entity_id())
 756            .map(|project| {
 757                project
 758                    .context
 759                    .update(cx, |context, cx| context.related_files_with_buffers(cx))
 760            })
 761            .unwrap_or_default()
 762    }
 763
 764    pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
 765        if matches!(
 766            self.edit_prediction_model,
 767            EditPredictionModel::Zeta2 { .. }
 768        ) {
 769            self.user_store.read(cx).edit_prediction_usage()
 770        } else {
 771            None
 772        }
 773    }
 774
 775    pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
 776        self.get_or_init_project(project, cx);
 777    }
 778
 779    pub fn register_buffer(
 780        &mut self,
 781        buffer: &Entity<Buffer>,
 782        project: &Entity<Project>,
 783        cx: &mut Context<Self>,
 784    ) {
 785        let project_state = self.get_or_init_project(project, cx);
 786        Self::register_buffer_impl(project_state, buffer, project, cx);
 787    }
 788
 789    fn get_or_init_project(
 790        &mut self,
 791        project: &Entity<Project>,
 792        cx: &mut Context<Self>,
 793    ) -> &mut ProjectState {
 794        let entity_id = project.entity_id();
 795        self.projects
 796            .entry(entity_id)
 797            .or_insert_with(|| ProjectState {
 798                context: {
 799                    let related_excerpt_store = cx.new(|cx| RelatedExcerptStore::new(project, cx));
 800                    cx.subscribe(&related_excerpt_store, move |this, _, event, _| {
 801                        this.handle_excerpt_store_event(entity_id, event);
 802                    })
 803                    .detach();
 804                    related_excerpt_store
 805                },
 806                events: VecDeque::new(),
 807                last_event: None,
 808                recent_paths: VecDeque::new(),
 809                debug_tx: None,
 810                registered_buffers: HashMap::default(),
 811                current_prediction: None,
 812                cancelled_predictions: HashSet::default(),
 813                pending_predictions: ArrayVec::new(),
 814                next_pending_prediction_id: 0,
 815                last_prediction_refresh: None,
 816                license_detection_watchers: HashMap::default(),
 817                user_actions: VecDeque::with_capacity(USER_ACTION_HISTORY_SIZE),
 818                _subscription: cx.subscribe(&project, Self::handle_project_event),
 819                copilot: None,
 820            })
 821    }
 822
 823    pub fn remove_project(&mut self, project: &Entity<Project>) {
 824        self.projects.remove(&project.entity_id());
 825    }
 826
 827    fn handle_excerpt_store_event(
 828        &mut self,
 829        project_entity_id: EntityId,
 830        event: &RelatedExcerptStoreEvent,
 831    ) {
 832        if let Some(project_state) = self.projects.get(&project_entity_id) {
 833            if let Some(debug_tx) = project_state.debug_tx.clone() {
 834                match event {
 835                    RelatedExcerptStoreEvent::StartedRefresh => {
 836                        debug_tx
 837                            .unbounded_send(DebugEvent::ContextRetrievalStarted(
 838                                ContextRetrievalStartedDebugEvent {
 839                                    project_entity_id: project_entity_id,
 840                                    timestamp: Instant::now(),
 841                                    search_prompt: String::new(),
 842                                },
 843                            ))
 844                            .ok();
 845                    }
 846                    RelatedExcerptStoreEvent::FinishedRefresh {
 847                        cache_hit_count,
 848                        cache_miss_count,
 849                        mean_definition_latency,
 850                        max_definition_latency,
 851                    } => {
 852                        debug_tx
 853                            .unbounded_send(DebugEvent::ContextRetrievalFinished(
 854                                ContextRetrievalFinishedDebugEvent {
 855                                    project_entity_id: project_entity_id,
 856                                    timestamp: Instant::now(),
 857                                    metadata: vec![
 858                                        (
 859                                            "Cache Hits",
 860                                            format!(
 861                                                "{}/{}",
 862                                                cache_hit_count,
 863                                                cache_hit_count + cache_miss_count
 864                                            )
 865                                            .into(),
 866                                        ),
 867                                        (
 868                                            "Max LSP Time",
 869                                            format!("{} ms", max_definition_latency.as_millis())
 870                                                .into(),
 871                                        ),
 872                                        (
 873                                            "Mean LSP Time",
 874                                            format!("{} ms", mean_definition_latency.as_millis())
 875                                                .into(),
 876                                        ),
 877                                    ],
 878                                },
 879                            ))
 880                            .ok();
 881                    }
 882                }
 883            }
 884        }
 885    }
 886
 887    pub fn debug_info(
 888        &mut self,
 889        project: &Entity<Project>,
 890        cx: &mut Context<Self>,
 891    ) -> mpsc::UnboundedReceiver<DebugEvent> {
 892        let project_state = self.get_or_init_project(project, cx);
 893        let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
 894        project_state.debug_tx = Some(debug_watch_tx);
 895        debug_watch_rx
 896    }
 897
 898    fn handle_project_event(
 899        &mut self,
 900        project: Entity<Project>,
 901        event: &project::Event,
 902        cx: &mut Context<Self>,
 903    ) {
 904        // TODO [zeta2] init with recent paths
 905        match event {
 906            project::Event::ActiveEntryChanged(Some(active_entry_id)) => {
 907                let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
 908                    return;
 909                };
 910                let path = project.read(cx).path_for_entry(*active_entry_id, cx);
 911                if let Some(path) = path {
 912                    if let Some(ix) = project_state
 913                        .recent_paths
 914                        .iter()
 915                        .position(|probe| probe == &path)
 916                    {
 917                        project_state.recent_paths.remove(ix);
 918                    }
 919                    project_state.recent_paths.push_front(path);
 920                }
 921            }
 922            project::Event::DiagnosticsUpdated { .. } => {
 923                if cx.has_flag::<Zeta2FeatureFlag>() {
 924                    self.refresh_prediction_from_diagnostics(project, cx);
 925                }
 926            }
 927            _ => (),
 928        }
 929    }
 930
 931    fn register_buffer_impl<'a>(
 932        project_state: &'a mut ProjectState,
 933        buffer: &Entity<Buffer>,
 934        project: &Entity<Project>,
 935        cx: &mut Context<Self>,
 936    ) -> &'a mut RegisteredBuffer {
 937        let buffer_id = buffer.entity_id();
 938
 939        if let Some(file) = buffer.read(cx).file() {
 940            let worktree_id = file.worktree_id(cx);
 941            if let Some(worktree) = project.read(cx).worktree_for_id(worktree_id, cx) {
 942                project_state
 943                    .license_detection_watchers
 944                    .entry(worktree_id)
 945                    .or_insert_with(|| {
 946                        let project_entity_id = project.entity_id();
 947                        cx.observe_release(&worktree, move |this, _worktree, _cx| {
 948                            let Some(project_state) = this.projects.get_mut(&project_entity_id)
 949                            else {
 950                                return;
 951                            };
 952                            project_state
 953                                .license_detection_watchers
 954                                .remove(&worktree_id);
 955                        })
 956                        .detach();
 957                        Rc::new(LicenseDetectionWatcher::new(&worktree, cx))
 958                    });
 959            }
 960        }
 961
 962        match project_state.registered_buffers.entry(buffer_id) {
 963            hash_map::Entry::Occupied(entry) => entry.into_mut(),
 964            hash_map::Entry::Vacant(entry) => {
 965                let buf = buffer.read(cx);
 966                let snapshot = buf.text_snapshot();
 967                let file = buf.file().cloned();
 968                let project_entity_id = project.entity_id();
 969                entry.insert(RegisteredBuffer {
 970                    snapshot,
 971                    file,
 972                    last_position: None,
 973                    _subscriptions: [
 974                        cx.subscribe(buffer, {
 975                            let project = project.downgrade();
 976                            move |this, buffer, event, cx| {
 977                                if let language::BufferEvent::Edited = event
 978                                    && let Some(project) = project.upgrade()
 979                                {
 980                                    this.report_changes_for_buffer(&buffer, &project, cx);
 981                                }
 982                            }
 983                        }),
 984                        cx.observe_release(buffer, move |this, _buffer, _cx| {
 985                            let Some(project_state) = this.projects.get_mut(&project_entity_id)
 986                            else {
 987                                return;
 988                            };
 989                            project_state.registered_buffers.remove(&buffer_id);
 990                        }),
 991                    ],
 992                })
 993            }
 994        }
 995    }
 996
 997    fn report_changes_for_buffer(
 998        &mut self,
 999        buffer: &Entity<Buffer>,
1000        project: &Entity<Project>,
1001        cx: &mut Context<Self>,
1002    ) {
1003        let project_state = self.get_or_init_project(project, cx);
1004        let registered_buffer = Self::register_buffer_impl(project_state, buffer, project, cx);
1005
1006        let buf = buffer.read(cx);
1007        let new_file = buf.file().cloned();
1008        let new_snapshot = buf.text_snapshot();
1009        if new_snapshot.version == registered_buffer.snapshot.version {
1010            return;
1011        }
1012
1013        let old_file = mem::replace(&mut registered_buffer.file, new_file.clone());
1014        let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
1015        let mut num_edits = 0usize;
1016        let mut total_deleted = 0usize;
1017        let mut total_inserted = 0usize;
1018        let mut edit_range: Option<Range<Anchor>> = None;
1019        let mut last_offset: Option<usize> = None;
1020
1021        for (edit, anchor_range) in
1022            new_snapshot.anchored_edits_since::<usize>(&old_snapshot.version)
1023        {
1024            num_edits += 1;
1025            total_deleted += edit.old.len();
1026            total_inserted += edit.new.len();
1027            edit_range = Some(match edit_range {
1028                None => anchor_range,
1029                Some(acc) => acc.start..anchor_range.end,
1030            });
1031            last_offset = Some(edit.new.end);
1032        }
1033
1034        if num_edits > 0 {
1035            let action_type = match (total_deleted, total_inserted, num_edits) {
1036                (0, ins, n) if ins == n => UserActionType::InsertChar,
1037                (0, _, _) => UserActionType::InsertSelection,
1038                (del, 0, n) if del == n => UserActionType::DeleteChar,
1039                (_, 0, _) => UserActionType::DeleteSelection,
1040                (_, ins, n) if ins == n => UserActionType::InsertChar,
1041                (_, _, _) => UserActionType::InsertSelection,
1042            };
1043
1044            if let Some(offset) = last_offset {
1045                let point = new_snapshot.offset_to_point(offset);
1046                let timestamp_epoch_ms = SystemTime::now()
1047                    .duration_since(UNIX_EPOCH)
1048                    .map(|d| d.as_millis() as u64)
1049                    .unwrap_or(0);
1050                project_state.record_user_action(UserActionRecord {
1051                    action_type,
1052                    buffer_id: buffer.entity_id(),
1053                    line_number: point.row,
1054                    offset,
1055                    timestamp_epoch_ms,
1056                });
1057            }
1058        }
1059
1060        let events = &mut project_state.events;
1061
1062        let now = cx.background_executor().now();
1063        if let Some(last_event) = project_state.last_event.as_mut() {
1064            let is_next_snapshot_of_same_buffer = old_snapshot.remote_id()
1065                == last_event.new_snapshot.remote_id()
1066                && old_snapshot.version == last_event.new_snapshot.version;
1067
1068            let should_coalesce = is_next_snapshot_of_same_buffer
1069                && edit_range
1070                    .as_ref()
1071                    .zip(last_event.edit_range.as_ref())
1072                    .is_some_and(|(a, b)| {
1073                        let a = a.to_point(&new_snapshot);
1074                        let b = b.to_point(&new_snapshot);
1075                        if a.start > b.end {
1076                            a.start.row.abs_diff(b.end.row) <= CHANGE_GROUPING_LINE_SPAN
1077                        } else if b.start > a.end {
1078                            b.start.row.abs_diff(a.end.row) <= CHANGE_GROUPING_LINE_SPAN
1079                        } else {
1080                            true
1081                        }
1082                    });
1083
1084            if should_coalesce {
1085                let pause_elapsed = last_event
1086                    .last_edit_time
1087                    .map(|t| now.duration_since(t) >= LAST_CHANGE_GROUPING_TIME)
1088                    .unwrap_or(false);
1089                if pause_elapsed {
1090                    last_event.snapshot_after_last_editing_pause =
1091                        Some(last_event.new_snapshot.clone());
1092                }
1093
1094                last_event.edit_range = edit_range;
1095                last_event.new_snapshot = new_snapshot;
1096                last_event.last_edit_time = Some(now);
1097                return;
1098            }
1099        }
1100
1101        if let Some(event) = project_state.last_event.take() {
1102            if let Some(event) = event.finalize(&project_state.license_detection_watchers, cx) {
1103                if events.len() + 1 >= EVENT_COUNT_MAX {
1104                    events.pop_front();
1105                }
1106                events.push_back(event);
1107            }
1108        }
1109
1110        project_state.last_event = Some(LastEvent {
1111            old_file,
1112            new_file,
1113            old_snapshot,
1114            new_snapshot,
1115            edit_range,
1116            snapshot_after_last_editing_pause: None,
1117            last_edit_time: Some(now),
1118        });
1119    }
1120
1121    fn prediction_at(
1122        &mut self,
1123        buffer: &Entity<Buffer>,
1124        position: Option<language::Anchor>,
1125        project: &Entity<Project>,
1126        cx: &App,
1127    ) -> Option<BufferEditPrediction<'_>> {
1128        let project_state = self.projects.get_mut(&project.entity_id())?;
1129        if let Some(position) = position
1130            && let Some(buffer) = project_state
1131                .registered_buffers
1132                .get_mut(&buffer.entity_id())
1133        {
1134            buffer.last_position = Some(position);
1135        }
1136
1137        let CurrentEditPrediction {
1138            requested_by,
1139            prediction,
1140            ..
1141        } = project_state.current_prediction.as_ref()?;
1142
1143        if prediction.targets_buffer(buffer.read(cx)) {
1144            Some(BufferEditPrediction::Local { prediction })
1145        } else {
1146            let show_jump = match requested_by {
1147                PredictionRequestedBy::Buffer(requested_by_buffer_id) => {
1148                    requested_by_buffer_id == &buffer.entity_id()
1149                }
1150                PredictionRequestedBy::DiagnosticsUpdate => true,
1151            };
1152
1153            if show_jump {
1154                Some(BufferEditPrediction::Jump { prediction })
1155            } else {
1156                None
1157            }
1158        }
1159    }
1160
1161    fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
1162        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1163            return;
1164        };
1165
1166        let Some(current_prediction) = project_state.current_prediction.take() else {
1167            return;
1168        };
1169
1170        for pending_prediction in mem::take(&mut project_state.pending_predictions) {
1171            project_state.cancel_pending_prediction(pending_prediction, cx);
1172        }
1173
1174        match self.edit_prediction_model {
1175            EditPredictionModel::Sweep => {
1176                sweep_ai::edit_prediction_accepted(self, current_prediction, cx)
1177            }
1178            EditPredictionModel::Mercury => {}
1179            EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 { .. } => {
1180                zeta2::edit_prediction_accepted(self, current_prediction, cx)
1181            }
1182        }
1183    }
1184
1185    async fn handle_rejected_predictions(
1186        rx: UnboundedReceiver<EditPredictionRejection>,
1187        client: Arc<Client>,
1188        llm_token: LlmApiToken,
1189        app_version: Version,
1190        background_executor: BackgroundExecutor,
1191    ) {
1192        let mut rx = std::pin::pin!(rx.peekable());
1193        let mut batched = Vec::new();
1194
1195        while let Some(rejection) = rx.next().await {
1196            batched.push(rejection);
1197
1198            if batched.len() < MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST / 2 {
1199                select_biased! {
1200                    next = rx.as_mut().peek().fuse() => {
1201                        if next.is_some() {
1202                            continue;
1203                        }
1204                    }
1205                    () = background_executor.timer(REJECT_REQUEST_DEBOUNCE).fuse() => {},
1206                }
1207            }
1208
1209            let url = client
1210                .http_client()
1211                .build_zed_llm_url("/predict_edits/reject", &[])
1212                .unwrap();
1213
1214            let flush_count = batched
1215                .len()
1216                // in case items have accumulated after failure
1217                .min(MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST);
1218            let start = batched.len() - flush_count;
1219
1220            let body = RejectEditPredictionsBodyRef {
1221                rejections: &batched[start..],
1222            };
1223
1224            let result = Self::send_api_request::<()>(
1225                |builder| {
1226                    let req = builder
1227                        .uri(url.as_ref())
1228                        .body(serde_json::to_string(&body)?.into());
1229                    anyhow::Ok(req?)
1230                },
1231                client.clone(),
1232                llm_token.clone(),
1233                app_version.clone(),
1234                true,
1235            )
1236            .await;
1237
1238            if result.log_err().is_some() {
1239                batched.drain(start..);
1240            }
1241        }
1242    }
1243
1244    fn reject_current_prediction(
1245        &mut self,
1246        reason: EditPredictionRejectReason,
1247        project: &Entity<Project>,
1248    ) {
1249        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
1250            project_state.pending_predictions.clear();
1251            if let Some(prediction) = project_state.current_prediction.take() {
1252                self.reject_prediction(prediction.prediction.id, reason, prediction.was_shown);
1253            }
1254        };
1255    }
1256
1257    fn did_show_current_prediction(
1258        &mut self,
1259        project: &Entity<Project>,
1260        display_type: edit_prediction_types::SuggestionDisplayType,
1261        cx: &mut Context<Self>,
1262    ) {
1263        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1264            return;
1265        };
1266
1267        let Some(current_prediction) = project_state.current_prediction.as_mut() else {
1268            return;
1269        };
1270
1271        let is_jump = display_type == edit_prediction_types::SuggestionDisplayType::Jump;
1272        let previous_shown_with = current_prediction.shown_with;
1273
1274        if previous_shown_with.is_none() || !is_jump {
1275            current_prediction.shown_with = Some(display_type);
1276        }
1277
1278        let is_first_non_jump_show = !current_prediction.was_shown && !is_jump;
1279
1280        if is_first_non_jump_show {
1281            current_prediction.was_shown = true;
1282        }
1283
1284        let display_type_changed = previous_shown_with != Some(display_type);
1285
1286        if self.edit_prediction_model == EditPredictionModel::Sweep && display_type_changed {
1287            sweep_ai::edit_prediction_shown(
1288                &self.sweep_ai,
1289                self.client.clone(),
1290                &current_prediction.prediction,
1291                display_type,
1292                cx,
1293            );
1294        }
1295
1296        if is_first_non_jump_show {
1297            self.shown_predictions
1298                .push_front(current_prediction.prediction.clone());
1299            if self.shown_predictions.len() > 50 {
1300                let completion = self.shown_predictions.pop_back().unwrap();
1301                self.rated_predictions.remove(&completion.id);
1302            }
1303        }
1304    }
1305
1306    fn reject_prediction(
1307        &mut self,
1308        prediction_id: EditPredictionId,
1309        reason: EditPredictionRejectReason,
1310        was_shown: bool,
1311    ) {
1312        match self.edit_prediction_model {
1313            EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 { .. } => {
1314                if self.custom_predict_edits_url.is_some() {
1315                    return;
1316                }
1317            }
1318            EditPredictionModel::Sweep | EditPredictionModel::Mercury => return,
1319        }
1320
1321        self.reject_predictions_tx
1322            .unbounded_send(EditPredictionRejection {
1323                request_id: prediction_id.to_string(),
1324                reason,
1325                was_shown,
1326            })
1327            .log_err();
1328    }
1329
1330    fn is_refreshing(&self, project: &Entity<Project>) -> bool {
1331        self.projects
1332            .get(&project.entity_id())
1333            .is_some_and(|project_state| !project_state.pending_predictions.is_empty())
1334    }
1335
1336    pub fn refresh_prediction_from_buffer(
1337        &mut self,
1338        project: Entity<Project>,
1339        buffer: Entity<Buffer>,
1340        position: language::Anchor,
1341        cx: &mut Context<Self>,
1342    ) {
1343        self.queue_prediction_refresh(project.clone(), buffer.entity_id(), cx, move |this, cx| {
1344            let Some(request_task) = this
1345                .update(cx, |this, cx| {
1346                    this.request_prediction(
1347                        &project,
1348                        &buffer,
1349                        position,
1350                        PredictEditsRequestTrigger::Other,
1351                        cx,
1352                    )
1353                })
1354                .log_err()
1355            else {
1356                return Task::ready(anyhow::Ok(None));
1357            };
1358
1359            cx.spawn(async move |_cx| {
1360                request_task.await.map(|prediction_result| {
1361                    prediction_result.map(|prediction_result| {
1362                        (
1363                            prediction_result,
1364                            PredictionRequestedBy::Buffer(buffer.entity_id()),
1365                        )
1366                    })
1367                })
1368            })
1369        })
1370    }
1371
1372    pub fn refresh_prediction_from_diagnostics(
1373        &mut self,
1374        project: Entity<Project>,
1375        cx: &mut Context<Self>,
1376    ) {
1377        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1378            return;
1379        };
1380
1381        // Prefer predictions from buffer
1382        if project_state.current_prediction.is_some() {
1383            return;
1384        };
1385
1386        self.queue_prediction_refresh(project.clone(), project.entity_id(), cx, move |this, cx| {
1387            let Some((active_buffer, snapshot, cursor_point)) = this
1388                .read_with(cx, |this, cx| {
1389                    let project_state = this.projects.get(&project.entity_id())?;
1390                    let (buffer, position) = project_state.active_buffer(&project, cx)?;
1391                    let snapshot = buffer.read(cx).snapshot();
1392
1393                    if !Self::predictions_enabled_at(&snapshot, position, cx) {
1394                        return None;
1395                    }
1396
1397                    let cursor_point = position
1398                        .map(|pos| pos.to_point(&snapshot))
1399                        .unwrap_or_default();
1400
1401                    Some((buffer, snapshot, cursor_point))
1402                })
1403                .log_err()
1404                .flatten()
1405            else {
1406                return Task::ready(anyhow::Ok(None));
1407            };
1408
1409            cx.spawn(async move |cx| {
1410                let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1411                    active_buffer,
1412                    &snapshot,
1413                    Default::default(),
1414                    cursor_point,
1415                    &project,
1416                    cx,
1417                )
1418                .await?
1419                else {
1420                    return anyhow::Ok(None);
1421                };
1422
1423                let Some(prediction_result) = this
1424                    .update(cx, |this, cx| {
1425                        this.request_prediction(
1426                            &project,
1427                            &jump_buffer,
1428                            jump_position,
1429                            PredictEditsRequestTrigger::Diagnostics,
1430                            cx,
1431                        )
1432                    })?
1433                    .await?
1434                else {
1435                    return anyhow::Ok(None);
1436                };
1437
1438                this.update(cx, |this, cx| {
1439                    Some((
1440                        if this
1441                            .get_or_init_project(&project, cx)
1442                            .current_prediction
1443                            .is_none()
1444                        {
1445                            prediction_result
1446                        } else {
1447                            EditPredictionResult {
1448                                id: prediction_result.id,
1449                                prediction: Err(EditPredictionRejectReason::CurrentPreferred),
1450                            }
1451                        },
1452                        PredictionRequestedBy::DiagnosticsUpdate,
1453                    ))
1454                })
1455            })
1456        });
1457    }
1458
1459    fn predictions_enabled_at(
1460        snapshot: &BufferSnapshot,
1461        position: Option<language::Anchor>,
1462        cx: &App,
1463    ) -> bool {
1464        let file = snapshot.file();
1465        let all_settings = all_language_settings(file, cx);
1466        if !all_settings.show_edit_predictions(snapshot.language(), cx)
1467            || file.is_some_and(|file| !all_settings.edit_predictions_enabled_for_file(file, cx))
1468        {
1469            return false;
1470        }
1471
1472        if let Some(last_position) = position {
1473            let settings = snapshot.settings_at(last_position, cx);
1474
1475            if !settings.edit_predictions_disabled_in.is_empty()
1476                && let Some(scope) = snapshot.language_scope_at(last_position)
1477                && let Some(scope_name) = scope.override_name()
1478                && settings
1479                    .edit_predictions_disabled_in
1480                    .iter()
1481                    .any(|s| s == scope_name)
1482            {
1483                return false;
1484            }
1485        }
1486
1487        true
1488    }
1489
1490    #[cfg(not(test))]
1491    pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
1492    #[cfg(test)]
1493    pub const THROTTLE_TIMEOUT: Duration = Duration::ZERO;
1494
1495    fn queue_prediction_refresh(
1496        &mut self,
1497        project: Entity<Project>,
1498        throttle_entity: EntityId,
1499        cx: &mut Context<Self>,
1500        do_refresh: impl FnOnce(
1501            WeakEntity<Self>,
1502            &mut AsyncApp,
1503        )
1504            -> Task<Result<Option<(EditPredictionResult, PredictionRequestedBy)>>>
1505        + 'static,
1506    ) {
1507        let project_state = self.get_or_init_project(&project, cx);
1508        let pending_prediction_id = project_state.next_pending_prediction_id;
1509        project_state.next_pending_prediction_id += 1;
1510        let last_request = project_state.last_prediction_refresh;
1511
1512        let task = cx.spawn(async move |this, cx| {
1513            if let Some((last_entity, last_timestamp)) = last_request
1514                && throttle_entity == last_entity
1515                && let Some(timeout) =
1516                    (last_timestamp + Self::THROTTLE_TIMEOUT).checked_duration_since(Instant::now())
1517            {
1518                cx.background_executor().timer(timeout).await;
1519            }
1520
1521            // If this task was cancelled before the throttle timeout expired,
1522            // do not perform a request.
1523            let mut is_cancelled = true;
1524            this.update(cx, |this, cx| {
1525                let project_state = this.get_or_init_project(&project, cx);
1526                if !project_state
1527                    .cancelled_predictions
1528                    .remove(&pending_prediction_id)
1529                {
1530                    project_state.last_prediction_refresh = Some((throttle_entity, Instant::now()));
1531                    is_cancelled = false;
1532                }
1533            })
1534            .ok();
1535            if is_cancelled {
1536                return None;
1537            }
1538
1539            let new_prediction_result = do_refresh(this.clone(), cx).await.log_err().flatten();
1540            let new_prediction_id = new_prediction_result
1541                .as_ref()
1542                .map(|(prediction, _)| prediction.id.clone());
1543
1544            // When a prediction completes, remove it from the pending list, and cancel
1545            // any pending predictions that were enqueued before it.
1546            this.update(cx, |this, cx| {
1547                let project_state = this.get_or_init_project(&project, cx);
1548
1549                let is_cancelled = project_state
1550                    .cancelled_predictions
1551                    .remove(&pending_prediction_id);
1552
1553                let new_current_prediction = if !is_cancelled
1554                    && let Some((prediction_result, requested_by)) = new_prediction_result
1555                {
1556                    match prediction_result.prediction {
1557                        Ok(prediction) => {
1558                            let new_prediction = CurrentEditPrediction {
1559                                requested_by,
1560                                prediction,
1561                                was_shown: false,
1562                                shown_with: None,
1563                            };
1564
1565                            if let Some(current_prediction) =
1566                                project_state.current_prediction.as_ref()
1567                            {
1568                                if new_prediction.should_replace_prediction(&current_prediction, cx)
1569                                {
1570                                    this.reject_current_prediction(
1571                                        EditPredictionRejectReason::Replaced,
1572                                        &project,
1573                                    );
1574
1575                                    Some(new_prediction)
1576                                } else {
1577                                    this.reject_prediction(
1578                                        new_prediction.prediction.id,
1579                                        EditPredictionRejectReason::CurrentPreferred,
1580                                        false,
1581                                    );
1582                                    None
1583                                }
1584                            } else {
1585                                Some(new_prediction)
1586                            }
1587                        }
1588                        Err(reject_reason) => {
1589                            this.reject_prediction(prediction_result.id, reject_reason, false);
1590                            None
1591                        }
1592                    }
1593                } else {
1594                    None
1595                };
1596
1597                let project_state = this.get_or_init_project(&project, cx);
1598
1599                if let Some(new_prediction) = new_current_prediction {
1600                    project_state.current_prediction = Some(new_prediction);
1601                }
1602
1603                let mut pending_predictions = mem::take(&mut project_state.pending_predictions);
1604                for (ix, pending_prediction) in pending_predictions.iter().enumerate() {
1605                    if pending_prediction.id == pending_prediction_id {
1606                        pending_predictions.remove(ix);
1607                        for pending_prediction in pending_predictions.drain(0..ix) {
1608                            project_state.cancel_pending_prediction(pending_prediction, cx)
1609                        }
1610                        break;
1611                    }
1612                }
1613                this.get_or_init_project(&project, cx).pending_predictions = pending_predictions;
1614                cx.notify();
1615            })
1616            .ok();
1617
1618            new_prediction_id
1619        });
1620
1621        if project_state.pending_predictions.len() <= 1 {
1622            project_state.pending_predictions.push(PendingPrediction {
1623                id: pending_prediction_id,
1624                task,
1625            });
1626        } else if project_state.pending_predictions.len() == 2 {
1627            let pending_prediction = project_state.pending_predictions.pop().unwrap();
1628            project_state.pending_predictions.push(PendingPrediction {
1629                id: pending_prediction_id,
1630                task,
1631            });
1632            project_state.cancel_pending_prediction(pending_prediction, cx);
1633        }
1634    }
1635
1636    pub fn request_prediction(
1637        &mut self,
1638        project: &Entity<Project>,
1639        active_buffer: &Entity<Buffer>,
1640        position: language::Anchor,
1641        trigger: PredictEditsRequestTrigger,
1642        cx: &mut Context<Self>,
1643    ) -> Task<Result<Option<EditPredictionResult>>> {
1644        self.request_prediction_internal(
1645            project.clone(),
1646            active_buffer.clone(),
1647            position,
1648            trigger,
1649            cx.has_flag::<Zeta2FeatureFlag>(),
1650            cx,
1651        )
1652    }
1653
1654    fn request_prediction_internal(
1655        &mut self,
1656        project: Entity<Project>,
1657        active_buffer: Entity<Buffer>,
1658        position: language::Anchor,
1659        trigger: PredictEditsRequestTrigger,
1660        allow_jump: bool,
1661        cx: &mut Context<Self>,
1662    ) -> Task<Result<Option<EditPredictionResult>>> {
1663        const DIAGNOSTIC_LINES_RANGE: u32 = 20;
1664
1665        self.get_or_init_project(&project, cx);
1666        let project_state = self.projects.get(&project.entity_id()).unwrap();
1667        let stored_events = project_state.events(cx);
1668        let has_events = !stored_events.is_empty();
1669        let events: Vec<Arc<zeta_prompt::Event>> =
1670            stored_events.into_iter().map(|e| e.event).collect();
1671        let debug_tx = project_state.debug_tx.clone();
1672
1673        let snapshot = active_buffer.read(cx).snapshot();
1674        let cursor_point = position.to_point(&snapshot);
1675        let current_offset = position.to_offset(&snapshot);
1676
1677        let mut user_actions: Vec<UserActionRecord> =
1678            project_state.user_actions.iter().cloned().collect();
1679
1680        if let Some(last_action) = user_actions.last() {
1681            if last_action.buffer_id == active_buffer.entity_id()
1682                && current_offset != last_action.offset
1683            {
1684                let timestamp_epoch_ms = SystemTime::now()
1685                    .duration_since(UNIX_EPOCH)
1686                    .map(|d| d.as_millis() as u64)
1687                    .unwrap_or(0);
1688                user_actions.push(UserActionRecord {
1689                    action_type: UserActionType::CursorMovement,
1690                    buffer_id: active_buffer.entity_id(),
1691                    line_number: cursor_point.row,
1692                    offset: current_offset,
1693                    timestamp_epoch_ms,
1694                });
1695            }
1696        }
1697        let diagnostic_search_start = cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
1698        let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
1699        let diagnostic_search_range =
1700            Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
1701
1702        let related_files = self.context_for_project(&project, cx);
1703
1704        let inputs = EditPredictionModelInput {
1705            project: project.clone(),
1706            buffer: active_buffer.clone(),
1707            snapshot: snapshot.clone(),
1708            position,
1709            events,
1710            related_files,
1711            recent_paths: project_state.recent_paths.clone(),
1712            trigger,
1713            diagnostic_search_range: diagnostic_search_range.clone(),
1714            debug_tx,
1715            user_actions,
1716        };
1717
1718        let can_collect_example = snapshot
1719            .file()
1720            .is_some_and(|file| self.can_collect_file(&project, file, cx))
1721            && self.can_collect_events(&inputs.events, cx)
1722            && self.can_collect_related_files(&project, cx);
1723
1724        if can_collect_example && should_sample_edit_prediction_example_capture(cx) {
1725            let events_for_capture =
1726                self.edit_history_for_project_with_pause_split_last_event(&project, cx);
1727            let related_files_for_capture = inputs.related_files.clone();
1728            if let Some(example_task) = capture_example::capture_example(
1729                project.clone(),
1730                active_buffer.clone(),
1731                position,
1732                events_for_capture,
1733                related_files_for_capture,
1734                false,
1735                cx,
1736            ) {
1737                cx.spawn(async move |_this, _cx| {
1738                    let example = example_task.await?;
1739                    telemetry::event!("Edit Prediction Example Captured", example = example);
1740                    anyhow::Ok(())
1741                })
1742                .detach_and_log_err(cx);
1743            }
1744        }
1745        let task = match self.edit_prediction_model {
1746            EditPredictionModel::Zeta1 => {
1747                if should_send_testing_zeta2_request() {
1748                    let mut zeta2_inputs = inputs.clone();
1749                    zeta2_inputs.trigger = PredictEditsRequestTrigger::Testing;
1750                    zeta2::request_prediction_with_zeta2(
1751                        self,
1752                        zeta2_inputs,
1753                        Default::default(),
1754                        cx,
1755                    )
1756                    .detach();
1757                }
1758                zeta1::request_prediction_with_zeta1(self, inputs, cx)
1759            }
1760            EditPredictionModel::Zeta2 { version } => {
1761                zeta2::request_prediction_with_zeta2(self, inputs, version, cx)
1762            }
1763            EditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep(inputs, cx),
1764            EditPredictionModel::Mercury => self.mercury.request_prediction(inputs, cx),
1765        };
1766
1767        cx.spawn(async move |this, cx| {
1768            let prediction = task.await?;
1769
1770            if prediction.is_none() && allow_jump {
1771                let cursor_point = position.to_point(&snapshot);
1772                if has_events
1773                    && let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1774                        active_buffer.clone(),
1775                        &snapshot,
1776                        diagnostic_search_range,
1777                        cursor_point,
1778                        &project,
1779                        cx,
1780                    )
1781                    .await?
1782                {
1783                    return this
1784                        .update(cx, |this, cx| {
1785                            this.request_prediction_internal(
1786                                project,
1787                                jump_buffer,
1788                                jump_position,
1789                                trigger,
1790                                false,
1791                                cx,
1792                            )
1793                        })?
1794                        .await;
1795                }
1796
1797                return anyhow::Ok(None);
1798            }
1799
1800            Ok(prediction)
1801        })
1802    }
1803
1804    async fn next_diagnostic_location(
1805        active_buffer: Entity<Buffer>,
1806        active_buffer_snapshot: &BufferSnapshot,
1807        active_buffer_diagnostic_search_range: Range<Point>,
1808        active_buffer_cursor_point: Point,
1809        project: &Entity<Project>,
1810        cx: &mut AsyncApp,
1811    ) -> Result<Option<(Entity<Buffer>, language::Anchor)>> {
1812        // find the closest diagnostic to the cursor that wasn't close enough to be included in the last request
1813        let mut jump_location = active_buffer_snapshot
1814            .diagnostic_groups(None)
1815            .into_iter()
1816            .filter_map(|(_, group)| {
1817                let range = &group.entries[group.primary_ix]
1818                    .range
1819                    .to_point(&active_buffer_snapshot);
1820                if range.overlaps(&active_buffer_diagnostic_search_range) {
1821                    None
1822                } else {
1823                    Some(range.start)
1824                }
1825            })
1826            .min_by_key(|probe| probe.row.abs_diff(active_buffer_cursor_point.row))
1827            .map(|position| {
1828                (
1829                    active_buffer.clone(),
1830                    active_buffer_snapshot.anchor_before(position),
1831                )
1832            });
1833
1834        if jump_location.is_none() {
1835            let active_buffer_path = active_buffer.read_with(cx, |buffer, cx| {
1836                let file = buffer.file()?;
1837
1838                Some(ProjectPath {
1839                    worktree_id: file.worktree_id(cx),
1840                    path: file.path().clone(),
1841                })
1842            });
1843
1844            let buffer_task = project.update(cx, |project, cx| {
1845                let (path, _, _) = project
1846                    .diagnostic_summaries(false, cx)
1847                    .filter(|(path, _, _)| Some(path) != active_buffer_path.as_ref())
1848                    .max_by_key(|(path, _, _)| {
1849                        // find the buffer with errors that shares most parent directories
1850                        path.path
1851                            .components()
1852                            .zip(
1853                                active_buffer_path
1854                                    .as_ref()
1855                                    .map(|p| p.path.components())
1856                                    .unwrap_or_default(),
1857                            )
1858                            .take_while(|(a, b)| a == b)
1859                            .count()
1860                    })?;
1861
1862                Some(project.open_buffer(path, cx))
1863            });
1864
1865            if let Some(buffer_task) = buffer_task {
1866                let closest_buffer = buffer_task.await?;
1867
1868                jump_location = closest_buffer
1869                    .read_with(cx, |buffer, _cx| {
1870                        buffer
1871                            .buffer_diagnostics(None)
1872                            .into_iter()
1873                            .min_by_key(|entry| entry.diagnostic.severity)
1874                            .map(|entry| entry.range.start)
1875                    })
1876                    .map(|position| (closest_buffer, position));
1877            }
1878        }
1879
1880        anyhow::Ok(jump_location)
1881    }
1882
1883    async fn send_raw_llm_request(
1884        request: RawCompletionRequest,
1885        client: Arc<Client>,
1886        custom_url: Option<Arc<Url>>,
1887        llm_token: LlmApiToken,
1888        app_version: Version,
1889    ) -> Result<(RawCompletionResponse, Option<EditPredictionUsage>)> {
1890        let url = if let Some(custom_url) = custom_url {
1891            custom_url.as_ref().clone()
1892        } else {
1893            client
1894                .http_client()
1895                .build_zed_llm_url("/predict_edits/raw", &[])?
1896        };
1897
1898        Self::send_api_request(
1899            |builder| {
1900                let req = builder
1901                    .uri(url.as_ref())
1902                    .body(serde_json::to_string(&request)?.into());
1903                Ok(req?)
1904            },
1905            client,
1906            llm_token,
1907            app_version,
1908            true,
1909        )
1910        .await
1911    }
1912
1913    pub(crate) async fn send_v3_request(
1914        input: ZetaPromptInput,
1915        prompt_version: ZetaVersion,
1916        client: Arc<Client>,
1917        llm_token: LlmApiToken,
1918        app_version: Version,
1919        trigger: PredictEditsRequestTrigger,
1920    ) -> Result<(PredictEditsV3Response, Option<EditPredictionUsage>)> {
1921        let url = client
1922            .http_client()
1923            .build_zed_llm_url("/predict_edits/v3", &[])?;
1924
1925        let request = PredictEditsV3Request {
1926            input,
1927            model: EDIT_PREDICTIONS_MODEL_ID.clone(),
1928            prompt_version,
1929            trigger,
1930        };
1931
1932        Self::send_api_request(
1933            |builder| {
1934                let req = builder
1935                    .uri(url.as_ref())
1936                    .body(serde_json::to_string(&request)?.into());
1937                Ok(req?)
1938            },
1939            client,
1940            llm_token,
1941            app_version,
1942            true,
1943        )
1944        .await
1945    }
1946
1947    fn handle_api_response<T>(
1948        this: &WeakEntity<Self>,
1949        response: Result<(T, Option<EditPredictionUsage>)>,
1950        cx: &mut gpui::AsyncApp,
1951    ) -> Result<T> {
1952        match response {
1953            Ok((data, usage)) => {
1954                if let Some(usage) = usage {
1955                    this.update(cx, |this, cx| {
1956                        this.user_store.update(cx, |user_store, cx| {
1957                            user_store.update_edit_prediction_usage(usage, cx);
1958                        });
1959                    })
1960                    .ok();
1961                }
1962                Ok(data)
1963            }
1964            Err(err) => {
1965                if err.is::<ZedUpdateRequiredError>() {
1966                    cx.update(|cx| {
1967                        this.update(cx, |this, _cx| {
1968                            this.update_required = true;
1969                        })
1970                        .ok();
1971
1972                        let error_message: SharedString = err.to_string().into();
1973                        show_app_notification(
1974                            NotificationId::unique::<ZedUpdateRequiredError>(),
1975                            cx,
1976                            move |cx| {
1977                                cx.new(|cx| {
1978                                    ErrorMessagePrompt::new(error_message.clone(), cx)
1979                                        .with_link_button("Update Zed", "https://zed.dev/releases")
1980                                })
1981                            },
1982                        );
1983                    });
1984                }
1985                Err(err)
1986            }
1987        }
1988    }
1989
1990    async fn send_api_request<Res>(
1991        build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
1992        client: Arc<Client>,
1993        llm_token: LlmApiToken,
1994        app_version: Version,
1995        require_auth: bool,
1996    ) -> Result<(Res, Option<EditPredictionUsage>)>
1997    where
1998        Res: DeserializeOwned,
1999    {
2000        let http_client = client.http_client();
2001
2002        let mut token = if let Ok(custom_token) = std::env::var("ZED_PREDICT_EDITS_TOKEN") {
2003            Some(custom_token)
2004        } else if require_auth {
2005            Some(llm_token.acquire(&client).await?)
2006        } else {
2007            llm_token.acquire(&client).await.ok()
2008        };
2009        let mut did_retry = false;
2010
2011        loop {
2012            let request_builder = http_client::Request::builder().method(Method::POST);
2013
2014            let mut request_builder = request_builder
2015                .header("Content-Type", "application/json")
2016                .header(ZED_VERSION_HEADER_NAME, app_version.to_string());
2017
2018            // Only add Authorization header if we have a token
2019            if let Some(ref token_value) = token {
2020                request_builder =
2021                    request_builder.header("Authorization", format!("Bearer {}", token_value));
2022            }
2023
2024            let request = build(request_builder)?;
2025
2026            let mut response = http_client.send(request).await?;
2027
2028            if let Some(minimum_required_version) = response
2029                .headers()
2030                .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
2031                .and_then(|version| Version::from_str(version.to_str().ok()?).ok())
2032            {
2033                anyhow::ensure!(
2034                    app_version >= minimum_required_version,
2035                    ZedUpdateRequiredError {
2036                        minimum_version: minimum_required_version
2037                    }
2038                );
2039            }
2040
2041            if response.status().is_success() {
2042                let usage = EditPredictionUsage::from_headers(response.headers()).ok();
2043
2044                let mut body = Vec::new();
2045                response.body_mut().read_to_end(&mut body).await?;
2046                return Ok((serde_json::from_slice(&body)?, usage));
2047            } else if !did_retry && token.is_some() && response.needs_llm_token_refresh() {
2048                did_retry = true;
2049                token = Some(llm_token.refresh(&client).await?);
2050            } else {
2051                let mut body = String::new();
2052                response.body_mut().read_to_string(&mut body).await?;
2053                anyhow::bail!(
2054                    "Request failed with status: {:?}\nBody: {}",
2055                    response.status(),
2056                    body
2057                );
2058            }
2059        }
2060    }
2061
2062    pub fn refresh_context(
2063        &mut self,
2064        project: &Entity<Project>,
2065        buffer: &Entity<language::Buffer>,
2066        cursor_position: language::Anchor,
2067        cx: &mut Context<Self>,
2068    ) {
2069        self.get_or_init_project(project, cx)
2070            .context
2071            .update(cx, |store, cx| {
2072                store.refresh(buffer.clone(), cursor_position, cx);
2073            });
2074    }
2075
2076    #[cfg(feature = "cli-support")]
2077    pub fn set_context_for_buffer(
2078        &mut self,
2079        project: &Entity<Project>,
2080        related_files: Vec<RelatedFile>,
2081        cx: &mut Context<Self>,
2082    ) {
2083        self.get_or_init_project(project, cx)
2084            .context
2085            .update(cx, |store, cx| {
2086                store.set_related_files(related_files, cx);
2087            });
2088    }
2089
2090    #[cfg(feature = "cli-support")]
2091    pub fn set_recent_paths_for_project(
2092        &mut self,
2093        project: &Entity<Project>,
2094        paths: impl IntoIterator<Item = project::ProjectPath>,
2095        cx: &mut Context<Self>,
2096    ) {
2097        let project_state = self.get_or_init_project(project, cx);
2098        project_state.recent_paths = paths.into_iter().collect();
2099    }
2100
2101    fn is_file_open_source(
2102        &self,
2103        project: &Entity<Project>,
2104        file: &Arc<dyn File>,
2105        cx: &App,
2106    ) -> bool {
2107        if !file.is_local() || file.is_private() {
2108            return false;
2109        }
2110        let Some(project_state) = self.projects.get(&project.entity_id()) else {
2111            return false;
2112        };
2113        project_state
2114            .license_detection_watchers
2115            .get(&file.worktree_id(cx))
2116            .as_ref()
2117            .is_some_and(|watcher| watcher.is_project_open_source())
2118    }
2119
2120    fn can_collect_file(&self, project: &Entity<Project>, file: &Arc<dyn File>, cx: &App) -> bool {
2121        self.data_collection_choice.is_enabled(cx) && self.is_file_open_source(project, file, cx)
2122    }
2123
2124    fn can_collect_events(&self, events: &[Arc<zeta_prompt::Event>], cx: &App) -> bool {
2125        if !self.data_collection_choice.is_enabled(cx) {
2126            return false;
2127        }
2128        events.iter().all(|event| {
2129            matches!(
2130                event.as_ref(),
2131                zeta_prompt::Event::BufferChange {
2132                    in_open_source_repo: true,
2133                    ..
2134                }
2135            )
2136        })
2137    }
2138
2139    fn can_collect_related_files(&self, project: &Entity<Project>, cx: &mut App) -> bool {
2140        if !self.data_collection_choice.is_enabled(cx) {
2141            return false;
2142        }
2143
2144        let related_with_buffers = self.context_for_project_with_buffers(project, cx);
2145
2146        related_with_buffers.iter().all(|(_, buffer)| {
2147            buffer
2148                .read(cx)
2149                .file()
2150                .is_some_and(|file| self.is_file_open_source(project, &file, cx))
2151        })
2152    }
2153
2154    fn load_data_collection_choice() -> DataCollectionChoice {
2155        let choice = KEY_VALUE_STORE
2156            .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2157            .log_err()
2158            .flatten();
2159
2160        match choice.as_deref() {
2161            Some("true") => DataCollectionChoice::Enabled,
2162            Some("false") => DataCollectionChoice::Disabled,
2163            Some(_) => {
2164                log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
2165                DataCollectionChoice::NotAnswered
2166            }
2167            None => DataCollectionChoice::NotAnswered,
2168        }
2169    }
2170
2171    fn toggle_data_collection_choice(&mut self, cx: &mut Context<Self>) {
2172        self.data_collection_choice = self.data_collection_choice.toggle();
2173        let new_choice = self.data_collection_choice;
2174        let is_enabled = new_choice.is_enabled(cx);
2175        db::write_and_log(cx, move || {
2176            KEY_VALUE_STORE.write_kvp(
2177                ZED_PREDICT_DATA_COLLECTION_CHOICE.into(),
2178                is_enabled.to_string(),
2179            )
2180        });
2181    }
2182
2183    pub fn shown_predictions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
2184        self.shown_predictions.iter()
2185    }
2186
2187    pub fn shown_completions_len(&self) -> usize {
2188        self.shown_predictions.len()
2189    }
2190
2191    pub fn is_prediction_rated(&self, id: &EditPredictionId) -> bool {
2192        self.rated_predictions.contains(id)
2193    }
2194
2195    pub fn rate_prediction(
2196        &mut self,
2197        prediction: &EditPrediction,
2198        rating: EditPredictionRating,
2199        feedback: String,
2200        cx: &mut Context<Self>,
2201    ) {
2202        self.rated_predictions.insert(prediction.id.clone());
2203        telemetry::event!(
2204            "Edit Prediction Rated",
2205            rating,
2206            inputs = prediction.inputs,
2207            output = prediction
2208                .edit_preview
2209                .as_unified_diff(prediction.snapshot.file(), &prediction.edits),
2210            feedback
2211        );
2212        self.client.telemetry().flush_events().detach();
2213        cx.notify();
2214    }
2215}
2216
2217pub(crate) fn filter_redundant_excerpts(
2218    mut related_files: Vec<RelatedFile>,
2219    cursor_path: &Path,
2220    cursor_row_range: Range<u32>,
2221) -> Vec<RelatedFile> {
2222    for file in &mut related_files {
2223        if file.path.as_ref() == cursor_path {
2224            file.excerpts.retain(|excerpt| {
2225                excerpt.row_range.start < cursor_row_range.start
2226                    || excerpt.row_range.end > cursor_row_range.end
2227            });
2228        }
2229    }
2230    related_files.retain(|file| !file.excerpts.is_empty());
2231    related_files
2232}
2233
2234#[derive(Error, Debug)]
2235#[error(
2236    "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
2237)]
2238pub struct ZedUpdateRequiredError {
2239    minimum_version: Version,
2240}
2241
2242#[derive(Debug, Clone, Copy)]
2243pub enum DataCollectionChoice {
2244    NotAnswered,
2245    Enabled,
2246    Disabled,
2247}
2248
2249impl DataCollectionChoice {
2250    pub fn is_enabled(self, cx: &App) -> bool {
2251        if cx.is_staff() {
2252            return true;
2253        }
2254        match self {
2255            Self::Enabled => true,
2256            Self::NotAnswered | Self::Disabled => false,
2257        }
2258    }
2259
2260    #[must_use]
2261    pub fn toggle(&self) -> DataCollectionChoice {
2262        match self {
2263            Self::Enabled => Self::Disabled,
2264            Self::Disabled => Self::Enabled,
2265            Self::NotAnswered => Self::Enabled,
2266        }
2267    }
2268}
2269
2270impl From<bool> for DataCollectionChoice {
2271    fn from(value: bool) -> Self {
2272        match value {
2273            true => DataCollectionChoice::Enabled,
2274            false => DataCollectionChoice::Disabled,
2275        }
2276    }
2277}
2278
2279struct ZedPredictUpsell;
2280
2281impl Dismissable for ZedPredictUpsell {
2282    const KEY: &'static str = "dismissed-edit-predict-upsell";
2283
2284    fn dismissed() -> bool {
2285        // To make this backwards compatible with older versions of Zed, we
2286        // check if the user has seen the previous Edit Prediction Onboarding
2287        // before, by checking the data collection choice which was written to
2288        // the database once the user clicked on "Accept and Enable"
2289        if KEY_VALUE_STORE
2290            .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2291            .log_err()
2292            .is_some_and(|s| s.is_some())
2293        {
2294            return true;
2295        }
2296
2297        KEY_VALUE_STORE
2298            .read_kvp(Self::KEY)
2299            .log_err()
2300            .is_some_and(|s| s.is_some())
2301    }
2302}
2303
2304pub fn should_show_upsell_modal() -> bool {
2305    !ZedPredictUpsell::dismissed()
2306}
2307
2308pub fn init(cx: &mut App) {
2309    cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
2310        workspace.register_action(
2311            move |workspace, _: &zed_actions::OpenZedPredictOnboarding, window, cx| {
2312                ZedPredictModal::toggle(
2313                    workspace,
2314                    workspace.user_store().clone(),
2315                    workspace.client().clone(),
2316                    window,
2317                    cx,
2318                )
2319            },
2320        );
2321
2322        workspace.register_action(|workspace, _: &ResetOnboarding, _window, cx| {
2323            update_settings_file(workspace.app_state().fs.clone(), cx, move |settings, _| {
2324                settings
2325                    .project
2326                    .all_languages
2327                    .features
2328                    .get_or_insert_default()
2329                    .edit_prediction_provider = Some(EditPredictionProvider::None)
2330            });
2331        });
2332        fn copilot_for_project(project: &Entity<Project>, cx: &mut App) -> Option<Entity<Copilot>> {
2333            EditPredictionStore::try_global(cx).and_then(|store| {
2334                store.update(cx, |this, cx| this.start_copilot_for_project(project, cx))
2335            })
2336        }
2337
2338        workspace.register_action(|workspace, _: &SignIn, window, cx| {
2339            if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
2340                copilot_ui::initiate_sign_in(copilot, window, cx);
2341            }
2342        });
2343        workspace.register_action(|workspace, _: &Reinstall, window, cx| {
2344            if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
2345                copilot_ui::reinstall_and_sign_in(copilot, window, cx);
2346            }
2347        });
2348        workspace.register_action(|workspace, _: &SignOut, window, cx| {
2349            if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
2350                copilot_ui::initiate_sign_out(copilot, window, cx);
2351            }
2352        });
2353    })
2354    .detach();
2355}