edit_prediction.rs

   1use anyhow::Result;
   2use arrayvec::ArrayVec;
   3use client::{Client, EditPredictionUsage, UserStore};
   4use cloud_api_types::SubmitEditPredictionFeedbackBody;
   5use cloud_llm_client::predict_edits_v3::{
   6    PredictEditsV3Request, PredictEditsV3Response, RawCompletionRequest, RawCompletionResponse,
   7};
   8use cloud_llm_client::{
   9    EditPredictionRejectReason, EditPredictionRejection,
  10    MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
  11    PredictEditsRequestTrigger, RejectEditPredictionsBodyRef, ZED_VERSION_HEADER_NAME,
  12};
  13use collections::{HashMap, HashSet};
  14use copilot::{Copilot, Reinstall, SignIn, SignOut};
  15use db::kvp::{Dismissable, KEY_VALUE_STORE};
  16use edit_prediction_context::{RelatedExcerptStore, RelatedExcerptStoreEvent, RelatedFile};
  17use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
  18use futures::{
  19    AsyncReadExt as _, FutureExt as _, StreamExt as _,
  20    channel::mpsc::{self, UnboundedReceiver},
  21    select_biased,
  22};
  23use gpui::BackgroundExecutor;
  24use gpui::http_client::Url;
  25use gpui::{
  26    App, AsyncApp, Entity, EntityId, Global, SharedString, Subscription, Task, WeakEntity, actions,
  27    http_client::{self, AsyncBody, Method},
  28    prelude::*,
  29};
  30use language::language_settings::all_language_settings;
  31use language::{Anchor, Buffer, File, Point, TextBufferSnapshot, ToOffset, ToPoint};
  32use language::{BufferSnapshot, OffsetRangeExt};
  33use language_model::{LlmApiToken, NeedsLlmTokenRefresh, RefreshLlmTokenListener};
  34use project::{DisableAiSettings, Project, ProjectPath, WorktreeId};
  35use release_channel::AppVersion;
  36use semver::Version;
  37use serde::de::DeserializeOwned;
  38use settings::{
  39    EditPredictionPromptFormat, EditPredictionProvider, Settings as _, update_settings_file,
  40};
  41use std::collections::{VecDeque, hash_map};
  42use std::env;
  43use text::{AnchorRangeExt, Edit};
  44use workspace::Workspace;
  45use zeta_prompt::{ZetaFormat, ZetaPromptInput};
  46
  47use std::mem;
  48use std::ops::Range;
  49use std::path::Path;
  50use std::rc::Rc;
  51use std::str::FromStr as _;
  52use std::sync::Arc;
  53use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
  54use thiserror::Error;
  55use util::{RangeExt as _, ResultExt as _};
  56use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
  57
  58pub mod cursor_excerpt;
  59pub mod example_spec;
  60pub mod fim;
  61mod license_detection;
  62pub mod mercury;
  63pub mod ollama;
  64mod onboarding_modal;
  65pub mod open_ai_response;
  66mod prediction;
  67pub mod sweep_ai;
  68
  69pub mod udiff;
  70
  71mod capture_example;
  72mod zed_edit_prediction_delegate;
  73pub mod zeta;
  74
  75#[cfg(test)]
  76mod edit_prediction_tests;
  77
  78use crate::license_detection::LicenseDetectionWatcher;
  79use crate::mercury::Mercury;
  80use crate::onboarding_modal::ZedPredictModal;
  81pub use crate::prediction::EditPrediction;
  82pub use crate::prediction::EditPredictionId;
  83use crate::prediction::EditPredictionResult;
  84pub use crate::sweep_ai::SweepAi;
  85pub use capture_example::capture_example;
  86pub use language_model::ApiKeyState;
  87pub use telemetry_events::EditPredictionRating;
  88pub use zed_edit_prediction_delegate::ZedEditPredictionDelegate;
  89
  90actions!(
  91    edit_prediction,
  92    [
  93        /// Resets the edit prediction onboarding state.
  94        ResetOnboarding,
  95        /// Clears the edit prediction history.
  96        ClearHistory,
  97    ]
  98);
  99
 100/// Maximum number of events to track.
 101const EVENT_COUNT_MAX: usize = 6;
 102const CHANGE_GROUPING_LINE_SPAN: u32 = 8;
 103const LAST_CHANGE_GROUPING_TIME: Duration = Duration::from_secs(1);
 104const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
 105const REJECT_REQUEST_DEBOUNCE: Duration = Duration::from_secs(15);
 106const EDIT_PREDICTION_SETTLED_EVENT: &str = "Edit Prediction Settled";
 107const EDIT_PREDICTION_SETTLED_TTL: Duration = Duration::from_secs(60 * 5);
 108const EDIT_PREDICTION_SETTLED_QUIESCENCE: Duration = Duration::from_secs(10);
 109
 110pub struct Zeta2FeatureFlag;
 111pub struct EditPredictionJumpsFeatureFlag;
 112
 113impl FeatureFlag for Zeta2FeatureFlag {
 114    const NAME: &'static str = "zeta2";
 115}
 116
 117impl FeatureFlag for EditPredictionJumpsFeatureFlag {
 118    const NAME: &'static str = "edit_prediction_jumps";
 119}
 120
 121#[derive(Clone)]
 122struct EditPredictionStoreGlobal(Entity<EditPredictionStore>);
 123
 124impl Global for EditPredictionStoreGlobal {}
 125
 126/// Configuration for using the raw Zeta2 endpoint.
 127/// When set, the client uses the raw endpoint and constructs the prompt itself.
 128/// The version is also used as the Baseten environment name (lowercased).
 129#[derive(Clone)]
 130pub struct Zeta2RawConfig {
 131    pub model_id: Option<String>,
 132    pub format: ZetaFormat,
 133}
 134
 135pub struct EditPredictionStore {
 136    client: Arc<Client>,
 137    user_store: Entity<UserStore>,
 138    llm_token: LlmApiToken,
 139    _llm_token_subscription: Subscription,
 140    projects: HashMap<EntityId, ProjectState>,
 141    update_required: bool,
 142    edit_prediction_model: EditPredictionModel,
 143    zeta2_raw_config: Option<Zeta2RawConfig>,
 144    pub sweep_ai: SweepAi,
 145    pub mercury: Mercury,
 146    data_collection_choice: DataCollectionChoice,
 147    reject_predictions_tx: mpsc::UnboundedSender<EditPredictionRejection>,
 148    settled_predictions_tx: mpsc::UnboundedSender<Instant>,
 149    shown_predictions: VecDeque<EditPrediction>,
 150    rated_predictions: HashSet<EditPredictionId>,
 151    #[cfg(test)]
 152    settled_event_callback: Option<Box<dyn Fn(EditPredictionId, String)>>,
 153}
 154
 155#[derive(Copy, Clone, PartialEq, Eq)]
 156pub enum EditPredictionModel {
 157    Zeta1,
 158    Zeta2,
 159    Fim { format: EditPredictionPromptFormat },
 160    Sweep,
 161    Mercury,
 162}
 163
 164#[derive(Clone)]
 165pub struct EditPredictionModelInput {
 166    project: Entity<Project>,
 167    buffer: Entity<Buffer>,
 168    snapshot: BufferSnapshot,
 169    position: Anchor,
 170    events: Vec<Arc<zeta_prompt::Event>>,
 171    related_files: Vec<RelatedFile>,
 172    recent_paths: VecDeque<ProjectPath>,
 173    trigger: PredictEditsRequestTrigger,
 174    diagnostic_search_range: Range<Point>,
 175    debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
 176    pub user_actions: Vec<UserActionRecord>,
 177}
 178
 179#[derive(Debug)]
 180pub enum DebugEvent {
 181    ContextRetrievalStarted(ContextRetrievalStartedDebugEvent),
 182    ContextRetrievalFinished(ContextRetrievalFinishedDebugEvent),
 183    EditPredictionStarted(EditPredictionStartedDebugEvent),
 184    EditPredictionFinished(EditPredictionFinishedDebugEvent),
 185}
 186
 187#[derive(Debug)]
 188pub struct ContextRetrievalStartedDebugEvent {
 189    pub project_entity_id: EntityId,
 190    pub timestamp: Instant,
 191    pub search_prompt: String,
 192}
 193
 194#[derive(Debug)]
 195pub struct ContextRetrievalFinishedDebugEvent {
 196    pub project_entity_id: EntityId,
 197    pub timestamp: Instant,
 198    pub metadata: Vec<(&'static str, SharedString)>,
 199}
 200
 201#[derive(Debug)]
 202pub struct EditPredictionStartedDebugEvent {
 203    pub buffer: WeakEntity<Buffer>,
 204    pub position: Anchor,
 205    pub prompt: Option<String>,
 206}
 207
 208#[derive(Debug)]
 209pub struct EditPredictionFinishedDebugEvent {
 210    pub buffer: WeakEntity<Buffer>,
 211    pub position: Anchor,
 212    pub model_output: Option<String>,
 213}
 214
 215const USER_ACTION_HISTORY_SIZE: usize = 16;
 216
 217#[derive(Clone, Debug)]
 218pub struct UserActionRecord {
 219    pub action_type: UserActionType,
 220    pub buffer_id: EntityId,
 221    pub line_number: u32,
 222    pub offset: usize,
 223    pub timestamp_epoch_ms: u64,
 224}
 225
 226#[derive(Clone, Copy, Debug, PartialEq, Eq)]
 227pub enum UserActionType {
 228    InsertChar,
 229    InsertSelection,
 230    DeleteChar,
 231    DeleteSelection,
 232    CursorMovement,
 233}
 234
 235/// An event with associated metadata for reconstructing buffer state.
 236#[derive(Clone)]
 237pub struct StoredEvent {
 238    pub event: Arc<zeta_prompt::Event>,
 239    pub old_snapshot: TextBufferSnapshot,
 240    pub edit_range: Range<Anchor>,
 241}
 242
 243impl StoredEvent {
 244    fn can_merge(
 245        &self,
 246        next_old_event: &&&StoredEvent,
 247        new_snapshot: &TextBufferSnapshot,
 248        last_edit_range: &Range<Anchor>,
 249    ) -> bool {
 250        // Events must be for the same buffer
 251        if self.old_snapshot.remote_id() != next_old_event.old_snapshot.remote_id() {
 252            return false;
 253        }
 254        if self.old_snapshot.remote_id() != new_snapshot.remote_id() {
 255            return false;
 256        }
 257
 258        let a_is_predicted = matches!(
 259            self.event.as_ref(),
 260            zeta_prompt::Event::BufferChange {
 261                predicted: true,
 262                ..
 263            }
 264        );
 265        let b_is_predicted = matches!(
 266            next_old_event.event.as_ref(),
 267            zeta_prompt::Event::BufferChange {
 268                predicted: true,
 269                ..
 270            }
 271        );
 272
 273        // If events come from the same source (both predicted or both manual) then
 274        // we would have coalesced them already.
 275        if a_is_predicted == b_is_predicted {
 276            return false;
 277        }
 278
 279        let left_range = self.edit_range.to_point(new_snapshot);
 280        let right_range = next_old_event.edit_range.to_point(new_snapshot);
 281        let latest_range = last_edit_range.to_point(&new_snapshot);
 282
 283        // Events near to the latest edit are not merged if their sources differ.
 284        if lines_between_ranges(&left_range, &latest_range)
 285            .min(lines_between_ranges(&right_range, &latest_range))
 286            <= CHANGE_GROUPING_LINE_SPAN
 287        {
 288            return false;
 289        }
 290
 291        // Events that are distant from each other are not merged.
 292        if lines_between_ranges(&left_range, &right_range) > CHANGE_GROUPING_LINE_SPAN {
 293            return false;
 294        }
 295
 296        true
 297    }
 298}
 299
 300fn lines_between_ranges(left: &Range<Point>, right: &Range<Point>) -> u32 {
 301    if left.start > right.end {
 302        return left.start.row - right.end.row;
 303    }
 304    if right.start > left.end {
 305        return right.start.row - left.end.row;
 306    }
 307    0
 308}
 309
 310struct ProjectState {
 311    events: VecDeque<StoredEvent>,
 312    last_event: Option<LastEvent>,
 313    recent_paths: VecDeque<ProjectPath>,
 314    registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
 315    current_prediction: Option<CurrentEditPrediction>,
 316    next_pending_prediction_id: usize,
 317    pending_predictions: ArrayVec<PendingPrediction, 2>,
 318    debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
 319    last_edit_prediction_refresh: Option<(EntityId, Instant)>,
 320    last_jump_prediction_refresh: Option<(EntityId, Instant)>,
 321    cancelled_predictions: HashSet<usize>,
 322    context: Entity<RelatedExcerptStore>,
 323    license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
 324    user_actions: VecDeque<UserActionRecord>,
 325    _subscriptions: [gpui::Subscription; 2],
 326    copilot: Option<Entity<Copilot>>,
 327}
 328
 329impl ProjectState {
 330    fn record_user_action(&mut self, action: UserActionRecord) {
 331        if self.user_actions.len() >= USER_ACTION_HISTORY_SIZE {
 332            self.user_actions.pop_front();
 333        }
 334        self.user_actions.push_back(action);
 335    }
 336
 337    pub fn events(&self, cx: &App) -> Vec<StoredEvent> {
 338        self.events
 339            .iter()
 340            .cloned()
 341            .chain(self.last_event.as_ref().iter().flat_map(|event| {
 342                let (one, two) = event.split_by_pause();
 343                let one = one.finalize(&self.license_detection_watchers, cx);
 344                let two = two.and_then(|two| two.finalize(&self.license_detection_watchers, cx));
 345                one.into_iter().chain(two)
 346            }))
 347            .collect()
 348    }
 349
 350    fn cancel_pending_prediction(
 351        &mut self,
 352        pending_prediction: PendingPrediction,
 353        cx: &mut Context<EditPredictionStore>,
 354    ) {
 355        self.cancelled_predictions.insert(pending_prediction.id);
 356
 357        if pending_prediction.drop_on_cancel {
 358            drop(pending_prediction.task);
 359        } else {
 360            cx.spawn(async move |this, cx| {
 361                let Some(prediction_id) = pending_prediction.task.await else {
 362                    return;
 363                };
 364
 365                this.update(cx, |this, cx| {
 366                    this.reject_prediction(
 367                        prediction_id,
 368                        EditPredictionRejectReason::Canceled,
 369                        false,
 370                        None,
 371                        cx,
 372                    );
 373                })
 374                .ok();
 375            })
 376            .detach()
 377        }
 378    }
 379
 380    fn active_buffer(
 381        &self,
 382        project: &Entity<Project>,
 383        cx: &App,
 384    ) -> Option<(Entity<Buffer>, Option<Anchor>)> {
 385        let project = project.read(cx);
 386        let active_path = project.path_for_entry(project.active_entry()?, cx)?;
 387        let active_buffer = project.buffer_store().read(cx).get_by_path(&active_path)?;
 388        let registered_buffer = self.registered_buffers.get(&active_buffer.entity_id())?;
 389        Some((active_buffer, registered_buffer.last_position))
 390    }
 391}
 392
 393#[derive(Debug, Clone)]
 394struct CurrentEditPrediction {
 395    pub requested_by: PredictionRequestedBy,
 396    pub prediction: EditPrediction,
 397    pub was_shown: bool,
 398    pub shown_with: Option<edit_prediction_types::SuggestionDisplayType>,
 399}
 400
 401impl CurrentEditPrediction {
 402    fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
 403        let Some(new_edits) = self
 404            .prediction
 405            .interpolate(&self.prediction.buffer.read(cx))
 406        else {
 407            return false;
 408        };
 409
 410        if self.prediction.buffer != old_prediction.prediction.buffer {
 411            return true;
 412        }
 413
 414        let Some(old_edits) = old_prediction
 415            .prediction
 416            .interpolate(&old_prediction.prediction.buffer.read(cx))
 417        else {
 418            return true;
 419        };
 420
 421        let requested_by_buffer_id = self.requested_by.buffer_id();
 422
 423        // This reduces the occurrence of UI thrash from replacing edits
 424        //
 425        // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
 426        if requested_by_buffer_id == Some(self.prediction.buffer.entity_id())
 427            && requested_by_buffer_id == Some(old_prediction.prediction.buffer.entity_id())
 428            && old_edits.len() == 1
 429            && new_edits.len() == 1
 430        {
 431            let (old_range, old_text) = &old_edits[0];
 432            let (new_range, new_text) = &new_edits[0];
 433            new_range == old_range && new_text.starts_with(old_text.as_ref())
 434        } else {
 435            true
 436        }
 437    }
 438}
 439
 440#[derive(Debug, Clone)]
 441enum PredictionRequestedBy {
 442    DiagnosticsUpdate,
 443    Buffer(EntityId),
 444}
 445
 446impl PredictionRequestedBy {
 447    pub fn buffer_id(&self) -> Option<EntityId> {
 448        match self {
 449            PredictionRequestedBy::DiagnosticsUpdate => None,
 450            PredictionRequestedBy::Buffer(buffer_id) => Some(*buffer_id),
 451        }
 452    }
 453}
 454
 455const DIAGNOSTIC_LINES_RANGE: u32 = 20;
 456
 457#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
 458pub enum DiagnosticSearchScope {
 459    Local,
 460    Global,
 461}
 462
 463#[derive(Debug)]
 464struct PendingPrediction {
 465    id: usize,
 466    task: Task<Option<EditPredictionId>>,
 467    /// If true, the task is dropped immediately on cancel (cancelling the HTTP request).
 468    /// If false, the task is awaited to completion so rejection can be reported.
 469    drop_on_cancel: bool,
 470}
 471
 472/// A prediction from the perspective of a buffer.
 473#[derive(Debug)]
 474enum BufferEditPrediction<'a> {
 475    Local { prediction: &'a EditPrediction },
 476    Jump { prediction: &'a EditPrediction },
 477}
 478
 479#[cfg(test)]
 480impl std::ops::Deref for BufferEditPrediction<'_> {
 481    type Target = EditPrediction;
 482
 483    fn deref(&self) -> &Self::Target {
 484        match self {
 485            BufferEditPrediction::Local { prediction } => prediction,
 486            BufferEditPrediction::Jump { prediction } => prediction,
 487        }
 488    }
 489}
 490
 491#[derive(Clone)]
 492struct PendingSettledPrediction {
 493    request_id: EditPredictionId,
 494    editable_anchor_range: Range<Anchor>,
 495    enqueued_at: Instant,
 496    last_edit_at: Instant,
 497}
 498
 499struct RegisteredBuffer {
 500    file: Option<Arc<dyn File>>,
 501    snapshot: TextBufferSnapshot,
 502    pending_predictions: Vec<PendingSettledPrediction>,
 503    last_position: Option<Anchor>,
 504    _subscriptions: [gpui::Subscription; 2],
 505}
 506
 507#[derive(Clone)]
 508struct LastEvent {
 509    old_snapshot: TextBufferSnapshot,
 510    new_snapshot: TextBufferSnapshot,
 511    old_file: Option<Arc<dyn File>>,
 512    new_file: Option<Arc<dyn File>>,
 513    edit_range: Option<Range<Anchor>>,
 514    predicted: bool,
 515    snapshot_after_last_editing_pause: Option<TextBufferSnapshot>,
 516    last_edit_time: Option<Instant>,
 517}
 518
 519impl LastEvent {
 520    pub fn finalize(
 521        &self,
 522        license_detection_watchers: &HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
 523        cx: &App,
 524    ) -> Option<StoredEvent> {
 525        let path = buffer_path_with_id_fallback(self.new_file.as_ref(), &self.new_snapshot, cx);
 526        let old_path = buffer_path_with_id_fallback(self.old_file.as_ref(), &self.old_snapshot, cx);
 527
 528        let in_open_source_repo =
 529            [self.new_file.as_ref(), self.old_file.as_ref()]
 530                .iter()
 531                .all(|file| {
 532                    file.is_some_and(|file| {
 533                        license_detection_watchers
 534                            .get(&file.worktree_id(cx))
 535                            .is_some_and(|watcher| watcher.is_project_open_source())
 536                    })
 537                });
 538
 539        let (diff, edit_range) =
 540            compute_diff_between_snapshots(&self.old_snapshot, &self.new_snapshot)?;
 541
 542        if path == old_path && diff.is_empty() {
 543            None
 544        } else {
 545            Some(StoredEvent {
 546                event: Arc::new(zeta_prompt::Event::BufferChange {
 547                    old_path,
 548                    path,
 549                    diff,
 550                    in_open_source_repo,
 551                    predicted: self.predicted,
 552                }),
 553                edit_range: self.new_snapshot.anchor_before(edit_range.start)
 554                    ..self.new_snapshot.anchor_before(edit_range.end),
 555                old_snapshot: self.old_snapshot.clone(),
 556            })
 557        }
 558    }
 559
 560    pub fn split_by_pause(&self) -> (LastEvent, Option<LastEvent>) {
 561        let Some(boundary_snapshot) = self.snapshot_after_last_editing_pause.as_ref() else {
 562            return (self.clone(), None);
 563        };
 564
 565        let before = LastEvent {
 566            old_snapshot: self.old_snapshot.clone(),
 567            new_snapshot: boundary_snapshot.clone(),
 568            old_file: self.old_file.clone(),
 569            new_file: self.new_file.clone(),
 570            edit_range: None,
 571            predicted: self.predicted,
 572            snapshot_after_last_editing_pause: None,
 573            last_edit_time: self.last_edit_time,
 574        };
 575
 576        let after = LastEvent {
 577            old_snapshot: boundary_snapshot.clone(),
 578            new_snapshot: self.new_snapshot.clone(),
 579            old_file: self.old_file.clone(),
 580            new_file: self.new_file.clone(),
 581            edit_range: None,
 582            predicted: self.predicted,
 583            snapshot_after_last_editing_pause: None,
 584            last_edit_time: self.last_edit_time,
 585        };
 586
 587        (before, Some(after))
 588    }
 589}
 590
 591pub(crate) fn compute_diff_between_snapshots(
 592    old_snapshot: &TextBufferSnapshot,
 593    new_snapshot: &TextBufferSnapshot,
 594) -> Option<(String, Range<Point>)> {
 595    let edits: Vec<Edit<usize>> = new_snapshot
 596        .edits_since::<usize>(&old_snapshot.version)
 597        .collect();
 598
 599    let (first_edit, last_edit) = edits.first().zip(edits.last())?;
 600
 601    let old_start_point = old_snapshot.offset_to_point(first_edit.old.start);
 602    let old_end_point = old_snapshot.offset_to_point(last_edit.old.end);
 603    let new_start_point = new_snapshot.offset_to_point(first_edit.new.start);
 604    let new_end_point = new_snapshot.offset_to_point(last_edit.new.end);
 605
 606    const CONTEXT_LINES: u32 = 3;
 607
 608    let old_context_start_row = old_start_point.row.saturating_sub(CONTEXT_LINES);
 609    let new_context_start_row = new_start_point.row.saturating_sub(CONTEXT_LINES);
 610    let old_context_end_row =
 611        (old_end_point.row + 1 + CONTEXT_LINES).min(old_snapshot.max_point().row);
 612    let new_context_end_row =
 613        (new_end_point.row + 1 + CONTEXT_LINES).min(new_snapshot.max_point().row);
 614
 615    let old_start_line_offset = old_snapshot.point_to_offset(Point::new(old_context_start_row, 0));
 616    let new_start_line_offset = new_snapshot.point_to_offset(Point::new(new_context_start_row, 0));
 617    let old_end_line_offset = old_snapshot
 618        .point_to_offset(Point::new(old_context_end_row + 1, 0).min(old_snapshot.max_point()));
 619    let new_end_line_offset = new_snapshot
 620        .point_to_offset(Point::new(new_context_end_row + 1, 0).min(new_snapshot.max_point()));
 621    let old_edit_range = old_start_line_offset..old_end_line_offset;
 622    let new_edit_range = new_start_line_offset..new_end_line_offset;
 623
 624    let old_region_text: String = old_snapshot.text_for_range(old_edit_range).collect();
 625    let new_region_text: String = new_snapshot.text_for_range(new_edit_range).collect();
 626
 627    let diff = language::unified_diff_with_offsets(
 628        &old_region_text,
 629        &new_region_text,
 630        old_context_start_row,
 631        new_context_start_row,
 632    );
 633
 634    Some((diff, new_start_point..new_end_point))
 635}
 636
 637fn buffer_path_with_id_fallback(
 638    file: Option<&Arc<dyn File>>,
 639    snapshot: &TextBufferSnapshot,
 640    cx: &App,
 641) -> Arc<Path> {
 642    if let Some(file) = file {
 643        file.full_path(cx).into()
 644    } else {
 645        Path::new(&format!("untitled-{}", snapshot.remote_id())).into()
 646    }
 647}
 648
 649impl EditPredictionStore {
 650    pub fn try_global(cx: &App) -> Option<Entity<Self>> {
 651        cx.try_global::<EditPredictionStoreGlobal>()
 652            .map(|global| global.0.clone())
 653    }
 654
 655    pub fn global(
 656        client: &Arc<Client>,
 657        user_store: &Entity<UserStore>,
 658        cx: &mut App,
 659    ) -> Entity<Self> {
 660        cx.try_global::<EditPredictionStoreGlobal>()
 661            .map(|global| global.0.clone())
 662            .unwrap_or_else(|| {
 663                let ep_store = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
 664                cx.set_global(EditPredictionStoreGlobal(ep_store.clone()));
 665                ep_store
 666            })
 667    }
 668
 669    pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
 670        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
 671        let data_collection_choice = Self::load_data_collection_choice();
 672
 673        let llm_token = LlmApiToken::default();
 674
 675        let (reject_tx, reject_rx) = mpsc::unbounded();
 676        cx.background_spawn({
 677            let client = client.clone();
 678            let llm_token = llm_token.clone();
 679            let app_version = AppVersion::global(cx);
 680            let background_executor = cx.background_executor().clone();
 681            async move {
 682                Self::handle_rejected_predictions(
 683                    reject_rx,
 684                    client,
 685                    llm_token,
 686                    app_version,
 687                    background_executor,
 688                )
 689                .await
 690            }
 691        })
 692        .detach();
 693
 694        let (settled_predictions_tx, settled_predictions_rx) = mpsc::unbounded();
 695        cx.spawn(async move |this, cx| {
 696            Self::run_settled_predictions_worker(this, settled_predictions_rx, cx).await;
 697        })
 698        .detach();
 699
 700        let this = Self {
 701            projects: HashMap::default(),
 702            client,
 703            user_store,
 704            llm_token,
 705            _llm_token_subscription: cx.subscribe(
 706                &refresh_llm_token_listener,
 707                |this, _listener, _event, cx| {
 708                    let client = this.client.clone();
 709                    let llm_token = this.llm_token.clone();
 710                    cx.spawn(async move |_this, _cx| {
 711                        llm_token.refresh(&client).await?;
 712                        anyhow::Ok(())
 713                    })
 714                    .detach_and_log_err(cx);
 715                },
 716            ),
 717            update_required: false,
 718            edit_prediction_model: EditPredictionModel::Zeta2,
 719            zeta2_raw_config: Self::zeta2_raw_config_from_env(),
 720            sweep_ai: SweepAi::new(cx),
 721            mercury: Mercury::new(cx),
 722
 723            data_collection_choice,
 724            reject_predictions_tx: reject_tx,
 725            settled_predictions_tx,
 726            rated_predictions: Default::default(),
 727            shown_predictions: Default::default(),
 728            #[cfg(test)]
 729            settled_event_callback: None,
 730        };
 731
 732        this
 733    }
 734
 735    fn zeta2_raw_config_from_env() -> Option<Zeta2RawConfig> {
 736        let version_str = env::var("ZED_ZETA_FORMAT").ok()?;
 737        let format = ZetaFormat::parse(&version_str).ok()?;
 738        let model_id = env::var("ZED_ZETA_MODEL").ok();
 739        Some(Zeta2RawConfig { model_id, format })
 740    }
 741
 742    pub fn set_edit_prediction_model(&mut self, model: EditPredictionModel) {
 743        self.edit_prediction_model = model;
 744    }
 745
 746    pub fn set_zeta2_raw_config(&mut self, config: Zeta2RawConfig) {
 747        self.zeta2_raw_config = Some(config);
 748    }
 749
 750    pub fn zeta2_raw_config(&self) -> Option<&Zeta2RawConfig> {
 751        self.zeta2_raw_config.as_ref()
 752    }
 753
 754    pub fn icons(&self, cx: &App) -> edit_prediction_types::EditPredictionIconSet {
 755        use ui::IconName;
 756        match self.edit_prediction_model {
 757            EditPredictionModel::Sweep => {
 758                edit_prediction_types::EditPredictionIconSet::new(IconName::SweepAi)
 759                    .with_disabled(IconName::SweepAiDisabled)
 760                    .with_up(IconName::SweepAiUp)
 761                    .with_down(IconName::SweepAiDown)
 762                    .with_error(IconName::SweepAiError)
 763            }
 764            EditPredictionModel::Mercury => {
 765                edit_prediction_types::EditPredictionIconSet::new(IconName::Inception)
 766            }
 767            EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {
 768                edit_prediction_types::EditPredictionIconSet::new(IconName::ZedPredict)
 769                    .with_disabled(IconName::ZedPredictDisabled)
 770                    .with_up(IconName::ZedPredictUp)
 771                    .with_down(IconName::ZedPredictDown)
 772                    .with_error(IconName::ZedPredictError)
 773            }
 774            EditPredictionModel::Fim { .. } => {
 775                let settings = &all_language_settings(None, cx).edit_predictions;
 776                match settings.provider {
 777                    EditPredictionProvider::Ollama => {
 778                        edit_prediction_types::EditPredictionIconSet::new(IconName::AiOllama)
 779                    }
 780                    _ => {
 781                        edit_prediction_types::EditPredictionIconSet::new(IconName::AiOpenAiCompat)
 782                    }
 783                }
 784            }
 785        }
 786    }
 787
 788    pub fn has_sweep_api_token(&self, cx: &App) -> bool {
 789        self.sweep_ai.api_token.read(cx).has_key()
 790    }
 791
 792    pub fn has_mercury_api_token(&self, cx: &App) -> bool {
 793        self.mercury.api_token.read(cx).has_key()
 794    }
 795
 796    pub fn clear_history(&mut self) {
 797        for project_state in self.projects.values_mut() {
 798            project_state.events.clear();
 799            project_state.last_event.take();
 800        }
 801    }
 802
 803    pub fn clear_history_for_project(&mut self, project: &Entity<Project>) {
 804        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
 805            project_state.events.clear();
 806            project_state.last_event.take();
 807        }
 808    }
 809
 810    pub fn edit_history_for_project(
 811        &self,
 812        project: &Entity<Project>,
 813        cx: &App,
 814    ) -> Vec<StoredEvent> {
 815        self.projects
 816            .get(&project.entity_id())
 817            .map(|project_state| project_state.events(cx))
 818            .unwrap_or_default()
 819    }
 820
 821    pub fn context_for_project<'a>(
 822        &'a self,
 823        project: &Entity<Project>,
 824        cx: &'a mut App,
 825    ) -> Vec<RelatedFile> {
 826        self.projects
 827            .get(&project.entity_id())
 828            .map(|project_state| {
 829                project_state.context.update(cx, |context, cx| {
 830                    context
 831                        .related_files_with_buffers(cx)
 832                        .map(|(mut related_file, buffer)| {
 833                            related_file.in_open_source_repo = buffer
 834                                .read(cx)
 835                                .file()
 836                                .map_or(false, |file| self.is_file_open_source(&project, file, cx));
 837                            related_file
 838                        })
 839                        .collect()
 840                })
 841            })
 842            .unwrap_or_default()
 843    }
 844
 845    pub fn copilot_for_project(&self, project: &Entity<Project>) -> Option<Entity<Copilot>> {
 846        self.projects
 847            .get(&project.entity_id())
 848            .and_then(|project| project.copilot.clone())
 849    }
 850
 851    pub fn start_copilot_for_project(
 852        &mut self,
 853        project: &Entity<Project>,
 854        cx: &mut Context<Self>,
 855    ) -> Option<Entity<Copilot>> {
 856        if DisableAiSettings::get(None, cx).disable_ai {
 857            return None;
 858        }
 859        let state = self.get_or_init_project(project, cx);
 860
 861        if state.copilot.is_some() {
 862            return state.copilot.clone();
 863        }
 864        let _project = project.clone();
 865        let project = project.read(cx);
 866
 867        let node = project.node_runtime().cloned();
 868        if let Some(node) = node {
 869            let next_id = project.languages().next_language_server_id();
 870            let fs = project.fs().clone();
 871
 872            let copilot = cx.new(|cx| Copilot::new(Some(_project), next_id, fs, node, cx));
 873            state.copilot = Some(copilot.clone());
 874            Some(copilot)
 875        } else {
 876            None
 877        }
 878    }
 879
 880    pub fn context_for_project_with_buffers<'a>(
 881        &'a self,
 882        project: &Entity<Project>,
 883        cx: &'a mut App,
 884    ) -> Vec<(RelatedFile, Entity<Buffer>)> {
 885        self.projects
 886            .get(&project.entity_id())
 887            .map(|project| {
 888                project.context.update(cx, |context, cx| {
 889                    context.related_files_with_buffers(cx).collect()
 890                })
 891            })
 892            .unwrap_or_default()
 893    }
 894
 895    pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
 896        if matches!(
 897            self.edit_prediction_model,
 898            EditPredictionModel::Zeta2 | EditPredictionModel::Zeta1
 899        ) {
 900            self.user_store.read(cx).edit_prediction_usage()
 901        } else {
 902            None
 903        }
 904    }
 905
 906    pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
 907        self.get_or_init_project(project, cx);
 908    }
 909
 910    pub fn register_buffer(
 911        &mut self,
 912        buffer: &Entity<Buffer>,
 913        project: &Entity<Project>,
 914        cx: &mut Context<Self>,
 915    ) {
 916        let project_state = self.get_or_init_project(project, cx);
 917        Self::register_buffer_impl(project_state, buffer, project, cx);
 918    }
 919
 920    fn get_or_init_project(
 921        &mut self,
 922        project: &Entity<Project>,
 923        cx: &mut Context<Self>,
 924    ) -> &mut ProjectState {
 925        let entity_id = project.entity_id();
 926        self.projects
 927            .entry(entity_id)
 928            .or_insert_with(|| ProjectState {
 929                context: {
 930                    let related_excerpt_store = cx.new(|cx| RelatedExcerptStore::new(project, cx));
 931                    cx.subscribe(&related_excerpt_store, move |this, _, event, _| {
 932                        this.handle_excerpt_store_event(entity_id, event);
 933                    })
 934                    .detach();
 935                    related_excerpt_store
 936                },
 937                events: VecDeque::new(),
 938                last_event: None,
 939                recent_paths: VecDeque::new(),
 940                debug_tx: None,
 941                registered_buffers: HashMap::default(),
 942                current_prediction: None,
 943                cancelled_predictions: HashSet::default(),
 944                pending_predictions: ArrayVec::new(),
 945                next_pending_prediction_id: 0,
 946                last_edit_prediction_refresh: None,
 947                last_jump_prediction_refresh: None,
 948                license_detection_watchers: HashMap::default(),
 949                user_actions: VecDeque::with_capacity(USER_ACTION_HISTORY_SIZE),
 950                _subscriptions: [
 951                    cx.subscribe(&project, Self::handle_project_event),
 952                    cx.observe_release(&project, move |this, _, cx| {
 953                        this.projects.remove(&entity_id);
 954                        cx.notify();
 955                    }),
 956                ],
 957                copilot: None,
 958            })
 959    }
 960
 961    pub fn remove_project(&mut self, project: &Entity<Project>) {
 962        self.projects.remove(&project.entity_id());
 963    }
 964
 965    fn handle_excerpt_store_event(
 966        &mut self,
 967        project_entity_id: EntityId,
 968        event: &RelatedExcerptStoreEvent,
 969    ) {
 970        if let Some(project_state) = self.projects.get(&project_entity_id) {
 971            if let Some(debug_tx) = project_state.debug_tx.clone() {
 972                match event {
 973                    RelatedExcerptStoreEvent::StartedRefresh => {
 974                        debug_tx
 975                            .unbounded_send(DebugEvent::ContextRetrievalStarted(
 976                                ContextRetrievalStartedDebugEvent {
 977                                    project_entity_id: project_entity_id,
 978                                    timestamp: Instant::now(),
 979                                    search_prompt: String::new(),
 980                                },
 981                            ))
 982                            .ok();
 983                    }
 984                    RelatedExcerptStoreEvent::FinishedRefresh {
 985                        cache_hit_count,
 986                        cache_miss_count,
 987                        mean_definition_latency,
 988                        max_definition_latency,
 989                    } => {
 990                        debug_tx
 991                            .unbounded_send(DebugEvent::ContextRetrievalFinished(
 992                                ContextRetrievalFinishedDebugEvent {
 993                                    project_entity_id: project_entity_id,
 994                                    timestamp: Instant::now(),
 995                                    metadata: vec![
 996                                        (
 997                                            "Cache Hits",
 998                                            format!(
 999                                                "{}/{}",
1000                                                cache_hit_count,
1001                                                cache_hit_count + cache_miss_count
1002                                            )
1003                                            .into(),
1004                                        ),
1005                                        (
1006                                            "Max LSP Time",
1007                                            format!("{} ms", max_definition_latency.as_millis())
1008                                                .into(),
1009                                        ),
1010                                        (
1011                                            "Mean LSP Time",
1012                                            format!("{} ms", mean_definition_latency.as_millis())
1013                                                .into(),
1014                                        ),
1015                                    ],
1016                                },
1017                            ))
1018                            .ok();
1019                    }
1020                }
1021            }
1022        }
1023    }
1024
1025    pub fn debug_info(
1026        &mut self,
1027        project: &Entity<Project>,
1028        cx: &mut Context<Self>,
1029    ) -> mpsc::UnboundedReceiver<DebugEvent> {
1030        let project_state = self.get_or_init_project(project, cx);
1031        let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
1032        project_state.debug_tx = Some(debug_watch_tx);
1033        debug_watch_rx
1034    }
1035
1036    fn handle_project_event(
1037        &mut self,
1038        project: Entity<Project>,
1039        event: &project::Event,
1040        cx: &mut Context<Self>,
1041    ) {
1042        if !is_ep_store_provider(all_language_settings(None, cx).edit_predictions.provider) {
1043            return;
1044        }
1045        // TODO [zeta2] init with recent paths
1046        match event {
1047            project::Event::ActiveEntryChanged(Some(active_entry_id)) => {
1048                let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1049                    return;
1050                };
1051                let path = project.read(cx).path_for_entry(*active_entry_id, cx);
1052                if let Some(path) = path {
1053                    if let Some(ix) = project_state
1054                        .recent_paths
1055                        .iter()
1056                        .position(|probe| probe == &path)
1057                    {
1058                        project_state.recent_paths.remove(ix);
1059                    }
1060                    project_state.recent_paths.push_front(path);
1061                }
1062            }
1063            project::Event::DiagnosticsUpdated { .. } => {
1064                if cx.has_flag::<EditPredictionJumpsFeatureFlag>() {
1065                    self.refresh_prediction_from_diagnostics(
1066                        project,
1067                        DiagnosticSearchScope::Global,
1068                        cx,
1069                    );
1070                }
1071            }
1072            _ => (),
1073        }
1074    }
1075
1076    fn register_buffer_impl<'a>(
1077        project_state: &'a mut ProjectState,
1078        buffer: &Entity<Buffer>,
1079        project: &Entity<Project>,
1080        cx: &mut Context<Self>,
1081    ) -> &'a mut RegisteredBuffer {
1082        let buffer_id = buffer.entity_id();
1083
1084        if let Some(file) = buffer.read(cx).file() {
1085            let worktree_id = file.worktree_id(cx);
1086            if let Some(worktree) = project.read(cx).worktree_for_id(worktree_id, cx) {
1087                project_state
1088                    .license_detection_watchers
1089                    .entry(worktree_id)
1090                    .or_insert_with(|| {
1091                        let project_entity_id = project.entity_id();
1092                        cx.observe_release(&worktree, move |this, _worktree, _cx| {
1093                            let Some(project_state) = this.projects.get_mut(&project_entity_id)
1094                            else {
1095                                return;
1096                            };
1097                            project_state
1098                                .license_detection_watchers
1099                                .remove(&worktree_id);
1100                        })
1101                        .detach();
1102                        Rc::new(LicenseDetectionWatcher::new(&worktree, cx))
1103                    });
1104            }
1105        }
1106
1107        match project_state.registered_buffers.entry(buffer_id) {
1108            hash_map::Entry::Occupied(entry) => entry.into_mut(),
1109            hash_map::Entry::Vacant(entry) => {
1110                let buf = buffer.read(cx);
1111                let snapshot = buf.text_snapshot();
1112                let file = buf.file().cloned();
1113                let project_entity_id = project.entity_id();
1114                entry.insert(RegisteredBuffer {
1115                    snapshot,
1116                    file,
1117                    last_position: None,
1118                    pending_predictions: Vec::new(),
1119                    _subscriptions: [
1120                        cx.subscribe(buffer, {
1121                            let project = project.downgrade();
1122                            move |this, buffer, event, cx| {
1123                                if let language::BufferEvent::Edited = event
1124                                    && let Some(project) = project.upgrade()
1125                                {
1126                                    this.report_changes_for_buffer(&buffer, &project, false, cx);
1127                                }
1128                            }
1129                        }),
1130                        cx.observe_release(buffer, move |this, _buffer, _cx| {
1131                            let Some(project_state) = this.projects.get_mut(&project_entity_id)
1132                            else {
1133                                return;
1134                            };
1135                            project_state.registered_buffers.remove(&buffer_id);
1136                        }),
1137                    ],
1138                })
1139            }
1140        }
1141    }
1142
1143    fn report_changes_for_buffer(
1144        &mut self,
1145        buffer: &Entity<Buffer>,
1146        project: &Entity<Project>,
1147        is_predicted: bool,
1148        cx: &mut Context<Self>,
1149    ) {
1150        let project_state = self.get_or_init_project(project, cx);
1151        let registered_buffer = Self::register_buffer_impl(project_state, buffer, project, cx);
1152
1153        let buf = buffer.read(cx);
1154        let new_file = buf.file().cloned();
1155        let new_snapshot = buf.text_snapshot();
1156        if new_snapshot.version == registered_buffer.snapshot.version {
1157            return;
1158        }
1159
1160        let old_file = mem::replace(&mut registered_buffer.file, new_file.clone());
1161        let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
1162        let mut num_edits = 0usize;
1163        let mut total_deleted = 0usize;
1164        let mut total_inserted = 0usize;
1165        let mut edit_range: Option<Range<Anchor>> = None;
1166        let mut last_offset: Option<usize> = None;
1167        let now = cx.background_executor().now();
1168
1169        for (edit, anchor_range) in
1170            new_snapshot.anchored_edits_since::<usize>(&old_snapshot.version)
1171        {
1172            num_edits += 1;
1173            total_deleted += edit.old.len();
1174            total_inserted += edit.new.len();
1175            edit_range = Some(match edit_range {
1176                None => anchor_range,
1177                Some(acc) => acc.start..anchor_range.end,
1178            });
1179            last_offset = Some(edit.new.end);
1180        }
1181
1182        let Some(edit_range) = edit_range else {
1183            return;
1184        };
1185
1186        for pending_prediction in &mut registered_buffer.pending_predictions {
1187            if edit_range.overlaps(&pending_prediction.editable_anchor_range, &new_snapshot) {
1188                pending_prediction.last_edit_at = now;
1189            }
1190        }
1191
1192        let action_type = match (total_deleted, total_inserted, num_edits) {
1193            (0, ins, n) if ins == n => UserActionType::InsertChar,
1194            (0, _, _) => UserActionType::InsertSelection,
1195            (del, 0, n) if del == n => UserActionType::DeleteChar,
1196            (_, 0, _) => UserActionType::DeleteSelection,
1197            (_, ins, n) if ins == n => UserActionType::InsertChar,
1198            (_, _, _) => UserActionType::InsertSelection,
1199        };
1200
1201        if let Some(offset) = last_offset {
1202            let point = new_snapshot.offset_to_point(offset);
1203            let timestamp_epoch_ms = SystemTime::now()
1204                .duration_since(UNIX_EPOCH)
1205                .map(|d| d.as_millis() as u64)
1206                .unwrap_or(0);
1207            project_state.record_user_action(UserActionRecord {
1208                action_type,
1209                buffer_id: buffer.entity_id(),
1210                line_number: point.row,
1211                offset,
1212                timestamp_epoch_ms,
1213            });
1214        }
1215
1216        let events = &mut project_state.events;
1217
1218        if let Some(last_event) = project_state.last_event.as_mut() {
1219            let is_next_snapshot_of_same_buffer = old_snapshot.remote_id()
1220                == last_event.new_snapshot.remote_id()
1221                && old_snapshot.version == last_event.new_snapshot.version;
1222
1223            let prediction_source_changed = is_predicted != last_event.predicted;
1224
1225            let should_coalesce = is_next_snapshot_of_same_buffer
1226                && !prediction_source_changed
1227                && last_event
1228                    .edit_range
1229                    .as_ref()
1230                    .is_some_and(|last_edit_range| {
1231                        lines_between_ranges(
1232                            &edit_range.to_point(&new_snapshot),
1233                            &last_edit_range.to_point(&new_snapshot),
1234                        ) <= CHANGE_GROUPING_LINE_SPAN
1235                    });
1236
1237            if should_coalesce {
1238                let pause_elapsed = last_event
1239                    .last_edit_time
1240                    .map(|t| now.duration_since(t) >= LAST_CHANGE_GROUPING_TIME)
1241                    .unwrap_or(false);
1242                if pause_elapsed {
1243                    last_event.snapshot_after_last_editing_pause =
1244                        Some(last_event.new_snapshot.clone());
1245                }
1246
1247                last_event.edit_range = Some(edit_range);
1248                last_event.new_snapshot = new_snapshot;
1249                last_event.last_edit_time = Some(now);
1250                return;
1251            }
1252        }
1253
1254        if let Some(event) = project_state.last_event.take() {
1255            if let Some(event) = event.finalize(&project_state.license_detection_watchers, cx) {
1256                if events.len() + 1 >= EVENT_COUNT_MAX {
1257                    events.pop_front();
1258                }
1259                events.push_back(event);
1260            }
1261        }
1262
1263        merge_trailing_events_if_needed(events, &old_snapshot, &new_snapshot, &edit_range);
1264
1265        project_state.last_event = Some(LastEvent {
1266            old_file,
1267            new_file,
1268            old_snapshot,
1269            new_snapshot,
1270            edit_range: Some(edit_range),
1271            predicted: is_predicted,
1272            snapshot_after_last_editing_pause: None,
1273            last_edit_time: Some(now),
1274        });
1275    }
1276
1277    fn prediction_at(
1278        &mut self,
1279        buffer: &Entity<Buffer>,
1280        position: Option<language::Anchor>,
1281        project: &Entity<Project>,
1282        cx: &App,
1283    ) -> Option<BufferEditPrediction<'_>> {
1284        let project_state = self.projects.get_mut(&project.entity_id())?;
1285        if let Some(position) = position
1286            && let Some(buffer) = project_state
1287                .registered_buffers
1288                .get_mut(&buffer.entity_id())
1289        {
1290            buffer.last_position = Some(position);
1291        }
1292
1293        let CurrentEditPrediction {
1294            requested_by,
1295            prediction,
1296            ..
1297        } = project_state.current_prediction.as_ref()?;
1298
1299        if prediction.targets_buffer(buffer.read(cx)) {
1300            Some(BufferEditPrediction::Local { prediction })
1301        } else {
1302            let show_jump = match requested_by {
1303                PredictionRequestedBy::Buffer(requested_by_buffer_id) => {
1304                    requested_by_buffer_id == &buffer.entity_id()
1305                }
1306                PredictionRequestedBy::DiagnosticsUpdate => true,
1307            };
1308
1309            if show_jump {
1310                Some(BufferEditPrediction::Jump { prediction })
1311            } else {
1312                None
1313            }
1314        }
1315    }
1316
1317    fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
1318        let Some(current_prediction) = self
1319            .projects
1320            .get_mut(&project.entity_id())
1321            .and_then(|project_state| project_state.current_prediction.take())
1322        else {
1323            return;
1324        };
1325
1326        self.report_changes_for_buffer(&current_prediction.prediction.buffer, project, true, cx);
1327
1328        // can't hold &mut project_state ref across report_changes_for_buffer_call
1329        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1330            return;
1331        };
1332
1333        for pending_prediction in mem::take(&mut project_state.pending_predictions) {
1334            project_state.cancel_pending_prediction(pending_prediction, cx);
1335        }
1336
1337        match self.edit_prediction_model {
1338            EditPredictionModel::Sweep => {
1339                sweep_ai::edit_prediction_accepted(self, current_prediction, cx)
1340            }
1341            EditPredictionModel::Mercury => {
1342                mercury::edit_prediction_accepted(
1343                    current_prediction.prediction.id,
1344                    self.client.http_client(),
1345                    cx,
1346                );
1347            }
1348            EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {
1349                let is_cloud = !matches!(
1350                    all_language_settings(None, cx).edit_predictions.provider,
1351                    EditPredictionProvider::Ollama | EditPredictionProvider::OpenAiCompatibleApi
1352                );
1353                if is_cloud {
1354                    zeta::edit_prediction_accepted(self, current_prediction, cx)
1355                }
1356            }
1357            EditPredictionModel::Fim { .. } => {}
1358        }
1359    }
1360
1361    async fn handle_rejected_predictions(
1362        rx: UnboundedReceiver<EditPredictionRejection>,
1363        client: Arc<Client>,
1364        llm_token: LlmApiToken,
1365        app_version: Version,
1366        background_executor: BackgroundExecutor,
1367    ) {
1368        let mut rx = std::pin::pin!(rx.peekable());
1369        let mut batched = Vec::new();
1370
1371        while let Some(rejection) = rx.next().await {
1372            batched.push(rejection);
1373
1374            if batched.len() < MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST / 2 {
1375                select_biased! {
1376                    next = rx.as_mut().peek().fuse() => {
1377                        if next.is_some() {
1378                            continue;
1379                        }
1380                    }
1381                    () = background_executor.timer(REJECT_REQUEST_DEBOUNCE).fuse() => {},
1382                }
1383            }
1384
1385            let url = client
1386                .http_client()
1387                .build_zed_llm_url("/predict_edits/reject", &[])
1388                .unwrap();
1389
1390            let flush_count = batched
1391                .len()
1392                // in case items have accumulated after failure
1393                .min(MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST);
1394            let start = batched.len() - flush_count;
1395
1396            let body = RejectEditPredictionsBodyRef {
1397                rejections: &batched[start..],
1398            };
1399
1400            let result = Self::send_api_request::<()>(
1401                |builder| {
1402                    let req = builder
1403                        .uri(url.as_ref())
1404                        .body(serde_json::to_string(&body)?.into());
1405                    anyhow::Ok(req?)
1406                },
1407                client.clone(),
1408                llm_token.clone(),
1409                app_version.clone(),
1410                true,
1411            )
1412            .await;
1413
1414            if result.log_err().is_some() {
1415                batched.drain(start..);
1416            }
1417        }
1418    }
1419
1420    async fn run_settled_predictions_worker(
1421        this: WeakEntity<Self>,
1422        mut rx: UnboundedReceiver<Instant>,
1423        cx: &mut AsyncApp,
1424    ) {
1425        let mut next_wake_time: Option<Instant> = None;
1426        loop {
1427            let now = cx.background_executor().now();
1428            if let Some(wake_time) = next_wake_time.take() {
1429                cx.background_executor()
1430                    .timer(wake_time.duration_since(now))
1431                    .await;
1432            } else {
1433                let Some(new_enqueue_time) = rx.next().await else {
1434                    break;
1435                };
1436                next_wake_time = Some(new_enqueue_time + EDIT_PREDICTION_SETTLED_QUIESCENCE);
1437                while rx.next().now_or_never().flatten().is_some() {}
1438                continue;
1439            }
1440
1441            let Some(this) = this.upgrade() else {
1442                break;
1443            };
1444
1445            let now = cx.background_executor().now();
1446
1447            let mut oldest_edited_at = None;
1448
1449            this.update(cx, |this, _| {
1450                for (_, project_state) in this.projects.iter_mut() {
1451                    for (_, registered_buffer) in project_state.registered_buffers.iter_mut() {
1452                        registered_buffer
1453                            .pending_predictions
1454                            .retain_mut(|pending_prediction| {
1455                                let age =
1456                                    now.saturating_duration_since(pending_prediction.enqueued_at);
1457                                if age >= EDIT_PREDICTION_SETTLED_TTL {
1458                                    return false;
1459                                }
1460
1461                                let quiet_for =
1462                                    now.saturating_duration_since(pending_prediction.last_edit_at);
1463                                if quiet_for >= EDIT_PREDICTION_SETTLED_QUIESCENCE {
1464                                    let settled_editable_region = registered_buffer
1465                                        .snapshot
1466                                        .text_for_range(
1467                                            pending_prediction.editable_anchor_range.clone(),
1468                                        )
1469                                        .collect::<String>();
1470
1471                                    #[cfg(test)]
1472                                    if let Some(callback) = &this.settled_event_callback {
1473                                        callback(
1474                                            pending_prediction.request_id.clone(),
1475                                            settled_editable_region.clone(),
1476                                        );
1477                                    }
1478
1479                                    telemetry::event!(
1480                                        EDIT_PREDICTION_SETTLED_EVENT,
1481                                        request_id = pending_prediction.request_id.0.clone(),
1482                                        settled_editable_region,
1483                                    );
1484
1485                                    return false;
1486                                }
1487
1488                                if oldest_edited_at
1489                                    .is_none_or(|t| pending_prediction.last_edit_at < t)
1490                                {
1491                                    oldest_edited_at = Some(pending_prediction.last_edit_at);
1492                                }
1493
1494                                true
1495                            });
1496                    }
1497                }
1498            });
1499
1500            next_wake_time = oldest_edited_at.map(|t| t + EDIT_PREDICTION_SETTLED_QUIESCENCE);
1501        }
1502    }
1503
1504    pub(crate) fn enqueue_settled_prediction(
1505        &mut self,
1506        request_id: EditPredictionId,
1507        project: &Entity<Project>,
1508        edited_buffer: &Entity<Buffer>,
1509        edited_buffer_snapshot: &BufferSnapshot,
1510        editable_offset_range: Range<usize>,
1511        cx: &mut Context<Self>,
1512    ) {
1513        let project_state = self.get_or_init_project(project, cx);
1514        if let Some(buffer) = project_state
1515            .registered_buffers
1516            .get_mut(&edited_buffer.entity_id())
1517        {
1518            let now = cx.background_executor().now();
1519            buffer.pending_predictions.push(PendingSettledPrediction {
1520                request_id,
1521                editable_anchor_range: edited_buffer_snapshot
1522                    .anchor_range_around(editable_offset_range),
1523                enqueued_at: now,
1524                last_edit_at: now,
1525            });
1526            self.settled_predictions_tx.unbounded_send(now).ok();
1527        }
1528    }
1529
1530    fn reject_current_prediction(
1531        &mut self,
1532        reason: EditPredictionRejectReason,
1533        project: &Entity<Project>,
1534        cx: &App,
1535    ) {
1536        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
1537            project_state.pending_predictions.clear();
1538            if let Some(prediction) = project_state.current_prediction.take() {
1539                let model_version = prediction.prediction.model_version.clone();
1540                self.reject_prediction(
1541                    prediction.prediction.id,
1542                    reason,
1543                    prediction.was_shown,
1544                    model_version,
1545                    cx,
1546                );
1547            }
1548        };
1549    }
1550
1551    fn did_show_current_prediction(
1552        &mut self,
1553        project: &Entity<Project>,
1554        display_type: edit_prediction_types::SuggestionDisplayType,
1555        cx: &mut Context<Self>,
1556    ) {
1557        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1558            return;
1559        };
1560
1561        let Some(current_prediction) = project_state.current_prediction.as_mut() else {
1562            return;
1563        };
1564
1565        let is_jump = display_type == edit_prediction_types::SuggestionDisplayType::Jump;
1566        let previous_shown_with = current_prediction.shown_with;
1567
1568        if previous_shown_with.is_none() || !is_jump {
1569            current_prediction.shown_with = Some(display_type);
1570        }
1571
1572        let is_first_non_jump_show = !current_prediction.was_shown && !is_jump;
1573
1574        if is_first_non_jump_show {
1575            current_prediction.was_shown = true;
1576        }
1577
1578        let display_type_changed = previous_shown_with != Some(display_type);
1579
1580        if self.edit_prediction_model == EditPredictionModel::Sweep && display_type_changed {
1581            sweep_ai::edit_prediction_shown(
1582                &self.sweep_ai,
1583                self.client.clone(),
1584                &current_prediction.prediction,
1585                display_type,
1586                cx,
1587            );
1588        }
1589
1590        if is_first_non_jump_show {
1591            self.shown_predictions
1592                .push_front(current_prediction.prediction.clone());
1593            if self.shown_predictions.len() > 50 {
1594                let completion = self.shown_predictions.pop_back().unwrap();
1595                self.rated_predictions.remove(&completion.id);
1596            }
1597        }
1598    }
1599
1600    fn reject_prediction(
1601        &mut self,
1602        prediction_id: EditPredictionId,
1603        reason: EditPredictionRejectReason,
1604        was_shown: bool,
1605        model_version: Option<String>,
1606        cx: &App,
1607    ) {
1608        match self.edit_prediction_model {
1609            EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {
1610                let is_cloud = !matches!(
1611                    all_language_settings(None, cx).edit_predictions.provider,
1612                    EditPredictionProvider::Ollama | EditPredictionProvider::OpenAiCompatibleApi
1613                );
1614                if is_cloud {
1615                    self.reject_predictions_tx
1616                        .unbounded_send(EditPredictionRejection {
1617                            request_id: prediction_id.to_string(),
1618                            reason,
1619                            was_shown,
1620                            model_version,
1621                        })
1622                        .log_err();
1623                }
1624            }
1625            EditPredictionModel::Mercury => {
1626                mercury::edit_prediction_rejected(
1627                    prediction_id,
1628                    was_shown,
1629                    reason,
1630                    self.client.http_client(),
1631                    cx,
1632                );
1633            }
1634            EditPredictionModel::Sweep | EditPredictionModel::Fim { .. } => {}
1635        }
1636    }
1637
1638    fn is_refreshing(&self, project: &Entity<Project>) -> bool {
1639        self.projects
1640            .get(&project.entity_id())
1641            .is_some_and(|project_state| !project_state.pending_predictions.is_empty())
1642    }
1643
1644    pub fn refresh_prediction_from_buffer(
1645        &mut self,
1646        project: Entity<Project>,
1647        buffer: Entity<Buffer>,
1648        position: language::Anchor,
1649        cx: &mut Context<Self>,
1650    ) {
1651        self.queue_prediction_refresh(
1652            project.clone(),
1653            PredictEditsRequestTrigger::Other,
1654            buffer.entity_id(),
1655            cx,
1656            move |this, cx| {
1657                let Some(request_task) = this
1658                    .update(cx, |this, cx| {
1659                        this.request_prediction(
1660                            &project,
1661                            &buffer,
1662                            position,
1663                            PredictEditsRequestTrigger::Other,
1664                            cx,
1665                        )
1666                    })
1667                    .log_err()
1668                else {
1669                    return Task::ready(anyhow::Ok(None));
1670                };
1671
1672                cx.spawn(async move |_cx| {
1673                    request_task.await.map(|prediction_result| {
1674                        prediction_result.map(|prediction_result| {
1675                            (
1676                                prediction_result,
1677                                PredictionRequestedBy::Buffer(buffer.entity_id()),
1678                            )
1679                        })
1680                    })
1681                })
1682            },
1683        )
1684    }
1685
1686    pub fn refresh_prediction_from_diagnostics(
1687        &mut self,
1688        project: Entity<Project>,
1689        scope: DiagnosticSearchScope,
1690        cx: &mut Context<Self>,
1691    ) {
1692        if !is_ep_store_provider(all_language_settings(None, cx).edit_predictions.provider) {
1693            return;
1694        }
1695
1696        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1697            return;
1698        };
1699
1700        // Prefer predictions from buffer
1701        if project_state.current_prediction.is_some() {
1702            return;
1703        }
1704
1705        self.queue_prediction_refresh(
1706            project.clone(),
1707            PredictEditsRequestTrigger::Diagnostics,
1708            project.entity_id(),
1709            cx,
1710            move |this, cx| {
1711                let Some((active_buffer, snapshot, cursor_point)) = this
1712                    .read_with(cx, |this, cx| {
1713                        let project_state = this.projects.get(&project.entity_id())?;
1714                        let (buffer, position) = project_state.active_buffer(&project, cx)?;
1715                        let snapshot = buffer.read(cx).snapshot();
1716
1717                        if !Self::predictions_enabled_at(&snapshot, position, cx) {
1718                            return None;
1719                        }
1720
1721                        let cursor_point = position
1722                            .map(|pos| pos.to_point(&snapshot))
1723                            .unwrap_or_default();
1724
1725                        Some((buffer, snapshot, cursor_point))
1726                    })
1727                    .log_err()
1728                    .flatten()
1729                else {
1730                    return Task::ready(anyhow::Ok(None));
1731                };
1732
1733                cx.spawn(async move |cx| {
1734                    let diagnostic_search_range = match scope {
1735                        DiagnosticSearchScope::Local => {
1736                            let diagnostic_search_start =
1737                                cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
1738                            let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
1739                            Point::new(diagnostic_search_start, 0)
1740                                ..Point::new(diagnostic_search_end, 0)
1741                        }
1742                        DiagnosticSearchScope::Global => Default::default(),
1743                    };
1744
1745                    let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1746                        active_buffer,
1747                        &snapshot,
1748                        diagnostic_search_range,
1749                        cursor_point,
1750                        &project,
1751                        cx,
1752                    )
1753                    .await?
1754                    else {
1755                        return anyhow::Ok(None);
1756                    };
1757
1758                    let Some(prediction_result) = this
1759                        .update(cx, |this, cx| {
1760                            this.request_prediction(
1761                                &project,
1762                                &jump_buffer,
1763                                jump_position,
1764                                PredictEditsRequestTrigger::Diagnostics,
1765                                cx,
1766                            )
1767                        })?
1768                        .await?
1769                    else {
1770                        return anyhow::Ok(None);
1771                    };
1772
1773                    this.update(cx, |this, cx| {
1774                        Some((
1775                            if this
1776                                .get_or_init_project(&project, cx)
1777                                .current_prediction
1778                                .is_none()
1779                            {
1780                                prediction_result
1781                            } else {
1782                                EditPredictionResult {
1783                                    id: prediction_result.id,
1784                                    prediction: Err(EditPredictionRejectReason::CurrentPreferred),
1785                                }
1786                            },
1787                            PredictionRequestedBy::DiagnosticsUpdate,
1788                        ))
1789                    })
1790                })
1791            },
1792        );
1793    }
1794
1795    fn predictions_enabled_at(
1796        snapshot: &BufferSnapshot,
1797        position: Option<language::Anchor>,
1798        cx: &App,
1799    ) -> bool {
1800        let file = snapshot.file();
1801        let all_settings = all_language_settings(file, cx);
1802        if !all_settings.show_edit_predictions(snapshot.language(), cx)
1803            || file.is_some_and(|file| !all_settings.edit_predictions_enabled_for_file(file, cx))
1804        {
1805            return false;
1806        }
1807
1808        if let Some(last_position) = position {
1809            let settings = snapshot.settings_at(last_position, cx);
1810
1811            if !settings.edit_predictions_disabled_in.is_empty()
1812                && let Some(scope) = snapshot.language_scope_at(last_position)
1813                && let Some(scope_name) = scope.override_name()
1814                && settings
1815                    .edit_predictions_disabled_in
1816                    .iter()
1817                    .any(|s| s == scope_name)
1818            {
1819                return false;
1820            }
1821        }
1822
1823        true
1824    }
1825
1826    pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
1827}
1828
1829fn is_ep_store_provider(provider: EditPredictionProvider) -> bool {
1830    match provider {
1831        EditPredictionProvider::Zed
1832        | EditPredictionProvider::Sweep
1833        | EditPredictionProvider::Mercury
1834        | EditPredictionProvider::Ollama
1835        | EditPredictionProvider::OpenAiCompatibleApi
1836        | EditPredictionProvider::Experimental(_) => true,
1837        EditPredictionProvider::None
1838        | EditPredictionProvider::Copilot
1839        | EditPredictionProvider::Supermaven
1840        | EditPredictionProvider::Codestral => false,
1841    }
1842}
1843
1844impl EditPredictionStore {
1845    fn queue_prediction_refresh(
1846        &mut self,
1847        project: Entity<Project>,
1848        request_trigger: PredictEditsRequestTrigger,
1849        throttle_entity: EntityId,
1850        cx: &mut Context<Self>,
1851        do_refresh: impl FnOnce(
1852            WeakEntity<Self>,
1853            &mut AsyncApp,
1854        )
1855            -> Task<Result<Option<(EditPredictionResult, PredictionRequestedBy)>>>
1856        + 'static,
1857    ) {
1858        fn select_throttle(
1859            project_state: &mut ProjectState,
1860            request_trigger: PredictEditsRequestTrigger,
1861        ) -> &mut Option<(EntityId, Instant)> {
1862            match request_trigger {
1863                PredictEditsRequestTrigger::Diagnostics => {
1864                    &mut project_state.last_jump_prediction_refresh
1865                }
1866                _ => &mut project_state.last_edit_prediction_refresh,
1867            }
1868        }
1869
1870        let (needs_acceptance_tracking, max_pending_predictions) =
1871            match all_language_settings(None, cx).edit_predictions.provider {
1872                EditPredictionProvider::Zed
1873                | EditPredictionProvider::Sweep
1874                | EditPredictionProvider::Mercury
1875                | EditPredictionProvider::Experimental(_) => (true, 2),
1876                EditPredictionProvider::Ollama => (false, 1),
1877                EditPredictionProvider::OpenAiCompatibleApi => (false, 2),
1878                EditPredictionProvider::None
1879                | EditPredictionProvider::Copilot
1880                | EditPredictionProvider::Supermaven
1881                | EditPredictionProvider::Codestral => {
1882                    log::error!("queue_prediction_refresh called with non-store provider");
1883                    return;
1884                }
1885            };
1886
1887        let drop_on_cancel = !needs_acceptance_tracking;
1888        let throttle_timeout = Self::THROTTLE_TIMEOUT;
1889        let project_state = self.get_or_init_project(&project, cx);
1890        let pending_prediction_id = project_state.next_pending_prediction_id;
1891        project_state.next_pending_prediction_id += 1;
1892        let last_request = *select_throttle(project_state, request_trigger);
1893
1894        let task = cx.spawn(async move |this, cx| {
1895            if let Some(timeout) = last_request.and_then(|(last_entity, last_timestamp)| {
1896                if throttle_entity != last_entity {
1897                    return None;
1898                }
1899                (last_timestamp + throttle_timeout).checked_duration_since(Instant::now())
1900            }) {
1901                cx.background_executor().timer(timeout).await;
1902            }
1903
1904            // If this task was cancelled before the throttle timeout expired,
1905            // do not perform a request.
1906            let mut is_cancelled = true;
1907            this.update(cx, |this, cx| {
1908                let project_state = this.get_or_init_project(&project, cx);
1909                let was_cancelled = project_state
1910                    .cancelled_predictions
1911                    .remove(&pending_prediction_id);
1912                if !was_cancelled {
1913                    let new_refresh = (throttle_entity, Instant::now());
1914                    *select_throttle(project_state, request_trigger) = Some(new_refresh);
1915                    is_cancelled = false;
1916                }
1917            })
1918            .ok();
1919            if is_cancelled {
1920                return None;
1921            }
1922
1923            let new_prediction_result = do_refresh(this.clone(), cx).await.log_err().flatten();
1924            let new_prediction_id = new_prediction_result
1925                .as_ref()
1926                .map(|(prediction, _)| prediction.id.clone());
1927
1928            // When a prediction completes, remove it from the pending list, and cancel
1929            // any pending predictions that were enqueued before it.
1930            this.update(cx, |this, cx| {
1931                let project_state = this.get_or_init_project(&project, cx);
1932
1933                let is_cancelled = project_state
1934                    .cancelled_predictions
1935                    .remove(&pending_prediction_id);
1936
1937                let new_current_prediction = if !is_cancelled
1938                    && let Some((prediction_result, requested_by)) = new_prediction_result
1939                {
1940                    match prediction_result.prediction {
1941                        Ok(prediction) => {
1942                            let new_prediction = CurrentEditPrediction {
1943                                requested_by,
1944                                prediction,
1945                                was_shown: false,
1946                                shown_with: None,
1947                            };
1948
1949                            if let Some(current_prediction) =
1950                                project_state.current_prediction.as_ref()
1951                            {
1952                                if new_prediction.should_replace_prediction(&current_prediction, cx)
1953                                {
1954                                    this.reject_current_prediction(
1955                                        EditPredictionRejectReason::Replaced,
1956                                        &project,
1957                                        cx,
1958                                    );
1959
1960                                    Some(new_prediction)
1961                                } else {
1962                                    this.reject_prediction(
1963                                        new_prediction.prediction.id,
1964                                        EditPredictionRejectReason::CurrentPreferred,
1965                                        false,
1966                                        new_prediction.prediction.model_version,
1967                                        cx,
1968                                    );
1969                                    None
1970                                }
1971                            } else {
1972                                Some(new_prediction)
1973                            }
1974                        }
1975                        Err(reject_reason) => {
1976                            this.reject_prediction(
1977                                prediction_result.id,
1978                                reject_reason,
1979                                false,
1980                                None,
1981                                cx,
1982                            );
1983                            None
1984                        }
1985                    }
1986                } else {
1987                    None
1988                };
1989
1990                let project_state = this.get_or_init_project(&project, cx);
1991
1992                if let Some(new_prediction) = new_current_prediction {
1993                    project_state.current_prediction = Some(new_prediction);
1994                }
1995
1996                let mut pending_predictions = mem::take(&mut project_state.pending_predictions);
1997                for (ix, pending_prediction) in pending_predictions.iter().enumerate() {
1998                    if pending_prediction.id == pending_prediction_id {
1999                        pending_predictions.remove(ix);
2000                        for pending_prediction in pending_predictions.drain(0..ix) {
2001                            project_state.cancel_pending_prediction(pending_prediction, cx)
2002                        }
2003                        break;
2004                    }
2005                }
2006                this.get_or_init_project(&project, cx).pending_predictions = pending_predictions;
2007                cx.notify();
2008            })
2009            .ok();
2010
2011            new_prediction_id
2012        });
2013
2014        if project_state.pending_predictions.len() < max_pending_predictions {
2015            project_state.pending_predictions.push(PendingPrediction {
2016                id: pending_prediction_id,
2017                task,
2018                drop_on_cancel,
2019            });
2020        } else {
2021            let pending_prediction = project_state.pending_predictions.pop().unwrap();
2022            project_state.pending_predictions.push(PendingPrediction {
2023                id: pending_prediction_id,
2024                task,
2025                drop_on_cancel,
2026            });
2027            project_state.cancel_pending_prediction(pending_prediction, cx);
2028        }
2029    }
2030
2031    pub fn request_prediction(
2032        &mut self,
2033        project: &Entity<Project>,
2034        active_buffer: &Entity<Buffer>,
2035        position: language::Anchor,
2036        trigger: PredictEditsRequestTrigger,
2037        cx: &mut Context<Self>,
2038    ) -> Task<Result<Option<EditPredictionResult>>> {
2039        self.request_prediction_internal(
2040            project.clone(),
2041            active_buffer.clone(),
2042            position,
2043            trigger,
2044            cx.has_flag::<Zeta2FeatureFlag>(),
2045            cx,
2046        )
2047    }
2048
2049    fn request_prediction_internal(
2050        &mut self,
2051        project: Entity<Project>,
2052        active_buffer: Entity<Buffer>,
2053        position: language::Anchor,
2054        trigger: PredictEditsRequestTrigger,
2055        allow_jump: bool,
2056        cx: &mut Context<Self>,
2057    ) -> Task<Result<Option<EditPredictionResult>>> {
2058        self.get_or_init_project(&project, cx);
2059        let project_state = self.projects.get(&project.entity_id()).unwrap();
2060        let stored_events = project_state.events(cx);
2061        let has_events = !stored_events.is_empty();
2062        let events: Vec<Arc<zeta_prompt::Event>> =
2063            stored_events.into_iter().map(|e| e.event).collect();
2064        let debug_tx = project_state.debug_tx.clone();
2065
2066        let snapshot = active_buffer.read(cx).snapshot();
2067        let cursor_point = position.to_point(&snapshot);
2068        let current_offset = position.to_offset(&snapshot);
2069
2070        let mut user_actions: Vec<UserActionRecord> =
2071            project_state.user_actions.iter().cloned().collect();
2072
2073        if let Some(last_action) = user_actions.last() {
2074            if last_action.buffer_id == active_buffer.entity_id()
2075                && current_offset != last_action.offset
2076            {
2077                let timestamp_epoch_ms = SystemTime::now()
2078                    .duration_since(UNIX_EPOCH)
2079                    .map(|d| d.as_millis() as u64)
2080                    .unwrap_or(0);
2081                user_actions.push(UserActionRecord {
2082                    action_type: UserActionType::CursorMovement,
2083                    buffer_id: active_buffer.entity_id(),
2084                    line_number: cursor_point.row,
2085                    offset: current_offset,
2086                    timestamp_epoch_ms,
2087                });
2088            }
2089        }
2090        let diagnostic_search_start = cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
2091        let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
2092        let diagnostic_search_range =
2093            Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
2094
2095        let related_files = self.context_for_project(&project, cx);
2096
2097        let inputs = EditPredictionModelInput {
2098            project: project.clone(),
2099            buffer: active_buffer,
2100            snapshot: snapshot,
2101            position,
2102            events,
2103            related_files,
2104            recent_paths: project_state.recent_paths.clone(),
2105            trigger,
2106            diagnostic_search_range: diagnostic_search_range,
2107            debug_tx,
2108            user_actions,
2109        };
2110
2111        let task = match self.edit_prediction_model {
2112            EditPredictionModel::Zeta1 => zeta::request_prediction_with_zeta(
2113                self,
2114                inputs,
2115                Some(zeta_prompt::EditPredictionModelKind::Zeta1),
2116                cx,
2117            ),
2118            EditPredictionModel::Zeta2 => zeta::request_prediction_with_zeta(
2119                self,
2120                inputs,
2121                Some(zeta_prompt::EditPredictionModelKind::Zeta2),
2122                cx,
2123            ),
2124            EditPredictionModel::Fim { format } => fim::request_prediction(inputs, format, cx),
2125            EditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep(inputs, cx),
2126            EditPredictionModel::Mercury => self.mercury.request_prediction(inputs, cx),
2127        };
2128
2129        cx.spawn(async move |this, cx| {
2130            let prediction = task.await?;
2131
2132            if prediction.is_none() && allow_jump && has_events {
2133                this.update(cx, |this, cx| {
2134                    this.refresh_prediction_from_diagnostics(
2135                        project,
2136                        DiagnosticSearchScope::Local,
2137                        cx,
2138                    );
2139                })?;
2140                return anyhow::Ok(None);
2141            }
2142
2143            Ok(prediction)
2144        })
2145    }
2146
2147    pub(crate) async fn next_diagnostic_location(
2148        active_buffer: Entity<Buffer>,
2149        active_buffer_snapshot: &BufferSnapshot,
2150        active_buffer_diagnostic_search_range: Range<Point>,
2151        active_buffer_cursor_point: Point,
2152        project: &Entity<Project>,
2153        cx: &mut AsyncApp,
2154    ) -> Result<Option<(Entity<Buffer>, language::Anchor)>> {
2155        let collaborator_cursor_rows: Vec<u32> = active_buffer_snapshot
2156            .selections_in_range(Anchor::MIN..Anchor::MAX, false)
2157            .flat_map(|(_, _, _, selections)| {
2158                selections.map(|s| s.head().to_point(active_buffer_snapshot).row)
2159            })
2160            .collect();
2161
2162        let mut jump_location = active_buffer_snapshot
2163            .diagnostic_groups(None)
2164            .into_iter()
2165            .filter_map(|(_, group)| {
2166                let range = &group.entries[group.primary_ix]
2167                    .range
2168                    .to_point(&active_buffer_snapshot);
2169                if range.overlaps(&active_buffer_diagnostic_search_range) {
2170                    return None;
2171                }
2172                let near_collaborator = collaborator_cursor_rows.iter().any(|&collab_row| {
2173                    range.start.row.abs_diff(collab_row) <= DIAGNOSTIC_LINES_RANGE
2174                });
2175                let near_local = active_buffer_cursor_point.row.abs_diff(range.start.row)
2176                    <= DIAGNOSTIC_LINES_RANGE;
2177                if near_collaborator && !near_local {
2178                    return None;
2179                }
2180                Some(range.start)
2181            })
2182            .min_by_key(|probe| probe.row.abs_diff(active_buffer_cursor_point.row))
2183            .map(|position| {
2184                (
2185                    active_buffer.clone(),
2186                    active_buffer_snapshot.anchor_before(position),
2187                )
2188            });
2189
2190        if jump_location.is_none() {
2191            let active_buffer_path = active_buffer.read_with(cx, |buffer, cx| {
2192                let file = buffer.file()?;
2193
2194                Some(ProjectPath {
2195                    worktree_id: file.worktree_id(cx),
2196                    path: file.path().clone(),
2197                })
2198            });
2199
2200            let mut candidates: Vec<(ProjectPath, usize)> = project.read_with(cx, |project, cx| {
2201                project
2202                    .diagnostic_summaries(false, cx)
2203                    .filter(|(path, _, _)| Some(path) != active_buffer_path.as_ref())
2204                    .map(|(path, _, _)| {
2205                        let shared_prefix = path
2206                            .path
2207                            .components()
2208                            .zip(
2209                                active_buffer_path
2210                                    .as_ref()
2211                                    .map(|p| p.path.components())
2212                                    .unwrap_or_default(),
2213                            )
2214                            .take_while(|(a, b)| a == b)
2215                            .count();
2216                        (path, shared_prefix)
2217                    })
2218                    .collect()
2219            });
2220
2221            candidates.sort_by(|a, b| b.1.cmp(&a.1));
2222
2223            for (path, _) in candidates {
2224                let candidate_buffer = project
2225                    .update(cx, |project, cx| project.open_buffer(path, cx))
2226                    .await?;
2227
2228                let (has_collaborators, diagnostic_position) =
2229                    candidate_buffer.read_with(cx, |buffer, _cx| {
2230                        let snapshot = buffer.snapshot();
2231                        let has_collaborators = snapshot
2232                            .selections_in_range(Anchor::MIN..Anchor::MAX, false)
2233                            .next()
2234                            .is_some();
2235                        let position = buffer
2236                            .buffer_diagnostics(None)
2237                            .into_iter()
2238                            .min_by_key(|entry| entry.diagnostic.severity)
2239                            .map(|entry| entry.range.start);
2240                        (has_collaborators, position)
2241                    });
2242
2243                if has_collaborators {
2244                    continue;
2245                }
2246
2247                if let Some(position) = diagnostic_position {
2248                    jump_location = Some((candidate_buffer, position));
2249                    break;
2250                }
2251            }
2252        }
2253
2254        anyhow::Ok(jump_location)
2255    }
2256
2257    async fn send_raw_llm_request(
2258        request: RawCompletionRequest,
2259        client: Arc<Client>,
2260        custom_url: Option<Arc<Url>>,
2261        llm_token: LlmApiToken,
2262        app_version: Version,
2263    ) -> Result<(RawCompletionResponse, Option<EditPredictionUsage>)> {
2264        let url = if let Some(custom_url) = custom_url {
2265            custom_url.as_ref().clone()
2266        } else {
2267            client
2268                .http_client()
2269                .build_zed_llm_url("/predict_edits/raw", &[])?
2270        };
2271
2272        Self::send_api_request(
2273            |builder| {
2274                let req = builder
2275                    .uri(url.as_ref())
2276                    .body(serde_json::to_string(&request)?.into());
2277                Ok(req?)
2278            },
2279            client,
2280            llm_token,
2281            app_version,
2282            true,
2283        )
2284        .await
2285    }
2286
2287    pub(crate) async fn send_v3_request(
2288        input: ZetaPromptInput,
2289        client: Arc<Client>,
2290        llm_token: LlmApiToken,
2291        app_version: Version,
2292        trigger: PredictEditsRequestTrigger,
2293    ) -> Result<(PredictEditsV3Response, Option<EditPredictionUsage>)> {
2294        let url = client
2295            .http_client()
2296            .build_zed_llm_url("/predict_edits/v3", &[])?;
2297
2298        let request = PredictEditsV3Request { input, trigger };
2299
2300        let json_bytes = serde_json::to_vec(&request)?;
2301        let compressed = zstd::encode_all(&json_bytes[..], 3)?;
2302
2303        Self::send_api_request(
2304            |builder| {
2305                let req = builder
2306                    .uri(url.as_ref())
2307                    .header("Content-Encoding", "zstd")
2308                    .body(compressed.clone().into());
2309                Ok(req?)
2310            },
2311            client,
2312            llm_token,
2313            app_version,
2314            true,
2315        )
2316        .await
2317    }
2318
2319    fn handle_api_response<T>(
2320        this: &WeakEntity<Self>,
2321        response: Result<(T, Option<EditPredictionUsage>)>,
2322        cx: &mut gpui::AsyncApp,
2323    ) -> Result<T> {
2324        match response {
2325            Ok((data, usage)) => {
2326                if let Some(usage) = usage {
2327                    this.update(cx, |this, cx| {
2328                        this.user_store.update(cx, |user_store, cx| {
2329                            user_store.update_edit_prediction_usage(usage, cx);
2330                        });
2331                    })
2332                    .ok();
2333                }
2334                Ok(data)
2335            }
2336            Err(err) => {
2337                if err.is::<ZedUpdateRequiredError>() {
2338                    cx.update(|cx| {
2339                        this.update(cx, |this, _cx| {
2340                            this.update_required = true;
2341                        })
2342                        .ok();
2343
2344                        let error_message: SharedString = err.to_string().into();
2345                        show_app_notification(
2346                            NotificationId::unique::<ZedUpdateRequiredError>(),
2347                            cx,
2348                            move |cx| {
2349                                cx.new(|cx| {
2350                                    ErrorMessagePrompt::new(error_message.clone(), cx)
2351                                        .with_link_button("Update Zed", "https://zed.dev/releases")
2352                                })
2353                            },
2354                        );
2355                    });
2356                }
2357                Err(err)
2358            }
2359        }
2360    }
2361
2362    async fn send_api_request<Res>(
2363        build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
2364        client: Arc<Client>,
2365        llm_token: LlmApiToken,
2366        app_version: Version,
2367        require_auth: bool,
2368    ) -> Result<(Res, Option<EditPredictionUsage>)>
2369    where
2370        Res: DeserializeOwned,
2371    {
2372        let http_client = client.http_client();
2373
2374        let mut token = if require_auth {
2375            Some(llm_token.acquire(&client).await?)
2376        } else {
2377            llm_token.acquire(&client).await.ok()
2378        };
2379        let mut did_retry = false;
2380
2381        loop {
2382            let request_builder = http_client::Request::builder().method(Method::POST);
2383
2384            let mut request_builder = request_builder
2385                .header("Content-Type", "application/json")
2386                .header(ZED_VERSION_HEADER_NAME, app_version.to_string());
2387
2388            // Only add Authorization header if we have a token
2389            if let Some(ref token_value) = token {
2390                request_builder =
2391                    request_builder.header("Authorization", format!("Bearer {}", token_value));
2392            }
2393
2394            let request = build(request_builder)?;
2395
2396            let mut response = http_client.send(request).await?;
2397
2398            if let Some(minimum_required_version) = response
2399                .headers()
2400                .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
2401                .and_then(|version| Version::from_str(version.to_str().ok()?).ok())
2402            {
2403                anyhow::ensure!(
2404                    app_version >= minimum_required_version,
2405                    ZedUpdateRequiredError {
2406                        minimum_version: minimum_required_version
2407                    }
2408                );
2409            }
2410
2411            if response.status().is_success() {
2412                let usage = EditPredictionUsage::from_headers(response.headers()).ok();
2413
2414                let mut body = Vec::new();
2415                response.body_mut().read_to_end(&mut body).await?;
2416                return Ok((serde_json::from_slice(&body)?, usage));
2417            } else if !did_retry && token.is_some() && response.needs_llm_token_refresh() {
2418                did_retry = true;
2419                token = Some(llm_token.refresh(&client).await?);
2420            } else {
2421                let mut body = String::new();
2422                response.body_mut().read_to_string(&mut body).await?;
2423                anyhow::bail!(
2424                    "Request failed with status: {:?}\nBody: {}",
2425                    response.status(),
2426                    body
2427                );
2428            }
2429        }
2430    }
2431
2432    pub fn refresh_context(
2433        &mut self,
2434        project: &Entity<Project>,
2435        buffer: &Entity<language::Buffer>,
2436        cursor_position: language::Anchor,
2437        cx: &mut Context<Self>,
2438    ) {
2439        self.get_or_init_project(project, cx)
2440            .context
2441            .update(cx, |store, cx| {
2442                store.refresh(buffer.clone(), cursor_position, cx);
2443            });
2444    }
2445
2446    #[cfg(feature = "cli-support")]
2447    pub fn set_context_for_buffer(
2448        &mut self,
2449        project: &Entity<Project>,
2450        related_files: Vec<RelatedFile>,
2451        cx: &mut Context<Self>,
2452    ) {
2453        self.get_or_init_project(project, cx)
2454            .context
2455            .update(cx, |store, cx| {
2456                store.set_related_files(related_files, cx);
2457            });
2458    }
2459
2460    #[cfg(feature = "cli-support")]
2461    pub fn set_recent_paths_for_project(
2462        &mut self,
2463        project: &Entity<Project>,
2464        paths: impl IntoIterator<Item = project::ProjectPath>,
2465        cx: &mut Context<Self>,
2466    ) {
2467        let project_state = self.get_or_init_project(project, cx);
2468        project_state.recent_paths = paths.into_iter().collect();
2469    }
2470
2471    fn is_file_open_source(
2472        &self,
2473        project: &Entity<Project>,
2474        file: &Arc<dyn File>,
2475        cx: &App,
2476    ) -> bool {
2477        if !file.is_local() || file.is_private() {
2478            return false;
2479        }
2480        let Some(project_state) = self.projects.get(&project.entity_id()) else {
2481            return false;
2482        };
2483        project_state
2484            .license_detection_watchers
2485            .get(&file.worktree_id(cx))
2486            .as_ref()
2487            .is_some_and(|watcher| watcher.is_project_open_source())
2488    }
2489
2490    pub(crate) fn is_data_collection_enabled(&self, cx: &App) -> bool {
2491        self.data_collection_choice.is_enabled(cx)
2492    }
2493
2494    fn load_data_collection_choice() -> DataCollectionChoice {
2495        let choice = KEY_VALUE_STORE
2496            .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2497            .log_err()
2498            .flatten();
2499
2500        match choice.as_deref() {
2501            Some("true") => DataCollectionChoice::Enabled,
2502            Some("false") => DataCollectionChoice::Disabled,
2503            Some(_) => {
2504                log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
2505                DataCollectionChoice::NotAnswered
2506            }
2507            None => DataCollectionChoice::NotAnswered,
2508        }
2509    }
2510
2511    fn toggle_data_collection_choice(&mut self, cx: &mut Context<Self>) {
2512        self.data_collection_choice = self.data_collection_choice.toggle();
2513        let new_choice = self.data_collection_choice;
2514        let is_enabled = new_choice.is_enabled(cx);
2515        db::write_and_log(cx, move || {
2516            KEY_VALUE_STORE.write_kvp(
2517                ZED_PREDICT_DATA_COLLECTION_CHOICE.into(),
2518                is_enabled.to_string(),
2519            )
2520        });
2521    }
2522
2523    pub fn shown_predictions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
2524        self.shown_predictions.iter()
2525    }
2526
2527    pub fn shown_completions_len(&self) -> usize {
2528        self.shown_predictions.len()
2529    }
2530
2531    pub fn is_prediction_rated(&self, id: &EditPredictionId) -> bool {
2532        self.rated_predictions.contains(id)
2533    }
2534
2535    pub fn rate_prediction(
2536        &mut self,
2537        prediction: &EditPrediction,
2538        rating: EditPredictionRating,
2539        feedback: String,
2540        cx: &mut Context<Self>,
2541    ) {
2542        let organization = self.user_store.read(cx).current_organization();
2543
2544        self.rated_predictions.insert(prediction.id.clone());
2545
2546        cx.background_spawn({
2547            let client = self.client.clone();
2548            let prediction_id = prediction.id.to_string();
2549            let inputs = serde_json::to_value(&prediction.inputs);
2550            let output = prediction
2551                .edit_preview
2552                .as_unified_diff(prediction.snapshot.file(), &prediction.edits);
2553            async move {
2554                client
2555                    .cloud_client()
2556                    .submit_edit_prediction_feedback(SubmitEditPredictionFeedbackBody {
2557                        organization_id: organization.map(|organization| organization.id.clone()),
2558                        request_id: prediction_id,
2559                        rating: match rating {
2560                            EditPredictionRating::Positive => "positive".to_string(),
2561                            EditPredictionRating::Negative => "negative".to_string(),
2562                        },
2563                        inputs: inputs?,
2564                        output,
2565                        feedback,
2566                    })
2567                    .await?;
2568
2569                anyhow::Ok(())
2570            }
2571        })
2572        .detach_and_log_err(cx);
2573
2574        cx.notify();
2575    }
2576}
2577
2578fn merge_trailing_events_if_needed(
2579    events: &mut VecDeque<StoredEvent>,
2580    end_snapshot: &TextBufferSnapshot,
2581    latest_snapshot: &TextBufferSnapshot,
2582    latest_edit_range: &Range<Anchor>,
2583) {
2584    if let Some(last_event) = events.back() {
2585        if last_event.old_snapshot.remote_id() != latest_snapshot.remote_id() {
2586            return;
2587        }
2588    }
2589
2590    let mut next_old_event = None;
2591    let mut mergeable_count = 0;
2592    for old_event in events.iter().rev() {
2593        if let Some(next_old_event) = &next_old_event
2594            && !old_event.can_merge(&next_old_event, latest_snapshot, latest_edit_range)
2595        {
2596            break;
2597        }
2598        mergeable_count += 1;
2599        next_old_event = Some(old_event);
2600    }
2601
2602    if mergeable_count <= 1 {
2603        return;
2604    }
2605
2606    let mut events_to_merge = events.range(events.len() - mergeable_count..).peekable();
2607    let oldest_event = events_to_merge.peek().unwrap();
2608    let oldest_snapshot = oldest_event.old_snapshot.clone();
2609
2610    if let Some((diff, edited_range)) =
2611        compute_diff_between_snapshots(&oldest_snapshot, end_snapshot)
2612    {
2613        let merged_event = match oldest_event.event.as_ref() {
2614            zeta_prompt::Event::BufferChange {
2615                old_path,
2616                path,
2617                in_open_source_repo,
2618                ..
2619            } => StoredEvent {
2620                event: Arc::new(zeta_prompt::Event::BufferChange {
2621                    old_path: old_path.clone(),
2622                    path: path.clone(),
2623                    diff,
2624                    in_open_source_repo: *in_open_source_repo,
2625                    predicted: events_to_merge.all(|e| {
2626                        matches!(
2627                            e.event.as_ref(),
2628                            zeta_prompt::Event::BufferChange {
2629                                predicted: true,
2630                                ..
2631                            }
2632                        )
2633                    }),
2634                }),
2635                old_snapshot: oldest_snapshot.clone(),
2636                edit_range: end_snapshot.anchor_before(edited_range.start)
2637                    ..end_snapshot.anchor_before(edited_range.end),
2638            },
2639        };
2640        events.truncate(events.len() - mergeable_count);
2641        events.push_back(merged_event);
2642    }
2643}
2644
2645pub(crate) fn filter_redundant_excerpts(
2646    mut related_files: Vec<RelatedFile>,
2647    cursor_path: &Path,
2648    cursor_row_range: Range<u32>,
2649) -> Vec<RelatedFile> {
2650    for file in &mut related_files {
2651        if file.path.as_ref() == cursor_path {
2652            file.excerpts.retain(|excerpt| {
2653                excerpt.row_range.start < cursor_row_range.start
2654                    || excerpt.row_range.end > cursor_row_range.end
2655            });
2656        }
2657    }
2658    related_files.retain(|file| !file.excerpts.is_empty());
2659    related_files
2660}
2661
2662#[derive(Error, Debug)]
2663#[error(
2664    "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
2665)]
2666pub struct ZedUpdateRequiredError {
2667    minimum_version: Version,
2668}
2669
2670#[derive(Debug, Clone, Copy)]
2671pub enum DataCollectionChoice {
2672    NotAnswered,
2673    Enabled,
2674    Disabled,
2675}
2676
2677impl DataCollectionChoice {
2678    pub fn is_enabled(self, cx: &App) -> bool {
2679        if cx.is_staff() {
2680            return true;
2681        }
2682        match self {
2683            Self::Enabled => true,
2684            Self::NotAnswered | Self::Disabled => false,
2685        }
2686    }
2687
2688    #[must_use]
2689    pub fn toggle(&self) -> DataCollectionChoice {
2690        match self {
2691            Self::Enabled => Self::Disabled,
2692            Self::Disabled => Self::Enabled,
2693            Self::NotAnswered => Self::Enabled,
2694        }
2695    }
2696}
2697
2698impl From<bool> for DataCollectionChoice {
2699    fn from(value: bool) -> Self {
2700        match value {
2701            true => DataCollectionChoice::Enabled,
2702            false => DataCollectionChoice::Disabled,
2703        }
2704    }
2705}
2706
2707struct ZedPredictUpsell;
2708
2709impl Dismissable for ZedPredictUpsell {
2710    const KEY: &'static str = "dismissed-edit-predict-upsell";
2711
2712    fn dismissed() -> bool {
2713        // To make this backwards compatible with older versions of Zed, we
2714        // check if the user has seen the previous Edit Prediction Onboarding
2715        // before, by checking the data collection choice which was written to
2716        // the database once the user clicked on "Accept and Enable"
2717        if KEY_VALUE_STORE
2718            .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2719            .log_err()
2720            .is_some_and(|s| s.is_some())
2721        {
2722            return true;
2723        }
2724
2725        KEY_VALUE_STORE
2726            .read_kvp(Self::KEY)
2727            .log_err()
2728            .is_some_and(|s| s.is_some())
2729    }
2730}
2731
2732pub fn should_show_upsell_modal() -> bool {
2733    !ZedPredictUpsell::dismissed()
2734}
2735
2736pub fn init(cx: &mut App) {
2737    cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
2738        workspace.register_action(
2739            move |workspace, _: &zed_actions::OpenZedPredictOnboarding, window, cx| {
2740                ZedPredictModal::toggle(
2741                    workspace,
2742                    workspace.user_store().clone(),
2743                    workspace.client().clone(),
2744                    window,
2745                    cx,
2746                )
2747            },
2748        );
2749
2750        workspace.register_action(|workspace, _: &ResetOnboarding, _window, cx| {
2751            update_settings_file(workspace.app_state().fs.clone(), cx, move |settings, _| {
2752                settings
2753                    .project
2754                    .all_languages
2755                    .edit_predictions
2756                    .get_or_insert_default()
2757                    .provider = Some(EditPredictionProvider::None)
2758            });
2759        });
2760        fn copilot_for_project(project: &Entity<Project>, cx: &mut App) -> Option<Entity<Copilot>> {
2761            EditPredictionStore::try_global(cx).and_then(|store| {
2762                store.update(cx, |this, cx| this.start_copilot_for_project(project, cx))
2763            })
2764        }
2765
2766        workspace.register_action(|workspace, _: &SignIn, window, cx| {
2767            if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
2768                copilot_ui::initiate_sign_in(copilot, window, cx);
2769            }
2770        });
2771        workspace.register_action(|workspace, _: &Reinstall, window, cx| {
2772            if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
2773                copilot_ui::reinstall_and_sign_in(copilot, window, cx);
2774            }
2775        });
2776        workspace.register_action(|workspace, _: &SignOut, window, cx| {
2777            if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
2778                copilot_ui::initiate_sign_out(copilot, window, cx);
2779            }
2780        });
2781    })
2782    .detach();
2783}