edit_prediction.rs

   1use anyhow::Result;
   2use arrayvec::ArrayVec;
   3use client::{Client, EditPredictionUsage, UserStore};
   4use cloud_api_types::{OrganizationId, SubmitEditPredictionFeedbackBody};
   5use cloud_llm_client::predict_edits_v3::{
   6    PredictEditsV3Request, PredictEditsV3Response, RawCompletionRequest, RawCompletionResponse,
   7};
   8use cloud_llm_client::{
   9    EditPredictionRejectReason, EditPredictionRejection,
  10    MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
  11    PredictEditsRequestTrigger, RejectEditPredictionsBodyRef, ZED_VERSION_HEADER_NAME,
  12};
  13use collections::{HashMap, HashSet};
  14use copilot::{Copilot, Reinstall, SignIn, SignOut};
  15use db::kvp::{Dismissable, KEY_VALUE_STORE};
  16use edit_prediction_context::{RelatedExcerptStore, RelatedExcerptStoreEvent, RelatedFile};
  17use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
  18use futures::{
  19    AsyncReadExt as _, FutureExt as _, StreamExt as _,
  20    channel::mpsc::{self, UnboundedReceiver},
  21    select_biased,
  22};
  23use gpui::BackgroundExecutor;
  24use gpui::http_client::Url;
  25use gpui::{
  26    App, AsyncApp, Entity, EntityId, Global, SharedString, Subscription, Task, WeakEntity, actions,
  27    http_client::{self, AsyncBody, Method},
  28    prelude::*,
  29};
  30use language::language_settings::all_language_settings;
  31use language::{Anchor, Buffer, File, Point, TextBufferSnapshot, ToOffset, ToPoint};
  32use language::{BufferSnapshot, OffsetRangeExt};
  33use language_model::{LlmApiToken, NeedsLlmTokenRefresh, RefreshLlmTokenListener};
  34use project::{DisableAiSettings, Project, ProjectPath, WorktreeId};
  35use release_channel::AppVersion;
  36use semver::Version;
  37use serde::de::DeserializeOwned;
  38use settings::{
  39    EditPredictionPromptFormat, EditPredictionProvider, Settings as _, update_settings_file,
  40};
  41use std::collections::{VecDeque, hash_map};
  42use std::env;
  43use text::{AnchorRangeExt, Edit};
  44use workspace::Workspace;
  45use zeta_prompt::{ZetaFormat, ZetaPromptInput};
  46
  47use std::mem;
  48use std::ops::Range;
  49use std::path::Path;
  50use std::rc::Rc;
  51use std::str::FromStr as _;
  52use std::sync::Arc;
  53use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
  54use thiserror::Error;
  55use util::{RangeExt as _, ResultExt as _};
  56use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
  57
  58pub mod cursor_excerpt;
  59pub mod example_spec;
  60pub mod fim;
  61mod license_detection;
  62pub mod mercury;
  63pub mod ollama;
  64mod onboarding_modal;
  65pub mod open_ai_response;
  66mod prediction;
  67pub mod sweep_ai;
  68
  69pub mod udiff;
  70
  71mod capture_example;
  72pub mod open_ai_compatible;
  73mod zed_edit_prediction_delegate;
  74pub mod zeta;
  75
  76#[cfg(test)]
  77mod edit_prediction_tests;
  78
  79use crate::license_detection::LicenseDetectionWatcher;
  80use crate::mercury::Mercury;
  81use crate::onboarding_modal::ZedPredictModal;
  82pub use crate::prediction::EditPrediction;
  83pub use crate::prediction::EditPredictionId;
  84use crate::prediction::EditPredictionResult;
  85pub use crate::sweep_ai::SweepAi;
  86pub use capture_example::capture_example;
  87pub use language_model::ApiKeyState;
  88pub use telemetry_events::EditPredictionRating;
  89pub use zed_edit_prediction_delegate::ZedEditPredictionDelegate;
  90
  91actions!(
  92    edit_prediction,
  93    [
  94        /// Resets the edit prediction onboarding state.
  95        ResetOnboarding,
  96        /// Clears the edit prediction history.
  97        ClearHistory,
  98    ]
  99);
 100
 101/// Maximum number of events to track.
 102const EVENT_COUNT_MAX: usize = 6;
 103const CHANGE_GROUPING_LINE_SPAN: u32 = 8;
 104const LAST_CHANGE_GROUPING_TIME: Duration = Duration::from_secs(1);
 105const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
 106const REJECT_REQUEST_DEBOUNCE: Duration = Duration::from_secs(15);
 107const EDIT_PREDICTION_SETTLED_EVENT: &str = "Edit Prediction Settled";
 108const EDIT_PREDICTION_SETTLED_TTL: Duration = Duration::from_secs(60 * 5);
 109const EDIT_PREDICTION_SETTLED_QUIESCENCE: Duration = Duration::from_secs(10);
 110
 111pub struct EditPredictionJumpsFeatureFlag;
 112
 113impl FeatureFlag for EditPredictionJumpsFeatureFlag {
 114    const NAME: &'static str = "edit_prediction_jumps";
 115}
 116
 117#[derive(Clone)]
 118struct EditPredictionStoreGlobal(Entity<EditPredictionStore>);
 119
 120impl Global for EditPredictionStoreGlobal {}
 121
 122/// Configuration for using the raw Zeta2 endpoint.
 123/// When set, the client uses the raw endpoint and constructs the prompt itself.
 124/// The version is also used as the Baseten environment name (lowercased).
 125#[derive(Clone)]
 126pub struct Zeta2RawConfig {
 127    pub model_id: Option<String>,
 128    pub environment: Option<String>,
 129    pub format: ZetaFormat,
 130}
 131
 132pub struct EditPredictionStore {
 133    client: Arc<Client>,
 134    user_store: Entity<UserStore>,
 135    llm_token: LlmApiToken,
 136    _llm_token_subscription: Subscription,
 137    _fetch_experiments_task: Task<()>,
 138    projects: HashMap<EntityId, ProjectState>,
 139    update_required: bool,
 140    edit_prediction_model: EditPredictionModel,
 141    zeta2_raw_config: Option<Zeta2RawConfig>,
 142    preferred_experiment: Option<String>,
 143    available_experiments: Vec<String>,
 144    pub sweep_ai: SweepAi,
 145    pub mercury: Mercury,
 146    data_collection_choice: DataCollectionChoice,
 147    reject_predictions_tx: mpsc::UnboundedSender<EditPredictionRejectionPayload>,
 148    settled_predictions_tx: mpsc::UnboundedSender<Instant>,
 149    shown_predictions: VecDeque<EditPrediction>,
 150    rated_predictions: HashSet<EditPredictionId>,
 151    #[cfg(test)]
 152    settled_event_callback: Option<Box<dyn Fn(EditPredictionId, String)>>,
 153}
 154
 155pub(crate) struct EditPredictionRejectionPayload {
 156    rejection: EditPredictionRejection,
 157    organization_id: Option<OrganizationId>,
 158}
 159
 160#[derive(Copy, Clone, PartialEq, Eq)]
 161pub enum EditPredictionModel {
 162    Zeta,
 163    Fim { format: EditPredictionPromptFormat },
 164    Sweep,
 165    Mercury,
 166}
 167
 168#[derive(Clone)]
 169pub struct EditPredictionModelInput {
 170    project: Entity<Project>,
 171    buffer: Entity<Buffer>,
 172    snapshot: BufferSnapshot,
 173    position: Anchor,
 174    events: Vec<Arc<zeta_prompt::Event>>,
 175    related_files: Vec<RelatedFile>,
 176    recent_paths: VecDeque<ProjectPath>,
 177    trigger: PredictEditsRequestTrigger,
 178    diagnostic_search_range: Range<Point>,
 179    debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
 180    can_collect_data: bool,
 181    is_open_source: bool,
 182    pub user_actions: Vec<UserActionRecord>,
 183}
 184
 185#[derive(Debug)]
 186pub enum DebugEvent {
 187    ContextRetrievalStarted(ContextRetrievalStartedDebugEvent),
 188    ContextRetrievalFinished(ContextRetrievalFinishedDebugEvent),
 189    EditPredictionStarted(EditPredictionStartedDebugEvent),
 190    EditPredictionFinished(EditPredictionFinishedDebugEvent),
 191}
 192
 193#[derive(Debug)]
 194pub struct ContextRetrievalStartedDebugEvent {
 195    pub project_entity_id: EntityId,
 196    pub timestamp: Instant,
 197    pub search_prompt: String,
 198}
 199
 200#[derive(Debug)]
 201pub struct ContextRetrievalFinishedDebugEvent {
 202    pub project_entity_id: EntityId,
 203    pub timestamp: Instant,
 204    pub metadata: Vec<(&'static str, SharedString)>,
 205}
 206
 207#[derive(Debug)]
 208pub struct EditPredictionStartedDebugEvent {
 209    pub buffer: WeakEntity<Buffer>,
 210    pub position: Anchor,
 211    pub prompt: Option<String>,
 212}
 213
 214#[derive(Debug)]
 215pub struct EditPredictionFinishedDebugEvent {
 216    pub buffer: WeakEntity<Buffer>,
 217    pub position: Anchor,
 218    pub model_output: Option<String>,
 219}
 220
 221const USER_ACTION_HISTORY_SIZE: usize = 16;
 222
 223#[derive(Clone, Debug)]
 224pub struct UserActionRecord {
 225    pub action_type: UserActionType,
 226    pub buffer_id: EntityId,
 227    pub line_number: u32,
 228    pub offset: usize,
 229    pub timestamp_epoch_ms: u64,
 230}
 231
 232#[derive(Clone, Copy, Debug, PartialEq, Eq)]
 233pub enum UserActionType {
 234    InsertChar,
 235    InsertSelection,
 236    DeleteChar,
 237    DeleteSelection,
 238    CursorMovement,
 239}
 240
 241/// An event with associated metadata for reconstructing buffer state.
 242#[derive(Clone)]
 243pub struct StoredEvent {
 244    pub event: Arc<zeta_prompt::Event>,
 245    pub old_snapshot: TextBufferSnapshot,
 246    pub edit_range: Range<Anchor>,
 247}
 248
 249impl StoredEvent {
 250    fn can_merge(
 251        &self,
 252        next_old_event: &&&StoredEvent,
 253        new_snapshot: &TextBufferSnapshot,
 254        last_edit_range: &Range<Anchor>,
 255    ) -> bool {
 256        // Events must be for the same buffer
 257        if self.old_snapshot.remote_id() != next_old_event.old_snapshot.remote_id() {
 258            return false;
 259        }
 260        if self.old_snapshot.remote_id() != new_snapshot.remote_id() {
 261            return false;
 262        }
 263
 264        let a_is_predicted = matches!(
 265            self.event.as_ref(),
 266            zeta_prompt::Event::BufferChange {
 267                predicted: true,
 268                ..
 269            }
 270        );
 271        let b_is_predicted = matches!(
 272            next_old_event.event.as_ref(),
 273            zeta_prompt::Event::BufferChange {
 274                predicted: true,
 275                ..
 276            }
 277        );
 278
 279        // If events come from the same source (both predicted or both manual) then
 280        // we would have coalesced them already.
 281        if a_is_predicted == b_is_predicted {
 282            return false;
 283        }
 284
 285        let left_range = self.edit_range.to_point(new_snapshot);
 286        let right_range = next_old_event.edit_range.to_point(new_snapshot);
 287        let latest_range = last_edit_range.to_point(&new_snapshot);
 288
 289        // Events near to the latest edit are not merged if their sources differ.
 290        if lines_between_ranges(&left_range, &latest_range)
 291            .min(lines_between_ranges(&right_range, &latest_range))
 292            <= CHANGE_GROUPING_LINE_SPAN
 293        {
 294            return false;
 295        }
 296
 297        // Events that are distant from each other are not merged.
 298        if lines_between_ranges(&left_range, &right_range) > CHANGE_GROUPING_LINE_SPAN {
 299            return false;
 300        }
 301
 302        true
 303    }
 304}
 305
 306fn lines_between_ranges(left: &Range<Point>, right: &Range<Point>) -> u32 {
 307    if left.start > right.end {
 308        return left.start.row - right.end.row;
 309    }
 310    if right.start > left.end {
 311        return right.start.row - left.end.row;
 312    }
 313    0
 314}
 315
 316struct ProjectState {
 317    events: VecDeque<StoredEvent>,
 318    last_event: Option<LastEvent>,
 319    recent_paths: VecDeque<ProjectPath>,
 320    registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
 321    current_prediction: Option<CurrentEditPrediction>,
 322    next_pending_prediction_id: usize,
 323    pending_predictions: ArrayVec<PendingPrediction, 2>,
 324    debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
 325    last_edit_prediction_refresh: Option<(EntityId, Instant)>,
 326    last_jump_prediction_refresh: Option<(EntityId, Instant)>,
 327    cancelled_predictions: HashSet<usize>,
 328    context: Entity<RelatedExcerptStore>,
 329    license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
 330    user_actions: VecDeque<UserActionRecord>,
 331    _subscriptions: [gpui::Subscription; 2],
 332    copilot: Option<Entity<Copilot>>,
 333}
 334
 335impl ProjectState {
 336    fn record_user_action(&mut self, action: UserActionRecord) {
 337        if self.user_actions.len() >= USER_ACTION_HISTORY_SIZE {
 338            self.user_actions.pop_front();
 339        }
 340        self.user_actions.push_back(action);
 341    }
 342
 343    pub fn events(&self, cx: &App) -> Vec<StoredEvent> {
 344        self.events
 345            .iter()
 346            .cloned()
 347            .chain(self.last_event.as_ref().iter().flat_map(|event| {
 348                let (one, two) = event.split_by_pause();
 349                let one = one.finalize(&self.license_detection_watchers, cx);
 350                let two = two.and_then(|two| two.finalize(&self.license_detection_watchers, cx));
 351                one.into_iter().chain(two)
 352            }))
 353            .collect()
 354    }
 355
 356    fn cancel_pending_prediction(
 357        &mut self,
 358        pending_prediction: PendingPrediction,
 359        cx: &mut Context<EditPredictionStore>,
 360    ) {
 361        self.cancelled_predictions.insert(pending_prediction.id);
 362
 363        if pending_prediction.drop_on_cancel {
 364            drop(pending_prediction.task);
 365        } else {
 366            cx.spawn(async move |this, cx| {
 367                let Some(prediction_id) = pending_prediction.task.await else {
 368                    return;
 369                };
 370
 371                this.update(cx, |this, cx| {
 372                    this.reject_prediction(
 373                        prediction_id,
 374                        EditPredictionRejectReason::Canceled,
 375                        false,
 376                        None,
 377                        cx,
 378                    );
 379                })
 380                .ok();
 381            })
 382            .detach()
 383        }
 384    }
 385
 386    fn active_buffer(
 387        &self,
 388        project: &Entity<Project>,
 389        cx: &App,
 390    ) -> Option<(Entity<Buffer>, Option<Anchor>)> {
 391        let project = project.read(cx);
 392        let active_path = project.path_for_entry(project.active_entry()?, cx)?;
 393        let active_buffer = project.buffer_store().read(cx).get_by_path(&active_path)?;
 394        let registered_buffer = self.registered_buffers.get(&active_buffer.entity_id())?;
 395        Some((active_buffer, registered_buffer.last_position))
 396    }
 397}
 398
 399#[derive(Debug, Clone)]
 400struct CurrentEditPrediction {
 401    pub requested_by: PredictionRequestedBy,
 402    pub prediction: EditPrediction,
 403    pub was_shown: bool,
 404    pub shown_with: Option<edit_prediction_types::SuggestionDisplayType>,
 405}
 406
 407impl CurrentEditPrediction {
 408    fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
 409        let Some(new_edits) = self
 410            .prediction
 411            .interpolate(&self.prediction.buffer.read(cx))
 412        else {
 413            return false;
 414        };
 415
 416        if self.prediction.buffer != old_prediction.prediction.buffer {
 417            return true;
 418        }
 419
 420        let Some(old_edits) = old_prediction
 421            .prediction
 422            .interpolate(&old_prediction.prediction.buffer.read(cx))
 423        else {
 424            return true;
 425        };
 426
 427        let requested_by_buffer_id = self.requested_by.buffer_id();
 428
 429        // This reduces the occurrence of UI thrash from replacing edits
 430        //
 431        // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
 432        if requested_by_buffer_id == Some(self.prediction.buffer.entity_id())
 433            && requested_by_buffer_id == Some(old_prediction.prediction.buffer.entity_id())
 434            && old_edits.len() == 1
 435            && new_edits.len() == 1
 436        {
 437            let (old_range, old_text) = &old_edits[0];
 438            let (new_range, new_text) = &new_edits[0];
 439            new_range == old_range && new_text.starts_with(old_text.as_ref())
 440        } else {
 441            true
 442        }
 443    }
 444}
 445
 446#[derive(Debug, Clone)]
 447enum PredictionRequestedBy {
 448    DiagnosticsUpdate,
 449    Buffer(EntityId),
 450}
 451
 452impl PredictionRequestedBy {
 453    pub fn buffer_id(&self) -> Option<EntityId> {
 454        match self {
 455            PredictionRequestedBy::DiagnosticsUpdate => None,
 456            PredictionRequestedBy::Buffer(buffer_id) => Some(*buffer_id),
 457        }
 458    }
 459}
 460
 461const DIAGNOSTIC_LINES_RANGE: u32 = 20;
 462
 463#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
 464pub enum DiagnosticSearchScope {
 465    Local,
 466    Global,
 467}
 468
 469#[derive(Debug)]
 470struct PendingPrediction {
 471    id: usize,
 472    task: Task<Option<EditPredictionId>>,
 473    /// If true, the task is dropped immediately on cancel (cancelling the HTTP request).
 474    /// If false, the task is awaited to completion so rejection can be reported.
 475    drop_on_cancel: bool,
 476}
 477
 478/// A prediction from the perspective of a buffer.
 479#[derive(Debug)]
 480enum BufferEditPrediction<'a> {
 481    Local { prediction: &'a EditPrediction },
 482    Jump { prediction: &'a EditPrediction },
 483}
 484
 485#[cfg(test)]
 486impl std::ops::Deref for BufferEditPrediction<'_> {
 487    type Target = EditPrediction;
 488
 489    fn deref(&self) -> &Self::Target {
 490        match self {
 491            BufferEditPrediction::Local { prediction } => prediction,
 492            BufferEditPrediction::Jump { prediction } => prediction,
 493        }
 494    }
 495}
 496
 497#[derive(Clone)]
 498struct PendingSettledPrediction {
 499    request_id: EditPredictionId,
 500    editable_anchor_range: Range<Anchor>,
 501    enqueued_at: Instant,
 502    last_edit_at: Instant,
 503}
 504
 505struct RegisteredBuffer {
 506    file: Option<Arc<dyn File>>,
 507    snapshot: TextBufferSnapshot,
 508    pending_predictions: Vec<PendingSettledPrediction>,
 509    last_position: Option<Anchor>,
 510    _subscriptions: [gpui::Subscription; 2],
 511}
 512
 513#[derive(Clone)]
 514struct LastEvent {
 515    old_snapshot: TextBufferSnapshot,
 516    new_snapshot: TextBufferSnapshot,
 517    old_file: Option<Arc<dyn File>>,
 518    new_file: Option<Arc<dyn File>>,
 519    edit_range: Option<Range<Anchor>>,
 520    predicted: bool,
 521    snapshot_after_last_editing_pause: Option<TextBufferSnapshot>,
 522    last_edit_time: Option<Instant>,
 523}
 524
 525impl LastEvent {
 526    pub fn finalize(
 527        &self,
 528        license_detection_watchers: &HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
 529        cx: &App,
 530    ) -> Option<StoredEvent> {
 531        let path = buffer_path_with_id_fallback(self.new_file.as_ref(), &self.new_snapshot, cx);
 532        let old_path = buffer_path_with_id_fallback(self.old_file.as_ref(), &self.old_snapshot, cx);
 533
 534        let in_open_source_repo =
 535            [self.new_file.as_ref(), self.old_file.as_ref()]
 536                .iter()
 537                .all(|file| {
 538                    file.is_some_and(|file| {
 539                        license_detection_watchers
 540                            .get(&file.worktree_id(cx))
 541                            .is_some_and(|watcher| watcher.is_project_open_source())
 542                    })
 543                });
 544
 545        let (diff, edit_range) =
 546            compute_diff_between_snapshots(&self.old_snapshot, &self.new_snapshot)?;
 547
 548        if path == old_path && diff.is_empty() {
 549            None
 550        } else {
 551            Some(StoredEvent {
 552                event: Arc::new(zeta_prompt::Event::BufferChange {
 553                    old_path,
 554                    path,
 555                    diff,
 556                    in_open_source_repo,
 557                    predicted: self.predicted,
 558                }),
 559                edit_range: self.new_snapshot.anchor_before(edit_range.start)
 560                    ..self.new_snapshot.anchor_before(edit_range.end),
 561                old_snapshot: self.old_snapshot.clone(),
 562            })
 563        }
 564    }
 565
 566    pub fn split_by_pause(&self) -> (LastEvent, Option<LastEvent>) {
 567        let Some(boundary_snapshot) = self.snapshot_after_last_editing_pause.as_ref() else {
 568            return (self.clone(), None);
 569        };
 570
 571        let before = LastEvent {
 572            old_snapshot: self.old_snapshot.clone(),
 573            new_snapshot: boundary_snapshot.clone(),
 574            old_file: self.old_file.clone(),
 575            new_file: self.new_file.clone(),
 576            edit_range: None,
 577            predicted: self.predicted,
 578            snapshot_after_last_editing_pause: None,
 579            last_edit_time: self.last_edit_time,
 580        };
 581
 582        let after = LastEvent {
 583            old_snapshot: boundary_snapshot.clone(),
 584            new_snapshot: self.new_snapshot.clone(),
 585            old_file: self.old_file.clone(),
 586            new_file: self.new_file.clone(),
 587            edit_range: None,
 588            predicted: self.predicted,
 589            snapshot_after_last_editing_pause: None,
 590            last_edit_time: self.last_edit_time,
 591        };
 592
 593        (before, Some(after))
 594    }
 595}
 596
 597pub(crate) fn compute_diff_between_snapshots(
 598    old_snapshot: &TextBufferSnapshot,
 599    new_snapshot: &TextBufferSnapshot,
 600) -> Option<(String, Range<Point>)> {
 601    let edits: Vec<Edit<usize>> = new_snapshot
 602        .edits_since::<usize>(&old_snapshot.version)
 603        .collect();
 604
 605    let (first_edit, last_edit) = edits.first().zip(edits.last())?;
 606
 607    let old_start_point = old_snapshot.offset_to_point(first_edit.old.start);
 608    let old_end_point = old_snapshot.offset_to_point(last_edit.old.end);
 609    let new_start_point = new_snapshot.offset_to_point(first_edit.new.start);
 610    let new_end_point = new_snapshot.offset_to_point(last_edit.new.end);
 611
 612    const CONTEXT_LINES: u32 = 3;
 613
 614    let old_context_start_row = old_start_point.row.saturating_sub(CONTEXT_LINES);
 615    let new_context_start_row = new_start_point.row.saturating_sub(CONTEXT_LINES);
 616    let old_context_end_row =
 617        (old_end_point.row + 1 + CONTEXT_LINES).min(old_snapshot.max_point().row);
 618    let new_context_end_row =
 619        (new_end_point.row + 1 + CONTEXT_LINES).min(new_snapshot.max_point().row);
 620
 621    let old_start_line_offset = old_snapshot.point_to_offset(Point::new(old_context_start_row, 0));
 622    let new_start_line_offset = new_snapshot.point_to_offset(Point::new(new_context_start_row, 0));
 623    let old_end_line_offset = old_snapshot
 624        .point_to_offset(Point::new(old_context_end_row + 1, 0).min(old_snapshot.max_point()));
 625    let new_end_line_offset = new_snapshot
 626        .point_to_offset(Point::new(new_context_end_row + 1, 0).min(new_snapshot.max_point()));
 627    let old_edit_range = old_start_line_offset..old_end_line_offset;
 628    let new_edit_range = new_start_line_offset..new_end_line_offset;
 629
 630    let old_region_text: String = old_snapshot.text_for_range(old_edit_range).collect();
 631    let new_region_text: String = new_snapshot.text_for_range(new_edit_range).collect();
 632
 633    let diff = language::unified_diff_with_offsets(
 634        &old_region_text,
 635        &new_region_text,
 636        old_context_start_row,
 637        new_context_start_row,
 638    );
 639
 640    Some((diff, new_start_point..new_end_point))
 641}
 642
 643fn buffer_path_with_id_fallback(
 644    file: Option<&Arc<dyn File>>,
 645    snapshot: &TextBufferSnapshot,
 646    cx: &App,
 647) -> Arc<Path> {
 648    if let Some(file) = file {
 649        file.full_path(cx).into()
 650    } else {
 651        Path::new(&format!("untitled-{}", snapshot.remote_id())).into()
 652    }
 653}
 654
 655impl EditPredictionStore {
 656    pub fn try_global(cx: &App) -> Option<Entity<Self>> {
 657        cx.try_global::<EditPredictionStoreGlobal>()
 658            .map(|global| global.0.clone())
 659    }
 660
 661    pub fn global(
 662        client: &Arc<Client>,
 663        user_store: &Entity<UserStore>,
 664        cx: &mut App,
 665    ) -> Entity<Self> {
 666        cx.try_global::<EditPredictionStoreGlobal>()
 667            .map(|global| global.0.clone())
 668            .unwrap_or_else(|| {
 669                let ep_store = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
 670                cx.set_global(EditPredictionStoreGlobal(ep_store.clone()));
 671                ep_store
 672            })
 673    }
 674
 675    pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
 676        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
 677        let data_collection_choice = Self::load_data_collection_choice();
 678
 679        let llm_token = LlmApiToken::default();
 680
 681        let (reject_tx, reject_rx) = mpsc::unbounded();
 682        cx.background_spawn({
 683            let client = client.clone();
 684            let llm_token = llm_token.clone();
 685            let app_version = AppVersion::global(cx);
 686            let background_executor = cx.background_executor().clone();
 687            async move {
 688                Self::handle_rejected_predictions(
 689                    reject_rx,
 690                    client,
 691                    llm_token,
 692                    app_version,
 693                    background_executor,
 694                )
 695                .await
 696            }
 697        })
 698        .detach();
 699
 700        let (settled_predictions_tx, settled_predictions_rx) = mpsc::unbounded();
 701        cx.spawn(async move |this, cx| {
 702            Self::run_settled_predictions_worker(this, settled_predictions_rx, cx).await;
 703        })
 704        .detach();
 705
 706        let mut current_user = user_store.read(cx).watch_current_user();
 707        let fetch_experiments_task = cx.spawn(async move |this, cx| {
 708            while current_user.borrow().is_none() {
 709                current_user.next().await;
 710            }
 711            this.update(cx, |this, cx| {
 712                this.refresh_available_experiments(cx);
 713            })
 714            .log_err();
 715        });
 716
 717        let this = Self {
 718            projects: HashMap::default(),
 719            client,
 720            user_store,
 721            llm_token,
 722            _fetch_experiments_task: fetch_experiments_task,
 723            _llm_token_subscription: cx.subscribe(
 724                &refresh_llm_token_listener,
 725                |this, _listener, _event, cx| {
 726                    let client = this.client.clone();
 727                    let llm_token = this.llm_token.clone();
 728                    let organization_id = this
 729                        .user_store
 730                        .read(cx)
 731                        .current_organization()
 732                        .map(|organization| organization.id.clone());
 733                    cx.spawn(async move |_this, _cx| {
 734                        llm_token.refresh(&client, organization_id).await?;
 735                        anyhow::Ok(())
 736                    })
 737                    .detach_and_log_err(cx);
 738                },
 739            ),
 740            update_required: false,
 741            edit_prediction_model: EditPredictionModel::Zeta,
 742            zeta2_raw_config: Self::zeta2_raw_config_from_env(),
 743            preferred_experiment: None,
 744            available_experiments: Vec::new(),
 745            sweep_ai: SweepAi::new(cx),
 746            mercury: Mercury::new(cx),
 747
 748            data_collection_choice,
 749            reject_predictions_tx: reject_tx,
 750            settled_predictions_tx,
 751            rated_predictions: Default::default(),
 752            shown_predictions: Default::default(),
 753            #[cfg(test)]
 754            settled_event_callback: None,
 755        };
 756
 757        this
 758    }
 759
 760    fn zeta2_raw_config_from_env() -> Option<Zeta2RawConfig> {
 761        let version_str = env::var("ZED_ZETA_FORMAT").ok()?;
 762        let format = ZetaFormat::parse(&version_str).ok()?;
 763        let model_id = env::var("ZED_ZETA_MODEL").ok();
 764        let environment = env::var("ZED_ZETA_ENVIRONMENT").ok();
 765        Some(Zeta2RawConfig {
 766            model_id,
 767            environment,
 768            format,
 769        })
 770    }
 771
 772    pub fn set_edit_prediction_model(&mut self, model: EditPredictionModel) {
 773        self.edit_prediction_model = model;
 774    }
 775
 776    pub fn set_zeta2_raw_config(&mut self, config: Zeta2RawConfig) {
 777        self.zeta2_raw_config = Some(config);
 778    }
 779
 780    pub fn zeta2_raw_config(&self) -> Option<&Zeta2RawConfig> {
 781        self.zeta2_raw_config.as_ref()
 782    }
 783
 784    pub fn preferred_experiment(&self) -> Option<&str> {
 785        self.preferred_experiment.as_deref()
 786    }
 787
 788    pub fn set_preferred_experiment(&mut self, experiment: Option<String>) {
 789        self.preferred_experiment = experiment;
 790    }
 791
 792    pub fn available_experiments(&self) -> &[String] {
 793        &self.available_experiments
 794    }
 795
 796    pub fn refresh_available_experiments(&mut self, cx: &mut Context<Self>) {
 797        let client = self.client.clone();
 798        let llm_token = self.llm_token.clone();
 799        let app_version = AppVersion::global(cx);
 800        let organization_id = self
 801            .user_store
 802            .read(cx)
 803            .current_organization()
 804            .map(|organization| organization.id.clone());
 805
 806        cx.spawn(async move |this, cx| {
 807            let experiments = cx
 808                .background_spawn(async move {
 809                    let http_client = client.http_client();
 810                    let token = llm_token.acquire(&client, organization_id).await?;
 811                    let url = http_client.build_zed_llm_url("/edit_prediction_experiments", &[])?;
 812                    let request = http_client::Request::builder()
 813                        .method(Method::GET)
 814                        .uri(url.as_ref())
 815                        .header("Authorization", format!("Bearer {}", token))
 816                        .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
 817                        .body(Default::default())?;
 818                    let mut response = http_client.send(request).await?;
 819                    if response.status().is_success() {
 820                        let mut body = Vec::new();
 821                        response.body_mut().read_to_end(&mut body).await?;
 822                        let experiments: Vec<String> = serde_json::from_slice(&body)?;
 823                        Ok(experiments)
 824                    } else {
 825                        let mut body = String::new();
 826                        response.body_mut().read_to_string(&mut body).await?;
 827                        anyhow::bail!(
 828                            "Failed to fetch experiments: {:?}\nBody: {}",
 829                            response.status(),
 830                            body
 831                        );
 832                    }
 833                })
 834                .await?;
 835            this.update(cx, |this, cx| {
 836                this.available_experiments = experiments;
 837                cx.notify();
 838            })?;
 839            anyhow::Ok(())
 840        })
 841        .detach_and_log_err(cx);
 842    }
 843
 844    pub fn icons(&self, cx: &App) -> edit_prediction_types::EditPredictionIconSet {
 845        use ui::IconName;
 846        match self.edit_prediction_model {
 847            EditPredictionModel::Sweep => {
 848                edit_prediction_types::EditPredictionIconSet::new(IconName::SweepAi)
 849                    .with_disabled(IconName::SweepAiDisabled)
 850                    .with_up(IconName::SweepAiUp)
 851                    .with_down(IconName::SweepAiDown)
 852                    .with_error(IconName::SweepAiError)
 853            }
 854            EditPredictionModel::Mercury => {
 855                edit_prediction_types::EditPredictionIconSet::new(IconName::Inception)
 856            }
 857            EditPredictionModel::Zeta => {
 858                edit_prediction_types::EditPredictionIconSet::new(IconName::ZedPredict)
 859                    .with_disabled(IconName::ZedPredictDisabled)
 860                    .with_up(IconName::ZedPredictUp)
 861                    .with_down(IconName::ZedPredictDown)
 862                    .with_error(IconName::ZedPredictError)
 863            }
 864            EditPredictionModel::Fim { .. } => {
 865                let settings = &all_language_settings(None, cx).edit_predictions;
 866                match settings.provider {
 867                    EditPredictionProvider::Ollama => {
 868                        edit_prediction_types::EditPredictionIconSet::new(IconName::AiOllama)
 869                    }
 870                    _ => {
 871                        edit_prediction_types::EditPredictionIconSet::new(IconName::AiOpenAiCompat)
 872                    }
 873                }
 874            }
 875        }
 876    }
 877
 878    pub fn has_sweep_api_token(&self, cx: &App) -> bool {
 879        self.sweep_ai.api_token.read(cx).has_key()
 880    }
 881
 882    pub fn has_mercury_api_token(&self, cx: &App) -> bool {
 883        self.mercury.api_token.read(cx).has_key()
 884    }
 885
 886    pub fn clear_history(&mut self) {
 887        for project_state in self.projects.values_mut() {
 888            project_state.events.clear();
 889            project_state.last_event.take();
 890        }
 891    }
 892
 893    pub fn clear_history_for_project(&mut self, project: &Entity<Project>) {
 894        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
 895            project_state.events.clear();
 896            project_state.last_event.take();
 897        }
 898    }
 899
 900    pub fn edit_history_for_project(
 901        &self,
 902        project: &Entity<Project>,
 903        cx: &App,
 904    ) -> Vec<StoredEvent> {
 905        self.projects
 906            .get(&project.entity_id())
 907            .map(|project_state| project_state.events(cx))
 908            .unwrap_or_default()
 909    }
 910
 911    pub fn context_for_project<'a>(
 912        &'a self,
 913        project: &Entity<Project>,
 914        cx: &'a mut App,
 915    ) -> Vec<RelatedFile> {
 916        self.projects
 917            .get(&project.entity_id())
 918            .map(|project_state| {
 919                project_state.context.update(cx, |context, cx| {
 920                    context
 921                        .related_files_with_buffers(cx)
 922                        .map(|(mut related_file, buffer)| {
 923                            related_file.in_open_source_repo = buffer
 924                                .read(cx)
 925                                .file()
 926                                .map_or(false, |file| self.is_file_open_source(&project, file, cx));
 927                            related_file
 928                        })
 929                        .collect()
 930                })
 931            })
 932            .unwrap_or_default()
 933    }
 934
 935    pub fn copilot_for_project(&self, project: &Entity<Project>) -> Option<Entity<Copilot>> {
 936        self.projects
 937            .get(&project.entity_id())
 938            .and_then(|project| project.copilot.clone())
 939    }
 940
 941    pub fn start_copilot_for_project(
 942        &mut self,
 943        project: &Entity<Project>,
 944        cx: &mut Context<Self>,
 945    ) -> Option<Entity<Copilot>> {
 946        if DisableAiSettings::get(None, cx).disable_ai {
 947            return None;
 948        }
 949        let state = self.get_or_init_project(project, cx);
 950
 951        if state.copilot.is_some() {
 952            return state.copilot.clone();
 953        }
 954        let _project = project.clone();
 955        let project = project.read(cx);
 956
 957        let node = project.node_runtime().cloned();
 958        if let Some(node) = node {
 959            let next_id = project.languages().next_language_server_id();
 960            let fs = project.fs().clone();
 961
 962            let copilot = cx.new(|cx| Copilot::new(Some(_project), next_id, fs, node, cx));
 963            state.copilot = Some(copilot.clone());
 964            Some(copilot)
 965        } else {
 966            None
 967        }
 968    }
 969
 970    pub fn context_for_project_with_buffers<'a>(
 971        &'a self,
 972        project: &Entity<Project>,
 973        cx: &'a mut App,
 974    ) -> Vec<(RelatedFile, Entity<Buffer>)> {
 975        self.projects
 976            .get(&project.entity_id())
 977            .map(|project| {
 978                project.context.update(cx, |context, cx| {
 979                    context.related_files_with_buffers(cx).collect()
 980                })
 981            })
 982            .unwrap_or_default()
 983    }
 984
 985    pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
 986        if matches!(self.edit_prediction_model, EditPredictionModel::Zeta) {
 987            self.user_store.read(cx).edit_prediction_usage()
 988        } else {
 989            None
 990        }
 991    }
 992
 993    pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
 994        self.get_or_init_project(project, cx);
 995    }
 996
 997    pub fn register_buffer(
 998        &mut self,
 999        buffer: &Entity<Buffer>,
1000        project: &Entity<Project>,
1001        cx: &mut Context<Self>,
1002    ) {
1003        let project_state = self.get_or_init_project(project, cx);
1004        Self::register_buffer_impl(project_state, buffer, project, cx);
1005    }
1006
1007    fn get_or_init_project(
1008        &mut self,
1009        project: &Entity<Project>,
1010        cx: &mut Context<Self>,
1011    ) -> &mut ProjectState {
1012        let entity_id = project.entity_id();
1013        self.projects
1014            .entry(entity_id)
1015            .or_insert_with(|| ProjectState {
1016                context: {
1017                    let related_excerpt_store = cx.new(|cx| RelatedExcerptStore::new(project, cx));
1018                    cx.subscribe(&related_excerpt_store, move |this, _, event, _| {
1019                        this.handle_excerpt_store_event(entity_id, event);
1020                    })
1021                    .detach();
1022                    related_excerpt_store
1023                },
1024                events: VecDeque::new(),
1025                last_event: None,
1026                recent_paths: VecDeque::new(),
1027                debug_tx: None,
1028                registered_buffers: HashMap::default(),
1029                current_prediction: None,
1030                cancelled_predictions: HashSet::default(),
1031                pending_predictions: ArrayVec::new(),
1032                next_pending_prediction_id: 0,
1033                last_edit_prediction_refresh: None,
1034                last_jump_prediction_refresh: None,
1035                license_detection_watchers: HashMap::default(),
1036                user_actions: VecDeque::with_capacity(USER_ACTION_HISTORY_SIZE),
1037                _subscriptions: [
1038                    cx.subscribe(&project, Self::handle_project_event),
1039                    cx.observe_release(&project, move |this, _, cx| {
1040                        this.projects.remove(&entity_id);
1041                        cx.notify();
1042                    }),
1043                ],
1044                copilot: None,
1045            })
1046    }
1047
1048    pub fn remove_project(&mut self, project: &Entity<Project>) {
1049        self.projects.remove(&project.entity_id());
1050    }
1051
1052    fn handle_excerpt_store_event(
1053        &mut self,
1054        project_entity_id: EntityId,
1055        event: &RelatedExcerptStoreEvent,
1056    ) {
1057        if let Some(project_state) = self.projects.get(&project_entity_id) {
1058            if let Some(debug_tx) = project_state.debug_tx.clone() {
1059                match event {
1060                    RelatedExcerptStoreEvent::StartedRefresh => {
1061                        debug_tx
1062                            .unbounded_send(DebugEvent::ContextRetrievalStarted(
1063                                ContextRetrievalStartedDebugEvent {
1064                                    project_entity_id: project_entity_id,
1065                                    timestamp: Instant::now(),
1066                                    search_prompt: String::new(),
1067                                },
1068                            ))
1069                            .ok();
1070                    }
1071                    RelatedExcerptStoreEvent::FinishedRefresh {
1072                        cache_hit_count,
1073                        cache_miss_count,
1074                        mean_definition_latency,
1075                        max_definition_latency,
1076                    } => {
1077                        debug_tx
1078                            .unbounded_send(DebugEvent::ContextRetrievalFinished(
1079                                ContextRetrievalFinishedDebugEvent {
1080                                    project_entity_id: project_entity_id,
1081                                    timestamp: Instant::now(),
1082                                    metadata: vec![
1083                                        (
1084                                            "Cache Hits",
1085                                            format!(
1086                                                "{}/{}",
1087                                                cache_hit_count,
1088                                                cache_hit_count + cache_miss_count
1089                                            )
1090                                            .into(),
1091                                        ),
1092                                        (
1093                                            "Max LSP Time",
1094                                            format!("{} ms", max_definition_latency.as_millis())
1095                                                .into(),
1096                                        ),
1097                                        (
1098                                            "Mean LSP Time",
1099                                            format!("{} ms", mean_definition_latency.as_millis())
1100                                                .into(),
1101                                        ),
1102                                    ],
1103                                },
1104                            ))
1105                            .ok();
1106                    }
1107                }
1108            }
1109        }
1110    }
1111
1112    pub fn debug_info(
1113        &mut self,
1114        project: &Entity<Project>,
1115        cx: &mut Context<Self>,
1116    ) -> mpsc::UnboundedReceiver<DebugEvent> {
1117        let project_state = self.get_or_init_project(project, cx);
1118        let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
1119        project_state.debug_tx = Some(debug_watch_tx);
1120        debug_watch_rx
1121    }
1122
1123    fn handle_project_event(
1124        &mut self,
1125        project: Entity<Project>,
1126        event: &project::Event,
1127        cx: &mut Context<Self>,
1128    ) {
1129        if !is_ep_store_provider(all_language_settings(None, cx).edit_predictions.provider) {
1130            return;
1131        }
1132        // TODO [zeta2] init with recent paths
1133        match event {
1134            project::Event::ActiveEntryChanged(Some(active_entry_id)) => {
1135                let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1136                    return;
1137                };
1138                let path = project.read(cx).path_for_entry(*active_entry_id, cx);
1139                if let Some(path) = path {
1140                    if let Some(ix) = project_state
1141                        .recent_paths
1142                        .iter()
1143                        .position(|probe| probe == &path)
1144                    {
1145                        project_state.recent_paths.remove(ix);
1146                    }
1147                    project_state.recent_paths.push_front(path);
1148                }
1149            }
1150            project::Event::DiagnosticsUpdated { .. } => {
1151                if cx.has_flag::<EditPredictionJumpsFeatureFlag>() {
1152                    self.refresh_prediction_from_diagnostics(
1153                        project,
1154                        DiagnosticSearchScope::Global,
1155                        cx,
1156                    );
1157                }
1158            }
1159            _ => (),
1160        }
1161    }
1162
1163    fn register_buffer_impl<'a>(
1164        project_state: &'a mut ProjectState,
1165        buffer: &Entity<Buffer>,
1166        project: &Entity<Project>,
1167        cx: &mut Context<Self>,
1168    ) -> &'a mut RegisteredBuffer {
1169        let buffer_id = buffer.entity_id();
1170
1171        if let Some(file) = buffer.read(cx).file() {
1172            let worktree_id = file.worktree_id(cx);
1173            if let Some(worktree) = project.read(cx).worktree_for_id(worktree_id, cx) {
1174                project_state
1175                    .license_detection_watchers
1176                    .entry(worktree_id)
1177                    .or_insert_with(|| {
1178                        let project_entity_id = project.entity_id();
1179                        cx.observe_release(&worktree, move |this, _worktree, _cx| {
1180                            let Some(project_state) = this.projects.get_mut(&project_entity_id)
1181                            else {
1182                                return;
1183                            };
1184                            project_state
1185                                .license_detection_watchers
1186                                .remove(&worktree_id);
1187                        })
1188                        .detach();
1189                        Rc::new(LicenseDetectionWatcher::new(&worktree, cx))
1190                    });
1191            }
1192        }
1193
1194        match project_state.registered_buffers.entry(buffer_id) {
1195            hash_map::Entry::Occupied(entry) => entry.into_mut(),
1196            hash_map::Entry::Vacant(entry) => {
1197                let buf = buffer.read(cx);
1198                let snapshot = buf.text_snapshot();
1199                let file = buf.file().cloned();
1200                let project_entity_id = project.entity_id();
1201                entry.insert(RegisteredBuffer {
1202                    snapshot,
1203                    file,
1204                    last_position: None,
1205                    pending_predictions: Vec::new(),
1206                    _subscriptions: [
1207                        cx.subscribe(buffer, {
1208                            let project = project.downgrade();
1209                            move |this, buffer, event, cx| {
1210                                if let language::BufferEvent::Edited = event
1211                                    && let Some(project) = project.upgrade()
1212                                {
1213                                    this.report_changes_for_buffer(&buffer, &project, false, cx);
1214                                }
1215                            }
1216                        }),
1217                        cx.observe_release(buffer, move |this, _buffer, _cx| {
1218                            let Some(project_state) = this.projects.get_mut(&project_entity_id)
1219                            else {
1220                                return;
1221                            };
1222                            project_state.registered_buffers.remove(&buffer_id);
1223                        }),
1224                    ],
1225                })
1226            }
1227        }
1228    }
1229
1230    fn report_changes_for_buffer(
1231        &mut self,
1232        buffer: &Entity<Buffer>,
1233        project: &Entity<Project>,
1234        is_predicted: bool,
1235        cx: &mut Context<Self>,
1236    ) {
1237        let project_state = self.get_or_init_project(project, cx);
1238        let registered_buffer = Self::register_buffer_impl(project_state, buffer, project, cx);
1239
1240        let buf = buffer.read(cx);
1241        let new_file = buf.file().cloned();
1242        let new_snapshot = buf.text_snapshot();
1243        if new_snapshot.version == registered_buffer.snapshot.version {
1244            return;
1245        }
1246
1247        let old_file = mem::replace(&mut registered_buffer.file, new_file.clone());
1248        let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
1249        let mut num_edits = 0usize;
1250        let mut total_deleted = 0usize;
1251        let mut total_inserted = 0usize;
1252        let mut edit_range: Option<Range<Anchor>> = None;
1253        let mut last_offset: Option<usize> = None;
1254        let now = cx.background_executor().now();
1255
1256        for (edit, anchor_range) in
1257            new_snapshot.anchored_edits_since::<usize>(&old_snapshot.version)
1258        {
1259            num_edits += 1;
1260            total_deleted += edit.old.len();
1261            total_inserted += edit.new.len();
1262            edit_range = Some(match edit_range {
1263                None => anchor_range,
1264                Some(acc) => acc.start..anchor_range.end,
1265            });
1266            last_offset = Some(edit.new.end);
1267        }
1268
1269        let Some(edit_range) = edit_range else {
1270            return;
1271        };
1272
1273        for pending_prediction in &mut registered_buffer.pending_predictions {
1274            if edit_range.overlaps(&pending_prediction.editable_anchor_range, &new_snapshot) {
1275                pending_prediction.last_edit_at = now;
1276            }
1277        }
1278
1279        let action_type = match (total_deleted, total_inserted, num_edits) {
1280            (0, ins, n) if ins == n => UserActionType::InsertChar,
1281            (0, _, _) => UserActionType::InsertSelection,
1282            (del, 0, n) if del == n => UserActionType::DeleteChar,
1283            (_, 0, _) => UserActionType::DeleteSelection,
1284            (_, ins, n) if ins == n => UserActionType::InsertChar,
1285            (_, _, _) => UserActionType::InsertSelection,
1286        };
1287
1288        if let Some(offset) = last_offset {
1289            let point = new_snapshot.offset_to_point(offset);
1290            let timestamp_epoch_ms = SystemTime::now()
1291                .duration_since(UNIX_EPOCH)
1292                .map(|d| d.as_millis() as u64)
1293                .unwrap_or(0);
1294            project_state.record_user_action(UserActionRecord {
1295                action_type,
1296                buffer_id: buffer.entity_id(),
1297                line_number: point.row,
1298                offset,
1299                timestamp_epoch_ms,
1300            });
1301        }
1302
1303        let events = &mut project_state.events;
1304
1305        if let Some(last_event) = project_state.last_event.as_mut() {
1306            let is_next_snapshot_of_same_buffer = old_snapshot.remote_id()
1307                == last_event.new_snapshot.remote_id()
1308                && old_snapshot.version == last_event.new_snapshot.version;
1309
1310            let prediction_source_changed = is_predicted != last_event.predicted;
1311
1312            let should_coalesce = is_next_snapshot_of_same_buffer
1313                && !prediction_source_changed
1314                && last_event
1315                    .edit_range
1316                    .as_ref()
1317                    .is_some_and(|last_edit_range| {
1318                        lines_between_ranges(
1319                            &edit_range.to_point(&new_snapshot),
1320                            &last_edit_range.to_point(&new_snapshot),
1321                        ) <= CHANGE_GROUPING_LINE_SPAN
1322                    });
1323
1324            if should_coalesce {
1325                let pause_elapsed = last_event
1326                    .last_edit_time
1327                    .map(|t| now.duration_since(t) >= LAST_CHANGE_GROUPING_TIME)
1328                    .unwrap_or(false);
1329                if pause_elapsed {
1330                    last_event.snapshot_after_last_editing_pause =
1331                        Some(last_event.new_snapshot.clone());
1332                }
1333
1334                last_event.edit_range = Some(edit_range);
1335                last_event.new_snapshot = new_snapshot;
1336                last_event.last_edit_time = Some(now);
1337                return;
1338            }
1339        }
1340
1341        if let Some(event) = project_state.last_event.take() {
1342            if let Some(event) = event.finalize(&project_state.license_detection_watchers, cx) {
1343                if events.len() + 1 >= EVENT_COUNT_MAX {
1344                    events.pop_front();
1345                }
1346                events.push_back(event);
1347            }
1348        }
1349
1350        merge_trailing_events_if_needed(events, &old_snapshot, &new_snapshot, &edit_range);
1351
1352        project_state.last_event = Some(LastEvent {
1353            old_file,
1354            new_file,
1355            old_snapshot,
1356            new_snapshot,
1357            edit_range: Some(edit_range),
1358            predicted: is_predicted,
1359            snapshot_after_last_editing_pause: None,
1360            last_edit_time: Some(now),
1361        });
1362    }
1363
1364    fn prediction_at(
1365        &mut self,
1366        buffer: &Entity<Buffer>,
1367        position: Option<language::Anchor>,
1368        project: &Entity<Project>,
1369        cx: &App,
1370    ) -> Option<BufferEditPrediction<'_>> {
1371        let project_state = self.projects.get_mut(&project.entity_id())?;
1372        if let Some(position) = position
1373            && let Some(buffer) = project_state
1374                .registered_buffers
1375                .get_mut(&buffer.entity_id())
1376        {
1377            buffer.last_position = Some(position);
1378        }
1379
1380        let CurrentEditPrediction {
1381            requested_by,
1382            prediction,
1383            ..
1384        } = project_state.current_prediction.as_ref()?;
1385
1386        if prediction.targets_buffer(buffer.read(cx)) {
1387            Some(BufferEditPrediction::Local { prediction })
1388        } else {
1389            let show_jump = match requested_by {
1390                PredictionRequestedBy::Buffer(requested_by_buffer_id) => {
1391                    requested_by_buffer_id == &buffer.entity_id()
1392                }
1393                PredictionRequestedBy::DiagnosticsUpdate => true,
1394            };
1395
1396            if show_jump {
1397                Some(BufferEditPrediction::Jump { prediction })
1398            } else {
1399                None
1400            }
1401        }
1402    }
1403
1404    fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
1405        let Some(current_prediction) = self
1406            .projects
1407            .get_mut(&project.entity_id())
1408            .and_then(|project_state| project_state.current_prediction.take())
1409        else {
1410            return;
1411        };
1412
1413        self.report_changes_for_buffer(&current_prediction.prediction.buffer, project, true, cx);
1414
1415        // can't hold &mut project_state ref across report_changes_for_buffer_call
1416        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1417            return;
1418        };
1419
1420        for pending_prediction in mem::take(&mut project_state.pending_predictions) {
1421            project_state.cancel_pending_prediction(pending_prediction, cx);
1422        }
1423
1424        match self.edit_prediction_model {
1425            EditPredictionModel::Sweep => {
1426                sweep_ai::edit_prediction_accepted(self, current_prediction, cx)
1427            }
1428            EditPredictionModel::Mercury => {
1429                mercury::edit_prediction_accepted(
1430                    current_prediction.prediction.id,
1431                    self.client.http_client(),
1432                    cx,
1433                );
1434            }
1435            EditPredictionModel::Zeta => {
1436                let is_cloud = !matches!(
1437                    all_language_settings(None, cx).edit_predictions.provider,
1438                    EditPredictionProvider::Ollama | EditPredictionProvider::OpenAiCompatibleApi
1439                );
1440                if is_cloud {
1441                    zeta::edit_prediction_accepted(self, current_prediction, cx)
1442                }
1443            }
1444            EditPredictionModel::Fim { .. } => {}
1445        }
1446    }
1447
1448    async fn handle_rejected_predictions(
1449        rx: UnboundedReceiver<EditPredictionRejectionPayload>,
1450        client: Arc<Client>,
1451        llm_token: LlmApiToken,
1452        app_version: Version,
1453        background_executor: BackgroundExecutor,
1454    ) {
1455        let mut rx = std::pin::pin!(rx.peekable());
1456        let mut batched = Vec::new();
1457
1458        while let Some(EditPredictionRejectionPayload {
1459            rejection,
1460            organization_id,
1461        }) = rx.next().await
1462        {
1463            batched.push(rejection);
1464
1465            if batched.len() < MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST / 2 {
1466                select_biased! {
1467                    next = rx.as_mut().peek().fuse() => {
1468                        if next.is_some() {
1469                            continue;
1470                        }
1471                    }
1472                    () = background_executor.timer(REJECT_REQUEST_DEBOUNCE).fuse() => {},
1473                }
1474            }
1475
1476            let url = client
1477                .http_client()
1478                .build_zed_llm_url("/predict_edits/reject", &[])
1479                .unwrap();
1480
1481            let flush_count = batched
1482                .len()
1483                // in case items have accumulated after failure
1484                .min(MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST);
1485            let start = batched.len() - flush_count;
1486
1487            let body = RejectEditPredictionsBodyRef {
1488                rejections: &batched[start..],
1489            };
1490
1491            let result = Self::send_api_request::<()>(
1492                |builder| {
1493                    let req = builder
1494                        .uri(url.as_ref())
1495                        .body(serde_json::to_string(&body)?.into());
1496                    anyhow::Ok(req?)
1497                },
1498                client.clone(),
1499                llm_token.clone(),
1500                organization_id,
1501                app_version.clone(),
1502                true,
1503            )
1504            .await;
1505
1506            if result.log_err().is_some() {
1507                batched.drain(start..);
1508            }
1509        }
1510    }
1511
1512    async fn run_settled_predictions_worker(
1513        this: WeakEntity<Self>,
1514        mut rx: UnboundedReceiver<Instant>,
1515        cx: &mut AsyncApp,
1516    ) {
1517        let mut next_wake_time: Option<Instant> = None;
1518        loop {
1519            let now = cx.background_executor().now();
1520            if let Some(wake_time) = next_wake_time.take() {
1521                cx.background_executor()
1522                    .timer(wake_time.duration_since(now))
1523                    .await;
1524            } else {
1525                let Some(new_enqueue_time) = rx.next().await else {
1526                    break;
1527                };
1528                next_wake_time = Some(new_enqueue_time + EDIT_PREDICTION_SETTLED_QUIESCENCE);
1529                while rx.next().now_or_never().flatten().is_some() {}
1530                continue;
1531            }
1532
1533            let Some(this) = this.upgrade() else {
1534                break;
1535            };
1536
1537            let now = cx.background_executor().now();
1538
1539            let mut oldest_edited_at = None;
1540
1541            this.update(cx, |this, _| {
1542                for (_, project_state) in this.projects.iter_mut() {
1543                    for (_, registered_buffer) in project_state.registered_buffers.iter_mut() {
1544                        registered_buffer
1545                            .pending_predictions
1546                            .retain_mut(|pending_prediction| {
1547                                let age =
1548                                    now.saturating_duration_since(pending_prediction.enqueued_at);
1549                                if age >= EDIT_PREDICTION_SETTLED_TTL {
1550                                    return false;
1551                                }
1552
1553                                let quiet_for =
1554                                    now.saturating_duration_since(pending_prediction.last_edit_at);
1555                                if quiet_for >= EDIT_PREDICTION_SETTLED_QUIESCENCE {
1556                                    let settled_editable_region = registered_buffer
1557                                        .snapshot
1558                                        .text_for_range(
1559                                            pending_prediction.editable_anchor_range.clone(),
1560                                        )
1561                                        .collect::<String>();
1562
1563                                    #[cfg(test)]
1564                                    if let Some(callback) = &this.settled_event_callback {
1565                                        callback(
1566                                            pending_prediction.request_id.clone(),
1567                                            settled_editable_region.clone(),
1568                                        );
1569                                    }
1570
1571                                    telemetry::event!(
1572                                        EDIT_PREDICTION_SETTLED_EVENT,
1573                                        request_id = pending_prediction.request_id.0.clone(),
1574                                        settled_editable_region,
1575                                    );
1576
1577                                    return false;
1578                                }
1579
1580                                if oldest_edited_at
1581                                    .is_none_or(|t| pending_prediction.last_edit_at < t)
1582                                {
1583                                    oldest_edited_at = Some(pending_prediction.last_edit_at);
1584                                }
1585
1586                                true
1587                            });
1588                    }
1589                }
1590            });
1591
1592            next_wake_time = oldest_edited_at.map(|t| t + EDIT_PREDICTION_SETTLED_QUIESCENCE);
1593        }
1594    }
1595
1596    pub(crate) fn enqueue_settled_prediction(
1597        &mut self,
1598        request_id: EditPredictionId,
1599        project: &Entity<Project>,
1600        edited_buffer: &Entity<Buffer>,
1601        edited_buffer_snapshot: &BufferSnapshot,
1602        editable_offset_range: Range<usize>,
1603        cx: &mut Context<Self>,
1604    ) {
1605        let project_state = self.get_or_init_project(project, cx);
1606        if let Some(buffer) = project_state
1607            .registered_buffers
1608            .get_mut(&edited_buffer.entity_id())
1609        {
1610            let now = cx.background_executor().now();
1611            buffer.pending_predictions.push(PendingSettledPrediction {
1612                request_id,
1613                editable_anchor_range: edited_buffer_snapshot
1614                    .anchor_range_around(editable_offset_range),
1615                enqueued_at: now,
1616                last_edit_at: now,
1617            });
1618            self.settled_predictions_tx.unbounded_send(now).ok();
1619        }
1620    }
1621
1622    fn reject_current_prediction(
1623        &mut self,
1624        reason: EditPredictionRejectReason,
1625        project: &Entity<Project>,
1626        cx: &App,
1627    ) {
1628        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
1629            project_state.pending_predictions.clear();
1630            if let Some(prediction) = project_state.current_prediction.take() {
1631                let model_version = prediction.prediction.model_version.clone();
1632                self.reject_prediction(
1633                    prediction.prediction.id,
1634                    reason,
1635                    prediction.was_shown,
1636                    model_version,
1637                    cx,
1638                );
1639            }
1640        };
1641    }
1642
1643    fn did_show_current_prediction(
1644        &mut self,
1645        project: &Entity<Project>,
1646        display_type: edit_prediction_types::SuggestionDisplayType,
1647        cx: &mut Context<Self>,
1648    ) {
1649        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1650            return;
1651        };
1652
1653        let Some(current_prediction) = project_state.current_prediction.as_mut() else {
1654            return;
1655        };
1656
1657        let is_jump = display_type == edit_prediction_types::SuggestionDisplayType::Jump;
1658        let previous_shown_with = current_prediction.shown_with;
1659
1660        if previous_shown_with.is_none() || !is_jump {
1661            current_prediction.shown_with = Some(display_type);
1662        }
1663
1664        let is_first_non_jump_show = !current_prediction.was_shown && !is_jump;
1665
1666        if is_first_non_jump_show {
1667            current_prediction.was_shown = true;
1668        }
1669
1670        let display_type_changed = previous_shown_with != Some(display_type);
1671
1672        if self.edit_prediction_model == EditPredictionModel::Sweep && display_type_changed {
1673            sweep_ai::edit_prediction_shown(
1674                &self.sweep_ai,
1675                self.client.clone(),
1676                &current_prediction.prediction,
1677                display_type,
1678                cx,
1679            );
1680        }
1681
1682        if is_first_non_jump_show {
1683            self.shown_predictions
1684                .push_front(current_prediction.prediction.clone());
1685            if self.shown_predictions.len() > 50 {
1686                let completion = self.shown_predictions.pop_back().unwrap();
1687                self.rated_predictions.remove(&completion.id);
1688            }
1689        }
1690    }
1691
1692    fn reject_prediction(
1693        &mut self,
1694        prediction_id: EditPredictionId,
1695        reason: EditPredictionRejectReason,
1696        was_shown: bool,
1697        model_version: Option<String>,
1698        cx: &App,
1699    ) {
1700        match self.edit_prediction_model {
1701            EditPredictionModel::Zeta => {
1702                let is_cloud = !matches!(
1703                    all_language_settings(None, cx).edit_predictions.provider,
1704                    EditPredictionProvider::Ollama | EditPredictionProvider::OpenAiCompatibleApi
1705                );
1706
1707                if is_cloud {
1708                    let organization_id = self
1709                        .user_store
1710                        .read(cx)
1711                        .current_organization()
1712                        .map(|organization| organization.id.clone());
1713
1714                    self.reject_predictions_tx
1715                        .unbounded_send(EditPredictionRejectionPayload {
1716                            rejection: EditPredictionRejection {
1717                                request_id: prediction_id.to_string(),
1718                                reason,
1719                                was_shown,
1720                                model_version,
1721                            },
1722                            organization_id,
1723                        })
1724                        .log_err();
1725                }
1726            }
1727            EditPredictionModel::Mercury => {
1728                mercury::edit_prediction_rejected(
1729                    prediction_id,
1730                    was_shown,
1731                    reason,
1732                    self.client.http_client(),
1733                    cx,
1734                );
1735            }
1736            EditPredictionModel::Sweep | EditPredictionModel::Fim { .. } => {}
1737        }
1738    }
1739
1740    fn is_refreshing(&self, project: &Entity<Project>) -> bool {
1741        self.projects
1742            .get(&project.entity_id())
1743            .is_some_and(|project_state| !project_state.pending_predictions.is_empty())
1744    }
1745
1746    pub fn refresh_prediction_from_buffer(
1747        &mut self,
1748        project: Entity<Project>,
1749        buffer: Entity<Buffer>,
1750        position: language::Anchor,
1751        cx: &mut Context<Self>,
1752    ) {
1753        self.queue_prediction_refresh(
1754            project.clone(),
1755            PredictEditsRequestTrigger::Other,
1756            buffer.entity_id(),
1757            cx,
1758            move |this, cx| {
1759                let Some(request_task) = this
1760                    .update(cx, |this, cx| {
1761                        this.request_prediction(
1762                            &project,
1763                            &buffer,
1764                            position,
1765                            PredictEditsRequestTrigger::Other,
1766                            cx,
1767                        )
1768                    })
1769                    .log_err()
1770                else {
1771                    return Task::ready(anyhow::Ok(None));
1772                };
1773
1774                cx.spawn(async move |_cx| {
1775                    request_task.await.map(|prediction_result| {
1776                        prediction_result.map(|prediction_result| {
1777                            (
1778                                prediction_result,
1779                                PredictionRequestedBy::Buffer(buffer.entity_id()),
1780                            )
1781                        })
1782                    })
1783                })
1784            },
1785        )
1786    }
1787
1788    pub fn refresh_prediction_from_diagnostics(
1789        &mut self,
1790        project: Entity<Project>,
1791        scope: DiagnosticSearchScope,
1792        cx: &mut Context<Self>,
1793    ) {
1794        if !is_ep_store_provider(all_language_settings(None, cx).edit_predictions.provider) {
1795            return;
1796        }
1797
1798        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1799            return;
1800        };
1801
1802        // Prefer predictions from buffer
1803        if project_state.current_prediction.is_some() {
1804            log::debug!(
1805                "edit_prediction: diagnostic refresh skipped, current prediction already exists"
1806            );
1807            return;
1808        }
1809
1810        self.queue_prediction_refresh(
1811            project.clone(),
1812            PredictEditsRequestTrigger::Diagnostics,
1813            project.entity_id(),
1814            cx,
1815            move |this, cx| {
1816                let Some((active_buffer, snapshot, cursor_point)) = this
1817                    .read_with(cx, |this, cx| {
1818                        let project_state = this.projects.get(&project.entity_id())?;
1819                        let (buffer, position) = project_state.active_buffer(&project, cx)?;
1820                        let snapshot = buffer.read(cx).snapshot();
1821
1822                        if !Self::predictions_enabled_at(&snapshot, position, cx) {
1823                            return None;
1824                        }
1825
1826                        let cursor_point = position
1827                            .map(|pos| pos.to_point(&snapshot))
1828                            .unwrap_or_default();
1829
1830                        Some((buffer, snapshot, cursor_point))
1831                    })
1832                    .log_err()
1833                    .flatten()
1834                else {
1835                    return Task::ready(anyhow::Ok(None));
1836                };
1837
1838                cx.spawn(async move |cx| {
1839                    let diagnostic_search_range = match scope {
1840                        DiagnosticSearchScope::Local => {
1841                            let diagnostic_search_start =
1842                                cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
1843                            let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
1844                            Point::new(diagnostic_search_start, 0)
1845                                ..Point::new(diagnostic_search_end, 0)
1846                        }
1847                        DiagnosticSearchScope::Global => Default::default(),
1848                    };
1849
1850                    let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1851                        active_buffer,
1852                        &snapshot,
1853                        diagnostic_search_range,
1854                        cursor_point,
1855                        &project,
1856                        cx,
1857                    )
1858                    .await?
1859                    else {
1860                        return anyhow::Ok(None);
1861                    };
1862
1863                    let Some(prediction_result) = this
1864                        .update(cx, |this, cx| {
1865                            this.request_prediction(
1866                                &project,
1867                                &jump_buffer,
1868                                jump_position,
1869                                PredictEditsRequestTrigger::Diagnostics,
1870                                cx,
1871                            )
1872                        })?
1873                        .await?
1874                    else {
1875                        return anyhow::Ok(None);
1876                    };
1877
1878                    this.update(cx, |this, cx| {
1879                        Some((
1880                            if this
1881                                .get_or_init_project(&project, cx)
1882                                .current_prediction
1883                                .is_none()
1884                            {
1885                                prediction_result
1886                            } else {
1887                                EditPredictionResult {
1888                                    id: prediction_result.id,
1889                                    prediction: Err(EditPredictionRejectReason::CurrentPreferred),
1890                                }
1891                            },
1892                            PredictionRequestedBy::DiagnosticsUpdate,
1893                        ))
1894                    })
1895                })
1896            },
1897        );
1898    }
1899
1900    fn predictions_enabled_at(
1901        snapshot: &BufferSnapshot,
1902        position: Option<language::Anchor>,
1903        cx: &App,
1904    ) -> bool {
1905        let file = snapshot.file();
1906        let all_settings = all_language_settings(file, cx);
1907        if !all_settings.show_edit_predictions(snapshot.language(), cx)
1908            || file.is_some_and(|file| !all_settings.edit_predictions_enabled_for_file(file, cx))
1909        {
1910            return false;
1911        }
1912
1913        if let Some(last_position) = position {
1914            let settings = snapshot.settings_at(last_position, cx);
1915
1916            if !settings.edit_predictions_disabled_in.is_empty()
1917                && let Some(scope) = snapshot.language_scope_at(last_position)
1918                && let Some(scope_name) = scope.override_name()
1919                && settings
1920                    .edit_predictions_disabled_in
1921                    .iter()
1922                    .any(|s| s == scope_name)
1923            {
1924                return false;
1925            }
1926        }
1927
1928        true
1929    }
1930
1931    pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
1932}
1933
1934fn is_ep_store_provider(provider: EditPredictionProvider) -> bool {
1935    match provider {
1936        EditPredictionProvider::Zed
1937        | EditPredictionProvider::Sweep
1938        | EditPredictionProvider::Mercury
1939        | EditPredictionProvider::Ollama
1940        | EditPredictionProvider::OpenAiCompatibleApi
1941        | EditPredictionProvider::Experimental(_) => true,
1942        EditPredictionProvider::None
1943        | EditPredictionProvider::Copilot
1944        | EditPredictionProvider::Codestral => false,
1945    }
1946}
1947
1948impl EditPredictionStore {
1949    fn queue_prediction_refresh(
1950        &mut self,
1951        project: Entity<Project>,
1952        request_trigger: PredictEditsRequestTrigger,
1953        throttle_entity: EntityId,
1954        cx: &mut Context<Self>,
1955        do_refresh: impl FnOnce(
1956            WeakEntity<Self>,
1957            &mut AsyncApp,
1958        )
1959            -> Task<Result<Option<(EditPredictionResult, PredictionRequestedBy)>>>
1960        + 'static,
1961    ) {
1962        fn select_throttle(
1963            project_state: &mut ProjectState,
1964            request_trigger: PredictEditsRequestTrigger,
1965        ) -> &mut Option<(EntityId, Instant)> {
1966            match request_trigger {
1967                PredictEditsRequestTrigger::Diagnostics => {
1968                    &mut project_state.last_jump_prediction_refresh
1969                }
1970                _ => &mut project_state.last_edit_prediction_refresh,
1971            }
1972        }
1973
1974        let (needs_acceptance_tracking, max_pending_predictions) =
1975            match all_language_settings(None, cx).edit_predictions.provider {
1976                EditPredictionProvider::Zed
1977                | EditPredictionProvider::Sweep
1978                | EditPredictionProvider::Mercury
1979                | EditPredictionProvider::Experimental(_) => (true, 2),
1980                EditPredictionProvider::Ollama => (false, 1),
1981                EditPredictionProvider::OpenAiCompatibleApi => (false, 2),
1982                EditPredictionProvider::None
1983                | EditPredictionProvider::Copilot
1984                | EditPredictionProvider::Codestral => {
1985                    log::error!("queue_prediction_refresh called with non-store provider");
1986                    return;
1987                }
1988            };
1989
1990        let drop_on_cancel = !needs_acceptance_tracking;
1991        let throttle_timeout = Self::THROTTLE_TIMEOUT;
1992        let project_state = self.get_or_init_project(&project, cx);
1993        let pending_prediction_id = project_state.next_pending_prediction_id;
1994        project_state.next_pending_prediction_id += 1;
1995        let throttle_at_enqueue = *select_throttle(project_state, request_trigger);
1996
1997        let task = cx.spawn(async move |this, cx| {
1998            let throttle_wait = this
1999                .update(cx, |this, cx| {
2000                    let project_state = this.get_or_init_project(&project, cx);
2001                    let throttle = *select_throttle(project_state, request_trigger);
2002
2003                    throttle.and_then(|(last_entity, last_timestamp)| {
2004                        if throttle_entity != last_entity {
2005                            return None;
2006                        }
2007                        (last_timestamp + throttle_timeout).checked_duration_since(Instant::now())
2008                    })
2009                })
2010                .ok()
2011                .flatten();
2012
2013            if let Some(timeout) = throttle_wait {
2014                cx.background_executor().timer(timeout).await;
2015            }
2016
2017            // If this task was cancelled before the throttle timeout expired,
2018            // do not perform a request. Also skip if another task already
2019            // proceeded since we were enqueued (duplicate).
2020            let mut is_cancelled = true;
2021            this.update(cx, |this, cx| {
2022                let project_state = this.get_or_init_project(&project, cx);
2023                let was_cancelled = project_state
2024                    .cancelled_predictions
2025                    .remove(&pending_prediction_id);
2026                if was_cancelled {
2027                    return;
2028                }
2029
2030                // Another request has been already sent since this was enqueued
2031                if *select_throttle(project_state, request_trigger) != throttle_at_enqueue {
2032                    return;
2033                }
2034
2035                let new_refresh = (throttle_entity, Instant::now());
2036                *select_throttle(project_state, request_trigger) = Some(new_refresh);
2037                is_cancelled = false;
2038            })
2039            .ok();
2040            if is_cancelled {
2041                return None;
2042            }
2043
2044            let new_prediction_result = do_refresh(this.clone(), cx).await.log_err().flatten();
2045            let new_prediction_id = new_prediction_result
2046                .as_ref()
2047                .map(|(prediction, _)| prediction.id.clone());
2048
2049            // When a prediction completes, remove it from the pending list, and cancel
2050            // any pending predictions that were enqueued before it.
2051            this.update(cx, |this, cx| {
2052                let project_state = this.get_or_init_project(&project, cx);
2053
2054                let is_cancelled = project_state
2055                    .cancelled_predictions
2056                    .remove(&pending_prediction_id);
2057
2058                let new_current_prediction = if !is_cancelled
2059                    && let Some((prediction_result, requested_by)) = new_prediction_result
2060                {
2061                    match prediction_result.prediction {
2062                        Ok(prediction) => {
2063                            let new_prediction = CurrentEditPrediction {
2064                                requested_by,
2065                                prediction,
2066                                was_shown: false,
2067                                shown_with: None,
2068                            };
2069
2070                            if let Some(current_prediction) =
2071                                project_state.current_prediction.as_ref()
2072                            {
2073                                if new_prediction.should_replace_prediction(&current_prediction, cx)
2074                                {
2075                                    this.reject_current_prediction(
2076                                        EditPredictionRejectReason::Replaced,
2077                                        &project,
2078                                        cx,
2079                                    );
2080
2081                                    Some(new_prediction)
2082                                } else {
2083                                    this.reject_prediction(
2084                                        new_prediction.prediction.id,
2085                                        EditPredictionRejectReason::CurrentPreferred,
2086                                        false,
2087                                        new_prediction.prediction.model_version,
2088                                        cx,
2089                                    );
2090                                    None
2091                                }
2092                            } else {
2093                                Some(new_prediction)
2094                            }
2095                        }
2096                        Err(reject_reason) => {
2097                            this.reject_prediction(
2098                                prediction_result.id,
2099                                reject_reason,
2100                                false,
2101                                None,
2102                                cx,
2103                            );
2104                            None
2105                        }
2106                    }
2107                } else {
2108                    None
2109                };
2110
2111                let project_state = this.get_or_init_project(&project, cx);
2112
2113                if let Some(new_prediction) = new_current_prediction {
2114                    project_state.current_prediction = Some(new_prediction);
2115                }
2116
2117                let mut pending_predictions = mem::take(&mut project_state.pending_predictions);
2118                for (ix, pending_prediction) in pending_predictions.iter().enumerate() {
2119                    if pending_prediction.id == pending_prediction_id {
2120                        pending_predictions.remove(ix);
2121                        for pending_prediction in pending_predictions.drain(0..ix) {
2122                            project_state.cancel_pending_prediction(pending_prediction, cx)
2123                        }
2124                        break;
2125                    }
2126                }
2127                this.get_or_init_project(&project, cx).pending_predictions = pending_predictions;
2128                cx.notify();
2129            })
2130            .ok();
2131
2132            new_prediction_id
2133        });
2134
2135        if project_state.pending_predictions.len() < max_pending_predictions {
2136            project_state.pending_predictions.push(PendingPrediction {
2137                id: pending_prediction_id,
2138                task,
2139                drop_on_cancel,
2140            });
2141        } else {
2142            let pending_prediction = project_state.pending_predictions.pop().unwrap();
2143            project_state.pending_predictions.push(PendingPrediction {
2144                id: pending_prediction_id,
2145                task,
2146                drop_on_cancel,
2147            });
2148            project_state.cancel_pending_prediction(pending_prediction, cx);
2149        }
2150    }
2151
2152    pub fn request_prediction(
2153        &mut self,
2154        project: &Entity<Project>,
2155        active_buffer: &Entity<Buffer>,
2156        position: language::Anchor,
2157        trigger: PredictEditsRequestTrigger,
2158        cx: &mut Context<Self>,
2159    ) -> Task<Result<Option<EditPredictionResult>>> {
2160        self.request_prediction_internal(
2161            project.clone(),
2162            active_buffer.clone(),
2163            position,
2164            trigger,
2165            cx.has_flag::<EditPredictionJumpsFeatureFlag>(),
2166            cx,
2167        )
2168    }
2169
2170    fn request_prediction_internal(
2171        &mut self,
2172        project: Entity<Project>,
2173        active_buffer: Entity<Buffer>,
2174        position: language::Anchor,
2175        trigger: PredictEditsRequestTrigger,
2176        allow_jump: bool,
2177        cx: &mut Context<Self>,
2178    ) -> Task<Result<Option<EditPredictionResult>>> {
2179        self.get_or_init_project(&project, cx);
2180        let project_state = self.projects.get(&project.entity_id()).unwrap();
2181        let stored_events = project_state.events(cx);
2182        let has_events = !stored_events.is_empty();
2183        let events: Vec<Arc<zeta_prompt::Event>> =
2184            stored_events.iter().map(|e| e.event.clone()).collect();
2185        let debug_tx = project_state.debug_tx.clone();
2186
2187        let snapshot = active_buffer.read(cx).snapshot();
2188        let cursor_point = position.to_point(&snapshot);
2189        let current_offset = position.to_offset(&snapshot);
2190
2191        let mut user_actions: Vec<UserActionRecord> =
2192            project_state.user_actions.iter().cloned().collect();
2193
2194        if let Some(last_action) = user_actions.last() {
2195            if last_action.buffer_id == active_buffer.entity_id()
2196                && current_offset != last_action.offset
2197            {
2198                let timestamp_epoch_ms = SystemTime::now()
2199                    .duration_since(UNIX_EPOCH)
2200                    .map(|d| d.as_millis() as u64)
2201                    .unwrap_or(0);
2202                user_actions.push(UserActionRecord {
2203                    action_type: UserActionType::CursorMovement,
2204                    buffer_id: active_buffer.entity_id(),
2205                    line_number: cursor_point.row,
2206                    offset: current_offset,
2207                    timestamp_epoch_ms,
2208                });
2209            }
2210        }
2211        let diagnostic_search_start = cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
2212        let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
2213        let diagnostic_search_range =
2214            Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
2215
2216        let related_files = self.context_for_project(&project, cx);
2217
2218        let is_open_source = snapshot
2219            .file()
2220            .map_or(false, |file| self.is_file_open_source(&project, file, cx))
2221            && events.iter().all(|event| event.in_open_source_repo())
2222            && related_files.iter().all(|file| file.in_open_source_repo);
2223
2224        let can_collect_data = !cfg!(test)
2225            && is_open_source
2226            && self.is_data_collection_enabled(cx)
2227            && matches!(self.edit_prediction_model, EditPredictionModel::Zeta);
2228
2229        let inputs = EditPredictionModelInput {
2230            project: project.clone(),
2231            buffer: active_buffer.clone(),
2232            snapshot: snapshot,
2233            position,
2234            events,
2235            related_files,
2236            recent_paths: project_state.recent_paths.clone(),
2237            trigger,
2238            diagnostic_search_range: diagnostic_search_range,
2239            debug_tx,
2240            user_actions,
2241            can_collect_data,
2242            is_open_source,
2243        };
2244
2245        if can_collect_data && rand::random_ratio(1, 1000) {
2246            if let Some(task) = capture_example(
2247                project.clone(),
2248                active_buffer,
2249                position,
2250                stored_events,
2251                false,
2252                cx,
2253            ) {
2254                task.detach();
2255            }
2256        }
2257
2258        let task = match self.edit_prediction_model {
2259            EditPredictionModel::Zeta => zeta::request_prediction_with_zeta(self, inputs, cx),
2260            EditPredictionModel::Fim { format } => fim::request_prediction(inputs, format, cx),
2261            EditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep(inputs, cx),
2262            EditPredictionModel::Mercury => self.mercury.request_prediction(inputs, cx),
2263        };
2264
2265        cx.spawn(async move |this, cx| {
2266            let prediction = task.await?;
2267
2268            // Only fall back to diagnostics-based prediction if we got a
2269            // the model had nothing to suggest for the buffer
2270            if prediction.is_none()
2271                && allow_jump
2272                && has_events
2273                && !matches!(trigger, PredictEditsRequestTrigger::Diagnostics)
2274            {
2275                this.update(cx, |this, cx| {
2276                    this.refresh_prediction_from_diagnostics(
2277                        project,
2278                        DiagnosticSearchScope::Local,
2279                        cx,
2280                    );
2281                })?;
2282                return anyhow::Ok(None);
2283            }
2284
2285            Ok(prediction)
2286        })
2287    }
2288
2289    pub(crate) async fn next_diagnostic_location(
2290        active_buffer: Entity<Buffer>,
2291        active_buffer_snapshot: &BufferSnapshot,
2292        active_buffer_diagnostic_search_range: Range<Point>,
2293        active_buffer_cursor_point: Point,
2294        project: &Entity<Project>,
2295        cx: &mut AsyncApp,
2296    ) -> Result<Option<(Entity<Buffer>, language::Anchor)>> {
2297        let collaborator_cursor_rows: Vec<u32> = active_buffer_snapshot
2298            .selections_in_range(Anchor::MIN..Anchor::MAX, false)
2299            .flat_map(|(_, _, _, selections)| {
2300                selections.map(|s| s.head().to_point(active_buffer_snapshot).row)
2301            })
2302            .collect();
2303
2304        let mut jump_location = active_buffer_snapshot
2305            .diagnostic_groups(None)
2306            .into_iter()
2307            .filter_map(|(_, group)| {
2308                let range = &group.entries[group.primary_ix]
2309                    .range
2310                    .to_point(&active_buffer_snapshot);
2311                if range.overlaps(&active_buffer_diagnostic_search_range) {
2312                    return None;
2313                }
2314                let near_collaborator = collaborator_cursor_rows.iter().any(|&collab_row| {
2315                    range.start.row.abs_diff(collab_row) <= DIAGNOSTIC_LINES_RANGE
2316                });
2317                let near_local = active_buffer_cursor_point.row.abs_diff(range.start.row)
2318                    <= DIAGNOSTIC_LINES_RANGE;
2319                if near_collaborator && !near_local {
2320                    return None;
2321                }
2322                Some(range.start)
2323            })
2324            .min_by_key(|probe| probe.row.abs_diff(active_buffer_cursor_point.row))
2325            .map(|position| {
2326                (
2327                    active_buffer.clone(),
2328                    active_buffer_snapshot.anchor_before(position),
2329                )
2330            });
2331
2332        if jump_location.is_none() {
2333            let active_buffer_path = active_buffer.read_with(cx, |buffer, cx| {
2334                let file = buffer.file()?;
2335
2336                Some(ProjectPath {
2337                    worktree_id: file.worktree_id(cx),
2338                    path: file.path().clone(),
2339                })
2340            });
2341
2342            let mut candidates: Vec<(ProjectPath, usize)> = project.read_with(cx, |project, cx| {
2343                project
2344                    .diagnostic_summaries(false, cx)
2345                    .filter(|(path, _, _)| Some(path) != active_buffer_path.as_ref())
2346                    .map(|(path, _, _)| {
2347                        let shared_prefix = path
2348                            .path
2349                            .components()
2350                            .zip(
2351                                active_buffer_path
2352                                    .as_ref()
2353                                    .map(|p| p.path.components())
2354                                    .unwrap_or_default(),
2355                            )
2356                            .take_while(|(a, b)| a == b)
2357                            .count();
2358                        (path, shared_prefix)
2359                    })
2360                    .collect()
2361            });
2362
2363            candidates.sort_by(|a, b| b.1.cmp(&a.1));
2364
2365            for (path, _) in candidates {
2366                let candidate_buffer = project
2367                    .update(cx, |project, cx| project.open_buffer(path, cx))
2368                    .await?;
2369
2370                let (has_collaborators, diagnostic_position) =
2371                    candidate_buffer.read_with(cx, |buffer, _cx| {
2372                        let snapshot = buffer.snapshot();
2373                        let has_collaborators = snapshot
2374                            .selections_in_range(Anchor::MIN..Anchor::MAX, false)
2375                            .next()
2376                            .is_some();
2377                        let position = buffer
2378                            .buffer_diagnostics(None)
2379                            .into_iter()
2380                            .min_by_key(|entry| entry.diagnostic.severity)
2381                            .map(|entry| entry.range.start);
2382                        (has_collaborators, position)
2383                    });
2384
2385                if has_collaborators {
2386                    continue;
2387                }
2388
2389                if let Some(position) = diagnostic_position {
2390                    jump_location = Some((candidate_buffer, position));
2391                    break;
2392                }
2393            }
2394        }
2395
2396        anyhow::Ok(jump_location)
2397    }
2398
2399    async fn send_raw_llm_request(
2400        request: RawCompletionRequest,
2401        client: Arc<Client>,
2402        custom_url: Option<Arc<Url>>,
2403        llm_token: LlmApiToken,
2404        organization_id: Option<OrganizationId>,
2405        app_version: Version,
2406    ) -> Result<(RawCompletionResponse, Option<EditPredictionUsage>)> {
2407        let url = if let Some(custom_url) = custom_url {
2408            custom_url.as_ref().clone()
2409        } else {
2410            client
2411                .http_client()
2412                .build_zed_llm_url("/predict_edits/raw", &[])?
2413        };
2414
2415        Self::send_api_request(
2416            |builder| {
2417                let req = builder
2418                    .uri(url.as_ref())
2419                    .body(serde_json::to_string(&request)?.into());
2420                Ok(req?)
2421            },
2422            client,
2423            llm_token,
2424            organization_id,
2425            app_version,
2426            true,
2427        )
2428        .await
2429    }
2430
2431    pub(crate) async fn send_v3_request(
2432        input: ZetaPromptInput,
2433        client: Arc<Client>,
2434        llm_token: LlmApiToken,
2435        organization_id: Option<OrganizationId>,
2436        app_version: Version,
2437        trigger: PredictEditsRequestTrigger,
2438    ) -> Result<(PredictEditsV3Response, Option<EditPredictionUsage>)> {
2439        let url = client
2440            .http_client()
2441            .build_zed_llm_url("/predict_edits/v3", &[])?;
2442
2443        let request = PredictEditsV3Request { input, trigger };
2444
2445        let json_bytes = serde_json::to_vec(&request)?;
2446        let compressed = zstd::encode_all(&json_bytes[..], 3)?;
2447
2448        Self::send_api_request(
2449            |builder| {
2450                let req = builder
2451                    .uri(url.as_ref())
2452                    .header("Content-Encoding", "zstd")
2453                    .body(compressed.clone().into());
2454                Ok(req?)
2455            },
2456            client,
2457            llm_token,
2458            organization_id,
2459            app_version,
2460            true,
2461        )
2462        .await
2463    }
2464
2465    fn handle_api_response<T>(
2466        this: &WeakEntity<Self>,
2467        response: Result<(T, Option<EditPredictionUsage>)>,
2468        cx: &mut gpui::AsyncApp,
2469    ) -> Result<T> {
2470        match response {
2471            Ok((data, usage)) => {
2472                if let Some(usage) = usage {
2473                    this.update(cx, |this, cx| {
2474                        this.user_store.update(cx, |user_store, cx| {
2475                            user_store.update_edit_prediction_usage(usage, cx);
2476                        });
2477                    })
2478                    .ok();
2479                }
2480                Ok(data)
2481            }
2482            Err(err) => {
2483                if err.is::<ZedUpdateRequiredError>() {
2484                    cx.update(|cx| {
2485                        this.update(cx, |this, _cx| {
2486                            this.update_required = true;
2487                        })
2488                        .ok();
2489
2490                        let error_message: SharedString = err.to_string().into();
2491                        show_app_notification(
2492                            NotificationId::unique::<ZedUpdateRequiredError>(),
2493                            cx,
2494                            move |cx| {
2495                                cx.new(|cx| {
2496                                    ErrorMessagePrompt::new(error_message.clone(), cx)
2497                                        .with_link_button("Update Zed", "https://zed.dev/releases")
2498                                })
2499                            },
2500                        );
2501                    });
2502                }
2503                Err(err)
2504            }
2505        }
2506    }
2507
2508    async fn send_api_request<Res>(
2509        build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
2510        client: Arc<Client>,
2511        llm_token: LlmApiToken,
2512        organization_id: Option<OrganizationId>,
2513        app_version: Version,
2514        require_auth: bool,
2515    ) -> Result<(Res, Option<EditPredictionUsage>)>
2516    where
2517        Res: DeserializeOwned,
2518    {
2519        let http_client = client.http_client();
2520
2521        let mut token = if require_auth {
2522            Some(llm_token.acquire(&client, organization_id.clone()).await?)
2523        } else {
2524            llm_token
2525                .acquire(&client, organization_id.clone())
2526                .await
2527                .ok()
2528        };
2529        let mut did_retry = false;
2530
2531        loop {
2532            let request_builder = http_client::Request::builder().method(Method::POST);
2533
2534            let mut request_builder = request_builder
2535                .header("Content-Type", "application/json")
2536                .header(ZED_VERSION_HEADER_NAME, app_version.to_string());
2537
2538            // Only add Authorization header if we have a token
2539            if let Some(ref token_value) = token {
2540                request_builder =
2541                    request_builder.header("Authorization", format!("Bearer {}", token_value));
2542            }
2543
2544            let request = build(request_builder)?;
2545
2546            let mut response = http_client.send(request).await?;
2547
2548            if let Some(minimum_required_version) = response
2549                .headers()
2550                .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
2551                .and_then(|version| Version::from_str(version.to_str().ok()?).ok())
2552            {
2553                anyhow::ensure!(
2554                    app_version >= minimum_required_version,
2555                    ZedUpdateRequiredError {
2556                        minimum_version: minimum_required_version
2557                    }
2558                );
2559            }
2560
2561            if response.status().is_success() {
2562                let usage = EditPredictionUsage::from_headers(response.headers()).ok();
2563
2564                let mut body = Vec::new();
2565                response.body_mut().read_to_end(&mut body).await?;
2566                return Ok((serde_json::from_slice(&body)?, usage));
2567            } else if !did_retry && token.is_some() && response.needs_llm_token_refresh() {
2568                did_retry = true;
2569                token = Some(llm_token.refresh(&client, organization_id.clone()).await?);
2570            } else {
2571                let mut body = String::new();
2572                response.body_mut().read_to_string(&mut body).await?;
2573                anyhow::bail!(
2574                    "Request failed with status: {:?}\nBody: {}",
2575                    response.status(),
2576                    body
2577                );
2578            }
2579        }
2580    }
2581
2582    pub fn refresh_context(
2583        &mut self,
2584        project: &Entity<Project>,
2585        buffer: &Entity<language::Buffer>,
2586        cursor_position: language::Anchor,
2587        cx: &mut Context<Self>,
2588    ) {
2589        self.get_or_init_project(project, cx)
2590            .context
2591            .update(cx, |store, cx| {
2592                store.refresh(buffer.clone(), cursor_position, cx);
2593            });
2594    }
2595
2596    #[cfg(feature = "cli-support")]
2597    pub fn set_context_for_buffer(
2598        &mut self,
2599        project: &Entity<Project>,
2600        related_files: Vec<RelatedFile>,
2601        cx: &mut Context<Self>,
2602    ) {
2603        self.get_or_init_project(project, cx)
2604            .context
2605            .update(cx, |store, cx| {
2606                store.set_related_files(related_files, cx);
2607            });
2608    }
2609
2610    #[cfg(feature = "cli-support")]
2611    pub fn set_recent_paths_for_project(
2612        &mut self,
2613        project: &Entity<Project>,
2614        paths: impl IntoIterator<Item = project::ProjectPath>,
2615        cx: &mut Context<Self>,
2616    ) {
2617        let project_state = self.get_or_init_project(project, cx);
2618        project_state.recent_paths = paths.into_iter().collect();
2619    }
2620
2621    fn is_file_open_source(
2622        &self,
2623        project: &Entity<Project>,
2624        file: &Arc<dyn File>,
2625        cx: &App,
2626    ) -> bool {
2627        if !file.is_local() || file.is_private() {
2628            return false;
2629        }
2630        let Some(project_state) = self.projects.get(&project.entity_id()) else {
2631            return false;
2632        };
2633        project_state
2634            .license_detection_watchers
2635            .get(&file.worktree_id(cx))
2636            .as_ref()
2637            .is_some_and(|watcher| watcher.is_project_open_source())
2638    }
2639
2640    pub(crate) fn is_data_collection_enabled(&self, cx: &App) -> bool {
2641        self.data_collection_choice.is_enabled(cx)
2642    }
2643
2644    fn load_data_collection_choice() -> DataCollectionChoice {
2645        let choice = KEY_VALUE_STORE
2646            .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2647            .log_err()
2648            .flatten();
2649
2650        match choice.as_deref() {
2651            Some("true") => DataCollectionChoice::Enabled,
2652            Some("false") => DataCollectionChoice::Disabled,
2653            Some(_) => {
2654                log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
2655                DataCollectionChoice::NotAnswered
2656            }
2657            None => DataCollectionChoice::NotAnswered,
2658        }
2659    }
2660
2661    fn toggle_data_collection_choice(&mut self, cx: &mut Context<Self>) {
2662        self.data_collection_choice = self.data_collection_choice.toggle();
2663        let new_choice = self.data_collection_choice;
2664        let is_enabled = new_choice.is_enabled(cx);
2665        db::write_and_log(cx, move || {
2666            KEY_VALUE_STORE.write_kvp(
2667                ZED_PREDICT_DATA_COLLECTION_CHOICE.into(),
2668                is_enabled.to_string(),
2669            )
2670        });
2671    }
2672
2673    pub fn shown_predictions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
2674        self.shown_predictions.iter()
2675    }
2676
2677    pub fn shown_completions_len(&self) -> usize {
2678        self.shown_predictions.len()
2679    }
2680
2681    pub fn is_prediction_rated(&self, id: &EditPredictionId) -> bool {
2682        self.rated_predictions.contains(id)
2683    }
2684
2685    pub fn rate_prediction(
2686        &mut self,
2687        prediction: &EditPrediction,
2688        rating: EditPredictionRating,
2689        feedback: String,
2690        cx: &mut Context<Self>,
2691    ) {
2692        let organization = self.user_store.read(cx).current_organization();
2693
2694        self.rated_predictions.insert(prediction.id.clone());
2695
2696        cx.background_spawn({
2697            let client = self.client.clone();
2698            let prediction_id = prediction.id.to_string();
2699            let inputs = serde_json::to_value(&prediction.inputs);
2700            let output = prediction
2701                .edit_preview
2702                .as_unified_diff(prediction.snapshot.file(), &prediction.edits);
2703            async move {
2704                client
2705                    .cloud_client()
2706                    .submit_edit_prediction_feedback(SubmitEditPredictionFeedbackBody {
2707                        organization_id: organization.map(|organization| organization.id.clone()),
2708                        request_id: prediction_id,
2709                        rating: match rating {
2710                            EditPredictionRating::Positive => "positive".to_string(),
2711                            EditPredictionRating::Negative => "negative".to_string(),
2712                        },
2713                        inputs: inputs?,
2714                        output,
2715                        feedback,
2716                    })
2717                    .await?;
2718
2719                anyhow::Ok(())
2720            }
2721        })
2722        .detach_and_log_err(cx);
2723
2724        cx.notify();
2725    }
2726}
2727
2728fn merge_trailing_events_if_needed(
2729    events: &mut VecDeque<StoredEvent>,
2730    end_snapshot: &TextBufferSnapshot,
2731    latest_snapshot: &TextBufferSnapshot,
2732    latest_edit_range: &Range<Anchor>,
2733) {
2734    if let Some(last_event) = events.back() {
2735        if last_event.old_snapshot.remote_id() != latest_snapshot.remote_id() {
2736            return;
2737        }
2738    }
2739
2740    let mut next_old_event = None;
2741    let mut mergeable_count = 0;
2742    for old_event in events.iter().rev() {
2743        if let Some(next_old_event) = &next_old_event
2744            && !old_event.can_merge(&next_old_event, latest_snapshot, latest_edit_range)
2745        {
2746            break;
2747        }
2748        mergeable_count += 1;
2749        next_old_event = Some(old_event);
2750    }
2751
2752    if mergeable_count <= 1 {
2753        return;
2754    }
2755
2756    let mut events_to_merge = events.range(events.len() - mergeable_count..).peekable();
2757    let oldest_event = events_to_merge.peek().unwrap();
2758    let oldest_snapshot = oldest_event.old_snapshot.clone();
2759
2760    if let Some((diff, edited_range)) =
2761        compute_diff_between_snapshots(&oldest_snapshot, end_snapshot)
2762    {
2763        let merged_event = match oldest_event.event.as_ref() {
2764            zeta_prompt::Event::BufferChange {
2765                old_path,
2766                path,
2767                in_open_source_repo,
2768                ..
2769            } => StoredEvent {
2770                event: Arc::new(zeta_prompt::Event::BufferChange {
2771                    old_path: old_path.clone(),
2772                    path: path.clone(),
2773                    diff,
2774                    in_open_source_repo: *in_open_source_repo,
2775                    predicted: events_to_merge.all(|e| {
2776                        matches!(
2777                            e.event.as_ref(),
2778                            zeta_prompt::Event::BufferChange {
2779                                predicted: true,
2780                                ..
2781                            }
2782                        )
2783                    }),
2784                }),
2785                old_snapshot: oldest_snapshot.clone(),
2786                edit_range: end_snapshot.anchor_before(edited_range.start)
2787                    ..end_snapshot.anchor_before(edited_range.end),
2788            },
2789        };
2790        events.truncate(events.len() - mergeable_count);
2791        events.push_back(merged_event);
2792    }
2793}
2794
2795pub(crate) fn filter_redundant_excerpts(
2796    mut related_files: Vec<RelatedFile>,
2797    cursor_path: &Path,
2798    cursor_row_range: Range<u32>,
2799) -> Vec<RelatedFile> {
2800    for file in &mut related_files {
2801        if file.path.as_ref() == cursor_path {
2802            file.excerpts.retain(|excerpt| {
2803                excerpt.row_range.start < cursor_row_range.start
2804                    || excerpt.row_range.end > cursor_row_range.end
2805            });
2806        }
2807    }
2808    related_files.retain(|file| !file.excerpts.is_empty());
2809    related_files
2810}
2811
2812#[derive(Error, Debug)]
2813#[error(
2814    "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
2815)]
2816pub struct ZedUpdateRequiredError {
2817    minimum_version: Version,
2818}
2819
2820#[derive(Debug, Clone, Copy)]
2821pub enum DataCollectionChoice {
2822    NotAnswered,
2823    Enabled,
2824    Disabled,
2825}
2826
2827impl DataCollectionChoice {
2828    pub fn is_enabled(self, cx: &App) -> bool {
2829        if cx.is_staff() {
2830            return true;
2831        }
2832        match self {
2833            Self::Enabled => true,
2834            Self::NotAnswered | Self::Disabled => false,
2835        }
2836    }
2837
2838    #[must_use]
2839    pub fn toggle(&self) -> DataCollectionChoice {
2840        match self {
2841            Self::Enabled => Self::Disabled,
2842            Self::Disabled => Self::Enabled,
2843            Self::NotAnswered => Self::Enabled,
2844        }
2845    }
2846}
2847
2848impl From<bool> for DataCollectionChoice {
2849    fn from(value: bool) -> Self {
2850        match value {
2851            true => DataCollectionChoice::Enabled,
2852            false => DataCollectionChoice::Disabled,
2853        }
2854    }
2855}
2856
2857struct ZedPredictUpsell;
2858
2859impl Dismissable for ZedPredictUpsell {
2860    const KEY: &'static str = "dismissed-edit-predict-upsell";
2861
2862    fn dismissed() -> bool {
2863        // To make this backwards compatible with older versions of Zed, we
2864        // check if the user has seen the previous Edit Prediction Onboarding
2865        // before, by checking the data collection choice which was written to
2866        // the database once the user clicked on "Accept and Enable"
2867        if KEY_VALUE_STORE
2868            .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2869            .log_err()
2870            .is_some_and(|s| s.is_some())
2871        {
2872            return true;
2873        }
2874
2875        KEY_VALUE_STORE
2876            .read_kvp(Self::KEY)
2877            .log_err()
2878            .is_some_and(|s| s.is_some())
2879    }
2880}
2881
2882pub fn should_show_upsell_modal() -> bool {
2883    !ZedPredictUpsell::dismissed()
2884}
2885
2886pub fn init(cx: &mut App) {
2887    cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
2888        workspace.register_action(
2889            move |workspace, _: &zed_actions::OpenZedPredictOnboarding, window, cx| {
2890                ZedPredictModal::toggle(
2891                    workspace,
2892                    workspace.user_store().clone(),
2893                    workspace.client().clone(),
2894                    window,
2895                    cx,
2896                )
2897            },
2898        );
2899
2900        workspace.register_action(|workspace, _: &ResetOnboarding, _window, cx| {
2901            update_settings_file(workspace.app_state().fs.clone(), cx, move |settings, _| {
2902                settings
2903                    .project
2904                    .all_languages
2905                    .edit_predictions
2906                    .get_or_insert_default()
2907                    .provider = Some(EditPredictionProvider::None)
2908            });
2909        });
2910        fn copilot_for_project(project: &Entity<Project>, cx: &mut App) -> Option<Entity<Copilot>> {
2911            EditPredictionStore::try_global(cx).and_then(|store| {
2912                store.update(cx, |this, cx| this.start_copilot_for_project(project, cx))
2913            })
2914        }
2915
2916        workspace.register_action(|workspace, _: &SignIn, window, cx| {
2917            if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
2918                copilot_ui::initiate_sign_in(copilot, window, cx);
2919            }
2920        });
2921        workspace.register_action(|workspace, _: &Reinstall, window, cx| {
2922            if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
2923                copilot_ui::reinstall_and_sign_in(copilot, window, cx);
2924            }
2925        });
2926        workspace.register_action(|workspace, _: &SignOut, window, cx| {
2927            if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
2928                copilot_ui::initiate_sign_out(copilot, window, cx);
2929            }
2930        });
2931    })
2932    .detach();
2933}