edit_prediction.rs

   1use anyhow::Result;
   2use client::{Client, EditPredictionUsage, NeedsLlmTokenRefresh, UserStore, global_llm_token};
   3use cloud_api_client::LlmApiToken;
   4use cloud_api_types::{OrganizationId, SubmitEditPredictionFeedbackBody};
   5use cloud_llm_client::predict_edits_v3::{
   6    PREDICT_EDITS_MODE_HEADER_NAME, PredictEditsMode, PredictEditsV3Request,
   7    PredictEditsV3Response, RawCompletionRequest, RawCompletionResponse,
   8};
   9use cloud_llm_client::{
  10    EditPredictionRejectReason, EditPredictionRejection,
  11    MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
  12    PredictEditsRequestTrigger, RejectEditPredictionsBodyRef, ZED_VERSION_HEADER_NAME,
  13};
  14use collections::{HashMap, HashSet};
  15use copilot::{Copilot, Reinstall, SignIn, SignOut};
  16use credentials_provider::CredentialsProvider;
  17use db::kvp::{Dismissable, KeyValueStore};
  18use edit_prediction_context::{RelatedExcerptStore, RelatedExcerptStoreEvent, RelatedFile};
  19use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
  20use futures::{
  21    AsyncReadExt as _, FutureExt as _, StreamExt as _,
  22    channel::mpsc::{self, UnboundedReceiver},
  23    select_biased,
  24};
  25use gpui::BackgroundExecutor;
  26use gpui::http_client::Url;
  27use gpui::{
  28    App, AsyncApp, Entity, EntityId, Global, SharedString, Task, WeakEntity, actions,
  29    http_client::{self, AsyncBody, Method},
  30    prelude::*,
  31};
  32use heapless::Vec as ArrayVec;
  33use language::{
  34    Anchor, Buffer, BufferSnapshot, EditPredictionsMode, EditPreview, File, OffsetRangeExt, Point,
  35    TextBufferSnapshot, ToOffset, ToPoint, language_settings::all_language_settings,
  36};
  37use project::{DisableAiSettings, Project, ProjectPath, WorktreeId};
  38use release_channel::AppVersion;
  39use semver::Version;
  40use serde::de::DeserializeOwned;
  41use settings::{
  42    EditPredictionPromptFormat, EditPredictionProvider, Settings as _, update_settings_file,
  43};
  44use std::collections::{VecDeque, hash_map};
  45use std::env;
  46use text::{AnchorRangeExt, Edit};
  47use workspace::{AppState, Workspace};
  48use zeta_prompt::{ZetaFormat, ZetaPromptInput};
  49
  50use std::mem;
  51use std::ops::Range;
  52use std::path::Path;
  53use std::rc::Rc;
  54use std::str::FromStr as _;
  55use std::sync::Arc;
  56use std::time::{Duration, Instant};
  57
  58use thiserror::Error;
  59use util::{RangeExt as _, ResultExt as _};
  60
  61pub mod cursor_excerpt;
  62pub mod example_spec;
  63pub mod fim;
  64mod license_detection;
  65pub mod mercury;
  66pub mod metrics;
  67pub mod ollama;
  68mod onboarding_modal;
  69pub mod open_ai_response;
  70mod prediction;
  71
  72pub mod udiff;
  73
  74mod capture_example;
  75pub mod open_ai_compatible;
  76mod zed_edit_prediction_delegate;
  77pub mod zeta;
  78
  79#[cfg(test)]
  80mod edit_prediction_tests;
  81
  82use crate::cursor_excerpt::expand_context_syntactically_then_linewise;
  83use crate::example_spec::ExampleSpec;
  84use crate::license_detection::LicenseDetectionWatcher;
  85use crate::mercury::Mercury;
  86pub use crate::metrics::{KeptRateResult, compute_kept_rate};
  87use crate::onboarding_modal::ZedPredictModal;
  88pub use crate::prediction::EditPrediction;
  89pub use crate::prediction::EditPredictionId;
  90use crate::prediction::EditPredictionResult;
  91pub use capture_example::capture_example;
  92pub use language_model::ApiKeyState;
  93pub use telemetry_events::EditPredictionRating;
  94pub use zed_edit_prediction_delegate::ZedEditPredictionDelegate;
  95
  96actions!(
  97    edit_prediction,
  98    [
  99        /// Resets the edit prediction onboarding state.
 100        ResetOnboarding,
 101        /// Clears the edit prediction history.
 102        ClearHistory,
 103    ]
 104);
 105
 106/// Maximum number of events to track.
 107const EVENT_COUNT_MAX: usize = 10;
 108const CHANGE_GROUPING_LINE_SPAN: u32 = 8;
 109const EDIT_HISTORY_DIFF_SIZE_LIMIT: usize = 2048 * 3; // ~2048 tokens or ~50% of typical prompt budget
 110const COLLABORATOR_EDIT_LOCALITY_CONTEXT_TOKENS: usize = 512;
 111const LAST_CHANGE_GROUPING_TIME: Duration = Duration::from_secs(1);
 112const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
 113const REJECT_REQUEST_DEBOUNCE: Duration = Duration::from_secs(15);
 114const EDIT_PREDICTION_SETTLED_EVENT: &str = "Edit Prediction Settled";
 115const EDIT_PREDICTION_SETTLED_TTL: Duration = Duration::from_secs(60 * 5);
 116const EDIT_PREDICTION_SETTLED_QUIESCENCE: Duration = Duration::from_secs(10);
 117
 118pub struct EditPredictionJumpsFeatureFlag;
 119
 120impl FeatureFlag for EditPredictionJumpsFeatureFlag {
 121    const NAME: &'static str = "edit_prediction_jumps";
 122}
 123
 124#[derive(Clone)]
 125struct EditPredictionStoreGlobal(Entity<EditPredictionStore>);
 126
 127impl Global for EditPredictionStoreGlobal {}
 128
 129/// Configuration for using the raw Zeta2 endpoint.
 130/// When set, the client uses the raw endpoint and constructs the prompt itself.
 131/// The version is also used as the Baseten environment name (lowercased).
 132#[derive(Clone)]
 133pub struct Zeta2RawConfig {
 134    pub model_id: Option<String>,
 135    pub environment: Option<String>,
 136    pub format: ZetaFormat,
 137}
 138
 139pub struct EditPredictionStore {
 140    client: Arc<Client>,
 141    user_store: Entity<UserStore>,
 142    llm_token: LlmApiToken,
 143    _fetch_experiments_task: Task<()>,
 144    projects: HashMap<EntityId, ProjectState>,
 145    update_required: bool,
 146    edit_prediction_model: EditPredictionModel,
 147    zeta2_raw_config: Option<Zeta2RawConfig>,
 148    preferred_experiment: Option<String>,
 149    available_experiments: Vec<String>,
 150    pub mercury: Mercury,
 151    data_collection_choice: DataCollectionChoice,
 152    reject_predictions_tx: mpsc::UnboundedSender<EditPredictionRejectionPayload>,
 153    settled_predictions_tx: mpsc::UnboundedSender<Instant>,
 154    shown_predictions: VecDeque<EditPrediction>,
 155    rated_predictions: HashSet<EditPredictionId>,
 156    #[cfg(test)]
 157    settled_event_callback: Option<Box<dyn Fn(EditPredictionId, String)>>,
 158    credentials_provider: Arc<dyn CredentialsProvider>,
 159}
 160
 161pub(crate) struct EditPredictionRejectionPayload {
 162    rejection: EditPredictionRejection,
 163    organization_id: Option<OrganizationId>,
 164}
 165
 166#[derive(Copy, Clone, PartialEq, Eq)]
 167pub enum EditPredictionModel {
 168    Zeta,
 169    Fim { format: EditPredictionPromptFormat },
 170    Mercury,
 171}
 172
 173#[derive(Clone)]
 174pub struct EditPredictionModelInput {
 175    project: Entity<Project>,
 176    buffer: Entity<Buffer>,
 177    snapshot: BufferSnapshot,
 178    position: Anchor,
 179    events: Vec<Arc<zeta_prompt::Event>>,
 180    related_files: Vec<RelatedFile>,
 181    mode: PredictEditsMode,
 182    trigger: PredictEditsRequestTrigger,
 183    diagnostic_search_range: Range<Point>,
 184    debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
 185    can_collect_data: bool,
 186    is_open_source: bool,
 187}
 188
 189#[derive(Debug)]
 190pub enum DebugEvent {
 191    ContextRetrievalStarted(ContextRetrievalStartedDebugEvent),
 192    ContextRetrievalFinished(ContextRetrievalFinishedDebugEvent),
 193    EditPredictionStarted(EditPredictionStartedDebugEvent),
 194    EditPredictionFinished(EditPredictionFinishedDebugEvent),
 195}
 196
 197#[derive(Debug)]
 198pub struct ContextRetrievalStartedDebugEvent {
 199    pub project_entity_id: EntityId,
 200    pub timestamp: Instant,
 201    pub search_prompt: String,
 202}
 203
 204#[derive(Debug)]
 205pub struct ContextRetrievalFinishedDebugEvent {
 206    pub project_entity_id: EntityId,
 207    pub timestamp: Instant,
 208    pub metadata: Vec<(&'static str, SharedString)>,
 209}
 210
 211#[derive(Debug)]
 212pub struct EditPredictionStartedDebugEvent {
 213    pub buffer: WeakEntity<Buffer>,
 214    pub position: Anchor,
 215    pub prompt: Option<String>,
 216}
 217
 218#[derive(Debug)]
 219pub struct EditPredictionFinishedDebugEvent {
 220    pub buffer: WeakEntity<Buffer>,
 221    pub position: Anchor,
 222    pub model_output: Option<String>,
 223}
 224
 225/// An event with associated metadata for reconstructing buffer state.
 226#[derive(Clone)]
 227pub struct StoredEvent {
 228    pub event: Arc<zeta_prompt::Event>,
 229    pub old_snapshot: TextBufferSnapshot,
 230    pub new_snapshot_version: clock::Global,
 231    pub total_edit_range: Range<Anchor>,
 232}
 233
 234impl StoredEvent {
 235    fn can_merge(
 236        &self,
 237        next_old_event: &StoredEvent,
 238        latest_snapshot: &TextBufferSnapshot,
 239        latest_edit_range: &Range<Anchor>,
 240    ) -> bool {
 241        // Events must be for the same buffer and be contiguous across included snapshots to be mergeable.
 242        if self.old_snapshot.remote_id() != next_old_event.old_snapshot.remote_id() {
 243            return false;
 244        }
 245        if self.old_snapshot.remote_id() != latest_snapshot.remote_id() {
 246            return false;
 247        }
 248        if self.new_snapshot_version != next_old_event.old_snapshot.version {
 249            return false;
 250        }
 251        if !latest_snapshot
 252            .version
 253            .observed_all(&next_old_event.new_snapshot_version)
 254        {
 255            return false;
 256        }
 257
 258        let a_is_predicted = matches!(
 259            self.event.as_ref(),
 260            zeta_prompt::Event::BufferChange {
 261                predicted: true,
 262                ..
 263            }
 264        );
 265        let b_is_predicted = matches!(
 266            next_old_event.event.as_ref(),
 267            zeta_prompt::Event::BufferChange {
 268                predicted: true,
 269                ..
 270            }
 271        );
 272
 273        // If events come from the same source (both predicted or both manual) then
 274        // we would have coalesced them already.
 275        if a_is_predicted == b_is_predicted {
 276            return false;
 277        }
 278
 279        let left_range = self.total_edit_range.to_point(latest_snapshot);
 280        let right_range = next_old_event.total_edit_range.to_point(latest_snapshot);
 281        let latest_range = latest_edit_range.to_point(latest_snapshot);
 282
 283        // Events near to the latest edit are not merged if their sources differ.
 284        if lines_between_ranges(&left_range, &latest_range)
 285            .min(lines_between_ranges(&right_range, &latest_range))
 286            <= CHANGE_GROUPING_LINE_SPAN
 287        {
 288            return false;
 289        }
 290
 291        // Events that are distant from each other are not merged.
 292        if lines_between_ranges(&left_range, &right_range) > CHANGE_GROUPING_LINE_SPAN {
 293            return false;
 294        }
 295
 296        true
 297    }
 298}
 299
 300fn lines_between_ranges(left: &Range<Point>, right: &Range<Point>) -> u32 {
 301    if left.start > right.end {
 302        return left.start.row - right.end.row;
 303    }
 304    if right.start > left.end {
 305        return right.start.row - left.end.row;
 306    }
 307    0
 308}
 309
 310struct ProjectState {
 311    events: VecDeque<StoredEvent>,
 312    last_event: Option<LastEvent>,
 313    recent_paths: VecDeque<ProjectPath>,
 314    registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
 315    current_prediction: Option<CurrentEditPrediction>,
 316    next_pending_prediction_id: usize,
 317    pending_predictions: ArrayVec<PendingPrediction, 2, u8>,
 318    debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
 319    last_edit_prediction_refresh: Option<(EntityId, Instant)>,
 320    last_jump_prediction_refresh: Option<(EntityId, Instant)>,
 321    cancelled_predictions: HashSet<usize>,
 322    context: Entity<RelatedExcerptStore>,
 323    license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
 324    _subscriptions: [gpui::Subscription; 2],
 325    copilot: Option<Entity<Copilot>>,
 326}
 327
 328impl ProjectState {
 329    pub fn events(&self, cx: &App) -> Vec<StoredEvent> {
 330        self.events
 331            .iter()
 332            .cloned()
 333            .chain(self.last_event.as_ref().iter().flat_map(|event| {
 334                let (one, two) = event.split_by_pause();
 335                let one = one.finalize(&self.license_detection_watchers, cx);
 336                let two = two.and_then(|two| two.finalize(&self.license_detection_watchers, cx));
 337                one.into_iter().chain(two)
 338            }))
 339            .collect()
 340    }
 341
 342    fn cancel_pending_prediction(
 343        &mut self,
 344        pending_prediction: PendingPrediction,
 345        cx: &mut Context<EditPredictionStore>,
 346    ) {
 347        self.cancelled_predictions.insert(pending_prediction.id);
 348
 349        if pending_prediction.drop_on_cancel {
 350            drop(pending_prediction.task);
 351        } else {
 352            cx.spawn(async move |this, cx| {
 353                let Some(prediction_id) = pending_prediction.task.await else {
 354                    return;
 355                };
 356
 357                this.update(cx, |this, cx| {
 358                    this.reject_prediction(
 359                        prediction_id,
 360                        EditPredictionRejectReason::Canceled,
 361                        false,
 362                        None,
 363                        None,
 364                        cx,
 365                    );
 366                })
 367                .ok();
 368            })
 369            .detach()
 370        }
 371    }
 372
 373    fn active_buffer(
 374        &self,
 375        project: &Entity<Project>,
 376        cx: &App,
 377    ) -> Option<(Entity<Buffer>, Option<Anchor>)> {
 378        let project = project.read(cx);
 379        let active_path = project.path_for_entry(project.active_entry()?, cx)?;
 380        let active_buffer = project.buffer_store().read(cx).get_by_path(&active_path)?;
 381        let registered_buffer = self.registered_buffers.get(&active_buffer.entity_id())?;
 382        Some((active_buffer, registered_buffer.last_position))
 383    }
 384}
 385
 386#[derive(Debug, Clone)]
 387struct CurrentEditPrediction {
 388    pub requested_by: PredictionRequestedBy,
 389    pub prediction: EditPrediction,
 390    pub was_shown: bool,
 391    pub shown_with: Option<edit_prediction_types::SuggestionDisplayType>,
 392    pub e2e_latency: std::time::Duration,
 393}
 394
 395impl CurrentEditPrediction {
 396    fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
 397        let Some(new_edits) = self
 398            .prediction
 399            .interpolate(&self.prediction.buffer.read(cx))
 400        else {
 401            return false;
 402        };
 403
 404        if self.prediction.buffer != old_prediction.prediction.buffer {
 405            return true;
 406        }
 407
 408        let Some(old_edits) = old_prediction
 409            .prediction
 410            .interpolate(&old_prediction.prediction.buffer.read(cx))
 411        else {
 412            return true;
 413        };
 414
 415        let requested_by_buffer_id = self.requested_by.buffer_id();
 416
 417        // This reduces the occurrence of UI thrash from replacing edits
 418        //
 419        // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
 420        if requested_by_buffer_id == Some(self.prediction.buffer.entity_id())
 421            && requested_by_buffer_id == Some(old_prediction.prediction.buffer.entity_id())
 422            && old_edits.len() == 1
 423            && new_edits.len() == 1
 424        {
 425            let (old_range, old_text) = &old_edits[0];
 426            let (new_range, new_text) = &new_edits[0];
 427            new_range == old_range && new_text.starts_with(old_text.as_ref())
 428        } else {
 429            true
 430        }
 431    }
 432}
 433
 434#[derive(Debug, Clone)]
 435enum PredictionRequestedBy {
 436    DiagnosticsUpdate,
 437    Buffer(EntityId),
 438}
 439
 440impl PredictionRequestedBy {
 441    pub fn buffer_id(&self) -> Option<EntityId> {
 442        match self {
 443            PredictionRequestedBy::DiagnosticsUpdate => None,
 444            PredictionRequestedBy::Buffer(buffer_id) => Some(*buffer_id),
 445        }
 446    }
 447}
 448
 449const DIAGNOSTIC_LINES_RANGE: u32 = 20;
 450
 451#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
 452pub enum DiagnosticSearchScope {
 453    Local,
 454    Global,
 455}
 456
 457#[derive(Debug)]
 458struct PendingPrediction {
 459    id: usize,
 460    task: Task<Option<EditPredictionId>>,
 461    /// If true, the task is dropped immediately on cancel (cancelling the HTTP request).
 462    /// If false, the task is awaited to completion so rejection can be reported.
 463    drop_on_cancel: bool,
 464}
 465
 466/// A prediction from the perspective of a buffer.
 467#[derive(Debug)]
 468enum BufferEditPrediction<'a> {
 469    Local { prediction: &'a EditPrediction },
 470    Jump { prediction: &'a EditPrediction },
 471}
 472
 473#[cfg(test)]
 474impl std::ops::Deref for BufferEditPrediction<'_> {
 475    type Target = EditPrediction;
 476
 477    fn deref(&self) -> &Self::Target {
 478        match self {
 479            BufferEditPrediction::Local { prediction } => prediction,
 480            BufferEditPrediction::Jump { prediction } => prediction,
 481        }
 482    }
 483}
 484
 485#[derive(Clone)]
 486struct PendingSettledPrediction {
 487    request_id: EditPredictionId,
 488    editable_anchor_range: Range<Anchor>,
 489    editable_region_before_prediction: String,
 490    predicted_editable_region: String,
 491    ts_error_count_before_prediction: usize,
 492    ts_error_count_after_prediction: usize,
 493    example: Option<ExampleSpec>,
 494    enqueued_at: Instant,
 495    last_edit_at: Instant,
 496    e2e_latency: std::time::Duration,
 497}
 498
 499struct RegisteredBuffer {
 500    file: Option<Arc<dyn File>>,
 501    snapshot: TextBufferSnapshot,
 502    pending_predictions: Vec<PendingSettledPrediction>,
 503    last_position: Option<Anchor>,
 504    _subscriptions: [gpui::Subscription; 2],
 505}
 506
 507#[derive(Clone)]
 508struct LastEvent {
 509    old_snapshot: TextBufferSnapshot,
 510    new_snapshot: TextBufferSnapshot,
 511    old_file: Option<Arc<dyn File>>,
 512    new_file: Option<Arc<dyn File>>,
 513    latest_edit_range: Range<Anchor>,
 514    total_edit_range: Range<Anchor>,
 515    total_edit_range_at_last_pause_boundary: Option<Range<Anchor>>,
 516    predicted: bool,
 517    snapshot_after_last_editing_pause: Option<TextBufferSnapshot>,
 518    last_edit_time: Option<Instant>,
 519}
 520
 521impl LastEvent {
 522    pub fn finalize(
 523        &self,
 524        license_detection_watchers: &HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
 525        cx: &App,
 526    ) -> Option<StoredEvent> {
 527        let path = buffer_path_with_id_fallback(self.new_file.as_ref(), &self.new_snapshot, cx);
 528        let old_path = buffer_path_with_id_fallback(self.old_file.as_ref(), &self.old_snapshot, cx);
 529
 530        let in_open_source_repo =
 531            [self.new_file.as_ref(), self.old_file.as_ref()]
 532                .iter()
 533                .all(|file| {
 534                    file.is_some_and(|file| {
 535                        license_detection_watchers
 536                            .get(&file.worktree_id(cx))
 537                            .is_some_and(|watcher| watcher.is_project_open_source())
 538                    })
 539                });
 540
 541        let (diff, edit_range) = compute_diff_between_snapshots_in_range(
 542            &self.old_snapshot,
 543            &self.new_snapshot,
 544            &self.total_edit_range,
 545        )?;
 546
 547        if path == old_path && diff.is_empty() {
 548            None
 549        } else {
 550            Some(StoredEvent {
 551                event: Arc::new(zeta_prompt::Event::BufferChange {
 552                    old_path,
 553                    path,
 554                    diff,
 555                    in_open_source_repo,
 556                    predicted: self.predicted,
 557                }),
 558                old_snapshot: self.old_snapshot.clone(),
 559                new_snapshot_version: self.new_snapshot.version.clone(),
 560                total_edit_range: self.new_snapshot.anchor_before(edit_range.start)
 561                    ..self.new_snapshot.anchor_before(edit_range.end),
 562            })
 563        }
 564    }
 565
 566    pub fn split_by_pause(&self) -> (LastEvent, Option<LastEvent>) {
 567        let Some(boundary_snapshot) = self.snapshot_after_last_editing_pause.as_ref() else {
 568            return (self.clone(), None);
 569        };
 570
 571        let total_edit_range_before_pause = self
 572            .total_edit_range_at_last_pause_boundary
 573            .clone()
 574            .unwrap_or_else(|| self.total_edit_range.clone());
 575
 576        let Some(total_edit_range_after_pause) =
 577            compute_total_edit_range_between_snapshots(boundary_snapshot, &self.new_snapshot)
 578        else {
 579            return (self.clone(), None);
 580        };
 581
 582        let latest_edit_range_before_pause = total_edit_range_before_pause.clone();
 583        let latest_edit_range_after_pause = total_edit_range_after_pause.clone();
 584
 585        let before = LastEvent {
 586            old_snapshot: self.old_snapshot.clone(),
 587            new_snapshot: boundary_snapshot.clone(),
 588            old_file: self.old_file.clone(),
 589            new_file: self.new_file.clone(),
 590            latest_edit_range: latest_edit_range_before_pause,
 591            total_edit_range: total_edit_range_before_pause,
 592            total_edit_range_at_last_pause_boundary: None,
 593            predicted: self.predicted,
 594            snapshot_after_last_editing_pause: None,
 595            last_edit_time: self.last_edit_time,
 596        };
 597
 598        let after = LastEvent {
 599            old_snapshot: boundary_snapshot.clone(),
 600            new_snapshot: self.new_snapshot.clone(),
 601            old_file: self.old_file.clone(),
 602            new_file: self.new_file.clone(),
 603            latest_edit_range: latest_edit_range_after_pause,
 604            total_edit_range: total_edit_range_after_pause,
 605            total_edit_range_at_last_pause_boundary: None,
 606            predicted: self.predicted,
 607            snapshot_after_last_editing_pause: None,
 608            last_edit_time: self.last_edit_time,
 609        };
 610
 611        (before, Some(after))
 612    }
 613}
 614
 615fn compute_total_edit_range_between_snapshots(
 616    old_snapshot: &TextBufferSnapshot,
 617    new_snapshot: &TextBufferSnapshot,
 618) -> Option<Range<Anchor>> {
 619    let edits: Vec<Edit<usize>> = new_snapshot
 620        .edits_since::<usize>(&old_snapshot.version)
 621        .collect();
 622
 623    let (first_edit, last_edit) = edits.first().zip(edits.last())?;
 624    let new_start_point = new_snapshot.offset_to_point(first_edit.new.start);
 625    let new_end_point = new_snapshot.offset_to_point(last_edit.new.end);
 626
 627    Some(new_snapshot.anchor_before(new_start_point)..new_snapshot.anchor_before(new_end_point))
 628}
 629
 630fn compute_old_range_for_new_range(
 631    old_snapshot: &TextBufferSnapshot,
 632    new_snapshot: &TextBufferSnapshot,
 633    total_edit_range: &Range<Anchor>,
 634) -> Option<Range<Point>> {
 635    let new_start_offset = total_edit_range.start.to_offset(new_snapshot);
 636    let new_end_offset = total_edit_range.end.to_offset(new_snapshot);
 637
 638    let edits: Vec<Edit<usize>> = new_snapshot
 639        .edits_since::<usize>(&old_snapshot.version)
 640        .collect();
 641    let mut old_start_offset = None;
 642    let mut old_end_offset = None;
 643    let mut delta: isize = 0;
 644
 645    for edit in &edits {
 646        if old_start_offset.is_none() && new_start_offset <= edit.new.end {
 647            old_start_offset = Some(if new_start_offset < edit.new.start {
 648                new_start_offset.checked_add_signed(-delta)?
 649            } else {
 650                edit.old.start
 651            });
 652        }
 653
 654        if old_end_offset.is_none() && new_end_offset <= edit.new.end {
 655            old_end_offset = Some(if new_end_offset < edit.new.start {
 656                new_end_offset.checked_add_signed(-delta)?
 657            } else {
 658                edit.old.end
 659            });
 660        }
 661
 662        delta += edit.new.len() as isize - edit.old.len() as isize;
 663    }
 664
 665    let old_start_offset =
 666        old_start_offset.unwrap_or_else(|| new_start_offset.saturating_add_signed(-delta));
 667    let old_end_offset =
 668        old_end_offset.unwrap_or_else(|| new_end_offset.saturating_add_signed(-delta));
 669
 670    Some(
 671        old_snapshot.offset_to_point(old_start_offset)
 672            ..old_snapshot.offset_to_point(old_end_offset),
 673    )
 674}
 675
 676fn compute_diff_between_snapshots_in_range(
 677    old_snapshot: &TextBufferSnapshot,
 678    new_snapshot: &TextBufferSnapshot,
 679    total_edit_range: &Range<Anchor>,
 680) -> Option<(String, Range<Point>)> {
 681    let new_start_point = total_edit_range.start.to_point(new_snapshot);
 682    let new_end_point = total_edit_range.end.to_point(new_snapshot);
 683    let old_range = compute_old_range_for_new_range(old_snapshot, new_snapshot, total_edit_range)?;
 684    let old_start_point = old_range.start;
 685    let old_end_point = old_range.end;
 686
 687    const CONTEXT_LINES: u32 = 3;
 688
 689    let old_context_start_row = old_start_point.row.saturating_sub(CONTEXT_LINES);
 690    let new_context_start_row = new_start_point.row.saturating_sub(CONTEXT_LINES);
 691    let old_context_end_row =
 692        (old_end_point.row + 1 + CONTEXT_LINES).min(old_snapshot.max_point().row);
 693    let new_context_end_row =
 694        (new_end_point.row + 1 + CONTEXT_LINES).min(new_snapshot.max_point().row);
 695
 696    let old_start_line_offset = old_snapshot.point_to_offset(Point::new(old_context_start_row, 0));
 697    let new_start_line_offset = new_snapshot.point_to_offset(Point::new(new_context_start_row, 0));
 698    let old_end_line_offset = old_snapshot
 699        .point_to_offset(Point::new(old_context_end_row + 1, 0).min(old_snapshot.max_point()));
 700    let new_end_line_offset = new_snapshot
 701        .point_to_offset(Point::new(new_context_end_row + 1, 0).min(new_snapshot.max_point()));
 702    let old_edit_range = old_start_line_offset..old_end_line_offset;
 703    let new_edit_range = new_start_line_offset..new_end_line_offset;
 704
 705    if new_edit_range.len() > EDIT_HISTORY_DIFF_SIZE_LIMIT
 706        || old_edit_range.len() > EDIT_HISTORY_DIFF_SIZE_LIMIT
 707    {
 708        return None;
 709    }
 710
 711    let old_region_text: String = old_snapshot.text_for_range(old_edit_range).collect();
 712    let new_region_text: String = new_snapshot.text_for_range(new_edit_range).collect();
 713
 714    let diff = language::unified_diff_with_offsets(
 715        &old_region_text,
 716        &new_region_text,
 717        old_context_start_row,
 718        new_context_start_row,
 719    );
 720
 721    Some((diff, new_start_point..new_end_point))
 722}
 723
 724fn buffer_path_with_id_fallback(
 725    file: Option<&Arc<dyn File>>,
 726    snapshot: &TextBufferSnapshot,
 727    cx: &App,
 728) -> Arc<Path> {
 729    if let Some(file) = file {
 730        file.full_path(cx).into()
 731    } else {
 732        Path::new(&format!("untitled-{}", snapshot.remote_id())).into()
 733    }
 734}
 735
 736impl EditPredictionStore {
 737    pub fn try_global(cx: &App) -> Option<Entity<Self>> {
 738        cx.try_global::<EditPredictionStoreGlobal>()
 739            .map(|global| global.0.clone())
 740    }
 741
 742    pub fn global(
 743        client: &Arc<Client>,
 744        user_store: &Entity<UserStore>,
 745        cx: &mut App,
 746    ) -> Entity<Self> {
 747        cx.try_global::<EditPredictionStoreGlobal>()
 748            .map(|global| global.0.clone())
 749            .unwrap_or_else(|| {
 750                let ep_store = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
 751                cx.set_global(EditPredictionStoreGlobal(ep_store.clone()));
 752                ep_store
 753            })
 754    }
 755
 756    pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
 757        let data_collection_choice = Self::load_data_collection_choice(cx);
 758
 759        let llm_token = global_llm_token(cx);
 760
 761        let (reject_tx, reject_rx) = mpsc::unbounded();
 762        cx.background_spawn({
 763            let client = client.clone();
 764            let llm_token = llm_token.clone();
 765            let app_version = AppVersion::global(cx);
 766            let background_executor = cx.background_executor().clone();
 767            async move {
 768                Self::handle_rejected_predictions(
 769                    reject_rx,
 770                    client,
 771                    llm_token,
 772                    app_version,
 773                    background_executor,
 774                )
 775                .await
 776            }
 777        })
 778        .detach();
 779
 780        let (settled_predictions_tx, settled_predictions_rx) = mpsc::unbounded();
 781        cx.spawn(async move |this, cx| {
 782            Self::run_settled_predictions_worker(this, settled_predictions_rx, cx).await;
 783        })
 784        .detach();
 785
 786        let mut current_user = user_store.read(cx).watch_current_user();
 787        let fetch_experiments_task = cx.spawn(async move |this, cx| {
 788            while current_user.borrow().is_none() {
 789                current_user.next().await;
 790            }
 791
 792            this.update(cx, |this, cx| {
 793                if cx.is_staff() {
 794                    this.refresh_available_experiments(cx);
 795                }
 796            })
 797            .log_err();
 798        });
 799
 800        let credentials_provider = zed_credentials_provider::global(cx);
 801
 802        let this = Self {
 803            projects: HashMap::default(),
 804            client,
 805            user_store,
 806            llm_token,
 807            _fetch_experiments_task: fetch_experiments_task,
 808            update_required: false,
 809            edit_prediction_model: EditPredictionModel::Zeta,
 810            zeta2_raw_config: Self::zeta2_raw_config_from_env(),
 811            preferred_experiment: None,
 812            available_experiments: Vec::new(),
 813            mercury: Mercury::new(cx),
 814
 815            data_collection_choice,
 816            reject_predictions_tx: reject_tx,
 817            settled_predictions_tx,
 818            rated_predictions: Default::default(),
 819            shown_predictions: Default::default(),
 820            #[cfg(test)]
 821            settled_event_callback: None,
 822
 823            credentials_provider,
 824        };
 825
 826        this
 827    }
 828
 829    fn zeta2_raw_config_from_env() -> Option<Zeta2RawConfig> {
 830        let version_str = env::var("ZED_ZETA_FORMAT").ok()?;
 831        let format = ZetaFormat::parse(&version_str).ok()?;
 832        let model_id = env::var("ZED_ZETA_MODEL").ok();
 833        let environment = env::var("ZED_ZETA_ENVIRONMENT").ok();
 834        Some(Zeta2RawConfig {
 835            model_id,
 836            environment,
 837            format,
 838        })
 839    }
 840
 841    pub fn set_edit_prediction_model(&mut self, model: EditPredictionModel) {
 842        self.edit_prediction_model = model;
 843    }
 844
 845    pub fn set_zeta2_raw_config(&mut self, config: Zeta2RawConfig) {
 846        self.zeta2_raw_config = Some(config);
 847    }
 848
 849    pub fn zeta2_raw_config(&self) -> Option<&Zeta2RawConfig> {
 850        self.zeta2_raw_config.as_ref()
 851    }
 852
 853    pub fn preferred_experiment(&self) -> Option<&str> {
 854        self.preferred_experiment.as_deref()
 855    }
 856
 857    pub fn set_preferred_experiment(&mut self, experiment: Option<String>) {
 858        self.preferred_experiment = experiment;
 859    }
 860
 861    pub fn available_experiments(&self) -> &[String] {
 862        &self.available_experiments
 863    }
 864
 865    pub fn active_experiment(&self) -> Option<&str> {
 866        self.preferred_experiment.as_deref().or_else(|| {
 867            self.shown_predictions
 868                .iter()
 869                .find_map(|p| p.model_version.as_ref())
 870                .and_then(|model_version| model_version.strip_prefix("zeta2:"))
 871        })
 872    }
 873
 874    pub fn refresh_available_experiments(&mut self, cx: &mut Context<Self>) {
 875        let client = self.client.clone();
 876        let llm_token = self.llm_token.clone();
 877        let app_version = AppVersion::global(cx);
 878        let organization_id = self
 879            .user_store
 880            .read(cx)
 881            .current_organization()
 882            .map(|organization| organization.id.clone());
 883
 884        cx.spawn(async move |this, cx| {
 885            let experiments = cx
 886                .background_spawn(async move {
 887                    let http_client = client.http_client();
 888                    let token = client
 889                        .acquire_llm_token(&llm_token, organization_id.clone())
 890                        .await?;
 891                    let url = http_client.build_zed_llm_url("/edit_prediction_experiments", &[])?;
 892                    let request = http_client::Request::builder()
 893                        .method(Method::GET)
 894                        .uri(url.as_ref())
 895                        .header("Authorization", format!("Bearer {}", token))
 896                        .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
 897                        .body(Default::default())?;
 898                    let mut response = http_client.send(request).await?;
 899                    if response.status().is_success() {
 900                        let mut body = Vec::new();
 901                        response.body_mut().read_to_end(&mut body).await?;
 902                        let experiments: Vec<String> = serde_json::from_slice(&body)?;
 903                        Ok(experiments)
 904                    } else {
 905                        let mut body = String::new();
 906                        response.body_mut().read_to_string(&mut body).await?;
 907                        anyhow::bail!(
 908                            "Failed to fetch experiments: {:?}\nBody: {}",
 909                            response.status(),
 910                            body
 911                        );
 912                    }
 913                })
 914                .await?;
 915            this.update(cx, |this, cx| {
 916                this.available_experiments = experiments;
 917                cx.notify();
 918            })?;
 919            anyhow::Ok(())
 920        })
 921        .detach_and_log_err(cx);
 922    }
 923
 924    pub fn icons(&self, cx: &App) -> edit_prediction_types::EditPredictionIconSet {
 925        use ui::IconName;
 926        match self.edit_prediction_model {
 927            EditPredictionModel::Mercury => {
 928                edit_prediction_types::EditPredictionIconSet::new(IconName::Inception)
 929            }
 930            EditPredictionModel::Zeta => {
 931                edit_prediction_types::EditPredictionIconSet::new(IconName::ZedPredict)
 932                    .with_disabled(IconName::ZedPredictDisabled)
 933                    .with_up(IconName::ZedPredictUp)
 934                    .with_down(IconName::ZedPredictDown)
 935                    .with_error(IconName::ZedPredictError)
 936            }
 937            EditPredictionModel::Fim { .. } => {
 938                let settings = &all_language_settings(None, cx).edit_predictions;
 939                match settings.provider {
 940                    EditPredictionProvider::Ollama => {
 941                        edit_prediction_types::EditPredictionIconSet::new(IconName::AiOllama)
 942                    }
 943                    _ => {
 944                        edit_prediction_types::EditPredictionIconSet::new(IconName::AiOpenAiCompat)
 945                    }
 946                }
 947            }
 948        }
 949    }
 950
 951    pub fn has_mercury_api_token(&self, cx: &App) -> bool {
 952        self.mercury.api_token.read(cx).has_key()
 953    }
 954
 955    pub fn mercury_has_payment_required_error(&self) -> bool {
 956        self.mercury.has_payment_required_error()
 957    }
 958
 959    pub fn clear_history(&mut self) {
 960        for project_state in self.projects.values_mut() {
 961            project_state.events.clear();
 962            project_state.last_event.take();
 963        }
 964    }
 965
 966    pub fn clear_history_for_project(&mut self, project: &Entity<Project>) {
 967        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
 968            project_state.events.clear();
 969            project_state.last_event.take();
 970        }
 971    }
 972
 973    pub fn edit_history_for_project(
 974        &self,
 975        project: &Entity<Project>,
 976        cx: &App,
 977    ) -> Vec<StoredEvent> {
 978        self.projects
 979            .get(&project.entity_id())
 980            .map(|project_state| project_state.events(cx))
 981            .unwrap_or_default()
 982    }
 983
 984    pub fn context_for_project<'a>(
 985        &'a self,
 986        project: &Entity<Project>,
 987        cx: &'a mut App,
 988    ) -> Vec<RelatedFile> {
 989        self.projects
 990            .get(&project.entity_id())
 991            .map(|project_state| {
 992                project_state.context.update(cx, |context, cx| {
 993                    context
 994                        .related_files_with_buffers(cx)
 995                        .map(|(mut related_file, buffer)| {
 996                            related_file.in_open_source_repo = buffer
 997                                .read(cx)
 998                                .file()
 999                                .map_or(false, |file| self.is_file_open_source(&project, file, cx));
1000                            related_file
1001                        })
1002                        .collect()
1003                })
1004            })
1005            .unwrap_or_default()
1006    }
1007
1008    pub fn copilot_for_project(&self, project: &Entity<Project>) -> Option<Entity<Copilot>> {
1009        self.projects
1010            .get(&project.entity_id())
1011            .and_then(|project| project.copilot.clone())
1012    }
1013
1014    pub fn start_copilot_for_project(
1015        &mut self,
1016        project: &Entity<Project>,
1017        cx: &mut Context<Self>,
1018    ) -> Option<Entity<Copilot>> {
1019        if DisableAiSettings::get(None, cx).disable_ai {
1020            return None;
1021        }
1022        let state = self.get_or_init_project(project, cx);
1023
1024        if state.copilot.is_some() {
1025            return state.copilot.clone();
1026        }
1027        let _project = project.clone();
1028        let project = project.read(cx);
1029
1030        let node = project.node_runtime().cloned();
1031        if let Some(node) = node {
1032            let next_id = project.languages().next_language_server_id();
1033            let fs = project.fs().clone();
1034
1035            let copilot = cx.new(|cx| Copilot::new(Some(_project), next_id, fs, node, cx));
1036            state.copilot = Some(copilot.clone());
1037            Some(copilot)
1038        } else {
1039            None
1040        }
1041    }
1042
1043    pub fn context_for_project_with_buffers<'a>(
1044        &'a self,
1045        project: &Entity<Project>,
1046        cx: &'a mut App,
1047    ) -> Vec<(RelatedFile, Entity<Buffer>)> {
1048        self.projects
1049            .get(&project.entity_id())
1050            .map(|project| {
1051                project.context.update(cx, |context, cx| {
1052                    context.related_files_with_buffers(cx).collect()
1053                })
1054            })
1055            .unwrap_or_default()
1056    }
1057
1058    pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
1059        if matches!(self.edit_prediction_model, EditPredictionModel::Zeta) {
1060            self.user_store.read(cx).edit_prediction_usage()
1061        } else {
1062            None
1063        }
1064    }
1065
1066    pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
1067        self.get_or_init_project(project, cx);
1068    }
1069
1070    pub fn register_buffer(
1071        &mut self,
1072        buffer: &Entity<Buffer>,
1073        project: &Entity<Project>,
1074        cx: &mut Context<Self>,
1075    ) {
1076        let project_state = self.get_or_init_project(project, cx);
1077        Self::register_buffer_impl(project_state, buffer, project, cx);
1078    }
1079
1080    fn get_or_init_project(
1081        &mut self,
1082        project: &Entity<Project>,
1083        cx: &mut Context<Self>,
1084    ) -> &mut ProjectState {
1085        let entity_id = project.entity_id();
1086        self.projects
1087            .entry(entity_id)
1088            .or_insert_with(|| ProjectState {
1089                context: {
1090                    let related_excerpt_store = cx.new(|cx| RelatedExcerptStore::new(project, cx));
1091                    cx.subscribe(&related_excerpt_store, move |this, _, event, _| {
1092                        this.handle_excerpt_store_event(entity_id, event);
1093                    })
1094                    .detach();
1095                    related_excerpt_store
1096                },
1097                events: VecDeque::new(),
1098                last_event: None,
1099                recent_paths: VecDeque::new(),
1100                debug_tx: None,
1101                registered_buffers: HashMap::default(),
1102                current_prediction: None,
1103                cancelled_predictions: HashSet::default(),
1104                pending_predictions: ArrayVec::new(),
1105                next_pending_prediction_id: 0,
1106                last_edit_prediction_refresh: None,
1107                last_jump_prediction_refresh: None,
1108                license_detection_watchers: HashMap::default(),
1109                _subscriptions: [
1110                    cx.subscribe(&project, Self::handle_project_event),
1111                    cx.observe_release(&project, move |this, _, cx| {
1112                        this.projects.remove(&entity_id);
1113                        cx.notify();
1114                    }),
1115                ],
1116                copilot: None,
1117            })
1118    }
1119
1120    pub fn remove_project(&mut self, project: &Entity<Project>) {
1121        self.projects.remove(&project.entity_id());
1122    }
1123
1124    fn handle_excerpt_store_event(
1125        &mut self,
1126        project_entity_id: EntityId,
1127        event: &RelatedExcerptStoreEvent,
1128    ) {
1129        if let Some(project_state) = self.projects.get(&project_entity_id) {
1130            if let Some(debug_tx) = project_state.debug_tx.clone() {
1131                match event {
1132                    RelatedExcerptStoreEvent::StartedRefresh => {
1133                        debug_tx
1134                            .unbounded_send(DebugEvent::ContextRetrievalStarted(
1135                                ContextRetrievalStartedDebugEvent {
1136                                    project_entity_id: project_entity_id,
1137                                    timestamp: Instant::now(),
1138                                    search_prompt: String::new(),
1139                                },
1140                            ))
1141                            .ok();
1142                    }
1143                    RelatedExcerptStoreEvent::FinishedRefresh {
1144                        cache_hit_count,
1145                        cache_miss_count,
1146                        mean_definition_latency,
1147                        max_definition_latency,
1148                    } => {
1149                        debug_tx
1150                            .unbounded_send(DebugEvent::ContextRetrievalFinished(
1151                                ContextRetrievalFinishedDebugEvent {
1152                                    project_entity_id: project_entity_id,
1153                                    timestamp: Instant::now(),
1154                                    metadata: vec![
1155                                        (
1156                                            "Cache Hits",
1157                                            format!(
1158                                                "{}/{}",
1159                                                cache_hit_count,
1160                                                cache_hit_count + cache_miss_count
1161                                            )
1162                                            .into(),
1163                                        ),
1164                                        (
1165                                            "Max LSP Time",
1166                                            format!("{} ms", max_definition_latency.as_millis())
1167                                                .into(),
1168                                        ),
1169                                        (
1170                                            "Mean LSP Time",
1171                                            format!("{} ms", mean_definition_latency.as_millis())
1172                                                .into(),
1173                                        ),
1174                                    ],
1175                                },
1176                            ))
1177                            .ok();
1178                    }
1179                }
1180            }
1181        }
1182    }
1183
1184    pub fn debug_info(
1185        &mut self,
1186        project: &Entity<Project>,
1187        cx: &mut Context<Self>,
1188    ) -> mpsc::UnboundedReceiver<DebugEvent> {
1189        let project_state = self.get_or_init_project(project, cx);
1190        let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
1191        project_state.debug_tx = Some(debug_watch_tx);
1192        debug_watch_rx
1193    }
1194
1195    fn handle_project_event(
1196        &mut self,
1197        project: Entity<Project>,
1198        event: &project::Event,
1199        cx: &mut Context<Self>,
1200    ) {
1201        if !is_ep_store_provider(all_language_settings(None, cx).edit_predictions.provider) {
1202            return;
1203        }
1204        // TODO [zeta2] init with recent paths
1205        match event {
1206            project::Event::ActiveEntryChanged(Some(active_entry_id)) => {
1207                let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1208                    return;
1209                };
1210                let path = project.read(cx).path_for_entry(*active_entry_id, cx);
1211                if let Some(path) = path {
1212                    if let Some(ix) = project_state
1213                        .recent_paths
1214                        .iter()
1215                        .position(|probe| probe == &path)
1216                    {
1217                        project_state.recent_paths.remove(ix);
1218                    }
1219                    project_state.recent_paths.push_front(path);
1220                }
1221            }
1222            project::Event::DiagnosticsUpdated { .. } => {
1223                if cx.has_flag::<EditPredictionJumpsFeatureFlag>() {
1224                    self.refresh_prediction_from_diagnostics(
1225                        project,
1226                        DiagnosticSearchScope::Global,
1227                        cx,
1228                    );
1229                }
1230            }
1231            _ => (),
1232        }
1233    }
1234
1235    fn register_buffer_impl<'a>(
1236        project_state: &'a mut ProjectState,
1237        buffer: &Entity<Buffer>,
1238        project: &Entity<Project>,
1239        cx: &mut Context<Self>,
1240    ) -> &'a mut RegisteredBuffer {
1241        let buffer_id = buffer.entity_id();
1242
1243        if let Some(file) = buffer.read(cx).file() {
1244            let worktree_id = file.worktree_id(cx);
1245            if let Some(worktree) = project.read(cx).worktree_for_id(worktree_id, cx) {
1246                project_state
1247                    .license_detection_watchers
1248                    .entry(worktree_id)
1249                    .or_insert_with(|| {
1250                        let project_entity_id = project.entity_id();
1251                        cx.observe_release(&worktree, move |this, _worktree, _cx| {
1252                            let Some(project_state) = this.projects.get_mut(&project_entity_id)
1253                            else {
1254                                return;
1255                            };
1256                            project_state
1257                                .license_detection_watchers
1258                                .remove(&worktree_id);
1259                        })
1260                        .detach();
1261                        Rc::new(LicenseDetectionWatcher::new(&worktree, cx))
1262                    });
1263            }
1264        }
1265
1266        match project_state.registered_buffers.entry(buffer_id) {
1267            hash_map::Entry::Occupied(entry) => entry.into_mut(),
1268            hash_map::Entry::Vacant(entry) => {
1269                let buf = buffer.read(cx);
1270                let snapshot = buf.text_snapshot();
1271                let file = buf.file().cloned();
1272                let project_entity_id = project.entity_id();
1273                entry.insert(RegisteredBuffer {
1274                    snapshot,
1275                    file,
1276                    last_position: None,
1277                    pending_predictions: Vec::new(),
1278                    _subscriptions: [
1279                        cx.subscribe(buffer, {
1280                            let project = project.downgrade();
1281                            move |this, buffer, event, cx| {
1282                                if let language::BufferEvent::Edited { is_local } = event
1283                                    && let Some(project) = project.upgrade()
1284                                {
1285                                    this.report_changes_for_buffer(
1286                                        &buffer, &project, false, *is_local, cx,
1287                                    );
1288                                }
1289                            }
1290                        }),
1291                        cx.observe_release(buffer, move |this, _buffer, _cx| {
1292                            let Some(project_state) = this.projects.get_mut(&project_entity_id)
1293                            else {
1294                                return;
1295                            };
1296                            project_state.registered_buffers.remove(&buffer_id);
1297                        }),
1298                    ],
1299                })
1300            }
1301        }
1302    }
1303
1304    fn report_changes_for_buffer(
1305        &mut self,
1306        buffer: &Entity<Buffer>,
1307        project: &Entity<Project>,
1308        is_predicted: bool,
1309        is_local: bool,
1310        cx: &mut Context<Self>,
1311    ) {
1312        let project_state = self.get_or_init_project(project, cx);
1313        let registered_buffer = Self::register_buffer_impl(project_state, buffer, project, cx);
1314
1315        let buf = buffer.read(cx);
1316        let new_file = buf.file().cloned();
1317        let new_snapshot = buf.text_snapshot();
1318        if new_snapshot.version == registered_buffer.snapshot.version {
1319            return;
1320        }
1321        let old_file = mem::replace(&mut registered_buffer.file, new_file.clone());
1322        let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
1323        let mut edit_range: Option<Range<Anchor>> = None;
1324        let now = cx.background_executor().now();
1325
1326        for (_edit, anchor_range) in
1327            new_snapshot.anchored_edits_since::<usize>(&old_snapshot.version)
1328        {
1329            edit_range = Some(match edit_range {
1330                None => anchor_range,
1331                Some(acc) => acc.start..anchor_range.end,
1332            });
1333        }
1334
1335        let Some(edit_range) = edit_range else {
1336            return;
1337        };
1338
1339        for pending_prediction in &mut registered_buffer.pending_predictions {
1340            if edit_range.overlaps(&pending_prediction.editable_anchor_range, &new_snapshot) {
1341                pending_prediction.last_edit_at = now;
1342            }
1343        }
1344
1345        let include_in_history = is_local
1346            || collaborator_edit_overlaps_locality_region(
1347                project_state,
1348                project,
1349                buffer,
1350                &buf.snapshot(),
1351                &edit_range,
1352                cx,
1353            );
1354
1355        if !include_in_history {
1356            return;
1357        }
1358
1359        let is_recordable_history_edit =
1360            compute_diff_between_snapshots_in_range(&old_snapshot, &new_snapshot, &edit_range)
1361                .is_some();
1362
1363        let events = &mut project_state.events;
1364
1365        if !is_recordable_history_edit {
1366            if let Some(event) = project_state.last_event.take() {
1367                if let Some(event) = event.finalize(&project_state.license_detection_watchers, cx) {
1368                    if events.len() + 1 >= EVENT_COUNT_MAX {
1369                        events.pop_front();
1370                    }
1371                    events.push_back(event);
1372                }
1373            }
1374            return;
1375        }
1376
1377        if let Some(last_event) = project_state.last_event.as_mut() {
1378            let is_next_snapshot_of_same_buffer = old_snapshot.remote_id()
1379                == last_event.new_snapshot.remote_id()
1380                && old_snapshot.version == last_event.new_snapshot.version;
1381
1382            let prediction_source_changed = is_predicted != last_event.predicted;
1383
1384            let should_coalesce = is_next_snapshot_of_same_buffer
1385                && !prediction_source_changed
1386                && lines_between_ranges(
1387                    &edit_range.to_point(&new_snapshot),
1388                    &last_event.latest_edit_range.to_point(&new_snapshot),
1389                ) <= CHANGE_GROUPING_LINE_SPAN;
1390
1391            if should_coalesce {
1392                let pause_elapsed = last_event
1393                    .last_edit_time
1394                    .map(|t| now.duration_since(t) >= LAST_CHANGE_GROUPING_TIME)
1395                    .unwrap_or(false);
1396                if pause_elapsed {
1397                    last_event.snapshot_after_last_editing_pause =
1398                        Some(last_event.new_snapshot.clone());
1399                    last_event.total_edit_range_at_last_pause_boundary =
1400                        Some(last_event.total_edit_range.clone());
1401                }
1402
1403                last_event.latest_edit_range = edit_range.clone();
1404                last_event.total_edit_range =
1405                    merge_anchor_ranges(&last_event.total_edit_range, &edit_range, &new_snapshot);
1406                last_event.new_snapshot = new_snapshot;
1407                last_event.last_edit_time = Some(now);
1408                return;
1409            }
1410        }
1411
1412        if let Some(event) = project_state.last_event.take() {
1413            if let Some(event) = event.finalize(&project_state.license_detection_watchers, cx) {
1414                if events.len() + 1 >= EVENT_COUNT_MAX {
1415                    events.pop_front();
1416                }
1417                events.push_back(event);
1418            }
1419        }
1420
1421        merge_trailing_events_if_needed(events, &old_snapshot, &new_snapshot, &edit_range);
1422
1423        project_state.last_event = Some(LastEvent {
1424            old_file,
1425            new_file,
1426            old_snapshot,
1427            new_snapshot,
1428            latest_edit_range: edit_range.clone(),
1429            total_edit_range: edit_range,
1430            total_edit_range_at_last_pause_boundary: None,
1431            predicted: is_predicted,
1432            snapshot_after_last_editing_pause: None,
1433            last_edit_time: Some(now),
1434        });
1435    }
1436
1437    fn prediction_at(
1438        &mut self,
1439        buffer: &Entity<Buffer>,
1440        position: Option<language::Anchor>,
1441        project: &Entity<Project>,
1442        cx: &App,
1443    ) -> Option<BufferEditPrediction<'_>> {
1444        let project_state = self.projects.get_mut(&project.entity_id())?;
1445        if let Some(position) = position
1446            && let Some(buffer) = project_state
1447                .registered_buffers
1448                .get_mut(&buffer.entity_id())
1449        {
1450            buffer.last_position = Some(position);
1451        }
1452
1453        let CurrentEditPrediction {
1454            requested_by,
1455            prediction,
1456            ..
1457        } = project_state.current_prediction.as_ref()?;
1458
1459        if prediction.targets_buffer(buffer.read(cx)) {
1460            Some(BufferEditPrediction::Local { prediction })
1461        } else {
1462            let show_jump = match requested_by {
1463                PredictionRequestedBy::Buffer(requested_by_buffer_id) => {
1464                    requested_by_buffer_id == &buffer.entity_id()
1465                }
1466                PredictionRequestedBy::DiagnosticsUpdate => true,
1467            };
1468
1469            if show_jump {
1470                Some(BufferEditPrediction::Jump { prediction })
1471            } else {
1472                None
1473            }
1474        }
1475    }
1476
1477    fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
1478        let Some(current_prediction) = self
1479            .projects
1480            .get_mut(&project.entity_id())
1481            .and_then(|project_state| project_state.current_prediction.take())
1482        else {
1483            return;
1484        };
1485
1486        self.report_changes_for_buffer(
1487            &current_prediction.prediction.buffer,
1488            project,
1489            true,
1490            true,
1491            cx,
1492        );
1493
1494        // can't hold &mut project_state ref across report_changes_for_buffer_call
1495        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1496            return;
1497        };
1498
1499        for pending_prediction in mem::take(&mut project_state.pending_predictions) {
1500            project_state.cancel_pending_prediction(pending_prediction, cx);
1501        }
1502
1503        match self.edit_prediction_model {
1504            EditPredictionModel::Mercury => {
1505                mercury::edit_prediction_accepted(
1506                    current_prediction.prediction.id,
1507                    self.client.http_client(),
1508                    cx,
1509                );
1510            }
1511            EditPredictionModel::Zeta => {
1512                let is_cloud = !matches!(
1513                    all_language_settings(None, cx).edit_predictions.provider,
1514                    EditPredictionProvider::Ollama | EditPredictionProvider::OpenAiCompatibleApi
1515                );
1516                if is_cloud {
1517                    zeta::edit_prediction_accepted(self, current_prediction, cx)
1518                }
1519            }
1520            EditPredictionModel::Fim { .. } => {}
1521        }
1522    }
1523
1524    async fn handle_rejected_predictions(
1525        rx: UnboundedReceiver<EditPredictionRejectionPayload>,
1526        client: Arc<Client>,
1527        llm_token: LlmApiToken,
1528        app_version: Version,
1529        background_executor: BackgroundExecutor,
1530    ) {
1531        let mut rx = std::pin::pin!(rx.peekable());
1532        let mut batched = Vec::new();
1533
1534        while let Some(EditPredictionRejectionPayload {
1535            rejection,
1536            organization_id,
1537        }) = rx.next().await
1538        {
1539            batched.push(rejection);
1540
1541            if batched.len() < MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST / 2 {
1542                select_biased! {
1543                    next = rx.as_mut().peek().fuse() => {
1544                        if next.is_some() {
1545                            continue;
1546                        }
1547                    }
1548                    () = background_executor.timer(REJECT_REQUEST_DEBOUNCE).fuse() => {},
1549                }
1550            }
1551
1552            let url = client
1553                .http_client()
1554                .build_zed_llm_url("/predict_edits/reject", &[])
1555                .unwrap();
1556
1557            let flush_count = batched
1558                .len()
1559                // in case items have accumulated after failure
1560                .min(MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST);
1561            let start = batched.len() - flush_count;
1562
1563            let body = RejectEditPredictionsBodyRef {
1564                rejections: &batched[start..],
1565            };
1566
1567            let result = Self::send_api_request::<()>(
1568                |builder| {
1569                    let req = builder
1570                        .uri(url.as_ref())
1571                        .body(serde_json::to_string(&body)?.into());
1572                    anyhow::Ok(req?)
1573                },
1574                client.clone(),
1575                llm_token.clone(),
1576                organization_id,
1577                app_version.clone(),
1578                true,
1579            )
1580            .await;
1581
1582            if result.log_err().is_some() {
1583                batched.drain(start..);
1584            }
1585        }
1586    }
1587
1588    async fn run_settled_predictions_worker(
1589        this: WeakEntity<Self>,
1590        mut rx: UnboundedReceiver<Instant>,
1591        cx: &mut AsyncApp,
1592    ) {
1593        let mut next_wake_time: Option<Instant> = None;
1594        loop {
1595            let now = cx.background_executor().now();
1596            if let Some(wake_time) = next_wake_time.take() {
1597                cx.background_executor()
1598                    .timer(wake_time.duration_since(now))
1599                    .await;
1600            } else {
1601                let Some(new_enqueue_time) = rx.next().await else {
1602                    break;
1603                };
1604                next_wake_time = Some(new_enqueue_time + EDIT_PREDICTION_SETTLED_QUIESCENCE);
1605                while rx.next().now_or_never().flatten().is_some() {}
1606                continue;
1607            }
1608
1609            let Some(this) = this.upgrade() else {
1610                break;
1611            };
1612
1613            let now = cx.background_executor().now();
1614            let mut oldest_edited_at = None;
1615            let mut ready_predictions = Vec::new();
1616
1617            this.update(cx, |this, _| {
1618                for (_, project_state) in this.projects.iter_mut() {
1619                    for (_, registered_buffer) in project_state.registered_buffers.iter_mut() {
1620                        let mut pending_index = 0;
1621                        while pending_index < registered_buffer.pending_predictions.len() {
1622                            let pending_prediction =
1623                                &registered_buffer.pending_predictions[pending_index];
1624                            let age = now.saturating_duration_since(pending_prediction.enqueued_at);
1625                            if age >= EDIT_PREDICTION_SETTLED_TTL {
1626                                registered_buffer.pending_predictions.remove(pending_index);
1627                                continue;
1628                            }
1629
1630                            let quiet_for =
1631                                now.saturating_duration_since(pending_prediction.last_edit_at);
1632                            if quiet_for >= EDIT_PREDICTION_SETTLED_QUIESCENCE {
1633                                let pending_prediction =
1634                                    registered_buffer.pending_predictions.remove(pending_index);
1635                                let settled_editable_region = registered_buffer
1636                                    .snapshot
1637                                    .text_for_range(
1638                                        pending_prediction.editable_anchor_range.clone(),
1639                                    )
1640                                    .collect::<String>();
1641                                ready_predictions
1642                                    .push((pending_prediction, settled_editable_region));
1643                                continue;
1644                            }
1645
1646                            if oldest_edited_at
1647                                .is_none_or(|time| pending_prediction.last_edit_at < time)
1648                            {
1649                                oldest_edited_at = Some(pending_prediction.last_edit_at);
1650                            }
1651                            pending_index += 1;
1652                        }
1653                    }
1654                }
1655            });
1656
1657            for (pending_prediction, settled_editable_region) in ready_predictions {
1658                let PendingSettledPrediction {
1659                    request_id,
1660                    editable_region_before_prediction,
1661                    predicted_editable_region,
1662                    ts_error_count_before_prediction,
1663                    ts_error_count_after_prediction,
1664                    example,
1665                    e2e_latency,
1666                    ..
1667                } = pending_prediction;
1668                let settled_editable_region_for_metrics = settled_editable_region.clone();
1669                let kept_rate_result = cx
1670                    .background_spawn(async move {
1671                        compute_kept_rate(
1672                            &editable_region_before_prediction,
1673                            &predicted_editable_region,
1674                            &settled_editable_region_for_metrics,
1675                        )
1676                    })
1677                    .await;
1678
1679                #[cfg(test)]
1680                {
1681                    let request_id = request_id.clone();
1682                    let settled_editable_region = settled_editable_region.clone();
1683                    this.update(cx, |this, _| {
1684                        if let Some(callback) = &this.settled_event_callback {
1685                            callback(request_id, settled_editable_region);
1686                        }
1687                    });
1688                }
1689
1690                telemetry::event!(
1691                    EDIT_PREDICTION_SETTLED_EVENT,
1692                    request_id = request_id.0.clone(),
1693                    settled_editable_region,
1694                    ts_error_count_before_prediction,
1695                    ts_error_count_after_prediction,
1696                    edit_bytes_candidate_new = kept_rate_result.candidate_new_chars,
1697                    edit_bytes_reference_new = kept_rate_result.reference_new_chars,
1698                    edit_bytes_candidate_deleted = kept_rate_result.candidate_deleted_chars,
1699                    edit_bytes_reference_deleted = kept_rate_result.reference_deleted_chars,
1700                    edit_bytes_kept = kept_rate_result.kept_chars,
1701                    edit_bytes_correctly_deleted = kept_rate_result.correctly_deleted_chars,
1702                    edit_bytes_discarded = kept_rate_result.discarded_chars,
1703                    edit_bytes_context = kept_rate_result.context_chars,
1704                    edit_bytes_kept_rate = kept_rate_result.kept_rate,
1705                    edit_bytes_recall_rate = kept_rate_result.recall_rate,
1706                    example,
1707                    e2e_latency = e2e_latency.as_millis(),
1708                );
1709            }
1710
1711            next_wake_time = oldest_edited_at.map(|time| time + EDIT_PREDICTION_SETTLED_QUIESCENCE);
1712        }
1713    }
1714
1715    pub(crate) fn enqueue_settled_prediction(
1716        &mut self,
1717        request_id: EditPredictionId,
1718        project: &Entity<Project>,
1719        edited_buffer: &Entity<Buffer>,
1720        edited_buffer_snapshot: &BufferSnapshot,
1721        editable_offset_range: Range<usize>,
1722        edit_preview: &EditPreview,
1723        example: Option<ExampleSpec>,
1724        e2e_latency: std::time::Duration,
1725        cx: &mut Context<Self>,
1726    ) {
1727        let this = &mut *self;
1728        let project_state = this.get_or_init_project(project, cx);
1729        let Some(registered_buffer) = project_state
1730            .registered_buffers
1731            .get_mut(&edited_buffer.entity_id())
1732        else {
1733            return;
1734        };
1735
1736        let editable_region_before_prediction = edited_buffer_snapshot
1737            .text_for_range(editable_offset_range.clone())
1738            .collect::<String>();
1739        let editable_anchor_range_for_result =
1740            edited_buffer_snapshot.anchor_range_inside(editable_offset_range.clone());
1741        let predicted_editable_region = edit_preview
1742            .result_text_snapshot()
1743            .text_for_range(editable_anchor_range_for_result.clone())
1744            .collect();
1745        let ts_error_count_before_prediction = crate::metrics::count_tree_sitter_errors(
1746            edited_buffer_snapshot
1747                .syntax_layers_for_range(editable_anchor_range_for_result.clone(), true),
1748        );
1749        let ts_error_count_after_prediction = crate::metrics::count_tree_sitter_errors(
1750            edit_preview.result_syntax_snapshot().layers_for_range(
1751                editable_anchor_range_for_result,
1752                edit_preview.result_text_snapshot(),
1753                true,
1754            ),
1755        );
1756        let editable_anchor_range =
1757            edited_buffer_snapshot.anchor_range_inside(editable_offset_range);
1758        let now = cx.background_executor().now();
1759        registered_buffer
1760            .pending_predictions
1761            .push(PendingSettledPrediction {
1762                request_id,
1763                editable_anchor_range,
1764                editable_region_before_prediction,
1765                predicted_editable_region,
1766                ts_error_count_before_prediction,
1767                ts_error_count_after_prediction,
1768                example,
1769                e2e_latency,
1770                enqueued_at: now,
1771                last_edit_at: now,
1772            });
1773        this.settled_predictions_tx.unbounded_send(now).ok();
1774    }
1775
1776    fn reject_current_prediction(
1777        &mut self,
1778        reason: EditPredictionRejectReason,
1779        project: &Entity<Project>,
1780        cx: &App,
1781    ) {
1782        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
1783            project_state.pending_predictions.clear();
1784            if let Some(prediction) = project_state.current_prediction.take() {
1785                let model_version = prediction.prediction.model_version.clone();
1786                self.reject_prediction(
1787                    prediction.prediction.id,
1788                    reason,
1789                    prediction.was_shown,
1790                    model_version,
1791                    Some(prediction.e2e_latency),
1792                    cx,
1793                );
1794            }
1795        };
1796    }
1797
1798    fn did_show_current_prediction(
1799        &mut self,
1800        project: &Entity<Project>,
1801        display_type: edit_prediction_types::SuggestionDisplayType,
1802        _cx: &mut Context<Self>,
1803    ) {
1804        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1805            return;
1806        };
1807
1808        let Some(current_prediction) = project_state.current_prediction.as_mut() else {
1809            return;
1810        };
1811
1812        let is_jump = display_type == edit_prediction_types::SuggestionDisplayType::Jump;
1813        let previous_shown_with = current_prediction.shown_with;
1814
1815        if previous_shown_with.is_none() || !is_jump {
1816            current_prediction.shown_with = Some(display_type);
1817        }
1818
1819        let is_first_non_jump_show = !current_prediction.was_shown && !is_jump;
1820
1821        if is_first_non_jump_show {
1822            current_prediction.was_shown = true;
1823        }
1824
1825        if is_first_non_jump_show {
1826            self.shown_predictions
1827                .push_front(current_prediction.prediction.clone());
1828            if self.shown_predictions.len() > 50 {
1829                let completion = self.shown_predictions.pop_back().unwrap();
1830                self.rated_predictions.remove(&completion.id);
1831            }
1832        }
1833    }
1834
1835    fn reject_prediction(
1836        &mut self,
1837        prediction_id: EditPredictionId,
1838        reason: EditPredictionRejectReason,
1839        was_shown: bool,
1840        model_version: Option<String>,
1841        e2e_latency: Option<std::time::Duration>,
1842        cx: &App,
1843    ) {
1844        match self.edit_prediction_model {
1845            EditPredictionModel::Zeta => {
1846                let is_cloud = !matches!(
1847                    all_language_settings(None, cx).edit_predictions.provider,
1848                    EditPredictionProvider::Ollama | EditPredictionProvider::OpenAiCompatibleApi
1849                );
1850
1851                if is_cloud {
1852                    let organization_id = self
1853                        .user_store
1854                        .read(cx)
1855                        .current_organization()
1856                        .map(|organization| organization.id.clone());
1857
1858                    self.reject_predictions_tx
1859                        .unbounded_send(EditPredictionRejectionPayload {
1860                            rejection: EditPredictionRejection {
1861                                request_id: prediction_id.to_string(),
1862                                reason,
1863                                was_shown,
1864                                model_version,
1865                                e2e_latency_ms: e2e_latency.map(|latency| latency.as_millis()),
1866                            },
1867                            organization_id,
1868                        })
1869                        .log_err();
1870                }
1871            }
1872            EditPredictionModel::Mercury => {
1873                mercury::edit_prediction_rejected(
1874                    prediction_id,
1875                    was_shown,
1876                    reason,
1877                    self.client.http_client(),
1878                    cx,
1879                );
1880            }
1881            EditPredictionModel::Fim { .. } => {}
1882        }
1883    }
1884
1885    fn is_refreshing(&self, project: &Entity<Project>) -> bool {
1886        self.projects
1887            .get(&project.entity_id())
1888            .is_some_and(|project_state| !project_state.pending_predictions.is_empty())
1889    }
1890
1891    pub fn refresh_prediction_from_buffer(
1892        &mut self,
1893        project: Entity<Project>,
1894        buffer: Entity<Buffer>,
1895        position: language::Anchor,
1896        cx: &mut Context<Self>,
1897    ) {
1898        self.queue_prediction_refresh(
1899            project.clone(),
1900            PredictEditsRequestTrigger::Other,
1901            buffer.entity_id(),
1902            cx,
1903            move |this, cx| {
1904                let Some(request_task) = this
1905                    .update(cx, |this, cx| {
1906                        this.request_prediction(
1907                            &project,
1908                            &buffer,
1909                            position,
1910                            PredictEditsRequestTrigger::Other,
1911                            cx,
1912                        )
1913                    })
1914                    .log_err()
1915                else {
1916                    return Task::ready(anyhow::Ok(None));
1917                };
1918
1919                cx.spawn(async move |_cx| {
1920                    request_task.await.map(|prediction_result| {
1921                        prediction_result.map(|prediction_result| {
1922                            (
1923                                prediction_result,
1924                                PredictionRequestedBy::Buffer(buffer.entity_id()),
1925                            )
1926                        })
1927                    })
1928                })
1929            },
1930        )
1931    }
1932
1933    pub fn refresh_prediction_from_diagnostics(
1934        &mut self,
1935        project: Entity<Project>,
1936        scope: DiagnosticSearchScope,
1937        cx: &mut Context<Self>,
1938    ) {
1939        if !is_ep_store_provider(all_language_settings(None, cx).edit_predictions.provider) {
1940            return;
1941        }
1942
1943        if currently_following(&project, cx) {
1944            return;
1945        }
1946
1947        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1948            return;
1949        };
1950
1951        // Prefer predictions from buffer
1952        if project_state.current_prediction.is_some() {
1953            log::debug!(
1954                "edit_prediction: diagnostic refresh skipped, current prediction already exists"
1955            );
1956            return;
1957        }
1958
1959        self.queue_prediction_refresh(
1960            project.clone(),
1961            PredictEditsRequestTrigger::Diagnostics,
1962            project.entity_id(),
1963            cx,
1964            move |this, cx| {
1965                let Some((active_buffer, snapshot, cursor_point)) = this
1966                    .read_with(cx, |this, cx| {
1967                        let project_state = this.projects.get(&project.entity_id())?;
1968                        let (buffer, position) = project_state.active_buffer(&project, cx)?;
1969                        let snapshot = buffer.read(cx).snapshot();
1970
1971                        if !Self::predictions_enabled_at(&snapshot, position, cx) {
1972                            return None;
1973                        }
1974
1975                        let cursor_point = position
1976                            .map(|pos| pos.to_point(&snapshot))
1977                            .unwrap_or_default();
1978
1979                        Some((buffer, snapshot, cursor_point))
1980                    })
1981                    .log_err()
1982                    .flatten()
1983                else {
1984                    return Task::ready(anyhow::Ok(None));
1985                };
1986
1987                cx.spawn(async move |cx| {
1988                    let diagnostic_search_range = match scope {
1989                        DiagnosticSearchScope::Local => {
1990                            let diagnostic_search_start =
1991                                cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
1992                            let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
1993                            Point::new(diagnostic_search_start, 0)
1994                                ..Point::new(diagnostic_search_end, 0)
1995                        }
1996                        DiagnosticSearchScope::Global => Default::default(),
1997                    };
1998
1999                    let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
2000                        active_buffer,
2001                        &snapshot,
2002                        diagnostic_search_range,
2003                        cursor_point,
2004                        &project,
2005                        cx,
2006                    )
2007                    .await?
2008                    else {
2009                        return anyhow::Ok(None);
2010                    };
2011
2012                    let Some(prediction_result) = this
2013                        .update(cx, |this, cx| {
2014                            this.request_prediction(
2015                                &project,
2016                                &jump_buffer,
2017                                jump_position,
2018                                PredictEditsRequestTrigger::Diagnostics,
2019                                cx,
2020                            )
2021                        })?
2022                        .await?
2023                    else {
2024                        return anyhow::Ok(None);
2025                    };
2026
2027                    this.update(cx, |this, cx| {
2028                        Some((
2029                            if this
2030                                .get_or_init_project(&project, cx)
2031                                .current_prediction
2032                                .is_none()
2033                            {
2034                                prediction_result
2035                            } else {
2036                                EditPredictionResult {
2037                                    id: prediction_result.id,
2038                                    prediction: Err(EditPredictionRejectReason::CurrentPreferred),
2039                                    e2e_latency: prediction_result.e2e_latency,
2040                                }
2041                            },
2042                            PredictionRequestedBy::DiagnosticsUpdate,
2043                        ))
2044                    })
2045                })
2046            },
2047        );
2048    }
2049
2050    fn predictions_enabled_at(
2051        snapshot: &BufferSnapshot,
2052        position: Option<language::Anchor>,
2053        cx: &App,
2054    ) -> bool {
2055        let file = snapshot.file();
2056        let all_settings = all_language_settings(file, cx);
2057        if !all_settings.show_edit_predictions(snapshot.language(), cx)
2058            || file.is_some_and(|file| !all_settings.edit_predictions_enabled_for_file(file, cx))
2059        {
2060            return false;
2061        }
2062
2063        if let Some(last_position) = position {
2064            let settings = snapshot.settings_at(last_position, cx);
2065
2066            if !settings.edit_predictions_disabled_in.is_empty()
2067                && let Some(scope) = snapshot.language_scope_at(last_position)
2068                && let Some(scope_name) = scope.override_name()
2069                && settings
2070                    .edit_predictions_disabled_in
2071                    .iter()
2072                    .any(|s| s == scope_name)
2073            {
2074                return false;
2075            }
2076        }
2077
2078        true
2079    }
2080
2081    pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
2082}
2083
2084fn currently_following(project: &Entity<Project>, cx: &App) -> bool {
2085    let Some(app_state) = AppState::try_global(cx) else {
2086        return false;
2087    };
2088
2089    app_state
2090        .workspace_store
2091        .read(cx)
2092        .workspaces()
2093        .filter_map(|workspace| workspace.upgrade())
2094        .any(|workspace| {
2095            workspace.read(cx).project().entity_id() == project.entity_id()
2096                && workspace
2097                    .read(cx)
2098                    .leader_for_pane(workspace.read(cx).active_pane())
2099                    .is_some()
2100        })
2101}
2102
2103fn is_ep_store_provider(provider: EditPredictionProvider) -> bool {
2104    match provider {
2105        EditPredictionProvider::Zed
2106        | EditPredictionProvider::Mercury
2107        | EditPredictionProvider::Ollama
2108        | EditPredictionProvider::OpenAiCompatibleApi
2109        | EditPredictionProvider::Experimental(_) => true,
2110        EditPredictionProvider::None
2111        | EditPredictionProvider::Copilot
2112        | EditPredictionProvider::Codestral => false,
2113    }
2114}
2115
2116impl EditPredictionStore {
2117    fn queue_prediction_refresh(
2118        &mut self,
2119        project: Entity<Project>,
2120        request_trigger: PredictEditsRequestTrigger,
2121        throttle_entity: EntityId,
2122        cx: &mut Context<Self>,
2123        do_refresh: impl FnOnce(
2124            WeakEntity<Self>,
2125            &mut AsyncApp,
2126        )
2127            -> Task<Result<Option<(EditPredictionResult, PredictionRequestedBy)>>>
2128        + 'static,
2129    ) {
2130        fn select_throttle(
2131            project_state: &mut ProjectState,
2132            request_trigger: PredictEditsRequestTrigger,
2133        ) -> &mut Option<(EntityId, Instant)> {
2134            match request_trigger {
2135                PredictEditsRequestTrigger::Diagnostics => {
2136                    &mut project_state.last_jump_prediction_refresh
2137                }
2138                _ => &mut project_state.last_edit_prediction_refresh,
2139            }
2140        }
2141
2142        let (needs_acceptance_tracking, max_pending_predictions) =
2143            match all_language_settings(None, cx).edit_predictions.provider {
2144                EditPredictionProvider::Zed
2145                | EditPredictionProvider::Mercury
2146                | EditPredictionProvider::Experimental(_) => (true, 2),
2147                EditPredictionProvider::Ollama => (false, 1),
2148                EditPredictionProvider::OpenAiCompatibleApi => (false, 2),
2149                EditPredictionProvider::None
2150                | EditPredictionProvider::Copilot
2151                | EditPredictionProvider::Codestral => {
2152                    log::error!("queue_prediction_refresh called with non-store provider");
2153                    return;
2154                }
2155            };
2156
2157        let drop_on_cancel = !needs_acceptance_tracking;
2158        let throttle_timeout = Self::THROTTLE_TIMEOUT;
2159        let project_state = self.get_or_init_project(&project, cx);
2160        let pending_prediction_id = project_state.next_pending_prediction_id;
2161        project_state.next_pending_prediction_id += 1;
2162        let throttle_at_enqueue = *select_throttle(project_state, request_trigger);
2163
2164        let task = cx.spawn(async move |this, cx| {
2165            let throttle_wait = this
2166                .update(cx, |this, cx| {
2167                    let project_state = this.get_or_init_project(&project, cx);
2168                    let throttle = *select_throttle(project_state, request_trigger);
2169
2170                    let now = cx.background_executor().now();
2171                    throttle.and_then(|(last_entity, last_timestamp)| {
2172                        if throttle_entity != last_entity {
2173                            return None;
2174                        }
2175                        (last_timestamp + throttle_timeout).checked_duration_since(now)
2176                    })
2177                })
2178                .ok()
2179                .flatten();
2180
2181            if let Some(timeout) = throttle_wait {
2182                cx.background_executor().timer(timeout).await;
2183            }
2184
2185            // If this task was cancelled before the throttle timeout expired,
2186            // do not perform a request. Also skip if another task already
2187            // proceeded since we were enqueued (duplicate).
2188            let mut is_cancelled = true;
2189            this.update(cx, |this, cx| {
2190                let project_state = this.get_or_init_project(&project, cx);
2191                let was_cancelled = project_state
2192                    .cancelled_predictions
2193                    .remove(&pending_prediction_id);
2194                if was_cancelled {
2195                    return;
2196                }
2197
2198                // Another request has been already sent since this was enqueued
2199                if *select_throttle(project_state, request_trigger) != throttle_at_enqueue {
2200                    return;
2201                }
2202
2203                let new_refresh = (throttle_entity, cx.background_executor().now());
2204                *select_throttle(project_state, request_trigger) = Some(new_refresh);
2205                is_cancelled = false;
2206            })
2207            .ok();
2208            if is_cancelled {
2209                return None;
2210            }
2211
2212            let new_prediction_result = do_refresh(this.clone(), cx).await.log_err().flatten();
2213            let new_prediction_id = new_prediction_result
2214                .as_ref()
2215                .map(|(prediction, _)| prediction.id.clone());
2216
2217            // When a prediction completes, remove it from the pending list, and cancel
2218            // any pending predictions that were enqueued before it.
2219            this.update(cx, |this, cx| {
2220                let project_state = this.get_or_init_project(&project, cx);
2221
2222                let is_cancelled = project_state
2223                    .cancelled_predictions
2224                    .remove(&pending_prediction_id);
2225
2226                let new_current_prediction = if !is_cancelled
2227                    && let Some((prediction_result, requested_by)) = new_prediction_result
2228                {
2229                    match prediction_result.prediction {
2230                        Ok(prediction) => {
2231                            let new_prediction = CurrentEditPrediction {
2232                                requested_by,
2233                                prediction,
2234                                was_shown: false,
2235                                shown_with: None,
2236                                e2e_latency: prediction_result.e2e_latency,
2237                            };
2238
2239                            if let Some(current_prediction) =
2240                                project_state.current_prediction.as_ref()
2241                            {
2242                                if new_prediction.should_replace_prediction(&current_prediction, cx)
2243                                {
2244                                    this.reject_current_prediction(
2245                                        EditPredictionRejectReason::Replaced,
2246                                        &project,
2247                                        cx,
2248                                    );
2249
2250                                    Some(new_prediction)
2251                                } else {
2252                                    this.reject_prediction(
2253                                        new_prediction.prediction.id,
2254                                        EditPredictionRejectReason::CurrentPreferred,
2255                                        false,
2256                                        new_prediction.prediction.model_version,
2257                                        Some(new_prediction.e2e_latency),
2258                                        cx,
2259                                    );
2260                                    None
2261                                }
2262                            } else {
2263                                Some(new_prediction)
2264                            }
2265                        }
2266                        Err(reject_reason) => {
2267                            this.reject_prediction(
2268                                prediction_result.id,
2269                                reject_reason,
2270                                false,
2271                                None,
2272                                Some(prediction_result.e2e_latency),
2273                                cx,
2274                            );
2275                            None
2276                        }
2277                    }
2278                } else {
2279                    None
2280                };
2281
2282                let project_state = this.get_or_init_project(&project, cx);
2283
2284                if let Some(new_prediction) = new_current_prediction {
2285                    project_state.current_prediction = Some(new_prediction);
2286                }
2287
2288                let mut pending_predictions = mem::take(&mut project_state.pending_predictions);
2289                for (ix, pending_prediction) in pending_predictions.iter().enumerate() {
2290                    if pending_prediction.id == pending_prediction_id {
2291                        pending_predictions.remove(ix);
2292                        for pending_prediction in pending_predictions.drain(0..ix) {
2293                            project_state.cancel_pending_prediction(pending_prediction, cx)
2294                        }
2295                        break;
2296                    }
2297                }
2298                this.get_or_init_project(&project, cx).pending_predictions = pending_predictions;
2299                cx.notify();
2300            })
2301            .ok();
2302
2303            new_prediction_id
2304        });
2305
2306        if project_state.pending_predictions.len() < max_pending_predictions {
2307            project_state
2308                .pending_predictions
2309                .push(PendingPrediction {
2310                    id: pending_prediction_id,
2311                    task,
2312                    drop_on_cancel,
2313                })
2314                .unwrap();
2315        } else {
2316            let pending_prediction = project_state.pending_predictions.pop().unwrap();
2317            project_state
2318                .pending_predictions
2319                .push(PendingPrediction {
2320                    id: pending_prediction_id,
2321                    task,
2322                    drop_on_cancel,
2323                })
2324                .unwrap();
2325            project_state.cancel_pending_prediction(pending_prediction, cx);
2326        }
2327    }
2328
2329    pub fn request_prediction(
2330        &mut self,
2331        project: &Entity<Project>,
2332        active_buffer: &Entity<Buffer>,
2333        position: language::Anchor,
2334        trigger: PredictEditsRequestTrigger,
2335        cx: &mut Context<Self>,
2336    ) -> Task<Result<Option<EditPredictionResult>>> {
2337        self.request_prediction_internal(
2338            project.clone(),
2339            active_buffer.clone(),
2340            position,
2341            trigger,
2342            cx.has_flag::<EditPredictionJumpsFeatureFlag>(),
2343            cx,
2344        )
2345    }
2346
2347    fn request_prediction_internal(
2348        &mut self,
2349        project: Entity<Project>,
2350        active_buffer: Entity<Buffer>,
2351        position: language::Anchor,
2352        trigger: PredictEditsRequestTrigger,
2353        allow_jump: bool,
2354        cx: &mut Context<Self>,
2355    ) -> Task<Result<Option<EditPredictionResult>>> {
2356        self.get_or_init_project(&project, cx);
2357        let project_state = self.projects.get(&project.entity_id()).unwrap();
2358        let stored_events = project_state.events(cx);
2359        let has_events = !stored_events.is_empty();
2360        let events: Vec<Arc<zeta_prompt::Event>> =
2361            stored_events.iter().map(|e| e.event.clone()).collect();
2362        let debug_tx = project_state.debug_tx.clone();
2363
2364        let snapshot = active_buffer.read(cx).snapshot();
2365        let cursor_point = position.to_point(&snapshot);
2366        let diagnostic_search_start = cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
2367        let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
2368        let diagnostic_search_range =
2369            Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
2370
2371        let related_files = self.context_for_project(&project, cx);
2372        let mode = match all_language_settings(snapshot.file(), cx).edit_predictions_mode() {
2373            EditPredictionsMode::Eager => PredictEditsMode::Eager,
2374            EditPredictionsMode::Subtle => PredictEditsMode::Subtle,
2375        };
2376
2377        let is_open_source = snapshot
2378            .file()
2379            .map_or(false, |file| self.is_file_open_source(&project, file, cx))
2380            && events.iter().all(|event| event.in_open_source_repo())
2381            && related_files.iter().all(|file| file.in_open_source_repo);
2382
2383        let can_collect_data = !cfg!(test)
2384            && is_open_source
2385            && self.is_data_collection_enabled(cx)
2386            && matches!(self.edit_prediction_model, EditPredictionModel::Zeta);
2387        let inputs = EditPredictionModelInput {
2388            project: project.clone(),
2389            buffer: active_buffer,
2390            snapshot,
2391            position,
2392            events,
2393            related_files,
2394            mode,
2395            trigger,
2396            diagnostic_search_range,
2397            debug_tx,
2398            can_collect_data,
2399            is_open_source,
2400        };
2401
2402        let capture_data = (can_collect_data && rand::random_ratio(1, 1000)).then(|| stored_events);
2403
2404        let task = match self.edit_prediction_model {
2405            EditPredictionModel::Zeta => {
2406                zeta::request_prediction_with_zeta(self, inputs, capture_data, cx)
2407            }
2408            EditPredictionModel::Fim { format } => fim::request_prediction(inputs, format, cx),
2409            EditPredictionModel::Mercury => {
2410                self.mercury
2411                    .request_prediction(inputs, self.credentials_provider.clone(), cx)
2412            }
2413        };
2414
2415        cx.spawn(async move |this, cx| {
2416            let prediction = task.await?;
2417
2418            // Only fall back to diagnostics-based prediction if we got a
2419            // the model had nothing to suggest for the buffer
2420            if prediction.is_none()
2421                && allow_jump
2422                && has_events
2423                && !matches!(trigger, PredictEditsRequestTrigger::Diagnostics)
2424            {
2425                this.update(cx, |this, cx| {
2426                    this.refresh_prediction_from_diagnostics(
2427                        project,
2428                        DiagnosticSearchScope::Local,
2429                        cx,
2430                    );
2431                })?;
2432                return anyhow::Ok(None);
2433            }
2434
2435            Ok(prediction)
2436        })
2437    }
2438
2439    pub(crate) async fn next_diagnostic_location(
2440        active_buffer: Entity<Buffer>,
2441        active_buffer_snapshot: &BufferSnapshot,
2442        active_buffer_diagnostic_search_range: Range<Point>,
2443        active_buffer_cursor_point: Point,
2444        project: &Entity<Project>,
2445        cx: &mut AsyncApp,
2446    ) -> Result<Option<(Entity<Buffer>, language::Anchor)>> {
2447        let collaborator_cursor_rows: Vec<u32> = active_buffer_snapshot
2448            .selections_in_range(
2449                Anchor::min_max_range_for_buffer(active_buffer_snapshot.remote_id()),
2450                false,
2451            )
2452            .flat_map(|(_, _, _, selections)| {
2453                selections.map(|s| s.head().to_point(active_buffer_snapshot).row)
2454            })
2455            .collect();
2456
2457        let mut jump_location = active_buffer_snapshot
2458            .diagnostic_groups(None)
2459            .into_iter()
2460            .filter_map(|(_, group)| {
2461                let range = &group.entries[group.primary_ix]
2462                    .range
2463                    .to_point(&active_buffer_snapshot);
2464                if range.overlaps(&active_buffer_diagnostic_search_range) {
2465                    return None;
2466                }
2467                let near_collaborator = collaborator_cursor_rows.iter().any(|&collab_row| {
2468                    range.start.row.abs_diff(collab_row) <= DIAGNOSTIC_LINES_RANGE
2469                });
2470                let near_local = active_buffer_cursor_point.row.abs_diff(range.start.row)
2471                    <= DIAGNOSTIC_LINES_RANGE;
2472                if near_collaborator && !near_local {
2473                    return None;
2474                }
2475                Some(range.start)
2476            })
2477            .min_by_key(|probe| probe.row.abs_diff(active_buffer_cursor_point.row))
2478            .map(|position| {
2479                (
2480                    active_buffer.clone(),
2481                    active_buffer_snapshot.anchor_before(position),
2482                )
2483            });
2484
2485        if jump_location.is_none() {
2486            let active_buffer_path = active_buffer.read_with(cx, |buffer, cx| {
2487                let file = buffer.file()?;
2488
2489                Some(ProjectPath {
2490                    worktree_id: file.worktree_id(cx),
2491                    path: file.path().clone(),
2492                })
2493            });
2494
2495            let mut candidates: Vec<(ProjectPath, usize)> = project.read_with(cx, |project, cx| {
2496                project
2497                    .diagnostic_summaries(false, cx)
2498                    .filter(|(path, _, _)| Some(path) != active_buffer_path.as_ref())
2499                    .map(|(path, _, _)| {
2500                        let shared_prefix = path
2501                            .path
2502                            .components()
2503                            .zip(
2504                                active_buffer_path
2505                                    .as_ref()
2506                                    .map(|p| p.path.components())
2507                                    .unwrap_or_default(),
2508                            )
2509                            .take_while(|(a, b)| a == b)
2510                            .count();
2511                        (path, shared_prefix)
2512                    })
2513                    .collect()
2514            });
2515
2516            candidates.sort_by(|a, b| b.1.cmp(&a.1));
2517
2518            for (path, _) in candidates {
2519                let candidate_buffer = project
2520                    .update(cx, |project, cx| project.open_buffer(path, cx))
2521                    .await?;
2522
2523                let (has_collaborators, diagnostic_position) =
2524                    candidate_buffer.read_with(cx, |buffer, _cx| {
2525                        let snapshot = buffer.snapshot();
2526                        let has_collaborators = snapshot
2527                            .selections_in_range(
2528                                Anchor::min_max_range_for_buffer(snapshot.remote_id()),
2529                                false,
2530                            )
2531                            .next()
2532                            .is_some();
2533                        let position = buffer
2534                            .buffer_diagnostics(None)
2535                            .into_iter()
2536                            .min_by_key(|entry| entry.diagnostic.severity)
2537                            .map(|entry| entry.range.start);
2538                        (has_collaborators, position)
2539                    });
2540
2541                if has_collaborators {
2542                    continue;
2543                }
2544
2545                if let Some(position) = diagnostic_position {
2546                    jump_location = Some((candidate_buffer, position));
2547                    break;
2548                }
2549            }
2550        }
2551
2552        anyhow::Ok(jump_location)
2553    }
2554
2555    async fn send_raw_llm_request(
2556        request: RawCompletionRequest,
2557        client: Arc<Client>,
2558        custom_url: Option<Arc<Url>>,
2559        llm_token: LlmApiToken,
2560        organization_id: Option<OrganizationId>,
2561        app_version: Version,
2562    ) -> Result<(RawCompletionResponse, Option<EditPredictionUsage>)> {
2563        let url = if let Some(custom_url) = custom_url {
2564            custom_url.as_ref().clone()
2565        } else {
2566            client
2567                .http_client()
2568                .build_zed_llm_url("/predict_edits/raw", &[])?
2569        };
2570
2571        Self::send_api_request(
2572            |builder| {
2573                let req = builder
2574                    .uri(url.as_ref())
2575                    .body(serde_json::to_string(&request)?.into());
2576                Ok(req?)
2577            },
2578            client,
2579            llm_token,
2580            organization_id,
2581            app_version,
2582            true,
2583        )
2584        .await
2585    }
2586
2587    pub(crate) async fn send_v3_request(
2588        input: ZetaPromptInput,
2589        client: Arc<Client>,
2590        llm_token: LlmApiToken,
2591        organization_id: Option<OrganizationId>,
2592        app_version: Version,
2593        trigger: PredictEditsRequestTrigger,
2594        mode: PredictEditsMode,
2595    ) -> Result<(PredictEditsV3Response, Option<EditPredictionUsage>)> {
2596        let url = client
2597            .http_client()
2598            .build_zed_llm_url("/predict_edits/v3", &[])?;
2599
2600        let request = PredictEditsV3Request { input, trigger };
2601
2602        let json_bytes = serde_json::to_vec(&request)?;
2603        let compressed = zstd::encode_all(&json_bytes[..], 3)?;
2604
2605        Self::send_api_request(
2606            |builder| {
2607                let req = builder
2608                    .uri(url.as_ref())
2609                    .header("Content-Encoding", "zstd")
2610                    .header(PREDICT_EDITS_MODE_HEADER_NAME, mode.as_ref())
2611                    .body(compressed.clone().into());
2612                Ok(req?)
2613            },
2614            client,
2615            llm_token,
2616            organization_id,
2617            app_version,
2618            true,
2619        )
2620        .await
2621    }
2622
2623    async fn send_api_request<Res>(
2624        build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
2625        client: Arc<Client>,
2626        llm_token: LlmApiToken,
2627        organization_id: Option<OrganizationId>,
2628        app_version: Version,
2629        require_auth: bool,
2630    ) -> Result<(Res, Option<EditPredictionUsage>)>
2631    where
2632        Res: DeserializeOwned,
2633    {
2634        let http_client = client.http_client();
2635        let mut token = if require_auth {
2636            Some(
2637                client
2638                    .acquire_llm_token(&llm_token, organization_id.clone())
2639                    .await?,
2640            )
2641        } else {
2642            client
2643                .acquire_llm_token(&llm_token, organization_id.clone())
2644                .await
2645                .ok()
2646        };
2647        let mut did_retry = false;
2648
2649        loop {
2650            let request_builder = http_client::Request::builder().method(Method::POST);
2651
2652            let mut request_builder = request_builder
2653                .header("Content-Type", "application/json")
2654                .header(ZED_VERSION_HEADER_NAME, app_version.to_string());
2655
2656            // Only add Authorization header if we have a token
2657            if let Some(ref token_value) = token {
2658                request_builder =
2659                    request_builder.header("Authorization", format!("Bearer {}", token_value));
2660            }
2661
2662            let request = build(request_builder)?;
2663
2664            let mut response = http_client.send(request).await?;
2665
2666            if let Some(minimum_required_version) = response
2667                .headers()
2668                .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
2669                .and_then(|version| Version::from_str(version.to_str().ok()?).ok())
2670            {
2671                anyhow::ensure!(
2672                    app_version >= minimum_required_version,
2673                    ZedUpdateRequiredError {
2674                        minimum_version: minimum_required_version
2675                    }
2676                );
2677            }
2678
2679            if response.status().is_success() {
2680                let usage = EditPredictionUsage::from_headers(response.headers()).ok();
2681
2682                let mut body = Vec::new();
2683                response.body_mut().read_to_end(&mut body).await?;
2684                return Ok((serde_json::from_slice(&body)?, usage));
2685            } else if !did_retry && token.is_some() && response.needs_llm_token_refresh() {
2686                did_retry = true;
2687                token = Some(
2688                    client
2689                        .refresh_llm_token(&llm_token, organization_id.clone())
2690                        .await?,
2691                );
2692            } else {
2693                let mut body = String::new();
2694                response.body_mut().read_to_string(&mut body).await?;
2695                anyhow::bail!(
2696                    "Request failed with status: {:?}\nBody: {}",
2697                    response.status(),
2698                    body
2699                );
2700            }
2701        }
2702    }
2703
2704    pub fn refresh_context(
2705        &mut self,
2706        project: &Entity<Project>,
2707        buffer: &Entity<language::Buffer>,
2708        cursor_position: language::Anchor,
2709        cx: &mut Context<Self>,
2710    ) {
2711        self.get_or_init_project(project, cx)
2712            .context
2713            .update(cx, |store, cx| {
2714                store.refresh(buffer.clone(), cursor_position, cx);
2715            });
2716    }
2717
2718    #[cfg(feature = "cli-support")]
2719    pub fn set_context_for_buffer(
2720        &mut self,
2721        project: &Entity<Project>,
2722        related_files: Vec<RelatedFile>,
2723        cx: &mut Context<Self>,
2724    ) {
2725        self.get_or_init_project(project, cx)
2726            .context
2727            .update(cx, |store, cx| {
2728                store.set_related_files(related_files, cx);
2729            });
2730    }
2731
2732    #[cfg(feature = "cli-support")]
2733    pub fn set_recent_paths_for_project(
2734        &mut self,
2735        project: &Entity<Project>,
2736        paths: impl IntoIterator<Item = project::ProjectPath>,
2737        cx: &mut Context<Self>,
2738    ) {
2739        let project_state = self.get_or_init_project(project, cx);
2740        project_state.recent_paths = paths.into_iter().collect();
2741    }
2742
2743    fn is_file_open_source(
2744        &self,
2745        project: &Entity<Project>,
2746        file: &Arc<dyn File>,
2747        cx: &App,
2748    ) -> bool {
2749        if !file.is_local() || file.is_private() {
2750            return false;
2751        }
2752        let Some(project_state) = self.projects.get(&project.entity_id()) else {
2753            return false;
2754        };
2755        project_state
2756            .license_detection_watchers
2757            .get(&file.worktree_id(cx))
2758            .as_ref()
2759            .is_some_and(|watcher| watcher.is_project_open_source())
2760    }
2761
2762    pub(crate) fn is_data_collection_enabled(&self, cx: &App) -> bool {
2763        self.data_collection_choice.is_enabled(cx)
2764    }
2765
2766    fn load_data_collection_choice(cx: &App) -> DataCollectionChoice {
2767        let choice = KeyValueStore::global(cx)
2768            .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2769            .log_err()
2770            .flatten();
2771
2772        match choice.as_deref() {
2773            Some("true") => DataCollectionChoice::Enabled,
2774            Some("false") => DataCollectionChoice::Disabled,
2775            Some(_) => {
2776                log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
2777                DataCollectionChoice::NotAnswered
2778            }
2779            None => DataCollectionChoice::NotAnswered,
2780        }
2781    }
2782
2783    fn toggle_data_collection_choice(&mut self, cx: &mut Context<Self>) {
2784        self.data_collection_choice = self.data_collection_choice.toggle();
2785        let new_choice = self.data_collection_choice;
2786        let is_enabled = new_choice.is_enabled(cx);
2787        let kvp = KeyValueStore::global(cx);
2788        db::write_and_log(cx, move || async move {
2789            kvp.write_kvp(
2790                ZED_PREDICT_DATA_COLLECTION_CHOICE.into(),
2791                is_enabled.to_string(),
2792            )
2793            .await
2794        });
2795    }
2796
2797    pub fn shown_predictions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
2798        self.shown_predictions.iter()
2799    }
2800
2801    pub fn shown_completions_len(&self) -> usize {
2802        self.shown_predictions.len()
2803    }
2804
2805    pub fn is_prediction_rated(&self, id: &EditPredictionId) -> bool {
2806        self.rated_predictions.contains(id)
2807    }
2808
2809    pub fn rate_prediction(
2810        &mut self,
2811        prediction: &EditPrediction,
2812        rating: EditPredictionRating,
2813        feedback: String,
2814        cx: &mut Context<Self>,
2815    ) {
2816        let organization = self.user_store.read(cx).current_organization();
2817
2818        self.rated_predictions.insert(prediction.id.clone());
2819
2820        cx.background_spawn({
2821            let client = self.client.clone();
2822            let prediction_id = prediction.id.to_string();
2823            let inputs = serde_json::to_value(&prediction.inputs);
2824            let output = prediction
2825                .edit_preview
2826                .as_unified_diff(prediction.snapshot.file(), &prediction.edits);
2827            async move {
2828                client
2829                    .cloud_client()
2830                    .submit_edit_prediction_feedback(SubmitEditPredictionFeedbackBody {
2831                        organization_id: organization.map(|organization| organization.id.clone()),
2832                        request_id: prediction_id,
2833                        rating: match rating {
2834                            EditPredictionRating::Positive => "positive".to_string(),
2835                            EditPredictionRating::Negative => "negative".to_string(),
2836                        },
2837                        inputs: inputs?,
2838                        output,
2839                        feedback,
2840                    })
2841                    .await?;
2842
2843                anyhow::Ok(())
2844            }
2845        })
2846        .detach_and_log_err(cx);
2847
2848        cx.notify();
2849    }
2850}
2851
2852fn collaborator_edit_overlaps_locality_region(
2853    project_state: &ProjectState,
2854    project: &Entity<Project>,
2855    buffer: &Entity<Buffer>,
2856    snapshot: &BufferSnapshot,
2857    edit_range: &Range<Anchor>,
2858    cx: &App,
2859) -> bool {
2860    let Some((active_buffer, Some(position))) = project_state.active_buffer(project, cx) else {
2861        return false;
2862    };
2863
2864    if active_buffer.entity_id() != buffer.entity_id() {
2865        return false;
2866    }
2867
2868    let locality_point_range = expand_context_syntactically_then_linewise(
2869        snapshot,
2870        (position..position).to_point(snapshot),
2871        COLLABORATOR_EDIT_LOCALITY_CONTEXT_TOKENS,
2872    );
2873    let locality_anchor_range = snapshot.anchor_range_inside(locality_point_range);
2874
2875    edit_range.overlaps(&locality_anchor_range, snapshot)
2876}
2877
2878fn merge_trailing_events_if_needed(
2879    events: &mut VecDeque<StoredEvent>,
2880    end_snapshot: &TextBufferSnapshot,
2881    latest_snapshot: &TextBufferSnapshot,
2882    latest_edit_range: &Range<Anchor>,
2883) {
2884    if let Some(last_event) = events.back() {
2885        if last_event.old_snapshot.remote_id() != latest_snapshot.remote_id() {
2886            return;
2887        }
2888        if !latest_snapshot
2889            .version
2890            .observed_all(&last_event.new_snapshot_version)
2891        {
2892            return;
2893        }
2894    }
2895
2896    let mut next_old_event = None;
2897    let mut mergeable_count = 0;
2898    for old_event in events.iter().rev() {
2899        if let Some(next_old_event) = next_old_event
2900            && !old_event.can_merge(next_old_event, latest_snapshot, latest_edit_range)
2901        {
2902            break;
2903        }
2904        mergeable_count += 1;
2905        next_old_event = Some(old_event);
2906    }
2907
2908    if mergeable_count <= 1 {
2909        return;
2910    }
2911
2912    let mut events_to_merge = events.range(events.len() - mergeable_count..).peekable();
2913    let oldest_event = events_to_merge.peek().unwrap();
2914    let oldest_snapshot = oldest_event.old_snapshot.clone();
2915    let newest_snapshot = end_snapshot;
2916    let mut merged_edit_range = oldest_event.total_edit_range.clone();
2917
2918    for event in events.range(events.len() - mergeable_count + 1..) {
2919        merged_edit_range =
2920            merge_anchor_ranges(&merged_edit_range, &event.total_edit_range, latest_snapshot);
2921    }
2922
2923    if let Some((diff, edit_range)) = compute_diff_between_snapshots_in_range(
2924        &oldest_snapshot,
2925        newest_snapshot,
2926        &merged_edit_range,
2927    ) {
2928        let merged_event = match oldest_event.event.as_ref() {
2929            zeta_prompt::Event::BufferChange {
2930                old_path,
2931                path,
2932                in_open_source_repo,
2933                ..
2934            } => StoredEvent {
2935                event: Arc::new(zeta_prompt::Event::BufferChange {
2936                    old_path: old_path.clone(),
2937                    path: path.clone(),
2938                    diff,
2939                    in_open_source_repo: *in_open_source_repo,
2940                    predicted: events_to_merge.all(|e| {
2941                        matches!(
2942                            e.event.as_ref(),
2943                            zeta_prompt::Event::BufferChange {
2944                                predicted: true,
2945                                ..
2946                            }
2947                        )
2948                    }),
2949                }),
2950                old_snapshot: oldest_snapshot.clone(),
2951                new_snapshot_version: newest_snapshot.version.clone(),
2952                total_edit_range: newest_snapshot.anchor_before(edit_range.start)
2953                    ..newest_snapshot.anchor_before(edit_range.end),
2954            },
2955        };
2956        events.truncate(events.len() - mergeable_count);
2957        events.push_back(merged_event);
2958    }
2959}
2960
2961fn merge_anchor_ranges(
2962    left: &Range<Anchor>,
2963    right: &Range<Anchor>,
2964    snapshot: &TextBufferSnapshot,
2965) -> Range<Anchor> {
2966    let start = if left.start.cmp(&right.start, snapshot).is_le() {
2967        left.start
2968    } else {
2969        right.start
2970    };
2971    let end = if left.end.cmp(&right.end, snapshot).is_ge() {
2972        left.end
2973    } else {
2974        right.end
2975    };
2976    start..end
2977}
2978
2979#[derive(Error, Debug)]
2980#[error(
2981    "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
2982)]
2983pub struct ZedUpdateRequiredError {
2984    minimum_version: Version,
2985}
2986
2987#[derive(Debug, Clone, Copy)]
2988pub enum DataCollectionChoice {
2989    NotAnswered,
2990    Enabled,
2991    Disabled,
2992}
2993
2994impl DataCollectionChoice {
2995    pub fn is_enabled(self, cx: &App) -> bool {
2996        if cx.is_staff() {
2997            return true;
2998        }
2999        match self {
3000            Self::Enabled => true,
3001            Self::NotAnswered | Self::Disabled => false,
3002        }
3003    }
3004
3005    #[must_use]
3006    pub fn toggle(&self) -> DataCollectionChoice {
3007        match self {
3008            Self::Enabled => Self::Disabled,
3009            Self::Disabled => Self::Enabled,
3010            Self::NotAnswered => Self::Enabled,
3011        }
3012    }
3013}
3014
3015impl From<bool> for DataCollectionChoice {
3016    fn from(value: bool) -> Self {
3017        match value {
3018            true => DataCollectionChoice::Enabled,
3019            false => DataCollectionChoice::Disabled,
3020        }
3021    }
3022}
3023
3024struct ZedPredictUpsell;
3025
3026impl Dismissable for ZedPredictUpsell {
3027    const KEY: &'static str = "dismissed-edit-predict-upsell";
3028
3029    fn dismissed(cx: &App) -> bool {
3030        // To make this backwards compatible with older versions of Zed, we
3031        // check if the user has seen the previous Edit Prediction Onboarding
3032        // before, by checking the data collection choice which was written to
3033        // the database once the user clicked on "Accept and Enable"
3034        let kvp = KeyValueStore::global(cx);
3035        if kvp
3036            .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
3037            .log_err()
3038            .is_some_and(|s| s.is_some())
3039        {
3040            return true;
3041        }
3042
3043        kvp.read_kvp(Self::KEY)
3044            .log_err()
3045            .is_some_and(|s| s.is_some())
3046    }
3047}
3048
3049pub fn should_show_upsell_modal(cx: &App) -> bool {
3050    !ZedPredictUpsell::dismissed(cx)
3051}
3052
3053pub fn init(cx: &mut App) {
3054    cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
3055        workspace.register_action(
3056            move |workspace, _: &zed_actions::OpenZedPredictOnboarding, window, cx| {
3057                ZedPredictModal::toggle(
3058                    workspace,
3059                    workspace.user_store().clone(),
3060                    workspace.client().clone(),
3061                    window,
3062                    cx,
3063                )
3064            },
3065        );
3066
3067        workspace.register_action(|workspace, _: &ResetOnboarding, _window, cx| {
3068            update_settings_file(workspace.app_state().fs.clone(), cx, move |settings, _| {
3069                settings
3070                    .project
3071                    .all_languages
3072                    .edit_predictions
3073                    .get_or_insert_default()
3074                    .provider = Some(EditPredictionProvider::None)
3075            });
3076        });
3077        fn copilot_for_project(project: &Entity<Project>, cx: &mut App) -> Option<Entity<Copilot>> {
3078            EditPredictionStore::try_global(cx).and_then(|store| {
3079                store.update(cx, |this, cx| this.start_copilot_for_project(project, cx))
3080            })
3081        }
3082
3083        workspace.register_action(|workspace, _: &SignIn, window, cx| {
3084            if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
3085                copilot_ui::initiate_sign_in(copilot, window, cx);
3086            }
3087        });
3088        workspace.register_action(|workspace, _: &Reinstall, window, cx| {
3089            if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
3090                copilot_ui::reinstall_and_sign_in(copilot, window, cx);
3091            }
3092        });
3093        workspace.register_action(|workspace, _: &SignOut, window, cx| {
3094            if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
3095                copilot_ui::initiate_sign_out(copilot, window, cx);
3096            }
3097        });
3098    })
3099    .detach();
3100}