edit_prediction.rs

   1use anyhow::Result;
   2use arrayvec::ArrayVec;
   3use client::{Client, EditPredictionUsage, UserStore};
   4use cloud_api_types::SubmitEditPredictionFeedbackBody;
   5use cloud_llm_client::predict_edits_v3::{
   6    PredictEditsV3Request, PredictEditsV3Response, RawCompletionRequest, RawCompletionResponse,
   7};
   8use cloud_llm_client::{
   9    EditPredictionRejectReason, EditPredictionRejection,
  10    MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
  11    PredictEditsRequestTrigger, RejectEditPredictionsBodyRef, ZED_VERSION_HEADER_NAME,
  12};
  13use collections::{HashMap, HashSet};
  14use copilot::{Copilot, Reinstall, SignIn, SignOut};
  15use db::kvp::{Dismissable, KEY_VALUE_STORE};
  16use edit_prediction_context::{RelatedExcerptStore, RelatedExcerptStoreEvent, RelatedFile};
  17use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
  18use futures::{
  19    AsyncReadExt as _, FutureExt as _, StreamExt as _,
  20    channel::mpsc::{self, UnboundedReceiver},
  21    select_biased,
  22};
  23use gpui::BackgroundExecutor;
  24use gpui::http_client::Url;
  25use gpui::{
  26    App, AsyncApp, Entity, EntityId, Global, SharedString, Subscription, Task, WeakEntity, actions,
  27    http_client::{self, AsyncBody, Method},
  28    prelude::*,
  29};
  30use language::language_settings::all_language_settings;
  31use language::{Anchor, Buffer, File, Point, TextBufferSnapshot, ToOffset, ToPoint};
  32use language::{BufferSnapshot, OffsetRangeExt};
  33use language_model::{LlmApiToken, NeedsLlmTokenRefresh, RefreshLlmTokenListener};
  34use project::{DisableAiSettings, Project, ProjectPath, WorktreeId};
  35use release_channel::AppVersion;
  36use semver::Version;
  37use serde::de::DeserializeOwned;
  38use settings::{
  39    EditPredictionPromptFormat, EditPredictionProvider, Settings as _, update_settings_file,
  40};
  41use std::collections::{VecDeque, hash_map};
  42use std::env;
  43use text::{AnchorRangeExt, Edit};
  44use workspace::Workspace;
  45use zeta_prompt::{ZetaFormat, ZetaPromptInput};
  46
  47use std::mem;
  48use std::ops::Range;
  49use std::path::Path;
  50use std::rc::Rc;
  51use std::str::FromStr as _;
  52use std::sync::Arc;
  53use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
  54use thiserror::Error;
  55use util::{RangeExt as _, ResultExt as _};
  56use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
  57
  58pub mod cursor_excerpt;
  59pub mod example_spec;
  60pub mod fim;
  61mod license_detection;
  62pub mod mercury;
  63pub mod ollama;
  64mod onboarding_modal;
  65pub mod open_ai_response;
  66mod prediction;
  67pub mod sweep_ai;
  68
  69pub mod udiff;
  70
  71mod capture_example;
  72mod zed_edit_prediction_delegate;
  73pub mod zeta;
  74
  75#[cfg(test)]
  76mod edit_prediction_tests;
  77
  78use crate::license_detection::LicenseDetectionWatcher;
  79use crate::mercury::Mercury;
  80use crate::onboarding_modal::ZedPredictModal;
  81pub use crate::prediction::EditPrediction;
  82pub use crate::prediction::EditPredictionId;
  83use crate::prediction::EditPredictionResult;
  84pub use crate::sweep_ai::SweepAi;
  85pub use capture_example::capture_example;
  86pub use language_model::ApiKeyState;
  87pub use telemetry_events::EditPredictionRating;
  88pub use zed_edit_prediction_delegate::ZedEditPredictionDelegate;
  89
  90actions!(
  91    edit_prediction,
  92    [
  93        /// Resets the edit prediction onboarding state.
  94        ResetOnboarding,
  95        /// Clears the edit prediction history.
  96        ClearHistory,
  97    ]
  98);
  99
 100/// Maximum number of events to track.
 101const EVENT_COUNT_MAX: usize = 6;
 102const CHANGE_GROUPING_LINE_SPAN: u32 = 8;
 103const LAST_CHANGE_GROUPING_TIME: Duration = Duration::from_secs(1);
 104const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
 105const REJECT_REQUEST_DEBOUNCE: Duration = Duration::from_secs(15);
 106const EDIT_PREDICTION_SETTLED_EVENT: &str = "Edit Prediction Settled";
 107const EDIT_PREDICTION_SETTLED_TTL: Duration = Duration::from_secs(60 * 5);
 108const EDIT_PREDICTION_SETTLED_QUIESCENCE: Duration = Duration::from_secs(10);
 109
 110pub struct Zeta2FeatureFlag;
 111pub struct EditPredictionJumpsFeatureFlag;
 112
 113impl FeatureFlag for Zeta2FeatureFlag {
 114    const NAME: &'static str = "zeta2";
 115}
 116
 117impl FeatureFlag for EditPredictionJumpsFeatureFlag {
 118    const NAME: &'static str = "edit_prediction_jumps";
 119}
 120
 121#[derive(Clone)]
 122struct EditPredictionStoreGlobal(Entity<EditPredictionStore>);
 123
 124impl Global for EditPredictionStoreGlobal {}
 125
 126/// Configuration for using the raw Zeta2 endpoint.
 127/// When set, the client uses the raw endpoint and constructs the prompt itself.
 128/// The version is also used as the Baseten environment name (lowercased).
 129#[derive(Clone)]
 130pub struct Zeta2RawConfig {
 131    pub model_id: Option<String>,
 132    pub format: ZetaFormat,
 133}
 134
 135pub struct EditPredictionStore {
 136    client: Arc<Client>,
 137    user_store: Entity<UserStore>,
 138    llm_token: LlmApiToken,
 139    _llm_token_subscription: Subscription,
 140    projects: HashMap<EntityId, ProjectState>,
 141    update_required: bool,
 142    edit_prediction_model: EditPredictionModel,
 143    zeta2_raw_config: Option<Zeta2RawConfig>,
 144    pub sweep_ai: SweepAi,
 145    pub mercury: Mercury,
 146    data_collection_choice: DataCollectionChoice,
 147    reject_predictions_tx: mpsc::UnboundedSender<EditPredictionRejection>,
 148    settled_predictions_tx: mpsc::UnboundedSender<Instant>,
 149    shown_predictions: VecDeque<EditPrediction>,
 150    rated_predictions: HashSet<EditPredictionId>,
 151    #[cfg(test)]
 152    settled_event_callback: Option<Box<dyn Fn(EditPredictionId, String)>>,
 153}
 154
 155#[derive(Copy, Clone, PartialEq, Eq)]
 156pub enum EditPredictionModel {
 157    Zeta1,
 158    Zeta2,
 159    Fim { format: EditPredictionPromptFormat },
 160    Sweep,
 161    Mercury,
 162}
 163
 164#[derive(Clone)]
 165pub struct EditPredictionModelInput {
 166    project: Entity<Project>,
 167    buffer: Entity<Buffer>,
 168    snapshot: BufferSnapshot,
 169    position: Anchor,
 170    events: Vec<Arc<zeta_prompt::Event>>,
 171    related_files: Vec<RelatedFile>,
 172    recent_paths: VecDeque<ProjectPath>,
 173    trigger: PredictEditsRequestTrigger,
 174    diagnostic_search_range: Range<Point>,
 175    debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
 176    can_collect_data: bool,
 177    is_open_source: bool,
 178    pub user_actions: Vec<UserActionRecord>,
 179}
 180
 181#[derive(Debug)]
 182pub enum DebugEvent {
 183    ContextRetrievalStarted(ContextRetrievalStartedDebugEvent),
 184    ContextRetrievalFinished(ContextRetrievalFinishedDebugEvent),
 185    EditPredictionStarted(EditPredictionStartedDebugEvent),
 186    EditPredictionFinished(EditPredictionFinishedDebugEvent),
 187}
 188
 189#[derive(Debug)]
 190pub struct ContextRetrievalStartedDebugEvent {
 191    pub project_entity_id: EntityId,
 192    pub timestamp: Instant,
 193    pub search_prompt: String,
 194}
 195
 196#[derive(Debug)]
 197pub struct ContextRetrievalFinishedDebugEvent {
 198    pub project_entity_id: EntityId,
 199    pub timestamp: Instant,
 200    pub metadata: Vec<(&'static str, SharedString)>,
 201}
 202
 203#[derive(Debug)]
 204pub struct EditPredictionStartedDebugEvent {
 205    pub buffer: WeakEntity<Buffer>,
 206    pub position: Anchor,
 207    pub prompt: Option<String>,
 208}
 209
 210#[derive(Debug)]
 211pub struct EditPredictionFinishedDebugEvent {
 212    pub buffer: WeakEntity<Buffer>,
 213    pub position: Anchor,
 214    pub model_output: Option<String>,
 215}
 216
 217const USER_ACTION_HISTORY_SIZE: usize = 16;
 218
 219#[derive(Clone, Debug)]
 220pub struct UserActionRecord {
 221    pub action_type: UserActionType,
 222    pub buffer_id: EntityId,
 223    pub line_number: u32,
 224    pub offset: usize,
 225    pub timestamp_epoch_ms: u64,
 226}
 227
 228#[derive(Clone, Copy, Debug, PartialEq, Eq)]
 229pub enum UserActionType {
 230    InsertChar,
 231    InsertSelection,
 232    DeleteChar,
 233    DeleteSelection,
 234    CursorMovement,
 235}
 236
 237/// An event with associated metadata for reconstructing buffer state.
 238#[derive(Clone)]
 239pub struct StoredEvent {
 240    pub event: Arc<zeta_prompt::Event>,
 241    pub old_snapshot: TextBufferSnapshot,
 242    pub edit_range: Range<Anchor>,
 243}
 244
 245impl StoredEvent {
 246    fn can_merge(
 247        &self,
 248        next_old_event: &&&StoredEvent,
 249        new_snapshot: &TextBufferSnapshot,
 250        last_edit_range: &Range<Anchor>,
 251    ) -> bool {
 252        // Events must be for the same buffer
 253        if self.old_snapshot.remote_id() != next_old_event.old_snapshot.remote_id() {
 254            return false;
 255        }
 256        if self.old_snapshot.remote_id() != new_snapshot.remote_id() {
 257            return false;
 258        }
 259
 260        let a_is_predicted = matches!(
 261            self.event.as_ref(),
 262            zeta_prompt::Event::BufferChange {
 263                predicted: true,
 264                ..
 265            }
 266        );
 267        let b_is_predicted = matches!(
 268            next_old_event.event.as_ref(),
 269            zeta_prompt::Event::BufferChange {
 270                predicted: true,
 271                ..
 272            }
 273        );
 274
 275        // If events come from the same source (both predicted or both manual) then
 276        // we would have coalesced them already.
 277        if a_is_predicted == b_is_predicted {
 278            return false;
 279        }
 280
 281        let left_range = self.edit_range.to_point(new_snapshot);
 282        let right_range = next_old_event.edit_range.to_point(new_snapshot);
 283        let latest_range = last_edit_range.to_point(&new_snapshot);
 284
 285        // Events near to the latest edit are not merged if their sources differ.
 286        if lines_between_ranges(&left_range, &latest_range)
 287            .min(lines_between_ranges(&right_range, &latest_range))
 288            <= CHANGE_GROUPING_LINE_SPAN
 289        {
 290            return false;
 291        }
 292
 293        // Events that are distant from each other are not merged.
 294        if lines_between_ranges(&left_range, &right_range) > CHANGE_GROUPING_LINE_SPAN {
 295            return false;
 296        }
 297
 298        true
 299    }
 300}
 301
 302fn lines_between_ranges(left: &Range<Point>, right: &Range<Point>) -> u32 {
 303    if left.start > right.end {
 304        return left.start.row - right.end.row;
 305    }
 306    if right.start > left.end {
 307        return right.start.row - left.end.row;
 308    }
 309    0
 310}
 311
 312struct ProjectState {
 313    events: VecDeque<StoredEvent>,
 314    last_event: Option<LastEvent>,
 315    recent_paths: VecDeque<ProjectPath>,
 316    registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
 317    current_prediction: Option<CurrentEditPrediction>,
 318    next_pending_prediction_id: usize,
 319    pending_predictions: ArrayVec<PendingPrediction, 2>,
 320    debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
 321    last_edit_prediction_refresh: Option<(EntityId, Instant)>,
 322    last_jump_prediction_refresh: Option<(EntityId, Instant)>,
 323    cancelled_predictions: HashSet<usize>,
 324    context: Entity<RelatedExcerptStore>,
 325    license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
 326    user_actions: VecDeque<UserActionRecord>,
 327    _subscriptions: [gpui::Subscription; 2],
 328    copilot: Option<Entity<Copilot>>,
 329}
 330
 331impl ProjectState {
 332    fn record_user_action(&mut self, action: UserActionRecord) {
 333        if self.user_actions.len() >= USER_ACTION_HISTORY_SIZE {
 334            self.user_actions.pop_front();
 335        }
 336        self.user_actions.push_back(action);
 337    }
 338
 339    pub fn events(&self, cx: &App) -> Vec<StoredEvent> {
 340        self.events
 341            .iter()
 342            .cloned()
 343            .chain(self.last_event.as_ref().iter().flat_map(|event| {
 344                let (one, two) = event.split_by_pause();
 345                let one = one.finalize(&self.license_detection_watchers, cx);
 346                let two = two.and_then(|two| two.finalize(&self.license_detection_watchers, cx));
 347                one.into_iter().chain(two)
 348            }))
 349            .collect()
 350    }
 351
 352    fn cancel_pending_prediction(
 353        &mut self,
 354        pending_prediction: PendingPrediction,
 355        cx: &mut Context<EditPredictionStore>,
 356    ) {
 357        self.cancelled_predictions.insert(pending_prediction.id);
 358
 359        if pending_prediction.drop_on_cancel {
 360            drop(pending_prediction.task);
 361        } else {
 362            cx.spawn(async move |this, cx| {
 363                let Some(prediction_id) = pending_prediction.task.await else {
 364                    return;
 365                };
 366
 367                this.update(cx, |this, cx| {
 368                    this.reject_prediction(
 369                        prediction_id,
 370                        EditPredictionRejectReason::Canceled,
 371                        false,
 372                        None,
 373                        cx,
 374                    );
 375                })
 376                .ok();
 377            })
 378            .detach()
 379        }
 380    }
 381
 382    fn active_buffer(
 383        &self,
 384        project: &Entity<Project>,
 385        cx: &App,
 386    ) -> Option<(Entity<Buffer>, Option<Anchor>)> {
 387        let project = project.read(cx);
 388        let active_path = project.path_for_entry(project.active_entry()?, cx)?;
 389        let active_buffer = project.buffer_store().read(cx).get_by_path(&active_path)?;
 390        let registered_buffer = self.registered_buffers.get(&active_buffer.entity_id())?;
 391        Some((active_buffer, registered_buffer.last_position))
 392    }
 393}
 394
 395#[derive(Debug, Clone)]
 396struct CurrentEditPrediction {
 397    pub requested_by: PredictionRequestedBy,
 398    pub prediction: EditPrediction,
 399    pub was_shown: bool,
 400    pub shown_with: Option<edit_prediction_types::SuggestionDisplayType>,
 401}
 402
 403impl CurrentEditPrediction {
 404    fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
 405        let Some(new_edits) = self
 406            .prediction
 407            .interpolate(&self.prediction.buffer.read(cx))
 408        else {
 409            return false;
 410        };
 411
 412        if self.prediction.buffer != old_prediction.prediction.buffer {
 413            return true;
 414        }
 415
 416        let Some(old_edits) = old_prediction
 417            .prediction
 418            .interpolate(&old_prediction.prediction.buffer.read(cx))
 419        else {
 420            return true;
 421        };
 422
 423        let requested_by_buffer_id = self.requested_by.buffer_id();
 424
 425        // This reduces the occurrence of UI thrash from replacing edits
 426        //
 427        // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
 428        if requested_by_buffer_id == Some(self.prediction.buffer.entity_id())
 429            && requested_by_buffer_id == Some(old_prediction.prediction.buffer.entity_id())
 430            && old_edits.len() == 1
 431            && new_edits.len() == 1
 432        {
 433            let (old_range, old_text) = &old_edits[0];
 434            let (new_range, new_text) = &new_edits[0];
 435            new_range == old_range && new_text.starts_with(old_text.as_ref())
 436        } else {
 437            true
 438        }
 439    }
 440}
 441
 442#[derive(Debug, Clone)]
 443enum PredictionRequestedBy {
 444    DiagnosticsUpdate,
 445    Buffer(EntityId),
 446}
 447
 448impl PredictionRequestedBy {
 449    pub fn buffer_id(&self) -> Option<EntityId> {
 450        match self {
 451            PredictionRequestedBy::DiagnosticsUpdate => None,
 452            PredictionRequestedBy::Buffer(buffer_id) => Some(*buffer_id),
 453        }
 454    }
 455}
 456
 457const DIAGNOSTIC_LINES_RANGE: u32 = 20;
 458
 459#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
 460pub enum DiagnosticSearchScope {
 461    Local,
 462    Global,
 463}
 464
 465#[derive(Debug)]
 466struct PendingPrediction {
 467    id: usize,
 468    task: Task<Option<EditPredictionId>>,
 469    /// If true, the task is dropped immediately on cancel (cancelling the HTTP request).
 470    /// If false, the task is awaited to completion so rejection can be reported.
 471    drop_on_cancel: bool,
 472}
 473
 474/// A prediction from the perspective of a buffer.
 475#[derive(Debug)]
 476enum BufferEditPrediction<'a> {
 477    Local { prediction: &'a EditPrediction },
 478    Jump { prediction: &'a EditPrediction },
 479}
 480
 481#[cfg(test)]
 482impl std::ops::Deref for BufferEditPrediction<'_> {
 483    type Target = EditPrediction;
 484
 485    fn deref(&self) -> &Self::Target {
 486        match self {
 487            BufferEditPrediction::Local { prediction } => prediction,
 488            BufferEditPrediction::Jump { prediction } => prediction,
 489        }
 490    }
 491}
 492
 493#[derive(Clone)]
 494struct PendingSettledPrediction {
 495    request_id: EditPredictionId,
 496    editable_anchor_range: Range<Anchor>,
 497    enqueued_at: Instant,
 498    last_edit_at: Instant,
 499}
 500
 501struct RegisteredBuffer {
 502    file: Option<Arc<dyn File>>,
 503    snapshot: TextBufferSnapshot,
 504    pending_predictions: Vec<PendingSettledPrediction>,
 505    last_position: Option<Anchor>,
 506    _subscriptions: [gpui::Subscription; 2],
 507}
 508
 509#[derive(Clone)]
 510struct LastEvent {
 511    old_snapshot: TextBufferSnapshot,
 512    new_snapshot: TextBufferSnapshot,
 513    old_file: Option<Arc<dyn File>>,
 514    new_file: Option<Arc<dyn File>>,
 515    edit_range: Option<Range<Anchor>>,
 516    predicted: bool,
 517    snapshot_after_last_editing_pause: Option<TextBufferSnapshot>,
 518    last_edit_time: Option<Instant>,
 519}
 520
 521impl LastEvent {
 522    pub fn finalize(
 523        &self,
 524        license_detection_watchers: &HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
 525        cx: &App,
 526    ) -> Option<StoredEvent> {
 527        let path = buffer_path_with_id_fallback(self.new_file.as_ref(), &self.new_snapshot, cx);
 528        let old_path = buffer_path_with_id_fallback(self.old_file.as_ref(), &self.old_snapshot, cx);
 529
 530        let in_open_source_repo =
 531            [self.new_file.as_ref(), self.old_file.as_ref()]
 532                .iter()
 533                .all(|file| {
 534                    file.is_some_and(|file| {
 535                        license_detection_watchers
 536                            .get(&file.worktree_id(cx))
 537                            .is_some_and(|watcher| watcher.is_project_open_source())
 538                    })
 539                });
 540
 541        let (diff, edit_range) =
 542            compute_diff_between_snapshots(&self.old_snapshot, &self.new_snapshot)?;
 543
 544        if path == old_path && diff.is_empty() {
 545            None
 546        } else {
 547            Some(StoredEvent {
 548                event: Arc::new(zeta_prompt::Event::BufferChange {
 549                    old_path,
 550                    path,
 551                    diff,
 552                    in_open_source_repo,
 553                    predicted: self.predicted,
 554                }),
 555                edit_range: self.new_snapshot.anchor_before(edit_range.start)
 556                    ..self.new_snapshot.anchor_before(edit_range.end),
 557                old_snapshot: self.old_snapshot.clone(),
 558            })
 559        }
 560    }
 561
 562    pub fn split_by_pause(&self) -> (LastEvent, Option<LastEvent>) {
 563        let Some(boundary_snapshot) = self.snapshot_after_last_editing_pause.as_ref() else {
 564            return (self.clone(), None);
 565        };
 566
 567        let before = LastEvent {
 568            old_snapshot: self.old_snapshot.clone(),
 569            new_snapshot: boundary_snapshot.clone(),
 570            old_file: self.old_file.clone(),
 571            new_file: self.new_file.clone(),
 572            edit_range: None,
 573            predicted: self.predicted,
 574            snapshot_after_last_editing_pause: None,
 575            last_edit_time: self.last_edit_time,
 576        };
 577
 578        let after = LastEvent {
 579            old_snapshot: boundary_snapshot.clone(),
 580            new_snapshot: self.new_snapshot.clone(),
 581            old_file: self.old_file.clone(),
 582            new_file: self.new_file.clone(),
 583            edit_range: None,
 584            predicted: self.predicted,
 585            snapshot_after_last_editing_pause: None,
 586            last_edit_time: self.last_edit_time,
 587        };
 588
 589        (before, Some(after))
 590    }
 591}
 592
 593pub(crate) fn compute_diff_between_snapshots(
 594    old_snapshot: &TextBufferSnapshot,
 595    new_snapshot: &TextBufferSnapshot,
 596) -> Option<(String, Range<Point>)> {
 597    let edits: Vec<Edit<usize>> = new_snapshot
 598        .edits_since::<usize>(&old_snapshot.version)
 599        .collect();
 600
 601    let (first_edit, last_edit) = edits.first().zip(edits.last())?;
 602
 603    let old_start_point = old_snapshot.offset_to_point(first_edit.old.start);
 604    let old_end_point = old_snapshot.offset_to_point(last_edit.old.end);
 605    let new_start_point = new_snapshot.offset_to_point(first_edit.new.start);
 606    let new_end_point = new_snapshot.offset_to_point(last_edit.new.end);
 607
 608    const CONTEXT_LINES: u32 = 3;
 609
 610    let old_context_start_row = old_start_point.row.saturating_sub(CONTEXT_LINES);
 611    let new_context_start_row = new_start_point.row.saturating_sub(CONTEXT_LINES);
 612    let old_context_end_row =
 613        (old_end_point.row + 1 + CONTEXT_LINES).min(old_snapshot.max_point().row);
 614    let new_context_end_row =
 615        (new_end_point.row + 1 + CONTEXT_LINES).min(new_snapshot.max_point().row);
 616
 617    let old_start_line_offset = old_snapshot.point_to_offset(Point::new(old_context_start_row, 0));
 618    let new_start_line_offset = new_snapshot.point_to_offset(Point::new(new_context_start_row, 0));
 619    let old_end_line_offset = old_snapshot
 620        .point_to_offset(Point::new(old_context_end_row + 1, 0).min(old_snapshot.max_point()));
 621    let new_end_line_offset = new_snapshot
 622        .point_to_offset(Point::new(new_context_end_row + 1, 0).min(new_snapshot.max_point()));
 623    let old_edit_range = old_start_line_offset..old_end_line_offset;
 624    let new_edit_range = new_start_line_offset..new_end_line_offset;
 625
 626    let old_region_text: String = old_snapshot.text_for_range(old_edit_range).collect();
 627    let new_region_text: String = new_snapshot.text_for_range(new_edit_range).collect();
 628
 629    let diff = language::unified_diff_with_offsets(
 630        &old_region_text,
 631        &new_region_text,
 632        old_context_start_row,
 633        new_context_start_row,
 634    );
 635
 636    Some((diff, new_start_point..new_end_point))
 637}
 638
 639fn buffer_path_with_id_fallback(
 640    file: Option<&Arc<dyn File>>,
 641    snapshot: &TextBufferSnapshot,
 642    cx: &App,
 643) -> Arc<Path> {
 644    if let Some(file) = file {
 645        file.full_path(cx).into()
 646    } else {
 647        Path::new(&format!("untitled-{}", snapshot.remote_id())).into()
 648    }
 649}
 650
 651impl EditPredictionStore {
 652    pub fn try_global(cx: &App) -> Option<Entity<Self>> {
 653        cx.try_global::<EditPredictionStoreGlobal>()
 654            .map(|global| global.0.clone())
 655    }
 656
 657    pub fn global(
 658        client: &Arc<Client>,
 659        user_store: &Entity<UserStore>,
 660        cx: &mut App,
 661    ) -> Entity<Self> {
 662        cx.try_global::<EditPredictionStoreGlobal>()
 663            .map(|global| global.0.clone())
 664            .unwrap_or_else(|| {
 665                let ep_store = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
 666                cx.set_global(EditPredictionStoreGlobal(ep_store.clone()));
 667                ep_store
 668            })
 669    }
 670
 671    pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
 672        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
 673        let data_collection_choice = Self::load_data_collection_choice();
 674
 675        let llm_token = LlmApiToken::default();
 676
 677        let (reject_tx, reject_rx) = mpsc::unbounded();
 678        cx.background_spawn({
 679            let client = client.clone();
 680            let llm_token = llm_token.clone();
 681            let app_version = AppVersion::global(cx);
 682            let background_executor = cx.background_executor().clone();
 683            async move {
 684                Self::handle_rejected_predictions(
 685                    reject_rx,
 686                    client,
 687                    llm_token,
 688                    app_version,
 689                    background_executor,
 690                )
 691                .await
 692            }
 693        })
 694        .detach();
 695
 696        let (settled_predictions_tx, settled_predictions_rx) = mpsc::unbounded();
 697        cx.spawn(async move |this, cx| {
 698            Self::run_settled_predictions_worker(this, settled_predictions_rx, cx).await;
 699        })
 700        .detach();
 701
 702        let this = Self {
 703            projects: HashMap::default(),
 704            client,
 705            user_store,
 706            llm_token,
 707            _llm_token_subscription: cx.subscribe(
 708                &refresh_llm_token_listener,
 709                |this, _listener, _event, cx| {
 710                    let client = this.client.clone();
 711                    let llm_token = this.llm_token.clone();
 712                    cx.spawn(async move |_this, _cx| {
 713                        llm_token.refresh(&client).await?;
 714                        anyhow::Ok(())
 715                    })
 716                    .detach_and_log_err(cx);
 717                },
 718            ),
 719            update_required: false,
 720            edit_prediction_model: EditPredictionModel::Zeta2,
 721            zeta2_raw_config: Self::zeta2_raw_config_from_env(),
 722            sweep_ai: SweepAi::new(cx),
 723            mercury: Mercury::new(cx),
 724
 725            data_collection_choice,
 726            reject_predictions_tx: reject_tx,
 727            settled_predictions_tx,
 728            rated_predictions: Default::default(),
 729            shown_predictions: Default::default(),
 730            #[cfg(test)]
 731            settled_event_callback: None,
 732        };
 733
 734        this
 735    }
 736
 737    fn zeta2_raw_config_from_env() -> Option<Zeta2RawConfig> {
 738        let version_str = env::var("ZED_ZETA_FORMAT").ok()?;
 739        let format = ZetaFormat::parse(&version_str).ok()?;
 740        let model_id = env::var("ZED_ZETA_MODEL").ok();
 741        Some(Zeta2RawConfig { model_id, format })
 742    }
 743
 744    pub fn set_edit_prediction_model(&mut self, model: EditPredictionModel) {
 745        self.edit_prediction_model = model;
 746    }
 747
 748    pub fn set_zeta2_raw_config(&mut self, config: Zeta2RawConfig) {
 749        self.zeta2_raw_config = Some(config);
 750    }
 751
 752    pub fn zeta2_raw_config(&self) -> Option<&Zeta2RawConfig> {
 753        self.zeta2_raw_config.as_ref()
 754    }
 755
 756    pub fn icons(&self, cx: &App) -> edit_prediction_types::EditPredictionIconSet {
 757        use ui::IconName;
 758        match self.edit_prediction_model {
 759            EditPredictionModel::Sweep => {
 760                edit_prediction_types::EditPredictionIconSet::new(IconName::SweepAi)
 761                    .with_disabled(IconName::SweepAiDisabled)
 762                    .with_up(IconName::SweepAiUp)
 763                    .with_down(IconName::SweepAiDown)
 764                    .with_error(IconName::SweepAiError)
 765            }
 766            EditPredictionModel::Mercury => {
 767                edit_prediction_types::EditPredictionIconSet::new(IconName::Inception)
 768            }
 769            EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {
 770                edit_prediction_types::EditPredictionIconSet::new(IconName::ZedPredict)
 771                    .with_disabled(IconName::ZedPredictDisabled)
 772                    .with_up(IconName::ZedPredictUp)
 773                    .with_down(IconName::ZedPredictDown)
 774                    .with_error(IconName::ZedPredictError)
 775            }
 776            EditPredictionModel::Fim { .. } => {
 777                let settings = &all_language_settings(None, cx).edit_predictions;
 778                match settings.provider {
 779                    EditPredictionProvider::Ollama => {
 780                        edit_prediction_types::EditPredictionIconSet::new(IconName::AiOllama)
 781                    }
 782                    _ => {
 783                        edit_prediction_types::EditPredictionIconSet::new(IconName::AiOpenAiCompat)
 784                    }
 785                }
 786            }
 787        }
 788    }
 789
 790    pub fn has_sweep_api_token(&self, cx: &App) -> bool {
 791        self.sweep_ai.api_token.read(cx).has_key()
 792    }
 793
 794    pub fn has_mercury_api_token(&self, cx: &App) -> bool {
 795        self.mercury.api_token.read(cx).has_key()
 796    }
 797
 798    pub fn clear_history(&mut self) {
 799        for project_state in self.projects.values_mut() {
 800            project_state.events.clear();
 801            project_state.last_event.take();
 802        }
 803    }
 804
 805    pub fn clear_history_for_project(&mut self, project: &Entity<Project>) {
 806        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
 807            project_state.events.clear();
 808            project_state.last_event.take();
 809        }
 810    }
 811
 812    pub fn edit_history_for_project(
 813        &self,
 814        project: &Entity<Project>,
 815        cx: &App,
 816    ) -> Vec<StoredEvent> {
 817        self.projects
 818            .get(&project.entity_id())
 819            .map(|project_state| project_state.events(cx))
 820            .unwrap_or_default()
 821    }
 822
 823    pub fn context_for_project<'a>(
 824        &'a self,
 825        project: &Entity<Project>,
 826        cx: &'a mut App,
 827    ) -> Vec<RelatedFile> {
 828        self.projects
 829            .get(&project.entity_id())
 830            .map(|project_state| {
 831                project_state.context.update(cx, |context, cx| {
 832                    context
 833                        .related_files_with_buffers(cx)
 834                        .map(|(mut related_file, buffer)| {
 835                            related_file.in_open_source_repo = buffer
 836                                .read(cx)
 837                                .file()
 838                                .map_or(false, |file| self.is_file_open_source(&project, file, cx));
 839                            related_file
 840                        })
 841                        .collect()
 842                })
 843            })
 844            .unwrap_or_default()
 845    }
 846
 847    pub fn copilot_for_project(&self, project: &Entity<Project>) -> Option<Entity<Copilot>> {
 848        self.projects
 849            .get(&project.entity_id())
 850            .and_then(|project| project.copilot.clone())
 851    }
 852
 853    pub fn start_copilot_for_project(
 854        &mut self,
 855        project: &Entity<Project>,
 856        cx: &mut Context<Self>,
 857    ) -> Option<Entity<Copilot>> {
 858        if DisableAiSettings::get(None, cx).disable_ai {
 859            return None;
 860        }
 861        let state = self.get_or_init_project(project, cx);
 862
 863        if state.copilot.is_some() {
 864            return state.copilot.clone();
 865        }
 866        let _project = project.clone();
 867        let project = project.read(cx);
 868
 869        let node = project.node_runtime().cloned();
 870        if let Some(node) = node {
 871            let next_id = project.languages().next_language_server_id();
 872            let fs = project.fs().clone();
 873
 874            let copilot = cx.new(|cx| Copilot::new(Some(_project), next_id, fs, node, cx));
 875            state.copilot = Some(copilot.clone());
 876            Some(copilot)
 877        } else {
 878            None
 879        }
 880    }
 881
 882    pub fn context_for_project_with_buffers<'a>(
 883        &'a self,
 884        project: &Entity<Project>,
 885        cx: &'a mut App,
 886    ) -> Vec<(RelatedFile, Entity<Buffer>)> {
 887        self.projects
 888            .get(&project.entity_id())
 889            .map(|project| {
 890                project.context.update(cx, |context, cx| {
 891                    context.related_files_with_buffers(cx).collect()
 892                })
 893            })
 894            .unwrap_or_default()
 895    }
 896
 897    pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
 898        if matches!(
 899            self.edit_prediction_model,
 900            EditPredictionModel::Zeta2 | EditPredictionModel::Zeta1
 901        ) {
 902            self.user_store.read(cx).edit_prediction_usage()
 903        } else {
 904            None
 905        }
 906    }
 907
 908    pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
 909        self.get_or_init_project(project, cx);
 910    }
 911
 912    pub fn register_buffer(
 913        &mut self,
 914        buffer: &Entity<Buffer>,
 915        project: &Entity<Project>,
 916        cx: &mut Context<Self>,
 917    ) {
 918        let project_state = self.get_or_init_project(project, cx);
 919        Self::register_buffer_impl(project_state, buffer, project, cx);
 920    }
 921
 922    fn get_or_init_project(
 923        &mut self,
 924        project: &Entity<Project>,
 925        cx: &mut Context<Self>,
 926    ) -> &mut ProjectState {
 927        let entity_id = project.entity_id();
 928        self.projects
 929            .entry(entity_id)
 930            .or_insert_with(|| ProjectState {
 931                context: {
 932                    let related_excerpt_store = cx.new(|cx| RelatedExcerptStore::new(project, cx));
 933                    cx.subscribe(&related_excerpt_store, move |this, _, event, _| {
 934                        this.handle_excerpt_store_event(entity_id, event);
 935                    })
 936                    .detach();
 937                    related_excerpt_store
 938                },
 939                events: VecDeque::new(),
 940                last_event: None,
 941                recent_paths: VecDeque::new(),
 942                debug_tx: None,
 943                registered_buffers: HashMap::default(),
 944                current_prediction: None,
 945                cancelled_predictions: HashSet::default(),
 946                pending_predictions: ArrayVec::new(),
 947                next_pending_prediction_id: 0,
 948                last_edit_prediction_refresh: None,
 949                last_jump_prediction_refresh: None,
 950                license_detection_watchers: HashMap::default(),
 951                user_actions: VecDeque::with_capacity(USER_ACTION_HISTORY_SIZE),
 952                _subscriptions: [
 953                    cx.subscribe(&project, Self::handle_project_event),
 954                    cx.observe_release(&project, move |this, _, cx| {
 955                        this.projects.remove(&entity_id);
 956                        cx.notify();
 957                    }),
 958                ],
 959                copilot: None,
 960            })
 961    }
 962
 963    pub fn remove_project(&mut self, project: &Entity<Project>) {
 964        self.projects.remove(&project.entity_id());
 965    }
 966
 967    fn handle_excerpt_store_event(
 968        &mut self,
 969        project_entity_id: EntityId,
 970        event: &RelatedExcerptStoreEvent,
 971    ) {
 972        if let Some(project_state) = self.projects.get(&project_entity_id) {
 973            if let Some(debug_tx) = project_state.debug_tx.clone() {
 974                match event {
 975                    RelatedExcerptStoreEvent::StartedRefresh => {
 976                        debug_tx
 977                            .unbounded_send(DebugEvent::ContextRetrievalStarted(
 978                                ContextRetrievalStartedDebugEvent {
 979                                    project_entity_id: project_entity_id,
 980                                    timestamp: Instant::now(),
 981                                    search_prompt: String::new(),
 982                                },
 983                            ))
 984                            .ok();
 985                    }
 986                    RelatedExcerptStoreEvent::FinishedRefresh {
 987                        cache_hit_count,
 988                        cache_miss_count,
 989                        mean_definition_latency,
 990                        max_definition_latency,
 991                    } => {
 992                        debug_tx
 993                            .unbounded_send(DebugEvent::ContextRetrievalFinished(
 994                                ContextRetrievalFinishedDebugEvent {
 995                                    project_entity_id: project_entity_id,
 996                                    timestamp: Instant::now(),
 997                                    metadata: vec![
 998                                        (
 999                                            "Cache Hits",
1000                                            format!(
1001                                                "{}/{}",
1002                                                cache_hit_count,
1003                                                cache_hit_count + cache_miss_count
1004                                            )
1005                                            .into(),
1006                                        ),
1007                                        (
1008                                            "Max LSP Time",
1009                                            format!("{} ms", max_definition_latency.as_millis())
1010                                                .into(),
1011                                        ),
1012                                        (
1013                                            "Mean LSP Time",
1014                                            format!("{} ms", mean_definition_latency.as_millis())
1015                                                .into(),
1016                                        ),
1017                                    ],
1018                                },
1019                            ))
1020                            .ok();
1021                    }
1022                }
1023            }
1024        }
1025    }
1026
1027    pub fn debug_info(
1028        &mut self,
1029        project: &Entity<Project>,
1030        cx: &mut Context<Self>,
1031    ) -> mpsc::UnboundedReceiver<DebugEvent> {
1032        let project_state = self.get_or_init_project(project, cx);
1033        let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
1034        project_state.debug_tx = Some(debug_watch_tx);
1035        debug_watch_rx
1036    }
1037
1038    fn handle_project_event(
1039        &mut self,
1040        project: Entity<Project>,
1041        event: &project::Event,
1042        cx: &mut Context<Self>,
1043    ) {
1044        if !is_ep_store_provider(all_language_settings(None, cx).edit_predictions.provider) {
1045            return;
1046        }
1047        // TODO [zeta2] init with recent paths
1048        match event {
1049            project::Event::ActiveEntryChanged(Some(active_entry_id)) => {
1050                let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1051                    return;
1052                };
1053                let path = project.read(cx).path_for_entry(*active_entry_id, cx);
1054                if let Some(path) = path {
1055                    if let Some(ix) = project_state
1056                        .recent_paths
1057                        .iter()
1058                        .position(|probe| probe == &path)
1059                    {
1060                        project_state.recent_paths.remove(ix);
1061                    }
1062                    project_state.recent_paths.push_front(path);
1063                }
1064            }
1065            project::Event::DiagnosticsUpdated { .. } => {
1066                if cx.has_flag::<EditPredictionJumpsFeatureFlag>() {
1067                    self.refresh_prediction_from_diagnostics(
1068                        project,
1069                        DiagnosticSearchScope::Global,
1070                        cx,
1071                    );
1072                }
1073            }
1074            _ => (),
1075        }
1076    }
1077
1078    fn register_buffer_impl<'a>(
1079        project_state: &'a mut ProjectState,
1080        buffer: &Entity<Buffer>,
1081        project: &Entity<Project>,
1082        cx: &mut Context<Self>,
1083    ) -> &'a mut RegisteredBuffer {
1084        let buffer_id = buffer.entity_id();
1085
1086        if let Some(file) = buffer.read(cx).file() {
1087            let worktree_id = file.worktree_id(cx);
1088            if let Some(worktree) = project.read(cx).worktree_for_id(worktree_id, cx) {
1089                project_state
1090                    .license_detection_watchers
1091                    .entry(worktree_id)
1092                    .or_insert_with(|| {
1093                        let project_entity_id = project.entity_id();
1094                        cx.observe_release(&worktree, move |this, _worktree, _cx| {
1095                            let Some(project_state) = this.projects.get_mut(&project_entity_id)
1096                            else {
1097                                return;
1098                            };
1099                            project_state
1100                                .license_detection_watchers
1101                                .remove(&worktree_id);
1102                        })
1103                        .detach();
1104                        Rc::new(LicenseDetectionWatcher::new(&worktree, cx))
1105                    });
1106            }
1107        }
1108
1109        match project_state.registered_buffers.entry(buffer_id) {
1110            hash_map::Entry::Occupied(entry) => entry.into_mut(),
1111            hash_map::Entry::Vacant(entry) => {
1112                let buf = buffer.read(cx);
1113                let snapshot = buf.text_snapshot();
1114                let file = buf.file().cloned();
1115                let project_entity_id = project.entity_id();
1116                entry.insert(RegisteredBuffer {
1117                    snapshot,
1118                    file,
1119                    last_position: None,
1120                    pending_predictions: Vec::new(),
1121                    _subscriptions: [
1122                        cx.subscribe(buffer, {
1123                            let project = project.downgrade();
1124                            move |this, buffer, event, cx| {
1125                                if let language::BufferEvent::Edited = event
1126                                    && let Some(project) = project.upgrade()
1127                                {
1128                                    this.report_changes_for_buffer(&buffer, &project, false, cx);
1129                                }
1130                            }
1131                        }),
1132                        cx.observe_release(buffer, move |this, _buffer, _cx| {
1133                            let Some(project_state) = this.projects.get_mut(&project_entity_id)
1134                            else {
1135                                return;
1136                            };
1137                            project_state.registered_buffers.remove(&buffer_id);
1138                        }),
1139                    ],
1140                })
1141            }
1142        }
1143    }
1144
1145    fn report_changes_for_buffer(
1146        &mut self,
1147        buffer: &Entity<Buffer>,
1148        project: &Entity<Project>,
1149        is_predicted: bool,
1150        cx: &mut Context<Self>,
1151    ) {
1152        let project_state = self.get_or_init_project(project, cx);
1153        let registered_buffer = Self::register_buffer_impl(project_state, buffer, project, cx);
1154
1155        let buf = buffer.read(cx);
1156        let new_file = buf.file().cloned();
1157        let new_snapshot = buf.text_snapshot();
1158        if new_snapshot.version == registered_buffer.snapshot.version {
1159            return;
1160        }
1161
1162        let old_file = mem::replace(&mut registered_buffer.file, new_file.clone());
1163        let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
1164        let mut num_edits = 0usize;
1165        let mut total_deleted = 0usize;
1166        let mut total_inserted = 0usize;
1167        let mut edit_range: Option<Range<Anchor>> = None;
1168        let mut last_offset: Option<usize> = None;
1169        let now = cx.background_executor().now();
1170
1171        for (edit, anchor_range) in
1172            new_snapshot.anchored_edits_since::<usize>(&old_snapshot.version)
1173        {
1174            num_edits += 1;
1175            total_deleted += edit.old.len();
1176            total_inserted += edit.new.len();
1177            edit_range = Some(match edit_range {
1178                None => anchor_range,
1179                Some(acc) => acc.start..anchor_range.end,
1180            });
1181            last_offset = Some(edit.new.end);
1182        }
1183
1184        let Some(edit_range) = edit_range else {
1185            return;
1186        };
1187
1188        for pending_prediction in &mut registered_buffer.pending_predictions {
1189            if edit_range.overlaps(&pending_prediction.editable_anchor_range, &new_snapshot) {
1190                pending_prediction.last_edit_at = now;
1191            }
1192        }
1193
1194        let action_type = match (total_deleted, total_inserted, num_edits) {
1195            (0, ins, n) if ins == n => UserActionType::InsertChar,
1196            (0, _, _) => UserActionType::InsertSelection,
1197            (del, 0, n) if del == n => UserActionType::DeleteChar,
1198            (_, 0, _) => UserActionType::DeleteSelection,
1199            (_, ins, n) if ins == n => UserActionType::InsertChar,
1200            (_, _, _) => UserActionType::InsertSelection,
1201        };
1202
1203        if let Some(offset) = last_offset {
1204            let point = new_snapshot.offset_to_point(offset);
1205            let timestamp_epoch_ms = SystemTime::now()
1206                .duration_since(UNIX_EPOCH)
1207                .map(|d| d.as_millis() as u64)
1208                .unwrap_or(0);
1209            project_state.record_user_action(UserActionRecord {
1210                action_type,
1211                buffer_id: buffer.entity_id(),
1212                line_number: point.row,
1213                offset,
1214                timestamp_epoch_ms,
1215            });
1216        }
1217
1218        let events = &mut project_state.events;
1219
1220        if let Some(last_event) = project_state.last_event.as_mut() {
1221            let is_next_snapshot_of_same_buffer = old_snapshot.remote_id()
1222                == last_event.new_snapshot.remote_id()
1223                && old_snapshot.version == last_event.new_snapshot.version;
1224
1225            let prediction_source_changed = is_predicted != last_event.predicted;
1226
1227            let should_coalesce = is_next_snapshot_of_same_buffer
1228                && !prediction_source_changed
1229                && last_event
1230                    .edit_range
1231                    .as_ref()
1232                    .is_some_and(|last_edit_range| {
1233                        lines_between_ranges(
1234                            &edit_range.to_point(&new_snapshot),
1235                            &last_edit_range.to_point(&new_snapshot),
1236                        ) <= CHANGE_GROUPING_LINE_SPAN
1237                    });
1238
1239            if should_coalesce {
1240                let pause_elapsed = last_event
1241                    .last_edit_time
1242                    .map(|t| now.duration_since(t) >= LAST_CHANGE_GROUPING_TIME)
1243                    .unwrap_or(false);
1244                if pause_elapsed {
1245                    last_event.snapshot_after_last_editing_pause =
1246                        Some(last_event.new_snapshot.clone());
1247                }
1248
1249                last_event.edit_range = Some(edit_range);
1250                last_event.new_snapshot = new_snapshot;
1251                last_event.last_edit_time = Some(now);
1252                return;
1253            }
1254        }
1255
1256        if let Some(event) = project_state.last_event.take() {
1257            if let Some(event) = event.finalize(&project_state.license_detection_watchers, cx) {
1258                if events.len() + 1 >= EVENT_COUNT_MAX {
1259                    events.pop_front();
1260                }
1261                events.push_back(event);
1262            }
1263        }
1264
1265        merge_trailing_events_if_needed(events, &old_snapshot, &new_snapshot, &edit_range);
1266
1267        project_state.last_event = Some(LastEvent {
1268            old_file,
1269            new_file,
1270            old_snapshot,
1271            new_snapshot,
1272            edit_range: Some(edit_range),
1273            predicted: is_predicted,
1274            snapshot_after_last_editing_pause: None,
1275            last_edit_time: Some(now),
1276        });
1277    }
1278
1279    fn prediction_at(
1280        &mut self,
1281        buffer: &Entity<Buffer>,
1282        position: Option<language::Anchor>,
1283        project: &Entity<Project>,
1284        cx: &App,
1285    ) -> Option<BufferEditPrediction<'_>> {
1286        let project_state = self.projects.get_mut(&project.entity_id())?;
1287        if let Some(position) = position
1288            && let Some(buffer) = project_state
1289                .registered_buffers
1290                .get_mut(&buffer.entity_id())
1291        {
1292            buffer.last_position = Some(position);
1293        }
1294
1295        let CurrentEditPrediction {
1296            requested_by,
1297            prediction,
1298            ..
1299        } = project_state.current_prediction.as_ref()?;
1300
1301        if prediction.targets_buffer(buffer.read(cx)) {
1302            Some(BufferEditPrediction::Local { prediction })
1303        } else {
1304            let show_jump = match requested_by {
1305                PredictionRequestedBy::Buffer(requested_by_buffer_id) => {
1306                    requested_by_buffer_id == &buffer.entity_id()
1307                }
1308                PredictionRequestedBy::DiagnosticsUpdate => true,
1309            };
1310
1311            if show_jump {
1312                Some(BufferEditPrediction::Jump { prediction })
1313            } else {
1314                None
1315            }
1316        }
1317    }
1318
1319    fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
1320        let Some(current_prediction) = self
1321            .projects
1322            .get_mut(&project.entity_id())
1323            .and_then(|project_state| project_state.current_prediction.take())
1324        else {
1325            return;
1326        };
1327
1328        self.report_changes_for_buffer(&current_prediction.prediction.buffer, project, true, cx);
1329
1330        // can't hold &mut project_state ref across report_changes_for_buffer_call
1331        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1332            return;
1333        };
1334
1335        for pending_prediction in mem::take(&mut project_state.pending_predictions) {
1336            project_state.cancel_pending_prediction(pending_prediction, cx);
1337        }
1338
1339        match self.edit_prediction_model {
1340            EditPredictionModel::Sweep => {
1341                sweep_ai::edit_prediction_accepted(self, current_prediction, cx)
1342            }
1343            EditPredictionModel::Mercury => {
1344                mercury::edit_prediction_accepted(
1345                    current_prediction.prediction.id,
1346                    self.client.http_client(),
1347                    cx,
1348                );
1349            }
1350            EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {
1351                let is_cloud = !matches!(
1352                    all_language_settings(None, cx).edit_predictions.provider,
1353                    EditPredictionProvider::Ollama | EditPredictionProvider::OpenAiCompatibleApi
1354                );
1355                if is_cloud {
1356                    zeta::edit_prediction_accepted(self, current_prediction, cx)
1357                }
1358            }
1359            EditPredictionModel::Fim { .. } => {}
1360        }
1361    }
1362
1363    async fn handle_rejected_predictions(
1364        rx: UnboundedReceiver<EditPredictionRejection>,
1365        client: Arc<Client>,
1366        llm_token: LlmApiToken,
1367        app_version: Version,
1368        background_executor: BackgroundExecutor,
1369    ) {
1370        let mut rx = std::pin::pin!(rx.peekable());
1371        let mut batched = Vec::new();
1372
1373        while let Some(rejection) = rx.next().await {
1374            batched.push(rejection);
1375
1376            if batched.len() < MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST / 2 {
1377                select_biased! {
1378                    next = rx.as_mut().peek().fuse() => {
1379                        if next.is_some() {
1380                            continue;
1381                        }
1382                    }
1383                    () = background_executor.timer(REJECT_REQUEST_DEBOUNCE).fuse() => {},
1384                }
1385            }
1386
1387            let url = client
1388                .http_client()
1389                .build_zed_llm_url("/predict_edits/reject", &[])
1390                .unwrap();
1391
1392            let flush_count = batched
1393                .len()
1394                // in case items have accumulated after failure
1395                .min(MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST);
1396            let start = batched.len() - flush_count;
1397
1398            let body = RejectEditPredictionsBodyRef {
1399                rejections: &batched[start..],
1400            };
1401
1402            let result = Self::send_api_request::<()>(
1403                |builder| {
1404                    let req = builder
1405                        .uri(url.as_ref())
1406                        .body(serde_json::to_string(&body)?.into());
1407                    anyhow::Ok(req?)
1408                },
1409                client.clone(),
1410                llm_token.clone(),
1411                app_version.clone(),
1412                true,
1413            )
1414            .await;
1415
1416            if result.log_err().is_some() {
1417                batched.drain(start..);
1418            }
1419        }
1420    }
1421
1422    async fn run_settled_predictions_worker(
1423        this: WeakEntity<Self>,
1424        mut rx: UnboundedReceiver<Instant>,
1425        cx: &mut AsyncApp,
1426    ) {
1427        let mut next_wake_time: Option<Instant> = None;
1428        loop {
1429            let now = cx.background_executor().now();
1430            if let Some(wake_time) = next_wake_time.take() {
1431                cx.background_executor()
1432                    .timer(wake_time.duration_since(now))
1433                    .await;
1434            } else {
1435                let Some(new_enqueue_time) = rx.next().await else {
1436                    break;
1437                };
1438                next_wake_time = Some(new_enqueue_time + EDIT_PREDICTION_SETTLED_QUIESCENCE);
1439                while rx.next().now_or_never().flatten().is_some() {}
1440                continue;
1441            }
1442
1443            let Some(this) = this.upgrade() else {
1444                break;
1445            };
1446
1447            let now = cx.background_executor().now();
1448
1449            let mut oldest_edited_at = None;
1450
1451            this.update(cx, |this, _| {
1452                for (_, project_state) in this.projects.iter_mut() {
1453                    for (_, registered_buffer) in project_state.registered_buffers.iter_mut() {
1454                        registered_buffer
1455                            .pending_predictions
1456                            .retain_mut(|pending_prediction| {
1457                                let age =
1458                                    now.saturating_duration_since(pending_prediction.enqueued_at);
1459                                if age >= EDIT_PREDICTION_SETTLED_TTL {
1460                                    return false;
1461                                }
1462
1463                                let quiet_for =
1464                                    now.saturating_duration_since(pending_prediction.last_edit_at);
1465                                if quiet_for >= EDIT_PREDICTION_SETTLED_QUIESCENCE {
1466                                    let settled_editable_region = registered_buffer
1467                                        .snapshot
1468                                        .text_for_range(
1469                                            pending_prediction.editable_anchor_range.clone(),
1470                                        )
1471                                        .collect::<String>();
1472
1473                                    #[cfg(test)]
1474                                    if let Some(callback) = &this.settled_event_callback {
1475                                        callback(
1476                                            pending_prediction.request_id.clone(),
1477                                            settled_editable_region.clone(),
1478                                        );
1479                                    }
1480
1481                                    telemetry::event!(
1482                                        EDIT_PREDICTION_SETTLED_EVENT,
1483                                        request_id = pending_prediction.request_id.0.clone(),
1484                                        settled_editable_region,
1485                                    );
1486
1487                                    return false;
1488                                }
1489
1490                                if oldest_edited_at
1491                                    .is_none_or(|t| pending_prediction.last_edit_at < t)
1492                                {
1493                                    oldest_edited_at = Some(pending_prediction.last_edit_at);
1494                                }
1495
1496                                true
1497                            });
1498                    }
1499                }
1500            });
1501
1502            next_wake_time = oldest_edited_at.map(|t| t + EDIT_PREDICTION_SETTLED_QUIESCENCE);
1503        }
1504    }
1505
1506    pub(crate) fn enqueue_settled_prediction(
1507        &mut self,
1508        request_id: EditPredictionId,
1509        project: &Entity<Project>,
1510        edited_buffer: &Entity<Buffer>,
1511        edited_buffer_snapshot: &BufferSnapshot,
1512        editable_offset_range: Range<usize>,
1513        cx: &mut Context<Self>,
1514    ) {
1515        let project_state = self.get_or_init_project(project, cx);
1516        if let Some(buffer) = project_state
1517            .registered_buffers
1518            .get_mut(&edited_buffer.entity_id())
1519        {
1520            let now = cx.background_executor().now();
1521            buffer.pending_predictions.push(PendingSettledPrediction {
1522                request_id,
1523                editable_anchor_range: edited_buffer_snapshot
1524                    .anchor_range_around(editable_offset_range),
1525                enqueued_at: now,
1526                last_edit_at: now,
1527            });
1528            self.settled_predictions_tx.unbounded_send(now).ok();
1529        }
1530    }
1531
1532    fn reject_current_prediction(
1533        &mut self,
1534        reason: EditPredictionRejectReason,
1535        project: &Entity<Project>,
1536        cx: &App,
1537    ) {
1538        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
1539            project_state.pending_predictions.clear();
1540            if let Some(prediction) = project_state.current_prediction.take() {
1541                let model_version = prediction.prediction.model_version.clone();
1542                self.reject_prediction(
1543                    prediction.prediction.id,
1544                    reason,
1545                    prediction.was_shown,
1546                    model_version,
1547                    cx,
1548                );
1549            }
1550        };
1551    }
1552
1553    fn did_show_current_prediction(
1554        &mut self,
1555        project: &Entity<Project>,
1556        display_type: edit_prediction_types::SuggestionDisplayType,
1557        cx: &mut Context<Self>,
1558    ) {
1559        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1560            return;
1561        };
1562
1563        let Some(current_prediction) = project_state.current_prediction.as_mut() else {
1564            return;
1565        };
1566
1567        let is_jump = display_type == edit_prediction_types::SuggestionDisplayType::Jump;
1568        let previous_shown_with = current_prediction.shown_with;
1569
1570        if previous_shown_with.is_none() || !is_jump {
1571            current_prediction.shown_with = Some(display_type);
1572        }
1573
1574        let is_first_non_jump_show = !current_prediction.was_shown && !is_jump;
1575
1576        if is_first_non_jump_show {
1577            current_prediction.was_shown = true;
1578        }
1579
1580        let display_type_changed = previous_shown_with != Some(display_type);
1581
1582        if self.edit_prediction_model == EditPredictionModel::Sweep && display_type_changed {
1583            sweep_ai::edit_prediction_shown(
1584                &self.sweep_ai,
1585                self.client.clone(),
1586                &current_prediction.prediction,
1587                display_type,
1588                cx,
1589            );
1590        }
1591
1592        if is_first_non_jump_show {
1593            self.shown_predictions
1594                .push_front(current_prediction.prediction.clone());
1595            if self.shown_predictions.len() > 50 {
1596                let completion = self.shown_predictions.pop_back().unwrap();
1597                self.rated_predictions.remove(&completion.id);
1598            }
1599        }
1600    }
1601
1602    fn reject_prediction(
1603        &mut self,
1604        prediction_id: EditPredictionId,
1605        reason: EditPredictionRejectReason,
1606        was_shown: bool,
1607        model_version: Option<String>,
1608        cx: &App,
1609    ) {
1610        match self.edit_prediction_model {
1611            EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {
1612                let is_cloud = !matches!(
1613                    all_language_settings(None, cx).edit_predictions.provider,
1614                    EditPredictionProvider::Ollama | EditPredictionProvider::OpenAiCompatibleApi
1615                );
1616                if is_cloud {
1617                    self.reject_predictions_tx
1618                        .unbounded_send(EditPredictionRejection {
1619                            request_id: prediction_id.to_string(),
1620                            reason,
1621                            was_shown,
1622                            model_version,
1623                        })
1624                        .log_err();
1625                }
1626            }
1627            EditPredictionModel::Mercury => {
1628                mercury::edit_prediction_rejected(
1629                    prediction_id,
1630                    was_shown,
1631                    reason,
1632                    self.client.http_client(),
1633                    cx,
1634                );
1635            }
1636            EditPredictionModel::Sweep | EditPredictionModel::Fim { .. } => {}
1637        }
1638    }
1639
1640    fn is_refreshing(&self, project: &Entity<Project>) -> bool {
1641        self.projects
1642            .get(&project.entity_id())
1643            .is_some_and(|project_state| !project_state.pending_predictions.is_empty())
1644    }
1645
1646    pub fn refresh_prediction_from_buffer(
1647        &mut self,
1648        project: Entity<Project>,
1649        buffer: Entity<Buffer>,
1650        position: language::Anchor,
1651        cx: &mut Context<Self>,
1652    ) {
1653        self.queue_prediction_refresh(
1654            project.clone(),
1655            PredictEditsRequestTrigger::Other,
1656            buffer.entity_id(),
1657            cx,
1658            move |this, cx| {
1659                let Some(request_task) = this
1660                    .update(cx, |this, cx| {
1661                        this.request_prediction(
1662                            &project,
1663                            &buffer,
1664                            position,
1665                            PredictEditsRequestTrigger::Other,
1666                            cx,
1667                        )
1668                    })
1669                    .log_err()
1670                else {
1671                    return Task::ready(anyhow::Ok(None));
1672                };
1673
1674                cx.spawn(async move |_cx| {
1675                    request_task.await.map(|prediction_result| {
1676                        prediction_result.map(|prediction_result| {
1677                            (
1678                                prediction_result,
1679                                PredictionRequestedBy::Buffer(buffer.entity_id()),
1680                            )
1681                        })
1682                    })
1683                })
1684            },
1685        )
1686    }
1687
1688    pub fn refresh_prediction_from_diagnostics(
1689        &mut self,
1690        project: Entity<Project>,
1691        scope: DiagnosticSearchScope,
1692        cx: &mut Context<Self>,
1693    ) {
1694        if !is_ep_store_provider(all_language_settings(None, cx).edit_predictions.provider) {
1695            return;
1696        }
1697
1698        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1699            return;
1700        };
1701
1702        // Prefer predictions from buffer
1703        if project_state.current_prediction.is_some() {
1704            return;
1705        }
1706
1707        self.queue_prediction_refresh(
1708            project.clone(),
1709            PredictEditsRequestTrigger::Diagnostics,
1710            project.entity_id(),
1711            cx,
1712            move |this, cx| {
1713                let Some((active_buffer, snapshot, cursor_point)) = this
1714                    .read_with(cx, |this, cx| {
1715                        let project_state = this.projects.get(&project.entity_id())?;
1716                        let (buffer, position) = project_state.active_buffer(&project, cx)?;
1717                        let snapshot = buffer.read(cx).snapshot();
1718
1719                        if !Self::predictions_enabled_at(&snapshot, position, cx) {
1720                            return None;
1721                        }
1722
1723                        let cursor_point = position
1724                            .map(|pos| pos.to_point(&snapshot))
1725                            .unwrap_or_default();
1726
1727                        Some((buffer, snapshot, cursor_point))
1728                    })
1729                    .log_err()
1730                    .flatten()
1731                else {
1732                    return Task::ready(anyhow::Ok(None));
1733                };
1734
1735                cx.spawn(async move |cx| {
1736                    let diagnostic_search_range = match scope {
1737                        DiagnosticSearchScope::Local => {
1738                            let diagnostic_search_start =
1739                                cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
1740                            let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
1741                            Point::new(diagnostic_search_start, 0)
1742                                ..Point::new(diagnostic_search_end, 0)
1743                        }
1744                        DiagnosticSearchScope::Global => Default::default(),
1745                    };
1746
1747                    let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1748                        active_buffer,
1749                        &snapshot,
1750                        diagnostic_search_range,
1751                        cursor_point,
1752                        &project,
1753                        cx,
1754                    )
1755                    .await?
1756                    else {
1757                        return anyhow::Ok(None);
1758                    };
1759
1760                    let Some(prediction_result) = this
1761                        .update(cx, |this, cx| {
1762                            this.request_prediction(
1763                                &project,
1764                                &jump_buffer,
1765                                jump_position,
1766                                PredictEditsRequestTrigger::Diagnostics,
1767                                cx,
1768                            )
1769                        })?
1770                        .await?
1771                    else {
1772                        return anyhow::Ok(None);
1773                    };
1774
1775                    this.update(cx, |this, cx| {
1776                        Some((
1777                            if this
1778                                .get_or_init_project(&project, cx)
1779                                .current_prediction
1780                                .is_none()
1781                            {
1782                                prediction_result
1783                            } else {
1784                                EditPredictionResult {
1785                                    id: prediction_result.id,
1786                                    prediction: Err(EditPredictionRejectReason::CurrentPreferred),
1787                                }
1788                            },
1789                            PredictionRequestedBy::DiagnosticsUpdate,
1790                        ))
1791                    })
1792                })
1793            },
1794        );
1795    }
1796
1797    fn predictions_enabled_at(
1798        snapshot: &BufferSnapshot,
1799        position: Option<language::Anchor>,
1800        cx: &App,
1801    ) -> bool {
1802        let file = snapshot.file();
1803        let all_settings = all_language_settings(file, cx);
1804        if !all_settings.show_edit_predictions(snapshot.language(), cx)
1805            || file.is_some_and(|file| !all_settings.edit_predictions_enabled_for_file(file, cx))
1806        {
1807            return false;
1808        }
1809
1810        if let Some(last_position) = position {
1811            let settings = snapshot.settings_at(last_position, cx);
1812
1813            if !settings.edit_predictions_disabled_in.is_empty()
1814                && let Some(scope) = snapshot.language_scope_at(last_position)
1815                && let Some(scope_name) = scope.override_name()
1816                && settings
1817                    .edit_predictions_disabled_in
1818                    .iter()
1819                    .any(|s| s == scope_name)
1820            {
1821                return false;
1822            }
1823        }
1824
1825        true
1826    }
1827
1828    pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
1829}
1830
1831fn is_ep_store_provider(provider: EditPredictionProvider) -> bool {
1832    match provider {
1833        EditPredictionProvider::Zed
1834        | EditPredictionProvider::Sweep
1835        | EditPredictionProvider::Mercury
1836        | EditPredictionProvider::Ollama
1837        | EditPredictionProvider::OpenAiCompatibleApi
1838        | EditPredictionProvider::Experimental(_) => true,
1839        EditPredictionProvider::None
1840        | EditPredictionProvider::Copilot
1841        | EditPredictionProvider::Supermaven
1842        | EditPredictionProvider::Codestral => false,
1843    }
1844}
1845
1846impl EditPredictionStore {
1847    fn queue_prediction_refresh(
1848        &mut self,
1849        project: Entity<Project>,
1850        request_trigger: PredictEditsRequestTrigger,
1851        throttle_entity: EntityId,
1852        cx: &mut Context<Self>,
1853        do_refresh: impl FnOnce(
1854            WeakEntity<Self>,
1855            &mut AsyncApp,
1856        )
1857            -> Task<Result<Option<(EditPredictionResult, PredictionRequestedBy)>>>
1858        + 'static,
1859    ) {
1860        fn select_throttle(
1861            project_state: &mut ProjectState,
1862            request_trigger: PredictEditsRequestTrigger,
1863        ) -> &mut Option<(EntityId, Instant)> {
1864            match request_trigger {
1865                PredictEditsRequestTrigger::Diagnostics => {
1866                    &mut project_state.last_jump_prediction_refresh
1867                }
1868                _ => &mut project_state.last_edit_prediction_refresh,
1869            }
1870        }
1871
1872        let (needs_acceptance_tracking, max_pending_predictions) =
1873            match all_language_settings(None, cx).edit_predictions.provider {
1874                EditPredictionProvider::Zed
1875                | EditPredictionProvider::Sweep
1876                | EditPredictionProvider::Mercury
1877                | EditPredictionProvider::Experimental(_) => (true, 2),
1878                EditPredictionProvider::Ollama => (false, 1),
1879                EditPredictionProvider::OpenAiCompatibleApi => (false, 2),
1880                EditPredictionProvider::None
1881                | EditPredictionProvider::Copilot
1882                | EditPredictionProvider::Supermaven
1883                | EditPredictionProvider::Codestral => {
1884                    log::error!("queue_prediction_refresh called with non-store provider");
1885                    return;
1886                }
1887            };
1888
1889        let drop_on_cancel = !needs_acceptance_tracking;
1890        let throttle_timeout = Self::THROTTLE_TIMEOUT;
1891        let project_state = self.get_or_init_project(&project, cx);
1892        let pending_prediction_id = project_state.next_pending_prediction_id;
1893        project_state.next_pending_prediction_id += 1;
1894        let last_request = *select_throttle(project_state, request_trigger);
1895
1896        let task = cx.spawn(async move |this, cx| {
1897            if let Some(timeout) = last_request.and_then(|(last_entity, last_timestamp)| {
1898                if throttle_entity != last_entity {
1899                    return None;
1900                }
1901                (last_timestamp + throttle_timeout).checked_duration_since(Instant::now())
1902            }) {
1903                cx.background_executor().timer(timeout).await;
1904            }
1905
1906            // If this task was cancelled before the throttle timeout expired,
1907            // do not perform a request.
1908            let mut is_cancelled = true;
1909            this.update(cx, |this, cx| {
1910                let project_state = this.get_or_init_project(&project, cx);
1911                let was_cancelled = project_state
1912                    .cancelled_predictions
1913                    .remove(&pending_prediction_id);
1914                if !was_cancelled {
1915                    let new_refresh = (throttle_entity, Instant::now());
1916                    *select_throttle(project_state, request_trigger) = Some(new_refresh);
1917                    is_cancelled = false;
1918                }
1919            })
1920            .ok();
1921            if is_cancelled {
1922                return None;
1923            }
1924
1925            let new_prediction_result = do_refresh(this.clone(), cx).await.log_err().flatten();
1926            let new_prediction_id = new_prediction_result
1927                .as_ref()
1928                .map(|(prediction, _)| prediction.id.clone());
1929
1930            // When a prediction completes, remove it from the pending list, and cancel
1931            // any pending predictions that were enqueued before it.
1932            this.update(cx, |this, cx| {
1933                let project_state = this.get_or_init_project(&project, cx);
1934
1935                let is_cancelled = project_state
1936                    .cancelled_predictions
1937                    .remove(&pending_prediction_id);
1938
1939                let new_current_prediction = if !is_cancelled
1940                    && let Some((prediction_result, requested_by)) = new_prediction_result
1941                {
1942                    match prediction_result.prediction {
1943                        Ok(prediction) => {
1944                            let new_prediction = CurrentEditPrediction {
1945                                requested_by,
1946                                prediction,
1947                                was_shown: false,
1948                                shown_with: None,
1949                            };
1950
1951                            if let Some(current_prediction) =
1952                                project_state.current_prediction.as_ref()
1953                            {
1954                                if new_prediction.should_replace_prediction(&current_prediction, cx)
1955                                {
1956                                    this.reject_current_prediction(
1957                                        EditPredictionRejectReason::Replaced,
1958                                        &project,
1959                                        cx,
1960                                    );
1961
1962                                    Some(new_prediction)
1963                                } else {
1964                                    this.reject_prediction(
1965                                        new_prediction.prediction.id,
1966                                        EditPredictionRejectReason::CurrentPreferred,
1967                                        false,
1968                                        new_prediction.prediction.model_version,
1969                                        cx,
1970                                    );
1971                                    None
1972                                }
1973                            } else {
1974                                Some(new_prediction)
1975                            }
1976                        }
1977                        Err(reject_reason) => {
1978                            this.reject_prediction(
1979                                prediction_result.id,
1980                                reject_reason,
1981                                false,
1982                                None,
1983                                cx,
1984                            );
1985                            None
1986                        }
1987                    }
1988                } else {
1989                    None
1990                };
1991
1992                let project_state = this.get_or_init_project(&project, cx);
1993
1994                if let Some(new_prediction) = new_current_prediction {
1995                    project_state.current_prediction = Some(new_prediction);
1996                }
1997
1998                let mut pending_predictions = mem::take(&mut project_state.pending_predictions);
1999                for (ix, pending_prediction) in pending_predictions.iter().enumerate() {
2000                    if pending_prediction.id == pending_prediction_id {
2001                        pending_predictions.remove(ix);
2002                        for pending_prediction in pending_predictions.drain(0..ix) {
2003                            project_state.cancel_pending_prediction(pending_prediction, cx)
2004                        }
2005                        break;
2006                    }
2007                }
2008                this.get_or_init_project(&project, cx).pending_predictions = pending_predictions;
2009                cx.notify();
2010            })
2011            .ok();
2012
2013            new_prediction_id
2014        });
2015
2016        if project_state.pending_predictions.len() < max_pending_predictions {
2017            project_state.pending_predictions.push(PendingPrediction {
2018                id: pending_prediction_id,
2019                task,
2020                drop_on_cancel,
2021            });
2022        } else {
2023            let pending_prediction = project_state.pending_predictions.pop().unwrap();
2024            project_state.pending_predictions.push(PendingPrediction {
2025                id: pending_prediction_id,
2026                task,
2027                drop_on_cancel,
2028            });
2029            project_state.cancel_pending_prediction(pending_prediction, cx);
2030        }
2031    }
2032
2033    pub fn request_prediction(
2034        &mut self,
2035        project: &Entity<Project>,
2036        active_buffer: &Entity<Buffer>,
2037        position: language::Anchor,
2038        trigger: PredictEditsRequestTrigger,
2039        cx: &mut Context<Self>,
2040    ) -> Task<Result<Option<EditPredictionResult>>> {
2041        self.request_prediction_internal(
2042            project.clone(),
2043            active_buffer.clone(),
2044            position,
2045            trigger,
2046            cx.has_flag::<Zeta2FeatureFlag>(),
2047            cx,
2048        )
2049    }
2050
2051    fn request_prediction_internal(
2052        &mut self,
2053        project: Entity<Project>,
2054        active_buffer: Entity<Buffer>,
2055        position: language::Anchor,
2056        trigger: PredictEditsRequestTrigger,
2057        allow_jump: bool,
2058        cx: &mut Context<Self>,
2059    ) -> Task<Result<Option<EditPredictionResult>>> {
2060        self.get_or_init_project(&project, cx);
2061        let project_state = self.projects.get(&project.entity_id()).unwrap();
2062        let stored_events = project_state.events(cx);
2063        let has_events = !stored_events.is_empty();
2064        let events: Vec<Arc<zeta_prompt::Event>> =
2065            stored_events.iter().map(|e| e.event.clone()).collect();
2066        let debug_tx = project_state.debug_tx.clone();
2067
2068        let snapshot = active_buffer.read(cx).snapshot();
2069        let cursor_point = position.to_point(&snapshot);
2070        let current_offset = position.to_offset(&snapshot);
2071
2072        let mut user_actions: Vec<UserActionRecord> =
2073            project_state.user_actions.iter().cloned().collect();
2074
2075        if let Some(last_action) = user_actions.last() {
2076            if last_action.buffer_id == active_buffer.entity_id()
2077                && current_offset != last_action.offset
2078            {
2079                let timestamp_epoch_ms = SystemTime::now()
2080                    .duration_since(UNIX_EPOCH)
2081                    .map(|d| d.as_millis() as u64)
2082                    .unwrap_or(0);
2083                user_actions.push(UserActionRecord {
2084                    action_type: UserActionType::CursorMovement,
2085                    buffer_id: active_buffer.entity_id(),
2086                    line_number: cursor_point.row,
2087                    offset: current_offset,
2088                    timestamp_epoch_ms,
2089                });
2090            }
2091        }
2092        let diagnostic_search_start = cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
2093        let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
2094        let diagnostic_search_range =
2095            Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
2096
2097        let related_files = self.context_for_project(&project, cx);
2098
2099        let is_open_source = snapshot
2100            .file()
2101            .map_or(false, |file| self.is_file_open_source(&project, file, cx))
2102            && events.iter().all(|event| event.in_open_source_repo())
2103            && related_files.iter().all(|file| file.in_open_source_repo);
2104
2105        let can_collect_data = !cfg!(test)
2106            && is_open_source
2107            && self.is_data_collection_enabled(cx)
2108            && matches!(
2109                self.edit_prediction_model,
2110                EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2
2111            );
2112
2113        let inputs = EditPredictionModelInput {
2114            project: project.clone(),
2115            buffer: active_buffer.clone(),
2116            snapshot: snapshot,
2117            position,
2118            events,
2119            related_files,
2120            recent_paths: project_state.recent_paths.clone(),
2121            trigger,
2122            diagnostic_search_range: diagnostic_search_range,
2123            debug_tx,
2124            user_actions,
2125            can_collect_data,
2126            is_open_source,
2127        };
2128
2129        if can_collect_data && rand::random_ratio(1, 1000) {
2130            if let Some(task) = capture_example(
2131                project.clone(),
2132                active_buffer,
2133                position,
2134                stored_events,
2135                false,
2136                cx,
2137            ) {
2138                task.detach();
2139            }
2140        }
2141
2142        let task = match self.edit_prediction_model {
2143            EditPredictionModel::Zeta1 => zeta::request_prediction_with_zeta(
2144                self,
2145                inputs,
2146                Some(zeta_prompt::EditPredictionModelKind::Zeta1),
2147                cx,
2148            ),
2149            EditPredictionModel::Zeta2 => zeta::request_prediction_with_zeta(
2150                self,
2151                inputs,
2152                Some(zeta_prompt::EditPredictionModelKind::Zeta2),
2153                cx,
2154            ),
2155            EditPredictionModel::Fim { format } => fim::request_prediction(inputs, format, cx),
2156            EditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep(inputs, cx),
2157            EditPredictionModel::Mercury => self.mercury.request_prediction(inputs, cx),
2158        };
2159
2160        cx.spawn(async move |this, cx| {
2161            let prediction = task.await?;
2162
2163            if prediction.is_none() && allow_jump && has_events {
2164                this.update(cx, |this, cx| {
2165                    this.refresh_prediction_from_diagnostics(
2166                        project,
2167                        DiagnosticSearchScope::Local,
2168                        cx,
2169                    );
2170                })?;
2171                return anyhow::Ok(None);
2172            }
2173
2174            Ok(prediction)
2175        })
2176    }
2177
2178    pub(crate) async fn next_diagnostic_location(
2179        active_buffer: Entity<Buffer>,
2180        active_buffer_snapshot: &BufferSnapshot,
2181        active_buffer_diagnostic_search_range: Range<Point>,
2182        active_buffer_cursor_point: Point,
2183        project: &Entity<Project>,
2184        cx: &mut AsyncApp,
2185    ) -> Result<Option<(Entity<Buffer>, language::Anchor)>> {
2186        let collaborator_cursor_rows: Vec<u32> = active_buffer_snapshot
2187            .selections_in_range(Anchor::MIN..Anchor::MAX, false)
2188            .flat_map(|(_, _, _, selections)| {
2189                selections.map(|s| s.head().to_point(active_buffer_snapshot).row)
2190            })
2191            .collect();
2192
2193        let mut jump_location = active_buffer_snapshot
2194            .diagnostic_groups(None)
2195            .into_iter()
2196            .filter_map(|(_, group)| {
2197                let range = &group.entries[group.primary_ix]
2198                    .range
2199                    .to_point(&active_buffer_snapshot);
2200                if range.overlaps(&active_buffer_diagnostic_search_range) {
2201                    return None;
2202                }
2203                let near_collaborator = collaborator_cursor_rows.iter().any(|&collab_row| {
2204                    range.start.row.abs_diff(collab_row) <= DIAGNOSTIC_LINES_RANGE
2205                });
2206                let near_local = active_buffer_cursor_point.row.abs_diff(range.start.row)
2207                    <= DIAGNOSTIC_LINES_RANGE;
2208                if near_collaborator && !near_local {
2209                    return None;
2210                }
2211                Some(range.start)
2212            })
2213            .min_by_key(|probe| probe.row.abs_diff(active_buffer_cursor_point.row))
2214            .map(|position| {
2215                (
2216                    active_buffer.clone(),
2217                    active_buffer_snapshot.anchor_before(position),
2218                )
2219            });
2220
2221        if jump_location.is_none() {
2222            let active_buffer_path = active_buffer.read_with(cx, |buffer, cx| {
2223                let file = buffer.file()?;
2224
2225                Some(ProjectPath {
2226                    worktree_id: file.worktree_id(cx),
2227                    path: file.path().clone(),
2228                })
2229            });
2230
2231            let mut candidates: Vec<(ProjectPath, usize)> = project.read_with(cx, |project, cx| {
2232                project
2233                    .diagnostic_summaries(false, cx)
2234                    .filter(|(path, _, _)| Some(path) != active_buffer_path.as_ref())
2235                    .map(|(path, _, _)| {
2236                        let shared_prefix = path
2237                            .path
2238                            .components()
2239                            .zip(
2240                                active_buffer_path
2241                                    .as_ref()
2242                                    .map(|p| p.path.components())
2243                                    .unwrap_or_default(),
2244                            )
2245                            .take_while(|(a, b)| a == b)
2246                            .count();
2247                        (path, shared_prefix)
2248                    })
2249                    .collect()
2250            });
2251
2252            candidates.sort_by(|a, b| b.1.cmp(&a.1));
2253
2254            for (path, _) in candidates {
2255                let candidate_buffer = project
2256                    .update(cx, |project, cx| project.open_buffer(path, cx))
2257                    .await?;
2258
2259                let (has_collaborators, diagnostic_position) =
2260                    candidate_buffer.read_with(cx, |buffer, _cx| {
2261                        let snapshot = buffer.snapshot();
2262                        let has_collaborators = snapshot
2263                            .selections_in_range(Anchor::MIN..Anchor::MAX, false)
2264                            .next()
2265                            .is_some();
2266                        let position = buffer
2267                            .buffer_diagnostics(None)
2268                            .into_iter()
2269                            .min_by_key(|entry| entry.diagnostic.severity)
2270                            .map(|entry| entry.range.start);
2271                        (has_collaborators, position)
2272                    });
2273
2274                if has_collaborators {
2275                    continue;
2276                }
2277
2278                if let Some(position) = diagnostic_position {
2279                    jump_location = Some((candidate_buffer, position));
2280                    break;
2281                }
2282            }
2283        }
2284
2285        anyhow::Ok(jump_location)
2286    }
2287
2288    async fn send_raw_llm_request(
2289        request: RawCompletionRequest,
2290        client: Arc<Client>,
2291        custom_url: Option<Arc<Url>>,
2292        llm_token: LlmApiToken,
2293        app_version: Version,
2294    ) -> Result<(RawCompletionResponse, Option<EditPredictionUsage>)> {
2295        let url = if let Some(custom_url) = custom_url {
2296            custom_url.as_ref().clone()
2297        } else {
2298            client
2299                .http_client()
2300                .build_zed_llm_url("/predict_edits/raw", &[])?
2301        };
2302
2303        Self::send_api_request(
2304            |builder| {
2305                let req = builder
2306                    .uri(url.as_ref())
2307                    .body(serde_json::to_string(&request)?.into());
2308                Ok(req?)
2309            },
2310            client,
2311            llm_token,
2312            app_version,
2313            true,
2314        )
2315        .await
2316    }
2317
2318    pub(crate) async fn send_v3_request(
2319        input: ZetaPromptInput,
2320        client: Arc<Client>,
2321        llm_token: LlmApiToken,
2322        app_version: Version,
2323        trigger: PredictEditsRequestTrigger,
2324    ) -> Result<(PredictEditsV3Response, Option<EditPredictionUsage>)> {
2325        let url = client
2326            .http_client()
2327            .build_zed_llm_url("/predict_edits/v3", &[])?;
2328
2329        let request = PredictEditsV3Request { input, trigger };
2330
2331        let json_bytes = serde_json::to_vec(&request)?;
2332        let compressed = zstd::encode_all(&json_bytes[..], 3)?;
2333
2334        Self::send_api_request(
2335            |builder| {
2336                let req = builder
2337                    .uri(url.as_ref())
2338                    .header("Content-Encoding", "zstd")
2339                    .body(compressed.clone().into());
2340                Ok(req?)
2341            },
2342            client,
2343            llm_token,
2344            app_version,
2345            true,
2346        )
2347        .await
2348    }
2349
2350    fn handle_api_response<T>(
2351        this: &WeakEntity<Self>,
2352        response: Result<(T, Option<EditPredictionUsage>)>,
2353        cx: &mut gpui::AsyncApp,
2354    ) -> Result<T> {
2355        match response {
2356            Ok((data, usage)) => {
2357                if let Some(usage) = usage {
2358                    this.update(cx, |this, cx| {
2359                        this.user_store.update(cx, |user_store, cx| {
2360                            user_store.update_edit_prediction_usage(usage, cx);
2361                        });
2362                    })
2363                    .ok();
2364                }
2365                Ok(data)
2366            }
2367            Err(err) => {
2368                if err.is::<ZedUpdateRequiredError>() {
2369                    cx.update(|cx| {
2370                        this.update(cx, |this, _cx| {
2371                            this.update_required = true;
2372                        })
2373                        .ok();
2374
2375                        let error_message: SharedString = err.to_string().into();
2376                        show_app_notification(
2377                            NotificationId::unique::<ZedUpdateRequiredError>(),
2378                            cx,
2379                            move |cx| {
2380                                cx.new(|cx| {
2381                                    ErrorMessagePrompt::new(error_message.clone(), cx)
2382                                        .with_link_button("Update Zed", "https://zed.dev/releases")
2383                                })
2384                            },
2385                        );
2386                    });
2387                }
2388                Err(err)
2389            }
2390        }
2391    }
2392
2393    async fn send_api_request<Res>(
2394        build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
2395        client: Arc<Client>,
2396        llm_token: LlmApiToken,
2397        app_version: Version,
2398        require_auth: bool,
2399    ) -> Result<(Res, Option<EditPredictionUsage>)>
2400    where
2401        Res: DeserializeOwned,
2402    {
2403        let http_client = client.http_client();
2404
2405        let mut token = if require_auth {
2406            Some(llm_token.acquire(&client).await?)
2407        } else {
2408            llm_token.acquire(&client).await.ok()
2409        };
2410        let mut did_retry = false;
2411
2412        loop {
2413            let request_builder = http_client::Request::builder().method(Method::POST);
2414
2415            let mut request_builder = request_builder
2416                .header("Content-Type", "application/json")
2417                .header(ZED_VERSION_HEADER_NAME, app_version.to_string());
2418
2419            // Only add Authorization header if we have a token
2420            if let Some(ref token_value) = token {
2421                request_builder =
2422                    request_builder.header("Authorization", format!("Bearer {}", token_value));
2423            }
2424
2425            let request = build(request_builder)?;
2426
2427            let mut response = http_client.send(request).await?;
2428
2429            if let Some(minimum_required_version) = response
2430                .headers()
2431                .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
2432                .and_then(|version| Version::from_str(version.to_str().ok()?).ok())
2433            {
2434                anyhow::ensure!(
2435                    app_version >= minimum_required_version,
2436                    ZedUpdateRequiredError {
2437                        minimum_version: minimum_required_version
2438                    }
2439                );
2440            }
2441
2442            if response.status().is_success() {
2443                let usage = EditPredictionUsage::from_headers(response.headers()).ok();
2444
2445                let mut body = Vec::new();
2446                response.body_mut().read_to_end(&mut body).await?;
2447                return Ok((serde_json::from_slice(&body)?, usage));
2448            } else if !did_retry && token.is_some() && response.needs_llm_token_refresh() {
2449                did_retry = true;
2450                token = Some(llm_token.refresh(&client).await?);
2451            } else {
2452                let mut body = String::new();
2453                response.body_mut().read_to_string(&mut body).await?;
2454                anyhow::bail!(
2455                    "Request failed with status: {:?}\nBody: {}",
2456                    response.status(),
2457                    body
2458                );
2459            }
2460        }
2461    }
2462
2463    pub fn refresh_context(
2464        &mut self,
2465        project: &Entity<Project>,
2466        buffer: &Entity<language::Buffer>,
2467        cursor_position: language::Anchor,
2468        cx: &mut Context<Self>,
2469    ) {
2470        self.get_or_init_project(project, cx)
2471            .context
2472            .update(cx, |store, cx| {
2473                store.refresh(buffer.clone(), cursor_position, cx);
2474            });
2475    }
2476
2477    #[cfg(feature = "cli-support")]
2478    pub fn set_context_for_buffer(
2479        &mut self,
2480        project: &Entity<Project>,
2481        related_files: Vec<RelatedFile>,
2482        cx: &mut Context<Self>,
2483    ) {
2484        self.get_or_init_project(project, cx)
2485            .context
2486            .update(cx, |store, cx| {
2487                store.set_related_files(related_files, cx);
2488            });
2489    }
2490
2491    #[cfg(feature = "cli-support")]
2492    pub fn set_recent_paths_for_project(
2493        &mut self,
2494        project: &Entity<Project>,
2495        paths: impl IntoIterator<Item = project::ProjectPath>,
2496        cx: &mut Context<Self>,
2497    ) {
2498        let project_state = self.get_or_init_project(project, cx);
2499        project_state.recent_paths = paths.into_iter().collect();
2500    }
2501
2502    fn is_file_open_source(
2503        &self,
2504        project: &Entity<Project>,
2505        file: &Arc<dyn File>,
2506        cx: &App,
2507    ) -> bool {
2508        if !file.is_local() || file.is_private() {
2509            return false;
2510        }
2511        let Some(project_state) = self.projects.get(&project.entity_id()) else {
2512            return false;
2513        };
2514        project_state
2515            .license_detection_watchers
2516            .get(&file.worktree_id(cx))
2517            .as_ref()
2518            .is_some_and(|watcher| watcher.is_project_open_source())
2519    }
2520
2521    pub(crate) fn is_data_collection_enabled(&self, cx: &App) -> bool {
2522        self.data_collection_choice.is_enabled(cx)
2523    }
2524
2525    fn load_data_collection_choice() -> DataCollectionChoice {
2526        let choice = KEY_VALUE_STORE
2527            .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2528            .log_err()
2529            .flatten();
2530
2531        match choice.as_deref() {
2532            Some("true") => DataCollectionChoice::Enabled,
2533            Some("false") => DataCollectionChoice::Disabled,
2534            Some(_) => {
2535                log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
2536                DataCollectionChoice::NotAnswered
2537            }
2538            None => DataCollectionChoice::NotAnswered,
2539        }
2540    }
2541
2542    fn toggle_data_collection_choice(&mut self, cx: &mut Context<Self>) {
2543        self.data_collection_choice = self.data_collection_choice.toggle();
2544        let new_choice = self.data_collection_choice;
2545        let is_enabled = new_choice.is_enabled(cx);
2546        db::write_and_log(cx, move || {
2547            KEY_VALUE_STORE.write_kvp(
2548                ZED_PREDICT_DATA_COLLECTION_CHOICE.into(),
2549                is_enabled.to_string(),
2550            )
2551        });
2552    }
2553
2554    pub fn shown_predictions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
2555        self.shown_predictions.iter()
2556    }
2557
2558    pub fn shown_completions_len(&self) -> usize {
2559        self.shown_predictions.len()
2560    }
2561
2562    pub fn is_prediction_rated(&self, id: &EditPredictionId) -> bool {
2563        self.rated_predictions.contains(id)
2564    }
2565
2566    pub fn rate_prediction(
2567        &mut self,
2568        prediction: &EditPrediction,
2569        rating: EditPredictionRating,
2570        feedback: String,
2571        cx: &mut Context<Self>,
2572    ) {
2573        let organization = self.user_store.read(cx).current_organization();
2574
2575        self.rated_predictions.insert(prediction.id.clone());
2576
2577        cx.background_spawn({
2578            let client = self.client.clone();
2579            let prediction_id = prediction.id.to_string();
2580            let inputs = serde_json::to_value(&prediction.inputs);
2581            let output = prediction
2582                .edit_preview
2583                .as_unified_diff(prediction.snapshot.file(), &prediction.edits);
2584            async move {
2585                client
2586                    .cloud_client()
2587                    .submit_edit_prediction_feedback(SubmitEditPredictionFeedbackBody {
2588                        organization_id: organization.map(|organization| organization.id.clone()),
2589                        request_id: prediction_id,
2590                        rating: match rating {
2591                            EditPredictionRating::Positive => "positive".to_string(),
2592                            EditPredictionRating::Negative => "negative".to_string(),
2593                        },
2594                        inputs: inputs?,
2595                        output,
2596                        feedback,
2597                    })
2598                    .await?;
2599
2600                anyhow::Ok(())
2601            }
2602        })
2603        .detach_and_log_err(cx);
2604
2605        cx.notify();
2606    }
2607}
2608
2609fn merge_trailing_events_if_needed(
2610    events: &mut VecDeque<StoredEvent>,
2611    end_snapshot: &TextBufferSnapshot,
2612    latest_snapshot: &TextBufferSnapshot,
2613    latest_edit_range: &Range<Anchor>,
2614) {
2615    if let Some(last_event) = events.back() {
2616        if last_event.old_snapshot.remote_id() != latest_snapshot.remote_id() {
2617            return;
2618        }
2619    }
2620
2621    let mut next_old_event = None;
2622    let mut mergeable_count = 0;
2623    for old_event in events.iter().rev() {
2624        if let Some(next_old_event) = &next_old_event
2625            && !old_event.can_merge(&next_old_event, latest_snapshot, latest_edit_range)
2626        {
2627            break;
2628        }
2629        mergeable_count += 1;
2630        next_old_event = Some(old_event);
2631    }
2632
2633    if mergeable_count <= 1 {
2634        return;
2635    }
2636
2637    let mut events_to_merge = events.range(events.len() - mergeable_count..).peekable();
2638    let oldest_event = events_to_merge.peek().unwrap();
2639    let oldest_snapshot = oldest_event.old_snapshot.clone();
2640
2641    if let Some((diff, edited_range)) =
2642        compute_diff_between_snapshots(&oldest_snapshot, end_snapshot)
2643    {
2644        let merged_event = match oldest_event.event.as_ref() {
2645            zeta_prompt::Event::BufferChange {
2646                old_path,
2647                path,
2648                in_open_source_repo,
2649                ..
2650            } => StoredEvent {
2651                event: Arc::new(zeta_prompt::Event::BufferChange {
2652                    old_path: old_path.clone(),
2653                    path: path.clone(),
2654                    diff,
2655                    in_open_source_repo: *in_open_source_repo,
2656                    predicted: events_to_merge.all(|e| {
2657                        matches!(
2658                            e.event.as_ref(),
2659                            zeta_prompt::Event::BufferChange {
2660                                predicted: true,
2661                                ..
2662                            }
2663                        )
2664                    }),
2665                }),
2666                old_snapshot: oldest_snapshot.clone(),
2667                edit_range: end_snapshot.anchor_before(edited_range.start)
2668                    ..end_snapshot.anchor_before(edited_range.end),
2669            },
2670        };
2671        events.truncate(events.len() - mergeable_count);
2672        events.push_back(merged_event);
2673    }
2674}
2675
2676pub(crate) fn filter_redundant_excerpts(
2677    mut related_files: Vec<RelatedFile>,
2678    cursor_path: &Path,
2679    cursor_row_range: Range<u32>,
2680) -> Vec<RelatedFile> {
2681    for file in &mut related_files {
2682        if file.path.as_ref() == cursor_path {
2683            file.excerpts.retain(|excerpt| {
2684                excerpt.row_range.start < cursor_row_range.start
2685                    || excerpt.row_range.end > cursor_row_range.end
2686            });
2687        }
2688    }
2689    related_files.retain(|file| !file.excerpts.is_empty());
2690    related_files
2691}
2692
2693#[derive(Error, Debug)]
2694#[error(
2695    "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
2696)]
2697pub struct ZedUpdateRequiredError {
2698    minimum_version: Version,
2699}
2700
2701#[derive(Debug, Clone, Copy)]
2702pub enum DataCollectionChoice {
2703    NotAnswered,
2704    Enabled,
2705    Disabled,
2706}
2707
2708impl DataCollectionChoice {
2709    pub fn is_enabled(self, cx: &App) -> bool {
2710        if cx.is_staff() {
2711            return true;
2712        }
2713        match self {
2714            Self::Enabled => true,
2715            Self::NotAnswered | Self::Disabled => false,
2716        }
2717    }
2718
2719    #[must_use]
2720    pub fn toggle(&self) -> DataCollectionChoice {
2721        match self {
2722            Self::Enabled => Self::Disabled,
2723            Self::Disabled => Self::Enabled,
2724            Self::NotAnswered => Self::Enabled,
2725        }
2726    }
2727}
2728
2729impl From<bool> for DataCollectionChoice {
2730    fn from(value: bool) -> Self {
2731        match value {
2732            true => DataCollectionChoice::Enabled,
2733            false => DataCollectionChoice::Disabled,
2734        }
2735    }
2736}
2737
2738struct ZedPredictUpsell;
2739
2740impl Dismissable for ZedPredictUpsell {
2741    const KEY: &'static str = "dismissed-edit-predict-upsell";
2742
2743    fn dismissed() -> bool {
2744        // To make this backwards compatible with older versions of Zed, we
2745        // check if the user has seen the previous Edit Prediction Onboarding
2746        // before, by checking the data collection choice which was written to
2747        // the database once the user clicked on "Accept and Enable"
2748        if KEY_VALUE_STORE
2749            .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2750            .log_err()
2751            .is_some_and(|s| s.is_some())
2752        {
2753            return true;
2754        }
2755
2756        KEY_VALUE_STORE
2757            .read_kvp(Self::KEY)
2758            .log_err()
2759            .is_some_and(|s| s.is_some())
2760    }
2761}
2762
2763pub fn should_show_upsell_modal() -> bool {
2764    !ZedPredictUpsell::dismissed()
2765}
2766
2767pub fn init(cx: &mut App) {
2768    cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
2769        workspace.register_action(
2770            move |workspace, _: &zed_actions::OpenZedPredictOnboarding, window, cx| {
2771                ZedPredictModal::toggle(
2772                    workspace,
2773                    workspace.user_store().clone(),
2774                    workspace.client().clone(),
2775                    window,
2776                    cx,
2777                )
2778            },
2779        );
2780
2781        workspace.register_action(|workspace, _: &ResetOnboarding, _window, cx| {
2782            update_settings_file(workspace.app_state().fs.clone(), cx, move |settings, _| {
2783                settings
2784                    .project
2785                    .all_languages
2786                    .edit_predictions
2787                    .get_or_insert_default()
2788                    .provider = Some(EditPredictionProvider::None)
2789            });
2790        });
2791        fn copilot_for_project(project: &Entity<Project>, cx: &mut App) -> Option<Entity<Copilot>> {
2792            EditPredictionStore::try_global(cx).and_then(|store| {
2793                store.update(cx, |this, cx| this.start_copilot_for_project(project, cx))
2794            })
2795        }
2796
2797        workspace.register_action(|workspace, _: &SignIn, window, cx| {
2798            if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
2799                copilot_ui::initiate_sign_in(copilot, window, cx);
2800            }
2801        });
2802        workspace.register_action(|workspace, _: &Reinstall, window, cx| {
2803            if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
2804                copilot_ui::reinstall_and_sign_in(copilot, window, cx);
2805            }
2806        });
2807        workspace.register_action(|workspace, _: &SignOut, window, cx| {
2808            if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
2809                copilot_ui::initiate_sign_out(copilot, window, cx);
2810            }
2811        });
2812    })
2813    .detach();
2814}