edit_prediction.rs

   1use anyhow::Result;
   2use client::{Client, EditPredictionUsage, NeedsLlmTokenRefresh, UserStore, global_llm_token};
   3use cloud_api_client::LlmApiToken;
   4use cloud_api_types::{OrganizationId, SubmitEditPredictionFeedbackBody};
   5use cloud_llm_client::predict_edits_v3::{
   6    PREDICT_EDITS_MODE_HEADER_NAME, PredictEditsMode, PredictEditsV3Request,
   7    PredictEditsV3Response, RawCompletionRequest, RawCompletionResponse,
   8};
   9use cloud_llm_client::{
  10    EditPredictionRejectReason, EditPredictionRejection,
  11    MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
  12    PREFERRED_EXPERIMENT_HEADER_NAME, PredictEditsRequestTrigger, RejectEditPredictionsBodyRef,
  13    ZED_VERSION_HEADER_NAME,
  14};
  15use collections::{HashMap, HashSet};
  16use copilot::{Copilot, Reinstall, SignIn, SignOut};
  17use credentials_provider::CredentialsProvider;
  18use db::kvp::{Dismissable, KeyValueStore};
  19use edit_prediction_context::{RelatedExcerptStore, RelatedExcerptStoreEvent, RelatedFile};
  20use feature_flags::{FeatureFlag, FeatureFlagAppExt as _, PresenceFlag, register_feature_flag};
  21use futures::{
  22    AsyncReadExt as _, FutureExt as _, StreamExt as _,
  23    channel::mpsc::{self, UnboundedReceiver},
  24    select_biased,
  25};
  26use gpui::BackgroundExecutor;
  27use gpui::http_client::Url;
  28use gpui::{
  29    App, AsyncApp, Entity, EntityId, Global, SharedString, Task, WeakEntity, actions,
  30    http_client::{self, AsyncBody, Method},
  31    prelude::*,
  32};
  33use heapless::Vec as ArrayVec;
  34use language::{
  35    Anchor, Buffer, BufferSnapshot, EditPredictionsMode, EditPreview, File, OffsetRangeExt, Point,
  36    TextBufferSnapshot, ToOffset, ToPoint, language_settings::all_language_settings,
  37};
  38use project::{DisableAiSettings, Project, ProjectPath, WorktreeId};
  39use release_channel::AppVersion;
  40use semver::Version;
  41use serde::de::DeserializeOwned;
  42use settings::{
  43    EditPredictionPromptFormat, EditPredictionProvider, Settings as _, update_settings_file,
  44};
  45use std::collections::{VecDeque, hash_map};
  46use std::env;
  47use text::{AnchorRangeExt, Edit};
  48use workspace::{AppState, Workspace};
  49use zeta_prompt::{ZetaFormat, ZetaPromptInput};
  50
  51use std::mem;
  52use std::ops::Range;
  53use std::path::Path;
  54use std::rc::Rc;
  55use std::str::FromStr as _;
  56use std::sync::Arc;
  57use std::time::{Duration, Instant};
  58
  59use thiserror::Error;
  60use util::{RangeExt as _, ResultExt as _};
  61
  62pub mod cursor_excerpt;
  63pub mod example_spec;
  64pub mod fim;
  65mod license_detection;
  66pub mod mercury;
  67pub mod metrics;
  68pub mod ollama;
  69mod onboarding_modal;
  70pub mod open_ai_response;
  71mod prediction;
  72
  73pub mod udiff;
  74
  75mod capture_example;
  76pub mod open_ai_compatible;
  77mod zed_edit_prediction_delegate;
  78pub mod zeta;
  79
  80#[cfg(test)]
  81mod edit_prediction_tests;
  82
  83use crate::cursor_excerpt::expand_context_syntactically_then_linewise;
  84use crate::example_spec::ExampleSpec;
  85use crate::license_detection::LicenseDetectionWatcher;
  86use crate::mercury::Mercury;
  87pub use crate::metrics::{KeptRateResult, compute_kept_rate};
  88use crate::onboarding_modal::ZedPredictModal;
  89pub use crate::prediction::EditPrediction;
  90pub use crate::prediction::EditPredictionId;
  91use crate::prediction::EditPredictionResult;
  92pub use capture_example::capture_example;
  93pub use language_model::ApiKeyState;
  94pub use telemetry_events::EditPredictionRating;
  95pub use zed_edit_prediction_delegate::ZedEditPredictionDelegate;
  96
  97actions!(
  98    edit_prediction,
  99    [
 100        /// Resets the edit prediction onboarding state.
 101        ResetOnboarding,
 102        /// Clears the edit prediction history.
 103        ClearHistory,
 104    ]
 105);
 106
 107/// Maximum number of events to track.
 108const EVENT_COUNT_MAX: usize = 10;
 109const CHANGE_GROUPING_LINE_SPAN: u32 = 8;
 110const EDIT_HISTORY_DIFF_SIZE_LIMIT: usize = 2048 * 3; // ~2048 tokens or ~50% of typical prompt budget
 111const COLLABORATOR_EDIT_LOCALITY_CONTEXT_TOKENS: usize = 512;
 112const LAST_CHANGE_GROUPING_TIME: Duration = Duration::from_secs(1);
 113const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
 114const REJECT_REQUEST_DEBOUNCE: Duration = Duration::from_secs(15);
 115const EDIT_PREDICTION_SETTLED_EVENT: &str = "Edit Prediction Settled";
 116const EDIT_PREDICTION_SETTLED_TTL: Duration = Duration::from_secs(60 * 5);
 117const EDIT_PREDICTION_SETTLED_QUIESCENCE: Duration = Duration::from_secs(10);
 118
 119pub struct EditPredictionJumpsFeatureFlag;
 120
 121impl FeatureFlag for EditPredictionJumpsFeatureFlag {
 122    const NAME: &'static str = "edit_prediction_jumps";
 123    type Value = PresenceFlag;
 124}
 125register_feature_flag!(EditPredictionJumpsFeatureFlag);
 126
 127#[derive(Clone)]
 128struct EditPredictionStoreGlobal(Entity<EditPredictionStore>);
 129
 130impl Global for EditPredictionStoreGlobal {}
 131
 132/// Configuration for using the raw Zeta2 endpoint.
 133/// When set, the client uses the raw endpoint and constructs the prompt itself.
 134/// The version is also used as the Baseten environment name (lowercased).
 135#[derive(Clone)]
 136pub struct Zeta2RawConfig {
 137    pub model_id: Option<String>,
 138    pub environment: Option<String>,
 139    pub format: ZetaFormat,
 140}
 141
 142pub struct EditPredictionStore {
 143    client: Arc<Client>,
 144    user_store: Entity<UserStore>,
 145    llm_token: LlmApiToken,
 146    _fetch_experiments_task: Task<()>,
 147    projects: HashMap<EntityId, ProjectState>,
 148    update_required: bool,
 149    edit_prediction_model: EditPredictionModel,
 150    zeta2_raw_config: Option<Zeta2RawConfig>,
 151    preferred_experiment: Option<String>,
 152    available_experiments: Vec<String>,
 153    pub mercury: Mercury,
 154    data_collection_choice: DataCollectionChoice,
 155    reject_predictions_tx: mpsc::UnboundedSender<EditPredictionRejectionPayload>,
 156    settled_predictions_tx: mpsc::UnboundedSender<Instant>,
 157    shown_predictions: VecDeque<EditPrediction>,
 158    rated_predictions: HashSet<EditPredictionId>,
 159    #[cfg(test)]
 160    settled_event_callback: Option<Box<dyn Fn(EditPredictionId, String)>>,
 161    credentials_provider: Arc<dyn CredentialsProvider>,
 162}
 163
 164pub(crate) struct EditPredictionRejectionPayload {
 165    rejection: EditPredictionRejection,
 166    organization_id: Option<OrganizationId>,
 167}
 168
 169#[derive(Copy, Clone, PartialEq, Eq)]
 170pub enum EditPredictionModel {
 171    Zeta,
 172    Fim { format: EditPredictionPromptFormat },
 173    Mercury,
 174}
 175
 176#[derive(Clone)]
 177pub struct EditPredictionModelInput {
 178    project: Entity<Project>,
 179    buffer: Entity<Buffer>,
 180    snapshot: BufferSnapshot,
 181    position: Anchor,
 182    events: Vec<Arc<zeta_prompt::Event>>,
 183    related_files: Vec<RelatedFile>,
 184    mode: PredictEditsMode,
 185    trigger: PredictEditsRequestTrigger,
 186    diagnostic_search_range: Range<Point>,
 187    debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
 188    can_collect_data: bool,
 189    is_open_source: bool,
 190}
 191
 192#[derive(Debug)]
 193pub enum DebugEvent {
 194    ContextRetrievalStarted(ContextRetrievalStartedDebugEvent),
 195    ContextRetrievalFinished(ContextRetrievalFinishedDebugEvent),
 196    EditPredictionStarted(EditPredictionStartedDebugEvent),
 197    EditPredictionFinished(EditPredictionFinishedDebugEvent),
 198}
 199
 200#[derive(Debug)]
 201pub struct ContextRetrievalStartedDebugEvent {
 202    pub project_entity_id: EntityId,
 203    pub timestamp: Instant,
 204    pub search_prompt: String,
 205}
 206
 207#[derive(Debug)]
 208pub struct ContextRetrievalFinishedDebugEvent {
 209    pub project_entity_id: EntityId,
 210    pub timestamp: Instant,
 211    pub metadata: Vec<(&'static str, SharedString)>,
 212}
 213
 214#[derive(Debug)]
 215pub struct EditPredictionStartedDebugEvent {
 216    pub buffer: WeakEntity<Buffer>,
 217    pub position: Anchor,
 218    pub prompt: Option<String>,
 219}
 220
 221#[derive(Debug)]
 222pub struct EditPredictionFinishedDebugEvent {
 223    pub buffer: WeakEntity<Buffer>,
 224    pub position: Anchor,
 225    pub model_output: Option<String>,
 226}
 227
 228/// An event with associated metadata for reconstructing buffer state.
 229#[derive(Clone)]
 230pub struct StoredEvent {
 231    pub event: Arc<zeta_prompt::Event>,
 232    pub old_snapshot: TextBufferSnapshot,
 233    pub new_snapshot_version: clock::Global,
 234    pub total_edit_range: Range<Anchor>,
 235}
 236
 237impl StoredEvent {
 238    fn can_merge(
 239        &self,
 240        next_old_event: &StoredEvent,
 241        latest_snapshot: &TextBufferSnapshot,
 242        latest_edit_range: &Range<Anchor>,
 243    ) -> bool {
 244        // Events must be for the same buffer and be contiguous across included snapshots to be mergeable.
 245        if self.old_snapshot.remote_id() != next_old_event.old_snapshot.remote_id() {
 246            return false;
 247        }
 248        if self.old_snapshot.remote_id() != latest_snapshot.remote_id() {
 249            return false;
 250        }
 251        if self.new_snapshot_version != next_old_event.old_snapshot.version {
 252            return false;
 253        }
 254        if !latest_snapshot
 255            .version
 256            .observed_all(&next_old_event.new_snapshot_version)
 257        {
 258            return false;
 259        }
 260
 261        let a_is_predicted = matches!(
 262            self.event.as_ref(),
 263            zeta_prompt::Event::BufferChange {
 264                predicted: true,
 265                ..
 266            }
 267        );
 268        let b_is_predicted = matches!(
 269            next_old_event.event.as_ref(),
 270            zeta_prompt::Event::BufferChange {
 271                predicted: true,
 272                ..
 273            }
 274        );
 275
 276        // If events come from the same source (both predicted or both manual) then
 277        // we would have coalesced them already.
 278        if a_is_predicted == b_is_predicted {
 279            return false;
 280        }
 281
 282        let left_range = self.total_edit_range.to_point(latest_snapshot);
 283        let right_range = next_old_event.total_edit_range.to_point(latest_snapshot);
 284        let latest_range = latest_edit_range.to_point(latest_snapshot);
 285
 286        // Events near to the latest edit are not merged if their sources differ.
 287        if lines_between_ranges(&left_range, &latest_range)
 288            .min(lines_between_ranges(&right_range, &latest_range))
 289            <= CHANGE_GROUPING_LINE_SPAN
 290        {
 291            return false;
 292        }
 293
 294        // Events that are distant from each other are not merged.
 295        if lines_between_ranges(&left_range, &right_range) > CHANGE_GROUPING_LINE_SPAN {
 296            return false;
 297        }
 298
 299        true
 300    }
 301}
 302
 303fn lines_between_ranges(left: &Range<Point>, right: &Range<Point>) -> u32 {
 304    if left.start > right.end {
 305        return left.start.row - right.end.row;
 306    }
 307    if right.start > left.end {
 308        return right.start.row - left.end.row;
 309    }
 310    0
 311}
 312
 313struct ProjectState {
 314    events: VecDeque<StoredEvent>,
 315    last_event: Option<LastEvent>,
 316    recent_paths: VecDeque<ProjectPath>,
 317    registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
 318    current_prediction: Option<CurrentEditPrediction>,
 319    next_pending_prediction_id: usize,
 320    pending_predictions: ArrayVec<PendingPrediction, 2, u8>,
 321    debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
 322    last_edit_prediction_refresh: Option<(EntityId, Instant)>,
 323    last_jump_prediction_refresh: Option<(EntityId, Instant)>,
 324    cancelled_predictions: HashSet<usize>,
 325    context: Entity<RelatedExcerptStore>,
 326    license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
 327    _subscriptions: [gpui::Subscription; 2],
 328    copilot: Option<Entity<Copilot>>,
 329}
 330
 331impl ProjectState {
 332    pub fn events(&self, cx: &App) -> Vec<StoredEvent> {
 333        self.events
 334            .iter()
 335            .cloned()
 336            .chain(self.last_event.as_ref().iter().flat_map(|event| {
 337                let (one, two) = event.split_by_pause();
 338                let one = one.finalize(&self.license_detection_watchers, cx);
 339                let two = two.and_then(|two| two.finalize(&self.license_detection_watchers, cx));
 340                one.into_iter().chain(two)
 341            }))
 342            .collect()
 343    }
 344
 345    fn cancel_pending_prediction(
 346        &mut self,
 347        pending_prediction: PendingPrediction,
 348        cx: &mut Context<EditPredictionStore>,
 349    ) {
 350        self.cancelled_predictions.insert(pending_prediction.id);
 351
 352        if pending_prediction.drop_on_cancel {
 353            drop(pending_prediction.task);
 354        } else {
 355            cx.spawn(async move |this, cx| {
 356                let Some((prediction_id, model_version)) = pending_prediction.task.await else {
 357                    return;
 358                };
 359
 360                this.update(cx, |this, cx| {
 361                    this.reject_prediction(
 362                        prediction_id,
 363                        EditPredictionRejectReason::Canceled,
 364                        false,
 365                        model_version,
 366                        None,
 367                        cx,
 368                    );
 369                })
 370                .ok();
 371            })
 372            .detach()
 373        }
 374    }
 375
 376    fn active_buffer(
 377        &self,
 378        project: &Entity<Project>,
 379        cx: &App,
 380    ) -> Option<(Entity<Buffer>, Option<Anchor>)> {
 381        let project = project.read(cx);
 382        let active_path = project.path_for_entry(project.active_entry()?, cx)?;
 383        let active_buffer = project.buffer_store().read(cx).get_by_path(&active_path)?;
 384        let registered_buffer = self.registered_buffers.get(&active_buffer.entity_id())?;
 385        Some((active_buffer, registered_buffer.last_position))
 386    }
 387}
 388
 389#[derive(Debug, Clone)]
 390struct CurrentEditPrediction {
 391    pub requested_by: PredictionRequestedBy,
 392    pub prediction: EditPrediction,
 393    pub was_shown: bool,
 394    pub shown_with: Option<edit_prediction_types::SuggestionDisplayType>,
 395    pub e2e_latency: std::time::Duration,
 396}
 397
 398impl CurrentEditPrediction {
 399    fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
 400        let Some(new_edits) = self
 401            .prediction
 402            .interpolate(&self.prediction.buffer.read(cx))
 403        else {
 404            return false;
 405        };
 406
 407        if self.prediction.buffer != old_prediction.prediction.buffer {
 408            return true;
 409        }
 410
 411        let Some(old_edits) = old_prediction
 412            .prediction
 413            .interpolate(&old_prediction.prediction.buffer.read(cx))
 414        else {
 415            return true;
 416        };
 417
 418        let requested_by_buffer_id = self.requested_by.buffer_id();
 419
 420        // This reduces the occurrence of UI thrash from replacing edits
 421        //
 422        // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
 423        if requested_by_buffer_id == Some(self.prediction.buffer.entity_id())
 424            && requested_by_buffer_id == Some(old_prediction.prediction.buffer.entity_id())
 425            && old_edits.len() == 1
 426            && new_edits.len() == 1
 427        {
 428            let (old_range, old_text) = &old_edits[0];
 429            let (new_range, new_text) = &new_edits[0];
 430            new_range == old_range && new_text.starts_with(old_text.as_ref())
 431        } else {
 432            true
 433        }
 434    }
 435}
 436
 437#[derive(Debug, Clone)]
 438enum PredictionRequestedBy {
 439    DiagnosticsUpdate,
 440    Buffer(EntityId),
 441}
 442
 443impl PredictionRequestedBy {
 444    pub fn buffer_id(&self) -> Option<EntityId> {
 445        match self {
 446            PredictionRequestedBy::DiagnosticsUpdate => None,
 447            PredictionRequestedBy::Buffer(buffer_id) => Some(*buffer_id),
 448        }
 449    }
 450}
 451
 452const DIAGNOSTIC_LINES_RANGE: u32 = 20;
 453
 454#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
 455pub enum DiagnosticSearchScope {
 456    Local,
 457    Global,
 458}
 459
 460#[derive(Debug)]
 461struct PendingPrediction {
 462    id: usize,
 463    task: Task<Option<(EditPredictionId, Option<String>)>>,
 464    /// If true, the task is dropped immediately on cancel (cancelling the HTTP request).
 465    /// If false, the task is awaited to completion so rejection can be reported.
 466    drop_on_cancel: bool,
 467}
 468
 469/// A prediction from the perspective of a buffer.
 470#[derive(Debug)]
 471enum BufferEditPrediction<'a> {
 472    Local { prediction: &'a EditPrediction },
 473    Jump { prediction: &'a EditPrediction },
 474}
 475
 476#[cfg(test)]
 477impl std::ops::Deref for BufferEditPrediction<'_> {
 478    type Target = EditPrediction;
 479
 480    fn deref(&self) -> &Self::Target {
 481        match self {
 482            BufferEditPrediction::Local { prediction } => prediction,
 483            BufferEditPrediction::Jump { prediction } => prediction,
 484        }
 485    }
 486}
 487
 488#[derive(Clone)]
 489struct PendingSettledPrediction {
 490    request_id: EditPredictionId,
 491    editable_anchor_range: Range<Anchor>,
 492    editable_region_before_prediction: String,
 493    predicted_editable_region: String,
 494    ts_error_count_before_prediction: usize,
 495    ts_error_count_after_prediction: usize,
 496    example: Option<ExampleSpec>,
 497    enqueued_at: Instant,
 498    last_edit_at: Instant,
 499    e2e_latency: std::time::Duration,
 500}
 501
 502struct RegisteredBuffer {
 503    file: Option<Arc<dyn File>>,
 504    snapshot: TextBufferSnapshot,
 505    pending_predictions: Vec<PendingSettledPrediction>,
 506    last_position: Option<Anchor>,
 507    _subscriptions: [gpui::Subscription; 2],
 508}
 509
 510#[derive(Clone)]
 511struct LastEvent {
 512    old_snapshot: TextBufferSnapshot,
 513    new_snapshot: TextBufferSnapshot,
 514    old_file: Option<Arc<dyn File>>,
 515    new_file: Option<Arc<dyn File>>,
 516    latest_edit_range: Range<Anchor>,
 517    total_edit_range: Range<Anchor>,
 518    total_edit_range_at_last_pause_boundary: Option<Range<Anchor>>,
 519    predicted: bool,
 520    snapshot_after_last_editing_pause: Option<TextBufferSnapshot>,
 521    last_edit_time: Option<Instant>,
 522}
 523
 524impl LastEvent {
 525    pub fn finalize(
 526        &self,
 527        license_detection_watchers: &HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
 528        cx: &App,
 529    ) -> Option<StoredEvent> {
 530        let path = buffer_path_with_id_fallback(self.new_file.as_ref(), &self.new_snapshot, cx);
 531        let old_path = buffer_path_with_id_fallback(self.old_file.as_ref(), &self.old_snapshot, cx);
 532
 533        let in_open_source_repo =
 534            [self.new_file.as_ref(), self.old_file.as_ref()]
 535                .iter()
 536                .all(|file| {
 537                    file.is_some_and(|file| {
 538                        license_detection_watchers
 539                            .get(&file.worktree_id(cx))
 540                            .is_some_and(|watcher| watcher.is_project_open_source())
 541                    })
 542                });
 543
 544        let (diff, edit_range) = compute_diff_between_snapshots_in_range(
 545            &self.old_snapshot,
 546            &self.new_snapshot,
 547            &self.total_edit_range,
 548        )?;
 549
 550        if path == old_path && diff.is_empty() {
 551            None
 552        } else {
 553            Some(StoredEvent {
 554                event: Arc::new(zeta_prompt::Event::BufferChange {
 555                    old_path,
 556                    path,
 557                    diff,
 558                    in_open_source_repo,
 559                    predicted: self.predicted,
 560                }),
 561                old_snapshot: self.old_snapshot.clone(),
 562                new_snapshot_version: self.new_snapshot.version.clone(),
 563                total_edit_range: self.new_snapshot.anchor_before(edit_range.start)
 564                    ..self.new_snapshot.anchor_before(edit_range.end),
 565            })
 566        }
 567    }
 568
 569    pub fn split_by_pause(&self) -> (LastEvent, Option<LastEvent>) {
 570        let Some(boundary_snapshot) = self.snapshot_after_last_editing_pause.as_ref() else {
 571            return (self.clone(), None);
 572        };
 573
 574        let total_edit_range_before_pause = self
 575            .total_edit_range_at_last_pause_boundary
 576            .clone()
 577            .unwrap_or_else(|| self.total_edit_range.clone());
 578
 579        let Some(total_edit_range_after_pause) =
 580            compute_total_edit_range_between_snapshots(boundary_snapshot, &self.new_snapshot)
 581        else {
 582            return (self.clone(), None);
 583        };
 584
 585        let latest_edit_range_before_pause = total_edit_range_before_pause.clone();
 586        let latest_edit_range_after_pause = total_edit_range_after_pause.clone();
 587
 588        let before = LastEvent {
 589            old_snapshot: self.old_snapshot.clone(),
 590            new_snapshot: boundary_snapshot.clone(),
 591            old_file: self.old_file.clone(),
 592            new_file: self.new_file.clone(),
 593            latest_edit_range: latest_edit_range_before_pause,
 594            total_edit_range: total_edit_range_before_pause,
 595            total_edit_range_at_last_pause_boundary: None,
 596            predicted: self.predicted,
 597            snapshot_after_last_editing_pause: None,
 598            last_edit_time: self.last_edit_time,
 599        };
 600
 601        let after = LastEvent {
 602            old_snapshot: boundary_snapshot.clone(),
 603            new_snapshot: self.new_snapshot.clone(),
 604            old_file: self.old_file.clone(),
 605            new_file: self.new_file.clone(),
 606            latest_edit_range: latest_edit_range_after_pause,
 607            total_edit_range: total_edit_range_after_pause,
 608            total_edit_range_at_last_pause_boundary: None,
 609            predicted: self.predicted,
 610            snapshot_after_last_editing_pause: None,
 611            last_edit_time: self.last_edit_time,
 612        };
 613
 614        (before, Some(after))
 615    }
 616}
 617
 618fn compute_total_edit_range_between_snapshots(
 619    old_snapshot: &TextBufferSnapshot,
 620    new_snapshot: &TextBufferSnapshot,
 621) -> Option<Range<Anchor>> {
 622    let edits: Vec<Edit<usize>> = new_snapshot
 623        .edits_since::<usize>(&old_snapshot.version)
 624        .collect();
 625
 626    let (first_edit, last_edit) = edits.first().zip(edits.last())?;
 627    let new_start_point = new_snapshot.offset_to_point(first_edit.new.start);
 628    let new_end_point = new_snapshot.offset_to_point(last_edit.new.end);
 629
 630    Some(new_snapshot.anchor_before(new_start_point)..new_snapshot.anchor_before(new_end_point))
 631}
 632
 633fn compute_old_range_for_new_range(
 634    old_snapshot: &TextBufferSnapshot,
 635    new_snapshot: &TextBufferSnapshot,
 636    total_edit_range: &Range<Anchor>,
 637) -> Option<Range<Point>> {
 638    let new_start_offset = total_edit_range.start.to_offset(new_snapshot);
 639    let new_end_offset = total_edit_range.end.to_offset(new_snapshot);
 640
 641    let edits: Vec<Edit<usize>> = new_snapshot
 642        .edits_since::<usize>(&old_snapshot.version)
 643        .collect();
 644    let mut old_start_offset = None;
 645    let mut old_end_offset = None;
 646    let mut delta: isize = 0;
 647
 648    for edit in &edits {
 649        if old_start_offset.is_none() && new_start_offset <= edit.new.end {
 650            old_start_offset = Some(if new_start_offset < edit.new.start {
 651                new_start_offset.checked_add_signed(-delta)?
 652            } else {
 653                edit.old.start
 654            });
 655        }
 656
 657        if old_end_offset.is_none() && new_end_offset <= edit.new.end {
 658            old_end_offset = Some(if new_end_offset < edit.new.start {
 659                new_end_offset.checked_add_signed(-delta)?
 660            } else {
 661                edit.old.end
 662            });
 663        }
 664
 665        delta += edit.new.len() as isize - edit.old.len() as isize;
 666    }
 667
 668    let old_start_offset =
 669        old_start_offset.unwrap_or_else(|| new_start_offset.saturating_add_signed(-delta));
 670    let old_end_offset =
 671        old_end_offset.unwrap_or_else(|| new_end_offset.saturating_add_signed(-delta));
 672
 673    Some(
 674        old_snapshot.offset_to_point(old_start_offset)
 675            ..old_snapshot.offset_to_point(old_end_offset),
 676    )
 677}
 678
 679fn compute_diff_between_snapshots_in_range(
 680    old_snapshot: &TextBufferSnapshot,
 681    new_snapshot: &TextBufferSnapshot,
 682    total_edit_range: &Range<Anchor>,
 683) -> Option<(String, Range<Point>)> {
 684    let new_start_point = total_edit_range.start.to_point(new_snapshot);
 685    let new_end_point = total_edit_range.end.to_point(new_snapshot);
 686    let old_range = compute_old_range_for_new_range(old_snapshot, new_snapshot, total_edit_range)?;
 687    let old_start_point = old_range.start;
 688    let old_end_point = old_range.end;
 689
 690    const CONTEXT_LINES: u32 = 3;
 691
 692    let old_context_start_row = old_start_point.row.saturating_sub(CONTEXT_LINES);
 693    let new_context_start_row = new_start_point.row.saturating_sub(CONTEXT_LINES);
 694    let old_context_end_row =
 695        (old_end_point.row + 1 + CONTEXT_LINES).min(old_snapshot.max_point().row);
 696    let new_context_end_row =
 697        (new_end_point.row + 1 + CONTEXT_LINES).min(new_snapshot.max_point().row);
 698
 699    let old_start_line_offset = old_snapshot.point_to_offset(Point::new(old_context_start_row, 0));
 700    let new_start_line_offset = new_snapshot.point_to_offset(Point::new(new_context_start_row, 0));
 701    let old_end_line_offset = old_snapshot
 702        .point_to_offset(Point::new(old_context_end_row + 1, 0).min(old_snapshot.max_point()));
 703    let new_end_line_offset = new_snapshot
 704        .point_to_offset(Point::new(new_context_end_row + 1, 0).min(new_snapshot.max_point()));
 705    let old_edit_range = old_start_line_offset..old_end_line_offset;
 706    let new_edit_range = new_start_line_offset..new_end_line_offset;
 707
 708    if new_edit_range.len() > EDIT_HISTORY_DIFF_SIZE_LIMIT
 709        || old_edit_range.len() > EDIT_HISTORY_DIFF_SIZE_LIMIT
 710    {
 711        return None;
 712    }
 713
 714    let old_region_text: String = old_snapshot.text_for_range(old_edit_range).collect();
 715    let new_region_text: String = new_snapshot.text_for_range(new_edit_range).collect();
 716
 717    let diff = language::unified_diff_with_offsets(
 718        &old_region_text,
 719        &new_region_text,
 720        old_context_start_row,
 721        new_context_start_row,
 722    );
 723
 724    Some((diff, new_start_point..new_end_point))
 725}
 726
 727fn buffer_path_with_id_fallback(
 728    file: Option<&Arc<dyn File>>,
 729    snapshot: &TextBufferSnapshot,
 730    cx: &App,
 731) -> Arc<Path> {
 732    if let Some(file) = file {
 733        file.full_path(cx).into()
 734    } else {
 735        Path::new(&format!("untitled-{}", snapshot.remote_id())).into()
 736    }
 737}
 738
 739impl EditPredictionStore {
 740    pub fn try_global(cx: &App) -> Option<Entity<Self>> {
 741        cx.try_global::<EditPredictionStoreGlobal>()
 742            .map(|global| global.0.clone())
 743    }
 744
 745    pub fn global(
 746        client: &Arc<Client>,
 747        user_store: &Entity<UserStore>,
 748        cx: &mut App,
 749    ) -> Entity<Self> {
 750        cx.try_global::<EditPredictionStoreGlobal>()
 751            .map(|global| global.0.clone())
 752            .unwrap_or_else(|| {
 753                let ep_store = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
 754                cx.set_global(EditPredictionStoreGlobal(ep_store.clone()));
 755                ep_store
 756            })
 757    }
 758
 759    pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
 760        let data_collection_choice = Self::load_data_collection_choice(cx);
 761
 762        let llm_token = global_llm_token(cx);
 763
 764        let (reject_tx, reject_rx) = mpsc::unbounded();
 765        cx.background_spawn({
 766            let client = client.clone();
 767            let llm_token = llm_token.clone();
 768            let app_version = AppVersion::global(cx);
 769            let background_executor = cx.background_executor().clone();
 770            async move {
 771                Self::handle_rejected_predictions(
 772                    reject_rx,
 773                    client,
 774                    llm_token,
 775                    app_version,
 776                    background_executor,
 777                )
 778                .await
 779            }
 780        })
 781        .detach();
 782
 783        let (settled_predictions_tx, settled_predictions_rx) = mpsc::unbounded();
 784        cx.spawn(async move |this, cx| {
 785            Self::run_settled_predictions_worker(this, settled_predictions_rx, cx).await;
 786        })
 787        .detach();
 788
 789        let mut current_user = user_store.read(cx).watch_current_user();
 790        let fetch_experiments_task = cx.spawn(async move |this, cx| {
 791            while current_user.borrow().is_none() {
 792                current_user.next().await;
 793            }
 794
 795            this.update(cx, |this, cx| {
 796                if cx.is_staff() {
 797                    this.refresh_available_experiments(cx);
 798                }
 799            })
 800            .log_err();
 801        });
 802
 803        let credentials_provider = zed_credentials_provider::global(cx);
 804
 805        let this = Self {
 806            projects: HashMap::default(),
 807            client,
 808            user_store,
 809            llm_token,
 810            _fetch_experiments_task: fetch_experiments_task,
 811            update_required: false,
 812            edit_prediction_model: EditPredictionModel::Zeta,
 813            zeta2_raw_config: Self::zeta2_raw_config_from_env(),
 814            preferred_experiment: None,
 815            available_experiments: Vec::new(),
 816            mercury: Mercury::new(cx),
 817
 818            data_collection_choice,
 819            reject_predictions_tx: reject_tx,
 820            settled_predictions_tx,
 821            rated_predictions: Default::default(),
 822            shown_predictions: Default::default(),
 823            #[cfg(test)]
 824            settled_event_callback: None,
 825
 826            credentials_provider,
 827        };
 828
 829        this
 830    }
 831
 832    fn zeta2_raw_config_from_env() -> Option<Zeta2RawConfig> {
 833        let version_str = env::var("ZED_ZETA_FORMAT").ok()?;
 834        let format = ZetaFormat::parse(&version_str).ok()?;
 835        let model_id = env::var("ZED_ZETA_MODEL").ok();
 836        let environment = env::var("ZED_ZETA_ENVIRONMENT").ok();
 837        Some(Zeta2RawConfig {
 838            model_id,
 839            environment,
 840            format,
 841        })
 842    }
 843
 844    pub fn set_edit_prediction_model(&mut self, model: EditPredictionModel) {
 845        self.edit_prediction_model = model;
 846    }
 847
 848    pub fn set_zeta2_raw_config(&mut self, config: Zeta2RawConfig) {
 849        self.zeta2_raw_config = Some(config);
 850    }
 851
 852    pub fn zeta2_raw_config(&self) -> Option<&Zeta2RawConfig> {
 853        self.zeta2_raw_config.as_ref()
 854    }
 855
 856    pub fn preferred_experiment(&self) -> Option<&str> {
 857        self.preferred_experiment.as_deref()
 858    }
 859
 860    pub fn set_preferred_experiment(&mut self, experiment: Option<String>) {
 861        self.preferred_experiment = experiment;
 862    }
 863
 864    pub fn available_experiments(&self) -> &[String] {
 865        &self.available_experiments
 866    }
 867
 868    pub fn active_experiment(&self) -> Option<&str> {
 869        self.preferred_experiment.as_deref().or_else(|| {
 870            self.shown_predictions
 871                .iter()
 872                .find_map(|p| p.model_version.as_ref())
 873                .and_then(|model_version| model_version.strip_prefix("zeta2:"))
 874        })
 875    }
 876
 877    pub fn refresh_available_experiments(&mut self, cx: &mut Context<Self>) {
 878        let client = self.client.clone();
 879        let llm_token = self.llm_token.clone();
 880        let app_version = AppVersion::global(cx);
 881        let organization_id = self
 882            .user_store
 883            .read(cx)
 884            .current_organization()
 885            .map(|organization| organization.id.clone());
 886
 887        cx.spawn(async move |this, cx| {
 888            let experiments = cx
 889                .background_spawn(async move {
 890                    let http_client = client.http_client();
 891                    let token = client
 892                        .acquire_llm_token(&llm_token, organization_id.clone())
 893                        .await?;
 894                    let url = http_client.build_zed_llm_url("/edit_prediction_experiments", &[])?;
 895                    let request = http_client::Request::builder()
 896                        .method(Method::GET)
 897                        .uri(url.as_ref())
 898                        .header("Authorization", format!("Bearer {}", token))
 899                        .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
 900                        .body(Default::default())?;
 901                    let mut response = http_client.send(request).await?;
 902                    if response.status().is_success() {
 903                        let mut body = Vec::new();
 904                        response.body_mut().read_to_end(&mut body).await?;
 905                        let experiments: Vec<String> = serde_json::from_slice(&body)?;
 906                        Ok(experiments)
 907                    } else {
 908                        let mut body = String::new();
 909                        response.body_mut().read_to_string(&mut body).await?;
 910                        anyhow::bail!(
 911                            "Failed to fetch experiments: {:?}\nBody: {}",
 912                            response.status(),
 913                            body
 914                        );
 915                    }
 916                })
 917                .await?;
 918            this.update(cx, |this, cx| {
 919                this.available_experiments = experiments;
 920                cx.notify();
 921            })?;
 922            anyhow::Ok(())
 923        })
 924        .detach_and_log_err(cx);
 925    }
 926
 927    pub fn icons(&self, cx: &App) -> edit_prediction_types::EditPredictionIconSet {
 928        use ui::IconName;
 929        match self.edit_prediction_model {
 930            EditPredictionModel::Mercury => {
 931                edit_prediction_types::EditPredictionIconSet::new(IconName::Inception)
 932            }
 933            EditPredictionModel::Zeta => {
 934                edit_prediction_types::EditPredictionIconSet::new(IconName::ZedPredict)
 935                    .with_disabled(IconName::ZedPredictDisabled)
 936                    .with_up(IconName::ZedPredictUp)
 937                    .with_down(IconName::ZedPredictDown)
 938                    .with_error(IconName::ZedPredictError)
 939            }
 940            EditPredictionModel::Fim { .. } => {
 941                let settings = &all_language_settings(None, cx).edit_predictions;
 942                match settings.provider {
 943                    EditPredictionProvider::Ollama => {
 944                        edit_prediction_types::EditPredictionIconSet::new(IconName::AiOllama)
 945                    }
 946                    _ => {
 947                        edit_prediction_types::EditPredictionIconSet::new(IconName::AiOpenAiCompat)
 948                    }
 949                }
 950            }
 951        }
 952    }
 953
 954    pub fn has_mercury_api_token(&self, cx: &App) -> bool {
 955        self.mercury.api_token.read(cx).has_key()
 956    }
 957
 958    pub fn mercury_has_payment_required_error(&self) -> bool {
 959        self.mercury.has_payment_required_error()
 960    }
 961
 962    pub fn clear_history(&mut self) {
 963        for project_state in self.projects.values_mut() {
 964            project_state.events.clear();
 965            project_state.last_event.take();
 966        }
 967    }
 968
 969    pub fn clear_history_for_project(&mut self, project: &Entity<Project>) {
 970        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
 971            project_state.events.clear();
 972            project_state.last_event.take();
 973        }
 974    }
 975
 976    pub fn edit_history_for_project(
 977        &self,
 978        project: &Entity<Project>,
 979        cx: &App,
 980    ) -> Vec<StoredEvent> {
 981        self.projects
 982            .get(&project.entity_id())
 983            .map(|project_state| project_state.events(cx))
 984            .unwrap_or_default()
 985    }
 986
 987    pub fn context_for_project<'a>(
 988        &'a self,
 989        project: &Entity<Project>,
 990        cx: &'a mut App,
 991    ) -> Vec<RelatedFile> {
 992        self.projects
 993            .get(&project.entity_id())
 994            .map(|project_state| {
 995                project_state.context.update(cx, |context, cx| {
 996                    context
 997                        .related_files_with_buffers(cx)
 998                        .map(|(mut related_file, buffer)| {
 999                            related_file.in_open_source_repo = buffer
1000                                .read(cx)
1001                                .file()
1002                                .map_or(false, |file| self.is_file_open_source(&project, file, cx));
1003                            related_file
1004                        })
1005                        .collect()
1006                })
1007            })
1008            .unwrap_or_default()
1009    }
1010
1011    pub fn copilot_for_project(&self, project: &Entity<Project>) -> Option<Entity<Copilot>> {
1012        self.projects
1013            .get(&project.entity_id())
1014            .and_then(|project| project.copilot.clone())
1015    }
1016
1017    pub fn start_copilot_for_project(
1018        &mut self,
1019        project: &Entity<Project>,
1020        cx: &mut Context<Self>,
1021    ) -> Option<Entity<Copilot>> {
1022        if DisableAiSettings::get(None, cx).disable_ai {
1023            return None;
1024        }
1025        let state = self.get_or_init_project(project, cx);
1026
1027        if state.copilot.is_some() {
1028            return state.copilot.clone();
1029        }
1030        let _project = project.clone();
1031        let project = project.read(cx);
1032
1033        let node = project.node_runtime().cloned();
1034        if let Some(node) = node {
1035            let next_id = project.languages().next_language_server_id();
1036            let fs = project.fs().clone();
1037
1038            let copilot = cx.new(|cx| Copilot::new(Some(_project), next_id, fs, node, cx));
1039            state.copilot = Some(copilot.clone());
1040            Some(copilot)
1041        } else {
1042            None
1043        }
1044    }
1045
1046    pub fn context_for_project_with_buffers<'a>(
1047        &'a self,
1048        project: &Entity<Project>,
1049        cx: &'a mut App,
1050    ) -> Vec<(RelatedFile, Entity<Buffer>)> {
1051        self.projects
1052            .get(&project.entity_id())
1053            .map(|project| {
1054                project.context.update(cx, |context, cx| {
1055                    context.related_files_with_buffers(cx).collect()
1056                })
1057            })
1058            .unwrap_or_default()
1059    }
1060
1061    pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
1062        if matches!(self.edit_prediction_model, EditPredictionModel::Zeta) {
1063            self.user_store.read(cx).edit_prediction_usage()
1064        } else {
1065            None
1066        }
1067    }
1068
1069    pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
1070        self.get_or_init_project(project, cx);
1071    }
1072
1073    pub fn register_buffer(
1074        &mut self,
1075        buffer: &Entity<Buffer>,
1076        project: &Entity<Project>,
1077        cx: &mut Context<Self>,
1078    ) {
1079        let project_state = self.get_or_init_project(project, cx);
1080        Self::register_buffer_impl(project_state, buffer, project, cx);
1081    }
1082
1083    fn get_or_init_project(
1084        &mut self,
1085        project: &Entity<Project>,
1086        cx: &mut Context<Self>,
1087    ) -> &mut ProjectState {
1088        let entity_id = project.entity_id();
1089        self.projects
1090            .entry(entity_id)
1091            .or_insert_with(|| ProjectState {
1092                context: {
1093                    let related_excerpt_store = cx.new(|cx| RelatedExcerptStore::new(project, cx));
1094                    cx.subscribe(&related_excerpt_store, move |this, _, event, _| {
1095                        this.handle_excerpt_store_event(entity_id, event);
1096                    })
1097                    .detach();
1098                    related_excerpt_store
1099                },
1100                events: VecDeque::new(),
1101                last_event: None,
1102                recent_paths: VecDeque::new(),
1103                debug_tx: None,
1104                registered_buffers: HashMap::default(),
1105                current_prediction: None,
1106                cancelled_predictions: HashSet::default(),
1107                pending_predictions: ArrayVec::new(),
1108                next_pending_prediction_id: 0,
1109                last_edit_prediction_refresh: None,
1110                last_jump_prediction_refresh: None,
1111                license_detection_watchers: HashMap::default(),
1112                _subscriptions: [
1113                    cx.subscribe(&project, Self::handle_project_event),
1114                    cx.observe_release(&project, move |this, _, cx| {
1115                        this.projects.remove(&entity_id);
1116                        cx.notify();
1117                    }),
1118                ],
1119                copilot: None,
1120            })
1121    }
1122
1123    pub fn remove_project(&mut self, project: &Entity<Project>) {
1124        self.projects.remove(&project.entity_id());
1125    }
1126
1127    fn handle_excerpt_store_event(
1128        &mut self,
1129        project_entity_id: EntityId,
1130        event: &RelatedExcerptStoreEvent,
1131    ) {
1132        if let Some(project_state) = self.projects.get(&project_entity_id) {
1133            if let Some(debug_tx) = project_state.debug_tx.clone() {
1134                match event {
1135                    RelatedExcerptStoreEvent::StartedRefresh => {
1136                        debug_tx
1137                            .unbounded_send(DebugEvent::ContextRetrievalStarted(
1138                                ContextRetrievalStartedDebugEvent {
1139                                    project_entity_id: project_entity_id,
1140                                    timestamp: Instant::now(),
1141                                    search_prompt: String::new(),
1142                                },
1143                            ))
1144                            .ok();
1145                    }
1146                    RelatedExcerptStoreEvent::FinishedRefresh {
1147                        cache_hit_count,
1148                        cache_miss_count,
1149                        mean_definition_latency,
1150                        max_definition_latency,
1151                    } => {
1152                        debug_tx
1153                            .unbounded_send(DebugEvent::ContextRetrievalFinished(
1154                                ContextRetrievalFinishedDebugEvent {
1155                                    project_entity_id: project_entity_id,
1156                                    timestamp: Instant::now(),
1157                                    metadata: vec![
1158                                        (
1159                                            "Cache Hits",
1160                                            format!(
1161                                                "{}/{}",
1162                                                cache_hit_count,
1163                                                cache_hit_count + cache_miss_count
1164                                            )
1165                                            .into(),
1166                                        ),
1167                                        (
1168                                            "Max LSP Time",
1169                                            format!("{} ms", max_definition_latency.as_millis())
1170                                                .into(),
1171                                        ),
1172                                        (
1173                                            "Mean LSP Time",
1174                                            format!("{} ms", mean_definition_latency.as_millis())
1175                                                .into(),
1176                                        ),
1177                                    ],
1178                                },
1179                            ))
1180                            .ok();
1181                    }
1182                }
1183            }
1184        }
1185    }
1186
1187    pub fn debug_info(
1188        &mut self,
1189        project: &Entity<Project>,
1190        cx: &mut Context<Self>,
1191    ) -> mpsc::UnboundedReceiver<DebugEvent> {
1192        let project_state = self.get_or_init_project(project, cx);
1193        let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
1194        project_state.debug_tx = Some(debug_watch_tx);
1195        debug_watch_rx
1196    }
1197
1198    fn handle_project_event(
1199        &mut self,
1200        project: Entity<Project>,
1201        event: &project::Event,
1202        cx: &mut Context<Self>,
1203    ) {
1204        if !is_ep_store_provider(all_language_settings(None, cx).edit_predictions.provider) {
1205            return;
1206        }
1207        // TODO [zeta2] init with recent paths
1208        match event {
1209            project::Event::ActiveEntryChanged(Some(active_entry_id)) => {
1210                let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1211                    return;
1212                };
1213                let path = project.read(cx).path_for_entry(*active_entry_id, cx);
1214                if let Some(path) = path {
1215                    if let Some(ix) = project_state
1216                        .recent_paths
1217                        .iter()
1218                        .position(|probe| probe == &path)
1219                    {
1220                        project_state.recent_paths.remove(ix);
1221                    }
1222                    project_state.recent_paths.push_front(path);
1223                }
1224            }
1225            project::Event::DiagnosticsUpdated { .. } => {
1226                if cx.has_flag::<EditPredictionJumpsFeatureFlag>() {
1227                    self.refresh_prediction_from_diagnostics(
1228                        project,
1229                        DiagnosticSearchScope::Global,
1230                        cx,
1231                    );
1232                }
1233            }
1234            _ => (),
1235        }
1236    }
1237
1238    fn register_buffer_impl<'a>(
1239        project_state: &'a mut ProjectState,
1240        buffer: &Entity<Buffer>,
1241        project: &Entity<Project>,
1242        cx: &mut Context<Self>,
1243    ) -> &'a mut RegisteredBuffer {
1244        let buffer_id = buffer.entity_id();
1245
1246        if let Some(file) = buffer.read(cx).file() {
1247            let worktree_id = file.worktree_id(cx);
1248            if let Some(worktree) = project.read(cx).worktree_for_id(worktree_id, cx) {
1249                project_state
1250                    .license_detection_watchers
1251                    .entry(worktree_id)
1252                    .or_insert_with(|| {
1253                        let project_entity_id = project.entity_id();
1254                        cx.observe_release(&worktree, move |this, _worktree, _cx| {
1255                            let Some(project_state) = this.projects.get_mut(&project_entity_id)
1256                            else {
1257                                return;
1258                            };
1259                            project_state
1260                                .license_detection_watchers
1261                                .remove(&worktree_id);
1262                        })
1263                        .detach();
1264                        Rc::new(LicenseDetectionWatcher::new(&worktree, cx))
1265                    });
1266            }
1267        }
1268
1269        match project_state.registered_buffers.entry(buffer_id) {
1270            hash_map::Entry::Occupied(entry) => entry.into_mut(),
1271            hash_map::Entry::Vacant(entry) => {
1272                let buf = buffer.read(cx);
1273                let snapshot = buf.text_snapshot();
1274                let file = buf.file().cloned();
1275                let project_entity_id = project.entity_id();
1276                entry.insert(RegisteredBuffer {
1277                    snapshot,
1278                    file,
1279                    last_position: None,
1280                    pending_predictions: Vec::new(),
1281                    _subscriptions: [
1282                        cx.subscribe(buffer, {
1283                            let project = project.downgrade();
1284                            move |this, buffer, event, cx| {
1285                                if let language::BufferEvent::Edited { is_local } = event
1286                                    && let Some(project) = project.upgrade()
1287                                {
1288                                    this.report_changes_for_buffer(
1289                                        &buffer, &project, false, *is_local, cx,
1290                                    );
1291                                }
1292                            }
1293                        }),
1294                        cx.observe_release(buffer, move |this, _buffer, _cx| {
1295                            let Some(project_state) = this.projects.get_mut(&project_entity_id)
1296                            else {
1297                                return;
1298                            };
1299                            project_state.registered_buffers.remove(&buffer_id);
1300                        }),
1301                    ],
1302                })
1303            }
1304        }
1305    }
1306
1307    fn report_changes_for_buffer(
1308        &mut self,
1309        buffer: &Entity<Buffer>,
1310        project: &Entity<Project>,
1311        is_predicted: bool,
1312        is_local: bool,
1313        cx: &mut Context<Self>,
1314    ) {
1315        let project_state = self.get_or_init_project(project, cx);
1316        let registered_buffer = Self::register_buffer_impl(project_state, buffer, project, cx);
1317
1318        let buf = buffer.read(cx);
1319        let new_file = buf.file().cloned();
1320        let new_snapshot = buf.text_snapshot();
1321        if new_snapshot.version == registered_buffer.snapshot.version {
1322            return;
1323        }
1324        let old_file = mem::replace(&mut registered_buffer.file, new_file.clone());
1325        let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
1326        let mut edit_range: Option<Range<Anchor>> = None;
1327        let now = cx.background_executor().now();
1328
1329        for (_edit, anchor_range) in
1330            new_snapshot.anchored_edits_since::<usize>(&old_snapshot.version)
1331        {
1332            edit_range = Some(match edit_range {
1333                None => anchor_range,
1334                Some(acc) => acc.start..anchor_range.end,
1335            });
1336        }
1337
1338        let Some(edit_range) = edit_range else {
1339            return;
1340        };
1341
1342        for pending_prediction in &mut registered_buffer.pending_predictions {
1343            if edit_range.overlaps(&pending_prediction.editable_anchor_range, &new_snapshot) {
1344                pending_prediction.last_edit_at = now;
1345            }
1346        }
1347
1348        let include_in_history = is_local
1349            || collaborator_edit_overlaps_locality_region(
1350                project_state,
1351                project,
1352                buffer,
1353                &buf.snapshot(),
1354                &edit_range,
1355                cx,
1356            );
1357
1358        if !include_in_history {
1359            return;
1360        }
1361
1362        let is_recordable_history_edit =
1363            compute_diff_between_snapshots_in_range(&old_snapshot, &new_snapshot, &edit_range)
1364                .is_some();
1365
1366        let events = &mut project_state.events;
1367
1368        if !is_recordable_history_edit {
1369            if let Some(event) = project_state.last_event.take() {
1370                if let Some(event) = event.finalize(&project_state.license_detection_watchers, cx) {
1371                    if events.len() + 1 >= EVENT_COUNT_MAX {
1372                        events.pop_front();
1373                    }
1374                    events.push_back(event);
1375                }
1376            }
1377            return;
1378        }
1379
1380        if let Some(last_event) = project_state.last_event.as_mut() {
1381            let is_next_snapshot_of_same_buffer = old_snapshot.remote_id()
1382                == last_event.new_snapshot.remote_id()
1383                && old_snapshot.version == last_event.new_snapshot.version;
1384
1385            let prediction_source_changed = is_predicted != last_event.predicted;
1386
1387            let should_coalesce = is_next_snapshot_of_same_buffer
1388                && !prediction_source_changed
1389                && lines_between_ranges(
1390                    &edit_range.to_point(&new_snapshot),
1391                    &last_event.latest_edit_range.to_point(&new_snapshot),
1392                ) <= CHANGE_GROUPING_LINE_SPAN;
1393
1394            if should_coalesce {
1395                let pause_elapsed = last_event
1396                    .last_edit_time
1397                    .map(|t| now.duration_since(t) >= LAST_CHANGE_GROUPING_TIME)
1398                    .unwrap_or(false);
1399                if pause_elapsed {
1400                    last_event.snapshot_after_last_editing_pause =
1401                        Some(last_event.new_snapshot.clone());
1402                    last_event.total_edit_range_at_last_pause_boundary =
1403                        Some(last_event.total_edit_range.clone());
1404                }
1405
1406                last_event.latest_edit_range = edit_range.clone();
1407                last_event.total_edit_range =
1408                    merge_anchor_ranges(&last_event.total_edit_range, &edit_range, &new_snapshot);
1409                last_event.new_snapshot = new_snapshot;
1410                last_event.last_edit_time = Some(now);
1411                return;
1412            }
1413        }
1414
1415        if let Some(event) = project_state.last_event.take() {
1416            if let Some(event) = event.finalize(&project_state.license_detection_watchers, cx) {
1417                if events.len() + 1 >= EVENT_COUNT_MAX {
1418                    events.pop_front();
1419                }
1420                events.push_back(event);
1421            }
1422        }
1423
1424        merge_trailing_events_if_needed(events, &old_snapshot, &new_snapshot, &edit_range);
1425
1426        project_state.last_event = Some(LastEvent {
1427            old_file,
1428            new_file,
1429            old_snapshot,
1430            new_snapshot,
1431            latest_edit_range: edit_range.clone(),
1432            total_edit_range: edit_range,
1433            total_edit_range_at_last_pause_boundary: None,
1434            predicted: is_predicted,
1435            snapshot_after_last_editing_pause: None,
1436            last_edit_time: Some(now),
1437        });
1438    }
1439
1440    fn prediction_at(
1441        &mut self,
1442        buffer: &Entity<Buffer>,
1443        position: Option<language::Anchor>,
1444        project: &Entity<Project>,
1445        cx: &App,
1446    ) -> Option<BufferEditPrediction<'_>> {
1447        let project_state = self.projects.get_mut(&project.entity_id())?;
1448        if let Some(position) = position
1449            && let Some(buffer) = project_state
1450                .registered_buffers
1451                .get_mut(&buffer.entity_id())
1452        {
1453            buffer.last_position = Some(position);
1454        }
1455
1456        let CurrentEditPrediction {
1457            requested_by,
1458            prediction,
1459            ..
1460        } = project_state.current_prediction.as_ref()?;
1461
1462        if prediction.targets_buffer(buffer.read(cx)) {
1463            Some(BufferEditPrediction::Local { prediction })
1464        } else {
1465            let show_jump = match requested_by {
1466                PredictionRequestedBy::Buffer(requested_by_buffer_id) => {
1467                    requested_by_buffer_id == &buffer.entity_id()
1468                }
1469                PredictionRequestedBy::DiagnosticsUpdate => true,
1470            };
1471
1472            if show_jump {
1473                Some(BufferEditPrediction::Jump { prediction })
1474            } else {
1475                None
1476            }
1477        }
1478    }
1479
1480    fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
1481        let Some(current_prediction) = self
1482            .projects
1483            .get_mut(&project.entity_id())
1484            .and_then(|project_state| project_state.current_prediction.take())
1485        else {
1486            return;
1487        };
1488
1489        self.report_changes_for_buffer(
1490            &current_prediction.prediction.buffer,
1491            project,
1492            true,
1493            true,
1494            cx,
1495        );
1496
1497        // can't hold &mut project_state ref across report_changes_for_buffer_call
1498        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1499            return;
1500        };
1501
1502        for pending_prediction in mem::take(&mut project_state.pending_predictions) {
1503            project_state.cancel_pending_prediction(pending_prediction, cx);
1504        }
1505
1506        match self.edit_prediction_model {
1507            EditPredictionModel::Mercury => {
1508                mercury::edit_prediction_accepted(
1509                    current_prediction.prediction.id,
1510                    self.client.http_client(),
1511                    cx,
1512                );
1513            }
1514            EditPredictionModel::Zeta => {
1515                let is_cloud = !matches!(
1516                    all_language_settings(None, cx).edit_predictions.provider,
1517                    EditPredictionProvider::Ollama | EditPredictionProvider::OpenAiCompatibleApi
1518                );
1519                if is_cloud {
1520                    zeta::edit_prediction_accepted(self, current_prediction, cx)
1521                }
1522            }
1523            EditPredictionModel::Fim { .. } => {}
1524        }
1525    }
1526
1527    async fn handle_rejected_predictions(
1528        rx: UnboundedReceiver<EditPredictionRejectionPayload>,
1529        client: Arc<Client>,
1530        llm_token: LlmApiToken,
1531        app_version: Version,
1532        background_executor: BackgroundExecutor,
1533    ) {
1534        let mut rx = std::pin::pin!(rx.peekable());
1535        let mut batched = Vec::new();
1536
1537        while let Some(EditPredictionRejectionPayload {
1538            rejection,
1539            organization_id,
1540        }) = rx.next().await
1541        {
1542            batched.push(rejection);
1543
1544            if batched.len() < MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST / 2 {
1545                select_biased! {
1546                    next = rx.as_mut().peek().fuse() => {
1547                        if next.is_some() {
1548                            continue;
1549                        }
1550                    }
1551                    () = background_executor.timer(REJECT_REQUEST_DEBOUNCE).fuse() => {},
1552                }
1553            }
1554
1555            let url = client
1556                .http_client()
1557                .build_zed_llm_url("/predict_edits/reject", &[])
1558                .unwrap();
1559
1560            let flush_count = batched
1561                .len()
1562                // in case items have accumulated after failure
1563                .min(MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST);
1564            let start = batched.len() - flush_count;
1565
1566            let body = RejectEditPredictionsBodyRef {
1567                rejections: &batched[start..],
1568            };
1569
1570            let result = Self::send_api_request::<()>(
1571                |builder| {
1572                    let req = builder
1573                        .uri(url.as_ref())
1574                        .body(serde_json::to_string(&body)?.into());
1575                    anyhow::Ok(req?)
1576                },
1577                client.clone(),
1578                llm_token.clone(),
1579                organization_id,
1580                app_version.clone(),
1581                true,
1582            )
1583            .await;
1584
1585            if result.log_err().is_some() {
1586                batched.drain(start..);
1587            }
1588        }
1589    }
1590
1591    async fn run_settled_predictions_worker(
1592        this: WeakEntity<Self>,
1593        mut rx: UnboundedReceiver<Instant>,
1594        cx: &mut AsyncApp,
1595    ) {
1596        let mut next_wake_time: Option<Instant> = None;
1597        loop {
1598            let now = cx.background_executor().now();
1599            if let Some(wake_time) = next_wake_time.take() {
1600                cx.background_executor()
1601                    .timer(wake_time.duration_since(now))
1602                    .await;
1603            } else {
1604                let Some(new_enqueue_time) = rx.next().await else {
1605                    break;
1606                };
1607                next_wake_time = Some(new_enqueue_time + EDIT_PREDICTION_SETTLED_QUIESCENCE);
1608                while rx.next().now_or_never().flatten().is_some() {}
1609                continue;
1610            }
1611
1612            let Some(this) = this.upgrade() else {
1613                break;
1614            };
1615
1616            let now = cx.background_executor().now();
1617            let mut oldest_edited_at = None;
1618            let mut ready_predictions = Vec::new();
1619
1620            this.update(cx, |this, _| {
1621                for (_, project_state) in this.projects.iter_mut() {
1622                    for (_, registered_buffer) in project_state.registered_buffers.iter_mut() {
1623                        let mut pending_index = 0;
1624                        while pending_index < registered_buffer.pending_predictions.len() {
1625                            let pending_prediction =
1626                                &registered_buffer.pending_predictions[pending_index];
1627                            let age = now.saturating_duration_since(pending_prediction.enqueued_at);
1628                            if age >= EDIT_PREDICTION_SETTLED_TTL {
1629                                registered_buffer.pending_predictions.remove(pending_index);
1630                                continue;
1631                            }
1632
1633                            let quiet_for =
1634                                now.saturating_duration_since(pending_prediction.last_edit_at);
1635                            if quiet_for >= EDIT_PREDICTION_SETTLED_QUIESCENCE {
1636                                let pending_prediction =
1637                                    registered_buffer.pending_predictions.remove(pending_index);
1638                                let settled_editable_region = registered_buffer
1639                                    .snapshot
1640                                    .text_for_range(
1641                                        pending_prediction.editable_anchor_range.clone(),
1642                                    )
1643                                    .collect::<String>();
1644                                ready_predictions
1645                                    .push((pending_prediction, settled_editable_region));
1646                                continue;
1647                            }
1648
1649                            if oldest_edited_at
1650                                .is_none_or(|time| pending_prediction.last_edit_at < time)
1651                            {
1652                                oldest_edited_at = Some(pending_prediction.last_edit_at);
1653                            }
1654                            pending_index += 1;
1655                        }
1656                    }
1657                }
1658            });
1659
1660            for (pending_prediction, settled_editable_region) in ready_predictions {
1661                let PendingSettledPrediction {
1662                    request_id,
1663                    editable_region_before_prediction,
1664                    predicted_editable_region,
1665                    ts_error_count_before_prediction,
1666                    ts_error_count_after_prediction,
1667                    example,
1668                    e2e_latency,
1669                    ..
1670                } = pending_prediction;
1671                let settled_editable_region_for_metrics = settled_editable_region.clone();
1672                let kept_rate_result = cx
1673                    .background_spawn(async move {
1674                        compute_kept_rate(
1675                            &editable_region_before_prediction,
1676                            &predicted_editable_region,
1677                            &settled_editable_region_for_metrics,
1678                        )
1679                    })
1680                    .await;
1681
1682                #[cfg(test)]
1683                {
1684                    let request_id = request_id.clone();
1685                    let settled_editable_region = settled_editable_region.clone();
1686                    this.update(cx, |this, _| {
1687                        if let Some(callback) = &this.settled_event_callback {
1688                            callback(request_id, settled_editable_region);
1689                        }
1690                    });
1691                }
1692
1693                telemetry::event!(
1694                    EDIT_PREDICTION_SETTLED_EVENT,
1695                    request_id = request_id.0.clone(),
1696                    settled_editable_region,
1697                    ts_error_count_before_prediction,
1698                    ts_error_count_after_prediction,
1699                    edit_bytes_candidate_new = kept_rate_result.candidate_new_chars,
1700                    edit_bytes_reference_new = kept_rate_result.reference_new_chars,
1701                    edit_bytes_candidate_deleted = kept_rate_result.candidate_deleted_chars,
1702                    edit_bytes_reference_deleted = kept_rate_result.reference_deleted_chars,
1703                    edit_bytes_kept = kept_rate_result.kept_chars,
1704                    edit_bytes_correctly_deleted = kept_rate_result.correctly_deleted_chars,
1705                    edit_bytes_discarded = kept_rate_result.discarded_chars,
1706                    edit_bytes_context = kept_rate_result.context_chars,
1707                    edit_bytes_kept_rate = kept_rate_result.kept_rate,
1708                    edit_bytes_recall_rate = kept_rate_result.recall_rate,
1709                    example,
1710                    e2e_latency = e2e_latency.as_millis(),
1711                );
1712            }
1713
1714            next_wake_time = oldest_edited_at.map(|time| time + EDIT_PREDICTION_SETTLED_QUIESCENCE);
1715        }
1716    }
1717
1718    pub(crate) fn enqueue_settled_prediction(
1719        &mut self,
1720        request_id: EditPredictionId,
1721        project: &Entity<Project>,
1722        edited_buffer: &Entity<Buffer>,
1723        edited_buffer_snapshot: &BufferSnapshot,
1724        editable_offset_range: Range<usize>,
1725        edit_preview: &EditPreview,
1726        example: Option<ExampleSpec>,
1727        e2e_latency: std::time::Duration,
1728        cx: &mut Context<Self>,
1729    ) {
1730        let this = &mut *self;
1731        let project_state = this.get_or_init_project(project, cx);
1732        let Some(registered_buffer) = project_state
1733            .registered_buffers
1734            .get_mut(&edited_buffer.entity_id())
1735        else {
1736            return;
1737        };
1738
1739        let editable_region_before_prediction = edited_buffer_snapshot
1740            .text_for_range(editable_offset_range.clone())
1741            .collect::<String>();
1742        let editable_anchor_range_for_result =
1743            edited_buffer_snapshot.anchor_range_inside(editable_offset_range.clone());
1744        let predicted_editable_region = edit_preview
1745            .result_text_snapshot()
1746            .text_for_range(editable_anchor_range_for_result.clone())
1747            .collect();
1748        let ts_error_count_before_prediction = crate::metrics::count_tree_sitter_errors(
1749            edited_buffer_snapshot
1750                .syntax_layers_for_range(editable_anchor_range_for_result.clone(), true),
1751        );
1752        let ts_error_count_after_prediction = crate::metrics::count_tree_sitter_errors(
1753            edit_preview.result_syntax_snapshot().layers_for_range(
1754                editable_anchor_range_for_result,
1755                edit_preview.result_text_snapshot(),
1756                true,
1757            ),
1758        );
1759        let editable_anchor_range =
1760            edited_buffer_snapshot.anchor_range_inside(editable_offset_range);
1761        let now = cx.background_executor().now();
1762        registered_buffer
1763            .pending_predictions
1764            .push(PendingSettledPrediction {
1765                request_id,
1766                editable_anchor_range,
1767                editable_region_before_prediction,
1768                predicted_editable_region,
1769                ts_error_count_before_prediction,
1770                ts_error_count_after_prediction,
1771                example,
1772                e2e_latency,
1773                enqueued_at: now,
1774                last_edit_at: now,
1775            });
1776        this.settled_predictions_tx.unbounded_send(now).ok();
1777    }
1778
1779    fn reject_current_prediction(
1780        &mut self,
1781        reason: EditPredictionRejectReason,
1782        project: &Entity<Project>,
1783        cx: &App,
1784    ) {
1785        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
1786            project_state.pending_predictions.clear();
1787            if let Some(prediction) = project_state.current_prediction.take() {
1788                let model_version = prediction.prediction.model_version.clone();
1789                self.reject_prediction(
1790                    prediction.prediction.id,
1791                    reason,
1792                    prediction.was_shown,
1793                    model_version,
1794                    Some(prediction.e2e_latency),
1795                    cx,
1796                );
1797            }
1798        };
1799    }
1800
1801    fn did_show_current_prediction(
1802        &mut self,
1803        project: &Entity<Project>,
1804        display_type: edit_prediction_types::SuggestionDisplayType,
1805        _cx: &mut Context<Self>,
1806    ) {
1807        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1808            return;
1809        };
1810
1811        let Some(current_prediction) = project_state.current_prediction.as_mut() else {
1812            return;
1813        };
1814
1815        let is_jump = display_type == edit_prediction_types::SuggestionDisplayType::Jump;
1816        let previous_shown_with = current_prediction.shown_with;
1817
1818        if previous_shown_with.is_none() || !is_jump {
1819            current_prediction.shown_with = Some(display_type);
1820        }
1821
1822        let is_first_non_jump_show = !current_prediction.was_shown && !is_jump;
1823
1824        if is_first_non_jump_show {
1825            current_prediction.was_shown = true;
1826        }
1827
1828        if is_first_non_jump_show {
1829            self.shown_predictions
1830                .push_front(current_prediction.prediction.clone());
1831            if self.shown_predictions.len() > 50 {
1832                let completion = self.shown_predictions.pop_back().unwrap();
1833                self.rated_predictions.remove(&completion.id);
1834            }
1835        }
1836    }
1837
1838    fn reject_prediction(
1839        &mut self,
1840        prediction_id: EditPredictionId,
1841        reason: EditPredictionRejectReason,
1842        was_shown: bool,
1843        model_version: Option<String>,
1844        e2e_latency: Option<std::time::Duration>,
1845        cx: &App,
1846    ) {
1847        match self.edit_prediction_model {
1848            EditPredictionModel::Zeta => {
1849                let is_cloud = !matches!(
1850                    all_language_settings(None, cx).edit_predictions.provider,
1851                    EditPredictionProvider::Ollama | EditPredictionProvider::OpenAiCompatibleApi
1852                );
1853
1854                if is_cloud {
1855                    let organization_id = self
1856                        .user_store
1857                        .read(cx)
1858                        .current_organization()
1859                        .map(|organization| organization.id.clone());
1860
1861                    self.reject_predictions_tx
1862                        .unbounded_send(EditPredictionRejectionPayload {
1863                            rejection: EditPredictionRejection {
1864                                request_id: prediction_id.to_string(),
1865                                reason,
1866                                was_shown,
1867                                model_version,
1868                                e2e_latency_ms: e2e_latency.map(|latency| latency.as_millis()),
1869                            },
1870                            organization_id,
1871                        })
1872                        .log_err();
1873                }
1874            }
1875            EditPredictionModel::Mercury => {
1876                mercury::edit_prediction_rejected(
1877                    prediction_id,
1878                    was_shown,
1879                    reason,
1880                    self.client.http_client(),
1881                    cx,
1882                );
1883            }
1884            EditPredictionModel::Fim { .. } => {}
1885        }
1886    }
1887
1888    fn is_refreshing(&self, project: &Entity<Project>) -> bool {
1889        self.projects
1890            .get(&project.entity_id())
1891            .is_some_and(|project_state| !project_state.pending_predictions.is_empty())
1892    }
1893
1894    pub fn refresh_prediction_from_buffer(
1895        &mut self,
1896        project: Entity<Project>,
1897        buffer: Entity<Buffer>,
1898        position: language::Anchor,
1899        cx: &mut Context<Self>,
1900    ) {
1901        self.queue_prediction_refresh(
1902            project.clone(),
1903            PredictEditsRequestTrigger::Other,
1904            buffer.entity_id(),
1905            cx,
1906            move |this, cx| {
1907                let Some(request_task) = this
1908                    .update(cx, |this, cx| {
1909                        this.request_prediction(
1910                            &project,
1911                            &buffer,
1912                            position,
1913                            PredictEditsRequestTrigger::Other,
1914                            cx,
1915                        )
1916                    })
1917                    .log_err()
1918                else {
1919                    return Task::ready(anyhow::Ok(None));
1920                };
1921
1922                cx.spawn(async move |_cx| {
1923                    request_task.await.map(|prediction_result| {
1924                        prediction_result.map(|prediction_result| {
1925                            (
1926                                prediction_result,
1927                                PredictionRequestedBy::Buffer(buffer.entity_id()),
1928                            )
1929                        })
1930                    })
1931                })
1932            },
1933        )
1934    }
1935
1936    pub fn refresh_prediction_from_diagnostics(
1937        &mut self,
1938        project: Entity<Project>,
1939        scope: DiagnosticSearchScope,
1940        cx: &mut Context<Self>,
1941    ) {
1942        if !is_ep_store_provider(all_language_settings(None, cx).edit_predictions.provider) {
1943            return;
1944        }
1945
1946        if currently_following(&project, cx) {
1947            return;
1948        }
1949
1950        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1951            return;
1952        };
1953
1954        // Prefer predictions from buffer
1955        if project_state.current_prediction.is_some() {
1956            log::debug!(
1957                "edit_prediction: diagnostic refresh skipped, current prediction already exists"
1958            );
1959            return;
1960        }
1961
1962        self.queue_prediction_refresh(
1963            project.clone(),
1964            PredictEditsRequestTrigger::Diagnostics,
1965            project.entity_id(),
1966            cx,
1967            move |this, cx| {
1968                let Some((active_buffer, snapshot, cursor_point)) = this
1969                    .read_with(cx, |this, cx| {
1970                        let project_state = this.projects.get(&project.entity_id())?;
1971                        let (buffer, position) = project_state.active_buffer(&project, cx)?;
1972                        let snapshot = buffer.read(cx).snapshot();
1973
1974                        if !Self::predictions_enabled_at(&snapshot, position, cx) {
1975                            return None;
1976                        }
1977
1978                        let cursor_point = position
1979                            .map(|pos| pos.to_point(&snapshot))
1980                            .unwrap_or_default();
1981
1982                        Some((buffer, snapshot, cursor_point))
1983                    })
1984                    .log_err()
1985                    .flatten()
1986                else {
1987                    return Task::ready(anyhow::Ok(None));
1988                };
1989
1990                cx.spawn(async move |cx| {
1991                    let diagnostic_search_range = match scope {
1992                        DiagnosticSearchScope::Local => {
1993                            let diagnostic_search_start =
1994                                cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
1995                            let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
1996                            Point::new(diagnostic_search_start, 0)
1997                                ..Point::new(diagnostic_search_end, 0)
1998                        }
1999                        DiagnosticSearchScope::Global => Default::default(),
2000                    };
2001
2002                    let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
2003                        active_buffer,
2004                        &snapshot,
2005                        diagnostic_search_range,
2006                        cursor_point,
2007                        &project,
2008                        cx,
2009                    )
2010                    .await?
2011                    else {
2012                        return anyhow::Ok(None);
2013                    };
2014
2015                    let Some(prediction_result) = this
2016                        .update(cx, |this, cx| {
2017                            this.request_prediction(
2018                                &project,
2019                                &jump_buffer,
2020                                jump_position,
2021                                PredictEditsRequestTrigger::Diagnostics,
2022                                cx,
2023                            )
2024                        })?
2025                        .await?
2026                    else {
2027                        return anyhow::Ok(None);
2028                    };
2029
2030                    this.update(cx, |this, cx| {
2031                        Some((
2032                            if this
2033                                .get_or_init_project(&project, cx)
2034                                .current_prediction
2035                                .is_none()
2036                            {
2037                                prediction_result
2038                            } else {
2039                                EditPredictionResult {
2040                                    id: prediction_result.id,
2041                                    prediction: Err(EditPredictionRejectReason::CurrentPreferred),
2042                                    model_version: prediction_result.model_version,
2043                                    e2e_latency: prediction_result.e2e_latency,
2044                                }
2045                            },
2046                            PredictionRequestedBy::DiagnosticsUpdate,
2047                        ))
2048                    })
2049                })
2050            },
2051        );
2052    }
2053
2054    fn predictions_enabled_at(
2055        snapshot: &BufferSnapshot,
2056        position: Option<language::Anchor>,
2057        cx: &App,
2058    ) -> bool {
2059        let file = snapshot.file();
2060        let all_settings = all_language_settings(file, cx);
2061        if !all_settings.show_edit_predictions(snapshot.language(), cx)
2062            || file.is_some_and(|file| !all_settings.edit_predictions_enabled_for_file(file, cx))
2063        {
2064            return false;
2065        }
2066
2067        if let Some(last_position) = position {
2068            let settings = snapshot.settings_at(last_position, cx);
2069
2070            if !settings.edit_predictions_disabled_in.is_empty()
2071                && let Some(scope) = snapshot.language_scope_at(last_position)
2072                && let Some(scope_name) = scope.override_name()
2073                && settings
2074                    .edit_predictions_disabled_in
2075                    .iter()
2076                    .any(|s| s == scope_name)
2077            {
2078                return false;
2079            }
2080        }
2081
2082        true
2083    }
2084
2085    pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
2086}
2087
2088fn currently_following(project: &Entity<Project>, cx: &App) -> bool {
2089    let Some(app_state) = AppState::try_global(cx) else {
2090        return false;
2091    };
2092
2093    app_state
2094        .workspace_store
2095        .read(cx)
2096        .workspaces()
2097        .filter_map(|workspace| workspace.upgrade())
2098        .any(|workspace| {
2099            workspace.read(cx).project().entity_id() == project.entity_id()
2100                && workspace
2101                    .read(cx)
2102                    .leader_for_pane(workspace.read(cx).active_pane())
2103                    .is_some()
2104        })
2105}
2106
2107fn is_ep_store_provider(provider: EditPredictionProvider) -> bool {
2108    match provider {
2109        EditPredictionProvider::Zed
2110        | EditPredictionProvider::Mercury
2111        | EditPredictionProvider::Ollama
2112        | EditPredictionProvider::OpenAiCompatibleApi
2113        | EditPredictionProvider::Experimental(_) => true,
2114        EditPredictionProvider::None
2115        | EditPredictionProvider::Copilot
2116        | EditPredictionProvider::Codestral => false,
2117    }
2118}
2119
2120impl EditPredictionStore {
2121    fn queue_prediction_refresh(
2122        &mut self,
2123        project: Entity<Project>,
2124        request_trigger: PredictEditsRequestTrigger,
2125        throttle_entity: EntityId,
2126        cx: &mut Context<Self>,
2127        do_refresh: impl FnOnce(
2128            WeakEntity<Self>,
2129            &mut AsyncApp,
2130        )
2131            -> Task<Result<Option<(EditPredictionResult, PredictionRequestedBy)>>>
2132        + 'static,
2133    ) {
2134        fn select_throttle(
2135            project_state: &mut ProjectState,
2136            request_trigger: PredictEditsRequestTrigger,
2137        ) -> &mut Option<(EntityId, Instant)> {
2138            match request_trigger {
2139                PredictEditsRequestTrigger::Diagnostics => {
2140                    &mut project_state.last_jump_prediction_refresh
2141                }
2142                _ => &mut project_state.last_edit_prediction_refresh,
2143            }
2144        }
2145
2146        let (needs_acceptance_tracking, max_pending_predictions) =
2147            match all_language_settings(None, cx).edit_predictions.provider {
2148                EditPredictionProvider::Zed
2149                | EditPredictionProvider::Mercury
2150                | EditPredictionProvider::Experimental(_) => (true, 2),
2151                EditPredictionProvider::Ollama => (false, 1),
2152                EditPredictionProvider::OpenAiCompatibleApi => (false, 2),
2153                EditPredictionProvider::None
2154                | EditPredictionProvider::Copilot
2155                | EditPredictionProvider::Codestral => {
2156                    log::error!("queue_prediction_refresh called with non-store provider");
2157                    return;
2158                }
2159            };
2160
2161        let drop_on_cancel = !needs_acceptance_tracking;
2162        let throttle_timeout = Self::THROTTLE_TIMEOUT;
2163        let project_state = self.get_or_init_project(&project, cx);
2164        let pending_prediction_id = project_state.next_pending_prediction_id;
2165        project_state.next_pending_prediction_id += 1;
2166        let throttle_at_enqueue = *select_throttle(project_state, request_trigger);
2167
2168        let task = cx.spawn(async move |this, cx| {
2169            let throttle_wait = this
2170                .update(cx, |this, cx| {
2171                    let project_state = this.get_or_init_project(&project, cx);
2172                    let throttle = *select_throttle(project_state, request_trigger);
2173
2174                    let now = cx.background_executor().now();
2175                    throttle.and_then(|(last_entity, last_timestamp)| {
2176                        if throttle_entity != last_entity {
2177                            return None;
2178                        }
2179                        (last_timestamp + throttle_timeout).checked_duration_since(now)
2180                    })
2181                })
2182                .ok()
2183                .flatten();
2184
2185            if let Some(timeout) = throttle_wait {
2186                cx.background_executor().timer(timeout).await;
2187            }
2188
2189            // If this task was cancelled before the throttle timeout expired,
2190            // do not perform a request. Also skip if another task already
2191            // proceeded since we were enqueued (duplicate).
2192            let mut is_cancelled = true;
2193            this.update(cx, |this, cx| {
2194                let project_state = this.get_or_init_project(&project, cx);
2195                let was_cancelled = project_state
2196                    .cancelled_predictions
2197                    .remove(&pending_prediction_id);
2198                if was_cancelled {
2199                    return;
2200                }
2201
2202                // Another request has been already sent since this was enqueued
2203                if *select_throttle(project_state, request_trigger) != throttle_at_enqueue {
2204                    return;
2205                }
2206
2207                let new_refresh = (throttle_entity, cx.background_executor().now());
2208                *select_throttle(project_state, request_trigger) = Some(new_refresh);
2209                is_cancelled = false;
2210            })
2211            .ok();
2212            if is_cancelled {
2213                return None;
2214            }
2215
2216            let new_prediction_result = do_refresh(this.clone(), cx).await.log_err().flatten();
2217            let new_prediction_metadata = new_prediction_result
2218                .as_ref()
2219                .map(|(prediction, _)| (prediction.id.clone(), prediction.model_version.clone()));
2220
2221            // When a prediction completes, remove it from the pending list, and cancel
2222            // any pending predictions that were enqueued before it.
2223            this.update(cx, |this, cx| {
2224                let project_state = this.get_or_init_project(&project, cx);
2225
2226                let is_cancelled = project_state
2227                    .cancelled_predictions
2228                    .remove(&pending_prediction_id);
2229
2230                let new_current_prediction = if !is_cancelled
2231                    && let Some((prediction_result, requested_by)) = new_prediction_result
2232                {
2233                    match prediction_result.prediction {
2234                        Ok(prediction) => {
2235                            let new_prediction = CurrentEditPrediction {
2236                                requested_by,
2237                                prediction,
2238                                was_shown: false,
2239                                shown_with: None,
2240                                e2e_latency: prediction_result.e2e_latency,
2241                            };
2242
2243                            if let Some(current_prediction) =
2244                                project_state.current_prediction.as_ref()
2245                            {
2246                                if new_prediction.should_replace_prediction(&current_prediction, cx)
2247                                {
2248                                    this.reject_current_prediction(
2249                                        EditPredictionRejectReason::Replaced,
2250                                        &project,
2251                                        cx,
2252                                    );
2253
2254                                    Some(new_prediction)
2255                                } else {
2256                                    this.reject_prediction(
2257                                        new_prediction.prediction.id,
2258                                        EditPredictionRejectReason::CurrentPreferred,
2259                                        false,
2260                                        new_prediction.prediction.model_version,
2261                                        Some(new_prediction.e2e_latency),
2262                                        cx,
2263                                    );
2264                                    None
2265                                }
2266                            } else {
2267                                Some(new_prediction)
2268                            }
2269                        }
2270                        Err(reject_reason) => {
2271                            this.reject_prediction(
2272                                prediction_result.id,
2273                                reject_reason,
2274                                false,
2275                                prediction_result.model_version,
2276                                Some(prediction_result.e2e_latency),
2277                                cx,
2278                            );
2279                            None
2280                        }
2281                    }
2282                } else {
2283                    None
2284                };
2285
2286                let project_state = this.get_or_init_project(&project, cx);
2287
2288                if let Some(new_prediction) = new_current_prediction {
2289                    project_state.current_prediction = Some(new_prediction);
2290                }
2291
2292                let mut pending_predictions = mem::take(&mut project_state.pending_predictions);
2293                for (ix, pending_prediction) in pending_predictions.iter().enumerate() {
2294                    if pending_prediction.id == pending_prediction_id {
2295                        pending_predictions.remove(ix);
2296                        for pending_prediction in pending_predictions.drain(0..ix) {
2297                            project_state.cancel_pending_prediction(pending_prediction, cx)
2298                        }
2299                        break;
2300                    }
2301                }
2302                this.get_or_init_project(&project, cx).pending_predictions = pending_predictions;
2303                cx.notify();
2304            })
2305            .ok();
2306
2307            new_prediction_metadata
2308        });
2309
2310        if project_state.pending_predictions.len() < max_pending_predictions {
2311            project_state
2312                .pending_predictions
2313                .push(PendingPrediction {
2314                    id: pending_prediction_id,
2315                    task,
2316                    drop_on_cancel,
2317                })
2318                .unwrap();
2319        } else {
2320            let pending_prediction = project_state.pending_predictions.pop().unwrap();
2321            project_state
2322                .pending_predictions
2323                .push(PendingPrediction {
2324                    id: pending_prediction_id,
2325                    task,
2326                    drop_on_cancel,
2327                })
2328                .unwrap();
2329            project_state.cancel_pending_prediction(pending_prediction, cx);
2330        }
2331    }
2332
2333    pub fn request_prediction(
2334        &mut self,
2335        project: &Entity<Project>,
2336        active_buffer: &Entity<Buffer>,
2337        position: language::Anchor,
2338        trigger: PredictEditsRequestTrigger,
2339        cx: &mut Context<Self>,
2340    ) -> Task<Result<Option<EditPredictionResult>>> {
2341        self.request_prediction_internal(
2342            project.clone(),
2343            active_buffer.clone(),
2344            position,
2345            trigger,
2346            cx.has_flag::<EditPredictionJumpsFeatureFlag>(),
2347            cx,
2348        )
2349    }
2350
2351    fn request_prediction_internal(
2352        &mut self,
2353        project: Entity<Project>,
2354        active_buffer: Entity<Buffer>,
2355        position: language::Anchor,
2356        trigger: PredictEditsRequestTrigger,
2357        allow_jump: bool,
2358        cx: &mut Context<Self>,
2359    ) -> Task<Result<Option<EditPredictionResult>>> {
2360        self.get_or_init_project(&project, cx);
2361        let project_state = self.projects.get(&project.entity_id()).unwrap();
2362        let stored_events = project_state.events(cx);
2363        let has_events = !stored_events.is_empty();
2364        let events: Vec<Arc<zeta_prompt::Event>> =
2365            stored_events.iter().map(|e| e.event.clone()).collect();
2366        let debug_tx = project_state.debug_tx.clone();
2367
2368        let snapshot = active_buffer.read(cx).snapshot();
2369        let cursor_point = position.to_point(&snapshot);
2370        let diagnostic_search_start = cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
2371        let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
2372        let diagnostic_search_range =
2373            Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
2374
2375        let related_files = self.context_for_project(&project, cx);
2376        let mode = match all_language_settings(snapshot.file(), cx).edit_predictions_mode() {
2377            EditPredictionsMode::Eager => PredictEditsMode::Eager,
2378            EditPredictionsMode::Subtle => PredictEditsMode::Subtle,
2379        };
2380
2381        let is_open_source = snapshot
2382            .file()
2383            .map_or(false, |file| self.is_file_open_source(&project, file, cx))
2384            && events.iter().all(|event| event.in_open_source_repo())
2385            && related_files.iter().all(|file| file.in_open_source_repo);
2386
2387        let can_collect_data = !cfg!(test)
2388            && is_open_source
2389            && self.is_data_collection_enabled(cx)
2390            && matches!(self.edit_prediction_model, EditPredictionModel::Zeta);
2391        let inputs = EditPredictionModelInput {
2392            project: project.clone(),
2393            buffer: active_buffer,
2394            snapshot,
2395            position,
2396            events,
2397            related_files,
2398            mode,
2399            trigger,
2400            diagnostic_search_range,
2401            debug_tx,
2402            can_collect_data,
2403            is_open_source,
2404        };
2405
2406        let capture_data = (can_collect_data && rand::random_ratio(1, 1000)).then(|| stored_events);
2407
2408        let task = match self.edit_prediction_model {
2409            EditPredictionModel::Zeta => {
2410                zeta::request_prediction_with_zeta(self, inputs, capture_data, cx)
2411            }
2412            EditPredictionModel::Fim { format } => fim::request_prediction(inputs, format, cx),
2413            EditPredictionModel::Mercury => {
2414                self.mercury
2415                    .request_prediction(inputs, self.credentials_provider.clone(), cx)
2416            }
2417        };
2418
2419        cx.spawn(async move |this, cx| {
2420            let prediction = task.await?;
2421
2422            // Only fall back to diagnostics-based prediction if we got a
2423            // the model had nothing to suggest for the buffer
2424            if prediction.is_none()
2425                && allow_jump
2426                && has_events
2427                && !matches!(trigger, PredictEditsRequestTrigger::Diagnostics)
2428            {
2429                this.update(cx, |this, cx| {
2430                    this.refresh_prediction_from_diagnostics(
2431                        project,
2432                        DiagnosticSearchScope::Local,
2433                        cx,
2434                    );
2435                })?;
2436                return anyhow::Ok(None);
2437            }
2438
2439            Ok(prediction)
2440        })
2441    }
2442
2443    pub(crate) async fn next_diagnostic_location(
2444        active_buffer: Entity<Buffer>,
2445        active_buffer_snapshot: &BufferSnapshot,
2446        active_buffer_diagnostic_search_range: Range<Point>,
2447        active_buffer_cursor_point: Point,
2448        project: &Entity<Project>,
2449        cx: &mut AsyncApp,
2450    ) -> Result<Option<(Entity<Buffer>, language::Anchor)>> {
2451        let collaborator_cursor_rows: Vec<u32> = active_buffer_snapshot
2452            .selections_in_range(
2453                Anchor::min_max_range_for_buffer(active_buffer_snapshot.remote_id()),
2454                false,
2455            )
2456            .flat_map(|(_, _, _, selections)| {
2457                selections.map(|s| s.head().to_point(active_buffer_snapshot).row)
2458            })
2459            .collect();
2460
2461        let mut jump_location = active_buffer_snapshot
2462            .diagnostic_groups(None)
2463            .into_iter()
2464            .filter_map(|(_, group)| {
2465                let range = &group.entries[group.primary_ix]
2466                    .range
2467                    .to_point(&active_buffer_snapshot);
2468                if range.overlaps(&active_buffer_diagnostic_search_range) {
2469                    return None;
2470                }
2471                let near_collaborator = collaborator_cursor_rows.iter().any(|&collab_row| {
2472                    range.start.row.abs_diff(collab_row) <= DIAGNOSTIC_LINES_RANGE
2473                });
2474                let near_local = active_buffer_cursor_point.row.abs_diff(range.start.row)
2475                    <= DIAGNOSTIC_LINES_RANGE;
2476                if near_collaborator && !near_local {
2477                    return None;
2478                }
2479                Some(range.start)
2480            })
2481            .min_by_key(|probe| probe.row.abs_diff(active_buffer_cursor_point.row))
2482            .map(|position| {
2483                (
2484                    active_buffer.clone(),
2485                    active_buffer_snapshot.anchor_before(position),
2486                )
2487            });
2488
2489        if jump_location.is_none() {
2490            let active_buffer_path = active_buffer.read_with(cx, |buffer, cx| {
2491                let file = buffer.file()?;
2492
2493                Some(ProjectPath {
2494                    worktree_id: file.worktree_id(cx),
2495                    path: file.path().clone(),
2496                })
2497            });
2498
2499            let mut candidates: Vec<(ProjectPath, usize)> = project.read_with(cx, |project, cx| {
2500                project
2501                    .diagnostic_summaries(false, cx)
2502                    .filter(|(path, _, _)| Some(path) != active_buffer_path.as_ref())
2503                    .map(|(path, _, _)| {
2504                        let shared_prefix = path
2505                            .path
2506                            .components()
2507                            .zip(
2508                                active_buffer_path
2509                                    .as_ref()
2510                                    .map(|p| p.path.components())
2511                                    .unwrap_or_default(),
2512                            )
2513                            .take_while(|(a, b)| a == b)
2514                            .count();
2515                        (path, shared_prefix)
2516                    })
2517                    .collect()
2518            });
2519
2520            candidates.sort_by(|a, b| b.1.cmp(&a.1));
2521
2522            for (path, _) in candidates {
2523                let candidate_buffer = project
2524                    .update(cx, |project, cx| project.open_buffer(path, cx))
2525                    .await?;
2526
2527                let (has_collaborators, diagnostic_position) =
2528                    candidate_buffer.read_with(cx, |buffer, _cx| {
2529                        let snapshot = buffer.snapshot();
2530                        let has_collaborators = snapshot
2531                            .selections_in_range(
2532                                Anchor::min_max_range_for_buffer(snapshot.remote_id()),
2533                                false,
2534                            )
2535                            .next()
2536                            .is_some();
2537                        let position = buffer
2538                            .buffer_diagnostics(None)
2539                            .into_iter()
2540                            .min_by_key(|entry| entry.diagnostic.severity)
2541                            .map(|entry| entry.range.start);
2542                        (has_collaborators, position)
2543                    });
2544
2545                if has_collaborators {
2546                    continue;
2547                }
2548
2549                if let Some(position) = diagnostic_position {
2550                    jump_location = Some((candidate_buffer, position));
2551                    break;
2552                }
2553            }
2554        }
2555
2556        anyhow::Ok(jump_location)
2557    }
2558
2559    async fn send_raw_llm_request(
2560        request: RawCompletionRequest,
2561        client: Arc<Client>,
2562        custom_url: Option<Arc<Url>>,
2563        llm_token: LlmApiToken,
2564        organization_id: Option<OrganizationId>,
2565        app_version: Version,
2566    ) -> Result<(RawCompletionResponse, Option<EditPredictionUsage>)> {
2567        let url = if let Some(custom_url) = custom_url {
2568            custom_url.as_ref().clone()
2569        } else {
2570            client
2571                .http_client()
2572                .build_zed_llm_url("/predict_edits/raw", &[])?
2573        };
2574
2575        Self::send_api_request(
2576            |builder| {
2577                let req = builder
2578                    .uri(url.as_ref())
2579                    .body(serde_json::to_string(&request)?.into());
2580                Ok(req?)
2581            },
2582            client,
2583            llm_token,
2584            organization_id,
2585            app_version,
2586            true,
2587        )
2588        .await
2589    }
2590
2591    pub(crate) async fn send_v3_request(
2592        input: ZetaPromptInput,
2593        preferred_experiment: Option<String>,
2594        client: Arc<Client>,
2595        llm_token: LlmApiToken,
2596        organization_id: Option<OrganizationId>,
2597        app_version: Version,
2598        trigger: PredictEditsRequestTrigger,
2599        mode: PredictEditsMode,
2600    ) -> Result<(PredictEditsV3Response, Option<EditPredictionUsage>)> {
2601        let url = client
2602            .http_client()
2603            .build_zed_llm_url("/predict_edits/v3", &[])?;
2604
2605        let request = PredictEditsV3Request { input, trigger };
2606
2607        let json_bytes = serde_json::to_vec(&request)?;
2608        let compressed = zstd::encode_all(&json_bytes[..], 3)?;
2609
2610        Self::send_api_request(
2611            |builder| {
2612                let builder = builder
2613                    .uri(url.as_ref())
2614                    .header("Content-Encoding", "zstd")
2615                    .header(PREDICT_EDITS_MODE_HEADER_NAME, mode.as_ref());
2616                let builder = if let Some(preferred_experiment) = preferred_experiment.as_deref() {
2617                    builder.header(PREFERRED_EXPERIMENT_HEADER_NAME, preferred_experiment)
2618                } else {
2619                    builder
2620                };
2621                let req = builder.body(compressed.clone().into());
2622                Ok(req?)
2623            },
2624            client,
2625            llm_token,
2626            organization_id,
2627            app_version,
2628            true,
2629        )
2630        .await
2631    }
2632
2633    async fn send_api_request<Res>(
2634        build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
2635        client: Arc<Client>,
2636        llm_token: LlmApiToken,
2637        organization_id: Option<OrganizationId>,
2638        app_version: Version,
2639        require_auth: bool,
2640    ) -> Result<(Res, Option<EditPredictionUsage>)>
2641    where
2642        Res: DeserializeOwned,
2643    {
2644        let http_client = client.http_client();
2645        let mut token = if require_auth {
2646            Some(
2647                client
2648                    .acquire_llm_token(&llm_token, organization_id.clone())
2649                    .await?,
2650            )
2651        } else {
2652            client
2653                .acquire_llm_token(&llm_token, organization_id.clone())
2654                .await
2655                .ok()
2656        };
2657        let mut did_retry = false;
2658
2659        loop {
2660            let request_builder = http_client::Request::builder().method(Method::POST);
2661
2662            let mut request_builder = request_builder
2663                .header("Content-Type", "application/json")
2664                .header(ZED_VERSION_HEADER_NAME, app_version.to_string());
2665
2666            // Only add Authorization header if we have a token
2667            if let Some(ref token_value) = token {
2668                request_builder =
2669                    request_builder.header("Authorization", format!("Bearer {}", token_value));
2670            }
2671
2672            let request = build(request_builder)?;
2673
2674            let mut response = http_client.send(request).await?;
2675
2676            if let Some(minimum_required_version) = response
2677                .headers()
2678                .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
2679                .and_then(|version| Version::from_str(version.to_str().ok()?).ok())
2680            {
2681                anyhow::ensure!(
2682                    app_version >= minimum_required_version,
2683                    ZedUpdateRequiredError {
2684                        minimum_version: minimum_required_version
2685                    }
2686                );
2687            }
2688
2689            if response.status().is_success() {
2690                let usage = EditPredictionUsage::from_headers(response.headers()).ok();
2691
2692                let mut body = Vec::new();
2693                response.body_mut().read_to_end(&mut body).await?;
2694                return Ok((serde_json::from_slice(&body)?, usage));
2695            } else if !did_retry && token.is_some() && response.needs_llm_token_refresh() {
2696                did_retry = true;
2697                token = Some(
2698                    client
2699                        .refresh_llm_token(&llm_token, organization_id.clone())
2700                        .await?,
2701                );
2702            } else {
2703                let mut body = String::new();
2704                response.body_mut().read_to_string(&mut body).await?;
2705                anyhow::bail!(
2706                    "Request failed with status: {:?}\nBody: {}",
2707                    response.status(),
2708                    body
2709                );
2710            }
2711        }
2712    }
2713
2714    pub fn refresh_context(
2715        &mut self,
2716        project: &Entity<Project>,
2717        buffer: &Entity<language::Buffer>,
2718        cursor_position: language::Anchor,
2719        cx: &mut Context<Self>,
2720    ) {
2721        self.get_or_init_project(project, cx)
2722            .context
2723            .update(cx, |store, cx| {
2724                store.refresh(buffer.clone(), cursor_position, cx);
2725            });
2726    }
2727
2728    #[cfg(feature = "cli-support")]
2729    pub fn set_context_for_buffer(
2730        &mut self,
2731        project: &Entity<Project>,
2732        related_files: Vec<RelatedFile>,
2733        cx: &mut Context<Self>,
2734    ) {
2735        self.get_or_init_project(project, cx)
2736            .context
2737            .update(cx, |store, cx| {
2738                store.set_related_files(related_files, cx);
2739            });
2740    }
2741
2742    #[cfg(feature = "cli-support")]
2743    pub fn set_recent_paths_for_project(
2744        &mut self,
2745        project: &Entity<Project>,
2746        paths: impl IntoIterator<Item = project::ProjectPath>,
2747        cx: &mut Context<Self>,
2748    ) {
2749        let project_state = self.get_or_init_project(project, cx);
2750        project_state.recent_paths = paths.into_iter().collect();
2751    }
2752
2753    fn is_file_open_source(
2754        &self,
2755        project: &Entity<Project>,
2756        file: &Arc<dyn File>,
2757        cx: &App,
2758    ) -> bool {
2759        if !file.is_local() || file.is_private() {
2760            return false;
2761        }
2762        let Some(project_state) = self.projects.get(&project.entity_id()) else {
2763            return false;
2764        };
2765        project_state
2766            .license_detection_watchers
2767            .get(&file.worktree_id(cx))
2768            .as_ref()
2769            .is_some_and(|watcher| watcher.is_project_open_source())
2770    }
2771
2772    pub(crate) fn is_data_collection_enabled(&self, cx: &App) -> bool {
2773        self.data_collection_choice.is_enabled(cx)
2774    }
2775
2776    fn load_data_collection_choice(cx: &App) -> DataCollectionChoice {
2777        let choice = KeyValueStore::global(cx)
2778            .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2779            .log_err()
2780            .flatten();
2781
2782        match choice.as_deref() {
2783            Some("true") => DataCollectionChoice::Enabled,
2784            Some("false") => DataCollectionChoice::Disabled,
2785            Some(_) => {
2786                log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
2787                DataCollectionChoice::NotAnswered
2788            }
2789            None => DataCollectionChoice::NotAnswered,
2790        }
2791    }
2792
2793    fn toggle_data_collection_choice(&mut self, cx: &mut Context<Self>) {
2794        self.data_collection_choice = self.data_collection_choice.toggle();
2795        let new_choice = self.data_collection_choice;
2796        let is_enabled = new_choice.is_enabled(cx);
2797        let kvp = KeyValueStore::global(cx);
2798        db::write_and_log(cx, move || async move {
2799            kvp.write_kvp(
2800                ZED_PREDICT_DATA_COLLECTION_CHOICE.into(),
2801                is_enabled.to_string(),
2802            )
2803            .await
2804        });
2805    }
2806
2807    pub fn shown_predictions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
2808        self.shown_predictions.iter()
2809    }
2810
2811    pub fn shown_completions_len(&self) -> usize {
2812        self.shown_predictions.len()
2813    }
2814
2815    pub fn is_prediction_rated(&self, id: &EditPredictionId) -> bool {
2816        self.rated_predictions.contains(id)
2817    }
2818
2819    pub fn rate_prediction(
2820        &mut self,
2821        prediction: &EditPrediction,
2822        rating: EditPredictionRating,
2823        feedback: String,
2824        cx: &mut Context<Self>,
2825    ) {
2826        let organization = self.user_store.read(cx).current_organization();
2827
2828        self.rated_predictions.insert(prediction.id.clone());
2829
2830        cx.background_spawn({
2831            let client = self.client.clone();
2832            let prediction_id = prediction.id.to_string();
2833            let inputs = serde_json::to_value(&prediction.inputs);
2834            let output = prediction
2835                .edit_preview
2836                .as_unified_diff(prediction.snapshot.file(), &prediction.edits);
2837            async move {
2838                client
2839                    .cloud_client()
2840                    .submit_edit_prediction_feedback(SubmitEditPredictionFeedbackBody {
2841                        organization_id: organization.map(|organization| organization.id.clone()),
2842                        request_id: prediction_id,
2843                        rating: match rating {
2844                            EditPredictionRating::Positive => "positive".to_string(),
2845                            EditPredictionRating::Negative => "negative".to_string(),
2846                        },
2847                        inputs: inputs?,
2848                        output,
2849                        feedback,
2850                    })
2851                    .await?;
2852
2853                anyhow::Ok(())
2854            }
2855        })
2856        .detach_and_log_err(cx);
2857
2858        cx.notify();
2859    }
2860}
2861
2862fn collaborator_edit_overlaps_locality_region(
2863    project_state: &ProjectState,
2864    project: &Entity<Project>,
2865    buffer: &Entity<Buffer>,
2866    snapshot: &BufferSnapshot,
2867    edit_range: &Range<Anchor>,
2868    cx: &App,
2869) -> bool {
2870    let Some((active_buffer, Some(position))) = project_state.active_buffer(project, cx) else {
2871        return false;
2872    };
2873
2874    if active_buffer.entity_id() != buffer.entity_id() {
2875        return false;
2876    }
2877
2878    let locality_point_range = expand_context_syntactically_then_linewise(
2879        snapshot,
2880        (position..position).to_point(snapshot),
2881        COLLABORATOR_EDIT_LOCALITY_CONTEXT_TOKENS,
2882    );
2883    let locality_anchor_range = snapshot.anchor_range_inside(locality_point_range);
2884
2885    edit_range.overlaps(&locality_anchor_range, snapshot)
2886}
2887
2888fn merge_trailing_events_if_needed(
2889    events: &mut VecDeque<StoredEvent>,
2890    end_snapshot: &TextBufferSnapshot,
2891    latest_snapshot: &TextBufferSnapshot,
2892    latest_edit_range: &Range<Anchor>,
2893) {
2894    if let Some(last_event) = events.back() {
2895        if last_event.old_snapshot.remote_id() != latest_snapshot.remote_id() {
2896            return;
2897        }
2898        if !latest_snapshot
2899            .version
2900            .observed_all(&last_event.new_snapshot_version)
2901        {
2902            return;
2903        }
2904    }
2905
2906    let mut next_old_event = None;
2907    let mut mergeable_count = 0;
2908    for old_event in events.iter().rev() {
2909        if let Some(next_old_event) = next_old_event
2910            && !old_event.can_merge(next_old_event, latest_snapshot, latest_edit_range)
2911        {
2912            break;
2913        }
2914        mergeable_count += 1;
2915        next_old_event = Some(old_event);
2916    }
2917
2918    if mergeable_count <= 1 {
2919        return;
2920    }
2921
2922    let mut events_to_merge = events.range(events.len() - mergeable_count..).peekable();
2923    let oldest_event = events_to_merge.peek().unwrap();
2924    let oldest_snapshot = oldest_event.old_snapshot.clone();
2925    let newest_snapshot = end_snapshot;
2926    let mut merged_edit_range = oldest_event.total_edit_range.clone();
2927
2928    for event in events.range(events.len() - mergeable_count + 1..) {
2929        merged_edit_range =
2930            merge_anchor_ranges(&merged_edit_range, &event.total_edit_range, latest_snapshot);
2931    }
2932
2933    if let Some((diff, edit_range)) = compute_diff_between_snapshots_in_range(
2934        &oldest_snapshot,
2935        newest_snapshot,
2936        &merged_edit_range,
2937    ) {
2938        let merged_event = match oldest_event.event.as_ref() {
2939            zeta_prompt::Event::BufferChange {
2940                old_path,
2941                path,
2942                in_open_source_repo,
2943                ..
2944            } => StoredEvent {
2945                event: Arc::new(zeta_prompt::Event::BufferChange {
2946                    old_path: old_path.clone(),
2947                    path: path.clone(),
2948                    diff,
2949                    in_open_source_repo: *in_open_source_repo,
2950                    predicted: events_to_merge.all(|e| {
2951                        matches!(
2952                            e.event.as_ref(),
2953                            zeta_prompt::Event::BufferChange {
2954                                predicted: true,
2955                                ..
2956                            }
2957                        )
2958                    }),
2959                }),
2960                old_snapshot: oldest_snapshot.clone(),
2961                new_snapshot_version: newest_snapshot.version.clone(),
2962                total_edit_range: newest_snapshot.anchor_before(edit_range.start)
2963                    ..newest_snapshot.anchor_before(edit_range.end),
2964            },
2965        };
2966        events.truncate(events.len() - mergeable_count);
2967        events.push_back(merged_event);
2968    }
2969}
2970
2971fn merge_anchor_ranges(
2972    left: &Range<Anchor>,
2973    right: &Range<Anchor>,
2974    snapshot: &TextBufferSnapshot,
2975) -> Range<Anchor> {
2976    let start = if left.start.cmp(&right.start, snapshot).is_le() {
2977        left.start
2978    } else {
2979        right.start
2980    };
2981    let end = if left.end.cmp(&right.end, snapshot).is_ge() {
2982        left.end
2983    } else {
2984        right.end
2985    };
2986    start..end
2987}
2988
2989#[derive(Error, Debug)]
2990#[error(
2991    "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
2992)]
2993pub struct ZedUpdateRequiredError {
2994    minimum_version: Version,
2995}
2996
2997#[derive(Debug, Clone, Copy)]
2998pub enum DataCollectionChoice {
2999    NotAnswered,
3000    Enabled,
3001    Disabled,
3002}
3003
3004impl DataCollectionChoice {
3005    pub fn is_enabled(self, cx: &App) -> bool {
3006        if cx.is_staff() {
3007            return true;
3008        }
3009        match self {
3010            Self::Enabled => true,
3011            Self::NotAnswered | Self::Disabled => false,
3012        }
3013    }
3014
3015    #[must_use]
3016    pub fn toggle(&self) -> DataCollectionChoice {
3017        match self {
3018            Self::Enabled => Self::Disabled,
3019            Self::Disabled => Self::Enabled,
3020            Self::NotAnswered => Self::Enabled,
3021        }
3022    }
3023}
3024
3025impl From<bool> for DataCollectionChoice {
3026    fn from(value: bool) -> Self {
3027        match value {
3028            true => DataCollectionChoice::Enabled,
3029            false => DataCollectionChoice::Disabled,
3030        }
3031    }
3032}
3033
3034struct ZedPredictUpsell;
3035
3036impl Dismissable for ZedPredictUpsell {
3037    const KEY: &'static str = "dismissed-edit-predict-upsell";
3038
3039    fn dismissed(cx: &App) -> bool {
3040        // To make this backwards compatible with older versions of Zed, we
3041        // check if the user has seen the previous Edit Prediction Onboarding
3042        // before, by checking the data collection choice which was written to
3043        // the database once the user clicked on "Accept and Enable"
3044        let kvp = KeyValueStore::global(cx);
3045        if kvp
3046            .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
3047            .log_err()
3048            .is_some_and(|s| s.is_some())
3049        {
3050            return true;
3051        }
3052
3053        kvp.read_kvp(Self::KEY)
3054            .log_err()
3055            .is_some_and(|s| s.is_some())
3056    }
3057}
3058
3059pub fn should_show_upsell_modal(cx: &App) -> bool {
3060    !ZedPredictUpsell::dismissed(cx)
3061}
3062
3063pub fn init(cx: &mut App) {
3064    cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
3065        workspace.register_action(
3066            move |workspace, _: &zed_actions::OpenZedPredictOnboarding, window, cx| {
3067                ZedPredictModal::toggle(
3068                    workspace,
3069                    workspace.user_store().clone(),
3070                    workspace.client().clone(),
3071                    window,
3072                    cx,
3073                )
3074            },
3075        );
3076
3077        workspace.register_action(|workspace, _: &ResetOnboarding, _window, cx| {
3078            update_settings_file(workspace.app_state().fs.clone(), cx, move |settings, _| {
3079                settings
3080                    .project
3081                    .all_languages
3082                    .edit_predictions
3083                    .get_or_insert_default()
3084                    .provider = Some(EditPredictionProvider::None)
3085            });
3086        });
3087        fn copilot_for_project(project: &Entity<Project>, cx: &mut App) -> Option<Entity<Copilot>> {
3088            EditPredictionStore::try_global(cx).and_then(|store| {
3089                store.update(cx, |this, cx| this.start_copilot_for_project(project, cx))
3090            })
3091        }
3092
3093        workspace.register_action(|workspace, _: &SignIn, window, cx| {
3094            if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
3095                copilot_ui::initiate_sign_in(copilot, window, cx);
3096            }
3097        });
3098        workspace.register_action(|workspace, _: &Reinstall, window, cx| {
3099            if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
3100                copilot_ui::reinstall_and_sign_in(copilot, window, cx);
3101            }
3102        });
3103        workspace.register_action(|workspace, _: &SignOut, window, cx| {
3104            if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
3105                copilot_ui::initiate_sign_out(copilot, window, cx);
3106            }
3107        });
3108    })
3109    .detach();
3110}