edit_prediction.rs

   1use anyhow::Result;
   2use client::{Client, EditPredictionUsage, NeedsLlmTokenRefresh, UserStore, global_llm_token};
   3use cloud_api_client::LlmApiToken;
   4use cloud_api_types::{OrganizationId, SubmitEditPredictionFeedbackBody};
   5use cloud_llm_client::predict_edits_v3::{
   6    PredictEditsV3Request, PredictEditsV3Response, RawCompletionRequest, RawCompletionResponse,
   7};
   8use cloud_llm_client::{
   9    EditPredictionRejectReason, EditPredictionRejection,
  10    MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
  11    PredictEditsRequestTrigger, RejectEditPredictionsBodyRef, ZED_VERSION_HEADER_NAME,
  12};
  13use collections::{HashMap, HashSet};
  14use copilot::{Copilot, Reinstall, SignIn, SignOut};
  15use credentials_provider::CredentialsProvider;
  16use db::kvp::{Dismissable, KeyValueStore};
  17use edit_prediction_context::{RelatedExcerptStore, RelatedExcerptStoreEvent, RelatedFile};
  18use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
  19use futures::{
  20    AsyncReadExt as _, FutureExt as _, StreamExt as _,
  21    channel::mpsc::{self, UnboundedReceiver},
  22    select_biased,
  23};
  24use gpui::BackgroundExecutor;
  25use gpui::http_client::Url;
  26use gpui::{
  27    App, AsyncApp, Entity, EntityId, Global, SharedString, Task, TaskExt, WeakEntity, actions,
  28    http_client::{self, AsyncBody, Method},
  29    prelude::*,
  30};
  31use heapless::Vec as ArrayVec;
  32use language::language_settings::all_language_settings;
  33use language::{Anchor, Buffer, EditPreview, File, Point, TextBufferSnapshot, ToOffset, ToPoint};
  34use language::{BufferSnapshot, OffsetRangeExt};
  35use project::{DisableAiSettings, Project, ProjectPath, WorktreeId};
  36use release_channel::AppVersion;
  37use semver::Version;
  38use serde::de::DeserializeOwned;
  39use settings::{
  40    EditPredictionPromptFormat, EditPredictionProvider, Settings as _, update_settings_file,
  41};
  42use std::collections::{VecDeque, hash_map};
  43use std::env;
  44use text::{AnchorRangeExt, Edit};
  45use workspace::{AppState, Workspace};
  46use zeta_prompt::{ZetaFormat, ZetaPromptInput};
  47
  48use std::mem;
  49use std::ops::Range;
  50use std::path::Path;
  51use std::rc::Rc;
  52use std::str::FromStr as _;
  53use std::sync::Arc;
  54use std::time::{Duration, Instant};
  55
  56use thiserror::Error;
  57use util::{RangeExt as _, ResultExt as _};
  58
  59pub mod cursor_excerpt;
  60pub mod example_spec;
  61pub mod fim;
  62mod license_detection;
  63pub mod mercury;
  64pub mod metrics;
  65pub mod ollama;
  66mod onboarding_modal;
  67pub mod open_ai_response;
  68mod prediction;
  69
  70pub mod udiff;
  71
  72mod capture_example;
  73pub mod open_ai_compatible;
  74mod zed_edit_prediction_delegate;
  75pub mod zeta;
  76
  77#[cfg(test)]
  78mod edit_prediction_tests;
  79
  80use crate::cursor_excerpt::expand_context_syntactically_then_linewise;
  81use crate::example_spec::ExampleSpec;
  82use crate::license_detection::LicenseDetectionWatcher;
  83use crate::mercury::Mercury;
  84pub use crate::metrics::{KeptRateResult, compute_kept_rate};
  85use crate::onboarding_modal::ZedPredictModal;
  86pub use crate::prediction::EditPrediction;
  87pub use crate::prediction::EditPredictionId;
  88use crate::prediction::EditPredictionResult;
  89pub use capture_example::capture_example;
  90pub use language_model::ApiKeyState;
  91pub use telemetry_events::EditPredictionRating;
  92pub use zed_edit_prediction_delegate::ZedEditPredictionDelegate;
  93
  94actions!(
  95    edit_prediction,
  96    [
  97        /// Resets the edit prediction onboarding state.
  98        ResetOnboarding,
  99        /// Clears the edit prediction history.
 100        ClearHistory,
 101    ]
 102);
 103
 104/// Maximum number of events to track.
 105const EVENT_COUNT_MAX: usize = 10;
 106const CHANGE_GROUPING_LINE_SPAN: u32 = 8;
 107const EDIT_HISTORY_DIFF_SIZE_LIMIT: usize = 2048 * 3; // ~2048 tokens or ~50% of typical prompt budget
 108const COLLABORATOR_EDIT_LOCALITY_CONTEXT_TOKENS: usize = 512;
 109const LAST_CHANGE_GROUPING_TIME: Duration = Duration::from_secs(1);
 110const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
 111const REJECT_REQUEST_DEBOUNCE: Duration = Duration::from_secs(15);
 112const EDIT_PREDICTION_SETTLED_EVENT: &str = "Edit Prediction Settled";
 113const EDIT_PREDICTION_SETTLED_TTL: Duration = Duration::from_secs(60 * 5);
 114const EDIT_PREDICTION_SETTLED_QUIESCENCE: Duration = Duration::from_secs(10);
 115
 116pub struct EditPredictionJumpsFeatureFlag;
 117
 118impl FeatureFlag for EditPredictionJumpsFeatureFlag {
 119    const NAME: &'static str = "edit_prediction_jumps";
 120}
 121
 122#[derive(Clone)]
 123struct EditPredictionStoreGlobal(Entity<EditPredictionStore>);
 124
 125impl Global for EditPredictionStoreGlobal {}
 126
 127/// Configuration for using the raw Zeta2 endpoint.
 128/// When set, the client uses the raw endpoint and constructs the prompt itself.
 129/// The version is also used as the Baseten environment name (lowercased).
 130#[derive(Clone)]
 131pub struct Zeta2RawConfig {
 132    pub model_id: Option<String>,
 133    pub environment: Option<String>,
 134    pub format: ZetaFormat,
 135}
 136
 137pub struct EditPredictionStore {
 138    client: Arc<Client>,
 139    user_store: Entity<UserStore>,
 140    llm_token: LlmApiToken,
 141    _fetch_experiments_task: Task<()>,
 142    projects: HashMap<EntityId, ProjectState>,
 143    update_required: bool,
 144    edit_prediction_model: EditPredictionModel,
 145    zeta2_raw_config: Option<Zeta2RawConfig>,
 146    preferred_experiment: Option<String>,
 147    available_experiments: Vec<String>,
 148    pub mercury: Mercury,
 149    data_collection_choice: DataCollectionChoice,
 150    reject_predictions_tx: mpsc::UnboundedSender<EditPredictionRejectionPayload>,
 151    settled_predictions_tx: mpsc::UnboundedSender<Instant>,
 152    shown_predictions: VecDeque<EditPrediction>,
 153    rated_predictions: HashSet<EditPredictionId>,
 154    #[cfg(test)]
 155    settled_event_callback: Option<Box<dyn Fn(EditPredictionId, String)>>,
 156    credentials_provider: Arc<dyn CredentialsProvider>,
 157}
 158
 159pub(crate) struct EditPredictionRejectionPayload {
 160    rejection: EditPredictionRejection,
 161    organization_id: Option<OrganizationId>,
 162}
 163
 164#[derive(Copy, Clone, PartialEq, Eq)]
 165pub enum EditPredictionModel {
 166    Zeta,
 167    Fim { format: EditPredictionPromptFormat },
 168    Mercury,
 169}
 170
 171#[derive(Clone)]
 172pub struct EditPredictionModelInput {
 173    project: Entity<Project>,
 174    buffer: Entity<Buffer>,
 175    snapshot: BufferSnapshot,
 176    position: Anchor,
 177    events: Vec<Arc<zeta_prompt::Event>>,
 178    related_files: Vec<RelatedFile>,
 179    trigger: PredictEditsRequestTrigger,
 180    diagnostic_search_range: Range<Point>,
 181    debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
 182    can_collect_data: bool,
 183    is_open_source: bool,
 184}
 185
 186#[derive(Debug)]
 187pub enum DebugEvent {
 188    ContextRetrievalStarted(ContextRetrievalStartedDebugEvent),
 189    ContextRetrievalFinished(ContextRetrievalFinishedDebugEvent),
 190    EditPredictionStarted(EditPredictionStartedDebugEvent),
 191    EditPredictionFinished(EditPredictionFinishedDebugEvent),
 192}
 193
 194#[derive(Debug)]
 195pub struct ContextRetrievalStartedDebugEvent {
 196    pub project_entity_id: EntityId,
 197    pub timestamp: Instant,
 198    pub search_prompt: String,
 199}
 200
 201#[derive(Debug)]
 202pub struct ContextRetrievalFinishedDebugEvent {
 203    pub project_entity_id: EntityId,
 204    pub timestamp: Instant,
 205    pub metadata: Vec<(&'static str, SharedString)>,
 206}
 207
 208#[derive(Debug)]
 209pub struct EditPredictionStartedDebugEvent {
 210    pub buffer: WeakEntity<Buffer>,
 211    pub position: Anchor,
 212    pub prompt: Option<String>,
 213}
 214
 215#[derive(Debug)]
 216pub struct EditPredictionFinishedDebugEvent {
 217    pub buffer: WeakEntity<Buffer>,
 218    pub position: Anchor,
 219    pub model_output: Option<String>,
 220}
 221
 222/// An event with associated metadata for reconstructing buffer state.
 223#[derive(Clone)]
 224pub struct StoredEvent {
 225    pub event: Arc<zeta_prompt::Event>,
 226    pub old_snapshot: TextBufferSnapshot,
 227    pub new_snapshot_version: clock::Global,
 228    pub total_edit_range: Range<Anchor>,
 229}
 230
 231impl StoredEvent {
 232    fn can_merge(
 233        &self,
 234        next_old_event: &StoredEvent,
 235        latest_snapshot: &TextBufferSnapshot,
 236        latest_edit_range: &Range<Anchor>,
 237    ) -> bool {
 238        // Events must be for the same buffer and be contiguous across included snapshots to be mergeable.
 239        if self.old_snapshot.remote_id() != next_old_event.old_snapshot.remote_id() {
 240            return false;
 241        }
 242        if self.old_snapshot.remote_id() != latest_snapshot.remote_id() {
 243            return false;
 244        }
 245        if self.new_snapshot_version != next_old_event.old_snapshot.version {
 246            return false;
 247        }
 248        if !latest_snapshot
 249            .version
 250            .observed_all(&next_old_event.new_snapshot_version)
 251        {
 252            return false;
 253        }
 254
 255        let a_is_predicted = matches!(
 256            self.event.as_ref(),
 257            zeta_prompt::Event::BufferChange {
 258                predicted: true,
 259                ..
 260            }
 261        );
 262        let b_is_predicted = matches!(
 263            next_old_event.event.as_ref(),
 264            zeta_prompt::Event::BufferChange {
 265                predicted: true,
 266                ..
 267            }
 268        );
 269
 270        // If events come from the same source (both predicted or both manual) then
 271        // we would have coalesced them already.
 272        if a_is_predicted == b_is_predicted {
 273            return false;
 274        }
 275
 276        let left_range = self.total_edit_range.to_point(latest_snapshot);
 277        let right_range = next_old_event.total_edit_range.to_point(latest_snapshot);
 278        let latest_range = latest_edit_range.to_point(latest_snapshot);
 279
 280        // Events near to the latest edit are not merged if their sources differ.
 281        if lines_between_ranges(&left_range, &latest_range)
 282            .min(lines_between_ranges(&right_range, &latest_range))
 283            <= CHANGE_GROUPING_LINE_SPAN
 284        {
 285            return false;
 286        }
 287
 288        // Events that are distant from each other are not merged.
 289        if lines_between_ranges(&left_range, &right_range) > CHANGE_GROUPING_LINE_SPAN {
 290            return false;
 291        }
 292
 293        true
 294    }
 295}
 296
 297fn lines_between_ranges(left: &Range<Point>, right: &Range<Point>) -> u32 {
 298    if left.start > right.end {
 299        return left.start.row - right.end.row;
 300    }
 301    if right.start > left.end {
 302        return right.start.row - left.end.row;
 303    }
 304    0
 305}
 306
 307struct ProjectState {
 308    events: VecDeque<StoredEvent>,
 309    last_event: Option<LastEvent>,
 310    recent_paths: VecDeque<ProjectPath>,
 311    registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
 312    current_prediction: Option<CurrentEditPrediction>,
 313    next_pending_prediction_id: usize,
 314    pending_predictions: ArrayVec<PendingPrediction, 2, u8>,
 315    debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
 316    last_edit_prediction_refresh: Option<(EntityId, Instant)>,
 317    last_jump_prediction_refresh: Option<(EntityId, Instant)>,
 318    cancelled_predictions: HashSet<usize>,
 319    context: Entity<RelatedExcerptStore>,
 320    license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
 321    _subscriptions: [gpui::Subscription; 2],
 322    copilot: Option<Entity<Copilot>>,
 323}
 324
 325impl ProjectState {
 326    pub fn events(&self, cx: &App) -> Vec<StoredEvent> {
 327        self.events
 328            .iter()
 329            .cloned()
 330            .chain(self.last_event.as_ref().iter().flat_map(|event| {
 331                let (one, two) = event.split_by_pause();
 332                let one = one.finalize(&self.license_detection_watchers, cx);
 333                let two = two.and_then(|two| two.finalize(&self.license_detection_watchers, cx));
 334                one.into_iter().chain(two)
 335            }))
 336            .collect()
 337    }
 338
 339    fn cancel_pending_prediction(
 340        &mut self,
 341        pending_prediction: PendingPrediction,
 342        cx: &mut Context<EditPredictionStore>,
 343    ) {
 344        self.cancelled_predictions.insert(pending_prediction.id);
 345
 346        if pending_prediction.drop_on_cancel {
 347            drop(pending_prediction.task);
 348        } else {
 349            cx.spawn(async move |this, cx| {
 350                let Some(prediction_id) = pending_prediction.task.await else {
 351                    return;
 352                };
 353
 354                this.update(cx, |this, cx| {
 355                    this.reject_prediction(
 356                        prediction_id,
 357                        EditPredictionRejectReason::Canceled,
 358                        false,
 359                        None,
 360                        None,
 361                        cx,
 362                    );
 363                })
 364                .ok();
 365            })
 366            .detach()
 367        }
 368    }
 369
 370    fn active_buffer(
 371        &self,
 372        project: &Entity<Project>,
 373        cx: &App,
 374    ) -> Option<(Entity<Buffer>, Option<Anchor>)> {
 375        let project = project.read(cx);
 376        let active_path = project.path_for_entry(project.active_entry()?, cx)?;
 377        let active_buffer = project.buffer_store().read(cx).get_by_path(&active_path)?;
 378        let registered_buffer = self.registered_buffers.get(&active_buffer.entity_id())?;
 379        Some((active_buffer, registered_buffer.last_position))
 380    }
 381}
 382
 383#[derive(Debug, Clone)]
 384struct CurrentEditPrediction {
 385    pub requested_by: PredictionRequestedBy,
 386    pub prediction: EditPrediction,
 387    pub was_shown: bool,
 388    pub shown_with: Option<edit_prediction_types::SuggestionDisplayType>,
 389    pub e2e_latency: std::time::Duration,
 390}
 391
 392impl CurrentEditPrediction {
 393    fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
 394        let Some(new_edits) = self
 395            .prediction
 396            .interpolate(&self.prediction.buffer.read(cx))
 397        else {
 398            return false;
 399        };
 400
 401        if self.prediction.buffer != old_prediction.prediction.buffer {
 402            return true;
 403        }
 404
 405        let Some(old_edits) = old_prediction
 406            .prediction
 407            .interpolate(&old_prediction.prediction.buffer.read(cx))
 408        else {
 409            return true;
 410        };
 411
 412        let requested_by_buffer_id = self.requested_by.buffer_id();
 413
 414        // This reduces the occurrence of UI thrash from replacing edits
 415        //
 416        // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
 417        if requested_by_buffer_id == Some(self.prediction.buffer.entity_id())
 418            && requested_by_buffer_id == Some(old_prediction.prediction.buffer.entity_id())
 419            && old_edits.len() == 1
 420            && new_edits.len() == 1
 421        {
 422            let (old_range, old_text) = &old_edits[0];
 423            let (new_range, new_text) = &new_edits[0];
 424            new_range == old_range && new_text.starts_with(old_text.as_ref())
 425        } else {
 426            true
 427        }
 428    }
 429}
 430
 431#[derive(Debug, Clone)]
 432enum PredictionRequestedBy {
 433    DiagnosticsUpdate,
 434    Buffer(EntityId),
 435}
 436
 437impl PredictionRequestedBy {
 438    pub fn buffer_id(&self) -> Option<EntityId> {
 439        match self {
 440            PredictionRequestedBy::DiagnosticsUpdate => None,
 441            PredictionRequestedBy::Buffer(buffer_id) => Some(*buffer_id),
 442        }
 443    }
 444}
 445
 446const DIAGNOSTIC_LINES_RANGE: u32 = 20;
 447
 448#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
 449pub enum DiagnosticSearchScope {
 450    Local,
 451    Global,
 452}
 453
 454#[derive(Debug)]
 455struct PendingPrediction {
 456    id: usize,
 457    task: Task<Option<EditPredictionId>>,
 458    /// If true, the task is dropped immediately on cancel (cancelling the HTTP request).
 459    /// If false, the task is awaited to completion so rejection can be reported.
 460    drop_on_cancel: bool,
 461}
 462
 463/// A prediction from the perspective of a buffer.
 464#[derive(Debug)]
 465enum BufferEditPrediction<'a> {
 466    Local { prediction: &'a EditPrediction },
 467    Jump { prediction: &'a EditPrediction },
 468}
 469
 470#[cfg(test)]
 471impl std::ops::Deref for BufferEditPrediction<'_> {
 472    type Target = EditPrediction;
 473
 474    fn deref(&self) -> &Self::Target {
 475        match self {
 476            BufferEditPrediction::Local { prediction } => prediction,
 477            BufferEditPrediction::Jump { prediction } => prediction,
 478        }
 479    }
 480}
 481
 482#[derive(Clone)]
 483struct PendingSettledPrediction {
 484    request_id: EditPredictionId,
 485    editable_anchor_range: Range<Anchor>,
 486    editable_region_before_prediction: String,
 487    predicted_editable_region: String,
 488    ts_error_count_before_prediction: usize,
 489    ts_error_count_after_prediction: usize,
 490    example: Option<ExampleSpec>,
 491    enqueued_at: Instant,
 492    last_edit_at: Instant,
 493    e2e_latency: std::time::Duration,
 494}
 495
 496struct RegisteredBuffer {
 497    file: Option<Arc<dyn File>>,
 498    snapshot: TextBufferSnapshot,
 499    pending_predictions: Vec<PendingSettledPrediction>,
 500    last_position: Option<Anchor>,
 501    _subscriptions: [gpui::Subscription; 2],
 502}
 503
 504#[derive(Clone)]
 505struct LastEvent {
 506    old_snapshot: TextBufferSnapshot,
 507    new_snapshot: TextBufferSnapshot,
 508    old_file: Option<Arc<dyn File>>,
 509    new_file: Option<Arc<dyn File>>,
 510    latest_edit_range: Range<Anchor>,
 511    total_edit_range: Range<Anchor>,
 512    total_edit_range_at_last_pause_boundary: Option<Range<Anchor>>,
 513    predicted: bool,
 514    snapshot_after_last_editing_pause: Option<TextBufferSnapshot>,
 515    last_edit_time: Option<Instant>,
 516}
 517
 518impl LastEvent {
 519    pub fn finalize(
 520        &self,
 521        license_detection_watchers: &HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
 522        cx: &App,
 523    ) -> Option<StoredEvent> {
 524        let path = buffer_path_with_id_fallback(self.new_file.as_ref(), &self.new_snapshot, cx);
 525        let old_path = buffer_path_with_id_fallback(self.old_file.as_ref(), &self.old_snapshot, cx);
 526
 527        let in_open_source_repo =
 528            [self.new_file.as_ref(), self.old_file.as_ref()]
 529                .iter()
 530                .all(|file| {
 531                    file.is_some_and(|file| {
 532                        license_detection_watchers
 533                            .get(&file.worktree_id(cx))
 534                            .is_some_and(|watcher| watcher.is_project_open_source())
 535                    })
 536                });
 537
 538        let (diff, edit_range) = compute_diff_between_snapshots_in_range(
 539            &self.old_snapshot,
 540            &self.new_snapshot,
 541            &self.total_edit_range,
 542        )?;
 543
 544        if path == old_path && diff.is_empty() {
 545            None
 546        } else {
 547            Some(StoredEvent {
 548                event: Arc::new(zeta_prompt::Event::BufferChange {
 549                    old_path,
 550                    path,
 551                    diff,
 552                    in_open_source_repo,
 553                    predicted: self.predicted,
 554                }),
 555                old_snapshot: self.old_snapshot.clone(),
 556                new_snapshot_version: self.new_snapshot.version.clone(),
 557                total_edit_range: self.new_snapshot.anchor_before(edit_range.start)
 558                    ..self.new_snapshot.anchor_before(edit_range.end),
 559            })
 560        }
 561    }
 562
 563    pub fn split_by_pause(&self) -> (LastEvent, Option<LastEvent>) {
 564        let Some(boundary_snapshot) = self.snapshot_after_last_editing_pause.as_ref() else {
 565            return (self.clone(), None);
 566        };
 567
 568        let total_edit_range_before_pause = self
 569            .total_edit_range_at_last_pause_boundary
 570            .clone()
 571            .unwrap_or_else(|| self.total_edit_range.clone());
 572
 573        let Some(total_edit_range_after_pause) =
 574            compute_total_edit_range_between_snapshots(boundary_snapshot, &self.new_snapshot)
 575        else {
 576            return (self.clone(), None);
 577        };
 578
 579        let latest_edit_range_before_pause = total_edit_range_before_pause.clone();
 580        let latest_edit_range_after_pause = total_edit_range_after_pause.clone();
 581
 582        let before = LastEvent {
 583            old_snapshot: self.old_snapshot.clone(),
 584            new_snapshot: boundary_snapshot.clone(),
 585            old_file: self.old_file.clone(),
 586            new_file: self.new_file.clone(),
 587            latest_edit_range: latest_edit_range_before_pause,
 588            total_edit_range: total_edit_range_before_pause,
 589            total_edit_range_at_last_pause_boundary: None,
 590            predicted: self.predicted,
 591            snapshot_after_last_editing_pause: None,
 592            last_edit_time: self.last_edit_time,
 593        };
 594
 595        let after = LastEvent {
 596            old_snapshot: boundary_snapshot.clone(),
 597            new_snapshot: self.new_snapshot.clone(),
 598            old_file: self.old_file.clone(),
 599            new_file: self.new_file.clone(),
 600            latest_edit_range: latest_edit_range_after_pause,
 601            total_edit_range: total_edit_range_after_pause,
 602            total_edit_range_at_last_pause_boundary: None,
 603            predicted: self.predicted,
 604            snapshot_after_last_editing_pause: None,
 605            last_edit_time: self.last_edit_time,
 606        };
 607
 608        (before, Some(after))
 609    }
 610}
 611
 612fn compute_total_edit_range_between_snapshots(
 613    old_snapshot: &TextBufferSnapshot,
 614    new_snapshot: &TextBufferSnapshot,
 615) -> Option<Range<Anchor>> {
 616    let edits: Vec<Edit<usize>> = new_snapshot
 617        .edits_since::<usize>(&old_snapshot.version)
 618        .collect();
 619
 620    let (first_edit, last_edit) = edits.first().zip(edits.last())?;
 621    let new_start_point = new_snapshot.offset_to_point(first_edit.new.start);
 622    let new_end_point = new_snapshot.offset_to_point(last_edit.new.end);
 623
 624    Some(new_snapshot.anchor_before(new_start_point)..new_snapshot.anchor_before(new_end_point))
 625}
 626
 627fn compute_old_range_for_new_range(
 628    old_snapshot: &TextBufferSnapshot,
 629    new_snapshot: &TextBufferSnapshot,
 630    total_edit_range: &Range<Anchor>,
 631) -> Option<Range<Point>> {
 632    let new_start_offset = total_edit_range.start.to_offset(new_snapshot);
 633    let new_end_offset = total_edit_range.end.to_offset(new_snapshot);
 634
 635    let edits: Vec<Edit<usize>> = new_snapshot
 636        .edits_since::<usize>(&old_snapshot.version)
 637        .collect();
 638    let mut old_start_offset = None;
 639    let mut old_end_offset = None;
 640    let mut delta: isize = 0;
 641
 642    for edit in &edits {
 643        if old_start_offset.is_none() && new_start_offset <= edit.new.end {
 644            old_start_offset = Some(if new_start_offset < edit.new.start {
 645                new_start_offset.checked_add_signed(-delta)?
 646            } else {
 647                edit.old.start
 648            });
 649        }
 650
 651        if old_end_offset.is_none() && new_end_offset <= edit.new.end {
 652            old_end_offset = Some(if new_end_offset < edit.new.start {
 653                new_end_offset.checked_add_signed(-delta)?
 654            } else {
 655                edit.old.end
 656            });
 657        }
 658
 659        delta += edit.new.len() as isize - edit.old.len() as isize;
 660    }
 661
 662    let old_start_offset =
 663        old_start_offset.unwrap_or_else(|| new_start_offset.saturating_add_signed(-delta));
 664    let old_end_offset =
 665        old_end_offset.unwrap_or_else(|| new_end_offset.saturating_add_signed(-delta));
 666
 667    Some(
 668        old_snapshot.offset_to_point(old_start_offset)
 669            ..old_snapshot.offset_to_point(old_end_offset),
 670    )
 671}
 672
 673fn compute_diff_between_snapshots_in_range(
 674    old_snapshot: &TextBufferSnapshot,
 675    new_snapshot: &TextBufferSnapshot,
 676    total_edit_range: &Range<Anchor>,
 677) -> Option<(String, Range<Point>)> {
 678    let new_start_point = total_edit_range.start.to_point(new_snapshot);
 679    let new_end_point = total_edit_range.end.to_point(new_snapshot);
 680    let old_range = compute_old_range_for_new_range(old_snapshot, new_snapshot, total_edit_range)?;
 681    let old_start_point = old_range.start;
 682    let old_end_point = old_range.end;
 683
 684    const CONTEXT_LINES: u32 = 3;
 685
 686    let old_context_start_row = old_start_point.row.saturating_sub(CONTEXT_LINES);
 687    let new_context_start_row = new_start_point.row.saturating_sub(CONTEXT_LINES);
 688    let old_context_end_row =
 689        (old_end_point.row + 1 + CONTEXT_LINES).min(old_snapshot.max_point().row);
 690    let new_context_end_row =
 691        (new_end_point.row + 1 + CONTEXT_LINES).min(new_snapshot.max_point().row);
 692
 693    let old_start_line_offset = old_snapshot.point_to_offset(Point::new(old_context_start_row, 0));
 694    let new_start_line_offset = new_snapshot.point_to_offset(Point::new(new_context_start_row, 0));
 695    let old_end_line_offset = old_snapshot
 696        .point_to_offset(Point::new(old_context_end_row + 1, 0).min(old_snapshot.max_point()));
 697    let new_end_line_offset = new_snapshot
 698        .point_to_offset(Point::new(new_context_end_row + 1, 0).min(new_snapshot.max_point()));
 699    let old_edit_range = old_start_line_offset..old_end_line_offset;
 700    let new_edit_range = new_start_line_offset..new_end_line_offset;
 701
 702    if new_edit_range.len() > EDIT_HISTORY_DIFF_SIZE_LIMIT
 703        || old_edit_range.len() > EDIT_HISTORY_DIFF_SIZE_LIMIT
 704    {
 705        return None;
 706    }
 707
 708    let old_region_text: String = old_snapshot.text_for_range(old_edit_range).collect();
 709    let new_region_text: String = new_snapshot.text_for_range(new_edit_range).collect();
 710
 711    let diff = language::unified_diff_with_offsets(
 712        &old_region_text,
 713        &new_region_text,
 714        old_context_start_row,
 715        new_context_start_row,
 716    );
 717
 718    Some((diff, new_start_point..new_end_point))
 719}
 720
 721fn buffer_path_with_id_fallback(
 722    file: Option<&Arc<dyn File>>,
 723    snapshot: &TextBufferSnapshot,
 724    cx: &App,
 725) -> Arc<Path> {
 726    if let Some(file) = file {
 727        file.full_path(cx).into()
 728    } else {
 729        Path::new(&format!("untitled-{}", snapshot.remote_id())).into()
 730    }
 731}
 732
 733impl EditPredictionStore {
 734    pub fn try_global(cx: &App) -> Option<Entity<Self>> {
 735        cx.try_global::<EditPredictionStoreGlobal>()
 736            .map(|global| global.0.clone())
 737    }
 738
 739    pub fn global(
 740        client: &Arc<Client>,
 741        user_store: &Entity<UserStore>,
 742        cx: &mut App,
 743    ) -> Entity<Self> {
 744        cx.try_global::<EditPredictionStoreGlobal>()
 745            .map(|global| global.0.clone())
 746            .unwrap_or_else(|| {
 747                let ep_store = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
 748                cx.set_global(EditPredictionStoreGlobal(ep_store.clone()));
 749                ep_store
 750            })
 751    }
 752
 753    pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
 754        let data_collection_choice = Self::load_data_collection_choice(cx);
 755
 756        let llm_token = global_llm_token(cx);
 757
 758        let (reject_tx, reject_rx) = mpsc::unbounded();
 759        cx.background_spawn({
 760            let client = client.clone();
 761            let llm_token = llm_token.clone();
 762            let app_version = AppVersion::global(cx);
 763            let background_executor = cx.background_executor().clone();
 764            async move {
 765                Self::handle_rejected_predictions(
 766                    reject_rx,
 767                    client,
 768                    llm_token,
 769                    app_version,
 770                    background_executor,
 771                )
 772                .await
 773            }
 774        })
 775        .detach();
 776
 777        let (settled_predictions_tx, settled_predictions_rx) = mpsc::unbounded();
 778        cx.spawn(async move |this, cx| {
 779            Self::run_settled_predictions_worker(this, settled_predictions_rx, cx).await;
 780        })
 781        .detach();
 782
 783        let mut current_user = user_store.read(cx).watch_current_user();
 784        let fetch_experiments_task = cx.spawn(async move |this, cx| {
 785            while current_user.borrow().is_none() {
 786                current_user.next().await;
 787            }
 788
 789            this.update(cx, |this, cx| {
 790                if cx.is_staff() {
 791                    this.refresh_available_experiments(cx);
 792                }
 793            })
 794            .log_err();
 795        });
 796
 797        let credentials_provider = zed_credentials_provider::global(cx);
 798
 799        let this = Self {
 800            projects: HashMap::default(),
 801            client,
 802            user_store,
 803            llm_token,
 804            _fetch_experiments_task: fetch_experiments_task,
 805            update_required: false,
 806            edit_prediction_model: EditPredictionModel::Zeta,
 807            zeta2_raw_config: Self::zeta2_raw_config_from_env(),
 808            preferred_experiment: None,
 809            available_experiments: Vec::new(),
 810            mercury: Mercury::new(cx),
 811
 812            data_collection_choice,
 813            reject_predictions_tx: reject_tx,
 814            settled_predictions_tx,
 815            rated_predictions: Default::default(),
 816            shown_predictions: Default::default(),
 817            #[cfg(test)]
 818            settled_event_callback: None,
 819
 820            credentials_provider,
 821        };
 822
 823        this
 824    }
 825
 826    fn zeta2_raw_config_from_env() -> Option<Zeta2RawConfig> {
 827        let version_str = env::var("ZED_ZETA_FORMAT").ok()?;
 828        let format = ZetaFormat::parse(&version_str).ok()?;
 829        let model_id = env::var("ZED_ZETA_MODEL").ok();
 830        let environment = env::var("ZED_ZETA_ENVIRONMENT").ok();
 831        Some(Zeta2RawConfig {
 832            model_id,
 833            environment,
 834            format,
 835        })
 836    }
 837
 838    pub fn set_edit_prediction_model(&mut self, model: EditPredictionModel) {
 839        self.edit_prediction_model = model;
 840    }
 841
 842    pub fn set_zeta2_raw_config(&mut self, config: Zeta2RawConfig) {
 843        self.zeta2_raw_config = Some(config);
 844    }
 845
 846    pub fn zeta2_raw_config(&self) -> Option<&Zeta2RawConfig> {
 847        self.zeta2_raw_config.as_ref()
 848    }
 849
 850    pub fn preferred_experiment(&self) -> Option<&str> {
 851        self.preferred_experiment.as_deref()
 852    }
 853
 854    pub fn set_preferred_experiment(&mut self, experiment: Option<String>) {
 855        self.preferred_experiment = experiment;
 856    }
 857
 858    pub fn available_experiments(&self) -> &[String] {
 859        &self.available_experiments
 860    }
 861
 862    pub fn active_experiment(&self) -> Option<&str> {
 863        self.preferred_experiment.as_deref().or_else(|| {
 864            self.shown_predictions
 865                .iter()
 866                .find_map(|p| p.model_version.as_ref())
 867                .and_then(|model_version| model_version.strip_prefix("zeta2:"))
 868        })
 869    }
 870
 871    pub fn refresh_available_experiments(&mut self, cx: &mut Context<Self>) {
 872        let client = self.client.clone();
 873        let llm_token = self.llm_token.clone();
 874        let app_version = AppVersion::global(cx);
 875        let organization_id = self
 876            .user_store
 877            .read(cx)
 878            .current_organization()
 879            .map(|organization| organization.id.clone());
 880
 881        cx.spawn(async move |this, cx| {
 882            let experiments = cx
 883                .background_spawn(async move {
 884                    let http_client = client.http_client();
 885                    let token = client
 886                        .acquire_llm_token(&llm_token, organization_id.clone())
 887                        .await?;
 888                    let url = http_client.build_zed_llm_url("/edit_prediction_experiments", &[])?;
 889                    let request = http_client::Request::builder()
 890                        .method(Method::GET)
 891                        .uri(url.as_ref())
 892                        .header("Authorization", format!("Bearer {}", token))
 893                        .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
 894                        .body(Default::default())?;
 895                    let mut response = http_client.send(request).await?;
 896                    if response.status().is_success() {
 897                        let mut body = Vec::new();
 898                        response.body_mut().read_to_end(&mut body).await?;
 899                        let experiments: Vec<String> = serde_json::from_slice(&body)?;
 900                        Ok(experiments)
 901                    } else {
 902                        let mut body = String::new();
 903                        response.body_mut().read_to_string(&mut body).await?;
 904                        anyhow::bail!(
 905                            "Failed to fetch experiments: {:?}\nBody: {}",
 906                            response.status(),
 907                            body
 908                        );
 909                    }
 910                })
 911                .await?;
 912            this.update(cx, |this, cx| {
 913                this.available_experiments = experiments;
 914                cx.notify();
 915            })?;
 916            anyhow::Ok(())
 917        })
 918        .detach_and_log_err(cx);
 919    }
 920
 921    pub fn icons(&self, cx: &App) -> edit_prediction_types::EditPredictionIconSet {
 922        use ui::IconName;
 923        match self.edit_prediction_model {
 924            EditPredictionModel::Mercury => {
 925                edit_prediction_types::EditPredictionIconSet::new(IconName::Inception)
 926            }
 927            EditPredictionModel::Zeta => {
 928                edit_prediction_types::EditPredictionIconSet::new(IconName::ZedPredict)
 929                    .with_disabled(IconName::ZedPredictDisabled)
 930                    .with_up(IconName::ZedPredictUp)
 931                    .with_down(IconName::ZedPredictDown)
 932                    .with_error(IconName::ZedPredictError)
 933            }
 934            EditPredictionModel::Fim { .. } => {
 935                let settings = &all_language_settings(None, cx).edit_predictions;
 936                match settings.provider {
 937                    EditPredictionProvider::Ollama => {
 938                        edit_prediction_types::EditPredictionIconSet::new(IconName::AiOllama)
 939                    }
 940                    _ => {
 941                        edit_prediction_types::EditPredictionIconSet::new(IconName::AiOpenAiCompat)
 942                    }
 943                }
 944            }
 945        }
 946    }
 947
 948    pub fn has_mercury_api_token(&self, cx: &App) -> bool {
 949        self.mercury.api_token.read(cx).has_key()
 950    }
 951
 952    pub fn mercury_has_payment_required_error(&self) -> bool {
 953        self.mercury.has_payment_required_error()
 954    }
 955
 956    pub fn clear_history(&mut self) {
 957        for project_state in self.projects.values_mut() {
 958            project_state.events.clear();
 959            project_state.last_event.take();
 960        }
 961    }
 962
 963    pub fn clear_history_for_project(&mut self, project: &Entity<Project>) {
 964        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
 965            project_state.events.clear();
 966            project_state.last_event.take();
 967        }
 968    }
 969
 970    pub fn edit_history_for_project(
 971        &self,
 972        project: &Entity<Project>,
 973        cx: &App,
 974    ) -> Vec<StoredEvent> {
 975        self.projects
 976            .get(&project.entity_id())
 977            .map(|project_state| project_state.events(cx))
 978            .unwrap_or_default()
 979    }
 980
 981    pub fn context_for_project<'a>(
 982        &'a self,
 983        project: &Entity<Project>,
 984        cx: &'a mut App,
 985    ) -> Vec<RelatedFile> {
 986        self.projects
 987            .get(&project.entity_id())
 988            .map(|project_state| {
 989                project_state.context.update(cx, |context, cx| {
 990                    context
 991                        .related_files_with_buffers(cx)
 992                        .map(|(mut related_file, buffer)| {
 993                            related_file.in_open_source_repo = buffer
 994                                .read(cx)
 995                                .file()
 996                                .map_or(false, |file| self.is_file_open_source(&project, file, cx));
 997                            related_file
 998                        })
 999                        .collect()
1000                })
1001            })
1002            .unwrap_or_default()
1003    }
1004
1005    pub fn copilot_for_project(&self, project: &Entity<Project>) -> Option<Entity<Copilot>> {
1006        self.projects
1007            .get(&project.entity_id())
1008            .and_then(|project| project.copilot.clone())
1009    }
1010
1011    pub fn start_copilot_for_project(
1012        &mut self,
1013        project: &Entity<Project>,
1014        cx: &mut Context<Self>,
1015    ) -> Option<Entity<Copilot>> {
1016        if DisableAiSettings::get(None, cx).disable_ai {
1017            return None;
1018        }
1019        let state = self.get_or_init_project(project, cx);
1020
1021        if state.copilot.is_some() {
1022            return state.copilot.clone();
1023        }
1024        let _project = project.clone();
1025        let project = project.read(cx);
1026
1027        let node = project.node_runtime().cloned();
1028        if let Some(node) = node {
1029            let next_id = project.languages().next_language_server_id();
1030            let fs = project.fs().clone();
1031
1032            let copilot = cx.new(|cx| Copilot::new(Some(_project), next_id, fs, node, cx));
1033            state.copilot = Some(copilot.clone());
1034            Some(copilot)
1035        } else {
1036            None
1037        }
1038    }
1039
1040    pub fn context_for_project_with_buffers<'a>(
1041        &'a self,
1042        project: &Entity<Project>,
1043        cx: &'a mut App,
1044    ) -> Vec<(RelatedFile, Entity<Buffer>)> {
1045        self.projects
1046            .get(&project.entity_id())
1047            .map(|project| {
1048                project.context.update(cx, |context, cx| {
1049                    context.related_files_with_buffers(cx).collect()
1050                })
1051            })
1052            .unwrap_or_default()
1053    }
1054
1055    pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
1056        if matches!(self.edit_prediction_model, EditPredictionModel::Zeta) {
1057            self.user_store.read(cx).edit_prediction_usage()
1058        } else {
1059            None
1060        }
1061    }
1062
1063    pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
1064        self.get_or_init_project(project, cx);
1065    }
1066
1067    pub fn register_buffer(
1068        &mut self,
1069        buffer: &Entity<Buffer>,
1070        project: &Entity<Project>,
1071        cx: &mut Context<Self>,
1072    ) {
1073        let project_state = self.get_or_init_project(project, cx);
1074        Self::register_buffer_impl(project_state, buffer, project, cx);
1075    }
1076
1077    fn get_or_init_project(
1078        &mut self,
1079        project: &Entity<Project>,
1080        cx: &mut Context<Self>,
1081    ) -> &mut ProjectState {
1082        let entity_id = project.entity_id();
1083        self.projects
1084            .entry(entity_id)
1085            .or_insert_with(|| ProjectState {
1086                context: {
1087                    let related_excerpt_store = cx.new(|cx| RelatedExcerptStore::new(project, cx));
1088                    cx.subscribe(&related_excerpt_store, move |this, _, event, _| {
1089                        this.handle_excerpt_store_event(entity_id, event);
1090                    })
1091                    .detach();
1092                    related_excerpt_store
1093                },
1094                events: VecDeque::new(),
1095                last_event: None,
1096                recent_paths: VecDeque::new(),
1097                debug_tx: None,
1098                registered_buffers: HashMap::default(),
1099                current_prediction: None,
1100                cancelled_predictions: HashSet::default(),
1101                pending_predictions: ArrayVec::new(),
1102                next_pending_prediction_id: 0,
1103                last_edit_prediction_refresh: None,
1104                last_jump_prediction_refresh: None,
1105                license_detection_watchers: HashMap::default(),
1106                _subscriptions: [
1107                    cx.subscribe(&project, Self::handle_project_event),
1108                    cx.observe_release(&project, move |this, _, cx| {
1109                        this.projects.remove(&entity_id);
1110                        cx.notify();
1111                    }),
1112                ],
1113                copilot: None,
1114            })
1115    }
1116
1117    pub fn remove_project(&mut self, project: &Entity<Project>) {
1118        self.projects.remove(&project.entity_id());
1119    }
1120
1121    fn handle_excerpt_store_event(
1122        &mut self,
1123        project_entity_id: EntityId,
1124        event: &RelatedExcerptStoreEvent,
1125    ) {
1126        if let Some(project_state) = self.projects.get(&project_entity_id) {
1127            if let Some(debug_tx) = project_state.debug_tx.clone() {
1128                match event {
1129                    RelatedExcerptStoreEvent::StartedRefresh => {
1130                        debug_tx
1131                            .unbounded_send(DebugEvent::ContextRetrievalStarted(
1132                                ContextRetrievalStartedDebugEvent {
1133                                    project_entity_id: project_entity_id,
1134                                    timestamp: Instant::now(),
1135                                    search_prompt: String::new(),
1136                                },
1137                            ))
1138                            .ok();
1139                    }
1140                    RelatedExcerptStoreEvent::FinishedRefresh {
1141                        cache_hit_count,
1142                        cache_miss_count,
1143                        mean_definition_latency,
1144                        max_definition_latency,
1145                    } => {
1146                        debug_tx
1147                            .unbounded_send(DebugEvent::ContextRetrievalFinished(
1148                                ContextRetrievalFinishedDebugEvent {
1149                                    project_entity_id: project_entity_id,
1150                                    timestamp: Instant::now(),
1151                                    metadata: vec![
1152                                        (
1153                                            "Cache Hits",
1154                                            format!(
1155                                                "{}/{}",
1156                                                cache_hit_count,
1157                                                cache_hit_count + cache_miss_count
1158                                            )
1159                                            .into(),
1160                                        ),
1161                                        (
1162                                            "Max LSP Time",
1163                                            format!("{} ms", max_definition_latency.as_millis())
1164                                                .into(),
1165                                        ),
1166                                        (
1167                                            "Mean LSP Time",
1168                                            format!("{} ms", mean_definition_latency.as_millis())
1169                                                .into(),
1170                                        ),
1171                                    ],
1172                                },
1173                            ))
1174                            .ok();
1175                    }
1176                }
1177            }
1178        }
1179    }
1180
1181    pub fn debug_info(
1182        &mut self,
1183        project: &Entity<Project>,
1184        cx: &mut Context<Self>,
1185    ) -> mpsc::UnboundedReceiver<DebugEvent> {
1186        let project_state = self.get_or_init_project(project, cx);
1187        let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
1188        project_state.debug_tx = Some(debug_watch_tx);
1189        debug_watch_rx
1190    }
1191
1192    fn handle_project_event(
1193        &mut self,
1194        project: Entity<Project>,
1195        event: &project::Event,
1196        cx: &mut Context<Self>,
1197    ) {
1198        if !is_ep_store_provider(all_language_settings(None, cx).edit_predictions.provider) {
1199            return;
1200        }
1201        // TODO [zeta2] init with recent paths
1202        match event {
1203            project::Event::ActiveEntryChanged(Some(active_entry_id)) => {
1204                let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1205                    return;
1206                };
1207                let path = project.read(cx).path_for_entry(*active_entry_id, cx);
1208                if let Some(path) = path {
1209                    if let Some(ix) = project_state
1210                        .recent_paths
1211                        .iter()
1212                        .position(|probe| probe == &path)
1213                    {
1214                        project_state.recent_paths.remove(ix);
1215                    }
1216                    project_state.recent_paths.push_front(path);
1217                }
1218            }
1219            project::Event::DiagnosticsUpdated { .. } => {
1220                if cx.has_flag::<EditPredictionJumpsFeatureFlag>() {
1221                    self.refresh_prediction_from_diagnostics(
1222                        project,
1223                        DiagnosticSearchScope::Global,
1224                        cx,
1225                    );
1226                }
1227            }
1228            _ => (),
1229        }
1230    }
1231
1232    fn register_buffer_impl<'a>(
1233        project_state: &'a mut ProjectState,
1234        buffer: &Entity<Buffer>,
1235        project: &Entity<Project>,
1236        cx: &mut Context<Self>,
1237    ) -> &'a mut RegisteredBuffer {
1238        let buffer_id = buffer.entity_id();
1239
1240        if let Some(file) = buffer.read(cx).file() {
1241            let worktree_id = file.worktree_id(cx);
1242            if let Some(worktree) = project.read(cx).worktree_for_id(worktree_id, cx) {
1243                project_state
1244                    .license_detection_watchers
1245                    .entry(worktree_id)
1246                    .or_insert_with(|| {
1247                        let project_entity_id = project.entity_id();
1248                        cx.observe_release(&worktree, move |this, _worktree, _cx| {
1249                            let Some(project_state) = this.projects.get_mut(&project_entity_id)
1250                            else {
1251                                return;
1252                            };
1253                            project_state
1254                                .license_detection_watchers
1255                                .remove(&worktree_id);
1256                        })
1257                        .detach();
1258                        Rc::new(LicenseDetectionWatcher::new(&worktree, cx))
1259                    });
1260            }
1261        }
1262
1263        match project_state.registered_buffers.entry(buffer_id) {
1264            hash_map::Entry::Occupied(entry) => entry.into_mut(),
1265            hash_map::Entry::Vacant(entry) => {
1266                let buf = buffer.read(cx);
1267                let snapshot = buf.text_snapshot();
1268                let file = buf.file().cloned();
1269                let project_entity_id = project.entity_id();
1270                entry.insert(RegisteredBuffer {
1271                    snapshot,
1272                    file,
1273                    last_position: None,
1274                    pending_predictions: Vec::new(),
1275                    _subscriptions: [
1276                        cx.subscribe(buffer, {
1277                            let project = project.downgrade();
1278                            move |this, buffer, event, cx| {
1279                                if let language::BufferEvent::Edited { is_local } = event
1280                                    && let Some(project) = project.upgrade()
1281                                {
1282                                    this.report_changes_for_buffer(
1283                                        &buffer, &project, false, *is_local, cx,
1284                                    );
1285                                }
1286                            }
1287                        }),
1288                        cx.observe_release(buffer, move |this, _buffer, _cx| {
1289                            let Some(project_state) = this.projects.get_mut(&project_entity_id)
1290                            else {
1291                                return;
1292                            };
1293                            project_state.registered_buffers.remove(&buffer_id);
1294                        }),
1295                    ],
1296                })
1297            }
1298        }
1299    }
1300
1301    fn report_changes_for_buffer(
1302        &mut self,
1303        buffer: &Entity<Buffer>,
1304        project: &Entity<Project>,
1305        is_predicted: bool,
1306        is_local: bool,
1307        cx: &mut Context<Self>,
1308    ) {
1309        let project_state = self.get_or_init_project(project, cx);
1310        let registered_buffer = Self::register_buffer_impl(project_state, buffer, project, cx);
1311
1312        let buf = buffer.read(cx);
1313        let new_file = buf.file().cloned();
1314        let new_snapshot = buf.text_snapshot();
1315        if new_snapshot.version == registered_buffer.snapshot.version {
1316            return;
1317        }
1318        let old_file = mem::replace(&mut registered_buffer.file, new_file.clone());
1319        let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
1320        let mut edit_range: Option<Range<Anchor>> = None;
1321        let now = cx.background_executor().now();
1322
1323        for (_edit, anchor_range) in
1324            new_snapshot.anchored_edits_since::<usize>(&old_snapshot.version)
1325        {
1326            edit_range = Some(match edit_range {
1327                None => anchor_range,
1328                Some(acc) => acc.start..anchor_range.end,
1329            });
1330        }
1331
1332        let Some(edit_range) = edit_range else {
1333            return;
1334        };
1335
1336        for pending_prediction in &mut registered_buffer.pending_predictions {
1337            if edit_range.overlaps(&pending_prediction.editable_anchor_range, &new_snapshot) {
1338                pending_prediction.last_edit_at = now;
1339            }
1340        }
1341
1342        let include_in_history = is_local
1343            || collaborator_edit_overlaps_locality_region(
1344                project_state,
1345                project,
1346                buffer,
1347                &buf.snapshot(),
1348                &edit_range,
1349                cx,
1350            );
1351
1352        if !include_in_history {
1353            return;
1354        }
1355
1356        let is_recordable_history_edit =
1357            compute_diff_between_snapshots_in_range(&old_snapshot, &new_snapshot, &edit_range)
1358                .is_some();
1359
1360        let events = &mut project_state.events;
1361
1362        if !is_recordable_history_edit {
1363            if let Some(event) = project_state.last_event.take() {
1364                if let Some(event) = event.finalize(&project_state.license_detection_watchers, cx) {
1365                    if events.len() + 1 >= EVENT_COUNT_MAX {
1366                        events.pop_front();
1367                    }
1368                    events.push_back(event);
1369                }
1370            }
1371            return;
1372        }
1373
1374        if let Some(last_event) = project_state.last_event.as_mut() {
1375            let is_next_snapshot_of_same_buffer = old_snapshot.remote_id()
1376                == last_event.new_snapshot.remote_id()
1377                && old_snapshot.version == last_event.new_snapshot.version;
1378
1379            let prediction_source_changed = is_predicted != last_event.predicted;
1380
1381            let should_coalesce = is_next_snapshot_of_same_buffer
1382                && !prediction_source_changed
1383                && lines_between_ranges(
1384                    &edit_range.to_point(&new_snapshot),
1385                    &last_event.latest_edit_range.to_point(&new_snapshot),
1386                ) <= CHANGE_GROUPING_LINE_SPAN;
1387
1388            if should_coalesce {
1389                let pause_elapsed = last_event
1390                    .last_edit_time
1391                    .map(|t| now.duration_since(t) >= LAST_CHANGE_GROUPING_TIME)
1392                    .unwrap_or(false);
1393                if pause_elapsed {
1394                    last_event.snapshot_after_last_editing_pause =
1395                        Some(last_event.new_snapshot.clone());
1396                    last_event.total_edit_range_at_last_pause_boundary =
1397                        Some(last_event.total_edit_range.clone());
1398                }
1399
1400                last_event.latest_edit_range = edit_range.clone();
1401                last_event.total_edit_range =
1402                    merge_anchor_ranges(&last_event.total_edit_range, &edit_range, &new_snapshot);
1403                last_event.new_snapshot = new_snapshot;
1404                last_event.last_edit_time = Some(now);
1405                return;
1406            }
1407        }
1408
1409        if let Some(event) = project_state.last_event.take() {
1410            if let Some(event) = event.finalize(&project_state.license_detection_watchers, cx) {
1411                if events.len() + 1 >= EVENT_COUNT_MAX {
1412                    events.pop_front();
1413                }
1414                events.push_back(event);
1415            }
1416        }
1417
1418        merge_trailing_events_if_needed(events, &old_snapshot, &new_snapshot, &edit_range);
1419
1420        project_state.last_event = Some(LastEvent {
1421            old_file,
1422            new_file,
1423            old_snapshot,
1424            new_snapshot,
1425            latest_edit_range: edit_range.clone(),
1426            total_edit_range: edit_range,
1427            total_edit_range_at_last_pause_boundary: None,
1428            predicted: is_predicted,
1429            snapshot_after_last_editing_pause: None,
1430            last_edit_time: Some(now),
1431        });
1432    }
1433
1434    fn prediction_at(
1435        &mut self,
1436        buffer: &Entity<Buffer>,
1437        position: Option<language::Anchor>,
1438        project: &Entity<Project>,
1439        cx: &App,
1440    ) -> Option<BufferEditPrediction<'_>> {
1441        let project_state = self.projects.get_mut(&project.entity_id())?;
1442        if let Some(position) = position
1443            && let Some(buffer) = project_state
1444                .registered_buffers
1445                .get_mut(&buffer.entity_id())
1446        {
1447            buffer.last_position = Some(position);
1448        }
1449
1450        let CurrentEditPrediction {
1451            requested_by,
1452            prediction,
1453            ..
1454        } = project_state.current_prediction.as_ref()?;
1455
1456        if prediction.targets_buffer(buffer.read(cx)) {
1457            Some(BufferEditPrediction::Local { prediction })
1458        } else {
1459            let show_jump = match requested_by {
1460                PredictionRequestedBy::Buffer(requested_by_buffer_id) => {
1461                    requested_by_buffer_id == &buffer.entity_id()
1462                }
1463                PredictionRequestedBy::DiagnosticsUpdate => true,
1464            };
1465
1466            if show_jump {
1467                Some(BufferEditPrediction::Jump { prediction })
1468            } else {
1469                None
1470            }
1471        }
1472    }
1473
1474    fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
1475        let Some(current_prediction) = self
1476            .projects
1477            .get_mut(&project.entity_id())
1478            .and_then(|project_state| project_state.current_prediction.take())
1479        else {
1480            return;
1481        };
1482
1483        self.report_changes_for_buffer(
1484            &current_prediction.prediction.buffer,
1485            project,
1486            true,
1487            true,
1488            cx,
1489        );
1490
1491        // can't hold &mut project_state ref across report_changes_for_buffer_call
1492        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1493            return;
1494        };
1495
1496        for pending_prediction in mem::take(&mut project_state.pending_predictions) {
1497            project_state.cancel_pending_prediction(pending_prediction, cx);
1498        }
1499
1500        match self.edit_prediction_model {
1501            EditPredictionModel::Mercury => {
1502                mercury::edit_prediction_accepted(
1503                    current_prediction.prediction.id,
1504                    self.client.http_client(),
1505                    cx,
1506                );
1507            }
1508            EditPredictionModel::Zeta => {
1509                let is_cloud = !matches!(
1510                    all_language_settings(None, cx).edit_predictions.provider,
1511                    EditPredictionProvider::Ollama | EditPredictionProvider::OpenAiCompatibleApi
1512                );
1513                if is_cloud {
1514                    zeta::edit_prediction_accepted(self, current_prediction, cx)
1515                }
1516            }
1517            EditPredictionModel::Fim { .. } => {}
1518        }
1519    }
1520
1521    async fn handle_rejected_predictions(
1522        rx: UnboundedReceiver<EditPredictionRejectionPayload>,
1523        client: Arc<Client>,
1524        llm_token: LlmApiToken,
1525        app_version: Version,
1526        background_executor: BackgroundExecutor,
1527    ) {
1528        let mut rx = std::pin::pin!(rx.peekable());
1529        let mut batched = Vec::new();
1530
1531        while let Some(EditPredictionRejectionPayload {
1532            rejection,
1533            organization_id,
1534        }) = rx.next().await
1535        {
1536            batched.push(rejection);
1537
1538            if batched.len() < MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST / 2 {
1539                select_biased! {
1540                    next = rx.as_mut().peek().fuse() => {
1541                        if next.is_some() {
1542                            continue;
1543                        }
1544                    }
1545                    () = background_executor.timer(REJECT_REQUEST_DEBOUNCE).fuse() => {},
1546                }
1547            }
1548
1549            let url = client
1550                .http_client()
1551                .build_zed_llm_url("/predict_edits/reject", &[])
1552                .unwrap();
1553
1554            let flush_count = batched
1555                .len()
1556                // in case items have accumulated after failure
1557                .min(MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST);
1558            let start = batched.len() - flush_count;
1559
1560            let body = RejectEditPredictionsBodyRef {
1561                rejections: &batched[start..],
1562            };
1563
1564            let result = Self::send_api_request::<()>(
1565                |builder| {
1566                    let req = builder
1567                        .uri(url.as_ref())
1568                        .body(serde_json::to_string(&body)?.into());
1569                    anyhow::Ok(req?)
1570                },
1571                client.clone(),
1572                llm_token.clone(),
1573                organization_id,
1574                app_version.clone(),
1575                true,
1576            )
1577            .await;
1578
1579            if result.log_err().is_some() {
1580                batched.drain(start..);
1581            }
1582        }
1583    }
1584
1585    async fn run_settled_predictions_worker(
1586        this: WeakEntity<Self>,
1587        mut rx: UnboundedReceiver<Instant>,
1588        cx: &mut AsyncApp,
1589    ) {
1590        let mut next_wake_time: Option<Instant> = None;
1591        loop {
1592            let now = cx.background_executor().now();
1593            if let Some(wake_time) = next_wake_time.take() {
1594                cx.background_executor()
1595                    .timer(wake_time.duration_since(now))
1596                    .await;
1597            } else {
1598                let Some(new_enqueue_time) = rx.next().await else {
1599                    break;
1600                };
1601                next_wake_time = Some(new_enqueue_time + EDIT_PREDICTION_SETTLED_QUIESCENCE);
1602                while rx.next().now_or_never().flatten().is_some() {}
1603                continue;
1604            }
1605
1606            let Some(this) = this.upgrade() else {
1607                break;
1608            };
1609
1610            let now = cx.background_executor().now();
1611            let mut oldest_edited_at = None;
1612            let mut ready_predictions = Vec::new();
1613
1614            this.update(cx, |this, _| {
1615                for (_, project_state) in this.projects.iter_mut() {
1616                    for (_, registered_buffer) in project_state.registered_buffers.iter_mut() {
1617                        let mut pending_index = 0;
1618                        while pending_index < registered_buffer.pending_predictions.len() {
1619                            let pending_prediction =
1620                                &registered_buffer.pending_predictions[pending_index];
1621                            let age = now.saturating_duration_since(pending_prediction.enqueued_at);
1622                            if age >= EDIT_PREDICTION_SETTLED_TTL {
1623                                registered_buffer.pending_predictions.remove(pending_index);
1624                                continue;
1625                            }
1626
1627                            let quiet_for =
1628                                now.saturating_duration_since(pending_prediction.last_edit_at);
1629                            if quiet_for >= EDIT_PREDICTION_SETTLED_QUIESCENCE {
1630                                let pending_prediction =
1631                                    registered_buffer.pending_predictions.remove(pending_index);
1632                                let settled_editable_region = registered_buffer
1633                                    .snapshot
1634                                    .text_for_range(
1635                                        pending_prediction.editable_anchor_range.clone(),
1636                                    )
1637                                    .collect::<String>();
1638                                ready_predictions
1639                                    .push((pending_prediction, settled_editable_region));
1640                                continue;
1641                            }
1642
1643                            if oldest_edited_at
1644                                .is_none_or(|time| pending_prediction.last_edit_at < time)
1645                            {
1646                                oldest_edited_at = Some(pending_prediction.last_edit_at);
1647                            }
1648                            pending_index += 1;
1649                        }
1650                    }
1651                }
1652            });
1653
1654            for (pending_prediction, settled_editable_region) in ready_predictions {
1655                let PendingSettledPrediction {
1656                    request_id,
1657                    editable_region_before_prediction,
1658                    predicted_editable_region,
1659                    ts_error_count_before_prediction,
1660                    ts_error_count_after_prediction,
1661                    example,
1662                    e2e_latency,
1663                    ..
1664                } = pending_prediction;
1665                let settled_editable_region_for_metrics = settled_editable_region.clone();
1666                let kept_rate_result = cx
1667                    .background_spawn(async move {
1668                        compute_kept_rate(
1669                            &editable_region_before_prediction,
1670                            &predicted_editable_region,
1671                            &settled_editable_region_for_metrics,
1672                        )
1673                    })
1674                    .await;
1675
1676                #[cfg(test)]
1677                {
1678                    let request_id = request_id.clone();
1679                    let settled_editable_region = settled_editable_region.clone();
1680                    this.update(cx, |this, _| {
1681                        if let Some(callback) = &this.settled_event_callback {
1682                            callback(request_id, settled_editable_region);
1683                        }
1684                    });
1685                }
1686
1687                telemetry::event!(
1688                    EDIT_PREDICTION_SETTLED_EVENT,
1689                    request_id = request_id.0.clone(),
1690                    settled_editable_region,
1691                    ts_error_count_before_prediction,
1692                    ts_error_count_after_prediction,
1693                    edit_bytes_candidate_new = kept_rate_result.candidate_new_chars,
1694                    edit_bytes_reference_new = kept_rate_result.reference_new_chars,
1695                    edit_bytes_candidate_deleted = kept_rate_result.candidate_deleted_chars,
1696                    edit_bytes_reference_deleted = kept_rate_result.reference_deleted_chars,
1697                    edit_bytes_kept = kept_rate_result.kept_chars,
1698                    edit_bytes_correctly_deleted = kept_rate_result.correctly_deleted_chars,
1699                    edit_bytes_discarded = kept_rate_result.discarded_chars,
1700                    edit_bytes_context = kept_rate_result.context_chars,
1701                    edit_bytes_kept_rate = kept_rate_result.kept_rate,
1702                    edit_bytes_recall_rate = kept_rate_result.recall_rate,
1703                    example,
1704                    e2e_latency = e2e_latency.as_millis(),
1705                );
1706            }
1707
1708            next_wake_time = oldest_edited_at.map(|time| time + EDIT_PREDICTION_SETTLED_QUIESCENCE);
1709        }
1710    }
1711
1712    pub(crate) fn enqueue_settled_prediction(
1713        &mut self,
1714        request_id: EditPredictionId,
1715        project: &Entity<Project>,
1716        edited_buffer: &Entity<Buffer>,
1717        edited_buffer_snapshot: &BufferSnapshot,
1718        editable_offset_range: Range<usize>,
1719        edit_preview: &EditPreview,
1720        example: Option<ExampleSpec>,
1721        e2e_latency: std::time::Duration,
1722        cx: &mut Context<Self>,
1723    ) {
1724        let this = &mut *self;
1725        let project_state = this.get_or_init_project(project, cx);
1726        let Some(registered_buffer) = project_state
1727            .registered_buffers
1728            .get_mut(&edited_buffer.entity_id())
1729        else {
1730            return;
1731        };
1732
1733        let editable_region_before_prediction = edited_buffer_snapshot
1734            .text_for_range(editable_offset_range.clone())
1735            .collect::<String>();
1736        let editable_anchor_range_for_result =
1737            edited_buffer_snapshot.anchor_range_inside(editable_offset_range.clone());
1738        let predicted_editable_region = edit_preview
1739            .result_text_snapshot()
1740            .text_for_range(editable_anchor_range_for_result.clone())
1741            .collect();
1742        let ts_error_count_before_prediction = crate::metrics::count_tree_sitter_errors(
1743            edited_buffer_snapshot
1744                .syntax_layers_for_range(editable_anchor_range_for_result.clone(), true),
1745        );
1746        let ts_error_count_after_prediction = crate::metrics::count_tree_sitter_errors(
1747            edit_preview.result_syntax_snapshot().layers_for_range(
1748                editable_anchor_range_for_result,
1749                edit_preview.result_text_snapshot(),
1750                true,
1751            ),
1752        );
1753        let editable_anchor_range =
1754            edited_buffer_snapshot.anchor_range_inside(editable_offset_range);
1755        let now = cx.background_executor().now();
1756        registered_buffer
1757            .pending_predictions
1758            .push(PendingSettledPrediction {
1759                request_id,
1760                editable_anchor_range,
1761                editable_region_before_prediction,
1762                predicted_editable_region,
1763                ts_error_count_before_prediction,
1764                ts_error_count_after_prediction,
1765                example,
1766                e2e_latency,
1767                enqueued_at: now,
1768                last_edit_at: now,
1769            });
1770        this.settled_predictions_tx.unbounded_send(now).ok();
1771    }
1772
1773    fn reject_current_prediction(
1774        &mut self,
1775        reason: EditPredictionRejectReason,
1776        project: &Entity<Project>,
1777        cx: &App,
1778    ) {
1779        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
1780            project_state.pending_predictions.clear();
1781            if let Some(prediction) = project_state.current_prediction.take() {
1782                let model_version = prediction.prediction.model_version.clone();
1783                self.reject_prediction(
1784                    prediction.prediction.id,
1785                    reason,
1786                    prediction.was_shown,
1787                    model_version,
1788                    Some(prediction.e2e_latency),
1789                    cx,
1790                );
1791            }
1792        };
1793    }
1794
1795    fn did_show_current_prediction(
1796        &mut self,
1797        project: &Entity<Project>,
1798        display_type: edit_prediction_types::SuggestionDisplayType,
1799        _cx: &mut Context<Self>,
1800    ) {
1801        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1802            return;
1803        };
1804
1805        let Some(current_prediction) = project_state.current_prediction.as_mut() else {
1806            return;
1807        };
1808
1809        let is_jump = display_type == edit_prediction_types::SuggestionDisplayType::Jump;
1810        let previous_shown_with = current_prediction.shown_with;
1811
1812        if previous_shown_with.is_none() || !is_jump {
1813            current_prediction.shown_with = Some(display_type);
1814        }
1815
1816        let is_first_non_jump_show = !current_prediction.was_shown && !is_jump;
1817
1818        if is_first_non_jump_show {
1819            current_prediction.was_shown = true;
1820        }
1821
1822        if is_first_non_jump_show {
1823            self.shown_predictions
1824                .push_front(current_prediction.prediction.clone());
1825            if self.shown_predictions.len() > 50 {
1826                let completion = self.shown_predictions.pop_back().unwrap();
1827                self.rated_predictions.remove(&completion.id);
1828            }
1829        }
1830    }
1831
1832    fn reject_prediction(
1833        &mut self,
1834        prediction_id: EditPredictionId,
1835        reason: EditPredictionRejectReason,
1836        was_shown: bool,
1837        model_version: Option<String>,
1838        e2e_latency: Option<std::time::Duration>,
1839        cx: &App,
1840    ) {
1841        match self.edit_prediction_model {
1842            EditPredictionModel::Zeta => {
1843                let is_cloud = !matches!(
1844                    all_language_settings(None, cx).edit_predictions.provider,
1845                    EditPredictionProvider::Ollama | EditPredictionProvider::OpenAiCompatibleApi
1846                );
1847
1848                if is_cloud {
1849                    let organization_id = self
1850                        .user_store
1851                        .read(cx)
1852                        .current_organization()
1853                        .map(|organization| organization.id.clone());
1854
1855                    self.reject_predictions_tx
1856                        .unbounded_send(EditPredictionRejectionPayload {
1857                            rejection: EditPredictionRejection {
1858                                request_id: prediction_id.to_string(),
1859                                reason,
1860                                was_shown,
1861                                model_version,
1862                                e2e_latency_ms: e2e_latency.map(|latency| latency.as_millis()),
1863                            },
1864                            organization_id,
1865                        })
1866                        .log_err();
1867                }
1868            }
1869            EditPredictionModel::Mercury => {
1870                mercury::edit_prediction_rejected(
1871                    prediction_id,
1872                    was_shown,
1873                    reason,
1874                    self.client.http_client(),
1875                    cx,
1876                );
1877            }
1878            EditPredictionModel::Fim { .. } => {}
1879        }
1880    }
1881
1882    fn is_refreshing(&self, project: &Entity<Project>) -> bool {
1883        self.projects
1884            .get(&project.entity_id())
1885            .is_some_and(|project_state| !project_state.pending_predictions.is_empty())
1886    }
1887
1888    pub fn refresh_prediction_from_buffer(
1889        &mut self,
1890        project: Entity<Project>,
1891        buffer: Entity<Buffer>,
1892        position: language::Anchor,
1893        cx: &mut Context<Self>,
1894    ) {
1895        self.queue_prediction_refresh(
1896            project.clone(),
1897            PredictEditsRequestTrigger::Other,
1898            buffer.entity_id(),
1899            cx,
1900            move |this, cx| {
1901                let Some(request_task) = this
1902                    .update(cx, |this, cx| {
1903                        this.request_prediction(
1904                            &project,
1905                            &buffer,
1906                            position,
1907                            PredictEditsRequestTrigger::Other,
1908                            cx,
1909                        )
1910                    })
1911                    .log_err()
1912                else {
1913                    return Task::ready(anyhow::Ok(None));
1914                };
1915
1916                cx.spawn(async move |_cx| {
1917                    request_task.await.map(|prediction_result| {
1918                        prediction_result.map(|prediction_result| {
1919                            (
1920                                prediction_result,
1921                                PredictionRequestedBy::Buffer(buffer.entity_id()),
1922                            )
1923                        })
1924                    })
1925                })
1926            },
1927        )
1928    }
1929
1930    pub fn refresh_prediction_from_diagnostics(
1931        &mut self,
1932        project: Entity<Project>,
1933        scope: DiagnosticSearchScope,
1934        cx: &mut Context<Self>,
1935    ) {
1936        if !is_ep_store_provider(all_language_settings(None, cx).edit_predictions.provider) {
1937            return;
1938        }
1939
1940        if currently_following(&project, cx) {
1941            return;
1942        }
1943
1944        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1945            return;
1946        };
1947
1948        // Prefer predictions from buffer
1949        if project_state.current_prediction.is_some() {
1950            log::debug!(
1951                "edit_prediction: diagnostic refresh skipped, current prediction already exists"
1952            );
1953            return;
1954        }
1955
1956        self.queue_prediction_refresh(
1957            project.clone(),
1958            PredictEditsRequestTrigger::Diagnostics,
1959            project.entity_id(),
1960            cx,
1961            move |this, cx| {
1962                let Some((active_buffer, snapshot, cursor_point)) = this
1963                    .read_with(cx, |this, cx| {
1964                        let project_state = this.projects.get(&project.entity_id())?;
1965                        let (buffer, position) = project_state.active_buffer(&project, cx)?;
1966                        let snapshot = buffer.read(cx).snapshot();
1967
1968                        if !Self::predictions_enabled_at(&snapshot, position, cx) {
1969                            return None;
1970                        }
1971
1972                        let cursor_point = position
1973                            .map(|pos| pos.to_point(&snapshot))
1974                            .unwrap_or_default();
1975
1976                        Some((buffer, snapshot, cursor_point))
1977                    })
1978                    .log_err()
1979                    .flatten()
1980                else {
1981                    return Task::ready(anyhow::Ok(None));
1982                };
1983
1984                cx.spawn(async move |cx| {
1985                    let diagnostic_search_range = match scope {
1986                        DiagnosticSearchScope::Local => {
1987                            let diagnostic_search_start =
1988                                cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
1989                            let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
1990                            Point::new(diagnostic_search_start, 0)
1991                                ..Point::new(diagnostic_search_end, 0)
1992                        }
1993                        DiagnosticSearchScope::Global => Default::default(),
1994                    };
1995
1996                    let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1997                        active_buffer,
1998                        &snapshot,
1999                        diagnostic_search_range,
2000                        cursor_point,
2001                        &project,
2002                        cx,
2003                    )
2004                    .await?
2005                    else {
2006                        return anyhow::Ok(None);
2007                    };
2008
2009                    let Some(prediction_result) = this
2010                        .update(cx, |this, cx| {
2011                            this.request_prediction(
2012                                &project,
2013                                &jump_buffer,
2014                                jump_position,
2015                                PredictEditsRequestTrigger::Diagnostics,
2016                                cx,
2017                            )
2018                        })?
2019                        .await?
2020                    else {
2021                        return anyhow::Ok(None);
2022                    };
2023
2024                    this.update(cx, |this, cx| {
2025                        Some((
2026                            if this
2027                                .get_or_init_project(&project, cx)
2028                                .current_prediction
2029                                .is_none()
2030                            {
2031                                prediction_result
2032                            } else {
2033                                EditPredictionResult {
2034                                    id: prediction_result.id,
2035                                    prediction: Err(EditPredictionRejectReason::CurrentPreferred),
2036                                    e2e_latency: prediction_result.e2e_latency,
2037                                }
2038                            },
2039                            PredictionRequestedBy::DiagnosticsUpdate,
2040                        ))
2041                    })
2042                })
2043            },
2044        );
2045    }
2046
2047    fn predictions_enabled_at(
2048        snapshot: &BufferSnapshot,
2049        position: Option<language::Anchor>,
2050        cx: &App,
2051    ) -> bool {
2052        let file = snapshot.file();
2053        let all_settings = all_language_settings(file, cx);
2054        if !all_settings.show_edit_predictions(snapshot.language(), cx)
2055            || file.is_some_and(|file| !all_settings.edit_predictions_enabled_for_file(file, cx))
2056        {
2057            return false;
2058        }
2059
2060        if let Some(last_position) = position {
2061            let settings = snapshot.settings_at(last_position, cx);
2062
2063            if !settings.edit_predictions_disabled_in.is_empty()
2064                && let Some(scope) = snapshot.language_scope_at(last_position)
2065                && let Some(scope_name) = scope.override_name()
2066                && settings
2067                    .edit_predictions_disabled_in
2068                    .iter()
2069                    .any(|s| s == scope_name)
2070            {
2071                return false;
2072            }
2073        }
2074
2075        true
2076    }
2077
2078    pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
2079}
2080
2081fn currently_following(project: &Entity<Project>, cx: &App) -> bool {
2082    let Some(app_state) = AppState::try_global(cx) else {
2083        return false;
2084    };
2085
2086    app_state
2087        .workspace_store
2088        .read(cx)
2089        .workspaces()
2090        .filter_map(|workspace| workspace.upgrade())
2091        .any(|workspace| {
2092            workspace.read(cx).project().entity_id() == project.entity_id()
2093                && workspace
2094                    .read(cx)
2095                    .leader_for_pane(workspace.read(cx).active_pane())
2096                    .is_some()
2097        })
2098}
2099
2100fn is_ep_store_provider(provider: EditPredictionProvider) -> bool {
2101    match provider {
2102        EditPredictionProvider::Zed
2103        | EditPredictionProvider::Mercury
2104        | EditPredictionProvider::Ollama
2105        | EditPredictionProvider::OpenAiCompatibleApi
2106        | EditPredictionProvider::Experimental(_) => true,
2107        EditPredictionProvider::None
2108        | EditPredictionProvider::Copilot
2109        | EditPredictionProvider::Codestral => false,
2110    }
2111}
2112
2113impl EditPredictionStore {
2114    fn queue_prediction_refresh(
2115        &mut self,
2116        project: Entity<Project>,
2117        request_trigger: PredictEditsRequestTrigger,
2118        throttle_entity: EntityId,
2119        cx: &mut Context<Self>,
2120        do_refresh: impl FnOnce(
2121            WeakEntity<Self>,
2122            &mut AsyncApp,
2123        )
2124            -> Task<Result<Option<(EditPredictionResult, PredictionRequestedBy)>>>
2125        + 'static,
2126    ) {
2127        fn select_throttle(
2128            project_state: &mut ProjectState,
2129            request_trigger: PredictEditsRequestTrigger,
2130        ) -> &mut Option<(EntityId, Instant)> {
2131            match request_trigger {
2132                PredictEditsRequestTrigger::Diagnostics => {
2133                    &mut project_state.last_jump_prediction_refresh
2134                }
2135                _ => &mut project_state.last_edit_prediction_refresh,
2136            }
2137        }
2138
2139        let (needs_acceptance_tracking, max_pending_predictions) =
2140            match all_language_settings(None, cx).edit_predictions.provider {
2141                EditPredictionProvider::Zed
2142                | EditPredictionProvider::Mercury
2143                | EditPredictionProvider::Experimental(_) => (true, 2),
2144                EditPredictionProvider::Ollama => (false, 1),
2145                EditPredictionProvider::OpenAiCompatibleApi => (false, 2),
2146                EditPredictionProvider::None
2147                | EditPredictionProvider::Copilot
2148                | EditPredictionProvider::Codestral => {
2149                    log::error!("queue_prediction_refresh called with non-store provider");
2150                    return;
2151                }
2152            };
2153
2154        let drop_on_cancel = !needs_acceptance_tracking;
2155        let throttle_timeout = Self::THROTTLE_TIMEOUT;
2156        let project_state = self.get_or_init_project(&project, cx);
2157        let pending_prediction_id = project_state.next_pending_prediction_id;
2158        project_state.next_pending_prediction_id += 1;
2159        let throttle_at_enqueue = *select_throttle(project_state, request_trigger);
2160
2161        let task = cx.spawn(async move |this, cx| {
2162            let throttle_wait = this
2163                .update(cx, |this, cx| {
2164                    let project_state = this.get_or_init_project(&project, cx);
2165                    let throttle = *select_throttle(project_state, request_trigger);
2166
2167                    let now = cx.background_executor().now();
2168                    throttle.and_then(|(last_entity, last_timestamp)| {
2169                        if throttle_entity != last_entity {
2170                            return None;
2171                        }
2172                        (last_timestamp + throttle_timeout).checked_duration_since(now)
2173                    })
2174                })
2175                .ok()
2176                .flatten();
2177
2178            if let Some(timeout) = throttle_wait {
2179                cx.background_executor().timer(timeout).await;
2180            }
2181
2182            // If this task was cancelled before the throttle timeout expired,
2183            // do not perform a request. Also skip if another task already
2184            // proceeded since we were enqueued (duplicate).
2185            let mut is_cancelled = true;
2186            this.update(cx, |this, cx| {
2187                let project_state = this.get_or_init_project(&project, cx);
2188                let was_cancelled = project_state
2189                    .cancelled_predictions
2190                    .remove(&pending_prediction_id);
2191                if was_cancelled {
2192                    return;
2193                }
2194
2195                // Another request has been already sent since this was enqueued
2196                if *select_throttle(project_state, request_trigger) != throttle_at_enqueue {
2197                    return;
2198                }
2199
2200                let new_refresh = (throttle_entity, cx.background_executor().now());
2201                *select_throttle(project_state, request_trigger) = Some(new_refresh);
2202                is_cancelled = false;
2203            })
2204            .ok();
2205            if is_cancelled {
2206                return None;
2207            }
2208
2209            let new_prediction_result = do_refresh(this.clone(), cx).await.log_err().flatten();
2210            let new_prediction_id = new_prediction_result
2211                .as_ref()
2212                .map(|(prediction, _)| prediction.id.clone());
2213
2214            // When a prediction completes, remove it from the pending list, and cancel
2215            // any pending predictions that were enqueued before it.
2216            this.update(cx, |this, cx| {
2217                let project_state = this.get_or_init_project(&project, cx);
2218
2219                let is_cancelled = project_state
2220                    .cancelled_predictions
2221                    .remove(&pending_prediction_id);
2222
2223                let new_current_prediction = if !is_cancelled
2224                    && let Some((prediction_result, requested_by)) = new_prediction_result
2225                {
2226                    match prediction_result.prediction {
2227                        Ok(prediction) => {
2228                            let new_prediction = CurrentEditPrediction {
2229                                requested_by,
2230                                prediction,
2231                                was_shown: false,
2232                                shown_with: None,
2233                                e2e_latency: prediction_result.e2e_latency,
2234                            };
2235
2236                            if let Some(current_prediction) =
2237                                project_state.current_prediction.as_ref()
2238                            {
2239                                if new_prediction.should_replace_prediction(&current_prediction, cx)
2240                                {
2241                                    this.reject_current_prediction(
2242                                        EditPredictionRejectReason::Replaced,
2243                                        &project,
2244                                        cx,
2245                                    );
2246
2247                                    Some(new_prediction)
2248                                } else {
2249                                    this.reject_prediction(
2250                                        new_prediction.prediction.id,
2251                                        EditPredictionRejectReason::CurrentPreferred,
2252                                        false,
2253                                        new_prediction.prediction.model_version,
2254                                        Some(new_prediction.e2e_latency),
2255                                        cx,
2256                                    );
2257                                    None
2258                                }
2259                            } else {
2260                                Some(new_prediction)
2261                            }
2262                        }
2263                        Err(reject_reason) => {
2264                            this.reject_prediction(
2265                                prediction_result.id,
2266                                reject_reason,
2267                                false,
2268                                None,
2269                                Some(prediction_result.e2e_latency),
2270                                cx,
2271                            );
2272                            None
2273                        }
2274                    }
2275                } else {
2276                    None
2277                };
2278
2279                let project_state = this.get_or_init_project(&project, cx);
2280
2281                if let Some(new_prediction) = new_current_prediction {
2282                    project_state.current_prediction = Some(new_prediction);
2283                }
2284
2285                let mut pending_predictions = mem::take(&mut project_state.pending_predictions);
2286                for (ix, pending_prediction) in pending_predictions.iter().enumerate() {
2287                    if pending_prediction.id == pending_prediction_id {
2288                        pending_predictions.remove(ix);
2289                        for pending_prediction in pending_predictions.drain(0..ix) {
2290                            project_state.cancel_pending_prediction(pending_prediction, cx)
2291                        }
2292                        break;
2293                    }
2294                }
2295                this.get_or_init_project(&project, cx).pending_predictions = pending_predictions;
2296                cx.notify();
2297            })
2298            .ok();
2299
2300            new_prediction_id
2301        });
2302
2303        if project_state.pending_predictions.len() < max_pending_predictions {
2304            project_state
2305                .pending_predictions
2306                .push(PendingPrediction {
2307                    id: pending_prediction_id,
2308                    task,
2309                    drop_on_cancel,
2310                })
2311                .unwrap();
2312        } else {
2313            let pending_prediction = project_state.pending_predictions.pop().unwrap();
2314            project_state
2315                .pending_predictions
2316                .push(PendingPrediction {
2317                    id: pending_prediction_id,
2318                    task,
2319                    drop_on_cancel,
2320                })
2321                .unwrap();
2322            project_state.cancel_pending_prediction(pending_prediction, cx);
2323        }
2324    }
2325
2326    pub fn request_prediction(
2327        &mut self,
2328        project: &Entity<Project>,
2329        active_buffer: &Entity<Buffer>,
2330        position: language::Anchor,
2331        trigger: PredictEditsRequestTrigger,
2332        cx: &mut Context<Self>,
2333    ) -> Task<Result<Option<EditPredictionResult>>> {
2334        self.request_prediction_internal(
2335            project.clone(),
2336            active_buffer.clone(),
2337            position,
2338            trigger,
2339            cx.has_flag::<EditPredictionJumpsFeatureFlag>(),
2340            cx,
2341        )
2342    }
2343
2344    fn request_prediction_internal(
2345        &mut self,
2346        project: Entity<Project>,
2347        active_buffer: Entity<Buffer>,
2348        position: language::Anchor,
2349        trigger: PredictEditsRequestTrigger,
2350        allow_jump: bool,
2351        cx: &mut Context<Self>,
2352    ) -> Task<Result<Option<EditPredictionResult>>> {
2353        self.get_or_init_project(&project, cx);
2354        let project_state = self.projects.get(&project.entity_id()).unwrap();
2355        let stored_events = project_state.events(cx);
2356        let has_events = !stored_events.is_empty();
2357        let events: Vec<Arc<zeta_prompt::Event>> =
2358            stored_events.iter().map(|e| e.event.clone()).collect();
2359        let debug_tx = project_state.debug_tx.clone();
2360
2361        let snapshot = active_buffer.read(cx).snapshot();
2362        let cursor_point = position.to_point(&snapshot);
2363        let diagnostic_search_start = cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
2364        let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
2365        let diagnostic_search_range =
2366            Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
2367
2368        let related_files = self.context_for_project(&project, cx);
2369
2370        let is_open_source = snapshot
2371            .file()
2372            .map_or(false, |file| self.is_file_open_source(&project, file, cx))
2373            && events.iter().all(|event| event.in_open_source_repo())
2374            && related_files.iter().all(|file| file.in_open_source_repo);
2375
2376        let can_collect_data = !cfg!(test)
2377            && is_open_source
2378            && self.is_data_collection_enabled(cx)
2379            && matches!(self.edit_prediction_model, EditPredictionModel::Zeta);
2380
2381        let inputs = EditPredictionModelInput {
2382            project: project.clone(),
2383            buffer: active_buffer,
2384            snapshot,
2385            position,
2386            events,
2387            related_files,
2388            trigger,
2389            diagnostic_search_range: diagnostic_search_range,
2390            debug_tx,
2391            can_collect_data,
2392            is_open_source,
2393        };
2394
2395        let capture_data = (can_collect_data && rand::random_ratio(1, 1000)).then(|| stored_events);
2396
2397        let task = match self.edit_prediction_model {
2398            EditPredictionModel::Zeta => {
2399                zeta::request_prediction_with_zeta(self, inputs, capture_data, cx)
2400            }
2401            EditPredictionModel::Fim { format } => fim::request_prediction(inputs, format, cx),
2402            EditPredictionModel::Mercury => {
2403                self.mercury
2404                    .request_prediction(inputs, self.credentials_provider.clone(), cx)
2405            }
2406        };
2407
2408        cx.spawn(async move |this, cx| {
2409            let prediction = task.await?;
2410
2411            // Only fall back to diagnostics-based prediction if we got a
2412            // the model had nothing to suggest for the buffer
2413            if prediction.is_none()
2414                && allow_jump
2415                && has_events
2416                && !matches!(trigger, PredictEditsRequestTrigger::Diagnostics)
2417            {
2418                this.update(cx, |this, cx| {
2419                    this.refresh_prediction_from_diagnostics(
2420                        project,
2421                        DiagnosticSearchScope::Local,
2422                        cx,
2423                    );
2424                })?;
2425                return anyhow::Ok(None);
2426            }
2427
2428            Ok(prediction)
2429        })
2430    }
2431
2432    pub(crate) async fn next_diagnostic_location(
2433        active_buffer: Entity<Buffer>,
2434        active_buffer_snapshot: &BufferSnapshot,
2435        active_buffer_diagnostic_search_range: Range<Point>,
2436        active_buffer_cursor_point: Point,
2437        project: &Entity<Project>,
2438        cx: &mut AsyncApp,
2439    ) -> Result<Option<(Entity<Buffer>, language::Anchor)>> {
2440        let collaborator_cursor_rows: Vec<u32> = active_buffer_snapshot
2441            .selections_in_range(
2442                Anchor::min_max_range_for_buffer(active_buffer_snapshot.remote_id()),
2443                false,
2444            )
2445            .flat_map(|(_, _, _, selections)| {
2446                selections.map(|s| s.head().to_point(active_buffer_snapshot).row)
2447            })
2448            .collect();
2449
2450        let mut jump_location = active_buffer_snapshot
2451            .diagnostic_groups(None)
2452            .into_iter()
2453            .filter_map(|(_, group)| {
2454                let range = &group.entries[group.primary_ix]
2455                    .range
2456                    .to_point(&active_buffer_snapshot);
2457                if range.overlaps(&active_buffer_diagnostic_search_range) {
2458                    return None;
2459                }
2460                let near_collaborator = collaborator_cursor_rows.iter().any(|&collab_row| {
2461                    range.start.row.abs_diff(collab_row) <= DIAGNOSTIC_LINES_RANGE
2462                });
2463                let near_local = active_buffer_cursor_point.row.abs_diff(range.start.row)
2464                    <= DIAGNOSTIC_LINES_RANGE;
2465                if near_collaborator && !near_local {
2466                    return None;
2467                }
2468                Some(range.start)
2469            })
2470            .min_by_key(|probe| probe.row.abs_diff(active_buffer_cursor_point.row))
2471            .map(|position| {
2472                (
2473                    active_buffer.clone(),
2474                    active_buffer_snapshot.anchor_before(position),
2475                )
2476            });
2477
2478        if jump_location.is_none() {
2479            let active_buffer_path = active_buffer.read_with(cx, |buffer, cx| {
2480                let file = buffer.file()?;
2481
2482                Some(ProjectPath {
2483                    worktree_id: file.worktree_id(cx),
2484                    path: file.path().clone(),
2485                })
2486            });
2487
2488            let mut candidates: Vec<(ProjectPath, usize)> = project.read_with(cx, |project, cx| {
2489                project
2490                    .diagnostic_summaries(false, cx)
2491                    .filter(|(path, _, _)| Some(path) != active_buffer_path.as_ref())
2492                    .map(|(path, _, _)| {
2493                        let shared_prefix = path
2494                            .path
2495                            .components()
2496                            .zip(
2497                                active_buffer_path
2498                                    .as_ref()
2499                                    .map(|p| p.path.components())
2500                                    .unwrap_or_default(),
2501                            )
2502                            .take_while(|(a, b)| a == b)
2503                            .count();
2504                        (path, shared_prefix)
2505                    })
2506                    .collect()
2507            });
2508
2509            candidates.sort_by(|a, b| b.1.cmp(&a.1));
2510
2511            for (path, _) in candidates {
2512                let candidate_buffer = project
2513                    .update(cx, |project, cx| project.open_buffer(path, cx))
2514                    .await?;
2515
2516                let (has_collaborators, diagnostic_position) =
2517                    candidate_buffer.read_with(cx, |buffer, _cx| {
2518                        let snapshot = buffer.snapshot();
2519                        let has_collaborators = snapshot
2520                            .selections_in_range(
2521                                Anchor::min_max_range_for_buffer(snapshot.remote_id()),
2522                                false,
2523                            )
2524                            .next()
2525                            .is_some();
2526                        let position = buffer
2527                            .buffer_diagnostics(None)
2528                            .into_iter()
2529                            .min_by_key(|entry| entry.diagnostic.severity)
2530                            .map(|entry| entry.range.start);
2531                        (has_collaborators, position)
2532                    });
2533
2534                if has_collaborators {
2535                    continue;
2536                }
2537
2538                if let Some(position) = diagnostic_position {
2539                    jump_location = Some((candidate_buffer, position));
2540                    break;
2541                }
2542            }
2543        }
2544
2545        anyhow::Ok(jump_location)
2546    }
2547
2548    async fn send_raw_llm_request(
2549        request: RawCompletionRequest,
2550        client: Arc<Client>,
2551        custom_url: Option<Arc<Url>>,
2552        llm_token: LlmApiToken,
2553        organization_id: Option<OrganizationId>,
2554        app_version: Version,
2555    ) -> Result<(RawCompletionResponse, Option<EditPredictionUsage>)> {
2556        let url = if let Some(custom_url) = custom_url {
2557            custom_url.as_ref().clone()
2558        } else {
2559            client
2560                .http_client()
2561                .build_zed_llm_url("/predict_edits/raw", &[])?
2562        };
2563
2564        Self::send_api_request(
2565            |builder| {
2566                let req = builder
2567                    .uri(url.as_ref())
2568                    .body(serde_json::to_string(&request)?.into());
2569                Ok(req?)
2570            },
2571            client,
2572            llm_token,
2573            organization_id,
2574            app_version,
2575            true,
2576        )
2577        .await
2578    }
2579
2580    pub(crate) async fn send_v3_request(
2581        input: ZetaPromptInput,
2582        client: Arc<Client>,
2583        llm_token: LlmApiToken,
2584        organization_id: Option<OrganizationId>,
2585        app_version: Version,
2586        trigger: PredictEditsRequestTrigger,
2587    ) -> Result<(PredictEditsV3Response, Option<EditPredictionUsage>)> {
2588        let url = client
2589            .http_client()
2590            .build_zed_llm_url("/predict_edits/v3", &[])?;
2591
2592        let request = PredictEditsV3Request { input, trigger };
2593
2594        let json_bytes = serde_json::to_vec(&request)?;
2595        let compressed = zstd::encode_all(&json_bytes[..], 3)?;
2596
2597        Self::send_api_request(
2598            |builder| {
2599                let req = builder
2600                    .uri(url.as_ref())
2601                    .header("Content-Encoding", "zstd")
2602                    .body(compressed.clone().into());
2603                Ok(req?)
2604            },
2605            client,
2606            llm_token,
2607            organization_id,
2608            app_version,
2609            true,
2610        )
2611        .await
2612    }
2613
2614    async fn send_api_request<Res>(
2615        build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
2616        client: Arc<Client>,
2617        llm_token: LlmApiToken,
2618        organization_id: Option<OrganizationId>,
2619        app_version: Version,
2620        require_auth: bool,
2621    ) -> Result<(Res, Option<EditPredictionUsage>)>
2622    where
2623        Res: DeserializeOwned,
2624    {
2625        let http_client = client.http_client();
2626        let mut token = if require_auth {
2627            Some(
2628                client
2629                    .acquire_llm_token(&llm_token, organization_id.clone())
2630                    .await?,
2631            )
2632        } else {
2633            client
2634                .acquire_llm_token(&llm_token, organization_id.clone())
2635                .await
2636                .ok()
2637        };
2638        let mut did_retry = false;
2639
2640        loop {
2641            let request_builder = http_client::Request::builder().method(Method::POST);
2642
2643            let mut request_builder = request_builder
2644                .header("Content-Type", "application/json")
2645                .header(ZED_VERSION_HEADER_NAME, app_version.to_string());
2646
2647            // Only add Authorization header if we have a token
2648            if let Some(ref token_value) = token {
2649                request_builder =
2650                    request_builder.header("Authorization", format!("Bearer {}", token_value));
2651            }
2652
2653            let request = build(request_builder)?;
2654
2655            let mut response = http_client.send(request).await?;
2656
2657            if let Some(minimum_required_version) = response
2658                .headers()
2659                .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
2660                .and_then(|version| Version::from_str(version.to_str().ok()?).ok())
2661            {
2662                anyhow::ensure!(
2663                    app_version >= minimum_required_version,
2664                    ZedUpdateRequiredError {
2665                        minimum_version: minimum_required_version
2666                    }
2667                );
2668            }
2669
2670            if response.status().is_success() {
2671                let usage = EditPredictionUsage::from_headers(response.headers()).ok();
2672
2673                let mut body = Vec::new();
2674                response.body_mut().read_to_end(&mut body).await?;
2675                return Ok((serde_json::from_slice(&body)?, usage));
2676            } else if !did_retry && token.is_some() && response.needs_llm_token_refresh() {
2677                did_retry = true;
2678                token = Some(
2679                    client
2680                        .refresh_llm_token(&llm_token, organization_id.clone())
2681                        .await?,
2682                );
2683            } else {
2684                let mut body = String::new();
2685                response.body_mut().read_to_string(&mut body).await?;
2686                anyhow::bail!(
2687                    "Request failed with status: {:?}\nBody: {}",
2688                    response.status(),
2689                    body
2690                );
2691            }
2692        }
2693    }
2694
2695    pub fn refresh_context(
2696        &mut self,
2697        project: &Entity<Project>,
2698        buffer: &Entity<language::Buffer>,
2699        cursor_position: language::Anchor,
2700        cx: &mut Context<Self>,
2701    ) {
2702        self.get_or_init_project(project, cx)
2703            .context
2704            .update(cx, |store, cx| {
2705                store.refresh(buffer.clone(), cursor_position, cx);
2706            });
2707    }
2708
2709    #[cfg(feature = "cli-support")]
2710    pub fn set_context_for_buffer(
2711        &mut self,
2712        project: &Entity<Project>,
2713        related_files: Vec<RelatedFile>,
2714        cx: &mut Context<Self>,
2715    ) {
2716        self.get_or_init_project(project, cx)
2717            .context
2718            .update(cx, |store, cx| {
2719                store.set_related_files(related_files, cx);
2720            });
2721    }
2722
2723    #[cfg(feature = "cli-support")]
2724    pub fn set_recent_paths_for_project(
2725        &mut self,
2726        project: &Entity<Project>,
2727        paths: impl IntoIterator<Item = project::ProjectPath>,
2728        cx: &mut Context<Self>,
2729    ) {
2730        let project_state = self.get_or_init_project(project, cx);
2731        project_state.recent_paths = paths.into_iter().collect();
2732    }
2733
2734    fn is_file_open_source(
2735        &self,
2736        project: &Entity<Project>,
2737        file: &Arc<dyn File>,
2738        cx: &App,
2739    ) -> bool {
2740        if !file.is_local() || file.is_private() {
2741            return false;
2742        }
2743        let Some(project_state) = self.projects.get(&project.entity_id()) else {
2744            return false;
2745        };
2746        project_state
2747            .license_detection_watchers
2748            .get(&file.worktree_id(cx))
2749            .as_ref()
2750            .is_some_and(|watcher| watcher.is_project_open_source())
2751    }
2752
2753    pub(crate) fn is_data_collection_enabled(&self, cx: &App) -> bool {
2754        self.data_collection_choice.is_enabled(cx)
2755    }
2756
2757    fn load_data_collection_choice(cx: &App) -> DataCollectionChoice {
2758        let choice = KeyValueStore::global(cx)
2759            .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2760            .log_err()
2761            .flatten();
2762
2763        match choice.as_deref() {
2764            Some("true") => DataCollectionChoice::Enabled,
2765            Some("false") => DataCollectionChoice::Disabled,
2766            Some(_) => {
2767                log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
2768                DataCollectionChoice::NotAnswered
2769            }
2770            None => DataCollectionChoice::NotAnswered,
2771        }
2772    }
2773
2774    fn toggle_data_collection_choice(&mut self, cx: &mut Context<Self>) {
2775        self.data_collection_choice = self.data_collection_choice.toggle();
2776        let new_choice = self.data_collection_choice;
2777        let is_enabled = new_choice.is_enabled(cx);
2778        let kvp = KeyValueStore::global(cx);
2779        db::write_and_log(cx, move || async move {
2780            kvp.write_kvp(
2781                ZED_PREDICT_DATA_COLLECTION_CHOICE.into(),
2782                is_enabled.to_string(),
2783            )
2784            .await
2785        });
2786    }
2787
2788    pub fn shown_predictions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
2789        self.shown_predictions.iter()
2790    }
2791
2792    pub fn shown_completions_len(&self) -> usize {
2793        self.shown_predictions.len()
2794    }
2795
2796    pub fn is_prediction_rated(&self, id: &EditPredictionId) -> bool {
2797        self.rated_predictions.contains(id)
2798    }
2799
2800    pub fn rate_prediction(
2801        &mut self,
2802        prediction: &EditPrediction,
2803        rating: EditPredictionRating,
2804        feedback: String,
2805        cx: &mut Context<Self>,
2806    ) {
2807        let organization = self.user_store.read(cx).current_organization();
2808
2809        self.rated_predictions.insert(prediction.id.clone());
2810
2811        cx.background_spawn({
2812            let client = self.client.clone();
2813            let prediction_id = prediction.id.to_string();
2814            let inputs = serde_json::to_value(&prediction.inputs);
2815            let output = prediction
2816                .edit_preview
2817                .as_unified_diff(prediction.snapshot.file(), &prediction.edits);
2818            async move {
2819                client
2820                    .cloud_client()
2821                    .submit_edit_prediction_feedback(SubmitEditPredictionFeedbackBody {
2822                        organization_id: organization.map(|organization| organization.id.clone()),
2823                        request_id: prediction_id,
2824                        rating: match rating {
2825                            EditPredictionRating::Positive => "positive".to_string(),
2826                            EditPredictionRating::Negative => "negative".to_string(),
2827                        },
2828                        inputs: inputs?,
2829                        output,
2830                        feedback,
2831                    })
2832                    .await?;
2833
2834                anyhow::Ok(())
2835            }
2836        })
2837        .detach_and_log_err(cx);
2838
2839        cx.notify();
2840    }
2841}
2842
2843fn collaborator_edit_overlaps_locality_region(
2844    project_state: &ProjectState,
2845    project: &Entity<Project>,
2846    buffer: &Entity<Buffer>,
2847    snapshot: &BufferSnapshot,
2848    edit_range: &Range<Anchor>,
2849    cx: &App,
2850) -> bool {
2851    let Some((active_buffer, Some(position))) = project_state.active_buffer(project, cx) else {
2852        return false;
2853    };
2854
2855    if active_buffer.entity_id() != buffer.entity_id() {
2856        return false;
2857    }
2858
2859    let locality_point_range = expand_context_syntactically_then_linewise(
2860        snapshot,
2861        (position..position).to_point(snapshot),
2862        COLLABORATOR_EDIT_LOCALITY_CONTEXT_TOKENS,
2863    );
2864    let locality_anchor_range = snapshot.anchor_range_inside(locality_point_range);
2865
2866    edit_range.overlaps(&locality_anchor_range, snapshot)
2867}
2868
2869fn merge_trailing_events_if_needed(
2870    events: &mut VecDeque<StoredEvent>,
2871    end_snapshot: &TextBufferSnapshot,
2872    latest_snapshot: &TextBufferSnapshot,
2873    latest_edit_range: &Range<Anchor>,
2874) {
2875    if let Some(last_event) = events.back() {
2876        if last_event.old_snapshot.remote_id() != latest_snapshot.remote_id() {
2877            return;
2878        }
2879        if !latest_snapshot
2880            .version
2881            .observed_all(&last_event.new_snapshot_version)
2882        {
2883            return;
2884        }
2885    }
2886
2887    let mut next_old_event = None;
2888    let mut mergeable_count = 0;
2889    for old_event in events.iter().rev() {
2890        if let Some(next_old_event) = next_old_event
2891            && !old_event.can_merge(next_old_event, latest_snapshot, latest_edit_range)
2892        {
2893            break;
2894        }
2895        mergeable_count += 1;
2896        next_old_event = Some(old_event);
2897    }
2898
2899    if mergeable_count <= 1 {
2900        return;
2901    }
2902
2903    let mut events_to_merge = events.range(events.len() - mergeable_count..).peekable();
2904    let oldest_event = events_to_merge.peek().unwrap();
2905    let oldest_snapshot = oldest_event.old_snapshot.clone();
2906    let newest_snapshot = end_snapshot;
2907    let mut merged_edit_range = oldest_event.total_edit_range.clone();
2908
2909    for event in events.range(events.len() - mergeable_count + 1..) {
2910        merged_edit_range =
2911            merge_anchor_ranges(&merged_edit_range, &event.total_edit_range, latest_snapshot);
2912    }
2913
2914    if let Some((diff, edit_range)) = compute_diff_between_snapshots_in_range(
2915        &oldest_snapshot,
2916        newest_snapshot,
2917        &merged_edit_range,
2918    ) {
2919        let merged_event = match oldest_event.event.as_ref() {
2920            zeta_prompt::Event::BufferChange {
2921                old_path,
2922                path,
2923                in_open_source_repo,
2924                ..
2925            } => StoredEvent {
2926                event: Arc::new(zeta_prompt::Event::BufferChange {
2927                    old_path: old_path.clone(),
2928                    path: path.clone(),
2929                    diff,
2930                    in_open_source_repo: *in_open_source_repo,
2931                    predicted: events_to_merge.all(|e| {
2932                        matches!(
2933                            e.event.as_ref(),
2934                            zeta_prompt::Event::BufferChange {
2935                                predicted: true,
2936                                ..
2937                            }
2938                        )
2939                    }),
2940                }),
2941                old_snapshot: oldest_snapshot.clone(),
2942                new_snapshot_version: newest_snapshot.version.clone(),
2943                total_edit_range: newest_snapshot.anchor_before(edit_range.start)
2944                    ..newest_snapshot.anchor_before(edit_range.end),
2945            },
2946        };
2947        events.truncate(events.len() - mergeable_count);
2948        events.push_back(merged_event);
2949    }
2950}
2951
2952fn merge_anchor_ranges(
2953    left: &Range<Anchor>,
2954    right: &Range<Anchor>,
2955    snapshot: &TextBufferSnapshot,
2956) -> Range<Anchor> {
2957    let start = if left.start.cmp(&right.start, snapshot).is_le() {
2958        left.start
2959    } else {
2960        right.start
2961    };
2962    let end = if left.end.cmp(&right.end, snapshot).is_ge() {
2963        left.end
2964    } else {
2965        right.end
2966    };
2967    start..end
2968}
2969
2970#[derive(Error, Debug)]
2971#[error(
2972    "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
2973)]
2974pub struct ZedUpdateRequiredError {
2975    minimum_version: Version,
2976}
2977
2978#[derive(Debug, Clone, Copy)]
2979pub enum DataCollectionChoice {
2980    NotAnswered,
2981    Enabled,
2982    Disabled,
2983}
2984
2985impl DataCollectionChoice {
2986    pub fn is_enabled(self, cx: &App) -> bool {
2987        if cx.is_staff() {
2988            return true;
2989        }
2990        match self {
2991            Self::Enabled => true,
2992            Self::NotAnswered | Self::Disabled => false,
2993        }
2994    }
2995
2996    #[must_use]
2997    pub fn toggle(&self) -> DataCollectionChoice {
2998        match self {
2999            Self::Enabled => Self::Disabled,
3000            Self::Disabled => Self::Enabled,
3001            Self::NotAnswered => Self::Enabled,
3002        }
3003    }
3004}
3005
3006impl From<bool> for DataCollectionChoice {
3007    fn from(value: bool) -> Self {
3008        match value {
3009            true => DataCollectionChoice::Enabled,
3010            false => DataCollectionChoice::Disabled,
3011        }
3012    }
3013}
3014
3015struct ZedPredictUpsell;
3016
3017impl Dismissable for ZedPredictUpsell {
3018    const KEY: &'static str = "dismissed-edit-predict-upsell";
3019
3020    fn dismissed(cx: &App) -> bool {
3021        // To make this backwards compatible with older versions of Zed, we
3022        // check if the user has seen the previous Edit Prediction Onboarding
3023        // before, by checking the data collection choice which was written to
3024        // the database once the user clicked on "Accept and Enable"
3025        let kvp = KeyValueStore::global(cx);
3026        if kvp
3027            .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
3028            .log_err()
3029            .is_some_and(|s| s.is_some())
3030        {
3031            return true;
3032        }
3033
3034        kvp.read_kvp(Self::KEY)
3035            .log_err()
3036            .is_some_and(|s| s.is_some())
3037    }
3038}
3039
3040pub fn should_show_upsell_modal(cx: &App) -> bool {
3041    !ZedPredictUpsell::dismissed(cx)
3042}
3043
3044pub fn init(cx: &mut App) {
3045    cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
3046        workspace.register_action(
3047            move |workspace, _: &zed_actions::OpenZedPredictOnboarding, window, cx| {
3048                ZedPredictModal::toggle(
3049                    workspace,
3050                    workspace.user_store().clone(),
3051                    workspace.client().clone(),
3052                    window,
3053                    cx,
3054                )
3055            },
3056        );
3057
3058        workspace.register_action(|workspace, _: &ResetOnboarding, _window, cx| {
3059            update_settings_file(workspace.app_state().fs.clone(), cx, move |settings, _| {
3060                settings
3061                    .project
3062                    .all_languages
3063                    .edit_predictions
3064                    .get_or_insert_default()
3065                    .provider = Some(EditPredictionProvider::None)
3066            });
3067        });
3068        fn copilot_for_project(project: &Entity<Project>, cx: &mut App) -> Option<Entity<Copilot>> {
3069            EditPredictionStore::try_global(cx).and_then(|store| {
3070                store.update(cx, |this, cx| this.start_copilot_for_project(project, cx))
3071            })
3072        }
3073
3074        workspace.register_action(|workspace, _: &SignIn, window, cx| {
3075            if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
3076                copilot_ui::initiate_sign_in(copilot, window, cx);
3077            }
3078        });
3079        workspace.register_action(|workspace, _: &Reinstall, window, cx| {
3080            if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
3081                copilot_ui::reinstall_and_sign_in(copilot, window, cx);
3082            }
3083        });
3084        workspace.register_action(|workspace, _: &SignOut, window, cx| {
3085            if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
3086                copilot_ui::initiate_sign_out(copilot, window, cx);
3087            }
3088        });
3089    })
3090    .detach();
3091}